1//===----------- MultiBuffering.cpp ---------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements multi buffering transformation.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Arith/Utils/Utils.h"
15#include "mlir/Dialect/MemRef/IR/MemRef.h"
16#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
17#include "mlir/IR/AffineExpr.h"
18#include "mlir/IR/BuiltinAttributes.h"
19#include "mlir/IR/Dominance.h"
20#include "mlir/IR/PatternMatch.h"
21#include "mlir/IR/ValueRange.h"
22#include "mlir/Interfaces/LoopLikeInterface.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/Debug.h"
25
26using namespace mlir;
27
28#define DEBUG_TYPE "memref-transforms"
29#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
30#define DBGSNL() (llvm::dbgs() << "\n")
31
32/// Return true if the op fully overwrite the given `buffer` value.
33static bool overrideBuffer(Operation *op, Value buffer) {
34 auto copyOp = dyn_cast<memref::CopyOp>(Val: op);
35 if (!copyOp)
36 return false;
37 return copyOp.getTarget() == buffer;
38}
39
40/// Replace the uses of `oldOp` with the given `val` and for subview uses
41/// propagate the type change. Changing the memref type may require propagating
42/// it through subview ops so we cannot just do a replaceAllUse but need to
43/// propagate the type change and erase old subview ops.
44static void replaceUsesAndPropagateType(RewriterBase &rewriter,
45 Operation *oldOp, Value val) {
46 SmallVector<Operation *> opsToDelete;
47 SmallVector<OpOperand *> operandsToReplace;
48
49 // Save the operand to replace / delete later (avoid iterator invalidation).
50 // TODO: can we use an early_inc iterator?
51 for (OpOperand &use : oldOp->getUses()) {
52 // Non-subview ops will be replaced by `val`.
53 auto subviewUse = dyn_cast<memref::SubViewOp>(Val: use.getOwner());
54 if (!subviewUse) {
55 operandsToReplace.push_back(Elt: &use);
56 continue;
57 }
58
59 // `subview(old_op)` is replaced by a new `subview(val)`.
60 OpBuilder::InsertionGuard g(rewriter);
61 rewriter.setInsertionPoint(subviewUse);
62 MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
63 resultShape: subviewUse.getType().getShape(), sourceMemRefType: cast<MemRefType>(Val: val.getType()),
64 staticOffsets: subviewUse.getStaticOffsets(), staticSizes: subviewUse.getStaticSizes(),
65 staticStrides: subviewUse.getStaticStrides());
66 Value newSubview = rewriter.create<memref::SubViewOp>(
67 location: subviewUse->getLoc(), args&: newType, args&: val, args: subviewUse.getMixedOffsets(),
68 args: subviewUse.getMixedSizes(), args: subviewUse.getMixedStrides());
69
70 // Ouch recursion ... is this really necessary?
71 replaceUsesAndPropagateType(rewriter, oldOp: subviewUse, val: newSubview);
72
73 opsToDelete.push_back(Elt: use.getOwner());
74 }
75
76 // Perform late replacement.
77 // TODO: can we use an early_inc iterator?
78 for (OpOperand *operand : operandsToReplace) {
79 Operation *op = operand->getOwner();
80 rewriter.startOpModification(op);
81 operand->set(val);
82 rewriter.finalizeOpModification(op);
83 }
84
85 // Perform late op erasure.
86 // TODO: can we use an early_inc iterator?
87 for (Operation *op : opsToDelete)
88 rewriter.eraseOp(op);
89}
90
91// Transformation to do multi-buffering/array expansion to remove dependencies
92// on the temporary allocation between consecutive loop iterations.
93// Returns success if the transformation happened and failure otherwise.
94// This is not a pattern as it requires propagating the new memref type to its
95// uses and requires updating subview ops.
96FailureOr<memref::AllocOp>
97mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
98 unsigned multiBufferingFactor,
99 bool skipOverrideAnalysis) {
100 LLVM_DEBUG(DBGS() << "Start multibuffering: " << allocOp << "\n");
101 DominanceInfo dom(allocOp->getParentOp());
102 LoopLikeOpInterface candidateLoop;
103 for (Operation *user : allocOp->getUsers()) {
104 auto parentLoop = user->getParentOfType<LoopLikeOpInterface>();
105 if (!parentLoop) {
106 if (isa<memref::DeallocOp>(Val: user)) {
107 // Allow dealloc outside of any loop.
108 // TODO: The whole precondition function here is very brittle and will
109 // need to rethought an isolated into a cleaner analysis.
110 continue;
111 }
112 LLVM_DEBUG(DBGS() << "--no parent loop -> fail\n");
113 LLVM_DEBUG(DBGS() << "----due to user: " << *user << "\n");
114 return failure();
115 }
116 if (!skipOverrideAnalysis) {
117 /// Make sure there is no loop-carried dependency on the allocation.
118 if (!overrideBuffer(op: user, buffer: allocOp.getResult())) {
119 LLVM_DEBUG(DBGS() << "--Skip user: found loop-carried dependence\n");
120 continue;
121 }
122 // If this user doesn't dominate all the other users keep looking.
123 if (llvm::any_of(Range: allocOp->getUsers(), P: [&](Operation *otherUser) {
124 return !dom.dominates(a: user, b: otherUser);
125 })) {
126 LLVM_DEBUG(
127 DBGS() << "--Skip user: does not dominate all other users\n");
128 continue;
129 }
130 } else {
131 if (llvm::any_of(Range: allocOp->getUsers(), P: [&](Operation *otherUser) {
132 return !isa<memref::DeallocOp>(Val: otherUser) &&
133 !parentLoop->isProperAncestor(other: otherUser);
134 })) {
135 LLVM_DEBUG(
136 DBGS()
137 << "--Skip user: not all other users are in the parent loop\n");
138 continue;
139 }
140 }
141 candidateLoop = parentLoop;
142 break;
143 }
144
145 if (!candidateLoop) {
146 LLVM_DEBUG(DBGS() << "Skip alloc: no candidate loop\n");
147 return failure();
148 }
149
150 std::optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
151 std::optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
152 std::optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
153 if (!inductionVar || !lowerBound || !singleStep ||
154 !llvm::hasSingleElement(C: candidateLoop.getLoopRegions())) {
155 LLVM_DEBUG(DBGS() << "Skip alloc: no single iv, lb, step or region\n");
156 return failure();
157 }
158
159 if (!dom.dominates(a: allocOp.getOperation(), b: candidateLoop)) {
160 LLVM_DEBUG(DBGS() << "Skip alloc: does not dominate candidate loop\n");
161 return failure();
162 }
163
164 LLVM_DEBUG(DBGS() << "Start multibuffering loop: " << candidateLoop << "\n");
165
166 // 1. Construct the multi-buffered memref type.
167 ArrayRef<int64_t> originalShape = allocOp.getType().getShape();
168 SmallVector<int64_t, 4> multiBufferedShape{multiBufferingFactor};
169 llvm::append_range(C&: multiBufferedShape, R&: originalShape);
170 LLVM_DEBUG(DBGS() << "--original type: " << allocOp.getType() << "\n");
171 MemRefType mbMemRefType = MemRefType::Builder(allocOp.getType())
172 .setShape(multiBufferedShape)
173 .setLayout(MemRefLayoutAttrInterface());
174 LLVM_DEBUG(DBGS() << "--multi-buffered type: " << mbMemRefType << "\n");
175
176 // 2. Create the multi-buffered alloc.
177 Location loc = allocOp->getLoc();
178 OpBuilder::InsertionGuard g(rewriter);
179 rewriter.setInsertionPoint(allocOp);
180 auto mbAlloc = rewriter.create<memref::AllocOp>(
181 location: loc, args&: mbMemRefType, args: ValueRange{}, args: allocOp->getAttrs());
182 LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n");
183
184 // 3. Within the loop, build the modular leading index (i.e. each loop
185 // iteration %iv accesses slice ((%iv - %lb) / %step) % %mb_factor).
186 rewriter.setInsertionPointToStart(
187 &candidateLoop.getLoopRegions().front()->front());
188 Value ivVal = *inductionVar;
189 Value lbVal = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: *lowerBound);
190 Value stepVal = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: *singleStep);
191 AffineExpr iv, lb, step;
192 bindDims(ctx: rewriter.getContext(), exprs&: iv, exprs&: lb, exprs&: step);
193 Value bufferIndex = affine::makeComposedAffineApply(
194 b&: rewriter, loc, e: ((iv - lb).floorDiv(other: step)) % multiBufferingFactor,
195 operands: {ivVal, lbVal, stepVal});
196 LLVM_DEBUG(DBGS() << "--multi-buffered indexing: " << bufferIndex << "\n");
197
198 // 4. Build the subview accessing the particular slice, taking modular
199 // rotation into account.
200 int64_t mbMemRefTypeRank = mbMemRefType.getRank();
201 IntegerAttr zero = rewriter.getIndexAttr(value: 0);
202 IntegerAttr one = rewriter.getIndexAttr(value: 1);
203 SmallVector<OpFoldResult> offsets(mbMemRefTypeRank, zero);
204 SmallVector<OpFoldResult> sizes(mbMemRefTypeRank, one);
205 SmallVector<OpFoldResult> strides(mbMemRefTypeRank, one);
206 // Offset is [bufferIndex, 0 ... 0 ].
207 offsets.front() = bufferIndex;
208 // Sizes is [1, original_size_0 ... original_size_n ].
209 for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
210 sizes[1 + i] = rewriter.getIndexAttr(value: originalShape[i]);
211 // Strides is [1, 1 ... 1 ].
212 MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType(
213 resultShape: originalShape, sourceMemRefType: mbMemRefType, staticOffsets: offsets, staticSizes: sizes, staticStrides: strides);
214 Value subview = rewriter.create<memref::SubViewOp>(location: loc, args&: dstMemref, args&: mbAlloc,
215 args&: offsets, args&: sizes, args&: strides);
216 LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
217
218 // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to
219 // handle dealloc uses separately..
220 for (OpOperand &use : llvm::make_early_inc_range(Range: allocOp->getUses())) {
221 auto deallocOp = dyn_cast<memref::DeallocOp>(Val: use.getOwner());
222 if (!deallocOp)
223 continue;
224 OpBuilder::InsertionGuard g(rewriter);
225 rewriter.setInsertionPoint(deallocOp);
226 auto newDeallocOp =
227 rewriter.create<memref::DeallocOp>(location: deallocOp->getLoc(), args&: mbAlloc);
228 (void)newDeallocOp;
229 LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n");
230 rewriter.eraseOp(op: deallocOp);
231 }
232
233 // 6. RAUW with the particular slice, taking modular rotation into account.
234 replaceUsesAndPropagateType(rewriter, oldOp: allocOp, val: subview);
235
236 // 7. Finally, erase the old allocOp.
237 rewriter.eraseOp(op: allocOp);
238
239 return mbAlloc;
240}
241
242FailureOr<memref::AllocOp>
243mlir::memref::multiBuffer(memref::AllocOp allocOp,
244 unsigned multiBufferingFactor,
245 bool skipOverrideAnalysis) {
246 IRRewriter rewriter(allocOp->getContext());
247 return multiBuffer(rewriter, allocOp, multiBufferingFactor,
248 skipOverrideAnalysis);
249}
250

source code of mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp