1 | //===- BufferDeallocationOpInterface.cpp ----------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" |
10 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
11 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
12 | #include "mlir/IR/AsmState.h" |
13 | #include "mlir/IR/Matchers.h" |
14 | #include "mlir/IR/Operation.h" |
15 | #include "mlir/IR/TypeUtilities.h" |
16 | #include "mlir/IR/Value.h" |
17 | #include "llvm/ADT/SetOperations.h" |
18 | |
19 | //===----------------------------------------------------------------------===// |
20 | // BufferDeallocationOpInterface |
21 | //===----------------------------------------------------------------------===// |
22 | |
23 | namespace mlir { |
24 | namespace bufferization { |
25 | |
26 | #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp.inc" |
27 | |
28 | } // namespace bufferization |
29 | } // namespace mlir |
30 | |
31 | using namespace mlir; |
32 | using namespace bufferization; |
33 | |
34 | //===----------------------------------------------------------------------===// |
35 | // Helpers |
36 | //===----------------------------------------------------------------------===// |
37 | |
38 | static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) { |
39 | return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value)); |
40 | } |
41 | |
42 | static bool isMemref(Value v) { return isa<BaseMemRefType>(Val: v.getType()); } |
43 | |
44 | //===----------------------------------------------------------------------===// |
45 | // Ownership |
46 | //===----------------------------------------------------------------------===// |
47 | |
48 | Ownership::Ownership(Value indicator) |
49 | : indicator(indicator), state(State::Unique) {} |
50 | |
51 | Ownership Ownership::getUnknown() { |
52 | Ownership unknown; |
53 | unknown.indicator = Value(); |
54 | unknown.state = State::Unknown; |
55 | return unknown; |
56 | } |
57 | Ownership Ownership::getUnique(Value indicator) { return Ownership(indicator); } |
58 | Ownership Ownership::getUninitialized() { return Ownership(); } |
59 | |
60 | bool Ownership::isUninitialized() const { |
61 | return state == State::Uninitialized; |
62 | } |
63 | bool Ownership::isUnique() const { return state == State::Unique; } |
64 | bool Ownership::isUnknown() const { return state == State::Unknown; } |
65 | |
66 | Value Ownership::getIndicator() const { |
67 | assert(isUnique() && "must have unique ownership to get the indicator" ); |
68 | return indicator; |
69 | } |
70 | |
71 | Ownership Ownership::getCombined(Ownership other) const { |
72 | if (other.isUninitialized()) |
73 | return *this; |
74 | if (isUninitialized()) |
75 | return other; |
76 | |
77 | if (!isUnique() || !other.isUnique()) |
78 | return getUnknown(); |
79 | |
80 | // Since we create a new constant i1 value for (almost) each use-site, we |
81 | // should compare the actual value rather than just the SSA Value to avoid |
82 | // unnecessary invalidations. |
83 | if (isEqualConstantIntOrValue(ofr1: indicator, ofr2: other.indicator)) |
84 | return *this; |
85 | |
86 | // Return the join of the lattice if the indicator of both ownerships cannot |
87 | // be merged. |
88 | return getUnknown(); |
89 | } |
90 | |
91 | void Ownership::combine(Ownership other) { *this = getCombined(other); } |
92 | |
93 | //===----------------------------------------------------------------------===// |
94 | // DeallocationState |
95 | //===----------------------------------------------------------------------===// |
96 | |
97 | DeallocationState::DeallocationState(Operation *op, |
98 | SymbolTableCollection &symbolTables) |
99 | : symbolTable(symbolTables), liveness(op) {} |
100 | |
101 | void DeallocationState::updateOwnership(Value memref, Ownership ownership, |
102 | Block *block) { |
103 | // In most cases we care about the block where the value is defined. |
104 | if (block == nullptr) |
105 | block = memref.getParentBlock(); |
106 | |
107 | // Update ownership of current memref itself. |
108 | ownershipMap[{memref, block}].combine(other: ownership); |
109 | } |
110 | |
111 | void DeallocationState::resetOwnerships(ValueRange memrefs, Block *block) { |
112 | for (Value val : memrefs) |
113 | ownershipMap[{val, block}] = Ownership::getUninitialized(); |
114 | } |
115 | |
116 | Ownership DeallocationState::getOwnership(Value memref, Block *block) const { |
117 | return ownershipMap.lookup(Val: {memref, block}); |
118 | } |
119 | |
120 | void DeallocationState::addMemrefToDeallocate(Value memref, Block *block) { |
121 | memrefsToDeallocatePerBlock[block].push_back(Elt: memref); |
122 | } |
123 | |
124 | void DeallocationState::dropMemrefToDeallocate(Value memref, Block *block) { |
125 | llvm::erase(C&: memrefsToDeallocatePerBlock[block], V: memref); |
126 | } |
127 | |
128 | void DeallocationState::getLiveMemrefsIn(Block *block, |
129 | SmallVectorImpl<Value> &memrefs) { |
130 | SmallVector<Value> liveMemrefs( |
131 | llvm::make_filter_range(Range: liveness.getLiveIn(block), Pred: isMemref)); |
132 | llvm::sort(C&: liveMemrefs, Comp: ValueComparator()); |
133 | memrefs.append(RHS: liveMemrefs); |
134 | } |
135 | |
136 | std::pair<Value, Value> |
137 | DeallocationState::getMemrefWithUniqueOwnership(OpBuilder &builder, |
138 | Value memref, Block *block) { |
139 | auto iter = ownershipMap.find(Val: {memref, block}); |
140 | assert(iter != ownershipMap.end() && |
141 | "Value must already have been registered in the ownership map" ); |
142 | |
143 | Ownership ownership = iter->second; |
144 | if (ownership.isUnique()) |
145 | return {memref, ownership.getIndicator()}; |
146 | |
147 | // Instead of inserting a clone operation we could also insert a dealloc |
148 | // operation earlier in the block and use the updated ownerships returned by |
149 | // the op for the retained values. Alternatively, we could insert code to |
150 | // check aliasing at runtime and use this information to combine two unique |
151 | // ownerships more intelligently to not end up with an 'Unknown' ownership in |
152 | // the first place. |
153 | auto cloneOp = |
154 | builder.create<bufferization::CloneOp>(memref.getLoc(), memref); |
155 | Value condition = buildBoolValue(builder, loc: memref.getLoc(), value: true); |
156 | Value newMemref = cloneOp.getResult(); |
157 | updateOwnership(memref: newMemref, ownership: condition); |
158 | memrefsToDeallocatePerBlock[newMemref.getParentBlock()].push_back(Elt: newMemref); |
159 | return {newMemref, condition}; |
160 | } |
161 | |
162 | void DeallocationState::getMemrefsToRetain( |
163 | Block *fromBlock, Block *toBlock, ValueRange destOperands, |
164 | SmallVectorImpl<Value> &toRetain) const { |
165 | for (Value operand : destOperands) { |
166 | if (!isMemref(v: operand)) |
167 | continue; |
168 | toRetain.push_back(Elt: operand); |
169 | } |
170 | |
171 | SmallPtrSet<Value, 16> liveOut; |
172 | for (auto val : liveness.getLiveOut(block: fromBlock)) |
173 | if (isMemref(v: val)) |
174 | liveOut.insert(Ptr: val); |
175 | |
176 | if (toBlock) |
177 | llvm::set_intersect(S1&: liveOut, S2: liveness.getLiveIn(block: toBlock)); |
178 | |
179 | // liveOut has non-deterministic order because it was constructed by iterating |
180 | // over a hash-set. |
181 | SmallVector<Value> retainedByLiveness(liveOut.begin(), liveOut.end()); |
182 | llvm::sort(C&: retainedByLiveness, Comp: ValueComparator()); |
183 | toRetain.append(RHS: retainedByLiveness); |
184 | } |
185 | |
186 | LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate( |
187 | OpBuilder &builder, Location loc, Block *block, |
188 | SmallVectorImpl<Value> &memrefs, SmallVectorImpl<Value> &conditions) const { |
189 | |
190 | for (auto [i, memref] : |
191 | llvm::enumerate(First: memrefsToDeallocatePerBlock.lookup(Val: block))) { |
192 | Ownership ownership = ownershipMap.lookup(Val: {memref, block}); |
193 | if (!ownership.isUnique()) |
194 | return emitError(loc: memref.getLoc(), |
195 | message: "MemRef value does not have valid ownership" ); |
196 | |
197 | // Simply cast unranked MemRefs to ranked memrefs with 0 dimensions such |
198 | // that we can call extract_strided_metadata on it. |
199 | if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType())) |
200 | memref = builder.create<memref::ReinterpretCastOp>( |
201 | loc, memref, |
202 | /*offset=*/builder.getIndexAttr(0), |
203 | /*sizes=*/ArrayRef<OpFoldResult>{}, |
204 | /*strides=*/ArrayRef<OpFoldResult>{}); |
205 | |
206 | // Use the `memref.extract_strided_metadata` operation to get the base |
207 | // memref. This is needed because the same MemRef that was produced by the |
208 | // alloc operation has to be passed to the dealloc operation. Passing |
209 | // subviews, etc. to a dealloc operation is not allowed. |
210 | memrefs.push_back( |
211 | builder.create<memref::ExtractStridedMetadataOp>(loc, memref) |
212 | .getResult(0)); |
213 | conditions.push_back(Elt: ownership.getIndicator()); |
214 | } |
215 | |
216 | return success(); |
217 | } |
218 | |
219 | //===----------------------------------------------------------------------===// |
220 | // ValueComparator |
221 | //===----------------------------------------------------------------------===// |
222 | |
223 | bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const { |
224 | if (lhs == rhs) |
225 | return false; |
226 | |
227 | // Block arguments are less than results. |
228 | bool lhsIsBBArg = isa<BlockArgument>(Val: lhs); |
229 | if (lhsIsBBArg != isa<BlockArgument>(Val: rhs)) { |
230 | return lhsIsBBArg; |
231 | } |
232 | |
233 | Region *lhsRegion; |
234 | Region *rhsRegion; |
235 | if (lhsIsBBArg) { |
236 | auto lhsBBArg = llvm::cast<BlockArgument>(Val: lhs); |
237 | auto rhsBBArg = llvm::cast<BlockArgument>(Val: rhs); |
238 | if (lhsBBArg.getArgNumber() != rhsBBArg.getArgNumber()) { |
239 | return lhsBBArg.getArgNumber() < rhsBBArg.getArgNumber(); |
240 | } |
241 | lhsRegion = lhsBBArg.getParentRegion(); |
242 | rhsRegion = rhsBBArg.getParentRegion(); |
243 | assert(lhsRegion != rhsRegion && |
244 | "lhsRegion == rhsRegion implies lhs == rhs" ); |
245 | } else if (lhs.getDefiningOp() == rhs.getDefiningOp()) { |
246 | return llvm::cast<OpResult>(Val: lhs).getResultNumber() < |
247 | llvm::cast<OpResult>(Val: rhs).getResultNumber(); |
248 | } else { |
249 | lhsRegion = lhs.getDefiningOp()->getParentRegion(); |
250 | rhsRegion = rhs.getDefiningOp()->getParentRegion(); |
251 | if (lhsRegion == rhsRegion) { |
252 | return lhs.getDefiningOp()->isBeforeInBlock(other: rhs.getDefiningOp()); |
253 | } |
254 | } |
255 | |
256 | // lhsRegion != rhsRegion, so if we look at their ancestor chain, they |
257 | // - have different heights |
258 | // - or there's a spot where their region numbers differ |
259 | // - or their parent regions are the same and their parent ops are |
260 | // different. |
261 | while (lhsRegion && rhsRegion) { |
262 | if (lhsRegion->getRegionNumber() != rhsRegion->getRegionNumber()) { |
263 | return lhsRegion->getRegionNumber() < rhsRegion->getRegionNumber(); |
264 | } |
265 | if (lhsRegion->getParentRegion() == rhsRegion->getParentRegion()) { |
266 | return lhsRegion->getParentOp()->isBeforeInBlock( |
267 | other: rhsRegion->getParentOp()); |
268 | } |
269 | lhsRegion = lhsRegion->getParentRegion(); |
270 | rhsRegion = rhsRegion->getParentRegion(); |
271 | } |
272 | if (rhsRegion) |
273 | return true; |
274 | assert(lhsRegion && "this should only happen if lhs == rhs" ); |
275 | return false; |
276 | } |
277 | |
278 | //===----------------------------------------------------------------------===// |
279 | // Implementation utilities |
280 | //===----------------------------------------------------------------------===// |
281 | |
282 | FailureOr<Operation *> deallocation_impl::insertDeallocOpForReturnLike( |
283 | DeallocationState &state, Operation *op, ValueRange operands, |
284 | SmallVectorImpl<Value> &updatedOperandOwnerships) { |
285 | assert(op->hasTrait<OpTrait::IsTerminator>() && "must be a terminator" ); |
286 | assert(!op->hasSuccessors() && "must not have any successors" ); |
287 | // Collect the values to deallocate and retain and use them to create the |
288 | // dealloc operation. |
289 | OpBuilder builder(op); |
290 | Block *block = op->getBlock(); |
291 | SmallVector<Value> memrefs, conditions, toRetain; |
292 | if (failed(Result: state.getMemrefsAndConditionsToDeallocate( |
293 | builder, loc: op->getLoc(), block, memrefs, conditions))) |
294 | return failure(); |
295 | |
296 | state.getMemrefsToRetain(fromBlock: block, /*toBlock=*/nullptr, destOperands: operands, toRetain); |
297 | if (memrefs.empty() && toRetain.empty()) |
298 | return op; |
299 | |
300 | auto deallocOp = builder.create<bufferization::DeallocOp>( |
301 | op->getLoc(), memrefs, conditions, toRetain); |
302 | |
303 | // We want to replace the current ownership of the retained values with the |
304 | // result values of the dealloc operation as they are always unique. |
305 | state.resetOwnerships(memrefs: deallocOp.getRetained(), block); |
306 | for (auto [retained, ownership] : |
307 | llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions())) |
308 | state.updateOwnership(retained, ownership, block); |
309 | |
310 | unsigned numMemrefOperands = llvm::count_if(Range&: operands, P: isMemref); |
311 | auto newOperandOwnerships = |
312 | deallocOp.getUpdatedConditions().take_front(numMemrefOperands); |
313 | updatedOperandOwnerships.append(newOperandOwnerships.begin(), |
314 | newOperandOwnerships.end()); |
315 | |
316 | return op; |
317 | } |
318 | |