1 | //===- TopologicalSortUtils.cpp - Topological sort utilities --------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Analysis/TopologicalSortUtils.h" |
10 | #include "mlir/IR/Block.h" |
11 | #include "mlir/IR/OpDefinition.h" |
12 | #include "mlir/IR/RegionGraphTraits.h" |
13 | |
14 | #include "llvm/ADT/PostOrderIterator.h" |
15 | #include "llvm/ADT/SetVector.h" |
16 | |
17 | using namespace mlir; |
18 | |
19 | /// Return `true` if the given operation is ready to be scheduled. |
20 | static bool isOpReady(Operation *op, DenseSet<Operation *> &unscheduledOps, |
21 | function_ref<bool(Value, Operation *)> isOperandReady) { |
22 | // An operation is ready to be scheduled if all its operands are ready. An |
23 | // operation is ready if: |
24 | const auto isReady = [&](Value value) { |
25 | // - the user-provided callback marks it as ready, |
26 | if (isOperandReady && isOperandReady(value, op)) |
27 | return true; |
28 | Operation *parent = value.getDefiningOp(); |
29 | // - it is a block argument, |
30 | if (!parent) |
31 | return true; |
32 | // - or it is not defined by an unscheduled op (and also not nested within |
33 | // an unscheduled op). |
34 | do { |
35 | // Stop traversal when op under examination is reached. |
36 | if (parent == op) |
37 | return true; |
38 | if (unscheduledOps.contains(V: parent)) |
39 | return false; |
40 | } while ((parent = parent->getParentOp())); |
41 | // No unscheduled op found. |
42 | return true; |
43 | }; |
44 | |
45 | // An operation is recursively ready to be scheduled of it and its nested |
46 | // operations are ready. |
47 | WalkResult readyToSchedule = op->walk(callback: [&](Operation *nestedOp) { |
48 | return llvm::all_of(Range: nestedOp->getOperands(), |
49 | P: [&](Value operand) { return isReady(operand); }) |
50 | ? WalkResult::advance() |
51 | : WalkResult::interrupt(); |
52 | }); |
53 | return !readyToSchedule.wasInterrupted(); |
54 | } |
55 | |
56 | bool mlir::sortTopologically( |
57 | Block *block, llvm::iterator_range<Block::iterator> ops, |
58 | function_ref<bool(Value, Operation *)> isOperandReady) { |
59 | if (ops.empty()) |
60 | return true; |
61 | |
62 | // The set of operations that have not yet been scheduled. |
63 | DenseSet<Operation *> unscheduledOps; |
64 | // Mark all operations as unscheduled. |
65 | for (Operation &op : ops) |
66 | unscheduledOps.insert(V: &op); |
67 | |
68 | Block::iterator nextScheduledOp = ops.begin(); |
69 | Block::iterator end = ops.end(); |
70 | |
71 | bool allOpsScheduled = true; |
72 | while (!unscheduledOps.empty()) { |
73 | bool scheduledAtLeastOnce = false; |
74 | |
75 | // Loop over the ops that are not sorted yet, try to find the ones "ready", |
76 | // i.e. the ones for which there aren't any operand produced by an op in the |
77 | // set, and "schedule" it (move it before the `nextScheduledOp`). |
78 | for (Operation &op : |
79 | llvm::make_early_inc_range(Range: llvm::make_range(x: nextScheduledOp, y: end))) { |
80 | if (!isOpReady(op: &op, unscheduledOps, isOperandReady)) |
81 | continue; |
82 | |
83 | // Schedule the operation by moving it to the start. |
84 | unscheduledOps.erase(V: &op); |
85 | op.moveBefore(block, iterator: nextScheduledOp); |
86 | scheduledAtLeastOnce = true; |
87 | // Move the iterator forward if we schedule the operation at the front. |
88 | if (&op == &*nextScheduledOp) |
89 | ++nextScheduledOp; |
90 | } |
91 | // If no operations were scheduled, give up and advance the iterator. |
92 | if (!scheduledAtLeastOnce) { |
93 | allOpsScheduled = false; |
94 | unscheduledOps.erase(V: &*nextScheduledOp); |
95 | ++nextScheduledOp; |
96 | } |
97 | } |
98 | |
99 | return allOpsScheduled; |
100 | } |
101 | |
102 | bool mlir::sortTopologically( |
103 | Block *block, function_ref<bool(Value, Operation *)> isOperandReady) { |
104 | if (block->empty()) |
105 | return true; |
106 | if (block->back().hasTrait<OpTrait::IsTerminator>()) |
107 | return sortTopologically(block, ops: block->without_terminator(), |
108 | isOperandReady); |
109 | return sortTopologically(block, ops: *block, isOperandReady); |
110 | } |
111 | |
112 | bool mlir::computeTopologicalSorting( |
113 | MutableArrayRef<Operation *> ops, |
114 | function_ref<bool(Value, Operation *)> isOperandReady) { |
115 | if (ops.empty()) |
116 | return true; |
117 | |
118 | // The set of operations that have not yet been scheduled. |
119 | // Mark all operations as unscheduled. |
120 | DenseSet<Operation *> unscheduledOps(llvm::from_range, ops); |
121 | |
122 | unsigned nextScheduledOp = 0; |
123 | |
124 | bool allOpsScheduled = true; |
125 | while (!unscheduledOps.empty()) { |
126 | bool scheduledAtLeastOnce = false; |
127 | |
128 | // Loop over the ops that are not sorted yet, try to find the ones "ready", |
129 | // i.e. the ones for which there aren't any operand produced by an op in the |
130 | // set, and "schedule" it (swap it with the op at `nextScheduledOp`). |
131 | for (unsigned i = nextScheduledOp; i < ops.size(); ++i) { |
132 | if (!isOpReady(op: ops[i], unscheduledOps, isOperandReady)) |
133 | continue; |
134 | |
135 | // Schedule the operation by moving it to the start. |
136 | unscheduledOps.erase(V: ops[i]); |
137 | std::swap(a&: ops[i], b&: ops[nextScheduledOp]); |
138 | scheduledAtLeastOnce = true; |
139 | ++nextScheduledOp; |
140 | } |
141 | |
142 | // If no operations were scheduled, just schedule the first op and continue. |
143 | if (!scheduledAtLeastOnce) { |
144 | allOpsScheduled = false; |
145 | unscheduledOps.erase(V: ops[nextScheduledOp++]); |
146 | } |
147 | } |
148 | |
149 | return allOpsScheduled; |
150 | } |
151 | |
152 | SetVector<Block *> mlir::getBlocksSortedByDominance(Region ®ion) { |
153 | // For each block that has not been visited yet (i.e. that has no |
154 | // predecessors), add it to the list as well as its successors. |
155 | SetVector<Block *> blocks; |
156 | for (Block &b : region) { |
157 | if (blocks.count(key: &b) == 0) { |
158 | llvm::ReversePostOrderTraversal<Block *> traversal(&b); |
159 | blocks.insert_range(R&: traversal); |
160 | } |
161 | } |
162 | assert(blocks.size() == region.getBlocks().size() && |
163 | "some blocks are not sorted" ); |
164 | |
165 | return blocks; |
166 | } |
167 | |
168 | namespace { |
169 | class TopoSortHelper { |
170 | public: |
171 | explicit TopoSortHelper(const SetVector<Operation *> &toSort) |
172 | : toSort(toSort) {} |
173 | |
174 | /// Executes the topological sort of the operations this instance was |
175 | /// constructed with. This function will destroy the internal state of the |
176 | /// instance. |
177 | SetVector<Operation *> sort() { |
178 | if (toSort.size() <= 1) { |
179 | // Note: Creates a copy on purpose. |
180 | return toSort; |
181 | } |
182 | |
183 | // First, find the root region to start the traversal through the IR. This |
184 | // additionally enriches the internal caches with all relevant ancestor |
185 | // regions and blocks. |
186 | Region *rootRegion = findCommonAncestorRegion(); |
187 | assert(rootRegion && "expected all ops to have a common ancestor" ); |
188 | |
189 | // Sort all elements in `toSort` by traversing the IR in the appropriate |
190 | // order. |
191 | SetVector<Operation *> result = topoSortRegion(rootRegion&: *rootRegion); |
192 | assert(result.size() == toSort.size() && |
193 | "expected all operations to be present in the result" ); |
194 | return result; |
195 | } |
196 | |
197 | private: |
198 | /// Computes the closest common ancestor region of all operations in `toSort`. |
199 | Region *findCommonAncestorRegion() { |
200 | // Map to count the number of times a region was encountered. |
201 | DenseMap<Region *, size_t> regionCounts; |
202 | size_t expectedCount = toSort.size(); |
203 | |
204 | // Walk the region tree for each operation towards the root and add to the |
205 | // region count. |
206 | Region *res = nullptr; |
207 | for (Operation *op : toSort) { |
208 | Region *current = op->getParentRegion(); |
209 | // Store the block as an ancestor block. |
210 | ancestorBlocks.insert(V: op->getBlock()); |
211 | while (current) { |
212 | // Insert or update the count and compare it. |
213 | if (++regionCounts[current] == expectedCount) { |
214 | res = current; |
215 | break; |
216 | } |
217 | ancestorBlocks.insert(V: current->getParentOp()->getBlock()); |
218 | current = current->getParentRegion(); |
219 | } |
220 | } |
221 | auto firstRange = llvm::make_first_range(c&: regionCounts); |
222 | ancestorRegions.insert_range(R&: firstRange); |
223 | return res; |
224 | } |
225 | |
226 | /// Performs the dominance respecting IR walk to collect the topological order |
227 | /// of the operation to sort. |
228 | SetVector<Operation *> topoSortRegion(Region &rootRegion) { |
229 | using StackT = PointerUnion<Region *, Block *, Operation *>; |
230 | |
231 | SetVector<Operation *> result; |
232 | // Stack that stores the different IR constructs to traverse. |
233 | SmallVector<StackT> stack; |
234 | stack.push_back(Elt: &rootRegion); |
235 | |
236 | // Traverse the IR in a dominance respecting pre-order walk. |
237 | while (!stack.empty()) { |
238 | StackT current = stack.pop_back_val(); |
239 | if (auto *region = dyn_cast<Region *>(Val&: current)) { |
240 | // A region's blocks need to be traversed in dominance order. |
241 | SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(region&: *region); |
242 | for (Block *block : llvm::reverse(C&: sortedBlocks)) { |
243 | // Only add blocks to the stack that are ancestors of the operations |
244 | // to sort. |
245 | if (ancestorBlocks.contains(V: block)) |
246 | stack.push_back(Elt: block); |
247 | } |
248 | continue; |
249 | } |
250 | |
251 | if (auto *block = dyn_cast<Block *>(Val&: current)) { |
252 | // Add all of the blocks operations to the stack. |
253 | for (Operation &op : llvm::reverse(C&: *block)) |
254 | stack.push_back(Elt: &op); |
255 | continue; |
256 | } |
257 | |
258 | auto *op = cast<Operation *>(Val&: current); |
259 | if (toSort.contains(key: op)) |
260 | result.insert(X: op); |
261 | |
262 | // Add all the subregions that are ancestors of the operations to sort. |
263 | for (Region &subRegion : op->getRegions()) |
264 | if (ancestorRegions.contains(V: &subRegion)) |
265 | stack.push_back(Elt: &subRegion); |
266 | } |
267 | return result; |
268 | } |
269 | |
270 | /// Operations to sort. |
271 | const SetVector<Operation *> &toSort; |
272 | /// Set containing all the ancestor regions of the operations to sort. |
273 | DenseSet<Region *> ancestorRegions; |
274 | /// Set containing all the ancestor blocks of the operations to sort. |
275 | DenseSet<Block *> ancestorBlocks; |
276 | }; |
277 | } // namespace |
278 | |
279 | SetVector<Operation *> |
280 | mlir::topologicalSort(const SetVector<Operation *> &toSort) { |
281 | return TopoSortHelper(toSort).sort(); |
282 | } |
283 | |