1//===- GPUToSPIRV.cpp - GPU to SPIR-V Patterns ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert GPU dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
14#include "mlir/Dialect/GPU/IR/GPUDialect.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
22#include "mlir/IR/BuiltinOps.h"
23#include "mlir/IR/Matchers.h"
24#include "mlir/Support/LogicalResult.h"
25#include "mlir/Transforms/DialectConversion.h"
26#include <optional>
27
28using namespace mlir;
29
30static constexpr const char kSPIRVModule[] = "__spv__";
31
32namespace {
33/// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
34/// builtin variables.
35template <typename SourceOp, spirv::BuiltIn builtin>
36class LaunchConfigConversion : public OpConversionPattern<SourceOp> {
37public:
38 using OpConversionPattern<SourceOp>::OpConversionPattern;
39
40 LogicalResult
41 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
42 ConversionPatternRewriter &rewriter) const override;
43};
44
45/// Pattern lowering subgroup size/id to loading SPIR-V invocation
46/// builtin variables.
47template <typename SourceOp, spirv::BuiltIn builtin>
48class SingleDimLaunchConfigConversion : public OpConversionPattern<SourceOp> {
49public:
50 using OpConversionPattern<SourceOp>::OpConversionPattern;
51
52 LogicalResult
53 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
54 ConversionPatternRewriter &rewriter) const override;
55};
56
57/// This is separate because in Vulkan workgroup size is exposed to shaders via
58/// a constant with WorkgroupSize decoration. So here we cannot generate a
59/// builtin variable; instead the information in the `spirv.entry_point_abi`
60/// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp.
61class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> {
62public:
63 WorkGroupSizeConversion(TypeConverter &typeConverter, MLIRContext *context)
64 : OpConversionPattern(typeConverter, context, /*benefit*/ 10) {}
65
66 LogicalResult
67 matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor,
68 ConversionPatternRewriter &rewriter) const override;
69};
70
71/// Pattern to convert a kernel function in GPU dialect within a spirv.module.
72class GPUFuncOpConversion final : public OpConversionPattern<gpu::GPUFuncOp> {
73public:
74 using OpConversionPattern<gpu::GPUFuncOp>::OpConversionPattern;
75
76 LogicalResult
77 matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
78 ConversionPatternRewriter &rewriter) const override;
79
80private:
81 SmallVector<int32_t, 3> workGroupSizeAsInt32;
82};
83
84/// Pattern to convert a gpu.module to a spirv.module.
85class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> {
86public:
87 using OpConversionPattern<gpu::GPUModuleOp>::OpConversionPattern;
88
89 LogicalResult
90 matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
91 ConversionPatternRewriter &rewriter) const override;
92};
93
94class GPUModuleEndConversion final
95 : public OpConversionPattern<gpu::ModuleEndOp> {
96public:
97 using OpConversionPattern::OpConversionPattern;
98
99 LogicalResult
100 matchAndRewrite(gpu::ModuleEndOp endOp, OpAdaptor adaptor,
101 ConversionPatternRewriter &rewriter) const override {
102 rewriter.eraseOp(op: endOp);
103 return success();
104 }
105};
106
107/// Pattern to convert a gpu.return into a SPIR-V return.
108// TODO: This can go to DRR when GPU return has operands.
109class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
110public:
111 using OpConversionPattern<gpu::ReturnOp>::OpConversionPattern;
112
113 LogicalResult
114 matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
115 ConversionPatternRewriter &rewriter) const override;
116};
117
118/// Pattern to convert a gpu.barrier op into a spirv.ControlBarrier op.
119class GPUBarrierConversion final : public OpConversionPattern<gpu::BarrierOp> {
120public:
121 using OpConversionPattern::OpConversionPattern;
122
123 LogicalResult
124 matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor,
125 ConversionPatternRewriter &rewriter) const override;
126};
127
128/// Pattern to convert a gpu.shuffle op into a spirv.GroupNonUniformShuffle op.
129class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
130public:
131 using OpConversionPattern::OpConversionPattern;
132
133 LogicalResult
134 matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
135 ConversionPatternRewriter &rewriter) const override;
136};
137
138} // namespace
139
140//===----------------------------------------------------------------------===//
141// Builtins.
142//===----------------------------------------------------------------------===//
143
144template <typename SourceOp, spirv::BuiltIn builtin>
145LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
146 SourceOp op, typename SourceOp::Adaptor adaptor,
147 ConversionPatternRewriter &rewriter) const {
148 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
149 Type indexType = typeConverter->getIndexType();
150
151 // For Vulkan, these SPIR-V builtin variables are required to be a vector of
152 // type <3xi32> by the spec:
153 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumWorkgroups.html
154 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupId.html
155 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupSize.html
156 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
157 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
158 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/GlobalInvocationId.html
159 //
160 // For OpenCL, it depends on the Physical32/Physical64 addressing model:
161 // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
162 bool forShader =
163 typeConverter->getTargetEnv().allows(spirv::Capability::Shader);
164 Type builtinType = forShader ? rewriter.getIntegerType(32) : indexType;
165
166 Value vector =
167 spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter);
168 Value dim = rewriter.create<spirv::CompositeExtractOp>(
169 op.getLoc(), builtinType, vector,
170 rewriter.getI32ArrayAttr({static_cast<int32_t>(op.getDimension())}));
171 if (forShader && builtinType != indexType)
172 dim = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, dim);
173 rewriter.replaceOp(op, dim);
174 return success();
175}
176
177template <typename SourceOp, spirv::BuiltIn builtin>
178LogicalResult
179SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
180 SourceOp op, typename SourceOp::Adaptor adaptor,
181 ConversionPatternRewriter &rewriter) const {
182 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
183 Type indexType = typeConverter->getIndexType();
184 Type i32Type = rewriter.getIntegerType(32);
185
186 // For Vulkan, these SPIR-V builtin variables are required to be a vector of
187 // type i32 by the spec:
188 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumSubgroups.html
189 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupId.html
190 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupSize.html
191 //
192 // For OpenCL, they are also required to be i32:
193 // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
194 Value builtinValue =
195 spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter);
196 if (i32Type != indexType)
197 builtinValue = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType,
198 builtinValue);
199 rewriter.replaceOp(op, builtinValue);
200 return success();
201}
202
203LogicalResult WorkGroupSizeConversion::matchAndRewrite(
204 gpu::BlockDimOp op, OpAdaptor adaptor,
205 ConversionPatternRewriter &rewriter) const {
206 DenseI32ArrayAttr workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op);
207 if (!workGroupSizeAttr)
208 return failure();
209
210 int val =
211 workGroupSizeAttr.asArrayRef()[static_cast<int32_t>(op.getDimension())];
212 auto convertedType =
213 getTypeConverter()->convertType(op.getResult().getType());
214 if (!convertedType)
215 return failure();
216 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
217 op, convertedType, IntegerAttr::get(convertedType, val));
218 return success();
219}
220
221//===----------------------------------------------------------------------===//
222// GPUFuncOp
223//===----------------------------------------------------------------------===//
224
225// Legalizes a GPU function as an entry SPIR-V function.
226static spirv::FuncOp
227lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter,
228 ConversionPatternRewriter &rewriter,
229 spirv::EntryPointABIAttr entryPointInfo,
230 ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) {
231 auto fnType = funcOp.getFunctionType();
232 if (fnType.getNumResults()) {
233 funcOp.emitError("SPIR-V lowering only supports entry functions"
234 "with no return values right now");
235 return nullptr;
236 }
237 if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) {
238 funcOp.emitError(
239 "lowering as entry functions requires ABI info for all arguments "
240 "or none of them");
241 return nullptr;
242 }
243 // Update the signature to valid SPIR-V types and add the ABI
244 // attributes. These will be "materialized" by using the
245 // LowerABIAttributesPass.
246 TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
247 {
248 for (const auto &argType :
249 enumerate(funcOp.getFunctionType().getInputs())) {
250 auto convertedType = typeConverter.convertType(argType.value());
251 if (!convertedType)
252 return nullptr;
253 signatureConverter.addInputs(argType.index(), convertedType);
254 }
255 }
256 auto newFuncOp = rewriter.create<spirv::FuncOp>(
257 funcOp.getLoc(), funcOp.getName(),
258 rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
259 std::nullopt));
260 for (const auto &namedAttr : funcOp->getAttrs()) {
261 if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
262 namedAttr.getName() == SymbolTable::getSymbolAttrName())
263 continue;
264 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
265 }
266
267 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
268 newFuncOp.end());
269 if (failed(rewriter.convertRegionTypes(region: &newFuncOp.getBody(), converter: typeConverter,
270 entryConversion: &signatureConverter)))
271 return nullptr;
272 rewriter.eraseOp(op: funcOp);
273
274 // Set the attributes for argument and the function.
275 StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
276 for (auto argIndex : llvm::seq<unsigned>(Begin: 0, End: argABIInfo.size())) {
277 newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
278 }
279 newFuncOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
280
281 return newFuncOp;
282}
283
284/// Populates `argABI` with spirv.interface_var_abi attributes for lowering
285/// gpu.func to spirv.func if no arguments have the attributes set
286/// already. Returns failure if any argument has the ABI attribute set already.
287static LogicalResult
288getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp,
289 SmallVectorImpl<spirv::InterfaceVarABIAttr> &argABI) {
290 if (!spirv::needsInterfaceVarABIAttrs(targetAttr: targetEnv))
291 return success();
292
293 for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
294 if (funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
295 argIndex, spirv::getInterfaceVarABIAttrName()))
296 return failure();
297 // Vulkan's interface variable requirements needs scalars to be wrapped in a
298 // struct. The struct held in storage buffer.
299 std::optional<spirv::StorageClass> sc;
300 if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
301 sc = spirv::StorageClass::StorageBuffer;
302 argABI.push_back(
303 spirv::getInterfaceVarABIAttr(0, argIndex, sc, funcOp.getContext()));
304 }
305 return success();
306}
307
308LogicalResult GPUFuncOpConversion::matchAndRewrite(
309 gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
310 ConversionPatternRewriter &rewriter) const {
311 if (!gpu::GPUDialect::isKernel(funcOp))
312 return failure();
313
314 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
315 SmallVector<spirv::InterfaceVarABIAttr, 4> argABI;
316 if (failed(
317 getDefaultABIAttrs(typeConverter->getTargetEnv(), funcOp, argABI))) {
318 argABI.clear();
319 for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
320 // If the ABI is already specified, use it.
321 auto abiAttr = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
322 argIndex, spirv::getInterfaceVarABIAttrName());
323 if (!abiAttr) {
324 funcOp.emitRemark(
325 "match failure: missing 'spirv.interface_var_abi' attribute at "
326 "argument ")
327 << argIndex;
328 return failure();
329 }
330 argABI.push_back(abiAttr);
331 }
332 }
333
334 auto entryPointAttr = spirv::lookupEntryPointABI(funcOp);
335 if (!entryPointAttr) {
336 funcOp.emitRemark(
337 "match failure: missing 'spirv.entry_point_abi' attribute");
338 return failure();
339 }
340 spirv::FuncOp newFuncOp = lowerAsEntryFunction(
341 funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
342 if (!newFuncOp)
343 return failure();
344 newFuncOp->removeAttr(
345 rewriter.getStringAttr(gpu::GPUDialect::getKernelFuncAttrName()));
346 return success();
347}
348
349//===----------------------------------------------------------------------===//
350// ModuleOp with gpu.module.
351//===----------------------------------------------------------------------===//
352
353LogicalResult GPUModuleConversion::matchAndRewrite(
354 gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
355 ConversionPatternRewriter &rewriter) const {
356 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
357 const spirv::TargetEnv &targetEnv = typeConverter->getTargetEnv();
358 spirv::AddressingModel addressingModel = spirv::getAddressingModel(
359 targetEnv, typeConverter->getOptions().use64bitIndex);
360 FailureOr<spirv::MemoryModel> memoryModel = spirv::getMemoryModel(targetEnv);
361 if (failed(memoryModel))
362 return moduleOp.emitRemark(
363 "cannot deduce memory model from 'spirv.target_env'");
364
365 // Add a keyword to the module name to avoid symbolic conflict.
366 std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
367 auto spvModule = rewriter.create<spirv::ModuleOp>(
368 moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
369 StringRef(spvModuleName));
370
371 // Move the region from the module op into the SPIR-V module.
372 Region &spvModuleRegion = spvModule.getRegion();
373 rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion,
374 spvModuleRegion.begin());
375 // The spirv.module build method adds a block. Remove that.
376 rewriter.eraseBlock(block: &spvModuleRegion.back());
377
378 // Some of the patterns call `lookupTargetEnv` during conversion and they
379 // will fail if called after GPUModuleConversion and we don't preserve
380 // `TargetEnv` attribute.
381 // Copy TargetEnvAttr only if it is attached directly to the GPUModuleOp.
382 if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>(
383 spirv::getTargetEnvAttrName()))
384 spvModule->setAttr(spirv::getTargetEnvAttrName(), attr);
385
386 rewriter.eraseOp(op: moduleOp);
387 return success();
388}
389
390//===----------------------------------------------------------------------===//
391// GPU return inside kernel functions to SPIR-V return.
392//===----------------------------------------------------------------------===//
393
394LogicalResult GPUReturnOpConversion::matchAndRewrite(
395 gpu::ReturnOp returnOp, OpAdaptor adaptor,
396 ConversionPatternRewriter &rewriter) const {
397 if (!adaptor.getOperands().empty())
398 return failure();
399
400 rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
401 return success();
402}
403
404//===----------------------------------------------------------------------===//
405// Barrier.
406//===----------------------------------------------------------------------===//
407
408LogicalResult GPUBarrierConversion::matchAndRewrite(
409 gpu::BarrierOp barrierOp, OpAdaptor adaptor,
410 ConversionPatternRewriter &rewriter) const {
411 MLIRContext *context = getContext();
412 // Both execution and memory scope should be workgroup.
413 auto scope = spirv::ScopeAttr::get(context, spirv::Scope::Workgroup);
414 // Require acquire and release memory semantics for workgroup memory.
415 auto memorySemantics = spirv::MemorySemanticsAttr::get(
416 context, spirv::MemorySemantics::WorkgroupMemory |
417 spirv::MemorySemantics::AcquireRelease);
418 rewriter.replaceOpWithNewOp<spirv::ControlBarrierOp>(barrierOp, scope, scope,
419 memorySemantics);
420 return success();
421}
422
423//===----------------------------------------------------------------------===//
424// Shuffle
425//===----------------------------------------------------------------------===//
426
427LogicalResult GPUShuffleConversion::matchAndRewrite(
428 gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
429 ConversionPatternRewriter &rewriter) const {
430 // Require the shuffle width to be the same as the target's subgroup size,
431 // given that for SPIR-V non-uniform subgroup ops, we cannot select
432 // participating invocations.
433 auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
434 unsigned subgroupSize =
435 targetEnv.getAttr().getResourceLimits().getSubgroupSize();
436 IntegerAttr widthAttr;
437 if (!matchPattern(shuffleOp.getWidth(), m_Constant(&widthAttr)) ||
438 widthAttr.getValue().getZExtValue() != subgroupSize)
439 return rewriter.notifyMatchFailure(
440 shuffleOp, "shuffle width and target subgroup size mismatch");
441
442 Location loc = shuffleOp.getLoc();
443 Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
444 shuffleOp.getLoc(), rewriter);
445 auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
446 Value result;
447
448 switch (shuffleOp.getMode()) {
449 case gpu::ShuffleMode::XOR:
450 result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>(
451 loc, scope, adaptor.getValue(), adaptor.getOffset());
452 break;
453 case gpu::ShuffleMode::IDX:
454 result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
455 loc, scope, adaptor.getValue(), adaptor.getOffset());
456 break;
457 default:
458 return rewriter.notifyMatchFailure(shuffleOp, "unimplemented shuffle mode");
459 }
460
461 rewriter.replaceOp(shuffleOp, {result, trueVal});
462 return success();
463}
464
465//===----------------------------------------------------------------------===//
466// Group ops
467//===----------------------------------------------------------------------===//
468
469template <typename UniformOp, typename NonUniformOp>
470static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc,
471 Value arg, bool isGroup, bool isUniform) {
472 Type type = arg.getType();
473 auto scope = mlir::spirv::ScopeAttr::get(builder.getContext(),
474 isGroup ? spirv::Scope::Workgroup
475 : spirv::Scope::Subgroup);
476 auto groupOp = spirv::GroupOperationAttr::get(builder.getContext(),
477 spirv::GroupOperation::Reduce);
478 if (isUniform) {
479 return builder.create<UniformOp>(loc, type, scope, groupOp, arg)
480 .getResult();
481 }
482 return builder.create<NonUniformOp>(loc, type, scope, groupOp, arg, Value{})
483 .getResult();
484}
485
486static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
487 Location loc, Value arg,
488 gpu::AllReduceOperation opType,
489 bool isGroup, bool isUniform) {
490 enum class ElemType { Float, Boolean, Integer };
491 using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool);
492 struct OpHandler {
493 gpu::AllReduceOperation kind;
494 ElemType elemType;
495 FuncT func;
496 };
497
498 Type type = arg.getType();
499 ElemType elementType;
500 if (isa<FloatType>(Val: type)) {
501 elementType = ElemType::Float;
502 } else if (auto intTy = dyn_cast<IntegerType>(type)) {
503 elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean
504 : ElemType::Integer;
505 } else {
506 return std::nullopt;
507 }
508
509 // TODO(https://github.com/llvm/llvm-project/issues/73459): The SPIR-V spec
510 // does not specify how -0.0 / +0.0 and NaN values are handled in *FMin/*FMax
511 // reduction ops. We should account possible precision requirements in this
512 // conversion.
513
514 using ReduceType = gpu::AllReduceOperation;
515 const OpHandler handlers[] = {
516 {ReduceType::ADD, ElemType::Integer,
517 &createGroupReduceOpImpl<spirv::GroupIAddOp,
518 spirv::GroupNonUniformIAddOp>},
519 {ReduceType::ADD, ElemType::Float,
520 &createGroupReduceOpImpl<spirv::GroupFAddOp,
521 spirv::GroupNonUniformFAddOp>},
522 {ReduceType::MUL, ElemType::Integer,
523 &createGroupReduceOpImpl<spirv::GroupIMulKHROp,
524 spirv::GroupNonUniformIMulOp>},
525 {ReduceType::MUL, ElemType::Float,
526 &createGroupReduceOpImpl<spirv::GroupFMulKHROp,
527 spirv::GroupNonUniformFMulOp>},
528 {ReduceType::MINUI, ElemType::Integer,
529 &createGroupReduceOpImpl<spirv::GroupUMinOp,
530 spirv::GroupNonUniformUMinOp>},
531 {ReduceType::MINSI, ElemType::Integer,
532 &createGroupReduceOpImpl<spirv::GroupSMinOp,
533 spirv::GroupNonUniformSMinOp>},
534 {ReduceType::MINNUMF, ElemType::Float,
535 &createGroupReduceOpImpl<spirv::GroupFMinOp,
536 spirv::GroupNonUniformFMinOp>},
537 {ReduceType::MAXUI, ElemType::Integer,
538 &createGroupReduceOpImpl<spirv::GroupUMaxOp,
539 spirv::GroupNonUniformUMaxOp>},
540 {ReduceType::MAXSI, ElemType::Integer,
541 &createGroupReduceOpImpl<spirv::GroupSMaxOp,
542 spirv::GroupNonUniformSMaxOp>},
543 {ReduceType::MAXNUMF, ElemType::Float,
544 &createGroupReduceOpImpl<spirv::GroupFMaxOp,
545 spirv::GroupNonUniformFMaxOp>},
546 {ReduceType::MINIMUMF, ElemType::Float,
547 &createGroupReduceOpImpl<spirv::GroupFMinOp,
548 spirv::GroupNonUniformFMinOp>},
549 {ReduceType::MAXIMUMF, ElemType::Float,
550 &createGroupReduceOpImpl<spirv::GroupFMaxOp,
551 spirv::GroupNonUniformFMaxOp>}};
552
553 for (const OpHandler &handler : handlers)
554 if (handler.kind == opType && elementType == handler.elemType)
555 return handler.func(builder, loc, arg, isGroup, isUniform);
556
557 return std::nullopt;
558}
559
560/// Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
561class GPUAllReduceConversion final
562 : public OpConversionPattern<gpu::AllReduceOp> {
563public:
564 using OpConversionPattern::OpConversionPattern;
565
566 LogicalResult
567 matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor,
568 ConversionPatternRewriter &rewriter) const override {
569 auto opType = op.getOp();
570
571 // gpu.all_reduce can have either reduction op attribute or reduction
572 // region. Only attribute version is supported.
573 if (!opType)
574 return failure();
575
576 auto result =
577 createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), *opType,
578 /*isGroup*/ true, op.getUniform());
579 if (!result)
580 return failure();
581
582 rewriter.replaceOp(op, *result);
583 return success();
584 }
585};
586
587/// Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
588class GPUSubgroupReduceConversion final
589 : public OpConversionPattern<gpu::SubgroupReduceOp> {
590public:
591 using OpConversionPattern::OpConversionPattern;
592
593 LogicalResult
594 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
595 ConversionPatternRewriter &rewriter) const override {
596 if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
597 return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");
598
599 auto result = createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(),
600 adaptor.getOp(),
601 /*isGroup=*/false, adaptor.getUniform());
602 if (!result)
603 return failure();
604
605 rewriter.replaceOp(op, *result);
606 return success();
607 }
608};
609
610//===----------------------------------------------------------------------===//
611// GPU To SPIRV Patterns.
612//===----------------------------------------------------------------------===//
613
614void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
615 RewritePatternSet &patterns) {
616 patterns.add<
617 GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
618 GPUModuleEndConversion, GPUReturnOpConversion, GPUShuffleConversion,
619 LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
620 LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
621 LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
622 LaunchConfigConversion<gpu::ThreadIdOp,
623 spirv::BuiltIn::LocalInvocationId>,
624 LaunchConfigConversion<gpu::GlobalIdOp,
625 spirv::BuiltIn::GlobalInvocationId>,
626 SingleDimLaunchConfigConversion<gpu::SubgroupIdOp,
627 spirv::BuiltIn::SubgroupId>,
628 SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp,
629 spirv::BuiltIn::NumSubgroups>,
630 SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
631 spirv::BuiltIn::SubgroupSize>,
632 WorkGroupSizeConversion, GPUAllReduceConversion,
633 GPUSubgroupReduceConversion>(typeConverter, patterns.getContext());
634}
635

source code of mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp