1//===- TestMakeIsolatedFromAbove.cpp - Test makeIsolatedFromAbove method -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestDialect.h"
10#include "TestOps.h"
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/IR/PatternMatch.h"
13#include "mlir/Pass/Pass.h"
14#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15#include "mlir/Transforms/RegionUtils.h"
16
17using namespace mlir;
18
19/// Helper function to call the `makeRegionIsolatedFromAbove` to convert
20/// `test.one_region_op` to `test.isolated_one_region_op`.
21static LogicalResult
22makeIsolatedFromAboveImpl(RewriterBase &rewriter,
23 test::OneRegionWithOperandsOp regionOp,
24 llvm::function_ref<bool(Operation *)> callBack) {
25 Region &region = regionOp.getRegion();
26 SmallVector<Value> capturedValues =
27 makeRegionIsolatedFromAbove(rewriter, region, cloneOperationIntoRegion: callBack);
28 SmallVector<Value> operands = regionOp.getOperands();
29 operands.append(RHS: capturedValues);
30 auto isolatedRegionOp =
31 rewriter.create<test::IsolatedOneRegionOp>(regionOp.getLoc(), operands);
32 rewriter.inlineRegionBefore(region, isolatedRegionOp.getRegion(),
33 isolatedRegionOp.getRegion().begin());
34 rewriter.eraseOp(op: regionOp);
35 return success();
36}
37
38namespace {
39
40/// Simple test for making region isolated from above without cloning any
41/// operations.
42struct SimpleMakeIsolatedFromAbove
43 : OpRewritePattern<test::OneRegionWithOperandsOp> {
44 using OpRewritePattern::OpRewritePattern;
45
46 LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp,
47 PatternRewriter &rewriter) const override {
48 return makeIsolatedFromAboveImpl(rewriter, regionOp,
49 [](Operation *) { return false; });
50 }
51};
52
53/// Test for making region isolated from above while clong operations
54/// with no operands.
55struct MakeIsolatedFromAboveAndCloneOpsWithNoOperands
56 : OpRewritePattern<test::OneRegionWithOperandsOp> {
57 using OpRewritePattern::OpRewritePattern;
58
59 LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp,
60 PatternRewriter &rewriter) const override {
61 return makeIsolatedFromAboveImpl(rewriter, regionOp, [](Operation *op) {
62 return op->getNumOperands() == 0;
63 });
64 }
65};
66
67/// Test for making region isolated from above while clong operations
68/// with no operands.
69struct MakeIsolatedFromAboveAndCloneOpsWithOperands
70 : OpRewritePattern<test::OneRegionWithOperandsOp> {
71 using OpRewritePattern::OpRewritePattern;
72
73 LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp,
74 PatternRewriter &rewriter) const override {
75 return makeIsolatedFromAboveImpl(rewriter, regionOp,
76 [](Operation *op) { return true; });
77 }
78};
79
80/// Test pass for testing the `makeIsolatedFromAbove` function.
81struct TestMakeIsolatedFromAbovePass
82 : public PassWrapper<TestMakeIsolatedFromAbovePass,
83 OperationPass<func::FuncOp>> {
84
85 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMakeIsolatedFromAbovePass)
86
87 TestMakeIsolatedFromAbovePass() = default;
88 TestMakeIsolatedFromAbovePass(const TestMakeIsolatedFromAbovePass &pass)
89 : PassWrapper(pass) {}
90
91 StringRef getArgument() const final {
92 return "test-make-isolated-from-above";
93 }
94
95 StringRef getDescription() const final {
96 return "Test making a region isolated from above";
97 }
98
99 Option<bool> simple{
100 *this, "simple",
101 llvm::cl::desc("Test simple case with no cloning of operations"),
102 llvm::cl::init(false)};
103
104 Option<bool> cloneOpsWithNoOperands{
105 *this, "clone-ops-with-no-operands",
106 llvm::cl::desc("Test case with cloning of operations with no operands"),
107 llvm::cl::init(false)};
108
109 Option<bool> cloneOpsWithOperands{
110 *this, "clone-ops-with-operands",
111 llvm::cl::desc("Test case with cloning of operations with no operands"),
112 llvm::cl::init(false)};
113
114 void runOnOperation() override;
115};
116
117} // namespace
118
119void TestMakeIsolatedFromAbovePass::runOnOperation() {
120 MLIRContext *context = &getContext();
121 func::FuncOp funcOp = getOperation();
122
123 if (simple) {
124 RewritePatternSet patterns(context);
125 patterns.insert<SimpleMakeIsolatedFromAbove>(arg&: context);
126 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
127 return signalPassFailure();
128 }
129 return;
130 }
131
132 if (cloneOpsWithNoOperands) {
133 RewritePatternSet patterns(context);
134 patterns.insert<MakeIsolatedFromAboveAndCloneOpsWithNoOperands>(arg&: context);
135 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
136 return signalPassFailure();
137 }
138 return;
139 }
140
141 if (cloneOpsWithOperands) {
142 RewritePatternSet patterns(context);
143 patterns.insert<MakeIsolatedFromAboveAndCloneOpsWithOperands>(arg&: context);
144 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
145 return signalPassFailure();
146 }
147 return;
148 }
149}
150
151namespace mlir {
152namespace test {
153void registerTestMakeIsolatedFromAbovePass() {
154 PassRegistration<TestMakeIsolatedFromAbovePass>();
155}
156} // namespace test
157} // namespace mlir
158

source code of mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp