1 | //===- TestIRVisitorsGeneric.cpp - Pass to test the Generic IR visitors ---===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "TestOps.h" |
10 | #include "mlir/Pass/Pass.h" |
11 | |
12 | using namespace mlir; |
13 | |
14 | static std::string getStageDescription(const WalkStage &stage) { |
15 | if (stage.isBeforeAllRegions()) |
16 | return "before all regions" ; |
17 | if (stage.isAfterAllRegions()) |
18 | return "after all regions" ; |
19 | return "before region #" + std::to_string(val: stage.getNextRegion()); |
20 | } |
21 | |
22 | namespace { |
23 | /// This pass exercises generic visitor with void callbacks and prints the order |
24 | /// and stage in which operations are visited. |
25 | struct TestGenericIRVisitorPass |
26 | : public PassWrapper<TestGenericIRVisitorPass, OperationPass<>> { |
27 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGenericIRVisitorPass) |
28 | |
29 | StringRef getArgument() const final { return "test-generic-ir-visitors" ; } |
30 | StringRef getDescription() const final { return "Test generic IR visitors." ; } |
31 | void runOnOperation() override { |
32 | Operation *outerOp = getOperation(); |
33 | int stepNo = 0; |
34 | outerOp->walk(callback: [&](Operation *op, const WalkStage &stage) { |
35 | llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' " |
36 | << getStageDescription(stage) << "\n" ; |
37 | }); |
38 | |
39 | // Exercise static inference of operation type. |
40 | outerOp->walk(callback: [&](test::TwoRegionOp op, const WalkStage &stage) { |
41 | llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' " |
42 | << getStageDescription(stage) << "\n" ; |
43 | }); |
44 | } |
45 | }; |
46 | |
47 | /// This pass exercises the generic visitor with non-void callbacks and prints |
48 | /// the order and stage in which operations are visited. It will interrupt the |
49 | /// walk based on attributes peesent in the IR. |
50 | struct TestGenericIRVisitorInterruptPass |
51 | : public PassWrapper<TestGenericIRVisitorInterruptPass, OperationPass<>> { |
52 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
53 | TestGenericIRVisitorInterruptPass) |
54 | |
55 | StringRef getArgument() const final { |
56 | return "test-generic-ir-visitors-interrupt" ; |
57 | } |
58 | StringRef getDescription() const final { |
59 | return "Test generic IR visitors with interrupts." ; |
60 | } |
61 | void runOnOperation() override { |
62 | Operation *outerOp = getOperation(); |
63 | int stepNo = 0; |
64 | |
65 | auto walker = [&](Operation *op, const WalkStage &stage) { |
66 | if (auto interruptBeforeAall = |
67 | op->getAttrOfType<BoolAttr>(name: "interrupt_before_all" )) |
68 | if (interruptBeforeAall.getValue() && stage.isBeforeAllRegions()) |
69 | return WalkResult::interrupt(); |
70 | |
71 | if (auto interruptAfterAll = |
72 | op->getAttrOfType<BoolAttr>(name: "interrupt_after_all" )) |
73 | if (interruptAfterAll.getValue() && stage.isAfterAllRegions()) |
74 | return WalkResult::interrupt(); |
75 | |
76 | if (auto interruptAfterRegion = |
77 | op->getAttrOfType<IntegerAttr>("interrupt_after_region" )) |
78 | if (stage.isAfterRegion( |
79 | region: static_cast<int>(interruptAfterRegion.getInt()))) |
80 | return WalkResult::interrupt(); |
81 | |
82 | if (auto skipBeforeAall = op->getAttrOfType<BoolAttr>(name: "skip_before_all" )) |
83 | if (skipBeforeAall.getValue() && stage.isBeforeAllRegions()) |
84 | return WalkResult::skip(); |
85 | |
86 | if (auto skipAfterAll = op->getAttrOfType<BoolAttr>(name: "skip_after_all" )) |
87 | if (skipAfterAll.getValue() && stage.isAfterAllRegions()) |
88 | return WalkResult::skip(); |
89 | |
90 | if (auto skipAfterRegion = |
91 | op->getAttrOfType<IntegerAttr>("skip_after_region" )) |
92 | if (stage.isAfterRegion(region: static_cast<int>(skipAfterRegion.getInt()))) |
93 | return WalkResult::skip(); |
94 | |
95 | llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' " |
96 | << getStageDescription(stage) << "\n" ; |
97 | return WalkResult::advance(); |
98 | }; |
99 | |
100 | // Interrupt the walk based on attributes on the operation. |
101 | auto result = outerOp->walk(callback&: walker); |
102 | |
103 | if (result.wasInterrupted()) |
104 | llvm::outs() << "step " << stepNo++ << " walk was interrupted\n" ; |
105 | |
106 | // Exercise static inference of operation type. |
107 | result = outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) { |
108 | return walker(op, stage); |
109 | }); |
110 | |
111 | if (result.wasInterrupted()) |
112 | llvm::outs() << "step " << stepNo++ << " walk was interrupted\n" ; |
113 | } |
114 | }; |
115 | |
116 | struct TestGenericIRBlockVisitorInterruptPass |
117 | : public PassWrapper<TestGenericIRBlockVisitorInterruptPass, |
118 | OperationPass<ModuleOp>> { |
119 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
120 | TestGenericIRBlockVisitorInterruptPass) |
121 | |
122 | StringRef getArgument() const final { |
123 | return "test-generic-ir-block-visitors-interrupt" ; |
124 | } |
125 | StringRef getDescription() const final { |
126 | return "Test generic IR visitors with interrupts, starting with Blocks." ; |
127 | } |
128 | |
129 | void runOnOperation() override { |
130 | int stepNo = 0; |
131 | |
132 | auto walker = [&](Block *block) { |
133 | for (Operation &op : *block) |
134 | if (op.getAttrOfType<BoolAttr>(name: "interrupt" )) |
135 | return WalkResult::interrupt(); |
136 | |
137 | llvm::outs() << "step " << stepNo++ << "\n" ; |
138 | return WalkResult::advance(); |
139 | }; |
140 | |
141 | auto result = getOperation()->walk(walker); |
142 | if (result.wasInterrupted()) |
143 | llvm::outs() << "step " << stepNo++ << " walk was interrupted\n" ; |
144 | } |
145 | }; |
146 | |
147 | struct TestGenericIRRegionVisitorInterruptPass |
148 | : public PassWrapper<TestGenericIRRegionVisitorInterruptPass, |
149 | OperationPass<ModuleOp>> { |
150 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
151 | TestGenericIRRegionVisitorInterruptPass) |
152 | |
153 | StringRef getArgument() const final { |
154 | return "test-generic-ir-region-visitors-interrupt" ; |
155 | } |
156 | StringRef getDescription() const final { |
157 | return "Test generic IR visitors with interrupts, starting with Regions." ; |
158 | } |
159 | |
160 | void runOnOperation() override { |
161 | int stepNo = 0; |
162 | |
163 | auto walker = [&](Region *region) { |
164 | for (Operation &op : region->getOps()) |
165 | if (op.getAttrOfType<BoolAttr>(name: "interrupt" )) |
166 | return WalkResult::interrupt(); |
167 | |
168 | llvm::outs() << "step " << stepNo++ << "\n" ; |
169 | return WalkResult::advance(); |
170 | }; |
171 | |
172 | auto result = getOperation()->walk(walker); |
173 | if (result.wasInterrupted()) |
174 | llvm::outs() << "step " << stepNo++ << " walk was interrupted\n" ; |
175 | } |
176 | }; |
177 | |
178 | } // namespace |
179 | |
180 | namespace mlir { |
181 | namespace test { |
182 | void registerTestGenericIRVisitorsPass() { |
183 | PassRegistration<TestGenericIRVisitorPass>(); |
184 | PassRegistration<TestGenericIRVisitorInterruptPass>(); |
185 | PassRegistration<TestGenericIRBlockVisitorInterruptPass>(); |
186 | PassRegistration<TestGenericIRRegionVisitorInterruptPass>(); |
187 | } |
188 | |
189 | } // namespace test |
190 | } // namespace mlir |
191 | |