1//===- MemRefMemorySlot.cpp - Memory Slot Interfaces ------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements Mem2Reg-related interfaces for MemRef dialect
10// operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
15#include "mlir/Dialect/MemRef/IR/MemRef.h"
16#include "mlir/IR/BuiltinDialect.h"
17#include "mlir/IR/BuiltinTypes.h"
18#include "mlir/IR/Matchers.h"
19#include "mlir/IR/Value.h"
20#include "mlir/Interfaces/MemorySlotInterfaces.h"
21#include "llvm/ADT/ArrayRef.h"
22#include "llvm/ADT/TypeSwitch.h"
23#include "llvm/Support/ErrorHandling.h"
24
25using namespace mlir;
26
27//===----------------------------------------------------------------------===//
28// Utilities
29//===----------------------------------------------------------------------===//
30
31/// Walks over the indices of the elements of a tensor of a given `shape` by
32/// updating `index` in place to the next index. This returns failure if the
33/// provided index was the last index.
34static LogicalResult nextIndex(ArrayRef<int64_t> shape,
35 MutableArrayRef<int64_t> index) {
36 for (size_t i = 0; i < shape.size(); ++i) {
37 index[i]++;
38 if (index[i] < shape[i])
39 return success();
40 index[i] = 0;
41 }
42 return failure();
43}
44
45/// Calls `walker` for each index within a tensor of a given `shape`, providing
46/// the index as an array attribute of the coordinates.
47template <typename CallableT>
48static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
49 CallableT &&walker) {
50 Type indexType = IndexType::get(context: ctx);
51 SmallVector<int64_t> shapeIter(shape.size(), 0);
52 do {
53 SmallVector<Attribute> indexAsAttr;
54 for (int64_t dim : shapeIter)
55 indexAsAttr.push_back(Elt: IntegerAttr::get(type: indexType, value: dim));
56 walker(ArrayAttr::get(context: ctx, value: indexAsAttr));
57 } while (succeeded(Result: nextIndex(shape, index: shapeIter)));
58}
59
60//===----------------------------------------------------------------------===//
61// Interfaces for AllocaOp
62//===----------------------------------------------------------------------===//
63
64static bool isSupportedElementType(Type type) {
65 return llvm::isa<MemRefType>(Val: type) ||
66 OpBuilder(type.getContext()).getZeroAttr(type);
67}
68
69SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
70 MemRefType type = getType();
71 if (!isSupportedElementType(type: type.getElementType()))
72 return {};
73 if (!type.hasStaticShape())
74 return {};
75 // Make sure the memref contains only a single element.
76 if (type.getNumElements() != 1)
77 return {};
78
79 return {MemorySlot{.ptr: getResult(), .elemType: type.getElementType()}};
80}
81
82Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
83 OpBuilder &builder) {
84 assert(isSupportedElementType(slot.elemType));
85 // TODO: support more types.
86 return TypeSwitch<Type, Value>(slot.elemType)
87 .Case(caseFn: [&](MemRefType t) {
88 return builder.create<memref::AllocaOp>(location: getLoc(), args&: t);
89 })
90 .Default(defaultFn: [&](Type t) {
91 return builder.create<arith::ConstantOp>(location: getLoc(), args&: t,
92 args: builder.getZeroAttr(type: t));
93 });
94}
95
96std::optional<PromotableAllocationOpInterface>
97memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
98 Value defaultValue,
99 OpBuilder &builder) {
100 if (defaultValue.use_empty())
101 defaultValue.getDefiningOp()->erase();
102 this->erase();
103 return std::nullopt;
104}
105
106void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
107 BlockArgument argument,
108 OpBuilder &builder) {}
109
110SmallVector<DestructurableMemorySlot>
111memref::AllocaOp::getDestructurableSlots() {
112 MemRefType memrefType = getType();
113 auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(Val&: memrefType);
114 if (!destructurable)
115 return {};
116
117 std::optional<DenseMap<Attribute, Type>> destructuredType =
118 destructurable.getSubelementIndexMap();
119 if (!destructuredType)
120 return {};
121
122 return {
123 DestructurableMemorySlot{{.ptr: getMemref(), .elemType: memrefType}, .subelementTypes: *destructuredType}};
124}
125
126DenseMap<Attribute, MemorySlot> memref::AllocaOp::destructure(
127 const DestructurableMemorySlot &slot,
128 const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
129 SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
130 builder.setInsertionPointAfter(*this);
131
132 DenseMap<Attribute, MemorySlot> slotMap;
133
134 auto memrefType = llvm::cast<DestructurableTypeInterface>(Val: getType());
135 for (Attribute usedIndex : usedIndices) {
136 Type elemType = memrefType.getTypeAtIndex(index: usedIndex);
137 MemRefType elemPtr = MemRefType::get(shape: {}, elementType: elemType);
138 auto subAlloca = builder.create<memref::AllocaOp>(location: getLoc(), args&: elemPtr);
139 newAllocators.push_back(Elt: subAlloca);
140 slotMap.try_emplace<MemorySlot>(Key: usedIndex,
141 Args: {.ptr: subAlloca.getResult(), .elemType: elemType});
142 }
143
144 return slotMap;
145}
146
147std::optional<DestructurableAllocationOpInterface>
148memref::AllocaOp::handleDestructuringComplete(
149 const DestructurableMemorySlot &slot, OpBuilder &builder) {
150 assert(slot.ptr == getResult());
151 this->erase();
152 return std::nullopt;
153}
154
155//===----------------------------------------------------------------------===//
156// Interfaces for LoadOp/StoreOp
157//===----------------------------------------------------------------------===//
158
159bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
160 return getMemRef() == slot.ptr;
161}
162
163bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
164
165Value memref::LoadOp::getStored(const MemorySlot &slot, OpBuilder &builder,
166 Value reachingDef,
167 const DataLayout &dataLayout) {
168 llvm_unreachable("getStored should not be called on LoadOp");
169}
170
171bool memref::LoadOp::canUsesBeRemoved(
172 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
173 SmallVectorImpl<OpOperand *> &newBlockingUses,
174 const DataLayout &dataLayout) {
175 if (blockingUses.size() != 1)
176 return false;
177 Value blockingUse = (*blockingUses.begin())->get();
178 return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
179 getResult().getType() == slot.elemType;
180}
181
182DeletionKind memref::LoadOp::removeBlockingUses(
183 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
184 OpBuilder &builder, Value reachingDefinition,
185 const DataLayout &dataLayout) {
186 // `canUsesBeRemoved` checked this blocking use must be the loaded slot
187 // pointer.
188 getResult().replaceAllUsesWith(newValue: reachingDefinition);
189 return DeletionKind::Delete;
190}
191
192/// Returns the index of a memref in attribute form, given its indices. Returns
193/// a null pointer if whether the indices form a valid index for the provided
194/// MemRefType cannot be computed. The indices must come from a valid memref
195/// StoreOp or LoadOp.
196static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx,
197 ValueRange indices,
198 MemRefType memrefType) {
199 SmallVector<Attribute> index;
200 for (auto [coord, dimSize] : llvm::zip(t&: indices, u: memrefType.getShape())) {
201 IntegerAttr coordAttr;
202 if (!matchPattern(value: coord, pattern: m_Constant<IntegerAttr>(bind_value: &coordAttr)))
203 return {};
204 // MemRefType shape dimensions are always positive (checked by verifier).
205 std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
206 if (!coordInt || coordInt.value() >= static_cast<uint64_t>(dimSize))
207 return {};
208 index.push_back(Elt: coordAttr);
209 }
210 return ArrayAttr::get(context: ctx, value: index);
211}
212
213bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
214 SmallPtrSetImpl<Attribute> &usedIndices,
215 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
216 const DataLayout &dataLayout) {
217 if (slot.ptr != getMemRef())
218 return false;
219 Attribute index = getAttributeIndexFromIndexOperands(
220 ctx: getContext(), indices: getIndices(), memrefType: getMemRefType());
221 if (!index)
222 return false;
223 usedIndices.insert(Ptr: index);
224 return true;
225}
226
227DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
228 DenseMap<Attribute, MemorySlot> &subslots,
229 OpBuilder &builder,
230 const DataLayout &dataLayout) {
231 Attribute index = getAttributeIndexFromIndexOperands(
232 ctx: getContext(), indices: getIndices(), memrefType: getMemRefType());
233 const MemorySlot &memorySlot = subslots.at(Val: index);
234 setMemRef(memorySlot.ptr);
235 getIndicesMutable().clear();
236 return DeletionKind::Keep;
237}
238
239bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
240
241bool memref::StoreOp::storesTo(const MemorySlot &slot) {
242 return getMemRef() == slot.ptr;
243}
244
245Value memref::StoreOp::getStored(const MemorySlot &slot, OpBuilder &builder,
246 Value reachingDef,
247 const DataLayout &dataLayout) {
248 return getValue();
249}
250
251bool memref::StoreOp::canUsesBeRemoved(
252 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
253 SmallVectorImpl<OpOperand *> &newBlockingUses,
254 const DataLayout &dataLayout) {
255 if (blockingUses.size() != 1)
256 return false;
257 Value blockingUse = (*blockingUses.begin())->get();
258 return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
259 getValue() != slot.ptr && getValue().getType() == slot.elemType;
260}
261
262DeletionKind memref::StoreOp::removeBlockingUses(
263 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
264 OpBuilder &builder, Value reachingDefinition,
265 const DataLayout &dataLayout) {
266 return DeletionKind::Delete;
267}
268
269bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
270 SmallPtrSetImpl<Attribute> &usedIndices,
271 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
272 const DataLayout &dataLayout) {
273 if (slot.ptr != getMemRef() || getValue() == slot.ptr)
274 return false;
275 Attribute index = getAttributeIndexFromIndexOperands(
276 ctx: getContext(), indices: getIndices(), memrefType: getMemRefType());
277 if (!index || !slot.subelementTypes.contains(Val: index))
278 return false;
279 usedIndices.insert(Ptr: index);
280 return true;
281}
282
283DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
284 DenseMap<Attribute, MemorySlot> &subslots,
285 OpBuilder &builder,
286 const DataLayout &dataLayout) {
287 Attribute index = getAttributeIndexFromIndexOperands(
288 ctx: getContext(), indices: getIndices(), memrefType: getMemRefType());
289 const MemorySlot &memorySlot = subslots.at(Val: index);
290 setMemRef(memorySlot.ptr);
291 getIndicesMutable().clear();
292 return DeletionKind::Keep;
293}
294
295//===----------------------------------------------------------------------===//
296// Interfaces for destructurable types
297//===----------------------------------------------------------------------===//
298
299namespace {
300
301struct MemRefDestructurableTypeExternalModel
302 : public DestructurableTypeInterface::ExternalModel<
303 MemRefDestructurableTypeExternalModel, MemRefType> {
304 std::optional<DenseMap<Attribute, Type>>
305 getSubelementIndexMap(Type type) const {
306 auto memrefType = llvm::cast<MemRefType>(Val&: type);
307 constexpr int64_t maxMemrefSizeForDestructuring = 16;
308 if (!memrefType.hasStaticShape() ||
309 memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
310 memrefType.getNumElements() == 1)
311 return {};
312
313 DenseMap<Attribute, Type> destructured;
314 walkIndicesAsAttr(
315 ctx: memrefType.getContext(), shape: memrefType.getShape(), walker: [&](Attribute index) {
316 destructured.insert(KV: {index, memrefType.getElementType()});
317 });
318
319 return destructured;
320 }
321
322 Type getTypeAtIndex(Type type, Attribute index) const {
323 auto memrefType = llvm::cast<MemRefType>(Val&: type);
324 auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(Val&: index);
325 if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
326 return {};
327
328 Type indexType = IndexType::get(context: memrefType.getContext());
329 for (const auto &[coordAttr, dimSize] :
330 llvm::zip(t&: coordArrAttr, u: memrefType.getShape())) {
331 auto coord = llvm::dyn_cast<IntegerAttr>(Val: coordAttr);
332 if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
333 coord.getInt() >= dimSize)
334 return {};
335 }
336
337 return memrefType.getElementType();
338 }
339};
340
341} // namespace
342
343//===----------------------------------------------------------------------===//
344// Register external models
345//===----------------------------------------------------------------------===//
346
347void mlir::memref::registerMemorySlotExternalModels(DialectRegistry &registry) {
348 registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) {
349 MemRefType::attachInterface<MemRefDestructurableTypeExternalModel>(context&: *ctx);
350 });
351}
352

source code of mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp