1 | //===- LowerDeallocations.cpp - Bufferization Deallocs to MemRef pass -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements patterns to convert `bufferization.dealloc` operations |
10 | // to the MemRef dialect. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
16 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
18 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
19 | #include "mlir/Dialect/SCF/IR/SCF.h" |
20 | #include "mlir/IR/BuiltinTypes.h" |
21 | #include "mlir/Pass/Pass.h" |
22 | #include "mlir/Support/LogicalResult.h" |
23 | #include "mlir/Transforms/DialectConversion.h" |
24 | |
25 | namespace mlir { |
26 | namespace bufferization { |
27 | #define GEN_PASS_DEF_LOWERDEALLOCATIONS |
28 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" |
29 | } // namespace bufferization |
30 | } // namespace mlir |
31 | |
32 | using namespace mlir; |
33 | |
34 | namespace { |
35 | /// The DeallocOpConversion transforms all bufferization dealloc operations into |
36 | /// memref dealloc operations potentially guarded by scf if operations. |
37 | /// Additionally, memref extract_aligned_pointer_as_index and arith operations |
38 | /// are inserted to compute the guard conditions. We distinguish multiple cases |
39 | /// to provide an overall more efficient lowering. In the general case, a helper |
40 | /// func is created to avoid quadratic code size explosion (relative to the |
41 | /// number of operands of the dealloc operation). For examples of each case, |
42 | /// refer to the documentation of the member functions of this class. |
43 | class DeallocOpConversion |
44 | : public OpConversionPattern<bufferization::DeallocOp> { |
45 | |
46 | /// Lower a simple case without any retained values and a single memref to |
47 | /// avoiding the helper function. Ideally, static analysis can provide enough |
48 | /// aliasing information to split the dealloc operations up into this simple |
49 | /// case as much as possible before running this pass. |
50 | /// |
51 | /// Example: |
52 | /// ``` |
53 | /// bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1) |
54 | /// ``` |
55 | /// is lowered to |
56 | /// ``` |
57 | /// scf.if %arg1 { |
58 | /// memref.dealloc %arg0 : memref<2xf32> |
59 | /// } |
60 | /// ``` |
61 | LogicalResult |
62 | rewriteOneMemrefNoRetainCase(bufferization::DeallocOp op, OpAdaptor adaptor, |
63 | ConversionPatternRewriter &rewriter) const { |
64 | assert(adaptor.getMemrefs().size() == 1 && "expected only one memref" ); |
65 | assert(adaptor.getRetained().empty() && "expected no retained memrefs" ); |
66 | |
67 | rewriter.replaceOpWithNewOp<scf::IfOp>( |
68 | op, adaptor.getConditions()[0], [&](OpBuilder &builder, Location loc) { |
69 | builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]); |
70 | builder.create<scf::YieldOp>(loc); |
71 | }); |
72 | return success(); |
73 | } |
74 | |
75 | /// A special case lowering for the deallocation operation with exactly one |
76 | /// memref, but arbitrary number of retained values. This avoids the helper |
77 | /// function that the general case needs and thus also avoids storing indices |
78 | /// to specifically allocated memrefs. The size of the code produced by this |
79 | /// lowering is linear to the number of retained values. |
80 | /// |
81 | /// Example: |
82 | /// ```mlir |
83 | /// %0:2 = bufferization.dealloc (%m : memref<2xf32>) if (%cond) |
84 | // retain (%r0, %r1 : memref<1xf32>, memref<2xf32>) |
85 | /// return %0#0, %0#1 : i1, i1 |
86 | /// ``` |
87 | /// ```mlir |
88 | /// %m_base_pointer = memref.extract_aligned_pointer_as_index %m |
89 | /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0 |
90 | /// %r0_does_not_alias = arith.cmpi ne, %m_base_pointer, %r0_base_pointer |
91 | /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1 |
92 | /// %r1_does_not_alias = arith.cmpi ne, %m_base_pointer, %r1_base_pointer |
93 | /// %not_retained = arith.andi %r0_does_not_alias, %r1_does_not_alias : i1 |
94 | /// %should_dealloc = arith.andi %not_retained, %cond : i1 |
95 | /// scf.if %should_dealloc { |
96 | /// memref.dealloc %m : memref<2xf32> |
97 | /// } |
98 | /// %true = arith.constant true |
99 | /// %r0_does_alias = arith.xori %r0_does_not_alias, %true : i1 |
100 | /// %r0_ownership = arith.andi %r0_does_alias, %cond : i1 |
101 | /// %r1_does_alias = arith.xori %r1_does_not_alias, %true : i1 |
102 | /// %r1_ownership = arith.andi %r1_does_alias, %cond : i1 |
103 | /// return %r0_ownership, %r1_ownership : i1, i1 |
104 | /// ``` |
105 | LogicalResult rewriteOneMemrefMultipleRetainCase( |
106 | bufferization::DeallocOp op, OpAdaptor adaptor, |
107 | ConversionPatternRewriter &rewriter) const { |
108 | assert(adaptor.getMemrefs().size() == 1 && "expected only one memref" ); |
109 | |
110 | // Compute the base pointer indices, compare all retained indices to the |
111 | // memref index to check if they alias. |
112 | SmallVector<Value> doesNotAliasList; |
113 | Value memrefAsIdx = rewriter.create<memref::ExtractAlignedPointerAsIndexOp>( |
114 | op->getLoc(), adaptor.getMemrefs()[0]); |
115 | for (Value retained : adaptor.getRetained()) { |
116 | Value retainedAsIdx = |
117 | rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op->getLoc(), |
118 | retained); |
119 | Value doesNotAlias = rewriter.create<arith::CmpIOp>( |
120 | op->getLoc(), arith::CmpIPredicate::ne, memrefAsIdx, retainedAsIdx); |
121 | doesNotAliasList.push_back(doesNotAlias); |
122 | } |
123 | |
124 | // AND-reduce the list of booleans from above. |
125 | Value prev = doesNotAliasList.front(); |
126 | for (Value doesNotAlias : ArrayRef(doesNotAliasList).drop_front()) |
127 | prev = rewriter.create<arith::AndIOp>(op->getLoc(), prev, doesNotAlias); |
128 | |
129 | // Also consider the condition given by the dealloc operation and perform a |
130 | // conditional deallocation guarded by that value. |
131 | Value shouldDealloc = rewriter.create<arith::AndIOp>( |
132 | op->getLoc(), prev, adaptor.getConditions()[0]); |
133 | |
134 | rewriter.create<scf::IfOp>( |
135 | op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) { |
136 | builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]); |
137 | builder.create<scf::YieldOp>(loc); |
138 | }); |
139 | |
140 | // Compute the replacement values for the dealloc operation results. This |
141 | // inserts an already canonicalized form of |
142 | // `select(does_alias_with_memref(r), memref_cond, false)` for each retained |
143 | // value r. |
144 | SmallVector<Value> replacements; |
145 | Value trueVal = rewriter.create<arith::ConstantOp>( |
146 | op->getLoc(), rewriter.getBoolAttr(true)); |
147 | for (Value doesNotAlias : doesNotAliasList) { |
148 | Value aliases = |
149 | rewriter.create<arith::XOrIOp>(op->getLoc(), doesNotAlias, trueVal); |
150 | Value result = rewriter.create<arith::AndIOp>(op->getLoc(), aliases, |
151 | adaptor.getConditions()[0]); |
152 | replacements.push_back(result); |
153 | } |
154 | |
155 | rewriter.replaceOp(op, replacements); |
156 | |
157 | return success(); |
158 | } |
159 | |
160 | /// Lowering that supports all features the dealloc operation has to offer. It |
161 | /// computes the base pointer of each memref (as an index), stores it in a |
162 | /// new memref helper structure and passes it to the helper function generated |
163 | /// in 'buildDeallocationHelperFunction'. The results are stored in two lists |
164 | /// (represented as memrefs) of booleans passed as arguments. The first list |
165 | /// stores whether the corresponding condition should be deallocated, the |
166 | /// second list stores the ownership of the retained values which can be used |
167 | /// to replace the result values of the `bufferization.dealloc` operation. |
168 | /// |
169 | /// Example: |
170 | /// ``` |
171 | /// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xf32>, memref<5xf32>) |
172 | /// if (%cond0, %cond1) |
173 | /// retain (%r0, %r1 : memref<1xf32>, memref<2xf32>) |
174 | /// ``` |
175 | /// lowers to (simplified): |
176 | /// ``` |
177 | /// %c0 = arith.constant 0 : index |
178 | /// %c1 = arith.constant 1 : index |
179 | /// %dealloc_base_pointer_list = memref.alloc() : memref<2xindex> |
180 | /// %cond_list = memref.alloc() : memref<2xi1> |
181 | /// %retain_base_pointer_list = memref.alloc() : memref<2xindex> |
182 | /// %m0_base_pointer = memref.extract_aligned_pointer_as_index %m0 |
183 | /// memref.store %m0_base_pointer, %dealloc_base_pointer_list[%c0] |
184 | /// %m1_base_pointer = memref.extract_aligned_pointer_as_index %m1 |
185 | /// memref.store %m1_base_pointer, %dealloc_base_pointer_list[%c1] |
186 | /// memref.store %cond0, %cond_list[%c0] |
187 | /// memref.store %cond1, %cond_list[%c1] |
188 | /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0 |
189 | /// memref.store %r0_base_pointer, %retain_base_pointer_list[%c0] |
190 | /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1 |
191 | /// memref.store %r1_base_pointer, %retain_base_pointer_list[%c1] |
192 | /// %dyn_dealloc_base_pointer_list = memref.cast %dealloc_base_pointer_list : |
193 | /// memref<2xindex> to memref<?xindex> |
194 | /// %dyn_cond_list = memref.cast %cond_list : memref<2xi1> to memref<?xi1> |
195 | /// %dyn_retain_base_pointer_list = memref.cast %retain_base_pointer_list : |
196 | /// memref<2xindex> to memref<?xindex> |
197 | /// %dealloc_cond_out = memref.alloc() : memref<2xi1> |
198 | /// %ownership_out = memref.alloc() : memref<2xi1> |
199 | /// %dyn_dealloc_cond_out = memref.cast %dealloc_cond_out : |
200 | /// memref<2xi1> to memref<?xi1> |
201 | /// %dyn_ownership_out = memref.cast %ownership_out : |
202 | /// memref<2xi1> to memref<?xi1> |
203 | /// call @dealloc_helper(%dyn_dealloc_base_pointer_list, |
204 | /// %dyn_retain_base_pointer_list, |
205 | /// %dyn_cond_list, |
206 | /// %dyn_dealloc_cond_out, |
207 | /// %dyn_ownership_out) : (...) |
208 | /// %m0_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c0] : memref<2xi1> |
209 | /// scf.if %m0_dealloc_cond { |
210 | /// memref.dealloc %m0 : memref<2xf32> |
211 | /// } |
212 | /// %m1_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c1] : memref<2xi1> |
213 | /// scf.if %m1_dealloc_cond { |
214 | /// memref.dealloc %m1 : memref<5xf32> |
215 | /// } |
216 | /// %r0_ownership = memref.load %dyn_ownership_out[%c0] : memref<2xi1> |
217 | /// %r1_ownership = memref.load %dyn_ownership_out[%c1] : memref<2xi1> |
218 | /// memref.dealloc %dealloc_base_pointer_list : memref<2xindex> |
219 | /// memref.dealloc %retain_base_pointer_list : memref<2xindex> |
220 | /// memref.dealloc %cond_list : memref<2xi1> |
221 | /// memref.dealloc %dealloc_cond_out : memref<2xi1> |
222 | /// memref.dealloc %ownership_out : memref<2xi1> |
223 | /// // replace %0#0 with %r0_ownership |
224 | /// // replace %0#1 with %r1_ownership |
225 | /// ``` |
226 | LogicalResult rewriteGeneralCase(bufferization::DeallocOp op, |
227 | OpAdaptor adaptor, |
228 | ConversionPatternRewriter &rewriter) const { |
229 | // Allocate two memrefs holding the base pointer indices of the list of |
230 | // memrefs to be deallocated and the ones to be retained. These can then be |
231 | // passed to the helper function and the for-loops can iterate over them. |
232 | // Without storing them to memrefs, we could not use for-loops but only a |
233 | // completely unrolled version of it, potentially leading to code-size |
234 | // blow-up. |
235 | Value toDeallocMemref = rewriter.create<memref::AllocOp>( |
236 | op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()}, |
237 | rewriter.getIndexType())); |
238 | Value conditionMemref = rewriter.create<memref::AllocOp>( |
239 | op.getLoc(), MemRefType::get({(int64_t)adaptor.getConditions().size()}, |
240 | rewriter.getI1Type())); |
241 | Value toRetainMemref = rewriter.create<memref::AllocOp>( |
242 | op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()}, |
243 | rewriter.getIndexType())); |
244 | |
245 | auto getConstValue = [&](uint64_t value) -> Value { |
246 | return rewriter.create<arith::ConstantOp>(op.getLoc(), |
247 | rewriter.getIndexAttr(value)); |
248 | }; |
249 | |
250 | // Extract the base pointers of the memrefs as indices to check for aliasing |
251 | // at runtime. |
252 | for (auto [i, toDealloc] : llvm::enumerate(adaptor.getMemrefs())) { |
253 | Value memrefAsIdx = |
254 | rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(), |
255 | toDealloc); |
256 | rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx, |
257 | toDeallocMemref, getConstValue(i)); |
258 | } |
259 | |
260 | for (auto [i, cond] : llvm::enumerate(adaptor.getConditions())) |
261 | rewriter.create<memref::StoreOp>(op.getLoc(), cond, conditionMemref, |
262 | getConstValue(i)); |
263 | |
264 | for (auto [i, toRetain] : llvm::enumerate(adaptor.getRetained())) { |
265 | Value memrefAsIdx = |
266 | rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(), |
267 | toRetain); |
268 | rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx, toRetainMemref, |
269 | getConstValue(i)); |
270 | } |
271 | |
272 | // Cast the allocated memrefs to dynamic shape because we want only one |
273 | // helper function no matter how many operands the bufferization.dealloc |
274 | // has. |
275 | Value castedDeallocMemref = rewriter.create<memref::CastOp>( |
276 | op->getLoc(), |
277 | MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()), |
278 | toDeallocMemref); |
279 | Value castedCondsMemref = rewriter.create<memref::CastOp>( |
280 | op->getLoc(), |
281 | MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), |
282 | conditionMemref); |
283 | Value castedRetainMemref = rewriter.create<memref::CastOp>( |
284 | op->getLoc(), |
285 | MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()), |
286 | toRetainMemref); |
287 | |
288 | Value deallocCondsMemref = rewriter.create<memref::AllocOp>( |
289 | op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()}, |
290 | rewriter.getI1Type())); |
291 | Value retainCondsMemref = rewriter.create<memref::AllocOp>( |
292 | op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()}, |
293 | rewriter.getI1Type())); |
294 | |
295 | Value castedDeallocCondsMemref = rewriter.create<memref::CastOp>( |
296 | op->getLoc(), |
297 | MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), |
298 | deallocCondsMemref); |
299 | Value castedRetainCondsMemref = rewriter.create<memref::CastOp>( |
300 | op->getLoc(), |
301 | MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), |
302 | retainCondsMemref); |
303 | |
304 | rewriter.create<func::CallOp>( |
305 | op.getLoc(), deallocHelperFunc, |
306 | SmallVector<Value>{castedDeallocMemref, castedRetainMemref, |
307 | castedCondsMemref, castedDeallocCondsMemref, |
308 | castedRetainCondsMemref}); |
309 | |
310 | for (unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) { |
311 | Value idxValue = getConstValue(i); |
312 | Value shouldDealloc = rewriter.create<memref::LoadOp>( |
313 | op.getLoc(), deallocCondsMemref, idxValue); |
314 | rewriter.create<scf::IfOp>( |
315 | op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) { |
316 | builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[i]); |
317 | builder.create<scf::YieldOp>(loc); |
318 | }); |
319 | } |
320 | |
321 | SmallVector<Value> replacements; |
322 | for (unsigned i = 0, e = adaptor.getRetained().size(); i < e; ++i) { |
323 | Value idxValue = getConstValue(i); |
324 | Value ownership = rewriter.create<memref::LoadOp>( |
325 | op.getLoc(), retainCondsMemref, idxValue); |
326 | replacements.push_back(ownership); |
327 | } |
328 | |
329 | // Deallocate above allocated memrefs again to avoid memory leaks. |
330 | // Deallocation will not be run on code after this stage. |
331 | rewriter.create<memref::DeallocOp>(op.getLoc(), toDeallocMemref); |
332 | rewriter.create<memref::DeallocOp>(op.getLoc(), toRetainMemref); |
333 | rewriter.create<memref::DeallocOp>(op.getLoc(), conditionMemref); |
334 | rewriter.create<memref::DeallocOp>(op.getLoc(), deallocCondsMemref); |
335 | rewriter.create<memref::DeallocOp>(op.getLoc(), retainCondsMemref); |
336 | |
337 | rewriter.replaceOp(op, replacements); |
338 | return success(); |
339 | } |
340 | |
341 | public: |
342 | DeallocOpConversion(MLIRContext *context, func::FuncOp deallocHelperFunc) |
343 | : OpConversionPattern<bufferization::DeallocOp>(context), |
344 | deallocHelperFunc(deallocHelperFunc) {} |
345 | |
346 | LogicalResult |
347 | matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor, |
348 | ConversionPatternRewriter &rewriter) const override { |
349 | // Lower the trivial case. |
350 | if (adaptor.getMemrefs().empty()) { |
351 | Value falseVal = rewriter.create<arith::ConstantOp>( |
352 | op.getLoc(), rewriter.getBoolAttr(false)); |
353 | rewriter.replaceOp( |
354 | op, SmallVector<Value>(adaptor.getRetained().size(), falseVal)); |
355 | return success(); |
356 | } |
357 | |
358 | if (adaptor.getMemrefs().size() == 1 && adaptor.getRetained().empty()) |
359 | return rewriteOneMemrefNoRetainCase(op, adaptor, rewriter); |
360 | |
361 | if (adaptor.getMemrefs().size() == 1) |
362 | return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter); |
363 | |
364 | if (!deallocHelperFunc) |
365 | return op->emitError( |
366 | "library function required for generic lowering, but cannot be " |
367 | "automatically inserted when operating on functions" ); |
368 | |
369 | return rewriteGeneralCase(op, adaptor, rewriter); |
370 | } |
371 | |
372 | private: |
373 | func::FuncOp deallocHelperFunc; |
374 | }; |
375 | } // namespace |
376 | |
377 | namespace { |
378 | struct LowerDeallocationsPass |
379 | : public bufferization::impl::LowerDeallocationsBase< |
380 | LowerDeallocationsPass> { |
381 | void runOnOperation() override { |
382 | if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) { |
383 | emitError(getOperation()->getLoc(), |
384 | "root operation must be a builtin.module or a function" ); |
385 | signalPassFailure(); |
386 | return; |
387 | } |
388 | |
389 | func::FuncOp helperFuncOp; |
390 | if (auto module = dyn_cast<ModuleOp>(getOperation())) { |
391 | OpBuilder builder = |
392 | OpBuilder::atBlockBegin(block: &module.getBodyRegion().front()); |
393 | SymbolTable symbolTable(module); |
394 | |
395 | // Build dealloc helper function if there are deallocs. |
396 | getOperation()->walk([&](bufferization::DeallocOp deallocOp) { |
397 | if (deallocOp.getMemrefs().size() > 1) { |
398 | helperFuncOp = bufferization::buildDeallocationLibraryFunction( |
399 | builder, loc: getOperation()->getLoc(), symbolTable); |
400 | return WalkResult::interrupt(); |
401 | } |
402 | return WalkResult::advance(); |
403 | }); |
404 | } |
405 | |
406 | RewritePatternSet patterns(&getContext()); |
407 | bufferization::populateBufferizationDeallocLoweringPattern(patterns, |
408 | deallocLibraryFunc: helperFuncOp); |
409 | |
410 | ConversionTarget target(getContext()); |
411 | target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect, |
412 | scf::SCFDialect, func::FuncDialect>(); |
413 | target.addIllegalOp<bufferization::DeallocOp>(); |
414 | |
415 | if (failed(applyPartialConversion(getOperation(), target, |
416 | std::move(patterns)))) |
417 | signalPassFailure(); |
418 | } |
419 | }; |
420 | } // namespace |
421 | |
422 | func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction( |
423 | OpBuilder &builder, Location loc, SymbolTable &symbolTable) { |
424 | Type indexMemrefType = |
425 | MemRefType::get({ShapedType::kDynamic}, builder.getIndexType()); |
426 | Type boolMemrefType = |
427 | MemRefType::get({ShapedType::kDynamic}, builder.getI1Type()); |
428 | SmallVector<Type> argTypes{indexMemrefType, indexMemrefType, boolMemrefType, |
429 | boolMemrefType, boolMemrefType}; |
430 | builder.clearInsertionPoint(); |
431 | |
432 | // Generate the func operation itself. |
433 | auto helperFuncOp = func::FuncOp::create( |
434 | loc, "dealloc_helper" , builder.getFunctionType(argTypes, {})); |
435 | helperFuncOp.setVisibility(SymbolTable::Visibility::Private); |
436 | symbolTable.insert(symbol: helperFuncOp); |
437 | auto &block = helperFuncOp.getFunctionBody().emplaceBlock(); |
438 | block.addArguments(argTypes, SmallVector<Location>(argTypes.size(), loc)); |
439 | |
440 | builder.setInsertionPointToStart(&block); |
441 | Value toDeallocMemref = helperFuncOp.getArguments()[0]; |
442 | Value toRetainMemref = helperFuncOp.getArguments()[1]; |
443 | Value conditionMemref = helperFuncOp.getArguments()[2]; |
444 | Value deallocCondsMemref = helperFuncOp.getArguments()[3]; |
445 | Value retainCondsMemref = helperFuncOp.getArguments()[4]; |
446 | |
447 | // Insert some prerequisites. |
448 | Value c0 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(0)); |
449 | Value c1 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(1)); |
450 | Value trueValue = |
451 | builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(true)); |
452 | Value falseValue = |
453 | builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(false)); |
454 | Value toDeallocSize = builder.create<memref::DimOp>(loc, toDeallocMemref, c0); |
455 | Value toRetainSize = builder.create<memref::DimOp>(loc, toRetainMemref, c0); |
456 | |
457 | builder.create<scf::ForOp>( |
458 | loc, c0, toRetainSize, c1, std::nullopt, |
459 | [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { |
460 | builder.create<memref::StoreOp>(loc, falseValue, retainCondsMemref, i); |
461 | builder.create<scf::YieldOp>(loc); |
462 | }); |
463 | |
464 | builder.create<scf::ForOp>( |
465 | loc, c0, toDeallocSize, c1, std::nullopt, |
466 | [&](OpBuilder &builder, Location loc, Value outerIter, |
467 | ValueRange iterArgs) { |
468 | Value toDealloc = |
469 | builder.create<memref::LoadOp>(loc, toDeallocMemref, outerIter); |
470 | Value cond = |
471 | builder.create<memref::LoadOp>(loc, conditionMemref, outerIter); |
472 | |
473 | // Build the first for loop that computes aliasing with retained |
474 | // memrefs. |
475 | Value noRetainAlias = |
476 | builder |
477 | .create<scf::ForOp>( |
478 | loc, c0, toRetainSize, c1, trueValue, |
479 | [&](OpBuilder &builder, Location loc, Value i, |
480 | ValueRange iterArgs) { |
481 | Value retainValue = builder.create<memref::LoadOp>( |
482 | loc, toRetainMemref, i); |
483 | Value doesAlias = builder.create<arith::CmpIOp>( |
484 | loc, arith::CmpIPredicate::eq, retainValue, |
485 | toDealloc); |
486 | builder.create<scf::IfOp>( |
487 | loc, doesAlias, |
488 | [&](OpBuilder &builder, Location loc) { |
489 | Value retainCondValue = |
490 | builder.create<memref::LoadOp>( |
491 | loc, retainCondsMemref, i); |
492 | Value aggregatedRetainCond = |
493 | builder.create<arith::OrIOp>( |
494 | loc, retainCondValue, cond); |
495 | builder.create<memref::StoreOp>( |
496 | loc, aggregatedRetainCond, retainCondsMemref, |
497 | i); |
498 | builder.create<scf::YieldOp>(loc); |
499 | }); |
500 | Value doesntAlias = builder.create<arith::CmpIOp>( |
501 | loc, arith::CmpIPredicate::ne, retainValue, |
502 | toDealloc); |
503 | Value yieldValue = builder.create<arith::AndIOp>( |
504 | loc, iterArgs[0], doesntAlias); |
505 | builder.create<scf::YieldOp>(loc, yieldValue); |
506 | }) |
507 | .getResult(0); |
508 | |
509 | // Build the second for loop that adds aliasing with previously |
510 | // deallocated memrefs. |
511 | Value noAlias = |
512 | builder |
513 | .create<scf::ForOp>( |
514 | loc, c0, outerIter, c1, noRetainAlias, |
515 | [&](OpBuilder &builder, Location loc, Value i, |
516 | ValueRange iterArgs) { |
517 | Value prevDeallocValue = builder.create<memref::LoadOp>( |
518 | loc, toDeallocMemref, i); |
519 | Value doesntAlias = builder.create<arith::CmpIOp>( |
520 | loc, arith::CmpIPredicate::ne, prevDeallocValue, |
521 | toDealloc); |
522 | Value yieldValue = builder.create<arith::AndIOp>( |
523 | loc, iterArgs[0], doesntAlias); |
524 | builder.create<scf::YieldOp>(loc, yieldValue); |
525 | }) |
526 | .getResult(0); |
527 | |
528 | Value shouldDealoc = builder.create<arith::AndIOp>(loc, noAlias, cond); |
529 | builder.create<memref::StoreOp>(loc, shouldDealoc, deallocCondsMemref, |
530 | outerIter); |
531 | builder.create<scf::YieldOp>(loc); |
532 | }); |
533 | |
534 | builder.create<func::ReturnOp>(loc); |
535 | return helperFuncOp; |
536 | } |
537 | |
538 | void mlir::bufferization::populateBufferizationDeallocLoweringPattern( |
539 | RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc) { |
540 | patterns.add<DeallocOpConversion>(patterns.getContext(), deallocLibraryFunc); |
541 | } |
542 | |
543 | std::unique_ptr<Pass> mlir::bufferization::createLowerDeallocationsPass() { |
544 | return std::make_unique<LowerDeallocationsPass>(); |
545 | } |
546 | |