1//===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
10
11#include "PredicateTree.h"
12#include "mlir/Dialect/PDL/IR/PDLTypes.h"
13#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
14#include "mlir/Pass/Pass.h"
15#include "llvm/ADT/MapVector.h"
16#include "llvm/ADT/ScopedHashTable.h"
17#include "llvm/ADT/Sequence.h"
18#include "llvm/ADT/SmallVector.h"
19#include "llvm/ADT/TypeSwitch.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_CONVERTPDLTOPDLINTERPPASS
23#include "mlir/Conversion/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27using namespace mlir::pdl_to_pdl_interp;
28
29//===----------------------------------------------------------------------===//
30// PatternLowering
31//===----------------------------------------------------------------------===//
32
33namespace {
34/// This class generators operations within the PDL Interpreter dialect from a
35/// given module containing PDL pattern operations.
36struct PatternLowering {
37public:
38 PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
39 DenseMap<Operation *, PDLPatternConfigSet *> *configMap);
40
41 /// Generate code for matching and rewriting based on the pattern operations
42 /// within the module.
43 void lower(ModuleOp module);
44
45private:
46 using ValueMap = llvm::ScopedHashTable<Position *, Value>;
47 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
48
49 /// Generate interpreter operations for the tree rooted at the given matcher
50 /// node, in the specified region.
51 Block *generateMatcher(MatcherNode &node, Region &region,
52 Block *block = nullptr);
53
54 /// Get or create an access to the provided positional value in the current
55 /// block. This operation may mutate the provided block pointer if nested
56 /// regions (i.e., pdl_interp.iterate) are required.
57 Value getValueAt(Block *&currentBlock, Position *pos);
58
59 /// Create the interpreter predicate operations. This operation may mutate the
60 /// provided current block pointer if nested regions (iterates) are required.
61 void generate(BoolNode *boolNode, Block *&currentBlock, Value val);
62
63 /// Create the interpreter switch / predicate operations, with several case
64 /// destinations. This operation never mutates the provided current block
65 /// pointer, because the switch operation does not need Values beyond `val`.
66 void generate(SwitchNode *switchNode, Block *currentBlock, Value val);
67
68 /// Create the interpreter operations to record a successful pattern match
69 /// using the contained root operation. This operation may mutate the current
70 /// block pointer if nested regions (i.e., pdl_interp.iterate) are required.
71 void generate(SuccessNode *successNode, Block *&currentBlock);
72
73 /// Generate a rewriter function for the given pattern operation, and returns
74 /// a reference to that function.
75 SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
76 SmallVectorImpl<Position *> &usedMatchValues);
77
78 /// Generate the rewriter code for the given operation.
79 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
80 DenseMap<Value, Value> &rewriteValues,
81 function_ref<Value(Value)> mapRewriteValue);
82 void generateRewriter(pdl::AttributeOp attrOp,
83 DenseMap<Value, Value> &rewriteValues,
84 function_ref<Value(Value)> mapRewriteValue);
85 void generateRewriter(pdl::EraseOp eraseOp,
86 DenseMap<Value, Value> &rewriteValues,
87 function_ref<Value(Value)> mapRewriteValue);
88 void generateRewriter(pdl::OperationOp operationOp,
89 DenseMap<Value, Value> &rewriteValues,
90 function_ref<Value(Value)> mapRewriteValue);
91 void generateRewriter(pdl::RangeOp rangeOp,
92 DenseMap<Value, Value> &rewriteValues,
93 function_ref<Value(Value)> mapRewriteValue);
94 void generateRewriter(pdl::ReplaceOp replaceOp,
95 DenseMap<Value, Value> &rewriteValues,
96 function_ref<Value(Value)> mapRewriteValue);
97 void generateRewriter(pdl::ResultOp resultOp,
98 DenseMap<Value, Value> &rewriteValues,
99 function_ref<Value(Value)> mapRewriteValue);
100 void generateRewriter(pdl::ResultsOp resultOp,
101 DenseMap<Value, Value> &rewriteValues,
102 function_ref<Value(Value)> mapRewriteValue);
103 void generateRewriter(pdl::TypeOp typeOp,
104 DenseMap<Value, Value> &rewriteValues,
105 function_ref<Value(Value)> mapRewriteValue);
106 void generateRewriter(pdl::TypesOp typeOp,
107 DenseMap<Value, Value> &rewriteValues,
108 function_ref<Value(Value)> mapRewriteValue);
109
110 /// Generate the values used for resolving the result types of an operation
111 /// created within a dag rewriter region. If the result types of the operation
112 /// should be inferred, `hasInferredResultTypes` is set to true.
113 void generateOperationResultTypeRewriter(
114 pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
115 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
116 bool &hasInferredResultTypes);
117
118 /// A builder to use when generating interpreter operations.
119 OpBuilder builder;
120
121 /// The matcher function used for all match related logic within PDL patterns.
122 pdl_interp::FuncOp matcherFunc;
123
124 /// The rewriter module containing the all rewrite related logic within PDL
125 /// patterns.
126 ModuleOp rewriterModule;
127
128 /// The symbol table of the rewriter module used for insertion.
129 SymbolTable rewriterSymbolTable;
130
131 /// A scoped map connecting a position with the corresponding interpreter
132 /// value.
133 ValueMap values;
134
135 /// A stack of blocks used as the failure destination for matcher nodes that
136 /// don't have an explicit failure path.
137 SmallVector<Block *, 8> failureBlockStack;
138
139 /// A mapping between values defined in a pattern match, and the corresponding
140 /// positional value.
141 DenseMap<Value, Position *> valueToPosition;
142
143 /// The set of operation values whose location will be used for newly
144 /// generated operations.
145 SetVector<Value> locOps;
146
147 /// A mapping between pattern operations and the corresponding configuration
148 /// set.
149 DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
150
151 /// A mapping from a constraint question to the ApplyConstraintOp
152 /// that implements it.
153 DenseMap<ConstraintQuestion *, pdl_interp::ApplyConstraintOp> constraintOpMap;
154};
155} // namespace
156
157PatternLowering::PatternLowering(
158 pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
159 DenseMap<Operation *, PDLPatternConfigSet *> *configMap)
160 : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
161 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule),
162 configMap(configMap) {}
163
164void PatternLowering::lower(ModuleOp module) {
165 PredicateUniquer predicateUniquer;
166 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
167
168 // Define top-level scope for the arguments to the matcher function.
169 ValueMapScope topLevelValueScope(values);
170
171 // Insert the root operation, i.e. argument to the matcher, at the root
172 // position.
173 Block *matcherEntryBlock = &matcherFunc.front();
174 values.insert(Key: predicateBuilder.getRoot(), Val: matcherEntryBlock->getArgument(i: 0));
175
176 // Generate a root matcher node from the provided PDL module.
177 std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
178 module, builder&: predicateBuilder, valueToPosition);
179 Block *firstMatcherBlock = generateMatcher(node&: *root, region&: matcherFunc.getBody());
180 assert(failureBlockStack.empty() && "failed to empty the stack");
181
182 // After generation, merged the first matched block into the entry.
183 matcherEntryBlock->getOperations().splice(where: matcherEntryBlock->end(),
184 L2&: firstMatcherBlock->getOperations());
185 firstMatcherBlock->erase();
186}
187
188Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region,
189 Block *block) {
190 // Push a new scope for the values used by this matcher.
191 if (!block)
192 block = &region.emplaceBlock();
193 ValueMapScope scope(values);
194
195 // If this is the return node, simply insert the corresponding interpreter
196 // finalize.
197 if (isa<ExitNode>(Val: node)) {
198 builder.setInsertionPointToEnd(block);
199 builder.create<pdl_interp::FinalizeOp>(location: matcherFunc.getLoc());
200 return block;
201 }
202
203 // Get the next block in the match sequence.
204 // This is intentionally executed first, before we get the value for the
205 // position associated with the node, so that we preserve an "there exist"
206 // semantics: if getting a value requires an upward traversal (going from a
207 // value to its consumers), we want to perform the check on all the consumers
208 // before we pass control to the failure node.
209 std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
210 Block *failureBlock;
211 if (failureNode) {
212 failureBlock = generateMatcher(node&: *failureNode, region);
213 failureBlockStack.push_back(Elt: failureBlock);
214 } else {
215 assert(!failureBlockStack.empty() && "expected valid failure block");
216 failureBlock = failureBlockStack.back();
217 }
218
219 // If this node contains a position, get the corresponding value for this
220 // block.
221 Block *currentBlock = block;
222 Position *position = node.getPosition();
223 Value val = position ? getValueAt(currentBlock, pos: position) : Value();
224
225 // If this value corresponds to an operation, record that we are going to use
226 // its location as part of a fused location.
227 bool isOperationValue = val && isa<pdl::OperationType>(Val: val.getType());
228 if (isOperationValue)
229 locOps.insert(X: val);
230
231 // Dispatch to the correct method based on derived node type.
232 TypeSwitch<MatcherNode *>(&node)
233 .Case<BoolNode, SwitchNode>(caseFn: [&](auto *derivedNode) {
234 this->generate(derivedNode, currentBlock, val);
235 })
236 .Case(caseFn: [&](SuccessNode *successNode) {
237 generate(successNode, currentBlock);
238 });
239
240 // Pop all the failure blocks that were inserted due to nesting of
241 // pdl_interp.iterate.
242 while (failureBlockStack.back() != failureBlock) {
243 failureBlockStack.pop_back();
244 assert(!failureBlockStack.empty() && "unable to locate failure block");
245 }
246
247 // Pop the new failure block.
248 if (failureNode)
249 failureBlockStack.pop_back();
250
251 if (isOperationValue)
252 locOps.remove(X: val);
253
254 return block;
255}
256
257Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
258 if (Value val = values.lookup(Key: pos))
259 return val;
260
261 // Get the value for the parent position.
262 Value parentVal;
263 if (Position *parent = pos->getParent())
264 parentVal = getValueAt(currentBlock, pos: parent);
265
266 // TODO: Use a location from the position.
267 Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
268 builder.setInsertionPointToEnd(currentBlock);
269 Value value;
270 switch (pos->getKind()) {
271 case Predicates::OperationPos: {
272 auto *operationPos = cast<OperationPosition>(Val: pos);
273 if (operationPos->isOperandDefiningOp())
274 // Standard (downward) traversal which directly follows the defining op.
275 value = builder.create<pdl_interp::GetDefiningOpOp>(
276 location: loc, args: builder.getType<pdl::OperationType>(), args&: parentVal);
277 else
278 // A passthrough operation position.
279 value = parentVal;
280 break;
281 }
282 case Predicates::UsersPos: {
283 auto *usersPos = cast<UsersPosition>(Val: pos);
284
285 // The first operation retrieves the representative value of a range.
286 // This applies only when the parent is a range of values and we were
287 // requested to use a representative value (e.g., upward traversal).
288 if (isa<pdl::RangeType>(Val: parentVal.getType()) &&
289 usersPos->useRepresentative())
290 value = builder.create<pdl_interp::ExtractOp>(location: loc, args&: parentVal, args: 0);
291 else
292 value = parentVal;
293
294 // The second operation retrieves the users.
295 value = builder.create<pdl_interp::GetUsersOp>(location: loc, args&: value);
296 break;
297 }
298 case Predicates::ForEachPos: {
299 assert(!failureBlockStack.empty() && "expected valid failure block");
300 auto foreach = builder.create<pdl_interp::ForEachOp>(
301 location: loc, args&: parentVal, args&: failureBlockStack.back(), /*initLoop=*/args: true);
302 value = foreach.getLoopVariable();
303
304 // Create the continuation block.
305 Block *continueBlock = builder.createBlock(parent: &foreach.getRegion());
306 builder.create<pdl_interp::ContinueOp>(location: loc);
307 failureBlockStack.push_back(Elt: continueBlock);
308
309 currentBlock = &foreach.getRegion().front();
310 break;
311 }
312 case Predicates::OperandPos: {
313 auto *operandPos = cast<OperandPosition>(Val: pos);
314 value = builder.create<pdl_interp::GetOperandOp>(
315 location: loc, args: builder.getType<pdl::ValueType>(), args&: parentVal,
316 args: operandPos->getOperandNumber());
317 break;
318 }
319 case Predicates::OperandGroupPos: {
320 auto *operandPos = cast<OperandGroupPosition>(Val: pos);
321 Type valueTy = builder.getType<pdl::ValueType>();
322 value = builder.create<pdl_interp::GetOperandsOp>(
323 location: loc, args: operandPos->isVariadic() ? pdl::RangeType::get(elementType: valueTy) : valueTy,
324 args&: parentVal, args: operandPos->getOperandGroupNumber());
325 break;
326 }
327 case Predicates::AttributePos: {
328 auto *attrPos = cast<AttributePosition>(Val: pos);
329 value = builder.create<pdl_interp::GetAttributeOp>(
330 location: loc, args: builder.getType<pdl::AttributeType>(), args&: parentVal,
331 args: attrPos->getName().strref());
332 break;
333 }
334 case Predicates::TypePos: {
335 if (isa<pdl::AttributeType>(Val: parentVal.getType()))
336 value = builder.create<pdl_interp::GetAttributeTypeOp>(location: loc, args&: parentVal);
337 else
338 value = builder.create<pdl_interp::GetValueTypeOp>(location: loc, args&: parentVal);
339 break;
340 }
341 case Predicates::ResultPos: {
342 auto *resPos = cast<ResultPosition>(Val: pos);
343 value = builder.create<pdl_interp::GetResultOp>(
344 location: loc, args: builder.getType<pdl::ValueType>(), args&: parentVal,
345 args: resPos->getResultNumber());
346 break;
347 }
348 case Predicates::ResultGroupPos: {
349 auto *resPos = cast<ResultGroupPosition>(Val: pos);
350 Type valueTy = builder.getType<pdl::ValueType>();
351 value = builder.create<pdl_interp::GetResultsOp>(
352 location: loc, args: resPos->isVariadic() ? pdl::RangeType::get(elementType: valueTy) : valueTy,
353 args&: parentVal, args: resPos->getResultGroupNumber());
354 break;
355 }
356 case Predicates::AttributeLiteralPos: {
357 auto *attrPos = cast<AttributeLiteralPosition>(Val: pos);
358 value =
359 builder.create<pdl_interp::CreateAttributeOp>(location: loc, args: attrPos->getValue());
360 break;
361 }
362 case Predicates::TypeLiteralPos: {
363 auto *typePos = cast<TypeLiteralPosition>(Val: pos);
364 Attribute rawTypeAttr = typePos->getValue();
365 if (TypeAttr typeAttr = dyn_cast<TypeAttr>(Val&: rawTypeAttr))
366 value = builder.create<pdl_interp::CreateTypeOp>(location: loc, args&: typeAttr);
367 else
368 value = builder.create<pdl_interp::CreateTypesOp>(
369 location: loc, args: cast<ArrayAttr>(Val&: rawTypeAttr));
370 break;
371 }
372 case Predicates::ConstraintResultPos: {
373 // Due to the order of traversal, the ApplyConstraintOp has already been
374 // created and we can find it in constraintOpMap.
375 auto *constrResPos = cast<ConstraintPosition>(Val: pos);
376 auto i = constraintOpMap.find(Val: constrResPos->getQuestion());
377 assert(i != constraintOpMap.end());
378 value = i->second->getResult(idx: constrResPos->getIndex());
379 break;
380 }
381 default:
382 llvm_unreachable("Generating unknown Position getter");
383 break;
384 }
385
386 values.insert(Key: pos, Val: value);
387 return value;
388}
389
390void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
391 Value val) {
392 Location loc = val.getLoc();
393 Qualifier *question = boolNode->getQuestion();
394 Qualifier *answer = boolNode->getAnswer();
395 Region *region = currentBlock->getParent();
396
397 // Execute the getValue queries first, so that we create success
398 // matcher in the correct (possibly nested) region.
399 SmallVector<Value> args;
400 if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(Val: question)) {
401 args = {getValueAt(currentBlock, pos: equalToQuestion->getValue())};
402 } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(Val: question)) {
403 for (Position *position : cstQuestion->getArgs())
404 args.push_back(Elt: getValueAt(currentBlock, pos: position));
405 }
406
407 // Generate a new block as success successor and get the failure successor.
408 Block *success = &region->emplaceBlock();
409 Block *failure = failureBlockStack.back();
410
411 // Create the predicate.
412 builder.setInsertionPointToEnd(currentBlock);
413 Predicates::Kind kind = question->getKind();
414 switch (kind) {
415 case Predicates::IsNotNullQuestion:
416 builder.create<pdl_interp::IsNotNullOp>(location: loc, args&: val, args&: success, args&: failure);
417 break;
418 case Predicates::OperationNameQuestion: {
419 auto *opNameAnswer = cast<OperationNameAnswer>(Val: answer);
420 builder.create<pdl_interp::CheckOperationNameOp>(
421 location: loc, args&: val, args: opNameAnswer->getValue().getStringRef(), args&: success, args&: failure);
422 break;
423 }
424 case Predicates::TypeQuestion: {
425 auto *ans = cast<TypeAnswer>(Val: answer);
426 if (isa<pdl::RangeType>(Val: val.getType()))
427 builder.create<pdl_interp::CheckTypesOp>(
428 location: loc, args&: val, args: llvm::cast<ArrayAttr>(Val: ans->getValue()), args&: success, args&: failure);
429 else
430 builder.create<pdl_interp::CheckTypeOp>(
431 location: loc, args&: val, args: llvm::cast<TypeAttr>(Val: ans->getValue()), args&: success, args&: failure);
432 break;
433 }
434 case Predicates::AttributeQuestion: {
435 auto *ans = cast<AttributeAnswer>(Val: answer);
436 builder.create<pdl_interp::CheckAttributeOp>(location: loc, args&: val, args: ans->getValue(),
437 args&: success, args&: failure);
438 break;
439 }
440 case Predicates::OperandCountAtLeastQuestion:
441 case Predicates::OperandCountQuestion:
442 builder.create<pdl_interp::CheckOperandCountOp>(
443 location: loc, args&: val, args: cast<UnsignedAnswer>(Val: answer)->getValue(),
444 /*compareAtLeast=*/args: kind == Predicates::OperandCountAtLeastQuestion,
445 args&: success, args&: failure);
446 break;
447 case Predicates::ResultCountAtLeastQuestion:
448 case Predicates::ResultCountQuestion:
449 builder.create<pdl_interp::CheckResultCountOp>(
450 location: loc, args&: val, args: cast<UnsignedAnswer>(Val: answer)->getValue(),
451 /*compareAtLeast=*/args: kind == Predicates::ResultCountAtLeastQuestion,
452 args&: success, args&: failure);
453 break;
454 case Predicates::EqualToQuestion: {
455 bool trueAnswer = isa<TrueAnswer>(Val: answer);
456 builder.create<pdl_interp::AreEqualOp>(location: loc, args&: val, args&: args.front(),
457 args&: trueAnswer ? success : failure,
458 args&: trueAnswer ? failure : success);
459 break;
460 }
461 case Predicates::ConstraintQuestion: {
462 auto *cstQuestion = cast<ConstraintQuestion>(Val: question);
463 auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
464 location: loc, args: cstQuestion->getResultTypes(), args: cstQuestion->getName(), args,
465 args: cstQuestion->getIsNegated(), args&: success, args&: failure);
466
467 constraintOpMap.insert(KV: {cstQuestion, applyConstraintOp});
468 break;
469 }
470 default:
471 llvm_unreachable("Generating unknown Predicate operation");
472 }
473
474 // Generate the matcher in the current (potentially nested) region.
475 // This might use the results of the current predicate.
476 generateMatcher(node&: *boolNode->getSuccessNode(), region&: *region, block: success);
477}
478
479template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
480static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
481 llvm::MapVector<Qualifier *, Block *> &dests) {
482 std::vector<ValT> values;
483 std::vector<Block *> blocks;
484 values.reserve(dests.size());
485 blocks.reserve(n: dests.size());
486 for (const auto &it : dests) {
487 blocks.push_back(x: it.second);
488 values.push_back(cast<PredT>(it.first)->getValue());
489 }
490 builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
491}
492
493void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
494 Value val) {
495 Qualifier *question = switchNode->getQuestion();
496 Region *region = currentBlock->getParent();
497 Block *defaultDest = failureBlockStack.back();
498
499 // If the switch question is not an exact answer, i.e. for the `at_least`
500 // cases, we generate a special block sequence.
501 Predicates::Kind kind = question->getKind();
502 if (kind == Predicates::OperandCountAtLeastQuestion ||
503 kind == Predicates::ResultCountAtLeastQuestion) {
504 // Order the children such that the cases are in reverse numerical order.
505 SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
506 Range: llvm::seq<unsigned>(Begin: 0, End: switchNode->getChildren().size()));
507 llvm::sort(C&: sortedChildren, Comp: [&](unsigned lhs, unsigned rhs) {
508 return cast<UnsignedAnswer>(Val: switchNode->getChild(i: lhs).first)->getValue() >
509 cast<UnsignedAnswer>(Val: switchNode->getChild(i: rhs).first)->getValue();
510 });
511
512 // Build the destination for each child using the next highest child as a
513 // a failure destination. This essentially creates the following control
514 // flow:
515 //
516 // if (operand_count < 1)
517 // goto failure
518 // if (child1.match())
519 // ...
520 //
521 // if (operand_count < 2)
522 // goto failure
523 // if (child2.match())
524 // ...
525 //
526 // failure:
527 // ...
528 //
529 failureBlockStack.push_back(Elt: defaultDest);
530 Location loc = val.getLoc();
531 for (unsigned idx : sortedChildren) {
532 auto &child = switchNode->getChild(i: idx);
533 Block *childBlock = generateMatcher(node&: *child.second, region&: *region);
534 Block *predicateBlock = builder.createBlock(insertBefore: childBlock);
535 builder.setInsertionPointToEnd(predicateBlock);
536 unsigned ans = cast<UnsignedAnswer>(Val: child.first)->getValue();
537 switch (kind) {
538 case Predicates::OperandCountAtLeastQuestion:
539 builder.create<pdl_interp::CheckOperandCountOp>(
540 location: loc, args&: val, args&: ans, /*compareAtLeast=*/args: true, args&: childBlock, args&: defaultDest);
541 break;
542 case Predicates::ResultCountAtLeastQuestion:
543 builder.create<pdl_interp::CheckResultCountOp>(
544 location: loc, args&: val, args&: ans, /*compareAtLeast=*/args: true, args&: childBlock, args&: defaultDest);
545 break;
546 default:
547 llvm_unreachable("Generating invalid AtLeast operation");
548 }
549 failureBlockStack.back() = predicateBlock;
550 }
551 Block *firstPredicateBlock = failureBlockStack.pop_back_val();
552 currentBlock->getOperations().splice(where: currentBlock->end(),
553 L2&: firstPredicateBlock->getOperations());
554 firstPredicateBlock->erase();
555 return;
556 }
557
558 // Otherwise, generate each of the children and generate an interpreter
559 // switch.
560 llvm::MapVector<Qualifier *, Block *> children;
561 for (auto &it : switchNode->getChildren())
562 children.insert(KV: {it.first, generateMatcher(node&: *it.second, region&: *region)});
563 builder.setInsertionPointToEnd(currentBlock);
564
565 switch (question->getKind()) {
566 case Predicates::OperandCountQuestion:
567 return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
568 int32_t>(val, defaultDest, builder, dests&: children);
569 case Predicates::ResultCountQuestion:
570 return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
571 int32_t>(val, defaultDest, builder, dests&: children);
572 case Predicates::OperationNameQuestion:
573 return createSwitchOp<pdl_interp::SwitchOperationNameOp,
574 OperationNameAnswer>(val, defaultDest, builder,
575 dests&: children);
576 case Predicates::TypeQuestion:
577 if (isa<pdl::RangeType>(Val: val.getType())) {
578 return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
579 val, defaultDest, builder, dests&: children);
580 }
581 return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
582 val, defaultDest, builder, dests&: children);
583 case Predicates::AttributeQuestion:
584 return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
585 val, defaultDest, builder, dests&: children);
586 default:
587 llvm_unreachable("Generating unknown switch predicate.");
588 }
589}
590
591void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {
592 pdl::PatternOp pattern = successNode->getPattern();
593 Value root = successNode->getRoot();
594
595 // Generate a rewriter for the pattern this success node represents, and track
596 // any values used from the match region.
597 SmallVector<Position *, 8> usedMatchValues;
598 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
599
600 // Process any values used in the rewrite that are defined in the match.
601 std::vector<Value> mappedMatchValues;
602 mappedMatchValues.reserve(n: usedMatchValues.size());
603 for (Position *position : usedMatchValues)
604 mappedMatchValues.push_back(x: getValueAt(currentBlock, pos: position));
605
606 // Collect the set of operations generated by the rewriter.
607 SmallVector<StringRef, 4> generatedOps;
608 for (auto op :
609 pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>())
610 generatedOps.push_back(Elt: *op.getOpName());
611 ArrayAttr generatedOpsAttr;
612 if (!generatedOps.empty())
613 generatedOpsAttr = builder.getStrArrayAttr(values: generatedOps);
614
615 // Grab the root kind if present.
616 StringAttr rootKindAttr;
617 if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
618 if (std::optional<StringRef> rootKind = rootOp.getOpName())
619 rootKindAttr = builder.getStringAttr(bytes: *rootKind);
620
621 builder.setInsertionPointToEnd(currentBlock);
622 auto matchOp = builder.create<pdl_interp::RecordMatchOp>(
623 location: pattern.getLoc(), args&: mappedMatchValues, args: locOps.getArrayRef(),
624 args&: rewriterFuncRef, args&: rootKindAttr, args&: generatedOpsAttr, args: pattern.getBenefitAttr(),
625 args&: failureBlockStack.back());
626
627 // Set the config of the lowered match to the parent pattern.
628 if (configMap)
629 configMap->try_emplace(Key: matchOp, Args: configMap->lookup(Val: pattern));
630}
631
632SymbolRefAttr PatternLowering::generateRewriter(
633 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
634 builder.setInsertionPointToEnd(rewriterModule.getBody());
635 auto rewriterFunc = builder.create<pdl_interp::FuncOp>(
636 location: pattern.getLoc(), args: "pdl_generated_rewriter",
637 args: builder.getFunctionType(inputs: {}, results: {}));
638 rewriterSymbolTable.insert(symbol: rewriterFunc);
639
640 // Generate the rewriter function body.
641 builder.setInsertionPointToEnd(&rewriterFunc.front());
642
643 // Map an input operand of the pattern to a generated interpreter value.
644 DenseMap<Value, Value> rewriteValues;
645 auto mapRewriteValue = [&](Value oldValue) {
646 Value &newValue = rewriteValues[oldValue];
647 if (newValue)
648 return newValue;
649
650 // Prefer materializing constants directly when possible.
651 Operation *oldOp = oldValue.getDefiningOp();
652 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(Val: oldOp)) {
653 if (Attribute value = attrOp.getValueAttr()) {
654 return newValue = builder.create<pdl_interp::CreateAttributeOp>(
655 location: attrOp.getLoc(), args&: value);
656 }
657 } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(Val: oldOp)) {
658 if (TypeAttr type = typeOp.getConstantTypeAttr()) {
659 return newValue = builder.create<pdl_interp::CreateTypeOp>(
660 location: typeOp.getLoc(), args&: type);
661 }
662 } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(Val: oldOp)) {
663 if (ArrayAttr type = typeOp.getConstantTypesAttr()) {
664 return newValue = builder.create<pdl_interp::CreateTypesOp>(
665 location: typeOp.getLoc(), args: typeOp.getType(), args&: type);
666 }
667 }
668
669 // Otherwise, add this as an input to the rewriter.
670 Position *inputPos = valueToPosition.lookup(Val: oldValue);
671 assert(inputPos && "expected value to be a pattern input");
672 usedMatchValues.push_back(Elt: inputPos);
673 return newValue = rewriterFunc.front().addArgument(type: oldValue.getType(),
674 loc: oldValue.getLoc());
675 };
676
677 // If this is a custom rewriter, simply dispatch to the registered rewrite
678 // method.
679 pdl::RewriteOp rewriter = pattern.getRewriter();
680 if (StringAttr rewriteName = rewriter.getNameAttr()) {
681 SmallVector<Value> args;
682 if (rewriter.getRoot())
683 args.push_back(Elt: mapRewriteValue(rewriter.getRoot()));
684 auto mappedArgs =
685 llvm::map_range(C: rewriter.getExternalArgs(), F: mapRewriteValue);
686 args.append(in_start: mappedArgs.begin(), in_end: mappedArgs.end());
687 builder.create<pdl_interp::ApplyRewriteOp>(
688 location: rewriter.getLoc(), /*resultTypes=*/args: TypeRange(), args&: rewriteName, args);
689 } else {
690 // Otherwise this is a dag rewriter defined using PDL operations.
691 for (Operation &rewriteOp : *rewriter.getBody()) {
692 llvm::TypeSwitch<Operation *>(&rewriteOp)
693 .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
694 pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp,
695 pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>(caseFn: [&](auto op) {
696 this->generateRewriter(op, rewriteValues, mapRewriteValue);
697 });
698 }
699 }
700
701 // Update the signature of the rewrite function.
702 rewriterFunc.setType(builder.getFunctionType(
703 inputs: llvm::to_vector<8>(Range: rewriterFunc.front().getArgumentTypes()),
704 /*results=*/{}));
705
706 builder.create<pdl_interp::FinalizeOp>(location: rewriter.getLoc());
707 return SymbolRefAttr::get(
708 ctx: builder.getContext(),
709 value: pdl_interp::PDLInterpDialect::getRewriterModuleName(),
710 nestedRefs: SymbolRefAttr::get(symbol: rewriterFunc));
711}
712
713void PatternLowering::generateRewriter(
714 pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
715 function_ref<Value(Value)> mapRewriteValue) {
716 SmallVector<Value, 2> arguments;
717 for (Value argument : rewriteOp.getArgs())
718 arguments.push_back(Elt: mapRewriteValue(argument));
719 auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
720 location: rewriteOp.getLoc(), args: rewriteOp.getResultTypes(), args: rewriteOp.getNameAttr(),
721 args&: arguments);
722 for (auto it : llvm::zip(t: rewriteOp.getResults(), u: interpOp.getResults()))
723 rewriteValues[std::get<0>(t&: it)] = std::get<1>(t&: it);
724}
725
726void PatternLowering::generateRewriter(
727 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
728 function_ref<Value(Value)> mapRewriteValue) {
729 Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
730 location: attrOp.getLoc(), args: attrOp.getValueAttr());
731 rewriteValues[attrOp] = newAttr;
732}
733
734void PatternLowering::generateRewriter(
735 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
736 function_ref<Value(Value)> mapRewriteValue) {
737 builder.create<pdl_interp::EraseOp>(location: eraseOp.getLoc(),
738 args: mapRewriteValue(eraseOp.getOpValue()));
739}
740
741void PatternLowering::generateRewriter(
742 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
743 function_ref<Value(Value)> mapRewriteValue) {
744 SmallVector<Value, 4> operands;
745 for (Value operand : operationOp.getOperandValues())
746 operands.push_back(Elt: mapRewriteValue(operand));
747
748 SmallVector<Value, 4> attributes;
749 for (Value attr : operationOp.getAttributeValues())
750 attributes.push_back(Elt: mapRewriteValue(attr));
751
752 bool hasInferredResultTypes = false;
753 SmallVector<Value, 2> types;
754 generateOperationResultTypeRewriter(op: operationOp, mapRewriteValue, types,
755 rewriteValues, hasInferredResultTypes);
756
757 // Create the new operation.
758 Location loc = operationOp.getLoc();
759 Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
760 location: loc, args: *operationOp.getOpName(), args&: types, args&: hasInferredResultTypes, args&: operands,
761 args&: attributes, args: operationOp.getAttributeValueNames());
762 rewriteValues[operationOp.getOp()] = createdOp;
763
764 // Generate accesses for any results that have their types constrained.
765 // Handle the case where there is a single range representing all of the
766 // result types.
767 OperandRange resultTys = operationOp.getTypeValues();
768 if (resultTys.size() == 1 && isa<pdl::RangeType>(Val: resultTys[0].getType())) {
769 Value &type = rewriteValues[resultTys[0]];
770 if (!type) {
771 auto results = builder.create<pdl_interp::GetResultsOp>(location: loc, args&: createdOp);
772 type = builder.create<pdl_interp::GetValueTypeOp>(location: loc, args&: results);
773 }
774 return;
775 }
776
777 // Otherwise, populate the individual results.
778 bool seenVariableLength = false;
779 Type valueTy = builder.getType<pdl::ValueType>();
780 Type valueRangeTy = pdl::RangeType::get(elementType: valueTy);
781 for (const auto &it : llvm::enumerate(First&: resultTys)) {
782 Value &type = rewriteValues[it.value()];
783 if (type)
784 continue;
785 bool isVariadic = isa<pdl::RangeType>(Val: it.value().getType());
786 seenVariableLength |= isVariadic;
787
788 // After a variable length result has been seen, we need to use result
789 // groups because the exact index of the result is not statically known.
790 Value resultVal;
791 if (seenVariableLength)
792 resultVal = builder.create<pdl_interp::GetResultsOp>(
793 location: loc, args&: isVariadic ? valueRangeTy : valueTy, args&: createdOp, args: it.index());
794 else
795 resultVal = builder.create<pdl_interp::GetResultOp>(
796 location: loc, args&: valueTy, args&: createdOp, args: it.index());
797 type = builder.create<pdl_interp::GetValueTypeOp>(location: loc, args&: resultVal);
798 }
799}
800
801void PatternLowering::generateRewriter(
802 pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues,
803 function_ref<Value(Value)> mapRewriteValue) {
804 SmallVector<Value, 4> replOperands;
805 for (Value operand : rangeOp.getArguments())
806 replOperands.push_back(Elt: mapRewriteValue(operand));
807 rewriteValues[rangeOp] = builder.create<pdl_interp::CreateRangeOp>(
808 location: rangeOp.getLoc(), args: rangeOp.getType(), args&: replOperands);
809}
810
811void PatternLowering::generateRewriter(
812 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
813 function_ref<Value(Value)> mapRewriteValue) {
814 SmallVector<Value, 4> replOperands;
815
816 // If the replacement was another operation, get its results. `pdl` allows
817 // for using an operation for simplicitly, but the interpreter isn't as
818 // user facing.
819 if (Value replOp = replaceOp.getReplOperation()) {
820 // Don't use replace if we know the replaced operation has no results.
821 auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
822 if (!opOp || !opOp.getTypeValues().empty()) {
823 replOperands.push_back(Elt: builder.create<pdl_interp::GetResultsOp>(
824 location: replOp.getLoc(), args: mapRewriteValue(replOp)));
825 }
826 } else {
827 for (Value operand : replaceOp.getReplValues())
828 replOperands.push_back(Elt: mapRewriteValue(operand));
829 }
830
831 // If there are no replacement values, just create an erase instead.
832 if (replOperands.empty()) {
833 builder.create<pdl_interp::EraseOp>(
834 location: replaceOp.getLoc(), args: mapRewriteValue(replaceOp.getOpValue()));
835 return;
836 }
837
838 builder.create<pdl_interp::ReplaceOp>(location: replaceOp.getLoc(),
839 args: mapRewriteValue(replaceOp.getOpValue()),
840 args&: replOperands);
841}
842
843void PatternLowering::generateRewriter(
844 pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
845 function_ref<Value(Value)> mapRewriteValue) {
846 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
847 location: resultOp.getLoc(), args: builder.getType<pdl::ValueType>(),
848 args: mapRewriteValue(resultOp.getParent()), args: resultOp.getIndex());
849}
850
851void PatternLowering::generateRewriter(
852 pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
853 function_ref<Value(Value)> mapRewriteValue) {
854 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
855 location: resultOp.getLoc(), args: resultOp.getType(),
856 args: mapRewriteValue(resultOp.getParent()), args: resultOp.getIndex());
857}
858
859void PatternLowering::generateRewriter(
860 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
861 function_ref<Value(Value)> mapRewriteValue) {
862 // If the type isn't constant, the users (e.g. OperationOp) will resolve this
863 // type.
864 if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) {
865 rewriteValues[typeOp] =
866 builder.create<pdl_interp::CreateTypeOp>(location: typeOp.getLoc(), args&: typeAttr);
867 }
868}
869
870void PatternLowering::generateRewriter(
871 pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
872 function_ref<Value(Value)> mapRewriteValue) {
873 // If the type isn't constant, the users (e.g. OperationOp) will resolve this
874 // type.
875 if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) {
876 rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
877 location: typeOp.getLoc(), args: typeOp.getType(), args&: typeAttr);
878 }
879}
880
881void PatternLowering::generateOperationResultTypeRewriter(
882 pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
883 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
884 bool &hasInferredResultTypes) {
885 Block *rewriterBlock = op->getBlock();
886
887 // Try to handle resolution for each of the result types individually. This is
888 // preferred over type inferrence because it will allow for us to use existing
889 // types directly, as opposed to trying to rebuild the type list.
890 OperandRange resultTypeValues = op.getTypeValues();
891 auto tryResolveResultTypes = [&] {
892 types.reserve(N: resultTypeValues.size());
893 for (const auto &it : llvm::enumerate(First&: resultTypeValues)) {
894 Value resultType = it.value();
895
896 // Check for an already translated value.
897 if (Value existingRewriteValue = rewriteValues.lookup(Val: resultType)) {
898 types.push_back(Elt: existingRewriteValue);
899 continue;
900 }
901
902 // Check for an input from the matcher.
903 if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
904 types.push_back(Elt: mapRewriteValue(resultType));
905 continue;
906 }
907
908 // Otherwise, we couldn't infer the result types. Bail out here to see if
909 // we can infer the types for this operation from another way.
910 types.clear();
911 return failure();
912 }
913 return success();
914 };
915 if (!resultTypeValues.empty() && succeeded(Result: tryResolveResultTypes()))
916 return;
917
918 // Otherwise, check if the operation has type inference support itself.
919 if (op.hasTypeInference()) {
920 hasInferredResultTypes = true;
921 return;
922 }
923
924 // Look for an operation that was replaced by `op`. The result types will be
925 // inferred from the results that were replaced.
926 for (OpOperand &use : op.getOp().getUses()) {
927 // Check that the use corresponds to a ReplaceOp and that it is the
928 // replacement value, not the operation being replaced.
929 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(Val: use.getOwner());
930 if (!replOpUser || use.getOperandNumber() == 0)
931 continue;
932 // Make sure the replaced operation was defined before this one. PDL
933 // rewrites only have single block regions, so if the op isn't in the
934 // rewriter block (i.e. the current block of the operation) we already know
935 // it dominates (i.e. it's in the matcher).
936 Value replOpVal = replOpUser.getOpValue();
937 Operation *replacedOp = replOpVal.getDefiningOp();
938 if (replacedOp->getBlock() == rewriterBlock &&
939 !replacedOp->isBeforeInBlock(other: op))
940 continue;
941
942 Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
943 location: replacedOp->getLoc(), args: mapRewriteValue(replOpVal));
944 types.push_back(Elt: builder.create<pdl_interp::GetValueTypeOp>(
945 location: replacedOp->getLoc(), args&: replacedOpResults));
946 return;
947 }
948
949 // If the types could not be inferred from any context and there weren't any
950 // explicit result types, assume the user actually meant for the operation to
951 // have no results.
952 if (resultTypeValues.empty())
953 return;
954
955 // The verifier asserts that the result types of each pdl.getOperation can be
956 // inferred. If we reach here, there is a bug either in the logic above or
957 // in the verifier for pdl.getOperation.
958 op->emitOpError() << "unable to infer result type for operation";
959 llvm_unreachable("unable to infer result type for operation");
960}
961
962//===----------------------------------------------------------------------===//
963// Conversion Pass
964//===----------------------------------------------------------------------===//
965
966namespace {
967struct PDLToPDLInterpPass
968 : public impl::ConvertPDLToPDLInterpPassBase<PDLToPDLInterpPass> {
969 PDLToPDLInterpPass() = default;
970 PDLToPDLInterpPass(const PDLToPDLInterpPass &rhs) = default;
971 PDLToPDLInterpPass(DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
972 : configMap(&configMap) {}
973 void runOnOperation() final;
974
975 /// A map containing the configuration for each pattern.
976 DenseMap<Operation *, PDLPatternConfigSet *> *configMap = nullptr;
977};
978} // namespace
979
980/// Convert the given module containing PDL pattern operations into a PDL
981/// Interpreter operations.
982void PDLToPDLInterpPass::runOnOperation() {
983 ModuleOp module = getOperation();
984
985 // Create the main matcher function This function contains all of the match
986 // related functionality from patterns in the module.
987 OpBuilder builder = OpBuilder::atBlockBegin(block: module.getBody());
988 auto matcherFunc = builder.create<pdl_interp::FuncOp>(
989 location: module.getLoc(), args: pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
990 args: builder.getFunctionType(inputs: builder.getType<pdl::OperationType>(),
991 /*results=*/{}),
992 /*attrs=*/args: ArrayRef<NamedAttribute>());
993
994 // Create a nested module to hold the functions invoked for rewriting the IR
995 // after a successful match.
996 ModuleOp rewriterModule = builder.create<ModuleOp>(
997 location: module.getLoc(), args: pdl_interp::PDLInterpDialect::getRewriterModuleName());
998
999 // Generate the code for the patterns within the module.
1000 PatternLowering generator(matcherFunc, rewriterModule, configMap);
1001 generator.lower(module);
1002
1003 // After generation, delete all of the pattern operations.
1004 for (pdl::PatternOp pattern :
1005 llvm::make_early_inc_range(Range: module.getOps<pdl::PatternOp>())) {
1006 // Drop the now dead config mappings.
1007 if (configMap)
1008 configMap->erase(Val: pattern);
1009
1010 pattern.erase();
1011 }
1012}
1013
1014std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertPDLToPDLInterpPass(
1015 DenseMap<Operation *, PDLPatternConfigSet *> &configMap) {
1016 return std::make_unique<PDLToPDLInterpPass>(args&: configMap);
1017}
1018

source code of mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp