1 | //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements functions concerned with hoisting invariant operations |
10 | // in the context of Linalg transformations. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" |
15 | #include "mlir/Analysis/SliceAnalysis.h" |
16 | #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" |
17 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
18 | #include "mlir/Dialect/Affine/IR/AffineValueMap.h" |
19 | #include "mlir/Dialect/Affine/Utils.h" |
20 | #include "mlir/Dialect/Arith/IR/Arith.h" |
21 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
22 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
23 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
24 | #include "mlir/Dialect/SCF/IR/SCF.h" |
25 | #include "mlir/Dialect/SCF/Utils/Utils.h" |
26 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
27 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
28 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
29 | #include "mlir/IR/BuiltinOps.h" |
30 | #include "mlir/IR/Dominance.h" |
31 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
32 | #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" |
33 | #include "llvm/ADT/StringRef.h" |
34 | #include "llvm/ADT/TypeSwitch.h" |
35 | #include "llvm/Support/Debug.h" |
36 | |
37 | using llvm::dbgs; |
38 | |
39 | #define DEBUG_TYPE "linalg-hoisting" |
40 | |
41 | #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") |
42 | |
43 | using namespace mlir; |
44 | using namespace mlir::linalg; |
45 | |
46 | /// Replace `loop` with a new loop that has a different init operand at |
47 | /// position `index`. The body of this loop is moved over to the new loop. |
48 | /// |
49 | /// `newInitOperands` specifies the replacement "init" operands. |
50 | /// `newYieldValue` is the replacement yield value of the loop at position |
51 | /// `index`. |
52 | static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, |
53 | scf::ForOp loop, |
54 | Value newInitOperand, |
55 | unsigned index, |
56 | Value newYieldValue) { |
57 | OpBuilder::InsertionGuard g(rewriter); |
58 | rewriter.setInsertionPoint(loop.getOperation()); |
59 | auto inits = llvm::to_vector(loop.getInits()); |
60 | |
61 | // Replace the init value with the new operand. |
62 | assert(index < inits.size()); |
63 | inits[index] = newInitOperand; |
64 | |
65 | scf::ForOp newLoop = rewriter.create<scf::ForOp>( |
66 | loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), |
67 | inits, [](OpBuilder &, Location, Value, ValueRange) {}); |
68 | |
69 | // Generate the new yield with the replaced operand. |
70 | auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator()); |
71 | yieldOp.setOperand(index, newYieldValue); |
72 | |
73 | // Move the loop body to the new op. |
74 | rewriter.mergeBlocks(source: loop.getBody(), dest: newLoop.getBody(), |
75 | argValues: newLoop.getBody()->getArguments()); |
76 | |
77 | // Replace the old loop. |
78 | rewriter.replaceOp(loop.getOperation(), newLoop->getResults()); |
79 | return newLoop; |
80 | } |
81 | |
82 | // Hoist out a pair of corresponding vector.extract+vector.broadcast |
83 | // operations. This function transforms a loop like this: |
84 | // %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) { |
85 | // %e = vector.extract %iarg : t1 to t2 |
86 | // %u = "some_use"(%e) : (t2) -> t2 |
87 | // %b = vector.broadcast %u : t2 to t1 |
88 | // scf.yield %b : t1 |
89 | // } |
90 | // into the following: |
91 | // %e = vector.extract %v: t1 to t2 |
92 | // %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) { |
93 | // %u' = "some_use"(%iarg) : (t2) -> t2 |
94 | // scf.yield %u' : t2 |
95 | // } |
96 | // %res = vector.broadcast %res' : t2 to t1 |
97 | void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter, |
98 | Operation *root) { |
99 | bool changed = true; |
100 | while (changed) { |
101 | changed = false; |
102 | // First move loop invariant ops outside of their loop. This needs to be |
103 | // done before as we cannot move ops without interrupting the function walk. |
104 | root->walk( |
105 | [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); |
106 | |
107 | root->walk(callback: [&](vector::ExtractOp ) { |
108 | LLVM_DEBUG(DBGS() << "Candidate for hoisting: " |
109 | << *extractOp.getOperation() << "\n" ); |
110 | |
111 | auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp()); |
112 | if (!loop) |
113 | return WalkResult::advance(); |
114 | |
115 | // Check that the vector to extract from is a BlockArgument. |
116 | auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector()); |
117 | if (!blockArg) |
118 | return WalkResult::advance(); |
119 | |
120 | // Check that the blockArg is an iter_arg of the loop. |
121 | OpOperand *initArg = loop.getTiedLoopInit(blockArg); |
122 | if (!initArg) |
123 | return WalkResult::advance(); |
124 | |
125 | // If the iter_arg does not have only one use, it won't be possible to |
126 | // hoist the extractOp out. |
127 | if (!blockArg.hasOneUse()) |
128 | return WalkResult::advance(); |
129 | |
130 | unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars(); |
131 | |
132 | // Check that the loop yields a broadcast that has just one use. |
133 | Operation *yieldedVal = |
134 | loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp(); |
135 | auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal); |
136 | if (!broadcast || !broadcast.getResult().hasOneUse()) |
137 | return WalkResult::advance(); |
138 | |
139 | LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n" ); |
140 | |
141 | Type broadcastInputType = broadcast.getSourceType(); |
142 | if (broadcastInputType != extractOp.getType()) |
143 | return WalkResult::advance(); |
144 | |
145 | // The position of the extract must be defined outside of the loop if |
146 | // it is dynamic. |
147 | for (auto operand : extractOp.getDynamicPosition()) |
148 | if (!loop.isDefinedOutsideOfLoop(operand)) |
149 | return WalkResult::advance(); |
150 | |
151 | rewriter.modifyOpInPlace(broadcast, [&] { |
152 | extractOp.getVectorMutable().assign(initArg->get()); |
153 | }); |
154 | loop.moveOutOfLoop(extractOp); |
155 | rewriter.moveOpAfter(broadcast, loop); |
156 | |
157 | scf::ForOp newLoop = replaceWithDifferentYield( |
158 | rewriter, loop, extractOp.getResult(), index, broadcast.getSource()); |
159 | |
160 | LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n" ); |
161 | |
162 | rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast); |
163 | rewriter.modifyOpInPlace( |
164 | broadcast, [&] { broadcast.setOperand(newLoop.getResult(index)); }); |
165 | |
166 | changed = true; |
167 | return WalkResult::interrupt(); |
168 | }); |
169 | } |
170 | } |
171 | |
172 | static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, |
173 | LoopLikeOpInterface loop) { |
174 | Value source = transferRead.getSource(); |
175 | |
176 | // Skip view-like Ops and retrive the actual soruce Operation |
177 | while (auto srcOp = |
178 | dyn_cast_or_null<ViewLikeOpInterface>(source.getDefiningOp())) |
179 | source = srcOp.getViewSource(); |
180 | |
181 | llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), |
182 | source.getUsers().end()); |
183 | llvm::SmallDenseSet<Operation *, 32> processed; |
184 | while (!users.empty()) { |
185 | Operation *user = users.pop_back_val(); |
186 | // If the user has already been processed skip. |
187 | if (!processed.insert(V: user).second) |
188 | continue; |
189 | if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) { |
190 | users.append(viewLike->getUsers().begin(), viewLike->getUsers().end()); |
191 | continue; |
192 | } |
193 | if (isMemoryEffectFree(op: user) || isa<vector::TransferReadOp>(Val: user)) |
194 | continue; |
195 | if (!loop->isAncestor(user)) |
196 | continue; |
197 | return false; |
198 | } |
199 | return true; |
200 | } |
201 | |
202 | void mlir::linalg::hoistRedundantVectorTransfers(Operation *root) { |
203 | bool changed = true; |
204 | while (changed) { |
205 | changed = false; |
206 | // First move loop invariant ops outside of their loop. This needs to be |
207 | // done before as we cannot move ops without interrupting the function walk. |
208 | root->walk( |
209 | [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); |
210 | |
211 | root->walk([&](vector::TransferReadOp transferRead) { |
212 | if (!isa<MemRefType>(transferRead.getShapedType())) |
213 | return WalkResult::advance(); |
214 | |
215 | LLVM_DEBUG(DBGS() << "Candidate for hoisting: " |
216 | << *transferRead.getOperation() << "\n" ); |
217 | auto loop = dyn_cast<LoopLikeOpInterface>(transferRead->getParentOp()); |
218 | LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp() |
219 | << "\n" ); |
220 | if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop)) |
221 | return WalkResult::advance(); |
222 | |
223 | LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() |
224 | << "\n" ); |
225 | |
226 | SetVector<Operation *> forwardSlice; |
227 | getForwardSlice(transferRead.getOperation(), &forwardSlice); |
228 | |
229 | // Look for the last TransferWriteOp in the forwardSlice of |
230 | // `transferRead` that operates on the same memref. |
231 | vector::TransferWriteOp transferWrite; |
232 | for (auto *sliceOp : llvm::reverse(C&: forwardSlice)) { |
233 | auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp); |
234 | if (!candidateWrite || |
235 | candidateWrite.getSource() != transferRead.getSource()) |
236 | continue; |
237 | transferWrite = candidateWrite; |
238 | } |
239 | |
240 | // All operands of the TransferRead must be defined outside of the loop. |
241 | for (auto operand : transferRead.getOperands()) |
242 | if (!loop.isDefinedOutsideOfLoop(operand)) |
243 | return WalkResult::advance(); |
244 | |
245 | // Only hoist transfer_read / transfer_write pairs and singleton |
246 | // transfer_reads for now. |
247 | if (!transferWrite) { |
248 | // Make sure there are no other accesses to the memref before |
249 | // hoisting transfer_read. |
250 | if (noAliasingUseInLoop(transferRead, loop)) |
251 | loop.moveOutOfLoop(transferRead); |
252 | return WalkResult::advance(); |
253 | } |
254 | |
255 | LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() |
256 | << "\n" ); |
257 | |
258 | // Approximate aliasing by checking that: |
259 | // 1. indices, vector type and permutation map are the same (i.e., the |
260 | // transfer_read/transfer_write ops are matching), |
261 | // 2. source operands for transfer.{read|write} do not originate from |
262 | // Ops implementing ViewLikeOpInterface. |
263 | // 3. no other operations in the loop access the same memref except |
264 | // for transfer_read/transfer_write accessing statically disjoint |
265 | // slices. |
266 | if (transferRead.getIndices() != transferWrite.getIndices() || |
267 | transferRead.getVectorType() != transferWrite.getVectorType() || |
268 | transferRead.getPermutationMap() != transferWrite.getPermutationMap()) |
269 | return WalkResult::advance(); |
270 | |
271 | auto *source = transferRead.getSource().getDefiningOp(); |
272 | if (source && isa_and_nonnull<ViewLikeOpInterface>(source)) |
273 | return WalkResult::advance(); |
274 | |
275 | source = transferWrite.getSource().getDefiningOp(); |
276 | if (source && isa_and_nonnull<ViewLikeOpInterface>(source)) |
277 | return WalkResult::advance(); |
278 | |
279 | // TODO: may want to memoize this information for performance but it |
280 | // likely gets invalidated often. |
281 | DominanceInfo dom(loop); |
282 | if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) |
283 | return WalkResult::advance(); |
284 | for (auto &use : transferRead.getSource().getUses()) { |
285 | if (!loop->isAncestor(use.getOwner())) |
286 | continue; |
287 | if (use.getOwner() == transferRead.getOperation() || |
288 | use.getOwner() == transferWrite.getOperation()) |
289 | continue; |
290 | if (auto transferWriteUse = |
291 | dyn_cast<vector::TransferWriteOp>(use.getOwner())) { |
292 | if (!vector::isDisjointTransferSet( |
293 | cast<VectorTransferOpInterface>(*transferWrite), |
294 | cast<VectorTransferOpInterface>(*transferWriteUse), |
295 | /*testDynamicValueUsingBounds=*/true)) |
296 | return WalkResult::advance(); |
297 | } else if (auto transferReadUse = |
298 | dyn_cast<vector::TransferReadOp>(use.getOwner())) { |
299 | if (!vector::isDisjointTransferSet( |
300 | cast<VectorTransferOpInterface>(*transferWrite), |
301 | cast<VectorTransferOpInterface>(*transferReadUse), |
302 | /*testDynamicValueUsingBounds=*/true)) |
303 | return WalkResult::advance(); |
304 | } else { |
305 | // Unknown use, we cannot prove that it doesn't alias with the |
306 | // transferRead/transferWrite operations. |
307 | return WalkResult::advance(); |
308 | } |
309 | } |
310 | |
311 | // Hoist read before. |
312 | loop.moveOutOfLoop(transferRead); |
313 | |
314 | // Hoist write after. |
315 | transferWrite->moveAfter(loop); |
316 | |
317 | // Rewrite `loop` with new yields by cloning and erase the original loop. |
318 | IRRewriter rewriter(transferRead.getContext()); |
319 | NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc, |
320 | ArrayRef<BlockArgument> newBBArgs) { |
321 | return SmallVector<Value>{transferWrite.getVector()}; |
322 | }; |
323 | |
324 | auto maybeNewLoop = loop.replaceWithAdditionalYields( |
325 | rewriter, transferRead.getVector(), |
326 | /*replaceInitOperandUsesInLoop=*/true, yieldFn); |
327 | if (failed(maybeNewLoop)) |
328 | return WalkResult::interrupt(); |
329 | |
330 | transferWrite.getVectorMutable().assign( |
331 | maybeNewLoop->getOperation()->getResults().back()); |
332 | changed = true; |
333 | // Need to interrupt and restart because erasing the loop messes up |
334 | // the walk. |
335 | return WalkResult::interrupt(); |
336 | }); |
337 | } |
338 | } |
339 | |