| 1 | //===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h" |
| 10 | #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" |
| 11 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 12 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
| 13 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 14 | #include "mlir/IR/Dialect.h" |
| 15 | #include "mlir/IR/Operation.h" |
| 16 | |
| 17 | using namespace mlir; |
| 18 | using namespace mlir::bufferization; |
| 19 | |
| 20 | static bool isMemref(Value v) { return isa<BaseMemRefType>(Val: v.getType()); } |
| 21 | |
| 22 | namespace { |
| 23 | /// While CondBranchOp also implement the BranchOpInterface, we add a |
| 24 | /// special-case implementation here because the BranchOpInterface does not |
| 25 | /// offer all of the functionallity we need to insert dealloc oeprations in an |
| 26 | /// efficient way. More precisely, there is no way to extract the branch |
| 27 | /// condition without casting to CondBranchOp specifically. It is still |
| 28 | /// possible to implement deallocation for cases where we don't know to which |
| 29 | /// successor the terminator branches before the actual branch happens by |
| 30 | /// inserting auxiliary blocks and putting the dealloc op there, however, this |
| 31 | /// can lead to less efficient code. |
| 32 | /// This function inserts two dealloc operations (one for each successor) and |
| 33 | /// adjusts the dealloc conditions according to the branch condition, then the |
| 34 | /// ownerships of the retained MemRefs are updated by combining the result |
| 35 | /// values of the two dealloc operations. |
| 36 | /// |
| 37 | /// Example: |
| 38 | /// ``` |
| 39 | /// ^bb1: |
| 40 | /// <more ops...> |
| 41 | /// cf.cond_br cond, ^bb2(<forward-to-bb2>), ^bb3(<forward-to-bb2>) |
| 42 | /// ``` |
| 43 | /// becomes |
| 44 | /// ``` |
| 45 | /// // let (m, c) = getMemrefsAndConditionsToDeallocate(bb1) |
| 46 | /// // let r0 = getMemrefsToRetain(bb1, bb2, <forward-to-bb2>) |
| 47 | /// // let r1 = getMemrefsToRetain(bb1, bb3, <forward-to-bb3>) |
| 48 | /// ^bb1: |
| 49 | /// <more ops...> |
| 50 | /// let thenCond = map(c, (c) -> arith.andi cond, c) |
| 51 | /// let elseCond = map(c, (c) -> arith.andi (arith.xori cond, true), c) |
| 52 | /// o0 = bufferization.dealloc m if thenCond retain r0 |
| 53 | /// o1 = bufferization.dealloc m if elseCond retain r1 |
| 54 | /// // replace ownership(r0) with o0 element-wise |
| 55 | /// // replace ownership(r1) with o1 element-wise |
| 56 | /// // let ownership0 := (r) -> o in o0 corresponding to r |
| 57 | /// // let ownership1 := (r) -> o in o1 corresponding to r |
| 58 | /// // let cmn := intersection(r0, r1) |
| 59 | /// foreach (a, b) in zip(map(cmn, ownership0), map(cmn, ownership1)): |
| 60 | /// forall r in r0: replace ownership0(r) with arith.select cond, a, b) |
| 61 | /// forall r in r1: replace ownership1(r) with arith.select cond, a, b) |
| 62 | /// cf.cond_br cond, ^bb2(<forward-to-bb2>, o0), ^bb3(<forward-to-bb3>, o1) |
| 63 | /// ``` |
| 64 | struct CondBranchOpInterface |
| 65 | : public BufferDeallocationOpInterface::ExternalModel<CondBranchOpInterface, |
| 66 | cf::CondBranchOp> { |
| 67 | FailureOr<Operation *> process(Operation *op, DeallocationState &state, |
| 68 | const DeallocationOptions &options) const { |
| 69 | OpBuilder builder(op); |
| 70 | auto condBr = cast<cf::CondBranchOp>(op); |
| 71 | |
| 72 | // The list of memrefs to deallocate in this block is independent of which |
| 73 | // branch is taken. |
| 74 | SmallVector<Value> memrefs, conditions; |
| 75 | if (failed(state.getMemrefsAndConditionsToDeallocate( |
| 76 | builder, loc: condBr.getLoc(), block: condBr->getBlock(), memrefs, conditions))) |
| 77 | return failure(); |
| 78 | |
| 79 | // Helper lambda to factor out common logic for inserting the dealloc |
| 80 | // operations for each successor. |
| 81 | auto insertDeallocForBranch = |
| 82 | [&](Block *target, MutableOperandRange destOperands, |
| 83 | const std::function<Value(Value)> &conditionModifier, |
| 84 | DenseMap<Value, Value> &mapping) -> DeallocOp { |
| 85 | SmallVector<Value> toRetain; |
| 86 | state.getMemrefsToRetain(condBr->getBlock(), target, |
| 87 | destOperands.getAsOperandRange(), toRetain); |
| 88 | SmallVector<Value> adaptedConditions( |
| 89 | llvm::map_range(conditions, conditionModifier)); |
| 90 | auto deallocOp = builder.create<bufferization::DeallocOp>( |
| 91 | condBr.getLoc(), memrefs, adaptedConditions, toRetain); |
| 92 | state.resetOwnerships(deallocOp.getRetained(), condBr->getBlock()); |
| 93 | for (auto [retained, ownership] : llvm::zip( |
| 94 | deallocOp.getRetained(), deallocOp.getUpdatedConditions())) { |
| 95 | state.updateOwnership(retained, ownership, condBr->getBlock()); |
| 96 | mapping[retained] = ownership; |
| 97 | } |
| 98 | SmallVector<Value> replacements, ownerships; |
| 99 | for (OpOperand &operand : destOperands) { |
| 100 | replacements.push_back(operand.get()); |
| 101 | if (isMemref(operand.get())) { |
| 102 | assert(mapping.contains(operand.get()) && |
| 103 | "Should be contained at this point" ); |
| 104 | ownerships.push_back(mapping[operand.get()]); |
| 105 | } |
| 106 | } |
| 107 | replacements.append(ownerships); |
| 108 | destOperands.assign(replacements); |
| 109 | return deallocOp; |
| 110 | }; |
| 111 | |
| 112 | // Call the helper lambda and make sure the dealloc conditions are properly |
| 113 | // modified to reflect the branch condition as well. |
| 114 | DenseMap<Value, Value> thenMapping, elseMapping; |
| 115 | DeallocOp thenTakenDeallocOp = insertDeallocForBranch( |
| 116 | condBr.getTrueDest(), condBr.getTrueDestOperandsMutable(), |
| 117 | [&](Value cond) { |
| 118 | return builder.create<arith::AndIOp>(condBr.getLoc(), cond, |
| 119 | condBr.getCondition()); |
| 120 | }, |
| 121 | thenMapping); |
| 122 | DeallocOp elseTakenDeallocOp = insertDeallocForBranch( |
| 123 | condBr.getFalseDest(), condBr.getFalseDestOperandsMutable(), |
| 124 | [&](Value cond) { |
| 125 | Value trueVal = builder.create<arith::ConstantOp>( |
| 126 | condBr.getLoc(), builder.getBoolAttr(true)); |
| 127 | Value negation = builder.create<arith::XOrIOp>( |
| 128 | condBr.getLoc(), trueVal, condBr.getCondition()); |
| 129 | return builder.create<arith::AndIOp>(condBr.getLoc(), cond, negation); |
| 130 | }, |
| 131 | elseMapping); |
| 132 | |
| 133 | // We specifically need to update the ownerships of values that are retained |
| 134 | // in both dealloc operations again to get a combined 'Unique' ownership |
| 135 | // instead of an 'Unknown' ownership. |
| 136 | SmallPtrSet<Value, 16> thenValues(llvm::from_range, |
| 137 | thenTakenDeallocOp.getRetained()); |
| 138 | SetVector<Value> commonValues; |
| 139 | for (Value val : elseTakenDeallocOp.getRetained()) { |
| 140 | if (thenValues.contains(val)) |
| 141 | commonValues.insert(val); |
| 142 | } |
| 143 | |
| 144 | for (Value retained : commonValues) { |
| 145 | state.resetOwnerships(memrefs: retained, block: condBr->getBlock()); |
| 146 | Value combinedOwnership = builder.create<arith::SelectOp>( |
| 147 | condBr.getLoc(), condBr.getCondition(), thenMapping[retained], |
| 148 | elseMapping[retained]); |
| 149 | state.updateOwnership(memref: retained, ownership: combinedOwnership, block: condBr->getBlock()); |
| 150 | } |
| 151 | |
| 152 | return condBr.getOperation(); |
| 153 | } |
| 154 | }; |
| 155 | |
| 156 | } // namespace |
| 157 | |
| 158 | void mlir::cf::registerBufferDeallocationOpInterfaceExternalModels( |
| 159 | DialectRegistry ®istry) { |
| 160 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, ControlFlowDialect *dialect) { |
| 161 | CondBranchOp::attachInterface<CondBranchOpInterface>(*ctx); |
| 162 | }); |
| 163 | } |
| 164 | |