| 1 | //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements patterns to convert SPIR-V dialect to LLVM dialect. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h" |
| 14 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
| 15 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
| 16 | #include "mlir/Conversion/SPIRVCommon/AttrToLLVMConverter.h" |
| 17 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 18 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
| 19 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
| 20 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| 21 | #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" |
| 22 | #include "mlir/IR/BuiltinOps.h" |
| 23 | #include "mlir/IR/PatternMatch.h" |
| 24 | #include "mlir/Transforms/DialectConversion.h" |
| 25 | #include "llvm/ADT/TypeSwitch.h" |
| 26 | #include "llvm/Support/Debug.h" |
| 27 | #include "llvm/Support/FormatVariadic.h" |
| 28 | |
| 29 | #define DEBUG_TYPE "spirv-to-llvm-pattern" |
| 30 | |
| 31 | using namespace mlir; |
| 32 | |
| 33 | //===----------------------------------------------------------------------===// |
| 34 | // Utility functions |
| 35 | //===----------------------------------------------------------------------===// |
| 36 | |
| 37 | /// Returns true if the given type is a signed integer or vector type. |
| 38 | static bool isSignedIntegerOrVector(Type type) { |
| 39 | if (type.isSignedInteger()) |
| 40 | return true; |
| 41 | if (auto vecType = dyn_cast<VectorType>(type)) |
| 42 | return vecType.getElementType().isSignedInteger(); |
| 43 | return false; |
| 44 | } |
| 45 | |
| 46 | /// Returns true if the given type is an unsigned integer or vector type |
| 47 | static bool isUnsignedIntegerOrVector(Type type) { |
| 48 | if (type.isUnsignedInteger()) |
| 49 | return true; |
| 50 | if (auto vecType = dyn_cast<VectorType>(type)) |
| 51 | return vecType.getElementType().isUnsignedInteger(); |
| 52 | return false; |
| 53 | } |
| 54 | |
| 55 | /// Returns the width of an integer or of the element type of an integer vector, |
| 56 | /// if applicable. |
| 57 | static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) { |
| 58 | if (auto intType = dyn_cast<IntegerType>(type)) |
| 59 | return intType.getWidth(); |
| 60 | if (auto vecType = dyn_cast<VectorType>(type)) |
| 61 | if (auto intType = dyn_cast<IntegerType>(vecType.getElementType())) |
| 62 | return intType.getWidth(); |
| 63 | return std::nullopt; |
| 64 | } |
| 65 | |
| 66 | /// Returns the bit width of integer, float or vector of float or integer values |
| 67 | static unsigned getBitWidth(Type type) { |
| 68 | assert((type.isIntOrFloat() || isa<VectorType>(type)) && |
| 69 | "bitwidth is not supported for this type" ); |
| 70 | if (type.isIntOrFloat()) |
| 71 | return type.getIntOrFloatBitWidth(); |
| 72 | auto vecType = dyn_cast<VectorType>(type); |
| 73 | auto elementType = vecType.getElementType(); |
| 74 | assert(elementType.isIntOrFloat() && |
| 75 | "only integers and floats have a bitwidth" ); |
| 76 | return elementType.getIntOrFloatBitWidth(); |
| 77 | } |
| 78 | |
| 79 | /// Returns the bit width of LLVMType integer or vector. |
| 80 | static unsigned getLLVMTypeBitWidth(Type type) { |
| 81 | if (auto vecTy = dyn_cast<VectorType>(type)) |
| 82 | type = vecTy.getElementType(); |
| 83 | return cast<IntegerType>(type).getWidth(); |
| 84 | } |
| 85 | |
| 86 | /// Creates `IntegerAttribute` with all bits set for given type |
| 87 | static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { |
| 88 | if (auto vecType = dyn_cast<VectorType>(type)) { |
| 89 | auto integerType = cast<IntegerType>(vecType.getElementType()); |
| 90 | return builder.getIntegerAttr(integerType, -1); |
| 91 | } |
| 92 | auto integerType = cast<IntegerType>(type); |
| 93 | return builder.getIntegerAttr(integerType, -1); |
| 94 | } |
| 95 | |
| 96 | /// Creates `llvm.mlir.constant` with all bits set for the given type. |
| 97 | static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, |
| 98 | PatternRewriter &rewriter) { |
| 99 | if (isa<VectorType>(Val: srcType)) { |
| 100 | return rewriter.create<LLVM::ConstantOp>( |
| 101 | loc, dstType, |
| 102 | SplatElementsAttr::get(cast<ShapedType>(srcType), |
| 103 | minusOneIntegerAttribute(srcType, rewriter))); |
| 104 | } |
| 105 | return rewriter.create<LLVM::ConstantOp>( |
| 106 | loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); |
| 107 | } |
| 108 | |
| 109 | /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value. |
| 110 | static Value createFPConstant(Location loc, Type srcType, Type dstType, |
| 111 | PatternRewriter &rewriter, double value) { |
| 112 | if (auto vecType = dyn_cast<VectorType>(srcType)) { |
| 113 | auto floatType = cast<FloatType>(vecType.getElementType()); |
| 114 | return rewriter.create<LLVM::ConstantOp>( |
| 115 | loc, dstType, |
| 116 | SplatElementsAttr::get(vecType, |
| 117 | rewriter.getFloatAttr(floatType, value))); |
| 118 | } |
| 119 | auto floatType = cast<FloatType>(srcType); |
| 120 | return rewriter.create<LLVM::ConstantOp>( |
| 121 | loc, dstType, rewriter.getFloatAttr(floatType, value)); |
| 122 | } |
| 123 | |
| 124 | /// Utility function for bitfield ops: |
| 125 | /// - `BitFieldInsert` |
| 126 | /// - `BitFieldSExtract` |
| 127 | /// - `BitFieldUExtract` |
| 128 | /// Truncates or extends the value. If the bitwidth of the value is the same as |
| 129 | /// `llvmType` bitwidth, the value remains unchanged. |
| 130 | static Value optionallyTruncateOrExtend(Location loc, Value value, |
| 131 | Type llvmType, |
| 132 | PatternRewriter &rewriter) { |
| 133 | auto srcType = value.getType(); |
| 134 | unsigned targetBitWidth = getLLVMTypeBitWidth(type: llvmType); |
| 135 | unsigned valueBitWidth = LLVM::isCompatibleType(type: srcType) |
| 136 | ? getLLVMTypeBitWidth(type: srcType) |
| 137 | : getBitWidth(type: srcType); |
| 138 | |
| 139 | if (valueBitWidth < targetBitWidth) |
| 140 | return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value); |
| 141 | // If the bit widths of `Count` and `Offset` are greater than the bit width |
| 142 | // of the target type, they are truncated. Truncation is safe since `Count` |
| 143 | // and `Offset` must be no more than 64 for op behaviour to be defined. Hence, |
| 144 | // both values can be expressed in 8 bits. |
| 145 | if (valueBitWidth > targetBitWidth) |
| 146 | return rewriter.create<LLVM::TruncOp>(loc, llvmType, value); |
| 147 | return value; |
| 148 | } |
| 149 | |
| 150 | /// Broadcasts the value to vector with `numElements` number of elements. |
| 151 | static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, |
| 152 | const TypeConverter &typeConverter, |
| 153 | ConversionPatternRewriter &rewriter) { |
| 154 | auto vectorType = VectorType::get(numElements, toBroadcast.getType()); |
| 155 | auto llvmVectorType = typeConverter.convertType(vectorType); |
| 156 | auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); |
| 157 | Value broadcasted = rewriter.create<LLVM::PoisonOp>(loc, llvmVectorType); |
| 158 | for (unsigned i = 0; i < numElements; ++i) { |
| 159 | auto index = rewriter.create<LLVM::ConstantOp>( |
| 160 | loc, llvmI32Type, rewriter.getI32IntegerAttr(i)); |
| 161 | broadcasted = rewriter.create<LLVM::InsertElementOp>( |
| 162 | loc, llvmVectorType, broadcasted, toBroadcast, index); |
| 163 | } |
| 164 | return broadcasted; |
| 165 | } |
| 166 | |
| 167 | /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged. |
| 168 | static Value optionallyBroadcast(Location loc, Value value, Type srcType, |
| 169 | const TypeConverter &typeConverter, |
| 170 | ConversionPatternRewriter &rewriter) { |
| 171 | if (auto vectorType = dyn_cast<VectorType>(srcType)) { |
| 172 | unsigned numElements = vectorType.getNumElements(); |
| 173 | return broadcast(loc, toBroadcast: value, numElements, typeConverter, rewriter); |
| 174 | } |
| 175 | return value; |
| 176 | } |
| 177 | |
| 178 | /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and |
| 179 | /// `BitFieldUExtract`. |
| 180 | /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of |
| 181 | /// a vector type, construct a vector that has: |
| 182 | /// - same number of elements as `Base` |
| 183 | /// - each element has the type that is the same as the type of `Offset` or |
| 184 | /// `Count` |
| 185 | /// - each element has the same value as `Offset` or `Count` |
| 186 | /// Then cast `Offset` and `Count` if their bit width is different |
| 187 | /// from `Base` bit width. |
| 188 | static Value processCountOrOffset(Location loc, Value value, Type srcType, |
| 189 | Type dstType, const TypeConverter &converter, |
| 190 | ConversionPatternRewriter &rewriter) { |
| 191 | Value broadcasted = |
| 192 | optionallyBroadcast(loc, value, srcType, typeConverter: converter, rewriter); |
| 193 | return optionallyTruncateOrExtend(loc, value: broadcasted, llvmType: dstType, rewriter); |
| 194 | } |
| 195 | |
| 196 | /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`) |
| 197 | /// offset to LLVM struct. Otherwise, the conversion is not supported. |
| 198 | static Type convertStructTypeWithOffset(spirv::StructType type, |
| 199 | const TypeConverter &converter) { |
| 200 | if (type != VulkanLayoutUtils::decorateType(structType: type)) |
| 201 | return nullptr; |
| 202 | |
| 203 | SmallVector<Type> elementsVector; |
| 204 | if (failed(Result: converter.convertTypes(types: type.getElementTypes(), results&: elementsVector))) |
| 205 | return nullptr; |
| 206 | return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector, |
| 207 | /*isPacked=*/false); |
| 208 | } |
| 209 | |
| 210 | /// Converts SPIR-V struct with no offset to packed LLVM struct. |
| 211 | static Type convertStructTypePacked(spirv::StructType type, |
| 212 | const TypeConverter &converter) { |
| 213 | SmallVector<Type> elementsVector; |
| 214 | if (failed(Result: converter.convertTypes(types: type.getElementTypes(), results&: elementsVector))) |
| 215 | return nullptr; |
| 216 | return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector, |
| 217 | /*isPacked=*/true); |
| 218 | } |
| 219 | |
| 220 | /// Creates LLVM dialect constant with the given value. |
| 221 | static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, |
| 222 | unsigned value) { |
| 223 | return rewriter.create<LLVM::ConstantOp>( |
| 224 | loc, IntegerType::get(rewriter.getContext(), 32), |
| 225 | rewriter.getIntegerAttr(rewriter.getI32Type(), value)); |
| 226 | } |
| 227 | |
| 228 | /// Utility for `spirv.Load` and `spirv.Store` conversion. |
| 229 | static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, |
| 230 | ConversionPatternRewriter &rewriter, |
| 231 | const TypeConverter &typeConverter, |
| 232 | unsigned alignment, bool isVolatile, |
| 233 | bool isNonTemporal) { |
| 234 | if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) { |
| 235 | auto dstType = typeConverter.convertType(loadOp.getType()); |
| 236 | if (!dstType) |
| 237 | return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed" ); |
| 238 | rewriter.replaceOpWithNewOp<LLVM::LoadOp>( |
| 239 | loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment, |
| 240 | isVolatile, isNonTemporal); |
| 241 | return success(); |
| 242 | } |
| 243 | auto storeOp = cast<spirv::StoreOp>(op); |
| 244 | spirv::StoreOpAdaptor adaptor(operands); |
| 245 | rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(), |
| 246 | adaptor.getPtr(), alignment, |
| 247 | isVolatile, isNonTemporal); |
| 248 | return success(); |
| 249 | } |
| 250 | |
| 251 | //===----------------------------------------------------------------------===// |
| 252 | // Type conversion |
| 253 | //===----------------------------------------------------------------------===// |
| 254 | |
| 255 | /// Converts SPIR-V array type to LLVM array. Natural stride (according to |
| 256 | /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected |
| 257 | /// when converting ops that manipulate array types. |
| 258 | static std::optional<Type> convertArrayType(spirv::ArrayType type, |
| 259 | TypeConverter &converter) { |
| 260 | unsigned stride = type.getArrayStride(); |
| 261 | Type elementType = type.getElementType(); |
| 262 | auto sizeInBytes = cast<spirv::SPIRVType>(Val&: elementType).getSizeInBytes(); |
| 263 | if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride)) |
| 264 | return std::nullopt; |
| 265 | |
| 266 | auto llvmElementType = converter.convertType(t: elementType); |
| 267 | unsigned numElements = type.getNumElements(); |
| 268 | return LLVM::LLVMArrayType::get(llvmElementType, numElements); |
| 269 | } |
| 270 | |
| 271 | /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not |
| 272 | /// modelled at the moment. |
| 273 | static Type convertPointerType(spirv::PointerType type, |
| 274 | const TypeConverter &converter, |
| 275 | spirv::ClientAPI clientAPI) { |
| 276 | unsigned addressSpace = |
| 277 | storageClassToAddressSpace(clientAPI, type.getStorageClass()); |
| 278 | return LLVM::LLVMPointerType::get(type.getContext(), addressSpace); |
| 279 | } |
| 280 | |
| 281 | /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over |
| 282 | /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is |
| 283 | /// no modelling of array stride at the moment. |
| 284 | static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type, |
| 285 | TypeConverter &converter) { |
| 286 | if (type.getArrayStride() != 0) |
| 287 | return std::nullopt; |
| 288 | auto elementType = converter.convertType(t: type.getElementType()); |
| 289 | return LLVM::LLVMArrayType::get(elementType, 0); |
| 290 | } |
| 291 | |
| 292 | /// Converts SPIR-V struct to LLVM struct. There is no support of structs with |
| 293 | /// member decorations. Also, only natural offset is supported. |
| 294 | static Type convertStructType(spirv::StructType type, |
| 295 | const TypeConverter &converter) { |
| 296 | SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; |
| 297 | type.getMemberDecorations(memberDecorations); |
| 298 | if (!memberDecorations.empty()) |
| 299 | return nullptr; |
| 300 | if (type.hasOffset()) |
| 301 | return convertStructTypeWithOffset(type, converter); |
| 302 | return convertStructTypePacked(type, converter); |
| 303 | } |
| 304 | |
| 305 | //===----------------------------------------------------------------------===// |
| 306 | // Operation conversion |
| 307 | //===----------------------------------------------------------------------===// |
| 308 | |
| 309 | namespace { |
| 310 | |
| 311 | class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> { |
| 312 | public: |
| 313 | using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion; |
| 314 | |
| 315 | LogicalResult |
| 316 | matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor, |
| 317 | ConversionPatternRewriter &rewriter) const override { |
| 318 | auto dstType = |
| 319 | getTypeConverter()->convertType(op.getComponentPtr().getType()); |
| 320 | if (!dstType) |
| 321 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
| 322 | // To use GEP we need to add a first 0 index to go through the pointer. |
| 323 | auto indices = llvm::to_vector<4>(adaptor.getIndices()); |
| 324 | Type indexType = op.getIndices().front().getType(); |
| 325 | auto llvmIndexType = getTypeConverter()->convertType(indexType); |
| 326 | if (!llvmIndexType) |
| 327 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
| 328 | Value zero = rewriter.create<LLVM::ConstantOp>( |
| 329 | op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0)); |
| 330 | indices.insert(indices.begin(), zero); |
| 331 | |
| 332 | auto elementType = getTypeConverter()->convertType( |
| 333 | cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType()); |
| 334 | if (!elementType) |
| 335 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
| 336 | rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType, |
| 337 | adaptor.getBasePtr(), indices); |
| 338 | return success(); |
| 339 | } |
| 340 | }; |
| 341 | |
| 342 | class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> { |
| 343 | public: |
| 344 | using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion; |
| 345 | |
| 346 | LogicalResult |
| 347 | matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor, |
| 348 | ConversionPatternRewriter &rewriter) const override { |
| 349 | auto dstType = getTypeConverter()->convertType(op.getPointer().getType()); |
| 350 | if (!dstType) |
| 351 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
| 352 | rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, |
| 353 | op.getVariable()); |
| 354 | return success(); |
| 355 | } |
| 356 | }; |
| 357 | |
| 358 | class BitFieldInsertPattern |
| 359 | : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> { |
| 360 | public: |
| 361 | using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion; |
| 362 | |
| 363 | LogicalResult |
| 364 | matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor, |
| 365 | ConversionPatternRewriter &rewriter) const override { |
| 366 | auto srcType = op.getType(); |
| 367 | auto dstType = getTypeConverter()->convertType(srcType); |
| 368 | if (!dstType) |
| 369 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
| 370 | Location loc = op.getLoc(); |
| 371 | |
| 372 | // Process `Offset` and `Count`: broadcast and extend/truncate if needed. |
| 373 | Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType, |
| 374 | *getTypeConverter(), rewriter); |
| 375 | Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType, |
| 376 | *getTypeConverter(), rewriter); |
| 377 | |
| 378 | // Create a mask with bits set outside [Offset, Offset + Count - 1]. |
| 379 | Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); |
| 380 | Value maskShiftedByCount = |
| 381 | rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count); |
| 382 | Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType, |
| 383 | maskShiftedByCount, minusOne); |
| 384 | Value maskShiftedByCountAndOffset = |
| 385 | rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset); |
| 386 | Value mask = rewriter.create<LLVM::XOrOp>( |
| 387 | loc, dstType, maskShiftedByCountAndOffset, minusOne); |
| 388 | |
| 389 | // Extract unchanged bits from the `Base` that are outside of |
| 390 | // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`. |
| 391 | Value baseAndMask = |
| 392 | rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask); |
| 393 | Value insertShiftedByOffset = |
| 394 | rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset); |
| 395 | rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask, |
| 396 | insertShiftedByOffset); |
| 397 | return success(); |
| 398 | } |
| 399 | }; |
| 400 | |
| 401 | /// Converts SPIR-V ConstantOp with scalar or vector type. |
| 402 | class ConstantScalarAndVectorPattern |
| 403 | : public SPIRVToLLVMConversion<spirv::ConstantOp> { |
| 404 | public: |
| 405 | using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion; |
| 406 | |
| 407 | LogicalResult |
| 408 | matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor, |
| 409 | ConversionPatternRewriter &rewriter) const override { |
| 410 | auto srcType = constOp.getType(); |
| 411 | if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat()) |
| 412 | return failure(); |
| 413 | |
| 414 | auto dstType = getTypeConverter()->convertType(srcType); |
| 415 | if (!dstType) |
| 416 | return rewriter.notifyMatchFailure(constOp, "type conversion failed" ); |
| 417 | |
| 418 | // SPIR-V constant can be a signed/unsigned integer, which has to be |
| 419 | // casted to signless integer when converting to LLVM dialect. Removing the |
| 420 | // sign bit may have unexpected behaviour. However, it is better to handle |
| 421 | // it case-by-case, given that the purpose of the conversion is not to |
| 422 | // cover all possible corner cases. |
| 423 | if (isSignedIntegerOrVector(srcType) || |
| 424 | isUnsignedIntegerOrVector(srcType)) { |
| 425 | auto signlessType = rewriter.getIntegerType(getBitWidth(srcType)); |
| 426 | |
| 427 | if (isa<VectorType>(srcType)) { |
| 428 | auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue()); |
| 429 | rewriter.replaceOpWithNewOp<LLVM::ConstantOp>( |
| 430 | constOp, dstType, |
| 431 | dstElementsAttr.mapValues( |
| 432 | signlessType, [&](const APInt &value) { return value; })); |
| 433 | return success(); |
| 434 | } |
| 435 | auto srcAttr = cast<IntegerAttr>(constOp.getValue()); |
| 436 | auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue()); |
| 437 | rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr); |
| 438 | return success(); |
| 439 | } |
| 440 | rewriter.replaceOpWithNewOp<LLVM::ConstantOp>( |
| 441 | constOp, dstType, adaptor.getOperands(), constOp->getAttrs()); |
| 442 | return success(); |
| 443 | } |
| 444 | }; |
| 445 | |
| 446 | class |
| 447 | : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> { |
| 448 | public: |
| 449 | using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion; |
| 450 | |
| 451 | LogicalResult |
| 452 | matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor, |
| 453 | ConversionPatternRewriter &rewriter) const override { |
| 454 | auto srcType = op.getType(); |
| 455 | auto dstType = getTypeConverter()->convertType(srcType); |
| 456 | if (!dstType) |
| 457 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
| 458 | Location loc = op.getLoc(); |
| 459 | |
| 460 | // Process `Offset` and `Count`: broadcast and extend/truncate if needed. |
| 461 | Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType, |
| 462 | *getTypeConverter(), rewriter); |
| 463 | Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType, |
| 464 | *getTypeConverter(), rewriter); |
| 465 | |
| 466 | // Create a constant that holds the size of the `Base`. |
| 467 | IntegerType integerType; |
| 468 | if (auto vecType = dyn_cast<VectorType>(srcType)) |
| 469 | integerType = cast<IntegerType>(vecType.getElementType()); |
| 470 | else |
| 471 | integerType = cast<IntegerType>(srcType); |
| 472 | |
| 473 | auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType)); |
| 474 | Value size = |
| 475 | isa<VectorType>(srcType) |
| 476 | ? rewriter.create<LLVM::ConstantOp>( |
| 477 | loc, dstType, |
| 478 | SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize)) |
| 479 | : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize); |
| 480 | |
| 481 | // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit |
| 482 | // at Offset + Count - 1 is the most significant bit now. |
| 483 | Value countPlusOffset = |
| 484 | rewriter.create<LLVM::AddOp>(loc, dstType, count, offset); |
| 485 | Value amountToShiftLeft = |
| 486 | rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset); |
| 487 | Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>( |
| 488 | loc, dstType, op.getBase(), amountToShiftLeft); |
| 489 | |
| 490 | // Shift the result right, filling the bits with the sign bit. |
| 491 | Value amountToShiftRight = |
| 492 | rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft); |
| 493 | rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft, |
| 494 | amountToShiftRight); |
| 495 | return success(); |
| 496 | } |
| 497 | }; |
| 498 | |
| 499 | class |
| 500 | : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> { |
| 501 | public: |
| 502 | using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion; |
| 503 | |
| 504 | LogicalResult |
| 505 | matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor, |
| 506 | ConversionPatternRewriter &rewriter) const override { |
| 507 | auto srcType = op.getType(); |
| 508 | auto dstType = getTypeConverter()->convertType(srcType); |
| 509 | if (!dstType) |
| 510 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
| 511 | Location loc = op.getLoc(); |
| 512 | |
| 513 | // Process `Offset` and `Count`: broadcast and extend/truncate if needed. |
| 514 | Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType, |
| 515 | *getTypeConverter(), rewriter); |
| 516 | Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType, |
| 517 | *getTypeConverter(), rewriter); |
| 518 | |
| 519 | // Create a mask with bits set at [0, Count - 1]. |
| 520 | Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); |
| 521 | Value maskShiftedByCount = |
| 522 | rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count); |
| 523 | Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount, |
| 524 | minusOne); |
| 525 | |
| 526 | // Shift `Base` by `Offset` and apply the mask on it. |
| 527 | Value shiftedBase = |
| 528 | rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset); |
| 529 | rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask); |
| 530 | return success(); |
| 531 | } |
| 532 | }; |
| 533 | |
| 534 | class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> { |
| 535 | public: |
| 536 | using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion; |
| 537 | |
| 538 | LogicalResult |
| 539 | matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor, |
| 540 | ConversionPatternRewriter &rewriter) const override { |
| 541 | rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(), |
| 542 | branchOp.getTarget()); |
| 543 | return success(); |
| 544 | } |
| 545 | }; |
| 546 | |
| 547 | class BranchConditionalConversionPattern |
| 548 | : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> { |
| 549 | public: |
| 550 | using SPIRVToLLVMConversion< |
| 551 | spirv::BranchConditionalOp>::SPIRVToLLVMConversion; |
| 552 | |
| 553 | LogicalResult |
| 554 | matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor, |
| 555 | ConversionPatternRewriter &rewriter) const override { |
| 556 | // If branch weights exist, map them to 32-bit integer vector. |
| 557 | DenseI32ArrayAttr branchWeights = nullptr; |
| 558 | if (auto weights = op.getBranchWeights()) { |
| 559 | SmallVector<int32_t> weightValues; |
| 560 | for (auto weight : weights->getAsRange<IntegerAttr>()) |
| 561 | weightValues.push_back(weight.getInt()); |
| 562 | branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues); |
| 563 | } |
| 564 | |
| 565 | rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( |
| 566 | op, op.getCondition(), op.getTrueBlockArguments(), |
| 567 | op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(), |
| 568 | op.getFalseBlock()); |
| 569 | return success(); |
| 570 | } |
| 571 | }; |
| 572 | |
| 573 | /// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container |
| 574 | /// type is an aggregate type (struct or array). Otherwise, converts to |
| 575 | /// `llvm.extractelement` that operates on vectors. |
| 576 | class |
| 577 | : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> { |
| 578 | public: |
| 579 | using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion; |
| 580 | |
| 581 | LogicalResult |
| 582 | matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor, |
| 583 | ConversionPatternRewriter &rewriter) const override { |
| 584 | auto dstType = this->getTypeConverter()->convertType(op.getType()); |
| 585 | if (!dstType) |
| 586 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
| 587 | |
| 588 | Type containerType = op.getComposite().getType(); |
| 589 | if (isa<VectorType>(Val: containerType)) { |
| 590 | Location loc = op.getLoc(); |
| 591 | IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]); |
| 592 | Value index = createI32ConstantOf(loc, rewriter, value.getInt()); |
| 593 | rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( |
| 594 | op, dstType, adaptor.getComposite(), index); |
| 595 | return success(); |
| 596 | } |
| 597 | |
| 598 | rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>( |
| 599 | op, adaptor.getComposite(), |
| 600 | LLVM::convertArrayToIndices(op.getIndices())); |
| 601 | return success(); |
| 602 | } |
| 603 | }; |
| 604 | |
| 605 | /// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container |
| 606 | /// type is an aggregate type (struct or array). Otherwise, converts to |
| 607 | /// `llvm.insertelement` that operates on vectors. |
| 608 | class CompositeInsertPattern |
| 609 | : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> { |
| 610 | public: |
| 611 | using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion; |
| 612 | |
| 613 | LogicalResult |
| 614 | matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor, |
| 615 | ConversionPatternRewriter &rewriter) const override { |
| 616 | auto dstType = this->getTypeConverter()->convertType(op.getType()); |
| 617 | if (!dstType) |
| 618 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
| 619 | |
| 620 | Type containerType = op.getComposite().getType(); |
| 621 | if (isa<VectorType>(Val: containerType)) { |
| 622 | Location loc = op.getLoc(); |
| 623 | IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]); |
| 624 | Value index = createI32ConstantOf(loc, rewriter, value.getInt()); |
| 625 | rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( |
| 626 | op, dstType, adaptor.getComposite(), adaptor.getObject(), index); |
| 627 | return success(); |
| 628 | } |
| 629 | |
| 630 | rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>( |
| 631 | op, adaptor.getComposite(), adaptor.getObject(), |
| 632 | LLVM::convertArrayToIndices(op.getIndices())); |
| 633 | return success(); |
| 634 | } |
| 635 | }; |
| 636 | |
| 637 | /// Converts SPIR-V operations that have straightforward LLVM equivalent |
| 638 | /// into LLVM dialect operations. |
| 639 | template <typename SPIRVOp, typename LLVMOp> |
| 640 | class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> { |
| 641 | public: |
| 642 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
| 643 | |
| 644 | LogicalResult |
| 645 | matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, |
| 646 | ConversionPatternRewriter &rewriter) const override { |
| 647 | auto dstType = this->getTypeConverter()->convertType(op.getType()); |
| 648 | if (!dstType) |
| 649 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
| 650 | rewriter.template replaceOpWithNewOp<LLVMOp>( |
| 651 | op, dstType, adaptor.getOperands(), op->getAttrs()); |
| 652 | return success(); |
| 653 | } |
| 654 | }; |
| 655 | |
| 656 | /// Converts `spirv.ExecutionMode` into a global struct constant that holds |
| 657 | /// execution mode information. |
| 658 | class ExecutionModePattern |
| 659 | : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> { |
| 660 | public: |
| 661 | using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion; |
| 662 | |
| 663 | LogicalResult |
| 664 | matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor, |
| 665 | ConversionPatternRewriter &rewriter) const override { |
| 666 | // First, create the global struct's name that would be associated with |
| 667 | // this entry point's execution mode. We set it to be: |
| 668 | // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode} |
| 669 | ModuleOp module = op->getParentOfType<ModuleOp>(); |
| 670 | spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr(); |
| 671 | std::string moduleName; |
| 672 | if (module.getName().has_value()) |
| 673 | moduleName = "_" + module.getName()->str(); |
| 674 | else |
| 675 | moduleName = "" ; |
| 676 | std::string executionModeInfoName = llvm::formatv( |
| 677 | "__spv_{0}_{1}_execution_mode_info_{2}" , moduleName, op.getFn().str(), |
| 678 | static_cast<uint32_t>(executionModeAttr.getValue())); |
| 679 | |
| 680 | MLIRContext *context = rewriter.getContext(); |
| 681 | OpBuilder::InsertionGuard guard(rewriter); |
| 682 | rewriter.setInsertionPointToStart(module.getBody()); |
| 683 | |
| 684 | // Create a struct type, corresponding to the C struct below. |
| 685 | // struct { |
| 686 | // int32_t executionMode; |
| 687 | // int32_t values[]; // optional values |
| 688 | // }; |
| 689 | auto llvmI32Type = IntegerType::get(context, 32); |
| 690 | SmallVector<Type, 2> fields; |
| 691 | fields.push_back(Elt: llvmI32Type); |
| 692 | ArrayAttr values = op.getValues(); |
| 693 | if (!values.empty()) { |
| 694 | auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size()); |
| 695 | fields.push_back(Elt: arrayType); |
| 696 | } |
| 697 | auto structType = LLVM::LLVMStructType::getLiteral(context, fields); |
| 698 | |
| 699 | // Create `llvm.mlir.global` with initializer region containing one block. |
| 700 | auto global = rewriter.create<LLVM::GlobalOp>( |
| 701 | UnknownLoc::get(context), structType, /*isConstant=*/true, |
| 702 | LLVM::Linkage::External, executionModeInfoName, Attribute(), |
| 703 | /*alignment=*/0); |
| 704 | Location loc = global.getLoc(); |
| 705 | Region ®ion = global.getInitializerRegion(); |
| 706 | Block *block = rewriter.createBlock(parent: ®ion); |
| 707 | |
| 708 | // Initialize the struct and set the execution mode value. |
| 709 | rewriter.setInsertionPointToStart(block); |
| 710 | Value structValue = rewriter.create<LLVM::PoisonOp>(loc, structType); |
| 711 | Value executionMode = rewriter.create<LLVM::ConstantOp>( |
| 712 | loc, llvmI32Type, |
| 713 | rewriter.getI32IntegerAttr( |
| 714 | static_cast<uint32_t>(executionModeAttr.getValue()))); |
| 715 | structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue, |
| 716 | executionMode, 0); |
| 717 | |
| 718 | // Insert extra operands if they exist into execution mode info struct. |
| 719 | for (unsigned i = 0, e = values.size(); i < e; ++i) { |
| 720 | auto attr = values.getValue()[i]; |
| 721 | Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr); |
| 722 | structValue = rewriter.create<LLVM::InsertValueOp>( |
| 723 | loc, structValue, entry, ArrayRef<int64_t>({1, i})); |
| 724 | } |
| 725 | rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue})); |
| 726 | rewriter.eraseOp(op: op); |
| 727 | return success(); |
| 728 | } |
| 729 | }; |
| 730 | |
| 731 | /// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V |
| 732 | /// global returns a pointer, whereas in LLVM dialect the global holds an actual |
| 733 | /// value. This difference is handled by `spirv.mlir.addressof` and |
| 734 | /// `llvm.mlir.addressof`ops that both return a pointer. |
| 735 | class GlobalVariablePattern |
| 736 | : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> { |
| 737 | public: |
| 738 | template <typename... Args> |
| 739 | GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args) |
| 740 | : SPIRVToLLVMConversion<spirv::GlobalVariableOp>( |
| 741 | std::forward<Args>(args)...), |
| 742 | clientAPI(clientAPI) {} |
| 743 | |
| 744 | LogicalResult |
| 745 | matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor, |
| 746 | ConversionPatternRewriter &rewriter) const override { |
| 747 | // Currently, there is no support of initialization with a constant value in |
| 748 | // SPIR-V dialect. Specialization constants are not considered as well. |
| 749 | if (op.getInitializer()) |
| 750 | return failure(); |
| 751 | |
| 752 | auto srcType = cast<spirv::PointerType>(op.getType()); |
| 753 | auto dstType = getTypeConverter()->convertType(srcType.getPointeeType()); |
| 754 | if (!dstType) |
| 755 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
| 756 | |
| 757 | // Limit conversion to the current invocation only or `StorageBuffer` |
| 758 | // required by SPIR-V runner. |
| 759 | // This is okay because multiple invocations are not supported yet. |
| 760 | auto storageClass = srcType.getStorageClass(); |
| 761 | switch (storageClass) { |
| 762 | case spirv::StorageClass::Input: |
| 763 | case spirv::StorageClass::Private: |
| 764 | case spirv::StorageClass::Output: |
| 765 | case spirv::StorageClass::StorageBuffer: |
| 766 | case spirv::StorageClass::UniformConstant: |
| 767 | break; |
| 768 | default: |
| 769 | return failure(); |
| 770 | } |
| 771 | |
| 772 | // LLVM dialect spec: "If the global value is a constant, storing into it is |
| 773 | // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant' |
| 774 | // storage class that is read-only. |
| 775 | bool isConstant = (storageClass == spirv::StorageClass::Input) || |
| 776 | (storageClass == spirv::StorageClass::UniformConstant); |
| 777 | // SPIR-V spec: "By default, functions and global variables are private to a |
| 778 | // module and cannot be accessed by other modules. However, a module may be |
| 779 | // written to export or import functions and global (module scope) |
| 780 | // variables.". Therefore, map 'Private' storage class to private linkage, |
| 781 | // 'Input' and 'Output' to external linkage. |
| 782 | auto linkage = storageClass == spirv::StorageClass::Private |
| 783 | ? LLVM::Linkage::Private |
| 784 | : LLVM::Linkage::External; |
| 785 | auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( |
| 786 | op, dstType, isConstant, linkage, op.getSymName(), Attribute(), |
| 787 | /*alignment=*/0, storageClassToAddressSpace(clientAPI, storageClass)); |
| 788 | |
| 789 | // Attach location attribute if applicable |
| 790 | if (op.getLocationAttr()) |
| 791 | newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr()); |
| 792 | |
| 793 | return success(); |
| 794 | } |
| 795 | |
| 796 | private: |
| 797 | spirv::ClientAPI clientAPI; |
| 798 | }; |
| 799 | |
| 800 | /// Converts SPIR-V cast ops that do not have straightforward LLVM |
| 801 | /// equivalent in LLVM dialect. |
| 802 | template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp> |
| 803 | class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> { |
| 804 | public: |
| 805 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
| 806 | |
| 807 | LogicalResult |
| 808 | matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, |
| 809 | ConversionPatternRewriter &rewriter) const override { |
| 810 | |
| 811 | Type fromType = op.getOperand().getType(); |
| 812 | Type toType = op.getType(); |
| 813 | |
| 814 | auto dstType = this->getTypeConverter()->convertType(toType); |
| 815 | if (!dstType) |
| 816 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
| 817 | |
| 818 | if (getBitWidth(type: fromType) < getBitWidth(type: toType)) { |
| 819 | rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType, |
| 820 | adaptor.getOperands()); |
| 821 | return success(); |
| 822 | } |
| 823 | if (getBitWidth(type: fromType) > getBitWidth(type: toType)) { |
| 824 | rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType, |
| 825 | adaptor.getOperands()); |
| 826 | return success(); |
| 827 | } |
| 828 | return failure(); |
| 829 | } |
| 830 | }; |
| 831 | |
| 832 | class FunctionCallPattern |
| 833 | : public SPIRVToLLVMConversion<spirv::FunctionCallOp> { |
| 834 | public: |
| 835 | using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion; |
| 836 | |
| 837 | LogicalResult |
| 838 | matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor, |
| 839 | ConversionPatternRewriter &rewriter) const override { |
| 840 | if (callOp.getNumResults() == 0) { |
| 841 | auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>( |
| 842 | callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs()); |
| 843 | newOp.getProperties().operandSegmentSizes = { |
| 844 | static_cast<int32_t>(adaptor.getOperands().size()), 0}; |
| 845 | newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({}); |
| 846 | return success(); |
| 847 | } |
| 848 | |
| 849 | // Function returns a single result. |
| 850 | auto dstType = getTypeConverter()->convertType(callOp.getType(0)); |
| 851 | if (!dstType) |
| 852 | return rewriter.notifyMatchFailure(callOp, "type conversion failed" ); |
| 853 | auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>( |
| 854 | callOp, dstType, adaptor.getOperands(), callOp->getAttrs()); |
| 855 | newOp.getProperties().operandSegmentSizes = { |
| 856 | static_cast<int32_t>(adaptor.getOperands().size()), 0}; |
| 857 | newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({}); |
| 858 | return success(); |
| 859 | } |
| 860 | }; |
| 861 | |
| 862 | /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate" |
| 863 | template <typename SPIRVOp, LLVM::FCmpPredicate predicate> |
| 864 | class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> { |
| 865 | public: |
| 866 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
| 867 | |
| 868 | LogicalResult |
| 869 | matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, |
| 870 | ConversionPatternRewriter &rewriter) const override { |
| 871 | |
| 872 | auto dstType = this->getTypeConverter()->convertType(op.getType()); |
| 873 | if (!dstType) |
| 874 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
| 875 | |
| 876 | rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>( |
| 877 | op, dstType, predicate, op.getOperand1(), op.getOperand2()); |
| 878 | return success(); |
| 879 | } |
| 880 | }; |
| 881 | |
| 882 | /// Converts SPIR-V integer comparisons to llvm.icmp "predicate" |
| 883 | template <typename SPIRVOp, LLVM::ICmpPredicate predicate> |
| 884 | class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> { |
| 885 | public: |
| 886 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
| 887 | |
| 888 | LogicalResult |
| 889 | matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, |
| 890 | ConversionPatternRewriter &rewriter) const override { |
| 891 | |
| 892 | auto dstType = this->getTypeConverter()->convertType(op.getType()); |
| 893 | if (!dstType) |
| 894 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
| 895 | |
| 896 | rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>( |
| 897 | op, dstType, predicate, op.getOperand1(), op.getOperand2()); |
| 898 | return success(); |
| 899 | } |
| 900 | }; |
| 901 | |
| 902 | class InverseSqrtPattern |
| 903 | : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> { |
| 904 | public: |
| 905 | using SPIRVToLLVMConversion<spirv::GLInverseSqrtOp>::SPIRVToLLVMConversion; |
| 906 | |
| 907 | LogicalResult |
| 908 | matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor, |
| 909 | ConversionPatternRewriter &rewriter) const override { |
| 910 | auto srcType = op.getType(); |
| 911 | auto dstType = getTypeConverter()->convertType(srcType); |
| 912 | if (!dstType) |
| 913 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
| 914 | |
| 915 | Location loc = op.getLoc(); |
| 916 | Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); |
| 917 | Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand()); |
| 918 | rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt); |
| 919 | return success(); |
| 920 | } |
| 921 | }; |
| 922 | |
| 923 | /// Converts `spirv.Load` and `spirv.Store` to LLVM dialect. |
| 924 | template <typename SPIRVOp> |
| 925 | class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> { |
| 926 | public: |
| 927 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
| 928 | |
| 929 | LogicalResult |
| 930 | matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, |
| 931 | ConversionPatternRewriter &rewriter) const override { |
| 932 | if (!op.getMemoryAccess()) { |
| 933 | return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter, |
| 934 | *this->getTypeConverter(), /*alignment=*/0, |
| 935 | /*isVolatile=*/false, |
| 936 | /*isNonTemporal=*/false); |
| 937 | } |
| 938 | auto memoryAccess = *op.getMemoryAccess(); |
| 939 | switch (memoryAccess) { |
| 940 | case spirv::MemoryAccess::Aligned: |
| 941 | case spirv::MemoryAccess::None: |
| 942 | case spirv::MemoryAccess::Nontemporal: |
| 943 | case spirv::MemoryAccess::Volatile: { |
| 944 | unsigned alignment = |
| 945 | memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0; |
| 946 | bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal; |
| 947 | bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile; |
| 948 | return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter, |
| 949 | *this->getTypeConverter(), alignment, |
| 950 | isVolatile, isNonTemporal); |
| 951 | } |
| 952 | default: |
| 953 | // There is no support of other memory access attributes. |
| 954 | return failure(); |
| 955 | } |
| 956 | } |
| 957 | }; |
| 958 | |
| 959 | /// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect. |
| 960 | template <typename SPIRVOp> |
| 961 | class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> { |
| 962 | public: |
| 963 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
| 964 | |
| 965 | LogicalResult |
| 966 | matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor, |
| 967 | ConversionPatternRewriter &rewriter) const override { |
| 968 | auto srcType = notOp.getType(); |
| 969 | auto dstType = this->getTypeConverter()->convertType(srcType); |
| 970 | if (!dstType) |
| 971 | return rewriter.notifyMatchFailure(notOp, "type conversion failed" ); |
| 972 | |
| 973 | Location loc = notOp.getLoc(); |
| 974 | IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter); |
| 975 | auto mask = |
| 976 | isa<VectorType>(srcType) |
| 977 | ? rewriter.create<LLVM::ConstantOp>( |
| 978 | loc, dstType, |
| 979 | SplatElementsAttr::get(cast<VectorType>(srcType), minusOne)) |
| 980 | : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne); |
| 981 | rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType, |
| 982 | notOp.getOperand(), mask); |
| 983 | return success(); |
| 984 | } |
| 985 | }; |
| 986 | |
| 987 | /// A template pattern that erases the given `SPIRVOp`. |
| 988 | template <typename SPIRVOp> |
| 989 | class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> { |
| 990 | public: |
| 991 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
| 992 | |
| 993 | LogicalResult |
| 994 | matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, |
| 995 | ConversionPatternRewriter &rewriter) const override { |
| 996 | rewriter.eraseOp(op); |
| 997 | return success(); |
| 998 | } |
| 999 | }; |
| 1000 | |
| 1001 | class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> { |
| 1002 | public: |
| 1003 | using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion; |
| 1004 | |
| 1005 | LogicalResult |
| 1006 | matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor, |
| 1007 | ConversionPatternRewriter &rewriter) const override { |
| 1008 | rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(), |
| 1009 | ArrayRef<Value>()); |
| 1010 | return success(); |
| 1011 | } |
| 1012 | }; |
| 1013 | |
| 1014 | class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> { |
| 1015 | public: |
| 1016 | using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion; |
| 1017 | |
| 1018 | LogicalResult |
| 1019 | matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor, |
| 1020 | ConversionPatternRewriter &rewriter) const override { |
| 1021 | rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(), |
| 1022 | adaptor.getOperands()); |
| 1023 | return success(); |
| 1024 | } |
| 1025 | }; |
| 1026 | |
| 1027 | static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, |
| 1028 | StringRef name, |
| 1029 | ArrayRef<Type> paramTypes, |
| 1030 | Type resultType, |
| 1031 | bool convergent = true) { |
| 1032 | auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>( |
| 1033 | SymbolTable::lookupSymbolIn(symbolTable, name)); |
| 1034 | if (func) |
| 1035 | return func; |
| 1036 | |
| 1037 | OpBuilder b(symbolTable->getRegion(index: 0)); |
| 1038 | func = b.create<LLVM::LLVMFuncOp>( |
| 1039 | symbolTable->getLoc(), name, |
| 1040 | LLVM::LLVMFunctionType::get(resultType, paramTypes)); |
| 1041 | func.setCConv(LLVM::cconv::CConv::SPIR_FUNC); |
| 1042 | func.setConvergent(convergent); |
| 1043 | func.setNoUnwind(true); |
| 1044 | func.setWillReturn(true); |
| 1045 | return func; |
| 1046 | } |
| 1047 | |
| 1048 | static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder, |
| 1049 | LLVM::LLVMFuncOp func, |
| 1050 | ValueRange args) { |
| 1051 | auto call = builder.create<LLVM::CallOp>(loc, func, args); |
| 1052 | call.setCConv(func.getCConv()); |
| 1053 | call.setConvergentAttr(func.getConvergentAttr()); |
| 1054 | call.setNoUnwindAttr(func.getNoUnwindAttr()); |
| 1055 | call.setWillReturnAttr(func.getWillReturnAttr()); |
| 1056 | return call; |
| 1057 | } |
| 1058 | |
| 1059 | template <typename BarrierOpTy> |
| 1060 | class ControlBarrierPattern : public SPIRVToLLVMConversion<BarrierOpTy> { |
| 1061 | public: |
| 1062 | using OpAdaptor = typename SPIRVToLLVMConversion<BarrierOpTy>::OpAdaptor; |
| 1063 | |
| 1064 | using SPIRVToLLVMConversion<BarrierOpTy>::SPIRVToLLVMConversion; |
| 1065 | |
| 1066 | static constexpr StringRef getFuncName(); |
| 1067 | |
| 1068 | LogicalResult |
| 1069 | matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor, |
| 1070 | ConversionPatternRewriter &rewriter) const override { |
| 1071 | constexpr StringRef funcName = getFuncName(); |
| 1072 | Operation *symbolTable = |
| 1073 | controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>(); |
| 1074 | |
| 1075 | Type i32 = rewriter.getI32Type(); |
| 1076 | |
| 1077 | Type voidTy = rewriter.getType<LLVM::LLVMVoidType>(); |
| 1078 | LLVM::LLVMFuncOp func = |
| 1079 | lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy); |
| 1080 | |
| 1081 | Location loc = controlBarrierOp->getLoc(); |
| 1082 | Value execution = rewriter.create<LLVM::ConstantOp>( |
| 1083 | loc, i32, static_cast<int32_t>(adaptor.getExecutionScope())); |
| 1084 | Value memory = rewriter.create<LLVM::ConstantOp>( |
| 1085 | loc, i32, static_cast<int32_t>(adaptor.getMemoryScope())); |
| 1086 | Value semantics = rewriter.create<LLVM::ConstantOp>( |
| 1087 | loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics())); |
| 1088 | |
| 1089 | auto call = createSPIRVBuiltinCall(loc, rewriter, func, |
| 1090 | {execution, memory, semantics}); |
| 1091 | |
| 1092 | rewriter.replaceOp(controlBarrierOp, call); |
| 1093 | return success(); |
| 1094 | } |
| 1095 | }; |
| 1096 | |
| 1097 | namespace { |
| 1098 | |
| 1099 | StringRef getTypeMangling(Type type, bool isSigned) { |
| 1100 | return llvm::TypeSwitch<Type, StringRef>(type) |
| 1101 | .Case<Float16Type>([](auto) { return "Dh" ; }) |
| 1102 | .Case<Float32Type>([](auto) { return "f" ; }) |
| 1103 | .Case<Float64Type>([](auto) { return "d" ; }) |
| 1104 | .Case<IntegerType>([isSigned](IntegerType intTy) { |
| 1105 | switch (intTy.getWidth()) { |
| 1106 | case 1: |
| 1107 | return "b" ; |
| 1108 | case 8: |
| 1109 | return (isSigned) ? "a" : "c" ; |
| 1110 | case 16: |
| 1111 | return (isSigned) ? "s" : "t" ; |
| 1112 | case 32: |
| 1113 | return (isSigned) ? "i" : "j" ; |
| 1114 | case 64: |
| 1115 | return (isSigned) ? "l" : "m" ; |
| 1116 | default: |
| 1117 | llvm_unreachable("Unsupported integer width" ); |
| 1118 | } |
| 1119 | }) |
| 1120 | .Default([](auto) { |
| 1121 | llvm_unreachable("No mangling defined" ); |
| 1122 | return "" ; |
| 1123 | }); |
| 1124 | } |
| 1125 | |
| 1126 | template <typename ReduceOp> |
| 1127 | constexpr StringLiteral getGroupFuncName(); |
| 1128 | |
| 1129 | template <> |
| 1130 | constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() { |
| 1131 | return "_Z17__spirv_GroupIAddii" ; |
| 1132 | } |
| 1133 | template <> |
| 1134 | constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() { |
| 1135 | return "_Z17__spirv_GroupFAddii" ; |
| 1136 | } |
| 1137 | template <> |
| 1138 | constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() { |
| 1139 | return "_Z17__spirv_GroupSMinii" ; |
| 1140 | } |
| 1141 | template <> |
| 1142 | constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() { |
| 1143 | return "_Z17__spirv_GroupUMinii" ; |
| 1144 | } |
| 1145 | template <> |
| 1146 | constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() { |
| 1147 | return "_Z17__spirv_GroupFMinii" ; |
| 1148 | } |
| 1149 | template <> |
| 1150 | constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() { |
| 1151 | return "_Z17__spirv_GroupSMaxii" ; |
| 1152 | } |
| 1153 | template <> |
| 1154 | constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() { |
| 1155 | return "_Z17__spirv_GroupUMaxii" ; |
| 1156 | } |
| 1157 | template <> |
| 1158 | constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() { |
| 1159 | return "_Z17__spirv_GroupFMaxii" ; |
| 1160 | } |
| 1161 | template <> |
| 1162 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() { |
| 1163 | return "_Z27__spirv_GroupNonUniformIAddii" ; |
| 1164 | } |
| 1165 | template <> |
| 1166 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() { |
| 1167 | return "_Z27__spirv_GroupNonUniformFAddii" ; |
| 1168 | } |
| 1169 | template <> |
| 1170 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() { |
| 1171 | return "_Z27__spirv_GroupNonUniformIMulii" ; |
| 1172 | } |
| 1173 | template <> |
| 1174 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() { |
| 1175 | return "_Z27__spirv_GroupNonUniformFMulii" ; |
| 1176 | } |
| 1177 | template <> |
| 1178 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() { |
| 1179 | return "_Z27__spirv_GroupNonUniformSMinii" ; |
| 1180 | } |
| 1181 | template <> |
| 1182 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() { |
| 1183 | return "_Z27__spirv_GroupNonUniformUMinii" ; |
| 1184 | } |
| 1185 | template <> |
| 1186 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() { |
| 1187 | return "_Z27__spirv_GroupNonUniformFMinii" ; |
| 1188 | } |
| 1189 | template <> |
| 1190 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() { |
| 1191 | return "_Z27__spirv_GroupNonUniformSMaxii" ; |
| 1192 | } |
| 1193 | template <> |
| 1194 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() { |
| 1195 | return "_Z27__spirv_GroupNonUniformUMaxii" ; |
| 1196 | } |
| 1197 | template <> |
| 1198 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() { |
| 1199 | return "_Z27__spirv_GroupNonUniformFMaxii" ; |
| 1200 | } |
| 1201 | template <> |
| 1202 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() { |
| 1203 | return "_Z33__spirv_GroupNonUniformBitwiseAndii" ; |
| 1204 | } |
| 1205 | template <> |
| 1206 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() { |
| 1207 | return "_Z32__spirv_GroupNonUniformBitwiseOrii" ; |
| 1208 | } |
| 1209 | template <> |
| 1210 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() { |
| 1211 | return "_Z33__spirv_GroupNonUniformBitwiseXorii" ; |
| 1212 | } |
| 1213 | template <> |
| 1214 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() { |
| 1215 | return "_Z33__spirv_GroupNonUniformLogicalAndii" ; |
| 1216 | } |
| 1217 | template <> |
| 1218 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() { |
| 1219 | return "_Z32__spirv_GroupNonUniformLogicalOrii" ; |
| 1220 | } |
| 1221 | template <> |
| 1222 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() { |
| 1223 | return "_Z33__spirv_GroupNonUniformLogicalXorii" ; |
| 1224 | } |
| 1225 | } // namespace |
| 1226 | |
| 1227 | template <typename ReduceOp, bool Signed = false, bool NonUniform = false> |
| 1228 | class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> { |
| 1229 | public: |
| 1230 | using SPIRVToLLVMConversion<ReduceOp>::SPIRVToLLVMConversion; |
| 1231 | |
| 1232 | LogicalResult |
| 1233 | matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor, |
| 1234 | ConversionPatternRewriter &rewriter) const override { |
| 1235 | |
| 1236 | Type retTy = op.getResult().getType(); |
| 1237 | if (!retTy.isIntOrFloat()) { |
| 1238 | return failure(); |
| 1239 | } |
| 1240 | SmallString<36> funcName = getGroupFuncName<ReduceOp>(); |
| 1241 | funcName += getTypeMangling(type: retTy, isSigned: false); |
| 1242 | |
| 1243 | Type i32Ty = rewriter.getI32Type(); |
| 1244 | SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy}; |
| 1245 | if constexpr (NonUniform) { |
| 1246 | if (adaptor.getClusterSize()) { |
| 1247 | funcName += "j" ; |
| 1248 | paramTypes.push_back(Elt: i32Ty); |
| 1249 | } |
| 1250 | } |
| 1251 | |
| 1252 | Operation *symbolTable = |
| 1253 | op->template getParentWithTrait<OpTrait::SymbolTable>(); |
| 1254 | |
| 1255 | LLVM::LLVMFuncOp func = |
| 1256 | lookupOrCreateSPIRVFn(symbolTable, funcName, paramTypes, retTy); |
| 1257 | |
| 1258 | Location loc = op.getLoc(); |
| 1259 | Value scope = rewriter.create<LLVM::ConstantOp>( |
| 1260 | loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope())); |
| 1261 | Value groupOp = rewriter.create<LLVM::ConstantOp>( |
| 1262 | loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation())); |
| 1263 | SmallVector<Value> operands{scope, groupOp}; |
| 1264 | operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end()); |
| 1265 | |
| 1266 | auto call = createSPIRVBuiltinCall(loc, rewriter, func, operands); |
| 1267 | rewriter.replaceOp(op, call); |
| 1268 | return success(); |
| 1269 | } |
| 1270 | }; |
| 1271 | |
| 1272 | template <> |
| 1273 | constexpr StringRef |
| 1274 | ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() { |
| 1275 | return "_Z22__spirv_ControlBarrieriii" ; |
| 1276 | } |
| 1277 | |
| 1278 | template <> |
| 1279 | constexpr StringRef |
| 1280 | ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() { |
| 1281 | return "_Z33__spirv_ControlBarrierArriveINTELiii" ; |
| 1282 | } |
| 1283 | |
| 1284 | template <> |
| 1285 | constexpr StringRef |
| 1286 | ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() { |
| 1287 | return "_Z31__spirv_ControlBarrierWaitINTELiii" ; |
| 1288 | } |
| 1289 | |
| 1290 | /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection |
| 1291 | /// should be reachable for conversion to succeed. The structure of the loop in |
| 1292 | /// LLVM dialect will be the following: |
| 1293 | /// |
| 1294 | /// +------------------------------------+ |
| 1295 | /// | <code before spirv.mlir.loop> | |
| 1296 | /// | llvm.br ^header | |
| 1297 | /// +------------------------------------+ |
| 1298 | /// | |
| 1299 | /// +----------------+ | |
| 1300 | /// | | | |
| 1301 | /// | V V |
| 1302 | /// | +------------------------------------+ |
| 1303 | /// | | ^header: | |
| 1304 | /// | | <header code> | |
| 1305 | /// | | llvm.cond_br %cond, ^body, ^exit | |
| 1306 | /// | +------------------------------------+ |
| 1307 | /// | | |
| 1308 | /// | |----------------------+ |
| 1309 | /// | | | |
| 1310 | /// | V | |
| 1311 | /// | +------------------------------------+ | |
| 1312 | /// | | ^body: | | |
| 1313 | /// | | <body code> | | |
| 1314 | /// | | llvm.br ^continue | | |
| 1315 | /// | +------------------------------------+ | |
| 1316 | /// | | | |
| 1317 | /// | V | |
| 1318 | /// | +------------------------------------+ | |
| 1319 | /// | | ^continue: | | |
| 1320 | /// | | <continue code> | | |
| 1321 | /// | | llvm.br ^header | | |
| 1322 | /// | +------------------------------------+ | |
| 1323 | /// | | | |
| 1324 | /// +---------------+ +----------------------+ |
| 1325 | /// | |
| 1326 | /// V |
| 1327 | /// +------------------------------------+ |
| 1328 | /// | ^exit: | |
| 1329 | /// | llvm.br ^remaining | |
| 1330 | /// +------------------------------------+ |
| 1331 | /// | |
| 1332 | /// V |
| 1333 | /// +------------------------------------+ |
| 1334 | /// | ^remaining: | |
| 1335 | /// | <code after spirv.mlir.loop> | |
| 1336 | /// +------------------------------------+ |
| 1337 | /// |
| 1338 | class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> { |
| 1339 | public: |
| 1340 | using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion; |
| 1341 | |
| 1342 | LogicalResult |
| 1343 | matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor, |
| 1344 | ConversionPatternRewriter &rewriter) const override { |
| 1345 | // There is no support of loop control at the moment. |
| 1346 | if (loopOp.getLoopControl() != spirv::LoopControl::None) |
| 1347 | return failure(); |
| 1348 | |
| 1349 | // `spirv.mlir.loop` with empty region is redundant and should be erased. |
| 1350 | if (loopOp.getBody().empty()) { |
| 1351 | rewriter.eraseOp(op: loopOp); |
| 1352 | return success(); |
| 1353 | } |
| 1354 | |
| 1355 | Location loc = loopOp.getLoc(); |
| 1356 | |
| 1357 | // Split the current block after `spirv.mlir.loop`. The remaining ops will |
| 1358 | // be used in `endBlock`. |
| 1359 | Block *currentBlock = rewriter.getBlock(); |
| 1360 | auto position = Block::iterator(loopOp); |
| 1361 | Block *endBlock = rewriter.splitBlock(block: currentBlock, before: position); |
| 1362 | |
| 1363 | // Remove entry block and create a branch in the current block going to the |
| 1364 | // header block. |
| 1365 | Block *entryBlock = loopOp.getEntryBlock(); |
| 1366 | assert(entryBlock->getOperations().size() == 1); |
| 1367 | auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front()); |
| 1368 | if (!brOp) |
| 1369 | return failure(); |
| 1370 | Block * = loopOp.getHeaderBlock(); |
| 1371 | rewriter.setInsertionPointToEnd(currentBlock); |
| 1372 | rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock); |
| 1373 | rewriter.eraseBlock(block: entryBlock); |
| 1374 | |
| 1375 | // Branch from merge block to end block. |
| 1376 | Block *mergeBlock = loopOp.getMergeBlock(); |
| 1377 | Operation *terminator = mergeBlock->getTerminator(); |
| 1378 | ValueRange terminatorOperands = terminator->getOperands(); |
| 1379 | rewriter.setInsertionPointToEnd(mergeBlock); |
| 1380 | rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock); |
| 1381 | |
| 1382 | rewriter.inlineRegionBefore(loopOp.getBody(), endBlock); |
| 1383 | rewriter.replaceOp(loopOp, endBlock->getArguments()); |
| 1384 | return success(); |
| 1385 | } |
| 1386 | }; |
| 1387 | |
| 1388 | /// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header |
| 1389 | /// block. All blocks within selection should be reachable for conversion to |
| 1390 | /// succeed. |
| 1391 | class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> { |
| 1392 | public: |
| 1393 | using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion; |
| 1394 | |
| 1395 | LogicalResult |
| 1396 | matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor, |
| 1397 | ConversionPatternRewriter &rewriter) const override { |
| 1398 | // There is no support for `Flatten` or `DontFlatten` selection control at |
| 1399 | // the moment. This are just compiler hints and can be performed during the |
| 1400 | // optimization passes. |
| 1401 | if (op.getSelectionControl() != spirv::SelectionControl::None) |
| 1402 | return failure(); |
| 1403 | |
| 1404 | // `spirv.mlir.selection` should have at least two blocks: one selection |
| 1405 | // header block and one merge block. If no blocks are present, or control |
| 1406 | // flow branches straight to merge block (two blocks are present), the op is |
| 1407 | // redundant and it is erased. |
| 1408 | if (op.getBody().getBlocks().size() <= 2) { |
| 1409 | rewriter.eraseOp(op: op); |
| 1410 | return success(); |
| 1411 | } |
| 1412 | |
| 1413 | Location loc = op.getLoc(); |
| 1414 | |
| 1415 | // Split the current block after `spirv.mlir.selection`. The remaining ops |
| 1416 | // will be used in `continueBlock`. |
| 1417 | auto *currentBlock = rewriter.getInsertionBlock(); |
| 1418 | rewriter.setInsertionPointAfter(op); |
| 1419 | auto position = rewriter.getInsertionPoint(); |
| 1420 | auto *continueBlock = rewriter.splitBlock(block: currentBlock, before: position); |
| 1421 | |
| 1422 | // Extract conditional branch information from the header block. By SPIR-V |
| 1423 | // dialect spec, it should contain `spirv.BranchConditional` or |
| 1424 | // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the |
| 1425 | // moment in the SPIR-V dialect. Remove this block when finished. |
| 1426 | auto * = op.getHeaderBlock(); |
| 1427 | assert(headerBlock->getOperations().size() == 1); |
| 1428 | auto condBrOp = dyn_cast<spirv::BranchConditionalOp>( |
| 1429 | headerBlock->getOperations().front()); |
| 1430 | if (!condBrOp) |
| 1431 | return failure(); |
| 1432 | rewriter.eraseBlock(block: headerBlock); |
| 1433 | |
| 1434 | // Branch from merge block to continue block. |
| 1435 | auto *mergeBlock = op.getMergeBlock(); |
| 1436 | Operation *terminator = mergeBlock->getTerminator(); |
| 1437 | ValueRange terminatorOperands = terminator->getOperands(); |
| 1438 | rewriter.setInsertionPointToEnd(mergeBlock); |
| 1439 | rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock); |
| 1440 | |
| 1441 | // Link current block to `true` and `false` blocks within the selection. |
| 1442 | Block *trueBlock = condBrOp.getTrueBlock(); |
| 1443 | Block *falseBlock = condBrOp.getFalseBlock(); |
| 1444 | rewriter.setInsertionPointToEnd(currentBlock); |
| 1445 | rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock, |
| 1446 | condBrOp.getTrueTargetOperands(), |
| 1447 | falseBlock, |
| 1448 | condBrOp.getFalseTargetOperands()); |
| 1449 | |
| 1450 | rewriter.inlineRegionBefore(op.getBody(), continueBlock); |
| 1451 | rewriter.replaceOp(op, continueBlock->getArguments()); |
| 1452 | return success(); |
| 1453 | } |
| 1454 | }; |
| 1455 | |
| 1456 | /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect |
| 1457 | /// puts a restriction on `Shift` and `Base` to have the same bit width, |
| 1458 | /// `Shift` is zero or sign extended to match this specification. Cases when |
| 1459 | /// `Shift` bit width > `Base` bit width are considered to be illegal. |
| 1460 | template <typename SPIRVOp, typename LLVMOp> |
| 1461 | class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> { |
| 1462 | public: |
| 1463 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
| 1464 | |
| 1465 | LogicalResult |
| 1466 | matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, |
| 1467 | ConversionPatternRewriter &rewriter) const override { |
| 1468 | |
| 1469 | auto dstType = this->getTypeConverter()->convertType(op.getType()); |
| 1470 | if (!dstType) |
| 1471 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
| 1472 | |
| 1473 | Type op1Type = op.getOperand1().getType(); |
| 1474 | Type op2Type = op.getOperand2().getType(); |
| 1475 | |
| 1476 | if (op1Type == op2Type) { |
| 1477 | rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType, |
| 1478 | adaptor.getOperands()); |
| 1479 | return success(); |
| 1480 | } |
| 1481 | |
| 1482 | std::optional<uint64_t> dstTypeWidth = |
| 1483 | getIntegerOrVectorElementWidth(dstType); |
| 1484 | std::optional<uint64_t> op2TypeWidth = |
| 1485 | getIntegerOrVectorElementWidth(type: op2Type); |
| 1486 | |
| 1487 | if (!dstTypeWidth || !op2TypeWidth) |
| 1488 | return failure(); |
| 1489 | |
| 1490 | Location loc = op.getLoc(); |
| 1491 | Value extended; |
| 1492 | if (op2TypeWidth < dstTypeWidth) { |
| 1493 | if (isUnsignedIntegerOrVector(type: op2Type)) { |
| 1494 | extended = rewriter.template create<LLVM::ZExtOp>( |
| 1495 | loc, dstType, adaptor.getOperand2()); |
| 1496 | } else { |
| 1497 | extended = rewriter.template create<LLVM::SExtOp>( |
| 1498 | loc, dstType, adaptor.getOperand2()); |
| 1499 | } |
| 1500 | } else if (op2TypeWidth == dstTypeWidth) { |
| 1501 | extended = adaptor.getOperand2(); |
| 1502 | } else { |
| 1503 | return failure(); |
| 1504 | } |
| 1505 | |
| 1506 | Value result = rewriter.template create<LLVMOp>( |
| 1507 | loc, dstType, adaptor.getOperand1(), extended); |
| 1508 | rewriter.replaceOp(op, result); |
| 1509 | return success(); |
| 1510 | } |
| 1511 | }; |
| 1512 | |
| 1513 | class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> { |
| 1514 | public: |
| 1515 | using SPIRVToLLVMConversion<spirv::GLTanOp>::SPIRVToLLVMConversion; |
| 1516 | |
| 1517 | LogicalResult |
| 1518 | matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor, |
| 1519 | ConversionPatternRewriter &rewriter) const override { |
| 1520 | auto dstType = getTypeConverter()->convertType(tanOp.getType()); |
| 1521 | if (!dstType) |
| 1522 | return rewriter.notifyMatchFailure(tanOp, "type conversion failed" ); |
| 1523 | |
| 1524 | Location loc = tanOp.getLoc(); |
| 1525 | Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand()); |
| 1526 | Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand()); |
| 1527 | rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos); |
| 1528 | return success(); |
| 1529 | } |
| 1530 | }; |
| 1531 | |
| 1532 | /// Convert `spirv.Tanh` to |
| 1533 | /// |
| 1534 | /// exp(2x) - 1 |
| 1535 | /// ----------- |
| 1536 | /// exp(2x) + 1 |
| 1537 | /// |
| 1538 | class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> { |
| 1539 | public: |
| 1540 | using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion; |
| 1541 | |
| 1542 | LogicalResult |
| 1543 | matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor, |
| 1544 | ConversionPatternRewriter &rewriter) const override { |
| 1545 | auto srcType = tanhOp.getType(); |
| 1546 | auto dstType = getTypeConverter()->convertType(srcType); |
| 1547 | if (!dstType) |
| 1548 | return rewriter.notifyMatchFailure(tanhOp, "type conversion failed" ); |
| 1549 | |
| 1550 | Location loc = tanhOp.getLoc(); |
| 1551 | Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); |
| 1552 | Value multiplied = |
| 1553 | rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand()); |
| 1554 | Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied); |
| 1555 | Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); |
| 1556 | Value numerator = |
| 1557 | rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one); |
| 1558 | Value denominator = |
| 1559 | rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one); |
| 1560 | rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator, |
| 1561 | denominator); |
| 1562 | return success(); |
| 1563 | } |
| 1564 | }; |
| 1565 | |
| 1566 | class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> { |
| 1567 | public: |
| 1568 | using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion; |
| 1569 | |
| 1570 | LogicalResult |
| 1571 | matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor, |
| 1572 | ConversionPatternRewriter &rewriter) const override { |
| 1573 | auto srcType = varOp.getType(); |
| 1574 | // Initialization is supported for scalars and vectors only. |
| 1575 | auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType(); |
| 1576 | auto init = varOp.getInitializer(); |
| 1577 | if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo)) |
| 1578 | return failure(); |
| 1579 | |
| 1580 | auto dstType = getTypeConverter()->convertType(srcType); |
| 1581 | if (!dstType) |
| 1582 | return rewriter.notifyMatchFailure(varOp, "type conversion failed" ); |
| 1583 | |
| 1584 | Location loc = varOp.getLoc(); |
| 1585 | Value size = createI32ConstantOf(loc, rewriter, value: 1); |
| 1586 | if (!init) { |
| 1587 | auto elementType = getTypeConverter()->convertType(pointerTo); |
| 1588 | if (!elementType) |
| 1589 | return rewriter.notifyMatchFailure(varOp, "type conversion failed" ); |
| 1590 | rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType, |
| 1591 | size); |
| 1592 | return success(); |
| 1593 | } |
| 1594 | auto elementType = getTypeConverter()->convertType(pointerTo); |
| 1595 | if (!elementType) |
| 1596 | return rewriter.notifyMatchFailure(varOp, "type conversion failed" ); |
| 1597 | Value allocated = |
| 1598 | rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size); |
| 1599 | rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated); |
| 1600 | rewriter.replaceOp(varOp, allocated); |
| 1601 | return success(); |
| 1602 | } |
| 1603 | }; |
| 1604 | |
| 1605 | //===----------------------------------------------------------------------===// |
| 1606 | // BitcastOp conversion |
| 1607 | //===----------------------------------------------------------------------===// |
| 1608 | |
| 1609 | class BitcastConversionPattern |
| 1610 | : public SPIRVToLLVMConversion<spirv::BitcastOp> { |
| 1611 | public: |
| 1612 | using SPIRVToLLVMConversion<spirv::BitcastOp>::SPIRVToLLVMConversion; |
| 1613 | |
| 1614 | LogicalResult |
| 1615 | matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor, |
| 1616 | ConversionPatternRewriter &rewriter) const override { |
| 1617 | auto dstType = getTypeConverter()->convertType(bitcastOp.getType()); |
| 1618 | if (!dstType) |
| 1619 | return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed" ); |
| 1620 | |
| 1621 | // LLVM's opaque pointers do not require bitcasts. |
| 1622 | if (isa<LLVM::LLVMPointerType>(dstType)) { |
| 1623 | rewriter.replaceOp(bitcastOp, adaptor.getOperand()); |
| 1624 | return success(); |
| 1625 | } |
| 1626 | |
| 1627 | rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( |
| 1628 | bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs()); |
| 1629 | return success(); |
| 1630 | } |
| 1631 | }; |
| 1632 | |
| 1633 | //===----------------------------------------------------------------------===// |
| 1634 | // FuncOp conversion |
| 1635 | //===----------------------------------------------------------------------===// |
| 1636 | |
| 1637 | class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> { |
| 1638 | public: |
| 1639 | using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion; |
| 1640 | |
| 1641 | LogicalResult |
| 1642 | matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor, |
| 1643 | ConversionPatternRewriter &rewriter) const override { |
| 1644 | |
| 1645 | // Convert function signature. At the moment LLVMType converter is enough |
| 1646 | // for currently supported types. |
| 1647 | auto funcType = funcOp.getFunctionType(); |
| 1648 | TypeConverter::SignatureConversion signatureConverter( |
| 1649 | funcType.getNumInputs()); |
| 1650 | auto llvmType = static_cast<const LLVMTypeConverter *>(getTypeConverter()) |
| 1651 | ->convertFunctionSignature( |
| 1652 | funcType, /*isVariadic=*/false, |
| 1653 | /*useBarePtrCallConv=*/false, signatureConverter); |
| 1654 | if (!llvmType) |
| 1655 | return failure(); |
| 1656 | |
| 1657 | // Create a new `LLVMFuncOp` |
| 1658 | Location loc = funcOp.getLoc(); |
| 1659 | StringRef name = funcOp.getName(); |
| 1660 | auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType); |
| 1661 | |
| 1662 | // Convert SPIR-V Function Control to equivalent LLVM function attribute |
| 1663 | MLIRContext *context = funcOp.getContext(); |
| 1664 | switch (funcOp.getFunctionControl()) { |
| 1665 | case spirv::FunctionControl::Inline: |
| 1666 | newFuncOp.setAlwaysInline(true); |
| 1667 | break; |
| 1668 | case spirv::FunctionControl::DontInline: |
| 1669 | newFuncOp.setNoInline(true); |
| 1670 | break; |
| 1671 | |
| 1672 | #define DISPATCH(functionControl, llvmAttr) \ |
| 1673 | case functionControl: \ |
| 1674 | newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \ |
| 1675 | break; |
| 1676 | |
| 1677 | DISPATCH(spirv::FunctionControl::Pure, |
| 1678 | StringAttr::get(context, "readonly" )); |
| 1679 | DISPATCH(spirv::FunctionControl::Const, |
| 1680 | StringAttr::get(context, "readnone" )); |
| 1681 | |
| 1682 | #undef DISPATCH |
| 1683 | |
| 1684 | // Default: if `spirv::FunctionControl::None`, then no attributes are |
| 1685 | // needed. |
| 1686 | default: |
| 1687 | break; |
| 1688 | } |
| 1689 | |
| 1690 | rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), |
| 1691 | newFuncOp.end()); |
| 1692 | if (failed(rewriter.convertRegionTypes( |
| 1693 | region: &newFuncOp.getBody(), converter: *getTypeConverter(), entryConversion: &signatureConverter))) { |
| 1694 | return failure(); |
| 1695 | } |
| 1696 | rewriter.eraseOp(op: funcOp); |
| 1697 | return success(); |
| 1698 | } |
| 1699 | }; |
| 1700 | |
| 1701 | //===----------------------------------------------------------------------===// |
| 1702 | // ModuleOp conversion |
| 1703 | //===----------------------------------------------------------------------===// |
| 1704 | |
| 1705 | class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> { |
| 1706 | public: |
| 1707 | using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion; |
| 1708 | |
| 1709 | LogicalResult |
| 1710 | matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor, |
| 1711 | ConversionPatternRewriter &rewriter) const override { |
| 1712 | |
| 1713 | auto newModuleOp = |
| 1714 | rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName()); |
| 1715 | rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody()); |
| 1716 | |
| 1717 | // Remove the terminator block that was automatically added by builder |
| 1718 | rewriter.eraseBlock(block: &newModuleOp.getBodyRegion().back()); |
| 1719 | rewriter.eraseOp(op: spvModuleOp); |
| 1720 | return success(); |
| 1721 | } |
| 1722 | }; |
| 1723 | |
| 1724 | //===----------------------------------------------------------------------===// |
| 1725 | // VectorShuffleOp conversion |
| 1726 | //===----------------------------------------------------------------------===// |
| 1727 | |
| 1728 | class VectorShufflePattern |
| 1729 | : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> { |
| 1730 | public: |
| 1731 | using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion; |
| 1732 | LogicalResult |
| 1733 | matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor, |
| 1734 | ConversionPatternRewriter &rewriter) const override { |
| 1735 | Location loc = op.getLoc(); |
| 1736 | auto components = adaptor.getComponents(); |
| 1737 | auto vector1 = adaptor.getVector1(); |
| 1738 | auto vector2 = adaptor.getVector2(); |
| 1739 | int vector1Size = cast<VectorType>(vector1.getType()).getNumElements(); |
| 1740 | int vector2Size = cast<VectorType>(vector2.getType()).getNumElements(); |
| 1741 | if (vector1Size == vector2Size) { |
| 1742 | rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>( |
| 1743 | op, vector1, vector2, |
| 1744 | LLVM::convertArrayToIndices<int32_t>(components)); |
| 1745 | return success(); |
| 1746 | } |
| 1747 | |
| 1748 | auto dstType = getTypeConverter()->convertType(op.getType()); |
| 1749 | if (!dstType) |
| 1750 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
| 1751 | auto scalarType = cast<VectorType>(dstType).getElementType(); |
| 1752 | auto componentsArray = components.getValue(); |
| 1753 | auto *context = rewriter.getContext(); |
| 1754 | auto llvmI32Type = IntegerType::get(context, 32); |
| 1755 | Value targetOp = rewriter.create<LLVM::PoisonOp>(loc, dstType); |
| 1756 | for (unsigned i = 0; i < componentsArray.size(); i++) { |
| 1757 | if (!isa<IntegerAttr>(componentsArray[i])) |
| 1758 | return op.emitError("unable to support non-constant component" ); |
| 1759 | |
| 1760 | int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt(); |
| 1761 | if (indexVal == -1) |
| 1762 | continue; |
| 1763 | |
| 1764 | int offsetVal = 0; |
| 1765 | Value baseVector = vector1; |
| 1766 | if (indexVal >= vector1Size) { |
| 1767 | offsetVal = vector1Size; |
| 1768 | baseVector = vector2; |
| 1769 | } |
| 1770 | |
| 1771 | Value dstIndex = rewriter.create<LLVM::ConstantOp>( |
| 1772 | loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i)); |
| 1773 | Value index = rewriter.create<LLVM::ConstantOp>( |
| 1774 | loc, llvmI32Type, |
| 1775 | rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal)); |
| 1776 | |
| 1777 | auto = rewriter.create<LLVM::ExtractElementOp>( |
| 1778 | loc, scalarType, baseVector, index); |
| 1779 | targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp, |
| 1780 | extractOp, dstIndex); |
| 1781 | } |
| 1782 | rewriter.replaceOp(op, targetOp); |
| 1783 | return success(); |
| 1784 | } |
| 1785 | }; |
| 1786 | } // namespace |
| 1787 | |
| 1788 | //===----------------------------------------------------------------------===// |
| 1789 | // Pattern population |
| 1790 | //===----------------------------------------------------------------------===// |
| 1791 | |
| 1792 | void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, |
| 1793 | spirv::ClientAPI clientAPI) { |
| 1794 | typeConverter.addConversion(callback: [&](spirv::ArrayType type) { |
| 1795 | return convertArrayType(type, converter&: typeConverter); |
| 1796 | }); |
| 1797 | typeConverter.addConversion(callback: [&, clientAPI](spirv::PointerType type) { |
| 1798 | return convertPointerType(type, typeConverter, clientAPI); |
| 1799 | }); |
| 1800 | typeConverter.addConversion(callback: [&](spirv::RuntimeArrayType type) { |
| 1801 | return convertRuntimeArrayType(type, converter&: typeConverter); |
| 1802 | }); |
| 1803 | typeConverter.addConversion(callback: [&](spirv::StructType type) { |
| 1804 | return convertStructType(type, converter: typeConverter); |
| 1805 | }); |
| 1806 | } |
| 1807 | |
| 1808 | void mlir::populateSPIRVToLLVMConversionPatterns( |
| 1809 | const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, |
| 1810 | spirv::ClientAPI clientAPI) { |
| 1811 | patterns.add< |
| 1812 | // Arithmetic ops |
| 1813 | DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>, |
| 1814 | DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>, |
| 1815 | DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>, |
| 1816 | DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>, |
| 1817 | DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>, |
| 1818 | DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>, |
| 1819 | DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>, |
| 1820 | DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>, |
| 1821 | DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>, |
| 1822 | DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>, |
| 1823 | DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>, |
| 1824 | DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>, |
| 1825 | DirectConversionPattern<spirv::UModOp, LLVM::URemOp>, |
| 1826 | |
| 1827 | // Bitwise ops |
| 1828 | BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern, |
| 1829 | DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>, |
| 1830 | DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>, |
| 1831 | DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>, |
| 1832 | DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>, |
| 1833 | DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>, |
| 1834 | NotPattern<spirv::NotOp>, |
| 1835 | |
| 1836 | // Cast ops |
| 1837 | BitcastConversionPattern, |
| 1838 | DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>, |
| 1839 | DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>, |
| 1840 | DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>, |
| 1841 | DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>, |
| 1842 | IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>, |
| 1843 | IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>, |
| 1844 | IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>, |
| 1845 | |
| 1846 | // Comparison ops |
| 1847 | IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>, |
| 1848 | IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>, |
| 1849 | FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>, |
| 1850 | FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>, |
| 1851 | FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>, |
| 1852 | FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>, |
| 1853 | FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>, |
| 1854 | FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>, |
| 1855 | FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>, |
| 1856 | FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>, |
| 1857 | FComparePattern<spirv::FUnordGreaterThanEqualOp, |
| 1858 | LLVM::FCmpPredicate::uge>, |
| 1859 | FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>, |
| 1860 | FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>, |
| 1861 | FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>, |
| 1862 | IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>, |
| 1863 | IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>, |
| 1864 | IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>, |
| 1865 | IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>, |
| 1866 | IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>, |
| 1867 | IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>, |
| 1868 | IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>, |
| 1869 | IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>, |
| 1870 | |
| 1871 | // Constant op |
| 1872 | ConstantScalarAndVectorPattern, |
| 1873 | |
| 1874 | // Control Flow ops |
| 1875 | BranchConversionPattern, BranchConditionalConversionPattern, |
| 1876 | FunctionCallPattern, LoopPattern, SelectionPattern, |
| 1877 | ErasePattern<spirv::MergeOp>, |
| 1878 | |
| 1879 | // Entry points and execution mode are handled separately. |
| 1880 | ErasePattern<spirv::EntryPointOp>, ExecutionModePattern, |
| 1881 | |
| 1882 | // GLSL extended instruction set ops |
| 1883 | DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>, |
| 1884 | DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>, |
| 1885 | DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>, |
| 1886 | DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>, |
| 1887 | DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>, |
| 1888 | DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>, |
| 1889 | DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>, |
| 1890 | DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>, |
| 1891 | DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>, |
| 1892 | DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>, |
| 1893 | DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>, |
| 1894 | DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>, |
| 1895 | InverseSqrtPattern, TanPattern, TanhPattern, |
| 1896 | |
| 1897 | // Logical ops |
| 1898 | DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>, |
| 1899 | DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>, |
| 1900 | IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>, |
| 1901 | IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>, |
| 1902 | NotPattern<spirv::LogicalNotOp>, |
| 1903 | |
| 1904 | // Memory ops |
| 1905 | AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>, |
| 1906 | LoadStorePattern<spirv::StoreOp>, VariablePattern, |
| 1907 | |
| 1908 | // Miscellaneous ops |
| 1909 | CompositeExtractPattern, CompositeInsertPattern, |
| 1910 | DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>, |
| 1911 | DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>, |
| 1912 | VectorShufflePattern, |
| 1913 | |
| 1914 | // Shift ops |
| 1915 | ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>, |
| 1916 | ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>, |
| 1917 | ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>, |
| 1918 | |
| 1919 | // Return ops |
| 1920 | ReturnPattern, ReturnValuePattern, |
| 1921 | |
| 1922 | // Barrier ops |
| 1923 | ControlBarrierPattern<spirv::ControlBarrierOp>, |
| 1924 | ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>, |
| 1925 | ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>, |
| 1926 | |
| 1927 | // Group reduction operations |
| 1928 | GroupReducePattern<spirv::GroupIAddOp>, |
| 1929 | GroupReducePattern<spirv::GroupFAddOp>, |
| 1930 | GroupReducePattern<spirv::GroupFMinOp>, |
| 1931 | GroupReducePattern<spirv::GroupUMinOp>, |
| 1932 | GroupReducePattern<spirv::GroupSMinOp, /*Signed=*/true>, |
| 1933 | GroupReducePattern<spirv::GroupFMaxOp>, |
| 1934 | GroupReducePattern<spirv::GroupUMaxOp>, |
| 1935 | GroupReducePattern<spirv::GroupSMaxOp, /*Signed=*/true>, |
| 1936 | GroupReducePattern<spirv::GroupNonUniformIAddOp, /*Signed=*/false, |
| 1937 | /*NonUniform=*/true>, |
| 1938 | GroupReducePattern<spirv::GroupNonUniformFAddOp, /*Signed=*/false, |
| 1939 | /*NonUniform=*/true>, |
| 1940 | GroupReducePattern<spirv::GroupNonUniformIMulOp, /*Signed=*/false, |
| 1941 | /*NonUniform=*/true>, |
| 1942 | GroupReducePattern<spirv::GroupNonUniformFMulOp, /*Signed=*/false, |
| 1943 | /*NonUniform=*/true>, |
| 1944 | GroupReducePattern<spirv::GroupNonUniformSMinOp, /*Signed=*/true, |
| 1945 | /*NonUniform=*/true>, |
| 1946 | GroupReducePattern<spirv::GroupNonUniformUMinOp, /*Signed=*/false, |
| 1947 | /*NonUniform=*/true>, |
| 1948 | GroupReducePattern<spirv::GroupNonUniformFMinOp, /*Signed=*/false, |
| 1949 | /*NonUniform=*/true>, |
| 1950 | GroupReducePattern<spirv::GroupNonUniformSMaxOp, /*Signed=*/true, |
| 1951 | /*NonUniform=*/true>, |
| 1952 | GroupReducePattern<spirv::GroupNonUniformUMaxOp, /*Signed=*/false, |
| 1953 | /*NonUniform=*/true>, |
| 1954 | GroupReducePattern<spirv::GroupNonUniformFMaxOp, /*Signed=*/false, |
| 1955 | /*NonUniform=*/true>, |
| 1956 | GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /*Signed=*/false, |
| 1957 | /*NonUniform=*/true>, |
| 1958 | GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /*Signed=*/false, |
| 1959 | /*NonUniform=*/true>, |
| 1960 | GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /*Signed=*/false, |
| 1961 | /*NonUniform=*/true>, |
| 1962 | GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /*Signed=*/false, |
| 1963 | /*NonUniform=*/true>, |
| 1964 | GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /*Signed=*/false, |
| 1965 | /*NonUniform=*/true>, |
| 1966 | GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /*Signed=*/false, |
| 1967 | /*NonUniform=*/true>>(patterns.getContext(), |
| 1968 | typeConverter); |
| 1969 | |
| 1970 | patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(), |
| 1971 | typeConverter); |
| 1972 | } |
| 1973 | |
| 1974 | void mlir::populateSPIRVToLLVMFunctionConversionPatterns( |
| 1975 | const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { |
| 1976 | patterns.add<FuncConversionPattern>(arg: patterns.getContext(), args: typeConverter); |
| 1977 | } |
| 1978 | |
| 1979 | void mlir::populateSPIRVToLLVMModuleConversionPatterns( |
| 1980 | const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { |
| 1981 | patterns.add<ModuleConversionPattern>(arg: patterns.getContext(), args: typeConverter); |
| 1982 | } |
| 1983 | |
| 1984 | //===----------------------------------------------------------------------===// |
| 1985 | // Pre-conversion hooks |
| 1986 | //===----------------------------------------------------------------------===// |
| 1987 | |
| 1988 | /// Hook for descriptor set and binding number encoding. |
| 1989 | static constexpr StringRef kBinding = "binding" ; |
| 1990 | static constexpr StringRef kDescriptorSet = "descriptor_set" ; |
| 1991 | void mlir::encodeBindAttribute(ModuleOp module) { |
| 1992 | auto spvModules = module.getOps<spirv::ModuleOp>(); |
| 1993 | for (auto spvModule : spvModules) { |
| 1994 | spvModule.walk([&](spirv::GlobalVariableOp op) { |
| 1995 | IntegerAttr descriptorSet = |
| 1996 | op->getAttrOfType<IntegerAttr>(kDescriptorSet); |
| 1997 | IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding); |
| 1998 | // For every global variable in the module, get the ones with descriptor |
| 1999 | // set and binding numbers. |
| 2000 | if (descriptorSet && binding) { |
| 2001 | // Encode these numbers into the variable's symbolic name. If the |
| 2002 | // SPIR-V module has a name, add it at the beginning. |
| 2003 | auto moduleAndName = |
| 2004 | spvModule.getName().has_value() |
| 2005 | ? spvModule.getName()->str() + "_" + op.getSymName().str() |
| 2006 | : op.getSymName().str(); |
| 2007 | std::string name = |
| 2008 | llvm::formatv("{0}_descriptor_set{1}_binding{2}" , moduleAndName, |
| 2009 | std::to_string(descriptorSet.getInt()), |
| 2010 | std::to_string(binding.getInt())); |
| 2011 | auto nameAttr = StringAttr::get(op->getContext(), name); |
| 2012 | |
| 2013 | // Replace all symbol uses and set the new symbol name. Finally, remove |
| 2014 | // descriptor set and binding attributes. |
| 2015 | if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule))) |
| 2016 | op.emitError("unable to replace all symbol uses for " ) << name; |
| 2017 | SymbolTable::setSymbolName(op, nameAttr); |
| 2018 | op->removeAttr(kDescriptorSet); |
| 2019 | op->removeAttr(kBinding); |
| 2020 | } |
| 2021 | }); |
| 2022 | } |
| 2023 | } |
| 2024 | |