1//===- PatternApplicator.h - PatternApplicator ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements an applicator that applies pattern rewrites based upon a
10// user defined cost model.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_REWRITE_PATTERNAPPLICATOR_H
15#define MLIR_REWRITE_PATTERNAPPLICATOR_H
16
17#include "mlir/Rewrite/FrozenRewritePatternSet.h"
18
19#include "mlir/IR/Action.h"
20
21namespace mlir {
22class PatternRewriter;
23
24namespace detail {
25class PDLByteCodeMutableState;
26} // namespace detail
27
28/// This is the type of Action that is dispatched when a pattern is applied.
29/// It captures the pattern to apply on top of the usual context.
30class ApplyPatternAction : public tracing::ActionImpl<ApplyPatternAction> {
31public:
32 using Base = tracing::ActionImpl<ApplyPatternAction>;
33 ApplyPatternAction(ArrayRef<IRUnit> irUnits, const Pattern &pattern)
34 : Base(irUnits), pattern(pattern) {}
35 static constexpr StringLiteral tag = "apply-pattern";
36 static constexpr StringLiteral desc =
37 "Encapsulate the application of rewrite patterns";
38
39 void print(raw_ostream &os) const override {
40 os << "`" << tag << " pattern: " << pattern.getDebugName();
41 }
42
43private:
44 const Pattern &pattern;
45};
46
47/// This class manages the application of a group of rewrite patterns, with a
48/// user-provided cost model.
49class PatternApplicator {
50public:
51 /// The cost model dynamically assigns a PatternBenefit to a particular
52 /// pattern. Users can query contained patterns and pass analysis results to
53 /// applyCostModel. Patterns to be discarded should have a benefit of
54 /// `impossibleToMatch`.
55 using CostModel = function_ref<PatternBenefit(const Pattern &)>;
56
57 explicit PatternApplicator(const FrozenRewritePatternSet &frozenPatternList);
58 ~PatternApplicator();
59
60 /// Attempt to match and rewrite the given op with any pattern, allowing a
61 /// predicate to decide if a pattern can be applied or not, and hooks for if
62 /// the pattern match was a success or failure.
63 ///
64 /// canApply: called before each match and rewrite attempt; return false to
65 /// skip pattern.
66 /// onFailure: called when a pattern fails to match to perform cleanup.
67 /// onSuccess: called when a pattern match succeeds; return failure() to
68 /// invalidate the match and try another pattern.
69 LogicalResult
70 matchAndRewrite(Operation *op, PatternRewriter &rewriter,
71 function_ref<bool(const Pattern &)> canApply = {},
72 function_ref<void(const Pattern &)> onFailure = {},
73 function_ref<LogicalResult(const Pattern &)> onSuccess = {});
74
75 /// Apply a cost model to the patterns within this applicator.
76 void applyCostModel(CostModel model);
77
78 /// Apply the default cost model that solely uses the pattern's static
79 /// benefit.
80 void applyDefaultCostModel() {
81 applyCostModel(model: [](const Pattern &pattern) { return pattern.getBenefit(); });
82 }
83
84 /// Walk all of the patterns within the applicator.
85 void walkAllPatterns(function_ref<void(const Pattern &)> walk);
86
87private:
88 /// The list that owns the patterns used within this applicator.
89 const FrozenRewritePatternSet &frozenPatternList;
90 /// The set of patterns to match for each operation, stable sorted by benefit.
91 DenseMap<OperationName, SmallVector<const RewritePattern *, 2>> patterns;
92 /// The set of patterns that may match against any operation type, stable
93 /// sorted by benefit.
94 SmallVector<const RewritePattern *, 1> anyOpPatterns;
95 /// The mutable state used during execution of the PDL bytecode.
96 std::unique_ptr<detail::PDLByteCodeMutableState> mutableByteCodeState;
97};
98
99} // namespace mlir
100
101#endif // MLIR_REWRITE_PATTERNAPPLICATOR_H
102

source code of mlir/include/mlir/Rewrite/PatternApplicator.h