| 1 | //===- TestIRVisitorsGeneric.cpp - Pass to test the Generic IR visitors ---===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "TestOps.h" |
| 10 | #include "mlir/Pass/Pass.h" |
| 11 | |
| 12 | using namespace mlir; |
| 13 | |
| 14 | static std::string getStageDescription(const WalkStage &stage) { |
| 15 | if (stage.isBeforeAllRegions()) |
| 16 | return "before all regions" ; |
| 17 | if (stage.isAfterAllRegions()) |
| 18 | return "after all regions" ; |
| 19 | return "before region #" + std::to_string(val: stage.getNextRegion()); |
| 20 | } |
| 21 | |
| 22 | namespace { |
| 23 | /// This pass exercises generic visitor with void callbacks and prints the order |
| 24 | /// and stage in which operations are visited. |
| 25 | struct TestGenericIRVisitorPass |
| 26 | : public PassWrapper<TestGenericIRVisitorPass, OperationPass<>> { |
| 27 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGenericIRVisitorPass) |
| 28 | |
| 29 | StringRef getArgument() const final { return "test-generic-ir-visitors" ; } |
| 30 | StringRef getDescription() const final { return "Test generic IR visitors." ; } |
| 31 | void runOnOperation() override { |
| 32 | Operation *outerOp = getOperation(); |
| 33 | int stepNo = 0; |
| 34 | outerOp->walk(callback: [&](Operation *op, const WalkStage &stage) { |
| 35 | llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' " |
| 36 | << getStageDescription(stage) << "\n" ; |
| 37 | }); |
| 38 | |
| 39 | // Exercise static inference of operation type. |
| 40 | outerOp->walk(callback: [&](test::TwoRegionOp op, const WalkStage &stage) { |
| 41 | llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' " |
| 42 | << getStageDescription(stage) << "\n" ; |
| 43 | }); |
| 44 | } |
| 45 | }; |
| 46 | |
| 47 | /// This pass exercises the generic visitor with non-void callbacks and prints |
| 48 | /// the order and stage in which operations are visited. It will interrupt the |
| 49 | /// walk based on attributes peesent in the IR. |
| 50 | struct TestGenericIRVisitorInterruptPass |
| 51 | : public PassWrapper<TestGenericIRVisitorInterruptPass, OperationPass<>> { |
| 52 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
| 53 | TestGenericIRVisitorInterruptPass) |
| 54 | |
| 55 | StringRef getArgument() const final { |
| 56 | return "test-generic-ir-visitors-interrupt" ; |
| 57 | } |
| 58 | StringRef getDescription() const final { |
| 59 | return "Test generic IR visitors with interrupts." ; |
| 60 | } |
| 61 | void runOnOperation() override { |
| 62 | Operation *outerOp = getOperation(); |
| 63 | int stepNo = 0; |
| 64 | |
| 65 | auto walker = [&](Operation *op, const WalkStage &stage) { |
| 66 | if (auto interruptBeforeAall = |
| 67 | op->getAttrOfType<BoolAttr>(name: "interrupt_before_all" )) |
| 68 | if (interruptBeforeAall.getValue() && stage.isBeforeAllRegions()) |
| 69 | return WalkResult::interrupt(); |
| 70 | |
| 71 | if (auto interruptAfterAll = |
| 72 | op->getAttrOfType<BoolAttr>(name: "interrupt_after_all" )) |
| 73 | if (interruptAfterAll.getValue() && stage.isAfterAllRegions()) |
| 74 | return WalkResult::interrupt(); |
| 75 | |
| 76 | if (auto interruptAfterRegion = |
| 77 | op->getAttrOfType<IntegerAttr>("interrupt_after_region" )) |
| 78 | if (stage.isAfterRegion( |
| 79 | region: static_cast<int>(interruptAfterRegion.getInt()))) |
| 80 | return WalkResult::interrupt(); |
| 81 | |
| 82 | if (auto skipBeforeAall = op->getAttrOfType<BoolAttr>(name: "skip_before_all" )) |
| 83 | if (skipBeforeAall.getValue() && stage.isBeforeAllRegions()) |
| 84 | return WalkResult::skip(); |
| 85 | |
| 86 | if (auto skipAfterAll = op->getAttrOfType<BoolAttr>(name: "skip_after_all" )) |
| 87 | if (skipAfterAll.getValue() && stage.isAfterAllRegions()) |
| 88 | return WalkResult::skip(); |
| 89 | |
| 90 | if (auto skipAfterRegion = |
| 91 | op->getAttrOfType<IntegerAttr>("skip_after_region" )) |
| 92 | if (stage.isAfterRegion(region: static_cast<int>(skipAfterRegion.getInt()))) |
| 93 | return WalkResult::skip(); |
| 94 | |
| 95 | llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' " |
| 96 | << getStageDescription(stage) << "\n" ; |
| 97 | return WalkResult::advance(); |
| 98 | }; |
| 99 | |
| 100 | // Interrupt the walk based on attributes on the operation. |
| 101 | auto result = outerOp->walk(callback&: walker); |
| 102 | |
| 103 | if (result.wasInterrupted()) |
| 104 | llvm::outs() << "step " << stepNo++ << " walk was interrupted\n" ; |
| 105 | |
| 106 | // Exercise static inference of operation type. |
| 107 | result = outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) { |
| 108 | return walker(op, stage); |
| 109 | }); |
| 110 | |
| 111 | if (result.wasInterrupted()) |
| 112 | llvm::outs() << "step " << stepNo++ << " walk was interrupted\n" ; |
| 113 | } |
| 114 | }; |
| 115 | |
| 116 | struct TestGenericIRBlockVisitorInterruptPass |
| 117 | : public PassWrapper<TestGenericIRBlockVisitorInterruptPass, |
| 118 | OperationPass<ModuleOp>> { |
| 119 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
| 120 | TestGenericIRBlockVisitorInterruptPass) |
| 121 | |
| 122 | StringRef getArgument() const final { |
| 123 | return "test-generic-ir-block-visitors-interrupt" ; |
| 124 | } |
| 125 | StringRef getDescription() const final { |
| 126 | return "Test generic IR visitors with interrupts, starting with Blocks." ; |
| 127 | } |
| 128 | |
| 129 | void runOnOperation() override { |
| 130 | int stepNo = 0; |
| 131 | |
| 132 | auto walker = [&](Block *block) { |
| 133 | for (Operation &op : *block) |
| 134 | if (op.getAttrOfType<BoolAttr>(name: "interrupt" )) |
| 135 | return WalkResult::interrupt(); |
| 136 | |
| 137 | llvm::outs() << "step " << stepNo++ << "\n" ; |
| 138 | return WalkResult::advance(); |
| 139 | }; |
| 140 | |
| 141 | auto result = getOperation()->walk(walker); |
| 142 | if (result.wasInterrupted()) |
| 143 | llvm::outs() << "step " << stepNo++ << " walk was interrupted\n" ; |
| 144 | } |
| 145 | }; |
| 146 | |
| 147 | struct TestGenericIRRegionVisitorInterruptPass |
| 148 | : public PassWrapper<TestGenericIRRegionVisitorInterruptPass, |
| 149 | OperationPass<ModuleOp>> { |
| 150 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
| 151 | TestGenericIRRegionVisitorInterruptPass) |
| 152 | |
| 153 | StringRef getArgument() const final { |
| 154 | return "test-generic-ir-region-visitors-interrupt" ; |
| 155 | } |
| 156 | StringRef getDescription() const final { |
| 157 | return "Test generic IR visitors with interrupts, starting with Regions." ; |
| 158 | } |
| 159 | |
| 160 | void runOnOperation() override { |
| 161 | int stepNo = 0; |
| 162 | |
| 163 | auto walker = [&](Region *region) { |
| 164 | for (Operation &op : region->getOps()) |
| 165 | if (op.getAttrOfType<BoolAttr>(name: "interrupt" )) |
| 166 | return WalkResult::interrupt(); |
| 167 | |
| 168 | llvm::outs() << "step " << stepNo++ << "\n" ; |
| 169 | return WalkResult::advance(); |
| 170 | }; |
| 171 | |
| 172 | auto result = getOperation()->walk(walker); |
| 173 | if (result.wasInterrupted()) |
| 174 | llvm::outs() << "step " << stepNo++ << " walk was interrupted\n" ; |
| 175 | } |
| 176 | }; |
| 177 | |
| 178 | } // namespace |
| 179 | |
| 180 | namespace mlir { |
| 181 | namespace test { |
| 182 | void registerTestGenericIRVisitorsPass() { |
| 183 | PassRegistration<TestGenericIRVisitorPass>(); |
| 184 | PassRegistration<TestGenericIRVisitorInterruptPass>(); |
| 185 | PassRegistration<TestGenericIRBlockVisitorInterruptPass>(); |
| 186 | PassRegistration<TestGenericIRRegionVisitorInterruptPass>(); |
| 187 | } |
| 188 | |
| 189 | } // namespace test |
| 190 | } // namespace mlir |
| 191 | |