1//===- MemRefMemorySlot.cpp - Memory Slot Interfaces ------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements Mem2Reg-related interfaces for MemRef dialect
10// operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
15#include "mlir/Dialect/MemRef/IR/MemRef.h"
16#include "mlir/IR/BuiltinDialect.h"
17#include "mlir/IR/BuiltinTypes.h"
18#include "mlir/IR/Matchers.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/IR/Value.h"
21#include "mlir/Interfaces/InferTypeOpInterface.h"
22#include "mlir/Interfaces/MemorySlotInterfaces.h"
23#include "llvm/ADT/ArrayRef.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/ErrorHandling.h"
26
27using namespace mlir;
28
29//===----------------------------------------------------------------------===//
30// Utilities
31//===----------------------------------------------------------------------===//
32
33/// Walks over the indices of the elements of a tensor of a given `shape` by
34/// updating `index` in place to the next index. This returns failure if the
35/// provided index was the last index.
36static LogicalResult nextIndex(ArrayRef<int64_t> shape,
37 MutableArrayRef<int64_t> index) {
38 for (size_t i = 0; i < shape.size(); ++i) {
39 index[i]++;
40 if (index[i] < shape[i])
41 return success();
42 index[i] = 0;
43 }
44 return failure();
45}
46
47/// Calls `walker` for each index within a tensor of a given `shape`, providing
48/// the index as an array attribute of the coordinates.
49template <typename CallableT>
50static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
51 CallableT &&walker) {
52 Type indexType = IndexType::get(ctx);
53 SmallVector<int64_t> shapeIter(shape.size(), 0);
54 do {
55 SmallVector<Attribute> indexAsAttr;
56 for (int64_t dim : shapeIter)
57 indexAsAttr.push_back(IntegerAttr::get(indexType, dim));
58 walker(ArrayAttr::get(ctx, indexAsAttr));
59 } while (succeeded(Result: nextIndex(shape, index: shapeIter)));
60}
61
62//===----------------------------------------------------------------------===//
63// Interfaces for AllocaOp
64//===----------------------------------------------------------------------===//
65
66static bool isSupportedElementType(Type type) {
67 return llvm::isa<MemRefType>(Val: type) ||
68 OpBuilder(type.getContext()).getZeroAttr(type);
69}
70
71SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
72 MemRefType type = getType();
73 if (!isSupportedElementType(type.getElementType()))
74 return {};
75 if (!type.hasStaticShape())
76 return {};
77 // Make sure the memref contains only a single element.
78 if (type.getNumElements() != 1)
79 return {};
80
81 return {MemorySlot{getResult(), type.getElementType()}};
82}
83
84Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
85 OpBuilder &builder) {
86 assert(isSupportedElementType(slot.elemType));
87 // TODO: support more types.
88 return TypeSwitch<Type, Value>(slot.elemType)
89 .Case([&](MemRefType t) {
90 return builder.create<memref::AllocaOp>(getLoc(), t);
91 })
92 .Default([&](Type t) {
93 return builder.create<arith::ConstantOp>(getLoc(), t,
94 builder.getZeroAttr(t));
95 });
96}
97
98std::optional<PromotableAllocationOpInterface>
99memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
100 Value defaultValue,
101 OpBuilder &builder) {
102 if (defaultValue.use_empty())
103 defaultValue.getDefiningOp()->erase();
104 this->erase();
105 return std::nullopt;
106}
107
108void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
109 BlockArgument argument,
110 OpBuilder &builder) {}
111
112SmallVector<DestructurableMemorySlot>
113memref::AllocaOp::getDestructurableSlots() {
114 MemRefType memrefType = getType();
115 auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
116 if (!destructurable)
117 return {};
118
119 std::optional<DenseMap<Attribute, Type>> destructuredType =
120 destructurable.getSubelementIndexMap();
121 if (!destructuredType)
122 return {};
123
124 return {
125 DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}};
126}
127
128DenseMap<Attribute, MemorySlot> memref::AllocaOp::destructure(
129 const DestructurableMemorySlot &slot,
130 const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
131 SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
132 builder.setInsertionPointAfter(*this);
133
134 DenseMap<Attribute, MemorySlot> slotMap;
135
136 auto memrefType = llvm::cast<DestructurableTypeInterface>(getType());
137 for (Attribute usedIndex : usedIndices) {
138 Type elemType = memrefType.getTypeAtIndex(usedIndex);
139 MemRefType elemPtr = MemRefType::get({}, elemType);
140 auto subAlloca = builder.create<memref::AllocaOp>(getLoc(), elemPtr);
141 newAllocators.push_back(subAlloca);
142 slotMap.try_emplace<MemorySlot>(usedIndex,
143 {subAlloca.getResult(), elemType});
144 }
145
146 return slotMap;
147}
148
149std::optional<DestructurableAllocationOpInterface>
150memref::AllocaOp::handleDestructuringComplete(
151 const DestructurableMemorySlot &slot, OpBuilder &builder) {
152 assert(slot.ptr == getResult());
153 this->erase();
154 return std::nullopt;
155}
156
157//===----------------------------------------------------------------------===//
158// Interfaces for LoadOp/StoreOp
159//===----------------------------------------------------------------------===//
160
161bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
162 return getMemRef() == slot.ptr;
163}
164
165bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
166
167Value memref::LoadOp::getStored(const MemorySlot &slot, OpBuilder &builder,
168 Value reachingDef,
169 const DataLayout &dataLayout) {
170 llvm_unreachable("getStored should not be called on LoadOp");
171}
172
173bool memref::LoadOp::canUsesBeRemoved(
174 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
175 SmallVectorImpl<OpOperand *> &newBlockingUses,
176 const DataLayout &dataLayout) {
177 if (blockingUses.size() != 1)
178 return false;
179 Value blockingUse = (*blockingUses.begin())->get();
180 return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
181 getResult().getType() == slot.elemType;
182}
183
184DeletionKind memref::LoadOp::removeBlockingUses(
185 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
186 OpBuilder &builder, Value reachingDefinition,
187 const DataLayout &dataLayout) {
188 // `canUsesBeRemoved` checked this blocking use must be the loaded slot
189 // pointer.
190 getResult().replaceAllUsesWith(reachingDefinition);
191 return DeletionKind::Delete;
192}
193
194/// Returns the index of a memref in attribute form, given its indices. Returns
195/// a null pointer if whether the indices form a valid index for the provided
196/// MemRefType cannot be computed. The indices must come from a valid memref
197/// StoreOp or LoadOp.
198static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx,
199 ValueRange indices,
200 MemRefType memrefType) {
201 SmallVector<Attribute> index;
202 for (auto [coord, dimSize] : llvm::zip(indices, memrefType.getShape())) {
203 IntegerAttr coordAttr;
204 if (!matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
205 return {};
206 // MemRefType shape dimensions are always positive (checked by verifier).
207 std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
208 if (!coordInt || coordInt.value() >= static_cast<uint64_t>(dimSize))
209 return {};
210 index.push_back(coordAttr);
211 }
212 return ArrayAttr::get(ctx, index);
213}
214
215bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
216 SmallPtrSetImpl<Attribute> &usedIndices,
217 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
218 const DataLayout &dataLayout) {
219 if (slot.ptr != getMemRef())
220 return false;
221 Attribute index = getAttributeIndexFromIndexOperands(
222 getContext(), getIndices(), getMemRefType());
223 if (!index)
224 return false;
225 usedIndices.insert(index);
226 return true;
227}
228
229DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
230 DenseMap<Attribute, MemorySlot> &subslots,
231 OpBuilder &builder,
232 const DataLayout &dataLayout) {
233 Attribute index = getAttributeIndexFromIndexOperands(
234 getContext(), getIndices(), getMemRefType());
235 const MemorySlot &memorySlot = subslots.at(index);
236 setMemRef(memorySlot.ptr);
237 getIndicesMutable().clear();
238 return DeletionKind::Keep;
239}
240
241bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
242
243bool memref::StoreOp::storesTo(const MemorySlot &slot) {
244 return getMemRef() == slot.ptr;
245}
246
247Value memref::StoreOp::getStored(const MemorySlot &slot, OpBuilder &builder,
248 Value reachingDef,
249 const DataLayout &dataLayout) {
250 return getValue();
251}
252
253bool memref::StoreOp::canUsesBeRemoved(
254 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
255 SmallVectorImpl<OpOperand *> &newBlockingUses,
256 const DataLayout &dataLayout) {
257 if (blockingUses.size() != 1)
258 return false;
259 Value blockingUse = (*blockingUses.begin())->get();
260 return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
261 getValue() != slot.ptr && getValue().getType() == slot.elemType;
262}
263
264DeletionKind memref::StoreOp::removeBlockingUses(
265 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
266 OpBuilder &builder, Value reachingDefinition,
267 const DataLayout &dataLayout) {
268 return DeletionKind::Delete;
269}
270
271bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
272 SmallPtrSetImpl<Attribute> &usedIndices,
273 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
274 const DataLayout &dataLayout) {
275 if (slot.ptr != getMemRef() || getValue() == slot.ptr)
276 return false;
277 Attribute index = getAttributeIndexFromIndexOperands(
278 getContext(), getIndices(), getMemRefType());
279 if (!index || !slot.subelementTypes.contains(index))
280 return false;
281 usedIndices.insert(index);
282 return true;
283}
284
285DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
286 DenseMap<Attribute, MemorySlot> &subslots,
287 OpBuilder &builder,
288 const DataLayout &dataLayout) {
289 Attribute index = getAttributeIndexFromIndexOperands(
290 getContext(), getIndices(), getMemRefType());
291 const MemorySlot &memorySlot = subslots.at(index);
292 setMemRef(memorySlot.ptr);
293 getIndicesMutable().clear();
294 return DeletionKind::Keep;
295}
296
297//===----------------------------------------------------------------------===//
298// Interfaces for destructurable types
299//===----------------------------------------------------------------------===//
300
301namespace {
302
303struct MemRefDestructurableTypeExternalModel
304 : public DestructurableTypeInterface::ExternalModel<
305 MemRefDestructurableTypeExternalModel, MemRefType> {
306 std::optional<DenseMap<Attribute, Type>>
307 getSubelementIndexMap(Type type) const {
308 auto memrefType = llvm::cast<MemRefType>(type);
309 constexpr int64_t maxMemrefSizeForDestructuring = 16;
310 if (!memrefType.hasStaticShape() ||
311 memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
312 memrefType.getNumElements() == 1)
313 return {};
314
315 DenseMap<Attribute, Type> destructured;
316 walkIndicesAsAttr(
317 memrefType.getContext(), memrefType.getShape(), [&](Attribute index) {
318 destructured.insert({index, memrefType.getElementType()});
319 });
320
321 return destructured;
322 }
323
324 Type getTypeAtIndex(Type type, Attribute index) const {
325 auto memrefType = llvm::cast<MemRefType>(type);
326 auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index);
327 if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
328 return {};
329
330 Type indexType = IndexType::get(memrefType.getContext());
331 for (const auto &[coordAttr, dimSize] :
332 llvm::zip(coordArrAttr, memrefType.getShape())) {
333 auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr);
334 if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
335 coord.getInt() >= dimSize)
336 return {};
337 }
338
339 return memrefType.getElementType();
340 }
341};
342
343} // namespace
344
345//===----------------------------------------------------------------------===//
346// Register external models
347//===----------------------------------------------------------------------===//
348
349void mlir::memref::registerMemorySlotExternalModels(DialectRegistry &registry) {
350 registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) {
351 MemRefType::attachInterface<MemRefDestructurableTypeExternalModel>(*ctx);
352 });
353}
354

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp