1//===- TopologicalSortUtils.cpp - Topological sort utilities --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis/TopologicalSortUtils.h"
10#include "mlir/IR/Block.h"
11#include "mlir/IR/OpDefinition.h"
12#include "mlir/IR/RegionGraphTraits.h"
13
14#include "llvm/ADT/PostOrderIterator.h"
15#include "llvm/ADT/SetVector.h"
16
17using namespace mlir;
18
19/// Return `true` if the given operation is ready to be scheduled.
20static bool isOpReady(Operation *op, DenseSet<Operation *> &unscheduledOps,
21 function_ref<bool(Value, Operation *)> isOperandReady) {
22 // An operation is ready to be scheduled if all its operands are ready. An
23 // operation is ready if:
24 const auto isReady = [&](Value value) {
25 // - the user-provided callback marks it as ready,
26 if (isOperandReady && isOperandReady(value, op))
27 return true;
28 Operation *parent = value.getDefiningOp();
29 // - it is a block argument,
30 if (!parent)
31 return true;
32 // - or it is not defined by an unscheduled op (and also not nested within
33 // an unscheduled op).
34 do {
35 // Stop traversal when op under examination is reached.
36 if (parent == op)
37 return true;
38 if (unscheduledOps.contains(V: parent))
39 return false;
40 } while ((parent = parent->getParentOp()));
41 // No unscheduled op found.
42 return true;
43 };
44
45 // An operation is recursively ready to be scheduled of it and its nested
46 // operations are ready.
47 WalkResult readyToSchedule = op->walk(callback: [&](Operation *nestedOp) {
48 return llvm::all_of(Range: nestedOp->getOperands(),
49 P: [&](Value operand) { return isReady(operand); })
50 ? WalkResult::advance()
51 : WalkResult::interrupt();
52 });
53 return !readyToSchedule.wasInterrupted();
54}
55
56bool mlir::sortTopologically(
57 Block *block, llvm::iterator_range<Block::iterator> ops,
58 function_ref<bool(Value, Operation *)> isOperandReady) {
59 if (ops.empty())
60 return true;
61
62 // The set of operations that have not yet been scheduled.
63 DenseSet<Operation *> unscheduledOps;
64 // Mark all operations as unscheduled.
65 for (Operation &op : ops)
66 unscheduledOps.insert(V: &op);
67
68 Block::iterator nextScheduledOp = ops.begin();
69 Block::iterator end = ops.end();
70
71 bool allOpsScheduled = true;
72 while (!unscheduledOps.empty()) {
73 bool scheduledAtLeastOnce = false;
74
75 // Loop over the ops that are not sorted yet, try to find the ones "ready",
76 // i.e. the ones for which there aren't any operand produced by an op in the
77 // set, and "schedule" it (move it before the `nextScheduledOp`).
78 for (Operation &op :
79 llvm::make_early_inc_range(Range: llvm::make_range(x: nextScheduledOp, y: end))) {
80 if (!isOpReady(op: &op, unscheduledOps, isOperandReady))
81 continue;
82
83 // Schedule the operation by moving it to the start.
84 unscheduledOps.erase(V: &op);
85 op.moveBefore(block, iterator: nextScheduledOp);
86 scheduledAtLeastOnce = true;
87 // Move the iterator forward if we schedule the operation at the front.
88 if (&op == &*nextScheduledOp)
89 ++nextScheduledOp;
90 }
91 // If no operations were scheduled, give up and advance the iterator.
92 if (!scheduledAtLeastOnce) {
93 allOpsScheduled = false;
94 unscheduledOps.erase(V: &*nextScheduledOp);
95 ++nextScheduledOp;
96 }
97 }
98
99 return allOpsScheduled;
100}
101
102bool mlir::sortTopologically(
103 Block *block, function_ref<bool(Value, Operation *)> isOperandReady) {
104 if (block->empty())
105 return true;
106 if (block->back().hasTrait<OpTrait::IsTerminator>())
107 return sortTopologically(block, ops: block->without_terminator(),
108 isOperandReady);
109 return sortTopologically(block, ops: *block, isOperandReady);
110}
111
112bool mlir::computeTopologicalSorting(
113 MutableArrayRef<Operation *> ops,
114 function_ref<bool(Value, Operation *)> isOperandReady) {
115 if (ops.empty())
116 return true;
117
118 // The set of operations that have not yet been scheduled.
119 // Mark all operations as unscheduled.
120 DenseSet<Operation *> unscheduledOps(llvm::from_range, ops);
121
122 unsigned nextScheduledOp = 0;
123
124 bool allOpsScheduled = true;
125 while (!unscheduledOps.empty()) {
126 bool scheduledAtLeastOnce = false;
127
128 // Loop over the ops that are not sorted yet, try to find the ones "ready",
129 // i.e. the ones for which there aren't any operand produced by an op in the
130 // set, and "schedule" it (swap it with the op at `nextScheduledOp`).
131 for (unsigned i = nextScheduledOp; i < ops.size(); ++i) {
132 if (!isOpReady(op: ops[i], unscheduledOps, isOperandReady))
133 continue;
134
135 // Schedule the operation by moving it to the start.
136 unscheduledOps.erase(V: ops[i]);
137 std::swap(a&: ops[i], b&: ops[nextScheduledOp]);
138 scheduledAtLeastOnce = true;
139 ++nextScheduledOp;
140 }
141
142 // If no operations were scheduled, just schedule the first op and continue.
143 if (!scheduledAtLeastOnce) {
144 allOpsScheduled = false;
145 unscheduledOps.erase(V: ops[nextScheduledOp++]);
146 }
147 }
148
149 return allOpsScheduled;
150}
151
152SetVector<Block *> mlir::getBlocksSortedByDominance(Region &region) {
153 // For each block that has not been visited yet (i.e. that has no
154 // predecessors), add it to the list as well as its successors.
155 SetVector<Block *> blocks;
156 for (Block &b : region) {
157 if (blocks.count(key: &b) == 0) {
158 llvm::ReversePostOrderTraversal<Block *> traversal(&b);
159 blocks.insert_range(R&: traversal);
160 }
161 }
162 assert(blocks.size() == region.getBlocks().size() &&
163 "some blocks are not sorted");
164
165 return blocks;
166}
167
168namespace {
169class TopoSortHelper {
170public:
171 explicit TopoSortHelper(const SetVector<Operation *> &toSort)
172 : toSort(toSort) {}
173
174 /// Executes the topological sort of the operations this instance was
175 /// constructed with. This function will destroy the internal state of the
176 /// instance.
177 SetVector<Operation *> sort() {
178 if (toSort.size() <= 1) {
179 // Note: Creates a copy on purpose.
180 return toSort;
181 }
182
183 // First, find the root region to start the traversal through the IR. This
184 // additionally enriches the internal caches with all relevant ancestor
185 // regions and blocks.
186 Region *rootRegion = findCommonAncestorRegion();
187 assert(rootRegion && "expected all ops to have a common ancestor");
188
189 // Sort all elements in `toSort` by traversing the IR in the appropriate
190 // order.
191 SetVector<Operation *> result = topoSortRegion(rootRegion&: *rootRegion);
192 assert(result.size() == toSort.size() &&
193 "expected all operations to be present in the result");
194 return result;
195 }
196
197private:
198 /// Computes the closest common ancestor region of all operations in `toSort`.
199 Region *findCommonAncestorRegion() {
200 // Map to count the number of times a region was encountered.
201 DenseMap<Region *, size_t> regionCounts;
202 size_t expectedCount = toSort.size();
203
204 // Walk the region tree for each operation towards the root and add to the
205 // region count.
206 Region *res = nullptr;
207 for (Operation *op : toSort) {
208 Region *current = op->getParentRegion();
209 // Store the block as an ancestor block.
210 ancestorBlocks.insert(V: op->getBlock());
211 while (current) {
212 // Insert or update the count and compare it.
213 if (++regionCounts[current] == expectedCount) {
214 res = current;
215 break;
216 }
217 ancestorBlocks.insert(V: current->getParentOp()->getBlock());
218 current = current->getParentRegion();
219 }
220 }
221 auto firstRange = llvm::make_first_range(c&: regionCounts);
222 ancestorRegions.insert_range(R&: firstRange);
223 return res;
224 }
225
226 /// Performs the dominance respecting IR walk to collect the topological order
227 /// of the operation to sort.
228 SetVector<Operation *> topoSortRegion(Region &rootRegion) {
229 using StackT = PointerUnion<Region *, Block *, Operation *>;
230
231 SetVector<Operation *> result;
232 // Stack that stores the different IR constructs to traverse.
233 SmallVector<StackT> stack;
234 stack.push_back(Elt: &rootRegion);
235
236 // Traverse the IR in a dominance respecting pre-order walk.
237 while (!stack.empty()) {
238 StackT current = stack.pop_back_val();
239 if (auto *region = dyn_cast<Region *>(Val&: current)) {
240 // A region's blocks need to be traversed in dominance order.
241 SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(region&: *region);
242 for (Block *block : llvm::reverse(C&: sortedBlocks)) {
243 // Only add blocks to the stack that are ancestors of the operations
244 // to sort.
245 if (ancestorBlocks.contains(V: block))
246 stack.push_back(Elt: block);
247 }
248 continue;
249 }
250
251 if (auto *block = dyn_cast<Block *>(Val&: current)) {
252 // Add all of the blocks operations to the stack.
253 for (Operation &op : llvm::reverse(C&: *block))
254 stack.push_back(Elt: &op);
255 continue;
256 }
257
258 auto *op = cast<Operation *>(Val&: current);
259 if (toSort.contains(key: op))
260 result.insert(X: op);
261
262 // Add all the subregions that are ancestors of the operations to sort.
263 for (Region &subRegion : op->getRegions())
264 if (ancestorRegions.contains(V: &subRegion))
265 stack.push_back(Elt: &subRegion);
266 }
267 return result;
268 }
269
270 /// Operations to sort.
271 const SetVector<Operation *> &toSort;
272 /// Set containing all the ancestor regions of the operations to sort.
273 DenseSet<Region *> ancestorRegions;
274 /// Set containing all the ancestor blocks of the operations to sort.
275 DenseSet<Block *> ancestorBlocks;
276};
277} // namespace
278
279SetVector<Operation *>
280mlir::topologicalSort(const SetVector<Operation *> &toSort) {
281 return TopoSortHelper(toSort).sort();
282}
283

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Analysis/TopologicalSortUtils.cpp