1 | //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Transforms/ViewOpGraph.h" |
10 | |
11 | #include "mlir/IR/Block.h" |
12 | #include "mlir/IR/BuiltinTypes.h" |
13 | #include "mlir/IR/Operation.h" |
14 | #include "mlir/Pass/Pass.h" |
15 | #include "mlir/Support/IndentedOstream.h" |
16 | #include "mlir/Transforms/TopologicalSortUtils.h" |
17 | #include "llvm/Support/Format.h" |
18 | #include "llvm/Support/GraphWriter.h" |
19 | #include <map> |
20 | #include <optional> |
21 | #include <utility> |
22 | |
23 | namespace mlir { |
24 | #define GEN_PASS_DEF_VIEWOPGRAPH |
25 | #include "mlir/Transforms/Passes.h.inc" |
26 | } // namespace mlir |
27 | |
28 | using namespace mlir; |
29 | |
30 | static const StringRef kLineStyleControlFlow = "dashed" ; |
31 | static const StringRef kLineStyleDataFlow = "solid" ; |
32 | static const StringRef kShapeNode = "ellipse" ; |
33 | static const StringRef kShapeNone = "plain" ; |
34 | |
35 | /// Return the size limits for eliding large attributes. |
36 | static int64_t getLargeAttributeSizeLimit() { |
37 | // Use the default from the printer flags if possible. |
38 | if (std::optional<int64_t> limit = |
39 | OpPrintingFlags().getLargeElementsAttrLimit()) |
40 | return *limit; |
41 | return 16; |
42 | } |
43 | |
44 | /// Return all values printed onto a stream as a string. |
45 | static std::string strFromOs(function_ref<void(raw_ostream &)> func) { |
46 | std::string buf; |
47 | llvm::raw_string_ostream os(buf); |
48 | func(os); |
49 | return os.str(); |
50 | } |
51 | |
52 | /// Escape special characters such as '\n' and quotation marks. |
53 | static std::string escapeString(std::string str) { |
54 | return strFromOs([&](raw_ostream &os) { os.write_escaped(Str: str); }); |
55 | } |
56 | |
57 | /// Put quotation marks around a given string. |
58 | static std::string quoteString(const std::string &str) { |
59 | return "\"" + str + "\"" ; |
60 | } |
61 | |
62 | using AttributeMap = std::map<std::string, std::string>; |
63 | |
64 | namespace { |
65 | |
66 | /// This struct represents a node in the DOT language. Each node has an |
67 | /// identifier and an optional identifier for the cluster (subgraph) that |
68 | /// contains the node. |
69 | /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but |
70 | /// not between clusters. However, edges can be clipped to the boundary of a |
71 | /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new |
72 | /// cluster, an invisible "anchor" node is created. |
73 | struct Node { |
74 | public: |
75 | Node(int id = 0, std::optional<int> clusterId = std::nullopt) |
76 | : id(id), clusterId(clusterId) {} |
77 | |
78 | int id; |
79 | std::optional<int> clusterId; |
80 | }; |
81 | |
82 | /// This pass generates a Graphviz dataflow visualization of an MLIR operation. |
83 | /// Note: See https://www.graphviz.org/doc/info/lang.html for more information |
84 | /// about the Graphviz DOT language. |
85 | class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> { |
86 | public: |
87 | PrintOpPass(raw_ostream &os) : os(os) {} |
88 | PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {} |
89 | |
90 | void runOnOperation() override { |
91 | initColorMapping(*getOperation()); |
92 | emitGraph(builder: [&]() { |
93 | processOperation(op: getOperation()); |
94 | emitAllEdgeStmts(); |
95 | }); |
96 | } |
97 | |
98 | /// Create a CFG graph for a region. Used in `Region::viewGraph`. |
99 | void emitRegionCFG(Region ®ion) { |
100 | printControlFlowEdges = true; |
101 | printDataFlowEdges = false; |
102 | initColorMapping(irEntity&: region); |
103 | emitGraph(builder: [&]() { processRegion(region); }); |
104 | } |
105 | |
106 | private: |
107 | /// Generate a color mapping that will color every operation with the same |
108 | /// name the same way. It'll interpolate the hue in the HSV color-space, |
109 | /// attempting to keep the contrast suitable for black text. |
110 | template <typename T> |
111 | void initColorMapping(T &irEntity) { |
112 | backgroundColors.clear(); |
113 | SmallVector<Operation *> ops; |
114 | irEntity.walk([&](Operation *op) { |
115 | auto &entry = backgroundColors[op->getName()]; |
116 | if (entry.first == 0) |
117 | ops.push_back(Elt: op); |
118 | ++entry.first; |
119 | }); |
120 | for (auto indexedOps : llvm::enumerate(First&: ops)) { |
121 | double hue = ((double)indexedOps.index()) / ops.size(); |
122 | backgroundColors[indexedOps.value()->getName()].second = |
123 | std::to_string(val: hue) + " 1.0 1.0" ; |
124 | } |
125 | } |
126 | |
127 | /// Emit all edges. This function should be called after all nodes have been |
128 | /// emitted. |
129 | void emitAllEdgeStmts() { |
130 | if (printDataFlowEdges) { |
131 | for (const auto &[value, node, label] : dataFlowEdges) { |
132 | emitEdgeStmt(n1: valueToNode[value], n2: node, label, style: kLineStyleDataFlow); |
133 | } |
134 | } |
135 | |
136 | for (const std::string &edge : edges) |
137 | os << edge << ";\n" ; |
138 | edges.clear(); |
139 | } |
140 | |
141 | /// Emit a cluster (subgraph). The specified builder generates the body of the |
142 | /// cluster. Return the anchor node of the cluster. |
143 | Node emitClusterStmt(function_ref<void()> builder, std::string label = "" ) { |
144 | int clusterId = ++counter; |
145 | os << "subgraph cluster_" << clusterId << " {\n" ; |
146 | os.indent(); |
147 | // Emit invisible anchor node from/to which arrows can be drawn. |
148 | Node anchorNode = emitNodeStmt(label: " " , shape: kShapeNone); |
149 | os << attrStmt(key: "label" , value: quoteString(str: escapeString(str: std::move(label)))) |
150 | << ";\n" ; |
151 | builder(); |
152 | os.unindent(); |
153 | os << "}\n" ; |
154 | return Node(anchorNode.id, clusterId); |
155 | } |
156 | |
157 | /// Generate an attribute statement. |
158 | std::string attrStmt(const Twine &key, const Twine &value) { |
159 | return (key + " = " + value).str(); |
160 | } |
161 | |
162 | /// Emit an attribute list. |
163 | void emitAttrList(raw_ostream &os, const AttributeMap &map) { |
164 | os << "[" ; |
165 | interleaveComma(c: map, os, each_fn: [&](const auto &it) { |
166 | os << this->attrStmt(key: it.first, value: it.second); |
167 | }); |
168 | os << "]" ; |
169 | } |
170 | |
171 | // Print an MLIR attribute to `os`. Large attributes are truncated. |
172 | void emitMlirAttr(raw_ostream &os, Attribute attr) { |
173 | // A value used to elide large container attribute. |
174 | int64_t largeAttrLimit = getLargeAttributeSizeLimit(); |
175 | |
176 | // Always emit splat attributes. |
177 | if (isa<SplatElementsAttr>(Val: attr)) { |
178 | attr.print(os); |
179 | return; |
180 | } |
181 | |
182 | // Elide "big" elements attributes. |
183 | auto elements = dyn_cast<ElementsAttr>(attr); |
184 | if (elements && elements.getNumElements() > largeAttrLimit) { |
185 | os << std::string(elements.getShapedType().getRank(), '[') << "..." |
186 | << std::string(elements.getShapedType().getRank(), ']') << " : " |
187 | << elements.getType(); |
188 | return; |
189 | } |
190 | |
191 | auto array = dyn_cast<ArrayAttr>(attr); |
192 | if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) { |
193 | os << "[...]" ; |
194 | return; |
195 | } |
196 | |
197 | // Print all other attributes. |
198 | std::string buf; |
199 | llvm::raw_string_ostream ss(buf); |
200 | attr.print(os&: ss); |
201 | os << truncateString(str: ss.str()); |
202 | } |
203 | |
204 | /// Append an edge to the list of edges. |
205 | /// Note: Edges are written to the output stream via `emitAllEdgeStmts`. |
206 | void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) { |
207 | AttributeMap attrs; |
208 | attrs["style" ] = style.str(); |
209 | // Do not label edges that start/end at a cluster boundary. Such edges are |
210 | // clipped at the boundary, but labels are not. This can lead to labels |
211 | // floating around without any edge next to them. |
212 | if (!n1.clusterId && !n2.clusterId) |
213 | attrs["label" ] = quoteString(str: escapeString(str: std::move(label))); |
214 | // Use `ltail` and `lhead` to draw edges between clusters. |
215 | if (n1.clusterId) |
216 | attrs["ltail" ] = "cluster_" + std::to_string(val: *n1.clusterId); |
217 | if (n2.clusterId) |
218 | attrs["lhead" ] = "cluster_" + std::to_string(val: *n2.clusterId); |
219 | |
220 | edges.push_back(x: strFromOs([&](raw_ostream &os) { |
221 | os << llvm::format(Fmt: "v%i -> v%i " , Vals: n1.id, Vals: n2.id); |
222 | emitAttrList(os, map: attrs); |
223 | })); |
224 | } |
225 | |
226 | /// Emit a graph. The specified builder generates the body of the graph. |
227 | void emitGraph(function_ref<void()> builder) { |
228 | os << "digraph G {\n" ; |
229 | os.indent(); |
230 | // Edges between clusters are allowed only in compound mode. |
231 | os << attrStmt(key: "compound" , value: "true" ) << ";\n" ; |
232 | builder(); |
233 | os.unindent(); |
234 | os << "}\n" ; |
235 | } |
236 | |
237 | /// Emit a node statement. |
238 | Node emitNodeStmt(std::string label, StringRef shape = kShapeNode, |
239 | StringRef background = "" ) { |
240 | int nodeId = ++counter; |
241 | AttributeMap attrs; |
242 | attrs["label" ] = quoteString(str: escapeString(str: std::move(label))); |
243 | attrs["shape" ] = shape.str(); |
244 | if (!background.empty()) { |
245 | attrs["style" ] = "filled" ; |
246 | attrs["fillcolor" ] = ("\"" + background + "\"" ).str(); |
247 | } |
248 | os << llvm::format(Fmt: "v%i " , Vals: nodeId); |
249 | emitAttrList(os, map: attrs); |
250 | os << ";\n" ; |
251 | return Node(nodeId); |
252 | } |
253 | |
254 | /// Generate a label for an operation. |
255 | std::string getLabel(Operation *op) { |
256 | return strFromOs([&](raw_ostream &os) { |
257 | // Print operation name and type. |
258 | os << op->getName(); |
259 | if (printResultTypes) { |
260 | os << " : (" ; |
261 | std::string buf; |
262 | llvm::raw_string_ostream ss(buf); |
263 | interleaveComma(c: op->getResultTypes(), os&: ss); |
264 | os << truncateString(str: ss.str()) << ")" ; |
265 | } |
266 | |
267 | // Print attributes. |
268 | if (printAttrs) { |
269 | os << "\n" ; |
270 | for (const NamedAttribute &attr : op->getAttrs()) { |
271 | os << '\n' << attr.getName().getValue() << ": " ; |
272 | emitMlirAttr(os, attr: attr.getValue()); |
273 | } |
274 | } |
275 | }); |
276 | } |
277 | |
278 | /// Generate a label for a block argument. |
279 | std::string getLabel(BlockArgument arg) { |
280 | return "arg" + std::to_string(val: arg.getArgNumber()); |
281 | } |
282 | |
283 | /// Process a block. Emit a cluster and one node per block argument and |
284 | /// operation inside the cluster. |
285 | void processBlock(Block &block) { |
286 | emitClusterStmt(builder: [&]() { |
287 | for (BlockArgument &blockArg : block.getArguments()) |
288 | valueToNode[blockArg] = emitNodeStmt(label: getLabel(arg: blockArg)); |
289 | |
290 | // Emit a node for each operation. |
291 | std::optional<Node> prevNode; |
292 | for (Operation &op : block) { |
293 | Node nextNode = processOperation(op: &op); |
294 | if (printControlFlowEdges && prevNode) |
295 | emitEdgeStmt(n1: *prevNode, n2: nextNode, /*label=*/"" , |
296 | style: kLineStyleControlFlow); |
297 | prevNode = nextNode; |
298 | } |
299 | }); |
300 | } |
301 | |
302 | /// Process an operation. If the operation has regions, emit a cluster. |
303 | /// Otherwise, emit a node. |
304 | Node processOperation(Operation *op) { |
305 | Node node; |
306 | if (op->getNumRegions() > 0) { |
307 | // Emit cluster for op with regions. |
308 | node = emitClusterStmt( |
309 | builder: [&]() { |
310 | for (Region ®ion : op->getRegions()) |
311 | processRegion(region); |
312 | }, |
313 | label: getLabel(op)); |
314 | } else { |
315 | node = emitNodeStmt(label: getLabel(op), shape: kShapeNode, |
316 | background: backgroundColors[op->getName()].second); |
317 | } |
318 | |
319 | // Insert data flow edges originating from each operand. |
320 | if (printDataFlowEdges) { |
321 | unsigned numOperands = op->getNumOperands(); |
322 | for (unsigned i = 0; i < numOperands; i++) |
323 | dataFlowEdges.push_back(x: {op->getOperand(idx: i), node, |
324 | numOperands == 1 ? "" : std::to_string(val: i)}); |
325 | } |
326 | |
327 | for (Value result : op->getResults()) |
328 | valueToNode[result] = node; |
329 | |
330 | return node; |
331 | } |
332 | |
333 | /// Process a region. |
334 | void processRegion(Region ®ion) { |
335 | for (Block &block : region.getBlocks()) |
336 | processBlock(block); |
337 | } |
338 | |
339 | /// Truncate long strings. |
340 | std::string truncateString(std::string str) { |
341 | if (str.length() <= maxLabelLen) |
342 | return str; |
343 | return str.substr(0, maxLabelLen) + "..." ; |
344 | } |
345 | |
346 | /// Output stream to write DOT file to. |
347 | raw_indented_ostream os; |
348 | /// A list of edges. For simplicity, should be emitted after all nodes were |
349 | /// emitted. |
350 | std::vector<std::string> edges; |
351 | /// Mapping of SSA values to Graphviz nodes/clusters. |
352 | DenseMap<Value, Node> valueToNode; |
353 | /// Output for data flow edges is delayed until the end to handle cycles |
354 | std::vector<std::tuple<Value, Node, std::string>> dataFlowEdges; |
355 | /// Counter for generating unique node/subgraph identifiers. |
356 | int counter = 0; |
357 | |
358 | DenseMap<OperationName, std::pair<int, std::string>> backgroundColors; |
359 | }; |
360 | |
361 | } // namespace |
362 | |
363 | std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) { |
364 | return std::make_unique<PrintOpPass>(args&: os); |
365 | } |
366 | |
367 | /// Generate a CFG for a region and show it in a window. |
368 | static void llvmViewGraph(Region ®ion, const Twine &name) { |
369 | int fd; |
370 | std::string filename = llvm::createGraphFilename(Name: name.str(), FD&: fd); |
371 | { |
372 | llvm::raw_fd_ostream os(fd, /*shouldClose=*/true); |
373 | if (fd == -1) { |
374 | llvm::errs() << "error opening file '" << filename << "' for writing\n" ; |
375 | return; |
376 | } |
377 | PrintOpPass pass(os); |
378 | pass.emitRegionCFG(region); |
379 | } |
380 | llvm::DisplayGraph(Filename: filename, /*wait=*/false, program: llvm::GraphProgram::DOT); |
381 | } |
382 | |
383 | void mlir::Region::viewGraph(const Twine ®ionName) { |
384 | llvmViewGraph(region&: *this, name: regionName); |
385 | } |
386 | |
387 | void mlir::Region::viewGraph() { viewGraph(regionName: "region" ); } |
388 | |