1//===- TestIRVisitorsGeneric.cpp - Pass to test the Generic IR visitors ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestOps.h"
10#include "mlir/Pass/Pass.h"
11
12using namespace mlir;
13
14static std::string getStageDescription(const WalkStage &stage) {
15 if (stage.isBeforeAllRegions())
16 return "before all regions";
17 if (stage.isAfterAllRegions())
18 return "after all regions";
19 return "before region #" + std::to_string(val: stage.getNextRegion());
20}
21
22namespace {
23/// This pass exercises generic visitor with void callbacks and prints the order
24/// and stage in which operations are visited.
25struct TestGenericIRVisitorPass
26 : public PassWrapper<TestGenericIRVisitorPass, OperationPass<>> {
27 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGenericIRVisitorPass)
28
29 StringRef getArgument() const final { return "test-generic-ir-visitors"; }
30 StringRef getDescription() const final { return "Test generic IR visitors."; }
31 void runOnOperation() override {
32 Operation *outerOp = getOperation();
33 int stepNo = 0;
34 outerOp->walk(callback: [&](Operation *op, const WalkStage &stage) {
35 llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
36 << getStageDescription(stage) << "\n";
37 });
38
39 // Exercise static inference of operation type.
40 outerOp->walk(callback: [&](test::TwoRegionOp op, const WalkStage &stage) {
41 llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
42 << getStageDescription(stage) << "\n";
43 });
44 }
45};
46
47/// This pass exercises the generic visitor with non-void callbacks and prints
48/// the order and stage in which operations are visited. It will interrupt the
49/// walk based on attributes peesent in the IR.
50struct TestGenericIRVisitorInterruptPass
51 : public PassWrapper<TestGenericIRVisitorInterruptPass, OperationPass<>> {
52 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
53 TestGenericIRVisitorInterruptPass)
54
55 StringRef getArgument() const final {
56 return "test-generic-ir-visitors-interrupt";
57 }
58 StringRef getDescription() const final {
59 return "Test generic IR visitors with interrupts.";
60 }
61 void runOnOperation() override {
62 Operation *outerOp = getOperation();
63 int stepNo = 0;
64
65 auto walker = [&](Operation *op, const WalkStage &stage) {
66 if (auto interruptBeforeAall =
67 op->getAttrOfType<BoolAttr>(name: "interrupt_before_all"))
68 if (interruptBeforeAall.getValue() && stage.isBeforeAllRegions())
69 return WalkResult::interrupt();
70
71 if (auto interruptAfterAll =
72 op->getAttrOfType<BoolAttr>(name: "interrupt_after_all"))
73 if (interruptAfterAll.getValue() && stage.isAfterAllRegions())
74 return WalkResult::interrupt();
75
76 if (auto interruptAfterRegion =
77 op->getAttrOfType<IntegerAttr>("interrupt_after_region"))
78 if (stage.isAfterRegion(
79 region: static_cast<int>(interruptAfterRegion.getInt())))
80 return WalkResult::interrupt();
81
82 if (auto skipBeforeAall = op->getAttrOfType<BoolAttr>(name: "skip_before_all"))
83 if (skipBeforeAall.getValue() && stage.isBeforeAllRegions())
84 return WalkResult::skip();
85
86 if (auto skipAfterAll = op->getAttrOfType<BoolAttr>(name: "skip_after_all"))
87 if (skipAfterAll.getValue() && stage.isAfterAllRegions())
88 return WalkResult::skip();
89
90 if (auto skipAfterRegion =
91 op->getAttrOfType<IntegerAttr>("skip_after_region"))
92 if (stage.isAfterRegion(region: static_cast<int>(skipAfterRegion.getInt())))
93 return WalkResult::skip();
94
95 llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
96 << getStageDescription(stage) << "\n";
97 return WalkResult::advance();
98 };
99
100 // Interrupt the walk based on attributes on the operation.
101 auto result = outerOp->walk(callback&: walker);
102
103 if (result.wasInterrupted())
104 llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
105
106 // Exercise static inference of operation type.
107 result = outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
108 return walker(op, stage);
109 });
110
111 if (result.wasInterrupted())
112 llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
113 }
114};
115
116struct TestGenericIRBlockVisitorInterruptPass
117 : public PassWrapper<TestGenericIRBlockVisitorInterruptPass,
118 OperationPass<ModuleOp>> {
119 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
120 TestGenericIRBlockVisitorInterruptPass)
121
122 StringRef getArgument() const final {
123 return "test-generic-ir-block-visitors-interrupt";
124 }
125 StringRef getDescription() const final {
126 return "Test generic IR visitors with interrupts, starting with Blocks.";
127 }
128
129 void runOnOperation() override {
130 int stepNo = 0;
131
132 auto walker = [&](Block *block) {
133 for (Operation &op : *block)
134 if (op.getAttrOfType<BoolAttr>(name: "interrupt"))
135 return WalkResult::interrupt();
136
137 llvm::outs() << "step " << stepNo++ << "\n";
138 return WalkResult::advance();
139 };
140
141 auto result = getOperation()->walk(walker);
142 if (result.wasInterrupted())
143 llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
144 }
145};
146
147struct TestGenericIRRegionVisitorInterruptPass
148 : public PassWrapper<TestGenericIRRegionVisitorInterruptPass,
149 OperationPass<ModuleOp>> {
150 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
151 TestGenericIRRegionVisitorInterruptPass)
152
153 StringRef getArgument() const final {
154 return "test-generic-ir-region-visitors-interrupt";
155 }
156 StringRef getDescription() const final {
157 return "Test generic IR visitors with interrupts, starting with Regions.";
158 }
159
160 void runOnOperation() override {
161 int stepNo = 0;
162
163 auto walker = [&](Region *region) {
164 for (Operation &op : region->getOps())
165 if (op.getAttrOfType<BoolAttr>(name: "interrupt"))
166 return WalkResult::interrupt();
167
168 llvm::outs() << "step " << stepNo++ << "\n";
169 return WalkResult::advance();
170 };
171
172 auto result = getOperation()->walk(walker);
173 if (result.wasInterrupted())
174 llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
175 }
176};
177
178} // namespace
179
180namespace mlir {
181namespace test {
182void registerTestGenericIRVisitorsPass() {
183 PassRegistration<TestGenericIRVisitorPass>();
184 PassRegistration<TestGenericIRVisitorInterruptPass>();
185 PassRegistration<TestGenericIRBlockVisitorInterruptPass>();
186 PassRegistration<TestGenericIRRegionVisitorInterruptPass>();
187}
188
189} // namespace test
190} // namespace mlir
191

source code of mlir/test/lib/IR/TestVisitorsGeneric.cpp