1 | //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "TestDialect.h" |
10 | #include "TestOps.h" |
11 | #include "TestTypes.h" |
12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
13 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
14 | #include "mlir/Dialect/Func/Transforms/FuncConversions.h" |
15 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
16 | #include "mlir/IR/BuiltinAttributes.h" |
17 | #include "mlir/IR/Matchers.h" |
18 | #include "mlir/IR/Visitors.h" |
19 | #include "mlir/Pass/Pass.h" |
20 | #include "mlir/Transforms/DialectConversion.h" |
21 | #include "mlir/Transforms/FoldUtils.h" |
22 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
23 | #include "mlir/Transforms/WalkPatternRewriteDriver.h" |
24 | #include "llvm/ADT/ScopeExit.h" |
25 | #include <cstdint> |
26 | |
27 | using namespace mlir; |
28 | using namespace test; |
29 | |
30 | // Native function for testing NativeCodeCall |
31 | static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { |
32 | return choice.getValue() ? input1 : input2; |
33 | } |
34 | |
35 | static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { |
36 | rewriter.create<OpI>(loc, input); |
37 | } |
38 | |
39 | static void handleNoResultOp(PatternRewriter &rewriter, |
40 | OpSymbolBindingNoResult op) { |
41 | // Turn the no result op to a one-result op. |
42 | rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(), |
43 | op.getOperand()); |
44 | } |
45 | |
46 | static bool getFirstI32Result(Operation *op, Value &value) { |
47 | if (!Type(op->getResult(idx: 0).getType()).isSignlessInteger(width: 32)) |
48 | return false; |
49 | value = op->getResult(idx: 0); |
50 | return true; |
51 | } |
52 | |
53 | static Value bindNativeCodeCallResult(Value value) { return value; } |
54 | |
55 | static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, |
56 | Value input2) { |
57 | return SmallVector<Value, 2>({input2, input1}); |
58 | } |
59 | |
60 | // Test that natives calls are only called once during rewrites. |
61 | // OpM_Test will return Pi, increased by 1 for each subsequent calls. |
62 | // This let us check the number of times OpM_Test was called by inspecting |
63 | // the returned value in the MLIR output. |
64 | static int64_t opMIncreasingValue = 314159265; |
65 | static Attribute opMTest(PatternRewriter &rewriter, Value val) { |
66 | int64_t i = opMIncreasingValue++; |
67 | return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); |
68 | } |
69 | |
70 | namespace { |
71 | #include "TestPatterns.inc" |
72 | } // namespace |
73 | |
74 | //===----------------------------------------------------------------------===// |
75 | // Test Reduce Pattern Interface |
76 | //===----------------------------------------------------------------------===// |
77 | |
78 | void test::populateTestReductionPatterns(RewritePatternSet &patterns) { |
79 | populateWithGenerated(patterns); |
80 | } |
81 | |
82 | //===----------------------------------------------------------------------===// |
83 | // Canonicalizer Driver. |
84 | //===----------------------------------------------------------------------===// |
85 | |
86 | namespace { |
87 | struct FoldingPattern : public RewritePattern { |
88 | public: |
89 | FoldingPattern(MLIRContext *context) |
90 | : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), |
91 | /*benefit=*/1, context) {} |
92 | |
93 | LogicalResult matchAndRewrite(Operation *op, |
94 | PatternRewriter &rewriter) const override { |
95 | // Exercise createOrFold API for a single-result operation that is folded |
96 | // upon construction. The operation being created has an in-place folder, |
97 | // and it should be still present in the output. Furthermore, the folder |
98 | // should not crash when attempting to recover the (unchanged) operation |
99 | // result. |
100 | Value result = rewriter.createOrFold<TestOpInPlaceFold>( |
101 | op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0)); |
102 | assert(result); |
103 | rewriter.replaceOp(op, newValues: result); |
104 | return success(); |
105 | } |
106 | }; |
107 | |
108 | /// This pattern creates a foldable operation at the entry point of the block. |
109 | /// This tests the situation where the operation folder will need to replace an |
110 | /// operation with a previously created constant that does not initially |
111 | /// dominate the operation to replace. |
112 | struct FolderInsertBeforePreviouslyFoldedConstantPattern |
113 | : public OpRewritePattern<TestCastOp> { |
114 | public: |
115 | using OpRewritePattern<TestCastOp>::OpRewritePattern; |
116 | |
117 | LogicalResult matchAndRewrite(TestCastOp op, |
118 | PatternRewriter &rewriter) const override { |
119 | if (!op->hasAttr("test_fold_before_previously_folded_op")) |
120 | return failure(); |
121 | rewriter.setInsertionPointToStart(op->getBlock()); |
122 | |
123 | auto constOp = rewriter.create<arith::ConstantOp>( |
124 | op.getLoc(), rewriter.getBoolAttr(true)); |
125 | rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), |
126 | Value(constOp)); |
127 | return success(); |
128 | } |
129 | }; |
130 | |
131 | /// This pattern matches test.op_commutative2 with the first operand being |
132 | /// another test.op_commutative2 with a constant on the right side and fold it |
133 | /// away by propagating it as its result. This is intend to check that patterns |
134 | /// are applied after the commutative property moves constant to the right. |
135 | struct FolderCommutativeOp2WithConstant |
136 | : public OpRewritePattern<TestCommutative2Op> { |
137 | public: |
138 | using OpRewritePattern<TestCommutative2Op>::OpRewritePattern; |
139 | |
140 | LogicalResult matchAndRewrite(TestCommutative2Op op, |
141 | PatternRewriter &rewriter) const override { |
142 | auto operand = |
143 | dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp()); |
144 | if (!operand) |
145 | return failure(); |
146 | Attribute constInput; |
147 | if (!matchPattern(operand->getOperand(1), m_Constant(bind_value: &constInput))) |
148 | return failure(); |
149 | rewriter.replaceOp(op, operand->getOperand(1)); |
150 | return success(); |
151 | } |
152 | }; |
153 | |
154 | /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer |
155 | /// attribute with value smaller than MaxVal, it increments the value by 1. |
156 | template <int MaxVal> |
157 | struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> { |
158 | using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern; |
159 | |
160 | LogicalResult matchAndRewrite(AnyAttrOfOp op, |
161 | PatternRewriter &rewriter) const override { |
162 | auto intAttr = dyn_cast<IntegerAttr>(op.getAttr()); |
163 | if (!intAttr) |
164 | return failure(); |
165 | int64_t val = intAttr.getInt(); |
166 | if (val >= MaxVal) |
167 | return failure(); |
168 | rewriter.modifyOpInPlace( |
169 | op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); }); |
170 | return success(); |
171 | } |
172 | }; |
173 | |
174 | /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op". |
175 | struct MakeOpEligible : public RewritePattern { |
176 | MakeOpEligible(MLIRContext *context) |
177 | : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {} |
178 | |
179 | LogicalResult matchAndRewrite(Operation *op, |
180 | PatternRewriter &rewriter) const override { |
181 | if (op->hasAttr(name: "eligible")) |
182 | return failure(); |
183 | rewriter.modifyOpInPlace( |
184 | root: op, callable: [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); }); |
185 | return success(); |
186 | } |
187 | }; |
188 | |
189 | /// This pattern hoists eligible ops out of a "test.one_region_op". |
190 | struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> { |
191 | using OpRewritePattern<test::OneRegionOp>::OpRewritePattern; |
192 | |
193 | LogicalResult matchAndRewrite(test::OneRegionOp op, |
194 | PatternRewriter &rewriter) const override { |
195 | Operation *terminator = op.getRegion().front().getTerminator(); |
196 | Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp(); |
197 | if (toBeHoisted->getParentOp() != op) |
198 | return failure(); |
199 | if (!toBeHoisted->hasAttr(name: "eligible")) |
200 | return failure(); |
201 | rewriter.moveOpBefore(toBeHoisted, op); |
202 | return success(); |
203 | } |
204 | }; |
205 | |
206 | /// This pattern moves "test.move_before_parent_op" before the parent op. |
207 | struct MoveBeforeParentOp : public RewritePattern { |
208 | MoveBeforeParentOp(MLIRContext *context) |
209 | : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {} |
210 | |
211 | LogicalResult matchAndRewrite(Operation *op, |
212 | PatternRewriter &rewriter) const override { |
213 | // Do not hoist past functions. |
214 | if (isa<FunctionOpInterface>(Val: op->getParentOp())) |
215 | return failure(); |
216 | rewriter.moveOpBefore(op, existingOp: op->getParentOp()); |
217 | return success(); |
218 | } |
219 | }; |
220 | |
221 | /// This pattern moves "test.move_after_parent_op" after the parent op. |
222 | struct MoveAfterParentOp : public RewritePattern { |
223 | MoveAfterParentOp(MLIRContext *context) |
224 | : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {} |
225 | |
226 | LogicalResult matchAndRewrite(Operation *op, |
227 | PatternRewriter &rewriter) const override { |
228 | // Do not hoist past functions. |
229 | if (isa<FunctionOpInterface>(Val: op->getParentOp())) |
230 | return failure(); |
231 | |
232 | int64_t moveForwardBy = 0; |
233 | if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance")) |
234 | moveForwardBy = advanceBy.getInt(); |
235 | |
236 | Operation *moveAfter = op->getParentOp(); |
237 | for (int64_t i = 0; i < moveForwardBy; ++i) |
238 | moveAfter = moveAfter->getNextNode(); |
239 | |
240 | rewriter.moveOpAfter(op, existingOp: moveAfter); |
241 | return success(); |
242 | } |
243 | }; |
244 | |
245 | /// This pattern inlines blocks that are nested in |
246 | /// "test.inline_blocks_into_parent" into the parent block. |
247 | struct InlineBlocksIntoParent : public RewritePattern { |
248 | InlineBlocksIntoParent(MLIRContext *context) |
249 | : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1, |
250 | context) {} |
251 | |
252 | LogicalResult matchAndRewrite(Operation *op, |
253 | PatternRewriter &rewriter) const override { |
254 | bool changed = false; |
255 | for (Region &r : op->getRegions()) { |
256 | while (!r.empty()) { |
257 | rewriter.inlineBlockBefore(source: &r.front(), op); |
258 | changed = true; |
259 | } |
260 | } |
261 | return success(IsSuccess: changed); |
262 | } |
263 | }; |
264 | |
265 | /// This pattern splits blocks at "test.split_block_here" and replaces the op |
266 | /// with a new op (to prevent an infinite loop of block splitting). |
267 | struct SplitBlockHere : public RewritePattern { |
268 | SplitBlockHere(MLIRContext *context) |
269 | : RewritePattern("test.split_block_here", /*benefit=*/1, context) {} |
270 | |
271 | LogicalResult matchAndRewrite(Operation *op, |
272 | PatternRewriter &rewriter) const override { |
273 | rewriter.splitBlock(block: op->getBlock(), before: op->getIterator()); |
274 | Operation *newOp = rewriter.create( |
275 | op->getLoc(), |
276 | OperationName("test.new_op", op->getContext()).getIdentifier(), |
277 | op->getOperands(), op->getResultTypes()); |
278 | rewriter.replaceOp(op, newOp); |
279 | return success(); |
280 | } |
281 | }; |
282 | |
283 | /// This pattern clones "test.clone_me" ops. |
284 | struct CloneOp : public RewritePattern { |
285 | CloneOp(MLIRContext *context) |
286 | : RewritePattern("test.clone_me", /*benefit=*/1, context) {} |
287 | |
288 | LogicalResult matchAndRewrite(Operation *op, |
289 | PatternRewriter &rewriter) const override { |
290 | // Do not clone already cloned ops to avoid going into an infinite loop. |
291 | if (op->hasAttr(name: "was_cloned")) |
292 | return failure(); |
293 | Operation *cloned = rewriter.clone(op&: *op); |
294 | cloned->setAttr("was_cloned", rewriter.getUnitAttr()); |
295 | return success(); |
296 | } |
297 | }; |
298 | |
299 | /// This pattern clones regions of "test.clone_region_before" ops before the |
300 | /// parent block. |
301 | struct CloneRegionBeforeOp : public RewritePattern { |
302 | CloneRegionBeforeOp(MLIRContext *context) |
303 | : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {} |
304 | |
305 | LogicalResult matchAndRewrite(Operation *op, |
306 | PatternRewriter &rewriter) const override { |
307 | // Do not clone already cloned ops to avoid going into an infinite loop. |
308 | if (op->hasAttr(name: "was_cloned")) |
309 | return failure(); |
310 | for (Region &r : op->getRegions()) |
311 | rewriter.cloneRegionBefore(region&: r, before: op->getBlock()); |
312 | op->setAttr("was_cloned", rewriter.getUnitAttr()); |
313 | return success(); |
314 | } |
315 | }; |
316 | |
317 | /// Replace an operation may introduce the re-visiting of its users. |
318 | class ReplaceWithNewOp : public RewritePattern { |
319 | public: |
320 | ReplaceWithNewOp(MLIRContext *context) |
321 | : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {} |
322 | |
323 | LogicalResult matchAndRewrite(Operation *op, |
324 | PatternRewriter &rewriter) const override { |
325 | Operation *newOp; |
326 | if (op->hasAttr(name: "create_erase_op")) { |
327 | newOp = rewriter.create( |
328 | op->getLoc(), |
329 | OperationName("test.erase_op", op->getContext()).getIdentifier(), |
330 | ValueRange(), TypeRange()); |
331 | } else { |
332 | newOp = rewriter.create( |
333 | op->getLoc(), |
334 | OperationName("test.new_op", op->getContext()).getIdentifier(), |
335 | op->getOperands(), op->getResultTypes()); |
336 | } |
337 | // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp". |
338 | // A "notifyOperationReplaced" callback is triggered in either case. |
339 | rewriter.replaceAllOpUsesWith(from: op, to: newOp->getResults()); |
340 | rewriter.eraseOp(op); |
341 | return success(); |
342 | } |
343 | }; |
344 | |
345 | /// Erases the first child block of the matched "test.erase_first_block" |
346 | /// operation. |
347 | class EraseFirstBlock : public RewritePattern { |
348 | public: |
349 | EraseFirstBlock(MLIRContext *context) |
350 | : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {} |
351 | |
352 | LogicalResult matchAndRewrite(Operation *op, |
353 | PatternRewriter &rewriter) const override { |
354 | for (Region &r : op->getRegions()) { |
355 | for (Block &b : r.getBlocks()) { |
356 | rewriter.eraseBlock(block: &b); |
357 | return success(); |
358 | } |
359 | } |
360 | |
361 | return failure(); |
362 | } |
363 | }; |
364 | |
365 | struct TestGreedyPatternDriver |
366 | : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> { |
367 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver) |
368 | |
369 | TestGreedyPatternDriver() = default; |
370 | TestGreedyPatternDriver(const TestGreedyPatternDriver &other) |
371 | : PassWrapper(other) {} |
372 | |
373 | StringRef getArgument() const final { return "test-greedy-patterns"; } |
374 | StringRef getDescription() const final { return "Run test dialect patterns"; } |
375 | void runOnOperation() override { |
376 | mlir::RewritePatternSet patterns(&getContext()); |
377 | populateWithGenerated(patterns); |
378 | |
379 | // Verify named pattern is generated with expected name. |
380 | patterns.add<FoldingPattern, TestNamedPatternRule, |
381 | FolderInsertBeforePreviouslyFoldedConstantPattern, |
382 | FolderCommutativeOp2WithConstant, HoistEligibleOps, |
383 | MakeOpEligible>(&getContext()); |
384 | |
385 | // Additional patterns for testing the GreedyPatternRewriteDriver. |
386 | patterns.insert<IncrementIntAttribute<3>>(arg: &getContext()); |
387 | |
388 | GreedyRewriteConfig config; |
389 | config.setUseTopDownTraversal(useTopDownTraversal) |
390 | .setMaxIterations(this->maxIterations) |
391 | .enableFolding(enable: this->fold) |
392 | .enableConstantCSE(enable: this->cseConstants); |
393 | (void)applyPatternsGreedily(getOperation(), std::move(patterns), config); |
394 | } |
395 | |
396 | Option<bool> useTopDownTraversal{ |
397 | *this, "top-down", |
398 | llvm::cl::desc("Seed the worklist in general top-down order"), |
399 | llvm::cl::init(Val: GreedyRewriteConfig().getUseTopDownTraversal())}; |
400 | Option<int> maxIterations{ |
401 | *this, "max-iterations", |
402 | llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"), |
403 | llvm::cl::init(Val: GreedyRewriteConfig().getMaxIterations())}; |
404 | Option<bool> fold{*this, "fold", llvm::cl::desc( "Whether to fold"), |
405 | llvm::cl::init(Val: GreedyRewriteConfig().isFoldingEnabled())}; |
406 | Option<bool> cseConstants{ |
407 | *this, "cse-constants", llvm::cl::desc( "Whether to CSE constants"), |
408 | llvm::cl::init(Val: GreedyRewriteConfig().isConstantCSEEnabled())}; |
409 | }; |
410 | |
411 | struct DumpNotifications : public RewriterBase::Listener { |
412 | void notifyBlockInserted(Block *block, Region *previous, |
413 | Region::iterator previousIt) override { |
414 | llvm::outs() << "notifyBlockInserted"; |
415 | if (block->getParentOp()) { |
416 | llvm::outs() << " into "<< block->getParentOp()->getName() << ": "; |
417 | } else { |
418 | llvm::outs() << " into unknown op: "; |
419 | } |
420 | if (previous == nullptr) { |
421 | llvm::outs() << "was unlinked\n"; |
422 | } else { |
423 | llvm::outs() << "was linked\n"; |
424 | } |
425 | } |
426 | void notifyOperationInserted(Operation *op, |
427 | OpBuilder::InsertPoint previous) override { |
428 | llvm::outs() << "notifyOperationInserted: "<< op->getName(); |
429 | if (!previous.isSet()) { |
430 | llvm::outs() << ", was unlinked\n"; |
431 | } else { |
432 | if (!previous.getPoint().getNodePtr()) { |
433 | llvm::outs() << ", was linked, exact position unknown\n"; |
434 | } else if (previous.getPoint() == previous.getBlock()->end()) { |
435 | llvm::outs() << ", was last in block\n"; |
436 | } else { |
437 | llvm::outs() << ", previous = "<< previous.getPoint()->getName() |
438 | << "\n"; |
439 | } |
440 | } |
441 | } |
442 | void notifyBlockErased(Block *block) override { |
443 | llvm::outs() << "notifyBlockErased\n"; |
444 | } |
445 | void notifyOperationErased(Operation *op) override { |
446 | llvm::outs() << "notifyOperationErased: "<< op->getName() << "\n"; |
447 | } |
448 | void notifyOperationModified(Operation *op) override { |
449 | llvm::outs() << "notifyOperationModified: "<< op->getName() << "\n"; |
450 | } |
451 | void notifyOperationReplaced(Operation *op, ValueRange values) override { |
452 | llvm::outs() << "notifyOperationReplaced: "<< op->getName() << "\n"; |
453 | } |
454 | }; |
455 | |
456 | struct TestStrictPatternDriver |
457 | : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> { |
458 | public: |
459 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver) |
460 | |
461 | TestStrictPatternDriver() = default; |
462 | TestStrictPatternDriver(const TestStrictPatternDriver &other) |
463 | : PassWrapper(other) { |
464 | strictMode = other.strictMode; |
465 | } |
466 | |
467 | StringRef getArgument() const final { return "test-strict-pattern-driver"; } |
468 | StringRef getDescription() const final { |
469 | return "Test strict mode of pattern driver"; |
470 | } |
471 | |
472 | void runOnOperation() override { |
473 | MLIRContext *ctx = &getContext(); |
474 | mlir::RewritePatternSet patterns(ctx); |
475 | patterns.add< |
476 | // clang-format off |
477 | ChangeBlockOp, |
478 | CloneOp, |
479 | CloneRegionBeforeOp, |
480 | EraseOp, |
481 | ImplicitChangeOp, |
482 | InlineBlocksIntoParent, |
483 | InsertSameOp, |
484 | MoveBeforeParentOp, |
485 | ReplaceWithNewOp, |
486 | SplitBlockHere |
487 | // clang-format on |
488 | >(arg&: ctx); |
489 | SmallVector<Operation *> ops; |
490 | getOperation()->walk([&](Operation *op) { |
491 | StringRef opName = op->getName().getStringRef(); |
492 | if (opName == "test.insert_same_op"|| opName == "test.change_block_op"|| |
493 | opName == "test.replace_with_new_op"|| opName == "test.erase_op"|| |
494 | opName == "test.move_before_parent_op"|| |
495 | opName == "test.inline_blocks_into_parent"|| |
496 | opName == "test.split_block_here"|| opName == "test.clone_me"|| |
497 | opName == "test.clone_region_before") { |
498 | ops.push_back(Elt: op); |
499 | } |
500 | }); |
501 | |
502 | DumpNotifications dumpNotifications; |
503 | GreedyRewriteConfig config; |
504 | config.setListener(&dumpNotifications); |
505 | if (strictMode == "AnyOp") { |
506 | config.setStrictness(GreedyRewriteStrictness::AnyOp); |
507 | } else if (strictMode == "ExistingAndNewOps") { |
508 | config.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps); |
509 | } else if (strictMode == "ExistingOps") { |
510 | config.setStrictness(GreedyRewriteStrictness::ExistingOps); |
511 | } else { |
512 | llvm_unreachable("invalid strictness option"); |
513 | } |
514 | |
515 | // Check if these transformations introduce visiting of operations that |
516 | // are not in the `ops` set (The new created ops are valid). An invalid |
517 | // operation will trigger the assertion while processing. |
518 | bool changed = false; |
519 | bool allErased = false; |
520 | (void)applyOpPatternsGreedily(ArrayRef(ops), std::move(patterns), config, |
521 | &changed, &allErased); |
522 | Builder b(ctx); |
523 | getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(value: changed)); |
524 | getOperation()->setAttr("pattern_driver_all_erased", |
525 | b.getBoolAttr(value: allErased)); |
526 | } |
527 | |
528 | Option<std::string> strictMode{ |
529 | *this, "strictness", |
530 | llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"), |
531 | llvm::cl::init("AnyOp")}; |
532 | |
533 | private: |
534 | // New inserted operation is valid for further transformation. |
535 | class InsertSameOp : public RewritePattern { |
536 | public: |
537 | InsertSameOp(MLIRContext *context) |
538 | : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {} |
539 | |
540 | LogicalResult matchAndRewrite(Operation *op, |
541 | PatternRewriter &rewriter) const override { |
542 | if (op->hasAttr(name: "skip")) |
543 | return failure(); |
544 | |
545 | Operation *newOp = |
546 | rewriter.create(op->getLoc(), op->getName().getIdentifier(), |
547 | op->getOperands(), op->getResultTypes()); |
548 | rewriter.modifyOpInPlace( |
549 | root: op, callable: [&]() { op->setAttr(name: "skip", value: rewriter.getBoolAttr(value: true)); }); |
550 | newOp->setAttr(name: "skip", value: rewriter.getBoolAttr(value: true)); |
551 | |
552 | return success(); |
553 | } |
554 | }; |
555 | |
556 | // Remove an operation may introduce the re-visiting of its operands. |
557 | class EraseOp : public RewritePattern { |
558 | public: |
559 | EraseOp(MLIRContext *context) |
560 | : RewritePattern("test.erase_op", /*benefit=*/1, context) {} |
561 | LogicalResult matchAndRewrite(Operation *op, |
562 | PatternRewriter &rewriter) const override { |
563 | rewriter.eraseOp(op); |
564 | return success(); |
565 | } |
566 | }; |
567 | |
568 | // The following two patterns test RewriterBase::replaceAllUsesWith. |
569 | // |
570 | // That function replaces all usages of a Block (or a Value) with another one |
571 | // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver |
572 | // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its |
573 | // worklist: when an op is modified, it is added to the worklist. The two |
574 | // patterns below make the tracking observable: ChangeBlockOp replaces all |
575 | // usages of a block and that pattern is applied because the corresponding ops |
576 | // are put on the initial worklist (see above). ImplicitChangeOp does an |
577 | // unrelated change but ops of the corresponding type are *not* on the initial |
578 | // worklist, so the effect of the second pattern is only visible if the |
579 | // tracking and subsequent adding to the worklist actually works. |
580 | |
581 | // Replace all usages of the first successor with the second successor. |
582 | class ChangeBlockOp : public RewritePattern { |
583 | public: |
584 | ChangeBlockOp(MLIRContext *context) |
585 | : RewritePattern("test.change_block_op", /*benefit=*/1, context) {} |
586 | LogicalResult matchAndRewrite(Operation *op, |
587 | PatternRewriter &rewriter) const override { |
588 | if (op->getNumSuccessors() < 2) |
589 | return failure(); |
590 | Block *firstSuccessor = op->getSuccessor(index: 0); |
591 | Block *secondSuccessor = op->getSuccessor(index: 1); |
592 | if (firstSuccessor == secondSuccessor) |
593 | return failure(); |
594 | // This is the function being tested: |
595 | rewriter.replaceAllUsesWith(from: firstSuccessor, to: secondSuccessor); |
596 | // Using the following line instead would make the test fail: |
597 | // firstSuccessor->replaceAllUsesWith(secondSuccessor); |
598 | return success(); |
599 | } |
600 | }; |
601 | |
602 | // Changes the successor to the parent block. |
603 | class ImplicitChangeOp : public RewritePattern { |
604 | public: |
605 | ImplicitChangeOp(MLIRContext *context) |
606 | : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {} |
607 | LogicalResult matchAndRewrite(Operation *op, |
608 | PatternRewriter &rewriter) const override { |
609 | if (op->getNumSuccessors() < 1 || op->getSuccessor(index: 0) == op->getBlock()) |
610 | return failure(); |
611 | rewriter.modifyOpInPlace(root: op, |
612 | callable: [&]() { op->setSuccessor(block: op->getBlock(), index: 0); }); |
613 | return success(); |
614 | } |
615 | }; |
616 | }; |
617 | |
618 | struct TestWalkPatternDriver final |
619 | : PassWrapper<TestWalkPatternDriver, OperationPass<>> { |
620 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver) |
621 | |
622 | TestWalkPatternDriver() = default; |
623 | TestWalkPatternDriver(const TestWalkPatternDriver &other) |
624 | : PassWrapper(other) {} |
625 | |
626 | StringRef getArgument() const override { |
627 | return "test-walk-pattern-rewrite-driver"; |
628 | } |
629 | StringRef getDescription() const override { |
630 | return "Run test walk pattern rewrite driver"; |
631 | } |
632 | void runOnOperation() override { |
633 | mlir::RewritePatternSet patterns(&getContext()); |
634 | |
635 | // Patterns for testing the WalkPatternRewriteDriver. |
636 | patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp, |
637 | MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>( |
638 | arg: &getContext()); |
639 | |
640 | DumpNotifications dumpListener; |
641 | walkAndApplyPatterns(getOperation(), std::move(patterns), |
642 | dumpNotifications ? &dumpListener : nullptr); |
643 | } |
644 | |
645 | Option<bool> dumpNotifications{ |
646 | *this, "dump-notifications", |
647 | llvm::cl::desc("Print rewrite listener notifications"), |
648 | llvm::cl::init(Val: false)}; |
649 | }; |
650 | |
651 | } // namespace |
652 | |
653 | //===----------------------------------------------------------------------===// |
654 | // ReturnType Driver. |
655 | //===----------------------------------------------------------------------===// |
656 | |
657 | namespace { |
658 | // Generate ops for each instance where the type can be successfully inferred. |
659 | template <typename OpTy> |
660 | static void invokeCreateWithInferredReturnType(Operation *op) { |
661 | auto *context = op->getContext(); |
662 | auto fop = op->getParentOfType<func::FuncOp>(); |
663 | auto location = UnknownLoc::get(context); |
664 | OpBuilder b(op); |
665 | b.setInsertionPointAfter(op); |
666 | |
667 | // Use permutations of 2 args as operands. |
668 | assert(fop.getNumArguments() >= 2); |
669 | for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { |
670 | for (int j = 0; j < e; ++j) { |
671 | std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; |
672 | SmallVector<Type, 2> inferredReturnTypes; |
673 | if (succeeded(OpTy::inferReturnTypes( |
674 | context, std::nullopt, values, op->getDiscardableAttrDictionary(), |
675 | op->getPropertiesStorage(), op->getRegions(), |
676 | inferredReturnTypes))) { |
677 | OperationState state(location, OpTy::getOperationName()); |
678 | // TODO: Expand to regions. |
679 | OpTy::build(b, state, values, op->getAttrs()); |
680 | (void)b.create(state); |
681 | } |
682 | } |
683 | } |
684 | } |
685 | |
686 | static void reifyReturnShape(Operation *op) { |
687 | OpBuilder b(op); |
688 | |
689 | // Use permutations of 2 args as operands. |
690 | auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); |
691 | SmallVector<Value, 2> shapes; |
692 | if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || |
693 | !llvm::hasSingleElement(C&: shapes)) |
694 | return; |
695 | for (const auto &it : llvm::enumerate(First&: shapes)) { |
696 | op->emitRemark() << "value "<< it.index() << ": " |
697 | << it.value().getDefiningOp(); |
698 | } |
699 | } |
700 | |
701 | struct TestReturnTypeDriver |
702 | : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> { |
703 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver) |
704 | |
705 | void getDependentDialects(DialectRegistry ®istry) const override { |
706 | registry.insert<tensor::TensorDialect>(); |
707 | } |
708 | StringRef getArgument() const final { return "test-return-type"; } |
709 | StringRef getDescription() const final { return "Run return type functions"; } |
710 | |
711 | void runOnOperation() override { |
712 | if (getOperation().getName() == "testCreateFunctions") { |
713 | std::vector<Operation *> ops; |
714 | // Collect ops to avoid triggering on inserted ops. |
715 | for (auto &op : getOperation().getBody().front()) |
716 | ops.push_back(&op); |
717 | // Generate test patterns for each, but skip terminator. |
718 | for (auto *op : llvm::ArrayRef(ops).drop_back()) { |
719 | // Test create method of each of the Op classes below. The resultant |
720 | // output would be in reverse order underneath `op` from which |
721 | // the attributes and regions are used. |
722 | invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); |
723 | invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>( |
724 | op); |
725 | invokeCreateWithInferredReturnType< |
726 | OpWithShapedTypeInferTypeInterfaceOp>(op); |
727 | }; |
728 | return; |
729 | } |
730 | if (getOperation().getName() == "testReifyFunctions") { |
731 | std::vector<Operation *> ops; |
732 | // Collect ops to avoid triggering on inserted ops. |
733 | for (auto &op : getOperation().getBody().front()) |
734 | if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) |
735 | ops.push_back(&op); |
736 | // Generate test patterns for each, but skip terminator. |
737 | for (auto *op : ops) |
738 | reifyReturnShape(op); |
739 | } |
740 | } |
741 | }; |
742 | } // namespace |
743 | |
744 | namespace { |
745 | struct TestDerivedAttributeDriver |
746 | : public PassWrapper<TestDerivedAttributeDriver, |
747 | OperationPass<func::FuncOp>> { |
748 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver) |
749 | |
750 | StringRef getArgument() const final { return "test-derived-attr"; } |
751 | StringRef getDescription() const final { |
752 | return "Run test derived attributes"; |
753 | } |
754 | void runOnOperation() override; |
755 | }; |
756 | } // namespace |
757 | |
758 | void TestDerivedAttributeDriver::runOnOperation() { |
759 | getOperation().walk([](DerivedAttributeOpInterface dOp) { |
760 | auto dAttr = dOp.materializeDerivedAttributes(); |
761 | if (!dAttr) |
762 | return; |
763 | for (auto d : dAttr) |
764 | dOp.emitRemark() << d.getName().getValue() << " = "<< d.getValue(); |
765 | }); |
766 | } |
767 | |
768 | //===----------------------------------------------------------------------===// |
769 | // Legalization Driver. |
770 | //===----------------------------------------------------------------------===// |
771 | |
772 | namespace { |
773 | //===----------------------------------------------------------------------===// |
774 | // Region-Block Rewrite Testing |
775 | //===----------------------------------------------------------------------===// |
776 | |
777 | /// This pattern applies a signature conversion to a block inside a detached |
778 | /// region. |
779 | struct TestDetachedSignatureConversion : public ConversionPattern { |
780 | TestDetachedSignatureConversion(MLIRContext *ctx) |
781 | : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1, |
782 | ctx) {} |
783 | |
784 | LogicalResult |
785 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
786 | ConversionPatternRewriter &rewriter) const final { |
787 | if (op->getNumRegions() != 1) |
788 | return failure(); |
789 | OperationState state(op->getLoc(), "test.legal_op", operands, |
790 | op->getResultTypes(), {}, BlockRange()); |
791 | Region *newRegion = state.addRegion(); |
792 | rewriter.inlineRegionBefore(region&: op->getRegion(index: 0), parent&: *newRegion, |
793 | before: newRegion->begin()); |
794 | TypeConverter::SignatureConversion result(newRegion->getNumArguments()); |
795 | for (unsigned i = 0, e = newRegion->getNumArguments(); i < e; ++i) |
796 | result.addInputs(i, rewriter.getF64Type()); |
797 | rewriter.applySignatureConversion(block: &newRegion->front(), conversion&: result); |
798 | Operation *newOp = rewriter.create(state); |
799 | rewriter.replaceOp(op, newValues: newOp->getResults()); |
800 | return success(); |
801 | } |
802 | }; |
803 | |
804 | /// This pattern is a simple pattern that inlines the first region of a given |
805 | /// operation into the parent region. |
806 | struct TestRegionRewriteBlockMovement : public ConversionPattern { |
807 | TestRegionRewriteBlockMovement(MLIRContext *ctx) |
808 | : ConversionPattern("test.region", 1, ctx) {} |
809 | |
810 | LogicalResult |
811 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
812 | ConversionPatternRewriter &rewriter) const final { |
813 | // Inline this region into the parent region. |
814 | auto &parentRegion = *op->getParentRegion(); |
815 | auto &opRegion = op->getRegion(index: 0); |
816 | if (op->getDiscardableAttr(name: "legalizer.should_clone")) |
817 | rewriter.cloneRegionBefore(region&: opRegion, parent&: parentRegion, before: parentRegion.end()); |
818 | else |
819 | rewriter.inlineRegionBefore(region&: opRegion, parent&: parentRegion, before: parentRegion.end()); |
820 | |
821 | if (op->getDiscardableAttr(name: "legalizer.erase_old_blocks")) { |
822 | while (!opRegion.empty()) |
823 | rewriter.eraseBlock(block: &opRegion.front()); |
824 | } |
825 | |
826 | // Drop this operation. |
827 | rewriter.eraseOp(op); |
828 | return success(); |
829 | } |
830 | }; |
831 | /// This pattern is a simple pattern that generates a region containing an |
832 | /// illegal operation. |
833 | struct TestRegionRewriteUndo : public RewritePattern { |
834 | TestRegionRewriteUndo(MLIRContext *ctx) |
835 | : RewritePattern("test.region_builder", 1, ctx) {} |
836 | |
837 | LogicalResult matchAndRewrite(Operation *op, |
838 | PatternRewriter &rewriter) const final { |
839 | // Create the region operation with an entry block containing arguments. |
840 | OperationState newRegion(op->getLoc(), "test.region"); |
841 | newRegion.addRegion(); |
842 | auto *regionOp = rewriter.create(state: newRegion); |
843 | auto *entryBlock = rewriter.createBlock(parent: ®ionOp->getRegion(index: 0)); |
844 | entryBlock->addArgument(rewriter.getIntegerType(64), |
845 | rewriter.getUnknownLoc()); |
846 | |
847 | // Add an explicitly illegal operation to ensure the conversion fails. |
848 | rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); |
849 | rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); |
850 | |
851 | // Drop this operation. |
852 | rewriter.eraseOp(op); |
853 | return success(); |
854 | } |
855 | }; |
856 | /// A simple pattern that creates a block at the end of the parent region of the |
857 | /// matched operation. |
858 | struct TestCreateBlock : public RewritePattern { |
859 | TestCreateBlock(MLIRContext *ctx) |
860 | : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} |
861 | |
862 | LogicalResult matchAndRewrite(Operation *op, |
863 | PatternRewriter &rewriter) const final { |
864 | Region ®ion = *op->getParentRegion(); |
865 | Type i32Type = rewriter.getIntegerType(32); |
866 | Location loc = op->getLoc(); |
867 | rewriter.createBlock(parent: ®ion, insertPt: region.end(), argTypes: {i32Type, i32Type}, locs: {loc, loc}); |
868 | rewriter.create<TerminatorOp>(loc); |
869 | rewriter.eraseOp(op); |
870 | return success(); |
871 | } |
872 | }; |
873 | |
874 | /// A simple pattern that creates a block containing an invalid operation in |
875 | /// order to trigger the block creation undo mechanism. |
876 | struct TestCreateIllegalBlock : public RewritePattern { |
877 | TestCreateIllegalBlock(MLIRContext *ctx) |
878 | : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} |
879 | |
880 | LogicalResult matchAndRewrite(Operation *op, |
881 | PatternRewriter &rewriter) const final { |
882 | Region ®ion = *op->getParentRegion(); |
883 | Type i32Type = rewriter.getIntegerType(32); |
884 | Location loc = op->getLoc(); |
885 | rewriter.createBlock(parent: ®ion, insertPt: region.end(), argTypes: {i32Type, i32Type}, locs: {loc, loc}); |
886 | // Create an illegal op to ensure the conversion fails. |
887 | rewriter.create<ILLegalOpF>(loc, i32Type); |
888 | rewriter.create<TerminatorOp>(loc); |
889 | rewriter.eraseOp(op); |
890 | return success(); |
891 | } |
892 | }; |
893 | |
894 | /// A simple pattern that tests the undo mechanism when replacing the uses of a |
895 | /// block argument. |
896 | struct TestUndoBlockArgReplace : public ConversionPattern { |
897 | TestUndoBlockArgReplace(MLIRContext *ctx) |
898 | : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} |
899 | |
900 | LogicalResult |
901 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
902 | ConversionPatternRewriter &rewriter) const final { |
903 | auto illegalOp = |
904 | rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); |
905 | rewriter.replaceUsesOfBlockArgument(from: op->getRegion(index: 0).getArgument(i: 0), |
906 | to: illegalOp->getResult(0)); |
907 | rewriter.modifyOpInPlace(root: op, callable: [] {}); |
908 | return success(); |
909 | } |
910 | }; |
911 | |
912 | /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion. |
913 | /// This is to test the rollback logic. |
914 | struct TestUndoMoveOpBefore : public ConversionPattern { |
915 | TestUndoMoveOpBefore(MLIRContext *ctx) |
916 | : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {} |
917 | |
918 | LogicalResult |
919 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
920 | ConversionPatternRewriter &rewriter) const override { |
921 | rewriter.moveOpBefore(op, existingOp: op->getParentOp()); |
922 | // Replace with an illegal op to ensure the conversion fails. |
923 | rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type()); |
924 | return success(); |
925 | } |
926 | }; |
927 | |
928 | /// A rewrite pattern that tests the undo mechanism when erasing a block. |
929 | struct TestUndoBlockErase : public ConversionPattern { |
930 | TestUndoBlockErase(MLIRContext *ctx) |
931 | : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} |
932 | |
933 | LogicalResult |
934 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
935 | ConversionPatternRewriter &rewriter) const final { |
936 | Block *secondBlock = &*std::next(x: op->getRegion(index: 0).begin()); |
937 | rewriter.setInsertionPointToStart(secondBlock); |
938 | rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); |
939 | rewriter.eraseBlock(block: secondBlock); |
940 | rewriter.modifyOpInPlace(root: op, callable: [] {}); |
941 | return success(); |
942 | } |
943 | }; |
944 | |
945 | /// A pattern that modifies a property in-place, but keeps the op illegal. |
946 | struct TestUndoPropertiesModification : public ConversionPattern { |
947 | TestUndoPropertiesModification(MLIRContext *ctx) |
948 | : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {} |
949 | LogicalResult |
950 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
951 | ConversionPatternRewriter &rewriter) const final { |
952 | if (!op->hasAttr(name: "modify_inplace")) |
953 | return failure(); |
954 | rewriter.modifyOpInPlace( |
955 | root: op, callable: [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); }); |
956 | return success(); |
957 | } |
958 | }; |
959 | |
960 | //===----------------------------------------------------------------------===// |
961 | // Type-Conversion Rewrite Testing |
962 | //===----------------------------------------------------------------------===// |
963 | |
964 | /// This patterns erases a region operation that has had a type conversion. |
965 | struct TestDropOpSignatureConversion : public ConversionPattern { |
966 | TestDropOpSignatureConversion(MLIRContext *ctx, |
967 | const TypeConverter &converter) |
968 | : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} |
969 | LogicalResult |
970 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
971 | ConversionPatternRewriter &rewriter) const override { |
972 | Region ®ion = op->getRegion(index: 0); |
973 | Block *entry = ®ion.front(); |
974 | |
975 | // Convert the original entry arguments. |
976 | const TypeConverter &converter = *getTypeConverter(); |
977 | TypeConverter::SignatureConversion result(entry->getNumArguments()); |
978 | if (failed(Result: converter.convertSignatureArgs(types: entry->getArgumentTypes(), |
979 | result)) || |
980 | failed(Result: rewriter.convertRegionTypes(region: ®ion, converter, entryConversion: &result))) |
981 | return failure(); |
982 | |
983 | // Convert the region signature and just drop the operation. |
984 | rewriter.eraseOp(op); |
985 | return success(); |
986 | } |
987 | }; |
988 | /// This pattern simply updates the operands of the given operation. |
989 | struct TestPassthroughInvalidOp : public ConversionPattern { |
990 | TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter) |
991 | : ConversionPattern(converter, "test.invalid", 1, ctx) {} |
992 | LogicalResult |
993 | matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, |
994 | ConversionPatternRewriter &rewriter) const final { |
995 | SmallVector<Value> flattened; |
996 | for (auto it : llvm::enumerate(First&: operands)) { |
997 | ValueRange range = it.value(); |
998 | if (range.size() == 1) { |
999 | flattened.push_back(Elt: range.front()); |
1000 | continue; |
1001 | } |
1002 | |
1003 | // This is a 1:N replacement. Insert a test.cast op. (That's what the |
1004 | // argument materialization used to do.) |
1005 | flattened.push_back( |
1006 | rewriter |
1007 | .create<TestCastOp>(op->getLoc(), |
1008 | op->getOperand(it.index()).getType(), range) |
1009 | .getResult()); |
1010 | } |
1011 | rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, flattened, |
1012 | std::nullopt); |
1013 | return success(); |
1014 | } |
1015 | }; |
1016 | /// Replace with valid op, but simply drop the operands. This is used in a |
1017 | /// regression where we used to generate circular unrealized_conversion_cast |
1018 | /// ops. |
1019 | struct TestDropAndReplaceInvalidOp : public ConversionPattern { |
1020 | TestDropAndReplaceInvalidOp(MLIRContext *ctx, const TypeConverter &converter) |
1021 | : ConversionPattern(converter, |
1022 | "test.drop_operands_and_replace_with_valid", 1, ctx) { |
1023 | } |
1024 | LogicalResult |
1025 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
1026 | ConversionPatternRewriter &rewriter) const final { |
1027 | rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, ValueRange(), |
1028 | std::nullopt); |
1029 | return success(); |
1030 | } |
1031 | }; |
1032 | /// This pattern handles the case of a split return value. |
1033 | struct TestSplitReturnType : public ConversionPattern { |
1034 | TestSplitReturnType(MLIRContext *ctx) |
1035 | : ConversionPattern("test.return", 1, ctx) {} |
1036 | LogicalResult |
1037 | matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, |
1038 | ConversionPatternRewriter &rewriter) const final { |
1039 | // Check for a return of F32. |
1040 | if (op->getNumOperands() != 1 || !op->getOperand(idx: 0).getType().isF32()) |
1041 | return failure(); |
1042 | rewriter.replaceOpWithNewOp<TestReturnOp>(op, operands[0]); |
1043 | return success(); |
1044 | } |
1045 | }; |
1046 | |
1047 | //===----------------------------------------------------------------------===// |
1048 | // Multi-Level Type-Conversion Rewrite Testing |
1049 | struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { |
1050 | TestChangeProducerTypeI32ToF32(MLIRContext *ctx) |
1051 | : ConversionPattern("test.type_producer", 1, ctx) {} |
1052 | LogicalResult |
1053 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
1054 | ConversionPatternRewriter &rewriter) const final { |
1055 | // If the type is I32, change the type to F32. |
1056 | if (!Type(*op->result_type_begin()).isSignlessInteger(width: 32)) |
1057 | return failure(); |
1058 | rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); |
1059 | return success(); |
1060 | } |
1061 | }; |
1062 | struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { |
1063 | TestChangeProducerTypeF32ToF64(MLIRContext *ctx) |
1064 | : ConversionPattern("test.type_producer", 1, ctx) {} |
1065 | LogicalResult |
1066 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
1067 | ConversionPatternRewriter &rewriter) const final { |
1068 | // If the type is F32, change the type to F64. |
1069 | if (!Type(*op->result_type_begin()).isF32()) |
1070 | return rewriter.notifyMatchFailure(arg&: op, msg: "expected single f32 operand"); |
1071 | rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); |
1072 | return success(); |
1073 | } |
1074 | }; |
1075 | struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { |
1076 | TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) |
1077 | : ConversionPattern("test.type_producer", 10, ctx) {} |
1078 | LogicalResult |
1079 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
1080 | ConversionPatternRewriter &rewriter) const final { |
1081 | // Always convert to B16, even though it is not a legal type. This tests |
1082 | // that values are unmapped correctly. |
1083 | rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); |
1084 | return success(); |
1085 | } |
1086 | }; |
1087 | struct TestUpdateConsumerType : public ConversionPattern { |
1088 | TestUpdateConsumerType(MLIRContext *ctx) |
1089 | : ConversionPattern("test.type_consumer", 1, ctx) {} |
1090 | LogicalResult |
1091 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
1092 | ConversionPatternRewriter &rewriter) const final { |
1093 | // Verify that the incoming operand has been successfully remapped to F64. |
1094 | if (!operands[0].getType().isF64()) |
1095 | return failure(); |
1096 | rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); |
1097 | return success(); |
1098 | } |
1099 | }; |
1100 | |
1101 | //===----------------------------------------------------------------------===// |
1102 | // Non-Root Replacement Rewrite Testing |
1103 | /// This pattern generates an invalid operation, but replaces it before the |
1104 | /// pattern is finished. This checks that we don't need to legalize the |
1105 | /// temporary op. |
1106 | struct TestNonRootReplacement : public RewritePattern { |
1107 | TestNonRootReplacement(MLIRContext *ctx) |
1108 | : RewritePattern("test.replace_non_root", 1, ctx) {} |
1109 | |
1110 | LogicalResult matchAndRewrite(Operation *op, |
1111 | PatternRewriter &rewriter) const final { |
1112 | auto resultType = *op->result_type_begin(); |
1113 | auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); |
1114 | auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); |
1115 | |
1116 | rewriter.replaceOp(illegalOp, legalOp); |
1117 | rewriter.replaceOp(op, illegalOp); |
1118 | return success(); |
1119 | } |
1120 | }; |
1121 | |
1122 | //===----------------------------------------------------------------------===// |
1123 | // Recursive Rewrite Testing |
1124 | /// This pattern is applied to the same operation multiple times, but has a |
1125 | /// bounded recursion. |
1126 | struct TestBoundedRecursiveRewrite |
1127 | : public OpRewritePattern<TestRecursiveRewriteOp> { |
1128 | using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; |
1129 | |
1130 | void initialize() { |
1131 | // The conversion target handles bounding the recursion of this pattern. |
1132 | setHasBoundedRewriteRecursion(); |
1133 | } |
1134 | |
1135 | LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, |
1136 | PatternRewriter &rewriter) const final { |
1137 | // Decrement the depth of the op in-place. |
1138 | rewriter.modifyOpInPlace(op, [&] { |
1139 | op->setAttr("depth", rewriter.getI64IntegerAttr(value: op.getDepth() - 1)); |
1140 | }); |
1141 | return success(); |
1142 | } |
1143 | }; |
1144 | |
1145 | struct TestNestedOpCreationUndoRewrite |
1146 | : public OpRewritePattern<IllegalOpWithRegionAnchor> { |
1147 | using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; |
1148 | |
1149 | LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, |
1150 | PatternRewriter &rewriter) const final { |
1151 | // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); |
1152 | rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); |
1153 | return success(); |
1154 | }; |
1155 | }; |
1156 | |
1157 | // This pattern matches `test.blackhole` and delete this op and its producer. |
1158 | struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { |
1159 | using OpRewritePattern<BlackHoleOp>::OpRewritePattern; |
1160 | |
1161 | LogicalResult matchAndRewrite(BlackHoleOp op, |
1162 | PatternRewriter &rewriter) const final { |
1163 | Operation *producer = op.getOperand().getDefiningOp(); |
1164 | // Always erase the user before the producer, the framework should handle |
1165 | // this correctly. |
1166 | rewriter.eraseOp(op: op); |
1167 | rewriter.eraseOp(op: producer); |
1168 | return success(); |
1169 | }; |
1170 | }; |
1171 | |
1172 | // This pattern replaces explicitly illegal op with explicitly legal op, |
1173 | // but in addition creates unregistered operation. |
1174 | struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { |
1175 | using OpRewritePattern<ILLegalOpG>::OpRewritePattern; |
1176 | |
1177 | LogicalResult matchAndRewrite(ILLegalOpG op, |
1178 | PatternRewriter &rewriter) const final { |
1179 | IntegerAttr attr = rewriter.getI32IntegerAttr(0); |
1180 | Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr); |
1181 | rewriter.replaceOpWithNewOp<LegalOpC>(op, val); |
1182 | return success(); |
1183 | }; |
1184 | }; |
1185 | |
1186 | class TestEraseOp : public ConversionPattern { |
1187 | public: |
1188 | TestEraseOp(MLIRContext *ctx) : ConversionPattern("test.erase_op", 1, ctx) {} |
1189 | LogicalResult |
1190 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
1191 | ConversionPatternRewriter &rewriter) const final { |
1192 | // Erase op without replacements. |
1193 | rewriter.eraseOp(op); |
1194 | return success(); |
1195 | } |
1196 | }; |
1197 | |
1198 | /// This pattern matches a test.convert_block_args op. It either: |
1199 | /// a) Duplicates all block arguments, |
1200 | /// b) or: drops all block arguments and replaces each with 2x the first |
1201 | /// operand. |
1202 | class TestConvertBlockArgs : public OpConversionPattern<ConvertBlockArgsOp> { |
1203 | using OpConversionPattern<ConvertBlockArgsOp>::OpConversionPattern; |
1204 | |
1205 | LogicalResult |
1206 | matchAndRewrite(ConvertBlockArgsOp op, OpAdaptor adaptor, |
1207 | ConversionPatternRewriter &rewriter) const override { |
1208 | if (op.getIsLegal()) |
1209 | return failure(); |
1210 | Block *body = &op.getBody().front(); |
1211 | TypeConverter::SignatureConversion result(body->getNumArguments()); |
1212 | for (auto it : llvm::enumerate(body->getArgumentTypes())) { |
1213 | if (op.getReplaceWithOperand()) { |
1214 | result.remapInput(it.index(), {adaptor.getVal(), adaptor.getVal()}); |
1215 | } else if (op.getDuplicate()) { |
1216 | result.addInputs(it.index(), {it.value(), it.value()}); |
1217 | } else { |
1218 | // No action specified. Pattern does not apply. |
1219 | return failure(); |
1220 | } |
1221 | } |
1222 | rewriter.startOpModification(op: op); |
1223 | rewriter.applySignatureConversion(block: body, conversion&: result, converter: getTypeConverter()); |
1224 | op.setIsLegal(true); |
1225 | rewriter.finalizeOpModification(op: op); |
1226 | return success(); |
1227 | } |
1228 | }; |
1229 | |
1230 | /// This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid |
1231 | /// op. The pattern supports 1:N replacements and forwards the replacement |
1232 | /// values of the single operand as test.valid operands. |
1233 | class TestRepetitive1ToNConsumer : public ConversionPattern { |
1234 | public: |
1235 | TestRepetitive1ToNConsumer(MLIRContext *ctx) |
1236 | : ConversionPattern("test.repetitive_1_to_n_consumer", 1, ctx) {} |
1237 | LogicalResult |
1238 | matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, |
1239 | ConversionPatternRewriter &rewriter) const final { |
1240 | // A single operand is expected. |
1241 | if (op->getNumOperands() != 1) |
1242 | return failure(); |
1243 | rewriter.replaceOpWithNewOp<TestValidOp>(op, operands.front()); |
1244 | return success(); |
1245 | } |
1246 | }; |
1247 | |
1248 | /// A pattern that tests two back-to-back 1 -> 2 op replacements. |
1249 | class TestMultiple1ToNReplacement : public ConversionPattern { |
1250 | public: |
1251 | TestMultiple1ToNReplacement(MLIRContext *ctx, const TypeConverter &converter) |
1252 | : ConversionPattern(converter, "test.multiple_1_to_n_replacement", 1, |
1253 | ctx) {} |
1254 | LogicalResult |
1255 | matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, |
1256 | ConversionPatternRewriter &rewriter) const final { |
1257 | // Helper function that replaces the given op with a new op of the given |
1258 | // name and doubles each result (1 -> 2 replacement of each result). |
1259 | auto replaceWithDoubleResults = [&](Operation *op, StringRef name) { |
1260 | SmallVector<Type> types; |
1261 | for (Type t : op->getResultTypes()) { |
1262 | types.push_back(Elt: t); |
1263 | types.push_back(Elt: t); |
1264 | } |
1265 | OperationState state(op->getLoc(), name, |
1266 | /*operands=*/{}, types, op->getAttrs()); |
1267 | auto *newOp = rewriter.create(state); |
1268 | SmallVector<ValueRange> repls; |
1269 | for (size_t i = 0, e = op->getNumResults(); i < e; ++i) |
1270 | repls.push_back(Elt: newOp->getResults().slice(n: 2 * i, m: 2)); |
1271 | rewriter.replaceOpWithMultiple(op, newValues&: repls); |
1272 | return newOp; |
1273 | }; |
1274 | |
1275 | // Replace test.multiple_1_to_n_replacement with test.step_1. |
1276 | Operation *repl1 = replaceWithDoubleResults(op, "test.step_1"); |
1277 | // Now replace test.step_1 with test.legal_op. |
1278 | replaceWithDoubleResults(repl1, "test.legal_op"); |
1279 | return success(); |
1280 | } |
1281 | }; |
1282 | |
1283 | /// Test unambiguous overload resolution of replaceOpWithMultiple. This |
1284 | /// function is just to trigger compiler errors. It is never executed. |
1285 | [[maybe_unused]] void testReplaceOpWithMultipleOverloads( |
1286 | ConversionPatternRewriter &rewriter, Operation *op, ArrayRef<ValueRange> r1, |
1287 | SmallVector<ValueRange> r2, ArrayRef<SmallVector<Value>> r3, |
1288 | SmallVector<SmallVector<Value>> r4, ArrayRef<ArrayRef<Value>> r5, |
1289 | SmallVector<ArrayRef<Value>> r6, SmallVector<SmallVector<Value>> &&r7, |
1290 | Value v, ValueRange vr, ArrayRef<Value> ar) { |
1291 | rewriter.replaceOpWithMultiple(op, newValues: r1); |
1292 | rewriter.replaceOpWithMultiple(op, newValues&: r2); |
1293 | rewriter.replaceOpWithMultiple(op, newValues: r3); |
1294 | rewriter.replaceOpWithMultiple(op, newValues&: r4); |
1295 | rewriter.replaceOpWithMultiple(op, newValues: r5); |
1296 | rewriter.replaceOpWithMultiple(op, newValues&: r6); |
1297 | rewriter.replaceOpWithMultiple(op, newValues: std::move(r7)); |
1298 | rewriter.replaceOpWithMultiple(op, newValues: {vr}); |
1299 | rewriter.replaceOpWithMultiple(op, newValues: {ar}); |
1300 | rewriter.replaceOpWithMultiple(op, newValues: {{v}}); |
1301 | rewriter.replaceOpWithMultiple(op, newValues: {{v, v}}); |
1302 | rewriter.replaceOpWithMultiple(op, newValues: {{v, v}, vr}); |
1303 | rewriter.replaceOpWithMultiple(op, newValues: {{v, v}, ar}); |
1304 | rewriter.replaceOpWithMultiple(op, newValues: {ar, {v, v}, vr}); |
1305 | } |
1306 | } // namespace |
1307 | |
1308 | namespace { |
1309 | struct TestTypeConverter : public TypeConverter { |
1310 | using TypeConverter::TypeConverter; |
1311 | TestTypeConverter() { |
1312 | addConversion(callback&: convertType); |
1313 | addSourceMaterialization(callback&: materializeCast); |
1314 | } |
1315 | |
1316 | static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { |
1317 | // Drop I16 types. |
1318 | if (t.isSignlessInteger(width: 16)) |
1319 | return success(); |
1320 | |
1321 | // Convert I64 to F64. |
1322 | if (t.isSignlessInteger(width: 64)) { |
1323 | results.push_back(Float64Type::get(t.getContext())); |
1324 | return success(); |
1325 | } |
1326 | |
1327 | // Convert I42 to I43. |
1328 | if (t.isInteger(width: 42)) { |
1329 | results.push_back(IntegerType::get(t.getContext(), 43)); |
1330 | return success(); |
1331 | } |
1332 | |
1333 | // Split F32 into F16,F16. |
1334 | if (t.isF32()) { |
1335 | results.assign(2, Float16Type::get(t.getContext())); |
1336 | return success(); |
1337 | } |
1338 | |
1339 | // Drop I24 types. |
1340 | if (t.isInteger(width: 24)) { |
1341 | return success(); |
1342 | } |
1343 | |
1344 | // Otherwise, convert the type directly. |
1345 | results.push_back(Elt: t); |
1346 | return success(); |
1347 | } |
1348 | |
1349 | /// Hook for materializing a conversion. This is necessary because we generate |
1350 | /// 1->N type mappings. |
1351 | static Value materializeCast(OpBuilder &builder, Type resultType, |
1352 | ValueRange inputs, Location loc) { |
1353 | return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); |
1354 | } |
1355 | }; |
1356 | |
1357 | struct TestLegalizePatternDriver |
1358 | : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> { |
1359 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver) |
1360 | |
1361 | StringRef getArgument() const final { return "test-legalize-patterns"; } |
1362 | StringRef getDescription() const final { |
1363 | return "Run test dialect legalization patterns"; |
1364 | } |
1365 | /// The mode of conversion to use with the driver. |
1366 | enum class ConversionMode { Analysis, Full, Partial }; |
1367 | |
1368 | TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} |
1369 | |
1370 | void getDependentDialects(DialectRegistry ®istry) const override { |
1371 | registry.insert<func::FuncDialect, test::TestDialect>(); |
1372 | } |
1373 | |
1374 | void runOnOperation() override { |
1375 | TestTypeConverter converter; |
1376 | mlir::RewritePatternSet patterns(&getContext()); |
1377 | populateWithGenerated(patterns); |
1378 | patterns |
1379 | .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion, |
1380 | TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock, |
1381 | TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType, |
1382 | TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, |
1383 | TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, |
1384 | TestNonRootReplacement, TestBoundedRecursiveRewrite, |
1385 | TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, |
1386 | TestCreateUnregisteredOp, TestUndoMoveOpBefore, |
1387 | TestUndoPropertiesModification, TestEraseOp, |
1388 | TestRepetitive1ToNConsumer>(arg: &getContext()); |
1389 | patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp, |
1390 | TestPassthroughInvalidOp, TestMultiple1ToNReplacement>( |
1391 | arg: &getContext(), args&: converter); |
1392 | patterns.add<TestConvertBlockArgs>(arg&: converter, args: &getContext()); |
1393 | mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, |
1394 | converter); |
1395 | mlir::populateCallOpTypeConversionPattern(patterns, converter); |
1396 | |
1397 | // Define the conversion target used for the test. |
1398 | ConversionTarget target(getContext()); |
1399 | target.addLegalOp<ModuleOp>(); |
1400 | target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, |
1401 | TerminatorOp, OneRegionOp>(); |
1402 | target.addLegalOp(op: OperationName("test.legal_op", &getContext())); |
1403 | target |
1404 | .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); |
1405 | target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { |
1406 | // Don't allow F32 operands. |
1407 | return llvm::none_of(op.getOperandTypes(), |
1408 | [](Type type) { return type.isF32(); }); |
1409 | }); |
1410 | target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
1411 | return converter.isSignatureLegal(op.getFunctionType()) && |
1412 | converter.isLegal(&op.getBody()); |
1413 | }); |
1414 | target.addDynamicallyLegalOp<func::CallOp>( |
1415 | [&](func::CallOp op) { return converter.isLegal(op); }); |
1416 | |
1417 | // TestCreateUnregisteredOp creates `arith.constant` operation, |
1418 | // which was not added to target intentionally to test |
1419 | // correct error code from conversion driver. |
1420 | target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; }); |
1421 | |
1422 | // Expect the type_producer/type_consumer operations to only operate on f64. |
1423 | target.addDynamicallyLegalOp<TestTypeProducerOp>( |
1424 | [](TestTypeProducerOp op) { return op.getType().isF64(); }); |
1425 | target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { |
1426 | return op.getOperand().getType().isF64(); |
1427 | }); |
1428 | |
1429 | // Check support for marking certain operations as recursively legal. |
1430 | target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) { |
1431 | return static_cast<bool>( |
1432 | op->getAttrOfType<UnitAttr>("test.recursively_legal")); |
1433 | }); |
1434 | |
1435 | // Mark the bound recursion operation as dynamically legal. |
1436 | target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( |
1437 | [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; }); |
1438 | |
1439 | // Create a dynamically legal rule that can only be legalized by folding it. |
1440 | target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>( |
1441 | [](TestOpInPlaceSelfFold op) { return op.getFolded(); }); |
1442 | |
1443 | target.addDynamicallyLegalOp<ConvertBlockArgsOp>( |
1444 | [](ConvertBlockArgsOp op) { return op.getIsLegal(); }); |
1445 | |
1446 | // Handle a partial conversion. |
1447 | if (mode == ConversionMode::Partial) { |
1448 | DenseSet<Operation *> unlegalizedOps; |
1449 | ConversionConfig config; |
1450 | DumpNotifications dumpNotifications; |
1451 | config.listener = &dumpNotifications; |
1452 | config.unlegalizedOps = &unlegalizedOps; |
1453 | if (failed(applyPartialConversion(getOperation(), target, |
1454 | std::move(patterns), config))) { |
1455 | getOperation()->emitRemark() << "applyPartialConversion failed"; |
1456 | } |
1457 | // Emit remarks for each legalizable operation. |
1458 | for (auto *op : unlegalizedOps) |
1459 | op->emitRemark() << "op '"<< op->getName() << "' is not legalizable"; |
1460 | return; |
1461 | } |
1462 | |
1463 | // Handle a full conversion. |
1464 | if (mode == ConversionMode::Full) { |
1465 | // Check support for marking unknown operations as dynamically legal. |
1466 | target.markUnknownOpDynamicallyLegal([](Operation *op) { |
1467 | return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); |
1468 | }); |
1469 | |
1470 | ConversionConfig config; |
1471 | DumpNotifications dumpNotifications; |
1472 | config.listener = &dumpNotifications; |
1473 | if (failed(applyFullConversion(getOperation(), target, |
1474 | std::move(patterns), config))) { |
1475 | getOperation()->emitRemark() << "applyFullConversion failed"; |
1476 | } |
1477 | return; |
1478 | } |
1479 | |
1480 | // Otherwise, handle an analysis conversion. |
1481 | assert(mode == ConversionMode::Analysis); |
1482 | |
1483 | // Analyze the convertible operations. |
1484 | DenseSet<Operation *> legalizedOps; |
1485 | ConversionConfig config; |
1486 | config.legalizableOps = &legalizedOps; |
1487 | if (failed(applyAnalysisConversion(getOperation(), target, |
1488 | std::move(patterns), config))) |
1489 | return signalPassFailure(); |
1490 | |
1491 | // Emit remarks for each legalizable operation. |
1492 | for (auto *op : legalizedOps) |
1493 | op->emitRemark() << "op '"<< op->getName() << "' is legalizable"; |
1494 | } |
1495 | |
1496 | /// The mode of conversion to use. |
1497 | ConversionMode mode; |
1498 | }; |
1499 | } // namespace |
1500 | |
1501 | static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> |
1502 | legalizerConversionMode( |
1503 | "test-legalize-mode", |
1504 | llvm::cl::desc("The legalization mode to use with the test driver"), |
1505 | llvm::cl::init(Val: TestLegalizePatternDriver::ConversionMode::Partial), |
1506 | llvm::cl::values( |
1507 | clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, |
1508 | "analysis", "Perform an analysis conversion"), |
1509 | clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", |
1510 | "Perform a full conversion"), |
1511 | clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, |
1512 | "partial", "Perform a partial conversion"))); |
1513 | |
1514 | //===----------------------------------------------------------------------===// |
1515 | // ConversionPatternRewriter::getRemappedValue testing. This method is used |
1516 | // to get the remapped value of an original value that was replaced using |
1517 | // ConversionPatternRewriter. |
1518 | namespace { |
1519 | struct TestRemapValueTypeConverter : public TypeConverter { |
1520 | using TypeConverter::TypeConverter; |
1521 | |
1522 | TestRemapValueTypeConverter() { |
1523 | addConversion( |
1524 | callback: [](Float32Type type) { return Float64Type::get(type.getContext()); }); |
1525 | addConversion(callback: [](Type type) { return type; }); |
1526 | } |
1527 | }; |
1528 | |
1529 | /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with |
1530 | /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original |
1531 | /// operand twice. |
1532 | /// |
1533 | /// Example: |
1534 | /// %1 = test.one_variadic_out_one_variadic_in1"(%0) |
1535 | /// is replaced with: |
1536 | /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) |
1537 | struct OneVResOneVOperandOp1Converter |
1538 | : public OpConversionPattern<OneVResOneVOperandOp1> { |
1539 | using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; |
1540 | |
1541 | LogicalResult |
1542 | matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor, |
1543 | ConversionPatternRewriter &rewriter) const override { |
1544 | auto origOps = op.getOperands(); |
1545 | assert(std::distance(origOps.begin(), origOps.end()) == 1 && |
1546 | "One operand expected"); |
1547 | Value origOp = *origOps.begin(); |
1548 | SmallVector<Value, 2> remappedOperands; |
1549 | // Replicate the remapped original operand twice. Note that we don't used |
1550 | // the remapped 'operand' since the goal is testing 'getRemappedValue'. |
1551 | remappedOperands.push_back(Elt: rewriter.getRemappedValue(key: origOp)); |
1552 | remappedOperands.push_back(Elt: rewriter.getRemappedValue(key: origOp)); |
1553 | |
1554 | rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), |
1555 | remappedOperands); |
1556 | return success(); |
1557 | } |
1558 | }; |
1559 | |
1560 | /// A rewriter pattern that tests that blocks can be merged. |
1561 | struct TestRemapValueInRegion |
1562 | : public OpConversionPattern<TestRemappedValueRegionOp> { |
1563 | using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern; |
1564 | |
1565 | LogicalResult |
1566 | matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor, |
1567 | ConversionPatternRewriter &rewriter) const final { |
1568 | Block &block = op.getBody().front(); |
1569 | Operation *terminator = block.getTerminator(); |
1570 | |
1571 | // Merge the block into the parent region. |
1572 | Block *parentBlock = op->getBlock(); |
1573 | Block *finalBlock = rewriter.splitBlock(block: parentBlock, before: op->getIterator()); |
1574 | rewriter.mergeBlocks(source: &block, dest: parentBlock, argValues: ValueRange()); |
1575 | rewriter.mergeBlocks(source: finalBlock, dest: parentBlock, argValues: ValueRange()); |
1576 | |
1577 | // Replace the results of this operation with the remapped terminator |
1578 | // values. |
1579 | SmallVector<Value> terminatorOperands; |
1580 | if (failed(Result: rewriter.getRemappedValues(keys: terminator->getOperands(), |
1581 | results&: terminatorOperands))) |
1582 | return failure(); |
1583 | |
1584 | rewriter.eraseOp(op: terminator); |
1585 | rewriter.replaceOp(op, terminatorOperands); |
1586 | return success(); |
1587 | } |
1588 | }; |
1589 | |
1590 | struct TestRemappedValue |
1591 | : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> { |
1592 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue) |
1593 | |
1594 | StringRef getArgument() const final { return "test-remapped-value"; } |
1595 | StringRef getDescription() const final { |
1596 | return "Test public remapped value mechanism in ConversionPatternRewriter"; |
1597 | } |
1598 | void runOnOperation() override { |
1599 | TestRemapValueTypeConverter typeConverter; |
1600 | |
1601 | mlir::RewritePatternSet patterns(&getContext()); |
1602 | patterns.add<OneVResOneVOperandOp1Converter>(arg: &getContext()); |
1603 | patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>( |
1604 | arg: &getContext()); |
1605 | patterns.add<TestRemapValueInRegion>(arg&: typeConverter, args: &getContext()); |
1606 | |
1607 | mlir::ConversionTarget target(getContext()); |
1608 | target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>(); |
1609 | |
1610 | // Expect the type_producer/type_consumer operations to only operate on f64. |
1611 | target.addDynamicallyLegalOp<TestTypeProducerOp>( |
1612 | [](TestTypeProducerOp op) { return op.getType().isF64(); }); |
1613 | target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { |
1614 | return op.getOperand().getType().isF64(); |
1615 | }); |
1616 | |
1617 | // We make OneVResOneVOperandOp1 legal only when it has more that one |
1618 | // operand. This will trigger the conversion that will replace one-operand |
1619 | // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. |
1620 | target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( |
1621 | [](Operation *op) { return op->getNumOperands() > 1; }); |
1622 | |
1623 | if (failed(mlir::applyFullConversion(getOperation(), target, |
1624 | std::move(patterns)))) { |
1625 | signalPassFailure(); |
1626 | } |
1627 | } |
1628 | }; |
1629 | } // namespace |
1630 | |
1631 | //===----------------------------------------------------------------------===// |
1632 | // Test patterns without a specific root operation kind |
1633 | //===----------------------------------------------------------------------===// |
1634 | |
1635 | namespace { |
1636 | /// This pattern matches and removes any operation in the test dialect. |
1637 | struct RemoveTestDialectOps : public RewritePattern { |
1638 | RemoveTestDialectOps(MLIRContext *context) |
1639 | : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} |
1640 | |
1641 | LogicalResult matchAndRewrite(Operation *op, |
1642 | PatternRewriter &rewriter) const override { |
1643 | if (!isa<TestDialect>(Val: op->getDialect())) |
1644 | return failure(); |
1645 | rewriter.eraseOp(op); |
1646 | return success(); |
1647 | } |
1648 | }; |
1649 | |
1650 | struct TestUnknownRootOpDriver |
1651 | : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> { |
1652 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver) |
1653 | |
1654 | StringRef getArgument() const final { |
1655 | return "test-legalize-unknown-root-patterns"; |
1656 | } |
1657 | StringRef getDescription() const final { |
1658 | return "Test public remapped value mechanism in ConversionPatternRewriter"; |
1659 | } |
1660 | void runOnOperation() override { |
1661 | mlir::RewritePatternSet patterns(&getContext()); |
1662 | patterns.add<RemoveTestDialectOps>(arg: &getContext()); |
1663 | |
1664 | mlir::ConversionTarget target(getContext()); |
1665 | target.addIllegalDialect<TestDialect>(); |
1666 | if (failed(applyPartialConversion(getOperation(), target, |
1667 | std::move(patterns)))) |
1668 | signalPassFailure(); |
1669 | } |
1670 | }; |
1671 | } // namespace |
1672 | |
1673 | //===----------------------------------------------------------------------===// |
1674 | // Test patterns that uses operations and types defined at runtime |
1675 | //===----------------------------------------------------------------------===// |
1676 | |
1677 | namespace { |
1678 | /// This pattern matches dynamic operations 'test.one_operand_two_results' and |
1679 | /// replace them with dynamic operations 'test.generic_dynamic_op'. |
1680 | struct RewriteDynamicOp : public RewritePattern { |
1681 | RewriteDynamicOp(MLIRContext *context) |
1682 | : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1, |
1683 | context) {} |
1684 | |
1685 | LogicalResult matchAndRewrite(Operation *op, |
1686 | PatternRewriter &rewriter) const override { |
1687 | assert(op->getName().getStringRef() == |
1688 | "test.dynamic_one_operand_two_results"&& |
1689 | "rewrite pattern should only match operations with the right name"); |
1690 | |
1691 | OperationState state(op->getLoc(), "test.dynamic_generic", |
1692 | op->getOperands(), op->getResultTypes(), |
1693 | op->getAttrs()); |
1694 | auto *newOp = rewriter.create(state); |
1695 | rewriter.replaceOp(op, newValues: newOp->getResults()); |
1696 | return success(); |
1697 | } |
1698 | }; |
1699 | |
1700 | struct TestRewriteDynamicOpDriver |
1701 | : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> { |
1702 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver) |
1703 | |
1704 | void getDependentDialects(DialectRegistry ®istry) const override { |
1705 | registry.insert<TestDialect>(); |
1706 | } |
1707 | StringRef getArgument() const final { return "test-rewrite-dynamic-op"; } |
1708 | StringRef getDescription() const final { |
1709 | return "Test rewritting on dynamic operations"; |
1710 | } |
1711 | void runOnOperation() override { |
1712 | RewritePatternSet patterns(&getContext()); |
1713 | patterns.add<RewriteDynamicOp>(arg: &getContext()); |
1714 | |
1715 | ConversionTarget target(getContext()); |
1716 | target.addIllegalOp( |
1717 | op: OperationName("test.dynamic_one_operand_two_results", &getContext())); |
1718 | target.addLegalOp(op: OperationName("test.dynamic_generic", &getContext())); |
1719 | if (failed(applyPartialConversion(getOperation(), target, |
1720 | std::move(patterns)))) |
1721 | signalPassFailure(); |
1722 | } |
1723 | }; |
1724 | } // end anonymous namespace |
1725 | |
1726 | //===----------------------------------------------------------------------===// |
1727 | // Test type conversions |
1728 | //===----------------------------------------------------------------------===// |
1729 | |
1730 | namespace { |
1731 | struct TestTypeConversionProducer |
1732 | : public OpConversionPattern<TestTypeProducerOp> { |
1733 | using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; |
1734 | LogicalResult |
1735 | matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, |
1736 | ConversionPatternRewriter &rewriter) const final { |
1737 | Type resultType = op.getType(); |
1738 | Type convertedType = getTypeConverter() |
1739 | ? getTypeConverter()->convertType(resultType) |
1740 | : resultType; |
1741 | if (isa<FloatType>(Val: resultType)) |
1742 | resultType = rewriter.getF64Type(); |
1743 | else if (resultType.isInteger(width: 16)) |
1744 | resultType = rewriter.getIntegerType(64); |
1745 | else if (isa<test::TestRecursiveType>(Val: resultType) && |
1746 | convertedType != resultType) |
1747 | resultType = convertedType; |
1748 | else |
1749 | return failure(); |
1750 | |
1751 | rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); |
1752 | return success(); |
1753 | } |
1754 | }; |
1755 | |
1756 | /// Call signature conversion and then fail the rewrite to trigger the undo |
1757 | /// mechanism. |
1758 | struct TestSignatureConversionUndo |
1759 | : public OpConversionPattern<TestSignatureConversionUndoOp> { |
1760 | using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; |
1761 | |
1762 | LogicalResult |
1763 | matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor, |
1764 | ConversionPatternRewriter &rewriter) const final { |
1765 | (void)rewriter.convertRegionTypes(region: &op->getRegion(0), converter: *getTypeConverter()); |
1766 | return failure(); |
1767 | } |
1768 | }; |
1769 | |
1770 | /// Call signature conversion without providing a type converter to handle |
1771 | /// materializations. |
1772 | struct TestTestSignatureConversionNoConverter |
1773 | : public OpConversionPattern<TestSignatureConversionNoConverterOp> { |
1774 | TestTestSignatureConversionNoConverter(const TypeConverter &converter, |
1775 | MLIRContext *context) |
1776 | : OpConversionPattern<TestSignatureConversionNoConverterOp>(context), |
1777 | converter(converter) {} |
1778 | |
1779 | LogicalResult |
1780 | matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor, |
1781 | ConversionPatternRewriter &rewriter) const final { |
1782 | Region ®ion = op->getRegion(0); |
1783 | Block *entry = ®ion.front(); |
1784 | |
1785 | // Convert the original entry arguments. |
1786 | TypeConverter::SignatureConversion result(entry->getNumArguments()); |
1787 | if (failed( |
1788 | Result: converter.convertSignatureArgs(types: entry->getArgumentTypes(), result))) |
1789 | return failure(); |
1790 | rewriter.modifyOpInPlace(op, [&] { |
1791 | rewriter.applySignatureConversion(block: ®ion.front(), conversion&: result); |
1792 | }); |
1793 | return success(); |
1794 | } |
1795 | |
1796 | const TypeConverter &converter; |
1797 | }; |
1798 | |
1799 | /// Just forward the operands to the root op. This is essentially a no-op |
1800 | /// pattern that is used to trigger target materialization. |
1801 | struct TestTypeConsumerForward |
1802 | : public OpConversionPattern<TestTypeConsumerOp> { |
1803 | using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; |
1804 | |
1805 | LogicalResult |
1806 | matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, |
1807 | ConversionPatternRewriter &rewriter) const final { |
1808 | rewriter.modifyOpInPlace(op, |
1809 | [&] { op->setOperands(adaptor.getOperands()); }); |
1810 | return success(); |
1811 | } |
1812 | }; |
1813 | |
1814 | struct TestTypeConversionAnotherProducer |
1815 | : public OpRewritePattern<TestAnotherTypeProducerOp> { |
1816 | using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; |
1817 | |
1818 | LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, |
1819 | PatternRewriter &rewriter) const final { |
1820 | rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); |
1821 | return success(); |
1822 | } |
1823 | }; |
1824 | |
1825 | struct TestReplaceWithLegalOp : public ConversionPattern { |
1826 | TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx) |
1827 | : ConversionPattern(converter, "test.replace_with_legal_op", |
1828 | /*benefit=*/1, ctx) {} |
1829 | LogicalResult |
1830 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
1831 | ConversionPatternRewriter &rewriter) const final { |
1832 | rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]); |
1833 | return success(); |
1834 | } |
1835 | }; |
1836 | |
1837 | struct TestTypeConversionDriver |
1838 | : public PassWrapper<TestTypeConversionDriver, OperationPass<>> { |
1839 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver) |
1840 | |
1841 | void getDependentDialects(DialectRegistry ®istry) const override { |
1842 | registry.insert<TestDialect>(); |
1843 | } |
1844 | StringRef getArgument() const final { |
1845 | return "test-legalize-type-conversion"; |
1846 | } |
1847 | StringRef getDescription() const final { |
1848 | return "Test various type conversion functionalities in DialectConversion"; |
1849 | } |
1850 | |
1851 | void runOnOperation() override { |
1852 | // Initialize the type converter. |
1853 | SmallVector<Type, 2> conversionCallStack; |
1854 | TypeConverter converter; |
1855 | |
1856 | /// Add the legal set of type conversions. |
1857 | converter.addConversion(callback: [](Type type) -> Type { |
1858 | // Treat F64 as legal. |
1859 | if (type.isF64()) |
1860 | return type; |
1861 | // Allow converting BF16/F16/F32 to F64. |
1862 | if (type.isBF16() || type.isF16() || type.isF32()) |
1863 | return Float64Type::get(type.getContext()); |
1864 | // Otherwise, the type is illegal. |
1865 | return nullptr; |
1866 | }); |
1867 | converter.addConversion(callback: [](IntegerType type, SmallVectorImpl<Type> &) { |
1868 | // Drop all integer types. |
1869 | return success(); |
1870 | }); |
1871 | converter.addConversion( |
1872 | // Convert a recursive self-referring type into a non-self-referring |
1873 | // type named "outer_converted_type" that contains a SimpleAType. |
1874 | callback: [&](test::TestRecursiveType type, |
1875 | SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { |
1876 | // If the type is already converted, return it to indicate that it is |
1877 | // legal. |
1878 | if (type.getName() == "outer_converted_type") { |
1879 | results.push_back(type); |
1880 | return success(); |
1881 | } |
1882 | |
1883 | conversionCallStack.push_back(type); |
1884 | auto popConversionCallStack = llvm::make_scope_exit( |
1885 | F: [&conversionCallStack]() { conversionCallStack.pop_back(); }); |
1886 | |
1887 | // If the type is on the call stack more than once (it is there at |
1888 | // least once because of the _current_ call, which is always the last |
1889 | // element on the stack), we've hit the recursive case. Just return |
1890 | // SimpleAType here to create a non-recursive type as a result. |
1891 | if (llvm::is_contained(Range: ArrayRef(conversionCallStack).drop_back(), |
1892 | Element: type)) { |
1893 | results.push_back(test::SimpleAType::get(type.getContext())); |
1894 | return success(); |
1895 | } |
1896 | |
1897 | // Convert the body recursively. |
1898 | auto result = test::TestRecursiveType::get(ctx: type.getContext(), |
1899 | name: "outer_converted_type"); |
1900 | if (failed(result.setBody(converter.convertType(t: type.getBody())))) |
1901 | return failure(); |
1902 | results.push_back(Elt: result); |
1903 | return success(); |
1904 | }); |
1905 | |
1906 | /// Add the legal set of type materializations. |
1907 | converter.addSourceMaterialization(callback: [](OpBuilder &builder, Type resultType, |
1908 | ValueRange inputs, |
1909 | Location loc) -> Value { |
1910 | // Allow casting from F64 back to F32. |
1911 | if (!resultType.isF16() && inputs.size() == 1 && |
1912 | inputs[0].getType().isF64()) |
1913 | return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); |
1914 | // Allow producing an i32 or i64 from nothing. |
1915 | if ((resultType.isInteger(32) || resultType.isInteger(64)) && |
1916 | inputs.empty()) |
1917 | return builder.create<TestTypeProducerOp>(loc, resultType); |
1918 | // Allow producing an i64 from an integer. |
1919 | if (isa<IntegerType>(resultType) && inputs.size() == 1 && |
1920 | isa<IntegerType>(inputs[0].getType())) |
1921 | return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); |
1922 | // Otherwise, fail. |
1923 | return nullptr; |
1924 | }); |
1925 | |
1926 | // Initialize the conversion target. |
1927 | mlir::ConversionTarget target(getContext()); |
1928 | target.addLegalOp<LegalOpD>(); |
1929 | target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { |
1930 | auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType()); |
1931 | return op.getType().isF64() || op.getType().isInteger(64) || |
1932 | (recursiveType && |
1933 | recursiveType.getName() == "outer_converted_type"); |
1934 | }); |
1935 | target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
1936 | return converter.isSignatureLegal(op.getFunctionType()) && |
1937 | converter.isLegal(&op.getBody()); |
1938 | }); |
1939 | target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { |
1940 | // Allow casts from F64 to F32. |
1941 | return (*op.operand_type_begin()).isF64() && op.getType().isF32(); |
1942 | }); |
1943 | target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>( |
1944 | [&](TestSignatureConversionNoConverterOp op) { |
1945 | return converter.isLegal(op.getRegion().front().getArgumentTypes()); |
1946 | }); |
1947 | |
1948 | // Initialize the set of rewrite patterns. |
1949 | RewritePatternSet patterns(&getContext()); |
1950 | patterns |
1951 | .add<TestTypeConsumerForward, TestTypeConversionProducer, |
1952 | TestSignatureConversionUndo, |
1953 | TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>( |
1954 | arg&: converter, args: &getContext()); |
1955 | patterns.add<TestTypeConversionAnotherProducer>(arg: &getContext()); |
1956 | mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, |
1957 | converter); |
1958 | |
1959 | if (failed(applyPartialConversion(getOperation(), target, |
1960 | std::move(patterns)))) |
1961 | signalPassFailure(); |
1962 | } |
1963 | }; |
1964 | } // namespace |
1965 | |
1966 | //===----------------------------------------------------------------------===// |
1967 | // Test Target Materialization With No Uses |
1968 | //===----------------------------------------------------------------------===// |
1969 | |
1970 | namespace { |
1971 | struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> { |
1972 | using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern; |
1973 | |
1974 | LogicalResult |
1975 | matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor, |
1976 | ConversionPatternRewriter &rewriter) const final { |
1977 | rewriter.replaceOp(op, adaptor.getOperands()); |
1978 | return success(); |
1979 | } |
1980 | }; |
1981 | |
1982 | struct TestTargetMaterializationWithNoUses |
1983 | : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> { |
1984 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
1985 | TestTargetMaterializationWithNoUses) |
1986 | |
1987 | StringRef getArgument() const final { |
1988 | return "test-target-materialization-with-no-uses"; |
1989 | } |
1990 | StringRef getDescription() const final { |
1991 | return "Test a special case of target materialization in DialectConversion"; |
1992 | } |
1993 | |
1994 | void runOnOperation() override { |
1995 | TypeConverter converter; |
1996 | converter.addConversion(callback: [](Type t) { return t; }); |
1997 | converter.addConversion(callback: [](IntegerType intTy) -> Type { |
1998 | if (intTy.getWidth() == 16) |
1999 | return IntegerType::get(intTy.getContext(), 64); |
2000 | return intTy; |
2001 | }); |
2002 | converter.addTargetMaterialization( |
2003 | callback: [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { |
2004 | return builder.create<TestCastOp>(loc, type, inputs).getResult(); |
2005 | }); |
2006 | |
2007 | ConversionTarget target(getContext()); |
2008 | target.addIllegalOp<TestTypeChangerOp>(); |
2009 | |
2010 | RewritePatternSet patterns(&getContext()); |
2011 | patterns.add<ForwardOperandPattern>(arg&: converter, args: &getContext()); |
2012 | |
2013 | if (failed(applyPartialConversion(getOperation(), target, |
2014 | std::move(patterns)))) |
2015 | signalPassFailure(); |
2016 | } |
2017 | }; |
2018 | } // namespace |
2019 | |
2020 | //===----------------------------------------------------------------------===// |
2021 | // Test Block Merging |
2022 | //===----------------------------------------------------------------------===// |
2023 | |
2024 | namespace { |
2025 | /// A rewriter pattern that tests that blocks can be merged. |
2026 | struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { |
2027 | using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; |
2028 | |
2029 | LogicalResult |
2030 | matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, |
2031 | ConversionPatternRewriter &rewriter) const final { |
2032 | Block &firstBlock = op.getBody().front(); |
2033 | Operation *branchOp = firstBlock.getTerminator(); |
2034 | Block *secondBlock = &*(std::next(op.getBody().begin())); |
2035 | auto succOperands = branchOp->getOperands(); |
2036 | SmallVector<Value, 2> replacements(succOperands); |
2037 | rewriter.eraseOp(op: branchOp); |
2038 | rewriter.mergeBlocks(source: secondBlock, dest: &firstBlock, argValues: replacements); |
2039 | rewriter.modifyOpInPlace(op, [] {}); |
2040 | return success(); |
2041 | } |
2042 | }; |
2043 | |
2044 | /// A rewrite pattern to tests the undo mechanism of blocks being merged. |
2045 | struct TestUndoBlocksMerge : public ConversionPattern { |
2046 | TestUndoBlocksMerge(MLIRContext *ctx) |
2047 | : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} |
2048 | LogicalResult |
2049 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
2050 | ConversionPatternRewriter &rewriter) const final { |
2051 | Block &firstBlock = op->getRegion(index: 0).front(); |
2052 | Operation *branchOp = firstBlock.getTerminator(); |
2053 | Block *secondBlock = &*(std::next(x: op->getRegion(index: 0).begin())); |
2054 | rewriter.setInsertionPointToStart(secondBlock); |
2055 | rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); |
2056 | auto succOperands = branchOp->getOperands(); |
2057 | SmallVector<Value, 2> replacements(succOperands); |
2058 | rewriter.eraseOp(op: branchOp); |
2059 | rewriter.mergeBlocks(source: secondBlock, dest: &firstBlock, argValues: replacements); |
2060 | rewriter.modifyOpInPlace(root: op, callable: [] {}); |
2061 | return success(); |
2062 | } |
2063 | }; |
2064 | |
2065 | /// A rewrite mechanism to inline the body of the op into its parent, when both |
2066 | /// ops can have a single block. |
2067 | struct TestMergeSingleBlockOps |
2068 | : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { |
2069 | using OpConversionPattern< |
2070 | SingleBlockImplicitTerminatorOp>::OpConversionPattern; |
2071 | |
2072 | LogicalResult |
2073 | matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor, |
2074 | ConversionPatternRewriter &rewriter) const final { |
2075 | SingleBlockImplicitTerminatorOp parentOp = |
2076 | op->getParentOfType<SingleBlockImplicitTerminatorOp>(); |
2077 | if (!parentOp) |
2078 | return failure(); |
2079 | Block &innerBlock = op.getRegion().front(); |
2080 | TerminatorOp innerTerminator = |
2081 | cast<TerminatorOp>(innerBlock.getTerminator()); |
2082 | rewriter.inlineBlockBefore(&innerBlock, op); |
2083 | rewriter.eraseOp(op: innerTerminator); |
2084 | rewriter.eraseOp(op: op); |
2085 | return success(); |
2086 | } |
2087 | }; |
2088 | |
2089 | struct TestMergeBlocksPatternDriver |
2090 | : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> { |
2091 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver) |
2092 | |
2093 | StringRef getArgument() const final { return "test-merge-blocks"; } |
2094 | StringRef getDescription() const final { |
2095 | return "Test Merging operation in ConversionPatternRewriter"; |
2096 | } |
2097 | void runOnOperation() override { |
2098 | MLIRContext *context = &getContext(); |
2099 | mlir::RewritePatternSet patterns(context); |
2100 | patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( |
2101 | arg&: context); |
2102 | ConversionTarget target(*context); |
2103 | target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp, |
2104 | TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); |
2105 | target.addIllegalOp<ILLegalOpF>(); |
2106 | |
2107 | /// Expect the op to have a single block after legalization. |
2108 | target.addDynamicallyLegalOp<TestMergeBlocksOp>( |
2109 | [&](TestMergeBlocksOp op) -> bool { |
2110 | return llvm::hasSingleElement(op.getBody()); |
2111 | }); |
2112 | |
2113 | /// Only allow `test.br` within test.merge_blocks op. |
2114 | target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { |
2115 | return op->getParentOfType<TestMergeBlocksOp>(); |
2116 | }); |
2117 | |
2118 | /// Expect that all nested test.SingleBlockImplicitTerminator ops are |
2119 | /// inlined. |
2120 | target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( |
2121 | [&](SingleBlockImplicitTerminatorOp op) -> bool { |
2122 | return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); |
2123 | }); |
2124 | |
2125 | DenseSet<Operation *> unlegalizedOps; |
2126 | ConversionConfig config; |
2127 | config.unlegalizedOps = &unlegalizedOps; |
2128 | (void)applyPartialConversion(getOperation(), target, std::move(patterns), |
2129 | config); |
2130 | for (auto *op : unlegalizedOps) |
2131 | op->emitRemark() << "op '"<< op->getName() << "' is not legalizable"; |
2132 | } |
2133 | }; |
2134 | } // namespace |
2135 | |
2136 | //===----------------------------------------------------------------------===// |
2137 | // Test Selective Replacement |
2138 | //===----------------------------------------------------------------------===// |
2139 | |
2140 | namespace { |
2141 | /// A rewrite mechanism to inline the body of the op into its parent, when both |
2142 | /// ops can have a single block. |
2143 | struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { |
2144 | using OpRewritePattern<TestCastOp>::OpRewritePattern; |
2145 | |
2146 | LogicalResult matchAndRewrite(TestCastOp op, |
2147 | PatternRewriter &rewriter) const final { |
2148 | if (op.getNumOperands() != 2) |
2149 | return failure(); |
2150 | OperandRange operands = op.getOperands(); |
2151 | |
2152 | // Replace non-terminator uses with the first operand. |
2153 | rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) { |
2154 | return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); |
2155 | }); |
2156 | // Replace everything else with the second operand if the operation isn't |
2157 | // dead. |
2158 | rewriter.replaceOp(op, op.getOperand(1)); |
2159 | return success(); |
2160 | } |
2161 | }; |
2162 | |
2163 | struct TestSelectiveReplacementPatternDriver |
2164 | : public PassWrapper<TestSelectiveReplacementPatternDriver, |
2165 | OperationPass<>> { |
2166 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
2167 | TestSelectiveReplacementPatternDriver) |
2168 | |
2169 | StringRef getArgument() const final { |
2170 | return "test-pattern-selective-replacement"; |
2171 | } |
2172 | StringRef getDescription() const final { |
2173 | return "Test selective replacement in the PatternRewriter"; |
2174 | } |
2175 | void runOnOperation() override { |
2176 | MLIRContext *context = &getContext(); |
2177 | mlir::RewritePatternSet patterns(context); |
2178 | patterns.add<TestSelectiveOpReplacementPattern>(arg&: context); |
2179 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
2180 | } |
2181 | }; |
2182 | } // namespace |
2183 | |
2184 | //===----------------------------------------------------------------------===// |
2185 | // PassRegistration |
2186 | //===----------------------------------------------------------------------===// |
2187 | |
2188 | namespace mlir { |
2189 | namespace test { |
2190 | void registerPatternsTestPass() { |
2191 | PassRegistration<TestReturnTypeDriver>(); |
2192 | |
2193 | PassRegistration<TestDerivedAttributeDriver>(); |
2194 | |
2195 | PassRegistration<TestGreedyPatternDriver>(); |
2196 | PassRegistration<TestStrictPatternDriver>(); |
2197 | PassRegistration<TestWalkPatternDriver>(); |
2198 | |
2199 | PassRegistration<TestLegalizePatternDriver>([] { |
2200 | return std::make_unique<TestLegalizePatternDriver>(args&: legalizerConversionMode); |
2201 | }); |
2202 | |
2203 | PassRegistration<TestRemappedValue>(); |
2204 | |
2205 | PassRegistration<TestUnknownRootOpDriver>(); |
2206 | |
2207 | PassRegistration<TestTypeConversionDriver>(); |
2208 | PassRegistration<TestTargetMaterializationWithNoUses>(); |
2209 | |
2210 | PassRegistration<TestRewriteDynamicOpDriver>(); |
2211 | |
2212 | PassRegistration<TestMergeBlocksPatternDriver>(); |
2213 | PassRegistration<TestSelectiveReplacementPatternDriver>(); |
2214 | } |
2215 | } // namespace test |
2216 | } // namespace mlir |
2217 |
Definitions
- chooseOperand
- createOpI
- handleNoResultOp
- getFirstI32Result
- bindNativeCodeCallResult
- bindMultipleNativeCodeCallResult
- opMIncreasingValue
- opMTest
- populateTestReductionPatterns
- FoldingPattern
- FoldingPattern
- matchAndRewrite
- FolderInsertBeforePreviouslyFoldedConstantPattern
- matchAndRewrite
- FolderCommutativeOp2WithConstant
- matchAndRewrite
- IncrementIntAttribute
- matchAndRewrite
- MakeOpEligible
- MakeOpEligible
- matchAndRewrite
- HoistEligibleOps
- matchAndRewrite
- MoveBeforeParentOp
- MoveBeforeParentOp
- matchAndRewrite
- MoveAfterParentOp
- MoveAfterParentOp
- matchAndRewrite
- InlineBlocksIntoParent
- InlineBlocksIntoParent
- matchAndRewrite
- SplitBlockHere
- SplitBlockHere
- matchAndRewrite
- CloneOp
- CloneOp
- matchAndRewrite
- CloneRegionBeforeOp
- CloneRegionBeforeOp
- matchAndRewrite
- ReplaceWithNewOp
- ReplaceWithNewOp
- matchAndRewrite
- EraseFirstBlock
- EraseFirstBlock
- matchAndRewrite
- TestGreedyPatternDriver
- TestGreedyPatternDriver
- TestGreedyPatternDriver
- getArgument
- getDescription
- runOnOperation
- DumpNotifications
- notifyBlockInserted
- notifyOperationInserted
- notifyBlockErased
- notifyOperationErased
- notifyOperationModified
- notifyOperationReplaced
- TestStrictPatternDriver
- TestStrictPatternDriver
- TestStrictPatternDriver
- getArgument
- getDescription
- runOnOperation
- InsertSameOp
- InsertSameOp
- matchAndRewrite
- EraseOp
- EraseOp
- matchAndRewrite
- ChangeBlockOp
- ChangeBlockOp
- matchAndRewrite
- ImplicitChangeOp
- ImplicitChangeOp
- matchAndRewrite
- TestWalkPatternDriver
- TestWalkPatternDriver
- TestWalkPatternDriver
- getArgument
- getDescription
- runOnOperation
- invokeCreateWithInferredReturnType
- reifyReturnShape
- TestReturnTypeDriver
- getDependentDialects
- getArgument
- getDescription
- runOnOperation
- TestDerivedAttributeDriver
- getArgument
- getDescription
- runOnOperation
- TestDetachedSignatureConversion
- TestDetachedSignatureConversion
- matchAndRewrite
- TestRegionRewriteBlockMovement
- TestRegionRewriteBlockMovement
- matchAndRewrite
- TestRegionRewriteUndo
- TestRegionRewriteUndo
- matchAndRewrite
- TestCreateBlock
- TestCreateBlock
- matchAndRewrite
- TestCreateIllegalBlock
- TestCreateIllegalBlock
- matchAndRewrite
- TestUndoBlockArgReplace
- TestUndoBlockArgReplace
- matchAndRewrite
- TestUndoMoveOpBefore
- TestUndoMoveOpBefore
- matchAndRewrite
- TestUndoBlockErase
- TestUndoBlockErase
- matchAndRewrite
- TestUndoPropertiesModification
- TestUndoPropertiesModification
- matchAndRewrite
- TestDropOpSignatureConversion
- TestDropOpSignatureConversion
- matchAndRewrite
- TestPassthroughInvalidOp
- TestPassthroughInvalidOp
- matchAndRewrite
- TestDropAndReplaceInvalidOp
- TestDropAndReplaceInvalidOp
- matchAndRewrite
- TestSplitReturnType
- TestSplitReturnType
- matchAndRewrite
- TestChangeProducerTypeI32ToF32
- TestChangeProducerTypeI32ToF32
- matchAndRewrite
- TestChangeProducerTypeF32ToF64
- TestChangeProducerTypeF32ToF64
- matchAndRewrite
- TestChangeProducerTypeF32ToInvalid
- TestChangeProducerTypeF32ToInvalid
- matchAndRewrite
- TestUpdateConsumerType
- TestUpdateConsumerType
- matchAndRewrite
- TestNonRootReplacement
- TestNonRootReplacement
- matchAndRewrite
- TestBoundedRecursiveRewrite
- initialize
- matchAndRewrite
- TestNestedOpCreationUndoRewrite
- matchAndRewrite
- TestReplaceEraseOp
- matchAndRewrite
- TestCreateUnregisteredOp
- matchAndRewrite
- TestEraseOp
- TestEraseOp
- matchAndRewrite
- TestConvertBlockArgs
- matchAndRewrite
- TestRepetitive1ToNConsumer
- TestRepetitive1ToNConsumer
- matchAndRewrite
- TestMultiple1ToNReplacement
- TestMultiple1ToNReplacement
- matchAndRewrite
- testReplaceOpWithMultipleOverloads
- TestTypeConverter
- TestTypeConverter
- convertType
- materializeCast
- TestLegalizePatternDriver
- getArgument
- getDescription
- ConversionMode
- TestLegalizePatternDriver
- getDependentDialects
- runOnOperation
- legalizerConversionMode
- TestRemapValueTypeConverter
- TestRemapValueTypeConverter
- OneVResOneVOperandOp1Converter
- matchAndRewrite
- TestRemapValueInRegion
- matchAndRewrite
- TestRemappedValue
- getArgument
- getDescription
- runOnOperation
- RemoveTestDialectOps
- RemoveTestDialectOps
- matchAndRewrite
- TestUnknownRootOpDriver
- getArgument
- getDescription
- runOnOperation
- RewriteDynamicOp
- RewriteDynamicOp
- matchAndRewrite
- TestRewriteDynamicOpDriver
- getDependentDialects
- getArgument
- getDescription
- runOnOperation
- TestTypeConversionProducer
- matchAndRewrite
- TestSignatureConversionUndo
- matchAndRewrite
- TestTestSignatureConversionNoConverter
- TestTestSignatureConversionNoConverter
- matchAndRewrite
- TestTypeConsumerForward
- matchAndRewrite
- TestTypeConversionAnotherProducer
- matchAndRewrite
- TestReplaceWithLegalOp
- TestReplaceWithLegalOp
- matchAndRewrite
- TestTypeConversionDriver
- getDependentDialects
- getArgument
- getDescription
- runOnOperation
- ForwardOperandPattern
- matchAndRewrite
- TestTargetMaterializationWithNoUses
- getArgument
- getDescription
- runOnOperation
- TestMergeBlock
- matchAndRewrite
- TestUndoBlocksMerge
- TestUndoBlocksMerge
- matchAndRewrite
- TestMergeSingleBlockOps
- matchAndRewrite
- TestMergeBlocksPatternDriver
- getArgument
- getDescription
- runOnOperation
- TestSelectiveOpReplacementPattern
- matchAndRewrite
- TestSelectiveReplacementPatternDriver
- getArgument
- getDescription
- runOnOperation
Learn to use CMake with our Intro Training
Find out more