1//===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
11#include "mlir/IR/BuiltinTypes.h"
12#include "mlir/Interfaces/ControlFlowInterfaces.h"
13#include "llvm/ADT/SmallPtrSet.h"
14
15using namespace mlir;
16
17//===----------------------------------------------------------------------===//
18// ControlFlowInterfaces
19//===----------------------------------------------------------------------===//
20
21#include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
22
23SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands)
24 : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
25}
26
27SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
28 MutableOperandRange forwardedOperands)
29 : producedOperandCount(producedOperandCount),
30 forwardedOperands(std::move(forwardedOperands)) {}
31
32//===----------------------------------------------------------------------===//
33// BranchOpInterface
34//===----------------------------------------------------------------------===//
35
36/// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
37/// successor if 'operandIndex' is within the range of 'operands', or
38/// std::nullopt if `operandIndex` isn't a successor operand index.
39std::optional<BlockArgument>
40detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
41 unsigned operandIndex, Block *successor) {
42 OperandRange forwardedOperands = operands.getForwardedOperands();
43 // Check that the operands are valid.
44 if (forwardedOperands.empty())
45 return std::nullopt;
46
47 // Check to ensure that this operand is within the range.
48 unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
49 if (operandIndex < operandsStart ||
50 operandIndex >= (operandsStart + forwardedOperands.size()))
51 return std::nullopt;
52
53 // Index the successor.
54 unsigned argIndex =
55 operands.getProducedOperandCount() + operandIndex - operandsStart;
56 return successor->getArgument(i: argIndex);
57}
58
59/// Verify that the given operands match those of the given successor block.
60LogicalResult
61detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
62 const SuccessorOperands &operands) {
63 // Check the count.
64 unsigned operandCount = operands.size();
65 Block *destBB = op->getSuccessor(index: succNo);
66 if (operandCount != destBB->getNumArguments())
67 return op->emitError() << "branch has " << operandCount
68 << " operands for successor #" << succNo
69 << ", but target block has "
70 << destBB->getNumArguments();
71
72 // Check the types.
73 for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
74 ++i) {
75 if (!cast<BranchOpInterface>(op).areTypesCompatible(
76 operands[i].getType(), destBB->getArgument(i).getType()))
77 return op->emitError() << "type mismatch for bb argument #" << i
78 << " of successor #" << succNo;
79 }
80 return success();
81}
82
83//===----------------------------------------------------------------------===//
84// RegionBranchOpInterface
85//===----------------------------------------------------------------------===//
86
87static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
88 RegionBranchPoint sourceNo,
89 RegionBranchPoint succRegionNo) {
90 diag << "from ";
91 if (Region *region = sourceNo.getRegionOrNull())
92 diag << "Region #" << region->getRegionNumber();
93 else
94 diag << "parent operands";
95
96 diag << " to ";
97 if (Region *region = succRegionNo.getRegionOrNull())
98 diag << "Region #" << region->getRegionNumber();
99 else
100 diag << "parent results";
101 return diag;
102}
103
104/// Verify that types match along all region control flow edges originating from
105/// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
106/// types of the inputs that flow to a successor region.
107static LogicalResult
108verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
109 function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
110 getInputsTypesForRegion) {
111 auto regionInterface = cast<RegionBranchOpInterface>(op);
112
113 SmallVector<RegionSuccessor, 2> successors;
114 regionInterface.getSuccessorRegions(sourcePoint, successors);
115
116 for (RegionSuccessor &succ : successors) {
117 FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
118 if (failed(result: sourceTypes))
119 return failure();
120
121 TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
122 if (sourceTypes->size() != succInputsTypes.size()) {
123 InFlightDiagnostic diag = op->emitOpError(message: " region control flow edge ");
124 return printRegionEdgeName(diag, sourceNo: sourcePoint, succRegionNo: succ)
125 << ": source has " << sourceTypes->size()
126 << " operands, but target successor needs "
127 << succInputsTypes.size();
128 }
129
130 for (const auto &typesIdx :
131 llvm::enumerate(First: llvm::zip(t&: *sourceTypes, u&: succInputsTypes))) {
132 Type sourceType = std::get<0>(t&: typesIdx.value());
133 Type inputType = std::get<1>(t&: typesIdx.value());
134 if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
135 InFlightDiagnostic diag = op->emitOpError(message: " along control flow edge ");
136 return printRegionEdgeName(diag, sourceNo: sourcePoint, succRegionNo: succ)
137 << ": source type #" << typesIdx.index() << " " << sourceType
138 << " should match input type #" << typesIdx.index() << " "
139 << inputType;
140 }
141 }
142 }
143 return success();
144}
145
146/// Verify that types match along control flow edges described the given op.
147LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
148 auto regionInterface = cast<RegionBranchOpInterface>(op);
149
150 auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange {
151 return regionInterface.getEntrySuccessorOperands(point).getTypes();
152 };
153
154 // Verify types along control flow edges originating from the parent.
155 if (failed(result: verifyTypesAlongAllEdges(op, sourcePoint: RegionBranchPoint::parent(),
156 getInputsTypesForRegion: inputTypesFromParent)))
157 return failure();
158
159 auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
160 if (lhs.size() != rhs.size())
161 return false;
162 for (auto types : llvm::zip(t&: lhs, u&: rhs)) {
163 if (!regionInterface.areTypesCompatible(std::get<0>(t&: types),
164 std::get<1>(t&: types))) {
165 return false;
166 }
167 }
168 return true;
169 };
170
171 // Verify types along control flow edges originating from each region.
172 for (Region &region : op->getRegions()) {
173
174 // Since there can be multiple terminators implementing the
175 // `RegionBranchTerminatorOpInterface`, all should have the same operand
176 // types when passing them to the same region.
177
178 SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps;
179 for (Block &block : region)
180 if (!block.empty())
181 if (auto terminator =
182 dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
183 regionReturnOps.push_back(terminator);
184
185 // If there is no return-like terminator, the op itself should verify
186 // type consistency.
187 if (regionReturnOps.empty())
188 continue;
189
190 auto inputTypesForRegion =
191 [&](RegionBranchPoint point) -> FailureOr<TypeRange> {
192 std::optional<OperandRange> regionReturnOperands;
193 for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
194 auto terminatorOperands = regionReturnOp.getSuccessorOperands(point);
195
196 if (!regionReturnOperands) {
197 regionReturnOperands = terminatorOperands;
198 continue;
199 }
200
201 // Found more than one ReturnLike terminator. Make sure the operand
202 // types match with the first one.
203 if (!areTypesCompatible(regionReturnOperands->getTypes(),
204 terminatorOperands.getTypes())) {
205 InFlightDiagnostic diag = op->emitOpError(" along control flow edge");
206 return printRegionEdgeName(diag, region, point)
207 << " operands mismatch between return-like terminators";
208 }
209 }
210
211 // All successors get the same set of operand types.
212 return TypeRange(regionReturnOperands->getTypes());
213 };
214
215 if (failed(result: verifyTypesAlongAllEdges(op, sourcePoint: region, getInputsTypesForRegion: inputTypesForRegion)))
216 return failure();
217 }
218
219 return success();
220}
221
222/// Stop condition for `traverseRegionGraph`. The traversal is interrupted if
223/// this function returns "true" for a successor region. The first parameter is
224/// the successor region. The second parameter indicates all already visited
225/// regions.
226using StopConditionFn = function_ref<bool(Region *, ArrayRef<bool> visited)>;
227
228/// Traverse the region graph starting at `begin`. The traversal is interrupted
229/// if `stopCondition` evaluates to "true" for a successor region. In that case,
230/// this function returns "true". Otherwise, if the traversal was not
231/// interrupted, this function returns "false".
232static bool traverseRegionGraph(Region *begin,
233 StopConditionFn stopConditionFn) {
234 auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
235 SmallVector<bool> visited(op->getNumRegions(), false);
236 visited[begin->getRegionNumber()] = true;
237
238 // Retrieve all successors of the region and enqueue them in the worklist.
239 SmallVector<Region *> worklist;
240 auto enqueueAllSuccessors = [&](Region *region) {
241 SmallVector<RegionSuccessor> successors;
242 op.getSuccessorRegions(region, successors);
243 for (RegionSuccessor successor : successors)
244 if (!successor.isParent())
245 worklist.push_back(Elt: successor.getSuccessor());
246 };
247 enqueueAllSuccessors(begin);
248
249 // Process all regions in the worklist via DFS.
250 while (!worklist.empty()) {
251 Region *nextRegion = worklist.pop_back_val();
252 if (stopConditionFn(nextRegion, visited))
253 return true;
254 if (visited[nextRegion->getRegionNumber()])
255 continue;
256 visited[nextRegion->getRegionNumber()] = true;
257 enqueueAllSuccessors(nextRegion);
258 }
259
260 return false;
261}
262
263/// Return `true` if region `r` is reachable from region `begin` according to
264/// the RegionBranchOpInterface (by taking a branch).
265static bool isRegionReachable(Region *begin, Region *r) {
266 assert(begin->getParentOp() == r->getParentOp() &&
267 "expected that both regions belong to the same op");
268 return traverseRegionGraph(begin,
269 stopConditionFn: [&](Region *nextRegion, ArrayRef<bool> visited) {
270 // Interrupt traversal if `r` was reached.
271 return nextRegion == r;
272 });
273}
274
275/// Return `true` if `a` and `b` are in mutually exclusive regions.
276///
277/// 1. Find the first common of `a` and `b` (ancestor) that implements
278/// RegionBranchOpInterface.
279/// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
280/// contained.
281/// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
282/// mutually exclusive if they are not reachable from each other as per
283/// RegionBranchOpInterface::getSuccessorRegions.
284bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
285 assert(a && "expected non-empty operation");
286 assert(b && "expected non-empty operation");
287
288 auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
289 while (branchOp) {
290 // Check if b is inside branchOp. (We already know that a is.)
291 if (!branchOp->isProperAncestor(b)) {
292 // Check next enclosing RegionBranchOpInterface.
293 branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
294 continue;
295 }
296
297 // b is contained in branchOp. Retrieve the regions in which `a` and `b`
298 // are contained.
299 Region *regionA = nullptr, *regionB = nullptr;
300 for (Region &r : branchOp->getRegions()) {
301 if (r.findAncestorOpInRegion(*a)) {
302 assert(!regionA && "already found a region for a");
303 regionA = &r;
304 }
305 if (r.findAncestorOpInRegion(*b)) {
306 assert(!regionB && "already found a region for b");
307 regionB = &r;
308 }
309 }
310 assert(regionA && regionB && "could not find region of op");
311
312 // `a` and `b` are in mutually exclusive regions if both regions are
313 // distinct and neither region is reachable from the other region.
314 return regionA != regionB && !isRegionReachable(begin: regionA, r: regionB) &&
315 !isRegionReachable(begin: regionB, r: regionA);
316 }
317
318 // Could not find a common RegionBranchOpInterface among a's and b's
319 // ancestors.
320 return false;
321}
322
323bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
324 Region *region = &getOperation()->getRegion(index);
325 return isRegionReachable(region, region);
326}
327
328bool RegionBranchOpInterface::hasLoop() {
329 SmallVector<RegionSuccessor> entryRegions;
330 getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
331 for (RegionSuccessor successor : entryRegions)
332 if (!successor.isParent() &&
333 traverseRegionGraph(successor.getSuccessor(),
334 [](Region *nextRegion, ArrayRef<bool> visited) {
335 // Interrupt traversal if the region was already
336 // visited.
337 return visited[nextRegion->getRegionNumber()];
338 }))
339 return true;
340 return false;
341}
342
343Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
344 while (Region *region = op->getParentRegion()) {
345 op = region->getParentOp();
346 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
347 if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
348 return region;
349 }
350 return nullptr;
351}
352
353Region *mlir::getEnclosingRepetitiveRegion(Value value) {
354 Region *region = value.getParentRegion();
355 while (region) {
356 Operation *op = region->getParentOp();
357 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
358 if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
359 return region;
360 region = op->getParentRegion();
361 }
362 return nullptr;
363}
364

source code of mlir/lib/Interfaces/ControlFlowInterfaces.cpp