1//===- LowerDeallocations.cpp - Bufferization Deallocs to MemRef pass -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert `bufferization.dealloc` operations
10// to the MemRef dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
16#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/MemRef/IR/MemRef.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/Transforms/DialectConversion.h"
22
23namespace mlir {
24namespace bufferization {
25#define GEN_PASS_DEF_LOWERDEALLOCATIONSPASS
26#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
27} // namespace bufferization
28} // namespace mlir
29
30using namespace mlir;
31
32namespace {
33/// The DeallocOpConversion transforms all bufferization dealloc operations into
34/// memref dealloc operations potentially guarded by scf if operations.
35/// Additionally, memref extract_aligned_pointer_as_index and arith operations
36/// are inserted to compute the guard conditions. We distinguish multiple cases
37/// to provide an overall more efficient lowering. In the general case, a helper
38/// func is created to avoid quadratic code size explosion (relative to the
39/// number of operands of the dealloc operation). For examples of each case,
40/// refer to the documentation of the member functions of this class.
41class DeallocOpConversion
42 : public OpConversionPattern<bufferization::DeallocOp> {
43
44 /// Lower a simple case without any retained values and a single memref to
45 /// avoiding the helper function. Ideally, static analysis can provide enough
46 /// aliasing information to split the dealloc operations up into this simple
47 /// case as much as possible before running this pass.
48 ///
49 /// Example:
50 /// ```
51 /// bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1)
52 /// ```
53 /// is lowered to
54 /// ```
55 /// scf.if %arg1 {
56 /// memref.dealloc %arg0 : memref<2xf32>
57 /// }
58 /// ```
59 LogicalResult
60 rewriteOneMemrefNoRetainCase(bufferization::DeallocOp op, OpAdaptor adaptor,
61 ConversionPatternRewriter &rewriter) const {
62 assert(adaptor.getMemrefs().size() == 1 && "expected only one memref");
63 assert(adaptor.getRetained().empty() && "expected no retained memrefs");
64
65 rewriter.replaceOpWithNewOp<scf::IfOp>(
66 op, args: adaptor.getConditions()[0], args: [&](OpBuilder &builder, Location loc) {
67 builder.create<memref::DeallocOp>(location: loc, args: adaptor.getMemrefs()[0]);
68 builder.create<scf::YieldOp>(location: loc);
69 });
70 return success();
71 }
72
73 /// A special case lowering for the deallocation operation with exactly one
74 /// memref, but arbitrary number of retained values. This avoids the helper
75 /// function that the general case needs and thus also avoids storing indices
76 /// to specifically allocated memrefs. The size of the code produced by this
77 /// lowering is linear to the number of retained values.
78 ///
79 /// Example:
80 /// ```mlir
81 /// %0:2 = bufferization.dealloc (%m : memref<2xf32>) if (%cond)
82 // retain (%r0, %r1 : memref<1xf32>, memref<2xf32>)
83 /// return %0#0, %0#1 : i1, i1
84 /// ```
85 /// ```mlir
86 /// %m_base_pointer = memref.extract_aligned_pointer_as_index %m
87 /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0
88 /// %r0_does_not_alias = arith.cmpi ne, %m_base_pointer, %r0_base_pointer
89 /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1
90 /// %r1_does_not_alias = arith.cmpi ne, %m_base_pointer, %r1_base_pointer
91 /// %not_retained = arith.andi %r0_does_not_alias, %r1_does_not_alias : i1
92 /// %should_dealloc = arith.andi %not_retained, %cond : i1
93 /// scf.if %should_dealloc {
94 /// memref.dealloc %m : memref<2xf32>
95 /// }
96 /// %true = arith.constant true
97 /// %r0_does_alias = arith.xori %r0_does_not_alias, %true : i1
98 /// %r0_ownership = arith.andi %r0_does_alias, %cond : i1
99 /// %r1_does_alias = arith.xori %r1_does_not_alias, %true : i1
100 /// %r1_ownership = arith.andi %r1_does_alias, %cond : i1
101 /// return %r0_ownership, %r1_ownership : i1, i1
102 /// ```
103 LogicalResult rewriteOneMemrefMultipleRetainCase(
104 bufferization::DeallocOp op, OpAdaptor adaptor,
105 ConversionPatternRewriter &rewriter) const {
106 assert(adaptor.getMemrefs().size() == 1 && "expected only one memref");
107
108 // Compute the base pointer indices, compare all retained indices to the
109 // memref index to check if they alias.
110 SmallVector<Value> doesNotAliasList;
111 Value memrefAsIdx = rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(
112 location: op->getLoc(), args: adaptor.getMemrefs()[0]);
113 for (Value retained : adaptor.getRetained()) {
114 Value retainedAsIdx =
115 rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(location: op->getLoc(),
116 args&: retained);
117 Value doesNotAlias = rewriter.create<arith::CmpIOp>(
118 location: op->getLoc(), args: arith::CmpIPredicate::ne, args&: memrefAsIdx, args&: retainedAsIdx);
119 doesNotAliasList.push_back(Elt: doesNotAlias);
120 }
121
122 // AND-reduce the list of booleans from above.
123 Value prev = doesNotAliasList.front();
124 for (Value doesNotAlias : ArrayRef(doesNotAliasList).drop_front())
125 prev = rewriter.create<arith::AndIOp>(location: op->getLoc(), args&: prev, args&: doesNotAlias);
126
127 // Also consider the condition given by the dealloc operation and perform a
128 // conditional deallocation guarded by that value.
129 Value shouldDealloc = rewriter.create<arith::AndIOp>(
130 location: op->getLoc(), args&: prev, args: adaptor.getConditions()[0]);
131
132 rewriter.create<scf::IfOp>(
133 location: op.getLoc(), args&: shouldDealloc, args: [&](OpBuilder &builder, Location loc) {
134 builder.create<memref::DeallocOp>(location: loc, args: adaptor.getMemrefs()[0]);
135 builder.create<scf::YieldOp>(location: loc);
136 });
137
138 // Compute the replacement values for the dealloc operation results. This
139 // inserts an already canonicalized form of
140 // `select(does_alias_with_memref(r), memref_cond, false)` for each retained
141 // value r.
142 SmallVector<Value> replacements;
143 Value trueVal = rewriter.create<arith::ConstantOp>(
144 location: op->getLoc(), args: rewriter.getBoolAttr(value: true));
145 for (Value doesNotAlias : doesNotAliasList) {
146 Value aliases =
147 rewriter.create<arith::XOrIOp>(location: op->getLoc(), args&: doesNotAlias, args&: trueVal);
148 Value result = rewriter.create<arith::AndIOp>(location: op->getLoc(), args&: aliases,
149 args: adaptor.getConditions()[0]);
150 replacements.push_back(Elt: result);
151 }
152
153 rewriter.replaceOp(op, newValues: replacements);
154
155 return success();
156 }
157
158 /// Lowering that supports all features the dealloc operation has to offer. It
159 /// computes the base pointer of each memref (as an index), stores it in a
160 /// new memref helper structure and passes it to the helper function generated
161 /// in 'buildDeallocationHelperFunction'. The results are stored in two lists
162 /// (represented as memrefs) of booleans passed as arguments. The first list
163 /// stores whether the corresponding condition should be deallocated, the
164 /// second list stores the ownership of the retained values which can be used
165 /// to replace the result values of the `bufferization.dealloc` operation.
166 ///
167 /// Example:
168 /// ```
169 /// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xf32>, memref<5xf32>)
170 /// if (%cond0, %cond1)
171 /// retain (%r0, %r1 : memref<1xf32>, memref<2xf32>)
172 /// ```
173 /// lowers to (simplified):
174 /// ```
175 /// %c0 = arith.constant 0 : index
176 /// %c1 = arith.constant 1 : index
177 /// %dealloc_base_pointer_list = memref.alloc() : memref<2xindex>
178 /// %cond_list = memref.alloc() : memref<2xi1>
179 /// %retain_base_pointer_list = memref.alloc() : memref<2xindex>
180 /// %m0_base_pointer = memref.extract_aligned_pointer_as_index %m0
181 /// memref.store %m0_base_pointer, %dealloc_base_pointer_list[%c0]
182 /// %m1_base_pointer = memref.extract_aligned_pointer_as_index %m1
183 /// memref.store %m1_base_pointer, %dealloc_base_pointer_list[%c1]
184 /// memref.store %cond0, %cond_list[%c0]
185 /// memref.store %cond1, %cond_list[%c1]
186 /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0
187 /// memref.store %r0_base_pointer, %retain_base_pointer_list[%c0]
188 /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1
189 /// memref.store %r1_base_pointer, %retain_base_pointer_list[%c1]
190 /// %dyn_dealloc_base_pointer_list = memref.cast %dealloc_base_pointer_list :
191 /// memref<2xindex> to memref<?xindex>
192 /// %dyn_cond_list = memref.cast %cond_list : memref<2xi1> to memref<?xi1>
193 /// %dyn_retain_base_pointer_list = memref.cast %retain_base_pointer_list :
194 /// memref<2xindex> to memref<?xindex>
195 /// %dealloc_cond_out = memref.alloc() : memref<2xi1>
196 /// %ownership_out = memref.alloc() : memref<2xi1>
197 /// %dyn_dealloc_cond_out = memref.cast %dealloc_cond_out :
198 /// memref<2xi1> to memref<?xi1>
199 /// %dyn_ownership_out = memref.cast %ownership_out :
200 /// memref<2xi1> to memref<?xi1>
201 /// call @dealloc_helper(%dyn_dealloc_base_pointer_list,
202 /// %dyn_retain_base_pointer_list,
203 /// %dyn_cond_list,
204 /// %dyn_dealloc_cond_out,
205 /// %dyn_ownership_out) : (...)
206 /// %m0_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c0] : memref<2xi1>
207 /// scf.if %m0_dealloc_cond {
208 /// memref.dealloc %m0 : memref<2xf32>
209 /// }
210 /// %m1_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c1] : memref<2xi1>
211 /// scf.if %m1_dealloc_cond {
212 /// memref.dealloc %m1 : memref<5xf32>
213 /// }
214 /// %r0_ownership = memref.load %dyn_ownership_out[%c0] : memref<2xi1>
215 /// %r1_ownership = memref.load %dyn_ownership_out[%c1] : memref<2xi1>
216 /// memref.dealloc %dealloc_base_pointer_list : memref<2xindex>
217 /// memref.dealloc %retain_base_pointer_list : memref<2xindex>
218 /// memref.dealloc %cond_list : memref<2xi1>
219 /// memref.dealloc %dealloc_cond_out : memref<2xi1>
220 /// memref.dealloc %ownership_out : memref<2xi1>
221 /// // replace %0#0 with %r0_ownership
222 /// // replace %0#1 with %r1_ownership
223 /// ```
224 LogicalResult rewriteGeneralCase(bufferization::DeallocOp op,
225 OpAdaptor adaptor,
226 ConversionPatternRewriter &rewriter) const {
227 // Allocate two memrefs holding the base pointer indices of the list of
228 // memrefs to be deallocated and the ones to be retained. These can then be
229 // passed to the helper function and the for-loops can iterate over them.
230 // Without storing them to memrefs, we could not use for-loops but only a
231 // completely unrolled version of it, potentially leading to code-size
232 // blow-up.
233 Value toDeallocMemref = rewriter.create<memref::AllocOp>(
234 location: op.getLoc(), args: MemRefType::get(shape: {(int64_t)adaptor.getMemrefs().size()},
235 elementType: rewriter.getIndexType()));
236 Value conditionMemref = rewriter.create<memref::AllocOp>(
237 location: op.getLoc(), args: MemRefType::get(shape: {(int64_t)adaptor.getConditions().size()},
238 elementType: rewriter.getI1Type()));
239 Value toRetainMemref = rewriter.create<memref::AllocOp>(
240 location: op.getLoc(), args: MemRefType::get(shape: {(int64_t)adaptor.getRetained().size()},
241 elementType: rewriter.getIndexType()));
242
243 auto getConstValue = [&](uint64_t value) -> Value {
244 return rewriter.create<arith::ConstantOp>(location: op.getLoc(),
245 args: rewriter.getIndexAttr(value));
246 };
247
248 // Extract the base pointers of the memrefs as indices to check for aliasing
249 // at runtime.
250 for (auto [i, toDealloc] : llvm::enumerate(First: adaptor.getMemrefs())) {
251 Value memrefAsIdx =
252 rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(location: op.getLoc(),
253 args&: toDealloc);
254 rewriter.create<memref::StoreOp>(location: op.getLoc(), args&: memrefAsIdx,
255 args&: toDeallocMemref, args: getConstValue(i));
256 }
257
258 for (auto [i, cond] : llvm::enumerate(First: adaptor.getConditions()))
259 rewriter.create<memref::StoreOp>(location: op.getLoc(), args&: cond, args&: conditionMemref,
260 args: getConstValue(i));
261
262 for (auto [i, toRetain] : llvm::enumerate(First: adaptor.getRetained())) {
263 Value memrefAsIdx =
264 rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(location: op.getLoc(),
265 args&: toRetain);
266 rewriter.create<memref::StoreOp>(location: op.getLoc(), args&: memrefAsIdx, args&: toRetainMemref,
267 args: getConstValue(i));
268 }
269
270 // Cast the allocated memrefs to dynamic shape because we want only one
271 // helper function no matter how many operands the bufferization.dealloc
272 // has.
273 Value castedDeallocMemref = rewriter.create<memref::CastOp>(
274 location: op->getLoc(),
275 args: MemRefType::get(shape: {ShapedType::kDynamic}, elementType: rewriter.getIndexType()),
276 args&: toDeallocMemref);
277 Value castedCondsMemref = rewriter.create<memref::CastOp>(
278 location: op->getLoc(),
279 args: MemRefType::get(shape: {ShapedType::kDynamic}, elementType: rewriter.getI1Type()),
280 args&: conditionMemref);
281 Value castedRetainMemref = rewriter.create<memref::CastOp>(
282 location: op->getLoc(),
283 args: MemRefType::get(shape: {ShapedType::kDynamic}, elementType: rewriter.getIndexType()),
284 args&: toRetainMemref);
285
286 Value deallocCondsMemref = rewriter.create<memref::AllocOp>(
287 location: op.getLoc(), args: MemRefType::get(shape: {(int64_t)adaptor.getMemrefs().size()},
288 elementType: rewriter.getI1Type()));
289 Value retainCondsMemref = rewriter.create<memref::AllocOp>(
290 location: op.getLoc(), args: MemRefType::get(shape: {(int64_t)adaptor.getRetained().size()},
291 elementType: rewriter.getI1Type()));
292
293 Value castedDeallocCondsMemref = rewriter.create<memref::CastOp>(
294 location: op->getLoc(),
295 args: MemRefType::get(shape: {ShapedType::kDynamic}, elementType: rewriter.getI1Type()),
296 args&: deallocCondsMemref);
297 Value castedRetainCondsMemref = rewriter.create<memref::CastOp>(
298 location: op->getLoc(),
299 args: MemRefType::get(shape: {ShapedType::kDynamic}, elementType: rewriter.getI1Type()),
300 args&: retainCondsMemref);
301
302 Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
303 rewriter.create<func::CallOp>(
304 location: op.getLoc(), args: deallocHelperFuncMap.lookup(Val: symtableOp),
305 args: SmallVector<Value>{castedDeallocMemref, castedRetainMemref,
306 castedCondsMemref, castedDeallocCondsMemref,
307 castedRetainCondsMemref});
308
309 for (unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) {
310 Value idxValue = getConstValue(i);
311 Value shouldDealloc = rewriter.create<memref::LoadOp>(
312 location: op.getLoc(), args&: deallocCondsMemref, args&: idxValue);
313 rewriter.create<scf::IfOp>(
314 location: op.getLoc(), args&: shouldDealloc, args: [&](OpBuilder &builder, Location loc) {
315 builder.create<memref::DeallocOp>(location: loc, args: adaptor.getMemrefs()[i]);
316 builder.create<scf::YieldOp>(location: loc);
317 });
318 }
319
320 SmallVector<Value> replacements;
321 for (unsigned i = 0, e = adaptor.getRetained().size(); i < e; ++i) {
322 Value idxValue = getConstValue(i);
323 Value ownership = rewriter.create<memref::LoadOp>(
324 location: op.getLoc(), args&: retainCondsMemref, args&: idxValue);
325 replacements.push_back(Elt: ownership);
326 }
327
328 // Deallocate above allocated memrefs again to avoid memory leaks.
329 // Deallocation will not be run on code after this stage.
330 rewriter.create<memref::DeallocOp>(location: op.getLoc(), args&: toDeallocMemref);
331 rewriter.create<memref::DeallocOp>(location: op.getLoc(), args&: toRetainMemref);
332 rewriter.create<memref::DeallocOp>(location: op.getLoc(), args&: conditionMemref);
333 rewriter.create<memref::DeallocOp>(location: op.getLoc(), args&: deallocCondsMemref);
334 rewriter.create<memref::DeallocOp>(location: op.getLoc(), args&: retainCondsMemref);
335
336 rewriter.replaceOp(op, newValues: replacements);
337 return success();
338 }
339
340public:
341 DeallocOpConversion(
342 MLIRContext *context,
343 const bufferization::DeallocHelperMap &deallocHelperFuncMap)
344 : OpConversionPattern<bufferization::DeallocOp>(context),
345 deallocHelperFuncMap(deallocHelperFuncMap) {}
346
347 LogicalResult
348 matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter) const override {
350 // Lower the trivial case.
351 if (adaptor.getMemrefs().empty()) {
352 Value falseVal = rewriter.create<arith::ConstantOp>(
353 location: op.getLoc(), args: rewriter.getBoolAttr(value: false));
354 rewriter.replaceOp(
355 op, newValues: SmallVector<Value>(adaptor.getRetained().size(), falseVal));
356 return success();
357 }
358
359 if (adaptor.getMemrefs().size() == 1 && adaptor.getRetained().empty())
360 return rewriteOneMemrefNoRetainCase(op, adaptor, rewriter);
361
362 if (adaptor.getMemrefs().size() == 1)
363 return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter);
364
365 Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
366 if (!deallocHelperFuncMap.contains(Val: symtableOp))
367 return op->emitError(
368 message: "library function required for generic lowering, but cannot be "
369 "automatically inserted when operating on functions");
370
371 return rewriteGeneralCase(op, adaptor, rewriter);
372 }
373
374private:
375 const bufferization::DeallocHelperMap &deallocHelperFuncMap;
376};
377} // namespace
378
379namespace {
380struct LowerDeallocationsPass
381 : public bufferization::impl::LowerDeallocationsPassBase<
382 LowerDeallocationsPass> {
383 void runOnOperation() override {
384 if (!isa<ModuleOp, FunctionOpInterface>(Val: getOperation())) {
385 emitError(loc: getOperation()->getLoc(),
386 message: "root operation must be a builtin.module or a function");
387 signalPassFailure();
388 return;
389 }
390
391 bufferization::DeallocHelperMap deallocHelperFuncMap;
392 if (auto module = dyn_cast<ModuleOp>(Val: getOperation())) {
393 OpBuilder builder = OpBuilder::atBlockBegin(block: module.getBody());
394
395 // Build dealloc helper function if there are deallocs.
396 getOperation()->walk(callback: [&](bufferization::DeallocOp deallocOp) {
397 Operation *symtableOp =
398 deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
399 if (deallocOp.getMemrefs().size() > 1 &&
400 !deallocHelperFuncMap.contains(Val: symtableOp)) {
401 SymbolTable symbolTable(symtableOp);
402 func::FuncOp helperFuncOp =
403 bufferization::buildDeallocationLibraryFunction(
404 builder, loc: getOperation()->getLoc(), symbolTable);
405 deallocHelperFuncMap[symtableOp] = helperFuncOp;
406 }
407 });
408 }
409
410 RewritePatternSet patterns(&getContext());
411 bufferization::populateBufferizationDeallocLoweringPattern(
412 patterns, deallocHelperFuncMap);
413
414 ConversionTarget target(getContext());
415 target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
416 scf::SCFDialect, func::FuncDialect>();
417 target.addIllegalOp<bufferization::DeallocOp>();
418
419 if (failed(Result: applyPartialConversion(op: getOperation(), target,
420 patterns: std::move(patterns))))
421 signalPassFailure();
422 }
423};
424} // namespace
425
426func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction(
427 OpBuilder &builder, Location loc, SymbolTable &symbolTable) {
428 Type indexMemrefType =
429 MemRefType::get(shape: {ShapedType::kDynamic}, elementType: builder.getIndexType());
430 Type boolMemrefType =
431 MemRefType::get(shape: {ShapedType::kDynamic}, elementType: builder.getI1Type());
432 SmallVector<Type> argTypes{indexMemrefType, indexMemrefType, boolMemrefType,
433 boolMemrefType, boolMemrefType};
434 builder.clearInsertionPoint();
435
436 // Generate the func operation itself.
437 auto helperFuncOp = func::FuncOp::create(
438 location: loc, name: "dealloc_helper", type: builder.getFunctionType(inputs: argTypes, results: {}));
439 helperFuncOp.setVisibility(SymbolTable::Visibility::Private);
440 symbolTable.insert(symbol: helperFuncOp);
441 auto &block = helperFuncOp.getFunctionBody().emplaceBlock();
442 block.addArguments(types: argTypes, locs: SmallVector<Location>(argTypes.size(), loc));
443
444 builder.setInsertionPointToStart(&block);
445 Value toDeallocMemref = helperFuncOp.getArguments()[0];
446 Value toRetainMemref = helperFuncOp.getArguments()[1];
447 Value conditionMemref = helperFuncOp.getArguments()[2];
448 Value deallocCondsMemref = helperFuncOp.getArguments()[3];
449 Value retainCondsMemref = helperFuncOp.getArguments()[4];
450
451 // Insert some prerequisites.
452 Value c0 = builder.create<arith::ConstantOp>(location: loc, args: builder.getIndexAttr(value: 0));
453 Value c1 = builder.create<arith::ConstantOp>(location: loc, args: builder.getIndexAttr(value: 1));
454 Value trueValue =
455 builder.create<arith::ConstantOp>(location: loc, args: builder.getBoolAttr(value: true));
456 Value falseValue =
457 builder.create<arith::ConstantOp>(location: loc, args: builder.getBoolAttr(value: false));
458 Value toDeallocSize = builder.create<memref::DimOp>(location: loc, args&: toDeallocMemref, args&: c0);
459 Value toRetainSize = builder.create<memref::DimOp>(location: loc, args&: toRetainMemref, args&: c0);
460
461 builder.create<scf::ForOp>(
462 location: loc, args&: c0, args&: toRetainSize, args&: c1, args: ValueRange(),
463 args: [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) {
464 builder.create<memref::StoreOp>(location: loc, args&: falseValue, args&: retainCondsMemref, args&: i);
465 builder.create<scf::YieldOp>(location: loc);
466 });
467
468 builder.create<scf::ForOp>(
469 location: loc, args&: c0, args&: toDeallocSize, args&: c1, args: ValueRange(),
470 args: [&](OpBuilder &builder, Location loc, Value outerIter,
471 ValueRange iterArgs) {
472 Value toDealloc =
473 builder.create<memref::LoadOp>(location: loc, args&: toDeallocMemref, args&: outerIter);
474 Value cond =
475 builder.create<memref::LoadOp>(location: loc, args&: conditionMemref, args&: outerIter);
476
477 // Build the first for loop that computes aliasing with retained
478 // memrefs.
479 Value noRetainAlias =
480 builder
481 .create<scf::ForOp>(
482 location: loc, args&: c0, args&: toRetainSize, args&: c1, args&: trueValue,
483 args: [&](OpBuilder &builder, Location loc, Value i,
484 ValueRange iterArgs) {
485 Value retainValue = builder.create<memref::LoadOp>(
486 location: loc, args&: toRetainMemref, args&: i);
487 Value doesAlias = builder.create<arith::CmpIOp>(
488 location: loc, args: arith::CmpIPredicate::eq, args&: retainValue,
489 args&: toDealloc);
490 builder.create<scf::IfOp>(
491 location: loc, args&: doesAlias,
492 args: [&](OpBuilder &builder, Location loc) {
493 Value retainCondValue =
494 builder.create<memref::LoadOp>(
495 location: loc, args&: retainCondsMemref, args&: i);
496 Value aggregatedRetainCond =
497 builder.create<arith::OrIOp>(
498 location: loc, args&: retainCondValue, args&: cond);
499 builder.create<memref::StoreOp>(
500 location: loc, args&: aggregatedRetainCond, args&: retainCondsMemref,
501 args&: i);
502 builder.create<scf::YieldOp>(location: loc);
503 });
504 Value doesntAlias = builder.create<arith::CmpIOp>(
505 location: loc, args: arith::CmpIPredicate::ne, args&: retainValue,
506 args&: toDealloc);
507 Value yieldValue = builder.create<arith::AndIOp>(
508 location: loc, args: iterArgs[0], args&: doesntAlias);
509 builder.create<scf::YieldOp>(location: loc, args&: yieldValue);
510 })
511 .getResult(i: 0);
512
513 // Build the second for loop that adds aliasing with previously
514 // deallocated memrefs.
515 Value noAlias =
516 builder
517 .create<scf::ForOp>(
518 location: loc, args&: c0, args&: outerIter, args&: c1, args&: noRetainAlias,
519 args: [&](OpBuilder &builder, Location loc, Value i,
520 ValueRange iterArgs) {
521 Value prevDeallocValue = builder.create<memref::LoadOp>(
522 location: loc, args&: toDeallocMemref, args&: i);
523 Value doesntAlias = builder.create<arith::CmpIOp>(
524 location: loc, args: arith::CmpIPredicate::ne, args&: prevDeallocValue,
525 args&: toDealloc);
526 Value yieldValue = builder.create<arith::AndIOp>(
527 location: loc, args: iterArgs[0], args&: doesntAlias);
528 builder.create<scf::YieldOp>(location: loc, args&: yieldValue);
529 })
530 .getResult(i: 0);
531
532 Value shouldDealoc = builder.create<arith::AndIOp>(location: loc, args&: noAlias, args&: cond);
533 builder.create<memref::StoreOp>(location: loc, args&: shouldDealoc, args&: deallocCondsMemref,
534 args&: outerIter);
535 builder.create<scf::YieldOp>(location: loc);
536 });
537
538 builder.create<func::ReturnOp>(location: loc);
539 return helperFuncOp;
540}
541
542void mlir::bufferization::populateBufferizationDeallocLoweringPattern(
543 RewritePatternSet &patterns,
544 const bufferization::DeallocHelperMap &deallocHelperFuncMap) {
545 patterns.add<DeallocOpConversion>(arg: patterns.getContext(),
546 args: deallocHelperFuncMap);
547}
548

source code of mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp