1 | //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include <utility> |
10 | |
11 | #include "mlir/IR/BuiltinTypes.h" |
12 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
13 | #include "llvm/ADT/SmallPtrSet.h" |
14 | |
15 | using namespace mlir; |
16 | |
17 | //===----------------------------------------------------------------------===// |
18 | // ControlFlowInterfaces |
19 | //===----------------------------------------------------------------------===// |
20 | |
21 | #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc" |
22 | |
23 | SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands) |
24 | : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) { |
25 | } |
26 | |
27 | SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount, |
28 | MutableOperandRange forwardedOperands) |
29 | : producedOperandCount(producedOperandCount), |
30 | forwardedOperands(std::move(forwardedOperands)) {} |
31 | |
32 | //===----------------------------------------------------------------------===// |
33 | // BranchOpInterface |
34 | //===----------------------------------------------------------------------===// |
35 | |
36 | /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some |
37 | /// successor if 'operandIndex' is within the range of 'operands', or |
38 | /// std::nullopt if `operandIndex` isn't a successor operand index. |
39 | std::optional<BlockArgument> |
40 | detail::getBranchSuccessorArgument(const SuccessorOperands &operands, |
41 | unsigned operandIndex, Block *successor) { |
42 | OperandRange forwardedOperands = operands.getForwardedOperands(); |
43 | // Check that the operands are valid. |
44 | if (forwardedOperands.empty()) |
45 | return std::nullopt; |
46 | |
47 | // Check to ensure that this operand is within the range. |
48 | unsigned operandsStart = forwardedOperands.getBeginOperandIndex(); |
49 | if (operandIndex < operandsStart || |
50 | operandIndex >= (operandsStart + forwardedOperands.size())) |
51 | return std::nullopt; |
52 | |
53 | // Index the successor. |
54 | unsigned argIndex = |
55 | operands.getProducedOperandCount() + operandIndex - operandsStart; |
56 | return successor->getArgument(i: argIndex); |
57 | } |
58 | |
59 | /// Verify that the given operands match those of the given successor block. |
60 | LogicalResult |
61 | detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, |
62 | const SuccessorOperands &operands) { |
63 | // Check the count. |
64 | unsigned operandCount = operands.size(); |
65 | Block *destBB = op->getSuccessor(index: succNo); |
66 | if (operandCount != destBB->getNumArguments()) |
67 | return op->emitError() << "branch has " << operandCount |
68 | << " operands for successor #" << succNo |
69 | << ", but target block has " |
70 | << destBB->getNumArguments(); |
71 | |
72 | // Check the types. |
73 | for (unsigned i = operands.getProducedOperandCount(); i != operandCount; |
74 | ++i) { |
75 | if (!cast<BranchOpInterface>(op).areTypesCompatible( |
76 | operands[i].getType(), destBB->getArgument(i).getType())) |
77 | return op->emitError() << "type mismatch for bb argument #" << i |
78 | << " of successor #" << succNo; |
79 | } |
80 | return success(); |
81 | } |
82 | |
83 | //===----------------------------------------------------------------------===// |
84 | // RegionBranchOpInterface |
85 | //===----------------------------------------------------------------------===// |
86 | |
87 | static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag, |
88 | RegionBranchPoint sourceNo, |
89 | RegionBranchPoint succRegionNo) { |
90 | diag << "from " ; |
91 | if (Region *region = sourceNo.getRegionOrNull()) |
92 | diag << "Region #" << region->getRegionNumber(); |
93 | else |
94 | diag << "parent operands" ; |
95 | |
96 | diag << " to " ; |
97 | if (Region *region = succRegionNo.getRegionOrNull()) |
98 | diag << "Region #" << region->getRegionNumber(); |
99 | else |
100 | diag << "parent results" ; |
101 | return diag; |
102 | } |
103 | |
104 | /// Verify that types match along all region control flow edges originating from |
105 | /// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the |
106 | /// types of the inputs that flow to a successor region. |
107 | static LogicalResult |
108 | verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, |
109 | function_ref<FailureOr<TypeRange>(RegionBranchPoint)> |
110 | getInputsTypesForRegion) { |
111 | auto regionInterface = cast<RegionBranchOpInterface>(op); |
112 | |
113 | SmallVector<RegionSuccessor, 2> successors; |
114 | regionInterface.getSuccessorRegions(sourcePoint, successors); |
115 | |
116 | for (RegionSuccessor &succ : successors) { |
117 | FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ); |
118 | if (failed(result: sourceTypes)) |
119 | return failure(); |
120 | |
121 | TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); |
122 | if (sourceTypes->size() != succInputsTypes.size()) { |
123 | InFlightDiagnostic diag = op->emitOpError(message: " region control flow edge " ); |
124 | return printRegionEdgeName(diag, sourceNo: sourcePoint, succRegionNo: succ) |
125 | << ": source has " << sourceTypes->size() |
126 | << " operands, but target successor needs " |
127 | << succInputsTypes.size(); |
128 | } |
129 | |
130 | for (const auto &typesIdx : |
131 | llvm::enumerate(First: llvm::zip(t&: *sourceTypes, u&: succInputsTypes))) { |
132 | Type sourceType = std::get<0>(t&: typesIdx.value()); |
133 | Type inputType = std::get<1>(t&: typesIdx.value()); |
134 | if (!regionInterface.areTypesCompatible(sourceType, inputType)) { |
135 | InFlightDiagnostic diag = op->emitOpError(message: " along control flow edge " ); |
136 | return printRegionEdgeName(diag, sourceNo: sourcePoint, succRegionNo: succ) |
137 | << ": source type #" << typesIdx.index() << " " << sourceType |
138 | << " should match input type #" << typesIdx.index() << " " |
139 | << inputType; |
140 | } |
141 | } |
142 | } |
143 | return success(); |
144 | } |
145 | |
146 | /// Verify that types match along control flow edges described the given op. |
147 | LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { |
148 | auto regionInterface = cast<RegionBranchOpInterface>(op); |
149 | |
150 | auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange { |
151 | return regionInterface.getEntrySuccessorOperands(point).getTypes(); |
152 | }; |
153 | |
154 | // Verify types along control flow edges originating from the parent. |
155 | if (failed(result: verifyTypesAlongAllEdges(op, sourcePoint: RegionBranchPoint::parent(), |
156 | getInputsTypesForRegion: inputTypesFromParent))) |
157 | return failure(); |
158 | |
159 | auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) { |
160 | if (lhs.size() != rhs.size()) |
161 | return false; |
162 | for (auto types : llvm::zip(t&: lhs, u&: rhs)) { |
163 | if (!regionInterface.areTypesCompatible(std::get<0>(t&: types), |
164 | std::get<1>(t&: types))) { |
165 | return false; |
166 | } |
167 | } |
168 | return true; |
169 | }; |
170 | |
171 | // Verify types along control flow edges originating from each region. |
172 | for (Region ®ion : op->getRegions()) { |
173 | |
174 | // Since there can be multiple terminators implementing the |
175 | // `RegionBranchTerminatorOpInterface`, all should have the same operand |
176 | // types when passing them to the same region. |
177 | |
178 | SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps; |
179 | for (Block &block : region) |
180 | if (!block.empty()) |
181 | if (auto terminator = |
182 | dyn_cast<RegionBranchTerminatorOpInterface>(block.back())) |
183 | regionReturnOps.push_back(terminator); |
184 | |
185 | // If there is no return-like terminator, the op itself should verify |
186 | // type consistency. |
187 | if (regionReturnOps.empty()) |
188 | continue; |
189 | |
190 | auto inputTypesForRegion = |
191 | [&](RegionBranchPoint point) -> FailureOr<TypeRange> { |
192 | std::optional<OperandRange> regionReturnOperands; |
193 | for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) { |
194 | auto terminatorOperands = regionReturnOp.getSuccessorOperands(point); |
195 | |
196 | if (!regionReturnOperands) { |
197 | regionReturnOperands = terminatorOperands; |
198 | continue; |
199 | } |
200 | |
201 | // Found more than one ReturnLike terminator. Make sure the operand |
202 | // types match with the first one. |
203 | if (!areTypesCompatible(regionReturnOperands->getTypes(), |
204 | terminatorOperands.getTypes())) { |
205 | InFlightDiagnostic diag = op->emitOpError(" along control flow edge" ); |
206 | return printRegionEdgeName(diag, region, point) |
207 | << " operands mismatch between return-like terminators" ; |
208 | } |
209 | } |
210 | |
211 | // All successors get the same set of operand types. |
212 | return TypeRange(regionReturnOperands->getTypes()); |
213 | }; |
214 | |
215 | if (failed(result: verifyTypesAlongAllEdges(op, sourcePoint: region, getInputsTypesForRegion: inputTypesForRegion))) |
216 | return failure(); |
217 | } |
218 | |
219 | return success(); |
220 | } |
221 | |
222 | /// Stop condition for `traverseRegionGraph`. The traversal is interrupted if |
223 | /// this function returns "true" for a successor region. The first parameter is |
224 | /// the successor region. The second parameter indicates all already visited |
225 | /// regions. |
226 | using StopConditionFn = function_ref<bool(Region *, ArrayRef<bool> visited)>; |
227 | |
228 | /// Traverse the region graph starting at `begin`. The traversal is interrupted |
229 | /// if `stopCondition` evaluates to "true" for a successor region. In that case, |
230 | /// this function returns "true". Otherwise, if the traversal was not |
231 | /// interrupted, this function returns "false". |
232 | static bool traverseRegionGraph(Region *begin, |
233 | StopConditionFn stopConditionFn) { |
234 | auto op = cast<RegionBranchOpInterface>(begin->getParentOp()); |
235 | SmallVector<bool> visited(op->getNumRegions(), false); |
236 | visited[begin->getRegionNumber()] = true; |
237 | |
238 | // Retrieve all successors of the region and enqueue them in the worklist. |
239 | SmallVector<Region *> worklist; |
240 | auto enqueueAllSuccessors = [&](Region *region) { |
241 | SmallVector<RegionSuccessor> successors; |
242 | op.getSuccessorRegions(region, successors); |
243 | for (RegionSuccessor successor : successors) |
244 | if (!successor.isParent()) |
245 | worklist.push_back(Elt: successor.getSuccessor()); |
246 | }; |
247 | enqueueAllSuccessors(begin); |
248 | |
249 | // Process all regions in the worklist via DFS. |
250 | while (!worklist.empty()) { |
251 | Region *nextRegion = worklist.pop_back_val(); |
252 | if (stopConditionFn(nextRegion, visited)) |
253 | return true; |
254 | if (visited[nextRegion->getRegionNumber()]) |
255 | continue; |
256 | visited[nextRegion->getRegionNumber()] = true; |
257 | enqueueAllSuccessors(nextRegion); |
258 | } |
259 | |
260 | return false; |
261 | } |
262 | |
263 | /// Return `true` if region `r` is reachable from region `begin` according to |
264 | /// the RegionBranchOpInterface (by taking a branch). |
265 | static bool isRegionReachable(Region *begin, Region *r) { |
266 | assert(begin->getParentOp() == r->getParentOp() && |
267 | "expected that both regions belong to the same op" ); |
268 | return traverseRegionGraph(begin, |
269 | stopConditionFn: [&](Region *nextRegion, ArrayRef<bool> visited) { |
270 | // Interrupt traversal if `r` was reached. |
271 | return nextRegion == r; |
272 | }); |
273 | } |
274 | |
275 | /// Return `true` if `a` and `b` are in mutually exclusive regions. |
276 | /// |
277 | /// 1. Find the first common of `a` and `b` (ancestor) that implements |
278 | /// RegionBranchOpInterface. |
279 | /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are |
280 | /// contained. |
281 | /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are |
282 | /// mutually exclusive if they are not reachable from each other as per |
283 | /// RegionBranchOpInterface::getSuccessorRegions. |
284 | bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) { |
285 | assert(a && "expected non-empty operation" ); |
286 | assert(b && "expected non-empty operation" ); |
287 | |
288 | auto branchOp = a->getParentOfType<RegionBranchOpInterface>(); |
289 | while (branchOp) { |
290 | // Check if b is inside branchOp. (We already know that a is.) |
291 | if (!branchOp->isProperAncestor(b)) { |
292 | // Check next enclosing RegionBranchOpInterface. |
293 | branchOp = branchOp->getParentOfType<RegionBranchOpInterface>(); |
294 | continue; |
295 | } |
296 | |
297 | // b is contained in branchOp. Retrieve the regions in which `a` and `b` |
298 | // are contained. |
299 | Region *regionA = nullptr, *regionB = nullptr; |
300 | for (Region &r : branchOp->getRegions()) { |
301 | if (r.findAncestorOpInRegion(*a)) { |
302 | assert(!regionA && "already found a region for a" ); |
303 | regionA = &r; |
304 | } |
305 | if (r.findAncestorOpInRegion(*b)) { |
306 | assert(!regionB && "already found a region for b" ); |
307 | regionB = &r; |
308 | } |
309 | } |
310 | assert(regionA && regionB && "could not find region of op" ); |
311 | |
312 | // `a` and `b` are in mutually exclusive regions if both regions are |
313 | // distinct and neither region is reachable from the other region. |
314 | return regionA != regionB && !isRegionReachable(begin: regionA, r: regionB) && |
315 | !isRegionReachable(begin: regionB, r: regionA); |
316 | } |
317 | |
318 | // Could not find a common RegionBranchOpInterface among a's and b's |
319 | // ancestors. |
320 | return false; |
321 | } |
322 | |
323 | bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) { |
324 | Region *region = &getOperation()->getRegion(index); |
325 | return isRegionReachable(region, region); |
326 | } |
327 | |
328 | bool RegionBranchOpInterface::hasLoop() { |
329 | SmallVector<RegionSuccessor> entryRegions; |
330 | getSuccessorRegions(RegionBranchPoint::parent(), entryRegions); |
331 | for (RegionSuccessor successor : entryRegions) |
332 | if (!successor.isParent() && |
333 | traverseRegionGraph(successor.getSuccessor(), |
334 | [](Region *nextRegion, ArrayRef<bool> visited) { |
335 | // Interrupt traversal if the region was already |
336 | // visited. |
337 | return visited[nextRegion->getRegionNumber()]; |
338 | })) |
339 | return true; |
340 | return false; |
341 | } |
342 | |
343 | Region *mlir::getEnclosingRepetitiveRegion(Operation *op) { |
344 | while (Region *region = op->getParentRegion()) { |
345 | op = region->getParentOp(); |
346 | if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) |
347 | if (branchOp.isRepetitiveRegion(region->getRegionNumber())) |
348 | return region; |
349 | } |
350 | return nullptr; |
351 | } |
352 | |
353 | Region *mlir::getEnclosingRepetitiveRegion(Value value) { |
354 | Region *region = value.getParentRegion(); |
355 | while (region) { |
356 | Operation *op = region->getParentOp(); |
357 | if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) |
358 | if (branchOp.isRepetitiveRegion(region->getRegionNumber())) |
359 | return region; |
360 | region = op->getParentRegion(); |
361 | } |
362 | return nullptr; |
363 | } |
364 | |