1//===----------- MultiBuffering.cpp ---------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements multi buffering transformation.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Arith/Utils/Utils.h"
15#include "mlir/Dialect/MemRef/IR/MemRef.h"
16#include "mlir/Dialect/MemRef/Transforms/Passes.h"
17#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
18#include "mlir/IR/AffineExpr.h"
19#include "mlir/IR/BuiltinAttributes.h"
20#include "mlir/IR/Dominance.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/IR/ValueRange.h"
23#include "mlir/Interfaces/LoopLikeInterface.h"
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/Support/Debug.h"
26
27using namespace mlir;
28
29#define DEBUG_TYPE "memref-transforms"
30#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
31#define DBGSNL() (llvm::dbgs() << "\n")
32
33/// Return true if the op fully overwrite the given `buffer` value.
34static bool overrideBuffer(Operation *op, Value buffer) {
35 auto copyOp = dyn_cast<memref::CopyOp>(op);
36 if (!copyOp)
37 return false;
38 return copyOp.getTarget() == buffer;
39}
40
41/// Replace the uses of `oldOp` with the given `val` and for subview uses
42/// propagate the type change. Changing the memref type may require propagating
43/// it through subview ops so we cannot just do a replaceAllUse but need to
44/// propagate the type change and erase old subview ops.
45static void replaceUsesAndPropagateType(RewriterBase &rewriter,
46 Operation *oldOp, Value val) {
47 SmallVector<Operation *> opsToDelete;
48 SmallVector<OpOperand *> operandsToReplace;
49
50 // Save the operand to replace / delete later (avoid iterator invalidation).
51 // TODO: can we use an early_inc iterator?
52 for (OpOperand &use : oldOp->getUses()) {
53 // Non-subview ops will be replaced by `val`.
54 auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner());
55 if (!subviewUse) {
56 operandsToReplace.push_back(Elt: &use);
57 continue;
58 }
59
60 // `subview(old_op)` is replaced by a new `subview(val)`.
61 OpBuilder::InsertionGuard g(rewriter);
62 rewriter.setInsertionPoint(subviewUse);
63 MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
64 subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
65 subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
66 subviewUse.getStaticStrides());
67 Value newSubview = rewriter.create<memref::SubViewOp>(
68 subviewUse->getLoc(), newType, val, subviewUse.getMixedOffsets(),
69 subviewUse.getMixedSizes(), subviewUse.getMixedStrides());
70
71 // Ouch recursion ... is this really necessary?
72 replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
73
74 opsToDelete.push_back(Elt: use.getOwner());
75 }
76
77 // Perform late replacement.
78 // TODO: can we use an early_inc iterator?
79 for (OpOperand *operand : operandsToReplace) {
80 Operation *op = operand->getOwner();
81 rewriter.startOpModification(op);
82 operand->set(val);
83 rewriter.finalizeOpModification(op);
84 }
85
86 // Perform late op erasure.
87 // TODO: can we use an early_inc iterator?
88 for (Operation *op : opsToDelete)
89 rewriter.eraseOp(op);
90}
91
92// Transformation to do multi-buffering/array expansion to remove dependencies
93// on the temporary allocation between consecutive loop iterations.
94// Returns success if the transformation happened and failure otherwise.
95// This is not a pattern as it requires propagating the new memref type to its
96// uses and requires updating subview ops.
97FailureOr<memref::AllocOp>
98mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
99 unsigned multiBufferingFactor,
100 bool skipOverrideAnalysis) {
101 LLVM_DEBUG(DBGS() << "Start multibuffering: " << allocOp << "\n");
102 DominanceInfo dom(allocOp->getParentOp());
103 LoopLikeOpInterface candidateLoop;
104 for (Operation *user : allocOp->getUsers()) {
105 auto parentLoop = user->getParentOfType<LoopLikeOpInterface>();
106 if (!parentLoop) {
107 if (isa<memref::DeallocOp>(user)) {
108 // Allow dealloc outside of any loop.
109 // TODO: The whole precondition function here is very brittle and will
110 // need to rethought an isolated into a cleaner analysis.
111 continue;
112 }
113 LLVM_DEBUG(DBGS() << "--no parent loop -> fail\n");
114 LLVM_DEBUG(DBGS() << "----due to user: " << *user << "\n");
115 return failure();
116 }
117 if (!skipOverrideAnalysis) {
118 /// Make sure there is no loop-carried dependency on the allocation.
119 if (!overrideBuffer(user, allocOp.getResult())) {
120 LLVM_DEBUG(DBGS() << "--Skip user: found loop-carried dependence\n");
121 continue;
122 }
123 // If this user doesn't dominate all the other users keep looking.
124 if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
125 return !dom.dominates(user, otherUser);
126 })) {
127 LLVM_DEBUG(
128 DBGS() << "--Skip user: does not dominate all other users\n");
129 continue;
130 }
131 } else {
132 if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
133 return !isa<memref::DeallocOp>(otherUser) &&
134 !parentLoop->isProperAncestor(otherUser);
135 })) {
136 LLVM_DEBUG(
137 DBGS()
138 << "--Skip user: not all other users are in the parent loop\n");
139 continue;
140 }
141 }
142 candidateLoop = parentLoop;
143 break;
144 }
145
146 if (!candidateLoop) {
147 LLVM_DEBUG(DBGS() << "Skip alloc: no candidate loop\n");
148 return failure();
149 }
150
151 std::optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
152 std::optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
153 std::optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
154 if (!inductionVar || !lowerBound || !singleStep ||
155 !llvm::hasSingleElement(candidateLoop.getLoopRegions())) {
156 LLVM_DEBUG(DBGS() << "Skip alloc: no single iv, lb, step or region\n");
157 return failure();
158 }
159
160 if (!dom.dominates(allocOp.getOperation(), candidateLoop)) {
161 LLVM_DEBUG(DBGS() << "Skip alloc: does not dominate candidate loop\n");
162 return failure();
163 }
164
165 LLVM_DEBUG(DBGS() << "Start multibuffering loop: " << candidateLoop << "\n");
166
167 // 1. Construct the multi-buffered memref type.
168 ArrayRef<int64_t> originalShape = allocOp.getType().getShape();
169 SmallVector<int64_t, 4> multiBufferedShape{multiBufferingFactor};
170 llvm::append_range(C&: multiBufferedShape, R&: originalShape);
171 LLVM_DEBUG(DBGS() << "--original type: " << allocOp.getType() << "\n");
172 MemRefType mbMemRefType = MemRefType::Builder(allocOp.getType())
173 .setShape(multiBufferedShape)
174 .setLayout(MemRefLayoutAttrInterface());
175 LLVM_DEBUG(DBGS() << "--multi-buffered type: " << mbMemRefType << "\n");
176
177 // 2. Create the multi-buffered alloc.
178 Location loc = allocOp->getLoc();
179 OpBuilder::InsertionGuard g(rewriter);
180 rewriter.setInsertionPoint(allocOp);
181 auto mbAlloc = rewriter.create<memref::AllocOp>(
182 loc, mbMemRefType, ValueRange{}, allocOp->getAttrs());
183 LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n");
184
185 // 3. Within the loop, build the modular leading index (i.e. each loop
186 // iteration %iv accesses slice ((%iv - %lb) / %step) % %mb_factor).
187 rewriter.setInsertionPointToStart(
188 &candidateLoop.getLoopRegions().front()->front());
189 Value ivVal = *inductionVar;
190 Value lbVal = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: *lowerBound);
191 Value stepVal = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: *singleStep);
192 AffineExpr iv, lb, step;
193 bindDims(ctx: rewriter.getContext(), exprs&: iv, exprs&: lb, exprs&: step);
194 Value bufferIndex = affine::makeComposedAffineApply(
195 rewriter, loc, ((iv - lb).floorDiv(other: step)) % multiBufferingFactor,
196 {ivVal, lbVal, stepVal});
197 LLVM_DEBUG(DBGS() << "--multi-buffered indexing: " << bufferIndex << "\n");
198
199 // 4. Build the subview accessing the particular slice, taking modular
200 // rotation into account.
201 int64_t mbMemRefTypeRank = mbMemRefType.getRank();
202 IntegerAttr zero = rewriter.getIndexAttr(0);
203 IntegerAttr one = rewriter.getIndexAttr(1);
204 SmallVector<OpFoldResult> offsets(mbMemRefTypeRank, zero);
205 SmallVector<OpFoldResult> sizes(mbMemRefTypeRank, one);
206 SmallVector<OpFoldResult> strides(mbMemRefTypeRank, one);
207 // Offset is [bufferIndex, 0 ... 0 ].
208 offsets.front() = bufferIndex;
209 // Sizes is [1, original_size_0 ... original_size_n ].
210 for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
211 sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]);
212 // Strides is [1, 1 ... 1 ].
213 MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType(
214 originalShape, mbMemRefType, offsets, sizes, strides);
215 Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc,
216 offsets, sizes, strides);
217 LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
218
219 // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to
220 // handle dealloc uses separately..
221 for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) {
222 auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner());
223 if (!deallocOp)
224 continue;
225 OpBuilder::InsertionGuard g(rewriter);
226 rewriter.setInsertionPoint(deallocOp);
227 auto newDeallocOp =
228 rewriter.create<memref::DeallocOp>(deallocOp->getLoc(), mbAlloc);
229 (void)newDeallocOp;
230 LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n");
231 rewriter.eraseOp(deallocOp);
232 }
233
234 // 6. RAUW with the particular slice, taking modular rotation into account.
235 replaceUsesAndPropagateType(rewriter, allocOp, subview);
236
237 // 7. Finally, erase the old allocOp.
238 rewriter.eraseOp(op: allocOp);
239
240 return mbAlloc;
241}
242
243FailureOr<memref::AllocOp>
244mlir::memref::multiBuffer(memref::AllocOp allocOp,
245 unsigned multiBufferingFactor,
246 bool skipOverrideAnalysis) {
247 IRRewriter rewriter(allocOp->getContext());
248 return multiBuffer(rewriter, allocOp, multiBufferingFactor,
249 skipOverrideAnalysis);
250}
251

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp