1 | //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" |
10 | |
11 | #include "PredicateTree.h" |
12 | #include "mlir/Dialect/PDL/IR/PDL.h" |
13 | #include "mlir/Dialect/PDL/IR/PDLTypes.h" |
14 | #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" |
15 | #include "mlir/Pass/Pass.h" |
16 | #include "llvm/ADT/MapVector.h" |
17 | #include "llvm/ADT/ScopedHashTable.h" |
18 | #include "llvm/ADT/Sequence.h" |
19 | #include "llvm/ADT/SetVector.h" |
20 | #include "llvm/ADT/SmallVector.h" |
21 | #include "llvm/ADT/TypeSwitch.h" |
22 | |
23 | namespace mlir { |
24 | #define GEN_PASS_DEF_CONVERTPDLTOPDLINTERP |
25 | #include "mlir/Conversion/Passes.h.inc" |
26 | } // namespace mlir |
27 | |
28 | using namespace mlir; |
29 | using namespace mlir::pdl_to_pdl_interp; |
30 | |
31 | //===----------------------------------------------------------------------===// |
32 | // PatternLowering |
33 | //===----------------------------------------------------------------------===// |
34 | |
35 | namespace { |
36 | /// This class generators operations within the PDL Interpreter dialect from a |
37 | /// given module containing PDL pattern operations. |
38 | struct PatternLowering { |
39 | public: |
40 | PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule, |
41 | DenseMap<Operation *, PDLPatternConfigSet *> *configMap); |
42 | |
43 | /// Generate code for matching and rewriting based on the pattern operations |
44 | /// within the module. |
45 | void lower(ModuleOp module); |
46 | |
47 | private: |
48 | using ValueMap = llvm::ScopedHashTable<Position *, Value>; |
49 | using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>; |
50 | |
51 | /// Generate interpreter operations for the tree rooted at the given matcher |
52 | /// node, in the specified region. |
53 | Block *generateMatcher(MatcherNode &node, Region ®ion, |
54 | Block *block = nullptr); |
55 | |
56 | /// Get or create an access to the provided positional value in the current |
57 | /// block. This operation may mutate the provided block pointer if nested |
58 | /// regions (i.e., pdl_interp.iterate) are required. |
59 | Value getValueAt(Block *¤tBlock, Position *pos); |
60 | |
61 | /// Create the interpreter predicate operations. This operation may mutate the |
62 | /// provided current block pointer if nested regions (iterates) are required. |
63 | void generate(BoolNode *boolNode, Block *¤tBlock, Value val); |
64 | |
65 | /// Create the interpreter switch / predicate operations, with several case |
66 | /// destinations. This operation never mutates the provided current block |
67 | /// pointer, because the switch operation does not need Values beyond `val`. |
68 | void generate(SwitchNode *switchNode, Block *currentBlock, Value val); |
69 | |
70 | /// Create the interpreter operations to record a successful pattern match |
71 | /// using the contained root operation. This operation may mutate the current |
72 | /// block pointer if nested regions (i.e., pdl_interp.iterate) are required. |
73 | void generate(SuccessNode *successNode, Block *¤tBlock); |
74 | |
75 | /// Generate a rewriter function for the given pattern operation, and returns |
76 | /// a reference to that function. |
77 | SymbolRefAttr generateRewriter(pdl::PatternOp pattern, |
78 | SmallVectorImpl<Position *> &usedMatchValues); |
79 | |
80 | /// Generate the rewriter code for the given operation. |
81 | void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp, |
82 | DenseMap<Value, Value> &rewriteValues, |
83 | function_ref<Value(Value)> mapRewriteValue); |
84 | void generateRewriter(pdl::AttributeOp attrOp, |
85 | DenseMap<Value, Value> &rewriteValues, |
86 | function_ref<Value(Value)> mapRewriteValue); |
87 | void generateRewriter(pdl::EraseOp eraseOp, |
88 | DenseMap<Value, Value> &rewriteValues, |
89 | function_ref<Value(Value)> mapRewriteValue); |
90 | void generateRewriter(pdl::OperationOp operationOp, |
91 | DenseMap<Value, Value> &rewriteValues, |
92 | function_ref<Value(Value)> mapRewriteValue); |
93 | void generateRewriter(pdl::RangeOp rangeOp, |
94 | DenseMap<Value, Value> &rewriteValues, |
95 | function_ref<Value(Value)> mapRewriteValue); |
96 | void generateRewriter(pdl::ReplaceOp replaceOp, |
97 | DenseMap<Value, Value> &rewriteValues, |
98 | function_ref<Value(Value)> mapRewriteValue); |
99 | void generateRewriter(pdl::ResultOp resultOp, |
100 | DenseMap<Value, Value> &rewriteValues, |
101 | function_ref<Value(Value)> mapRewriteValue); |
102 | void generateRewriter(pdl::ResultsOp resultOp, |
103 | DenseMap<Value, Value> &rewriteValues, |
104 | function_ref<Value(Value)> mapRewriteValue); |
105 | void generateRewriter(pdl::TypeOp typeOp, |
106 | DenseMap<Value, Value> &rewriteValues, |
107 | function_ref<Value(Value)> mapRewriteValue); |
108 | void generateRewriter(pdl::TypesOp typeOp, |
109 | DenseMap<Value, Value> &rewriteValues, |
110 | function_ref<Value(Value)> mapRewriteValue); |
111 | |
112 | /// Generate the values used for resolving the result types of an operation |
113 | /// created within a dag rewriter region. If the result types of the operation |
114 | /// should be inferred, `hasInferredResultTypes` is set to true. |
115 | void generateOperationResultTypeRewriter( |
116 | pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue, |
117 | SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues, |
118 | bool &hasInferredResultTypes); |
119 | |
120 | /// A builder to use when generating interpreter operations. |
121 | OpBuilder builder; |
122 | |
123 | /// The matcher function used for all match related logic within PDL patterns. |
124 | pdl_interp::FuncOp matcherFunc; |
125 | |
126 | /// The rewriter module containing the all rewrite related logic within PDL |
127 | /// patterns. |
128 | ModuleOp rewriterModule; |
129 | |
130 | /// The symbol table of the rewriter module used for insertion. |
131 | SymbolTable rewriterSymbolTable; |
132 | |
133 | /// A scoped map connecting a position with the corresponding interpreter |
134 | /// value. |
135 | ValueMap values; |
136 | |
137 | /// A stack of blocks used as the failure destination for matcher nodes that |
138 | /// don't have an explicit failure path. |
139 | SmallVector<Block *, 8> failureBlockStack; |
140 | |
141 | /// A mapping between values defined in a pattern match, and the corresponding |
142 | /// positional value. |
143 | DenseMap<Value, Position *> valueToPosition; |
144 | |
145 | /// The set of operation values whose location will be used for newly |
146 | /// generated operations. |
147 | SetVector<Value> locOps; |
148 | |
149 | /// A mapping between pattern operations and the corresponding configuration |
150 | /// set. |
151 | DenseMap<Operation *, PDLPatternConfigSet *> *configMap; |
152 | |
153 | /// A mapping from a constraint question to the ApplyConstraintOp |
154 | /// that implements it. |
155 | DenseMap<ConstraintQuestion *, pdl_interp::ApplyConstraintOp> constraintOpMap; |
156 | }; |
157 | } // namespace |
158 | |
159 | PatternLowering::PatternLowering( |
160 | pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule, |
161 | DenseMap<Operation *, PDLPatternConfigSet *> *configMap) |
162 | : builder(matcherFunc.getContext()), matcherFunc(matcherFunc), |
163 | rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule), |
164 | configMap(configMap) {} |
165 | |
166 | void PatternLowering::lower(ModuleOp module) { |
167 | PredicateUniquer predicateUniquer; |
168 | PredicateBuilder predicateBuilder(predicateUniquer, module.getContext()); |
169 | |
170 | // Define top-level scope for the arguments to the matcher function. |
171 | ValueMapScope topLevelValueScope(values); |
172 | |
173 | // Insert the root operation, i.e. argument to the matcher, at the root |
174 | // position. |
175 | Block *matcherEntryBlock = &matcherFunc.front(); |
176 | values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0)); |
177 | |
178 | // Generate a root matcher node from the provided PDL module. |
179 | std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree( |
180 | module, predicateBuilder, valueToPosition); |
181 | Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody()); |
182 | assert(failureBlockStack.empty() && "failed to empty the stack" ); |
183 | |
184 | // After generation, merged the first matched block into the entry. |
185 | matcherEntryBlock->getOperations().splice(where: matcherEntryBlock->end(), |
186 | L2&: firstMatcherBlock->getOperations()); |
187 | firstMatcherBlock->erase(); |
188 | } |
189 | |
190 | Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion, |
191 | Block *block) { |
192 | // Push a new scope for the values used by this matcher. |
193 | if (!block) |
194 | block = ®ion.emplaceBlock(); |
195 | ValueMapScope scope(values); |
196 | |
197 | // If this is the return node, simply insert the corresponding interpreter |
198 | // finalize. |
199 | if (isa<ExitNode>(Val: node)) { |
200 | builder.setInsertionPointToEnd(block); |
201 | builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc()); |
202 | return block; |
203 | } |
204 | |
205 | // Get the next block in the match sequence. |
206 | // This is intentionally executed first, before we get the value for the |
207 | // position associated with the node, so that we preserve an "there exist" |
208 | // semantics: if getting a value requires an upward traversal (going from a |
209 | // value to its consumers), we want to perform the check on all the consumers |
210 | // before we pass control to the failure node. |
211 | std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode(); |
212 | Block *failureBlock; |
213 | if (failureNode) { |
214 | failureBlock = generateMatcher(node&: *failureNode, region); |
215 | failureBlockStack.push_back(failureBlock); |
216 | } else { |
217 | assert(!failureBlockStack.empty() && "expected valid failure block" ); |
218 | failureBlock = failureBlockStack.back(); |
219 | } |
220 | |
221 | // If this node contains a position, get the corresponding value for this |
222 | // block. |
223 | Block *currentBlock = block; |
224 | Position *position = node.getPosition(); |
225 | Value val = position ? getValueAt(currentBlock, pos: position) : Value(); |
226 | |
227 | // If this value corresponds to an operation, record that we are going to use |
228 | // its location as part of a fused location. |
229 | bool isOperationValue = val && isa<pdl::OperationType>(val.getType()); |
230 | if (isOperationValue) |
231 | locOps.insert(val); |
232 | |
233 | // Dispatch to the correct method based on derived node type. |
234 | TypeSwitch<MatcherNode *>(&node) |
235 | .Case<BoolNode, SwitchNode>(caseFn: [&](auto *derivedNode) { |
236 | this->generate(derivedNode, currentBlock, val); |
237 | }) |
238 | .Case(caseFn: [&](SuccessNode *successNode) { |
239 | generate(successNode, currentBlock); |
240 | }); |
241 | |
242 | // Pop all the failure blocks that were inserted due to nesting of |
243 | // pdl_interp.iterate. |
244 | while (failureBlockStack.back() != failureBlock) { |
245 | failureBlockStack.pop_back(); |
246 | assert(!failureBlockStack.empty() && "unable to locate failure block" ); |
247 | } |
248 | |
249 | // Pop the new failure block. |
250 | if (failureNode) |
251 | failureBlockStack.pop_back(); |
252 | |
253 | if (isOperationValue) |
254 | locOps.remove(val); |
255 | |
256 | return block; |
257 | } |
258 | |
259 | Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { |
260 | if (Value val = values.lookup(pos)) |
261 | return val; |
262 | |
263 | // Get the value for the parent position. |
264 | Value parentVal; |
265 | if (Position *parent = pos->getParent()) |
266 | parentVal = getValueAt(currentBlock, pos: parent); |
267 | |
268 | // TODO: Use a location from the position. |
269 | Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc(); |
270 | builder.setInsertionPointToEnd(currentBlock); |
271 | Value value; |
272 | switch (pos->getKind()) { |
273 | case Predicates::OperationPos: { |
274 | auto *operationPos = cast<OperationPosition>(Val: pos); |
275 | if (operationPos->isOperandDefiningOp()) |
276 | // Standard (downward) traversal which directly follows the defining op. |
277 | value = builder.create<pdl_interp::GetDefiningOpOp>( |
278 | loc, builder.getType<pdl::OperationType>(), parentVal); |
279 | else |
280 | // A passthrough operation position. |
281 | value = parentVal; |
282 | break; |
283 | } |
284 | case Predicates::UsersPos: { |
285 | auto *usersPos = cast<UsersPosition>(Val: pos); |
286 | |
287 | // The first operation retrieves the representative value of a range. |
288 | // This applies only when the parent is a range of values and we were |
289 | // requested to use a representative value (e.g., upward traversal). |
290 | if (isa<pdl::RangeType>(parentVal.getType()) && |
291 | usersPos->useRepresentative()) |
292 | value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0); |
293 | else |
294 | value = parentVal; |
295 | |
296 | // The second operation retrieves the users. |
297 | value = builder.create<pdl_interp::GetUsersOp>(loc, value); |
298 | break; |
299 | } |
300 | case Predicates::ForEachPos: { |
301 | assert(!failureBlockStack.empty() && "expected valid failure block" ); |
302 | auto foreach = builder.create<pdl_interp::ForEachOp>( |
303 | loc, parentVal, failureBlockStack.back(), /*initLoop=*/true); |
304 | value = foreach.getLoopVariable(); |
305 | |
306 | // Create the continuation block. |
307 | Block *continueBlock = builder.createBlock(&foreach.getRegion()); |
308 | builder.create<pdl_interp::ContinueOp>(loc); |
309 | failureBlockStack.push_back(continueBlock); |
310 | |
311 | currentBlock = &foreach.getRegion().front(); |
312 | break; |
313 | } |
314 | case Predicates::OperandPos: { |
315 | auto *operandPos = cast<OperandPosition>(Val: pos); |
316 | value = builder.create<pdl_interp::GetOperandOp>( |
317 | loc, builder.getType<pdl::ValueType>(), parentVal, |
318 | operandPos->getOperandNumber()); |
319 | break; |
320 | } |
321 | case Predicates::OperandGroupPos: { |
322 | auto *operandPos = cast<OperandGroupPosition>(Val: pos); |
323 | Type valueTy = builder.getType<pdl::ValueType>(); |
324 | value = builder.create<pdl_interp::GetOperandsOp>( |
325 | loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, |
326 | parentVal, operandPos->getOperandGroupNumber()); |
327 | break; |
328 | } |
329 | case Predicates::AttributePos: { |
330 | auto *attrPos = cast<AttributePosition>(Val: pos); |
331 | value = builder.create<pdl_interp::GetAttributeOp>( |
332 | loc, builder.getType<pdl::AttributeType>(), parentVal, |
333 | attrPos->getName().strref()); |
334 | break; |
335 | } |
336 | case Predicates::TypePos: { |
337 | if (isa<pdl::AttributeType>(parentVal.getType())) |
338 | value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal); |
339 | else |
340 | value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal); |
341 | break; |
342 | } |
343 | case Predicates::ResultPos: { |
344 | auto *resPos = cast<ResultPosition>(Val: pos); |
345 | value = builder.create<pdl_interp::GetResultOp>( |
346 | loc, builder.getType<pdl::ValueType>(), parentVal, |
347 | resPos->getResultNumber()); |
348 | break; |
349 | } |
350 | case Predicates::ResultGroupPos: { |
351 | auto *resPos = cast<ResultGroupPosition>(Val: pos); |
352 | Type valueTy = builder.getType<pdl::ValueType>(); |
353 | value = builder.create<pdl_interp::GetResultsOp>( |
354 | loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, |
355 | parentVal, resPos->getResultGroupNumber()); |
356 | break; |
357 | } |
358 | case Predicates::AttributeLiteralPos: { |
359 | auto *attrPos = cast<AttributeLiteralPosition>(Val: pos); |
360 | value = |
361 | builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue()); |
362 | break; |
363 | } |
364 | case Predicates::TypeLiteralPos: { |
365 | auto *typePos = cast<TypeLiteralPosition>(Val: pos); |
366 | Attribute rawTypeAttr = typePos->getValue(); |
367 | if (TypeAttr typeAttr = dyn_cast<TypeAttr>(rawTypeAttr)) |
368 | value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr); |
369 | else |
370 | value = builder.create<pdl_interp::CreateTypesOp>( |
371 | loc, cast<ArrayAttr>(rawTypeAttr)); |
372 | break; |
373 | } |
374 | case Predicates::ConstraintResultPos: { |
375 | // Due to the order of traversal, the ApplyConstraintOp has already been |
376 | // created and we can find it in constraintOpMap. |
377 | auto *constrResPos = cast<ConstraintPosition>(Val: pos); |
378 | auto i = constraintOpMap.find(constrResPos->getQuestion()); |
379 | assert(i != constraintOpMap.end()); |
380 | value = i->second->getResult(constrResPos->getIndex()); |
381 | break; |
382 | } |
383 | default: |
384 | llvm_unreachable("Generating unknown Position getter" ); |
385 | break; |
386 | } |
387 | |
388 | values.insert(pos, value); |
389 | return value; |
390 | } |
391 | |
392 | void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, |
393 | Value val) { |
394 | Location loc = val.getLoc(); |
395 | Qualifier *question = boolNode->getQuestion(); |
396 | Qualifier *answer = boolNode->getAnswer(); |
397 | Region *region = currentBlock->getParent(); |
398 | |
399 | // Execute the getValue queries first, so that we create success |
400 | // matcher in the correct (possibly nested) region. |
401 | SmallVector<Value> args; |
402 | if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(Val: question)) { |
403 | args = {getValueAt(currentBlock, pos: equalToQuestion->getValue())}; |
404 | } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(Val: question)) { |
405 | for (Position *position : cstQuestion->getArgs()) |
406 | args.push_back(Elt: getValueAt(currentBlock, pos: position)); |
407 | } |
408 | |
409 | // Generate a new block as success successor and get the failure successor. |
410 | Block *success = ®ion->emplaceBlock(); |
411 | Block *failure = failureBlockStack.back(); |
412 | |
413 | // Create the predicate. |
414 | builder.setInsertionPointToEnd(currentBlock); |
415 | Predicates::Kind kind = question->getKind(); |
416 | switch (kind) { |
417 | case Predicates::IsNotNullQuestion: |
418 | builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure); |
419 | break; |
420 | case Predicates::OperationNameQuestion: { |
421 | auto *opNameAnswer = cast<OperationNameAnswer>(Val: answer); |
422 | builder.create<pdl_interp::CheckOperationNameOp>( |
423 | loc, val, opNameAnswer->getValue().getStringRef(), success, failure); |
424 | break; |
425 | } |
426 | case Predicates::TypeQuestion: { |
427 | auto *ans = cast<TypeAnswer>(Val: answer); |
428 | if (isa<pdl::RangeType>(val.getType())) |
429 | builder.create<pdl_interp::CheckTypesOp>( |
430 | loc, val, llvm::cast<ArrayAttr>(ans->getValue()), success, failure); |
431 | else |
432 | builder.create<pdl_interp::CheckTypeOp>( |
433 | loc, val, llvm::cast<TypeAttr>(ans->getValue()), success, failure); |
434 | break; |
435 | } |
436 | case Predicates::AttributeQuestion: { |
437 | auto *ans = cast<AttributeAnswer>(Val: answer); |
438 | builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(), |
439 | success, failure); |
440 | break; |
441 | } |
442 | case Predicates::OperandCountAtLeastQuestion: |
443 | case Predicates::OperandCountQuestion: |
444 | builder.create<pdl_interp::CheckOperandCountOp>( |
445 | loc, val, cast<UnsignedAnswer>(answer)->getValue(), |
446 | /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion, |
447 | success, failure); |
448 | break; |
449 | case Predicates::ResultCountAtLeastQuestion: |
450 | case Predicates::ResultCountQuestion: |
451 | builder.create<pdl_interp::CheckResultCountOp>( |
452 | loc, val, cast<UnsignedAnswer>(answer)->getValue(), |
453 | /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion, |
454 | success, failure); |
455 | break; |
456 | case Predicates::EqualToQuestion: { |
457 | bool trueAnswer = isa<TrueAnswer>(Val: answer); |
458 | builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(), |
459 | trueAnswer ? success : failure, |
460 | trueAnswer ? failure : success); |
461 | break; |
462 | } |
463 | case Predicates::ConstraintQuestion: { |
464 | auto *cstQuestion = cast<ConstraintQuestion>(Val: question); |
465 | auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>( |
466 | loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args, |
467 | cstQuestion->getIsNegated(), success, failure); |
468 | |
469 | constraintOpMap.insert({cstQuestion, applyConstraintOp}); |
470 | break; |
471 | } |
472 | default: |
473 | llvm_unreachable("Generating unknown Predicate operation" ); |
474 | } |
475 | |
476 | // Generate the matcher in the current (potentially nested) region. |
477 | // This might use the results of the current predicate. |
478 | generateMatcher(node&: *boolNode->getSuccessNode(), region&: *region, block: success); |
479 | } |
480 | |
481 | template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy> |
482 | static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, |
483 | llvm::MapVector<Qualifier *, Block *> &dests) { |
484 | std::vector<ValT> values; |
485 | std::vector<Block *> blocks; |
486 | values.reserve(dests.size()); |
487 | blocks.reserve(n: dests.size()); |
488 | for (const auto &it : dests) { |
489 | blocks.push_back(x: it.second); |
490 | values.push_back(cast<PredT>(it.first)->getValue()); |
491 | } |
492 | builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks); |
493 | } |
494 | |
495 | void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock, |
496 | Value val) { |
497 | Qualifier *question = switchNode->getQuestion(); |
498 | Region *region = currentBlock->getParent(); |
499 | Block *defaultDest = failureBlockStack.back(); |
500 | |
501 | // If the switch question is not an exact answer, i.e. for the `at_least` |
502 | // cases, we generate a special block sequence. |
503 | Predicates::Kind kind = question->getKind(); |
504 | if (kind == Predicates::OperandCountAtLeastQuestion || |
505 | kind == Predicates::ResultCountAtLeastQuestion) { |
506 | // Order the children such that the cases are in reverse numerical order. |
507 | SmallVector<unsigned> sortedChildren = llvm::to_vector<16>( |
508 | Range: llvm::seq<unsigned>(Begin: 0, End: switchNode->getChildren().size())); |
509 | llvm::sort(C&: sortedChildren, Comp: [&](unsigned lhs, unsigned rhs) { |
510 | return cast<UnsignedAnswer>(Val: switchNode->getChild(i: lhs).first)->getValue() > |
511 | cast<UnsignedAnswer>(Val: switchNode->getChild(i: rhs).first)->getValue(); |
512 | }); |
513 | |
514 | // Build the destination for each child using the next highest child as a |
515 | // a failure destination. This essentially creates the following control |
516 | // flow: |
517 | // |
518 | // if (operand_count < 1) |
519 | // goto failure |
520 | // if (child1.match()) |
521 | // ... |
522 | // |
523 | // if (operand_count < 2) |
524 | // goto failure |
525 | // if (child2.match()) |
526 | // ... |
527 | // |
528 | // failure: |
529 | // ... |
530 | // |
531 | failureBlockStack.push_back(defaultDest); |
532 | Location loc = val.getLoc(); |
533 | for (unsigned idx : sortedChildren) { |
534 | auto &child = switchNode->getChild(i: idx); |
535 | Block *childBlock = generateMatcher(node&: *child.second, region&: *region); |
536 | Block *predicateBlock = builder.createBlock(insertBefore: childBlock); |
537 | builder.setInsertionPointToEnd(predicateBlock); |
538 | unsigned ans = cast<UnsignedAnswer>(Val: child.first)->getValue(); |
539 | switch (kind) { |
540 | case Predicates::OperandCountAtLeastQuestion: |
541 | builder.create<pdl_interp::CheckOperandCountOp>( |
542 | loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); |
543 | break; |
544 | case Predicates::ResultCountAtLeastQuestion: |
545 | builder.create<pdl_interp::CheckResultCountOp>( |
546 | loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); |
547 | break; |
548 | default: |
549 | llvm_unreachable("Generating invalid AtLeast operation" ); |
550 | } |
551 | failureBlockStack.back() = predicateBlock; |
552 | } |
553 | Block *firstPredicateBlock = failureBlockStack.pop_back_val(); |
554 | currentBlock->getOperations().splice(where: currentBlock->end(), |
555 | L2&: firstPredicateBlock->getOperations()); |
556 | firstPredicateBlock->erase(); |
557 | return; |
558 | } |
559 | |
560 | // Otherwise, generate each of the children and generate an interpreter |
561 | // switch. |
562 | llvm::MapVector<Qualifier *, Block *> children; |
563 | for (auto &it : switchNode->getChildren()) |
564 | children.insert(KV: {it.first, generateMatcher(node&: *it.second, region&: *region)}); |
565 | builder.setInsertionPointToEnd(currentBlock); |
566 | |
567 | switch (question->getKind()) { |
568 | case Predicates::OperandCountQuestion: |
569 | return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer, |
570 | int32_t>(val, defaultDest, builder, children); |
571 | case Predicates::ResultCountQuestion: |
572 | return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer, |
573 | int32_t>(val, defaultDest, builder, children); |
574 | case Predicates::OperationNameQuestion: |
575 | return createSwitchOp<pdl_interp::SwitchOperationNameOp, |
576 | OperationNameAnswer>(val, defaultDest, builder, |
577 | children); |
578 | case Predicates::TypeQuestion: |
579 | if (isa<pdl::RangeType>(val.getType())) { |
580 | return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>( |
581 | val, defaultDest, builder, children); |
582 | } |
583 | return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>( |
584 | val, defaultDest, builder, children); |
585 | case Predicates::AttributeQuestion: |
586 | return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>( |
587 | val, defaultDest, builder, children); |
588 | default: |
589 | llvm_unreachable("Generating unknown switch predicate." ); |
590 | } |
591 | } |
592 | |
593 | void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) { |
594 | pdl::PatternOp pattern = successNode->getPattern(); |
595 | Value root = successNode->getRoot(); |
596 | |
597 | // Generate a rewriter for the pattern this success node represents, and track |
598 | // any values used from the match region. |
599 | SmallVector<Position *, 8> usedMatchValues; |
600 | SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues); |
601 | |
602 | // Process any values used in the rewrite that are defined in the match. |
603 | std::vector<Value> mappedMatchValues; |
604 | mappedMatchValues.reserve(n: usedMatchValues.size()); |
605 | for (Position *position : usedMatchValues) |
606 | mappedMatchValues.push_back(x: getValueAt(currentBlock, pos: position)); |
607 | |
608 | // Collect the set of operations generated by the rewriter. |
609 | SmallVector<StringRef, 4> generatedOps; |
610 | for (auto op : |
611 | pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>()) |
612 | generatedOps.push_back(*op.getOpName()); |
613 | ArrayAttr generatedOpsAttr; |
614 | if (!generatedOps.empty()) |
615 | generatedOpsAttr = builder.getStrArrayAttr(generatedOps); |
616 | |
617 | // Grab the root kind if present. |
618 | StringAttr rootKindAttr; |
619 | if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>()) |
620 | if (std::optional<StringRef> rootKind = rootOp.getOpName()) |
621 | rootKindAttr = builder.getStringAttr(*rootKind); |
622 | |
623 | builder.setInsertionPointToEnd(currentBlock); |
624 | auto matchOp = builder.create<pdl_interp::RecordMatchOp>( |
625 | pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), |
626 | rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(), |
627 | failureBlockStack.back()); |
628 | |
629 | // Set the config of the lowered match to the parent pattern. |
630 | if (configMap) |
631 | configMap->try_emplace(matchOp, configMap->lookup(pattern)); |
632 | } |
633 | |
634 | SymbolRefAttr PatternLowering::generateRewriter( |
635 | pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) { |
636 | builder.setInsertionPointToEnd(rewriterModule.getBody()); |
637 | auto rewriterFunc = builder.create<pdl_interp::FuncOp>( |
638 | pattern.getLoc(), "pdl_generated_rewriter" , |
639 | builder.getFunctionType(std::nullopt, std::nullopt)); |
640 | rewriterSymbolTable.insert(symbol: rewriterFunc); |
641 | |
642 | // Generate the rewriter function body. |
643 | builder.setInsertionPointToEnd(&rewriterFunc.front()); |
644 | |
645 | // Map an input operand of the pattern to a generated interpreter value. |
646 | DenseMap<Value, Value> rewriteValues; |
647 | auto mapRewriteValue = [&](Value oldValue) { |
648 | Value &newValue = rewriteValues[oldValue]; |
649 | if (newValue) |
650 | return newValue; |
651 | |
652 | // Prefer materializing constants directly when possible. |
653 | Operation *oldOp = oldValue.getDefiningOp(); |
654 | if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) { |
655 | if (Attribute value = attrOp.getValueAttr()) { |
656 | return newValue = builder.create<pdl_interp::CreateAttributeOp>( |
657 | attrOp.getLoc(), value); |
658 | } |
659 | } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) { |
660 | if (TypeAttr type = typeOp.getConstantTypeAttr()) { |
661 | return newValue = builder.create<pdl_interp::CreateTypeOp>( |
662 | typeOp.getLoc(), type); |
663 | } |
664 | } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) { |
665 | if (ArrayAttr type = typeOp.getConstantTypesAttr()) { |
666 | return newValue = builder.create<pdl_interp::CreateTypesOp>( |
667 | typeOp.getLoc(), typeOp.getType(), type); |
668 | } |
669 | } |
670 | |
671 | // Otherwise, add this as an input to the rewriter. |
672 | Position *inputPos = valueToPosition.lookup(oldValue); |
673 | assert(inputPos && "expected value to be a pattern input" ); |
674 | usedMatchValues.push_back(Elt: inputPos); |
675 | return newValue = rewriterFunc.front().addArgument(oldValue.getType(), |
676 | oldValue.getLoc()); |
677 | }; |
678 | |
679 | // If this is a custom rewriter, simply dispatch to the registered rewrite |
680 | // method. |
681 | pdl::RewriteOp rewriter = pattern.getRewriter(); |
682 | if (StringAttr rewriteName = rewriter.getNameAttr()) { |
683 | SmallVector<Value> args; |
684 | if (rewriter.getRoot()) |
685 | args.push_back(Elt: mapRewriteValue(rewriter.getRoot())); |
686 | auto mappedArgs = |
687 | llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue); |
688 | args.append(mappedArgs.begin(), mappedArgs.end()); |
689 | builder.create<pdl_interp::ApplyRewriteOp>( |
690 | rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args); |
691 | } else { |
692 | // Otherwise this is a dag rewriter defined using PDL operations. |
693 | for (Operation &rewriteOp : *rewriter.getBody()) { |
694 | llvm::TypeSwitch<Operation *>(&rewriteOp) |
695 | .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp, |
696 | pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp, |
697 | pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](auto op) { |
698 | this->generateRewriter(op, rewriteValues, mapRewriteValue); |
699 | }); |
700 | } |
701 | } |
702 | |
703 | // Update the signature of the rewrite function. |
704 | rewriterFunc.setType(builder.getFunctionType( |
705 | inputs: llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()), |
706 | /*results=*/std::nullopt)); |
707 | |
708 | builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc()); |
709 | return SymbolRefAttr::get( |
710 | builder.getContext(), |
711 | pdl_interp::PDLInterpDialect::getRewriterModuleName(), |
712 | SymbolRefAttr::get(rewriterFunc)); |
713 | } |
714 | |
715 | void PatternLowering::generateRewriter( |
716 | pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues, |
717 | function_ref<Value(Value)> mapRewriteValue) { |
718 | SmallVector<Value, 2> arguments; |
719 | for (Value argument : rewriteOp.getArgs()) |
720 | arguments.push_back(mapRewriteValue(argument)); |
721 | auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>( |
722 | rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.getNameAttr(), |
723 | arguments); |
724 | for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults())) |
725 | rewriteValues[std::get<0>(it)] = std::get<1>(it); |
726 | } |
727 | |
728 | void PatternLowering::generateRewriter( |
729 | pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues, |
730 | function_ref<Value(Value)> mapRewriteValue) { |
731 | Value newAttr = builder.create<pdl_interp::CreateAttributeOp>( |
732 | attrOp.getLoc(), attrOp.getValueAttr()); |
733 | rewriteValues[attrOp] = newAttr; |
734 | } |
735 | |
736 | void PatternLowering::generateRewriter( |
737 | pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues, |
738 | function_ref<Value(Value)> mapRewriteValue) { |
739 | builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(), |
740 | mapRewriteValue(eraseOp.getOpValue())); |
741 | } |
742 | |
743 | void PatternLowering::generateRewriter( |
744 | pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues, |
745 | function_ref<Value(Value)> mapRewriteValue) { |
746 | SmallVector<Value, 4> operands; |
747 | for (Value operand : operationOp.getOperandValues()) |
748 | operands.push_back(mapRewriteValue(operand)); |
749 | |
750 | SmallVector<Value, 4> attributes; |
751 | for (Value attr : operationOp.getAttributeValues()) |
752 | attributes.push_back(mapRewriteValue(attr)); |
753 | |
754 | bool hasInferredResultTypes = false; |
755 | SmallVector<Value, 2> types; |
756 | generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types, |
757 | rewriteValues, hasInferredResultTypes); |
758 | |
759 | // Create the new operation. |
760 | Location loc = operationOp.getLoc(); |
761 | Value createdOp = builder.create<pdl_interp::CreateOperationOp>( |
762 | loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands, |
763 | attributes, operationOp.getAttributeValueNames()); |
764 | rewriteValues[operationOp.getOp()] = createdOp; |
765 | |
766 | // Generate accesses for any results that have their types constrained. |
767 | // Handle the case where there is a single range representing all of the |
768 | // result types. |
769 | OperandRange resultTys = operationOp.getTypeValues(); |
770 | if (resultTys.size() == 1 && isa<pdl::RangeType>(resultTys[0].getType())) { |
771 | Value &type = rewriteValues[resultTys[0]]; |
772 | if (!type) { |
773 | auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp); |
774 | type = builder.create<pdl_interp::GetValueTypeOp>(loc, results); |
775 | } |
776 | return; |
777 | } |
778 | |
779 | // Otherwise, populate the individual results. |
780 | bool seenVariableLength = false; |
781 | Type valueTy = builder.getType<pdl::ValueType>(); |
782 | Type valueRangeTy = pdl::RangeType::get(valueTy); |
783 | for (const auto &it : llvm::enumerate(resultTys)) { |
784 | Value &type = rewriteValues[it.value()]; |
785 | if (type) |
786 | continue; |
787 | bool isVariadic = isa<pdl::RangeType>(it.value().getType()); |
788 | seenVariableLength |= isVariadic; |
789 | |
790 | // After a variable length result has been seen, we need to use result |
791 | // groups because the exact index of the result is not statically known. |
792 | Value resultVal; |
793 | if (seenVariableLength) |
794 | resultVal = builder.create<pdl_interp::GetResultsOp>( |
795 | loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index()); |
796 | else |
797 | resultVal = builder.create<pdl_interp::GetResultOp>( |
798 | loc, valueTy, createdOp, it.index()); |
799 | type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal); |
800 | } |
801 | } |
802 | |
803 | void PatternLowering::generateRewriter( |
804 | pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues, |
805 | function_ref<Value(Value)> mapRewriteValue) { |
806 | SmallVector<Value, 4> replOperands; |
807 | for (Value operand : rangeOp.getArguments()) |
808 | replOperands.push_back(mapRewriteValue(operand)); |
809 | rewriteValues[rangeOp] = builder.create<pdl_interp::CreateRangeOp>( |
810 | rangeOp.getLoc(), rangeOp.getType(), replOperands); |
811 | } |
812 | |
813 | void PatternLowering::generateRewriter( |
814 | pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues, |
815 | function_ref<Value(Value)> mapRewriteValue) { |
816 | SmallVector<Value, 4> replOperands; |
817 | |
818 | // If the replacement was another operation, get its results. `pdl` allows |
819 | // for using an operation for simplicitly, but the interpreter isn't as |
820 | // user facing. |
821 | if (Value replOp = replaceOp.getReplOperation()) { |
822 | // Don't use replace if we know the replaced operation has no results. |
823 | auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>(); |
824 | if (!opOp || !opOp.getTypeValues().empty()) { |
825 | replOperands.push_back(builder.create<pdl_interp::GetResultsOp>( |
826 | replOp.getLoc(), mapRewriteValue(replOp))); |
827 | } |
828 | } else { |
829 | for (Value operand : replaceOp.getReplValues()) |
830 | replOperands.push_back(mapRewriteValue(operand)); |
831 | } |
832 | |
833 | // If there are no replacement values, just create an erase instead. |
834 | if (replOperands.empty()) { |
835 | builder.create<pdl_interp::EraseOp>( |
836 | replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue())); |
837 | return; |
838 | } |
839 | |
840 | builder.create<pdl_interp::ReplaceOp>(replaceOp.getLoc(), |
841 | mapRewriteValue(replaceOp.getOpValue()), |
842 | replOperands); |
843 | } |
844 | |
845 | void PatternLowering::generateRewriter( |
846 | pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues, |
847 | function_ref<Value(Value)> mapRewriteValue) { |
848 | rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>( |
849 | resultOp.getLoc(), builder.getType<pdl::ValueType>(), |
850 | mapRewriteValue(resultOp.getParent()), resultOp.getIndex()); |
851 | } |
852 | |
853 | void PatternLowering::generateRewriter( |
854 | pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues, |
855 | function_ref<Value(Value)> mapRewriteValue) { |
856 | rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>( |
857 | resultOp.getLoc(), resultOp.getType(), |
858 | mapRewriteValue(resultOp.getParent()), resultOp.getIndex()); |
859 | } |
860 | |
861 | void PatternLowering::generateRewriter( |
862 | pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues, |
863 | function_ref<Value(Value)> mapRewriteValue) { |
864 | // If the type isn't constant, the users (e.g. OperationOp) will resolve this |
865 | // type. |
866 | if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) { |
867 | rewriteValues[typeOp] = |
868 | builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr); |
869 | } |
870 | } |
871 | |
872 | void PatternLowering::generateRewriter( |
873 | pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues, |
874 | function_ref<Value(Value)> mapRewriteValue) { |
875 | // If the type isn't constant, the users (e.g. OperationOp) will resolve this |
876 | // type. |
877 | if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) { |
878 | rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>( |
879 | typeOp.getLoc(), typeOp.getType(), typeAttr); |
880 | } |
881 | } |
882 | |
883 | void PatternLowering::generateOperationResultTypeRewriter( |
884 | pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue, |
885 | SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues, |
886 | bool &hasInferredResultTypes) { |
887 | Block *rewriterBlock = op->getBlock(); |
888 | |
889 | // Try to handle resolution for each of the result types individually. This is |
890 | // preferred over type inferrence because it will allow for us to use existing |
891 | // types directly, as opposed to trying to rebuild the type list. |
892 | OperandRange resultTypeValues = op.getTypeValues(); |
893 | auto tryResolveResultTypes = [&] { |
894 | types.reserve(N: resultTypeValues.size()); |
895 | for (const auto &it : llvm::enumerate(resultTypeValues)) { |
896 | Value resultType = it.value(); |
897 | |
898 | // Check for an already translated value. |
899 | if (Value existingRewriteValue = rewriteValues.lookup(resultType)) { |
900 | types.push_back(existingRewriteValue); |
901 | continue; |
902 | } |
903 | |
904 | // Check for an input from the matcher. |
905 | if (resultType.getDefiningOp()->getBlock() != rewriterBlock) { |
906 | types.push_back(mapRewriteValue(resultType)); |
907 | continue; |
908 | } |
909 | |
910 | // Otherwise, we couldn't infer the result types. Bail out here to see if |
911 | // we can infer the types for this operation from another way. |
912 | types.clear(); |
913 | return failure(); |
914 | } |
915 | return success(); |
916 | }; |
917 | if (!resultTypeValues.empty() && succeeded(result: tryResolveResultTypes())) |
918 | return; |
919 | |
920 | // Otherwise, check if the operation has type inference support itself. |
921 | if (op.hasTypeInference()) { |
922 | hasInferredResultTypes = true; |
923 | return; |
924 | } |
925 | |
926 | // Look for an operation that was replaced by `op`. The result types will be |
927 | // inferred from the results that were replaced. |
928 | for (OpOperand &use : op.getOp().getUses()) { |
929 | // Check that the use corresponds to a ReplaceOp and that it is the |
930 | // replacement value, not the operation being replaced. |
931 | pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner()); |
932 | if (!replOpUser || use.getOperandNumber() == 0) |
933 | continue; |
934 | // Make sure the replaced operation was defined before this one. PDL |
935 | // rewrites only have single block regions, so if the op isn't in the |
936 | // rewriter block (i.e. the current block of the operation) we already know |
937 | // it dominates (i.e. it's in the matcher). |
938 | Value replOpVal = replOpUser.getOpValue(); |
939 | Operation *replacedOp = replOpVal.getDefiningOp(); |
940 | if (replacedOp->getBlock() == rewriterBlock && |
941 | !replacedOp->isBeforeInBlock(op)) |
942 | continue; |
943 | |
944 | Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>( |
945 | replacedOp->getLoc(), mapRewriteValue(replOpVal)); |
946 | types.push_back(builder.create<pdl_interp::GetValueTypeOp>( |
947 | replacedOp->getLoc(), replacedOpResults)); |
948 | return; |
949 | } |
950 | |
951 | // If the types could not be inferred from any context and there weren't any |
952 | // explicit result types, assume the user actually meant for the operation to |
953 | // have no results. |
954 | if (resultTypeValues.empty()) |
955 | return; |
956 | |
957 | // The verifier asserts that the result types of each pdl.getOperation can be |
958 | // inferred. If we reach here, there is a bug either in the logic above or |
959 | // in the verifier for pdl.getOperation. |
960 | op->emitOpError() << "unable to infer result type for operation" ; |
961 | llvm_unreachable("unable to infer result type for operation" ); |
962 | } |
963 | |
964 | //===----------------------------------------------------------------------===// |
965 | // Conversion Pass |
966 | //===----------------------------------------------------------------------===// |
967 | |
968 | namespace { |
969 | struct PDLToPDLInterpPass |
970 | : public impl::ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> { |
971 | PDLToPDLInterpPass() = default; |
972 | PDLToPDLInterpPass(const PDLToPDLInterpPass &rhs) = default; |
973 | PDLToPDLInterpPass(DenseMap<Operation *, PDLPatternConfigSet *> &configMap) |
974 | : configMap(&configMap) {} |
975 | void runOnOperation() final; |
976 | |
977 | /// A map containing the configuration for each pattern. |
978 | DenseMap<Operation *, PDLPatternConfigSet *> *configMap = nullptr; |
979 | }; |
980 | } // namespace |
981 | |
982 | /// Convert the given module containing PDL pattern operations into a PDL |
983 | /// Interpreter operations. |
984 | void PDLToPDLInterpPass::runOnOperation() { |
985 | ModuleOp module = getOperation(); |
986 | |
987 | // Create the main matcher function This function contains all of the match |
988 | // related functionality from patterns in the module. |
989 | OpBuilder builder = OpBuilder::atBlockBegin(block: module.getBody()); |
990 | auto matcherFunc = builder.create<pdl_interp::FuncOp>( |
991 | module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(), |
992 | builder.getFunctionType(builder.getType<pdl::OperationType>(), |
993 | /*results=*/std::nullopt), |
994 | /*attrs=*/std::nullopt); |
995 | |
996 | // Create a nested module to hold the functions invoked for rewriting the IR |
997 | // after a successful match. |
998 | ModuleOp rewriterModule = builder.create<ModuleOp>( |
999 | module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName()); |
1000 | |
1001 | // Generate the code for the patterns within the module. |
1002 | PatternLowering generator(matcherFunc, rewriterModule, configMap); |
1003 | generator.lower(module: module); |
1004 | |
1005 | // After generation, delete all of the pattern operations. |
1006 | for (pdl::PatternOp pattern : |
1007 | llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) { |
1008 | // Drop the now dead config mappings. |
1009 | if (configMap) |
1010 | configMap->erase(pattern); |
1011 | |
1012 | pattern.erase(); |
1013 | } |
1014 | } |
1015 | |
1016 | std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() { |
1017 | return std::make_unique<PDLToPDLInterpPass>(); |
1018 | } |
1019 | std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass( |
1020 | DenseMap<Operation *, PDLPatternConfigSet *> &configMap) { |
1021 | return std::make_unique<PDLToPDLInterpPass>(args&: configMap); |
1022 | } |
1023 | |