1//===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements loop fusion on parallel loops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SCF/Transforms/Passes.h"
14
15#include "mlir/Analysis/AliasAnalysis.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/Dialect/SCF/Transforms/Transforms.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/OpDefinition.h"
22#include "mlir/IR/OperationSupport.h"
23#include "mlir/Interfaces/SideEffectInterfaces.h"
24
25namespace mlir {
26#define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
27#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
28} // namespace mlir
29
30using namespace mlir;
31using namespace mlir::scf;
32
33/// Verify there are no nested ParallelOps.
34static bool hasNestedParallelOp(ParallelOp ploop) {
35 auto walkResult =
36 ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
37 return walkResult.wasInterrupted();
38}
39
40/// Verify equal iteration spaces.
41static bool equalIterationSpaces(ParallelOp firstPloop,
42 ParallelOp secondPloop) {
43 if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
44 return false;
45
46 auto matchOperands = [&](const OperandRange &lhs,
47 const OperandRange &rhs) -> bool {
48 // TODO: Extend this to support aliases and equal constants.
49 return std::equal(first1: lhs.begin(), last1: lhs.end(), first2: rhs.begin());
50 };
51 return matchOperands(firstPloop.getLowerBound(),
52 secondPloop.getLowerBound()) &&
53 matchOperands(firstPloop.getUpperBound(),
54 secondPloop.getUpperBound()) &&
55 matchOperands(firstPloop.getStep(), secondPloop.getStep());
56}
57
58/// Checks if the parallel loops have mixed access to the same buffers. Returns
59/// `true` if the first parallel loop writes to the same indices that the second
60/// loop reads.
61static bool haveNoReadsAfterWriteExceptSameIndex(
62 ParallelOp firstPloop, ParallelOp secondPloop,
63 const IRMapping &firstToSecondPloopIndices,
64 llvm::function_ref<bool(Value, Value)> mayAlias) {
65 DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
66 SmallVector<Value> bufferStoresVec;
67 firstPloop.getBody()->walk([&](memref::StoreOp store) {
68 bufferStores[store.getMemRef()].push_back(store.getIndices());
69 bufferStoresVec.emplace_back(store.getMemRef());
70 });
71 auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
72 Value loadMem = load.getMemRef();
73 // Stop if the memref is defined in secondPloop body. Careful alias analysis
74 // is needed.
75 auto *memrefDef = loadMem.getDefiningOp();
76 if (memrefDef && memrefDef->getBlock() == load->getBlock())
77 return WalkResult::interrupt();
78
79 for (Value store : bufferStoresVec)
80 if (store != loadMem && mayAlias(store, loadMem))
81 return WalkResult::interrupt();
82
83 auto write = bufferStores.find(Val: loadMem);
84 if (write == bufferStores.end())
85 return WalkResult::advance();
86
87 // Check that at last one store was retrieved
88 if (!write->second.size())
89 return WalkResult::interrupt();
90
91 auto storeIndices = write->second.front();
92
93 // Multiple writes to the same memref are allowed only on the same indices
94 for (const auto &othStoreIndices : write->second) {
95 if (othStoreIndices != storeIndices)
96 return WalkResult::interrupt();
97 }
98
99 // Check that the load indices of secondPloop coincide with store indices of
100 // firstPloop for the same memrefs.
101 auto loadIndices = load.getIndices();
102 if (storeIndices.size() != loadIndices.size())
103 return WalkResult::interrupt();
104 for (int i = 0, e = storeIndices.size(); i < e; ++i) {
105 if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
106 loadIndices[i]) {
107 auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
108 auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
109 if (storeIndexDefOp && loadIndexDefOp) {
110 if (!isMemoryEffectFree(storeIndexDefOp))
111 return WalkResult::interrupt();
112 if (!isMemoryEffectFree(loadIndexDefOp))
113 return WalkResult::interrupt();
114 if (!OperationEquivalence::isEquivalentTo(
115 storeIndexDefOp, loadIndexDefOp,
116 [&](Value storeIndex, Value loadIndex) {
117 if (firstToSecondPloopIndices.lookupOrDefault(from: storeIndex) !=
118 firstToSecondPloopIndices.lookupOrDefault(from: loadIndex))
119 return failure();
120 else
121 return success();
122 },
123 /*markEquivalent=*/nullptr,
124 OperationEquivalence::Flags::IgnoreLocations)) {
125 return WalkResult::interrupt();
126 }
127 } else
128 return WalkResult::interrupt();
129 }
130 }
131 return WalkResult::advance();
132 });
133 return !walkResult.wasInterrupted();
134}
135
136/// Analyzes dependencies in the most primitive way by checking simple read and
137/// write patterns.
138static LogicalResult
139verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
140 const IRMapping &firstToSecondPloopIndices,
141 llvm::function_ref<bool(Value, Value)> mayAlias) {
142 if (!haveNoReadsAfterWriteExceptSameIndex(
143 firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
144 return failure();
145
146 IRMapping secondToFirstPloopIndices;
147 secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
148 firstPloop.getBody()->getArguments());
149 return success(haveNoReadsAfterWriteExceptSameIndex(
150 secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
151}
152
153static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
154 const IRMapping &firstToSecondPloopIndices,
155 llvm::function_ref<bool(Value, Value)> mayAlias) {
156 return !hasNestedParallelOp(firstPloop) &&
157 !hasNestedParallelOp(secondPloop) &&
158 equalIterationSpaces(firstPloop, secondPloop) &&
159 succeeded(verifyDependencies(firstPloop, secondPloop,
160 firstToSecondPloopIndices, mayAlias));
161}
162
163/// Prepends operations of firstPloop's body into secondPloop's body.
164/// Updates secondPloop with new loop.
165static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
166 OpBuilder builder,
167 llvm::function_ref<bool(Value, Value)> mayAlias) {
168 Block *block1 = firstPloop.getBody();
169 Block *block2 = secondPloop.getBody();
170 IRMapping firstToSecondPloopIndices;
171 firstToSecondPloopIndices.map(from: block1->getArguments(), to: block2->getArguments());
172
173 if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
174 mayAlias))
175 return;
176
177 DominanceInfo dom;
178 // We are fusing first loop into second, make sure there are no users of the
179 // first loop results between loops.
180 for (Operation *user : firstPloop->getUsers())
181 if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
182 return;
183
184 ValueRange inits1 = firstPloop.getInitVals();
185 ValueRange inits2 = secondPloop.getInitVals();
186
187 SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
188 newInitVars.append(in_start: inits2.begin(), in_end: inits2.end());
189
190 IRRewriter b(builder);
191 b.setInsertionPoint(secondPloop);
192 auto newSecondPloop = b.create<ParallelOp>(
193 secondPloop.getLoc(), secondPloop.getLowerBound(),
194 secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
195
196 Block *newBlock = newSecondPloop.getBody();
197 auto term1 = cast<ReduceOp>(block1->getTerminator());
198 auto term2 = cast<ReduceOp>(block2->getTerminator());
199
200 b.inlineBlockBefore(source: block2, dest: newBlock, before: newBlock->begin(),
201 argValues: newBlock->getArguments());
202 b.inlineBlockBefore(source: block1, dest: newBlock, before: newBlock->begin(),
203 argValues: newBlock->getArguments());
204
205 ValueRange results = newSecondPloop.getResults();
206 if (!results.empty()) {
207 b.setInsertionPointToEnd(newBlock);
208
209 ValueRange reduceArgs1 = term1.getOperands();
210 ValueRange reduceArgs2 = term2.getOperands();
211 SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
212 newReduceArgs.append(in_start: reduceArgs2.begin(), in_end: reduceArgs2.end());
213
214 auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
215
216 for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
217 term1.getReductions(), term2.getReductions()))) {
218 Block &oldRedBlock = reg.front();
219 Block &newRedBlock = newReduceOp.getReductions()[i].front();
220 b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
221 newRedBlock.getArguments());
222 }
223
224 firstPloop.replaceAllUsesWith(results.take_front(n: inits1.size()));
225 secondPloop.replaceAllUsesWith(results.take_back(n: inits2.size()));
226 }
227 term1->erase();
228 term2->erase();
229 firstPloop.erase();
230 secondPloop.erase();
231 secondPloop = newSecondPloop;
232}
233
234void mlir::scf::naivelyFuseParallelOps(
235 Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
236 OpBuilder b(region);
237 // Consider every single block and attempt to fuse adjacent loops.
238 SmallVector<SmallVector<ParallelOp>, 1> ploopChains;
239 for (auto &block : region) {
240 ploopChains.clear();
241 ploopChains.push_back({});
242
243 // Not using `walk()` to traverse only top-level parallel loops and also
244 // make sure that there are no side-effecting ops between the parallel
245 // loops.
246 bool noSideEffects = true;
247 for (auto &op : block) {
248 if (auto ploop = dyn_cast<ParallelOp>(op)) {
249 if (noSideEffects) {
250 ploopChains.back().push_back(ploop);
251 } else {
252 ploopChains.push_back({ploop});
253 noSideEffects = true;
254 }
255 continue;
256 }
257 // TODO: Handle region side effects properly.
258 noSideEffects &= isMemoryEffectFree(op: &op) && op.getNumRegions() == 0;
259 }
260 for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
261 for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
262 fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
263 }
264 }
265}
266
267namespace {
268struct ParallelLoopFusion
269 : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
270 void runOnOperation() override {
271 auto &AA = getAnalysis<AliasAnalysis>();
272
273 auto mayAlias = [&](Value val1, Value val2) -> bool {
274 return !AA.alias(val1, val2).isNo();
275 };
276
277 getOperation()->walk([&](Operation *child) {
278 for (Region &region : child->getRegions())
279 naivelyFuseParallelOps(region, mayAlias);
280 });
281 }
282};
283} // namespace
284
285std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
286 return std::make_unique<ParallelLoopFusion>();
287}
288

source code of mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp