1//===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements loop fusion transformation utility functions.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/LoopFusionUtils.h"
14#include "mlir/Analysis/SliceAnalysis.h"
15#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
16#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
17#include "mlir/Dialect/Affine/Analysis/Utils.h"
18#include "mlir/Dialect/Affine/IR/AffineOps.h"
19#include "mlir/Dialect/Affine/LoopUtils.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Operation.h"
22#include "mlir/IR/PatternMatch.h"
23#include "llvm/Support/Debug.h"
24#include "llvm/Support/raw_ostream.h"
25#include <optional>
26
27#define DEBUG_TYPE "loop-fusion-utils"
28
29using namespace mlir;
30using namespace mlir::affine;
31
32// Gathers all load and store memref accesses in 'opA' into 'values', where
33// 'values[memref] == true' for each store operation.
34static void getLoadAndStoreMemRefAccesses(Operation *opA,
35 DenseMap<Value, bool> &values) {
36 opA->walk(callback: [&](Operation *op) {
37 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
38 if (values.count(Val: loadOp.getMemRef()) == 0)
39 values[loadOp.getMemRef()] = false;
40 } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
41 values[storeOp.getMemRef()] = true;
42 }
43 });
44}
45
46/// Returns true if 'op' is a load or store operation which access a memref
47/// accessed 'values' and at least one of the access is a store operation.
48/// Returns false otherwise.
49static bool isDependentLoadOrStoreOp(Operation *op,
50 DenseMap<Value, bool> &values) {
51 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
52 return values.count(Val: loadOp.getMemRef()) > 0 && values[loadOp.getMemRef()];
53 }
54 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
55 return values.count(Val: storeOp.getMemRef()) > 0;
56 }
57 return false;
58}
59
60// Returns the first operation in range ('opA', 'opB') which has a data
61// dependence on 'opA'. Returns 'nullptr' of no dependence exists.
62static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) {
63 // Record memref values from all loads/store in loop nest rooted at 'opA'.
64 // Map from memref value to bool which is true if store, false otherwise.
65 DenseMap<Value, bool> values;
66 getLoadAndStoreMemRefAccesses(opA, values);
67
68 // For each 'opX' in block in range ('opA', 'opB'), check if there is a data
69 // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref
70 // and at least one of the accesses is a store).
71 Operation *firstDepOp = nullptr;
72 for (Block::iterator it = std::next(x: Block::iterator(opA));
73 it != Block::iterator(opB); ++it) {
74 Operation *opX = &(*it);
75 opX->walk(callback: [&](Operation *op) {
76 if (!firstDepOp && isDependentLoadOrStoreOp(op, values))
77 firstDepOp = opX;
78 });
79 if (firstDepOp)
80 break;
81 }
82 return firstDepOp;
83}
84
85// Returns the last operation 'opX' in range ('opA', 'opB'), for which there
86// exists a data dependence from 'opX' to 'opB'.
87// Returns 'nullptr' of no dependence exists.
88static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
89 // Record memref values from all loads/store in loop nest rooted at 'opB'.
90 // Map from memref value to bool which is true if store, false otherwise.
91 DenseMap<Value, bool> values;
92 getLoadAndStoreMemRefAccesses(opA: opB, values);
93
94 // For each 'opX' in block in range ('opA', 'opB') in reverse order,
95 // check if there is a data dependence from 'opX' to 'opB':
96 // *) 'opX' and 'opB' access the same memref and at least one of the accesses
97 // is a store.
98 // *) 'opX' produces an SSA Value which is used by 'opB'.
99 Operation *lastDepOp = nullptr;
100 for (Block::reverse_iterator it = std::next(x: Block::reverse_iterator(opB));
101 it != Block::reverse_iterator(opA); ++it) {
102 Operation *opX = &(*it);
103 opX->walk(callback: [&](Operation *op) {
104 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
105 if (isDependentLoadOrStoreOp(op, values)) {
106 lastDepOp = opX;
107 return WalkResult::interrupt();
108 }
109 return WalkResult::advance();
110 }
111 for (Value value : op->getResults()) {
112 for (Operation *user : value.getUsers()) {
113 SmallVector<AffineForOp, 4> loops;
114 // Check if any loop in loop nest surrounding 'user' is 'opB'.
115 getAffineForIVs(*user, &loops);
116 if (llvm::is_contained(loops, cast<AffineForOp>(opB))) {
117 lastDepOp = opX;
118 return WalkResult::interrupt();
119 }
120 }
121 }
122 return WalkResult::advance();
123 });
124 if (lastDepOp)
125 break;
126 }
127 return lastDepOp;
128}
129
130// Computes and returns an insertion point operation, before which the
131// the fused <srcForOp, dstForOp> loop nest can be inserted while preserving
132// dependences. Returns nullptr if no such insertion point is found.
133static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
134 AffineForOp dstForOp) {
135 bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp);
136 auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
137 auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
138
139 Operation *firstDepOpA = getFirstDependentOpInRange(forOpA, forOpB);
140 Operation *lastDepOpB = getLastDependentOpInRange(forOpA, forOpB);
141 // Block:
142 // ...
143 // |-- opA
144 // | ...
145 // | lastDepOpB --|
146 // | ... |
147 // |-> firstDepOpA |
148 // ... |
149 // opB <---------
150 //
151 // Valid insertion point range: (lastDepOpB, firstDepOpA)
152 //
153 if (firstDepOpA) {
154 if (lastDepOpB) {
155 if (firstDepOpA->isBeforeInBlock(other: lastDepOpB) || firstDepOpA == lastDepOpB)
156 // No valid insertion point exists which preserves dependences.
157 return nullptr;
158 }
159 // Return insertion point in valid range closest to 'opB'.
160 // TODO: Consider other insertion points in valid range.
161 return firstDepOpA;
162 }
163 // No dependences from 'opA' to operation in range ('opA', 'opB'), return
164 // 'opB' insertion point.
165 return forOpB;
166}
167
168// Gathers all load and store ops in loop nest rooted at 'forOp' into
169// 'loadAndStoreOps'.
170static bool
171gatherLoadsAndStores(AffineForOp forOp,
172 SmallVectorImpl<Operation *> &loadAndStoreOps) {
173 bool hasIfOp = false;
174 forOp.walk([&](Operation *op) {
175 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
176 loadAndStoreOps.push_back(Elt: op);
177 else if (isa<AffineIfOp>(Val: op))
178 hasIfOp = true;
179 });
180 return !hasIfOp;
181}
182
183/// Returns the maximum loop depth at which we could fuse producer loop
184/// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences.
185// TODO: Generalize this check for sibling and more generic fusion scenarios.
186// TODO: Support forward slice fusion.
187static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
188 ArrayRef<Operation *> dstOps) {
189 if (dstOps.empty())
190 // Expected at least one memory operation.
191 // TODO: Revisit this case with a specific example.
192 return 0;
193
194 // Filter out ops in 'dstOps' that do not use the producer-consumer memref so
195 // that they are not considered for analysis.
196 DenseSet<Value> producerConsumerMemrefs;
197 gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs);
198 SmallVector<Operation *, 4> targetDstOps;
199 for (Operation *dstOp : dstOps) {
200 auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp);
201 Value memref = loadOp ? loadOp.getMemRef()
202 : cast<AffineWriteOpInterface>(dstOp).getMemRef();
203 if (producerConsumerMemrefs.count(V: memref) > 0)
204 targetDstOps.push_back(Elt: dstOp);
205 }
206
207 assert(!targetDstOps.empty() &&
208 "No dependences between 'srcForOp' and 'dstForOp'?");
209
210 // Compute the innermost common loop depth for loads and stores.
211 unsigned loopDepth = getInnermostCommonLoopDepth(ops: targetDstOps);
212
213 // Return common loop depth for loads if there are no store ops.
214 if (all_of(targetDstOps, llvm::IsaPred<AffineReadOpInterface>))
215 return loopDepth;
216
217 // Check dependences on all pairs of ops in 'targetDstOps' and store the
218 // minimum loop depth at which a dependence is satisfied.
219 for (unsigned i = 0, e = targetDstOps.size(); i < e; ++i) {
220 Operation *srcOpInst = targetDstOps[i];
221 MemRefAccess srcAccess(srcOpInst);
222 for (unsigned j = 0; j < e; ++j) {
223 auto *dstOpInst = targetDstOps[j];
224 MemRefAccess dstAccess(dstOpInst);
225
226 unsigned numCommonLoops =
227 getNumCommonSurroundingLoops(a&: *srcOpInst, b&: *dstOpInst);
228 for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
229 // TODO: Cache dependence analysis results, check cache here.
230 DependenceResult result =
231 checkMemrefAccessDependence(srcAccess, dstAccess, loopDepth: d);
232 if (hasDependence(result)) {
233 // Store minimum loop depth and break because we want the min 'd' at
234 // which there is a dependence.
235 loopDepth = std::min(a: loopDepth, b: d - 1);
236 break;
237 }
238 }
239 }
240 }
241
242 return loopDepth;
243}
244
245// TODO: This pass performs some computation that is the same for all the depths
246// (e.g., getMaxLoopDepth). Implement a version of this utility that processes
247// all the depths at once or only the legal maximal depth for maximal fusion.
248FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp,
249 AffineForOp dstForOp,
250 unsigned dstLoopDepth,
251 ComputationSliceState *srcSlice,
252 FusionStrategy fusionStrategy) {
253 // Return 'failure' if 'dstLoopDepth == 0'.
254 if (dstLoopDepth == 0) {
255 LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n");
256 return FusionResult::FailPrecondition;
257 }
258 // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
259 auto *block = srcForOp->getBlock();
260 if (block != dstForOp->getBlock()) {
261 LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n");
262 return FusionResult::FailPrecondition;
263 }
264
265 // Return 'failure' if no valid insertion point for fused loop nest in 'block'
266 // exists which would preserve dependences.
267 if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
268 LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n");
269 return FusionResult::FailBlockDependence;
270 }
271
272 // Check if 'srcForOp' precedes 'dstForOp' in 'block'.
273 bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp);
274 // 'forOpA' executes before 'forOpB' in 'block'.
275 auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
276 auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
277
278 // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'.
279 SmallVector<Operation *, 4> opsA;
280 if (!gatherLoadsAndStores(forOpA, opsA)) {
281 LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
282 return FusionResult::FailPrecondition;
283 }
284
285 // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'.
286 SmallVector<Operation *, 4> opsB;
287 if (!gatherLoadsAndStores(forOpB, opsB)) {
288 LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
289 return FusionResult::FailPrecondition;
290 }
291
292 // Return 'failure' if fusing loops at depth 'dstLoopDepth' wouldn't preserve
293 // loop dependences.
294 // TODO: Enable this check for sibling and more generic loop fusion
295 // strategies.
296 if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) {
297 // TODO: 'getMaxLoopDepth' does not support forward slice fusion.
298 assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion");
299 if (getMaxLoopDepth(srcOps: opsA, dstOps: opsB) < dstLoopDepth) {
300 LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n");
301 return FusionResult::FailFusionDependence;
302 }
303 }
304
305 // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'.
306 unsigned numCommonLoops =
307 affine::getNumCommonSurroundingLoops(a&: *srcForOp, b&: *dstForOp);
308
309 // Filter out ops in 'opsA' to compute the slice union based on the
310 // assumptions made by the fusion strategy.
311 SmallVector<Operation *, 4> strategyOpsA;
312 switch (fusionStrategy.getStrategy()) {
313 case FusionStrategy::Generic:
314 // Generic fusion. Take into account all the memory operations to compute
315 // the slice union.
316 strategyOpsA.append(in_start: opsA.begin(), in_end: opsA.end());
317 break;
318 case FusionStrategy::ProducerConsumer:
319 // Producer-consumer fusion (AffineLoopFusion pass) only takes into
320 // account stores in 'srcForOp' to compute the slice union.
321 for (Operation *op : opsA) {
322 if (isa<AffineWriteOpInterface>(op))
323 strategyOpsA.push_back(Elt: op);
324 }
325 break;
326 case FusionStrategy::Sibling:
327 // Sibling fusion (AffineLoopFusion pass) only takes into account the loads
328 // to 'memref' in 'srcForOp' to compute the slice union.
329 for (Operation *op : opsA) {
330 auto load = dyn_cast<AffineReadOpInterface>(op);
331 if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef())
332 strategyOpsA.push_back(Elt: op);
333 }
334 break;
335 }
336
337 // Compute union of computation slices computed between all pairs of ops
338 // from 'forOpA' and 'forOpB'.
339 SliceComputationResult sliceComputationResult = affine::computeSliceUnion(
340 opsA: strategyOpsA, opsB, loopDepth: dstLoopDepth, numCommonLoops,
341 isBackwardSlice: isSrcForOpBeforeDstForOp, sliceUnion: srcSlice);
342 if (sliceComputationResult.value == SliceComputationResult::GenericFailure) {
343 LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
344 return FusionResult::FailPrecondition;
345 }
346 if (sliceComputationResult.value ==
347 SliceComputationResult::IncorrectSliceFailure) {
348 LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n");
349 return FusionResult::FailIncorrectSlice;
350 }
351
352 return FusionResult::Success;
353}
354
355/// Patch the loop body of a forOp that is a single iteration reduction loop
356/// into its containing block.
357static LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp,
358 bool siblingFusionUser) {
359 // Check if the reduction loop is a single iteration loop.
360 std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
361 if (!tripCount || *tripCount != 1)
362 return failure();
363 auto *parentOp = forOp->getParentOp();
364 if (!isa<AffineForOp>(parentOp))
365 return failure();
366 SmallVector<Value> newOperands;
367 llvm::append_range(newOperands,
368 forOp.getBody()->getTerminator()->getOperands());
369 IRRewriter rewriter(parentOp->getContext());
370 int64_t parentOpNumResults = parentOp->getNumResults();
371 // Replace the parent loop and add iteroperands and results from the `forOp`.
372 AffineForOp parentForOp = forOp->getParentOfType<AffineForOp>();
373 AffineForOp newLoop =
374 cast<AffineForOp>(*parentForOp.replaceWithAdditionalYields(
375 rewriter, forOp.getInits(), /*replaceInitOperandUsesInLoop=*/false,
376 [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
377 return newOperands;
378 }));
379
380 // For sibling-fusion users, collect operations that use the results of the
381 // `forOp` outside the new parent loop that has absorbed all its iter args
382 // and operands. These operations will be moved later after the results
383 // have been replaced.
384 SetVector<Operation *> forwardSlice;
385 if (siblingFusionUser) {
386 for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
387 SetVector<Operation *> tmpForwardSlice;
388 getForwardSlice(forOp.getResult(i), &tmpForwardSlice);
389 forwardSlice.set_union(tmpForwardSlice);
390 }
391 }
392 // Update the results of the `forOp` in the new loop.
393 for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
394 forOp.getResult(i).replaceAllUsesWith(
395 newLoop.getResult(i + parentOpNumResults));
396 }
397 // For sibling-fusion users, move operations that use the results of the
398 // `forOp` outside the new parent loop
399 if (siblingFusionUser) {
400 topologicalSort(toSort: forwardSlice);
401 for (Operation *op : llvm::reverse(C&: forwardSlice))
402 op->moveAfter(newLoop);
403 }
404 // Replace the induction variable.
405 auto iv = forOp.getInductionVar();
406 iv.replaceAllUsesWith(newLoop.getInductionVar());
407 // Replace the iter args.
408 auto forOpIterArgs = forOp.getRegionIterArgs();
409 for (auto it : llvm::zip(forOpIterArgs, newLoop.getRegionIterArgs().take_back(
410 forOpIterArgs.size()))) {
411 std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
412 }
413 // Move the loop body operations, except for its terminator, to the loop's
414 // containing block.
415 forOp.getBody()->back().erase();
416 auto *parentBlock = forOp->getBlock();
417 parentBlock->getOperations().splice(Block::iterator(forOp),
418 forOp.getBody()->getOperations());
419 forOp.erase();
420 return success();
421}
422
423/// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point
424/// and source slice loop bounds specified in 'srcSlice'.
425void mlir::affine::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
426 const ComputationSliceState &srcSlice,
427 bool isInnermostSiblingInsertion) {
428 // Clone 'srcForOp' into 'dstForOp' at 'srcSlice->insertPoint'.
429 OpBuilder b(srcSlice.insertPoint->getBlock(), srcSlice.insertPoint);
430 IRMapping mapper;
431 b.clone(*srcForOp, mapper);
432
433 // Update 'sliceLoopNest' upper and lower bounds from computed 'srcSlice'.
434 SmallVector<AffineForOp, 4> sliceLoops;
435 for (unsigned i = 0, e = srcSlice.ivs.size(); i < e; ++i) {
436 auto loopIV = mapper.lookupOrNull(from: srcSlice.ivs[i]);
437 if (!loopIV)
438 continue;
439 auto forOp = getForInductionVarOwner(loopIV);
440 sliceLoops.push_back(forOp);
441 if (AffineMap lbMap = srcSlice.lbs[i]) {
442 auto lbOperands = srcSlice.lbOperands[i];
443 canonicalizeMapAndOperands(map: &lbMap, operands: &lbOperands);
444 forOp.setLowerBound(lbOperands, lbMap);
445 }
446 if (AffineMap ubMap = srcSlice.ubs[i]) {
447 auto ubOperands = srcSlice.ubOperands[i];
448 canonicalizeMapAndOperands(map: &ubMap, operands: &ubOperands);
449 forOp.setUpperBound(ubOperands, ubMap);
450 }
451 }
452
453 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
454 auto srcIsUnitSlice = [&]() {
455 return (buildSliceTripCountMap(slice: srcSlice, tripCountMap: &sliceTripCountMap) &&
456 (getSliceIterationCount(sliceTripCountMap) == 1));
457 };
458 // Fix up and if possible, eliminate single iteration loops.
459 for (AffineForOp forOp : sliceLoops) {
460 if (isLoopParallelAndContainsReduction(forOp) &&
461 isInnermostSiblingInsertion && srcIsUnitSlice())
462 // Patch reduction loop - only ones that are sibling-fused with the
463 // destination loop - into the parent loop.
464 (void)promoteSingleIterReductionLoop(forOp, true);
465 else
466 // Promote any single iteration slice loops.
467 (void)promoteIfSingleIteration(forOp);
468 }
469}
470
471/// Collect loop nest statistics (eg. loop trip count and operation count)
472/// in 'stats' for loop nest rooted at 'forOp'. Returns true on success,
473/// returns false otherwise.
474bool mlir::affine::getLoopNestStats(AffineForOp forOpRoot,
475 LoopNestStats *stats) {
476 auto walkResult = forOpRoot.walk([&](AffineForOp forOp) {
477 auto *childForOp = forOp.getOperation();
478 auto *parentForOp = forOp->getParentOp();
479 if (forOp != forOpRoot) {
480 if (!isa<AffineForOp>(parentForOp)) {
481 LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n");
482 return WalkResult::interrupt();
483 }
484 // Add mapping to 'forOp' from its parent AffineForOp.
485 stats->loopMap[parentForOp].push_back(forOp);
486 }
487
488 // Record the number of op operations in the body of 'forOp'.
489 unsigned count = 0;
490 stats->opCountMap[childForOp] = 0;
491 for (auto &op : *forOp.getBody()) {
492 if (!isa<AffineForOp, AffineIfOp>(op))
493 ++count;
494 }
495 stats->opCountMap[childForOp] = count;
496
497 // Record trip count for 'forOp'. Set flag if trip count is not
498 // constant.
499 std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
500 if (!maybeConstTripCount) {
501 // Currently only constant trip count loop nests are supported.
502 LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n");
503 return WalkResult::interrupt();
504 }
505
506 stats->tripCountMap[childForOp] = *maybeConstTripCount;
507 return WalkResult::advance();
508 });
509 return !walkResult.wasInterrupted();
510}
511
512// Computes the total cost of the loop nest rooted at 'forOp'.
513// Currently, the total cost is computed by counting the total operation
514// instance count (i.e. total number of operations in the loop bodyloop
515// operation count * loop trip count) for the entire loop nest.
516// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
517// specified in the map when computing the total op instance count.
518// NOTEs: 1) This is used to compute the cost of computation slices, which are
519// sliced along the iteration dimension, and thus reduce the trip count.
520// If 'computeCostMap' is non-null, the total op count for forOps specified
521// in the map is increased (not overridden) by adding the op count from the
522// map to the existing op count for the for loop. This is done before
523// multiplying by the loop's trip count, and is used to model the cost of
524// inserting a sliced loop nest of known cost into the loop's body.
525// 2) This is also used to compute the cost of fusing a slice of some loop nest
526// within another loop.
527static int64_t getComputeCostHelper(
528 Operation *forOp, LoopNestStats &stats,
529 llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap,
530 DenseMap<Operation *, int64_t> *computeCostMap) {
531 // 'opCount' is the total number operations in one iteration of 'forOp' body,
532 // minus terminator op which is a no-op.
533 int64_t opCount = stats.opCountMap[forOp] - 1;
534 if (stats.loopMap.count(Val: forOp) > 0) {
535 for (auto childForOp : stats.loopMap[forOp]) {
536 opCount += getComputeCostHelper(childForOp, stats, tripCountOverrideMap,
537 computeCostMap);
538 }
539 }
540 // Add in additional op instances from slice (if specified in map).
541 if (computeCostMap) {
542 auto it = computeCostMap->find(Val: forOp);
543 if (it != computeCostMap->end()) {
544 opCount += it->second;
545 }
546 }
547 // Override trip count (if specified in map).
548 int64_t tripCount = stats.tripCountMap[forOp];
549 if (tripCountOverrideMap) {
550 auto it = tripCountOverrideMap->find(Val: forOp);
551 if (it != tripCountOverrideMap->end()) {
552 tripCount = it->second;
553 }
554 }
555 // Returns the total number of dynamic instances of operations in loop body.
556 return tripCount * opCount;
557}
558
559/// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
560/// Currently, the total cost is computed by counting the total operation
561/// instance count (i.e. total number of operations in the loop body * loop
562/// trip count) for the entire loop nest.
563int64_t mlir::affine::getComputeCost(AffineForOp forOp, LoopNestStats &stats) {
564 return getComputeCostHelper(forOp, stats,
565 /*tripCountOverrideMap=*/nullptr,
566 /*computeCostMap=*/nullptr);
567}
568
569/// Computes and returns in 'computeCost', the total compute cost of fusing the
570/// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently,
571/// the total cost is computed by counting the total operation instance count
572/// (i.e. total number of operations in the loop body * loop trip count) for
573/// the entire loop nest.
574bool mlir::affine::getFusionComputeCost(AffineForOp srcForOp,
575 LoopNestStats &srcStats,
576 AffineForOp dstForOp,
577 LoopNestStats &dstStats,
578 const ComputationSliceState &slice,
579 int64_t *computeCost) {
580 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
581 DenseMap<Operation *, int64_t> computeCostMap;
582
583 // Build trip count map for computation slice.
584 if (!buildSliceTripCountMap(slice, tripCountMap: &sliceTripCountMap))
585 return false;
586 // Checks whether a store to load forwarding will happen.
587 int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
588 assert(sliceIterationCount > 0);
589 bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
590 auto *insertPointParent = slice.insertPoint->getParentOp();
591
592 // The store and loads to this memref will disappear.
593 if (storeLoadFwdGuaranteed) {
594 // Subtract from operation count the loads/store we expect load/store
595 // forwarding to remove.
596 unsigned storeCount = 0;
597 llvm::SmallDenseSet<Value, 4> storeMemrefs;
598 srcForOp.walk([&](AffineWriteOpInterface storeOp) {
599 storeMemrefs.insert(storeOp.getMemRef());
600 ++storeCount;
601 });
602 // Subtract out any store ops in single-iteration src slice loop nest.
603 if (storeCount > 0)
604 computeCostMap[insertPointParent] = -storeCount;
605 // Subtract out any load users of 'storeMemrefs' nested below
606 // 'insertPointParent'.
607 for (Value memref : storeMemrefs) {
608 for (Operation *user : memref.getUsers()) {
609 if (!isa<AffineReadOpInterface>(user))
610 continue;
611 SmallVector<AffineForOp, 4> loops;
612 // Check if any loop in loop nest surrounding 'user' is
613 // 'insertPointParent'.
614 getAffineForIVs(*user, &loops);
615 if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) {
616 if (auto forOp = dyn_cast_or_null<AffineForOp>(user->getParentOp())) {
617 if (computeCostMap.count(Val: forOp) == 0)
618 computeCostMap[forOp] = 0;
619 computeCostMap[forOp] -= 1;
620 }
621 }
622 }
623 }
624 }
625
626 // Compute op instance count for the src loop nest with iteration slicing.
627 int64_t sliceComputeCost = getComputeCostHelper(
628 srcForOp, srcStats, &sliceTripCountMap, &computeCostMap);
629
630 // Compute cost of fusion for this depth.
631 computeCostMap[insertPointParent] = sliceComputeCost;
632
633 *computeCost =
634 getComputeCostHelper(dstForOp, dstStats,
635 /*tripCountOverrideMap=*/nullptr, &computeCostMap);
636 return true;
637}
638
639/// Returns in 'producerConsumerMemrefs' the memrefs involved in a
640/// producer-consumer dependence between write ops in 'srcOps' and read ops in
641/// 'dstOps'.
642void mlir::affine::gatherProducerConsumerMemrefs(
643 ArrayRef<Operation *> srcOps, ArrayRef<Operation *> dstOps,
644 DenseSet<Value> &producerConsumerMemrefs) {
645 // Gather memrefs from stores in 'srcOps'.
646 DenseSet<Value> srcStoreMemRefs;
647 for (Operation *op : srcOps)
648 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op))
649 srcStoreMemRefs.insert(storeOp.getMemRef());
650
651 // Compute the intersection between memrefs from stores in 'srcOps' and
652 // memrefs from loads in 'dstOps'.
653 for (Operation *op : dstOps)
654 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op))
655 if (srcStoreMemRefs.count(V: loadOp.getMemRef()) > 0)
656 producerConsumerMemrefs.insert(loadOp.getMemRef());
657}
658

source code of mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp