1//===- LowerDeallocations.cpp - Bufferization Deallocs to MemRef pass -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert `bufferization.dealloc` operations
10// to the MemRef dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
16#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/MemRef/IR/MemRef.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/Pass/Pass.h"
22#include "mlir/Transforms/DialectConversion.h"
23
24namespace mlir {
25namespace bufferization {
26#define GEN_PASS_DEF_LOWERDEALLOCATIONSPASS
27#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
28} // namespace bufferization
29} // namespace mlir
30
31using namespace mlir;
32
33namespace {
34/// The DeallocOpConversion transforms all bufferization dealloc operations into
35/// memref dealloc operations potentially guarded by scf if operations.
36/// Additionally, memref extract_aligned_pointer_as_index and arith operations
37/// are inserted to compute the guard conditions. We distinguish multiple cases
38/// to provide an overall more efficient lowering. In the general case, a helper
39/// func is created to avoid quadratic code size explosion (relative to the
40/// number of operands of the dealloc operation). For examples of each case,
41/// refer to the documentation of the member functions of this class.
42class DeallocOpConversion
43 : public OpConversionPattern<bufferization::DeallocOp> {
44
45 /// Lower a simple case without any retained values and a single memref to
46 /// avoiding the helper function. Ideally, static analysis can provide enough
47 /// aliasing information to split the dealloc operations up into this simple
48 /// case as much as possible before running this pass.
49 ///
50 /// Example:
51 /// ```
52 /// bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1)
53 /// ```
54 /// is lowered to
55 /// ```
56 /// scf.if %arg1 {
57 /// memref.dealloc %arg0 : memref<2xf32>
58 /// }
59 /// ```
60 LogicalResult
61 rewriteOneMemrefNoRetainCase(bufferization::DeallocOp op, OpAdaptor adaptor,
62 ConversionPatternRewriter &rewriter) const {
63 assert(adaptor.getMemrefs().size() == 1 && "expected only one memref");
64 assert(adaptor.getRetained().empty() && "expected no retained memrefs");
65
66 rewriter.replaceOpWithNewOp<scf::IfOp>(
67 op, adaptor.getConditions()[0], [&](OpBuilder &builder, Location loc) {
68 builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
69 builder.create<scf::YieldOp>(loc);
70 });
71 return success();
72 }
73
74 /// A special case lowering for the deallocation operation with exactly one
75 /// memref, but arbitrary number of retained values. This avoids the helper
76 /// function that the general case needs and thus also avoids storing indices
77 /// to specifically allocated memrefs. The size of the code produced by this
78 /// lowering is linear to the number of retained values.
79 ///
80 /// Example:
81 /// ```mlir
82 /// %0:2 = bufferization.dealloc (%m : memref<2xf32>) if (%cond)
83 // retain (%r0, %r1 : memref<1xf32>, memref<2xf32>)
84 /// return %0#0, %0#1 : i1, i1
85 /// ```
86 /// ```mlir
87 /// %m_base_pointer = memref.extract_aligned_pointer_as_index %m
88 /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0
89 /// %r0_does_not_alias = arith.cmpi ne, %m_base_pointer, %r0_base_pointer
90 /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1
91 /// %r1_does_not_alias = arith.cmpi ne, %m_base_pointer, %r1_base_pointer
92 /// %not_retained = arith.andi %r0_does_not_alias, %r1_does_not_alias : i1
93 /// %should_dealloc = arith.andi %not_retained, %cond : i1
94 /// scf.if %should_dealloc {
95 /// memref.dealloc %m : memref<2xf32>
96 /// }
97 /// %true = arith.constant true
98 /// %r0_does_alias = arith.xori %r0_does_not_alias, %true : i1
99 /// %r0_ownership = arith.andi %r0_does_alias, %cond : i1
100 /// %r1_does_alias = arith.xori %r1_does_not_alias, %true : i1
101 /// %r1_ownership = arith.andi %r1_does_alias, %cond : i1
102 /// return %r0_ownership, %r1_ownership : i1, i1
103 /// ```
104 LogicalResult rewriteOneMemrefMultipleRetainCase(
105 bufferization::DeallocOp op, OpAdaptor adaptor,
106 ConversionPatternRewriter &rewriter) const {
107 assert(adaptor.getMemrefs().size() == 1 && "expected only one memref");
108
109 // Compute the base pointer indices, compare all retained indices to the
110 // memref index to check if they alias.
111 SmallVector<Value> doesNotAliasList;
112 Value memrefAsIdx = rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(
113 op->getLoc(), adaptor.getMemrefs()[0]);
114 for (Value retained : adaptor.getRetained()) {
115 Value retainedAsIdx =
116 rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op->getLoc(),
117 retained);
118 Value doesNotAlias = rewriter.create<arith::CmpIOp>(
119 op->getLoc(), arith::CmpIPredicate::ne, memrefAsIdx, retainedAsIdx);
120 doesNotAliasList.push_back(doesNotAlias);
121 }
122
123 // AND-reduce the list of booleans from above.
124 Value prev = doesNotAliasList.front();
125 for (Value doesNotAlias : ArrayRef(doesNotAliasList).drop_front())
126 prev = rewriter.create<arith::AndIOp>(op->getLoc(), prev, doesNotAlias);
127
128 // Also consider the condition given by the dealloc operation and perform a
129 // conditional deallocation guarded by that value.
130 Value shouldDealloc = rewriter.create<arith::AndIOp>(
131 op->getLoc(), prev, adaptor.getConditions()[0]);
132
133 rewriter.create<scf::IfOp>(
134 op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) {
135 builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
136 builder.create<scf::YieldOp>(loc);
137 });
138
139 // Compute the replacement values for the dealloc operation results. This
140 // inserts an already canonicalized form of
141 // `select(does_alias_with_memref(r), memref_cond, false)` for each retained
142 // value r.
143 SmallVector<Value> replacements;
144 Value trueVal = rewriter.create<arith::ConstantOp>(
145 op->getLoc(), rewriter.getBoolAttr(true));
146 for (Value doesNotAlias : doesNotAliasList) {
147 Value aliases =
148 rewriter.create<arith::XOrIOp>(op->getLoc(), doesNotAlias, trueVal);
149 Value result = rewriter.create<arith::AndIOp>(op->getLoc(), aliases,
150 adaptor.getConditions()[0]);
151 replacements.push_back(result);
152 }
153
154 rewriter.replaceOp(op, replacements);
155
156 return success();
157 }
158
159 /// Lowering that supports all features the dealloc operation has to offer. It
160 /// computes the base pointer of each memref (as an index), stores it in a
161 /// new memref helper structure and passes it to the helper function generated
162 /// in 'buildDeallocationHelperFunction'. The results are stored in two lists
163 /// (represented as memrefs) of booleans passed as arguments. The first list
164 /// stores whether the corresponding condition should be deallocated, the
165 /// second list stores the ownership of the retained values which can be used
166 /// to replace the result values of the `bufferization.dealloc` operation.
167 ///
168 /// Example:
169 /// ```
170 /// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xf32>, memref<5xf32>)
171 /// if (%cond0, %cond1)
172 /// retain (%r0, %r1 : memref<1xf32>, memref<2xf32>)
173 /// ```
174 /// lowers to (simplified):
175 /// ```
176 /// %c0 = arith.constant 0 : index
177 /// %c1 = arith.constant 1 : index
178 /// %dealloc_base_pointer_list = memref.alloc() : memref<2xindex>
179 /// %cond_list = memref.alloc() : memref<2xi1>
180 /// %retain_base_pointer_list = memref.alloc() : memref<2xindex>
181 /// %m0_base_pointer = memref.extract_aligned_pointer_as_index %m0
182 /// memref.store %m0_base_pointer, %dealloc_base_pointer_list[%c0]
183 /// %m1_base_pointer = memref.extract_aligned_pointer_as_index %m1
184 /// memref.store %m1_base_pointer, %dealloc_base_pointer_list[%c1]
185 /// memref.store %cond0, %cond_list[%c0]
186 /// memref.store %cond1, %cond_list[%c1]
187 /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0
188 /// memref.store %r0_base_pointer, %retain_base_pointer_list[%c0]
189 /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1
190 /// memref.store %r1_base_pointer, %retain_base_pointer_list[%c1]
191 /// %dyn_dealloc_base_pointer_list = memref.cast %dealloc_base_pointer_list :
192 /// memref<2xindex> to memref<?xindex>
193 /// %dyn_cond_list = memref.cast %cond_list : memref<2xi1> to memref<?xi1>
194 /// %dyn_retain_base_pointer_list = memref.cast %retain_base_pointer_list :
195 /// memref<2xindex> to memref<?xindex>
196 /// %dealloc_cond_out = memref.alloc() : memref<2xi1>
197 /// %ownership_out = memref.alloc() : memref<2xi1>
198 /// %dyn_dealloc_cond_out = memref.cast %dealloc_cond_out :
199 /// memref<2xi1> to memref<?xi1>
200 /// %dyn_ownership_out = memref.cast %ownership_out :
201 /// memref<2xi1> to memref<?xi1>
202 /// call @dealloc_helper(%dyn_dealloc_base_pointer_list,
203 /// %dyn_retain_base_pointer_list,
204 /// %dyn_cond_list,
205 /// %dyn_dealloc_cond_out,
206 /// %dyn_ownership_out) : (...)
207 /// %m0_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c0] : memref<2xi1>
208 /// scf.if %m0_dealloc_cond {
209 /// memref.dealloc %m0 : memref<2xf32>
210 /// }
211 /// %m1_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c1] : memref<2xi1>
212 /// scf.if %m1_dealloc_cond {
213 /// memref.dealloc %m1 : memref<5xf32>
214 /// }
215 /// %r0_ownership = memref.load %dyn_ownership_out[%c0] : memref<2xi1>
216 /// %r1_ownership = memref.load %dyn_ownership_out[%c1] : memref<2xi1>
217 /// memref.dealloc %dealloc_base_pointer_list : memref<2xindex>
218 /// memref.dealloc %retain_base_pointer_list : memref<2xindex>
219 /// memref.dealloc %cond_list : memref<2xi1>
220 /// memref.dealloc %dealloc_cond_out : memref<2xi1>
221 /// memref.dealloc %ownership_out : memref<2xi1>
222 /// // replace %0#0 with %r0_ownership
223 /// // replace %0#1 with %r1_ownership
224 /// ```
225 LogicalResult rewriteGeneralCase(bufferization::DeallocOp op,
226 OpAdaptor adaptor,
227 ConversionPatternRewriter &rewriter) const {
228 // Allocate two memrefs holding the base pointer indices of the list of
229 // memrefs to be deallocated and the ones to be retained. These can then be
230 // passed to the helper function and the for-loops can iterate over them.
231 // Without storing them to memrefs, we could not use for-loops but only a
232 // completely unrolled version of it, potentially leading to code-size
233 // blow-up.
234 Value toDeallocMemref = rewriter.create<memref::AllocOp>(
235 op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()},
236 rewriter.getIndexType()));
237 Value conditionMemref = rewriter.create<memref::AllocOp>(
238 op.getLoc(), MemRefType::get({(int64_t)adaptor.getConditions().size()},
239 rewriter.getI1Type()));
240 Value toRetainMemref = rewriter.create<memref::AllocOp>(
241 op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()},
242 rewriter.getIndexType()));
243
244 auto getConstValue = [&](uint64_t value) -> Value {
245 return rewriter.create<arith::ConstantOp>(op.getLoc(),
246 rewriter.getIndexAttr(value));
247 };
248
249 // Extract the base pointers of the memrefs as indices to check for aliasing
250 // at runtime.
251 for (auto [i, toDealloc] : llvm::enumerate(adaptor.getMemrefs())) {
252 Value memrefAsIdx =
253 rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(),
254 toDealloc);
255 rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx,
256 toDeallocMemref, getConstValue(i));
257 }
258
259 for (auto [i, cond] : llvm::enumerate(adaptor.getConditions()))
260 rewriter.create<memref::StoreOp>(op.getLoc(), cond, conditionMemref,
261 getConstValue(i));
262
263 for (auto [i, toRetain] : llvm::enumerate(adaptor.getRetained())) {
264 Value memrefAsIdx =
265 rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(),
266 toRetain);
267 rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx, toRetainMemref,
268 getConstValue(i));
269 }
270
271 // Cast the allocated memrefs to dynamic shape because we want only one
272 // helper function no matter how many operands the bufferization.dealloc
273 // has.
274 Value castedDeallocMemref = rewriter.create<memref::CastOp>(
275 op->getLoc(),
276 MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()),
277 toDeallocMemref);
278 Value castedCondsMemref = rewriter.create<memref::CastOp>(
279 op->getLoc(),
280 MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
281 conditionMemref);
282 Value castedRetainMemref = rewriter.create<memref::CastOp>(
283 op->getLoc(),
284 MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()),
285 toRetainMemref);
286
287 Value deallocCondsMemref = rewriter.create<memref::AllocOp>(
288 op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()},
289 rewriter.getI1Type()));
290 Value retainCondsMemref = rewriter.create<memref::AllocOp>(
291 op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()},
292 rewriter.getI1Type()));
293
294 Value castedDeallocCondsMemref = rewriter.create<memref::CastOp>(
295 op->getLoc(),
296 MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
297 deallocCondsMemref);
298 Value castedRetainCondsMemref = rewriter.create<memref::CastOp>(
299 op->getLoc(),
300 MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
301 retainCondsMemref);
302
303 Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
304 rewriter.create<func::CallOp>(
305 op.getLoc(), deallocHelperFuncMap.lookup(symtableOp),
306 SmallVector<Value>{castedDeallocMemref, castedRetainMemref,
307 castedCondsMemref, castedDeallocCondsMemref,
308 castedRetainCondsMemref});
309
310 for (unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) {
311 Value idxValue = getConstValue(i);
312 Value shouldDealloc = rewriter.create<memref::LoadOp>(
313 op.getLoc(), deallocCondsMemref, idxValue);
314 rewriter.create<scf::IfOp>(
315 op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) {
316 builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[i]);
317 builder.create<scf::YieldOp>(loc);
318 });
319 }
320
321 SmallVector<Value> replacements;
322 for (unsigned i = 0, e = adaptor.getRetained().size(); i < e; ++i) {
323 Value idxValue = getConstValue(i);
324 Value ownership = rewriter.create<memref::LoadOp>(
325 op.getLoc(), retainCondsMemref, idxValue);
326 replacements.push_back(ownership);
327 }
328
329 // Deallocate above allocated memrefs again to avoid memory leaks.
330 // Deallocation will not be run on code after this stage.
331 rewriter.create<memref::DeallocOp>(op.getLoc(), toDeallocMemref);
332 rewriter.create<memref::DeallocOp>(op.getLoc(), toRetainMemref);
333 rewriter.create<memref::DeallocOp>(op.getLoc(), conditionMemref);
334 rewriter.create<memref::DeallocOp>(op.getLoc(), deallocCondsMemref);
335 rewriter.create<memref::DeallocOp>(op.getLoc(), retainCondsMemref);
336
337 rewriter.replaceOp(op, replacements);
338 return success();
339 }
340
341public:
342 DeallocOpConversion(
343 MLIRContext *context,
344 const bufferization::DeallocHelperMap &deallocHelperFuncMap)
345 : OpConversionPattern<bufferization::DeallocOp>(context),
346 deallocHelperFuncMap(deallocHelperFuncMap) {}
347
348 LogicalResult
349 matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
350 ConversionPatternRewriter &rewriter) const override {
351 // Lower the trivial case.
352 if (adaptor.getMemrefs().empty()) {
353 Value falseVal = rewriter.create<arith::ConstantOp>(
354 op.getLoc(), rewriter.getBoolAttr(false));
355 rewriter.replaceOp(
356 op, SmallVector<Value>(adaptor.getRetained().size(), falseVal));
357 return success();
358 }
359
360 if (adaptor.getMemrefs().size() == 1 && adaptor.getRetained().empty())
361 return rewriteOneMemrefNoRetainCase(op, adaptor, rewriter);
362
363 if (adaptor.getMemrefs().size() == 1)
364 return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter);
365
366 Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
367 if (!deallocHelperFuncMap.contains(symtableOp))
368 return op->emitError(
369 "library function required for generic lowering, but cannot be "
370 "automatically inserted when operating on functions");
371
372 return rewriteGeneralCase(op, adaptor, rewriter);
373 }
374
375private:
376 const bufferization::DeallocHelperMap &deallocHelperFuncMap;
377};
378} // namespace
379
380namespace {
381struct LowerDeallocationsPass
382 : public bufferization::impl::LowerDeallocationsPassBase<
383 LowerDeallocationsPass> {
384 void runOnOperation() override {
385 if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) {
386 emitError(getOperation()->getLoc(),
387 "root operation must be a builtin.module or a function");
388 signalPassFailure();
389 return;
390 }
391
392 bufferization::DeallocHelperMap deallocHelperFuncMap;
393 if (auto module = dyn_cast<ModuleOp>(getOperation())) {
394 OpBuilder builder = OpBuilder::atBlockBegin(block: module.getBody());
395
396 // Build dealloc helper function if there are deallocs.
397 getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
398 Operation *symtableOp =
399 deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
400 if (deallocOp.getMemrefs().size() > 1 &&
401 !deallocHelperFuncMap.contains(Val: symtableOp)) {
402 SymbolTable symbolTable(symtableOp);
403 func::FuncOp helperFuncOp =
404 bufferization::buildDeallocationLibraryFunction(
405 builder, getOperation()->getLoc(), symbolTable);
406 deallocHelperFuncMap[symtableOp] = helperFuncOp;
407 }
408 });
409 }
410
411 RewritePatternSet patterns(&getContext());
412 bufferization::populateBufferizationDeallocLoweringPattern(
413 patterns, deallocHelperFuncMap);
414
415 ConversionTarget target(getContext());
416 target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
417 scf::SCFDialect, func::FuncDialect>();
418 target.addIllegalOp<bufferization::DeallocOp>();
419
420 if (failed(applyPartialConversion(getOperation(), target,
421 std::move(patterns))))
422 signalPassFailure();
423 }
424};
425} // namespace
426
427func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction(
428 OpBuilder &builder, Location loc, SymbolTable &symbolTable) {
429 Type indexMemrefType =
430 MemRefType::get({ShapedType::kDynamic}, builder.getIndexType());
431 Type boolMemrefType =
432 MemRefType::get({ShapedType::kDynamic}, builder.getI1Type());
433 SmallVector<Type> argTypes{indexMemrefType, indexMemrefType, boolMemrefType,
434 boolMemrefType, boolMemrefType};
435 builder.clearInsertionPoint();
436
437 // Generate the func operation itself.
438 auto helperFuncOp = func::FuncOp::create(
439 loc, "dealloc_helper", builder.getFunctionType(argTypes, {}));
440 helperFuncOp.setVisibility(SymbolTable::Visibility::Private);
441 symbolTable.insert(symbol: helperFuncOp);
442 auto &block = helperFuncOp.getFunctionBody().emplaceBlock();
443 block.addArguments(argTypes, SmallVector<Location>(argTypes.size(), loc));
444
445 builder.setInsertionPointToStart(&block);
446 Value toDeallocMemref = helperFuncOp.getArguments()[0];
447 Value toRetainMemref = helperFuncOp.getArguments()[1];
448 Value conditionMemref = helperFuncOp.getArguments()[2];
449 Value deallocCondsMemref = helperFuncOp.getArguments()[3];
450 Value retainCondsMemref = helperFuncOp.getArguments()[4];
451
452 // Insert some prerequisites.
453 Value c0 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(0));
454 Value c1 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(1));
455 Value trueValue =
456 builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(true));
457 Value falseValue =
458 builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(false));
459 Value toDeallocSize = builder.create<memref::DimOp>(loc, toDeallocMemref, c0);
460 Value toRetainSize = builder.create<memref::DimOp>(loc, toRetainMemref, c0);
461
462 builder.create<scf::ForOp>(
463 loc, c0, toRetainSize, c1, std::nullopt,
464 [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) {
465 builder.create<memref::StoreOp>(loc, falseValue, retainCondsMemref, i);
466 builder.create<scf::YieldOp>(loc);
467 });
468
469 builder.create<scf::ForOp>(
470 loc, c0, toDeallocSize, c1, std::nullopt,
471 [&](OpBuilder &builder, Location loc, Value outerIter,
472 ValueRange iterArgs) {
473 Value toDealloc =
474 builder.create<memref::LoadOp>(loc, toDeallocMemref, outerIter);
475 Value cond =
476 builder.create<memref::LoadOp>(loc, conditionMemref, outerIter);
477
478 // Build the first for loop that computes aliasing with retained
479 // memrefs.
480 Value noRetainAlias =
481 builder
482 .create<scf::ForOp>(
483 loc, c0, toRetainSize, c1, trueValue,
484 [&](OpBuilder &builder, Location loc, Value i,
485 ValueRange iterArgs) {
486 Value retainValue = builder.create<memref::LoadOp>(
487 loc, toRetainMemref, i);
488 Value doesAlias = builder.create<arith::CmpIOp>(
489 loc, arith::CmpIPredicate::eq, retainValue,
490 toDealloc);
491 builder.create<scf::IfOp>(
492 loc, doesAlias,
493 [&](OpBuilder &builder, Location loc) {
494 Value retainCondValue =
495 builder.create<memref::LoadOp>(
496 loc, retainCondsMemref, i);
497 Value aggregatedRetainCond =
498 builder.create<arith::OrIOp>(
499 loc, retainCondValue, cond);
500 builder.create<memref::StoreOp>(
501 loc, aggregatedRetainCond, retainCondsMemref,
502 i);
503 builder.create<scf::YieldOp>(loc);
504 });
505 Value doesntAlias = builder.create<arith::CmpIOp>(
506 loc, arith::CmpIPredicate::ne, retainValue,
507 toDealloc);
508 Value yieldValue = builder.create<arith::AndIOp>(
509 loc, iterArgs[0], doesntAlias);
510 builder.create<scf::YieldOp>(loc, yieldValue);
511 })
512 .getResult(0);
513
514 // Build the second for loop that adds aliasing with previously
515 // deallocated memrefs.
516 Value noAlias =
517 builder
518 .create<scf::ForOp>(
519 loc, c0, outerIter, c1, noRetainAlias,
520 [&](OpBuilder &builder, Location loc, Value i,
521 ValueRange iterArgs) {
522 Value prevDeallocValue = builder.create<memref::LoadOp>(
523 loc, toDeallocMemref, i);
524 Value doesntAlias = builder.create<arith::CmpIOp>(
525 loc, arith::CmpIPredicate::ne, prevDeallocValue,
526 toDealloc);
527 Value yieldValue = builder.create<arith::AndIOp>(
528 loc, iterArgs[0], doesntAlias);
529 builder.create<scf::YieldOp>(loc, yieldValue);
530 })
531 .getResult(0);
532
533 Value shouldDealoc = builder.create<arith::AndIOp>(loc, noAlias, cond);
534 builder.create<memref::StoreOp>(loc, shouldDealoc, deallocCondsMemref,
535 outerIter);
536 builder.create<scf::YieldOp>(loc);
537 });
538
539 builder.create<func::ReturnOp>(loc);
540 return helperFuncOp;
541}
542
543void mlir::bufferization::populateBufferizationDeallocLoweringPattern(
544 RewritePatternSet &patterns,
545 const bufferization::DeallocHelperMap &deallocHelperFuncMap) {
546 patterns.add<DeallocOpConversion>(arg: patterns.getContext(),
547 args: deallocHelperFuncMap);
548}
549

source code of mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp