1 | //===- TestLoopFusion.cpp - Test loop fusion ------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a pass to test various loop fusion utility functions. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Affine/Analysis/Utils.h" |
14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
15 | #include "mlir/Dialect/Affine/LoopFusionUtils.h" |
16 | #include "mlir/Dialect/Affine/LoopUtils.h" |
17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
18 | #include "mlir/Pass/Pass.h" |
19 | |
20 | #define DEBUG_TYPE "test-loop-fusion" |
21 | |
22 | using namespace mlir; |
23 | using namespace mlir::affine; |
24 | |
25 | namespace { |
26 | |
27 | struct TestLoopFusion |
28 | : public PassWrapper<TestLoopFusion, OperationPass<func::FuncOp>> { |
29 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopFusion) |
30 | |
31 | StringRef getArgument() const final { return "test-loop-fusion" ; } |
32 | StringRef getDescription() const final { |
33 | return "Tests loop fusion utility functions." ; |
34 | } |
35 | void runOnOperation() override; |
36 | |
37 | TestLoopFusion() = default; |
38 | TestLoopFusion(const TestLoopFusion &pass) : PassWrapper(pass){}; |
39 | |
40 | Option<bool> clTestDependenceCheck{ |
41 | *this, "test-loop-fusion-dependence-check" , |
42 | llvm::cl::desc("Enable testing of loop fusion dependence check" ), |
43 | llvm::cl::init(Val: false)}; |
44 | |
45 | Option<bool> clTestSliceComputation{ |
46 | *this, "test-loop-fusion-slice-computation" , |
47 | llvm::cl::desc("Enable testing of loop fusion slice computation" ), |
48 | llvm::cl::init(Val: false)}; |
49 | |
50 | Option<bool> clTestLoopFusionTransformation{ |
51 | *this, "test-loop-fusion-transformation" , |
52 | llvm::cl::desc("Enable testing of loop fusion transformation" ), |
53 | llvm::cl::init(Val: false)}; |
54 | }; |
55 | |
56 | } // namespace |
57 | |
58 | // Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths |
59 | // in range ['loopDepth' + 1, 'maxLoopDepth']. |
60 | // Emits a remark on 'loops[i]' if a fusion-preventing dependence exists. |
61 | // Returns false as IR is not transformed. |
62 | static bool testDependenceCheck(AffineForOp srcForOp, AffineForOp dstForOp, |
63 | unsigned i, unsigned j, unsigned loopDepth, |
64 | unsigned maxLoopDepth) { |
65 | affine::ComputationSliceState sliceUnion; |
66 | for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) { |
67 | FusionResult result = |
68 | affine::canFuseLoops(srcForOp: srcForOp, dstForOp: dstForOp, dstLoopDepth: d, srcSlice: &sliceUnion); |
69 | if (result.value == FusionResult::FailBlockDependence) { |
70 | srcForOp->emitRemark("block-level dependence preventing" |
71 | " fusion of loop nest " ) |
72 | << i << " into loop nest " << j << " at depth " << loopDepth; |
73 | } |
74 | } |
75 | return false; |
76 | } |
77 | |
78 | // Returns the index of 'op' in its block. |
79 | static unsigned getBlockIndex(Operation &op) { |
80 | unsigned index = 0; |
81 | for (auto &opX : *op.getBlock()) { |
82 | if (&op == &opX) |
83 | break; |
84 | ++index; |
85 | } |
86 | return index; |
87 | } |
88 | |
89 | // Returns a string representation of 'sliceUnion'. |
90 | static std::string |
91 | getSliceStr(const affine::ComputationSliceState &sliceUnion) { |
92 | std::string result; |
93 | llvm::raw_string_ostream os(result); |
94 | // Slice insertion point format [loop-depth, operation-block-index] |
95 | unsigned ipd = getNestingDepth(op: &*sliceUnion.insertPoint); |
96 | unsigned ipb = getBlockIndex(op&: *sliceUnion.insertPoint); |
97 | os << "insert point: (" << std::to_string(val: ipd) << ", " << std::to_string(val: ipb) |
98 | << ")" ; |
99 | assert(sliceUnion.lbs.size() == sliceUnion.ubs.size()); |
100 | os << " loop bounds: " ; |
101 | for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) { |
102 | os << '['; |
103 | sliceUnion.lbs[k].print(os); |
104 | os << ", " ; |
105 | sliceUnion.ubs[k].print(os); |
106 | os << "] " ; |
107 | } |
108 | return os.str(); |
109 | } |
110 | |
111 | /// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths |
112 | /// in range ['loopDepth' + 1, 'maxLoopDepth']. |
113 | /// Emits a string representation of the slice union as a remark on 'loops[j]' |
114 | /// and marks this as incorrect slice if the slice is invalid. Returns false as |
115 | /// IR is not transformed. |
116 | static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB, |
117 | unsigned i, unsigned j, unsigned loopDepth, |
118 | unsigned maxLoopDepth) { |
119 | for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) { |
120 | affine::ComputationSliceState sliceUnion; |
121 | FusionResult result = affine::canFuseLoops(srcForOp: forOpA, dstForOp: forOpB, dstLoopDepth: d, srcSlice: &sliceUnion); |
122 | if (result.value == FusionResult::Success) { |
123 | forOpB->emitRemark("slice (" ) |
124 | << " src loop: " << i << ", dst loop: " << j << ", depth: " << d |
125 | << " : " << getSliceStr(sliceUnion) << ")" ; |
126 | } else if (result.value == FusionResult::FailIncorrectSlice) { |
127 | forOpB->emitRemark("Incorrect slice (" ) |
128 | << " src loop: " << i << ", dst loop: " << j << ", depth: " << d |
129 | << " : " << getSliceStr(sliceUnion) << ")" ; |
130 | } |
131 | } |
132 | return false; |
133 | } |
134 | |
135 | // Attempts to fuse 'forOpA' into 'forOpB' at loop depths in range |
136 | // ['loopDepth' + 1, 'maxLoopDepth']. |
137 | // Returns true if loops were successfully fused, false otherwise. |
138 | static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB, |
139 | unsigned i, unsigned j, |
140 | unsigned loopDepth, |
141 | unsigned maxLoopDepth) { |
142 | for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) { |
143 | affine::ComputationSliceState sliceUnion; |
144 | FusionResult result = affine::canFuseLoops(srcForOp: forOpA, dstForOp: forOpB, dstLoopDepth: d, srcSlice: &sliceUnion); |
145 | if (result.value == FusionResult::Success) { |
146 | affine::fuseLoops(srcForOp: forOpA, dstForOp: forOpB, srcSlice: sliceUnion); |
147 | // Note: 'forOpA' is removed to simplify test output. A proper loop |
148 | // fusion pass should check the data dependence graph and run memref |
149 | // region analysis to ensure removing 'forOpA' is safe. |
150 | forOpA.erase(); |
151 | return true; |
152 | } |
153 | } |
154 | return false; |
155 | } |
156 | |
157 | using LoopFunc = function_ref<bool(AffineForOp, AffineForOp, unsigned, unsigned, |
158 | unsigned, unsigned)>; |
159 | |
160 | // Run tests on all combinations of src/dst loop nests in 'depthToLoops'. |
161 | // If 'return_on_change' is true, returns on first invocation of 'fn' which |
162 | // returns true. |
163 | static bool iterateLoops(ArrayRef<SmallVector<AffineForOp, 2>> depthToLoops, |
164 | LoopFunc fn, bool returnOnChange = false) { |
165 | bool changed = false; |
166 | for (unsigned loopDepth = 0, end = depthToLoops.size(); loopDepth < end; |
167 | ++loopDepth) { |
168 | auto &loops = depthToLoops[loopDepth]; |
169 | unsigned numLoops = loops.size(); |
170 | for (unsigned j = 0; j < numLoops; ++j) { |
171 | for (unsigned k = 0; k < numLoops; ++k) { |
172 | if (j != k) |
173 | changed |= |
174 | fn(loops[j], loops[k], j, k, loopDepth, depthToLoops.size()); |
175 | if (changed && returnOnChange) |
176 | return true; |
177 | } |
178 | } |
179 | } |
180 | return changed; |
181 | } |
182 | |
183 | void TestLoopFusion::runOnOperation() { |
184 | std::vector<SmallVector<AffineForOp, 2>> depthToLoops; |
185 | if (clTestLoopFusionTransformation) { |
186 | // Run loop fusion until a fixed point is reached. |
187 | do { |
188 | depthToLoops.clear(); |
189 | // Gather all AffineForOps by loop depth. |
190 | gatherLoops(getOperation(), depthToLoops); |
191 | |
192 | // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'. |
193 | } while (iterateLoops(depthToLoops, testLoopFusionTransformation, |
194 | /*returnOnChange=*/true)); |
195 | return; |
196 | } |
197 | |
198 | // Gather all AffineForOps by loop depth. |
199 | gatherLoops(getOperation(), depthToLoops); |
200 | |
201 | // Run tests on all combinations of src/dst loop nests in 'depthToLoops'. |
202 | if (clTestDependenceCheck) |
203 | iterateLoops(depthToLoops, testDependenceCheck); |
204 | if (clTestSliceComputation) |
205 | iterateLoops(depthToLoops, testSliceComputation); |
206 | } |
207 | |
208 | namespace mlir { |
209 | namespace test { |
210 | void registerTestLoopFusion() { PassRegistration<TestLoopFusion>(); } |
211 | } // namespace test |
212 | } // namespace mlir |
213 | |