1//===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implements mlir::walkAndApplyPatterns.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Transforms/WalkPatternRewriteDriver.h"
14
15#include "mlir/IR/MLIRContext.h"
16#include "mlir/IR/Operation.h"
17#include "mlir/IR/OperationSupport.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/IR/Verifier.h"
20#include "mlir/IR/Visitors.h"
21#include "mlir/Rewrite/PatternApplicator.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/Support/Debug.h"
24#include "llvm/Support/ErrorHandling.h"
25
26#define DEBUG_TYPE "walk-rewriter"
27
28namespace mlir {
29
30// Find all reachable blocks in the region and add them to the visitedBlocks
31// set.
32static void findReachableBlocks(Region &region,
33 DenseSet<Block *> &reachableBlocks) {
34 Block *entryBlock = &region.front();
35 reachableBlocks.insert(V: entryBlock);
36 // Traverse the CFG and add all reachable blocks to the blockList.
37 SmallVector<Block *> worklist({entryBlock});
38 while (!worklist.empty()) {
39 Block *block = worklist.pop_back_val();
40 Operation *terminator = &block->back();
41 for (Block *successor : terminator->getSuccessors()) {
42 if (reachableBlocks.contains(V: successor))
43 continue;
44 worklist.push_back(Elt: successor);
45 reachableBlocks.insert(V: successor);
46 }
47 }
48}
49
50namespace {
51struct WalkAndApplyPatternsAction final
52 : tracing::ActionImpl<WalkAndApplyPatternsAction> {
53 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction)
54 using ActionImpl::ActionImpl;
55 static constexpr StringLiteral tag = "walk-and-apply-patterns";
56 void print(raw_ostream &os) const override { os << tag; }
57};
58
59#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
60// Forwarding listener to guard against unsupported erasures of non-descendant
61// ops/blocks. Because we use walk-based pattern application, erasing the
62// op/block from the *next* iteration (e.g., a user of the visited op) is not
63// valid. Note that this is only used with expensive pattern API checks.
64struct ErasedOpsListener final : RewriterBase::ForwardingListener {
65 using RewriterBase::ForwardingListener::ForwardingListener;
66
67 void notifyOperationErased(Operation *op) override {
68 checkErasure(op);
69 ForwardingListener::notifyOperationErased(op);
70 }
71
72 void notifyBlockErased(Block *block) override {
73 checkErasure(block->getParentOp());
74 ForwardingListener::notifyBlockErased(block);
75 }
76
77 void checkErasure(Operation *op) const {
78 Operation *ancestorOp = op;
79 while (ancestorOp && ancestorOp != visitedOp)
80 ancestorOp = ancestorOp->getParentOp();
81
82 if (ancestorOp != visitedOp)
83 llvm::report_fatal_error(
84 "unsupported erasure in WalkPatternRewriter; "
85 "erasure is only supported for matched ops and their descendants");
86 }
87
88 Operation *visitedOp = nullptr;
89};
90#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
91} // namespace
92
93void walkAndApplyPatterns(Operation *op,
94 const FrozenRewritePatternSet &patterns,
95 RewriterBase::Listener *listener) {
96#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
97 if (failed(verify(op)))
98 llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
99#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
100
101 MLIRContext *ctx = op->getContext();
102 PatternRewriter rewriter(ctx);
103#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
104 ErasedOpsListener erasedListener(listener);
105 rewriter.setListener(&erasedListener);
106#else
107 rewriter.setListener(listener);
108#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
109
110 PatternApplicator applicator(patterns);
111 applicator.applyDefaultCostModel();
112
113 // Iterator on all reachable operations in the region.
114 // Also keep track if we visited the nested regions of the current op
115 // already to drive the post-order traversal.
116 struct RegionReachableOpIterator {
117 RegionReachableOpIterator(Region *region) : region(region) {
118 regionIt = region->begin();
119 if (regionIt != region->end())
120 blockIt = regionIt->begin();
121 if (!llvm::hasSingleElement(C&: *region))
122 findReachableBlocks(region&: *region, reachableBlocks);
123 }
124 // Advance the iterator to the next reachable operation.
125 void advance() {
126 assert(regionIt != region->end());
127 hasVisitedRegions = false;
128 if (blockIt == regionIt->end()) {
129 ++regionIt;
130 while (regionIt != region->end() &&
131 !reachableBlocks.contains(V: &*regionIt))
132 ++regionIt;
133 if (regionIt != region->end())
134 blockIt = regionIt->begin();
135 return;
136 }
137 ++blockIt;
138 if (blockIt != regionIt->end()) {
139 LLVM_DEBUG({
140 llvm::dbgs() << "Incrementing block iterator, next op: "
141 << OpWithFlags(&*blockIt,
142 OpPrintingFlags().skipRegions())
143 << "\n";
144 });
145 }
146 }
147 // The region we're iterating over.
148 Region *region;
149 // The Block currently being iterated over.
150 Region::iterator regionIt;
151 // The Operation currently being iterated over.
152 Block::iterator blockIt;
153 // The set of blocks that are reachable in the current region.
154 DenseSet<Block *> reachableBlocks;
155 // Whether we've visited the nested regions of the current op already.
156 bool hasVisitedRegions = false;
157 };
158
159 // Worklist of regions to visit to drive the post-order traversal.
160 SmallVector<RegionReachableOpIterator> worklist;
161
162 LLVM_DEBUG(
163 { llvm::dbgs() << "Starting walk-based pattern rewrite driver\n"; });
164 ctx->executeAction<WalkAndApplyPatternsAction>(
165 actionFn: [&] {
166 // Perform a post-order traversal of the regions, visiting each
167 // reachable operation.
168 for (Region &region : op->getRegions()) {
169 assert(worklist.empty());
170 if (region.empty())
171 continue;
172
173 // Prime the worklist with the entry block of this region.
174 worklist.push_back(Elt: {&region});
175 while (!worklist.empty()) {
176 RegionReachableOpIterator &it = worklist.back();
177 if (it.regionIt == it.region->end()) {
178 // We're done with this region.
179 worklist.pop_back();
180 continue;
181 }
182 if (it.blockIt == it.regionIt->end()) {
183 // We're done with this block.
184 it.advance();
185 continue;
186 }
187 Operation *op = &*it.blockIt;
188 // If we haven't visited the nested regions of this op yet,
189 // enqueue them.
190 if (!it.hasVisitedRegions) {
191 it.hasVisitedRegions = true;
192 for (Region &nestedRegion : llvm::reverse(C: op->getRegions())) {
193 if (nestedRegion.empty())
194 continue;
195 worklist.push_back(Elt: {&nestedRegion});
196 }
197 }
198 // If we're not at the back of the worklist, we've enqueued some
199 // nested region for processing. We'll come back to this op later
200 // (post-order)
201 if (&it != &worklist.back())
202 continue;
203
204 // Preemptively increment the iterator, in case the current op
205 // would be erased.
206 it.advance();
207
208 LLVM_DEBUG({
209 llvm::dbgs() << "Visiting op: "
210 << OpWithFlags(op, OpPrintingFlags().skipRegions())
211 << "\n";
212 });
213#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
214 erasedListener.visitedOp = op;
215#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
216 if (succeeded(Result: applicator.matchAndRewrite(op, rewriter)))
217 LLVM_DEBUG({ llvm::dbgs() << "\tOp matched and rewritten\n"; });
218 }
219 }
220 },
221 irUnits: {op});
222
223#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
224 if (failed(verify(op)))
225 llvm::report_fatal_error(
226 "walk pattern rewriter result IR failed to verify");
227#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
228}
229
230} // namespace mlir
231

source code of mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp