1//===----------- MultiBuffering.cpp ---------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements multi buffering transformation.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Arith/Utils/Utils.h"
15#include "mlir/Dialect/MemRef/IR/MemRef.h"
16#include "mlir/Dialect/MemRef/Transforms/Passes.h"
17#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
18#include "mlir/IR/AffineExpr.h"
19#include "mlir/IR/BuiltinAttributes.h"
20#include "mlir/IR/Dominance.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/IR/ValueRange.h"
23#include "mlir/Interfaces/LoopLikeInterface.h"
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/Support/Debug.h"
26
27using namespace mlir;
28
29#define DEBUG_TYPE "memref-transforms"
30#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
31#define DBGSNL() (llvm::dbgs() << "\n")
32
33/// Return true if the op fully overwrite the given `buffer` value.
34static bool overrideBuffer(Operation *op, Value buffer) {
35 auto copyOp = dyn_cast<memref::CopyOp>(op);
36 if (!copyOp)
37 return false;
38 return copyOp.getTarget() == buffer;
39}
40
41/// Replace the uses of `oldOp` with the given `val` and for subview uses
42/// propagate the type change. Changing the memref type may require propagating
43/// it through subview ops so we cannot just do a replaceAllUse but need to
44/// propagate the type change and erase old subview ops.
45static void replaceUsesAndPropagateType(RewriterBase &rewriter,
46 Operation *oldOp, Value val) {
47 SmallVector<Operation *> opsToDelete;
48 SmallVector<OpOperand *> operandsToReplace;
49
50 // Save the operand to replace / delete later (avoid iterator invalidation).
51 // TODO: can we use an early_inc iterator?
52 for (OpOperand &use : oldOp->getUses()) {
53 // Non-subview ops will be replaced by `val`.
54 auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner());
55 if (!subviewUse) {
56 operandsToReplace.push_back(Elt: &use);
57 continue;
58 }
59
60 // `subview(old_op)` is replaced by a new `subview(val)`.
61 OpBuilder::InsertionGuard g(rewriter);
62 rewriter.setInsertionPoint(subviewUse);
63 Type newType = memref::SubViewOp::inferRankReducedResultType(
64 subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
65 subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
66 subviewUse.getStaticStrides());
67 Value newSubview = rewriter.create<memref::SubViewOp>(
68 subviewUse->getLoc(), cast<MemRefType>(newType), val,
69 subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
70 subviewUse.getMixedStrides());
71
72 // Ouch recursion ... is this really necessary?
73 replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
74
75 opsToDelete.push_back(Elt: use.getOwner());
76 }
77
78 // Perform late replacement.
79 // TODO: can we use an early_inc iterator?
80 for (OpOperand *operand : operandsToReplace) {
81 Operation *op = operand->getOwner();
82 rewriter.startOpModification(op);
83 operand->set(val);
84 rewriter.finalizeOpModification(op);
85 }
86
87 // Perform late op erasure.
88 // TODO: can we use an early_inc iterator?
89 for (Operation *op : opsToDelete)
90 rewriter.eraseOp(op);
91}
92
93// Transformation to do multi-buffering/array expansion to remove dependencies
94// on the temporary allocation between consecutive loop iterations.
95// Returns success if the transformation happened and failure otherwise.
96// This is not a pattern as it requires propagating the new memref type to its
97// uses and requires updating subview ops.
98FailureOr<memref::AllocOp>
99mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
100 unsigned multiBufferingFactor,
101 bool skipOverrideAnalysis) {
102 LLVM_DEBUG(DBGS() << "Start multibuffering: " << allocOp << "\n");
103 DominanceInfo dom(allocOp->getParentOp());
104 LoopLikeOpInterface candidateLoop;
105 for (Operation *user : allocOp->getUsers()) {
106 auto parentLoop = user->getParentOfType<LoopLikeOpInterface>();
107 if (!parentLoop) {
108 if (isa<memref::DeallocOp>(user)) {
109 // Allow dealloc outside of any loop.
110 // TODO: The whole precondition function here is very brittle and will
111 // need to rethought an isolated into a cleaner analysis.
112 continue;
113 }
114 LLVM_DEBUG(DBGS() << "--no parent loop -> fail\n");
115 LLVM_DEBUG(DBGS() << "----due to user: " << *user << "\n");
116 return failure();
117 }
118 if (!skipOverrideAnalysis) {
119 /// Make sure there is no loop-carried dependency on the allocation.
120 if (!overrideBuffer(user, allocOp.getResult())) {
121 LLVM_DEBUG(DBGS() << "--Skip user: found loop-carried dependence\n");
122 continue;
123 }
124 // If this user doesn't dominate all the other users keep looking.
125 if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
126 return !dom.dominates(user, otherUser);
127 })) {
128 LLVM_DEBUG(
129 DBGS() << "--Skip user: does not dominate all other users\n");
130 continue;
131 }
132 } else {
133 if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
134 return !isa<memref::DeallocOp>(otherUser) &&
135 !parentLoop->isProperAncestor(otherUser);
136 })) {
137 LLVM_DEBUG(
138 DBGS()
139 << "--Skip user: not all other users are in the parent loop\n");
140 continue;
141 }
142 }
143 candidateLoop = parentLoop;
144 break;
145 }
146
147 if (!candidateLoop) {
148 LLVM_DEBUG(DBGS() << "Skip alloc: no candidate loop\n");
149 return failure();
150 }
151
152 std::optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
153 std::optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
154 std::optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
155 if (!inductionVar || !lowerBound || !singleStep ||
156 !llvm::hasSingleElement(candidateLoop.getLoopRegions())) {
157 LLVM_DEBUG(DBGS() << "Skip alloc: no single iv, lb, step or region\n");
158 return failure();
159 }
160
161 if (!dom.dominates(allocOp.getOperation(), candidateLoop)) {
162 LLVM_DEBUG(DBGS() << "Skip alloc: does not dominate candidate loop\n");
163 return failure();
164 }
165
166 LLVM_DEBUG(DBGS() << "Start multibuffering loop: " << candidateLoop << "\n");
167
168 // 1. Construct the multi-buffered memref type.
169 ArrayRef<int64_t> originalShape = allocOp.getType().getShape();
170 SmallVector<int64_t, 4> multiBufferedShape{multiBufferingFactor};
171 llvm::append_range(C&: multiBufferedShape, R&: originalShape);
172 LLVM_DEBUG(DBGS() << "--original type: " << allocOp.getType() << "\n");
173 MemRefType mbMemRefType = MemRefType::Builder(allocOp.getType())
174 .setShape(multiBufferedShape)
175 .setLayout(MemRefLayoutAttrInterface());
176 LLVM_DEBUG(DBGS() << "--multi-buffered type: " << mbMemRefType << "\n");
177
178 // 2. Create the multi-buffered alloc.
179 Location loc = allocOp->getLoc();
180 OpBuilder::InsertionGuard g(rewriter);
181 rewriter.setInsertionPoint(allocOp);
182 auto mbAlloc = rewriter.create<memref::AllocOp>(
183 loc, mbMemRefType, ValueRange{}, allocOp->getAttrs());
184 LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n");
185
186 // 3. Within the loop, build the modular leading index (i.e. each loop
187 // iteration %iv accesses slice ((%iv - %lb) / %step) % %mb_factor).
188 rewriter.setInsertionPointToStart(
189 &candidateLoop.getLoopRegions().front()->front());
190 Value ivVal = *inductionVar;
191 Value lbVal = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: *lowerBound);
192 Value stepVal = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: *singleStep);
193 AffineExpr iv, lb, step;
194 bindDims(ctx: rewriter.getContext(), exprs&: iv, exprs&: lb, exprs&: step);
195 Value bufferIndex = affine::makeComposedAffineApply(
196 rewriter, loc, ((iv - lb).floorDiv(other: step)) % multiBufferingFactor,
197 {ivVal, lbVal, stepVal});
198 LLVM_DEBUG(DBGS() << "--multi-buffered indexing: " << bufferIndex << "\n");
199
200 // 4. Build the subview accessing the particular slice, taking modular
201 // rotation into account.
202 int64_t mbMemRefTypeRank = mbMemRefType.getRank();
203 IntegerAttr zero = rewriter.getIndexAttr(0);
204 IntegerAttr one = rewriter.getIndexAttr(1);
205 SmallVector<OpFoldResult> offsets(mbMemRefTypeRank, zero);
206 SmallVector<OpFoldResult> sizes(mbMemRefTypeRank, one);
207 SmallVector<OpFoldResult> strides(mbMemRefTypeRank, one);
208 // Offset is [bufferIndex, 0 ... 0 ].
209 offsets.front() = bufferIndex;
210 // Sizes is [1, original_size_0 ... original_size_n ].
211 for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
212 sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]);
213 // Strides is [1, 1 ... 1 ].
214 auto dstMemref =
215 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
216 originalShape, mbMemRefType, offsets, sizes, strides));
217 Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc,
218 offsets, sizes, strides);
219 LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
220
221 // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to
222 // handle dealloc uses separately..
223 for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) {
224 auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner());
225 if (!deallocOp)
226 continue;
227 OpBuilder::InsertionGuard g(rewriter);
228 rewriter.setInsertionPoint(deallocOp);
229 auto newDeallocOp =
230 rewriter.create<memref::DeallocOp>(deallocOp->getLoc(), mbAlloc);
231 (void)newDeallocOp;
232 LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n");
233 rewriter.eraseOp(deallocOp);
234 }
235
236 // 6. RAUW with the particular slice, taking modular rotation into account.
237 replaceUsesAndPropagateType(rewriter, allocOp, subview);
238
239 // 7. Finally, erase the old allocOp.
240 rewriter.eraseOp(op: allocOp);
241
242 return mbAlloc;
243}
244
245FailureOr<memref::AllocOp>
246mlir::memref::multiBuffer(memref::AllocOp allocOp,
247 unsigned multiBufferingFactor,
248 bool skipOverrideAnalysis) {
249 IRRewriter rewriter(allocOp->getContext());
250 return multiBuffer(rewriter, allocOp, multiBufferingFactor,
251 skipOverrideAnalysis);
252}
253

source code of mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp