1//===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/LLVMCommon/Pattern.h"
10#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
11#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
13#include "mlir/IR/AffineMap.h"
14#include "mlir/IR/BuiltinAttributes.h"
15
16using namespace mlir;
17
18//===----------------------------------------------------------------------===//
19// ConvertToLLVMPattern
20//===----------------------------------------------------------------------===//
21
22ConvertToLLVMPattern::ConvertToLLVMPattern(
23 StringRef rootOpName, MLIRContext *context,
24 const LLVMTypeConverter &typeConverter, PatternBenefit benefit)
25 : ConversionPattern(typeConverter, rootOpName, benefit, context) {}
26
27const LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const {
28 return static_cast<const LLVMTypeConverter *>(
29 ConversionPattern::getTypeConverter());
30}
31
32LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
33 return *getTypeConverter()->getDialect();
34}
35
36Type ConvertToLLVMPattern::getIndexType() const {
37 return getTypeConverter()->getIndexType();
38}
39
40Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
41 return IntegerType::get(&getTypeConverter()->getContext(),
42 getTypeConverter()->getPointerBitwidth(addressSpace));
43}
44
45Type ConvertToLLVMPattern::getVoidType() const {
46 return LLVM::LLVMVoidType::get(ctx: &getTypeConverter()->getContext());
47}
48
49Type ConvertToLLVMPattern::getVoidPtrType() const {
50 return LLVM::LLVMPointerType::get(&getTypeConverter()->getContext());
51}
52
53Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
54 Location loc,
55 Type resultType,
56 int64_t value) {
57 return builder.create<LLVM::ConstantOp>(loc, resultType,
58 builder.getIndexAttr(value));
59}
60
61Value ConvertToLLVMPattern::getStridedElementPtr(
62 Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
63 ConversionPatternRewriter &rewriter) const {
64
65 auto [strides, offset] = getStridesAndOffset(type);
66
67 MemRefDescriptor memRefDescriptor(memRefDesc);
68 // Use a canonical representation of the start address so that later
69 // optimizations have a longer sequence of instructions to CSE.
70 // If we don't do that we would sprinkle the memref.offset in various
71 // position of the different address computations.
72 Value base =
73 memRefDescriptor.bufferPtr(builder&: rewriter, loc, converter: *getTypeConverter(), type: type);
74
75 Type indexType = getIndexType();
76 Value index;
77 for (int i = 0, e = indices.size(); i < e; ++i) {
78 Value increment = indices[i];
79 if (strides[i] != 1) { // Skip if stride is 1.
80 Value stride =
81 ShapedType::isDynamic(strides[i])
82 ? memRefDescriptor.stride(rewriter, loc, i)
83 : createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
84 increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
85 }
86 index =
87 index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
88 }
89
90 Type elementPtrType = memRefDescriptor.getElementPtrType();
91 return index ? rewriter.create<LLVM::GEPOp>(
92 loc, elementPtrType,
93 getTypeConverter()->convertType(type.getElementType()),
94 base, index)
95 : base;
96}
97
98// Check if the MemRefType `type` is supported by the lowering. We currently
99// only support memrefs with identity maps.
100bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
101 MemRefType type) const {
102 if (!typeConverter->convertType(type.getElementType()))
103 return false;
104 return type.getLayout().isIdentity();
105}
106
107Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
108 auto addressSpace = getTypeConverter()->getMemRefAddressSpace(type: type);
109 if (failed(addressSpace))
110 return {};
111 return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
112}
113
114void ConvertToLLVMPattern::getMemRefDescriptorSizes(
115 Location loc, MemRefType memRefType, ValueRange dynamicSizes,
116 ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes,
117 SmallVectorImpl<Value> &strides, Value &size, bool sizeInBytes) const {
118 assert(isConvertibleAndHasIdentityMaps(memRefType) &&
119 "layout maps must have been normalized away");
120 assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
121 static_cast<ssize_t>(dynamicSizes.size()) &&
122 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
123
124 sizes.reserve(N: memRefType.getRank());
125 unsigned dynamicIndex = 0;
126 Type indexType = getIndexType();
127 for (int64_t size : memRefType.getShape()) {
128 sizes.push_back(
129 size == ShapedType::kDynamic
130 ? dynamicSizes[dynamicIndex++]
131 : createIndexAttrConstant(rewriter, loc, indexType, size));
132 }
133
134 // Strides: iterate sizes in reverse order and multiply.
135 int64_t stride = 1;
136 Value runningStride = createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: 1);
137 strides.resize(memRefType.getRank());
138 for (auto i = memRefType.getRank(); i-- > 0;) {
139 strides[i] = runningStride;
140
141 int64_t staticSize = memRefType.getShape()[i];
142 if (staticSize == 0)
143 continue;
144 bool useSizeAsStride = stride == 1;
145 if (staticSize == ShapedType::kDynamic)
146 stride = ShapedType::kDynamic;
147 if (stride != ShapedType::kDynamic)
148 stride *= staticSize;
149
150 if (useSizeAsStride)
151 runningStride = sizes[i];
152 else if (stride == ShapedType::kDynamic)
153 runningStride =
154 rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
155 else
156 runningStride = createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: stride);
157 }
158 if (sizeInBytes) {
159 // Buffer size in bytes.
160 Type elementType = typeConverter->convertType(memRefType.getElementType());
161 auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
162 Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
163 Value gepPtr = rewriter.create<LLVM::GEPOp>(
164 loc, elementPtrType, elementType, nullPtr, runningStride);
165 size = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
166 } else {
167 size = runningStride;
168 }
169}
170
171Value ConvertToLLVMPattern::getSizeInBytes(
172 Location loc, Type type, ConversionPatternRewriter &rewriter) const {
173 // Compute the size of an individual element. This emits the MLIR equivalent
174 // of the following sizeof(...) implementation in LLVM IR:
175 // %0 = getelementptr %elementType* null, %indexType 1
176 // %1 = ptrtoint %elementType* %0 to %indexType
177 // which is a common pattern of getting the size of a type in bytes.
178 Type llvmType = typeConverter->convertType(t: type);
179 auto convertedPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
180 auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, convertedPtrType);
181 auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, llvmType,
182 nullPtr, ArrayRef<LLVM::GEPArg>{1});
183 return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
184}
185
186Value ConvertToLLVMPattern::getNumElements(
187 Location loc, MemRefType memRefType, ValueRange dynamicSizes,
188 ConversionPatternRewriter &rewriter) const {
189 assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
190 static_cast<ssize_t>(dynamicSizes.size()) &&
191 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
192
193 Type indexType = getIndexType();
194 Value numElements = memRefType.getRank() == 0
195 ? createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: 1)
196 : nullptr;
197 unsigned dynamicIndex = 0;
198
199 // Compute the total number of memref elements.
200 for (int64_t staticSize : memRefType.getShape()) {
201 if (numElements) {
202 Value size =
203 staticSize == ShapedType::kDynamic
204 ? dynamicSizes[dynamicIndex++]
205 : createIndexAttrConstant(rewriter, loc, indexType, staticSize);
206 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
207 } else {
208 numElements =
209 staticSize == ShapedType::kDynamic
210 ? dynamicSizes[dynamicIndex++]
211 : createIndexAttrConstant(rewriter, loc, indexType, staticSize);
212 }
213 }
214 return numElements;
215}
216
217/// Creates and populates the memref descriptor struct given all its fields.
218MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
219 Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
220 ArrayRef<Value> sizes, ArrayRef<Value> strides,
221 ConversionPatternRewriter &rewriter) const {
222 auto structType = typeConverter->convertType(memRefType);
223 auto memRefDescriptor = MemRefDescriptor::undef(builder&: rewriter, loc, descriptorType: structType);
224
225 // Field 1: Allocated pointer, used for malloc/free.
226 memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
227
228 // Field 2: Actual aligned pointer to payload.
229 memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
230
231 // Field 3: Offset in aligned pointer.
232 Type indexType = getIndexType();
233 memRefDescriptor.setOffset(
234 rewriter, loc, createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: 0));
235
236 // Fields 4: Sizes.
237 for (const auto &en : llvm::enumerate(First&: sizes))
238 memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
239
240 // Field 5: Strides.
241 for (const auto &en : llvm::enumerate(First&: strides))
242 memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
243
244 return memRefDescriptor;
245}
246
247LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
248 OpBuilder &builder, Location loc, TypeRange origTypes,
249 SmallVectorImpl<Value> &operands, bool toDynamic) const {
250 assert(origTypes.size() == operands.size() &&
251 "expected as may original types as operands");
252
253 // Find operands of unranked memref type and store them.
254 SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs;
255 SmallVector<unsigned> unrankedAddressSpaces;
256 for (unsigned i = 0, e = operands.size(); i < e; ++i) {
257 if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
258 unrankedMemrefs.emplace_back(Args&: operands[i]);
259 FailureOr<unsigned> addressSpace =
260 getTypeConverter()->getMemRefAddressSpace(type: memRefType);
261 if (failed(result: addressSpace))
262 return failure();
263 unrankedAddressSpaces.emplace_back(Args&: *addressSpace);
264 }
265 }
266
267 if (unrankedMemrefs.empty())
268 return success();
269
270 // Compute allocation sizes.
271 SmallVector<Value> sizes;
272 UnrankedMemRefDescriptor::computeSizes(builder, loc, typeConverter: *getTypeConverter(),
273 values: unrankedMemrefs, addressSpaces: unrankedAddressSpaces,
274 sizes);
275
276 // Get frequently used types.
277 Type indexType = getTypeConverter()->getIndexType();
278
279 // Find the malloc and free, or declare them if necessary.
280 auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
281 LLVM::LLVMFuncOp freeFunc, mallocFunc;
282 if (toDynamic)
283 mallocFunc = LLVM::lookupOrCreateMallocFn(moduleOp: module, indexType);
284 if (!toDynamic)
285 freeFunc = LLVM::lookupOrCreateFreeFn(moduleOp: module);
286
287 unsigned unrankedMemrefPos = 0;
288 for (unsigned i = 0, e = operands.size(); i < e; ++i) {
289 Type type = origTypes[i];
290 if (!isa<UnrankedMemRefType>(Val: type))
291 continue;
292 Value allocationSize = sizes[unrankedMemrefPos++];
293 UnrankedMemRefDescriptor desc(operands[i]);
294
295 // Allocate memory, copy, and free the source if necessary.
296 Value memory =
297 toDynamic
298 ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
299 .getResult()
300 : builder.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
301 IntegerType::get(getContext(), 8),
302 allocationSize,
303 /*alignment=*/0);
304 Value source = desc.memRefDescPtr(builder, loc);
305 builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false);
306 if (!toDynamic)
307 builder.create<LLVM::CallOp>(loc, freeFunc, source);
308
309 // Create a new descriptor. The same descriptor can be returned multiple
310 // times, attempting to modify its pointer can lead to memory leaks
311 // (allocated twice and overwritten) or double frees (the caller does not
312 // know if the descriptor points to the same memory).
313 Type descriptorType = getTypeConverter()->convertType(t: type);
314 if (!descriptorType)
315 return failure();
316 auto updatedDesc =
317 UnrankedMemRefDescriptor::undef(builder, loc, descriptorType);
318 Value rank = desc.rank(builder, loc);
319 updatedDesc.setRank(builder, loc, value: rank);
320 updatedDesc.setMemRefDescPtr(builder, loc, value: memory);
321
322 operands[i] = updatedDesc;
323 }
324
325 return success();
326}
327
328//===----------------------------------------------------------------------===//
329// Detail methods
330//===----------------------------------------------------------------------===//
331
332void LLVM::detail::setNativeProperties(Operation *op,
333 IntegerOverflowFlags overflowFlags) {
334 if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
335 iface.setOverflowFlags(overflowFlags);
336}
337
338/// Replaces the given operation "op" with a new operation of type "targetOp"
339/// and given operands.
340LogicalResult LLVM::detail::oneToOneRewrite(
341 Operation *op, StringRef targetOp, ValueRange operands,
342 ArrayRef<NamedAttribute> targetAttrs,
343 const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
344 IntegerOverflowFlags overflowFlags) {
345 unsigned numResults = op->getNumResults();
346
347 SmallVector<Type> resultTypes;
348 if (numResults != 0) {
349 resultTypes.push_back(
350 Elt: typeConverter.packOperationResults(types: op->getResultTypes()));
351 if (!resultTypes.back())
352 return failure();
353 }
354
355 // Create the operation through state since we don't know its C++ type.
356 Operation *newOp =
357 rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
358 resultTypes, targetAttrs);
359
360 setNativeProperties(newOp, overflowFlags);
361
362 // If the operation produced 0 or 1 result, return them immediately.
363 if (numResults == 0)
364 return rewriter.eraseOp(op), success();
365 if (numResults == 1)
366 return rewriter.replaceOp(op, newValues: newOp->getResult(idx: 0)), success();
367
368 // Otherwise, it had been converted to an operation producing a structure.
369 // Extract individual results from the structure and return them as list.
370 SmallVector<Value, 4> results;
371 results.reserve(N: numResults);
372 for (unsigned i = 0; i < numResults; ++i) {
373 results.push_back(rewriter.create<LLVM::ExtractValueOp>(
374 op->getLoc(), newOp->getResult(0), i));
375 }
376 rewriter.replaceOp(op, newValues: results);
377 return success();
378}
379

source code of mlir/lib/Conversion/LLVMCommon/Pattern.cpp