1 | //===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a commutativity utility pattern and a function to |
10 | // populate this pattern. The function is intended to be used inside passes to |
11 | // simplify the matching of commutative operations by fixing the order of their |
12 | // operands. |
13 | // |
14 | //===----------------------------------------------------------------------===// |
15 | |
16 | #include "mlir/Transforms/CommutativityUtils.h" |
17 | |
18 | #include <queue> |
19 | |
20 | using namespace mlir; |
21 | |
22 | /// The possible "types" of ancestors. Here, an ancestor is an op or a block |
23 | /// argument present in the backward slice of a value. |
24 | enum AncestorType { |
25 | /// Pertains to a block argument. |
26 | BLOCK_ARGUMENT, |
27 | |
28 | /// Pertains to a non-constant-like op. |
29 | NON_CONSTANT_OP, |
30 | |
31 | /// Pertains to a constant-like op. |
32 | CONSTANT_OP |
33 | }; |
34 | |
35 | /// Stores the "key" associated with an ancestor. |
36 | struct AncestorKey { |
37 | /// Holds `BLOCK_ARGUMENT`, `NON_CONSTANT_OP`, or `CONSTANT_OP`, depending on |
38 | /// the ancestor. |
39 | AncestorType type; |
40 | |
41 | /// Holds the op name of the ancestor if its `type` is `NON_CONSTANT_OP` or |
42 | /// `CONSTANT_OP`. Else, holds "". |
43 | StringRef opName; |
44 | |
45 | /// Constructor for `AncestorKey`. |
46 | AncestorKey(Operation *op) { |
47 | if (!op) { |
48 | type = BLOCK_ARGUMENT; |
49 | } else { |
50 | type = |
51 | op->hasTrait<OpTrait::ConstantLike>() ? CONSTANT_OP : NON_CONSTANT_OP; |
52 | opName = op->getName().getStringRef(); |
53 | } |
54 | } |
55 | |
56 | /// Overloaded operator `<` for `AncestorKey`. |
57 | /// |
58 | /// AncestorKeys of type `BLOCK_ARGUMENT` are considered the smallest, those |
59 | /// of type `CONSTANT_OP`, the largest, and `NON_CONSTANT_OP` types come in |
60 | /// between. Within the types `NON_CONSTANT_OP` and `CONSTANT_OP`, the smaller |
61 | /// ones are the ones with smaller op names (lexicographically). |
62 | /// |
63 | /// TODO: Include other information like attributes, value type, etc., to |
64 | /// enhance this comparison. For example, currently this comparison doesn't |
65 | /// differentiate between `cmpi sle` and `cmpi sgt` or `addi (in i32)` and |
66 | /// `addi (in i64)`. Such an enhancement should only be done if the need |
67 | /// arises. |
68 | bool operator<(const AncestorKey &key) const { |
69 | return std::tie(args: type, args: opName) < std::tie(args: key.type, args: key.opName); |
70 | } |
71 | }; |
72 | |
73 | /// Stores a commutative operand along with its BFS traversal information. |
74 | struct CommutativeOperand { |
75 | /// Stores the operand. |
76 | Value operand; |
77 | |
78 | /// Stores the queue of ancestors of the operand's BFS traversal at a |
79 | /// particular point in time. |
80 | std::queue<Operation *> ancestorQueue; |
81 | |
82 | /// Stores the list of ancestors that have been visited by the BFS traversal |
83 | /// at a particular point in time. |
84 | DenseSet<Operation *> visitedAncestors; |
85 | |
86 | /// Stores the operand's "key". This "key" is defined as a list of the |
87 | /// "AncestorKeys" associated with the ancestors of this operand, in a |
88 | /// breadth-first order. |
89 | /// |
90 | /// So, if an operand, say `A`, was produced as follows: |
91 | /// |
92 | /// `<block argument>` `<block argument>` |
93 | /// \ / |
94 | /// \ / |
95 | /// `arith.subi` `arith.constant` |
96 | /// \ / |
97 | /// `arith.addi` |
98 | /// | |
99 | /// returns `A` |
100 | /// |
101 | /// Then, the ancestors of `A`, in the breadth-first order are: |
102 | /// `arith.addi`, `arith.subi`, `arith.constant`, `<block argument>`, and |
103 | /// `<block argument>`. |
104 | /// |
105 | /// Thus, the "key" associated with operand `A` is: |
106 | /// { |
107 | /// {type: `NON_CONSTANT_OP`, opName: "arith.addi"}, |
108 | /// {type: `NON_CONSTANT_OP`, opName: "arith.subi"}, |
109 | /// {type: `CONSTANT_OP`, opName: "arith.constant"}, |
110 | /// {type: `BLOCK_ARGUMENT`, opName: ""}, |
111 | /// {type: `BLOCK_ARGUMENT`, opName: ""} |
112 | /// } |
113 | SmallVector<AncestorKey, 4> key; |
114 | |
115 | /// Push an ancestor into the operand's BFS information structure. This |
116 | /// entails it being pushed into the queue (always) and inserted into the |
117 | /// "visited ancestors" list (iff it is an op rather than a block argument). |
118 | void pushAncestor(Operation *op) { |
119 | ancestorQueue.push(x: op); |
120 | if (op) |
121 | visitedAncestors.insert(V: op); |
122 | } |
123 | |
124 | /// Refresh the key. |
125 | /// |
126 | /// Refreshing a key entails making it up-to-date with the operand's BFS |
127 | /// traversal that has happened till that point in time, i.e, appending the |
128 | /// existing key with the front ancestor's "AncestorKey". Note that a key |
129 | /// directly reflects the BFS and thus needs to be refreshed during the |
130 | /// progression of the traversal. |
131 | void refreshKey() { |
132 | if (ancestorQueue.empty()) |
133 | return; |
134 | |
135 | Operation *frontAncestor = ancestorQueue.front(); |
136 | AncestorKey frontAncestorKey(frontAncestor); |
137 | key.push_back(Elt: frontAncestorKey); |
138 | } |
139 | |
140 | /// Pop the front ancestor, if any, from the queue and then push its adjacent |
141 | /// unvisited ancestors, if any, to the queue (this is the main body of the |
142 | /// BFS algorithm). |
143 | void popFrontAndPushAdjacentUnvisitedAncestors() { |
144 | if (ancestorQueue.empty()) |
145 | return; |
146 | Operation *frontAncestor = ancestorQueue.front(); |
147 | ancestorQueue.pop(); |
148 | if (!frontAncestor) |
149 | return; |
150 | for (Value operand : frontAncestor->getOperands()) { |
151 | Operation *operandDefOp = operand.getDefiningOp(); |
152 | if (!operandDefOp || !visitedAncestors.contains(V: operandDefOp)) |
153 | pushAncestor(op: operandDefOp); |
154 | } |
155 | } |
156 | }; |
157 | |
158 | /// Sorts the operands of `op` in ascending order of the "key" associated with |
159 | /// each operand iff `op` is commutative. This is a stable sort. |
160 | /// |
161 | /// After the application of this pattern, since the commutative operands now |
162 | /// have a deterministic order in which they occur in an op, the matching of |
163 | /// large DAGs becomes much simpler, i.e., requires much less number of checks |
164 | /// to be written by a user in her/his pattern matching function. |
165 | /// |
166 | /// Some examples of such a sorting: |
167 | /// |
168 | /// Assume that the sorting is being applied to `foo.commutative`, which is a |
169 | /// commutative op. |
170 | /// |
171 | /// Example 1: |
172 | /// |
173 | /// %1 = foo.const 0 |
174 | /// %2 = foo.mul <block argument>, <block argument> |
175 | /// %3 = foo.commutative %1, %2 |
176 | /// |
177 | /// Here, |
178 | /// 1. The key associated with %1 is: |
179 | /// `{ |
180 | /// {CONSTANT_OP, "foo.const"} |
181 | /// }` |
182 | /// 2. The key associated with %2 is: |
183 | /// `{ |
184 | /// {NON_CONSTANT_OP, "foo.mul"}, |
185 | /// {BLOCK_ARGUMENT, ""}, |
186 | /// {BLOCK_ARGUMENT, ""} |
187 | /// }` |
188 | /// |
189 | /// The key of %2 < the key of %1 |
190 | /// Thus, the sorted `foo.commutative` is: |
191 | /// %3 = foo.commutative %2, %1 |
192 | /// |
193 | /// Example 2: |
194 | /// |
195 | /// %1 = foo.const 0 |
196 | /// %2 = foo.mul <block argument>, <block argument> |
197 | /// %3 = foo.mul %2, %1 |
198 | /// %4 = foo.add %2, %1 |
199 | /// %5 = foo.commutative %1, %2, %3, %4 |
200 | /// |
201 | /// Here, |
202 | /// 1. The key associated with %1 is: |
203 | /// `{ |
204 | /// {CONSTANT_OP, "foo.const"} |
205 | /// }` |
206 | /// 2. The key associated with %2 is: |
207 | /// `{ |
208 | /// {NON_CONSTANT_OP, "foo.mul"}, |
209 | /// {BLOCK_ARGUMENT, ""} |
210 | /// }` |
211 | /// 3. The key associated with %3 is: |
212 | /// `{ |
213 | /// {NON_CONSTANT_OP, "foo.mul"}, |
214 | /// {NON_CONSTANT_OP, "foo.mul"}, |
215 | /// {CONSTANT_OP, "foo.const"}, |
216 | /// {BLOCK_ARGUMENT, ""}, |
217 | /// {BLOCK_ARGUMENT, ""} |
218 | /// }` |
219 | /// 4. The key associated with %4 is: |
220 | /// `{ |
221 | /// {NON_CONSTANT_OP, "foo.add"}, |
222 | /// {NON_CONSTANT_OP, "foo.mul"}, |
223 | /// {CONSTANT_OP, "foo.const"}, |
224 | /// {BLOCK_ARGUMENT, ""}, |
225 | /// {BLOCK_ARGUMENT, ""} |
226 | /// }` |
227 | /// |
228 | /// Thus, the sorted `foo.commutative` is: |
229 | /// %5 = foo.commutative %4, %3, %2, %1 |
230 | class SortCommutativeOperands : public RewritePattern { |
231 | public: |
232 | SortCommutativeOperands(MLIRContext *context) |
233 | : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/5, context) {} |
234 | LogicalResult matchAndRewrite(Operation *op, |
235 | PatternRewriter &rewriter) const override { |
236 | // Custom comparator for two commutative operands, which returns true iff |
237 | // the "key" of `constCommOperandA` < the "key" of `constCommOperandB`, |
238 | // i.e., |
239 | // 1. In the first unequal pair of corresponding AncestorKeys, the |
240 | // AncestorKey in `constCommOperandA` is smaller, or, |
241 | // 2. Both the AncestorKeys in every pair are the same and the size of |
242 | // `constCommOperandA`'s "key" is smaller. |
243 | auto commutativeOperandComparator = |
244 | [](const std::unique_ptr<CommutativeOperand> &constCommOperandA, |
245 | const std::unique_ptr<CommutativeOperand> &constCommOperandB) { |
246 | if (constCommOperandA->operand == constCommOperandB->operand) |
247 | return false; |
248 | |
249 | auto &commOperandA = |
250 | const_cast<std::unique_ptr<CommutativeOperand> &>( |
251 | constCommOperandA); |
252 | auto &commOperandB = |
253 | const_cast<std::unique_ptr<CommutativeOperand> &>( |
254 | constCommOperandB); |
255 | |
256 | // Iteratively perform the BFS's of both operands until an order among |
257 | // them can be determined. |
258 | unsigned keyIndex = 0; |
259 | while (true) { |
260 | if (commOperandA->key.size() <= keyIndex) { |
261 | if (commOperandA->ancestorQueue.empty()) |
262 | return true; |
263 | commOperandA->popFrontAndPushAdjacentUnvisitedAncestors(); |
264 | commOperandA->refreshKey(); |
265 | } |
266 | if (commOperandB->key.size() <= keyIndex) { |
267 | if (commOperandB->ancestorQueue.empty()) |
268 | return false; |
269 | commOperandB->popFrontAndPushAdjacentUnvisitedAncestors(); |
270 | commOperandB->refreshKey(); |
271 | } |
272 | if (commOperandA->ancestorQueue.empty() || |
273 | commOperandB->ancestorQueue.empty()) |
274 | return commOperandA->key.size() < commOperandB->key.size(); |
275 | if (commOperandA->key[keyIndex] < commOperandB->key[keyIndex]) |
276 | return true; |
277 | if (commOperandB->key[keyIndex] < commOperandA->key[keyIndex]) |
278 | return false; |
279 | keyIndex++; |
280 | } |
281 | }; |
282 | |
283 | // If `op` is not commutative, do nothing. |
284 | if (!op->hasTrait<OpTrait::IsCommutative>()) |
285 | return failure(); |
286 | |
287 | // Populate the list of commutative operands. |
288 | SmallVector<Value, 2> operands = op->getOperands(); |
289 | SmallVector<std::unique_ptr<CommutativeOperand>, 2> commOperands; |
290 | for (Value operand : operands) { |
291 | std::unique_ptr<CommutativeOperand> commOperand = |
292 | std::make_unique<CommutativeOperand>(); |
293 | commOperand->operand = operand; |
294 | commOperand->pushAncestor(op: operand.getDefiningOp()); |
295 | commOperand->refreshKey(); |
296 | commOperands.push_back(Elt: std::move(commOperand)); |
297 | } |
298 | |
299 | // Sort the operands. |
300 | std::stable_sort(first: commOperands.begin(), last: commOperands.end(), |
301 | comp: commutativeOperandComparator); |
302 | SmallVector<Value, 2> sortedOperands; |
303 | for (const std::unique_ptr<CommutativeOperand> &commOperand : commOperands) |
304 | sortedOperands.push_back(Elt: commOperand->operand); |
305 | if (sortedOperands == operands) |
306 | return failure(); |
307 | rewriter.modifyOpInPlace(root: op, callable: [&] { op->setOperands(sortedOperands); }); |
308 | return success(); |
309 | } |
310 | }; |
311 | |
312 | void mlir::populateCommutativityUtilsPatterns(RewritePatternSet &patterns) { |
313 | patterns.add<SortCommutativeOperands>(arg: patterns.getContext()); |
314 | } |
315 | |