1 | //===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements loop fusion on parallel loops. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/SCF/Transforms/Passes.h" |
14 | |
15 | #include "mlir/Analysis/AliasAnalysis.h" |
16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
17 | #include "mlir/Dialect/SCF/IR/SCF.h" |
18 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
19 | #include "mlir/IR/Builders.h" |
20 | #include "mlir/IR/IRMapping.h" |
21 | #include "mlir/IR/OpDefinition.h" |
22 | #include "mlir/IR/OperationSupport.h" |
23 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
24 | |
25 | namespace mlir { |
26 | #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION |
27 | #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" |
28 | } // namespace mlir |
29 | |
30 | using namespace mlir; |
31 | using namespace mlir::scf; |
32 | |
33 | /// Verify there are no nested ParallelOps. |
34 | static bool hasNestedParallelOp(ParallelOp ploop) { |
35 | auto walkResult = |
36 | ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); }); |
37 | return walkResult.wasInterrupted(); |
38 | } |
39 | |
40 | /// Verify equal iteration spaces. |
41 | static bool equalIterationSpaces(ParallelOp firstPloop, |
42 | ParallelOp secondPloop) { |
43 | if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) |
44 | return false; |
45 | |
46 | auto matchOperands = [&](const OperandRange &lhs, |
47 | const OperandRange &rhs) -> bool { |
48 | // TODO: Extend this to support aliases and equal constants. |
49 | return std::equal(first1: lhs.begin(), last1: lhs.end(), first2: rhs.begin()); |
50 | }; |
51 | return matchOperands(firstPloop.getLowerBound(), |
52 | secondPloop.getLowerBound()) && |
53 | matchOperands(firstPloop.getUpperBound(), |
54 | secondPloop.getUpperBound()) && |
55 | matchOperands(firstPloop.getStep(), secondPloop.getStep()); |
56 | } |
57 | |
58 | /// Checks if the parallel loops have mixed access to the same buffers. Returns |
59 | /// `true` if the first parallel loop writes to the same indices that the second |
60 | /// loop reads. |
61 | static bool haveNoReadsAfterWriteExceptSameIndex( |
62 | ParallelOp firstPloop, ParallelOp secondPloop, |
63 | const IRMapping &firstToSecondPloopIndices, |
64 | llvm::function_ref<bool(Value, Value)> mayAlias) { |
65 | DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores; |
66 | SmallVector<Value> bufferStoresVec; |
67 | firstPloop.getBody()->walk([&](memref::StoreOp store) { |
68 | bufferStores[store.getMemRef()].push_back(store.getIndices()); |
69 | bufferStoresVec.emplace_back(store.getMemRef()); |
70 | }); |
71 | auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) { |
72 | Value loadMem = load.getMemRef(); |
73 | // Stop if the memref is defined in secondPloop body. Careful alias analysis |
74 | // is needed. |
75 | auto *memrefDef = loadMem.getDefiningOp(); |
76 | if (memrefDef && memrefDef->getBlock() == load->getBlock()) |
77 | return WalkResult::interrupt(); |
78 | |
79 | for (Value store : bufferStoresVec) |
80 | if (store != loadMem && mayAlias(store, loadMem)) |
81 | return WalkResult::interrupt(); |
82 | |
83 | auto write = bufferStores.find(Val: loadMem); |
84 | if (write == bufferStores.end()) |
85 | return WalkResult::advance(); |
86 | |
87 | // Check that at last one store was retrieved |
88 | if (!write->second.size()) |
89 | return WalkResult::interrupt(); |
90 | |
91 | auto storeIndices = write->second.front(); |
92 | |
93 | // Multiple writes to the same memref are allowed only on the same indices |
94 | for (const auto &othStoreIndices : write->second) { |
95 | if (othStoreIndices != storeIndices) |
96 | return WalkResult::interrupt(); |
97 | } |
98 | |
99 | // Check that the load indices of secondPloop coincide with store indices of |
100 | // firstPloop for the same memrefs. |
101 | auto loadIndices = load.getIndices(); |
102 | if (storeIndices.size() != loadIndices.size()) |
103 | return WalkResult::interrupt(); |
104 | for (int i = 0, e = storeIndices.size(); i < e; ++i) { |
105 | if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) != |
106 | loadIndices[i]) { |
107 | auto *storeIndexDefOp = storeIndices[i].getDefiningOp(); |
108 | auto *loadIndexDefOp = loadIndices[i].getDefiningOp(); |
109 | if (storeIndexDefOp && loadIndexDefOp) { |
110 | if (!isMemoryEffectFree(storeIndexDefOp)) |
111 | return WalkResult::interrupt(); |
112 | if (!isMemoryEffectFree(loadIndexDefOp)) |
113 | return WalkResult::interrupt(); |
114 | if (!OperationEquivalence::isEquivalentTo( |
115 | storeIndexDefOp, loadIndexDefOp, |
116 | [&](Value storeIndex, Value loadIndex) { |
117 | if (firstToSecondPloopIndices.lookupOrDefault(from: storeIndex) != |
118 | firstToSecondPloopIndices.lookupOrDefault(from: loadIndex)) |
119 | return failure(); |
120 | else |
121 | return success(); |
122 | }, |
123 | /*markEquivalent=*/nullptr, |
124 | OperationEquivalence::Flags::IgnoreLocations)) { |
125 | return WalkResult::interrupt(); |
126 | } |
127 | } else |
128 | return WalkResult::interrupt(); |
129 | } |
130 | } |
131 | return WalkResult::advance(); |
132 | }); |
133 | return !walkResult.wasInterrupted(); |
134 | } |
135 | |
136 | /// Analyzes dependencies in the most primitive way by checking simple read and |
137 | /// write patterns. |
138 | static LogicalResult |
139 | verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, |
140 | const IRMapping &firstToSecondPloopIndices, |
141 | llvm::function_ref<bool(Value, Value)> mayAlias) { |
142 | if (!haveNoReadsAfterWriteExceptSameIndex( |
143 | firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias)) |
144 | return failure(); |
145 | |
146 | IRMapping secondToFirstPloopIndices; |
147 | secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(), |
148 | firstPloop.getBody()->getArguments()); |
149 | return success(haveNoReadsAfterWriteExceptSameIndex( |
150 | secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias)); |
151 | } |
152 | |
153 | static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, |
154 | const IRMapping &firstToSecondPloopIndices, |
155 | llvm::function_ref<bool(Value, Value)> mayAlias) { |
156 | return !hasNestedParallelOp(firstPloop) && |
157 | !hasNestedParallelOp(secondPloop) && |
158 | equalIterationSpaces(firstPloop, secondPloop) && |
159 | succeeded(verifyDependencies(firstPloop, secondPloop, |
160 | firstToSecondPloopIndices, mayAlias)); |
161 | } |
162 | |
163 | /// Prepends operations of firstPloop's body into secondPloop's body. |
164 | /// Updates secondPloop with new loop. |
165 | static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, |
166 | OpBuilder builder, |
167 | llvm::function_ref<bool(Value, Value)> mayAlias) { |
168 | Block *block1 = firstPloop.getBody(); |
169 | Block *block2 = secondPloop.getBody(); |
170 | IRMapping firstToSecondPloopIndices; |
171 | firstToSecondPloopIndices.map(from: block1->getArguments(), to: block2->getArguments()); |
172 | |
173 | if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices, |
174 | mayAlias)) |
175 | return; |
176 | |
177 | DominanceInfo dom; |
178 | // We are fusing first loop into second, make sure there are no users of the |
179 | // first loop results between loops. |
180 | for (Operation *user : firstPloop->getUsers()) |
181 | if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false)) |
182 | return; |
183 | |
184 | ValueRange inits1 = firstPloop.getInitVals(); |
185 | ValueRange inits2 = secondPloop.getInitVals(); |
186 | |
187 | SmallVector<Value> newInitVars(inits1.begin(), inits1.end()); |
188 | newInitVars.append(in_start: inits2.begin(), in_end: inits2.end()); |
189 | |
190 | IRRewriter b(builder); |
191 | b.setInsertionPoint(secondPloop); |
192 | auto newSecondPloop = b.create<ParallelOp>( |
193 | secondPloop.getLoc(), secondPloop.getLowerBound(), |
194 | secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars); |
195 | |
196 | Block *newBlock = newSecondPloop.getBody(); |
197 | auto term1 = cast<ReduceOp>(block1->getTerminator()); |
198 | auto term2 = cast<ReduceOp>(block2->getTerminator()); |
199 | |
200 | b.inlineBlockBefore(source: block2, dest: newBlock, before: newBlock->begin(), |
201 | argValues: newBlock->getArguments()); |
202 | b.inlineBlockBefore(source: block1, dest: newBlock, before: newBlock->begin(), |
203 | argValues: newBlock->getArguments()); |
204 | |
205 | ValueRange results = newSecondPloop.getResults(); |
206 | if (!results.empty()) { |
207 | b.setInsertionPointToEnd(newBlock); |
208 | |
209 | ValueRange reduceArgs1 = term1.getOperands(); |
210 | ValueRange reduceArgs2 = term2.getOperands(); |
211 | SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); |
212 | newReduceArgs.append(in_start: reduceArgs2.begin(), in_end: reduceArgs2.end()); |
213 | |
214 | auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs); |
215 | |
216 | for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>( |
217 | term1.getReductions(), term2.getReductions()))) { |
218 | Block &oldRedBlock = reg.front(); |
219 | Block &newRedBlock = newReduceOp.getReductions()[i].front(); |
220 | b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(), |
221 | newRedBlock.getArguments()); |
222 | } |
223 | |
224 | firstPloop.replaceAllUsesWith(results.take_front(n: inits1.size())); |
225 | secondPloop.replaceAllUsesWith(results.take_back(n: inits2.size())); |
226 | } |
227 | term1->erase(); |
228 | term2->erase(); |
229 | firstPloop.erase(); |
230 | secondPloop.erase(); |
231 | secondPloop = newSecondPloop; |
232 | } |
233 | |
234 | void mlir::scf::naivelyFuseParallelOps( |
235 | Region ®ion, llvm::function_ref<bool(Value, Value)> mayAlias) { |
236 | OpBuilder b(region); |
237 | // Consider every single block and attempt to fuse adjacent loops. |
238 | SmallVector<SmallVector<ParallelOp>, 1> ploopChains; |
239 | for (auto &block : region) { |
240 | ploopChains.clear(); |
241 | ploopChains.push_back({}); |
242 | |
243 | // Not using `walk()` to traverse only top-level parallel loops and also |
244 | // make sure that there are no side-effecting ops between the parallel |
245 | // loops. |
246 | bool noSideEffects = true; |
247 | for (auto &op : block) { |
248 | if (auto ploop = dyn_cast<ParallelOp>(op)) { |
249 | if (noSideEffects) { |
250 | ploopChains.back().push_back(ploop); |
251 | } else { |
252 | ploopChains.push_back({ploop}); |
253 | noSideEffects = true; |
254 | } |
255 | continue; |
256 | } |
257 | // TODO: Handle region side effects properly. |
258 | noSideEffects &= isMemoryEffectFree(op: &op) && op.getNumRegions() == 0; |
259 | } |
260 | for (MutableArrayRef<ParallelOp> ploops : ploopChains) { |
261 | for (int i = 0, e = ploops.size(); i + 1 < e; ++i) |
262 | fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias); |
263 | } |
264 | } |
265 | } |
266 | |
267 | namespace { |
268 | struct ParallelLoopFusion |
269 | : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> { |
270 | void runOnOperation() override { |
271 | auto &AA = getAnalysis<AliasAnalysis>(); |
272 | |
273 | auto mayAlias = [&](Value val1, Value val2) -> bool { |
274 | return !AA.alias(val1, val2).isNo(); |
275 | }; |
276 | |
277 | getOperation()->walk([&](Operation *child) { |
278 | for (Region ®ion : child->getRegions()) |
279 | naivelyFuseParallelOps(region, mayAlias); |
280 | }); |
281 | } |
282 | }; |
283 | } // namespace |
284 | |
285 | std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() { |
286 | return std::make_unique<ParallelLoopFusion>(); |
287 | } |
288 | |