| 1 | //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
| 10 | #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" |
| 11 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 12 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
| 13 | #include "mlir/IR/AffineMap.h" |
| 14 | #include "mlir/IR/BuiltinAttributes.h" |
| 15 | |
| 16 | using namespace mlir; |
| 17 | |
| 18 | //===----------------------------------------------------------------------===// |
| 19 | // ConvertToLLVMPattern |
| 20 | //===----------------------------------------------------------------------===// |
| 21 | |
| 22 | ConvertToLLVMPattern::ConvertToLLVMPattern( |
| 23 | StringRef rootOpName, MLIRContext *context, |
| 24 | const LLVMTypeConverter &typeConverter, PatternBenefit benefit) |
| 25 | : ConversionPattern(typeConverter, rootOpName, benefit, context) {} |
| 26 | |
| 27 | const LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const { |
| 28 | return static_cast<const LLVMTypeConverter *>( |
| 29 | ConversionPattern::getTypeConverter()); |
| 30 | } |
| 31 | |
| 32 | LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const { |
| 33 | return *getTypeConverter()->getDialect(); |
| 34 | } |
| 35 | |
| 36 | Type ConvertToLLVMPattern::getIndexType() const { |
| 37 | return getTypeConverter()->getIndexType(); |
| 38 | } |
| 39 | |
| 40 | Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const { |
| 41 | return IntegerType::get(context: &getTypeConverter()->getContext(), |
| 42 | width: getTypeConverter()->getPointerBitwidth(addressSpace)); |
| 43 | } |
| 44 | |
| 45 | Type ConvertToLLVMPattern::getVoidType() const { |
| 46 | return LLVM::LLVMVoidType::get(ctx: &getTypeConverter()->getContext()); |
| 47 | } |
| 48 | |
| 49 | Type ConvertToLLVMPattern::getPtrType(unsigned addressSpace) const { |
| 50 | return LLVM::LLVMPointerType::get(context: &getTypeConverter()->getContext(), |
| 51 | addressSpace); |
| 52 | } |
| 53 | |
| 54 | Type ConvertToLLVMPattern::getVoidPtrType() const { return getPtrType(); } |
| 55 | |
| 56 | Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder, |
| 57 | Location loc, |
| 58 | Type resultType, |
| 59 | int64_t value) { |
| 60 | return builder.create<LLVM::ConstantOp>(location: loc, args&: resultType, |
| 61 | args: builder.getIndexAttr(value)); |
| 62 | } |
| 63 | |
| 64 | Value ConvertToLLVMPattern::getStridedElementPtr( |
| 65 | ConversionPatternRewriter &rewriter, Location loc, MemRefType type, |
| 66 | Value memRefDesc, ValueRange indices, |
| 67 | LLVM::GEPNoWrapFlags noWrapFlags) const { |
| 68 | return LLVM::getStridedElementPtr(builder&: rewriter, loc, converter: *getTypeConverter(), type, |
| 69 | memRefDesc, indices, noWrapFlags); |
| 70 | } |
| 71 | |
| 72 | // Check if the MemRefType `type` is supported by the lowering. We currently |
| 73 | // only support memrefs with identity maps. |
| 74 | bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps( |
| 75 | MemRefType type) const { |
| 76 | if (!type.getLayout().isIdentity()) |
| 77 | return false; |
| 78 | return static_cast<bool>(typeConverter->convertType(t: type)); |
| 79 | } |
| 80 | |
| 81 | Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { |
| 82 | auto addressSpace = getTypeConverter()->getMemRefAddressSpace(type); |
| 83 | if (failed(Result: addressSpace)) |
| 84 | return {}; |
| 85 | return LLVM::LLVMPointerType::get(context: type.getContext(), addressSpace: *addressSpace); |
| 86 | } |
| 87 | |
| 88 | void ConvertToLLVMPattern::getMemRefDescriptorSizes( |
| 89 | Location loc, MemRefType memRefType, ValueRange dynamicSizes, |
| 90 | ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes, |
| 91 | SmallVectorImpl<Value> &strides, Value &size, bool sizeInBytes) const { |
| 92 | assert(isConvertibleAndHasIdentityMaps(memRefType) && |
| 93 | "layout maps must have been normalized away" ); |
| 94 | assert(count(memRefType.getShape(), ShapedType::kDynamic) == |
| 95 | static_cast<ssize_t>(dynamicSizes.size()) && |
| 96 | "dynamicSizes size doesn't match dynamic sizes count in memref shape" ); |
| 97 | |
| 98 | sizes.reserve(N: memRefType.getRank()); |
| 99 | unsigned dynamicIndex = 0; |
| 100 | Type indexType = getIndexType(); |
| 101 | for (int64_t size : memRefType.getShape()) { |
| 102 | sizes.push_back( |
| 103 | Elt: size == ShapedType::kDynamic |
| 104 | ? dynamicSizes[dynamicIndex++] |
| 105 | : createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: size)); |
| 106 | } |
| 107 | |
| 108 | // Strides: iterate sizes in reverse order and multiply. |
| 109 | int64_t stride = 1; |
| 110 | Value runningStride = createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: 1); |
| 111 | strides.resize(N: memRefType.getRank()); |
| 112 | for (auto i = memRefType.getRank(); i-- > 0;) { |
| 113 | strides[i] = runningStride; |
| 114 | |
| 115 | int64_t staticSize = memRefType.getShape()[i]; |
| 116 | bool useSizeAsStride = stride == 1; |
| 117 | if (staticSize == ShapedType::kDynamic) |
| 118 | stride = ShapedType::kDynamic; |
| 119 | if (stride != ShapedType::kDynamic) |
| 120 | stride *= staticSize; |
| 121 | |
| 122 | if (useSizeAsStride) |
| 123 | runningStride = sizes[i]; |
| 124 | else if (stride == ShapedType::kDynamic) |
| 125 | runningStride = |
| 126 | rewriter.create<LLVM::MulOp>(location: loc, args&: runningStride, args&: sizes[i]); |
| 127 | else |
| 128 | runningStride = createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: stride); |
| 129 | } |
| 130 | if (sizeInBytes) { |
| 131 | // Buffer size in bytes. |
| 132 | Type elementType = typeConverter->convertType(t: memRefType.getElementType()); |
| 133 | auto elementPtrType = LLVM::LLVMPointerType::get(context: rewriter.getContext()); |
| 134 | Value nullPtr = rewriter.create<LLVM::ZeroOp>(location: loc, args&: elementPtrType); |
| 135 | Value gepPtr = rewriter.create<LLVM::GEPOp>( |
| 136 | location: loc, args&: elementPtrType, args&: elementType, args&: nullPtr, args&: runningStride); |
| 137 | size = rewriter.create<LLVM::PtrToIntOp>(location: loc, args: getIndexType(), args&: gepPtr); |
| 138 | } else { |
| 139 | size = runningStride; |
| 140 | } |
| 141 | } |
| 142 | |
| 143 | Value ConvertToLLVMPattern::getSizeInBytes( |
| 144 | Location loc, Type type, ConversionPatternRewriter &rewriter) const { |
| 145 | // Compute the size of an individual element. This emits the MLIR equivalent |
| 146 | // of the following sizeof(...) implementation in LLVM IR: |
| 147 | // %0 = getelementptr %elementType* null, %indexType 1 |
| 148 | // %1 = ptrtoint %elementType* %0 to %indexType |
| 149 | // which is a common pattern of getting the size of a type in bytes. |
| 150 | Type llvmType = typeConverter->convertType(t: type); |
| 151 | auto convertedPtrType = LLVM::LLVMPointerType::get(context: rewriter.getContext()); |
| 152 | auto nullPtr = rewriter.create<LLVM::ZeroOp>(location: loc, args&: convertedPtrType); |
| 153 | auto gep = rewriter.create<LLVM::GEPOp>(location: loc, args&: convertedPtrType, args&: llvmType, |
| 154 | args&: nullPtr, args: ArrayRef<LLVM::GEPArg>{1}); |
| 155 | return rewriter.create<LLVM::PtrToIntOp>(location: loc, args: getIndexType(), args&: gep); |
| 156 | } |
| 157 | |
| 158 | Value ConvertToLLVMPattern::getNumElements( |
| 159 | Location loc, MemRefType memRefType, ValueRange dynamicSizes, |
| 160 | ConversionPatternRewriter &rewriter) const { |
| 161 | assert(count(memRefType.getShape(), ShapedType::kDynamic) == |
| 162 | static_cast<ssize_t>(dynamicSizes.size()) && |
| 163 | "dynamicSizes size doesn't match dynamic sizes count in memref shape" ); |
| 164 | |
| 165 | Type indexType = getIndexType(); |
| 166 | Value numElements = memRefType.getRank() == 0 |
| 167 | ? createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: 1) |
| 168 | : nullptr; |
| 169 | unsigned dynamicIndex = 0; |
| 170 | |
| 171 | // Compute the total number of memref elements. |
| 172 | for (int64_t staticSize : memRefType.getShape()) { |
| 173 | if (numElements) { |
| 174 | Value size = |
| 175 | staticSize == ShapedType::kDynamic |
| 176 | ? dynamicSizes[dynamicIndex++] |
| 177 | : createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: staticSize); |
| 178 | numElements = rewriter.create<LLVM::MulOp>(location: loc, args&: numElements, args&: size); |
| 179 | } else { |
| 180 | numElements = |
| 181 | staticSize == ShapedType::kDynamic |
| 182 | ? dynamicSizes[dynamicIndex++] |
| 183 | : createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: staticSize); |
| 184 | } |
| 185 | } |
| 186 | return numElements; |
| 187 | } |
| 188 | |
| 189 | /// Creates and populates the memref descriptor struct given all its fields. |
| 190 | MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor( |
| 191 | Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, |
| 192 | ArrayRef<Value> sizes, ArrayRef<Value> strides, |
| 193 | ConversionPatternRewriter &rewriter) const { |
| 194 | auto structType = typeConverter->convertType(t: memRefType); |
| 195 | auto memRefDescriptor = MemRefDescriptor::poison(builder&: rewriter, loc, descriptorType: structType); |
| 196 | |
| 197 | // Field 1: Allocated pointer, used for malloc/free. |
| 198 | memRefDescriptor.setAllocatedPtr(builder&: rewriter, loc, ptr: allocatedPtr); |
| 199 | |
| 200 | // Field 2: Actual aligned pointer to payload. |
| 201 | memRefDescriptor.setAlignedPtr(builder&: rewriter, loc, ptr: alignedPtr); |
| 202 | |
| 203 | // Field 3: Offset in aligned pointer. |
| 204 | Type indexType = getIndexType(); |
| 205 | memRefDescriptor.setOffset( |
| 206 | builder&: rewriter, loc, offset: createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: 0)); |
| 207 | |
| 208 | // Fields 4: Sizes. |
| 209 | for (const auto &en : llvm::enumerate(First&: sizes)) |
| 210 | memRefDescriptor.setSize(builder&: rewriter, loc, pos: en.index(), size: en.value()); |
| 211 | |
| 212 | // Field 5: Strides. |
| 213 | for (const auto &en : llvm::enumerate(First&: strides)) |
| 214 | memRefDescriptor.setStride(builder&: rewriter, loc, pos: en.index(), stride: en.value()); |
| 215 | |
| 216 | return memRefDescriptor; |
| 217 | } |
| 218 | |
| 219 | LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( |
| 220 | OpBuilder &builder, Location loc, TypeRange origTypes, |
| 221 | SmallVectorImpl<Value> &operands, bool toDynamic) const { |
| 222 | assert(origTypes.size() == operands.size() && |
| 223 | "expected as may original types as operands" ); |
| 224 | |
| 225 | // Find operands of unranked memref type and store them. |
| 226 | SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs; |
| 227 | SmallVector<unsigned> unrankedAddressSpaces; |
| 228 | for (unsigned i = 0, e = operands.size(); i < e; ++i) { |
| 229 | if (auto memRefType = dyn_cast<UnrankedMemRefType>(Val: origTypes[i])) { |
| 230 | unrankedMemrefs.emplace_back(Args&: operands[i]); |
| 231 | FailureOr<unsigned> addressSpace = |
| 232 | getTypeConverter()->getMemRefAddressSpace(type: memRefType); |
| 233 | if (failed(Result: addressSpace)) |
| 234 | return failure(); |
| 235 | unrankedAddressSpaces.emplace_back(Args&: *addressSpace); |
| 236 | } |
| 237 | } |
| 238 | |
| 239 | if (unrankedMemrefs.empty()) |
| 240 | return success(); |
| 241 | |
| 242 | // Compute allocation sizes. |
| 243 | SmallVector<Value> sizes; |
| 244 | UnrankedMemRefDescriptor::computeSizes(builder, loc, typeConverter: *getTypeConverter(), |
| 245 | values: unrankedMemrefs, addressSpaces: unrankedAddressSpaces, |
| 246 | sizes); |
| 247 | |
| 248 | // Get frequently used types. |
| 249 | Type indexType = getTypeConverter()->getIndexType(); |
| 250 | |
| 251 | // Find the malloc and free, or declare them if necessary. |
| 252 | auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>(); |
| 253 | FailureOr<LLVM::LLVMFuncOp> freeFunc, mallocFunc; |
| 254 | if (toDynamic) { |
| 255 | mallocFunc = LLVM::lookupOrCreateMallocFn(b&: builder, moduleOp: module, indexType); |
| 256 | if (failed(Result: mallocFunc)) |
| 257 | return failure(); |
| 258 | } |
| 259 | if (!toDynamic) { |
| 260 | freeFunc = LLVM::lookupOrCreateFreeFn(b&: builder, moduleOp: module); |
| 261 | if (failed(Result: freeFunc)) |
| 262 | return failure(); |
| 263 | } |
| 264 | |
| 265 | unsigned unrankedMemrefPos = 0; |
| 266 | for (unsigned i = 0, e = operands.size(); i < e; ++i) { |
| 267 | Type type = origTypes[i]; |
| 268 | if (!isa<UnrankedMemRefType>(Val: type)) |
| 269 | continue; |
| 270 | Value allocationSize = sizes[unrankedMemrefPos++]; |
| 271 | UnrankedMemRefDescriptor desc(operands[i]); |
| 272 | |
| 273 | // Allocate memory, copy, and free the source if necessary. |
| 274 | Value memory = |
| 275 | toDynamic |
| 276 | ? builder |
| 277 | .create<LLVM::CallOp>(location: loc, args&: mallocFunc.value(), args&: allocationSize) |
| 278 | .getResult() |
| 279 | : builder.create<LLVM::AllocaOp>(location: loc, args: getPtrType(), |
| 280 | args: IntegerType::get(context: getContext(), width: 8), |
| 281 | args&: allocationSize, |
| 282 | /*alignment=*/args: 0); |
| 283 | Value source = desc.memRefDescPtr(builder, loc); |
| 284 | builder.create<LLVM::MemcpyOp>(location: loc, args&: memory, args&: source, args&: allocationSize, args: false); |
| 285 | if (!toDynamic) |
| 286 | builder.create<LLVM::CallOp>(location: loc, args&: freeFunc.value(), args&: source); |
| 287 | |
| 288 | // Create a new descriptor. The same descriptor can be returned multiple |
| 289 | // times, attempting to modify its pointer can lead to memory leaks |
| 290 | // (allocated twice and overwritten) or double frees (the caller does not |
| 291 | // know if the descriptor points to the same memory). |
| 292 | Type descriptorType = getTypeConverter()->convertType(t: type); |
| 293 | if (!descriptorType) |
| 294 | return failure(); |
| 295 | auto updatedDesc = |
| 296 | UnrankedMemRefDescriptor::poison(builder, loc, descriptorType); |
| 297 | Value rank = desc.rank(builder, loc); |
| 298 | updatedDesc.setRank(builder, loc, value: rank); |
| 299 | updatedDesc.setMemRefDescPtr(builder, loc, value: memory); |
| 300 | |
| 301 | operands[i] = updatedDesc; |
| 302 | } |
| 303 | |
| 304 | return success(); |
| 305 | } |
| 306 | |
| 307 | //===----------------------------------------------------------------------===// |
| 308 | // Detail methods |
| 309 | //===----------------------------------------------------------------------===// |
| 310 | |
| 311 | void LLVM::detail::setNativeProperties(Operation *op, |
| 312 | IntegerOverflowFlags overflowFlags) { |
| 313 | if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(Val: op)) |
| 314 | iface.setOverflowFlags(overflowFlags); |
| 315 | } |
| 316 | |
| 317 | /// Replaces the given operation "op" with a new operation of type "targetOp" |
| 318 | /// and given operands. |
| 319 | LogicalResult LLVM::detail::oneToOneRewrite( |
| 320 | Operation *op, StringRef targetOp, ValueRange operands, |
| 321 | ArrayRef<NamedAttribute> targetAttrs, |
| 322 | const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, |
| 323 | IntegerOverflowFlags overflowFlags) { |
| 324 | unsigned numResults = op->getNumResults(); |
| 325 | |
| 326 | SmallVector<Type> resultTypes; |
| 327 | if (numResults != 0) { |
| 328 | resultTypes.push_back( |
| 329 | Elt: typeConverter.packOperationResults(types: op->getResultTypes())); |
| 330 | if (!resultTypes.back()) |
| 331 | return failure(); |
| 332 | } |
| 333 | |
| 334 | // Create the operation through state since we don't know its C++ type. |
| 335 | Operation *newOp = |
| 336 | rewriter.create(loc: op->getLoc(), opName: rewriter.getStringAttr(bytes: targetOp), operands, |
| 337 | types: resultTypes, attributes: targetAttrs); |
| 338 | |
| 339 | setNativeProperties(op: newOp, overflowFlags); |
| 340 | |
| 341 | // If the operation produced 0 or 1 result, return them immediately. |
| 342 | if (numResults == 0) |
| 343 | return rewriter.eraseOp(op), success(); |
| 344 | if (numResults == 1) |
| 345 | return rewriter.replaceOp(op, newValues: newOp->getResult(idx: 0)), success(); |
| 346 | |
| 347 | // Otherwise, it had been converted to an operation producing a structure. |
| 348 | // Extract individual results from the structure and return them as list. |
| 349 | SmallVector<Value, 4> results; |
| 350 | results.reserve(N: numResults); |
| 351 | for (unsigned i = 0; i < numResults; ++i) { |
| 352 | results.push_back(Elt: rewriter.create<LLVM::ExtractValueOp>( |
| 353 | location: op->getLoc(), args: newOp->getResult(idx: 0), args&: i)); |
| 354 | } |
| 355 | rewriter.replaceOp(op, newValues: results); |
| 356 | return success(); |
| 357 | } |
| 358 | |
| 359 | LogicalResult LLVM::detail::intrinsicRewrite( |
| 360 | Operation *op, StringRef intrinsic, ValueRange operands, |
| 361 | const LLVMTypeConverter &typeConverter, RewriterBase &rewriter) { |
| 362 | auto loc = op->getLoc(); |
| 363 | |
| 364 | if (!llvm::all_of(Range&: operands, P: [](Value value) { |
| 365 | return LLVM::isCompatibleType(type: value.getType()); |
| 366 | })) |
| 367 | return failure(); |
| 368 | |
| 369 | unsigned numResults = op->getNumResults(); |
| 370 | Type resType; |
| 371 | if (numResults != 0) |
| 372 | resType = typeConverter.packOperationResults(types: op->getResultTypes()); |
| 373 | |
| 374 | auto callIntrOp = rewriter.create<LLVM::CallIntrinsicOp>( |
| 375 | location: loc, args&: resType, args: rewriter.getStringAttr(bytes: intrinsic), args&: operands); |
| 376 | // Propagate attributes. |
| 377 | callIntrOp->setAttrs(op->getAttrDictionary()); |
| 378 | |
| 379 | if (numResults <= 1) { |
| 380 | // Directly replace the original op. |
| 381 | rewriter.replaceOp(op, newOp: callIntrOp); |
| 382 | return success(); |
| 383 | } |
| 384 | |
| 385 | // Extract individual results from packed structure and use them as |
| 386 | // replacements. |
| 387 | SmallVector<Value, 4> results; |
| 388 | results.reserve(N: numResults); |
| 389 | Value intrRes = callIntrOp.getResults(); |
| 390 | for (unsigned i = 0; i < numResults; ++i) |
| 391 | results.push_back(Elt: rewriter.create<LLVM::ExtractValueOp>(location: loc, args&: intrRes, args&: i)); |
| 392 | rewriter.replaceOp(op, newValues: results); |
| 393 | |
| 394 | return success(); |
| 395 | } |
| 396 | |
| 397 | static unsigned getBitWidth(Type type) { |
| 398 | if (type.isIntOrFloat()) |
| 399 | return type.getIntOrFloatBitWidth(); |
| 400 | |
| 401 | auto vec = cast<VectorType>(Val&: type); |
| 402 | assert(!vec.isScalable() && "scalable vectors are not supported" ); |
| 403 | return vec.getNumElements() * getBitWidth(type: vec.getElementType()); |
| 404 | } |
| 405 | |
| 406 | static Value createI32Constant(OpBuilder &builder, Location loc, |
| 407 | int32_t value) { |
| 408 | Type i32 = builder.getI32Type(); |
| 409 | return builder.create<LLVM::ConstantOp>(location: loc, args&: i32, args&: value); |
| 410 | } |
| 411 | |
| 412 | SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc, |
| 413 | Value src, Type dstType) { |
| 414 | Type srcType = src.getType(); |
| 415 | if (srcType == dstType) |
| 416 | return {src}; |
| 417 | |
| 418 | unsigned srcBitWidth = getBitWidth(type: srcType); |
| 419 | unsigned dstBitWidth = getBitWidth(type: dstType); |
| 420 | if (srcBitWidth == dstBitWidth) { |
| 421 | Value cast = builder.create<LLVM::BitcastOp>(location: loc, args&: dstType, args&: src); |
| 422 | return {cast}; |
| 423 | } |
| 424 | |
| 425 | if (dstBitWidth > srcBitWidth) { |
| 426 | auto smallerInt = builder.getIntegerType(width: srcBitWidth); |
| 427 | if (srcType != smallerInt) |
| 428 | src = builder.create<LLVM::BitcastOp>(location: loc, args&: smallerInt, args&: src); |
| 429 | |
| 430 | auto largerInt = builder.getIntegerType(width: dstBitWidth); |
| 431 | Value res = builder.create<LLVM::ZExtOp>(location: loc, args&: largerInt, args&: src); |
| 432 | return {res}; |
| 433 | } |
| 434 | assert(srcBitWidth % dstBitWidth == 0 && |
| 435 | "src bit width must be a multiple of dst bit width" ); |
| 436 | int64_t numElements = srcBitWidth / dstBitWidth; |
| 437 | auto vecType = VectorType::get(shape: numElements, elementType: dstType); |
| 438 | |
| 439 | src = builder.create<LLVM::BitcastOp>(location: loc, args&: vecType, args&: src); |
| 440 | |
| 441 | SmallVector<Value> res; |
| 442 | for (auto i : llvm::seq(Size: numElements)) { |
| 443 | Value idx = createI32Constant(builder, loc, value: i); |
| 444 | Value elem = builder.create<LLVM::ExtractElementOp>(location: loc, args&: src, args&: idx); |
| 445 | res.emplace_back(Args&: elem); |
| 446 | } |
| 447 | |
| 448 | return res; |
| 449 | } |
| 450 | |
| 451 | Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src, |
| 452 | Type dstType) { |
| 453 | assert(!src.empty() && "src range must not be empty" ); |
| 454 | if (src.size() == 1) { |
| 455 | Value res = src.front(); |
| 456 | if (res.getType() == dstType) |
| 457 | return res; |
| 458 | |
| 459 | unsigned srcBitWidth = getBitWidth(type: res.getType()); |
| 460 | unsigned dstBitWidth = getBitWidth(type: dstType); |
| 461 | if (dstBitWidth < srcBitWidth) { |
| 462 | auto largerInt = builder.getIntegerType(width: srcBitWidth); |
| 463 | if (res.getType() != largerInt) |
| 464 | res = builder.create<LLVM::BitcastOp>(location: loc, args&: largerInt, args&: res); |
| 465 | |
| 466 | auto smallerInt = builder.getIntegerType(width: dstBitWidth); |
| 467 | res = builder.create<LLVM::TruncOp>(location: loc, args&: smallerInt, args&: res); |
| 468 | } |
| 469 | |
| 470 | if (res.getType() != dstType) |
| 471 | res = builder.create<LLVM::BitcastOp>(location: loc, args&: dstType, args&: res); |
| 472 | |
| 473 | return res; |
| 474 | } |
| 475 | |
| 476 | int64_t numElements = src.size(); |
| 477 | auto srcType = VectorType::get(shape: numElements, elementType: src.front().getType()); |
| 478 | Value res = builder.create<LLVM::PoisonOp>(location: loc, args&: srcType); |
| 479 | for (auto &&[i, elem] : llvm::enumerate(First&: src)) { |
| 480 | Value idx = createI32Constant(builder, loc, value: i); |
| 481 | res = builder.create<LLVM::InsertElementOp>(location: loc, args&: srcType, args&: res, args&: elem, args&: idx); |
| 482 | } |
| 483 | |
| 484 | if (res.getType() != dstType) |
| 485 | res = builder.create<LLVM::BitcastOp>(location: loc, args&: dstType, args&: res); |
| 486 | |
| 487 | return res; |
| 488 | } |
| 489 | |
| 490 | Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc, |
| 491 | const LLVMTypeConverter &converter, |
| 492 | MemRefType type, Value memRefDesc, |
| 493 | ValueRange indices, |
| 494 | LLVM::GEPNoWrapFlags noWrapFlags) { |
| 495 | auto [strides, offset] = type.getStridesAndOffset(); |
| 496 | |
| 497 | MemRefDescriptor memRefDescriptor(memRefDesc); |
| 498 | // Use a canonical representation of the start address so that later |
| 499 | // optimizations have a longer sequence of instructions to CSE. |
| 500 | // If we don't do that we would sprinkle the memref.offset in various |
| 501 | // position of the different address computations. |
| 502 | Value base = memRefDescriptor.bufferPtr(builder, loc, converter, type); |
| 503 | |
| 504 | LLVM::IntegerOverflowFlags intOverflowFlags = |
| 505 | LLVM::IntegerOverflowFlags::none; |
| 506 | if (LLVM::bitEnumContainsAny(bits: noWrapFlags, bit: LLVM::GEPNoWrapFlags::nusw)) { |
| 507 | intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw; |
| 508 | } |
| 509 | if (LLVM::bitEnumContainsAny(bits: noWrapFlags, bit: LLVM::GEPNoWrapFlags::nuw)) { |
| 510 | intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw; |
| 511 | } |
| 512 | |
| 513 | Type indexType = converter.getIndexType(); |
| 514 | Value index; |
| 515 | for (int i = 0, e = indices.size(); i < e; ++i) { |
| 516 | Value increment = indices[i]; |
| 517 | if (strides[i] != 1) { // Skip if stride is 1. |
| 518 | Value stride = |
| 519 | ShapedType::isDynamic(dValue: strides[i]) |
| 520 | ? memRefDescriptor.stride(builder, loc, pos: i) |
| 521 | : builder.create<LLVM::ConstantOp>( |
| 522 | location: loc, args&: indexType, args: builder.getIndexAttr(value: strides[i])); |
| 523 | increment = |
| 524 | builder.create<LLVM::MulOp>(location: loc, args&: increment, args&: stride, args&: intOverflowFlags); |
| 525 | } |
| 526 | index = index ? builder.create<LLVM::AddOp>(location: loc, args&: index, args&: increment, |
| 527 | args&: intOverflowFlags) |
| 528 | : increment; |
| 529 | } |
| 530 | |
| 531 | Type elementPtrType = memRefDescriptor.getElementPtrType(); |
| 532 | return index ? builder.create<LLVM::GEPOp>( |
| 533 | location: loc, args&: elementPtrType, |
| 534 | args: converter.convertType(t: type.getElementType()), args&: base, args&: index, |
| 535 | args&: noWrapFlags) |
| 536 | : base; |
| 537 | } |
| 538 | |