| 1 | //===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===// |
|---|---|
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Implements mlir::walkAndApplyPatterns. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Transforms/WalkPatternRewriteDriver.h" |
| 14 | |
| 15 | #include "mlir/IR/MLIRContext.h" |
| 16 | #include "mlir/IR/OperationSupport.h" |
| 17 | #include "mlir/IR/PatternMatch.h" |
| 18 | #include "mlir/IR/Verifier.h" |
| 19 | #include "mlir/IR/Visitors.h" |
| 20 | #include "mlir/Rewrite/PatternApplicator.h" |
| 21 | #include "llvm/Support/Debug.h" |
| 22 | #include "llvm/Support/ErrorHandling.h" |
| 23 | |
| 24 | #define DEBUG_TYPE "walk-rewriter" |
| 25 | |
| 26 | namespace mlir { |
| 27 | |
| 28 | namespace { |
| 29 | struct WalkAndApplyPatternsAction final |
| 30 | : tracing::ActionImpl<WalkAndApplyPatternsAction> { |
| 31 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction) |
| 32 | using ActionImpl::ActionImpl; |
| 33 | static constexpr StringLiteral tag = "walk-and-apply-patterns"; |
| 34 | void print(raw_ostream &os) const override { os << tag; } |
| 35 | }; |
| 36 | |
| 37 | #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| 38 | // Forwarding listener to guard against unsupported erasures of non-descendant |
| 39 | // ops/blocks. Because we use walk-based pattern application, erasing the |
| 40 | // op/block from the *next* iteration (e.g., a user of the visited op) is not |
| 41 | // valid. Note that this is only used with expensive pattern API checks. |
| 42 | struct ErasedOpsListener final : RewriterBase::ForwardingListener { |
| 43 | using RewriterBase::ForwardingListener::ForwardingListener; |
| 44 | |
| 45 | void notifyOperationErased(Operation *op) override { |
| 46 | checkErasure(op); |
| 47 | ForwardingListener::notifyOperationErased(op); |
| 48 | } |
| 49 | |
| 50 | void notifyBlockErased(Block *block) override { |
| 51 | checkErasure(block->getParentOp()); |
| 52 | ForwardingListener::notifyBlockErased(block); |
| 53 | } |
| 54 | |
| 55 | void checkErasure(Operation *op) const { |
| 56 | Operation *ancestorOp = op; |
| 57 | while (ancestorOp && ancestorOp != visitedOp) |
| 58 | ancestorOp = ancestorOp->getParentOp(); |
| 59 | |
| 60 | if (ancestorOp != visitedOp) |
| 61 | llvm::report_fatal_error( |
| 62 | "unsupported erasure in WalkPatternRewriter; " |
| 63 | "erasure is only supported for matched ops and their descendants"); |
| 64 | } |
| 65 | |
| 66 | Operation *visitedOp = nullptr; |
| 67 | }; |
| 68 | #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| 69 | } // namespace |
| 70 | |
| 71 | void walkAndApplyPatterns(Operation *op, |
| 72 | const FrozenRewritePatternSet &patterns, |
| 73 | RewriterBase::Listener *listener) { |
| 74 | #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| 75 | if (failed(verify(op))) |
| 76 | llvm::report_fatal_error("walk pattern rewriter input IR failed to verify"); |
| 77 | #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| 78 | |
| 79 | MLIRContext *ctx = op->getContext(); |
| 80 | PatternRewriter rewriter(ctx); |
| 81 | #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| 82 | ErasedOpsListener erasedListener(listener); |
| 83 | rewriter.setListener(&erasedListener); |
| 84 | #else |
| 85 | rewriter.setListener(listener); |
| 86 | #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| 87 | |
| 88 | PatternApplicator applicator(patterns); |
| 89 | applicator.applyDefaultCostModel(); |
| 90 | |
| 91 | ctx->executeAction<WalkAndApplyPatternsAction>( |
| 92 | actionFn: [&] { |
| 93 | for (Region ®ion : op->getRegions()) { |
| 94 | region.walk(callback: [&](Operation *visitedOp) { |
| 95 | LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print( |
| 96 | llvm::dbgs(), OpPrintingFlags().skipRegions()); |
| 97 | llvm::dbgs() << "\n";); |
| 98 | #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| 99 | erasedListener.visitedOp = visitedOp; |
| 100 | #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| 101 | if (succeeded(Result: applicator.matchAndRewrite(op: visitedOp, rewriter))) { |
| 102 | LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";); |
| 103 | } |
| 104 | }); |
| 105 | } |
| 106 | }, |
| 107 | irUnits: {op}); |
| 108 | |
| 109 | #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| 110 | if (failed(verify(op))) |
| 111 | llvm::report_fatal_error( |
| 112 | "walk pattern rewriter result IR failed to verify"); |
| 113 | #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| 114 | } |
| 115 | |
| 116 | } // namespace mlir |
| 117 |
