1//===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements an applicator that applies pattern rewrites based upon a
10// user defined cost model.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Rewrite/PatternApplicator.h"
15#include "ByteCode.h"
16#include "llvm/Support/Debug.h"
17
18#define DEBUG_TYPE "pattern-application"
19
20using namespace mlir;
21using namespace mlir::detail;
22
23PatternApplicator::PatternApplicator(
24 const FrozenRewritePatternSet &frozenPatternList)
25 : frozenPatternList(frozenPatternList) {
26 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
27 mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>();
28 bytecode->initializeMutableState(state&: *mutableByteCodeState);
29 }
30}
31PatternApplicator::~PatternApplicator() = default;
32
33#ifndef NDEBUG
34/// Log a message for a pattern that is impossible to match.
35static void logImpossibleToMatch(const Pattern &pattern) {
36 llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind()
37 << "' because it is impossible to match or cannot lead "
38 "to legal IR (by cost model)\n";
39}
40
41/// Log IR after pattern application.
42static Operation *getDumpRootOp(Operation *op) {
43 Operation *isolatedParent =
44 op->getParentWithTrait<mlir::OpTrait::IsIsolatedFromAbove>();
45 if (isolatedParent)
46 return isolatedParent;
47 return op;
48}
49static void logSucessfulPatternApplication(Operation *op) {
50 llvm::dbgs() << "// *** IR Dump After Pattern Application ***\n";
51 op->dump();
52 llvm::dbgs() << "\n\n";
53}
54#endif
55
56void PatternApplicator::applyCostModel(CostModel model) {
57 // Apply the cost model to the bytecode patterns first, and then the native
58 // patterns.
59 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
60 for (const auto &it : llvm::enumerate(First: bytecode->getPatterns()))
61 mutableByteCodeState->updatePatternBenefit(patternIndex: it.index(), benefit: model(it.value()));
62 }
63
64 // Copy over the patterns so that we can sort by benefit based on the cost
65 // model. Patterns that are already impossible to match are ignored.
66 patterns.clear();
67 for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) {
68 for (const RewritePattern *pattern : it.second) {
69 if (pattern->getBenefit().isImpossibleToMatch())
70 LLVM_DEBUG(logImpossibleToMatch(*pattern));
71 else
72 patterns[it.first].push_back(Elt: pattern);
73 }
74 }
75 anyOpPatterns.clear();
76 for (const RewritePattern &pattern :
77 frozenPatternList.getMatchAnyOpNativePatterns()) {
78 if (pattern.getBenefit().isImpossibleToMatch())
79 LLVM_DEBUG(logImpossibleToMatch(pattern));
80 else
81 anyOpPatterns.push_back(Elt: &pattern);
82 }
83
84 // Sort the patterns using the provided cost model.
85 llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits;
86 auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) {
87 return benefits[lhs] > benefits[rhs];
88 };
89 auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) {
90 // Special case for one pattern in the list, which is the most common case.
91 if (list.size() == 1) {
92 if (model(*list.front()).isImpossibleToMatch()) {
93 LLVM_DEBUG(logImpossibleToMatch(*list.front()));
94 list.clear();
95 }
96 return;
97 }
98
99 // Collect the dynamic benefits for the current pattern list.
100 benefits.clear();
101 for (const Pattern *pat : list)
102 benefits.try_emplace(Key: pat, Args: model(*pat));
103
104 // Sort patterns with highest benefit first, and remove those that are
105 // impossible to match.
106 std::stable_sort(first: list.begin(), last: list.end(), comp: cmp);
107 while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
108 LLVM_DEBUG(logImpossibleToMatch(*list.back()));
109 list.pop_back();
110 }
111 };
112 for (auto &it : patterns)
113 processPatternList(it.second);
114 processPatternList(anyOpPatterns);
115}
116
117void PatternApplicator::walkAllPatterns(
118 function_ref<void(const Pattern &)> walk) {
119 for (const auto &it : frozenPatternList.getOpSpecificNativePatterns())
120 for (const auto &pattern : it.second)
121 walk(*pattern);
122 for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns())
123 walk(it);
124 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
125 for (const Pattern &it : bytecode->getPatterns())
126 walk(it);
127 }
128}
129
130LogicalResult PatternApplicator::matchAndRewrite(
131 Operation *op, PatternRewriter &rewriter,
132 function_ref<bool(const Pattern &)> canApply,
133 function_ref<void(const Pattern &)> onFailure,
134 function_ref<LogicalResult(const Pattern &)> onSuccess) {
135 // Before checking native patterns, first match against the bytecode. This
136 // won't automatically perform any rewrites so there is no need to worry about
137 // conflicts.
138 SmallVector<PDLByteCode::MatchResult, 4> pdlMatches;
139 const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode();
140 if (bytecode)
141 bytecode->match(op, rewriter, matches&: pdlMatches, state&: *mutableByteCodeState);
142
143 // Check to see if there are patterns matching this specific operation type.
144 MutableArrayRef<const RewritePattern *> opPatterns;
145 auto patternIt = patterns.find(Val: op->getName());
146 if (patternIt != patterns.end())
147 opPatterns = patternIt->second;
148
149 // Process the patterns for that match the specific operation type, and any
150 // operation type in an interleaved fashion.
151 unsigned opIt = 0, opE = opPatterns.size();
152 unsigned anyIt = 0, anyE = anyOpPatterns.size();
153 unsigned pdlIt = 0, pdlE = pdlMatches.size();
154 LogicalResult result = failure();
155 do {
156 // Find the next pattern with the highest benefit.
157 const Pattern *bestPattern = nullptr;
158 unsigned *bestPatternIt = &opIt;
159
160 /// Operation specific patterns.
161 if (opIt < opE)
162 bestPattern = opPatterns[opIt];
163 /// Operation agnostic patterns.
164 if (anyIt < anyE &&
165 (!bestPattern ||
166 bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) {
167 bestPatternIt = &anyIt;
168 bestPattern = anyOpPatterns[anyIt];
169 }
170
171 const PDLByteCode::MatchResult *pdlMatch = nullptr;
172 /// PDL patterns.
173 if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() <
174 pdlMatches[pdlIt].benefit)) {
175 bestPatternIt = &pdlIt;
176 pdlMatch = &pdlMatches[pdlIt];
177 bestPattern = pdlMatch->pattern;
178 }
179
180 if (!bestPattern)
181 break;
182
183 // Update the pattern iterator on failure so that this pattern isn't
184 // attempted again.
185 ++(*bestPatternIt);
186
187 // Check that the pattern can be applied.
188 if (canApply && !canApply(*bestPattern))
189 continue;
190
191 // Try to match and rewrite this pattern. The patterns are sorted by
192 // benefit, so if we match we can immediately rewrite. For PDL patterns, the
193 // match has already been performed, we just need to rewrite.
194 bool matched = false;
195 op->getContext()->executeAction<ApplyPatternAction>(
196 actionFn: [&]() {
197 rewriter.setInsertionPoint(op);
198#ifndef NDEBUG
199 // Operation `op` may be invalidated after applying the rewrite
200 // pattern.
201 Operation *dumpRootOp = getDumpRootOp(op);
202#endif
203 if (pdlMatch) {
204 result =
205 bytecode->rewrite(rewriter, match: *pdlMatch, state&: *mutableByteCodeState);
206 } else {
207 LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
208 << bestPattern->getDebugName() << "\"\n");
209
210 const auto *pattern =
211 static_cast<const RewritePattern *>(bestPattern);
212 result = pattern->matchAndRewrite(op, rewriter);
213
214 LLVM_DEBUG(llvm::dbgs()
215 << "\"" << bestPattern->getDebugName() << "\" result "
216 << succeeded(result) << "\n");
217 }
218
219 // Process the result of the pattern application.
220 if (succeeded(result) && onSuccess && failed(result: onSuccess(*bestPattern)))
221 result = failure();
222 if (succeeded(result)) {
223 LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
224 matched = true;
225 return;
226 }
227
228 // Perform any necessary cleanups.
229 if (onFailure)
230 onFailure(*bestPattern);
231 },
232 irUnits: {op}, args: *bestPattern);
233 if (matched)
234 break;
235 } while (true);
236
237 if (mutableByteCodeState)
238 mutableByteCodeState->cleanupAfterMatchAndRewrite();
239 return result;
240}
241

source code of mlir/lib/Rewrite/PatternApplicator.cpp