1 | //===- IRPrinting.cpp -----------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "PassDetail.h" |
10 | #include "mlir/IR/SymbolTable.h" |
11 | #include "mlir/Pass/PassManager.h" |
12 | #include "mlir/Support/FileUtilities.h" |
13 | #include "llvm/ADT/STLExtras.h" |
14 | #include "llvm/ADT/StringExtras.h" |
15 | #include "llvm/Support/FileSystem.h" |
16 | #include "llvm/Support/FormatVariadic.h" |
17 | #include "llvm/Support/Path.h" |
18 | #include "llvm/Support/ToolOutputFile.h" |
19 | |
20 | using namespace mlir; |
21 | using namespace mlir::detail; |
22 | |
23 | namespace { |
24 | //===----------------------------------------------------------------------===// |
25 | // IRPrinter |
26 | //===----------------------------------------------------------------------===// |
27 | |
28 | class IRPrinterInstrumentation : public PassInstrumentation { |
29 | public: |
30 | IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config) |
31 | : config(std::move(config)) {} |
32 | |
33 | private: |
34 | /// Instrumentation hooks. |
35 | void runBeforePass(Pass *pass, Operation *op) override; |
36 | void runAfterPass(Pass *pass, Operation *op) override; |
37 | void runAfterPassFailed(Pass *pass, Operation *op) override; |
38 | |
39 | /// Configuration to use. |
40 | std::unique_ptr<PassManager::IRPrinterConfig> config; |
41 | |
42 | /// The following is a set of fingerprints for operations that are currently |
43 | /// being operated on in a pass. This field is only used when the |
44 | /// configuration asked for change detection. |
45 | DenseMap<Pass *, OperationFingerPrint> beforePassFingerPrints; |
46 | }; |
47 | } // namespace |
48 | |
49 | static void printIR(Operation *op, bool printModuleScope, raw_ostream &out, |
50 | OpPrintingFlags flags) { |
51 | // Otherwise, check to see if we are not printing at module scope. |
52 | if (!printModuleScope) |
53 | return op->print(os&: out << " //----- //\n" , |
54 | flags: op->getBlock() ? flags.useLocalScope() : flags); |
55 | |
56 | // Otherwise, we are printing at module scope. |
57 | out << " ('" << op->getName() << "' operation" ; |
58 | if (auto symbolName = |
59 | op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())) |
60 | out << ": @" << symbolName.getValue(); |
61 | out << ") //----- //\n" ; |
62 | |
63 | // Find the top-level operation. |
64 | auto *topLevelOp = op; |
65 | while (auto *parentOp = topLevelOp->getParentOp()) |
66 | topLevelOp = parentOp; |
67 | topLevelOp->print(os&: out, flags); |
68 | } |
69 | |
70 | /// Instrumentation hooks. |
71 | void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) { |
72 | if (isa<OpToOpPassAdaptor>(Val: pass)) |
73 | return; |
74 | // If the config asked to detect changes, record the current fingerprint. |
75 | if (config->shouldPrintAfterOnlyOnChange()) |
76 | beforePassFingerPrints.try_emplace(Key: pass, Args&: op); |
77 | |
78 | config->printBeforeIfEnabled(pass, operation: op, printCallback: [&](raw_ostream &out) { |
79 | out << "// -----// IR Dump Before " << pass->getName() << " (" |
80 | << pass->getArgument() << ")" ; |
81 | printIR(op, printModuleScope: config->shouldPrintAtModuleScope(), out, |
82 | flags: config->getOpPrintingFlags()); |
83 | out << "\n\n" ; |
84 | }); |
85 | } |
86 | |
87 | void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) { |
88 | if (isa<OpToOpPassAdaptor>(Val: pass)) |
89 | return; |
90 | |
91 | // Check to see if we are only printing on failure. |
92 | if (config->shouldPrintAfterOnlyOnFailure()) |
93 | return; |
94 | |
95 | // If the config asked to detect changes, compare the current fingerprint with |
96 | // the previous. |
97 | if (config->shouldPrintAfterOnlyOnChange()) { |
98 | auto fingerPrintIt = beforePassFingerPrints.find(Val: pass); |
99 | assert(fingerPrintIt != beforePassFingerPrints.end() && |
100 | "expected valid fingerprint" ); |
101 | // If the fingerprints are the same, we don't print the IR. |
102 | if (fingerPrintIt->second == OperationFingerPrint(op)) { |
103 | beforePassFingerPrints.erase(I: fingerPrintIt); |
104 | return; |
105 | } |
106 | beforePassFingerPrints.erase(I: fingerPrintIt); |
107 | } |
108 | |
109 | config->printAfterIfEnabled(pass, operation: op, printCallback: [&](raw_ostream &out) { |
110 | out << "// -----// IR Dump After " << pass->getName() << " (" |
111 | << pass->getArgument() << ")" ; |
112 | printIR(op, printModuleScope: config->shouldPrintAtModuleScope(), out, |
113 | flags: config->getOpPrintingFlags()); |
114 | out << "\n\n" ; |
115 | }); |
116 | } |
117 | |
118 | void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) { |
119 | if (isa<OpToOpPassAdaptor>(Val: pass)) |
120 | return; |
121 | if (config->shouldPrintAfterOnlyOnChange()) |
122 | beforePassFingerPrints.erase(Val: pass); |
123 | |
124 | config->printAfterIfEnabled(pass, operation: op, printCallback: [&](raw_ostream &out) { |
125 | out << formatv(Fmt: "// -----// IR Dump After {0} Failed ({1})" , Vals: pass->getName(), |
126 | Vals: pass->getArgument()); |
127 | printIR(op, printModuleScope: config->shouldPrintAtModuleScope(), out, |
128 | flags: config->getOpPrintingFlags()); |
129 | out << "\n\n" ; |
130 | }); |
131 | } |
132 | |
133 | //===----------------------------------------------------------------------===// |
134 | // IRPrinterConfig |
135 | //===----------------------------------------------------------------------===// |
136 | |
137 | /// Initialize the configuration. |
138 | PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope, |
139 | bool printAfterOnlyOnChange, |
140 | bool printAfterOnlyOnFailure, |
141 | OpPrintingFlags opPrintingFlags) |
142 | : printModuleScope(printModuleScope), |
143 | printAfterOnlyOnChange(printAfterOnlyOnChange), |
144 | printAfterOnlyOnFailure(printAfterOnlyOnFailure), |
145 | opPrintingFlags(opPrintingFlags) {} |
146 | PassManager::IRPrinterConfig::~IRPrinterConfig() = default; |
147 | |
148 | /// A hook that may be overridden by a derived config that checks if the IR |
149 | /// of 'operation' should be dumped *before* the pass 'pass' has been |
150 | /// executed. If the IR should be dumped, 'printCallback' should be invoked |
151 | /// with the stream to dump into. |
152 | void PassManager::IRPrinterConfig::printBeforeIfEnabled( |
153 | Pass *pass, Operation *operation, PrintCallbackFn printCallback) { |
154 | // By default, never print. |
155 | } |
156 | |
157 | /// A hook that may be overridden by a derived config that checks if the IR |
158 | /// of 'operation' should be dumped *after* the pass 'pass' has been |
159 | /// executed. If the IR should be dumped, 'printCallback' should be invoked |
160 | /// with the stream to dump into. |
161 | void PassManager::IRPrinterConfig::printAfterIfEnabled( |
162 | Pass *pass, Operation *operation, PrintCallbackFn printCallback) { |
163 | // By default, never print. |
164 | } |
165 | |
166 | //===----------------------------------------------------------------------===// |
167 | // PassManager |
168 | //===----------------------------------------------------------------------===// |
169 | |
170 | namespace { |
171 | /// Simple wrapper config that allows for the simpler interface defined above. |
172 | struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig { |
173 | BasicIRPrinterConfig( |
174 | std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, |
175 | std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, |
176 | bool printModuleScope, bool printAfterOnlyOnChange, |
177 | bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags, |
178 | raw_ostream &out) |
179 | : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange, |
180 | printAfterOnlyOnFailure, opPrintingFlags), |
181 | shouldPrintBeforePass(std::move(shouldPrintBeforePass)), |
182 | shouldPrintAfterPass(std::move(shouldPrintAfterPass)), out(out) { |
183 | assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) && |
184 | "expected at least one valid filter function" ); |
185 | } |
186 | |
187 | void printBeforeIfEnabled(Pass *pass, Operation *operation, |
188 | PrintCallbackFn printCallback) final { |
189 | if (shouldPrintBeforePass && shouldPrintBeforePass(pass, operation)) |
190 | printCallback(out); |
191 | } |
192 | |
193 | void printAfterIfEnabled(Pass *pass, Operation *operation, |
194 | PrintCallbackFn printCallback) final { |
195 | if (shouldPrintAfterPass && shouldPrintAfterPass(pass, operation)) |
196 | printCallback(out); |
197 | } |
198 | |
199 | /// Filter functions for before and after pass execution. |
200 | std::function<bool(Pass *, Operation *)> shouldPrintBeforePass; |
201 | std::function<bool(Pass *, Operation *)> shouldPrintAfterPass; |
202 | |
203 | /// The stream to output to. |
204 | raw_ostream &out; |
205 | }; |
206 | } // namespace |
207 | |
208 | /// Return pairs of (sanitized op name, symbol name) for `op` and all parent |
209 | /// operations. Op names are sanitized by replacing periods with underscores. |
210 | /// The pairs are returned in order of outer-most to inner-most (ancestors of |
211 | /// `op` first, `op` last). This information is used to construct the directory |
212 | /// tree for the `FileTreeIRPrinterConfig` below. |
213 | /// The counter for `op` will be incremented by this call. |
214 | static std::pair<SmallVector<std::pair<std::string, std::string>>, std::string> |
215 | getOpAndSymbolNames(Operation *op, StringRef passName, |
216 | llvm::DenseMap<Operation *, unsigned> &counters) { |
217 | SmallVector<std::pair<std::string, std::string>> pathElements; |
218 | SmallVector<unsigned> countPrefix; |
219 | |
220 | Operation *iter = op; |
221 | ++counters.try_emplace(Key: op, Args: -1).first->second; |
222 | while (iter) { |
223 | countPrefix.push_back(Elt: counters[iter]); |
224 | StringAttr symbolNameAttr = |
225 | iter->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); |
226 | std::string symbolName = |
227 | symbolNameAttr ? symbolNameAttr.str() : "no-symbol-name" ; |
228 | llvm::replace(Range&: symbolName, OldValue: '/', NewValue: '_'); |
229 | llvm::replace(Range&: symbolName, OldValue: '\\', NewValue: '_'); |
230 | |
231 | std::string opName = |
232 | llvm::join(R: llvm::split(Str: iter->getName().getStringRef().str(), Separator: '.'), Separator: "_" ); |
233 | pathElements.emplace_back(Args: std::move(opName), Args: std::move(symbolName)); |
234 | iter = iter->getParentOp(); |
235 | } |
236 | // Return in the order of top level (module) down to `op`. |
237 | std::reverse(first: countPrefix.begin(), last: countPrefix.end()); |
238 | std::reverse(first: pathElements.begin(), last: pathElements.end()); |
239 | |
240 | std::string passFileName = llvm::formatv( |
241 | Fmt: "{0:$[_]}_{1}.mlir" , |
242 | Vals: llvm::make_range(x: countPrefix.begin(), y: countPrefix.end()), Vals&: passName); |
243 | |
244 | return {pathElements, passFileName}; |
245 | } |
246 | |
247 | static LogicalResult createDirectoryOrPrintErr(llvm::StringRef dirPath) { |
248 | if (std::error_code ec = |
249 | llvm::sys::fs::create_directory(path: dirPath, /*IgnoreExisting=*/true)) { |
250 | llvm::errs() << "Error while creating directory " << dirPath << ": " |
251 | << ec.message() << "\n" ; |
252 | return failure(); |
253 | } |
254 | return success(); |
255 | } |
256 | |
257 | /// Creates directories (if required) and opens an output file for the |
258 | /// FileTreeIRPrinterConfig. |
259 | static std::unique_ptr<llvm::ToolOutputFile> |
260 | createTreePrinterOutputPath(Operation *op, llvm::StringRef passArgument, |
261 | llvm::StringRef rootDir, |
262 | llvm::DenseMap<Operation *, unsigned> &counters) { |
263 | // Create the path. We will create a tree rooted at the given 'rootDir' |
264 | // directory. The root directory will contain folders with the names of |
265 | // modules. Sub-directories within those folders mirror the nesting |
266 | // structure of the pass manager, using symbol names for directory names. |
267 | auto [opAndSymbolNames, fileName] = |
268 | getOpAndSymbolNames(op, passName: passArgument, counters); |
269 | |
270 | // Create all the directories, starting at the root. Abort early if we fail to |
271 | // create any directory. |
272 | llvm::SmallString<128> path(rootDir); |
273 | if (failed(Result: createDirectoryOrPrintErr(dirPath: path))) |
274 | return nullptr; |
275 | |
276 | for (const auto &[opName, symbolName] : opAndSymbolNames) { |
277 | llvm::sys::path::append(path, a: opName + "_" + symbolName); |
278 | if (failed(Result: createDirectoryOrPrintErr(dirPath: path))) |
279 | return nullptr; |
280 | } |
281 | |
282 | // Open output file. |
283 | llvm::sys::path::append(path, a: fileName); |
284 | std::string error; |
285 | std::unique_ptr<llvm::ToolOutputFile> file = openOutputFile(outputFilename: path, errorMessage: &error); |
286 | if (!file) { |
287 | llvm::errs() << "Error opening output file " << path << ": " << error |
288 | << "\n" ; |
289 | return nullptr; |
290 | } |
291 | return file; |
292 | } |
293 | |
294 | namespace { |
295 | /// A configuration that prints the IR before/after each pass to a set of files |
296 | /// in the specified directory. The files are organized into subdirectories that |
297 | /// mirror the nesting structure of the IR. |
298 | struct FileTreeIRPrinterConfig : public PassManager::IRPrinterConfig { |
299 | FileTreeIRPrinterConfig( |
300 | std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, |
301 | std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, |
302 | bool printModuleScope, bool printAfterOnlyOnChange, |
303 | bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags, |
304 | llvm::StringRef treeDir) |
305 | : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange, |
306 | printAfterOnlyOnFailure, opPrintingFlags), |
307 | shouldPrintBeforePass(std::move(shouldPrintBeforePass)), |
308 | shouldPrintAfterPass(std::move(shouldPrintAfterPass)), |
309 | treeDir(treeDir) { |
310 | assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) && |
311 | "expected at least one valid filter function" ); |
312 | } |
313 | |
314 | void printBeforeIfEnabled(Pass *pass, Operation *operation, |
315 | PrintCallbackFn printCallback) final { |
316 | if (!shouldPrintBeforePass || !shouldPrintBeforePass(pass, operation)) |
317 | return; |
318 | std::unique_ptr<llvm::ToolOutputFile> file = createTreePrinterOutputPath( |
319 | op: operation, passArgument: pass->getArgument(), rootDir: treeDir, counters); |
320 | if (!file) |
321 | return; |
322 | printCallback(file->os()); |
323 | file->keep(); |
324 | } |
325 | |
326 | void printAfterIfEnabled(Pass *pass, Operation *operation, |
327 | PrintCallbackFn printCallback) final { |
328 | if (!shouldPrintAfterPass || !shouldPrintAfterPass(pass, operation)) |
329 | return; |
330 | std::unique_ptr<llvm::ToolOutputFile> file = createTreePrinterOutputPath( |
331 | op: operation, passArgument: pass->getArgument(), rootDir: treeDir, counters); |
332 | if (!file) |
333 | return; |
334 | printCallback(file->os()); |
335 | file->keep(); |
336 | } |
337 | |
338 | /// Filter functions for before and after pass execution. |
339 | std::function<bool(Pass *, Operation *)> shouldPrintBeforePass; |
340 | std::function<bool(Pass *, Operation *)> shouldPrintAfterPass; |
341 | |
342 | /// Directory that should be used as the root of the file tree. |
343 | std::string treeDir; |
344 | |
345 | /// Counters used for labeling the prefix. Every op which could be targeted by |
346 | /// a pass gets its own counter. |
347 | llvm::DenseMap<Operation *, unsigned> counters; |
348 | }; |
349 | |
350 | } // namespace |
351 | |
352 | /// Add an instrumentation to print the IR before and after pass execution, |
353 | /// using the provided configuration. |
354 | void PassManager::enableIRPrinting(std::unique_ptr<IRPrinterConfig> config) { |
355 | if (config->shouldPrintAtModuleScope() && |
356 | getContext()->isMultithreadingEnabled()) |
357 | llvm::report_fatal_error(reason: "IR printing can't be setup on a pass-manager " |
358 | "without disabling multi-threading first." ); |
359 | addInstrumentation( |
360 | pi: std::make_unique<IRPrinterInstrumentation>(args: std::move(config))); |
361 | } |
362 | |
363 | /// Add an instrumentation to print the IR before and after pass execution. |
364 | void PassManager::enableIRPrinting( |
365 | std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, |
366 | std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, |
367 | bool printModuleScope, bool printAfterOnlyOnChange, |
368 | bool printAfterOnlyOnFailure, raw_ostream &out, |
369 | OpPrintingFlags opPrintingFlags) { |
370 | enableIRPrinting(config: std::make_unique<BasicIRPrinterConfig>( |
371 | args: std::move(shouldPrintBeforePass), args: std::move(shouldPrintAfterPass), |
372 | args&: printModuleScope, args&: printAfterOnlyOnChange, args&: printAfterOnlyOnFailure, |
373 | args&: opPrintingFlags, args&: out)); |
374 | } |
375 | |
376 | /// Add an instrumentation to print the IR before and after pass execution. |
377 | void PassManager::enableIRPrintingToFileTree( |
378 | std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, |
379 | std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, |
380 | bool printModuleScope, bool printAfterOnlyOnChange, |
381 | bool printAfterOnlyOnFailure, StringRef printTreeDir, |
382 | OpPrintingFlags opPrintingFlags) { |
383 | enableIRPrinting(config: std::make_unique<FileTreeIRPrinterConfig>( |
384 | args: std::move(shouldPrintBeforePass), args: std::move(shouldPrintAfterPass), |
385 | args&: printModuleScope, args&: printAfterOnlyOnChange, args&: printAfterOnlyOnFailure, |
386 | args&: opPrintingFlags, args&: printTreeDir)); |
387 | } |
388 | |