1//===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements loop fusion transformation utility functions.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/LoopFusionUtils.h"
14#include "mlir/Analysis/SliceAnalysis.h"
15#include "mlir/Analysis/TopologicalSortUtils.h"
16#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
17#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
18#include "mlir/Dialect/Affine/Analysis/Utils.h"
19#include "mlir/Dialect/Affine/IR/AffineOps.h"
20#include "mlir/Dialect/Affine/LoopUtils.h"
21#include "mlir/IR/IRMapping.h"
22#include "mlir/IR/Operation.h"
23#include "mlir/IR/PatternMatch.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/raw_ostream.h"
26#include <optional>
27
28#define DEBUG_TYPE "affine-fusion-utils"
29
30using namespace mlir;
31using namespace mlir::affine;
32
33// Gathers all load and store memref accesses in 'opA' into 'values', where
34// 'values[memref] == true' for each store operation.
35static void getLoadAndStoreMemRefAccesses(Operation *opA,
36 DenseMap<Value, bool> &values) {
37 opA->walk(callback: [&](Operation *op) {
38 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
39 if (values.count(Val: loadOp.getMemRef()) == 0)
40 values[loadOp.getMemRef()] = false;
41 } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
42 values[storeOp.getMemRef()] = true;
43 }
44 });
45}
46
47/// Returns true if 'op' is a load or store operation which access a memref
48/// accessed 'values' and at least one of the access is a store operation.
49/// Returns false otherwise.
50static bool isDependentLoadOrStoreOp(Operation *op,
51 DenseMap<Value, bool> &values) {
52 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op))
53 return values.count(Val: loadOp.getMemRef()) > 0 && values[loadOp.getMemRef()];
54 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op))
55 return values.count(Val: storeOp.getMemRef()) > 0;
56 return false;
57}
58
59// Returns the first operation in range ('opA', 'opB') which has a data
60// dependence on 'opA'. Returns 'nullptr' of no dependence exists.
61static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) {
62 // Record memref values from all loads/store in loop nest rooted at 'opA'.
63 // Map from memref value to bool which is true if store, false otherwise.
64 DenseMap<Value, bool> values;
65 getLoadAndStoreMemRefAccesses(opA, values);
66
67 // For each 'opX' in block in range ('opA', 'opB'), check if there is a data
68 // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref
69 // and at least one of the accesses is a store).
70 Operation *firstDepOp = nullptr;
71 for (Block::iterator it = std::next(x: Block::iterator(opA));
72 it != Block::iterator(opB); ++it) {
73 Operation *opX = &(*it);
74 opX->walk(callback: [&](Operation *op) {
75 if (!firstDepOp && isDependentLoadOrStoreOp(op, values))
76 firstDepOp = opX;
77 });
78 if (firstDepOp)
79 break;
80 }
81 return firstDepOp;
82}
83
84// Returns the last operation 'opX' in range ('opA', 'opB'), for which there
85// exists a data dependence from 'opX' to 'opB'.
86// Returns 'nullptr' of no dependence exists.
87static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
88 // Record memref values from all loads/store in loop nest rooted at 'opB'.
89 // Map from memref value to bool which is true if store, false otherwise.
90 DenseMap<Value, bool> values;
91 getLoadAndStoreMemRefAccesses(opA: opB, values);
92
93 // For each 'opX' in block in range ('opA', 'opB') in reverse order,
94 // check if there is a data dependence from 'opX' to 'opB':
95 // *) 'opX' and 'opB' access the same memref and at least one of the accesses
96 // is a store.
97 // *) 'opX' produces an SSA Value which is used by 'opB'.
98 Operation *lastDepOp = nullptr;
99 for (Block::reverse_iterator it = std::next(x: Block::reverse_iterator(opB));
100 it != Block::reverse_iterator(opA); ++it) {
101 Operation *opX = &(*it);
102 opX->walk(callback: [&](Operation *op) {
103 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
104 if (isDependentLoadOrStoreOp(op, values)) {
105 lastDepOp = opX;
106 return WalkResult::interrupt();
107 }
108 return WalkResult::advance();
109 }
110 for (Value value : op->getResults()) {
111 for (Operation *user : value.getUsers()) {
112 SmallVector<AffineForOp, 4> loops;
113 // Check if any loop in loop nest surrounding 'user' is 'opB'.
114 getAffineForIVs(*user, &loops);
115 if (llvm::is_contained(loops, cast<AffineForOp>(opB))) {
116 lastDepOp = opX;
117 return WalkResult::interrupt();
118 }
119 }
120 }
121 return WalkResult::advance();
122 });
123 if (lastDepOp)
124 break;
125 }
126 return lastDepOp;
127}
128
129// Computes and returns an insertion point operation, before which the
130// the fused <srcForOp, dstForOp> loop nest can be inserted while preserving
131// dependences. Returns nullptr if no such insertion point is found.
132static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
133 AffineForOp dstForOp) {
134 bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp);
135 auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
136 auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
137
138 Operation *firstDepOpA = getFirstDependentOpInRange(forOpA, forOpB);
139 Operation *lastDepOpB = getLastDependentOpInRange(forOpA, forOpB);
140 // Block:
141 // ...
142 // |-- opA
143 // | ...
144 // | lastDepOpB --|
145 // | ... |
146 // |-> firstDepOpA |
147 // ... |
148 // opB <---------
149 //
150 // Valid insertion point range: (lastDepOpB, firstDepOpA)
151 //
152 if (firstDepOpA) {
153 if (lastDepOpB) {
154 if (firstDepOpA->isBeforeInBlock(other: lastDepOpB) || firstDepOpA == lastDepOpB)
155 // No valid insertion point exists which preserves dependences.
156 return nullptr;
157 }
158 // Return insertion point in valid range closest to 'opB'.
159 // TODO: Consider other insertion points in valid range.
160 return firstDepOpA;
161 }
162 // No dependences from 'opA' to operation in range ('opA', 'opB'), return
163 // 'opB' insertion point.
164 return forOpB;
165}
166
167// Gathers all load and store ops in loop nest rooted at 'forOp' into
168// 'loadAndStoreOps'.
169static bool
170gatherLoadsAndStores(AffineForOp forOp,
171 SmallVectorImpl<Operation *> &loadAndStoreOps) {
172 bool hasIfOp = false;
173 forOp.walk([&](Operation *op) {
174 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
175 loadAndStoreOps.push_back(Elt: op);
176 else if (isa<AffineIfOp>(Val: op))
177 hasIfOp = true;
178 });
179 return !hasIfOp;
180}
181
182/// Returns the maximum loop depth at which we could fuse producer loop
183/// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences.
184// TODO: Generalize this check for sibling and more generic fusion scenarios.
185// TODO: Support forward slice fusion.
186static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
187 ArrayRef<Operation *> dstOps) {
188 if (dstOps.empty())
189 // Expected at least one memory operation.
190 // TODO: Revisit this case with a specific example.
191 return 0;
192
193 // Filter out ops in 'dstOps' that do not use the producer-consumer memref so
194 // that they are not considered for analysis.
195 DenseSet<Value> producerConsumerMemrefs;
196 gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs);
197 SmallVector<Operation *, 4> targetDstOps;
198 for (Operation *dstOp : dstOps) {
199 auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp);
200 Value memref = loadOp ? loadOp.getMemRef()
201 : cast<AffineWriteOpInterface>(dstOp).getMemRef();
202 if (producerConsumerMemrefs.count(V: memref) > 0)
203 targetDstOps.push_back(Elt: dstOp);
204 }
205
206 assert(!targetDstOps.empty() &&
207 "No dependences between 'srcForOp' and 'dstForOp'?");
208
209 // Compute the innermost common loop depth for loads and stores.
210 unsigned loopDepth = getInnermostCommonLoopDepth(ops: targetDstOps);
211
212 // Return common loop depth for loads if there are no store ops.
213 if (all_of(targetDstOps, llvm::IsaPred<AffineReadOpInterface>))
214 return loopDepth;
215
216 // Check dependences on all pairs of ops in 'targetDstOps' and store the
217 // minimum loop depth at which a dependence is satisfied.
218 for (unsigned i = 0, e = targetDstOps.size(); i < e; ++i) {
219 Operation *srcOpInst = targetDstOps[i];
220 MemRefAccess srcAccess(srcOpInst);
221 for (unsigned j = 0; j < e; ++j) {
222 auto *dstOpInst = targetDstOps[j];
223 MemRefAccess dstAccess(dstOpInst);
224
225 unsigned numCommonLoops =
226 getNumCommonSurroundingLoops(a&: *srcOpInst, b&: *dstOpInst);
227 for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
228 // TODO: Cache dependence analysis results, check cache here.
229 DependenceResult result =
230 checkMemrefAccessDependence(srcAccess, dstAccess, loopDepth: d);
231 if (hasDependence(result)) {
232 // Store minimum loop depth and break because we want the min 'd' at
233 // which there is a dependence.
234 loopDepth = std::min(a: loopDepth, b: d - 1);
235 break;
236 }
237 }
238 }
239 }
240
241 return loopDepth;
242}
243
244// TODO: This pass performs some computation that is the same for all the depths
245// (e.g., getMaxLoopDepth). Implement a version of this utility that processes
246// all the depths at once or only the legal maximal depth for maximal fusion.
247FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp,
248 AffineForOp dstForOp,
249 unsigned dstLoopDepth,
250 ComputationSliceState *srcSlice,
251 FusionStrategy fusionStrategy) {
252 // Return 'failure' if 'dstLoopDepth == 0'.
253 if (dstLoopDepth == 0) {
254 LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n");
255 return FusionResult::FailPrecondition;
256 }
257 // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
258 auto *block = srcForOp->getBlock();
259 if (block != dstForOp->getBlock()) {
260 LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n");
261 return FusionResult::FailPrecondition;
262 }
263
264 // Return 'failure' if no valid insertion point for fused loop nest in 'block'
265 // exists which would preserve dependences.
266 if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
267 LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n");
268 return FusionResult::FailBlockDependence;
269 }
270
271 // Check if 'srcForOp' precedes 'dstForOp' in 'block'.
272 bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp);
273 // 'forOpA' executes before 'forOpB' in 'block'.
274 auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
275 auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
276
277 // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'.
278 SmallVector<Operation *, 4> opsA;
279 if (!gatherLoadsAndStores(forOpA, opsA)) {
280 LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
281 return FusionResult::FailPrecondition;
282 }
283
284 // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'.
285 SmallVector<Operation *, 4> opsB;
286 if (!gatherLoadsAndStores(forOpB, opsB)) {
287 LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
288 return FusionResult::FailPrecondition;
289 }
290
291 // Return 'failure' if fusing loops at depth 'dstLoopDepth' wouldn't preserve
292 // loop dependences.
293 // TODO: Enable this check for sibling and more generic loop fusion
294 // strategies.
295 if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) {
296 // TODO: 'getMaxLoopDepth' does not support forward slice fusion.
297 assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion");
298 if (getMaxLoopDepth(srcOps: opsA, dstOps: opsB) < dstLoopDepth) {
299 LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n");
300 return FusionResult::FailFusionDependence;
301 }
302 }
303
304 // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'.
305 unsigned numCommonLoops =
306 affine::getNumCommonSurroundingLoops(a&: *srcForOp, b&: *dstForOp);
307
308 // Filter out ops in 'opsA' to compute the slice union based on the
309 // assumptions made by the fusion strategy.
310 SmallVector<Operation *, 4> strategyOpsA;
311 switch (fusionStrategy.getStrategy()) {
312 case FusionStrategy::Generic:
313 // Generic fusion. Take into account all the memory operations to compute
314 // the slice union.
315 strategyOpsA.append(in_start: opsA.begin(), in_end: opsA.end());
316 break;
317 case FusionStrategy::ProducerConsumer:
318 // Producer-consumer fusion (AffineLoopFusion pass) only takes into
319 // account stores in 'srcForOp' to compute the slice union.
320 for (Operation *op : opsA) {
321 if (isa<AffineWriteOpInterface>(op))
322 strategyOpsA.push_back(Elt: op);
323 }
324 break;
325 case FusionStrategy::Sibling:
326 // Sibling fusion (AffineLoopFusion pass) only takes into account the loads
327 // to 'memref' in 'srcForOp' to compute the slice union.
328 for (Operation *op : opsA) {
329 auto load = dyn_cast<AffineReadOpInterface>(op);
330 if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef())
331 strategyOpsA.push_back(Elt: op);
332 }
333 break;
334 }
335
336 // Compute union of computation slices computed between all pairs of ops
337 // from 'forOpA' and 'forOpB'.
338 SliceComputationResult sliceComputationResult = affine::computeSliceUnion(
339 opsA: strategyOpsA, opsB, loopDepth: dstLoopDepth, numCommonLoops,
340 isBackwardSlice: isSrcForOpBeforeDstForOp, sliceUnion: srcSlice);
341 if (sliceComputationResult.value == SliceComputationResult::GenericFailure) {
342 LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
343 return FusionResult::FailPrecondition;
344 }
345 if (sliceComputationResult.value ==
346 SliceComputationResult::IncorrectSliceFailure) {
347 LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n");
348 return FusionResult::FailIncorrectSlice;
349 }
350
351 return FusionResult::Success;
352}
353
354/// Patch the loop body of a forOp that is a single iteration reduction loop
355/// into its containing block.
356static LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp,
357 bool siblingFusionUser) {
358 // Check if the reduction loop is a single iteration loop.
359 std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
360 if (!tripCount || *tripCount != 1)
361 return failure();
362 auto *parentOp = forOp->getParentOp();
363 if (!isa<AffineForOp>(parentOp))
364 return failure();
365 SmallVector<Value> newOperands;
366 llvm::append_range(newOperands,
367 forOp.getBody()->getTerminator()->getOperands());
368 IRRewriter rewriter(parentOp->getContext());
369 int64_t parentOpNumResults = parentOp->getNumResults();
370 // Replace the parent loop and add iteroperands and results from the `forOp`.
371 AffineForOp parentForOp = forOp->getParentOfType<AffineForOp>();
372 AffineForOp newLoop =
373 cast<AffineForOp>(*parentForOp.replaceWithAdditionalYields(
374 rewriter, forOp.getInits(), /*replaceInitOperandUsesInLoop=*/false,
375 [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
376 return newOperands;
377 }));
378
379 // For sibling-fusion users, collect operations that use the results of the
380 // `forOp` outside the new parent loop that has absorbed all its iter args
381 // and operands. These operations will be moved later after the results
382 // have been replaced.
383 SetVector<Operation *> forwardSlice;
384 if (siblingFusionUser) {
385 for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
386 SetVector<Operation *> tmpForwardSlice;
387 getForwardSlice(forOp.getResult(i), &tmpForwardSlice);
388 forwardSlice.set_union(tmpForwardSlice);
389 }
390 }
391 // Update the results of the `forOp` in the new loop.
392 for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
393 forOp.getResult(i).replaceAllUsesWith(
394 newLoop.getResult(i + parentOpNumResults));
395 }
396 // For sibling-fusion users, move operations that use the results of the
397 // `forOp` outside the new parent loop
398 if (siblingFusionUser) {
399 topologicalSort(toSort: forwardSlice);
400 for (Operation *op : llvm::reverse(C&: forwardSlice))
401 op->moveAfter(newLoop);
402 }
403 // Replace the induction variable.
404 auto iv = forOp.getInductionVar();
405 iv.replaceAllUsesWith(newLoop.getInductionVar());
406 // Replace the iter args.
407 auto forOpIterArgs = forOp.getRegionIterArgs();
408 for (auto it : llvm::zip(forOpIterArgs, newLoop.getRegionIterArgs().take_back(
409 forOpIterArgs.size()))) {
410 std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
411 }
412 // Move the loop body operations, except for its terminator, to the loop's
413 // containing block.
414 forOp.getBody()->back().erase();
415 auto *parentBlock = forOp->getBlock();
416 parentBlock->getOperations().splice(Block::iterator(forOp),
417 forOp.getBody()->getOperations());
418 forOp.erase();
419 return success();
420}
421
422/// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point
423/// and source slice loop bounds specified in 'srcSlice'.
424void mlir::affine::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
425 const ComputationSliceState &srcSlice,
426 bool isInnermostSiblingInsertion) {
427 // Clone 'srcForOp' into 'dstForOp' at 'srcSlice->insertPoint'.
428 OpBuilder b(srcSlice.insertPoint->getBlock(), srcSlice.insertPoint);
429 IRMapping mapper;
430 b.clone(*srcForOp, mapper);
431
432 // Update 'sliceLoopNest' upper and lower bounds from computed 'srcSlice'.
433 SmallVector<AffineForOp, 4> sliceLoops;
434 for (unsigned i = 0, e = srcSlice.ivs.size(); i < e; ++i) {
435 auto loopIV = mapper.lookupOrNull(from: srcSlice.ivs[i]);
436 if (!loopIV)
437 continue;
438 auto forOp = getForInductionVarOwner(loopIV);
439 sliceLoops.push_back(forOp);
440 if (AffineMap lbMap = srcSlice.lbs[i]) {
441 auto lbOperands = srcSlice.lbOperands[i];
442 canonicalizeMapAndOperands(map: &lbMap, operands: &lbOperands);
443 forOp.setLowerBound(lbOperands, lbMap);
444 }
445 if (AffineMap ubMap = srcSlice.ubs[i]) {
446 auto ubOperands = srcSlice.ubOperands[i];
447 canonicalizeMapAndOperands(map: &ubMap, operands: &ubOperands);
448 forOp.setUpperBound(ubOperands, ubMap);
449 }
450 }
451
452 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
453 auto srcIsUnitSlice = [&]() {
454 return (buildSliceTripCountMap(slice: srcSlice, tripCountMap: &sliceTripCountMap) &&
455 (getSliceIterationCount(sliceTripCountMap) == 1));
456 };
457 // Fix up and if possible, eliminate single iteration loops.
458 for (AffineForOp forOp : sliceLoops) {
459 if (isLoopParallelAndContainsReduction(forOp) &&
460 isInnermostSiblingInsertion && srcIsUnitSlice())
461 // Patch reduction loop - only ones that are sibling-fused with the
462 // destination loop - into the parent loop.
463 (void)promoteSingleIterReductionLoop(forOp, true);
464 else
465 // Promote any single iteration slice loops.
466 (void)promoteIfSingleIteration(forOp);
467 }
468}
469
470/// Collect loop nest statistics (eg. loop trip count and operation count)
471/// in 'stats' for loop nest rooted at 'forOp'. Returns true on success,
472/// returns false otherwise.
473bool mlir::affine::getLoopNestStats(AffineForOp forOpRoot,
474 LoopNestStats *stats) {
475 auto walkResult = forOpRoot.walk([&](AffineForOp forOp) {
476 auto *childForOp = forOp.getOperation();
477 auto *parentForOp = forOp->getParentOp();
478 if (forOp != forOpRoot) {
479 if (!isa<AffineForOp>(parentForOp)) {
480 LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n");
481 return WalkResult::interrupt();
482 }
483 // Add mapping to 'forOp' from its parent AffineForOp.
484 stats->loopMap[parentForOp].push_back(forOp);
485 }
486
487 // Record the number of op operations in the body of 'forOp'.
488 unsigned count = 0;
489 stats->opCountMap[childForOp] = 0;
490 for (auto &op : *forOp.getBody()) {
491 if (!isa<AffineForOp, AffineIfOp>(op))
492 ++count;
493 }
494 stats->opCountMap[childForOp] = count;
495
496 // Record trip count for 'forOp'. Set flag if trip count is not
497 // constant.
498 std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
499 if (!maybeConstTripCount) {
500 // Currently only constant trip count loop nests are supported.
501 LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n");
502 return WalkResult::interrupt();
503 }
504
505 stats->tripCountMap[childForOp] = *maybeConstTripCount;
506 return WalkResult::advance();
507 });
508 return !walkResult.wasInterrupted();
509}
510
511// Computes the total cost of the loop nest rooted at 'forOp'.
512// Currently, the total cost is computed by counting the total operation
513// instance count (i.e. total number of operations in the loop bodyloop
514// operation count * loop trip count) for the entire loop nest.
515// If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
516// specified in the map when computing the total op instance count.
517// NOTEs: 1) This is used to compute the cost of computation slices, which are
518// sliced along the iteration dimension, and thus reduce the trip count.
519// If 'computeCostMap' is non-null, the total op count for forOps specified
520// in the map is increased (not overridden) by adding the op count from the
521// map to the existing op count for the for loop. This is done before
522// multiplying by the loop's trip count, and is used to model the cost of
523// inserting a sliced loop nest of known cost into the loop's body.
524// 2) This is also used to compute the cost of fusing a slice of some loop nest
525// within another loop.
526static int64_t getComputeCostHelper(
527 Operation *forOp, LoopNestStats &stats,
528 llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap,
529 DenseMap<Operation *, int64_t> *computeCostMap) {
530 // 'opCount' is the total number operations in one iteration of 'forOp' body,
531 // minus terminator op which is a no-op.
532 int64_t opCount = stats.opCountMap[forOp] - 1;
533 if (stats.loopMap.count(Val: forOp) > 0) {
534 for (auto childForOp : stats.loopMap[forOp]) {
535 opCount += getComputeCostHelper(childForOp, stats, tripCountOverrideMap,
536 computeCostMap);
537 }
538 }
539 // Add in additional op instances from slice (if specified in map).
540 if (computeCostMap) {
541 auto it = computeCostMap->find(Val: forOp);
542 if (it != computeCostMap->end()) {
543 opCount += it->second;
544 }
545 }
546 // Override trip count (if specified in map).
547 int64_t tripCount = stats.tripCountMap[forOp];
548 if (tripCountOverrideMap) {
549 auto it = tripCountOverrideMap->find(Val: forOp);
550 if (it != tripCountOverrideMap->end()) {
551 tripCount = it->second;
552 }
553 }
554 // Returns the total number of dynamic instances of operations in loop body.
555 return tripCount * opCount;
556}
557
558/// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
559/// Currently, the total cost is computed by counting the total operation
560/// instance count (i.e. total number of operations in the loop body * loop
561/// trip count) for the entire loop nest.
562int64_t mlir::affine::getComputeCost(AffineForOp forOp, LoopNestStats &stats) {
563 return getComputeCostHelper(forOp, stats,
564 /*tripCountOverrideMap=*/nullptr,
565 /*computeCostMap=*/nullptr);
566}
567
568/// Computes and returns in 'computeCost', the total compute cost of fusing the
569/// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently,
570/// the total cost is computed by counting the total operation instance count
571/// (i.e. total number of operations in the loop body * loop trip count) for
572/// the entire loop nest.
573bool mlir::affine::getFusionComputeCost(AffineForOp srcForOp,
574 LoopNestStats &srcStats,
575 AffineForOp dstForOp,
576 LoopNestStats &dstStats,
577 const ComputationSliceState &slice,
578 int64_t *computeCost) {
579 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
580 DenseMap<Operation *, int64_t> computeCostMap;
581
582 // Build trip count map for computation slice.
583 if (!buildSliceTripCountMap(slice, tripCountMap: &sliceTripCountMap))
584 return false;
585 // Checks whether a store to load forwarding will happen.
586 int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
587 assert(sliceIterationCount > 0);
588 bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
589 auto *insertPointParent = slice.insertPoint->getParentOp();
590
591 // The store and loads to this memref will disappear.
592 if (storeLoadFwdGuaranteed) {
593 // Subtract from operation count the loads/store we expect load/store
594 // forwarding to remove.
595 unsigned storeCount = 0;
596 llvm::SmallDenseSet<Value, 4> storeMemrefs;
597 srcForOp.walk([&](AffineWriteOpInterface storeOp) {
598 storeMemrefs.insert(storeOp.getMemRef());
599 ++storeCount;
600 });
601 // Subtract out any store ops in single-iteration src slice loop nest.
602 if (storeCount > 0)
603 computeCostMap[insertPointParent] = -storeCount;
604 // Subtract out any load users of 'storeMemrefs' nested below
605 // 'insertPointParent'.
606 for (Value memref : storeMemrefs) {
607 for (Operation *user : memref.getUsers()) {
608 if (!isa<AffineReadOpInterface>(user))
609 continue;
610 SmallVector<AffineForOp, 4> loops;
611 // Check if any loop in loop nest surrounding 'user' is
612 // 'insertPointParent'.
613 getAffineForIVs(*user, &loops);
614 if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) {
615 if (auto forOp = dyn_cast_or_null<AffineForOp>(user->getParentOp()))
616 --computeCostMap[forOp];
617 }
618 }
619 }
620 }
621
622 // Compute op instance count for the src loop nest with iteration slicing.
623 int64_t sliceComputeCost = getComputeCostHelper(
624 srcForOp, srcStats, &sliceTripCountMap, &computeCostMap);
625
626 // Compute cost of fusion for this depth.
627 computeCostMap[insertPointParent] = sliceComputeCost;
628
629 *computeCost =
630 getComputeCostHelper(dstForOp, dstStats,
631 /*tripCountOverrideMap=*/nullptr, &computeCostMap);
632 return true;
633}
634
635/// Returns in 'producerConsumerMemrefs' the memrefs involved in a
636/// producer-consumer dependence between write ops in 'srcOps' and read ops in
637/// 'dstOps'.
638void mlir::affine::gatherProducerConsumerMemrefs(
639 ArrayRef<Operation *> srcOps, ArrayRef<Operation *> dstOps,
640 DenseSet<Value> &producerConsumerMemrefs) {
641 // Gather memrefs from stores in 'srcOps'.
642 DenseSet<Value> srcStoreMemRefs;
643 for (Operation *op : srcOps)
644 if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op))
645 srcStoreMemRefs.insert(storeOp.getMemRef());
646
647 // Compute the intersection between memrefs from stores in 'srcOps' and
648 // memrefs from loads in 'dstOps'.
649 for (Operation *op : dstOps)
650 if (auto loadOp = dyn_cast<AffineReadOpInterface>(op))
651 if (srcStoreMemRefs.count(V: loadOp.getMemRef()) > 0)
652 producerConsumerMemrefs.insert(loadOp.getMemRef());
653}
654

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp