1//===- ByteCode.h - Pattern byte-code interpreter ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file declares a byte-code and interpreter for pattern rewrites in MLIR.
10// The byte-code is constructed from the PDL Interpreter dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_REWRITE_BYTECODE_H_
15#define MLIR_REWRITE_BYTECODE_H_
16
17#include "mlir/IR/PatternMatch.h"
18
19#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
20
21namespace mlir {
22namespace pdl_interp {
23class RecordMatchOp;
24} // namespace pdl_interp
25
26namespace detail {
27class PDLByteCode;
28
29/// Use generic bytecode types. ByteCodeField refers to the actual bytecode
30/// entries. ByteCodeAddr refers to size of indices into the bytecode.
31using ByteCodeField = uint16_t;
32using ByteCodeAddr = uint32_t;
33using OwningOpRange = llvm::OwningArrayRef<Operation *>;
34
35//===----------------------------------------------------------------------===//
36// PDLByteCodePattern
37//===----------------------------------------------------------------------===//
38
39/// All of the data pertaining to a specific pattern within the bytecode.
40class PDLByteCodePattern : public Pattern {
41public:
42 static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp,
43 PDLPatternConfigSet *configSet,
44 ByteCodeAddr rewriterAddr);
45
46 /// Return the bytecode address of the rewriter for this pattern.
47 ByteCodeAddr getRewriterAddr() const { return rewriterAddr; }
48
49 /// Return the configuration set for this pattern, or null if there is none.
50 PDLPatternConfigSet *getConfigSet() const { return configSet; }
51
52private:
53 template <typename... Args>
54 PDLByteCodePattern(ByteCodeAddr rewriterAddr, PDLPatternConfigSet *configSet,
55 Args &&...patternArgs)
56 : Pattern(std::forward<Args>(patternArgs)...), rewriterAddr(rewriterAddr),
57 configSet(configSet) {}
58
59 /// The address of the rewriter for this pattern.
60 ByteCodeAddr rewriterAddr;
61
62 /// The optional config set for this pattern.
63 PDLPatternConfigSet *configSet;
64};
65
66//===----------------------------------------------------------------------===//
67// PDLByteCodeMutableState
68//===----------------------------------------------------------------------===//
69
70/// This class contains the mutable state of a bytecode instance. This allows
71/// for a bytecode instance to be cached and reused across various different
72/// threads/drivers.
73class PDLByteCodeMutableState {
74public:
75 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
76 /// to the position of the pattern within the range returned by
77 /// `PDLByteCode::getPatterns`.
78 void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit);
79
80 /// Cleanup any allocated state after a match/rewrite has been completed. This
81 /// method should be called irregardless of whether the match+rewrite was a
82 /// success or not.
83 void cleanupAfterMatchAndRewrite();
84
85private:
86 /// Allow access to data fields.
87 friend class PDLByteCode;
88
89 /// The mutable block of memory used during the matching and rewriting phases
90 /// of the bytecode.
91 std::vector<const void *> memory;
92
93 /// A mutable block of memory used during the matching and rewriting phase of
94 /// the bytecode to store ranges of operations. These are always stored by
95 /// owning references, because at no point in the execution of the byte code
96 /// we get an indexed range (view) of operations.
97 std::vector<OwningOpRange> opRangeMemory;
98
99 /// A mutable block of memory used during the matching and rewriting phase of
100 /// the bytecode to store ranges of types.
101 std::vector<TypeRange> typeRangeMemory;
102 /// A set of type ranges that have been allocated by the byte code interpreter
103 /// to provide a guaranteed lifetime.
104 std::vector<llvm::OwningArrayRef<Type>> allocatedTypeRangeMemory;
105
106 /// A mutable block of memory used during the matching and rewriting phase of
107 /// the bytecode to store ranges of values.
108 std::vector<ValueRange> valueRangeMemory;
109 /// A set of value ranges that have been allocated by the byte code
110 /// interpreter to provide a guaranteed lifetime.
111 std::vector<llvm::OwningArrayRef<Value>> allocatedValueRangeMemory;
112
113 /// The current index of ranges being iterated over for each level of nesting.
114 /// These are always maintained at 0 for the loops that are not active, so we
115 /// do not need to have a separate initialization phase for each loop.
116 std::vector<unsigned> loopIndex;
117
118 /// The up-to-date benefits of the patterns held by the bytecode. The order
119 /// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`.
120 std::vector<PatternBenefit> currentPatternBenefits;
121};
122
123//===----------------------------------------------------------------------===//
124// PDLByteCode
125//===----------------------------------------------------------------------===//
126
127/// The bytecode class is also the interpreter. Contains the bytecode itself,
128/// the static info, addresses of the rewriter functions, the interpreter
129/// memory buffer, and the execution context.
130class PDLByteCode {
131public:
132 /// Each successful match returns a MatchResult, which contains information
133 /// necessary to execute the rewriter and indicates the originating pattern.
134 struct MatchResult {
135 MatchResult(Location loc, const PDLByteCodePattern &pattern,
136 PatternBenefit benefit)
137 : location(loc), pattern(&pattern), benefit(benefit) {}
138 MatchResult(const MatchResult &) = delete;
139 MatchResult &operator=(const MatchResult &) = delete;
140 MatchResult(MatchResult &&other) = default;
141 MatchResult &operator=(MatchResult &&) = default;
142
143 /// The location of operations to be replaced.
144 Location location;
145 /// Memory values defined in the matcher that are passed to the rewriter.
146 SmallVector<const void *> values;
147 /// Memory used for the range input values.
148 SmallVector<TypeRange, 0> typeRangeValues;
149 SmallVector<ValueRange, 0> valueRangeValues;
150
151 /// The originating pattern that was matched. This is always non-null, but
152 /// represented with a pointer to allow for assignment.
153 const PDLByteCodePattern *pattern;
154 /// The current benefit of the pattern that was matched.
155 PatternBenefit benefit;
156 };
157
158 /// Create a ByteCode instance from the given module containing operations in
159 /// the PDL interpreter dialect.
160 PDLByteCode(ModuleOp module,
161 SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
162 const DenseMap<Operation *, PDLPatternConfigSet *> &configMap,
163 llvm::StringMap<PDLConstraintFunction> constraintFns,
164 llvm::StringMap<PDLRewriteFunction> rewriteFns);
165
166 /// Return the patterns held by the bytecode.
167 ArrayRef<PDLByteCodePattern> getPatterns() const { return patterns; }
168
169 /// Initialize the given state such that it can be used to execute the current
170 /// bytecode.
171 void initializeMutableState(PDLByteCodeMutableState &state) const;
172
173 /// Run the pattern matcher on the given root operation, collecting the
174 /// matched patterns in `matches`.
175 void match(Operation *op, PatternRewriter &rewriter,
176 SmallVectorImpl<MatchResult> &matches,
177 PDLByteCodeMutableState &state) const;
178
179 /// Run the rewriter of the given pattern that was previously matched in
180 /// `match`. Returns if a failure was encountered during the rewrite.
181 LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match,
182 PDLByteCodeMutableState &state) const;
183
184private:
185 /// Execute the given byte code starting at the provided instruction `inst`.
186 /// `matches` is an optional field provided when this function is executed in
187 /// a matching context.
188 void executeByteCode(const ByteCodeField *inst, PatternRewriter &rewriter,
189 PDLByteCodeMutableState &state,
190 SmallVectorImpl<MatchResult> *matches) const;
191
192 /// The set of pattern configs referenced within the bytecode.
193 SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
194
195 /// A vector containing pointers to uniqued data. The storage is intentionally
196 /// opaque such that we can store a wide range of data types. The types of
197 /// data stored here include:
198 /// * Attribute, OperationName, Type
199 std::vector<const void *> uniquedData;
200
201 /// A vector containing the generated bytecode for the matcher.
202 SmallVector<ByteCodeField, 64> matcherByteCode;
203
204 /// A vector containing the generated bytecode for all of the rewriters.
205 SmallVector<ByteCodeField, 64> rewriterByteCode;
206
207 /// The set of patterns contained within the bytecode.
208 SmallVector<PDLByteCodePattern, 32> patterns;
209
210 /// A set of user defined functions invoked via PDL.
211 std::vector<PDLConstraintFunction> constraintFunctions;
212 std::vector<PDLRewriteFunction> rewriteFunctions;
213
214 /// The maximum memory index used by a value.
215 ByteCodeField maxValueMemoryIndex = 0;
216
217 /// The maximum number of different types of ranges.
218 ByteCodeField maxOpRangeCount = 0;
219 ByteCodeField maxTypeRangeCount = 0;
220 ByteCodeField maxValueRangeCount = 0;
221
222 /// The maximum number of nested loops.
223 ByteCodeField maxLoopLevel = 0;
224};
225
226} // namespace detail
227} // namespace mlir
228
229#else
230
231namespace mlir::detail {
232
233class PDLByteCodeMutableState {
234public:
235 void cleanupAfterMatchAndRewrite() {}
236 void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit) {}
237};
238
239class PDLByteCodePattern : public Pattern {};
240
241class PDLByteCode {
242public:
243 struct MatchResult {
244 const PDLByteCodePattern *pattern = nullptr;
245 PatternBenefit benefit;
246 };
247
248 void initializeMutableState(PDLByteCodeMutableState &state) const {}
249 void match(Operation *op, PatternRewriter &rewriter,
250 SmallVectorImpl<MatchResult> &matches,
251 PDLByteCodeMutableState &state) const {}
252 LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match,
253 PDLByteCodeMutableState &state) const {
254 return failure();
255 }
256 ArrayRef<PDLByteCodePattern> getPatterns() const { return {}; }
257};
258
259} // namespace mlir::detail
260
261#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
262
263#endif // MLIR_REWRITE_BYTECODE_H_
264

source code of mlir/lib/Rewrite/ByteCode.h