1//===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements loop fusion on parallel loops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SCF/Transforms/Passes.h"
14
15#include "mlir/Analysis/AliasAnalysis.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/Dialect/SCF/Transforms/Transforms.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/OpDefinition.h"
22#include "mlir/IR/OperationSupport.h"
23#include "mlir/Interfaces/SideEffectInterfaces.h"
24
25namespace mlir {
26#define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
27#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
28} // namespace mlir
29
30using namespace mlir;
31using namespace mlir::scf;
32
33/// Verify there are no nested ParallelOps.
34static bool hasNestedParallelOp(ParallelOp ploop) {
35 auto walkResult =
36 ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
37 return walkResult.wasInterrupted();
38}
39
40/// Verify equal iteration spaces.
41static bool equalIterationSpaces(ParallelOp firstPloop,
42 ParallelOp secondPloop) {
43 if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
44 return false;
45
46 auto matchOperands = [&](const OperandRange &lhs,
47 const OperandRange &rhs) -> bool {
48 // TODO: Extend this to support aliases and equal constants.
49 return std::equal(first1: lhs.begin(), last1: lhs.end(), first2: rhs.begin());
50 };
51 return matchOperands(firstPloop.getLowerBound(),
52 secondPloop.getLowerBound()) &&
53 matchOperands(firstPloop.getUpperBound(),
54 secondPloop.getUpperBound()) &&
55 matchOperands(firstPloop.getStep(), secondPloop.getStep());
56}
57
58/// Checks if the parallel loops have mixed access to the same buffers. Returns
59/// `true` if the first parallel loop writes to the same indices that the second
60/// loop reads.
61static bool haveNoReadsAfterWriteExceptSameIndex(
62 ParallelOp firstPloop, ParallelOp secondPloop,
63 const IRMapping &firstToSecondPloopIndices,
64 llvm::function_ref<bool(Value, Value)> mayAlias) {
65 DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
66 SmallVector<Value> bufferStoresVec;
67 firstPloop.getBody()->walk([&](memref::StoreOp store) {
68 bufferStores[store.getMemRef()].push_back(store.getIndices());
69 bufferStoresVec.emplace_back(store.getMemRef());
70 });
71 auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
72 Value loadMem = load.getMemRef();
73 // Stop if the memref is defined in secondPloop body. Careful alias analysis
74 // is needed.
75 auto *memrefDef = loadMem.getDefiningOp();
76 if (memrefDef && memrefDef->getBlock() == load->getBlock())
77 return WalkResult::interrupt();
78
79 for (Value store : bufferStoresVec)
80 if (store != loadMem && mayAlias(store, loadMem))
81 return WalkResult::interrupt();
82
83 auto write = bufferStores.find(Val: loadMem);
84 if (write == bufferStores.end())
85 return WalkResult::advance();
86
87 // Check that at last one store was retrieved
88 if (write->second.empty())
89 return WalkResult::interrupt();
90
91 auto storeIndices = write->second.front();
92
93 // Multiple writes to the same memref are allowed only on the same indices
94 for (const auto &othStoreIndices : write->second) {
95 if (othStoreIndices != storeIndices)
96 return WalkResult::interrupt();
97 }
98
99 // Check that the load indices of secondPloop coincide with store indices of
100 // firstPloop for the same memrefs.
101 auto loadIndices = load.getIndices();
102 if (storeIndices.size() != loadIndices.size())
103 return WalkResult::interrupt();
104 for (int i = 0, e = storeIndices.size(); i < e; ++i) {
105 if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
106 loadIndices[i]) {
107 auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
108 auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
109 if (storeIndexDefOp && loadIndexDefOp) {
110 if (!isMemoryEffectFree(storeIndexDefOp))
111 return WalkResult::interrupt();
112 if (!isMemoryEffectFree(loadIndexDefOp))
113 return WalkResult::interrupt();
114 if (!OperationEquivalence::isEquivalentTo(
115 storeIndexDefOp, loadIndexDefOp,
116 [&](Value storeIndex, Value loadIndex) {
117 if (firstToSecondPloopIndices.lookupOrDefault(from: storeIndex) !=
118 firstToSecondPloopIndices.lookupOrDefault(from: loadIndex))
119 return failure();
120 else
121 return success();
122 },
123 /*markEquivalent=*/nullptr,
124 OperationEquivalence::Flags::IgnoreLocations)) {
125 return WalkResult::interrupt();
126 }
127 } else {
128 return WalkResult::interrupt();
129 }
130 }
131 }
132 return WalkResult::advance();
133 });
134 return !walkResult.wasInterrupted();
135}
136
137/// Analyzes dependencies in the most primitive way by checking simple read and
138/// write patterns.
139static LogicalResult
140verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
141 const IRMapping &firstToSecondPloopIndices,
142 llvm::function_ref<bool(Value, Value)> mayAlias) {
143 if (!haveNoReadsAfterWriteExceptSameIndex(
144 firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
145 return failure();
146
147 IRMapping secondToFirstPloopIndices;
148 secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
149 firstPloop.getBody()->getArguments());
150 return success(haveNoReadsAfterWriteExceptSameIndex(
151 secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
152}
153
154static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
155 const IRMapping &firstToSecondPloopIndices,
156 llvm::function_ref<bool(Value, Value)> mayAlias) {
157 return !hasNestedParallelOp(firstPloop) &&
158 !hasNestedParallelOp(secondPloop) &&
159 equalIterationSpaces(firstPloop, secondPloop) &&
160 succeeded(verifyDependencies(firstPloop, secondPloop,
161 firstToSecondPloopIndices, mayAlias));
162}
163
164/// Prepends operations of firstPloop's body into secondPloop's body.
165/// Updates secondPloop with new loop.
166static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
167 OpBuilder builder,
168 llvm::function_ref<bool(Value, Value)> mayAlias) {
169 Block *block1 = firstPloop.getBody();
170 Block *block2 = secondPloop.getBody();
171 IRMapping firstToSecondPloopIndices;
172 firstToSecondPloopIndices.map(from: block1->getArguments(), to: block2->getArguments());
173
174 if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
175 mayAlias))
176 return;
177
178 DominanceInfo dom;
179 // We are fusing first loop into second, make sure there are no users of the
180 // first loop results between loops.
181 for (Operation *user : firstPloop->getUsers())
182 if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
183 return;
184
185 ValueRange inits1 = firstPloop.getInitVals();
186 ValueRange inits2 = secondPloop.getInitVals();
187
188 SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
189 newInitVars.append(in_start: inits2.begin(), in_end: inits2.end());
190
191 IRRewriter b(builder);
192 b.setInsertionPoint(secondPloop);
193 auto newSecondPloop = b.create<ParallelOp>(
194 secondPloop.getLoc(), secondPloop.getLowerBound(),
195 secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
196
197 Block *newBlock = newSecondPloop.getBody();
198 auto term1 = cast<ReduceOp>(block1->getTerminator());
199 auto term2 = cast<ReduceOp>(block2->getTerminator());
200
201 b.inlineBlockBefore(source: block2, dest: newBlock, before: newBlock->begin(),
202 argValues: newBlock->getArguments());
203 b.inlineBlockBefore(source: block1, dest: newBlock, before: newBlock->begin(),
204 argValues: newBlock->getArguments());
205
206 ValueRange results = newSecondPloop.getResults();
207 if (!results.empty()) {
208 b.setInsertionPointToEnd(newBlock);
209
210 ValueRange reduceArgs1 = term1.getOperands();
211 ValueRange reduceArgs2 = term2.getOperands();
212 SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
213 newReduceArgs.append(in_start: reduceArgs2.begin(), in_end: reduceArgs2.end());
214
215 auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
216
217 for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
218 term1.getReductions(), term2.getReductions()))) {
219 Block &oldRedBlock = reg.front();
220 Block &newRedBlock = newReduceOp.getReductions()[i].front();
221 b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
222 newRedBlock.getArguments());
223 }
224
225 firstPloop.replaceAllUsesWith(results.take_front(n: inits1.size()));
226 secondPloop.replaceAllUsesWith(results.take_back(n: inits2.size()));
227 }
228 term1->erase();
229 term2->erase();
230 firstPloop.erase();
231 secondPloop.erase();
232 secondPloop = newSecondPloop;
233}
234
235void mlir::scf::naivelyFuseParallelOps(
236 Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
237 OpBuilder b(region);
238 // Consider every single block and attempt to fuse adjacent loops.
239 SmallVector<SmallVector<ParallelOp>, 1> ploopChains;
240 for (auto &block : region) {
241 ploopChains.clear();
242 ploopChains.push_back({});
243
244 // Not using `walk()` to traverse only top-level parallel loops and also
245 // make sure that there are no side-effecting ops between the parallel
246 // loops.
247 bool noSideEffects = true;
248 for (auto &op : block) {
249 if (auto ploop = dyn_cast<ParallelOp>(op)) {
250 if (noSideEffects) {
251 ploopChains.back().push_back(ploop);
252 } else {
253 ploopChains.push_back({ploop});
254 noSideEffects = true;
255 }
256 continue;
257 }
258 // TODO: Handle region side effects properly.
259 noSideEffects &= isMemoryEffectFree(op: &op) && op.getNumRegions() == 0;
260 }
261 for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
262 for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
263 fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
264 }
265 }
266}
267
268namespace {
269struct ParallelLoopFusion
270 : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
271 void runOnOperation() override {
272 auto &AA = getAnalysis<AliasAnalysis>();
273
274 auto mayAlias = [&](Value val1, Value val2) -> bool {
275 return !AA.alias(val1, val2).isNo();
276 };
277
278 getOperation()->walk([&](Operation *child) {
279 for (Region &region : child->getRegions())
280 naivelyFuseParallelOps(region, mayAlias);
281 });
282 }
283};
284} // namespace
285
286std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
287 return std::make_unique<ParallelLoopFusion>();
288}
289

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp