1//===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Transforms/ViewOpGraph.h"
10
11#include "mlir/IR/Block.h"
12#include "mlir/IR/BuiltinTypes.h"
13#include "mlir/IR/Operation.h"
14#include "mlir/Pass/Pass.h"
15#include "mlir/Support/IndentedOstream.h"
16#include "llvm/ADT/STLExtras.h"
17#include "llvm/Support/Format.h"
18#include "llvm/Support/GraphWriter.h"
19#include <map>
20#include <optional>
21#include <utility>
22
23namespace mlir {
24#define GEN_PASS_DEF_VIEWOPGRAPH
25#include "mlir/Transforms/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29
30static const StringRef kLineStyleControlFlow = "dashed";
31static const StringRef kLineStyleDataFlow = "solid";
32static const StringRef kShapeNode = "Mrecord";
33static const StringRef kShapeNone = "plain";
34
35/// Return the size limits for eliding large attributes.
36static int64_t getLargeAttributeSizeLimit() {
37 // Use the default from the printer flags if possible.
38 if (std::optional<int64_t> limit =
39 OpPrintingFlags().getLargeElementsAttrLimit())
40 return *limit;
41 return 16;
42}
43
44/// Return all values printed onto a stream as a string.
45static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
46 std::string buf;
47 llvm::raw_string_ostream os(buf);
48 func(os);
49 return buf;
50}
51
52/// Put quotation marks around a given string.
53static std::string quoteString(const std::string &str) {
54 return "\"" + str + "\"";
55}
56
57/// For Graphviz record nodes:
58/// " Braces, vertical bars and angle brackets must be escaped with a backslash
59/// character if you wish them to appear as a literal character "
60std::string escapeLabelString(const std::string &str) {
61 std::string buf;
62 llvm::raw_string_ostream os(buf);
63 for (char c : str) {
64 if (llvm::is_contained(Set: {'{', '|', '<', '}', '>', '\n', '"'}, Element: c))
65 os << '\\';
66 os << c;
67 }
68 return buf;
69}
70
71using AttributeMap = std::map<std::string, std::string>;
72
73namespace {
74
75/// This struct represents a node in the DOT language. Each node has an
76/// identifier and an optional identifier for the cluster (subgraph) that
77/// contains the node.
78/// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
79/// not between clusters. However, edges can be clipped to the boundary of a
80/// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
81/// cluster, an invisible "anchor" node is created.
82struct Node {
83public:
84 Node(int id = 0, std::optional<int> clusterId = std::nullopt)
85 : id(id), clusterId(clusterId) {}
86
87 int id;
88 std::optional<int> clusterId;
89};
90
91struct DataFlowEdge {
92 Value value;
93 Node node;
94 std::string port;
95};
96
97/// This pass generates a Graphviz dataflow visualization of an MLIR operation.
98/// Note: See https://www.graphviz.org/doc/info/lang.html for more information
99/// about the Graphviz DOT language.
100class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
101public:
102 PrintOpPass(raw_ostream &os) : os(os) {}
103 PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
104
105 void runOnOperation() override {
106 initColorMapping(irEntity&: *getOperation());
107 emitGraph(builder: [&]() {
108 processOperation(op: getOperation());
109 emitAllEdgeStmts();
110 });
111 markAllAnalysesPreserved();
112 }
113
114 /// Create a CFG graph for a region. Used in `Region::viewGraph`.
115 void emitRegionCFG(Region &region) {
116 printControlFlowEdges = true;
117 printDataFlowEdges = false;
118 initColorMapping(irEntity&: region);
119 emitGraph(builder: [&]() { processRegion(region); });
120 }
121
122private:
123 /// Generate a color mapping that will color every operation with the same
124 /// name the same way. It'll interpolate the hue in the HSV color-space,
125 /// using muted colors that provide good contrast for black text.
126 template <typename T>
127 void initColorMapping(T &irEntity) {
128 backgroundColors.clear();
129 SmallVector<Operation *> ops;
130 irEntity.walk([&](Operation *op) {
131 auto &entry = backgroundColors[op->getName()];
132 if (entry.first == 0)
133 ops.push_back(Elt: op);
134 ++entry.first;
135 });
136 for (auto indexedOps : llvm::enumerate(First&: ops)) {
137 double hue = ((double)indexedOps.index()) / ops.size();
138 // Use lower saturation (0.3) and higher value (0.95) for better
139 // readability
140 backgroundColors[indexedOps.value()->getName()].second =
141 std::to_string(val: hue) + " 0.3 0.95";
142 }
143 }
144
145 /// Emit all edges. This function should be called after all nodes have been
146 /// emitted.
147 void emitAllEdgeStmts() {
148 if (printDataFlowEdges) {
149 for (const auto &e : dataFlowEdges) {
150 emitEdgeStmt(n1: valueToNode[e.value], n2: e.node, port: e.port, style: kLineStyleDataFlow);
151 }
152 }
153
154 for (const std::string &edge : edges)
155 os << edge << ";\n";
156 edges.clear();
157 }
158
159 /// Emit a cluster (subgraph). The specified builder generates the body of the
160 /// cluster. Return the anchor node of the cluster.
161 Node emitClusterStmt(function_ref<void()> builder, std::string label = "") {
162 int clusterId = ++counter;
163 os << "subgraph cluster_" << clusterId << " {\n";
164 os.indent();
165 // Emit invisible anchor node from/to which arrows can be drawn.
166 Node anchorNode = emitNodeStmt(label: " ", shape: kShapeNone);
167 os << attrStmt(key: "label", value: quoteString(str: label)) << ";\n";
168 builder();
169 os.unindent();
170 os << "}\n";
171 return Node(anchorNode.id, clusterId);
172 }
173
174 /// Generate an attribute statement.
175 std::string attrStmt(const Twine &key, const Twine &value) {
176 return (key + " = " + value).str();
177 }
178
179 /// Emit an attribute list.
180 void emitAttrList(raw_ostream &os, const AttributeMap &map) {
181 os << "[";
182 interleaveComma(c: map, os, each_fn: [&](const auto &it) {
183 os << this->attrStmt(key: it.first, value: it.second);
184 });
185 os << "]";
186 }
187
188 // Print an MLIR attribute to `os`. Large attributes are truncated.
189 void emitMlirAttr(raw_ostream &os, Attribute attr) {
190 // A value used to elide large container attribute.
191 int64_t largeAttrLimit = getLargeAttributeSizeLimit();
192
193 // Always emit splat attributes.
194 if (isa<SplatElementsAttr>(Val: attr)) {
195 os << escapeLabelString(
196 str: strFromOs(func: [&](raw_ostream &os) { attr.print(os); }));
197 return;
198 }
199
200 // Elide "big" elements attributes.
201 auto elements = dyn_cast<ElementsAttr>(Val&: attr);
202 if (elements && elements.getNumElements() > largeAttrLimit) {
203 os << std::string(elements.getShapedType().getRank(), '[') << "..."
204 << std::string(elements.getShapedType().getRank(), ']') << " : ";
205 emitMlirType(os, type: elements.getType());
206 return;
207 }
208
209 auto array = dyn_cast<ArrayAttr>(Val&: attr);
210 if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
211 os << "[...]";
212 return;
213 }
214
215 // Print all other attributes.
216 std::string buf;
217 llvm::raw_string_ostream ss(buf);
218 attr.print(os&: ss);
219 os << escapeLabelString(str: truncateString(str: buf));
220 }
221
222 // Print a truncated and escaped MLIR type to `os`.
223 void emitMlirType(raw_ostream &os, Type type) {
224 std::string buf;
225 llvm::raw_string_ostream ss(buf);
226 type.print(os&: ss);
227 os << escapeLabelString(str: truncateString(str: buf));
228 }
229
230 // Print a truncated and escaped MLIR operand to `os`.
231 void emitMlirOperand(raw_ostream &os, Value operand) {
232 operand.printAsOperand(os, flags: OpPrintingFlags());
233 }
234
235 /// Append an edge to the list of edges.
236 /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
237 void emitEdgeStmt(Node n1, Node n2, std::string port, StringRef style) {
238 AttributeMap attrs;
239 attrs["style"] = style.str();
240 // Use `ltail` and `lhead` to draw edges between clusters.
241 if (n1.clusterId)
242 attrs["ltail"] = "cluster_" + std::to_string(val: *n1.clusterId);
243 if (n2.clusterId)
244 attrs["lhead"] = "cluster_" + std::to_string(val: *n2.clusterId);
245
246 edges.push_back(x: strFromOs(func: [&](raw_ostream &os) {
247 os << "v" << n1.id;
248 if (!port.empty() && !n1.clusterId)
249 // Attach edge to south compass point of the result
250 os << ":res" << port << ":s";
251 os << " -> ";
252 os << "v" << n2.id;
253 if (!port.empty() && !n2.clusterId)
254 // Attach edge to north compass point of the operand
255 os << ":arg" << port << ":n";
256 emitAttrList(os, map: attrs);
257 }));
258 }
259
260 /// Emit a graph. The specified builder generates the body of the graph.
261 void emitGraph(function_ref<void()> builder) {
262 os << "digraph G {\n";
263 os.indent();
264 // Edges between clusters are allowed only in compound mode.
265 os << attrStmt(key: "compound", value: "true") << ";\n";
266 builder();
267 os.unindent();
268 os << "}\n";
269 }
270
271 /// Emit a node statement.
272 Node emitNodeStmt(std::string label, StringRef shape = kShapeNode,
273 StringRef background = "") {
274 int nodeId = ++counter;
275 AttributeMap attrs;
276 attrs["label"] = quoteString(str: label);
277 attrs["shape"] = shape.str();
278 if (!background.empty()) {
279 attrs["style"] = "filled";
280 attrs["fillcolor"] = quoteString(str: background.str());
281 }
282 os << llvm::format(Fmt: "v%i ", Vals: nodeId);
283 emitAttrList(os, map: attrs);
284 os << ";\n";
285 return Node(nodeId);
286 }
287
288 std::string getValuePortName(Value operand) {
289 // Print value as an operand and omit the leading '%' character.
290 auto str = strFromOs(func: [&](raw_ostream &os) {
291 operand.printAsOperand(os, flags: OpPrintingFlags());
292 });
293 // Replace % and # with _
294 llvm::replace(Range&: str, OldValue: '%', NewValue: '_');
295 llvm::replace(Range&: str, OldValue: '#', NewValue: '_');
296 return str;
297 }
298
299 std::string getClusterLabel(Operation *op) {
300 return strFromOs(func: [&](raw_ostream &os) {
301 // Print operation name and type.
302 os << op->getName();
303 if (printResultTypes) {
304 os << " : (";
305 std::string buf;
306 llvm::raw_string_ostream ss(buf);
307 interleaveComma(c: op->getResultTypes(), os&: ss);
308 os << truncateString(str: buf) << ")";
309 }
310
311 // Print attributes.
312 if (printAttrs) {
313 os << "\\l";
314 for (const NamedAttribute &attr : op->getAttrs()) {
315 os << escapeLabelString(str: attr.getName().getValue().str()) << ": ";
316 emitMlirAttr(os, attr: attr.getValue());
317 os << "\\l";
318 }
319 }
320 });
321 }
322
323 /// Generate a label for an operation.
324 std::string getRecordLabel(Operation *op) {
325 return strFromOs(func: [&](raw_ostream &os) {
326 os << "{";
327
328 // Print operation inputs.
329 if (op->getNumOperands() > 0) {
330 os << "{";
331 auto operandToPort = [&](Value operand) {
332 os << "<arg" << getValuePortName(operand) << "> ";
333 emitMlirOperand(os, operand);
334 };
335 interleave(c: op->getOperands(), os, each_fn: operandToPort, separator: "|");
336 os << "}|";
337 }
338 // Print operation name and type.
339 os << op->getName() << "\\l";
340
341 // Print attributes.
342 if (printAttrs && !op->getAttrs().empty()) {
343 // Extra line break to separate attributes from the operation name.
344 os << "\\l";
345 for (const NamedAttribute &attr : op->getAttrs()) {
346 os << attr.getName().getValue() << ": ";
347 emitMlirAttr(os, attr: attr.getValue());
348 os << "\\l";
349 }
350 }
351
352 if (op->getNumResults() > 0) {
353 os << "|{";
354 auto resultToPort = [&](Value result) {
355 os << "<res" << getValuePortName(operand: result) << "> ";
356 emitMlirOperand(os, operand: result);
357 if (printResultTypes) {
358 os << " ";
359 emitMlirType(os, type: result.getType());
360 }
361 };
362 interleave(c: op->getResults(), os, each_fn: resultToPort, separator: "|");
363 os << "}";
364 }
365
366 os << "}";
367 });
368 }
369
370 /// Generate a label for a block argument.
371 std::string getLabel(BlockArgument arg) {
372 return strFromOs(func: [&](raw_ostream &os) {
373 os << "<res" << getValuePortName(operand: arg) << "> ";
374 arg.printAsOperand(os, flags: OpPrintingFlags());
375 if (printResultTypes) {
376 os << " ";
377 emitMlirType(os, type: arg.getType());
378 }
379 });
380 }
381
382 /// Process a block. Emit a cluster and one node per block argument and
383 /// operation inside the cluster.
384 void processBlock(Block &block) {
385 emitClusterStmt(builder: [&]() {
386 for (BlockArgument &blockArg : block.getArguments())
387 valueToNode[blockArg] = emitNodeStmt(label: getLabel(arg: blockArg));
388 // Emit a node for each operation.
389 std::optional<Node> prevNode;
390 for (Operation &op : block) {
391 Node nextNode = processOperation(op: &op);
392 if (printControlFlowEdges && prevNode)
393 emitEdgeStmt(n1: *prevNode, n2: nextNode, /*port=*/"", style: kLineStyleControlFlow);
394 prevNode = nextNode;
395 }
396 });
397 }
398
399 /// Process an operation. If the operation has regions, emit a cluster.
400 /// Otherwise, emit a node.
401 Node processOperation(Operation *op) {
402 Node node;
403 if (op->getNumRegions() > 0) {
404 // Emit cluster for op with regions.
405 node = emitClusterStmt(
406 builder: [&]() {
407 for (Region &region : op->getRegions())
408 processRegion(region);
409 },
410 label: getClusterLabel(op));
411 } else {
412 node = emitNodeStmt(label: getRecordLabel(op), shape: kShapeNode,
413 background: backgroundColors[op->getName()].second);
414 }
415
416 // Insert data flow edges originating from each operand.
417 if (printDataFlowEdges) {
418 unsigned numOperands = op->getNumOperands();
419 for (unsigned i = 0; i < numOperands; i++) {
420 auto operand = op->getOperand(idx: i);
421 dataFlowEdges.push_back(x: {.value: operand, .node: node, .port: getValuePortName(operand)});
422 }
423 }
424
425 for (Value result : op->getResults())
426 valueToNode[result] = node;
427
428 return node;
429 }
430
431 /// Process a region.
432 void processRegion(Region &region) {
433 for (Block &block : region.getBlocks())
434 processBlock(block);
435 }
436
437 /// Truncate long strings.
438 std::string truncateString(std::string str) {
439 if (str.length() <= maxLabelLen)
440 return str;
441 return str.substr(pos: 0, n: maxLabelLen) + "...";
442 }
443
444 /// Output stream to write DOT file to.
445 raw_indented_ostream os;
446 /// A list of edges. For simplicity, should be emitted after all nodes were
447 /// emitted.
448 std::vector<std::string> edges;
449 /// Mapping of SSA values to Graphviz nodes/clusters.
450 DenseMap<Value, Node> valueToNode;
451 /// Output for data flow edges is delayed until the end to handle cycles
452 std::vector<DataFlowEdge> dataFlowEdges;
453 /// Counter for generating unique node/subgraph identifiers.
454 int counter = 0;
455
456 DenseMap<OperationName, std::pair<int, std::string>> backgroundColors;
457};
458
459} // namespace
460
461std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) {
462 return std::make_unique<PrintOpPass>(args&: os);
463}
464
465/// Generate a CFG for a region and show it in a window.
466static void llvmViewGraph(Region &region, const Twine &name) {
467 int fd;
468 std::string filename = llvm::createGraphFilename(Name: name.str(), FD&: fd);
469 {
470 llvm::raw_fd_ostream os(fd, /*shouldClose=*/true);
471 if (fd == -1) {
472 llvm::errs() << "error opening file '" << filename << "' for writing\n";
473 return;
474 }
475 PrintOpPass pass(os);
476 pass.emitRegionCFG(region);
477 }
478 llvm::DisplayGraph(Filename: filename, /*wait=*/false, program: llvm::GraphProgram::DOT);
479}
480
481void mlir::Region::viewGraph(const Twine &regionName) {
482 llvmViewGraph(region&: *this, name: regionName);
483}
484
485void mlir::Region::viewGraph() { viewGraph(regionName: "region"); }
486

source code of mlir/lib/Transforms/ViewOpGraph.cpp