1//===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/LLVMCommon/Pattern.h"
10#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
11#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
13#include "mlir/IR/AffineMap.h"
14#include "mlir/IR/BuiltinAttributes.h"
15
16using namespace mlir;
17
18//===----------------------------------------------------------------------===//
19// ConvertToLLVMPattern
20//===----------------------------------------------------------------------===//
21
22ConvertToLLVMPattern::ConvertToLLVMPattern(
23 StringRef rootOpName, MLIRContext *context,
24 const LLVMTypeConverter &typeConverter, PatternBenefit benefit)
25 : ConversionPattern(typeConverter, rootOpName, benefit, context) {}
26
27const LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const {
28 return static_cast<const LLVMTypeConverter *>(
29 ConversionPattern::getTypeConverter());
30}
31
32LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
33 return *getTypeConverter()->getDialect();
34}
35
36Type ConvertToLLVMPattern::getIndexType() const {
37 return getTypeConverter()->getIndexType();
38}
39
40Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
41 return IntegerType::get(&getTypeConverter()->getContext(),
42 getTypeConverter()->getPointerBitwidth(addressSpace));
43}
44
45Type ConvertToLLVMPattern::getVoidType() const {
46 return LLVM::LLVMVoidType::get(ctx: &getTypeConverter()->getContext());
47}
48
49Type ConvertToLLVMPattern::getVoidPtrType() const {
50 return LLVM::LLVMPointerType::get(&getTypeConverter()->getContext());
51}
52
53Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
54 Location loc,
55 Type resultType,
56 int64_t value) {
57 return builder.create<LLVM::ConstantOp>(loc, resultType,
58 builder.getIndexAttr(value));
59}
60
61Value ConvertToLLVMPattern::getStridedElementPtr(
62 ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
63 Value memRefDesc, ValueRange indices,
64 LLVM::GEPNoWrapFlags noWrapFlags) const {
65 return LLVM::getStridedElementPtr(rewriter, loc, *getTypeConverter(), type,
66 memRefDesc, indices, noWrapFlags);
67}
68
69// Check if the MemRefType `type` is supported by the lowering. We currently
70// only support memrefs with identity maps.
71bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
72 MemRefType type) const {
73 if (!type.getLayout().isIdentity())
74 return false;
75 return static_cast<bool>(typeConverter->convertType(type));
76}
77
78Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
79 auto addressSpace = getTypeConverter()->getMemRefAddressSpace(type: type);
80 if (failed(addressSpace))
81 return {};
82 return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
83}
84
85void ConvertToLLVMPattern::getMemRefDescriptorSizes(
86 Location loc, MemRefType memRefType, ValueRange dynamicSizes,
87 ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes,
88 SmallVectorImpl<Value> &strides, Value &size, bool sizeInBytes) const {
89 assert(isConvertibleAndHasIdentityMaps(memRefType) &&
90 "layout maps must have been normalized away");
91 assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
92 static_cast<ssize_t>(dynamicSizes.size()) &&
93 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
94
95 sizes.reserve(N: memRefType.getRank());
96 unsigned dynamicIndex = 0;
97 Type indexType = getIndexType();
98 for (int64_t size : memRefType.getShape()) {
99 sizes.push_back(
100 size == ShapedType::kDynamic
101 ? dynamicSizes[dynamicIndex++]
102 : createIndexAttrConstant(rewriter, loc, indexType, size));
103 }
104
105 // Strides: iterate sizes in reverse order and multiply.
106 int64_t stride = 1;
107 Value runningStride = createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: 1);
108 strides.resize(memRefType.getRank());
109 for (auto i = memRefType.getRank(); i-- > 0;) {
110 strides[i] = runningStride;
111
112 int64_t staticSize = memRefType.getShape()[i];
113 bool useSizeAsStride = stride == 1;
114 if (staticSize == ShapedType::kDynamic)
115 stride = ShapedType::kDynamic;
116 if (stride != ShapedType::kDynamic)
117 stride *= staticSize;
118
119 if (useSizeAsStride)
120 runningStride = sizes[i];
121 else if (stride == ShapedType::kDynamic)
122 runningStride =
123 rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
124 else
125 runningStride = createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: stride);
126 }
127 if (sizeInBytes) {
128 // Buffer size in bytes.
129 Type elementType = typeConverter->convertType(memRefType.getElementType());
130 auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
131 Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
132 Value gepPtr = rewriter.create<LLVM::GEPOp>(
133 loc, elementPtrType, elementType, nullPtr, runningStride);
134 size = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
135 } else {
136 size = runningStride;
137 }
138}
139
140Value ConvertToLLVMPattern::getSizeInBytes(
141 Location loc, Type type, ConversionPatternRewriter &rewriter) const {
142 // Compute the size of an individual element. This emits the MLIR equivalent
143 // of the following sizeof(...) implementation in LLVM IR:
144 // %0 = getelementptr %elementType* null, %indexType 1
145 // %1 = ptrtoint %elementType* %0 to %indexType
146 // which is a common pattern of getting the size of a type in bytes.
147 Type llvmType = typeConverter->convertType(t: type);
148 auto convertedPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
149 auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, convertedPtrType);
150 auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, llvmType,
151 nullPtr, ArrayRef<LLVM::GEPArg>{1});
152 return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
153}
154
155Value ConvertToLLVMPattern::getNumElements(
156 Location loc, MemRefType memRefType, ValueRange dynamicSizes,
157 ConversionPatternRewriter &rewriter) const {
158 assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
159 static_cast<ssize_t>(dynamicSizes.size()) &&
160 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
161
162 Type indexType = getIndexType();
163 Value numElements = memRefType.getRank() == 0
164 ? createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: 1)
165 : nullptr;
166 unsigned dynamicIndex = 0;
167
168 // Compute the total number of memref elements.
169 for (int64_t staticSize : memRefType.getShape()) {
170 if (numElements) {
171 Value size =
172 staticSize == ShapedType::kDynamic
173 ? dynamicSizes[dynamicIndex++]
174 : createIndexAttrConstant(rewriter, loc, indexType, staticSize);
175 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
176 } else {
177 numElements =
178 staticSize == ShapedType::kDynamic
179 ? dynamicSizes[dynamicIndex++]
180 : createIndexAttrConstant(rewriter, loc, indexType, staticSize);
181 }
182 }
183 return numElements;
184}
185
186/// Creates and populates the memref descriptor struct given all its fields.
187MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
188 Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
189 ArrayRef<Value> sizes, ArrayRef<Value> strides,
190 ConversionPatternRewriter &rewriter) const {
191 auto structType = typeConverter->convertType(memRefType);
192 auto memRefDescriptor = MemRefDescriptor::poison(builder&: rewriter, loc, descriptorType: structType);
193
194 // Field 1: Allocated pointer, used for malloc/free.
195 memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
196
197 // Field 2: Actual aligned pointer to payload.
198 memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
199
200 // Field 3: Offset in aligned pointer.
201 Type indexType = getIndexType();
202 memRefDescriptor.setOffset(
203 rewriter, loc, createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: 0));
204
205 // Fields 4: Sizes.
206 for (const auto &en : llvm::enumerate(First&: sizes))
207 memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
208
209 // Field 5: Strides.
210 for (const auto &en : llvm::enumerate(First&: strides))
211 memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
212
213 return memRefDescriptor;
214}
215
216LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
217 OpBuilder &builder, Location loc, TypeRange origTypes,
218 SmallVectorImpl<Value> &operands, bool toDynamic) const {
219 assert(origTypes.size() == operands.size() &&
220 "expected as may original types as operands");
221
222 // Find operands of unranked memref type and store them.
223 SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs;
224 SmallVector<unsigned> unrankedAddressSpaces;
225 for (unsigned i = 0, e = operands.size(); i < e; ++i) {
226 if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
227 unrankedMemrefs.emplace_back(Args&: operands[i]);
228 FailureOr<unsigned> addressSpace =
229 getTypeConverter()->getMemRefAddressSpace(type: memRefType);
230 if (failed(Result: addressSpace))
231 return failure();
232 unrankedAddressSpaces.emplace_back(Args&: *addressSpace);
233 }
234 }
235
236 if (unrankedMemrefs.empty())
237 return success();
238
239 // Compute allocation sizes.
240 SmallVector<Value> sizes;
241 UnrankedMemRefDescriptor::computeSizes(builder, loc, typeConverter: *getTypeConverter(),
242 values: unrankedMemrefs, addressSpaces: unrankedAddressSpaces,
243 sizes);
244
245 // Get frequently used types.
246 Type indexType = getTypeConverter()->getIndexType();
247
248 // Find the malloc and free, or declare them if necessary.
249 auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
250 FailureOr<LLVM::LLVMFuncOp> freeFunc, mallocFunc;
251 if (toDynamic) {
252 mallocFunc = LLVM::lookupOrCreateMallocFn(b&: builder, moduleOp: module, indexType);
253 if (failed(Result: mallocFunc))
254 return failure();
255 }
256 if (!toDynamic) {
257 freeFunc = LLVM::lookupOrCreateFreeFn(b&: builder, moduleOp: module);
258 if (failed(freeFunc))
259 return failure();
260 }
261
262 unsigned unrankedMemrefPos = 0;
263 for (unsigned i = 0, e = operands.size(); i < e; ++i) {
264 Type type = origTypes[i];
265 if (!isa<UnrankedMemRefType>(Val: type))
266 continue;
267 Value allocationSize = sizes[unrankedMemrefPos++];
268 UnrankedMemRefDescriptor desc(operands[i]);
269
270 // Allocate memory, copy, and free the source if necessary.
271 Value memory =
272 toDynamic
273 ? builder
274 .create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize)
275 .getResult()
276 : builder.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
277 IntegerType::get(getContext(), 8),
278 allocationSize,
279 /*alignment=*/0);
280 Value source = desc.memRefDescPtr(builder, loc);
281 builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false);
282 if (!toDynamic)
283 builder.create<LLVM::CallOp>(loc, freeFunc.value(), source);
284
285 // Create a new descriptor. The same descriptor can be returned multiple
286 // times, attempting to modify its pointer can lead to memory leaks
287 // (allocated twice and overwritten) or double frees (the caller does not
288 // know if the descriptor points to the same memory).
289 Type descriptorType = getTypeConverter()->convertType(t: type);
290 if (!descriptorType)
291 return failure();
292 auto updatedDesc =
293 UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
294 Value rank = desc.rank(builder, loc);
295 updatedDesc.setRank(builder, loc, value: rank);
296 updatedDesc.setMemRefDescPtr(builder, loc, value: memory);
297
298 operands[i] = updatedDesc;
299 }
300
301 return success();
302}
303
304//===----------------------------------------------------------------------===//
305// Detail methods
306//===----------------------------------------------------------------------===//
307
308void LLVM::detail::setNativeProperties(Operation *op,
309 IntegerOverflowFlags overflowFlags) {
310 if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
311 iface.setOverflowFlags(overflowFlags);
312}
313
314/// Replaces the given operation "op" with a new operation of type "targetOp"
315/// and given operands.
316LogicalResult LLVM::detail::oneToOneRewrite(
317 Operation *op, StringRef targetOp, ValueRange operands,
318 ArrayRef<NamedAttribute> targetAttrs,
319 const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
320 IntegerOverflowFlags overflowFlags) {
321 unsigned numResults = op->getNumResults();
322
323 SmallVector<Type> resultTypes;
324 if (numResults != 0) {
325 resultTypes.push_back(
326 Elt: typeConverter.packOperationResults(types: op->getResultTypes()));
327 if (!resultTypes.back())
328 return failure();
329 }
330
331 // Create the operation through state since we don't know its C++ type.
332 Operation *newOp =
333 rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
334 resultTypes, targetAttrs);
335
336 setNativeProperties(newOp, overflowFlags);
337
338 // If the operation produced 0 or 1 result, return them immediately.
339 if (numResults == 0)
340 return rewriter.eraseOp(op), success();
341 if (numResults == 1)
342 return rewriter.replaceOp(op, newValues: newOp->getResult(idx: 0)), success();
343
344 // Otherwise, it had been converted to an operation producing a structure.
345 // Extract individual results from the structure and return them as list.
346 SmallVector<Value, 4> results;
347 results.reserve(N: numResults);
348 for (unsigned i = 0; i < numResults; ++i) {
349 results.push_back(rewriter.create<LLVM::ExtractValueOp>(
350 op->getLoc(), newOp->getResult(0), i));
351 }
352 rewriter.replaceOp(op, newValues: results);
353 return success();
354}
355
356LogicalResult LLVM::detail::intrinsicRewrite(
357 Operation *op, StringRef intrinsic, ValueRange operands,
358 const LLVMTypeConverter &typeConverter, RewriterBase &rewriter) {
359 auto loc = op->getLoc();
360
361 if (!llvm::all_of(Range&: operands, P: [](Value value) {
362 return LLVM::isCompatibleType(type: value.getType());
363 }))
364 return failure();
365
366 unsigned numResults = op->getNumResults();
367 Type resType;
368 if (numResults != 0)
369 resType = typeConverter.packOperationResults(types: op->getResultTypes());
370
371 auto callIntrOp = rewriter.create<LLVM::CallIntrinsicOp>(
372 loc, resType, rewriter.getStringAttr(intrinsic), operands);
373 // Propagate attributes.
374 callIntrOp->setAttrs(op->getAttrDictionary());
375
376 if (numResults <= 1) {
377 // Directly replace the original op.
378 rewriter.replaceOp(op, callIntrOp);
379 return success();
380 }
381
382 // Extract individual results from packed structure and use them as
383 // replacements.
384 SmallVector<Value, 4> results;
385 results.reserve(N: numResults);
386 Value intrRes = callIntrOp.getResults();
387 for (unsigned i = 0; i < numResults; ++i)
388 results.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, intrRes, i));
389 rewriter.replaceOp(op, newValues: results);
390
391 return success();
392}
393
394static unsigned getBitWidth(Type type) {
395 if (type.isIntOrFloat())
396 return type.getIntOrFloatBitWidth();
397
398 auto vec = cast<VectorType>(type);
399 assert(!vec.isScalable() && "scalable vectors are not supported");
400 return vec.getNumElements() * getBitWidth(vec.getElementType());
401}
402
403static Value createI32Constant(OpBuilder &builder, Location loc,
404 int32_t value) {
405 Type i32 = builder.getI32Type();
406 return builder.create<LLVM::ConstantOp>(loc, i32, value);
407}
408
409SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
410 Value src, Type dstType) {
411 Type srcType = src.getType();
412 if (srcType == dstType)
413 return {src};
414
415 unsigned srcBitWidth = getBitWidth(type: srcType);
416 unsigned dstBitWidth = getBitWidth(type: dstType);
417 if (srcBitWidth == dstBitWidth) {
418 Value cast = builder.create<LLVM::BitcastOp>(loc, dstType, src);
419 return {cast};
420 }
421
422 if (dstBitWidth > srcBitWidth) {
423 auto smallerInt = builder.getIntegerType(srcBitWidth);
424 if (srcType != smallerInt)
425 src = builder.create<LLVM::BitcastOp>(loc, smallerInt, src);
426
427 auto largerInt = builder.getIntegerType(dstBitWidth);
428 Value res = builder.create<LLVM::ZExtOp>(loc, largerInt, src);
429 return {res};
430 }
431 assert(srcBitWidth % dstBitWidth == 0 &&
432 "src bit width must be a multiple of dst bit width");
433 int64_t numElements = srcBitWidth / dstBitWidth;
434 auto vecType = VectorType::get(numElements, dstType);
435
436 src = builder.create<LLVM::BitcastOp>(loc, vecType, src);
437
438 SmallVector<Value> res;
439 for (auto i : llvm::seq(Size: numElements)) {
440 Value idx = createI32Constant(builder, loc, value: i);
441 Value elem = builder.create<LLVM::ExtractElementOp>(loc, src, idx);
442 res.emplace_back(Args&: elem);
443 }
444
445 return res;
446}
447
448Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
449 Type dstType) {
450 assert(!src.empty() && "src range must not be empty");
451 if (src.size() == 1) {
452 Value res = src.front();
453 if (res.getType() == dstType)
454 return res;
455
456 unsigned srcBitWidth = getBitWidth(type: res.getType());
457 unsigned dstBitWidth = getBitWidth(type: dstType);
458 if (dstBitWidth < srcBitWidth) {
459 auto largerInt = builder.getIntegerType(srcBitWidth);
460 if (res.getType() != largerInt)
461 res = builder.create<LLVM::BitcastOp>(loc, largerInt, res);
462
463 auto smallerInt = builder.getIntegerType(dstBitWidth);
464 res = builder.create<LLVM::TruncOp>(loc, smallerInt, res);
465 }
466
467 if (res.getType() != dstType)
468 res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
469
470 return res;
471 }
472
473 int64_t numElements = src.size();
474 auto srcType = VectorType::get(numElements, src.front().getType());
475 Value res = builder.create<LLVM::PoisonOp>(loc, srcType);
476 for (auto &&[i, elem] : llvm::enumerate(First&: src)) {
477 Value idx = createI32Constant(builder, loc, value: i);
478 res = builder.create<LLVM::InsertElementOp>(loc, srcType, res, elem, idx);
479 }
480
481 if (res.getType() != dstType)
482 res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
483
484 return res;
485}
486
487Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc,
488 const LLVMTypeConverter &converter,
489 MemRefType type, Value memRefDesc,
490 ValueRange indices,
491 LLVM::GEPNoWrapFlags noWrapFlags) {
492 auto [strides, offset] = type.getStridesAndOffset();
493
494 MemRefDescriptor memRefDescriptor(memRefDesc);
495 // Use a canonical representation of the start address so that later
496 // optimizations have a longer sequence of instructions to CSE.
497 // If we don't do that we would sprinkle the memref.offset in various
498 // position of the different address computations.
499 Value base = memRefDescriptor.bufferPtr(builder, loc, converter, type: type);
500
501 LLVM::IntegerOverflowFlags intOverflowFlags =
502 LLVM::IntegerOverflowFlags::none;
503 if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) {
504 intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw;
505 }
506 if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) {
507 intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw;
508 }
509
510 Type indexType = converter.getIndexType();
511 Value index;
512 for (int i = 0, e = indices.size(); i < e; ++i) {
513 Value increment = indices[i];
514 if (strides[i] != 1) { // Skip if stride is 1.
515 Value stride =
516 ShapedType::isDynamic(strides[i])
517 ? memRefDescriptor.stride(builder, loc, i)
518 : builder.create<LLVM::ConstantOp>(
519 loc, indexType, builder.getIndexAttr(strides[i]));
520 increment =
521 builder.create<LLVM::MulOp>(loc, increment, stride, intOverflowFlags);
522 }
523 index = index ? builder.create<LLVM::AddOp>(loc, index, increment,
524 intOverflowFlags)
525 : increment;
526 }
527
528 Type elementPtrType = memRefDescriptor.getElementPtrType();
529 return index ? builder.create<LLVM::GEPOp>(
530 loc, elementPtrType,
531 converter.convertType(type.getElementType()), base, index,
532 noWrapFlags)
533 : base;
534}
535

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Conversion/LLVMCommon/Pattern.cpp