1 | //===- LoopLikeInterface.cpp - Loop-like operations in MLIR ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Interfaces/LoopLikeInterface.h" |
10 | |
11 | #include "mlir/Interfaces/FunctionInterfaces.h" |
12 | #include "llvm/ADT/DenseSet.h" |
13 | |
14 | using namespace mlir; |
15 | |
16 | /// Include the definitions of the loop-like interfaces. |
17 | #include "mlir/Interfaces/LoopLikeInterface.cpp.inc" |
18 | |
19 | bool LoopLikeOpInterface::blockIsInLoop(Block *block) { |
20 | Operation *parent = block->getParentOp(); |
21 | |
22 | // The block could be inside a loop-like operation |
23 | if (isa<LoopLikeOpInterface>(parent) || |
24 | parent->getParentOfType<LoopLikeOpInterface>()) |
25 | return true; |
26 | |
27 | // This block might be nested inside another block, which is in a loop |
28 | if (!isa<FunctionOpInterface>(parent)) |
29 | if (mlir::Block *parentBlock = parent->getBlock()) |
30 | if (blockIsInLoop(parentBlock)) |
31 | return true; |
32 | |
33 | // Or the block could be inside a control flow graph loop: |
34 | // A block is in a control flow graph loop if it can reach itself in a graph |
35 | // traversal |
36 | DenseSet<Block *> visited; |
37 | SmallVector<Block *> stack; |
38 | stack.push_back(block); |
39 | while (!stack.empty()) { |
40 | Block *current = stack.pop_back_val(); |
41 | auto [it, inserted] = visited.insert(current); |
42 | if (!inserted) { |
43 | // loop detected |
44 | if (current == block) |
45 | return true; |
46 | continue; |
47 | } |
48 | |
49 | stack.reserve(stack.size() + current->getNumSuccessors()); |
50 | for (Block *successor : current->getSuccessors()) |
51 | stack.push_back(successor); |
52 | } |
53 | return false; |
54 | } |
55 | |
56 | LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) { |
57 | // Note: These invariants are also verified by the RegionBranchOpInterface, |
58 | // but the LoopLikeOpInterface provides better error messages. |
59 | auto loopLikeOp = cast<LoopLikeOpInterface>(op); |
60 | |
61 | // Verify number of inits/iter_args/yielded values/loop results. |
62 | if (loopLikeOp.getInits().size() != loopLikeOp.getRegionIterArgs().size()) |
63 | return op->emitOpError(message: "different number of inits and region iter_args: " ) |
64 | << loopLikeOp.getInits().size() |
65 | << " != " << loopLikeOp.getRegionIterArgs().size(); |
66 | if (!loopLikeOp.getYieldedValues().empty() && |
67 | loopLikeOp.getRegionIterArgs().size() != |
68 | loopLikeOp.getYieldedValues().size()) |
69 | return op->emitOpError( |
70 | message: "different number of region iter_args and yielded values: " ) |
71 | << loopLikeOp.getRegionIterArgs().size() |
72 | << " != " << loopLikeOp.getYieldedValues().size(); |
73 | if (loopLikeOp.getLoopResults() && loopLikeOp.getLoopResults()->size() != |
74 | loopLikeOp.getRegionIterArgs().size()) |
75 | return op->emitOpError( |
76 | message: "different number of loop results and region iter_args: " ) |
77 | << loopLikeOp.getLoopResults()->size() |
78 | << " != " << loopLikeOp.getRegionIterArgs().size(); |
79 | |
80 | // Verify types of inits/iter_args/yielded values/loop results. |
81 | int64_t i = 0; |
82 | auto yieldedValues = loopLikeOp.getYieldedValues(); |
83 | for (const auto [index, init, regionIterArg] : |
84 | llvm::enumerate(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs())) { |
85 | if (init.getType() != regionIterArg.getType()) |
86 | return op->emitOpError(std::to_string(index)) |
87 | << "-th init and " << index |
88 | << "-th region iter_arg have different type: " << init.getType() |
89 | << " != " << regionIterArg.getType(); |
90 | if (!yieldedValues.empty()) { |
91 | if (regionIterArg.getType() != yieldedValues[index].getType()) |
92 | return op->emitOpError(std::to_string(index)) |
93 | << "-th region iter_arg and " << index |
94 | << "-th yielded value have different type: " |
95 | << regionIterArg.getType() |
96 | << " != " << yieldedValues[index].getType(); |
97 | } |
98 | ++i; |
99 | } |
100 | i = 0; |
101 | if (loopLikeOp.getLoopResults()) { |
102 | for (const auto it : llvm::zip_equal(loopLikeOp.getRegionIterArgs(), |
103 | *loopLikeOp.getLoopResults())) { |
104 | if (std::get<0>(it).getType() != std::get<1>(it).getType()) |
105 | return op->emitOpError(std::to_string(i)) |
106 | << "-th region iter_arg and " << i |
107 | << "-th loop result have different type: " |
108 | << std::get<0>(it).getType() |
109 | << " != " << std::get<1>(it).getType(); |
110 | } |
111 | ++i; |
112 | } |
113 | |
114 | return success(); |
115 | } |
116 | |