1//===- AllocLikeConversion.cpp - LLVM conversion for alloc operations -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
10#include "mlir/Analysis/DataLayoutAnalysis.h"
11#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
12#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13
14using namespace mlir;
15
16namespace {
17// TODO: Fix the LLVM utilities for looking up functions to take Operation*
18// with SymbolTable trait instead of ModuleOp and make similar change here. This
19// allows call sites to use getParentWithTrait<OpTrait::SymbolTable> instead
20// of getParentOfType<ModuleOp> to pass down the operation.
21LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
22 ModuleOp module, Type indexType) {
23 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
24
25 if (useGenericFn)
26 return LLVM::lookupOrCreateGenericAllocFn(moduleOp: module, indexType);
27
28 return LLVM::lookupOrCreateMallocFn(moduleOp: module, indexType);
29}
30
31LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
32 ModuleOp module, Type indexType) {
33 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
34
35 if (useGenericFn)
36 return LLVM::lookupOrCreateGenericAlignedAllocFn(moduleOp: module, indexType);
37
38 return LLVM::lookupOrCreateAlignedAllocFn(moduleOp: module, indexType);
39}
40
41} // end namespace
42
43Value AllocationOpLLVMLowering::createAligned(
44 ConversionPatternRewriter &rewriter, Location loc, Value input,
45 Value alignment) {
46 Value one = createIndexAttrConstant(builder&: rewriter, loc, resultType: alignment.getType(), value: 1);
47 Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
48 Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
49 Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
50 return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
51}
52
53static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
54 Location loc, Value allocatedPtr,
55 MemRefType memRefType, Type elementPtrType,
56 const LLVMTypeConverter &typeConverter) {
57 auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
58 FailureOr<unsigned> maybeMemrefAddrSpace =
59 typeConverter.getMemRefAddressSpace(type: memRefType);
60 if (failed(result: maybeMemrefAddrSpace))
61 return Value();
62 unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
63 if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
64 allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
65 loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
66 allocatedPtr);
67 return allocatedPtr;
68}
69
70std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
71 ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
72 Operation *op, Value alignment) const {
73 if (alignment) {
74 // Adjust the allocation size to consider alignment.
75 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
76 }
77
78 MemRefType memRefType = getMemRefResultType(op);
79 // Allocate the underlying buffer.
80 Type elementPtrType = this->getElementPtrType(type: memRefType);
81 LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
82 getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
83 auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
84
85 Value allocatedPtr =
86 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
87 elementPtrType, *getTypeConverter());
88 if (!allocatedPtr)
89 return std::make_tuple(args: Value(), args: Value());
90 Value alignedPtr = allocatedPtr;
91 if (alignment) {
92 // Compute the aligned pointer.
93 Value allocatedInt =
94 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
95 Value alignmentInt = createAligned(rewriter, loc, input: allocatedInt, alignment);
96 alignedPtr =
97 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
98 }
99
100 return std::make_tuple(args&: allocatedPtr, args&: alignedPtr);
101}
102
103unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes(
104 MemRefType memRefType, Operation *op,
105 const DataLayout *defaultLayout) const {
106 const DataLayout *layout = defaultLayout;
107 if (const DataLayoutAnalysis *analysis =
108 getTypeConverter()->getDataLayoutAnalysis()) {
109 layout = &analysis->getAbove(operation: op);
110 }
111 Type elementType = memRefType.getElementType();
112 if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
113 return getTypeConverter()->getMemRefDescriptorSize(type: memRefElementType,
114 layout: *layout);
115 if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
116 return getTypeConverter()->getUnrankedMemRefDescriptorSize(
117 type: memRefElementType, layout: *layout);
118 return layout->getTypeSize(t: elementType);
119}
120
121bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf(
122 MemRefType type, uint64_t factor, Operation *op,
123 const DataLayout *defaultLayout) const {
124 uint64_t sizeDivisor = getMemRefEltSizeInBytes(memRefType: type, op, defaultLayout);
125 for (unsigned i = 0, e = type.getRank(); i < e; i++) {
126 if (type.isDynamicDim(i))
127 continue;
128 sizeDivisor = sizeDivisor * type.getDimSize(i);
129 }
130 return sizeDivisor % factor == 0;
131}
132
133Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
134 ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
135 Operation *op, const DataLayout *defaultLayout, int64_t alignment) const {
136 Value allocAlignment =
137 createIndexAttrConstant(builder&: rewriter, loc, resultType: getIndexType(), value: alignment);
138
139 MemRefType memRefType = getMemRefResultType(op);
140 // Function aligned_alloc requires size to be a multiple of alignment; we pad
141 // the size to the next multiple if necessary.
142 if (!isMemRefSizeMultipleOf(type: memRefType, factor: alignment, op, defaultLayout))
143 sizeBytes = createAligned(rewriter, loc, input: sizeBytes, alignment: allocAlignment);
144
145 Type elementPtrType = this->getElementPtrType(type: memRefType);
146 LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
147 getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
148 auto results = rewriter.create<LLVM::CallOp>(
149 loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
150
151 return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
152 elementPtrType, *getTypeConverter());
153}
154
155void AllocLikeOpLLVMLowering::setRequiresNumElements() {
156 requiresNumElements = true;
157}
158
159LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
160 Operation *op, ArrayRef<Value> operands,
161 ConversionPatternRewriter &rewriter) const {
162 MemRefType memRefType = getMemRefResultType(op);
163 if (!isConvertibleAndHasIdentityMaps(type: memRefType))
164 return rewriter.notifyMatchFailure(arg&: op, msg: "incompatible memref type");
165 auto loc = op->getLoc();
166
167 // Get actual sizes of the memref as values: static sizes are constant
168 // values and dynamic sizes are passed to 'alloc' as operands. In case of
169 // zero-dimensional memref, assume a scalar (size 1).
170 SmallVector<Value, 4> sizes;
171 SmallVector<Value, 4> strides;
172 Value size;
173
174 this->getMemRefDescriptorSizes(loc, memRefType: memRefType, dynamicSizes: operands, rewriter, sizes,
175 strides, size, sizeInBytes: !requiresNumElements);
176
177 // Allocate the underlying buffer.
178 auto [allocatedPtr, alignedPtr] =
179 this->allocateBuffer(rewriter, loc, size, op);
180
181 if (!allocatedPtr || !alignedPtr)
182 return rewriter.notifyMatchFailure(arg&: loc,
183 msg: "underlying buffer allocation failed");
184
185 // Create the MemRef descriptor.
186 auto memRefDescriptor = this->createMemRefDescriptor(
187 loc, memRefType: memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
188
189 // Return the final value of the descriptor.
190 rewriter.replaceOp(op, {memRefDescriptor});
191 return success();
192}
193

source code of mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp