1 | //===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h" |
10 | #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" |
11 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
12 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
13 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
14 | #include "mlir/IR/Dialect.h" |
15 | #include "mlir/IR/Operation.h" |
16 | |
17 | using namespace mlir; |
18 | using namespace mlir::bufferization; |
19 | |
20 | static bool isMemref(Value v) { return isa<BaseMemRefType>(Val: v.getType()); } |
21 | |
22 | namespace { |
23 | /// While CondBranchOp also implement the BranchOpInterface, we add a |
24 | /// special-case implementation here because the BranchOpInterface does not |
25 | /// offer all of the functionallity we need to insert dealloc oeprations in an |
26 | /// efficient way. More precisely, there is no way to extract the branch |
27 | /// condition without casting to CondBranchOp specifically. It is still |
28 | /// possible to implement deallocation for cases where we don't know to which |
29 | /// successor the terminator branches before the actual branch happens by |
30 | /// inserting auxiliary blocks and putting the dealloc op there, however, this |
31 | /// can lead to less efficient code. |
32 | /// This function inserts two dealloc operations (one for each successor) and |
33 | /// adjusts the dealloc conditions according to the branch condition, then the |
34 | /// ownerships of the retained MemRefs are updated by combining the result |
35 | /// values of the two dealloc operations. |
36 | /// |
37 | /// Example: |
38 | /// ``` |
39 | /// ^bb1: |
40 | /// <more ops...> |
41 | /// cf.cond_br cond, ^bb2(<forward-to-bb2>), ^bb3(<forward-to-bb2>) |
42 | /// ``` |
43 | /// becomes |
44 | /// ``` |
45 | /// // let (m, c) = getMemrefsAndConditionsToDeallocate(bb1) |
46 | /// // let r0 = getMemrefsToRetain(bb1, bb2, <forward-to-bb2>) |
47 | /// // let r1 = getMemrefsToRetain(bb1, bb3, <forward-to-bb3>) |
48 | /// ^bb1: |
49 | /// <more ops...> |
50 | /// let thenCond = map(c, (c) -> arith.andi cond, c) |
51 | /// let elseCond = map(c, (c) -> arith.andi (arith.xori cond, true), c) |
52 | /// o0 = bufferization.dealloc m if thenCond retain r0 |
53 | /// o1 = bufferization.dealloc m if elseCond retain r1 |
54 | /// // replace ownership(r0) with o0 element-wise |
55 | /// // replace ownership(r1) with o1 element-wise |
56 | /// // let ownership0 := (r) -> o in o0 corresponding to r |
57 | /// // let ownership1 := (r) -> o in o1 corresponding to r |
58 | /// // let cmn := intersection(r0, r1) |
59 | /// foreach (a, b) in zip(map(cmn, ownership0), map(cmn, ownership1)): |
60 | /// forall r in r0: replace ownership0(r) with arith.select cond, a, b) |
61 | /// forall r in r1: replace ownership1(r) with arith.select cond, a, b) |
62 | /// cf.cond_br cond, ^bb2(<forward-to-bb2>, o0), ^bb3(<forward-to-bb3>, o1) |
63 | /// ``` |
64 | struct CondBranchOpInterface |
65 | : public BufferDeallocationOpInterface::ExternalModel<CondBranchOpInterface, |
66 | cf::CondBranchOp> { |
67 | FailureOr<Operation *> process(Operation *op, DeallocationState &state, |
68 | const DeallocationOptions &options) const { |
69 | OpBuilder builder(op); |
70 | auto condBr = cast<cf::CondBranchOp>(op); |
71 | |
72 | // The list of memrefs to deallocate in this block is independent of which |
73 | // branch is taken. |
74 | SmallVector<Value> memrefs, conditions; |
75 | if (failed(state.getMemrefsAndConditionsToDeallocate( |
76 | builder, loc: condBr.getLoc(), block: condBr->getBlock(), memrefs, conditions))) |
77 | return failure(); |
78 | |
79 | // Helper lambda to factor out common logic for inserting the dealloc |
80 | // operations for each successor. |
81 | auto insertDeallocForBranch = |
82 | [&](Block *target, MutableOperandRange destOperands, |
83 | const std::function<Value(Value)> &conditionModifier, |
84 | DenseMap<Value, Value> &mapping) -> DeallocOp { |
85 | SmallVector<Value> toRetain; |
86 | state.getMemrefsToRetain(condBr->getBlock(), target, |
87 | destOperands.getAsOperandRange(), toRetain); |
88 | SmallVector<Value> adaptedConditions( |
89 | llvm::map_range(conditions, conditionModifier)); |
90 | auto deallocOp = builder.create<bufferization::DeallocOp>( |
91 | condBr.getLoc(), memrefs, adaptedConditions, toRetain); |
92 | state.resetOwnerships(deallocOp.getRetained(), condBr->getBlock()); |
93 | for (auto [retained, ownership] : llvm::zip( |
94 | deallocOp.getRetained(), deallocOp.getUpdatedConditions())) { |
95 | state.updateOwnership(retained, ownership, condBr->getBlock()); |
96 | mapping[retained] = ownership; |
97 | } |
98 | SmallVector<Value> replacements, ownerships; |
99 | for (OpOperand &operand : destOperands) { |
100 | replacements.push_back(operand.get()); |
101 | if (isMemref(operand.get())) { |
102 | assert(mapping.contains(operand.get()) && |
103 | "Should be contained at this point" ); |
104 | ownerships.push_back(mapping[operand.get()]); |
105 | } |
106 | } |
107 | replacements.append(ownerships); |
108 | destOperands.assign(replacements); |
109 | return deallocOp; |
110 | }; |
111 | |
112 | // Call the helper lambda and make sure the dealloc conditions are properly |
113 | // modified to reflect the branch condition as well. |
114 | DenseMap<Value, Value> thenMapping, elseMapping; |
115 | DeallocOp thenTakenDeallocOp = insertDeallocForBranch( |
116 | condBr.getTrueDest(), condBr.getTrueDestOperandsMutable(), |
117 | [&](Value cond) { |
118 | return builder.create<arith::AndIOp>(condBr.getLoc(), cond, |
119 | condBr.getCondition()); |
120 | }, |
121 | thenMapping); |
122 | DeallocOp elseTakenDeallocOp = insertDeallocForBranch( |
123 | condBr.getFalseDest(), condBr.getFalseDestOperandsMutable(), |
124 | [&](Value cond) { |
125 | Value trueVal = builder.create<arith::ConstantOp>( |
126 | condBr.getLoc(), builder.getBoolAttr(true)); |
127 | Value negation = builder.create<arith::XOrIOp>( |
128 | condBr.getLoc(), trueVal, condBr.getCondition()); |
129 | return builder.create<arith::AndIOp>(condBr.getLoc(), cond, negation); |
130 | }, |
131 | elseMapping); |
132 | |
133 | // We specifically need to update the ownerships of values that are retained |
134 | // in both dealloc operations again to get a combined 'Unique' ownership |
135 | // instead of an 'Unknown' ownership. |
136 | SmallPtrSet<Value, 16> thenValues(thenTakenDeallocOp.getRetained().begin(), |
137 | thenTakenDeallocOp.getRetained().end()); |
138 | SetVector<Value> commonValues; |
139 | for (Value val : elseTakenDeallocOp.getRetained()) { |
140 | if (thenValues.contains(val)) |
141 | commonValues.insert(val); |
142 | } |
143 | |
144 | for (Value retained : commonValues) { |
145 | state.resetOwnerships(memrefs: retained, block: condBr->getBlock()); |
146 | Value combinedOwnership = builder.create<arith::SelectOp>( |
147 | condBr.getLoc(), condBr.getCondition(), thenMapping[retained], |
148 | elseMapping[retained]); |
149 | state.updateOwnership(memref: retained, ownership: combinedOwnership, block: condBr->getBlock()); |
150 | } |
151 | |
152 | return condBr.getOperation(); |
153 | } |
154 | }; |
155 | |
156 | } // namespace |
157 | |
158 | void mlir::cf::registerBufferDeallocationOpInterfaceExternalModels( |
159 | DialectRegistry ®istry) { |
160 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, ControlFlowDialect *dialect) { |
161 | CondBranchOp::attachInterface<CondBranchOpInterface>(*ctx); |
162 | }); |
163 | } |
164 | |