1//===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestDialect.h"
10#include "TestOps.h"
11#include "TestTypes.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/IR/Matchers.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/DialectConversion.h"
19#include "mlir/Transforms/FoldUtils.h"
20#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21#include "llvm/ADT/ScopeExit.h"
22
23using namespace mlir;
24using namespace test;
25
26// Native function for testing NativeCodeCall
27static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
28 return choice.getValue() ? input1 : input2;
29}
30
31static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
32 rewriter.create<OpI>(loc, input);
33}
34
35static void handleNoResultOp(PatternRewriter &rewriter,
36 OpSymbolBindingNoResult op) {
37 // Turn the no result op to a one-result op.
38 rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(),
39 op.getOperand());
40}
41
42static bool getFirstI32Result(Operation *op, Value &value) {
43 if (!Type(op->getResult(idx: 0).getType()).isSignlessInteger(width: 32))
44 return false;
45 value = op->getResult(idx: 0);
46 return true;
47}
48
49static Value bindNativeCodeCallResult(Value value) { return value; }
50
51static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1,
52 Value input2) {
53 return SmallVector<Value, 2>({input2, input1});
54}
55
56// Test that natives calls are only called once during rewrites.
57// OpM_Test will return Pi, increased by 1 for each subsequent calls.
58// This let us check the number of times OpM_Test was called by inspecting
59// the returned value in the MLIR output.
60static int64_t opMIncreasingValue = 314159265;
61static Attribute opMTest(PatternRewriter &rewriter, Value val) {
62 int64_t i = opMIncreasingValue++;
63 return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
64}
65
66namespace {
67#include "TestPatterns.inc"
68} // namespace
69
70//===----------------------------------------------------------------------===//
71// Test Reduce Pattern Interface
72//===----------------------------------------------------------------------===//
73
74void test::populateTestReductionPatterns(RewritePatternSet &patterns) {
75 populateWithGenerated(patterns);
76}
77
78//===----------------------------------------------------------------------===//
79// Canonicalizer Driver.
80//===----------------------------------------------------------------------===//
81
82namespace {
83struct FoldingPattern : public RewritePattern {
84public:
85 FoldingPattern(MLIRContext *context)
86 : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(),
87 /*benefit=*/1, context) {}
88
89 LogicalResult matchAndRewrite(Operation *op,
90 PatternRewriter &rewriter) const override {
91 // Exercise createOrFold API for a single-result operation that is folded
92 // upon construction. The operation being created has an in-place folder,
93 // and it should be still present in the output. Furthermore, the folder
94 // should not crash when attempting to recover the (unchanged) operation
95 // result.
96 Value result = rewriter.createOrFold<TestOpInPlaceFold>(
97 op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0));
98 assert(result);
99 rewriter.replaceOp(op, newValues: result);
100 return success();
101 }
102};
103
104/// This pattern creates a foldable operation at the entry point of the block.
105/// This tests the situation where the operation folder will need to replace an
106/// operation with a previously created constant that does not initially
107/// dominate the operation to replace.
108struct FolderInsertBeforePreviouslyFoldedConstantPattern
109 : public OpRewritePattern<TestCastOp> {
110public:
111 using OpRewritePattern<TestCastOp>::OpRewritePattern;
112
113 LogicalResult matchAndRewrite(TestCastOp op,
114 PatternRewriter &rewriter) const override {
115 if (!op->hasAttr("test_fold_before_previously_folded_op"))
116 return failure();
117 rewriter.setInsertionPointToStart(op->getBlock());
118
119 auto constOp = rewriter.create<arith::ConstantOp>(
120 op.getLoc(), rewriter.getBoolAttr(true));
121 rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
122 Value(constOp));
123 return success();
124 }
125};
126
127/// This pattern matches test.op_commutative2 with the first operand being
128/// another test.op_commutative2 with a constant on the right side and fold it
129/// away by propagating it as its result. This is intend to check that patterns
130/// are applied after the commutative property moves constant to the right.
131struct FolderCommutativeOp2WithConstant
132 : public OpRewritePattern<TestCommutative2Op> {
133public:
134 using OpRewritePattern<TestCommutative2Op>::OpRewritePattern;
135
136 LogicalResult matchAndRewrite(TestCommutative2Op op,
137 PatternRewriter &rewriter) const override {
138 auto operand =
139 dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp());
140 if (!operand)
141 return failure();
142 Attribute constInput;
143 if (!matchPattern(operand->getOperand(1), m_Constant(bind_value: &constInput)))
144 return failure();
145 rewriter.replaceOp(op, operand->getOperand(1));
146 return success();
147 }
148};
149
150/// This pattern matches test.any_attr_of_i32_str ops. In case of an integer
151/// attribute with value smaller than MaxVal, it increments the value by 1.
152template <int MaxVal>
153struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
154 using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern;
155
156 LogicalResult matchAndRewrite(AnyAttrOfOp op,
157 PatternRewriter &rewriter) const override {
158 auto intAttr = dyn_cast<IntegerAttr>(op.getAttr());
159 if (!intAttr)
160 return failure();
161 int64_t val = intAttr.getInt();
162 if (val >= MaxVal)
163 return failure();
164 rewriter.modifyOpInPlace(
165 op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); });
166 return success();
167 }
168};
169
170/// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op".
171struct MakeOpEligible : public RewritePattern {
172 MakeOpEligible(MLIRContext *context)
173 : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {}
174
175 LogicalResult matchAndRewrite(Operation *op,
176 PatternRewriter &rewriter) const override {
177 if (op->hasAttr(name: "eligible"))
178 return failure();
179 rewriter.modifyOpInPlace(
180 root: op, callable: [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); });
181 return success();
182 }
183};
184
185/// This pattern hoists eligible ops out of a "test.one_region_op".
186struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
187 using OpRewritePattern<test::OneRegionOp>::OpRewritePattern;
188
189 LogicalResult matchAndRewrite(test::OneRegionOp op,
190 PatternRewriter &rewriter) const override {
191 Operation *terminator = op.getRegion().front().getTerminator();
192 Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp();
193 if (toBeHoisted->getParentOp() != op)
194 return failure();
195 if (!toBeHoisted->hasAttr(name: "eligible"))
196 return failure();
197 rewriter.moveOpBefore(toBeHoisted, op);
198 return success();
199 }
200};
201
202/// This pattern moves "test.move_before_parent_op" before the parent op.
203struct MoveBeforeParentOp : public RewritePattern {
204 MoveBeforeParentOp(MLIRContext *context)
205 : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
206
207 LogicalResult matchAndRewrite(Operation *op,
208 PatternRewriter &rewriter) const override {
209 // Do not hoist past functions.
210 if (isa<FunctionOpInterface>(Val: op->getParentOp()))
211 return failure();
212 rewriter.moveOpBefore(op, existingOp: op->getParentOp());
213 return success();
214 }
215};
216
217/// This pattern inlines blocks that are nested in
218/// "test.inline_blocks_into_parent" into the parent block.
219struct InlineBlocksIntoParent : public RewritePattern {
220 InlineBlocksIntoParent(MLIRContext *context)
221 : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
222 context) {}
223
224 LogicalResult matchAndRewrite(Operation *op,
225 PatternRewriter &rewriter) const override {
226 bool changed = false;
227 for (Region &r : op->getRegions()) {
228 while (!r.empty()) {
229 rewriter.inlineBlockBefore(source: &r.front(), op);
230 changed = true;
231 }
232 }
233 return success(isSuccess: changed);
234 }
235};
236
237/// This pattern splits blocks at "test.split_block_here" and replaces the op
238/// with a new op (to prevent an infinite loop of block splitting).
239struct SplitBlockHere : public RewritePattern {
240 SplitBlockHere(MLIRContext *context)
241 : RewritePattern("test.split_block_here", /*benefit=*/1, context) {}
242
243 LogicalResult matchAndRewrite(Operation *op,
244 PatternRewriter &rewriter) const override {
245 rewriter.splitBlock(block: op->getBlock(), before: op->getIterator());
246 Operation *newOp = rewriter.create(
247 op->getLoc(),
248 OperationName("test.new_op", op->getContext()).getIdentifier(),
249 op->getOperands(), op->getResultTypes());
250 rewriter.replaceOp(op, newOp);
251 return success();
252 }
253};
254
255/// This pattern clones "test.clone_me" ops.
256struct CloneOp : public RewritePattern {
257 CloneOp(MLIRContext *context)
258 : RewritePattern("test.clone_me", /*benefit=*/1, context) {}
259
260 LogicalResult matchAndRewrite(Operation *op,
261 PatternRewriter &rewriter) const override {
262 // Do not clone already cloned ops to avoid going into an infinite loop.
263 if (op->hasAttr(name: "was_cloned"))
264 return failure();
265 Operation *cloned = rewriter.clone(op&: *op);
266 cloned->setAttr("was_cloned", rewriter.getUnitAttr());
267 return success();
268 }
269};
270
271/// This pattern clones regions of "test.clone_region_before" ops before the
272/// parent block.
273struct CloneRegionBeforeOp : public RewritePattern {
274 CloneRegionBeforeOp(MLIRContext *context)
275 : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {}
276
277 LogicalResult matchAndRewrite(Operation *op,
278 PatternRewriter &rewriter) const override {
279 // Do not clone already cloned ops to avoid going into an infinite loop.
280 if (op->hasAttr(name: "was_cloned"))
281 return failure();
282 for (Region &r : op->getRegions())
283 rewriter.cloneRegionBefore(region&: r, before: op->getBlock());
284 op->setAttr("was_cloned", rewriter.getUnitAttr());
285 return success();
286 }
287};
288
289struct TestPatternDriver
290 : public PassWrapper<TestPatternDriver, OperationPass<>> {
291 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
292
293 TestPatternDriver() = default;
294 TestPatternDriver(const TestPatternDriver &other) : PassWrapper(other) {}
295
296 StringRef getArgument() const final { return "test-patterns"; }
297 StringRef getDescription() const final { return "Run test dialect patterns"; }
298 void runOnOperation() override {
299 mlir::RewritePatternSet patterns(&getContext());
300 populateWithGenerated(patterns);
301
302 // Verify named pattern is generated with expected name.
303 patterns.add<FoldingPattern, TestNamedPatternRule,
304 FolderInsertBeforePreviouslyFoldedConstantPattern,
305 FolderCommutativeOp2WithConstant, HoistEligibleOps,
306 MakeOpEligible>(&getContext());
307
308 // Additional patterns for testing the GreedyPatternRewriteDriver.
309 patterns.insert<IncrementIntAttribute<3>>(arg: &getContext());
310
311 GreedyRewriteConfig config;
312 config.useTopDownTraversal = this->useTopDownTraversal;
313 config.maxIterations = this->maxIterations;
314 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
315 config);
316 }
317
318 Option<bool> useTopDownTraversal{
319 *this, "top-down",
320 llvm::cl::desc("Seed the worklist in general top-down order"),
321 llvm::cl::init(Val: GreedyRewriteConfig().useTopDownTraversal)};
322 Option<int> maxIterations{
323 *this, "max-iterations",
324 llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
325 llvm::cl::init(Val: GreedyRewriteConfig().maxIterations)};
326};
327
328struct DumpNotifications : public RewriterBase::Listener {
329 void notifyBlockInserted(Block *block, Region *previous,
330 Region::iterator previousIt) override {
331 llvm::outs() << "notifyBlockInserted";
332 if (block->getParentOp()) {
333 llvm::outs() << " into " << block->getParentOp()->getName() << ": ";
334 } else {
335 llvm::outs() << " into unknown op: ";
336 }
337 if (previous == nullptr) {
338 llvm::outs() << "was unlinked\n";
339 } else {
340 llvm::outs() << "was linked\n";
341 }
342 }
343 void notifyOperationInserted(Operation *op,
344 OpBuilder::InsertPoint previous) override {
345 llvm::outs() << "notifyOperationInserted: " << op->getName();
346 if (!previous.isSet()) {
347 llvm::outs() << ", was unlinked\n";
348 } else {
349 if (!previous.getPoint().getNodePtr()) {
350 llvm::outs() << ", was linked, exact position unknown\n";
351 } else if (previous.getPoint() == previous.getBlock()->end()) {
352 llvm::outs() << ", was last in block\n";
353 } else {
354 llvm::outs() << ", previous = " << previous.getPoint()->getName()
355 << "\n";
356 }
357 }
358 }
359 void notifyBlockErased(Block *block) override {
360 llvm::outs() << "notifyBlockErased\n";
361 }
362 void notifyOperationErased(Operation *op) override {
363 llvm::outs() << "notifyOperationErased: " << op->getName() << "\n";
364 }
365 void notifyOperationModified(Operation *op) override {
366 llvm::outs() << "notifyOperationModified: " << op->getName() << "\n";
367 }
368 void notifyOperationReplaced(Operation *op, ValueRange values) override {
369 llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n";
370 }
371};
372
373struct TestStrictPatternDriver
374 : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
375public:
376 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
377
378 TestStrictPatternDriver() = default;
379 TestStrictPatternDriver(const TestStrictPatternDriver &other)
380 : PassWrapper(other) {
381 strictMode = other.strictMode;
382 }
383
384 StringRef getArgument() const final { return "test-strict-pattern-driver"; }
385 StringRef getDescription() const final {
386 return "Test strict mode of pattern driver";
387 }
388
389 void runOnOperation() override {
390 MLIRContext *ctx = &getContext();
391 mlir::RewritePatternSet patterns(ctx);
392 patterns.add<
393 // clang-format off
394 ChangeBlockOp,
395 CloneOp,
396 CloneRegionBeforeOp,
397 EraseOp,
398 ImplicitChangeOp,
399 InlineBlocksIntoParent,
400 InsertSameOp,
401 MoveBeforeParentOp,
402 ReplaceWithNewOp,
403 SplitBlockHere
404 // clang-format on
405 >(arg&: ctx);
406 SmallVector<Operation *> ops;
407 getOperation()->walk([&](Operation *op) {
408 StringRef opName = op->getName().getStringRef();
409 if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
410 opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
411 opName == "test.move_before_parent_op" ||
412 opName == "test.inline_blocks_into_parent" ||
413 opName == "test.split_block_here" || opName == "test.clone_me" ||
414 opName == "test.clone_region_before") {
415 ops.push_back(Elt: op);
416 }
417 });
418
419 DumpNotifications dumpNotifications;
420 GreedyRewriteConfig config;
421 config.listener = &dumpNotifications;
422 if (strictMode == "AnyOp") {
423 config.strictMode = GreedyRewriteStrictness::AnyOp;
424 } else if (strictMode == "ExistingAndNewOps") {
425 config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
426 } else if (strictMode == "ExistingOps") {
427 config.strictMode = GreedyRewriteStrictness::ExistingOps;
428 } else {
429 llvm_unreachable("invalid strictness option");
430 }
431
432 // Check if these transformations introduce visiting of operations that
433 // are not in the `ops` set (The new created ops are valid). An invalid
434 // operation will trigger the assertion while processing.
435 bool changed = false;
436 bool allErased = false;
437 (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config,
438 &changed, &allErased);
439 Builder b(ctx);
440 getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(value: changed));
441 getOperation()->setAttr("pattern_driver_all_erased",
442 b.getBoolAttr(value: allErased));
443 }
444
445 Option<std::string> strictMode{
446 *this, "strictness",
447 llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"),
448 llvm::cl::init("AnyOp")};
449
450private:
451 // New inserted operation is valid for further transformation.
452 class InsertSameOp : public RewritePattern {
453 public:
454 InsertSameOp(MLIRContext *context)
455 : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {}
456
457 LogicalResult matchAndRewrite(Operation *op,
458 PatternRewriter &rewriter) const override {
459 if (op->hasAttr(name: "skip"))
460 return failure();
461
462 Operation *newOp =
463 rewriter.create(op->getLoc(), op->getName().getIdentifier(),
464 op->getOperands(), op->getResultTypes());
465 rewriter.modifyOpInPlace(
466 root: op, callable: [&]() { op->setAttr(name: "skip", value: rewriter.getBoolAttr(value: true)); });
467 newOp->setAttr(name: "skip", value: rewriter.getBoolAttr(value: true));
468
469 return success();
470 }
471 };
472
473 // Replace an operation may introduce the re-visiting of its users.
474 class ReplaceWithNewOp : public RewritePattern {
475 public:
476 ReplaceWithNewOp(MLIRContext *context)
477 : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
478
479 LogicalResult matchAndRewrite(Operation *op,
480 PatternRewriter &rewriter) const override {
481 Operation *newOp;
482 if (op->hasAttr(name: "create_erase_op")) {
483 newOp = rewriter.create(
484 op->getLoc(),
485 OperationName("test.erase_op", op->getContext()).getIdentifier(),
486 ValueRange(), TypeRange());
487 } else {
488 newOp = rewriter.create(
489 op->getLoc(),
490 OperationName("test.new_op", op->getContext()).getIdentifier(),
491 op->getOperands(), op->getResultTypes());
492 }
493 // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
494 // A "notifyOperationReplaced" callback is triggered in either case.
495 rewriter.replaceAllOpUsesWith(from: op, to: newOp->getResults());
496 rewriter.eraseOp(op);
497 return success();
498 }
499 };
500
501 // Remove an operation may introduce the re-visiting of its operands.
502 class EraseOp : public RewritePattern {
503 public:
504 EraseOp(MLIRContext *context)
505 : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
506 LogicalResult matchAndRewrite(Operation *op,
507 PatternRewriter &rewriter) const override {
508 rewriter.eraseOp(op);
509 return success();
510 }
511 };
512
513 // The following two patterns test RewriterBase::replaceAllUsesWith.
514 //
515 // That function replaces all usages of a Block (or a Value) with another one
516 // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver
517 // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its
518 // worklist: when an op is modified, it is added to the worklist. The two
519 // patterns below make the tracking observable: ChangeBlockOp replaces all
520 // usages of a block and that pattern is applied because the corresponding ops
521 // are put on the initial worklist (see above). ImplicitChangeOp does an
522 // unrelated change but ops of the corresponding type are *not* on the initial
523 // worklist, so the effect of the second pattern is only visible if the
524 // tracking and subsequent adding to the worklist actually works.
525
526 // Replace all usages of the first successor with the second successor.
527 class ChangeBlockOp : public RewritePattern {
528 public:
529 ChangeBlockOp(MLIRContext *context)
530 : RewritePattern("test.change_block_op", /*benefit=*/1, context) {}
531 LogicalResult matchAndRewrite(Operation *op,
532 PatternRewriter &rewriter) const override {
533 if (op->getNumSuccessors() < 2)
534 return failure();
535 Block *firstSuccessor = op->getSuccessor(index: 0);
536 Block *secondSuccessor = op->getSuccessor(index: 1);
537 if (firstSuccessor == secondSuccessor)
538 return failure();
539 // This is the function being tested:
540 rewriter.replaceAllUsesWith(from: firstSuccessor, to: secondSuccessor);
541 // Using the following line instead would make the test fail:
542 // firstSuccessor->replaceAllUsesWith(secondSuccessor);
543 return success();
544 }
545 };
546
547 // Changes the successor to the parent block.
548 class ImplicitChangeOp : public RewritePattern {
549 public:
550 ImplicitChangeOp(MLIRContext *context)
551 : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {}
552 LogicalResult matchAndRewrite(Operation *op,
553 PatternRewriter &rewriter) const override {
554 if (op->getNumSuccessors() < 1 || op->getSuccessor(index: 0) == op->getBlock())
555 return failure();
556 rewriter.modifyOpInPlace(root: op,
557 callable: [&]() { op->setSuccessor(block: op->getBlock(), index: 0); });
558 return success();
559 }
560 };
561};
562
563} // namespace
564
565//===----------------------------------------------------------------------===//
566// ReturnType Driver.
567//===----------------------------------------------------------------------===//
568
569namespace {
570// Generate ops for each instance where the type can be successfully inferred.
571template <typename OpTy>
572static void invokeCreateWithInferredReturnType(Operation *op) {
573 auto *context = op->getContext();
574 auto fop = op->getParentOfType<func::FuncOp>();
575 auto location = UnknownLoc::get(context);
576 OpBuilder b(op);
577 b.setInsertionPointAfter(op);
578
579 // Use permutations of 2 args as operands.
580 assert(fop.getNumArguments() >= 2);
581 for (int i = 0, e = fop.getNumArguments(); i < e; ++i) {
582 for (int j = 0; j < e; ++j) {
583 std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}};
584 SmallVector<Type, 2> inferredReturnTypes;
585 if (succeeded(OpTy::inferReturnTypes(
586 context, std::nullopt, values, op->getDiscardableAttrDictionary(),
587 op->getPropertiesStorage(), op->getRegions(),
588 inferredReturnTypes))) {
589 OperationState state(location, OpTy::getOperationName());
590 // TODO: Expand to regions.
591 OpTy::build(b, state, values, op->getAttrs());
592 (void)b.create(state);
593 }
594 }
595 }
596}
597
598static void reifyReturnShape(Operation *op) {
599 OpBuilder b(op);
600
601 // Use permutations of 2 args as operands.
602 auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
603 SmallVector<Value, 2> shapes;
604 if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) ||
605 !llvm::hasSingleElement(C&: shapes))
606 return;
607 for (const auto &it : llvm::enumerate(First&: shapes)) {
608 op->emitRemark() << "value " << it.index() << ": "
609 << it.value().getDefiningOp();
610 }
611}
612
613struct TestReturnTypeDriver
614 : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> {
615 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver)
616
617 void getDependentDialects(DialectRegistry &registry) const override {
618 registry.insert<tensor::TensorDialect>();
619 }
620 StringRef getArgument() const final { return "test-return-type"; }
621 StringRef getDescription() const final { return "Run return type functions"; }
622
623 void runOnOperation() override {
624 if (getOperation().getName() == "testCreateFunctions") {
625 std::vector<Operation *> ops;
626 // Collect ops to avoid triggering on inserted ops.
627 for (auto &op : getOperation().getBody().front())
628 ops.push_back(&op);
629 // Generate test patterns for each, but skip terminator.
630 for (auto *op : llvm::ArrayRef(ops).drop_back()) {
631 // Test create method of each of the Op classes below. The resultant
632 // output would be in reverse order underneath `op` from which
633 // the attributes and regions are used.
634 invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
635 invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>(
636 op);
637 invokeCreateWithInferredReturnType<
638 OpWithShapedTypeInferTypeInterfaceOp>(op);
639 };
640 return;
641 }
642 if (getOperation().getName() == "testReifyFunctions") {
643 std::vector<Operation *> ops;
644 // Collect ops to avoid triggering on inserted ops.
645 for (auto &op : getOperation().getBody().front())
646 if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
647 ops.push_back(&op);
648 // Generate test patterns for each, but skip terminator.
649 for (auto *op : ops)
650 reifyReturnShape(op);
651 }
652 }
653};
654} // namespace
655
656namespace {
657struct TestDerivedAttributeDriver
658 : public PassWrapper<TestDerivedAttributeDriver,
659 OperationPass<func::FuncOp>> {
660 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver)
661
662 StringRef getArgument() const final { return "test-derived-attr"; }
663 StringRef getDescription() const final {
664 return "Run test derived attributes";
665 }
666 void runOnOperation() override;
667};
668} // namespace
669
670void TestDerivedAttributeDriver::runOnOperation() {
671 getOperation().walk([](DerivedAttributeOpInterface dOp) {
672 auto dAttr = dOp.materializeDerivedAttributes();
673 if (!dAttr)
674 return;
675 for (auto d : dAttr)
676 dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue();
677 });
678}
679
680//===----------------------------------------------------------------------===//
681// Legalization Driver.
682//===----------------------------------------------------------------------===//
683
684namespace {
685//===----------------------------------------------------------------------===//
686// Region-Block Rewrite Testing
687
688/// This pattern is a simple pattern that inlines the first region of a given
689/// operation into the parent region.
690struct TestRegionRewriteBlockMovement : public ConversionPattern {
691 TestRegionRewriteBlockMovement(MLIRContext *ctx)
692 : ConversionPattern("test.region", 1, ctx) {}
693
694 LogicalResult
695 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
696 ConversionPatternRewriter &rewriter) const final {
697 // Inline this region into the parent region.
698 auto &parentRegion = *op->getParentRegion();
699 auto &opRegion = op->getRegion(index: 0);
700 if (op->getDiscardableAttr(name: "legalizer.should_clone"))
701 rewriter.cloneRegionBefore(region&: opRegion, parent&: parentRegion, before: parentRegion.end());
702 else
703 rewriter.inlineRegionBefore(region&: opRegion, parent&: parentRegion, before: parentRegion.end());
704
705 if (op->getDiscardableAttr(name: "legalizer.erase_old_blocks")) {
706 while (!opRegion.empty())
707 rewriter.eraseBlock(block: &opRegion.front());
708 }
709
710 // Drop this operation.
711 rewriter.eraseOp(op);
712 return success();
713 }
714};
715/// This pattern is a simple pattern that generates a region containing an
716/// illegal operation.
717struct TestRegionRewriteUndo : public RewritePattern {
718 TestRegionRewriteUndo(MLIRContext *ctx)
719 : RewritePattern("test.region_builder", 1, ctx) {}
720
721 LogicalResult matchAndRewrite(Operation *op,
722 PatternRewriter &rewriter) const final {
723 // Create the region operation with an entry block containing arguments.
724 OperationState newRegion(op->getLoc(), "test.region");
725 newRegion.addRegion();
726 auto *regionOp = rewriter.create(state: newRegion);
727 auto *entryBlock = rewriter.createBlock(parent: &regionOp->getRegion(index: 0));
728 entryBlock->addArgument(rewriter.getIntegerType(64),
729 rewriter.getUnknownLoc());
730
731 // Add an explicitly illegal operation to ensure the conversion fails.
732 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
733 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
734
735 // Drop this operation.
736 rewriter.eraseOp(op);
737 return success();
738 }
739};
740/// A simple pattern that creates a block at the end of the parent region of the
741/// matched operation.
742struct TestCreateBlock : public RewritePattern {
743 TestCreateBlock(MLIRContext *ctx)
744 : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
745
746 LogicalResult matchAndRewrite(Operation *op,
747 PatternRewriter &rewriter) const final {
748 Region &region = *op->getParentRegion();
749 Type i32Type = rewriter.getIntegerType(32);
750 Location loc = op->getLoc();
751 rewriter.createBlock(parent: &region, insertPt: region.end(), argTypes: {i32Type, i32Type}, locs: {loc, loc});
752 rewriter.create<TerminatorOp>(loc);
753 rewriter.eraseOp(op);
754 return success();
755 }
756};
757
758/// A simple pattern that creates a block containing an invalid operation in
759/// order to trigger the block creation undo mechanism.
760struct TestCreateIllegalBlock : public RewritePattern {
761 TestCreateIllegalBlock(MLIRContext *ctx)
762 : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
763
764 LogicalResult matchAndRewrite(Operation *op,
765 PatternRewriter &rewriter) const final {
766 Region &region = *op->getParentRegion();
767 Type i32Type = rewriter.getIntegerType(32);
768 Location loc = op->getLoc();
769 rewriter.createBlock(parent: &region, insertPt: region.end(), argTypes: {i32Type, i32Type}, locs: {loc, loc});
770 // Create an illegal op to ensure the conversion fails.
771 rewriter.create<ILLegalOpF>(loc, i32Type);
772 rewriter.create<TerminatorOp>(loc);
773 rewriter.eraseOp(op);
774 return success();
775 }
776};
777
778/// A simple pattern that tests the undo mechanism when replacing the uses of a
779/// block argument.
780struct TestUndoBlockArgReplace : public ConversionPattern {
781 TestUndoBlockArgReplace(MLIRContext *ctx)
782 : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
783
784 LogicalResult
785 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
786 ConversionPatternRewriter &rewriter) const final {
787 auto illegalOp =
788 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
789 rewriter.replaceUsesOfBlockArgument(from: op->getRegion(index: 0).getArgument(i: 0),
790 to: illegalOp->getResult(0));
791 rewriter.modifyOpInPlace(root: op, callable: [] {});
792 return success();
793 }
794};
795
796/// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
797/// This is to test the rollback logic.
798struct TestUndoMoveOpBefore : public ConversionPattern {
799 TestUndoMoveOpBefore(MLIRContext *ctx)
800 : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}
801
802 LogicalResult
803 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
804 ConversionPatternRewriter &rewriter) const override {
805 rewriter.moveOpBefore(op, existingOp: op->getParentOp());
806 // Replace with an illegal op to ensure the conversion fails.
807 rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type());
808 return success();
809 }
810};
811
812/// A rewrite pattern that tests the undo mechanism when erasing a block.
813struct TestUndoBlockErase : public ConversionPattern {
814 TestUndoBlockErase(MLIRContext *ctx)
815 : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
816
817 LogicalResult
818 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
819 ConversionPatternRewriter &rewriter) const final {
820 Block *secondBlock = &*std::next(x: op->getRegion(index: 0).begin());
821 rewriter.setInsertionPointToStart(secondBlock);
822 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
823 rewriter.eraseBlock(block: secondBlock);
824 rewriter.modifyOpInPlace(root: op, callable: [] {});
825 return success();
826 }
827};
828
829/// A pattern that modifies a property in-place, but keeps the op illegal.
830struct TestUndoPropertiesModification : public ConversionPattern {
831 TestUndoPropertiesModification(MLIRContext *ctx)
832 : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {}
833 LogicalResult
834 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
835 ConversionPatternRewriter &rewriter) const final {
836 if (!op->hasAttr(name: "modify_inplace"))
837 return failure();
838 rewriter.modifyOpInPlace(
839 root: op, callable: [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); });
840 return success();
841 }
842};
843
844//===----------------------------------------------------------------------===//
845// Type-Conversion Rewrite Testing
846
847/// This patterns erases a region operation that has had a type conversion.
848struct TestDropOpSignatureConversion : public ConversionPattern {
849 TestDropOpSignatureConversion(MLIRContext *ctx,
850 const TypeConverter &converter)
851 : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {}
852 LogicalResult
853 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
854 ConversionPatternRewriter &rewriter) const override {
855 Region &region = op->getRegion(index: 0);
856 Block *entry = &region.front();
857
858 // Convert the original entry arguments.
859 const TypeConverter &converter = *getTypeConverter();
860 TypeConverter::SignatureConversion result(entry->getNumArguments());
861 if (failed(result: converter.convertSignatureArgs(types: entry->getArgumentTypes(),
862 result)) ||
863 failed(result: rewriter.convertRegionTypes(region: &region, converter, entryConversion: &result)))
864 return failure();
865
866 // Convert the region signature and just drop the operation.
867 rewriter.eraseOp(op);
868 return success();
869 }
870};
871/// This pattern simply updates the operands of the given operation.
872struct TestPassthroughInvalidOp : public ConversionPattern {
873 TestPassthroughInvalidOp(MLIRContext *ctx)
874 : ConversionPattern("test.invalid", 1, ctx) {}
875 LogicalResult
876 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
877 ConversionPatternRewriter &rewriter) const final {
878 rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, operands,
879 std::nullopt);
880 return success();
881 }
882};
883/// This pattern handles the case of a split return value.
884struct TestSplitReturnType : public ConversionPattern {
885 TestSplitReturnType(MLIRContext *ctx)
886 : ConversionPattern("test.return", 1, ctx) {}
887 LogicalResult
888 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
889 ConversionPatternRewriter &rewriter) const final {
890 // Check for a return of F32.
891 if (op->getNumOperands() != 1 || !op->getOperand(idx: 0).getType().isF32())
892 return failure();
893
894 // Check if the first operation is a cast operation, if it is we use the
895 // results directly.
896 auto *defOp = operands[0].getDefiningOp();
897 if (auto packerOp =
898 llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) {
899 rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
900 return success();
901 }
902
903 // Otherwise, fail to match.
904 return failure();
905 }
906};
907
908//===----------------------------------------------------------------------===//
909// Multi-Level Type-Conversion Rewrite Testing
910struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
911 TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
912 : ConversionPattern("test.type_producer", 1, ctx) {}
913 LogicalResult
914 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
915 ConversionPatternRewriter &rewriter) const final {
916 // If the type is I32, change the type to F32.
917 if (!Type(*op->result_type_begin()).isSignlessInteger(width: 32))
918 return failure();
919 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
920 return success();
921 }
922};
923struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
924 TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
925 : ConversionPattern("test.type_producer", 1, ctx) {}
926 LogicalResult
927 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
928 ConversionPatternRewriter &rewriter) const final {
929 // If the type is F32, change the type to F64.
930 if (!Type(*op->result_type_begin()).isF32())
931 return rewriter.notifyMatchFailure(arg&: op, msg: "expected single f32 operand");
932 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
933 return success();
934 }
935};
936struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
937 TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
938 : ConversionPattern("test.type_producer", 10, ctx) {}
939 LogicalResult
940 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
941 ConversionPatternRewriter &rewriter) const final {
942 // Always convert to B16, even though it is not a legal type. This tests
943 // that values are unmapped correctly.
944 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
945 return success();
946 }
947};
948struct TestUpdateConsumerType : public ConversionPattern {
949 TestUpdateConsumerType(MLIRContext *ctx)
950 : ConversionPattern("test.type_consumer", 1, ctx) {}
951 LogicalResult
952 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
953 ConversionPatternRewriter &rewriter) const final {
954 // Verify that the incoming operand has been successfully remapped to F64.
955 if (!operands[0].getType().isF64())
956 return failure();
957 rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
958 return success();
959 }
960};
961
962//===----------------------------------------------------------------------===//
963// Non-Root Replacement Rewrite Testing
964/// This pattern generates an invalid operation, but replaces it before the
965/// pattern is finished. This checks that we don't need to legalize the
966/// temporary op.
967struct TestNonRootReplacement : public RewritePattern {
968 TestNonRootReplacement(MLIRContext *ctx)
969 : RewritePattern("test.replace_non_root", 1, ctx) {}
970
971 LogicalResult matchAndRewrite(Operation *op,
972 PatternRewriter &rewriter) const final {
973 auto resultType = *op->result_type_begin();
974 auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
975 auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
976
977 rewriter.replaceOp(illegalOp, legalOp);
978 rewriter.replaceOp(op, illegalOp);
979 return success();
980 }
981};
982
983//===----------------------------------------------------------------------===//
984// Recursive Rewrite Testing
985/// This pattern is applied to the same operation multiple times, but has a
986/// bounded recursion.
987struct TestBoundedRecursiveRewrite
988 : public OpRewritePattern<TestRecursiveRewriteOp> {
989 using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
990
991 void initialize() {
992 // The conversion target handles bounding the recursion of this pattern.
993 setHasBoundedRewriteRecursion();
994 }
995
996 LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
997 PatternRewriter &rewriter) const final {
998 // Decrement the depth of the op in-place.
999 rewriter.modifyOpInPlace(op, [&] {
1000 op->setAttr("depth", rewriter.getI64IntegerAttr(value: op.getDepth() - 1));
1001 });
1002 return success();
1003 }
1004};
1005
1006struct TestNestedOpCreationUndoRewrite
1007 : public OpRewritePattern<IllegalOpWithRegionAnchor> {
1008 using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern;
1009
1010 LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op,
1011 PatternRewriter &rewriter) const final {
1012 // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1013 rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op);
1014 return success();
1015 };
1016};
1017
1018// This pattern matches `test.blackhole` and delete this op and its producer.
1019struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
1020 using OpRewritePattern<BlackHoleOp>::OpRewritePattern;
1021
1022 LogicalResult matchAndRewrite(BlackHoleOp op,
1023 PatternRewriter &rewriter) const final {
1024 Operation *producer = op.getOperand().getDefiningOp();
1025 // Always erase the user before the producer, the framework should handle
1026 // this correctly.
1027 rewriter.eraseOp(op: op);
1028 rewriter.eraseOp(op: producer);
1029 return success();
1030 };
1031};
1032
1033// This pattern replaces explicitly illegal op with explicitly legal op,
1034// but in addition creates unregistered operation.
1035struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
1036 using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
1037
1038 LogicalResult matchAndRewrite(ILLegalOpG op,
1039 PatternRewriter &rewriter) const final {
1040 IntegerAttr attr = rewriter.getI32IntegerAttr(0);
1041 Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
1042 rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
1043 return success();
1044 };
1045};
1046} // namespace
1047
1048namespace {
1049struct TestTypeConverter : public TypeConverter {
1050 using TypeConverter::TypeConverter;
1051 TestTypeConverter() {
1052 addConversion(callback&: convertType);
1053 addArgumentMaterialization(callback&: materializeCast);
1054 addSourceMaterialization(callback&: materializeCast);
1055 }
1056
1057 static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
1058 // Drop I16 types.
1059 if (t.isSignlessInteger(width: 16))
1060 return success();
1061
1062 // Convert I64 to F64.
1063 if (t.isSignlessInteger(width: 64)) {
1064 results.push_back(Elt: FloatType::getF64(ctx: t.getContext()));
1065 return success();
1066 }
1067
1068 // Convert I42 to I43.
1069 if (t.isInteger(width: 42)) {
1070 results.push_back(IntegerType::get(t.getContext(), 43));
1071 return success();
1072 }
1073
1074 // Split F32 into F16,F16.
1075 if (t.isF32()) {
1076 results.assign(NumElts: 2, Elt: FloatType::getF16(ctx: t.getContext()));
1077 return success();
1078 }
1079
1080 // Otherwise, convert the type directly.
1081 results.push_back(Elt: t);
1082 return success();
1083 }
1084
1085 /// Hook for materializing a conversion. This is necessary because we generate
1086 /// 1->N type mappings.
1087 static std::optional<Value> materializeCast(OpBuilder &builder,
1088 Type resultType,
1089 ValueRange inputs, Location loc) {
1090 return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1091 }
1092};
1093
1094struct TestLegalizePatternDriver
1095 : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> {
1096 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver)
1097
1098 StringRef getArgument() const final { return "test-legalize-patterns"; }
1099 StringRef getDescription() const final {
1100 return "Run test dialect legalization patterns";
1101 }
1102 /// The mode of conversion to use with the driver.
1103 enum class ConversionMode { Analysis, Full, Partial };
1104
1105 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
1106
1107 void getDependentDialects(DialectRegistry &registry) const override {
1108 registry.insert<func::FuncDialect, test::TestDialect>();
1109 }
1110
1111 void runOnOperation() override {
1112 TestTypeConverter converter;
1113 mlir::RewritePatternSet patterns(&getContext());
1114 populateWithGenerated(patterns);
1115 patterns
1116 .add<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
1117 TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace,
1118 TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType,
1119 TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
1120 TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
1121 TestNonRootReplacement, TestBoundedRecursiveRewrite,
1122 TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
1123 TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1124 TestUndoPropertiesModification>(arg: &getContext());
1125 patterns.add<TestDropOpSignatureConversion>(arg: &getContext(), args&: converter);
1126 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1127 converter);
1128 mlir::populateCallOpTypeConversionPattern(patterns, converter);
1129
1130 // Define the conversion target used for the test.
1131 ConversionTarget target(getContext());
1132 target.addLegalOp<ModuleOp>();
1133 target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
1134 TerminatorOp, OneRegionOp>();
1135 target
1136 .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
1137 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
1138 // Don't allow F32 operands.
1139 return llvm::none_of(op.getOperandTypes(),
1140 [](Type type) { return type.isF32(); });
1141 });
1142 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1143 return converter.isSignatureLegal(op.getFunctionType()) &&
1144 converter.isLegal(&op.getBody());
1145 });
1146 target.addDynamicallyLegalOp<func::CallOp>(
1147 [&](func::CallOp op) { return converter.isLegal(op); });
1148
1149 // TestCreateUnregisteredOp creates `arith.constant` operation,
1150 // which was not added to target intentionally to test
1151 // correct error code from conversion driver.
1152 target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; });
1153
1154 // Expect the type_producer/type_consumer operations to only operate on f64.
1155 target.addDynamicallyLegalOp<TestTypeProducerOp>(
1156 [](TestTypeProducerOp op) { return op.getType().isF64(); });
1157 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1158 return op.getOperand().getType().isF64();
1159 });
1160
1161 // Check support for marking certain operations as recursively legal.
1162 target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) {
1163 return static_cast<bool>(
1164 op->getAttrOfType<UnitAttr>("test.recursively_legal"));
1165 });
1166
1167 // Mark the bound recursion operation as dynamically legal.
1168 target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
1169 [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
1170
1171 // Create a dynamically legal rule that can only be legalized by folding it.
1172 target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
1173 [](TestOpInPlaceSelfFold op) { return op.getFolded(); });
1174
1175 // Handle a partial conversion.
1176 if (mode == ConversionMode::Partial) {
1177 DenseSet<Operation *> unlegalizedOps;
1178 ConversionConfig config;
1179 DumpNotifications dumpNotifications;
1180 config.listener = &dumpNotifications;
1181 config.unlegalizedOps = &unlegalizedOps;
1182 if (failed(applyPartialConversion(getOperation(), target,
1183 std::move(patterns), config))) {
1184 getOperation()->emitRemark() << "applyPartialConversion failed";
1185 }
1186 // Emit remarks for each legalizable operation.
1187 for (auto *op : unlegalizedOps)
1188 op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1189 return;
1190 }
1191
1192 // Handle a full conversion.
1193 if (mode == ConversionMode::Full) {
1194 // Check support for marking unknown operations as dynamically legal.
1195 target.markUnknownOpDynamicallyLegal([](Operation *op) {
1196 return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
1197 });
1198
1199 ConversionConfig config;
1200 DumpNotifications dumpNotifications;
1201 config.listener = &dumpNotifications;
1202 if (failed(applyFullConversion(getOperation(), target,
1203 std::move(patterns), config))) {
1204 getOperation()->emitRemark() << "applyFullConversion failed";
1205 }
1206 return;
1207 }
1208
1209 // Otherwise, handle an analysis conversion.
1210 assert(mode == ConversionMode::Analysis);
1211
1212 // Analyze the convertible operations.
1213 DenseSet<Operation *> legalizedOps;
1214 ConversionConfig config;
1215 config.legalizableOps = &legalizedOps;
1216 if (failed(applyAnalysisConversion(getOperation(), target,
1217 std::move(patterns), config)))
1218 return signalPassFailure();
1219
1220 // Emit remarks for each legalizable operation.
1221 for (auto *op : legalizedOps)
1222 op->emitRemark() << "op '" << op->getName() << "' is legalizable";
1223 }
1224
1225 /// The mode of conversion to use.
1226 ConversionMode mode;
1227};
1228} // namespace
1229
1230static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
1231 legalizerConversionMode(
1232 "test-legalize-mode",
1233 llvm::cl::desc("The legalization mode to use with the test driver"),
1234 llvm::cl::init(Val: TestLegalizePatternDriver::ConversionMode::Partial),
1235 llvm::cl::values(
1236 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
1237 "analysis", "Perform an analysis conversion"),
1238 clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
1239 "Perform a full conversion"),
1240 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
1241 "partial", "Perform a partial conversion")));
1242
1243//===----------------------------------------------------------------------===//
1244// ConversionPatternRewriter::getRemappedValue testing. This method is used
1245// to get the remapped value of an original value that was replaced using
1246// ConversionPatternRewriter.
1247namespace {
1248struct TestRemapValueTypeConverter : public TypeConverter {
1249 using TypeConverter::TypeConverter;
1250
1251 TestRemapValueTypeConverter() {
1252 addConversion(
1253 callback: [](Float32Type type) { return Float64Type::get(type.getContext()); });
1254 addConversion(callback: [](Type type) { return type; });
1255 }
1256};
1257
1258/// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
1259/// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
1260/// operand twice.
1261///
1262/// Example:
1263/// %1 = test.one_variadic_out_one_variadic_in1"(%0)
1264/// is replaced with:
1265/// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
1266struct OneVResOneVOperandOp1Converter
1267 : public OpConversionPattern<OneVResOneVOperandOp1> {
1268 using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
1269
1270 LogicalResult
1271 matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor,
1272 ConversionPatternRewriter &rewriter) const override {
1273 auto origOps = op.getOperands();
1274 assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
1275 "One operand expected");
1276 Value origOp = *origOps.begin();
1277 SmallVector<Value, 2> remappedOperands;
1278 // Replicate the remapped original operand twice. Note that we don't used
1279 // the remapped 'operand' since the goal is testing 'getRemappedValue'.
1280 remappedOperands.push_back(Elt: rewriter.getRemappedValue(key: origOp));
1281 remappedOperands.push_back(Elt: rewriter.getRemappedValue(key: origOp));
1282
1283 rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
1284 remappedOperands);
1285 return success();
1286 }
1287};
1288
1289/// A rewriter pattern that tests that blocks can be merged.
1290struct TestRemapValueInRegion
1291 : public OpConversionPattern<TestRemappedValueRegionOp> {
1292 using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern;
1293
1294 LogicalResult
1295 matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor,
1296 ConversionPatternRewriter &rewriter) const final {
1297 Block &block = op.getBody().front();
1298 Operation *terminator = block.getTerminator();
1299
1300 // Merge the block into the parent region.
1301 Block *parentBlock = op->getBlock();
1302 Block *finalBlock = rewriter.splitBlock(block: parentBlock, before: op->getIterator());
1303 rewriter.mergeBlocks(source: &block, dest: parentBlock, argValues: ValueRange());
1304 rewriter.mergeBlocks(source: finalBlock, dest: parentBlock, argValues: ValueRange());
1305
1306 // Replace the results of this operation with the remapped terminator
1307 // values.
1308 SmallVector<Value> terminatorOperands;
1309 if (failed(result: rewriter.getRemappedValues(keys: terminator->getOperands(),
1310 results&: terminatorOperands)))
1311 return failure();
1312
1313 rewriter.eraseOp(op: terminator);
1314 rewriter.replaceOp(op, terminatorOperands);
1315 return success();
1316 }
1317};
1318
1319struct TestRemappedValue
1320 : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> {
1321 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue)
1322
1323 StringRef getArgument() const final { return "test-remapped-value"; }
1324 StringRef getDescription() const final {
1325 return "Test public remapped value mechanism in ConversionPatternRewriter";
1326 }
1327 void runOnOperation() override {
1328 TestRemapValueTypeConverter typeConverter;
1329
1330 mlir::RewritePatternSet patterns(&getContext());
1331 patterns.add<OneVResOneVOperandOp1Converter>(arg: &getContext());
1332 patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>(
1333 arg: &getContext());
1334 patterns.add<TestRemapValueInRegion>(arg&: typeConverter, args: &getContext());
1335
1336 mlir::ConversionTarget target(getContext());
1337 target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>();
1338
1339 // Expect the type_producer/type_consumer operations to only operate on f64.
1340 target.addDynamicallyLegalOp<TestTypeProducerOp>(
1341 [](TestTypeProducerOp op) { return op.getType().isF64(); });
1342 target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
1343 return op.getOperand().getType().isF64();
1344 });
1345
1346 // We make OneVResOneVOperandOp1 legal only when it has more that one
1347 // operand. This will trigger the conversion that will replace one-operand
1348 // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
1349 target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
1350 [](Operation *op) { return op->getNumOperands() > 1; });
1351
1352 if (failed(mlir::applyFullConversion(getOperation(), target,
1353 std::move(patterns)))) {
1354 signalPassFailure();
1355 }
1356 }
1357};
1358} // namespace
1359
1360//===----------------------------------------------------------------------===//
1361// Test patterns without a specific root operation kind
1362//===----------------------------------------------------------------------===//
1363
1364namespace {
1365/// This pattern matches and removes any operation in the test dialect.
1366struct RemoveTestDialectOps : public RewritePattern {
1367 RemoveTestDialectOps(MLIRContext *context)
1368 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
1369
1370 LogicalResult matchAndRewrite(Operation *op,
1371 PatternRewriter &rewriter) const override {
1372 if (!isa<TestDialect>(Val: op->getDialect()))
1373 return failure();
1374 rewriter.eraseOp(op);
1375 return success();
1376 }
1377};
1378
1379struct TestUnknownRootOpDriver
1380 : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> {
1381 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver)
1382
1383 StringRef getArgument() const final {
1384 return "test-legalize-unknown-root-patterns";
1385 }
1386 StringRef getDescription() const final {
1387 return "Test public remapped value mechanism in ConversionPatternRewriter";
1388 }
1389 void runOnOperation() override {
1390 mlir::RewritePatternSet patterns(&getContext());
1391 patterns.add<RemoveTestDialectOps>(arg: &getContext());
1392
1393 mlir::ConversionTarget target(getContext());
1394 target.addIllegalDialect<TestDialect>();
1395 if (failed(applyPartialConversion(getOperation(), target,
1396 std::move(patterns))))
1397 signalPassFailure();
1398 }
1399};
1400} // namespace
1401
1402//===----------------------------------------------------------------------===//
1403// Test patterns that uses operations and types defined at runtime
1404//===----------------------------------------------------------------------===//
1405
1406namespace {
1407/// This pattern matches dynamic operations 'test.one_operand_two_results' and
1408/// replace them with dynamic operations 'test.generic_dynamic_op'.
1409struct RewriteDynamicOp : public RewritePattern {
1410 RewriteDynamicOp(MLIRContext *context)
1411 : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1,
1412 context) {}
1413
1414 LogicalResult matchAndRewrite(Operation *op,
1415 PatternRewriter &rewriter) const override {
1416 assert(op->getName().getStringRef() ==
1417 "test.dynamic_one_operand_two_results" &&
1418 "rewrite pattern should only match operations with the right name");
1419
1420 OperationState state(op->getLoc(), "test.dynamic_generic",
1421 op->getOperands(), op->getResultTypes(),
1422 op->getAttrs());
1423 auto *newOp = rewriter.create(state);
1424 rewriter.replaceOp(op, newValues: newOp->getResults());
1425 return success();
1426 }
1427};
1428
1429struct TestRewriteDynamicOpDriver
1430 : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> {
1431 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver)
1432
1433 void getDependentDialects(DialectRegistry &registry) const override {
1434 registry.insert<TestDialect>();
1435 }
1436 StringRef getArgument() const final { return "test-rewrite-dynamic-op"; }
1437 StringRef getDescription() const final {
1438 return "Test rewritting on dynamic operations";
1439 }
1440 void runOnOperation() override {
1441 RewritePatternSet patterns(&getContext());
1442 patterns.add<RewriteDynamicOp>(arg: &getContext());
1443
1444 ConversionTarget target(getContext());
1445 target.addIllegalOp(
1446 op: OperationName("test.dynamic_one_operand_two_results", &getContext()));
1447 target.addLegalOp(op: OperationName("test.dynamic_generic", &getContext()));
1448 if (failed(applyPartialConversion(getOperation(), target,
1449 std::move(patterns))))
1450 signalPassFailure();
1451 }
1452};
1453} // end anonymous namespace
1454
1455//===----------------------------------------------------------------------===//
1456// Test type conversions
1457//===----------------------------------------------------------------------===//
1458
1459namespace {
1460struct TestTypeConversionProducer
1461 : public OpConversionPattern<TestTypeProducerOp> {
1462 using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
1463 LogicalResult
1464 matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
1465 ConversionPatternRewriter &rewriter) const final {
1466 Type resultType = op.getType();
1467 Type convertedType = getTypeConverter()
1468 ? getTypeConverter()->convertType(resultType)
1469 : resultType;
1470 if (isa<FloatType>(Val: resultType))
1471 resultType = rewriter.getF64Type();
1472 else if (resultType.isInteger(width: 16))
1473 resultType = rewriter.getIntegerType(64);
1474 else if (isa<test::TestRecursiveType>(Val: resultType) &&
1475 convertedType != resultType)
1476 resultType = convertedType;
1477 else
1478 return failure();
1479
1480 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
1481 return success();
1482 }
1483};
1484
1485/// Call signature conversion and then fail the rewrite to trigger the undo
1486/// mechanism.
1487struct TestSignatureConversionUndo
1488 : public OpConversionPattern<TestSignatureConversionUndoOp> {
1489 using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern;
1490
1491 LogicalResult
1492 matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor,
1493 ConversionPatternRewriter &rewriter) const final {
1494 (void)rewriter.convertRegionTypes(region: &op->getRegion(0), converter: *getTypeConverter());
1495 return failure();
1496 }
1497};
1498
1499/// Call signature conversion without providing a type converter to handle
1500/// materializations.
1501struct TestTestSignatureConversionNoConverter
1502 : public OpConversionPattern<TestSignatureConversionNoConverterOp> {
1503 TestTestSignatureConversionNoConverter(const TypeConverter &converter,
1504 MLIRContext *context)
1505 : OpConversionPattern<TestSignatureConversionNoConverterOp>(context),
1506 converter(converter) {}
1507
1508 LogicalResult
1509 matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor,
1510 ConversionPatternRewriter &rewriter) const final {
1511 Region &region = op->getRegion(0);
1512 Block *entry = &region.front();
1513
1514 // Convert the original entry arguments.
1515 TypeConverter::SignatureConversion result(entry->getNumArguments());
1516 if (failed(
1517 result: converter.convertSignatureArgs(types: entry->getArgumentTypes(), result)))
1518 return failure();
1519 rewriter.modifyOpInPlace(
1520 op, [&] { rewriter.applySignatureConversion(region: &region, conversion&: result); });
1521 return success();
1522 }
1523
1524 const TypeConverter &converter;
1525};
1526
1527/// Just forward the operands to the root op. This is essentially a no-op
1528/// pattern that is used to trigger target materialization.
1529struct TestTypeConsumerForward
1530 : public OpConversionPattern<TestTypeConsumerOp> {
1531 using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern;
1532
1533 LogicalResult
1534 matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
1535 ConversionPatternRewriter &rewriter) const final {
1536 rewriter.modifyOpInPlace(op,
1537 [&] { op->setOperands(adaptor.getOperands()); });
1538 return success();
1539 }
1540};
1541
1542struct TestTypeConversionAnotherProducer
1543 : public OpRewritePattern<TestAnotherTypeProducerOp> {
1544 using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
1545
1546 LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
1547 PatternRewriter &rewriter) const final {
1548 rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
1549 return success();
1550 }
1551};
1552
1553struct TestTypeConversionDriver
1554 : public PassWrapper<TestTypeConversionDriver, OperationPass<>> {
1555 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
1556
1557 void getDependentDialects(DialectRegistry &registry) const override {
1558 registry.insert<TestDialect>();
1559 }
1560 StringRef getArgument() const final {
1561 return "test-legalize-type-conversion";
1562 }
1563 StringRef getDescription() const final {
1564 return "Test various type conversion functionalities in DialectConversion";
1565 }
1566
1567 void runOnOperation() override {
1568 // Initialize the type converter.
1569 SmallVector<Type, 2> conversionCallStack;
1570 TypeConverter converter;
1571
1572 /// Add the legal set of type conversions.
1573 converter.addConversion(callback: [](Type type) -> Type {
1574 // Treat F64 as legal.
1575 if (type.isF64())
1576 return type;
1577 // Allow converting BF16/F16/F32 to F64.
1578 if (type.isBF16() || type.isF16() || type.isF32())
1579 return FloatType::getF64(ctx: type.getContext());
1580 // Otherwise, the type is illegal.
1581 return nullptr;
1582 });
1583 converter.addConversion(callback: [](IntegerType type, SmallVectorImpl<Type> &) {
1584 // Drop all integer types.
1585 return success();
1586 });
1587 converter.addConversion(
1588 // Convert a recursive self-referring type into a non-self-referring
1589 // type named "outer_converted_type" that contains a SimpleAType.
1590 callback: [&](test::TestRecursiveType type,
1591 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
1592 // If the type is already converted, return it to indicate that it is
1593 // legal.
1594 if (type.getName() == "outer_converted_type") {
1595 results.push_back(type);
1596 return success();
1597 }
1598
1599 conversionCallStack.push_back(type);
1600 auto popConversionCallStack = llvm::make_scope_exit(
1601 F: [&conversionCallStack]() { conversionCallStack.pop_back(); });
1602
1603 // If the type is on the call stack more than once (it is there at
1604 // least once because of the _current_ call, which is always the last
1605 // element on the stack), we've hit the recursive case. Just return
1606 // SimpleAType here to create a non-recursive type as a result.
1607 if (llvm::is_contained(Range: ArrayRef(conversionCallStack).drop_back(),
1608 Element: type)) {
1609 results.push_back(test::SimpleAType::get(type.getContext()));
1610 return success();
1611 }
1612
1613 // Convert the body recursively.
1614 auto result = test::TestRecursiveType::get(ctx: type.getContext(),
1615 name: "outer_converted_type");
1616 if (failed(result.setBody(converter.convertType(t: type.getBody()))))
1617 return failure();
1618 results.push_back(Elt: result);
1619 return success();
1620 });
1621
1622 /// Add the legal set of type materializations.
1623 converter.addSourceMaterialization(callback: [](OpBuilder &builder, Type resultType,
1624 ValueRange inputs,
1625 Location loc) -> Value {
1626 // Allow casting from F64 back to F32.
1627 if (!resultType.isF16() && inputs.size() == 1 &&
1628 inputs[0].getType().isF64())
1629 return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1630 // Allow producing an i32 or i64 from nothing.
1631 if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
1632 inputs.empty())
1633 return builder.create<TestTypeProducerOp>(loc, resultType);
1634 // Allow producing an i64 from an integer.
1635 if (isa<IntegerType>(resultType) && inputs.size() == 1 &&
1636 isa<IntegerType>(inputs[0].getType()))
1637 return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
1638 // Otherwise, fail.
1639 return nullptr;
1640 });
1641
1642 // Initialize the conversion target.
1643 mlir::ConversionTarget target(getContext());
1644 target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
1645 auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType());
1646 return op.getType().isF64() || op.getType().isInteger(64) ||
1647 (recursiveType &&
1648 recursiveType.getName() == "outer_converted_type");
1649 });
1650 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1651 return converter.isSignatureLegal(op.getFunctionType()) &&
1652 converter.isLegal(&op.getBody());
1653 });
1654 target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
1655 // Allow casts from F64 to F32.
1656 return (*op.operand_type_begin()).isF64() && op.getType().isF32();
1657 });
1658 target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>(
1659 [&](TestSignatureConversionNoConverterOp op) {
1660 return converter.isLegal(op.getRegion().front().getArgumentTypes());
1661 });
1662
1663 // Initialize the set of rewrite patterns.
1664 RewritePatternSet patterns(&getContext());
1665 patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
1666 TestSignatureConversionUndo,
1667 TestTestSignatureConversionNoConverter>(arg&: converter,
1668 args: &getContext());
1669 patterns.add<TestTypeConversionAnotherProducer>(arg: &getContext());
1670 mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
1671 converter);
1672
1673 if (failed(applyPartialConversion(getOperation(), target,
1674 std::move(patterns))))
1675 signalPassFailure();
1676 }
1677};
1678} // namespace
1679
1680//===----------------------------------------------------------------------===//
1681// Test Target Materialization With No Uses
1682//===----------------------------------------------------------------------===//
1683
1684namespace {
1685struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> {
1686 using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern;
1687
1688 LogicalResult
1689 matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor,
1690 ConversionPatternRewriter &rewriter) const final {
1691 rewriter.replaceOp(op, adaptor.getOperands());
1692 return success();
1693 }
1694};
1695
1696struct TestTargetMaterializationWithNoUses
1697 : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> {
1698 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
1699 TestTargetMaterializationWithNoUses)
1700
1701 StringRef getArgument() const final {
1702 return "test-target-materialization-with-no-uses";
1703 }
1704 StringRef getDescription() const final {
1705 return "Test a special case of target materialization in DialectConversion";
1706 }
1707
1708 void runOnOperation() override {
1709 TypeConverter converter;
1710 converter.addConversion(callback: [](Type t) { return t; });
1711 converter.addConversion(callback: [](IntegerType intTy) -> Type {
1712 if (intTy.getWidth() == 16)
1713 return IntegerType::get(intTy.getContext(), 64);
1714 return intTy;
1715 });
1716 converter.addTargetMaterialization(
1717 callback: [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1718 return builder.create<TestCastOp>(loc, type, inputs).getResult();
1719 });
1720
1721 ConversionTarget target(getContext());
1722 target.addIllegalOp<TestTypeChangerOp>();
1723
1724 RewritePatternSet patterns(&getContext());
1725 patterns.add<ForwardOperandPattern>(arg&: converter, args: &getContext());
1726
1727 if (failed(applyPartialConversion(getOperation(), target,
1728 std::move(patterns))))
1729 signalPassFailure();
1730 }
1731};
1732} // namespace
1733
1734//===----------------------------------------------------------------------===//
1735// Test Block Merging
1736//===----------------------------------------------------------------------===//
1737
1738namespace {
1739/// A rewriter pattern that tests that blocks can be merged.
1740struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
1741 using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
1742
1743 LogicalResult
1744 matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor,
1745 ConversionPatternRewriter &rewriter) const final {
1746 Block &firstBlock = op.getBody().front();
1747 Operation *branchOp = firstBlock.getTerminator();
1748 Block *secondBlock = &*(std::next(op.getBody().begin()));
1749 auto succOperands = branchOp->getOperands();
1750 SmallVector<Value, 2> replacements(succOperands);
1751 rewriter.eraseOp(op: branchOp);
1752 rewriter.mergeBlocks(source: secondBlock, dest: &firstBlock, argValues: replacements);
1753 rewriter.modifyOpInPlace(op, [] {});
1754 return success();
1755 }
1756};
1757
1758/// A rewrite pattern to tests the undo mechanism of blocks being merged.
1759struct TestUndoBlocksMerge : public ConversionPattern {
1760 TestUndoBlocksMerge(MLIRContext *ctx)
1761 : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
1762 LogicalResult
1763 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1764 ConversionPatternRewriter &rewriter) const final {
1765 Block &firstBlock = op->getRegion(index: 0).front();
1766 Operation *branchOp = firstBlock.getTerminator();
1767 Block *secondBlock = &*(std::next(x: op->getRegion(index: 0).begin()));
1768 rewriter.setInsertionPointToStart(secondBlock);
1769 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
1770 auto succOperands = branchOp->getOperands();
1771 SmallVector<Value, 2> replacements(succOperands);
1772 rewriter.eraseOp(op: branchOp);
1773 rewriter.mergeBlocks(source: secondBlock, dest: &firstBlock, argValues: replacements);
1774 rewriter.modifyOpInPlace(root: op, callable: [] {});
1775 return success();
1776 }
1777};
1778
1779/// A rewrite mechanism to inline the body of the op into its parent, when both
1780/// ops can have a single block.
1781struct TestMergeSingleBlockOps
1782 : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
1783 using OpConversionPattern<
1784 SingleBlockImplicitTerminatorOp>::OpConversionPattern;
1785
1786 LogicalResult
1787 matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor,
1788 ConversionPatternRewriter &rewriter) const final {
1789 SingleBlockImplicitTerminatorOp parentOp =
1790 op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1791 if (!parentOp)
1792 return failure();
1793 Block &innerBlock = op.getRegion().front();
1794 TerminatorOp innerTerminator =
1795 cast<TerminatorOp>(innerBlock.getTerminator());
1796 rewriter.inlineBlockBefore(&innerBlock, op);
1797 rewriter.eraseOp(op: innerTerminator);
1798 rewriter.eraseOp(op: op);
1799 return success();
1800 }
1801};
1802
1803struct TestMergeBlocksPatternDriver
1804 : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> {
1805 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver)
1806
1807 StringRef getArgument() const final { return "test-merge-blocks"; }
1808 StringRef getDescription() const final {
1809 return "Test Merging operation in ConversionPatternRewriter";
1810 }
1811 void runOnOperation() override {
1812 MLIRContext *context = &getContext();
1813 mlir::RewritePatternSet patterns(context);
1814 patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
1815 arg&: context);
1816 ConversionTarget target(*context);
1817 target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
1818 TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
1819 target.addIllegalOp<ILLegalOpF>();
1820
1821 /// Expect the op to have a single block after legalization.
1822 target.addDynamicallyLegalOp<TestMergeBlocksOp>(
1823 [&](TestMergeBlocksOp op) -> bool {
1824 return llvm::hasSingleElement(op.getBody());
1825 });
1826
1827 /// Only allow `test.br` within test.merge_blocks op.
1828 target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
1829 return op->getParentOfType<TestMergeBlocksOp>();
1830 });
1831
1832 /// Expect that all nested test.SingleBlockImplicitTerminator ops are
1833 /// inlined.
1834 target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
1835 [&](SingleBlockImplicitTerminatorOp op) -> bool {
1836 return !op->getParentOfType<SingleBlockImplicitTerminatorOp>();
1837 });
1838
1839 DenseSet<Operation *> unlegalizedOps;
1840 ConversionConfig config;
1841 config.unlegalizedOps = &unlegalizedOps;
1842 (void)applyPartialConversion(getOperation(), target, std::move(patterns),
1843 config);
1844 for (auto *op : unlegalizedOps)
1845 op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
1846 }
1847};
1848} // namespace
1849
1850//===----------------------------------------------------------------------===//
1851// Test Selective Replacement
1852//===----------------------------------------------------------------------===//
1853
1854namespace {
1855/// A rewrite mechanism to inline the body of the op into its parent, when both
1856/// ops can have a single block.
1857struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
1858 using OpRewritePattern<TestCastOp>::OpRewritePattern;
1859
1860 LogicalResult matchAndRewrite(TestCastOp op,
1861 PatternRewriter &rewriter) const final {
1862 if (op.getNumOperands() != 2)
1863 return failure();
1864 OperandRange operands = op.getOperands();
1865
1866 // Replace non-terminator uses with the first operand.
1867 rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) {
1868 return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
1869 });
1870 // Replace everything else with the second operand if the operation isn't
1871 // dead.
1872 rewriter.replaceOp(op, op.getOperand(1));
1873 return success();
1874 }
1875};
1876
1877struct TestSelectiveReplacementPatternDriver
1878 : public PassWrapper<TestSelectiveReplacementPatternDriver,
1879 OperationPass<>> {
1880 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
1881 TestSelectiveReplacementPatternDriver)
1882
1883 StringRef getArgument() const final {
1884 return "test-pattern-selective-replacement";
1885 }
1886 StringRef getDescription() const final {
1887 return "Test selective replacement in the PatternRewriter";
1888 }
1889 void runOnOperation() override {
1890 MLIRContext *context = &getContext();
1891 mlir::RewritePatternSet patterns(context);
1892 patterns.add<TestSelectiveOpReplacementPattern>(arg&: context);
1893 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
1894 }
1895};
1896} // namespace
1897
1898//===----------------------------------------------------------------------===//
1899// PassRegistration
1900//===----------------------------------------------------------------------===//
1901
1902namespace mlir {
1903namespace test {
1904void registerPatternsTestPass() {
1905 PassRegistration<TestReturnTypeDriver>();
1906
1907 PassRegistration<TestDerivedAttributeDriver>();
1908
1909 PassRegistration<TestPatternDriver>();
1910 PassRegistration<TestStrictPatternDriver>();
1911
1912 PassRegistration<TestLegalizePatternDriver>([] {
1913 return std::make_unique<TestLegalizePatternDriver>(args&: legalizerConversionMode);
1914 });
1915
1916 PassRegistration<TestRemappedValue>();
1917
1918 PassRegistration<TestUnknownRootOpDriver>();
1919
1920 PassRegistration<TestTypeConversionDriver>();
1921 PassRegistration<TestTargetMaterializationWithNoUses>();
1922
1923 PassRegistration<TestRewriteDynamicOpDriver>();
1924
1925 PassRegistration<TestMergeBlocksPatternDriver>();
1926 PassRegistration<TestSelectiveReplacementPatternDriver>();
1927}
1928} // namespace test
1929} // namespace mlir
1930

source code of mlir/test/lib/Dialect/Test/TestPatterns.cpp