1//===- TestIRVisitors.cpp - Pass to test the IR visitors ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/BuiltinOps.h"
10#include "mlir/IR/Iterators.h"
11#include "mlir/Interfaces/FunctionInterfaces.h"
12#include "mlir/Pass/Pass.h"
13
14using namespace mlir;
15
16static void printRegion(Region *region) {
17 llvm::outs() << "region " << region->getRegionNumber() << " from operation '"
18 << region->getParentOp()->getName() << "'";
19}
20
21static void printBlock(Block *block) {
22 llvm::outs() << "block ";
23 block->printAsOperand(os&: llvm::outs(), /*printType=*/false);
24 llvm::outs() << " from ";
25 printRegion(region: block->getParent());
26}
27
28static void printOperation(Operation *op) {
29 llvm::outs() << "op '" << op->getName() << "'";
30}
31
32/// Tests pure callbacks.
33static void testPureCallbacks(Operation *op) {
34 auto opPure = [](Operation *op) {
35 llvm::outs() << "Visiting ";
36 printOperation(op);
37 llvm::outs() << "\n";
38 };
39 auto blockPure = [](Block *block) {
40 llvm::outs() << "Visiting ";
41 printBlock(block);
42 llvm::outs() << "\n";
43 };
44 auto regionPure = [](Region *region) {
45 llvm::outs() << "Visiting ";
46 printRegion(region);
47 llvm::outs() << "\n";
48 };
49
50 llvm::outs() << "Op pre-order visits"
51 << "\n";
52 op->walk<WalkOrder::PreOrder>(callback&: opPure);
53 llvm::outs() << "Block pre-order visits"
54 << "\n";
55 op->walk<WalkOrder::PreOrder>(callback&: blockPure);
56 llvm::outs() << "Region pre-order visits"
57 << "\n";
58 op->walk<WalkOrder::PreOrder>(callback&: regionPure);
59
60 llvm::outs() << "Op post-order visits"
61 << "\n";
62 op->walk<WalkOrder::PostOrder>(callback&: opPure);
63 llvm::outs() << "Block post-order visits"
64 << "\n";
65 op->walk<WalkOrder::PostOrder>(callback&: blockPure);
66 llvm::outs() << "Region post-order visits"
67 << "\n";
68 op->walk<WalkOrder::PostOrder>(callback&: regionPure);
69
70 llvm::outs() << "Op reverse post-order visits"
71 << "\n";
72 op->walk<WalkOrder::PostOrder, ReverseIterator>(callback&: opPure);
73 llvm::outs() << "Block reverse post-order visits"
74 << "\n";
75 op->walk<WalkOrder::PostOrder, ReverseIterator>(callback&: blockPure);
76 llvm::outs() << "Region reverse post-order visits"
77 << "\n";
78 op->walk<WalkOrder::PostOrder, ReverseIterator>(callback&: regionPure);
79
80 // This test case tests "NoGraphRegions = true", so start the walk with
81 // functions.
82 op->walk([&](FunctionOpInterface funcOp) {
83 llvm::outs() << "Op forward dominance post-order visits"
84 << "\n";
85 funcOp->walk<WalkOrder::PostOrder,
86 ForwardDominanceIterator</*NoGraphRegions=*/true>>(opPure);
87 llvm::outs() << "Block forward dominance post-order visits"
88 << "\n";
89 funcOp->walk<WalkOrder::PostOrder,
90 ForwardDominanceIterator</*NoGraphRegions=*/true>>(blockPure);
91 llvm::outs() << "Region forward dominance post-order visits"
92 << "\n";
93 funcOp->walk<WalkOrder::PostOrder,
94 ForwardDominanceIterator</*NoGraphRegions=*/true>>(regionPure);
95
96 llvm::outs() << "Op reverse dominance post-order visits"
97 << "\n";
98 funcOp->walk<WalkOrder::PostOrder,
99 ReverseDominanceIterator</*NoGraphRegions=*/true>>(opPure);
100 llvm::outs() << "Block reverse dominance post-order visits"
101 << "\n";
102 funcOp->walk<WalkOrder::PostOrder,
103 ReverseDominanceIterator</*NoGraphRegions=*/true>>(blockPure);
104 llvm::outs() << "Region reverse dominance post-order visits"
105 << "\n";
106 funcOp->walk<WalkOrder::PostOrder,
107 ReverseDominanceIterator</*NoGraphRegions=*/true>>(regionPure);
108 });
109}
110
111/// Tests erasure callbacks that skip the walk.
112static void testSkipErasureCallbacks(Operation *op) {
113 auto skipOpErasure = [](Operation *op) {
114 // Do not erase module and module children operations. Otherwise, there
115 // wouldn't be too much to test in pre-order.
116 if (isa<ModuleOp>(Val: op) || isa<ModuleOp>(Val: op->getParentOp()))
117 return WalkResult::advance();
118
119 llvm::outs() << "Erasing ";
120 printOperation(op);
121 llvm::outs() << "\n";
122 op->dropAllUses();
123 op->erase();
124 return WalkResult::skip();
125 };
126 auto skipBlockErasure = [](Block *block) {
127 // Do not erase module and module children blocks. Otherwise there wouldn't
128 // be too much to test in pre-order.
129 Operation *parentOp = block->getParentOp();
130 if (isa<ModuleOp>(Val: parentOp) || isa<ModuleOp>(Val: parentOp->getParentOp()))
131 return WalkResult::advance();
132
133 if (block->use_empty()) {
134 llvm::outs() << "Erasing ";
135 printBlock(block);
136 llvm::outs() << "\n";
137 block->erase();
138 return WalkResult::skip();
139 }
140 llvm::outs() << "Cannot erase ";
141 printBlock(block);
142 llvm::outs() << ", still has uses\n";
143 return WalkResult::advance();
144
145 };
146
147 llvm::outs() << "Op pre-order erasures (skip)"
148 << "\n";
149 Operation *cloned = op->clone();
150 cloned->walk<WalkOrder::PreOrder>(callback&: skipOpErasure);
151 cloned->erase();
152
153 llvm::outs() << "Block pre-order erasures (skip)"
154 << "\n";
155 cloned = op->clone();
156 cloned->walk<WalkOrder::PreOrder>(callback&: skipBlockErasure);
157 cloned->erase();
158
159 llvm::outs() << "Op post-order erasures (skip)"
160 << "\n";
161 cloned = op->clone();
162 cloned->walk<WalkOrder::PostOrder>(callback&: skipOpErasure);
163 cloned->erase();
164
165 llvm::outs() << "Block post-order erasures (skip)"
166 << "\n";
167 cloned = op->clone();
168 cloned->walk<WalkOrder::PostOrder>(callback&: skipBlockErasure);
169 cloned->erase();
170}
171
172/// Tests callbacks that erase the op or block but don't return 'Skip'. This
173/// callbacks are only valid in post-order.
174static void testNoSkipErasureCallbacks(Operation *op) {
175 auto noSkipOpErasure = [](Operation *op) {
176 llvm::outs() << "Erasing ";
177 printOperation(op);
178 llvm::outs() << "\n";
179 op->dropAllUses();
180 op->erase();
181 };
182 auto noSkipBlockErasure = [](Block *block) {
183 if (block->use_empty()) {
184 llvm::outs() << "Erasing ";
185 printBlock(block);
186 llvm::outs() << "\n";
187 block->erase();
188 } else {
189 llvm::outs() << "Cannot erase ";
190 printBlock(block);
191 llvm::outs() << ", still has uses\n";
192 }
193 };
194
195 llvm::outs() << "Op post-order erasures (no skip)"
196 << "\n";
197 Operation *cloned = op->clone();
198 cloned->walk<WalkOrder::PostOrder>(callback&: noSkipOpErasure);
199
200 llvm::outs() << "Block post-order erasures (no skip)"
201 << "\n";
202 cloned = op->clone();
203 cloned->walk<WalkOrder::PostOrder>(callback&: noSkipBlockErasure);
204 cloned->erase();
205}
206
207/// Invoke region/block walks on regions/blocks.
208static void testBlockAndRegionWalkers(Operation *op) {
209 auto blockPure = [](Block *block) {
210 llvm::outs() << "Visiting ";
211 printBlock(block);
212 llvm::outs() << "\n";
213 };
214 auto regionPure = [](Region *region) {
215 llvm::outs() << "Visiting ";
216 printRegion(region);
217 llvm::outs() << "\n";
218 };
219
220 llvm::outs() << "Invoke block pre-order visits on blocks\n";
221 op->walk(callback: [&](Operation *op) {
222 if (!op->hasAttr(name: "walk_blocks"))
223 return;
224 for (Region &region : op->getRegions()) {
225 for (Block &block : region.getBlocks()) {
226 block.walk<WalkOrder::PreOrder>(callback&: blockPure);
227 }
228 }
229 });
230
231 llvm::outs() << "Invoke block post-order visits on blocks\n";
232 op->walk(callback: [&](Operation *op) {
233 if (!op->hasAttr(name: "walk_blocks"))
234 return;
235 for (Region &region : op->getRegions()) {
236 for (Block &block : region.getBlocks()) {
237 block.walk<WalkOrder::PostOrder>(callback&: blockPure);
238 }
239 }
240 });
241
242 llvm::outs() << "Invoke region pre-order visits on region\n";
243 op->walk(callback: [&](Operation *op) {
244 if (!op->hasAttr(name: "walk_regions"))
245 return;
246 for (Region &region : op->getRegions()) {
247 region.walk<WalkOrder::PreOrder>(callback&: regionPure);
248 }
249 });
250
251 llvm::outs() << "Invoke region post-order visits on region\n";
252 op->walk(callback: [&](Operation *op) {
253 if (!op->hasAttr(name: "walk_regions"))
254 return;
255 for (Region &region : op->getRegions()) {
256 region.walk<WalkOrder::PostOrder>(callback&: regionPure);
257 }
258 });
259}
260
261namespace {
262/// This pass exercises the different configurations of the IR visitors.
263struct TestIRVisitorsPass
264 : public PassWrapper<TestIRVisitorsPass, OperationPass<>> {
265 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIRVisitorsPass)
266
267 StringRef getArgument() const final { return "test-ir-visitors"; }
268 StringRef getDescription() const final { return "Test various visitors."; }
269 void runOnOperation() override {
270 Operation *op = getOperation();
271 testPureCallbacks(op);
272 testBlockAndRegionWalkers(op);
273 testSkipErasureCallbacks(op);
274 testNoSkipErasureCallbacks(op);
275 }
276};
277} // namespace
278
279namespace mlir {
280namespace test {
281void registerTestIRVisitorsPass() { PassRegistration<TestIRVisitorsPass>(); }
282} // namespace test
283} // namespace mlir
284

source code of mlir/test/lib/IR/TestVisitors.cpp