1 | //===- AllocLikeConversion.cpp - LLVM conversion for alloc operations -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" |
10 | #include "mlir/Analysis/DataLayoutAnalysis.h" |
11 | #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" |
12 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
13 | |
14 | using namespace mlir; |
15 | |
16 | namespace { |
17 | // TODO: Fix the LLVM utilities for looking up functions to take Operation* |
18 | // with SymbolTable trait instead of ModuleOp and make similar change here. This |
19 | // allows call sites to use getParentWithTrait<OpTrait::SymbolTable> instead |
20 | // of getParentOfType<ModuleOp> to pass down the operation. |
21 | LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, |
22 | ModuleOp module, Type indexType) { |
23 | bool useGenericFn = typeConverter->getOptions().useGenericFunctions; |
24 | |
25 | if (useGenericFn) |
26 | return LLVM::lookupOrCreateGenericAllocFn(moduleOp: module, indexType); |
27 | |
28 | return LLVM::lookupOrCreateMallocFn(moduleOp: module, indexType); |
29 | } |
30 | |
31 | LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter, |
32 | ModuleOp module, Type indexType) { |
33 | bool useGenericFn = typeConverter->getOptions().useGenericFunctions; |
34 | |
35 | if (useGenericFn) |
36 | return LLVM::lookupOrCreateGenericAlignedAllocFn(moduleOp: module, indexType); |
37 | |
38 | return LLVM::lookupOrCreateAlignedAllocFn(moduleOp: module, indexType); |
39 | } |
40 | |
41 | } // end namespace |
42 | |
43 | Value AllocationOpLLVMLowering::createAligned( |
44 | ConversionPatternRewriter &rewriter, Location loc, Value input, |
45 | Value alignment) { |
46 | Value one = createIndexAttrConstant(builder&: rewriter, loc, resultType: alignment.getType(), value: 1); |
47 | Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one); |
48 | Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump); |
49 | Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment); |
50 | return rewriter.create<LLVM::SubOp>(loc, bumped, mod); |
51 | } |
52 | |
53 | static Value castAllocFuncResult(ConversionPatternRewriter &rewriter, |
54 | Location loc, Value allocatedPtr, |
55 | MemRefType memRefType, Type elementPtrType, |
56 | const LLVMTypeConverter &typeConverter) { |
57 | auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType()); |
58 | FailureOr<unsigned> maybeMemrefAddrSpace = |
59 | typeConverter.getMemRefAddressSpace(type: memRefType); |
60 | if (failed(result: maybeMemrefAddrSpace)) |
61 | return Value(); |
62 | unsigned memrefAddrSpace = *maybeMemrefAddrSpace; |
63 | if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace) |
64 | allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>( |
65 | loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace), |
66 | allocatedPtr); |
67 | return allocatedPtr; |
68 | } |
69 | |
70 | std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign( |
71 | ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, |
72 | Operation *op, Value alignment) const { |
73 | if (alignment) { |
74 | // Adjust the allocation size to consider alignment. |
75 | sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); |
76 | } |
77 | |
78 | MemRefType memRefType = getMemRefResultType(op); |
79 | // Allocate the underlying buffer. |
80 | Type elementPtrType = this->getElementPtrType(type: memRefType); |
81 | LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn( |
82 | getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType()); |
83 | auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes); |
84 | |
85 | Value allocatedPtr = |
86 | castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, |
87 | elementPtrType, *getTypeConverter()); |
88 | if (!allocatedPtr) |
89 | return std::make_tuple(args: Value(), args: Value()); |
90 | Value alignedPtr = allocatedPtr; |
91 | if (alignment) { |
92 | // Compute the aligned pointer. |
93 | Value allocatedInt = |
94 | rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); |
95 | Value alignmentInt = createAligned(rewriter, loc, input: allocatedInt, alignment); |
96 | alignedPtr = |
97 | rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt); |
98 | } |
99 | |
100 | return std::make_tuple(args&: allocatedPtr, args&: alignedPtr); |
101 | } |
102 | |
103 | unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes( |
104 | MemRefType memRefType, Operation *op, |
105 | const DataLayout *defaultLayout) const { |
106 | const DataLayout *layout = defaultLayout; |
107 | if (const DataLayoutAnalysis *analysis = |
108 | getTypeConverter()->getDataLayoutAnalysis()) { |
109 | layout = &analysis->getAbove(operation: op); |
110 | } |
111 | Type elementType = memRefType.getElementType(); |
112 | if (auto memRefElementType = dyn_cast<MemRefType>(elementType)) |
113 | return getTypeConverter()->getMemRefDescriptorSize(type: memRefElementType, |
114 | layout: *layout); |
115 | if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType)) |
116 | return getTypeConverter()->getUnrankedMemRefDescriptorSize( |
117 | type: memRefElementType, layout: *layout); |
118 | return layout->getTypeSize(t: elementType); |
119 | } |
120 | |
121 | bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf( |
122 | MemRefType type, uint64_t factor, Operation *op, |
123 | const DataLayout *defaultLayout) const { |
124 | uint64_t sizeDivisor = getMemRefEltSizeInBytes(memRefType: type, op, defaultLayout); |
125 | for (unsigned i = 0, e = type.getRank(); i < e; i++) { |
126 | if (type.isDynamicDim(i)) |
127 | continue; |
128 | sizeDivisor = sizeDivisor * type.getDimSize(i); |
129 | } |
130 | return sizeDivisor % factor == 0; |
131 | } |
132 | |
133 | Value AllocationOpLLVMLowering::allocateBufferAutoAlign( |
134 | ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, |
135 | Operation *op, const DataLayout *defaultLayout, int64_t alignment) const { |
136 | Value allocAlignment = |
137 | createIndexAttrConstant(builder&: rewriter, loc, resultType: getIndexType(), value: alignment); |
138 | |
139 | MemRefType memRefType = getMemRefResultType(op); |
140 | // Function aligned_alloc requires size to be a multiple of alignment; we pad |
141 | // the size to the next multiple if necessary. |
142 | if (!isMemRefSizeMultipleOf(type: memRefType, factor: alignment, op, defaultLayout)) |
143 | sizeBytes = createAligned(rewriter, loc, input: sizeBytes, alignment: allocAlignment); |
144 | |
145 | Type elementPtrType = this->getElementPtrType(type: memRefType); |
146 | LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn( |
147 | getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType()); |
148 | auto results = rewriter.create<LLVM::CallOp>( |
149 | loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes})); |
150 | |
151 | return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, |
152 | elementPtrType, *getTypeConverter()); |
153 | } |
154 | |
155 | void AllocLikeOpLLVMLowering::setRequiresNumElements() { |
156 | requiresNumElements = true; |
157 | } |
158 | |
159 | LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite( |
160 | Operation *op, ArrayRef<Value> operands, |
161 | ConversionPatternRewriter &rewriter) const { |
162 | MemRefType memRefType = getMemRefResultType(op); |
163 | if (!isConvertibleAndHasIdentityMaps(type: memRefType)) |
164 | return rewriter.notifyMatchFailure(arg&: op, msg: "incompatible memref type" ); |
165 | auto loc = op->getLoc(); |
166 | |
167 | // Get actual sizes of the memref as values: static sizes are constant |
168 | // values and dynamic sizes are passed to 'alloc' as operands. In case of |
169 | // zero-dimensional memref, assume a scalar (size 1). |
170 | SmallVector<Value, 4> sizes; |
171 | SmallVector<Value, 4> strides; |
172 | Value size; |
173 | |
174 | this->getMemRefDescriptorSizes(loc, memRefType: memRefType, dynamicSizes: operands, rewriter, sizes, |
175 | strides, size, sizeInBytes: !requiresNumElements); |
176 | |
177 | // Allocate the underlying buffer. |
178 | auto [allocatedPtr, alignedPtr] = |
179 | this->allocateBuffer(rewriter, loc, size, op); |
180 | |
181 | if (!allocatedPtr || !alignedPtr) |
182 | return rewriter.notifyMatchFailure(arg&: loc, |
183 | msg: "underlying buffer allocation failed" ); |
184 | |
185 | // Create the MemRef descriptor. |
186 | auto memRefDescriptor = this->createMemRefDescriptor( |
187 | loc, memRefType: memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter); |
188 | |
189 | // Return the final value of the descriptor. |
190 | rewriter.replaceOp(op, {memRefDescriptor}); |
191 | return success(); |
192 | } |
193 | |