| 1 | //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Transforms/ViewOpGraph.h" |
| 10 | |
| 11 | #include "mlir/Analysis/TopologicalSortUtils.h" |
| 12 | #include "mlir/IR/Block.h" |
| 13 | #include "mlir/IR/BuiltinTypes.h" |
| 14 | #include "mlir/IR/Operation.h" |
| 15 | #include "mlir/Pass/Pass.h" |
| 16 | #include "mlir/Support/IndentedOstream.h" |
| 17 | #include "llvm/ADT/STLExtras.h" |
| 18 | #include "llvm/Support/Format.h" |
| 19 | #include "llvm/Support/GraphWriter.h" |
| 20 | #include <map> |
| 21 | #include <optional> |
| 22 | #include <utility> |
| 23 | |
| 24 | namespace mlir { |
| 25 | #define GEN_PASS_DEF_VIEWOPGRAPH |
| 26 | #include "mlir/Transforms/Passes.h.inc" |
| 27 | } // namespace mlir |
| 28 | |
| 29 | using namespace mlir; |
| 30 | |
| 31 | static const StringRef kLineStyleControlFlow = "dashed" ; |
| 32 | static const StringRef kLineStyleDataFlow = "solid" ; |
| 33 | static const StringRef kShapeNode = "Mrecord" ; |
| 34 | static const StringRef kShapeNone = "plain" ; |
| 35 | |
| 36 | /// Return the size limits for eliding large attributes. |
| 37 | static int64_t getLargeAttributeSizeLimit() { |
| 38 | // Use the default from the printer flags if possible. |
| 39 | if (std::optional<int64_t> limit = |
| 40 | OpPrintingFlags().getLargeElementsAttrLimit()) |
| 41 | return *limit; |
| 42 | return 16; |
| 43 | } |
| 44 | |
| 45 | /// Return all values printed onto a stream as a string. |
| 46 | static std::string strFromOs(function_ref<void(raw_ostream &)> func) { |
| 47 | std::string buf; |
| 48 | llvm::raw_string_ostream os(buf); |
| 49 | func(os); |
| 50 | return buf; |
| 51 | } |
| 52 | |
| 53 | /// Put quotation marks around a given string. |
| 54 | static std::string quoteString(const std::string &str) { |
| 55 | return "\"" + str + "\"" ; |
| 56 | } |
| 57 | |
| 58 | /// For Graphviz record nodes: |
| 59 | /// " Braces, vertical bars and angle brackets must be escaped with a backslash |
| 60 | /// character if you wish them to appear as a literal character " |
| 61 | std::string escapeLabelString(const std::string &str) { |
| 62 | std::string buf; |
| 63 | llvm::raw_string_ostream os(buf); |
| 64 | for (char c : str) { |
| 65 | if (llvm::is_contained(Set: {'{', '|', '<', '}', '>', '\n', '"'}, Element: c)) |
| 66 | os << '\\'; |
| 67 | os << c; |
| 68 | } |
| 69 | return buf; |
| 70 | } |
| 71 | |
| 72 | using AttributeMap = std::map<std::string, std::string>; |
| 73 | |
| 74 | namespace { |
| 75 | |
| 76 | /// This struct represents a node in the DOT language. Each node has an |
| 77 | /// identifier and an optional identifier for the cluster (subgraph) that |
| 78 | /// contains the node. |
| 79 | /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but |
| 80 | /// not between clusters. However, edges can be clipped to the boundary of a |
| 81 | /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new |
| 82 | /// cluster, an invisible "anchor" node is created. |
| 83 | struct Node { |
| 84 | public: |
| 85 | Node(int id = 0, std::optional<int> clusterId = std::nullopt) |
| 86 | : id(id), clusterId(clusterId) {} |
| 87 | |
| 88 | int id; |
| 89 | std::optional<int> clusterId; |
| 90 | }; |
| 91 | |
| 92 | struct DataFlowEdge { |
| 93 | Value value; |
| 94 | Node node; |
| 95 | std::string port; |
| 96 | }; |
| 97 | |
| 98 | /// This pass generates a Graphviz dataflow visualization of an MLIR operation. |
| 99 | /// Note: See https://www.graphviz.org/doc/info/lang.html for more information |
| 100 | /// about the Graphviz DOT language. |
| 101 | class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> { |
| 102 | public: |
| 103 | PrintOpPass(raw_ostream &os) : os(os) {} |
| 104 | PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {} |
| 105 | |
| 106 | void runOnOperation() override { |
| 107 | initColorMapping(*getOperation()); |
| 108 | emitGraph(builder: [&]() { |
| 109 | processOperation(op: getOperation()); |
| 110 | emitAllEdgeStmts(); |
| 111 | }); |
| 112 | markAllAnalysesPreserved(); |
| 113 | } |
| 114 | |
| 115 | /// Create a CFG graph for a region. Used in `Region::viewGraph`. |
| 116 | void emitRegionCFG(Region ®ion) { |
| 117 | printControlFlowEdges = true; |
| 118 | printDataFlowEdges = false; |
| 119 | initColorMapping(irEntity&: region); |
| 120 | emitGraph(builder: [&]() { processRegion(region); }); |
| 121 | } |
| 122 | |
| 123 | private: |
| 124 | /// Generate a color mapping that will color every operation with the same |
| 125 | /// name the same way. It'll interpolate the hue in the HSV color-space, |
| 126 | /// using muted colors that provide good contrast for black text. |
| 127 | template <typename T> |
| 128 | void initColorMapping(T &irEntity) { |
| 129 | backgroundColors.clear(); |
| 130 | SmallVector<Operation *> ops; |
| 131 | irEntity.walk([&](Operation *op) { |
| 132 | auto &entry = backgroundColors[op->getName()]; |
| 133 | if (entry.first == 0) |
| 134 | ops.push_back(Elt: op); |
| 135 | ++entry.first; |
| 136 | }); |
| 137 | for (auto indexedOps : llvm::enumerate(First&: ops)) { |
| 138 | double hue = ((double)indexedOps.index()) / ops.size(); |
| 139 | // Use lower saturation (0.3) and higher value (0.95) for better |
| 140 | // readability |
| 141 | backgroundColors[indexedOps.value()->getName()].second = |
| 142 | std::to_string(val: hue) + " 0.3 0.95" ; |
| 143 | } |
| 144 | } |
| 145 | |
| 146 | /// Emit all edges. This function should be called after all nodes have been |
| 147 | /// emitted. |
| 148 | void emitAllEdgeStmts() { |
| 149 | if (printDataFlowEdges) { |
| 150 | for (const auto &e : dataFlowEdges) { |
| 151 | emitEdgeStmt(n1: valueToNode[e.value], n2: e.node, port: e.port, style: kLineStyleDataFlow); |
| 152 | } |
| 153 | } |
| 154 | |
| 155 | for (const std::string &edge : edges) |
| 156 | os << edge << ";\n" ; |
| 157 | edges.clear(); |
| 158 | } |
| 159 | |
| 160 | /// Emit a cluster (subgraph). The specified builder generates the body of the |
| 161 | /// cluster. Return the anchor node of the cluster. |
| 162 | Node emitClusterStmt(function_ref<void()> builder, std::string label = "" ) { |
| 163 | int clusterId = ++counter; |
| 164 | os << "subgraph cluster_" << clusterId << " {\n" ; |
| 165 | os.indent(); |
| 166 | // Emit invisible anchor node from/to which arrows can be drawn. |
| 167 | Node anchorNode = emitNodeStmt(label: " " , shape: kShapeNone); |
| 168 | os << attrStmt(key: "label" , value: quoteString(str: label)) << ";\n" ; |
| 169 | builder(); |
| 170 | os.unindent(); |
| 171 | os << "}\n" ; |
| 172 | return Node(anchorNode.id, clusterId); |
| 173 | } |
| 174 | |
| 175 | /// Generate an attribute statement. |
| 176 | std::string attrStmt(const Twine &key, const Twine &value) { |
| 177 | return (key + " = " + value).str(); |
| 178 | } |
| 179 | |
| 180 | /// Emit an attribute list. |
| 181 | void emitAttrList(raw_ostream &os, const AttributeMap &map) { |
| 182 | os << "[" ; |
| 183 | interleaveComma(c: map, os, each_fn: [&](const auto &it) { |
| 184 | os << this->attrStmt(key: it.first, value: it.second); |
| 185 | }); |
| 186 | os << "]" ; |
| 187 | } |
| 188 | |
| 189 | // Print an MLIR attribute to `os`. Large attributes are truncated. |
| 190 | void emitMlirAttr(raw_ostream &os, Attribute attr) { |
| 191 | // A value used to elide large container attribute. |
| 192 | int64_t largeAttrLimit = getLargeAttributeSizeLimit(); |
| 193 | |
| 194 | // Always emit splat attributes. |
| 195 | if (isa<SplatElementsAttr>(Val: attr)) { |
| 196 | os << escapeLabelString( |
| 197 | str: strFromOs([&](raw_ostream &os) { attr.print(os); })); |
| 198 | return; |
| 199 | } |
| 200 | |
| 201 | // Elide "big" elements attributes. |
| 202 | auto elements = dyn_cast<ElementsAttr>(attr); |
| 203 | if (elements && elements.getNumElements() > largeAttrLimit) { |
| 204 | os << std::string(elements.getShapedType().getRank(), '[') << "..." |
| 205 | << std::string(elements.getShapedType().getRank(), ']') << " : " ; |
| 206 | emitMlirType(os, type: elements.getType()); |
| 207 | return; |
| 208 | } |
| 209 | |
| 210 | auto array = dyn_cast<ArrayAttr>(attr); |
| 211 | if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) { |
| 212 | os << "[...]" ; |
| 213 | return; |
| 214 | } |
| 215 | |
| 216 | // Print all other attributes. |
| 217 | std::string buf; |
| 218 | llvm::raw_string_ostream ss(buf); |
| 219 | attr.print(os&: ss); |
| 220 | os << escapeLabelString(str: truncateString(str: buf)); |
| 221 | } |
| 222 | |
| 223 | // Print a truncated and escaped MLIR type to `os`. |
| 224 | void emitMlirType(raw_ostream &os, Type type) { |
| 225 | std::string buf; |
| 226 | llvm::raw_string_ostream ss(buf); |
| 227 | type.print(os&: ss); |
| 228 | os << escapeLabelString(str: truncateString(str: buf)); |
| 229 | } |
| 230 | |
| 231 | // Print a truncated and escaped MLIR operand to `os`. |
| 232 | void emitMlirOperand(raw_ostream &os, Value operand) { |
| 233 | operand.printAsOperand(os, flags: OpPrintingFlags()); |
| 234 | } |
| 235 | |
| 236 | /// Append an edge to the list of edges. |
| 237 | /// Note: Edges are written to the output stream via `emitAllEdgeStmts`. |
| 238 | void emitEdgeStmt(Node n1, Node n2, std::string port, StringRef style) { |
| 239 | AttributeMap attrs; |
| 240 | attrs["style" ] = style.str(); |
| 241 | // Use `ltail` and `lhead` to draw edges between clusters. |
| 242 | if (n1.clusterId) |
| 243 | attrs["ltail" ] = "cluster_" + std::to_string(val: *n1.clusterId); |
| 244 | if (n2.clusterId) |
| 245 | attrs["lhead" ] = "cluster_" + std::to_string(val: *n2.clusterId); |
| 246 | |
| 247 | edges.push_back(x: strFromOs([&](raw_ostream &os) { |
| 248 | os << "v" << n1.id; |
| 249 | if (!port.empty() && !n1.clusterId) |
| 250 | // Attach edge to south compass point of the result |
| 251 | os << ":res" << port << ":s" ; |
| 252 | os << " -> " ; |
| 253 | os << "v" << n2.id; |
| 254 | if (!port.empty() && !n2.clusterId) |
| 255 | // Attach edge to north compass point of the operand |
| 256 | os << ":arg" << port << ":n" ; |
| 257 | emitAttrList(os, map: attrs); |
| 258 | })); |
| 259 | } |
| 260 | |
| 261 | /// Emit a graph. The specified builder generates the body of the graph. |
| 262 | void emitGraph(function_ref<void()> builder) { |
| 263 | os << "digraph G {\n" ; |
| 264 | os.indent(); |
| 265 | // Edges between clusters are allowed only in compound mode. |
| 266 | os << attrStmt(key: "compound" , value: "true" ) << ";\n" ; |
| 267 | builder(); |
| 268 | os.unindent(); |
| 269 | os << "}\n" ; |
| 270 | } |
| 271 | |
| 272 | /// Emit a node statement. |
| 273 | Node emitNodeStmt(std::string label, StringRef shape = kShapeNode, |
| 274 | StringRef background = "" ) { |
| 275 | int nodeId = ++counter; |
| 276 | AttributeMap attrs; |
| 277 | attrs["label" ] = quoteString(str: label); |
| 278 | attrs["shape" ] = shape.str(); |
| 279 | if (!background.empty()) { |
| 280 | attrs["style" ] = "filled" ; |
| 281 | attrs["fillcolor" ] = quoteString(str: background.str()); |
| 282 | } |
| 283 | os << llvm::format(Fmt: "v%i " , Vals: nodeId); |
| 284 | emitAttrList(os, map: attrs); |
| 285 | os << ";\n" ; |
| 286 | return Node(nodeId); |
| 287 | } |
| 288 | |
| 289 | std::string getValuePortName(Value operand) { |
| 290 | // Print value as an operand and omit the leading '%' character. |
| 291 | auto str = strFromOs([&](raw_ostream &os) { |
| 292 | operand.printAsOperand(os, flags: OpPrintingFlags()); |
| 293 | }); |
| 294 | // Replace % and # with _ |
| 295 | llvm::replace(str, '%', '_'); |
| 296 | llvm::replace(str, '#', '_'); |
| 297 | return str; |
| 298 | } |
| 299 | |
| 300 | std::string getClusterLabel(Operation *op) { |
| 301 | return strFromOs([&](raw_ostream &os) { |
| 302 | // Print operation name and type. |
| 303 | os << op->getName(); |
| 304 | if (printResultTypes) { |
| 305 | os << " : (" ; |
| 306 | std::string buf; |
| 307 | llvm::raw_string_ostream ss(buf); |
| 308 | interleaveComma(c: op->getResultTypes(), os&: ss); |
| 309 | os << truncateString(str: buf) << ")" ; |
| 310 | } |
| 311 | |
| 312 | // Print attributes. |
| 313 | if (printAttrs) { |
| 314 | os << "\\l" ; |
| 315 | for (const NamedAttribute &attr : op->getAttrs()) { |
| 316 | os << escapeLabelString(attr.getName().getValue().str()) << ": " ; |
| 317 | emitMlirAttr(os, attr: attr.getValue()); |
| 318 | os << "\\l" ; |
| 319 | } |
| 320 | } |
| 321 | }); |
| 322 | } |
| 323 | |
| 324 | /// Generate a label for an operation. |
| 325 | std::string getRecordLabel(Operation *op) { |
| 326 | return strFromOs([&](raw_ostream &os) { |
| 327 | os << "{" ; |
| 328 | |
| 329 | // Print operation inputs. |
| 330 | if (op->getNumOperands() > 0) { |
| 331 | os << "{" ; |
| 332 | auto operandToPort = [&](Value operand) { |
| 333 | os << "<arg" << getValuePortName(operand) << "> " ; |
| 334 | emitMlirOperand(os, operand); |
| 335 | }; |
| 336 | interleave(c: op->getOperands(), os, each_fn: operandToPort, separator: "|" ); |
| 337 | os << "}|" ; |
| 338 | } |
| 339 | // Print operation name and type. |
| 340 | os << op->getName() << "\\l" ; |
| 341 | |
| 342 | // Print attributes. |
| 343 | if (printAttrs && !op->getAttrs().empty()) { |
| 344 | // Extra line break to separate attributes from the operation name. |
| 345 | os << "\\l" ; |
| 346 | for (const NamedAttribute &attr : op->getAttrs()) { |
| 347 | os << attr.getName().getValue() << ": " ; |
| 348 | emitMlirAttr(os, attr: attr.getValue()); |
| 349 | os << "\\l" ; |
| 350 | } |
| 351 | } |
| 352 | |
| 353 | if (op->getNumResults() > 0) { |
| 354 | os << "|{" ; |
| 355 | auto resultToPort = [&](Value result) { |
| 356 | os << "<res" << getValuePortName(operand: result) << "> " ; |
| 357 | emitMlirOperand(os, operand: result); |
| 358 | if (printResultTypes) { |
| 359 | os << " " ; |
| 360 | emitMlirType(os, type: result.getType()); |
| 361 | } |
| 362 | }; |
| 363 | interleave(c: op->getResults(), os, each_fn: resultToPort, separator: "|" ); |
| 364 | os << "}" ; |
| 365 | } |
| 366 | |
| 367 | os << "}" ; |
| 368 | }); |
| 369 | } |
| 370 | |
| 371 | /// Generate a label for a block argument. |
| 372 | std::string getLabel(BlockArgument arg) { |
| 373 | return strFromOs([&](raw_ostream &os) { |
| 374 | os << "<res" << getValuePortName(operand: arg) << "> " ; |
| 375 | arg.printAsOperand(os, flags: OpPrintingFlags()); |
| 376 | if (printResultTypes) { |
| 377 | os << " " ; |
| 378 | emitMlirType(os, type: arg.getType()); |
| 379 | } |
| 380 | }); |
| 381 | } |
| 382 | |
| 383 | /// Process a block. Emit a cluster and one node per block argument and |
| 384 | /// operation inside the cluster. |
| 385 | void processBlock(Block &block) { |
| 386 | emitClusterStmt(builder: [&]() { |
| 387 | for (BlockArgument &blockArg : block.getArguments()) |
| 388 | valueToNode[blockArg] = emitNodeStmt(label: getLabel(arg: blockArg)); |
| 389 | // Emit a node for each operation. |
| 390 | std::optional<Node> prevNode; |
| 391 | for (Operation &op : block) { |
| 392 | Node nextNode = processOperation(op: &op); |
| 393 | if (printControlFlowEdges && prevNode) |
| 394 | emitEdgeStmt(n1: *prevNode, n2: nextNode, /*port=*/"" , style: kLineStyleControlFlow); |
| 395 | prevNode = nextNode; |
| 396 | } |
| 397 | }); |
| 398 | } |
| 399 | |
| 400 | /// Process an operation. If the operation has regions, emit a cluster. |
| 401 | /// Otherwise, emit a node. |
| 402 | Node processOperation(Operation *op) { |
| 403 | Node node; |
| 404 | if (op->getNumRegions() > 0) { |
| 405 | // Emit cluster for op with regions. |
| 406 | node = emitClusterStmt( |
| 407 | builder: [&]() { |
| 408 | for (Region ®ion : op->getRegions()) |
| 409 | processRegion(region); |
| 410 | }, |
| 411 | label: getClusterLabel(op)); |
| 412 | } else { |
| 413 | node = emitNodeStmt(label: getRecordLabel(op), shape: kShapeNode, |
| 414 | background: backgroundColors[op->getName()].second); |
| 415 | } |
| 416 | |
| 417 | // Insert data flow edges originating from each operand. |
| 418 | if (printDataFlowEdges) { |
| 419 | unsigned numOperands = op->getNumOperands(); |
| 420 | for (unsigned i = 0; i < numOperands; i++) { |
| 421 | auto operand = op->getOperand(idx: i); |
| 422 | dataFlowEdges.push_back(x: {.value: operand, .node: node, .port: getValuePortName(operand)}); |
| 423 | } |
| 424 | } |
| 425 | |
| 426 | for (Value result : op->getResults()) |
| 427 | valueToNode[result] = node; |
| 428 | |
| 429 | return node; |
| 430 | } |
| 431 | |
| 432 | /// Process a region. |
| 433 | void processRegion(Region ®ion) { |
| 434 | for (Block &block : region.getBlocks()) |
| 435 | processBlock(block); |
| 436 | } |
| 437 | |
| 438 | /// Truncate long strings. |
| 439 | std::string truncateString(std::string str) { |
| 440 | if (str.length() <= maxLabelLen) |
| 441 | return str; |
| 442 | return str.substr(0, maxLabelLen) + "..." ; |
| 443 | } |
| 444 | |
| 445 | /// Output stream to write DOT file to. |
| 446 | raw_indented_ostream os; |
| 447 | /// A list of edges. For simplicity, should be emitted after all nodes were |
| 448 | /// emitted. |
| 449 | std::vector<std::string> edges; |
| 450 | /// Mapping of SSA values to Graphviz nodes/clusters. |
| 451 | DenseMap<Value, Node> valueToNode; |
| 452 | /// Output for data flow edges is delayed until the end to handle cycles |
| 453 | std::vector<DataFlowEdge> dataFlowEdges; |
| 454 | /// Counter for generating unique node/subgraph identifiers. |
| 455 | int counter = 0; |
| 456 | |
| 457 | DenseMap<OperationName, std::pair<int, std::string>> backgroundColors; |
| 458 | }; |
| 459 | |
| 460 | } // namespace |
| 461 | |
| 462 | std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) { |
| 463 | return std::make_unique<PrintOpPass>(args&: os); |
| 464 | } |
| 465 | |
| 466 | /// Generate a CFG for a region and show it in a window. |
| 467 | static void llvmViewGraph(Region ®ion, const Twine &name) { |
| 468 | int fd; |
| 469 | std::string filename = llvm::createGraphFilename(Name: name.str(), FD&: fd); |
| 470 | { |
| 471 | llvm::raw_fd_ostream os(fd, /*shouldClose=*/true); |
| 472 | if (fd == -1) { |
| 473 | llvm::errs() << "error opening file '" << filename << "' for writing\n" ; |
| 474 | return; |
| 475 | } |
| 476 | PrintOpPass pass(os); |
| 477 | pass.emitRegionCFG(region); |
| 478 | } |
| 479 | llvm::DisplayGraph(Filename: filename, /*wait=*/false, program: llvm::GraphProgram::DOT); |
| 480 | } |
| 481 | |
| 482 | void mlir::Region::viewGraph(const Twine ®ionName) { |
| 483 | llvmViewGraph(region&: *this, name: regionName); |
| 484 | } |
| 485 | |
| 486 | void mlir::Region::viewGraph() { viewGraph(regionName: "region" ); } |
| 487 | |