1 | //===- PatternApplicator.h - PatternApplicator ------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements an applicator that applies pattern rewrites based upon a |
10 | // user defined cost model. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #ifndef MLIR_REWRITE_PATTERNAPPLICATOR_H |
15 | #define MLIR_REWRITE_PATTERNAPPLICATOR_H |
16 | |
17 | #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
18 | |
19 | #include "mlir/IR/Action.h" |
20 | |
21 | namespace mlir { |
22 | class PatternRewriter; |
23 | |
24 | namespace detail { |
25 | class PDLByteCodeMutableState; |
26 | } // namespace detail |
27 | |
28 | /// This is the type of Action that is dispatched when a pattern is applied. |
29 | /// It captures the pattern to apply on top of the usual context. |
30 | class ApplyPatternAction : public tracing::ActionImpl<ApplyPatternAction> { |
31 | public: |
32 | using Base = tracing::ActionImpl<ApplyPatternAction>; |
33 | ApplyPatternAction(ArrayRef<IRUnit> irUnits, const Pattern &pattern) |
34 | : Base(irUnits), pattern(pattern) {} |
35 | static constexpr StringLiteral tag = "apply-pattern" ; |
36 | static constexpr StringLiteral desc = |
37 | "Encapsulate the application of rewrite patterns" ; |
38 | |
39 | void print(raw_ostream &os) const override { |
40 | os << "`" << tag << " pattern: " << pattern.getDebugName(); |
41 | } |
42 | |
43 | private: |
44 | const Pattern &pattern; |
45 | }; |
46 | |
47 | /// This class manages the application of a group of rewrite patterns, with a |
48 | /// user-provided cost model. |
49 | class PatternApplicator { |
50 | public: |
51 | /// The cost model dynamically assigns a PatternBenefit to a particular |
52 | /// pattern. Users can query contained patterns and pass analysis results to |
53 | /// applyCostModel. Patterns to be discarded should have a benefit of |
54 | /// `impossibleToMatch`. |
55 | using CostModel = function_ref<PatternBenefit(const Pattern &)>; |
56 | |
57 | explicit PatternApplicator(const FrozenRewritePatternSet &frozenPatternList); |
58 | ~PatternApplicator(); |
59 | |
60 | /// Attempt to match and rewrite the given op with any pattern, allowing a |
61 | /// predicate to decide if a pattern can be applied or not, and hooks for if |
62 | /// the pattern match was a success or failure. |
63 | /// |
64 | /// canApply: called before each match and rewrite attempt; return false to |
65 | /// skip pattern. |
66 | /// onFailure: called when a pattern fails to match to perform cleanup. |
67 | /// onSuccess: called when a pattern match succeeds; return failure() to |
68 | /// invalidate the match and try another pattern. |
69 | LogicalResult |
70 | matchAndRewrite(Operation *op, PatternRewriter &rewriter, |
71 | function_ref<bool(const Pattern &)> canApply = {}, |
72 | function_ref<void(const Pattern &)> onFailure = {}, |
73 | function_ref<LogicalResult(const Pattern &)> onSuccess = {}); |
74 | |
75 | /// Apply a cost model to the patterns within this applicator. |
76 | void applyCostModel(CostModel model); |
77 | |
78 | /// Apply the default cost model that solely uses the pattern's static |
79 | /// benefit. |
80 | void applyDefaultCostModel() { |
81 | applyCostModel(model: [](const Pattern &pattern) { return pattern.getBenefit(); }); |
82 | } |
83 | |
84 | /// Walk all of the patterns within the applicator. |
85 | void walkAllPatterns(function_ref<void(const Pattern &)> walk); |
86 | |
87 | private: |
88 | /// The list that owns the patterns used within this applicator. |
89 | const FrozenRewritePatternSet &frozenPatternList; |
90 | /// The set of patterns to match for each operation, stable sorted by benefit. |
91 | DenseMap<OperationName, SmallVector<const RewritePattern *, 2>> patterns; |
92 | /// The set of patterns that may match against any operation type, stable |
93 | /// sorted by benefit. |
94 | SmallVector<const RewritePattern *, 1> anyOpPatterns; |
95 | /// The mutable state used during execution of the PDL bytecode. |
96 | std::unique_ptr<detail::PDLByteCodeMutableState> mutableByteCodeState; |
97 | }; |
98 | |
99 | } // namespace mlir |
100 | |
101 | #endif // MLIR_REWRITE_PATTERNAPPLICATOR_H |
102 | |