1 | //===- TestIRVisitors.cpp - Pass to test the IR visitors ------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/IR/BuiltinOps.h" |
10 | #include "mlir/IR/Iterators.h" |
11 | #include "mlir/Interfaces/FunctionInterfaces.h" |
12 | #include "mlir/Pass/Pass.h" |
13 | |
14 | using namespace mlir; |
15 | |
16 | static void printRegion(Region *region) { |
17 | llvm::outs() << "region " << region->getRegionNumber() << " from operation '" |
18 | << region->getParentOp()->getName() << "'" ; |
19 | } |
20 | |
21 | static void printBlock(Block *block) { |
22 | llvm::outs() << "block " ; |
23 | block->printAsOperand(os&: llvm::outs(), /*printType=*/false); |
24 | llvm::outs() << " from " ; |
25 | printRegion(region: block->getParent()); |
26 | } |
27 | |
28 | static void printOperation(Operation *op) { |
29 | llvm::outs() << "op '" << op->getName() << "'" ; |
30 | } |
31 | |
32 | /// Tests pure callbacks. |
33 | static void testPureCallbacks(Operation *op) { |
34 | auto opPure = [](Operation *op) { |
35 | llvm::outs() << "Visiting " ; |
36 | printOperation(op); |
37 | llvm::outs() << "\n" ; |
38 | }; |
39 | auto blockPure = [](Block *block) { |
40 | llvm::outs() << "Visiting " ; |
41 | printBlock(block); |
42 | llvm::outs() << "\n" ; |
43 | }; |
44 | auto regionPure = [](Region *region) { |
45 | llvm::outs() << "Visiting " ; |
46 | printRegion(region); |
47 | llvm::outs() << "\n" ; |
48 | }; |
49 | |
50 | llvm::outs() << "Op pre-order visits" |
51 | << "\n" ; |
52 | op->walk<WalkOrder::PreOrder>(callback&: opPure); |
53 | llvm::outs() << "Block pre-order visits" |
54 | << "\n" ; |
55 | op->walk<WalkOrder::PreOrder>(callback&: blockPure); |
56 | llvm::outs() << "Region pre-order visits" |
57 | << "\n" ; |
58 | op->walk<WalkOrder::PreOrder>(callback&: regionPure); |
59 | |
60 | llvm::outs() << "Op post-order visits" |
61 | << "\n" ; |
62 | op->walk<WalkOrder::PostOrder>(callback&: opPure); |
63 | llvm::outs() << "Block post-order visits" |
64 | << "\n" ; |
65 | op->walk<WalkOrder::PostOrder>(callback&: blockPure); |
66 | llvm::outs() << "Region post-order visits" |
67 | << "\n" ; |
68 | op->walk<WalkOrder::PostOrder>(callback&: regionPure); |
69 | |
70 | llvm::outs() << "Op reverse post-order visits" |
71 | << "\n" ; |
72 | op->walk<WalkOrder::PostOrder, ReverseIterator>(callback&: opPure); |
73 | llvm::outs() << "Block reverse post-order visits" |
74 | << "\n" ; |
75 | op->walk<WalkOrder::PostOrder, ReverseIterator>(callback&: blockPure); |
76 | llvm::outs() << "Region reverse post-order visits" |
77 | << "\n" ; |
78 | op->walk<WalkOrder::PostOrder, ReverseIterator>(callback&: regionPure); |
79 | |
80 | // This test case tests "NoGraphRegions = true", so start the walk with |
81 | // functions. |
82 | op->walk([&](FunctionOpInterface funcOp) { |
83 | llvm::outs() << "Op forward dominance post-order visits" |
84 | << "\n" ; |
85 | funcOp->walk<WalkOrder::PostOrder, |
86 | ForwardDominanceIterator</*NoGraphRegions=*/true>>(opPure); |
87 | llvm::outs() << "Block forward dominance post-order visits" |
88 | << "\n" ; |
89 | funcOp->walk<WalkOrder::PostOrder, |
90 | ForwardDominanceIterator</*NoGraphRegions=*/true>>(blockPure); |
91 | llvm::outs() << "Region forward dominance post-order visits" |
92 | << "\n" ; |
93 | funcOp->walk<WalkOrder::PostOrder, |
94 | ForwardDominanceIterator</*NoGraphRegions=*/true>>(regionPure); |
95 | |
96 | llvm::outs() << "Op reverse dominance post-order visits" |
97 | << "\n" ; |
98 | funcOp->walk<WalkOrder::PostOrder, |
99 | ReverseDominanceIterator</*NoGraphRegions=*/true>>(opPure); |
100 | llvm::outs() << "Block reverse dominance post-order visits" |
101 | << "\n" ; |
102 | funcOp->walk<WalkOrder::PostOrder, |
103 | ReverseDominanceIterator</*NoGraphRegions=*/true>>(blockPure); |
104 | llvm::outs() << "Region reverse dominance post-order visits" |
105 | << "\n" ; |
106 | funcOp->walk<WalkOrder::PostOrder, |
107 | ReverseDominanceIterator</*NoGraphRegions=*/true>>(regionPure); |
108 | }); |
109 | } |
110 | |
111 | /// Tests erasure callbacks that skip the walk. |
112 | static void testSkipErasureCallbacks(Operation *op) { |
113 | auto skipOpErasure = [](Operation *op) { |
114 | // Do not erase module and module children operations. Otherwise, there |
115 | // wouldn't be too much to test in pre-order. |
116 | if (isa<ModuleOp>(Val: op) || isa<ModuleOp>(Val: op->getParentOp())) |
117 | return WalkResult::advance(); |
118 | |
119 | llvm::outs() << "Erasing " ; |
120 | printOperation(op); |
121 | llvm::outs() << "\n" ; |
122 | op->dropAllUses(); |
123 | op->erase(); |
124 | return WalkResult::skip(); |
125 | }; |
126 | auto skipBlockErasure = [](Block *block) { |
127 | // Do not erase module and module children blocks. Otherwise there wouldn't |
128 | // be too much to test in pre-order. |
129 | Operation *parentOp = block->getParentOp(); |
130 | if (isa<ModuleOp>(Val: parentOp) || isa<ModuleOp>(Val: parentOp->getParentOp())) |
131 | return WalkResult::advance(); |
132 | |
133 | if (block->use_empty()) { |
134 | llvm::outs() << "Erasing " ; |
135 | printBlock(block); |
136 | llvm::outs() << "\n" ; |
137 | block->erase(); |
138 | return WalkResult::skip(); |
139 | } |
140 | llvm::outs() << "Cannot erase " ; |
141 | printBlock(block); |
142 | llvm::outs() << ", still has uses\n" ; |
143 | return WalkResult::advance(); |
144 | |
145 | }; |
146 | |
147 | llvm::outs() << "Op pre-order erasures (skip)" |
148 | << "\n" ; |
149 | Operation *cloned = op->clone(); |
150 | cloned->walk<WalkOrder::PreOrder>(callback&: skipOpErasure); |
151 | cloned->erase(); |
152 | |
153 | llvm::outs() << "Block pre-order erasures (skip)" |
154 | << "\n" ; |
155 | cloned = op->clone(); |
156 | cloned->walk<WalkOrder::PreOrder>(callback&: skipBlockErasure); |
157 | cloned->erase(); |
158 | |
159 | llvm::outs() << "Op post-order erasures (skip)" |
160 | << "\n" ; |
161 | cloned = op->clone(); |
162 | cloned->walk<WalkOrder::PostOrder>(callback&: skipOpErasure); |
163 | cloned->erase(); |
164 | |
165 | llvm::outs() << "Block post-order erasures (skip)" |
166 | << "\n" ; |
167 | cloned = op->clone(); |
168 | cloned->walk<WalkOrder::PostOrder>(callback&: skipBlockErasure); |
169 | cloned->erase(); |
170 | } |
171 | |
172 | /// Tests callbacks that erase the op or block but don't return 'Skip'. This |
173 | /// callbacks are only valid in post-order. |
174 | static void testNoSkipErasureCallbacks(Operation *op) { |
175 | auto noSkipOpErasure = [](Operation *op) { |
176 | llvm::outs() << "Erasing " ; |
177 | printOperation(op); |
178 | llvm::outs() << "\n" ; |
179 | op->dropAllUses(); |
180 | op->erase(); |
181 | }; |
182 | auto noSkipBlockErasure = [](Block *block) { |
183 | if (block->use_empty()) { |
184 | llvm::outs() << "Erasing " ; |
185 | printBlock(block); |
186 | llvm::outs() << "\n" ; |
187 | block->erase(); |
188 | } else { |
189 | llvm::outs() << "Cannot erase " ; |
190 | printBlock(block); |
191 | llvm::outs() << ", still has uses\n" ; |
192 | } |
193 | }; |
194 | |
195 | llvm::outs() << "Op post-order erasures (no skip)" |
196 | << "\n" ; |
197 | Operation *cloned = op->clone(); |
198 | cloned->walk<WalkOrder::PostOrder>(callback&: noSkipOpErasure); |
199 | |
200 | llvm::outs() << "Block post-order erasures (no skip)" |
201 | << "\n" ; |
202 | cloned = op->clone(); |
203 | cloned->walk<WalkOrder::PostOrder>(callback&: noSkipBlockErasure); |
204 | cloned->erase(); |
205 | } |
206 | |
207 | /// Invoke region/block walks on regions/blocks. |
208 | static void testBlockAndRegionWalkers(Operation *op) { |
209 | auto blockPure = [](Block *block) { |
210 | llvm::outs() << "Visiting " ; |
211 | printBlock(block); |
212 | llvm::outs() << "\n" ; |
213 | }; |
214 | auto regionPure = [](Region *region) { |
215 | llvm::outs() << "Visiting " ; |
216 | printRegion(region); |
217 | llvm::outs() << "\n" ; |
218 | }; |
219 | |
220 | llvm::outs() << "Invoke block pre-order visits on blocks\n" ; |
221 | op->walk(callback: [&](Operation *op) { |
222 | if (!op->hasAttr(name: "walk_blocks" )) |
223 | return; |
224 | for (Region ®ion : op->getRegions()) { |
225 | for (Block &block : region.getBlocks()) { |
226 | block.walk<WalkOrder::PreOrder>(callback&: blockPure); |
227 | } |
228 | } |
229 | }); |
230 | |
231 | llvm::outs() << "Invoke block post-order visits on blocks\n" ; |
232 | op->walk(callback: [&](Operation *op) { |
233 | if (!op->hasAttr(name: "walk_blocks" )) |
234 | return; |
235 | for (Region ®ion : op->getRegions()) { |
236 | for (Block &block : region.getBlocks()) { |
237 | block.walk<WalkOrder::PostOrder>(callback&: blockPure); |
238 | } |
239 | } |
240 | }); |
241 | |
242 | llvm::outs() << "Invoke region pre-order visits on region\n" ; |
243 | op->walk(callback: [&](Operation *op) { |
244 | if (!op->hasAttr(name: "walk_regions" )) |
245 | return; |
246 | for (Region ®ion : op->getRegions()) { |
247 | region.walk<WalkOrder::PreOrder>(callback&: regionPure); |
248 | } |
249 | }); |
250 | |
251 | llvm::outs() << "Invoke region post-order visits on region\n" ; |
252 | op->walk(callback: [&](Operation *op) { |
253 | if (!op->hasAttr(name: "walk_regions" )) |
254 | return; |
255 | for (Region ®ion : op->getRegions()) { |
256 | region.walk<WalkOrder::PostOrder>(callback&: regionPure); |
257 | } |
258 | }); |
259 | } |
260 | |
261 | namespace { |
262 | /// This pass exercises the different configurations of the IR visitors. |
263 | struct TestIRVisitorsPass |
264 | : public PassWrapper<TestIRVisitorsPass, OperationPass<>> { |
265 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIRVisitorsPass) |
266 | |
267 | StringRef getArgument() const final { return "test-ir-visitors" ; } |
268 | StringRef getDescription() const final { return "Test various visitors." ; } |
269 | void runOnOperation() override { |
270 | Operation *op = getOperation(); |
271 | testPureCallbacks(op); |
272 | testBlockAndRegionWalkers(op); |
273 | testSkipErasureCallbacks(op); |
274 | testNoSkipErasureCallbacks(op); |
275 | } |
276 | }; |
277 | } // namespace |
278 | |
279 | namespace mlir { |
280 | namespace test { |
281 | void registerTestIRVisitorsPass() { PassRegistration<TestIRVisitorsPass>(); } |
282 | } // namespace test |
283 | } // namespace mlir |
284 | |