1//===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to pipeline data transfers.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/Passes.h"
14
15#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
16#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
17#include "mlir/Dialect/Affine/Analysis/Utils.h"
18#include "mlir/Dialect/Affine/IR/AffineOps.h"
19#include "mlir/Dialect/Affine/LoopUtils.h"
20#include "mlir/Dialect/Affine/Utils.h"
21#include "mlir/Dialect/Arith/Utils/Utils.h"
22#include "mlir/Dialect/Func/IR/FuncOps.h"
23#include "mlir/Dialect/MemRef/IR/MemRef.h"
24#include "mlir/IR/Builders.h"
25#include "mlir/Transforms/Passes.h"
26#include "llvm/ADT/DenseMap.h"
27#include "llvm/Support/Debug.h"
28
29namespace mlir {
30namespace affine {
31#define GEN_PASS_DEF_AFFINEPIPELINEDATATRANSFER
32#include "mlir/Dialect/Affine/Passes.h.inc"
33} // namespace affine
34} // namespace mlir
35
36#define DEBUG_TYPE "affine-pipeline-data-transfer"
37
38using namespace mlir;
39using namespace mlir::affine;
40
41namespace {
42struct PipelineDataTransfer
43 : public affine::impl::AffinePipelineDataTransferBase<
44 PipelineDataTransfer> {
45 void runOnOperation() override;
46 void runOnAffineForOp(AffineForOp forOp);
47
48 std::vector<AffineForOp> forOps;
49};
50
51} // namespace
52
53/// Creates a pass to pipeline explicit movement of data across levels of the
54/// memory hierarchy.
55std::unique_ptr<OperationPass<func::FuncOp>>
56mlir::affine::createPipelineDataTransferPass() {
57 return std::make_unique<PipelineDataTransfer>();
58}
59
60// Returns the position of the tag memref operand given a DMA operation.
61// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
62// added.
63static unsigned getTagMemRefPos(Operation &dmaOp) {
64 assert((isa<AffineDmaStartOp, AffineDmaWaitOp>(dmaOp)));
65 if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaOp)) {
66 return dmaStartOp.getTagMemRefOperandIndex();
67 }
68 // First operand for a dma finish operation.
69 return 0;
70}
71
72/// Doubles the buffer of the supplied memref on the specified 'affine.for'
73/// operation by adding a leading dimension of size two to the memref.
74/// Replaces all uses of the old memref by the new one while indexing the newly
75/// added dimension by the loop IV of the specified 'affine.for' operation
76/// modulo 2. Returns false if such a replacement cannot be performed.
77static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
78 auto *forBody = forOp.getBody();
79 OpBuilder bInner(forBody, forBody->begin());
80
81 // Doubles the shape with a leading dimension extent of 2.
82 auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
83 // Add the leading dimension in the shape for the double buffer.
84 ArrayRef<int64_t> oldShape = oldMemRefType.getShape();
85 SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
86 newShape[0] = 2;
87 std::copy(first: oldShape.begin(), last: oldShape.end(), result: newShape.begin() + 1);
88 return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({});
89 };
90
91 auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
92 auto newMemRefType = doubleShape(oldMemRefType);
93
94 // The double buffer is allocated right before 'forOp'.
95 OpBuilder bOuter(forOp);
96 // Put together alloc operands for any dynamic dimensions of the memref.
97 SmallVector<Value, 4> allocOperands;
98 for (const auto &dim : llvm::enumerate(oldMemRefType.getShape())) {
99 if (dim.value() == ShapedType::kDynamic)
100 allocOperands.push_back(bOuter.createOrFold<memref::DimOp>(
101 forOp.getLoc(), oldMemRef, dim.index()));
102 }
103
104 // Create and place the alloc right before the 'affine.for' operation.
105 Value newMemRef = bOuter.create<memref::AllocOp>(
106 forOp.getLoc(), newMemRefType, allocOperands);
107
108 // Create 'iv mod 2' value to index the leading dimension.
109 auto d0 = bInner.getAffineDimExpr(position: 0);
110 int64_t step = forOp.getStepAsInt();
111 auto modTwoMap =
112 AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2);
113 auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap,
114 forOp.getInductionVar());
115
116 // replaceAllMemRefUsesWith will succeed unless the forOp body has
117 // non-dereferencing uses of the memref (dealloc's are fine though).
118 if (failed(replaceAllMemRefUsesWith(
119 oldMemRef, newMemRef,
120 /*extraIndices=*/{ivModTwoOp},
121 /*indexRemap=*/AffineMap(),
122 /*extraOperands=*/{},
123 /*symbolOperands=*/{},
124 /*domOpFilter=*/&*forOp.getBody()->begin()))) {
125 LLVM_DEBUG(
126 forOp.emitError("memref replacement for double buffering failed"));
127 ivModTwoOp.erase();
128 return false;
129 }
130 // Insert the dealloc op right after the for loop.
131 bOuter.setInsertionPointAfter(forOp);
132 bOuter.create<memref::DeallocOp>(forOp.getLoc(), newMemRef);
133
134 return true;
135}
136
137/// Returns success if the IR is in a valid state.
138void PipelineDataTransfer::runOnOperation() {
139 // Do a post order walk so that inner loop DMAs are processed first. This is
140 // necessary since 'affine.for' operations nested within would otherwise
141 // become invalid (erased) when the outer loop is pipelined (the pipelined one
142 // gets deleted and replaced by a prologue, a new steady-state loop and an
143 // epilogue).
144 forOps.clear();
145 getOperation().walk([&](AffineForOp forOp) { forOps.push_back(forOp); });
146 for (auto forOp : forOps)
147 runOnAffineForOp(forOp);
148}
149
150// Check if tags of the dma start op and dma wait op match.
151static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp) {
152 if (startOp.getTagMemRef() != waitOp.getTagMemRef())
153 return false;
154 auto startIndices = startOp.getTagIndices();
155 auto waitIndices = waitOp.getTagIndices();
156 // Both of these have the same number of indices since they correspond to the
157 // same tag memref.
158 for (auto it = startIndices.begin(), wIt = waitIndices.begin(),
159 e = startIndices.end();
160 it != e; ++it, ++wIt) {
161 // Keep it simple for now, just checking if indices match.
162 // TODO: this would in general need to check if there is no
163 // intervening write writing to the same tag location, i.e., memory last
164 // write/data flow analysis. This is however sufficient/powerful enough for
165 // now since the DMA generation pass or the input for it will always have
166 // start/wait with matching tags (same SSA operand indices).
167 if (*it != *wIt)
168 return false;
169 }
170 return true;
171}
172
173// Identify matching DMA start/finish operations to overlap computation with.
174static void findMatchingStartFinishInsts(
175 AffineForOp forOp,
176 SmallVectorImpl<std::pair<Operation *, Operation *>> &startWaitPairs) {
177
178 // Collect outgoing DMA operations - needed to check for dependences below.
179 SmallVector<AffineDmaStartOp, 4> outgoingDmaOps;
180 for (auto &op : *forOp.getBody()) {
181 auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
182 if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster())
183 outgoingDmaOps.push_back(dmaStartOp);
184 }
185
186 SmallVector<Operation *, 4> dmaStartInsts, dmaFinishInsts;
187 for (auto &op : *forOp.getBody()) {
188 // Collect DMA finish operations.
189 if (isa<AffineDmaWaitOp>(op)) {
190 dmaFinishInsts.push_back(&op);
191 continue;
192 }
193 auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
194 if (!dmaStartOp)
195 continue;
196
197 // Only DMAs incoming into higher memory spaces are pipelined for now.
198 // TODO: handle outgoing DMA pipelining.
199 if (!dmaStartOp.isDestMemorySpaceFaster())
200 continue;
201
202 // Check for dependence with outgoing DMAs. Doing this conservatively.
203 // TODO: use the dependence analysis to check for
204 // dependences between an incoming and outgoing DMA in the same iteration.
205 auto *it = outgoingDmaOps.begin();
206 for (; it != outgoingDmaOps.end(); ++it) {
207 if (it->getDstMemRef() == dmaStartOp.getSrcMemRef())
208 break;
209 }
210 if (it != outgoingDmaOps.end())
211 continue;
212
213 // We only double buffer if the buffer is not live out of loop.
214 auto memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos());
215 bool escapingUses = false;
216 for (auto *user : memref.getUsers()) {
217 // We can double buffer regardless of dealloc's outside the loop.
218 if (isa<memref::DeallocOp>(user))
219 continue;
220 if (!forOp.getBody()->findAncestorOpInBlock(*user)) {
221 LLVM_DEBUG(llvm::dbgs()
222 << "can't pipeline: buffer is live out of loop\n";);
223 escapingUses = true;
224 break;
225 }
226 }
227 if (!escapingUses)
228 dmaStartInsts.push_back(&op);
229 }
230
231 // For each start operation, we look for a matching finish operation.
232 for (auto *dmaStartOp : dmaStartInsts) {
233 for (auto *dmaFinishOp : dmaFinishInsts) {
234 if (checkTagMatch(startOp: cast<AffineDmaStartOp>(Val: dmaStartOp),
235 waitOp: cast<AffineDmaWaitOp>(Val: dmaFinishOp))) {
236 startWaitPairs.push_back(Elt: {dmaStartOp, dmaFinishOp});
237 break;
238 }
239 }
240 }
241}
242
243/// Overlap DMA transfers with computation in this loop. If successful,
244/// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are
245/// inserted right before where it was.
246void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
247 auto mayBeConstTripCount = getConstantTripCount(forOp);
248 if (!mayBeConstTripCount) {
249 LLVM_DEBUG(forOp.emitRemark("won't pipeline due to unknown trip count"));
250 return;
251 }
252
253 SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs;
254 findMatchingStartFinishInsts(forOp, startWaitPairs);
255
256 if (startWaitPairs.empty()) {
257 LLVM_DEBUG(forOp.emitRemark("No dma start/finish pairs\n"));
258 return;
259 }
260
261 // Double the buffers for the higher memory space memref's.
262 // Identify memref's to replace by scanning through all DMA start
263 // operations. A DMA start operation has two memref's - the one from the
264 // higher level of memory hierarchy is the one to double buffer.
265 // TODO: check whether double-buffering is even necessary.
266 // TODO: make this work with different layouts: assuming here that
267 // the dimension we are adding here for the double buffering is the outermost
268 // dimension.
269 for (auto &pair : startWaitPairs) {
270 auto *dmaStartOp = pair.first;
271 Value oldMemRef = dmaStartOp->getOperand(
272 idx: cast<AffineDmaStartOp>(Val: dmaStartOp).getFasterMemPos());
273 if (!doubleBuffer(oldMemRef, forOp)) {
274 // Normally, double buffering should not fail because we already checked
275 // that there are no uses outside.
276 LLVM_DEBUG(llvm::dbgs()
277 << "double buffering failed for" << dmaStartOp << "\n";);
278 // IR still valid and semantically correct.
279 return;
280 }
281 // If the old memref has no more uses, remove its 'dead' alloc if it was
282 // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim'
283 // operation could have been used on it if it was dynamically shaped in
284 // order to create the double buffer above.)
285 // '-canonicalize' does this in a more general way, but we'll anyway do the
286 // simple/common case so that the output / test cases looks clear.
287 if (auto *allocOp = oldMemRef.getDefiningOp()) {
288 if (oldMemRef.use_empty()) {
289 allocOp->erase();
290 } else if (oldMemRef.hasOneUse()) {
291 if (auto dealloc =
292 dyn_cast<memref::DeallocOp>(*oldMemRef.user_begin())) {
293 dealloc.erase();
294 allocOp->erase();
295 }
296 }
297 }
298 }
299
300 // Double the buffers for tag memrefs.
301 for (auto &pair : startWaitPairs) {
302 auto *dmaFinishOp = pair.second;
303 Value oldTagMemRef = dmaFinishOp->getOperand(idx: getTagMemRefPos(dmaOp&: *dmaFinishOp));
304 if (!doubleBuffer(oldTagMemRef, forOp)) {
305 LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
306 return;
307 }
308 // If the old tag has no uses or a single dealloc use, remove it.
309 // (canonicalization handles more complex cases).
310 if (auto *tagAllocOp = oldTagMemRef.getDefiningOp()) {
311 if (oldTagMemRef.use_empty()) {
312 tagAllocOp->erase();
313 } else if (oldTagMemRef.hasOneUse()) {
314 if (auto dealloc =
315 dyn_cast<memref::DeallocOp>(*oldTagMemRef.user_begin())) {
316 dealloc.erase();
317 tagAllocOp->erase();
318 }
319 }
320 }
321 }
322
323 // Double buffering would have invalidated all the old DMA start/wait insts.
324 startWaitPairs.clear();
325 findMatchingStartFinishInsts(forOp, startWaitPairs);
326
327 // Store shift for operation for later lookup for AffineApplyOp's.
328 DenseMap<Operation *, unsigned> instShiftMap;
329 for (auto &pair : startWaitPairs) {
330 auto *dmaStartOp = pair.first;
331 assert(isa<AffineDmaStartOp>(dmaStartOp));
332 instShiftMap[dmaStartOp] = 0;
333 // Set shifts for DMA start op's affine operand computation slices to 0.
334 SmallVector<AffineApplyOp, 4> sliceOps;
335 affine::createAffineComputationSlice(opInst: dmaStartOp, sliceOps: &sliceOps);
336 if (!sliceOps.empty()) {
337 for (auto sliceOp : sliceOps) {
338 instShiftMap[sliceOp.getOperation()] = 0;
339 }
340 } else {
341 // If a slice wasn't created, the reachable affine.apply op's from its
342 // operands are the ones that go with it.
343 SmallVector<Operation *, 4> affineApplyInsts;
344 SmallVector<Value, 4> operands(dmaStartOp->getOperands());
345 getReachableAffineApplyOps(operands, affineApplyOps&: affineApplyInsts);
346 for (auto *op : affineApplyInsts) {
347 instShiftMap[op] = 0;
348 }
349 }
350 }
351 // Everything else (including compute ops and dma finish) are shifted by one.
352 for (auto &op : forOp.getBody()->without_terminator())
353 if (!instShiftMap.contains(&op))
354 instShiftMap[&op] = 1;
355
356 // Get shifts stored in map.
357 SmallVector<uint64_t, 8> shifts(forOp.getBody()->getOperations().size());
358 unsigned s = 0;
359 for (auto &op : forOp.getBody()->without_terminator()) {
360 assert(instShiftMap.contains(&op));
361 shifts[s++] = instShiftMap[&op];
362
363 // Tagging operations with shifts for debugging purposes.
364 LLVM_DEBUG({
365 OpBuilder b(&op);
366 op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1]));
367 });
368 }
369
370 if (!isOpwiseShiftValid(forOp, shifts)) {
371 // Violates dependences.
372 LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
373 return;
374 }
375
376 if (failed(affineForOpBodySkew(forOp, shifts))) {
377 LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";);
378 return;
379 }
380}
381

source code of mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp