| 1 | //===----------- MultiBuffering.cpp ---------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements multi buffering transformation. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 14 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 15 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 16 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
| 17 | #include "mlir/IR/AffineExpr.h" |
| 18 | #include "mlir/IR/BuiltinAttributes.h" |
| 19 | #include "mlir/IR/Dominance.h" |
| 20 | #include "mlir/IR/PatternMatch.h" |
| 21 | #include "mlir/IR/ValueRange.h" |
| 22 | #include "mlir/Interfaces/LoopLikeInterface.h" |
| 23 | #include "llvm/ADT/STLExtras.h" |
| 24 | #include "llvm/Support/Debug.h" |
| 25 | |
| 26 | using namespace mlir; |
| 27 | |
| 28 | #define DEBUG_TYPE "memref-transforms" |
| 29 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
| 30 | #define DBGSNL() (llvm::dbgs() << "\n") |
| 31 | |
| 32 | /// Return true if the op fully overwrite the given `buffer` value. |
| 33 | static bool overrideBuffer(Operation *op, Value buffer) { |
| 34 | auto copyOp = dyn_cast<memref::CopyOp>(Val: op); |
| 35 | if (!copyOp) |
| 36 | return false; |
| 37 | return copyOp.getTarget() == buffer; |
| 38 | } |
| 39 | |
| 40 | /// Replace the uses of `oldOp` with the given `val` and for subview uses |
| 41 | /// propagate the type change. Changing the memref type may require propagating |
| 42 | /// it through subview ops so we cannot just do a replaceAllUse but need to |
| 43 | /// propagate the type change and erase old subview ops. |
| 44 | static void replaceUsesAndPropagateType(RewriterBase &rewriter, |
| 45 | Operation *oldOp, Value val) { |
| 46 | SmallVector<Operation *> opsToDelete; |
| 47 | SmallVector<OpOperand *> operandsToReplace; |
| 48 | |
| 49 | // Save the operand to replace / delete later (avoid iterator invalidation). |
| 50 | // TODO: can we use an early_inc iterator? |
| 51 | for (OpOperand &use : oldOp->getUses()) { |
| 52 | // Non-subview ops will be replaced by `val`. |
| 53 | auto subviewUse = dyn_cast<memref::SubViewOp>(Val: use.getOwner()); |
| 54 | if (!subviewUse) { |
| 55 | operandsToReplace.push_back(Elt: &use); |
| 56 | continue; |
| 57 | } |
| 58 | |
| 59 | // `subview(old_op)` is replaced by a new `subview(val)`. |
| 60 | OpBuilder::InsertionGuard g(rewriter); |
| 61 | rewriter.setInsertionPoint(subviewUse); |
| 62 | MemRefType newType = memref::SubViewOp::inferRankReducedResultType( |
| 63 | resultShape: subviewUse.getType().getShape(), sourceMemRefType: cast<MemRefType>(Val: val.getType()), |
| 64 | staticOffsets: subviewUse.getStaticOffsets(), staticSizes: subviewUse.getStaticSizes(), |
| 65 | staticStrides: subviewUse.getStaticStrides()); |
| 66 | Value newSubview = rewriter.create<memref::SubViewOp>( |
| 67 | location: subviewUse->getLoc(), args&: newType, args&: val, args: subviewUse.getMixedOffsets(), |
| 68 | args: subviewUse.getMixedSizes(), args: subviewUse.getMixedStrides()); |
| 69 | |
| 70 | // Ouch recursion ... is this really necessary? |
| 71 | replaceUsesAndPropagateType(rewriter, oldOp: subviewUse, val: newSubview); |
| 72 | |
| 73 | opsToDelete.push_back(Elt: use.getOwner()); |
| 74 | } |
| 75 | |
| 76 | // Perform late replacement. |
| 77 | // TODO: can we use an early_inc iterator? |
| 78 | for (OpOperand *operand : operandsToReplace) { |
| 79 | Operation *op = operand->getOwner(); |
| 80 | rewriter.startOpModification(op); |
| 81 | operand->set(val); |
| 82 | rewriter.finalizeOpModification(op); |
| 83 | } |
| 84 | |
| 85 | // Perform late op erasure. |
| 86 | // TODO: can we use an early_inc iterator? |
| 87 | for (Operation *op : opsToDelete) |
| 88 | rewriter.eraseOp(op); |
| 89 | } |
| 90 | |
| 91 | // Transformation to do multi-buffering/array expansion to remove dependencies |
| 92 | // on the temporary allocation between consecutive loop iterations. |
| 93 | // Returns success if the transformation happened and failure otherwise. |
| 94 | // This is not a pattern as it requires propagating the new memref type to its |
| 95 | // uses and requires updating subview ops. |
| 96 | FailureOr<memref::AllocOp> |
| 97 | mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, |
| 98 | unsigned multiBufferingFactor, |
| 99 | bool skipOverrideAnalysis) { |
| 100 | LLVM_DEBUG(DBGS() << "Start multibuffering: " << allocOp << "\n" ); |
| 101 | DominanceInfo dom(allocOp->getParentOp()); |
| 102 | LoopLikeOpInterface candidateLoop; |
| 103 | for (Operation *user : allocOp->getUsers()) { |
| 104 | auto parentLoop = user->getParentOfType<LoopLikeOpInterface>(); |
| 105 | if (!parentLoop) { |
| 106 | if (isa<memref::DeallocOp>(Val: user)) { |
| 107 | // Allow dealloc outside of any loop. |
| 108 | // TODO: The whole precondition function here is very brittle and will |
| 109 | // need to rethought an isolated into a cleaner analysis. |
| 110 | continue; |
| 111 | } |
| 112 | LLVM_DEBUG(DBGS() << "--no parent loop -> fail\n" ); |
| 113 | LLVM_DEBUG(DBGS() << "----due to user: " << *user << "\n" ); |
| 114 | return failure(); |
| 115 | } |
| 116 | if (!skipOverrideAnalysis) { |
| 117 | /// Make sure there is no loop-carried dependency on the allocation. |
| 118 | if (!overrideBuffer(op: user, buffer: allocOp.getResult())) { |
| 119 | LLVM_DEBUG(DBGS() << "--Skip user: found loop-carried dependence\n" ); |
| 120 | continue; |
| 121 | } |
| 122 | // If this user doesn't dominate all the other users keep looking. |
| 123 | if (llvm::any_of(Range: allocOp->getUsers(), P: [&](Operation *otherUser) { |
| 124 | return !dom.dominates(a: user, b: otherUser); |
| 125 | })) { |
| 126 | LLVM_DEBUG( |
| 127 | DBGS() << "--Skip user: does not dominate all other users\n" ); |
| 128 | continue; |
| 129 | } |
| 130 | } else { |
| 131 | if (llvm::any_of(Range: allocOp->getUsers(), P: [&](Operation *otherUser) { |
| 132 | return !isa<memref::DeallocOp>(Val: otherUser) && |
| 133 | !parentLoop->isProperAncestor(other: otherUser); |
| 134 | })) { |
| 135 | LLVM_DEBUG( |
| 136 | DBGS() |
| 137 | << "--Skip user: not all other users are in the parent loop\n" ); |
| 138 | continue; |
| 139 | } |
| 140 | } |
| 141 | candidateLoop = parentLoop; |
| 142 | break; |
| 143 | } |
| 144 | |
| 145 | if (!candidateLoop) { |
| 146 | LLVM_DEBUG(DBGS() << "Skip alloc: no candidate loop\n" ); |
| 147 | return failure(); |
| 148 | } |
| 149 | |
| 150 | std::optional<Value> inductionVar = candidateLoop.getSingleInductionVar(); |
| 151 | std::optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound(); |
| 152 | std::optional<OpFoldResult> singleStep = candidateLoop.getSingleStep(); |
| 153 | if (!inductionVar || !lowerBound || !singleStep || |
| 154 | !llvm::hasSingleElement(C: candidateLoop.getLoopRegions())) { |
| 155 | LLVM_DEBUG(DBGS() << "Skip alloc: no single iv, lb, step or region\n" ); |
| 156 | return failure(); |
| 157 | } |
| 158 | |
| 159 | if (!dom.dominates(a: allocOp.getOperation(), b: candidateLoop)) { |
| 160 | LLVM_DEBUG(DBGS() << "Skip alloc: does not dominate candidate loop\n" ); |
| 161 | return failure(); |
| 162 | } |
| 163 | |
| 164 | LLVM_DEBUG(DBGS() << "Start multibuffering loop: " << candidateLoop << "\n" ); |
| 165 | |
| 166 | // 1. Construct the multi-buffered memref type. |
| 167 | ArrayRef<int64_t> originalShape = allocOp.getType().getShape(); |
| 168 | SmallVector<int64_t, 4> multiBufferedShape{multiBufferingFactor}; |
| 169 | llvm::append_range(C&: multiBufferedShape, R&: originalShape); |
| 170 | LLVM_DEBUG(DBGS() << "--original type: " << allocOp.getType() << "\n" ); |
| 171 | MemRefType mbMemRefType = MemRefType::Builder(allocOp.getType()) |
| 172 | .setShape(multiBufferedShape) |
| 173 | .setLayout(MemRefLayoutAttrInterface()); |
| 174 | LLVM_DEBUG(DBGS() << "--multi-buffered type: " << mbMemRefType << "\n" ); |
| 175 | |
| 176 | // 2. Create the multi-buffered alloc. |
| 177 | Location loc = allocOp->getLoc(); |
| 178 | OpBuilder::InsertionGuard g(rewriter); |
| 179 | rewriter.setInsertionPoint(allocOp); |
| 180 | auto mbAlloc = rewriter.create<memref::AllocOp>( |
| 181 | location: loc, args&: mbMemRefType, args: ValueRange{}, args: allocOp->getAttrs()); |
| 182 | LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n" ); |
| 183 | |
| 184 | // 3. Within the loop, build the modular leading index (i.e. each loop |
| 185 | // iteration %iv accesses slice ((%iv - %lb) / %step) % %mb_factor). |
| 186 | rewriter.setInsertionPointToStart( |
| 187 | &candidateLoop.getLoopRegions().front()->front()); |
| 188 | Value ivVal = *inductionVar; |
| 189 | Value lbVal = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: *lowerBound); |
| 190 | Value stepVal = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: *singleStep); |
| 191 | AffineExpr iv, lb, step; |
| 192 | bindDims(ctx: rewriter.getContext(), exprs&: iv, exprs&: lb, exprs&: step); |
| 193 | Value bufferIndex = affine::makeComposedAffineApply( |
| 194 | b&: rewriter, loc, e: ((iv - lb).floorDiv(other: step)) % multiBufferingFactor, |
| 195 | operands: {ivVal, lbVal, stepVal}); |
| 196 | LLVM_DEBUG(DBGS() << "--multi-buffered indexing: " << bufferIndex << "\n" ); |
| 197 | |
| 198 | // 4. Build the subview accessing the particular slice, taking modular |
| 199 | // rotation into account. |
| 200 | int64_t mbMemRefTypeRank = mbMemRefType.getRank(); |
| 201 | IntegerAttr zero = rewriter.getIndexAttr(value: 0); |
| 202 | IntegerAttr one = rewriter.getIndexAttr(value: 1); |
| 203 | SmallVector<OpFoldResult> offsets(mbMemRefTypeRank, zero); |
| 204 | SmallVector<OpFoldResult> sizes(mbMemRefTypeRank, one); |
| 205 | SmallVector<OpFoldResult> strides(mbMemRefTypeRank, one); |
| 206 | // Offset is [bufferIndex, 0 ... 0 ]. |
| 207 | offsets.front() = bufferIndex; |
| 208 | // Sizes is [1, original_size_0 ... original_size_n ]. |
| 209 | for (int64_t i = 0, e = originalShape.size(); i != e; ++i) |
| 210 | sizes[1 + i] = rewriter.getIndexAttr(value: originalShape[i]); |
| 211 | // Strides is [1, 1 ... 1 ]. |
| 212 | MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType( |
| 213 | resultShape: originalShape, sourceMemRefType: mbMemRefType, staticOffsets: offsets, staticSizes: sizes, staticStrides: strides); |
| 214 | Value subview = rewriter.create<memref::SubViewOp>(location: loc, args&: dstMemref, args&: mbAlloc, |
| 215 | args&: offsets, args&: sizes, args&: strides); |
| 216 | LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n" ); |
| 217 | |
| 218 | // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to |
| 219 | // handle dealloc uses separately.. |
| 220 | for (OpOperand &use : llvm::make_early_inc_range(Range: allocOp->getUses())) { |
| 221 | auto deallocOp = dyn_cast<memref::DeallocOp>(Val: use.getOwner()); |
| 222 | if (!deallocOp) |
| 223 | continue; |
| 224 | OpBuilder::InsertionGuard g(rewriter); |
| 225 | rewriter.setInsertionPoint(deallocOp); |
| 226 | auto newDeallocOp = |
| 227 | rewriter.create<memref::DeallocOp>(location: deallocOp->getLoc(), args&: mbAlloc); |
| 228 | (void)newDeallocOp; |
| 229 | LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n" ); |
| 230 | rewriter.eraseOp(op: deallocOp); |
| 231 | } |
| 232 | |
| 233 | // 6. RAUW with the particular slice, taking modular rotation into account. |
| 234 | replaceUsesAndPropagateType(rewriter, oldOp: allocOp, val: subview); |
| 235 | |
| 236 | // 7. Finally, erase the old allocOp. |
| 237 | rewriter.eraseOp(op: allocOp); |
| 238 | |
| 239 | return mbAlloc; |
| 240 | } |
| 241 | |
| 242 | FailureOr<memref::AllocOp> |
| 243 | mlir::memref::multiBuffer(memref::AllocOp allocOp, |
| 244 | unsigned multiBufferingFactor, |
| 245 | bool skipOverrideAnalysis) { |
| 246 | IRRewriter rewriter(allocOp->getContext()); |
| 247 | return multiBuffer(rewriter, allocOp, multiBufferingFactor, |
| 248 | skipOverrideAnalysis); |
| 249 | } |
| 250 | |