1//===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a commutativity utility pattern and a function to
10// populate this pattern. The function is intended to be used inside passes to
11// simplify the matching of commutative operations by fixing the order of their
12// operands.
13//
14//===----------------------------------------------------------------------===//
15
16#include "mlir/Transforms/CommutativityUtils.h"
17
18#include <queue>
19
20using namespace mlir;
21
22/// The possible "types" of ancestors. Here, an ancestor is an op or a block
23/// argument present in the backward slice of a value.
24enum AncestorType {
25 /// Pertains to a block argument.
26 BLOCK_ARGUMENT,
27
28 /// Pertains to a non-constant-like op.
29 NON_CONSTANT_OP,
30
31 /// Pertains to a constant-like op.
32 CONSTANT_OP
33};
34
35/// Stores the "key" associated with an ancestor.
36struct AncestorKey {
37 /// Holds `BLOCK_ARGUMENT`, `NON_CONSTANT_OP`, or `CONSTANT_OP`, depending on
38 /// the ancestor.
39 AncestorType type;
40
41 /// Holds the op name of the ancestor if its `type` is `NON_CONSTANT_OP` or
42 /// `CONSTANT_OP`. Else, holds "".
43 StringRef opName;
44
45 /// Constructor for `AncestorKey`.
46 AncestorKey(Operation *op) {
47 if (!op) {
48 type = BLOCK_ARGUMENT;
49 } else {
50 type =
51 op->hasTrait<OpTrait::ConstantLike>() ? CONSTANT_OP : NON_CONSTANT_OP;
52 opName = op->getName().getStringRef();
53 }
54 }
55
56 /// Overloaded operator `<` for `AncestorKey`.
57 ///
58 /// AncestorKeys of type `BLOCK_ARGUMENT` are considered the smallest, those
59 /// of type `CONSTANT_OP`, the largest, and `NON_CONSTANT_OP` types come in
60 /// between. Within the types `NON_CONSTANT_OP` and `CONSTANT_OP`, the smaller
61 /// ones are the ones with smaller op names (lexicographically).
62 ///
63 /// TODO: Include other information like attributes, value type, etc., to
64 /// enhance this comparison. For example, currently this comparison doesn't
65 /// differentiate between `cmpi sle` and `cmpi sgt` or `addi (in i32)` and
66 /// `addi (in i64)`. Such an enhancement should only be done if the need
67 /// arises.
68 bool operator<(const AncestorKey &key) const {
69 return std::tie(args: type, args: opName) < std::tie(args: key.type, args: key.opName);
70 }
71};
72
73/// Stores a commutative operand along with its BFS traversal information.
74struct CommutativeOperand {
75 /// Stores the operand.
76 Value operand;
77
78 /// Stores the queue of ancestors of the operand's BFS traversal at a
79 /// particular point in time.
80 std::queue<Operation *> ancestorQueue;
81
82 /// Stores the list of ancestors that have been visited by the BFS traversal
83 /// at a particular point in time.
84 DenseSet<Operation *> visitedAncestors;
85
86 /// Stores the operand's "key". This "key" is defined as a list of the
87 /// "AncestorKeys" associated with the ancestors of this operand, in a
88 /// breadth-first order.
89 ///
90 /// So, if an operand, say `A`, was produced as follows:
91 ///
92 /// `<block argument>` `<block argument>`
93 /// \ /
94 /// \ /
95 /// `arith.subi` `arith.constant`
96 /// \ /
97 /// `arith.addi`
98 /// |
99 /// returns `A`
100 ///
101 /// Then, the ancestors of `A`, in the breadth-first order are:
102 /// `arith.addi`, `arith.subi`, `arith.constant`, `<block argument>`, and
103 /// `<block argument>`.
104 ///
105 /// Thus, the "key" associated with operand `A` is:
106 /// {
107 /// {type: `NON_CONSTANT_OP`, opName: "arith.addi"},
108 /// {type: `NON_CONSTANT_OP`, opName: "arith.subi"},
109 /// {type: `CONSTANT_OP`, opName: "arith.constant"},
110 /// {type: `BLOCK_ARGUMENT`, opName: ""},
111 /// {type: `BLOCK_ARGUMENT`, opName: ""}
112 /// }
113 SmallVector<AncestorKey, 4> key;
114
115 /// Push an ancestor into the operand's BFS information structure. This
116 /// entails it being pushed into the queue (always) and inserted into the
117 /// "visited ancestors" list (iff it is an op rather than a block argument).
118 void pushAncestor(Operation *op) {
119 ancestorQueue.push(x: op);
120 if (op)
121 visitedAncestors.insert(V: op);
122 }
123
124 /// Refresh the key.
125 ///
126 /// Refreshing a key entails making it up-to-date with the operand's BFS
127 /// traversal that has happened till that point in time, i.e, appending the
128 /// existing key with the front ancestor's "AncestorKey". Note that a key
129 /// directly reflects the BFS and thus needs to be refreshed during the
130 /// progression of the traversal.
131 void refreshKey() {
132 if (ancestorQueue.empty())
133 return;
134
135 Operation *frontAncestor = ancestorQueue.front();
136 AncestorKey frontAncestorKey(frontAncestor);
137 key.push_back(Elt: frontAncestorKey);
138 }
139
140 /// Pop the front ancestor, if any, from the queue and then push its adjacent
141 /// unvisited ancestors, if any, to the queue (this is the main body of the
142 /// BFS algorithm).
143 void popFrontAndPushAdjacentUnvisitedAncestors() {
144 if (ancestorQueue.empty())
145 return;
146 Operation *frontAncestor = ancestorQueue.front();
147 ancestorQueue.pop();
148 if (!frontAncestor)
149 return;
150 for (Value operand : frontAncestor->getOperands()) {
151 Operation *operandDefOp = operand.getDefiningOp();
152 if (!operandDefOp || !visitedAncestors.contains(V: operandDefOp))
153 pushAncestor(op: operandDefOp);
154 }
155 }
156};
157
158/// Sorts the operands of `op` in ascending order of the "key" associated with
159/// each operand iff `op` is commutative. This is a stable sort.
160///
161/// After the application of this pattern, since the commutative operands now
162/// have a deterministic order in which they occur in an op, the matching of
163/// large DAGs becomes much simpler, i.e., requires much less number of checks
164/// to be written by a user in her/his pattern matching function.
165///
166/// Some examples of such a sorting:
167///
168/// Assume that the sorting is being applied to `foo.commutative`, which is a
169/// commutative op.
170///
171/// Example 1:
172///
173/// %1 = foo.const 0
174/// %2 = foo.mul <block argument>, <block argument>
175/// %3 = foo.commutative %1, %2
176///
177/// Here,
178/// 1. The key associated with %1 is:
179/// `{
180/// {CONSTANT_OP, "foo.const"}
181/// }`
182/// 2. The key associated with %2 is:
183/// `{
184/// {NON_CONSTANT_OP, "foo.mul"},
185/// {BLOCK_ARGUMENT, ""},
186/// {BLOCK_ARGUMENT, ""}
187/// }`
188///
189/// The key of %2 < the key of %1
190/// Thus, the sorted `foo.commutative` is:
191/// %3 = foo.commutative %2, %1
192///
193/// Example 2:
194///
195/// %1 = foo.const 0
196/// %2 = foo.mul <block argument>, <block argument>
197/// %3 = foo.mul %2, %1
198/// %4 = foo.add %2, %1
199/// %5 = foo.commutative %1, %2, %3, %4
200///
201/// Here,
202/// 1. The key associated with %1 is:
203/// `{
204/// {CONSTANT_OP, "foo.const"}
205/// }`
206/// 2. The key associated with %2 is:
207/// `{
208/// {NON_CONSTANT_OP, "foo.mul"},
209/// {BLOCK_ARGUMENT, ""}
210/// }`
211/// 3. The key associated with %3 is:
212/// `{
213/// {NON_CONSTANT_OP, "foo.mul"},
214/// {NON_CONSTANT_OP, "foo.mul"},
215/// {CONSTANT_OP, "foo.const"},
216/// {BLOCK_ARGUMENT, ""},
217/// {BLOCK_ARGUMENT, ""}
218/// }`
219/// 4. The key associated with %4 is:
220/// `{
221/// {NON_CONSTANT_OP, "foo.add"},
222/// {NON_CONSTANT_OP, "foo.mul"},
223/// {CONSTANT_OP, "foo.const"},
224/// {BLOCK_ARGUMENT, ""},
225/// {BLOCK_ARGUMENT, ""}
226/// }`
227///
228/// Thus, the sorted `foo.commutative` is:
229/// %5 = foo.commutative %4, %3, %2, %1
230class SortCommutativeOperands : public RewritePattern {
231public:
232 SortCommutativeOperands(MLIRContext *context)
233 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/5, context) {}
234 LogicalResult matchAndRewrite(Operation *op,
235 PatternRewriter &rewriter) const override {
236 // Custom comparator for two commutative operands, which returns true iff
237 // the "key" of `constCommOperandA` < the "key" of `constCommOperandB`,
238 // i.e.,
239 // 1. In the first unequal pair of corresponding AncestorKeys, the
240 // AncestorKey in `constCommOperandA` is smaller, or,
241 // 2. Both the AncestorKeys in every pair are the same and the size of
242 // `constCommOperandA`'s "key" is smaller.
243 auto commutativeOperandComparator =
244 [](const std::unique_ptr<CommutativeOperand> &constCommOperandA,
245 const std::unique_ptr<CommutativeOperand> &constCommOperandB) {
246 if (constCommOperandA->operand == constCommOperandB->operand)
247 return false;
248
249 auto &commOperandA =
250 const_cast<std::unique_ptr<CommutativeOperand> &>(
251 constCommOperandA);
252 auto &commOperandB =
253 const_cast<std::unique_ptr<CommutativeOperand> &>(
254 constCommOperandB);
255
256 // Iteratively perform the BFS's of both operands until an order among
257 // them can be determined.
258 unsigned keyIndex = 0;
259 while (true) {
260 if (commOperandA->key.size() <= keyIndex) {
261 if (commOperandA->ancestorQueue.empty())
262 return true;
263 commOperandA->popFrontAndPushAdjacentUnvisitedAncestors();
264 commOperandA->refreshKey();
265 }
266 if (commOperandB->key.size() <= keyIndex) {
267 if (commOperandB->ancestorQueue.empty())
268 return false;
269 commOperandB->popFrontAndPushAdjacentUnvisitedAncestors();
270 commOperandB->refreshKey();
271 }
272 if (commOperandA->ancestorQueue.empty() ||
273 commOperandB->ancestorQueue.empty())
274 return commOperandA->key.size() < commOperandB->key.size();
275 if (commOperandA->key[keyIndex] < commOperandB->key[keyIndex])
276 return true;
277 if (commOperandB->key[keyIndex] < commOperandA->key[keyIndex])
278 return false;
279 keyIndex++;
280 }
281 };
282
283 // If `op` is not commutative, do nothing.
284 if (!op->hasTrait<OpTrait::IsCommutative>())
285 return failure();
286
287 // Populate the list of commutative operands.
288 SmallVector<Value, 2> operands = op->getOperands();
289 SmallVector<std::unique_ptr<CommutativeOperand>, 2> commOperands;
290 for (Value operand : operands) {
291 std::unique_ptr<CommutativeOperand> commOperand =
292 std::make_unique<CommutativeOperand>();
293 commOperand->operand = operand;
294 commOperand->pushAncestor(op: operand.getDefiningOp());
295 commOperand->refreshKey();
296 commOperands.push_back(Elt: std::move(commOperand));
297 }
298
299 // Sort the operands.
300 std::stable_sort(first: commOperands.begin(), last: commOperands.end(),
301 comp: commutativeOperandComparator);
302 SmallVector<Value, 2> sortedOperands;
303 for (const std::unique_ptr<CommutativeOperand> &commOperand : commOperands)
304 sortedOperands.push_back(Elt: commOperand->operand);
305 if (sortedOperands == operands)
306 return failure();
307 rewriter.modifyOpInPlace(root: op, callable: [&] { op->setOperands(sortedOperands); });
308 return success();
309 }
310};
311
312void mlir::populateCommutativityUtilsPatterns(RewritePatternSet &patterns) {
313 patterns.add<SortCommutativeOperands>(arg: patterns.getContext());
314}
315

source code of mlir/lib/Transforms/Utils/CommutativityUtils.cpp