1 | //===- TopologicalSortUtils.h - Topological sort utilities ------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Transforms/TopologicalSortUtils.h" |
10 | #include "mlir/IR/OpDefinition.h" |
11 | |
12 | using namespace mlir; |
13 | |
14 | /// Return `true` if the given operation is ready to be scheduled. |
15 | static bool isOpReady(Operation *op, DenseSet<Operation *> &unscheduledOps, |
16 | function_ref<bool(Value, Operation *)> isOperandReady) { |
17 | // An operation is ready to be scheduled if all its operands are ready. An |
18 | // operation is ready if: |
19 | const auto isReady = [&](Value value) { |
20 | // - the user-provided callback marks it as ready, |
21 | if (isOperandReady && isOperandReady(value, op)) |
22 | return true; |
23 | Operation *parent = value.getDefiningOp(); |
24 | // - it is a block argument, |
25 | if (!parent) |
26 | return true; |
27 | // - or it is not defined by an unscheduled op (and also not nested within |
28 | // an unscheduled op). |
29 | do { |
30 | // Stop traversal when op under examination is reached. |
31 | if (parent == op) |
32 | return true; |
33 | if (unscheduledOps.contains(V: parent)) |
34 | return false; |
35 | } while ((parent = parent->getParentOp())); |
36 | // No unscheduled op found. |
37 | return true; |
38 | }; |
39 | |
40 | // An operation is recursively ready to be scheduled of it and its nested |
41 | // operations are ready. |
42 | WalkResult readyToSchedule = op->walk(callback: [&](Operation *nestedOp) { |
43 | return llvm::all_of(Range: nestedOp->getOperands(), |
44 | P: [&](Value operand) { return isReady(operand); }) |
45 | ? WalkResult::advance() |
46 | : WalkResult::interrupt(); |
47 | }); |
48 | return !readyToSchedule.wasInterrupted(); |
49 | } |
50 | |
51 | bool mlir::sortTopologically( |
52 | Block *block, llvm::iterator_range<Block::iterator> ops, |
53 | function_ref<bool(Value, Operation *)> isOperandReady) { |
54 | if (ops.empty()) |
55 | return true; |
56 | |
57 | // The set of operations that have not yet been scheduled. |
58 | DenseSet<Operation *> unscheduledOps; |
59 | // Mark all operations as unscheduled. |
60 | for (Operation &op : ops) |
61 | unscheduledOps.insert(V: &op); |
62 | |
63 | Block::iterator nextScheduledOp = ops.begin(); |
64 | Block::iterator end = ops.end(); |
65 | |
66 | bool allOpsScheduled = true; |
67 | while (!unscheduledOps.empty()) { |
68 | bool scheduledAtLeastOnce = false; |
69 | |
70 | // Loop over the ops that are not sorted yet, try to find the ones "ready", |
71 | // i.e. the ones for which there aren't any operand produced by an op in the |
72 | // set, and "schedule" it (move it before the `nextScheduledOp`). |
73 | for (Operation &op : |
74 | llvm::make_early_inc_range(Range: llvm::make_range(x: nextScheduledOp, y: end))) { |
75 | if (!isOpReady(op: &op, unscheduledOps, isOperandReady)) |
76 | continue; |
77 | |
78 | // Schedule the operation by moving it to the start. |
79 | unscheduledOps.erase(V: &op); |
80 | op.moveBefore(block, iterator: nextScheduledOp); |
81 | scheduledAtLeastOnce = true; |
82 | // Move the iterator forward if we schedule the operation at the front. |
83 | if (&op == &*nextScheduledOp) |
84 | ++nextScheduledOp; |
85 | } |
86 | // If no operations were scheduled, give up and advance the iterator. |
87 | if (!scheduledAtLeastOnce) { |
88 | allOpsScheduled = false; |
89 | unscheduledOps.erase(V: &*nextScheduledOp); |
90 | ++nextScheduledOp; |
91 | } |
92 | } |
93 | |
94 | return allOpsScheduled; |
95 | } |
96 | |
97 | bool mlir::sortTopologically( |
98 | Block *block, function_ref<bool(Value, Operation *)> isOperandReady) { |
99 | if (block->empty()) |
100 | return true; |
101 | if (block->back().hasTrait<OpTrait::IsTerminator>()) |
102 | return sortTopologically(block, ops: block->without_terminator(), |
103 | isOperandReady); |
104 | return sortTopologically(block, ops: *block, isOperandReady); |
105 | } |
106 | |
107 | bool mlir::computeTopologicalSorting( |
108 | MutableArrayRef<Operation *> ops, |
109 | function_ref<bool(Value, Operation *)> isOperandReady) { |
110 | if (ops.empty()) |
111 | return true; |
112 | |
113 | // The set of operations that have not yet been scheduled. |
114 | DenseSet<Operation *> unscheduledOps; |
115 | |
116 | // Mark all operations as unscheduled. |
117 | for (Operation *op : ops) |
118 | unscheduledOps.insert(V: op); |
119 | |
120 | unsigned nextScheduledOp = 0; |
121 | |
122 | bool allOpsScheduled = true; |
123 | while (!unscheduledOps.empty()) { |
124 | bool scheduledAtLeastOnce = false; |
125 | |
126 | // Loop over the ops that are not sorted yet, try to find the ones "ready", |
127 | // i.e. the ones for which there aren't any operand produced by an op in the |
128 | // set, and "schedule" it (swap it with the op at `nextScheduledOp`). |
129 | for (unsigned i = nextScheduledOp; i < ops.size(); ++i) { |
130 | if (!isOpReady(op: ops[i], unscheduledOps, isOperandReady)) |
131 | continue; |
132 | |
133 | // Schedule the operation by moving it to the start. |
134 | unscheduledOps.erase(V: ops[i]); |
135 | std::swap(a&: ops[i], b&: ops[nextScheduledOp]); |
136 | scheduledAtLeastOnce = true; |
137 | ++nextScheduledOp; |
138 | } |
139 | |
140 | // If no operations were scheduled, just schedule the first op and continue. |
141 | if (!scheduledAtLeastOnce) { |
142 | allOpsScheduled = false; |
143 | unscheduledOps.erase(V: ops[nextScheduledOp++]); |
144 | } |
145 | } |
146 | |
147 | return allOpsScheduled; |
148 | } |
149 | |