1//===- MemRefMemorySlot.cpp - Memory Slot Interfaces ------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements Mem2Reg-related interfaces for MemRef dialect
10// operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
15#include "mlir/Dialect/MemRef/IR/MemRef.h"
16#include "mlir/IR/BuiltinDialect.h"
17#include "mlir/IR/BuiltinTypes.h"
18#include "mlir/IR/Matchers.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/IR/Value.h"
21#include "mlir/Interfaces/InferTypeOpInterface.h"
22#include "mlir/Interfaces/MemorySlotInterfaces.h"
23#include "mlir/Support/LogicalResult.h"
24#include "llvm/ADT/ArrayRef.h"
25#include "llvm/ADT/TypeSwitch.h"
26#include "llvm/Support/ErrorHandling.h"
27
28using namespace mlir;
29
30//===----------------------------------------------------------------------===//
31// Utilities
32//===----------------------------------------------------------------------===//
33
34/// Walks over the indices of the elements of a tensor of a given `shape` by
35/// updating `index` in place to the next index. This returns failure if the
36/// provided index was the last index.
37static LogicalResult nextIndex(ArrayRef<int64_t> shape,
38 MutableArrayRef<int64_t> index) {
39 for (size_t i = 0; i < shape.size(); ++i) {
40 index[i]++;
41 if (index[i] < shape[i])
42 return success();
43 index[i] = 0;
44 }
45 return failure();
46}
47
48/// Calls `walker` for each index within a tensor of a given `shape`, providing
49/// the index as an array attribute of the coordinates.
50template <typename CallableT>
51static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
52 CallableT &&walker) {
53 Type indexType = IndexType::get(ctx);
54 SmallVector<int64_t> shapeIter(shape.size(), 0);
55 do {
56 SmallVector<Attribute> indexAsAttr;
57 for (int64_t dim : shapeIter)
58 indexAsAttr.push_back(IntegerAttr::get(indexType, dim));
59 walker(ArrayAttr::get(ctx, indexAsAttr));
60 } while (succeeded(result: nextIndex(shape, index: shapeIter)));
61}
62
63//===----------------------------------------------------------------------===//
64// Interfaces for AllocaOp
65//===----------------------------------------------------------------------===//
66
67static bool isSupportedElementType(Type type) {
68 return llvm::isa<MemRefType>(Val: type) ||
69 OpBuilder(type.getContext()).getZeroAttr(type);
70}
71
72SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
73 MemRefType type = getType();
74 if (!isSupportedElementType(type.getElementType()))
75 return {};
76 if (!type.hasStaticShape())
77 return {};
78 // Make sure the memref contains only a single element.
79 if (type.getNumElements() != 1)
80 return {};
81
82 return {MemorySlot{getResult(), type.getElementType()}};
83}
84
85Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
86 RewriterBase &rewriter) {
87 assert(isSupportedElementType(slot.elemType));
88 // TODO: support more types.
89 return TypeSwitch<Type, Value>(slot.elemType)
90 .Case([&](MemRefType t) {
91 return rewriter.create<memref::AllocaOp>(getLoc(), t);
92 })
93 .Default([&](Type t) {
94 return rewriter.create<arith::ConstantOp>(getLoc(), t,
95 rewriter.getZeroAttr(t));
96 });
97}
98
99void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
100 Value defaultValue,
101 RewriterBase &rewriter) {
102 if (defaultValue.use_empty())
103 rewriter.eraseOp(defaultValue.getDefiningOp());
104 rewriter.eraseOp(*this);
105}
106
107void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
108 BlockArgument argument,
109 RewriterBase &rewriter) {}
110
111SmallVector<DestructurableMemorySlot>
112memref::AllocaOp::getDestructurableSlots() {
113 MemRefType memrefType = getType();
114 auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
115 if (!destructurable)
116 return {};
117
118 std::optional<DenseMap<Attribute, Type>> destructuredType =
119 destructurable.getSubelementIndexMap();
120 if (!destructuredType)
121 return {};
122
123 return {
124 DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}};
125}
126
127DenseMap<Attribute, MemorySlot>
128memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
129 const SmallPtrSetImpl<Attribute> &usedIndices,
130 RewriterBase &rewriter) {
131 rewriter.setInsertionPointAfter(*this);
132
133 DenseMap<Attribute, MemorySlot> slotMap;
134
135 auto memrefType = llvm::cast<DestructurableTypeInterface>(getType());
136 for (Attribute usedIndex : usedIndices) {
137 Type elemType = memrefType.getTypeAtIndex(usedIndex);
138 MemRefType elemPtr = MemRefType::get({}, elemType);
139 auto subAlloca = rewriter.create<memref::AllocaOp>(getLoc(), elemPtr);
140 slotMap.try_emplace<MemorySlot>(usedIndex,
141 {subAlloca.getResult(), elemType});
142 }
143
144 return slotMap;
145}
146
147void memref::AllocaOp::handleDestructuringComplete(
148 const DestructurableMemorySlot &slot, RewriterBase &rewriter) {
149 assert(slot.ptr == getResult());
150 rewriter.eraseOp(*this);
151}
152
153//===----------------------------------------------------------------------===//
154// Interfaces for LoadOp/StoreOp
155//===----------------------------------------------------------------------===//
156
157bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
158 return getMemRef() == slot.ptr;
159}
160
161bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
162
163Value memref::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
164 Value reachingDef,
165 const DataLayout &dataLayout) {
166 llvm_unreachable("getStored should not be called on LoadOp");
167}
168
169bool memref::LoadOp::canUsesBeRemoved(
170 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
171 SmallVectorImpl<OpOperand *> &newBlockingUses,
172 const DataLayout &dataLayout) {
173 if (blockingUses.size() != 1)
174 return false;
175 Value blockingUse = (*blockingUses.begin())->get();
176 return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
177 getResult().getType() == slot.elemType;
178}
179
180DeletionKind memref::LoadOp::removeBlockingUses(
181 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
182 RewriterBase &rewriter, Value reachingDefinition,
183 const DataLayout &dataLayout) {
184 // `canUsesBeRemoved` checked this blocking use must be the loaded slot
185 // pointer.
186 rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
187 return DeletionKind::Delete;
188}
189
190/// Returns the index of a memref in attribute form, given its indices. Returns
191/// a null pointer if whether the indices form a valid index for the provided
192/// MemRefType cannot be computed. The indices must come from a valid memref
193/// StoreOp or LoadOp.
194static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx,
195 ValueRange indices,
196 MemRefType memrefType) {
197 SmallVector<Attribute> index;
198 for (auto [coord, dimSize] : llvm::zip(indices, memrefType.getShape())) {
199 IntegerAttr coordAttr;
200 if (!matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
201 return {};
202 // MemRefType shape dimensions are always positive (checked by verifier).
203 std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
204 if (!coordInt || coordInt.value() >= static_cast<uint64_t>(dimSize))
205 return {};
206 index.push_back(coordAttr);
207 }
208 return ArrayAttr::get(ctx, index);
209}
210
211bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
212 SmallPtrSetImpl<Attribute> &usedIndices,
213 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
214 const DataLayout &dataLayout) {
215 if (slot.ptr != getMemRef())
216 return false;
217 Attribute index = getAttributeIndexFromIndexOperands(
218 getContext(), getIndices(), getMemRefType());
219 if (!index)
220 return false;
221 usedIndices.insert(index);
222 return true;
223}
224
225DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
226 DenseMap<Attribute, MemorySlot> &subslots,
227 RewriterBase &rewriter,
228 const DataLayout &dataLayout) {
229 Attribute index = getAttributeIndexFromIndexOperands(
230 getContext(), getIndices(), getMemRefType());
231 const MemorySlot &memorySlot = subslots.at(index);
232 rewriter.modifyOpInPlace(*this, [&]() {
233 setMemRef(memorySlot.ptr);
234 getIndicesMutable().clear();
235 });
236 return DeletionKind::Keep;
237}
238
239bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
240
241bool memref::StoreOp::storesTo(const MemorySlot &slot) {
242 return getMemRef() == slot.ptr;
243}
244
245Value memref::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
246 Value reachingDef,
247 const DataLayout &dataLayout) {
248 return getValue();
249}
250
251bool memref::StoreOp::canUsesBeRemoved(
252 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
253 SmallVectorImpl<OpOperand *> &newBlockingUses,
254 const DataLayout &dataLayout) {
255 if (blockingUses.size() != 1)
256 return false;
257 Value blockingUse = (*blockingUses.begin())->get();
258 return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
259 getValue() != slot.ptr && getValue().getType() == slot.elemType;
260}
261
262DeletionKind memref::StoreOp::removeBlockingUses(
263 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
264 RewriterBase &rewriter, Value reachingDefinition,
265 const DataLayout &dataLayout) {
266 return DeletionKind::Delete;
267}
268
269bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
270 SmallPtrSetImpl<Attribute> &usedIndices,
271 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
272 const DataLayout &dataLayout) {
273 if (slot.ptr != getMemRef() || getValue() == slot.ptr)
274 return false;
275 Attribute index = getAttributeIndexFromIndexOperands(
276 getContext(), getIndices(), getMemRefType());
277 if (!index || !slot.elementPtrs.contains(index))
278 return false;
279 usedIndices.insert(index);
280 return true;
281}
282
283DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
284 DenseMap<Attribute, MemorySlot> &subslots,
285 RewriterBase &rewriter,
286 const DataLayout &dataLayout) {
287 Attribute index = getAttributeIndexFromIndexOperands(
288 getContext(), getIndices(), getMemRefType());
289 const MemorySlot &memorySlot = subslots.at(index);
290 rewriter.modifyOpInPlace(*this, [&]() {
291 setMemRef(memorySlot.ptr);
292 getIndicesMutable().clear();
293 });
294 return DeletionKind::Keep;
295}
296
297//===----------------------------------------------------------------------===//
298// Interfaces for destructurable types
299//===----------------------------------------------------------------------===//
300
301namespace {
302
303struct MemRefDestructurableTypeExternalModel
304 : public DestructurableTypeInterface::ExternalModel<
305 MemRefDestructurableTypeExternalModel, MemRefType> {
306 std::optional<DenseMap<Attribute, Type>>
307 getSubelementIndexMap(Type type) const {
308 auto memrefType = llvm::cast<MemRefType>(type);
309 constexpr int64_t maxMemrefSizeForDestructuring = 16;
310 if (!memrefType.hasStaticShape() ||
311 memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
312 memrefType.getNumElements() == 1)
313 return {};
314
315 DenseMap<Attribute, Type> destructured;
316 walkIndicesAsAttr(
317 memrefType.getContext(), memrefType.getShape(), [&](Attribute index) {
318 destructured.insert({index, memrefType.getElementType()});
319 });
320
321 return destructured;
322 }
323
324 Type getTypeAtIndex(Type type, Attribute index) const {
325 auto memrefType = llvm::cast<MemRefType>(type);
326 auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index);
327 if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
328 return {};
329
330 Type indexType = IndexType::get(memrefType.getContext());
331 for (const auto &[coordAttr, dimSize] :
332 llvm::zip(coordArrAttr, memrefType.getShape())) {
333 auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr);
334 if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
335 coord.getInt() >= dimSize)
336 return {};
337 }
338
339 return memrefType.getElementType();
340 }
341};
342
343} // namespace
344
345//===----------------------------------------------------------------------===//
346// Register external models
347//===----------------------------------------------------------------------===//
348
349void mlir::memref::registerMemorySlotExternalModels(DialectRegistry &registry) {
350 registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) {
351 MemRefType::attachInterface<MemRefDestructurableTypeExternalModel>(*ctx);
352 });
353}
354

source code of mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp