1 | //===- IRPrinting.cpp -----------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "PassDetail.h" |
10 | #include "mlir/IR/SymbolTable.h" |
11 | #include "mlir/Pass/PassManager.h" |
12 | #include "llvm/Support/Format.h" |
13 | #include "llvm/Support/FormatVariadic.h" |
14 | |
15 | using namespace mlir; |
16 | using namespace mlir::detail; |
17 | |
18 | namespace { |
19 | //===----------------------------------------------------------------------===// |
20 | // IRPrinter |
21 | //===----------------------------------------------------------------------===// |
22 | |
23 | class IRPrinterInstrumentation : public PassInstrumentation { |
24 | public: |
25 | IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config) |
26 | : config(std::move(config)) {} |
27 | |
28 | private: |
29 | /// Instrumentation hooks. |
30 | void runBeforePass(Pass *pass, Operation *op) override; |
31 | void runAfterPass(Pass *pass, Operation *op) override; |
32 | void runAfterPassFailed(Pass *pass, Operation *op) override; |
33 | |
34 | /// Configuration to use. |
35 | std::unique_ptr<PassManager::IRPrinterConfig> config; |
36 | |
37 | /// The following is a set of fingerprints for operations that are currently |
38 | /// being operated on in a pass. This field is only used when the |
39 | /// configuration asked for change detection. |
40 | DenseMap<Pass *, OperationFingerPrint> beforePassFingerPrints; |
41 | }; |
42 | } // namespace |
43 | |
44 | static void printIR(Operation *op, bool printModuleScope, raw_ostream &out, |
45 | OpPrintingFlags flags) { |
46 | // Otherwise, check to see if we are not printing at module scope. |
47 | if (!printModuleScope) |
48 | return op->print(os&: out << " //----- //\n" , |
49 | flags: op->getBlock() ? flags.useLocalScope() : flags); |
50 | |
51 | // Otherwise, we are printing at module scope. |
52 | out << " ('" << op->getName() << "' operation" ; |
53 | if (auto symbolName = |
54 | op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())) |
55 | out << ": @" << symbolName.getValue(); |
56 | out << ") //----- //\n" ; |
57 | |
58 | // Find the top-level operation. |
59 | auto *topLevelOp = op; |
60 | while (auto *parentOp = topLevelOp->getParentOp()) |
61 | topLevelOp = parentOp; |
62 | topLevelOp->print(os&: out, flags); |
63 | } |
64 | |
65 | /// Instrumentation hooks. |
66 | void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) { |
67 | if (isa<OpToOpPassAdaptor>(Val: pass)) |
68 | return; |
69 | // If the config asked to detect changes, record the current fingerprint. |
70 | if (config->shouldPrintAfterOnlyOnChange()) |
71 | beforePassFingerPrints.try_emplace(Key: pass, Args&: op); |
72 | |
73 | config->printBeforeIfEnabled(pass, operation: op, printCallback: [&](raw_ostream &out) { |
74 | out << "// -----// IR Dump Before " << pass->getName() << " (" |
75 | << pass->getArgument() << ")" ; |
76 | printIR(op, printModuleScope: config->shouldPrintAtModuleScope(), out, |
77 | flags: config->getOpPrintingFlags()); |
78 | out << "\n\n" ; |
79 | }); |
80 | } |
81 | |
82 | void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) { |
83 | if (isa<OpToOpPassAdaptor>(Val: pass)) |
84 | return; |
85 | |
86 | // Check to see if we are only printing on failure. |
87 | if (config->shouldPrintAfterOnlyOnFailure()) |
88 | return; |
89 | |
90 | // If the config asked to detect changes, compare the current fingerprint with |
91 | // the previous. |
92 | if (config->shouldPrintAfterOnlyOnChange()) { |
93 | auto fingerPrintIt = beforePassFingerPrints.find(Val: pass); |
94 | assert(fingerPrintIt != beforePassFingerPrints.end() && |
95 | "expected valid fingerprint" ); |
96 | // If the fingerprints are the same, we don't print the IR. |
97 | if (fingerPrintIt->second == OperationFingerPrint(op)) { |
98 | beforePassFingerPrints.erase(I: fingerPrintIt); |
99 | return; |
100 | } |
101 | beforePassFingerPrints.erase(I: fingerPrintIt); |
102 | } |
103 | |
104 | config->printAfterIfEnabled(pass, operation: op, printCallback: [&](raw_ostream &out) { |
105 | out << "// -----// IR Dump After " << pass->getName() << " (" |
106 | << pass->getArgument() << ")" ; |
107 | printIR(op, printModuleScope: config->shouldPrintAtModuleScope(), out, |
108 | flags: config->getOpPrintingFlags()); |
109 | out << "\n\n" ; |
110 | }); |
111 | } |
112 | |
113 | void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) { |
114 | if (isa<OpToOpPassAdaptor>(Val: pass)) |
115 | return; |
116 | if (config->shouldPrintAfterOnlyOnChange()) |
117 | beforePassFingerPrints.erase(Val: pass); |
118 | |
119 | config->printAfterIfEnabled(pass, operation: op, printCallback: [&](raw_ostream &out) { |
120 | out << formatv(Fmt: "// -----// IR Dump After {0} Failed ({1})" , Vals: pass->getName(), |
121 | Vals: pass->getArgument()); |
122 | printIR(op, printModuleScope: config->shouldPrintAtModuleScope(), out, |
123 | flags: config->getOpPrintingFlags()); |
124 | out << "\n\n" ; |
125 | }); |
126 | } |
127 | |
128 | //===----------------------------------------------------------------------===// |
129 | // IRPrinterConfig |
130 | //===----------------------------------------------------------------------===// |
131 | |
132 | /// Initialize the configuration. |
133 | PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope, |
134 | bool printAfterOnlyOnChange, |
135 | bool printAfterOnlyOnFailure, |
136 | OpPrintingFlags opPrintingFlags) |
137 | : printModuleScope(printModuleScope), |
138 | printAfterOnlyOnChange(printAfterOnlyOnChange), |
139 | printAfterOnlyOnFailure(printAfterOnlyOnFailure), |
140 | opPrintingFlags(opPrintingFlags) {} |
141 | PassManager::IRPrinterConfig::~IRPrinterConfig() = default; |
142 | |
143 | /// A hook that may be overridden by a derived config that checks if the IR |
144 | /// of 'operation' should be dumped *before* the pass 'pass' has been |
145 | /// executed. If the IR should be dumped, 'printCallback' should be invoked |
146 | /// with the stream to dump into. |
147 | void PassManager::IRPrinterConfig::printBeforeIfEnabled( |
148 | Pass *pass, Operation *operation, PrintCallbackFn printCallback) { |
149 | // By default, never print. |
150 | } |
151 | |
152 | /// A hook that may be overridden by a derived config that checks if the IR |
153 | /// of 'operation' should be dumped *after* the pass 'pass' has been |
154 | /// executed. If the IR should be dumped, 'printCallback' should be invoked |
155 | /// with the stream to dump into. |
156 | void PassManager::IRPrinterConfig::printAfterIfEnabled( |
157 | Pass *pass, Operation *operation, PrintCallbackFn printCallback) { |
158 | // By default, never print. |
159 | } |
160 | |
161 | //===----------------------------------------------------------------------===// |
162 | // PassManager |
163 | //===----------------------------------------------------------------------===// |
164 | |
165 | namespace { |
166 | /// Simple wrapper config that allows for the simpler interface defined above. |
167 | struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig { |
168 | BasicIRPrinterConfig( |
169 | std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, |
170 | std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, |
171 | bool printModuleScope, bool printAfterOnlyOnChange, |
172 | bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags, |
173 | raw_ostream &out) |
174 | : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange, |
175 | printAfterOnlyOnFailure, opPrintingFlags), |
176 | shouldPrintBeforePass(std::move(shouldPrintBeforePass)), |
177 | shouldPrintAfterPass(std::move(shouldPrintAfterPass)), out(out) { |
178 | assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) && |
179 | "expected at least one valid filter function" ); |
180 | } |
181 | |
182 | void printBeforeIfEnabled(Pass *pass, Operation *operation, |
183 | PrintCallbackFn printCallback) final { |
184 | if (shouldPrintBeforePass && shouldPrintBeforePass(pass, operation)) |
185 | printCallback(out); |
186 | } |
187 | |
188 | void printAfterIfEnabled(Pass *pass, Operation *operation, |
189 | PrintCallbackFn printCallback) final { |
190 | if (shouldPrintAfterPass && shouldPrintAfterPass(pass, operation)) |
191 | printCallback(out); |
192 | } |
193 | |
194 | /// Filter functions for before and after pass execution. |
195 | std::function<bool(Pass *, Operation *)> shouldPrintBeforePass; |
196 | std::function<bool(Pass *, Operation *)> shouldPrintAfterPass; |
197 | |
198 | /// The stream to output to. |
199 | raw_ostream &out; |
200 | }; |
201 | } // namespace |
202 | |
203 | /// Add an instrumentation to print the IR before and after pass execution, |
204 | /// using the provided configuration. |
205 | void PassManager::enableIRPrinting(std::unique_ptr<IRPrinterConfig> config) { |
206 | if (config->shouldPrintAtModuleScope() && |
207 | getContext()->isMultithreadingEnabled()) |
208 | llvm::report_fatal_error(reason: "IR printing can't be setup on a pass-manager " |
209 | "without disabling multi-threading first." ); |
210 | addInstrumentation( |
211 | pi: std::make_unique<IRPrinterInstrumentation>(args: std::move(config))); |
212 | } |
213 | |
214 | /// Add an instrumentation to print the IR before and after pass execution. |
215 | void PassManager::enableIRPrinting( |
216 | std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, |
217 | std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, |
218 | bool printModuleScope, bool printAfterOnlyOnChange, |
219 | bool printAfterOnlyOnFailure, raw_ostream &out, |
220 | OpPrintingFlags opPrintingFlags) { |
221 | enableIRPrinting(config: std::make_unique<BasicIRPrinterConfig>( |
222 | args: std::move(shouldPrintBeforePass), args: std::move(shouldPrintAfterPass), |
223 | args&: printModuleScope, args&: printAfterOnlyOnChange, args&: printAfterOnlyOnFailure, |
224 | args&: opPrintingFlags, args&: out)); |
225 | } |
226 | |