1//===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements MLIR to byte-code generation and the interpreter.
10//
11//===----------------------------------------------------------------------===//
12
13#include "ByteCode.h"
14#include "mlir/Analysis/Liveness.h"
15#include "mlir/Dialect/PDL/IR/PDLTypes.h"
16#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17#include "mlir/IR/BuiltinOps.h"
18#include "mlir/IR/RegionGraphTraits.h"
19#include "llvm/ADT/IntervalMap.h"
20#include "llvm/ADT/PostOrderIterator.h"
21#include "llvm/ADT/TypeSwitch.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/Format.h"
24#include "llvm/Support/FormatVariadic.h"
25#include <numeric>
26#include <optional>
27
28#define DEBUG_TYPE "pdl-bytecode"
29
30using namespace mlir;
31using namespace mlir::detail;
32
33//===----------------------------------------------------------------------===//
34// PDLByteCodePattern
35//===----------------------------------------------------------------------===//
36
37PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
38 PDLPatternConfigSet *configSet,
39 ByteCodeAddr rewriterAddr) {
40 PatternBenefit benefit = matchOp.getBenefit();
41 MLIRContext *ctx = matchOp.getContext();
42
43 // Collect the set of generated operations.
44 SmallVector<StringRef, 8> generatedOps;
45 if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
46 generatedOps =
47 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
48
49 // Check to see if this is pattern matches a specific operation type.
50 if (std::optional<StringRef> rootKind = matchOp.getRootKind())
51 return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
52 generatedOps);
53 return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
54 benefit, ctx, generatedOps);
55}
56
57//===----------------------------------------------------------------------===//
58// PDLByteCodeMutableState
59//===----------------------------------------------------------------------===//
60
61/// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
62/// to the position of the pattern within the range returned by
63/// `PDLByteCode::getPatterns`.
64void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
65 PatternBenefit benefit) {
66 currentPatternBenefits[patternIndex] = benefit;
67}
68
69/// Cleanup any allocated state after a full match/rewrite has been completed.
70/// This method should be called irregardless of whether the match+rewrite was a
71/// success or not.
72void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
73 allocatedTypeRangeMemory.clear();
74 allocatedValueRangeMemory.clear();
75}
76
77//===----------------------------------------------------------------------===//
78// Bytecode OpCodes
79//===----------------------------------------------------------------------===//
80
81namespace {
82enum OpCode : ByteCodeField {
83 /// Apply an externally registered constraint.
84 ApplyConstraint,
85 /// Apply an externally registered rewrite.
86 ApplyRewrite,
87 /// Check if two generic values are equal.
88 AreEqual,
89 /// Check if two ranges are equal.
90 AreRangesEqual,
91 /// Unconditional branch.
92 Branch,
93 /// Compare the operand count of an operation with a constant.
94 CheckOperandCount,
95 /// Compare the name of an operation with a constant.
96 CheckOperationName,
97 /// Compare the result count of an operation with a constant.
98 CheckResultCount,
99 /// Compare a range of types to a constant range of types.
100 CheckTypes,
101 /// Continue to the next iteration of a loop.
102 Continue,
103 /// Create a type range from a list of constant types.
104 CreateConstantTypeRange,
105 /// Create an operation.
106 CreateOperation,
107 /// Create a type range from a list of dynamic types.
108 CreateDynamicTypeRange,
109 /// Create a value range.
110 CreateDynamicValueRange,
111 /// Erase an operation.
112 EraseOp,
113 /// Extract the op from a range at the specified index.
114 ExtractOp,
115 /// Extract the type from a range at the specified index.
116 ExtractType,
117 /// Extract the value from a range at the specified index.
118 ExtractValue,
119 /// Terminate a matcher or rewrite sequence.
120 Finalize,
121 /// Iterate over a range of values.
122 ForEach,
123 /// Get a specific attribute of an operation.
124 GetAttribute,
125 /// Get the type of an attribute.
126 GetAttributeType,
127 /// Get the defining operation of a value.
128 GetDefiningOp,
129 /// Get a specific operand of an operation.
130 GetOperand0,
131 GetOperand1,
132 GetOperand2,
133 GetOperand3,
134 GetOperandN,
135 /// Get a specific operand group of an operation.
136 GetOperands,
137 /// Get a specific result of an operation.
138 GetResult0,
139 GetResult1,
140 GetResult2,
141 GetResult3,
142 GetResultN,
143 /// Get a specific result group of an operation.
144 GetResults,
145 /// Get the users of a value or a range of values.
146 GetUsers,
147 /// Get the type of a value.
148 GetValueType,
149 /// Get the types of a value range.
150 GetValueRangeTypes,
151 /// Check if a generic value is not null.
152 IsNotNull,
153 /// Record a successful pattern match.
154 RecordMatch,
155 /// Replace an operation.
156 ReplaceOp,
157 /// Compare an attribute with a set of constants.
158 SwitchAttribute,
159 /// Compare the operand count of an operation with a set of constants.
160 SwitchOperandCount,
161 /// Compare the name of an operation with a set of constants.
162 SwitchOperationName,
163 /// Compare the result count of an operation with a set of constants.
164 SwitchResultCount,
165 /// Compare a type with a set of constants.
166 SwitchType,
167 /// Compare a range of types with a set of constants.
168 SwitchTypes,
169};
170} // namespace
171
172/// A marker used to indicate if an operation should infer types.
173static constexpr ByteCodeField kInferTypesMarker =
174 std::numeric_limits<ByteCodeField>::max();
175
176//===----------------------------------------------------------------------===//
177// ByteCode Generation
178//===----------------------------------------------------------------------===//
179
180//===----------------------------------------------------------------------===//
181// Generator
182
183namespace {
184struct ByteCodeLiveRange;
185struct ByteCodeWriter;
186
187/// Check if the given class `T` can be converted to an opaque pointer.
188template <typename T, typename... Args>
189using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
190
191/// This class represents the main generator for the pattern bytecode.
192class Generator {
193public:
194 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
195 SmallVectorImpl<ByteCodeField> &matcherByteCode,
196 SmallVectorImpl<ByteCodeField> &rewriterByteCode,
197 SmallVectorImpl<PDLByteCodePattern> &patterns,
198 ByteCodeField &maxValueMemoryIndex,
199 ByteCodeField &maxOpRangeMemoryIndex,
200 ByteCodeField &maxTypeRangeMemoryIndex,
201 ByteCodeField &maxValueRangeMemoryIndex,
202 ByteCodeField &maxLoopLevel,
203 llvm::StringMap<PDLConstraintFunction> &constraintFns,
204 llvm::StringMap<PDLRewriteFunction> &rewriteFns,
205 const DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
206 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
207 rewriterByteCode(rewriterByteCode), patterns(patterns),
208 maxValueMemoryIndex(maxValueMemoryIndex),
209 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
210 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
211 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
212 maxLoopLevel(maxLoopLevel), configMap(configMap) {
213 for (const auto &it : llvm::enumerate(First&: constraintFns))
214 constraintToMemIndex.try_emplace(Key: it.value().first(), Args: it.index());
215 for (const auto &it : llvm::enumerate(First&: rewriteFns))
216 externalRewriterToMemIndex.try_emplace(Key: it.value().first(), Args: it.index());
217 }
218
219 /// Generate the bytecode for the given PDL interpreter module.
220 void generate(ModuleOp module);
221
222 /// Return the memory index to use for the given value.
223 ByteCodeField &getMemIndex(Value value) {
224 assert(valueToMemIndex.count(value) &&
225 "expected memory index to be assigned");
226 return valueToMemIndex[value];
227 }
228
229 /// Return the range memory index used to store the given range value.
230 ByteCodeField &getRangeStorageIndex(Value value) {
231 assert(valueToRangeIndex.count(value) &&
232 "expected range index to be assigned");
233 return valueToRangeIndex[value];
234 }
235
236 /// Return an index to use when referring to the given data that is uniqued in
237 /// the MLIR context.
238 template <typename T>
239 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
240 getMemIndex(T val) {
241 const void *opaqueVal = val.getAsOpaquePointer();
242
243 // Get or insert a reference to this value.
244 auto it = uniquedDataToMemIndex.try_emplace(
245 Key: opaqueVal, Args: maxValueMemoryIndex + uniquedData.size());
246 if (it.second)
247 uniquedData.push_back(x: opaqueVal);
248 return it.first->second;
249 }
250
251private:
252 /// Allocate memory indices for the results of operations within the matcher
253 /// and rewriters.
254 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
255 ModuleOp rewriterModule);
256
257 /// Generate the bytecode for the given operation.
258 void generate(Region *region, ByteCodeWriter &writer);
259 void generate(Operation *op, ByteCodeWriter &writer);
260 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
261 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
262 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
263 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
264 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
265 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
266 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
267 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
268 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
269 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
270 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
271 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
272 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
273 void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
274 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
275 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
276 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
277 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
278 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
279 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
280 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
281 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
282 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
283 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
284 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
285 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
286 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
287 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
288 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
289 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
290 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
291 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
292 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
293 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
294 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
295 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
296 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
297 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
298
299 /// Mapping from value to its corresponding memory index.
300 DenseMap<Value, ByteCodeField> valueToMemIndex;
301
302 /// Mapping from a range value to its corresponding range storage index.
303 DenseMap<Value, ByteCodeField> valueToRangeIndex;
304
305 /// Mapping from the name of an externally registered rewrite to its index in
306 /// the bytecode registry.
307 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
308
309 /// Mapping from the name of an externally registered constraint to its index
310 /// in the bytecode registry.
311 llvm::StringMap<ByteCodeField> constraintToMemIndex;
312
313 /// Mapping from rewriter function name to the bytecode address of the
314 /// rewriter function in byte.
315 llvm::StringMap<ByteCodeAddr> rewriterToAddr;
316
317 /// Mapping from a uniqued storage object to its memory index within
318 /// `uniquedData`.
319 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
320
321 /// The current level of the foreach loop.
322 ByteCodeField curLoopLevel = 0;
323
324 /// The current MLIR context.
325 MLIRContext *ctx;
326
327 /// Mapping from block to its address.
328 DenseMap<Block *, ByteCodeAddr> blockToAddr;
329
330 /// Data of the ByteCode class to be populated.
331 std::vector<const void *> &uniquedData;
332 SmallVectorImpl<ByteCodeField> &matcherByteCode;
333 SmallVectorImpl<ByteCodeField> &rewriterByteCode;
334 SmallVectorImpl<PDLByteCodePattern> &patterns;
335 ByteCodeField &maxValueMemoryIndex;
336 ByteCodeField &maxOpRangeMemoryIndex;
337 ByteCodeField &maxTypeRangeMemoryIndex;
338 ByteCodeField &maxValueRangeMemoryIndex;
339 ByteCodeField &maxLoopLevel;
340
341 /// A map of pattern configurations.
342 const DenseMap<Operation *, PDLPatternConfigSet *> &configMap;
343};
344
345/// This class provides utilities for writing a bytecode stream.
346struct ByteCodeWriter {
347 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
348 : bytecode(bytecode), generator(generator) {}
349
350 /// Append a field to the bytecode.
351 void append(ByteCodeField field) { bytecode.push_back(Elt: field); }
352 void append(OpCode opCode) { bytecode.push_back(Elt: opCode); }
353
354 /// Append an address to the bytecode.
355 void append(ByteCodeAddr field) {
356 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
357 "unexpected ByteCode address size");
358
359 ByteCodeField fieldParts[2];
360 std::memcpy(dest: fieldParts, src: &field, n: sizeof(ByteCodeAddr));
361 bytecode.append(IL: {fieldParts[0], fieldParts[1]});
362 }
363
364 /// Append a single successor to the bytecode, the exact address will need to
365 /// be resolved later.
366 void append(Block *successor) {
367 // Add back a reference to the successor so that the address can be resolved
368 // later.
369 unresolvedSuccessorRefs[successor].push_back(Elt: bytecode.size());
370 append(field: ByteCodeAddr(0));
371 }
372
373 /// Append a successor range to the bytecode, the exact address will need to
374 /// be resolved later.
375 void append(SuccessorRange successors) {
376 for (Block *successor : successors)
377 append(successor);
378 }
379
380 /// Append a range of values that will be read as generic PDLValues.
381 void appendPDLValueList(OperandRange values) {
382 bytecode.push_back(Elt: values.size());
383 for (Value value : values)
384 appendPDLValue(value);
385 }
386
387 /// Append a value as a PDLValue.
388 void appendPDLValue(Value value) {
389 appendPDLValueKind(value);
390 append(value);
391 }
392
393 /// Append the PDLValue::Kind of the given value.
394 void appendPDLValueKind(Value value) { appendPDLValueKind(type: value.getType()); }
395
396 /// Append the PDLValue::Kind of the given type.
397 void appendPDLValueKind(Type type) {
398 PDLValue::Kind kind =
399 TypeSwitch<Type, PDLValue::Kind>(type)
400 .Case<pdl::AttributeType>(
401 [](Type) { return PDLValue::Kind::Attribute; })
402 .Case<pdl::OperationType>(
403 [](Type) { return PDLValue::Kind::Operation; })
404 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
405 if (isa<pdl::TypeType>(rangeTy.getElementType()))
406 return PDLValue::Kind::TypeRange;
407 return PDLValue::Kind::ValueRange;
408 })
409 .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
410 .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
411 bytecode.push_back(Elt: static_cast<ByteCodeField>(kind));
412 }
413
414 /// Append a value that will be stored in a memory slot and not inline within
415 /// the bytecode.
416 template <typename T>
417 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
418 std::is_pointer<T>::value>
419 append(T value) {
420 bytecode.push_back(Elt: generator.getMemIndex(value));
421 }
422
423 /// Append a range of values.
424 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
425 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
426 append(T range) {
427 bytecode.push_back(Elt: llvm::size(range));
428 for (auto it : range)
429 append(it);
430 }
431
432 /// Append a variadic number of fields to the bytecode.
433 template <typename FieldTy, typename Field2Ty, typename... FieldTys>
434 void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
435 append(field);
436 append(field2, fields...);
437 }
438
439 /// Appends a value as a pointer, stored inline within the bytecode.
440 template <typename T>
441 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
442 appendInline(T value) {
443 constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
444 const void *pointer = value.getAsOpaquePointer();
445 ByteCodeField fieldParts[numParts];
446 std::memcpy(dest: fieldParts, src: &pointer, n: sizeof(const void *));
447 bytecode.append(in_start: fieldParts, in_end: fieldParts + numParts);
448 }
449
450 /// Successor references in the bytecode that have yet to be resolved.
451 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
452
453 /// The underlying bytecode buffer.
454 SmallVectorImpl<ByteCodeField> &bytecode;
455
456 /// The main generator producing PDL.
457 Generator &generator;
458};
459
460/// This class represents a live range of PDL Interpreter values, containing
461/// information about when values are live within a match/rewrite.
462struct ByteCodeLiveRange {
463 using Set = llvm::IntervalMap<uint64_t, char, 16>;
464 using Allocator = Set::Allocator;
465
466 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
467
468 /// Union this live range with the one provided.
469 void unionWith(const ByteCodeLiveRange &rhs) {
470 for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
471 ++it)
472 liveness->insert(a: it.start(), b: it.stop(), /*dummyValue*/ y: 0);
473 }
474
475 /// Returns true if this range overlaps with the one provided.
476 bool overlaps(const ByteCodeLiveRange &rhs) const {
477 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
478 .valid();
479 }
480
481 /// A map representing the ranges of the match/rewrite that a value is live in
482 /// the interpreter.
483 ///
484 /// We use std::unique_ptr here, because IntervalMap does not provide a
485 /// correct copy or move constructor. We can eliminate the pointer once
486 /// https://reviews.llvm.org/D113240 lands.
487 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
488
489 /// The operation range storage index for this range.
490 std::optional<unsigned> opRangeIndex;
491
492 /// The type range storage index for this range.
493 std::optional<unsigned> typeRangeIndex;
494
495 /// The value range storage index for this range.
496 std::optional<unsigned> valueRangeIndex;
497};
498} // namespace
499
500void Generator::generate(ModuleOp module) {
501 auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
502 pdl_interp::PDLInterpDialect::getMatcherFunctionName());
503 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
504 pdl_interp::PDLInterpDialect::getRewriterModuleName());
505 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
506
507 // Allocate memory indices for the results of operations within the matcher
508 // and rewriters.
509 allocateMemoryIndices(matcherFunc, rewriterModule);
510
511 // Generate code for the rewriter functions.
512 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
513 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
514 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
515 for (Operation &op : rewriterFunc.getOps())
516 generate(&op, rewriterByteCodeWriter);
517 }
518 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
519 "unexpected branches in rewriter function");
520
521 // Generate code for the matcher function.
522 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
523 generate(&matcherFunc.getBody(), matcherByteCodeWriter);
524
525 // Resolve successor references in the matcher.
526 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
527 ByteCodeAddr addr = blockToAddr[it.first];
528 for (unsigned offsetToFix : it.second)
529 std::memcpy(dest: &matcherByteCode[offsetToFix], src: &addr, n: sizeof(ByteCodeAddr));
530 }
531}
532
533void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
534 ModuleOp rewriterModule) {
535 // Rewriters use simplistic allocation scheme that simply assigns an index to
536 // each result.
537 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
538 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
539 auto processRewriterValue = [&](Value val) {
540 valueToMemIndex.try_emplace(val, index++);
541 if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
542 Type elementTy = rangeType.getElementType();
543 if (isa<pdl::TypeType>(elementTy))
544 valueToRangeIndex.try_emplace(val, typeRangeIndex++);
545 else if (isa<pdl::ValueType>(elementTy))
546 valueToRangeIndex.try_emplace(val, valueRangeIndex++);
547 }
548 };
549
550 for (BlockArgument arg : rewriterFunc.getArguments())
551 processRewriterValue(arg);
552 rewriterFunc.getBody().walk([&](Operation *op) {
553 for (Value result : op->getResults())
554 processRewriterValue(result);
555 });
556 if (index > maxValueMemoryIndex)
557 maxValueMemoryIndex = index;
558 if (typeRangeIndex > maxTypeRangeMemoryIndex)
559 maxTypeRangeMemoryIndex = typeRangeIndex;
560 if (valueRangeIndex > maxValueRangeMemoryIndex)
561 maxValueRangeMemoryIndex = valueRangeIndex;
562 }
563
564 // The matcher function uses a more sophisticated numbering that tries to
565 // minimize the number of memory indices assigned. This is done by determining
566 // a live range of the values within the matcher, then the allocation is just
567 // finding the minimal number of overlapping live ranges. This is essentially
568 // a simplified form of register allocation where we don't necessarily have a
569 // limited number of registers, but we still want to minimize the number used.
570 DenseMap<Operation *, unsigned> opToFirstIndex;
571 DenseMap<Operation *, unsigned> opToLastIndex;
572
573 // A custom walk that marks the first and the last index of each operation.
574 // The entry marks the beginning of the liveness range for this operation,
575 // followed by nested operations, followed by the end of the liveness range.
576 unsigned index = 0;
577 llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
578 opToFirstIndex.try_emplace(Key: op, Args: index++);
579 for (Region &region : op->getRegions())
580 for (Block &block : region.getBlocks())
581 for (Operation &nested : block)
582 walk(&nested);
583 opToLastIndex.try_emplace(Key: op, Args: index++);
584 };
585 walk(matcherFunc);
586
587 // Liveness info for each of the defs within the matcher.
588 ByteCodeLiveRange::Allocator allocator;
589 DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
590
591 // Assign the root operation being matched to slot 0.
592 BlockArgument rootOpArg = matcherFunc.getArgument(0);
593 valueToMemIndex[rootOpArg] = 0;
594
595 // Walk each of the blocks, computing the def interval that the value is used.
596 Liveness matcherLiveness(matcherFunc);
597 matcherFunc->walk([&](Block *block) {
598 const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
599 assert(info && "expected liveness info for block");
600 auto processValue = [&](Value value, Operation *firstUseOrDef) {
601 // We don't need to process the root op argument, this value is always
602 // assigned to the first memory slot.
603 if (value == rootOpArg)
604 return;
605
606 // Set indices for the range of this block that the value is used.
607 auto defRangeIt = valueDefRanges.try_emplace(Key: value, Args&: allocator).first;
608 defRangeIt->second.liveness->insert(
609 a: opToFirstIndex[firstUseOrDef],
610 b: opToLastIndex[info->getEndOperation(value, startOperation: firstUseOrDef)],
611 /*dummyValue*/ y: 0);
612
613 // Check to see if this value is a range type.
614 if (auto rangeTy = dyn_cast<pdl::RangeType>(value.getType())) {
615 Type eleType = rangeTy.getElementType();
616 if (isa<pdl::OperationType>(eleType))
617 defRangeIt->second.opRangeIndex = 0;
618 else if (isa<pdl::TypeType>(eleType))
619 defRangeIt->second.typeRangeIndex = 0;
620 else if (isa<pdl::ValueType>(eleType))
621 defRangeIt->second.valueRangeIndex = 0;
622 }
623 };
624
625 // Process the live-ins of this block.
626 for (Value liveIn : info->in()) {
627 // Only process the value if it has been defined in the current region.
628 // Other values that span across pdl_interp.foreach will be added higher
629 // up. This ensures that the we keep them alive for the entire duration
630 // of the loop.
631 if (liveIn.getParentRegion() == block->getParent())
632 processValue(liveIn, &block->front());
633 }
634
635 // Process the block arguments for the entry block (those are not live-in).
636 if (block->isEntryBlock()) {
637 for (Value argument : block->getArguments())
638 processValue(argument, &block->front());
639 }
640
641 // Process any new defs within this block.
642 for (Operation &op : *block)
643 for (Value result : op.getResults())
644 processValue(result, &op);
645 });
646
647 // Greedily allocate memory slots using the computed def live ranges.
648 std::vector<ByteCodeLiveRange> allocatedIndices;
649
650 // The number of memory indices currently allocated (and its next value).
651 // Recall that the root gets allocated memory index 0.
652 ByteCodeField numIndices = 1;
653
654 // The number of memory ranges of various types (and their next values).
655 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
656
657 for (auto &defIt : valueDefRanges) {
658 ByteCodeField &memIndex = valueToMemIndex[defIt.first];
659 ByteCodeLiveRange &defRange = defIt.second;
660
661 // Try to allocate to an existing index.
662 for (const auto &existingIndexIt : llvm::enumerate(First&: allocatedIndices)) {
663 ByteCodeLiveRange &existingRange = existingIndexIt.value();
664 if (!defRange.overlaps(rhs: existingRange)) {
665 existingRange.unionWith(rhs: defRange);
666 memIndex = existingIndexIt.index() + 1;
667
668 if (defRange.opRangeIndex) {
669 if (!existingRange.opRangeIndex)
670 existingRange.opRangeIndex = numOpRanges++;
671 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
672 } else if (defRange.typeRangeIndex) {
673 if (!existingRange.typeRangeIndex)
674 existingRange.typeRangeIndex = numTypeRanges++;
675 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
676 } else if (defRange.valueRangeIndex) {
677 if (!existingRange.valueRangeIndex)
678 existingRange.valueRangeIndex = numValueRanges++;
679 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
680 }
681 break;
682 }
683 }
684
685 // If no existing index could be used, add a new one.
686 if (memIndex == 0) {
687 allocatedIndices.emplace_back(args&: allocator);
688 ByteCodeLiveRange &newRange = allocatedIndices.back();
689 newRange.unionWith(rhs: defRange);
690
691 // Allocate an index for op/type/value ranges.
692 if (defRange.opRangeIndex) {
693 newRange.opRangeIndex = numOpRanges;
694 valueToRangeIndex[defIt.first] = numOpRanges++;
695 } else if (defRange.typeRangeIndex) {
696 newRange.typeRangeIndex = numTypeRanges;
697 valueToRangeIndex[defIt.first] = numTypeRanges++;
698 } else if (defRange.valueRangeIndex) {
699 newRange.valueRangeIndex = numValueRanges;
700 valueToRangeIndex[defIt.first] = numValueRanges++;
701 }
702
703 memIndex = allocatedIndices.size();
704 ++numIndices;
705 }
706 }
707
708 // Print the index usage and ensure that we did not run out of index space.
709 LLVM_DEBUG({
710 llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
711 << "(down from initial " << valueDefRanges.size() << ").\n";
712 });
713 assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
714 "Ran out of memory for allocated indices");
715
716 // Update the max number of indices.
717 if (numIndices > maxValueMemoryIndex)
718 maxValueMemoryIndex = numIndices;
719 if (numOpRanges > maxOpRangeMemoryIndex)
720 maxOpRangeMemoryIndex = numOpRanges;
721 if (numTypeRanges > maxTypeRangeMemoryIndex)
722 maxTypeRangeMemoryIndex = numTypeRanges;
723 if (numValueRanges > maxValueRangeMemoryIndex)
724 maxValueRangeMemoryIndex = numValueRanges;
725}
726
727void Generator::generate(Region *region, ByteCodeWriter &writer) {
728 llvm::ReversePostOrderTraversal<Region *> rpot(region);
729 for (Block *block : rpot) {
730 // Keep track of where this block begins within the matcher function.
731 blockToAddr.try_emplace(Key: block, Args: matcherByteCode.size());
732 for (Operation &op : *block)
733 generate(op: &op, writer);
734 }
735}
736
737void Generator::generate(Operation *op, ByteCodeWriter &writer) {
738 LLVM_DEBUG({
739 // The following list must contain all the operations that do not
740 // produce any bytecode.
741 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
742 writer.appendInline(op->getLoc());
743 });
744 TypeSwitch<Operation *>(op)
745 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
746 pdl_interp::AreEqualOp, pdl_interp::BranchOp,
747 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
748 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
749 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
750 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
751 pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
752 pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
753 pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
754 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
755 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
756 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
757 pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
758 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
759 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
760 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
761 pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
762 pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
763 pdl_interp::SwitchResultCountOp>(
764 [&](auto interpOp) { this->generate(interpOp, writer); })
765 .Default([](Operation *) {
766 llvm_unreachable("unknown `pdl_interp` operation");
767 });
768}
769
770void Generator::generate(pdl_interp::ApplyConstraintOp op,
771 ByteCodeWriter &writer) {
772 // Constraints that should return a value have to be registered as rewrites.
773 // If a constraint and a rewrite of similar name are registered the
774 // constraint takes precedence
775 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
776 writer.appendPDLValueList(values: op.getArgs());
777 writer.append(field: ByteCodeField(op.getIsNegated()));
778 ResultRange results = op.getResults();
779 writer.append(field: ByteCodeField(results.size()));
780 for (Value result : results) {
781 // We record the expected kind of the result, so that we can provide extra
782 // verification of the native rewrite function and handle the failure case
783 // of constraints accordingly.
784 writer.appendPDLValueKind(result);
785
786 // Range results also need to append the range storage index.
787 if (isa<pdl::RangeType>(result.getType()))
788 writer.append(getRangeStorageIndex(result));
789 writer.append(result);
790 }
791 writer.append(op.getSuccessors());
792}
793void Generator::generate(pdl_interp::ApplyRewriteOp op,
794 ByteCodeWriter &writer) {
795 assert(externalRewriterToMemIndex.count(op.getName()) &&
796 "expected index for rewrite function");
797 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
798 writer.appendPDLValueList(values: op.getArgs());
799
800 ResultRange results = op.getResults();
801 writer.append(field: ByteCodeField(results.size()));
802 for (Value result : results) {
803 // We record the expected kind of the result, so that we
804 // can provide extra verification of the native rewrite function.
805 writer.appendPDLValueKind(result);
806
807 // Range results also need to append the range storage index.
808 if (isa<pdl::RangeType>(result.getType()))
809 writer.append(getRangeStorageIndex(result));
810 writer.append(result);
811 }
812}
813void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
814 Value lhs = op.getLhs();
815 if (isa<pdl::RangeType>(lhs.getType())) {
816 writer.append(opCode: OpCode::AreRangesEqual);
817 writer.appendPDLValueKind(value: lhs);
818 writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
819 return;
820 }
821
822 writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
823}
824void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
825 writer.append(field: OpCode::Branch, field2: SuccessorRange(op.getOperation()));
826}
827void Generator::generate(pdl_interp::CheckAttributeOp op,
828 ByteCodeWriter &writer) {
829 writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
830 op.getSuccessors());
831}
832void Generator::generate(pdl_interp::CheckOperandCountOp op,
833 ByteCodeWriter &writer) {
834 writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
835 static_cast<ByteCodeField>(op.getCompareAtLeast()),
836 op.getSuccessors());
837}
838void Generator::generate(pdl_interp::CheckOperationNameOp op,
839 ByteCodeWriter &writer) {
840 writer.append(OpCode::CheckOperationName, op.getInputOp(),
841 OperationName(op.getName(), ctx), op.getSuccessors());
842}
843void Generator::generate(pdl_interp::CheckResultCountOp op,
844 ByteCodeWriter &writer) {
845 writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
846 static_cast<ByteCodeField>(op.getCompareAtLeast()),
847 op.getSuccessors());
848}
849void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
850 writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
851 op.getSuccessors());
852}
853void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
854 writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
855 op.getSuccessors());
856}
857void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
858 assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
859 writer.append(field: OpCode::Continue, field2: ByteCodeField(curLoopLevel - 1));
860}
861void Generator::generate(pdl_interp::CreateAttributeOp op,
862 ByteCodeWriter &writer) {
863 // Simply repoint the memory index of the result to the constant.
864 getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
865}
866void Generator::generate(pdl_interp::CreateOperationOp op,
867 ByteCodeWriter &writer) {
868 writer.append(OpCode::CreateOperation, op.getResultOp(),
869 OperationName(op.getName(), ctx));
870 writer.appendPDLValueList(values: op.getInputOperands());
871
872 // Add the attributes.
873 OperandRange attributes = op.getInputAttributes();
874 writer.append(field: static_cast<ByteCodeField>(attributes.size()));
875 for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
876 writer.append(std::get<0>(it), std::get<1>(it));
877
878 // Add the result types. If the operation has inferred results, we use a
879 // marker "size" value. Otherwise, we add the list of explicit result types.
880 if (op.getInferredResultTypes())
881 writer.append(field: kInferTypesMarker);
882 else
883 writer.appendPDLValueList(values: op.getInputResultTypes());
884}
885void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
886 // Append the correct opcode for the range type.
887 TypeSwitch<Type>(op.getType().getElementType())
888 .Case(
889 caseFn: [&](pdl::TypeType) { writer.append(opCode: OpCode::CreateDynamicTypeRange); })
890 .Case(caseFn: [&](pdl::ValueType) {
891 writer.append(opCode: OpCode::CreateDynamicValueRange);
892 });
893
894 writer.append(op.getResult(), getRangeStorageIndex(value: op.getResult()));
895 writer.appendPDLValueList(values: op->getOperands());
896}
897void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
898 // Simply repoint the memory index of the result to the constant.
899 getMemIndex(op.getResult()) = getMemIndex(op.getValue());
900}
901void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
902 writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
903 getRangeStorageIndex(value: op.getResult()), op.getValue());
904}
905void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
906 writer.append(OpCode::EraseOp, op.getInputOp());
907}
908void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
909 OpCode opCode =
910 TypeSwitch<Type, OpCode>(op.getResult().getType())
911 .Case(caseFn: [](pdl::OperationType) { return OpCode::ExtractOp; })
912 .Case(caseFn: [](pdl::ValueType) { return OpCode::ExtractValue; })
913 .Case(caseFn: [](pdl::TypeType) { return OpCode::ExtractType; })
914 .Default(defaultFn: [](Type) -> OpCode {
915 llvm_unreachable("unsupported element type");
916 });
917 writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
918}
919void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
920 writer.append(opCode: OpCode::Finalize);
921}
922void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
923 BlockArgument arg = op.getLoopVariable();
924 writer.append(OpCode::ForEach, getRangeStorageIndex(value: op.getValues()), arg);
925 writer.appendPDLValueKind(type: arg.getType());
926 writer.append(curLoopLevel, op.getSuccessor());
927 ++curLoopLevel;
928 if (curLoopLevel > maxLoopLevel)
929 maxLoopLevel = curLoopLevel;
930 generate(&op.getRegion(), writer);
931 --curLoopLevel;
932}
933void Generator::generate(pdl_interp::GetAttributeOp op,
934 ByteCodeWriter &writer) {
935 writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
936 op.getNameAttr());
937}
938void Generator::generate(pdl_interp::GetAttributeTypeOp op,
939 ByteCodeWriter &writer) {
940 writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
941}
942void Generator::generate(pdl_interp::GetDefiningOpOp op,
943 ByteCodeWriter &writer) {
944 writer.append(OpCode::GetDefiningOp, op.getInputOp());
945 writer.appendPDLValue(value: op.getValue());
946}
947void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
948 uint32_t index = op.getIndex();
949 if (index < 4)
950 writer.append(opCode: static_cast<OpCode>(OpCode::GetOperand0 + index));
951 else
952 writer.append(field: OpCode::GetOperandN, field2: index);
953 writer.append(op.getInputOp(), op.getValue());
954}
955void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
956 Value result = op.getValue();
957 std::optional<uint32_t> index = op.getIndex();
958 writer.append(OpCode::GetOperands,
959 index.value_or(u: std::numeric_limits<uint32_t>::max()),
960 op.getInputOp());
961 if (isa<pdl::RangeType>(result.getType()))
962 writer.append(field: getRangeStorageIndex(value: result));
963 else
964 writer.append(field: std::numeric_limits<ByteCodeField>::max());
965 writer.append(value: result);
966}
967void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
968 uint32_t index = op.getIndex();
969 if (index < 4)
970 writer.append(opCode: static_cast<OpCode>(OpCode::GetResult0 + index));
971 else
972 writer.append(field: OpCode::GetResultN, field2: index);
973 writer.append(op.getInputOp(), op.getValue());
974}
975void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
976 Value result = op.getValue();
977 std::optional<uint32_t> index = op.getIndex();
978 writer.append(OpCode::GetResults,
979 index.value_or(u: std::numeric_limits<uint32_t>::max()),
980 op.getInputOp());
981 if (isa<pdl::RangeType>(result.getType()))
982 writer.append(field: getRangeStorageIndex(value: result));
983 else
984 writer.append(field: std::numeric_limits<ByteCodeField>::max());
985 writer.append(value: result);
986}
987void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
988 Value operations = op.getOperations();
989 ByteCodeField rangeIndex = getRangeStorageIndex(value: operations);
990 writer.append(field: OpCode::GetUsers, field2: operations, fields: rangeIndex);
991 writer.appendPDLValue(value: op.getValue());
992}
993void Generator::generate(pdl_interp::GetValueTypeOp op,
994 ByteCodeWriter &writer) {
995 if (isa<pdl::RangeType>(op.getType())) {
996 Value result = op.getResult();
997 writer.append(OpCode::GetValueRangeTypes, result,
998 getRangeStorageIndex(value: result), op.getValue());
999 } else {
1000 writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
1001 }
1002}
1003void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
1004 writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
1005}
1006void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
1007 ByteCodeField patternIndex = patterns.size();
1008 patterns.emplace_back(PDLByteCodePattern::create(
1009 matchOp: op, configSet: configMap.lookup(Val: op),
1010 rewriterAddr: rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
1011 writer.append(OpCode::RecordMatch, patternIndex,
1012 SuccessorRange(op.getOperation()), op.getMatchedOps());
1013 writer.appendPDLValueList(values: op.getInputs());
1014}
1015void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1016 writer.append(OpCode::ReplaceOp, op.getInputOp());
1017 writer.appendPDLValueList(values: op.getReplValues());
1018}
1019void Generator::generate(pdl_interp::SwitchAttributeOp op,
1020 ByteCodeWriter &writer) {
1021 writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1022 op.getCaseValuesAttr(), op.getSuccessors());
1023}
1024void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1025 ByteCodeWriter &writer) {
1026 writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1027 op.getCaseValuesAttr(), op.getSuccessors());
1028}
1029void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1030 ByteCodeWriter &writer) {
1031 auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
1032 return OperationName(cast<StringAttr>(attr).getValue(), ctx);
1033 });
1034 writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1035 op.getSuccessors());
1036}
1037void Generator::generate(pdl_interp::SwitchResultCountOp op,
1038 ByteCodeWriter &writer) {
1039 writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1040 op.getCaseValuesAttr(), op.getSuccessors());
1041}
1042void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1043 writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1044 op.getSuccessors());
1045}
1046void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1047 writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1048 op.getSuccessors());
1049}
1050
1051//===----------------------------------------------------------------------===//
1052// PDLByteCode
1053//===----------------------------------------------------------------------===//
1054
1055PDLByteCode::PDLByteCode(
1056 ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1057 const DenseMap<Operation *, PDLPatternConfigSet *> &configMap,
1058 llvm::StringMap<PDLConstraintFunction> constraintFns,
1059 llvm::StringMap<PDLRewriteFunction> rewriteFns)
1060 : configs(std::move(configs)) {
1061 Generator generator(module.getContext(), uniquedData, matcherByteCode,
1062 rewriterByteCode, patterns, maxValueMemoryIndex,
1063 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1064 maxLoopLevel, constraintFns, rewriteFns, configMap);
1065 generator.generate(module);
1066
1067 // Initialize the external functions.
1068 for (auto &it : constraintFns)
1069 constraintFunctions.push_back(x: std::move(it.second));
1070 for (auto &it : rewriteFns)
1071 rewriteFunctions.push_back(x: std::move(it.second));
1072}
1073
1074/// Initialize the given state such that it can be used to execute the current
1075/// bytecode.
1076void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1077 state.memory.resize(new_size: maxValueMemoryIndex, x: nullptr);
1078 state.opRangeMemory.resize(new_size: maxOpRangeCount);
1079 state.typeRangeMemory.resize(new_size: maxTypeRangeCount, x: TypeRange());
1080 state.valueRangeMemory.resize(new_size: maxValueRangeCount, x: ValueRange());
1081 state.loopIndex.resize(new_size: maxLoopLevel, x: 0);
1082 state.currentPatternBenefits.reserve(n: patterns.size());
1083 for (const PDLByteCodePattern &pattern : patterns)
1084 state.currentPatternBenefits.push_back(x: pattern.getBenefit());
1085}
1086
1087//===----------------------------------------------------------------------===//
1088// ByteCode Execution
1089
1090namespace {
1091/// This class is an instantiation of the PDLResultList that provides access to
1092/// the returned results. This API is not on `PDLResultList` to avoid
1093/// overexposing access to information specific solely to the ByteCode.
1094class ByteCodeRewriteResultList : public PDLResultList {
1095public:
1096 ByteCodeRewriteResultList(unsigned maxNumResults)
1097 : PDLResultList(maxNumResults) {}
1098
1099 /// Return the list of PDL results.
1100 MutableArrayRef<PDLValue> getResults() { return results; }
1101
1102 /// Return the type ranges allocated by this list.
1103 MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1104 return allocatedTypeRanges;
1105 }
1106
1107 /// Return the value ranges allocated by this list.
1108 MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1109 return allocatedValueRanges;
1110 }
1111};
1112
1113/// This class provides support for executing a bytecode stream.
1114class ByteCodeExecutor {
1115public:
1116 ByteCodeExecutor(
1117 const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1118 MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1119 MutableArrayRef<TypeRange> typeRangeMemory,
1120 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1121 MutableArrayRef<ValueRange> valueRangeMemory,
1122 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1123 MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1124 ArrayRef<ByteCodeField> code,
1125 ArrayRef<PatternBenefit> currentPatternBenefits,
1126 ArrayRef<PDLByteCodePattern> patterns,
1127 ArrayRef<PDLConstraintFunction> constraintFunctions,
1128 ArrayRef<PDLRewriteFunction> rewriteFunctions)
1129 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1130 typeRangeMemory(typeRangeMemory),
1131 allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1132 valueRangeMemory(valueRangeMemory),
1133 allocatedValueRangeMemory(allocatedValueRangeMemory),
1134 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1135 currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1136 constraintFunctions(constraintFunctions),
1137 rewriteFunctions(rewriteFunctions) {}
1138
1139 /// Start executing the code at the current bytecode index. `matches` is an
1140 /// optional field provided when this function is executed in a matching
1141 /// context.
1142 LogicalResult
1143 execute(PatternRewriter &rewriter,
1144 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1145 std::optional<Location> mainRewriteLoc = {});
1146
1147private:
1148 /// Internal implementation of executing each of the bytecode commands.
1149 void executeApplyConstraint(PatternRewriter &rewriter);
1150 LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
1151 void executeAreEqual();
1152 void executeAreRangesEqual();
1153 void executeBranch();
1154 void executeCheckOperandCount();
1155 void executeCheckOperationName();
1156 void executeCheckResultCount();
1157 void executeCheckTypes();
1158 void executeContinue();
1159 void executeCreateConstantTypeRange();
1160 void executeCreateOperation(PatternRewriter &rewriter,
1161 Location mainRewriteLoc);
1162 template <typename T>
1163 void executeDynamicCreateRange(StringRef type);
1164 void executeEraseOp(PatternRewriter &rewriter);
1165 template <typename T, typename Range, PDLValue::Kind kind>
1166 void executeExtract();
1167 void executeFinalize();
1168 void executeForEach();
1169 void executeGetAttribute();
1170 void executeGetAttributeType();
1171 void executeGetDefiningOp();
1172 void executeGetOperand(unsigned index);
1173 void executeGetOperands();
1174 void executeGetResult(unsigned index);
1175 void executeGetResults();
1176 void executeGetUsers();
1177 void executeGetValueType();
1178 void executeGetValueRangeTypes();
1179 void executeIsNotNull();
1180 void executeRecordMatch(PatternRewriter &rewriter,
1181 SmallVectorImpl<PDLByteCode::MatchResult> &matches);
1182 void executeReplaceOp(PatternRewriter &rewriter);
1183 void executeSwitchAttribute();
1184 void executeSwitchOperandCount();
1185 void executeSwitchOperationName();
1186 void executeSwitchResultCount();
1187 void executeSwitchType();
1188 void executeSwitchTypes();
1189 void processNativeFunResults(ByteCodeRewriteResultList &results,
1190 unsigned numResults,
1191 LogicalResult &rewriteResult);
1192
1193 /// Pushes a code iterator to the stack.
1194 void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(Elt: it); }
1195
1196 /// Pops a code iterator from the stack, returning true on success.
1197 void popCodeIt() {
1198 assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1199 curCodeIt = resumeCodeIt.back();
1200 resumeCodeIt.pop_back();
1201 }
1202
1203 /// Return the bytecode iterator at the start of the current op code.
1204 const ByteCodeField *getPrevCodeIt() const {
1205 LLVM_DEBUG({
1206 // Account for the op code and the Location stored inline.
1207 return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1208 });
1209
1210 // Account for the op code only.
1211 return curCodeIt - 1;
1212 }
1213
1214 /// Read a value from the bytecode buffer, optionally skipping a certain
1215 /// number of prefix values. These methods always update the buffer to point
1216 /// to the next field after the read data.
1217 template <typename T = ByteCodeField>
1218 T read(size_t skipN = 0) {
1219 curCodeIt += skipN;
1220 return readImpl<T>();
1221 }
1222 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1223
1224 /// Read a list of values from the bytecode buffer.
1225 template <typename ValueT, typename T>
1226 void readList(SmallVectorImpl<T> &list) {
1227 list.clear();
1228 for (unsigned i = 0, e = read(); i != e; ++i)
1229 list.push_back(read<ValueT>());
1230 }
1231
1232 /// Read a list of values from the bytecode buffer. The values may be encoded
1233 /// either as a single element or a range of elements.
1234 void readList(SmallVectorImpl<Type> &list) {
1235 for (unsigned i = 0, e = read(); i != e; ++i) {
1236 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1237 list.push_back(Elt: read<Type>());
1238 } else {
1239 TypeRange *values = read<TypeRange *>();
1240 list.append(in_start: values->begin(), in_end: values->end());
1241 }
1242 }
1243 }
1244 void readList(SmallVectorImpl<Value> &list) {
1245 for (unsigned i = 0, e = read(); i != e; ++i) {
1246 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1247 list.push_back(Elt: read<Value>());
1248 } else {
1249 ValueRange *values = read<ValueRange *>();
1250 list.append(in_start: values->begin(), in_end: values->end());
1251 }
1252 }
1253 }
1254
1255 /// Read a value stored inline as a pointer.
1256 template <typename T>
1257 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1258 readInline() {
1259 const void *pointer;
1260 std::memcpy(dest: &pointer, src: curCodeIt, n: sizeof(const void *));
1261 curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1262 return T::getFromOpaquePointer(pointer);
1263 }
1264
1265 void skip(size_t skipN) { curCodeIt += skipN; }
1266
1267 /// Jump to a specific successor based on a predicate value.
1268 void selectJump(bool isTrue) { selectJump(destIndex: size_t(isTrue ? 0 : 1)); }
1269 /// Jump to a specific successor based on a destination index.
1270 void selectJump(size_t destIndex) {
1271 curCodeIt = &code[read<ByteCodeAddr>(skipN: destIndex * 2)];
1272 }
1273
1274 /// Handle a switch operation with the provided value and cases.
1275 template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1276 void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1277 LLVM_DEBUG({
1278 llvm::dbgs() << " * Value: " << value << "\n"
1279 << " * Cases: ";
1280 llvm::interleaveComma(cases, llvm::dbgs());
1281 llvm::dbgs() << "\n";
1282 });
1283
1284 // Check to see if the attribute value is within the case list. Jump to
1285 // the correct successor index based on the result.
1286 for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1287 if (cmp(*it, value))
1288 return selectJump(destIndex: size_t((it - cases.begin()) + 1));
1289 selectJump(destIndex: size_t(0));
1290 }
1291
1292 /// Store a pointer to memory.
1293 void storeToMemory(unsigned index, const void *value) {
1294 memory[index] = value;
1295 }
1296
1297 /// Store a value to memory as an opaque pointer.
1298 template <typename T>
1299 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1300 storeToMemory(unsigned index, T value) {
1301 memory[index] = value.getAsOpaquePointer();
1302 }
1303
1304 /// Internal implementation of reading various data types from the bytecode
1305 /// stream.
1306 template <typename T>
1307 const void *readFromMemory() {
1308 size_t index = *curCodeIt++;
1309
1310 // If this type is an SSA value, it can only be stored in non-const memory.
1311 if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1312 Value>::value ||
1313 index < memory.size())
1314 return memory[index];
1315
1316 // Otherwise, if this index is not inbounds it is uniqued.
1317 return uniquedMemory[index - memory.size()];
1318 }
1319 template <typename T>
1320 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1321 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1322 }
1323 template <typename T>
1324 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1325 T>
1326 readImpl() {
1327 return T(T::getFromOpaquePointer(readFromMemory<T>()));
1328 }
1329 template <typename T>
1330 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1331 switch (read<PDLValue::Kind>()) {
1332 case PDLValue::Kind::Attribute:
1333 return read<Attribute>();
1334 case PDLValue::Kind::Operation:
1335 return read<Operation *>();
1336 case PDLValue::Kind::Type:
1337 return read<Type>();
1338 case PDLValue::Kind::Value:
1339 return read<Value>();
1340 case PDLValue::Kind::TypeRange:
1341 return read<TypeRange *>();
1342 case PDLValue::Kind::ValueRange:
1343 return read<ValueRange *>();
1344 }
1345 llvm_unreachable("unhandled PDLValue::Kind");
1346 }
1347 template <typename T>
1348 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1349 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1350 "unexpected ByteCode address size");
1351 ByteCodeAddr result;
1352 std::memcpy(dest: &result, src: curCodeIt, n: sizeof(ByteCodeAddr));
1353 curCodeIt += 2;
1354 return result;
1355 }
1356 template <typename T>
1357 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1358 return *curCodeIt++;
1359 }
1360 template <typename T>
1361 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1362 return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1363 }
1364
1365 /// Assign the given range to the given memory index. This allocates a new
1366 /// range object if necessary.
1367 template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
1368 void assignRangeToMemory(RangeT &&range, unsigned memIndex,
1369 unsigned rangeIndex) {
1370 // Utility functor used to type-erase the assignment.
1371 auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
1372 // If the input range is empty, we don't need to allocate anything.
1373 if (range.empty()) {
1374 rangeMemory[rangeIndex] = {};
1375 } else {
1376 // Allocate a buffer for this type range.
1377 llvm::OwningArrayRef<T> storage(llvm::size(range));
1378 llvm::copy(range, storage.begin());
1379
1380 // Assign this to the range slot and use the range as the value for the
1381 // memory index.
1382 allocatedRangeMemory.emplace_back(std::move(storage));
1383 rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1384 }
1385 memory[memIndex] = &rangeMemory[rangeIndex];
1386 };
1387
1388 // Dispatch based on the concrete range type.
1389 if constexpr (std::is_same_v<T, Type>) {
1390 return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1391 } else if constexpr (std::is_same_v<T, Value>) {
1392 return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1393 } else {
1394 llvm_unreachable("unhandled range type");
1395 }
1396 }
1397
1398 /// The underlying bytecode buffer.
1399 const ByteCodeField *curCodeIt;
1400
1401 /// The stack of bytecode positions at which to resume operation.
1402 SmallVector<const ByteCodeField *> resumeCodeIt;
1403
1404 /// The current execution memory.
1405 MutableArrayRef<const void *> memory;
1406 MutableArrayRef<OwningOpRange> opRangeMemory;
1407 MutableArrayRef<TypeRange> typeRangeMemory;
1408 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1409 MutableArrayRef<ValueRange> valueRangeMemory;
1410 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1411
1412 /// The current loop indices.
1413 MutableArrayRef<unsigned> loopIndex;
1414
1415 /// References to ByteCode data necessary for execution.
1416 ArrayRef<const void *> uniquedMemory;
1417 ArrayRef<ByteCodeField> code;
1418 ArrayRef<PatternBenefit> currentPatternBenefits;
1419 ArrayRef<PDLByteCodePattern> patterns;
1420 ArrayRef<PDLConstraintFunction> constraintFunctions;
1421 ArrayRef<PDLRewriteFunction> rewriteFunctions;
1422};
1423} // namespace
1424
1425void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1426 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1427 ByteCodeField fun_idx = read();
1428 SmallVector<PDLValue, 16> args;
1429 readList<PDLValue>(list&: args);
1430
1431 LLVM_DEBUG({
1432 llvm::dbgs() << " * Arguments: ";
1433 llvm::interleaveComma(args, llvm::dbgs());
1434 llvm::dbgs() << "\n";
1435 });
1436
1437 ByteCodeField isNegated = read();
1438 LLVM_DEBUG({
1439 llvm::dbgs() << " * isNegated: " << isNegated << "\n";
1440 llvm::interleaveComma(args, llvm::dbgs());
1441 });
1442
1443 ByteCodeField numResults = read();
1444 const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
1445 ByteCodeRewriteResultList results(numResults);
1446 LogicalResult rewriteResult = constraintFn(rewriter, results, args);
1447 [[maybe_unused]] ArrayRef<PDLValue> constraintResults = results.getResults();
1448 LLVM_DEBUG({
1449 if (succeeded(rewriteResult)) {
1450 llvm::dbgs() << " * Constraint succeeded\n";
1451 llvm::dbgs() << " * Results: ";
1452 llvm::interleaveComma(constraintResults, llvm::dbgs());
1453 llvm::dbgs() << "\n";
1454 } else {
1455 llvm::dbgs() << " * Constraint failed\n";
1456 }
1457 });
1458 assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
1459 "native PDL rewrite function succeeded but returned "
1460 "unexpected number of results");
1461 processNativeFunResults(results, numResults, rewriteResult);
1462
1463 // Depending on the constraint jump to the proper destination.
1464 selectJump(isTrue: isNegated != succeeded(result: rewriteResult));
1465}
1466
1467LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1468 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1469 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1470 SmallVector<PDLValue, 16> args;
1471 readList<PDLValue>(list&: args);
1472
1473 LLVM_DEBUG({
1474 llvm::dbgs() << " * Arguments: ";
1475 llvm::interleaveComma(args, llvm::dbgs());
1476 });
1477
1478 // Execute the rewrite function.
1479 ByteCodeField numResults = read();
1480 ByteCodeRewriteResultList results(numResults);
1481 LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1482
1483 assert(results.getResults().size() == numResults &&
1484 "native PDL rewrite function returned unexpected number of results");
1485
1486 processNativeFunResults(results, numResults, rewriteResult);
1487
1488 if (failed(result: rewriteResult)) {
1489 LLVM_DEBUG(llvm::dbgs() << " - Failed");
1490 return failure();
1491 }
1492 return success();
1493}
1494
1495void ByteCodeExecutor::processNativeFunResults(
1496 ByteCodeRewriteResultList &results, unsigned numResults,
1497 LogicalResult &rewriteResult) {
1498 // Store the results in the bytecode memory or handle missing results on
1499 // failure.
1500 for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1501 PDLValue::Kind resultKind = read<PDLValue::Kind>();
1502
1503 // Skip the according number of values on the buffer on failure and exit
1504 // early as there are no results to process.
1505 if (failed(result: rewriteResult)) {
1506 if (resultKind == PDLValue::Kind::TypeRange ||
1507 resultKind == PDLValue::Kind::ValueRange) {
1508 skip(skipN: 2);
1509 } else {
1510 skip(skipN: 1);
1511 }
1512 return;
1513 }
1514 PDLValue result = results.getResults()[resultIdx];
1515 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
1516 assert(result.getKind() == resultKind &&
1517 "native PDL rewrite function returned an unexpected type of "
1518 "result");
1519 // If the result is a range, we need to copy it over to the bytecodes
1520 // range memory.
1521 if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1522 unsigned rangeIndex = read();
1523 typeRangeMemory[rangeIndex] = *typeRange;
1524 memory[read()] = &typeRangeMemory[rangeIndex];
1525 } else if (std::optional<ValueRange> valueRange =
1526 result.dyn_cast<ValueRange>()) {
1527 unsigned rangeIndex = read();
1528 valueRangeMemory[rangeIndex] = *valueRange;
1529 memory[read()] = &valueRangeMemory[rangeIndex];
1530 } else {
1531 memory[read()] = result.getAsOpaquePointer();
1532 }
1533 }
1534
1535 // Copy over any underlying storage allocated for result ranges.
1536 for (auto &it : results.getAllocatedTypeRanges())
1537 allocatedTypeRangeMemory.push_back(x: std::move(it));
1538 for (auto &it : results.getAllocatedValueRanges())
1539 allocatedValueRangeMemory.push_back(x: std::move(it));
1540}
1541
1542void ByteCodeExecutor::executeAreEqual() {
1543 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1544 const void *lhs = read<const void *>();
1545 const void *rhs = read<const void *>();
1546
1547 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n");
1548 selectJump(isTrue: lhs == rhs);
1549}
1550
1551void ByteCodeExecutor::executeAreRangesEqual() {
1552 LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1553 PDLValue::Kind valueKind = read<PDLValue::Kind>();
1554 const void *lhs = read<const void *>();
1555 const void *rhs = read<const void *>();
1556
1557 switch (valueKind) {
1558 case PDLValue::Kind::TypeRange: {
1559 const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1560 const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1561 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1562 selectJump(isTrue: *lhsRange == *rhsRange);
1563 break;
1564 }
1565 case PDLValue::Kind::ValueRange: {
1566 const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1567 const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1568 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1569 selectJump(isTrue: *lhsRange == *rhsRange);
1570 break;
1571 }
1572 default:
1573 llvm_unreachable("unexpected `AreRangesEqual` value kind");
1574 }
1575}
1576
1577void ByteCodeExecutor::executeBranch() {
1578 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1579 curCodeIt = &code[read<ByteCodeAddr>()];
1580}
1581
1582void ByteCodeExecutor::executeCheckOperandCount() {
1583 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1584 Operation *op = read<Operation *>();
1585 uint32_t expectedCount = read<uint32_t>();
1586 bool compareAtLeast = read();
1587
1588 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
1589 << " * Expected: " << expectedCount << "\n"
1590 << " * Comparator: "
1591 << (compareAtLeast ? ">=" : "==") << "\n");
1592 if (compareAtLeast)
1593 selectJump(isTrue: op->getNumOperands() >= expectedCount);
1594 else
1595 selectJump(isTrue: op->getNumOperands() == expectedCount);
1596}
1597
1598void ByteCodeExecutor::executeCheckOperationName() {
1599 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1600 Operation *op = read<Operation *>();
1601 OperationName expectedName = read<OperationName>();
1602
1603 LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n"
1604 << " * Expected: \"" << expectedName << "\"\n");
1605 selectJump(isTrue: op->getName() == expectedName);
1606}
1607
1608void ByteCodeExecutor::executeCheckResultCount() {
1609 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1610 Operation *op = read<Operation *>();
1611 uint32_t expectedCount = read<uint32_t>();
1612 bool compareAtLeast = read();
1613
1614 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
1615 << " * Expected: " << expectedCount << "\n"
1616 << " * Comparator: "
1617 << (compareAtLeast ? ">=" : "==") << "\n");
1618 if (compareAtLeast)
1619 selectJump(isTrue: op->getNumResults() >= expectedCount);
1620 else
1621 selectJump(isTrue: op->getNumResults() == expectedCount);
1622}
1623
1624void ByteCodeExecutor::executeCheckTypes() {
1625 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1626 TypeRange *lhs = read<TypeRange *>();
1627 Attribute rhs = read<Attribute>();
1628 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1629
1630 selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
1631}
1632
1633void ByteCodeExecutor::executeContinue() {
1634 ByteCodeField level = read();
1635 LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1636 << " * Level: " << level << "\n");
1637 ++loopIndex[level];
1638 popCodeIt();
1639}
1640
1641void ByteCodeExecutor::executeCreateConstantTypeRange() {
1642 LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
1643 unsigned memIndex = read();
1644 unsigned rangeIndex = read();
1645 ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
1646
1647 LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
1648 assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1649 rangeIndex);
1650}
1651
1652void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1653 Location mainRewriteLoc) {
1654 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1655
1656 unsigned memIndex = read();
1657 OperationState state(mainRewriteLoc, read<OperationName>());
1658 readList(list&: state.operands);
1659 for (unsigned i = 0, e = read(); i != e; ++i) {
1660 StringAttr name = read<StringAttr>();
1661 if (Attribute attr = read<Attribute>())
1662 state.addAttribute(name, attr);
1663 }
1664
1665 // Read in the result types. If the "size" is the sentinel value, this
1666 // indicates that the result types should be inferred.
1667 unsigned numResults = read();
1668 if (numResults == kInferTypesMarker) {
1669 InferTypeOpInterface::Concept *inferInterface =
1670 state.name.getInterface<InferTypeOpInterface>();
1671 assert(inferInterface &&
1672 "expected operation to provide InferTypeOpInterface");
1673
1674 // TODO: Handle failure.
1675 if (failed(inferInterface->inferReturnTypes(
1676 state.getContext(), state.location, state.operands,
1677 state.attributes.getDictionary(state.getContext()),
1678 state.getRawProperties(), state.regions, state.types)))
1679 return;
1680 } else {
1681 // Otherwise, this is a fixed number of results.
1682 for (unsigned i = 0; i != numResults; ++i) {
1683 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1684 state.types.push_back(Elt: read<Type>());
1685 } else {
1686 TypeRange *resultTypes = read<TypeRange *>();
1687 state.types.append(in_start: resultTypes->begin(), in_end: resultTypes->end());
1688 }
1689 }
1690 }
1691
1692 Operation *resultOp = rewriter.create(state);
1693 memory[memIndex] = resultOp;
1694
1695 LLVM_DEBUG({
1696 llvm::dbgs() << " * Attributes: "
1697 << state.attributes.getDictionary(state.getContext())
1698 << "\n * Operands: ";
1699 llvm::interleaveComma(state.operands, llvm::dbgs());
1700 llvm::dbgs() << "\n * Result Types: ";
1701 llvm::interleaveComma(state.types, llvm::dbgs());
1702 llvm::dbgs() << "\n * Result: " << *resultOp << "\n";
1703 });
1704}
1705
1706template <typename T>
1707void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1708 LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n");
1709 unsigned memIndex = read();
1710 unsigned rangeIndex = read();
1711 SmallVector<T> values;
1712 readList(values);
1713
1714 LLVM_DEBUG({
1715 llvm::dbgs() << "\n * " << type << "s: ";
1716 llvm::interleaveComma(values, llvm::dbgs());
1717 llvm::dbgs() << "\n";
1718 });
1719
1720 assignRangeToMemory(values, memIndex, rangeIndex);
1721}
1722
1723void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1724 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1725 Operation *op = read<Operation *>();
1726
1727 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1728 rewriter.eraseOp(op);
1729}
1730
1731template <typename T, typename Range, PDLValue::Kind kind>
1732void ByteCodeExecutor::executeExtract() {
1733 LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1734 Range *range = read<Range *>();
1735 unsigned index = read<uint32_t>();
1736 unsigned memIndex = read();
1737
1738 if (!range) {
1739 memory[memIndex] = nullptr;
1740 return;
1741 }
1742
1743 T result = index < range->size() ? (*range)[index] : T();
1744 LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n"
1745 << " * Index: " << index << "\n"
1746 << " * Result: " << result << "\n");
1747 storeToMemory(memIndex, result);
1748}
1749
1750void ByteCodeExecutor::executeFinalize() {
1751 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1752}
1753
1754void ByteCodeExecutor::executeForEach() {
1755 LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1756 const ByteCodeField *prevCodeIt = getPrevCodeIt();
1757 unsigned rangeIndex = read();
1758 unsigned memIndex = read();
1759 const void *value = nullptr;
1760
1761 switch (read<PDLValue::Kind>()) {
1762 case PDLValue::Kind::Operation: {
1763 unsigned &index = loopIndex[read()];
1764 ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1765 assert(index <= array.size() && "iterated past the end");
1766 if (index < array.size()) {
1767 LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n");
1768 value = array[index];
1769 break;
1770 }
1771
1772 LLVM_DEBUG(llvm::dbgs() << " * Done\n");
1773 index = 0;
1774 selectJump(destIndex: size_t(0));
1775 return;
1776 }
1777 default:
1778 llvm_unreachable("unexpected `ForEach` value kind");
1779 }
1780
1781 // Store the iterate value and the stack address.
1782 memory[memIndex] = value;
1783 pushCodeIt(it: prevCodeIt);
1784
1785 // Skip over the successor (we will enter the body of the loop).
1786 read<ByteCodeAddr>();
1787}
1788
1789void ByteCodeExecutor::executeGetAttribute() {
1790 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1791 unsigned memIndex = read();
1792 Operation *op = read<Operation *>();
1793 StringAttr attrName = read<StringAttr>();
1794 Attribute attr = op->getAttr(attrName);
1795
1796 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1797 << " * Attribute: " << attrName << "\n"
1798 << " * Result: " << attr << "\n");
1799 memory[memIndex] = attr.getAsOpaquePointer();
1800}
1801
1802void ByteCodeExecutor::executeGetAttributeType() {
1803 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1804 unsigned memIndex = read();
1805 Attribute attr = read<Attribute>();
1806 Type type;
1807 if (auto typedAttr = dyn_cast<TypedAttr>(attr))
1808 type = typedAttr.getType();
1809
1810 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
1811 << " * Result: " << type << "\n");
1812 memory[memIndex] = type.getAsOpaquePointer();
1813}
1814
1815void ByteCodeExecutor::executeGetDefiningOp() {
1816 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1817 unsigned memIndex = read();
1818 Operation *op = nullptr;
1819 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1820 Value value = read<Value>();
1821 if (value)
1822 op = value.getDefiningOp();
1823 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1824 } else {
1825 ValueRange *values = read<ValueRange *>();
1826 if (values && !values->empty()) {
1827 op = values->front().getDefiningOp();
1828 }
1829 LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n");
1830 }
1831
1832 LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n");
1833 memory[memIndex] = op;
1834}
1835
1836void ByteCodeExecutor::executeGetOperand(unsigned index) {
1837 Operation *op = read<Operation *>();
1838 unsigned memIndex = read();
1839 Value operand =
1840 index < op->getNumOperands() ? op->getOperand(idx: index) : Value();
1841
1842 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1843 << " * Index: " << index << "\n"
1844 << " * Result: " << operand << "\n");
1845 memory[memIndex] = operand.getAsOpaquePointer();
1846}
1847
1848/// This function is the internal implementation of `GetResults` and
1849/// `GetOperands` that provides support for extracting a value range from the
1850/// given operation.
1851template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1852static void *
1853executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1854 ByteCodeField rangeIndex, StringRef attrSizedSegments,
1855 MutableArrayRef<ValueRange> valueRangeMemory) {
1856 // Check for the sentinel index that signals that all values should be
1857 // returned.
1858 if (index == std::numeric_limits<uint32_t>::max()) {
1859 LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n");
1860 // `values` is already the full value range.
1861
1862 // Otherwise, check to see if this operation uses AttrSizedSegments.
1863 } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1864 LLVM_DEBUG(llvm::dbgs()
1865 << " * Extracting values from `" << attrSizedSegments << "`\n");
1866
1867 auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
1868 if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1869 return nullptr;
1870
1871 ArrayRef<int32_t> segments = segmentAttr;
1872 unsigned startIndex =
1873 std::accumulate(first: segments.begin(), last: segments.begin() + index, init: 0);
1874 values = values.slice(startIndex, *std::next(x: segments.begin(), n: index));
1875
1876 LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", "
1877 << *std::next(segments.begin(), index) << "]\n");
1878
1879 // Otherwise, assume this is the last operand group of the operation.
1880 // FIXME: We currently don't support operations with
1881 // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1882 // have a way to detect it's presence.
1883 } else if (values.size() >= index) {
1884 LLVM_DEBUG(llvm::dbgs()
1885 << " * Treating values as trailing variadic range\n");
1886 values = values.drop_front(index);
1887
1888 // If we couldn't detect a way to compute the values, bail out.
1889 } else {
1890 return nullptr;
1891 }
1892
1893 // If the range index is valid, we are returning a range.
1894 if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1895 valueRangeMemory[rangeIndex] = values;
1896 return &valueRangeMemory[rangeIndex];
1897 }
1898
1899 // If a range index wasn't provided, the range is required to be non-variadic.
1900 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1901}
1902
1903void ByteCodeExecutor::executeGetOperands() {
1904 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1905 unsigned index = read<uint32_t>();
1906 Operation *op = read<Operation *>();
1907 ByteCodeField rangeIndex = read();
1908
1909 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1910 values: op->getOperands(), op, index, rangeIndex, attrSizedSegments: "operandSegmentSizes",
1911 valueRangeMemory);
1912 if (!result)
1913 LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n");
1914 memory[read()] = result;
1915}
1916
1917void ByteCodeExecutor::executeGetResult(unsigned index) {
1918 Operation *op = read<Operation *>();
1919 unsigned memIndex = read();
1920 OpResult result =
1921 index < op->getNumResults() ? op->getResult(idx: index) : OpResult();
1922
1923 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1924 << " * Index: " << index << "\n"
1925 << " * Result: " << result << "\n");
1926 memory[memIndex] = result.getAsOpaquePointer();
1927}
1928
1929void ByteCodeExecutor::executeGetResults() {
1930 LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1931 unsigned index = read<uint32_t>();
1932 Operation *op = read<Operation *>();
1933 ByteCodeField rangeIndex = read();
1934
1935 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1936 values: op->getResults(), op, index, rangeIndex, attrSizedSegments: "resultSegmentSizes",
1937 valueRangeMemory);
1938 if (!result)
1939 LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n");
1940 memory[read()] = result;
1941}
1942
1943void ByteCodeExecutor::executeGetUsers() {
1944 LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1945 unsigned memIndex = read();
1946 unsigned rangeIndex = read();
1947 OwningOpRange &range = opRangeMemory[rangeIndex];
1948 memory[memIndex] = &range;
1949
1950 range = OwningOpRange();
1951 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1952 // Read the value.
1953 Value value = read<Value>();
1954 if (!value)
1955 return;
1956 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1957
1958 // Extract the users of a single value.
1959 range = OwningOpRange(std::distance(first: value.user_begin(), last: value.user_end()));
1960 llvm::copy(Range: value.getUsers(), Out: range.begin());
1961 } else {
1962 // Read a range of values.
1963 ValueRange *values = read<ValueRange *>();
1964 if (!values)
1965 return;
1966 LLVM_DEBUG({
1967 llvm::dbgs() << " * Values (" << values->size() << "): ";
1968 llvm::interleaveComma(*values, llvm::dbgs());
1969 llvm::dbgs() << "\n";
1970 });
1971
1972 // Extract all the users of a range of values.
1973 SmallVector<Operation *> users;
1974 for (Value value : *values)
1975 users.append(in_start: value.user_begin(), in_end: value.user_end());
1976 range = OwningOpRange(users.size());
1977 llvm::copy(Range&: users, Out: range.begin());
1978 }
1979
1980 LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n");
1981}
1982
1983void ByteCodeExecutor::executeGetValueType() {
1984 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1985 unsigned memIndex = read();
1986 Value value = read<Value>();
1987 Type type = value ? value.getType() : Type();
1988
1989 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
1990 << " * Result: " << type << "\n");
1991 memory[memIndex] = type.getAsOpaquePointer();
1992}
1993
1994void ByteCodeExecutor::executeGetValueRangeTypes() {
1995 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1996 unsigned memIndex = read();
1997 unsigned rangeIndex = read();
1998 ValueRange *values = read<ValueRange *>();
1999 if (!values) {
2000 LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n");
2001 memory[memIndex] = nullptr;
2002 return;
2003 }
2004
2005 LLVM_DEBUG({
2006 llvm::dbgs() << " * Values (" << values->size() << "): ";
2007 llvm::interleaveComma(*values, llvm::dbgs());
2008 llvm::dbgs() << "\n * Result: ";
2009 llvm::interleaveComma(values->getType(), llvm::dbgs());
2010 llvm::dbgs() << "\n";
2011 });
2012 typeRangeMemory[rangeIndex] = values->getType();
2013 memory[memIndex] = &typeRangeMemory[rangeIndex];
2014}
2015
2016void ByteCodeExecutor::executeIsNotNull() {
2017 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
2018 const void *value = read<const void *>();
2019
2020 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
2021 selectJump(isTrue: value != nullptr);
2022}
2023
2024void ByteCodeExecutor::executeRecordMatch(
2025 PatternRewriter &rewriter,
2026 SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
2027 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
2028 unsigned patternIndex = read();
2029 PatternBenefit benefit = currentPatternBenefits[patternIndex];
2030 const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
2031
2032 // If the benefit of the pattern is impossible, skip the processing of the
2033 // rest of the pattern.
2034 if (benefit.isImpossibleToMatch()) {
2035 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n");
2036 curCodeIt = dest;
2037 return;
2038 }
2039
2040 // Create a fused location containing the locations of each of the
2041 // operations used in the match. This will be used as the location for
2042 // created operations during the rewrite that don't already have an
2043 // explicit location set.
2044 unsigned numMatchLocs = read();
2045 SmallVector<Location, 4> matchLocs;
2046 matchLocs.reserve(N: numMatchLocs);
2047 for (unsigned i = 0; i != numMatchLocs; ++i)
2048 matchLocs.push_back(Elt: read<Operation *>()->getLoc());
2049 Location matchLoc = rewriter.getFusedLoc(locs: matchLocs);
2050
2051 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
2052 << " * Location: " << matchLoc << "\n");
2053 matches.emplace_back(Args&: matchLoc, Args: patterns[patternIndex], Args&: benefit);
2054 PDLByteCode::MatchResult &match = matches.back();
2055
2056 // Record all of the inputs to the match. If any of the inputs are ranges, we
2057 // will also need to remap the range pointer to memory stored in the match
2058 // state.
2059 unsigned numInputs = read();
2060 match.values.reserve(N: numInputs);
2061 match.typeRangeValues.reserve(N: numInputs);
2062 match.valueRangeValues.reserve(N: numInputs);
2063 for (unsigned i = 0; i < numInputs; ++i) {
2064 switch (read<PDLValue::Kind>()) {
2065 case PDLValue::Kind::TypeRange:
2066 match.typeRangeValues.push_back(Elt: *read<TypeRange *>());
2067 match.values.push_back(Elt: &match.typeRangeValues.back());
2068 break;
2069 case PDLValue::Kind::ValueRange:
2070 match.valueRangeValues.push_back(Elt: *read<ValueRange *>());
2071 match.values.push_back(Elt: &match.valueRangeValues.back());
2072 break;
2073 default:
2074 match.values.push_back(Elt: read<const void *>());
2075 break;
2076 }
2077 }
2078 curCodeIt = dest;
2079}
2080
2081void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
2082 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
2083 Operation *op = read<Operation *>();
2084 SmallVector<Value, 16> args;
2085 readList(list&: args);
2086
2087 LLVM_DEBUG({
2088 llvm::dbgs() << " * Operation: " << *op << "\n"
2089 << " * Values: ";
2090 llvm::interleaveComma(args, llvm::dbgs());
2091 llvm::dbgs() << "\n";
2092 });
2093 rewriter.replaceOp(op, newValues: args);
2094}
2095
2096void ByteCodeExecutor::executeSwitchAttribute() {
2097 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
2098 Attribute value = read<Attribute>();
2099 ArrayAttr cases = read<ArrayAttr>();
2100 handleSwitch(value, cases);
2101}
2102
2103void ByteCodeExecutor::executeSwitchOperandCount() {
2104 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
2105 Operation *op = read<Operation *>();
2106 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2107
2108 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
2109 handleSwitch(op->getNumOperands(), cases);
2110}
2111
2112void ByteCodeExecutor::executeSwitchOperationName() {
2113 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
2114 OperationName value = read<Operation *>()->getName();
2115 size_t caseCount = read();
2116
2117 // The operation names are stored in-line, so to print them out for
2118 // debugging purposes we need to read the array before executing the
2119 // switch so that we can display all of the possible values.
2120 LLVM_DEBUG({
2121 const ByteCodeField *prevCodeIt = curCodeIt;
2122 llvm::dbgs() << " * Value: " << value << "\n"
2123 << " * Cases: ";
2124 llvm::interleaveComma(
2125 llvm::map_range(llvm::seq<size_t>(0, caseCount),
2126 [&](size_t) { return read<OperationName>(); }),
2127 llvm::dbgs());
2128 llvm::dbgs() << "\n";
2129 curCodeIt = prevCodeIt;
2130 });
2131
2132 // Try to find the switch value within any of the cases.
2133 for (size_t i = 0; i != caseCount; ++i) {
2134 if (read<OperationName>() == value) {
2135 curCodeIt += (caseCount - i - 1);
2136 return selectJump(destIndex: i + 1);
2137 }
2138 }
2139 selectJump(destIndex: size_t(0));
2140}
2141
2142void ByteCodeExecutor::executeSwitchResultCount() {
2143 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
2144 Operation *op = read<Operation *>();
2145 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2146
2147 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
2148 handleSwitch(op->getNumResults(), cases);
2149}
2150
2151void ByteCodeExecutor::executeSwitchType() {
2152 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2153 Type value = read<Type>();
2154 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2155 handleSwitch(value, cases);
2156}
2157
2158void ByteCodeExecutor::executeSwitchTypes() {
2159 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2160 TypeRange *value = read<TypeRange *>();
2161 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2162 if (!value) {
2163 LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2164 return selectJump(destIndex: size_t(0));
2165 }
2166 handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2167 return value == caseValue.getAsValueRange<TypeAttr>();
2168 });
2169}
2170
2171LogicalResult
2172ByteCodeExecutor::execute(PatternRewriter &rewriter,
2173 SmallVectorImpl<PDLByteCode::MatchResult> *matches,
2174 std::optional<Location> mainRewriteLoc) {
2175 while (true) {
2176 // Print the location of the operation being executed.
2177 LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2178
2179 OpCode opCode = static_cast<OpCode>(read());
2180 switch (opCode) {
2181 case ApplyConstraint:
2182 executeApplyConstraint(rewriter);
2183 break;
2184 case ApplyRewrite:
2185 if (failed(result: executeApplyRewrite(rewriter)))
2186 return failure();
2187 break;
2188 case AreEqual:
2189 executeAreEqual();
2190 break;
2191 case AreRangesEqual:
2192 executeAreRangesEqual();
2193 break;
2194 case Branch:
2195 executeBranch();
2196 break;
2197 case CheckOperandCount:
2198 executeCheckOperandCount();
2199 break;
2200 case CheckOperationName:
2201 executeCheckOperationName();
2202 break;
2203 case CheckResultCount:
2204 executeCheckResultCount();
2205 break;
2206 case CheckTypes:
2207 executeCheckTypes();
2208 break;
2209 case Continue:
2210 executeContinue();
2211 break;
2212 case CreateConstantTypeRange:
2213 executeCreateConstantTypeRange();
2214 break;
2215 case CreateOperation:
2216 executeCreateOperation(rewriter, mainRewriteLoc: *mainRewriteLoc);
2217 break;
2218 case CreateDynamicTypeRange:
2219 executeDynamicCreateRange<Type>(type: "Type");
2220 break;
2221 case CreateDynamicValueRange:
2222 executeDynamicCreateRange<Value>(type: "Value");
2223 break;
2224 case EraseOp:
2225 executeEraseOp(rewriter);
2226 break;
2227 case ExtractOp:
2228 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2229 break;
2230 case ExtractType:
2231 executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2232 break;
2233 case ExtractValue:
2234 executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2235 break;
2236 case Finalize:
2237 executeFinalize();
2238 LLVM_DEBUG(llvm::dbgs() << "\n");
2239 return success();
2240 case ForEach:
2241 executeForEach();
2242 break;
2243 case GetAttribute:
2244 executeGetAttribute();
2245 break;
2246 case GetAttributeType:
2247 executeGetAttributeType();
2248 break;
2249 case GetDefiningOp:
2250 executeGetDefiningOp();
2251 break;
2252 case GetOperand0:
2253 case GetOperand1:
2254 case GetOperand2:
2255 case GetOperand3: {
2256 unsigned index = opCode - GetOperand0;
2257 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2258 executeGetOperand(index);
2259 break;
2260 }
2261 case GetOperandN:
2262 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2263 executeGetOperand(index: read<uint32_t>());
2264 break;
2265 case GetOperands:
2266 executeGetOperands();
2267 break;
2268 case GetResult0:
2269 case GetResult1:
2270 case GetResult2:
2271 case GetResult3: {
2272 unsigned index = opCode - GetResult0;
2273 LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2274 executeGetResult(index);
2275 break;
2276 }
2277 case GetResultN:
2278 LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2279 executeGetResult(index: read<uint32_t>());
2280 break;
2281 case GetResults:
2282 executeGetResults();
2283 break;
2284 case GetUsers:
2285 executeGetUsers();
2286 break;
2287 case GetValueType:
2288 executeGetValueType();
2289 break;
2290 case GetValueRangeTypes:
2291 executeGetValueRangeTypes();
2292 break;
2293 case IsNotNull:
2294 executeIsNotNull();
2295 break;
2296 case RecordMatch:
2297 assert(matches &&
2298 "expected matches to be provided when executing the matcher");
2299 executeRecordMatch(rewriter, matches&: *matches);
2300 break;
2301 case ReplaceOp:
2302 executeReplaceOp(rewriter);
2303 break;
2304 case SwitchAttribute:
2305 executeSwitchAttribute();
2306 break;
2307 case SwitchOperandCount:
2308 executeSwitchOperandCount();
2309 break;
2310 case SwitchOperationName:
2311 executeSwitchOperationName();
2312 break;
2313 case SwitchResultCount:
2314 executeSwitchResultCount();
2315 break;
2316 case SwitchType:
2317 executeSwitchType();
2318 break;
2319 case SwitchTypes:
2320 executeSwitchTypes();
2321 break;
2322 }
2323 LLVM_DEBUG(llvm::dbgs() << "\n");
2324 }
2325}
2326
2327void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2328 SmallVectorImpl<MatchResult> &matches,
2329 PDLByteCodeMutableState &state) const {
2330 // The first memory slot is always the root operation.
2331 state.memory[0] = op;
2332
2333 // The matcher function always starts at code address 0.
2334 ByteCodeExecutor executor(
2335 matcherByteCode.data(), state.memory, state.opRangeMemory,
2336 state.typeRangeMemory, state.allocatedTypeRangeMemory,
2337 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2338 uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2339 constraintFunctions, rewriteFunctions);
2340 LogicalResult executeResult = executor.execute(rewriter, matches: &matches);
2341 (void)executeResult;
2342 assert(succeeded(executeResult) && "unexpected matcher execution failure");
2343
2344 // Order the found matches by benefit.
2345 std::stable_sort(first: matches.begin(), last: matches.end(),
2346 comp: [](const MatchResult &lhs, const MatchResult &rhs) {
2347 return lhs.benefit > rhs.benefit;
2348 });
2349}
2350
2351LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
2352 const MatchResult &match,
2353 PDLByteCodeMutableState &state) const {
2354 auto *configSet = match.pattern->getConfigSet();
2355 if (configSet)
2356 configSet->notifyRewriteBegin(rewriter);
2357
2358 // The arguments of the rewrite function are stored at the start of the
2359 // memory buffer.
2360 llvm::copy(Range: match.values, Out: state.memory.begin());
2361
2362 ByteCodeExecutor executor(
2363 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2364 state.opRangeMemory, state.typeRangeMemory,
2365 state.allocatedTypeRangeMemory, state.valueRangeMemory,
2366 state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2367 rewriterByteCode, state.currentPatternBenefits, patterns,
2368 constraintFunctions, rewriteFunctions);
2369 LogicalResult result =
2370 executor.execute(rewriter, /*matches=*/nullptr, mainRewriteLoc: match.location);
2371
2372 if (configSet)
2373 configSet->notifyRewriteEnd(rewriter);
2374
2375 // If the rewrite failed, check if the pattern rewriter can recover. If it
2376 // can, we can signal to the pattern applicator to keep trying patterns. If it
2377 // doesn't, we need to bail. Bailing here should be fine, given that we have
2378 // no means to propagate such a failure to the user, and it also indicates a
2379 // bug in the user code (i.e. failable rewrites should not be used with
2380 // pattern rewriters that don't support it).
2381 if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
2382 LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
2383 llvm::report_fatal_error(
2384 reason: "Native PDL Rewrite failed, but the pattern "
2385 "rewriter doesn't support recovery. Failable pattern rewrites should "
2386 "not be used with pattern rewriters that do not support them.");
2387 }
2388 return result;
2389}
2390

source code of mlir/lib/Rewrite/ByteCode.cpp