1 | //===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Implements mlir::walkAndApplyPatterns. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Transforms/WalkPatternRewriteDriver.h" |
14 | |
15 | #include "mlir/IR/MLIRContext.h" |
16 | #include "mlir/IR/OperationSupport.h" |
17 | #include "mlir/IR/PatternMatch.h" |
18 | #include "mlir/IR/Verifier.h" |
19 | #include "mlir/IR/Visitors.h" |
20 | #include "mlir/Rewrite/PatternApplicator.h" |
21 | #include "llvm/Support/Debug.h" |
22 | #include "llvm/Support/ErrorHandling.h" |
23 | |
24 | #define DEBUG_TYPE "walk-rewriter" |
25 | |
26 | namespace mlir { |
27 | |
28 | namespace { |
29 | struct WalkAndApplyPatternsAction final |
30 | : tracing::ActionImpl<WalkAndApplyPatternsAction> { |
31 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction) |
32 | using ActionImpl::ActionImpl; |
33 | static constexpr StringLiteral tag = "walk-and-apply-patterns"; |
34 | void print(raw_ostream &os) const override { os << tag; } |
35 | }; |
36 | |
37 | #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
38 | // Forwarding listener to guard against unsupported erasures of non-descendant |
39 | // ops/blocks. Because we use walk-based pattern application, erasing the |
40 | // op/block from the *next* iteration (e.g., a user of the visited op) is not |
41 | // valid. Note that this is only used with expensive pattern API checks. |
42 | struct ErasedOpsListener final : RewriterBase::ForwardingListener { |
43 | using RewriterBase::ForwardingListener::ForwardingListener; |
44 | |
45 | void notifyOperationErased(Operation *op) override { |
46 | checkErasure(op); |
47 | ForwardingListener::notifyOperationErased(op); |
48 | } |
49 | |
50 | void notifyBlockErased(Block *block) override { |
51 | checkErasure(block->getParentOp()); |
52 | ForwardingListener::notifyBlockErased(block); |
53 | } |
54 | |
55 | void checkErasure(Operation *op) const { |
56 | Operation *ancestorOp = op; |
57 | while (ancestorOp && ancestorOp != visitedOp) |
58 | ancestorOp = ancestorOp->getParentOp(); |
59 | |
60 | if (ancestorOp != visitedOp) |
61 | llvm::report_fatal_error( |
62 | "unsupported erasure in WalkPatternRewriter; " |
63 | "erasure is only supported for matched ops and their descendants"); |
64 | } |
65 | |
66 | Operation *visitedOp = nullptr; |
67 | }; |
68 | #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
69 | } // namespace |
70 | |
71 | void walkAndApplyPatterns(Operation *op, |
72 | const FrozenRewritePatternSet &patterns, |
73 | RewriterBase::Listener *listener) { |
74 | #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
75 | if (failed(verify(op))) |
76 | llvm::report_fatal_error("walk pattern rewriter input IR failed to verify"); |
77 | #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
78 | |
79 | MLIRContext *ctx = op->getContext(); |
80 | PatternRewriter rewriter(ctx); |
81 | #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
82 | ErasedOpsListener erasedListener(listener); |
83 | rewriter.setListener(&erasedListener); |
84 | #else |
85 | rewriter.setListener(listener); |
86 | #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
87 | |
88 | PatternApplicator applicator(patterns); |
89 | applicator.applyDefaultCostModel(); |
90 | |
91 | ctx->executeAction<WalkAndApplyPatternsAction>( |
92 | actionFn: [&] { |
93 | for (Region ®ion : op->getRegions()) { |
94 | region.walk(callback: [&](Operation *visitedOp) { |
95 | LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print( |
96 | llvm::dbgs(), OpPrintingFlags().skipRegions()); |
97 | llvm::dbgs() << "\n";); |
98 | #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
99 | erasedListener.visitedOp = visitedOp; |
100 | #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
101 | if (succeeded(Result: applicator.matchAndRewrite(op: visitedOp, rewriter))) { |
102 | LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";); |
103 | } |
104 | }); |
105 | } |
106 | }, |
107 | irUnits: {op}); |
108 | |
109 | #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
110 | if (failed(verify(op))) |
111 | llvm::report_fatal_error( |
112 | "walk pattern rewriter result IR failed to verify"); |
113 | #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
114 | } |
115 | |
116 | } // namespace mlir |
117 |