1//===- ControlFlowInterfacesTest.cpp - Unit Tests for Control Flow Interf. ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Interfaces/ControlFlowInterfaces.h"
10#include "mlir/IR/BuiltinOps.h"
11#include "mlir/IR/Dialect.h"
12#include "mlir/IR/DialectImplementation.h"
13#include "mlir/IR/OpDefinition.h"
14#include "mlir/IR/OpImplementation.h"
15#include "mlir/Parser/Parser.h"
16
17#include <gtest/gtest.h>
18
19using namespace mlir;
20
21/// A dummy op that is also a terminator.
22struct DummyOp : public Op<DummyOp, OpTrait::IsTerminator> {
23 using Op::Op;
24 static ArrayRef<StringRef> getAttributeNames() { return {}; }
25
26 static StringRef getOperationName() { return "cftest.dummy_op"; }
27};
28
29/// All regions of this op are mutually exclusive.
30struct MutuallyExclusiveRegionsOp
31 : public Op<MutuallyExclusiveRegionsOp, RegionBranchOpInterface::Trait> {
32 using Op::Op;
33 static ArrayRef<StringRef> getAttributeNames() { return {}; }
34
35 static StringRef getOperationName() {
36 return "cftest.mutually_exclusive_regions_op";
37 }
38
39 // Regions have no successors.
40 void getSuccessorRegions(RegionBranchPoint point,
41 SmallVectorImpl<RegionSuccessor> &regions) {}
42};
43
44/// All regions of this op call each other in a large circle.
45struct LoopRegionsOp
46 : public Op<LoopRegionsOp, RegionBranchOpInterface::Trait> {
47 using Op::Op;
48 static const unsigned kNumRegions = 3;
49
50 static ArrayRef<StringRef> getAttributeNames() { return {}; }
51
52 static StringRef getOperationName() { return "cftest.loop_regions_op"; }
53
54 void getSuccessorRegions(RegionBranchPoint point,
55 SmallVectorImpl<RegionSuccessor> &regions) {
56 if (Region *region = point.getRegionOrNull()) {
57 if (point == (*this)->getRegion(1))
58 // This region also branches back to the parent.
59 regions.push_back(Elt: RegionSuccessor());
60 regions.push_back(Elt: RegionSuccessor(region));
61 }
62 }
63};
64
65/// Each region branches back it itself or the parent.
66struct DoubleLoopRegionsOp
67 : public Op<DoubleLoopRegionsOp, RegionBranchOpInterface::Trait> {
68 using Op::Op;
69
70 static ArrayRef<StringRef> getAttributeNames() { return {}; }
71
72 static StringRef getOperationName() {
73 return "cftest.double_loop_regions_op";
74 }
75
76 void getSuccessorRegions(RegionBranchPoint point,
77 SmallVectorImpl<RegionSuccessor> &regions) {
78 if (Region *region = point.getRegionOrNull()) {
79 regions.push_back(Elt: RegionSuccessor());
80 regions.push_back(Elt: RegionSuccessor(region));
81 }
82 }
83};
84
85/// Regions are executed sequentially.
86struct SequentialRegionsOp
87 : public Op<SequentialRegionsOp, RegionBranchOpInterface::Trait> {
88 using Op::Op;
89 static ArrayRef<StringRef> getAttributeNames() { return {}; }
90
91 static StringRef getOperationName() { return "cftest.sequential_regions_op"; }
92
93 // Region 0 has Region 1 as a successor.
94 void getSuccessorRegions(RegionBranchPoint point,
95 SmallVectorImpl<RegionSuccessor> &regions) {
96 if (point == (*this)->getRegion(0)) {
97 Operation *thisOp = this->getOperation();
98 regions.push_back(Elt: RegionSuccessor(&thisOp->getRegion(index: 1)));
99 }
100 }
101};
102
103/// A dialect putting all the above together.
104struct CFTestDialect : Dialect {
105 explicit CFTestDialect(MLIRContext *ctx)
106 : Dialect(getDialectNamespace(), ctx, TypeID::get<CFTestDialect>()) {
107 addOperations<DummyOp, MutuallyExclusiveRegionsOp, LoopRegionsOp,
108 DoubleLoopRegionsOp, SequentialRegionsOp>();
109 }
110 static StringRef getDialectNamespace() { return "cftest"; }
111};
112
113TEST(RegionBranchOpInterface, MutuallyExclusiveOps) {
114 const char *ir = R"MLIR(
115"cftest.mutually_exclusive_regions_op"() (
116 {"cftest.dummy_op"() : () -> ()}, // op1
117 {"cftest.dummy_op"() : () -> ()} // op2
118 ) : () -> ()
119 )MLIR";
120
121 DialectRegistry registry;
122 registry.insert<CFTestDialect>();
123 MLIRContext ctx(registry);
124
125 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
126 Operation *testOp = &module->getBody()->getOperations().front();
127 Operation *op1 = &testOp->getRegion(index: 0).front().front();
128 Operation *op2 = &testOp->getRegion(index: 1).front().front();
129
130 EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
131 EXPECT_TRUE(insideMutuallyExclusiveRegions(op2, op1));
132}
133
134TEST(RegionBranchOpInterface, MutuallyExclusiveOps2) {
135 const char *ir = R"MLIR(
136"cftest.double_loop_regions_op"() (
137 {"cftest.dummy_op"() : () -> ()}, // op1
138 {"cftest.dummy_op"() : () -> ()} // op2
139 ) : () -> ()
140 )MLIR";
141
142 DialectRegistry registry;
143 registry.insert<CFTestDialect>();
144 MLIRContext ctx(registry);
145
146 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(sourceStr: ir, config: &ctx);
147 Operation *testOp = &module->getBody()->getOperations().front();
148 Operation *op1 = &testOp->getRegion(index: 0).front().front();
149 Operation *op2 = &testOp->getRegion(index: 1).front().front();
150
151 EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
152 EXPECT_TRUE(insideMutuallyExclusiveRegions(op2, op1));
153}
154
155TEST(RegionBranchOpInterface, NotMutuallyExclusiveOps) {
156 const char *ir = R"MLIR(
157"cftest.sequential_regions_op"() (
158 {"cftest.dummy_op"() : () -> ()}, // op1
159 {"cftest.dummy_op"() : () -> ()} // op2
160 ) : () -> ()
161 )MLIR";
162
163 DialectRegistry registry;
164 registry.insert<CFTestDialect>();
165 MLIRContext ctx(registry);
166
167 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(sourceStr: ir, config: &ctx);
168 Operation *testOp = &module->getBody()->getOperations().front();
169 Operation *op1 = &testOp->getRegion(index: 0).front().front();
170 Operation *op2 = &testOp->getRegion(index: 1).front().front();
171
172 EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op2));
173 EXPECT_FALSE(insideMutuallyExclusiveRegions(op2, op1));
174}
175
176TEST(RegionBranchOpInterface, NestedMutuallyExclusiveOps) {
177 const char *ir = R"MLIR(
178"cftest.mutually_exclusive_regions_op"() (
179 {
180 "cftest.sequential_regions_op"() (
181 {"cftest.dummy_op"() : () -> ()}, // op1
182 {"cftest.dummy_op"() : () -> ()} // op3
183 ) : () -> ()
184 "cftest.dummy_op"() : () -> ()
185 },
186 {"cftest.dummy_op"() : () -> ()} // op2
187 ) : () -> ()
188 )MLIR";
189
190 DialectRegistry registry;
191 registry.insert<CFTestDialect>();
192 MLIRContext ctx(registry);
193
194 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(sourceStr: ir, config: &ctx);
195 Operation *testOp = &module->getBody()->getOperations().front();
196 Operation *op1 =
197 &testOp->getRegion(index: 0).front().front().getRegion(index: 0).front().front();
198 Operation *op2 = &testOp->getRegion(index: 1).front().front();
199 Operation *op3 =
200 &testOp->getRegion(index: 0).front().front().getRegion(index: 1).front().front();
201
202 EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
203 EXPECT_TRUE(insideMutuallyExclusiveRegions(op3, op2));
204 EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op3));
205}
206
207TEST(RegionBranchOpInterface, RecursiveRegions) {
208 const char *ir = R"MLIR(
209"cftest.loop_regions_op"() (
210 {"cftest.dummy_op"() : () -> ()}, // op1
211 {"cftest.dummy_op"() : () -> ()}, // op2
212 {"cftest.dummy_op"() : () -> ()} // op3
213 ) : () -> ()
214 )MLIR";
215
216 DialectRegistry registry;
217 registry.insert<CFTestDialect>();
218 MLIRContext ctx(registry);
219
220 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(sourceStr: ir, config: &ctx);
221 Operation *testOp = &module->getBody()->getOperations().front();
222 auto regionOp = cast<RegionBranchOpInterface>(testOp);
223 Operation *op1 = &testOp->getRegion(index: 0).front().front();
224 Operation *op2 = &testOp->getRegion(index: 1).front().front();
225 Operation *op3 = &testOp->getRegion(index: 2).front().front();
226
227 EXPECT_TRUE(regionOp.isRepetitiveRegion(0));
228 EXPECT_TRUE(regionOp.isRepetitiveRegion(1));
229 EXPECT_TRUE(regionOp.isRepetitiveRegion(2));
230 EXPECT_NE(getEnclosingRepetitiveRegion(op1), nullptr);
231 EXPECT_NE(getEnclosingRepetitiveRegion(op2), nullptr);
232 EXPECT_NE(getEnclosingRepetitiveRegion(op3), nullptr);
233}
234
235TEST(RegionBranchOpInterface, NotRecursiveRegions) {
236 const char *ir = R"MLIR(
237"cftest.sequential_regions_op"() (
238 {"cftest.dummy_op"() : () -> ()}, // op1
239 {"cftest.dummy_op"() : () -> ()} // op2
240 ) : () -> ()
241 )MLIR";
242
243 DialectRegistry registry;
244 registry.insert<CFTestDialect>();
245 MLIRContext ctx(registry);
246
247 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(sourceStr: ir, config: &ctx);
248 Operation *testOp = &module->getBody()->getOperations().front();
249 Operation *op1 = &testOp->getRegion(index: 0).front().front();
250 Operation *op2 = &testOp->getRegion(index: 1).front().front();
251
252 EXPECT_EQ(getEnclosingRepetitiveRegion(op1), nullptr);
253 EXPECT_EQ(getEnclosingRepetitiveRegion(op2), nullptr);
254}
255

source code of mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp