1//===- BufferDeallocationOpInterface.cpp ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
10#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11#include "mlir/Dialect/MemRef/IR/MemRef.h"
12#include "mlir/IR/AsmState.h"
13#include "mlir/IR/Operation.h"
14#include "mlir/IR/TypeUtilities.h"
15#include "mlir/IR/Value.h"
16#include "llvm/ADT/SetOperations.h"
17
18//===----------------------------------------------------------------------===//
19// BufferDeallocationOpInterface
20//===----------------------------------------------------------------------===//
21
22namespace mlir {
23namespace bufferization {
24
25#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp.inc"
26
27} // namespace bufferization
28} // namespace mlir
29
30using namespace mlir;
31using namespace bufferization;
32
33//===----------------------------------------------------------------------===//
34// Helpers
35//===----------------------------------------------------------------------===//
36
37static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
38 return builder.create<arith::ConstantOp>(location: loc, args: builder.getBoolAttr(value));
39}
40
41static bool isMemref(Value v) { return isa<BaseMemRefType>(Val: v.getType()); }
42
43//===----------------------------------------------------------------------===//
44// Ownership
45//===----------------------------------------------------------------------===//
46
47Ownership::Ownership(Value indicator)
48 : indicator(indicator), state(State::Unique) {}
49
50Ownership Ownership::getUnknown() {
51 Ownership unknown;
52 unknown.indicator = Value();
53 unknown.state = State::Unknown;
54 return unknown;
55}
56Ownership Ownership::getUnique(Value indicator) { return Ownership(indicator); }
57Ownership Ownership::getUninitialized() { return Ownership(); }
58
59bool Ownership::isUninitialized() const {
60 return state == State::Uninitialized;
61}
62bool Ownership::isUnique() const { return state == State::Unique; }
63bool Ownership::isUnknown() const { return state == State::Unknown; }
64
65Value Ownership::getIndicator() const {
66 assert(isUnique() && "must have unique ownership to get the indicator");
67 return indicator;
68}
69
70Ownership Ownership::getCombined(Ownership other) const {
71 if (other.isUninitialized())
72 return *this;
73 if (isUninitialized())
74 return other;
75
76 if (!isUnique() || !other.isUnique())
77 return getUnknown();
78
79 // Since we create a new constant i1 value for (almost) each use-site, we
80 // should compare the actual value rather than just the SSA Value to avoid
81 // unnecessary invalidations.
82 if (isEqualConstantIntOrValue(ofr1: indicator, ofr2: other.indicator))
83 return *this;
84
85 // Return the join of the lattice if the indicator of both ownerships cannot
86 // be merged.
87 return getUnknown();
88}
89
90void Ownership::combine(Ownership other) { *this = getCombined(other); }
91
92//===----------------------------------------------------------------------===//
93// DeallocationState
94//===----------------------------------------------------------------------===//
95
96DeallocationState::DeallocationState(Operation *op,
97 SymbolTableCollection &symbolTables)
98 : symbolTable(symbolTables), liveness(op) {}
99
100void DeallocationState::updateOwnership(Value memref, Ownership ownership,
101 Block *block) {
102 // In most cases we care about the block where the value is defined.
103 if (block == nullptr)
104 block = memref.getParentBlock();
105
106 // Update ownership of current memref itself.
107 ownershipMap[{memref, block}].combine(other: ownership);
108}
109
110void DeallocationState::resetOwnerships(ValueRange memrefs, Block *block) {
111 for (Value val : memrefs)
112 ownershipMap[{val, block}] = Ownership::getUninitialized();
113}
114
115Ownership DeallocationState::getOwnership(Value memref, Block *block) const {
116 return ownershipMap.lookup(Val: {memref, block});
117}
118
119void DeallocationState::addMemrefToDeallocate(Value memref, Block *block) {
120 memrefsToDeallocatePerBlock[block].push_back(Elt: memref);
121}
122
123void DeallocationState::dropMemrefToDeallocate(Value memref, Block *block) {
124 llvm::erase(C&: memrefsToDeallocatePerBlock[block], V: memref);
125}
126
127void DeallocationState::getLiveMemrefsIn(Block *block,
128 SmallVectorImpl<Value> &memrefs) {
129 SmallVector<Value> liveMemrefs(
130 llvm::make_filter_range(Range: liveness.getLiveIn(block), Pred: isMemref));
131 llvm::sort(C&: liveMemrefs, Comp: ValueComparator());
132 memrefs.append(RHS: liveMemrefs);
133}
134
135std::pair<Value, Value>
136DeallocationState::getMemrefWithUniqueOwnership(OpBuilder &builder,
137 Value memref, Block *block) {
138 auto iter = ownershipMap.find(Val: {memref, block});
139 assert(iter != ownershipMap.end() &&
140 "Value must already have been registered in the ownership map");
141
142 Ownership ownership = iter->second;
143 if (ownership.isUnique())
144 return {memref, ownership.getIndicator()};
145
146 // Instead of inserting a clone operation we could also insert a dealloc
147 // operation earlier in the block and use the updated ownerships returned by
148 // the op for the retained values. Alternatively, we could insert code to
149 // check aliasing at runtime and use this information to combine two unique
150 // ownerships more intelligently to not end up with an 'Unknown' ownership in
151 // the first place.
152 auto cloneOp =
153 builder.create<bufferization::CloneOp>(location: memref.getLoc(), args&: memref);
154 Value condition = buildBoolValue(builder, loc: memref.getLoc(), value: true);
155 Value newMemref = cloneOp.getResult();
156 updateOwnership(memref: newMemref, ownership: condition);
157 memrefsToDeallocatePerBlock[newMemref.getParentBlock()].push_back(Elt: newMemref);
158 return {newMemref, condition};
159}
160
161void DeallocationState::getMemrefsToRetain(
162 Block *fromBlock, Block *toBlock, ValueRange destOperands,
163 SmallVectorImpl<Value> &toRetain) const {
164 for (Value operand : destOperands) {
165 if (!isMemref(v: operand))
166 continue;
167 toRetain.push_back(Elt: operand);
168 }
169
170 SmallPtrSet<Value, 16> liveOut;
171 for (auto val : liveness.getLiveOut(block: fromBlock))
172 if (isMemref(v: val))
173 liveOut.insert(Ptr: val);
174
175 if (toBlock)
176 llvm::set_intersect(S1&: liveOut, S2: liveness.getLiveIn(block: toBlock));
177
178 // liveOut has non-deterministic order because it was constructed by iterating
179 // over a hash-set.
180 SmallVector<Value> retainedByLiveness(liveOut.begin(), liveOut.end());
181 llvm::sort(C&: retainedByLiveness, Comp: ValueComparator());
182 toRetain.append(RHS: retainedByLiveness);
183}
184
185LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate(
186 OpBuilder &builder, Location loc, Block *block,
187 SmallVectorImpl<Value> &memrefs, SmallVectorImpl<Value> &conditions) const {
188
189 for (auto [i, memref] :
190 llvm::enumerate(First: memrefsToDeallocatePerBlock.lookup(Val: block))) {
191 Ownership ownership = ownershipMap.lookup(Val: {memref, block});
192 if (!ownership.isUnique())
193 return emitError(loc: memref.getLoc(),
194 message: "MemRef value does not have valid ownership");
195
196 // Simply cast unranked MemRefs to ranked memrefs with 0 dimensions such
197 // that we can call extract_strided_metadata on it.
198 if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(Val: memref.getType()))
199 memref = builder.create<memref::ReinterpretCastOp>(
200 location: loc, args&: memref,
201 /*offset=*/args: builder.getIndexAttr(value: 0),
202 /*sizes=*/args: ArrayRef<OpFoldResult>{},
203 /*strides=*/args: ArrayRef<OpFoldResult>{});
204
205 // Use the `memref.extract_strided_metadata` operation to get the base
206 // memref. This is needed because the same MemRef that was produced by the
207 // alloc operation has to be passed to the dealloc operation. Passing
208 // subviews, etc. to a dealloc operation is not allowed.
209 memrefs.push_back(
210 Elt: builder.create<memref::ExtractStridedMetadataOp>(location: loc, args&: memref)
211 .getResult(i: 0));
212 conditions.push_back(Elt: ownership.getIndicator());
213 }
214
215 return success();
216}
217
218//===----------------------------------------------------------------------===//
219// ValueComparator
220//===----------------------------------------------------------------------===//
221
222bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const {
223 if (lhs == rhs)
224 return false;
225
226 // Block arguments are less than results.
227 bool lhsIsBBArg = isa<BlockArgument>(Val: lhs);
228 if (lhsIsBBArg != isa<BlockArgument>(Val: rhs)) {
229 return lhsIsBBArg;
230 }
231
232 Region *lhsRegion;
233 Region *rhsRegion;
234 if (lhsIsBBArg) {
235 auto lhsBBArg = llvm::cast<BlockArgument>(Val: lhs);
236 auto rhsBBArg = llvm::cast<BlockArgument>(Val: rhs);
237 if (lhsBBArg.getArgNumber() != rhsBBArg.getArgNumber()) {
238 return lhsBBArg.getArgNumber() < rhsBBArg.getArgNumber();
239 }
240 lhsRegion = lhsBBArg.getParentRegion();
241 rhsRegion = rhsBBArg.getParentRegion();
242 assert(lhsRegion != rhsRegion &&
243 "lhsRegion == rhsRegion implies lhs == rhs");
244 } else if (lhs.getDefiningOp() == rhs.getDefiningOp()) {
245 return llvm::cast<OpResult>(Val: lhs).getResultNumber() <
246 llvm::cast<OpResult>(Val: rhs).getResultNumber();
247 } else {
248 lhsRegion = lhs.getDefiningOp()->getParentRegion();
249 rhsRegion = rhs.getDefiningOp()->getParentRegion();
250 if (lhsRegion == rhsRegion) {
251 return lhs.getDefiningOp()->isBeforeInBlock(other: rhs.getDefiningOp());
252 }
253 }
254
255 // lhsRegion != rhsRegion, so if we look at their ancestor chain, they
256 // - have different heights
257 // - or there's a spot where their region numbers differ
258 // - or their parent regions are the same and their parent ops are
259 // different.
260 while (lhsRegion && rhsRegion) {
261 if (lhsRegion->getRegionNumber() != rhsRegion->getRegionNumber()) {
262 return lhsRegion->getRegionNumber() < rhsRegion->getRegionNumber();
263 }
264 if (lhsRegion->getParentRegion() == rhsRegion->getParentRegion()) {
265 return lhsRegion->getParentOp()->isBeforeInBlock(
266 other: rhsRegion->getParentOp());
267 }
268 lhsRegion = lhsRegion->getParentRegion();
269 rhsRegion = rhsRegion->getParentRegion();
270 }
271 if (rhsRegion)
272 return true;
273 assert(lhsRegion && "this should only happen if lhs == rhs");
274 return false;
275}
276
277//===----------------------------------------------------------------------===//
278// Implementation utilities
279//===----------------------------------------------------------------------===//
280
281FailureOr<Operation *> deallocation_impl::insertDeallocOpForReturnLike(
282 DeallocationState &state, Operation *op, ValueRange operands,
283 SmallVectorImpl<Value> &updatedOperandOwnerships) {
284 assert(op->hasTrait<OpTrait::IsTerminator>() && "must be a terminator");
285 assert(!op->hasSuccessors() && "must not have any successors");
286 // Collect the values to deallocate and retain and use them to create the
287 // dealloc operation.
288 OpBuilder builder(op);
289 Block *block = op->getBlock();
290 SmallVector<Value> memrefs, conditions, toRetain;
291 if (failed(Result: state.getMemrefsAndConditionsToDeallocate(
292 builder, loc: op->getLoc(), block, memrefs, conditions)))
293 return failure();
294
295 state.getMemrefsToRetain(fromBlock: block, /*toBlock=*/nullptr, destOperands: operands, toRetain);
296 if (memrefs.empty() && toRetain.empty())
297 return op;
298
299 auto deallocOp = builder.create<bufferization::DeallocOp>(
300 location: op->getLoc(), args&: memrefs, args&: conditions, args&: toRetain);
301
302 // We want to replace the current ownership of the retained values with the
303 // result values of the dealloc operation as they are always unique.
304 state.resetOwnerships(memrefs: deallocOp.getRetained(), block);
305 for (auto [retained, ownership] :
306 llvm::zip(t: deallocOp.getRetained(), u: deallocOp.getUpdatedConditions()))
307 state.updateOwnership(memref: retained, ownership, block);
308
309 unsigned numMemrefOperands = llvm::count_if(Range&: operands, P: isMemref);
310 auto newOperandOwnerships =
311 deallocOp.getUpdatedConditions().take_front(n: numMemrefOperands);
312 updatedOperandOwnerships.append(in_start: newOperandOwnerships.begin(),
313 in_end: newOperandOwnerships.end());
314
315 return op;
316}
317

source code of mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp