1 | //===- MemRefMemorySlot.cpp - Memory Slot Interfaces ------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements Mem2Reg-related interfaces for MemRef dialect |
10 | // operations. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h" |
15 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
16 | #include "mlir/IR/BuiltinDialect.h" |
17 | #include "mlir/IR/BuiltinTypes.h" |
18 | #include "mlir/IR/Matchers.h" |
19 | #include "mlir/IR/PatternMatch.h" |
20 | #include "mlir/IR/Value.h" |
21 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
22 | #include "mlir/Interfaces/MemorySlotInterfaces.h" |
23 | #include "mlir/Support/LogicalResult.h" |
24 | #include "llvm/ADT/ArrayRef.h" |
25 | #include "llvm/ADT/TypeSwitch.h" |
26 | #include "llvm/Support/ErrorHandling.h" |
27 | |
28 | using namespace mlir; |
29 | |
30 | //===----------------------------------------------------------------------===// |
31 | // Utilities |
32 | //===----------------------------------------------------------------------===// |
33 | |
34 | /// Walks over the indices of the elements of a tensor of a given `shape` by |
35 | /// updating `index` in place to the next index. This returns failure if the |
36 | /// provided index was the last index. |
37 | static LogicalResult nextIndex(ArrayRef<int64_t> shape, |
38 | MutableArrayRef<int64_t> index) { |
39 | for (size_t i = 0; i < shape.size(); ++i) { |
40 | index[i]++; |
41 | if (index[i] < shape[i]) |
42 | return success(); |
43 | index[i] = 0; |
44 | } |
45 | return failure(); |
46 | } |
47 | |
48 | /// Calls `walker` for each index within a tensor of a given `shape`, providing |
49 | /// the index as an array attribute of the coordinates. |
50 | template <typename CallableT> |
51 | static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape, |
52 | CallableT &&walker) { |
53 | Type indexType = IndexType::get(ctx); |
54 | SmallVector<int64_t> shapeIter(shape.size(), 0); |
55 | do { |
56 | SmallVector<Attribute> indexAsAttr; |
57 | for (int64_t dim : shapeIter) |
58 | indexAsAttr.push_back(IntegerAttr::get(indexType, dim)); |
59 | walker(ArrayAttr::get(ctx, indexAsAttr)); |
60 | } while (succeeded(result: nextIndex(shape, index: shapeIter))); |
61 | } |
62 | |
63 | //===----------------------------------------------------------------------===// |
64 | // Interfaces for AllocaOp |
65 | //===----------------------------------------------------------------------===// |
66 | |
67 | static bool isSupportedElementType(Type type) { |
68 | return llvm::isa<MemRefType>(Val: type) || |
69 | OpBuilder(type.getContext()).getZeroAttr(type); |
70 | } |
71 | |
72 | SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() { |
73 | MemRefType type = getType(); |
74 | if (!isSupportedElementType(type.getElementType())) |
75 | return {}; |
76 | if (!type.hasStaticShape()) |
77 | return {}; |
78 | // Make sure the memref contains only a single element. |
79 | if (type.getNumElements() != 1) |
80 | return {}; |
81 | |
82 | return {MemorySlot{getResult(), type.getElementType()}}; |
83 | } |
84 | |
85 | Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot, |
86 | RewriterBase &rewriter) { |
87 | assert(isSupportedElementType(slot.elemType)); |
88 | // TODO: support more types. |
89 | return TypeSwitch<Type, Value>(slot.elemType) |
90 | .Case([&](MemRefType t) { |
91 | return rewriter.create<memref::AllocaOp>(getLoc(), t); |
92 | }) |
93 | .Default([&](Type t) { |
94 | return rewriter.create<arith::ConstantOp>(getLoc(), t, |
95 | rewriter.getZeroAttr(t)); |
96 | }); |
97 | } |
98 | |
99 | void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot, |
100 | Value defaultValue, |
101 | RewriterBase &rewriter) { |
102 | if (defaultValue.use_empty()) |
103 | rewriter.eraseOp(defaultValue.getDefiningOp()); |
104 | rewriter.eraseOp(*this); |
105 | } |
106 | |
107 | void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot, |
108 | BlockArgument argument, |
109 | RewriterBase &rewriter) {} |
110 | |
111 | SmallVector<DestructurableMemorySlot> |
112 | memref::AllocaOp::getDestructurableSlots() { |
113 | MemRefType memrefType = getType(); |
114 | auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType); |
115 | if (!destructurable) |
116 | return {}; |
117 | |
118 | std::optional<DenseMap<Attribute, Type>> destructuredType = |
119 | destructurable.getSubelementIndexMap(); |
120 | if (!destructuredType) |
121 | return {}; |
122 | |
123 | return { |
124 | DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}}; |
125 | } |
126 | |
127 | DenseMap<Attribute, MemorySlot> |
128 | memref::AllocaOp::destructure(const DestructurableMemorySlot &slot, |
129 | const SmallPtrSetImpl<Attribute> &usedIndices, |
130 | RewriterBase &rewriter) { |
131 | rewriter.setInsertionPointAfter(*this); |
132 | |
133 | DenseMap<Attribute, MemorySlot> slotMap; |
134 | |
135 | auto memrefType = llvm::cast<DestructurableTypeInterface>(getType()); |
136 | for (Attribute usedIndex : usedIndices) { |
137 | Type elemType = memrefType.getTypeAtIndex(usedIndex); |
138 | MemRefType elemPtr = MemRefType::get({}, elemType); |
139 | auto subAlloca = rewriter.create<memref::AllocaOp>(getLoc(), elemPtr); |
140 | slotMap.try_emplace<MemorySlot>(usedIndex, |
141 | {subAlloca.getResult(), elemType}); |
142 | } |
143 | |
144 | return slotMap; |
145 | } |
146 | |
147 | void memref::AllocaOp::handleDestructuringComplete( |
148 | const DestructurableMemorySlot &slot, RewriterBase &rewriter) { |
149 | assert(slot.ptr == getResult()); |
150 | rewriter.eraseOp(*this); |
151 | } |
152 | |
153 | //===----------------------------------------------------------------------===// |
154 | // Interfaces for LoadOp/StoreOp |
155 | //===----------------------------------------------------------------------===// |
156 | |
157 | bool memref::LoadOp::loadsFrom(const MemorySlot &slot) { |
158 | return getMemRef() == slot.ptr; |
159 | } |
160 | |
161 | bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; } |
162 | |
163 | Value memref::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter, |
164 | Value reachingDef, |
165 | const DataLayout &dataLayout) { |
166 | llvm_unreachable("getStored should not be called on LoadOp" ); |
167 | } |
168 | |
169 | bool memref::LoadOp::canUsesBeRemoved( |
170 | const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, |
171 | SmallVectorImpl<OpOperand *> &newBlockingUses, |
172 | const DataLayout &dataLayout) { |
173 | if (blockingUses.size() != 1) |
174 | return false; |
175 | Value blockingUse = (*blockingUses.begin())->get(); |
176 | return blockingUse == slot.ptr && getMemRef() == slot.ptr && |
177 | getResult().getType() == slot.elemType; |
178 | } |
179 | |
180 | DeletionKind memref::LoadOp::removeBlockingUses( |
181 | const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, |
182 | RewriterBase &rewriter, Value reachingDefinition, |
183 | const DataLayout &dataLayout) { |
184 | // `canUsesBeRemoved` checked this blocking use must be the loaded slot |
185 | // pointer. |
186 | rewriter.replaceAllUsesWith(getResult(), reachingDefinition); |
187 | return DeletionKind::Delete; |
188 | } |
189 | |
190 | /// Returns the index of a memref in attribute form, given its indices. Returns |
191 | /// a null pointer if whether the indices form a valid index for the provided |
192 | /// MemRefType cannot be computed. The indices must come from a valid memref |
193 | /// StoreOp or LoadOp. |
194 | static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx, |
195 | ValueRange indices, |
196 | MemRefType memrefType) { |
197 | SmallVector<Attribute> index; |
198 | for (auto [coord, dimSize] : llvm::zip(indices, memrefType.getShape())) { |
199 | IntegerAttr coordAttr; |
200 | if (!matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr))) |
201 | return {}; |
202 | // MemRefType shape dimensions are always positive (checked by verifier). |
203 | std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue(); |
204 | if (!coordInt || coordInt.value() >= static_cast<uint64_t>(dimSize)) |
205 | return {}; |
206 | index.push_back(coordAttr); |
207 | } |
208 | return ArrayAttr::get(ctx, index); |
209 | } |
210 | |
211 | bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot, |
212 | SmallPtrSetImpl<Attribute> &usedIndices, |
213 | SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, |
214 | const DataLayout &dataLayout) { |
215 | if (slot.ptr != getMemRef()) |
216 | return false; |
217 | Attribute index = getAttributeIndexFromIndexOperands( |
218 | getContext(), getIndices(), getMemRefType()); |
219 | if (!index) |
220 | return false; |
221 | usedIndices.insert(index); |
222 | return true; |
223 | } |
224 | |
225 | DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot, |
226 | DenseMap<Attribute, MemorySlot> &subslots, |
227 | RewriterBase &rewriter, |
228 | const DataLayout &dataLayout) { |
229 | Attribute index = getAttributeIndexFromIndexOperands( |
230 | getContext(), getIndices(), getMemRefType()); |
231 | const MemorySlot &memorySlot = subslots.at(index); |
232 | rewriter.modifyOpInPlace(*this, [&]() { |
233 | setMemRef(memorySlot.ptr); |
234 | getIndicesMutable().clear(); |
235 | }); |
236 | return DeletionKind::Keep; |
237 | } |
238 | |
239 | bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; } |
240 | |
241 | bool memref::StoreOp::storesTo(const MemorySlot &slot) { |
242 | return getMemRef() == slot.ptr; |
243 | } |
244 | |
245 | Value memref::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter, |
246 | Value reachingDef, |
247 | const DataLayout &dataLayout) { |
248 | return getValue(); |
249 | } |
250 | |
251 | bool memref::StoreOp::canUsesBeRemoved( |
252 | const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, |
253 | SmallVectorImpl<OpOperand *> &newBlockingUses, |
254 | const DataLayout &dataLayout) { |
255 | if (blockingUses.size() != 1) |
256 | return false; |
257 | Value blockingUse = (*blockingUses.begin())->get(); |
258 | return blockingUse == slot.ptr && getMemRef() == slot.ptr && |
259 | getValue() != slot.ptr && getValue().getType() == slot.elemType; |
260 | } |
261 | |
262 | DeletionKind memref::StoreOp::removeBlockingUses( |
263 | const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, |
264 | RewriterBase &rewriter, Value reachingDefinition, |
265 | const DataLayout &dataLayout) { |
266 | return DeletionKind::Delete; |
267 | } |
268 | |
269 | bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot, |
270 | SmallPtrSetImpl<Attribute> &usedIndices, |
271 | SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, |
272 | const DataLayout &dataLayout) { |
273 | if (slot.ptr != getMemRef() || getValue() == slot.ptr) |
274 | return false; |
275 | Attribute index = getAttributeIndexFromIndexOperands( |
276 | getContext(), getIndices(), getMemRefType()); |
277 | if (!index || !slot.elementPtrs.contains(index)) |
278 | return false; |
279 | usedIndices.insert(index); |
280 | return true; |
281 | } |
282 | |
283 | DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot, |
284 | DenseMap<Attribute, MemorySlot> &subslots, |
285 | RewriterBase &rewriter, |
286 | const DataLayout &dataLayout) { |
287 | Attribute index = getAttributeIndexFromIndexOperands( |
288 | getContext(), getIndices(), getMemRefType()); |
289 | const MemorySlot &memorySlot = subslots.at(index); |
290 | rewriter.modifyOpInPlace(*this, [&]() { |
291 | setMemRef(memorySlot.ptr); |
292 | getIndicesMutable().clear(); |
293 | }); |
294 | return DeletionKind::Keep; |
295 | } |
296 | |
297 | //===----------------------------------------------------------------------===// |
298 | // Interfaces for destructurable types |
299 | //===----------------------------------------------------------------------===// |
300 | |
301 | namespace { |
302 | |
303 | struct MemRefDestructurableTypeExternalModel |
304 | : public DestructurableTypeInterface::ExternalModel< |
305 | MemRefDestructurableTypeExternalModel, MemRefType> { |
306 | std::optional<DenseMap<Attribute, Type>> |
307 | getSubelementIndexMap(Type type) const { |
308 | auto memrefType = llvm::cast<MemRefType>(type); |
309 | constexpr int64_t maxMemrefSizeForDestructuring = 16; |
310 | if (!memrefType.hasStaticShape() || |
311 | memrefType.getNumElements() > maxMemrefSizeForDestructuring || |
312 | memrefType.getNumElements() == 1) |
313 | return {}; |
314 | |
315 | DenseMap<Attribute, Type> destructured; |
316 | walkIndicesAsAttr( |
317 | memrefType.getContext(), memrefType.getShape(), [&](Attribute index) { |
318 | destructured.insert({index, memrefType.getElementType()}); |
319 | }); |
320 | |
321 | return destructured; |
322 | } |
323 | |
324 | Type getTypeAtIndex(Type type, Attribute index) const { |
325 | auto memrefType = llvm::cast<MemRefType>(type); |
326 | auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index); |
327 | if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size()) |
328 | return {}; |
329 | |
330 | Type indexType = IndexType::get(memrefType.getContext()); |
331 | for (const auto &[coordAttr, dimSize] : |
332 | llvm::zip(coordArrAttr, memrefType.getShape())) { |
333 | auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr); |
334 | if (!coord || coord.getType() != indexType || coord.getInt() < 0 || |
335 | coord.getInt() >= dimSize) |
336 | return {}; |
337 | } |
338 | |
339 | return memrefType.getElementType(); |
340 | } |
341 | }; |
342 | |
343 | } // namespace |
344 | |
345 | //===----------------------------------------------------------------------===// |
346 | // Register external models |
347 | //===----------------------------------------------------------------------===// |
348 | |
349 | void mlir::memref::registerMemorySlotExternalModels(DialectRegistry ®istry) { |
350 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) { |
351 | MemRefType::attachInterface<MemRefDestructurableTypeExternalModel>(*ctx); |
352 | }); |
353 | } |
354 | |