1//===- IRPrinting.cpp -----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "PassDetail.h"
10#include "mlir/IR/SymbolTable.h"
11#include "mlir/Pass/PassManager.h"
12#include "mlir/Support/FileUtilities.h"
13#include "llvm/ADT/STLExtras.h"
14#include "llvm/ADT/StringExtras.h"
15#include "llvm/Support/FileSystem.h"
16#include "llvm/Support/FormatVariadic.h"
17#include "llvm/Support/Path.h"
18#include "llvm/Support/ToolOutputFile.h"
19
20using namespace mlir;
21using namespace mlir::detail;
22
23namespace {
24//===----------------------------------------------------------------------===//
25// IRPrinter
26//===----------------------------------------------------------------------===//
27
28class IRPrinterInstrumentation : public PassInstrumentation {
29public:
30 IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config)
31 : config(std::move(config)) {}
32
33private:
34 /// Instrumentation hooks.
35 void runBeforePass(Pass *pass, Operation *op) override;
36 void runAfterPass(Pass *pass, Operation *op) override;
37 void runAfterPassFailed(Pass *pass, Operation *op) override;
38
39 /// Configuration to use.
40 std::unique_ptr<PassManager::IRPrinterConfig> config;
41
42 /// The following is a set of fingerprints for operations that are currently
43 /// being operated on in a pass. This field is only used when the
44 /// configuration asked for change detection.
45 DenseMap<Pass *, OperationFingerPrint> beforePassFingerPrints;
46};
47} // namespace
48
49static void printIR(Operation *op, bool printModuleScope, raw_ostream &out,
50 OpPrintingFlags flags) {
51 // Otherwise, check to see if we are not printing at module scope.
52 if (!printModuleScope)
53 return op->print(os&: out << " //----- //\n",
54 flags: op->getBlock() ? flags.useLocalScope() : flags);
55
56 // Otherwise, we are printing at module scope.
57 out << " ('" << op->getName() << "' operation";
58 if (auto symbolName =
59 op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()))
60 out << ": @" << symbolName.getValue();
61 out << ") //----- //\n";
62
63 // Find the top-level operation.
64 auto *topLevelOp = op;
65 while (auto *parentOp = topLevelOp->getParentOp())
66 topLevelOp = parentOp;
67 topLevelOp->print(os&: out, flags);
68}
69
70/// Instrumentation hooks.
71void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) {
72 if (isa<OpToOpPassAdaptor>(Val: pass))
73 return;
74 // If the config asked to detect changes, record the current fingerprint.
75 if (config->shouldPrintAfterOnlyOnChange())
76 beforePassFingerPrints.try_emplace(Key: pass, Args&: op);
77
78 config->printBeforeIfEnabled(pass, operation: op, printCallback: [&](raw_ostream &out) {
79 out << "// -----// IR Dump Before " << pass->getName() << " ("
80 << pass->getArgument() << ")";
81 printIR(op, printModuleScope: config->shouldPrintAtModuleScope(), out,
82 flags: config->getOpPrintingFlags());
83 out << "\n\n";
84 });
85}
86
87void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) {
88 if (isa<OpToOpPassAdaptor>(Val: pass))
89 return;
90
91 // Check to see if we are only printing on failure.
92 if (config->shouldPrintAfterOnlyOnFailure())
93 return;
94
95 // If the config asked to detect changes, compare the current fingerprint with
96 // the previous.
97 if (config->shouldPrintAfterOnlyOnChange()) {
98 auto fingerPrintIt = beforePassFingerPrints.find(Val: pass);
99 assert(fingerPrintIt != beforePassFingerPrints.end() &&
100 "expected valid fingerprint");
101 // If the fingerprints are the same, we don't print the IR.
102 if (fingerPrintIt->second == OperationFingerPrint(op)) {
103 beforePassFingerPrints.erase(I: fingerPrintIt);
104 return;
105 }
106 beforePassFingerPrints.erase(I: fingerPrintIt);
107 }
108
109 config->printAfterIfEnabled(pass, operation: op, printCallback: [&](raw_ostream &out) {
110 out << "// -----// IR Dump After " << pass->getName() << " ("
111 << pass->getArgument() << ")";
112 printIR(op, printModuleScope: config->shouldPrintAtModuleScope(), out,
113 flags: config->getOpPrintingFlags());
114 out << "\n\n";
115 });
116}
117
118void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) {
119 if (isa<OpToOpPassAdaptor>(Val: pass))
120 return;
121 if (config->shouldPrintAfterOnlyOnChange())
122 beforePassFingerPrints.erase(Val: pass);
123
124 config->printAfterIfEnabled(pass, operation: op, printCallback: [&](raw_ostream &out) {
125 out << formatv(Fmt: "// -----// IR Dump After {0} Failed ({1})", Vals: pass->getName(),
126 Vals: pass->getArgument());
127 printIR(op, printModuleScope: config->shouldPrintAtModuleScope(), out,
128 flags: config->getOpPrintingFlags());
129 out << "\n\n";
130 });
131}
132
133//===----------------------------------------------------------------------===//
134// IRPrinterConfig
135//===----------------------------------------------------------------------===//
136
137/// Initialize the configuration.
138PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope,
139 bool printAfterOnlyOnChange,
140 bool printAfterOnlyOnFailure,
141 OpPrintingFlags opPrintingFlags)
142 : printModuleScope(printModuleScope),
143 printAfterOnlyOnChange(printAfterOnlyOnChange),
144 printAfterOnlyOnFailure(printAfterOnlyOnFailure),
145 opPrintingFlags(opPrintingFlags) {}
146PassManager::IRPrinterConfig::~IRPrinterConfig() = default;
147
148/// A hook that may be overridden by a derived config that checks if the IR
149/// of 'operation' should be dumped *before* the pass 'pass' has been
150/// executed. If the IR should be dumped, 'printCallback' should be invoked
151/// with the stream to dump into.
152void PassManager::IRPrinterConfig::printBeforeIfEnabled(
153 Pass *pass, Operation *operation, PrintCallbackFn printCallback) {
154 // By default, never print.
155}
156
157/// A hook that may be overridden by a derived config that checks if the IR
158/// of 'operation' should be dumped *after* the pass 'pass' has been
159/// executed. If the IR should be dumped, 'printCallback' should be invoked
160/// with the stream to dump into.
161void PassManager::IRPrinterConfig::printAfterIfEnabled(
162 Pass *pass, Operation *operation, PrintCallbackFn printCallback) {
163 // By default, never print.
164}
165
166//===----------------------------------------------------------------------===//
167// PassManager
168//===----------------------------------------------------------------------===//
169
170namespace {
171/// Simple wrapper config that allows for the simpler interface defined above.
172struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig {
173 BasicIRPrinterConfig(
174 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
175 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
176 bool printModuleScope, bool printAfterOnlyOnChange,
177 bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags,
178 raw_ostream &out)
179 : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange,
180 printAfterOnlyOnFailure, opPrintingFlags),
181 shouldPrintBeforePass(std::move(shouldPrintBeforePass)),
182 shouldPrintAfterPass(std::move(shouldPrintAfterPass)), out(out) {
183 assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) &&
184 "expected at least one valid filter function");
185 }
186
187 void printBeforeIfEnabled(Pass *pass, Operation *operation,
188 PrintCallbackFn printCallback) final {
189 if (shouldPrintBeforePass && shouldPrintBeforePass(pass, operation))
190 printCallback(out);
191 }
192
193 void printAfterIfEnabled(Pass *pass, Operation *operation,
194 PrintCallbackFn printCallback) final {
195 if (shouldPrintAfterPass && shouldPrintAfterPass(pass, operation))
196 printCallback(out);
197 }
198
199 /// Filter functions for before and after pass execution.
200 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass;
201 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass;
202
203 /// The stream to output to.
204 raw_ostream &out;
205};
206} // namespace
207
208/// Return pairs of (sanitized op name, symbol name) for `op` and all parent
209/// operations. Op names are sanitized by replacing periods with underscores.
210/// The pairs are returned in order of outer-most to inner-most (ancestors of
211/// `op` first, `op` last). This information is used to construct the directory
212/// tree for the `FileTreeIRPrinterConfig` below.
213/// The counter for `op` will be incremented by this call.
214static std::pair<SmallVector<std::pair<std::string, std::string>>, std::string>
215getOpAndSymbolNames(Operation *op, StringRef passName,
216 llvm::DenseMap<Operation *, unsigned> &counters) {
217 SmallVector<std::pair<std::string, std::string>> pathElements;
218 SmallVector<unsigned> countPrefix;
219
220 Operation *iter = op;
221 ++counters.try_emplace(Key: op, Args: -1).first->second;
222 while (iter) {
223 countPrefix.push_back(Elt: counters[iter]);
224 StringAttr symbolNameAttr =
225 iter->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
226 std::string symbolName =
227 symbolNameAttr ? symbolNameAttr.str() : "no-symbol-name";
228 llvm::replace(Range&: symbolName, OldValue: '/', NewValue: '_');
229 llvm::replace(Range&: symbolName, OldValue: '\\', NewValue: '_');
230
231 std::string opName =
232 llvm::join(R: llvm::split(Str: iter->getName().getStringRef().str(), Separator: '.'), Separator: "_");
233 pathElements.emplace_back(Args: std::move(opName), Args: std::move(symbolName));
234 iter = iter->getParentOp();
235 }
236 // Return in the order of top level (module) down to `op`.
237 std::reverse(first: countPrefix.begin(), last: countPrefix.end());
238 std::reverse(first: pathElements.begin(), last: pathElements.end());
239
240 std::string passFileName = llvm::formatv(
241 Fmt: "{0:$[_]}_{1}.mlir",
242 Vals: llvm::make_range(x: countPrefix.begin(), y: countPrefix.end()), Vals&: passName);
243
244 return {pathElements, passFileName};
245}
246
247static LogicalResult createDirectoryOrPrintErr(llvm::StringRef dirPath) {
248 if (std::error_code ec =
249 llvm::sys::fs::create_directory(path: dirPath, /*IgnoreExisting=*/true)) {
250 llvm::errs() << "Error while creating directory " << dirPath << ": "
251 << ec.message() << "\n";
252 return failure();
253 }
254 return success();
255}
256
257/// Creates directories (if required) and opens an output file for the
258/// FileTreeIRPrinterConfig.
259static std::unique_ptr<llvm::ToolOutputFile>
260createTreePrinterOutputPath(Operation *op, llvm::StringRef passArgument,
261 llvm::StringRef rootDir,
262 llvm::DenseMap<Operation *, unsigned> &counters) {
263 // Create the path. We will create a tree rooted at the given 'rootDir'
264 // directory. The root directory will contain folders with the names of
265 // modules. Sub-directories within those folders mirror the nesting
266 // structure of the pass manager, using symbol names for directory names.
267 auto [opAndSymbolNames, fileName] =
268 getOpAndSymbolNames(op, passName: passArgument, counters);
269
270 // Create all the directories, starting at the root. Abort early if we fail to
271 // create any directory.
272 llvm::SmallString<128> path(rootDir);
273 if (failed(Result: createDirectoryOrPrintErr(dirPath: path)))
274 return nullptr;
275
276 for (const auto &[opName, symbolName] : opAndSymbolNames) {
277 llvm::sys::path::append(path, a: opName + "_" + symbolName);
278 if (failed(Result: createDirectoryOrPrintErr(dirPath: path)))
279 return nullptr;
280 }
281
282 // Open output file.
283 llvm::sys::path::append(path, a: fileName);
284 std::string error;
285 std::unique_ptr<llvm::ToolOutputFile> file = openOutputFile(outputFilename: path, errorMessage: &error);
286 if (!file) {
287 llvm::errs() << "Error opening output file " << path << ": " << error
288 << "\n";
289 return nullptr;
290 }
291 return file;
292}
293
294namespace {
295/// A configuration that prints the IR before/after each pass to a set of files
296/// in the specified directory. The files are organized into subdirectories that
297/// mirror the nesting structure of the IR.
298struct FileTreeIRPrinterConfig : public PassManager::IRPrinterConfig {
299 FileTreeIRPrinterConfig(
300 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
301 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
302 bool printModuleScope, bool printAfterOnlyOnChange,
303 bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags,
304 llvm::StringRef treeDir)
305 : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange,
306 printAfterOnlyOnFailure, opPrintingFlags),
307 shouldPrintBeforePass(std::move(shouldPrintBeforePass)),
308 shouldPrintAfterPass(std::move(shouldPrintAfterPass)),
309 treeDir(treeDir) {
310 assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) &&
311 "expected at least one valid filter function");
312 }
313
314 void printBeforeIfEnabled(Pass *pass, Operation *operation,
315 PrintCallbackFn printCallback) final {
316 if (!shouldPrintBeforePass || !shouldPrintBeforePass(pass, operation))
317 return;
318 std::unique_ptr<llvm::ToolOutputFile> file = createTreePrinterOutputPath(
319 op: operation, passArgument: pass->getArgument(), rootDir: treeDir, counters);
320 if (!file)
321 return;
322 printCallback(file->os());
323 file->keep();
324 }
325
326 void printAfterIfEnabled(Pass *pass, Operation *operation,
327 PrintCallbackFn printCallback) final {
328 if (!shouldPrintAfterPass || !shouldPrintAfterPass(pass, operation))
329 return;
330 std::unique_ptr<llvm::ToolOutputFile> file = createTreePrinterOutputPath(
331 op: operation, passArgument: pass->getArgument(), rootDir: treeDir, counters);
332 if (!file)
333 return;
334 printCallback(file->os());
335 file->keep();
336 }
337
338 /// Filter functions for before and after pass execution.
339 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass;
340 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass;
341
342 /// Directory that should be used as the root of the file tree.
343 std::string treeDir;
344
345 /// Counters used for labeling the prefix. Every op which could be targeted by
346 /// a pass gets its own counter.
347 llvm::DenseMap<Operation *, unsigned> counters;
348};
349
350} // namespace
351
352/// Add an instrumentation to print the IR before and after pass execution,
353/// using the provided configuration.
354void PassManager::enableIRPrinting(std::unique_ptr<IRPrinterConfig> config) {
355 if (config->shouldPrintAtModuleScope() &&
356 getContext()->isMultithreadingEnabled())
357 llvm::report_fatal_error(reason: "IR printing can't be setup on a pass-manager "
358 "without disabling multi-threading first.");
359 addInstrumentation(
360 pi: std::make_unique<IRPrinterInstrumentation>(args: std::move(config)));
361}
362
363/// Add an instrumentation to print the IR before and after pass execution.
364void PassManager::enableIRPrinting(
365 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
366 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
367 bool printModuleScope, bool printAfterOnlyOnChange,
368 bool printAfterOnlyOnFailure, raw_ostream &out,
369 OpPrintingFlags opPrintingFlags) {
370 enableIRPrinting(config: std::make_unique<BasicIRPrinterConfig>(
371 args: std::move(shouldPrintBeforePass), args: std::move(shouldPrintAfterPass),
372 args&: printModuleScope, args&: printAfterOnlyOnChange, args&: printAfterOnlyOnFailure,
373 args&: opPrintingFlags, args&: out));
374}
375
376/// Add an instrumentation to print the IR before and after pass execution.
377void PassManager::enableIRPrintingToFileTree(
378 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
379 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
380 bool printModuleScope, bool printAfterOnlyOnChange,
381 bool printAfterOnlyOnFailure, StringRef printTreeDir,
382 OpPrintingFlags opPrintingFlags) {
383 enableIRPrinting(config: std::make_unique<FileTreeIRPrinterConfig>(
384 args: std::move(shouldPrintBeforePass), args: std::move(shouldPrintAfterPass),
385 args&: printModuleScope, args&: printAfterOnlyOnChange, args&: printAfterOnlyOnFailure,
386 args&: opPrintingFlags, args&: printTreeDir));
387}
388

source code of mlir/lib/Pass/IRPrinting.cpp