| 1 | //===- ByteCode.h - Pattern byte-code interpreter ---------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file declares a byte-code and interpreter for pattern rewrites in MLIR. |
| 10 | // The byte-code is constructed from the PDL Interpreter dialect. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #ifndef MLIR_REWRITE_BYTECODE_H_ |
| 15 | #define MLIR_REWRITE_BYTECODE_H_ |
| 16 | |
| 17 | #include "mlir/IR/PatternMatch.h" |
| 18 | |
| 19 | #if MLIR_ENABLE_PDL_IN_PATTERNMATCH |
| 20 | |
| 21 | namespace mlir { |
| 22 | namespace pdl_interp { |
| 23 | class RecordMatchOp; |
| 24 | } // namespace pdl_interp |
| 25 | |
| 26 | namespace detail { |
| 27 | class PDLByteCode; |
| 28 | |
| 29 | /// Use generic bytecode types. ByteCodeField refers to the actual bytecode |
| 30 | /// entries. ByteCodeAddr refers to size of indices into the bytecode. |
| 31 | using ByteCodeField = uint16_t; |
| 32 | using ByteCodeAddr = uint32_t; |
| 33 | using OwningOpRange = llvm::OwningArrayRef<Operation *>; |
| 34 | |
| 35 | //===----------------------------------------------------------------------===// |
| 36 | // PDLByteCodePattern |
| 37 | //===----------------------------------------------------------------------===// |
| 38 | |
| 39 | /// All of the data pertaining to a specific pattern within the bytecode. |
| 40 | class PDLByteCodePattern : public Pattern { |
| 41 | public: |
| 42 | static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp, |
| 43 | PDLPatternConfigSet *configSet, |
| 44 | ByteCodeAddr rewriterAddr); |
| 45 | |
| 46 | /// Return the bytecode address of the rewriter for this pattern. |
| 47 | ByteCodeAddr getRewriterAddr() const { return rewriterAddr; } |
| 48 | |
| 49 | /// Return the configuration set for this pattern, or null if there is none. |
| 50 | PDLPatternConfigSet *getConfigSet() const { return configSet; } |
| 51 | |
| 52 | private: |
| 53 | template <typename... Args> |
| 54 | PDLByteCodePattern(ByteCodeAddr rewriterAddr, PDLPatternConfigSet *configSet, |
| 55 | Args &&...patternArgs) |
| 56 | : Pattern(std::forward<Args>(patternArgs)...), rewriterAddr(rewriterAddr), |
| 57 | configSet(configSet) {} |
| 58 | |
| 59 | /// The address of the rewriter for this pattern. |
| 60 | ByteCodeAddr rewriterAddr; |
| 61 | |
| 62 | /// The optional config set for this pattern. |
| 63 | PDLPatternConfigSet *configSet; |
| 64 | }; |
| 65 | |
| 66 | //===----------------------------------------------------------------------===// |
| 67 | // PDLByteCodeMutableState |
| 68 | //===----------------------------------------------------------------------===// |
| 69 | |
| 70 | /// This class contains the mutable state of a bytecode instance. This allows |
| 71 | /// for a bytecode instance to be cached and reused across various different |
| 72 | /// threads/drivers. |
| 73 | class PDLByteCodeMutableState { |
| 74 | public: |
| 75 | /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds |
| 76 | /// to the position of the pattern within the range returned by |
| 77 | /// `PDLByteCode::getPatterns`. |
| 78 | void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit); |
| 79 | |
| 80 | /// Cleanup any allocated state after a match/rewrite has been completed. This |
| 81 | /// method should be called irregardless of whether the match+rewrite was a |
| 82 | /// success or not. |
| 83 | void cleanupAfterMatchAndRewrite(); |
| 84 | |
| 85 | private: |
| 86 | /// Allow access to data fields. |
| 87 | friend class PDLByteCode; |
| 88 | |
| 89 | /// The mutable block of memory used during the matching and rewriting phases |
| 90 | /// of the bytecode. |
| 91 | std::vector<const void *> memory; |
| 92 | |
| 93 | /// A mutable block of memory used during the matching and rewriting phase of |
| 94 | /// the bytecode to store ranges of operations. These are always stored by |
| 95 | /// owning references, because at no point in the execution of the byte code |
| 96 | /// we get an indexed range (view) of operations. |
| 97 | std::vector<OwningOpRange> opRangeMemory; |
| 98 | |
| 99 | /// A mutable block of memory used during the matching and rewriting phase of |
| 100 | /// the bytecode to store ranges of types. |
| 101 | std::vector<TypeRange> typeRangeMemory; |
| 102 | /// A set of type ranges that have been allocated by the byte code interpreter |
| 103 | /// to provide a guaranteed lifetime. |
| 104 | std::vector<llvm::OwningArrayRef<Type>> allocatedTypeRangeMemory; |
| 105 | |
| 106 | /// A mutable block of memory used during the matching and rewriting phase of |
| 107 | /// the bytecode to store ranges of values. |
| 108 | std::vector<ValueRange> valueRangeMemory; |
| 109 | /// A set of value ranges that have been allocated by the byte code |
| 110 | /// interpreter to provide a guaranteed lifetime. |
| 111 | std::vector<llvm::OwningArrayRef<Value>> allocatedValueRangeMemory; |
| 112 | |
| 113 | /// The current index of ranges being iterated over for each level of nesting. |
| 114 | /// These are always maintained at 0 for the loops that are not active, so we |
| 115 | /// do not need to have a separate initialization phase for each loop. |
| 116 | std::vector<unsigned> loopIndex; |
| 117 | |
| 118 | /// The up-to-date benefits of the patterns held by the bytecode. The order |
| 119 | /// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`. |
| 120 | std::vector<PatternBenefit> currentPatternBenefits; |
| 121 | }; |
| 122 | |
| 123 | //===----------------------------------------------------------------------===// |
| 124 | // PDLByteCode |
| 125 | //===----------------------------------------------------------------------===// |
| 126 | |
| 127 | /// The bytecode class is also the interpreter. Contains the bytecode itself, |
| 128 | /// the static info, addresses of the rewriter functions, the interpreter |
| 129 | /// memory buffer, and the execution context. |
| 130 | class PDLByteCode { |
| 131 | public: |
| 132 | /// Each successful match returns a MatchResult, which contains information |
| 133 | /// necessary to execute the rewriter and indicates the originating pattern. |
| 134 | struct MatchResult { |
| 135 | MatchResult(Location loc, const PDLByteCodePattern &pattern, |
| 136 | PatternBenefit benefit) |
| 137 | : location(loc), pattern(&pattern), benefit(benefit) {} |
| 138 | MatchResult(const MatchResult &) = delete; |
| 139 | MatchResult &operator=(const MatchResult &) = delete; |
| 140 | MatchResult(MatchResult &&other) = default; |
| 141 | MatchResult &operator=(MatchResult &&) = default; |
| 142 | |
| 143 | /// The location of operations to be replaced. |
| 144 | Location location; |
| 145 | /// Memory values defined in the matcher that are passed to the rewriter. |
| 146 | SmallVector<const void *> values; |
| 147 | /// Memory used for the range input values. |
| 148 | SmallVector<TypeRange, 0> typeRangeValues; |
| 149 | SmallVector<ValueRange, 0> valueRangeValues; |
| 150 | |
| 151 | /// The originating pattern that was matched. This is always non-null, but |
| 152 | /// represented with a pointer to allow for assignment. |
| 153 | const PDLByteCodePattern *pattern; |
| 154 | /// The current benefit of the pattern that was matched. |
| 155 | PatternBenefit benefit; |
| 156 | }; |
| 157 | |
| 158 | /// Create a ByteCode instance from the given module containing operations in |
| 159 | /// the PDL interpreter dialect. |
| 160 | PDLByteCode(ModuleOp module, |
| 161 | SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs, |
| 162 | const DenseMap<Operation *, PDLPatternConfigSet *> &configMap, |
| 163 | llvm::StringMap<PDLConstraintFunction> constraintFns, |
| 164 | llvm::StringMap<PDLRewriteFunction> rewriteFns); |
| 165 | |
| 166 | /// Return the patterns held by the bytecode. |
| 167 | ArrayRef<PDLByteCodePattern> getPatterns() const { return patterns; } |
| 168 | |
| 169 | /// Initialize the given state such that it can be used to execute the current |
| 170 | /// bytecode. |
| 171 | void initializeMutableState(PDLByteCodeMutableState &state) const; |
| 172 | |
| 173 | /// Run the pattern matcher on the given root operation, collecting the |
| 174 | /// matched patterns in `matches`. |
| 175 | void match(Operation *op, PatternRewriter &rewriter, |
| 176 | SmallVectorImpl<MatchResult> &matches, |
| 177 | PDLByteCodeMutableState &state) const; |
| 178 | |
| 179 | /// Run the rewriter of the given pattern that was previously matched in |
| 180 | /// `match`. Returns if a failure was encountered during the rewrite. |
| 181 | LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, |
| 182 | PDLByteCodeMutableState &state) const; |
| 183 | |
| 184 | private: |
| 185 | /// Execute the given byte code starting at the provided instruction `inst`. |
| 186 | /// `matches` is an optional field provided when this function is executed in |
| 187 | /// a matching context. |
| 188 | void executeByteCode(const ByteCodeField *inst, PatternRewriter &rewriter, |
| 189 | PDLByteCodeMutableState &state, |
| 190 | SmallVectorImpl<MatchResult> *matches) const; |
| 191 | |
| 192 | /// The set of pattern configs referenced within the bytecode. |
| 193 | SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs; |
| 194 | |
| 195 | /// A vector containing pointers to uniqued data. The storage is intentionally |
| 196 | /// opaque such that we can store a wide range of data types. The types of |
| 197 | /// data stored here include: |
| 198 | /// * Attribute, OperationName, Type |
| 199 | std::vector<const void *> uniquedData; |
| 200 | |
| 201 | /// A vector containing the generated bytecode for the matcher. |
| 202 | SmallVector<ByteCodeField, 64> matcherByteCode; |
| 203 | |
| 204 | /// A vector containing the generated bytecode for all of the rewriters. |
| 205 | SmallVector<ByteCodeField, 64> rewriterByteCode; |
| 206 | |
| 207 | /// The set of patterns contained within the bytecode. |
| 208 | SmallVector<PDLByteCodePattern, 32> patterns; |
| 209 | |
| 210 | /// A set of user defined functions invoked via PDL. |
| 211 | std::vector<PDLConstraintFunction> constraintFunctions; |
| 212 | std::vector<PDLRewriteFunction> rewriteFunctions; |
| 213 | |
| 214 | /// The maximum memory index used by a value. |
| 215 | ByteCodeField maxValueMemoryIndex = 0; |
| 216 | |
| 217 | /// The maximum number of different types of ranges. |
| 218 | ByteCodeField maxOpRangeCount = 0; |
| 219 | ByteCodeField maxTypeRangeCount = 0; |
| 220 | ByteCodeField maxValueRangeCount = 0; |
| 221 | |
| 222 | /// The maximum number of nested loops. |
| 223 | ByteCodeField maxLoopLevel = 0; |
| 224 | }; |
| 225 | |
| 226 | } // namespace detail |
| 227 | } // namespace mlir |
| 228 | |
| 229 | #else |
| 230 | |
| 231 | namespace mlir::detail { |
| 232 | |
| 233 | class PDLByteCodeMutableState { |
| 234 | public: |
| 235 | void cleanupAfterMatchAndRewrite() {} |
| 236 | void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit) {} |
| 237 | }; |
| 238 | |
| 239 | class PDLByteCodePattern : public Pattern {}; |
| 240 | |
| 241 | class PDLByteCode { |
| 242 | public: |
| 243 | struct MatchResult { |
| 244 | const PDLByteCodePattern *pattern = nullptr; |
| 245 | PatternBenefit benefit; |
| 246 | }; |
| 247 | |
| 248 | void initializeMutableState(PDLByteCodeMutableState &state) const {} |
| 249 | void match(Operation *op, PatternRewriter &rewriter, |
| 250 | SmallVectorImpl<MatchResult> &matches, |
| 251 | PDLByteCodeMutableState &state) const {} |
| 252 | LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, |
| 253 | PDLByteCodeMutableState &state) const { |
| 254 | return failure(); |
| 255 | } |
| 256 | ArrayRef<PDLByteCodePattern> getPatterns() const { return {}; } |
| 257 | }; |
| 258 | |
| 259 | } // namespace mlir::detail |
| 260 | |
| 261 | #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH |
| 262 | |
| 263 | #endif // MLIR_REWRITE_BYTECODE_H_ |
| 264 | |