1//===- TestLoopFusion.cpp - Test loop fusion ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to test various loop fusion utility functions.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/Analysis/Utils.h"
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Affine/LoopFusionUtils.h"
16#include "mlir/Dialect/Affine/LoopUtils.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Pass/Pass.h"
19
20#define DEBUG_TYPE "test-loop-fusion"
21
22using namespace mlir;
23using namespace mlir::affine;
24
25namespace {
26
27struct TestLoopFusion
28 : public PassWrapper<TestLoopFusion, OperationPass<func::FuncOp>> {
29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopFusion)
30
31 StringRef getArgument() const final { return "test-loop-fusion"; }
32 StringRef getDescription() const final {
33 return "Tests loop fusion utility functions.";
34 }
35 void runOnOperation() override;
36
37 TestLoopFusion() = default;
38 TestLoopFusion(const TestLoopFusion &pass) : PassWrapper(pass){};
39
40 Option<bool> clTestDependenceCheck{
41 *this, "test-loop-fusion-dependence-check",
42 llvm::cl::desc("Enable testing of loop fusion dependence check"),
43 llvm::cl::init(Val: false)};
44
45 Option<bool> clTestSliceComputation{
46 *this, "test-loop-fusion-slice-computation",
47 llvm::cl::desc("Enable testing of loop fusion slice computation"),
48 llvm::cl::init(Val: false)};
49
50 Option<bool> clTestLoopFusionTransformation{
51 *this, "test-loop-fusion-transformation",
52 llvm::cl::desc("Enable testing of loop fusion transformation"),
53 llvm::cl::init(Val: false)};
54};
55
56} // namespace
57
58// Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths
59// in range ['loopDepth' + 1, 'maxLoopDepth'].
60// Emits a remark on 'loops[i]' if a fusion-preventing dependence exists.
61// Returns false as IR is not transformed.
62static bool testDependenceCheck(AffineForOp srcForOp, AffineForOp dstForOp,
63 unsigned i, unsigned j, unsigned loopDepth,
64 unsigned maxLoopDepth) {
65 affine::ComputationSliceState sliceUnion;
66 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
67 FusionResult result =
68 affine::canFuseLoops(srcForOp: srcForOp, dstForOp: dstForOp, dstLoopDepth: d, srcSlice: &sliceUnion);
69 if (result.value == FusionResult::FailBlockDependence) {
70 srcForOp->emitRemark("block-level dependence preventing"
71 " fusion of loop nest ")
72 << i << " into loop nest " << j << " at depth " << loopDepth;
73 }
74 }
75 return false;
76}
77
78// Returns the index of 'op' in its block.
79static unsigned getBlockIndex(Operation &op) {
80 unsigned index = 0;
81 for (auto &opX : *op.getBlock()) {
82 if (&op == &opX)
83 break;
84 ++index;
85 }
86 return index;
87}
88
89// Returns a string representation of 'sliceUnion'.
90static std::string
91getSliceStr(const affine::ComputationSliceState &sliceUnion) {
92 std::string result;
93 llvm::raw_string_ostream os(result);
94 // Slice insertion point format [loop-depth, operation-block-index]
95 unsigned ipd = getNestingDepth(op: &*sliceUnion.insertPoint);
96 unsigned ipb = getBlockIndex(op&: *sliceUnion.insertPoint);
97 os << "insert point: (" << std::to_string(val: ipd) << ", " << std::to_string(val: ipb)
98 << ")";
99 assert(sliceUnion.lbs.size() == sliceUnion.ubs.size());
100 os << " loop bounds: ";
101 for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) {
102 os << '[';
103 sliceUnion.lbs[k].print(os);
104 os << ", ";
105 sliceUnion.ubs[k].print(os);
106 os << "] ";
107 }
108 return os.str();
109}
110
111/// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths
112/// in range ['loopDepth' + 1, 'maxLoopDepth'].
113/// Emits a string representation of the slice union as a remark on 'loops[j]'
114/// and marks this as incorrect slice if the slice is invalid. Returns false as
115/// IR is not transformed.
116static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB,
117 unsigned i, unsigned j, unsigned loopDepth,
118 unsigned maxLoopDepth) {
119 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
120 affine::ComputationSliceState sliceUnion;
121 FusionResult result = affine::canFuseLoops(srcForOp: forOpA, dstForOp: forOpB, dstLoopDepth: d, srcSlice: &sliceUnion);
122 if (result.value == FusionResult::Success) {
123 forOpB->emitRemark("slice (")
124 << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
125 << " : " << getSliceStr(sliceUnion) << ")";
126 } else if (result.value == FusionResult::FailIncorrectSlice) {
127 forOpB->emitRemark("Incorrect slice (")
128 << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
129 << " : " << getSliceStr(sliceUnion) << ")";
130 }
131 }
132 return false;
133}
134
135// Attempts to fuse 'forOpA' into 'forOpB' at loop depths in range
136// ['loopDepth' + 1, 'maxLoopDepth'].
137// Returns true if loops were successfully fused, false otherwise.
138static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB,
139 unsigned i, unsigned j,
140 unsigned loopDepth,
141 unsigned maxLoopDepth) {
142 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
143 affine::ComputationSliceState sliceUnion;
144 FusionResult result = affine::canFuseLoops(srcForOp: forOpA, dstForOp: forOpB, dstLoopDepth: d, srcSlice: &sliceUnion);
145 if (result.value == FusionResult::Success) {
146 affine::fuseLoops(srcForOp: forOpA, dstForOp: forOpB, srcSlice: sliceUnion);
147 // Note: 'forOpA' is removed to simplify test output. A proper loop
148 // fusion pass should check the data dependence graph and run memref
149 // region analysis to ensure removing 'forOpA' is safe.
150 forOpA.erase();
151 return true;
152 }
153 }
154 return false;
155}
156
157using LoopFunc = function_ref<bool(AffineForOp, AffineForOp, unsigned, unsigned,
158 unsigned, unsigned)>;
159
160// Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
161// If 'return_on_change' is true, returns on first invocation of 'fn' which
162// returns true.
163static bool iterateLoops(ArrayRef<SmallVector<AffineForOp, 2>> depthToLoops,
164 LoopFunc fn, bool returnOnChange = false) {
165 bool changed = false;
166 for (unsigned loopDepth = 0, end = depthToLoops.size(); loopDepth < end;
167 ++loopDepth) {
168 auto &loops = depthToLoops[loopDepth];
169 unsigned numLoops = loops.size();
170 for (unsigned j = 0; j < numLoops; ++j) {
171 for (unsigned k = 0; k < numLoops; ++k) {
172 if (j != k)
173 changed |=
174 fn(loops[j], loops[k], j, k, loopDepth, depthToLoops.size());
175 if (changed && returnOnChange)
176 return true;
177 }
178 }
179 }
180 return changed;
181}
182
183void TestLoopFusion::runOnOperation() {
184 std::vector<SmallVector<AffineForOp, 2>> depthToLoops;
185 if (clTestLoopFusionTransformation) {
186 // Run loop fusion until a fixed point is reached.
187 do {
188 depthToLoops.clear();
189 // Gather all AffineForOps by loop depth.
190 gatherLoops(getOperation(), depthToLoops);
191
192 // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'.
193 } while (iterateLoops(depthToLoops, testLoopFusionTransformation,
194 /*returnOnChange=*/true));
195 return;
196 }
197
198 // Gather all AffineForOps by loop depth.
199 gatherLoops(getOperation(), depthToLoops);
200
201 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
202 if (clTestDependenceCheck)
203 iterateLoops(depthToLoops, testDependenceCheck);
204 if (clTestSliceComputation)
205 iterateLoops(depthToLoops, testSliceComputation);
206}
207
208namespace mlir {
209namespace test {
210void registerTestLoopFusion() { PassRegistration<TestLoopFusion>(); }
211} // namespace test
212} // namespace mlir
213

source code of mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp