1//===- TestLoopFusion.cpp - Test loop fusion ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to test various loop fusion utilities. It is not
10// meant to be a pass to perform valid fusion.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/Analysis/Utils.h"
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Affine/LoopFusionUtils.h"
17#include "mlir/Dialect/Affine/LoopUtils.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Pass/Pass.h"
20
21#define DEBUG_TYPE "test-loop-fusion"
22
23using namespace mlir;
24using namespace mlir::affine;
25
26namespace {
27
28struct TestLoopFusion
29 : public PassWrapper<TestLoopFusion, OperationPass<func::FuncOp>> {
30 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopFusion)
31
32 StringRef getArgument() const final { return "test-loop-fusion"; }
33 StringRef getDescription() const final {
34 return "Tests loop fusion utility functions.";
35 }
36 void runOnOperation() override;
37
38 TestLoopFusion() = default;
39 TestLoopFusion(const TestLoopFusion &pass) : PassWrapper(pass){};
40
41 Option<bool> clTestDependenceCheck{
42 *this, "test-loop-fusion-dependence-check",
43 llvm::cl::desc("Enable testing of loop fusion dependence check"),
44 llvm::cl::init(Val: false)};
45
46 Option<bool> clTestSliceComputation{
47 *this, "test-loop-fusion-slice-computation",
48 llvm::cl::desc("Enable testing of loop fusion slice computation"),
49 llvm::cl::init(Val: false)};
50
51 Option<bool> clTestLoopFusionUtilities{
52 *this, "test-loop-fusion-utilities",
53 llvm::cl::desc("Enable testing of loop fusion transformation utilities"),
54 llvm::cl::init(Val: false)};
55};
56
57} // namespace
58
59// Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths
60// in range ['loopDepth' + 1, 'maxLoopDepth'].
61// Emits a remark on 'loops[i]' if a fusion-preventing dependence exists.
62// Returns false as IR is not transformed.
63static bool testDependenceCheck(AffineForOp srcForOp, AffineForOp dstForOp,
64 unsigned i, unsigned j, unsigned loopDepth,
65 unsigned maxLoopDepth) {
66 ComputationSliceState sliceUnion;
67 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
68 FusionResult result = canFuseLoops(srcForOp, dstForOp, dstLoopDepth: d, srcSlice: &sliceUnion);
69 if (result.value == FusionResult::FailBlockDependence) {
70 srcForOp->emitRemark(message: "block-level dependence preventing"
71 " fusion of loop nest ")
72 << i << " into loop nest " << j << " at depth " << loopDepth;
73 }
74 }
75 return false;
76}
77
78// Returns the index of 'op' in its block.
79static unsigned getBlockIndex(Operation &op) {
80 unsigned index = 0;
81 for (auto &opX : *op.getBlock()) {
82 if (&op == &opX)
83 break;
84 ++index;
85 }
86 return index;
87}
88
89// Returns a string representation of 'sliceUnion'.
90static std::string getSliceStr(const ComputationSliceState &sliceUnion) {
91 std::string result;
92 llvm::raw_string_ostream os(result);
93 // Slice insertion point format [loop-depth, operation-block-index]
94 unsigned ipd = getNestingDepth(op: &*sliceUnion.insertPoint);
95 unsigned ipb = getBlockIndex(op&: *sliceUnion.insertPoint);
96 os << "insert point: (" << std::to_string(val: ipd) << ", " << std::to_string(val: ipb)
97 << ")";
98 assert(sliceUnion.lbs.size() == sliceUnion.ubs.size());
99 os << " loop bounds: ";
100 for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) {
101 os << '[';
102 sliceUnion.lbs[k].print(os);
103 os << ", ";
104 sliceUnion.ubs[k].print(os);
105 os << "] ";
106 }
107 return os.str();
108}
109
110/// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths
111/// in range ['loopDepth' + 1, 'maxLoopDepth'].
112/// Emits a string representation of the slice union as a remark on 'loops[j]'
113/// and marks this as incorrect slice if the slice is invalid. Returns false as
114/// IR is not transformed.
115static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB,
116 unsigned i, unsigned j, unsigned loopDepth,
117 unsigned maxLoopDepth) {
118 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
119 ComputationSliceState sliceUnion;
120 FusionResult result = canFuseLoops(srcForOp: forOpA, dstForOp: forOpB, dstLoopDepth: d, srcSlice: &sliceUnion);
121 if (result.value == FusionResult::Success) {
122 forOpB->emitRemark(message: "slice (")
123 << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
124 << " : " << getSliceStr(sliceUnion) << ")";
125 } else if (result.value == FusionResult::FailIncorrectSlice) {
126 forOpB->emitRemark(message: "Incorrect slice (")
127 << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
128 << " : " << getSliceStr(sliceUnion) << ")";
129 }
130 }
131 return false;
132}
133
134// Attempts to fuse 'forOpA' into 'forOpB' at loop depths in range
135// ['loopDepth' + 1, 'maxLoopDepth'].
136// Returns true if loops were successfully fused, false otherwise. This tests
137// `fuseLoops` and `canFuseLoops` utilities.
138static bool testLoopFusionUtilities(AffineForOp forOpA, AffineForOp forOpB,
139 unsigned i, unsigned j, unsigned loopDepth,
140 unsigned maxLoopDepth) {
141 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
142 ComputationSliceState sliceUnion;
143 // This check isn't a sufficient one, but necessary.
144 FusionResult result = canFuseLoops(srcForOp: forOpA, dstForOp: forOpB, dstLoopDepth: d, srcSlice: &sliceUnion);
145 if (result.value != FusionResult::Success)
146 continue;
147 fuseLoops(srcForOp: forOpA, dstForOp: forOpB, srcSlice: sliceUnion);
148 // Note: 'forOpA' is removed to simplify test output. A proper loop
149 // fusion pass should perform additional checks to check safe removal.
150 if (forOpA.use_empty())
151 forOpA.erase();
152 return true;
153 }
154 return false;
155}
156
157using LoopFunc = function_ref<bool(AffineForOp, AffineForOp, unsigned, unsigned,
158 unsigned, unsigned)>;
159
160// Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
161// If 'return_on_change' is true, returns on first invocation of 'fn' which
162// returns true.
163static bool iterateLoops(ArrayRef<SmallVector<AffineForOp, 2>> depthToLoops,
164 LoopFunc fn, bool returnOnChange = false) {
165 bool changed = false;
166 for (unsigned loopDepth = 0, end = depthToLoops.size(); loopDepth < end;
167 ++loopDepth) {
168 auto &loops = depthToLoops[loopDepth];
169 unsigned numLoops = loops.size();
170 for (unsigned j = 0; j < numLoops; ++j) {
171 for (unsigned k = 0; k < numLoops; ++k) {
172 if (j != k)
173 changed |=
174 fn(loops[j], loops[k], j, k, loopDepth, depthToLoops.size());
175 if (changed && returnOnChange)
176 return true;
177 }
178 }
179 }
180 return changed;
181}
182
183void TestLoopFusion::runOnOperation() {
184 std::vector<SmallVector<AffineForOp, 2>> depthToLoops;
185 if (clTestLoopFusionUtilities) {
186 // Run loop fusion until a fixed point is reached.
187 do {
188 depthToLoops.clear();
189 // Gather all AffineForOps by loop depth.
190 gatherLoops(func: getOperation(), depthToLoops);
191
192 // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'.
193 } while (iterateLoops(depthToLoops, fn: testLoopFusionUtilities,
194 /*returnOnChange=*/true));
195 return;
196 }
197
198 // Gather all AffineForOps by loop depth.
199 gatherLoops(func: getOperation(), depthToLoops);
200
201 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
202 if (clTestDependenceCheck)
203 iterateLoops(depthToLoops, fn: testDependenceCheck);
204 if (clTestSliceComputation)
205 iterateLoops(depthToLoops, fn: testSliceComputation);
206}
207
208namespace mlir {
209namespace test {
210void registerTestLoopFusion() { PassRegistration<TestLoopFusion>(); }
211} // namespace test
212} // namespace mlir
213

source code of mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp