| 1 | //===- IRPrinting.cpp -----------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "PassDetail.h" |
| 10 | #include "mlir/IR/SymbolTable.h" |
| 11 | #include "mlir/Pass/PassManager.h" |
| 12 | #include "mlir/Support/FileUtilities.h" |
| 13 | #include "llvm/ADT/STLExtras.h" |
| 14 | #include "llvm/ADT/StringExtras.h" |
| 15 | #include "llvm/Support/FileSystem.h" |
| 16 | #include "llvm/Support/FormatVariadic.h" |
| 17 | #include "llvm/Support/Path.h" |
| 18 | #include "llvm/Support/ToolOutputFile.h" |
| 19 | |
| 20 | using namespace mlir; |
| 21 | using namespace mlir::detail; |
| 22 | |
| 23 | namespace { |
| 24 | //===----------------------------------------------------------------------===// |
| 25 | // IRPrinter |
| 26 | //===----------------------------------------------------------------------===// |
| 27 | |
| 28 | class IRPrinterInstrumentation : public PassInstrumentation { |
| 29 | public: |
| 30 | IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config) |
| 31 | : config(std::move(config)) {} |
| 32 | |
| 33 | private: |
| 34 | /// Instrumentation hooks. |
| 35 | void runBeforePass(Pass *pass, Operation *op) override; |
| 36 | void runAfterPass(Pass *pass, Operation *op) override; |
| 37 | void runAfterPassFailed(Pass *pass, Operation *op) override; |
| 38 | |
| 39 | /// Configuration to use. |
| 40 | std::unique_ptr<PassManager::IRPrinterConfig> config; |
| 41 | |
| 42 | /// The following is a set of fingerprints for operations that are currently |
| 43 | /// being operated on in a pass. This field is only used when the |
| 44 | /// configuration asked for change detection. |
| 45 | DenseMap<Pass *, OperationFingerPrint> beforePassFingerPrints; |
| 46 | }; |
| 47 | } // namespace |
| 48 | |
| 49 | static void printIR(Operation *op, bool printModuleScope, raw_ostream &out, |
| 50 | OpPrintingFlags flags) { |
| 51 | // Otherwise, check to see if we are not printing at module scope. |
| 52 | if (!printModuleScope) |
| 53 | return op->print(os&: out << " //----- //\n" , |
| 54 | flags: op->getBlock() ? flags.useLocalScope() : flags); |
| 55 | |
| 56 | // Otherwise, we are printing at module scope. |
| 57 | out << " ('" << op->getName() << "' operation" ; |
| 58 | if (auto symbolName = |
| 59 | op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())) |
| 60 | out << ": @" << symbolName.getValue(); |
| 61 | out << ") //----- //\n" ; |
| 62 | |
| 63 | // Find the top-level operation. |
| 64 | auto *topLevelOp = op; |
| 65 | while (auto *parentOp = topLevelOp->getParentOp()) |
| 66 | topLevelOp = parentOp; |
| 67 | topLevelOp->print(os&: out, flags); |
| 68 | } |
| 69 | |
| 70 | /// Instrumentation hooks. |
| 71 | void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) { |
| 72 | if (isa<OpToOpPassAdaptor>(Val: pass)) |
| 73 | return; |
| 74 | // If the config asked to detect changes, record the current fingerprint. |
| 75 | if (config->shouldPrintAfterOnlyOnChange()) |
| 76 | beforePassFingerPrints.try_emplace(Key: pass, Args&: op); |
| 77 | |
| 78 | config->printBeforeIfEnabled(pass, operation: op, printCallback: [&](raw_ostream &out) { |
| 79 | out << "// -----// IR Dump Before " << pass->getName() << " (" |
| 80 | << pass->getArgument() << ")" ; |
| 81 | printIR(op, printModuleScope: config->shouldPrintAtModuleScope(), out, |
| 82 | flags: config->getOpPrintingFlags()); |
| 83 | out << "\n\n" ; |
| 84 | }); |
| 85 | } |
| 86 | |
| 87 | void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) { |
| 88 | if (isa<OpToOpPassAdaptor>(Val: pass)) |
| 89 | return; |
| 90 | |
| 91 | // Check to see if we are only printing on failure. |
| 92 | if (config->shouldPrintAfterOnlyOnFailure()) |
| 93 | return; |
| 94 | |
| 95 | // If the config asked to detect changes, compare the current fingerprint with |
| 96 | // the previous. |
| 97 | if (config->shouldPrintAfterOnlyOnChange()) { |
| 98 | auto fingerPrintIt = beforePassFingerPrints.find(Val: pass); |
| 99 | assert(fingerPrintIt != beforePassFingerPrints.end() && |
| 100 | "expected valid fingerprint" ); |
| 101 | // If the fingerprints are the same, we don't print the IR. |
| 102 | if (fingerPrintIt->second == OperationFingerPrint(op)) { |
| 103 | beforePassFingerPrints.erase(I: fingerPrintIt); |
| 104 | return; |
| 105 | } |
| 106 | beforePassFingerPrints.erase(I: fingerPrintIt); |
| 107 | } |
| 108 | |
| 109 | config->printAfterIfEnabled(pass, operation: op, printCallback: [&](raw_ostream &out) { |
| 110 | out << "// -----// IR Dump After " << pass->getName() << " (" |
| 111 | << pass->getArgument() << ")" ; |
| 112 | printIR(op, printModuleScope: config->shouldPrintAtModuleScope(), out, |
| 113 | flags: config->getOpPrintingFlags()); |
| 114 | out << "\n\n" ; |
| 115 | }); |
| 116 | } |
| 117 | |
| 118 | void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) { |
| 119 | if (isa<OpToOpPassAdaptor>(Val: pass)) |
| 120 | return; |
| 121 | if (config->shouldPrintAfterOnlyOnChange()) |
| 122 | beforePassFingerPrints.erase(Val: pass); |
| 123 | |
| 124 | config->printAfterIfEnabled(pass, operation: op, printCallback: [&](raw_ostream &out) { |
| 125 | out << formatv(Fmt: "// -----// IR Dump After {0} Failed ({1})" , Vals: pass->getName(), |
| 126 | Vals: pass->getArgument()); |
| 127 | printIR(op, printModuleScope: config->shouldPrintAtModuleScope(), out, |
| 128 | flags: config->getOpPrintingFlags()); |
| 129 | out << "\n\n" ; |
| 130 | }); |
| 131 | } |
| 132 | |
| 133 | //===----------------------------------------------------------------------===// |
| 134 | // IRPrinterConfig |
| 135 | //===----------------------------------------------------------------------===// |
| 136 | |
| 137 | /// Initialize the configuration. |
| 138 | PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope, |
| 139 | bool printAfterOnlyOnChange, |
| 140 | bool printAfterOnlyOnFailure, |
| 141 | OpPrintingFlags opPrintingFlags) |
| 142 | : printModuleScope(printModuleScope), |
| 143 | printAfterOnlyOnChange(printAfterOnlyOnChange), |
| 144 | printAfterOnlyOnFailure(printAfterOnlyOnFailure), |
| 145 | opPrintingFlags(opPrintingFlags) {} |
| 146 | PassManager::IRPrinterConfig::~IRPrinterConfig() = default; |
| 147 | |
| 148 | /// A hook that may be overridden by a derived config that checks if the IR |
| 149 | /// of 'operation' should be dumped *before* the pass 'pass' has been |
| 150 | /// executed. If the IR should be dumped, 'printCallback' should be invoked |
| 151 | /// with the stream to dump into. |
| 152 | void PassManager::IRPrinterConfig::printBeforeIfEnabled( |
| 153 | Pass *pass, Operation *operation, PrintCallbackFn printCallback) { |
| 154 | // By default, never print. |
| 155 | } |
| 156 | |
| 157 | /// A hook that may be overridden by a derived config that checks if the IR |
| 158 | /// of 'operation' should be dumped *after* the pass 'pass' has been |
| 159 | /// executed. If the IR should be dumped, 'printCallback' should be invoked |
| 160 | /// with the stream to dump into. |
| 161 | void PassManager::IRPrinterConfig::printAfterIfEnabled( |
| 162 | Pass *pass, Operation *operation, PrintCallbackFn printCallback) { |
| 163 | // By default, never print. |
| 164 | } |
| 165 | |
| 166 | //===----------------------------------------------------------------------===// |
| 167 | // PassManager |
| 168 | //===----------------------------------------------------------------------===// |
| 169 | |
| 170 | namespace { |
| 171 | /// Simple wrapper config that allows for the simpler interface defined above. |
| 172 | struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig { |
| 173 | BasicIRPrinterConfig( |
| 174 | std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, |
| 175 | std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, |
| 176 | bool printModuleScope, bool printAfterOnlyOnChange, |
| 177 | bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags, |
| 178 | raw_ostream &out) |
| 179 | : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange, |
| 180 | printAfterOnlyOnFailure, opPrintingFlags), |
| 181 | shouldPrintBeforePass(std::move(shouldPrintBeforePass)), |
| 182 | shouldPrintAfterPass(std::move(shouldPrintAfterPass)), out(out) { |
| 183 | assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) && |
| 184 | "expected at least one valid filter function" ); |
| 185 | } |
| 186 | |
| 187 | void printBeforeIfEnabled(Pass *pass, Operation *operation, |
| 188 | PrintCallbackFn printCallback) final { |
| 189 | if (shouldPrintBeforePass && shouldPrintBeforePass(pass, operation)) |
| 190 | printCallback(out); |
| 191 | } |
| 192 | |
| 193 | void printAfterIfEnabled(Pass *pass, Operation *operation, |
| 194 | PrintCallbackFn printCallback) final { |
| 195 | if (shouldPrintAfterPass && shouldPrintAfterPass(pass, operation)) |
| 196 | printCallback(out); |
| 197 | } |
| 198 | |
| 199 | /// Filter functions for before and after pass execution. |
| 200 | std::function<bool(Pass *, Operation *)> shouldPrintBeforePass; |
| 201 | std::function<bool(Pass *, Operation *)> shouldPrintAfterPass; |
| 202 | |
| 203 | /// The stream to output to. |
| 204 | raw_ostream &out; |
| 205 | }; |
| 206 | } // namespace |
| 207 | |
| 208 | /// Return pairs of (sanitized op name, symbol name) for `op` and all parent |
| 209 | /// operations. Op names are sanitized by replacing periods with underscores. |
| 210 | /// The pairs are returned in order of outer-most to inner-most (ancestors of |
| 211 | /// `op` first, `op` last). This information is used to construct the directory |
| 212 | /// tree for the `FileTreeIRPrinterConfig` below. |
| 213 | /// The counter for `op` will be incremented by this call. |
| 214 | static std::pair<SmallVector<std::pair<std::string, std::string>>, std::string> |
| 215 | getOpAndSymbolNames(Operation *op, StringRef passName, |
| 216 | llvm::DenseMap<Operation *, unsigned> &counters) { |
| 217 | SmallVector<std::pair<std::string, std::string>> pathElements; |
| 218 | SmallVector<unsigned> countPrefix; |
| 219 | |
| 220 | Operation *iter = op; |
| 221 | ++counters.try_emplace(Key: op, Args: -1).first->second; |
| 222 | while (iter) { |
| 223 | countPrefix.push_back(Elt: counters[iter]); |
| 224 | StringAttr symbolNameAttr = |
| 225 | iter->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); |
| 226 | std::string symbolName = |
| 227 | symbolNameAttr ? symbolNameAttr.str() : "no-symbol-name" ; |
| 228 | llvm::replace(Range&: symbolName, OldValue: '/', NewValue: '_'); |
| 229 | llvm::replace(Range&: symbolName, OldValue: '\\', NewValue: '_'); |
| 230 | |
| 231 | std::string opName = |
| 232 | llvm::join(R: llvm::split(Str: iter->getName().getStringRef().str(), Separator: '.'), Separator: "_" ); |
| 233 | pathElements.emplace_back(Args: std::move(opName), Args: std::move(symbolName)); |
| 234 | iter = iter->getParentOp(); |
| 235 | } |
| 236 | // Return in the order of top level (module) down to `op`. |
| 237 | std::reverse(first: countPrefix.begin(), last: countPrefix.end()); |
| 238 | std::reverse(first: pathElements.begin(), last: pathElements.end()); |
| 239 | |
| 240 | std::string passFileName = llvm::formatv( |
| 241 | Fmt: "{0:$[_]}_{1}.mlir" , |
| 242 | Vals: llvm::make_range(x: countPrefix.begin(), y: countPrefix.end()), Vals&: passName); |
| 243 | |
| 244 | return {pathElements, passFileName}; |
| 245 | } |
| 246 | |
| 247 | static LogicalResult createDirectoryOrPrintErr(llvm::StringRef dirPath) { |
| 248 | if (std::error_code ec = |
| 249 | llvm::sys::fs::create_directory(path: dirPath, /*IgnoreExisting=*/true)) { |
| 250 | llvm::errs() << "Error while creating directory " << dirPath << ": " |
| 251 | << ec.message() << "\n" ; |
| 252 | return failure(); |
| 253 | } |
| 254 | return success(); |
| 255 | } |
| 256 | |
| 257 | /// Creates directories (if required) and opens an output file for the |
| 258 | /// FileTreeIRPrinterConfig. |
| 259 | static std::unique_ptr<llvm::ToolOutputFile> |
| 260 | createTreePrinterOutputPath(Operation *op, llvm::StringRef passArgument, |
| 261 | llvm::StringRef rootDir, |
| 262 | llvm::DenseMap<Operation *, unsigned> &counters) { |
| 263 | // Create the path. We will create a tree rooted at the given 'rootDir' |
| 264 | // directory. The root directory will contain folders with the names of |
| 265 | // modules. Sub-directories within those folders mirror the nesting |
| 266 | // structure of the pass manager, using symbol names for directory names. |
| 267 | auto [opAndSymbolNames, fileName] = |
| 268 | getOpAndSymbolNames(op, passName: passArgument, counters); |
| 269 | |
| 270 | // Create all the directories, starting at the root. Abort early if we fail to |
| 271 | // create any directory. |
| 272 | llvm::SmallString<128> path(rootDir); |
| 273 | if (failed(Result: createDirectoryOrPrintErr(dirPath: path))) |
| 274 | return nullptr; |
| 275 | |
| 276 | for (const auto &[opName, symbolName] : opAndSymbolNames) { |
| 277 | llvm::sys::path::append(path, a: opName + "_" + symbolName); |
| 278 | if (failed(Result: createDirectoryOrPrintErr(dirPath: path))) |
| 279 | return nullptr; |
| 280 | } |
| 281 | |
| 282 | // Open output file. |
| 283 | llvm::sys::path::append(path, a: fileName); |
| 284 | std::string error; |
| 285 | std::unique_ptr<llvm::ToolOutputFile> file = openOutputFile(outputFilename: path, errorMessage: &error); |
| 286 | if (!file) { |
| 287 | llvm::errs() << "Error opening output file " << path << ": " << error |
| 288 | << "\n" ; |
| 289 | return nullptr; |
| 290 | } |
| 291 | return file; |
| 292 | } |
| 293 | |
| 294 | namespace { |
| 295 | /// A configuration that prints the IR before/after each pass to a set of files |
| 296 | /// in the specified directory. The files are organized into subdirectories that |
| 297 | /// mirror the nesting structure of the IR. |
| 298 | struct FileTreeIRPrinterConfig : public PassManager::IRPrinterConfig { |
| 299 | FileTreeIRPrinterConfig( |
| 300 | std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, |
| 301 | std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, |
| 302 | bool printModuleScope, bool printAfterOnlyOnChange, |
| 303 | bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags, |
| 304 | llvm::StringRef treeDir) |
| 305 | : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange, |
| 306 | printAfterOnlyOnFailure, opPrintingFlags), |
| 307 | shouldPrintBeforePass(std::move(shouldPrintBeforePass)), |
| 308 | shouldPrintAfterPass(std::move(shouldPrintAfterPass)), |
| 309 | treeDir(treeDir) { |
| 310 | assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) && |
| 311 | "expected at least one valid filter function" ); |
| 312 | } |
| 313 | |
| 314 | void printBeforeIfEnabled(Pass *pass, Operation *operation, |
| 315 | PrintCallbackFn printCallback) final { |
| 316 | if (!shouldPrintBeforePass || !shouldPrintBeforePass(pass, operation)) |
| 317 | return; |
| 318 | std::unique_ptr<llvm::ToolOutputFile> file = createTreePrinterOutputPath( |
| 319 | op: operation, passArgument: pass->getArgument(), rootDir: treeDir, counters); |
| 320 | if (!file) |
| 321 | return; |
| 322 | printCallback(file->os()); |
| 323 | file->keep(); |
| 324 | } |
| 325 | |
| 326 | void printAfterIfEnabled(Pass *pass, Operation *operation, |
| 327 | PrintCallbackFn printCallback) final { |
| 328 | if (!shouldPrintAfterPass || !shouldPrintAfterPass(pass, operation)) |
| 329 | return; |
| 330 | std::unique_ptr<llvm::ToolOutputFile> file = createTreePrinterOutputPath( |
| 331 | op: operation, passArgument: pass->getArgument(), rootDir: treeDir, counters); |
| 332 | if (!file) |
| 333 | return; |
| 334 | printCallback(file->os()); |
| 335 | file->keep(); |
| 336 | } |
| 337 | |
| 338 | /// Filter functions for before and after pass execution. |
| 339 | std::function<bool(Pass *, Operation *)> shouldPrintBeforePass; |
| 340 | std::function<bool(Pass *, Operation *)> shouldPrintAfterPass; |
| 341 | |
| 342 | /// Directory that should be used as the root of the file tree. |
| 343 | std::string treeDir; |
| 344 | |
| 345 | /// Counters used for labeling the prefix. Every op which could be targeted by |
| 346 | /// a pass gets its own counter. |
| 347 | llvm::DenseMap<Operation *, unsigned> counters; |
| 348 | }; |
| 349 | |
| 350 | } // namespace |
| 351 | |
| 352 | /// Add an instrumentation to print the IR before and after pass execution, |
| 353 | /// using the provided configuration. |
| 354 | void PassManager::enableIRPrinting(std::unique_ptr<IRPrinterConfig> config) { |
| 355 | if (config->shouldPrintAtModuleScope() && |
| 356 | getContext()->isMultithreadingEnabled()) |
| 357 | llvm::report_fatal_error(reason: "IR printing can't be setup on a pass-manager " |
| 358 | "without disabling multi-threading first." ); |
| 359 | addInstrumentation( |
| 360 | pi: std::make_unique<IRPrinterInstrumentation>(args: std::move(config))); |
| 361 | } |
| 362 | |
| 363 | /// Add an instrumentation to print the IR before and after pass execution. |
| 364 | void PassManager::enableIRPrinting( |
| 365 | std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, |
| 366 | std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, |
| 367 | bool printModuleScope, bool printAfterOnlyOnChange, |
| 368 | bool printAfterOnlyOnFailure, raw_ostream &out, |
| 369 | OpPrintingFlags opPrintingFlags) { |
| 370 | enableIRPrinting(config: std::make_unique<BasicIRPrinterConfig>( |
| 371 | args: std::move(shouldPrintBeforePass), args: std::move(shouldPrintAfterPass), |
| 372 | args&: printModuleScope, args&: printAfterOnlyOnChange, args&: printAfterOnlyOnFailure, |
| 373 | args&: opPrintingFlags, args&: out)); |
| 374 | } |
| 375 | |
| 376 | /// Add an instrumentation to print the IR before and after pass execution. |
| 377 | void PassManager::enableIRPrintingToFileTree( |
| 378 | std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, |
| 379 | std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, |
| 380 | bool printModuleScope, bool printAfterOnlyOnChange, |
| 381 | bool printAfterOnlyOnFailure, StringRef printTreeDir, |
| 382 | OpPrintingFlags opPrintingFlags) { |
| 383 | enableIRPrinting(config: std::make_unique<FileTreeIRPrinterConfig>( |
| 384 | args: std::move(shouldPrintBeforePass), args: std::move(shouldPrintAfterPass), |
| 385 | args&: printModuleScope, args&: printAfterOnlyOnChange, args&: printAfterOnlyOnFailure, |
| 386 | args&: opPrintingFlags, args&: printTreeDir)); |
| 387 | } |
| 388 | |