1//===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Transforms/ViewOpGraph.h"
10
11#include "mlir/Analysis/TopologicalSortUtils.h"
12#include "mlir/IR/Block.h"
13#include "mlir/IR/BuiltinTypes.h"
14#include "mlir/IR/Operation.h"
15#include "mlir/Pass/Pass.h"
16#include "mlir/Support/IndentedOstream.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/Support/Format.h"
19#include "llvm/Support/GraphWriter.h"
20#include <map>
21#include <optional>
22#include <utility>
23
24namespace mlir {
25#define GEN_PASS_DEF_VIEWOPGRAPH
26#include "mlir/Transforms/Passes.h.inc"
27} // namespace mlir
28
29using namespace mlir;
30
31static const StringRef kLineStyleControlFlow = "dashed";
32static const StringRef kLineStyleDataFlow = "solid";
33static const StringRef kShapeNode = "Mrecord";
34static const StringRef kShapeNone = "plain";
35
36/// Return the size limits for eliding large attributes.
37static int64_t getLargeAttributeSizeLimit() {
38 // Use the default from the printer flags if possible.
39 if (std::optional<int64_t> limit =
40 OpPrintingFlags().getLargeElementsAttrLimit())
41 return *limit;
42 return 16;
43}
44
45/// Return all values printed onto a stream as a string.
46static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
47 std::string buf;
48 llvm::raw_string_ostream os(buf);
49 func(os);
50 return buf;
51}
52
53/// Put quotation marks around a given string.
54static std::string quoteString(const std::string &str) {
55 return "\"" + str + "\"";
56}
57
58/// For Graphviz record nodes:
59/// " Braces, vertical bars and angle brackets must be escaped with a backslash
60/// character if you wish them to appear as a literal character "
61std::string escapeLabelString(const std::string &str) {
62 std::string buf;
63 llvm::raw_string_ostream os(buf);
64 for (char c : str) {
65 if (llvm::is_contained(Set: {'{', '|', '<', '}', '>', '\n', '"'}, Element: c))
66 os << '\\';
67 os << c;
68 }
69 return buf;
70}
71
72using AttributeMap = std::map<std::string, std::string>;
73
74namespace {
75
76/// This struct represents a node in the DOT language. Each node has an
77/// identifier and an optional identifier for the cluster (subgraph) that
78/// contains the node.
79/// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
80/// not between clusters. However, edges can be clipped to the boundary of a
81/// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
82/// cluster, an invisible "anchor" node is created.
83struct Node {
84public:
85 Node(int id = 0, std::optional<int> clusterId = std::nullopt)
86 : id(id), clusterId(clusterId) {}
87
88 int id;
89 std::optional<int> clusterId;
90};
91
92struct DataFlowEdge {
93 Value value;
94 Node node;
95 std::string port;
96};
97
98/// This pass generates a Graphviz dataflow visualization of an MLIR operation.
99/// Note: See https://www.graphviz.org/doc/info/lang.html for more information
100/// about the Graphviz DOT language.
101class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
102public:
103 PrintOpPass(raw_ostream &os) : os(os) {}
104 PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
105
106 void runOnOperation() override {
107 initColorMapping(*getOperation());
108 emitGraph(builder: [&]() {
109 processOperation(op: getOperation());
110 emitAllEdgeStmts();
111 });
112 markAllAnalysesPreserved();
113 }
114
115 /// Create a CFG graph for a region. Used in `Region::viewGraph`.
116 void emitRegionCFG(Region &region) {
117 printControlFlowEdges = true;
118 printDataFlowEdges = false;
119 initColorMapping(irEntity&: region);
120 emitGraph(builder: [&]() { processRegion(region); });
121 }
122
123private:
124 /// Generate a color mapping that will color every operation with the same
125 /// name the same way. It'll interpolate the hue in the HSV color-space,
126 /// using muted colors that provide good contrast for black text.
127 template <typename T>
128 void initColorMapping(T &irEntity) {
129 backgroundColors.clear();
130 SmallVector<Operation *> ops;
131 irEntity.walk([&](Operation *op) {
132 auto &entry = backgroundColors[op->getName()];
133 if (entry.first == 0)
134 ops.push_back(Elt: op);
135 ++entry.first;
136 });
137 for (auto indexedOps : llvm::enumerate(First&: ops)) {
138 double hue = ((double)indexedOps.index()) / ops.size();
139 // Use lower saturation (0.3) and higher value (0.95) for better
140 // readability
141 backgroundColors[indexedOps.value()->getName()].second =
142 std::to_string(val: hue) + " 0.3 0.95";
143 }
144 }
145
146 /// Emit all edges. This function should be called after all nodes have been
147 /// emitted.
148 void emitAllEdgeStmts() {
149 if (printDataFlowEdges) {
150 for (const auto &e : dataFlowEdges) {
151 emitEdgeStmt(n1: valueToNode[e.value], n2: e.node, port: e.port, style: kLineStyleDataFlow);
152 }
153 }
154
155 for (const std::string &edge : edges)
156 os << edge << ";\n";
157 edges.clear();
158 }
159
160 /// Emit a cluster (subgraph). The specified builder generates the body of the
161 /// cluster. Return the anchor node of the cluster.
162 Node emitClusterStmt(function_ref<void()> builder, std::string label = "") {
163 int clusterId = ++counter;
164 os << "subgraph cluster_" << clusterId << " {\n";
165 os.indent();
166 // Emit invisible anchor node from/to which arrows can be drawn.
167 Node anchorNode = emitNodeStmt(label: " ", shape: kShapeNone);
168 os << attrStmt(key: "label", value: quoteString(str: label)) << ";\n";
169 builder();
170 os.unindent();
171 os << "}\n";
172 return Node(anchorNode.id, clusterId);
173 }
174
175 /// Generate an attribute statement.
176 std::string attrStmt(const Twine &key, const Twine &value) {
177 return (key + " = " + value).str();
178 }
179
180 /// Emit an attribute list.
181 void emitAttrList(raw_ostream &os, const AttributeMap &map) {
182 os << "[";
183 interleaveComma(c: map, os, each_fn: [&](const auto &it) {
184 os << this->attrStmt(key: it.first, value: it.second);
185 });
186 os << "]";
187 }
188
189 // Print an MLIR attribute to `os`. Large attributes are truncated.
190 void emitMlirAttr(raw_ostream &os, Attribute attr) {
191 // A value used to elide large container attribute.
192 int64_t largeAttrLimit = getLargeAttributeSizeLimit();
193
194 // Always emit splat attributes.
195 if (isa<SplatElementsAttr>(Val: attr)) {
196 os << escapeLabelString(
197 str: strFromOs([&](raw_ostream &os) { attr.print(os); }));
198 return;
199 }
200
201 // Elide "big" elements attributes.
202 auto elements = dyn_cast<ElementsAttr>(attr);
203 if (elements && elements.getNumElements() > largeAttrLimit) {
204 os << std::string(elements.getShapedType().getRank(), '[') << "..."
205 << std::string(elements.getShapedType().getRank(), ']') << " : ";
206 emitMlirType(os, type: elements.getType());
207 return;
208 }
209
210 auto array = dyn_cast<ArrayAttr>(attr);
211 if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
212 os << "[...]";
213 return;
214 }
215
216 // Print all other attributes.
217 std::string buf;
218 llvm::raw_string_ostream ss(buf);
219 attr.print(os&: ss);
220 os << escapeLabelString(str: truncateString(str: buf));
221 }
222
223 // Print a truncated and escaped MLIR type to `os`.
224 void emitMlirType(raw_ostream &os, Type type) {
225 std::string buf;
226 llvm::raw_string_ostream ss(buf);
227 type.print(os&: ss);
228 os << escapeLabelString(str: truncateString(str: buf));
229 }
230
231 // Print a truncated and escaped MLIR operand to `os`.
232 void emitMlirOperand(raw_ostream &os, Value operand) {
233 operand.printAsOperand(os, flags: OpPrintingFlags());
234 }
235
236 /// Append an edge to the list of edges.
237 /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
238 void emitEdgeStmt(Node n1, Node n2, std::string port, StringRef style) {
239 AttributeMap attrs;
240 attrs["style"] = style.str();
241 // Use `ltail` and `lhead` to draw edges between clusters.
242 if (n1.clusterId)
243 attrs["ltail"] = "cluster_" + std::to_string(val: *n1.clusterId);
244 if (n2.clusterId)
245 attrs["lhead"] = "cluster_" + std::to_string(val: *n2.clusterId);
246
247 edges.push_back(x: strFromOs([&](raw_ostream &os) {
248 os << "v" << n1.id;
249 if (!port.empty() && !n1.clusterId)
250 // Attach edge to south compass point of the result
251 os << ":res" << port << ":s";
252 os << " -> ";
253 os << "v" << n2.id;
254 if (!port.empty() && !n2.clusterId)
255 // Attach edge to north compass point of the operand
256 os << ":arg" << port << ":n";
257 emitAttrList(os, map: attrs);
258 }));
259 }
260
261 /// Emit a graph. The specified builder generates the body of the graph.
262 void emitGraph(function_ref<void()> builder) {
263 os << "digraph G {\n";
264 os.indent();
265 // Edges between clusters are allowed only in compound mode.
266 os << attrStmt(key: "compound", value: "true") << ";\n";
267 builder();
268 os.unindent();
269 os << "}\n";
270 }
271
272 /// Emit a node statement.
273 Node emitNodeStmt(std::string label, StringRef shape = kShapeNode,
274 StringRef background = "") {
275 int nodeId = ++counter;
276 AttributeMap attrs;
277 attrs["label"] = quoteString(str: label);
278 attrs["shape"] = shape.str();
279 if (!background.empty()) {
280 attrs["style"] = "filled";
281 attrs["fillcolor"] = quoteString(str: background.str());
282 }
283 os << llvm::format(Fmt: "v%i ", Vals: nodeId);
284 emitAttrList(os, map: attrs);
285 os << ";\n";
286 return Node(nodeId);
287 }
288
289 std::string getValuePortName(Value operand) {
290 // Print value as an operand and omit the leading '%' character.
291 auto str = strFromOs([&](raw_ostream &os) {
292 operand.printAsOperand(os, flags: OpPrintingFlags());
293 });
294 // Replace % and # with _
295 llvm::replace(str, '%', '_');
296 llvm::replace(str, '#', '_');
297 return str;
298 }
299
300 std::string getClusterLabel(Operation *op) {
301 return strFromOs([&](raw_ostream &os) {
302 // Print operation name and type.
303 os << op->getName();
304 if (printResultTypes) {
305 os << " : (";
306 std::string buf;
307 llvm::raw_string_ostream ss(buf);
308 interleaveComma(c: op->getResultTypes(), os&: ss);
309 os << truncateString(str: buf) << ")";
310 }
311
312 // Print attributes.
313 if (printAttrs) {
314 os << "\\l";
315 for (const NamedAttribute &attr : op->getAttrs()) {
316 os << escapeLabelString(attr.getName().getValue().str()) << ": ";
317 emitMlirAttr(os, attr: attr.getValue());
318 os << "\\l";
319 }
320 }
321 });
322 }
323
324 /// Generate a label for an operation.
325 std::string getRecordLabel(Operation *op) {
326 return strFromOs([&](raw_ostream &os) {
327 os << "{";
328
329 // Print operation inputs.
330 if (op->getNumOperands() > 0) {
331 os << "{";
332 auto operandToPort = [&](Value operand) {
333 os << "<arg" << getValuePortName(operand) << "> ";
334 emitMlirOperand(os, operand);
335 };
336 interleave(c: op->getOperands(), os, each_fn: operandToPort, separator: "|");
337 os << "}|";
338 }
339 // Print operation name and type.
340 os << op->getName() << "\\l";
341
342 // Print attributes.
343 if (printAttrs && !op->getAttrs().empty()) {
344 // Extra line break to separate attributes from the operation name.
345 os << "\\l";
346 for (const NamedAttribute &attr : op->getAttrs()) {
347 os << attr.getName().getValue() << ": ";
348 emitMlirAttr(os, attr: attr.getValue());
349 os << "\\l";
350 }
351 }
352
353 if (op->getNumResults() > 0) {
354 os << "|{";
355 auto resultToPort = [&](Value result) {
356 os << "<res" << getValuePortName(operand: result) << "> ";
357 emitMlirOperand(os, operand: result);
358 if (printResultTypes) {
359 os << " ";
360 emitMlirType(os, type: result.getType());
361 }
362 };
363 interleave(c: op->getResults(), os, each_fn: resultToPort, separator: "|");
364 os << "}";
365 }
366
367 os << "}";
368 });
369 }
370
371 /// Generate a label for a block argument.
372 std::string getLabel(BlockArgument arg) {
373 return strFromOs([&](raw_ostream &os) {
374 os << "<res" << getValuePortName(operand: arg) << "> ";
375 arg.printAsOperand(os, flags: OpPrintingFlags());
376 if (printResultTypes) {
377 os << " ";
378 emitMlirType(os, type: arg.getType());
379 }
380 });
381 }
382
383 /// Process a block. Emit a cluster and one node per block argument and
384 /// operation inside the cluster.
385 void processBlock(Block &block) {
386 emitClusterStmt(builder: [&]() {
387 for (BlockArgument &blockArg : block.getArguments())
388 valueToNode[blockArg] = emitNodeStmt(label: getLabel(arg: blockArg));
389 // Emit a node for each operation.
390 std::optional<Node> prevNode;
391 for (Operation &op : block) {
392 Node nextNode = processOperation(op: &op);
393 if (printControlFlowEdges && prevNode)
394 emitEdgeStmt(n1: *prevNode, n2: nextNode, /*port=*/"", style: kLineStyleControlFlow);
395 prevNode = nextNode;
396 }
397 });
398 }
399
400 /// Process an operation. If the operation has regions, emit a cluster.
401 /// Otherwise, emit a node.
402 Node processOperation(Operation *op) {
403 Node node;
404 if (op->getNumRegions() > 0) {
405 // Emit cluster for op with regions.
406 node = emitClusterStmt(
407 builder: [&]() {
408 for (Region &region : op->getRegions())
409 processRegion(region);
410 },
411 label: getClusterLabel(op));
412 } else {
413 node = emitNodeStmt(label: getRecordLabel(op), shape: kShapeNode,
414 background: backgroundColors[op->getName()].second);
415 }
416
417 // Insert data flow edges originating from each operand.
418 if (printDataFlowEdges) {
419 unsigned numOperands = op->getNumOperands();
420 for (unsigned i = 0; i < numOperands; i++) {
421 auto operand = op->getOperand(idx: i);
422 dataFlowEdges.push_back(x: {.value: operand, .node: node, .port: getValuePortName(operand)});
423 }
424 }
425
426 for (Value result : op->getResults())
427 valueToNode[result] = node;
428
429 return node;
430 }
431
432 /// Process a region.
433 void processRegion(Region &region) {
434 for (Block &block : region.getBlocks())
435 processBlock(block);
436 }
437
438 /// Truncate long strings.
439 std::string truncateString(std::string str) {
440 if (str.length() <= maxLabelLen)
441 return str;
442 return str.substr(0, maxLabelLen) + "...";
443 }
444
445 /// Output stream to write DOT file to.
446 raw_indented_ostream os;
447 /// A list of edges. For simplicity, should be emitted after all nodes were
448 /// emitted.
449 std::vector<std::string> edges;
450 /// Mapping of SSA values to Graphviz nodes/clusters.
451 DenseMap<Value, Node> valueToNode;
452 /// Output for data flow edges is delayed until the end to handle cycles
453 std::vector<DataFlowEdge> dataFlowEdges;
454 /// Counter for generating unique node/subgraph identifiers.
455 int counter = 0;
456
457 DenseMap<OperationName, std::pair<int, std::string>> backgroundColors;
458};
459
460} // namespace
461
462std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) {
463 return std::make_unique<PrintOpPass>(args&: os);
464}
465
466/// Generate a CFG for a region and show it in a window.
467static void llvmViewGraph(Region &region, const Twine &name) {
468 int fd;
469 std::string filename = llvm::createGraphFilename(Name: name.str(), FD&: fd);
470 {
471 llvm::raw_fd_ostream os(fd, /*shouldClose=*/true);
472 if (fd == -1) {
473 llvm::errs() << "error opening file '" << filename << "' for writing\n";
474 return;
475 }
476 PrintOpPass pass(os);
477 pass.emitRegionCFG(region);
478 }
479 llvm::DisplayGraph(Filename: filename, /*wait=*/false, program: llvm::GraphProgram::DOT);
480}
481
482void mlir::Region::viewGraph(const Twine &regionName) {
483 llvmViewGraph(region&: *this, name: regionName);
484}
485
486void mlir::Region::viewGraph() { viewGraph(regionName: "region"); }
487

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Transforms/ViewOpGraph.cpp