| 1 | //===- PatternApplicator.h - PatternApplicator ------------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements an applicator that applies pattern rewrites based upon a |
| 10 | // user defined cost model. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #ifndef MLIR_REWRITE_PATTERNAPPLICATOR_H |
| 15 | #define MLIR_REWRITE_PATTERNAPPLICATOR_H |
| 16 | |
| 17 | #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
| 18 | |
| 19 | #include "mlir/IR/Action.h" |
| 20 | |
| 21 | namespace mlir { |
| 22 | class PatternRewriter; |
| 23 | |
| 24 | namespace detail { |
| 25 | class PDLByteCodeMutableState; |
| 26 | } // namespace detail |
| 27 | |
| 28 | /// This is the type of Action that is dispatched when a pattern is applied. |
| 29 | /// It captures the pattern to apply on top of the usual context. |
| 30 | class ApplyPatternAction : public tracing::ActionImpl<ApplyPatternAction> { |
| 31 | public: |
| 32 | using Base = tracing::ActionImpl<ApplyPatternAction>; |
| 33 | ApplyPatternAction(ArrayRef<IRUnit> irUnits, const Pattern &pattern) |
| 34 | : Base(irUnits), pattern(pattern) {} |
| 35 | static constexpr StringLiteral tag = "apply-pattern" ; |
| 36 | static constexpr StringLiteral desc = |
| 37 | "Encapsulate the application of rewrite patterns" ; |
| 38 | |
| 39 | void print(raw_ostream &os) const override { |
| 40 | os << "`" << tag << " pattern: " << pattern.getDebugName(); |
| 41 | } |
| 42 | |
| 43 | private: |
| 44 | const Pattern &pattern; |
| 45 | }; |
| 46 | |
| 47 | /// This class manages the application of a group of rewrite patterns, with a |
| 48 | /// user-provided cost model. |
| 49 | class PatternApplicator { |
| 50 | public: |
| 51 | /// The cost model dynamically assigns a PatternBenefit to a particular |
| 52 | /// pattern. Users can query contained patterns and pass analysis results to |
| 53 | /// applyCostModel. Patterns to be discarded should have a benefit of |
| 54 | /// `impossibleToMatch`. |
| 55 | using CostModel = function_ref<PatternBenefit(const Pattern &)>; |
| 56 | |
| 57 | explicit PatternApplicator(const FrozenRewritePatternSet &frozenPatternList); |
| 58 | ~PatternApplicator(); |
| 59 | |
| 60 | /// Attempt to match and rewrite the given op with any pattern, allowing a |
| 61 | /// predicate to decide if a pattern can be applied or not, and hooks for if |
| 62 | /// the pattern match was a success or failure. |
| 63 | /// |
| 64 | /// canApply: called before each match and rewrite attempt; return false to |
| 65 | /// skip pattern. |
| 66 | /// onFailure: called when a pattern fails to match to perform cleanup. |
| 67 | /// onSuccess: called when a pattern match succeeds; return failure() to |
| 68 | /// invalidate the match and try another pattern. |
| 69 | LogicalResult |
| 70 | matchAndRewrite(Operation *op, PatternRewriter &rewriter, |
| 71 | function_ref<bool(const Pattern &)> canApply = {}, |
| 72 | function_ref<void(const Pattern &)> onFailure = {}, |
| 73 | function_ref<LogicalResult(const Pattern &)> onSuccess = {}); |
| 74 | |
| 75 | /// Apply a cost model to the patterns within this applicator. |
| 76 | void applyCostModel(CostModel model); |
| 77 | |
| 78 | /// Apply the default cost model that solely uses the pattern's static |
| 79 | /// benefit. |
| 80 | void applyDefaultCostModel() { |
| 81 | applyCostModel(model: [](const Pattern &pattern) { return pattern.getBenefit(); }); |
| 82 | } |
| 83 | |
| 84 | /// Walk all of the patterns within the applicator. |
| 85 | void walkAllPatterns(function_ref<void(const Pattern &)> walk); |
| 86 | |
| 87 | private: |
| 88 | /// The list that owns the patterns used within this applicator. |
| 89 | const FrozenRewritePatternSet &frozenPatternList; |
| 90 | /// The set of patterns to match for each operation, stable sorted by benefit. |
| 91 | DenseMap<OperationName, SmallVector<const RewritePattern *, 2>> patterns; |
| 92 | /// The set of patterns that may match against any operation type, stable |
| 93 | /// sorted by benefit. |
| 94 | SmallVector<const RewritePattern *, 1> anyOpPatterns; |
| 95 | /// The mutable state used during execution of the PDL bytecode. |
| 96 | std::unique_ptr<detail::PDLByteCodeMutableState> mutableByteCodeState; |
| 97 | }; |
| 98 | |
| 99 | } // namespace mlir |
| 100 | |
| 101 | #endif // MLIR_REWRITE_PATTERNAPPLICATOR_H |
| 102 | |