| 1 | //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements functions concerned with hoisting invariant operations |
| 10 | // in the context of Linalg transformations. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" |
| 15 | #include "mlir/Analysis/SliceAnalysis.h" |
| 16 | #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" |
| 17 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 18 | #include "mlir/Dialect/Affine/IR/AffineValueMap.h" |
| 19 | #include "mlir/Dialect/Affine/Utils.h" |
| 20 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 21 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 22 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 23 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| 24 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 25 | #include "mlir/Dialect/SCF/Utils/Utils.h" |
| 26 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 27 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 28 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
| 29 | #include "mlir/IR/BuiltinOps.h" |
| 30 | #include "mlir/IR/Dominance.h" |
| 31 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 32 | #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" |
| 33 | #include "llvm/ADT/StringRef.h" |
| 34 | #include "llvm/ADT/TypeSwitch.h" |
| 35 | #include "llvm/Support/Debug.h" |
| 36 | |
| 37 | using llvm::dbgs; |
| 38 | |
| 39 | #define DEBUG_TYPE "linalg-hoisting" |
| 40 | |
| 41 | #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") |
| 42 | |
| 43 | using namespace mlir; |
| 44 | using namespace mlir::linalg; |
| 45 | |
| 46 | /// Replace `loop` with a new loop that has a different init operand at |
| 47 | /// position `index`. The body of this loop is moved over to the new loop. |
| 48 | /// |
| 49 | /// `newInitOperands` specifies the replacement "init" operands. |
| 50 | /// `newYieldValue` is the replacement yield value of the loop at position |
| 51 | /// `index`. |
| 52 | static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, |
| 53 | scf::ForOp loop, |
| 54 | Value newInitOperand, |
| 55 | unsigned index, |
| 56 | Value newYieldValue) { |
| 57 | OpBuilder::InsertionGuard g(rewriter); |
| 58 | rewriter.setInsertionPoint(loop.getOperation()); |
| 59 | auto inits = llvm::to_vector(loop.getInits()); |
| 60 | |
| 61 | // Replace the init value with the new operand. |
| 62 | assert(index < inits.size()); |
| 63 | inits[index] = newInitOperand; |
| 64 | |
| 65 | scf::ForOp newLoop = rewriter.create<scf::ForOp>( |
| 66 | loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), |
| 67 | inits, [](OpBuilder &, Location, Value, ValueRange) {}); |
| 68 | |
| 69 | // Generate the new yield with the replaced operand. |
| 70 | auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator()); |
| 71 | yieldOp.setOperand(index, newYieldValue); |
| 72 | |
| 73 | // Move the loop body to the new op. |
| 74 | rewriter.mergeBlocks(source: loop.getBody(), dest: newLoop.getBody(), |
| 75 | argValues: newLoop.getBody()->getArguments()); |
| 76 | |
| 77 | // Replace the old loop. |
| 78 | rewriter.replaceOp(loop.getOperation(), newLoop->getResults()); |
| 79 | return newLoop; |
| 80 | } |
| 81 | |
| 82 | // Hoist out a pair of corresponding vector.extract+vector.broadcast |
| 83 | // operations. This function transforms a loop like this: |
| 84 | // %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) { |
| 85 | // %e = vector.extract %iarg : t1 to t2 |
| 86 | // %u = "some_use"(%e) : (t2) -> t2 |
| 87 | // %b = vector.broadcast %u : t2 to t1 |
| 88 | // scf.yield %b : t1 |
| 89 | // } |
| 90 | // into the following: |
| 91 | // %e = vector.extract %v: t1 to t2 |
| 92 | // %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) { |
| 93 | // %u' = "some_use"(%iarg) : (t2) -> t2 |
| 94 | // scf.yield %u' : t2 |
| 95 | // } |
| 96 | // %res = vector.broadcast %res' : t2 to t1 |
| 97 | void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter, |
| 98 | Operation *root) { |
| 99 | bool changed = true; |
| 100 | while (changed) { |
| 101 | changed = false; |
| 102 | // First move loop invariant ops outside of their loop. This needs to be |
| 103 | // done before as we cannot move ops without interrupting the function walk. |
| 104 | root->walk( |
| 105 | [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); |
| 106 | |
| 107 | root->walk(callback: [&](vector::ExtractOp ) { |
| 108 | LLVM_DEBUG(DBGS() << "Candidate for hoisting: " |
| 109 | << *extractOp.getOperation() << "\n" ); |
| 110 | |
| 111 | auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp()); |
| 112 | if (!loop) |
| 113 | return WalkResult::advance(); |
| 114 | |
| 115 | // Check that the vector to extract from is a BlockArgument. |
| 116 | auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector()); |
| 117 | if (!blockArg) |
| 118 | return WalkResult::advance(); |
| 119 | |
| 120 | // Check that the blockArg is an iter_arg of the loop. |
| 121 | OpOperand *initArg = loop.getTiedLoopInit(blockArg); |
| 122 | if (!initArg) |
| 123 | return WalkResult::advance(); |
| 124 | |
| 125 | // If the iter_arg does not have only one use, it won't be possible to |
| 126 | // hoist the extractOp out. |
| 127 | if (!blockArg.hasOneUse()) |
| 128 | return WalkResult::advance(); |
| 129 | |
| 130 | unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars(); |
| 131 | |
| 132 | // Check that the loop yields a broadcast that has just one use. |
| 133 | Operation *yieldedVal = |
| 134 | loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp(); |
| 135 | auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal); |
| 136 | if (!broadcast || !broadcast.getResult().hasOneUse()) |
| 137 | return WalkResult::advance(); |
| 138 | |
| 139 | LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n" ); |
| 140 | |
| 141 | Type broadcastInputType = broadcast.getSourceType(); |
| 142 | if (broadcastInputType != extractOp.getType()) |
| 143 | return WalkResult::advance(); |
| 144 | |
| 145 | // The position of the extract must be defined outside of the loop if |
| 146 | // it is dynamic. |
| 147 | for (auto operand : extractOp.getDynamicPosition()) |
| 148 | if (!loop.isDefinedOutsideOfLoop(operand)) |
| 149 | return WalkResult::advance(); |
| 150 | |
| 151 | rewriter.modifyOpInPlace(broadcast, [&] { |
| 152 | extractOp.getVectorMutable().assign(initArg->get()); |
| 153 | }); |
| 154 | loop.moveOutOfLoop(extractOp); |
| 155 | rewriter.moveOpAfter(broadcast, loop); |
| 156 | |
| 157 | scf::ForOp newLoop = replaceWithDifferentYield( |
| 158 | rewriter, loop, extractOp.getResult(), index, broadcast.getSource()); |
| 159 | |
| 160 | LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n" ); |
| 161 | |
| 162 | rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast); |
| 163 | rewriter.modifyOpInPlace( |
| 164 | broadcast, [&] { broadcast.setOperand(newLoop.getResult(index)); }); |
| 165 | |
| 166 | changed = true; |
| 167 | return WalkResult::interrupt(); |
| 168 | }); |
| 169 | } |
| 170 | } |
| 171 | |
| 172 | static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, |
| 173 | LoopLikeOpInterface loop) { |
| 174 | Value source = transferRead.getBase(); |
| 175 | |
| 176 | // Skip view-like Ops and retrive the actual soruce Operation |
| 177 | while (auto srcOp = |
| 178 | dyn_cast_or_null<ViewLikeOpInterface>(source.getDefiningOp())) |
| 179 | source = srcOp.getViewSource(); |
| 180 | |
| 181 | llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), |
| 182 | source.getUsers().end()); |
| 183 | llvm::SmallDenseSet<Operation *, 32> processed; |
| 184 | while (!users.empty()) { |
| 185 | Operation *user = users.pop_back_val(); |
| 186 | // If the user has already been processed skip. |
| 187 | if (!processed.insert(V: user).second) |
| 188 | continue; |
| 189 | if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) { |
| 190 | users.append(viewLike->getUsers().begin(), viewLike->getUsers().end()); |
| 191 | continue; |
| 192 | } |
| 193 | if (isMemoryEffectFree(op: user) || isa<vector::TransferReadOp>(Val: user)) |
| 194 | continue; |
| 195 | if (!loop->isAncestor(user)) |
| 196 | continue; |
| 197 | return false; |
| 198 | } |
| 199 | return true; |
| 200 | } |
| 201 | |
| 202 | void mlir::linalg::hoistRedundantVectorTransfers(Operation *root, |
| 203 | bool verifyNonZeroTrip) { |
| 204 | bool changed = true; |
| 205 | while (changed) { |
| 206 | changed = false; |
| 207 | // First move loop invariant ops outside of their loop. This needs to be |
| 208 | // done before as we cannot move ops without interrupting the function walk. |
| 209 | root->walk( |
| 210 | [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); |
| 211 | |
| 212 | // Find all loops that are certain to have non zero trip count. Any loops |
| 213 | // that are not part of this set cannot be hoisted from, since hoisting from |
| 214 | // a potentially zero trip count loop may cause a vector transfer to be |
| 215 | // executed when it shouldn't be. |
| 216 | llvm::DenseSet<LoopLikeOpInterface> definiteNonZeroTripCountLoops; |
| 217 | if (verifyNonZeroTrip) { |
| 218 | root->walk([&](LoopLikeOpInterface loopLike) { |
| 219 | std::optional<SmallVector<OpFoldResult>> lbs = |
| 220 | loopLike.getLoopLowerBounds(); |
| 221 | std::optional<SmallVector<OpFoldResult>> ubs = |
| 222 | loopLike.getLoopUpperBounds(); |
| 223 | // If loop bounds cannot be found, assume possibly zero trip count. |
| 224 | if (!lbs || !ubs) |
| 225 | return; |
| 226 | |
| 227 | // Otherwise, use ValueBounds to find the maximum lower bound and |
| 228 | // minimum upper bound. If the bounds are found, and maxLb is less |
| 229 | // than the minUb, then the loop will not have zero trip count. |
| 230 | for (auto [lb, ub] : llvm::zip_equal(lbs.value(), ubs.value())) { |
| 231 | FailureOr<int64_t> maxLb = |
| 232 | ValueBoundsConstraintSet::computeConstantBound( |
| 233 | presburger::BoundType::UB, lb, |
| 234 | /*stopCondition=*/nullptr, /*closedUB=*/true); |
| 235 | if (failed(maxLb)) |
| 236 | return; |
| 237 | FailureOr<int64_t> minUb = |
| 238 | ValueBoundsConstraintSet::computeConstantBound( |
| 239 | presburger::BoundType::LB, ub); |
| 240 | if (failed(minUb)) |
| 241 | return; |
| 242 | if (minUb.value() <= maxLb.value()) |
| 243 | return; |
| 244 | definiteNonZeroTripCountLoops.insert(loopLike); |
| 245 | } |
| 246 | }); |
| 247 | } |
| 248 | |
| 249 | root->walk([&](vector::TransferReadOp transferRead) { |
| 250 | if (!isa<MemRefType>(transferRead.getShapedType())) |
| 251 | return WalkResult::advance(); |
| 252 | |
| 253 | LLVM_DEBUG(DBGS() << "Candidate for hoisting: " |
| 254 | << *transferRead.getOperation() << "\n" ); |
| 255 | auto loop = dyn_cast<LoopLikeOpInterface>(transferRead->getParentOp()); |
| 256 | LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp() |
| 257 | << "\n" ); |
| 258 | if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop)) |
| 259 | return WalkResult::advance(); |
| 260 | |
| 261 | if (verifyNonZeroTrip && !definiteNonZeroTripCountLoops.contains(V: loop)) { |
| 262 | LLVM_DEBUG(DBGS() << "Loop may have zero trip count: " << *loop |
| 263 | << "\n" ); |
| 264 | return WalkResult::advance(); |
| 265 | } |
| 266 | |
| 267 | LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() |
| 268 | << "\n" ); |
| 269 | |
| 270 | SetVector<Operation *> forwardSlice; |
| 271 | getForwardSlice(transferRead.getOperation(), &forwardSlice); |
| 272 | |
| 273 | // Look for the last TransferWriteOp in the forwardSlice of |
| 274 | // `transferRead` that operates on the same memref. |
| 275 | vector::TransferWriteOp transferWrite; |
| 276 | for (auto *sliceOp : llvm::reverse(C&: forwardSlice)) { |
| 277 | auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp); |
| 278 | if (!candidateWrite || |
| 279 | candidateWrite.getBase() != transferRead.getBase()) |
| 280 | continue; |
| 281 | transferWrite = candidateWrite; |
| 282 | } |
| 283 | |
| 284 | // All operands of the TransferRead must be defined outside of the loop. |
| 285 | for (auto operand : transferRead.getOperands()) |
| 286 | if (!loop.isDefinedOutsideOfLoop(operand)) |
| 287 | return WalkResult::advance(); |
| 288 | |
| 289 | // Only hoist transfer_read / transfer_write pairs and singleton |
| 290 | // transfer_reads for now. |
| 291 | if (!transferWrite) { |
| 292 | // Make sure there are no other accesses to the memref before |
| 293 | // hoisting transfer_read. |
| 294 | if (noAliasingUseInLoop(transferRead, loop)) |
| 295 | loop.moveOutOfLoop(transferRead); |
| 296 | return WalkResult::advance(); |
| 297 | } |
| 298 | |
| 299 | LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() |
| 300 | << "\n" ); |
| 301 | |
| 302 | // Approximate aliasing by checking that: |
| 303 | // 1. indices, vector type and permutation map are the same (i.e., the |
| 304 | // transfer_read/transfer_write ops are matching), |
| 305 | // 2. source operands for transfer.{read|write} do not originate from |
| 306 | // Ops implementing ViewLikeOpInterface. |
| 307 | // 3. no other operations in the loop access the same memref except |
| 308 | // for transfer_read/transfer_write accessing statically disjoint |
| 309 | // slices. |
| 310 | if (transferRead.getIndices() != transferWrite.getIndices() || |
| 311 | transferRead.getVectorType() != transferWrite.getVectorType() || |
| 312 | transferRead.getPermutationMap() != transferWrite.getPermutationMap()) |
| 313 | return WalkResult::advance(); |
| 314 | |
| 315 | auto *source = transferRead.getBase().getDefiningOp(); |
| 316 | if (source && isa_and_nonnull<ViewLikeOpInterface>(source)) |
| 317 | return WalkResult::advance(); |
| 318 | |
| 319 | source = transferWrite.getBase().getDefiningOp(); |
| 320 | if (source && isa_and_nonnull<ViewLikeOpInterface>(source)) |
| 321 | return WalkResult::advance(); |
| 322 | |
| 323 | // TODO: may want to memoize this information for performance but it |
| 324 | // likely gets invalidated often. |
| 325 | DominanceInfo dom(loop); |
| 326 | if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) |
| 327 | return WalkResult::advance(); |
| 328 | for (auto &use : transferRead.getBase().getUses()) { |
| 329 | if (!loop->isAncestor(use.getOwner())) |
| 330 | continue; |
| 331 | if (use.getOwner() == transferRead.getOperation() || |
| 332 | use.getOwner() == transferWrite.getOperation()) |
| 333 | continue; |
| 334 | if (auto transferWriteUse = |
| 335 | dyn_cast<vector::TransferWriteOp>(use.getOwner())) { |
| 336 | if (!vector::isDisjointTransferSet( |
| 337 | cast<VectorTransferOpInterface>(*transferWrite), |
| 338 | cast<VectorTransferOpInterface>(*transferWriteUse), |
| 339 | /*testDynamicValueUsingBounds=*/true)) |
| 340 | return WalkResult::advance(); |
| 341 | } else if (auto transferReadUse = |
| 342 | dyn_cast<vector::TransferReadOp>(use.getOwner())) { |
| 343 | if (!vector::isDisjointTransferSet( |
| 344 | cast<VectorTransferOpInterface>(*transferWrite), |
| 345 | cast<VectorTransferOpInterface>(*transferReadUse), |
| 346 | /*testDynamicValueUsingBounds=*/true)) |
| 347 | return WalkResult::advance(); |
| 348 | } else { |
| 349 | // Unknown use, we cannot prove that it doesn't alias with the |
| 350 | // transferRead/transferWrite operations. |
| 351 | return WalkResult::advance(); |
| 352 | } |
| 353 | } |
| 354 | |
| 355 | // Hoist read before. |
| 356 | loop.moveOutOfLoop(transferRead); |
| 357 | |
| 358 | // Hoist write after. |
| 359 | transferWrite->moveAfter(loop); |
| 360 | |
| 361 | // Rewrite `loop` with new yields by cloning and erase the original loop. |
| 362 | IRRewriter rewriter(transferRead.getContext()); |
| 363 | NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc, |
| 364 | ArrayRef<BlockArgument> newBBArgs) { |
| 365 | return SmallVector<Value>{transferWrite.getVector()}; |
| 366 | }; |
| 367 | |
| 368 | auto maybeNewLoop = loop.replaceWithAdditionalYields( |
| 369 | rewriter, transferRead.getVector(), |
| 370 | /*replaceInitOperandUsesInLoop=*/true, yieldFn); |
| 371 | if (failed(maybeNewLoop)) |
| 372 | return WalkResult::interrupt(); |
| 373 | |
| 374 | transferWrite.getValueToStoreMutable().assign( |
| 375 | maybeNewLoop->getOperation()->getResults().back()); |
| 376 | changed = true; |
| 377 | // Need to interrupt and restart because erasing the loop messes up |
| 378 | // the walk. |
| 379 | return WalkResult::interrupt(); |
| 380 | }); |
| 381 | } |
| 382 | } |
| 383 | |