| 1 | //===- TopologicalSortUtils.cpp - Topological sort utilities --------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Analysis/TopologicalSortUtils.h" |
| 10 | #include "mlir/IR/Block.h" |
| 11 | #include "mlir/IR/OpDefinition.h" |
| 12 | #include "mlir/IR/RegionGraphTraits.h" |
| 13 | |
| 14 | #include "llvm/ADT/PostOrderIterator.h" |
| 15 | #include "llvm/ADT/SetVector.h" |
| 16 | |
| 17 | using namespace mlir; |
| 18 | |
| 19 | /// Return `true` if the given operation is ready to be scheduled. |
| 20 | static bool isOpReady(Operation *op, DenseSet<Operation *> &unscheduledOps, |
| 21 | function_ref<bool(Value, Operation *)> isOperandReady) { |
| 22 | // An operation is ready to be scheduled if all its operands are ready. An |
| 23 | // operation is ready if: |
| 24 | const auto isReady = [&](Value value) { |
| 25 | // - the user-provided callback marks it as ready, |
| 26 | if (isOperandReady && isOperandReady(value, op)) |
| 27 | return true; |
| 28 | Operation *parent = value.getDefiningOp(); |
| 29 | // - it is a block argument, |
| 30 | if (!parent) |
| 31 | return true; |
| 32 | // - or it is not defined by an unscheduled op (and also not nested within |
| 33 | // an unscheduled op). |
| 34 | do { |
| 35 | // Stop traversal when op under examination is reached. |
| 36 | if (parent == op) |
| 37 | return true; |
| 38 | if (unscheduledOps.contains(V: parent)) |
| 39 | return false; |
| 40 | } while ((parent = parent->getParentOp())); |
| 41 | // No unscheduled op found. |
| 42 | return true; |
| 43 | }; |
| 44 | |
| 45 | // An operation is recursively ready to be scheduled of it and its nested |
| 46 | // operations are ready. |
| 47 | WalkResult readyToSchedule = op->walk(callback: [&](Operation *nestedOp) { |
| 48 | return llvm::all_of(Range: nestedOp->getOperands(), |
| 49 | P: [&](Value operand) { return isReady(operand); }) |
| 50 | ? WalkResult::advance() |
| 51 | : WalkResult::interrupt(); |
| 52 | }); |
| 53 | return !readyToSchedule.wasInterrupted(); |
| 54 | } |
| 55 | |
| 56 | bool mlir::sortTopologically( |
| 57 | Block *block, llvm::iterator_range<Block::iterator> ops, |
| 58 | function_ref<bool(Value, Operation *)> isOperandReady) { |
| 59 | if (ops.empty()) |
| 60 | return true; |
| 61 | |
| 62 | // The set of operations that have not yet been scheduled. |
| 63 | DenseSet<Operation *> unscheduledOps; |
| 64 | // Mark all operations as unscheduled. |
| 65 | for (Operation &op : ops) |
| 66 | unscheduledOps.insert(V: &op); |
| 67 | |
| 68 | Block::iterator nextScheduledOp = ops.begin(); |
| 69 | Block::iterator end = ops.end(); |
| 70 | |
| 71 | bool allOpsScheduled = true; |
| 72 | while (!unscheduledOps.empty()) { |
| 73 | bool scheduledAtLeastOnce = false; |
| 74 | |
| 75 | // Loop over the ops that are not sorted yet, try to find the ones "ready", |
| 76 | // i.e. the ones for which there aren't any operand produced by an op in the |
| 77 | // set, and "schedule" it (move it before the `nextScheduledOp`). |
| 78 | for (Operation &op : |
| 79 | llvm::make_early_inc_range(Range: llvm::make_range(x: nextScheduledOp, y: end))) { |
| 80 | if (!isOpReady(op: &op, unscheduledOps, isOperandReady)) |
| 81 | continue; |
| 82 | |
| 83 | // Schedule the operation by moving it to the start. |
| 84 | unscheduledOps.erase(V: &op); |
| 85 | op.moveBefore(block, iterator: nextScheduledOp); |
| 86 | scheduledAtLeastOnce = true; |
| 87 | // Move the iterator forward if we schedule the operation at the front. |
| 88 | if (&op == &*nextScheduledOp) |
| 89 | ++nextScheduledOp; |
| 90 | } |
| 91 | // If no operations were scheduled, give up and advance the iterator. |
| 92 | if (!scheduledAtLeastOnce) { |
| 93 | allOpsScheduled = false; |
| 94 | unscheduledOps.erase(V: &*nextScheduledOp); |
| 95 | ++nextScheduledOp; |
| 96 | } |
| 97 | } |
| 98 | |
| 99 | return allOpsScheduled; |
| 100 | } |
| 101 | |
| 102 | bool mlir::sortTopologically( |
| 103 | Block *block, function_ref<bool(Value, Operation *)> isOperandReady) { |
| 104 | if (block->empty()) |
| 105 | return true; |
| 106 | if (block->back().hasTrait<OpTrait::IsTerminator>()) |
| 107 | return sortTopologically(block, ops: block->without_terminator(), |
| 108 | isOperandReady); |
| 109 | return sortTopologically(block, ops: *block, isOperandReady); |
| 110 | } |
| 111 | |
| 112 | bool mlir::computeTopologicalSorting( |
| 113 | MutableArrayRef<Operation *> ops, |
| 114 | function_ref<bool(Value, Operation *)> isOperandReady) { |
| 115 | if (ops.empty()) |
| 116 | return true; |
| 117 | |
| 118 | // The set of operations that have not yet been scheduled. |
| 119 | // Mark all operations as unscheduled. |
| 120 | DenseSet<Operation *> unscheduledOps(llvm::from_range, ops); |
| 121 | |
| 122 | unsigned nextScheduledOp = 0; |
| 123 | |
| 124 | bool allOpsScheduled = true; |
| 125 | while (!unscheduledOps.empty()) { |
| 126 | bool scheduledAtLeastOnce = false; |
| 127 | |
| 128 | // Loop over the ops that are not sorted yet, try to find the ones "ready", |
| 129 | // i.e. the ones for which there aren't any operand produced by an op in the |
| 130 | // set, and "schedule" it (swap it with the op at `nextScheduledOp`). |
| 131 | for (unsigned i = nextScheduledOp; i < ops.size(); ++i) { |
| 132 | if (!isOpReady(op: ops[i], unscheduledOps, isOperandReady)) |
| 133 | continue; |
| 134 | |
| 135 | // Schedule the operation by moving it to the start. |
| 136 | unscheduledOps.erase(V: ops[i]); |
| 137 | std::swap(a&: ops[i], b&: ops[nextScheduledOp]); |
| 138 | scheduledAtLeastOnce = true; |
| 139 | ++nextScheduledOp; |
| 140 | } |
| 141 | |
| 142 | // If no operations were scheduled, just schedule the first op and continue. |
| 143 | if (!scheduledAtLeastOnce) { |
| 144 | allOpsScheduled = false; |
| 145 | unscheduledOps.erase(V: ops[nextScheduledOp++]); |
| 146 | } |
| 147 | } |
| 148 | |
| 149 | return allOpsScheduled; |
| 150 | } |
| 151 | |
| 152 | SetVector<Block *> mlir::getBlocksSortedByDominance(Region ®ion) { |
| 153 | // For each block that has not been visited yet (i.e. that has no |
| 154 | // predecessors), add it to the list as well as its successors. |
| 155 | SetVector<Block *> blocks; |
| 156 | for (Block &b : region) { |
| 157 | if (blocks.count(key: &b) == 0) { |
| 158 | llvm::ReversePostOrderTraversal<Block *> traversal(&b); |
| 159 | blocks.insert_range(R&: traversal); |
| 160 | } |
| 161 | } |
| 162 | assert(blocks.size() == region.getBlocks().size() && |
| 163 | "some blocks are not sorted" ); |
| 164 | |
| 165 | return blocks; |
| 166 | } |
| 167 | |
| 168 | namespace { |
| 169 | class TopoSortHelper { |
| 170 | public: |
| 171 | explicit TopoSortHelper(const SetVector<Operation *> &toSort) |
| 172 | : toSort(toSort) {} |
| 173 | |
| 174 | /// Executes the topological sort of the operations this instance was |
| 175 | /// constructed with. This function will destroy the internal state of the |
| 176 | /// instance. |
| 177 | SetVector<Operation *> sort() { |
| 178 | if (toSort.size() <= 1) { |
| 179 | // Note: Creates a copy on purpose. |
| 180 | return toSort; |
| 181 | } |
| 182 | |
| 183 | // First, find the root region to start the traversal through the IR. This |
| 184 | // additionally enriches the internal caches with all relevant ancestor |
| 185 | // regions and blocks. |
| 186 | Region *rootRegion = findCommonAncestorRegion(); |
| 187 | assert(rootRegion && "expected all ops to have a common ancestor" ); |
| 188 | |
| 189 | // Sort all elements in `toSort` by traversing the IR in the appropriate |
| 190 | // order. |
| 191 | SetVector<Operation *> result = topoSortRegion(rootRegion&: *rootRegion); |
| 192 | assert(result.size() == toSort.size() && |
| 193 | "expected all operations to be present in the result" ); |
| 194 | return result; |
| 195 | } |
| 196 | |
| 197 | private: |
| 198 | /// Computes the closest common ancestor region of all operations in `toSort`. |
| 199 | Region *findCommonAncestorRegion() { |
| 200 | // Map to count the number of times a region was encountered. |
| 201 | DenseMap<Region *, size_t> regionCounts; |
| 202 | size_t expectedCount = toSort.size(); |
| 203 | |
| 204 | // Walk the region tree for each operation towards the root and add to the |
| 205 | // region count. |
| 206 | Region *res = nullptr; |
| 207 | for (Operation *op : toSort) { |
| 208 | Region *current = op->getParentRegion(); |
| 209 | // Store the block as an ancestor block. |
| 210 | ancestorBlocks.insert(V: op->getBlock()); |
| 211 | while (current) { |
| 212 | // Insert or update the count and compare it. |
| 213 | if (++regionCounts[current] == expectedCount) { |
| 214 | res = current; |
| 215 | break; |
| 216 | } |
| 217 | ancestorBlocks.insert(V: current->getParentOp()->getBlock()); |
| 218 | current = current->getParentRegion(); |
| 219 | } |
| 220 | } |
| 221 | auto firstRange = llvm::make_first_range(c&: regionCounts); |
| 222 | ancestorRegions.insert_range(R&: firstRange); |
| 223 | return res; |
| 224 | } |
| 225 | |
| 226 | /// Performs the dominance respecting IR walk to collect the topological order |
| 227 | /// of the operation to sort. |
| 228 | SetVector<Operation *> topoSortRegion(Region &rootRegion) { |
| 229 | using StackT = PointerUnion<Region *, Block *, Operation *>; |
| 230 | |
| 231 | SetVector<Operation *> result; |
| 232 | // Stack that stores the different IR constructs to traverse. |
| 233 | SmallVector<StackT> stack; |
| 234 | stack.push_back(Elt: &rootRegion); |
| 235 | |
| 236 | // Traverse the IR in a dominance respecting pre-order walk. |
| 237 | while (!stack.empty()) { |
| 238 | StackT current = stack.pop_back_val(); |
| 239 | if (auto *region = dyn_cast<Region *>(Val&: current)) { |
| 240 | // A region's blocks need to be traversed in dominance order. |
| 241 | SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(region&: *region); |
| 242 | for (Block *block : llvm::reverse(C&: sortedBlocks)) { |
| 243 | // Only add blocks to the stack that are ancestors of the operations |
| 244 | // to sort. |
| 245 | if (ancestorBlocks.contains(V: block)) |
| 246 | stack.push_back(Elt: block); |
| 247 | } |
| 248 | continue; |
| 249 | } |
| 250 | |
| 251 | if (auto *block = dyn_cast<Block *>(Val&: current)) { |
| 252 | // Add all of the blocks operations to the stack. |
| 253 | for (Operation &op : llvm::reverse(C&: *block)) |
| 254 | stack.push_back(Elt: &op); |
| 255 | continue; |
| 256 | } |
| 257 | |
| 258 | auto *op = cast<Operation *>(Val&: current); |
| 259 | if (toSort.contains(key: op)) |
| 260 | result.insert(X: op); |
| 261 | |
| 262 | // Add all the subregions that are ancestors of the operations to sort. |
| 263 | for (Region &subRegion : op->getRegions()) |
| 264 | if (ancestorRegions.contains(V: &subRegion)) |
| 265 | stack.push_back(Elt: &subRegion); |
| 266 | } |
| 267 | return result; |
| 268 | } |
| 269 | |
| 270 | /// Operations to sort. |
| 271 | const SetVector<Operation *> &toSort; |
| 272 | /// Set containing all the ancestor regions of the operations to sort. |
| 273 | DenseSet<Region *> ancestorRegions; |
| 274 | /// Set containing all the ancestor blocks of the operations to sort. |
| 275 | DenseSet<Block *> ancestorBlocks; |
| 276 | }; |
| 277 | } // namespace |
| 278 | |
| 279 | SetVector<Operation *> |
| 280 | mlir::topologicalSort(const SetVector<Operation *> &toSort) { |
| 281 | return TopoSortHelper(toSort).sort(); |
| 282 | } |
| 283 | |