1//===- BufferDeallocationOpInterface.cpp ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
10#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11#include "mlir/Dialect/MemRef/IR/MemRef.h"
12#include "mlir/IR/AsmState.h"
13#include "mlir/IR/Matchers.h"
14#include "mlir/IR/Operation.h"
15#include "mlir/IR/TypeUtilities.h"
16#include "mlir/IR/Value.h"
17#include "llvm/ADT/SetOperations.h"
18
19//===----------------------------------------------------------------------===//
20// BufferDeallocationOpInterface
21//===----------------------------------------------------------------------===//
22
23namespace mlir {
24namespace bufferization {
25
26#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp.inc"
27
28} // namespace bufferization
29} // namespace mlir
30
31using namespace mlir;
32using namespace bufferization;
33
34//===----------------------------------------------------------------------===//
35// Helpers
36//===----------------------------------------------------------------------===//
37
38static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
39 return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value));
40}
41
42static bool isMemref(Value v) { return isa<BaseMemRefType>(Val: v.getType()); }
43
44//===----------------------------------------------------------------------===//
45// Ownership
46//===----------------------------------------------------------------------===//
47
48Ownership::Ownership(Value indicator)
49 : indicator(indicator), state(State::Unique) {}
50
51Ownership Ownership::getUnknown() {
52 Ownership unknown;
53 unknown.indicator = Value();
54 unknown.state = State::Unknown;
55 return unknown;
56}
57Ownership Ownership::getUnique(Value indicator) { return Ownership(indicator); }
58Ownership Ownership::getUninitialized() { return Ownership(); }
59
60bool Ownership::isUninitialized() const {
61 return state == State::Uninitialized;
62}
63bool Ownership::isUnique() const { return state == State::Unique; }
64bool Ownership::isUnknown() const { return state == State::Unknown; }
65
66Value Ownership::getIndicator() const {
67 assert(isUnique() && "must have unique ownership to get the indicator");
68 return indicator;
69}
70
71Ownership Ownership::getCombined(Ownership other) const {
72 if (other.isUninitialized())
73 return *this;
74 if (isUninitialized())
75 return other;
76
77 if (!isUnique() || !other.isUnique())
78 return getUnknown();
79
80 // Since we create a new constant i1 value for (almost) each use-site, we
81 // should compare the actual value rather than just the SSA Value to avoid
82 // unnecessary invalidations.
83 if (isEqualConstantIntOrValue(ofr1: indicator, ofr2: other.indicator))
84 return *this;
85
86 // Return the join of the lattice if the indicator of both ownerships cannot
87 // be merged.
88 return getUnknown();
89}
90
91void Ownership::combine(Ownership other) { *this = getCombined(other); }
92
93//===----------------------------------------------------------------------===//
94// DeallocationState
95//===----------------------------------------------------------------------===//
96
97DeallocationState::DeallocationState(Operation *op,
98 SymbolTableCollection &symbolTables)
99 : symbolTable(symbolTables), liveness(op) {}
100
101void DeallocationState::updateOwnership(Value memref, Ownership ownership,
102 Block *block) {
103 // In most cases we care about the block where the value is defined.
104 if (block == nullptr)
105 block = memref.getParentBlock();
106
107 // Update ownership of current memref itself.
108 ownershipMap[{memref, block}].combine(other: ownership);
109}
110
111void DeallocationState::resetOwnerships(ValueRange memrefs, Block *block) {
112 for (Value val : memrefs)
113 ownershipMap[{val, block}] = Ownership::getUninitialized();
114}
115
116Ownership DeallocationState::getOwnership(Value memref, Block *block) const {
117 return ownershipMap.lookup(Val: {memref, block});
118}
119
120void DeallocationState::addMemrefToDeallocate(Value memref, Block *block) {
121 memrefsToDeallocatePerBlock[block].push_back(Elt: memref);
122}
123
124void DeallocationState::dropMemrefToDeallocate(Value memref, Block *block) {
125 llvm::erase(C&: memrefsToDeallocatePerBlock[block], V: memref);
126}
127
128void DeallocationState::getLiveMemrefsIn(Block *block,
129 SmallVectorImpl<Value> &memrefs) {
130 SmallVector<Value> liveMemrefs(
131 llvm::make_filter_range(Range: liveness.getLiveIn(block), Pred: isMemref));
132 llvm::sort(C&: liveMemrefs, Comp: ValueComparator());
133 memrefs.append(RHS: liveMemrefs);
134}
135
136std::pair<Value, Value>
137DeallocationState::getMemrefWithUniqueOwnership(OpBuilder &builder,
138 Value memref, Block *block) {
139 auto iter = ownershipMap.find(Val: {memref, block});
140 assert(iter != ownershipMap.end() &&
141 "Value must already have been registered in the ownership map");
142
143 Ownership ownership = iter->second;
144 if (ownership.isUnique())
145 return {memref, ownership.getIndicator()};
146
147 // Instead of inserting a clone operation we could also insert a dealloc
148 // operation earlier in the block and use the updated ownerships returned by
149 // the op for the retained values. Alternatively, we could insert code to
150 // check aliasing at runtime and use this information to combine two unique
151 // ownerships more intelligently to not end up with an 'Unknown' ownership in
152 // the first place.
153 auto cloneOp =
154 builder.create<bufferization::CloneOp>(memref.getLoc(), memref);
155 Value condition = buildBoolValue(builder, loc: memref.getLoc(), value: true);
156 Value newMemref = cloneOp.getResult();
157 updateOwnership(memref: newMemref, ownership: condition);
158 memrefsToDeallocatePerBlock[newMemref.getParentBlock()].push_back(Elt: newMemref);
159 return {newMemref, condition};
160}
161
162void DeallocationState::getMemrefsToRetain(
163 Block *fromBlock, Block *toBlock, ValueRange destOperands,
164 SmallVectorImpl<Value> &toRetain) const {
165 for (Value operand : destOperands) {
166 if (!isMemref(v: operand))
167 continue;
168 toRetain.push_back(Elt: operand);
169 }
170
171 SmallPtrSet<Value, 16> liveOut;
172 for (auto val : liveness.getLiveOut(block: fromBlock))
173 if (isMemref(v: val))
174 liveOut.insert(Ptr: val);
175
176 if (toBlock)
177 llvm::set_intersect(S1&: liveOut, S2: liveness.getLiveIn(block: toBlock));
178
179 // liveOut has non-deterministic order because it was constructed by iterating
180 // over a hash-set.
181 SmallVector<Value> retainedByLiveness(liveOut.begin(), liveOut.end());
182 llvm::sort(C&: retainedByLiveness, Comp: ValueComparator());
183 toRetain.append(RHS: retainedByLiveness);
184}
185
186LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate(
187 OpBuilder &builder, Location loc, Block *block,
188 SmallVectorImpl<Value> &memrefs, SmallVectorImpl<Value> &conditions) const {
189
190 for (auto [i, memref] :
191 llvm::enumerate(First: memrefsToDeallocatePerBlock.lookup(Val: block))) {
192 Ownership ownership = ownershipMap.lookup(Val: {memref, block});
193 if (!ownership.isUnique())
194 return emitError(loc: memref.getLoc(),
195 message: "MemRef value does not have valid ownership");
196
197 // Simply cast unranked MemRefs to ranked memrefs with 0 dimensions such
198 // that we can call extract_strided_metadata on it.
199 if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
200 memref = builder.create<memref::ReinterpretCastOp>(
201 loc, memref,
202 /*offset=*/builder.getIndexAttr(0),
203 /*sizes=*/ArrayRef<OpFoldResult>{},
204 /*strides=*/ArrayRef<OpFoldResult>{});
205
206 // Use the `memref.extract_strided_metadata` operation to get the base
207 // memref. This is needed because the same MemRef that was produced by the
208 // alloc operation has to be passed to the dealloc operation. Passing
209 // subviews, etc. to a dealloc operation is not allowed.
210 memrefs.push_back(
211 builder.create<memref::ExtractStridedMetadataOp>(loc, memref)
212 .getResult(0));
213 conditions.push_back(Elt: ownership.getIndicator());
214 }
215
216 return success();
217}
218
219//===----------------------------------------------------------------------===//
220// ValueComparator
221//===----------------------------------------------------------------------===//
222
223bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const {
224 if (lhs == rhs)
225 return false;
226
227 // Block arguments are less than results.
228 bool lhsIsBBArg = isa<BlockArgument>(Val: lhs);
229 if (lhsIsBBArg != isa<BlockArgument>(Val: rhs)) {
230 return lhsIsBBArg;
231 }
232
233 Region *lhsRegion;
234 Region *rhsRegion;
235 if (lhsIsBBArg) {
236 auto lhsBBArg = llvm::cast<BlockArgument>(Val: lhs);
237 auto rhsBBArg = llvm::cast<BlockArgument>(Val: rhs);
238 if (lhsBBArg.getArgNumber() != rhsBBArg.getArgNumber()) {
239 return lhsBBArg.getArgNumber() < rhsBBArg.getArgNumber();
240 }
241 lhsRegion = lhsBBArg.getParentRegion();
242 rhsRegion = rhsBBArg.getParentRegion();
243 assert(lhsRegion != rhsRegion &&
244 "lhsRegion == rhsRegion implies lhs == rhs");
245 } else if (lhs.getDefiningOp() == rhs.getDefiningOp()) {
246 return llvm::cast<OpResult>(Val: lhs).getResultNumber() <
247 llvm::cast<OpResult>(Val: rhs).getResultNumber();
248 } else {
249 lhsRegion = lhs.getDefiningOp()->getParentRegion();
250 rhsRegion = rhs.getDefiningOp()->getParentRegion();
251 if (lhsRegion == rhsRegion) {
252 return lhs.getDefiningOp()->isBeforeInBlock(other: rhs.getDefiningOp());
253 }
254 }
255
256 // lhsRegion != rhsRegion, so if we look at their ancestor chain, they
257 // - have different heights
258 // - or there's a spot where their region numbers differ
259 // - or their parent regions are the same and their parent ops are
260 // different.
261 while (lhsRegion && rhsRegion) {
262 if (lhsRegion->getRegionNumber() != rhsRegion->getRegionNumber()) {
263 return lhsRegion->getRegionNumber() < rhsRegion->getRegionNumber();
264 }
265 if (lhsRegion->getParentRegion() == rhsRegion->getParentRegion()) {
266 return lhsRegion->getParentOp()->isBeforeInBlock(
267 other: rhsRegion->getParentOp());
268 }
269 lhsRegion = lhsRegion->getParentRegion();
270 rhsRegion = rhsRegion->getParentRegion();
271 }
272 if (rhsRegion)
273 return true;
274 assert(lhsRegion && "this should only happen if lhs == rhs");
275 return false;
276}
277
278//===----------------------------------------------------------------------===//
279// Implementation utilities
280//===----------------------------------------------------------------------===//
281
282FailureOr<Operation *> deallocation_impl::insertDeallocOpForReturnLike(
283 DeallocationState &state, Operation *op, ValueRange operands,
284 SmallVectorImpl<Value> &updatedOperandOwnerships) {
285 assert(op->hasTrait<OpTrait::IsTerminator>() && "must be a terminator");
286 assert(!op->hasSuccessors() && "must not have any successors");
287 // Collect the values to deallocate and retain and use them to create the
288 // dealloc operation.
289 OpBuilder builder(op);
290 Block *block = op->getBlock();
291 SmallVector<Value> memrefs, conditions, toRetain;
292 if (failed(Result: state.getMemrefsAndConditionsToDeallocate(
293 builder, loc: op->getLoc(), block, memrefs, conditions)))
294 return failure();
295
296 state.getMemrefsToRetain(fromBlock: block, /*toBlock=*/nullptr, destOperands: operands, toRetain);
297 if (memrefs.empty() && toRetain.empty())
298 return op;
299
300 auto deallocOp = builder.create<bufferization::DeallocOp>(
301 op->getLoc(), memrefs, conditions, toRetain);
302
303 // We want to replace the current ownership of the retained values with the
304 // result values of the dealloc operation as they are always unique.
305 state.resetOwnerships(memrefs: deallocOp.getRetained(), block);
306 for (auto [retained, ownership] :
307 llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions()))
308 state.updateOwnership(retained, ownership, block);
309
310 unsigned numMemrefOperands = llvm::count_if(Range&: operands, P: isMemref);
311 auto newOperandOwnerships =
312 deallocOp.getUpdatedConditions().take_front(numMemrefOperands);
313 updatedOperandOwnerships.append(newOperandOwnerships.begin(),
314 newOperandOwnerships.end());
315
316 return op;
317}
318

source code of mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp