1//===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements functions concerned with hoisting invariant operations
10// in the context of Linalg transformations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
15#include "mlir/Analysis/SliceAnalysis.h"
16#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
17#include "mlir/Dialect/Affine/IR/AffineOps.h"
18#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19#include "mlir/Dialect/Affine/Utils.h"
20#include "mlir/Dialect/Arith/IR/Arith.h"
21#include "mlir/Dialect/Func/IR/FuncOps.h"
22#include "mlir/Dialect/Linalg/IR/Linalg.h"
23#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/Dialect/SCF/Utils/Utils.h"
26#include "mlir/Dialect/Tensor/IR/Tensor.h"
27#include "mlir/Dialect/Vector/IR/VectorOps.h"
28#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
29#include "mlir/IR/BuiltinOps.h"
30#include "mlir/IR/Dominance.h"
31#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
33#include "llvm/ADT/StringRef.h"
34#include "llvm/ADT/TypeSwitch.h"
35#include "llvm/Support/Debug.h"
36
37using llvm::dbgs;
38
39#define DEBUG_TYPE "linalg-hoisting"
40
41#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
42
43using namespace mlir;
44using namespace mlir::linalg;
45
46/// Replace `loop` with a new loop that has a different init operand at
47/// position `index`. The body of this loop is moved over to the new loop.
48///
49/// `newInitOperands` specifies the replacement "init" operands.
50/// `newYieldValue` is the replacement yield value of the loop at position
51/// `index`.
52static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter,
53 scf::ForOp loop,
54 Value newInitOperand,
55 unsigned index,
56 Value newYieldValue) {
57 OpBuilder::InsertionGuard g(rewriter);
58 rewriter.setInsertionPoint(loop.getOperation());
59 auto inits = llvm::to_vector(loop.getInits());
60
61 // Replace the init value with the new operand.
62 assert(index < inits.size());
63 inits[index] = newInitOperand;
64
65 scf::ForOp newLoop = rewriter.create<scf::ForOp>(
66 loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
67 inits, [](OpBuilder &, Location, Value, ValueRange) {});
68
69 // Generate the new yield with the replaced operand.
70 auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
71 yieldOp.setOperand(index, newYieldValue);
72
73 // Move the loop body to the new op.
74 rewriter.mergeBlocks(source: loop.getBody(), dest: newLoop.getBody(),
75 argValues: newLoop.getBody()->getArguments());
76
77 // Replace the old loop.
78 rewriter.replaceOp(loop.getOperation(), newLoop->getResults());
79 return newLoop;
80}
81
82// Hoist out a pair of corresponding vector.extract+vector.broadcast
83// operations. This function transforms a loop like this:
84// %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) {
85// %e = vector.extract %iarg : t1 to t2
86// %u = "some_use"(%e) : (t2) -> t2
87// %b = vector.broadcast %u : t2 to t1
88// scf.yield %b : t1
89// }
90// into the following:
91// %e = vector.extract %v: t1 to t2
92// %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) {
93// %u' = "some_use"(%iarg) : (t2) -> t2
94// scf.yield %u' : t2
95// }
96// %res = vector.broadcast %res' : t2 to t1
97void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
98 Operation *root) {
99 bool changed = true;
100 while (changed) {
101 changed = false;
102 // First move loop invariant ops outside of their loop. This needs to be
103 // done before as we cannot move ops without interrupting the function walk.
104 root->walk(
105 [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
106
107 root->walk(callback: [&](vector::ExtractOp extractOp) {
108 LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
109 << *extractOp.getOperation() << "\n");
110
111 auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp());
112 if (!loop)
113 return WalkResult::advance();
114
115 // Check that the vector to extract from is a BlockArgument.
116 auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
117 if (!blockArg)
118 return WalkResult::advance();
119
120 // Check that the blockArg is an iter_arg of the loop.
121 OpOperand *initArg = loop.getTiedLoopInit(blockArg);
122 if (!initArg)
123 return WalkResult::advance();
124
125 // If the iter_arg does not have only one use, it won't be possible to
126 // hoist the extractOp out.
127 if (!blockArg.hasOneUse())
128 return WalkResult::advance();
129
130 unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars();
131
132 // Check that the loop yields a broadcast that has just one use.
133 Operation *yieldedVal =
134 loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp();
135 auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
136 if (!broadcast || !broadcast.getResult().hasOneUse())
137 return WalkResult::advance();
138
139 LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
140
141 Type broadcastInputType = broadcast.getSourceType();
142 if (broadcastInputType != extractOp.getType())
143 return WalkResult::advance();
144
145 // The position of the extract must be defined outside of the loop if
146 // it is dynamic.
147 for (auto operand : extractOp.getDynamicPosition())
148 if (!loop.isDefinedOutsideOfLoop(operand))
149 return WalkResult::advance();
150
151 rewriter.modifyOpInPlace(broadcast, [&] {
152 extractOp.getVectorMutable().assign(initArg->get());
153 });
154 loop.moveOutOfLoop(extractOp);
155 rewriter.moveOpAfter(broadcast, loop);
156
157 scf::ForOp newLoop = replaceWithDifferentYield(
158 rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
159
160 LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
161
162 rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast);
163 rewriter.modifyOpInPlace(
164 broadcast, [&] { broadcast.setOperand(newLoop.getResult(index)); });
165
166 changed = true;
167 return WalkResult::interrupt();
168 });
169 }
170}
171
172static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
173 LoopLikeOpInterface loop) {
174 Value source = transferRead.getBase();
175
176 // Skip view-like Ops and retrive the actual soruce Operation
177 while (auto srcOp =
178 dyn_cast_or_null<ViewLikeOpInterface>(source.getDefiningOp()))
179 source = srcOp.getViewSource();
180
181 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
182 source.getUsers().end());
183 llvm::SmallDenseSet<Operation *, 32> processed;
184 while (!users.empty()) {
185 Operation *user = users.pop_back_val();
186 // If the user has already been processed skip.
187 if (!processed.insert(V: user).second)
188 continue;
189 if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
190 users.append(viewLike->getUsers().begin(), viewLike->getUsers().end());
191 continue;
192 }
193 if (isMemoryEffectFree(op: user) || isa<vector::TransferReadOp>(Val: user))
194 continue;
195 if (!loop->isAncestor(user))
196 continue;
197 return false;
198 }
199 return true;
200}
201
202void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
203 bool verifyNonZeroTrip) {
204 bool changed = true;
205 while (changed) {
206 changed = false;
207 // First move loop invariant ops outside of their loop. This needs to be
208 // done before as we cannot move ops without interrupting the function walk.
209 root->walk(
210 [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
211
212 // Find all loops that are certain to have non zero trip count. Any loops
213 // that are not part of this set cannot be hoisted from, since hoisting from
214 // a potentially zero trip count loop may cause a vector transfer to be
215 // executed when it shouldn't be.
216 llvm::DenseSet<LoopLikeOpInterface> definiteNonZeroTripCountLoops;
217 if (verifyNonZeroTrip) {
218 root->walk([&](LoopLikeOpInterface loopLike) {
219 std::optional<SmallVector<OpFoldResult>> lbs =
220 loopLike.getLoopLowerBounds();
221 std::optional<SmallVector<OpFoldResult>> ubs =
222 loopLike.getLoopUpperBounds();
223 // If loop bounds cannot be found, assume possibly zero trip count.
224 if (!lbs || !ubs)
225 return;
226
227 // Otherwise, use ValueBounds to find the maximum lower bound and
228 // minimum upper bound. If the bounds are found, and maxLb is less
229 // than the minUb, then the loop will not have zero trip count.
230 for (auto [lb, ub] : llvm::zip_equal(lbs.value(), ubs.value())) {
231 FailureOr<int64_t> maxLb =
232 ValueBoundsConstraintSet::computeConstantBound(
233 presburger::BoundType::UB, lb,
234 /*stopCondition=*/nullptr, /*closedUB=*/true);
235 if (failed(maxLb))
236 return;
237 FailureOr<int64_t> minUb =
238 ValueBoundsConstraintSet::computeConstantBound(
239 presburger::BoundType::LB, ub);
240 if (failed(minUb))
241 return;
242 if (minUb.value() <= maxLb.value())
243 return;
244 definiteNonZeroTripCountLoops.insert(loopLike);
245 }
246 });
247 }
248
249 root->walk([&](vector::TransferReadOp transferRead) {
250 if (!isa<MemRefType>(transferRead.getShapedType()))
251 return WalkResult::advance();
252
253 LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
254 << *transferRead.getOperation() << "\n");
255 auto loop = dyn_cast<LoopLikeOpInterface>(transferRead->getParentOp());
256 LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
257 << "\n");
258 if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop))
259 return WalkResult::advance();
260
261 if (verifyNonZeroTrip && !definiteNonZeroTripCountLoops.contains(V: loop)) {
262 LLVM_DEBUG(DBGS() << "Loop may have zero trip count: " << *loop
263 << "\n");
264 return WalkResult::advance();
265 }
266
267 LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
268 << "\n");
269
270 SetVector<Operation *> forwardSlice;
271 getForwardSlice(transferRead.getOperation(), &forwardSlice);
272
273 // Look for the last TransferWriteOp in the forwardSlice of
274 // `transferRead` that operates on the same memref.
275 vector::TransferWriteOp transferWrite;
276 for (auto *sliceOp : llvm::reverse(C&: forwardSlice)) {
277 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
278 if (!candidateWrite ||
279 candidateWrite.getBase() != transferRead.getBase())
280 continue;
281 transferWrite = candidateWrite;
282 }
283
284 // All operands of the TransferRead must be defined outside of the loop.
285 for (auto operand : transferRead.getOperands())
286 if (!loop.isDefinedOutsideOfLoop(operand))
287 return WalkResult::advance();
288
289 // Only hoist transfer_read / transfer_write pairs and singleton
290 // transfer_reads for now.
291 if (!transferWrite) {
292 // Make sure there are no other accesses to the memref before
293 // hoisting transfer_read.
294 if (noAliasingUseInLoop(transferRead, loop))
295 loop.moveOutOfLoop(transferRead);
296 return WalkResult::advance();
297 }
298
299 LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
300 << "\n");
301
302 // Approximate aliasing by checking that:
303 // 1. indices, vector type and permutation map are the same (i.e., the
304 // transfer_read/transfer_write ops are matching),
305 // 2. source operands for transfer.{read|write} do not originate from
306 // Ops implementing ViewLikeOpInterface.
307 // 3. no other operations in the loop access the same memref except
308 // for transfer_read/transfer_write accessing statically disjoint
309 // slices.
310 if (transferRead.getIndices() != transferWrite.getIndices() ||
311 transferRead.getVectorType() != transferWrite.getVectorType() ||
312 transferRead.getPermutationMap() != transferWrite.getPermutationMap())
313 return WalkResult::advance();
314
315 auto *source = transferRead.getBase().getDefiningOp();
316 if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
317 return WalkResult::advance();
318
319 source = transferWrite.getBase().getDefiningOp();
320 if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
321 return WalkResult::advance();
322
323 // TODO: may want to memoize this information for performance but it
324 // likely gets invalidated often.
325 DominanceInfo dom(loop);
326 if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
327 return WalkResult::advance();
328 for (auto &use : transferRead.getBase().getUses()) {
329 if (!loop->isAncestor(use.getOwner()))
330 continue;
331 if (use.getOwner() == transferRead.getOperation() ||
332 use.getOwner() == transferWrite.getOperation())
333 continue;
334 if (auto transferWriteUse =
335 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
336 if (!vector::isDisjointTransferSet(
337 cast<VectorTransferOpInterface>(*transferWrite),
338 cast<VectorTransferOpInterface>(*transferWriteUse),
339 /*testDynamicValueUsingBounds=*/true))
340 return WalkResult::advance();
341 } else if (auto transferReadUse =
342 dyn_cast<vector::TransferReadOp>(use.getOwner())) {
343 if (!vector::isDisjointTransferSet(
344 cast<VectorTransferOpInterface>(*transferWrite),
345 cast<VectorTransferOpInterface>(*transferReadUse),
346 /*testDynamicValueUsingBounds=*/true))
347 return WalkResult::advance();
348 } else {
349 // Unknown use, we cannot prove that it doesn't alias with the
350 // transferRead/transferWrite operations.
351 return WalkResult::advance();
352 }
353 }
354
355 // Hoist read before.
356 loop.moveOutOfLoop(transferRead);
357
358 // Hoist write after.
359 transferWrite->moveAfter(loop);
360
361 // Rewrite `loop` with new yields by cloning and erase the original loop.
362 IRRewriter rewriter(transferRead.getContext());
363 NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
364 ArrayRef<BlockArgument> newBBArgs) {
365 return SmallVector<Value>{transferWrite.getVector()};
366 };
367
368 auto maybeNewLoop = loop.replaceWithAdditionalYields(
369 rewriter, transferRead.getVector(),
370 /*replaceInitOperandUsesInLoop=*/true, yieldFn);
371 if (failed(maybeNewLoop))
372 return WalkResult::interrupt();
373
374 transferWrite.getValueToStoreMutable().assign(
375 maybeNewLoop->getOperation()->getResults().back());
376 changed = true;
377 // Need to interrupt and restart because erasing the loop messes up
378 // the walk.
379 return WalkResult::interrupt();
380 });
381 }
382}
383

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp