| 1 | //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "TestDialect.h" |
| 10 | #include "TestOps.h" |
| 11 | #include "TestTypes.h" |
| 12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 13 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 14 | #include "mlir/Dialect/Func/Transforms/FuncConversions.h" |
| 15 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 16 | #include "mlir/IR/BuiltinAttributes.h" |
| 17 | #include "mlir/IR/Matchers.h" |
| 18 | #include "mlir/IR/Visitors.h" |
| 19 | #include "mlir/Pass/Pass.h" |
| 20 | #include "mlir/Transforms/DialectConversion.h" |
| 21 | #include "mlir/Transforms/FoldUtils.h" |
| 22 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 23 | #include "mlir/Transforms/WalkPatternRewriteDriver.h" |
| 24 | #include "llvm/ADT/ScopeExit.h" |
| 25 | #include <cstdint> |
| 26 | |
| 27 | using namespace mlir; |
| 28 | using namespace test; |
| 29 | |
| 30 | // Native function for testing NativeCodeCall |
| 31 | static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { |
| 32 | return choice.getValue() ? input1 : input2; |
| 33 | } |
| 34 | |
| 35 | static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { |
| 36 | rewriter.create<OpI>(loc, input); |
| 37 | } |
| 38 | |
| 39 | static void handleNoResultOp(PatternRewriter &rewriter, |
| 40 | OpSymbolBindingNoResult op) { |
| 41 | // Turn the no result op to a one-result op. |
| 42 | rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(), |
| 43 | op.getOperand()); |
| 44 | } |
| 45 | |
| 46 | static bool getFirstI32Result(Operation *op, Value &value) { |
| 47 | if (!Type(op->getResult(idx: 0).getType()).isSignlessInteger(width: 32)) |
| 48 | return false; |
| 49 | value = op->getResult(idx: 0); |
| 50 | return true; |
| 51 | } |
| 52 | |
| 53 | static Value bindNativeCodeCallResult(Value value) { return value; } |
| 54 | |
| 55 | static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, |
| 56 | Value input2) { |
| 57 | return SmallVector<Value, 2>({input2, input1}); |
| 58 | } |
| 59 | |
| 60 | // Test that natives calls are only called once during rewrites. |
| 61 | // OpM_Test will return Pi, increased by 1 for each subsequent calls. |
| 62 | // This let us check the number of times OpM_Test was called by inspecting |
| 63 | // the returned value in the MLIR output. |
| 64 | static int64_t opMIncreasingValue = 314159265; |
| 65 | static Attribute opMTest(PatternRewriter &rewriter, Value val) { |
| 66 | int64_t i = opMIncreasingValue++; |
| 67 | return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); |
| 68 | } |
| 69 | |
| 70 | namespace { |
| 71 | #include "TestPatterns.inc" |
| 72 | } // namespace |
| 73 | |
| 74 | //===----------------------------------------------------------------------===// |
| 75 | // Test Reduce Pattern Interface |
| 76 | //===----------------------------------------------------------------------===// |
| 77 | |
| 78 | void test::populateTestReductionPatterns(RewritePatternSet &patterns) { |
| 79 | populateWithGenerated(patterns); |
| 80 | } |
| 81 | |
| 82 | //===----------------------------------------------------------------------===// |
| 83 | // Canonicalizer Driver. |
| 84 | //===----------------------------------------------------------------------===// |
| 85 | |
| 86 | namespace { |
| 87 | struct FoldingPattern : public RewritePattern { |
| 88 | public: |
| 89 | FoldingPattern(MLIRContext *context) |
| 90 | : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), |
| 91 | /*benefit=*/1, context) {} |
| 92 | |
| 93 | LogicalResult matchAndRewrite(Operation *op, |
| 94 | PatternRewriter &rewriter) const override { |
| 95 | // Exercise createOrFold API for a single-result operation that is folded |
| 96 | // upon construction. The operation being created has an in-place folder, |
| 97 | // and it should be still present in the output. Furthermore, the folder |
| 98 | // should not crash when attempting to recover the (unchanged) operation |
| 99 | // result. |
| 100 | Value result = rewriter.createOrFold<TestOpInPlaceFold>( |
| 101 | op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0)); |
| 102 | assert(result); |
| 103 | rewriter.replaceOp(op, newValues: result); |
| 104 | return success(); |
| 105 | } |
| 106 | }; |
| 107 | |
| 108 | /// This pattern creates a foldable operation at the entry point of the block. |
| 109 | /// This tests the situation where the operation folder will need to replace an |
| 110 | /// operation with a previously created constant that does not initially |
| 111 | /// dominate the operation to replace. |
| 112 | struct FolderInsertBeforePreviouslyFoldedConstantPattern |
| 113 | : public OpRewritePattern<TestCastOp> { |
| 114 | public: |
| 115 | using OpRewritePattern<TestCastOp>::OpRewritePattern; |
| 116 | |
| 117 | LogicalResult matchAndRewrite(TestCastOp op, |
| 118 | PatternRewriter &rewriter) const override { |
| 119 | if (!op->hasAttr("test_fold_before_previously_folded_op" )) |
| 120 | return failure(); |
| 121 | rewriter.setInsertionPointToStart(op->getBlock()); |
| 122 | |
| 123 | auto constOp = rewriter.create<arith::ConstantOp>( |
| 124 | op.getLoc(), rewriter.getBoolAttr(true)); |
| 125 | rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), |
| 126 | Value(constOp)); |
| 127 | return success(); |
| 128 | } |
| 129 | }; |
| 130 | |
| 131 | /// This pattern matches test.op_commutative2 with the first operand being |
| 132 | /// another test.op_commutative2 with a constant on the right side and fold it |
| 133 | /// away by propagating it as its result. This is intend to check that patterns |
| 134 | /// are applied after the commutative property moves constant to the right. |
| 135 | struct FolderCommutativeOp2WithConstant |
| 136 | : public OpRewritePattern<TestCommutative2Op> { |
| 137 | public: |
| 138 | using OpRewritePattern<TestCommutative2Op>::OpRewritePattern; |
| 139 | |
| 140 | LogicalResult matchAndRewrite(TestCommutative2Op op, |
| 141 | PatternRewriter &rewriter) const override { |
| 142 | auto operand = |
| 143 | dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp()); |
| 144 | if (!operand) |
| 145 | return failure(); |
| 146 | Attribute constInput; |
| 147 | if (!matchPattern(operand->getOperand(1), m_Constant(bind_value: &constInput))) |
| 148 | return failure(); |
| 149 | rewriter.replaceOp(op, operand->getOperand(1)); |
| 150 | return success(); |
| 151 | } |
| 152 | }; |
| 153 | |
| 154 | /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer |
| 155 | /// attribute with value smaller than MaxVal, it increments the value by 1. |
| 156 | template <int MaxVal> |
| 157 | struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> { |
| 158 | using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern; |
| 159 | |
| 160 | LogicalResult matchAndRewrite(AnyAttrOfOp op, |
| 161 | PatternRewriter &rewriter) const override { |
| 162 | auto intAttr = dyn_cast<IntegerAttr>(op.getAttr()); |
| 163 | if (!intAttr) |
| 164 | return failure(); |
| 165 | int64_t val = intAttr.getInt(); |
| 166 | if (val >= MaxVal) |
| 167 | return failure(); |
| 168 | rewriter.modifyOpInPlace( |
| 169 | op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); }); |
| 170 | return success(); |
| 171 | } |
| 172 | }; |
| 173 | |
| 174 | /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op". |
| 175 | struct MakeOpEligible : public RewritePattern { |
| 176 | MakeOpEligible(MLIRContext *context) |
| 177 | : RewritePattern("foo.maybe_eligible_op" , /*benefit=*/1, context) {} |
| 178 | |
| 179 | LogicalResult matchAndRewrite(Operation *op, |
| 180 | PatternRewriter &rewriter) const override { |
| 181 | if (op->hasAttr(name: "eligible" )) |
| 182 | return failure(); |
| 183 | rewriter.modifyOpInPlace( |
| 184 | root: op, callable: [&]() { op->setAttr("eligible" , rewriter.getUnitAttr()); }); |
| 185 | return success(); |
| 186 | } |
| 187 | }; |
| 188 | |
| 189 | /// This pattern hoists eligible ops out of a "test.one_region_op". |
| 190 | struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> { |
| 191 | using OpRewritePattern<test::OneRegionOp>::OpRewritePattern; |
| 192 | |
| 193 | LogicalResult matchAndRewrite(test::OneRegionOp op, |
| 194 | PatternRewriter &rewriter) const override { |
| 195 | Operation *terminator = op.getRegion().front().getTerminator(); |
| 196 | Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp(); |
| 197 | if (toBeHoisted->getParentOp() != op) |
| 198 | return failure(); |
| 199 | if (!toBeHoisted->hasAttr(name: "eligible" )) |
| 200 | return failure(); |
| 201 | rewriter.moveOpBefore(toBeHoisted, op); |
| 202 | return success(); |
| 203 | } |
| 204 | }; |
| 205 | |
| 206 | /// This pattern moves "test.move_before_parent_op" before the parent op. |
| 207 | struct MoveBeforeParentOp : public RewritePattern { |
| 208 | MoveBeforeParentOp(MLIRContext *context) |
| 209 | : RewritePattern("test.move_before_parent_op" , /*benefit=*/1, context) {} |
| 210 | |
| 211 | LogicalResult matchAndRewrite(Operation *op, |
| 212 | PatternRewriter &rewriter) const override { |
| 213 | // Do not hoist past functions. |
| 214 | if (isa<FunctionOpInterface>(Val: op->getParentOp())) |
| 215 | return failure(); |
| 216 | rewriter.moveOpBefore(op, existingOp: op->getParentOp()); |
| 217 | return success(); |
| 218 | } |
| 219 | }; |
| 220 | |
| 221 | /// This pattern moves "test.move_after_parent_op" after the parent op. |
| 222 | struct MoveAfterParentOp : public RewritePattern { |
| 223 | MoveAfterParentOp(MLIRContext *context) |
| 224 | : RewritePattern("test.move_after_parent_op" , /*benefit=*/1, context) {} |
| 225 | |
| 226 | LogicalResult matchAndRewrite(Operation *op, |
| 227 | PatternRewriter &rewriter) const override { |
| 228 | // Do not hoist past functions. |
| 229 | if (isa<FunctionOpInterface>(Val: op->getParentOp())) |
| 230 | return failure(); |
| 231 | |
| 232 | int64_t moveForwardBy = 0; |
| 233 | if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance" )) |
| 234 | moveForwardBy = advanceBy.getInt(); |
| 235 | |
| 236 | Operation *moveAfter = op->getParentOp(); |
| 237 | for (int64_t i = 0; i < moveForwardBy; ++i) |
| 238 | moveAfter = moveAfter->getNextNode(); |
| 239 | |
| 240 | rewriter.moveOpAfter(op, existingOp: moveAfter); |
| 241 | return success(); |
| 242 | } |
| 243 | }; |
| 244 | |
| 245 | /// This pattern inlines blocks that are nested in |
| 246 | /// "test.inline_blocks_into_parent" into the parent block. |
| 247 | struct InlineBlocksIntoParent : public RewritePattern { |
| 248 | InlineBlocksIntoParent(MLIRContext *context) |
| 249 | : RewritePattern("test.inline_blocks_into_parent" , /*benefit=*/1, |
| 250 | context) {} |
| 251 | |
| 252 | LogicalResult matchAndRewrite(Operation *op, |
| 253 | PatternRewriter &rewriter) const override { |
| 254 | bool changed = false; |
| 255 | for (Region &r : op->getRegions()) { |
| 256 | while (!r.empty()) { |
| 257 | rewriter.inlineBlockBefore(source: &r.front(), op); |
| 258 | changed = true; |
| 259 | } |
| 260 | } |
| 261 | return success(IsSuccess: changed); |
| 262 | } |
| 263 | }; |
| 264 | |
| 265 | /// This pattern splits blocks at "test.split_block_here" and replaces the op |
| 266 | /// with a new op (to prevent an infinite loop of block splitting). |
| 267 | struct SplitBlockHere : public RewritePattern { |
| 268 | SplitBlockHere(MLIRContext *context) |
| 269 | : RewritePattern("test.split_block_here" , /*benefit=*/1, context) {} |
| 270 | |
| 271 | LogicalResult matchAndRewrite(Operation *op, |
| 272 | PatternRewriter &rewriter) const override { |
| 273 | rewriter.splitBlock(block: op->getBlock(), before: op->getIterator()); |
| 274 | Operation *newOp = rewriter.create( |
| 275 | op->getLoc(), |
| 276 | OperationName("test.new_op" , op->getContext()).getIdentifier(), |
| 277 | op->getOperands(), op->getResultTypes()); |
| 278 | rewriter.replaceOp(op, newOp); |
| 279 | return success(); |
| 280 | } |
| 281 | }; |
| 282 | |
| 283 | /// This pattern clones "test.clone_me" ops. |
| 284 | struct CloneOp : public RewritePattern { |
| 285 | CloneOp(MLIRContext *context) |
| 286 | : RewritePattern("test.clone_me" , /*benefit=*/1, context) {} |
| 287 | |
| 288 | LogicalResult matchAndRewrite(Operation *op, |
| 289 | PatternRewriter &rewriter) const override { |
| 290 | // Do not clone already cloned ops to avoid going into an infinite loop. |
| 291 | if (op->hasAttr(name: "was_cloned" )) |
| 292 | return failure(); |
| 293 | Operation *cloned = rewriter.clone(op&: *op); |
| 294 | cloned->setAttr("was_cloned" , rewriter.getUnitAttr()); |
| 295 | return success(); |
| 296 | } |
| 297 | }; |
| 298 | |
| 299 | /// This pattern clones regions of "test.clone_region_before" ops before the |
| 300 | /// parent block. |
| 301 | struct CloneRegionBeforeOp : public RewritePattern { |
| 302 | CloneRegionBeforeOp(MLIRContext *context) |
| 303 | : RewritePattern("test.clone_region_before" , /*benefit=*/1, context) {} |
| 304 | |
| 305 | LogicalResult matchAndRewrite(Operation *op, |
| 306 | PatternRewriter &rewriter) const override { |
| 307 | // Do not clone already cloned ops to avoid going into an infinite loop. |
| 308 | if (op->hasAttr(name: "was_cloned" )) |
| 309 | return failure(); |
| 310 | for (Region &r : op->getRegions()) |
| 311 | rewriter.cloneRegionBefore(region&: r, before: op->getBlock()); |
| 312 | op->setAttr("was_cloned" , rewriter.getUnitAttr()); |
| 313 | return success(); |
| 314 | } |
| 315 | }; |
| 316 | |
| 317 | /// Replace an operation may introduce the re-visiting of its users. |
| 318 | class ReplaceWithNewOp : public RewritePattern { |
| 319 | public: |
| 320 | ReplaceWithNewOp(MLIRContext *context) |
| 321 | : RewritePattern("test.replace_with_new_op" , /*benefit=*/1, context) {} |
| 322 | |
| 323 | LogicalResult matchAndRewrite(Operation *op, |
| 324 | PatternRewriter &rewriter) const override { |
| 325 | Operation *newOp; |
| 326 | if (op->hasAttr(name: "create_erase_op" )) { |
| 327 | newOp = rewriter.create( |
| 328 | op->getLoc(), |
| 329 | OperationName("test.erase_op" , op->getContext()).getIdentifier(), |
| 330 | ValueRange(), TypeRange()); |
| 331 | } else { |
| 332 | newOp = rewriter.create( |
| 333 | op->getLoc(), |
| 334 | OperationName("test.new_op" , op->getContext()).getIdentifier(), |
| 335 | op->getOperands(), op->getResultTypes()); |
| 336 | } |
| 337 | // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp". |
| 338 | // A "notifyOperationReplaced" callback is triggered in either case. |
| 339 | rewriter.replaceAllOpUsesWith(from: op, to: newOp->getResults()); |
| 340 | rewriter.eraseOp(op); |
| 341 | return success(); |
| 342 | } |
| 343 | }; |
| 344 | |
| 345 | /// Erases the first child block of the matched "test.erase_first_block" |
| 346 | /// operation. |
| 347 | class EraseFirstBlock : public RewritePattern { |
| 348 | public: |
| 349 | EraseFirstBlock(MLIRContext *context) |
| 350 | : RewritePattern("test.erase_first_block" , /*benefit=*/1, context) {} |
| 351 | |
| 352 | LogicalResult matchAndRewrite(Operation *op, |
| 353 | PatternRewriter &rewriter) const override { |
| 354 | for (Region &r : op->getRegions()) { |
| 355 | for (Block &b : r.getBlocks()) { |
| 356 | rewriter.eraseBlock(block: &b); |
| 357 | return success(); |
| 358 | } |
| 359 | } |
| 360 | |
| 361 | return failure(); |
| 362 | } |
| 363 | }; |
| 364 | |
| 365 | struct TestGreedyPatternDriver |
| 366 | : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> { |
| 367 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver) |
| 368 | |
| 369 | TestGreedyPatternDriver() = default; |
| 370 | TestGreedyPatternDriver(const TestGreedyPatternDriver &other) |
| 371 | : PassWrapper(other) {} |
| 372 | |
| 373 | StringRef getArgument() const final { return "test-greedy-patterns" ; } |
| 374 | StringRef getDescription() const final { return "Run test dialect patterns" ; } |
| 375 | void runOnOperation() override { |
| 376 | mlir::RewritePatternSet patterns(&getContext()); |
| 377 | populateWithGenerated(patterns); |
| 378 | |
| 379 | // Verify named pattern is generated with expected name. |
| 380 | patterns.add<FoldingPattern, TestNamedPatternRule, |
| 381 | FolderInsertBeforePreviouslyFoldedConstantPattern, |
| 382 | FolderCommutativeOp2WithConstant, HoistEligibleOps, |
| 383 | MakeOpEligible>(&getContext()); |
| 384 | |
| 385 | // Additional patterns for testing the GreedyPatternRewriteDriver. |
| 386 | patterns.insert<IncrementIntAttribute<3>>(arg: &getContext()); |
| 387 | |
| 388 | GreedyRewriteConfig config; |
| 389 | config.setUseTopDownTraversal(useTopDownTraversal) |
| 390 | .setMaxIterations(this->maxIterations) |
| 391 | .enableFolding(enable: this->fold) |
| 392 | .enableConstantCSE(enable: this->cseConstants); |
| 393 | (void)applyPatternsGreedily(getOperation(), std::move(patterns), config); |
| 394 | } |
| 395 | |
| 396 | Option<bool> useTopDownTraversal{ |
| 397 | *this, "top-down" , |
| 398 | llvm::cl::desc("Seed the worklist in general top-down order" ), |
| 399 | llvm::cl::init(Val: GreedyRewriteConfig().getUseTopDownTraversal())}; |
| 400 | Option<int> maxIterations{ |
| 401 | *this, "max-iterations" , |
| 402 | llvm::cl::desc("Max. iterations in the GreedyRewriteConfig" ), |
| 403 | llvm::cl::init(Val: GreedyRewriteConfig().getMaxIterations())}; |
| 404 | Option<bool> fold{*this, "fold" , llvm::cl::desc("Whether to fold" ), |
| 405 | llvm::cl::init(Val: GreedyRewriteConfig().isFoldingEnabled())}; |
| 406 | Option<bool> cseConstants{ |
| 407 | *this, "cse-constants" , llvm::cl::desc("Whether to CSE constants" ), |
| 408 | llvm::cl::init(Val: GreedyRewriteConfig().isConstantCSEEnabled())}; |
| 409 | }; |
| 410 | |
| 411 | struct DumpNotifications : public RewriterBase::Listener { |
| 412 | void notifyBlockInserted(Block *block, Region *previous, |
| 413 | Region::iterator previousIt) override { |
| 414 | llvm::outs() << "notifyBlockInserted" ; |
| 415 | if (block->getParentOp()) { |
| 416 | llvm::outs() << " into " << block->getParentOp()->getName() << ": " ; |
| 417 | } else { |
| 418 | llvm::outs() << " into unknown op: " ; |
| 419 | } |
| 420 | if (previous == nullptr) { |
| 421 | llvm::outs() << "was unlinked\n" ; |
| 422 | } else { |
| 423 | llvm::outs() << "was linked\n" ; |
| 424 | } |
| 425 | } |
| 426 | void notifyOperationInserted(Operation *op, |
| 427 | OpBuilder::InsertPoint previous) override { |
| 428 | llvm::outs() << "notifyOperationInserted: " << op->getName(); |
| 429 | if (!previous.isSet()) { |
| 430 | llvm::outs() << ", was unlinked\n" ; |
| 431 | } else { |
| 432 | if (!previous.getPoint().getNodePtr()) { |
| 433 | llvm::outs() << ", was linked, exact position unknown\n" ; |
| 434 | } else if (previous.getPoint() == previous.getBlock()->end()) { |
| 435 | llvm::outs() << ", was last in block\n" ; |
| 436 | } else { |
| 437 | llvm::outs() << ", previous = " << previous.getPoint()->getName() |
| 438 | << "\n" ; |
| 439 | } |
| 440 | } |
| 441 | } |
| 442 | void notifyBlockErased(Block *block) override { |
| 443 | llvm::outs() << "notifyBlockErased\n" ; |
| 444 | } |
| 445 | void notifyOperationErased(Operation *op) override { |
| 446 | llvm::outs() << "notifyOperationErased: " << op->getName() << "\n" ; |
| 447 | } |
| 448 | void notifyOperationModified(Operation *op) override { |
| 449 | llvm::outs() << "notifyOperationModified: " << op->getName() << "\n" ; |
| 450 | } |
| 451 | void notifyOperationReplaced(Operation *op, ValueRange values) override { |
| 452 | llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n" ; |
| 453 | } |
| 454 | }; |
| 455 | |
| 456 | struct TestStrictPatternDriver |
| 457 | : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> { |
| 458 | public: |
| 459 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver) |
| 460 | |
| 461 | TestStrictPatternDriver() = default; |
| 462 | TestStrictPatternDriver(const TestStrictPatternDriver &other) |
| 463 | : PassWrapper(other) { |
| 464 | strictMode = other.strictMode; |
| 465 | } |
| 466 | |
| 467 | StringRef getArgument() const final { return "test-strict-pattern-driver" ; } |
| 468 | StringRef getDescription() const final { |
| 469 | return "Test strict mode of pattern driver" ; |
| 470 | } |
| 471 | |
| 472 | void runOnOperation() override { |
| 473 | MLIRContext *ctx = &getContext(); |
| 474 | mlir::RewritePatternSet patterns(ctx); |
| 475 | patterns.add< |
| 476 | // clang-format off |
| 477 | ChangeBlockOp, |
| 478 | CloneOp, |
| 479 | CloneRegionBeforeOp, |
| 480 | EraseOp, |
| 481 | ImplicitChangeOp, |
| 482 | InlineBlocksIntoParent, |
| 483 | InsertSameOp, |
| 484 | MoveBeforeParentOp, |
| 485 | ReplaceWithNewOp, |
| 486 | SplitBlockHere |
| 487 | // clang-format on |
| 488 | >(arg&: ctx); |
| 489 | SmallVector<Operation *> ops; |
| 490 | getOperation()->walk([&](Operation *op) { |
| 491 | StringRef opName = op->getName().getStringRef(); |
| 492 | if (opName == "test.insert_same_op" || opName == "test.change_block_op" || |
| 493 | opName == "test.replace_with_new_op" || opName == "test.erase_op" || |
| 494 | opName == "test.move_before_parent_op" || |
| 495 | opName == "test.inline_blocks_into_parent" || |
| 496 | opName == "test.split_block_here" || opName == "test.clone_me" || |
| 497 | opName == "test.clone_region_before" ) { |
| 498 | ops.push_back(Elt: op); |
| 499 | } |
| 500 | }); |
| 501 | |
| 502 | DumpNotifications dumpNotifications; |
| 503 | GreedyRewriteConfig config; |
| 504 | config.setListener(&dumpNotifications); |
| 505 | if (strictMode == "AnyOp" ) { |
| 506 | config.setStrictness(GreedyRewriteStrictness::AnyOp); |
| 507 | } else if (strictMode == "ExistingAndNewOps" ) { |
| 508 | config.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps); |
| 509 | } else if (strictMode == "ExistingOps" ) { |
| 510 | config.setStrictness(GreedyRewriteStrictness::ExistingOps); |
| 511 | } else { |
| 512 | llvm_unreachable("invalid strictness option" ); |
| 513 | } |
| 514 | |
| 515 | // Check if these transformations introduce visiting of operations that |
| 516 | // are not in the `ops` set (The new created ops are valid). An invalid |
| 517 | // operation will trigger the assertion while processing. |
| 518 | bool changed = false; |
| 519 | bool allErased = false; |
| 520 | (void)applyOpPatternsGreedily(ArrayRef(ops), std::move(patterns), config, |
| 521 | &changed, &allErased); |
| 522 | Builder b(ctx); |
| 523 | getOperation()->setAttr("pattern_driver_changed" , b.getBoolAttr(value: changed)); |
| 524 | getOperation()->setAttr("pattern_driver_all_erased" , |
| 525 | b.getBoolAttr(value: allErased)); |
| 526 | } |
| 527 | |
| 528 | Option<std::string> strictMode{ |
| 529 | *this, "strictness" , |
| 530 | llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}" ), |
| 531 | llvm::cl::init("AnyOp" )}; |
| 532 | |
| 533 | private: |
| 534 | // New inserted operation is valid for further transformation. |
| 535 | class InsertSameOp : public RewritePattern { |
| 536 | public: |
| 537 | InsertSameOp(MLIRContext *context) |
| 538 | : RewritePattern("test.insert_same_op" , /*benefit=*/1, context) {} |
| 539 | |
| 540 | LogicalResult matchAndRewrite(Operation *op, |
| 541 | PatternRewriter &rewriter) const override { |
| 542 | if (op->hasAttr(name: "skip" )) |
| 543 | return failure(); |
| 544 | |
| 545 | Operation *newOp = |
| 546 | rewriter.create(op->getLoc(), op->getName().getIdentifier(), |
| 547 | op->getOperands(), op->getResultTypes()); |
| 548 | rewriter.modifyOpInPlace( |
| 549 | root: op, callable: [&]() { op->setAttr(name: "skip" , value: rewriter.getBoolAttr(value: true)); }); |
| 550 | newOp->setAttr(name: "skip" , value: rewriter.getBoolAttr(value: true)); |
| 551 | |
| 552 | return success(); |
| 553 | } |
| 554 | }; |
| 555 | |
| 556 | // Remove an operation may introduce the re-visiting of its operands. |
| 557 | class EraseOp : public RewritePattern { |
| 558 | public: |
| 559 | EraseOp(MLIRContext *context) |
| 560 | : RewritePattern("test.erase_op" , /*benefit=*/1, context) {} |
| 561 | LogicalResult matchAndRewrite(Operation *op, |
| 562 | PatternRewriter &rewriter) const override { |
| 563 | rewriter.eraseOp(op); |
| 564 | return success(); |
| 565 | } |
| 566 | }; |
| 567 | |
| 568 | // The following two patterns test RewriterBase::replaceAllUsesWith. |
| 569 | // |
| 570 | // That function replaces all usages of a Block (or a Value) with another one |
| 571 | // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver |
| 572 | // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its |
| 573 | // worklist: when an op is modified, it is added to the worklist. The two |
| 574 | // patterns below make the tracking observable: ChangeBlockOp replaces all |
| 575 | // usages of a block and that pattern is applied because the corresponding ops |
| 576 | // are put on the initial worklist (see above). ImplicitChangeOp does an |
| 577 | // unrelated change but ops of the corresponding type are *not* on the initial |
| 578 | // worklist, so the effect of the second pattern is only visible if the |
| 579 | // tracking and subsequent adding to the worklist actually works. |
| 580 | |
| 581 | // Replace all usages of the first successor with the second successor. |
| 582 | class ChangeBlockOp : public RewritePattern { |
| 583 | public: |
| 584 | ChangeBlockOp(MLIRContext *context) |
| 585 | : RewritePattern("test.change_block_op" , /*benefit=*/1, context) {} |
| 586 | LogicalResult matchAndRewrite(Operation *op, |
| 587 | PatternRewriter &rewriter) const override { |
| 588 | if (op->getNumSuccessors() < 2) |
| 589 | return failure(); |
| 590 | Block *firstSuccessor = op->getSuccessor(index: 0); |
| 591 | Block *secondSuccessor = op->getSuccessor(index: 1); |
| 592 | if (firstSuccessor == secondSuccessor) |
| 593 | return failure(); |
| 594 | // This is the function being tested: |
| 595 | rewriter.replaceAllUsesWith(from: firstSuccessor, to: secondSuccessor); |
| 596 | // Using the following line instead would make the test fail: |
| 597 | // firstSuccessor->replaceAllUsesWith(secondSuccessor); |
| 598 | return success(); |
| 599 | } |
| 600 | }; |
| 601 | |
| 602 | // Changes the successor to the parent block. |
| 603 | class ImplicitChangeOp : public RewritePattern { |
| 604 | public: |
| 605 | ImplicitChangeOp(MLIRContext *context) |
| 606 | : RewritePattern("test.implicit_change_op" , /*benefit=*/1, context) {} |
| 607 | LogicalResult matchAndRewrite(Operation *op, |
| 608 | PatternRewriter &rewriter) const override { |
| 609 | if (op->getNumSuccessors() < 1 || op->getSuccessor(index: 0) == op->getBlock()) |
| 610 | return failure(); |
| 611 | rewriter.modifyOpInPlace(root: op, |
| 612 | callable: [&]() { op->setSuccessor(block: op->getBlock(), index: 0); }); |
| 613 | return success(); |
| 614 | } |
| 615 | }; |
| 616 | }; |
| 617 | |
| 618 | struct TestWalkPatternDriver final |
| 619 | : PassWrapper<TestWalkPatternDriver, OperationPass<>> { |
| 620 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver) |
| 621 | |
| 622 | TestWalkPatternDriver() = default; |
| 623 | TestWalkPatternDriver(const TestWalkPatternDriver &other) |
| 624 | : PassWrapper(other) {} |
| 625 | |
| 626 | StringRef getArgument() const override { |
| 627 | return "test-walk-pattern-rewrite-driver" ; |
| 628 | } |
| 629 | StringRef getDescription() const override { |
| 630 | return "Run test walk pattern rewrite driver" ; |
| 631 | } |
| 632 | void runOnOperation() override { |
| 633 | mlir::RewritePatternSet patterns(&getContext()); |
| 634 | |
| 635 | // Patterns for testing the WalkPatternRewriteDriver. |
| 636 | patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp, |
| 637 | MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>( |
| 638 | arg: &getContext()); |
| 639 | |
| 640 | DumpNotifications dumpListener; |
| 641 | walkAndApplyPatterns(getOperation(), std::move(patterns), |
| 642 | dumpNotifications ? &dumpListener : nullptr); |
| 643 | } |
| 644 | |
| 645 | Option<bool> dumpNotifications{ |
| 646 | *this, "dump-notifications" , |
| 647 | llvm::cl::desc("Print rewrite listener notifications" ), |
| 648 | llvm::cl::init(Val: false)}; |
| 649 | }; |
| 650 | |
| 651 | } // namespace |
| 652 | |
| 653 | //===----------------------------------------------------------------------===// |
| 654 | // ReturnType Driver. |
| 655 | //===----------------------------------------------------------------------===// |
| 656 | |
| 657 | namespace { |
| 658 | // Generate ops for each instance where the type can be successfully inferred. |
| 659 | template <typename OpTy> |
| 660 | static void invokeCreateWithInferredReturnType(Operation *op) { |
| 661 | auto *context = op->getContext(); |
| 662 | auto fop = op->getParentOfType<func::FuncOp>(); |
| 663 | auto location = UnknownLoc::get(context); |
| 664 | OpBuilder b(op); |
| 665 | b.setInsertionPointAfter(op); |
| 666 | |
| 667 | // Use permutations of 2 args as operands. |
| 668 | assert(fop.getNumArguments() >= 2); |
| 669 | for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { |
| 670 | for (int j = 0; j < e; ++j) { |
| 671 | std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; |
| 672 | SmallVector<Type, 2> inferredReturnTypes; |
| 673 | if (succeeded(OpTy::inferReturnTypes( |
| 674 | context, std::nullopt, values, op->getDiscardableAttrDictionary(), |
| 675 | op->getPropertiesStorage(), op->getRegions(), |
| 676 | inferredReturnTypes))) { |
| 677 | OperationState state(location, OpTy::getOperationName()); |
| 678 | // TODO: Expand to regions. |
| 679 | OpTy::build(b, state, values, op->getAttrs()); |
| 680 | (void)b.create(state); |
| 681 | } |
| 682 | } |
| 683 | } |
| 684 | } |
| 685 | |
| 686 | static void reifyReturnShape(Operation *op) { |
| 687 | OpBuilder b(op); |
| 688 | |
| 689 | // Use permutations of 2 args as operands. |
| 690 | auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); |
| 691 | SmallVector<Value, 2> shapes; |
| 692 | if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || |
| 693 | !llvm::hasSingleElement(C&: shapes)) |
| 694 | return; |
| 695 | for (const auto &it : llvm::enumerate(First&: shapes)) { |
| 696 | op->emitRemark() << "value " << it.index() << ": " |
| 697 | << it.value().getDefiningOp(); |
| 698 | } |
| 699 | } |
| 700 | |
| 701 | struct TestReturnTypeDriver |
| 702 | : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> { |
| 703 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver) |
| 704 | |
| 705 | void getDependentDialects(DialectRegistry ®istry) const override { |
| 706 | registry.insert<tensor::TensorDialect>(); |
| 707 | } |
| 708 | StringRef getArgument() const final { return "test-return-type" ; } |
| 709 | StringRef getDescription() const final { return "Run return type functions" ; } |
| 710 | |
| 711 | void runOnOperation() override { |
| 712 | if (getOperation().getName() == "testCreateFunctions" ) { |
| 713 | std::vector<Operation *> ops; |
| 714 | // Collect ops to avoid triggering on inserted ops. |
| 715 | for (auto &op : getOperation().getBody().front()) |
| 716 | ops.push_back(&op); |
| 717 | // Generate test patterns for each, but skip terminator. |
| 718 | for (auto *op : llvm::ArrayRef(ops).drop_back()) { |
| 719 | // Test create method of each of the Op classes below. The resultant |
| 720 | // output would be in reverse order underneath `op` from which |
| 721 | // the attributes and regions are used. |
| 722 | invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); |
| 723 | invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>( |
| 724 | op); |
| 725 | invokeCreateWithInferredReturnType< |
| 726 | OpWithShapedTypeInferTypeInterfaceOp>(op); |
| 727 | }; |
| 728 | return; |
| 729 | } |
| 730 | if (getOperation().getName() == "testReifyFunctions" ) { |
| 731 | std::vector<Operation *> ops; |
| 732 | // Collect ops to avoid triggering on inserted ops. |
| 733 | for (auto &op : getOperation().getBody().front()) |
| 734 | if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) |
| 735 | ops.push_back(&op); |
| 736 | // Generate test patterns for each, but skip terminator. |
| 737 | for (auto *op : ops) |
| 738 | reifyReturnShape(op); |
| 739 | } |
| 740 | } |
| 741 | }; |
| 742 | } // namespace |
| 743 | |
| 744 | namespace { |
| 745 | struct TestDerivedAttributeDriver |
| 746 | : public PassWrapper<TestDerivedAttributeDriver, |
| 747 | OperationPass<func::FuncOp>> { |
| 748 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver) |
| 749 | |
| 750 | StringRef getArgument() const final { return "test-derived-attr" ; } |
| 751 | StringRef getDescription() const final { |
| 752 | return "Run test derived attributes" ; |
| 753 | } |
| 754 | void runOnOperation() override; |
| 755 | }; |
| 756 | } // namespace |
| 757 | |
| 758 | void TestDerivedAttributeDriver::runOnOperation() { |
| 759 | getOperation().walk([](DerivedAttributeOpInterface dOp) { |
| 760 | auto dAttr = dOp.materializeDerivedAttributes(); |
| 761 | if (!dAttr) |
| 762 | return; |
| 763 | for (auto d : dAttr) |
| 764 | dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue(); |
| 765 | }); |
| 766 | } |
| 767 | |
| 768 | //===----------------------------------------------------------------------===// |
| 769 | // Legalization Driver. |
| 770 | //===----------------------------------------------------------------------===// |
| 771 | |
| 772 | namespace { |
| 773 | //===----------------------------------------------------------------------===// |
| 774 | // Region-Block Rewrite Testing |
| 775 | //===----------------------------------------------------------------------===// |
| 776 | |
| 777 | /// This pattern applies a signature conversion to a block inside a detached |
| 778 | /// region. |
| 779 | struct TestDetachedSignatureConversion : public ConversionPattern { |
| 780 | TestDetachedSignatureConversion(MLIRContext *ctx) |
| 781 | : ConversionPattern("test.detached_signature_conversion" , /*benefit=*/1, |
| 782 | ctx) {} |
| 783 | |
| 784 | LogicalResult |
| 785 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 786 | ConversionPatternRewriter &rewriter) const final { |
| 787 | if (op->getNumRegions() != 1) |
| 788 | return failure(); |
| 789 | OperationState state(op->getLoc(), "test.legal_op" , operands, |
| 790 | op->getResultTypes(), {}, BlockRange()); |
| 791 | Region *newRegion = state.addRegion(); |
| 792 | rewriter.inlineRegionBefore(region&: op->getRegion(index: 0), parent&: *newRegion, |
| 793 | before: newRegion->begin()); |
| 794 | TypeConverter::SignatureConversion result(newRegion->getNumArguments()); |
| 795 | for (unsigned i = 0, e = newRegion->getNumArguments(); i < e; ++i) |
| 796 | result.addInputs(i, rewriter.getF64Type()); |
| 797 | rewriter.applySignatureConversion(block: &newRegion->front(), conversion&: result); |
| 798 | Operation *newOp = rewriter.create(state); |
| 799 | rewriter.replaceOp(op, newValues: newOp->getResults()); |
| 800 | return success(); |
| 801 | } |
| 802 | }; |
| 803 | |
| 804 | /// This pattern is a simple pattern that inlines the first region of a given |
| 805 | /// operation into the parent region. |
| 806 | struct TestRegionRewriteBlockMovement : public ConversionPattern { |
| 807 | TestRegionRewriteBlockMovement(MLIRContext *ctx) |
| 808 | : ConversionPattern("test.region" , 1, ctx) {} |
| 809 | |
| 810 | LogicalResult |
| 811 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 812 | ConversionPatternRewriter &rewriter) const final { |
| 813 | // Inline this region into the parent region. |
| 814 | auto &parentRegion = *op->getParentRegion(); |
| 815 | auto &opRegion = op->getRegion(index: 0); |
| 816 | if (op->getDiscardableAttr(name: "legalizer.should_clone" )) |
| 817 | rewriter.cloneRegionBefore(region&: opRegion, parent&: parentRegion, before: parentRegion.end()); |
| 818 | else |
| 819 | rewriter.inlineRegionBefore(region&: opRegion, parent&: parentRegion, before: parentRegion.end()); |
| 820 | |
| 821 | if (op->getDiscardableAttr(name: "legalizer.erase_old_blocks" )) { |
| 822 | while (!opRegion.empty()) |
| 823 | rewriter.eraseBlock(block: &opRegion.front()); |
| 824 | } |
| 825 | |
| 826 | // Drop this operation. |
| 827 | rewriter.eraseOp(op); |
| 828 | return success(); |
| 829 | } |
| 830 | }; |
| 831 | /// This pattern is a simple pattern that generates a region containing an |
| 832 | /// illegal operation. |
| 833 | struct TestRegionRewriteUndo : public RewritePattern { |
| 834 | TestRegionRewriteUndo(MLIRContext *ctx) |
| 835 | : RewritePattern("test.region_builder" , 1, ctx) {} |
| 836 | |
| 837 | LogicalResult matchAndRewrite(Operation *op, |
| 838 | PatternRewriter &rewriter) const final { |
| 839 | // Create the region operation with an entry block containing arguments. |
| 840 | OperationState newRegion(op->getLoc(), "test.region" ); |
| 841 | newRegion.addRegion(); |
| 842 | auto *regionOp = rewriter.create(state: newRegion); |
| 843 | auto *entryBlock = rewriter.createBlock(parent: ®ionOp->getRegion(index: 0)); |
| 844 | entryBlock->addArgument(rewriter.getIntegerType(64), |
| 845 | rewriter.getUnknownLoc()); |
| 846 | |
| 847 | // Add an explicitly illegal operation to ensure the conversion fails. |
| 848 | rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); |
| 849 | rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); |
| 850 | |
| 851 | // Drop this operation. |
| 852 | rewriter.eraseOp(op); |
| 853 | return success(); |
| 854 | } |
| 855 | }; |
| 856 | /// A simple pattern that creates a block at the end of the parent region of the |
| 857 | /// matched operation. |
| 858 | struct TestCreateBlock : public RewritePattern { |
| 859 | TestCreateBlock(MLIRContext *ctx) |
| 860 | : RewritePattern("test.create_block" , /*benefit=*/1, ctx) {} |
| 861 | |
| 862 | LogicalResult matchAndRewrite(Operation *op, |
| 863 | PatternRewriter &rewriter) const final { |
| 864 | Region ®ion = *op->getParentRegion(); |
| 865 | Type i32Type = rewriter.getIntegerType(32); |
| 866 | Location loc = op->getLoc(); |
| 867 | rewriter.createBlock(parent: ®ion, insertPt: region.end(), argTypes: {i32Type, i32Type}, locs: {loc, loc}); |
| 868 | rewriter.create<TerminatorOp>(loc); |
| 869 | rewriter.eraseOp(op); |
| 870 | return success(); |
| 871 | } |
| 872 | }; |
| 873 | |
| 874 | /// A simple pattern that creates a block containing an invalid operation in |
| 875 | /// order to trigger the block creation undo mechanism. |
| 876 | struct TestCreateIllegalBlock : public RewritePattern { |
| 877 | TestCreateIllegalBlock(MLIRContext *ctx) |
| 878 | : RewritePattern("test.create_illegal_block" , /*benefit=*/1, ctx) {} |
| 879 | |
| 880 | LogicalResult matchAndRewrite(Operation *op, |
| 881 | PatternRewriter &rewriter) const final { |
| 882 | Region ®ion = *op->getParentRegion(); |
| 883 | Type i32Type = rewriter.getIntegerType(32); |
| 884 | Location loc = op->getLoc(); |
| 885 | rewriter.createBlock(parent: ®ion, insertPt: region.end(), argTypes: {i32Type, i32Type}, locs: {loc, loc}); |
| 886 | // Create an illegal op to ensure the conversion fails. |
| 887 | rewriter.create<ILLegalOpF>(loc, i32Type); |
| 888 | rewriter.create<TerminatorOp>(loc); |
| 889 | rewriter.eraseOp(op); |
| 890 | return success(); |
| 891 | } |
| 892 | }; |
| 893 | |
| 894 | /// A simple pattern that tests the undo mechanism when replacing the uses of a |
| 895 | /// block argument. |
| 896 | struct TestUndoBlockArgReplace : public ConversionPattern { |
| 897 | TestUndoBlockArgReplace(MLIRContext *ctx) |
| 898 | : ConversionPattern("test.undo_block_arg_replace" , /*benefit=*/1, ctx) {} |
| 899 | |
| 900 | LogicalResult |
| 901 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 902 | ConversionPatternRewriter &rewriter) const final { |
| 903 | auto illegalOp = |
| 904 | rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); |
| 905 | rewriter.replaceUsesOfBlockArgument(from: op->getRegion(index: 0).getArgument(i: 0), |
| 906 | to: illegalOp->getResult(0)); |
| 907 | rewriter.modifyOpInPlace(root: op, callable: [] {}); |
| 908 | return success(); |
| 909 | } |
| 910 | }; |
| 911 | |
| 912 | /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion. |
| 913 | /// This is to test the rollback logic. |
| 914 | struct TestUndoMoveOpBefore : public ConversionPattern { |
| 915 | TestUndoMoveOpBefore(MLIRContext *ctx) |
| 916 | : ConversionPattern("test.hoist_me" , /*benefit=*/1, ctx) {} |
| 917 | |
| 918 | LogicalResult |
| 919 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 920 | ConversionPatternRewriter &rewriter) const override { |
| 921 | rewriter.moveOpBefore(op, existingOp: op->getParentOp()); |
| 922 | // Replace with an illegal op to ensure the conversion fails. |
| 923 | rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type()); |
| 924 | return success(); |
| 925 | } |
| 926 | }; |
| 927 | |
| 928 | /// A rewrite pattern that tests the undo mechanism when erasing a block. |
| 929 | struct TestUndoBlockErase : public ConversionPattern { |
| 930 | TestUndoBlockErase(MLIRContext *ctx) |
| 931 | : ConversionPattern("test.undo_block_erase" , /*benefit=*/1, ctx) {} |
| 932 | |
| 933 | LogicalResult |
| 934 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 935 | ConversionPatternRewriter &rewriter) const final { |
| 936 | Block *secondBlock = &*std::next(x: op->getRegion(index: 0).begin()); |
| 937 | rewriter.setInsertionPointToStart(secondBlock); |
| 938 | rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); |
| 939 | rewriter.eraseBlock(block: secondBlock); |
| 940 | rewriter.modifyOpInPlace(root: op, callable: [] {}); |
| 941 | return success(); |
| 942 | } |
| 943 | }; |
| 944 | |
| 945 | /// A pattern that modifies a property in-place, but keeps the op illegal. |
| 946 | struct TestUndoPropertiesModification : public ConversionPattern { |
| 947 | TestUndoPropertiesModification(MLIRContext *ctx) |
| 948 | : ConversionPattern("test.with_properties" , /*benefit=*/1, ctx) {} |
| 949 | LogicalResult |
| 950 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 951 | ConversionPatternRewriter &rewriter) const final { |
| 952 | if (!op->hasAttr(name: "modify_inplace" )) |
| 953 | return failure(); |
| 954 | rewriter.modifyOpInPlace( |
| 955 | root: op, callable: [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); }); |
| 956 | return success(); |
| 957 | } |
| 958 | }; |
| 959 | |
| 960 | //===----------------------------------------------------------------------===// |
| 961 | // Type-Conversion Rewrite Testing |
| 962 | //===----------------------------------------------------------------------===// |
| 963 | |
| 964 | /// This patterns erases a region operation that has had a type conversion. |
| 965 | struct TestDropOpSignatureConversion : public ConversionPattern { |
| 966 | TestDropOpSignatureConversion(MLIRContext *ctx, |
| 967 | const TypeConverter &converter) |
| 968 | : ConversionPattern(converter, "test.drop_region_op" , 1, ctx) {} |
| 969 | LogicalResult |
| 970 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 971 | ConversionPatternRewriter &rewriter) const override { |
| 972 | Region ®ion = op->getRegion(index: 0); |
| 973 | Block *entry = ®ion.front(); |
| 974 | |
| 975 | // Convert the original entry arguments. |
| 976 | const TypeConverter &converter = *getTypeConverter(); |
| 977 | TypeConverter::SignatureConversion result(entry->getNumArguments()); |
| 978 | if (failed(Result: converter.convertSignatureArgs(types: entry->getArgumentTypes(), |
| 979 | result)) || |
| 980 | failed(Result: rewriter.convertRegionTypes(region: ®ion, converter, entryConversion: &result))) |
| 981 | return failure(); |
| 982 | |
| 983 | // Convert the region signature and just drop the operation. |
| 984 | rewriter.eraseOp(op); |
| 985 | return success(); |
| 986 | } |
| 987 | }; |
| 988 | /// This pattern simply updates the operands of the given operation. |
| 989 | struct TestPassthroughInvalidOp : public ConversionPattern { |
| 990 | TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter) |
| 991 | : ConversionPattern(converter, "test.invalid" , 1, ctx) {} |
| 992 | LogicalResult |
| 993 | matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, |
| 994 | ConversionPatternRewriter &rewriter) const final { |
| 995 | SmallVector<Value> flattened; |
| 996 | for (auto it : llvm::enumerate(First&: operands)) { |
| 997 | ValueRange range = it.value(); |
| 998 | if (range.size() == 1) { |
| 999 | flattened.push_back(Elt: range.front()); |
| 1000 | continue; |
| 1001 | } |
| 1002 | |
| 1003 | // This is a 1:N replacement. Insert a test.cast op. (That's what the |
| 1004 | // argument materialization used to do.) |
| 1005 | flattened.push_back( |
| 1006 | rewriter |
| 1007 | .create<TestCastOp>(op->getLoc(), |
| 1008 | op->getOperand(it.index()).getType(), range) |
| 1009 | .getResult()); |
| 1010 | } |
| 1011 | rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, flattened, |
| 1012 | std::nullopt); |
| 1013 | return success(); |
| 1014 | } |
| 1015 | }; |
| 1016 | /// Replace with valid op, but simply drop the operands. This is used in a |
| 1017 | /// regression where we used to generate circular unrealized_conversion_cast |
| 1018 | /// ops. |
| 1019 | struct TestDropAndReplaceInvalidOp : public ConversionPattern { |
| 1020 | TestDropAndReplaceInvalidOp(MLIRContext *ctx, const TypeConverter &converter) |
| 1021 | : ConversionPattern(converter, |
| 1022 | "test.drop_operands_and_replace_with_valid" , 1, ctx) { |
| 1023 | } |
| 1024 | LogicalResult |
| 1025 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 1026 | ConversionPatternRewriter &rewriter) const final { |
| 1027 | rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, ValueRange(), |
| 1028 | std::nullopt); |
| 1029 | return success(); |
| 1030 | } |
| 1031 | }; |
| 1032 | /// This pattern handles the case of a split return value. |
| 1033 | struct TestSplitReturnType : public ConversionPattern { |
| 1034 | TestSplitReturnType(MLIRContext *ctx) |
| 1035 | : ConversionPattern("test.return" , 1, ctx) {} |
| 1036 | LogicalResult |
| 1037 | matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, |
| 1038 | ConversionPatternRewriter &rewriter) const final { |
| 1039 | // Check for a return of F32. |
| 1040 | if (op->getNumOperands() != 1 || !op->getOperand(idx: 0).getType().isF32()) |
| 1041 | return failure(); |
| 1042 | rewriter.replaceOpWithNewOp<TestReturnOp>(op, operands[0]); |
| 1043 | return success(); |
| 1044 | } |
| 1045 | }; |
| 1046 | |
| 1047 | //===----------------------------------------------------------------------===// |
| 1048 | // Multi-Level Type-Conversion Rewrite Testing |
| 1049 | struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { |
| 1050 | TestChangeProducerTypeI32ToF32(MLIRContext *ctx) |
| 1051 | : ConversionPattern("test.type_producer" , 1, ctx) {} |
| 1052 | LogicalResult |
| 1053 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 1054 | ConversionPatternRewriter &rewriter) const final { |
| 1055 | // If the type is I32, change the type to F32. |
| 1056 | if (!Type(*op->result_type_begin()).isSignlessInteger(width: 32)) |
| 1057 | return failure(); |
| 1058 | rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); |
| 1059 | return success(); |
| 1060 | } |
| 1061 | }; |
| 1062 | struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { |
| 1063 | TestChangeProducerTypeF32ToF64(MLIRContext *ctx) |
| 1064 | : ConversionPattern("test.type_producer" , 1, ctx) {} |
| 1065 | LogicalResult |
| 1066 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 1067 | ConversionPatternRewriter &rewriter) const final { |
| 1068 | // If the type is F32, change the type to F64. |
| 1069 | if (!Type(*op->result_type_begin()).isF32()) |
| 1070 | return rewriter.notifyMatchFailure(arg&: op, msg: "expected single f32 operand" ); |
| 1071 | rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); |
| 1072 | return success(); |
| 1073 | } |
| 1074 | }; |
| 1075 | struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { |
| 1076 | TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) |
| 1077 | : ConversionPattern("test.type_producer" , 10, ctx) {} |
| 1078 | LogicalResult |
| 1079 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 1080 | ConversionPatternRewriter &rewriter) const final { |
| 1081 | // Always convert to B16, even though it is not a legal type. This tests |
| 1082 | // that values are unmapped correctly. |
| 1083 | rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); |
| 1084 | return success(); |
| 1085 | } |
| 1086 | }; |
| 1087 | struct TestUpdateConsumerType : public ConversionPattern { |
| 1088 | TestUpdateConsumerType(MLIRContext *ctx) |
| 1089 | : ConversionPattern("test.type_consumer" , 1, ctx) {} |
| 1090 | LogicalResult |
| 1091 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 1092 | ConversionPatternRewriter &rewriter) const final { |
| 1093 | // Verify that the incoming operand has been successfully remapped to F64. |
| 1094 | if (!operands[0].getType().isF64()) |
| 1095 | return failure(); |
| 1096 | rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); |
| 1097 | return success(); |
| 1098 | } |
| 1099 | }; |
| 1100 | |
| 1101 | //===----------------------------------------------------------------------===// |
| 1102 | // Non-Root Replacement Rewrite Testing |
| 1103 | /// This pattern generates an invalid operation, but replaces it before the |
| 1104 | /// pattern is finished. This checks that we don't need to legalize the |
| 1105 | /// temporary op. |
| 1106 | struct TestNonRootReplacement : public RewritePattern { |
| 1107 | TestNonRootReplacement(MLIRContext *ctx) |
| 1108 | : RewritePattern("test.replace_non_root" , 1, ctx) {} |
| 1109 | |
| 1110 | LogicalResult matchAndRewrite(Operation *op, |
| 1111 | PatternRewriter &rewriter) const final { |
| 1112 | auto resultType = *op->result_type_begin(); |
| 1113 | auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); |
| 1114 | auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); |
| 1115 | |
| 1116 | rewriter.replaceOp(illegalOp, legalOp); |
| 1117 | rewriter.replaceOp(op, illegalOp); |
| 1118 | return success(); |
| 1119 | } |
| 1120 | }; |
| 1121 | |
| 1122 | //===----------------------------------------------------------------------===// |
| 1123 | // Recursive Rewrite Testing |
| 1124 | /// This pattern is applied to the same operation multiple times, but has a |
| 1125 | /// bounded recursion. |
| 1126 | struct TestBoundedRecursiveRewrite |
| 1127 | : public OpRewritePattern<TestRecursiveRewriteOp> { |
| 1128 | using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; |
| 1129 | |
| 1130 | void initialize() { |
| 1131 | // The conversion target handles bounding the recursion of this pattern. |
| 1132 | setHasBoundedRewriteRecursion(); |
| 1133 | } |
| 1134 | |
| 1135 | LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, |
| 1136 | PatternRewriter &rewriter) const final { |
| 1137 | // Decrement the depth of the op in-place. |
| 1138 | rewriter.modifyOpInPlace(op, [&] { |
| 1139 | op->setAttr("depth" , rewriter.getI64IntegerAttr(value: op.getDepth() - 1)); |
| 1140 | }); |
| 1141 | return success(); |
| 1142 | } |
| 1143 | }; |
| 1144 | |
| 1145 | struct TestNestedOpCreationUndoRewrite |
| 1146 | : public OpRewritePattern<IllegalOpWithRegionAnchor> { |
| 1147 | using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; |
| 1148 | |
| 1149 | LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, |
| 1150 | PatternRewriter &rewriter) const final { |
| 1151 | // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); |
| 1152 | rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); |
| 1153 | return success(); |
| 1154 | }; |
| 1155 | }; |
| 1156 | |
| 1157 | // This pattern matches `test.blackhole` and delete this op and its producer. |
| 1158 | struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { |
| 1159 | using OpRewritePattern<BlackHoleOp>::OpRewritePattern; |
| 1160 | |
| 1161 | LogicalResult matchAndRewrite(BlackHoleOp op, |
| 1162 | PatternRewriter &rewriter) const final { |
| 1163 | Operation *producer = op.getOperand().getDefiningOp(); |
| 1164 | // Always erase the user before the producer, the framework should handle |
| 1165 | // this correctly. |
| 1166 | rewriter.eraseOp(op: op); |
| 1167 | rewriter.eraseOp(op: producer); |
| 1168 | return success(); |
| 1169 | }; |
| 1170 | }; |
| 1171 | |
| 1172 | // This pattern replaces explicitly illegal op with explicitly legal op, |
| 1173 | // but in addition creates unregistered operation. |
| 1174 | struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { |
| 1175 | using OpRewritePattern<ILLegalOpG>::OpRewritePattern; |
| 1176 | |
| 1177 | LogicalResult matchAndRewrite(ILLegalOpG op, |
| 1178 | PatternRewriter &rewriter) const final { |
| 1179 | IntegerAttr attr = rewriter.getI32IntegerAttr(0); |
| 1180 | Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr); |
| 1181 | rewriter.replaceOpWithNewOp<LegalOpC>(op, val); |
| 1182 | return success(); |
| 1183 | }; |
| 1184 | }; |
| 1185 | |
| 1186 | class TestEraseOp : public ConversionPattern { |
| 1187 | public: |
| 1188 | TestEraseOp(MLIRContext *ctx) : ConversionPattern("test.erase_op" , 1, ctx) {} |
| 1189 | LogicalResult |
| 1190 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 1191 | ConversionPatternRewriter &rewriter) const final { |
| 1192 | // Erase op without replacements. |
| 1193 | rewriter.eraseOp(op); |
| 1194 | return success(); |
| 1195 | } |
| 1196 | }; |
| 1197 | |
| 1198 | /// This pattern matches a test.convert_block_args op. It either: |
| 1199 | /// a) Duplicates all block arguments, |
| 1200 | /// b) or: drops all block arguments and replaces each with 2x the first |
| 1201 | /// operand. |
| 1202 | class TestConvertBlockArgs : public OpConversionPattern<ConvertBlockArgsOp> { |
| 1203 | using OpConversionPattern<ConvertBlockArgsOp>::OpConversionPattern; |
| 1204 | |
| 1205 | LogicalResult |
| 1206 | matchAndRewrite(ConvertBlockArgsOp op, OpAdaptor adaptor, |
| 1207 | ConversionPatternRewriter &rewriter) const override { |
| 1208 | if (op.getIsLegal()) |
| 1209 | return failure(); |
| 1210 | Block *body = &op.getBody().front(); |
| 1211 | TypeConverter::SignatureConversion result(body->getNumArguments()); |
| 1212 | for (auto it : llvm::enumerate(body->getArgumentTypes())) { |
| 1213 | if (op.getReplaceWithOperand()) { |
| 1214 | result.remapInput(it.index(), {adaptor.getVal(), adaptor.getVal()}); |
| 1215 | } else if (op.getDuplicate()) { |
| 1216 | result.addInputs(it.index(), {it.value(), it.value()}); |
| 1217 | } else { |
| 1218 | // No action specified. Pattern does not apply. |
| 1219 | return failure(); |
| 1220 | } |
| 1221 | } |
| 1222 | rewriter.startOpModification(op: op); |
| 1223 | rewriter.applySignatureConversion(block: body, conversion&: result, converter: getTypeConverter()); |
| 1224 | op.setIsLegal(true); |
| 1225 | rewriter.finalizeOpModification(op: op); |
| 1226 | return success(); |
| 1227 | } |
| 1228 | }; |
| 1229 | |
| 1230 | /// This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid |
| 1231 | /// op. The pattern supports 1:N replacements and forwards the replacement |
| 1232 | /// values of the single operand as test.valid operands. |
| 1233 | class TestRepetitive1ToNConsumer : public ConversionPattern { |
| 1234 | public: |
| 1235 | TestRepetitive1ToNConsumer(MLIRContext *ctx) |
| 1236 | : ConversionPattern("test.repetitive_1_to_n_consumer" , 1, ctx) {} |
| 1237 | LogicalResult |
| 1238 | matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, |
| 1239 | ConversionPatternRewriter &rewriter) const final { |
| 1240 | // A single operand is expected. |
| 1241 | if (op->getNumOperands() != 1) |
| 1242 | return failure(); |
| 1243 | rewriter.replaceOpWithNewOp<TestValidOp>(op, operands.front()); |
| 1244 | return success(); |
| 1245 | } |
| 1246 | }; |
| 1247 | |
| 1248 | /// A pattern that tests two back-to-back 1 -> 2 op replacements. |
| 1249 | class TestMultiple1ToNReplacement : public ConversionPattern { |
| 1250 | public: |
| 1251 | TestMultiple1ToNReplacement(MLIRContext *ctx, const TypeConverter &converter) |
| 1252 | : ConversionPattern(converter, "test.multiple_1_to_n_replacement" , 1, |
| 1253 | ctx) {} |
| 1254 | LogicalResult |
| 1255 | matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, |
| 1256 | ConversionPatternRewriter &rewriter) const final { |
| 1257 | // Helper function that replaces the given op with a new op of the given |
| 1258 | // name and doubles each result (1 -> 2 replacement of each result). |
| 1259 | auto replaceWithDoubleResults = [&](Operation *op, StringRef name) { |
| 1260 | SmallVector<Type> types; |
| 1261 | for (Type t : op->getResultTypes()) { |
| 1262 | types.push_back(Elt: t); |
| 1263 | types.push_back(Elt: t); |
| 1264 | } |
| 1265 | OperationState state(op->getLoc(), name, |
| 1266 | /*operands=*/{}, types, op->getAttrs()); |
| 1267 | auto *newOp = rewriter.create(state); |
| 1268 | SmallVector<ValueRange> repls; |
| 1269 | for (size_t i = 0, e = op->getNumResults(); i < e; ++i) |
| 1270 | repls.push_back(Elt: newOp->getResults().slice(n: 2 * i, m: 2)); |
| 1271 | rewriter.replaceOpWithMultiple(op, newValues&: repls); |
| 1272 | return newOp; |
| 1273 | }; |
| 1274 | |
| 1275 | // Replace test.multiple_1_to_n_replacement with test.step_1. |
| 1276 | Operation *repl1 = replaceWithDoubleResults(op, "test.step_1" ); |
| 1277 | // Now replace test.step_1 with test.legal_op. |
| 1278 | replaceWithDoubleResults(repl1, "test.legal_op" ); |
| 1279 | return success(); |
| 1280 | } |
| 1281 | }; |
| 1282 | |
| 1283 | /// Test unambiguous overload resolution of replaceOpWithMultiple. This |
| 1284 | /// function is just to trigger compiler errors. It is never executed. |
| 1285 | [[maybe_unused]] void testReplaceOpWithMultipleOverloads( |
| 1286 | ConversionPatternRewriter &rewriter, Operation *op, ArrayRef<ValueRange> r1, |
| 1287 | SmallVector<ValueRange> r2, ArrayRef<SmallVector<Value>> r3, |
| 1288 | SmallVector<SmallVector<Value>> r4, ArrayRef<ArrayRef<Value>> r5, |
| 1289 | SmallVector<ArrayRef<Value>> r6, SmallVector<SmallVector<Value>> &&r7, |
| 1290 | Value v, ValueRange vr, ArrayRef<Value> ar) { |
| 1291 | rewriter.replaceOpWithMultiple(op, newValues: r1); |
| 1292 | rewriter.replaceOpWithMultiple(op, newValues&: r2); |
| 1293 | rewriter.replaceOpWithMultiple(op, newValues: r3); |
| 1294 | rewriter.replaceOpWithMultiple(op, newValues&: r4); |
| 1295 | rewriter.replaceOpWithMultiple(op, newValues: r5); |
| 1296 | rewriter.replaceOpWithMultiple(op, newValues&: r6); |
| 1297 | rewriter.replaceOpWithMultiple(op, newValues: std::move(r7)); |
| 1298 | rewriter.replaceOpWithMultiple(op, newValues: {vr}); |
| 1299 | rewriter.replaceOpWithMultiple(op, newValues: {ar}); |
| 1300 | rewriter.replaceOpWithMultiple(op, newValues: {{v}}); |
| 1301 | rewriter.replaceOpWithMultiple(op, newValues: {{v, v}}); |
| 1302 | rewriter.replaceOpWithMultiple(op, newValues: {{v, v}, vr}); |
| 1303 | rewriter.replaceOpWithMultiple(op, newValues: {{v, v}, ar}); |
| 1304 | rewriter.replaceOpWithMultiple(op, newValues: {ar, {v, v}, vr}); |
| 1305 | } |
| 1306 | } // namespace |
| 1307 | |
| 1308 | namespace { |
| 1309 | struct TestTypeConverter : public TypeConverter { |
| 1310 | using TypeConverter::TypeConverter; |
| 1311 | TestTypeConverter() { |
| 1312 | addConversion(callback&: convertType); |
| 1313 | addSourceMaterialization(callback&: materializeCast); |
| 1314 | } |
| 1315 | |
| 1316 | static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { |
| 1317 | // Drop I16 types. |
| 1318 | if (t.isSignlessInteger(width: 16)) |
| 1319 | return success(); |
| 1320 | |
| 1321 | // Convert I64 to F64. |
| 1322 | if (t.isSignlessInteger(width: 64)) { |
| 1323 | results.push_back(Float64Type::get(t.getContext())); |
| 1324 | return success(); |
| 1325 | } |
| 1326 | |
| 1327 | // Convert I42 to I43. |
| 1328 | if (t.isInteger(width: 42)) { |
| 1329 | results.push_back(IntegerType::get(t.getContext(), 43)); |
| 1330 | return success(); |
| 1331 | } |
| 1332 | |
| 1333 | // Split F32 into F16,F16. |
| 1334 | if (t.isF32()) { |
| 1335 | results.assign(2, Float16Type::get(t.getContext())); |
| 1336 | return success(); |
| 1337 | } |
| 1338 | |
| 1339 | // Drop I24 types. |
| 1340 | if (t.isInteger(width: 24)) { |
| 1341 | return success(); |
| 1342 | } |
| 1343 | |
| 1344 | // Otherwise, convert the type directly. |
| 1345 | results.push_back(Elt: t); |
| 1346 | return success(); |
| 1347 | } |
| 1348 | |
| 1349 | /// Hook for materializing a conversion. This is necessary because we generate |
| 1350 | /// 1->N type mappings. |
| 1351 | static Value materializeCast(OpBuilder &builder, Type resultType, |
| 1352 | ValueRange inputs, Location loc) { |
| 1353 | return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); |
| 1354 | } |
| 1355 | }; |
| 1356 | |
| 1357 | struct TestLegalizePatternDriver |
| 1358 | : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> { |
| 1359 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver) |
| 1360 | |
| 1361 | StringRef getArgument() const final { return "test-legalize-patterns" ; } |
| 1362 | StringRef getDescription() const final { |
| 1363 | return "Run test dialect legalization patterns" ; |
| 1364 | } |
| 1365 | /// The mode of conversion to use with the driver. |
| 1366 | enum class ConversionMode { Analysis, Full, Partial }; |
| 1367 | |
| 1368 | TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} |
| 1369 | |
| 1370 | void getDependentDialects(DialectRegistry ®istry) const override { |
| 1371 | registry.insert<func::FuncDialect, test::TestDialect>(); |
| 1372 | } |
| 1373 | |
| 1374 | void runOnOperation() override { |
| 1375 | TestTypeConverter converter; |
| 1376 | mlir::RewritePatternSet patterns(&getContext()); |
| 1377 | populateWithGenerated(patterns); |
| 1378 | patterns |
| 1379 | .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion, |
| 1380 | TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock, |
| 1381 | TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType, |
| 1382 | TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, |
| 1383 | TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, |
| 1384 | TestNonRootReplacement, TestBoundedRecursiveRewrite, |
| 1385 | TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, |
| 1386 | TestCreateUnregisteredOp, TestUndoMoveOpBefore, |
| 1387 | TestUndoPropertiesModification, TestEraseOp, |
| 1388 | TestRepetitive1ToNConsumer>(arg: &getContext()); |
| 1389 | patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp, |
| 1390 | TestPassthroughInvalidOp, TestMultiple1ToNReplacement>( |
| 1391 | arg: &getContext(), args&: converter); |
| 1392 | patterns.add<TestConvertBlockArgs>(arg&: converter, args: &getContext()); |
| 1393 | mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, |
| 1394 | converter); |
| 1395 | mlir::populateCallOpTypeConversionPattern(patterns, converter); |
| 1396 | |
| 1397 | // Define the conversion target used for the test. |
| 1398 | ConversionTarget target(getContext()); |
| 1399 | target.addLegalOp<ModuleOp>(); |
| 1400 | target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, |
| 1401 | TerminatorOp, OneRegionOp>(); |
| 1402 | target.addLegalOp(op: OperationName("test.legal_op" , &getContext())); |
| 1403 | target |
| 1404 | .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); |
| 1405 | target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { |
| 1406 | // Don't allow F32 operands. |
| 1407 | return llvm::none_of(op.getOperandTypes(), |
| 1408 | [](Type type) { return type.isF32(); }); |
| 1409 | }); |
| 1410 | target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
| 1411 | return converter.isSignatureLegal(op.getFunctionType()) && |
| 1412 | converter.isLegal(&op.getBody()); |
| 1413 | }); |
| 1414 | target.addDynamicallyLegalOp<func::CallOp>( |
| 1415 | [&](func::CallOp op) { return converter.isLegal(op); }); |
| 1416 | |
| 1417 | // TestCreateUnregisteredOp creates `arith.constant` operation, |
| 1418 | // which was not added to target intentionally to test |
| 1419 | // correct error code from conversion driver. |
| 1420 | target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; }); |
| 1421 | |
| 1422 | // Expect the type_producer/type_consumer operations to only operate on f64. |
| 1423 | target.addDynamicallyLegalOp<TestTypeProducerOp>( |
| 1424 | [](TestTypeProducerOp op) { return op.getType().isF64(); }); |
| 1425 | target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { |
| 1426 | return op.getOperand().getType().isF64(); |
| 1427 | }); |
| 1428 | |
| 1429 | // Check support for marking certain operations as recursively legal. |
| 1430 | target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) { |
| 1431 | return static_cast<bool>( |
| 1432 | op->getAttrOfType<UnitAttr>("test.recursively_legal" )); |
| 1433 | }); |
| 1434 | |
| 1435 | // Mark the bound recursion operation as dynamically legal. |
| 1436 | target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( |
| 1437 | [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; }); |
| 1438 | |
| 1439 | // Create a dynamically legal rule that can only be legalized by folding it. |
| 1440 | target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>( |
| 1441 | [](TestOpInPlaceSelfFold op) { return op.getFolded(); }); |
| 1442 | |
| 1443 | target.addDynamicallyLegalOp<ConvertBlockArgsOp>( |
| 1444 | [](ConvertBlockArgsOp op) { return op.getIsLegal(); }); |
| 1445 | |
| 1446 | // Handle a partial conversion. |
| 1447 | if (mode == ConversionMode::Partial) { |
| 1448 | DenseSet<Operation *> unlegalizedOps; |
| 1449 | ConversionConfig config; |
| 1450 | DumpNotifications dumpNotifications; |
| 1451 | config.listener = &dumpNotifications; |
| 1452 | config.unlegalizedOps = &unlegalizedOps; |
| 1453 | if (failed(applyPartialConversion(getOperation(), target, |
| 1454 | std::move(patterns), config))) { |
| 1455 | getOperation()->emitRemark() << "applyPartialConversion failed" ; |
| 1456 | } |
| 1457 | // Emit remarks for each legalizable operation. |
| 1458 | for (auto *op : unlegalizedOps) |
| 1459 | op->emitRemark() << "op '" << op->getName() << "' is not legalizable" ; |
| 1460 | return; |
| 1461 | } |
| 1462 | |
| 1463 | // Handle a full conversion. |
| 1464 | if (mode == ConversionMode::Full) { |
| 1465 | // Check support for marking unknown operations as dynamically legal. |
| 1466 | target.markUnknownOpDynamicallyLegal([](Operation *op) { |
| 1467 | return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal" ); |
| 1468 | }); |
| 1469 | |
| 1470 | ConversionConfig config; |
| 1471 | DumpNotifications dumpNotifications; |
| 1472 | config.listener = &dumpNotifications; |
| 1473 | if (failed(applyFullConversion(getOperation(), target, |
| 1474 | std::move(patterns), config))) { |
| 1475 | getOperation()->emitRemark() << "applyFullConversion failed" ; |
| 1476 | } |
| 1477 | return; |
| 1478 | } |
| 1479 | |
| 1480 | // Otherwise, handle an analysis conversion. |
| 1481 | assert(mode == ConversionMode::Analysis); |
| 1482 | |
| 1483 | // Analyze the convertible operations. |
| 1484 | DenseSet<Operation *> legalizedOps; |
| 1485 | ConversionConfig config; |
| 1486 | config.legalizableOps = &legalizedOps; |
| 1487 | if (failed(applyAnalysisConversion(getOperation(), target, |
| 1488 | std::move(patterns), config))) |
| 1489 | return signalPassFailure(); |
| 1490 | |
| 1491 | // Emit remarks for each legalizable operation. |
| 1492 | for (auto *op : legalizedOps) |
| 1493 | op->emitRemark() << "op '" << op->getName() << "' is legalizable" ; |
| 1494 | } |
| 1495 | |
| 1496 | /// The mode of conversion to use. |
| 1497 | ConversionMode mode; |
| 1498 | }; |
| 1499 | } // namespace |
| 1500 | |
| 1501 | static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> |
| 1502 | legalizerConversionMode( |
| 1503 | "test-legalize-mode" , |
| 1504 | llvm::cl::desc("The legalization mode to use with the test driver" ), |
| 1505 | llvm::cl::init(Val: TestLegalizePatternDriver::ConversionMode::Partial), |
| 1506 | llvm::cl::values( |
| 1507 | clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, |
| 1508 | "analysis" , "Perform an analysis conversion" ), |
| 1509 | clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full" , |
| 1510 | "Perform a full conversion" ), |
| 1511 | clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, |
| 1512 | "partial" , "Perform a partial conversion" ))); |
| 1513 | |
| 1514 | //===----------------------------------------------------------------------===// |
| 1515 | // ConversionPatternRewriter::getRemappedValue testing. This method is used |
| 1516 | // to get the remapped value of an original value that was replaced using |
| 1517 | // ConversionPatternRewriter. |
| 1518 | namespace { |
| 1519 | struct TestRemapValueTypeConverter : public TypeConverter { |
| 1520 | using TypeConverter::TypeConverter; |
| 1521 | |
| 1522 | TestRemapValueTypeConverter() { |
| 1523 | addConversion( |
| 1524 | callback: [](Float32Type type) { return Float64Type::get(type.getContext()); }); |
| 1525 | addConversion(callback: [](Type type) { return type; }); |
| 1526 | } |
| 1527 | }; |
| 1528 | |
| 1529 | /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with |
| 1530 | /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original |
| 1531 | /// operand twice. |
| 1532 | /// |
| 1533 | /// Example: |
| 1534 | /// %1 = test.one_variadic_out_one_variadic_in1"(%0) |
| 1535 | /// is replaced with: |
| 1536 | /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) |
| 1537 | struct OneVResOneVOperandOp1Converter |
| 1538 | : public OpConversionPattern<OneVResOneVOperandOp1> { |
| 1539 | using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; |
| 1540 | |
| 1541 | LogicalResult |
| 1542 | matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor, |
| 1543 | ConversionPatternRewriter &rewriter) const override { |
| 1544 | auto origOps = op.getOperands(); |
| 1545 | assert(std::distance(origOps.begin(), origOps.end()) == 1 && |
| 1546 | "One operand expected" ); |
| 1547 | Value origOp = *origOps.begin(); |
| 1548 | SmallVector<Value, 2> remappedOperands; |
| 1549 | // Replicate the remapped original operand twice. Note that we don't used |
| 1550 | // the remapped 'operand' since the goal is testing 'getRemappedValue'. |
| 1551 | remappedOperands.push_back(Elt: rewriter.getRemappedValue(key: origOp)); |
| 1552 | remappedOperands.push_back(Elt: rewriter.getRemappedValue(key: origOp)); |
| 1553 | |
| 1554 | rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), |
| 1555 | remappedOperands); |
| 1556 | return success(); |
| 1557 | } |
| 1558 | }; |
| 1559 | |
| 1560 | /// A rewriter pattern that tests that blocks can be merged. |
| 1561 | struct TestRemapValueInRegion |
| 1562 | : public OpConversionPattern<TestRemappedValueRegionOp> { |
| 1563 | using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern; |
| 1564 | |
| 1565 | LogicalResult |
| 1566 | matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor, |
| 1567 | ConversionPatternRewriter &rewriter) const final { |
| 1568 | Block &block = op.getBody().front(); |
| 1569 | Operation *terminator = block.getTerminator(); |
| 1570 | |
| 1571 | // Merge the block into the parent region. |
| 1572 | Block *parentBlock = op->getBlock(); |
| 1573 | Block *finalBlock = rewriter.splitBlock(block: parentBlock, before: op->getIterator()); |
| 1574 | rewriter.mergeBlocks(source: &block, dest: parentBlock, argValues: ValueRange()); |
| 1575 | rewriter.mergeBlocks(source: finalBlock, dest: parentBlock, argValues: ValueRange()); |
| 1576 | |
| 1577 | // Replace the results of this operation with the remapped terminator |
| 1578 | // values. |
| 1579 | SmallVector<Value> terminatorOperands; |
| 1580 | if (failed(Result: rewriter.getRemappedValues(keys: terminator->getOperands(), |
| 1581 | results&: terminatorOperands))) |
| 1582 | return failure(); |
| 1583 | |
| 1584 | rewriter.eraseOp(op: terminator); |
| 1585 | rewriter.replaceOp(op, terminatorOperands); |
| 1586 | return success(); |
| 1587 | } |
| 1588 | }; |
| 1589 | |
| 1590 | struct TestRemappedValue |
| 1591 | : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> { |
| 1592 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue) |
| 1593 | |
| 1594 | StringRef getArgument() const final { return "test-remapped-value" ; } |
| 1595 | StringRef getDescription() const final { |
| 1596 | return "Test public remapped value mechanism in ConversionPatternRewriter" ; |
| 1597 | } |
| 1598 | void runOnOperation() override { |
| 1599 | TestRemapValueTypeConverter typeConverter; |
| 1600 | |
| 1601 | mlir::RewritePatternSet patterns(&getContext()); |
| 1602 | patterns.add<OneVResOneVOperandOp1Converter>(arg: &getContext()); |
| 1603 | patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>( |
| 1604 | arg: &getContext()); |
| 1605 | patterns.add<TestRemapValueInRegion>(arg&: typeConverter, args: &getContext()); |
| 1606 | |
| 1607 | mlir::ConversionTarget target(getContext()); |
| 1608 | target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>(); |
| 1609 | |
| 1610 | // Expect the type_producer/type_consumer operations to only operate on f64. |
| 1611 | target.addDynamicallyLegalOp<TestTypeProducerOp>( |
| 1612 | [](TestTypeProducerOp op) { return op.getType().isF64(); }); |
| 1613 | target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { |
| 1614 | return op.getOperand().getType().isF64(); |
| 1615 | }); |
| 1616 | |
| 1617 | // We make OneVResOneVOperandOp1 legal only when it has more that one |
| 1618 | // operand. This will trigger the conversion that will replace one-operand |
| 1619 | // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. |
| 1620 | target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( |
| 1621 | [](Operation *op) { return op->getNumOperands() > 1; }); |
| 1622 | |
| 1623 | if (failed(mlir::applyFullConversion(getOperation(), target, |
| 1624 | std::move(patterns)))) { |
| 1625 | signalPassFailure(); |
| 1626 | } |
| 1627 | } |
| 1628 | }; |
| 1629 | } // namespace |
| 1630 | |
| 1631 | //===----------------------------------------------------------------------===// |
| 1632 | // Test patterns without a specific root operation kind |
| 1633 | //===----------------------------------------------------------------------===// |
| 1634 | |
| 1635 | namespace { |
| 1636 | /// This pattern matches and removes any operation in the test dialect. |
| 1637 | struct RemoveTestDialectOps : public RewritePattern { |
| 1638 | RemoveTestDialectOps(MLIRContext *context) |
| 1639 | : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} |
| 1640 | |
| 1641 | LogicalResult matchAndRewrite(Operation *op, |
| 1642 | PatternRewriter &rewriter) const override { |
| 1643 | if (!isa<TestDialect>(Val: op->getDialect())) |
| 1644 | return failure(); |
| 1645 | rewriter.eraseOp(op); |
| 1646 | return success(); |
| 1647 | } |
| 1648 | }; |
| 1649 | |
| 1650 | struct TestUnknownRootOpDriver |
| 1651 | : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> { |
| 1652 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver) |
| 1653 | |
| 1654 | StringRef getArgument() const final { |
| 1655 | return "test-legalize-unknown-root-patterns" ; |
| 1656 | } |
| 1657 | StringRef getDescription() const final { |
| 1658 | return "Test public remapped value mechanism in ConversionPatternRewriter" ; |
| 1659 | } |
| 1660 | void runOnOperation() override { |
| 1661 | mlir::RewritePatternSet patterns(&getContext()); |
| 1662 | patterns.add<RemoveTestDialectOps>(arg: &getContext()); |
| 1663 | |
| 1664 | mlir::ConversionTarget target(getContext()); |
| 1665 | target.addIllegalDialect<TestDialect>(); |
| 1666 | if (failed(applyPartialConversion(getOperation(), target, |
| 1667 | std::move(patterns)))) |
| 1668 | signalPassFailure(); |
| 1669 | } |
| 1670 | }; |
| 1671 | } // namespace |
| 1672 | |
| 1673 | //===----------------------------------------------------------------------===// |
| 1674 | // Test patterns that uses operations and types defined at runtime |
| 1675 | //===----------------------------------------------------------------------===// |
| 1676 | |
| 1677 | namespace { |
| 1678 | /// This pattern matches dynamic operations 'test.one_operand_two_results' and |
| 1679 | /// replace them with dynamic operations 'test.generic_dynamic_op'. |
| 1680 | struct RewriteDynamicOp : public RewritePattern { |
| 1681 | RewriteDynamicOp(MLIRContext *context) |
| 1682 | : RewritePattern("test.dynamic_one_operand_two_results" , /*benefit=*/1, |
| 1683 | context) {} |
| 1684 | |
| 1685 | LogicalResult matchAndRewrite(Operation *op, |
| 1686 | PatternRewriter &rewriter) const override { |
| 1687 | assert(op->getName().getStringRef() == |
| 1688 | "test.dynamic_one_operand_two_results" && |
| 1689 | "rewrite pattern should only match operations with the right name" ); |
| 1690 | |
| 1691 | OperationState state(op->getLoc(), "test.dynamic_generic" , |
| 1692 | op->getOperands(), op->getResultTypes(), |
| 1693 | op->getAttrs()); |
| 1694 | auto *newOp = rewriter.create(state); |
| 1695 | rewriter.replaceOp(op, newValues: newOp->getResults()); |
| 1696 | return success(); |
| 1697 | } |
| 1698 | }; |
| 1699 | |
| 1700 | struct TestRewriteDynamicOpDriver |
| 1701 | : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> { |
| 1702 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver) |
| 1703 | |
| 1704 | void getDependentDialects(DialectRegistry ®istry) const override { |
| 1705 | registry.insert<TestDialect>(); |
| 1706 | } |
| 1707 | StringRef getArgument() const final { return "test-rewrite-dynamic-op" ; } |
| 1708 | StringRef getDescription() const final { |
| 1709 | return "Test rewritting on dynamic operations" ; |
| 1710 | } |
| 1711 | void runOnOperation() override { |
| 1712 | RewritePatternSet patterns(&getContext()); |
| 1713 | patterns.add<RewriteDynamicOp>(arg: &getContext()); |
| 1714 | |
| 1715 | ConversionTarget target(getContext()); |
| 1716 | target.addIllegalOp( |
| 1717 | op: OperationName("test.dynamic_one_operand_two_results" , &getContext())); |
| 1718 | target.addLegalOp(op: OperationName("test.dynamic_generic" , &getContext())); |
| 1719 | if (failed(applyPartialConversion(getOperation(), target, |
| 1720 | std::move(patterns)))) |
| 1721 | signalPassFailure(); |
| 1722 | } |
| 1723 | }; |
| 1724 | } // end anonymous namespace |
| 1725 | |
| 1726 | //===----------------------------------------------------------------------===// |
| 1727 | // Test type conversions |
| 1728 | //===----------------------------------------------------------------------===// |
| 1729 | |
| 1730 | namespace { |
| 1731 | struct TestTypeConversionProducer |
| 1732 | : public OpConversionPattern<TestTypeProducerOp> { |
| 1733 | using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; |
| 1734 | LogicalResult |
| 1735 | matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, |
| 1736 | ConversionPatternRewriter &rewriter) const final { |
| 1737 | Type resultType = op.getType(); |
| 1738 | Type convertedType = getTypeConverter() |
| 1739 | ? getTypeConverter()->convertType(resultType) |
| 1740 | : resultType; |
| 1741 | if (isa<FloatType>(Val: resultType)) |
| 1742 | resultType = rewriter.getF64Type(); |
| 1743 | else if (resultType.isInteger(width: 16)) |
| 1744 | resultType = rewriter.getIntegerType(64); |
| 1745 | else if (isa<test::TestRecursiveType>(Val: resultType) && |
| 1746 | convertedType != resultType) |
| 1747 | resultType = convertedType; |
| 1748 | else |
| 1749 | return failure(); |
| 1750 | |
| 1751 | rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); |
| 1752 | return success(); |
| 1753 | } |
| 1754 | }; |
| 1755 | |
| 1756 | /// Call signature conversion and then fail the rewrite to trigger the undo |
| 1757 | /// mechanism. |
| 1758 | struct TestSignatureConversionUndo |
| 1759 | : public OpConversionPattern<TestSignatureConversionUndoOp> { |
| 1760 | using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; |
| 1761 | |
| 1762 | LogicalResult |
| 1763 | matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor, |
| 1764 | ConversionPatternRewriter &rewriter) const final { |
| 1765 | (void)rewriter.convertRegionTypes(region: &op->getRegion(0), converter: *getTypeConverter()); |
| 1766 | return failure(); |
| 1767 | } |
| 1768 | }; |
| 1769 | |
| 1770 | /// Call signature conversion without providing a type converter to handle |
| 1771 | /// materializations. |
| 1772 | struct TestTestSignatureConversionNoConverter |
| 1773 | : public OpConversionPattern<TestSignatureConversionNoConverterOp> { |
| 1774 | TestTestSignatureConversionNoConverter(const TypeConverter &converter, |
| 1775 | MLIRContext *context) |
| 1776 | : OpConversionPattern<TestSignatureConversionNoConverterOp>(context), |
| 1777 | converter(converter) {} |
| 1778 | |
| 1779 | LogicalResult |
| 1780 | matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor, |
| 1781 | ConversionPatternRewriter &rewriter) const final { |
| 1782 | Region ®ion = op->getRegion(0); |
| 1783 | Block *entry = ®ion.front(); |
| 1784 | |
| 1785 | // Convert the original entry arguments. |
| 1786 | TypeConverter::SignatureConversion result(entry->getNumArguments()); |
| 1787 | if (failed( |
| 1788 | Result: converter.convertSignatureArgs(types: entry->getArgumentTypes(), result))) |
| 1789 | return failure(); |
| 1790 | rewriter.modifyOpInPlace(op, [&] { |
| 1791 | rewriter.applySignatureConversion(block: ®ion.front(), conversion&: result); |
| 1792 | }); |
| 1793 | return success(); |
| 1794 | } |
| 1795 | |
| 1796 | const TypeConverter &converter; |
| 1797 | }; |
| 1798 | |
| 1799 | /// Just forward the operands to the root op. This is essentially a no-op |
| 1800 | /// pattern that is used to trigger target materialization. |
| 1801 | struct TestTypeConsumerForward |
| 1802 | : public OpConversionPattern<TestTypeConsumerOp> { |
| 1803 | using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; |
| 1804 | |
| 1805 | LogicalResult |
| 1806 | matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, |
| 1807 | ConversionPatternRewriter &rewriter) const final { |
| 1808 | rewriter.modifyOpInPlace(op, |
| 1809 | [&] { op->setOperands(adaptor.getOperands()); }); |
| 1810 | return success(); |
| 1811 | } |
| 1812 | }; |
| 1813 | |
| 1814 | struct TestTypeConversionAnotherProducer |
| 1815 | : public OpRewritePattern<TestAnotherTypeProducerOp> { |
| 1816 | using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; |
| 1817 | |
| 1818 | LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, |
| 1819 | PatternRewriter &rewriter) const final { |
| 1820 | rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); |
| 1821 | return success(); |
| 1822 | } |
| 1823 | }; |
| 1824 | |
| 1825 | struct TestReplaceWithLegalOp : public ConversionPattern { |
| 1826 | TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx) |
| 1827 | : ConversionPattern(converter, "test.replace_with_legal_op" , |
| 1828 | /*benefit=*/1, ctx) {} |
| 1829 | LogicalResult |
| 1830 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 1831 | ConversionPatternRewriter &rewriter) const final { |
| 1832 | rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]); |
| 1833 | return success(); |
| 1834 | } |
| 1835 | }; |
| 1836 | |
| 1837 | struct TestTypeConversionDriver |
| 1838 | : public PassWrapper<TestTypeConversionDriver, OperationPass<>> { |
| 1839 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver) |
| 1840 | |
| 1841 | void getDependentDialects(DialectRegistry ®istry) const override { |
| 1842 | registry.insert<TestDialect>(); |
| 1843 | } |
| 1844 | StringRef getArgument() const final { |
| 1845 | return "test-legalize-type-conversion" ; |
| 1846 | } |
| 1847 | StringRef getDescription() const final { |
| 1848 | return "Test various type conversion functionalities in DialectConversion" ; |
| 1849 | } |
| 1850 | |
| 1851 | void runOnOperation() override { |
| 1852 | // Initialize the type converter. |
| 1853 | SmallVector<Type, 2> conversionCallStack; |
| 1854 | TypeConverter converter; |
| 1855 | |
| 1856 | /// Add the legal set of type conversions. |
| 1857 | converter.addConversion(callback: [](Type type) -> Type { |
| 1858 | // Treat F64 as legal. |
| 1859 | if (type.isF64()) |
| 1860 | return type; |
| 1861 | // Allow converting BF16/F16/F32 to F64. |
| 1862 | if (type.isBF16() || type.isF16() || type.isF32()) |
| 1863 | return Float64Type::get(type.getContext()); |
| 1864 | // Otherwise, the type is illegal. |
| 1865 | return nullptr; |
| 1866 | }); |
| 1867 | converter.addConversion(callback: [](IntegerType type, SmallVectorImpl<Type> &) { |
| 1868 | // Drop all integer types. |
| 1869 | return success(); |
| 1870 | }); |
| 1871 | converter.addConversion( |
| 1872 | // Convert a recursive self-referring type into a non-self-referring |
| 1873 | // type named "outer_converted_type" that contains a SimpleAType. |
| 1874 | callback: [&](test::TestRecursiveType type, |
| 1875 | SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { |
| 1876 | // If the type is already converted, return it to indicate that it is |
| 1877 | // legal. |
| 1878 | if (type.getName() == "outer_converted_type" ) { |
| 1879 | results.push_back(type); |
| 1880 | return success(); |
| 1881 | } |
| 1882 | |
| 1883 | conversionCallStack.push_back(type); |
| 1884 | auto popConversionCallStack = llvm::make_scope_exit( |
| 1885 | F: [&conversionCallStack]() { conversionCallStack.pop_back(); }); |
| 1886 | |
| 1887 | // If the type is on the call stack more than once (it is there at |
| 1888 | // least once because of the _current_ call, which is always the last |
| 1889 | // element on the stack), we've hit the recursive case. Just return |
| 1890 | // SimpleAType here to create a non-recursive type as a result. |
| 1891 | if (llvm::is_contained(Range: ArrayRef(conversionCallStack).drop_back(), |
| 1892 | Element: type)) { |
| 1893 | results.push_back(test::SimpleAType::get(type.getContext())); |
| 1894 | return success(); |
| 1895 | } |
| 1896 | |
| 1897 | // Convert the body recursively. |
| 1898 | auto result = test::TestRecursiveType::get(ctx: type.getContext(), |
| 1899 | name: "outer_converted_type" ); |
| 1900 | if (failed(result.setBody(converter.convertType(t: type.getBody())))) |
| 1901 | return failure(); |
| 1902 | results.push_back(Elt: result); |
| 1903 | return success(); |
| 1904 | }); |
| 1905 | |
| 1906 | /// Add the legal set of type materializations. |
| 1907 | converter.addSourceMaterialization(callback: [](OpBuilder &builder, Type resultType, |
| 1908 | ValueRange inputs, |
| 1909 | Location loc) -> Value { |
| 1910 | // Allow casting from F64 back to F32. |
| 1911 | if (!resultType.isF16() && inputs.size() == 1 && |
| 1912 | inputs[0].getType().isF64()) |
| 1913 | return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); |
| 1914 | // Allow producing an i32 or i64 from nothing. |
| 1915 | if ((resultType.isInteger(32) || resultType.isInteger(64)) && |
| 1916 | inputs.empty()) |
| 1917 | return builder.create<TestTypeProducerOp>(loc, resultType); |
| 1918 | // Allow producing an i64 from an integer. |
| 1919 | if (isa<IntegerType>(resultType) && inputs.size() == 1 && |
| 1920 | isa<IntegerType>(inputs[0].getType())) |
| 1921 | return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); |
| 1922 | // Otherwise, fail. |
| 1923 | return nullptr; |
| 1924 | }); |
| 1925 | |
| 1926 | // Initialize the conversion target. |
| 1927 | mlir::ConversionTarget target(getContext()); |
| 1928 | target.addLegalOp<LegalOpD>(); |
| 1929 | target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { |
| 1930 | auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType()); |
| 1931 | return op.getType().isF64() || op.getType().isInteger(64) || |
| 1932 | (recursiveType && |
| 1933 | recursiveType.getName() == "outer_converted_type" ); |
| 1934 | }); |
| 1935 | target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
| 1936 | return converter.isSignatureLegal(op.getFunctionType()) && |
| 1937 | converter.isLegal(&op.getBody()); |
| 1938 | }); |
| 1939 | target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { |
| 1940 | // Allow casts from F64 to F32. |
| 1941 | return (*op.operand_type_begin()).isF64() && op.getType().isF32(); |
| 1942 | }); |
| 1943 | target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>( |
| 1944 | [&](TestSignatureConversionNoConverterOp op) { |
| 1945 | return converter.isLegal(op.getRegion().front().getArgumentTypes()); |
| 1946 | }); |
| 1947 | |
| 1948 | // Initialize the set of rewrite patterns. |
| 1949 | RewritePatternSet patterns(&getContext()); |
| 1950 | patterns |
| 1951 | .add<TestTypeConsumerForward, TestTypeConversionProducer, |
| 1952 | TestSignatureConversionUndo, |
| 1953 | TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>( |
| 1954 | arg&: converter, args: &getContext()); |
| 1955 | patterns.add<TestTypeConversionAnotherProducer>(arg: &getContext()); |
| 1956 | mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, |
| 1957 | converter); |
| 1958 | |
| 1959 | if (failed(applyPartialConversion(getOperation(), target, |
| 1960 | std::move(patterns)))) |
| 1961 | signalPassFailure(); |
| 1962 | } |
| 1963 | }; |
| 1964 | } // namespace |
| 1965 | |
| 1966 | //===----------------------------------------------------------------------===// |
| 1967 | // Test Target Materialization With No Uses |
| 1968 | //===----------------------------------------------------------------------===// |
| 1969 | |
| 1970 | namespace { |
| 1971 | struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> { |
| 1972 | using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern; |
| 1973 | |
| 1974 | LogicalResult |
| 1975 | matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor, |
| 1976 | ConversionPatternRewriter &rewriter) const final { |
| 1977 | rewriter.replaceOp(op, adaptor.getOperands()); |
| 1978 | return success(); |
| 1979 | } |
| 1980 | }; |
| 1981 | |
| 1982 | struct TestTargetMaterializationWithNoUses |
| 1983 | : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> { |
| 1984 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
| 1985 | TestTargetMaterializationWithNoUses) |
| 1986 | |
| 1987 | StringRef getArgument() const final { |
| 1988 | return "test-target-materialization-with-no-uses" ; |
| 1989 | } |
| 1990 | StringRef getDescription() const final { |
| 1991 | return "Test a special case of target materialization in DialectConversion" ; |
| 1992 | } |
| 1993 | |
| 1994 | void runOnOperation() override { |
| 1995 | TypeConverter converter; |
| 1996 | converter.addConversion(callback: [](Type t) { return t; }); |
| 1997 | converter.addConversion(callback: [](IntegerType intTy) -> Type { |
| 1998 | if (intTy.getWidth() == 16) |
| 1999 | return IntegerType::get(intTy.getContext(), 64); |
| 2000 | return intTy; |
| 2001 | }); |
| 2002 | converter.addTargetMaterialization( |
| 2003 | callback: [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { |
| 2004 | return builder.create<TestCastOp>(loc, type, inputs).getResult(); |
| 2005 | }); |
| 2006 | |
| 2007 | ConversionTarget target(getContext()); |
| 2008 | target.addIllegalOp<TestTypeChangerOp>(); |
| 2009 | |
| 2010 | RewritePatternSet patterns(&getContext()); |
| 2011 | patterns.add<ForwardOperandPattern>(arg&: converter, args: &getContext()); |
| 2012 | |
| 2013 | if (failed(applyPartialConversion(getOperation(), target, |
| 2014 | std::move(patterns)))) |
| 2015 | signalPassFailure(); |
| 2016 | } |
| 2017 | }; |
| 2018 | } // namespace |
| 2019 | |
| 2020 | //===----------------------------------------------------------------------===// |
| 2021 | // Test Block Merging |
| 2022 | //===----------------------------------------------------------------------===// |
| 2023 | |
| 2024 | namespace { |
| 2025 | /// A rewriter pattern that tests that blocks can be merged. |
| 2026 | struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { |
| 2027 | using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; |
| 2028 | |
| 2029 | LogicalResult |
| 2030 | matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, |
| 2031 | ConversionPatternRewriter &rewriter) const final { |
| 2032 | Block &firstBlock = op.getBody().front(); |
| 2033 | Operation *branchOp = firstBlock.getTerminator(); |
| 2034 | Block *secondBlock = &*(std::next(op.getBody().begin())); |
| 2035 | auto succOperands = branchOp->getOperands(); |
| 2036 | SmallVector<Value, 2> replacements(succOperands); |
| 2037 | rewriter.eraseOp(op: branchOp); |
| 2038 | rewriter.mergeBlocks(source: secondBlock, dest: &firstBlock, argValues: replacements); |
| 2039 | rewriter.modifyOpInPlace(op, [] {}); |
| 2040 | return success(); |
| 2041 | } |
| 2042 | }; |
| 2043 | |
| 2044 | /// A rewrite pattern to tests the undo mechanism of blocks being merged. |
| 2045 | struct TestUndoBlocksMerge : public ConversionPattern { |
| 2046 | TestUndoBlocksMerge(MLIRContext *ctx) |
| 2047 | : ConversionPattern("test.undo_blocks_merge" , /*benefit=*/1, ctx) {} |
| 2048 | LogicalResult |
| 2049 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 2050 | ConversionPatternRewriter &rewriter) const final { |
| 2051 | Block &firstBlock = op->getRegion(index: 0).front(); |
| 2052 | Operation *branchOp = firstBlock.getTerminator(); |
| 2053 | Block *secondBlock = &*(std::next(x: op->getRegion(index: 0).begin())); |
| 2054 | rewriter.setInsertionPointToStart(secondBlock); |
| 2055 | rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); |
| 2056 | auto succOperands = branchOp->getOperands(); |
| 2057 | SmallVector<Value, 2> replacements(succOperands); |
| 2058 | rewriter.eraseOp(op: branchOp); |
| 2059 | rewriter.mergeBlocks(source: secondBlock, dest: &firstBlock, argValues: replacements); |
| 2060 | rewriter.modifyOpInPlace(root: op, callable: [] {}); |
| 2061 | return success(); |
| 2062 | } |
| 2063 | }; |
| 2064 | |
| 2065 | /// A rewrite mechanism to inline the body of the op into its parent, when both |
| 2066 | /// ops can have a single block. |
| 2067 | struct TestMergeSingleBlockOps |
| 2068 | : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { |
| 2069 | using OpConversionPattern< |
| 2070 | SingleBlockImplicitTerminatorOp>::OpConversionPattern; |
| 2071 | |
| 2072 | LogicalResult |
| 2073 | matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor, |
| 2074 | ConversionPatternRewriter &rewriter) const final { |
| 2075 | SingleBlockImplicitTerminatorOp parentOp = |
| 2076 | op->getParentOfType<SingleBlockImplicitTerminatorOp>(); |
| 2077 | if (!parentOp) |
| 2078 | return failure(); |
| 2079 | Block &innerBlock = op.getRegion().front(); |
| 2080 | TerminatorOp innerTerminator = |
| 2081 | cast<TerminatorOp>(innerBlock.getTerminator()); |
| 2082 | rewriter.inlineBlockBefore(&innerBlock, op); |
| 2083 | rewriter.eraseOp(op: innerTerminator); |
| 2084 | rewriter.eraseOp(op: op); |
| 2085 | return success(); |
| 2086 | } |
| 2087 | }; |
| 2088 | |
| 2089 | struct TestMergeBlocksPatternDriver |
| 2090 | : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> { |
| 2091 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver) |
| 2092 | |
| 2093 | StringRef getArgument() const final { return "test-merge-blocks" ; } |
| 2094 | StringRef getDescription() const final { |
| 2095 | return "Test Merging operation in ConversionPatternRewriter" ; |
| 2096 | } |
| 2097 | void runOnOperation() override { |
| 2098 | MLIRContext *context = &getContext(); |
| 2099 | mlir::RewritePatternSet patterns(context); |
| 2100 | patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( |
| 2101 | arg&: context); |
| 2102 | ConversionTarget target(*context); |
| 2103 | target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp, |
| 2104 | TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); |
| 2105 | target.addIllegalOp<ILLegalOpF>(); |
| 2106 | |
| 2107 | /// Expect the op to have a single block after legalization. |
| 2108 | target.addDynamicallyLegalOp<TestMergeBlocksOp>( |
| 2109 | [&](TestMergeBlocksOp op) -> bool { |
| 2110 | return llvm::hasSingleElement(op.getBody()); |
| 2111 | }); |
| 2112 | |
| 2113 | /// Only allow `test.br` within test.merge_blocks op. |
| 2114 | target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { |
| 2115 | return op->getParentOfType<TestMergeBlocksOp>(); |
| 2116 | }); |
| 2117 | |
| 2118 | /// Expect that all nested test.SingleBlockImplicitTerminator ops are |
| 2119 | /// inlined. |
| 2120 | target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( |
| 2121 | [&](SingleBlockImplicitTerminatorOp op) -> bool { |
| 2122 | return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); |
| 2123 | }); |
| 2124 | |
| 2125 | DenseSet<Operation *> unlegalizedOps; |
| 2126 | ConversionConfig config; |
| 2127 | config.unlegalizedOps = &unlegalizedOps; |
| 2128 | (void)applyPartialConversion(getOperation(), target, std::move(patterns), |
| 2129 | config); |
| 2130 | for (auto *op : unlegalizedOps) |
| 2131 | op->emitRemark() << "op '" << op->getName() << "' is not legalizable" ; |
| 2132 | } |
| 2133 | }; |
| 2134 | } // namespace |
| 2135 | |
| 2136 | //===----------------------------------------------------------------------===// |
| 2137 | // Test Selective Replacement |
| 2138 | //===----------------------------------------------------------------------===// |
| 2139 | |
| 2140 | namespace { |
| 2141 | /// A rewrite mechanism to inline the body of the op into its parent, when both |
| 2142 | /// ops can have a single block. |
| 2143 | struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { |
| 2144 | using OpRewritePattern<TestCastOp>::OpRewritePattern; |
| 2145 | |
| 2146 | LogicalResult matchAndRewrite(TestCastOp op, |
| 2147 | PatternRewriter &rewriter) const final { |
| 2148 | if (op.getNumOperands() != 2) |
| 2149 | return failure(); |
| 2150 | OperandRange operands = op.getOperands(); |
| 2151 | |
| 2152 | // Replace non-terminator uses with the first operand. |
| 2153 | rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) { |
| 2154 | return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); |
| 2155 | }); |
| 2156 | // Replace everything else with the second operand if the operation isn't |
| 2157 | // dead. |
| 2158 | rewriter.replaceOp(op, op.getOperand(1)); |
| 2159 | return success(); |
| 2160 | } |
| 2161 | }; |
| 2162 | |
| 2163 | struct TestSelectiveReplacementPatternDriver |
| 2164 | : public PassWrapper<TestSelectiveReplacementPatternDriver, |
| 2165 | OperationPass<>> { |
| 2166 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
| 2167 | TestSelectiveReplacementPatternDriver) |
| 2168 | |
| 2169 | StringRef getArgument() const final { |
| 2170 | return "test-pattern-selective-replacement" ; |
| 2171 | } |
| 2172 | StringRef getDescription() const final { |
| 2173 | return "Test selective replacement in the PatternRewriter" ; |
| 2174 | } |
| 2175 | void runOnOperation() override { |
| 2176 | MLIRContext *context = &getContext(); |
| 2177 | mlir::RewritePatternSet patterns(context); |
| 2178 | patterns.add<TestSelectiveOpReplacementPattern>(arg&: context); |
| 2179 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
| 2180 | } |
| 2181 | }; |
| 2182 | } // namespace |
| 2183 | |
| 2184 | //===----------------------------------------------------------------------===// |
| 2185 | // PassRegistration |
| 2186 | //===----------------------------------------------------------------------===// |
| 2187 | |
| 2188 | namespace mlir { |
| 2189 | namespace test { |
| 2190 | void registerPatternsTestPass() { |
| 2191 | PassRegistration<TestReturnTypeDriver>(); |
| 2192 | |
| 2193 | PassRegistration<TestDerivedAttributeDriver>(); |
| 2194 | |
| 2195 | PassRegistration<TestGreedyPatternDriver>(); |
| 2196 | PassRegistration<TestStrictPatternDriver>(); |
| 2197 | PassRegistration<TestWalkPatternDriver>(); |
| 2198 | |
| 2199 | PassRegistration<TestLegalizePatternDriver>([] { |
| 2200 | return std::make_unique<TestLegalizePatternDriver>(args&: legalizerConversionMode); |
| 2201 | }); |
| 2202 | |
| 2203 | PassRegistration<TestRemappedValue>(); |
| 2204 | |
| 2205 | PassRegistration<TestUnknownRootOpDriver>(); |
| 2206 | |
| 2207 | PassRegistration<TestTypeConversionDriver>(); |
| 2208 | PassRegistration<TestTargetMaterializationWithNoUses>(); |
| 2209 | |
| 2210 | PassRegistration<TestRewriteDynamicOpDriver>(); |
| 2211 | |
| 2212 | PassRegistration<TestMergeBlocksPatternDriver>(); |
| 2213 | PassRegistration<TestSelectiveReplacementPatternDriver>(); |
| 2214 | } |
| 2215 | } // namespace test |
| 2216 | } // namespace mlir |
| 2217 | |