1//===- BufferDeallocationSimplification.cpp -------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements logic for optimizing `bufferization.dealloc` operations
10// that requires more analysis than what can be supported by regular
11// canonicalization patterns.
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
16#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
17#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/MemRef/IR/MemRef.h"
20#include "mlir/IR/Matchers.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22
23namespace mlir {
24namespace bufferization {
25#define GEN_PASS_DEF_BUFFERDEALLOCATIONSIMPLIFICATION
26#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
27} // namespace bufferization
28} // namespace mlir
29
30using namespace mlir;
31using namespace mlir::bufferization;
32
33//===----------------------------------------------------------------------===//
34// Helpers
35//===----------------------------------------------------------------------===//
36
37/// Given a memref value, return the "base" value by skipping over all
38/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
39static Value getViewBase(Value value) {
40 while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
41 value = viewLikeOp.getViewSource();
42 return value;
43}
44
45static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
46 ValueRange memrefs,
47 ValueRange conditions,
48 PatternRewriter &rewriter) {
49 if (deallocOp.getMemrefs() == memrefs &&
50 deallocOp.getConditions() == conditions)
51 return failure();
52
53 rewriter.modifyOpInPlace(deallocOp, [&]() {
54 deallocOp.getMemrefsMutable().assign(memrefs);
55 deallocOp.getConditionsMutable().assign(conditions);
56 });
57 return success();
58}
59
60/// Return "true" if the given values are guaranteed to be different (and
61/// non-aliasing) allocations based on the fact that one value is the result
62/// of an allocation and the other value is a block argument of a parent block.
63/// Note: This is a best-effort analysis that will eventually be replaced by a
64/// proper "is same allocation" analysis. This function may return "false" even
65/// though the two values are distinct allocations.
66static bool distinctAllocAndBlockArgument(Value v1, Value v2) {
67 Value v1Base = getViewBase(value: v1);
68 Value v2Base = getViewBase(value: v2);
69 auto areDistinct = [](Value v1, Value v2) {
70 if (Operation *op = v1.getDefiningOp())
71 if (hasEffect<MemoryEffects::Allocate>(op, value: v1))
72 if (auto bbArg = dyn_cast<BlockArgument>(Val&: v2))
73 if (bbArg.getOwner()->findAncestorOpInBlock(op&: *op))
74 return true;
75 return false;
76 };
77 return areDistinct(v1Base, v2Base) || areDistinct(v2Base, v1Base);
78}
79
80/// Checks if `memref` may potentially alias a MemRef in `otherList`. It is
81/// often a requirement of optimization patterns that there cannot be any
82/// aliasing memref in order to perform the desired simplification.
83static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis,
84 ValueRange otherList, Value memref) {
85 for (auto other : otherList) {
86 if (distinctAllocAndBlockArgument(v1: other, v2: memref))
87 continue;
88 std::optional<bool> analysisResult =
89 analysis.isSameAllocation(v1: other, v2: memref);
90 if (!analysisResult.has_value() || analysisResult == true)
91 return true;
92 }
93 return false;
94}
95
96//===----------------------------------------------------------------------===//
97// Patterns
98//===----------------------------------------------------------------------===//
99
100namespace {
101
102/// Remove values from the `memref` operand list that are also present in the
103/// `retained` list (or a guaranteed alias of it) because they will never
104/// actually be deallocated. However, we also need to be certain about which
105/// other memrefs in the `retained` list can alias, i.e., there must not by any
106/// may-aliasing memref. This is necessary because the `dealloc` operation is
107/// defined to return one `i1` value per memref in the `retained` list which
108/// represents the disjunction of the condition values corresponding to all
109/// aliasing values in the `memref` list. In particular, this means that if
110/// there is some value R in the `retained` list which aliases with a value M in
111/// the `memref` list (but can only be staticaly determined to may-alias) and M
112/// is also present in the `retained` list, then it would be illegal to remove M
113/// because the result corresponding to R would be computed incorrectly
114/// afterwards. Because we require an alias analysis, this pattern cannot be
115/// applied as a regular canonicalization pattern.
116///
117/// Example:
118/// ```mlir
119/// %0:3 = bufferization.dealloc (%m0 : ...) if (%cond0)
120/// retain (%m0, %r0, %r1 : ...)
121/// ```
122/// is canonicalized to
123/// ```mlir
124/// // bufferization.dealloc without memrefs and conditions returns %false for
125/// // every retained value
126/// %0:3 = bufferization.dealloc retain (%m0, %r0, %r1 : ...)
127/// %1 = arith.ori %0#0, %cond0 : i1
128/// // replace %0#0 with %1
129/// ```
130/// given that `%r0` and `%r1` may not alias with `%m0`.
131struct RemoveDeallocMemrefsContainedInRetained
132 : public OpRewritePattern<DeallocOp> {
133 RemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
134 BufferOriginAnalysis &analysis)
135 : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
136
137 /// The passed 'memref' must not have a may-alias relation to any retained
138 /// memref, and at least one must-alias relation. If there is no must-aliasing
139 /// memref in the retain list, we cannot simply remove the memref as there
140 /// could be situations in which it actually has to be deallocated. If it's
141 /// no-alias, then just proceed, if it's must-alias we need to update the
142 /// updated condition returned by the dealloc operation for that alias.
143 LogicalResult handleOneMemref(DeallocOp deallocOp, Value memref, Value cond,
144 PatternRewriter &rewriter) const {
145 rewriter.setInsertionPointAfter(deallocOp);
146
147 // Check that there is no may-aliasing memref and that at least one memref
148 // in the retain list aliases (because otherwise it might have to be
149 // deallocated in some situations and can thus not be dropped).
150 bool atLeastOneMustAlias = false;
151 for (Value retained : deallocOp.getRetained()) {
152 std::optional<bool> analysisResult =
153 analysis.isSameAllocation(retained, memref);
154 if (!analysisResult.has_value())
155 return failure();
156 if (analysisResult == true)
157 atLeastOneMustAlias = true;
158 }
159 if (!atLeastOneMustAlias)
160 return failure();
161
162 // Insert arith.ori operations to update the corresponding dealloc result
163 // values to incorporate the condition of the must-aliasing memref such that
164 // we can remove that operand later on.
165 for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) {
166 Value updatedCondition = deallocOp.getUpdatedConditions()[i];
167 std::optional<bool> analysisResult =
168 analysis.isSameAllocation(retained, memref);
169 if (analysisResult == true) {
170 auto disjunction = rewriter.create<arith::OrIOp>(
171 deallocOp.getLoc(), updatedCondition, cond);
172 rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(),
173 disjunction);
174 }
175 }
176
177 return success();
178 }
179
180 LogicalResult matchAndRewrite(DeallocOp deallocOp,
181 PatternRewriter &rewriter) const override {
182 // There must not be any duplicates in the retain list anymore because we
183 // would miss updating one of the result values otherwise.
184 DenseSet<Value> retained(deallocOp.getRetained().begin(),
185 deallocOp.getRetained().end());
186 if (retained.size() != deallocOp.getRetained().size())
187 return failure();
188
189 SmallVector<Value> newMemrefs, newConditions;
190 for (auto [memref, cond] :
191 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
192
193 if (succeeded(handleOneMemref(deallocOp, memref, cond, rewriter)))
194 continue;
195
196 if (auto extractOp =
197 memref.getDefiningOp<memref::ExtractStridedMetadataOp>())
198 if (succeeded(handleOneMemref(deallocOp, extractOp.getOperand(), cond,
199 rewriter)))
200 continue;
201
202 newMemrefs.push_back(memref);
203 newConditions.push_back(cond);
204 }
205
206 // Return failure if we don't change anything such that we don't run into an
207 // infinite loop of pattern applications.
208 return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
209 rewriter);
210 }
211
212private:
213 BufferOriginAnalysis &analysis;
214};
215
216/// Remove memrefs from the `retained` list which are guaranteed to not alias
217/// any memref in the `memrefs` list. The corresponding result value can be
218/// replaced with `false` in that case according to the operation description.
219///
220/// Example:
221/// ```mlir
222/// %0:2 = bufferization.dealloc (%m : memref<2xi32>) if (%cond)
223/// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
224/// return %0#0, %0#1
225/// ```
226/// can be canonicalized to the following given that `%r0` and `%r1` do not
227/// alias `%m`:
228/// ```mlir
229/// bufferization.dealloc (%m : memref<2xi32>) if (%cond)
230/// return %false, %false
231/// ```
232struct RemoveRetainedMemrefsGuaranteedToNotAlias
233 : public OpRewritePattern<DeallocOp> {
234 RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context,
235 BufferOriginAnalysis &analysis)
236 : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
237
238 LogicalResult matchAndRewrite(DeallocOp deallocOp,
239 PatternRewriter &rewriter) const override {
240 SmallVector<Value> newRetainedMemrefs, replacements;
241
242 for (auto retainedMemref : deallocOp.getRetained()) {
243 if (potentiallyAliasesMemref(analysis, deallocOp.getMemrefs(),
244 retainedMemref)) {
245 newRetainedMemrefs.push_back(retainedMemref);
246 replacements.push_back({});
247 continue;
248 }
249
250 replacements.push_back(rewriter.create<arith::ConstantOp>(
251 deallocOp.getLoc(), rewriter.getBoolAttr(false)));
252 }
253
254 if (newRetainedMemrefs.size() == deallocOp.getRetained().size())
255 return failure();
256
257 auto newDeallocOp = rewriter.create<DeallocOp>(
258 deallocOp.getLoc(), deallocOp.getMemrefs(), deallocOp.getConditions(),
259 newRetainedMemrefs);
260 int i = 0;
261 for (auto &repl : replacements) {
262 if (!repl)
263 repl = newDeallocOp.getUpdatedConditions()[i++];
264 }
265
266 rewriter.replaceOp(deallocOp, replacements);
267 return success();
268 }
269
270private:
271 BufferOriginAnalysis &analysis;
272};
273
274/// Split off memrefs to separate dealloc operations to reduce the number of
275/// runtime checks required and enable further canonicalization of the new and
276/// simpler dealloc operations. A memref can be split off if it is guaranteed to
277/// not alias with any other memref in the `memref` operand list. The results
278/// of the old and the new dealloc operation have to be combined by computing
279/// the element-wise disjunction of them.
280///
281/// Example:
282/// ```mlir
283/// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xi32>, memref<2xi32>)
284/// if (%cond0, %cond1)
285/// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
286/// return %0#0, %0#1
287/// ```
288/// Given that `%m0` is guaranteed to never alias with `%m1`, the above IR is
289/// canonicalized to the following, thus reducing the number of runtime alias
290/// checks by 1 and potentially enabling further canonicalization of the new
291/// split-up dealloc operations.
292/// ```mlir
293/// %0:2 = bufferization.dealloc (%m0 : memref<2xi32>) if (%cond0)
294/// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
295/// %1:2 = bufferization.dealloc (%m1 : memref<2xi32>) if (%cond1)
296/// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
297/// %2 = arith.ori %0#0, %1#0
298/// %3 = arith.ori %0#1, %1#1
299/// return %2, %3
300/// ```
301struct SplitDeallocWhenNotAliasingAnyOther
302 : public OpRewritePattern<DeallocOp> {
303 SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context,
304 BufferOriginAnalysis &analysis)
305 : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
306
307 LogicalResult matchAndRewrite(DeallocOp deallocOp,
308 PatternRewriter &rewriter) const override {
309 Location loc = deallocOp.getLoc();
310 if (deallocOp.getMemrefs().size() <= 1)
311 return failure();
312
313 SmallVector<Value> remainingMemrefs, remainingConditions;
314 SmallVector<SmallVector<Value>> updatedConditions;
315 for (int64_t i = 0, e = deallocOp.getMemrefs().size(); i < e; ++i) {
316 Value memref = deallocOp.getMemrefs()[i];
317 Value cond = deallocOp.getConditions()[i];
318 SmallVector<Value> otherMemrefs(deallocOp.getMemrefs());
319 otherMemrefs.erase(CI: otherMemrefs.begin() + i);
320 // Check if `memref` can split off into a separate bufferization.dealloc.
321 if (potentiallyAliasesMemref(analysis, otherList: otherMemrefs, memref)) {
322 // `memref` alias with other memrefs, do not split off.
323 remainingMemrefs.push_back(Elt: memref);
324 remainingConditions.push_back(Elt: cond);
325 continue;
326 }
327
328 // Create new bufferization.dealloc op for `memref`.
329 auto newDeallocOp = rewriter.create<DeallocOp>(loc, memref, cond,
330 deallocOp.getRetained());
331 updatedConditions.push_back(
332 Elt: llvm::to_vector(Range: ValueRange(newDeallocOp.getUpdatedConditions())));
333 }
334
335 // Fail if no memref was split off.
336 if (remainingMemrefs.size() == deallocOp.getMemrefs().size())
337 return failure();
338
339 // Create bufferization.dealloc op for all remaining memrefs.
340 auto newDeallocOp = rewriter.create<DeallocOp>(
341 loc, remainingMemrefs, remainingConditions, deallocOp.getRetained());
342
343 // Bit-or all conditions.
344 SmallVector<Value> replacements =
345 llvm::to_vector(Range: ValueRange(newDeallocOp.getUpdatedConditions()));
346 for (auto additionalConditions : updatedConditions) {
347 assert(replacements.size() == additionalConditions.size() &&
348 "expected same number of updated conditions");
349 for (int64_t i = 0, e = replacements.size(); i < e; ++i) {
350 replacements[i] = rewriter.create<arith::OrIOp>(
351 loc, replacements[i], additionalConditions[i]);
352 }
353 }
354 rewriter.replaceOp(deallocOp, replacements);
355 return success();
356 }
357
358private:
359 BufferOriginAnalysis &analysis;
360};
361
362/// Check for every retained memref if a must-aliasing memref exists in the
363/// 'memref' operand list with constant 'true' condition. If so, we can replace
364/// the operation result corresponding to that retained memref with 'true'. If
365/// this condition holds for all retained memrefs we can also remove the
366/// aliasing memrefs and their conditions since they will never be deallocated
367/// due to the must-alias and we don't need them to compute the result value
368/// anymore since it got replaced with 'true'.
369///
370/// Example:
371/// ```mlir
372/// %0:2 = bufferization.dealloc (%arg0, %arg1, %arg2 : ...)
373/// if (%true, %true, %true)
374/// retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
375/// ```
376/// becomes
377/// ```mlir
378/// %0:2 = bufferization.dealloc (%arg2 : memref<2xi32>) if (%true)
379/// retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
380/// // replace %0#0 with %true
381/// // replace %0#1 with %true
382/// ```
383/// Note that the dealloc operation will still have the result values, but they
384/// don't have uses anymore.
385struct RetainedMemrefAliasingAlwaysDeallocatedMemref
386 : public OpRewritePattern<DeallocOp> {
387 RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context,
388 BufferOriginAnalysis &analysis)
389 : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
390
391 LogicalResult matchAndRewrite(DeallocOp deallocOp,
392 PatternRewriter &rewriter) const override {
393 BitVector aliasesWithConstTrueMemref(deallocOp.getRetained().size());
394 SmallVector<Value> newMemrefs, newConditions;
395 for (auto [memref, cond] :
396 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
397 bool canDropMemref = false;
398 for (auto [i, retained, res] : llvm::enumerate(
399 deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {
400 if (!matchPattern(cond, m_One()))
401 continue;
402
403 std::optional<bool> analysisResult =
404 analysis.isSameAllocation(retained, memref);
405 if (analysisResult == true) {
406 rewriter.replaceAllUsesWith(res, cond);
407 aliasesWithConstTrueMemref[i] = true;
408 canDropMemref = true;
409 continue;
410 }
411
412 // TODO: once our alias analysis is powerful enough we can remove the
413 // rest of this loop body
414 auto extractOp =
415 memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
416 if (!extractOp)
417 continue;
418
419 std::optional<bool> extractAnalysisResult =
420 analysis.isSameAllocation(retained, extractOp.getOperand());
421 if (extractAnalysisResult == true) {
422 rewriter.replaceAllUsesWith(res, cond);
423 aliasesWithConstTrueMemref[i] = true;
424 canDropMemref = true;
425 }
426 }
427
428 if (!canDropMemref) {
429 newMemrefs.push_back(memref);
430 newConditions.push_back(cond);
431 }
432 }
433 if (!aliasesWithConstTrueMemref.all())
434 return failure();
435
436 return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
437 rewriter);
438 }
439
440private:
441 BufferOriginAnalysis &analysis;
442};
443
444} // namespace
445
446//===----------------------------------------------------------------------===//
447// BufferDeallocationSimplificationPass
448//===----------------------------------------------------------------------===//
449
450namespace {
451
452/// The actual buffer deallocation pass that inserts and moves dealloc nodes
453/// into the right positions. Furthermore, it inserts additional clones if
454/// necessary. It uses the algorithm described at the top of the file.
455struct BufferDeallocationSimplificationPass
456 : public bufferization::impl::BufferDeallocationSimplificationBase<
457 BufferDeallocationSimplificationPass> {
458 void runOnOperation() override {
459 BufferOriginAnalysis analysis(getOperation());
460 RewritePatternSet patterns(&getContext());
461 patterns.add<RemoveDeallocMemrefsContainedInRetained,
462 RemoveRetainedMemrefsGuaranteedToNotAlias,
463 SplitDeallocWhenNotAliasingAnyOther,
464 RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
465 analysis);
466 populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
467
468 if (failed(
469 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
470 signalPassFailure();
471 }
472};
473
474} // namespace
475
476std::unique_ptr<Pass>
477mlir::bufferization::createBufferDeallocationSimplificationPass() {
478 return std::make_unique<BufferDeallocationSimplificationPass>();
479}
480

source code of mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp