| 1 | //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
| 10 | #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" |
| 11 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 12 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
| 13 | #include "mlir/IR/AffineMap.h" |
| 14 | #include "mlir/IR/BuiltinAttributes.h" |
| 15 | |
| 16 | using namespace mlir; |
| 17 | |
| 18 | //===----------------------------------------------------------------------===// |
| 19 | // ConvertToLLVMPattern |
| 20 | //===----------------------------------------------------------------------===// |
| 21 | |
| 22 | ConvertToLLVMPattern::ConvertToLLVMPattern( |
| 23 | StringRef rootOpName, MLIRContext *context, |
| 24 | const LLVMTypeConverter &typeConverter, PatternBenefit benefit) |
| 25 | : ConversionPattern(typeConverter, rootOpName, benefit, context) {} |
| 26 | |
| 27 | const LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const { |
| 28 | return static_cast<const LLVMTypeConverter *>( |
| 29 | ConversionPattern::getTypeConverter()); |
| 30 | } |
| 31 | |
| 32 | LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const { |
| 33 | return *getTypeConverter()->getDialect(); |
| 34 | } |
| 35 | |
| 36 | Type ConvertToLLVMPattern::getIndexType() const { |
| 37 | return getTypeConverter()->getIndexType(); |
| 38 | } |
| 39 | |
| 40 | Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const { |
| 41 | return IntegerType::get(&getTypeConverter()->getContext(), |
| 42 | getTypeConverter()->getPointerBitwidth(addressSpace)); |
| 43 | } |
| 44 | |
| 45 | Type ConvertToLLVMPattern::getVoidType() const { |
| 46 | return LLVM::LLVMVoidType::get(ctx: &getTypeConverter()->getContext()); |
| 47 | } |
| 48 | |
| 49 | Type ConvertToLLVMPattern::getVoidPtrType() const { |
| 50 | return LLVM::LLVMPointerType::get(&getTypeConverter()->getContext()); |
| 51 | } |
| 52 | |
| 53 | Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder, |
| 54 | Location loc, |
| 55 | Type resultType, |
| 56 | int64_t value) { |
| 57 | return builder.create<LLVM::ConstantOp>(loc, resultType, |
| 58 | builder.getIndexAttr(value)); |
| 59 | } |
| 60 | |
| 61 | Value ConvertToLLVMPattern::getStridedElementPtr( |
| 62 | ConversionPatternRewriter &rewriter, Location loc, MemRefType type, |
| 63 | Value memRefDesc, ValueRange indices, |
| 64 | LLVM::GEPNoWrapFlags noWrapFlags) const { |
| 65 | return LLVM::getStridedElementPtr(rewriter, loc, *getTypeConverter(), type, |
| 66 | memRefDesc, indices, noWrapFlags); |
| 67 | } |
| 68 | |
| 69 | // Check if the MemRefType `type` is supported by the lowering. We currently |
| 70 | // only support memrefs with identity maps. |
| 71 | bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps( |
| 72 | MemRefType type) const { |
| 73 | if (!type.getLayout().isIdentity()) |
| 74 | return false; |
| 75 | return static_cast<bool>(typeConverter->convertType(type)); |
| 76 | } |
| 77 | |
| 78 | Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { |
| 79 | auto addressSpace = getTypeConverter()->getMemRefAddressSpace(type: type); |
| 80 | if (failed(addressSpace)) |
| 81 | return {}; |
| 82 | return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace); |
| 83 | } |
| 84 | |
| 85 | void ConvertToLLVMPattern::getMemRefDescriptorSizes( |
| 86 | Location loc, MemRefType memRefType, ValueRange dynamicSizes, |
| 87 | ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes, |
| 88 | SmallVectorImpl<Value> &strides, Value &size, bool sizeInBytes) const { |
| 89 | assert(isConvertibleAndHasIdentityMaps(memRefType) && |
| 90 | "layout maps must have been normalized away" ); |
| 91 | assert(count(memRefType.getShape(), ShapedType::kDynamic) == |
| 92 | static_cast<ssize_t>(dynamicSizes.size()) && |
| 93 | "dynamicSizes size doesn't match dynamic sizes count in memref shape" ); |
| 94 | |
| 95 | sizes.reserve(N: memRefType.getRank()); |
| 96 | unsigned dynamicIndex = 0; |
| 97 | Type indexType = getIndexType(); |
| 98 | for (int64_t size : memRefType.getShape()) { |
| 99 | sizes.push_back( |
| 100 | size == ShapedType::kDynamic |
| 101 | ? dynamicSizes[dynamicIndex++] |
| 102 | : createIndexAttrConstant(rewriter, loc, indexType, size)); |
| 103 | } |
| 104 | |
| 105 | // Strides: iterate sizes in reverse order and multiply. |
| 106 | int64_t stride = 1; |
| 107 | Value runningStride = createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: 1); |
| 108 | strides.resize(memRefType.getRank()); |
| 109 | for (auto i = memRefType.getRank(); i-- > 0;) { |
| 110 | strides[i] = runningStride; |
| 111 | |
| 112 | int64_t staticSize = memRefType.getShape()[i]; |
| 113 | bool useSizeAsStride = stride == 1; |
| 114 | if (staticSize == ShapedType::kDynamic) |
| 115 | stride = ShapedType::kDynamic; |
| 116 | if (stride != ShapedType::kDynamic) |
| 117 | stride *= staticSize; |
| 118 | |
| 119 | if (useSizeAsStride) |
| 120 | runningStride = sizes[i]; |
| 121 | else if (stride == ShapedType::kDynamic) |
| 122 | runningStride = |
| 123 | rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]); |
| 124 | else |
| 125 | runningStride = createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: stride); |
| 126 | } |
| 127 | if (sizeInBytes) { |
| 128 | // Buffer size in bytes. |
| 129 | Type elementType = typeConverter->convertType(memRefType.getElementType()); |
| 130 | auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); |
| 131 | Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType); |
| 132 | Value gepPtr = rewriter.create<LLVM::GEPOp>( |
| 133 | loc, elementPtrType, elementType, nullPtr, runningStride); |
| 134 | size = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr); |
| 135 | } else { |
| 136 | size = runningStride; |
| 137 | } |
| 138 | } |
| 139 | |
| 140 | Value ConvertToLLVMPattern::getSizeInBytes( |
| 141 | Location loc, Type type, ConversionPatternRewriter &rewriter) const { |
| 142 | // Compute the size of an individual element. This emits the MLIR equivalent |
| 143 | // of the following sizeof(...) implementation in LLVM IR: |
| 144 | // %0 = getelementptr %elementType* null, %indexType 1 |
| 145 | // %1 = ptrtoint %elementType* %0 to %indexType |
| 146 | // which is a common pattern of getting the size of a type in bytes. |
| 147 | Type llvmType = typeConverter->convertType(t: type); |
| 148 | auto convertedPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); |
| 149 | auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, convertedPtrType); |
| 150 | auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, llvmType, |
| 151 | nullPtr, ArrayRef<LLVM::GEPArg>{1}); |
| 152 | return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep); |
| 153 | } |
| 154 | |
| 155 | Value ConvertToLLVMPattern::getNumElements( |
| 156 | Location loc, MemRefType memRefType, ValueRange dynamicSizes, |
| 157 | ConversionPatternRewriter &rewriter) const { |
| 158 | assert(count(memRefType.getShape(), ShapedType::kDynamic) == |
| 159 | static_cast<ssize_t>(dynamicSizes.size()) && |
| 160 | "dynamicSizes size doesn't match dynamic sizes count in memref shape" ); |
| 161 | |
| 162 | Type indexType = getIndexType(); |
| 163 | Value numElements = memRefType.getRank() == 0 |
| 164 | ? createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: 1) |
| 165 | : nullptr; |
| 166 | unsigned dynamicIndex = 0; |
| 167 | |
| 168 | // Compute the total number of memref elements. |
| 169 | for (int64_t staticSize : memRefType.getShape()) { |
| 170 | if (numElements) { |
| 171 | Value size = |
| 172 | staticSize == ShapedType::kDynamic |
| 173 | ? dynamicSizes[dynamicIndex++] |
| 174 | : createIndexAttrConstant(rewriter, loc, indexType, staticSize); |
| 175 | numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); |
| 176 | } else { |
| 177 | numElements = |
| 178 | staticSize == ShapedType::kDynamic |
| 179 | ? dynamicSizes[dynamicIndex++] |
| 180 | : createIndexAttrConstant(rewriter, loc, indexType, staticSize); |
| 181 | } |
| 182 | } |
| 183 | return numElements; |
| 184 | } |
| 185 | |
| 186 | /// Creates and populates the memref descriptor struct given all its fields. |
| 187 | MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor( |
| 188 | Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, |
| 189 | ArrayRef<Value> sizes, ArrayRef<Value> strides, |
| 190 | ConversionPatternRewriter &rewriter) const { |
| 191 | auto structType = typeConverter->convertType(memRefType); |
| 192 | auto memRefDescriptor = MemRefDescriptor::poison(builder&: rewriter, loc, descriptorType: structType); |
| 193 | |
| 194 | // Field 1: Allocated pointer, used for malloc/free. |
| 195 | memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr); |
| 196 | |
| 197 | // Field 2: Actual aligned pointer to payload. |
| 198 | memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr); |
| 199 | |
| 200 | // Field 3: Offset in aligned pointer. |
| 201 | Type indexType = getIndexType(); |
| 202 | memRefDescriptor.setOffset( |
| 203 | rewriter, loc, createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: 0)); |
| 204 | |
| 205 | // Fields 4: Sizes. |
| 206 | for (const auto &en : llvm::enumerate(First&: sizes)) |
| 207 | memRefDescriptor.setSize(rewriter, loc, en.index(), en.value()); |
| 208 | |
| 209 | // Field 5: Strides. |
| 210 | for (const auto &en : llvm::enumerate(First&: strides)) |
| 211 | memRefDescriptor.setStride(rewriter, loc, en.index(), en.value()); |
| 212 | |
| 213 | return memRefDescriptor; |
| 214 | } |
| 215 | |
| 216 | LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( |
| 217 | OpBuilder &builder, Location loc, TypeRange origTypes, |
| 218 | SmallVectorImpl<Value> &operands, bool toDynamic) const { |
| 219 | assert(origTypes.size() == operands.size() && |
| 220 | "expected as may original types as operands" ); |
| 221 | |
| 222 | // Find operands of unranked memref type and store them. |
| 223 | SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs; |
| 224 | SmallVector<unsigned> unrankedAddressSpaces; |
| 225 | for (unsigned i = 0, e = operands.size(); i < e; ++i) { |
| 226 | if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) { |
| 227 | unrankedMemrefs.emplace_back(Args&: operands[i]); |
| 228 | FailureOr<unsigned> addressSpace = |
| 229 | getTypeConverter()->getMemRefAddressSpace(type: memRefType); |
| 230 | if (failed(Result: addressSpace)) |
| 231 | return failure(); |
| 232 | unrankedAddressSpaces.emplace_back(Args&: *addressSpace); |
| 233 | } |
| 234 | } |
| 235 | |
| 236 | if (unrankedMemrefs.empty()) |
| 237 | return success(); |
| 238 | |
| 239 | // Compute allocation sizes. |
| 240 | SmallVector<Value> sizes; |
| 241 | UnrankedMemRefDescriptor::computeSizes(builder, loc, typeConverter: *getTypeConverter(), |
| 242 | values: unrankedMemrefs, addressSpaces: unrankedAddressSpaces, |
| 243 | sizes); |
| 244 | |
| 245 | // Get frequently used types. |
| 246 | Type indexType = getTypeConverter()->getIndexType(); |
| 247 | |
| 248 | // Find the malloc and free, or declare them if necessary. |
| 249 | auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>(); |
| 250 | FailureOr<LLVM::LLVMFuncOp> freeFunc, mallocFunc; |
| 251 | if (toDynamic) { |
| 252 | mallocFunc = LLVM::lookupOrCreateMallocFn(b&: builder, moduleOp: module, indexType); |
| 253 | if (failed(Result: mallocFunc)) |
| 254 | return failure(); |
| 255 | } |
| 256 | if (!toDynamic) { |
| 257 | freeFunc = LLVM::lookupOrCreateFreeFn(b&: builder, moduleOp: module); |
| 258 | if (failed(freeFunc)) |
| 259 | return failure(); |
| 260 | } |
| 261 | |
| 262 | unsigned unrankedMemrefPos = 0; |
| 263 | for (unsigned i = 0, e = operands.size(); i < e; ++i) { |
| 264 | Type type = origTypes[i]; |
| 265 | if (!isa<UnrankedMemRefType>(Val: type)) |
| 266 | continue; |
| 267 | Value allocationSize = sizes[unrankedMemrefPos++]; |
| 268 | UnrankedMemRefDescriptor desc(operands[i]); |
| 269 | |
| 270 | // Allocate memory, copy, and free the source if necessary. |
| 271 | Value memory = |
| 272 | toDynamic |
| 273 | ? builder |
| 274 | .create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize) |
| 275 | .getResult() |
| 276 | : builder.create<LLVM::AllocaOp>(loc, getVoidPtrType(), |
| 277 | IntegerType::get(getContext(), 8), |
| 278 | allocationSize, |
| 279 | /*alignment=*/0); |
| 280 | Value source = desc.memRefDescPtr(builder, loc); |
| 281 | builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false); |
| 282 | if (!toDynamic) |
| 283 | builder.create<LLVM::CallOp>(loc, freeFunc.value(), source); |
| 284 | |
| 285 | // Create a new descriptor. The same descriptor can be returned multiple |
| 286 | // times, attempting to modify its pointer can lead to memory leaks |
| 287 | // (allocated twice and overwritten) or double frees (the caller does not |
| 288 | // know if the descriptor points to the same memory). |
| 289 | Type descriptorType = getTypeConverter()->convertType(t: type); |
| 290 | if (!descriptorType) |
| 291 | return failure(); |
| 292 | auto updatedDesc = |
| 293 | UnrankedMemRefDescriptor::poison(builder, loc, descriptorType); |
| 294 | Value rank = desc.rank(builder, loc); |
| 295 | updatedDesc.setRank(builder, loc, value: rank); |
| 296 | updatedDesc.setMemRefDescPtr(builder, loc, value: memory); |
| 297 | |
| 298 | operands[i] = updatedDesc; |
| 299 | } |
| 300 | |
| 301 | return success(); |
| 302 | } |
| 303 | |
| 304 | //===----------------------------------------------------------------------===// |
| 305 | // Detail methods |
| 306 | //===----------------------------------------------------------------------===// |
| 307 | |
| 308 | void LLVM::detail::setNativeProperties(Operation *op, |
| 309 | IntegerOverflowFlags overflowFlags) { |
| 310 | if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op)) |
| 311 | iface.setOverflowFlags(overflowFlags); |
| 312 | } |
| 313 | |
| 314 | /// Replaces the given operation "op" with a new operation of type "targetOp" |
| 315 | /// and given operands. |
| 316 | LogicalResult LLVM::detail::oneToOneRewrite( |
| 317 | Operation *op, StringRef targetOp, ValueRange operands, |
| 318 | ArrayRef<NamedAttribute> targetAttrs, |
| 319 | const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, |
| 320 | IntegerOverflowFlags overflowFlags) { |
| 321 | unsigned numResults = op->getNumResults(); |
| 322 | |
| 323 | SmallVector<Type> resultTypes; |
| 324 | if (numResults != 0) { |
| 325 | resultTypes.push_back( |
| 326 | Elt: typeConverter.packOperationResults(types: op->getResultTypes())); |
| 327 | if (!resultTypes.back()) |
| 328 | return failure(); |
| 329 | } |
| 330 | |
| 331 | // Create the operation through state since we don't know its C++ type. |
| 332 | Operation *newOp = |
| 333 | rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, |
| 334 | resultTypes, targetAttrs); |
| 335 | |
| 336 | setNativeProperties(newOp, overflowFlags); |
| 337 | |
| 338 | // If the operation produced 0 or 1 result, return them immediately. |
| 339 | if (numResults == 0) |
| 340 | return rewriter.eraseOp(op), success(); |
| 341 | if (numResults == 1) |
| 342 | return rewriter.replaceOp(op, newValues: newOp->getResult(idx: 0)), success(); |
| 343 | |
| 344 | // Otherwise, it had been converted to an operation producing a structure. |
| 345 | // Extract individual results from the structure and return them as list. |
| 346 | SmallVector<Value, 4> results; |
| 347 | results.reserve(N: numResults); |
| 348 | for (unsigned i = 0; i < numResults; ++i) { |
| 349 | results.push_back(rewriter.create<LLVM::ExtractValueOp>( |
| 350 | op->getLoc(), newOp->getResult(0), i)); |
| 351 | } |
| 352 | rewriter.replaceOp(op, newValues: results); |
| 353 | return success(); |
| 354 | } |
| 355 | |
| 356 | LogicalResult LLVM::detail::intrinsicRewrite( |
| 357 | Operation *op, StringRef intrinsic, ValueRange operands, |
| 358 | const LLVMTypeConverter &typeConverter, RewriterBase &rewriter) { |
| 359 | auto loc = op->getLoc(); |
| 360 | |
| 361 | if (!llvm::all_of(Range&: operands, P: [](Value value) { |
| 362 | return LLVM::isCompatibleType(type: value.getType()); |
| 363 | })) |
| 364 | return failure(); |
| 365 | |
| 366 | unsigned numResults = op->getNumResults(); |
| 367 | Type resType; |
| 368 | if (numResults != 0) |
| 369 | resType = typeConverter.packOperationResults(types: op->getResultTypes()); |
| 370 | |
| 371 | auto callIntrOp = rewriter.create<LLVM::CallIntrinsicOp>( |
| 372 | loc, resType, rewriter.getStringAttr(intrinsic), operands); |
| 373 | // Propagate attributes. |
| 374 | callIntrOp->setAttrs(op->getAttrDictionary()); |
| 375 | |
| 376 | if (numResults <= 1) { |
| 377 | // Directly replace the original op. |
| 378 | rewriter.replaceOp(op, callIntrOp); |
| 379 | return success(); |
| 380 | } |
| 381 | |
| 382 | // Extract individual results from packed structure and use them as |
| 383 | // replacements. |
| 384 | SmallVector<Value, 4> results; |
| 385 | results.reserve(N: numResults); |
| 386 | Value intrRes = callIntrOp.getResults(); |
| 387 | for (unsigned i = 0; i < numResults; ++i) |
| 388 | results.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, intrRes, i)); |
| 389 | rewriter.replaceOp(op, newValues: results); |
| 390 | |
| 391 | return success(); |
| 392 | } |
| 393 | |
| 394 | static unsigned getBitWidth(Type type) { |
| 395 | if (type.isIntOrFloat()) |
| 396 | return type.getIntOrFloatBitWidth(); |
| 397 | |
| 398 | auto vec = cast<VectorType>(type); |
| 399 | assert(!vec.isScalable() && "scalable vectors are not supported" ); |
| 400 | return vec.getNumElements() * getBitWidth(vec.getElementType()); |
| 401 | } |
| 402 | |
| 403 | static Value createI32Constant(OpBuilder &builder, Location loc, |
| 404 | int32_t value) { |
| 405 | Type i32 = builder.getI32Type(); |
| 406 | return builder.create<LLVM::ConstantOp>(loc, i32, value); |
| 407 | } |
| 408 | |
| 409 | SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc, |
| 410 | Value src, Type dstType) { |
| 411 | Type srcType = src.getType(); |
| 412 | if (srcType == dstType) |
| 413 | return {src}; |
| 414 | |
| 415 | unsigned srcBitWidth = getBitWidth(type: srcType); |
| 416 | unsigned dstBitWidth = getBitWidth(type: dstType); |
| 417 | if (srcBitWidth == dstBitWidth) { |
| 418 | Value cast = builder.create<LLVM::BitcastOp>(loc, dstType, src); |
| 419 | return {cast}; |
| 420 | } |
| 421 | |
| 422 | if (dstBitWidth > srcBitWidth) { |
| 423 | auto smallerInt = builder.getIntegerType(srcBitWidth); |
| 424 | if (srcType != smallerInt) |
| 425 | src = builder.create<LLVM::BitcastOp>(loc, smallerInt, src); |
| 426 | |
| 427 | auto largerInt = builder.getIntegerType(dstBitWidth); |
| 428 | Value res = builder.create<LLVM::ZExtOp>(loc, largerInt, src); |
| 429 | return {res}; |
| 430 | } |
| 431 | assert(srcBitWidth % dstBitWidth == 0 && |
| 432 | "src bit width must be a multiple of dst bit width" ); |
| 433 | int64_t numElements = srcBitWidth / dstBitWidth; |
| 434 | auto vecType = VectorType::get(numElements, dstType); |
| 435 | |
| 436 | src = builder.create<LLVM::BitcastOp>(loc, vecType, src); |
| 437 | |
| 438 | SmallVector<Value> res; |
| 439 | for (auto i : llvm::seq(Size: numElements)) { |
| 440 | Value idx = createI32Constant(builder, loc, value: i); |
| 441 | Value elem = builder.create<LLVM::ExtractElementOp>(loc, src, idx); |
| 442 | res.emplace_back(Args&: elem); |
| 443 | } |
| 444 | |
| 445 | return res; |
| 446 | } |
| 447 | |
| 448 | Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src, |
| 449 | Type dstType) { |
| 450 | assert(!src.empty() && "src range must not be empty" ); |
| 451 | if (src.size() == 1) { |
| 452 | Value res = src.front(); |
| 453 | if (res.getType() == dstType) |
| 454 | return res; |
| 455 | |
| 456 | unsigned srcBitWidth = getBitWidth(type: res.getType()); |
| 457 | unsigned dstBitWidth = getBitWidth(type: dstType); |
| 458 | if (dstBitWidth < srcBitWidth) { |
| 459 | auto largerInt = builder.getIntegerType(srcBitWidth); |
| 460 | if (res.getType() != largerInt) |
| 461 | res = builder.create<LLVM::BitcastOp>(loc, largerInt, res); |
| 462 | |
| 463 | auto smallerInt = builder.getIntegerType(dstBitWidth); |
| 464 | res = builder.create<LLVM::TruncOp>(loc, smallerInt, res); |
| 465 | } |
| 466 | |
| 467 | if (res.getType() != dstType) |
| 468 | res = builder.create<LLVM::BitcastOp>(loc, dstType, res); |
| 469 | |
| 470 | return res; |
| 471 | } |
| 472 | |
| 473 | int64_t numElements = src.size(); |
| 474 | auto srcType = VectorType::get(numElements, src.front().getType()); |
| 475 | Value res = builder.create<LLVM::PoisonOp>(loc, srcType); |
| 476 | for (auto &&[i, elem] : llvm::enumerate(First&: src)) { |
| 477 | Value idx = createI32Constant(builder, loc, value: i); |
| 478 | res = builder.create<LLVM::InsertElementOp>(loc, srcType, res, elem, idx); |
| 479 | } |
| 480 | |
| 481 | if (res.getType() != dstType) |
| 482 | res = builder.create<LLVM::BitcastOp>(loc, dstType, res); |
| 483 | |
| 484 | return res; |
| 485 | } |
| 486 | |
| 487 | Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc, |
| 488 | const LLVMTypeConverter &converter, |
| 489 | MemRefType type, Value memRefDesc, |
| 490 | ValueRange indices, |
| 491 | LLVM::GEPNoWrapFlags noWrapFlags) { |
| 492 | auto [strides, offset] = type.getStridesAndOffset(); |
| 493 | |
| 494 | MemRefDescriptor memRefDescriptor(memRefDesc); |
| 495 | // Use a canonical representation of the start address so that later |
| 496 | // optimizations have a longer sequence of instructions to CSE. |
| 497 | // If we don't do that we would sprinkle the memref.offset in various |
| 498 | // position of the different address computations. |
| 499 | Value base = memRefDescriptor.bufferPtr(builder, loc, converter, type: type); |
| 500 | |
| 501 | LLVM::IntegerOverflowFlags intOverflowFlags = |
| 502 | LLVM::IntegerOverflowFlags::none; |
| 503 | if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) { |
| 504 | intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw; |
| 505 | } |
| 506 | if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) { |
| 507 | intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw; |
| 508 | } |
| 509 | |
| 510 | Type indexType = converter.getIndexType(); |
| 511 | Value index; |
| 512 | for (int i = 0, e = indices.size(); i < e; ++i) { |
| 513 | Value increment = indices[i]; |
| 514 | if (strides[i] != 1) { // Skip if stride is 1. |
| 515 | Value stride = |
| 516 | ShapedType::isDynamic(strides[i]) |
| 517 | ? memRefDescriptor.stride(builder, loc, i) |
| 518 | : builder.create<LLVM::ConstantOp>( |
| 519 | loc, indexType, builder.getIndexAttr(strides[i])); |
| 520 | increment = |
| 521 | builder.create<LLVM::MulOp>(loc, increment, stride, intOverflowFlags); |
| 522 | } |
| 523 | index = index ? builder.create<LLVM::AddOp>(loc, index, increment, |
| 524 | intOverflowFlags) |
| 525 | : increment; |
| 526 | } |
| 527 | |
| 528 | Type elementPtrType = memRefDescriptor.getElementPtrType(); |
| 529 | return index ? builder.create<LLVM::GEPOp>( |
| 530 | loc, elementPtrType, |
| 531 | converter.convertType(type.getElementType()), base, index, |
| 532 | noWrapFlags) |
| 533 | : base; |
| 534 | } |
| 535 | |