1//===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements functions concerned with hoisting invariant operations
10// in the context of Linalg transformations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
15#include "mlir/Analysis/SliceAnalysis.h"
16#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
17#include "mlir/Dialect/Affine/IR/AffineOps.h"
18#include "mlir/Dialect/Affine/Utils.h"
19#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20#include "mlir/Dialect/SCF/IR/SCF.h"
21#include "mlir/Dialect/SCF/Utils/Utils.h"
22#include "mlir/Dialect/Vector/IR/VectorOps.h"
23#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
24#include "mlir/IR/Dominance.h"
25#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
26#include "llvm/Support/Debug.h"
27
28using llvm::dbgs;
29
30#define DEBUG_TYPE "linalg-hoisting"
31
32#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
33
34using namespace mlir;
35using namespace mlir::linalg;
36
37/// Replace `loop` with a new loop that has a different init operand at
38/// position `index`. The body of this loop is moved over to the new loop.
39///
40/// `newInitOperands` specifies the replacement "init" operands.
41/// `newYieldValue` is the replacement yield value of the loop at position
42/// `index`.
43static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter,
44 scf::ForOp loop,
45 Value newInitOperand,
46 unsigned index,
47 Value newYieldValue) {
48 OpBuilder::InsertionGuard g(rewriter);
49 rewriter.setInsertionPoint(loop.getOperation());
50 auto inits = llvm::to_vector(Range: loop.getInits());
51
52 // Replace the init value with the new operand.
53 assert(index < inits.size());
54 inits[index] = newInitOperand;
55
56 scf::ForOp newLoop = rewriter.create<scf::ForOp>(
57 location: loop.getLoc(), args: loop.getLowerBound(), args: loop.getUpperBound(), args: loop.getStep(),
58 args&: inits, args: [](OpBuilder &, Location, Value, ValueRange) {});
59
60 // Generate the new yield with the replaced operand.
61 auto yieldOp = cast<scf::YieldOp>(Val: loop.getBody()->getTerminator());
62 yieldOp.setOperand(i: index, value: newYieldValue);
63
64 // Move the loop body to the new op.
65 rewriter.mergeBlocks(source: loop.getBody(), dest: newLoop.getBody(),
66 argValues: newLoop.getBody()->getArguments());
67
68 // Replace the old loop.
69 rewriter.replaceOp(op: loop.getOperation(), newValues: newLoop->getResults());
70 return newLoop;
71}
72
73// Hoist out a pair of corresponding vector.extract+vector.broadcast
74// operations. This function transforms a loop like this:
75// %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) {
76// %e = vector.extract %iarg : t1 to t2
77// %u = "some_use"(%e) : (t2) -> t2
78// %b = vector.broadcast %u : t2 to t1
79// scf.yield %b : t1
80// }
81// into the following:
82// %e = vector.extract %v: t1 to t2
83// %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) {
84// %u' = "some_use"(%iarg) : (t2) -> t2
85// scf.yield %u' : t2
86// }
87// %res = vector.broadcast %res' : t2 to t1
88void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
89 Operation *root) {
90 bool changed = true;
91 while (changed) {
92 changed = false;
93 // First move loop invariant ops outside of their loop. This needs to be
94 // done before as we cannot move ops without interrupting the function walk.
95 root->walk(
96 callback: [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
97
98 root->walk(callback: [&](vector::ExtractOp extractOp) {
99 LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
100 << *extractOp.getOperation() << "\n");
101
102 auto loop = dyn_cast<scf::ForOp>(Val: extractOp->getParentOp());
103 if (!loop)
104 return WalkResult::advance();
105
106 // Check that the vector to extract from is a BlockArgument.
107 auto blockArg = dyn_cast<BlockArgument>(Val: extractOp.getVector());
108 if (!blockArg)
109 return WalkResult::advance();
110
111 // Check that the blockArg is an iter_arg of the loop.
112 OpOperand *initArg = loop.getTiedLoopInit(bbArg: blockArg);
113 if (!initArg)
114 return WalkResult::advance();
115
116 // If the iter_arg does not have only one use, it won't be possible to
117 // hoist the extractOp out.
118 if (!blockArg.hasOneUse())
119 return WalkResult::advance();
120
121 unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars();
122
123 // Check that the loop yields a broadcast that has just one use.
124 Operation *yieldedVal =
125 loop.getTiedLoopYieldedValue(bbArg: blockArg)->get().getDefiningOp();
126 auto broadcast = dyn_cast<vector::BroadcastOp>(Val: yieldedVal);
127 if (!broadcast || !broadcast.getResult().hasOneUse())
128 return WalkResult::advance();
129
130 LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
131
132 Type broadcastInputType = broadcast.getSourceType();
133 if (broadcastInputType != extractOp.getType())
134 return WalkResult::advance();
135
136 // The position of the extract must be defined outside of the loop if
137 // it is dynamic.
138 for (auto operand : extractOp.getDynamicPosition())
139 if (!loop.isDefinedOutsideOfLoop(value: operand))
140 return WalkResult::advance();
141
142 rewriter.modifyOpInPlace(root: broadcast, callable: [&] {
143 extractOp.getVectorMutable().assign(value: initArg->get());
144 });
145 loop.moveOutOfLoop(op: extractOp);
146 rewriter.moveOpAfter(op: broadcast, existingOp: loop);
147
148 scf::ForOp newLoop = replaceWithDifferentYield(
149 rewriter, loop, newInitOperand: extractOp.getResult(), index, newYieldValue: broadcast.getSource());
150
151 LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
152
153 rewriter.replaceAllUsesWith(from: newLoop.getResult(i: index), to: broadcast);
154 rewriter.modifyOpInPlace(
155 root: broadcast, callable: [&] { broadcast.setOperand(newLoop.getResult(i: index)); });
156
157 changed = true;
158 return WalkResult::interrupt();
159 });
160 }
161}
162
163static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
164 LoopLikeOpInterface loop) {
165 Value source = transferRead.getBase();
166
167 // Skip view-like Ops and retrive the actual soruce Operation
168 while (auto srcOp =
169 dyn_cast_or_null<ViewLikeOpInterface>(Val: source.getDefiningOp()))
170 source = srcOp.getViewSource();
171
172 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
173 source.getUsers().end());
174 llvm::SmallDenseSet<Operation *, 32> processed;
175 while (!users.empty()) {
176 Operation *user = users.pop_back_val();
177 // If the user has already been processed skip.
178 if (!processed.insert(V: user).second)
179 continue;
180 if (auto viewLike = dyn_cast<ViewLikeOpInterface>(Val: user)) {
181 users.append(in_start: viewLike->getUsers().begin(), in_end: viewLike->getUsers().end());
182 continue;
183 }
184 if (isMemoryEffectFree(op: user) || isa<vector::TransferReadOp>(Val: user))
185 continue;
186 if (!loop->isAncestor(other: user))
187 continue;
188 return false;
189 }
190 return true;
191}
192
193void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
194 bool verifyNonZeroTrip) {
195 bool changed = true;
196 while (changed) {
197 changed = false;
198 // First move loop invariant ops outside of their loop. This needs to be
199 // done before as we cannot move ops without interrupting the function walk.
200 root->walk(
201 callback: [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
202
203 // Find all loops that are certain to have non zero trip count. Any loops
204 // that are not part of this set cannot be hoisted from, since hoisting from
205 // a potentially zero trip count loop may cause a vector transfer to be
206 // executed when it shouldn't be.
207 llvm::DenseSet<LoopLikeOpInterface> definiteNonZeroTripCountLoops;
208 if (verifyNonZeroTrip) {
209 root->walk(callback: [&](LoopLikeOpInterface loopLike) {
210 std::optional<SmallVector<OpFoldResult>> lbs =
211 loopLike.getLoopLowerBounds();
212 std::optional<SmallVector<OpFoldResult>> ubs =
213 loopLike.getLoopUpperBounds();
214 // If loop bounds cannot be found, assume possibly zero trip count.
215 if (!lbs || !ubs)
216 return;
217
218 // Otherwise, use ValueBounds to find the maximum lower bound and
219 // minimum upper bound. If the bounds are found, and maxLb is less
220 // than the minUb, then the loop will not have zero trip count.
221 for (auto [lb, ub] : llvm::zip_equal(t&: lbs.value(), u&: ubs.value())) {
222 FailureOr<int64_t> maxLb =
223 ValueBoundsConstraintSet::computeConstantBound(
224 type: presburger::BoundType::UB, var: lb,
225 /*stopCondition=*/nullptr, /*closedUB=*/true);
226 if (failed(Result: maxLb))
227 return;
228 FailureOr<int64_t> minUb =
229 ValueBoundsConstraintSet::computeConstantBound(
230 type: presburger::BoundType::LB, var: ub);
231 if (failed(Result: minUb))
232 return;
233 if (minUb.value() <= maxLb.value())
234 return;
235 definiteNonZeroTripCountLoops.insert(V: loopLike);
236 }
237 });
238 }
239
240 root->walk(callback: [&](vector::TransferReadOp transferRead) {
241 if (!isa<MemRefType>(Val: transferRead.getShapedType()))
242 return WalkResult::advance();
243
244 LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
245 << *transferRead.getOperation() << "\n");
246 auto loop = dyn_cast<LoopLikeOpInterface>(Val: transferRead->getParentOp());
247 LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
248 << "\n");
249 if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(Val: loop))
250 return WalkResult::advance();
251
252 if (verifyNonZeroTrip && !definiteNonZeroTripCountLoops.contains(V: loop)) {
253 LLVM_DEBUG(DBGS() << "Loop may have zero trip count: " << *loop
254 << "\n");
255 return WalkResult::advance();
256 }
257
258 LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
259 << "\n");
260
261 SetVector<Operation *> forwardSlice;
262 getForwardSlice(op: transferRead.getOperation(), forwardSlice: &forwardSlice);
263
264 // Look for the last TransferWriteOp in the forwardSlice of
265 // `transferRead` that operates on the same memref.
266 vector::TransferWriteOp transferWrite;
267 for (auto *sliceOp : llvm::reverse(C&: forwardSlice)) {
268 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(Val: sliceOp);
269 if (!candidateWrite ||
270 candidateWrite.getBase() != transferRead.getBase())
271 continue;
272 transferWrite = candidateWrite;
273 }
274
275 // All operands of the TransferRead must be defined outside of the loop.
276 for (auto operand : transferRead.getOperands())
277 if (!loop.isDefinedOutsideOfLoop(value: operand))
278 return WalkResult::advance();
279
280 // Only hoist transfer_read / transfer_write pairs and singleton
281 // transfer_reads for now.
282 if (!transferWrite) {
283 // Make sure there are no other accesses to the memref before
284 // hoisting transfer_read.
285 if (noAliasingUseInLoop(transferRead, loop))
286 loop.moveOutOfLoop(op: transferRead);
287 return WalkResult::advance();
288 }
289
290 LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
291 << "\n");
292
293 // Approximate aliasing by checking that:
294 // 1. indices, vector type and permutation map are the same (i.e., the
295 // transfer_read/transfer_write ops are matching),
296 // 2. source operands for transfer.{read|write} do not originate from
297 // nor have users that are Ops implementing ViewLikeOpInterface.
298 // 3. no other operations in the loop access the same memref except
299 // for transfer_read/transfer_write accessing statically disjoint
300 // slices.
301
302 // Check 1.
303 if (transferRead.getIndices() != transferWrite.getIndices() ||
304 transferRead.getVectorType() != transferWrite.getVectorType() ||
305 transferRead.getPermutationMap() != transferWrite.getPermutationMap())
306 return WalkResult::advance();
307
308 // Check 2. Note, since both xfer Ops share the source, we only need to
309 // look at one of them.
310 auto base = transferRead.getBase();
311 auto *source = base.getDefiningOp();
312 if (source) {
313 // NOTE: We treat `memref.assume_alignment` as a special case.
314 //
315 // The idea is that it is safe to look past AssumeAlignmemtOp (i.e.
316 // MemRef _before_ alignment) iff:
317 // 1. It has exactly two uses (these have to be the xfer Ops
318 // being looked at).
319 // 2. The original MemRef has only one use (i.e.
320 // AssumeAlignmentOp).
321 //
322 // Relaxing these conditions will most likely require proper alias
323 // analysis.
324 if (auto assume = dyn_cast<memref::AssumeAlignmentOp>(Val: source)) {
325 Value memPreAlignment = assume.getMemref();
326 auto numInLoopUses =
327 llvm::count_if(Range: base.getUses(), P: [&loop](OpOperand &use) {
328 return loop->isAncestor(other: use.getOwner());
329 });
330
331 if (numInLoopUses && memPreAlignment.hasOneUse())
332 source = memPreAlignment.getDefiningOp();
333 }
334 if (isa_and_nonnull<ViewLikeOpInterface>(Val: source))
335 return WalkResult::advance();
336 }
337
338 if (llvm::any_of(Range: base.getUsers(), P: llvm::IsaPred<ViewLikeOpInterface>))
339 return WalkResult::advance();
340
341 // Check 3.
342 // TODO: may want to memoize this information for performance but it
343 // likely gets invalidated often.
344 DominanceInfo dom(loop);
345 if (!dom.properlyDominates(a: transferRead.getOperation(), b: transferWrite))
346 return WalkResult::advance();
347 for (auto &use : transferRead.getBase().getUses()) {
348 if (!loop->isAncestor(other: use.getOwner()))
349 continue;
350 if (use.getOwner() == transferRead.getOperation() ||
351 use.getOwner() == transferWrite.getOperation())
352 continue;
353 if (auto transferWriteUse =
354 dyn_cast<vector::TransferWriteOp>(Val: use.getOwner())) {
355 if (!vector::isDisjointTransferSet(
356 transferA: cast<VectorTransferOpInterface>(Val&: *transferWrite),
357 transferB: cast<VectorTransferOpInterface>(Val&: *transferWriteUse),
358 /*testDynamicValueUsingBounds=*/true))
359 return WalkResult::advance();
360 } else if (auto transferReadUse =
361 dyn_cast<vector::TransferReadOp>(Val: use.getOwner())) {
362 if (!vector::isDisjointTransferSet(
363 transferA: cast<VectorTransferOpInterface>(Val&: *transferWrite),
364 transferB: cast<VectorTransferOpInterface>(Val&: *transferReadUse),
365 /*testDynamicValueUsingBounds=*/true))
366 return WalkResult::advance();
367 } else {
368 // Unknown use, we cannot prove that it doesn't alias with the
369 // transferRead/transferWrite operations.
370 return WalkResult::advance();
371 }
372 }
373
374 // Hoist read before.
375 loop.moveOutOfLoop(op: transferRead);
376
377 // Hoist write after.
378 transferWrite->moveAfter(existingOp: loop);
379
380 // Rewrite `loop` with new yields by cloning and erase the original
381 // loop.
382 IRRewriter rewriter(transferRead.getContext());
383 NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
384 ArrayRef<BlockArgument> newBBArgs) {
385 return SmallVector<Value>{transferWrite.getVector()};
386 };
387
388 auto maybeNewLoop = loop.replaceWithAdditionalYields(
389 rewriter, newInitOperands: transferRead.getVector(),
390 /*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn: yieldFn);
391 if (failed(Result: maybeNewLoop))
392 return WalkResult::interrupt();
393
394 transferWrite.getValueToStoreMutable().assign(
395 value: maybeNewLoop->getOperation()->getResults().back());
396 changed = true;
397 // Need to interrupt and restart because erasing the loop messes up
398 // the walk.
399 return WalkResult::interrupt();
400 });
401 }
402}
403

source code of mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp