1//===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements loop fusion on parallel loops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SCF/Transforms/Passes.h"
14
15#include "mlir/Analysis/AliasAnalysis.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/Dialect/SCF/Transforms/Transforms.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/OpDefinition.h"
22#include "mlir/IR/OperationSupport.h"
23#include "mlir/Interfaces/SideEffectInterfaces.h"
24
25namespace mlir {
26#define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
27#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
28} // namespace mlir
29
30using namespace mlir;
31using namespace mlir::scf;
32
33/// Verify there are no nested ParallelOps.
34static bool hasNestedParallelOp(ParallelOp ploop) {
35 auto walkResult =
36 ploop.getBody()->walk(callback: [](ParallelOp) { return WalkResult::interrupt(); });
37 return walkResult.wasInterrupted();
38}
39
40/// Verify equal iteration spaces.
41static bool equalIterationSpaces(ParallelOp firstPloop,
42 ParallelOp secondPloop) {
43 if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
44 return false;
45
46 auto matchOperands = [&](const OperandRange &lhs,
47 const OperandRange &rhs) -> bool {
48 // TODO: Extend this to support aliases and equal constants.
49 return std::equal(first1: lhs.begin(), last1: lhs.end(), first2: rhs.begin());
50 };
51 return matchOperands(firstPloop.getLowerBound(),
52 secondPloop.getLowerBound()) &&
53 matchOperands(firstPloop.getUpperBound(),
54 secondPloop.getUpperBound()) &&
55 matchOperands(firstPloop.getStep(), secondPloop.getStep());
56}
57
58/// Checks if the parallel loops have mixed access to the same buffers. Returns
59/// `true` if the first parallel loop writes to the same indices that the second
60/// loop reads.
61static bool haveNoReadsAfterWriteExceptSameIndex(
62 ParallelOp firstPloop, ParallelOp secondPloop,
63 const IRMapping &firstToSecondPloopIndices,
64 llvm::function_ref<bool(Value, Value)> mayAlias) {
65 DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
66 SmallVector<Value> bufferStoresVec;
67 firstPloop.getBody()->walk(callback: [&](memref::StoreOp store) {
68 bufferStores[store.getMemRef()].push_back(Elt: store.getIndices());
69 bufferStoresVec.emplace_back(Args: store.getMemRef());
70 });
71 auto walkResult = secondPloop.getBody()->walk(callback: [&](memref::LoadOp load) {
72 Value loadMem = load.getMemRef();
73 // Stop if the memref is defined in secondPloop body. Careful alias analysis
74 // is needed.
75 auto *memrefDef = loadMem.getDefiningOp();
76 if (memrefDef && memrefDef->getBlock() == load->getBlock())
77 return WalkResult::interrupt();
78
79 for (Value store : bufferStoresVec)
80 if (store != loadMem && mayAlias(store, loadMem))
81 return WalkResult::interrupt();
82
83 auto write = bufferStores.find(Val: loadMem);
84 if (write == bufferStores.end())
85 return WalkResult::advance();
86
87 // Check that at last one store was retrieved
88 if (write->second.empty())
89 return WalkResult::interrupt();
90
91 auto storeIndices = write->second.front();
92
93 // Multiple writes to the same memref are allowed only on the same indices
94 for (const auto &othStoreIndices : write->second) {
95 if (othStoreIndices != storeIndices)
96 return WalkResult::interrupt();
97 }
98
99 // Check that the load indices of secondPloop coincide with store indices of
100 // firstPloop for the same memrefs.
101 auto loadIndices = load.getIndices();
102 if (storeIndices.size() != loadIndices.size())
103 return WalkResult::interrupt();
104 for (int i = 0, e = storeIndices.size(); i < e; ++i) {
105 if (firstToSecondPloopIndices.lookupOrDefault(from: storeIndices[i]) !=
106 loadIndices[i]) {
107 auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
108 auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
109 if (storeIndexDefOp && loadIndexDefOp) {
110 if (!isMemoryEffectFree(op: storeIndexDefOp))
111 return WalkResult::interrupt();
112 if (!isMemoryEffectFree(op: loadIndexDefOp))
113 return WalkResult::interrupt();
114 if (!OperationEquivalence::isEquivalentTo(
115 lhs: storeIndexDefOp, rhs: loadIndexDefOp,
116 checkEquivalent: [&](Value storeIndex, Value loadIndex) {
117 if (firstToSecondPloopIndices.lookupOrDefault(from: storeIndex) !=
118 firstToSecondPloopIndices.lookupOrDefault(from: loadIndex))
119 return failure();
120 else
121 return success();
122 },
123 /*markEquivalent=*/nullptr,
124 flags: OperationEquivalence::Flags::IgnoreLocations)) {
125 return WalkResult::interrupt();
126 }
127 } else {
128 return WalkResult::interrupt();
129 }
130 }
131 }
132 return WalkResult::advance();
133 });
134 return !walkResult.wasInterrupted();
135}
136
137/// Analyzes dependencies in the most primitive way by checking simple read and
138/// write patterns.
139static LogicalResult
140verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
141 const IRMapping &firstToSecondPloopIndices,
142 llvm::function_ref<bool(Value, Value)> mayAlias) {
143 if (!haveNoReadsAfterWriteExceptSameIndex(
144 firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
145 return failure();
146
147 IRMapping secondToFirstPloopIndices;
148 secondToFirstPloopIndices.map(from: secondPloop.getBody()->getArguments(),
149 to: firstPloop.getBody()->getArguments());
150 return success(IsSuccess: haveNoReadsAfterWriteExceptSameIndex(
151 firstPloop: secondPloop, secondPloop: firstPloop, firstToSecondPloopIndices: secondToFirstPloopIndices, mayAlias));
152}
153
154static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
155 const IRMapping &firstToSecondPloopIndices,
156 llvm::function_ref<bool(Value, Value)> mayAlias) {
157 return !hasNestedParallelOp(ploop: firstPloop) &&
158 !hasNestedParallelOp(ploop: secondPloop) &&
159 equalIterationSpaces(firstPloop, secondPloop) &&
160 succeeded(Result: verifyDependencies(firstPloop, secondPloop,
161 firstToSecondPloopIndices, mayAlias));
162}
163
164/// Prepends operations of firstPloop's body into secondPloop's body.
165/// Updates secondPloop with new loop.
166static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
167 OpBuilder builder,
168 llvm::function_ref<bool(Value, Value)> mayAlias) {
169 Block *block1 = firstPloop.getBody();
170 Block *block2 = secondPloop.getBody();
171 IRMapping firstToSecondPloopIndices;
172 firstToSecondPloopIndices.map(from: block1->getArguments(), to: block2->getArguments());
173
174 if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
175 mayAlias))
176 return;
177
178 DominanceInfo dom;
179 // We are fusing first loop into second, make sure there are no users of the
180 // first loop results between loops.
181 for (Operation *user : firstPloop->getUsers())
182 if (!dom.properlyDominates(a: secondPloop, b: user, /*enclosingOpOk*/ false))
183 return;
184
185 ValueRange inits1 = firstPloop.getInitVals();
186 ValueRange inits2 = secondPloop.getInitVals();
187
188 SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
189 newInitVars.append(in_start: inits2.begin(), in_end: inits2.end());
190
191 IRRewriter b(builder);
192 b.setInsertionPoint(secondPloop);
193 auto newSecondPloop = b.create<ParallelOp>(
194 location: secondPloop.getLoc(), args: secondPloop.getLowerBound(),
195 args: secondPloop.getUpperBound(), args: secondPloop.getStep(), args&: newInitVars);
196
197 Block *newBlock = newSecondPloop.getBody();
198 auto term1 = cast<ReduceOp>(Val: block1->getTerminator());
199 auto term2 = cast<ReduceOp>(Val: block2->getTerminator());
200
201 b.inlineBlockBefore(source: block2, dest: newBlock, before: newBlock->begin(),
202 argValues: newBlock->getArguments());
203 b.inlineBlockBefore(source: block1, dest: newBlock, before: newBlock->begin(),
204 argValues: newBlock->getArguments());
205
206 ValueRange results = newSecondPloop.getResults();
207 if (!results.empty()) {
208 b.setInsertionPointToEnd(newBlock);
209
210 ValueRange reduceArgs1 = term1.getOperands();
211 ValueRange reduceArgs2 = term2.getOperands();
212 SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
213 newReduceArgs.append(in_start: reduceArgs2.begin(), in_end: reduceArgs2.end());
214
215 auto newReduceOp = b.create<scf::ReduceOp>(location: term2.getLoc(), args&: newReduceArgs);
216
217 for (auto &&[i, reg] : llvm::enumerate(First: llvm::concat<Region>(
218 Ranges: term1.getReductions(), Ranges: term2.getReductions()))) {
219 Block &oldRedBlock = reg.front();
220 Block &newRedBlock = newReduceOp.getReductions()[i].front();
221 b.inlineBlockBefore(source: &oldRedBlock, dest: &newRedBlock, before: newRedBlock.begin(),
222 argValues: newRedBlock.getArguments());
223 }
224
225 firstPloop.replaceAllUsesWith(values: results.take_front(n: inits1.size()));
226 secondPloop.replaceAllUsesWith(values: results.take_back(n: inits2.size()));
227 }
228 term1->erase();
229 term2->erase();
230 firstPloop.erase();
231 secondPloop.erase();
232 secondPloop = newSecondPloop;
233}
234
235void mlir::scf::naivelyFuseParallelOps(
236 Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
237 OpBuilder b(region);
238 // Consider every single block and attempt to fuse adjacent loops.
239 SmallVector<SmallVector<ParallelOp>, 1> ploopChains;
240 for (auto &block : region) {
241 ploopChains.clear();
242 ploopChains.push_back(Elt: {});
243
244 // Not using `walk()` to traverse only top-level parallel loops and also
245 // make sure that there are no side-effecting ops between the parallel
246 // loops.
247 bool noSideEffects = true;
248 for (auto &op : block) {
249 if (auto ploop = dyn_cast<ParallelOp>(Val&: op)) {
250 if (noSideEffects) {
251 ploopChains.back().push_back(Elt: ploop);
252 } else {
253 ploopChains.push_back(Elt: {ploop});
254 noSideEffects = true;
255 }
256 continue;
257 }
258 // TODO: Handle region side effects properly.
259 noSideEffects &= isMemoryEffectFree(op: &op) && op.getNumRegions() == 0;
260 }
261 for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
262 for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
263 fuseIfLegal(firstPloop: ploops[i], secondPloop&: ploops[i + 1], builder: b, mayAlias);
264 }
265 }
266}
267
268namespace {
269struct ParallelLoopFusion
270 : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
271 void runOnOperation() override {
272 auto &AA = getAnalysis<AliasAnalysis>();
273
274 auto mayAlias = [&](Value val1, Value val2) -> bool {
275 return !AA.alias(lhs: val1, rhs: val2).isNo();
276 };
277
278 getOperation()->walk(callback: [&](Operation *child) {
279 for (Region &region : child->getRegions())
280 naivelyFuseParallelOps(region, mayAlias);
281 });
282 }
283};
284} // namespace
285
286std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
287 return std::make_unique<ParallelLoopFusion>();
288}
289

source code of mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp