1 | //===- BufferDeallocationSimplification.cpp -------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements logic for optimizing `bufferization.dealloc` operations |
10 | // that requires more analysis than what can be supported by regular |
11 | // canonicalization patterns. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
16 | #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" |
17 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
18 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
19 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
20 | #include "mlir/IR/Matchers.h" |
21 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
22 | |
23 | namespace mlir { |
24 | namespace bufferization { |
25 | #define GEN_PASS_DEF_BUFFERDEALLOCATIONSIMPLIFICATION |
26 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" |
27 | } // namespace bufferization |
28 | } // namespace mlir |
29 | |
30 | using namespace mlir; |
31 | using namespace mlir::bufferization; |
32 | |
33 | //===----------------------------------------------------------------------===// |
34 | // Helpers |
35 | //===----------------------------------------------------------------------===// |
36 | |
37 | /// Given a memref value, return the "base" value by skipping over all |
38 | /// ViewLikeOpInterface ops (if any) in the reverse use-def chain. |
39 | static Value getViewBase(Value value) { |
40 | while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) |
41 | value = viewLikeOp.getViewSource(); |
42 | return value; |
43 | } |
44 | |
45 | static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, |
46 | ValueRange memrefs, |
47 | ValueRange conditions, |
48 | PatternRewriter &rewriter) { |
49 | if (deallocOp.getMemrefs() == memrefs && |
50 | deallocOp.getConditions() == conditions) |
51 | return failure(); |
52 | |
53 | rewriter.modifyOpInPlace(deallocOp, [&]() { |
54 | deallocOp.getMemrefsMutable().assign(memrefs); |
55 | deallocOp.getConditionsMutable().assign(conditions); |
56 | }); |
57 | return success(); |
58 | } |
59 | |
60 | /// Return "true" if the given values are guaranteed to be different (and |
61 | /// non-aliasing) allocations based on the fact that one value is the result |
62 | /// of an allocation and the other value is a block argument of a parent block. |
63 | /// Note: This is a best-effort analysis that will eventually be replaced by a |
64 | /// proper "is same allocation" analysis. This function may return "false" even |
65 | /// though the two values are distinct allocations. |
66 | static bool distinctAllocAndBlockArgument(Value v1, Value v2) { |
67 | Value v1Base = getViewBase(value: v1); |
68 | Value v2Base = getViewBase(value: v2); |
69 | auto areDistinct = [](Value v1, Value v2) { |
70 | if (Operation *op = v1.getDefiningOp()) |
71 | if (hasEffect<MemoryEffects::Allocate>(op, value: v1)) |
72 | if (auto bbArg = dyn_cast<BlockArgument>(Val&: v2)) |
73 | if (bbArg.getOwner()->findAncestorOpInBlock(op&: *op)) |
74 | return true; |
75 | return false; |
76 | }; |
77 | return areDistinct(v1Base, v2Base) || areDistinct(v2Base, v1Base); |
78 | } |
79 | |
80 | /// Checks if `memref` may potentially alias a MemRef in `otherList`. It is |
81 | /// often a requirement of optimization patterns that there cannot be any |
82 | /// aliasing memref in order to perform the desired simplification. |
83 | static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis, |
84 | ValueRange otherList, Value memref) { |
85 | for (auto other : otherList) { |
86 | if (distinctAllocAndBlockArgument(v1: other, v2: memref)) |
87 | continue; |
88 | std::optional<bool> analysisResult = |
89 | analysis.isSameAllocation(v1: other, v2: memref); |
90 | if (!analysisResult.has_value() || analysisResult == true) |
91 | return true; |
92 | } |
93 | return false; |
94 | } |
95 | |
96 | //===----------------------------------------------------------------------===// |
97 | // Patterns |
98 | //===----------------------------------------------------------------------===// |
99 | |
100 | namespace { |
101 | |
102 | /// Remove values from the `memref` operand list that are also present in the |
103 | /// `retained` list (or a guaranteed alias of it) because they will never |
104 | /// actually be deallocated. However, we also need to be certain about which |
105 | /// other memrefs in the `retained` list can alias, i.e., there must not by any |
106 | /// may-aliasing memref. This is necessary because the `dealloc` operation is |
107 | /// defined to return one `i1` value per memref in the `retained` list which |
108 | /// represents the disjunction of the condition values corresponding to all |
109 | /// aliasing values in the `memref` list. In particular, this means that if |
110 | /// there is some value R in the `retained` list which aliases with a value M in |
111 | /// the `memref` list (but can only be staticaly determined to may-alias) and M |
112 | /// is also present in the `retained` list, then it would be illegal to remove M |
113 | /// because the result corresponding to R would be computed incorrectly |
114 | /// afterwards. Because we require an alias analysis, this pattern cannot be |
115 | /// applied as a regular canonicalization pattern. |
116 | /// |
117 | /// Example: |
118 | /// ```mlir |
119 | /// %0:3 = bufferization.dealloc (%m0 : ...) if (%cond0) |
120 | /// retain (%m0, %r0, %r1 : ...) |
121 | /// ``` |
122 | /// is canonicalized to |
123 | /// ```mlir |
124 | /// // bufferization.dealloc without memrefs and conditions returns %false for |
125 | /// // every retained value |
126 | /// %0:3 = bufferization.dealloc retain (%m0, %r0, %r1 : ...) |
127 | /// %1 = arith.ori %0#0, %cond0 : i1 |
128 | /// // replace %0#0 with %1 |
129 | /// ``` |
130 | /// given that `%r0` and `%r1` may not alias with `%m0`. |
131 | struct RemoveDeallocMemrefsContainedInRetained |
132 | : public OpRewritePattern<DeallocOp> { |
133 | RemoveDeallocMemrefsContainedInRetained(MLIRContext *context, |
134 | BufferOriginAnalysis &analysis) |
135 | : OpRewritePattern<DeallocOp>(context), analysis(analysis) {} |
136 | |
137 | /// The passed 'memref' must not have a may-alias relation to any retained |
138 | /// memref, and at least one must-alias relation. If there is no must-aliasing |
139 | /// memref in the retain list, we cannot simply remove the memref as there |
140 | /// could be situations in which it actually has to be deallocated. If it's |
141 | /// no-alias, then just proceed, if it's must-alias we need to update the |
142 | /// updated condition returned by the dealloc operation for that alias. |
143 | LogicalResult handleOneMemref(DeallocOp deallocOp, Value memref, Value cond, |
144 | PatternRewriter &rewriter) const { |
145 | rewriter.setInsertionPointAfter(deallocOp); |
146 | |
147 | // Check that there is no may-aliasing memref and that at least one memref |
148 | // in the retain list aliases (because otherwise it might have to be |
149 | // deallocated in some situations and can thus not be dropped). |
150 | bool atLeastOneMustAlias = false; |
151 | for (Value retained : deallocOp.getRetained()) { |
152 | std::optional<bool> analysisResult = |
153 | analysis.isSameAllocation(retained, memref); |
154 | if (!analysisResult.has_value()) |
155 | return failure(); |
156 | if (analysisResult == true) |
157 | atLeastOneMustAlias = true; |
158 | } |
159 | if (!atLeastOneMustAlias) |
160 | return failure(); |
161 | |
162 | // Insert arith.ori operations to update the corresponding dealloc result |
163 | // values to incorporate the condition of the must-aliasing memref such that |
164 | // we can remove that operand later on. |
165 | for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) { |
166 | Value updatedCondition = deallocOp.getUpdatedConditions()[i]; |
167 | std::optional<bool> analysisResult = |
168 | analysis.isSameAllocation(retained, memref); |
169 | if (analysisResult == true) { |
170 | auto disjunction = rewriter.create<arith::OrIOp>( |
171 | deallocOp.getLoc(), updatedCondition, cond); |
172 | rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(), |
173 | disjunction); |
174 | } |
175 | } |
176 | |
177 | return success(); |
178 | } |
179 | |
180 | LogicalResult matchAndRewrite(DeallocOp deallocOp, |
181 | PatternRewriter &rewriter) const override { |
182 | // There must not be any duplicates in the retain list anymore because we |
183 | // would miss updating one of the result values otherwise. |
184 | DenseSet<Value> retained(deallocOp.getRetained().begin(), |
185 | deallocOp.getRetained().end()); |
186 | if (retained.size() != deallocOp.getRetained().size()) |
187 | return failure(); |
188 | |
189 | SmallVector<Value> newMemrefs, newConditions; |
190 | for (auto [memref, cond] : |
191 | llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { |
192 | |
193 | if (succeeded(handleOneMemref(deallocOp, memref, cond, rewriter))) |
194 | continue; |
195 | |
196 | if (auto extractOp = |
197 | memref.getDefiningOp<memref::ExtractStridedMetadataOp>()) |
198 | if (succeeded(handleOneMemref(deallocOp, extractOp.getOperand(), cond, |
199 | rewriter))) |
200 | continue; |
201 | |
202 | newMemrefs.push_back(memref); |
203 | newConditions.push_back(cond); |
204 | } |
205 | |
206 | // Return failure if we don't change anything such that we don't run into an |
207 | // infinite loop of pattern applications. |
208 | return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, |
209 | rewriter); |
210 | } |
211 | |
212 | private: |
213 | BufferOriginAnalysis &analysis; |
214 | }; |
215 | |
216 | /// Remove memrefs from the `retained` list which are guaranteed to not alias |
217 | /// any memref in the `memrefs` list. The corresponding result value can be |
218 | /// replaced with `false` in that case according to the operation description. |
219 | /// |
220 | /// Example: |
221 | /// ```mlir |
222 | /// %0:2 = bufferization.dealloc (%m : memref<2xi32>) if (%cond) |
223 | /// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>) |
224 | /// return %0#0, %0#1 |
225 | /// ``` |
226 | /// can be canonicalized to the following given that `%r0` and `%r1` do not |
227 | /// alias `%m`: |
228 | /// ```mlir |
229 | /// bufferization.dealloc (%m : memref<2xi32>) if (%cond) |
230 | /// return %false, %false |
231 | /// ``` |
232 | struct RemoveRetainedMemrefsGuaranteedToNotAlias |
233 | : public OpRewritePattern<DeallocOp> { |
234 | RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context, |
235 | BufferOriginAnalysis &analysis) |
236 | : OpRewritePattern<DeallocOp>(context), analysis(analysis) {} |
237 | |
238 | LogicalResult matchAndRewrite(DeallocOp deallocOp, |
239 | PatternRewriter &rewriter) const override { |
240 | SmallVector<Value> newRetainedMemrefs, replacements; |
241 | |
242 | for (auto retainedMemref : deallocOp.getRetained()) { |
243 | if (potentiallyAliasesMemref(analysis, deallocOp.getMemrefs(), |
244 | retainedMemref)) { |
245 | newRetainedMemrefs.push_back(retainedMemref); |
246 | replacements.push_back({}); |
247 | continue; |
248 | } |
249 | |
250 | replacements.push_back(rewriter.create<arith::ConstantOp>( |
251 | deallocOp.getLoc(), rewriter.getBoolAttr(false))); |
252 | } |
253 | |
254 | if (newRetainedMemrefs.size() == deallocOp.getRetained().size()) |
255 | return failure(); |
256 | |
257 | auto newDeallocOp = rewriter.create<DeallocOp>( |
258 | deallocOp.getLoc(), deallocOp.getMemrefs(), deallocOp.getConditions(), |
259 | newRetainedMemrefs); |
260 | int i = 0; |
261 | for (auto &repl : replacements) { |
262 | if (!repl) |
263 | repl = newDeallocOp.getUpdatedConditions()[i++]; |
264 | } |
265 | |
266 | rewriter.replaceOp(deallocOp, replacements); |
267 | return success(); |
268 | } |
269 | |
270 | private: |
271 | BufferOriginAnalysis &analysis; |
272 | }; |
273 | |
274 | /// Split off memrefs to separate dealloc operations to reduce the number of |
275 | /// runtime checks required and enable further canonicalization of the new and |
276 | /// simpler dealloc operations. A memref can be split off if it is guaranteed to |
277 | /// not alias with any other memref in the `memref` operand list. The results |
278 | /// of the old and the new dealloc operation have to be combined by computing |
279 | /// the element-wise disjunction of them. |
280 | /// |
281 | /// Example: |
282 | /// ```mlir |
283 | /// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xi32>, memref<2xi32>) |
284 | /// if (%cond0, %cond1) |
285 | /// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>) |
286 | /// return %0#0, %0#1 |
287 | /// ``` |
288 | /// Given that `%m0` is guaranteed to never alias with `%m1`, the above IR is |
289 | /// canonicalized to the following, thus reducing the number of runtime alias |
290 | /// checks by 1 and potentially enabling further canonicalization of the new |
291 | /// split-up dealloc operations. |
292 | /// ```mlir |
293 | /// %0:2 = bufferization.dealloc (%m0 : memref<2xi32>) if (%cond0) |
294 | /// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>) |
295 | /// %1:2 = bufferization.dealloc (%m1 : memref<2xi32>) if (%cond1) |
296 | /// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>) |
297 | /// %2 = arith.ori %0#0, %1#0 |
298 | /// %3 = arith.ori %0#1, %1#1 |
299 | /// return %2, %3 |
300 | /// ``` |
301 | struct SplitDeallocWhenNotAliasingAnyOther |
302 | : public OpRewritePattern<DeallocOp> { |
303 | SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context, |
304 | BufferOriginAnalysis &analysis) |
305 | : OpRewritePattern<DeallocOp>(context), analysis(analysis) {} |
306 | |
307 | LogicalResult matchAndRewrite(DeallocOp deallocOp, |
308 | PatternRewriter &rewriter) const override { |
309 | Location loc = deallocOp.getLoc(); |
310 | if (deallocOp.getMemrefs().size() <= 1) |
311 | return failure(); |
312 | |
313 | SmallVector<Value> remainingMemrefs, remainingConditions; |
314 | SmallVector<SmallVector<Value>> updatedConditions; |
315 | for (int64_t i = 0, e = deallocOp.getMemrefs().size(); i < e; ++i) { |
316 | Value memref = deallocOp.getMemrefs()[i]; |
317 | Value cond = deallocOp.getConditions()[i]; |
318 | SmallVector<Value> otherMemrefs(deallocOp.getMemrefs()); |
319 | otherMemrefs.erase(CI: otherMemrefs.begin() + i); |
320 | // Check if `memref` can split off into a separate bufferization.dealloc. |
321 | if (potentiallyAliasesMemref(analysis, otherList: otherMemrefs, memref)) { |
322 | // `memref` alias with other memrefs, do not split off. |
323 | remainingMemrefs.push_back(Elt: memref); |
324 | remainingConditions.push_back(Elt: cond); |
325 | continue; |
326 | } |
327 | |
328 | // Create new bufferization.dealloc op for `memref`. |
329 | auto newDeallocOp = rewriter.create<DeallocOp>(loc, memref, cond, |
330 | deallocOp.getRetained()); |
331 | updatedConditions.push_back( |
332 | Elt: llvm::to_vector(Range: ValueRange(newDeallocOp.getUpdatedConditions()))); |
333 | } |
334 | |
335 | // Fail if no memref was split off. |
336 | if (remainingMemrefs.size() == deallocOp.getMemrefs().size()) |
337 | return failure(); |
338 | |
339 | // Create bufferization.dealloc op for all remaining memrefs. |
340 | auto newDeallocOp = rewriter.create<DeallocOp>( |
341 | loc, remainingMemrefs, remainingConditions, deallocOp.getRetained()); |
342 | |
343 | // Bit-or all conditions. |
344 | SmallVector<Value> replacements = |
345 | llvm::to_vector(Range: ValueRange(newDeallocOp.getUpdatedConditions())); |
346 | for (auto additionalConditions : updatedConditions) { |
347 | assert(replacements.size() == additionalConditions.size() && |
348 | "expected same number of updated conditions" ); |
349 | for (int64_t i = 0, e = replacements.size(); i < e; ++i) { |
350 | replacements[i] = rewriter.create<arith::OrIOp>( |
351 | loc, replacements[i], additionalConditions[i]); |
352 | } |
353 | } |
354 | rewriter.replaceOp(deallocOp, replacements); |
355 | return success(); |
356 | } |
357 | |
358 | private: |
359 | BufferOriginAnalysis &analysis; |
360 | }; |
361 | |
362 | /// Check for every retained memref if a must-aliasing memref exists in the |
363 | /// 'memref' operand list with constant 'true' condition. If so, we can replace |
364 | /// the operation result corresponding to that retained memref with 'true'. If |
365 | /// this condition holds for all retained memrefs we can also remove the |
366 | /// aliasing memrefs and their conditions since they will never be deallocated |
367 | /// due to the must-alias and we don't need them to compute the result value |
368 | /// anymore since it got replaced with 'true'. |
369 | /// |
370 | /// Example: |
371 | /// ```mlir |
372 | /// %0:2 = bufferization.dealloc (%arg0, %arg1, %arg2 : ...) |
373 | /// if (%true, %true, %true) |
374 | /// retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) |
375 | /// ``` |
376 | /// becomes |
377 | /// ```mlir |
378 | /// %0:2 = bufferization.dealloc (%arg2 : memref<2xi32>) if (%true) |
379 | /// retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) |
380 | /// // replace %0#0 with %true |
381 | /// // replace %0#1 with %true |
382 | /// ``` |
383 | /// Note that the dealloc operation will still have the result values, but they |
384 | /// don't have uses anymore. |
385 | struct RetainedMemrefAliasingAlwaysDeallocatedMemref |
386 | : public OpRewritePattern<DeallocOp> { |
387 | RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context, |
388 | BufferOriginAnalysis &analysis) |
389 | : OpRewritePattern<DeallocOp>(context), analysis(analysis) {} |
390 | |
391 | LogicalResult matchAndRewrite(DeallocOp deallocOp, |
392 | PatternRewriter &rewriter) const override { |
393 | BitVector aliasesWithConstTrueMemref(deallocOp.getRetained().size()); |
394 | SmallVector<Value> newMemrefs, newConditions; |
395 | for (auto [memref, cond] : |
396 | llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { |
397 | bool canDropMemref = false; |
398 | for (auto [i, retained, res] : llvm::enumerate( |
399 | deallocOp.getRetained(), deallocOp.getUpdatedConditions())) { |
400 | if (!matchPattern(cond, m_One())) |
401 | continue; |
402 | |
403 | std::optional<bool> analysisResult = |
404 | analysis.isSameAllocation(retained, memref); |
405 | if (analysisResult == true) { |
406 | rewriter.replaceAllUsesWith(res, cond); |
407 | aliasesWithConstTrueMemref[i] = true; |
408 | canDropMemref = true; |
409 | continue; |
410 | } |
411 | |
412 | // TODO: once our alias analysis is powerful enough we can remove the |
413 | // rest of this loop body |
414 | auto extractOp = |
415 | memref.getDefiningOp<memref::ExtractStridedMetadataOp>(); |
416 | if (!extractOp) |
417 | continue; |
418 | |
419 | std::optional<bool> extractAnalysisResult = |
420 | analysis.isSameAllocation(retained, extractOp.getOperand()); |
421 | if (extractAnalysisResult == true) { |
422 | rewriter.replaceAllUsesWith(res, cond); |
423 | aliasesWithConstTrueMemref[i] = true; |
424 | canDropMemref = true; |
425 | } |
426 | } |
427 | |
428 | if (!canDropMemref) { |
429 | newMemrefs.push_back(memref); |
430 | newConditions.push_back(cond); |
431 | } |
432 | } |
433 | if (!aliasesWithConstTrueMemref.all()) |
434 | return failure(); |
435 | |
436 | return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, |
437 | rewriter); |
438 | } |
439 | |
440 | private: |
441 | BufferOriginAnalysis &analysis; |
442 | }; |
443 | |
444 | } // namespace |
445 | |
446 | //===----------------------------------------------------------------------===// |
447 | // BufferDeallocationSimplificationPass |
448 | //===----------------------------------------------------------------------===// |
449 | |
450 | namespace { |
451 | |
452 | /// The actual buffer deallocation pass that inserts and moves dealloc nodes |
453 | /// into the right positions. Furthermore, it inserts additional clones if |
454 | /// necessary. It uses the algorithm described at the top of the file. |
455 | struct BufferDeallocationSimplificationPass |
456 | : public bufferization::impl::BufferDeallocationSimplificationBase< |
457 | BufferDeallocationSimplificationPass> { |
458 | void runOnOperation() override { |
459 | BufferOriginAnalysis analysis(getOperation()); |
460 | RewritePatternSet patterns(&getContext()); |
461 | patterns.add<RemoveDeallocMemrefsContainedInRetained, |
462 | RemoveRetainedMemrefsGuaranteedToNotAlias, |
463 | SplitDeallocWhenNotAliasingAnyOther, |
464 | RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(), |
465 | analysis); |
466 | populateDeallocOpCanonicalizationPatterns(patterns, &getContext()); |
467 | |
468 | if (failed( |
469 | applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) |
470 | signalPassFailure(); |
471 | } |
472 | }; |
473 | |
474 | } // namespace |
475 | |
476 | std::unique_ptr<Pass> |
477 | mlir::bufferization::createBufferDeallocationSimplificationPass() { |
478 | return std::make_unique<BufferDeallocationSimplificationPass>(); |
479 | } |
480 | |