1 | //===- ByteCode.h - Pattern byte-code interpreter ---------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file declares a byte-code and interpreter for pattern rewrites in MLIR. |
10 | // The byte-code is constructed from the PDL Interpreter dialect. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #ifndef MLIR_REWRITE_BYTECODE_H_ |
15 | #define MLIR_REWRITE_BYTECODE_H_ |
16 | |
17 | #include "mlir/IR/PatternMatch.h" |
18 | |
19 | #if MLIR_ENABLE_PDL_IN_PATTERNMATCH |
20 | |
21 | namespace mlir { |
22 | namespace pdl_interp { |
23 | class RecordMatchOp; |
24 | } // namespace pdl_interp |
25 | |
26 | namespace detail { |
27 | class PDLByteCode; |
28 | |
29 | /// Use generic bytecode types. ByteCodeField refers to the actual bytecode |
30 | /// entries. ByteCodeAddr refers to size of indices into the bytecode. |
31 | using ByteCodeField = uint16_t; |
32 | using ByteCodeAddr = uint32_t; |
33 | using OwningOpRange = llvm::OwningArrayRef<Operation *>; |
34 | |
35 | //===----------------------------------------------------------------------===// |
36 | // PDLByteCodePattern |
37 | //===----------------------------------------------------------------------===// |
38 | |
39 | /// All of the data pertaining to a specific pattern within the bytecode. |
40 | class PDLByteCodePattern : public Pattern { |
41 | public: |
42 | static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp, |
43 | PDLPatternConfigSet *configSet, |
44 | ByteCodeAddr rewriterAddr); |
45 | |
46 | /// Return the bytecode address of the rewriter for this pattern. |
47 | ByteCodeAddr getRewriterAddr() const { return rewriterAddr; } |
48 | |
49 | /// Return the configuration set for this pattern, or null if there is none. |
50 | PDLPatternConfigSet *getConfigSet() const { return configSet; } |
51 | |
52 | private: |
53 | template <typename... Args> |
54 | PDLByteCodePattern(ByteCodeAddr rewriterAddr, PDLPatternConfigSet *configSet, |
55 | Args &&...patternArgs) |
56 | : Pattern(std::forward<Args>(patternArgs)...), rewriterAddr(rewriterAddr), |
57 | configSet(configSet) {} |
58 | |
59 | /// The address of the rewriter for this pattern. |
60 | ByteCodeAddr rewriterAddr; |
61 | |
62 | /// The optional config set for this pattern. |
63 | PDLPatternConfigSet *configSet; |
64 | }; |
65 | |
66 | //===----------------------------------------------------------------------===// |
67 | // PDLByteCodeMutableState |
68 | //===----------------------------------------------------------------------===// |
69 | |
70 | /// This class contains the mutable state of a bytecode instance. This allows |
71 | /// for a bytecode instance to be cached and reused across various different |
72 | /// threads/drivers. |
73 | class PDLByteCodeMutableState { |
74 | public: |
75 | /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds |
76 | /// to the position of the pattern within the range returned by |
77 | /// `PDLByteCode::getPatterns`. |
78 | void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit); |
79 | |
80 | /// Cleanup any allocated state after a match/rewrite has been completed. This |
81 | /// method should be called irregardless of whether the match+rewrite was a |
82 | /// success or not. |
83 | void cleanupAfterMatchAndRewrite(); |
84 | |
85 | private: |
86 | /// Allow access to data fields. |
87 | friend class PDLByteCode; |
88 | |
89 | /// The mutable block of memory used during the matching and rewriting phases |
90 | /// of the bytecode. |
91 | std::vector<const void *> memory; |
92 | |
93 | /// A mutable block of memory used during the matching and rewriting phase of |
94 | /// the bytecode to store ranges of operations. These are always stored by |
95 | /// owning references, because at no point in the execution of the byte code |
96 | /// we get an indexed range (view) of operations. |
97 | std::vector<OwningOpRange> opRangeMemory; |
98 | |
99 | /// A mutable block of memory used during the matching and rewriting phase of |
100 | /// the bytecode to store ranges of types. |
101 | std::vector<TypeRange> typeRangeMemory; |
102 | /// A set of type ranges that have been allocated by the byte code interpreter |
103 | /// to provide a guaranteed lifetime. |
104 | std::vector<llvm::OwningArrayRef<Type>> allocatedTypeRangeMemory; |
105 | |
106 | /// A mutable block of memory used during the matching and rewriting phase of |
107 | /// the bytecode to store ranges of values. |
108 | std::vector<ValueRange> valueRangeMemory; |
109 | /// A set of value ranges that have been allocated by the byte code |
110 | /// interpreter to provide a guaranteed lifetime. |
111 | std::vector<llvm::OwningArrayRef<Value>> allocatedValueRangeMemory; |
112 | |
113 | /// The current index of ranges being iterated over for each level of nesting. |
114 | /// These are always maintained at 0 for the loops that are not active, so we |
115 | /// do not need to have a separate initialization phase for each loop. |
116 | std::vector<unsigned> loopIndex; |
117 | |
118 | /// The up-to-date benefits of the patterns held by the bytecode. The order |
119 | /// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`. |
120 | std::vector<PatternBenefit> currentPatternBenefits; |
121 | }; |
122 | |
123 | //===----------------------------------------------------------------------===// |
124 | // PDLByteCode |
125 | //===----------------------------------------------------------------------===// |
126 | |
127 | /// The bytecode class is also the interpreter. Contains the bytecode itself, |
128 | /// the static info, addresses of the rewriter functions, the interpreter |
129 | /// memory buffer, and the execution context. |
130 | class PDLByteCode { |
131 | public: |
132 | /// Each successful match returns a MatchResult, which contains information |
133 | /// necessary to execute the rewriter and indicates the originating pattern. |
134 | struct MatchResult { |
135 | MatchResult(Location loc, const PDLByteCodePattern &pattern, |
136 | PatternBenefit benefit) |
137 | : location(loc), pattern(&pattern), benefit(benefit) {} |
138 | MatchResult(const MatchResult &) = delete; |
139 | MatchResult &operator=(const MatchResult &) = delete; |
140 | MatchResult(MatchResult &&other) = default; |
141 | MatchResult &operator=(MatchResult &&) = default; |
142 | |
143 | /// The location of operations to be replaced. |
144 | Location location; |
145 | /// Memory values defined in the matcher that are passed to the rewriter. |
146 | SmallVector<const void *> values; |
147 | /// Memory used for the range input values. |
148 | SmallVector<TypeRange, 0> typeRangeValues; |
149 | SmallVector<ValueRange, 0> valueRangeValues; |
150 | |
151 | /// The originating pattern that was matched. This is always non-null, but |
152 | /// represented with a pointer to allow for assignment. |
153 | const PDLByteCodePattern *pattern; |
154 | /// The current benefit of the pattern that was matched. |
155 | PatternBenefit benefit; |
156 | }; |
157 | |
158 | /// Create a ByteCode instance from the given module containing operations in |
159 | /// the PDL interpreter dialect. |
160 | PDLByteCode(ModuleOp module, |
161 | SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs, |
162 | const DenseMap<Operation *, PDLPatternConfigSet *> &configMap, |
163 | llvm::StringMap<PDLConstraintFunction> constraintFns, |
164 | llvm::StringMap<PDLRewriteFunction> rewriteFns); |
165 | |
166 | /// Return the patterns held by the bytecode. |
167 | ArrayRef<PDLByteCodePattern> getPatterns() const { return patterns; } |
168 | |
169 | /// Initialize the given state such that it can be used to execute the current |
170 | /// bytecode. |
171 | void initializeMutableState(PDLByteCodeMutableState &state) const; |
172 | |
173 | /// Run the pattern matcher on the given root operation, collecting the |
174 | /// matched patterns in `matches`. |
175 | void match(Operation *op, PatternRewriter &rewriter, |
176 | SmallVectorImpl<MatchResult> &matches, |
177 | PDLByteCodeMutableState &state) const; |
178 | |
179 | /// Run the rewriter of the given pattern that was previously matched in |
180 | /// `match`. Returns if a failure was encountered during the rewrite. |
181 | LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, |
182 | PDLByteCodeMutableState &state) const; |
183 | |
184 | private: |
185 | /// Execute the given byte code starting at the provided instruction `inst`. |
186 | /// `matches` is an optional field provided when this function is executed in |
187 | /// a matching context. |
188 | void executeByteCode(const ByteCodeField *inst, PatternRewriter &rewriter, |
189 | PDLByteCodeMutableState &state, |
190 | SmallVectorImpl<MatchResult> *matches) const; |
191 | |
192 | /// The set of pattern configs referenced within the bytecode. |
193 | SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs; |
194 | |
195 | /// A vector containing pointers to uniqued data. The storage is intentionally |
196 | /// opaque such that we can store a wide range of data types. The types of |
197 | /// data stored here include: |
198 | /// * Attribute, OperationName, Type |
199 | std::vector<const void *> uniquedData; |
200 | |
201 | /// A vector containing the generated bytecode for the matcher. |
202 | SmallVector<ByteCodeField, 64> matcherByteCode; |
203 | |
204 | /// A vector containing the generated bytecode for all of the rewriters. |
205 | SmallVector<ByteCodeField, 64> rewriterByteCode; |
206 | |
207 | /// The set of patterns contained within the bytecode. |
208 | SmallVector<PDLByteCodePattern, 32> patterns; |
209 | |
210 | /// A set of user defined functions invoked via PDL. |
211 | std::vector<PDLConstraintFunction> constraintFunctions; |
212 | std::vector<PDLRewriteFunction> rewriteFunctions; |
213 | |
214 | /// The maximum memory index used by a value. |
215 | ByteCodeField maxValueMemoryIndex = 0; |
216 | |
217 | /// The maximum number of different types of ranges. |
218 | ByteCodeField maxOpRangeCount = 0; |
219 | ByteCodeField maxTypeRangeCount = 0; |
220 | ByteCodeField maxValueRangeCount = 0; |
221 | |
222 | /// The maximum number of nested loops. |
223 | ByteCodeField maxLoopLevel = 0; |
224 | }; |
225 | |
226 | } // namespace detail |
227 | } // namespace mlir |
228 | |
229 | #else |
230 | |
231 | namespace mlir::detail { |
232 | |
233 | class PDLByteCodeMutableState { |
234 | public: |
235 | void cleanupAfterMatchAndRewrite() {} |
236 | void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit) {} |
237 | }; |
238 | |
239 | class PDLByteCodePattern : public Pattern {}; |
240 | |
241 | class PDLByteCode { |
242 | public: |
243 | struct MatchResult { |
244 | const PDLByteCodePattern *pattern = nullptr; |
245 | PatternBenefit benefit; |
246 | }; |
247 | |
248 | void initializeMutableState(PDLByteCodeMutableState &state) const {} |
249 | void match(Operation *op, PatternRewriter &rewriter, |
250 | SmallVectorImpl<MatchResult> &matches, |
251 | PDLByteCodeMutableState &state) const {} |
252 | LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, |
253 | PDLByteCodeMutableState &state) const { |
254 | return failure(); |
255 | } |
256 | ArrayRef<PDLByteCodePattern> getPatterns() const { return {}; } |
257 | }; |
258 | |
259 | } // namespace mlir::detail |
260 | |
261 | #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH |
262 | |
263 | #endif // MLIR_REWRITE_BYTECODE_H_ |
264 | |