1//===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements MLIR to byte-code generation and the interpreter.
10//
11//===----------------------------------------------------------------------===//
12
13#include "ByteCode.h"
14#include "mlir/Analysis/Liveness.h"
15#include "mlir/Dialect/PDL/IR/PDLTypes.h"
16#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17#include "mlir/IR/BuiltinOps.h"
18#include "mlir/IR/RegionGraphTraits.h"
19#include "llvm/ADT/IntervalMap.h"
20#include "llvm/ADT/PostOrderIterator.h"
21#include "llvm/ADT/TypeSwitch.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/Format.h"
24#include "llvm/Support/FormatVariadic.h"
25#include <numeric>
26#include <optional>
27
28#define DEBUG_TYPE "pdl-bytecode"
29
30using namespace mlir;
31using namespace mlir::detail;
32
33//===----------------------------------------------------------------------===//
34// PDLByteCodePattern
35//===----------------------------------------------------------------------===//
36
37PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
38 PDLPatternConfigSet *configSet,
39 ByteCodeAddr rewriterAddr) {
40 PatternBenefit benefit = matchOp.getBenefit();
41 MLIRContext *ctx = matchOp.getContext();
42
43 // Collect the set of generated operations.
44 SmallVector<StringRef, 8> generatedOps;
45 if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
46 generatedOps =
47 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
48
49 // Check to see if this is pattern matches a specific operation type.
50 if (std::optional<StringRef> rootKind = matchOp.getRootKind())
51 return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
52 generatedOps);
53 return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
54 benefit, ctx, generatedOps);
55}
56
57//===----------------------------------------------------------------------===//
58// PDLByteCodeMutableState
59//===----------------------------------------------------------------------===//
60
61/// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
62/// to the position of the pattern within the range returned by
63/// `PDLByteCode::getPatterns`.
64void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
65 PatternBenefit benefit) {
66 currentPatternBenefits[patternIndex] = benefit;
67}
68
69/// Cleanup any allocated state after a full match/rewrite has been completed.
70/// This method should be called irregardless of whether the match+rewrite was a
71/// success or not.
72void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
73 allocatedTypeRangeMemory.clear();
74 allocatedValueRangeMemory.clear();
75}
76
77//===----------------------------------------------------------------------===//
78// Bytecode OpCodes
79//===----------------------------------------------------------------------===//
80
81namespace {
82enum OpCode : ByteCodeField {
83 /// Apply an externally registered constraint.
84 ApplyConstraint,
85 /// Apply an externally registered rewrite.
86 ApplyRewrite,
87 /// Check if two generic values are equal.
88 AreEqual,
89 /// Check if two ranges are equal.
90 AreRangesEqual,
91 /// Unconditional branch.
92 Branch,
93 /// Compare the operand count of an operation with a constant.
94 CheckOperandCount,
95 /// Compare the name of an operation with a constant.
96 CheckOperationName,
97 /// Compare the result count of an operation with a constant.
98 CheckResultCount,
99 /// Compare a range of types to a constant range of types.
100 CheckTypes,
101 /// Continue to the next iteration of a loop.
102 Continue,
103 /// Create a type range from a list of constant types.
104 CreateConstantTypeRange,
105 /// Create an operation.
106 CreateOperation,
107 /// Create a type range from a list of dynamic types.
108 CreateDynamicTypeRange,
109 /// Create a value range.
110 CreateDynamicValueRange,
111 /// Erase an operation.
112 EraseOp,
113 /// Extract the op from a range at the specified index.
114 ExtractOp,
115 /// Extract the type from a range at the specified index.
116 ExtractType,
117 /// Extract the value from a range at the specified index.
118 ExtractValue,
119 /// Terminate a matcher or rewrite sequence.
120 Finalize,
121 /// Iterate over a range of values.
122 ForEach,
123 /// Get a specific attribute of an operation.
124 GetAttribute,
125 /// Get the type of an attribute.
126 GetAttributeType,
127 /// Get the defining operation of a value.
128 GetDefiningOp,
129 /// Get a specific operand of an operation.
130 GetOperand0,
131 GetOperand1,
132 GetOperand2,
133 GetOperand3,
134 GetOperandN,
135 /// Get a specific operand group of an operation.
136 GetOperands,
137 /// Get a specific result of an operation.
138 GetResult0,
139 GetResult1,
140 GetResult2,
141 GetResult3,
142 GetResultN,
143 /// Get a specific result group of an operation.
144 GetResults,
145 /// Get the users of a value or a range of values.
146 GetUsers,
147 /// Get the type of a value.
148 GetValueType,
149 /// Get the types of a value range.
150 GetValueRangeTypes,
151 /// Check if a generic value is not null.
152 IsNotNull,
153 /// Record a successful pattern match.
154 RecordMatch,
155 /// Replace an operation.
156 ReplaceOp,
157 /// Compare an attribute with a set of constants.
158 SwitchAttribute,
159 /// Compare the operand count of an operation with a set of constants.
160 SwitchOperandCount,
161 /// Compare the name of an operation with a set of constants.
162 SwitchOperationName,
163 /// Compare the result count of an operation with a set of constants.
164 SwitchResultCount,
165 /// Compare a type with a set of constants.
166 SwitchType,
167 /// Compare a range of types with a set of constants.
168 SwitchTypes,
169};
170} // namespace
171
172/// A marker used to indicate if an operation should infer types.
173static constexpr ByteCodeField kInferTypesMarker =
174 std::numeric_limits<ByteCodeField>::max();
175
176//===----------------------------------------------------------------------===//
177// ByteCode Generation
178//===----------------------------------------------------------------------===//
179
180//===----------------------------------------------------------------------===//
181// Generator
182//===----------------------------------------------------------------------===//
183
184namespace {
185struct ByteCodeLiveRange;
186struct ByteCodeWriter;
187
188/// Check if the given class `T` can be converted to an opaque pointer.
189template <typename T, typename... Args>
190using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
191
192/// This class represents the main generator for the pattern bytecode.
193class Generator {
194public:
195 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
196 SmallVectorImpl<ByteCodeField> &matcherByteCode,
197 SmallVectorImpl<ByteCodeField> &rewriterByteCode,
198 SmallVectorImpl<PDLByteCodePattern> &patterns,
199 ByteCodeField &maxValueMemoryIndex,
200 ByteCodeField &maxOpRangeMemoryIndex,
201 ByteCodeField &maxTypeRangeMemoryIndex,
202 ByteCodeField &maxValueRangeMemoryIndex,
203 ByteCodeField &maxLoopLevel,
204 llvm::StringMap<PDLConstraintFunction> &constraintFns,
205 llvm::StringMap<PDLRewriteFunction> &rewriteFns,
206 const DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
207 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
208 rewriterByteCode(rewriterByteCode), patterns(patterns),
209 maxValueMemoryIndex(maxValueMemoryIndex),
210 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
211 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
212 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
213 maxLoopLevel(maxLoopLevel), configMap(configMap) {
214 for (const auto &it : llvm::enumerate(First&: constraintFns))
215 constraintToMemIndex.try_emplace(Key: it.value().first(), Args: it.index());
216 for (const auto &it : llvm::enumerate(First&: rewriteFns))
217 externalRewriterToMemIndex.try_emplace(Key: it.value().first(), Args: it.index());
218 }
219
220 /// Generate the bytecode for the given PDL interpreter module.
221 void generate(ModuleOp module);
222
223 /// Return the memory index to use for the given value.
224 ByteCodeField &getMemIndex(Value value) {
225 assert(valueToMemIndex.count(value) &&
226 "expected memory index to be assigned");
227 return valueToMemIndex[value];
228 }
229
230 /// Return the range memory index used to store the given range value.
231 ByteCodeField &getRangeStorageIndex(Value value) {
232 assert(valueToRangeIndex.count(value) &&
233 "expected range index to be assigned");
234 return valueToRangeIndex[value];
235 }
236
237 /// Return an index to use when referring to the given data that is uniqued in
238 /// the MLIR context.
239 template <typename T>
240 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
241 getMemIndex(T val) {
242 const void *opaqueVal = val.getAsOpaquePointer();
243
244 // Get or insert a reference to this value.
245 auto it = uniquedDataToMemIndex.try_emplace(
246 Key: opaqueVal, Args: maxValueMemoryIndex + uniquedData.size());
247 if (it.second)
248 uniquedData.push_back(x: opaqueVal);
249 return it.first->second;
250 }
251
252private:
253 /// Allocate memory indices for the results of operations within the matcher
254 /// and rewriters.
255 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
256 ModuleOp rewriterModule);
257
258 /// Generate the bytecode for the given operation.
259 void generate(Region *region, ByteCodeWriter &writer);
260 void generate(Operation *op, ByteCodeWriter &writer);
261 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
262 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
263 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
264 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
265 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
266 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
267 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
268 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
269 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
270 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
271 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
272 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
273 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
274 void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
275 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
276 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
277 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
278 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
279 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
280 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
281 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
282 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
283 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
284 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
285 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
286 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
287 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
288 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
289 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
290 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
291 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
292 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
293 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
294 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
295 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
296 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
297 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
298 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
299
300 /// Mapping from value to its corresponding memory index.
301 DenseMap<Value, ByteCodeField> valueToMemIndex;
302
303 /// Mapping from a range value to its corresponding range storage index.
304 DenseMap<Value, ByteCodeField> valueToRangeIndex;
305
306 /// Mapping from the name of an externally registered rewrite to its index in
307 /// the bytecode registry.
308 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
309
310 /// Mapping from the name of an externally registered constraint to its index
311 /// in the bytecode registry.
312 llvm::StringMap<ByteCodeField> constraintToMemIndex;
313
314 /// Mapping from rewriter function name to the bytecode address of the
315 /// rewriter function in byte.
316 llvm::StringMap<ByteCodeAddr> rewriterToAddr;
317
318 /// Mapping from a uniqued storage object to its memory index within
319 /// `uniquedData`.
320 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
321
322 /// The current level of the foreach loop.
323 ByteCodeField curLoopLevel = 0;
324
325 /// The current MLIR context.
326 MLIRContext *ctx;
327
328 /// Mapping from block to its address.
329 DenseMap<Block *, ByteCodeAddr> blockToAddr;
330
331 /// Data of the ByteCode class to be populated.
332 std::vector<const void *> &uniquedData;
333 SmallVectorImpl<ByteCodeField> &matcherByteCode;
334 SmallVectorImpl<ByteCodeField> &rewriterByteCode;
335 SmallVectorImpl<PDLByteCodePattern> &patterns;
336 ByteCodeField &maxValueMemoryIndex;
337 ByteCodeField &maxOpRangeMemoryIndex;
338 ByteCodeField &maxTypeRangeMemoryIndex;
339 ByteCodeField &maxValueRangeMemoryIndex;
340 ByteCodeField &maxLoopLevel;
341
342 /// A map of pattern configurations.
343 const DenseMap<Operation *, PDLPatternConfigSet *> &configMap;
344};
345
346/// This class provides utilities for writing a bytecode stream.
347struct ByteCodeWriter {
348 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
349 : bytecode(bytecode), generator(generator) {}
350
351 /// Append a field to the bytecode.
352 void append(ByteCodeField field) { bytecode.push_back(Elt: field); }
353 void append(OpCode opCode) { bytecode.push_back(Elt: opCode); }
354
355 /// Append an address to the bytecode.
356 void append(ByteCodeAddr field) {
357 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
358 "unexpected ByteCode address size");
359
360 ByteCodeField fieldParts[2];
361 std::memcpy(dest: fieldParts, src: &field, n: sizeof(ByteCodeAddr));
362 bytecode.append(IL: {fieldParts[0], fieldParts[1]});
363 }
364
365 /// Append a single successor to the bytecode, the exact address will need to
366 /// be resolved later.
367 void append(Block *successor) {
368 // Add back a reference to the successor so that the address can be resolved
369 // later.
370 unresolvedSuccessorRefs[successor].push_back(Elt: bytecode.size());
371 append(field: ByteCodeAddr(0));
372 }
373
374 /// Append a successor range to the bytecode, the exact address will need to
375 /// be resolved later.
376 void append(SuccessorRange successors) {
377 for (Block *successor : successors)
378 append(successor);
379 }
380
381 /// Append a range of values that will be read as generic PDLValues.
382 void appendPDLValueList(OperandRange values) {
383 bytecode.push_back(Elt: values.size());
384 for (Value value : values)
385 appendPDLValue(value);
386 }
387
388 /// Append a value as a PDLValue.
389 void appendPDLValue(Value value) {
390 appendPDLValueKind(value);
391 append(value);
392 }
393
394 /// Append the PDLValue::Kind of the given value.
395 void appendPDLValueKind(Value value) { appendPDLValueKind(type: value.getType()); }
396
397 /// Append the PDLValue::Kind of the given type.
398 void appendPDLValueKind(Type type) {
399 PDLValue::Kind kind =
400 TypeSwitch<Type, PDLValue::Kind>(type)
401 .Case<pdl::AttributeType>(
402 [](Type) { return PDLValue::Kind::Attribute; })
403 .Case<pdl::OperationType>(
404 [](Type) { return PDLValue::Kind::Operation; })
405 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
406 if (isa<pdl::TypeType>(rangeTy.getElementType()))
407 return PDLValue::Kind::TypeRange;
408 return PDLValue::Kind::ValueRange;
409 })
410 .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
411 .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
412 bytecode.push_back(Elt: static_cast<ByteCodeField>(kind));
413 }
414
415 /// Append a value that will be stored in a memory slot and not inline within
416 /// the bytecode.
417 template <typename T>
418 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
419 std::is_pointer<T>::value>
420 append(T value) {
421 bytecode.push_back(Elt: generator.getMemIndex(value));
422 }
423
424 /// Append a range of values.
425 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
426 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
427 append(T range) {
428 bytecode.push_back(Elt: llvm::size(range));
429 for (auto it : range)
430 append(it);
431 }
432
433 /// Append a variadic number of fields to the bytecode.
434 template <typename FieldTy, typename Field2Ty, typename... FieldTys>
435 void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
436 append(field);
437 append(field2, fields...);
438 }
439
440 /// Appends a value as a pointer, stored inline within the bytecode.
441 template <typename T>
442 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
443 appendInline(T value) {
444 constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
445 const void *pointer = value.getAsOpaquePointer();
446 ByteCodeField fieldParts[numParts];
447 std::memcpy(dest: fieldParts, src: &pointer, n: sizeof(const void *));
448 bytecode.append(in_start: fieldParts, in_end: fieldParts + numParts);
449 }
450
451 /// Successor references in the bytecode that have yet to be resolved.
452 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
453
454 /// The underlying bytecode buffer.
455 SmallVectorImpl<ByteCodeField> &bytecode;
456
457 /// The main generator producing PDL.
458 Generator &generator;
459};
460
461/// This class represents a live range of PDL Interpreter values, containing
462/// information about when values are live within a match/rewrite.
463struct ByteCodeLiveRange {
464 using Set = llvm::IntervalMap<uint64_t, char, 16>;
465 using Allocator = Set::Allocator;
466
467 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
468
469 /// Union this live range with the one provided.
470 void unionWith(const ByteCodeLiveRange &rhs) {
471 for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
472 ++it)
473 liveness->insert(a: it.start(), b: it.stop(), /*dummyValue*/ y: 0);
474 }
475
476 /// Returns true if this range overlaps with the one provided.
477 bool overlaps(const ByteCodeLiveRange &rhs) const {
478 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
479 .valid();
480 }
481
482 /// A map representing the ranges of the match/rewrite that a value is live in
483 /// the interpreter.
484 ///
485 /// We use std::unique_ptr here, because IntervalMap does not provide a
486 /// correct copy or move constructor. We can eliminate the pointer once
487 /// https://reviews.llvm.org/D113240 lands.
488 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
489
490 /// The operation range storage index for this range.
491 std::optional<unsigned> opRangeIndex;
492
493 /// The type range storage index for this range.
494 std::optional<unsigned> typeRangeIndex;
495
496 /// The value range storage index for this range.
497 std::optional<unsigned> valueRangeIndex;
498};
499} // namespace
500
501void Generator::generate(ModuleOp module) {
502 auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
503 pdl_interp::PDLInterpDialect::getMatcherFunctionName());
504 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
505 pdl_interp::PDLInterpDialect::getRewriterModuleName());
506 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
507
508 // Allocate memory indices for the results of operations within the matcher
509 // and rewriters.
510 allocateMemoryIndices(matcherFunc, rewriterModule);
511
512 // Generate code for the rewriter functions.
513 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
514 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
515 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
516 for (Operation &op : rewriterFunc.getOps())
517 generate(&op, rewriterByteCodeWriter);
518 }
519 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
520 "unexpected branches in rewriter function");
521
522 // Generate code for the matcher function.
523 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
524 generate(&matcherFunc.getBody(), matcherByteCodeWriter);
525
526 // Resolve successor references in the matcher.
527 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
528 ByteCodeAddr addr = blockToAddr[it.first];
529 for (unsigned offsetToFix : it.second)
530 std::memcpy(dest: &matcherByteCode[offsetToFix], src: &addr, n: sizeof(ByteCodeAddr));
531 }
532}
533
534void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
535 ModuleOp rewriterModule) {
536 // Rewriters use simplistic allocation scheme that simply assigns an index to
537 // each result.
538 for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
539 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
540 auto processRewriterValue = [&](Value val) {
541 valueToMemIndex.try_emplace(val, index++);
542 if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
543 Type elementTy = rangeType.getElementType();
544 if (isa<pdl::TypeType>(elementTy))
545 valueToRangeIndex.try_emplace(val, typeRangeIndex++);
546 else if (isa<pdl::ValueType>(elementTy))
547 valueToRangeIndex.try_emplace(val, valueRangeIndex++);
548 }
549 };
550
551 for (BlockArgument arg : rewriterFunc.getArguments())
552 processRewriterValue(arg);
553 rewriterFunc.getBody().walk([&](Operation *op) {
554 for (Value result : op->getResults())
555 processRewriterValue(result);
556 });
557 if (index > maxValueMemoryIndex)
558 maxValueMemoryIndex = index;
559 if (typeRangeIndex > maxTypeRangeMemoryIndex)
560 maxTypeRangeMemoryIndex = typeRangeIndex;
561 if (valueRangeIndex > maxValueRangeMemoryIndex)
562 maxValueRangeMemoryIndex = valueRangeIndex;
563 }
564
565 // The matcher function uses a more sophisticated numbering that tries to
566 // minimize the number of memory indices assigned. This is done by determining
567 // a live range of the values within the matcher, then the allocation is just
568 // finding the minimal number of overlapping live ranges. This is essentially
569 // a simplified form of register allocation where we don't necessarily have a
570 // limited number of registers, but we still want to minimize the number used.
571 DenseMap<Operation *, unsigned> opToFirstIndex;
572 DenseMap<Operation *, unsigned> opToLastIndex;
573
574 // A custom walk that marks the first and the last index of each operation.
575 // The entry marks the beginning of the liveness range for this operation,
576 // followed by nested operations, followed by the end of the liveness range.
577 unsigned index = 0;
578 llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
579 opToFirstIndex.try_emplace(Key: op, Args: index++);
580 for (Region &region : op->getRegions())
581 for (Block &block : region.getBlocks())
582 for (Operation &nested : block)
583 walk(&nested);
584 opToLastIndex.try_emplace(Key: op, Args: index++);
585 };
586 walk(matcherFunc);
587
588 // Liveness info for each of the defs within the matcher.
589 ByteCodeLiveRange::Allocator allocator;
590 DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
591
592 // Assign the root operation being matched to slot 0.
593 BlockArgument rootOpArg = matcherFunc.getArgument(0);
594 valueToMemIndex[rootOpArg] = 0;
595
596 // Walk each of the blocks, computing the def interval that the value is used.
597 Liveness matcherLiveness(matcherFunc);
598 matcherFunc->walk([&](Block *block) {
599 const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
600 assert(info && "expected liveness info for block");
601 auto processValue = [&](Value value, Operation *firstUseOrDef) {
602 // We don't need to process the root op argument, this value is always
603 // assigned to the first memory slot.
604 if (value == rootOpArg)
605 return;
606
607 // Set indices for the range of this block that the value is used.
608 auto defRangeIt = valueDefRanges.try_emplace(Key: value, Args&: allocator).first;
609 defRangeIt->second.liveness->insert(
610 a: opToFirstIndex[firstUseOrDef],
611 b: opToLastIndex[info->getEndOperation(value, startOperation: firstUseOrDef)],
612 /*dummyValue*/ y: 0);
613
614 // Check to see if this value is a range type.
615 if (auto rangeTy = dyn_cast<pdl::RangeType>(value.getType())) {
616 Type eleType = rangeTy.getElementType();
617 if (isa<pdl::OperationType>(eleType))
618 defRangeIt->second.opRangeIndex = 0;
619 else if (isa<pdl::TypeType>(eleType))
620 defRangeIt->second.typeRangeIndex = 0;
621 else if (isa<pdl::ValueType>(eleType))
622 defRangeIt->second.valueRangeIndex = 0;
623 }
624 };
625
626 // Process the live-ins of this block.
627 for (Value liveIn : info->in()) {
628 // Only process the value if it has been defined in the current region.
629 // Other values that span across pdl_interp.foreach will be added higher
630 // up. This ensures that the we keep them alive for the entire duration
631 // of the loop.
632 if (liveIn.getParentRegion() == block->getParent())
633 processValue(liveIn, &block->front());
634 }
635
636 // Process the block arguments for the entry block (those are not live-in).
637 if (block->isEntryBlock()) {
638 for (Value argument : block->getArguments())
639 processValue(argument, &block->front());
640 }
641
642 // Process any new defs within this block.
643 for (Operation &op : *block)
644 for (Value result : op.getResults())
645 processValue(result, &op);
646 });
647
648 // Greedily allocate memory slots using the computed def live ranges.
649 std::vector<ByteCodeLiveRange> allocatedIndices;
650
651 // The number of memory indices currently allocated (and its next value).
652 // Recall that the root gets allocated memory index 0.
653 ByteCodeField numIndices = 1;
654
655 // The number of memory ranges of various types (and their next values).
656 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
657
658 for (auto &defIt : valueDefRanges) {
659 ByteCodeField &memIndex = valueToMemIndex[defIt.first];
660 ByteCodeLiveRange &defRange = defIt.second;
661
662 // Try to allocate to an existing index.
663 for (const auto &existingIndexIt : llvm::enumerate(First&: allocatedIndices)) {
664 ByteCodeLiveRange &existingRange = existingIndexIt.value();
665 if (!defRange.overlaps(rhs: existingRange)) {
666 existingRange.unionWith(rhs: defRange);
667 memIndex = existingIndexIt.index() + 1;
668
669 if (defRange.opRangeIndex) {
670 if (!existingRange.opRangeIndex)
671 existingRange.opRangeIndex = numOpRanges++;
672 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
673 } else if (defRange.typeRangeIndex) {
674 if (!existingRange.typeRangeIndex)
675 existingRange.typeRangeIndex = numTypeRanges++;
676 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
677 } else if (defRange.valueRangeIndex) {
678 if (!existingRange.valueRangeIndex)
679 existingRange.valueRangeIndex = numValueRanges++;
680 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
681 }
682 break;
683 }
684 }
685
686 // If no existing index could be used, add a new one.
687 if (memIndex == 0) {
688 allocatedIndices.emplace_back(args&: allocator);
689 ByteCodeLiveRange &newRange = allocatedIndices.back();
690 newRange.unionWith(rhs: defRange);
691
692 // Allocate an index for op/type/value ranges.
693 if (defRange.opRangeIndex) {
694 newRange.opRangeIndex = numOpRanges;
695 valueToRangeIndex[defIt.first] = numOpRanges++;
696 } else if (defRange.typeRangeIndex) {
697 newRange.typeRangeIndex = numTypeRanges;
698 valueToRangeIndex[defIt.first] = numTypeRanges++;
699 } else if (defRange.valueRangeIndex) {
700 newRange.valueRangeIndex = numValueRanges;
701 valueToRangeIndex[defIt.first] = numValueRanges++;
702 }
703
704 memIndex = allocatedIndices.size();
705 ++numIndices;
706 }
707 }
708
709 // Print the index usage and ensure that we did not run out of index space.
710 LLVM_DEBUG({
711 llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
712 << "(down from initial " << valueDefRanges.size() << ").\n";
713 });
714 assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
715 "Ran out of memory for allocated indices");
716
717 // Update the max number of indices.
718 if (numIndices > maxValueMemoryIndex)
719 maxValueMemoryIndex = numIndices;
720 if (numOpRanges > maxOpRangeMemoryIndex)
721 maxOpRangeMemoryIndex = numOpRanges;
722 if (numTypeRanges > maxTypeRangeMemoryIndex)
723 maxTypeRangeMemoryIndex = numTypeRanges;
724 if (numValueRanges > maxValueRangeMemoryIndex)
725 maxValueRangeMemoryIndex = numValueRanges;
726}
727
728void Generator::generate(Region *region, ByteCodeWriter &writer) {
729 llvm::ReversePostOrderTraversal<Region *> rpot(region);
730 for (Block *block : rpot) {
731 // Keep track of where this block begins within the matcher function.
732 blockToAddr.try_emplace(Key: block, Args: matcherByteCode.size());
733 for (Operation &op : *block)
734 generate(op: &op, writer);
735 }
736}
737
738void Generator::generate(Operation *op, ByteCodeWriter &writer) {
739 LLVM_DEBUG({
740 // The following list must contain all the operations that do not
741 // produce any bytecode.
742 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
743 writer.appendInline(op->getLoc());
744 });
745 TypeSwitch<Operation *>(op)
746 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
747 pdl_interp::AreEqualOp, pdl_interp::BranchOp,
748 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
749 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
750 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
751 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
752 pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
753 pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
754 pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
755 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
756 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
757 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
758 pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
759 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
760 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
761 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
762 pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
763 pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
764 pdl_interp::SwitchResultCountOp>(
765 [&](auto interpOp) { this->generate(interpOp, writer); })
766 .Default([](Operation *) {
767 llvm_unreachable("unknown `pdl_interp` operation");
768 });
769}
770
771void Generator::generate(pdl_interp::ApplyConstraintOp op,
772 ByteCodeWriter &writer) {
773 // Constraints that should return a value have to be registered as rewrites.
774 // If a constraint and a rewrite of similar name are registered the
775 // constraint takes precedence
776 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
777 writer.appendPDLValueList(values: op.getArgs());
778 writer.append(field: ByteCodeField(op.getIsNegated()));
779 ResultRange results = op.getResults();
780 writer.append(field: ByteCodeField(results.size()));
781 for (Value result : results) {
782 // We record the expected kind of the result, so that we can provide extra
783 // verification of the native rewrite function and handle the failure case
784 // of constraints accordingly.
785 writer.appendPDLValueKind(result);
786
787 // Range results also need to append the range storage index.
788 if (isa<pdl::RangeType>(result.getType()))
789 writer.append(getRangeStorageIndex(result));
790 writer.append(result);
791 }
792 writer.append(op.getSuccessors());
793}
794void Generator::generate(pdl_interp::ApplyRewriteOp op,
795 ByteCodeWriter &writer) {
796 assert(externalRewriterToMemIndex.count(op.getName()) &&
797 "expected index for rewrite function");
798 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
799 writer.appendPDLValueList(values: op.getArgs());
800
801 ResultRange results = op.getResults();
802 writer.append(field: ByteCodeField(results.size()));
803 for (Value result : results) {
804 // We record the expected kind of the result, so that we
805 // can provide extra verification of the native rewrite function.
806 writer.appendPDLValueKind(result);
807
808 // Range results also need to append the range storage index.
809 if (isa<pdl::RangeType>(result.getType()))
810 writer.append(getRangeStorageIndex(result));
811 writer.append(result);
812 }
813}
814void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
815 Value lhs = op.getLhs();
816 if (isa<pdl::RangeType>(lhs.getType())) {
817 writer.append(opCode: OpCode::AreRangesEqual);
818 writer.appendPDLValueKind(value: lhs);
819 writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
820 return;
821 }
822
823 writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
824}
825void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
826 writer.append(field: OpCode::Branch, field2: SuccessorRange(op.getOperation()));
827}
828void Generator::generate(pdl_interp::CheckAttributeOp op,
829 ByteCodeWriter &writer) {
830 writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
831 op.getSuccessors());
832}
833void Generator::generate(pdl_interp::CheckOperandCountOp op,
834 ByteCodeWriter &writer) {
835 writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
836 static_cast<ByteCodeField>(op.getCompareAtLeast()),
837 op.getSuccessors());
838}
839void Generator::generate(pdl_interp::CheckOperationNameOp op,
840 ByteCodeWriter &writer) {
841 writer.append(OpCode::CheckOperationName, op.getInputOp(),
842 OperationName(op.getName(), ctx), op.getSuccessors());
843}
844void Generator::generate(pdl_interp::CheckResultCountOp op,
845 ByteCodeWriter &writer) {
846 writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
847 static_cast<ByteCodeField>(op.getCompareAtLeast()),
848 op.getSuccessors());
849}
850void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
851 writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
852 op.getSuccessors());
853}
854void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
855 writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
856 op.getSuccessors());
857}
858void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
859 assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
860 writer.append(field: OpCode::Continue, field2: ByteCodeField(curLoopLevel - 1));
861}
862void Generator::generate(pdl_interp::CreateAttributeOp op,
863 ByteCodeWriter &writer) {
864 // Simply repoint the memory index of the result to the constant.
865 getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
866}
867void Generator::generate(pdl_interp::CreateOperationOp op,
868 ByteCodeWriter &writer) {
869 writer.append(OpCode::CreateOperation, op.getResultOp(),
870 OperationName(op.getName(), ctx));
871 writer.appendPDLValueList(values: op.getInputOperands());
872
873 // Add the attributes.
874 OperandRange attributes = op.getInputAttributes();
875 writer.append(field: static_cast<ByteCodeField>(attributes.size()));
876 for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
877 writer.append(std::get<0>(it), std::get<1>(it));
878
879 // Add the result types. If the operation has inferred results, we use a
880 // marker "size" value. Otherwise, we add the list of explicit result types.
881 if (op.getInferredResultTypes())
882 writer.append(field: kInferTypesMarker);
883 else
884 writer.appendPDLValueList(values: op.getInputResultTypes());
885}
886void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
887 // Append the correct opcode for the range type.
888 TypeSwitch<Type>(op.getType().getElementType())
889 .Case(
890 caseFn: [&](pdl::TypeType) { writer.append(opCode: OpCode::CreateDynamicTypeRange); })
891 .Case(caseFn: [&](pdl::ValueType) {
892 writer.append(opCode: OpCode::CreateDynamicValueRange);
893 });
894
895 writer.append(op.getResult(), getRangeStorageIndex(value: op.getResult()));
896 writer.appendPDLValueList(values: op->getOperands());
897}
898void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
899 // Simply repoint the memory index of the result to the constant.
900 getMemIndex(op.getResult()) = getMemIndex(op.getValue());
901}
902void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
903 writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
904 getRangeStorageIndex(value: op.getResult()), op.getValue());
905}
906void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
907 writer.append(OpCode::EraseOp, op.getInputOp());
908}
909void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
910 OpCode opCode =
911 TypeSwitch<Type, OpCode>(op.getResult().getType())
912 .Case(caseFn: [](pdl::OperationType) { return OpCode::ExtractOp; })
913 .Case(caseFn: [](pdl::ValueType) { return OpCode::ExtractValue; })
914 .Case(caseFn: [](pdl::TypeType) { return OpCode::ExtractType; })
915 .Default(defaultFn: [](Type) -> OpCode {
916 llvm_unreachable("unsupported element type");
917 });
918 writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
919}
920void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
921 writer.append(opCode: OpCode::Finalize);
922}
923void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
924 BlockArgument arg = op.getLoopVariable();
925 writer.append(OpCode::ForEach, getRangeStorageIndex(value: op.getValues()), arg);
926 writer.appendPDLValueKind(type: arg.getType());
927 writer.append(curLoopLevel, op.getSuccessor());
928 ++curLoopLevel;
929 if (curLoopLevel > maxLoopLevel)
930 maxLoopLevel = curLoopLevel;
931 generate(&op.getRegion(), writer);
932 --curLoopLevel;
933}
934void Generator::generate(pdl_interp::GetAttributeOp op,
935 ByteCodeWriter &writer) {
936 writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
937 op.getNameAttr());
938}
939void Generator::generate(pdl_interp::GetAttributeTypeOp op,
940 ByteCodeWriter &writer) {
941 writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
942}
943void Generator::generate(pdl_interp::GetDefiningOpOp op,
944 ByteCodeWriter &writer) {
945 writer.append(OpCode::GetDefiningOp, op.getInputOp());
946 writer.appendPDLValue(value: op.getValue());
947}
948void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
949 uint32_t index = op.getIndex();
950 if (index < 4)
951 writer.append(opCode: static_cast<OpCode>(OpCode::GetOperand0 + index));
952 else
953 writer.append(field: OpCode::GetOperandN, field2: index);
954 writer.append(op.getInputOp(), op.getValue());
955}
956void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
957 Value result = op.getValue();
958 std::optional<uint32_t> index = op.getIndex();
959 writer.append(OpCode::GetOperands,
960 index.value_or(u: std::numeric_limits<uint32_t>::max()),
961 op.getInputOp());
962 if (isa<pdl::RangeType>(result.getType()))
963 writer.append(field: getRangeStorageIndex(value: result));
964 else
965 writer.append(field: std::numeric_limits<ByteCodeField>::max());
966 writer.append(value: result);
967}
968void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
969 uint32_t index = op.getIndex();
970 if (index < 4)
971 writer.append(opCode: static_cast<OpCode>(OpCode::GetResult0 + index));
972 else
973 writer.append(field: OpCode::GetResultN, field2: index);
974 writer.append(op.getInputOp(), op.getValue());
975}
976void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
977 Value result = op.getValue();
978 std::optional<uint32_t> index = op.getIndex();
979 writer.append(OpCode::GetResults,
980 index.value_or(u: std::numeric_limits<uint32_t>::max()),
981 op.getInputOp());
982 if (isa<pdl::RangeType>(result.getType()))
983 writer.append(field: getRangeStorageIndex(value: result));
984 else
985 writer.append(field: std::numeric_limits<ByteCodeField>::max());
986 writer.append(value: result);
987}
988void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
989 Value operations = op.getOperations();
990 ByteCodeField rangeIndex = getRangeStorageIndex(value: operations);
991 writer.append(field: OpCode::GetUsers, field2: operations, fields: rangeIndex);
992 writer.appendPDLValue(value: op.getValue());
993}
994void Generator::generate(pdl_interp::GetValueTypeOp op,
995 ByteCodeWriter &writer) {
996 if (isa<pdl::RangeType>(op.getType())) {
997 Value result = op.getResult();
998 writer.append(OpCode::GetValueRangeTypes, result,
999 getRangeStorageIndex(value: result), op.getValue());
1000 } else {
1001 writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
1002 }
1003}
1004void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
1005 writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
1006}
1007void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
1008 ByteCodeField patternIndex = patterns.size();
1009 patterns.emplace_back(PDLByteCodePattern::create(
1010 matchOp: op, configSet: configMap.lookup(Val: op),
1011 rewriterAddr: rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
1012 writer.append(OpCode::RecordMatch, patternIndex,
1013 SuccessorRange(op.getOperation()), op.getMatchedOps());
1014 writer.appendPDLValueList(values: op.getInputs());
1015}
1016void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1017 writer.append(OpCode::ReplaceOp, op.getInputOp());
1018 writer.appendPDLValueList(values: op.getReplValues());
1019}
1020void Generator::generate(pdl_interp::SwitchAttributeOp op,
1021 ByteCodeWriter &writer) {
1022 writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1023 op.getCaseValuesAttr(), op.getSuccessors());
1024}
1025void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1026 ByteCodeWriter &writer) {
1027 writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1028 op.getCaseValuesAttr(), op.getSuccessors());
1029}
1030void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1031 ByteCodeWriter &writer) {
1032 auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
1033 return OperationName(cast<StringAttr>(attr).getValue(), ctx);
1034 });
1035 writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1036 op.getSuccessors());
1037}
1038void Generator::generate(pdl_interp::SwitchResultCountOp op,
1039 ByteCodeWriter &writer) {
1040 writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1041 op.getCaseValuesAttr(), op.getSuccessors());
1042}
1043void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1044 writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1045 op.getSuccessors());
1046}
1047void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1048 writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1049 op.getSuccessors());
1050}
1051
1052//===----------------------------------------------------------------------===//
1053// PDLByteCode
1054//===----------------------------------------------------------------------===//
1055
1056PDLByteCode::PDLByteCode(
1057 ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1058 const DenseMap<Operation *, PDLPatternConfigSet *> &configMap,
1059 llvm::StringMap<PDLConstraintFunction> constraintFns,
1060 llvm::StringMap<PDLRewriteFunction> rewriteFns)
1061 : configs(std::move(configs)) {
1062 Generator generator(module.getContext(), uniquedData, matcherByteCode,
1063 rewriterByteCode, patterns, maxValueMemoryIndex,
1064 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1065 maxLoopLevel, constraintFns, rewriteFns, configMap);
1066 generator.generate(module);
1067
1068 // Initialize the external functions.
1069 for (auto &it : constraintFns)
1070 constraintFunctions.push_back(x: std::move(it.second));
1071 for (auto &it : rewriteFns)
1072 rewriteFunctions.push_back(x: std::move(it.second));
1073}
1074
1075/// Initialize the given state such that it can be used to execute the current
1076/// bytecode.
1077void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1078 state.memory.resize(new_size: maxValueMemoryIndex, x: nullptr);
1079 state.opRangeMemory.resize(new_size: maxOpRangeCount);
1080 state.typeRangeMemory.resize(new_size: maxTypeRangeCount, x: TypeRange());
1081 state.valueRangeMemory.resize(new_size: maxValueRangeCount, x: ValueRange());
1082 state.loopIndex.resize(new_size: maxLoopLevel, x: 0);
1083 state.currentPatternBenefits.reserve(n: patterns.size());
1084 for (const PDLByteCodePattern &pattern : patterns)
1085 state.currentPatternBenefits.push_back(x: pattern.getBenefit());
1086}
1087
1088//===----------------------------------------------------------------------===//
1089// ByteCode Execution
1090//===----------------------------------------------------------------------===//
1091
1092namespace {
1093/// This class is an instantiation of the PDLResultList that provides access to
1094/// the returned results. This API is not on `PDLResultList` to avoid
1095/// overexposing access to information specific solely to the ByteCode.
1096class ByteCodeRewriteResultList : public PDLResultList {
1097public:
1098 ByteCodeRewriteResultList(unsigned maxNumResults)
1099 : PDLResultList(maxNumResults) {}
1100
1101 /// Return the list of PDL results.
1102 MutableArrayRef<PDLValue> getResults() { return results; }
1103
1104 /// Return the type ranges allocated by this list.
1105 MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1106 return allocatedTypeRanges;
1107 }
1108
1109 /// Return the value ranges allocated by this list.
1110 MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1111 return allocatedValueRanges;
1112 }
1113};
1114
1115/// This class provides support for executing a bytecode stream.
1116class ByteCodeExecutor {
1117public:
1118 ByteCodeExecutor(
1119 const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1120 MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1121 MutableArrayRef<TypeRange> typeRangeMemory,
1122 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1123 MutableArrayRef<ValueRange> valueRangeMemory,
1124 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1125 MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1126 ArrayRef<ByteCodeField> code,
1127 ArrayRef<PatternBenefit> currentPatternBenefits,
1128 ArrayRef<PDLByteCodePattern> patterns,
1129 ArrayRef<PDLConstraintFunction> constraintFunctions,
1130 ArrayRef<PDLRewriteFunction> rewriteFunctions)
1131 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1132 typeRangeMemory(typeRangeMemory),
1133 allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1134 valueRangeMemory(valueRangeMemory),
1135 allocatedValueRangeMemory(allocatedValueRangeMemory),
1136 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1137 currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1138 constraintFunctions(constraintFunctions),
1139 rewriteFunctions(rewriteFunctions) {}
1140
1141 /// Start executing the code at the current bytecode index. `matches` is an
1142 /// optional field provided when this function is executed in a matching
1143 /// context.
1144 LogicalResult
1145 execute(PatternRewriter &rewriter,
1146 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1147 std::optional<Location> mainRewriteLoc = {});
1148
1149private:
1150 /// Internal implementation of executing each of the bytecode commands.
1151 void executeApplyConstraint(PatternRewriter &rewriter);
1152 LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
1153 void executeAreEqual();
1154 void executeAreRangesEqual();
1155 void executeBranch();
1156 void executeCheckOperandCount();
1157 void executeCheckOperationName();
1158 void executeCheckResultCount();
1159 void executeCheckTypes();
1160 void executeContinue();
1161 void executeCreateConstantTypeRange();
1162 void executeCreateOperation(PatternRewriter &rewriter,
1163 Location mainRewriteLoc);
1164 template <typename T>
1165 void executeDynamicCreateRange(StringRef type);
1166 void executeEraseOp(PatternRewriter &rewriter);
1167 template <typename T, typename Range, PDLValue::Kind kind>
1168 void executeExtract();
1169 void executeFinalize();
1170 void executeForEach();
1171 void executeGetAttribute();
1172 void executeGetAttributeType();
1173 void executeGetDefiningOp();
1174 void executeGetOperand(unsigned index);
1175 void executeGetOperands();
1176 void executeGetResult(unsigned index);
1177 void executeGetResults();
1178 void executeGetUsers();
1179 void executeGetValueType();
1180 void executeGetValueRangeTypes();
1181 void executeIsNotNull();
1182 void executeRecordMatch(PatternRewriter &rewriter,
1183 SmallVectorImpl<PDLByteCode::MatchResult> &matches);
1184 void executeReplaceOp(PatternRewriter &rewriter);
1185 void executeSwitchAttribute();
1186 void executeSwitchOperandCount();
1187 void executeSwitchOperationName();
1188 void executeSwitchResultCount();
1189 void executeSwitchType();
1190 void executeSwitchTypes();
1191 void processNativeFunResults(ByteCodeRewriteResultList &results,
1192 unsigned numResults,
1193 LogicalResult &rewriteResult);
1194
1195 /// Pushes a code iterator to the stack.
1196 void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(Elt: it); }
1197
1198 /// Pops a code iterator from the stack, returning true on success.
1199 void popCodeIt() {
1200 assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1201 curCodeIt = resumeCodeIt.pop_back_val();
1202 }
1203
1204 /// Return the bytecode iterator at the start of the current op code.
1205 const ByteCodeField *getPrevCodeIt() const {
1206 LLVM_DEBUG({
1207 // Account for the op code and the Location stored inline.
1208 return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1209 });
1210
1211 // Account for the op code only.
1212 return curCodeIt - 1;
1213 }
1214
1215 /// Read a value from the bytecode buffer, optionally skipping a certain
1216 /// number of prefix values. These methods always update the buffer to point
1217 /// to the next field after the read data.
1218 template <typename T = ByteCodeField>
1219 T read(size_t skipN = 0) {
1220 curCodeIt += skipN;
1221 return readImpl<T>();
1222 }
1223 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1224
1225 /// Read a list of values from the bytecode buffer.
1226 template <typename ValueT, typename T>
1227 void readList(SmallVectorImpl<T> &list) {
1228 list.clear();
1229 for (unsigned i = 0, e = read(); i != e; ++i)
1230 list.push_back(read<ValueT>());
1231 }
1232
1233 /// Read a list of values from the bytecode buffer. The values may be encoded
1234 /// either as a single element or a range of elements.
1235 void readList(SmallVectorImpl<Type> &list) {
1236 for (unsigned i = 0, e = read(); i != e; ++i) {
1237 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1238 list.push_back(Elt: read<Type>());
1239 } else {
1240 TypeRange *values = read<TypeRange *>();
1241 list.append(in_start: values->begin(), in_end: values->end());
1242 }
1243 }
1244 }
1245 void readList(SmallVectorImpl<Value> &list) {
1246 for (unsigned i = 0, e = read(); i != e; ++i) {
1247 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1248 list.push_back(Elt: read<Value>());
1249 } else {
1250 ValueRange *values = read<ValueRange *>();
1251 list.append(in_start: values->begin(), in_end: values->end());
1252 }
1253 }
1254 }
1255
1256 /// Read a value stored inline as a pointer.
1257 template <typename T>
1258 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1259 readInline() {
1260 const void *pointer;
1261 std::memcpy(dest: &pointer, src: curCodeIt, n: sizeof(const void *));
1262 curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1263 return T::getFromOpaquePointer(pointer);
1264 }
1265
1266 void skip(size_t skipN) { curCodeIt += skipN; }
1267
1268 /// Jump to a specific successor based on a predicate value.
1269 void selectJump(bool isTrue) { selectJump(destIndex: size_t(isTrue ? 0 : 1)); }
1270 /// Jump to a specific successor based on a destination index.
1271 void selectJump(size_t destIndex) {
1272 curCodeIt = &code[read<ByteCodeAddr>(skipN: destIndex * 2)];
1273 }
1274
1275 /// Handle a switch operation with the provided value and cases.
1276 template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1277 void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1278 LLVM_DEBUG({
1279 llvm::dbgs() << " * Value: " << value << "\n"
1280 << " * Cases: ";
1281 llvm::interleaveComma(cases, llvm::dbgs());
1282 llvm::dbgs() << "\n";
1283 });
1284
1285 // Check to see if the attribute value is within the case list. Jump to
1286 // the correct successor index based on the result.
1287 for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1288 if (cmp(*it, value))
1289 return selectJump(destIndex: size_t((it - cases.begin()) + 1));
1290 selectJump(destIndex: size_t(0));
1291 }
1292
1293 /// Store a pointer to memory.
1294 void storeToMemory(unsigned index, const void *value) {
1295 memory[index] = value;
1296 }
1297
1298 /// Store a value to memory as an opaque pointer.
1299 template <typename T>
1300 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1301 storeToMemory(unsigned index, T value) {
1302 memory[index] = value.getAsOpaquePointer();
1303 }
1304
1305 /// Internal implementation of reading various data types from the bytecode
1306 /// stream.
1307 template <typename T>
1308 const void *readFromMemory() {
1309 size_t index = *curCodeIt++;
1310
1311 // If this type is an SSA value, it can only be stored in non-const memory.
1312 if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1313 Value>::value ||
1314 index < memory.size())
1315 return memory[index];
1316
1317 // Otherwise, if this index is not inbounds it is uniqued.
1318 return uniquedMemory[index - memory.size()];
1319 }
1320 template <typename T>
1321 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1322 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1323 }
1324 template <typename T>
1325 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1326 T>
1327 readImpl() {
1328 return T(T::getFromOpaquePointer(readFromMemory<T>()));
1329 }
1330 template <typename T>
1331 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1332 switch (read<PDLValue::Kind>()) {
1333 case PDLValue::Kind::Attribute:
1334 return read<Attribute>();
1335 case PDLValue::Kind::Operation:
1336 return read<Operation *>();
1337 case PDLValue::Kind::Type:
1338 return read<Type>();
1339 case PDLValue::Kind::Value:
1340 return read<Value>();
1341 case PDLValue::Kind::TypeRange:
1342 return read<TypeRange *>();
1343 case PDLValue::Kind::ValueRange:
1344 return read<ValueRange *>();
1345 }
1346 llvm_unreachable("unhandled PDLValue::Kind");
1347 }
1348 template <typename T>
1349 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1350 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1351 "unexpected ByteCode address size");
1352 ByteCodeAddr result;
1353 std::memcpy(dest: &result, src: curCodeIt, n: sizeof(ByteCodeAddr));
1354 curCodeIt += 2;
1355 return result;
1356 }
1357 template <typename T>
1358 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1359 return *curCodeIt++;
1360 }
1361 template <typename T>
1362 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1363 return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1364 }
1365
1366 /// Assign the given range to the given memory index. This allocates a new
1367 /// range object if necessary.
1368 template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
1369 void assignRangeToMemory(RangeT &&range, unsigned memIndex,
1370 unsigned rangeIndex) {
1371 // Utility functor used to type-erase the assignment.
1372 auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
1373 // If the input range is empty, we don't need to allocate anything.
1374 if (range.empty()) {
1375 rangeMemory[rangeIndex] = {};
1376 } else {
1377 // Allocate a buffer for this type range.
1378 llvm::OwningArrayRef<T> storage(llvm::size(range));
1379 llvm::copy(range, storage.begin());
1380
1381 // Assign this to the range slot and use the range as the value for the
1382 // memory index.
1383 allocatedRangeMemory.emplace_back(std::move(storage));
1384 rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1385 }
1386 memory[memIndex] = &rangeMemory[rangeIndex];
1387 };
1388
1389 // Dispatch based on the concrete range type.
1390 if constexpr (std::is_same_v<T, Type>) {
1391 return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1392 } else if constexpr (std::is_same_v<T, Value>) {
1393 return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1394 } else {
1395 llvm_unreachable("unhandled range type");
1396 }
1397 }
1398
1399 /// The underlying bytecode buffer.
1400 const ByteCodeField *curCodeIt;
1401
1402 /// The stack of bytecode positions at which to resume operation.
1403 SmallVector<const ByteCodeField *> resumeCodeIt;
1404
1405 /// The current execution memory.
1406 MutableArrayRef<const void *> memory;
1407 MutableArrayRef<OwningOpRange> opRangeMemory;
1408 MutableArrayRef<TypeRange> typeRangeMemory;
1409 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1410 MutableArrayRef<ValueRange> valueRangeMemory;
1411 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1412
1413 /// The current loop indices.
1414 MutableArrayRef<unsigned> loopIndex;
1415
1416 /// References to ByteCode data necessary for execution.
1417 ArrayRef<const void *> uniquedMemory;
1418 ArrayRef<ByteCodeField> code;
1419 ArrayRef<PatternBenefit> currentPatternBenefits;
1420 ArrayRef<PDLByteCodePattern> patterns;
1421 ArrayRef<PDLConstraintFunction> constraintFunctions;
1422 ArrayRef<PDLRewriteFunction> rewriteFunctions;
1423};
1424} // namespace
1425
1426void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1427 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1428 ByteCodeField fun_idx = read();
1429 SmallVector<PDLValue, 16> args;
1430 readList<PDLValue>(list&: args);
1431
1432 LLVM_DEBUG({
1433 llvm::dbgs() << " * Arguments: ";
1434 llvm::interleaveComma(args, llvm::dbgs());
1435 llvm::dbgs() << "\n";
1436 });
1437
1438 ByteCodeField isNegated = read();
1439 LLVM_DEBUG({
1440 llvm::dbgs() << " * isNegated: " << isNegated << "\n";
1441 llvm::interleaveComma(args, llvm::dbgs());
1442 });
1443
1444 ByteCodeField numResults = read();
1445 const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
1446 ByteCodeRewriteResultList results(numResults);
1447 LogicalResult rewriteResult = constraintFn(rewriter, results, args);
1448 [[maybe_unused]] ArrayRef<PDLValue> constraintResults = results.getResults();
1449 LLVM_DEBUG({
1450 if (succeeded(rewriteResult)) {
1451 llvm::dbgs() << " * Constraint succeeded\n";
1452 llvm::dbgs() << " * Results: ";
1453 llvm::interleaveComma(constraintResults, llvm::dbgs());
1454 llvm::dbgs() << "\n";
1455 } else {
1456 llvm::dbgs() << " * Constraint failed\n";
1457 }
1458 });
1459 assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
1460 "native PDL rewrite function succeeded but returned "
1461 "unexpected number of results");
1462 processNativeFunResults(results, numResults, rewriteResult);
1463
1464 // Depending on the constraint jump to the proper destination.
1465 selectJump(isTrue: isNegated != succeeded(Result: rewriteResult));
1466}
1467
1468LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1469 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1470 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1471 SmallVector<PDLValue, 16> args;
1472 readList<PDLValue>(list&: args);
1473
1474 LLVM_DEBUG({
1475 llvm::dbgs() << " * Arguments: ";
1476 llvm::interleaveComma(args, llvm::dbgs());
1477 });
1478
1479 // Execute the rewrite function.
1480 ByteCodeField numResults = read();
1481 ByteCodeRewriteResultList results(numResults);
1482 LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1483
1484 assert(results.getResults().size() == numResults &&
1485 "native PDL rewrite function returned unexpected number of results");
1486
1487 processNativeFunResults(results, numResults, rewriteResult);
1488
1489 if (failed(Result: rewriteResult)) {
1490 LLVM_DEBUG(llvm::dbgs() << " - Failed");
1491 return failure();
1492 }
1493 return success();
1494}
1495
1496void ByteCodeExecutor::processNativeFunResults(
1497 ByteCodeRewriteResultList &results, unsigned numResults,
1498 LogicalResult &rewriteResult) {
1499 if (failed(Result: rewriteResult)) {
1500 // Skip the according number of values on the buffer on failure and exit
1501 // early as there are no results to process.
1502 for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1503 const PDLValue::Kind resultKind = read<PDLValue::Kind>();
1504 if (resultKind == PDLValue::Kind::TypeRange ||
1505 resultKind == PDLValue::Kind::ValueRange) {
1506 skip(skipN: 2);
1507 } else {
1508 skip(skipN: 1);
1509 }
1510 }
1511 return;
1512 }
1513
1514 // Store the results in the bytecode memory
1515 for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1516 PDLValue::Kind resultKind = read<PDLValue::Kind>();
1517 (void)resultKind;
1518 PDLValue result = results.getResults()[resultIdx];
1519 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
1520 assert(result.getKind() == resultKind &&
1521 "native PDL rewrite function returned an unexpected type of "
1522 "result");
1523 // If the result is a range, we need to copy it over to the bytecodes
1524 // range memory.
1525 if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1526 unsigned rangeIndex = read();
1527 typeRangeMemory[rangeIndex] = *typeRange;
1528 memory[read()] = &typeRangeMemory[rangeIndex];
1529 } else if (std::optional<ValueRange> valueRange =
1530 result.dyn_cast<ValueRange>()) {
1531 unsigned rangeIndex = read();
1532 valueRangeMemory[rangeIndex] = *valueRange;
1533 memory[read()] = &valueRangeMemory[rangeIndex];
1534 } else {
1535 memory[read()] = result.getAsOpaquePointer();
1536 }
1537 }
1538
1539 // Copy over any underlying storage allocated for result ranges.
1540 for (auto &it : results.getAllocatedTypeRanges())
1541 allocatedTypeRangeMemory.push_back(x: std::move(it));
1542 for (auto &it : results.getAllocatedValueRanges())
1543 allocatedValueRangeMemory.push_back(x: std::move(it));
1544}
1545
1546void ByteCodeExecutor::executeAreEqual() {
1547 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1548 const void *lhs = read<const void *>();
1549 const void *rhs = read<const void *>();
1550
1551 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n");
1552 selectJump(isTrue: lhs == rhs);
1553}
1554
1555void ByteCodeExecutor::executeAreRangesEqual() {
1556 LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1557 PDLValue::Kind valueKind = read<PDLValue::Kind>();
1558 const void *lhs = read<const void *>();
1559 const void *rhs = read<const void *>();
1560
1561 switch (valueKind) {
1562 case PDLValue::Kind::TypeRange: {
1563 const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1564 const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1565 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1566 selectJump(isTrue: *lhsRange == *rhsRange);
1567 break;
1568 }
1569 case PDLValue::Kind::ValueRange: {
1570 const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1571 const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1572 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1573 selectJump(isTrue: *lhsRange == *rhsRange);
1574 break;
1575 }
1576 default:
1577 llvm_unreachable("unexpected `AreRangesEqual` value kind");
1578 }
1579}
1580
1581void ByteCodeExecutor::executeBranch() {
1582 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1583 curCodeIt = &code[read<ByteCodeAddr>()];
1584}
1585
1586void ByteCodeExecutor::executeCheckOperandCount() {
1587 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1588 Operation *op = read<Operation *>();
1589 uint32_t expectedCount = read<uint32_t>();
1590 bool compareAtLeast = read();
1591
1592 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
1593 << " * Expected: " << expectedCount << "\n"
1594 << " * Comparator: "
1595 << (compareAtLeast ? ">=" : "==") << "\n");
1596 if (compareAtLeast)
1597 selectJump(isTrue: op->getNumOperands() >= expectedCount);
1598 else
1599 selectJump(isTrue: op->getNumOperands() == expectedCount);
1600}
1601
1602void ByteCodeExecutor::executeCheckOperationName() {
1603 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1604 Operation *op = read<Operation *>();
1605 OperationName expectedName = read<OperationName>();
1606
1607 LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n"
1608 << " * Expected: \"" << expectedName << "\"\n");
1609 selectJump(isTrue: op->getName() == expectedName);
1610}
1611
1612void ByteCodeExecutor::executeCheckResultCount() {
1613 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1614 Operation *op = read<Operation *>();
1615 uint32_t expectedCount = read<uint32_t>();
1616 bool compareAtLeast = read();
1617
1618 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
1619 << " * Expected: " << expectedCount << "\n"
1620 << " * Comparator: "
1621 << (compareAtLeast ? ">=" : "==") << "\n");
1622 if (compareAtLeast)
1623 selectJump(isTrue: op->getNumResults() >= expectedCount);
1624 else
1625 selectJump(isTrue: op->getNumResults() == expectedCount);
1626}
1627
1628void ByteCodeExecutor::executeCheckTypes() {
1629 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1630 TypeRange *lhs = read<TypeRange *>();
1631 Attribute rhs = read<Attribute>();
1632 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1633
1634 selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
1635}
1636
1637void ByteCodeExecutor::executeContinue() {
1638 ByteCodeField level = read();
1639 LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1640 << " * Level: " << level << "\n");
1641 ++loopIndex[level];
1642 popCodeIt();
1643}
1644
1645void ByteCodeExecutor::executeCreateConstantTypeRange() {
1646 LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
1647 unsigned memIndex = read();
1648 unsigned rangeIndex = read();
1649 ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
1650
1651 LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
1652 assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1653 rangeIndex);
1654}
1655
1656void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1657 Location mainRewriteLoc) {
1658 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1659
1660 unsigned memIndex = read();
1661 OperationState state(mainRewriteLoc, read<OperationName>());
1662 readList(list&: state.operands);
1663 for (unsigned i = 0, e = read(); i != e; ++i) {
1664 StringAttr name = read<StringAttr>();
1665 if (Attribute attr = read<Attribute>())
1666 state.addAttribute(name, attr);
1667 }
1668
1669 // Read in the result types. If the "size" is the sentinel value, this
1670 // indicates that the result types should be inferred.
1671 unsigned numResults = read();
1672 if (numResults == kInferTypesMarker) {
1673 InferTypeOpInterface::Concept *inferInterface =
1674 state.name.getInterface<InferTypeOpInterface>();
1675 assert(inferInterface &&
1676 "expected operation to provide InferTypeOpInterface");
1677
1678 // TODO: Handle failure.
1679 if (failed(inferInterface->inferReturnTypes(
1680 state.getContext(), state.location, state.operands,
1681 state.attributes.getDictionary(state.getContext()),
1682 state.getRawProperties(), state.regions, state.types)))
1683 return;
1684 } else {
1685 // Otherwise, this is a fixed number of results.
1686 for (unsigned i = 0; i != numResults; ++i) {
1687 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1688 state.types.push_back(Elt: read<Type>());
1689 } else {
1690 TypeRange *resultTypes = read<TypeRange *>();
1691 state.types.append(in_start: resultTypes->begin(), in_end: resultTypes->end());
1692 }
1693 }
1694 }
1695
1696 Operation *resultOp = rewriter.create(state);
1697 memory[memIndex] = resultOp;
1698
1699 LLVM_DEBUG({
1700 llvm::dbgs() << " * Attributes: "
1701 << state.attributes.getDictionary(state.getContext())
1702 << "\n * Operands: ";
1703 llvm::interleaveComma(state.operands, llvm::dbgs());
1704 llvm::dbgs() << "\n * Result Types: ";
1705 llvm::interleaveComma(state.types, llvm::dbgs());
1706 llvm::dbgs() << "\n * Result: " << *resultOp << "\n";
1707 });
1708}
1709
1710template <typename T>
1711void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1712 LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n");
1713 unsigned memIndex = read();
1714 unsigned rangeIndex = read();
1715 SmallVector<T> values;
1716 readList(values);
1717
1718 LLVM_DEBUG({
1719 llvm::dbgs() << "\n * " << type << "s: ";
1720 llvm::interleaveComma(values, llvm::dbgs());
1721 llvm::dbgs() << "\n";
1722 });
1723
1724 assignRangeToMemory(values, memIndex, rangeIndex);
1725}
1726
1727void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1728 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1729 Operation *op = read<Operation *>();
1730
1731 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1732 rewriter.eraseOp(op);
1733}
1734
1735template <typename T, typename Range, PDLValue::Kind kind>
1736void ByteCodeExecutor::executeExtract() {
1737 LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1738 Range *range = read<Range *>();
1739 unsigned index = read<uint32_t>();
1740 unsigned memIndex = read();
1741
1742 if (!range) {
1743 memory[memIndex] = nullptr;
1744 return;
1745 }
1746
1747 T result = index < range->size() ? (*range)[index] : T();
1748 LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n"
1749 << " * Index: " << index << "\n"
1750 << " * Result: " << result << "\n");
1751 storeToMemory(memIndex, result);
1752}
1753
1754void ByteCodeExecutor::executeFinalize() {
1755 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1756}
1757
1758void ByteCodeExecutor::executeForEach() {
1759 LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1760 const ByteCodeField *prevCodeIt = getPrevCodeIt();
1761 unsigned rangeIndex = read();
1762 unsigned memIndex = read();
1763 const void *value = nullptr;
1764
1765 switch (read<PDLValue::Kind>()) {
1766 case PDLValue::Kind::Operation: {
1767 unsigned &index = loopIndex[read()];
1768 ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1769 assert(index <= array.size() && "iterated past the end");
1770 if (index < array.size()) {
1771 LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n");
1772 value = array[index];
1773 break;
1774 }
1775
1776 LLVM_DEBUG(llvm::dbgs() << " * Done\n");
1777 index = 0;
1778 selectJump(destIndex: size_t(0));
1779 return;
1780 }
1781 default:
1782 llvm_unreachable("unexpected `ForEach` value kind");
1783 }
1784
1785 // Store the iterate value and the stack address.
1786 memory[memIndex] = value;
1787 pushCodeIt(it: prevCodeIt);
1788
1789 // Skip over the successor (we will enter the body of the loop).
1790 read<ByteCodeAddr>();
1791}
1792
1793void ByteCodeExecutor::executeGetAttribute() {
1794 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1795 unsigned memIndex = read();
1796 Operation *op = read<Operation *>();
1797 StringAttr attrName = read<StringAttr>();
1798 Attribute attr = op->getAttr(attrName);
1799
1800 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1801 << " * Attribute: " << attrName << "\n"
1802 << " * Result: " << attr << "\n");
1803 memory[memIndex] = attr.getAsOpaquePointer();
1804}
1805
1806void ByteCodeExecutor::executeGetAttributeType() {
1807 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1808 unsigned memIndex = read();
1809 Attribute attr = read<Attribute>();
1810 Type type;
1811 if (auto typedAttr = dyn_cast<TypedAttr>(attr))
1812 type = typedAttr.getType();
1813
1814 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
1815 << " * Result: " << type << "\n");
1816 memory[memIndex] = type.getAsOpaquePointer();
1817}
1818
1819void ByteCodeExecutor::executeGetDefiningOp() {
1820 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1821 unsigned memIndex = read();
1822 Operation *op = nullptr;
1823 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1824 Value value = read<Value>();
1825 if (value)
1826 op = value.getDefiningOp();
1827 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1828 } else {
1829 ValueRange *values = read<ValueRange *>();
1830 if (values && !values->empty()) {
1831 op = values->front().getDefiningOp();
1832 }
1833 LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n");
1834 }
1835
1836 LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n");
1837 memory[memIndex] = op;
1838}
1839
1840void ByteCodeExecutor::executeGetOperand(unsigned index) {
1841 Operation *op = read<Operation *>();
1842 unsigned memIndex = read();
1843 Value operand =
1844 index < op->getNumOperands() ? op->getOperand(idx: index) : Value();
1845
1846 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1847 << " * Index: " << index << "\n"
1848 << " * Result: " << operand << "\n");
1849 memory[memIndex] = operand.getAsOpaquePointer();
1850}
1851
1852/// This function is the internal implementation of `GetResults` and
1853/// `GetOperands` that provides support for extracting a value range from the
1854/// given operation.
1855template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1856static void *
1857executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1858 ByteCodeField rangeIndex, StringRef attrSizedSegments,
1859 MutableArrayRef<ValueRange> valueRangeMemory) {
1860 // Check for the sentinel index that signals that all values should be
1861 // returned.
1862 if (index == std::numeric_limits<uint32_t>::max()) {
1863 LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n");
1864 // `values` is already the full value range.
1865
1866 // Otherwise, check to see if this operation uses AttrSizedSegments.
1867 } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1868 LLVM_DEBUG(llvm::dbgs()
1869 << " * Extracting values from `" << attrSizedSegments << "`\n");
1870
1871 auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
1872 if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1873 return nullptr;
1874
1875 ArrayRef<int32_t> segments = segmentAttr;
1876 unsigned startIndex =
1877 std::accumulate(first: segments.begin(), last: segments.begin() + index, init: 0);
1878 values = values.slice(startIndex, *std::next(x: segments.begin(), n: index));
1879
1880 LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", "
1881 << *std::next(segments.begin(), index) << "]\n");
1882
1883 // Otherwise, assume this is the last operand group of the operation.
1884 // FIXME: We currently don't support operations with
1885 // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1886 // have a way to detect it's presence.
1887 } else if (values.size() >= index) {
1888 LLVM_DEBUG(llvm::dbgs()
1889 << " * Treating values as trailing variadic range\n");
1890 values = values.drop_front(index);
1891
1892 // If we couldn't detect a way to compute the values, bail out.
1893 } else {
1894 return nullptr;
1895 }
1896
1897 // If the range index is valid, we are returning a range.
1898 if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1899 valueRangeMemory[rangeIndex] = values;
1900 return &valueRangeMemory[rangeIndex];
1901 }
1902
1903 // If a range index wasn't provided, the range is required to be non-variadic.
1904 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1905}
1906
1907void ByteCodeExecutor::executeGetOperands() {
1908 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1909 unsigned index = read<uint32_t>();
1910 Operation *op = read<Operation *>();
1911 ByteCodeField rangeIndex = read();
1912
1913 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1914 values: op->getOperands(), op, index, rangeIndex, attrSizedSegments: "operandSegmentSizes",
1915 valueRangeMemory);
1916 if (!result)
1917 LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n");
1918 memory[read()] = result;
1919}
1920
1921void ByteCodeExecutor::executeGetResult(unsigned index) {
1922 Operation *op = read<Operation *>();
1923 unsigned memIndex = read();
1924 OpResult result =
1925 index < op->getNumResults() ? op->getResult(idx: index) : OpResult();
1926
1927 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1928 << " * Index: " << index << "\n"
1929 << " * Result: " << result << "\n");
1930 memory[memIndex] = result.getAsOpaquePointer();
1931}
1932
1933void ByteCodeExecutor::executeGetResults() {
1934 LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1935 unsigned index = read<uint32_t>();
1936 Operation *op = read<Operation *>();
1937 ByteCodeField rangeIndex = read();
1938
1939 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1940 values: op->getResults(), op, index, rangeIndex, attrSizedSegments: "resultSegmentSizes",
1941 valueRangeMemory);
1942 if (!result)
1943 LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n");
1944 memory[read()] = result;
1945}
1946
1947void ByteCodeExecutor::executeGetUsers() {
1948 LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1949 unsigned memIndex = read();
1950 unsigned rangeIndex = read();
1951 OwningOpRange &range = opRangeMemory[rangeIndex];
1952 memory[memIndex] = &range;
1953
1954 range = OwningOpRange();
1955 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1956 // Read the value.
1957 Value value = read<Value>();
1958 if (!value)
1959 return;
1960 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1961
1962 // Extract the users of a single value.
1963 range = OwningOpRange(std::distance(first: value.user_begin(), last: value.user_end()));
1964 llvm::copy(Range: value.getUsers(), Out: range.begin());
1965 } else {
1966 // Read a range of values.
1967 ValueRange *values = read<ValueRange *>();
1968 if (!values)
1969 return;
1970 LLVM_DEBUG({
1971 llvm::dbgs() << " * Values (" << values->size() << "): ";
1972 llvm::interleaveComma(*values, llvm::dbgs());
1973 llvm::dbgs() << "\n";
1974 });
1975
1976 // Extract all the users of a range of values.
1977 SmallVector<Operation *> users;
1978 for (Value value : *values)
1979 users.append(in_start: value.user_begin(), in_end: value.user_end());
1980 range = OwningOpRange(users.size());
1981 llvm::copy(Range&: users, Out: range.begin());
1982 }
1983
1984 LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n");
1985}
1986
1987void ByteCodeExecutor::executeGetValueType() {
1988 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1989 unsigned memIndex = read();
1990 Value value = read<Value>();
1991 Type type = value ? value.getType() : Type();
1992
1993 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
1994 << " * Result: " << type << "\n");
1995 memory[memIndex] = type.getAsOpaquePointer();
1996}
1997
1998void ByteCodeExecutor::executeGetValueRangeTypes() {
1999 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
2000 unsigned memIndex = read();
2001 unsigned rangeIndex = read();
2002 ValueRange *values = read<ValueRange *>();
2003 if (!values) {
2004 LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n");
2005 memory[memIndex] = nullptr;
2006 return;
2007 }
2008
2009 LLVM_DEBUG({
2010 llvm::dbgs() << " * Values (" << values->size() << "): ";
2011 llvm::interleaveComma(*values, llvm::dbgs());
2012 llvm::dbgs() << "\n * Result: ";
2013 llvm::interleaveComma(values->getType(), llvm::dbgs());
2014 llvm::dbgs() << "\n";
2015 });
2016 typeRangeMemory[rangeIndex] = values->getType();
2017 memory[memIndex] = &typeRangeMemory[rangeIndex];
2018}
2019
2020void ByteCodeExecutor::executeIsNotNull() {
2021 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
2022 const void *value = read<const void *>();
2023
2024 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
2025 selectJump(isTrue: value != nullptr);
2026}
2027
2028void ByteCodeExecutor::executeRecordMatch(
2029 PatternRewriter &rewriter,
2030 SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
2031 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
2032 unsigned patternIndex = read();
2033 PatternBenefit benefit = currentPatternBenefits[patternIndex];
2034 const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
2035
2036 // If the benefit of the pattern is impossible, skip the processing of the
2037 // rest of the pattern.
2038 if (benefit.isImpossibleToMatch()) {
2039 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n");
2040 curCodeIt = dest;
2041 return;
2042 }
2043
2044 // Create a fused location containing the locations of each of the
2045 // operations used in the match. This will be used as the location for
2046 // created operations during the rewrite that don't already have an
2047 // explicit location set.
2048 unsigned numMatchLocs = read();
2049 SmallVector<Location, 4> matchLocs;
2050 matchLocs.reserve(N: numMatchLocs);
2051 for (unsigned i = 0; i != numMatchLocs; ++i)
2052 matchLocs.push_back(Elt: read<Operation *>()->getLoc());
2053 Location matchLoc = rewriter.getFusedLoc(locs: matchLocs);
2054
2055 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
2056 << " * Location: " << matchLoc << "\n");
2057 matches.emplace_back(Args&: matchLoc, Args: patterns[patternIndex], Args&: benefit);
2058 PDLByteCode::MatchResult &match = matches.back();
2059
2060 // Record all of the inputs to the match. If any of the inputs are ranges, we
2061 // will also need to remap the range pointer to memory stored in the match
2062 // state.
2063 unsigned numInputs = read();
2064 match.values.reserve(N: numInputs);
2065 match.typeRangeValues.reserve(N: numInputs);
2066 match.valueRangeValues.reserve(N: numInputs);
2067 for (unsigned i = 0; i < numInputs; ++i) {
2068 switch (read<PDLValue::Kind>()) {
2069 case PDLValue::Kind::TypeRange:
2070 match.typeRangeValues.push_back(Elt: *read<TypeRange *>());
2071 match.values.push_back(Elt: &match.typeRangeValues.back());
2072 break;
2073 case PDLValue::Kind::ValueRange:
2074 match.valueRangeValues.push_back(Elt: *read<ValueRange *>());
2075 match.values.push_back(Elt: &match.valueRangeValues.back());
2076 break;
2077 default:
2078 match.values.push_back(Elt: read<const void *>());
2079 break;
2080 }
2081 }
2082 curCodeIt = dest;
2083}
2084
2085void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
2086 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
2087 Operation *op = read<Operation *>();
2088 SmallVector<Value, 16> args;
2089 readList(list&: args);
2090
2091 LLVM_DEBUG({
2092 llvm::dbgs() << " * Operation: " << *op << "\n"
2093 << " * Values: ";
2094 llvm::interleaveComma(args, llvm::dbgs());
2095 llvm::dbgs() << "\n";
2096 });
2097 rewriter.replaceOp(op, newValues: args);
2098}
2099
2100void ByteCodeExecutor::executeSwitchAttribute() {
2101 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
2102 Attribute value = read<Attribute>();
2103 ArrayAttr cases = read<ArrayAttr>();
2104 handleSwitch(value, cases);
2105}
2106
2107void ByteCodeExecutor::executeSwitchOperandCount() {
2108 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
2109 Operation *op = read<Operation *>();
2110 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2111
2112 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
2113 handleSwitch(op->getNumOperands(), cases);
2114}
2115
2116void ByteCodeExecutor::executeSwitchOperationName() {
2117 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
2118 OperationName value = read<Operation *>()->getName();
2119 size_t caseCount = read();
2120
2121 // The operation names are stored in-line, so to print them out for
2122 // debugging purposes we need to read the array before executing the
2123 // switch so that we can display all of the possible values.
2124 LLVM_DEBUG({
2125 const ByteCodeField *prevCodeIt = curCodeIt;
2126 llvm::dbgs() << " * Value: " << value << "\n"
2127 << " * Cases: ";
2128 llvm::interleaveComma(
2129 llvm::map_range(llvm::seq<size_t>(0, caseCount),
2130 [&](size_t) { return read<OperationName>(); }),
2131 llvm::dbgs());
2132 llvm::dbgs() << "\n";
2133 curCodeIt = prevCodeIt;
2134 });
2135
2136 // Try to find the switch value within any of the cases.
2137 for (size_t i = 0; i != caseCount; ++i) {
2138 if (read<OperationName>() == value) {
2139 curCodeIt += (caseCount - i - 1);
2140 return selectJump(destIndex: i + 1);
2141 }
2142 }
2143 selectJump(destIndex: size_t(0));
2144}
2145
2146void ByteCodeExecutor::executeSwitchResultCount() {
2147 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
2148 Operation *op = read<Operation *>();
2149 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2150
2151 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
2152 handleSwitch(op->getNumResults(), cases);
2153}
2154
2155void ByteCodeExecutor::executeSwitchType() {
2156 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2157 Type value = read<Type>();
2158 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2159 handleSwitch(value, cases);
2160}
2161
2162void ByteCodeExecutor::executeSwitchTypes() {
2163 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2164 TypeRange *value = read<TypeRange *>();
2165 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2166 if (!value) {
2167 LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2168 return selectJump(destIndex: size_t(0));
2169 }
2170 handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2171 return value == caseValue.getAsValueRange<TypeAttr>();
2172 });
2173}
2174
2175LogicalResult
2176ByteCodeExecutor::execute(PatternRewriter &rewriter,
2177 SmallVectorImpl<PDLByteCode::MatchResult> *matches,
2178 std::optional<Location> mainRewriteLoc) {
2179 while (true) {
2180 // Print the location of the operation being executed.
2181 LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2182
2183 OpCode opCode = static_cast<OpCode>(read());
2184 switch (opCode) {
2185 case ApplyConstraint:
2186 executeApplyConstraint(rewriter);
2187 break;
2188 case ApplyRewrite:
2189 if (failed(Result: executeApplyRewrite(rewriter)))
2190 return failure();
2191 break;
2192 case AreEqual:
2193 executeAreEqual();
2194 break;
2195 case AreRangesEqual:
2196 executeAreRangesEqual();
2197 break;
2198 case Branch:
2199 executeBranch();
2200 break;
2201 case CheckOperandCount:
2202 executeCheckOperandCount();
2203 break;
2204 case CheckOperationName:
2205 executeCheckOperationName();
2206 break;
2207 case CheckResultCount:
2208 executeCheckResultCount();
2209 break;
2210 case CheckTypes:
2211 executeCheckTypes();
2212 break;
2213 case Continue:
2214 executeContinue();
2215 break;
2216 case CreateConstantTypeRange:
2217 executeCreateConstantTypeRange();
2218 break;
2219 case CreateOperation:
2220 executeCreateOperation(rewriter, mainRewriteLoc: *mainRewriteLoc);
2221 break;
2222 case CreateDynamicTypeRange:
2223 executeDynamicCreateRange<Type>(type: "Type");
2224 break;
2225 case CreateDynamicValueRange:
2226 executeDynamicCreateRange<Value>(type: "Value");
2227 break;
2228 case EraseOp:
2229 executeEraseOp(rewriter);
2230 break;
2231 case ExtractOp:
2232 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2233 break;
2234 case ExtractType:
2235 executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2236 break;
2237 case ExtractValue:
2238 executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2239 break;
2240 case Finalize:
2241 executeFinalize();
2242 LLVM_DEBUG(llvm::dbgs() << "\n");
2243 return success();
2244 case ForEach:
2245 executeForEach();
2246 break;
2247 case GetAttribute:
2248 executeGetAttribute();
2249 break;
2250 case GetAttributeType:
2251 executeGetAttributeType();
2252 break;
2253 case GetDefiningOp:
2254 executeGetDefiningOp();
2255 break;
2256 case GetOperand0:
2257 case GetOperand1:
2258 case GetOperand2:
2259 case GetOperand3: {
2260 unsigned index = opCode - GetOperand0;
2261 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2262 executeGetOperand(index);
2263 break;
2264 }
2265 case GetOperandN:
2266 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2267 executeGetOperand(index: read<uint32_t>());
2268 break;
2269 case GetOperands:
2270 executeGetOperands();
2271 break;
2272 case GetResult0:
2273 case GetResult1:
2274 case GetResult2:
2275 case GetResult3: {
2276 unsigned index = opCode - GetResult0;
2277 LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2278 executeGetResult(index);
2279 break;
2280 }
2281 case GetResultN:
2282 LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2283 executeGetResult(index: read<uint32_t>());
2284 break;
2285 case GetResults:
2286 executeGetResults();
2287 break;
2288 case GetUsers:
2289 executeGetUsers();
2290 break;
2291 case GetValueType:
2292 executeGetValueType();
2293 break;
2294 case GetValueRangeTypes:
2295 executeGetValueRangeTypes();
2296 break;
2297 case IsNotNull:
2298 executeIsNotNull();
2299 break;
2300 case RecordMatch:
2301 assert(matches &&
2302 "expected matches to be provided when executing the matcher");
2303 executeRecordMatch(rewriter, matches&: *matches);
2304 break;
2305 case ReplaceOp:
2306 executeReplaceOp(rewriter);
2307 break;
2308 case SwitchAttribute:
2309 executeSwitchAttribute();
2310 break;
2311 case SwitchOperandCount:
2312 executeSwitchOperandCount();
2313 break;
2314 case SwitchOperationName:
2315 executeSwitchOperationName();
2316 break;
2317 case SwitchResultCount:
2318 executeSwitchResultCount();
2319 break;
2320 case SwitchType:
2321 executeSwitchType();
2322 break;
2323 case SwitchTypes:
2324 executeSwitchTypes();
2325 break;
2326 }
2327 LLVM_DEBUG(llvm::dbgs() << "\n");
2328 }
2329}
2330
2331void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2332 SmallVectorImpl<MatchResult> &matches,
2333 PDLByteCodeMutableState &state) const {
2334 // The first memory slot is always the root operation.
2335 state.memory[0] = op;
2336
2337 // The matcher function always starts at code address 0.
2338 ByteCodeExecutor executor(
2339 matcherByteCode.data(), state.memory, state.opRangeMemory,
2340 state.typeRangeMemory, state.allocatedTypeRangeMemory,
2341 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2342 uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2343 constraintFunctions, rewriteFunctions);
2344 LogicalResult executeResult = executor.execute(rewriter, matches: &matches);
2345 (void)executeResult;
2346 assert(succeeded(executeResult) && "unexpected matcher execution failure");
2347
2348 // Order the found matches by benefit.
2349 llvm::stable_sort(Range&: matches,
2350 C: [](const MatchResult &lhs, const MatchResult &rhs) {
2351 return lhs.benefit > rhs.benefit;
2352 });
2353}
2354
2355LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
2356 const MatchResult &match,
2357 PDLByteCodeMutableState &state) const {
2358 auto *configSet = match.pattern->getConfigSet();
2359 if (configSet)
2360 configSet->notifyRewriteBegin(rewriter);
2361
2362 // The arguments of the rewrite function are stored at the start of the
2363 // memory buffer.
2364 llvm::copy(Range: match.values, Out: state.memory.begin());
2365
2366 ByteCodeExecutor executor(
2367 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2368 state.opRangeMemory, state.typeRangeMemory,
2369 state.allocatedTypeRangeMemory, state.valueRangeMemory,
2370 state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2371 rewriterByteCode, state.currentPatternBenefits, patterns,
2372 constraintFunctions, rewriteFunctions);
2373 LogicalResult result =
2374 executor.execute(rewriter, /*matches=*/nullptr, mainRewriteLoc: match.location);
2375
2376 if (configSet)
2377 configSet->notifyRewriteEnd(rewriter);
2378
2379 // If the rewrite failed, check if the pattern rewriter can recover. If it
2380 // can, we can signal to the pattern applicator to keep trying patterns. If it
2381 // doesn't, we need to bail. Bailing here should be fine, given that we have
2382 // no means to propagate such a failure to the user, and it also indicates a
2383 // bug in the user code (i.e. failable rewrites should not be used with
2384 // pattern rewriters that don't support it).
2385 if (failed(Result: result) && !rewriter.canRecoverFromRewriteFailure()) {
2386 LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
2387 llvm::report_fatal_error(
2388 reason: "Native PDL Rewrite failed, but the pattern "
2389 "rewriter doesn't support recovery. Failable pattern rewrites should "
2390 "not be used with pattern rewriters that do not support them.");
2391 }
2392 return result;
2393}
2394

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Rewrite/ByteCode.cpp