1//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements affine fusion.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/Passes.h"
14
15#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
16#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
17#include "mlir/Dialect/Affine/Analysis/Utils.h"
18#include "mlir/Dialect/Affine/LoopFusionUtils.h"
19#include "mlir/Dialect/Affine/LoopUtils.h"
20#include "mlir/Dialect/Affine/Utils.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/IR/AffineExpr.h"
23#include "mlir/IR/AffineMap.h"
24#include "mlir/IR/Builders.h"
25#include "llvm/ADT/DenseMap.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/Support/CommandLine.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/raw_ostream.h"
30#include <iomanip>
31#include <optional>
32#include <sstream>
33
34namespace mlir {
35namespace affine {
36#define GEN_PASS_DEF_AFFINELOOPFUSION
37#include "mlir/Dialect/Affine/Passes.h.inc"
38} // namespace affine
39} // namespace mlir
40
41#define DEBUG_TYPE "affine-fusion"
42
43using namespace mlir;
44using namespace mlir::affine;
45
46namespace {
47/// Loop fusion pass. This pass currently supports a greedy fusion policy,
48/// which fuses loop nests with single-writer/single-reader memref dependences
49/// with the goal of improving locality.
50// TODO: Support fusion of source loop nests which write to multiple
51// memrefs, where each memref can have multiple users (if profitable).
52struct LoopFusion : public affine::impl::AffineLoopFusionBase<LoopFusion> {
53 LoopFusion() = default;
54 LoopFusion(unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes,
55 bool maximalFusion, enum FusionMode affineFusionMode) {
56 this->fastMemorySpace = fastMemorySpace;
57 this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024;
58 this->maximalFusion = maximalFusion;
59 this->affineFusionMode = affineFusionMode;
60 }
61
62 void runOnBlock(Block *block);
63 void runOnOperation() override;
64};
65
66} // namespace
67
68/// Returns true if node 'srcId' can be removed after fusing it with node
69/// 'dstId'. The node can be removed if any of the following conditions are met:
70/// 1. 'srcId' has no output dependences after fusion and no escaping memrefs.
71/// 2. 'srcId' has no output dependences after fusion, has escaping memrefs
72/// and the fusion slice is maximal.
73/// 3. 'srcId' has output dependences after fusion, the fusion slice is
74/// maximal and the fusion insertion point dominates all the dependences.
75static bool canRemoveSrcNodeAfterFusion(
76 unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice,
77 Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs,
78 const MemRefDependenceGraph &mdg) {
79
80 Operation *dstNodeOp = mdg.getNode(id: dstId)->op;
81 bool hasOutDepsAfterFusion = false;
82
83 for (auto &outEdge : mdg.outEdges.lookup(Val: srcId)) {
84 Operation *depNodeOp = mdg.getNode(id: outEdge.id)->op;
85 // Skip dependence with dstOp since it will be removed after fusion.
86 if (depNodeOp == dstNodeOp)
87 continue;
88
89 // Only fusion within the same block is supported. Use domination analysis
90 // when needed.
91 if (depNodeOp->getBlock() != dstNodeOp->getBlock())
92 return false;
93
94 // Check if the insertion point of the fused loop dominates the dependence.
95 // Otherwise, the src loop can't be removed.
96 if (fusedLoopInsPoint != depNodeOp &&
97 !fusedLoopInsPoint->isBeforeInBlock(other: depNodeOp)) {
98 LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't "
99 "dominate dependence\n");
100 return false;
101 }
102
103 hasOutDepsAfterFusion = true;
104 }
105
106 // If src loop has dependences after fusion or it writes to an live-out or
107 // escaping memref, we can only remove it if the fusion slice is maximal so
108 // that all the dependences are preserved.
109 if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
110 std::optional<bool> isMaximal = fusionSlice.isMaximal();
111 if (!isMaximal) {
112 LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine "
113 "if fusion is maximal\n");
114 return false;
115 }
116
117 if (!*isMaximal) {
118 LLVM_DEBUG(llvm::dbgs()
119 << "Src loop can't be removed: fusion is not maximal\n");
120 return false;
121 }
122 }
123
124 return true;
125}
126
127/// Returns in 'srcIdCandidates' the producer fusion candidates for consumer
128/// 'dstId'. Candidates are sorted by node id order. This order corresponds to
129/// the program order when the 'mdg' is created. However, program order is not
130/// guaranteed and must not be required by the client. Program order won't be
131/// held if the 'mdg' is reused from a previous fusion step or if the node
132/// creation order changes in the future to support more advance cases.
133// TODO: Move this to a loop fusion utility once 'mdg' is also moved.
134static void getProducerCandidates(unsigned dstId,
135 const MemRefDependenceGraph &mdg,
136 SmallVectorImpl<unsigned> &srcIdCandidates) {
137 // Skip if no input edges along which to fuse.
138 if (mdg.inEdges.count(Val: dstId) == 0)
139 return;
140
141 // Gather memrefs from loads in 'dstId'.
142 auto *dstNode = mdg.getNode(id: dstId);
143 DenseSet<Value> consumedMemrefs;
144 for (Operation *load : dstNode->loads)
145 consumedMemrefs.insert(V: cast<AffineReadOpInterface>(Val: load).getMemRef());
146
147 // Traverse 'dstId' incoming edges and gather the nodes that contain a store
148 // to one of the consumed memrefs.
149 for (const auto &srcEdge : mdg.inEdges.lookup(Val: dstId)) {
150 const auto *srcNode = mdg.getNode(id: srcEdge.id);
151 // Skip if 'srcNode' is not a loop nest.
152 if (!isa<AffineForOp>(Val: srcNode->op))
153 continue;
154
155 if (any_of(Range: srcNode->stores, P: [&](Operation *op) {
156 auto storeOp = cast<AffineWriteOpInterface>(Val: op);
157 return consumedMemrefs.count(V: storeOp.getMemRef()) > 0;
158 }))
159 srcIdCandidates.push_back(Elt: srcNode->id);
160 }
161
162 llvm::sort(C&: srcIdCandidates);
163 srcIdCandidates.erase(CS: llvm::unique(R&: srcIdCandidates), CE: srcIdCandidates.end());
164}
165
166/// Returns in 'producerConsumerMemrefs' the memrefs involved in a
167/// producer-consumer dependence between 'srcId' and 'dstId'.
168static void
169gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
170 const MemRefDependenceGraph &mdg,
171 DenseSet<Value> &producerConsumerMemrefs) {
172 auto *dstNode = mdg.getNode(id: dstId);
173 auto *srcNode = mdg.getNode(id: srcId);
174 gatherProducerConsumerMemrefs(srcOps: srcNode->stores, dstOps: dstNode->loads,
175 producerConsumerMemrefs);
176}
177
178/// A memref escapes in the context of the fusion pass if either:
179/// 1. it (or its alias) is a block argument, or
180/// 2. created by an op not known to guarantee alias freedom,
181/// 3. it (or its alias) are used by ops other than affine dereferencing ops
182/// (e.g., by call op, memref load/store ops, alias creating ops, unknown ops,
183/// terminator ops, etc.); such ops do not deference the memref in an affine
184/// way.
185static bool isEscapingMemref(Value memref, Block *block) {
186 Operation *defOp = memref.getDefiningOp();
187 // Check if 'memref' is a block argument.
188 if (!defOp)
189 return true;
190
191 // Check if this is defined to be an alias of another memref.
192 if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(Val: defOp))
193 if (isEscapingMemref(memref: viewOp.getViewSource(), block))
194 return true;
195
196 // Any op besides allocating ops wouldn't guarantee alias freedom
197 if (!hasSingleEffect<mlir::MemoryEffects::Allocate>(op: defOp, value: memref))
198 return true;
199
200 // Check if 'memref' is used by a non-deferencing op (including unknown ones)
201 // (e.g., call ops, alias creating ops, etc.).
202 return llvm::any_of(Range: memref.getUsers(), P: [&](Operation *user) {
203 // Ignore users outside of `block`.
204 Operation *ancestorOp = block->getParent()->findAncestorOpInRegion(op&: *user);
205 if (!ancestorOp)
206 return true;
207 if (ancestorOp->getBlock() != block)
208 return false;
209 return !isa<AffineMapAccessInterface>(Val: *user);
210 });
211}
212
213/// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
214/// that escape the block or are accessed in a non-affine way.
215static void gatherEscapingMemrefs(unsigned id, const MemRefDependenceGraph &mdg,
216 DenseSet<Value> &escapingMemRefs) {
217 auto *node = mdg.getNode(id);
218 for (Operation *storeOp : node->stores) {
219 auto memref = cast<AffineWriteOpInterface>(Val: storeOp).getMemRef();
220 if (escapingMemRefs.count(V: memref))
221 continue;
222 if (isEscapingMemref(memref, block: &mdg.block))
223 escapingMemRefs.insert(V: memref);
224 }
225}
226
227// Sinks all sequential loops to the innermost levels (while preserving
228// relative order among them) and moves all parallel loops to the
229// outermost (while again preserving relative order among them).
230// This can increase the loop depth at which we can fuse a slice, since we are
231// pushing loop carried dependence to a greater depth in the loop nest.
232static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
233 assert(isa<AffineForOp>(node->op));
234 AffineForOp newRootForOp = sinkSequentialLoops(forOp: cast<AffineForOp>(Val: node->op));
235 node->op = newRootForOp;
236}
237
238/// Get the operation that should act as a dominance filter while replacing
239/// memref uses with a private memref for which `producerStores` and
240/// `sliceInsertionBlock` are provided. This effectively determines in what
241/// part of the IR we should be performing the replacement.
242static Operation *
243getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock,
244 ArrayRef<Operation *> producerStores) {
245 assert(!producerStores.empty() && "expected producer store");
246
247 // We first find the common block that contains the producer stores and
248 // the slice computation. The first ancestor among the ancestors of the
249 // producer stores in that common block is the dominance filter to use for
250 // replacement.
251 Block *commonBlock = nullptr;
252 // Find the common block of all relevant operations.
253 for (Operation *store : producerStores) {
254 Operation *otherOp =
255 !commonBlock ? &*sliceInsertionBlock->begin() : &*commonBlock->begin();
256 commonBlock = findInnermostCommonBlockInScope(a: store, b: otherOp);
257 }
258 assert(commonBlock &&
259 "common block of producer stores and slice should exist");
260
261 // Find the first ancestor among the ancestors of `producerStores` in
262 // `commonBlock`.
263 Operation *firstAncestor = nullptr;
264 for (Operation *store : producerStores) {
265 Operation *ancestor = commonBlock->findAncestorOpInBlock(op&: *store);
266 assert(ancestor && "producer store should be contained in common block");
267 firstAncestor = !firstAncestor || ancestor->isBeforeInBlock(other: firstAncestor)
268 ? ancestor
269 : firstAncestor;
270 }
271 return firstAncestor;
272}
273
274/// Returns the amount of additional (redundant) computation that will be done
275/// as a fraction of the total computation if `srcForOp` is fused into
276/// `dstForOp` at depth `depth`. The method returns the compute cost of the
277/// slice and the fused nest's compute cost in the trailing output arguments.
278static std::optional<double> getAdditionalComputeFraction(
279 AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth,
280 ArrayRef<ComputationSliceState> depthSliceUnions, int64_t &sliceCost,
281 int64_t &fusedLoopNestComputeCost) {
282 LLVM_DEBUG(llvm::dbgs() << "Determining additional compute fraction...\n";);
283 // Compute cost of sliced and unsliced src loop nest.
284 // Walk src loop nest and collect stats.
285 LoopNestStats srcLoopNestStats;
286 if (!getLoopNestStats(forOp: srcForOp, stats: &srcLoopNestStats)) {
287 LLVM_DEBUG(llvm::dbgs() << "Failed to get source loop nest stats.\n");
288 return std::nullopt;
289 }
290
291 // Compute cost of dst loop nest.
292 LoopNestStats dstLoopNestStats;
293 if (!getLoopNestStats(forOp: dstForOp, stats: &dstLoopNestStats)) {
294 LLVM_DEBUG(llvm::dbgs() << "Failed to get destination loop nest stats.\n");
295 return std::nullopt;
296 }
297
298 // Compute op instance count for the src loop nest without iteration slicing.
299 uint64_t srcLoopNestCost = getComputeCost(forOp: srcForOp, stats&: srcLoopNestStats);
300
301 // Compute op cost for the dst loop nest.
302 uint64_t dstLoopNestCost = getComputeCost(forOp: dstForOp, stats&: dstLoopNestStats);
303
304 const ComputationSliceState &slice = depthSliceUnions[depth - 1];
305 // Skip slice union if it wasn't computed for this depth.
306 if (slice.isEmpty()) {
307 LLVM_DEBUG(llvm::dbgs() << "Slice wasn't computed.\n");
308 return std::nullopt;
309 }
310
311 if (!getFusionComputeCost(srcForOp, srcStats&: srcLoopNestStats, dstForOp,
312 dstStats&: dstLoopNestStats, slice,
313 computeCost: &fusedLoopNestComputeCost)) {
314 LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n");
315 return std::nullopt;
316 }
317
318 double additionalComputeFraction =
319 fusedLoopNestComputeCost /
320 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
321 1;
322
323 return additionalComputeFraction;
324}
325
326// Creates and returns a private (single-user) memref for fused loop rooted at
327// 'forOp', with (potentially reduced) memref size based on the memref region
328// written to by `storeOps` at depth 'dstLoopDepth'. 'sliceInsertionBlock'
329// specifies the block in which the slice was/will be inserted. The method
330// expects that all stores ops to the memref have the same access function.
331// Returns nullptr if the creation failed.
332static Value createPrivateMemRef(AffineForOp forOp,
333 ArrayRef<Operation *> storeOps,
334 unsigned dstLoopDepth,
335 std::optional<unsigned> fastMemorySpace,
336 Block *sliceInsertionBlock,
337 uint64_t localBufSizeThreshold) {
338 assert(!storeOps.empty() && "no source stores supplied");
339
340 // Check if all stores have the same access function; we only support this
341 // case.
342 // TODO: Use union of memref write regions to compute private memref footprint
343 // for store ops with different access functions.
344 if (storeOps.size() > 1 &&
345 !std::equal(first1: std::next(x: storeOps.begin()), last1: storeOps.end(), first2: storeOps.begin(),
346 binary_pred: [](Operation *a, Operation *b) {
347 MemRefAccess aM(cast<AffineWriteOpInterface>(Val: a));
348 MemRefAccess bM(cast<AffineWriteOpInterface>(Val: b));
349 return aM == bM;
350 })) {
351 LLVM_DEBUG(llvm::dbgs()
352 << "Private memref creation unsupported for multiple producer "
353 "stores with different access functions.\n");
354 return nullptr;
355 }
356
357 Operation *srcStoreOp = storeOps[0];
358
359 // Create builder to insert alloc op just before 'forOp'.
360 OpBuilder b(forOp);
361 // Builder to create constants at the top level.
362 OpBuilder top(forOp->getParentRegion());
363 // Create new memref type based on slice bounds.
364 auto oldMemRef = cast<AffineWriteOpInterface>(Val: srcStoreOp).getMemRef();
365 auto oldMemRefType = cast<MemRefType>(Val: oldMemRef.getType());
366 unsigned rank = oldMemRefType.getRank();
367
368 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
369 MemRefRegion region(srcStoreOp->getLoc());
370 bool validRegion = succeeded(
371 Result: region.compute(op: srcStoreOp, loopDepth: dstLoopDepth, /*sliceState=*/nullptr,
372 /*addMemRefDimBounds=*/true, /*dropLocalVars=*/false));
373
374 (void)validRegion;
375 assert(validRegion && "unexpected memref region failure");
376 SmallVector<int64_t, 4> newShape;
377 SmallVector<AffineMap, 4> lbs;
378 lbs.reserve(N: rank);
379 // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
380 // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
381 std::optional<int64_t> numElements =
382 region.getConstantBoundingSizeAndShape(shape: &newShape, lbs: &lbs);
383 assert(numElements && "non-constant number of elts in local buffer");
384
385 const FlatAffineValueConstraints *cst = region.getConstraints();
386 // 'outerIVs' holds the values that this memory region is symbolic/parametric
387 // on; this would correspond to loop IVs surrounding the level at which the
388 // slice is being materialized.
389 SmallVector<Value, 8> outerIVs;
390 cst->getValues(start: rank, end: cst->getNumDimAndSymbolVars(), values: &outerIVs);
391
392 // Build 'rank' AffineExprs from MemRefRegion 'lbs'
393 SmallVector<AffineExpr, 4> offsets;
394 offsets.reserve(N: rank);
395
396 // Outer IVs are considered symbols during memref region computation. Replace
397 // them uniformly with dims so that valid IR is guaranteed.
398 SmallVector<AffineExpr> replacements;
399 for (unsigned j = 0, e = lbs[0].getNumSymbols(); j < e; ++j)
400 replacements.push_back(Elt: mlir::getAffineDimExpr(position: j, context: forOp.getContext()));
401 for (unsigned d = 0; d < rank; ++d) {
402 assert(lbs[d].getNumResults() == 1 &&
403 "invalid private memref bound calculation");
404 offsets.push_back(Elt: lbs[d].getResult(idx: 0).replaceSymbols(symReplacements: replacements));
405 }
406
407 // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
408 // by 'srcStoreOpInst'.
409 auto eltSize = getMemRefIntOrFloatEltSizeInBytes(memRefType: oldMemRefType);
410 assert(eltSize && "memrefs with size elt types expected");
411 uint64_t bufSize = *eltSize * *numElements;
412 Attribute newMemSpace;
413 if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
414 newMemSpace = b.getI64IntegerAttr(value: *fastMemorySpace);
415 } else {
416 newMemSpace = oldMemRefType.getMemorySpace();
417 }
418 auto newMemRefType = MemRefType::get(shape: newShape, elementType: oldMemRefType.getElementType(),
419 /*map=*/AffineMap(), memorySpace: newMemSpace);
420
421 // Create new private memref for fused loop 'forOp'. 'newShape' is always
422 // a constant shape.
423 // TODO: Create/move alloc ops for private memrefs closer to their
424 // consumer loop nests to reduce their live range. Currently they are added
425 // at the beginning of the block, because loop nests can be reordered
426 // during the fusion pass.
427 Value newMemRef = top.create<memref::AllocOp>(location: forOp.getLoc(), args&: newMemRefType);
428
429 // Build an AffineMap to remap access functions based on lower bound offsets.
430 SmallVector<AffineExpr, 4> remapExprs;
431 remapExprs.reserve(N: rank);
432 for (unsigned i = 0; i < rank; i++) {
433 auto dimExpr = b.getAffineDimExpr(position: outerIVs.size() + i);
434
435 auto remapExpr =
436 simplifyAffineExpr(expr: dimExpr - offsets[i], numDims: outerIVs.size() + rank, numSymbols: 0);
437 remapExprs.push_back(Elt: remapExpr);
438 }
439
440 auto indexRemap =
441 AffineMap::get(dimCount: outerIVs.size() + rank, symbolCount: 0, results: remapExprs, context: forOp.getContext());
442
443 // Replace all users of 'oldMemRef' with 'newMemRef'.
444 Operation *domFilter =
445 getDominanceFilterForPrivateMemRefRepl(sliceInsertionBlock, producerStores: storeOps);
446 auto userFilterFn = [&](Operation *user) {
447 auto domInfo = std::make_unique<DominanceInfo>(
448 args: domFilter->getParentOfType<FunctionOpInterface>());
449 return domInfo->dominates(a: domFilter, b: user);
450 };
451 LogicalResult res = replaceAllMemRefUsesWith(
452 oldMemRef, newMemRef, /*extraIndices=*/{}, indexRemap,
453 /*extraOperands=*/outerIVs,
454 /*symbolOperands=*/{}, userFilterFn);
455 assert(succeeded(res) &&
456 "replaceAllMemrefUsesWith should always succeed here");
457 (void)res;
458 LLVM_DEBUG(llvm::dbgs() << "Created private memref of type: " << newMemRefType
459 << '\n');
460 return newMemRef;
461}
462
463// Checks the profitability of fusing a backwards slice of the loop nest
464// `srcForOp` into the loop nest surrounding 'dstLoadOpInsts'. The argument
465// 'srcStoreOpInst' is used to calculate the storage reduction on the memref
466// being produced and consumed, which is an input to the cost model. For
467// producer-consumer fusion, 'srcStoreOpInst' will be the same as 'srcOpInst',
468// as we are slicing w.r.t to that producer. For input-reuse fusion, 'srcOpInst'
469// will be the src loop nest LoadOp which reads from the same memref as dst loop
470// nest load ops, and 'srcStoreOpInst' will be the unique store op in the src
471// node, which will be used to check that the write region is the same after
472// input-reuse fusion. Computation slices are provided in 'depthSliceUnions' for
473// each legal fusion depth. The maximal depth at which fusion is legal is
474// provided in 'maxLegalFusionDepth'. Returns true if it is profitable to fuse
475// the candidate loop nests. Returns false otherwise. `dstLoopDepth` is set to
476// the most profitable depth at which to materialize the source loop nest slice.
477// The profitability model executes the following steps:
478// *) Computes the backward computation slice at 'srcOpInst'. This
479// computation slice of the loop nest surrounding 'srcOpInst' is
480// represented by modified src loop bounds in 'sliceState', which are
481// functions of loop IVs in the loop nest surrounding 'srcOpInst'.
482// *) Computes the cost of unfused src/dst loop nests (currently the cost of a
483// loop nest is the total number of dynamic operation instances in the loop
484// nest).
485// *) Computes the cost of fusing a slice of the src loop nest into the dst
486// loop nest at various values of dst loop depth, attempting to fuse
487// the largest computation slice at the maximal dst loop depth (closest to
488// the load) to minimize reuse distance and potentially enable subsequent
489// load/store forwarding.
490// NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
491// nest, at which the src computation slice is inserted/fused.
492// NOTE: We attempt to maximize the dst loop depth, but there are cases
493// where a particular setting for 'dstLoopNest' might fuse an unsliced
494// loop (within the src computation slice) at a depth which results in
495// excessive recomputation (see unit tests for examples).
496// *) Compares the total cost of the unfused loop nests to the min cost fused
497// loop nest computed in the previous step, and returns true if the latter
498// is lower.
499// TODO: Extend profitability analysis to support scenarios with multiple
500// stores.
501static bool isFusionProfitable(AffineForOp srcForOp,
502 ArrayRef<Operation *> producerStores,
503 AffineForOp dstForOp,
504 ArrayRef<ComputationSliceState> depthSliceUnions,
505 unsigned maxLegalFusionDepth,
506 unsigned *dstLoopDepth,
507 double computeToleranceThreshold) {
508 LLVM_DEBUG({
509 llvm::dbgs()
510 << "Checking whether fusion is profitable between source nest:\n";
511 llvm::dbgs() << ' ' << srcForOp << " and destination nest:\n";
512 llvm::dbgs() << dstForOp << "\n";
513 });
514
515 if (maxLegalFusionDepth == 0) {
516 LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth is 0\n");
517 return false;
518 }
519
520 // Compute cost of sliced and unsliced src loop nest.
521
522 // Walk src loop nest and collect stats.
523 LoopNestStats srcLoopNestStats;
524 if (!getLoopNestStats(forOp: srcForOp, stats: &srcLoopNestStats))
525 return false;
526
527 // Compute cost of dst loop nest.
528 LoopNestStats dstLoopNestStats;
529 if (!getLoopNestStats(forOp: dstForOp, stats: &dstLoopNestStats))
530 return false;
531
532 // We limit profitability analysis to only scenarios with
533 // a single producer store for now. Note that some multi-store
534 // producer scenarios will still go through profitability analysis
535 // if only one of the stores is involved in the producer-consumer
536 // relationship of the candidate loops.
537 // TODO: Suppport multiple producer stores in profitability
538 // analysis.
539 if (producerStores.size() > 1) {
540 LLVM_DEBUG(llvm::dbgs() << "Limited profitability analysis. Not "
541 "supported for multiple producer store case.\n");
542 int64_t sliceCost;
543 int64_t fusedLoopNestComputeCost;
544 // We will still fuse if fusion obeys the specified compute
545 // tolerance at the max legal depth.
546 auto fraction = getAdditionalComputeFraction(
547 srcForOp, dstForOp, depth: maxLegalFusionDepth, depthSliceUnions, sliceCost,
548 fusedLoopNestComputeCost);
549 if (!fraction || fraction > computeToleranceThreshold) {
550 LLVM_DEBUG(llvm::dbgs() << "Additional computation exceeds "
551 "compute tolerance. Not fusing.\n");
552 return false;
553 }
554 LLVM_DEBUG(llvm::dbgs()
555 << "Considering fusion profitable at max legal depth.\n");
556 return true;
557 }
558
559 Operation *srcStoreOp = producerStores.front();
560
561 // Search for min cost value for 'dstLoopDepth'. At each value of
562 // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice
563 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
564 // of these bounds). Next the union slice bounds are used to calculate
565 // the cost of the slice and the cost of the slice inserted into the dst
566 // loop nest at 'dstLoopDepth'.
567 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
568 double maxStorageReduction = 0.0;
569 std::optional<uint64_t> sliceMemEstimate;
570
571 // The best loop depth at which to materialize the slice.
572 std::optional<unsigned> bestDstLoopDepth;
573
574 // Compute src loop nest write region size.
575 MemRefRegion srcWriteRegion(srcStoreOp->getLoc());
576 if (failed(Result: srcWriteRegion.compute(op: srcStoreOp, /*loopDepth=*/0))) {
577 LLVM_DEBUG(llvm::dbgs()
578 << "Unable to compute MemRefRegion for source operation\n");
579 return false;
580 }
581
582 std::optional<int64_t> maybeSrcWriteRegionSizeBytes =
583 srcWriteRegion.getRegionSize();
584 if (!maybeSrcWriteRegionSizeBytes.has_value())
585 return false;
586 int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;
587
588 // Compute op instance count for the src loop nest without iteration slicing.
589 uint64_t srcLoopNestCost = getComputeCost(forOp: srcForOp, stats&: srcLoopNestStats);
590
591 // Compute op instance count for the destination loop nest.
592 uint64_t dstLoopNestCost = getComputeCost(forOp: dstForOp, stats&: dstLoopNestStats);
593
594 // Evaluate all depth choices for materializing the slice in the destination
595 // loop nest.
596 for (unsigned i = maxLegalFusionDepth; i >= 1; --i) {
597 const ComputationSliceState &slice = depthSliceUnions[i - 1];
598 // Skip slice union if it wasn't computed for this depth.
599 if (slice.isEmpty())
600 continue;
601
602 // Compute cost of the slice separately, i.e, the compute cost of the slice
603 // if all outer trip counts are one.
604 int64_t sliceCost;
605
606 int64_t fusedLoopNestComputeCost;
607
608 auto mayAdditionalComputeFraction =
609 getAdditionalComputeFraction(srcForOp, dstForOp, depth: i, depthSliceUnions,
610 sliceCost, fusedLoopNestComputeCost);
611 if (!mayAdditionalComputeFraction) {
612 LLVM_DEBUG(llvm::dbgs()
613 << "Can't determine additional compute fraction.\n");
614 continue;
615 }
616 double additionalComputeFraction = *mayAdditionalComputeFraction;
617
618 // Determine what the slice write MemRefRegion would be, if the src loop
619 // nest slice 'slice' were to be inserted into the dst loop nest at loop
620 // depth 'i'.
621 MemRefRegion sliceWriteRegion(srcStoreOp->getLoc());
622 if (failed(Result: sliceWriteRegion.compute(op: srcStoreOp, /*loopDepth=*/0, sliceState: &slice))) {
623 LLVM_DEBUG(llvm::dbgs()
624 << "Failed to compute slice write region at loopDepth: " << i
625 << "\n");
626 continue;
627 }
628
629 std::optional<int64_t> maybeSliceWriteRegionSizeBytes =
630 sliceWriteRegion.getRegionSize();
631 if (!maybeSliceWriteRegionSizeBytes.has_value() ||
632 *maybeSliceWriteRegionSizeBytes == 0) {
633 LLVM_DEBUG(llvm::dbgs()
634 << "Failed to get slice write region size at loopDepth: " << i
635 << "\n");
636 continue;
637 }
638 int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
639
640 double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
641 static_cast<double>(sliceWriteRegionSizeBytes);
642
643 LLVM_DEBUG({
644 std::stringstream msg;
645 msg << " evaluating fusion profitability at depth : " << i << "\n"
646 << std::fixed << std::setprecision(2)
647 << " additional compute fraction: "
648 << 100.0 * additionalComputeFraction << "%\n"
649 << " storage reduction factor: " << storageReduction << "x\n"
650 << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
651 << " src write region size: " << srcWriteRegionSizeBytes << "\n"
652 << " slice write region size: " << sliceWriteRegionSizeBytes
653 << "\n";
654 llvm::dbgs() << msg.str();
655 });
656
657 // TODO: This is a placeholder cost model.
658 // Among all choices that add an acceptable amount of redundant computation
659 // (as per computeToleranceThreshold), we will simply pick the one that
660 // reduces the intermediary size the most.
661 if ((storageReduction > maxStorageReduction) &&
662 (additionalComputeFraction <= computeToleranceThreshold)) {
663 maxStorageReduction = storageReduction;
664 bestDstLoopDepth = i;
665 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
666 sliceMemEstimate = sliceWriteRegionSizeBytes;
667 }
668 }
669
670 // A simple cost model: fuse if it reduces the memory footprint.
671
672 if (!bestDstLoopDepth) {
673 LLVM_DEBUG(
674 llvm::dbgs()
675 << "All fusion choices involve more than the threshold amount of "
676 "redundant computation; NOT fusing.\n");
677 return false;
678 }
679
680 if (!bestDstLoopDepth) {
681 LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n");
682 return false;
683 }
684
685 // Set dstLoopDepth based on best values from search.
686 *dstLoopDepth = *bestDstLoopDepth;
687
688 LLVM_DEBUG(
689 llvm::dbgs() << " LoopFusion fusion stats:"
690 << "\n best loop depth: " << bestDstLoopDepth
691 << "\n src loop nest compute cost: " << srcLoopNestCost
692 << "\n dst loop nest compute cost: " << dstLoopNestCost
693 << "\n fused loop nest compute cost: "
694 << minFusedLoopNestComputeCost << "\n");
695
696 auto dstMemSize = getMemoryFootprintBytes(forOp: dstForOp);
697 auto srcMemSize = getMemoryFootprintBytes(forOp: srcForOp);
698
699 std::optional<double> storageReduction;
700
701 if (!dstMemSize || !srcMemSize) {
702 LLVM_DEBUG(llvm::dbgs()
703 << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
704 return false;
705 }
706
707 auto srcMemSizeVal = *srcMemSize;
708 auto dstMemSizeVal = *dstMemSize;
709
710 assert(sliceMemEstimate && "expected value");
711 auto fusedMem = dstMemSizeVal + *sliceMemEstimate;
712
713 LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
714 << " dst mem: " << dstMemSizeVal << "\n"
715 << " fused mem: " << fusedMem << "\n"
716 << " slice mem: " << sliceMemEstimate << "\n");
717
718 if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
719 LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
720 return false;
721 }
722 storageReduction =
723 100.0 *
724 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
725
726 double additionalComputeFraction =
727 100.0 * (minFusedLoopNestComputeCost /
728 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
729 1);
730 (void)additionalComputeFraction;
731 LLVM_DEBUG({
732 std::stringstream msg;
733 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
734 << std::setprecision(2) << additionalComputeFraction
735 << "% redundant computation and a ";
736 msg << (storageReduction ? std::to_string(*storageReduction) : "<unknown>");
737 msg << "% storage reduction.\n";
738 llvm::dbgs() << msg.str();
739 });
740
741 return true;
742}
743
744namespace {
745
746// GreedyFusion greedily fuses loop nests which have a producer/consumer or
747// input-reuse relationship on a memref, with the goal of improving locality.
748//
749// The steps of the producer-consumer fusion algorithm are as follows:
750//
751// *) A worklist is initialized with node ids from the dependence graph.
752// *) For each node id in the worklist:
753// *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a
754// candidate destination AffineForOp into which fusion will be attempted.
755// *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
756// *) For each LoadOp in 'dstLoadOps' do:
757// *) Look up dependent loop nests which have a single store op to the same
758// memref.
759// *) Check if dependences would be violated by the fusion.
760// *) Get a computation slice of 'srcLoopNest', which adjusts its loop
761// bounds to be functions of 'dstLoopNest' IVs and symbols.
762// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
763// at a loop depth determined by the cost model in 'isFusionProfitable'.
764// *) Add the newly fused load/store operations to the state,
765// and also add newly fused load ops to 'dstLoopOps' to be considered
766// as fusion dst load ops in another iteration.
767// *) Remove old src loop nest and its associated state.
768//
769// The steps of the input-reuse fusion algorithm are as follows:
770//
771// *) Initialize 'worklist' with node ids from the dependence graph.
772// *) For each 'dstNode' in the worklist:
773// *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which
774// loads from the same memref, but which has no dependence paths to/from.
775// *) Get a computation slice of 'sibLoopNest', which adjusts its loop
776// bounds to be functions of 'dstLoopNest' IVs and symbols.
777// *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest',
778// at a loop depth determined by the cost model in 'isFusionProfitable'.
779// This function also checks that the memref write region of 'sibLoopNest',
780// is preserved in the fused loop nest.
781// *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'.
782//
783// Given a graph where top-level operations are vertices in the set 'V' and
784// edges in the set 'E' are dependences between vertices, this algorithm
785// takes O(V) time for initialization, and has runtime O(V + E).
786//
787// This greedy algorithm is not 'maximal' due to the current restriction of
788// fusing along single producer consumer edges, but there is a TODO: to fix
789// this.
790//
791// TODO: Experiment with other fusion policies.
792struct GreedyFusion {
793public:
794 // The data dependence graph to traverse during fusion.
795 MemRefDependenceGraph *mdg;
796 // Worklist of graph nodes visited during the fusion pass.
797 SmallVector<unsigned, 8> worklist;
798 // Parameter for local buffer size threshold.
799 unsigned localBufSizeThreshold;
800 // Parameter for fast memory space.
801 std::optional<unsigned> fastMemorySpace;
802 // If true, ignore any additional (redundant) computation tolerance threshold
803 // that would have prevented fusion.
804 bool maximalFusion;
805 // The amount of additional computation that is tolerated while fusing
806 // pair-wise as a fraction of the total computation.
807 double computeToleranceThreshold;
808
809 using Node = MemRefDependenceGraph::Node;
810
811 GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold,
812 std::optional<unsigned> fastMemorySpace, bool maximalFusion,
813 double computeToleranceThreshold)
814 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
815 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
816 computeToleranceThreshold(computeToleranceThreshold) {}
817
818 /// Initializes 'worklist' with nodes from 'mdg'.
819 void init() {
820 // TODO: Add a priority queue for prioritizing nodes by different
821 // metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
822 worklist.clear();
823 for (auto &idAndNode : mdg->nodes) {
824 const Node &node = idAndNode.second;
825 worklist.push_back(Elt: node.id);
826 }
827 }
828 /// Run only sibling fusion on the `mdg`.
829 void runSiblingFusionOnly() {
830 fuseSiblingNodes();
831 eraseUnusedMemRefAllocations();
832 }
833
834 /// Run only producer/consumer fusion on the `mdg`.
835 void runProducerConsumerFusionOnly() {
836 fuseProducerConsumerNodes(
837 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
838 eraseUnusedMemRefAllocations();
839 }
840
841 // Run the GreedyFusion pass.
842 // *) First pass through the nodes fuses single-use producer nodes into their
843 // unique consumer.
844 // *) Second pass fuses sibling nodes which share no dependence edges.
845 // *) Third pass fuses any remaining producer nodes into their users.
846 void runGreedyFusion() {
847 // TODO: Run this repeatedly until a fixed-point is reached.
848 fuseProducerConsumerNodes(/*maxSrcUserCount=*/1);
849 fuseSiblingNodes();
850 fuseProducerConsumerNodes(
851 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
852 eraseUnusedMemRefAllocations();
853 }
854
855 /// Returns true if a private memref can be created for `memref` given
856 /// the fusion scenario reflected by the other arguments.
857 bool canCreatePrivateMemRef(Value memref,
858 const DenseSet<Value> &srcEscapingMemRefs,
859 unsigned producerId, unsigned consumerId,
860 bool removeSrcNode) {
861 // We can't generate private memrefs if their size can't be computed.
862 if (!getMemRefIntOrFloatEltSizeInBytes(memRefType: cast<MemRefType>(Val: memref.getType())))
863 return false;
864 const Node *consumerNode = mdg->getNode(id: consumerId);
865 // If `memref` is an escaping one, do not create a private memref
866 // for the below scenarios, since doing so will leave the escaping
867 // memref unmodified as all the writes originally meant for the
868 // escaping memref would be performed on the private memref:
869 // 1. The source is to be removed after fusion,
870 // OR
871 // 2. The destination writes to `memref`.
872 if (srcEscapingMemRefs.count(V: memref) > 0 &&
873 (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))
874 return false;
875
876 // Don't create a private memref if 'srcNode' has in edges on
877 // 'memref' or 'dstNode' has out edges on 'memref'.
878 if (mdg->getIncomingMemRefAccesses(id: producerId, memref) > 0 ||
879 mdg->getOutEdgeCount(id: consumerId, memref) > 0)
880 return false;
881
882 // If 'srcNode' will be removed but it has out edges on 'memref' to
883 // nodes other than 'dstNode', we have to preserve dependences and
884 // cannot create a private memref.
885 if (removeSrcNode &&
886 any_of(Range&: mdg->outEdges[producerId], P: [&](const auto &edge) {
887 return edge.value == memref && edge.id != consumerId;
888 }))
889 return false;
890
891 return true;
892 }
893
894 /// Perform fusions with node `dstId` as the destination of fusion, with
895 /// No fusion is performed when producers with a user count greater than
896 /// `maxSrcUserCount` for any of the memrefs involved.
897 void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) {
898 LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
899 // Skip if this node was removed (fused into another node).
900 if (mdg->nodes.count(Val: dstId) == 0)
901 return;
902 // Get 'dstNode' into which to attempt fusion.
903 auto *dstNode = mdg->getNode(id: dstId);
904 // Skip if 'dstNode' is not a loop nest.
905 if (!isa<AffineForOp>(Val: dstNode->op))
906 return;
907 // Skip if 'dstNode' is a loop nest returning values.
908 // TODO: support loop nests that return values.
909 if (dstNode->op->getNumResults() > 0)
910 return;
911
912 LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
913
914 // Sink sequential loops in 'dstNode' (and thus raise parallel loops)
915 // while preserving relative order. This can increase the maximum loop
916 // depth at which we can fuse a slice of a producer loop nest into a
917 // consumer loop nest.
918 sinkSequentialLoops(node: dstNode);
919 auto dstAffineForOp = cast<AffineForOp>(Val: dstNode->op);
920
921 // Try to fuse 'dstNode' with candidate producer loops until a fixed point
922 // is reached. Fusing two loops may expose new fusion opportunities.
923 bool dstNodeChanged;
924 do {
925 // Gather src loop candidates for 'dstNode' and visit them in "quasi"
926 // reverse program order to minimize the number of iterations needed to
927 // reach the fixed point. Note that this is a best effort approach since
928 // 'getProducerCandidates' does not always guarantee that program order
929 // in 'srcIdCandidates'.
930 dstNodeChanged = false;
931 SmallVector<unsigned, 16> srcIdCandidates;
932 getProducerCandidates(dstId, mdg: *mdg, srcIdCandidates);
933
934 for (unsigned srcId : llvm::reverse(C&: srcIdCandidates)) {
935 // Get 'srcNode' from which to attempt fusion into 'dstNode'.
936 auto *srcNode = mdg->getNode(id: srcId);
937 auto srcAffineForOp = cast<AffineForOp>(Val: srcNode->op);
938
939 LLVM_DEBUG(llvm::dbgs()
940 << "Trying to fuse producer loop nest " << srcId
941 << " with consumer loop nest " << dstId << "\n");
942 LLVM_DEBUG(llvm::dbgs() << "Compute tolerance threshold: "
943 << computeToleranceThreshold << '\n');
944 LLVM_DEBUG(llvm::dbgs()
945 << "Producer loop nest:\n"
946 << *srcNode->op << "\n and consumer loop nest:\n"
947 << *dstNode->op << '\n');
948
949 LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId
950 << " for dst loop " << dstId << "\n");
951
952 // Skip if 'srcNode' is a loop nest returning values.
953 // TODO: support loop nests that return values.
954 if (isa<AffineForOp>(Val: srcNode->op) && srcNode->op->getNumResults() > 0)
955 continue;
956
957 DenseSet<Value> producerConsumerMemrefs;
958 gatherProducerConsumerMemrefs(srcId, dstId, mdg: *mdg,
959 producerConsumerMemrefs);
960
961 // Skip if 'srcNode' out edge count on any memref is greater than
962 // 'maxSrcUserCount'.
963 if (any_of(Range&: producerConsumerMemrefs, P: [&](Value memref) {
964 return mdg->getOutEdgeCount(id: srcNode->id, memref) >
965 maxSrcUserCount;
966 }))
967 continue;
968
969 // Gather memrefs in 'srcNode' that are written and escape out of the
970 // block (e.g., memref block arguments, returned memrefs,
971 // memrefs passed to function calls, etc.).
972 DenseSet<Value> srcEscapingMemRefs;
973 gatherEscapingMemrefs(id: srcNode->id, mdg: *mdg, escapingMemRefs&: srcEscapingMemRefs);
974
975 // Compute an operation list insertion point for the fused loop
976 // nest which preserves dependences.
977 Operation *fusedLoopInsPoint =
978 mdg->getFusedLoopNestInsertionPoint(srcId: srcNode->id, dstId: dstNode->id);
979 if (fusedLoopInsPoint == nullptr)
980 continue;
981
982 // It's possible this fusion is at an inner depth (i.e., there are
983 // common surrounding affine loops for the source and destination for
984 // ops). We need to get this number because the call to canFuseLoops
985 // needs to be passed the absolute depth. The max legal depth and the
986 // depths we try below are however *relative* and as such don't include
987 // the common depth.
988 SmallVector<AffineForOp, 4> surroundingLoops;
989 getAffineForIVs(op&: *dstAffineForOp, loops: &surroundingLoops);
990 unsigned numSurroundingLoops = surroundingLoops.size();
991
992 // Compute the innermost common loop depth for dstNode
993 // producer-consumer loads/stores.
994 SmallVector<Operation *, 2> dstMemrefOps;
995 for (Operation *op : dstNode->loads)
996 if (producerConsumerMemrefs.count(
997 V: cast<AffineReadOpInterface>(Val: op).getMemRef()) > 0)
998 dstMemrefOps.push_back(Elt: op);
999 for (Operation *op : dstNode->stores)
1000 if (producerConsumerMemrefs.count(
1001 V: cast<AffineWriteOpInterface>(Val: op).getMemRef()))
1002 dstMemrefOps.push_back(Elt: op);
1003 if (dstMemrefOps.empty())
1004 continue;
1005 unsigned dstLoopDepthTest =
1006 getInnermostCommonLoopDepth(ops: dstMemrefOps) - numSurroundingLoops;
1007
1008 // Check the feasibility of fusing src loop nest into dst loop nest
1009 // at loop depths in range [1, dstLoopDepthTest].
1010 unsigned maxLegalFusionDepth = 0;
1011 SmallVector<ComputationSliceState, 8> depthSliceUnions;
1012 depthSliceUnions.resize(N: dstLoopDepthTest);
1013 FusionStrategy strategy(FusionStrategy::ProducerConsumer);
1014 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1015 FusionResult result =
1016 affine::canFuseLoops(srcForOp: srcAffineForOp, dstForOp: dstAffineForOp,
1017 /*dstLoopDepth=*/i + numSurroundingLoops,
1018 srcSlice: &depthSliceUnions[i - 1], fusionStrategy: strategy);
1019 if (result.value == FusionResult::Success) {
1020 maxLegalFusionDepth = i;
1021 LLVM_DEBUG(llvm::dbgs()
1022 << "Found valid slice for depth: " << i << '\n');
1023 }
1024 }
1025
1026 if (maxLegalFusionDepth == 0) {
1027 LLVM_DEBUG(llvm::dbgs()
1028 << "Can't fuse: fusion is not legal at any depth\n");
1029 continue;
1030 }
1031
1032 LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
1033 << maxLegalFusionDepth << '\n');
1034
1035 double computeToleranceThresholdToUse = computeToleranceThreshold;
1036
1037 // Cyclic dependences in the source nest may be violated when performing
1038 // slicing-based fusion. They aren't actually violated in cases where no
1039 // redundant execution of the source happens (1:1 pointwise dep on the
1040 // producer-consumer memref access for example). Check this and allow
1041 // fusion accordingly.
1042 if (hasCyclicDependence(root: srcAffineForOp)) {
1043 LLVM_DEBUG(llvm::dbgs() << "Source nest has a cyclic dependence.\n");
1044 // Maximal fusion does not check for compute tolerance threshold; so
1045 // perform the maximal fusion only when the redundanation computation
1046 // is zero.
1047 if (maximalFusion) {
1048 auto srcForOp = cast<AffineForOp>(Val: srcNode->op);
1049 auto dstForOp = cast<AffineForOp>(Val: dstNode->op);
1050 int64_t sliceCost;
1051 int64_t fusedLoopNestComputeCost;
1052 auto fraction = getAdditionalComputeFraction(
1053 srcForOp, dstForOp, depth: maxLegalFusionDepth, depthSliceUnions,
1054 sliceCost, fusedLoopNestComputeCost);
1055 if (!fraction || fraction > 0) {
1056 LLVM_DEBUG(
1057 llvm::dbgs()
1058 << "Can't perform maximal fusion with a cyclic dependence "
1059 "and non-zero additional compute.\n");
1060 return;
1061 }
1062 } else {
1063 // Set redundant computation tolerance to zero regardless of what
1064 // the user specified. Without this, fusion would be invalid.
1065 LLVM_DEBUG(llvm::dbgs()
1066 << "Setting compute tolerance to zero since "
1067 "source has a cylic dependence.\n");
1068 computeToleranceThresholdToUse = 0;
1069 }
1070 }
1071
1072 // Check if fusion would be profitable. We skip profitability analysis
1073 // for maximal fusion since we already know the maximal legal depth to
1074 // fuse.
1075 unsigned bestDstLoopDepth = maxLegalFusionDepth;
1076 if (!maximalFusion) {
1077 // Retrieve producer stores from the src loop.
1078 SmallVector<Operation *, 2> producerStores;
1079 for (Operation *op : srcNode->stores)
1080 if (producerConsumerMemrefs.count(
1081 V: cast<AffineWriteOpInterface>(Val: op).getMemRef()))
1082 producerStores.push_back(Elt: op);
1083
1084 assert(!producerStores.empty() && "Expected producer store");
1085 if (!isFusionProfitable(srcForOp: srcAffineForOp, producerStores,
1086 dstForOp: dstAffineForOp, depthSliceUnions,
1087 maxLegalFusionDepth, dstLoopDepth: &bestDstLoopDepth,
1088 computeToleranceThreshold: computeToleranceThresholdToUse)) {
1089 continue;
1090 }
1091 }
1092
1093 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
1094 ComputationSliceState &bestSlice =
1095 depthSliceUnions[bestDstLoopDepth - 1];
1096 assert(!bestSlice.isEmpty() && "Missing slice union for depth");
1097
1098 // Determine if 'srcId' can be removed after fusion, taking into
1099 // account remaining dependences, escaping memrefs and the fusion
1100 // insertion point.
1101 bool removeSrcNode = canRemoveSrcNodeAfterFusion(
1102 srcId, dstId, fusionSlice: bestSlice, fusedLoopInsPoint, escapingMemRefs: srcEscapingMemRefs,
1103 mdg: *mdg);
1104
1105 DenseSet<Value> privateMemrefs;
1106 for (Value memref : producerConsumerMemrefs) {
1107 if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, producerId: srcId, consumerId: dstId,
1108 removeSrcNode)) {
1109 // Create a private version of this memref.
1110 LLVM_DEBUG(llvm::dbgs()
1111 << "Creating private memref for " << memref << '\n');
1112 // Create a private version of this memref.
1113 privateMemrefs.insert(V: memref);
1114 }
1115 }
1116
1117 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
1118 fuseLoops(srcForOp: srcAffineForOp, dstForOp: dstAffineForOp, srcSlice: bestSlice);
1119 dstNodeChanged = true;
1120
1121 LLVM_DEBUG(llvm::dbgs()
1122 << "Fused src loop " << srcId << " into dst loop " << dstId
1123 << " at depth " << bestDstLoopDepth << ":\n"
1124 << dstAffineForOp << "\n");
1125
1126 // Move 'dstAffineForOp' before 'insertPointInst' if needed.
1127 if (fusedLoopInsPoint != dstAffineForOp)
1128 dstAffineForOp->moveBefore(existingOp: fusedLoopInsPoint);
1129
1130 // Update edges between 'srcNode' and 'dstNode'.
1131 mdg->updateEdges(srcId: srcNode->id, dstId: dstNode->id, privateMemRefs: privateMemrefs,
1132 removeSrcId: removeSrcNode);
1133
1134 // Create private memrefs.
1135 if (!privateMemrefs.empty()) {
1136 // Note the block into which fusion was performed. This can be used to
1137 // place `alloc`s that create private memrefs.
1138 Block *sliceInsertionBlock = bestSlice.insertPoint->getBlock();
1139
1140 // Gather stores for all the private-to-be memrefs.
1141 DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
1142 dstAffineForOp.walk(callback: [&](AffineWriteOpInterface storeOp) {
1143 Value storeMemRef = storeOp.getMemRef();
1144 if (privateMemrefs.count(V: storeMemRef) > 0)
1145 privateMemRefToStores[storeMemRef].push_back(Elt: storeOp);
1146 });
1147
1148 // Replace original memrefs with private memrefs. Note that all the
1149 // loads and stores on these memrefs will be replaced with a new
1150 // loads and stores. Any reference to the original ones becomes
1151 // invalid after this point.
1152 for (auto &memrefToStoresPair : privateMemRefToStores) {
1153 ArrayRef<Operation *> storesForMemref = memrefToStoresPair.second;
1154 Value newMemRef = createPrivateMemRef(
1155 forOp: dstAffineForOp, storeOps: storesForMemref, dstLoopDepth: bestDstLoopDepth,
1156 fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
1157 if (!newMemRef)
1158 continue;
1159 // Create new node in dependence graph for 'newMemRef' alloc op.
1160 unsigned newMemRefNodeId = mdg->addNode(op: newMemRef.getDefiningOp());
1161 // Add edge from 'newMemRef' node to dstNode.
1162 mdg->addEdge(srcId: newMemRefNodeId, dstId, value: newMemRef);
1163 }
1164 // One or more entries for 'newMemRef' alloc op are inserted into
1165 // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to
1166 // reallocate, update dstNode.
1167 dstNode = mdg->getNode(id: dstId);
1168 }
1169
1170 // Collect dst loop stats after memref privatization transformation.
1171 LoopNestStateCollector dstLoopCollector;
1172 dstLoopCollector.collect(opToWalk: dstAffineForOp);
1173
1174 // Clear and add back loads and stores.
1175 mdg->clearNodeLoadAndStores(id: dstNode->id);
1176 mdg->addToNode(
1177 id: dstId, loads: dstLoopCollector.loadOpInsts, stores: dstLoopCollector.storeOpInsts,
1178 memrefLoads: dstLoopCollector.memrefLoads, memrefStores: dstLoopCollector.memrefStores,
1179 memrefFrees: dstLoopCollector.memrefFrees);
1180
1181 if (removeSrcNode) {
1182 LLVM_DEBUG(llvm::dbgs()
1183 << "Removing src loop " << srcId << " after fusion\n");
1184 // srcNode is no longer valid after it is removed from mdg.
1185 srcAffineForOp.erase();
1186 mdg->removeNode(id: srcId);
1187 srcNode = nullptr;
1188 }
1189 }
1190 } while (dstNodeChanged);
1191 }
1192
1193 /// Visit each node in the graph, and for each node, attempt to fuse it with
1194 /// producer-consumer candidates. No fusion is performed when producers with a
1195 /// user count greater than `maxSrcUserCount` for any of the memrefs involved
1196 /// are encountered.
1197 void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
1198 LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
1199 init();
1200 while (!worklist.empty()) {
1201 unsigned dstId = worklist.back();
1202 worklist.pop_back();
1203 performFusionsIntoDest(dstId, maxSrcUserCount);
1204 }
1205 }
1206
1207 // Visits each node in the graph, and for each node, attempts to fuse it with
1208 // its sibling nodes (nodes which share a parent, but no dependence edges).
1209 void fuseSiblingNodes() {
1210 LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n");
1211 init();
1212 while (!worklist.empty()) {
1213 unsigned dstId = worklist.back();
1214 worklist.pop_back();
1215
1216 // Skip if this node was removed (fused into another node).
1217 if (mdg->nodes.count(Val: dstId) == 0)
1218 continue;
1219 // Get 'dstNode' into which to attempt fusion.
1220 auto *dstNode = mdg->getNode(id: dstId);
1221 // Skip if 'dstNode' is not a loop nest.
1222 if (!isa<AffineForOp>(Val: dstNode->op))
1223 continue;
1224 // Attempt to fuse 'dstNode' with its sibling nodes in the graph.
1225 fuseWithSiblingNodes(dstNode);
1226 }
1227 }
1228
1229 // Attempt to fuse 'dstNode' with sibling nodes in the graph.
1230 void fuseWithSiblingNodes(Node *dstNode) {
1231 DenseSet<unsigned> visitedSibNodeIds;
1232 std::pair<unsigned, Value> idAndMemref;
1233 auto dstAffineForOp = cast<AffineForOp>(Val: dstNode->op);
1234
1235 while (findSiblingNodeToFuse(dstNode, visitedSibNodeIds: &visitedSibNodeIds, idAndMemrefToFuse: &idAndMemref)) {
1236 unsigned sibId = idAndMemref.first;
1237 Value memref = idAndMemref.second;
1238 // TODO: Check that 'sibStoreOpInst' post-dominates all other
1239 // stores to the same memref in 'sibNode' loop nest.
1240 auto *sibNode = mdg->getNode(id: sibId);
1241 // Compute an operation list insertion point for the fused loop
1242 // nest which preserves dependences.
1243 assert(sibNode->op->getBlock() == dstNode->op->getBlock());
1244 Operation *insertPointInst =
1245 sibNode->op->isBeforeInBlock(other: dstNode->op)
1246 ? mdg->getFusedLoopNestInsertionPoint(srcId: sibNode->id, dstId: dstNode->id)
1247 : mdg->getFusedLoopNestInsertionPoint(srcId: dstNode->id, dstId: sibNode->id);
1248 if (insertPointInst == nullptr)
1249 continue;
1250
1251 // Check if fusion would be profitable and at what depth.
1252
1253 // Get unique 'sibNode' load op to 'memref'.
1254 SmallVector<Operation *, 2> sibLoadOpInsts;
1255 sibNode->getLoadOpsForMemref(memref, loadOps: &sibLoadOpInsts);
1256 // Currently findSiblingNodeToFuse searches for siblings with one load.
1257 Operation *sibLoadOpInst = llvm::getSingleElement(C&: sibLoadOpInsts);
1258
1259 // Gather 'dstNode' load ops to 'memref'.
1260 SmallVector<Operation *, 2> dstLoadOpInsts;
1261 dstNode->getLoadOpsForMemref(memref, loadOps: &dstLoadOpInsts);
1262
1263 // It's possible this fusion is at an inner depth (i.e., there are common
1264 // surrounding affine loops for the source and destination for ops). We
1265 // need to get this number because the call to canFuseLoops needs to be
1266 // passed the absolute depth. The max legal depth and the depths we try
1267 // below are however *relative* and as such don't include the common
1268 // depth.
1269 SmallVector<AffineForOp, 4> surroundingLoops;
1270 getAffineForIVs(op&: *dstAffineForOp, loops: &surroundingLoops);
1271 unsigned numSurroundingLoops = surroundingLoops.size();
1272 SmallVector<AffineForOp, 4> dstLoopIVs;
1273 getAffineForIVs(op&: *dstLoadOpInsts[0], loops: &dstLoopIVs);
1274 unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;
1275 auto sibAffineForOp = cast<AffineForOp>(Val: sibNode->op);
1276
1277 // Compute loop depth and slice union for fusion.
1278 SmallVector<ComputationSliceState, 8> depthSliceUnions;
1279 depthSliceUnions.resize(N: dstLoopDepthTest);
1280 unsigned maxLegalFusionDepth = 0;
1281 FusionStrategy strategy(memref);
1282 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1283 FusionResult result =
1284 affine::canFuseLoops(srcForOp: sibAffineForOp, dstForOp: dstAffineForOp,
1285 /*dstLoopDepth=*/i + numSurroundingLoops,
1286 srcSlice: &depthSliceUnions[i - 1], fusionStrategy: strategy);
1287
1288 if (result.value == FusionResult::Success)
1289 maxLegalFusionDepth = i;
1290 }
1291
1292 LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
1293 << maxLegalFusionDepth << '\n');
1294
1295 // Skip if fusion is not feasible at any loop depths.
1296 if (maxLegalFusionDepth == 0)
1297 continue;
1298
1299 double computeToleranceThresholdToUse = computeToleranceThreshold;
1300
1301 // Cyclic dependences in the source nest may be violated when performing
1302 // slicing-based fusion. They aren't actually violated in cases where no
1303 // redundant execution of the source happens (1:1 pointwise dep on the
1304 // producer-consumer memref access for example). Check this and allow
1305 // fusion accordingly.
1306 if (hasCyclicDependence(root: sibAffineForOp)) {
1307 LLVM_DEBUG(llvm::dbgs() << "Source nest has a cyclic dependence.\n");
1308 // Maximal fusion does not check for compute tolerance threshold; so
1309 // perform the maximal fusion only when the redundanation computation is
1310 // zero.
1311 if (maximalFusion) {
1312 auto dstForOp = cast<AffineForOp>(Val: dstNode->op);
1313 int64_t sliceCost;
1314 int64_t fusedLoopNestComputeCost;
1315 auto fraction = getAdditionalComputeFraction(
1316 srcForOp: sibAffineForOp, dstForOp, depth: maxLegalFusionDepth, depthSliceUnions,
1317 sliceCost, fusedLoopNestComputeCost);
1318 if (!fraction || fraction > 0) {
1319 LLVM_DEBUG(
1320 llvm::dbgs()
1321 << "Can't perform maximal fusion with a cyclic dependence "
1322 "and non-zero additional compute.\n");
1323 return;
1324 }
1325 } else {
1326 // Set redundant computation tolerance to zero regardless of what the
1327 // user specified. Without this, fusion would be invalid.
1328 LLVM_DEBUG(llvm::dbgs() << "Setting compute tolerance to zero since "
1329 "source has a cyclic dependence.\n");
1330 computeToleranceThresholdToUse = 0.0;
1331 }
1332 }
1333
1334 unsigned bestDstLoopDepth = maxLegalFusionDepth;
1335 if (!maximalFusion) {
1336 // Check if fusion would be profitable. For sibling fusion, the sibling
1337 // load op is treated as the src "store" op for fusion profitability
1338 // purposes. The footprint of the load in the slice relative to the
1339 // unfused source's determines reuse.
1340 if (!isFusionProfitable(srcForOp: sibAffineForOp, producerStores: sibLoadOpInst, dstForOp: dstAffineForOp,
1341 depthSliceUnions, maxLegalFusionDepth,
1342 dstLoopDepth: &bestDstLoopDepth,
1343 computeToleranceThreshold: computeToleranceThresholdToUse))
1344 continue;
1345 }
1346
1347 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
1348
1349 const ComputationSliceState &bestSlice =
1350 depthSliceUnions[bestDstLoopDepth - 1];
1351 assert(!bestSlice.isEmpty() &&
1352 "Fusion depth has no computed slice union");
1353
1354 // Do not perform sibling fusion if it isn't maximal. We always remove the
1355 // sibling node and as such fusion shouldn't be performed if a part of the
1356 // slice is used in the destination.
1357 auto isMaximal = bestSlice.isMaximal();
1358 if (!isMaximal.value_or(u: false)) {
1359 LLVM_DEBUG(llvm::dbgs()
1360 << "Slice isn't maximal; not performing sibling fusion.\n");
1361 continue;
1362 }
1363
1364 // Check if source loop is being inserted in the innermost
1365 // destination loop. Based on this, the fused loop may be optimized
1366 // further inside `fuseLoops`.
1367 bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
1368 // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
1369 affine::fuseLoops(srcForOp: sibAffineForOp, dstForOp: dstAffineForOp, srcSlice: bestSlice,
1370 isInnermostSiblingInsertionFusion: isInnermostInsertion);
1371
1372 auto dstForInst = cast<AffineForOp>(Val: dstNode->op);
1373 // Update operation position of fused loop nest (if needed).
1374 if (insertPointInst != dstForInst)
1375 dstForInst->moveBefore(existingOp: insertPointInst);
1376
1377 LLVM_DEBUG(llvm::dbgs()
1378 << "Fused sibling nest " << sibId << " into destination nest "
1379 << dstNode->id << " at depth " << bestDstLoopDepth << ":\n"
1380 << dstAffineForOp << "\n");
1381
1382 // Update data dependence graph state post fusion.
1383 updateStateAfterSiblingFusion(sibNode, dstNode);
1384
1385 // Remove old sibling loop nest.
1386 // Get op before we invalidate the MDG node.
1387 Operation *op = sibNode->op;
1388 mdg->removeNode(id: sibNode->id);
1389 op->erase();
1390 }
1391 }
1392
1393 // Searches block argument uses and the graph from 'dstNode' looking for a
1394 // fusion candidate sibling node which shares no dependences with 'dstNode'
1395 // but which loads from the same memref. Returns true and sets
1396 // 'idAndMemrefToFuse' on success. Returns false otherwise.
1397 bool findSiblingNodeToFuse(Node *dstNode,
1398 DenseSet<unsigned> *visitedSibNodeIds,
1399 std::pair<unsigned, Value> *idAndMemrefToFuse) {
1400 // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse
1401 // on 'memref'.
1402 auto canFuseWithSibNode = [&](Node *sibNode, Value memref) {
1403 // Skip if 'outEdge' is not a read-after-write dependence.
1404 // TODO: Remove restrict to single load op restriction.
1405 if (sibNode->getLoadOpCount(memref) != 1)
1406 return false;
1407 // Skip if there exists a path of dependent edges between
1408 // 'sibNode' and 'dstNode'.
1409 if (mdg->hasDependencePath(srcId: sibNode->id, dstId: dstNode->id) ||
1410 mdg->hasDependencePath(srcId: dstNode->id, dstId: sibNode->id))
1411 return false;
1412 // Skip sib node if it loads to (and stores from) the same memref on
1413 // which it also has an input dependence edge.
1414 DenseSet<Value> loadAndStoreMemrefSet;
1415 sibNode->getLoadAndStoreMemrefSet(loadAndStoreMemrefSet: &loadAndStoreMemrefSet);
1416 if (llvm::any_of(Range&: loadAndStoreMemrefSet, P: [=](Value memref) {
1417 return mdg->getIncomingMemRefAccesses(id: sibNode->id, memref) > 0;
1418 }))
1419 return false;
1420
1421 // Check that all stores are to the same memref if any.
1422 DenseSet<Value> storeMemrefs;
1423 for (auto *storeOpInst : sibNode->stores) {
1424 storeMemrefs.insert(
1425 V: cast<AffineWriteOpInterface>(Val: storeOpInst).getMemRef());
1426 }
1427 return storeMemrefs.size() <= 1;
1428 };
1429
1430 // Search for siblings which load the same memref block argument.
1431 Block *block = dstNode->op->getBlock();
1432 for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) {
1433 for (Operation *user : block->getArgument(i).getUsers()) {
1434 auto loadOp = dyn_cast<AffineReadOpInterface>(Val: user);
1435 if (!loadOp)
1436 continue;
1437 // Gather loops surrounding 'use'.
1438 SmallVector<AffineForOp, 4> loops;
1439 getAffineForIVs(op&: *user, loops: &loops);
1440 // Skip 'use' if it is not within a loop nest.
1441 // Find the surrounding affine.for nested immediately within the
1442 // block.
1443 auto *it = llvm::find_if(Range&: loops, P: [&](AffineForOp loop) {
1444 return loop->getBlock() == &mdg->block;
1445 });
1446 // Skip 'use' if it is not within a loop nest in `block`.
1447 if (it == loops.end())
1448 continue;
1449 Node *sibNode = mdg->getForOpNode(forOp: *it);
1450 assert(sibNode != nullptr);
1451 // Skip 'use' if it not a sibling to 'dstNode'.
1452 if (sibNode->id == dstNode->id)
1453 continue;
1454 // Skip 'use' if it has been visited.
1455 if (visitedSibNodeIds->count(V: sibNode->id) > 0)
1456 continue;
1457 // Skip 'use' if it does not load from the same memref as 'dstNode'.
1458 auto memref = loadOp.getMemRef();
1459 if (dstNode->getLoadOpCount(memref) == 0)
1460 continue;
1461 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
1462 if (canFuseWithSibNode(sibNode, memref)) {
1463 visitedSibNodeIds->insert(V: sibNode->id);
1464 idAndMemrefToFuse->first = sibNode->id;
1465 idAndMemrefToFuse->second = memref;
1466 return true;
1467 }
1468 }
1469 }
1470
1471 // Search for siblings by following edges through an intermediate src node.
1472 // Collect candidate 'dstNode' input edges in 'inEdges'.
1473 SmallVector<MemRefDependenceGraph::Edge, 2> inEdges;
1474 mdg->forEachMemRefInputEdge(
1475 id: dstNode->id, callback: [&](MemRefDependenceGraph::Edge inEdge) {
1476 // Add 'inEdge' if it is a read-after-write dependence.
1477 if (dstNode->getLoadOpCount(memref: inEdge.value) > 0 &&
1478 mdg->getNode(id: inEdge.id)->getStoreOpCount(memref: inEdge.value) > 0)
1479 inEdges.push_back(Elt: inEdge);
1480 });
1481
1482 // Search for sibling nodes to fuse by visiting output edges from each input
1483 // edge in 'inEdges'.
1484 for (auto &inEdge : inEdges) {
1485 // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'.
1486 SmallVector<MemRefDependenceGraph::Edge, 2> outEdges;
1487 mdg->forEachMemRefOutputEdge(
1488 id: inEdge.id, callback: [&](MemRefDependenceGraph::Edge outEdge) {
1489 unsigned sibNodeId = outEdge.id;
1490 if (visitedSibNodeIds->count(V: sibNodeId) > 0)
1491 return;
1492 // Skip output edge if not a sibling using the same memref.
1493 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
1494 return;
1495 auto *sibNode = mdg->getNode(id: sibNodeId);
1496 if (!isa<AffineForOp>(Val: sibNode->op))
1497 return;
1498 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
1499 if (canFuseWithSibNode(sibNode, outEdge.value)) {
1500 // Add candidate 'outEdge' to sibling node.
1501 outEdges.push_back(Elt: outEdge);
1502 }
1503 });
1504
1505 // Add first candidate if any were returned.
1506 if (!outEdges.empty()) {
1507 visitedSibNodeIds->insert(V: outEdges[0].id);
1508 idAndMemrefToFuse->first = outEdges[0].id;
1509 idAndMemrefToFuse->second = outEdges[0].value;
1510 return true;
1511 }
1512 }
1513 return false;
1514 }
1515
1516 /// Update data dependence graph state to reflect sibling fusion of 'sibNode'
1517 /// into 'dstNode'.
1518 void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) {
1519 // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion.
1520 mdg->updateEdges(sibId: sibNode->id, dstId: dstNode->id);
1521
1522 // Collect dst loop stats after memref privatization transformation.
1523 auto dstForInst = cast<AffineForOp>(Val: dstNode->op);
1524 LoopNestStateCollector dstLoopCollector;
1525 dstLoopCollector.collect(opToWalk: dstForInst);
1526 // Clear and add back loads and stores
1527 mdg->clearNodeLoadAndStores(id: dstNode->id);
1528 mdg->addToNode(id: dstNode->id, loads: dstLoopCollector.loadOpInsts,
1529 stores: dstLoopCollector.storeOpInsts, memrefLoads: dstLoopCollector.memrefLoads,
1530 memrefStores: dstLoopCollector.memrefStores, memrefFrees: dstLoopCollector.memrefFrees);
1531 }
1532
1533 // Clean up any allocs with no users.
1534 void eraseUnusedMemRefAllocations() {
1535 for (auto &pair : mdg->memrefEdgeCount) {
1536 if (pair.second > 0)
1537 continue;
1538 auto memref = pair.first;
1539 // Skip if there exist other uses (return operation or function calls).
1540 if (!memref.use_empty())
1541 continue;
1542 // Use list expected to match the dep graph info.
1543 auto *op = memref.getDefiningOp();
1544 if (isa_and_nonnull<memref::AllocOp>(Val: op))
1545 op->erase();
1546 }
1547 }
1548};
1549
1550} // namespace
1551
1552/// Run fusion on `block`.
1553void LoopFusion::runOnBlock(Block *block) {
1554 MemRefDependenceGraph g(*block);
1555 if (!g.init()) {
1556 LLVM_DEBUG(llvm::dbgs() << "MDG init failed\n");
1557 return;
1558 }
1559
1560 std::optional<unsigned> fastMemorySpaceOpt;
1561 if (fastMemorySpace.hasValue())
1562 fastMemorySpaceOpt = fastMemorySpace;
1563 unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;
1564 GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,
1565 maximalFusion, computeToleranceThreshold);
1566
1567 if (affineFusionMode == FusionMode::ProducerConsumer)
1568 fusion.runProducerConsumerFusionOnly();
1569 else if (affineFusionMode == FusionMode::Sibling)
1570 fusion.runSiblingFusionOnly();
1571 else
1572 fusion.runGreedyFusion();
1573}
1574
1575void LoopFusion::runOnOperation() {
1576 // Call fusion on every op that has at least two affine.for nests (in post
1577 // order).
1578 getOperation()->walk(callback: [&](Operation *op) {
1579 for (Region &region : op->getRegions()) {
1580 for (Block &block : region.getBlocks()) {
1581 auto affineFors = block.getOps<AffineForOp>();
1582 if (!affineFors.empty() && !llvm::hasSingleElement(C&: affineFors))
1583 runOnBlock(block: &block);
1584 }
1585 }
1586 });
1587}
1588
1589std::unique_ptr<Pass> mlir::affine::createLoopFusionPass(
1590 unsigned fastMemorySpace, uint64_t localBufSizeThreshold,
1591 bool maximalFusion, enum FusionMode affineFusionMode) {
1592 return std::make_unique<LoopFusion>(args&: fastMemorySpace, args&: localBufSizeThreshold,
1593 args&: maximalFusion, args&: affineFusionMode);
1594}
1595

source code of mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp