1//===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Transforms/ViewOpGraph.h"
10
11#include "mlir/IR/Block.h"
12#include "mlir/IR/BuiltinTypes.h"
13#include "mlir/IR/Operation.h"
14#include "mlir/Pass/Pass.h"
15#include "mlir/Support/IndentedOstream.h"
16#include "mlir/Transforms/TopologicalSortUtils.h"
17#include "llvm/Support/Format.h"
18#include "llvm/Support/GraphWriter.h"
19#include <map>
20#include <optional>
21#include <utility>
22
23namespace mlir {
24#define GEN_PASS_DEF_VIEWOPGRAPH
25#include "mlir/Transforms/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29
30static const StringRef kLineStyleControlFlow = "dashed";
31static const StringRef kLineStyleDataFlow = "solid";
32static const StringRef kShapeNode = "ellipse";
33static const StringRef kShapeNone = "plain";
34
35/// Return the size limits for eliding large attributes.
36static int64_t getLargeAttributeSizeLimit() {
37 // Use the default from the printer flags if possible.
38 if (std::optional<int64_t> limit =
39 OpPrintingFlags().getLargeElementsAttrLimit())
40 return *limit;
41 return 16;
42}
43
44/// Return all values printed onto a stream as a string.
45static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
46 std::string buf;
47 llvm::raw_string_ostream os(buf);
48 func(os);
49 return os.str();
50}
51
52/// Escape special characters such as '\n' and quotation marks.
53static std::string escapeString(std::string str) {
54 return strFromOs([&](raw_ostream &os) { os.write_escaped(Str: str); });
55}
56
57/// Put quotation marks around a given string.
58static std::string quoteString(const std::string &str) {
59 return "\"" + str + "\"";
60}
61
62using AttributeMap = std::map<std::string, std::string>;
63
64namespace {
65
66/// This struct represents a node in the DOT language. Each node has an
67/// identifier and an optional identifier for the cluster (subgraph) that
68/// contains the node.
69/// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
70/// not between clusters. However, edges can be clipped to the boundary of a
71/// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
72/// cluster, an invisible "anchor" node is created.
73struct Node {
74public:
75 Node(int id = 0, std::optional<int> clusterId = std::nullopt)
76 : id(id), clusterId(clusterId) {}
77
78 int id;
79 std::optional<int> clusterId;
80};
81
82/// This pass generates a Graphviz dataflow visualization of an MLIR operation.
83/// Note: See https://www.graphviz.org/doc/info/lang.html for more information
84/// about the Graphviz DOT language.
85class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
86public:
87 PrintOpPass(raw_ostream &os) : os(os) {}
88 PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
89
90 void runOnOperation() override {
91 initColorMapping(*getOperation());
92 emitGraph(builder: [&]() {
93 processOperation(op: getOperation());
94 emitAllEdgeStmts();
95 });
96 }
97
98 /// Create a CFG graph for a region. Used in `Region::viewGraph`.
99 void emitRegionCFG(Region &region) {
100 printControlFlowEdges = true;
101 printDataFlowEdges = false;
102 initColorMapping(irEntity&: region);
103 emitGraph(builder: [&]() { processRegion(region); });
104 }
105
106private:
107 /// Generate a color mapping that will color every operation with the same
108 /// name the same way. It'll interpolate the hue in the HSV color-space,
109 /// attempting to keep the contrast suitable for black text.
110 template <typename T>
111 void initColorMapping(T &irEntity) {
112 backgroundColors.clear();
113 SmallVector<Operation *> ops;
114 irEntity.walk([&](Operation *op) {
115 auto &entry = backgroundColors[op->getName()];
116 if (entry.first == 0)
117 ops.push_back(Elt: op);
118 ++entry.first;
119 });
120 for (auto indexedOps : llvm::enumerate(First&: ops)) {
121 double hue = ((double)indexedOps.index()) / ops.size();
122 backgroundColors[indexedOps.value()->getName()].second =
123 std::to_string(val: hue) + " 1.0 1.0";
124 }
125 }
126
127 /// Emit all edges. This function should be called after all nodes have been
128 /// emitted.
129 void emitAllEdgeStmts() {
130 if (printDataFlowEdges) {
131 for (const auto &[value, node, label] : dataFlowEdges) {
132 emitEdgeStmt(n1: valueToNode[value], n2: node, label, style: kLineStyleDataFlow);
133 }
134 }
135
136 for (const std::string &edge : edges)
137 os << edge << ";\n";
138 edges.clear();
139 }
140
141 /// Emit a cluster (subgraph). The specified builder generates the body of the
142 /// cluster. Return the anchor node of the cluster.
143 Node emitClusterStmt(function_ref<void()> builder, std::string label = "") {
144 int clusterId = ++counter;
145 os << "subgraph cluster_" << clusterId << " {\n";
146 os.indent();
147 // Emit invisible anchor node from/to which arrows can be drawn.
148 Node anchorNode = emitNodeStmt(label: " ", shape: kShapeNone);
149 os << attrStmt(key: "label", value: quoteString(str: escapeString(str: std::move(label))))
150 << ";\n";
151 builder();
152 os.unindent();
153 os << "}\n";
154 return Node(anchorNode.id, clusterId);
155 }
156
157 /// Generate an attribute statement.
158 std::string attrStmt(const Twine &key, const Twine &value) {
159 return (key + " = " + value).str();
160 }
161
162 /// Emit an attribute list.
163 void emitAttrList(raw_ostream &os, const AttributeMap &map) {
164 os << "[";
165 interleaveComma(c: map, os, each_fn: [&](const auto &it) {
166 os << this->attrStmt(key: it.first, value: it.second);
167 });
168 os << "]";
169 }
170
171 // Print an MLIR attribute to `os`. Large attributes are truncated.
172 void emitMlirAttr(raw_ostream &os, Attribute attr) {
173 // A value used to elide large container attribute.
174 int64_t largeAttrLimit = getLargeAttributeSizeLimit();
175
176 // Always emit splat attributes.
177 if (isa<SplatElementsAttr>(Val: attr)) {
178 attr.print(os);
179 return;
180 }
181
182 // Elide "big" elements attributes.
183 auto elements = dyn_cast<ElementsAttr>(attr);
184 if (elements && elements.getNumElements() > largeAttrLimit) {
185 os << std::string(elements.getShapedType().getRank(), '[') << "..."
186 << std::string(elements.getShapedType().getRank(), ']') << " : "
187 << elements.getType();
188 return;
189 }
190
191 auto array = dyn_cast<ArrayAttr>(attr);
192 if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
193 os << "[...]";
194 return;
195 }
196
197 // Print all other attributes.
198 std::string buf;
199 llvm::raw_string_ostream ss(buf);
200 attr.print(os&: ss);
201 os << truncateString(str: ss.str());
202 }
203
204 /// Append an edge to the list of edges.
205 /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
206 void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) {
207 AttributeMap attrs;
208 attrs["style"] = style.str();
209 // Do not label edges that start/end at a cluster boundary. Such edges are
210 // clipped at the boundary, but labels are not. This can lead to labels
211 // floating around without any edge next to them.
212 if (!n1.clusterId && !n2.clusterId)
213 attrs["label"] = quoteString(str: escapeString(str: std::move(label)));
214 // Use `ltail` and `lhead` to draw edges between clusters.
215 if (n1.clusterId)
216 attrs["ltail"] = "cluster_" + std::to_string(val: *n1.clusterId);
217 if (n2.clusterId)
218 attrs["lhead"] = "cluster_" + std::to_string(val: *n2.clusterId);
219
220 edges.push_back(x: strFromOs([&](raw_ostream &os) {
221 os << llvm::format(Fmt: "v%i -> v%i ", Vals: n1.id, Vals: n2.id);
222 emitAttrList(os, map: attrs);
223 }));
224 }
225
226 /// Emit a graph. The specified builder generates the body of the graph.
227 void emitGraph(function_ref<void()> builder) {
228 os << "digraph G {\n";
229 os.indent();
230 // Edges between clusters are allowed only in compound mode.
231 os << attrStmt(key: "compound", value: "true") << ";\n";
232 builder();
233 os.unindent();
234 os << "}\n";
235 }
236
237 /// Emit a node statement.
238 Node emitNodeStmt(std::string label, StringRef shape = kShapeNode,
239 StringRef background = "") {
240 int nodeId = ++counter;
241 AttributeMap attrs;
242 attrs["label"] = quoteString(str: escapeString(str: std::move(label)));
243 attrs["shape"] = shape.str();
244 if (!background.empty()) {
245 attrs["style"] = "filled";
246 attrs["fillcolor"] = ("\"" + background + "\"").str();
247 }
248 os << llvm::format(Fmt: "v%i ", Vals: nodeId);
249 emitAttrList(os, map: attrs);
250 os << ";\n";
251 return Node(nodeId);
252 }
253
254 /// Generate a label for an operation.
255 std::string getLabel(Operation *op) {
256 return strFromOs([&](raw_ostream &os) {
257 // Print operation name and type.
258 os << op->getName();
259 if (printResultTypes) {
260 os << " : (";
261 std::string buf;
262 llvm::raw_string_ostream ss(buf);
263 interleaveComma(c: op->getResultTypes(), os&: ss);
264 os << truncateString(str: ss.str()) << ")";
265 }
266
267 // Print attributes.
268 if (printAttrs) {
269 os << "\n";
270 for (const NamedAttribute &attr : op->getAttrs()) {
271 os << '\n' << attr.getName().getValue() << ": ";
272 emitMlirAttr(os, attr: attr.getValue());
273 }
274 }
275 });
276 }
277
278 /// Generate a label for a block argument.
279 std::string getLabel(BlockArgument arg) {
280 return "arg" + std::to_string(val: arg.getArgNumber());
281 }
282
283 /// Process a block. Emit a cluster and one node per block argument and
284 /// operation inside the cluster.
285 void processBlock(Block &block) {
286 emitClusterStmt(builder: [&]() {
287 for (BlockArgument &blockArg : block.getArguments())
288 valueToNode[blockArg] = emitNodeStmt(label: getLabel(arg: blockArg));
289
290 // Emit a node for each operation.
291 std::optional<Node> prevNode;
292 for (Operation &op : block) {
293 Node nextNode = processOperation(op: &op);
294 if (printControlFlowEdges && prevNode)
295 emitEdgeStmt(n1: *prevNode, n2: nextNode, /*label=*/"",
296 style: kLineStyleControlFlow);
297 prevNode = nextNode;
298 }
299 });
300 }
301
302 /// Process an operation. If the operation has regions, emit a cluster.
303 /// Otherwise, emit a node.
304 Node processOperation(Operation *op) {
305 Node node;
306 if (op->getNumRegions() > 0) {
307 // Emit cluster for op with regions.
308 node = emitClusterStmt(
309 builder: [&]() {
310 for (Region &region : op->getRegions())
311 processRegion(region);
312 },
313 label: getLabel(op));
314 } else {
315 node = emitNodeStmt(label: getLabel(op), shape: kShapeNode,
316 background: backgroundColors[op->getName()].second);
317 }
318
319 // Insert data flow edges originating from each operand.
320 if (printDataFlowEdges) {
321 unsigned numOperands = op->getNumOperands();
322 for (unsigned i = 0; i < numOperands; i++)
323 dataFlowEdges.push_back(x: {op->getOperand(idx: i), node,
324 numOperands == 1 ? "" : std::to_string(val: i)});
325 }
326
327 for (Value result : op->getResults())
328 valueToNode[result] = node;
329
330 return node;
331 }
332
333 /// Process a region.
334 void processRegion(Region &region) {
335 for (Block &block : region.getBlocks())
336 processBlock(block);
337 }
338
339 /// Truncate long strings.
340 std::string truncateString(std::string str) {
341 if (str.length() <= maxLabelLen)
342 return str;
343 return str.substr(0, maxLabelLen) + "...";
344 }
345
346 /// Output stream to write DOT file to.
347 raw_indented_ostream os;
348 /// A list of edges. For simplicity, should be emitted after all nodes were
349 /// emitted.
350 std::vector<std::string> edges;
351 /// Mapping of SSA values to Graphviz nodes/clusters.
352 DenseMap<Value, Node> valueToNode;
353 /// Output for data flow edges is delayed until the end to handle cycles
354 std::vector<std::tuple<Value, Node, std::string>> dataFlowEdges;
355 /// Counter for generating unique node/subgraph identifiers.
356 int counter = 0;
357
358 DenseMap<OperationName, std::pair<int, std::string>> backgroundColors;
359};
360
361} // namespace
362
363std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) {
364 return std::make_unique<PrintOpPass>(args&: os);
365}
366
367/// Generate a CFG for a region and show it in a window.
368static void llvmViewGraph(Region &region, const Twine &name) {
369 int fd;
370 std::string filename = llvm::createGraphFilename(Name: name.str(), FD&: fd);
371 {
372 llvm::raw_fd_ostream os(fd, /*shouldClose=*/true);
373 if (fd == -1) {
374 llvm::errs() << "error opening file '" << filename << "' for writing\n";
375 return;
376 }
377 PrintOpPass pass(os);
378 pass.emitRegionCFG(region);
379 }
380 llvm::DisplayGraph(Filename: filename, /*wait=*/false, program: llvm::GraphProgram::DOT);
381}
382
383void mlir::Region::viewGraph(const Twine &regionName) {
384 llvmViewGraph(region&: *this, name: regionName);
385}
386
387void mlir::Region::viewGraph() { viewGraph(regionName: "region"); }
388

source code of mlir/lib/Transforms/ViewOpGraph.cpp