| 1 | //===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
| 10 | #include "MemRefDescriptor.h" |
| 11 | #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" |
| 12 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 13 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
| 14 | #include "llvm/ADT/ScopeExit.h" |
| 15 | #include "llvm/Support/Threading.h" |
| 16 | #include <memory> |
| 17 | #include <mutex> |
| 18 | #include <optional> |
| 19 | |
| 20 | using namespace mlir; |
| 21 | |
| 22 | SmallVector<Type> &LLVMTypeConverter::getCurrentThreadRecursiveStack() { |
| 23 | { |
| 24 | // Most of the time, the entry already exists in the map. |
| 25 | std::shared_lock<decltype(callStackMutex)> lock(callStackMutex, |
| 26 | std::defer_lock); |
| 27 | if (getContext().isMultithreadingEnabled()) |
| 28 | lock.lock(); |
| 29 | auto recursiveStack = conversionCallStack.find(Val: llvm::get_threadid()); |
| 30 | if (recursiveStack != conversionCallStack.end()) |
| 31 | return *recursiveStack->second; |
| 32 | } |
| 33 | |
| 34 | // First time this thread gets here, we have to get an exclusive access to |
| 35 | // inset in the map |
| 36 | std::unique_lock<decltype(callStackMutex)> lock(callStackMutex); |
| 37 | auto recursiveStackInserted = conversionCallStack.insert(KV: std::make_pair( |
| 38 | x: llvm::get_threadid(), y: std::make_unique<SmallVector<Type>>())); |
| 39 | return *recursiveStackInserted.first->second; |
| 40 | } |
| 41 | |
| 42 | /// Create an LLVMTypeConverter using default LowerToLLVMOptions. |
| 43 | LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, |
| 44 | const DataLayoutAnalysis *analysis) |
| 45 | : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {} |
| 46 | |
| 47 | /// Helper function that checks if the given value range is a bare pointer. |
| 48 | static bool isBarePointer(ValueRange values) { |
| 49 | return values.size() == 1 && |
| 50 | isa<LLVM::LLVMPointerType>(Val: values.front().getType()); |
| 51 | } |
| 52 | |
| 53 | /// Pack SSA values into an unranked memref descriptor struct. |
| 54 | static Value packUnrankedMemRefDesc(OpBuilder &builder, |
| 55 | UnrankedMemRefType resultType, |
| 56 | ValueRange inputs, Location loc, |
| 57 | const LLVMTypeConverter &converter) { |
| 58 | // Note: Bare pointers are not supported for unranked memrefs because a |
| 59 | // memref descriptor cannot be built just from a bare pointer. |
| 60 | if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields()) |
| 61 | return Value(); |
| 62 | return UnrankedMemRefDescriptor::pack(builder, loc, converter, type: resultType, |
| 63 | values: inputs); |
| 64 | } |
| 65 | |
| 66 | /// Pack SSA values into a ranked memref descriptor struct. |
| 67 | static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType, |
| 68 | ValueRange inputs, Location loc, |
| 69 | const LLVMTypeConverter &converter) { |
| 70 | assert(resultType && "expected non-null result type" ); |
| 71 | if (isBarePointer(values: inputs)) |
| 72 | return MemRefDescriptor::fromStaticShape(builder, loc, converter, |
| 73 | resultType, inputs[0]); |
| 74 | if (TypeRange(inputs) == |
| 75 | converter.getMemRefDescriptorFields(type: resultType, |
| 76 | /*unpackAggregates=*/true)) |
| 77 | return MemRefDescriptor::pack(builder, loc, converter, type: resultType, values: inputs); |
| 78 | // The inputs are neither a bare pointer nor an unpacked memref descriptor. |
| 79 | // This materialization function cannot be used. |
| 80 | return Value(); |
| 81 | } |
| 82 | |
| 83 | /// MemRef descriptor elements -> UnrankedMemRefType |
| 84 | static Value unrankedMemRefMaterialization(OpBuilder &builder, |
| 85 | UnrankedMemRefType resultType, |
| 86 | ValueRange inputs, Location loc, |
| 87 | const LLVMTypeConverter &converter) { |
| 88 | // A source materialization must return a value of type |
| 89 | // `resultType`, so insert a cast from the memref descriptor type |
| 90 | // (!llvm.struct) to the original memref type. |
| 91 | Value packed = |
| 92 | packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter); |
| 93 | if (!packed) |
| 94 | return Value(); |
| 95 | return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed) |
| 96 | .getResult(0); |
| 97 | } |
| 98 | |
| 99 | /// MemRef descriptor elements -> MemRefType |
| 100 | static Value rankedMemRefMaterialization(OpBuilder &builder, |
| 101 | MemRefType resultType, |
| 102 | ValueRange inputs, Location loc, |
| 103 | const LLVMTypeConverter &converter) { |
| 104 | // A source materialization must return a value of type `resultType`, |
| 105 | // so insert a cast from the memref descriptor type (!llvm.struct) to the |
| 106 | // original memref type. |
| 107 | Value packed = |
| 108 | packRankedMemRefDesc(builder, resultType, inputs, loc, converter); |
| 109 | if (!packed) |
| 110 | return Value(); |
| 111 | return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed) |
| 112 | .getResult(0); |
| 113 | } |
| 114 | |
| 115 | /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. |
| 116 | LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, |
| 117 | const LowerToLLVMOptions &options, |
| 118 | const DataLayoutAnalysis *analysis) |
| 119 | : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options), |
| 120 | dataLayoutAnalysis(analysis) { |
| 121 | assert(llvmDialect && "LLVM IR dialect is not registered" ); |
| 122 | |
| 123 | // Register conversions for the builtin types. |
| 124 | addConversion(callback: [&](ComplexType type) { return convertComplexType(type); }); |
| 125 | addConversion([&](FloatType type) { return convertFloatType(type); }); |
| 126 | addConversion([&](FunctionType type) { return convertFunctionType(type); }); |
| 127 | addConversion([&](IndexType type) { return convertIndexType(type); }); |
| 128 | addConversion([&](IntegerType type) { return convertIntegerType(type); }); |
| 129 | addConversion([&](MemRefType type) { return convertMemRefType(type); }); |
| 130 | addConversion( |
| 131 | [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); }); |
| 132 | addConversion(callback: [&](VectorType type) -> std::optional<Type> { |
| 133 | FailureOr<Type> llvmType = convertVectorType(type: type); |
| 134 | if (failed(Result: llvmType)) |
| 135 | return std::nullopt; |
| 136 | return llvmType; |
| 137 | }); |
| 138 | |
| 139 | // LLVM-compatible types are legal, so add a pass-through conversion. Do this |
| 140 | // before the conversions below since conversions are attempted in reverse |
| 141 | // order and those should take priority. |
| 142 | addConversion(callback: [](Type type) { |
| 143 | return LLVM::isCompatibleType(type) ? std::optional<Type>(type) |
| 144 | : std::nullopt; |
| 145 | }); |
| 146 | |
| 147 | addConversion(callback: [&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results) |
| 148 | -> std::optional<LogicalResult> { |
| 149 | // Fastpath for types that won't be converted by this callback anyway. |
| 150 | if (LLVM::isCompatibleType(type: type)) { |
| 151 | results.push_back(Elt: type); |
| 152 | return success(); |
| 153 | } |
| 154 | |
| 155 | if (type.isIdentified()) { |
| 156 | auto convertedType = LLVM::LLVMStructType::getIdentified( |
| 157 | type.getContext(), ("_Converted." + type.getName()).str()); |
| 158 | |
| 159 | SmallVectorImpl<Type> &recursiveStack = getCurrentThreadRecursiveStack(); |
| 160 | if (llvm::count(recursiveStack, type)) { |
| 161 | results.push_back(Elt: convertedType); |
| 162 | return success(); |
| 163 | } |
| 164 | recursiveStack.push_back(Elt: type); |
| 165 | auto popConversionCallStack = llvm::make_scope_exit( |
| 166 | F: [&recursiveStack]() { recursiveStack.pop_back(); }); |
| 167 | |
| 168 | SmallVector<Type> convertedElemTypes; |
| 169 | convertedElemTypes.reserve(N: type.getBody().size()); |
| 170 | if (failed(convertTypes(types: type.getBody(), results&: convertedElemTypes))) |
| 171 | return std::nullopt; |
| 172 | |
| 173 | // If the converted type has not been initialized yet, just set its body |
| 174 | // to be the converted arguments and return. |
| 175 | if (!convertedType.isInitialized()) { |
| 176 | if (failed( |
| 177 | convertedType.setBody(convertedElemTypes, type.isPacked()))) { |
| 178 | return failure(); |
| 179 | } |
| 180 | results.push_back(Elt: convertedType); |
| 181 | return success(); |
| 182 | } |
| 183 | |
| 184 | // If it has been initialized, has the same body and packed bit, just use |
| 185 | // it. This ensures that recursive structs keep being recursive rather |
| 186 | // than including a non-updated name. |
| 187 | if (TypeRange(convertedType.getBody()) == TypeRange(convertedElemTypes) && |
| 188 | convertedType.isPacked() == type.isPacked()) { |
| 189 | results.push_back(Elt: convertedType); |
| 190 | return success(); |
| 191 | } |
| 192 | |
| 193 | return failure(); |
| 194 | } |
| 195 | |
| 196 | SmallVector<Type> convertedSubtypes; |
| 197 | convertedSubtypes.reserve(N: type.getBody().size()); |
| 198 | if (failed(convertTypes(types: type.getBody(), results&: convertedSubtypes))) |
| 199 | return std::nullopt; |
| 200 | |
| 201 | results.push_back(LLVM::LLVMStructType::getLiteral( |
| 202 | type.getContext(), convertedSubtypes, type.isPacked())); |
| 203 | return success(); |
| 204 | }); |
| 205 | addConversion(callback: [&](LLVM::LLVMArrayType type) -> std::optional<Type> { |
| 206 | if (auto element = convertType(type.getElementType())) |
| 207 | return LLVM::LLVMArrayType::get(element, type.getNumElements()); |
| 208 | return std::nullopt; |
| 209 | }); |
| 210 | addConversion(callback: [&](LLVM::LLVMFunctionType type) -> std::optional<Type> { |
| 211 | Type convertedResType = convertType(type.getReturnType()); |
| 212 | if (!convertedResType) |
| 213 | return std::nullopt; |
| 214 | |
| 215 | SmallVector<Type> convertedArgTypes; |
| 216 | convertedArgTypes.reserve(N: type.getNumParams()); |
| 217 | if (failed(convertTypes(types: type.getParams(), results&: convertedArgTypes))) |
| 218 | return std::nullopt; |
| 219 | |
| 220 | return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes, |
| 221 | type.isVarArg()); |
| 222 | }); |
| 223 | |
| 224 | // Add generic source and target materializations to handle cases where |
| 225 | // non-LLVM types persist after an LLVM conversion. |
| 226 | addSourceMaterialization(callback: [&](OpBuilder &builder, Type resultType, |
| 227 | ValueRange inputs, Location loc) { |
| 228 | return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) |
| 229 | .getResult(0); |
| 230 | }); |
| 231 | addTargetMaterialization(callback: [&](OpBuilder &builder, Type resultType, |
| 232 | ValueRange inputs, Location loc) { |
| 233 | return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) |
| 234 | .getResult(0); |
| 235 | }); |
| 236 | |
| 237 | // Source materializations convert from the new block argument types |
| 238 | // (multiple SSA values that make up a memref descriptor) back to the |
| 239 | // original block argument type. |
| 240 | addSourceMaterialization([&](OpBuilder &builder, |
| 241 | UnrankedMemRefType resultType, ValueRange inputs, |
| 242 | Location loc) { |
| 243 | return unrankedMemRefMaterialization(builder, resultType, inputs, loc, |
| 244 | *this); |
| 245 | }); |
| 246 | addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType, |
| 247 | ValueRange inputs, Location loc) { |
| 248 | return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this); |
| 249 | }); |
| 250 | |
| 251 | // Bare pointer -> Packed MemRef descriptor |
| 252 | addTargetMaterialization(callback: [&](OpBuilder &builder, Type resultType, |
| 253 | ValueRange inputs, Location loc, |
| 254 | Type originalType) -> Value { |
| 255 | // The original MemRef type is required to build a MemRef descriptor |
| 256 | // because the sizes/strides of the MemRef cannot be inferred from just the |
| 257 | // bare pointer. |
| 258 | if (!originalType) |
| 259 | return Value(); |
| 260 | if (resultType != convertType(t: originalType)) |
| 261 | return Value(); |
| 262 | if (auto memrefType = dyn_cast<MemRefType>(originalType)) |
| 263 | return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this); |
| 264 | if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType)) |
| 265 | return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc, |
| 266 | *this); |
| 267 | return Value(); |
| 268 | }); |
| 269 | |
| 270 | // Integer memory spaces map to themselves. |
| 271 | addTypeAttributeConversion( |
| 272 | [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; }); |
| 273 | } |
| 274 | |
| 275 | /// Returns the MLIR context. |
| 276 | MLIRContext &LLVMTypeConverter::getContext() const { |
| 277 | return *getDialect()->getContext(); |
| 278 | } |
| 279 | |
| 280 | Type LLVMTypeConverter::getIndexType() const { |
| 281 | return IntegerType::get(&getContext(), getIndexTypeBitwidth()); |
| 282 | } |
| 283 | |
| 284 | unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) const { |
| 285 | return options.dataLayout.getPointerSizeInBits(AS: addressSpace); |
| 286 | } |
| 287 | |
| 288 | Type LLVMTypeConverter::convertIndexType(IndexType type) const { |
| 289 | return getIndexType(); |
| 290 | } |
| 291 | |
| 292 | Type LLVMTypeConverter::convertIntegerType(IntegerType type) const { |
| 293 | return IntegerType::get(&getContext(), type.getWidth()); |
| 294 | } |
| 295 | |
| 296 | Type LLVMTypeConverter::convertFloatType(FloatType type) const { |
| 297 | // Valid LLVM float types are used directly. |
| 298 | if (LLVM::isCompatibleType(type: type)) |
| 299 | return type; |
| 300 | |
| 301 | // F4, F6, F8 types are converted to integer types with the same bit width. |
| 302 | if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType, |
| 303 | Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type, |
| 304 | Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType, |
| 305 | Float8E8M0FNUType>(type)) |
| 306 | return IntegerType::get(&getContext(), type.getWidth()); |
| 307 | |
| 308 | // Other floating-point types: A custom type conversion rule must be |
| 309 | // specified by the user. |
| 310 | return Type(); |
| 311 | } |
| 312 | |
| 313 | // Convert a `ComplexType` to an LLVM type. The result is a complex number |
| 314 | // struct with entries for the |
| 315 | // 1. real part and for the |
| 316 | // 2. imaginary part. |
| 317 | Type LLVMTypeConverter::convertComplexType(ComplexType type) const { |
| 318 | auto elementType = convertType(type.getElementType()); |
| 319 | return LLVM::LLVMStructType::getLiteral(&getContext(), |
| 320 | {elementType, elementType}); |
| 321 | } |
| 322 | |
| 323 | // Except for signatures, MLIR function types are converted into LLVM |
| 324 | // pointer-to-function types. |
| 325 | Type LLVMTypeConverter::convertFunctionType(FunctionType type) const { |
| 326 | return LLVM::LLVMPointerType::get(type.getContext()); |
| 327 | } |
| 328 | |
| 329 | /// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the |
| 330 | /// function arguments. Returns an empty container if none of these attributes |
| 331 | /// are found in any of the arguments. |
| 332 | static void |
| 333 | filterByValRefArgAttrs(FunctionOpInterface funcOp, |
| 334 | SmallVectorImpl<std::optional<NamedAttribute>> &result) { |
| 335 | assert(result.empty() && "Unexpected non-empty output" ); |
| 336 | result.resize(funcOp.getNumArguments(), std::nullopt); |
| 337 | bool foundByValByRefAttrs = false; |
| 338 | for (int argIdx : llvm::seq(funcOp.getNumArguments())) { |
| 339 | for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) { |
| 340 | if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() || |
| 341 | namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) { |
| 342 | foundByValByRefAttrs = true; |
| 343 | result[argIdx] = namedAttr; |
| 344 | break; |
| 345 | } |
| 346 | } |
| 347 | } |
| 348 | |
| 349 | if (!foundByValByRefAttrs) |
| 350 | result.clear(); |
| 351 | } |
| 352 | |
| 353 | // Function types are converted to LLVM Function types by recursively converting |
| 354 | // argument and result types. If MLIR Function has zero results, the LLVM |
| 355 | // Function has one VoidType result. If MLIR Function has more than one result, |
| 356 | // they are into an LLVM StructType in their order of appearance. |
| 357 | // If `byValRefNonPtrAttrs` is provided, converted types of `llvm.byval` and |
| 358 | // `llvm.byref` function arguments which are not LLVM pointers are overridden |
| 359 | // with LLVM pointers. `llvm.byval` and `llvm.byref` arguments that were already |
| 360 | // converted to LLVM pointer types are removed from 'byValRefNonPtrAttrs`. |
| 361 | Type LLVMTypeConverter::convertFunctionSignatureImpl( |
| 362 | FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, |
| 363 | LLVMTypeConverter::SignatureConversion &result, |
| 364 | SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs) const { |
| 365 | // Select the argument converter depending on the calling convention. |
| 366 | useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv; |
| 367 | auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter |
| 368 | : structFuncArgTypeConverter; |
| 369 | // Convert argument types one by one and check for errors. |
| 370 | for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) { |
| 371 | SmallVector<Type, 8> converted; |
| 372 | if (failed(funcArgConverter(*this, type, converted))) |
| 373 | return {}; |
| 374 | |
| 375 | // Rewrite converted type of `llvm.byval` or `llvm.byref` function |
| 376 | // argument that was not converted to an LLVM pointer types. |
| 377 | if (byValRefNonPtrAttrs != nullptr && !byValRefNonPtrAttrs->empty() && |
| 378 | converted.size() == 1 && (*byValRefNonPtrAttrs)[idx].has_value()) { |
| 379 | // If the argument was already converted to an LLVM pointer type, we stop |
| 380 | // tracking it as it doesn't need more processing. |
| 381 | if (isa<LLVM::LLVMPointerType>(converted[0])) |
| 382 | (*byValRefNonPtrAttrs)[idx] = std::nullopt; |
| 383 | else |
| 384 | converted[0] = LLVM::LLVMPointerType::get(&getContext()); |
| 385 | } |
| 386 | |
| 387 | result.addInputs(idx, converted); |
| 388 | } |
| 389 | |
| 390 | // If function does not return anything, create the void result type, |
| 391 | // if it returns on element, convert it, otherwise pack the result types into |
| 392 | // a struct. |
| 393 | Type resultType = |
| 394 | funcTy.getNumResults() == 0 |
| 395 | ? LLVM::LLVMVoidType::get(ctx: &getContext()) |
| 396 | : packFunctionResults(types: funcTy.getResults(), useBarePointerCallConv: useBarePtrCallConv); |
| 397 | if (!resultType) |
| 398 | return {}; |
| 399 | return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(), |
| 400 | isVariadic); |
| 401 | } |
| 402 | |
| 403 | Type LLVMTypeConverter::convertFunctionSignature( |
| 404 | FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, |
| 405 | LLVMTypeConverter::SignatureConversion &result) const { |
| 406 | return convertFunctionSignatureImpl(funcTy: funcTy, isVariadic, useBarePtrCallConv, |
| 407 | result, |
| 408 | /*byValRefNonPtrAttrs=*/nullptr); |
| 409 | } |
| 410 | |
| 411 | Type LLVMTypeConverter::convertFunctionSignature( |
| 412 | FunctionOpInterface funcOp, bool isVariadic, bool useBarePtrCallConv, |
| 413 | LLVMTypeConverter::SignatureConversion &result, |
| 414 | SmallVectorImpl<std::optional<NamedAttribute>> &byValRefNonPtrAttrs) const { |
| 415 | // Gather all `llvm.byval` and `llvm.byref` function arguments. Only those |
| 416 | // that were not converted to LLVM pointer types will be returned for further |
| 417 | // processing. |
| 418 | filterByValRefArgAttrs(funcOp, byValRefNonPtrAttrs); |
| 419 | auto funcTy = cast<FunctionType>(funcOp.getFunctionType()); |
| 420 | return convertFunctionSignatureImpl(funcTy: funcTy, isVariadic, useBarePtrCallConv, |
| 421 | result, byValRefNonPtrAttrs: &byValRefNonPtrAttrs); |
| 422 | } |
| 423 | |
| 424 | /// Converts the function type to a C-compatible format, in particular using |
| 425 | /// pointers to memref descriptors for arguments. |
| 426 | std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType> |
| 427 | LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) const { |
| 428 | SmallVector<Type, 4> inputs; |
| 429 | |
| 430 | Type resultType = type.getNumResults() == 0 |
| 431 | ? LLVM::LLVMVoidType::get(ctx: &getContext()) |
| 432 | : packFunctionResults(types: type.getResults()); |
| 433 | if (!resultType) |
| 434 | return {}; |
| 435 | |
| 436 | auto ptrType = LLVM::LLVMPointerType::get(type.getContext()); |
| 437 | auto structType = dyn_cast<LLVM::LLVMStructType>(resultType); |
| 438 | if (structType) { |
| 439 | // Struct types cannot be safely returned via C interface. Make this a |
| 440 | // pointer argument, instead. |
| 441 | inputs.push_back(Elt: ptrType); |
| 442 | resultType = LLVM::LLVMVoidType::get(ctx: &getContext()); |
| 443 | } |
| 444 | |
| 445 | for (Type t : type.getInputs()) { |
| 446 | auto converted = convertType(t); |
| 447 | if (!converted || !LLVM::isCompatibleType(converted)) |
| 448 | return {}; |
| 449 | if (isa<MemRefType, UnrankedMemRefType>(t)) |
| 450 | converted = ptrType; |
| 451 | inputs.push_back(converted); |
| 452 | } |
| 453 | |
| 454 | return {LLVM::LLVMFunctionType::get(resultType, inputs), structType}; |
| 455 | } |
| 456 | |
| 457 | /// Convert a memref type into a list of LLVM IR types that will form the |
| 458 | /// memref descriptor. The result contains the following types: |
| 459 | /// 1. The pointer to the allocated data buffer, followed by |
| 460 | /// 2. The pointer to the aligned data buffer, followed by |
| 461 | /// 3. A lowered `index`-type integer containing the distance between the |
| 462 | /// beginning of the buffer and the first element to be accessed through the |
| 463 | /// view, followed by |
| 464 | /// 4. An array containing as many `index`-type integers as the rank of the |
| 465 | /// MemRef: the array represents the size, in number of elements, of the memref |
| 466 | /// along the given dimension. For constant MemRef dimensions, the |
| 467 | /// corresponding size entry is a constant whose runtime value must match the |
| 468 | /// static value, followed by |
| 469 | /// 5. A second array containing as many `index`-type integers as the rank of |
| 470 | /// the MemRef: the second array represents the "stride" (in tensor abstraction |
| 471 | /// sense), i.e. the number of consecutive elements of the underlying buffer. |
| 472 | /// TODO: add assertions for the static cases. |
| 473 | /// |
| 474 | /// If `unpackAggregates` is set to true, the arrays described in (4) and (5) |
| 475 | /// are expanded into individual index-type elements. |
| 476 | /// |
| 477 | /// template <typename Elem, typename Index, size_t Rank> |
| 478 | /// struct { |
| 479 | /// Elem *allocatedPtr; |
| 480 | /// Elem *alignedPtr; |
| 481 | /// Index offset; |
| 482 | /// Index sizes[Rank]; // omitted when rank == 0 |
| 483 | /// Index strides[Rank]; // omitted when rank == 0 |
| 484 | /// }; |
| 485 | SmallVector<Type, 5> |
| 486 | LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type, |
| 487 | bool unpackAggregates) const { |
| 488 | if (!type.isStrided()) { |
| 489 | emitError( |
| 490 | UnknownLoc::get(type.getContext()), |
| 491 | "conversion to strided form failed either due to non-strided layout " |
| 492 | "maps (which should have been normalized away) or other reasons" ); |
| 493 | return {}; |
| 494 | } |
| 495 | |
| 496 | Type elementType = convertType(type.getElementType()); |
| 497 | if (!elementType) |
| 498 | return {}; |
| 499 | |
| 500 | FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type: type); |
| 501 | if (failed(Result: addressSpace)) { |
| 502 | emitError(UnknownLoc::get(type.getContext()), |
| 503 | "conversion of memref memory space " ) |
| 504 | << type.getMemorySpace() |
| 505 | << " to integer address space " |
| 506 | "failed. Consider adding memory space conversions." ; |
| 507 | return {}; |
| 508 | } |
| 509 | auto ptrTy = LLVM::LLVMPointerType::get(type.getContext(), *addressSpace); |
| 510 | |
| 511 | auto indexTy = getIndexType(); |
| 512 | |
| 513 | SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy}; |
| 514 | auto rank = type.getRank(); |
| 515 | if (rank == 0) |
| 516 | return results; |
| 517 | |
| 518 | if (unpackAggregates) |
| 519 | results.insert(results.end(), 2 * rank, indexTy); |
| 520 | else |
| 521 | results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank)); |
| 522 | return results; |
| 523 | } |
| 524 | |
| 525 | unsigned |
| 526 | LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type, |
| 527 | const DataLayout &layout) const { |
| 528 | // Compute the descriptor size given that of its components indicated above. |
| 529 | unsigned space = *getMemRefAddressSpace(type: type); |
| 530 | return 2 * llvm::divideCeil(Numerator: getPointerBitwidth(addressSpace: space), Denominator: 8) + |
| 531 | (1 + 2 * type.getRank()) * layout.getTypeSize(t: getIndexType()); |
| 532 | } |
| 533 | |
| 534 | /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that |
| 535 | /// packs the descriptor fields as defined by `getMemRefDescriptorFields`. |
| 536 | Type LLVMTypeConverter::convertMemRefType(MemRefType type) const { |
| 537 | // When converting a MemRefType to a struct with descriptor fields, do not |
| 538 | // unpack the `sizes` and `strides` arrays. |
| 539 | SmallVector<Type, 5> types = |
| 540 | getMemRefDescriptorFields(type: type, /*unpackAggregates=*/false); |
| 541 | if (types.empty()) |
| 542 | return {}; |
| 543 | return LLVM::LLVMStructType::getLiteral(&getContext(), types); |
| 544 | } |
| 545 | |
| 546 | /// Convert an unranked memref type into a list of non-aggregate LLVM IR types |
| 547 | /// that will form the unranked memref descriptor. In particular, the fields |
| 548 | /// for an unranked memref descriptor are: |
| 549 | /// 1. index-typed rank, the dynamic rank of this MemRef |
| 550 | /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be |
| 551 | /// stack allocated (alloca) copy of a MemRef descriptor that got casted to |
| 552 | /// be unranked. |
| 553 | SmallVector<Type, 2> |
| 554 | LLVMTypeConverter::getUnrankedMemRefDescriptorFields() const { |
| 555 | return {getIndexType(), LLVM::LLVMPointerType::get(&getContext())}; |
| 556 | } |
| 557 | |
| 558 | unsigned LLVMTypeConverter::getUnrankedMemRefDescriptorSize( |
| 559 | UnrankedMemRefType type, const DataLayout &layout) const { |
| 560 | // Compute the descriptor size given that of its components indicated above. |
| 561 | unsigned space = *getMemRefAddressSpace(type: type); |
| 562 | return layout.getTypeSize(t: getIndexType()) + |
| 563 | llvm::divideCeil(Numerator: getPointerBitwidth(addressSpace: space), Denominator: 8); |
| 564 | } |
| 565 | |
| 566 | Type LLVMTypeConverter::convertUnrankedMemRefType( |
| 567 | UnrankedMemRefType type) const { |
| 568 | if (!convertType(type.getElementType())) |
| 569 | return {}; |
| 570 | return LLVM::LLVMStructType::getLiteral(&getContext(), |
| 571 | getUnrankedMemRefDescriptorFields()); |
| 572 | } |
| 573 | |
| 574 | FailureOr<unsigned> |
| 575 | LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) const { |
| 576 | if (!type.getMemorySpace()) // Default memory space -> 0. |
| 577 | return 0; |
| 578 | std::optional<Attribute> converted = |
| 579 | convertTypeAttribute(type, attr: type.getMemorySpace()); |
| 580 | if (!converted) |
| 581 | return failure(); |
| 582 | if (!(*converted)) // Conversion to default is 0. |
| 583 | return 0; |
| 584 | if (auto explicitSpace = dyn_cast_if_present<IntegerAttr>(*converted)) { |
| 585 | if (explicitSpace.getType().isIndex() || |
| 586 | explicitSpace.getType().isSignlessInteger()) |
| 587 | return explicitSpace.getInt(); |
| 588 | } |
| 589 | return failure(); |
| 590 | } |
| 591 | |
| 592 | // Check if a memref type can be converted to a bare pointer. |
| 593 | bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) { |
| 594 | if (isa<UnrankedMemRefType>(Val: type)) |
| 595 | // Unranked memref is not supported in the bare pointer calling convention. |
| 596 | return false; |
| 597 | |
| 598 | // Check that the memref has static shape, strides and offset. Otherwise, it |
| 599 | // cannot be lowered to a bare pointer. |
| 600 | auto memrefTy = cast<MemRefType>(type); |
| 601 | if (!memrefTy.hasStaticShape()) |
| 602 | return false; |
| 603 | |
| 604 | int64_t offset = 0; |
| 605 | SmallVector<int64_t, 4> strides; |
| 606 | if (failed(memrefTy.getStridesAndOffset(strides, offset))) |
| 607 | return false; |
| 608 | |
| 609 | for (int64_t stride : strides) |
| 610 | if (ShapedType::isDynamic(stride)) |
| 611 | return false; |
| 612 | |
| 613 | return !ShapedType::isDynamic(offset); |
| 614 | } |
| 615 | |
| 616 | /// Convert a memref type to a bare pointer to the memref element type. |
| 617 | Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const { |
| 618 | if (!canConvertToBarePtr(type)) |
| 619 | return {}; |
| 620 | Type elementType = convertType(t: type.getElementType()); |
| 621 | if (!elementType) |
| 622 | return {}; |
| 623 | FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type); |
| 624 | if (failed(Result: addressSpace)) |
| 625 | return {}; |
| 626 | return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace); |
| 627 | } |
| 628 | |
| 629 | /// Convert an n-D vector type to an LLVM vector type: |
| 630 | /// * 0-D `vector<T>` are converted to vector<1xT> |
| 631 | /// * 1-D `vector<axT>` remains as is while, |
| 632 | /// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to |
| 633 | /// `!llvm.array<ax...array<jxvector<kxT>>>`. |
| 634 | /// As LLVM supports arrays of scalable vectors, this method will also convert |
| 635 | /// n-D scalable vectors provided that only the trailing dim is scalable. |
| 636 | FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const { |
| 637 | auto elementType = convertType(type.getElementType()); |
| 638 | if (!elementType) |
| 639 | return {}; |
| 640 | if (type.getShape().empty()) |
| 641 | return VectorType::get({1}, elementType); |
| 642 | Type vectorType = VectorType::get(type.getShape().back(), elementType, |
| 643 | type.getScalableDims().back()); |
| 644 | assert(LLVM::isCompatibleVectorType(vectorType) && |
| 645 | "expected vector type compatible with the LLVM dialect" ); |
| 646 | // For n-D vector types for which a _non-trailing_ dim is scalable, |
| 647 | // return a failure. Supporting such cases would require LLVM |
| 648 | // to support something akin "scalable arrays" of vectors. |
| 649 | if (llvm::is_contained(type.getScalableDims().drop_back(), true)) |
| 650 | return failure(); |
| 651 | auto shape = type.getShape(); |
| 652 | for (int i = shape.size() - 2; i >= 0; --i) |
| 653 | vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]); |
| 654 | return vectorType; |
| 655 | } |
| 656 | |
| 657 | /// Convert a type in the context of the default or bare pointer calling |
| 658 | /// convention. Calling convention sensitive types, such as MemRefType and |
| 659 | /// UnrankedMemRefType, are converted following the specific rules for the |
| 660 | /// calling convention. Calling convention independent types are converted |
| 661 | /// following the default LLVM type conversions. |
| 662 | Type LLVMTypeConverter::convertCallingConventionType( |
| 663 | Type type, bool useBarePtrCallConv) const { |
| 664 | if (useBarePtrCallConv) |
| 665 | if (auto memrefTy = dyn_cast<BaseMemRefType>(Val&: type)) |
| 666 | return convertMemRefToBarePtr(type: memrefTy); |
| 667 | |
| 668 | return convertType(t: type); |
| 669 | } |
| 670 | |
| 671 | /// Promote the bare pointers in 'values' that resulted from memrefs to |
| 672 | /// descriptors. 'stdTypes' holds they types of 'values' before the conversion |
| 673 | /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type). |
| 674 | void LLVMTypeConverter::promoteBarePtrsToDescriptors( |
| 675 | ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes, |
| 676 | SmallVectorImpl<Value> &values) const { |
| 677 | assert(stdTypes.size() == values.size() && |
| 678 | "The number of types and values doesn't match" ); |
| 679 | for (unsigned i = 0, end = values.size(); i < end; ++i) |
| 680 | if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i])) |
| 681 | values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this, |
| 682 | memrefTy, values[i]); |
| 683 | } |
| 684 | |
| 685 | /// Convert a non-empty list of types of values produced by an operation into an |
| 686 | /// LLVM-compatible type. In particular, if more than one value is |
| 687 | /// produced, create a literal structure with elements that correspond to each |
| 688 | /// of the types converted with `convertType`. |
| 689 | Type LLVMTypeConverter::packOperationResults(TypeRange types) const { |
| 690 | assert(!types.empty() && "expected non-empty list of type" ); |
| 691 | if (types.size() == 1) |
| 692 | return convertType(t: types[0]); |
| 693 | |
| 694 | SmallVector<Type> resultTypes; |
| 695 | resultTypes.reserve(N: types.size()); |
| 696 | for (Type type : types) { |
| 697 | Type converted = convertType(t: type); |
| 698 | if (!converted || !LLVM::isCompatibleType(type: converted)) |
| 699 | return {}; |
| 700 | resultTypes.push_back(Elt: converted); |
| 701 | } |
| 702 | |
| 703 | return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes); |
| 704 | } |
| 705 | |
| 706 | /// Convert a non-empty list of types to be returned from a function into an |
| 707 | /// LLVM-compatible type. In particular, if more than one value is returned, |
| 708 | /// create an LLVM dialect structure type with elements that correspond to each |
| 709 | /// of the types converted with `convertCallingConventionType`. |
| 710 | Type LLVMTypeConverter::packFunctionResults(TypeRange types, |
| 711 | bool useBarePtrCallConv) const { |
| 712 | assert(!types.empty() && "expected non-empty list of type" ); |
| 713 | |
| 714 | useBarePtrCallConv |= options.useBarePtrCallConv; |
| 715 | if (types.size() == 1) |
| 716 | return convertCallingConventionType(type: types.front(), useBarePtrCallConv); |
| 717 | |
| 718 | SmallVector<Type> resultTypes; |
| 719 | resultTypes.reserve(N: types.size()); |
| 720 | for (auto t : types) { |
| 721 | auto converted = convertCallingConventionType(type: t, useBarePtrCallConv); |
| 722 | if (!converted || !LLVM::isCompatibleType(type: converted)) |
| 723 | return {}; |
| 724 | resultTypes.push_back(Elt: converted); |
| 725 | } |
| 726 | |
| 727 | return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes); |
| 728 | } |
| 729 | |
| 730 | Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, |
| 731 | OpBuilder &builder) const { |
| 732 | // Alloca with proper alignment. We do not expect optimizations of this |
| 733 | // alloca op and so we omit allocating at the entry block. |
| 734 | auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); |
| 735 | Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(), |
| 736 | builder.getIndexAttr(1)); |
| 737 | Value allocated = |
| 738 | builder.create<LLVM::AllocaOp>(loc, ptrType, operand.getType(), one); |
| 739 | // Store into the alloca'ed descriptor. |
| 740 | builder.create<LLVM::StoreOp>(loc, operand, allocated); |
| 741 | return allocated; |
| 742 | } |
| 743 | |
| 744 | SmallVector<Value, 4> |
| 745 | LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands, |
| 746 | ValueRange operands, OpBuilder &builder, |
| 747 | bool useBarePtrCallConv) const { |
| 748 | SmallVector<Value, 4> promotedOperands; |
| 749 | promotedOperands.reserve(N: operands.size()); |
| 750 | useBarePtrCallConv |= options.useBarePtrCallConv; |
| 751 | for (auto it : llvm::zip(t&: opOperands, u&: operands)) { |
| 752 | auto operand = std::get<0>(t&: it); |
| 753 | auto llvmOperand = std::get<1>(t&: it); |
| 754 | |
| 755 | if (useBarePtrCallConv) { |
| 756 | // For the bare-ptr calling convention, we only have to extract the |
| 757 | // aligned pointer of a memref. |
| 758 | if (isa<MemRefType>(Val: operand.getType())) { |
| 759 | MemRefDescriptor desc(llvmOperand); |
| 760 | llvmOperand = desc.alignedPtr(builder, loc); |
| 761 | } else if (isa<UnrankedMemRefType>(Val: operand.getType())) { |
| 762 | llvm_unreachable("Unranked memrefs are not supported" ); |
| 763 | } |
| 764 | } else { |
| 765 | if (isa<UnrankedMemRefType>(Val: operand.getType())) { |
| 766 | UnrankedMemRefDescriptor::unpack(builder, loc, packed: llvmOperand, |
| 767 | results&: promotedOperands); |
| 768 | continue; |
| 769 | } |
| 770 | if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) { |
| 771 | MemRefDescriptor::unpack(builder, loc, packed: llvmOperand, type: memrefType, |
| 772 | results&: promotedOperands); |
| 773 | continue; |
| 774 | } |
| 775 | } |
| 776 | |
| 777 | promotedOperands.push_back(Elt: llvmOperand); |
| 778 | } |
| 779 | return promotedOperands; |
| 780 | } |
| 781 | |
| 782 | /// Callback to convert function argument types. It converts a MemRef function |
| 783 | /// argument to a list of non-aggregate types containing descriptor |
| 784 | /// information, and an UnrankedmemRef function argument to a list containing |
| 785 | /// the rank and a pointer to a descriptor struct. |
| 786 | LogicalResult |
| 787 | mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, |
| 788 | SmallVectorImpl<Type> &result) { |
| 789 | if (auto memref = dyn_cast<MemRefType>(type)) { |
| 790 | // In signatures, Memref descriptors are expanded into lists of |
| 791 | // non-aggregate values. |
| 792 | auto converted = |
| 793 | converter.getMemRefDescriptorFields(type: memref, /*unpackAggregates=*/true); |
| 794 | if (converted.empty()) |
| 795 | return failure(); |
| 796 | result.append(converted.begin(), converted.end()); |
| 797 | return success(); |
| 798 | } |
| 799 | if (isa<UnrankedMemRefType>(Val: type)) { |
| 800 | auto converted = converter.getUnrankedMemRefDescriptorFields(); |
| 801 | if (converted.empty()) |
| 802 | return failure(); |
| 803 | result.append(in_start: converted.begin(), in_end: converted.end()); |
| 804 | return success(); |
| 805 | } |
| 806 | auto converted = converter.convertType(t: type); |
| 807 | if (!converted) |
| 808 | return failure(); |
| 809 | result.push_back(Elt: converted); |
| 810 | return success(); |
| 811 | } |
| 812 | |
| 813 | /// Callback to convert function argument types. It converts MemRef function |
| 814 | /// arguments to bare pointers to the MemRef element type. |
| 815 | LogicalResult |
| 816 | mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, |
| 817 | SmallVectorImpl<Type> &result) { |
| 818 | auto llvmTy = converter.convertCallingConventionType( |
| 819 | type, /*useBarePointerCallConv=*/useBarePtrCallConv: true); |
| 820 | if (!llvmTy) |
| 821 | return failure(); |
| 822 | |
| 823 | result.push_back(Elt: llvmTy); |
| 824 | return success(); |
| 825 | } |
| 826 | |