| 1 | //===- TestIRVisitors.cpp - Pass to test the IR visitors ------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/IR/BuiltinOps.h" |
| 10 | #include "mlir/IR/Iterators.h" |
| 11 | #include "mlir/Interfaces/FunctionInterfaces.h" |
| 12 | #include "mlir/Pass/Pass.h" |
| 13 | |
| 14 | using namespace mlir; |
| 15 | |
| 16 | static void printRegion(Region *region) { |
| 17 | llvm::outs() << "region " << region->getRegionNumber() << " from operation '" |
| 18 | << region->getParentOp()->getName() << "'" ; |
| 19 | } |
| 20 | |
| 21 | static void printBlock(Block *block) { |
| 22 | llvm::outs() << "block " ; |
| 23 | block->printAsOperand(os&: llvm::outs(), /*printType=*/false); |
| 24 | llvm::outs() << " from " ; |
| 25 | printRegion(region: block->getParent()); |
| 26 | } |
| 27 | |
| 28 | static void printOperation(Operation *op) { |
| 29 | llvm::outs() << "op '" << op->getName() << "'" ; |
| 30 | } |
| 31 | |
| 32 | /// Tests pure callbacks. |
| 33 | static void testPureCallbacks(Operation *op) { |
| 34 | auto opPure = [](Operation *op) { |
| 35 | llvm::outs() << "Visiting " ; |
| 36 | printOperation(op); |
| 37 | llvm::outs() << "\n" ; |
| 38 | }; |
| 39 | auto blockPure = [](Block *block) { |
| 40 | llvm::outs() << "Visiting " ; |
| 41 | printBlock(block); |
| 42 | llvm::outs() << "\n" ; |
| 43 | }; |
| 44 | auto regionPure = [](Region *region) { |
| 45 | llvm::outs() << "Visiting " ; |
| 46 | printRegion(region); |
| 47 | llvm::outs() << "\n" ; |
| 48 | }; |
| 49 | |
| 50 | llvm::outs() << "Op pre-order visits" |
| 51 | << "\n" ; |
| 52 | op->walk<WalkOrder::PreOrder>(callback&: opPure); |
| 53 | llvm::outs() << "Block pre-order visits" |
| 54 | << "\n" ; |
| 55 | op->walk<WalkOrder::PreOrder>(callback&: blockPure); |
| 56 | llvm::outs() << "Region pre-order visits" |
| 57 | << "\n" ; |
| 58 | op->walk<WalkOrder::PreOrder>(callback&: regionPure); |
| 59 | |
| 60 | llvm::outs() << "Op post-order visits" |
| 61 | << "\n" ; |
| 62 | op->walk<WalkOrder::PostOrder>(callback&: opPure); |
| 63 | llvm::outs() << "Block post-order visits" |
| 64 | << "\n" ; |
| 65 | op->walk<WalkOrder::PostOrder>(callback&: blockPure); |
| 66 | llvm::outs() << "Region post-order visits" |
| 67 | << "\n" ; |
| 68 | op->walk<WalkOrder::PostOrder>(callback&: regionPure); |
| 69 | |
| 70 | llvm::outs() << "Op reverse post-order visits" |
| 71 | << "\n" ; |
| 72 | op->walk<WalkOrder::PostOrder, ReverseIterator>(callback&: opPure); |
| 73 | llvm::outs() << "Block reverse post-order visits" |
| 74 | << "\n" ; |
| 75 | op->walk<WalkOrder::PostOrder, ReverseIterator>(callback&: blockPure); |
| 76 | llvm::outs() << "Region reverse post-order visits" |
| 77 | << "\n" ; |
| 78 | op->walk<WalkOrder::PostOrder, ReverseIterator>(callback&: regionPure); |
| 79 | |
| 80 | // This test case tests "NoGraphRegions = true", so start the walk with |
| 81 | // functions. |
| 82 | op->walk([&](FunctionOpInterface funcOp) { |
| 83 | llvm::outs() << "Op forward dominance post-order visits" |
| 84 | << "\n" ; |
| 85 | funcOp->walk<WalkOrder::PostOrder, |
| 86 | ForwardDominanceIterator</*NoGraphRegions=*/true>>(opPure); |
| 87 | llvm::outs() << "Block forward dominance post-order visits" |
| 88 | << "\n" ; |
| 89 | funcOp->walk<WalkOrder::PostOrder, |
| 90 | ForwardDominanceIterator</*NoGraphRegions=*/true>>(blockPure); |
| 91 | llvm::outs() << "Region forward dominance post-order visits" |
| 92 | << "\n" ; |
| 93 | funcOp->walk<WalkOrder::PostOrder, |
| 94 | ForwardDominanceIterator</*NoGraphRegions=*/true>>(regionPure); |
| 95 | |
| 96 | llvm::outs() << "Op reverse dominance post-order visits" |
| 97 | << "\n" ; |
| 98 | funcOp->walk<WalkOrder::PostOrder, |
| 99 | ReverseDominanceIterator</*NoGraphRegions=*/true>>(opPure); |
| 100 | llvm::outs() << "Block reverse dominance post-order visits" |
| 101 | << "\n" ; |
| 102 | funcOp->walk<WalkOrder::PostOrder, |
| 103 | ReverseDominanceIterator</*NoGraphRegions=*/true>>(blockPure); |
| 104 | llvm::outs() << "Region reverse dominance post-order visits" |
| 105 | << "\n" ; |
| 106 | funcOp->walk<WalkOrder::PostOrder, |
| 107 | ReverseDominanceIterator</*NoGraphRegions=*/true>>(regionPure); |
| 108 | }); |
| 109 | } |
| 110 | |
| 111 | /// Tests erasure callbacks that skip the walk. |
| 112 | static void testSkipErasureCallbacks(Operation *op) { |
| 113 | auto skipOpErasure = [](Operation *op) { |
| 114 | // Do not erase module and module children operations. Otherwise, there |
| 115 | // wouldn't be too much to test in pre-order. |
| 116 | if (isa<ModuleOp>(Val: op) || isa<ModuleOp>(Val: op->getParentOp())) |
| 117 | return WalkResult::advance(); |
| 118 | |
| 119 | llvm::outs() << "Erasing " ; |
| 120 | printOperation(op); |
| 121 | llvm::outs() << "\n" ; |
| 122 | op->dropAllUses(); |
| 123 | op->erase(); |
| 124 | return WalkResult::skip(); |
| 125 | }; |
| 126 | auto skipBlockErasure = [](Block *block) { |
| 127 | // Do not erase module and module children blocks. Otherwise there wouldn't |
| 128 | // be too much to test in pre-order. |
| 129 | Operation *parentOp = block->getParentOp(); |
| 130 | if (isa<ModuleOp>(Val: parentOp) || isa<ModuleOp>(Val: parentOp->getParentOp())) |
| 131 | return WalkResult::advance(); |
| 132 | |
| 133 | if (block->use_empty()) { |
| 134 | llvm::outs() << "Erasing " ; |
| 135 | printBlock(block); |
| 136 | llvm::outs() << "\n" ; |
| 137 | block->erase(); |
| 138 | return WalkResult::skip(); |
| 139 | } |
| 140 | llvm::outs() << "Cannot erase " ; |
| 141 | printBlock(block); |
| 142 | llvm::outs() << ", still has uses\n" ; |
| 143 | return WalkResult::advance(); |
| 144 | |
| 145 | }; |
| 146 | |
| 147 | llvm::outs() << "Op pre-order erasures (skip)" |
| 148 | << "\n" ; |
| 149 | Operation *cloned = op->clone(); |
| 150 | cloned->walk<WalkOrder::PreOrder>(callback&: skipOpErasure); |
| 151 | cloned->erase(); |
| 152 | |
| 153 | llvm::outs() << "Block pre-order erasures (skip)" |
| 154 | << "\n" ; |
| 155 | cloned = op->clone(); |
| 156 | cloned->walk<WalkOrder::PreOrder>(callback&: skipBlockErasure); |
| 157 | cloned->erase(); |
| 158 | |
| 159 | llvm::outs() << "Op post-order erasures (skip)" |
| 160 | << "\n" ; |
| 161 | cloned = op->clone(); |
| 162 | cloned->walk<WalkOrder::PostOrder>(callback&: skipOpErasure); |
| 163 | cloned->erase(); |
| 164 | |
| 165 | llvm::outs() << "Block post-order erasures (skip)" |
| 166 | << "\n" ; |
| 167 | cloned = op->clone(); |
| 168 | cloned->walk<WalkOrder::PostOrder>(callback&: skipBlockErasure); |
| 169 | cloned->erase(); |
| 170 | } |
| 171 | |
| 172 | /// Tests callbacks that erase the op or block but don't return 'Skip'. This |
| 173 | /// callbacks are only valid in post-order. |
| 174 | static void testNoSkipErasureCallbacks(Operation *op) { |
| 175 | auto noSkipOpErasure = [](Operation *op) { |
| 176 | llvm::outs() << "Erasing " ; |
| 177 | printOperation(op); |
| 178 | llvm::outs() << "\n" ; |
| 179 | op->dropAllUses(); |
| 180 | op->erase(); |
| 181 | }; |
| 182 | auto noSkipBlockErasure = [](Block *block) { |
| 183 | if (block->use_empty()) { |
| 184 | llvm::outs() << "Erasing " ; |
| 185 | printBlock(block); |
| 186 | llvm::outs() << "\n" ; |
| 187 | block->erase(); |
| 188 | } else { |
| 189 | llvm::outs() << "Cannot erase " ; |
| 190 | printBlock(block); |
| 191 | llvm::outs() << ", still has uses\n" ; |
| 192 | } |
| 193 | }; |
| 194 | |
| 195 | llvm::outs() << "Op post-order erasures (no skip)" |
| 196 | << "\n" ; |
| 197 | Operation *cloned = op->clone(); |
| 198 | cloned->walk<WalkOrder::PostOrder>(callback&: noSkipOpErasure); |
| 199 | |
| 200 | llvm::outs() << "Block post-order erasures (no skip)" |
| 201 | << "\n" ; |
| 202 | cloned = op->clone(); |
| 203 | cloned->walk<WalkOrder::PostOrder>(callback&: noSkipBlockErasure); |
| 204 | cloned->erase(); |
| 205 | } |
| 206 | |
| 207 | /// Invoke region/block walks on regions/blocks. |
| 208 | static void testBlockAndRegionWalkers(Operation *op) { |
| 209 | auto blockPure = [](Block *block) { |
| 210 | llvm::outs() << "Visiting " ; |
| 211 | printBlock(block); |
| 212 | llvm::outs() << "\n" ; |
| 213 | }; |
| 214 | auto regionPure = [](Region *region) { |
| 215 | llvm::outs() << "Visiting " ; |
| 216 | printRegion(region); |
| 217 | llvm::outs() << "\n" ; |
| 218 | }; |
| 219 | |
| 220 | llvm::outs() << "Invoke block pre-order visits on blocks\n" ; |
| 221 | op->walk(callback: [&](Operation *op) { |
| 222 | if (!op->hasAttr(name: "walk_blocks" )) |
| 223 | return; |
| 224 | for (Region ®ion : op->getRegions()) { |
| 225 | for (Block &block : region.getBlocks()) { |
| 226 | block.walk<WalkOrder::PreOrder>(callback&: blockPure); |
| 227 | } |
| 228 | } |
| 229 | }); |
| 230 | |
| 231 | llvm::outs() << "Invoke block post-order visits on blocks\n" ; |
| 232 | op->walk(callback: [&](Operation *op) { |
| 233 | if (!op->hasAttr(name: "walk_blocks" )) |
| 234 | return; |
| 235 | for (Region ®ion : op->getRegions()) { |
| 236 | for (Block &block : region.getBlocks()) { |
| 237 | block.walk<WalkOrder::PostOrder>(callback&: blockPure); |
| 238 | } |
| 239 | } |
| 240 | }); |
| 241 | |
| 242 | llvm::outs() << "Invoke region pre-order visits on region\n" ; |
| 243 | op->walk(callback: [&](Operation *op) { |
| 244 | if (!op->hasAttr(name: "walk_regions" )) |
| 245 | return; |
| 246 | for (Region ®ion : op->getRegions()) { |
| 247 | region.walk<WalkOrder::PreOrder>(callback&: regionPure); |
| 248 | } |
| 249 | }); |
| 250 | |
| 251 | llvm::outs() << "Invoke region post-order visits on region\n" ; |
| 252 | op->walk(callback: [&](Operation *op) { |
| 253 | if (!op->hasAttr(name: "walk_regions" )) |
| 254 | return; |
| 255 | for (Region ®ion : op->getRegions()) { |
| 256 | region.walk<WalkOrder::PostOrder>(callback&: regionPure); |
| 257 | } |
| 258 | }); |
| 259 | } |
| 260 | |
| 261 | namespace { |
| 262 | /// This pass exercises the different configurations of the IR visitors. |
| 263 | struct TestIRVisitorsPass |
| 264 | : public PassWrapper<TestIRVisitorsPass, OperationPass<>> { |
| 265 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIRVisitorsPass) |
| 266 | |
| 267 | StringRef getArgument() const final { return "test-ir-visitors" ; } |
| 268 | StringRef getDescription() const final { return "Test various visitors." ; } |
| 269 | void runOnOperation() override { |
| 270 | Operation *op = getOperation(); |
| 271 | testPureCallbacks(op); |
| 272 | testBlockAndRegionWalkers(op); |
| 273 | testSkipErasureCallbacks(op); |
| 274 | testNoSkipErasureCallbacks(op); |
| 275 | } |
| 276 | }; |
| 277 | } // namespace |
| 278 | |
| 279 | namespace mlir { |
| 280 | namespace test { |
| 281 | void registerTestIRVisitorsPass() { PassRegistration<TestIRVisitorsPass>(); } |
| 282 | } // namespace test |
| 283 | } // namespace mlir |
| 284 | |