1 | //===- LoopInvariantCodeMotionUtils.cpp - LICM Utils ------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains the implementation of the core LICM algorithm. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" |
14 | |
15 | #include "mlir/IR/Operation.h" |
16 | #include "mlir/IR/PatternMatch.h" |
17 | #include "mlir/Interfaces/LoopLikeInterface.h" |
18 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
19 | #include "mlir/Interfaces/SubsetOpInterface.h" |
20 | #include "llvm/Support/Debug.h" |
21 | #include <queue> |
22 | |
23 | #define DEBUG_TYPE "licm" |
24 | |
25 | using namespace mlir; |
26 | |
27 | /// Checks whether the given op can be hoisted by checking that |
28 | /// - the op and none of its contained operations depend on values inside of the |
29 | /// loop (by means of calling definedOutside). |
30 | /// - the op has no side-effects. |
31 | static bool canBeHoisted(Operation *op, |
32 | function_ref<bool(OpOperand &)> condition) { |
33 | // Do not move terminators. |
34 | if (op->hasTrait<OpTrait::IsTerminator>()) |
35 | return false; |
36 | |
37 | // Walk the nested operations and check that all used values are either |
38 | // defined outside of the loop or in a nested region, but not at the level of |
39 | // the loop body. |
40 | auto walkFn = [&](Operation *child) { |
41 | for (OpOperand &operand : child->getOpOperands()) { |
42 | // Ignore values defined in a nested region. |
43 | if (op->isAncestor(other: operand.get().getParentRegion()->getParentOp())) |
44 | continue; |
45 | if (!condition(operand)) |
46 | return WalkResult::interrupt(); |
47 | } |
48 | return WalkResult::advance(); |
49 | }; |
50 | return !op->walk(callback&: walkFn).wasInterrupted(); |
51 | } |
52 | |
53 | static bool canBeHoisted(Operation *op, |
54 | function_ref<bool(Value)> definedOutside) { |
55 | return canBeHoisted( |
56 | op, condition: [&](OpOperand &operand) { return definedOutside(operand.get()); }); |
57 | } |
58 | |
59 | size_t mlir::moveLoopInvariantCode( |
60 | ArrayRef<Region *> regions, |
61 | function_ref<bool(Value, Region *)> isDefinedOutsideRegion, |
62 | function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion, |
63 | function_ref<void(Operation *, Region *)> moveOutOfRegion) { |
64 | size_t numMoved = 0; |
65 | |
66 | for (Region *region : regions) { |
67 | LLVM_DEBUG(llvm::dbgs() << "Original loop:\n" |
68 | << *region->getParentOp() << "\n" ); |
69 | |
70 | std::queue<Operation *> worklist; |
71 | // Add top-level operations in the loop body to the worklist. |
72 | for (Operation &op : region->getOps()) |
73 | worklist.push(x: &op); |
74 | |
75 | auto definedOutside = [&](Value value) { |
76 | return isDefinedOutsideRegion(value, region); |
77 | }; |
78 | |
79 | while (!worklist.empty()) { |
80 | Operation *op = worklist.front(); |
81 | worklist.pop(); |
82 | // Skip ops that have already been moved. Check if the op can be hoisted. |
83 | if (op->getParentRegion() != region) |
84 | continue; |
85 | |
86 | LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n" ); |
87 | if (!shouldMoveOutOfRegion(op, region) || |
88 | !canBeHoisted(op, definedOutside)) |
89 | continue; |
90 | |
91 | LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n" ); |
92 | moveOutOfRegion(op, region); |
93 | ++numMoved; |
94 | |
95 | // Since the op has been moved, we need to check its users within the |
96 | // top-level of the loop body. |
97 | for (Operation *user : op->getUsers()) |
98 | if (user->getParentRegion() == region) |
99 | worklist.push(x: user); |
100 | } |
101 | } |
102 | |
103 | return numMoved; |
104 | } |
105 | |
106 | size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) { |
107 | return moveLoopInvariantCode( |
108 | loopLike.getLoopRegions(), |
109 | [&](Value value, Region *) { |
110 | return loopLike.isDefinedOutsideOfLoop(value); |
111 | }, |
112 | [&](Operation *op, Region *) { |
113 | return isMemoryEffectFree(op) && isSpeculatable(op); |
114 | }, |
115 | [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); }); |
116 | } |
117 | |
118 | namespace { |
119 | /// Helper data structure that keeps track of equivalent/disjoint subset ops. |
120 | class MatchingSubsets { |
121 | public: |
122 | /// Insert a subset op. |
123 | void insert(SubsetOpInterface op, bool collectHoistableOps = true) { |
124 | allSubsetOps.push_back(op); |
125 | if (!collectHoistableOps) |
126 | return; |
127 | if (auto extractionOp = |
128 | dyn_cast<SubsetExtractionOpInterface>(op.getOperation())) |
129 | insertExtractionOp(extractionOp: extractionOp); |
130 | if (auto insertionOp = |
131 | dyn_cast<SubsetInsertionOpInterface>(op.getOperation())) |
132 | insertInsertionOp(insertionOp: insertionOp); |
133 | } |
134 | |
135 | /// Return a range of matching extraction-insertion subset ops. If there is no |
136 | /// matching extraction/insertion op, the respective value is empty. Ops are |
137 | /// skipped if there are other subset ops that are not guaranteed to operate |
138 | /// on disjoint subsets. |
139 | auto getHoistableSubsetOps() { |
140 | return llvm::make_filter_range( |
141 | llvm::zip(extractions, insertions), [&](auto pair) { |
142 | auto [extractionOp, insertionOp] = pair; |
143 | // Hoist only if the extracted and inserted values have the same type. |
144 | if (extractionOp && insertionOp && |
145 | extractionOp->getResult(0).getType() != |
146 | insertionOp.getSourceOperand().get().getType()) |
147 | return false; |
148 | // Hoist only if there are no conflicting subset ops. |
149 | return allDisjoint(extractionOp, insertionOp); |
150 | }); |
151 | } |
152 | |
153 | /// Populate subset ops starting from the given region iter_arg. Return |
154 | /// "failure" if non-subset ops are found along the path to the loop yielding |
155 | /// op or if there is no single path to the tied yielded operand. If |
156 | /// `collectHoistableOps` is set to "false", subset ops are gathered |
157 | /// throughout the traversal, but not enumerated by `getHoistableSubsetOps`. |
158 | LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike, |
159 | BlockArgument iterArg, |
160 | bool collectHoistableOps = true); |
161 | |
162 | private: |
163 | /// Helper function for equivalence of tensor values. Since only insertion |
164 | /// subset ops (that are also destination style ops) are followed when |
165 | /// traversing the SSA use-def chain, all tensor values are equivalent. |
166 | static bool isEquivalent(Value v1, Value v2) { return true; } |
167 | |
168 | /// Return "true" if the subsets of the given extraction and insertion ops |
169 | /// are operating disjoint from the subsets that all other known subset ops |
170 | /// are operating on. |
171 | bool (SubsetExtractionOpInterface , |
172 | SubsetInsertionOpInterface insertionOp) const { |
173 | for (SubsetOpInterface other : allSubsetOps) { |
174 | if (other == extractionOp || other == insertionOp) |
175 | continue; |
176 | if (extractionOp && |
177 | !other.operatesOnDisjointSubset(extractionOp, isEquivalent)) |
178 | return false; |
179 | if (insertionOp && |
180 | !other.operatesOnDisjointSubset(insertionOp, isEquivalent)) |
181 | return false; |
182 | } |
183 | return true; |
184 | } |
185 | |
186 | /// Insert a subset extraction op. If the subset is equivalent to an existing |
187 | /// subset insertion op, pair them up. (If there is already a paired up subset |
188 | /// extraction op, overwrite the subset extraction op.) |
189 | void (SubsetExtractionOpInterface ) { |
190 | for (auto it : llvm::enumerate(insertions)) { |
191 | if (!it.value()) |
192 | continue; |
193 | auto other = cast<SubsetOpInterface>(it.value().getOperation()); |
194 | if (other.operatesOnEquivalentSubset(extractionOp, isEquivalent)) { |
195 | extractions[it.index()] = extractionOp; |
196 | return; |
197 | } |
198 | } |
199 | // There is no known equivalent insertion op. Create a new entry. |
200 | extractions.push_back(extractionOp); |
201 | insertions.push_back({}); |
202 | } |
203 | |
204 | /// Insert a subset insertion op. If the subset is equivalent to an existing |
205 | /// subset extraction op, pair them up. (If there is already a paired up |
206 | /// subset insertion op, overwrite the subset insertion op.) |
207 | void insertInsertionOp(SubsetInsertionOpInterface insertionOp) { |
208 | for (auto it : llvm::enumerate(extractions)) { |
209 | if (!it.value()) |
210 | continue; |
211 | auto other = cast<SubsetOpInterface>(it.value().getOperation()); |
212 | if (other.operatesOnEquivalentSubset(insertionOp, isEquivalent)) { |
213 | insertions[it.index()] = insertionOp; |
214 | return; |
215 | } |
216 | } |
217 | // There is no known equivalent extraction op. Create a new entry. |
218 | extractions.push_back({}); |
219 | insertions.push_back(insertionOp); |
220 | } |
221 | |
222 | SmallVector<SubsetExtractionOpInterface> ; |
223 | SmallVector<SubsetInsertionOpInterface> insertions; |
224 | SmallVector<SubsetOpInterface> allSubsetOps; |
225 | }; |
226 | } // namespace |
227 | |
228 | /// If the given value has a single use by an op that is a terminator, return |
229 | /// that use. Otherwise, return nullptr. |
230 | static OpOperand *getSingleTerminatorUse(Value value) { |
231 | if (!value.hasOneUse()) |
232 | return nullptr; |
233 | OpOperand &use = *value.getUses().begin(); |
234 | if (use.getOwner()->hasTrait<OpTrait::IsTerminator>()) |
235 | return &use; |
236 | return nullptr; |
237 | } |
238 | |
239 | LogicalResult |
240 | MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike, |
241 | BlockArgument iterArg, |
242 | bool collectHoistableOps) { |
243 | assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg" ); |
244 | Value value = iterArg; |
245 | |
246 | // Traverse use-def chain. Subset ops can be hoisted only if all ops along the |
247 | // use-def chain starting from the region iter_arg are subset extraction or |
248 | // subset insertion ops. The chain must terminate at the corresponding yield |
249 | // operand (e.g., no swapping of iter_args). |
250 | OpOperand *yieldedOperand = nullptr; |
251 | // Iterate until the single use of the current SSA value is a terminator, |
252 | // which is expected to be the yielding operation of the loop. |
253 | while (!(yieldedOperand = getSingleTerminatorUse(value))) { |
254 | Value nextValue = {}; |
255 | |
256 | for (OpOperand &use : value.getUses()) { |
257 | if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) { |
258 | // Subset ops in nested loops are collected to check if there are only |
259 | // disjoint subset ops, but such subset ops are not subject to hoisting. |
260 | // To hoist subset ops from nested loops, the hoisting transformation |
261 | // should be run on the nested loop. |
262 | auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use); |
263 | if (!nestedIterArg) |
264 | return failure(); |
265 | // Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA |
266 | // use-def chain starting at `nestedIterArg` and terminating in the |
267 | // tied, yielding operand. |
268 | if (failed(populateSubsetOpsAtIterArg(loopLike: nestedLoop, iterArg: nestedIterArg, |
269 | /*collectHoistableOps=*/false))) |
270 | return failure(); |
271 | nextValue = nestedLoop.getTiedLoopResult(&use); |
272 | continue; |
273 | } |
274 | |
275 | auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner()); |
276 | if (!subsetOp) |
277 | return failure(); |
278 | insert(op: subsetOp); |
279 | |
280 | if (auto insertionOp = |
281 | dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) { |
282 | // Current implementation expects that the insertionOp implement |
283 | // the destinationStyleOpInterface as well. Abort if that tha is not |
284 | // the case |
285 | if (!isa<DestinationStyleOpInterface>(use.getOwner())) { |
286 | return failure(); |
287 | } |
288 | |
289 | // The value must be used as a destination. (In case of a source, the |
290 | // entire tensor would be read, which would prevent any hoisting.) |
291 | if (&use != &insertionOp.getDestinationOperand()) |
292 | return failure(); |
293 | // There must be a single use-def chain from the region iter_arg to the |
294 | // terminator. I.e., only one insertion op. Branches are not supported. |
295 | if (nextValue) |
296 | return failure(); |
297 | nextValue = insertionOp.getUpdatedDestination(); |
298 | } |
299 | } |
300 | |
301 | // Nothing can be hoisted if the chain does not continue with loop yielding |
302 | // op or a subset insertion op. |
303 | if (!nextValue) |
304 | return failure(); |
305 | value = nextValue; |
306 | } |
307 | |
308 | // Hoist only if the SSA use-def chain ends in the yielding terminator of the |
309 | // loop and the yielded value is the `idx`-th operand. (I.e., there is no |
310 | // swapping yield.) |
311 | if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand) |
312 | return failure(); |
313 | |
314 | return success(); |
315 | } |
316 | |
317 | /// Hoist all subset ops that operate on the idx-th region iter_arg of the given |
318 | /// loop-like op and index into loop-invariant subset locations. Return the |
319 | /// newly created loop op (that has extra iter_args) or the original loop op if |
320 | /// nothing was hoisted. |
321 | static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter, |
322 | LoopLikeOpInterface loopLike, |
323 | BlockArgument iterArg) { |
324 | assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg" ); |
325 | auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg); |
326 | int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it); |
327 | MatchingSubsets subsets; |
328 | if (failed(subsets.populateSubsetOpsAtIterArg(loopLike: loopLike, iterArg))) |
329 | return loopLike; |
330 | |
331 | // Hoist all matching extraction-insertion pairs one-by-one. |
332 | for (auto it : subsets.getHoistableSubsetOps()) { |
333 | auto extractionOp = std::get<0>(it); |
334 | auto insertionOp = std::get<1>(it); |
335 | |
336 | // Ops cannot be hoisted if they depend on loop-variant values. |
337 | if (extractionOp) { |
338 | if (!canBeHoisted(extractionOp, [&](OpOperand &operand) { |
339 | return loopLike.isDefinedOutsideOfLoop(operand.get()) || |
340 | &operand == &extractionOp.getSourceOperand(); |
341 | })) |
342 | extractionOp = {}; |
343 | } |
344 | if (insertionOp) { |
345 | if (!canBeHoisted(insertionOp, [&](OpOperand &operand) { |
346 | return loopLike.isDefinedOutsideOfLoop(operand.get()) || |
347 | &operand == &insertionOp.getSourceOperand() || |
348 | &operand == &insertionOp.getDestinationOperand(); |
349 | })) |
350 | insertionOp = {}; |
351 | } |
352 | |
353 | // Only hoist extraction-insertion pairs for now. Standalone extractions/ |
354 | // insertions that are loop-invariant could be hoisted, but there may be |
355 | // easier ways to canonicalize the IR. |
356 | if (extractionOp && insertionOp) { |
357 | // Create a new loop with an additional iter_arg. |
358 | NewYieldValuesFn newYieldValuesFn = |
359 | [&](OpBuilder &b, Location loc, |
360 | ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> { |
361 | return {insertionOp.getSourceOperand().get()}; |
362 | }; |
363 | FailureOr<LoopLikeOpInterface> newLoop = |
364 | loopLike.replaceWithAdditionalYields( |
365 | rewriter, extractionOp.getResult(), |
366 | /*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn); |
367 | if (failed(newLoop)) |
368 | return loopLike; |
369 | loopLike = *newLoop; |
370 | |
371 | // Hoist the extraction/insertion ops. |
372 | iterArg = loopLike.getRegionIterArgs()[iterArgIdx]; |
373 | OpResult loopResult = loopLike.getTiedLoopResult(iterArg); |
374 | OpResult newLoopResult = loopLike.getLoopResults()->back(); |
375 | rewriter.moveOpBefore(extractionOp, loopLike); |
376 | rewriter.moveOpAfter(insertionOp, loopLike); |
377 | rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(), |
378 | insertionOp.getDestinationOperand().get()); |
379 | extractionOp.getSourceOperand().set( |
380 | loopLike.getTiedLoopInit(iterArg)->get()); |
381 | rewriter.replaceAllUsesWith(loopResult, |
382 | insertionOp.getUpdatedDestination()); |
383 | insertionOp.getSourceOperand().set(newLoopResult); |
384 | insertionOp.getDestinationOperand().set(loopResult); |
385 | } |
386 | } |
387 | |
388 | return loopLike; |
389 | } |
390 | |
391 | LoopLikeOpInterface |
392 | mlir::hoistLoopInvariantSubsets(RewriterBase &rewriter, |
393 | LoopLikeOpInterface loopLike) { |
394 | // Note: As subset ops are getting hoisted, the number of region iter_args |
395 | // increases. This can enable further hoisting opportunities on the new |
396 | // iter_args. |
397 | for (int64_t i = 0; |
398 | i < static_cast<int64_t>(loopLike.getRegionIterArgs().size()); ++i) { |
399 | loopLike = hoistSubsetAtIterArg(rewriter, loopLike, |
400 | loopLike.getRegionIterArgs()[i]); |
401 | } |
402 | return loopLike; |
403 | } |
404 | |