1 | //===- TestMakeIsolatedFromAbove.cpp - Test makeIsolatedFromAbove method -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "TestDialect.h" |
10 | #include "TestOps.h" |
11 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
12 | #include "mlir/IR/PatternMatch.h" |
13 | #include "mlir/Pass/Pass.h" |
14 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
15 | #include "mlir/Transforms/RegionUtils.h" |
16 | |
17 | using namespace mlir; |
18 | |
19 | /// Helper function to call the `makeRegionIsolatedFromAbove` to convert |
20 | /// `test.one_region_op` to `test.isolated_one_region_op`. |
21 | static LogicalResult |
22 | makeIsolatedFromAboveImpl(RewriterBase &rewriter, |
23 | test::OneRegionWithOperandsOp regionOp, |
24 | llvm::function_ref<bool(Operation *)> callBack) { |
25 | Region ®ion = regionOp.getRegion(); |
26 | SmallVector<Value> capturedValues = |
27 | makeRegionIsolatedFromAbove(rewriter, region, cloneOperationIntoRegion: callBack); |
28 | SmallVector<Value> operands = regionOp.getOperands(); |
29 | operands.append(RHS: capturedValues); |
30 | auto isolatedRegionOp = |
31 | rewriter.create<test::IsolatedOneRegionOp>(regionOp.getLoc(), operands); |
32 | rewriter.inlineRegionBefore(region, isolatedRegionOp.getRegion(), |
33 | isolatedRegionOp.getRegion().begin()); |
34 | rewriter.eraseOp(op: regionOp); |
35 | return success(); |
36 | } |
37 | |
38 | namespace { |
39 | |
40 | /// Simple test for making region isolated from above without cloning any |
41 | /// operations. |
42 | struct SimpleMakeIsolatedFromAbove |
43 | : OpRewritePattern<test::OneRegionWithOperandsOp> { |
44 | using OpRewritePattern::OpRewritePattern; |
45 | |
46 | LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp, |
47 | PatternRewriter &rewriter) const override { |
48 | return makeIsolatedFromAboveImpl(rewriter, regionOp, |
49 | [](Operation *) { return false; }); |
50 | } |
51 | }; |
52 | |
53 | /// Test for making region isolated from above while clong operations |
54 | /// with no operands. |
55 | struct MakeIsolatedFromAboveAndCloneOpsWithNoOperands |
56 | : OpRewritePattern<test::OneRegionWithOperandsOp> { |
57 | using OpRewritePattern::OpRewritePattern; |
58 | |
59 | LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp, |
60 | PatternRewriter &rewriter) const override { |
61 | return makeIsolatedFromAboveImpl(rewriter, regionOp, [](Operation *op) { |
62 | return op->getNumOperands() == 0; |
63 | }); |
64 | } |
65 | }; |
66 | |
67 | /// Test for making region isolated from above while clong operations |
68 | /// with no operands. |
69 | struct MakeIsolatedFromAboveAndCloneOpsWithOperands |
70 | : OpRewritePattern<test::OneRegionWithOperandsOp> { |
71 | using OpRewritePattern::OpRewritePattern; |
72 | |
73 | LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp, |
74 | PatternRewriter &rewriter) const override { |
75 | return makeIsolatedFromAboveImpl(rewriter, regionOp, |
76 | [](Operation *op) { return true; }); |
77 | } |
78 | }; |
79 | |
80 | /// Test pass for testing the `makeIsolatedFromAbove` function. |
81 | struct TestMakeIsolatedFromAbovePass |
82 | : public PassWrapper<TestMakeIsolatedFromAbovePass, |
83 | OperationPass<func::FuncOp>> { |
84 | |
85 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMakeIsolatedFromAbovePass) |
86 | |
87 | TestMakeIsolatedFromAbovePass() = default; |
88 | TestMakeIsolatedFromAbovePass(const TestMakeIsolatedFromAbovePass &pass) |
89 | : PassWrapper(pass) {} |
90 | |
91 | StringRef getArgument() const final { |
92 | return "test-make-isolated-from-above" ; |
93 | } |
94 | |
95 | StringRef getDescription() const final { |
96 | return "Test making a region isolated from above" ; |
97 | } |
98 | |
99 | Option<bool> simple{ |
100 | *this, "simple" , |
101 | llvm::cl::desc("Test simple case with no cloning of operations" ), |
102 | llvm::cl::init(false)}; |
103 | |
104 | Option<bool> cloneOpsWithNoOperands{ |
105 | *this, "clone-ops-with-no-operands" , |
106 | llvm::cl::desc("Test case with cloning of operations with no operands" ), |
107 | llvm::cl::init(false)}; |
108 | |
109 | Option<bool> cloneOpsWithOperands{ |
110 | *this, "clone-ops-with-operands" , |
111 | llvm::cl::desc("Test case with cloning of operations with no operands" ), |
112 | llvm::cl::init(false)}; |
113 | |
114 | void runOnOperation() override; |
115 | }; |
116 | |
117 | } // namespace |
118 | |
119 | void TestMakeIsolatedFromAbovePass::runOnOperation() { |
120 | MLIRContext *context = &getContext(); |
121 | func::FuncOp funcOp = getOperation(); |
122 | |
123 | if (simple) { |
124 | RewritePatternSet patterns(context); |
125 | patterns.insert<SimpleMakeIsolatedFromAbove>(arg&: context); |
126 | if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { |
127 | return signalPassFailure(); |
128 | } |
129 | return; |
130 | } |
131 | |
132 | if (cloneOpsWithNoOperands) { |
133 | RewritePatternSet patterns(context); |
134 | patterns.insert<MakeIsolatedFromAboveAndCloneOpsWithNoOperands>(arg&: context); |
135 | if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { |
136 | return signalPassFailure(); |
137 | } |
138 | return; |
139 | } |
140 | |
141 | if (cloneOpsWithOperands) { |
142 | RewritePatternSet patterns(context); |
143 | patterns.insert<MakeIsolatedFromAboveAndCloneOpsWithOperands>(arg&: context); |
144 | if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { |
145 | return signalPassFailure(); |
146 | } |
147 | return; |
148 | } |
149 | } |
150 | |
151 | namespace mlir { |
152 | namespace test { |
153 | void registerTestMakeIsolatedFromAbovePass() { |
154 | PassRegistration<TestMakeIsolatedFromAbovePass>(); |
155 | } |
156 | } // namespace test |
157 | } // namespace mlir |
158 | |