1//===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
11#include "mlir/IR/BuiltinTypes.h"
12#include "mlir/Interfaces/CallInterfaces.h"
13#include "mlir/Interfaces/ControlFlowInterfaces.h"
14#include "llvm/ADT/SmallPtrSet.h"
15
16using namespace mlir;
17
18//===----------------------------------------------------------------------===//
19// ControlFlowInterfaces
20//===----------------------------------------------------------------------===//
21
22#include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
23
24SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands)
25 : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
26}
27
28SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
29 MutableOperandRange forwardedOperands)
30 : producedOperandCount(producedOperandCount),
31 forwardedOperands(std::move(forwardedOperands)) {}
32
33//===----------------------------------------------------------------------===//
34// BranchOpInterface
35//===----------------------------------------------------------------------===//
36
37/// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
38/// successor if 'operandIndex' is within the range of 'operands', or
39/// std::nullopt if `operandIndex` isn't a successor operand index.
40std::optional<BlockArgument>
41detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
42 unsigned operandIndex, Block *successor) {
43 OperandRange forwardedOperands = operands.getForwardedOperands();
44 // Check that the operands are valid.
45 if (forwardedOperands.empty())
46 return std::nullopt;
47
48 // Check to ensure that this operand is within the range.
49 unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
50 if (operandIndex < operandsStart ||
51 operandIndex >= (operandsStart + forwardedOperands.size()))
52 return std::nullopt;
53
54 // Index the successor.
55 unsigned argIndex =
56 operands.getProducedOperandCount() + operandIndex - operandsStart;
57 return successor->getArgument(i: argIndex);
58}
59
60/// Verify that the given operands match those of the given successor block.
61LogicalResult
62detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
63 const SuccessorOperands &operands) {
64 // Check the count.
65 unsigned operandCount = operands.size();
66 Block *destBB = op->getSuccessor(index: succNo);
67 if (operandCount != destBB->getNumArguments())
68 return op->emitError() << "branch has " << operandCount
69 << " operands for successor #" << succNo
70 << ", but target block has "
71 << destBB->getNumArguments();
72
73 // Check the types.
74 for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
75 ++i) {
76 if (!cast<BranchOpInterface>(Val: op).areTypesCompatible(
77 lhs: operands[i].getType(), rhs: destBB->getArgument(i).getType()))
78 return op->emitError() << "type mismatch for bb argument #" << i
79 << " of successor #" << succNo;
80 }
81 return success();
82}
83
84//===----------------------------------------------------------------------===//
85// WeightedBranchOpInterface
86//===----------------------------------------------------------------------===//
87
88static LogicalResult verifyWeights(Operation *op,
89 llvm::ArrayRef<int32_t> weights,
90 std::size_t expectedWeightsNum,
91 llvm::StringRef weightAnchorName,
92 llvm::StringRef weightRefName) {
93 if (weights.empty())
94 return success();
95
96 if (weights.size() != expectedWeightsNum)
97 return op->emitError() << "expects number of " << weightAnchorName
98 << " weights to match number of " << weightRefName
99 << ": " << weights.size() << " vs "
100 << expectedWeightsNum;
101
102 if (llvm::all_of(Range&: weights, P: [](int32_t value) { return value == 0; }))
103 return op->emitError() << "branch weights cannot all be zero";
104
105 return success();
106}
107
108LogicalResult detail::verifyBranchWeights(Operation *op) {
109 llvm::ArrayRef<int32_t> weights =
110 cast<WeightedBranchOpInterface>(Val: op).getWeights();
111 return verifyWeights(op, weights, expectedWeightsNum: op->getNumSuccessors(), weightAnchorName: "branch",
112 weightRefName: "successors");
113}
114
115//===----------------------------------------------------------------------===//
116// WeightedRegionBranchOpInterface
117//===----------------------------------------------------------------------===//
118
119LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
120 llvm::ArrayRef<int32_t> weights =
121 cast<WeightedRegionBranchOpInterface>(Val: op).getWeights();
122 return verifyWeights(op, weights, expectedWeightsNum: op->getNumRegions(), weightAnchorName: "region", weightRefName: "regions");
123}
124
125//===----------------------------------------------------------------------===//
126// RegionBranchOpInterface
127//===----------------------------------------------------------------------===//
128
129static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
130 RegionBranchPoint sourceNo,
131 RegionBranchPoint succRegionNo) {
132 diag << "from ";
133 if (Region *region = sourceNo.getRegionOrNull())
134 diag << "Region #" << region->getRegionNumber();
135 else
136 diag << "parent operands";
137
138 diag << " to ";
139 if (Region *region = succRegionNo.getRegionOrNull())
140 diag << "Region #" << region->getRegionNumber();
141 else
142 diag << "parent results";
143 return diag;
144}
145
146/// Verify that types match along all region control flow edges originating from
147/// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
148/// types of the inputs that flow to a successor region.
149static LogicalResult
150verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
151 function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
152 getInputsTypesForRegion) {
153 auto regionInterface = cast<RegionBranchOpInterface>(Val: op);
154
155 SmallVector<RegionSuccessor, 2> successors;
156 regionInterface.getSuccessorRegions(point: sourcePoint, regions&: successors);
157
158 for (RegionSuccessor &succ : successors) {
159 FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
160 if (failed(Result: sourceTypes))
161 return failure();
162
163 TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
164 if (sourceTypes->size() != succInputsTypes.size()) {
165 InFlightDiagnostic diag = op->emitOpError(message: "region control flow edge ");
166 return printRegionEdgeName(diag, sourceNo: sourcePoint, succRegionNo: succ)
167 << ": source has " << sourceTypes->size()
168 << " operands, but target successor needs "
169 << succInputsTypes.size();
170 }
171
172 for (const auto &typesIdx :
173 llvm::enumerate(First: llvm::zip(t&: *sourceTypes, u&: succInputsTypes))) {
174 Type sourceType = std::get<0>(t&: typesIdx.value());
175 Type inputType = std::get<1>(t&: typesIdx.value());
176 if (!regionInterface.areTypesCompatible(lhs: sourceType, rhs: inputType)) {
177 InFlightDiagnostic diag = op->emitOpError(message: "along control flow edge ");
178 return printRegionEdgeName(diag, sourceNo: sourcePoint, succRegionNo: succ)
179 << ": source type #" << typesIdx.index() << " " << sourceType
180 << " should match input type #" << typesIdx.index() << " "
181 << inputType;
182 }
183 }
184 }
185 return success();
186}
187
188/// Verify that types match along control flow edges described the given op.
189LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
190 auto regionInterface = cast<RegionBranchOpInterface>(Val: op);
191
192 auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange {
193 return regionInterface.getEntrySuccessorOperands(point).getTypes();
194 };
195
196 // Verify types along control flow edges originating from the parent.
197 if (failed(Result: verifyTypesAlongAllEdges(op, sourcePoint: RegionBranchPoint::parent(),
198 getInputsTypesForRegion: inputTypesFromParent)))
199 return failure();
200
201 auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
202 if (lhs.size() != rhs.size())
203 return false;
204 for (auto types : llvm::zip(t&: lhs, u&: rhs)) {
205 if (!regionInterface.areTypesCompatible(lhs: std::get<0>(t&: types),
206 rhs: std::get<1>(t&: types))) {
207 return false;
208 }
209 }
210 return true;
211 };
212
213 // Verify types along control flow edges originating from each region.
214 for (Region &region : op->getRegions()) {
215
216 // Since there can be multiple terminators implementing the
217 // `RegionBranchTerminatorOpInterface`, all should have the same operand
218 // types when passing them to the same region.
219
220 SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps;
221 for (Block &block : region)
222 if (!block.empty())
223 if (auto terminator =
224 dyn_cast<RegionBranchTerminatorOpInterface>(Val&: block.back()))
225 regionReturnOps.push_back(Elt: terminator);
226
227 // If there is no return-like terminator, the op itself should verify
228 // type consistency.
229 if (regionReturnOps.empty())
230 continue;
231
232 auto inputTypesForRegion =
233 [&](RegionBranchPoint point) -> FailureOr<TypeRange> {
234 std::optional<OperandRange> regionReturnOperands;
235 for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
236 auto terminatorOperands = regionReturnOp.getSuccessorOperands(point);
237
238 if (!regionReturnOperands) {
239 regionReturnOperands = terminatorOperands;
240 continue;
241 }
242
243 // Found more than one ReturnLike terminator. Make sure the operand
244 // types match with the first one.
245 if (!areTypesCompatible(regionReturnOperands->getTypes(),
246 terminatorOperands.getTypes())) {
247 InFlightDiagnostic diag = op->emitOpError(message: "along control flow edge");
248 return printRegionEdgeName(diag, sourceNo: region, succRegionNo: point)
249 << " operands mismatch between return-like terminators";
250 }
251 }
252
253 // All successors get the same set of operand types.
254 return TypeRange(regionReturnOperands->getTypes());
255 };
256
257 if (failed(Result: verifyTypesAlongAllEdges(op, sourcePoint: region, getInputsTypesForRegion: inputTypesForRegion)))
258 return failure();
259 }
260
261 return success();
262}
263
264/// Stop condition for `traverseRegionGraph`. The traversal is interrupted if
265/// this function returns "true" for a successor region. The first parameter is
266/// the successor region. The second parameter indicates all already visited
267/// regions.
268using StopConditionFn = function_ref<bool(Region *, ArrayRef<bool> visited)>;
269
270/// Traverse the region graph starting at `begin`. The traversal is interrupted
271/// if `stopCondition` evaluates to "true" for a successor region. In that case,
272/// this function returns "true". Otherwise, if the traversal was not
273/// interrupted, this function returns "false".
274static bool traverseRegionGraph(Region *begin,
275 StopConditionFn stopConditionFn) {
276 auto op = cast<RegionBranchOpInterface>(Val: begin->getParentOp());
277 SmallVector<bool> visited(op->getNumRegions(), false);
278 visited[begin->getRegionNumber()] = true;
279
280 // Retrieve all successors of the region and enqueue them in the worklist.
281 SmallVector<Region *> worklist;
282 auto enqueueAllSuccessors = [&](Region *region) {
283 SmallVector<RegionSuccessor> successors;
284 op.getSuccessorRegions(point: region, regions&: successors);
285 for (RegionSuccessor successor : successors)
286 if (!successor.isParent())
287 worklist.push_back(Elt: successor.getSuccessor());
288 };
289 enqueueAllSuccessors(begin);
290
291 // Process all regions in the worklist via DFS.
292 while (!worklist.empty()) {
293 Region *nextRegion = worklist.pop_back_val();
294 if (stopConditionFn(nextRegion, visited))
295 return true;
296 if (visited[nextRegion->getRegionNumber()])
297 continue;
298 visited[nextRegion->getRegionNumber()] = true;
299 enqueueAllSuccessors(nextRegion);
300 }
301
302 return false;
303}
304
305/// Return `true` if region `r` is reachable from region `begin` according to
306/// the RegionBranchOpInterface (by taking a branch).
307static bool isRegionReachable(Region *begin, Region *r) {
308 assert(begin->getParentOp() == r->getParentOp() &&
309 "expected that both regions belong to the same op");
310 return traverseRegionGraph(begin,
311 stopConditionFn: [&](Region *nextRegion, ArrayRef<bool> visited) {
312 // Interrupt traversal if `r` was reached.
313 return nextRegion == r;
314 });
315}
316
317/// Return `true` if `a` and `b` are in mutually exclusive regions.
318///
319/// 1. Find the first common of `a` and `b` (ancestor) that implements
320/// RegionBranchOpInterface.
321/// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
322/// contained.
323/// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
324/// mutually exclusive if they are not reachable from each other as per
325/// RegionBranchOpInterface::getSuccessorRegions.
326bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
327 assert(a && "expected non-empty operation");
328 assert(b && "expected non-empty operation");
329
330 auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
331 while (branchOp) {
332 // Check if b is inside branchOp. (We already know that a is.)
333 if (!branchOp->isProperAncestor(other: b)) {
334 // Check next enclosing RegionBranchOpInterface.
335 branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
336 continue;
337 }
338
339 // b is contained in branchOp. Retrieve the regions in which `a` and `b`
340 // are contained.
341 Region *regionA = nullptr, *regionB = nullptr;
342 for (Region &r : branchOp->getRegions()) {
343 if (r.findAncestorOpInRegion(op&: *a)) {
344 assert(!regionA && "already found a region for a");
345 regionA = &r;
346 }
347 if (r.findAncestorOpInRegion(op&: *b)) {
348 assert(!regionB && "already found a region for b");
349 regionB = &r;
350 }
351 }
352 assert(regionA && regionB && "could not find region of op");
353
354 // `a` and `b` are in mutually exclusive regions if both regions are
355 // distinct and neither region is reachable from the other region.
356 return regionA != regionB && !isRegionReachable(begin: regionA, r: regionB) &&
357 !isRegionReachable(begin: regionB, r: regionA);
358 }
359
360 // Could not find a common RegionBranchOpInterface among a's and b's
361 // ancestors.
362 return false;
363}
364
365bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
366 Region *region = &getOperation()->getRegion(index);
367 return isRegionReachable(begin: region, r: region);
368}
369
370bool RegionBranchOpInterface::hasLoop() {
371 SmallVector<RegionSuccessor> entryRegions;
372 getSuccessorRegions(point: RegionBranchPoint::parent(), regions&: entryRegions);
373 for (RegionSuccessor successor : entryRegions)
374 if (!successor.isParent() &&
375 traverseRegionGraph(begin: successor.getSuccessor(),
376 stopConditionFn: [](Region *nextRegion, ArrayRef<bool> visited) {
377 // Interrupt traversal if the region was already
378 // visited.
379 return visited[nextRegion->getRegionNumber()];
380 }))
381 return true;
382 return false;
383}
384
385Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
386 while (Region *region = op->getParentRegion()) {
387 op = region->getParentOp();
388 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(Val: op))
389 if (branchOp.isRepetitiveRegion(index: region->getRegionNumber()))
390 return region;
391 }
392 return nullptr;
393}
394
395Region *mlir::getEnclosingRepetitiveRegion(Value value) {
396 Region *region = value.getParentRegion();
397 while (region) {
398 Operation *op = region->getParentOp();
399 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(Val: op))
400 if (branchOp.isRepetitiveRegion(index: region->getRegionNumber()))
401 return region;
402 region = op->getParentRegion();
403 }
404 return nullptr;
405}
406

source code of mlir/lib/Interfaces/ControlFlowInterfaces.cpp