1 | //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "TestDialect.h" |
10 | #include "TestOps.h" |
11 | #include "TestTypes.h" |
12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
13 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
14 | #include "mlir/Dialect/Func/Transforms/FuncConversions.h" |
15 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
16 | #include "mlir/IR/Matchers.h" |
17 | #include "mlir/Pass/Pass.h" |
18 | #include "mlir/Transforms/DialectConversion.h" |
19 | #include "mlir/Transforms/FoldUtils.h" |
20 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
21 | #include "llvm/ADT/ScopeExit.h" |
22 | |
23 | using namespace mlir; |
24 | using namespace test; |
25 | |
26 | // Native function for testing NativeCodeCall |
27 | static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { |
28 | return choice.getValue() ? input1 : input2; |
29 | } |
30 | |
31 | static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { |
32 | rewriter.create<OpI>(loc, input); |
33 | } |
34 | |
35 | static void handleNoResultOp(PatternRewriter &rewriter, |
36 | OpSymbolBindingNoResult op) { |
37 | // Turn the no result op to a one-result op. |
38 | rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(), |
39 | op.getOperand()); |
40 | } |
41 | |
42 | static bool getFirstI32Result(Operation *op, Value &value) { |
43 | if (!Type(op->getResult(idx: 0).getType()).isSignlessInteger(width: 32)) |
44 | return false; |
45 | value = op->getResult(idx: 0); |
46 | return true; |
47 | } |
48 | |
49 | static Value bindNativeCodeCallResult(Value value) { return value; } |
50 | |
51 | static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, |
52 | Value input2) { |
53 | return SmallVector<Value, 2>({input2, input1}); |
54 | } |
55 | |
56 | // Test that natives calls are only called once during rewrites. |
57 | // OpM_Test will return Pi, increased by 1 for each subsequent calls. |
58 | // This let us check the number of times OpM_Test was called by inspecting |
59 | // the returned value in the MLIR output. |
60 | static int64_t opMIncreasingValue = 314159265; |
61 | static Attribute opMTest(PatternRewriter &rewriter, Value val) { |
62 | int64_t i = opMIncreasingValue++; |
63 | return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); |
64 | } |
65 | |
66 | namespace { |
67 | #include "TestPatterns.inc" |
68 | } // namespace |
69 | |
70 | //===----------------------------------------------------------------------===// |
71 | // Test Reduce Pattern Interface |
72 | //===----------------------------------------------------------------------===// |
73 | |
74 | void test::populateTestReductionPatterns(RewritePatternSet &patterns) { |
75 | populateWithGenerated(patterns); |
76 | } |
77 | |
78 | //===----------------------------------------------------------------------===// |
79 | // Canonicalizer Driver. |
80 | //===----------------------------------------------------------------------===// |
81 | |
82 | namespace { |
83 | struct FoldingPattern : public RewritePattern { |
84 | public: |
85 | FoldingPattern(MLIRContext *context) |
86 | : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), |
87 | /*benefit=*/1, context) {} |
88 | |
89 | LogicalResult matchAndRewrite(Operation *op, |
90 | PatternRewriter &rewriter) const override { |
91 | // Exercise createOrFold API for a single-result operation that is folded |
92 | // upon construction. The operation being created has an in-place folder, |
93 | // and it should be still present in the output. Furthermore, the folder |
94 | // should not crash when attempting to recover the (unchanged) operation |
95 | // result. |
96 | Value result = rewriter.createOrFold<TestOpInPlaceFold>( |
97 | op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0)); |
98 | assert(result); |
99 | rewriter.replaceOp(op, newValues: result); |
100 | return success(); |
101 | } |
102 | }; |
103 | |
104 | /// This pattern creates a foldable operation at the entry point of the block. |
105 | /// This tests the situation where the operation folder will need to replace an |
106 | /// operation with a previously created constant that does not initially |
107 | /// dominate the operation to replace. |
108 | struct FolderInsertBeforePreviouslyFoldedConstantPattern |
109 | : public OpRewritePattern<TestCastOp> { |
110 | public: |
111 | using OpRewritePattern<TestCastOp>::OpRewritePattern; |
112 | |
113 | LogicalResult matchAndRewrite(TestCastOp op, |
114 | PatternRewriter &rewriter) const override { |
115 | if (!op->hasAttr("test_fold_before_previously_folded_op" )) |
116 | return failure(); |
117 | rewriter.setInsertionPointToStart(op->getBlock()); |
118 | |
119 | auto constOp = rewriter.create<arith::ConstantOp>( |
120 | op.getLoc(), rewriter.getBoolAttr(true)); |
121 | rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), |
122 | Value(constOp)); |
123 | return success(); |
124 | } |
125 | }; |
126 | |
127 | /// This pattern matches test.op_commutative2 with the first operand being |
128 | /// another test.op_commutative2 with a constant on the right side and fold it |
129 | /// away by propagating it as its result. This is intend to check that patterns |
130 | /// are applied after the commutative property moves constant to the right. |
131 | struct FolderCommutativeOp2WithConstant |
132 | : public OpRewritePattern<TestCommutative2Op> { |
133 | public: |
134 | using OpRewritePattern<TestCommutative2Op>::OpRewritePattern; |
135 | |
136 | LogicalResult matchAndRewrite(TestCommutative2Op op, |
137 | PatternRewriter &rewriter) const override { |
138 | auto operand = |
139 | dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp()); |
140 | if (!operand) |
141 | return failure(); |
142 | Attribute constInput; |
143 | if (!matchPattern(operand->getOperand(1), m_Constant(bind_value: &constInput))) |
144 | return failure(); |
145 | rewriter.replaceOp(op, operand->getOperand(1)); |
146 | return success(); |
147 | } |
148 | }; |
149 | |
150 | /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer |
151 | /// attribute with value smaller than MaxVal, it increments the value by 1. |
152 | template <int MaxVal> |
153 | struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> { |
154 | using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern; |
155 | |
156 | LogicalResult matchAndRewrite(AnyAttrOfOp op, |
157 | PatternRewriter &rewriter) const override { |
158 | auto intAttr = dyn_cast<IntegerAttr>(op.getAttr()); |
159 | if (!intAttr) |
160 | return failure(); |
161 | int64_t val = intAttr.getInt(); |
162 | if (val >= MaxVal) |
163 | return failure(); |
164 | rewriter.modifyOpInPlace( |
165 | op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); }); |
166 | return success(); |
167 | } |
168 | }; |
169 | |
170 | /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op". |
171 | struct MakeOpEligible : public RewritePattern { |
172 | MakeOpEligible(MLIRContext *context) |
173 | : RewritePattern("foo.maybe_eligible_op" , /*benefit=*/1, context) {} |
174 | |
175 | LogicalResult matchAndRewrite(Operation *op, |
176 | PatternRewriter &rewriter) const override { |
177 | if (op->hasAttr(name: "eligible" )) |
178 | return failure(); |
179 | rewriter.modifyOpInPlace( |
180 | root: op, callable: [&]() { op->setAttr("eligible" , rewriter.getUnitAttr()); }); |
181 | return success(); |
182 | } |
183 | }; |
184 | |
185 | /// This pattern hoists eligible ops out of a "test.one_region_op". |
186 | struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> { |
187 | using OpRewritePattern<test::OneRegionOp>::OpRewritePattern; |
188 | |
189 | LogicalResult matchAndRewrite(test::OneRegionOp op, |
190 | PatternRewriter &rewriter) const override { |
191 | Operation *terminator = op.getRegion().front().getTerminator(); |
192 | Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp(); |
193 | if (toBeHoisted->getParentOp() != op) |
194 | return failure(); |
195 | if (!toBeHoisted->hasAttr(name: "eligible" )) |
196 | return failure(); |
197 | rewriter.moveOpBefore(toBeHoisted, op); |
198 | return success(); |
199 | } |
200 | }; |
201 | |
202 | /// This pattern moves "test.move_before_parent_op" before the parent op. |
203 | struct MoveBeforeParentOp : public RewritePattern { |
204 | MoveBeforeParentOp(MLIRContext *context) |
205 | : RewritePattern("test.move_before_parent_op" , /*benefit=*/1, context) {} |
206 | |
207 | LogicalResult matchAndRewrite(Operation *op, |
208 | PatternRewriter &rewriter) const override { |
209 | // Do not hoist past functions. |
210 | if (isa<FunctionOpInterface>(Val: op->getParentOp())) |
211 | return failure(); |
212 | rewriter.moveOpBefore(op, existingOp: op->getParentOp()); |
213 | return success(); |
214 | } |
215 | }; |
216 | |
217 | /// This pattern inlines blocks that are nested in |
218 | /// "test.inline_blocks_into_parent" into the parent block. |
219 | struct InlineBlocksIntoParent : public RewritePattern { |
220 | InlineBlocksIntoParent(MLIRContext *context) |
221 | : RewritePattern("test.inline_blocks_into_parent" , /*benefit=*/1, |
222 | context) {} |
223 | |
224 | LogicalResult matchAndRewrite(Operation *op, |
225 | PatternRewriter &rewriter) const override { |
226 | bool changed = false; |
227 | for (Region &r : op->getRegions()) { |
228 | while (!r.empty()) { |
229 | rewriter.inlineBlockBefore(source: &r.front(), op); |
230 | changed = true; |
231 | } |
232 | } |
233 | return success(isSuccess: changed); |
234 | } |
235 | }; |
236 | |
237 | /// This pattern splits blocks at "test.split_block_here" and replaces the op |
238 | /// with a new op (to prevent an infinite loop of block splitting). |
239 | struct SplitBlockHere : public RewritePattern { |
240 | SplitBlockHere(MLIRContext *context) |
241 | : RewritePattern("test.split_block_here" , /*benefit=*/1, context) {} |
242 | |
243 | LogicalResult matchAndRewrite(Operation *op, |
244 | PatternRewriter &rewriter) const override { |
245 | rewriter.splitBlock(block: op->getBlock(), before: op->getIterator()); |
246 | Operation *newOp = rewriter.create( |
247 | op->getLoc(), |
248 | OperationName("test.new_op" , op->getContext()).getIdentifier(), |
249 | op->getOperands(), op->getResultTypes()); |
250 | rewriter.replaceOp(op, newOp); |
251 | return success(); |
252 | } |
253 | }; |
254 | |
255 | /// This pattern clones "test.clone_me" ops. |
256 | struct CloneOp : public RewritePattern { |
257 | CloneOp(MLIRContext *context) |
258 | : RewritePattern("test.clone_me" , /*benefit=*/1, context) {} |
259 | |
260 | LogicalResult matchAndRewrite(Operation *op, |
261 | PatternRewriter &rewriter) const override { |
262 | // Do not clone already cloned ops to avoid going into an infinite loop. |
263 | if (op->hasAttr(name: "was_cloned" )) |
264 | return failure(); |
265 | Operation *cloned = rewriter.clone(op&: *op); |
266 | cloned->setAttr("was_cloned" , rewriter.getUnitAttr()); |
267 | return success(); |
268 | } |
269 | }; |
270 | |
271 | /// This pattern clones regions of "test.clone_region_before" ops before the |
272 | /// parent block. |
273 | struct CloneRegionBeforeOp : public RewritePattern { |
274 | CloneRegionBeforeOp(MLIRContext *context) |
275 | : RewritePattern("test.clone_region_before" , /*benefit=*/1, context) {} |
276 | |
277 | LogicalResult matchAndRewrite(Operation *op, |
278 | PatternRewriter &rewriter) const override { |
279 | // Do not clone already cloned ops to avoid going into an infinite loop. |
280 | if (op->hasAttr(name: "was_cloned" )) |
281 | return failure(); |
282 | for (Region &r : op->getRegions()) |
283 | rewriter.cloneRegionBefore(region&: r, before: op->getBlock()); |
284 | op->setAttr("was_cloned" , rewriter.getUnitAttr()); |
285 | return success(); |
286 | } |
287 | }; |
288 | |
289 | struct TestPatternDriver |
290 | : public PassWrapper<TestPatternDriver, OperationPass<>> { |
291 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver) |
292 | |
293 | TestPatternDriver() = default; |
294 | TestPatternDriver(const TestPatternDriver &other) : PassWrapper(other) {} |
295 | |
296 | StringRef getArgument() const final { return "test-patterns" ; } |
297 | StringRef getDescription() const final { return "Run test dialect patterns" ; } |
298 | void runOnOperation() override { |
299 | mlir::RewritePatternSet patterns(&getContext()); |
300 | populateWithGenerated(patterns); |
301 | |
302 | // Verify named pattern is generated with expected name. |
303 | patterns.add<FoldingPattern, TestNamedPatternRule, |
304 | FolderInsertBeforePreviouslyFoldedConstantPattern, |
305 | FolderCommutativeOp2WithConstant, HoistEligibleOps, |
306 | MakeOpEligible>(&getContext()); |
307 | |
308 | // Additional patterns for testing the GreedyPatternRewriteDriver. |
309 | patterns.insert<IncrementIntAttribute<3>>(arg: &getContext()); |
310 | |
311 | GreedyRewriteConfig config; |
312 | config.useTopDownTraversal = this->useTopDownTraversal; |
313 | config.maxIterations = this->maxIterations; |
314 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), |
315 | config); |
316 | } |
317 | |
318 | Option<bool> useTopDownTraversal{ |
319 | *this, "top-down" , |
320 | llvm::cl::desc("Seed the worklist in general top-down order" ), |
321 | llvm::cl::init(Val: GreedyRewriteConfig().useTopDownTraversal)}; |
322 | Option<int> maxIterations{ |
323 | *this, "max-iterations" , |
324 | llvm::cl::desc("Max. iterations in the GreedyRewriteConfig" ), |
325 | llvm::cl::init(Val: GreedyRewriteConfig().maxIterations)}; |
326 | }; |
327 | |
328 | struct DumpNotifications : public RewriterBase::Listener { |
329 | void notifyBlockInserted(Block *block, Region *previous, |
330 | Region::iterator previousIt) override { |
331 | llvm::outs() << "notifyBlockInserted" ; |
332 | if (block->getParentOp()) { |
333 | llvm::outs() << " into " << block->getParentOp()->getName() << ": " ; |
334 | } else { |
335 | llvm::outs() << " into unknown op: " ; |
336 | } |
337 | if (previous == nullptr) { |
338 | llvm::outs() << "was unlinked\n" ; |
339 | } else { |
340 | llvm::outs() << "was linked\n" ; |
341 | } |
342 | } |
343 | void notifyOperationInserted(Operation *op, |
344 | OpBuilder::InsertPoint previous) override { |
345 | llvm::outs() << "notifyOperationInserted: " << op->getName(); |
346 | if (!previous.isSet()) { |
347 | llvm::outs() << ", was unlinked\n" ; |
348 | } else { |
349 | if (!previous.getPoint().getNodePtr()) { |
350 | llvm::outs() << ", was linked, exact position unknown\n" ; |
351 | } else if (previous.getPoint() == previous.getBlock()->end()) { |
352 | llvm::outs() << ", was last in block\n" ; |
353 | } else { |
354 | llvm::outs() << ", previous = " << previous.getPoint()->getName() |
355 | << "\n" ; |
356 | } |
357 | } |
358 | } |
359 | void notifyBlockErased(Block *block) override { |
360 | llvm::outs() << "notifyBlockErased\n" ; |
361 | } |
362 | void notifyOperationErased(Operation *op) override { |
363 | llvm::outs() << "notifyOperationErased: " << op->getName() << "\n" ; |
364 | } |
365 | void notifyOperationModified(Operation *op) override { |
366 | llvm::outs() << "notifyOperationModified: " << op->getName() << "\n" ; |
367 | } |
368 | void notifyOperationReplaced(Operation *op, ValueRange values) override { |
369 | llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n" ; |
370 | } |
371 | }; |
372 | |
373 | struct TestStrictPatternDriver |
374 | : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> { |
375 | public: |
376 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver) |
377 | |
378 | TestStrictPatternDriver() = default; |
379 | TestStrictPatternDriver(const TestStrictPatternDriver &other) |
380 | : PassWrapper(other) { |
381 | strictMode = other.strictMode; |
382 | } |
383 | |
384 | StringRef getArgument() const final { return "test-strict-pattern-driver" ; } |
385 | StringRef getDescription() const final { |
386 | return "Test strict mode of pattern driver" ; |
387 | } |
388 | |
389 | void runOnOperation() override { |
390 | MLIRContext *ctx = &getContext(); |
391 | mlir::RewritePatternSet patterns(ctx); |
392 | patterns.add< |
393 | // clang-format off |
394 | ChangeBlockOp, |
395 | CloneOp, |
396 | CloneRegionBeforeOp, |
397 | EraseOp, |
398 | ImplicitChangeOp, |
399 | InlineBlocksIntoParent, |
400 | InsertSameOp, |
401 | MoveBeforeParentOp, |
402 | ReplaceWithNewOp, |
403 | SplitBlockHere |
404 | // clang-format on |
405 | >(arg&: ctx); |
406 | SmallVector<Operation *> ops; |
407 | getOperation()->walk([&](Operation *op) { |
408 | StringRef opName = op->getName().getStringRef(); |
409 | if (opName == "test.insert_same_op" || opName == "test.change_block_op" || |
410 | opName == "test.replace_with_new_op" || opName == "test.erase_op" || |
411 | opName == "test.move_before_parent_op" || |
412 | opName == "test.inline_blocks_into_parent" || |
413 | opName == "test.split_block_here" || opName == "test.clone_me" || |
414 | opName == "test.clone_region_before" ) { |
415 | ops.push_back(Elt: op); |
416 | } |
417 | }); |
418 | |
419 | DumpNotifications dumpNotifications; |
420 | GreedyRewriteConfig config; |
421 | config.listener = &dumpNotifications; |
422 | if (strictMode == "AnyOp" ) { |
423 | config.strictMode = GreedyRewriteStrictness::AnyOp; |
424 | } else if (strictMode == "ExistingAndNewOps" ) { |
425 | config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; |
426 | } else if (strictMode == "ExistingOps" ) { |
427 | config.strictMode = GreedyRewriteStrictness::ExistingOps; |
428 | } else { |
429 | llvm_unreachable("invalid strictness option" ); |
430 | } |
431 | |
432 | // Check if these transformations introduce visiting of operations that |
433 | // are not in the `ops` set (The new created ops are valid). An invalid |
434 | // operation will trigger the assertion while processing. |
435 | bool changed = false; |
436 | bool allErased = false; |
437 | (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config, |
438 | &changed, &allErased); |
439 | Builder b(ctx); |
440 | getOperation()->setAttr("pattern_driver_changed" , b.getBoolAttr(value: changed)); |
441 | getOperation()->setAttr("pattern_driver_all_erased" , |
442 | b.getBoolAttr(value: allErased)); |
443 | } |
444 | |
445 | Option<std::string> strictMode{ |
446 | *this, "strictness" , |
447 | llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}" ), |
448 | llvm::cl::init("AnyOp" )}; |
449 | |
450 | private: |
451 | // New inserted operation is valid for further transformation. |
452 | class InsertSameOp : public RewritePattern { |
453 | public: |
454 | InsertSameOp(MLIRContext *context) |
455 | : RewritePattern("test.insert_same_op" , /*benefit=*/1, context) {} |
456 | |
457 | LogicalResult matchAndRewrite(Operation *op, |
458 | PatternRewriter &rewriter) const override { |
459 | if (op->hasAttr(name: "skip" )) |
460 | return failure(); |
461 | |
462 | Operation *newOp = |
463 | rewriter.create(op->getLoc(), op->getName().getIdentifier(), |
464 | op->getOperands(), op->getResultTypes()); |
465 | rewriter.modifyOpInPlace( |
466 | root: op, callable: [&]() { op->setAttr(name: "skip" , value: rewriter.getBoolAttr(value: true)); }); |
467 | newOp->setAttr(name: "skip" , value: rewriter.getBoolAttr(value: true)); |
468 | |
469 | return success(); |
470 | } |
471 | }; |
472 | |
473 | // Replace an operation may introduce the re-visiting of its users. |
474 | class ReplaceWithNewOp : public RewritePattern { |
475 | public: |
476 | ReplaceWithNewOp(MLIRContext *context) |
477 | : RewritePattern("test.replace_with_new_op" , /*benefit=*/1, context) {} |
478 | |
479 | LogicalResult matchAndRewrite(Operation *op, |
480 | PatternRewriter &rewriter) const override { |
481 | Operation *newOp; |
482 | if (op->hasAttr(name: "create_erase_op" )) { |
483 | newOp = rewriter.create( |
484 | op->getLoc(), |
485 | OperationName("test.erase_op" , op->getContext()).getIdentifier(), |
486 | ValueRange(), TypeRange()); |
487 | } else { |
488 | newOp = rewriter.create( |
489 | op->getLoc(), |
490 | OperationName("test.new_op" , op->getContext()).getIdentifier(), |
491 | op->getOperands(), op->getResultTypes()); |
492 | } |
493 | // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp". |
494 | // A "notifyOperationReplaced" callback is triggered in either case. |
495 | rewriter.replaceAllOpUsesWith(from: op, to: newOp->getResults()); |
496 | rewriter.eraseOp(op); |
497 | return success(); |
498 | } |
499 | }; |
500 | |
501 | // Remove an operation may introduce the re-visiting of its operands. |
502 | class EraseOp : public RewritePattern { |
503 | public: |
504 | EraseOp(MLIRContext *context) |
505 | : RewritePattern("test.erase_op" , /*benefit=*/1, context) {} |
506 | LogicalResult matchAndRewrite(Operation *op, |
507 | PatternRewriter &rewriter) const override { |
508 | rewriter.eraseOp(op); |
509 | return success(); |
510 | } |
511 | }; |
512 | |
513 | // The following two patterns test RewriterBase::replaceAllUsesWith. |
514 | // |
515 | // That function replaces all usages of a Block (or a Value) with another one |
516 | // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver |
517 | // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its |
518 | // worklist: when an op is modified, it is added to the worklist. The two |
519 | // patterns below make the tracking observable: ChangeBlockOp replaces all |
520 | // usages of a block and that pattern is applied because the corresponding ops |
521 | // are put on the initial worklist (see above). ImplicitChangeOp does an |
522 | // unrelated change but ops of the corresponding type are *not* on the initial |
523 | // worklist, so the effect of the second pattern is only visible if the |
524 | // tracking and subsequent adding to the worklist actually works. |
525 | |
526 | // Replace all usages of the first successor with the second successor. |
527 | class ChangeBlockOp : public RewritePattern { |
528 | public: |
529 | ChangeBlockOp(MLIRContext *context) |
530 | : RewritePattern("test.change_block_op" , /*benefit=*/1, context) {} |
531 | LogicalResult matchAndRewrite(Operation *op, |
532 | PatternRewriter &rewriter) const override { |
533 | if (op->getNumSuccessors() < 2) |
534 | return failure(); |
535 | Block *firstSuccessor = op->getSuccessor(index: 0); |
536 | Block *secondSuccessor = op->getSuccessor(index: 1); |
537 | if (firstSuccessor == secondSuccessor) |
538 | return failure(); |
539 | // This is the function being tested: |
540 | rewriter.replaceAllUsesWith(from: firstSuccessor, to: secondSuccessor); |
541 | // Using the following line instead would make the test fail: |
542 | // firstSuccessor->replaceAllUsesWith(secondSuccessor); |
543 | return success(); |
544 | } |
545 | }; |
546 | |
547 | // Changes the successor to the parent block. |
548 | class ImplicitChangeOp : public RewritePattern { |
549 | public: |
550 | ImplicitChangeOp(MLIRContext *context) |
551 | : RewritePattern("test.implicit_change_op" , /*benefit=*/1, context) {} |
552 | LogicalResult matchAndRewrite(Operation *op, |
553 | PatternRewriter &rewriter) const override { |
554 | if (op->getNumSuccessors() < 1 || op->getSuccessor(index: 0) == op->getBlock()) |
555 | return failure(); |
556 | rewriter.modifyOpInPlace(root: op, |
557 | callable: [&]() { op->setSuccessor(block: op->getBlock(), index: 0); }); |
558 | return success(); |
559 | } |
560 | }; |
561 | }; |
562 | |
563 | } // namespace |
564 | |
565 | //===----------------------------------------------------------------------===// |
566 | // ReturnType Driver. |
567 | //===----------------------------------------------------------------------===// |
568 | |
569 | namespace { |
570 | // Generate ops for each instance where the type can be successfully inferred. |
571 | template <typename OpTy> |
572 | static void invokeCreateWithInferredReturnType(Operation *op) { |
573 | auto *context = op->getContext(); |
574 | auto fop = op->getParentOfType<func::FuncOp>(); |
575 | auto location = UnknownLoc::get(context); |
576 | OpBuilder b(op); |
577 | b.setInsertionPointAfter(op); |
578 | |
579 | // Use permutations of 2 args as operands. |
580 | assert(fop.getNumArguments() >= 2); |
581 | for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { |
582 | for (int j = 0; j < e; ++j) { |
583 | std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; |
584 | SmallVector<Type, 2> inferredReturnTypes; |
585 | if (succeeded(OpTy::inferReturnTypes( |
586 | context, std::nullopt, values, op->getDiscardableAttrDictionary(), |
587 | op->getPropertiesStorage(), op->getRegions(), |
588 | inferredReturnTypes))) { |
589 | OperationState state(location, OpTy::getOperationName()); |
590 | // TODO: Expand to regions. |
591 | OpTy::build(b, state, values, op->getAttrs()); |
592 | (void)b.create(state); |
593 | } |
594 | } |
595 | } |
596 | } |
597 | |
598 | static void reifyReturnShape(Operation *op) { |
599 | OpBuilder b(op); |
600 | |
601 | // Use permutations of 2 args as operands. |
602 | auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); |
603 | SmallVector<Value, 2> shapes; |
604 | if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || |
605 | !llvm::hasSingleElement(C&: shapes)) |
606 | return; |
607 | for (const auto &it : llvm::enumerate(First&: shapes)) { |
608 | op->emitRemark() << "value " << it.index() << ": " |
609 | << it.value().getDefiningOp(); |
610 | } |
611 | } |
612 | |
613 | struct TestReturnTypeDriver |
614 | : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> { |
615 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver) |
616 | |
617 | void getDependentDialects(DialectRegistry ®istry) const override { |
618 | registry.insert<tensor::TensorDialect>(); |
619 | } |
620 | StringRef getArgument() const final { return "test-return-type" ; } |
621 | StringRef getDescription() const final { return "Run return type functions" ; } |
622 | |
623 | void runOnOperation() override { |
624 | if (getOperation().getName() == "testCreateFunctions" ) { |
625 | std::vector<Operation *> ops; |
626 | // Collect ops to avoid triggering on inserted ops. |
627 | for (auto &op : getOperation().getBody().front()) |
628 | ops.push_back(&op); |
629 | // Generate test patterns for each, but skip terminator. |
630 | for (auto *op : llvm::ArrayRef(ops).drop_back()) { |
631 | // Test create method of each of the Op classes below. The resultant |
632 | // output would be in reverse order underneath `op` from which |
633 | // the attributes and regions are used. |
634 | invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); |
635 | invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>( |
636 | op); |
637 | invokeCreateWithInferredReturnType< |
638 | OpWithShapedTypeInferTypeInterfaceOp>(op); |
639 | }; |
640 | return; |
641 | } |
642 | if (getOperation().getName() == "testReifyFunctions" ) { |
643 | std::vector<Operation *> ops; |
644 | // Collect ops to avoid triggering on inserted ops. |
645 | for (auto &op : getOperation().getBody().front()) |
646 | if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) |
647 | ops.push_back(&op); |
648 | // Generate test patterns for each, but skip terminator. |
649 | for (auto *op : ops) |
650 | reifyReturnShape(op); |
651 | } |
652 | } |
653 | }; |
654 | } // namespace |
655 | |
656 | namespace { |
657 | struct TestDerivedAttributeDriver |
658 | : public PassWrapper<TestDerivedAttributeDriver, |
659 | OperationPass<func::FuncOp>> { |
660 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver) |
661 | |
662 | StringRef getArgument() const final { return "test-derived-attr" ; } |
663 | StringRef getDescription() const final { |
664 | return "Run test derived attributes" ; |
665 | } |
666 | void runOnOperation() override; |
667 | }; |
668 | } // namespace |
669 | |
670 | void TestDerivedAttributeDriver::runOnOperation() { |
671 | getOperation().walk([](DerivedAttributeOpInterface dOp) { |
672 | auto dAttr = dOp.materializeDerivedAttributes(); |
673 | if (!dAttr) |
674 | return; |
675 | for (auto d : dAttr) |
676 | dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue(); |
677 | }); |
678 | } |
679 | |
680 | //===----------------------------------------------------------------------===// |
681 | // Legalization Driver. |
682 | //===----------------------------------------------------------------------===// |
683 | |
684 | namespace { |
685 | //===----------------------------------------------------------------------===// |
686 | // Region-Block Rewrite Testing |
687 | |
688 | /// This pattern is a simple pattern that inlines the first region of a given |
689 | /// operation into the parent region. |
690 | struct TestRegionRewriteBlockMovement : public ConversionPattern { |
691 | TestRegionRewriteBlockMovement(MLIRContext *ctx) |
692 | : ConversionPattern("test.region" , 1, ctx) {} |
693 | |
694 | LogicalResult |
695 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
696 | ConversionPatternRewriter &rewriter) const final { |
697 | // Inline this region into the parent region. |
698 | auto &parentRegion = *op->getParentRegion(); |
699 | auto &opRegion = op->getRegion(index: 0); |
700 | if (op->getDiscardableAttr(name: "legalizer.should_clone" )) |
701 | rewriter.cloneRegionBefore(region&: opRegion, parent&: parentRegion, before: parentRegion.end()); |
702 | else |
703 | rewriter.inlineRegionBefore(region&: opRegion, parent&: parentRegion, before: parentRegion.end()); |
704 | |
705 | if (op->getDiscardableAttr(name: "legalizer.erase_old_blocks" )) { |
706 | while (!opRegion.empty()) |
707 | rewriter.eraseBlock(block: &opRegion.front()); |
708 | } |
709 | |
710 | // Drop this operation. |
711 | rewriter.eraseOp(op); |
712 | return success(); |
713 | } |
714 | }; |
715 | /// This pattern is a simple pattern that generates a region containing an |
716 | /// illegal operation. |
717 | struct TestRegionRewriteUndo : public RewritePattern { |
718 | TestRegionRewriteUndo(MLIRContext *ctx) |
719 | : RewritePattern("test.region_builder" , 1, ctx) {} |
720 | |
721 | LogicalResult matchAndRewrite(Operation *op, |
722 | PatternRewriter &rewriter) const final { |
723 | // Create the region operation with an entry block containing arguments. |
724 | OperationState newRegion(op->getLoc(), "test.region" ); |
725 | newRegion.addRegion(); |
726 | auto *regionOp = rewriter.create(state: newRegion); |
727 | auto *entryBlock = rewriter.createBlock(parent: ®ionOp->getRegion(index: 0)); |
728 | entryBlock->addArgument(rewriter.getIntegerType(64), |
729 | rewriter.getUnknownLoc()); |
730 | |
731 | // Add an explicitly illegal operation to ensure the conversion fails. |
732 | rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); |
733 | rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); |
734 | |
735 | // Drop this operation. |
736 | rewriter.eraseOp(op); |
737 | return success(); |
738 | } |
739 | }; |
740 | /// A simple pattern that creates a block at the end of the parent region of the |
741 | /// matched operation. |
742 | struct TestCreateBlock : public RewritePattern { |
743 | TestCreateBlock(MLIRContext *ctx) |
744 | : RewritePattern("test.create_block" , /*benefit=*/1, ctx) {} |
745 | |
746 | LogicalResult matchAndRewrite(Operation *op, |
747 | PatternRewriter &rewriter) const final { |
748 | Region ®ion = *op->getParentRegion(); |
749 | Type i32Type = rewriter.getIntegerType(32); |
750 | Location loc = op->getLoc(); |
751 | rewriter.createBlock(parent: ®ion, insertPt: region.end(), argTypes: {i32Type, i32Type}, locs: {loc, loc}); |
752 | rewriter.create<TerminatorOp>(loc); |
753 | rewriter.eraseOp(op); |
754 | return success(); |
755 | } |
756 | }; |
757 | |
758 | /// A simple pattern that creates a block containing an invalid operation in |
759 | /// order to trigger the block creation undo mechanism. |
760 | struct TestCreateIllegalBlock : public RewritePattern { |
761 | TestCreateIllegalBlock(MLIRContext *ctx) |
762 | : RewritePattern("test.create_illegal_block" , /*benefit=*/1, ctx) {} |
763 | |
764 | LogicalResult matchAndRewrite(Operation *op, |
765 | PatternRewriter &rewriter) const final { |
766 | Region ®ion = *op->getParentRegion(); |
767 | Type i32Type = rewriter.getIntegerType(32); |
768 | Location loc = op->getLoc(); |
769 | rewriter.createBlock(parent: ®ion, insertPt: region.end(), argTypes: {i32Type, i32Type}, locs: {loc, loc}); |
770 | // Create an illegal op to ensure the conversion fails. |
771 | rewriter.create<ILLegalOpF>(loc, i32Type); |
772 | rewriter.create<TerminatorOp>(loc); |
773 | rewriter.eraseOp(op); |
774 | return success(); |
775 | } |
776 | }; |
777 | |
778 | /// A simple pattern that tests the undo mechanism when replacing the uses of a |
779 | /// block argument. |
780 | struct TestUndoBlockArgReplace : public ConversionPattern { |
781 | TestUndoBlockArgReplace(MLIRContext *ctx) |
782 | : ConversionPattern("test.undo_block_arg_replace" , /*benefit=*/1, ctx) {} |
783 | |
784 | LogicalResult |
785 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
786 | ConversionPatternRewriter &rewriter) const final { |
787 | auto illegalOp = |
788 | rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); |
789 | rewriter.replaceUsesOfBlockArgument(from: op->getRegion(index: 0).getArgument(i: 0), |
790 | to: illegalOp->getResult(0)); |
791 | rewriter.modifyOpInPlace(root: op, callable: [] {}); |
792 | return success(); |
793 | } |
794 | }; |
795 | |
796 | /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion. |
797 | /// This is to test the rollback logic. |
798 | struct TestUndoMoveOpBefore : public ConversionPattern { |
799 | TestUndoMoveOpBefore(MLIRContext *ctx) |
800 | : ConversionPattern("test.hoist_me" , /*benefit=*/1, ctx) {} |
801 | |
802 | LogicalResult |
803 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
804 | ConversionPatternRewriter &rewriter) const override { |
805 | rewriter.moveOpBefore(op, existingOp: op->getParentOp()); |
806 | // Replace with an illegal op to ensure the conversion fails. |
807 | rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type()); |
808 | return success(); |
809 | } |
810 | }; |
811 | |
812 | /// A rewrite pattern that tests the undo mechanism when erasing a block. |
813 | struct TestUndoBlockErase : public ConversionPattern { |
814 | TestUndoBlockErase(MLIRContext *ctx) |
815 | : ConversionPattern("test.undo_block_erase" , /*benefit=*/1, ctx) {} |
816 | |
817 | LogicalResult |
818 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
819 | ConversionPatternRewriter &rewriter) const final { |
820 | Block *secondBlock = &*std::next(x: op->getRegion(index: 0).begin()); |
821 | rewriter.setInsertionPointToStart(secondBlock); |
822 | rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); |
823 | rewriter.eraseBlock(block: secondBlock); |
824 | rewriter.modifyOpInPlace(root: op, callable: [] {}); |
825 | return success(); |
826 | } |
827 | }; |
828 | |
829 | /// A pattern that modifies a property in-place, but keeps the op illegal. |
830 | struct TestUndoPropertiesModification : public ConversionPattern { |
831 | TestUndoPropertiesModification(MLIRContext *ctx) |
832 | : ConversionPattern("test.with_properties" , /*benefit=*/1, ctx) {} |
833 | LogicalResult |
834 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
835 | ConversionPatternRewriter &rewriter) const final { |
836 | if (!op->hasAttr(name: "modify_inplace" )) |
837 | return failure(); |
838 | rewriter.modifyOpInPlace( |
839 | root: op, callable: [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); }); |
840 | return success(); |
841 | } |
842 | }; |
843 | |
844 | //===----------------------------------------------------------------------===// |
845 | // Type-Conversion Rewrite Testing |
846 | |
847 | /// This patterns erases a region operation that has had a type conversion. |
848 | struct TestDropOpSignatureConversion : public ConversionPattern { |
849 | TestDropOpSignatureConversion(MLIRContext *ctx, |
850 | const TypeConverter &converter) |
851 | : ConversionPattern(converter, "test.drop_region_op" , 1, ctx) {} |
852 | LogicalResult |
853 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
854 | ConversionPatternRewriter &rewriter) const override { |
855 | Region ®ion = op->getRegion(index: 0); |
856 | Block *entry = ®ion.front(); |
857 | |
858 | // Convert the original entry arguments. |
859 | const TypeConverter &converter = *getTypeConverter(); |
860 | TypeConverter::SignatureConversion result(entry->getNumArguments()); |
861 | if (failed(result: converter.convertSignatureArgs(types: entry->getArgumentTypes(), |
862 | result)) || |
863 | failed(result: rewriter.convertRegionTypes(region: ®ion, converter, entryConversion: &result))) |
864 | return failure(); |
865 | |
866 | // Convert the region signature and just drop the operation. |
867 | rewriter.eraseOp(op); |
868 | return success(); |
869 | } |
870 | }; |
871 | /// This pattern simply updates the operands of the given operation. |
872 | struct TestPassthroughInvalidOp : public ConversionPattern { |
873 | TestPassthroughInvalidOp(MLIRContext *ctx) |
874 | : ConversionPattern("test.invalid" , 1, ctx) {} |
875 | LogicalResult |
876 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
877 | ConversionPatternRewriter &rewriter) const final { |
878 | rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, operands, |
879 | std::nullopt); |
880 | return success(); |
881 | } |
882 | }; |
883 | /// This pattern handles the case of a split return value. |
884 | struct TestSplitReturnType : public ConversionPattern { |
885 | TestSplitReturnType(MLIRContext *ctx) |
886 | : ConversionPattern("test.return" , 1, ctx) {} |
887 | LogicalResult |
888 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
889 | ConversionPatternRewriter &rewriter) const final { |
890 | // Check for a return of F32. |
891 | if (op->getNumOperands() != 1 || !op->getOperand(idx: 0).getType().isF32()) |
892 | return failure(); |
893 | |
894 | // Check if the first operation is a cast operation, if it is we use the |
895 | // results directly. |
896 | auto *defOp = operands[0].getDefiningOp(); |
897 | if (auto packerOp = |
898 | llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) { |
899 | rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands()); |
900 | return success(); |
901 | } |
902 | |
903 | // Otherwise, fail to match. |
904 | return failure(); |
905 | } |
906 | }; |
907 | |
908 | //===----------------------------------------------------------------------===// |
909 | // Multi-Level Type-Conversion Rewrite Testing |
910 | struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { |
911 | TestChangeProducerTypeI32ToF32(MLIRContext *ctx) |
912 | : ConversionPattern("test.type_producer" , 1, ctx) {} |
913 | LogicalResult |
914 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
915 | ConversionPatternRewriter &rewriter) const final { |
916 | // If the type is I32, change the type to F32. |
917 | if (!Type(*op->result_type_begin()).isSignlessInteger(width: 32)) |
918 | return failure(); |
919 | rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); |
920 | return success(); |
921 | } |
922 | }; |
923 | struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { |
924 | TestChangeProducerTypeF32ToF64(MLIRContext *ctx) |
925 | : ConversionPattern("test.type_producer" , 1, ctx) {} |
926 | LogicalResult |
927 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
928 | ConversionPatternRewriter &rewriter) const final { |
929 | // If the type is F32, change the type to F64. |
930 | if (!Type(*op->result_type_begin()).isF32()) |
931 | return rewriter.notifyMatchFailure(arg&: op, msg: "expected single f32 operand" ); |
932 | rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); |
933 | return success(); |
934 | } |
935 | }; |
936 | struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { |
937 | TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) |
938 | : ConversionPattern("test.type_producer" , 10, ctx) {} |
939 | LogicalResult |
940 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
941 | ConversionPatternRewriter &rewriter) const final { |
942 | // Always convert to B16, even though it is not a legal type. This tests |
943 | // that values are unmapped correctly. |
944 | rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); |
945 | return success(); |
946 | } |
947 | }; |
948 | struct TestUpdateConsumerType : public ConversionPattern { |
949 | TestUpdateConsumerType(MLIRContext *ctx) |
950 | : ConversionPattern("test.type_consumer" , 1, ctx) {} |
951 | LogicalResult |
952 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
953 | ConversionPatternRewriter &rewriter) const final { |
954 | // Verify that the incoming operand has been successfully remapped to F64. |
955 | if (!operands[0].getType().isF64()) |
956 | return failure(); |
957 | rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); |
958 | return success(); |
959 | } |
960 | }; |
961 | |
962 | //===----------------------------------------------------------------------===// |
963 | // Non-Root Replacement Rewrite Testing |
964 | /// This pattern generates an invalid operation, but replaces it before the |
965 | /// pattern is finished. This checks that we don't need to legalize the |
966 | /// temporary op. |
967 | struct TestNonRootReplacement : public RewritePattern { |
968 | TestNonRootReplacement(MLIRContext *ctx) |
969 | : RewritePattern("test.replace_non_root" , 1, ctx) {} |
970 | |
971 | LogicalResult matchAndRewrite(Operation *op, |
972 | PatternRewriter &rewriter) const final { |
973 | auto resultType = *op->result_type_begin(); |
974 | auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); |
975 | auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); |
976 | |
977 | rewriter.replaceOp(illegalOp, legalOp); |
978 | rewriter.replaceOp(op, illegalOp); |
979 | return success(); |
980 | } |
981 | }; |
982 | |
983 | //===----------------------------------------------------------------------===// |
984 | // Recursive Rewrite Testing |
985 | /// This pattern is applied to the same operation multiple times, but has a |
986 | /// bounded recursion. |
987 | struct TestBoundedRecursiveRewrite |
988 | : public OpRewritePattern<TestRecursiveRewriteOp> { |
989 | using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; |
990 | |
991 | void initialize() { |
992 | // The conversion target handles bounding the recursion of this pattern. |
993 | setHasBoundedRewriteRecursion(); |
994 | } |
995 | |
996 | LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, |
997 | PatternRewriter &rewriter) const final { |
998 | // Decrement the depth of the op in-place. |
999 | rewriter.modifyOpInPlace(op, [&] { |
1000 | op->setAttr("depth" , rewriter.getI64IntegerAttr(value: op.getDepth() - 1)); |
1001 | }); |
1002 | return success(); |
1003 | } |
1004 | }; |
1005 | |
1006 | struct TestNestedOpCreationUndoRewrite |
1007 | : public OpRewritePattern<IllegalOpWithRegionAnchor> { |
1008 | using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; |
1009 | |
1010 | LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, |
1011 | PatternRewriter &rewriter) const final { |
1012 | // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); |
1013 | rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); |
1014 | return success(); |
1015 | }; |
1016 | }; |
1017 | |
1018 | // This pattern matches `test.blackhole` and delete this op and its producer. |
1019 | struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { |
1020 | using OpRewritePattern<BlackHoleOp>::OpRewritePattern; |
1021 | |
1022 | LogicalResult matchAndRewrite(BlackHoleOp op, |
1023 | PatternRewriter &rewriter) const final { |
1024 | Operation *producer = op.getOperand().getDefiningOp(); |
1025 | // Always erase the user before the producer, the framework should handle |
1026 | // this correctly. |
1027 | rewriter.eraseOp(op: op); |
1028 | rewriter.eraseOp(op: producer); |
1029 | return success(); |
1030 | }; |
1031 | }; |
1032 | |
1033 | // This pattern replaces explicitly illegal op with explicitly legal op, |
1034 | // but in addition creates unregistered operation. |
1035 | struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { |
1036 | using OpRewritePattern<ILLegalOpG>::OpRewritePattern; |
1037 | |
1038 | LogicalResult matchAndRewrite(ILLegalOpG op, |
1039 | PatternRewriter &rewriter) const final { |
1040 | IntegerAttr attr = rewriter.getI32IntegerAttr(0); |
1041 | Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr); |
1042 | rewriter.replaceOpWithNewOp<LegalOpC>(op, val); |
1043 | return success(); |
1044 | }; |
1045 | }; |
1046 | } // namespace |
1047 | |
1048 | namespace { |
1049 | struct TestTypeConverter : public TypeConverter { |
1050 | using TypeConverter::TypeConverter; |
1051 | TestTypeConverter() { |
1052 | addConversion(callback&: convertType); |
1053 | addArgumentMaterialization(callback&: materializeCast); |
1054 | addSourceMaterialization(callback&: materializeCast); |
1055 | } |
1056 | |
1057 | static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { |
1058 | // Drop I16 types. |
1059 | if (t.isSignlessInteger(width: 16)) |
1060 | return success(); |
1061 | |
1062 | // Convert I64 to F64. |
1063 | if (t.isSignlessInteger(width: 64)) { |
1064 | results.push_back(Elt: FloatType::getF64(ctx: t.getContext())); |
1065 | return success(); |
1066 | } |
1067 | |
1068 | // Convert I42 to I43. |
1069 | if (t.isInteger(width: 42)) { |
1070 | results.push_back(IntegerType::get(t.getContext(), 43)); |
1071 | return success(); |
1072 | } |
1073 | |
1074 | // Split F32 into F16,F16. |
1075 | if (t.isF32()) { |
1076 | results.assign(NumElts: 2, Elt: FloatType::getF16(ctx: t.getContext())); |
1077 | return success(); |
1078 | } |
1079 | |
1080 | // Otherwise, convert the type directly. |
1081 | results.push_back(Elt: t); |
1082 | return success(); |
1083 | } |
1084 | |
1085 | /// Hook for materializing a conversion. This is necessary because we generate |
1086 | /// 1->N type mappings. |
1087 | static std::optional<Value> materializeCast(OpBuilder &builder, |
1088 | Type resultType, |
1089 | ValueRange inputs, Location loc) { |
1090 | return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); |
1091 | } |
1092 | }; |
1093 | |
1094 | struct TestLegalizePatternDriver |
1095 | : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> { |
1096 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver) |
1097 | |
1098 | StringRef getArgument() const final { return "test-legalize-patterns" ; } |
1099 | StringRef getDescription() const final { |
1100 | return "Run test dialect legalization patterns" ; |
1101 | } |
1102 | /// The mode of conversion to use with the driver. |
1103 | enum class ConversionMode { Analysis, Full, Partial }; |
1104 | |
1105 | TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} |
1106 | |
1107 | void getDependentDialects(DialectRegistry ®istry) const override { |
1108 | registry.insert<func::FuncDialect, test::TestDialect>(); |
1109 | } |
1110 | |
1111 | void runOnOperation() override { |
1112 | TestTypeConverter converter; |
1113 | mlir::RewritePatternSet patterns(&getContext()); |
1114 | populateWithGenerated(patterns); |
1115 | patterns |
1116 | .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, |
1117 | TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, |
1118 | TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType, |
1119 | TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, |
1120 | TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, |
1121 | TestNonRootReplacement, TestBoundedRecursiveRewrite, |
1122 | TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, |
1123 | TestCreateUnregisteredOp, TestUndoMoveOpBefore, |
1124 | TestUndoPropertiesModification>(arg: &getContext()); |
1125 | patterns.add<TestDropOpSignatureConversion>(arg: &getContext(), args&: converter); |
1126 | mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, |
1127 | converter); |
1128 | mlir::populateCallOpTypeConversionPattern(patterns, converter); |
1129 | |
1130 | // Define the conversion target used for the test. |
1131 | ConversionTarget target(getContext()); |
1132 | target.addLegalOp<ModuleOp>(); |
1133 | target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, |
1134 | TerminatorOp, OneRegionOp>(); |
1135 | target |
1136 | .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); |
1137 | target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { |
1138 | // Don't allow F32 operands. |
1139 | return llvm::none_of(op.getOperandTypes(), |
1140 | [](Type type) { return type.isF32(); }); |
1141 | }); |
1142 | target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
1143 | return converter.isSignatureLegal(op.getFunctionType()) && |
1144 | converter.isLegal(&op.getBody()); |
1145 | }); |
1146 | target.addDynamicallyLegalOp<func::CallOp>( |
1147 | [&](func::CallOp op) { return converter.isLegal(op); }); |
1148 | |
1149 | // TestCreateUnregisteredOp creates `arith.constant` operation, |
1150 | // which was not added to target intentionally to test |
1151 | // correct error code from conversion driver. |
1152 | target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; }); |
1153 | |
1154 | // Expect the type_producer/type_consumer operations to only operate on f64. |
1155 | target.addDynamicallyLegalOp<TestTypeProducerOp>( |
1156 | [](TestTypeProducerOp op) { return op.getType().isF64(); }); |
1157 | target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { |
1158 | return op.getOperand().getType().isF64(); |
1159 | }); |
1160 | |
1161 | // Check support for marking certain operations as recursively legal. |
1162 | target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) { |
1163 | return static_cast<bool>( |
1164 | op->getAttrOfType<UnitAttr>("test.recursively_legal" )); |
1165 | }); |
1166 | |
1167 | // Mark the bound recursion operation as dynamically legal. |
1168 | target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( |
1169 | [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; }); |
1170 | |
1171 | // Create a dynamically legal rule that can only be legalized by folding it. |
1172 | target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>( |
1173 | [](TestOpInPlaceSelfFold op) { return op.getFolded(); }); |
1174 | |
1175 | // Handle a partial conversion. |
1176 | if (mode == ConversionMode::Partial) { |
1177 | DenseSet<Operation *> unlegalizedOps; |
1178 | ConversionConfig config; |
1179 | DumpNotifications dumpNotifications; |
1180 | config.listener = &dumpNotifications; |
1181 | config.unlegalizedOps = &unlegalizedOps; |
1182 | if (failed(applyPartialConversion(getOperation(), target, |
1183 | std::move(patterns), config))) { |
1184 | getOperation()->emitRemark() << "applyPartialConversion failed" ; |
1185 | } |
1186 | // Emit remarks for each legalizable operation. |
1187 | for (auto *op : unlegalizedOps) |
1188 | op->emitRemark() << "op '" << op->getName() << "' is not legalizable" ; |
1189 | return; |
1190 | } |
1191 | |
1192 | // Handle a full conversion. |
1193 | if (mode == ConversionMode::Full) { |
1194 | // Check support for marking unknown operations as dynamically legal. |
1195 | target.markUnknownOpDynamicallyLegal([](Operation *op) { |
1196 | return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal" ); |
1197 | }); |
1198 | |
1199 | ConversionConfig config; |
1200 | DumpNotifications dumpNotifications; |
1201 | config.listener = &dumpNotifications; |
1202 | if (failed(applyFullConversion(getOperation(), target, |
1203 | std::move(patterns), config))) { |
1204 | getOperation()->emitRemark() << "applyFullConversion failed" ; |
1205 | } |
1206 | return; |
1207 | } |
1208 | |
1209 | // Otherwise, handle an analysis conversion. |
1210 | assert(mode == ConversionMode::Analysis); |
1211 | |
1212 | // Analyze the convertible operations. |
1213 | DenseSet<Operation *> legalizedOps; |
1214 | ConversionConfig config; |
1215 | config.legalizableOps = &legalizedOps; |
1216 | if (failed(applyAnalysisConversion(getOperation(), target, |
1217 | std::move(patterns), config))) |
1218 | return signalPassFailure(); |
1219 | |
1220 | // Emit remarks for each legalizable operation. |
1221 | for (auto *op : legalizedOps) |
1222 | op->emitRemark() << "op '" << op->getName() << "' is legalizable" ; |
1223 | } |
1224 | |
1225 | /// The mode of conversion to use. |
1226 | ConversionMode mode; |
1227 | }; |
1228 | } // namespace |
1229 | |
1230 | static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> |
1231 | legalizerConversionMode( |
1232 | "test-legalize-mode" , |
1233 | llvm::cl::desc("The legalization mode to use with the test driver" ), |
1234 | llvm::cl::init(Val: TestLegalizePatternDriver::ConversionMode::Partial), |
1235 | llvm::cl::values( |
1236 | clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, |
1237 | "analysis" , "Perform an analysis conversion" ), |
1238 | clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full" , |
1239 | "Perform a full conversion" ), |
1240 | clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, |
1241 | "partial" , "Perform a partial conversion" ))); |
1242 | |
1243 | //===----------------------------------------------------------------------===// |
1244 | // ConversionPatternRewriter::getRemappedValue testing. This method is used |
1245 | // to get the remapped value of an original value that was replaced using |
1246 | // ConversionPatternRewriter. |
1247 | namespace { |
1248 | struct TestRemapValueTypeConverter : public TypeConverter { |
1249 | using TypeConverter::TypeConverter; |
1250 | |
1251 | TestRemapValueTypeConverter() { |
1252 | addConversion( |
1253 | callback: [](Float32Type type) { return Float64Type::get(type.getContext()); }); |
1254 | addConversion(callback: [](Type type) { return type; }); |
1255 | } |
1256 | }; |
1257 | |
1258 | /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with |
1259 | /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original |
1260 | /// operand twice. |
1261 | /// |
1262 | /// Example: |
1263 | /// %1 = test.one_variadic_out_one_variadic_in1"(%0) |
1264 | /// is replaced with: |
1265 | /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) |
1266 | struct OneVResOneVOperandOp1Converter |
1267 | : public OpConversionPattern<OneVResOneVOperandOp1> { |
1268 | using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; |
1269 | |
1270 | LogicalResult |
1271 | matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor, |
1272 | ConversionPatternRewriter &rewriter) const override { |
1273 | auto origOps = op.getOperands(); |
1274 | assert(std::distance(origOps.begin(), origOps.end()) == 1 && |
1275 | "One operand expected" ); |
1276 | Value origOp = *origOps.begin(); |
1277 | SmallVector<Value, 2> remappedOperands; |
1278 | // Replicate the remapped original operand twice. Note that we don't used |
1279 | // the remapped 'operand' since the goal is testing 'getRemappedValue'. |
1280 | remappedOperands.push_back(Elt: rewriter.getRemappedValue(key: origOp)); |
1281 | remappedOperands.push_back(Elt: rewriter.getRemappedValue(key: origOp)); |
1282 | |
1283 | rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), |
1284 | remappedOperands); |
1285 | return success(); |
1286 | } |
1287 | }; |
1288 | |
1289 | /// A rewriter pattern that tests that blocks can be merged. |
1290 | struct TestRemapValueInRegion |
1291 | : public OpConversionPattern<TestRemappedValueRegionOp> { |
1292 | using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern; |
1293 | |
1294 | LogicalResult |
1295 | matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor, |
1296 | ConversionPatternRewriter &rewriter) const final { |
1297 | Block &block = op.getBody().front(); |
1298 | Operation *terminator = block.getTerminator(); |
1299 | |
1300 | // Merge the block into the parent region. |
1301 | Block *parentBlock = op->getBlock(); |
1302 | Block *finalBlock = rewriter.splitBlock(block: parentBlock, before: op->getIterator()); |
1303 | rewriter.mergeBlocks(source: &block, dest: parentBlock, argValues: ValueRange()); |
1304 | rewriter.mergeBlocks(source: finalBlock, dest: parentBlock, argValues: ValueRange()); |
1305 | |
1306 | // Replace the results of this operation with the remapped terminator |
1307 | // values. |
1308 | SmallVector<Value> terminatorOperands; |
1309 | if (failed(result: rewriter.getRemappedValues(keys: terminator->getOperands(), |
1310 | results&: terminatorOperands))) |
1311 | return failure(); |
1312 | |
1313 | rewriter.eraseOp(op: terminator); |
1314 | rewriter.replaceOp(op, terminatorOperands); |
1315 | return success(); |
1316 | } |
1317 | }; |
1318 | |
1319 | struct TestRemappedValue |
1320 | : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> { |
1321 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue) |
1322 | |
1323 | StringRef getArgument() const final { return "test-remapped-value" ; } |
1324 | StringRef getDescription() const final { |
1325 | return "Test public remapped value mechanism in ConversionPatternRewriter" ; |
1326 | } |
1327 | void runOnOperation() override { |
1328 | TestRemapValueTypeConverter typeConverter; |
1329 | |
1330 | mlir::RewritePatternSet patterns(&getContext()); |
1331 | patterns.add<OneVResOneVOperandOp1Converter>(arg: &getContext()); |
1332 | patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>( |
1333 | arg: &getContext()); |
1334 | patterns.add<TestRemapValueInRegion>(arg&: typeConverter, args: &getContext()); |
1335 | |
1336 | mlir::ConversionTarget target(getContext()); |
1337 | target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>(); |
1338 | |
1339 | // Expect the type_producer/type_consumer operations to only operate on f64. |
1340 | target.addDynamicallyLegalOp<TestTypeProducerOp>( |
1341 | [](TestTypeProducerOp op) { return op.getType().isF64(); }); |
1342 | target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { |
1343 | return op.getOperand().getType().isF64(); |
1344 | }); |
1345 | |
1346 | // We make OneVResOneVOperandOp1 legal only when it has more that one |
1347 | // operand. This will trigger the conversion that will replace one-operand |
1348 | // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. |
1349 | target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( |
1350 | [](Operation *op) { return op->getNumOperands() > 1; }); |
1351 | |
1352 | if (failed(mlir::applyFullConversion(getOperation(), target, |
1353 | std::move(patterns)))) { |
1354 | signalPassFailure(); |
1355 | } |
1356 | } |
1357 | }; |
1358 | } // namespace |
1359 | |
1360 | //===----------------------------------------------------------------------===// |
1361 | // Test patterns without a specific root operation kind |
1362 | //===----------------------------------------------------------------------===// |
1363 | |
1364 | namespace { |
1365 | /// This pattern matches and removes any operation in the test dialect. |
1366 | struct RemoveTestDialectOps : public RewritePattern { |
1367 | RemoveTestDialectOps(MLIRContext *context) |
1368 | : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} |
1369 | |
1370 | LogicalResult matchAndRewrite(Operation *op, |
1371 | PatternRewriter &rewriter) const override { |
1372 | if (!isa<TestDialect>(Val: op->getDialect())) |
1373 | return failure(); |
1374 | rewriter.eraseOp(op); |
1375 | return success(); |
1376 | } |
1377 | }; |
1378 | |
1379 | struct TestUnknownRootOpDriver |
1380 | : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> { |
1381 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver) |
1382 | |
1383 | StringRef getArgument() const final { |
1384 | return "test-legalize-unknown-root-patterns" ; |
1385 | } |
1386 | StringRef getDescription() const final { |
1387 | return "Test public remapped value mechanism in ConversionPatternRewriter" ; |
1388 | } |
1389 | void runOnOperation() override { |
1390 | mlir::RewritePatternSet patterns(&getContext()); |
1391 | patterns.add<RemoveTestDialectOps>(arg: &getContext()); |
1392 | |
1393 | mlir::ConversionTarget target(getContext()); |
1394 | target.addIllegalDialect<TestDialect>(); |
1395 | if (failed(applyPartialConversion(getOperation(), target, |
1396 | std::move(patterns)))) |
1397 | signalPassFailure(); |
1398 | } |
1399 | }; |
1400 | } // namespace |
1401 | |
1402 | //===----------------------------------------------------------------------===// |
1403 | // Test patterns that uses operations and types defined at runtime |
1404 | //===----------------------------------------------------------------------===// |
1405 | |
1406 | namespace { |
1407 | /// This pattern matches dynamic operations 'test.one_operand_two_results' and |
1408 | /// replace them with dynamic operations 'test.generic_dynamic_op'. |
1409 | struct RewriteDynamicOp : public RewritePattern { |
1410 | RewriteDynamicOp(MLIRContext *context) |
1411 | : RewritePattern("test.dynamic_one_operand_two_results" , /*benefit=*/1, |
1412 | context) {} |
1413 | |
1414 | LogicalResult matchAndRewrite(Operation *op, |
1415 | PatternRewriter &rewriter) const override { |
1416 | assert(op->getName().getStringRef() == |
1417 | "test.dynamic_one_operand_two_results" && |
1418 | "rewrite pattern should only match operations with the right name" ); |
1419 | |
1420 | OperationState state(op->getLoc(), "test.dynamic_generic" , |
1421 | op->getOperands(), op->getResultTypes(), |
1422 | op->getAttrs()); |
1423 | auto *newOp = rewriter.create(state); |
1424 | rewriter.replaceOp(op, newValues: newOp->getResults()); |
1425 | return success(); |
1426 | } |
1427 | }; |
1428 | |
1429 | struct TestRewriteDynamicOpDriver |
1430 | : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> { |
1431 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver) |
1432 | |
1433 | void getDependentDialects(DialectRegistry ®istry) const override { |
1434 | registry.insert<TestDialect>(); |
1435 | } |
1436 | StringRef getArgument() const final { return "test-rewrite-dynamic-op" ; } |
1437 | StringRef getDescription() const final { |
1438 | return "Test rewritting on dynamic operations" ; |
1439 | } |
1440 | void runOnOperation() override { |
1441 | RewritePatternSet patterns(&getContext()); |
1442 | patterns.add<RewriteDynamicOp>(arg: &getContext()); |
1443 | |
1444 | ConversionTarget target(getContext()); |
1445 | target.addIllegalOp( |
1446 | op: OperationName("test.dynamic_one_operand_two_results" , &getContext())); |
1447 | target.addLegalOp(op: OperationName("test.dynamic_generic" , &getContext())); |
1448 | if (failed(applyPartialConversion(getOperation(), target, |
1449 | std::move(patterns)))) |
1450 | signalPassFailure(); |
1451 | } |
1452 | }; |
1453 | } // end anonymous namespace |
1454 | |
1455 | //===----------------------------------------------------------------------===// |
1456 | // Test type conversions |
1457 | //===----------------------------------------------------------------------===// |
1458 | |
1459 | namespace { |
1460 | struct TestTypeConversionProducer |
1461 | : public OpConversionPattern<TestTypeProducerOp> { |
1462 | using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; |
1463 | LogicalResult |
1464 | matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, |
1465 | ConversionPatternRewriter &rewriter) const final { |
1466 | Type resultType = op.getType(); |
1467 | Type convertedType = getTypeConverter() |
1468 | ? getTypeConverter()->convertType(resultType) |
1469 | : resultType; |
1470 | if (isa<FloatType>(Val: resultType)) |
1471 | resultType = rewriter.getF64Type(); |
1472 | else if (resultType.isInteger(width: 16)) |
1473 | resultType = rewriter.getIntegerType(64); |
1474 | else if (isa<test::TestRecursiveType>(Val: resultType) && |
1475 | convertedType != resultType) |
1476 | resultType = convertedType; |
1477 | else |
1478 | return failure(); |
1479 | |
1480 | rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); |
1481 | return success(); |
1482 | } |
1483 | }; |
1484 | |
1485 | /// Call signature conversion and then fail the rewrite to trigger the undo |
1486 | /// mechanism. |
1487 | struct TestSignatureConversionUndo |
1488 | : public OpConversionPattern<TestSignatureConversionUndoOp> { |
1489 | using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; |
1490 | |
1491 | LogicalResult |
1492 | matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor, |
1493 | ConversionPatternRewriter &rewriter) const final { |
1494 | (void)rewriter.convertRegionTypes(region: &op->getRegion(0), converter: *getTypeConverter()); |
1495 | return failure(); |
1496 | } |
1497 | }; |
1498 | |
1499 | /// Call signature conversion without providing a type converter to handle |
1500 | /// materializations. |
1501 | struct TestTestSignatureConversionNoConverter |
1502 | : public OpConversionPattern<TestSignatureConversionNoConverterOp> { |
1503 | TestTestSignatureConversionNoConverter(const TypeConverter &converter, |
1504 | MLIRContext *context) |
1505 | : OpConversionPattern<TestSignatureConversionNoConverterOp>(context), |
1506 | converter(converter) {} |
1507 | |
1508 | LogicalResult |
1509 | matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor, |
1510 | ConversionPatternRewriter &rewriter) const final { |
1511 | Region ®ion = op->getRegion(0); |
1512 | Block *entry = ®ion.front(); |
1513 | |
1514 | // Convert the original entry arguments. |
1515 | TypeConverter::SignatureConversion result(entry->getNumArguments()); |
1516 | if (failed( |
1517 | result: converter.convertSignatureArgs(types: entry->getArgumentTypes(), result))) |
1518 | return failure(); |
1519 | rewriter.modifyOpInPlace( |
1520 | op, [&] { rewriter.applySignatureConversion(region: ®ion, conversion&: result); }); |
1521 | return success(); |
1522 | } |
1523 | |
1524 | const TypeConverter &converter; |
1525 | }; |
1526 | |
1527 | /// Just forward the operands to the root op. This is essentially a no-op |
1528 | /// pattern that is used to trigger target materialization. |
1529 | struct TestTypeConsumerForward |
1530 | : public OpConversionPattern<TestTypeConsumerOp> { |
1531 | using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; |
1532 | |
1533 | LogicalResult |
1534 | matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, |
1535 | ConversionPatternRewriter &rewriter) const final { |
1536 | rewriter.modifyOpInPlace(op, |
1537 | [&] { op->setOperands(adaptor.getOperands()); }); |
1538 | return success(); |
1539 | } |
1540 | }; |
1541 | |
1542 | struct TestTypeConversionAnotherProducer |
1543 | : public OpRewritePattern<TestAnotherTypeProducerOp> { |
1544 | using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; |
1545 | |
1546 | LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, |
1547 | PatternRewriter &rewriter) const final { |
1548 | rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); |
1549 | return success(); |
1550 | } |
1551 | }; |
1552 | |
1553 | struct TestTypeConversionDriver |
1554 | : public PassWrapper<TestTypeConversionDriver, OperationPass<>> { |
1555 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver) |
1556 | |
1557 | void getDependentDialects(DialectRegistry ®istry) const override { |
1558 | registry.insert<TestDialect>(); |
1559 | } |
1560 | StringRef getArgument() const final { |
1561 | return "test-legalize-type-conversion" ; |
1562 | } |
1563 | StringRef getDescription() const final { |
1564 | return "Test various type conversion functionalities in DialectConversion" ; |
1565 | } |
1566 | |
1567 | void runOnOperation() override { |
1568 | // Initialize the type converter. |
1569 | SmallVector<Type, 2> conversionCallStack; |
1570 | TypeConverter converter; |
1571 | |
1572 | /// Add the legal set of type conversions. |
1573 | converter.addConversion(callback: [](Type type) -> Type { |
1574 | // Treat F64 as legal. |
1575 | if (type.isF64()) |
1576 | return type; |
1577 | // Allow converting BF16/F16/F32 to F64. |
1578 | if (type.isBF16() || type.isF16() || type.isF32()) |
1579 | return FloatType::getF64(ctx: type.getContext()); |
1580 | // Otherwise, the type is illegal. |
1581 | return nullptr; |
1582 | }); |
1583 | converter.addConversion(callback: [](IntegerType type, SmallVectorImpl<Type> &) { |
1584 | // Drop all integer types. |
1585 | return success(); |
1586 | }); |
1587 | converter.addConversion( |
1588 | // Convert a recursive self-referring type into a non-self-referring |
1589 | // type named "outer_converted_type" that contains a SimpleAType. |
1590 | callback: [&](test::TestRecursiveType type, |
1591 | SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { |
1592 | // If the type is already converted, return it to indicate that it is |
1593 | // legal. |
1594 | if (type.getName() == "outer_converted_type" ) { |
1595 | results.push_back(type); |
1596 | return success(); |
1597 | } |
1598 | |
1599 | conversionCallStack.push_back(type); |
1600 | auto popConversionCallStack = llvm::make_scope_exit( |
1601 | F: [&conversionCallStack]() { conversionCallStack.pop_back(); }); |
1602 | |
1603 | // If the type is on the call stack more than once (it is there at |
1604 | // least once because of the _current_ call, which is always the last |
1605 | // element on the stack), we've hit the recursive case. Just return |
1606 | // SimpleAType here to create a non-recursive type as a result. |
1607 | if (llvm::is_contained(Range: ArrayRef(conversionCallStack).drop_back(), |
1608 | Element: type)) { |
1609 | results.push_back(test::SimpleAType::get(type.getContext())); |
1610 | return success(); |
1611 | } |
1612 | |
1613 | // Convert the body recursively. |
1614 | auto result = test::TestRecursiveType::get(ctx: type.getContext(), |
1615 | name: "outer_converted_type" ); |
1616 | if (failed(result.setBody(converter.convertType(t: type.getBody())))) |
1617 | return failure(); |
1618 | results.push_back(Elt: result); |
1619 | return success(); |
1620 | }); |
1621 | |
1622 | /// Add the legal set of type materializations. |
1623 | converter.addSourceMaterialization(callback: [](OpBuilder &builder, Type resultType, |
1624 | ValueRange inputs, |
1625 | Location loc) -> Value { |
1626 | // Allow casting from F64 back to F32. |
1627 | if (!resultType.isF16() && inputs.size() == 1 && |
1628 | inputs[0].getType().isF64()) |
1629 | return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); |
1630 | // Allow producing an i32 or i64 from nothing. |
1631 | if ((resultType.isInteger(32) || resultType.isInteger(64)) && |
1632 | inputs.empty()) |
1633 | return builder.create<TestTypeProducerOp>(loc, resultType); |
1634 | // Allow producing an i64 from an integer. |
1635 | if (isa<IntegerType>(resultType) && inputs.size() == 1 && |
1636 | isa<IntegerType>(inputs[0].getType())) |
1637 | return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); |
1638 | // Otherwise, fail. |
1639 | return nullptr; |
1640 | }); |
1641 | |
1642 | // Initialize the conversion target. |
1643 | mlir::ConversionTarget target(getContext()); |
1644 | target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { |
1645 | auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType()); |
1646 | return op.getType().isF64() || op.getType().isInteger(64) || |
1647 | (recursiveType && |
1648 | recursiveType.getName() == "outer_converted_type" ); |
1649 | }); |
1650 | target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
1651 | return converter.isSignatureLegal(op.getFunctionType()) && |
1652 | converter.isLegal(&op.getBody()); |
1653 | }); |
1654 | target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { |
1655 | // Allow casts from F64 to F32. |
1656 | return (*op.operand_type_begin()).isF64() && op.getType().isF32(); |
1657 | }); |
1658 | target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>( |
1659 | [&](TestSignatureConversionNoConverterOp op) { |
1660 | return converter.isLegal(op.getRegion().front().getArgumentTypes()); |
1661 | }); |
1662 | |
1663 | // Initialize the set of rewrite patterns. |
1664 | RewritePatternSet patterns(&getContext()); |
1665 | patterns.add<TestTypeConsumerForward, TestTypeConversionProducer, |
1666 | TestSignatureConversionUndo, |
1667 | TestTestSignatureConversionNoConverter>(arg&: converter, |
1668 | args: &getContext()); |
1669 | patterns.add<TestTypeConversionAnotherProducer>(arg: &getContext()); |
1670 | mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, |
1671 | converter); |
1672 | |
1673 | if (failed(applyPartialConversion(getOperation(), target, |
1674 | std::move(patterns)))) |
1675 | signalPassFailure(); |
1676 | } |
1677 | }; |
1678 | } // namespace |
1679 | |
1680 | //===----------------------------------------------------------------------===// |
1681 | // Test Target Materialization With No Uses |
1682 | //===----------------------------------------------------------------------===// |
1683 | |
1684 | namespace { |
1685 | struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> { |
1686 | using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern; |
1687 | |
1688 | LogicalResult |
1689 | matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor, |
1690 | ConversionPatternRewriter &rewriter) const final { |
1691 | rewriter.replaceOp(op, adaptor.getOperands()); |
1692 | return success(); |
1693 | } |
1694 | }; |
1695 | |
1696 | struct TestTargetMaterializationWithNoUses |
1697 | : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> { |
1698 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
1699 | TestTargetMaterializationWithNoUses) |
1700 | |
1701 | StringRef getArgument() const final { |
1702 | return "test-target-materialization-with-no-uses" ; |
1703 | } |
1704 | StringRef getDescription() const final { |
1705 | return "Test a special case of target materialization in DialectConversion" ; |
1706 | } |
1707 | |
1708 | void runOnOperation() override { |
1709 | TypeConverter converter; |
1710 | converter.addConversion(callback: [](Type t) { return t; }); |
1711 | converter.addConversion(callback: [](IntegerType intTy) -> Type { |
1712 | if (intTy.getWidth() == 16) |
1713 | return IntegerType::get(intTy.getContext(), 64); |
1714 | return intTy; |
1715 | }); |
1716 | converter.addTargetMaterialization( |
1717 | callback: [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { |
1718 | return builder.create<TestCastOp>(loc, type, inputs).getResult(); |
1719 | }); |
1720 | |
1721 | ConversionTarget target(getContext()); |
1722 | target.addIllegalOp<TestTypeChangerOp>(); |
1723 | |
1724 | RewritePatternSet patterns(&getContext()); |
1725 | patterns.add<ForwardOperandPattern>(arg&: converter, args: &getContext()); |
1726 | |
1727 | if (failed(applyPartialConversion(getOperation(), target, |
1728 | std::move(patterns)))) |
1729 | signalPassFailure(); |
1730 | } |
1731 | }; |
1732 | } // namespace |
1733 | |
1734 | //===----------------------------------------------------------------------===// |
1735 | // Test Block Merging |
1736 | //===----------------------------------------------------------------------===// |
1737 | |
1738 | namespace { |
1739 | /// A rewriter pattern that tests that blocks can be merged. |
1740 | struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { |
1741 | using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; |
1742 | |
1743 | LogicalResult |
1744 | matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, |
1745 | ConversionPatternRewriter &rewriter) const final { |
1746 | Block &firstBlock = op.getBody().front(); |
1747 | Operation *branchOp = firstBlock.getTerminator(); |
1748 | Block *secondBlock = &*(std::next(op.getBody().begin())); |
1749 | auto succOperands = branchOp->getOperands(); |
1750 | SmallVector<Value, 2> replacements(succOperands); |
1751 | rewriter.eraseOp(op: branchOp); |
1752 | rewriter.mergeBlocks(source: secondBlock, dest: &firstBlock, argValues: replacements); |
1753 | rewriter.modifyOpInPlace(op, [] {}); |
1754 | return success(); |
1755 | } |
1756 | }; |
1757 | |
1758 | /// A rewrite pattern to tests the undo mechanism of blocks being merged. |
1759 | struct TestUndoBlocksMerge : public ConversionPattern { |
1760 | TestUndoBlocksMerge(MLIRContext *ctx) |
1761 | : ConversionPattern("test.undo_blocks_merge" , /*benefit=*/1, ctx) {} |
1762 | LogicalResult |
1763 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
1764 | ConversionPatternRewriter &rewriter) const final { |
1765 | Block &firstBlock = op->getRegion(index: 0).front(); |
1766 | Operation *branchOp = firstBlock.getTerminator(); |
1767 | Block *secondBlock = &*(std::next(x: op->getRegion(index: 0).begin())); |
1768 | rewriter.setInsertionPointToStart(secondBlock); |
1769 | rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); |
1770 | auto succOperands = branchOp->getOperands(); |
1771 | SmallVector<Value, 2> replacements(succOperands); |
1772 | rewriter.eraseOp(op: branchOp); |
1773 | rewriter.mergeBlocks(source: secondBlock, dest: &firstBlock, argValues: replacements); |
1774 | rewriter.modifyOpInPlace(root: op, callable: [] {}); |
1775 | return success(); |
1776 | } |
1777 | }; |
1778 | |
1779 | /// A rewrite mechanism to inline the body of the op into its parent, when both |
1780 | /// ops can have a single block. |
1781 | struct TestMergeSingleBlockOps |
1782 | : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { |
1783 | using OpConversionPattern< |
1784 | SingleBlockImplicitTerminatorOp>::OpConversionPattern; |
1785 | |
1786 | LogicalResult |
1787 | matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor, |
1788 | ConversionPatternRewriter &rewriter) const final { |
1789 | SingleBlockImplicitTerminatorOp parentOp = |
1790 | op->getParentOfType<SingleBlockImplicitTerminatorOp>(); |
1791 | if (!parentOp) |
1792 | return failure(); |
1793 | Block &innerBlock = op.getRegion().front(); |
1794 | TerminatorOp innerTerminator = |
1795 | cast<TerminatorOp>(innerBlock.getTerminator()); |
1796 | rewriter.inlineBlockBefore(&innerBlock, op); |
1797 | rewriter.eraseOp(op: innerTerminator); |
1798 | rewriter.eraseOp(op: op); |
1799 | return success(); |
1800 | } |
1801 | }; |
1802 | |
1803 | struct TestMergeBlocksPatternDriver |
1804 | : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> { |
1805 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver) |
1806 | |
1807 | StringRef getArgument() const final { return "test-merge-blocks" ; } |
1808 | StringRef getDescription() const final { |
1809 | return "Test Merging operation in ConversionPatternRewriter" ; |
1810 | } |
1811 | void runOnOperation() override { |
1812 | MLIRContext *context = &getContext(); |
1813 | mlir::RewritePatternSet patterns(context); |
1814 | patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( |
1815 | arg&: context); |
1816 | ConversionTarget target(*context); |
1817 | target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp, |
1818 | TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); |
1819 | target.addIllegalOp<ILLegalOpF>(); |
1820 | |
1821 | /// Expect the op to have a single block after legalization. |
1822 | target.addDynamicallyLegalOp<TestMergeBlocksOp>( |
1823 | [&](TestMergeBlocksOp op) -> bool { |
1824 | return llvm::hasSingleElement(op.getBody()); |
1825 | }); |
1826 | |
1827 | /// Only allow `test.br` within test.merge_blocks op. |
1828 | target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { |
1829 | return op->getParentOfType<TestMergeBlocksOp>(); |
1830 | }); |
1831 | |
1832 | /// Expect that all nested test.SingleBlockImplicitTerminator ops are |
1833 | /// inlined. |
1834 | target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( |
1835 | [&](SingleBlockImplicitTerminatorOp op) -> bool { |
1836 | return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); |
1837 | }); |
1838 | |
1839 | DenseSet<Operation *> unlegalizedOps; |
1840 | ConversionConfig config; |
1841 | config.unlegalizedOps = &unlegalizedOps; |
1842 | (void)applyPartialConversion(getOperation(), target, std::move(patterns), |
1843 | config); |
1844 | for (auto *op : unlegalizedOps) |
1845 | op->emitRemark() << "op '" << op->getName() << "' is not legalizable" ; |
1846 | } |
1847 | }; |
1848 | } // namespace |
1849 | |
1850 | //===----------------------------------------------------------------------===// |
1851 | // Test Selective Replacement |
1852 | //===----------------------------------------------------------------------===// |
1853 | |
1854 | namespace { |
1855 | /// A rewrite mechanism to inline the body of the op into its parent, when both |
1856 | /// ops can have a single block. |
1857 | struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { |
1858 | using OpRewritePattern<TestCastOp>::OpRewritePattern; |
1859 | |
1860 | LogicalResult matchAndRewrite(TestCastOp op, |
1861 | PatternRewriter &rewriter) const final { |
1862 | if (op.getNumOperands() != 2) |
1863 | return failure(); |
1864 | OperandRange operands = op.getOperands(); |
1865 | |
1866 | // Replace non-terminator uses with the first operand. |
1867 | rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) { |
1868 | return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); |
1869 | }); |
1870 | // Replace everything else with the second operand if the operation isn't |
1871 | // dead. |
1872 | rewriter.replaceOp(op, op.getOperand(1)); |
1873 | return success(); |
1874 | } |
1875 | }; |
1876 | |
1877 | struct TestSelectiveReplacementPatternDriver |
1878 | : public PassWrapper<TestSelectiveReplacementPatternDriver, |
1879 | OperationPass<>> { |
1880 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
1881 | TestSelectiveReplacementPatternDriver) |
1882 | |
1883 | StringRef getArgument() const final { |
1884 | return "test-pattern-selective-replacement" ; |
1885 | } |
1886 | StringRef getDescription() const final { |
1887 | return "Test selective replacement in the PatternRewriter" ; |
1888 | } |
1889 | void runOnOperation() override { |
1890 | MLIRContext *context = &getContext(); |
1891 | mlir::RewritePatternSet patterns(context); |
1892 | patterns.add<TestSelectiveOpReplacementPattern>(arg&: context); |
1893 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
1894 | } |
1895 | }; |
1896 | } // namespace |
1897 | |
1898 | //===----------------------------------------------------------------------===// |
1899 | // PassRegistration |
1900 | //===----------------------------------------------------------------------===// |
1901 | |
1902 | namespace mlir { |
1903 | namespace test { |
1904 | void registerPatternsTestPass() { |
1905 | PassRegistration<TestReturnTypeDriver>(); |
1906 | |
1907 | PassRegistration<TestDerivedAttributeDriver>(); |
1908 | |
1909 | PassRegistration<TestPatternDriver>(); |
1910 | PassRegistration<TestStrictPatternDriver>(); |
1911 | |
1912 | PassRegistration<TestLegalizePatternDriver>([] { |
1913 | return std::make_unique<TestLegalizePatternDriver>(args&: legalizerConversionMode); |
1914 | }); |
1915 | |
1916 | PassRegistration<TestRemappedValue>(); |
1917 | |
1918 | PassRegistration<TestUnknownRootOpDriver>(); |
1919 | |
1920 | PassRegistration<TestTypeConversionDriver>(); |
1921 | PassRegistration<TestTargetMaterializationWithNoUses>(); |
1922 | |
1923 | PassRegistration<TestRewriteDynamicOpDriver>(); |
1924 | |
1925 | PassRegistration<TestMergeBlocksPatternDriver>(); |
1926 | PassRegistration<TestSelectiveReplacementPatternDriver>(); |
1927 | } |
1928 | } // namespace test |
1929 | } // namespace mlir |
1930 | |