1//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements affine fusion.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/Passes.h"
14
15#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
16#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
17#include "mlir/Dialect/Affine/Analysis/Utils.h"
18#include "mlir/Dialect/Affine/LoopFusionUtils.h"
19#include "mlir/Dialect/Affine/LoopUtils.h"
20#include "mlir/Dialect/Affine/Utils.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/IR/AffineExpr.h"
23#include "mlir/IR/AffineMap.h"
24#include "mlir/IR/Builders.h"
25#include "mlir/Transforms/Passes.h"
26#include "llvm/ADT/DenseMap.h"
27#include "llvm/ADT/DenseSet.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/Support/CommandLine.h"
30#include "llvm/Support/Debug.h"
31#include "llvm/Support/raw_ostream.h"
32#include <iomanip>
33#include <optional>
34#include <sstream>
35
36namespace mlir {
37namespace affine {
38#define GEN_PASS_DEF_AFFINELOOPFUSION
39#include "mlir/Dialect/Affine/Passes.h.inc"
40} // namespace affine
41} // namespace mlir
42
43#define DEBUG_TYPE "affine-fusion"
44
45using namespace mlir;
46using namespace mlir::affine;
47
48namespace {
49/// Loop fusion pass. This pass currently supports a greedy fusion policy,
50/// which fuses loop nests with single-writer/single-reader memref dependences
51/// with the goal of improving locality.
52// TODO: Support fusion of source loop nests which write to multiple
53// memrefs, where each memref can have multiple users (if profitable).
54struct LoopFusion : public affine::impl::AffineLoopFusionBase<LoopFusion> {
55 LoopFusion() = default;
56 LoopFusion(unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes,
57 bool maximalFusion, enum FusionMode affineFusionMode) {
58 this->fastMemorySpace = fastMemorySpace;
59 this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024;
60 this->maximalFusion = maximalFusion;
61 this->affineFusionMode = affineFusionMode;
62 }
63
64 void runOnBlock(Block *block);
65 void runOnOperation() override;
66};
67
68} // namespace
69
70/// Returns true if node 'srcId' can be removed after fusing it with node
71/// 'dstId'. The node can be removed if any of the following conditions are met:
72/// 1. 'srcId' has no output dependences after fusion and no escaping memrefs.
73/// 2. 'srcId' has no output dependences after fusion, has escaping memrefs
74/// and the fusion slice is maximal.
75/// 3. 'srcId' has output dependences after fusion, the fusion slice is
76/// maximal and the fusion insertion point dominates all the dependences.
77static bool canRemoveSrcNodeAfterFusion(
78 unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice,
79 Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs,
80 const MemRefDependenceGraph &mdg) {
81
82 Operation *dstNodeOp = mdg.getNode(id: dstId)->op;
83 bool hasOutDepsAfterFusion = false;
84
85 for (auto &outEdge : mdg.outEdges.lookup(Val: srcId)) {
86 Operation *depNodeOp = mdg.getNode(id: outEdge.id)->op;
87 // Skip dependence with dstOp since it will be removed after fusion.
88 if (depNodeOp == dstNodeOp)
89 continue;
90
91 // Only fusion within the same block is supported. Use domination analysis
92 // when needed.
93 if (depNodeOp->getBlock() != dstNodeOp->getBlock())
94 return false;
95
96 // Check if the insertion point of the fused loop dominates the dependence.
97 // Otherwise, the src loop can't be removed.
98 if (fusedLoopInsPoint != depNodeOp &&
99 !fusedLoopInsPoint->isBeforeInBlock(other: depNodeOp)) {
100 LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't "
101 "dominate dependence\n");
102 return false;
103 }
104
105 hasOutDepsAfterFusion = true;
106 }
107
108 // If src loop has dependences after fusion or it writes to an live-out or
109 // escaping memref, we can only remove it if the fusion slice is maximal so
110 // that all the dependences are preserved.
111 if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
112 std::optional<bool> isMaximal = fusionSlice.isMaximal();
113 if (!isMaximal) {
114 LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine "
115 "if fusion is maximal\n");
116 return false;
117 }
118
119 if (!*isMaximal) {
120 LLVM_DEBUG(llvm::dbgs()
121 << "Src loop can't be removed: fusion is not maximal\n");
122 return false;
123 }
124 }
125
126 return true;
127}
128
129/// Returns in 'srcIdCandidates' the producer fusion candidates for consumer
130/// 'dstId'. Candidates are sorted by node id order. This order corresponds to
131/// the program order when the 'mdg' is created. However, program order is not
132/// guaranteed and must not be required by the client. Program order won't be
133/// held if the 'mdg' is reused from a previous fusion step or if the node
134/// creation order changes in the future to support more advance cases.
135// TODO: Move this to a loop fusion utility once 'mdg' is also moved.
136static void getProducerCandidates(unsigned dstId,
137 const MemRefDependenceGraph &mdg,
138 SmallVectorImpl<unsigned> &srcIdCandidates) {
139 // Skip if no input edges along which to fuse.
140 if (mdg.inEdges.count(Val: dstId) == 0)
141 return;
142
143 // Gather memrefs from loads in 'dstId'.
144 auto *dstNode = mdg.getNode(id: dstId);
145 DenseSet<Value> consumedMemrefs;
146 for (Operation *load : dstNode->loads)
147 consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
148
149 // Traverse 'dstId' incoming edges and gather the nodes that contain a store
150 // to one of the consumed memrefs.
151 for (const auto &srcEdge : mdg.inEdges.lookup(Val: dstId)) {
152 const auto *srcNode = mdg.getNode(id: srcEdge.id);
153 // Skip if 'srcNode' is not a loop nest.
154 if (!isa<AffineForOp>(Val: srcNode->op))
155 continue;
156
157 if (any_of(Range: srcNode->stores, P: [&](Operation *op) {
158 auto storeOp = cast<AffineWriteOpInterface>(op);
159 return consumedMemrefs.count(storeOp.getMemRef()) > 0;
160 }))
161 srcIdCandidates.push_back(Elt: srcNode->id);
162 }
163
164 llvm::sort(C&: srcIdCandidates);
165 srcIdCandidates.erase(CS: llvm::unique(R&: srcIdCandidates), CE: srcIdCandidates.end());
166}
167
168/// Returns in 'producerConsumerMemrefs' the memrefs involved in a
169/// producer-consumer dependence between 'srcId' and 'dstId'.
170static void
171gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
172 const MemRefDependenceGraph &mdg,
173 DenseSet<Value> &producerConsumerMemrefs) {
174 auto *dstNode = mdg.getNode(id: dstId);
175 auto *srcNode = mdg.getNode(id: srcId);
176 gatherProducerConsumerMemrefs(srcOps: srcNode->stores, dstOps: dstNode->loads,
177 producerConsumerMemrefs);
178}
179
180/// A memref escapes in the context of the fusion pass if either:
181/// 1. it (or its alias) is a block argument, or
182/// 2. created by an op not known to guarantee alias freedom,
183/// 3. it (or its alias) are used by ops other than affine dereferencing ops
184/// (e.g., by call op, memref load/store ops, alias creating ops, unknown ops,
185/// terminator ops, etc.); such ops do not deference the memref in an affine
186/// way.
187static bool isEscapingMemref(Value memref, Block *block) {
188 Operation *defOp = memref.getDefiningOp();
189 // Check if 'memref' is a block argument.
190 if (!defOp)
191 return true;
192
193 // Check if this is defined to be an alias of another memref.
194 if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp))
195 if (isEscapingMemref(viewOp.getViewSource(), block))
196 return true;
197
198 // Any op besides allocating ops wouldn't guarantee alias freedom
199 if (!hasSingleEffect<mlir::MemoryEffects::Allocate>(op: defOp, value: memref))
200 return true;
201
202 // Check if 'memref' is used by a non-deferencing op (including unknown ones)
203 // (e.g., call ops, alias creating ops, etc.).
204 return llvm::any_of(Range: memref.getUsers(), P: [&](Operation *user) {
205 // Ignore users outside of `block`.
206 Operation *ancestorOp = block->getParent()->findAncestorOpInRegion(op&: *user);
207 if (!ancestorOp)
208 return true;
209 if (ancestorOp->getBlock() != block)
210 return false;
211 return !isa<AffineMapAccessInterface>(*user);
212 });
213}
214
215/// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
216/// that escape the block or are accessed in a non-affine way.
217static void gatherEscapingMemrefs(unsigned id, const MemRefDependenceGraph &mdg,
218 DenseSet<Value> &escapingMemRefs) {
219 auto *node = mdg.getNode(id);
220 for (Operation *storeOp : node->stores) {
221 auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
222 if (escapingMemRefs.count(V: memref))
223 continue;
224 if (isEscapingMemref(memref, &mdg.block))
225 escapingMemRefs.insert(memref);
226 }
227}
228
229// Sinks all sequential loops to the innermost levels (while preserving
230// relative order among them) and moves all parallel loops to the
231// outermost (while again preserving relative order among them).
232// This can increase the loop depth at which we can fuse a slice, since we are
233// pushing loop carried dependence to a greater depth in the loop nest.
234static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
235 assert(isa<AffineForOp>(node->op));
236 AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op));
237 node->op = newRootForOp;
238}
239
240/// Get the operation that should act as a dominance filter while replacing
241/// memref uses with a private memref for which `producerStores` and
242/// `sliceInsertionBlock` are provided. This effectively determines in what
243/// part of the IR we should be performing the replacement.
244static Operation *
245getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock,
246 ArrayRef<Operation *> producerStores) {
247 assert(!producerStores.empty() && "expected producer store");
248
249 // We first find the common block that contains the producer stores and
250 // the slice computation. The first ancestor among the ancestors of the
251 // producer stores in that common block is the dominance filter to use for
252 // replacement.
253 Block *commonBlock = nullptr;
254 // Find the common block of all relevant operations.
255 for (Operation *store : producerStores) {
256 Operation *otherOp =
257 !commonBlock ? &*sliceInsertionBlock->begin() : &*commonBlock->begin();
258 commonBlock = findInnermostCommonBlockInScope(a: store, b: otherOp);
259 }
260 assert(commonBlock &&
261 "common block of producer stores and slice should exist");
262
263 // Find the first ancestor among the ancestors of `producerStores` in
264 // `commonBlock`.
265 Operation *firstAncestor = nullptr;
266 for (Operation *store : producerStores) {
267 Operation *ancestor = commonBlock->findAncestorOpInBlock(op&: *store);
268 assert(ancestor && "producer store should be contained in common block");
269 firstAncestor = !firstAncestor || ancestor->isBeforeInBlock(other: firstAncestor)
270 ? ancestor
271 : firstAncestor;
272 }
273 return firstAncestor;
274}
275
276/// Returns the amount of additional (redundant) computation that will be done
277/// as a fraction of the total computation if `srcForOp` is fused into
278/// `dstForOp` at depth `depth`. The method returns the compute cost of the
279/// slice and the fused nest's compute cost in the trailing output arguments.
280static std::optional<double> getAdditionalComputeFraction(
281 AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth,
282 ArrayRef<ComputationSliceState> depthSliceUnions, int64_t &sliceCost,
283 int64_t &fusedLoopNestComputeCost) {
284 LLVM_DEBUG(llvm::dbgs() << "Determining additional compute fraction...\n";);
285 // Compute cost of sliced and unsliced src loop nest.
286 // Walk src loop nest and collect stats.
287 LoopNestStats srcLoopNestStats;
288 if (!getLoopNestStats(srcForOp, &srcLoopNestStats)) {
289 LLVM_DEBUG(llvm::dbgs() << "Failed to get source loop nest stats.\n");
290 return std::nullopt;
291 }
292
293 // Compute cost of dst loop nest.
294 LoopNestStats dstLoopNestStats;
295 if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) {
296 LLVM_DEBUG(llvm::dbgs() << "Failed to get destination loop nest stats.\n");
297 return std::nullopt;
298 }
299
300 // Compute op instance count for the src loop nest without iteration slicing.
301 uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
302
303 // Compute op cost for the dst loop nest.
304 uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
305
306 const ComputationSliceState &slice = depthSliceUnions[depth - 1];
307 // Skip slice union if it wasn't computed for this depth.
308 if (slice.isEmpty()) {
309 LLVM_DEBUG(llvm::dbgs() << "Slice wasn't computed.\n");
310 return std::nullopt;
311 }
312
313 if (!getFusionComputeCost(srcForOp, srcLoopNestStats, dstForOp,
314 dstLoopNestStats, slice,
315 &fusedLoopNestComputeCost)) {
316 LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n");
317 return std::nullopt;
318 }
319
320 double additionalComputeFraction =
321 fusedLoopNestComputeCost /
322 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
323 1;
324
325 return additionalComputeFraction;
326}
327
328// Creates and returns a private (single-user) memref for fused loop rooted at
329// 'forOp', with (potentially reduced) memref size based on the memref region
330// written to by `storeOps` at depth 'dstLoopDepth'. 'sliceInsertionBlock'
331// specifies the block in which the slice was/will be inserted. The method
332// expects that all stores ops to the memref have the same access function.
333// Returns nullptr if the creation failed.
334static Value createPrivateMemRef(AffineForOp forOp,
335 ArrayRef<Operation *> storeOps,
336 unsigned dstLoopDepth,
337 std::optional<unsigned> fastMemorySpace,
338 Block *sliceInsertionBlock,
339 uint64_t localBufSizeThreshold) {
340 assert(!storeOps.empty() && "no source stores supplied");
341
342 // Check if all stores have the same access function; we only support this
343 // case.
344 // TODO: Use union of memref write regions to compute private memref footprint
345 // for store ops with different access functions.
346 if (storeOps.size() > 1 &&
347 !std::equal(first1: std::next(x: storeOps.begin()), last1: storeOps.end(), first2: storeOps.begin(),
348 binary_pred: [](Operation *a, Operation *b) {
349 MemRefAccess aM(cast<AffineWriteOpInterface>(a));
350 MemRefAccess bM(cast<AffineWriteOpInterface>(b));
351 return aM == bM;
352 })) {
353 LLVM_DEBUG(llvm::dbgs()
354 << "Private memref creation unsupported for multiple producer "
355 "stores with different access functions.\n");
356 return nullptr;
357 }
358
359 Operation *srcStoreOp = storeOps[0];
360
361 // Create builder to insert alloc op just before 'forOp'.
362 OpBuilder b(forOp);
363 // Builder to create constants at the top level.
364 OpBuilder top(forOp->getParentRegion());
365 // Create new memref type based on slice bounds.
366 auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOp).getMemRef();
367 auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
368 unsigned rank = oldMemRefType.getRank();
369
370 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
371 MemRefRegion region(srcStoreOp->getLoc());
372 bool validRegion = succeeded(
373 Result: region.compute(op: srcStoreOp, loopDepth: dstLoopDepth, /*sliceState=*/nullptr,
374 /*addMemRefDimBounds=*/true, /*dropLocalVars=*/false));
375
376 (void)validRegion;
377 assert(validRegion && "unexpected memref region failure");
378 SmallVector<int64_t, 4> newShape;
379 SmallVector<AffineMap, 4> lbs;
380 lbs.reserve(N: rank);
381 // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
382 // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
383 std::optional<int64_t> numElements =
384 region.getConstantBoundingSizeAndShape(shape: &newShape, lbs: &lbs);
385 assert(numElements && "non-constant number of elts in local buffer");
386
387 const FlatAffineValueConstraints *cst = region.getConstraints();
388 // 'outerIVs' holds the values that this memory region is symbolic/parametric
389 // on; this would correspond to loop IVs surrounding the level at which the
390 // slice is being materialized.
391 SmallVector<Value, 8> outerIVs;
392 cst->getValues(start: rank, end: cst->getNumDimAndSymbolVars(), values: &outerIVs);
393
394 // Build 'rank' AffineExprs from MemRefRegion 'lbs'
395 SmallVector<AffineExpr, 4> offsets;
396 offsets.reserve(N: rank);
397
398 // Outer IVs are considered symbols during memref region computation. Replace
399 // them uniformly with dims so that valid IR is guaranteed.
400 SmallVector<AffineExpr> replacements;
401 for (unsigned j = 0, e = lbs[0].getNumSymbols(); j < e; ++j)
402 replacements.push_back(Elt: mlir::getAffineDimExpr(position: j, context: forOp.getContext()));
403 for (unsigned d = 0; d < rank; ++d) {
404 assert(lbs[d].getNumResults() == 1 &&
405 "invalid private memref bound calculation");
406 offsets.push_back(Elt: lbs[d].getResult(idx: 0).replaceSymbols(symReplacements: replacements));
407 }
408
409 // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
410 // by 'srcStoreOpInst'.
411 auto eltSize = getMemRefIntOrFloatEltSizeInBytes(oldMemRefType);
412 assert(eltSize && "memrefs with size elt types expected");
413 uint64_t bufSize = *eltSize * *numElements;
414 Attribute newMemSpace;
415 if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
416 newMemSpace = b.getI64IntegerAttr(*fastMemorySpace);
417 } else {
418 newMemSpace = oldMemRefType.getMemorySpace();
419 }
420 auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),
421 /*map=*/AffineMap(), newMemSpace);
422
423 // Create new private memref for fused loop 'forOp'. 'newShape' is always
424 // a constant shape.
425 // TODO: Create/move alloc ops for private memrefs closer to their
426 // consumer loop nests to reduce their live range. Currently they are added
427 // at the beginning of the block, because loop nests can be reordered
428 // during the fusion pass.
429 Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType);
430
431 // Build an AffineMap to remap access functions based on lower bound offsets.
432 SmallVector<AffineExpr, 4> remapExprs;
433 remapExprs.reserve(N: rank);
434 for (unsigned i = 0; i < rank; i++) {
435 auto dimExpr = b.getAffineDimExpr(position: outerIVs.size() + i);
436
437 auto remapExpr =
438 simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
439 remapExprs.push_back(Elt: remapExpr);
440 }
441
442 auto indexRemap =
443 AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
444
445 // Replace all users of 'oldMemRef' with 'newMemRef'.
446 Operation *domFilter =
447 getDominanceFilterForPrivateMemRefRepl(sliceInsertionBlock, producerStores: storeOps);
448 LogicalResult res = replaceAllMemRefUsesWith(
449 oldMemRef, newMemRef, /*extraIndices=*/{}, indexRemap,
450 /*extraOperands=*/outerIVs,
451 /*symbolOperands=*/{}, domFilter);
452 assert(succeeded(res) &&
453 "replaceAllMemrefUsesWith should always succeed here");
454 (void)res;
455 LLVM_DEBUG(llvm::dbgs() << "Created private memref of type: " << newMemRefType
456 << '\n');
457 return newMemRef;
458}
459
460// Checks the profitability of fusing a backwards slice of the loop nest
461// `srcForOp` into the loop nest surrounding 'dstLoadOpInsts'. The argument
462// 'srcStoreOpInst' is used to calculate the storage reduction on the memref
463// being produced and consumed, which is an input to the cost model. For
464// producer-consumer fusion, 'srcStoreOpInst' will be the same as 'srcOpInst',
465// as we are slicing w.r.t to that producer. For input-reuse fusion, 'srcOpInst'
466// will be the src loop nest LoadOp which reads from the same memref as dst loop
467// nest load ops, and 'srcStoreOpInst' will be the unique store op in the src
468// node, which will be used to check that the write region is the same after
469// input-reuse fusion. Computation slices are provided in 'depthSliceUnions' for
470// each legal fusion depth. The maximal depth at which fusion is legal is
471// provided in 'maxLegalFusionDepth'. Returns true if it is profitable to fuse
472// the candidate loop nests. Returns false otherwise. `dstLoopDepth` is set to
473// the most profitable depth at which to materialize the source loop nest slice.
474// The profitability model executes the following steps:
475// *) Computes the backward computation slice at 'srcOpInst'. This
476// computation slice of the loop nest surrounding 'srcOpInst' is
477// represented by modified src loop bounds in 'sliceState', which are
478// functions of loop IVs in the loop nest surrounding 'srcOpInst'.
479// *) Computes the cost of unfused src/dst loop nests (currently the cost of a
480// loop nest is the total number of dynamic operation instances in the loop
481// nest).
482// *) Computes the cost of fusing a slice of the src loop nest into the dst
483// loop nest at various values of dst loop depth, attempting to fuse
484// the largest computation slice at the maximal dst loop depth (closest to
485// the load) to minimize reuse distance and potentially enable subsequent
486// load/store forwarding.
487// NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
488// nest, at which the src computation slice is inserted/fused.
489// NOTE: We attempt to maximize the dst loop depth, but there are cases
490// where a particular setting for 'dstLoopNest' might fuse an unsliced
491// loop (within the src computation slice) at a depth which results in
492// excessive recomputation (see unit tests for examples).
493// *) Compares the total cost of the unfused loop nests to the min cost fused
494// loop nest computed in the previous step, and returns true if the latter
495// is lower.
496// TODO: Extend profitability analysis to support scenarios with multiple
497// stores.
498static bool isFusionProfitable(AffineForOp srcForOp,
499 ArrayRef<Operation *> producerStores,
500 AffineForOp dstForOp,
501 ArrayRef<ComputationSliceState> depthSliceUnions,
502 unsigned maxLegalFusionDepth,
503 unsigned *dstLoopDepth,
504 double computeToleranceThreshold) {
505 LLVM_DEBUG({
506 llvm::dbgs()
507 << "Checking whether fusion is profitable between source nest:\n";
508 llvm::dbgs() << ' ' << srcForOp << " and destination nest:\n";
509 llvm::dbgs() << dstForOp << "\n";
510 });
511
512 if (maxLegalFusionDepth == 0) {
513 LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth is 0\n");
514 return false;
515 }
516
517 // Compute cost of sliced and unsliced src loop nest.
518
519 // Walk src loop nest and collect stats.
520 LoopNestStats srcLoopNestStats;
521 if (!getLoopNestStats(srcForOp, &srcLoopNestStats))
522 return false;
523
524 // Compute cost of dst loop nest.
525 LoopNestStats dstLoopNestStats;
526 if (!getLoopNestStats(dstForOp, &dstLoopNestStats))
527 return false;
528
529 // We limit profitability analysis to only scenarios with
530 // a single producer store for now. Note that some multi-store
531 // producer scenarios will still go through profitability analysis
532 // if only one of the stores is involved in the producer-consumer
533 // relationship of the candidate loops.
534 // TODO: Suppport multiple producer stores in profitability
535 // analysis.
536 if (producerStores.size() > 1) {
537 LLVM_DEBUG(llvm::dbgs() << "Limited profitability analysis. Not "
538 "supported for multiple producer store case.\n");
539 int64_t sliceCost;
540 int64_t fusedLoopNestComputeCost;
541 // We will still fuse if fusion obeys the specified compute
542 // tolerance at the max legal depth.
543 auto fraction = getAdditionalComputeFraction(
544 srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost,
545 fusedLoopNestComputeCost);
546 if (!fraction || fraction > computeToleranceThreshold) {
547 LLVM_DEBUG(llvm::dbgs() << "Additional computation exceeds "
548 "compute tolerance. Not fusing.\n");
549 return false;
550 }
551 LLVM_DEBUG(llvm::dbgs()
552 << "Considering fusion profitable at max legal depth.\n");
553 return true;
554 }
555
556 Operation *srcStoreOp = producerStores.front();
557
558 // Search for min cost value for 'dstLoopDepth'. At each value of
559 // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice
560 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
561 // of these bounds). Next the union slice bounds are used to calculate
562 // the cost of the slice and the cost of the slice inserted into the dst
563 // loop nest at 'dstLoopDepth'.
564 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
565 double maxStorageReduction = 0.0;
566 std::optional<uint64_t> sliceMemEstimate;
567
568 // The best loop depth at which to materialize the slice.
569 std::optional<unsigned> bestDstLoopDepth;
570
571 // Compute src loop nest write region size.
572 MemRefRegion srcWriteRegion(srcStoreOp->getLoc());
573 if (failed(Result: srcWriteRegion.compute(op: srcStoreOp, /*loopDepth=*/0))) {
574 LLVM_DEBUG(llvm::dbgs()
575 << "Unable to compute MemRefRegion for source operation\n");
576 return false;
577 }
578
579 std::optional<int64_t> maybeSrcWriteRegionSizeBytes =
580 srcWriteRegion.getRegionSize();
581 if (!maybeSrcWriteRegionSizeBytes.has_value())
582 return false;
583 int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;
584
585 // Compute op instance count for the src loop nest without iteration slicing.
586 uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
587
588 // Compute op instance count for the destination loop nest.
589 uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
590
591 // Evaluate all depth choices for materializing the slice in the destination
592 // loop nest.
593 for (unsigned i = maxLegalFusionDepth; i >= 1; --i) {
594 const ComputationSliceState &slice = depthSliceUnions[i - 1];
595 // Skip slice union if it wasn't computed for this depth.
596 if (slice.isEmpty())
597 continue;
598
599 // Compute cost of the slice separately, i.e, the compute cost of the slice
600 // if all outer trip counts are one.
601 int64_t sliceCost;
602
603 int64_t fusedLoopNestComputeCost;
604
605 auto mayAdditionalComputeFraction =
606 getAdditionalComputeFraction(srcForOp, dstForOp, i, depthSliceUnions,
607 sliceCost, fusedLoopNestComputeCost);
608 if (!mayAdditionalComputeFraction) {
609 LLVM_DEBUG(llvm::dbgs()
610 << "Can't determine additional compute fraction.\n");
611 continue;
612 }
613 double additionalComputeFraction = *mayAdditionalComputeFraction;
614
615 // Determine what the slice write MemRefRegion would be, if the src loop
616 // nest slice 'slice' were to be inserted into the dst loop nest at loop
617 // depth 'i'.
618 MemRefRegion sliceWriteRegion(srcStoreOp->getLoc());
619 if (failed(Result: sliceWriteRegion.compute(op: srcStoreOp, /*loopDepth=*/0, sliceState: &slice))) {
620 LLVM_DEBUG(llvm::dbgs()
621 << "Failed to compute slice write region at loopDepth: " << i
622 << "\n");
623 continue;
624 }
625
626 std::optional<int64_t> maybeSliceWriteRegionSizeBytes =
627 sliceWriteRegion.getRegionSize();
628 if (!maybeSliceWriteRegionSizeBytes.has_value() ||
629 *maybeSliceWriteRegionSizeBytes == 0) {
630 LLVM_DEBUG(llvm::dbgs()
631 << "Failed to get slice write region size at loopDepth: " << i
632 << "\n");
633 continue;
634 }
635 int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
636
637 double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
638 static_cast<double>(sliceWriteRegionSizeBytes);
639
640 LLVM_DEBUG({
641 std::stringstream msg;
642 msg << " evaluating fusion profitability at depth : " << i << "\n"
643 << std::fixed << std::setprecision(2)
644 << " additional compute fraction: "
645 << 100.0 * additionalComputeFraction << "%\n"
646 << " storage reduction factor: " << storageReduction << "x\n"
647 << " fused nest cost: " << fusedLoopNestComputeCost << "\n"
648 << " src write region size: " << srcWriteRegionSizeBytes << "\n"
649 << " slice write region size: " << sliceWriteRegionSizeBytes
650 << "\n";
651 llvm::dbgs() << msg.str();
652 });
653
654 // TODO: This is a placeholder cost model.
655 // Among all choices that add an acceptable amount of redundant computation
656 // (as per computeToleranceThreshold), we will simply pick the one that
657 // reduces the intermediary size the most.
658 if ((storageReduction > maxStorageReduction) &&
659 (additionalComputeFraction <= computeToleranceThreshold)) {
660 maxStorageReduction = storageReduction;
661 bestDstLoopDepth = i;
662 minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
663 sliceMemEstimate = sliceWriteRegionSizeBytes;
664 }
665 }
666
667 // A simple cost model: fuse if it reduces the memory footprint.
668
669 if (!bestDstLoopDepth) {
670 LLVM_DEBUG(
671 llvm::dbgs()
672 << "All fusion choices involve more than the threshold amount of "
673 "redundant computation; NOT fusing.\n");
674 return false;
675 }
676
677 if (!bestDstLoopDepth) {
678 LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n");
679 return false;
680 }
681
682 // Set dstLoopDepth based on best values from search.
683 *dstLoopDepth = *bestDstLoopDepth;
684
685 LLVM_DEBUG(
686 llvm::dbgs() << " LoopFusion fusion stats:"
687 << "\n best loop depth: " << bestDstLoopDepth
688 << "\n src loop nest compute cost: " << srcLoopNestCost
689 << "\n dst loop nest compute cost: " << dstLoopNestCost
690 << "\n fused loop nest compute cost: "
691 << minFusedLoopNestComputeCost << "\n");
692
693 auto dstMemSize = getMemoryFootprintBytes(dstForOp);
694 auto srcMemSize = getMemoryFootprintBytes(srcForOp);
695
696 std::optional<double> storageReduction;
697
698 if (!dstMemSize || !srcMemSize) {
699 LLVM_DEBUG(llvm::dbgs()
700 << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
701 return false;
702 }
703
704 auto srcMemSizeVal = *srcMemSize;
705 auto dstMemSizeVal = *dstMemSize;
706
707 assert(sliceMemEstimate && "expected value");
708 auto fusedMem = dstMemSizeVal + *sliceMemEstimate;
709
710 LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
711 << " dst mem: " << dstMemSizeVal << "\n"
712 << " fused mem: " << fusedMem << "\n"
713 << " slice mem: " << sliceMemEstimate << "\n");
714
715 if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
716 LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
717 return false;
718 }
719 storageReduction =
720 100.0 *
721 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
722
723 double additionalComputeFraction =
724 100.0 * (minFusedLoopNestComputeCost /
725 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
726 1);
727 (void)additionalComputeFraction;
728 LLVM_DEBUG({
729 std::stringstream msg;
730 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
731 << std::setprecision(2) << additionalComputeFraction
732 << "% redundant computation and a ";
733 msg << (storageReduction ? std::to_string(*storageReduction) : "<unknown>");
734 msg << "% storage reduction.\n";
735 llvm::dbgs() << msg.str();
736 });
737
738 return true;
739}
740
741namespace {
742
743// GreedyFusion greedily fuses loop nests which have a producer/consumer or
744// input-reuse relationship on a memref, with the goal of improving locality.
745//
746// The steps of the producer-consumer fusion algorithm are as follows:
747//
748// *) A worklist is initialized with node ids from the dependence graph.
749// *) For each node id in the worklist:
750// *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a
751// candidate destination AffineForOp into which fusion will be attempted.
752// *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
753// *) For each LoadOp in 'dstLoadOps' do:
754// *) Look up dependent loop nests which have a single store op to the same
755// memref.
756// *) Check if dependences would be violated by the fusion.
757// *) Get a computation slice of 'srcLoopNest', which adjusts its loop
758// bounds to be functions of 'dstLoopNest' IVs and symbols.
759// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
760// at a loop depth determined by the cost model in 'isFusionProfitable'.
761// *) Add the newly fused load/store operations to the state,
762// and also add newly fused load ops to 'dstLoopOps' to be considered
763// as fusion dst load ops in another iteration.
764// *) Remove old src loop nest and its associated state.
765//
766// The steps of the input-reuse fusion algorithm are as follows:
767//
768// *) Initialize 'worklist' with node ids from the dependence graph.
769// *) For each 'dstNode' in the worklist:
770// *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which
771// loads from the same memref, but which has no dependence paths to/from.
772// *) Get a computation slice of 'sibLoopNest', which adjusts its loop
773// bounds to be functions of 'dstLoopNest' IVs and symbols.
774// *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest',
775// at a loop depth determined by the cost model in 'isFusionProfitable'.
776// This function also checks that the memref write region of 'sibLoopNest',
777// is preserved in the fused loop nest.
778// *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'.
779//
780// Given a graph where top-level operations are vertices in the set 'V' and
781// edges in the set 'E' are dependences between vertices, this algorithm
782// takes O(V) time for initialization, and has runtime O(V + E).
783//
784// This greedy algorithm is not 'maximal' due to the current restriction of
785// fusing along single producer consumer edges, but there is a TODO: to fix
786// this.
787//
788// TODO: Experiment with other fusion policies.
789struct GreedyFusion {
790public:
791 // The data dependence graph to traverse during fusion.
792 MemRefDependenceGraph *mdg;
793 // Worklist of graph nodes visited during the fusion pass.
794 SmallVector<unsigned, 8> worklist;
795 // Parameter for local buffer size threshold.
796 unsigned localBufSizeThreshold;
797 // Parameter for fast memory space.
798 std::optional<unsigned> fastMemorySpace;
799 // If true, ignore any additional (redundant) computation tolerance threshold
800 // that would have prevented fusion.
801 bool maximalFusion;
802 // The amount of additional computation that is tolerated while fusing
803 // pair-wise as a fraction of the total computation.
804 double computeToleranceThreshold;
805
806 using Node = MemRefDependenceGraph::Node;
807
808 GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold,
809 std::optional<unsigned> fastMemorySpace, bool maximalFusion,
810 double computeToleranceThreshold)
811 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
812 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
813 computeToleranceThreshold(computeToleranceThreshold) {}
814
815 /// Initializes 'worklist' with nodes from 'mdg'.
816 void init() {
817 // TODO: Add a priority queue for prioritizing nodes by different
818 // metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
819 worklist.clear();
820 for (auto &idAndNode : mdg->nodes) {
821 const Node &node = idAndNode.second;
822 worklist.push_back(Elt: node.id);
823 }
824 }
825 /// Run only sibling fusion on the `mdg`.
826 void runSiblingFusionOnly() {
827 fuseSiblingNodes();
828 eraseUnusedMemRefAllocations();
829 }
830
831 /// Run only producer/consumer fusion on the `mdg`.
832 void runProducerConsumerFusionOnly() {
833 fuseProducerConsumerNodes(
834 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
835 eraseUnusedMemRefAllocations();
836 }
837
838 // Run the GreedyFusion pass.
839 // *) First pass through the nodes fuses single-use producer nodes into their
840 // unique consumer.
841 // *) Second pass fuses sibling nodes which share no dependence edges.
842 // *) Third pass fuses any remaining producer nodes into their users.
843 void runGreedyFusion() {
844 // TODO: Run this repeatedly until a fixed-point is reached.
845 fuseProducerConsumerNodes(/*maxSrcUserCount=*/1);
846 fuseSiblingNodes();
847 fuseProducerConsumerNodes(
848 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
849 eraseUnusedMemRefAllocations();
850 }
851
852 /// Returns true if a private memref can be created for `memref` given
853 /// the fusion scenario reflected by the other arguments.
854 bool canCreatePrivateMemRef(Value memref,
855 const DenseSet<Value> &srcEscapingMemRefs,
856 unsigned producerId, unsigned consumerId,
857 bool removeSrcNode) {
858 // We can't generate private memrefs if their size can't be computed.
859 if (!getMemRefIntOrFloatEltSizeInBytes(cast<MemRefType>(memref.getType())))
860 return false;
861 const Node *consumerNode = mdg->getNode(id: consumerId);
862 // If `memref` is an escaping one, do not create a private memref
863 // for the below scenarios, since doing so will leave the escaping
864 // memref unmodified as all the writes originally meant for the
865 // escaping memref would be performed on the private memref:
866 // 1. The source is to be removed after fusion,
867 // OR
868 // 2. The destination writes to `memref`.
869 if (srcEscapingMemRefs.count(V: memref) > 0 &&
870 (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))
871 return false;
872
873 // Don't create a private memref if 'srcNode' has in edges on
874 // 'memref' or 'dstNode' has out edges on 'memref'.
875 if (mdg->getIncomingMemRefAccesses(id: producerId, memref) > 0 ||
876 mdg->getOutEdgeCount(id: consumerId, memref) > 0)
877 return false;
878
879 // If 'srcNode' will be removed but it has out edges on 'memref' to
880 // nodes other than 'dstNode', we have to preserve dependences and
881 // cannot create a private memref.
882 if (removeSrcNode &&
883 any_of(Range&: mdg->outEdges[producerId], P: [&](const auto &edge) {
884 return edge.value == memref && edge.id != consumerId;
885 }))
886 return false;
887
888 return true;
889 }
890
891 /// Perform fusions with node `dstId` as the destination of fusion, with
892 /// No fusion is performed when producers with a user count greater than
893 /// `maxSrcUserCount` for any of the memrefs involved.
894 void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) {
895 LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
896 // Skip if this node was removed (fused into another node).
897 if (mdg->nodes.count(Val: dstId) == 0)
898 return;
899 // Get 'dstNode' into which to attempt fusion.
900 auto *dstNode = mdg->getNode(id: dstId);
901 // Skip if 'dstNode' is not a loop nest.
902 if (!isa<AffineForOp>(Val: dstNode->op))
903 return;
904 // Skip if 'dstNode' is a loop nest returning values.
905 // TODO: support loop nests that return values.
906 if (dstNode->op->getNumResults() > 0)
907 return;
908
909 LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
910
911 // Sink sequential loops in 'dstNode' (and thus raise parallel loops)
912 // while preserving relative order. This can increase the maximum loop
913 // depth at which we can fuse a slice of a producer loop nest into a
914 // consumer loop nest.
915 sinkSequentialLoops(node: dstNode);
916 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
917
918 // Try to fuse 'dstNode' with candidate producer loops until a fixed point
919 // is reached. Fusing two loops may expose new fusion opportunities.
920 bool dstNodeChanged;
921 do {
922 // Gather src loop candidates for 'dstNode' and visit them in "quasi"
923 // reverse program order to minimize the number of iterations needed to
924 // reach the fixed point. Note that this is a best effort approach since
925 // 'getProducerCandidates' does not always guarantee that program order
926 // in 'srcIdCandidates'.
927 dstNodeChanged = false;
928 SmallVector<unsigned, 16> srcIdCandidates;
929 getProducerCandidates(dstId, mdg: *mdg, srcIdCandidates);
930
931 for (unsigned srcId : llvm::reverse(C&: srcIdCandidates)) {
932 // Get 'srcNode' from which to attempt fusion into 'dstNode'.
933 auto *srcNode = mdg->getNode(id: srcId);
934 auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
935
936 LLVM_DEBUG(llvm::dbgs()
937 << "Trying to fuse producer loop nest " << srcId
938 << " with consumer loop nest " << dstId << "\n");
939 LLVM_DEBUG(llvm::dbgs() << "Compute tolerance threshold: "
940 << computeToleranceThreshold << '\n');
941 LLVM_DEBUG(llvm::dbgs()
942 << "Producer loop nest:\n"
943 << *srcNode->op << "\n and consumer loop nest:\n"
944 << *dstNode->op << '\n');
945
946 LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId
947 << " for dst loop " << dstId << "\n");
948
949 // Skip if 'srcNode' is a loop nest returning values.
950 // TODO: support loop nests that return values.
951 if (isa<AffineForOp>(Val: srcNode->op) && srcNode->op->getNumResults() > 0)
952 continue;
953
954 DenseSet<Value> producerConsumerMemrefs;
955 gatherProducerConsumerMemrefs(srcId, dstId, mdg: *mdg,
956 producerConsumerMemrefs);
957
958 // Skip if 'srcNode' out edge count on any memref is greater than
959 // 'maxSrcUserCount'.
960 if (any_of(Range&: producerConsumerMemrefs, P: [&](Value memref) {
961 return mdg->getOutEdgeCount(id: srcNode->id, memref) >
962 maxSrcUserCount;
963 }))
964 continue;
965
966 // Gather memrefs in 'srcNode' that are written and escape out of the
967 // block (e.g., memref block arguments, returned memrefs,
968 // memrefs passed to function calls, etc.).
969 DenseSet<Value> srcEscapingMemRefs;
970 gatherEscapingMemrefs(id: srcNode->id, mdg: *mdg, escapingMemRefs&: srcEscapingMemRefs);
971
972 // Compute an operation list insertion point for the fused loop
973 // nest which preserves dependences.
974 Operation *fusedLoopInsPoint =
975 mdg->getFusedLoopNestInsertionPoint(srcId: srcNode->id, dstId: dstNode->id);
976 if (fusedLoopInsPoint == nullptr)
977 continue;
978
979 // It's possible this fusion is at an inner depth (i.e., there are
980 // common surrounding affine loops for the source and destination for
981 // ops). We need to get this number because the call to canFuseLoops
982 // needs to be passed the absolute depth. The max legal depth and the
983 // depths we try below are however *relative* and as such don't include
984 // the common depth.
985 SmallVector<AffineForOp, 4> surroundingLoops;
986 getAffineForIVs(*dstAffineForOp, &surroundingLoops);
987 unsigned numSurroundingLoops = surroundingLoops.size();
988
989 // Compute the innermost common loop depth for dstNode
990 // producer-consumer loads/stores.
991 SmallVector<Operation *, 2> dstMemrefOps;
992 for (Operation *op : dstNode->loads)
993 if (producerConsumerMemrefs.count(
994 cast<AffineReadOpInterface>(op).getMemRef()) > 0)
995 dstMemrefOps.push_back(Elt: op);
996 for (Operation *op : dstNode->stores)
997 if (producerConsumerMemrefs.count(
998 cast<AffineWriteOpInterface>(op).getMemRef()))
999 dstMemrefOps.push_back(Elt: op);
1000 unsigned dstLoopDepthTest =
1001 getInnermostCommonLoopDepth(ops: dstMemrefOps) - numSurroundingLoops;
1002
1003 // Check the feasibility of fusing src loop nest into dst loop nest
1004 // at loop depths in range [1, dstLoopDepthTest].
1005 unsigned maxLegalFusionDepth = 0;
1006 SmallVector<ComputationSliceState, 8> depthSliceUnions;
1007 depthSliceUnions.resize(N: dstLoopDepthTest);
1008 FusionStrategy strategy(FusionStrategy::ProducerConsumer);
1009 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1010 FusionResult result =
1011 affine::canFuseLoops(srcForOp: srcAffineForOp, dstForOp: dstAffineForOp,
1012 /*dstLoopDepth=*/i + numSurroundingLoops,
1013 srcSlice: &depthSliceUnions[i - 1], fusionStrategy: strategy);
1014 if (result.value == FusionResult::Success) {
1015 maxLegalFusionDepth = i;
1016 LLVM_DEBUG(llvm::dbgs()
1017 << "Found valid slice for depth: " << i << '\n');
1018 }
1019 }
1020
1021 if (maxLegalFusionDepth == 0) {
1022 LLVM_DEBUG(llvm::dbgs()
1023 << "Can't fuse: fusion is not legal at any depth\n");
1024 continue;
1025 }
1026
1027 LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
1028 << maxLegalFusionDepth << '\n');
1029
1030 double computeToleranceThresholdToUse = computeToleranceThreshold;
1031
1032 // Cyclic dependences in the source nest may be violated when performing
1033 // slicing-based fusion. They aren't actually violated in cases where no
1034 // redundant execution of the source happens (1:1 pointwise dep on the
1035 // producer-consumer memref access for example). Check this and allow
1036 // fusion accordingly.
1037 if (hasCyclicDependence(srcAffineForOp)) {
1038 LLVM_DEBUG(llvm::dbgs() << "Source nest has a cyclic dependence.\n");
1039 // Maximal fusion does not check for compute tolerance threshold; so
1040 // perform the maximal fusion only when the redundanation computation
1041 // is zero.
1042 if (maximalFusion) {
1043 auto srcForOp = cast<AffineForOp>(srcNode->op);
1044 auto dstForOp = cast<AffineForOp>(dstNode->op);
1045 int64_t sliceCost;
1046 int64_t fusedLoopNestComputeCost;
1047 auto fraction = getAdditionalComputeFraction(
1048 srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
1049 sliceCost, fusedLoopNestComputeCost);
1050 if (!fraction || fraction > 0) {
1051 LLVM_DEBUG(
1052 llvm::dbgs()
1053 << "Can't perform maximal fusion with a cyclic dependence "
1054 "and non-zero additional compute.\n");
1055 return;
1056 }
1057 } else {
1058 // Set redundant computation tolerance to zero regardless of what
1059 // the user specified. Without this, fusion would be invalid.
1060 LLVM_DEBUG(llvm::dbgs()
1061 << "Setting compute tolerance to zero since "
1062 "source has a cylic dependence.\n");
1063 computeToleranceThresholdToUse = 0;
1064 }
1065 }
1066
1067 // Check if fusion would be profitable. We skip profitability analysis
1068 // for maximal fusion since we already know the maximal legal depth to
1069 // fuse.
1070 unsigned bestDstLoopDepth = maxLegalFusionDepth;
1071 if (!maximalFusion) {
1072 // Retrieve producer stores from the src loop.
1073 SmallVector<Operation *, 2> producerStores;
1074 for (Operation *op : srcNode->stores)
1075 if (producerConsumerMemrefs.count(
1076 cast<AffineWriteOpInterface>(op).getMemRef()))
1077 producerStores.push_back(Elt: op);
1078
1079 assert(!producerStores.empty() && "Expected producer store");
1080 if (!isFusionProfitable(srcAffineForOp, producerStores,
1081 dstAffineForOp, depthSliceUnions,
1082 maxLegalFusionDepth, &bestDstLoopDepth,
1083 computeToleranceThresholdToUse)) {
1084 continue;
1085 }
1086 }
1087
1088 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
1089 ComputationSliceState &bestSlice =
1090 depthSliceUnions[bestDstLoopDepth - 1];
1091 assert(!bestSlice.isEmpty() && "Missing slice union for depth");
1092
1093 // Determine if 'srcId' can be removed after fusion, taking into
1094 // account remaining dependences, escaping memrefs and the fusion
1095 // insertion point.
1096 bool removeSrcNode = canRemoveSrcNodeAfterFusion(
1097 srcId, dstId, fusionSlice: bestSlice, fusedLoopInsPoint, escapingMemRefs: srcEscapingMemRefs,
1098 mdg: *mdg);
1099
1100 DenseSet<Value> privateMemrefs;
1101 for (Value memref : producerConsumerMemrefs) {
1102 if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, producerId: srcId, consumerId: dstId,
1103 removeSrcNode)) {
1104 // Create a private version of this memref.
1105 LLVM_DEBUG(llvm::dbgs()
1106 << "Creating private memref for " << memref << '\n');
1107 // Create a private version of this memref.
1108 privateMemrefs.insert(V: memref);
1109 }
1110 }
1111
1112 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
1113 fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
1114 dstNodeChanged = true;
1115
1116 LLVM_DEBUG(llvm::dbgs()
1117 << "Fused src loop " << srcId << " into dst loop " << dstId
1118 << " at depth " << bestDstLoopDepth << ":\n"
1119 << dstAffineForOp << "\n");
1120
1121 // Move 'dstAffineForOp' before 'insertPointInst' if needed.
1122 if (fusedLoopInsPoint != dstAffineForOp)
1123 dstAffineForOp->moveBefore(fusedLoopInsPoint);
1124
1125 // Update edges between 'srcNode' and 'dstNode'.
1126 mdg->updateEdges(srcId: srcNode->id, dstId: dstNode->id, privateMemRefs: privateMemrefs,
1127 removeSrcId: removeSrcNode);
1128
1129 // Create private memrefs.
1130 if (!privateMemrefs.empty()) {
1131 // Note the block into which fusion was performed. This can be used to
1132 // place `alloc`s that create private memrefs.
1133 Block *sliceInsertionBlock = bestSlice.insertPoint->getBlock();
1134
1135 // Gather stores for all the private-to-be memrefs.
1136 DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
1137 dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
1138 Value storeMemRef = storeOp.getMemRef();
1139 if (privateMemrefs.count(V: storeMemRef) > 0)
1140 privateMemRefToStores[storeMemRef].push_back(Elt: storeOp);
1141 });
1142
1143 // Replace original memrefs with private memrefs. Note that all the
1144 // loads and stores on these memrefs will be replaced with a new
1145 // loads and stores. Any reference to the original ones becomes
1146 // invalid after this point.
1147 for (auto &memrefToStoresPair : privateMemRefToStores) {
1148 ArrayRef<Operation *> storesForMemref = memrefToStoresPair.second;
1149 Value newMemRef = createPrivateMemRef(
1150 dstAffineForOp, storesForMemref, bestDstLoopDepth,
1151 fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
1152 if (!newMemRef)
1153 continue;
1154 // Create new node in dependence graph for 'newMemRef' alloc op.
1155 unsigned newMemRefNodeId = mdg->addNode(op: newMemRef.getDefiningOp());
1156 // Add edge from 'newMemRef' node to dstNode.
1157 mdg->addEdge(srcId: newMemRefNodeId, dstId, value: newMemRef);
1158 }
1159 // One or more entries for 'newMemRef' alloc op are inserted into
1160 // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to
1161 // reallocate, update dstNode.
1162 dstNode = mdg->getNode(id: dstId);
1163 }
1164
1165 // Collect dst loop stats after memref privatization transformation.
1166 LoopNestStateCollector dstLoopCollector;
1167 dstLoopCollector.collect(opToWalk: dstAffineForOp);
1168
1169 // Clear and add back loads and stores.
1170 mdg->clearNodeLoadAndStores(id: dstNode->id);
1171 mdg->addToNode(
1172 id: dstId, loads: dstLoopCollector.loadOpInsts, stores: dstLoopCollector.storeOpInsts,
1173 memrefLoads: dstLoopCollector.memrefLoads, memrefStores: dstLoopCollector.memrefStores,
1174 memrefFrees: dstLoopCollector.memrefFrees);
1175
1176 if (removeSrcNode) {
1177 LLVM_DEBUG(llvm::dbgs()
1178 << "Removing src loop " << srcId << " after fusion\n");
1179 // srcNode is no longer valid after it is removed from mdg.
1180 srcAffineForOp.erase();
1181 mdg->removeNode(id: srcId);
1182 srcNode = nullptr;
1183 }
1184 }
1185 } while (dstNodeChanged);
1186 }
1187
1188 /// Visit each node in the graph, and for each node, attempt to fuse it with
1189 /// producer-consumer candidates. No fusion is performed when producers with a
1190 /// user count greater than `maxSrcUserCount` for any of the memrefs involved
1191 /// are encountered.
1192 void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
1193 LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
1194 init();
1195 while (!worklist.empty()) {
1196 unsigned dstId = worklist.back();
1197 worklist.pop_back();
1198 performFusionsIntoDest(dstId, maxSrcUserCount);
1199 }
1200 }
1201
1202 // Visits each node in the graph, and for each node, attempts to fuse it with
1203 // its sibling nodes (nodes which share a parent, but no dependence edges).
1204 void fuseSiblingNodes() {
1205 LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n");
1206 init();
1207 while (!worklist.empty()) {
1208 unsigned dstId = worklist.back();
1209 worklist.pop_back();
1210
1211 // Skip if this node was removed (fused into another node).
1212 if (mdg->nodes.count(Val: dstId) == 0)
1213 continue;
1214 // Get 'dstNode' into which to attempt fusion.
1215 auto *dstNode = mdg->getNode(id: dstId);
1216 // Skip if 'dstNode' is not a loop nest.
1217 if (!isa<AffineForOp>(Val: dstNode->op))
1218 continue;
1219 // Attempt to fuse 'dstNode' with its sibling nodes in the graph.
1220 fuseWithSiblingNodes(dstNode);
1221 }
1222 }
1223
1224 // Attempt to fuse 'dstNode' with sibling nodes in the graph.
1225 void fuseWithSiblingNodes(Node *dstNode) {
1226 DenseSet<unsigned> visitedSibNodeIds;
1227 std::pair<unsigned, Value> idAndMemref;
1228 auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
1229
1230 while (findSiblingNodeToFuse(dstNode, visitedSibNodeIds: &visitedSibNodeIds, idAndMemrefToFuse: &idAndMemref)) {
1231 unsigned sibId = idAndMemref.first;
1232 Value memref = idAndMemref.second;
1233 // TODO: Check that 'sibStoreOpInst' post-dominates all other
1234 // stores to the same memref in 'sibNode' loop nest.
1235 auto *sibNode = mdg->getNode(id: sibId);
1236 // Compute an operation list insertion point for the fused loop
1237 // nest which preserves dependences.
1238 assert(sibNode->op->getBlock() == dstNode->op->getBlock());
1239 Operation *insertPointInst =
1240 sibNode->op->isBeforeInBlock(other: dstNode->op)
1241 ? mdg->getFusedLoopNestInsertionPoint(srcId: sibNode->id, dstId: dstNode->id)
1242 : mdg->getFusedLoopNestInsertionPoint(srcId: dstNode->id, dstId: sibNode->id);
1243 if (insertPointInst == nullptr)
1244 continue;
1245
1246 // Check if fusion would be profitable and at what depth.
1247
1248 // Get unique 'sibNode' load op to 'memref'.
1249 SmallVector<Operation *, 2> sibLoadOpInsts;
1250 sibNode->getLoadOpsForMemref(memref, loadOps: &sibLoadOpInsts);
1251 // Currently findSiblingNodeToFuse searches for siblings with one load.
1252 Operation *sibLoadOpInst = llvm::getSingleElement(C&: sibLoadOpInsts);
1253
1254 // Gather 'dstNode' load ops to 'memref'.
1255 SmallVector<Operation *, 2> dstLoadOpInsts;
1256 dstNode->getLoadOpsForMemref(memref, loadOps: &dstLoadOpInsts);
1257
1258 // It's possible this fusion is at an inner depth (i.e., there are common
1259 // surrounding affine loops for the source and destination for ops). We
1260 // need to get this number because the call to canFuseLoops needs to be
1261 // passed the absolute depth. The max legal depth and the depths we try
1262 // below are however *relative* and as such don't include the common
1263 // depth.
1264 SmallVector<AffineForOp, 4> surroundingLoops;
1265 getAffineForIVs(*dstAffineForOp, &surroundingLoops);
1266 unsigned numSurroundingLoops = surroundingLoops.size();
1267 SmallVector<AffineForOp, 4> dstLoopIVs;
1268 getAffineForIVs(*dstLoadOpInsts[0], &dstLoopIVs);
1269 unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;
1270 auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
1271
1272 // Compute loop depth and slice union for fusion.
1273 SmallVector<ComputationSliceState, 8> depthSliceUnions;
1274 depthSliceUnions.resize(N: dstLoopDepthTest);
1275 unsigned maxLegalFusionDepth = 0;
1276 FusionStrategy strategy(memref);
1277 for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
1278 FusionResult result =
1279 affine::canFuseLoops(srcForOp: sibAffineForOp, dstForOp: dstAffineForOp,
1280 /*dstLoopDepth=*/i + numSurroundingLoops,
1281 srcSlice: &depthSliceUnions[i - 1], fusionStrategy: strategy);
1282
1283 if (result.value == FusionResult::Success)
1284 maxLegalFusionDepth = i;
1285 }
1286
1287 LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
1288 << maxLegalFusionDepth << '\n');
1289
1290 // Skip if fusion is not feasible at any loop depths.
1291 if (maxLegalFusionDepth == 0)
1292 continue;
1293
1294 double computeToleranceThresholdToUse = computeToleranceThreshold;
1295
1296 // Cyclic dependences in the source nest may be violated when performing
1297 // slicing-based fusion. They aren't actually violated in cases where no
1298 // redundant execution of the source happens (1:1 pointwise dep on the
1299 // producer-consumer memref access for example). Check this and allow
1300 // fusion accordingly.
1301 if (hasCyclicDependence(sibAffineForOp)) {
1302 LLVM_DEBUG(llvm::dbgs() << "Source nest has a cyclic dependence.\n");
1303 // Maximal fusion does not check for compute tolerance threshold; so
1304 // perform the maximal fusion only when the redundanation computation is
1305 // zero.
1306 if (maximalFusion) {
1307 auto dstForOp = cast<AffineForOp>(dstNode->op);
1308 int64_t sliceCost;
1309 int64_t fusedLoopNestComputeCost;
1310 auto fraction = getAdditionalComputeFraction(
1311 sibAffineForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
1312 sliceCost, fusedLoopNestComputeCost);
1313 if (!fraction || fraction > 0) {
1314 LLVM_DEBUG(
1315 llvm::dbgs()
1316 << "Can't perform maximal fusion with a cyclic dependence "
1317 "and non-zero additional compute.\n");
1318 return;
1319 }
1320 } else {
1321 // Set redundant computation tolerance to zero regardless of what the
1322 // user specified. Without this, fusion would be invalid.
1323 LLVM_DEBUG(llvm::dbgs() << "Setting compute tolerance to zero since "
1324 "source has a cyclic dependence.\n");
1325 computeToleranceThresholdToUse = 0.0;
1326 }
1327 }
1328
1329 unsigned bestDstLoopDepth = maxLegalFusionDepth;
1330 if (!maximalFusion) {
1331 // Check if fusion would be profitable. For sibling fusion, the sibling
1332 // load op is treated as the src "store" op for fusion profitability
1333 // purposes. The footprint of the load in the slice relative to the
1334 // unfused source's determines reuse.
1335 if (!isFusionProfitable(sibAffineForOp, sibLoadOpInst, dstAffineForOp,
1336 depthSliceUnions, maxLegalFusionDepth,
1337 &bestDstLoopDepth,
1338 computeToleranceThresholdToUse))
1339 continue;
1340 }
1341
1342 assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
1343
1344 const ComputationSliceState &bestSlice =
1345 depthSliceUnions[bestDstLoopDepth - 1];
1346 assert(!bestSlice.isEmpty() &&
1347 "Fusion depth has no computed slice union");
1348
1349 // Do not perform sibling fusion if it isn't maximal. We always remove the
1350 // sibling node and as such fusion shouldn't be performed if a part of the
1351 // slice is used in the destination.
1352 auto isMaximal = bestSlice.isMaximal();
1353 if (!isMaximal.value_or(u: false)) {
1354 LLVM_DEBUG(llvm::dbgs()
1355 << "Slice isn't maximal; not performing sibling fusion.\n");
1356 continue;
1357 }
1358
1359 // Check if source loop is being inserted in the innermost
1360 // destination loop. Based on this, the fused loop may be optimized
1361 // further inside `fuseLoops`.
1362 bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
1363 // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
1364 affine::fuseLoops(srcForOp: sibAffineForOp, dstForOp: dstAffineForOp, srcSlice: bestSlice,
1365 isInnermostSiblingInsertionFusion: isInnermostInsertion);
1366
1367 auto dstForInst = cast<AffineForOp>(dstNode->op);
1368 // Update operation position of fused loop nest (if needed).
1369 if (insertPointInst != dstForInst)
1370 dstForInst->moveBefore(insertPointInst);
1371
1372 LLVM_DEBUG(llvm::dbgs()
1373 << "Fused sibling nest " << sibId << " into destination nest "
1374 << dstNode->id << " at depth " << bestDstLoopDepth << ":\n"
1375 << dstAffineForOp << "\n");
1376
1377 // Update data dependence graph state post fusion.
1378 updateStateAfterSiblingFusion(sibNode, dstNode);
1379
1380 // Remove old sibling loop nest.
1381 // Get op before we invalidate the MDG node.
1382 Operation *op = sibNode->op;
1383 mdg->removeNode(id: sibNode->id);
1384 op->erase();
1385 }
1386 }
1387
1388 // Searches block argument uses and the graph from 'dstNode' looking for a
1389 // fusion candidate sibling node which shares no dependences with 'dstNode'
1390 // but which loads from the same memref. Returns true and sets
1391 // 'idAndMemrefToFuse' on success. Returns false otherwise.
1392 bool findSiblingNodeToFuse(Node *dstNode,
1393 DenseSet<unsigned> *visitedSibNodeIds,
1394 std::pair<unsigned, Value> *idAndMemrefToFuse) {
1395 // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse
1396 // on 'memref'.
1397 auto canFuseWithSibNode = [&](Node *sibNode, Value memref) {
1398 // Skip if 'outEdge' is not a read-after-write dependence.
1399 // TODO: Remove restrict to single load op restriction.
1400 if (sibNode->getLoadOpCount(memref) != 1)
1401 return false;
1402 // Skip if there exists a path of dependent edges between
1403 // 'sibNode' and 'dstNode'.
1404 if (mdg->hasDependencePath(srcId: sibNode->id, dstId: dstNode->id) ||
1405 mdg->hasDependencePath(srcId: dstNode->id, dstId: sibNode->id))
1406 return false;
1407 // Skip sib node if it loads to (and stores from) the same memref on
1408 // which it also has an input dependence edge.
1409 DenseSet<Value> loadAndStoreMemrefSet;
1410 sibNode->getLoadAndStoreMemrefSet(loadAndStoreMemrefSet: &loadAndStoreMemrefSet);
1411 if (llvm::any_of(Range&: loadAndStoreMemrefSet, P: [=](Value memref) {
1412 return mdg->getIncomingMemRefAccesses(id: sibNode->id, memref) > 0;
1413 }))
1414 return false;
1415
1416 // Check that all stores are to the same memref if any.
1417 DenseSet<Value> storeMemrefs;
1418 for (auto *storeOpInst : sibNode->stores) {
1419 storeMemrefs.insert(
1420 cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
1421 }
1422 return storeMemrefs.size() <= 1;
1423 };
1424
1425 // Search for siblings which load the same memref block argument.
1426 Block *block = dstNode->op->getBlock();
1427 for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) {
1428 for (Operation *user : block->getArgument(i).getUsers()) {
1429 auto loadOp = dyn_cast<AffineReadOpInterface>(user);
1430 if (!loadOp)
1431 continue;
1432 // Gather loops surrounding 'use'.
1433 SmallVector<AffineForOp, 4> loops;
1434 getAffineForIVs(*user, &loops);
1435 // Skip 'use' if it is not within a loop nest.
1436 // Find the surrounding affine.for nested immediately within the
1437 // block.
1438 auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
1439 return loop->getBlock() == &mdg->block;
1440 });
1441 // Skip 'use' if it is not within a loop nest in `block`.
1442 if (it == loops.end())
1443 continue;
1444 Node *sibNode = mdg->getForOpNode(*it);
1445 assert(sibNode != nullptr);
1446 // Skip 'use' if it not a sibling to 'dstNode'.
1447 if (sibNode->id == dstNode->id)
1448 continue;
1449 // Skip 'use' if it has been visited.
1450 if (visitedSibNodeIds->count(V: sibNode->id) > 0)
1451 continue;
1452 // Skip 'use' if it does not load from the same memref as 'dstNode'.
1453 auto memref = loadOp.getMemRef();
1454 if (dstNode->getLoadOpCount(memref: memref) == 0)
1455 continue;
1456 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
1457 if (canFuseWithSibNode(sibNode, memref)) {
1458 visitedSibNodeIds->insert(V: sibNode->id);
1459 idAndMemrefToFuse->first = sibNode->id;
1460 idAndMemrefToFuse->second = memref;
1461 return true;
1462 }
1463 }
1464 }
1465
1466 // Search for siblings by following edges through an intermediate src node.
1467 // Collect candidate 'dstNode' input edges in 'inEdges'.
1468 SmallVector<MemRefDependenceGraph::Edge, 2> inEdges;
1469 mdg->forEachMemRefInputEdge(
1470 id: dstNode->id, callback: [&](MemRefDependenceGraph::Edge inEdge) {
1471 // Add 'inEdge' if it is a read-after-write dependence.
1472 if (dstNode->getLoadOpCount(memref: inEdge.value) > 0 &&
1473 mdg->getNode(id: inEdge.id)->getStoreOpCount(memref: inEdge.value) > 0)
1474 inEdges.push_back(Elt: inEdge);
1475 });
1476
1477 // Search for sibling nodes to fuse by visiting output edges from each input
1478 // edge in 'inEdges'.
1479 for (auto &inEdge : inEdges) {
1480 // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'.
1481 SmallVector<MemRefDependenceGraph::Edge, 2> outEdges;
1482 mdg->forEachMemRefOutputEdge(
1483 id: inEdge.id, callback: [&](MemRefDependenceGraph::Edge outEdge) {
1484 unsigned sibNodeId = outEdge.id;
1485 if (visitedSibNodeIds->count(V: sibNodeId) > 0)
1486 return;
1487 // Skip output edge if not a sibling using the same memref.
1488 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
1489 return;
1490 auto *sibNode = mdg->getNode(id: sibNodeId);
1491 if (!isa<AffineForOp>(Val: sibNode->op))
1492 return;
1493 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
1494 if (canFuseWithSibNode(sibNode, outEdge.value)) {
1495 // Add candidate 'outEdge' to sibling node.
1496 outEdges.push_back(Elt: outEdge);
1497 }
1498 });
1499
1500 // Add first candidate if any were returned.
1501 if (!outEdges.empty()) {
1502 visitedSibNodeIds->insert(V: outEdges[0].id);
1503 idAndMemrefToFuse->first = outEdges[0].id;
1504 idAndMemrefToFuse->second = outEdges[0].value;
1505 return true;
1506 }
1507 }
1508 return false;
1509 }
1510
1511 /// Update data dependence graph state to reflect sibling fusion of 'sibNode'
1512 /// into 'dstNode'.
1513 void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) {
1514 // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion.
1515 mdg->updateEdges(sibId: sibNode->id, dstId: dstNode->id);
1516
1517 // Collect dst loop stats after memref privatization transformation.
1518 auto dstForInst = cast<AffineForOp>(dstNode->op);
1519 LoopNestStateCollector dstLoopCollector;
1520 dstLoopCollector.collect(opToWalk: dstForInst);
1521 // Clear and add back loads and stores
1522 mdg->clearNodeLoadAndStores(id: dstNode->id);
1523 mdg->addToNode(id: dstNode->id, loads: dstLoopCollector.loadOpInsts,
1524 stores: dstLoopCollector.storeOpInsts, memrefLoads: dstLoopCollector.memrefLoads,
1525 memrefStores: dstLoopCollector.memrefStores, memrefFrees: dstLoopCollector.memrefFrees);
1526 }
1527
1528 // Clean up any allocs with no users.
1529 void eraseUnusedMemRefAllocations() {
1530 for (auto &pair : mdg->memrefEdgeCount) {
1531 if (pair.second > 0)
1532 continue;
1533 auto memref = pair.first;
1534 // Skip if there exist other uses (return operation or function calls).
1535 if (!memref.use_empty())
1536 continue;
1537 // Use list expected to match the dep graph info.
1538 auto *op = memref.getDefiningOp();
1539 if (isa_and_nonnull<memref::AllocOp>(Val: op))
1540 op->erase();
1541 }
1542 }
1543};
1544
1545} // namespace
1546
1547/// Run fusion on `block`.
1548void LoopFusion::runOnBlock(Block *block) {
1549 MemRefDependenceGraph g(*block);
1550 if (!g.init()) {
1551 LLVM_DEBUG(llvm::dbgs() << "MDG init failed\n");
1552 return;
1553 }
1554
1555 std::optional<unsigned> fastMemorySpaceOpt;
1556 if (fastMemorySpace.hasValue())
1557 fastMemorySpaceOpt = fastMemorySpace;
1558 unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;
1559 GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,
1560 maximalFusion, computeToleranceThreshold);
1561
1562 if (affineFusionMode == FusionMode::ProducerConsumer)
1563 fusion.runProducerConsumerFusionOnly();
1564 else if (affineFusionMode == FusionMode::Sibling)
1565 fusion.runSiblingFusionOnly();
1566 else
1567 fusion.runGreedyFusion();
1568}
1569
1570void LoopFusion::runOnOperation() {
1571 // Call fusion on every op that has at least two affine.for nests (in post
1572 // order).
1573 getOperation()->walk([&](Operation *op) {
1574 for (Region &region : op->getRegions()) {
1575 for (Block &block : region.getBlocks()) {
1576 auto affineFors = block.getOps<AffineForOp>();
1577 if (!affineFors.empty() && !llvm::hasSingleElement(C&: affineFors))
1578 runOnBlock(block: &block);
1579 }
1580 }
1581 });
1582}
1583
1584std::unique_ptr<Pass> mlir::affine::createLoopFusionPass(
1585 unsigned fastMemorySpace, uint64_t localBufSizeThreshold,
1586 bool maximalFusion, enum FusionMode affineFusionMode) {
1587 return std::make_unique<LoopFusion>(args&: fastMemorySpace, args&: localBufSizeThreshold,
1588 args&: maximalFusion, args&: affineFusionMode);
1589}
1590

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp