| 1 | //===- GPUToLLVMSPV.cpp - Convert GPU operations to LLVM dialect ----------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Conversion/GPUToLLVMSPV/GPUToLLVMSPVPass.h" |
| 10 | |
| 11 | #include "../GPUCommon/GPUOpsLowering.h" |
| 12 | #include "mlir/Conversion/GPUCommon/AttrToSPIRVConverter.h" |
| 13 | #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" |
| 14 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
| 15 | #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" |
| 16 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
| 17 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
| 18 | #include "mlir/Conversion/SPIRVCommon/AttrToLLVMConverter.h" |
| 19 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 20 | #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" |
| 21 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 22 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
| 23 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
| 24 | #include "mlir/IR/BuiltinTypes.h" |
| 25 | #include "mlir/IR/Matchers.h" |
| 26 | #include "mlir/IR/PatternMatch.h" |
| 27 | #include "mlir/IR/SymbolTable.h" |
| 28 | #include "mlir/Pass/Pass.h" |
| 29 | #include "mlir/Support/LLVM.h" |
| 30 | #include "mlir/Transforms/DialectConversion.h" |
| 31 | |
| 32 | #include "llvm/ADT/TypeSwitch.h" |
| 33 | #include "llvm/Support/FormatVariadic.h" |
| 34 | |
| 35 | #define DEBUG_TYPE "gpu-to-llvm-spv" |
| 36 | |
| 37 | using namespace mlir; |
| 38 | |
| 39 | namespace mlir { |
| 40 | #define GEN_PASS_DEF_CONVERTGPUOPSTOLLVMSPVOPS |
| 41 | #include "mlir/Conversion/Passes.h.inc" |
| 42 | } // namespace mlir |
| 43 | |
| 44 | //===----------------------------------------------------------------------===// |
| 45 | // Helper Functions |
| 46 | //===----------------------------------------------------------------------===// |
| 47 | |
| 48 | static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, |
| 49 | StringRef name, |
| 50 | ArrayRef<Type> paramTypes, |
| 51 | Type resultType, bool isMemNone, |
| 52 | bool isConvergent) { |
| 53 | auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>( |
| 54 | SymbolTable::lookupSymbolIn(symbolTable, name)); |
| 55 | if (!func) { |
| 56 | OpBuilder b(symbolTable->getRegion(index: 0)); |
| 57 | func = b.create<LLVM::LLVMFuncOp>( |
| 58 | symbolTable->getLoc(), name, |
| 59 | LLVM::LLVMFunctionType::get(resultType, paramTypes)); |
| 60 | func.setCConv(LLVM::cconv::CConv::SPIR_FUNC); |
| 61 | func.setNoUnwind(true); |
| 62 | func.setWillReturn(true); |
| 63 | |
| 64 | if (isMemNone) { |
| 65 | // no externally observable effects |
| 66 | constexpr auto noModRef = mlir::LLVM::ModRefInfo::NoModRef; |
| 67 | auto memAttr = b.getAttr<LLVM::MemoryEffectsAttr>( |
| 68 | /*other=*/noModRef, |
| 69 | /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef); |
| 70 | func.setMemoryEffectsAttr(memAttr); |
| 71 | } |
| 72 | |
| 73 | func.setConvergent(isConvergent); |
| 74 | } |
| 75 | return func; |
| 76 | } |
| 77 | |
| 78 | static LLVM::CallOp createSPIRVBuiltinCall(Location loc, |
| 79 | ConversionPatternRewriter &rewriter, |
| 80 | LLVM::LLVMFuncOp func, |
| 81 | ValueRange args) { |
| 82 | auto call = rewriter.create<LLVM::CallOp>(loc, func, args); |
| 83 | call.setCConv(func.getCConv()); |
| 84 | call.setConvergentAttr(func.getConvergentAttr()); |
| 85 | call.setNoUnwindAttr(func.getNoUnwindAttr()); |
| 86 | call.setWillReturnAttr(func.getWillReturnAttr()); |
| 87 | call.setMemoryEffectsAttr(func.getMemoryEffectsAttr()); |
| 88 | return call; |
| 89 | } |
| 90 | |
| 91 | namespace { |
| 92 | //===----------------------------------------------------------------------===// |
| 93 | // Barriers |
| 94 | //===----------------------------------------------------------------------===// |
| 95 | |
| 96 | /// Replace `gpu.barrier` with an `llvm.call` to `barrier` with |
| 97 | /// `CLK_LOCAL_MEM_FENCE` argument, indicating work-group memory scope: |
| 98 | /// ``` |
| 99 | /// // gpu.barrier |
| 100 | /// %c1 = llvm.mlir.constant(1: i32) : i32 |
| 101 | /// llvm.call spir_funccc @_Z7barrierj(%c1) : (i32) -> () |
| 102 | /// ``` |
| 103 | struct GPUBarrierConversion final : ConvertOpToLLVMPattern<gpu::BarrierOp> { |
| 104 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
| 105 | |
| 106 | LogicalResult |
| 107 | matchAndRewrite(gpu::BarrierOp op, OpAdaptor adaptor, |
| 108 | ConversionPatternRewriter &rewriter) const final { |
| 109 | constexpr StringLiteral funcName = "_Z7barrierj" ; |
| 110 | |
| 111 | Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>(); |
| 112 | assert(moduleOp && "Expecting module" ); |
| 113 | Type flagTy = rewriter.getI32Type(); |
| 114 | Type voidTy = rewriter.getType<LLVM::LLVMVoidType>(); |
| 115 | LLVM::LLVMFuncOp func = |
| 116 | lookupOrCreateSPIRVFn(moduleOp, funcName, flagTy, voidTy, |
| 117 | /*isMemNone=*/false, /*isConvergent=*/true); |
| 118 | |
| 119 | // Value used by SPIR-V backend to represent `CLK_LOCAL_MEM_FENCE`. |
| 120 | // See `llvm/lib/Target/SPIRV/SPIRVBuiltins.td`. |
| 121 | constexpr int64_t localMemFenceFlag = 1; |
| 122 | Location loc = op->getLoc(); |
| 123 | Value flag = |
| 124 | rewriter.create<LLVM::ConstantOp>(loc, flagTy, localMemFenceFlag); |
| 125 | rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, flag)); |
| 126 | return success(); |
| 127 | } |
| 128 | }; |
| 129 | |
| 130 | //===----------------------------------------------------------------------===// |
| 131 | // SPIR-V Builtins |
| 132 | //===----------------------------------------------------------------------===// |
| 133 | |
| 134 | /// Replace `gpu.*` with an `llvm.call` to the corresponding SPIR-V builtin with |
| 135 | /// a constant argument for the `dimension` attribute. Return type will depend |
| 136 | /// on index width option: |
| 137 | /// ``` |
| 138 | /// // %thread_id_y = gpu.thread_id y |
| 139 | /// %c1 = llvm.mlir.constant(1: i32) : i32 |
| 140 | /// %0 = llvm.call spir_funccc @_Z12get_local_idj(%c1) : (i32) -> i64 |
| 141 | /// ``` |
| 142 | struct LaunchConfigConversion : ConvertToLLVMPattern { |
| 143 | LaunchConfigConversion(StringRef funcName, StringRef rootOpName, |
| 144 | MLIRContext *context, |
| 145 | const LLVMTypeConverter &typeConverter, |
| 146 | PatternBenefit benefit) |
| 147 | : ConvertToLLVMPattern(rootOpName, context, typeConverter, benefit), |
| 148 | funcName(funcName) {} |
| 149 | |
| 150 | virtual gpu::Dimension getDimension(Operation *op) const = 0; |
| 151 | |
| 152 | LogicalResult |
| 153 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 154 | ConversionPatternRewriter &rewriter) const final { |
| 155 | Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>(); |
| 156 | assert(moduleOp && "Expecting module" ); |
| 157 | Type dimTy = rewriter.getI32Type(); |
| 158 | Type indexTy = getTypeConverter()->getIndexType(); |
| 159 | LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(moduleOp, funcName, dimTy, |
| 160 | indexTy, /*isMemNone=*/true, |
| 161 | /*isConvergent=*/false); |
| 162 | |
| 163 | Location loc = op->getLoc(); |
| 164 | gpu::Dimension dim = getDimension(op); |
| 165 | Value dimVal = rewriter.create<LLVM::ConstantOp>(loc, dimTy, |
| 166 | static_cast<int64_t>(dim)); |
| 167 | rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, dimVal)); |
| 168 | return success(); |
| 169 | } |
| 170 | |
| 171 | StringRef funcName; |
| 172 | }; |
| 173 | |
| 174 | template <typename SourceOp> |
| 175 | struct LaunchConfigOpConversion final : LaunchConfigConversion { |
| 176 | static StringRef getFuncName(); |
| 177 | |
| 178 | explicit LaunchConfigOpConversion(const LLVMTypeConverter &typeConverter, |
| 179 | PatternBenefit benefit = 1) |
| 180 | : LaunchConfigConversion(getFuncName(), SourceOp::getOperationName(), |
| 181 | &typeConverter.getContext(), typeConverter, |
| 182 | benefit) {} |
| 183 | |
| 184 | gpu::Dimension getDimension(Operation *op) const final { |
| 185 | return cast<SourceOp>(op).getDimension(); |
| 186 | } |
| 187 | }; |
| 188 | |
| 189 | template <> |
| 190 | StringRef LaunchConfigOpConversion<gpu::BlockIdOp>::getFuncName() { |
| 191 | return "_Z12get_group_idj" ; |
| 192 | } |
| 193 | |
| 194 | template <> |
| 195 | StringRef LaunchConfigOpConversion<gpu::GridDimOp>::getFuncName() { |
| 196 | return "_Z14get_num_groupsj" ; |
| 197 | } |
| 198 | |
| 199 | template <> |
| 200 | StringRef LaunchConfigOpConversion<gpu::BlockDimOp>::getFuncName() { |
| 201 | return "_Z14get_local_sizej" ; |
| 202 | } |
| 203 | |
| 204 | template <> |
| 205 | StringRef LaunchConfigOpConversion<gpu::ThreadIdOp>::getFuncName() { |
| 206 | return "_Z12get_local_idj" ; |
| 207 | } |
| 208 | |
| 209 | template <> |
| 210 | StringRef LaunchConfigOpConversion<gpu::GlobalIdOp>::getFuncName() { |
| 211 | return "_Z13get_global_idj" ; |
| 212 | } |
| 213 | |
| 214 | //===----------------------------------------------------------------------===// |
| 215 | // Shuffles |
| 216 | //===----------------------------------------------------------------------===// |
| 217 | |
| 218 | /// Replace `gpu.shuffle` with an `llvm.call` to the corresponding SPIR-V |
| 219 | /// builtin for `shuffleResult`, keeping `value` and `offset` arguments, and a |
| 220 | /// `true` constant for the `valid` result type. Conversion will only take place |
| 221 | /// if `width` is constant and equal to the `subgroup` pass option: |
| 222 | /// ``` |
| 223 | /// // %0 = gpu.shuffle idx %value, %offset, %width : f64 |
| 224 | /// %0 = llvm.call spir_funccc @_Z17sub_group_shuffledj(%value, %offset) |
| 225 | /// : (f64, i32) -> f64 |
| 226 | /// ``` |
| 227 | struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> { |
| 228 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
| 229 | |
| 230 | static StringRef getBaseName(gpu::ShuffleMode mode) { |
| 231 | switch (mode) { |
| 232 | case gpu::ShuffleMode::IDX: |
| 233 | return "sub_group_shuffle" ; |
| 234 | case gpu::ShuffleMode::XOR: |
| 235 | return "sub_group_shuffle_xor" ; |
| 236 | case gpu::ShuffleMode::UP: |
| 237 | return "sub_group_shuffle_up" ; |
| 238 | case gpu::ShuffleMode::DOWN: |
| 239 | return "sub_group_shuffle_down" ; |
| 240 | } |
| 241 | llvm_unreachable("Unhandled shuffle mode" ); |
| 242 | } |
| 243 | |
| 244 | static std::optional<StringRef> getTypeMangling(Type type) { |
| 245 | return TypeSwitch<Type, std::optional<StringRef>>(type) |
| 246 | .Case<Float16Type>([](auto) { return "Dhj" ; }) |
| 247 | .Case<Float32Type>([](auto) { return "fj" ; }) |
| 248 | .Case<Float64Type>([](auto) { return "dj" ; }) |
| 249 | .Case<IntegerType>([](auto intTy) -> std::optional<StringRef> { |
| 250 | switch (intTy.getWidth()) { |
| 251 | case 8: |
| 252 | return "cj" ; |
| 253 | case 16: |
| 254 | return "sj" ; |
| 255 | case 32: |
| 256 | return "ij" ; |
| 257 | case 64: |
| 258 | return "lj" ; |
| 259 | } |
| 260 | return std::nullopt; |
| 261 | }) |
| 262 | .Default([](auto) { return std::nullopt; }); |
| 263 | } |
| 264 | |
| 265 | static std::optional<std::string> getFuncName(gpu::ShuffleMode mode, |
| 266 | Type type) { |
| 267 | StringRef baseName = getBaseName(mode); |
| 268 | std::optional<StringRef> typeMangling = getTypeMangling(type); |
| 269 | if (!typeMangling) |
| 270 | return std::nullopt; |
| 271 | return llvm::formatv(Fmt: "_Z{}{}{}" , Vals: baseName.size(), Vals&: baseName, |
| 272 | Vals&: typeMangling.value()); |
| 273 | } |
| 274 | |
| 275 | /// Get the subgroup size from the target or return a default. |
| 276 | static std::optional<int> getSubgroupSize(Operation *op) { |
| 277 | auto parentFunc = op->getParentOfType<LLVM::LLVMFuncOp>(); |
| 278 | if (!parentFunc) |
| 279 | return std::nullopt; |
| 280 | return parentFunc.getIntelReqdSubGroupSize(); |
| 281 | } |
| 282 | |
| 283 | static bool hasValidWidth(gpu::ShuffleOp op) { |
| 284 | llvm::APInt val; |
| 285 | Value width = op.getWidth(); |
| 286 | return matchPattern(width, m_ConstantInt(&val)) && |
| 287 | val == getSubgroupSize(op); |
| 288 | } |
| 289 | |
| 290 | static Value bitcastOrExtBeforeShuffle(Value oldVal, Location loc, |
| 291 | ConversionPatternRewriter &rewriter) { |
| 292 | return TypeSwitch<Type, Value>(oldVal.getType()) |
| 293 | .Case(caseFn: [&](BFloat16Type) { |
| 294 | return rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI16Type(), |
| 295 | oldVal); |
| 296 | }) |
| 297 | .Case(caseFn: [&](IntegerType intTy) -> Value { |
| 298 | if (intTy.getWidth() == 1) |
| 299 | return rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI8Type(), |
| 300 | oldVal); |
| 301 | return oldVal; |
| 302 | }) |
| 303 | .Default(defaultResult: oldVal); |
| 304 | } |
| 305 | |
| 306 | static Value bitcastOrTruncAfterShuffle(Value oldVal, Type newTy, |
| 307 | Location loc, |
| 308 | ConversionPatternRewriter &rewriter) { |
| 309 | return TypeSwitch<Type, Value>(newTy) |
| 310 | .Case(caseFn: [&](BFloat16Type) { |
| 311 | return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal); |
| 312 | }) |
| 313 | .Case(caseFn: [&](IntegerType intTy) -> Value { |
| 314 | if (intTy.getWidth() == 1) |
| 315 | return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal); |
| 316 | return oldVal; |
| 317 | }) |
| 318 | .Default(defaultResult: oldVal); |
| 319 | } |
| 320 | |
| 321 | LogicalResult |
| 322 | matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor, |
| 323 | ConversionPatternRewriter &rewriter) const final { |
| 324 | if (!hasValidWidth(op)) |
| 325 | return rewriter.notifyMatchFailure( |
| 326 | op, "shuffle width and subgroup size mismatch" ); |
| 327 | |
| 328 | Location loc = op->getLoc(); |
| 329 | Value inValue = |
| 330 | bitcastOrExtBeforeShuffle(oldVal: adaptor.getValue(), loc, rewriter); |
| 331 | std::optional<std::string> funcName = |
| 332 | getFuncName(op.getMode(), inValue.getType()); |
| 333 | if (!funcName) |
| 334 | return rewriter.notifyMatchFailure(op, "unsupported value type" ); |
| 335 | |
| 336 | Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>(); |
| 337 | assert(moduleOp && "Expecting module" ); |
| 338 | Type valueType = inValue.getType(); |
| 339 | Type offsetType = adaptor.getOffset().getType(); |
| 340 | Type resultType = valueType; |
| 341 | LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn( |
| 342 | moduleOp, funcName.value(), {valueType, offsetType}, resultType, |
| 343 | /*isMemNone=*/false, /*isConvergent=*/true); |
| 344 | |
| 345 | std::array<Value, 2> args{inValue, adaptor.getOffset()}; |
| 346 | Value result = |
| 347 | createSPIRVBuiltinCall(loc, rewriter, func, args).getResult(); |
| 348 | Value resultOrConversion = |
| 349 | bitcastOrTruncAfterShuffle(oldVal: result, newTy: op.getType(0), loc, rewriter); |
| 350 | |
| 351 | Value trueVal = |
| 352 | rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true); |
| 353 | rewriter.replaceOp(op, {resultOrConversion, trueVal}); |
| 354 | return success(); |
| 355 | } |
| 356 | }; |
| 357 | |
| 358 | class MemorySpaceToOpenCLMemorySpaceConverter final : public TypeConverter { |
| 359 | public: |
| 360 | MemorySpaceToOpenCLMemorySpaceConverter(MLIRContext *ctx) { |
| 361 | addConversion(callback: [](Type t) { return t; }); |
| 362 | addConversion(callback: [ctx](BaseMemRefType memRefType) -> std::optional<Type> { |
| 363 | // Attach global addr space attribute to memrefs with no addr space attr |
| 364 | Attribute memSpaceAttr = memRefType.getMemorySpace(); |
| 365 | if (memSpaceAttr) |
| 366 | return std::nullopt; |
| 367 | |
| 368 | unsigned globalAddrspace = storageClassToAddressSpace( |
| 369 | spirv::ClientAPI::OpenCL, spirv::StorageClass::CrossWorkgroup); |
| 370 | Attribute addrSpaceAttr = |
| 371 | IntegerAttr::get(IntegerType::get(ctx, 64), globalAddrspace); |
| 372 | if (auto rankedType = dyn_cast<MemRefType>(memRefType)) { |
| 373 | return MemRefType::get(memRefType.getShape(), |
| 374 | memRefType.getElementType(), |
| 375 | rankedType.getLayout(), addrSpaceAttr); |
| 376 | } |
| 377 | return UnrankedMemRefType::get(memRefType.getElementType(), |
| 378 | addrSpaceAttr); |
| 379 | }); |
| 380 | addConversion(callback: [this](FunctionType type) { |
| 381 | auto inputs = llvm::map_to_vector( |
| 382 | type.getInputs(), [this](Type ty) { return convertType(t: ty); }); |
| 383 | auto results = llvm::map_to_vector( |
| 384 | type.getResults(), [this](Type ty) { return convertType(t: ty); }); |
| 385 | return FunctionType::get(type.getContext(), inputs, results); |
| 386 | }); |
| 387 | } |
| 388 | }; |
| 389 | |
| 390 | //===----------------------------------------------------------------------===// |
| 391 | // Subgroup query ops. |
| 392 | //===----------------------------------------------------------------------===// |
| 393 | |
| 394 | template <typename SubgroupOp> |
| 395 | struct GPUSubgroupOpConversion final : ConvertOpToLLVMPattern<SubgroupOp> { |
| 396 | using ConvertOpToLLVMPattern<SubgroupOp>::ConvertOpToLLVMPattern; |
| 397 | using ConvertToLLVMPattern::getTypeConverter; |
| 398 | |
| 399 | LogicalResult |
| 400 | matchAndRewrite(SubgroupOp op, typename SubgroupOp::Adaptor adaptor, |
| 401 | ConversionPatternRewriter &rewriter) const final { |
| 402 | constexpr StringRef funcName = [] { |
| 403 | if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupIdOp>) { |
| 404 | return "_Z16get_sub_group_id" ; |
| 405 | } else if constexpr (std::is_same_v<SubgroupOp, gpu::LaneIdOp>) { |
| 406 | return "_Z22get_sub_group_local_id" ; |
| 407 | } else if constexpr (std::is_same_v<SubgroupOp, gpu::NumSubgroupsOp>) { |
| 408 | return "_Z18get_num_sub_groups" ; |
| 409 | } else if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupSizeOp>) { |
| 410 | return "_Z18get_sub_group_size" ; |
| 411 | } |
| 412 | }(); |
| 413 | |
| 414 | Operation *moduleOp = |
| 415 | op->template getParentWithTrait<OpTrait::SymbolTable>(); |
| 416 | Type resultTy = rewriter.getI32Type(); |
| 417 | LLVM::LLVMFuncOp func = |
| 418 | lookupOrCreateSPIRVFn(moduleOp, funcName, {}, resultTy, |
| 419 | /*isMemNone=*/false, /*isConvergent=*/false); |
| 420 | |
| 421 | Location loc = op->getLoc(); |
| 422 | Value result = createSPIRVBuiltinCall(loc, rewriter, func, {}).getResult(); |
| 423 | |
| 424 | Type indexTy = getTypeConverter()->getIndexType(); |
| 425 | if (resultTy != indexTy) { |
| 426 | if (indexTy.getIntOrFloatBitWidth() < resultTy.getIntOrFloatBitWidth()) { |
| 427 | return failure(); |
| 428 | } |
| 429 | result = rewriter.create<LLVM::ZExtOp>(loc, indexTy, result); |
| 430 | } |
| 431 | |
| 432 | rewriter.replaceOp(op, result); |
| 433 | return success(); |
| 434 | } |
| 435 | }; |
| 436 | |
| 437 | //===----------------------------------------------------------------------===// |
| 438 | // GPU To LLVM-SPV Pass. |
| 439 | //===----------------------------------------------------------------------===// |
| 440 | |
| 441 | struct GPUToLLVMSPVConversionPass final |
| 442 | : impl::ConvertGpuOpsToLLVMSPVOpsBase<GPUToLLVMSPVConversionPass> { |
| 443 | using Base::Base; |
| 444 | |
| 445 | void runOnOperation() final { |
| 446 | MLIRContext *context = &getContext(); |
| 447 | RewritePatternSet patterns(context); |
| 448 | |
| 449 | LowerToLLVMOptions options(context); |
| 450 | options.overrideIndexBitwidth(bitwidth: this->use64bitIndex ? 64 : 32); |
| 451 | LLVMTypeConverter converter(context, options); |
| 452 | LLVMConversionTarget target(*context); |
| 453 | |
| 454 | // Force OpenCL address spaces when they are not present |
| 455 | { |
| 456 | MemorySpaceToOpenCLMemorySpaceConverter converter(context); |
| 457 | AttrTypeReplacer replacer; |
| 458 | replacer.addReplacement(callback: [&converter](BaseMemRefType origType) |
| 459 | -> std::optional<BaseMemRefType> { |
| 460 | return converter.convertType<BaseMemRefType>(t: origType); |
| 461 | }); |
| 462 | |
| 463 | replacer.recursivelyReplaceElementsIn(op: getOperation(), |
| 464 | /*replaceAttrs=*/true, |
| 465 | /*replaceLocs=*/false, |
| 466 | /*replaceTypes=*/true); |
| 467 | } |
| 468 | |
| 469 | target.addIllegalOp<gpu::BarrierOp, gpu::BlockDimOp, gpu::BlockIdOp, |
| 470 | gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp, |
| 471 | gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp, |
| 472 | gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp, |
| 473 | gpu::ThreadIdOp>(); |
| 474 | |
| 475 | populateGpuToLLVMSPVConversionPatterns(converter, patterns); |
| 476 | populateGpuMemorySpaceAttributeConversions(typeConverter&: converter); |
| 477 | |
| 478 | if (failed(applyPartialConversion(getOperation(), target, |
| 479 | std::move(patterns)))) |
| 480 | signalPassFailure(); |
| 481 | } |
| 482 | }; |
| 483 | } // namespace |
| 484 | |
| 485 | //===----------------------------------------------------------------------===// |
| 486 | // GPU To LLVM-SPV Patterns. |
| 487 | //===----------------------------------------------------------------------===// |
| 488 | |
| 489 | namespace mlir { |
| 490 | namespace { |
| 491 | static unsigned |
| 492 | gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace addressSpace) { |
| 493 | constexpr spirv::ClientAPI clientAPI = spirv::ClientAPI::OpenCL; |
| 494 | return storageClassToAddressSpace(clientAPI, |
| 495 | addressSpaceToStorageClass(addressSpace)); |
| 496 | } |
| 497 | } // namespace |
| 498 | |
| 499 | void populateGpuToLLVMSPVConversionPatterns( |
| 500 | const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { |
| 501 | patterns.add<GPUBarrierConversion, GPUReturnOpLowering, GPUShuffleConversion, |
| 502 | GPUSubgroupOpConversion<gpu::LaneIdOp>, |
| 503 | GPUSubgroupOpConversion<gpu::NumSubgroupsOp>, |
| 504 | GPUSubgroupOpConversion<gpu::SubgroupIdOp>, |
| 505 | GPUSubgroupOpConversion<gpu::SubgroupSizeOp>, |
| 506 | LaunchConfigOpConversion<gpu::BlockDimOp>, |
| 507 | LaunchConfigOpConversion<gpu::BlockIdOp>, |
| 508 | LaunchConfigOpConversion<gpu::GlobalIdOp>, |
| 509 | LaunchConfigOpConversion<gpu::GridDimOp>, |
| 510 | LaunchConfigOpConversion<gpu::ThreadIdOp>>(typeConverter); |
| 511 | MLIRContext *context = &typeConverter.getContext(); |
| 512 | unsigned privateAddressSpace = |
| 513 | gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Private); |
| 514 | unsigned localAddressSpace = |
| 515 | gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Workgroup); |
| 516 | OperationName llvmFuncOpName(LLVM::LLVMFuncOp::getOperationName(), context); |
| 517 | StringAttr kernelBlockSizeAttributeName = |
| 518 | LLVM::LLVMFuncOp::getReqdWorkGroupSizeAttrName(llvmFuncOpName); |
| 519 | patterns.add<GPUFuncOpLowering>( |
| 520 | typeConverter, |
| 521 | GPUFuncOpLoweringOptions{ |
| 522 | privateAddressSpace, localAddressSpace, |
| 523 | /*kernelAttributeName=*/{}, kernelBlockSizeAttributeName, |
| 524 | LLVM::CConv::SPIR_KERNEL, LLVM::CConv::SPIR_FUNC, |
| 525 | /*encodeWorkgroupAttributionsAsArguments=*/true}); |
| 526 | } |
| 527 | |
| 528 | void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter) { |
| 529 | populateGpuMemorySpaceAttributeConversions(typeConverter, |
| 530 | mapping: gpuAddressSpaceToOCLAddressSpace); |
| 531 | } |
| 532 | } // namespace mlir |
| 533 | |