1//===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements functions concerned with hoisting invariant operations
10// in the context of Linalg transformations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
15#include "mlir/Analysis/SliceAnalysis.h"
16#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
17#include "mlir/Dialect/Affine/IR/AffineOps.h"
18#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19#include "mlir/Dialect/Affine/Utils.h"
20#include "mlir/Dialect/Arith/IR/Arith.h"
21#include "mlir/Dialect/Func/IR/FuncOps.h"
22#include "mlir/Dialect/Linalg/IR/Linalg.h"
23#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/Dialect/SCF/Utils/Utils.h"
26#include "mlir/Dialect/Tensor/IR/Tensor.h"
27#include "mlir/Dialect/Vector/IR/VectorOps.h"
28#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
29#include "mlir/IR/BuiltinOps.h"
30#include "mlir/IR/Dominance.h"
31#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
33#include "llvm/ADT/StringRef.h"
34#include "llvm/ADT/TypeSwitch.h"
35#include "llvm/Support/Debug.h"
36
37using llvm::dbgs;
38
39#define DEBUG_TYPE "linalg-hoisting"
40
41#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
42
43using namespace mlir;
44using namespace mlir::linalg;
45
46/// Replace `loop` with a new loop that has a different init operand at
47/// position `index`. The body of this loop is moved over to the new loop.
48///
49/// `newInitOperands` specifies the replacement "init" operands.
50/// `newYieldValue` is the replacement yield value of the loop at position
51/// `index`.
52static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter,
53 scf::ForOp loop,
54 Value newInitOperand,
55 unsigned index,
56 Value newYieldValue) {
57 OpBuilder::InsertionGuard g(rewriter);
58 rewriter.setInsertionPoint(loop.getOperation());
59 auto inits = llvm::to_vector(loop.getInits());
60
61 // Replace the init value with the new operand.
62 assert(index < inits.size());
63 inits[index] = newInitOperand;
64
65 scf::ForOp newLoop = rewriter.create<scf::ForOp>(
66 loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
67 inits, [](OpBuilder &, Location, Value, ValueRange) {});
68
69 // Generate the new yield with the replaced operand.
70 auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
71 yieldOp.setOperand(index, newYieldValue);
72
73 // Move the loop body to the new op.
74 rewriter.mergeBlocks(source: loop.getBody(), dest: newLoop.getBody(),
75 argValues: newLoop.getBody()->getArguments());
76
77 // Replace the old loop.
78 rewriter.replaceOp(loop.getOperation(), newLoop->getResults());
79 return newLoop;
80}
81
82// Hoist out a pair of corresponding vector.extract+vector.broadcast
83// operations. This function transforms a loop like this:
84// %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) {
85// %e = vector.extract %iarg : t1 to t2
86// %u = "some_use"(%e) : (t2) -> t2
87// %b = vector.broadcast %u : t2 to t1
88// scf.yield %b : t1
89// }
90// into the following:
91// %e = vector.extract %v: t1 to t2
92// %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) {
93// %u' = "some_use"(%iarg) : (t2) -> t2
94// scf.yield %u' : t2
95// }
96// %res = vector.broadcast %res' : t2 to t1
97void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
98 Operation *root) {
99 bool changed = true;
100 while (changed) {
101 changed = false;
102 // First move loop invariant ops outside of their loop. This needs to be
103 // done before as we cannot move ops without interrupting the function walk.
104 root->walk(
105 [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
106
107 root->walk(callback: [&](vector::ExtractOp extractOp) {
108 LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
109 << *extractOp.getOperation() << "\n");
110
111 auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp());
112 if (!loop)
113 return WalkResult::advance();
114
115 // Check that the vector to extract from is a BlockArgument.
116 auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
117 if (!blockArg)
118 return WalkResult::advance();
119
120 // Check that the blockArg is an iter_arg of the loop.
121 OpOperand *initArg = loop.getTiedLoopInit(blockArg);
122 if (!initArg)
123 return WalkResult::advance();
124
125 // If the iter_arg does not have only one use, it won't be possible to
126 // hoist the extractOp out.
127 if (!blockArg.hasOneUse())
128 return WalkResult::advance();
129
130 unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars();
131
132 // Check that the loop yields a broadcast that has just one use.
133 Operation *yieldedVal =
134 loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp();
135 auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
136 if (!broadcast || !broadcast.getResult().hasOneUse())
137 return WalkResult::advance();
138
139 LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
140
141 Type broadcastInputType = broadcast.getSourceType();
142 if (broadcastInputType != extractOp.getType())
143 return WalkResult::advance();
144
145 // The position of the extract must be defined outside of the loop if
146 // it is dynamic.
147 for (auto operand : extractOp.getDynamicPosition())
148 if (!loop.isDefinedOutsideOfLoop(operand))
149 return WalkResult::advance();
150
151 rewriter.modifyOpInPlace(broadcast, [&] {
152 extractOp.getVectorMutable().assign(initArg->get());
153 });
154 loop.moveOutOfLoop(extractOp);
155 rewriter.moveOpAfter(broadcast, loop);
156
157 scf::ForOp newLoop = replaceWithDifferentYield(
158 rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
159
160 LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
161
162 rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast);
163 rewriter.modifyOpInPlace(
164 broadcast, [&] { broadcast.setOperand(newLoop.getResult(index)); });
165
166 changed = true;
167 return WalkResult::interrupt();
168 });
169 }
170}
171
172static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
173 LoopLikeOpInterface loop) {
174 Value source = transferRead.getSource();
175
176 // Skip view-like Ops and retrive the actual soruce Operation
177 while (auto srcOp =
178 dyn_cast_or_null<ViewLikeOpInterface>(source.getDefiningOp()))
179 source = srcOp.getViewSource();
180
181 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
182 source.getUsers().end());
183 llvm::SmallDenseSet<Operation *, 32> processed;
184 while (!users.empty()) {
185 Operation *user = users.pop_back_val();
186 // If the user has already been processed skip.
187 if (!processed.insert(V: user).second)
188 continue;
189 if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
190 users.append(viewLike->getUsers().begin(), viewLike->getUsers().end());
191 continue;
192 }
193 if (isMemoryEffectFree(op: user) || isa<vector::TransferReadOp>(Val: user))
194 continue;
195 if (!loop->isAncestor(user))
196 continue;
197 return false;
198 }
199 return true;
200}
201
202void mlir::linalg::hoistRedundantVectorTransfers(Operation *root) {
203 bool changed = true;
204 while (changed) {
205 changed = false;
206 // First move loop invariant ops outside of their loop. This needs to be
207 // done before as we cannot move ops without interrupting the function walk.
208 root->walk(
209 [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
210
211 root->walk([&](vector::TransferReadOp transferRead) {
212 if (!isa<MemRefType>(transferRead.getShapedType()))
213 return WalkResult::advance();
214
215 LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
216 << *transferRead.getOperation() << "\n");
217 auto loop = dyn_cast<LoopLikeOpInterface>(transferRead->getParentOp());
218 LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
219 << "\n");
220 if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop))
221 return WalkResult::advance();
222
223 LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
224 << "\n");
225
226 SetVector<Operation *> forwardSlice;
227 getForwardSlice(transferRead.getOperation(), &forwardSlice);
228
229 // Look for the last TransferWriteOp in the forwardSlice of
230 // `transferRead` that operates on the same memref.
231 vector::TransferWriteOp transferWrite;
232 for (auto *sliceOp : llvm::reverse(C&: forwardSlice)) {
233 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
234 if (!candidateWrite ||
235 candidateWrite.getSource() != transferRead.getSource())
236 continue;
237 transferWrite = candidateWrite;
238 }
239
240 // All operands of the TransferRead must be defined outside of the loop.
241 for (auto operand : transferRead.getOperands())
242 if (!loop.isDefinedOutsideOfLoop(operand))
243 return WalkResult::advance();
244
245 // Only hoist transfer_read / transfer_write pairs and singleton
246 // transfer_reads for now.
247 if (!transferWrite) {
248 // Make sure there are no other accesses to the memref before
249 // hoisting transfer_read.
250 if (noAliasingUseInLoop(transferRead, loop))
251 loop.moveOutOfLoop(transferRead);
252 return WalkResult::advance();
253 }
254
255 LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
256 << "\n");
257
258 // Approximate aliasing by checking that:
259 // 1. indices, vector type and permutation map are the same (i.e., the
260 // transfer_read/transfer_write ops are matching),
261 // 2. source operands for transfer.{read|write} do not originate from
262 // Ops implementing ViewLikeOpInterface.
263 // 3. no other operations in the loop access the same memref except
264 // for transfer_read/transfer_write accessing statically disjoint
265 // slices.
266 if (transferRead.getIndices() != transferWrite.getIndices() ||
267 transferRead.getVectorType() != transferWrite.getVectorType() ||
268 transferRead.getPermutationMap() != transferWrite.getPermutationMap())
269 return WalkResult::advance();
270
271 auto *source = transferRead.getSource().getDefiningOp();
272 if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
273 return WalkResult::advance();
274
275 source = transferWrite.getSource().getDefiningOp();
276 if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
277 return WalkResult::advance();
278
279 // TODO: may want to memoize this information for performance but it
280 // likely gets invalidated often.
281 DominanceInfo dom(loop);
282 if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
283 return WalkResult::advance();
284 for (auto &use : transferRead.getSource().getUses()) {
285 if (!loop->isAncestor(use.getOwner()))
286 continue;
287 if (use.getOwner() == transferRead.getOperation() ||
288 use.getOwner() == transferWrite.getOperation())
289 continue;
290 if (auto transferWriteUse =
291 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
292 if (!vector::isDisjointTransferSet(
293 cast<VectorTransferOpInterface>(*transferWrite),
294 cast<VectorTransferOpInterface>(*transferWriteUse),
295 /*testDynamicValueUsingBounds=*/true))
296 return WalkResult::advance();
297 } else if (auto transferReadUse =
298 dyn_cast<vector::TransferReadOp>(use.getOwner())) {
299 if (!vector::isDisjointTransferSet(
300 cast<VectorTransferOpInterface>(*transferWrite),
301 cast<VectorTransferOpInterface>(*transferReadUse),
302 /*testDynamicValueUsingBounds=*/true))
303 return WalkResult::advance();
304 } else {
305 // Unknown use, we cannot prove that it doesn't alias with the
306 // transferRead/transferWrite operations.
307 return WalkResult::advance();
308 }
309 }
310
311 // Hoist read before.
312 loop.moveOutOfLoop(transferRead);
313
314 // Hoist write after.
315 transferWrite->moveAfter(loop);
316
317 // Rewrite `loop` with new yields by cloning and erase the original loop.
318 IRRewriter rewriter(transferRead.getContext());
319 NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
320 ArrayRef<BlockArgument> newBBArgs) {
321 return SmallVector<Value>{transferWrite.getVector()};
322 };
323
324 auto maybeNewLoop = loop.replaceWithAdditionalYields(
325 rewriter, transferRead.getVector(),
326 /*replaceInitOperandUsesInLoop=*/true, yieldFn);
327 if (failed(maybeNewLoop))
328 return WalkResult::interrupt();
329
330 transferWrite.getVectorMutable().assign(
331 maybeNewLoop->getOperation()->getResults().back());
332 changed = true;
333 // Need to interrupt and restart because erasing the loop messes up
334 // the walk.
335 return WalkResult::interrupt();
336 });
337 }
338}
339

source code of mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp