1 | //===----------- MultiBuffering.cpp ---------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements multi buffering transformation. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
14 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
15 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
16 | #include "mlir/Dialect/MemRef/Transforms/Passes.h" |
17 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
18 | #include "mlir/IR/AffineExpr.h" |
19 | #include "mlir/IR/BuiltinAttributes.h" |
20 | #include "mlir/IR/Dominance.h" |
21 | #include "mlir/IR/PatternMatch.h" |
22 | #include "mlir/IR/ValueRange.h" |
23 | #include "mlir/Interfaces/LoopLikeInterface.h" |
24 | #include "llvm/ADT/STLExtras.h" |
25 | #include "llvm/Support/Debug.h" |
26 | |
27 | using namespace mlir; |
28 | |
29 | #define DEBUG_TYPE "memref-transforms" |
30 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
31 | #define DBGSNL() (llvm::dbgs() << "\n") |
32 | |
33 | /// Return true if the op fully overwrite the given `buffer` value. |
34 | static bool overrideBuffer(Operation *op, Value buffer) { |
35 | auto copyOp = dyn_cast<memref::CopyOp>(op); |
36 | if (!copyOp) |
37 | return false; |
38 | return copyOp.getTarget() == buffer; |
39 | } |
40 | |
41 | /// Replace the uses of `oldOp` with the given `val` and for subview uses |
42 | /// propagate the type change. Changing the memref type may require propagating |
43 | /// it through subview ops so we cannot just do a replaceAllUse but need to |
44 | /// propagate the type change and erase old subview ops. |
45 | static void replaceUsesAndPropagateType(RewriterBase &rewriter, |
46 | Operation *oldOp, Value val) { |
47 | SmallVector<Operation *> opsToDelete; |
48 | SmallVector<OpOperand *> operandsToReplace; |
49 | |
50 | // Save the operand to replace / delete later (avoid iterator invalidation). |
51 | // TODO: can we use an early_inc iterator? |
52 | for (OpOperand &use : oldOp->getUses()) { |
53 | // Non-subview ops will be replaced by `val`. |
54 | auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner()); |
55 | if (!subviewUse) { |
56 | operandsToReplace.push_back(Elt: &use); |
57 | continue; |
58 | } |
59 | |
60 | // `subview(old_op)` is replaced by a new `subview(val)`. |
61 | OpBuilder::InsertionGuard g(rewriter); |
62 | rewriter.setInsertionPoint(subviewUse); |
63 | Type newType = memref::SubViewOp::inferRankReducedResultType( |
64 | subviewUse.getType().getShape(), cast<MemRefType>(val.getType()), |
65 | subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(), |
66 | subviewUse.getStaticStrides()); |
67 | Value newSubview = rewriter.create<memref::SubViewOp>( |
68 | subviewUse->getLoc(), cast<MemRefType>(newType), val, |
69 | subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), |
70 | subviewUse.getMixedStrides()); |
71 | |
72 | // Ouch recursion ... is this really necessary? |
73 | replaceUsesAndPropagateType(rewriter, subviewUse, newSubview); |
74 | |
75 | opsToDelete.push_back(Elt: use.getOwner()); |
76 | } |
77 | |
78 | // Perform late replacement. |
79 | // TODO: can we use an early_inc iterator? |
80 | for (OpOperand *operand : operandsToReplace) { |
81 | Operation *op = operand->getOwner(); |
82 | rewriter.startOpModification(op); |
83 | operand->set(val); |
84 | rewriter.finalizeOpModification(op); |
85 | } |
86 | |
87 | // Perform late op erasure. |
88 | // TODO: can we use an early_inc iterator? |
89 | for (Operation *op : opsToDelete) |
90 | rewriter.eraseOp(op); |
91 | } |
92 | |
93 | // Transformation to do multi-buffering/array expansion to remove dependencies |
94 | // on the temporary allocation between consecutive loop iterations. |
95 | // Returns success if the transformation happened and failure otherwise. |
96 | // This is not a pattern as it requires propagating the new memref type to its |
97 | // uses and requires updating subview ops. |
98 | FailureOr<memref::AllocOp> |
99 | mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, |
100 | unsigned multiBufferingFactor, |
101 | bool skipOverrideAnalysis) { |
102 | LLVM_DEBUG(DBGS() << "Start multibuffering: " << allocOp << "\n" ); |
103 | DominanceInfo dom(allocOp->getParentOp()); |
104 | LoopLikeOpInterface candidateLoop; |
105 | for (Operation *user : allocOp->getUsers()) { |
106 | auto parentLoop = user->getParentOfType<LoopLikeOpInterface>(); |
107 | if (!parentLoop) { |
108 | if (isa<memref::DeallocOp>(user)) { |
109 | // Allow dealloc outside of any loop. |
110 | // TODO: The whole precondition function here is very brittle and will |
111 | // need to rethought an isolated into a cleaner analysis. |
112 | continue; |
113 | } |
114 | LLVM_DEBUG(DBGS() << "--no parent loop -> fail\n" ); |
115 | LLVM_DEBUG(DBGS() << "----due to user: " << *user << "\n" ); |
116 | return failure(); |
117 | } |
118 | if (!skipOverrideAnalysis) { |
119 | /// Make sure there is no loop-carried dependency on the allocation. |
120 | if (!overrideBuffer(user, allocOp.getResult())) { |
121 | LLVM_DEBUG(DBGS() << "--Skip user: found loop-carried dependence\n" ); |
122 | continue; |
123 | } |
124 | // If this user doesn't dominate all the other users keep looking. |
125 | if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) { |
126 | return !dom.dominates(user, otherUser); |
127 | })) { |
128 | LLVM_DEBUG( |
129 | DBGS() << "--Skip user: does not dominate all other users\n" ); |
130 | continue; |
131 | } |
132 | } else { |
133 | if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) { |
134 | return !isa<memref::DeallocOp>(otherUser) && |
135 | !parentLoop->isProperAncestor(otherUser); |
136 | })) { |
137 | LLVM_DEBUG( |
138 | DBGS() |
139 | << "--Skip user: not all other users are in the parent loop\n" ); |
140 | continue; |
141 | } |
142 | } |
143 | candidateLoop = parentLoop; |
144 | break; |
145 | } |
146 | |
147 | if (!candidateLoop) { |
148 | LLVM_DEBUG(DBGS() << "Skip alloc: no candidate loop\n" ); |
149 | return failure(); |
150 | } |
151 | |
152 | std::optional<Value> inductionVar = candidateLoop.getSingleInductionVar(); |
153 | std::optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound(); |
154 | std::optional<OpFoldResult> singleStep = candidateLoop.getSingleStep(); |
155 | if (!inductionVar || !lowerBound || !singleStep || |
156 | !llvm::hasSingleElement(candidateLoop.getLoopRegions())) { |
157 | LLVM_DEBUG(DBGS() << "Skip alloc: no single iv, lb, step or region\n" ); |
158 | return failure(); |
159 | } |
160 | |
161 | if (!dom.dominates(allocOp.getOperation(), candidateLoop)) { |
162 | LLVM_DEBUG(DBGS() << "Skip alloc: does not dominate candidate loop\n" ); |
163 | return failure(); |
164 | } |
165 | |
166 | LLVM_DEBUG(DBGS() << "Start multibuffering loop: " << candidateLoop << "\n" ); |
167 | |
168 | // 1. Construct the multi-buffered memref type. |
169 | ArrayRef<int64_t> originalShape = allocOp.getType().getShape(); |
170 | SmallVector<int64_t, 4> multiBufferedShape{multiBufferingFactor}; |
171 | llvm::append_range(C&: multiBufferedShape, R&: originalShape); |
172 | LLVM_DEBUG(DBGS() << "--original type: " << allocOp.getType() << "\n" ); |
173 | MemRefType mbMemRefType = MemRefType::Builder(allocOp.getType()) |
174 | .setShape(multiBufferedShape) |
175 | .setLayout(MemRefLayoutAttrInterface()); |
176 | LLVM_DEBUG(DBGS() << "--multi-buffered type: " << mbMemRefType << "\n" ); |
177 | |
178 | // 2. Create the multi-buffered alloc. |
179 | Location loc = allocOp->getLoc(); |
180 | OpBuilder::InsertionGuard g(rewriter); |
181 | rewriter.setInsertionPoint(allocOp); |
182 | auto mbAlloc = rewriter.create<memref::AllocOp>( |
183 | loc, mbMemRefType, ValueRange{}, allocOp->getAttrs()); |
184 | LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n" ); |
185 | |
186 | // 3. Within the loop, build the modular leading index (i.e. each loop |
187 | // iteration %iv accesses slice ((%iv - %lb) / %step) % %mb_factor). |
188 | rewriter.setInsertionPointToStart( |
189 | &candidateLoop.getLoopRegions().front()->front()); |
190 | Value ivVal = *inductionVar; |
191 | Value lbVal = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: *lowerBound); |
192 | Value stepVal = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: *singleStep); |
193 | AffineExpr iv, lb, step; |
194 | bindDims(ctx: rewriter.getContext(), exprs&: iv, exprs&: lb, exprs&: step); |
195 | Value bufferIndex = affine::makeComposedAffineApply( |
196 | rewriter, loc, ((iv - lb).floorDiv(other: step)) % multiBufferingFactor, |
197 | {ivVal, lbVal, stepVal}); |
198 | LLVM_DEBUG(DBGS() << "--multi-buffered indexing: " << bufferIndex << "\n" ); |
199 | |
200 | // 4. Build the subview accessing the particular slice, taking modular |
201 | // rotation into account. |
202 | int64_t mbMemRefTypeRank = mbMemRefType.getRank(); |
203 | IntegerAttr zero = rewriter.getIndexAttr(0); |
204 | IntegerAttr one = rewriter.getIndexAttr(1); |
205 | SmallVector<OpFoldResult> offsets(mbMemRefTypeRank, zero); |
206 | SmallVector<OpFoldResult> sizes(mbMemRefTypeRank, one); |
207 | SmallVector<OpFoldResult> strides(mbMemRefTypeRank, one); |
208 | // Offset is [bufferIndex, 0 ... 0 ]. |
209 | offsets.front() = bufferIndex; |
210 | // Sizes is [1, original_size_0 ... original_size_n ]. |
211 | for (int64_t i = 0, e = originalShape.size(); i != e; ++i) |
212 | sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]); |
213 | // Strides is [1, 1 ... 1 ]. |
214 | auto dstMemref = |
215 | cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType( |
216 | originalShape, mbMemRefType, offsets, sizes, strides)); |
217 | Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc, |
218 | offsets, sizes, strides); |
219 | LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n" ); |
220 | |
221 | // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to |
222 | // handle dealloc uses separately.. |
223 | for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) { |
224 | auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner()); |
225 | if (!deallocOp) |
226 | continue; |
227 | OpBuilder::InsertionGuard g(rewriter); |
228 | rewriter.setInsertionPoint(deallocOp); |
229 | auto newDeallocOp = |
230 | rewriter.create<memref::DeallocOp>(deallocOp->getLoc(), mbAlloc); |
231 | (void)newDeallocOp; |
232 | LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n" ); |
233 | rewriter.eraseOp(deallocOp); |
234 | } |
235 | |
236 | // 6. RAUW with the particular slice, taking modular rotation into account. |
237 | replaceUsesAndPropagateType(rewriter, allocOp, subview); |
238 | |
239 | // 7. Finally, erase the old allocOp. |
240 | rewriter.eraseOp(op: allocOp); |
241 | |
242 | return mbAlloc; |
243 | } |
244 | |
245 | FailureOr<memref::AllocOp> |
246 | mlir::memref::multiBuffer(memref::AllocOp allocOp, |
247 | unsigned multiBufferingFactor, |
248 | bool skipOverrideAnalysis) { |
249 | IRRewriter rewriter(allocOp->getContext()); |
250 | return multiBuffer(rewriter, allocOp, multiBufferingFactor, |
251 | skipOverrideAnalysis); |
252 | } |
253 | |