1//===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestDialect.h"
10#include "TestOps.h"
11#include "TestTypes.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/IR/BuiltinAttributes.h"
17#include "mlir/IR/Matchers.h"
18#include "mlir/IR/Visitors.h"
19#include "mlir/Pass/Pass.h"
20#include "mlir/Transforms/DialectConversion.h"
21#include "mlir/Transforms/FoldUtils.h"
22#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23#include "mlir/Transforms/WalkPatternRewriteDriver.h"
24#include "llvm/ADT/ScopeExit.h"
25#include <cstdint>
26
27using namespace mlir;
28using namespace test;
29
30// Native function for testing NativeCodeCall
31static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
32 return choice.getValue() ? input1 : input2;
33}
34
35static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
36 rewriter.create<OpI>(loc, input);
37}
38
39static void handleNoResultOp(PatternRewriter &rewriter,
40 OpSymbolBindingNoResult op) {
41 // Turn the no result op to a one-result op.
42 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
43 op.getOperand());
44}
45
46static bool getFirstI32Result(Operation *op, Value &value) {
47 if (!Type(op->getResult(idx: 0).getType()).isSignlessInteger(width: 32))
48 return false;
49 value = op->getResult(idx: 0);
50 return true;
51}
52
53static Value bindNativeCodeCallResult(Value value) { return value; }
54
55static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
56 Value input2) {
57 return SmallVector<Value, 2>({input2, input1});
58}
59
60// Test that natives calls are only called once during rewrites.
61// OpM_Test will return Pi, increased by 1 for each subsequent calls.
62// This let us check the number of times OpM_Test was called by inspecting
63// the returned value in the MLIR output.
64static int64_t opMIncreasingValue = 314159265;
65static Attribute opMTest(PatternRewriter &rewriter, Value val) {
66 int64_t i = opMIncreasingValue++;
67 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
68}
69
70namespace {
71#include "TestPatterns.inc"
72} // namespace
73
74//===----------------------------------------------------------------------===//
75// Test Reduce Pattern Interface
76//===----------------------------------------------------------------------===//
77
78void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
79 populateWithGenerated(patterns);
80}
81
82//===----------------------------------------------------------------------===//
83// Canonicalizer Driver.
84//===----------------------------------------------------------------------===//
85
86namespace {
87struct FoldingPattern : public RewritePattern {
88public:
89 FoldingPattern(MLIRContext *context)
90 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
91 /*benefit=*/1, context) {}
92
93 LogicalResult matchAndRewrite(Operation *op,
94 PatternRewriter &rewriter) const override {
95 // Exercise createOrFold API for a single-result operation that is folded
96 // upon construction. The operation being created has an in-place folder,
97 // and it should be still present in the output. Furthermore, the folder
98 // should not crash when attempting to recover the (unchanged) operation
99 // result.
100 Value result = rewriter.createOrFold<TestOpInPlaceFold>(
101 op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0));
102 assert(result);
103 rewriter.replaceOp(op, newValues: result);
104 return success();
105 }
106};
107
108/// This pattern creates a foldable operation at the entry point of the block.
109/// This tests the situation where the operation folder will need to replace an
110/// operation with a previously created constant that does not initially
111/// dominate the operation to replace.
112struct FolderInsertBeforePreviouslyFoldedConstantPattern
113 : public OpRewritePattern<TestCastOp> {
114public:
115 using OpRewritePattern<TestCastOp>::OpRewritePattern;
116
117 LogicalResult matchAndRewrite(TestCastOp op,
118 PatternRewriter &rewriter) const override {
119 if (!op->hasAttr("test_fold_before_previously_folded_op"))
120 return failure();
121 rewriter.setInsertionPointToStart(op->getBlock());
122
123 auto constOp = rewriter.create<arith::ConstantOp>(
124 op.getLoc(), rewriter.getBoolAttr(true));
125 rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
126 Value(constOp));
127 return success();
128 }
129};
130
131/// This pattern matches test.op_commutative2 with the first operand being
132/// another test.op_commutative2 with a constant on the right side and fold it
133/// away by propagating it as its result. This is intend to check that patterns
134/// are applied after the commutative property moves constant to the right.
135struct FolderCommutativeOp2WithConstant
136 : public OpRewritePattern<TestCommutative2Op> {
137public:
138 using OpRewritePattern<TestCommutative2Op>::OpRewritePattern;
139
140 LogicalResult matchAndRewrite(TestCommutative2Op op,
141 PatternRewriter &rewriter) const override {
142 auto operand =
143 dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp());
144 if (!operand)
145 return failure();
146 Attribute constInput;
147 if (!matchPattern(operand->getOperand(1), m_Constant(bind_value: &constInput)))
148 return failure();
149 rewriter.replaceOp(op, operand->getOperand(1));
150 return success();
151 }
152};
153
154/// This pattern matches test.any_attr_of_i32_str ops. In case of an integer
155/// attribute with value smaller than MaxVal, it increments the value by 1.
156template <int MaxVal>
157struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
158 using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern;
159
160 LogicalResult matchAndRewrite(AnyAttrOfOp op,
161 PatternRewriter &rewriter) const override {
162 auto intAttr = dyn_cast<IntegerAttr>(op.getAttr());
163 if (!intAttr)
164 return failure();
165 int64_t val = intAttr.getInt();
166 if (val >= MaxVal)
167 return failure();
168 rewriter.modifyOpInPlace(
169 op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); });
170 return success();
171 }
172};
173
174/// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op".
175struct MakeOpEligible : public RewritePattern {
176 MakeOpEligible(MLIRContext *context)
177 : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {}
178
179 LogicalResult matchAndRewrite(Operation *op,
180 PatternRewriter &rewriter) const override {
181 if (op->hasAttr(name: "eligible"))
182 return failure();
183 rewriter.modifyOpInPlace(
184 root: op, callable: [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); });
185 return success();
186 }
187};
188
189/// This pattern hoists eligible ops out of a "test.one_region_op".
190struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
191 using OpRewritePattern<test::OneRegionOp>::OpRewritePattern;
192
193 LogicalResult matchAndRewrite(test::OneRegionOp op,
194 PatternRewriter &rewriter) const override {
195 Operation *terminator = op.getRegion().front().getTerminator();
196 Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp();
197 if (toBeHoisted->getParentOp() != op)
198 return failure();
199 if (!toBeHoisted->hasAttr(name: "eligible"))
200 return failure();
201 rewriter.moveOpBefore(toBeHoisted, op);
202 return success();
203 }
204};
205
206/// This pattern moves "test.move_before_parent_op" before the parent op.
207struct MoveBeforeParentOp : public RewritePattern {
208 MoveBeforeParentOp(MLIRContext *context)
209 : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
210
211 LogicalResult matchAndRewrite(Operation *op,
212 PatternRewriter &rewriter) const override {
213 // Do not hoist past functions.
214 if (isa<FunctionOpInterface>(Val: op->getParentOp()))
215 return failure();
216 rewriter.moveOpBefore(op, existingOp: op->getParentOp());
217 return success();
218 }
219};
220
221/// This pattern moves "test.move_after_parent_op" after the parent op.
222struct MoveAfterParentOp : public RewritePattern {
223 MoveAfterParentOp(MLIRContext *context)
224 : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {}
225
226 LogicalResult matchAndRewrite(Operation *op,
227 PatternRewriter &rewriter) const override {
228 // Do not hoist past functions.
229 if (isa<FunctionOpInterface>(Val: op->getParentOp()))
230 return failure();
231
232 int64_t moveForwardBy = 0;
233 if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance"))
234 moveForwardBy = advanceBy.getInt();
235
236 Operation *moveAfter = op->getParentOp();
237 for (int64_t i = 0; i < moveForwardBy; ++i)
238 moveAfter = moveAfter->getNextNode();
239
240 rewriter.moveOpAfter(op, existingOp: moveAfter);
241 return success();
242 }
243};
244
245/// This pattern inlines blocks that are nested in
246/// "test.inline_blocks_into_parent" into the parent block.
247struct InlineBlocksIntoParent : public RewritePattern {
248 InlineBlocksIntoParent(MLIRContext *context)
249 : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
250 context) {}
251
252 LogicalResult matchAndRewrite(Operation *op,
253 PatternRewriter &rewriter) const override {
254 bool changed = false;
255 for (Region &r : op->getRegions()) {
256 while (!r.empty()) {
257 rewriter.inlineBlockBefore(source: &r.front(), op);
258 changed = true;
259 }
260 }
261 return success(IsSuccess: changed);
262 }
263};
264
265/// This pattern splits blocks at "test.split_block_here" and replaces the op
266/// with a new op (to prevent an infinite loop of block splitting).
267struct SplitBlockHere : public RewritePattern {
268 SplitBlockHere(MLIRContext *context)
269 : RewritePattern("test.split_block_here", /*benefit=*/1, context) {}
270
271 LogicalResult matchAndRewrite(Operation *op,
272 PatternRewriter &rewriter) const override {
273 rewriter.splitBlock(block: op->getBlock(), before: op->getIterator());
274 Operation *newOp = rewriter.create(
275 op->getLoc(),
276 OperationName("test.new_op", op->getContext()).getIdentifier(),
277 op->getOperands(), op->getResultTypes());
278 rewriter.replaceOp(op, newOp);
279 return success();
280 }
281};
282
283/// This pattern clones "test.clone_me" ops.
284struct CloneOp : public RewritePattern {
285 CloneOp(MLIRContext *context)
286 : RewritePattern("test.clone_me", /*benefit=*/1, context) {}
287
288 LogicalResult matchAndRewrite(Operation *op,
289 PatternRewriter &rewriter) const override {
290 // Do not clone already cloned ops to avoid going into an infinite loop.
291 if (op->hasAttr(name: "was_cloned"))
292 return failure();
293 Operation *cloned = rewriter.clone(op&: *op);
294 cloned->setAttr("was_cloned", rewriter.getUnitAttr());
295 return success();
296 }
297};
298
299/// This pattern clones regions of "test.clone_region_before" ops before the
300/// parent block.
301struct CloneRegionBeforeOp : public RewritePattern {
302 CloneRegionBeforeOp(MLIRContext *context)
303 : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {}
304
305 LogicalResult matchAndRewrite(Operation *op,
306 PatternRewriter &rewriter) const override {
307 // Do not clone already cloned ops to avoid going into an infinite loop.
308 if (op->hasAttr(name: "was_cloned"))
309 return failure();
310 for (Region &r : op->getRegions())
311 rewriter.cloneRegionBefore(region&: r, before: op->getBlock());
312 op->setAttr("was_cloned", rewriter.getUnitAttr());
313 return success();
314 }
315};
316
317/// Replace an operation may introduce the re-visiting of its users.
318class ReplaceWithNewOp : public RewritePattern {
319public:
320 ReplaceWithNewOp(MLIRContext *context)
321 : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
322
323 LogicalResult matchAndRewrite(Operation *op,
324 PatternRewriter &rewriter) const override {
325 Operation *newOp;
326 if (op->hasAttr(name: "create_erase_op")) {
327 newOp = rewriter.create(
328 op->getLoc(),
329 OperationName("test.erase_op", op->getContext()).getIdentifier(),
330 ValueRange(), TypeRange());
331 } else {
332 newOp = rewriter.create(
333 op->getLoc(),
334 OperationName("test.new_op", op->getContext()).getIdentifier(),
335 op->getOperands(), op->getResultTypes());
336 }
337 // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
338 // A "notifyOperationReplaced" callback is triggered in either case.
339 rewriter.replaceAllOpUsesWith(from: op, to: newOp->getResults());
340 rewriter.eraseOp(op);
341 return success();
342 }
343};
344
345/// Erases the first child block of the matched "test.erase_first_block"
346/// operation.
347class EraseFirstBlock : public RewritePattern {
348public:
349 EraseFirstBlock(MLIRContext *context)
350 : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {}
351
352 LogicalResult matchAndRewrite(Operation *op,
353 PatternRewriter &rewriter) const override {
354 for (Region &r : op->getRegions()) {
355 for (Block &b : r.getBlocks()) {
356 rewriter.eraseBlock(block: &b);
357 return success();
358 }
359 }
360
361 return failure();
362 }
363};
364
365struct TestGreedyPatternDriver
366 : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> {
367 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver)
368
369 TestGreedyPatternDriver() = default;
370 TestGreedyPatternDriver(const TestGreedyPatternDriver &other)
371 : PassWrapper(other) {}
372
373 StringRef getArgument() const final { return "test-greedy-patterns"; }
374 StringRef getDescription() const final { return "Run test dialect patterns"; }
375 void runOnOperation() override {
376 mlir::RewritePatternSet patterns(&getContext());
377 populateWithGenerated(patterns);
378
379 // Verify named pattern is generated with expected name.
380 patterns.add<FoldingPattern, TestNamedPatternRule,
381 FolderInsertBeforePreviouslyFoldedConstantPattern,
382 FolderCommutativeOp2WithConstant, HoistEligibleOps,
383 MakeOpEligible>(&getContext());
384
385 // Additional patterns for testing the GreedyPatternRewriteDriver.
386 patterns.insert<IncrementIntAttribute<3>>(arg: &getContext());
387
388 GreedyRewriteConfig config;
389 config.setUseTopDownTraversal(useTopDownTraversal)
390 .setMaxIterations(this->maxIterations)
391 .enableFolding(enable: this->fold)
392 .enableConstantCSE(enable: this->cseConstants);
393 (void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
394 }
395
396 Option<bool> useTopDownTraversal{
397 *this, "top-down",
398 llvm::cl::desc("Seed the worklist in general top-down order"),
399 llvm::cl::init(Val: GreedyRewriteConfig().getUseTopDownTraversal())};
400 Option<int> maxIterations{
401 *this, "max-iterations",
402 llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
403 llvm::cl::init(Val: GreedyRewriteConfig().getMaxIterations())};
404 Option<bool> fold{*this, "fold", llvm::cl::desc("Whether to fold"),
405 llvm::cl::init(Val: GreedyRewriteConfig().isFoldingEnabled())};
406 Option<bool> cseConstants{
407 *this, "cse-constants", llvm::cl::desc("Whether to CSE constants"),
408 llvm::cl::init(Val: GreedyRewriteConfig().isConstantCSEEnabled())};
409};
410
411struct DumpNotifications : public RewriterBase::Listener {
412 void notifyBlockInserted(Block *block, Region *previous,
413 Region::iterator previousIt) override {
414 llvm::outs() << "notifyBlockInserted";
415 if (block->getParentOp()) {
416 llvm::outs() << " into " << block->getParentOp()->getName() << ": ";
417 } else {
418 llvm::outs() << " into unknown op: ";
419 }
420 if (previous == nullptr) {
421 llvm::outs() << "was unlinked\n";
422 } else {
423 llvm::outs() << "was linked\n";
424 }
425 }
426 void notifyOperationInserted(Operation *op,
427 OpBuilder::InsertPoint previous) override {
428 llvm::outs() << "notifyOperationInserted: " << op->getName();
429 if (!previous.isSet()) {
430 llvm::outs() << ", was unlinked\n";
431 } else {
432 if (!previous.getPoint().getNodePtr()) {
433 llvm::outs() << ", was linked, exact position unknown\n";
434 } else if (previous.getPoint() == previous.getBlock()->end()) {
435 llvm::outs() << ", was last in block\n";
436 } else {
437 llvm::outs() << ", previous = " << previous.getPoint()->getName()
438 << "\n";
439 }
440 }
441 }
442 void notifyBlockErased(Block *block) override {
443 llvm::outs() << "notifyBlockErased\n";
444 }
445 void notifyOperationErased(Operation *op) override {
446 llvm::outs() << "notifyOperationErased: " << op->getName() << "\n";
447 }
448 void notifyOperationModified(Operation *op) override {
449 llvm::outs() << "notifyOperationModified: " << op->getName() << "\n";
450 }
451 void notifyOperationReplaced(Operation *op, ValueRange values) override {
452 llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n";
453 }
454};
455
456struct TestStrictPatternDriver
457 : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
458public:
459 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
460
461 TestStrictPatternDriver() = default;
462 TestStrictPatternDriver(const TestStrictPatternDriver &other)
463 : PassWrapper(other) {
464 strictMode = other.strictMode;
465 }
466
467 StringRef getArgument() const final { return "test-strict-pattern-driver"; }
468 StringRef getDescription() const final {
469 return "Test strict mode of pattern driver";
470 }
471
472 void runOnOperation() override {
473 MLIRContext *ctx = &getContext();
474 mlir::RewritePatternSet patterns(ctx);
475 patterns.add<
476 // clang-format off
477 ChangeBlockOp,
478 CloneOp,
479 CloneRegionBeforeOp,
480 EraseOp,
481 ImplicitChangeOp,
482 InlineBlocksIntoParent,
483 InsertSameOp,
484 MoveBeforeParentOp,
485 ReplaceWithNewOp,
486 SplitBlockHere
487 // clang-format on
488 >(arg&: ctx);
489 SmallVector<Operation *> ops;
490 getOperation()->walk([&](Operation *op) {
491 StringRef opName = op->getName().getStringRef();
492 if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
493 opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
494 opName == "test.move_before_parent_op" ||
495 opName == "test.inline_blocks_into_parent" ||
496 opName == "test.split_block_here" || opName == "test.clone_me" ||
497 opName == "test.clone_region_before") {
498 ops.push_back(Elt: op);
499 }
500 });
501
502 DumpNotifications dumpNotifications;
503 GreedyRewriteConfig config;
504 config.setListener(&dumpNotifications);
505 if (strictMode == "AnyOp") {
506 config.setStrictness(GreedyRewriteStrictness::AnyOp);
507 } else if (strictMode == "ExistingAndNewOps") {
508 config.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps);
509 } else if (strictMode == "ExistingOps") {
510 config.setStrictness(GreedyRewriteStrictness::ExistingOps);
511 } else {
512 llvm_unreachable("invalid strictness option");
513 }
514
515 // Check if these transformations introduce visiting of operations that
516 // are not in the `ops` set (The new created ops are valid). An invalid
517 // operation will trigger the assertion while processing.
518 bool changed = false;
519 bool allErased = false;
520 (void)applyOpPatternsGreedily(ArrayRef(ops), std::move(patterns), config,
521 &changed, &allErased);
522 Builder b(ctx);
523 getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(value: changed));
524 getOperation()->setAttr("pattern_driver_all_erased",
525 b.getBoolAttr(value: allErased));
526 }
527
528 Option<std::string> strictMode{
529 *this, "strictness",
530 llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"),
531 llvm::cl::init("AnyOp")};
532
533private:
534 // New inserted operation is valid for further transformation.
535 class InsertSameOp : public RewritePattern {
536 public:
537 InsertSameOp(MLIRContext *context)
538 : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {}
539
540 LogicalResult matchAndRewrite(Operation *op,
541 PatternRewriter &rewriter) const override {
542 if (op->hasAttr(name: "skip"))
543 return failure();
544
545 Operation *newOp =
546 rewriter.create(op->getLoc(), op->getName().getIdentifier(),
547 op->getOperands(), op->getResultTypes());
548 rewriter.modifyOpInPlace(
549 root: op, callable: [&]() { op->setAttr(name: "skip", value: rewriter.getBoolAttr(value: true)); });
550 newOp->setAttr(name: "skip", value: rewriter.getBoolAttr(value: true));
551
552 return success();
553 }
554 };
555
556 // Remove an operation may introduce the re-visiting of its operands.
557 class EraseOp : public RewritePattern {
558 public:
559 EraseOp(MLIRContext *context)
560 : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
561 LogicalResult matchAndRewrite(Operation *op,
562 PatternRewriter &rewriter) const override {
563 rewriter.eraseOp(op);
564 return success();
565 }
566 };
567
568 // The following two patterns test RewriterBase::replaceAllUsesWith.
569 //
570 // That function replaces all usages of a Block (or a Value) with another one
571 // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver
572 // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its
573 // worklist: when an op is modified, it is added to the worklist. The two
574 // patterns below make the tracking observable: ChangeBlockOp replaces all
575 // usages of a block and that pattern is applied because the corresponding ops
576 // are put on the initial worklist (see above). ImplicitChangeOp does an
577 // unrelated change but ops of the corresponding type are *not* on the initial
578 // worklist, so the effect of the second pattern is only visible if the
579 // tracking and subsequent adding to the worklist actually works.
580
581 // Replace all usages of the first successor with the second successor.
582 class ChangeBlockOp : public RewritePattern {
583 public:
584 ChangeBlockOp(MLIRContext *context)
585 : RewritePattern("test.change_block_op", /*benefit=*/1, context) {}
586 LogicalResult matchAndRewrite(Operation *op,
587 PatternRewriter &rewriter) const override {
588 if (op->getNumSuccessors() < 2)
589 return failure();
590 Block *firstSuccessor = op->getSuccessor(index: 0);
591 Block *secondSuccessor = op->getSuccessor(index: 1);
592 if (firstSuccessor == secondSuccessor)
593 return failure();
594 // This is the function being tested:
595 rewriter.replaceAllUsesWith(from: firstSuccessor, to: secondSuccessor);
596 // Using the following line instead would make the test fail:
597 // firstSuccessor->replaceAllUsesWith(secondSuccessor);
598 return success();
599 }
600 };
601
602 // Changes the successor to the parent block.
603 class ImplicitChangeOp : public RewritePattern {
604 public:
605 ImplicitChangeOp(MLIRContext *context)
606 : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {}
607 LogicalResult matchAndRewrite(Operation *op,
608 PatternRewriter &rewriter) const override {
609 if (op->getNumSuccessors() < 1 || op->getSuccessor(index: 0) == op->getBlock())
610 return failure();
611 rewriter.modifyOpInPlace(root: op,
612 callable: [&]() { op->setSuccessor(block: op->getBlock(), index: 0); });
613 return success();
614 }
615 };
616};
617
618struct TestWalkPatternDriver final
619 : PassWrapper<TestWalkPatternDriver, OperationPass<>> {
620 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver)
621
622 TestWalkPatternDriver() = default;
623 TestWalkPatternDriver(const TestWalkPatternDriver &other)
624 : PassWrapper(other) {}
625
626 StringRef getArgument() const override {
627 return "test-walk-pattern-rewrite-driver";
628 }
629 StringRef getDescription() const override {
630 return "Run test walk pattern rewrite driver";
631 }
632 void runOnOperation() override {
633 mlir::RewritePatternSet patterns(&getContext());
634
635 // Patterns for testing the WalkPatternRewriteDriver.
636 patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp,
637 MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>(
638 arg: &getContext());
639
640 DumpNotifications dumpListener;
641 walkAndApplyPatterns(getOperation(), std::move(patterns),
642 dumpNotifications ? &dumpListener : nullptr);
643 }
644
645 Option<bool> dumpNotifications{
646 *this, "dump-notifications",
647 llvm::cl::desc("Print rewrite listener notifications"),
648 llvm::cl::init(Val: false)};
649};
650
651} // namespace
652
653//===----------------------------------------------------------------------===//
654// ReturnType Driver.
655//===----------------------------------------------------------------------===//
656
657namespace {
658// Generate ops for each instance where the type can be successfully inferred.
659template <typename OpTy>
660static void invokeCreateWithInferredReturnType(Operation *op) {
661 auto *context = op->getContext();
662 auto fop = op->getParentOfType<func::FuncOp>();
663 auto location = UnknownLoc::get(context);
664 OpBuilder b(op);
665 b.setInsertionPointAfter(op);
666
667 // Use permutations of 2 args as operands.
668 assert(fop.getNumArguments() >= 2);
669 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
670 for (int j = 0; j < e; ++j) {
671 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
672 SmallVector<Type, 2> inferredReturnTypes;
673 if (succeeded(OpTy::inferReturnTypes(
674 context, std::nullopt, values, op->getDiscardableAttrDictionary(),
675 op->getPropertiesStorage(), op->getRegions(),
676 inferredReturnTypes))) {
677 OperationState state(location, OpTy::getOperationName());
678 // TODO: Expand to regions.
679 OpTy::build(b, state, values, op->getAttrs());
680 (void)b.create(state);
681 }
682 }
683 }
684}
685
686static void reifyReturnShape(Operation *op) {
687 OpBuilder b(op);
688
689 // Use permutations of 2 args as operands.
690 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
691 SmallVector<Value, 2> shapes;
692 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
693 !llvm::hasSingleElement(C&: shapes))
694 return;
695 for (const auto &it : llvm::enumerate(First&: shapes)) {
696 op->emitRemark() << "value " << it.index() << ": "
697 << it.value().getDefiningOp();
698 }
699}
700
701struct TestReturnTypeDriver
702 : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> {
703 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver)
704
705 void getDependentDialects(DialectRegistry &registry) const override {
706 registry.insert<tensor::TensorDialect>();
707 }
708 StringRef getArgument() const final { return "test-return-type"; }
709 StringRef getDescription() const final { return "Run return type functions"; }
710
711 void runOnOperation() override {
712 if (getOperation().getName() == "testCreateFunctions") {
713 std::vector<Operation *> ops;
714 // Collect ops to avoid triggering on inserted ops.
715 for (auto &op : getOperation().getBody().front())
716 ops.push_back(&op);
717 // Generate test patterns for each, but skip terminator.
718 for (auto *op : llvm::ArrayRef(ops).drop_back()) {
719 // Test create method of each of the Op classes below. The resultant
720 // output would be in reverse order underneath `op` from which
721 // the attributes and regions are used.
722 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
723 invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>(
724 op);
725 invokeCreateWithInferredReturnType<
726 OpWithShapedTypeInferTypeInterfaceOp>(op);
727 };
728 return;
729 }
730 if (getOperation().getName() == "testReifyFunctions") {
731 std::vector<Operation *> ops;
732 // Collect ops to avoid triggering on inserted ops.
733 for (auto &op : getOperation().getBody().front())
734 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
735 ops.push_back(&op);
736 // Generate test patterns for each, but skip terminator.
737 for (auto *op : ops)
738 reifyReturnShape(op);
739 }
740 }
741};
742} // namespace
743
744namespace {
745struct TestDerivedAttributeDriver
746 : public PassWrapper<TestDerivedAttributeDriver,
747 OperationPass<func::FuncOp>> {
748 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver)
749
750 StringRef getArgument() const final { return "test-derived-attr"; }
751 StringRef getDescription() const final {
752 return "Run test derived attributes";
753 }
754 void runOnOperation() override;
755};
756} // namespace
757
758void TestDerivedAttributeDriver::runOnOperation() {
759 getOperation().walk([](DerivedAttributeOpInterface dOp) {
760 auto dAttr = dOp.materializeDerivedAttributes();
761 if (!dAttr)
762 return;
763 for (auto d : dAttr)
764 dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue();
765 });
766}
767
768//===----------------------------------------------------------------------===//
769// Legalization Driver.
770//===----------------------------------------------------------------------===//
771
772namespace {
773//===----------------------------------------------------------------------===//
774// Region-Block Rewrite Testing
775//===----------------------------------------------------------------------===//
776
777/// This pattern applies a signature conversion to a block inside a detached
778/// region.
779struct TestDetachedSignatureConversion : public ConversionPattern {
780 TestDetachedSignatureConversion(MLIRContext *ctx)
781 : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1,
782 ctx) {}
783
784 LogicalResult
785 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
786 ConversionPatternRewriter &rewriter) const final {
787 if (op->getNumRegions() != 1)
788 return failure();
789 OperationState state(op->getLoc(), "test.legal_op", operands,
790 op->getResultTypes(), {}, BlockRange());
791 Region *newRegion = state.addRegion();
792 rewriter.inlineRegionBefore(region&: op->getRegion(index: 0), parent&: *newRegion,
793 before: newRegion->begin());
794 TypeConverter::SignatureConversion result(newRegion->getNumArguments());
795 for (unsigned i = 0, e = newRegion->getNumArguments(); i < e; ++i)
796 result.addInputs(i, rewriter.getF64Type());
797 rewriter.applySignatureConversion(block: &newRegion->front(), conversion&: result);
798 Operation *newOp = rewriter.create(state);
799 rewriter.replaceOp(op, newValues: newOp->getResults());
800 return success();
801 }
802};
803
804/// This pattern is a simple pattern that inlines the first region of a given
805/// operation into the parent region.
806struct TestRegionRewriteBlockMovement : public ConversionPattern {
807 TestRegionRewriteBlockMovement(MLIRContext *ctx)
808 : ConversionPattern("test.region", 1, ctx) {}
809
810 LogicalResult
811 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
812 ConversionPatternRewriter &rewriter) const final {
813 // Inline this region into the parent region.
814 auto &parentRegion = *op->getParentRegion();
815 auto &opRegion = op->getRegion(index: 0);
816 if (op->getDiscardableAttr(name: "legalizer.should_clone"))
817 rewriter.cloneRegionBefore(region&: opRegion, parent&: parentRegion, before: parentRegion.end());
818 else
819 rewriter.inlineRegionBefore(region&: opRegion, parent&: parentRegion, before: parentRegion.end());
820
821 if (op->getDiscardableAttr(name: "legalizer.erase_old_blocks")) {
822 while (!opRegion.empty())
823 rewriter.eraseBlock(block: &opRegion.front());
824 }
825
826 // Drop this operation.
827 rewriter.eraseOp(op);
828 return success();
829 }
830};
831/// This pattern is a simple pattern that generates a region containing an
832/// illegal operation.
833struct TestRegionRewriteUndo : public RewritePattern {
834 TestRegionRewriteUndo(MLIRContext *ctx)
835 : RewritePattern("test.region_builder", 1, ctx) {}
836
837 LogicalResult matchAndRewrite(Operation *op,
838 PatternRewriter &rewriter) const final {
839 // Create the region operation with an entry block containing arguments.
840 OperationState newRegion(op->getLoc(), "test.region");
841 newRegion.addRegion();
842 auto *regionOp = rewriter.create(state: newRegion);
843 auto *entryBlock = rewriter.createBlock(parent: &regionOp->getRegion(index: 0));
844 entryBlock->addArgument(rewriter.getIntegerType(64),
845 rewriter.getUnknownLoc());
846
847 // Add an explicitly illegal operation to ensure the conversion fails.
848 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
849 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
850
851 // Drop this operation.
852 rewriter.eraseOp(op);
853 return success();
854 }
855};
856/// A simple pattern that creates a block at the end of the parent region of the
857/// matched operation.
858struct TestCreateBlock : public RewritePattern {
859 TestCreateBlock(MLIRContext *ctx)
860 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
861
862 LogicalResult matchAndRewrite(Operation *op,
863 PatternRewriter &rewriter) const final {
864 Region &region = *op->getParentRegion();
865 Type i32Type = rewriter.getIntegerType(32);
866 Location loc = op->getLoc();
867 rewriter.createBlock(parent: &region, insertPt: region.end(), argTypes: {i32Type, i32Type}, locs: {loc, loc});
868 rewriter.create<TerminatorOp>(loc);
869 rewriter.eraseOp(op);
870 return success();
871 }
872};
873
874/// A simple pattern that creates a block containing an invalid operation in
875/// order to trigger the block creation undo mechanism.
876struct TestCreateIllegalBlock : public RewritePattern {
877 TestCreateIllegalBlock(MLIRContext *ctx)
878 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
879
880 LogicalResult matchAndRewrite(Operation *op,
881 PatternRewriter &rewriter) const final {
882 Region &region = *op->getParentRegion();
883 Type i32Type = rewriter.getIntegerType(32);
884 Location loc = op->getLoc();
885 rewriter.createBlock(parent: &region, insertPt: region.end(), argTypes: {i32Type, i32Type}, locs: {loc, loc});
886 // Create an illegal op to ensure the conversion fails.
887 rewriter.create<ILLegalOpF>(loc, i32Type);
888 rewriter.create<TerminatorOp>(loc);
889 rewriter.eraseOp(op);
890 return success();
891 }
892};
893
894/// A simple pattern that tests the undo mechanism when replacing the uses of a
895/// block argument.
896struct TestUndoBlockArgReplace : public ConversionPattern {
897 TestUndoBlockArgReplace(MLIRContext *ctx)
898 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
899
900 LogicalResult
901 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
902 ConversionPatternRewriter &rewriter) const final {
903 auto illegalOp =
904 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
905 rewriter.replaceUsesOfBlockArgument(from: op->getRegion(index: 0).getArgument(i: 0),
906 to: illegalOp->getResult(0));
907 rewriter.modifyOpInPlace(root: op, callable: [] {});
908 return success();
909 }
910};
911
912/// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
913/// This is to test the rollback logic.
914struct TestUndoMoveOpBefore : public ConversionPattern {
915 TestUndoMoveOpBefore(MLIRContext *ctx)
916 : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}
917
918 LogicalResult
919 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
920 ConversionPatternRewriter &rewriter) const override {
921 rewriter.moveOpBefore(op, existingOp: op->getParentOp());
922 // Replace with an illegal op to ensure the conversion fails.
923 rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type());
924 return success();
925 }
926};
927
928/// A rewrite pattern that tests the undo mechanism when erasing a block.
929struct TestUndoBlockErase : public ConversionPattern {
930 TestUndoBlockErase(MLIRContext *ctx)
931 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
932
933 LogicalResult
934 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
935 ConversionPatternRewriter &rewriter) const final {
936 Block *secondBlock = &*std::next(x: op->getRegion(index: 0).begin());
937 rewriter.setInsertionPointToStart(secondBlock);
938 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
939 rewriter.eraseBlock(block: secondBlock);
940 rewriter.modifyOpInPlace(root: op, callable: [] {});
941 return success();
942 }
943};
944
945/// A pattern that modifies a property in-place, but keeps the op illegal.
946struct TestUndoPropertiesModification : public ConversionPattern {
947 TestUndoPropertiesModification(MLIRContext *ctx)
948 : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {}
949 LogicalResult
950 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
951 ConversionPatternRewriter &rewriter) const final {
952 if (!op->hasAttr(name: "modify_inplace"))
953 return failure();
954 rewriter.modifyOpInPlace(
955 root: op, callable: [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); });
956 return success();
957 }
958};
959
960//===----------------------------------------------------------------------===//
961// Type-Conversion Rewrite Testing
962//===----------------------------------------------------------------------===//
963
964/// This patterns erases a region operation that has had a type conversion.
965struct TestDropOpSignatureConversion : public ConversionPattern {
966 TestDropOpSignatureConversion(MLIRContext *ctx,
967 const TypeConverter &converter)
968 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
969 LogicalResult
970 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
971 ConversionPatternRewriter &rewriter) const override {
972 Region &region = op->getRegion(index: 0);
973 Block *entry = &region.front();
974
975 // Convert the original entry arguments.
976 const TypeConverter &converter = *getTypeConverter();
977 TypeConverter::SignatureConversion result(entry->getNumArguments());
978 if (failed(Result: converter.convertSignatureArgs(types: entry->getArgumentTypes(),
979 result)) ||
980 failed(Result: rewriter.convertRegionTypes(region: &region, converter, entryConversion: &result)))
981 return failure();
982
983 // Convert the region signature and just drop the operation.
984 rewriter.eraseOp(op);
985 return success();
986 }
987};
988/// This pattern simply updates the operands of the given operation.
989struct TestPassthroughInvalidOp : public ConversionPattern {
990 TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
991 : ConversionPattern(converter, "test.invalid", 1, ctx) {}
992 LogicalResult
993 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
994 ConversionPatternRewriter &rewriter) const final {
995 SmallVector<Value> flattened;
996 for (auto it : llvm::enumerate(First&: operands)) {
997 ValueRange range = it.value();
998 if (range.size() == 1) {
999 flattened.push_back(Elt: range.front());
1000 continue;
1001 }
1002
1003 // This is a 1:N replacement. Insert a test.cast op. (That's what the
1004 // argument materialization used to do.)
1005 flattened.push_back(
1006 rewriter
1007 .create<TestCastOp>(op->getLoc(),
1008 op->getOperand(it.index()).getType(), range)
1009 .getResult());
1010 }
1011 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, flattened,
1012 std::nullopt);
1013 return success();
1014 }
1015};
1016/// Replace with valid op, but simply drop the operands. This is used in a
1017/// regression where we used to generate circular unrealized_conversion_cast
1018/// ops.
1019struct TestDropAndReplaceInvalidOp : public ConversionPattern {
1020 TestDropAndReplaceInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
1021 : ConversionPattern(converter,
1022 "test.drop_operands_and_replace_with_valid", 1, ctx) {
1023 }
1024 LogicalResult
1025 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1026 ConversionPatternRewriter &rewriter) const final {
1027 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, ValueRange(),
1028 std::nullopt);
1029 return success();
1030 }
1031};
1032/// This pattern handles the case of a split return value.
1033struct TestSplitReturnType : public ConversionPattern {
1034 TestSplitReturnType(MLIRContext *ctx)
1035 : ConversionPattern("test.return", 1, ctx) {}
1036 LogicalResult
1037 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1038 ConversionPatternRewriter &rewriter) const final {
1039 // Check for a return of F32.
1040 if (op->getNumOperands() != 1 || !op->getOperand(idx: 0).getType().isF32())
1041 return failure();
1042 rewriter.replaceOpWithNewOp<TestReturnOp>(op, operands[0]);
1043 return success();
1044 }
1045};
1046
1047//===----------------------------------------------------------------------===//
1048// Multi-Level Type-Conversion Rewrite Testing
1049struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
1050 TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
1051 : ConversionPattern("test.type_producer", 1, ctx) {}
1052 LogicalResult
1053 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1054 ConversionPatternRewriter &rewriter) const final {
1055 // If the type is I32, change the type to F32.
1056 if (!Type(*op->result_type_begin()).isSignlessInteger(width: 32))
1057 return failure();
1058 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
1059 return success();
1060 }
1061};
1062struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
1063 TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
1064 : ConversionPattern("test.type_producer", 1, ctx) {}
1065 LogicalResult
1066 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1067 ConversionPatternRewriter &rewriter) const final {
1068 // If the type is F32, change the type to F64.
1069 if (!Type(*op->result_type_begin()).isF32())
1070 return rewriter.notifyMatchFailure(arg&: op, msg: "expected single f32 operand");
1071 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
1072 return success();
1073 }
1074};
1075struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
1076 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
1077 : ConversionPattern("test.type_producer", 10, ctx) {}
1078 LogicalResult
1079 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1080 ConversionPatternRewriter &rewriter) const final {
1081 // Always convert to B16, even though it is not a legal type. This tests
1082 // that values are unmapped correctly.
1083 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
1084 return success();
1085 }
1086};
1087struct TestUpdateConsumerType : public ConversionPattern {
1088 TestUpdateConsumerType(MLIRContext *ctx)
1089 : ConversionPattern("test.type_consumer", 1, ctx) {}
1090 LogicalResult
1091 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1092 ConversionPatternRewriter &rewriter) const final {
1093 // Verify that the incoming operand has been successfully remapped to F64.
1094 if (!operands[0].getType().isF64())
1095 return failure();
1096 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
1097 return success();
1098 }
1099};
1100
1101//===----------------------------------------------------------------------===//
1102// Non-Root Replacement Rewrite Testing
1103/// This pattern generates an invalid operation, but replaces it before the
1104/// pattern is finished. This checks that we don't need to legalize the
1105/// temporary op.
1106struct TestNonRootReplacement : public RewritePattern {
1107 TestNonRootReplacement(MLIRContext *ctx)
1108 : RewritePattern("test.replace_non_root", 1, ctx) {}
1109
1110 LogicalResult matchAndRewrite(Operation *op,
1111 PatternRewriter &rewriter) const final {
1112 auto resultType = *op->result_type_begin();
1113 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
1114 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
1115
1116 rewriter.replaceOp(illegalOp, legalOp);
1117 rewriter.replaceOp(op, illegalOp);
1118 return success();
1119 }
1120};
1121
1122//===----------------------------------------------------------------------===//
1123// Recursive Rewrite Testing
1124/// This pattern is applied to the same operation multiple times, but has a
1125/// bounded recursion.
1126struct TestBoundedRecursiveRewrite
1127 : public OpRewritePattern<TestRecursiveRewriteOp> {
1128 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
1129
1130 void initialize() {
1131 // The conversion target handles bounding the recursion of this pattern.
1132 setHasBoundedRewriteRecursion();
1133 }
1134
1135 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
1136 PatternRewriter &rewriter) const final {
1137 // Decrement the depth of the op in-place.
1138 rewriter.modifyOpInPlace(op, [&] {
1139 op->setAttr("depth", rewriter.getI64IntegerAttr(value: op.getDepth() - 1));
1140 });
1141 return success();
1142 }
1143};
1144
1145struct TestNestedOpCreationUndoRewrite
1146 : public OpRewritePattern<IllegalOpWithRegionAnchor> {
1147 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
1148
1149 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
1150 PatternRewriter &rewriter) const final {
1151 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1152 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1153 return success();
1154 };
1155};
1156
1157// This pattern matches `test.blackhole` and delete this op and its producer.
1158struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
1159 using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
1160
1161 LogicalResult matchAndRewrite(BlackHoleOp op,
1162 PatternRewriter &rewriter) const final {
1163 Operation *producer = op.getOperand().getDefiningOp();
1164 // Always erase the user before the producer, the framework should handle
1165 // this correctly.
1166 rewriter.eraseOp(op: op);
1167 rewriter.eraseOp(op: producer);
1168 return success();
1169 };
1170};
1171
1172// This pattern replaces explicitly illegal op with explicitly legal op,
1173// but in addition creates unregistered operation.
1174struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
1175 using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
1176
1177 LogicalResult matchAndRewrite(ILLegalOpG op,
1178 PatternRewriter &rewriter) const final {
1179 IntegerAttr attr = rewriter.getI32IntegerAttr(0);
1180 Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
1181 rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
1182 return success();
1183 };
1184};
1185
1186class TestEraseOp : public ConversionPattern {
1187public:
1188 TestEraseOp(MLIRContext *ctx) : ConversionPattern("test.erase_op", 1, ctx) {}
1189 LogicalResult
1190 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1191 ConversionPatternRewriter &rewriter) const final {
1192 // Erase op without replacements.
1193 rewriter.eraseOp(op);
1194 return success();
1195 }
1196};
1197
1198/// This pattern matches a test.convert_block_args op. It either:
1199/// a) Duplicates all block arguments,
1200/// b) or: drops all block arguments and replaces each with 2x the first
1201/// operand.
1202class TestConvertBlockArgs : public OpConversionPattern<ConvertBlockArgsOp> {
1203 using OpConversionPattern<ConvertBlockArgsOp>::OpConversionPattern;
1204
1205 LogicalResult
1206 matchAndRewrite(ConvertBlockArgsOp op, OpAdaptor adaptor,
1207 ConversionPatternRewriter &rewriter) const override {
1208 if (op.getIsLegal())
1209 return failure();
1210 Block *body = &op.getBody().front();
1211 TypeConverter::SignatureConversion result(body->getNumArguments());
1212 for (auto it : llvm::enumerate(body->getArgumentTypes())) {
1213 if (op.getReplaceWithOperand()) {
1214 result.remapInput(it.index(), {adaptor.getVal(), adaptor.getVal()});
1215 } else if (op.getDuplicate()) {
1216 result.addInputs(it.index(), {it.value(), it.value()});
1217 } else {
1218 // No action specified. Pattern does not apply.
1219 return failure();
1220 }
1221 }
1222 rewriter.startOpModification(op: op);
1223 rewriter.applySignatureConversion(block: body, conversion&: result, converter: getTypeConverter());
1224 op.setIsLegal(true);
1225 rewriter.finalizeOpModification(op: op);
1226 return success();
1227 }
1228};
1229
1230/// This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid
1231/// op. The pattern supports 1:N replacements and forwards the replacement
1232/// values of the single operand as test.valid operands.
1233class TestRepetitive1ToNConsumer : public ConversionPattern {
1234public:
1235 TestRepetitive1ToNConsumer(MLIRContext *ctx)
1236 : ConversionPattern("test.repetitive_1_to_n_consumer", 1, ctx) {}
1237 LogicalResult
1238 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1239 ConversionPatternRewriter &rewriter) const final {
1240 // A single operand is expected.
1241 if (op->getNumOperands() != 1)
1242 return failure();
1243 rewriter.replaceOpWithNewOp<TestValidOp>(op, operands.front());
1244 return success();
1245 }
1246};
1247
1248/// A pattern that tests two back-to-back 1 -> 2 op replacements.
1249class TestMultiple1ToNReplacement : public ConversionPattern {
1250public:
1251 TestMultiple1ToNReplacement(MLIRContext *ctx, const TypeConverter &converter)
1252 : ConversionPattern(converter, "test.multiple_1_to_n_replacement", 1,
1253 ctx) {}
1254 LogicalResult
1255 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1256 ConversionPatternRewriter &rewriter) const final {
1257 // Helper function that replaces the given op with a new op of the given
1258 // name and doubles each result (1 -> 2 replacement of each result).
1259 auto replaceWithDoubleResults = [&](Operation *op, StringRef name) {
1260 SmallVector<Type> types;
1261 for (Type t : op->getResultTypes()) {
1262 types.push_back(Elt: t);
1263 types.push_back(Elt: t);
1264 }
1265 OperationState state(op->getLoc(), name,
1266 /*operands=*/{}, types, op->getAttrs());
1267 auto *newOp = rewriter.create(state);
1268 SmallVector<ValueRange> repls;
1269 for (size_t i = 0, e = op->getNumResults(); i < e; ++i)
1270 repls.push_back(Elt: newOp->getResults().slice(n: 2 * i, m: 2));
1271 rewriter.replaceOpWithMultiple(op, newValues&: repls);
1272 return newOp;
1273 };
1274
1275 // Replace test.multiple_1_to_n_replacement with test.step_1.
1276 Operation *repl1 = replaceWithDoubleResults(op, "test.step_1");
1277 // Now replace test.step_1 with test.legal_op.
1278 replaceWithDoubleResults(repl1, "test.legal_op");
1279 return success();
1280 }
1281};
1282
1283/// Test unambiguous overload resolution of replaceOpWithMultiple. This
1284/// function is just to trigger compiler errors. It is never executed.
1285[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
1286 ConversionPatternRewriter &rewriter, Operation *op, ArrayRef<ValueRange> r1,
1287 SmallVector<ValueRange> r2, ArrayRef<SmallVector<Value>> r3,
1288 SmallVector<SmallVector<Value>> r4, ArrayRef<ArrayRef<Value>> r5,
1289 SmallVector<ArrayRef<Value>> r6, SmallVector<SmallVector<Value>> &&r7,
1290 Value v, ValueRange vr, ArrayRef<Value> ar) {
1291 rewriter.replaceOpWithMultiple(op, newValues: r1);
1292 rewriter.replaceOpWithMultiple(op, newValues&: r2);
1293 rewriter.replaceOpWithMultiple(op, newValues: r3);
1294 rewriter.replaceOpWithMultiple(op, newValues&: r4);
1295 rewriter.replaceOpWithMultiple(op, newValues: r5);
1296 rewriter.replaceOpWithMultiple(op, newValues&: r6);
1297 rewriter.replaceOpWithMultiple(op, newValues: std::move(r7));
1298 rewriter.replaceOpWithMultiple(op, newValues: {vr});
1299 rewriter.replaceOpWithMultiple(op, newValues: {ar});
1300 rewriter.replaceOpWithMultiple(op, newValues: {{v}});
1301 rewriter.replaceOpWithMultiple(op, newValues: {{v, v}});
1302 rewriter.replaceOpWithMultiple(op, newValues: {{v, v}, vr});
1303 rewriter.replaceOpWithMultiple(op, newValues: {{v, v}, ar});
1304 rewriter.replaceOpWithMultiple(op, newValues: {ar, {v, v}, vr});
1305}
1306} // namespace
1307
1308namespace {
1309struct TestTypeConverter : public TypeConverter {
1310 using TypeConverter::TypeConverter;
1311 TestTypeConverter() {
1312 addConversion(callback&: convertType);
1313 addSourceMaterialization(callback&: materializeCast);
1314 }
1315
1316 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
1317 // Drop I16 types.
1318 if (t.isSignlessInteger(width: 16))
1319 return success();
1320
1321 // Convert I64 to F64.
1322 if (t.isSignlessInteger(width: 64)) {
1323 results.push_back(Float64Type::get(t.getContext()));
1324 return success();
1325 }
1326
1327 // Convert I42 to I43.
1328 if (t.isInteger(width: 42)) {
1329 results.push_back(IntegerType::get(t.getContext(), 43));
1330 return success();
1331 }
1332
1333 // Split F32 into F16,F16.
1334 if (t.isF32()) {
1335 results.assign(2, Float16Type::get(t.getContext()));
1336 return success();
1337 }
1338
1339 // Drop I24 types.
1340 if (t.isInteger(width: 24)) {
1341 return success();
1342 }
1343
1344 // Otherwise, convert the type directly.
1345 results.push_back(Elt: t);
1346 return success();
1347 }
1348
1349 /// Hook for materializing a conversion. This is necessary because we generate
1350 /// 1->N type mappings.
1351 static Value materializeCast(OpBuilder &builder, Type resultType,
1352 ValueRange inputs, Location loc) {
1353 return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1354 }
1355};
1356
1357struct TestLegalizePatternDriver
1358 : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> {
1359 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver)
1360
1361 StringRef getArgument() const final { return "test-legalize-patterns"; }
1362 StringRef getDescription() const final {
1363 return "Run test dialect legalization patterns";
1364 }
1365 /// The mode of conversion to use with the driver.
1366 enum class ConversionMode { Analysis, Full, Partial };
1367
1368 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
1369
1370 void getDependentDialects(DialectRegistry &registry) const override {
1371 registry.insert<func::FuncDialect, test::TestDialect>();
1372 }
1373
1374 void runOnOperation() override {
1375 TestTypeConverter converter;
1376 mlir::RewritePatternSet patterns(&getContext());
1377 populateWithGenerated(patterns);
1378 patterns
1379 .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1380 TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1381 TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
1382 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
1383 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
1384 TestNonRootReplacement, TestBoundedRecursiveRewrite,
1385 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
1386 TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1387 TestUndoPropertiesModification, TestEraseOp,
1388 TestRepetitive1ToNConsumer>(arg: &getContext());
1389 patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
1390 TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
1391 arg: &getContext(), args&: converter);
1392 patterns.add<TestConvertBlockArgs>(arg&: converter, args: &getContext());
1393 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1394 converter);
1395 mlir::populateCallOpTypeConversionPattern(patterns, converter);
1396
1397 // Define the conversion target used for the test.
1398 ConversionTarget target(getContext());
1399 target.addLegalOp<ModuleOp>();
1400 target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
1401 TerminatorOp, OneRegionOp>();
1402 target.addLegalOp(op: OperationName("test.legal_op", &getContext()));
1403 target
1404 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
1405 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
1406 // Don't allow F32 operands.
1407 return llvm::none_of(op.getOperandTypes(),
1408 [](Type type) { return type.isF32(); });
1409 });
1410 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1411 return converter.isSignatureLegal(op.getFunctionType()) &&
1412 converter.isLegal(&op.getBody());
1413 });
1414 target.addDynamicallyLegalOp<func::CallOp>(
1415 [&](func::CallOp op) { return converter.isLegal(op); });
1416
1417 // TestCreateUnregisteredOp creates `arith.constant` operation,
1418 // which was not added to target intentionally to test
1419 // correct error code from conversion driver.
1420 target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; });
1421
1422 // Expect the type_producer/type_consumer operations to only operate on f64.
1423 target.addDynamicallyLegalOp<TestTypeProducerOp>(
1424 [](TestTypeProducerOp op) { return op.getType().isF64(); });
1425 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1426 return op.getOperand().getType().isF64();
1427 });
1428
1429 // Check support for marking certain operations as recursively legal.
1430 target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) {
1431 return static_cast<bool>(
1432 op->getAttrOfType<UnitAttr>("test.recursively_legal"));
1433 });
1434
1435 // Mark the bound recursion operation as dynamically legal.
1436 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
1437 [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
1438
1439 // Create a dynamically legal rule that can only be legalized by folding it.
1440 target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
1441 [](TestOpInPlaceSelfFold op) { return op.getFolded(); });
1442
1443 target.addDynamicallyLegalOp<ConvertBlockArgsOp>(
1444 [](ConvertBlockArgsOp op) { return op.getIsLegal(); });
1445
1446 // Handle a partial conversion.
1447 if (mode == ConversionMode::Partial) {
1448 DenseSet<Operation *> unlegalizedOps;
1449 ConversionConfig config;
1450 DumpNotifications dumpNotifications;
1451 config.listener = &dumpNotifications;
1452 config.unlegalizedOps = &unlegalizedOps;
1453 if (failed(applyPartialConversion(getOperation(), target,
1454 std::move(patterns), config))) {
1455 getOperation()->emitRemark() << "applyPartialConversion failed";
1456 }
1457 // Emit remarks for each legalizable operation.
1458 for (auto *op : unlegalizedOps)
1459 op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1460 return;
1461 }
1462
1463 // Handle a full conversion.
1464 if (mode == ConversionMode::Full) {
1465 // Check support for marking unknown operations as dynamically legal.
1466 target.markUnknownOpDynamicallyLegal([](Operation *op) {
1467 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
1468 });
1469
1470 ConversionConfig config;
1471 DumpNotifications dumpNotifications;
1472 config.listener = &dumpNotifications;
1473 if (failed(applyFullConversion(getOperation(), target,
1474 std::move(patterns), config))) {
1475 getOperation()->emitRemark() << "applyFullConversion failed";
1476 }
1477 return;
1478 }
1479
1480 // Otherwise, handle an analysis conversion.
1481 assert(mode == ConversionMode::Analysis);
1482
1483 // Analyze the convertible operations.
1484 DenseSet<Operation *> legalizedOps;
1485 ConversionConfig config;
1486 config.legalizableOps = &legalizedOps;
1487 if (failed(applyAnalysisConversion(getOperation(), target,
1488 std::move(patterns), config)))
1489 return signalPassFailure();
1490
1491 // Emit remarks for each legalizable operation.
1492 for (auto *op : legalizedOps)
1493 op->emitRemark() << "op '" << op->getName() << "' is legalizable";
1494 }
1495
1496 /// The mode of conversion to use.
1497 ConversionMode mode;
1498};
1499} // namespace
1500
1501static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
1502 legalizerConversionMode(
1503 "test-legalize-mode",
1504 llvm::cl::desc("The legalization mode to use with the test driver"),
1505 llvm::cl::init(Val: TestLegalizePatternDriver::ConversionMode::Partial),
1506 llvm::cl::values(
1507 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
1508 "analysis", "Perform an analysis conversion"),
1509 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
1510 "Perform a full conversion"),
1511 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
1512 "partial", "Perform a partial conversion")));
1513
1514//===----------------------------------------------------------------------===//
1515// ConversionPatternRewriter::getRemappedValue testing. This method is used
1516// to get the remapped value of an original value that was replaced using
1517// ConversionPatternRewriter.
1518namespace {
1519struct TestRemapValueTypeConverter : public TypeConverter {
1520 using TypeConverter::TypeConverter;
1521
1522 TestRemapValueTypeConverter() {
1523 addConversion(
1524 callback: [](Float32Type type) { return Float64Type::get(type.getContext()); });
1525 addConversion(callback: [](Type type) { return type; });
1526 }
1527};
1528
1529/// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
1530/// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
1531/// operand twice.
1532///
1533/// Example:
1534/// %1 = test.one_variadic_out_one_variadic_in1"(%0)
1535/// is replaced with:
1536/// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
1537struct OneVResOneVOperandOp1Converter
1538 : public OpConversionPattern<OneVResOneVOperandOp1> {
1539 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
1540
1541 LogicalResult
1542 matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
1543 ConversionPatternRewriter &rewriter) const override {
1544 auto origOps = op.getOperands();
1545 assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
1546 "One operand expected");
1547 Value origOp = *origOps.begin();
1548 SmallVector<Value, 2> remappedOperands;
1549 // Replicate the remapped original operand twice. Note that we don't used
1550 // the remapped 'operand' since the goal is testing 'getRemappedValue'.
1551 remappedOperands.push_back(Elt: rewriter.getRemappedValue(key: origOp));
1552 remappedOperands.push_back(Elt: rewriter.getRemappedValue(key: origOp));
1553
1554 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
1555 remappedOperands);
1556 return success();
1557 }
1558};
1559
1560/// A rewriter pattern that tests that blocks can be merged.
1561struct TestRemapValueInRegion
1562 : public OpConversionPattern<TestRemappedValueRegionOp> {
1563 using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern;
1564
1565 LogicalResult
1566 matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor,
1567 ConversionPatternRewriter &rewriter) const final {
1568 Block &block = op.getBody().front();
1569 Operation *terminator = block.getTerminator();
1570
1571 // Merge the block into the parent region.
1572 Block *parentBlock = op->getBlock();
1573 Block *finalBlock = rewriter.splitBlock(block: parentBlock, before: op->getIterator());
1574 rewriter.mergeBlocks(source: &block, dest: parentBlock, argValues: ValueRange());
1575 rewriter.mergeBlocks(source: finalBlock, dest: parentBlock, argValues: ValueRange());
1576
1577 // Replace the results of this operation with the remapped terminator
1578 // values.
1579 SmallVector<Value> terminatorOperands;
1580 if (failed(Result: rewriter.getRemappedValues(keys: terminator->getOperands(),
1581 results&: terminatorOperands)))
1582 return failure();
1583
1584 rewriter.eraseOp(op: terminator);
1585 rewriter.replaceOp(op, terminatorOperands);
1586 return success();
1587 }
1588};
1589
1590struct TestRemappedValue
1591 : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> {
1592 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue)
1593
1594 StringRef getArgument() const final { return "test-remapped-value"; }
1595 StringRef getDescription() const final {
1596 return "Test public remapped value mechanism in ConversionPatternRewriter";
1597 }
1598 void runOnOperation() override {
1599 TestRemapValueTypeConverter typeConverter;
1600
1601 mlir::RewritePatternSet patterns(&getContext());
1602 patterns.add<OneVResOneVOperandOp1Converter>(arg: &getContext());
1603 patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>(
1604 arg: &getContext());
1605 patterns.add<TestRemapValueInRegion>(arg&: typeConverter, args: &getContext());
1606
1607 mlir::ConversionTarget target(getContext());
1608 target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>();
1609
1610 // Expect the type_producer/type_consumer operations to only operate on f64.
1611 target.addDynamicallyLegalOp<TestTypeProducerOp>(
1612 [](TestTypeProducerOp op) { return op.getType().isF64(); });
1613 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1614 return op.getOperand().getType().isF64();
1615 });
1616
1617 // We make OneVResOneVOperandOp1 legal only when it has more that one
1618 // operand. This will trigger the conversion that will replace one-operand
1619 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
1620 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
1621 [](Operation *op) { return op->getNumOperands() > 1; });
1622
1623 if (failed(mlir::applyFullConversion(getOperation(), target,
1624 std::move(patterns)))) {
1625 signalPassFailure();
1626 }
1627 }
1628};
1629} // namespace
1630
1631//===----------------------------------------------------------------------===//
1632// Test patterns without a specific root operation kind
1633//===----------------------------------------------------------------------===//
1634
1635namespace {
1636/// This pattern matches and removes any operation in the test dialect.
1637struct RemoveTestDialectOps : public RewritePattern {
1638 RemoveTestDialectOps(MLIRContext *context)
1639 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
1640
1641 LogicalResult matchAndRewrite(Operation *op,
1642 PatternRewriter &rewriter) const override {
1643 if (!isa<TestDialect>(Val: op->getDialect()))
1644 return failure();
1645 rewriter.eraseOp(op);
1646 return success();
1647 }
1648};
1649
1650struct TestUnknownRootOpDriver
1651 : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> {
1652 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver)
1653
1654 StringRef getArgument() const final {
1655 return "test-legalize-unknown-root-patterns";
1656 }
1657 StringRef getDescription() const final {
1658 return "Test public remapped value mechanism in ConversionPatternRewriter";
1659 }
1660 void runOnOperation() override {
1661 mlir::RewritePatternSet patterns(&getContext());
1662 patterns.add<RemoveTestDialectOps>(arg: &getContext());
1663
1664 mlir::ConversionTarget target(getContext());
1665 target.addIllegalDialect<TestDialect>();
1666 if (failed(applyPartialConversion(getOperation(), target,
1667 std::move(patterns))))
1668 signalPassFailure();
1669 }
1670};
1671} // namespace
1672
1673//===----------------------------------------------------------------------===//
1674// Test patterns that uses operations and types defined at runtime
1675//===----------------------------------------------------------------------===//
1676
1677namespace {
1678/// This pattern matches dynamic operations 'test.one_operand_two_results' and
1679/// replace them with dynamic operations 'test.generic_dynamic_op'.
1680struct RewriteDynamicOp : public RewritePattern {
1681 RewriteDynamicOp(MLIRContext *context)
1682 : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1,
1683 context) {}
1684
1685 LogicalResult matchAndRewrite(Operation *op,
1686 PatternRewriter &rewriter) const override {
1687 assert(op->getName().getStringRef() ==
1688 "test.dynamic_one_operand_two_results" &&
1689 "rewrite pattern should only match operations with the right name");
1690
1691 OperationState state(op->getLoc(), "test.dynamic_generic",
1692 op->getOperands(), op->getResultTypes(),
1693 op->getAttrs());
1694 auto *newOp = rewriter.create(state);
1695 rewriter.replaceOp(op, newValues: newOp->getResults());
1696 return success();
1697 }
1698};
1699
1700struct TestRewriteDynamicOpDriver
1701 : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> {
1702 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver)
1703
1704 void getDependentDialects(DialectRegistry &registry) const override {
1705 registry.insert<TestDialect>();
1706 }
1707 StringRef getArgument() const final { return "test-rewrite-dynamic-op"; }
1708 StringRef getDescription() const final {
1709 return "Test rewritting on dynamic operations";
1710 }
1711 void runOnOperation() override {
1712 RewritePatternSet patterns(&getContext());
1713 patterns.add<RewriteDynamicOp>(arg: &getContext());
1714
1715 ConversionTarget target(getContext());
1716 target.addIllegalOp(
1717 op: OperationName("test.dynamic_one_operand_two_results", &getContext()));
1718 target.addLegalOp(op: OperationName("test.dynamic_generic", &getContext()));
1719 if (failed(applyPartialConversion(getOperation(), target,
1720 std::move(patterns))))
1721 signalPassFailure();
1722 }
1723};
1724} // end anonymous namespace
1725
1726//===----------------------------------------------------------------------===//
1727// Test type conversions
1728//===----------------------------------------------------------------------===//
1729
1730namespace {
1731struct TestTypeConversionProducer
1732 : public OpConversionPattern<TestTypeProducerOp> {
1733 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
1734 LogicalResult
1735 matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
1736 ConversionPatternRewriter &rewriter) const final {
1737 Type resultType = op.getType();
1738 Type convertedType = getTypeConverter()
1739 ? getTypeConverter()->convertType(resultType)
1740 : resultType;
1741 if (isa<FloatType>(Val: resultType))
1742 resultType = rewriter.getF64Type();
1743 else if (resultType.isInteger(width: 16))
1744 resultType = rewriter.getIntegerType(64);
1745 else if (isa<test::TestRecursiveType>(Val: resultType) &&
1746 convertedType != resultType)
1747 resultType = convertedType;
1748 else
1749 return failure();
1750
1751 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
1752 return success();
1753 }
1754};
1755
1756/// Call signature conversion and then fail the rewrite to trigger the undo
1757/// mechanism.
1758struct TestSignatureConversionUndo
1759 : public OpConversionPattern<TestSignatureConversionUndoOp> {
1760 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
1761
1762 LogicalResult
1763 matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
1764 ConversionPatternRewriter &rewriter) const final {
1765 (void)rewriter.convertRegionTypes(region: &op->getRegion(0), converter: *getTypeConverter());
1766 return failure();
1767 }
1768};
1769
1770/// Call signature conversion without providing a type converter to handle
1771/// materializations.
1772struct TestTestSignatureConversionNoConverter
1773 : public OpConversionPattern<TestSignatureConversionNoConverterOp> {
1774 TestTestSignatureConversionNoConverter(const TypeConverter &converter,
1775 MLIRContext *context)
1776 : OpConversionPattern<TestSignatureConversionNoConverterOp>(context),
1777 converter(converter) {}
1778
1779 LogicalResult
1780 matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor,
1781 ConversionPatternRewriter &rewriter) const final {
1782 Region &region = op->getRegion(0);
1783 Block *entry = &region.front();
1784
1785 // Convert the original entry arguments.
1786 TypeConverter::SignatureConversion result(entry->getNumArguments());
1787 if (failed(
1788 Result: converter.convertSignatureArgs(types: entry->getArgumentTypes(), result)))
1789 return failure();
1790 rewriter.modifyOpInPlace(op, [&] {
1791 rewriter.applySignatureConversion(block: &region.front(), conversion&: result);
1792 });
1793 return success();
1794 }
1795
1796 const TypeConverter &converter;
1797};
1798
1799/// Just forward the operands to the root op. This is essentially a no-op
1800/// pattern that is used to trigger target materialization.
1801struct TestTypeConsumerForward
1802 : public OpConversionPattern<TestTypeConsumerOp> {
1803 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
1804
1805 LogicalResult
1806 matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
1807 ConversionPatternRewriter &rewriter) const final {
1808 rewriter.modifyOpInPlace(op,
1809 [&] { op->setOperands(adaptor.getOperands()); });
1810 return success();
1811 }
1812};
1813
1814struct TestTypeConversionAnotherProducer
1815 : public OpRewritePattern<TestAnotherTypeProducerOp> {
1816 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
1817
1818 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
1819 PatternRewriter &rewriter) const final {
1820 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
1821 return success();
1822 }
1823};
1824
1825struct TestReplaceWithLegalOp : public ConversionPattern {
1826 TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx)
1827 : ConversionPattern(converter, "test.replace_with_legal_op",
1828 /*benefit=*/1, ctx) {}
1829 LogicalResult
1830 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1831 ConversionPatternRewriter &rewriter) const final {
1832 rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]);
1833 return success();
1834 }
1835};
1836
1837struct TestTypeConversionDriver
1838 : public PassWrapper<TestTypeConversionDriver, OperationPass<>> {
1839 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
1840
1841 void getDependentDialects(DialectRegistry &registry) const override {
1842 registry.insert<TestDialect>();
1843 }
1844 StringRef getArgument() const final {
1845 return "test-legalize-type-conversion";
1846 }
1847 StringRef getDescription() const final {
1848 return "Test various type conversion functionalities in DialectConversion";
1849 }
1850
1851 void runOnOperation() override {
1852 // Initialize the type converter.
1853 SmallVector<Type, 2> conversionCallStack;
1854 TypeConverter converter;
1855
1856 /// Add the legal set of type conversions.
1857 converter.addConversion(callback: [](Type type) -> Type {
1858 // Treat F64 as legal.
1859 if (type.isF64())
1860 return type;
1861 // Allow converting BF16/F16/F32 to F64.
1862 if (type.isBF16() || type.isF16() || type.isF32())
1863 return Float64Type::get(type.getContext());
1864 // Otherwise, the type is illegal.
1865 return nullptr;
1866 });
1867 converter.addConversion(callback: [](IntegerType type, SmallVectorImpl<Type> &) {
1868 // Drop all integer types.
1869 return success();
1870 });
1871 converter.addConversion(
1872 // Convert a recursive self-referring type into a non-self-referring
1873 // type named "outer_converted_type" that contains a SimpleAType.
1874 callback: [&](test::TestRecursiveType type,
1875 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
1876 // If the type is already converted, return it to indicate that it is
1877 // legal.
1878 if (type.getName() == "outer_converted_type") {
1879 results.push_back(type);
1880 return success();
1881 }
1882
1883 conversionCallStack.push_back(type);
1884 auto popConversionCallStack = llvm::make_scope_exit(
1885 F: [&conversionCallStack]() { conversionCallStack.pop_back(); });
1886
1887 // If the type is on the call stack more than once (it is there at
1888 // least once because of the _current_ call, which is always the last
1889 // element on the stack), we've hit the recursive case. Just return
1890 // SimpleAType here to create a non-recursive type as a result.
1891 if (llvm::is_contained(Range: ArrayRef(conversionCallStack).drop_back(),
1892 Element: type)) {
1893 results.push_back(test::SimpleAType::get(type.getContext()));
1894 return success();
1895 }
1896
1897 // Convert the body recursively.
1898 auto result = test::TestRecursiveType::get(ctx: type.getContext(),
1899 name: "outer_converted_type");
1900 if (failed(result.setBody(converter.convertType(t: type.getBody()))))
1901 return failure();
1902 results.push_back(Elt: result);
1903 return success();
1904 });
1905
1906 /// Add the legal set of type materializations.
1907 converter.addSourceMaterialization(callback: [](OpBuilder &builder, Type resultType,
1908 ValueRange inputs,
1909 Location loc) -> Value {
1910 // Allow casting from F64 back to F32.
1911 if (!resultType.isF16() && inputs.size() == 1 &&
1912 inputs[0].getType().isF64())
1913 return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1914 // Allow producing an i32 or i64 from nothing.
1915 if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
1916 inputs.empty())
1917 return builder.create<TestTypeProducerOp>(loc, resultType);
1918 // Allow producing an i64 from an integer.
1919 if (isa<IntegerType>(resultType) && inputs.size() == 1 &&
1920 isa<IntegerType>(inputs[0].getType()))
1921 return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1922 // Otherwise, fail.
1923 return nullptr;
1924 });
1925
1926 // Initialize the conversion target.
1927 mlir::ConversionTarget target(getContext());
1928 target.addLegalOp<LegalOpD>();
1929 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
1930 auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType());
1931 return op.getType().isF64() || op.getType().isInteger(64) ||
1932 (recursiveType &&
1933 recursiveType.getName() == "outer_converted_type");
1934 });
1935 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1936 return converter.isSignatureLegal(op.getFunctionType()) &&
1937 converter.isLegal(&op.getBody());
1938 });
1939 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
1940 // Allow casts from F64 to F32.
1941 return (*op.operand_type_begin()).isF64() && op.getType().isF32();
1942 });
1943 target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>(
1944 [&](TestSignatureConversionNoConverterOp op) {
1945 return converter.isLegal(op.getRegion().front().getArgumentTypes());
1946 });
1947
1948 // Initialize the set of rewrite patterns.
1949 RewritePatternSet patterns(&getContext());
1950 patterns
1951 .add<TestTypeConsumerForward, TestTypeConversionProducer,
1952 TestSignatureConversionUndo,
1953 TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>(
1954 arg&: converter, args: &getContext());
1955 patterns.add<TestTypeConversionAnotherProducer>(arg: &getContext());
1956 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1957 converter);
1958
1959 if (failed(applyPartialConversion(getOperation(), target,
1960 std::move(patterns))))
1961 signalPassFailure();
1962 }
1963};
1964} // namespace
1965
1966//===----------------------------------------------------------------------===//
1967// Test Target Materialization With No Uses
1968//===----------------------------------------------------------------------===//
1969
1970namespace {
1971struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> {
1972 using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern;
1973
1974 LogicalResult
1975 matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor,
1976 ConversionPatternRewriter &rewriter) const final {
1977 rewriter.replaceOp(op, adaptor.getOperands());
1978 return success();
1979 }
1980};
1981
1982struct TestTargetMaterializationWithNoUses
1983 : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> {
1984 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
1985 TestTargetMaterializationWithNoUses)
1986
1987 StringRef getArgument() const final {
1988 return "test-target-materialization-with-no-uses";
1989 }
1990 StringRef getDescription() const final {
1991 return "Test a special case of target materialization in DialectConversion";
1992 }
1993
1994 void runOnOperation() override {
1995 TypeConverter converter;
1996 converter.addConversion(callback: [](Type t) { return t; });
1997 converter.addConversion(callback: [](IntegerType intTy) -> Type {
1998 if (intTy.getWidth() == 16)
1999 return IntegerType::get(intTy.getContext(), 64);
2000 return intTy;
2001 });
2002 converter.addTargetMaterialization(
2003 callback: [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
2004 return builder.create<TestCastOp>(loc, type, inputs).getResult();
2005 });
2006
2007 ConversionTarget target(getContext());
2008 target.addIllegalOp<TestTypeChangerOp>();
2009
2010 RewritePatternSet patterns(&getContext());
2011 patterns.add<ForwardOperandPattern>(arg&: converter, args: &getContext());
2012
2013 if (failed(applyPartialConversion(getOperation(), target,
2014 std::move(patterns))))
2015 signalPassFailure();
2016 }
2017};
2018} // namespace
2019
2020//===----------------------------------------------------------------------===//
2021// Test Block Merging
2022//===----------------------------------------------------------------------===//
2023
2024namespace {
2025/// A rewriter pattern that tests that blocks can be merged.
2026struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
2027 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
2028
2029 LogicalResult
2030 matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
2031 ConversionPatternRewriter &rewriter) const final {
2032 Block &firstBlock = op.getBody().front();
2033 Operation *branchOp = firstBlock.getTerminator();
2034 Block *secondBlock = &*(std::next(op.getBody().begin()));
2035 auto succOperands = branchOp->getOperands();
2036 SmallVector<Value, 2> replacements(succOperands);
2037 rewriter.eraseOp(op: branchOp);
2038 rewriter.mergeBlocks(source: secondBlock, dest: &firstBlock, argValues: replacements);
2039 rewriter.modifyOpInPlace(op, [] {});
2040 return success();
2041 }
2042};
2043
2044/// A rewrite pattern to tests the undo mechanism of blocks being merged.
2045struct TestUndoBlocksMerge : public ConversionPattern {
2046 TestUndoBlocksMerge(MLIRContext *ctx)
2047 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
2048 LogicalResult
2049 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
2050 ConversionPatternRewriter &rewriter) const final {
2051 Block &firstBlock = op->getRegion(index: 0).front();
2052 Operation *branchOp = firstBlock.getTerminator();
2053 Block *secondBlock = &*(std::next(x: op->getRegion(index: 0).begin()));
2054 rewriter.setInsertionPointToStart(secondBlock);
2055 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
2056 auto succOperands = branchOp->getOperands();
2057 SmallVector<Value, 2> replacements(succOperands);
2058 rewriter.eraseOp(op: branchOp);
2059 rewriter.mergeBlocks(source: secondBlock, dest: &firstBlock, argValues: replacements);
2060 rewriter.modifyOpInPlace(root: op, callable: [] {});
2061 return success();
2062 }
2063};
2064
2065/// A rewrite mechanism to inline the body of the op into its parent, when both
2066/// ops can have a single block.
2067struct TestMergeSingleBlockOps
2068 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
2069 using OpConversionPattern<
2070 SingleBlockImplicitTerminatorOp>::OpConversionPattern;
2071
2072 LogicalResult
2073 matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
2074 ConversionPatternRewriter &rewriter) const final {
2075 SingleBlockImplicitTerminatorOp parentOp =
2076 op->getParentOfType<SingleBlockImplicitTerminatorOp>();
2077 if (!parentOp)
2078 return failure();
2079 Block &innerBlock = op.getRegion().front();
2080 TerminatorOp innerTerminator =
2081 cast<TerminatorOp>(innerBlock.getTerminator());
2082 rewriter.inlineBlockBefore(&innerBlock, op);
2083 rewriter.eraseOp(op: innerTerminator);
2084 rewriter.eraseOp(op: op);
2085 return success();
2086 }
2087};
2088
2089struct TestMergeBlocksPatternDriver
2090 : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> {
2091 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver)
2092
2093 StringRef getArgument() const final { return "test-merge-blocks"; }
2094 StringRef getDescription() const final {
2095 return "Test Merging operation in ConversionPatternRewriter";
2096 }
2097 void runOnOperation() override {
2098 MLIRContext *context = &getContext();
2099 mlir::RewritePatternSet patterns(context);
2100 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
2101 arg&: context);
2102 ConversionTarget target(*context);
2103 target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
2104 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
2105 target.addIllegalOp<ILLegalOpF>();
2106
2107 /// Expect the op to have a single block after legalization.
2108 target.addDynamicallyLegalOp<TestMergeBlocksOp>(
2109 [&](TestMergeBlocksOp op) -> bool {
2110 return llvm::hasSingleElement(op.getBody());
2111 });
2112
2113 /// Only allow `test.br` within test.merge_blocks op.
2114 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
2115 return op->getParentOfType<TestMergeBlocksOp>();
2116 });
2117
2118 /// Expect that all nested test.SingleBlockImplicitTerminator ops are
2119 /// inlined.
2120 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
2121 [&](SingleBlockImplicitTerminatorOp op) -> bool {
2122 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
2123 });
2124
2125 DenseSet<Operation *> unlegalizedOps;
2126 ConversionConfig config;
2127 config.unlegalizedOps = &unlegalizedOps;
2128 (void)applyPartialConversion(getOperation(), target, std::move(patterns),
2129 config);
2130 for (auto *op : unlegalizedOps)
2131 op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
2132 }
2133};
2134} // namespace
2135
2136//===----------------------------------------------------------------------===//
2137// Test Selective Replacement
2138//===----------------------------------------------------------------------===//
2139
2140namespace {
2141/// A rewrite mechanism to inline the body of the op into its parent, when both
2142/// ops can have a single block.
2143struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
2144 using OpRewritePattern<TestCastOp>::OpRewritePattern;
2145
2146 LogicalResult matchAndRewrite(TestCastOp op,
2147 PatternRewriter &rewriter) const final {
2148 if (op.getNumOperands() != 2)
2149 return failure();
2150 OperandRange operands = op.getOperands();
2151
2152 // Replace non-terminator uses with the first operand.
2153 rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) {
2154 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
2155 });
2156 // Replace everything else with the second operand if the operation isn't
2157 // dead.
2158 rewriter.replaceOp(op, op.getOperand(1));
2159 return success();
2160 }
2161};
2162
2163struct TestSelectiveReplacementPatternDriver
2164 : public PassWrapper<TestSelectiveReplacementPatternDriver,
2165 OperationPass<>> {
2166 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
2167 TestSelectiveReplacementPatternDriver)
2168
2169 StringRef getArgument() const final {
2170 return "test-pattern-selective-replacement";
2171 }
2172 StringRef getDescription() const final {
2173 return "Test selective replacement in the PatternRewriter";
2174 }
2175 void runOnOperation() override {
2176 MLIRContext *context = &getContext();
2177 mlir::RewritePatternSet patterns(context);
2178 patterns.add<TestSelectiveOpReplacementPattern>(arg&: context);
2179 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
2180 }
2181};
2182} // namespace
2183
2184//===----------------------------------------------------------------------===//
2185// PassRegistration
2186//===----------------------------------------------------------------------===//
2187
2188namespace mlir {
2189namespace test {
2190void registerPatternsTestPass() {
2191 PassRegistration<TestReturnTypeDriver>();
2192
2193 PassRegistration<TestDerivedAttributeDriver>();
2194
2195 PassRegistration<TestGreedyPatternDriver>();
2196 PassRegistration<TestStrictPatternDriver>();
2197 PassRegistration<TestWalkPatternDriver>();
2198
2199 PassRegistration<TestLegalizePatternDriver>([] {
2200 return std::make_unique<TestLegalizePatternDriver>(args&: legalizerConversionMode);
2201 });
2202
2203 PassRegistration<TestRemappedValue>();
2204
2205 PassRegistration<TestUnknownRootOpDriver>();
2206
2207 PassRegistration<TestTypeConversionDriver>();
2208 PassRegistration<TestTargetMaterializationWithNoUses>();
2209
2210 PassRegistration<TestRewriteDynamicOpDriver>();
2211
2212 PassRegistration<TestMergeBlocksPatternDriver>();
2213 PassRegistration<TestSelectiveReplacementPatternDriver>();
2214}
2215} // namespace test
2216} // namespace mlir
2217

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/test/lib/Dialect/Test/TestPatterns.cpp