1 | //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
10 | #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" |
11 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
12 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
13 | #include "mlir/IR/AffineMap.h" |
14 | #include "mlir/IR/BuiltinAttributes.h" |
15 | |
16 | using namespace mlir; |
17 | |
18 | //===----------------------------------------------------------------------===// |
19 | // ConvertToLLVMPattern |
20 | //===----------------------------------------------------------------------===// |
21 | |
22 | ConvertToLLVMPattern::ConvertToLLVMPattern( |
23 | StringRef rootOpName, MLIRContext *context, |
24 | const LLVMTypeConverter &typeConverter, PatternBenefit benefit) |
25 | : ConversionPattern(typeConverter, rootOpName, benefit, context) {} |
26 | |
27 | const LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const { |
28 | return static_cast<const LLVMTypeConverter *>( |
29 | ConversionPattern::getTypeConverter()); |
30 | } |
31 | |
32 | LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const { |
33 | return *getTypeConverter()->getDialect(); |
34 | } |
35 | |
36 | Type ConvertToLLVMPattern::getIndexType() const { |
37 | return getTypeConverter()->getIndexType(); |
38 | } |
39 | |
40 | Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const { |
41 | return IntegerType::get(&getTypeConverter()->getContext(), |
42 | getTypeConverter()->getPointerBitwidth(addressSpace)); |
43 | } |
44 | |
45 | Type ConvertToLLVMPattern::getVoidType() const { |
46 | return LLVM::LLVMVoidType::get(ctx: &getTypeConverter()->getContext()); |
47 | } |
48 | |
49 | Type ConvertToLLVMPattern::getVoidPtrType() const { |
50 | return LLVM::LLVMPointerType::get(&getTypeConverter()->getContext()); |
51 | } |
52 | |
53 | Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder, |
54 | Location loc, |
55 | Type resultType, |
56 | int64_t value) { |
57 | return builder.create<LLVM::ConstantOp>(loc, resultType, |
58 | builder.getIndexAttr(value)); |
59 | } |
60 | |
61 | Value ConvertToLLVMPattern::getStridedElementPtr( |
62 | Location loc, MemRefType type, Value memRefDesc, ValueRange indices, |
63 | ConversionPatternRewriter &rewriter) const { |
64 | |
65 | auto [strides, offset] = getStridesAndOffset(type); |
66 | |
67 | MemRefDescriptor memRefDescriptor(memRefDesc); |
68 | // Use a canonical representation of the start address so that later |
69 | // optimizations have a longer sequence of instructions to CSE. |
70 | // If we don't do that we would sprinkle the memref.offset in various |
71 | // position of the different address computations. |
72 | Value base = |
73 | memRefDescriptor.bufferPtr(builder&: rewriter, loc, converter: *getTypeConverter(), type: type); |
74 | |
75 | Type indexType = getIndexType(); |
76 | Value index; |
77 | for (int i = 0, e = indices.size(); i < e; ++i) { |
78 | Value increment = indices[i]; |
79 | if (strides[i] != 1) { // Skip if stride is 1. |
80 | Value stride = |
81 | ShapedType::isDynamic(strides[i]) |
82 | ? memRefDescriptor.stride(rewriter, loc, i) |
83 | : createIndexAttrConstant(rewriter, loc, indexType, strides[i]); |
84 | increment = rewriter.create<LLVM::MulOp>(loc, increment, stride); |
85 | } |
86 | index = |
87 | index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment; |
88 | } |
89 | |
90 | Type elementPtrType = memRefDescriptor.getElementPtrType(); |
91 | return index ? rewriter.create<LLVM::GEPOp>( |
92 | loc, elementPtrType, |
93 | getTypeConverter()->convertType(type.getElementType()), |
94 | base, index) |
95 | : base; |
96 | } |
97 | |
98 | // Check if the MemRefType `type` is supported by the lowering. We currently |
99 | // only support memrefs with identity maps. |
100 | bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps( |
101 | MemRefType type) const { |
102 | if (!typeConverter->convertType(type.getElementType())) |
103 | return false; |
104 | return type.getLayout().isIdentity(); |
105 | } |
106 | |
107 | Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { |
108 | auto addressSpace = getTypeConverter()->getMemRefAddressSpace(type: type); |
109 | if (failed(addressSpace)) |
110 | return {}; |
111 | return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace); |
112 | } |
113 | |
114 | void ConvertToLLVMPattern::getMemRefDescriptorSizes( |
115 | Location loc, MemRefType memRefType, ValueRange dynamicSizes, |
116 | ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes, |
117 | SmallVectorImpl<Value> &strides, Value &size, bool sizeInBytes) const { |
118 | assert(isConvertibleAndHasIdentityMaps(memRefType) && |
119 | "layout maps must have been normalized away" ); |
120 | assert(count(memRefType.getShape(), ShapedType::kDynamic) == |
121 | static_cast<ssize_t>(dynamicSizes.size()) && |
122 | "dynamicSizes size doesn't match dynamic sizes count in memref shape" ); |
123 | |
124 | sizes.reserve(N: memRefType.getRank()); |
125 | unsigned dynamicIndex = 0; |
126 | Type indexType = getIndexType(); |
127 | for (int64_t size : memRefType.getShape()) { |
128 | sizes.push_back( |
129 | size == ShapedType::kDynamic |
130 | ? dynamicSizes[dynamicIndex++] |
131 | : createIndexAttrConstant(rewriter, loc, indexType, size)); |
132 | } |
133 | |
134 | // Strides: iterate sizes in reverse order and multiply. |
135 | int64_t stride = 1; |
136 | Value runningStride = createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: 1); |
137 | strides.resize(memRefType.getRank()); |
138 | for (auto i = memRefType.getRank(); i-- > 0;) { |
139 | strides[i] = runningStride; |
140 | |
141 | int64_t staticSize = memRefType.getShape()[i]; |
142 | if (staticSize == 0) |
143 | continue; |
144 | bool useSizeAsStride = stride == 1; |
145 | if (staticSize == ShapedType::kDynamic) |
146 | stride = ShapedType::kDynamic; |
147 | if (stride != ShapedType::kDynamic) |
148 | stride *= staticSize; |
149 | |
150 | if (useSizeAsStride) |
151 | runningStride = sizes[i]; |
152 | else if (stride == ShapedType::kDynamic) |
153 | runningStride = |
154 | rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]); |
155 | else |
156 | runningStride = createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: stride); |
157 | } |
158 | if (sizeInBytes) { |
159 | // Buffer size in bytes. |
160 | Type elementType = typeConverter->convertType(memRefType.getElementType()); |
161 | auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); |
162 | Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType); |
163 | Value gepPtr = rewriter.create<LLVM::GEPOp>( |
164 | loc, elementPtrType, elementType, nullPtr, runningStride); |
165 | size = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr); |
166 | } else { |
167 | size = runningStride; |
168 | } |
169 | } |
170 | |
171 | Value ConvertToLLVMPattern::getSizeInBytes( |
172 | Location loc, Type type, ConversionPatternRewriter &rewriter) const { |
173 | // Compute the size of an individual element. This emits the MLIR equivalent |
174 | // of the following sizeof(...) implementation in LLVM IR: |
175 | // %0 = getelementptr %elementType* null, %indexType 1 |
176 | // %1 = ptrtoint %elementType* %0 to %indexType |
177 | // which is a common pattern of getting the size of a type in bytes. |
178 | Type llvmType = typeConverter->convertType(t: type); |
179 | auto convertedPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); |
180 | auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, convertedPtrType); |
181 | auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, llvmType, |
182 | nullPtr, ArrayRef<LLVM::GEPArg>{1}); |
183 | return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep); |
184 | } |
185 | |
186 | Value ConvertToLLVMPattern::getNumElements( |
187 | Location loc, MemRefType memRefType, ValueRange dynamicSizes, |
188 | ConversionPatternRewriter &rewriter) const { |
189 | assert(count(memRefType.getShape(), ShapedType::kDynamic) == |
190 | static_cast<ssize_t>(dynamicSizes.size()) && |
191 | "dynamicSizes size doesn't match dynamic sizes count in memref shape" ); |
192 | |
193 | Type indexType = getIndexType(); |
194 | Value numElements = memRefType.getRank() == 0 |
195 | ? createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: 1) |
196 | : nullptr; |
197 | unsigned dynamicIndex = 0; |
198 | |
199 | // Compute the total number of memref elements. |
200 | for (int64_t staticSize : memRefType.getShape()) { |
201 | if (numElements) { |
202 | Value size = |
203 | staticSize == ShapedType::kDynamic |
204 | ? dynamicSizes[dynamicIndex++] |
205 | : createIndexAttrConstant(rewriter, loc, indexType, staticSize); |
206 | numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); |
207 | } else { |
208 | numElements = |
209 | staticSize == ShapedType::kDynamic |
210 | ? dynamicSizes[dynamicIndex++] |
211 | : createIndexAttrConstant(rewriter, loc, indexType, staticSize); |
212 | } |
213 | } |
214 | return numElements; |
215 | } |
216 | |
217 | /// Creates and populates the memref descriptor struct given all its fields. |
218 | MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor( |
219 | Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, |
220 | ArrayRef<Value> sizes, ArrayRef<Value> strides, |
221 | ConversionPatternRewriter &rewriter) const { |
222 | auto structType = typeConverter->convertType(memRefType); |
223 | auto memRefDescriptor = MemRefDescriptor::undef(builder&: rewriter, loc, descriptorType: structType); |
224 | |
225 | // Field 1: Allocated pointer, used for malloc/free. |
226 | memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr); |
227 | |
228 | // Field 2: Actual aligned pointer to payload. |
229 | memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr); |
230 | |
231 | // Field 3: Offset in aligned pointer. |
232 | Type indexType = getIndexType(); |
233 | memRefDescriptor.setOffset( |
234 | rewriter, loc, createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: 0)); |
235 | |
236 | // Fields 4: Sizes. |
237 | for (const auto &en : llvm::enumerate(First&: sizes)) |
238 | memRefDescriptor.setSize(rewriter, loc, en.index(), en.value()); |
239 | |
240 | // Field 5: Strides. |
241 | for (const auto &en : llvm::enumerate(First&: strides)) |
242 | memRefDescriptor.setStride(rewriter, loc, en.index(), en.value()); |
243 | |
244 | return memRefDescriptor; |
245 | } |
246 | |
247 | LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( |
248 | OpBuilder &builder, Location loc, TypeRange origTypes, |
249 | SmallVectorImpl<Value> &operands, bool toDynamic) const { |
250 | assert(origTypes.size() == operands.size() && |
251 | "expected as may original types as operands" ); |
252 | |
253 | // Find operands of unranked memref type and store them. |
254 | SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs; |
255 | SmallVector<unsigned> unrankedAddressSpaces; |
256 | for (unsigned i = 0, e = operands.size(); i < e; ++i) { |
257 | if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) { |
258 | unrankedMemrefs.emplace_back(Args&: operands[i]); |
259 | FailureOr<unsigned> addressSpace = |
260 | getTypeConverter()->getMemRefAddressSpace(type: memRefType); |
261 | if (failed(result: addressSpace)) |
262 | return failure(); |
263 | unrankedAddressSpaces.emplace_back(Args&: *addressSpace); |
264 | } |
265 | } |
266 | |
267 | if (unrankedMemrefs.empty()) |
268 | return success(); |
269 | |
270 | // Compute allocation sizes. |
271 | SmallVector<Value> sizes; |
272 | UnrankedMemRefDescriptor::computeSizes(builder, loc, typeConverter: *getTypeConverter(), |
273 | values: unrankedMemrefs, addressSpaces: unrankedAddressSpaces, |
274 | sizes); |
275 | |
276 | // Get frequently used types. |
277 | Type indexType = getTypeConverter()->getIndexType(); |
278 | |
279 | // Find the malloc and free, or declare them if necessary. |
280 | auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>(); |
281 | LLVM::LLVMFuncOp freeFunc, mallocFunc; |
282 | if (toDynamic) |
283 | mallocFunc = LLVM::lookupOrCreateMallocFn(moduleOp: module, indexType); |
284 | if (!toDynamic) |
285 | freeFunc = LLVM::lookupOrCreateFreeFn(moduleOp: module); |
286 | |
287 | unsigned unrankedMemrefPos = 0; |
288 | for (unsigned i = 0, e = operands.size(); i < e; ++i) { |
289 | Type type = origTypes[i]; |
290 | if (!isa<UnrankedMemRefType>(Val: type)) |
291 | continue; |
292 | Value allocationSize = sizes[unrankedMemrefPos++]; |
293 | UnrankedMemRefDescriptor desc(operands[i]); |
294 | |
295 | // Allocate memory, copy, and free the source if necessary. |
296 | Value memory = |
297 | toDynamic |
298 | ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize) |
299 | .getResult() |
300 | : builder.create<LLVM::AllocaOp>(loc, getVoidPtrType(), |
301 | IntegerType::get(getContext(), 8), |
302 | allocationSize, |
303 | /*alignment=*/0); |
304 | Value source = desc.memRefDescPtr(builder, loc); |
305 | builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false); |
306 | if (!toDynamic) |
307 | builder.create<LLVM::CallOp>(loc, freeFunc, source); |
308 | |
309 | // Create a new descriptor. The same descriptor can be returned multiple |
310 | // times, attempting to modify its pointer can lead to memory leaks |
311 | // (allocated twice and overwritten) or double frees (the caller does not |
312 | // know if the descriptor points to the same memory). |
313 | Type descriptorType = getTypeConverter()->convertType(t: type); |
314 | if (!descriptorType) |
315 | return failure(); |
316 | auto updatedDesc = |
317 | UnrankedMemRefDescriptor::undef(builder, loc, descriptorType); |
318 | Value rank = desc.rank(builder, loc); |
319 | updatedDesc.setRank(builder, loc, value: rank); |
320 | updatedDesc.setMemRefDescPtr(builder, loc, value: memory); |
321 | |
322 | operands[i] = updatedDesc; |
323 | } |
324 | |
325 | return success(); |
326 | } |
327 | |
328 | //===----------------------------------------------------------------------===// |
329 | // Detail methods |
330 | //===----------------------------------------------------------------------===// |
331 | |
332 | void LLVM::detail::setNativeProperties(Operation *op, |
333 | IntegerOverflowFlags overflowFlags) { |
334 | if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op)) |
335 | iface.setOverflowFlags(overflowFlags); |
336 | } |
337 | |
338 | /// Replaces the given operation "op" with a new operation of type "targetOp" |
339 | /// and given operands. |
340 | LogicalResult LLVM::detail::oneToOneRewrite( |
341 | Operation *op, StringRef targetOp, ValueRange operands, |
342 | ArrayRef<NamedAttribute> targetAttrs, |
343 | const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, |
344 | IntegerOverflowFlags overflowFlags) { |
345 | unsigned numResults = op->getNumResults(); |
346 | |
347 | SmallVector<Type> resultTypes; |
348 | if (numResults != 0) { |
349 | resultTypes.push_back( |
350 | Elt: typeConverter.packOperationResults(types: op->getResultTypes())); |
351 | if (!resultTypes.back()) |
352 | return failure(); |
353 | } |
354 | |
355 | // Create the operation through state since we don't know its C++ type. |
356 | Operation *newOp = |
357 | rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, |
358 | resultTypes, targetAttrs); |
359 | |
360 | setNativeProperties(newOp, overflowFlags); |
361 | |
362 | // If the operation produced 0 or 1 result, return them immediately. |
363 | if (numResults == 0) |
364 | return rewriter.eraseOp(op), success(); |
365 | if (numResults == 1) |
366 | return rewriter.replaceOp(op, newValues: newOp->getResult(idx: 0)), success(); |
367 | |
368 | // Otherwise, it had been converted to an operation producing a structure. |
369 | // Extract individual results from the structure and return them as list. |
370 | SmallVector<Value, 4> results; |
371 | results.reserve(N: numResults); |
372 | for (unsigned i = 0; i < numResults; ++i) { |
373 | results.push_back(rewriter.create<LLVM::ExtractValueOp>( |
374 | op->getLoc(), newOp->getResult(0), i)); |
375 | } |
376 | rewriter.replaceOp(op, newValues: results); |
377 | return success(); |
378 | } |
379 | |