| 1 | //===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements a commutativity utility pattern and a function to |
| 10 | // populate this pattern. The function is intended to be used inside passes to |
| 11 | // simplify the matching of commutative operations by fixing the order of their |
| 12 | // operands. |
| 13 | // |
| 14 | //===----------------------------------------------------------------------===// |
| 15 | |
| 16 | #include "mlir/Transforms/CommutativityUtils.h" |
| 17 | |
| 18 | #include <queue> |
| 19 | |
| 20 | using namespace mlir; |
| 21 | |
| 22 | /// The possible "types" of ancestors. Here, an ancestor is an op or a block |
| 23 | /// argument present in the backward slice of a value. |
| 24 | enum AncestorType { |
| 25 | /// Pertains to a block argument. |
| 26 | BLOCK_ARGUMENT, |
| 27 | |
| 28 | /// Pertains to a non-constant-like op. |
| 29 | NON_CONSTANT_OP, |
| 30 | |
| 31 | /// Pertains to a constant-like op. |
| 32 | CONSTANT_OP |
| 33 | }; |
| 34 | |
| 35 | /// Stores the "key" associated with an ancestor. |
| 36 | struct AncestorKey { |
| 37 | /// Holds `BLOCK_ARGUMENT`, `NON_CONSTANT_OP`, or `CONSTANT_OP`, depending on |
| 38 | /// the ancestor. |
| 39 | AncestorType type; |
| 40 | |
| 41 | /// Holds the op name of the ancestor if its `type` is `NON_CONSTANT_OP` or |
| 42 | /// `CONSTANT_OP`. Else, holds "". |
| 43 | StringRef opName; |
| 44 | |
| 45 | /// Constructor for `AncestorKey`. |
| 46 | AncestorKey(Operation *op) { |
| 47 | if (!op) { |
| 48 | type = BLOCK_ARGUMENT; |
| 49 | } else { |
| 50 | type = |
| 51 | op->hasTrait<OpTrait::ConstantLike>() ? CONSTANT_OP : NON_CONSTANT_OP; |
| 52 | opName = op->getName().getStringRef(); |
| 53 | } |
| 54 | } |
| 55 | |
| 56 | /// Overloaded operator `<` for `AncestorKey`. |
| 57 | /// |
| 58 | /// AncestorKeys of type `BLOCK_ARGUMENT` are considered the smallest, those |
| 59 | /// of type `CONSTANT_OP`, the largest, and `NON_CONSTANT_OP` types come in |
| 60 | /// between. Within the types `NON_CONSTANT_OP` and `CONSTANT_OP`, the smaller |
| 61 | /// ones are the ones with smaller op names (lexicographically). |
| 62 | /// |
| 63 | /// TODO: Include other information like attributes, value type, etc., to |
| 64 | /// enhance this comparison. For example, currently this comparison doesn't |
| 65 | /// differentiate between `cmpi sle` and `cmpi sgt` or `addi (in i32)` and |
| 66 | /// `addi (in i64)`. Such an enhancement should only be done if the need |
| 67 | /// arises. |
| 68 | bool operator<(const AncestorKey &key) const { |
| 69 | return std::tie(args: type, args: opName) < std::tie(args: key.type, args: key.opName); |
| 70 | } |
| 71 | }; |
| 72 | |
| 73 | /// Stores a commutative operand along with its BFS traversal information. |
| 74 | struct CommutativeOperand { |
| 75 | /// Stores the operand. |
| 76 | Value operand; |
| 77 | |
| 78 | /// Stores the queue of ancestors of the operand's BFS traversal at a |
| 79 | /// particular point in time. |
| 80 | std::queue<Operation *> ancestorQueue; |
| 81 | |
| 82 | /// Stores the list of ancestors that have been visited by the BFS traversal |
| 83 | /// at a particular point in time. |
| 84 | DenseSet<Operation *> visitedAncestors; |
| 85 | |
| 86 | /// Stores the operand's "key". This "key" is defined as a list of the |
| 87 | /// "AncestorKeys" associated with the ancestors of this operand, in a |
| 88 | /// breadth-first order. |
| 89 | /// |
| 90 | /// So, if an operand, say `A`, was produced as follows: |
| 91 | /// |
| 92 | /// `<block argument>` `<block argument>` |
| 93 | /// \ / |
| 94 | /// \ / |
| 95 | /// `arith.subi` `arith.constant` |
| 96 | /// \ / |
| 97 | /// `arith.addi` |
| 98 | /// | |
| 99 | /// returns `A` |
| 100 | /// |
| 101 | /// Then, the ancestors of `A`, in the breadth-first order are: |
| 102 | /// `arith.addi`, `arith.subi`, `arith.constant`, `<block argument>`, and |
| 103 | /// `<block argument>`. |
| 104 | /// |
| 105 | /// Thus, the "key" associated with operand `A` is: |
| 106 | /// { |
| 107 | /// {type: `NON_CONSTANT_OP`, opName: "arith.addi"}, |
| 108 | /// {type: `NON_CONSTANT_OP`, opName: "arith.subi"}, |
| 109 | /// {type: `CONSTANT_OP`, opName: "arith.constant"}, |
| 110 | /// {type: `BLOCK_ARGUMENT`, opName: ""}, |
| 111 | /// {type: `BLOCK_ARGUMENT`, opName: ""} |
| 112 | /// } |
| 113 | SmallVector<AncestorKey, 4> key; |
| 114 | |
| 115 | /// Push an ancestor into the operand's BFS information structure. This |
| 116 | /// entails it being pushed into the queue (always) and inserted into the |
| 117 | /// "visited ancestors" list (iff it is an op rather than a block argument). |
| 118 | void pushAncestor(Operation *op) { |
| 119 | ancestorQueue.push(x: op); |
| 120 | if (op) |
| 121 | visitedAncestors.insert(V: op); |
| 122 | } |
| 123 | |
| 124 | /// Refresh the key. |
| 125 | /// |
| 126 | /// Refreshing a key entails making it up-to-date with the operand's BFS |
| 127 | /// traversal that has happened till that point in time, i.e, appending the |
| 128 | /// existing key with the front ancestor's "AncestorKey". Note that a key |
| 129 | /// directly reflects the BFS and thus needs to be refreshed during the |
| 130 | /// progression of the traversal. |
| 131 | void refreshKey() { |
| 132 | if (ancestorQueue.empty()) |
| 133 | return; |
| 134 | |
| 135 | Operation *frontAncestor = ancestorQueue.front(); |
| 136 | AncestorKey frontAncestorKey(frontAncestor); |
| 137 | key.push_back(Elt: frontAncestorKey); |
| 138 | } |
| 139 | |
| 140 | /// Pop the front ancestor, if any, from the queue and then push its adjacent |
| 141 | /// unvisited ancestors, if any, to the queue (this is the main body of the |
| 142 | /// BFS algorithm). |
| 143 | void popFrontAndPushAdjacentUnvisitedAncestors() { |
| 144 | if (ancestorQueue.empty()) |
| 145 | return; |
| 146 | Operation *frontAncestor = ancestorQueue.front(); |
| 147 | ancestorQueue.pop(); |
| 148 | if (!frontAncestor) |
| 149 | return; |
| 150 | for (Value operand : frontAncestor->getOperands()) { |
| 151 | Operation *operandDefOp = operand.getDefiningOp(); |
| 152 | if (!operandDefOp || !visitedAncestors.contains(V: operandDefOp)) |
| 153 | pushAncestor(op: operandDefOp); |
| 154 | } |
| 155 | } |
| 156 | }; |
| 157 | |
| 158 | /// Sorts the operands of `op` in ascending order of the "key" associated with |
| 159 | /// each operand iff `op` is commutative. This is a stable sort. |
| 160 | /// |
| 161 | /// After the application of this pattern, since the commutative operands now |
| 162 | /// have a deterministic order in which they occur in an op, the matching of |
| 163 | /// large DAGs becomes much simpler, i.e., requires much less number of checks |
| 164 | /// to be written by a user in her/his pattern matching function. |
| 165 | /// |
| 166 | /// Some examples of such a sorting: |
| 167 | /// |
| 168 | /// Assume that the sorting is being applied to `foo.commutative`, which is a |
| 169 | /// commutative op. |
| 170 | /// |
| 171 | /// Example 1: |
| 172 | /// |
| 173 | /// %1 = foo.const 0 |
| 174 | /// %2 = foo.mul <block argument>, <block argument> |
| 175 | /// %3 = foo.commutative %1, %2 |
| 176 | /// |
| 177 | /// Here, |
| 178 | /// 1. The key associated with %1 is: |
| 179 | /// `{ |
| 180 | /// {CONSTANT_OP, "foo.const"} |
| 181 | /// }` |
| 182 | /// 2. The key associated with %2 is: |
| 183 | /// `{ |
| 184 | /// {NON_CONSTANT_OP, "foo.mul"}, |
| 185 | /// {BLOCK_ARGUMENT, ""}, |
| 186 | /// {BLOCK_ARGUMENT, ""} |
| 187 | /// }` |
| 188 | /// |
| 189 | /// The key of %2 < the key of %1 |
| 190 | /// Thus, the sorted `foo.commutative` is: |
| 191 | /// %3 = foo.commutative %2, %1 |
| 192 | /// |
| 193 | /// Example 2: |
| 194 | /// |
| 195 | /// %1 = foo.const 0 |
| 196 | /// %2 = foo.mul <block argument>, <block argument> |
| 197 | /// %3 = foo.mul %2, %1 |
| 198 | /// %4 = foo.add %2, %1 |
| 199 | /// %5 = foo.commutative %1, %2, %3, %4 |
| 200 | /// |
| 201 | /// Here, |
| 202 | /// 1. The key associated with %1 is: |
| 203 | /// `{ |
| 204 | /// {CONSTANT_OP, "foo.const"} |
| 205 | /// }` |
| 206 | /// 2. The key associated with %2 is: |
| 207 | /// `{ |
| 208 | /// {NON_CONSTANT_OP, "foo.mul"}, |
| 209 | /// {BLOCK_ARGUMENT, ""} |
| 210 | /// }` |
| 211 | /// 3. The key associated with %3 is: |
| 212 | /// `{ |
| 213 | /// {NON_CONSTANT_OP, "foo.mul"}, |
| 214 | /// {NON_CONSTANT_OP, "foo.mul"}, |
| 215 | /// {CONSTANT_OP, "foo.const"}, |
| 216 | /// {BLOCK_ARGUMENT, ""}, |
| 217 | /// {BLOCK_ARGUMENT, ""} |
| 218 | /// }` |
| 219 | /// 4. The key associated with %4 is: |
| 220 | /// `{ |
| 221 | /// {NON_CONSTANT_OP, "foo.add"}, |
| 222 | /// {NON_CONSTANT_OP, "foo.mul"}, |
| 223 | /// {CONSTANT_OP, "foo.const"}, |
| 224 | /// {BLOCK_ARGUMENT, ""}, |
| 225 | /// {BLOCK_ARGUMENT, ""} |
| 226 | /// }` |
| 227 | /// |
| 228 | /// Thus, the sorted `foo.commutative` is: |
| 229 | /// %5 = foo.commutative %4, %3, %2, %1 |
| 230 | class SortCommutativeOperands : public RewritePattern { |
| 231 | public: |
| 232 | SortCommutativeOperands(MLIRContext *context) |
| 233 | : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/5, context) {} |
| 234 | LogicalResult matchAndRewrite(Operation *op, |
| 235 | PatternRewriter &rewriter) const override { |
| 236 | // Custom comparator for two commutative operands, which returns true iff |
| 237 | // the "key" of `constCommOperandA` < the "key" of `constCommOperandB`, |
| 238 | // i.e., |
| 239 | // 1. In the first unequal pair of corresponding AncestorKeys, the |
| 240 | // AncestorKey in `constCommOperandA` is smaller, or, |
| 241 | // 2. Both the AncestorKeys in every pair are the same and the size of |
| 242 | // `constCommOperandA`'s "key" is smaller. |
| 243 | auto commutativeOperandComparator = |
| 244 | [](const std::unique_ptr<CommutativeOperand> &constCommOperandA, |
| 245 | const std::unique_ptr<CommutativeOperand> &constCommOperandB) { |
| 246 | if (constCommOperandA->operand == constCommOperandB->operand) |
| 247 | return false; |
| 248 | |
| 249 | auto &commOperandA = |
| 250 | const_cast<std::unique_ptr<CommutativeOperand> &>( |
| 251 | constCommOperandA); |
| 252 | auto &commOperandB = |
| 253 | const_cast<std::unique_ptr<CommutativeOperand> &>( |
| 254 | constCommOperandB); |
| 255 | |
| 256 | // Iteratively perform the BFS's of both operands until an order among |
| 257 | // them can be determined. |
| 258 | unsigned keyIndex = 0; |
| 259 | while (true) { |
| 260 | if (commOperandA->key.size() <= keyIndex) { |
| 261 | if (commOperandA->ancestorQueue.empty()) |
| 262 | return true; |
| 263 | commOperandA->popFrontAndPushAdjacentUnvisitedAncestors(); |
| 264 | commOperandA->refreshKey(); |
| 265 | } |
| 266 | if (commOperandB->key.size() <= keyIndex) { |
| 267 | if (commOperandB->ancestorQueue.empty()) |
| 268 | return false; |
| 269 | commOperandB->popFrontAndPushAdjacentUnvisitedAncestors(); |
| 270 | commOperandB->refreshKey(); |
| 271 | } |
| 272 | if (commOperandA->ancestorQueue.empty() || |
| 273 | commOperandB->ancestorQueue.empty()) |
| 274 | return commOperandA->key.size() < commOperandB->key.size(); |
| 275 | if (commOperandA->key[keyIndex] < commOperandB->key[keyIndex]) |
| 276 | return true; |
| 277 | if (commOperandB->key[keyIndex] < commOperandA->key[keyIndex]) |
| 278 | return false; |
| 279 | keyIndex++; |
| 280 | } |
| 281 | }; |
| 282 | |
| 283 | // If `op` is not commutative, do nothing. |
| 284 | if (!op->hasTrait<OpTrait::IsCommutative>()) |
| 285 | return failure(); |
| 286 | |
| 287 | // Populate the list of commutative operands. |
| 288 | SmallVector<Value, 2> operands = op->getOperands(); |
| 289 | SmallVector<std::unique_ptr<CommutativeOperand>, 2> commOperands; |
| 290 | for (Value operand : operands) { |
| 291 | std::unique_ptr<CommutativeOperand> commOperand = |
| 292 | std::make_unique<CommutativeOperand>(); |
| 293 | commOperand->operand = operand; |
| 294 | commOperand->pushAncestor(op: operand.getDefiningOp()); |
| 295 | commOperand->refreshKey(); |
| 296 | commOperands.push_back(Elt: std::move(commOperand)); |
| 297 | } |
| 298 | |
| 299 | // Sort the operands. |
| 300 | llvm::stable_sort(Range&: commOperands, C: commutativeOperandComparator); |
| 301 | SmallVector<Value, 2> sortedOperands; |
| 302 | for (const std::unique_ptr<CommutativeOperand> &commOperand : commOperands) |
| 303 | sortedOperands.push_back(Elt: commOperand->operand); |
| 304 | if (sortedOperands == operands) |
| 305 | return failure(); |
| 306 | rewriter.modifyOpInPlace(root: op, callable: [&] { op->setOperands(sortedOperands); }); |
| 307 | return success(); |
| 308 | } |
| 309 | }; |
| 310 | |
| 311 | void mlir::populateCommutativityUtilsPatterns(RewritePatternSet &patterns) { |
| 312 | patterns.add<SortCommutativeOperands>(arg: patterns.getContext()); |
| 313 | } |
| 314 | |