1//===- GPUToLLVMSPV.cpp - Convert GPU operations to LLVM dialect ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/GPUToLLVMSPV/GPUToLLVMSPVPass.h"
10
11#include "../GPUCommon/GPUOpsLowering.h"
12#include "mlir/Conversion/GPUCommon/AttrToSPIRVConverter.h"
13#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
14#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
15#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
16#include "mlir/Conversion/LLVMCommon/Pattern.h"
17#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
18#include "mlir/Conversion/SPIRVCommon/AttrToLLVMConverter.h"
19#include "mlir/Dialect/GPU/IR/GPUDialect.h"
20#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
21#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
23#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
24#include "mlir/IR/BuiltinTypes.h"
25#include "mlir/IR/Matchers.h"
26#include "mlir/IR/PatternMatch.h"
27#include "mlir/IR/SymbolTable.h"
28#include "mlir/Pass/Pass.h"
29#include "mlir/Support/LLVM.h"
30#include "mlir/Transforms/DialectConversion.h"
31
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/Support/FormatVariadic.h"
34
35#define DEBUG_TYPE "gpu-to-llvm-spv"
36
37using namespace mlir;
38
39namespace mlir {
40#define GEN_PASS_DEF_CONVERTGPUOPSTOLLVMSPVOPS
41#include "mlir/Conversion/Passes.h.inc"
42} // namespace mlir
43
44//===----------------------------------------------------------------------===//
45// Helper Functions
46//===----------------------------------------------------------------------===//
47
48static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
49 StringRef name,
50 ArrayRef<Type> paramTypes,
51 Type resultType, bool isMemNone,
52 bool isConvergent) {
53 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
54 SymbolTable::lookupSymbolIn(symbolTable, name));
55 if (!func) {
56 OpBuilder b(symbolTable->getRegion(index: 0));
57 func = b.create<LLVM::LLVMFuncOp>(
58 symbolTable->getLoc(), name,
59 LLVM::LLVMFunctionType::get(resultType, paramTypes));
60 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
61 func.setNoUnwind(true);
62 func.setWillReturn(true);
63
64 if (isMemNone) {
65 // no externally observable effects
66 constexpr auto noModRef = mlir::LLVM::ModRefInfo::NoModRef;
67 auto memAttr = b.getAttr<LLVM::MemoryEffectsAttr>(
68 /*other=*/noModRef,
69 /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
70 func.setMemoryEffectsAttr(memAttr);
71 }
72
73 func.setConvergent(isConvergent);
74 }
75 return func;
76}
77
78static LLVM::CallOp createSPIRVBuiltinCall(Location loc,
79 ConversionPatternRewriter &rewriter,
80 LLVM::LLVMFuncOp func,
81 ValueRange args) {
82 auto call = rewriter.create<LLVM::CallOp>(loc, func, args);
83 call.setCConv(func.getCConv());
84 call.setConvergentAttr(func.getConvergentAttr());
85 call.setNoUnwindAttr(func.getNoUnwindAttr());
86 call.setWillReturnAttr(func.getWillReturnAttr());
87 call.setMemoryEffectsAttr(func.getMemoryEffectsAttr());
88 return call;
89}
90
91namespace {
92//===----------------------------------------------------------------------===//
93// Barriers
94//===----------------------------------------------------------------------===//
95
96/// Replace `gpu.barrier` with an `llvm.call` to `barrier` with
97/// `CLK_LOCAL_MEM_FENCE` argument, indicating work-group memory scope:
98/// ```
99/// // gpu.barrier
100/// %c1 = llvm.mlir.constant(1: i32) : i32
101/// llvm.call spir_funccc @_Z7barrierj(%c1) : (i32) -> ()
102/// ```
103struct GPUBarrierConversion final : ConvertOpToLLVMPattern<gpu::BarrierOp> {
104 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
105
106 LogicalResult
107 matchAndRewrite(gpu::BarrierOp op, OpAdaptor adaptor,
108 ConversionPatternRewriter &rewriter) const final {
109 constexpr StringLiteral funcName = "_Z7barrierj";
110
111 Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
112 assert(moduleOp && "Expecting module");
113 Type flagTy = rewriter.getI32Type();
114 Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
115 LLVM::LLVMFuncOp func =
116 lookupOrCreateSPIRVFn(moduleOp, funcName, flagTy, voidTy,
117 /*isMemNone=*/false, /*isConvergent=*/true);
118
119 // Value used by SPIR-V backend to represent `CLK_LOCAL_MEM_FENCE`.
120 // See `llvm/lib/Target/SPIRV/SPIRVBuiltins.td`.
121 constexpr int64_t localMemFenceFlag = 1;
122 Location loc = op->getLoc();
123 Value flag =
124 rewriter.create<LLVM::ConstantOp>(loc, flagTy, localMemFenceFlag);
125 rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, flag));
126 return success();
127 }
128};
129
130//===----------------------------------------------------------------------===//
131// SPIR-V Builtins
132//===----------------------------------------------------------------------===//
133
134/// Replace `gpu.*` with an `llvm.call` to the corresponding SPIR-V builtin with
135/// a constant argument for the `dimension` attribute. Return type will depend
136/// on index width option:
137/// ```
138/// // %thread_id_y = gpu.thread_id y
139/// %c1 = llvm.mlir.constant(1: i32) : i32
140/// %0 = llvm.call spir_funccc @_Z12get_local_idj(%c1) : (i32) -> i64
141/// ```
142struct LaunchConfigConversion : ConvertToLLVMPattern {
143 LaunchConfigConversion(StringRef funcName, StringRef rootOpName,
144 MLIRContext *context,
145 const LLVMTypeConverter &typeConverter,
146 PatternBenefit benefit)
147 : ConvertToLLVMPattern(rootOpName, context, typeConverter, benefit),
148 funcName(funcName) {}
149
150 virtual gpu::Dimension getDimension(Operation *op) const = 0;
151
152 LogicalResult
153 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
154 ConversionPatternRewriter &rewriter) const final {
155 Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
156 assert(moduleOp && "Expecting module");
157 Type dimTy = rewriter.getI32Type();
158 Type indexTy = getTypeConverter()->getIndexType();
159 LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(moduleOp, funcName, dimTy,
160 indexTy, /*isMemNone=*/true,
161 /*isConvergent=*/false);
162
163 Location loc = op->getLoc();
164 gpu::Dimension dim = getDimension(op);
165 Value dimVal = rewriter.create<LLVM::ConstantOp>(loc, dimTy,
166 static_cast<int64_t>(dim));
167 rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, dimVal));
168 return success();
169 }
170
171 StringRef funcName;
172};
173
174template <typename SourceOp>
175struct LaunchConfigOpConversion final : LaunchConfigConversion {
176 static StringRef getFuncName();
177
178 explicit LaunchConfigOpConversion(const LLVMTypeConverter &typeConverter,
179 PatternBenefit benefit = 1)
180 : LaunchConfigConversion(getFuncName(), SourceOp::getOperationName(),
181 &typeConverter.getContext(), typeConverter,
182 benefit) {}
183
184 gpu::Dimension getDimension(Operation *op) const final {
185 return cast<SourceOp>(op).getDimension();
186 }
187};
188
189template <>
190StringRef LaunchConfigOpConversion<gpu::BlockIdOp>::getFuncName() {
191 return "_Z12get_group_idj";
192}
193
194template <>
195StringRef LaunchConfigOpConversion<gpu::GridDimOp>::getFuncName() {
196 return "_Z14get_num_groupsj";
197}
198
199template <>
200StringRef LaunchConfigOpConversion<gpu::BlockDimOp>::getFuncName() {
201 return "_Z14get_local_sizej";
202}
203
204template <>
205StringRef LaunchConfigOpConversion<gpu::ThreadIdOp>::getFuncName() {
206 return "_Z12get_local_idj";
207}
208
209template <>
210StringRef LaunchConfigOpConversion<gpu::GlobalIdOp>::getFuncName() {
211 return "_Z13get_global_idj";
212}
213
214//===----------------------------------------------------------------------===//
215// Shuffles
216//===----------------------------------------------------------------------===//
217
218/// Replace `gpu.shuffle` with an `llvm.call` to the corresponding SPIR-V
219/// builtin for `shuffleResult`, keeping `value` and `offset` arguments, and a
220/// `true` constant for the `valid` result type. Conversion will only take place
221/// if `width` is constant and equal to the `subgroup` pass option:
222/// ```
223/// // %0 = gpu.shuffle idx %value, %offset, %width : f64
224/// %0 = llvm.call spir_funccc @_Z17sub_group_shuffledj(%value, %offset)
225/// : (f64, i32) -> f64
226/// ```
227struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
228 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
229
230 static StringRef getBaseName(gpu::ShuffleMode mode) {
231 switch (mode) {
232 case gpu::ShuffleMode::IDX:
233 return "sub_group_shuffle";
234 case gpu::ShuffleMode::XOR:
235 return "sub_group_shuffle_xor";
236 case gpu::ShuffleMode::UP:
237 return "sub_group_shuffle_up";
238 case gpu::ShuffleMode::DOWN:
239 return "sub_group_shuffle_down";
240 }
241 llvm_unreachable("Unhandled shuffle mode");
242 }
243
244 static std::optional<StringRef> getTypeMangling(Type type) {
245 return TypeSwitch<Type, std::optional<StringRef>>(type)
246 .Case<Float16Type>([](auto) { return "Dhj"; })
247 .Case<Float32Type>([](auto) { return "fj"; })
248 .Case<Float64Type>([](auto) { return "dj"; })
249 .Case<IntegerType>([](auto intTy) -> std::optional<StringRef> {
250 switch (intTy.getWidth()) {
251 case 8:
252 return "cj";
253 case 16:
254 return "sj";
255 case 32:
256 return "ij";
257 case 64:
258 return "lj";
259 }
260 return std::nullopt;
261 })
262 .Default([](auto) { return std::nullopt; });
263 }
264
265 static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
266 Type type) {
267 StringRef baseName = getBaseName(mode);
268 std::optional<StringRef> typeMangling = getTypeMangling(type);
269 if (!typeMangling)
270 return std::nullopt;
271 return llvm::formatv(Fmt: "_Z{}{}{}", Vals: baseName.size(), Vals&: baseName,
272 Vals&: typeMangling.value());
273 }
274
275 /// Get the subgroup size from the target or return a default.
276 static std::optional<int> getSubgroupSize(Operation *op) {
277 auto parentFunc = op->getParentOfType<LLVM::LLVMFuncOp>();
278 if (!parentFunc)
279 return std::nullopt;
280 return parentFunc.getIntelReqdSubGroupSize();
281 }
282
283 static bool hasValidWidth(gpu::ShuffleOp op) {
284 llvm::APInt val;
285 Value width = op.getWidth();
286 return matchPattern(width, m_ConstantInt(&val)) &&
287 val == getSubgroupSize(op);
288 }
289
290 static Value bitcastOrExtBeforeShuffle(Value oldVal, Location loc,
291 ConversionPatternRewriter &rewriter) {
292 return TypeSwitch<Type, Value>(oldVal.getType())
293 .Case(caseFn: [&](BFloat16Type) {
294 return rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI16Type(),
295 oldVal);
296 })
297 .Case(caseFn: [&](IntegerType intTy) -> Value {
298 if (intTy.getWidth() == 1)
299 return rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI8Type(),
300 oldVal);
301 return oldVal;
302 })
303 .Default(defaultResult: oldVal);
304 }
305
306 static Value bitcastOrTruncAfterShuffle(Value oldVal, Type newTy,
307 Location loc,
308 ConversionPatternRewriter &rewriter) {
309 return TypeSwitch<Type, Value>(newTy)
310 .Case(caseFn: [&](BFloat16Type) {
311 return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
312 })
313 .Case(caseFn: [&](IntegerType intTy) -> Value {
314 if (intTy.getWidth() == 1)
315 return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal);
316 return oldVal;
317 })
318 .Default(defaultResult: oldVal);
319 }
320
321 LogicalResult
322 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
323 ConversionPatternRewriter &rewriter) const final {
324 if (!hasValidWidth(op))
325 return rewriter.notifyMatchFailure(
326 op, "shuffle width and subgroup size mismatch");
327
328 Location loc = op->getLoc();
329 Value inValue =
330 bitcastOrExtBeforeShuffle(oldVal: adaptor.getValue(), loc, rewriter);
331 std::optional<std::string> funcName =
332 getFuncName(op.getMode(), inValue.getType());
333 if (!funcName)
334 return rewriter.notifyMatchFailure(op, "unsupported value type");
335
336 Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
337 assert(moduleOp && "Expecting module");
338 Type valueType = inValue.getType();
339 Type offsetType = adaptor.getOffset().getType();
340 Type resultType = valueType;
341 LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
342 moduleOp, funcName.value(), {valueType, offsetType}, resultType,
343 /*isMemNone=*/false, /*isConvergent=*/true);
344
345 std::array<Value, 2> args{inValue, adaptor.getOffset()};
346 Value result =
347 createSPIRVBuiltinCall(loc, rewriter, func, args).getResult();
348 Value resultOrConversion =
349 bitcastOrTruncAfterShuffle(oldVal: result, newTy: op.getType(0), loc, rewriter);
350
351 Value trueVal =
352 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true);
353 rewriter.replaceOp(op, {resultOrConversion, trueVal});
354 return success();
355 }
356};
357
358class MemorySpaceToOpenCLMemorySpaceConverter final : public TypeConverter {
359public:
360 MemorySpaceToOpenCLMemorySpaceConverter(MLIRContext *ctx) {
361 addConversion(callback: [](Type t) { return t; });
362 addConversion(callback: [ctx](BaseMemRefType memRefType) -> std::optional<Type> {
363 // Attach global addr space attribute to memrefs with no addr space attr
364 Attribute memSpaceAttr = memRefType.getMemorySpace();
365 if (memSpaceAttr)
366 return std::nullopt;
367
368 unsigned globalAddrspace = storageClassToAddressSpace(
369 spirv::ClientAPI::OpenCL, spirv::StorageClass::CrossWorkgroup);
370 Attribute addrSpaceAttr =
371 IntegerAttr::get(IntegerType::get(ctx, 64), globalAddrspace);
372 if (auto rankedType = dyn_cast<MemRefType>(memRefType)) {
373 return MemRefType::get(memRefType.getShape(),
374 memRefType.getElementType(),
375 rankedType.getLayout(), addrSpaceAttr);
376 }
377 return UnrankedMemRefType::get(memRefType.getElementType(),
378 addrSpaceAttr);
379 });
380 addConversion(callback: [this](FunctionType type) {
381 auto inputs = llvm::map_to_vector(
382 type.getInputs(), [this](Type ty) { return convertType(t: ty); });
383 auto results = llvm::map_to_vector(
384 type.getResults(), [this](Type ty) { return convertType(t: ty); });
385 return FunctionType::get(type.getContext(), inputs, results);
386 });
387 }
388};
389
390//===----------------------------------------------------------------------===//
391// Subgroup query ops.
392//===----------------------------------------------------------------------===//
393
394template <typename SubgroupOp>
395struct GPUSubgroupOpConversion final : ConvertOpToLLVMPattern<SubgroupOp> {
396 using ConvertOpToLLVMPattern<SubgroupOp>::ConvertOpToLLVMPattern;
397 using ConvertToLLVMPattern::getTypeConverter;
398
399 LogicalResult
400 matchAndRewrite(SubgroupOp op, typename SubgroupOp::Adaptor adaptor,
401 ConversionPatternRewriter &rewriter) const final {
402 constexpr StringRef funcName = [] {
403 if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupIdOp>) {
404 return "_Z16get_sub_group_id";
405 } else if constexpr (std::is_same_v<SubgroupOp, gpu::LaneIdOp>) {
406 return "_Z22get_sub_group_local_id";
407 } else if constexpr (std::is_same_v<SubgroupOp, gpu::NumSubgroupsOp>) {
408 return "_Z18get_num_sub_groups";
409 } else if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupSizeOp>) {
410 return "_Z18get_sub_group_size";
411 }
412 }();
413
414 Operation *moduleOp =
415 op->template getParentWithTrait<OpTrait::SymbolTable>();
416 Type resultTy = rewriter.getI32Type();
417 LLVM::LLVMFuncOp func =
418 lookupOrCreateSPIRVFn(moduleOp, funcName, {}, resultTy,
419 /*isMemNone=*/false, /*isConvergent=*/false);
420
421 Location loc = op->getLoc();
422 Value result = createSPIRVBuiltinCall(loc, rewriter, func, {}).getResult();
423
424 Type indexTy = getTypeConverter()->getIndexType();
425 if (resultTy != indexTy) {
426 if (indexTy.getIntOrFloatBitWidth() < resultTy.getIntOrFloatBitWidth()) {
427 return failure();
428 }
429 result = rewriter.create<LLVM::ZExtOp>(loc, indexTy, result);
430 }
431
432 rewriter.replaceOp(op, result);
433 return success();
434 }
435};
436
437//===----------------------------------------------------------------------===//
438// GPU To LLVM-SPV Pass.
439//===----------------------------------------------------------------------===//
440
441struct GPUToLLVMSPVConversionPass final
442 : impl::ConvertGpuOpsToLLVMSPVOpsBase<GPUToLLVMSPVConversionPass> {
443 using Base::Base;
444
445 void runOnOperation() final {
446 MLIRContext *context = &getContext();
447 RewritePatternSet patterns(context);
448
449 LowerToLLVMOptions options(context);
450 options.overrideIndexBitwidth(bitwidth: this->use64bitIndex ? 64 : 32);
451 LLVMTypeConverter converter(context, options);
452 LLVMConversionTarget target(*context);
453
454 // Force OpenCL address spaces when they are not present
455 {
456 MemorySpaceToOpenCLMemorySpaceConverter converter(context);
457 AttrTypeReplacer replacer;
458 replacer.addReplacement(callback: [&converter](BaseMemRefType origType)
459 -> std::optional<BaseMemRefType> {
460 return converter.convertType<BaseMemRefType>(t: origType);
461 });
462
463 replacer.recursivelyReplaceElementsIn(op: getOperation(),
464 /*replaceAttrs=*/true,
465 /*replaceLocs=*/false,
466 /*replaceTypes=*/true);
467 }
468
469 target.addIllegalOp<gpu::BarrierOp, gpu::BlockDimOp, gpu::BlockIdOp,
470 gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp,
471 gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp,
472 gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp,
473 gpu::ThreadIdOp>();
474
475 populateGpuToLLVMSPVConversionPatterns(converter, patterns);
476 populateGpuMemorySpaceAttributeConversions(typeConverter&: converter);
477
478 if (failed(applyPartialConversion(getOperation(), target,
479 std::move(patterns))))
480 signalPassFailure();
481 }
482};
483} // namespace
484
485//===----------------------------------------------------------------------===//
486// GPU To LLVM-SPV Patterns.
487//===----------------------------------------------------------------------===//
488
489namespace mlir {
490namespace {
491static unsigned
492gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace addressSpace) {
493 constexpr spirv::ClientAPI clientAPI = spirv::ClientAPI::OpenCL;
494 return storageClassToAddressSpace(clientAPI,
495 addressSpaceToStorageClass(addressSpace));
496}
497} // namespace
498
499void populateGpuToLLVMSPVConversionPatterns(
500 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
501 patterns.add<GPUBarrierConversion, GPUReturnOpLowering, GPUShuffleConversion,
502 GPUSubgroupOpConversion<gpu::LaneIdOp>,
503 GPUSubgroupOpConversion<gpu::NumSubgroupsOp>,
504 GPUSubgroupOpConversion<gpu::SubgroupIdOp>,
505 GPUSubgroupOpConversion<gpu::SubgroupSizeOp>,
506 LaunchConfigOpConversion<gpu::BlockDimOp>,
507 LaunchConfigOpConversion<gpu::BlockIdOp>,
508 LaunchConfigOpConversion<gpu::GlobalIdOp>,
509 LaunchConfigOpConversion<gpu::GridDimOp>,
510 LaunchConfigOpConversion<gpu::ThreadIdOp>>(typeConverter);
511 MLIRContext *context = &typeConverter.getContext();
512 unsigned privateAddressSpace =
513 gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Private);
514 unsigned localAddressSpace =
515 gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Workgroup);
516 OperationName llvmFuncOpName(LLVM::LLVMFuncOp::getOperationName(), context);
517 StringAttr kernelBlockSizeAttributeName =
518 LLVM::LLVMFuncOp::getReqdWorkGroupSizeAttrName(llvmFuncOpName);
519 patterns.add<GPUFuncOpLowering>(
520 typeConverter,
521 GPUFuncOpLoweringOptions{
522 privateAddressSpace, localAddressSpace,
523 /*kernelAttributeName=*/{}, kernelBlockSizeAttributeName,
524 LLVM::CConv::SPIR_KERNEL, LLVM::CConv::SPIR_FUNC,
525 /*encodeWorkgroupAttributionsAsArguments=*/true});
526}
527
528void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter) {
529 populateGpuMemorySpaceAttributeConversions(typeConverter,
530 mapping: gpuAddressSpaceToOCLAddressSpace);
531}
532} // namespace mlir
533

source code of mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp