| 1 | //===- LoopInvariantCodeMotionUtils.cpp - LICM Utils ------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file contains the implementation of the core LICM algorithm. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" |
| 14 | |
| 15 | #include "mlir/IR/Operation.h" |
| 16 | #include "mlir/IR/PatternMatch.h" |
| 17 | #include "mlir/Interfaces/LoopLikeInterface.h" |
| 18 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
| 19 | #include "mlir/Interfaces/SubsetOpInterface.h" |
| 20 | #include "llvm/Support/Debug.h" |
| 21 | #include <queue> |
| 22 | |
| 23 | #define DEBUG_TYPE "licm" |
| 24 | |
| 25 | using namespace mlir; |
| 26 | |
| 27 | /// Checks whether the given op can be hoisted by checking that |
| 28 | /// - the op and none of its contained operations depend on values inside of the |
| 29 | /// loop (by means of calling definedOutside). |
| 30 | /// - the op has no side-effects. |
| 31 | static bool canBeHoisted(Operation *op, |
| 32 | function_ref<bool(OpOperand &)> condition) { |
| 33 | // Do not move terminators. |
| 34 | if (op->hasTrait<OpTrait::IsTerminator>()) |
| 35 | return false; |
| 36 | |
| 37 | // Walk the nested operations and check that all used values are either |
| 38 | // defined outside of the loop or in a nested region, but not at the level of |
| 39 | // the loop body. |
| 40 | auto walkFn = [&](Operation *child) { |
| 41 | for (OpOperand &operand : child->getOpOperands()) { |
| 42 | // Ignore values defined in a nested region. |
| 43 | if (op->isAncestor(other: operand.get().getParentRegion()->getParentOp())) |
| 44 | continue; |
| 45 | if (!condition(operand)) |
| 46 | return WalkResult::interrupt(); |
| 47 | } |
| 48 | return WalkResult::advance(); |
| 49 | }; |
| 50 | return !op->walk(callback&: walkFn).wasInterrupted(); |
| 51 | } |
| 52 | |
| 53 | static bool canBeHoisted(Operation *op, |
| 54 | function_ref<bool(Value)> definedOutside) { |
| 55 | return canBeHoisted( |
| 56 | op, condition: [&](OpOperand &operand) { return definedOutside(operand.get()); }); |
| 57 | } |
| 58 | |
| 59 | size_t mlir::moveLoopInvariantCode( |
| 60 | ArrayRef<Region *> regions, |
| 61 | function_ref<bool(Value, Region *)> isDefinedOutsideRegion, |
| 62 | function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion, |
| 63 | function_ref<void(Operation *, Region *)> moveOutOfRegion) { |
| 64 | size_t numMoved = 0; |
| 65 | |
| 66 | for (Region *region : regions) { |
| 67 | LLVM_DEBUG(llvm::dbgs() << "Original loop:\n" |
| 68 | << *region->getParentOp() << "\n" ); |
| 69 | |
| 70 | std::queue<Operation *> worklist; |
| 71 | // Add top-level operations in the loop body to the worklist. |
| 72 | for (Operation &op : region->getOps()) |
| 73 | worklist.push(x: &op); |
| 74 | |
| 75 | auto definedOutside = [&](Value value) { |
| 76 | return isDefinedOutsideRegion(value, region); |
| 77 | }; |
| 78 | |
| 79 | while (!worklist.empty()) { |
| 80 | Operation *op = worklist.front(); |
| 81 | worklist.pop(); |
| 82 | // Skip ops that have already been moved. Check if the op can be hoisted. |
| 83 | if (op->getParentRegion() != region) |
| 84 | continue; |
| 85 | |
| 86 | LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n" ); |
| 87 | if (!shouldMoveOutOfRegion(op, region) || |
| 88 | !canBeHoisted(op, definedOutside)) |
| 89 | continue; |
| 90 | |
| 91 | LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n" ); |
| 92 | moveOutOfRegion(op, region); |
| 93 | ++numMoved; |
| 94 | |
| 95 | // Since the op has been moved, we need to check its users within the |
| 96 | // top-level of the loop body. |
| 97 | for (Operation *user : op->getUsers()) |
| 98 | if (user->getParentRegion() == region) |
| 99 | worklist.push(x: user); |
| 100 | } |
| 101 | } |
| 102 | |
| 103 | return numMoved; |
| 104 | } |
| 105 | |
| 106 | size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) { |
| 107 | return moveLoopInvariantCode( |
| 108 | loopLike.getLoopRegions(), |
| 109 | [&](Value value, Region *) { |
| 110 | return loopLike.isDefinedOutsideOfLoop(value); |
| 111 | }, |
| 112 | [&](Operation *op, Region *) { |
| 113 | return isMemoryEffectFree(op) && isSpeculatable(op); |
| 114 | }, |
| 115 | [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); }); |
| 116 | } |
| 117 | |
| 118 | namespace { |
| 119 | /// Helper data structure that keeps track of equivalent/disjoint subset ops. |
| 120 | class MatchingSubsets { |
| 121 | public: |
| 122 | /// Insert a subset op. |
| 123 | void insert(SubsetOpInterface op, bool collectHoistableOps = true) { |
| 124 | allSubsetOps.push_back(op); |
| 125 | if (!collectHoistableOps) |
| 126 | return; |
| 127 | if (auto extractionOp = |
| 128 | dyn_cast<SubsetExtractionOpInterface>(op.getOperation())) |
| 129 | insertExtractionOp(extractionOp: extractionOp); |
| 130 | if (auto insertionOp = |
| 131 | dyn_cast<SubsetInsertionOpInterface>(op.getOperation())) |
| 132 | insertInsertionOp(insertionOp: insertionOp); |
| 133 | } |
| 134 | |
| 135 | /// Return a range of matching extraction-insertion subset ops. If there is no |
| 136 | /// matching extraction/insertion op, the respective value is empty. Ops are |
| 137 | /// skipped if there are other subset ops that are not guaranteed to operate |
| 138 | /// on disjoint subsets. |
| 139 | auto getHoistableSubsetOps() { |
| 140 | return llvm::make_filter_range( |
| 141 | llvm::zip(extractions, insertions), [&](auto pair) { |
| 142 | auto [extractionOp, insertionOp] = pair; |
| 143 | // Hoist only if the extracted and inserted values have the same type. |
| 144 | if (extractionOp && insertionOp && |
| 145 | extractionOp->getResult(0).getType() != |
| 146 | insertionOp.getSourceOperand().get().getType()) |
| 147 | return false; |
| 148 | // Hoist only if there are no conflicting subset ops. |
| 149 | return allDisjoint(extractionOp, insertionOp); |
| 150 | }); |
| 151 | } |
| 152 | |
| 153 | /// Populate subset ops starting from the given region iter_arg. Return |
| 154 | /// "failure" if non-subset ops are found along the path to the loop yielding |
| 155 | /// op or if there is no single path to the tied yielded operand. If |
| 156 | /// `collectHoistableOps` is set to "false", subset ops are gathered |
| 157 | /// throughout the traversal, but not enumerated by `getHoistableSubsetOps`. |
| 158 | LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike, |
| 159 | BlockArgument iterArg, |
| 160 | bool collectHoistableOps = true); |
| 161 | |
| 162 | private: |
| 163 | /// Helper function for equivalence of tensor values. Since only insertion |
| 164 | /// subset ops (that are also destination style ops) are followed when |
| 165 | /// traversing the SSA use-def chain, all tensor values are equivalent. |
| 166 | static bool isEquivalent(Value v1, Value v2) { return true; } |
| 167 | |
| 168 | /// Return "true" if the subsets of the given extraction and insertion ops |
| 169 | /// are operating disjoint from the subsets that all other known subset ops |
| 170 | /// are operating on. |
| 171 | bool (SubsetExtractionOpInterface , |
| 172 | SubsetInsertionOpInterface insertionOp) const { |
| 173 | for (SubsetOpInterface other : allSubsetOps) { |
| 174 | if (other == extractionOp || other == insertionOp) |
| 175 | continue; |
| 176 | if (extractionOp && |
| 177 | !other.operatesOnDisjointSubset(extractionOp, isEquivalent)) |
| 178 | return false; |
| 179 | if (insertionOp && |
| 180 | !other.operatesOnDisjointSubset(insertionOp, isEquivalent)) |
| 181 | return false; |
| 182 | } |
| 183 | return true; |
| 184 | } |
| 185 | |
| 186 | /// Insert a subset extraction op. If the subset is equivalent to an existing |
| 187 | /// subset insertion op, pair them up. (If there is already a paired up subset |
| 188 | /// extraction op, overwrite the subset extraction op.) |
| 189 | void (SubsetExtractionOpInterface ) { |
| 190 | for (auto it : llvm::enumerate(insertions)) { |
| 191 | if (!it.value()) |
| 192 | continue; |
| 193 | auto other = cast<SubsetOpInterface>(it.value().getOperation()); |
| 194 | if (other.operatesOnEquivalentSubset(extractionOp, isEquivalent)) { |
| 195 | extractions[it.index()] = extractionOp; |
| 196 | return; |
| 197 | } |
| 198 | } |
| 199 | // There is no known equivalent insertion op. Create a new entry. |
| 200 | extractions.push_back(extractionOp); |
| 201 | insertions.push_back({}); |
| 202 | } |
| 203 | |
| 204 | /// Insert a subset insertion op. If the subset is equivalent to an existing |
| 205 | /// subset extraction op, pair them up. (If there is already a paired up |
| 206 | /// subset insertion op, overwrite the subset insertion op.) |
| 207 | void insertInsertionOp(SubsetInsertionOpInterface insertionOp) { |
| 208 | for (auto it : llvm::enumerate(extractions)) { |
| 209 | if (!it.value()) |
| 210 | continue; |
| 211 | auto other = cast<SubsetOpInterface>(it.value().getOperation()); |
| 212 | if (other.operatesOnEquivalentSubset(insertionOp, isEquivalent)) { |
| 213 | insertions[it.index()] = insertionOp; |
| 214 | return; |
| 215 | } |
| 216 | } |
| 217 | // There is no known equivalent extraction op. Create a new entry. |
| 218 | extractions.push_back({}); |
| 219 | insertions.push_back(insertionOp); |
| 220 | } |
| 221 | |
| 222 | SmallVector<SubsetExtractionOpInterface> ; |
| 223 | SmallVector<SubsetInsertionOpInterface> insertions; |
| 224 | SmallVector<SubsetOpInterface> allSubsetOps; |
| 225 | }; |
| 226 | } // namespace |
| 227 | |
| 228 | /// If the given value has a single use by an op that is a terminator, return |
| 229 | /// that use. Otherwise, return nullptr. |
| 230 | static OpOperand *getSingleTerminatorUse(Value value) { |
| 231 | if (!value.hasOneUse()) |
| 232 | return nullptr; |
| 233 | OpOperand &use = *value.getUses().begin(); |
| 234 | if (use.getOwner()->hasTrait<OpTrait::IsTerminator>()) |
| 235 | return &use; |
| 236 | return nullptr; |
| 237 | } |
| 238 | |
| 239 | LogicalResult |
| 240 | MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike, |
| 241 | BlockArgument iterArg, |
| 242 | bool collectHoistableOps) { |
| 243 | assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg" ); |
| 244 | Value value = iterArg; |
| 245 | |
| 246 | // Traverse use-def chain. Subset ops can be hoisted only if all ops along the |
| 247 | // use-def chain starting from the region iter_arg are subset extraction or |
| 248 | // subset insertion ops. The chain must terminate at the corresponding yield |
| 249 | // operand (e.g., no swapping of iter_args). |
| 250 | OpOperand *yieldedOperand = nullptr; |
| 251 | // Iterate until the single use of the current SSA value is a terminator, |
| 252 | // which is expected to be the yielding operation of the loop. |
| 253 | while (!(yieldedOperand = getSingleTerminatorUse(value))) { |
| 254 | Value nextValue = {}; |
| 255 | |
| 256 | for (OpOperand &use : value.getUses()) { |
| 257 | if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) { |
| 258 | // Subset ops in nested loops are collected to check if there are only |
| 259 | // disjoint subset ops, but such subset ops are not subject to hoisting. |
| 260 | // To hoist subset ops from nested loops, the hoisting transformation |
| 261 | // should be run on the nested loop. |
| 262 | auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use); |
| 263 | if (!nestedIterArg) |
| 264 | return failure(); |
| 265 | // Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA |
| 266 | // use-def chain starting at `nestedIterArg` and terminating in the |
| 267 | // tied, yielding operand. |
| 268 | if (failed(populateSubsetOpsAtIterArg(loopLike: nestedLoop, iterArg: nestedIterArg, |
| 269 | /*collectHoistableOps=*/false))) |
| 270 | return failure(); |
| 271 | nextValue = nestedLoop.getTiedLoopResult(&use); |
| 272 | continue; |
| 273 | } |
| 274 | |
| 275 | auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner()); |
| 276 | if (!subsetOp) |
| 277 | return failure(); |
| 278 | insert(op: subsetOp); |
| 279 | |
| 280 | if (auto insertionOp = |
| 281 | dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) { |
| 282 | // Current implementation expects that the insertionOp implement |
| 283 | // the DestinationStyleOpInterface and with pure tensor semantics |
| 284 | // as well. Abort if that is not the case. |
| 285 | auto dstOp = dyn_cast<DestinationStyleOpInterface>(use.getOwner()); |
| 286 | if (!dstOp || !dstOp.hasPureTensorSemantics()) |
| 287 | return failure(); |
| 288 | |
| 289 | // The value must be used as a destination. (In case of a source, the |
| 290 | // entire tensor would be read, which would prevent any hoisting.) |
| 291 | if (&use != &insertionOp.getDestinationOperand()) |
| 292 | return failure(); |
| 293 | // There must be a single use-def chain from the region iter_arg to the |
| 294 | // terminator. I.e., only one insertion op. Branches are not supported. |
| 295 | if (nextValue) |
| 296 | return failure(); |
| 297 | nextValue = insertionOp.getUpdatedDestination(); |
| 298 | } |
| 299 | } |
| 300 | |
| 301 | // Nothing can be hoisted if the chain does not continue with loop yielding |
| 302 | // op or a subset insertion op. |
| 303 | if (!nextValue) |
| 304 | return failure(); |
| 305 | value = nextValue; |
| 306 | } |
| 307 | |
| 308 | // Hoist only if the SSA use-def chain ends in the yielding terminator of the |
| 309 | // loop and the yielded value is the `idx`-th operand. (I.e., there is no |
| 310 | // swapping yield.) |
| 311 | if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand) |
| 312 | return failure(); |
| 313 | |
| 314 | return success(); |
| 315 | } |
| 316 | |
| 317 | /// Hoist all subset ops that operate on the idx-th region iter_arg of the given |
| 318 | /// loop-like op and index into loop-invariant subset locations. Return the |
| 319 | /// newly created loop op (that has extra iter_args) or the original loop op if |
| 320 | /// nothing was hoisted. |
| 321 | static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter, |
| 322 | LoopLikeOpInterface loopLike, |
| 323 | BlockArgument iterArg) { |
| 324 | assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg" ); |
| 325 | auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg); |
| 326 | int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it); |
| 327 | MatchingSubsets subsets; |
| 328 | if (failed(subsets.populateSubsetOpsAtIterArg(loopLike: loopLike, iterArg))) |
| 329 | return loopLike; |
| 330 | |
| 331 | // Hoist all matching extraction-insertion pairs one-by-one. |
| 332 | for (auto it : subsets.getHoistableSubsetOps()) { |
| 333 | auto extractionOp = std::get<0>(it); |
| 334 | auto insertionOp = std::get<1>(it); |
| 335 | |
| 336 | // Ops cannot be hoisted if they depend on loop-variant values. |
| 337 | if (extractionOp) { |
| 338 | if (!canBeHoisted(extractionOp, [&](OpOperand &operand) { |
| 339 | return loopLike.isDefinedOutsideOfLoop(operand.get()) || |
| 340 | &operand == &extractionOp.getSourceOperand(); |
| 341 | })) |
| 342 | extractionOp = {}; |
| 343 | } |
| 344 | if (insertionOp) { |
| 345 | if (!canBeHoisted(insertionOp, [&](OpOperand &operand) { |
| 346 | return loopLike.isDefinedOutsideOfLoop(operand.get()) || |
| 347 | &operand == &insertionOp.getSourceOperand() || |
| 348 | &operand == &insertionOp.getDestinationOperand(); |
| 349 | })) |
| 350 | insertionOp = {}; |
| 351 | } |
| 352 | |
| 353 | // Only hoist extraction-insertion pairs for now. Standalone extractions/ |
| 354 | // insertions that are loop-invariant could be hoisted, but there may be |
| 355 | // easier ways to canonicalize the IR. |
| 356 | if (extractionOp && insertionOp) { |
| 357 | // Create a new loop with an additional iter_arg. |
| 358 | NewYieldValuesFn newYieldValuesFn = |
| 359 | [&](OpBuilder &b, Location loc, |
| 360 | ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> { |
| 361 | return {insertionOp.getSourceOperand().get()}; |
| 362 | }; |
| 363 | FailureOr<LoopLikeOpInterface> newLoop = |
| 364 | loopLike.replaceWithAdditionalYields( |
| 365 | rewriter, extractionOp.getResult(), |
| 366 | /*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn); |
| 367 | if (failed(newLoop)) |
| 368 | return loopLike; |
| 369 | loopLike = *newLoop; |
| 370 | |
| 371 | // Hoist the extraction/insertion ops. |
| 372 | iterArg = loopLike.getRegionIterArgs()[iterArgIdx]; |
| 373 | OpResult loopResult = loopLike.getTiedLoopResult(iterArg); |
| 374 | OpResult newLoopResult = loopLike.getLoopResults()->back(); |
| 375 | rewriter.moveOpBefore(extractionOp, loopLike); |
| 376 | rewriter.moveOpAfter(insertionOp, loopLike); |
| 377 | rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(), |
| 378 | insertionOp.getDestinationOperand().get()); |
| 379 | extractionOp.getSourceOperand().set( |
| 380 | loopLike.getTiedLoopInit(iterArg)->get()); |
| 381 | rewriter.replaceAllUsesWith(loopResult, |
| 382 | insertionOp.getUpdatedDestination()); |
| 383 | insertionOp.getSourceOperand().set(newLoopResult); |
| 384 | insertionOp.getDestinationOperand().set(loopResult); |
| 385 | } |
| 386 | } |
| 387 | |
| 388 | return loopLike; |
| 389 | } |
| 390 | |
| 391 | LoopLikeOpInterface |
| 392 | mlir::hoistLoopInvariantSubsets(RewriterBase &rewriter, |
| 393 | LoopLikeOpInterface loopLike) { |
| 394 | // Note: As subset ops are getting hoisted, the number of region iter_args |
| 395 | // increases. This can enable further hoisting opportunities on the new |
| 396 | // iter_args. |
| 397 | for (int64_t i = 0; |
| 398 | i < static_cast<int64_t>(loopLike.getRegionIterArgs().size()); ++i) { |
| 399 | loopLike = hoistSubsetAtIterArg(rewriter, loopLike, |
| 400 | loopLike.getRegionIterArgs()[i]); |
| 401 | } |
| 402 | return loopLike; |
| 403 | } |
| 404 | |