1 | //===- LowerDeallocations.cpp - Bufferization Deallocs to MemRef pass -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements patterns to convert `bufferization.dealloc` operations |
10 | // to the MemRef dialect. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
16 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
18 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
19 | #include "mlir/Dialect/SCF/IR/SCF.h" |
20 | #include "mlir/IR/BuiltinTypes.h" |
21 | #include "mlir/Pass/Pass.h" |
22 | #include "mlir/Transforms/DialectConversion.h" |
23 | |
24 | namespace mlir { |
25 | namespace bufferization { |
26 | #define GEN_PASS_DEF_LOWERDEALLOCATIONSPASS |
27 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" |
28 | } // namespace bufferization |
29 | } // namespace mlir |
30 | |
31 | using namespace mlir; |
32 | |
33 | namespace { |
34 | /// The DeallocOpConversion transforms all bufferization dealloc operations into |
35 | /// memref dealloc operations potentially guarded by scf if operations. |
36 | /// Additionally, memref extract_aligned_pointer_as_index and arith operations |
37 | /// are inserted to compute the guard conditions. We distinguish multiple cases |
38 | /// to provide an overall more efficient lowering. In the general case, a helper |
39 | /// func is created to avoid quadratic code size explosion (relative to the |
40 | /// number of operands of the dealloc operation). For examples of each case, |
41 | /// refer to the documentation of the member functions of this class. |
42 | class DeallocOpConversion |
43 | : public OpConversionPattern<bufferization::DeallocOp> { |
44 | |
45 | /// Lower a simple case without any retained values and a single memref to |
46 | /// avoiding the helper function. Ideally, static analysis can provide enough |
47 | /// aliasing information to split the dealloc operations up into this simple |
48 | /// case as much as possible before running this pass. |
49 | /// |
50 | /// Example: |
51 | /// ``` |
52 | /// bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1) |
53 | /// ``` |
54 | /// is lowered to |
55 | /// ``` |
56 | /// scf.if %arg1 { |
57 | /// memref.dealloc %arg0 : memref<2xf32> |
58 | /// } |
59 | /// ``` |
60 | LogicalResult |
61 | rewriteOneMemrefNoRetainCase(bufferization::DeallocOp op, OpAdaptor adaptor, |
62 | ConversionPatternRewriter &rewriter) const { |
63 | assert(adaptor.getMemrefs().size() == 1 && "expected only one memref" ); |
64 | assert(adaptor.getRetained().empty() && "expected no retained memrefs" ); |
65 | |
66 | rewriter.replaceOpWithNewOp<scf::IfOp>( |
67 | op, adaptor.getConditions()[0], [&](OpBuilder &builder, Location loc) { |
68 | builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]); |
69 | builder.create<scf::YieldOp>(loc); |
70 | }); |
71 | return success(); |
72 | } |
73 | |
74 | /// A special case lowering for the deallocation operation with exactly one |
75 | /// memref, but arbitrary number of retained values. This avoids the helper |
76 | /// function that the general case needs and thus also avoids storing indices |
77 | /// to specifically allocated memrefs. The size of the code produced by this |
78 | /// lowering is linear to the number of retained values. |
79 | /// |
80 | /// Example: |
81 | /// ```mlir |
82 | /// %0:2 = bufferization.dealloc (%m : memref<2xf32>) if (%cond) |
83 | // retain (%r0, %r1 : memref<1xf32>, memref<2xf32>) |
84 | /// return %0#0, %0#1 : i1, i1 |
85 | /// ``` |
86 | /// ```mlir |
87 | /// %m_base_pointer = memref.extract_aligned_pointer_as_index %m |
88 | /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0 |
89 | /// %r0_does_not_alias = arith.cmpi ne, %m_base_pointer, %r0_base_pointer |
90 | /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1 |
91 | /// %r1_does_not_alias = arith.cmpi ne, %m_base_pointer, %r1_base_pointer |
92 | /// %not_retained = arith.andi %r0_does_not_alias, %r1_does_not_alias : i1 |
93 | /// %should_dealloc = arith.andi %not_retained, %cond : i1 |
94 | /// scf.if %should_dealloc { |
95 | /// memref.dealloc %m : memref<2xf32> |
96 | /// } |
97 | /// %true = arith.constant true |
98 | /// %r0_does_alias = arith.xori %r0_does_not_alias, %true : i1 |
99 | /// %r0_ownership = arith.andi %r0_does_alias, %cond : i1 |
100 | /// %r1_does_alias = arith.xori %r1_does_not_alias, %true : i1 |
101 | /// %r1_ownership = arith.andi %r1_does_alias, %cond : i1 |
102 | /// return %r0_ownership, %r1_ownership : i1, i1 |
103 | /// ``` |
104 | LogicalResult rewriteOneMemrefMultipleRetainCase( |
105 | bufferization::DeallocOp op, OpAdaptor adaptor, |
106 | ConversionPatternRewriter &rewriter) const { |
107 | assert(adaptor.getMemrefs().size() == 1 && "expected only one memref" ); |
108 | |
109 | // Compute the base pointer indices, compare all retained indices to the |
110 | // memref index to check if they alias. |
111 | SmallVector<Value> doesNotAliasList; |
112 | Value memrefAsIdx = rewriter.create<memref::ExtractAlignedPointerAsIndexOp>( |
113 | op->getLoc(), adaptor.getMemrefs()[0]); |
114 | for (Value retained : adaptor.getRetained()) { |
115 | Value retainedAsIdx = |
116 | rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op->getLoc(), |
117 | retained); |
118 | Value doesNotAlias = rewriter.create<arith::CmpIOp>( |
119 | op->getLoc(), arith::CmpIPredicate::ne, memrefAsIdx, retainedAsIdx); |
120 | doesNotAliasList.push_back(doesNotAlias); |
121 | } |
122 | |
123 | // AND-reduce the list of booleans from above. |
124 | Value prev = doesNotAliasList.front(); |
125 | for (Value doesNotAlias : ArrayRef(doesNotAliasList).drop_front()) |
126 | prev = rewriter.create<arith::AndIOp>(op->getLoc(), prev, doesNotAlias); |
127 | |
128 | // Also consider the condition given by the dealloc operation and perform a |
129 | // conditional deallocation guarded by that value. |
130 | Value shouldDealloc = rewriter.create<arith::AndIOp>( |
131 | op->getLoc(), prev, adaptor.getConditions()[0]); |
132 | |
133 | rewriter.create<scf::IfOp>( |
134 | op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) { |
135 | builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]); |
136 | builder.create<scf::YieldOp>(loc); |
137 | }); |
138 | |
139 | // Compute the replacement values for the dealloc operation results. This |
140 | // inserts an already canonicalized form of |
141 | // `select(does_alias_with_memref(r), memref_cond, false)` for each retained |
142 | // value r. |
143 | SmallVector<Value> replacements; |
144 | Value trueVal = rewriter.create<arith::ConstantOp>( |
145 | op->getLoc(), rewriter.getBoolAttr(true)); |
146 | for (Value doesNotAlias : doesNotAliasList) { |
147 | Value aliases = |
148 | rewriter.create<arith::XOrIOp>(op->getLoc(), doesNotAlias, trueVal); |
149 | Value result = rewriter.create<arith::AndIOp>(op->getLoc(), aliases, |
150 | adaptor.getConditions()[0]); |
151 | replacements.push_back(result); |
152 | } |
153 | |
154 | rewriter.replaceOp(op, replacements); |
155 | |
156 | return success(); |
157 | } |
158 | |
159 | /// Lowering that supports all features the dealloc operation has to offer. It |
160 | /// computes the base pointer of each memref (as an index), stores it in a |
161 | /// new memref helper structure and passes it to the helper function generated |
162 | /// in 'buildDeallocationHelperFunction'. The results are stored in two lists |
163 | /// (represented as memrefs) of booleans passed as arguments. The first list |
164 | /// stores whether the corresponding condition should be deallocated, the |
165 | /// second list stores the ownership of the retained values which can be used |
166 | /// to replace the result values of the `bufferization.dealloc` operation. |
167 | /// |
168 | /// Example: |
169 | /// ``` |
170 | /// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xf32>, memref<5xf32>) |
171 | /// if (%cond0, %cond1) |
172 | /// retain (%r0, %r1 : memref<1xf32>, memref<2xf32>) |
173 | /// ``` |
174 | /// lowers to (simplified): |
175 | /// ``` |
176 | /// %c0 = arith.constant 0 : index |
177 | /// %c1 = arith.constant 1 : index |
178 | /// %dealloc_base_pointer_list = memref.alloc() : memref<2xindex> |
179 | /// %cond_list = memref.alloc() : memref<2xi1> |
180 | /// %retain_base_pointer_list = memref.alloc() : memref<2xindex> |
181 | /// %m0_base_pointer = memref.extract_aligned_pointer_as_index %m0 |
182 | /// memref.store %m0_base_pointer, %dealloc_base_pointer_list[%c0] |
183 | /// %m1_base_pointer = memref.extract_aligned_pointer_as_index %m1 |
184 | /// memref.store %m1_base_pointer, %dealloc_base_pointer_list[%c1] |
185 | /// memref.store %cond0, %cond_list[%c0] |
186 | /// memref.store %cond1, %cond_list[%c1] |
187 | /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0 |
188 | /// memref.store %r0_base_pointer, %retain_base_pointer_list[%c0] |
189 | /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1 |
190 | /// memref.store %r1_base_pointer, %retain_base_pointer_list[%c1] |
191 | /// %dyn_dealloc_base_pointer_list = memref.cast %dealloc_base_pointer_list : |
192 | /// memref<2xindex> to memref<?xindex> |
193 | /// %dyn_cond_list = memref.cast %cond_list : memref<2xi1> to memref<?xi1> |
194 | /// %dyn_retain_base_pointer_list = memref.cast %retain_base_pointer_list : |
195 | /// memref<2xindex> to memref<?xindex> |
196 | /// %dealloc_cond_out = memref.alloc() : memref<2xi1> |
197 | /// %ownership_out = memref.alloc() : memref<2xi1> |
198 | /// %dyn_dealloc_cond_out = memref.cast %dealloc_cond_out : |
199 | /// memref<2xi1> to memref<?xi1> |
200 | /// %dyn_ownership_out = memref.cast %ownership_out : |
201 | /// memref<2xi1> to memref<?xi1> |
202 | /// call @dealloc_helper(%dyn_dealloc_base_pointer_list, |
203 | /// %dyn_retain_base_pointer_list, |
204 | /// %dyn_cond_list, |
205 | /// %dyn_dealloc_cond_out, |
206 | /// %dyn_ownership_out) : (...) |
207 | /// %m0_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c0] : memref<2xi1> |
208 | /// scf.if %m0_dealloc_cond { |
209 | /// memref.dealloc %m0 : memref<2xf32> |
210 | /// } |
211 | /// %m1_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c1] : memref<2xi1> |
212 | /// scf.if %m1_dealloc_cond { |
213 | /// memref.dealloc %m1 : memref<5xf32> |
214 | /// } |
215 | /// %r0_ownership = memref.load %dyn_ownership_out[%c0] : memref<2xi1> |
216 | /// %r1_ownership = memref.load %dyn_ownership_out[%c1] : memref<2xi1> |
217 | /// memref.dealloc %dealloc_base_pointer_list : memref<2xindex> |
218 | /// memref.dealloc %retain_base_pointer_list : memref<2xindex> |
219 | /// memref.dealloc %cond_list : memref<2xi1> |
220 | /// memref.dealloc %dealloc_cond_out : memref<2xi1> |
221 | /// memref.dealloc %ownership_out : memref<2xi1> |
222 | /// // replace %0#0 with %r0_ownership |
223 | /// // replace %0#1 with %r1_ownership |
224 | /// ``` |
225 | LogicalResult rewriteGeneralCase(bufferization::DeallocOp op, |
226 | OpAdaptor adaptor, |
227 | ConversionPatternRewriter &rewriter) const { |
228 | // Allocate two memrefs holding the base pointer indices of the list of |
229 | // memrefs to be deallocated and the ones to be retained. These can then be |
230 | // passed to the helper function and the for-loops can iterate over them. |
231 | // Without storing them to memrefs, we could not use for-loops but only a |
232 | // completely unrolled version of it, potentially leading to code-size |
233 | // blow-up. |
234 | Value toDeallocMemref = rewriter.create<memref::AllocOp>( |
235 | op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()}, |
236 | rewriter.getIndexType())); |
237 | Value conditionMemref = rewriter.create<memref::AllocOp>( |
238 | op.getLoc(), MemRefType::get({(int64_t)adaptor.getConditions().size()}, |
239 | rewriter.getI1Type())); |
240 | Value toRetainMemref = rewriter.create<memref::AllocOp>( |
241 | op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()}, |
242 | rewriter.getIndexType())); |
243 | |
244 | auto getConstValue = [&](uint64_t value) -> Value { |
245 | return rewriter.create<arith::ConstantOp>(op.getLoc(), |
246 | rewriter.getIndexAttr(value)); |
247 | }; |
248 | |
249 | // Extract the base pointers of the memrefs as indices to check for aliasing |
250 | // at runtime. |
251 | for (auto [i, toDealloc] : llvm::enumerate(adaptor.getMemrefs())) { |
252 | Value memrefAsIdx = |
253 | rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(), |
254 | toDealloc); |
255 | rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx, |
256 | toDeallocMemref, getConstValue(i)); |
257 | } |
258 | |
259 | for (auto [i, cond] : llvm::enumerate(adaptor.getConditions())) |
260 | rewriter.create<memref::StoreOp>(op.getLoc(), cond, conditionMemref, |
261 | getConstValue(i)); |
262 | |
263 | for (auto [i, toRetain] : llvm::enumerate(adaptor.getRetained())) { |
264 | Value memrefAsIdx = |
265 | rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(), |
266 | toRetain); |
267 | rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx, toRetainMemref, |
268 | getConstValue(i)); |
269 | } |
270 | |
271 | // Cast the allocated memrefs to dynamic shape because we want only one |
272 | // helper function no matter how many operands the bufferization.dealloc |
273 | // has. |
274 | Value castedDeallocMemref = rewriter.create<memref::CastOp>( |
275 | op->getLoc(), |
276 | MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()), |
277 | toDeallocMemref); |
278 | Value castedCondsMemref = rewriter.create<memref::CastOp>( |
279 | op->getLoc(), |
280 | MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), |
281 | conditionMemref); |
282 | Value castedRetainMemref = rewriter.create<memref::CastOp>( |
283 | op->getLoc(), |
284 | MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()), |
285 | toRetainMemref); |
286 | |
287 | Value deallocCondsMemref = rewriter.create<memref::AllocOp>( |
288 | op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()}, |
289 | rewriter.getI1Type())); |
290 | Value retainCondsMemref = rewriter.create<memref::AllocOp>( |
291 | op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()}, |
292 | rewriter.getI1Type())); |
293 | |
294 | Value castedDeallocCondsMemref = rewriter.create<memref::CastOp>( |
295 | op->getLoc(), |
296 | MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), |
297 | deallocCondsMemref); |
298 | Value castedRetainCondsMemref = rewriter.create<memref::CastOp>( |
299 | op->getLoc(), |
300 | MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), |
301 | retainCondsMemref); |
302 | |
303 | Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>(); |
304 | rewriter.create<func::CallOp>( |
305 | op.getLoc(), deallocHelperFuncMap.lookup(symtableOp), |
306 | SmallVector<Value>{castedDeallocMemref, castedRetainMemref, |
307 | castedCondsMemref, castedDeallocCondsMemref, |
308 | castedRetainCondsMemref}); |
309 | |
310 | for (unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) { |
311 | Value idxValue = getConstValue(i); |
312 | Value shouldDealloc = rewriter.create<memref::LoadOp>( |
313 | op.getLoc(), deallocCondsMemref, idxValue); |
314 | rewriter.create<scf::IfOp>( |
315 | op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) { |
316 | builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[i]); |
317 | builder.create<scf::YieldOp>(loc); |
318 | }); |
319 | } |
320 | |
321 | SmallVector<Value> replacements; |
322 | for (unsigned i = 0, e = adaptor.getRetained().size(); i < e; ++i) { |
323 | Value idxValue = getConstValue(i); |
324 | Value ownership = rewriter.create<memref::LoadOp>( |
325 | op.getLoc(), retainCondsMemref, idxValue); |
326 | replacements.push_back(ownership); |
327 | } |
328 | |
329 | // Deallocate above allocated memrefs again to avoid memory leaks. |
330 | // Deallocation will not be run on code after this stage. |
331 | rewriter.create<memref::DeallocOp>(op.getLoc(), toDeallocMemref); |
332 | rewriter.create<memref::DeallocOp>(op.getLoc(), toRetainMemref); |
333 | rewriter.create<memref::DeallocOp>(op.getLoc(), conditionMemref); |
334 | rewriter.create<memref::DeallocOp>(op.getLoc(), deallocCondsMemref); |
335 | rewriter.create<memref::DeallocOp>(op.getLoc(), retainCondsMemref); |
336 | |
337 | rewriter.replaceOp(op, replacements); |
338 | return success(); |
339 | } |
340 | |
341 | public: |
342 | DeallocOpConversion( |
343 | MLIRContext *context, |
344 | const bufferization::DeallocHelperMap &deallocHelperFuncMap) |
345 | : OpConversionPattern<bufferization::DeallocOp>(context), |
346 | deallocHelperFuncMap(deallocHelperFuncMap) {} |
347 | |
348 | LogicalResult |
349 | matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor, |
350 | ConversionPatternRewriter &rewriter) const override { |
351 | // Lower the trivial case. |
352 | if (adaptor.getMemrefs().empty()) { |
353 | Value falseVal = rewriter.create<arith::ConstantOp>( |
354 | op.getLoc(), rewriter.getBoolAttr(false)); |
355 | rewriter.replaceOp( |
356 | op, SmallVector<Value>(adaptor.getRetained().size(), falseVal)); |
357 | return success(); |
358 | } |
359 | |
360 | if (adaptor.getMemrefs().size() == 1 && adaptor.getRetained().empty()) |
361 | return rewriteOneMemrefNoRetainCase(op, adaptor, rewriter); |
362 | |
363 | if (adaptor.getMemrefs().size() == 1) |
364 | return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter); |
365 | |
366 | Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>(); |
367 | if (!deallocHelperFuncMap.contains(symtableOp)) |
368 | return op->emitError( |
369 | "library function required for generic lowering, but cannot be " |
370 | "automatically inserted when operating on functions" ); |
371 | |
372 | return rewriteGeneralCase(op, adaptor, rewriter); |
373 | } |
374 | |
375 | private: |
376 | const bufferization::DeallocHelperMap &deallocHelperFuncMap; |
377 | }; |
378 | } // namespace |
379 | |
380 | namespace { |
381 | struct LowerDeallocationsPass |
382 | : public bufferization::impl::LowerDeallocationsPassBase< |
383 | LowerDeallocationsPass> { |
384 | void runOnOperation() override { |
385 | if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) { |
386 | emitError(getOperation()->getLoc(), |
387 | "root operation must be a builtin.module or a function" ); |
388 | signalPassFailure(); |
389 | return; |
390 | } |
391 | |
392 | bufferization::DeallocHelperMap deallocHelperFuncMap; |
393 | if (auto module = dyn_cast<ModuleOp>(getOperation())) { |
394 | OpBuilder builder = OpBuilder::atBlockBegin(block: module.getBody()); |
395 | |
396 | // Build dealloc helper function if there are deallocs. |
397 | getOperation()->walk([&](bufferization::DeallocOp deallocOp) { |
398 | Operation *symtableOp = |
399 | deallocOp->getParentWithTrait<OpTrait::SymbolTable>(); |
400 | if (deallocOp.getMemrefs().size() > 1 && |
401 | !deallocHelperFuncMap.contains(Val: symtableOp)) { |
402 | SymbolTable symbolTable(symtableOp); |
403 | func::FuncOp helperFuncOp = |
404 | bufferization::buildDeallocationLibraryFunction( |
405 | builder, getOperation()->getLoc(), symbolTable); |
406 | deallocHelperFuncMap[symtableOp] = helperFuncOp; |
407 | } |
408 | }); |
409 | } |
410 | |
411 | RewritePatternSet patterns(&getContext()); |
412 | bufferization::populateBufferizationDeallocLoweringPattern( |
413 | patterns, deallocHelperFuncMap); |
414 | |
415 | ConversionTarget target(getContext()); |
416 | target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect, |
417 | scf::SCFDialect, func::FuncDialect>(); |
418 | target.addIllegalOp<bufferization::DeallocOp>(); |
419 | |
420 | if (failed(applyPartialConversion(getOperation(), target, |
421 | std::move(patterns)))) |
422 | signalPassFailure(); |
423 | } |
424 | }; |
425 | } // namespace |
426 | |
427 | func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction( |
428 | OpBuilder &builder, Location loc, SymbolTable &symbolTable) { |
429 | Type indexMemrefType = |
430 | MemRefType::get({ShapedType::kDynamic}, builder.getIndexType()); |
431 | Type boolMemrefType = |
432 | MemRefType::get({ShapedType::kDynamic}, builder.getI1Type()); |
433 | SmallVector<Type> argTypes{indexMemrefType, indexMemrefType, boolMemrefType, |
434 | boolMemrefType, boolMemrefType}; |
435 | builder.clearInsertionPoint(); |
436 | |
437 | // Generate the func operation itself. |
438 | auto helperFuncOp = func::FuncOp::create( |
439 | loc, "dealloc_helper" , builder.getFunctionType(argTypes, {})); |
440 | helperFuncOp.setVisibility(SymbolTable::Visibility::Private); |
441 | symbolTable.insert(symbol: helperFuncOp); |
442 | auto &block = helperFuncOp.getFunctionBody().emplaceBlock(); |
443 | block.addArguments(argTypes, SmallVector<Location>(argTypes.size(), loc)); |
444 | |
445 | builder.setInsertionPointToStart(&block); |
446 | Value toDeallocMemref = helperFuncOp.getArguments()[0]; |
447 | Value toRetainMemref = helperFuncOp.getArguments()[1]; |
448 | Value conditionMemref = helperFuncOp.getArguments()[2]; |
449 | Value deallocCondsMemref = helperFuncOp.getArguments()[3]; |
450 | Value retainCondsMemref = helperFuncOp.getArguments()[4]; |
451 | |
452 | // Insert some prerequisites. |
453 | Value c0 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(0)); |
454 | Value c1 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(1)); |
455 | Value trueValue = |
456 | builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(true)); |
457 | Value falseValue = |
458 | builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(false)); |
459 | Value toDeallocSize = builder.create<memref::DimOp>(loc, toDeallocMemref, c0); |
460 | Value toRetainSize = builder.create<memref::DimOp>(loc, toRetainMemref, c0); |
461 | |
462 | builder.create<scf::ForOp>( |
463 | loc, c0, toRetainSize, c1, std::nullopt, |
464 | [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { |
465 | builder.create<memref::StoreOp>(loc, falseValue, retainCondsMemref, i); |
466 | builder.create<scf::YieldOp>(loc); |
467 | }); |
468 | |
469 | builder.create<scf::ForOp>( |
470 | loc, c0, toDeallocSize, c1, std::nullopt, |
471 | [&](OpBuilder &builder, Location loc, Value outerIter, |
472 | ValueRange iterArgs) { |
473 | Value toDealloc = |
474 | builder.create<memref::LoadOp>(loc, toDeallocMemref, outerIter); |
475 | Value cond = |
476 | builder.create<memref::LoadOp>(loc, conditionMemref, outerIter); |
477 | |
478 | // Build the first for loop that computes aliasing with retained |
479 | // memrefs. |
480 | Value noRetainAlias = |
481 | builder |
482 | .create<scf::ForOp>( |
483 | loc, c0, toRetainSize, c1, trueValue, |
484 | [&](OpBuilder &builder, Location loc, Value i, |
485 | ValueRange iterArgs) { |
486 | Value retainValue = builder.create<memref::LoadOp>( |
487 | loc, toRetainMemref, i); |
488 | Value doesAlias = builder.create<arith::CmpIOp>( |
489 | loc, arith::CmpIPredicate::eq, retainValue, |
490 | toDealloc); |
491 | builder.create<scf::IfOp>( |
492 | loc, doesAlias, |
493 | [&](OpBuilder &builder, Location loc) { |
494 | Value retainCondValue = |
495 | builder.create<memref::LoadOp>( |
496 | loc, retainCondsMemref, i); |
497 | Value aggregatedRetainCond = |
498 | builder.create<arith::OrIOp>( |
499 | loc, retainCondValue, cond); |
500 | builder.create<memref::StoreOp>( |
501 | loc, aggregatedRetainCond, retainCondsMemref, |
502 | i); |
503 | builder.create<scf::YieldOp>(loc); |
504 | }); |
505 | Value doesntAlias = builder.create<arith::CmpIOp>( |
506 | loc, arith::CmpIPredicate::ne, retainValue, |
507 | toDealloc); |
508 | Value yieldValue = builder.create<arith::AndIOp>( |
509 | loc, iterArgs[0], doesntAlias); |
510 | builder.create<scf::YieldOp>(loc, yieldValue); |
511 | }) |
512 | .getResult(0); |
513 | |
514 | // Build the second for loop that adds aliasing with previously |
515 | // deallocated memrefs. |
516 | Value noAlias = |
517 | builder |
518 | .create<scf::ForOp>( |
519 | loc, c0, outerIter, c1, noRetainAlias, |
520 | [&](OpBuilder &builder, Location loc, Value i, |
521 | ValueRange iterArgs) { |
522 | Value prevDeallocValue = builder.create<memref::LoadOp>( |
523 | loc, toDeallocMemref, i); |
524 | Value doesntAlias = builder.create<arith::CmpIOp>( |
525 | loc, arith::CmpIPredicate::ne, prevDeallocValue, |
526 | toDealloc); |
527 | Value yieldValue = builder.create<arith::AndIOp>( |
528 | loc, iterArgs[0], doesntAlias); |
529 | builder.create<scf::YieldOp>(loc, yieldValue); |
530 | }) |
531 | .getResult(0); |
532 | |
533 | Value shouldDealoc = builder.create<arith::AndIOp>(loc, noAlias, cond); |
534 | builder.create<memref::StoreOp>(loc, shouldDealoc, deallocCondsMemref, |
535 | outerIter); |
536 | builder.create<scf::YieldOp>(loc); |
537 | }); |
538 | |
539 | builder.create<func::ReturnOp>(loc); |
540 | return helperFuncOp; |
541 | } |
542 | |
543 | void mlir::bufferization::populateBufferizationDeallocLoweringPattern( |
544 | RewritePatternSet &patterns, |
545 | const bufferization::DeallocHelperMap &deallocHelperFuncMap) { |
546 | patterns.add<DeallocOpConversion>(arg: patterns.getContext(), |
547 | args: deallocHelperFuncMap); |
548 | } |
549 | |