1 | //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Transforms/ViewOpGraph.h" |
10 | |
11 | #include "mlir/Analysis/TopologicalSortUtils.h" |
12 | #include "mlir/IR/Block.h" |
13 | #include "mlir/IR/BuiltinTypes.h" |
14 | #include "mlir/IR/Operation.h" |
15 | #include "mlir/Pass/Pass.h" |
16 | #include "mlir/Support/IndentedOstream.h" |
17 | #include "llvm/ADT/STLExtras.h" |
18 | #include "llvm/Support/Format.h" |
19 | #include "llvm/Support/GraphWriter.h" |
20 | #include <map> |
21 | #include <optional> |
22 | #include <utility> |
23 | |
24 | namespace mlir { |
25 | #define GEN_PASS_DEF_VIEWOPGRAPH |
26 | #include "mlir/Transforms/Passes.h.inc" |
27 | } // namespace mlir |
28 | |
29 | using namespace mlir; |
30 | |
31 | static const StringRef kLineStyleControlFlow = "dashed"; |
32 | static const StringRef kLineStyleDataFlow = "solid"; |
33 | static const StringRef kShapeNode = "Mrecord"; |
34 | static const StringRef kShapeNone = "plain"; |
35 | |
36 | /// Return the size limits for eliding large attributes. |
37 | static int64_t getLargeAttributeSizeLimit() { |
38 | // Use the default from the printer flags if possible. |
39 | if (std::optional<int64_t> limit = |
40 | OpPrintingFlags().getLargeElementsAttrLimit()) |
41 | return *limit; |
42 | return 16; |
43 | } |
44 | |
45 | /// Return all values printed onto a stream as a string. |
46 | static std::string strFromOs(function_ref<void(raw_ostream &)> func) { |
47 | std::string buf; |
48 | llvm::raw_string_ostream os(buf); |
49 | func(os); |
50 | return buf; |
51 | } |
52 | |
53 | /// Put quotation marks around a given string. |
54 | static std::string quoteString(const std::string &str) { |
55 | return "\""+ str + "\""; |
56 | } |
57 | |
58 | /// For Graphviz record nodes: |
59 | /// " Braces, vertical bars and angle brackets must be escaped with a backslash |
60 | /// character if you wish them to appear as a literal character " |
61 | std::string escapeLabelString(const std::string &str) { |
62 | std::string buf; |
63 | llvm::raw_string_ostream os(buf); |
64 | for (char c : str) { |
65 | if (llvm::is_contained(Set: {'{', '|', '<', '}', '>', '\n', '"'}, Element: c)) |
66 | os << '\\'; |
67 | os << c; |
68 | } |
69 | return buf; |
70 | } |
71 | |
72 | using AttributeMap = std::map<std::string, std::string>; |
73 | |
74 | namespace { |
75 | |
76 | /// This struct represents a node in the DOT language. Each node has an |
77 | /// identifier and an optional identifier for the cluster (subgraph) that |
78 | /// contains the node. |
79 | /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but |
80 | /// not between clusters. However, edges can be clipped to the boundary of a |
81 | /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new |
82 | /// cluster, an invisible "anchor" node is created. |
83 | struct Node { |
84 | public: |
85 | Node(int id = 0, std::optional<int> clusterId = std::nullopt) |
86 | : id(id), clusterId(clusterId) {} |
87 | |
88 | int id; |
89 | std::optional<int> clusterId; |
90 | }; |
91 | |
92 | struct DataFlowEdge { |
93 | Value value; |
94 | Node node; |
95 | std::string port; |
96 | }; |
97 | |
98 | /// This pass generates a Graphviz dataflow visualization of an MLIR operation. |
99 | /// Note: See https://www.graphviz.org/doc/info/lang.html for more information |
100 | /// about the Graphviz DOT language. |
101 | class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> { |
102 | public: |
103 | PrintOpPass(raw_ostream &os) : os(os) {} |
104 | PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {} |
105 | |
106 | void runOnOperation() override { |
107 | initColorMapping(*getOperation()); |
108 | emitGraph(builder: [&]() { |
109 | processOperation(op: getOperation()); |
110 | emitAllEdgeStmts(); |
111 | }); |
112 | markAllAnalysesPreserved(); |
113 | } |
114 | |
115 | /// Create a CFG graph for a region. Used in `Region::viewGraph`. |
116 | void emitRegionCFG(Region ®ion) { |
117 | printControlFlowEdges = true; |
118 | printDataFlowEdges = false; |
119 | initColorMapping(irEntity&: region); |
120 | emitGraph(builder: [&]() { processRegion(region); }); |
121 | } |
122 | |
123 | private: |
124 | /// Generate a color mapping that will color every operation with the same |
125 | /// name the same way. It'll interpolate the hue in the HSV color-space, |
126 | /// using muted colors that provide good contrast for black text. |
127 | template <typename T> |
128 | void initColorMapping(T &irEntity) { |
129 | backgroundColors.clear(); |
130 | SmallVector<Operation *> ops; |
131 | irEntity.walk([&](Operation *op) { |
132 | auto &entry = backgroundColors[op->getName()]; |
133 | if (entry.first == 0) |
134 | ops.push_back(Elt: op); |
135 | ++entry.first; |
136 | }); |
137 | for (auto indexedOps : llvm::enumerate(First&: ops)) { |
138 | double hue = ((double)indexedOps.index()) / ops.size(); |
139 | // Use lower saturation (0.3) and higher value (0.95) for better |
140 | // readability |
141 | backgroundColors[indexedOps.value()->getName()].second = |
142 | std::to_string(val: hue) + " 0.3 0.95"; |
143 | } |
144 | } |
145 | |
146 | /// Emit all edges. This function should be called after all nodes have been |
147 | /// emitted. |
148 | void emitAllEdgeStmts() { |
149 | if (printDataFlowEdges) { |
150 | for (const auto &e : dataFlowEdges) { |
151 | emitEdgeStmt(n1: valueToNode[e.value], n2: e.node, port: e.port, style: kLineStyleDataFlow); |
152 | } |
153 | } |
154 | |
155 | for (const std::string &edge : edges) |
156 | os << edge << ";\n"; |
157 | edges.clear(); |
158 | } |
159 | |
160 | /// Emit a cluster (subgraph). The specified builder generates the body of the |
161 | /// cluster. Return the anchor node of the cluster. |
162 | Node emitClusterStmt(function_ref<void()> builder, std::string label = "") { |
163 | int clusterId = ++counter; |
164 | os << "subgraph cluster_"<< clusterId << " {\n"; |
165 | os.indent(); |
166 | // Emit invisible anchor node from/to which arrows can be drawn. |
167 | Node anchorNode = emitNodeStmt(label: " ", shape: kShapeNone); |
168 | os << attrStmt(key: "label", value: quoteString(str: label)) << ";\n"; |
169 | builder(); |
170 | os.unindent(); |
171 | os << "}\n"; |
172 | return Node(anchorNode.id, clusterId); |
173 | } |
174 | |
175 | /// Generate an attribute statement. |
176 | std::string attrStmt(const Twine &key, const Twine &value) { |
177 | return (key + " = "+ value).str(); |
178 | } |
179 | |
180 | /// Emit an attribute list. |
181 | void emitAttrList(raw_ostream &os, const AttributeMap &map) { |
182 | os << "["; |
183 | interleaveComma(c: map, os, each_fn: [&](const auto &it) { |
184 | os << this->attrStmt(key: it.first, value: it.second); |
185 | }); |
186 | os << "]"; |
187 | } |
188 | |
189 | // Print an MLIR attribute to `os`. Large attributes are truncated. |
190 | void emitMlirAttr(raw_ostream &os, Attribute attr) { |
191 | // A value used to elide large container attribute. |
192 | int64_t largeAttrLimit = getLargeAttributeSizeLimit(); |
193 | |
194 | // Always emit splat attributes. |
195 | if (isa<SplatElementsAttr>(Val: attr)) { |
196 | os << escapeLabelString( |
197 | str: strFromOs([&](raw_ostream &os) { attr.print(os); })); |
198 | return; |
199 | } |
200 | |
201 | // Elide "big" elements attributes. |
202 | auto elements = dyn_cast<ElementsAttr>(attr); |
203 | if (elements && elements.getNumElements() > largeAttrLimit) { |
204 | os << std::string(elements.getShapedType().getRank(), '[') << "..." |
205 | << std::string(elements.getShapedType().getRank(), ']') << " : "; |
206 | emitMlirType(os, type: elements.getType()); |
207 | return; |
208 | } |
209 | |
210 | auto array = dyn_cast<ArrayAttr>(attr); |
211 | if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) { |
212 | os << "[...]"; |
213 | return; |
214 | } |
215 | |
216 | // Print all other attributes. |
217 | std::string buf; |
218 | llvm::raw_string_ostream ss(buf); |
219 | attr.print(os&: ss); |
220 | os << escapeLabelString(str: truncateString(str: buf)); |
221 | } |
222 | |
223 | // Print a truncated and escaped MLIR type to `os`. |
224 | void emitMlirType(raw_ostream &os, Type type) { |
225 | std::string buf; |
226 | llvm::raw_string_ostream ss(buf); |
227 | type.print(os&: ss); |
228 | os << escapeLabelString(str: truncateString(str: buf)); |
229 | } |
230 | |
231 | // Print a truncated and escaped MLIR operand to `os`. |
232 | void emitMlirOperand(raw_ostream &os, Value operand) { |
233 | operand.printAsOperand(os, flags: OpPrintingFlags()); |
234 | } |
235 | |
236 | /// Append an edge to the list of edges. |
237 | /// Note: Edges are written to the output stream via `emitAllEdgeStmts`. |
238 | void emitEdgeStmt(Node n1, Node n2, std::string port, StringRef style) { |
239 | AttributeMap attrs; |
240 | attrs["style"] = style.str(); |
241 | // Use `ltail` and `lhead` to draw edges between clusters. |
242 | if (n1.clusterId) |
243 | attrs["ltail"] = "cluster_"+ std::to_string(val: *n1.clusterId); |
244 | if (n2.clusterId) |
245 | attrs["lhead"] = "cluster_"+ std::to_string(val: *n2.clusterId); |
246 | |
247 | edges.push_back(x: strFromOs([&](raw_ostream &os) { |
248 | os << "v"<< n1.id; |
249 | if (!port.empty() && !n1.clusterId) |
250 | // Attach edge to south compass point of the result |
251 | os << ":res"<< port << ":s"; |
252 | os << " -> "; |
253 | os << "v"<< n2.id; |
254 | if (!port.empty() && !n2.clusterId) |
255 | // Attach edge to north compass point of the operand |
256 | os << ":arg"<< port << ":n"; |
257 | emitAttrList(os, map: attrs); |
258 | })); |
259 | } |
260 | |
261 | /// Emit a graph. The specified builder generates the body of the graph. |
262 | void emitGraph(function_ref<void()> builder) { |
263 | os << "digraph G {\n"; |
264 | os.indent(); |
265 | // Edges between clusters are allowed only in compound mode. |
266 | os << attrStmt(key: "compound", value: "true") << ";\n"; |
267 | builder(); |
268 | os.unindent(); |
269 | os << "}\n"; |
270 | } |
271 | |
272 | /// Emit a node statement. |
273 | Node emitNodeStmt(std::string label, StringRef shape = kShapeNode, |
274 | StringRef background = "") { |
275 | int nodeId = ++counter; |
276 | AttributeMap attrs; |
277 | attrs["label"] = quoteString(str: label); |
278 | attrs["shape"] = shape.str(); |
279 | if (!background.empty()) { |
280 | attrs["style"] = "filled"; |
281 | attrs["fillcolor"] = quoteString(str: background.str()); |
282 | } |
283 | os << llvm::format(Fmt: "v%i ", Vals: nodeId); |
284 | emitAttrList(os, map: attrs); |
285 | os << ";\n"; |
286 | return Node(nodeId); |
287 | } |
288 | |
289 | std::string getValuePortName(Value operand) { |
290 | // Print value as an operand and omit the leading '%' character. |
291 | auto str = strFromOs([&](raw_ostream &os) { |
292 | operand.printAsOperand(os, flags: OpPrintingFlags()); |
293 | }); |
294 | // Replace % and # with _ |
295 | llvm::replace(str, '%', '_'); |
296 | llvm::replace(str, '#', '_'); |
297 | return str; |
298 | } |
299 | |
300 | std::string getClusterLabel(Operation *op) { |
301 | return strFromOs([&](raw_ostream &os) { |
302 | // Print operation name and type. |
303 | os << op->getName(); |
304 | if (printResultTypes) { |
305 | os << " : ("; |
306 | std::string buf; |
307 | llvm::raw_string_ostream ss(buf); |
308 | interleaveComma(c: op->getResultTypes(), os&: ss); |
309 | os << truncateString(str: buf) << ")"; |
310 | } |
311 | |
312 | // Print attributes. |
313 | if (printAttrs) { |
314 | os << "\\l"; |
315 | for (const NamedAttribute &attr : op->getAttrs()) { |
316 | os << escapeLabelString(attr.getName().getValue().str()) << ": "; |
317 | emitMlirAttr(os, attr: attr.getValue()); |
318 | os << "\\l"; |
319 | } |
320 | } |
321 | }); |
322 | } |
323 | |
324 | /// Generate a label for an operation. |
325 | std::string getRecordLabel(Operation *op) { |
326 | return strFromOs([&](raw_ostream &os) { |
327 | os << "{"; |
328 | |
329 | // Print operation inputs. |
330 | if (op->getNumOperands() > 0) { |
331 | os << "{"; |
332 | auto operandToPort = [&](Value operand) { |
333 | os << "<arg"<< getValuePortName(operand) << "> "; |
334 | emitMlirOperand(os, operand); |
335 | }; |
336 | interleave(c: op->getOperands(), os, each_fn: operandToPort, separator: "|"); |
337 | os << "}|"; |
338 | } |
339 | // Print operation name and type. |
340 | os << op->getName() << "\\l"; |
341 | |
342 | // Print attributes. |
343 | if (printAttrs && !op->getAttrs().empty()) { |
344 | // Extra line break to separate attributes from the operation name. |
345 | os << "\\l"; |
346 | for (const NamedAttribute &attr : op->getAttrs()) { |
347 | os << attr.getName().getValue() << ": "; |
348 | emitMlirAttr(os, attr: attr.getValue()); |
349 | os << "\\l"; |
350 | } |
351 | } |
352 | |
353 | if (op->getNumResults() > 0) { |
354 | os << "|{"; |
355 | auto resultToPort = [&](Value result) { |
356 | os << "<res"<< getValuePortName(operand: result) << "> "; |
357 | emitMlirOperand(os, operand: result); |
358 | if (printResultTypes) { |
359 | os << " "; |
360 | emitMlirType(os, type: result.getType()); |
361 | } |
362 | }; |
363 | interleave(c: op->getResults(), os, each_fn: resultToPort, separator: "|"); |
364 | os << "}"; |
365 | } |
366 | |
367 | os << "}"; |
368 | }); |
369 | } |
370 | |
371 | /// Generate a label for a block argument. |
372 | std::string getLabel(BlockArgument arg) { |
373 | return strFromOs([&](raw_ostream &os) { |
374 | os << "<res"<< getValuePortName(operand: arg) << "> "; |
375 | arg.printAsOperand(os, flags: OpPrintingFlags()); |
376 | if (printResultTypes) { |
377 | os << " "; |
378 | emitMlirType(os, type: arg.getType()); |
379 | } |
380 | }); |
381 | } |
382 | |
383 | /// Process a block. Emit a cluster and one node per block argument and |
384 | /// operation inside the cluster. |
385 | void processBlock(Block &block) { |
386 | emitClusterStmt(builder: [&]() { |
387 | for (BlockArgument &blockArg : block.getArguments()) |
388 | valueToNode[blockArg] = emitNodeStmt(label: getLabel(arg: blockArg)); |
389 | // Emit a node for each operation. |
390 | std::optional<Node> prevNode; |
391 | for (Operation &op : block) { |
392 | Node nextNode = processOperation(op: &op); |
393 | if (printControlFlowEdges && prevNode) |
394 | emitEdgeStmt(n1: *prevNode, n2: nextNode, /*port=*/"", style: kLineStyleControlFlow); |
395 | prevNode = nextNode; |
396 | } |
397 | }); |
398 | } |
399 | |
400 | /// Process an operation. If the operation has regions, emit a cluster. |
401 | /// Otherwise, emit a node. |
402 | Node processOperation(Operation *op) { |
403 | Node node; |
404 | if (op->getNumRegions() > 0) { |
405 | // Emit cluster for op with regions. |
406 | node = emitClusterStmt( |
407 | builder: [&]() { |
408 | for (Region ®ion : op->getRegions()) |
409 | processRegion(region); |
410 | }, |
411 | label: getClusterLabel(op)); |
412 | } else { |
413 | node = emitNodeStmt(label: getRecordLabel(op), shape: kShapeNode, |
414 | background: backgroundColors[op->getName()].second); |
415 | } |
416 | |
417 | // Insert data flow edges originating from each operand. |
418 | if (printDataFlowEdges) { |
419 | unsigned numOperands = op->getNumOperands(); |
420 | for (unsigned i = 0; i < numOperands; i++) { |
421 | auto operand = op->getOperand(idx: i); |
422 | dataFlowEdges.push_back(x: {.value: operand, .node: node, .port: getValuePortName(operand)}); |
423 | } |
424 | } |
425 | |
426 | for (Value result : op->getResults()) |
427 | valueToNode[result] = node; |
428 | |
429 | return node; |
430 | } |
431 | |
432 | /// Process a region. |
433 | void processRegion(Region ®ion) { |
434 | for (Block &block : region.getBlocks()) |
435 | processBlock(block); |
436 | } |
437 | |
438 | /// Truncate long strings. |
439 | std::string truncateString(std::string str) { |
440 | if (str.length() <= maxLabelLen) |
441 | return str; |
442 | return str.substr(0, maxLabelLen) + "..."; |
443 | } |
444 | |
445 | /// Output stream to write DOT file to. |
446 | raw_indented_ostream os; |
447 | /// A list of edges. For simplicity, should be emitted after all nodes were |
448 | /// emitted. |
449 | std::vector<std::string> edges; |
450 | /// Mapping of SSA values to Graphviz nodes/clusters. |
451 | DenseMap<Value, Node> valueToNode; |
452 | /// Output for data flow edges is delayed until the end to handle cycles |
453 | std::vector<DataFlowEdge> dataFlowEdges; |
454 | /// Counter for generating unique node/subgraph identifiers. |
455 | int counter = 0; |
456 | |
457 | DenseMap<OperationName, std::pair<int, std::string>> backgroundColors; |
458 | }; |
459 | |
460 | } // namespace |
461 | |
462 | std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) { |
463 | return std::make_unique<PrintOpPass>(args&: os); |
464 | } |
465 | |
466 | /// Generate a CFG for a region and show it in a window. |
467 | static void llvmViewGraph(Region ®ion, const Twine &name) { |
468 | int fd; |
469 | std::string filename = llvm::createGraphFilename(Name: name.str(), FD&: fd); |
470 | { |
471 | llvm::raw_fd_ostream os(fd, /*shouldClose=*/true); |
472 | if (fd == -1) { |
473 | llvm::errs() << "error opening file '"<< filename << "' for writing\n"; |
474 | return; |
475 | } |
476 | PrintOpPass pass(os); |
477 | pass.emitRegionCFG(region); |
478 | } |
479 | llvm::DisplayGraph(Filename: filename, /*wait=*/false, program: llvm::GraphProgram::DOT); |
480 | } |
481 | |
482 | void mlir::Region::viewGraph(const Twine ®ionName) { |
483 | llvmViewGraph(region&: *this, name: regionName); |
484 | } |
485 | |
486 | void mlir::Region::viewGraph() { viewGraph(regionName: "region"); } |
487 |
Definitions
- kLineStyleControlFlow
- kLineStyleDataFlow
- kShapeNode
- kShapeNone
- getLargeAttributeSizeLimit
- strFromOs
- quoteString
- escapeLabelString
- Node
- Node
- DataFlowEdge
- PrintOpPass
- PrintOpPass
- PrintOpPass
- runOnOperation
- emitRegionCFG
- initColorMapping
- emitAllEdgeStmts
- emitClusterStmt
- attrStmt
- emitAttrList
- emitMlirAttr
- emitMlirType
- emitMlirOperand
- emitEdgeStmt
- emitGraph
- emitNodeStmt
- getValuePortName
- getClusterLabel
- getRecordLabel
- getLabel
- processBlock
- processOperation
- processRegion
- truncateString
- createPrintOpGraphPass
- llvmViewGraph
- viewGraph
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more