| 1 | //===- LowerDeallocations.cpp - Bufferization Deallocs to MemRef pass -----===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements patterns to convert `bufferization.dealloc` operations |
| 10 | // to the MemRef dialect. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 15 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 16 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
| 17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 18 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 19 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 20 | #include "mlir/IR/BuiltinTypes.h" |
| 21 | #include "mlir/Pass/Pass.h" |
| 22 | #include "mlir/Transforms/DialectConversion.h" |
| 23 | |
| 24 | namespace mlir { |
| 25 | namespace bufferization { |
| 26 | #define GEN_PASS_DEF_LOWERDEALLOCATIONSPASS |
| 27 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" |
| 28 | } // namespace bufferization |
| 29 | } // namespace mlir |
| 30 | |
| 31 | using namespace mlir; |
| 32 | |
| 33 | namespace { |
| 34 | /// The DeallocOpConversion transforms all bufferization dealloc operations into |
| 35 | /// memref dealloc operations potentially guarded by scf if operations. |
| 36 | /// Additionally, memref extract_aligned_pointer_as_index and arith operations |
| 37 | /// are inserted to compute the guard conditions. We distinguish multiple cases |
| 38 | /// to provide an overall more efficient lowering. In the general case, a helper |
| 39 | /// func is created to avoid quadratic code size explosion (relative to the |
| 40 | /// number of operands of the dealloc operation). For examples of each case, |
| 41 | /// refer to the documentation of the member functions of this class. |
| 42 | class DeallocOpConversion |
| 43 | : public OpConversionPattern<bufferization::DeallocOp> { |
| 44 | |
| 45 | /// Lower a simple case without any retained values and a single memref to |
| 46 | /// avoiding the helper function. Ideally, static analysis can provide enough |
| 47 | /// aliasing information to split the dealloc operations up into this simple |
| 48 | /// case as much as possible before running this pass. |
| 49 | /// |
| 50 | /// Example: |
| 51 | /// ``` |
| 52 | /// bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1) |
| 53 | /// ``` |
| 54 | /// is lowered to |
| 55 | /// ``` |
| 56 | /// scf.if %arg1 { |
| 57 | /// memref.dealloc %arg0 : memref<2xf32> |
| 58 | /// } |
| 59 | /// ``` |
| 60 | LogicalResult |
| 61 | rewriteOneMemrefNoRetainCase(bufferization::DeallocOp op, OpAdaptor adaptor, |
| 62 | ConversionPatternRewriter &rewriter) const { |
| 63 | assert(adaptor.getMemrefs().size() == 1 && "expected only one memref" ); |
| 64 | assert(adaptor.getRetained().empty() && "expected no retained memrefs" ); |
| 65 | |
| 66 | rewriter.replaceOpWithNewOp<scf::IfOp>( |
| 67 | op, adaptor.getConditions()[0], [&](OpBuilder &builder, Location loc) { |
| 68 | builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]); |
| 69 | builder.create<scf::YieldOp>(loc); |
| 70 | }); |
| 71 | return success(); |
| 72 | } |
| 73 | |
| 74 | /// A special case lowering for the deallocation operation with exactly one |
| 75 | /// memref, but arbitrary number of retained values. This avoids the helper |
| 76 | /// function that the general case needs and thus also avoids storing indices |
| 77 | /// to specifically allocated memrefs. The size of the code produced by this |
| 78 | /// lowering is linear to the number of retained values. |
| 79 | /// |
| 80 | /// Example: |
| 81 | /// ```mlir |
| 82 | /// %0:2 = bufferization.dealloc (%m : memref<2xf32>) if (%cond) |
| 83 | // retain (%r0, %r1 : memref<1xf32>, memref<2xf32>) |
| 84 | /// return %0#0, %0#1 : i1, i1 |
| 85 | /// ``` |
| 86 | /// ```mlir |
| 87 | /// %m_base_pointer = memref.extract_aligned_pointer_as_index %m |
| 88 | /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0 |
| 89 | /// %r0_does_not_alias = arith.cmpi ne, %m_base_pointer, %r0_base_pointer |
| 90 | /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1 |
| 91 | /// %r1_does_not_alias = arith.cmpi ne, %m_base_pointer, %r1_base_pointer |
| 92 | /// %not_retained = arith.andi %r0_does_not_alias, %r1_does_not_alias : i1 |
| 93 | /// %should_dealloc = arith.andi %not_retained, %cond : i1 |
| 94 | /// scf.if %should_dealloc { |
| 95 | /// memref.dealloc %m : memref<2xf32> |
| 96 | /// } |
| 97 | /// %true = arith.constant true |
| 98 | /// %r0_does_alias = arith.xori %r0_does_not_alias, %true : i1 |
| 99 | /// %r0_ownership = arith.andi %r0_does_alias, %cond : i1 |
| 100 | /// %r1_does_alias = arith.xori %r1_does_not_alias, %true : i1 |
| 101 | /// %r1_ownership = arith.andi %r1_does_alias, %cond : i1 |
| 102 | /// return %r0_ownership, %r1_ownership : i1, i1 |
| 103 | /// ``` |
| 104 | LogicalResult rewriteOneMemrefMultipleRetainCase( |
| 105 | bufferization::DeallocOp op, OpAdaptor adaptor, |
| 106 | ConversionPatternRewriter &rewriter) const { |
| 107 | assert(adaptor.getMemrefs().size() == 1 && "expected only one memref" ); |
| 108 | |
| 109 | // Compute the base pointer indices, compare all retained indices to the |
| 110 | // memref index to check if they alias. |
| 111 | SmallVector<Value> doesNotAliasList; |
| 112 | Value memrefAsIdx = rewriter.create<memref::ExtractAlignedPointerAsIndexOp>( |
| 113 | op->getLoc(), adaptor.getMemrefs()[0]); |
| 114 | for (Value retained : adaptor.getRetained()) { |
| 115 | Value retainedAsIdx = |
| 116 | rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op->getLoc(), |
| 117 | retained); |
| 118 | Value doesNotAlias = rewriter.create<arith::CmpIOp>( |
| 119 | op->getLoc(), arith::CmpIPredicate::ne, memrefAsIdx, retainedAsIdx); |
| 120 | doesNotAliasList.push_back(doesNotAlias); |
| 121 | } |
| 122 | |
| 123 | // AND-reduce the list of booleans from above. |
| 124 | Value prev = doesNotAliasList.front(); |
| 125 | for (Value doesNotAlias : ArrayRef(doesNotAliasList).drop_front()) |
| 126 | prev = rewriter.create<arith::AndIOp>(op->getLoc(), prev, doesNotAlias); |
| 127 | |
| 128 | // Also consider the condition given by the dealloc operation and perform a |
| 129 | // conditional deallocation guarded by that value. |
| 130 | Value shouldDealloc = rewriter.create<arith::AndIOp>( |
| 131 | op->getLoc(), prev, adaptor.getConditions()[0]); |
| 132 | |
| 133 | rewriter.create<scf::IfOp>( |
| 134 | op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) { |
| 135 | builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]); |
| 136 | builder.create<scf::YieldOp>(loc); |
| 137 | }); |
| 138 | |
| 139 | // Compute the replacement values for the dealloc operation results. This |
| 140 | // inserts an already canonicalized form of |
| 141 | // `select(does_alias_with_memref(r), memref_cond, false)` for each retained |
| 142 | // value r. |
| 143 | SmallVector<Value> replacements; |
| 144 | Value trueVal = rewriter.create<arith::ConstantOp>( |
| 145 | op->getLoc(), rewriter.getBoolAttr(true)); |
| 146 | for (Value doesNotAlias : doesNotAliasList) { |
| 147 | Value aliases = |
| 148 | rewriter.create<arith::XOrIOp>(op->getLoc(), doesNotAlias, trueVal); |
| 149 | Value result = rewriter.create<arith::AndIOp>(op->getLoc(), aliases, |
| 150 | adaptor.getConditions()[0]); |
| 151 | replacements.push_back(result); |
| 152 | } |
| 153 | |
| 154 | rewriter.replaceOp(op, replacements); |
| 155 | |
| 156 | return success(); |
| 157 | } |
| 158 | |
| 159 | /// Lowering that supports all features the dealloc operation has to offer. It |
| 160 | /// computes the base pointer of each memref (as an index), stores it in a |
| 161 | /// new memref helper structure and passes it to the helper function generated |
| 162 | /// in 'buildDeallocationHelperFunction'. The results are stored in two lists |
| 163 | /// (represented as memrefs) of booleans passed as arguments. The first list |
| 164 | /// stores whether the corresponding condition should be deallocated, the |
| 165 | /// second list stores the ownership of the retained values which can be used |
| 166 | /// to replace the result values of the `bufferization.dealloc` operation. |
| 167 | /// |
| 168 | /// Example: |
| 169 | /// ``` |
| 170 | /// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xf32>, memref<5xf32>) |
| 171 | /// if (%cond0, %cond1) |
| 172 | /// retain (%r0, %r1 : memref<1xf32>, memref<2xf32>) |
| 173 | /// ``` |
| 174 | /// lowers to (simplified): |
| 175 | /// ``` |
| 176 | /// %c0 = arith.constant 0 : index |
| 177 | /// %c1 = arith.constant 1 : index |
| 178 | /// %dealloc_base_pointer_list = memref.alloc() : memref<2xindex> |
| 179 | /// %cond_list = memref.alloc() : memref<2xi1> |
| 180 | /// %retain_base_pointer_list = memref.alloc() : memref<2xindex> |
| 181 | /// %m0_base_pointer = memref.extract_aligned_pointer_as_index %m0 |
| 182 | /// memref.store %m0_base_pointer, %dealloc_base_pointer_list[%c0] |
| 183 | /// %m1_base_pointer = memref.extract_aligned_pointer_as_index %m1 |
| 184 | /// memref.store %m1_base_pointer, %dealloc_base_pointer_list[%c1] |
| 185 | /// memref.store %cond0, %cond_list[%c0] |
| 186 | /// memref.store %cond1, %cond_list[%c1] |
| 187 | /// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0 |
| 188 | /// memref.store %r0_base_pointer, %retain_base_pointer_list[%c0] |
| 189 | /// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1 |
| 190 | /// memref.store %r1_base_pointer, %retain_base_pointer_list[%c1] |
| 191 | /// %dyn_dealloc_base_pointer_list = memref.cast %dealloc_base_pointer_list : |
| 192 | /// memref<2xindex> to memref<?xindex> |
| 193 | /// %dyn_cond_list = memref.cast %cond_list : memref<2xi1> to memref<?xi1> |
| 194 | /// %dyn_retain_base_pointer_list = memref.cast %retain_base_pointer_list : |
| 195 | /// memref<2xindex> to memref<?xindex> |
| 196 | /// %dealloc_cond_out = memref.alloc() : memref<2xi1> |
| 197 | /// %ownership_out = memref.alloc() : memref<2xi1> |
| 198 | /// %dyn_dealloc_cond_out = memref.cast %dealloc_cond_out : |
| 199 | /// memref<2xi1> to memref<?xi1> |
| 200 | /// %dyn_ownership_out = memref.cast %ownership_out : |
| 201 | /// memref<2xi1> to memref<?xi1> |
| 202 | /// call @dealloc_helper(%dyn_dealloc_base_pointer_list, |
| 203 | /// %dyn_retain_base_pointer_list, |
| 204 | /// %dyn_cond_list, |
| 205 | /// %dyn_dealloc_cond_out, |
| 206 | /// %dyn_ownership_out) : (...) |
| 207 | /// %m0_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c0] : memref<2xi1> |
| 208 | /// scf.if %m0_dealloc_cond { |
| 209 | /// memref.dealloc %m0 : memref<2xf32> |
| 210 | /// } |
| 211 | /// %m1_dealloc_cond = memref.load %dyn_dealloc_cond_out[%c1] : memref<2xi1> |
| 212 | /// scf.if %m1_dealloc_cond { |
| 213 | /// memref.dealloc %m1 : memref<5xf32> |
| 214 | /// } |
| 215 | /// %r0_ownership = memref.load %dyn_ownership_out[%c0] : memref<2xi1> |
| 216 | /// %r1_ownership = memref.load %dyn_ownership_out[%c1] : memref<2xi1> |
| 217 | /// memref.dealloc %dealloc_base_pointer_list : memref<2xindex> |
| 218 | /// memref.dealloc %retain_base_pointer_list : memref<2xindex> |
| 219 | /// memref.dealloc %cond_list : memref<2xi1> |
| 220 | /// memref.dealloc %dealloc_cond_out : memref<2xi1> |
| 221 | /// memref.dealloc %ownership_out : memref<2xi1> |
| 222 | /// // replace %0#0 with %r0_ownership |
| 223 | /// // replace %0#1 with %r1_ownership |
| 224 | /// ``` |
| 225 | LogicalResult rewriteGeneralCase(bufferization::DeallocOp op, |
| 226 | OpAdaptor adaptor, |
| 227 | ConversionPatternRewriter &rewriter) const { |
| 228 | // Allocate two memrefs holding the base pointer indices of the list of |
| 229 | // memrefs to be deallocated and the ones to be retained. These can then be |
| 230 | // passed to the helper function and the for-loops can iterate over them. |
| 231 | // Without storing them to memrefs, we could not use for-loops but only a |
| 232 | // completely unrolled version of it, potentially leading to code-size |
| 233 | // blow-up. |
| 234 | Value toDeallocMemref = rewriter.create<memref::AllocOp>( |
| 235 | op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()}, |
| 236 | rewriter.getIndexType())); |
| 237 | Value conditionMemref = rewriter.create<memref::AllocOp>( |
| 238 | op.getLoc(), MemRefType::get({(int64_t)adaptor.getConditions().size()}, |
| 239 | rewriter.getI1Type())); |
| 240 | Value toRetainMemref = rewriter.create<memref::AllocOp>( |
| 241 | op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()}, |
| 242 | rewriter.getIndexType())); |
| 243 | |
| 244 | auto getConstValue = [&](uint64_t value) -> Value { |
| 245 | return rewriter.create<arith::ConstantOp>(op.getLoc(), |
| 246 | rewriter.getIndexAttr(value)); |
| 247 | }; |
| 248 | |
| 249 | // Extract the base pointers of the memrefs as indices to check for aliasing |
| 250 | // at runtime. |
| 251 | for (auto [i, toDealloc] : llvm::enumerate(adaptor.getMemrefs())) { |
| 252 | Value memrefAsIdx = |
| 253 | rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(), |
| 254 | toDealloc); |
| 255 | rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx, |
| 256 | toDeallocMemref, getConstValue(i)); |
| 257 | } |
| 258 | |
| 259 | for (auto [i, cond] : llvm::enumerate(adaptor.getConditions())) |
| 260 | rewriter.create<memref::StoreOp>(op.getLoc(), cond, conditionMemref, |
| 261 | getConstValue(i)); |
| 262 | |
| 263 | for (auto [i, toRetain] : llvm::enumerate(adaptor.getRetained())) { |
| 264 | Value memrefAsIdx = |
| 265 | rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(), |
| 266 | toRetain); |
| 267 | rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx, toRetainMemref, |
| 268 | getConstValue(i)); |
| 269 | } |
| 270 | |
| 271 | // Cast the allocated memrefs to dynamic shape because we want only one |
| 272 | // helper function no matter how many operands the bufferization.dealloc |
| 273 | // has. |
| 274 | Value castedDeallocMemref = rewriter.create<memref::CastOp>( |
| 275 | op->getLoc(), |
| 276 | MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()), |
| 277 | toDeallocMemref); |
| 278 | Value castedCondsMemref = rewriter.create<memref::CastOp>( |
| 279 | op->getLoc(), |
| 280 | MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), |
| 281 | conditionMemref); |
| 282 | Value castedRetainMemref = rewriter.create<memref::CastOp>( |
| 283 | op->getLoc(), |
| 284 | MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()), |
| 285 | toRetainMemref); |
| 286 | |
| 287 | Value deallocCondsMemref = rewriter.create<memref::AllocOp>( |
| 288 | op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()}, |
| 289 | rewriter.getI1Type())); |
| 290 | Value retainCondsMemref = rewriter.create<memref::AllocOp>( |
| 291 | op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()}, |
| 292 | rewriter.getI1Type())); |
| 293 | |
| 294 | Value castedDeallocCondsMemref = rewriter.create<memref::CastOp>( |
| 295 | op->getLoc(), |
| 296 | MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), |
| 297 | deallocCondsMemref); |
| 298 | Value castedRetainCondsMemref = rewriter.create<memref::CastOp>( |
| 299 | op->getLoc(), |
| 300 | MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), |
| 301 | retainCondsMemref); |
| 302 | |
| 303 | Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>(); |
| 304 | rewriter.create<func::CallOp>( |
| 305 | op.getLoc(), deallocHelperFuncMap.lookup(symtableOp), |
| 306 | SmallVector<Value>{castedDeallocMemref, castedRetainMemref, |
| 307 | castedCondsMemref, castedDeallocCondsMemref, |
| 308 | castedRetainCondsMemref}); |
| 309 | |
| 310 | for (unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) { |
| 311 | Value idxValue = getConstValue(i); |
| 312 | Value shouldDealloc = rewriter.create<memref::LoadOp>( |
| 313 | op.getLoc(), deallocCondsMemref, idxValue); |
| 314 | rewriter.create<scf::IfOp>( |
| 315 | op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) { |
| 316 | builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[i]); |
| 317 | builder.create<scf::YieldOp>(loc); |
| 318 | }); |
| 319 | } |
| 320 | |
| 321 | SmallVector<Value> replacements; |
| 322 | for (unsigned i = 0, e = adaptor.getRetained().size(); i < e; ++i) { |
| 323 | Value idxValue = getConstValue(i); |
| 324 | Value ownership = rewriter.create<memref::LoadOp>( |
| 325 | op.getLoc(), retainCondsMemref, idxValue); |
| 326 | replacements.push_back(ownership); |
| 327 | } |
| 328 | |
| 329 | // Deallocate above allocated memrefs again to avoid memory leaks. |
| 330 | // Deallocation will not be run on code after this stage. |
| 331 | rewriter.create<memref::DeallocOp>(op.getLoc(), toDeallocMemref); |
| 332 | rewriter.create<memref::DeallocOp>(op.getLoc(), toRetainMemref); |
| 333 | rewriter.create<memref::DeallocOp>(op.getLoc(), conditionMemref); |
| 334 | rewriter.create<memref::DeallocOp>(op.getLoc(), deallocCondsMemref); |
| 335 | rewriter.create<memref::DeallocOp>(op.getLoc(), retainCondsMemref); |
| 336 | |
| 337 | rewriter.replaceOp(op, replacements); |
| 338 | return success(); |
| 339 | } |
| 340 | |
| 341 | public: |
| 342 | DeallocOpConversion( |
| 343 | MLIRContext *context, |
| 344 | const bufferization::DeallocHelperMap &deallocHelperFuncMap) |
| 345 | : OpConversionPattern<bufferization::DeallocOp>(context), |
| 346 | deallocHelperFuncMap(deallocHelperFuncMap) {} |
| 347 | |
| 348 | LogicalResult |
| 349 | matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor, |
| 350 | ConversionPatternRewriter &rewriter) const override { |
| 351 | // Lower the trivial case. |
| 352 | if (adaptor.getMemrefs().empty()) { |
| 353 | Value falseVal = rewriter.create<arith::ConstantOp>( |
| 354 | op.getLoc(), rewriter.getBoolAttr(false)); |
| 355 | rewriter.replaceOp( |
| 356 | op, SmallVector<Value>(adaptor.getRetained().size(), falseVal)); |
| 357 | return success(); |
| 358 | } |
| 359 | |
| 360 | if (adaptor.getMemrefs().size() == 1 && adaptor.getRetained().empty()) |
| 361 | return rewriteOneMemrefNoRetainCase(op, adaptor, rewriter); |
| 362 | |
| 363 | if (adaptor.getMemrefs().size() == 1) |
| 364 | return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter); |
| 365 | |
| 366 | Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>(); |
| 367 | if (!deallocHelperFuncMap.contains(symtableOp)) |
| 368 | return op->emitError( |
| 369 | "library function required for generic lowering, but cannot be " |
| 370 | "automatically inserted when operating on functions" ); |
| 371 | |
| 372 | return rewriteGeneralCase(op, adaptor, rewriter); |
| 373 | } |
| 374 | |
| 375 | private: |
| 376 | const bufferization::DeallocHelperMap &deallocHelperFuncMap; |
| 377 | }; |
| 378 | } // namespace |
| 379 | |
| 380 | namespace { |
| 381 | struct LowerDeallocationsPass |
| 382 | : public bufferization::impl::LowerDeallocationsPassBase< |
| 383 | LowerDeallocationsPass> { |
| 384 | void runOnOperation() override { |
| 385 | if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) { |
| 386 | emitError(getOperation()->getLoc(), |
| 387 | "root operation must be a builtin.module or a function" ); |
| 388 | signalPassFailure(); |
| 389 | return; |
| 390 | } |
| 391 | |
| 392 | bufferization::DeallocHelperMap deallocHelperFuncMap; |
| 393 | if (auto module = dyn_cast<ModuleOp>(getOperation())) { |
| 394 | OpBuilder builder = OpBuilder::atBlockBegin(block: module.getBody()); |
| 395 | |
| 396 | // Build dealloc helper function if there are deallocs. |
| 397 | getOperation()->walk([&](bufferization::DeallocOp deallocOp) { |
| 398 | Operation *symtableOp = |
| 399 | deallocOp->getParentWithTrait<OpTrait::SymbolTable>(); |
| 400 | if (deallocOp.getMemrefs().size() > 1 && |
| 401 | !deallocHelperFuncMap.contains(Val: symtableOp)) { |
| 402 | SymbolTable symbolTable(symtableOp); |
| 403 | func::FuncOp helperFuncOp = |
| 404 | bufferization::buildDeallocationLibraryFunction( |
| 405 | builder, getOperation()->getLoc(), symbolTable); |
| 406 | deallocHelperFuncMap[symtableOp] = helperFuncOp; |
| 407 | } |
| 408 | }); |
| 409 | } |
| 410 | |
| 411 | RewritePatternSet patterns(&getContext()); |
| 412 | bufferization::populateBufferizationDeallocLoweringPattern( |
| 413 | patterns, deallocHelperFuncMap); |
| 414 | |
| 415 | ConversionTarget target(getContext()); |
| 416 | target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect, |
| 417 | scf::SCFDialect, func::FuncDialect>(); |
| 418 | target.addIllegalOp<bufferization::DeallocOp>(); |
| 419 | |
| 420 | if (failed(applyPartialConversion(getOperation(), target, |
| 421 | std::move(patterns)))) |
| 422 | signalPassFailure(); |
| 423 | } |
| 424 | }; |
| 425 | } // namespace |
| 426 | |
| 427 | func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction( |
| 428 | OpBuilder &builder, Location loc, SymbolTable &symbolTable) { |
| 429 | Type indexMemrefType = |
| 430 | MemRefType::get({ShapedType::kDynamic}, builder.getIndexType()); |
| 431 | Type boolMemrefType = |
| 432 | MemRefType::get({ShapedType::kDynamic}, builder.getI1Type()); |
| 433 | SmallVector<Type> argTypes{indexMemrefType, indexMemrefType, boolMemrefType, |
| 434 | boolMemrefType, boolMemrefType}; |
| 435 | builder.clearInsertionPoint(); |
| 436 | |
| 437 | // Generate the func operation itself. |
| 438 | auto helperFuncOp = func::FuncOp::create( |
| 439 | loc, "dealloc_helper" , builder.getFunctionType(argTypes, {})); |
| 440 | helperFuncOp.setVisibility(SymbolTable::Visibility::Private); |
| 441 | symbolTable.insert(symbol: helperFuncOp); |
| 442 | auto &block = helperFuncOp.getFunctionBody().emplaceBlock(); |
| 443 | block.addArguments(argTypes, SmallVector<Location>(argTypes.size(), loc)); |
| 444 | |
| 445 | builder.setInsertionPointToStart(&block); |
| 446 | Value toDeallocMemref = helperFuncOp.getArguments()[0]; |
| 447 | Value toRetainMemref = helperFuncOp.getArguments()[1]; |
| 448 | Value conditionMemref = helperFuncOp.getArguments()[2]; |
| 449 | Value deallocCondsMemref = helperFuncOp.getArguments()[3]; |
| 450 | Value retainCondsMemref = helperFuncOp.getArguments()[4]; |
| 451 | |
| 452 | // Insert some prerequisites. |
| 453 | Value c0 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(0)); |
| 454 | Value c1 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(1)); |
| 455 | Value trueValue = |
| 456 | builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(true)); |
| 457 | Value falseValue = |
| 458 | builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(false)); |
| 459 | Value toDeallocSize = builder.create<memref::DimOp>(loc, toDeallocMemref, c0); |
| 460 | Value toRetainSize = builder.create<memref::DimOp>(loc, toRetainMemref, c0); |
| 461 | |
| 462 | builder.create<scf::ForOp>( |
| 463 | loc, c0, toRetainSize, c1, std::nullopt, |
| 464 | [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { |
| 465 | builder.create<memref::StoreOp>(loc, falseValue, retainCondsMemref, i); |
| 466 | builder.create<scf::YieldOp>(loc); |
| 467 | }); |
| 468 | |
| 469 | builder.create<scf::ForOp>( |
| 470 | loc, c0, toDeallocSize, c1, std::nullopt, |
| 471 | [&](OpBuilder &builder, Location loc, Value outerIter, |
| 472 | ValueRange iterArgs) { |
| 473 | Value toDealloc = |
| 474 | builder.create<memref::LoadOp>(loc, toDeallocMemref, outerIter); |
| 475 | Value cond = |
| 476 | builder.create<memref::LoadOp>(loc, conditionMemref, outerIter); |
| 477 | |
| 478 | // Build the first for loop that computes aliasing with retained |
| 479 | // memrefs. |
| 480 | Value noRetainAlias = |
| 481 | builder |
| 482 | .create<scf::ForOp>( |
| 483 | loc, c0, toRetainSize, c1, trueValue, |
| 484 | [&](OpBuilder &builder, Location loc, Value i, |
| 485 | ValueRange iterArgs) { |
| 486 | Value retainValue = builder.create<memref::LoadOp>( |
| 487 | loc, toRetainMemref, i); |
| 488 | Value doesAlias = builder.create<arith::CmpIOp>( |
| 489 | loc, arith::CmpIPredicate::eq, retainValue, |
| 490 | toDealloc); |
| 491 | builder.create<scf::IfOp>( |
| 492 | loc, doesAlias, |
| 493 | [&](OpBuilder &builder, Location loc) { |
| 494 | Value retainCondValue = |
| 495 | builder.create<memref::LoadOp>( |
| 496 | loc, retainCondsMemref, i); |
| 497 | Value aggregatedRetainCond = |
| 498 | builder.create<arith::OrIOp>( |
| 499 | loc, retainCondValue, cond); |
| 500 | builder.create<memref::StoreOp>( |
| 501 | loc, aggregatedRetainCond, retainCondsMemref, |
| 502 | i); |
| 503 | builder.create<scf::YieldOp>(loc); |
| 504 | }); |
| 505 | Value doesntAlias = builder.create<arith::CmpIOp>( |
| 506 | loc, arith::CmpIPredicate::ne, retainValue, |
| 507 | toDealloc); |
| 508 | Value yieldValue = builder.create<arith::AndIOp>( |
| 509 | loc, iterArgs[0], doesntAlias); |
| 510 | builder.create<scf::YieldOp>(loc, yieldValue); |
| 511 | }) |
| 512 | .getResult(0); |
| 513 | |
| 514 | // Build the second for loop that adds aliasing with previously |
| 515 | // deallocated memrefs. |
| 516 | Value noAlias = |
| 517 | builder |
| 518 | .create<scf::ForOp>( |
| 519 | loc, c0, outerIter, c1, noRetainAlias, |
| 520 | [&](OpBuilder &builder, Location loc, Value i, |
| 521 | ValueRange iterArgs) { |
| 522 | Value prevDeallocValue = builder.create<memref::LoadOp>( |
| 523 | loc, toDeallocMemref, i); |
| 524 | Value doesntAlias = builder.create<arith::CmpIOp>( |
| 525 | loc, arith::CmpIPredicate::ne, prevDeallocValue, |
| 526 | toDealloc); |
| 527 | Value yieldValue = builder.create<arith::AndIOp>( |
| 528 | loc, iterArgs[0], doesntAlias); |
| 529 | builder.create<scf::YieldOp>(loc, yieldValue); |
| 530 | }) |
| 531 | .getResult(0); |
| 532 | |
| 533 | Value shouldDealoc = builder.create<arith::AndIOp>(loc, noAlias, cond); |
| 534 | builder.create<memref::StoreOp>(loc, shouldDealoc, deallocCondsMemref, |
| 535 | outerIter); |
| 536 | builder.create<scf::YieldOp>(loc); |
| 537 | }); |
| 538 | |
| 539 | builder.create<func::ReturnOp>(loc); |
| 540 | return helperFuncOp; |
| 541 | } |
| 542 | |
| 543 | void mlir::bufferization::populateBufferizationDeallocLoweringPattern( |
| 544 | RewritePatternSet &patterns, |
| 545 | const bufferization::DeallocHelperMap &deallocHelperFuncMap) { |
| 546 | patterns.add<DeallocOpConversion>(arg: patterns.getContext(), |
| 547 | args: deallocHelperFuncMap); |
| 548 | } |
| 549 | |