1//===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestDialect.h"
10#include "TestOps.h"
11#include "TestTypes.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/IR/BuiltinAttributes.h"
17#include "mlir/IR/Matchers.h"
18#include "mlir/IR/Visitors.h"
19#include "mlir/Pass/Pass.h"
20#include "mlir/Transforms/DialectConversion.h"
21#include "mlir/Transforms/FoldUtils.h"
22#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23#include "mlir/Transforms/WalkPatternRewriteDriver.h"
24#include "llvm/ADT/ScopeExit.h"
25#include <cstdint>
26
27using namespace mlir;
28using namespace test;
29
30// Native function for testing NativeCodeCall
31static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
32 return choice.getValue() ? input1 : input2;
33}
34
35static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
36 rewriter.create<OpI>(location: loc, args&: input);
37}
38
39static void handleNoResultOp(PatternRewriter &rewriter,
40 OpSymbolBindingNoResult op) {
41 // Turn the no result op to a one-result op.
42 rewriter.create<OpSymbolBindingB>(location: op.getLoc(), args: op.getOperand().getType(),
43 args: op.getOperand());
44}
45
46static bool getFirstI32Result(Operation *op, Value &value) {
47 if (!Type(op->getResult(idx: 0).getType()).isSignlessInteger(width: 32))
48 return false;
49 value = op->getResult(idx: 0);
50 return true;
51}
52
53static Value bindNativeCodeCallResult(Value value) { return value; }
54
55static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
56 Value input2) {
57 return SmallVector<Value, 2>({input2, input1});
58}
59
60// Test that natives calls are only called once during rewrites.
61// OpM_Test will return Pi, increased by 1 for each subsequent calls.
62// This let us check the number of times OpM_Test was called by inspecting
63// the returned value in the MLIR output.
64static int64_t opMIncreasingValue = 314159265;
65static Attribute opMTest(PatternRewriter &rewriter, Value val) {
66 int64_t i = opMIncreasingValue++;
67 return rewriter.getIntegerAttr(type: rewriter.getIntegerType(width: 32), value: i);
68}
69
70namespace {
71#include "TestPatterns.inc"
72} // namespace
73
74//===----------------------------------------------------------------------===//
75// Test Reduce Pattern Interface
76//===----------------------------------------------------------------------===//
77
78void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
79 populateWithGenerated(patterns);
80}
81
82//===----------------------------------------------------------------------===//
83// Canonicalizer Driver.
84//===----------------------------------------------------------------------===//
85
86namespace {
87struct FoldingPattern : public RewritePattern {
88public:
89 FoldingPattern(MLIRContext *context)
90 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
91 /*benefit=*/1, context) {}
92
93 LogicalResult matchAndRewrite(Operation *op,
94 PatternRewriter &rewriter) const override {
95 // Exercise createOrFold API for a single-result operation that is folded
96 // upon construction. The operation being created has an in-place folder,
97 // and it should be still present in the output. Furthermore, the folder
98 // should not crash when attempting to recover the (unchanged) operation
99 // result.
100 Value result = rewriter.createOrFold<TestOpInPlaceFold>(
101 location: op->getLoc(), args: rewriter.getIntegerType(width: 32), args: op->getOperand(idx: 0));
102 assert(result);
103 rewriter.replaceOp(op, newValues: result);
104 return success();
105 }
106};
107
108/// This pattern creates a foldable operation at the entry point of the block.
109/// This tests the situation where the operation folder will need to replace an
110/// operation with a previously created constant that does not initially
111/// dominate the operation to replace.
112struct FolderInsertBeforePreviouslyFoldedConstantPattern
113 : public OpRewritePattern<TestCastOp> {
114public:
115 using OpRewritePattern<TestCastOp>::OpRewritePattern;
116
117 LogicalResult matchAndRewrite(TestCastOp op,
118 PatternRewriter &rewriter) const override {
119 if (!op->hasAttr(name: "test_fold_before_previously_folded_op"))
120 return failure();
121 rewriter.setInsertionPointToStart(op->getBlock());
122
123 auto constOp = rewriter.create<arith::ConstantOp>(
124 location: op.getLoc(), args: rewriter.getBoolAttr(value: true));
125 rewriter.replaceOpWithNewOp<TestCastOp>(op, args: rewriter.getI32Type(),
126 args: Value(constOp));
127 return success();
128 }
129};
130
131/// This pattern matches test.op_commutative2 with the first operand being
132/// another test.op_commutative2 with a constant on the right side and fold it
133/// away by propagating it as its result. This is intend to check that patterns
134/// are applied after the commutative property moves constant to the right.
135struct FolderCommutativeOp2WithConstant
136 : public OpRewritePattern<TestCommutative2Op> {
137public:
138 using OpRewritePattern<TestCommutative2Op>::OpRewritePattern;
139
140 LogicalResult matchAndRewrite(TestCommutative2Op op,
141 PatternRewriter &rewriter) const override {
142 auto operand =
143 dyn_cast_or_null<TestCommutative2Op>(Val: op->getOperand(idx: 0).getDefiningOp());
144 if (!operand)
145 return failure();
146 Attribute constInput;
147 if (!matchPattern(value: operand->getOperand(idx: 1), pattern: m_Constant(bind_value: &constInput)))
148 return failure();
149 rewriter.replaceOp(op, newValues: operand->getOperand(idx: 1));
150 return success();
151 }
152};
153
154/// This pattern matches test.any_attr_of_i32_str ops. In case of an integer
155/// attribute with value smaller than MaxVal, it increments the value by 1.
156template <int MaxVal>
157struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
158 using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern;
159
160 LogicalResult matchAndRewrite(AnyAttrOfOp op,
161 PatternRewriter &rewriter) const override {
162 auto intAttr = dyn_cast<IntegerAttr>(Val: op.getAttr());
163 if (!intAttr)
164 return failure();
165 int64_t val = intAttr.getInt();
166 if (val >= MaxVal)
167 return failure();
168 rewriter.modifyOpInPlace(
169 op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(value: val + 1)); });
170 return success();
171 }
172};
173
174/// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op".
175struct MakeOpEligible : public RewritePattern {
176 MakeOpEligible(MLIRContext *context)
177 : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {}
178
179 LogicalResult matchAndRewrite(Operation *op,
180 PatternRewriter &rewriter) const override {
181 if (op->hasAttr(name: "eligible"))
182 return failure();
183 rewriter.modifyOpInPlace(
184 root: op, callable: [&]() { op->setAttr(name: "eligible", value: rewriter.getUnitAttr()); });
185 return success();
186 }
187};
188
189/// This pattern hoists eligible ops out of a "test.one_region_op".
190struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
191 using OpRewritePattern<test::OneRegionOp>::OpRewritePattern;
192
193 LogicalResult matchAndRewrite(test::OneRegionOp op,
194 PatternRewriter &rewriter) const override {
195 Operation *terminator = op.getRegion().front().getTerminator();
196 Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp();
197 if (toBeHoisted->getParentOp() != op)
198 return failure();
199 if (!toBeHoisted->hasAttr(name: "eligible"))
200 return failure();
201 rewriter.moveOpBefore(op: toBeHoisted, existingOp: op);
202 return success();
203 }
204};
205
206/// This pattern moves "test.move_before_parent_op" before the parent op.
207struct MoveBeforeParentOp : public RewritePattern {
208 MoveBeforeParentOp(MLIRContext *context)
209 : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
210
211 LogicalResult matchAndRewrite(Operation *op,
212 PatternRewriter &rewriter) const override {
213 // Do not hoist past functions.
214 if (isa<FunctionOpInterface>(Val: op->getParentOp()))
215 return failure();
216 rewriter.moveOpBefore(op, existingOp: op->getParentOp());
217 return success();
218 }
219};
220
221/// This pattern moves "test.move_after_parent_op" after the parent op.
222struct MoveAfterParentOp : public RewritePattern {
223 MoveAfterParentOp(MLIRContext *context)
224 : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {}
225
226 LogicalResult matchAndRewrite(Operation *op,
227 PatternRewriter &rewriter) const override {
228 // Do not hoist past functions.
229 if (isa<FunctionOpInterface>(Val: op->getParentOp()))
230 return failure();
231
232 int64_t moveForwardBy = 0;
233 if (auto advanceBy = op->getAttrOfType<IntegerAttr>(name: "advance"))
234 moveForwardBy = advanceBy.getInt();
235
236 Operation *moveAfter = op->getParentOp();
237 for (int64_t i = 0; i < moveForwardBy; ++i)
238 moveAfter = moveAfter->getNextNode();
239
240 rewriter.moveOpAfter(op, existingOp: moveAfter);
241 return success();
242 }
243};
244
245/// This pattern inlines blocks that are nested in
246/// "test.inline_blocks_into_parent" into the parent block.
247struct InlineBlocksIntoParent : public RewritePattern {
248 InlineBlocksIntoParent(MLIRContext *context)
249 : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
250 context) {}
251
252 LogicalResult matchAndRewrite(Operation *op,
253 PatternRewriter &rewriter) const override {
254 bool changed = false;
255 for (Region &r : op->getRegions()) {
256 while (!r.empty()) {
257 rewriter.inlineBlockBefore(source: &r.front(), op);
258 changed = true;
259 }
260 }
261 return success(IsSuccess: changed);
262 }
263};
264
265/// This pattern splits blocks at "test.split_block_here" and replaces the op
266/// with a new op (to prevent an infinite loop of block splitting).
267struct SplitBlockHere : public RewritePattern {
268 SplitBlockHere(MLIRContext *context)
269 : RewritePattern("test.split_block_here", /*benefit=*/1, context) {}
270
271 LogicalResult matchAndRewrite(Operation *op,
272 PatternRewriter &rewriter) const override {
273 rewriter.splitBlock(block: op->getBlock(), before: op->getIterator());
274 Operation *newOp = rewriter.create(
275 loc: op->getLoc(),
276 opName: OperationName("test.new_op", op->getContext()).getIdentifier(),
277 operands: op->getOperands(), types: op->getResultTypes());
278 rewriter.replaceOp(op, newOp);
279 return success();
280 }
281};
282
283/// This pattern clones "test.clone_me" ops.
284struct CloneOp : public RewritePattern {
285 CloneOp(MLIRContext *context)
286 : RewritePattern("test.clone_me", /*benefit=*/1, context) {}
287
288 LogicalResult matchAndRewrite(Operation *op,
289 PatternRewriter &rewriter) const override {
290 // Do not clone already cloned ops to avoid going into an infinite loop.
291 if (op->hasAttr(name: "was_cloned"))
292 return failure();
293 Operation *cloned = rewriter.clone(op&: *op);
294 cloned->setAttr(name: "was_cloned", value: rewriter.getUnitAttr());
295 return success();
296 }
297};
298
299/// This pattern clones regions of "test.clone_region_before" ops before the
300/// parent block.
301struct CloneRegionBeforeOp : public RewritePattern {
302 CloneRegionBeforeOp(MLIRContext *context)
303 : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {}
304
305 LogicalResult matchAndRewrite(Operation *op,
306 PatternRewriter &rewriter) const override {
307 // Do not clone already cloned ops to avoid going into an infinite loop.
308 if (op->hasAttr(name: "was_cloned"))
309 return failure();
310 for (Region &r : op->getRegions())
311 rewriter.cloneRegionBefore(region&: r, before: op->getBlock());
312 op->setAttr(name: "was_cloned", value: rewriter.getUnitAttr());
313 return success();
314 }
315};
316
317/// Replace an operation may introduce the re-visiting of its users.
318class ReplaceWithNewOp : public RewritePattern {
319public:
320 ReplaceWithNewOp(MLIRContext *context)
321 : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
322
323 LogicalResult matchAndRewrite(Operation *op,
324 PatternRewriter &rewriter) const override {
325 Operation *newOp;
326 if (op->hasAttr(name: "create_erase_op")) {
327 newOp = rewriter.create(
328 loc: op->getLoc(),
329 opName: OperationName("test.erase_op", op->getContext()).getIdentifier(),
330 operands: ValueRange(), types: TypeRange());
331 } else {
332 newOp = rewriter.create(
333 loc: op->getLoc(),
334 opName: OperationName("test.new_op", op->getContext()).getIdentifier(),
335 operands: op->getOperands(), types: op->getResultTypes());
336 }
337 // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
338 // A "notifyOperationReplaced" callback is triggered in either case.
339 rewriter.replaceAllOpUsesWith(from: op, to: newOp->getResults());
340 rewriter.eraseOp(op);
341 return success();
342 }
343};
344
345/// Erases the first child block of the matched "test.erase_first_block"
346/// operation.
347class EraseFirstBlock : public RewritePattern {
348public:
349 EraseFirstBlock(MLIRContext *context)
350 : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {}
351
352 LogicalResult matchAndRewrite(Operation *op,
353 PatternRewriter &rewriter) const override {
354 for (Region &r : op->getRegions()) {
355 for (Block &b : r.getBlocks()) {
356 rewriter.eraseBlock(block: &b);
357 return success();
358 }
359 }
360
361 return failure();
362 }
363};
364
365struct TestGreedyPatternDriver
366 : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> {
367 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver)
368
369 TestGreedyPatternDriver() = default;
370 TestGreedyPatternDriver(const TestGreedyPatternDriver &other)
371 : PassWrapper(other) {}
372
373 StringRef getArgument() const final { return "test-greedy-patterns"; }
374 StringRef getDescription() const final { return "Run test dialect patterns"; }
375 void runOnOperation() override {
376 mlir::RewritePatternSet patterns(&getContext());
377 populateWithGenerated(patterns);
378
379 // Verify named pattern is generated with expected name.
380 patterns.add<FoldingPattern, TestNamedPatternRule,
381 FolderInsertBeforePreviouslyFoldedConstantPattern,
382 FolderCommutativeOp2WithConstant, HoistEligibleOps,
383 MakeOpEligible>(arg: &getContext());
384
385 // Additional patterns for testing the GreedyPatternRewriteDriver.
386 patterns.insert<IncrementIntAttribute<3>>(arg: &getContext());
387
388 GreedyRewriteConfig config;
389 config.setUseTopDownTraversal(useTopDownTraversal)
390 .setMaxIterations(this->maxIterations)
391 .enableFolding(enable: this->fold)
392 .enableConstantCSE(enable: this->cseConstants);
393 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns), config);
394 }
395
396 Option<bool> useTopDownTraversal{
397 *this, "top-down",
398 llvm::cl::desc("Seed the worklist in general top-down order"),
399 llvm::cl::init(Val: GreedyRewriteConfig().getUseTopDownTraversal())};
400 Option<int> maxIterations{
401 *this, "max-iterations",
402 llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
403 llvm::cl::init(Val: GreedyRewriteConfig().getMaxIterations())};
404 Option<bool> fold{*this, "fold", llvm::cl::desc("Whether to fold"),
405 llvm::cl::init(Val: GreedyRewriteConfig().isFoldingEnabled())};
406 Option<bool> cseConstants{
407 *this, "cse-constants", llvm::cl::desc("Whether to CSE constants"),
408 llvm::cl::init(Val: GreedyRewriteConfig().isConstantCSEEnabled())};
409};
410
411struct DumpNotifications : public RewriterBase::Listener {
412 void notifyBlockInserted(Block *block, Region *previous,
413 Region::iterator previousIt) override {
414 llvm::outs() << "notifyBlockInserted";
415 if (block->getParentOp()) {
416 llvm::outs() << " into " << block->getParentOp()->getName() << ": ";
417 } else {
418 llvm::outs() << " into unknown op: ";
419 }
420 if (previous == nullptr) {
421 llvm::outs() << "was unlinked\n";
422 } else {
423 llvm::outs() << "was linked\n";
424 }
425 }
426 void notifyOperationInserted(Operation *op,
427 OpBuilder::InsertPoint previous) override {
428 llvm::outs() << "notifyOperationInserted: " << op->getName();
429 if (!previous.isSet()) {
430 llvm::outs() << ", was unlinked\n";
431 } else {
432 if (!previous.getPoint().getNodePtr()) {
433 llvm::outs() << ", was linked, exact position unknown\n";
434 } else if (previous.getPoint() == previous.getBlock()->end()) {
435 llvm::outs() << ", was last in block\n";
436 } else {
437 llvm::outs() << ", previous = " << previous.getPoint()->getName()
438 << "\n";
439 }
440 }
441 }
442 void notifyBlockErased(Block *block) override {
443 llvm::outs() << "notifyBlockErased\n";
444 }
445 void notifyOperationErased(Operation *op) override {
446 llvm::outs() << "notifyOperationErased: " << op->getName() << "\n";
447 }
448 void notifyOperationModified(Operation *op) override {
449 llvm::outs() << "notifyOperationModified: " << op->getName() << "\n";
450 }
451 void notifyOperationReplaced(Operation *op, ValueRange values) override {
452 llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n";
453 }
454};
455
456struct TestStrictPatternDriver
457 : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
458public:
459 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
460
461 TestStrictPatternDriver() = default;
462 TestStrictPatternDriver(const TestStrictPatternDriver &other)
463 : PassWrapper(other) {
464 strictMode = other.strictMode;
465 }
466
467 StringRef getArgument() const final { return "test-strict-pattern-driver"; }
468 StringRef getDescription() const final {
469 return "Test strict mode of pattern driver";
470 }
471
472 void runOnOperation() override {
473 MLIRContext *ctx = &getContext();
474 mlir::RewritePatternSet patterns(ctx);
475 patterns.add<
476 // clang-format off
477 ChangeBlockOp,
478 CloneOp,
479 CloneRegionBeforeOp,
480 EraseOp,
481 ImplicitChangeOp,
482 InlineBlocksIntoParent,
483 InsertSameOp,
484 MoveBeforeParentOp,
485 ReplaceWithNewOp,
486 SplitBlockHere
487 // clang-format on
488 >(arg&: ctx);
489 SmallVector<Operation *> ops;
490 getOperation()->walk(callback: [&](Operation *op) {
491 StringRef opName = op->getName().getStringRef();
492 if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
493 opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
494 opName == "test.move_before_parent_op" ||
495 opName == "test.inline_blocks_into_parent" ||
496 opName == "test.split_block_here" || opName == "test.clone_me" ||
497 opName == "test.clone_region_before") {
498 ops.push_back(Elt: op);
499 }
500 });
501
502 DumpNotifications dumpNotifications;
503 GreedyRewriteConfig config;
504 config.setListener(&dumpNotifications);
505 if (strictMode == "AnyOp") {
506 config.setStrictness(GreedyRewriteStrictness::AnyOp);
507 } else if (strictMode == "ExistingAndNewOps") {
508 config.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps);
509 } else if (strictMode == "ExistingOps") {
510 config.setStrictness(GreedyRewriteStrictness::ExistingOps);
511 } else {
512 llvm_unreachable("invalid strictness option");
513 }
514
515 // Check if these transformations introduce visiting of operations that
516 // are not in the `ops` set (The new created ops are valid). An invalid
517 // operation will trigger the assertion while processing.
518 bool changed = false;
519 bool allErased = false;
520 (void)applyOpPatternsGreedily(ops: ArrayRef(ops), patterns: std::move(patterns), config,
521 changed: &changed, allErased: &allErased);
522 Builder b(ctx);
523 getOperation()->setAttr(name: "pattern_driver_changed", value: b.getBoolAttr(value: changed));
524 getOperation()->setAttr(name: "pattern_driver_all_erased",
525 value: b.getBoolAttr(value: allErased));
526 }
527
528 Option<std::string> strictMode{
529 *this, "strictness",
530 llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"),
531 llvm::cl::init(Val: "AnyOp")};
532
533private:
534 // New inserted operation is valid for further transformation.
535 class InsertSameOp : public RewritePattern {
536 public:
537 InsertSameOp(MLIRContext *context)
538 : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {}
539
540 LogicalResult matchAndRewrite(Operation *op,
541 PatternRewriter &rewriter) const override {
542 if (op->hasAttr(name: "skip"))
543 return failure();
544
545 Operation *newOp =
546 rewriter.create(loc: op->getLoc(), opName: op->getName().getIdentifier(),
547 operands: op->getOperands(), types: op->getResultTypes());
548 rewriter.modifyOpInPlace(
549 root: op, callable: [&]() { op->setAttr(name: "skip", value: rewriter.getBoolAttr(value: true)); });
550 newOp->setAttr(name: "skip", value: rewriter.getBoolAttr(value: true));
551
552 return success();
553 }
554 };
555
556 // Remove an operation may introduce the re-visiting of its operands.
557 class EraseOp : public RewritePattern {
558 public:
559 EraseOp(MLIRContext *context)
560 : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
561 LogicalResult matchAndRewrite(Operation *op,
562 PatternRewriter &rewriter) const override {
563 rewriter.eraseOp(op);
564 return success();
565 }
566 };
567
568 // The following two patterns test RewriterBase::replaceAllUsesWith.
569 //
570 // That function replaces all usages of a Block (or a Value) with another one
571 // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver
572 // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its
573 // worklist: when an op is modified, it is added to the worklist. The two
574 // patterns below make the tracking observable: ChangeBlockOp replaces all
575 // usages of a block and that pattern is applied because the corresponding ops
576 // are put on the initial worklist (see above). ImplicitChangeOp does an
577 // unrelated change but ops of the corresponding type are *not* on the initial
578 // worklist, so the effect of the second pattern is only visible if the
579 // tracking and subsequent adding to the worklist actually works.
580
581 // Replace all usages of the first successor with the second successor.
582 class ChangeBlockOp : public RewritePattern {
583 public:
584 ChangeBlockOp(MLIRContext *context)
585 : RewritePattern("test.change_block_op", /*benefit=*/1, context) {}
586 LogicalResult matchAndRewrite(Operation *op,
587 PatternRewriter &rewriter) const override {
588 if (op->getNumSuccessors() < 2)
589 return failure();
590 Block *firstSuccessor = op->getSuccessor(index: 0);
591 Block *secondSuccessor = op->getSuccessor(index: 1);
592 if (firstSuccessor == secondSuccessor)
593 return failure();
594 // This is the function being tested:
595 rewriter.replaceAllUsesWith(from: firstSuccessor, to: secondSuccessor);
596 // Using the following line instead would make the test fail:
597 // firstSuccessor->replaceAllUsesWith(secondSuccessor);
598 return success();
599 }
600 };
601
602 // Changes the successor to the parent block.
603 class ImplicitChangeOp : public RewritePattern {
604 public:
605 ImplicitChangeOp(MLIRContext *context)
606 : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {}
607 LogicalResult matchAndRewrite(Operation *op,
608 PatternRewriter &rewriter) const override {
609 if (op->getNumSuccessors() < 1 || op->getSuccessor(index: 0) == op->getBlock())
610 return failure();
611 rewriter.modifyOpInPlace(root: op,
612 callable: [&]() { op->setSuccessor(block: op->getBlock(), index: 0); });
613 return success();
614 }
615 };
616};
617
618struct TestWalkPatternDriver final
619 : PassWrapper<TestWalkPatternDriver, OperationPass<>> {
620 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver)
621
622 TestWalkPatternDriver() = default;
623 TestWalkPatternDriver(const TestWalkPatternDriver &other)
624 : PassWrapper(other) {}
625
626 StringRef getArgument() const override {
627 return "test-walk-pattern-rewrite-driver";
628 }
629 StringRef getDescription() const override {
630 return "Run test walk pattern rewrite driver";
631 }
632 void runOnOperation() override {
633 mlir::RewritePatternSet patterns(&getContext());
634
635 // Patterns for testing the WalkPatternRewriteDriver.
636 patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp,
637 MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>(
638 arg: &getContext());
639
640 DumpNotifications dumpListener;
641 walkAndApplyPatterns(op: getOperation(), patterns: std::move(patterns),
642 listener: dumpNotifications ? &dumpListener : nullptr);
643 }
644
645 Option<bool> dumpNotifications{
646 *this, "dump-notifications",
647 llvm::cl::desc("Print rewrite listener notifications"),
648 llvm::cl::init(Val: false)};
649};
650
651} // namespace
652
653//===----------------------------------------------------------------------===//
654// ReturnType Driver.
655//===----------------------------------------------------------------------===//
656
657namespace {
658// Generate ops for each instance where the type can be successfully inferred.
659template <typename OpTy>
660static void invokeCreateWithInferredReturnType(Operation *op) {
661 auto *context = op->getContext();
662 auto fop = op->getParentOfType<func::FuncOp>();
663 auto location = UnknownLoc::get(context);
664 OpBuilder b(op);
665 b.setInsertionPointAfter(op);
666
667 // Use permutations of 2 args as operands.
668 assert(fop.getNumArguments() >= 2);
669 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
670 for (int j = 0; j < e; ++j) {
671 std::array<Value, 2> values = {._M_elems: {fop.getArgument(idx: i), fop.getArgument(idx: j)}};
672 SmallVector<Type, 2> inferredReturnTypes;
673 if (succeeded(OpTy::inferReturnTypes(
674 context, std::nullopt, values, op->getDiscardableAttrDictionary(),
675 op->getPropertiesStorage(), op->getRegions(),
676 inferredReturnTypes))) {
677 OperationState state(location, OpTy::getOperationName());
678 // TODO: Expand to regions.
679 OpTy::build(b, state, values, op->getAttrs());
680 (void)b.create(state);
681 }
682 }
683 }
684}
685
686static void reifyReturnShape(Operation *op) {
687 OpBuilder b(op);
688
689 // Use permutations of 2 args as operands.
690 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(Val: op);
691 SmallVector<Value, 2> shapes;
692 if (failed(Result: shapedOp.reifyReturnTypeShapes(builder&: b, operands: op->getOperands(), reifiedReturnShapes&: shapes)) ||
693 !llvm::hasSingleElement(C&: shapes))
694 return;
695 for (const auto &it : llvm::enumerate(First&: shapes)) {
696 op->emitRemark() << "value " << it.index() << ": "
697 << it.value().getDefiningOp();
698 }
699}
700
701struct TestReturnTypeDriver
702 : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> {
703 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver)
704
705 void getDependentDialects(DialectRegistry &registry) const override {
706 registry.insert<tensor::TensorDialect>();
707 }
708 StringRef getArgument() const final { return "test-return-type"; }
709 StringRef getDescription() const final { return "Run return type functions"; }
710
711 void runOnOperation() override {
712 if (getOperation().getName() == "testCreateFunctions") {
713 std::vector<Operation *> ops;
714 // Collect ops to avoid triggering on inserted ops.
715 for (auto &op : getOperation().getBody().front())
716 ops.push_back(x: &op);
717 // Generate test patterns for each, but skip terminator.
718 for (auto *op : llvm::ArrayRef(ops).drop_back()) {
719 // Test create method of each of the Op classes below. The resultant
720 // output would be in reverse order underneath `op` from which
721 // the attributes and regions are used.
722 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
723 invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>(
724 op);
725 invokeCreateWithInferredReturnType<
726 OpWithShapedTypeInferTypeInterfaceOp>(op);
727 };
728 return;
729 }
730 if (getOperation().getName() == "testReifyFunctions") {
731 std::vector<Operation *> ops;
732 // Collect ops to avoid triggering on inserted ops.
733 for (auto &op : getOperation().getBody().front())
734 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(Val: op))
735 ops.push_back(x: &op);
736 // Generate test patterns for each, but skip terminator.
737 for (auto *op : ops)
738 reifyReturnShape(op);
739 }
740 }
741};
742} // namespace
743
744namespace {
745struct TestDerivedAttributeDriver
746 : public PassWrapper<TestDerivedAttributeDriver,
747 OperationPass<func::FuncOp>> {
748 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver)
749
750 StringRef getArgument() const final { return "test-derived-attr"; }
751 StringRef getDescription() const final {
752 return "Run test derived attributes";
753 }
754 void runOnOperation() override;
755};
756} // namespace
757
758void TestDerivedAttributeDriver::runOnOperation() {
759 getOperation().walk(callback: [](DerivedAttributeOpInterface dOp) {
760 auto dAttr = dOp.materializeDerivedAttributes();
761 if (!dAttr)
762 return;
763 for (auto d : dAttr)
764 dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue();
765 });
766}
767
768//===----------------------------------------------------------------------===//
769// Legalization Driver.
770//===----------------------------------------------------------------------===//
771
772namespace {
773//===----------------------------------------------------------------------===//
774// Region-Block Rewrite Testing
775//===----------------------------------------------------------------------===//
776
777/// This pattern applies a signature conversion to a block inside a detached
778/// region.
779struct TestDetachedSignatureConversion : public ConversionPattern {
780 TestDetachedSignatureConversion(MLIRContext *ctx)
781 : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1,
782 ctx) {}
783
784 LogicalResult
785 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
786 ConversionPatternRewriter &rewriter) const final {
787 if (op->getNumRegions() != 1)
788 return failure();
789 OperationState state(op->getLoc(), "test.legal_op", operands,
790 op->getResultTypes(), {}, BlockRange());
791 Region *newRegion = state.addRegion();
792 rewriter.inlineRegionBefore(region&: op->getRegion(index: 0), parent&: *newRegion,
793 before: newRegion->begin());
794 TypeConverter::SignatureConversion result(newRegion->getNumArguments());
795 for (unsigned i = 0, e = newRegion->getNumArguments(); i < e; ++i)
796 result.addInputs(origInputNo: i, types: rewriter.getF64Type());
797 rewriter.applySignatureConversion(block: &newRegion->front(), conversion&: result);
798 Operation *newOp = rewriter.create(state);
799 rewriter.replaceOp(op, newValues: newOp->getResults());
800 return success();
801 }
802};
803
804/// This pattern is a simple pattern that inlines the first region of a given
805/// operation into the parent region.
806struct TestRegionRewriteBlockMovement : public ConversionPattern {
807 TestRegionRewriteBlockMovement(MLIRContext *ctx)
808 : ConversionPattern("test.region", 1, ctx) {}
809
810 LogicalResult
811 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
812 ConversionPatternRewriter &rewriter) const final {
813 // Inline this region into the parent region.
814 auto &parentRegion = *op->getParentRegion();
815 auto &opRegion = op->getRegion(index: 0);
816 if (op->getDiscardableAttr(name: "legalizer.should_clone"))
817 rewriter.cloneRegionBefore(region&: opRegion, parent&: parentRegion, before: parentRegion.end());
818 else
819 rewriter.inlineRegionBefore(region&: opRegion, parent&: parentRegion, before: parentRegion.end());
820
821 if (op->getDiscardableAttr(name: "legalizer.erase_old_blocks")) {
822 while (!opRegion.empty())
823 rewriter.eraseBlock(block: &opRegion.front());
824 }
825
826 // Drop this operation.
827 rewriter.eraseOp(op);
828 return success();
829 }
830};
831/// This pattern is a simple pattern that generates a region containing an
832/// illegal operation.
833struct TestRegionRewriteUndo : public RewritePattern {
834 TestRegionRewriteUndo(MLIRContext *ctx)
835 : RewritePattern("test.region_builder", 1, ctx) {}
836
837 LogicalResult matchAndRewrite(Operation *op,
838 PatternRewriter &rewriter) const final {
839 // Create the region operation with an entry block containing arguments.
840 OperationState newRegion(op->getLoc(), "test.region");
841 newRegion.addRegion();
842 auto *regionOp = rewriter.create(state: newRegion);
843 auto *entryBlock = rewriter.createBlock(parent: &regionOp->getRegion(index: 0));
844 entryBlock->addArgument(type: rewriter.getIntegerType(width: 64),
845 loc: rewriter.getUnknownLoc());
846
847 // Add an explicitly illegal operation to ensure the conversion fails.
848 rewriter.create<ILLegalOpF>(location: op->getLoc(), args: rewriter.getIntegerType(width: 32));
849 rewriter.create<TestValidOp>(location: op->getLoc(), args: ArrayRef<Value>());
850
851 // Drop this operation.
852 rewriter.eraseOp(op);
853 return success();
854 }
855};
856/// A simple pattern that creates a block at the end of the parent region of the
857/// matched operation.
858struct TestCreateBlock : public RewritePattern {
859 TestCreateBlock(MLIRContext *ctx)
860 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
861
862 LogicalResult matchAndRewrite(Operation *op,
863 PatternRewriter &rewriter) const final {
864 Region &region = *op->getParentRegion();
865 Type i32Type = rewriter.getIntegerType(width: 32);
866 Location loc = op->getLoc();
867 rewriter.createBlock(parent: &region, insertPt: region.end(), argTypes: {i32Type, i32Type}, locs: {loc, loc});
868 rewriter.create<TerminatorOp>(location: loc);
869 rewriter.eraseOp(op);
870 return success();
871 }
872};
873
874/// A simple pattern that creates a block containing an invalid operation in
875/// order to trigger the block creation undo mechanism.
876struct TestCreateIllegalBlock : public RewritePattern {
877 TestCreateIllegalBlock(MLIRContext *ctx)
878 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
879
880 LogicalResult matchAndRewrite(Operation *op,
881 PatternRewriter &rewriter) const final {
882 Region &region = *op->getParentRegion();
883 Type i32Type = rewriter.getIntegerType(width: 32);
884 Location loc = op->getLoc();
885 rewriter.createBlock(parent: &region, insertPt: region.end(), argTypes: {i32Type, i32Type}, locs: {loc, loc});
886 // Create an illegal op to ensure the conversion fails.
887 rewriter.create<ILLegalOpF>(location: loc, args&: i32Type);
888 rewriter.create<TerminatorOp>(location: loc);
889 rewriter.eraseOp(op);
890 return success();
891 }
892};
893
894/// A simple pattern that tests the "replaceUsesOfBlockArgument" API.
895struct TestBlockArgReplace : public ConversionPattern {
896 TestBlockArgReplace(MLIRContext *ctx, const TypeConverter &converter)
897 : ConversionPattern(converter, "test.block_arg_replace", /*benefit=*/1,
898 ctx) {}
899
900 LogicalResult
901 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
902 ConversionPatternRewriter &rewriter) const final {
903 // Replace the first block argument with 2x the second block argument.
904 Value repl = op->getRegion(index: 0).getArgument(i: 1);
905 rewriter.replaceUsesOfBlockArgument(from: op->getRegion(index: 0).getArgument(i: 0),
906 to: {repl, repl});
907 rewriter.modifyOpInPlace(root: op, callable: [&] {
908 // If the "trigger_rollback" attribute is set, keep the op illegal, so
909 // that a rollback is triggered.
910 if (!op->hasAttr(name: "trigger_rollback"))
911 op->setAttr(name: "is_legal", value: rewriter.getUnitAttr());
912 });
913 return success();
914 }
915};
916
917/// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
918/// This is to test the rollback logic.
919struct TestUndoMoveOpBefore : public ConversionPattern {
920 TestUndoMoveOpBefore(MLIRContext *ctx)
921 : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}
922
923 LogicalResult
924 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
925 ConversionPatternRewriter &rewriter) const override {
926 rewriter.moveOpBefore(op, existingOp: op->getParentOp());
927 // Replace with an illegal op to ensure the conversion fails.
928 rewriter.replaceOpWithNewOp<ILLegalOpF>(op, args: rewriter.getF32Type());
929 return success();
930 }
931};
932
933/// A rewrite pattern that tests the undo mechanism when erasing a block.
934struct TestUndoBlockErase : public ConversionPattern {
935 TestUndoBlockErase(MLIRContext *ctx)
936 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
937
938 LogicalResult
939 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
940 ConversionPatternRewriter &rewriter) const final {
941 Block *secondBlock = &*std::next(x: op->getRegion(index: 0).begin());
942 rewriter.setInsertionPointToStart(secondBlock);
943 rewriter.create<ILLegalOpF>(location: op->getLoc(), args: rewriter.getF32Type());
944 rewriter.eraseBlock(block: secondBlock);
945 rewriter.modifyOpInPlace(root: op, callable: [] {});
946 return success();
947 }
948};
949
950/// A pattern that modifies a property in-place, but keeps the op illegal.
951struct TestUndoPropertiesModification : public ConversionPattern {
952 TestUndoPropertiesModification(MLIRContext *ctx)
953 : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {}
954 LogicalResult
955 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
956 ConversionPatternRewriter &rewriter) const final {
957 if (!op->hasAttr(name: "modify_inplace"))
958 return failure();
959 rewriter.modifyOpInPlace(
960 root: op, callable: [&]() { cast<TestOpWithProperties>(Val: op).getProperties().setA(42); });
961 return success();
962 }
963};
964
965//===----------------------------------------------------------------------===//
966// Type-Conversion Rewrite Testing
967//===----------------------------------------------------------------------===//
968
969/// This patterns erases a region operation that has had a type conversion.
970struct TestDropOpSignatureConversion : public ConversionPattern {
971 TestDropOpSignatureConversion(MLIRContext *ctx,
972 const TypeConverter &converter)
973 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
974 LogicalResult
975 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
976 ConversionPatternRewriter &rewriter) const override {
977 Region &region = op->getRegion(index: 0);
978 Block *entry = &region.front();
979
980 // Convert the original entry arguments.
981 const TypeConverter &converter = *getTypeConverter();
982 TypeConverter::SignatureConversion result(entry->getNumArguments());
983 if (failed(Result: converter.convertSignatureArgs(types: entry->getArgumentTypes(),
984 result)) ||
985 failed(Result: rewriter.convertRegionTypes(region: &region, converter, entryConversion: &result)))
986 return failure();
987
988 // Convert the region signature and just drop the operation.
989 rewriter.eraseOp(op);
990 return success();
991 }
992};
993/// This pattern simply updates the operands of the given operation.
994struct TestPassthroughInvalidOp : public ConversionPattern {
995 TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
996 : ConversionPattern(converter, "test.invalid", 1, ctx) {}
997 LogicalResult
998 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
999 ConversionPatternRewriter &rewriter) const final {
1000 SmallVector<Value> flattened;
1001 for (auto it : llvm::enumerate(First&: operands)) {
1002 ValueRange range = it.value();
1003 if (range.size() == 1) {
1004 flattened.push_back(Elt: range.front());
1005 continue;
1006 }
1007
1008 // This is a 1:N replacement. Insert a test.cast op. (That's what the
1009 // argument materialization used to do.)
1010 flattened.push_back(
1011 Elt: rewriter
1012 .create<TestCastOp>(location: op->getLoc(),
1013 args: op->getOperand(idx: it.index()).getType(), args&: range)
1014 .getResult());
1015 }
1016 rewriter.replaceOpWithNewOp<TestValidOp>(op, args: TypeRange(), args&: flattened,
1017 args: ArrayRef<NamedAttribute>());
1018 return success();
1019 }
1020};
1021/// Replace with valid op, but simply drop the operands. This is used in a
1022/// regression where we used to generate circular unrealized_conversion_cast
1023/// ops.
1024struct TestDropAndReplaceInvalidOp : public ConversionPattern {
1025 TestDropAndReplaceInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
1026 : ConversionPattern(converter,
1027 "test.drop_operands_and_replace_with_valid", 1, ctx) {
1028 }
1029 LogicalResult
1030 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1031 ConversionPatternRewriter &rewriter) const final {
1032 rewriter.replaceOpWithNewOp<TestValidOp>(op, args: TypeRange(), args: ValueRange(),
1033 args: ArrayRef<NamedAttribute>());
1034 return success();
1035 }
1036};
1037/// This pattern handles the case of a split return value.
1038struct TestSplitReturnType : public ConversionPattern {
1039 TestSplitReturnType(MLIRContext *ctx)
1040 : ConversionPattern("test.return", 1, ctx) {}
1041 LogicalResult
1042 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1043 ConversionPatternRewriter &rewriter) const final {
1044 // Check for a return of F32.
1045 if (op->getNumOperands() != 1 || !op->getOperand(idx: 0).getType().isF32())
1046 return failure();
1047 rewriter.replaceOpWithNewOp<TestReturnOp>(op, args: operands[0]);
1048 return success();
1049 }
1050};
1051
1052//===----------------------------------------------------------------------===//
1053// Multi-Level Type-Conversion Rewrite Testing
1054struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
1055 TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
1056 : ConversionPattern("test.type_producer", 1, ctx) {}
1057 LogicalResult
1058 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1059 ConversionPatternRewriter &rewriter) const final {
1060 // If the type is I32, change the type to F32.
1061 if (!Type(*op->result_type_begin()).isSignlessInteger(width: 32))
1062 return failure();
1063 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, args: rewriter.getF32Type());
1064 return success();
1065 }
1066};
1067struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
1068 TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
1069 : ConversionPattern("test.type_producer", 1, ctx) {}
1070 LogicalResult
1071 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1072 ConversionPatternRewriter &rewriter) const final {
1073 // If the type is F32, change the type to F64.
1074 if (!Type(*op->result_type_begin()).isF32())
1075 return rewriter.notifyMatchFailure(arg&: op, msg: "expected single f32 operand");
1076 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, args: rewriter.getF64Type());
1077 return success();
1078 }
1079};
1080struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
1081 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
1082 : ConversionPattern("test.type_producer", 10, ctx) {}
1083 LogicalResult
1084 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1085 ConversionPatternRewriter &rewriter) const final {
1086 // Always convert to B16, even though it is not a legal type. This tests
1087 // that values are unmapped correctly.
1088 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, args: rewriter.getBF16Type());
1089 return success();
1090 }
1091};
1092struct TestUpdateConsumerType : public ConversionPattern {
1093 TestUpdateConsumerType(MLIRContext *ctx)
1094 : ConversionPattern("test.type_consumer", 1, ctx) {}
1095 LogicalResult
1096 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1097 ConversionPatternRewriter &rewriter) const final {
1098 // Verify that the incoming operand has been successfully remapped to F64.
1099 if (!operands[0].getType().isF64())
1100 return failure();
1101 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, args: operands[0]);
1102 return success();
1103 }
1104};
1105
1106//===----------------------------------------------------------------------===//
1107// Non-Root Replacement Rewrite Testing
1108/// This pattern generates an invalid operation, but replaces it before the
1109/// pattern is finished. This checks that we don't need to legalize the
1110/// temporary op.
1111struct TestNonRootReplacement : public RewritePattern {
1112 TestNonRootReplacement(MLIRContext *ctx)
1113 : RewritePattern("test.replace_non_root", 1, ctx) {}
1114
1115 LogicalResult matchAndRewrite(Operation *op,
1116 PatternRewriter &rewriter) const final {
1117 auto resultType = *op->result_type_begin();
1118 auto illegalOp = rewriter.create<ILLegalOpF>(location: op->getLoc(), args&: resultType);
1119 auto legalOp = rewriter.create<LegalOpB>(location: op->getLoc(), args&: resultType);
1120
1121 rewriter.replaceOp(op: illegalOp, newOp: legalOp);
1122 rewriter.replaceOp(op, newOp: illegalOp);
1123 return success();
1124 }
1125};
1126
1127//===----------------------------------------------------------------------===//
1128// Recursive Rewrite Testing
1129/// This pattern is applied to the same operation multiple times, but has a
1130/// bounded recursion.
1131struct TestBoundedRecursiveRewrite
1132 : public OpRewritePattern<TestRecursiveRewriteOp> {
1133 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
1134
1135 void initialize() {
1136 // The conversion target handles bounding the recursion of this pattern.
1137 setHasBoundedRewriteRecursion();
1138 }
1139
1140 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
1141 PatternRewriter &rewriter) const final {
1142 // Decrement the depth of the op in-place.
1143 rewriter.modifyOpInPlace(root: op, callable: [&] {
1144 op->setAttr(name: "depth", value: rewriter.getI64IntegerAttr(value: op.getDepth() - 1));
1145 });
1146 return success();
1147 }
1148};
1149
1150struct TestNestedOpCreationUndoRewrite
1151 : public OpRewritePattern<IllegalOpWithRegionAnchor> {
1152 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
1153
1154 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
1155 PatternRewriter &rewriter) const final {
1156 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1157 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1158 return success();
1159 };
1160};
1161
1162// This pattern matches `test.blackhole` and delete this op and its producer.
1163struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
1164 using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
1165
1166 LogicalResult matchAndRewrite(BlackHoleOp op,
1167 PatternRewriter &rewriter) const final {
1168 Operation *producer = op.getOperand().getDefiningOp();
1169 // Always erase the user before the producer, the framework should handle
1170 // this correctly.
1171 rewriter.eraseOp(op);
1172 rewriter.eraseOp(op: producer);
1173 return success();
1174 };
1175};
1176
1177// This pattern replaces explicitly illegal op with explicitly legal op,
1178// but in addition creates unregistered operation.
1179struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
1180 using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
1181
1182 LogicalResult matchAndRewrite(ILLegalOpG op,
1183 PatternRewriter &rewriter) const final {
1184 IntegerAttr attr = rewriter.getI32IntegerAttr(value: 0);
1185 Value val = rewriter.create<arith::ConstantOp>(location: op->getLoc(), args&: attr);
1186 rewriter.replaceOpWithNewOp<LegalOpC>(op, args&: val);
1187 return success();
1188 };
1189};
1190
1191class TestEraseOp : public ConversionPattern {
1192public:
1193 TestEraseOp(MLIRContext *ctx) : ConversionPattern("test.erase_op", 1, ctx) {}
1194 LogicalResult
1195 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1196 ConversionPatternRewriter &rewriter) const final {
1197 // Erase op without replacements.
1198 rewriter.eraseOp(op);
1199 return success();
1200 }
1201};
1202
1203/// This pattern matches a test.convert_block_args op. It either:
1204/// a) Duplicates all block arguments,
1205/// b) or: drops all block arguments and replaces each with 2x the first
1206/// operand.
1207class TestConvertBlockArgs : public OpConversionPattern<ConvertBlockArgsOp> {
1208 using OpConversionPattern<ConvertBlockArgsOp>::OpConversionPattern;
1209
1210 LogicalResult
1211 matchAndRewrite(ConvertBlockArgsOp op, OpAdaptor adaptor,
1212 ConversionPatternRewriter &rewriter) const override {
1213 if (op.getIsLegal())
1214 return failure();
1215 Block *body = &op.getBody().front();
1216 TypeConverter::SignatureConversion result(body->getNumArguments());
1217 for (auto it : llvm::enumerate(First: body->getArgumentTypes())) {
1218 if (op.getReplaceWithOperand()) {
1219 result.remapInput(origInputNo: it.index(), replacements: {adaptor.getVal(), adaptor.getVal()});
1220 } else if (op.getDuplicate()) {
1221 result.addInputs(origInputNo: it.index(), types: {it.value(), it.value()});
1222 } else {
1223 // No action specified. Pattern does not apply.
1224 return failure();
1225 }
1226 }
1227 rewriter.startOpModification(op);
1228 rewriter.applySignatureConversion(block: body, conversion&: result, converter: getTypeConverter());
1229 op.setIsLegal(true);
1230 rewriter.finalizeOpModification(op);
1231 return success();
1232 }
1233};
1234
1235/// This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid
1236/// op. The pattern supports 1:N replacements and forwards the replacement
1237/// values of the single operand as test.valid operands.
1238class TestRepetitive1ToNConsumer : public ConversionPattern {
1239public:
1240 TestRepetitive1ToNConsumer(MLIRContext *ctx)
1241 : ConversionPattern("test.repetitive_1_to_n_consumer", 1, ctx) {}
1242 LogicalResult
1243 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1244 ConversionPatternRewriter &rewriter) const final {
1245 // A single operand is expected.
1246 if (op->getNumOperands() != 1)
1247 return failure();
1248 rewriter.replaceOpWithNewOp<TestValidOp>(op, args: operands.front());
1249 return success();
1250 }
1251};
1252
1253/// A pattern that tests two back-to-back 1 -> 2 op replacements.
1254class TestMultiple1ToNReplacement : public ConversionPattern {
1255public:
1256 TestMultiple1ToNReplacement(MLIRContext *ctx, const TypeConverter &converter)
1257 : ConversionPattern(converter, "test.multiple_1_to_n_replacement", 1,
1258 ctx) {}
1259 LogicalResult
1260 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
1261 ConversionPatternRewriter &rewriter) const final {
1262 // Helper function that replaces the given op with a new op of the given
1263 // name and doubles each result (1 -> 2 replacement of each result).
1264 auto replaceWithDoubleResults = [&](Operation *op, StringRef name) {
1265 SmallVector<Type> types;
1266 for (Type t : op->getResultTypes()) {
1267 types.push_back(Elt: t);
1268 types.push_back(Elt: t);
1269 }
1270 OperationState state(op->getLoc(), name,
1271 /*operands=*/{}, types, op->getAttrs());
1272 auto *newOp = rewriter.create(state);
1273 SmallVector<ValueRange> repls;
1274 for (size_t i = 0, e = op->getNumResults(); i < e; ++i)
1275 repls.push_back(Elt: newOp->getResults().slice(n: 2 * i, m: 2));
1276 rewriter.replaceOpWithMultiple(op, newValues&: repls);
1277 return newOp;
1278 };
1279
1280 // Replace test.multiple_1_to_n_replacement with test.step_1.
1281 Operation *repl1 = replaceWithDoubleResults(op, "test.step_1");
1282 // Now replace test.step_1 with test.legal_op.
1283 replaceWithDoubleResults(repl1, "test.legal_op");
1284 return success();
1285 }
1286};
1287
1288/// Test unambiguous overload resolution of replaceOpWithMultiple. This
1289/// function is just to trigger compiler errors. It is never executed.
1290[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
1291 ConversionPatternRewriter &rewriter, Operation *op, ArrayRef<ValueRange> r1,
1292 SmallVector<ValueRange> r2, ArrayRef<SmallVector<Value>> r3,
1293 SmallVector<SmallVector<Value>> r4, ArrayRef<ArrayRef<Value>> r5,
1294 SmallVector<ArrayRef<Value>> r6, SmallVector<SmallVector<Value>> &&r7,
1295 Value v, ValueRange vr, ArrayRef<Value> ar) {
1296 rewriter.replaceOpWithMultiple(op, newValues: r1);
1297 rewriter.replaceOpWithMultiple(op, newValues&: r2);
1298 rewriter.replaceOpWithMultiple(op, newValues: r3);
1299 rewriter.replaceOpWithMultiple(op, newValues&: r4);
1300 rewriter.replaceOpWithMultiple(op, newValues: r5);
1301 rewriter.replaceOpWithMultiple(op, newValues&: r6);
1302 rewriter.replaceOpWithMultiple(op, newValues: std::move(r7));
1303 rewriter.replaceOpWithMultiple(op, newValues: {vr});
1304 rewriter.replaceOpWithMultiple(op, newValues: {ar});
1305 rewriter.replaceOpWithMultiple(op, newValues: {{v}});
1306 rewriter.replaceOpWithMultiple(op, newValues: {{v, v}});
1307 rewriter.replaceOpWithMultiple(op, newValues: {{v, v}, vr});
1308 rewriter.replaceOpWithMultiple(op, newValues: {{v, v}, ar});
1309 rewriter.replaceOpWithMultiple(op, newValues: {ar, {v, v}, vr});
1310}
1311} // namespace
1312
1313namespace {
1314struct TestTypeConverter : public TypeConverter {
1315 using TypeConverter::TypeConverter;
1316 TestTypeConverter() {
1317 addConversion(callback&: convertType);
1318 addSourceMaterialization(callback&: materializeCast);
1319 }
1320
1321 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
1322 // Drop I16 types.
1323 if (t.isSignlessInteger(width: 16))
1324 return success();
1325
1326 // Convert I64 to F64.
1327 if (t.isSignlessInteger(width: 64)) {
1328 results.push_back(Elt: Float64Type::get(context: t.getContext()));
1329 return success();
1330 }
1331
1332 // Convert I42 to I43.
1333 if (t.isInteger(width: 42)) {
1334 results.push_back(Elt: IntegerType::get(context: t.getContext(), width: 43));
1335 return success();
1336 }
1337
1338 // Split F32 into F16,F16.
1339 if (t.isF32()) {
1340 results.assign(NumElts: 2, Elt: Float16Type::get(context: t.getContext()));
1341 return success();
1342 }
1343
1344 // Drop I24 types.
1345 if (t.isInteger(width: 24)) {
1346 return success();
1347 }
1348
1349 // Otherwise, convert the type directly.
1350 results.push_back(Elt: t);
1351 return success();
1352 }
1353
1354 /// Hook for materializing a conversion. This is necessary because we generate
1355 /// 1->N type mappings.
1356 static Value materializeCast(OpBuilder &builder, Type resultType,
1357 ValueRange inputs, Location loc) {
1358 return builder.create<TestCastOp>(location: loc, args&: resultType, args&: inputs).getResult();
1359 }
1360};
1361
1362struct TestLegalizePatternDriver
1363 : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> {
1364 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver)
1365
1366 StringRef getArgument() const final { return "test-legalize-patterns"; }
1367 StringRef getDescription() const final {
1368 return "Run test dialect legalization patterns";
1369 }
1370 /// The mode of conversion to use with the driver.
1371 enum class ConversionMode { Analysis, Full, Partial };
1372
1373 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
1374
1375 void getDependentDialects(DialectRegistry &registry) const override {
1376 registry.insert<func::FuncDialect, test::TestDialect>();
1377 }
1378
1379 void runOnOperation() override {
1380 TestTypeConverter converter;
1381 mlir::RewritePatternSet patterns(&getContext());
1382 populateWithGenerated(patterns);
1383 patterns.add<
1384 TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1385 TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1386 TestUndoBlockErase, TestSplitReturnType, TestChangeProducerTypeI32ToF32,
1387 TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
1388 TestUpdateConsumerType, TestNonRootReplacement,
1389 TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
1390 TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1391 TestUndoPropertiesModification, TestEraseOp,
1392 TestRepetitive1ToNConsumer>(arg: &getContext());
1393 patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
1394 TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
1395 TestBlockArgReplace>(arg: &getContext(), args&: converter);
1396 patterns.add<TestConvertBlockArgs>(arg&: converter, args: &getContext());
1397 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1398 converter);
1399 mlir::populateCallOpTypeConversionPattern(patterns, converter);
1400
1401 // Define the conversion target used for the test.
1402 ConversionTarget target(getContext());
1403 target.addLegalOp<ModuleOp>();
1404 target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
1405 TerminatorOp, OneRegionOp>();
1406 target.addLegalOp(op: OperationName("test.legal_op", &getContext()));
1407 target
1408 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
1409 target.addDynamicallyLegalOp<TestReturnOp>(callback: [](TestReturnOp op) {
1410 // Don't allow F32 operands.
1411 return llvm::none_of(Range: op.getOperandTypes(),
1412 P: [](Type type) { return type.isF32(); });
1413 });
1414 target.addDynamicallyLegalOp<func::FuncOp>(callback: [&](func::FuncOp op) {
1415 return converter.isSignatureLegal(ty: op.getFunctionType()) &&
1416 converter.isLegal(region: &op.getBody());
1417 });
1418 target.addDynamicallyLegalOp<func::CallOp>(
1419 callback: [&](func::CallOp op) { return converter.isLegal(op); });
1420 target.addDynamicallyLegalOp(
1421 op: OperationName("test.block_arg_replace", &getContext()),
1422 callback: [](Operation *op) { return op->hasAttr(name: "is_legal"); });
1423
1424 // TestCreateUnregisteredOp creates `arith.constant` operation,
1425 // which was not added to target intentionally to test
1426 // correct error code from conversion driver.
1427 target.addDynamicallyLegalOp<ILLegalOpG>(callback: [](ILLegalOpG) { return false; });
1428
1429 // Expect the type_producer/type_consumer operations to only operate on f64.
1430 target.addDynamicallyLegalOp<TestTypeProducerOp>(
1431 callback: [](TestTypeProducerOp op) { return op.getType().isF64(); });
1432 target.addDynamicallyLegalOp<TestTypeConsumerOp>(callback: [](TestTypeConsumerOp op) {
1433 return op.getOperand().getType().isF64();
1434 });
1435
1436 // Check support for marking certain operations as recursively legal.
1437 target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>(callback: [](Operation *op) {
1438 return static_cast<bool>(
1439 op->getAttrOfType<UnitAttr>(name: "test.recursively_legal"));
1440 });
1441
1442 // Mark the bound recursion operation as dynamically legal.
1443 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
1444 callback: [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
1445
1446 // Create a dynamically legal rule that can only be legalized by folding it.
1447 target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
1448 callback: [](TestOpInPlaceSelfFold op) { return op.getFolded(); });
1449
1450 target.addDynamicallyLegalOp<ConvertBlockArgsOp>(
1451 callback: [](ConvertBlockArgsOp op) { return op.getIsLegal(); });
1452
1453 // Handle a partial conversion.
1454 if (mode == ConversionMode::Partial) {
1455 DenseSet<Operation *> unlegalizedOps;
1456 ConversionConfig config;
1457 DumpNotifications dumpNotifications;
1458 config.listener = &dumpNotifications;
1459 config.unlegalizedOps = &unlegalizedOps;
1460 if (failed(Result: applyPartialConversion(op: getOperation(), target,
1461 patterns: std::move(patterns), config))) {
1462 getOperation()->emitRemark() << "applyPartialConversion failed";
1463 }
1464 // Emit remarks for each legalizable operation.
1465 for (auto *op : unlegalizedOps)
1466 op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1467 return;
1468 }
1469
1470 // Handle a full conversion.
1471 if (mode == ConversionMode::Full) {
1472 // Check support for marking unknown operations as dynamically legal.
1473 target.markUnknownOpDynamicallyLegal(fn: [](Operation *op) {
1474 return (bool)op->getAttrOfType<UnitAttr>(name: "test.dynamically_legal");
1475 });
1476
1477 ConversionConfig config;
1478 DumpNotifications dumpNotifications;
1479 config.listener = &dumpNotifications;
1480 if (failed(Result: applyFullConversion(op: getOperation(), target,
1481 patterns: std::move(patterns), config))) {
1482 getOperation()->emitRemark() << "applyFullConversion failed";
1483 }
1484 return;
1485 }
1486
1487 // Otherwise, handle an analysis conversion.
1488 assert(mode == ConversionMode::Analysis);
1489
1490 // Analyze the convertible operations.
1491 DenseSet<Operation *> legalizedOps;
1492 ConversionConfig config;
1493 config.legalizableOps = &legalizedOps;
1494 if (failed(Result: applyAnalysisConversion(op: getOperation(), target,
1495 patterns: std::move(patterns), config)))
1496 return signalPassFailure();
1497
1498 // Emit remarks for each legalizable operation.
1499 for (auto *op : legalizedOps)
1500 op->emitRemark() << "op '" << op->getName() << "' is legalizable";
1501 }
1502
1503 /// The mode of conversion to use.
1504 ConversionMode mode;
1505};
1506} // namespace
1507
1508static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
1509 legalizerConversionMode(
1510 "test-legalize-mode",
1511 llvm::cl::desc("The legalization mode to use with the test driver"),
1512 llvm::cl::init(Val: TestLegalizePatternDriver::ConversionMode::Partial),
1513 llvm::cl::values(
1514 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
1515 "analysis", "Perform an analysis conversion"),
1516 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
1517 "Perform a full conversion"),
1518 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
1519 "partial", "Perform a partial conversion")));
1520
1521//===----------------------------------------------------------------------===//
1522// ConversionPatternRewriter::getRemappedValue testing. This method is used
1523// to get the remapped value of an original value that was replaced using
1524// ConversionPatternRewriter.
1525namespace {
1526struct TestRemapValueTypeConverter : public TypeConverter {
1527 using TypeConverter::TypeConverter;
1528
1529 TestRemapValueTypeConverter() {
1530 addConversion(
1531 callback: [](Float32Type type) { return Float64Type::get(context: type.getContext()); });
1532 addConversion(callback: [](Type type) { return type; });
1533 }
1534};
1535
1536/// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
1537/// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
1538/// operand twice.
1539///
1540/// Example:
1541/// %1 = test.one_variadic_out_one_variadic_in1"(%0)
1542/// is replaced with:
1543/// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
1544struct OneVResOneVOperandOp1Converter
1545 : public OpConversionPattern<OneVResOneVOperandOp1> {
1546 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
1547
1548 LogicalResult
1549 matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
1550 ConversionPatternRewriter &rewriter) const override {
1551 auto origOps = op.getOperands();
1552 assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
1553 "One operand expected");
1554 Value origOp = *origOps.begin();
1555 SmallVector<Value, 2> remappedOperands;
1556 // Replicate the remapped original operand twice. Note that we don't used
1557 // the remapped 'operand' since the goal is testing 'getRemappedValue'.
1558 remappedOperands.push_back(Elt: rewriter.getRemappedValue(key: origOp));
1559 remappedOperands.push_back(Elt: rewriter.getRemappedValue(key: origOp));
1560
1561 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, args: op.getResultTypes(),
1562 args&: remappedOperands);
1563 return success();
1564 }
1565};
1566
1567/// A rewriter pattern that tests that blocks can be merged.
1568struct TestRemapValueInRegion
1569 : public OpConversionPattern<TestRemappedValueRegionOp> {
1570 using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern;
1571
1572 LogicalResult
1573 matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor,
1574 ConversionPatternRewriter &rewriter) const final {
1575 Block &block = op.getBody().front();
1576 Operation *terminator = block.getTerminator();
1577
1578 // Merge the block into the parent region.
1579 Block *parentBlock = op->getBlock();
1580 Block *finalBlock = rewriter.splitBlock(block: parentBlock, before: op->getIterator());
1581 rewriter.mergeBlocks(source: &block, dest: parentBlock, argValues: ValueRange());
1582 rewriter.mergeBlocks(source: finalBlock, dest: parentBlock, argValues: ValueRange());
1583
1584 // Replace the results of this operation with the remapped terminator
1585 // values.
1586 SmallVector<Value> terminatorOperands;
1587 if (failed(Result: rewriter.getRemappedValues(keys: terminator->getOperands(),
1588 results&: terminatorOperands)))
1589 return failure();
1590
1591 rewriter.eraseOp(op: terminator);
1592 rewriter.replaceOp(op, newValues: terminatorOperands);
1593 return success();
1594 }
1595};
1596
1597struct TestRemappedValue
1598 : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> {
1599 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue)
1600
1601 StringRef getArgument() const final { return "test-remapped-value"; }
1602 StringRef getDescription() const final {
1603 return "Test public remapped value mechanism in ConversionPatternRewriter";
1604 }
1605 void runOnOperation() override {
1606 TestRemapValueTypeConverter typeConverter;
1607
1608 mlir::RewritePatternSet patterns(&getContext());
1609 patterns.add<OneVResOneVOperandOp1Converter>(arg: &getContext());
1610 patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>(
1611 arg: &getContext());
1612 patterns.add<TestRemapValueInRegion>(arg&: typeConverter, args: &getContext());
1613
1614 mlir::ConversionTarget target(getContext());
1615 target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>();
1616
1617 // Expect the type_producer/type_consumer operations to only operate on f64.
1618 target.addDynamicallyLegalOp<TestTypeProducerOp>(
1619 callback: [](TestTypeProducerOp op) { return op.getType().isF64(); });
1620 target.addDynamicallyLegalOp<TestTypeConsumerOp>(callback: [](TestTypeConsumerOp op) {
1621 return op.getOperand().getType().isF64();
1622 });
1623
1624 // We make OneVResOneVOperandOp1 legal only when it has more that one
1625 // operand. This will trigger the conversion that will replace one-operand
1626 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
1627 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
1628 callback: [](Operation *op) { return op->getNumOperands() > 1; });
1629
1630 if (failed(Result: mlir::applyFullConversion(op: getOperation(), target,
1631 patterns: std::move(patterns)))) {
1632 signalPassFailure();
1633 }
1634 }
1635};
1636} // namespace
1637
1638//===----------------------------------------------------------------------===//
1639// Test patterns without a specific root operation kind
1640//===----------------------------------------------------------------------===//
1641
1642namespace {
1643/// This pattern matches and removes any operation in the test dialect.
1644struct RemoveTestDialectOps : public RewritePattern {
1645 RemoveTestDialectOps(MLIRContext *context)
1646 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
1647
1648 LogicalResult matchAndRewrite(Operation *op,
1649 PatternRewriter &rewriter) const override {
1650 if (!isa<TestDialect>(Val: op->getDialect()))
1651 return failure();
1652 rewriter.eraseOp(op);
1653 return success();
1654 }
1655};
1656
1657struct TestUnknownRootOpDriver
1658 : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> {
1659 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver)
1660
1661 StringRef getArgument() const final {
1662 return "test-legalize-unknown-root-patterns";
1663 }
1664 StringRef getDescription() const final {
1665 return "Test public remapped value mechanism in ConversionPatternRewriter";
1666 }
1667 void runOnOperation() override {
1668 mlir::RewritePatternSet patterns(&getContext());
1669 patterns.add<RemoveTestDialectOps>(arg: &getContext());
1670
1671 mlir::ConversionTarget target(getContext());
1672 target.addIllegalDialect<TestDialect>();
1673 if (failed(Result: applyPartialConversion(op: getOperation(), target,
1674 patterns: std::move(patterns))))
1675 signalPassFailure();
1676 }
1677};
1678} // namespace
1679
1680//===----------------------------------------------------------------------===//
1681// Test patterns that uses operations and types defined at runtime
1682//===----------------------------------------------------------------------===//
1683
1684namespace {
1685/// This pattern matches dynamic operations 'test.one_operand_two_results' and
1686/// replace them with dynamic operations 'test.generic_dynamic_op'.
1687struct RewriteDynamicOp : public RewritePattern {
1688 RewriteDynamicOp(MLIRContext *context)
1689 : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1,
1690 context) {}
1691
1692 LogicalResult matchAndRewrite(Operation *op,
1693 PatternRewriter &rewriter) const override {
1694 assert(op->getName().getStringRef() ==
1695 "test.dynamic_one_operand_two_results" &&
1696 "rewrite pattern should only match operations with the right name");
1697
1698 OperationState state(op->getLoc(), "test.dynamic_generic",
1699 op->getOperands(), op->getResultTypes(),
1700 op->getAttrs());
1701 auto *newOp = rewriter.create(state);
1702 rewriter.replaceOp(op, newValues: newOp->getResults());
1703 return success();
1704 }
1705};
1706
1707struct TestRewriteDynamicOpDriver
1708 : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> {
1709 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver)
1710
1711 void getDependentDialects(DialectRegistry &registry) const override {
1712 registry.insert<TestDialect>();
1713 }
1714 StringRef getArgument() const final { return "test-rewrite-dynamic-op"; }
1715 StringRef getDescription() const final {
1716 return "Test rewritting on dynamic operations";
1717 }
1718 void runOnOperation() override {
1719 RewritePatternSet patterns(&getContext());
1720 patterns.add<RewriteDynamicOp>(arg: &getContext());
1721
1722 ConversionTarget target(getContext());
1723 target.addIllegalOp(
1724 op: OperationName("test.dynamic_one_operand_two_results", &getContext()));
1725 target.addLegalOp(op: OperationName("test.dynamic_generic", &getContext()));
1726 if (failed(Result: applyPartialConversion(op: getOperation(), target,
1727 patterns: std::move(patterns))))
1728 signalPassFailure();
1729 }
1730};
1731} // end anonymous namespace
1732
1733//===----------------------------------------------------------------------===//
1734// Test type conversions
1735//===----------------------------------------------------------------------===//
1736
1737namespace {
1738struct TestTypeConversionProducer
1739 : public OpConversionPattern<TestTypeProducerOp> {
1740 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
1741 LogicalResult
1742 matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
1743 ConversionPatternRewriter &rewriter) const final {
1744 Type resultType = op.getType();
1745 Type convertedType = getTypeConverter()
1746 ? getTypeConverter()->convertType(t: resultType)
1747 : resultType;
1748 if (isa<FloatType>(Val: resultType))
1749 resultType = rewriter.getF64Type();
1750 else if (resultType.isInteger(width: 16))
1751 resultType = rewriter.getIntegerType(width: 64);
1752 else if (isa<test::TestRecursiveType>(Val: resultType) &&
1753 convertedType != resultType)
1754 resultType = convertedType;
1755 else
1756 return failure();
1757
1758 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, args&: resultType);
1759 return success();
1760 }
1761};
1762
1763/// Call signature conversion and then fail the rewrite to trigger the undo
1764/// mechanism.
1765struct TestSignatureConversionUndo
1766 : public OpConversionPattern<TestSignatureConversionUndoOp> {
1767 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
1768
1769 LogicalResult
1770 matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
1771 ConversionPatternRewriter &rewriter) const final {
1772 (void)rewriter.convertRegionTypes(region: &op->getRegion(index: 0), converter: *getTypeConverter());
1773 return failure();
1774 }
1775};
1776
1777/// Call signature conversion without providing a type converter to handle
1778/// materializations.
1779struct TestTestSignatureConversionNoConverter
1780 : public OpConversionPattern<TestSignatureConversionNoConverterOp> {
1781 TestTestSignatureConversionNoConverter(const TypeConverter &converter,
1782 MLIRContext *context)
1783 : OpConversionPattern<TestSignatureConversionNoConverterOp>(context),
1784 converter(converter) {}
1785
1786 LogicalResult
1787 matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor,
1788 ConversionPatternRewriter &rewriter) const final {
1789 Region &region = op->getRegion(index: 0);
1790 Block *entry = &region.front();
1791
1792 // Convert the original entry arguments.
1793 TypeConverter::SignatureConversion result(entry->getNumArguments());
1794 if (failed(
1795 Result: converter.convertSignatureArgs(types: entry->getArgumentTypes(), result)))
1796 return failure();
1797 rewriter.modifyOpInPlace(root: op, callable: [&] {
1798 rewriter.applySignatureConversion(block: &region.front(), conversion&: result);
1799 });
1800 return success();
1801 }
1802
1803 const TypeConverter &converter;
1804};
1805
1806/// Just forward the operands to the root op. This is essentially a no-op
1807/// pattern that is used to trigger target materialization.
1808struct TestTypeConsumerForward
1809 : public OpConversionPattern<TestTypeConsumerOp> {
1810 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
1811
1812 LogicalResult
1813 matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
1814 ConversionPatternRewriter &rewriter) const final {
1815 rewriter.modifyOpInPlace(root: op,
1816 callable: [&] { op->setOperands(adaptor.getOperands()); });
1817 return success();
1818 }
1819};
1820
1821struct TestTypeConversionAnotherProducer
1822 : public OpRewritePattern<TestAnotherTypeProducerOp> {
1823 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
1824
1825 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
1826 PatternRewriter &rewriter) const final {
1827 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, args: op.getType());
1828 return success();
1829 }
1830};
1831
1832struct TestReplaceWithLegalOp : public ConversionPattern {
1833 TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx)
1834 : ConversionPattern(converter, "test.replace_with_legal_op",
1835 /*benefit=*/1, ctx) {}
1836 LogicalResult
1837 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1838 ConversionPatternRewriter &rewriter) const final {
1839 rewriter.replaceOpWithNewOp<LegalOpD>(op, args: operands[0]);
1840 return success();
1841 }
1842};
1843
1844struct TestTypeConversionDriver
1845 : public PassWrapper<TestTypeConversionDriver, OperationPass<>> {
1846 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
1847
1848 void getDependentDialects(DialectRegistry &registry) const override {
1849 registry.insert<TestDialect>();
1850 }
1851 StringRef getArgument() const final {
1852 return "test-legalize-type-conversion";
1853 }
1854 StringRef getDescription() const final {
1855 return "Test various type conversion functionalities in DialectConversion";
1856 }
1857
1858 void runOnOperation() override {
1859 // Initialize the type converter.
1860 SmallVector<Type, 2> conversionCallStack;
1861 TypeConverter converter;
1862
1863 /// Add the legal set of type conversions.
1864 converter.addConversion(callback: [](Type type) -> Type {
1865 // Treat F64 as legal.
1866 if (type.isF64())
1867 return type;
1868 // Allow converting BF16/F16/F32 to F64.
1869 if (type.isBF16() || type.isF16() || type.isF32())
1870 return Float64Type::get(context: type.getContext());
1871 // Otherwise, the type is illegal.
1872 return nullptr;
1873 });
1874 converter.addConversion(callback: [](IntegerType type, SmallVectorImpl<Type> &) {
1875 // Drop all integer types.
1876 return success();
1877 });
1878 converter.addConversion(
1879 // Convert a recursive self-referring type into a non-self-referring
1880 // type named "outer_converted_type" that contains a SimpleAType.
1881 callback: [&](test::TestRecursiveType type,
1882 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
1883 // If the type is already converted, return it to indicate that it is
1884 // legal.
1885 if (type.getName() == "outer_converted_type") {
1886 results.push_back(Elt: type);
1887 return success();
1888 }
1889
1890 conversionCallStack.push_back(Elt: type);
1891 auto popConversionCallStack = llvm::make_scope_exit(
1892 F: [&conversionCallStack]() { conversionCallStack.pop_back(); });
1893
1894 // If the type is on the call stack more than once (it is there at
1895 // least once because of the _current_ call, which is always the last
1896 // element on the stack), we've hit the recursive case. Just return
1897 // SimpleAType here to create a non-recursive type as a result.
1898 if (llvm::is_contained(Range: ArrayRef(conversionCallStack).drop_back(),
1899 Element: type)) {
1900 results.push_back(Elt: test::SimpleAType::get(ctx: type.getContext()));
1901 return success();
1902 }
1903
1904 // Convert the body recursively.
1905 auto result = test::TestRecursiveType::get(ctx: type.getContext(),
1906 name: "outer_converted_type");
1907 if (failed(Result: result.setBody(converter.convertType(t: type.getBody()))))
1908 return failure();
1909 results.push_back(Elt: result);
1910 return success();
1911 });
1912
1913 /// Add the legal set of type materializations.
1914 converter.addSourceMaterialization(callback: [](OpBuilder &builder, Type resultType,
1915 ValueRange inputs,
1916 Location loc) -> Value {
1917 // Allow casting from F64 back to F32.
1918 if (!resultType.isF16() && inputs.size() == 1 &&
1919 inputs[0].getType().isF64())
1920 return builder.create<TestCastOp>(location: loc, args&: resultType, args&: inputs).getResult();
1921 // Allow producing an i32 or i64 from nothing.
1922 if ((resultType.isInteger(width: 32) || resultType.isInteger(width: 64)) &&
1923 inputs.empty())
1924 return builder.create<TestTypeProducerOp>(location: loc, args&: resultType);
1925 // Allow producing an i64 from an integer.
1926 if (isa<IntegerType>(Val: resultType) && inputs.size() == 1 &&
1927 isa<IntegerType>(Val: inputs[0].getType()))
1928 return builder.create<TestCastOp>(location: loc, args&: resultType, args&: inputs).getResult();
1929 // Otherwise, fail.
1930 return nullptr;
1931 });
1932
1933 // Initialize the conversion target.
1934 mlir::ConversionTarget target(getContext());
1935 target.addLegalOp<LegalOpD>();
1936 target.addDynamicallyLegalOp<TestTypeProducerOp>(callback: [](TestTypeProducerOp op) {
1937 auto recursiveType = dyn_cast<test::TestRecursiveType>(Val: op.getType());
1938 return op.getType().isF64() || op.getType().isInteger(width: 64) ||
1939 (recursiveType &&
1940 recursiveType.getName() == "outer_converted_type");
1941 });
1942 target.addDynamicallyLegalOp<func::FuncOp>(callback: [&](func::FuncOp op) {
1943 return converter.isSignatureLegal(ty: op.getFunctionType()) &&
1944 converter.isLegal(region: &op.getBody());
1945 });
1946 target.addDynamicallyLegalOp<TestCastOp>(callback: [&](TestCastOp op) {
1947 // Allow casts from F64 to F32.
1948 return (*op.operand_type_begin()).isF64() && op.getType().isF32();
1949 });
1950 target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>(
1951 callback: [&](TestSignatureConversionNoConverterOp op) {
1952 return converter.isLegal(range: op.getRegion().front().getArgumentTypes());
1953 });
1954
1955 // Initialize the set of rewrite patterns.
1956 RewritePatternSet patterns(&getContext());
1957 patterns
1958 .add<TestTypeConsumerForward, TestTypeConversionProducer,
1959 TestSignatureConversionUndo,
1960 TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>(
1961 arg&: converter, args: &getContext());
1962 patterns.add<TestTypeConversionAnotherProducer>(arg: &getContext());
1963 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1964 converter);
1965
1966 if (failed(Result: applyPartialConversion(op: getOperation(), target,
1967 patterns: std::move(patterns))))
1968 signalPassFailure();
1969 }
1970};
1971} // namespace
1972
1973//===----------------------------------------------------------------------===//
1974// Test Target Materialization With No Uses
1975//===----------------------------------------------------------------------===//
1976
1977namespace {
1978struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> {
1979 using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern;
1980
1981 LogicalResult
1982 matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor,
1983 ConversionPatternRewriter &rewriter) const final {
1984 rewriter.replaceOp(op, newValues: adaptor.getOperands());
1985 return success();
1986 }
1987};
1988
1989struct TestTargetMaterializationWithNoUses
1990 : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> {
1991 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
1992 TestTargetMaterializationWithNoUses)
1993
1994 StringRef getArgument() const final {
1995 return "test-target-materialization-with-no-uses";
1996 }
1997 StringRef getDescription() const final {
1998 return "Test a special case of target materialization in DialectConversion";
1999 }
2000
2001 void runOnOperation() override {
2002 TypeConverter converter;
2003 converter.addConversion(callback: [](Type t) { return t; });
2004 converter.addConversion(callback: [](IntegerType intTy) -> Type {
2005 if (intTy.getWidth() == 16)
2006 return IntegerType::get(context: intTy.getContext(), width: 64);
2007 return intTy;
2008 });
2009 converter.addTargetMaterialization(
2010 callback: [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
2011 return builder.create<TestCastOp>(location: loc, args&: type, args&: inputs).getResult();
2012 });
2013
2014 ConversionTarget target(getContext());
2015 target.addIllegalOp<TestTypeChangerOp>();
2016
2017 RewritePatternSet patterns(&getContext());
2018 patterns.add<ForwardOperandPattern>(arg&: converter, args: &getContext());
2019
2020 if (failed(Result: applyPartialConversion(op: getOperation(), target,
2021 patterns: std::move(patterns))))
2022 signalPassFailure();
2023 }
2024};
2025} // namespace
2026
2027//===----------------------------------------------------------------------===//
2028// Test Block Merging
2029//===----------------------------------------------------------------------===//
2030
2031namespace {
2032/// A rewriter pattern that tests that blocks can be merged.
2033struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
2034 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
2035
2036 LogicalResult
2037 matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
2038 ConversionPatternRewriter &rewriter) const final {
2039 Block &firstBlock = op.getBody().front();
2040 Operation *branchOp = firstBlock.getTerminator();
2041 Block *secondBlock = &*(std::next(x: op.getBody().begin()));
2042 auto succOperands = branchOp->getOperands();
2043 SmallVector<Value, 2> replacements(succOperands);
2044 rewriter.eraseOp(op: branchOp);
2045 rewriter.mergeBlocks(source: secondBlock, dest: &firstBlock, argValues: replacements);
2046 rewriter.modifyOpInPlace(root: op, callable: [] {});
2047 return success();
2048 }
2049};
2050
2051/// A rewrite pattern to tests the undo mechanism of blocks being merged.
2052struct TestUndoBlocksMerge : public ConversionPattern {
2053 TestUndoBlocksMerge(MLIRContext *ctx)
2054 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
2055 LogicalResult
2056 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
2057 ConversionPatternRewriter &rewriter) const final {
2058 Block &firstBlock = op->getRegion(index: 0).front();
2059 Operation *branchOp = firstBlock.getTerminator();
2060 Block *secondBlock = &*(std::next(x: op->getRegion(index: 0).begin()));
2061 rewriter.setInsertionPointToStart(secondBlock);
2062 rewriter.create<ILLegalOpF>(location: op->getLoc(), args: rewriter.getF32Type());
2063 auto succOperands = branchOp->getOperands();
2064 SmallVector<Value, 2> replacements(succOperands);
2065 rewriter.eraseOp(op: branchOp);
2066 rewriter.mergeBlocks(source: secondBlock, dest: &firstBlock, argValues: replacements);
2067 rewriter.modifyOpInPlace(root: op, callable: [] {});
2068 return success();
2069 }
2070};
2071
2072/// A rewrite mechanism to inline the body of the op into its parent, when both
2073/// ops can have a single block.
2074struct TestMergeSingleBlockOps
2075 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
2076 using OpConversionPattern<
2077 SingleBlockImplicitTerminatorOp>::OpConversionPattern;
2078
2079 LogicalResult
2080 matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
2081 ConversionPatternRewriter &rewriter) const final {
2082 SingleBlockImplicitTerminatorOp parentOp =
2083 op->getParentOfType<SingleBlockImplicitTerminatorOp>();
2084 if (!parentOp)
2085 return failure();
2086 Block &innerBlock = op.getRegion().front();
2087 TerminatorOp innerTerminator =
2088 cast<TerminatorOp>(Val: innerBlock.getTerminator());
2089 rewriter.inlineBlockBefore(source: &innerBlock, op);
2090 rewriter.eraseOp(op: innerTerminator);
2091 rewriter.eraseOp(op);
2092 return success();
2093 }
2094};
2095
2096struct TestMergeBlocksPatternDriver
2097 : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> {
2098 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver)
2099
2100 StringRef getArgument() const final { return "test-merge-blocks"; }
2101 StringRef getDescription() const final {
2102 return "Test Merging operation in ConversionPatternRewriter";
2103 }
2104 void runOnOperation() override {
2105 MLIRContext *context = &getContext();
2106 mlir::RewritePatternSet patterns(context);
2107 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
2108 arg&: context);
2109 ConversionTarget target(*context);
2110 target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
2111 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
2112 target.addIllegalOp<ILLegalOpF>();
2113
2114 /// Expect the op to have a single block after legalization.
2115 target.addDynamicallyLegalOp<TestMergeBlocksOp>(
2116 callback: [&](TestMergeBlocksOp op) -> bool {
2117 return llvm::hasSingleElement(C&: op.getBody());
2118 });
2119
2120 /// Only allow `test.br` within test.merge_blocks op.
2121 target.addDynamicallyLegalOp<TestBranchOp>(callback: [&](TestBranchOp op) -> bool {
2122 return op->getParentOfType<TestMergeBlocksOp>();
2123 });
2124
2125 /// Expect that all nested test.SingleBlockImplicitTerminator ops are
2126 /// inlined.
2127 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
2128 callback: [&](SingleBlockImplicitTerminatorOp op) -> bool {
2129 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
2130 });
2131
2132 DenseSet<Operation *> unlegalizedOps;
2133 ConversionConfig config;
2134 config.unlegalizedOps = &unlegalizedOps;
2135 (void)applyPartialConversion(op: getOperation(), target, patterns: std::move(patterns),
2136 config);
2137 for (auto *op : unlegalizedOps)
2138 op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
2139 }
2140};
2141} // namespace
2142
2143//===----------------------------------------------------------------------===//
2144// Test Selective Replacement
2145//===----------------------------------------------------------------------===//
2146
2147namespace {
2148/// A rewrite mechanism to inline the body of the op into its parent, when both
2149/// ops can have a single block.
2150struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
2151 using OpRewritePattern<TestCastOp>::OpRewritePattern;
2152
2153 LogicalResult matchAndRewrite(TestCastOp op,
2154 PatternRewriter &rewriter) const final {
2155 if (op.getNumOperands() != 2)
2156 return failure();
2157 OperandRange operands = op.getOperands();
2158
2159 // Replace non-terminator uses with the first operand.
2160 rewriter.replaceUsesWithIf(from: op, to: operands[0], functor: [](OpOperand &operand) {
2161 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
2162 });
2163 // Replace everything else with the second operand if the operation isn't
2164 // dead.
2165 rewriter.replaceOp(op, newValues: op.getOperand(i: 1));
2166 return success();
2167 }
2168};
2169
2170struct TestSelectiveReplacementPatternDriver
2171 : public PassWrapper<TestSelectiveReplacementPatternDriver,
2172 OperationPass<>> {
2173 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
2174 TestSelectiveReplacementPatternDriver)
2175
2176 StringRef getArgument() const final {
2177 return "test-pattern-selective-replacement";
2178 }
2179 StringRef getDescription() const final {
2180 return "Test selective replacement in the PatternRewriter";
2181 }
2182 void runOnOperation() override {
2183 MLIRContext *context = &getContext();
2184 mlir::RewritePatternSet patterns(context);
2185 patterns.add<TestSelectiveOpReplacementPattern>(arg&: context);
2186 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
2187 }
2188};
2189} // namespace
2190
2191//===----------------------------------------------------------------------===//
2192// PassRegistration
2193//===----------------------------------------------------------------------===//
2194
2195namespace mlir {
2196namespace test {
2197void registerPatternsTestPass() {
2198 PassRegistration<TestReturnTypeDriver>();
2199
2200 PassRegistration<TestDerivedAttributeDriver>();
2201
2202 PassRegistration<TestGreedyPatternDriver>();
2203 PassRegistration<TestStrictPatternDriver>();
2204 PassRegistration<TestWalkPatternDriver>();
2205
2206 PassRegistration<TestLegalizePatternDriver>([] {
2207 return std::make_unique<TestLegalizePatternDriver>(args&: legalizerConversionMode);
2208 });
2209
2210 PassRegistration<TestRemappedValue>();
2211
2212 PassRegistration<TestUnknownRootOpDriver>();
2213
2214 PassRegistration<TestTypeConversionDriver>();
2215 PassRegistration<TestTargetMaterializationWithNoUses>();
2216
2217 PassRegistration<TestRewriteDynamicOpDriver>();
2218
2219 PassRegistration<TestMergeBlocksPatternDriver>();
2220 PassRegistration<TestSelectiveReplacementPatternDriver>();
2221}
2222} // namespace test
2223} // namespace mlir
2224

source code of mlir/test/lib/Dialect/Test/TestPatterns.cpp