1 | //===- MemRefMemorySlot.cpp - Memory Slot Interfaces ------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements Mem2Reg-related interfaces for MemRef dialect |
10 | // operations. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h" |
15 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
16 | #include "mlir/IR/BuiltinDialect.h" |
17 | #include "mlir/IR/BuiltinTypes.h" |
18 | #include "mlir/IR/Matchers.h" |
19 | #include "mlir/IR/PatternMatch.h" |
20 | #include "mlir/IR/Value.h" |
21 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
22 | #include "mlir/Interfaces/MemorySlotInterfaces.h" |
23 | #include "llvm/ADT/ArrayRef.h" |
24 | #include "llvm/ADT/TypeSwitch.h" |
25 | #include "llvm/Support/ErrorHandling.h" |
26 | |
27 | using namespace mlir; |
28 | |
29 | //===----------------------------------------------------------------------===// |
30 | // Utilities |
31 | //===----------------------------------------------------------------------===// |
32 | |
33 | /// Walks over the indices of the elements of a tensor of a given `shape` by |
34 | /// updating `index` in place to the next index. This returns failure if the |
35 | /// provided index was the last index. |
36 | static LogicalResult nextIndex(ArrayRef<int64_t> shape, |
37 | MutableArrayRef<int64_t> index) { |
38 | for (size_t i = 0; i < shape.size(); ++i) { |
39 | index[i]++; |
40 | if (index[i] < shape[i]) |
41 | return success(); |
42 | index[i] = 0; |
43 | } |
44 | return failure(); |
45 | } |
46 | |
47 | /// Calls `walker` for each index within a tensor of a given `shape`, providing |
48 | /// the index as an array attribute of the coordinates. |
49 | template <typename CallableT> |
50 | static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape, |
51 | CallableT &&walker) { |
52 | Type indexType = IndexType::get(ctx); |
53 | SmallVector<int64_t> shapeIter(shape.size(), 0); |
54 | do { |
55 | SmallVector<Attribute> indexAsAttr; |
56 | for (int64_t dim : shapeIter) |
57 | indexAsAttr.push_back(IntegerAttr::get(indexType, dim)); |
58 | walker(ArrayAttr::get(ctx, indexAsAttr)); |
59 | } while (succeeded(Result: nextIndex(shape, index: shapeIter))); |
60 | } |
61 | |
62 | //===----------------------------------------------------------------------===// |
63 | // Interfaces for AllocaOp |
64 | //===----------------------------------------------------------------------===// |
65 | |
66 | static bool isSupportedElementType(Type type) { |
67 | return llvm::isa<MemRefType>(Val: type) || |
68 | OpBuilder(type.getContext()).getZeroAttr(type); |
69 | } |
70 | |
71 | SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() { |
72 | MemRefType type = getType(); |
73 | if (!isSupportedElementType(type.getElementType())) |
74 | return {}; |
75 | if (!type.hasStaticShape()) |
76 | return {}; |
77 | // Make sure the memref contains only a single element. |
78 | if (type.getNumElements() != 1) |
79 | return {}; |
80 | |
81 | return {MemorySlot{getResult(), type.getElementType()}}; |
82 | } |
83 | |
84 | Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot, |
85 | OpBuilder &builder) { |
86 | assert(isSupportedElementType(slot.elemType)); |
87 | // TODO: support more types. |
88 | return TypeSwitch<Type, Value>(slot.elemType) |
89 | .Case([&](MemRefType t) { |
90 | return builder.create<memref::AllocaOp>(getLoc(), t); |
91 | }) |
92 | .Default([&](Type t) { |
93 | return builder.create<arith::ConstantOp>(getLoc(), t, |
94 | builder.getZeroAttr(t)); |
95 | }); |
96 | } |
97 | |
98 | std::optional<PromotableAllocationOpInterface> |
99 | memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot, |
100 | Value defaultValue, |
101 | OpBuilder &builder) { |
102 | if (defaultValue.use_empty()) |
103 | defaultValue.getDefiningOp()->erase(); |
104 | this->erase(); |
105 | return std::nullopt; |
106 | } |
107 | |
108 | void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot, |
109 | BlockArgument argument, |
110 | OpBuilder &builder) {} |
111 | |
112 | SmallVector<DestructurableMemorySlot> |
113 | memref::AllocaOp::getDestructurableSlots() { |
114 | MemRefType memrefType = getType(); |
115 | auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType); |
116 | if (!destructurable) |
117 | return {}; |
118 | |
119 | std::optional<DenseMap<Attribute, Type>> destructuredType = |
120 | destructurable.getSubelementIndexMap(); |
121 | if (!destructuredType) |
122 | return {}; |
123 | |
124 | return { |
125 | DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}}; |
126 | } |
127 | |
128 | DenseMap<Attribute, MemorySlot> memref::AllocaOp::destructure( |
129 | const DestructurableMemorySlot &slot, |
130 | const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder, |
131 | SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) { |
132 | builder.setInsertionPointAfter(*this); |
133 | |
134 | DenseMap<Attribute, MemorySlot> slotMap; |
135 | |
136 | auto memrefType = llvm::cast<DestructurableTypeInterface>(getType()); |
137 | for (Attribute usedIndex : usedIndices) { |
138 | Type elemType = memrefType.getTypeAtIndex(usedIndex); |
139 | MemRefType elemPtr = MemRefType::get({}, elemType); |
140 | auto subAlloca = builder.create<memref::AllocaOp>(getLoc(), elemPtr); |
141 | newAllocators.push_back(subAlloca); |
142 | slotMap.try_emplace<MemorySlot>(usedIndex, |
143 | {subAlloca.getResult(), elemType}); |
144 | } |
145 | |
146 | return slotMap; |
147 | } |
148 | |
149 | std::optional<DestructurableAllocationOpInterface> |
150 | memref::AllocaOp::handleDestructuringComplete( |
151 | const DestructurableMemorySlot &slot, OpBuilder &builder) { |
152 | assert(slot.ptr == getResult()); |
153 | this->erase(); |
154 | return std::nullopt; |
155 | } |
156 | |
157 | //===----------------------------------------------------------------------===// |
158 | // Interfaces for LoadOp/StoreOp |
159 | //===----------------------------------------------------------------------===// |
160 | |
161 | bool memref::LoadOp::loadsFrom(const MemorySlot &slot) { |
162 | return getMemRef() == slot.ptr; |
163 | } |
164 | |
165 | bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; } |
166 | |
167 | Value memref::LoadOp::getStored(const MemorySlot &slot, OpBuilder &builder, |
168 | Value reachingDef, |
169 | const DataLayout &dataLayout) { |
170 | llvm_unreachable("getStored should not be called on LoadOp" ); |
171 | } |
172 | |
173 | bool memref::LoadOp::canUsesBeRemoved( |
174 | const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, |
175 | SmallVectorImpl<OpOperand *> &newBlockingUses, |
176 | const DataLayout &dataLayout) { |
177 | if (blockingUses.size() != 1) |
178 | return false; |
179 | Value blockingUse = (*blockingUses.begin())->get(); |
180 | return blockingUse == slot.ptr && getMemRef() == slot.ptr && |
181 | getResult().getType() == slot.elemType; |
182 | } |
183 | |
184 | DeletionKind memref::LoadOp::removeBlockingUses( |
185 | const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, |
186 | OpBuilder &builder, Value reachingDefinition, |
187 | const DataLayout &dataLayout) { |
188 | // `canUsesBeRemoved` checked this blocking use must be the loaded slot |
189 | // pointer. |
190 | getResult().replaceAllUsesWith(reachingDefinition); |
191 | return DeletionKind::Delete; |
192 | } |
193 | |
194 | /// Returns the index of a memref in attribute form, given its indices. Returns |
195 | /// a null pointer if whether the indices form a valid index for the provided |
196 | /// MemRefType cannot be computed. The indices must come from a valid memref |
197 | /// StoreOp or LoadOp. |
198 | static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx, |
199 | ValueRange indices, |
200 | MemRefType memrefType) { |
201 | SmallVector<Attribute> index; |
202 | for (auto [coord, dimSize] : llvm::zip(indices, memrefType.getShape())) { |
203 | IntegerAttr coordAttr; |
204 | if (!matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr))) |
205 | return {}; |
206 | // MemRefType shape dimensions are always positive (checked by verifier). |
207 | std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue(); |
208 | if (!coordInt || coordInt.value() >= static_cast<uint64_t>(dimSize)) |
209 | return {}; |
210 | index.push_back(coordAttr); |
211 | } |
212 | return ArrayAttr::get(ctx, index); |
213 | } |
214 | |
215 | bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot, |
216 | SmallPtrSetImpl<Attribute> &usedIndices, |
217 | SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, |
218 | const DataLayout &dataLayout) { |
219 | if (slot.ptr != getMemRef()) |
220 | return false; |
221 | Attribute index = getAttributeIndexFromIndexOperands( |
222 | getContext(), getIndices(), getMemRefType()); |
223 | if (!index) |
224 | return false; |
225 | usedIndices.insert(index); |
226 | return true; |
227 | } |
228 | |
229 | DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot, |
230 | DenseMap<Attribute, MemorySlot> &subslots, |
231 | OpBuilder &builder, |
232 | const DataLayout &dataLayout) { |
233 | Attribute index = getAttributeIndexFromIndexOperands( |
234 | getContext(), getIndices(), getMemRefType()); |
235 | const MemorySlot &memorySlot = subslots.at(index); |
236 | setMemRef(memorySlot.ptr); |
237 | getIndicesMutable().clear(); |
238 | return DeletionKind::Keep; |
239 | } |
240 | |
241 | bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; } |
242 | |
243 | bool memref::StoreOp::storesTo(const MemorySlot &slot) { |
244 | return getMemRef() == slot.ptr; |
245 | } |
246 | |
247 | Value memref::StoreOp::getStored(const MemorySlot &slot, OpBuilder &builder, |
248 | Value reachingDef, |
249 | const DataLayout &dataLayout) { |
250 | return getValue(); |
251 | } |
252 | |
253 | bool memref::StoreOp::canUsesBeRemoved( |
254 | const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, |
255 | SmallVectorImpl<OpOperand *> &newBlockingUses, |
256 | const DataLayout &dataLayout) { |
257 | if (blockingUses.size() != 1) |
258 | return false; |
259 | Value blockingUse = (*blockingUses.begin())->get(); |
260 | return blockingUse == slot.ptr && getMemRef() == slot.ptr && |
261 | getValue() != slot.ptr && getValue().getType() == slot.elemType; |
262 | } |
263 | |
264 | DeletionKind memref::StoreOp::removeBlockingUses( |
265 | const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, |
266 | OpBuilder &builder, Value reachingDefinition, |
267 | const DataLayout &dataLayout) { |
268 | return DeletionKind::Delete; |
269 | } |
270 | |
271 | bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot, |
272 | SmallPtrSetImpl<Attribute> &usedIndices, |
273 | SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, |
274 | const DataLayout &dataLayout) { |
275 | if (slot.ptr != getMemRef() || getValue() == slot.ptr) |
276 | return false; |
277 | Attribute index = getAttributeIndexFromIndexOperands( |
278 | getContext(), getIndices(), getMemRefType()); |
279 | if (!index || !slot.subelementTypes.contains(index)) |
280 | return false; |
281 | usedIndices.insert(index); |
282 | return true; |
283 | } |
284 | |
285 | DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot, |
286 | DenseMap<Attribute, MemorySlot> &subslots, |
287 | OpBuilder &builder, |
288 | const DataLayout &dataLayout) { |
289 | Attribute index = getAttributeIndexFromIndexOperands( |
290 | getContext(), getIndices(), getMemRefType()); |
291 | const MemorySlot &memorySlot = subslots.at(index); |
292 | setMemRef(memorySlot.ptr); |
293 | getIndicesMutable().clear(); |
294 | return DeletionKind::Keep; |
295 | } |
296 | |
297 | //===----------------------------------------------------------------------===// |
298 | // Interfaces for destructurable types |
299 | //===----------------------------------------------------------------------===// |
300 | |
301 | namespace { |
302 | |
303 | struct MemRefDestructurableTypeExternalModel |
304 | : public DestructurableTypeInterface::ExternalModel< |
305 | MemRefDestructurableTypeExternalModel, MemRefType> { |
306 | std::optional<DenseMap<Attribute, Type>> |
307 | getSubelementIndexMap(Type type) const { |
308 | auto memrefType = llvm::cast<MemRefType>(type); |
309 | constexpr int64_t maxMemrefSizeForDestructuring = 16; |
310 | if (!memrefType.hasStaticShape() || |
311 | memrefType.getNumElements() > maxMemrefSizeForDestructuring || |
312 | memrefType.getNumElements() == 1) |
313 | return {}; |
314 | |
315 | DenseMap<Attribute, Type> destructured; |
316 | walkIndicesAsAttr( |
317 | memrefType.getContext(), memrefType.getShape(), [&](Attribute index) { |
318 | destructured.insert({index, memrefType.getElementType()}); |
319 | }); |
320 | |
321 | return destructured; |
322 | } |
323 | |
324 | Type getTypeAtIndex(Type type, Attribute index) const { |
325 | auto memrefType = llvm::cast<MemRefType>(type); |
326 | auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index); |
327 | if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size()) |
328 | return {}; |
329 | |
330 | Type indexType = IndexType::get(memrefType.getContext()); |
331 | for (const auto &[coordAttr, dimSize] : |
332 | llvm::zip(coordArrAttr, memrefType.getShape())) { |
333 | auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr); |
334 | if (!coord || coord.getType() != indexType || coord.getInt() < 0 || |
335 | coord.getInt() >= dimSize) |
336 | return {}; |
337 | } |
338 | |
339 | return memrefType.getElementType(); |
340 | } |
341 | }; |
342 | |
343 | } // namespace |
344 | |
345 | //===----------------------------------------------------------------------===// |
346 | // Register external models |
347 | //===----------------------------------------------------------------------===// |
348 | |
349 | void mlir::memref::registerMemorySlotExternalModels(DialectRegistry ®istry) { |
350 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) { |
351 | MemRefType::attachInterface<MemRefDestructurableTypeExternalModel>(*ctx); |
352 | }); |
353 | } |
354 | |