1//===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements MLIR to byte-code generation and the interpreter.
10//
11//===----------------------------------------------------------------------===//
12
13#include "ByteCode.h"
14#include "mlir/Analysis/Liveness.h"
15#include "mlir/Dialect/PDL/IR/PDLTypes.h"
16#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17#include "mlir/IR/BuiltinOps.h"
18#include "mlir/IR/RegionGraphTraits.h"
19#include "llvm/ADT/IntervalMap.h"
20#include "llvm/ADT/PostOrderIterator.h"
21#include "llvm/ADT/TypeSwitch.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/Format.h"
24#include "llvm/Support/FormatVariadic.h"
25#include <numeric>
26#include <optional>
27
28#define DEBUG_TYPE "pdl-bytecode"
29
30using namespace mlir;
31using namespace mlir::detail;
32
33//===----------------------------------------------------------------------===//
34// PDLByteCodePattern
35//===----------------------------------------------------------------------===//
36
37PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
38 PDLPatternConfigSet *configSet,
39 ByteCodeAddr rewriterAddr) {
40 PatternBenefit benefit = matchOp.getBenefit();
41 MLIRContext *ctx = matchOp.getContext();
42
43 // Collect the set of generated operations.
44 SmallVector<StringRef, 8> generatedOps;
45 if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
46 generatedOps =
47 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
48
49 // Check to see if this is pattern matches a specific operation type.
50 if (std::optional<StringRef> rootKind = matchOp.getRootKind())
51 return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
52 generatedOps);
53 return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
54 benefit, ctx, generatedOps);
55}
56
57//===----------------------------------------------------------------------===//
58// PDLByteCodeMutableState
59//===----------------------------------------------------------------------===//
60
61/// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
62/// to the position of the pattern within the range returned by
63/// `PDLByteCode::getPatterns`.
64void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
65 PatternBenefit benefit) {
66 currentPatternBenefits[patternIndex] = benefit;
67}
68
69/// Cleanup any allocated state after a full match/rewrite has been completed.
70/// This method should be called irregardless of whether the match+rewrite was a
71/// success or not.
72void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
73 allocatedTypeRangeMemory.clear();
74 allocatedValueRangeMemory.clear();
75}
76
77//===----------------------------------------------------------------------===//
78// Bytecode OpCodes
79//===----------------------------------------------------------------------===//
80
81namespace {
82enum OpCode : ByteCodeField {
83 /// Apply an externally registered constraint.
84 ApplyConstraint,
85 /// Apply an externally registered rewrite.
86 ApplyRewrite,
87 /// Check if two generic values are equal.
88 AreEqual,
89 /// Check if two ranges are equal.
90 AreRangesEqual,
91 /// Unconditional branch.
92 Branch,
93 /// Compare the operand count of an operation with a constant.
94 CheckOperandCount,
95 /// Compare the name of an operation with a constant.
96 CheckOperationName,
97 /// Compare the result count of an operation with a constant.
98 CheckResultCount,
99 /// Compare a range of types to a constant range of types.
100 CheckTypes,
101 /// Continue to the next iteration of a loop.
102 Continue,
103 /// Create a type range from a list of constant types.
104 CreateConstantTypeRange,
105 /// Create an operation.
106 CreateOperation,
107 /// Create a type range from a list of dynamic types.
108 CreateDynamicTypeRange,
109 /// Create a value range.
110 CreateDynamicValueRange,
111 /// Erase an operation.
112 EraseOp,
113 /// Extract the op from a range at the specified index.
114 ExtractOp,
115 /// Extract the type from a range at the specified index.
116 ExtractType,
117 /// Extract the value from a range at the specified index.
118 ExtractValue,
119 /// Terminate a matcher or rewrite sequence.
120 Finalize,
121 /// Iterate over a range of values.
122 ForEach,
123 /// Get a specific attribute of an operation.
124 GetAttribute,
125 /// Get the type of an attribute.
126 GetAttributeType,
127 /// Get the defining operation of a value.
128 GetDefiningOp,
129 /// Get a specific operand of an operation.
130 GetOperand0,
131 GetOperand1,
132 GetOperand2,
133 GetOperand3,
134 GetOperandN,
135 /// Get a specific operand group of an operation.
136 GetOperands,
137 /// Get a specific result of an operation.
138 GetResult0,
139 GetResult1,
140 GetResult2,
141 GetResult3,
142 GetResultN,
143 /// Get a specific result group of an operation.
144 GetResults,
145 /// Get the users of a value or a range of values.
146 GetUsers,
147 /// Get the type of a value.
148 GetValueType,
149 /// Get the types of a value range.
150 GetValueRangeTypes,
151 /// Check if a generic value is not null.
152 IsNotNull,
153 /// Record a successful pattern match.
154 RecordMatch,
155 /// Replace an operation.
156 ReplaceOp,
157 /// Compare an attribute with a set of constants.
158 SwitchAttribute,
159 /// Compare the operand count of an operation with a set of constants.
160 SwitchOperandCount,
161 /// Compare the name of an operation with a set of constants.
162 SwitchOperationName,
163 /// Compare the result count of an operation with a set of constants.
164 SwitchResultCount,
165 /// Compare a type with a set of constants.
166 SwitchType,
167 /// Compare a range of types with a set of constants.
168 SwitchTypes,
169};
170} // namespace
171
172/// A marker used to indicate if an operation should infer types.
173static constexpr ByteCodeField kInferTypesMarker =
174 std::numeric_limits<ByteCodeField>::max();
175
176//===----------------------------------------------------------------------===//
177// ByteCode Generation
178//===----------------------------------------------------------------------===//
179
180//===----------------------------------------------------------------------===//
181// Generator
182
183namespace {
184struct ByteCodeLiveRange;
185struct ByteCodeWriter;
186
187/// Check if the given class `T` can be converted to an opaque pointer.
188template <typename T, typename... Args>
189using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
190
191/// This class represents the main generator for the pattern bytecode.
192class Generator {
193public:
194 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
195 SmallVectorImpl<ByteCodeField> &matcherByteCode,
196 SmallVectorImpl<ByteCodeField> &rewriterByteCode,
197 SmallVectorImpl<PDLByteCodePattern> &patterns,
198 ByteCodeField &maxValueMemoryIndex,
199 ByteCodeField &maxOpRangeMemoryIndex,
200 ByteCodeField &maxTypeRangeMemoryIndex,
201 ByteCodeField &maxValueRangeMemoryIndex,
202 ByteCodeField &maxLoopLevel,
203 llvm::StringMap<PDLConstraintFunction> &constraintFns,
204 llvm::StringMap<PDLRewriteFunction> &rewriteFns,
205 const DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
206 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
207 rewriterByteCode(rewriterByteCode), patterns(patterns),
208 maxValueMemoryIndex(maxValueMemoryIndex),
209 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
210 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
211 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
212 maxLoopLevel(maxLoopLevel), configMap(configMap) {
213 for (const auto &it : llvm::enumerate(First&: constraintFns))
214 constraintToMemIndex.try_emplace(Key: it.value().first(), Args: it.index());
215 for (const auto &it : llvm::enumerate(First&: rewriteFns))
216 externalRewriterToMemIndex.try_emplace(Key: it.value().first(), Args: it.index());
217 }
218
219 /// Generate the bytecode for the given PDL interpreter module.
220 void generate(ModuleOp module);
221
222 /// Return the memory index to use for the given value.
223 ByteCodeField &getMemIndex(Value value) {
224 assert(valueToMemIndex.count(value) &&
225 "expected memory index to be assigned");
226 return valueToMemIndex[value];
227 }
228
229 /// Return the range memory index used to store the given range value.
230 ByteCodeField &getRangeStorageIndex(Value value) {
231 assert(valueToRangeIndex.count(value) &&
232 "expected range index to be assigned");
233 return valueToRangeIndex[value];
234 }
235
236 /// Return an index to use when referring to the given data that is uniqued in
237 /// the MLIR context.
238 template <typename T>
239 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
240 getMemIndex(T val) {
241 const void *opaqueVal = val.getAsOpaquePointer();
242
243 // Get or insert a reference to this value.
244 auto it = uniquedDataToMemIndex.try_emplace(
245 Key: opaqueVal, Args: maxValueMemoryIndex + uniquedData.size());
246 if (it.second)
247 uniquedData.push_back(x: opaqueVal);
248 return it.first->second;
249 }
250
251private:
252 /// Allocate memory indices for the results of operations within the matcher
253 /// and rewriters.
254 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
255 ModuleOp rewriterModule);
256
257 /// Generate the bytecode for the given operation.
258 void generate(Region *region, ByteCodeWriter &writer);
259 void generate(Operation *op, ByteCodeWriter &writer);
260 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
261 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
262 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
263 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
264 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
265 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
266 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
267 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
268 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
269 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
270 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
271 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
272 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
273 void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
274 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
275 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
276 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
277 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
278 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
279 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
280 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
281 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
282 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
283 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
284 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
285 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
286 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
287 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
288 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
289 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
290 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
291 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
292 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
293 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
294 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
295 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
296 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
297 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
298
299 /// Mapping from value to its corresponding memory index.
300 DenseMap<Value, ByteCodeField> valueToMemIndex;
301
302 /// Mapping from a range value to its corresponding range storage index.
303 DenseMap<Value, ByteCodeField> valueToRangeIndex;
304
305 /// Mapping from the name of an externally registered rewrite to its index in
306 /// the bytecode registry.
307 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
308
309 /// Mapping from the name of an externally registered constraint to its index
310 /// in the bytecode registry.
311 llvm::StringMap<ByteCodeField> constraintToMemIndex;
312
313 /// Mapping from rewriter function name to the bytecode address of the
314 /// rewriter function in byte.
315 llvm::StringMap<ByteCodeAddr> rewriterToAddr;
316
317 /// Mapping from a uniqued storage object to its memory index within
318 /// `uniquedData`.
319 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
320
321 /// The current level of the foreach loop.
322 ByteCodeField curLoopLevel = 0;
323
324 /// The current MLIR context.
325 MLIRContext *ctx;
326
327 /// Mapping from block to its address.
328 DenseMap<Block *, ByteCodeAddr> blockToAddr;
329
330 /// Data of the ByteCode class to be populated.
331 std::vector<const void *> &uniquedData;
332 SmallVectorImpl<ByteCodeField> &matcherByteCode;
333 SmallVectorImpl<ByteCodeField> &rewriterByteCode;
334 SmallVectorImpl<PDLByteCodePattern> &patterns;
335 ByteCodeField &maxValueMemoryIndex;
336 ByteCodeField &maxOpRangeMemoryIndex;
337 ByteCodeField &maxTypeRangeMemoryIndex;
338 ByteCodeField &maxValueRangeMemoryIndex;
339 ByteCodeField &maxLoopLevel;
340
341 /// A map of pattern configurations.
342 const DenseMap<Operation *, PDLPatternConfigSet *> &configMap;
343};
344
345/// This class provides utilities for writing a bytecode stream.
346struct ByteCodeWriter {
347 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
348 : bytecode(bytecode), generator(generator) {}
349
350 /// Append a field to the bytecode.
351 void append(ByteCodeField field) { bytecode.push_back(Elt: field); }
352 void append(OpCode opCode) { bytecode.push_back(Elt: opCode); }
353
354 /// Append an address to the bytecode.
355 void append(ByteCodeAddr field) {
356 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
357 "unexpected ByteCode address size");
358
359 ByteCodeField fieldParts[2];
360 std::memcpy(dest: fieldParts, src: &field, n: sizeof(ByteCodeAddr));
361 bytecode.append(IL: {fieldParts[0], fieldParts[1]});
362 }
363
364 /// Append a single successor to the bytecode, the exact address will need to
365 /// be resolved later.
366 void append(Block *successor) {
367 // Add back a reference to the successor so that the address can be resolved
368 // later.
369 unresolvedSuccessorRefs[successor].push_back(Elt: bytecode.size());
370 append(field: ByteCodeAddr(0));
371 }
372
373 /// Append a successor range to the bytecode, the exact address will need to
374 /// be resolved later.
375 void append(SuccessorRange successors) {
376 for (Block *successor : successors)
377 append(successor);
378 }
379
380 /// Append a range of values that will be read as generic PDLValues.
381 void appendPDLValueList(OperandRange values) {
382 bytecode.push_back(Elt: values.size());
383 for (Value value : values)
384 appendPDLValue(value);
385 }
386
387 /// Append a value as a PDLValue.
388 void appendPDLValue(Value value) {
389 appendPDLValueKind(value);
390 append(value);
391 }
392
393 /// Append the PDLValue::Kind of the given value.
394 void appendPDLValueKind(Value value) { appendPDLValueKind(type: value.getType()); }
395
396 /// Append the PDLValue::Kind of the given type.
397 void appendPDLValueKind(Type type) {
398 PDLValue::Kind kind =
399 TypeSwitch<Type, PDLValue::Kind>(type)
400 .Case<pdl::AttributeType>(
401 [](Type) { return PDLValue::Kind::Attribute; })
402 .Case<pdl::OperationType>(
403 [](Type) { return PDLValue::Kind::Operation; })
404 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
405 if (isa<pdl::TypeType>(rangeTy.getElementType()))
406 return PDLValue::Kind::TypeRange;
407 return PDLValue::Kind::ValueRange;
408 })
409 .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
410 .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
411 bytecode.push_back(Elt: static_cast<ByteCodeField>(kind));
412 }
413
414 /// Append a value that will be stored in a memory slot and not inline within
415 /// the bytecode.
416 template <typename T>
417 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
418 std::is_pointer<T>::value>
419 append(T value) {
420 bytecode.push_back(Elt: generator.getMemIndex(value));
421 }
422
423 /// Append a range of values.
424 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
425 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
426 append(T range) {
427 bytecode.push_back(Elt: llvm::size(range));
428 for (auto it : range)
429 append(it);
430 }
431
432 /// Append a variadic number of fields to the bytecode.
433 template <typename FieldTy, typename Field2Ty, typename... FieldTys>
434 void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
435 append(field);
436 append(field2, fields...);
437 }
438
439 /// Appends a value as a pointer, stored inline within the bytecode.
440 template <typename T>
441 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
442 appendInline(T value) {
443 constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
444 const void *pointer = value.getAsOpaquePointer();
445 ByteCodeField fieldParts[numParts];
446 std::memcpy(dest: fieldParts, src: &pointer, n: sizeof(const void *));
447 bytecode.append(in_start: fieldParts, in_end: fieldParts + numParts);
448 }
449
450 /// Successor references in the bytecode that have yet to be resolved.
451 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
452
453 /// The underlying bytecode buffer.
454 SmallVectorImpl<ByteCodeField> &bytecode;
455
456 /// The main generator producing PDL.
457 Generator &generator;
458};
459
460/// This class represents a live range of PDL Interpreter values, containing
461/// information about when values are live within a match/rewrite.
462struct ByteCodeLiveRange {
463 using Set = llvm::IntervalMap<uint64_t, char, 16>;
464 using Allocator = Set::Allocator;
465
466 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
467
468 /// Union this live range with the one provided.
469 void unionWith(const ByteCodeLiveRange &rhs) {
470 for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
471 ++it)
472 liveness->insert(a: it.start(), b: it.stop(), /*dummyValue*/ y: 0);
473 }
474
475 /// Returns true if this range overlaps with the one provided.
476 bool overlaps(const ByteCodeLiveRange &rhs) const {
477 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
478 .valid();
479 }
480
481 /// A map representing the ranges of the match/rewrite that a value is live in
482 /// the interpreter.
483 ///
484 /// We use std::unique_ptr here, because IntervalMap does not provide a
485 /// correct copy or move constructor. We can eliminate the pointer once
486 /// https://reviews.llvm.org/D113240 lands.
487 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
488
489 /// The operation range storage index for this range.
490 std::optional<unsigned> opRangeIndex;
491
492 /// The type range storage index for this range.
493 std::optional<unsigned> typeRangeIndex;
494
495 /// The value range storage index for this range.
496 std::optional<unsigned> valueRangeIndex;
497};
498} // namespace
499
500void Generator::generate(ModuleOp module) {
501 auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
502 pdl_interp::PDLInterpDialect::getMatcherFunctionName());
503 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
504 pdl_interp::PDLInterpDialect::getRewriterModuleName());
505 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
506
507 // Allocate memory indices for the results of operations within the matcher
508 // and rewriters.
509 allocateMemoryIndices(matcherFunc, rewriterModule);
510
511 // Generate code for the rewriter functions.
512 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
513 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
514 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
515 for (Operation &op : rewriterFunc.getOps())
516 generate(&op, rewriterByteCodeWriter);
517 }
518 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
519 "unexpected branches in rewriter function");
520
521 // Generate code for the matcher function.
522 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
523 generate(&matcherFunc.getBody(), matcherByteCodeWriter);
524
525 // Resolve successor references in the matcher.
526 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
527 ByteCodeAddr addr = blockToAddr[it.first];
528 for (unsigned offsetToFix : it.second)
529 std::memcpy(dest: &matcherByteCode[offsetToFix], src: &addr, n: sizeof(ByteCodeAddr));
530 }
531}
532
533void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
534 ModuleOp rewriterModule) {
535 // Rewriters use simplistic allocation scheme that simply assigns an index to
536 // each result.
537 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
538 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
539 auto processRewriterValue = [&](Value val) {
540 valueToMemIndex.try_emplace(val, index++);
541 if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
542 Type elementTy = rangeType.getElementType();
543 if (isa<pdl::TypeType>(elementTy))
544 valueToRangeIndex.try_emplace(val, typeRangeIndex++);
545 else if (isa<pdl::ValueType>(elementTy))
546 valueToRangeIndex.try_emplace(val, valueRangeIndex++);
547 }
548 };
549
550 for (BlockArgument arg : rewriterFunc.getArguments())
551 processRewriterValue(arg);
552 rewriterFunc.getBody().walk([&](Operation *op) {
553 for (Value result : op->getResults())
554 processRewriterValue(result);
555 });
556 if (index > maxValueMemoryIndex)
557 maxValueMemoryIndex = index;
558 if (typeRangeIndex > maxTypeRangeMemoryIndex)
559 maxTypeRangeMemoryIndex = typeRangeIndex;
560 if (valueRangeIndex > maxValueRangeMemoryIndex)
561 maxValueRangeMemoryIndex = valueRangeIndex;
562 }
563
564 // The matcher function uses a more sophisticated numbering that tries to
565 // minimize the number of memory indices assigned. This is done by determining
566 // a live range of the values within the matcher, then the allocation is just
567 // finding the minimal number of overlapping live ranges. This is essentially
568 // a simplified form of register allocation where we don't necessarily have a
569 // limited number of registers, but we still want to minimize the number used.
570 DenseMap<Operation *, unsigned> opToFirstIndex;
571 DenseMap<Operation *, unsigned> opToLastIndex;
572
573 // A custom walk that marks the first and the last index of each operation.
574 // The entry marks the beginning of the liveness range for this operation,
575 // followed by nested operations, followed by the end of the liveness range.
576 unsigned index = 0;
577 llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
578 opToFirstIndex.try_emplace(Key: op, Args: index++);
579 for (Region &region : op->getRegions())
580 for (Block &block : region.getBlocks())
581 for (Operation &nested : block)
582 walk(&nested);
583 opToLastIndex.try_emplace(Key: op, Args: index++);
584 };
585 walk(matcherFunc);
586
587 // Liveness info for each of the defs within the matcher.
588 ByteCodeLiveRange::Allocator allocator;
589 DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
590
591 // Assign the root operation being matched to slot 0.
592 BlockArgument rootOpArg = matcherFunc.getArgument(0);
593 valueToMemIndex[rootOpArg] = 0;
594
595 // Walk each of the blocks, computing the def interval that the value is used.
596 Liveness matcherLiveness(matcherFunc);
597 matcherFunc->walk([&](Block *block) {
598 const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
599 assert(info && "expected liveness info for block");
600 auto processValue = [&](Value value, Operation *firstUseOrDef) {
601 // We don't need to process the root op argument, this value is always
602 // assigned to the first memory slot.
603 if (value == rootOpArg)
604 return;
605
606 // Set indices for the range of this block that the value is used.
607 auto defRangeIt = valueDefRanges.try_emplace(Key: value, Args&: allocator).first;
608 defRangeIt->second.liveness->insert(
609 a: opToFirstIndex[firstUseOrDef],
610 b: opToLastIndex[info->getEndOperation(value, startOperation: firstUseOrDef)],
611 /*dummyValue*/ y: 0);
612
613 // Check to see if this value is a range type.
614 if (auto rangeTy = dyn_cast<pdl::RangeType>(value.getType())) {
615 Type eleType = rangeTy.getElementType();
616 if (isa<pdl::OperationType>(eleType))
617 defRangeIt->second.opRangeIndex = 0;
618 else if (isa<pdl::TypeType>(eleType))
619 defRangeIt->second.typeRangeIndex = 0;
620 else if (isa<pdl::ValueType>(eleType))
621 defRangeIt->second.valueRangeIndex = 0;
622 }
623 };
624
625 // Process the live-ins of this block.
626 for (Value liveIn : info->in()) {
627 // Only process the value if it has been defined in the current region.
628 // Other values that span across pdl_interp.foreach will be added higher
629 // up. This ensures that the we keep them alive for the entire duration
630 // of the loop.
631 if (liveIn.getParentRegion() == block->getParent())
632 processValue(liveIn, &block->front());
633 }
634
635 // Process the block arguments for the entry block (those are not live-in).
636 if (block->isEntryBlock()) {
637 for (Value argument : block->getArguments())
638 processValue(argument, &block->front());
639 }
640
641 // Process any new defs within this block.
642 for (Operation &op : *block)
643 for (Value result : op.getResults())
644 processValue(result, &op);
645 });
646
647 // Greedily allocate memory slots using the computed def live ranges.
648 std::vector<ByteCodeLiveRange> allocatedIndices;
649
650 // The number of memory indices currently allocated (and its next value).
651 // Recall that the root gets allocated memory index 0.
652 ByteCodeField numIndices = 1;
653
654 // The number of memory ranges of various types (and their next values).
655 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
656
657 for (auto &defIt : valueDefRanges) {
658 ByteCodeField &memIndex = valueToMemIndex[defIt.first];
659 ByteCodeLiveRange &defRange = defIt.second;
660
661 // Try to allocate to an existing index.
662 for (const auto &existingIndexIt : llvm::enumerate(First&: allocatedIndices)) {
663 ByteCodeLiveRange &existingRange = existingIndexIt.value();
664 if (!defRange.overlaps(rhs: existingRange)) {
665 existingRange.unionWith(rhs: defRange);
666 memIndex = existingIndexIt.index() + 1;
667
668 if (defRange.opRangeIndex) {
669 if (!existingRange.opRangeIndex)
670 existingRange.opRangeIndex = numOpRanges++;
671 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
672 } else if (defRange.typeRangeIndex) {
673 if (!existingRange.typeRangeIndex)
674 existingRange.typeRangeIndex = numTypeRanges++;
675 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
676 } else if (defRange.valueRangeIndex) {
677 if (!existingRange.valueRangeIndex)
678 existingRange.valueRangeIndex = numValueRanges++;
679 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
680 }
681 break;
682 }
683 }
684
685 // If no existing index could be used, add a new one.
686 if (memIndex == 0) {
687 allocatedIndices.emplace_back(args&: allocator);
688 ByteCodeLiveRange &newRange = allocatedIndices.back();
689 newRange.unionWith(rhs: defRange);
690
691 // Allocate an index for op/type/value ranges.
692 if (defRange.opRangeIndex) {
693 newRange.opRangeIndex = numOpRanges;
694 valueToRangeIndex[defIt.first] = numOpRanges++;
695 } else if (defRange.typeRangeIndex) {
696 newRange.typeRangeIndex = numTypeRanges;
697 valueToRangeIndex[defIt.first] = numTypeRanges++;
698 } else if (defRange.valueRangeIndex) {
699 newRange.valueRangeIndex = numValueRanges;
700 valueToRangeIndex[defIt.first] = numValueRanges++;
701 }
702
703 memIndex = allocatedIndices.size();
704 ++numIndices;
705 }
706 }
707
708 // Print the index usage and ensure that we did not run out of index space.
709 LLVM_DEBUG({
710 llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
711 << "(down from initial " << valueDefRanges.size() << ").\n";
712 });
713 assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
714 "Ran out of memory for allocated indices");
715
716 // Update the max number of indices.
717 if (numIndices > maxValueMemoryIndex)
718 maxValueMemoryIndex = numIndices;
719 if (numOpRanges > maxOpRangeMemoryIndex)
720 maxOpRangeMemoryIndex = numOpRanges;
721 if (numTypeRanges > maxTypeRangeMemoryIndex)
722 maxTypeRangeMemoryIndex = numTypeRanges;
723 if (numValueRanges > maxValueRangeMemoryIndex)
724 maxValueRangeMemoryIndex = numValueRanges;
725}
726
727void Generator::generate(Region *region, ByteCodeWriter &writer) {
728 llvm::ReversePostOrderTraversal<Region *> rpot(region);
729 for (Block *block : rpot) {
730 // Keep track of where this block begins within the matcher function.
731 blockToAddr.try_emplace(Key: block, Args: matcherByteCode.size());
732 for (Operation &op : *block)
733 generate(op: &op, writer);
734 }
735}
736
737void Generator::generate(Operation *op, ByteCodeWriter &writer) {
738 LLVM_DEBUG({
739 // The following list must contain all the operations that do not
740 // produce any bytecode.
741 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
742 writer.appendInline(op->getLoc());
743 });
744 TypeSwitch<Operation *>(op)
745 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
746 pdl_interp::AreEqualOp, pdl_interp::BranchOp,
747 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
748 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
749 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
750 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
751 pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
752 pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
753 pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
754 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
755 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
756 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
757 pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
758 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
759 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
760 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
761 pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
762 pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
763 pdl_interp::SwitchResultCountOp>(
764 [&](auto interpOp) { this->generate(interpOp, writer); })
765 .Default([](Operation *) {
766 llvm_unreachable("unknown `pdl_interp` operation");
767 });
768}
769
770void Generator::generate(pdl_interp::ApplyConstraintOp op,
771 ByteCodeWriter &writer) {
772 assert(constraintToMemIndex.count(op.getName()) &&
773 "expected index for constraint function");
774 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
775 writer.appendPDLValueList(values: op.getArgs());
776 writer.append(field: ByteCodeField(op.getIsNegated()));
777 writer.append(op.getSuccessors());
778}
779void Generator::generate(pdl_interp::ApplyRewriteOp op,
780 ByteCodeWriter &writer) {
781 assert(externalRewriterToMemIndex.count(op.getName()) &&
782 "expected index for rewrite function");
783 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
784 writer.appendPDLValueList(values: op.getArgs());
785
786 ResultRange results = op.getResults();
787 writer.append(field: ByteCodeField(results.size()));
788 for (Value result : results) {
789 // In debug mode we also record the expected kind of the result, so that we
790 // can provide extra verification of the native rewrite function.
791#ifndef NDEBUG
792 writer.appendPDLValueKind(result);
793#endif
794
795 // Range results also need to append the range storage index.
796 if (isa<pdl::RangeType>(result.getType()))
797 writer.append(getRangeStorageIndex(result));
798 writer.append(result);
799 }
800}
801void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
802 Value lhs = op.getLhs();
803 if (isa<pdl::RangeType>(lhs.getType())) {
804 writer.append(opCode: OpCode::AreRangesEqual);
805 writer.appendPDLValueKind(value: lhs);
806 writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
807 return;
808 }
809
810 writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
811}
812void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
813 writer.append(field: OpCode::Branch, field2: SuccessorRange(op.getOperation()));
814}
815void Generator::generate(pdl_interp::CheckAttributeOp op,
816 ByteCodeWriter &writer) {
817 writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
818 op.getSuccessors());
819}
820void Generator::generate(pdl_interp::CheckOperandCountOp op,
821 ByteCodeWriter &writer) {
822 writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
823 static_cast<ByteCodeField>(op.getCompareAtLeast()),
824 op.getSuccessors());
825}
826void Generator::generate(pdl_interp::CheckOperationNameOp op,
827 ByteCodeWriter &writer) {
828 writer.append(OpCode::CheckOperationName, op.getInputOp(),
829 OperationName(op.getName(), ctx), op.getSuccessors());
830}
831void Generator::generate(pdl_interp::CheckResultCountOp op,
832 ByteCodeWriter &writer) {
833 writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
834 static_cast<ByteCodeField>(op.getCompareAtLeast()),
835 op.getSuccessors());
836}
837void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
838 writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
839 op.getSuccessors());
840}
841void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
842 writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
843 op.getSuccessors());
844}
845void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
846 assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
847 writer.append(field: OpCode::Continue, field2: ByteCodeField(curLoopLevel - 1));
848}
849void Generator::generate(pdl_interp::CreateAttributeOp op,
850 ByteCodeWriter &writer) {
851 // Simply repoint the memory index of the result to the constant.
852 getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
853}
854void Generator::generate(pdl_interp::CreateOperationOp op,
855 ByteCodeWriter &writer) {
856 writer.append(OpCode::CreateOperation, op.getResultOp(),
857 OperationName(op.getName(), ctx));
858 writer.appendPDLValueList(values: op.getInputOperands());
859
860 // Add the attributes.
861 OperandRange attributes = op.getInputAttributes();
862 writer.append(field: static_cast<ByteCodeField>(attributes.size()));
863 for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
864 writer.append(std::get<0>(it), std::get<1>(it));
865
866 // Add the result types. If the operation has inferred results, we use a
867 // marker "size" value. Otherwise, we add the list of explicit result types.
868 if (op.getInferredResultTypes())
869 writer.append(field: kInferTypesMarker);
870 else
871 writer.appendPDLValueList(values: op.getInputResultTypes());
872}
873void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
874 // Append the correct opcode for the range type.
875 TypeSwitch<Type>(op.getType().getElementType())
876 .Case(
877 caseFn: [&](pdl::TypeType) { writer.append(opCode: OpCode::CreateDynamicTypeRange); })
878 .Case(caseFn: [&](pdl::ValueType) {
879 writer.append(opCode: OpCode::CreateDynamicValueRange);
880 });
881
882 writer.append(op.getResult(), getRangeStorageIndex(value: op.getResult()));
883 writer.appendPDLValueList(values: op->getOperands());
884}
885void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
886 // Simply repoint the memory index of the result to the constant.
887 getMemIndex(op.getResult()) = getMemIndex(op.getValue());
888}
889void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
890 writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
891 getRangeStorageIndex(value: op.getResult()), op.getValue());
892}
893void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
894 writer.append(OpCode::EraseOp, op.getInputOp());
895}
896void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
897 OpCode opCode =
898 TypeSwitch<Type, OpCode>(op.getResult().getType())
899 .Case(caseFn: [](pdl::OperationType) { return OpCode::ExtractOp; })
900 .Case(caseFn: [](pdl::ValueType) { return OpCode::ExtractValue; })
901 .Case(caseFn: [](pdl::TypeType) { return OpCode::ExtractType; })
902 .Default(defaultFn: [](Type) -> OpCode {
903 llvm_unreachable("unsupported element type");
904 });
905 writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
906}
907void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
908 writer.append(opCode: OpCode::Finalize);
909}
910void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
911 BlockArgument arg = op.getLoopVariable();
912 writer.append(OpCode::ForEach, getRangeStorageIndex(value: op.getValues()), arg);
913 writer.appendPDLValueKind(type: arg.getType());
914 writer.append(curLoopLevel, op.getSuccessor());
915 ++curLoopLevel;
916 if (curLoopLevel > maxLoopLevel)
917 maxLoopLevel = curLoopLevel;
918 generate(&op.getRegion(), writer);
919 --curLoopLevel;
920}
921void Generator::generate(pdl_interp::GetAttributeOp op,
922 ByteCodeWriter &writer) {
923 writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
924 op.getNameAttr());
925}
926void Generator::generate(pdl_interp::GetAttributeTypeOp op,
927 ByteCodeWriter &writer) {
928 writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
929}
930void Generator::generate(pdl_interp::GetDefiningOpOp op,
931 ByteCodeWriter &writer) {
932 writer.append(OpCode::GetDefiningOp, op.getInputOp());
933 writer.appendPDLValue(value: op.getValue());
934}
935void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
936 uint32_t index = op.getIndex();
937 if (index < 4)
938 writer.append(opCode: static_cast<OpCode>(OpCode::GetOperand0 + index));
939 else
940 writer.append(field: OpCode::GetOperandN, field2: index);
941 writer.append(op.getInputOp(), op.getValue());
942}
943void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
944 Value result = op.getValue();
945 std::optional<uint32_t> index = op.getIndex();
946 writer.append(OpCode::GetOperands,
947 index.value_or(u: std::numeric_limits<uint32_t>::max()),
948 op.getInputOp());
949 if (isa<pdl::RangeType>(result.getType()))
950 writer.append(field: getRangeStorageIndex(value: result));
951 else
952 writer.append(field: std::numeric_limits<ByteCodeField>::max());
953 writer.append(value: result);
954}
955void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
956 uint32_t index = op.getIndex();
957 if (index < 4)
958 writer.append(opCode: static_cast<OpCode>(OpCode::GetResult0 + index));
959 else
960 writer.append(field: OpCode::GetResultN, field2: index);
961 writer.append(op.getInputOp(), op.getValue());
962}
963void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
964 Value result = op.getValue();
965 std::optional<uint32_t> index = op.getIndex();
966 writer.append(OpCode::GetResults,
967 index.value_or(u: std::numeric_limits<uint32_t>::max()),
968 op.getInputOp());
969 if (isa<pdl::RangeType>(result.getType()))
970 writer.append(field: getRangeStorageIndex(value: result));
971 else
972 writer.append(field: std::numeric_limits<ByteCodeField>::max());
973 writer.append(value: result);
974}
975void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
976 Value operations = op.getOperations();
977 ByteCodeField rangeIndex = getRangeStorageIndex(value: operations);
978 writer.append(field: OpCode::GetUsers, field2: operations, fields: rangeIndex);
979 writer.appendPDLValue(value: op.getValue());
980}
981void Generator::generate(pdl_interp::GetValueTypeOp op,
982 ByteCodeWriter &writer) {
983 if (isa<pdl::RangeType>(op.getType())) {
984 Value result = op.getResult();
985 writer.append(OpCode::GetValueRangeTypes, result,
986 getRangeStorageIndex(value: result), op.getValue());
987 } else {
988 writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
989 }
990}
991void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
992 writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
993}
994void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
995 ByteCodeField patternIndex = patterns.size();
996 patterns.emplace_back(PDLByteCodePattern::create(
997 matchOp: op, configSet: configMap.lookup(Val: op),
998 rewriterAddr: rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
999 writer.append(OpCode::RecordMatch, patternIndex,
1000 SuccessorRange(op.getOperation()), op.getMatchedOps());
1001 writer.appendPDLValueList(values: op.getInputs());
1002}
1003void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1004 writer.append(OpCode::ReplaceOp, op.getInputOp());
1005 writer.appendPDLValueList(values: op.getReplValues());
1006}
1007void Generator::generate(pdl_interp::SwitchAttributeOp op,
1008 ByteCodeWriter &writer) {
1009 writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1010 op.getCaseValuesAttr(), op.getSuccessors());
1011}
1012void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1013 ByteCodeWriter &writer) {
1014 writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1015 op.getCaseValuesAttr(), op.getSuccessors());
1016}
1017void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1018 ByteCodeWriter &writer) {
1019 auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
1020 return OperationName(cast<StringAttr>(attr).getValue(), ctx);
1021 });
1022 writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1023 op.getSuccessors());
1024}
1025void Generator::generate(pdl_interp::SwitchResultCountOp op,
1026 ByteCodeWriter &writer) {
1027 writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1028 op.getCaseValuesAttr(), op.getSuccessors());
1029}
1030void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1031 writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1032 op.getSuccessors());
1033}
1034void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1035 writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1036 op.getSuccessors());
1037}
1038
1039//===----------------------------------------------------------------------===//
1040// PDLByteCode
1041//===----------------------------------------------------------------------===//
1042
1043PDLByteCode::PDLByteCode(
1044 ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1045 const DenseMap<Operation *, PDLPatternConfigSet *> &configMap,
1046 llvm::StringMap<PDLConstraintFunction> constraintFns,
1047 llvm::StringMap<PDLRewriteFunction> rewriteFns)
1048 : configs(std::move(configs)) {
1049 Generator generator(module.getContext(), uniquedData, matcherByteCode,
1050 rewriterByteCode, patterns, maxValueMemoryIndex,
1051 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1052 maxLoopLevel, constraintFns, rewriteFns, configMap);
1053 generator.generate(module);
1054
1055 // Initialize the external functions.
1056 for (auto &it : constraintFns)
1057 constraintFunctions.push_back(x: std::move(it.second));
1058 for (auto &it : rewriteFns)
1059 rewriteFunctions.push_back(x: std::move(it.second));
1060}
1061
1062/// Initialize the given state such that it can be used to execute the current
1063/// bytecode.
1064void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1065 state.memory.resize(new_size: maxValueMemoryIndex, x: nullptr);
1066 state.opRangeMemory.resize(new_size: maxOpRangeCount);
1067 state.typeRangeMemory.resize(new_size: maxTypeRangeCount, x: TypeRange());
1068 state.valueRangeMemory.resize(new_size: maxValueRangeCount, x: ValueRange());
1069 state.loopIndex.resize(new_size: maxLoopLevel, x: 0);
1070 state.currentPatternBenefits.reserve(n: patterns.size());
1071 for (const PDLByteCodePattern &pattern : patterns)
1072 state.currentPatternBenefits.push_back(x: pattern.getBenefit());
1073}
1074
1075//===----------------------------------------------------------------------===//
1076// ByteCode Execution
1077
1078namespace {
1079/// This class provides support for executing a bytecode stream.
1080class ByteCodeExecutor {
1081public:
1082 ByteCodeExecutor(
1083 const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1084 MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1085 MutableArrayRef<TypeRange> typeRangeMemory,
1086 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1087 MutableArrayRef<ValueRange> valueRangeMemory,
1088 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1089 MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1090 ArrayRef<ByteCodeField> code,
1091 ArrayRef<PatternBenefit> currentPatternBenefits,
1092 ArrayRef<PDLByteCodePattern> patterns,
1093 ArrayRef<PDLConstraintFunction> constraintFunctions,
1094 ArrayRef<PDLRewriteFunction> rewriteFunctions)
1095 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1096 typeRangeMemory(typeRangeMemory),
1097 allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1098 valueRangeMemory(valueRangeMemory),
1099 allocatedValueRangeMemory(allocatedValueRangeMemory),
1100 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1101 currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1102 constraintFunctions(constraintFunctions),
1103 rewriteFunctions(rewriteFunctions) {}
1104
1105 /// Start executing the code at the current bytecode index. `matches` is an
1106 /// optional field provided when this function is executed in a matching
1107 /// context.
1108 LogicalResult
1109 execute(PatternRewriter &rewriter,
1110 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1111 std::optional<Location> mainRewriteLoc = {});
1112
1113private:
1114 /// Internal implementation of executing each of the bytecode commands.
1115 void executeApplyConstraint(PatternRewriter &rewriter);
1116 LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
1117 void executeAreEqual();
1118 void executeAreRangesEqual();
1119 void executeBranch();
1120 void executeCheckOperandCount();
1121 void executeCheckOperationName();
1122 void executeCheckResultCount();
1123 void executeCheckTypes();
1124 void executeContinue();
1125 void executeCreateConstantTypeRange();
1126 void executeCreateOperation(PatternRewriter &rewriter,
1127 Location mainRewriteLoc);
1128 template <typename T>
1129 void executeDynamicCreateRange(StringRef type);
1130 void executeEraseOp(PatternRewriter &rewriter);
1131 template <typename T, typename Range, PDLValue::Kind kind>
1132 void executeExtract();
1133 void executeFinalize();
1134 void executeForEach();
1135 void executeGetAttribute();
1136 void executeGetAttributeType();
1137 void executeGetDefiningOp();
1138 void executeGetOperand(unsigned index);
1139 void executeGetOperands();
1140 void executeGetResult(unsigned index);
1141 void executeGetResults();
1142 void executeGetUsers();
1143 void executeGetValueType();
1144 void executeGetValueRangeTypes();
1145 void executeIsNotNull();
1146 void executeRecordMatch(PatternRewriter &rewriter,
1147 SmallVectorImpl<PDLByteCode::MatchResult> &matches);
1148 void executeReplaceOp(PatternRewriter &rewriter);
1149 void executeSwitchAttribute();
1150 void executeSwitchOperandCount();
1151 void executeSwitchOperationName();
1152 void executeSwitchResultCount();
1153 void executeSwitchType();
1154 void executeSwitchTypes();
1155
1156 /// Pushes a code iterator to the stack.
1157 void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(Elt: it); }
1158
1159 /// Pops a code iterator from the stack, returning true on success.
1160 void popCodeIt() {
1161 assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1162 curCodeIt = resumeCodeIt.back();
1163 resumeCodeIt.pop_back();
1164 }
1165
1166 /// Return the bytecode iterator at the start of the current op code.
1167 const ByteCodeField *getPrevCodeIt() const {
1168 LLVM_DEBUG({
1169 // Account for the op code and the Location stored inline.
1170 return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1171 });
1172
1173 // Account for the op code only.
1174 return curCodeIt - 1;
1175 }
1176
1177 /// Read a value from the bytecode buffer, optionally skipping a certain
1178 /// number of prefix values. These methods always update the buffer to point
1179 /// to the next field after the read data.
1180 template <typename T = ByteCodeField>
1181 T read(size_t skipN = 0) {
1182 curCodeIt += skipN;
1183 return readImpl<T>();
1184 }
1185 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1186
1187 /// Read a list of values from the bytecode buffer.
1188 template <typename ValueT, typename T>
1189 void readList(SmallVectorImpl<T> &list) {
1190 list.clear();
1191 for (unsigned i = 0, e = read(); i != e; ++i)
1192 list.push_back(read<ValueT>());
1193 }
1194
1195 /// Read a list of values from the bytecode buffer. The values may be encoded
1196 /// either as a single element or a range of elements.
1197 void readList(SmallVectorImpl<Type> &list) {
1198 for (unsigned i = 0, e = read(); i != e; ++i) {
1199 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1200 list.push_back(Elt: read<Type>());
1201 } else {
1202 TypeRange *values = read<TypeRange *>();
1203 list.append(in_start: values->begin(), in_end: values->end());
1204 }
1205 }
1206 }
1207 void readList(SmallVectorImpl<Value> &list) {
1208 for (unsigned i = 0, e = read(); i != e; ++i) {
1209 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1210 list.push_back(Elt: read<Value>());
1211 } else {
1212 ValueRange *values = read<ValueRange *>();
1213 list.append(in_start: values->begin(), in_end: values->end());
1214 }
1215 }
1216 }
1217
1218 /// Read a value stored inline as a pointer.
1219 template <typename T>
1220 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1221 readInline() {
1222 const void *pointer;
1223 std::memcpy(dest: &pointer, src: curCodeIt, n: sizeof(const void *));
1224 curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1225 return T::getFromOpaquePointer(pointer);
1226 }
1227
1228 /// Jump to a specific successor based on a predicate value.
1229 void selectJump(bool isTrue) { selectJump(destIndex: size_t(isTrue ? 0 : 1)); }
1230 /// Jump to a specific successor based on a destination index.
1231 void selectJump(size_t destIndex) {
1232 curCodeIt = &code[read<ByteCodeAddr>(skipN: destIndex * 2)];
1233 }
1234
1235 /// Handle a switch operation with the provided value and cases.
1236 template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1237 void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1238 LLVM_DEBUG({
1239 llvm::dbgs() << " * Value: " << value << "\n"
1240 << " * Cases: ";
1241 llvm::interleaveComma(cases, llvm::dbgs());
1242 llvm::dbgs() << "\n";
1243 });
1244
1245 // Check to see if the attribute value is within the case list. Jump to
1246 // the correct successor index based on the result.
1247 for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1248 if (cmp(*it, value))
1249 return selectJump(destIndex: size_t((it - cases.begin()) + 1));
1250 selectJump(destIndex: size_t(0));
1251 }
1252
1253 /// Store a pointer to memory.
1254 void storeToMemory(unsigned index, const void *value) {
1255 memory[index] = value;
1256 }
1257
1258 /// Store a value to memory as an opaque pointer.
1259 template <typename T>
1260 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1261 storeToMemory(unsigned index, T value) {
1262 memory[index] = value.getAsOpaquePointer();
1263 }
1264
1265 /// Internal implementation of reading various data types from the bytecode
1266 /// stream.
1267 template <typename T>
1268 const void *readFromMemory() {
1269 size_t index = *curCodeIt++;
1270
1271 // If this type is an SSA value, it can only be stored in non-const memory.
1272 if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1273 Value>::value ||
1274 index < memory.size())
1275 return memory[index];
1276
1277 // Otherwise, if this index is not inbounds it is uniqued.
1278 return uniquedMemory[index - memory.size()];
1279 }
1280 template <typename T>
1281 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1282 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1283 }
1284 template <typename T>
1285 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1286 T>
1287 readImpl() {
1288 return T(T::getFromOpaquePointer(readFromMemory<T>()));
1289 }
1290 template <typename T>
1291 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1292 switch (read<PDLValue::Kind>()) {
1293 case PDLValue::Kind::Attribute:
1294 return read<Attribute>();
1295 case PDLValue::Kind::Operation:
1296 return read<Operation *>();
1297 case PDLValue::Kind::Type:
1298 return read<Type>();
1299 case PDLValue::Kind::Value:
1300 return read<Value>();
1301 case PDLValue::Kind::TypeRange:
1302 return read<TypeRange *>();
1303 case PDLValue::Kind::ValueRange:
1304 return read<ValueRange *>();
1305 }
1306 llvm_unreachable("unhandled PDLValue::Kind");
1307 }
1308 template <typename T>
1309 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1310 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1311 "unexpected ByteCode address size");
1312 ByteCodeAddr result;
1313 std::memcpy(dest: &result, src: curCodeIt, n: sizeof(ByteCodeAddr));
1314 curCodeIt += 2;
1315 return result;
1316 }
1317 template <typename T>
1318 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1319 return *curCodeIt++;
1320 }
1321 template <typename T>
1322 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1323 return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1324 }
1325
1326 /// Assign the given range to the given memory index. This allocates a new
1327 /// range object if necessary.
1328 template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
1329 void assignRangeToMemory(RangeT &&range, unsigned memIndex,
1330 unsigned rangeIndex) {
1331 // Utility functor used to type-erase the assignment.
1332 auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
1333 // If the input range is empty, we don't need to allocate anything.
1334 if (range.empty()) {
1335 rangeMemory[rangeIndex] = {};
1336 } else {
1337 // Allocate a buffer for this type range.
1338 llvm::OwningArrayRef<T> storage(llvm::size(range));
1339 llvm::copy(range, storage.begin());
1340
1341 // Assign this to the range slot and use the range as the value for the
1342 // memory index.
1343 allocatedRangeMemory.emplace_back(std::move(storage));
1344 rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1345 }
1346 memory[memIndex] = &rangeMemory[rangeIndex];
1347 };
1348
1349 // Dispatch based on the concrete range type.
1350 if constexpr (std::is_same_v<T, Type>) {
1351 return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1352 } else if constexpr (std::is_same_v<T, Value>) {
1353 return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1354 } else {
1355 llvm_unreachable("unhandled range type");
1356 }
1357 }
1358
1359 /// The underlying bytecode buffer.
1360 const ByteCodeField *curCodeIt;
1361
1362 /// The stack of bytecode positions at which to resume operation.
1363 SmallVector<const ByteCodeField *> resumeCodeIt;
1364
1365 /// The current execution memory.
1366 MutableArrayRef<const void *> memory;
1367 MutableArrayRef<OwningOpRange> opRangeMemory;
1368 MutableArrayRef<TypeRange> typeRangeMemory;
1369 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1370 MutableArrayRef<ValueRange> valueRangeMemory;
1371 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1372
1373 /// The current loop indices.
1374 MutableArrayRef<unsigned> loopIndex;
1375
1376 /// References to ByteCode data necessary for execution.
1377 ArrayRef<const void *> uniquedMemory;
1378 ArrayRef<ByteCodeField> code;
1379 ArrayRef<PatternBenefit> currentPatternBenefits;
1380 ArrayRef<PDLByteCodePattern> patterns;
1381 ArrayRef<PDLConstraintFunction> constraintFunctions;
1382 ArrayRef<PDLRewriteFunction> rewriteFunctions;
1383};
1384
1385/// This class is an instantiation of the PDLResultList that provides access to
1386/// the returned results. This API is not on `PDLResultList` to avoid
1387/// overexposing access to information specific solely to the ByteCode.
1388class ByteCodeRewriteResultList : public PDLResultList {
1389public:
1390 ByteCodeRewriteResultList(unsigned maxNumResults)
1391 : PDLResultList(maxNumResults) {}
1392
1393 /// Return the list of PDL results.
1394 MutableArrayRef<PDLValue> getResults() { return results; }
1395
1396 /// Return the type ranges allocated by this list.
1397 MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1398 return allocatedTypeRanges;
1399 }
1400
1401 /// Return the value ranges allocated by this list.
1402 MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1403 return allocatedValueRanges;
1404 }
1405};
1406} // namespace
1407
1408void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1409 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1410 const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1411 SmallVector<PDLValue, 16> args;
1412 readList<PDLValue>(list&: args);
1413
1414 LLVM_DEBUG({
1415 llvm::dbgs() << " * Arguments: ";
1416 llvm::interleaveComma(args, llvm::dbgs());
1417 llvm::dbgs() << "\n";
1418 });
1419
1420 ByteCodeField isNegated = read();
1421 LLVM_DEBUG({
1422 llvm::dbgs() << " * isNegated: " << isNegated << "\n";
1423 llvm::interleaveComma(args, llvm::dbgs());
1424 });
1425 // Invoke the constraint and jump to the proper destination.
1426 selectJump(isTrue: isNegated != succeeded(result: constraintFn(rewriter, args)));
1427}
1428
1429LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1430 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1431 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1432 SmallVector<PDLValue, 16> args;
1433 readList<PDLValue>(list&: args);
1434
1435 LLVM_DEBUG({
1436 llvm::dbgs() << " * Arguments: ";
1437 llvm::interleaveComma(args, llvm::dbgs());
1438 });
1439
1440 // Execute the rewrite function.
1441 ByteCodeField numResults = read();
1442 ByteCodeRewriteResultList results(numResults);
1443 LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1444
1445 assert(results.getResults().size() == numResults &&
1446 "native PDL rewrite function returned unexpected number of results");
1447
1448 // Store the results in the bytecode memory.
1449 for (PDLValue &result : results.getResults()) {
1450 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
1451
1452// In debug mode we also verify the expected kind of the result.
1453#ifndef NDEBUG
1454 assert(result.getKind() == read<PDLValue::Kind>() &&
1455 "native PDL rewrite function returned an unexpected type of result");
1456#endif
1457
1458 // If the result is a range, we need to copy it over to the bytecodes
1459 // range memory.
1460 if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1461 unsigned rangeIndex = read();
1462 typeRangeMemory[rangeIndex] = *typeRange;
1463 memory[read()] = &typeRangeMemory[rangeIndex];
1464 } else if (std::optional<ValueRange> valueRange =
1465 result.dyn_cast<ValueRange>()) {
1466 unsigned rangeIndex = read();
1467 valueRangeMemory[rangeIndex] = *valueRange;
1468 memory[read()] = &valueRangeMemory[rangeIndex];
1469 } else {
1470 memory[read()] = result.getAsOpaquePointer();
1471 }
1472 }
1473
1474 // Copy over any underlying storage allocated for result ranges.
1475 for (auto &it : results.getAllocatedTypeRanges())
1476 allocatedTypeRangeMemory.push_back(x: std::move(it));
1477 for (auto &it : results.getAllocatedValueRanges())
1478 allocatedValueRangeMemory.push_back(x: std::move(it));
1479
1480 // Process the result of the rewrite.
1481 if (failed(result: rewriteResult)) {
1482 LLVM_DEBUG(llvm::dbgs() << " - Failed");
1483 return failure();
1484 }
1485 return success();
1486}
1487
1488void ByteCodeExecutor::executeAreEqual() {
1489 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1490 const void *lhs = read<const void *>();
1491 const void *rhs = read<const void *>();
1492
1493 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n");
1494 selectJump(isTrue: lhs == rhs);
1495}
1496
1497void ByteCodeExecutor::executeAreRangesEqual() {
1498 LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1499 PDLValue::Kind valueKind = read<PDLValue::Kind>();
1500 const void *lhs = read<const void *>();
1501 const void *rhs = read<const void *>();
1502
1503 switch (valueKind) {
1504 case PDLValue::Kind::TypeRange: {
1505 const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1506 const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1507 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1508 selectJump(isTrue: *lhsRange == *rhsRange);
1509 break;
1510 }
1511 case PDLValue::Kind::ValueRange: {
1512 const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1513 const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1514 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1515 selectJump(isTrue: *lhsRange == *rhsRange);
1516 break;
1517 }
1518 default:
1519 llvm_unreachable("unexpected `AreRangesEqual` value kind");
1520 }
1521}
1522
1523void ByteCodeExecutor::executeBranch() {
1524 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1525 curCodeIt = &code[read<ByteCodeAddr>()];
1526}
1527
1528void ByteCodeExecutor::executeCheckOperandCount() {
1529 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1530 Operation *op = read<Operation *>();
1531 uint32_t expectedCount = read<uint32_t>();
1532 bool compareAtLeast = read();
1533
1534 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
1535 << " * Expected: " << expectedCount << "\n"
1536 << " * Comparator: "
1537 << (compareAtLeast ? ">=" : "==") << "\n");
1538 if (compareAtLeast)
1539 selectJump(isTrue: op->getNumOperands() >= expectedCount);
1540 else
1541 selectJump(isTrue: op->getNumOperands() == expectedCount);
1542}
1543
1544void ByteCodeExecutor::executeCheckOperationName() {
1545 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1546 Operation *op = read<Operation *>();
1547 OperationName expectedName = read<OperationName>();
1548
1549 LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n"
1550 << " * Expected: \"" << expectedName << "\"\n");
1551 selectJump(isTrue: op->getName() == expectedName);
1552}
1553
1554void ByteCodeExecutor::executeCheckResultCount() {
1555 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1556 Operation *op = read<Operation *>();
1557 uint32_t expectedCount = read<uint32_t>();
1558 bool compareAtLeast = read();
1559
1560 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
1561 << " * Expected: " << expectedCount << "\n"
1562 << " * Comparator: "
1563 << (compareAtLeast ? ">=" : "==") << "\n");
1564 if (compareAtLeast)
1565 selectJump(isTrue: op->getNumResults() >= expectedCount);
1566 else
1567 selectJump(isTrue: op->getNumResults() == expectedCount);
1568}
1569
1570void ByteCodeExecutor::executeCheckTypes() {
1571 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1572 TypeRange *lhs = read<TypeRange *>();
1573 Attribute rhs = read<Attribute>();
1574 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1575
1576 selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
1577}
1578
1579void ByteCodeExecutor::executeContinue() {
1580 ByteCodeField level = read();
1581 LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1582 << " * Level: " << level << "\n");
1583 ++loopIndex[level];
1584 popCodeIt();
1585}
1586
1587void ByteCodeExecutor::executeCreateConstantTypeRange() {
1588 LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
1589 unsigned memIndex = read();
1590 unsigned rangeIndex = read();
1591 ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
1592
1593 LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
1594 assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1595 rangeIndex);
1596}
1597
1598void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1599 Location mainRewriteLoc) {
1600 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1601
1602 unsigned memIndex = read();
1603 OperationState state(mainRewriteLoc, read<OperationName>());
1604 readList(list&: state.operands);
1605 for (unsigned i = 0, e = read(); i != e; ++i) {
1606 StringAttr name = read<StringAttr>();
1607 if (Attribute attr = read<Attribute>())
1608 state.addAttribute(name, attr);
1609 }
1610
1611 // Read in the result types. If the "size" is the sentinel value, this
1612 // indicates that the result types should be inferred.
1613 unsigned numResults = read();
1614 if (numResults == kInferTypesMarker) {
1615 InferTypeOpInterface::Concept *inferInterface =
1616 state.name.getInterface<InferTypeOpInterface>();
1617 assert(inferInterface &&
1618 "expected operation to provide InferTypeOpInterface");
1619
1620 // TODO: Handle failure.
1621 if (failed(inferInterface->inferReturnTypes(
1622 state.getContext(), state.location, state.operands,
1623 state.attributes.getDictionary(state.getContext()),
1624 state.getRawProperties(), state.regions, state.types)))
1625 return;
1626 } else {
1627 // Otherwise, this is a fixed number of results.
1628 for (unsigned i = 0; i != numResults; ++i) {
1629 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1630 state.types.push_back(Elt: read<Type>());
1631 } else {
1632 TypeRange *resultTypes = read<TypeRange *>();
1633 state.types.append(in_start: resultTypes->begin(), in_end: resultTypes->end());
1634 }
1635 }
1636 }
1637
1638 Operation *resultOp = rewriter.create(state);
1639 memory[memIndex] = resultOp;
1640
1641 LLVM_DEBUG({
1642 llvm::dbgs() << " * Attributes: "
1643 << state.attributes.getDictionary(state.getContext())
1644 << "\n * Operands: ";
1645 llvm::interleaveComma(state.operands, llvm::dbgs());
1646 llvm::dbgs() << "\n * Result Types: ";
1647 llvm::interleaveComma(state.types, llvm::dbgs());
1648 llvm::dbgs() << "\n * Result: " << *resultOp << "\n";
1649 });
1650}
1651
1652template <typename T>
1653void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1654 LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n");
1655 unsigned memIndex = read();
1656 unsigned rangeIndex = read();
1657 SmallVector<T> values;
1658 readList(values);
1659
1660 LLVM_DEBUG({
1661 llvm::dbgs() << "\n * " << type << "s: ";
1662 llvm::interleaveComma(values, llvm::dbgs());
1663 llvm::dbgs() << "\n";
1664 });
1665
1666 assignRangeToMemory(values, memIndex, rangeIndex);
1667}
1668
1669void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1670 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1671 Operation *op = read<Operation *>();
1672
1673 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1674 rewriter.eraseOp(op);
1675}
1676
1677template <typename T, typename Range, PDLValue::Kind kind>
1678void ByteCodeExecutor::executeExtract() {
1679 LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1680 Range *range = read<Range *>();
1681 unsigned index = read<uint32_t>();
1682 unsigned memIndex = read();
1683
1684 if (!range) {
1685 memory[memIndex] = nullptr;
1686 return;
1687 }
1688
1689 T result = index < range->size() ? (*range)[index] : T();
1690 LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n"
1691 << " * Index: " << index << "\n"
1692 << " * Result: " << result << "\n");
1693 storeToMemory(memIndex, result);
1694}
1695
1696void ByteCodeExecutor::executeFinalize() {
1697 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1698}
1699
1700void ByteCodeExecutor::executeForEach() {
1701 LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1702 const ByteCodeField *prevCodeIt = getPrevCodeIt();
1703 unsigned rangeIndex = read();
1704 unsigned memIndex = read();
1705 const void *value = nullptr;
1706
1707 switch (read<PDLValue::Kind>()) {
1708 case PDLValue::Kind::Operation: {
1709 unsigned &index = loopIndex[read()];
1710 ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1711 assert(index <= array.size() && "iterated past the end");
1712 if (index < array.size()) {
1713 LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n");
1714 value = array[index];
1715 break;
1716 }
1717
1718 LLVM_DEBUG(llvm::dbgs() << " * Done\n");
1719 index = 0;
1720 selectJump(destIndex: size_t(0));
1721 return;
1722 }
1723 default:
1724 llvm_unreachable("unexpected `ForEach` value kind");
1725 }
1726
1727 // Store the iterate value and the stack address.
1728 memory[memIndex] = value;
1729 pushCodeIt(it: prevCodeIt);
1730
1731 // Skip over the successor (we will enter the body of the loop).
1732 read<ByteCodeAddr>();
1733}
1734
1735void ByteCodeExecutor::executeGetAttribute() {
1736 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1737 unsigned memIndex = read();
1738 Operation *op = read<Operation *>();
1739 StringAttr attrName = read<StringAttr>();
1740 Attribute attr = op->getAttr(attrName);
1741
1742 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1743 << " * Attribute: " << attrName << "\n"
1744 << " * Result: " << attr << "\n");
1745 memory[memIndex] = attr.getAsOpaquePointer();
1746}
1747
1748void ByteCodeExecutor::executeGetAttributeType() {
1749 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1750 unsigned memIndex = read();
1751 Attribute attr = read<Attribute>();
1752 Type type;
1753 if (auto typedAttr = dyn_cast<TypedAttr>(attr))
1754 type = typedAttr.getType();
1755
1756 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
1757 << " * Result: " << type << "\n");
1758 memory[memIndex] = type.getAsOpaquePointer();
1759}
1760
1761void ByteCodeExecutor::executeGetDefiningOp() {
1762 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1763 unsigned memIndex = read();
1764 Operation *op = nullptr;
1765 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1766 Value value = read<Value>();
1767 if (value)
1768 op = value.getDefiningOp();
1769 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1770 } else {
1771 ValueRange *values = read<ValueRange *>();
1772 if (values && !values->empty()) {
1773 op = values->front().getDefiningOp();
1774 }
1775 LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n");
1776 }
1777
1778 LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n");
1779 memory[memIndex] = op;
1780}
1781
1782void ByteCodeExecutor::executeGetOperand(unsigned index) {
1783 Operation *op = read<Operation *>();
1784 unsigned memIndex = read();
1785 Value operand =
1786 index < op->getNumOperands() ? op->getOperand(idx: index) : Value();
1787
1788 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1789 << " * Index: " << index << "\n"
1790 << " * Result: " << operand << "\n");
1791 memory[memIndex] = operand.getAsOpaquePointer();
1792}
1793
1794/// This function is the internal implementation of `GetResults` and
1795/// `GetOperands` that provides support for extracting a value range from the
1796/// given operation.
1797template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1798static void *
1799executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1800 ByteCodeField rangeIndex, StringRef attrSizedSegments,
1801 MutableArrayRef<ValueRange> valueRangeMemory) {
1802 // Check for the sentinel index that signals that all values should be
1803 // returned.
1804 if (index == std::numeric_limits<uint32_t>::max()) {
1805 LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n");
1806 // `values` is already the full value range.
1807
1808 // Otherwise, check to see if this operation uses AttrSizedSegments.
1809 } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1810 LLVM_DEBUG(llvm::dbgs()
1811 << " * Extracting values from `" << attrSizedSegments << "`\n");
1812
1813 auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
1814 if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1815 return nullptr;
1816
1817 ArrayRef<int32_t> segments = segmentAttr;
1818 unsigned startIndex =
1819 std::accumulate(first: segments.begin(), last: segments.begin() + index, init: 0);
1820 values = values.slice(startIndex, *std::next(x: segments.begin(), n: index));
1821
1822 LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", "
1823 << *std::next(segments.begin(), index) << "]\n");
1824
1825 // Otherwise, assume this is the last operand group of the operation.
1826 // FIXME: We currently don't support operations with
1827 // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1828 // have a way to detect it's presence.
1829 } else if (values.size() >= index) {
1830 LLVM_DEBUG(llvm::dbgs()
1831 << " * Treating values as trailing variadic range\n");
1832 values = values.drop_front(index);
1833
1834 // If we couldn't detect a way to compute the values, bail out.
1835 } else {
1836 return nullptr;
1837 }
1838
1839 // If the range index is valid, we are returning a range.
1840 if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1841 valueRangeMemory[rangeIndex] = values;
1842 return &valueRangeMemory[rangeIndex];
1843 }
1844
1845 // If a range index wasn't provided, the range is required to be non-variadic.
1846 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1847}
1848
1849void ByteCodeExecutor::executeGetOperands() {
1850 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1851 unsigned index = read<uint32_t>();
1852 Operation *op = read<Operation *>();
1853 ByteCodeField rangeIndex = read();
1854
1855 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1856 values: op->getOperands(), op, index, rangeIndex, attrSizedSegments: "operandSegmentSizes",
1857 valueRangeMemory);
1858 if (!result)
1859 LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n");
1860 memory[read()] = result;
1861}
1862
1863void ByteCodeExecutor::executeGetResult(unsigned index) {
1864 Operation *op = read<Operation *>();
1865 unsigned memIndex = read();
1866 OpResult result =
1867 index < op->getNumResults() ? op->getResult(idx: index) : OpResult();
1868
1869 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1870 << " * Index: " << index << "\n"
1871 << " * Result: " << result << "\n");
1872 memory[memIndex] = result.getAsOpaquePointer();
1873}
1874
1875void ByteCodeExecutor::executeGetResults() {
1876 LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1877 unsigned index = read<uint32_t>();
1878 Operation *op = read<Operation *>();
1879 ByteCodeField rangeIndex = read();
1880
1881 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1882 values: op->getResults(), op, index, rangeIndex, attrSizedSegments: "resultSegmentSizes",
1883 valueRangeMemory);
1884 if (!result)
1885 LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n");
1886 memory[read()] = result;
1887}
1888
1889void ByteCodeExecutor::executeGetUsers() {
1890 LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1891 unsigned memIndex = read();
1892 unsigned rangeIndex = read();
1893 OwningOpRange &range = opRangeMemory[rangeIndex];
1894 memory[memIndex] = &range;
1895
1896 range = OwningOpRange();
1897 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1898 // Read the value.
1899 Value value = read<Value>();
1900 if (!value)
1901 return;
1902 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1903
1904 // Extract the users of a single value.
1905 range = OwningOpRange(std::distance(first: value.user_begin(), last: value.user_end()));
1906 llvm::copy(Range: value.getUsers(), Out: range.begin());
1907 } else {
1908 // Read a range of values.
1909 ValueRange *values = read<ValueRange *>();
1910 if (!values)
1911 return;
1912 LLVM_DEBUG({
1913 llvm::dbgs() << " * Values (" << values->size() << "): ";
1914 llvm::interleaveComma(*values, llvm::dbgs());
1915 llvm::dbgs() << "\n";
1916 });
1917
1918 // Extract all the users of a range of values.
1919 SmallVector<Operation *> users;
1920 for (Value value : *values)
1921 users.append(in_start: value.user_begin(), in_end: value.user_end());
1922 range = OwningOpRange(users.size());
1923 llvm::copy(Range&: users, Out: range.begin());
1924 }
1925
1926 LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n");
1927}
1928
1929void ByteCodeExecutor::executeGetValueType() {
1930 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1931 unsigned memIndex = read();
1932 Value value = read<Value>();
1933 Type type = value ? value.getType() : Type();
1934
1935 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
1936 << " * Result: " << type << "\n");
1937 memory[memIndex] = type.getAsOpaquePointer();
1938}
1939
1940void ByteCodeExecutor::executeGetValueRangeTypes() {
1941 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1942 unsigned memIndex = read();
1943 unsigned rangeIndex = read();
1944 ValueRange *values = read<ValueRange *>();
1945 if (!values) {
1946 LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n");
1947 memory[memIndex] = nullptr;
1948 return;
1949 }
1950
1951 LLVM_DEBUG({
1952 llvm::dbgs() << " * Values (" << values->size() << "): ";
1953 llvm::interleaveComma(*values, llvm::dbgs());
1954 llvm::dbgs() << "\n * Result: ";
1955 llvm::interleaveComma(values->getType(), llvm::dbgs());
1956 llvm::dbgs() << "\n";
1957 });
1958 typeRangeMemory[rangeIndex] = values->getType();
1959 memory[memIndex] = &typeRangeMemory[rangeIndex];
1960}
1961
1962void ByteCodeExecutor::executeIsNotNull() {
1963 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1964 const void *value = read<const void *>();
1965
1966 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1967 selectJump(isTrue: value != nullptr);
1968}
1969
1970void ByteCodeExecutor::executeRecordMatch(
1971 PatternRewriter &rewriter,
1972 SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1973 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1974 unsigned patternIndex = read();
1975 PatternBenefit benefit = currentPatternBenefits[patternIndex];
1976 const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1977
1978 // If the benefit of the pattern is impossible, skip the processing of the
1979 // rest of the pattern.
1980 if (benefit.isImpossibleToMatch()) {
1981 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n");
1982 curCodeIt = dest;
1983 return;
1984 }
1985
1986 // Create a fused location containing the locations of each of the
1987 // operations used in the match. This will be used as the location for
1988 // created operations during the rewrite that don't already have an
1989 // explicit location set.
1990 unsigned numMatchLocs = read();
1991 SmallVector<Location, 4> matchLocs;
1992 matchLocs.reserve(N: numMatchLocs);
1993 for (unsigned i = 0; i != numMatchLocs; ++i)
1994 matchLocs.push_back(Elt: read<Operation *>()->getLoc());
1995 Location matchLoc = rewriter.getFusedLoc(locs: matchLocs);
1996
1997 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
1998 << " * Location: " << matchLoc << "\n");
1999 matches.emplace_back(Args&: matchLoc, Args: patterns[patternIndex], Args&: benefit);
2000 PDLByteCode::MatchResult &match = matches.back();
2001
2002 // Record all of the inputs to the match. If any of the inputs are ranges, we
2003 // will also need to remap the range pointer to memory stored in the match
2004 // state.
2005 unsigned numInputs = read();
2006 match.values.reserve(N: numInputs);
2007 match.typeRangeValues.reserve(N: numInputs);
2008 match.valueRangeValues.reserve(N: numInputs);
2009 for (unsigned i = 0; i < numInputs; ++i) {
2010 switch (read<PDLValue::Kind>()) {
2011 case PDLValue::Kind::TypeRange:
2012 match.typeRangeValues.push_back(Elt: *read<TypeRange *>());
2013 match.values.push_back(Elt: &match.typeRangeValues.back());
2014 break;
2015 case PDLValue::Kind::ValueRange:
2016 match.valueRangeValues.push_back(Elt: *read<ValueRange *>());
2017 match.values.push_back(Elt: &match.valueRangeValues.back());
2018 break;
2019 default:
2020 match.values.push_back(Elt: read<const void *>());
2021 break;
2022 }
2023 }
2024 curCodeIt = dest;
2025}
2026
2027void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
2028 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
2029 Operation *op = read<Operation *>();
2030 SmallVector<Value, 16> args;
2031 readList(list&: args);
2032
2033 LLVM_DEBUG({
2034 llvm::dbgs() << " * Operation: " << *op << "\n"
2035 << " * Values: ";
2036 llvm::interleaveComma(args, llvm::dbgs());
2037 llvm::dbgs() << "\n";
2038 });
2039 rewriter.replaceOp(op, newValues: args);
2040}
2041
2042void ByteCodeExecutor::executeSwitchAttribute() {
2043 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
2044 Attribute value = read<Attribute>();
2045 ArrayAttr cases = read<ArrayAttr>();
2046 handleSwitch(value, cases);
2047}
2048
2049void ByteCodeExecutor::executeSwitchOperandCount() {
2050 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
2051 Operation *op = read<Operation *>();
2052 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2053
2054 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
2055 handleSwitch(op->getNumOperands(), cases);
2056}
2057
2058void ByteCodeExecutor::executeSwitchOperationName() {
2059 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
2060 OperationName value = read<Operation *>()->getName();
2061 size_t caseCount = read();
2062
2063 // The operation names are stored in-line, so to print them out for
2064 // debugging purposes we need to read the array before executing the
2065 // switch so that we can display all of the possible values.
2066 LLVM_DEBUG({
2067 const ByteCodeField *prevCodeIt = curCodeIt;
2068 llvm::dbgs() << " * Value: " << value << "\n"
2069 << " * Cases: ";
2070 llvm::interleaveComma(
2071 llvm::map_range(llvm::seq<size_t>(0, caseCount),
2072 [&](size_t) { return read<OperationName>(); }),
2073 llvm::dbgs());
2074 llvm::dbgs() << "\n";
2075 curCodeIt = prevCodeIt;
2076 });
2077
2078 // Try to find the switch value within any of the cases.
2079 for (size_t i = 0; i != caseCount; ++i) {
2080 if (read<OperationName>() == value) {
2081 curCodeIt += (caseCount - i - 1);
2082 return selectJump(destIndex: i + 1);
2083 }
2084 }
2085 selectJump(destIndex: size_t(0));
2086}
2087
2088void ByteCodeExecutor::executeSwitchResultCount() {
2089 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
2090 Operation *op = read<Operation *>();
2091 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2092
2093 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
2094 handleSwitch(op->getNumResults(), cases);
2095}
2096
2097void ByteCodeExecutor::executeSwitchType() {
2098 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2099 Type value = read<Type>();
2100 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2101 handleSwitch(value, cases);
2102}
2103
2104void ByteCodeExecutor::executeSwitchTypes() {
2105 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2106 TypeRange *value = read<TypeRange *>();
2107 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2108 if (!value) {
2109 LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2110 return selectJump(destIndex: size_t(0));
2111 }
2112 handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2113 return value == caseValue.getAsValueRange<TypeAttr>();
2114 });
2115}
2116
2117LogicalResult
2118ByteCodeExecutor::execute(PatternRewriter &rewriter,
2119 SmallVectorImpl<PDLByteCode::MatchResult> *matches,
2120 std::optional<Location> mainRewriteLoc) {
2121 while (true) {
2122 // Print the location of the operation being executed.
2123 LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2124
2125 OpCode opCode = static_cast<OpCode>(read());
2126 switch (opCode) {
2127 case ApplyConstraint:
2128 executeApplyConstraint(rewriter);
2129 break;
2130 case ApplyRewrite:
2131 if (failed(result: executeApplyRewrite(rewriter)))
2132 return failure();
2133 break;
2134 case AreEqual:
2135 executeAreEqual();
2136 break;
2137 case AreRangesEqual:
2138 executeAreRangesEqual();
2139 break;
2140 case Branch:
2141 executeBranch();
2142 break;
2143 case CheckOperandCount:
2144 executeCheckOperandCount();
2145 break;
2146 case CheckOperationName:
2147 executeCheckOperationName();
2148 break;
2149 case CheckResultCount:
2150 executeCheckResultCount();
2151 break;
2152 case CheckTypes:
2153 executeCheckTypes();
2154 break;
2155 case Continue:
2156 executeContinue();
2157 break;
2158 case CreateConstantTypeRange:
2159 executeCreateConstantTypeRange();
2160 break;
2161 case CreateOperation:
2162 executeCreateOperation(rewriter, mainRewriteLoc: *mainRewriteLoc);
2163 break;
2164 case CreateDynamicTypeRange:
2165 executeDynamicCreateRange<Type>(type: "Type");
2166 break;
2167 case CreateDynamicValueRange:
2168 executeDynamicCreateRange<Value>(type: "Value");
2169 break;
2170 case EraseOp:
2171 executeEraseOp(rewriter);
2172 break;
2173 case ExtractOp:
2174 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2175 break;
2176 case ExtractType:
2177 executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2178 break;
2179 case ExtractValue:
2180 executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2181 break;
2182 case Finalize:
2183 executeFinalize();
2184 LLVM_DEBUG(llvm::dbgs() << "\n");
2185 return success();
2186 case ForEach:
2187 executeForEach();
2188 break;
2189 case GetAttribute:
2190 executeGetAttribute();
2191 break;
2192 case GetAttributeType:
2193 executeGetAttributeType();
2194 break;
2195 case GetDefiningOp:
2196 executeGetDefiningOp();
2197 break;
2198 case GetOperand0:
2199 case GetOperand1:
2200 case GetOperand2:
2201 case GetOperand3: {
2202 unsigned index = opCode - GetOperand0;
2203 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2204 executeGetOperand(index);
2205 break;
2206 }
2207 case GetOperandN:
2208 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2209 executeGetOperand(index: read<uint32_t>());
2210 break;
2211 case GetOperands:
2212 executeGetOperands();
2213 break;
2214 case GetResult0:
2215 case GetResult1:
2216 case GetResult2:
2217 case GetResult3: {
2218 unsigned index = opCode - GetResult0;
2219 LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2220 executeGetResult(index);
2221 break;
2222 }
2223 case GetResultN:
2224 LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2225 executeGetResult(index: read<uint32_t>());
2226 break;
2227 case GetResults:
2228 executeGetResults();
2229 break;
2230 case GetUsers:
2231 executeGetUsers();
2232 break;
2233 case GetValueType:
2234 executeGetValueType();
2235 break;
2236 case GetValueRangeTypes:
2237 executeGetValueRangeTypes();
2238 break;
2239 case IsNotNull:
2240 executeIsNotNull();
2241 break;
2242 case RecordMatch:
2243 assert(matches &&
2244 "expected matches to be provided when executing the matcher");
2245 executeRecordMatch(rewriter, matches&: *matches);
2246 break;
2247 case ReplaceOp:
2248 executeReplaceOp(rewriter);
2249 break;
2250 case SwitchAttribute:
2251 executeSwitchAttribute();
2252 break;
2253 case SwitchOperandCount:
2254 executeSwitchOperandCount();
2255 break;
2256 case SwitchOperationName:
2257 executeSwitchOperationName();
2258 break;
2259 case SwitchResultCount:
2260 executeSwitchResultCount();
2261 break;
2262 case SwitchType:
2263 executeSwitchType();
2264 break;
2265 case SwitchTypes:
2266 executeSwitchTypes();
2267 break;
2268 }
2269 LLVM_DEBUG(llvm::dbgs() << "\n");
2270 }
2271}
2272
2273void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2274 SmallVectorImpl<MatchResult> &matches,
2275 PDLByteCodeMutableState &state) const {
2276 // The first memory slot is always the root operation.
2277 state.memory[0] = op;
2278
2279 // The matcher function always starts at code address 0.
2280 ByteCodeExecutor executor(
2281 matcherByteCode.data(), state.memory, state.opRangeMemory,
2282 state.typeRangeMemory, state.allocatedTypeRangeMemory,
2283 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2284 uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2285 constraintFunctions, rewriteFunctions);
2286 LogicalResult executeResult = executor.execute(rewriter, matches: &matches);
2287 (void)executeResult;
2288 assert(succeeded(executeResult) && "unexpected matcher execution failure");
2289
2290 // Order the found matches by benefit.
2291 std::stable_sort(first: matches.begin(), last: matches.end(),
2292 comp: [](const MatchResult &lhs, const MatchResult &rhs) {
2293 return lhs.benefit > rhs.benefit;
2294 });
2295}
2296
2297LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
2298 const MatchResult &match,
2299 PDLByteCodeMutableState &state) const {
2300 auto *configSet = match.pattern->getConfigSet();
2301 if (configSet)
2302 configSet->notifyRewriteBegin(rewriter);
2303
2304 // The arguments of the rewrite function are stored at the start of the
2305 // memory buffer.
2306 llvm::copy(Range: match.values, Out: state.memory.begin());
2307
2308 ByteCodeExecutor executor(
2309 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2310 state.opRangeMemory, state.typeRangeMemory,
2311 state.allocatedTypeRangeMemory, state.valueRangeMemory,
2312 state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2313 rewriterByteCode, state.currentPatternBenefits, patterns,
2314 constraintFunctions, rewriteFunctions);
2315 LogicalResult result =
2316 executor.execute(rewriter, /*matches=*/nullptr, mainRewriteLoc: match.location);
2317
2318 if (configSet)
2319 configSet->notifyRewriteEnd(rewriter);
2320
2321 // If the rewrite failed, check if the pattern rewriter can recover. If it
2322 // can, we can signal to the pattern applicator to keep trying patterns. If it
2323 // doesn't, we need to bail. Bailing here should be fine, given that we have
2324 // no means to propagate such a failure to the user, and it also indicates a
2325 // bug in the user code (i.e. failable rewrites should not be used with
2326 // pattern rewriters that don't support it).
2327 if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
2328 LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
2329 llvm::report_fatal_error(
2330 reason: "Native PDL Rewrite failed, but the pattern "
2331 "rewriter doesn't support recovery. Failable pattern rewrites should "
2332 "not be used with pattern rewriters that do not support them.");
2333 }
2334 return result;
2335}
2336

source code of mlir/lib/Rewrite/ByteCode.cpp