1//===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
10
11#include "PredicateTree.h"
12#include "mlir/Dialect/PDL/IR/PDL.h"
13#include "mlir/Dialect/PDL/IR/PDLTypes.h"
14#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15#include "mlir/Pass/Pass.h"
16#include "llvm/ADT/MapVector.h"
17#include "llvm/ADT/ScopedHashTable.h"
18#include "llvm/ADT/Sequence.h"
19#include "llvm/ADT/SetVector.h"
20#include "llvm/ADT/SmallVector.h"
21#include "llvm/ADT/TypeSwitch.h"
22
23namespace mlir {
24#define GEN_PASS_DEF_CONVERTPDLTOPDLINTERP
25#include "mlir/Conversion/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29using namespace mlir::pdl_to_pdl_interp;
30
31//===----------------------------------------------------------------------===//
32// PatternLowering
33//===----------------------------------------------------------------------===//
34
35namespace {
36/// This class generators operations within the PDL Interpreter dialect from a
37/// given module containing PDL pattern operations.
38struct PatternLowering {
39public:
40 PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
41 DenseMap<Operation *, PDLPatternConfigSet *> *configMap);
42
43 /// Generate code for matching and rewriting based on the pattern operations
44 /// within the module.
45 void lower(ModuleOp module);
46
47private:
48 using ValueMap = llvm::ScopedHashTable<Position *, Value>;
49 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
50
51 /// Generate interpreter operations for the tree rooted at the given matcher
52 /// node, in the specified region.
53 Block *generateMatcher(MatcherNode &node, Region &region,
54 Block *block = nullptr);
55
56 /// Get or create an access to the provided positional value in the current
57 /// block. This operation may mutate the provided block pointer if nested
58 /// regions (i.e., pdl_interp.iterate) are required.
59 Value getValueAt(Block *&currentBlock, Position *pos);
60
61 /// Create the interpreter predicate operations. This operation may mutate the
62 /// provided current block pointer if nested regions (iterates) are required.
63 void generate(BoolNode *boolNode, Block *&currentBlock, Value val);
64
65 /// Create the interpreter switch / predicate operations, with several case
66 /// destinations. This operation never mutates the provided current block
67 /// pointer, because the switch operation does not need Values beyond `val`.
68 void generate(SwitchNode *switchNode, Block *currentBlock, Value val);
69
70 /// Create the interpreter operations to record a successful pattern match
71 /// using the contained root operation. This operation may mutate the current
72 /// block pointer if nested regions (i.e., pdl_interp.iterate) are required.
73 void generate(SuccessNode *successNode, Block *&currentBlock);
74
75 /// Generate a rewriter function for the given pattern operation, and returns
76 /// a reference to that function.
77 SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
78 SmallVectorImpl<Position *> &usedMatchValues);
79
80 /// Generate the rewriter code for the given operation.
81 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
82 DenseMap<Value, Value> &rewriteValues,
83 function_ref<Value(Value)> mapRewriteValue);
84 void generateRewriter(pdl::AttributeOp attrOp,
85 DenseMap<Value, Value> &rewriteValues,
86 function_ref<Value(Value)> mapRewriteValue);
87 void generateRewriter(pdl::EraseOp eraseOp,
88 DenseMap<Value, Value> &rewriteValues,
89 function_ref<Value(Value)> mapRewriteValue);
90 void generateRewriter(pdl::OperationOp operationOp,
91 DenseMap<Value, Value> &rewriteValues,
92 function_ref<Value(Value)> mapRewriteValue);
93 void generateRewriter(pdl::RangeOp rangeOp,
94 DenseMap<Value, Value> &rewriteValues,
95 function_ref<Value(Value)> mapRewriteValue);
96 void generateRewriter(pdl::ReplaceOp replaceOp,
97 DenseMap<Value, Value> &rewriteValues,
98 function_ref<Value(Value)> mapRewriteValue);
99 void generateRewriter(pdl::ResultOp resultOp,
100 DenseMap<Value, Value> &rewriteValues,
101 function_ref<Value(Value)> mapRewriteValue);
102 void generateRewriter(pdl::ResultsOp resultOp,
103 DenseMap<Value, Value> &rewriteValues,
104 function_ref<Value(Value)> mapRewriteValue);
105 void generateRewriter(pdl::TypeOp typeOp,
106 DenseMap<Value, Value> &rewriteValues,
107 function_ref<Value(Value)> mapRewriteValue);
108 void generateRewriter(pdl::TypesOp typeOp,
109 DenseMap<Value, Value> &rewriteValues,
110 function_ref<Value(Value)> mapRewriteValue);
111
112 /// Generate the values used for resolving the result types of an operation
113 /// created within a dag rewriter region. If the result types of the operation
114 /// should be inferred, `hasInferredResultTypes` is set to true.
115 void generateOperationResultTypeRewriter(
116 pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
117 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
118 bool &hasInferredResultTypes);
119
120 /// A builder to use when generating interpreter operations.
121 OpBuilder builder;
122
123 /// The matcher function used for all match related logic within PDL patterns.
124 pdl_interp::FuncOp matcherFunc;
125
126 /// The rewriter module containing the all rewrite related logic within PDL
127 /// patterns.
128 ModuleOp rewriterModule;
129
130 /// The symbol table of the rewriter module used for insertion.
131 SymbolTable rewriterSymbolTable;
132
133 /// A scoped map connecting a position with the corresponding interpreter
134 /// value.
135 ValueMap values;
136
137 /// A stack of blocks used as the failure destination for matcher nodes that
138 /// don't have an explicit failure path.
139 SmallVector<Block *, 8> failureBlockStack;
140
141 /// A mapping between values defined in a pattern match, and the corresponding
142 /// positional value.
143 DenseMap<Value, Position *> valueToPosition;
144
145 /// The set of operation values whose location will be used for newly
146 /// generated operations.
147 SetVector<Value> locOps;
148
149 /// A mapping between pattern operations and the corresponding configuration
150 /// set.
151 DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
152
153 /// A mapping from a constraint question to the ApplyConstraintOp
154 /// that implements it.
155 DenseMap<ConstraintQuestion *, pdl_interp::ApplyConstraintOp> constraintOpMap;
156};
157} // namespace
158
159PatternLowering::PatternLowering(
160 pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
161 DenseMap<Operation *, PDLPatternConfigSet *> *configMap)
162 : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
163 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule),
164 configMap(configMap) {}
165
166void PatternLowering::lower(ModuleOp module) {
167 PredicateUniquer predicateUniquer;
168 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
169
170 // Define top-level scope for the arguments to the matcher function.
171 ValueMapScope topLevelValueScope(values);
172
173 // Insert the root operation, i.e. argument to the matcher, at the root
174 // position.
175 Block *matcherEntryBlock = &matcherFunc.front();
176 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
177
178 // Generate a root matcher node from the provided PDL module.
179 std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
180 module, predicateBuilder, valueToPosition);
181 Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody());
182 assert(failureBlockStack.empty() && "failed to empty the stack");
183
184 // After generation, merged the first matched block into the entry.
185 matcherEntryBlock->getOperations().splice(where: matcherEntryBlock->end(),
186 L2&: firstMatcherBlock->getOperations());
187 firstMatcherBlock->erase();
188}
189
190Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region,
191 Block *block) {
192 // Push a new scope for the values used by this matcher.
193 if (!block)
194 block = &region.emplaceBlock();
195 ValueMapScope scope(values);
196
197 // If this is the return node, simply insert the corresponding interpreter
198 // finalize.
199 if (isa<ExitNode>(Val: node)) {
200 builder.setInsertionPointToEnd(block);
201 builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
202 return block;
203 }
204
205 // Get the next block in the match sequence.
206 // This is intentionally executed first, before we get the value for the
207 // position associated with the node, so that we preserve an "there exist"
208 // semantics: if getting a value requires an upward traversal (going from a
209 // value to its consumers), we want to perform the check on all the consumers
210 // before we pass control to the failure node.
211 std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
212 Block *failureBlock;
213 if (failureNode) {
214 failureBlock = generateMatcher(node&: *failureNode, region);
215 failureBlockStack.push_back(failureBlock);
216 } else {
217 assert(!failureBlockStack.empty() && "expected valid failure block");
218 failureBlock = failureBlockStack.back();
219 }
220
221 // If this node contains a position, get the corresponding value for this
222 // block.
223 Block *currentBlock = block;
224 Position *position = node.getPosition();
225 Value val = position ? getValueAt(currentBlock, pos: position) : Value();
226
227 // If this value corresponds to an operation, record that we are going to use
228 // its location as part of a fused location.
229 bool isOperationValue = val && isa<pdl::OperationType>(val.getType());
230 if (isOperationValue)
231 locOps.insert(val);
232
233 // Dispatch to the correct method based on derived node type.
234 TypeSwitch<MatcherNode *>(&node)
235 .Case<BoolNode, SwitchNode>(caseFn: [&](auto *derivedNode) {
236 this->generate(derivedNode, currentBlock, val);
237 })
238 .Case(caseFn: [&](SuccessNode *successNode) {
239 generate(successNode, currentBlock);
240 });
241
242 // Pop all the failure blocks that were inserted due to nesting of
243 // pdl_interp.iterate.
244 while (failureBlockStack.back() != failureBlock) {
245 failureBlockStack.pop_back();
246 assert(!failureBlockStack.empty() && "unable to locate failure block");
247 }
248
249 // Pop the new failure block.
250 if (failureNode)
251 failureBlockStack.pop_back();
252
253 if (isOperationValue)
254 locOps.remove(val);
255
256 return block;
257}
258
259Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
260 if (Value val = values.lookup(pos))
261 return val;
262
263 // Get the value for the parent position.
264 Value parentVal;
265 if (Position *parent = pos->getParent())
266 parentVal = getValueAt(currentBlock, pos: parent);
267
268 // TODO: Use a location from the position.
269 Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
270 builder.setInsertionPointToEnd(currentBlock);
271 Value value;
272 switch (pos->getKind()) {
273 case Predicates::OperationPos: {
274 auto *operationPos = cast<OperationPosition>(Val: pos);
275 if (operationPos->isOperandDefiningOp())
276 // Standard (downward) traversal which directly follows the defining op.
277 value = builder.create<pdl_interp::GetDefiningOpOp>(
278 loc, builder.getType<pdl::OperationType>(), parentVal);
279 else
280 // A passthrough operation position.
281 value = parentVal;
282 break;
283 }
284 case Predicates::UsersPos: {
285 auto *usersPos = cast<UsersPosition>(Val: pos);
286
287 // The first operation retrieves the representative value of a range.
288 // This applies only when the parent is a range of values and we were
289 // requested to use a representative value (e.g., upward traversal).
290 if (isa<pdl::RangeType>(parentVal.getType()) &&
291 usersPos->useRepresentative())
292 value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0);
293 else
294 value = parentVal;
295
296 // The second operation retrieves the users.
297 value = builder.create<pdl_interp::GetUsersOp>(loc, value);
298 break;
299 }
300 case Predicates::ForEachPos: {
301 assert(!failureBlockStack.empty() && "expected valid failure block");
302 auto foreach = builder.create<pdl_interp::ForEachOp>(
303 loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
304 value = foreach.getLoopVariable();
305
306 // Create the continuation block.
307 Block *continueBlock = builder.createBlock(&foreach.getRegion());
308 builder.create<pdl_interp::ContinueOp>(loc);
309 failureBlockStack.push_back(continueBlock);
310
311 currentBlock = &foreach.getRegion().front();
312 break;
313 }
314 case Predicates::OperandPos: {
315 auto *operandPos = cast<OperandPosition>(Val: pos);
316 value = builder.create<pdl_interp::GetOperandOp>(
317 loc, builder.getType<pdl::ValueType>(), parentVal,
318 operandPos->getOperandNumber());
319 break;
320 }
321 case Predicates::OperandGroupPos: {
322 auto *operandPos = cast<OperandGroupPosition>(Val: pos);
323 Type valueTy = builder.getType<pdl::ValueType>();
324 value = builder.create<pdl_interp::GetOperandsOp>(
325 loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
326 parentVal, operandPos->getOperandGroupNumber());
327 break;
328 }
329 case Predicates::AttributePos: {
330 auto *attrPos = cast<AttributePosition>(Val: pos);
331 value = builder.create<pdl_interp::GetAttributeOp>(
332 loc, builder.getType<pdl::AttributeType>(), parentVal,
333 attrPos->getName().strref());
334 break;
335 }
336 case Predicates::TypePos: {
337 if (isa<pdl::AttributeType>(parentVal.getType()))
338 value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
339 else
340 value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
341 break;
342 }
343 case Predicates::ResultPos: {
344 auto *resPos = cast<ResultPosition>(Val: pos);
345 value = builder.create<pdl_interp::GetResultOp>(
346 loc, builder.getType<pdl::ValueType>(), parentVal,
347 resPos->getResultNumber());
348 break;
349 }
350 case Predicates::ResultGroupPos: {
351 auto *resPos = cast<ResultGroupPosition>(Val: pos);
352 Type valueTy = builder.getType<pdl::ValueType>();
353 value = builder.create<pdl_interp::GetResultsOp>(
354 loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
355 parentVal, resPos->getResultGroupNumber());
356 break;
357 }
358 case Predicates::AttributeLiteralPos: {
359 auto *attrPos = cast<AttributeLiteralPosition>(Val: pos);
360 value =
361 builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
362 break;
363 }
364 case Predicates::TypeLiteralPos: {
365 auto *typePos = cast<TypeLiteralPosition>(Val: pos);
366 Attribute rawTypeAttr = typePos->getValue();
367 if (TypeAttr typeAttr = dyn_cast<TypeAttr>(rawTypeAttr))
368 value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
369 else
370 value = builder.create<pdl_interp::CreateTypesOp>(
371 loc, cast<ArrayAttr>(rawTypeAttr));
372 break;
373 }
374 case Predicates::ConstraintResultPos: {
375 // Due to the order of traversal, the ApplyConstraintOp has already been
376 // created and we can find it in constraintOpMap.
377 auto *constrResPos = cast<ConstraintPosition>(Val: pos);
378 auto i = constraintOpMap.find(constrResPos->getQuestion());
379 assert(i != constraintOpMap.end());
380 value = i->second->getResult(constrResPos->getIndex());
381 break;
382 }
383 default:
384 llvm_unreachable("Generating unknown Position getter");
385 break;
386 }
387
388 values.insert(pos, value);
389 return value;
390}
391
392void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
393 Value val) {
394 Location loc = val.getLoc();
395 Qualifier *question = boolNode->getQuestion();
396 Qualifier *answer = boolNode->getAnswer();
397 Region *region = currentBlock->getParent();
398
399 // Execute the getValue queries first, so that we create success
400 // matcher in the correct (possibly nested) region.
401 SmallVector<Value> args;
402 if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(Val: question)) {
403 args = {getValueAt(currentBlock, pos: equalToQuestion->getValue())};
404 } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(Val: question)) {
405 for (Position *position : cstQuestion->getArgs())
406 args.push_back(Elt: getValueAt(currentBlock, pos: position));
407 }
408
409 // Generate a new block as success successor and get the failure successor.
410 Block *success = &region->emplaceBlock();
411 Block *failure = failureBlockStack.back();
412
413 // Create the predicate.
414 builder.setInsertionPointToEnd(currentBlock);
415 Predicates::Kind kind = question->getKind();
416 switch (kind) {
417 case Predicates::IsNotNullQuestion:
418 builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure);
419 break;
420 case Predicates::OperationNameQuestion: {
421 auto *opNameAnswer = cast<OperationNameAnswer>(Val: answer);
422 builder.create<pdl_interp::CheckOperationNameOp>(
423 loc, val, opNameAnswer->getValue().getStringRef(), success, failure);
424 break;
425 }
426 case Predicates::TypeQuestion: {
427 auto *ans = cast<TypeAnswer>(Val: answer);
428 if (isa<pdl::RangeType>(val.getType()))
429 builder.create<pdl_interp::CheckTypesOp>(
430 loc, val, llvm::cast<ArrayAttr>(ans->getValue()), success, failure);
431 else
432 builder.create<pdl_interp::CheckTypeOp>(
433 loc, val, llvm::cast<TypeAttr>(ans->getValue()), success, failure);
434 break;
435 }
436 case Predicates::AttributeQuestion: {
437 auto *ans = cast<AttributeAnswer>(Val: answer);
438 builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
439 success, failure);
440 break;
441 }
442 case Predicates::OperandCountAtLeastQuestion:
443 case Predicates::OperandCountQuestion:
444 builder.create<pdl_interp::CheckOperandCountOp>(
445 loc, val, cast<UnsignedAnswer>(answer)->getValue(),
446 /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
447 success, failure);
448 break;
449 case Predicates::ResultCountAtLeastQuestion:
450 case Predicates::ResultCountQuestion:
451 builder.create<pdl_interp::CheckResultCountOp>(
452 loc, val, cast<UnsignedAnswer>(answer)->getValue(),
453 /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
454 success, failure);
455 break;
456 case Predicates::EqualToQuestion: {
457 bool trueAnswer = isa<TrueAnswer>(Val: answer);
458 builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(),
459 trueAnswer ? success : failure,
460 trueAnswer ? failure : success);
461 break;
462 }
463 case Predicates::ConstraintQuestion: {
464 auto *cstQuestion = cast<ConstraintQuestion>(Val: question);
465 auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
466 loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
467 cstQuestion->getIsNegated(), success, failure);
468
469 constraintOpMap.insert({cstQuestion, applyConstraintOp});
470 break;
471 }
472 default:
473 llvm_unreachable("Generating unknown Predicate operation");
474 }
475
476 // Generate the matcher in the current (potentially nested) region.
477 // This might use the results of the current predicate.
478 generateMatcher(node&: *boolNode->getSuccessNode(), region&: *region, block: success);
479}
480
481template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
482static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
483 llvm::MapVector<Qualifier *, Block *> &dests) {
484 std::vector<ValT> values;
485 std::vector<Block *> blocks;
486 values.reserve(dests.size());
487 blocks.reserve(n: dests.size());
488 for (const auto &it : dests) {
489 blocks.push_back(x: it.second);
490 values.push_back(cast<PredT>(it.first)->getValue());
491 }
492 builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
493}
494
495void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
496 Value val) {
497 Qualifier *question = switchNode->getQuestion();
498 Region *region = currentBlock->getParent();
499 Block *defaultDest = failureBlockStack.back();
500
501 // If the switch question is not an exact answer, i.e. for the `at_least`
502 // cases, we generate a special block sequence.
503 Predicates::Kind kind = question->getKind();
504 if (kind == Predicates::OperandCountAtLeastQuestion ||
505 kind == Predicates::ResultCountAtLeastQuestion) {
506 // Order the children such that the cases are in reverse numerical order.
507 SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
508 Range: llvm::seq<unsigned>(Begin: 0, End: switchNode->getChildren().size()));
509 llvm::sort(C&: sortedChildren, Comp: [&](unsigned lhs, unsigned rhs) {
510 return cast<UnsignedAnswer>(Val: switchNode->getChild(i: lhs).first)->getValue() >
511 cast<UnsignedAnswer>(Val: switchNode->getChild(i: rhs).first)->getValue();
512 });
513
514 // Build the destination for each child using the next highest child as a
515 // a failure destination. This essentially creates the following control
516 // flow:
517 //
518 // if (operand_count < 1)
519 // goto failure
520 // if (child1.match())
521 // ...
522 //
523 // if (operand_count < 2)
524 // goto failure
525 // if (child2.match())
526 // ...
527 //
528 // failure:
529 // ...
530 //
531 failureBlockStack.push_back(defaultDest);
532 Location loc = val.getLoc();
533 for (unsigned idx : sortedChildren) {
534 auto &child = switchNode->getChild(i: idx);
535 Block *childBlock = generateMatcher(node&: *child.second, region&: *region);
536 Block *predicateBlock = builder.createBlock(insertBefore: childBlock);
537 builder.setInsertionPointToEnd(predicateBlock);
538 unsigned ans = cast<UnsignedAnswer>(Val: child.first)->getValue();
539 switch (kind) {
540 case Predicates::OperandCountAtLeastQuestion:
541 builder.create<pdl_interp::CheckOperandCountOp>(
542 loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
543 break;
544 case Predicates::ResultCountAtLeastQuestion:
545 builder.create<pdl_interp::CheckResultCountOp>(
546 loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
547 break;
548 default:
549 llvm_unreachable("Generating invalid AtLeast operation");
550 }
551 failureBlockStack.back() = predicateBlock;
552 }
553 Block *firstPredicateBlock = failureBlockStack.pop_back_val();
554 currentBlock->getOperations().splice(where: currentBlock->end(),
555 L2&: firstPredicateBlock->getOperations());
556 firstPredicateBlock->erase();
557 return;
558 }
559
560 // Otherwise, generate each of the children and generate an interpreter
561 // switch.
562 llvm::MapVector<Qualifier *, Block *> children;
563 for (auto &it : switchNode->getChildren())
564 children.insert(KV: {it.first, generateMatcher(node&: *it.second, region&: *region)});
565 builder.setInsertionPointToEnd(currentBlock);
566
567 switch (question->getKind()) {
568 case Predicates::OperandCountQuestion:
569 return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
570 int32_t>(val, defaultDest, builder, children);
571 case Predicates::ResultCountQuestion:
572 return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
573 int32_t>(val, defaultDest, builder, children);
574 case Predicates::OperationNameQuestion:
575 return createSwitchOp<pdl_interp::SwitchOperationNameOp,
576 OperationNameAnswer>(val, defaultDest, builder,
577 children);
578 case Predicates::TypeQuestion:
579 if (isa<pdl::RangeType>(val.getType())) {
580 return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
581 val, defaultDest, builder, children);
582 }
583 return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
584 val, defaultDest, builder, children);
585 case Predicates::AttributeQuestion:
586 return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
587 val, defaultDest, builder, children);
588 default:
589 llvm_unreachable("Generating unknown switch predicate.");
590 }
591}
592
593void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {
594 pdl::PatternOp pattern = successNode->getPattern();
595 Value root = successNode->getRoot();
596
597 // Generate a rewriter for the pattern this success node represents, and track
598 // any values used from the match region.
599 SmallVector<Position *, 8> usedMatchValues;
600 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
601
602 // Process any values used in the rewrite that are defined in the match.
603 std::vector<Value> mappedMatchValues;
604 mappedMatchValues.reserve(n: usedMatchValues.size());
605 for (Position *position : usedMatchValues)
606 mappedMatchValues.push_back(x: getValueAt(currentBlock, pos: position));
607
608 // Collect the set of operations generated by the rewriter.
609 SmallVector<StringRef, 4> generatedOps;
610 for (auto op :
611 pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>())
612 generatedOps.push_back(*op.getOpName());
613 ArrayAttr generatedOpsAttr;
614 if (!generatedOps.empty())
615 generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
616
617 // Grab the root kind if present.
618 StringAttr rootKindAttr;
619 if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
620 if (std::optional<StringRef> rootKind = rootOp.getOpName())
621 rootKindAttr = builder.getStringAttr(*rootKind);
622
623 builder.setInsertionPointToEnd(currentBlock);
624 auto matchOp = builder.create<pdl_interp::RecordMatchOp>(
625 pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
626 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(),
627 failureBlockStack.back());
628
629 // Set the config of the lowered match to the parent pattern.
630 if (configMap)
631 configMap->try_emplace(matchOp, configMap->lookup(pattern));
632}
633
634SymbolRefAttr PatternLowering::generateRewriter(
635 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
636 builder.setInsertionPointToEnd(rewriterModule.getBody());
637 auto rewriterFunc = builder.create<pdl_interp::FuncOp>(
638 pattern.getLoc(), "pdl_generated_rewriter",
639 builder.getFunctionType(std::nullopt, std::nullopt));
640 rewriterSymbolTable.insert(symbol: rewriterFunc);
641
642 // Generate the rewriter function body.
643 builder.setInsertionPointToEnd(&rewriterFunc.front());
644
645 // Map an input operand of the pattern to a generated interpreter value.
646 DenseMap<Value, Value> rewriteValues;
647 auto mapRewriteValue = [&](Value oldValue) {
648 Value &newValue = rewriteValues[oldValue];
649 if (newValue)
650 return newValue;
651
652 // Prefer materializing constants directly when possible.
653 Operation *oldOp = oldValue.getDefiningOp();
654 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
655 if (Attribute value = attrOp.getValueAttr()) {
656 return newValue = builder.create<pdl_interp::CreateAttributeOp>(
657 attrOp.getLoc(), value);
658 }
659 } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
660 if (TypeAttr type = typeOp.getConstantTypeAttr()) {
661 return newValue = builder.create<pdl_interp::CreateTypeOp>(
662 typeOp.getLoc(), type);
663 }
664 } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
665 if (ArrayAttr type = typeOp.getConstantTypesAttr()) {
666 return newValue = builder.create<pdl_interp::CreateTypesOp>(
667 typeOp.getLoc(), typeOp.getType(), type);
668 }
669 }
670
671 // Otherwise, add this as an input to the rewriter.
672 Position *inputPos = valueToPosition.lookup(oldValue);
673 assert(inputPos && "expected value to be a pattern input");
674 usedMatchValues.push_back(Elt: inputPos);
675 return newValue = rewriterFunc.front().addArgument(oldValue.getType(),
676 oldValue.getLoc());
677 };
678
679 // If this is a custom rewriter, simply dispatch to the registered rewrite
680 // method.
681 pdl::RewriteOp rewriter = pattern.getRewriter();
682 if (StringAttr rewriteName = rewriter.getNameAttr()) {
683 SmallVector<Value> args;
684 if (rewriter.getRoot())
685 args.push_back(Elt: mapRewriteValue(rewriter.getRoot()));
686 auto mappedArgs =
687 llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue);
688 args.append(mappedArgs.begin(), mappedArgs.end());
689 builder.create<pdl_interp::ApplyRewriteOp>(
690 rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args);
691 } else {
692 // Otherwise this is a dag rewriter defined using PDL operations.
693 for (Operation &rewriteOp : *rewriter.getBody()) {
694 llvm::TypeSwitch<Operation *>(&rewriteOp)
695 .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
696 pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp,
697 pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](auto op) {
698 this->generateRewriter(op, rewriteValues, mapRewriteValue);
699 });
700 }
701 }
702
703 // Update the signature of the rewrite function.
704 rewriterFunc.setType(builder.getFunctionType(
705 inputs: llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
706 /*results=*/std::nullopt));
707
708 builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
709 return SymbolRefAttr::get(
710 builder.getContext(),
711 pdl_interp::PDLInterpDialect::getRewriterModuleName(),
712 SymbolRefAttr::get(rewriterFunc));
713}
714
715void PatternLowering::generateRewriter(
716 pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
717 function_ref<Value(Value)> mapRewriteValue) {
718 SmallVector<Value, 2> arguments;
719 for (Value argument : rewriteOp.getArgs())
720 arguments.push_back(mapRewriteValue(argument));
721 auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
722 rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.getNameAttr(),
723 arguments);
724 for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
725 rewriteValues[std::get<0>(it)] = std::get<1>(it);
726}
727
728void PatternLowering::generateRewriter(
729 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
730 function_ref<Value(Value)> mapRewriteValue) {
731 Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
732 attrOp.getLoc(), attrOp.getValueAttr());
733 rewriteValues[attrOp] = newAttr;
734}
735
736void PatternLowering::generateRewriter(
737 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
738 function_ref<Value(Value)> mapRewriteValue) {
739 builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
740 mapRewriteValue(eraseOp.getOpValue()));
741}
742
743void PatternLowering::generateRewriter(
744 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
745 function_ref<Value(Value)> mapRewriteValue) {
746 SmallVector<Value, 4> operands;
747 for (Value operand : operationOp.getOperandValues())
748 operands.push_back(mapRewriteValue(operand));
749
750 SmallVector<Value, 4> attributes;
751 for (Value attr : operationOp.getAttributeValues())
752 attributes.push_back(mapRewriteValue(attr));
753
754 bool hasInferredResultTypes = false;
755 SmallVector<Value, 2> types;
756 generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types,
757 rewriteValues, hasInferredResultTypes);
758
759 // Create the new operation.
760 Location loc = operationOp.getLoc();
761 Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
762 loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands,
763 attributes, operationOp.getAttributeValueNames());
764 rewriteValues[operationOp.getOp()] = createdOp;
765
766 // Generate accesses for any results that have their types constrained.
767 // Handle the case where there is a single range representing all of the
768 // result types.
769 OperandRange resultTys = operationOp.getTypeValues();
770 if (resultTys.size() == 1 && isa<pdl::RangeType>(resultTys[0].getType())) {
771 Value &type = rewriteValues[resultTys[0]];
772 if (!type) {
773 auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
774 type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
775 }
776 return;
777 }
778
779 // Otherwise, populate the individual results.
780 bool seenVariableLength = false;
781 Type valueTy = builder.getType<pdl::ValueType>();
782 Type valueRangeTy = pdl::RangeType::get(valueTy);
783 for (const auto &it : llvm::enumerate(resultTys)) {
784 Value &type = rewriteValues[it.value()];
785 if (type)
786 continue;
787 bool isVariadic = isa<pdl::RangeType>(it.value().getType());
788 seenVariableLength |= isVariadic;
789
790 // After a variable length result has been seen, we need to use result
791 // groups because the exact index of the result is not statically known.
792 Value resultVal;
793 if (seenVariableLength)
794 resultVal = builder.create<pdl_interp::GetResultsOp>(
795 loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
796 else
797 resultVal = builder.create<pdl_interp::GetResultOp>(
798 loc, valueTy, createdOp, it.index());
799 type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
800 }
801}
802
803void PatternLowering::generateRewriter(
804 pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues,
805 function_ref<Value(Value)> mapRewriteValue) {
806 SmallVector<Value, 4> replOperands;
807 for (Value operand : rangeOp.getArguments())
808 replOperands.push_back(mapRewriteValue(operand));
809 rewriteValues[rangeOp] = builder.create<pdl_interp::CreateRangeOp>(
810 rangeOp.getLoc(), rangeOp.getType(), replOperands);
811}
812
813void PatternLowering::generateRewriter(
814 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
815 function_ref<Value(Value)> mapRewriteValue) {
816 SmallVector<Value, 4> replOperands;
817
818 // If the replacement was another operation, get its results. `pdl` allows
819 // for using an operation for simplicitly, but the interpreter isn't as
820 // user facing.
821 if (Value replOp = replaceOp.getReplOperation()) {
822 // Don't use replace if we know the replaced operation has no results.
823 auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
824 if (!opOp || !opOp.getTypeValues().empty()) {
825 replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
826 replOp.getLoc(), mapRewriteValue(replOp)));
827 }
828 } else {
829 for (Value operand : replaceOp.getReplValues())
830 replOperands.push_back(mapRewriteValue(operand));
831 }
832
833 // If there are no replacement values, just create an erase instead.
834 if (replOperands.empty()) {
835 builder.create<pdl_interp::EraseOp>(
836 replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue()));
837 return;
838 }
839
840 builder.create<pdl_interp::ReplaceOp>(replaceOp.getLoc(),
841 mapRewriteValue(replaceOp.getOpValue()),
842 replOperands);
843}
844
845void PatternLowering::generateRewriter(
846 pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
847 function_ref<Value(Value)> mapRewriteValue) {
848 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
849 resultOp.getLoc(), builder.getType<pdl::ValueType>(),
850 mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
851}
852
853void PatternLowering::generateRewriter(
854 pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
855 function_ref<Value(Value)> mapRewriteValue) {
856 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
857 resultOp.getLoc(), resultOp.getType(),
858 mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
859}
860
861void PatternLowering::generateRewriter(
862 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
863 function_ref<Value(Value)> mapRewriteValue) {
864 // If the type isn't constant, the users (e.g. OperationOp) will resolve this
865 // type.
866 if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) {
867 rewriteValues[typeOp] =
868 builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
869 }
870}
871
872void PatternLowering::generateRewriter(
873 pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
874 function_ref<Value(Value)> mapRewriteValue) {
875 // If the type isn't constant, the users (e.g. OperationOp) will resolve this
876 // type.
877 if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) {
878 rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
879 typeOp.getLoc(), typeOp.getType(), typeAttr);
880 }
881}
882
883void PatternLowering::generateOperationResultTypeRewriter(
884 pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
885 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
886 bool &hasInferredResultTypes) {
887 Block *rewriterBlock = op->getBlock();
888
889 // Try to handle resolution for each of the result types individually. This is
890 // preferred over type inferrence because it will allow for us to use existing
891 // types directly, as opposed to trying to rebuild the type list.
892 OperandRange resultTypeValues = op.getTypeValues();
893 auto tryResolveResultTypes = [&] {
894 types.reserve(N: resultTypeValues.size());
895 for (const auto &it : llvm::enumerate(resultTypeValues)) {
896 Value resultType = it.value();
897
898 // Check for an already translated value.
899 if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
900 types.push_back(existingRewriteValue);
901 continue;
902 }
903
904 // Check for an input from the matcher.
905 if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
906 types.push_back(mapRewriteValue(resultType));
907 continue;
908 }
909
910 // Otherwise, we couldn't infer the result types. Bail out here to see if
911 // we can infer the types for this operation from another way.
912 types.clear();
913 return failure();
914 }
915 return success();
916 };
917 if (!resultTypeValues.empty() && succeeded(result: tryResolveResultTypes()))
918 return;
919
920 // Otherwise, check if the operation has type inference support itself.
921 if (op.hasTypeInference()) {
922 hasInferredResultTypes = true;
923 return;
924 }
925
926 // Look for an operation that was replaced by `op`. The result types will be
927 // inferred from the results that were replaced.
928 for (OpOperand &use : op.getOp().getUses()) {
929 // Check that the use corresponds to a ReplaceOp and that it is the
930 // replacement value, not the operation being replaced.
931 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
932 if (!replOpUser || use.getOperandNumber() == 0)
933 continue;
934 // Make sure the replaced operation was defined before this one. PDL
935 // rewrites only have single block regions, so if the op isn't in the
936 // rewriter block (i.e. the current block of the operation) we already know
937 // it dominates (i.e. it's in the matcher).
938 Value replOpVal = replOpUser.getOpValue();
939 Operation *replacedOp = replOpVal.getDefiningOp();
940 if (replacedOp->getBlock() == rewriterBlock &&
941 !replacedOp->isBeforeInBlock(op))
942 continue;
943
944 Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
945 replacedOp->getLoc(), mapRewriteValue(replOpVal));
946 types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
947 replacedOp->getLoc(), replacedOpResults));
948 return;
949 }
950
951 // If the types could not be inferred from any context and there weren't any
952 // explicit result types, assume the user actually meant for the operation to
953 // have no results.
954 if (resultTypeValues.empty())
955 return;
956
957 // The verifier asserts that the result types of each pdl.getOperation can be
958 // inferred. If we reach here, there is a bug either in the logic above or
959 // in the verifier for pdl.getOperation.
960 op->emitOpError() << "unable to infer result type for operation";
961 llvm_unreachable("unable to infer result type for operation");
962}
963
964//===----------------------------------------------------------------------===//
965// Conversion Pass
966//===----------------------------------------------------------------------===//
967
968namespace {
969struct PDLToPDLInterpPass
970 : public impl::ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
971 PDLToPDLInterpPass() = default;
972 PDLToPDLInterpPass(const PDLToPDLInterpPass &rhs) = default;
973 PDLToPDLInterpPass(DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
974 : configMap(&configMap) {}
975 void runOnOperation() final;
976
977 /// A map containing the configuration for each pattern.
978 DenseMap<Operation *, PDLPatternConfigSet *> *configMap = nullptr;
979};
980} // namespace
981
982/// Convert the given module containing PDL pattern operations into a PDL
983/// Interpreter operations.
984void PDLToPDLInterpPass::runOnOperation() {
985 ModuleOp module = getOperation();
986
987 // Create the main matcher function This function contains all of the match
988 // related functionality from patterns in the module.
989 OpBuilder builder = OpBuilder::atBlockBegin(block: module.getBody());
990 auto matcherFunc = builder.create<pdl_interp::FuncOp>(
991 module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
992 builder.getFunctionType(builder.getType<pdl::OperationType>(),
993 /*results=*/std::nullopt),
994 /*attrs=*/std::nullopt);
995
996 // Create a nested module to hold the functions invoked for rewriting the IR
997 // after a successful match.
998 ModuleOp rewriterModule = builder.create<ModuleOp>(
999 module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
1000
1001 // Generate the code for the patterns within the module.
1002 PatternLowering generator(matcherFunc, rewriterModule, configMap);
1003 generator.lower(module: module);
1004
1005 // After generation, delete all of the pattern operations.
1006 for (pdl::PatternOp pattern :
1007 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) {
1008 // Drop the now dead config mappings.
1009 if (configMap)
1010 configMap->erase(pattern);
1011
1012 pattern.erase();
1013 }
1014}
1015
1016std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
1017 return std::make_unique<PDLToPDLInterpPass>();
1018}
1019std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass(
1020 DenseMap<Operation *, PDLPatternConfigSet *> &configMap) {
1021 return std::make_unique<PDLToPDLInterpPass>(args&: configMap);
1022}
1023

source code of mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp