1 | //===- BufferDeallocationOpInterface.cpp ----------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" |
10 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
11 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
12 | #include "mlir/IR/AsmState.h" |
13 | #include "mlir/IR/Matchers.h" |
14 | #include "mlir/IR/Operation.h" |
15 | #include "mlir/IR/TypeUtilities.h" |
16 | #include "mlir/IR/Value.h" |
17 | #include "llvm/ADT/SetOperations.h" |
18 | |
19 | //===----------------------------------------------------------------------===// |
20 | // BufferDeallocationOpInterface |
21 | //===----------------------------------------------------------------------===// |
22 | |
23 | namespace mlir { |
24 | namespace bufferization { |
25 | |
26 | #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp.inc" |
27 | |
28 | } // namespace bufferization |
29 | } // namespace mlir |
30 | |
31 | using namespace mlir; |
32 | using namespace bufferization; |
33 | |
34 | //===----------------------------------------------------------------------===// |
35 | // Helpers |
36 | //===----------------------------------------------------------------------===// |
37 | |
38 | static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) { |
39 | return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value)); |
40 | } |
41 | |
42 | static bool isMemref(Value v) { return isa<BaseMemRefType>(Val: v.getType()); } |
43 | |
44 | //===----------------------------------------------------------------------===// |
45 | // Ownership |
46 | //===----------------------------------------------------------------------===// |
47 | |
48 | Ownership::Ownership(Value indicator) |
49 | : indicator(indicator), state(State::Unique) {} |
50 | |
51 | Ownership Ownership::getUnknown() { |
52 | Ownership unknown; |
53 | unknown.indicator = Value(); |
54 | unknown.state = State::Unknown; |
55 | return unknown; |
56 | } |
57 | Ownership Ownership::getUnique(Value indicator) { return Ownership(indicator); } |
58 | Ownership Ownership::getUninitialized() { return Ownership(); } |
59 | |
60 | bool Ownership::isUninitialized() const { |
61 | return state == State::Uninitialized; |
62 | } |
63 | bool Ownership::isUnique() const { return state == State::Unique; } |
64 | bool Ownership::isUnknown() const { return state == State::Unknown; } |
65 | |
66 | Value Ownership::getIndicator() const { |
67 | assert(isUnique() && "must have unique ownership to get the indicator" ); |
68 | return indicator; |
69 | } |
70 | |
71 | Ownership Ownership::getCombined(Ownership other) const { |
72 | if (other.isUninitialized()) |
73 | return *this; |
74 | if (isUninitialized()) |
75 | return other; |
76 | |
77 | if (!isUnique() || !other.isUnique()) |
78 | return getUnknown(); |
79 | |
80 | // Since we create a new constant i1 value for (almost) each use-site, we |
81 | // should compare the actual value rather than just the SSA Value to avoid |
82 | // unnecessary invalidations. |
83 | if (isEqualConstantIntOrValue(ofr1: indicator, ofr2: other.indicator)) |
84 | return *this; |
85 | |
86 | // Return the join of the lattice if the indicator of both ownerships cannot |
87 | // be merged. |
88 | return getUnknown(); |
89 | } |
90 | |
91 | void Ownership::combine(Ownership other) { *this = getCombined(other); } |
92 | |
93 | //===----------------------------------------------------------------------===// |
94 | // DeallocationState |
95 | //===----------------------------------------------------------------------===// |
96 | |
97 | DeallocationState::DeallocationState(Operation *op) : liveness(op) {} |
98 | |
99 | void DeallocationState::updateOwnership(Value memref, Ownership ownership, |
100 | Block *block) { |
101 | // In most cases we care about the block where the value is defined. |
102 | if (block == nullptr) |
103 | block = memref.getParentBlock(); |
104 | |
105 | // Update ownership of current memref itself. |
106 | ownershipMap[{memref, block}].combine(other: ownership); |
107 | } |
108 | |
109 | void DeallocationState::resetOwnerships(ValueRange memrefs, Block *block) { |
110 | for (Value val : memrefs) |
111 | ownershipMap[{val, block}] = Ownership::getUninitialized(); |
112 | } |
113 | |
114 | Ownership DeallocationState::getOwnership(Value memref, Block *block) const { |
115 | return ownershipMap.lookup(Val: {memref, block}); |
116 | } |
117 | |
118 | void DeallocationState::addMemrefToDeallocate(Value memref, Block *block) { |
119 | memrefsToDeallocatePerBlock[block].push_back(Elt: memref); |
120 | } |
121 | |
122 | void DeallocationState::dropMemrefToDeallocate(Value memref, Block *block) { |
123 | llvm::erase(C&: memrefsToDeallocatePerBlock[block], V: memref); |
124 | } |
125 | |
126 | void DeallocationState::getLiveMemrefsIn(Block *block, |
127 | SmallVectorImpl<Value> &memrefs) { |
128 | SmallVector<Value> liveMemrefs( |
129 | llvm::make_filter_range(Range: liveness.getLiveIn(block), Pred: isMemref)); |
130 | llvm::sort(C&: liveMemrefs, Comp: ValueComparator()); |
131 | memrefs.append(RHS: liveMemrefs); |
132 | } |
133 | |
134 | std::pair<Value, Value> |
135 | DeallocationState::getMemrefWithUniqueOwnership(OpBuilder &builder, |
136 | Value memref, Block *block) { |
137 | auto iter = ownershipMap.find(Val: {memref, block}); |
138 | assert(iter != ownershipMap.end() && |
139 | "Value must already have been registered in the ownership map" ); |
140 | |
141 | Ownership ownership = iter->second; |
142 | if (ownership.isUnique()) |
143 | return {memref, ownership.getIndicator()}; |
144 | |
145 | // Instead of inserting a clone operation we could also insert a dealloc |
146 | // operation earlier in the block and use the updated ownerships returned by |
147 | // the op for the retained values. Alternatively, we could insert code to |
148 | // check aliasing at runtime and use this information to combine two unique |
149 | // ownerships more intelligently to not end up with an 'Unknown' ownership in |
150 | // the first place. |
151 | auto cloneOp = |
152 | builder.create<bufferization::CloneOp>(memref.getLoc(), memref); |
153 | Value condition = buildBoolValue(builder, loc: memref.getLoc(), value: true); |
154 | Value newMemref = cloneOp.getResult(); |
155 | updateOwnership(memref: newMemref, ownership: condition); |
156 | memrefsToDeallocatePerBlock[newMemref.getParentBlock()].push_back(Elt: newMemref); |
157 | return {newMemref, condition}; |
158 | } |
159 | |
160 | void DeallocationState::getMemrefsToRetain( |
161 | Block *fromBlock, Block *toBlock, ValueRange destOperands, |
162 | SmallVectorImpl<Value> &toRetain) const { |
163 | for (Value operand : destOperands) { |
164 | if (!isMemref(v: operand)) |
165 | continue; |
166 | toRetain.push_back(Elt: operand); |
167 | } |
168 | |
169 | SmallPtrSet<Value, 16> liveOut; |
170 | for (auto val : liveness.getLiveOut(block: fromBlock)) |
171 | if (isMemref(v: val)) |
172 | liveOut.insert(Ptr: val); |
173 | |
174 | if (toBlock) |
175 | llvm::set_intersect(S1&: liveOut, S2: liveness.getLiveIn(block: toBlock)); |
176 | |
177 | // liveOut has non-deterministic order because it was constructed by iterating |
178 | // over a hash-set. |
179 | SmallVector<Value> retainedByLiveness(liveOut.begin(), liveOut.end()); |
180 | std::sort(first: retainedByLiveness.begin(), last: retainedByLiveness.end(), |
181 | comp: ValueComparator()); |
182 | toRetain.append(RHS: retainedByLiveness); |
183 | } |
184 | |
185 | LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate( |
186 | OpBuilder &builder, Location loc, Block *block, |
187 | SmallVectorImpl<Value> &memrefs, SmallVectorImpl<Value> &conditions) const { |
188 | |
189 | for (auto [i, memref] : |
190 | llvm::enumerate(First: memrefsToDeallocatePerBlock.lookup(Val: block))) { |
191 | Ownership ownership = ownershipMap.lookup(Val: {memref, block}); |
192 | if (!ownership.isUnique()) |
193 | return emitError(loc: memref.getLoc(), |
194 | message: "MemRef value does not have valid ownership" ); |
195 | |
196 | // Simply cast unranked MemRefs to ranked memrefs with 0 dimensions such |
197 | // that we can call extract_strided_metadata on it. |
198 | if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType())) |
199 | memref = builder.create<memref::ReinterpretCastOp>( |
200 | loc, MemRefType::get({}, unrankedMemRefTy.getElementType()), memref, |
201 | 0, SmallVector<int64_t>{}, SmallVector<int64_t>{}); |
202 | |
203 | // Use the `memref.extract_strided_metadata` operation to get the base |
204 | // memref. This is needed because the same MemRef that was produced by the |
205 | // alloc operation has to be passed to the dealloc operation. Passing |
206 | // subviews, etc. to a dealloc operation is not allowed. |
207 | memrefs.push_back( |
208 | builder.create<memref::ExtractStridedMetadataOp>(loc, memref) |
209 | .getResult(0)); |
210 | conditions.push_back(Elt: ownership.getIndicator()); |
211 | } |
212 | |
213 | return success(); |
214 | } |
215 | |
216 | //===----------------------------------------------------------------------===// |
217 | // ValueComparator |
218 | //===----------------------------------------------------------------------===// |
219 | |
220 | bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const { |
221 | if (lhs == rhs) |
222 | return false; |
223 | |
224 | // Block arguments are less than results. |
225 | bool lhsIsBBArg = isa<BlockArgument>(Val: lhs); |
226 | if (lhsIsBBArg != isa<BlockArgument>(Val: rhs)) { |
227 | return lhsIsBBArg; |
228 | } |
229 | |
230 | Region *lhsRegion; |
231 | Region *rhsRegion; |
232 | if (lhsIsBBArg) { |
233 | auto lhsBBArg = llvm::cast<BlockArgument>(Val: lhs); |
234 | auto rhsBBArg = llvm::cast<BlockArgument>(Val: rhs); |
235 | if (lhsBBArg.getArgNumber() != rhsBBArg.getArgNumber()) { |
236 | return lhsBBArg.getArgNumber() < rhsBBArg.getArgNumber(); |
237 | } |
238 | lhsRegion = lhsBBArg.getParentRegion(); |
239 | rhsRegion = rhsBBArg.getParentRegion(); |
240 | assert(lhsRegion != rhsRegion && |
241 | "lhsRegion == rhsRegion implies lhs == rhs" ); |
242 | } else if (lhs.getDefiningOp() == rhs.getDefiningOp()) { |
243 | return llvm::cast<OpResult>(Val: lhs).getResultNumber() < |
244 | llvm::cast<OpResult>(Val: rhs).getResultNumber(); |
245 | } else { |
246 | lhsRegion = lhs.getDefiningOp()->getParentRegion(); |
247 | rhsRegion = rhs.getDefiningOp()->getParentRegion(); |
248 | if (lhsRegion == rhsRegion) { |
249 | return lhs.getDefiningOp()->isBeforeInBlock(other: rhs.getDefiningOp()); |
250 | } |
251 | } |
252 | |
253 | // lhsRegion != rhsRegion, so if we look at their ancestor chain, they |
254 | // - have different heights |
255 | // - or there's a spot where their region numbers differ |
256 | // - or their parent regions are the same and their parent ops are |
257 | // different. |
258 | while (lhsRegion && rhsRegion) { |
259 | if (lhsRegion->getRegionNumber() != rhsRegion->getRegionNumber()) { |
260 | return lhsRegion->getRegionNumber() < rhsRegion->getRegionNumber(); |
261 | } |
262 | if (lhsRegion->getParentRegion() == rhsRegion->getParentRegion()) { |
263 | return lhsRegion->getParentOp()->isBeforeInBlock( |
264 | other: rhsRegion->getParentOp()); |
265 | } |
266 | lhsRegion = lhsRegion->getParentRegion(); |
267 | rhsRegion = rhsRegion->getParentRegion(); |
268 | } |
269 | if (rhsRegion) |
270 | return true; |
271 | assert(lhsRegion && "this should only happen if lhs == rhs" ); |
272 | return false; |
273 | } |
274 | |
275 | //===----------------------------------------------------------------------===// |
276 | // Implementation utilities |
277 | //===----------------------------------------------------------------------===// |
278 | |
279 | FailureOr<Operation *> deallocation_impl::insertDeallocOpForReturnLike( |
280 | DeallocationState &state, Operation *op, ValueRange operands, |
281 | SmallVectorImpl<Value> &updatedOperandOwnerships) { |
282 | assert(op->hasTrait<OpTrait::IsTerminator>() && "must be a terminator" ); |
283 | assert(!op->hasSuccessors() && "must not have any successors" ); |
284 | // Collect the values to deallocate and retain and use them to create the |
285 | // dealloc operation. |
286 | OpBuilder builder(op); |
287 | Block *block = op->getBlock(); |
288 | SmallVector<Value> memrefs, conditions, toRetain; |
289 | if (failed(result: state.getMemrefsAndConditionsToDeallocate( |
290 | builder, loc: op->getLoc(), block, memrefs, conditions))) |
291 | return failure(); |
292 | |
293 | state.getMemrefsToRetain(fromBlock: block, /*toBlock=*/nullptr, destOperands: operands, toRetain); |
294 | if (memrefs.empty() && toRetain.empty()) |
295 | return op; |
296 | |
297 | auto deallocOp = builder.create<bufferization::DeallocOp>( |
298 | op->getLoc(), memrefs, conditions, toRetain); |
299 | |
300 | // We want to replace the current ownership of the retained values with the |
301 | // result values of the dealloc operation as they are always unique. |
302 | state.resetOwnerships(memrefs: deallocOp.getRetained(), block); |
303 | for (auto [retained, ownership] : |
304 | llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions())) |
305 | state.updateOwnership(retained, ownership, block); |
306 | |
307 | unsigned numMemrefOperands = llvm::count_if(Range&: operands, P: isMemref); |
308 | auto newOperandOwnerships = |
309 | deallocOp.getUpdatedConditions().take_front(numMemrefOperands); |
310 | updatedOperandOwnerships.append(newOperandOwnerships.begin(), |
311 | newOperandOwnerships.end()); |
312 | |
313 | return op; |
314 | } |
315 | |