| 1 | //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements an applicator that applies pattern rewrites based upon a |
| 10 | // user defined cost model. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Rewrite/PatternApplicator.h" |
| 15 | #include "ByteCode.h" |
| 16 | #include "llvm/Support/Debug.h" |
| 17 | |
| 18 | #define DEBUG_TYPE "pattern-application" |
| 19 | |
| 20 | using namespace mlir; |
| 21 | using namespace mlir::detail; |
| 22 | |
| 23 | PatternApplicator::PatternApplicator( |
| 24 | const FrozenRewritePatternSet &frozenPatternList) |
| 25 | : frozenPatternList(frozenPatternList) { |
| 26 | if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { |
| 27 | mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>(); |
| 28 | bytecode->initializeMutableState(state&: *mutableByteCodeState); |
| 29 | } |
| 30 | } |
| 31 | PatternApplicator::~PatternApplicator() = default; |
| 32 | |
| 33 | #ifndef NDEBUG |
| 34 | /// Log a message for a pattern that is impossible to match. |
| 35 | static void logImpossibleToMatch(const Pattern &pattern) { |
| 36 | llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind() |
| 37 | << "' because it is impossible to match or cannot lead " |
| 38 | "to legal IR (by cost model)\n" ; |
| 39 | } |
| 40 | |
| 41 | /// Log IR after pattern application. |
| 42 | static Operation *getDumpRootOp(Operation *op) { |
| 43 | Operation *isolatedParent = |
| 44 | op->getParentWithTrait<mlir::OpTrait::IsIsolatedFromAbove>(); |
| 45 | if (isolatedParent) |
| 46 | return isolatedParent; |
| 47 | return op; |
| 48 | } |
| 49 | static void logSucessfulPatternApplication(Operation *op) { |
| 50 | llvm::dbgs() << "// *** IR Dump After Pattern Application ***\n" ; |
| 51 | op->dump(); |
| 52 | llvm::dbgs() << "\n\n" ; |
| 53 | } |
| 54 | #endif |
| 55 | |
| 56 | void PatternApplicator::applyCostModel(CostModel model) { |
| 57 | // Apply the cost model to the bytecode patterns first, and then the native |
| 58 | // patterns. |
| 59 | if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { |
| 60 | for (const auto &it : llvm::enumerate(First: bytecode->getPatterns())) |
| 61 | mutableByteCodeState->updatePatternBenefit(patternIndex: it.index(), benefit: model(it.value())); |
| 62 | } |
| 63 | |
| 64 | // Copy over the patterns so that we can sort by benefit based on the cost |
| 65 | // model. Patterns that are already impossible to match are ignored. |
| 66 | patterns.clear(); |
| 67 | for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) { |
| 68 | for (const RewritePattern *pattern : it.second) { |
| 69 | if (pattern->getBenefit().isImpossibleToMatch()) |
| 70 | LLVM_DEBUG(logImpossibleToMatch(*pattern)); |
| 71 | else |
| 72 | patterns[it.first].push_back(Elt: pattern); |
| 73 | } |
| 74 | } |
| 75 | anyOpPatterns.clear(); |
| 76 | for (const RewritePattern &pattern : |
| 77 | frozenPatternList.getMatchAnyOpNativePatterns()) { |
| 78 | if (pattern.getBenefit().isImpossibleToMatch()) |
| 79 | LLVM_DEBUG(logImpossibleToMatch(pattern)); |
| 80 | else |
| 81 | anyOpPatterns.push_back(Elt: &pattern); |
| 82 | } |
| 83 | |
| 84 | // Sort the patterns using the provided cost model. |
| 85 | llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits; |
| 86 | auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) { |
| 87 | return benefits[lhs] > benefits[rhs]; |
| 88 | }; |
| 89 | auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) { |
| 90 | // Special case for one pattern in the list, which is the most common case. |
| 91 | if (list.size() == 1) { |
| 92 | if (model(*list.front()).isImpossibleToMatch()) { |
| 93 | LLVM_DEBUG(logImpossibleToMatch(*list.front())); |
| 94 | list.clear(); |
| 95 | } |
| 96 | return; |
| 97 | } |
| 98 | |
| 99 | // Collect the dynamic benefits for the current pattern list. |
| 100 | benefits.clear(); |
| 101 | for (const Pattern *pat : list) |
| 102 | benefits.try_emplace(Key: pat, Args: model(*pat)); |
| 103 | |
| 104 | // Sort patterns with highest benefit first, and remove those that are |
| 105 | // impossible to match. |
| 106 | llvm::stable_sort(Range&: list, C: cmp); |
| 107 | while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) { |
| 108 | LLVM_DEBUG(logImpossibleToMatch(*list.back())); |
| 109 | list.pop_back(); |
| 110 | } |
| 111 | }; |
| 112 | for (auto &it : patterns) |
| 113 | processPatternList(it.second); |
| 114 | processPatternList(anyOpPatterns); |
| 115 | } |
| 116 | |
| 117 | void PatternApplicator::walkAllPatterns( |
| 118 | function_ref<void(const Pattern &)> walk) { |
| 119 | for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) |
| 120 | for (const auto &pattern : it.second) |
| 121 | walk(*pattern); |
| 122 | for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns()) |
| 123 | walk(it); |
| 124 | if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { |
| 125 | for (const Pattern &it : bytecode->getPatterns()) |
| 126 | walk(it); |
| 127 | } |
| 128 | } |
| 129 | |
| 130 | LogicalResult PatternApplicator::matchAndRewrite( |
| 131 | Operation *op, PatternRewriter &rewriter, |
| 132 | function_ref<bool(const Pattern &)> canApply, |
| 133 | function_ref<void(const Pattern &)> onFailure, |
| 134 | function_ref<LogicalResult(const Pattern &)> onSuccess) { |
| 135 | // Before checking native patterns, first match against the bytecode. This |
| 136 | // won't automatically perform any rewrites so there is no need to worry about |
| 137 | // conflicts. |
| 138 | SmallVector<PDLByteCode::MatchResult, 4> pdlMatches; |
| 139 | const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode(); |
| 140 | if (bytecode) |
| 141 | bytecode->match(op, rewriter, matches&: pdlMatches, state&: *mutableByteCodeState); |
| 142 | |
| 143 | // Check to see if there are patterns matching this specific operation type. |
| 144 | MutableArrayRef<const RewritePattern *> opPatterns; |
| 145 | auto patternIt = patterns.find(Val: op->getName()); |
| 146 | if (patternIt != patterns.end()) |
| 147 | opPatterns = patternIt->second; |
| 148 | |
| 149 | // Process the patterns for that match the specific operation type, and any |
| 150 | // operation type in an interleaved fashion. |
| 151 | unsigned opIt = 0, opE = opPatterns.size(); |
| 152 | unsigned anyIt = 0, anyE = anyOpPatterns.size(); |
| 153 | unsigned pdlIt = 0, pdlE = pdlMatches.size(); |
| 154 | LogicalResult result = failure(); |
| 155 | do { |
| 156 | // Find the next pattern with the highest benefit. |
| 157 | const Pattern *bestPattern = nullptr; |
| 158 | unsigned *bestPatternIt = &opIt; |
| 159 | |
| 160 | /// Operation specific patterns. |
| 161 | if (opIt < opE) |
| 162 | bestPattern = opPatterns[opIt]; |
| 163 | /// Operation agnostic patterns. |
| 164 | if (anyIt < anyE && |
| 165 | (!bestPattern || |
| 166 | bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) { |
| 167 | bestPatternIt = &anyIt; |
| 168 | bestPattern = anyOpPatterns[anyIt]; |
| 169 | } |
| 170 | |
| 171 | const PDLByteCode::MatchResult *pdlMatch = nullptr; |
| 172 | /// PDL patterns. |
| 173 | if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() < |
| 174 | pdlMatches[pdlIt].benefit)) { |
| 175 | bestPatternIt = &pdlIt; |
| 176 | pdlMatch = &pdlMatches[pdlIt]; |
| 177 | bestPattern = pdlMatch->pattern; |
| 178 | } |
| 179 | |
| 180 | if (!bestPattern) |
| 181 | break; |
| 182 | |
| 183 | // Update the pattern iterator on failure so that this pattern isn't |
| 184 | // attempted again. |
| 185 | ++(*bestPatternIt); |
| 186 | |
| 187 | // Check that the pattern can be applied. |
| 188 | if (canApply && !canApply(*bestPattern)) |
| 189 | continue; |
| 190 | |
| 191 | // Try to match and rewrite this pattern. The patterns are sorted by |
| 192 | // benefit, so if we match we can immediately rewrite. For PDL patterns, the |
| 193 | // match has already been performed, we just need to rewrite. |
| 194 | bool matched = false; |
| 195 | op->getContext()->executeAction<ApplyPatternAction>( |
| 196 | actionFn: [&]() { |
| 197 | rewriter.setInsertionPoint(op); |
| 198 | #ifndef NDEBUG |
| 199 | // Operation `op` may be invalidated after applying the rewrite |
| 200 | // pattern. |
| 201 | Operation *dumpRootOp = getDumpRootOp(op); |
| 202 | #endif |
| 203 | if (pdlMatch) { |
| 204 | result = |
| 205 | bytecode->rewrite(rewriter, match: *pdlMatch, state&: *mutableByteCodeState); |
| 206 | } else { |
| 207 | LLVM_DEBUG(llvm::dbgs() << "Trying to match \"" |
| 208 | << bestPattern->getDebugName() << "\"\n" ); |
| 209 | |
| 210 | const auto *pattern = |
| 211 | static_cast<const RewritePattern *>(bestPattern); |
| 212 | result = pattern->matchAndRewrite(op, rewriter); |
| 213 | |
| 214 | LLVM_DEBUG(llvm::dbgs() |
| 215 | << "\"" << bestPattern->getDebugName() << "\" result " |
| 216 | << succeeded(result) << "\n" ); |
| 217 | } |
| 218 | |
| 219 | // Process the result of the pattern application. |
| 220 | if (succeeded(Result: result) && onSuccess && failed(Result: onSuccess(*bestPattern))) |
| 221 | result = failure(); |
| 222 | if (succeeded(Result: result)) { |
| 223 | LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp)); |
| 224 | matched = true; |
| 225 | return; |
| 226 | } |
| 227 | |
| 228 | // Perform any necessary cleanups. |
| 229 | if (onFailure) |
| 230 | onFailure(*bestPattern); |
| 231 | }, |
| 232 | irUnits: {op}, args: *bestPattern); |
| 233 | if (matched) |
| 234 | break; |
| 235 | } while (true); |
| 236 | |
| 237 | if (mutableByteCodeState) |
| 238 | mutableByteCodeState->cleanupAfterMatchAndRewrite(); |
| 239 | return result; |
| 240 | } |
| 241 | |