| 1 | //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements functions concerned with hoisting invariant operations |
| 10 | // in the context of Linalg transformations. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" |
| 15 | #include "mlir/Analysis/SliceAnalysis.h" |
| 16 | #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" |
| 17 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 18 | #include "mlir/Dialect/Affine/Utils.h" |
| 19 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| 20 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 21 | #include "mlir/Dialect/SCF/Utils/Utils.h" |
| 22 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 23 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
| 24 | #include "mlir/IR/Dominance.h" |
| 25 | #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" |
| 26 | #include "llvm/Support/Debug.h" |
| 27 | |
| 28 | using llvm::dbgs; |
| 29 | |
| 30 | #define DEBUG_TYPE "linalg-hoisting" |
| 31 | |
| 32 | #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") |
| 33 | |
| 34 | using namespace mlir; |
| 35 | using namespace mlir::linalg; |
| 36 | |
| 37 | /// Replace `loop` with a new loop that has a different init operand at |
| 38 | /// position `index`. The body of this loop is moved over to the new loop. |
| 39 | /// |
| 40 | /// `newInitOperands` specifies the replacement "init" operands. |
| 41 | /// `newYieldValue` is the replacement yield value of the loop at position |
| 42 | /// `index`. |
| 43 | static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, |
| 44 | scf::ForOp loop, |
| 45 | Value newInitOperand, |
| 46 | unsigned index, |
| 47 | Value newYieldValue) { |
| 48 | OpBuilder::InsertionGuard g(rewriter); |
| 49 | rewriter.setInsertionPoint(loop.getOperation()); |
| 50 | auto inits = llvm::to_vector(Range: loop.getInits()); |
| 51 | |
| 52 | // Replace the init value with the new operand. |
| 53 | assert(index < inits.size()); |
| 54 | inits[index] = newInitOperand; |
| 55 | |
| 56 | scf::ForOp newLoop = rewriter.create<scf::ForOp>( |
| 57 | location: loop.getLoc(), args: loop.getLowerBound(), args: loop.getUpperBound(), args: loop.getStep(), |
| 58 | args&: inits, args: [](OpBuilder &, Location, Value, ValueRange) {}); |
| 59 | |
| 60 | // Generate the new yield with the replaced operand. |
| 61 | auto yieldOp = cast<scf::YieldOp>(Val: loop.getBody()->getTerminator()); |
| 62 | yieldOp.setOperand(i: index, value: newYieldValue); |
| 63 | |
| 64 | // Move the loop body to the new op. |
| 65 | rewriter.mergeBlocks(source: loop.getBody(), dest: newLoop.getBody(), |
| 66 | argValues: newLoop.getBody()->getArguments()); |
| 67 | |
| 68 | // Replace the old loop. |
| 69 | rewriter.replaceOp(op: loop.getOperation(), newValues: newLoop->getResults()); |
| 70 | return newLoop; |
| 71 | } |
| 72 | |
| 73 | // Hoist out a pair of corresponding vector.extract+vector.broadcast |
| 74 | // operations. This function transforms a loop like this: |
| 75 | // %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) { |
| 76 | // %e = vector.extract %iarg : t1 to t2 |
| 77 | // %u = "some_use"(%e) : (t2) -> t2 |
| 78 | // %b = vector.broadcast %u : t2 to t1 |
| 79 | // scf.yield %b : t1 |
| 80 | // } |
| 81 | // into the following: |
| 82 | // %e = vector.extract %v: t1 to t2 |
| 83 | // %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) { |
| 84 | // %u' = "some_use"(%iarg) : (t2) -> t2 |
| 85 | // scf.yield %u' : t2 |
| 86 | // } |
| 87 | // %res = vector.broadcast %res' : t2 to t1 |
| 88 | void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter, |
| 89 | Operation *root) { |
| 90 | bool changed = true; |
| 91 | while (changed) { |
| 92 | changed = false; |
| 93 | // First move loop invariant ops outside of their loop. This needs to be |
| 94 | // done before as we cannot move ops without interrupting the function walk. |
| 95 | root->walk( |
| 96 | callback: [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); |
| 97 | |
| 98 | root->walk(callback: [&](vector::ExtractOp ) { |
| 99 | LLVM_DEBUG(DBGS() << "Candidate for hoisting: " |
| 100 | << *extractOp.getOperation() << "\n" ); |
| 101 | |
| 102 | auto loop = dyn_cast<scf::ForOp>(Val: extractOp->getParentOp()); |
| 103 | if (!loop) |
| 104 | return WalkResult::advance(); |
| 105 | |
| 106 | // Check that the vector to extract from is a BlockArgument. |
| 107 | auto blockArg = dyn_cast<BlockArgument>(Val: extractOp.getVector()); |
| 108 | if (!blockArg) |
| 109 | return WalkResult::advance(); |
| 110 | |
| 111 | // Check that the blockArg is an iter_arg of the loop. |
| 112 | OpOperand *initArg = loop.getTiedLoopInit(bbArg: blockArg); |
| 113 | if (!initArg) |
| 114 | return WalkResult::advance(); |
| 115 | |
| 116 | // If the iter_arg does not have only one use, it won't be possible to |
| 117 | // hoist the extractOp out. |
| 118 | if (!blockArg.hasOneUse()) |
| 119 | return WalkResult::advance(); |
| 120 | |
| 121 | unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars(); |
| 122 | |
| 123 | // Check that the loop yields a broadcast that has just one use. |
| 124 | Operation *yieldedVal = |
| 125 | loop.getTiedLoopYieldedValue(bbArg: blockArg)->get().getDefiningOp(); |
| 126 | auto broadcast = dyn_cast<vector::BroadcastOp>(Val: yieldedVal); |
| 127 | if (!broadcast || !broadcast.getResult().hasOneUse()) |
| 128 | return WalkResult::advance(); |
| 129 | |
| 130 | LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n" ); |
| 131 | |
| 132 | Type broadcastInputType = broadcast.getSourceType(); |
| 133 | if (broadcastInputType != extractOp.getType()) |
| 134 | return WalkResult::advance(); |
| 135 | |
| 136 | // The position of the extract must be defined outside of the loop if |
| 137 | // it is dynamic. |
| 138 | for (auto operand : extractOp.getDynamicPosition()) |
| 139 | if (!loop.isDefinedOutsideOfLoop(value: operand)) |
| 140 | return WalkResult::advance(); |
| 141 | |
| 142 | rewriter.modifyOpInPlace(root: broadcast, callable: [&] { |
| 143 | extractOp.getVectorMutable().assign(value: initArg->get()); |
| 144 | }); |
| 145 | loop.moveOutOfLoop(op: extractOp); |
| 146 | rewriter.moveOpAfter(op: broadcast, existingOp: loop); |
| 147 | |
| 148 | scf::ForOp newLoop = replaceWithDifferentYield( |
| 149 | rewriter, loop, newInitOperand: extractOp.getResult(), index, newYieldValue: broadcast.getSource()); |
| 150 | |
| 151 | LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n" ); |
| 152 | |
| 153 | rewriter.replaceAllUsesWith(from: newLoop.getResult(i: index), to: broadcast); |
| 154 | rewriter.modifyOpInPlace( |
| 155 | root: broadcast, callable: [&] { broadcast.setOperand(newLoop.getResult(i: index)); }); |
| 156 | |
| 157 | changed = true; |
| 158 | return WalkResult::interrupt(); |
| 159 | }); |
| 160 | } |
| 161 | } |
| 162 | |
| 163 | static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, |
| 164 | LoopLikeOpInterface loop) { |
| 165 | Value source = transferRead.getBase(); |
| 166 | |
| 167 | // Skip view-like Ops and retrive the actual soruce Operation |
| 168 | while (auto srcOp = |
| 169 | dyn_cast_or_null<ViewLikeOpInterface>(Val: source.getDefiningOp())) |
| 170 | source = srcOp.getViewSource(); |
| 171 | |
| 172 | llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), |
| 173 | source.getUsers().end()); |
| 174 | llvm::SmallDenseSet<Operation *, 32> processed; |
| 175 | while (!users.empty()) { |
| 176 | Operation *user = users.pop_back_val(); |
| 177 | // If the user has already been processed skip. |
| 178 | if (!processed.insert(V: user).second) |
| 179 | continue; |
| 180 | if (auto viewLike = dyn_cast<ViewLikeOpInterface>(Val: user)) { |
| 181 | users.append(in_start: viewLike->getUsers().begin(), in_end: viewLike->getUsers().end()); |
| 182 | continue; |
| 183 | } |
| 184 | if (isMemoryEffectFree(op: user) || isa<vector::TransferReadOp>(Val: user)) |
| 185 | continue; |
| 186 | if (!loop->isAncestor(other: user)) |
| 187 | continue; |
| 188 | return false; |
| 189 | } |
| 190 | return true; |
| 191 | } |
| 192 | |
| 193 | void mlir::linalg::hoistRedundantVectorTransfers(Operation *root, |
| 194 | bool verifyNonZeroTrip) { |
| 195 | bool changed = true; |
| 196 | while (changed) { |
| 197 | changed = false; |
| 198 | // First move loop invariant ops outside of their loop. This needs to be |
| 199 | // done before as we cannot move ops without interrupting the function walk. |
| 200 | root->walk( |
| 201 | callback: [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); |
| 202 | |
| 203 | // Find all loops that are certain to have non zero trip count. Any loops |
| 204 | // that are not part of this set cannot be hoisted from, since hoisting from |
| 205 | // a potentially zero trip count loop may cause a vector transfer to be |
| 206 | // executed when it shouldn't be. |
| 207 | llvm::DenseSet<LoopLikeOpInterface> definiteNonZeroTripCountLoops; |
| 208 | if (verifyNonZeroTrip) { |
| 209 | root->walk(callback: [&](LoopLikeOpInterface loopLike) { |
| 210 | std::optional<SmallVector<OpFoldResult>> lbs = |
| 211 | loopLike.getLoopLowerBounds(); |
| 212 | std::optional<SmallVector<OpFoldResult>> ubs = |
| 213 | loopLike.getLoopUpperBounds(); |
| 214 | // If loop bounds cannot be found, assume possibly zero trip count. |
| 215 | if (!lbs || !ubs) |
| 216 | return; |
| 217 | |
| 218 | // Otherwise, use ValueBounds to find the maximum lower bound and |
| 219 | // minimum upper bound. If the bounds are found, and maxLb is less |
| 220 | // than the minUb, then the loop will not have zero trip count. |
| 221 | for (auto [lb, ub] : llvm::zip_equal(t&: lbs.value(), u&: ubs.value())) { |
| 222 | FailureOr<int64_t> maxLb = |
| 223 | ValueBoundsConstraintSet::computeConstantBound( |
| 224 | type: presburger::BoundType::UB, var: lb, |
| 225 | /*stopCondition=*/nullptr, /*closedUB=*/true); |
| 226 | if (failed(Result: maxLb)) |
| 227 | return; |
| 228 | FailureOr<int64_t> minUb = |
| 229 | ValueBoundsConstraintSet::computeConstantBound( |
| 230 | type: presburger::BoundType::LB, var: ub); |
| 231 | if (failed(Result: minUb)) |
| 232 | return; |
| 233 | if (minUb.value() <= maxLb.value()) |
| 234 | return; |
| 235 | definiteNonZeroTripCountLoops.insert(V: loopLike); |
| 236 | } |
| 237 | }); |
| 238 | } |
| 239 | |
| 240 | root->walk(callback: [&](vector::TransferReadOp transferRead) { |
| 241 | if (!isa<MemRefType>(Val: transferRead.getShapedType())) |
| 242 | return WalkResult::advance(); |
| 243 | |
| 244 | LLVM_DEBUG(DBGS() << "Candidate for hoisting: " |
| 245 | << *transferRead.getOperation() << "\n" ); |
| 246 | auto loop = dyn_cast<LoopLikeOpInterface>(Val: transferRead->getParentOp()); |
| 247 | LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp() |
| 248 | << "\n" ); |
| 249 | if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(Val: loop)) |
| 250 | return WalkResult::advance(); |
| 251 | |
| 252 | if (verifyNonZeroTrip && !definiteNonZeroTripCountLoops.contains(V: loop)) { |
| 253 | LLVM_DEBUG(DBGS() << "Loop may have zero trip count: " << *loop |
| 254 | << "\n" ); |
| 255 | return WalkResult::advance(); |
| 256 | } |
| 257 | |
| 258 | LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() |
| 259 | << "\n" ); |
| 260 | |
| 261 | SetVector<Operation *> forwardSlice; |
| 262 | getForwardSlice(op: transferRead.getOperation(), forwardSlice: &forwardSlice); |
| 263 | |
| 264 | // Look for the last TransferWriteOp in the forwardSlice of |
| 265 | // `transferRead` that operates on the same memref. |
| 266 | vector::TransferWriteOp transferWrite; |
| 267 | for (auto *sliceOp : llvm::reverse(C&: forwardSlice)) { |
| 268 | auto candidateWrite = dyn_cast<vector::TransferWriteOp>(Val: sliceOp); |
| 269 | if (!candidateWrite || |
| 270 | candidateWrite.getBase() != transferRead.getBase()) |
| 271 | continue; |
| 272 | transferWrite = candidateWrite; |
| 273 | } |
| 274 | |
| 275 | // All operands of the TransferRead must be defined outside of the loop. |
| 276 | for (auto operand : transferRead.getOperands()) |
| 277 | if (!loop.isDefinedOutsideOfLoop(value: operand)) |
| 278 | return WalkResult::advance(); |
| 279 | |
| 280 | // Only hoist transfer_read / transfer_write pairs and singleton |
| 281 | // transfer_reads for now. |
| 282 | if (!transferWrite) { |
| 283 | // Make sure there are no other accesses to the memref before |
| 284 | // hoisting transfer_read. |
| 285 | if (noAliasingUseInLoop(transferRead, loop)) |
| 286 | loop.moveOutOfLoop(op: transferRead); |
| 287 | return WalkResult::advance(); |
| 288 | } |
| 289 | |
| 290 | LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() |
| 291 | << "\n" ); |
| 292 | |
| 293 | // Approximate aliasing by checking that: |
| 294 | // 1. indices, vector type and permutation map are the same (i.e., the |
| 295 | // transfer_read/transfer_write ops are matching), |
| 296 | // 2. source operands for transfer.{read|write} do not originate from |
| 297 | // nor have users that are Ops implementing ViewLikeOpInterface. |
| 298 | // 3. no other operations in the loop access the same memref except |
| 299 | // for transfer_read/transfer_write accessing statically disjoint |
| 300 | // slices. |
| 301 | |
| 302 | // Check 1. |
| 303 | if (transferRead.getIndices() != transferWrite.getIndices() || |
| 304 | transferRead.getVectorType() != transferWrite.getVectorType() || |
| 305 | transferRead.getPermutationMap() != transferWrite.getPermutationMap()) |
| 306 | return WalkResult::advance(); |
| 307 | |
| 308 | // Check 2. Note, since both xfer Ops share the source, we only need to |
| 309 | // look at one of them. |
| 310 | auto base = transferRead.getBase(); |
| 311 | auto *source = base.getDefiningOp(); |
| 312 | if (source) { |
| 313 | // NOTE: We treat `memref.assume_alignment` as a special case. |
| 314 | // |
| 315 | // The idea is that it is safe to look past AssumeAlignmemtOp (i.e. |
| 316 | // MemRef _before_ alignment) iff: |
| 317 | // 1. It has exactly two uses (these have to be the xfer Ops |
| 318 | // being looked at). |
| 319 | // 2. The original MemRef has only one use (i.e. |
| 320 | // AssumeAlignmentOp). |
| 321 | // |
| 322 | // Relaxing these conditions will most likely require proper alias |
| 323 | // analysis. |
| 324 | if (auto assume = dyn_cast<memref::AssumeAlignmentOp>(Val: source)) { |
| 325 | Value memPreAlignment = assume.getMemref(); |
| 326 | auto numInLoopUses = |
| 327 | llvm::count_if(Range: base.getUses(), P: [&loop](OpOperand &use) { |
| 328 | return loop->isAncestor(other: use.getOwner()); |
| 329 | }); |
| 330 | |
| 331 | if (numInLoopUses && memPreAlignment.hasOneUse()) |
| 332 | source = memPreAlignment.getDefiningOp(); |
| 333 | } |
| 334 | if (isa_and_nonnull<ViewLikeOpInterface>(Val: source)) |
| 335 | return WalkResult::advance(); |
| 336 | } |
| 337 | |
| 338 | if (llvm::any_of(Range: base.getUsers(), P: llvm::IsaPred<ViewLikeOpInterface>)) |
| 339 | return WalkResult::advance(); |
| 340 | |
| 341 | // Check 3. |
| 342 | // TODO: may want to memoize this information for performance but it |
| 343 | // likely gets invalidated often. |
| 344 | DominanceInfo dom(loop); |
| 345 | if (!dom.properlyDominates(a: transferRead.getOperation(), b: transferWrite)) |
| 346 | return WalkResult::advance(); |
| 347 | for (auto &use : transferRead.getBase().getUses()) { |
| 348 | if (!loop->isAncestor(other: use.getOwner())) |
| 349 | continue; |
| 350 | if (use.getOwner() == transferRead.getOperation() || |
| 351 | use.getOwner() == transferWrite.getOperation()) |
| 352 | continue; |
| 353 | if (auto transferWriteUse = |
| 354 | dyn_cast<vector::TransferWriteOp>(Val: use.getOwner())) { |
| 355 | if (!vector::isDisjointTransferSet( |
| 356 | transferA: cast<VectorTransferOpInterface>(Val&: *transferWrite), |
| 357 | transferB: cast<VectorTransferOpInterface>(Val&: *transferWriteUse), |
| 358 | /*testDynamicValueUsingBounds=*/true)) |
| 359 | return WalkResult::advance(); |
| 360 | } else if (auto transferReadUse = |
| 361 | dyn_cast<vector::TransferReadOp>(Val: use.getOwner())) { |
| 362 | if (!vector::isDisjointTransferSet( |
| 363 | transferA: cast<VectorTransferOpInterface>(Val&: *transferWrite), |
| 364 | transferB: cast<VectorTransferOpInterface>(Val&: *transferReadUse), |
| 365 | /*testDynamicValueUsingBounds=*/true)) |
| 366 | return WalkResult::advance(); |
| 367 | } else { |
| 368 | // Unknown use, we cannot prove that it doesn't alias with the |
| 369 | // transferRead/transferWrite operations. |
| 370 | return WalkResult::advance(); |
| 371 | } |
| 372 | } |
| 373 | |
| 374 | // Hoist read before. |
| 375 | loop.moveOutOfLoop(op: transferRead); |
| 376 | |
| 377 | // Hoist write after. |
| 378 | transferWrite->moveAfter(existingOp: loop); |
| 379 | |
| 380 | // Rewrite `loop` with new yields by cloning and erase the original |
| 381 | // loop. |
| 382 | IRRewriter rewriter(transferRead.getContext()); |
| 383 | NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc, |
| 384 | ArrayRef<BlockArgument> newBBArgs) { |
| 385 | return SmallVector<Value>{transferWrite.getVector()}; |
| 386 | }; |
| 387 | |
| 388 | auto maybeNewLoop = loop.replaceWithAdditionalYields( |
| 389 | rewriter, newInitOperands: transferRead.getVector(), |
| 390 | /*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn: yieldFn); |
| 391 | if (failed(Result: maybeNewLoop)) |
| 392 | return WalkResult::interrupt(); |
| 393 | |
| 394 | transferWrite.getValueToStoreMutable().assign( |
| 395 | value: maybeNewLoop->getOperation()->getResults().back()); |
| 396 | changed = true; |
| 397 | // Need to interrupt and restart because erasing the loop messes up |
| 398 | // the walk. |
| 399 | return WalkResult::interrupt(); |
| 400 | }); |
| 401 | } |
| 402 | } |
| 403 | |