1//===- GPUToSPIRV.cpp - GPU to SPIR-V Patterns ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert GPU dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
14#include "mlir/Dialect/GPU/IR/GPUDialect.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
22#include "mlir/IR/BuiltinOps.h"
23#include "mlir/IR/Matchers.h"
24#include "mlir/Transforms/DialectConversion.h"
25#include <optional>
26
27using namespace mlir;
28
29static constexpr const char kSPIRVModule[] = "__spv__";
30
31namespace {
32/// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
33/// builtin variables.
34template <typename SourceOp, spirv::BuiltIn builtin>
35class LaunchConfigConversion : public OpConversionPattern<SourceOp> {
36public:
37 using OpConversionPattern<SourceOp>::OpConversionPattern;
38
39 LogicalResult
40 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
41 ConversionPatternRewriter &rewriter) const override;
42};
43
44/// Pattern lowering subgroup size/id to loading SPIR-V invocation
45/// builtin variables.
46template <typename SourceOp, spirv::BuiltIn builtin>
47class SingleDimLaunchConfigConversion : public OpConversionPattern<SourceOp> {
48public:
49 using OpConversionPattern<SourceOp>::OpConversionPattern;
50
51 LogicalResult
52 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
53 ConversionPatternRewriter &rewriter) const override;
54};
55
56/// This is separate because in Vulkan workgroup size is exposed to shaders via
57/// a constant with WorkgroupSize decoration. So here we cannot generate a
58/// builtin variable; instead the information in the `spirv.entry_point_abi`
59/// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp.
60class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> {
61public:
62 WorkGroupSizeConversion(const TypeConverter &typeConverter,
63 MLIRContext *context)
64 : OpConversionPattern(typeConverter, context, /*benefit*/ 10) {}
65
66 LogicalResult
67 matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor,
68 ConversionPatternRewriter &rewriter) const override;
69};
70
71/// Pattern to convert a kernel function in GPU dialect within a spirv.module.
72class GPUFuncOpConversion final : public OpConversionPattern<gpu::GPUFuncOp> {
73public:
74 using OpConversionPattern<gpu::GPUFuncOp>::OpConversionPattern;
75
76 LogicalResult
77 matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
78 ConversionPatternRewriter &rewriter) const override;
79
80private:
81 SmallVector<int32_t, 3> workGroupSizeAsInt32;
82};
83
84/// Pattern to convert a gpu.module to a spirv.module.
85class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> {
86public:
87 using OpConversionPattern<gpu::GPUModuleOp>::OpConversionPattern;
88
89 LogicalResult
90 matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
91 ConversionPatternRewriter &rewriter) const override;
92};
93
94/// Pattern to convert a gpu.return into a SPIR-V return.
95// TODO: This can go to DRR when GPU return has operands.
96class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
97public:
98 using OpConversionPattern<gpu::ReturnOp>::OpConversionPattern;
99
100 LogicalResult
101 matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
102 ConversionPatternRewriter &rewriter) const override;
103};
104
105/// Pattern to convert a gpu.barrier op into a spirv.ControlBarrier op.
106class GPUBarrierConversion final : public OpConversionPattern<gpu::BarrierOp> {
107public:
108 using OpConversionPattern::OpConversionPattern;
109
110 LogicalResult
111 matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor,
112 ConversionPatternRewriter &rewriter) const override;
113};
114
115/// Pattern to convert a gpu.shuffle op into a spirv.GroupNonUniformShuffle op.
116class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
117public:
118 using OpConversionPattern::OpConversionPattern;
119
120 LogicalResult
121 matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
122 ConversionPatternRewriter &rewriter) const override;
123};
124
125class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
126public:
127 using OpConversionPattern::OpConversionPattern;
128
129 LogicalResult
130 matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
131 ConversionPatternRewriter &rewriter) const override;
132};
133
134} // namespace
135
136//===----------------------------------------------------------------------===//
137// Builtins.
138//===----------------------------------------------------------------------===//
139
140template <typename SourceOp, spirv::BuiltIn builtin>
141LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
142 SourceOp op, typename SourceOp::Adaptor adaptor,
143 ConversionPatternRewriter &rewriter) const {
144 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
145 Type indexType = typeConverter->getIndexType();
146
147 // For Vulkan, these SPIR-V builtin variables are required to be a vector of
148 // type <3xi32> by the spec:
149 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumWorkgroups.html
150 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupId.html
151 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupSize.html
152 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
153 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
154 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/GlobalInvocationId.html
155 //
156 // For OpenCL, it depends on the Physical32/Physical64 addressing model:
157 // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
158 bool forShader =
159 typeConverter->getTargetEnv().allows(spirv::Capability::Shader);
160 Type builtinType = forShader ? rewriter.getIntegerType(32) : indexType;
161
162 Value vector =
163 spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter);
164 Value dim = rewriter.create<spirv::CompositeExtractOp>(
165 op.getLoc(), builtinType, vector,
166 rewriter.getI32ArrayAttr({static_cast<int32_t>(op.getDimension())}));
167 if (forShader && builtinType != indexType)
168 dim = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, dim);
169 rewriter.replaceOp(op, dim);
170 return success();
171}
172
173template <typename SourceOp, spirv::BuiltIn builtin>
174LogicalResult
175SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
176 SourceOp op, typename SourceOp::Adaptor adaptor,
177 ConversionPatternRewriter &rewriter) const {
178 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
179 Type indexType = typeConverter->getIndexType();
180 Type i32Type = rewriter.getIntegerType(32);
181
182 // For Vulkan, these SPIR-V builtin variables are required to be a vector of
183 // type i32 by the spec:
184 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumSubgroups.html
185 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupId.html
186 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupSize.html
187 //
188 // For OpenCL, they are also required to be i32:
189 // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
190 Value builtinValue =
191 spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter);
192 if (i32Type != indexType)
193 builtinValue = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType,
194 builtinValue);
195 rewriter.replaceOp(op, builtinValue);
196 return success();
197}
198
199LogicalResult WorkGroupSizeConversion::matchAndRewrite(
200 gpu::BlockDimOp op, OpAdaptor adaptor,
201 ConversionPatternRewriter &rewriter) const {
202 DenseI32ArrayAttr workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op);
203 if (!workGroupSizeAttr)
204 return failure();
205
206 int val =
207 workGroupSizeAttr.asArrayRef()[static_cast<int32_t>(op.getDimension())];
208 auto convertedType =
209 getTypeConverter()->convertType(op.getResult().getType());
210 if (!convertedType)
211 return failure();
212 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
213 op, convertedType, IntegerAttr::get(convertedType, val));
214 return success();
215}
216
217//===----------------------------------------------------------------------===//
218// GPUFuncOp
219//===----------------------------------------------------------------------===//
220
221// Legalizes a GPU function as an entry SPIR-V function.
222static spirv::FuncOp
223lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter,
224 ConversionPatternRewriter &rewriter,
225 spirv::EntryPointABIAttr entryPointInfo,
226 ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) {
227 auto fnType = funcOp.getFunctionType();
228 if (fnType.getNumResults()) {
229 funcOp.emitError("SPIR-V lowering only supports entry functions"
230 "with no return values right now");
231 return nullptr;
232 }
233 if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) {
234 funcOp.emitError(
235 "lowering as entry functions requires ABI info for all arguments "
236 "or none of them");
237 return nullptr;
238 }
239 // Update the signature to valid SPIR-V types and add the ABI
240 // attributes. These will be "materialized" by using the
241 // LowerABIAttributesPass.
242 TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
243 {
244 for (const auto &argType :
245 enumerate(funcOp.getFunctionType().getInputs())) {
246 auto convertedType = typeConverter.convertType(argType.value());
247 if (!convertedType)
248 return nullptr;
249 signatureConverter.addInputs(argType.index(), convertedType);
250 }
251 }
252 auto newFuncOp = rewriter.create<spirv::FuncOp>(
253 funcOp.getLoc(), funcOp.getName(),
254 rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
255 std::nullopt));
256 for (const auto &namedAttr : funcOp->getAttrs()) {
257 if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
258 namedAttr.getName() == SymbolTable::getSymbolAttrName())
259 continue;
260 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
261 }
262
263 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
264 newFuncOp.end());
265 if (failed(rewriter.convertRegionTypes(region: &newFuncOp.getBody(), converter: typeConverter,
266 entryConversion: &signatureConverter)))
267 return nullptr;
268 rewriter.eraseOp(op: funcOp);
269
270 // Set the attributes for argument and the function.
271 StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
272 for (auto argIndex : llvm::seq<unsigned>(Begin: 0, End: argABIInfo.size())) {
273 newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
274 }
275 newFuncOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
276
277 return newFuncOp;
278}
279
280/// Populates `argABI` with spirv.interface_var_abi attributes for lowering
281/// gpu.func to spirv.func if no arguments have the attributes set
282/// already. Returns failure if any argument has the ABI attribute set already.
283static LogicalResult
284getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp,
285 SmallVectorImpl<spirv::InterfaceVarABIAttr> &argABI) {
286 if (!spirv::needsInterfaceVarABIAttrs(targetAttr: targetEnv))
287 return success();
288
289 for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
290 if (funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
291 argIndex, spirv::getInterfaceVarABIAttrName()))
292 return failure();
293 // Vulkan's interface variable requirements needs scalars to be wrapped in a
294 // struct. The struct held in storage buffer.
295 std::optional<spirv::StorageClass> sc;
296 if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
297 sc = spirv::StorageClass::StorageBuffer;
298 argABI.push_back(
299 spirv::getInterfaceVarABIAttr(0, argIndex, sc, funcOp.getContext()));
300 }
301 return success();
302}
303
304LogicalResult GPUFuncOpConversion::matchAndRewrite(
305 gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
306 ConversionPatternRewriter &rewriter) const {
307 if (!gpu::GPUDialect::isKernel(funcOp))
308 return failure();
309
310 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
311 SmallVector<spirv::InterfaceVarABIAttr, 4> argABI;
312 if (failed(
313 getDefaultABIAttrs(typeConverter->getTargetEnv(), funcOp, argABI))) {
314 argABI.clear();
315 for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
316 // If the ABI is already specified, use it.
317 auto abiAttr = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
318 argIndex, spirv::getInterfaceVarABIAttrName());
319 if (!abiAttr) {
320 funcOp.emitRemark(
321 "match failure: missing 'spirv.interface_var_abi' attribute at "
322 "argument ")
323 << argIndex;
324 return failure();
325 }
326 argABI.push_back(abiAttr);
327 }
328 }
329
330 auto entryPointAttr = spirv::lookupEntryPointABI(funcOp);
331 if (!entryPointAttr) {
332 funcOp.emitRemark(
333 "match failure: missing 'spirv.entry_point_abi' attribute");
334 return failure();
335 }
336 spirv::FuncOp newFuncOp = lowerAsEntryFunction(
337 funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
338 if (!newFuncOp)
339 return failure();
340 newFuncOp->removeAttr(
341 rewriter.getStringAttr(gpu::GPUDialect::getKernelFuncAttrName()));
342 return success();
343}
344
345//===----------------------------------------------------------------------===//
346// ModuleOp with gpu.module.
347//===----------------------------------------------------------------------===//
348
349LogicalResult GPUModuleConversion::matchAndRewrite(
350 gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
351 ConversionPatternRewriter &rewriter) const {
352 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
353 const spirv::TargetEnv &targetEnv = typeConverter->getTargetEnv();
354 spirv::AddressingModel addressingModel = spirv::getAddressingModel(
355 targetEnv, typeConverter->getOptions().use64bitIndex);
356 FailureOr<spirv::MemoryModel> memoryModel = spirv::getMemoryModel(targetEnv);
357 if (failed(memoryModel))
358 return moduleOp.emitRemark(
359 "cannot deduce memory model from 'spirv.target_env'");
360
361 // Add a keyword to the module name to avoid symbolic conflict.
362 std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
363 auto spvModule = rewriter.create<spirv::ModuleOp>(
364 moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
365 StringRef(spvModuleName));
366
367 // Move the region from the module op into the SPIR-V module.
368 Region &spvModuleRegion = spvModule.getRegion();
369 rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion,
370 spvModuleRegion.begin());
371 // The spirv.module build method adds a block. Remove that.
372 rewriter.eraseBlock(block: &spvModuleRegion.back());
373
374 // Some of the patterns call `lookupTargetEnv` during conversion and they
375 // will fail if called after GPUModuleConversion and we don't preserve
376 // `TargetEnv` attribute.
377 // Copy TargetEnvAttr only if it is attached directly to the GPUModuleOp.
378 if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>(
379 spirv::getTargetEnvAttrName()))
380 spvModule->setAttr(spirv::getTargetEnvAttrName(), attr);
381
382 rewriter.eraseOp(op: moduleOp);
383 return success();
384}
385
386//===----------------------------------------------------------------------===//
387// GPU return inside kernel functions to SPIR-V return.
388//===----------------------------------------------------------------------===//
389
390LogicalResult GPUReturnOpConversion::matchAndRewrite(
391 gpu::ReturnOp returnOp, OpAdaptor adaptor,
392 ConversionPatternRewriter &rewriter) const {
393 if (!adaptor.getOperands().empty())
394 return failure();
395
396 rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
397 return success();
398}
399
400//===----------------------------------------------------------------------===//
401// Barrier.
402//===----------------------------------------------------------------------===//
403
404LogicalResult GPUBarrierConversion::matchAndRewrite(
405 gpu::BarrierOp barrierOp, OpAdaptor adaptor,
406 ConversionPatternRewriter &rewriter) const {
407 MLIRContext *context = getContext();
408 // Both execution and memory scope should be workgroup.
409 auto scope = spirv::ScopeAttr::get(context, spirv::Scope::Workgroup);
410 // Require acquire and release memory semantics for workgroup memory.
411 auto memorySemantics = spirv::MemorySemanticsAttr::get(
412 context, spirv::MemorySemantics::WorkgroupMemory |
413 spirv::MemorySemantics::AcquireRelease);
414 rewriter.replaceOpWithNewOp<spirv::ControlBarrierOp>(barrierOp, scope, scope,
415 memorySemantics);
416 return success();
417}
418
419//===----------------------------------------------------------------------===//
420// Shuffle
421//===----------------------------------------------------------------------===//
422
423LogicalResult GPUShuffleConversion::matchAndRewrite(
424 gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
425 ConversionPatternRewriter &rewriter) const {
426 // Require the shuffle width to be the same as the target's subgroup size,
427 // given that for SPIR-V non-uniform subgroup ops, we cannot select
428 // participating invocations.
429 auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
430 unsigned subgroupSize =
431 targetEnv.getAttr().getResourceLimits().getSubgroupSize();
432 IntegerAttr widthAttr;
433 if (!matchPattern(shuffleOp.getWidth(), m_Constant(&widthAttr)) ||
434 widthAttr.getValue().getZExtValue() != subgroupSize)
435 return rewriter.notifyMatchFailure(
436 shuffleOp, "shuffle width and target subgroup size mismatch");
437
438 Location loc = shuffleOp.getLoc();
439 Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
440 shuffleOp.getLoc(), rewriter);
441 auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
442 Value result;
443
444 switch (shuffleOp.getMode()) {
445 case gpu::ShuffleMode::XOR:
446 result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>(
447 loc, scope, adaptor.getValue(), adaptor.getOffset());
448 break;
449 case gpu::ShuffleMode::IDX:
450 result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
451 loc, scope, adaptor.getValue(), adaptor.getOffset());
452 break;
453 default:
454 return rewriter.notifyMatchFailure(shuffleOp, "unimplemented shuffle mode");
455 }
456
457 rewriter.replaceOp(shuffleOp, {result, trueVal});
458 return success();
459}
460
461//===----------------------------------------------------------------------===//
462// Group ops
463//===----------------------------------------------------------------------===//
464
465template <typename UniformOp, typename NonUniformOp>
466static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc,
467 Value arg, bool isGroup, bool isUniform,
468 std::optional<uint32_t> clusterSize) {
469 Type type = arg.getType();
470 auto scope = mlir::spirv::ScopeAttr::get(builder.getContext(),
471 isGroup ? spirv::Scope::Workgroup
472 : spirv::Scope::Subgroup);
473 auto groupOp = spirv::GroupOperationAttr::get(
474 builder.getContext(), clusterSize.has_value()
475 ? spirv::GroupOperation::ClusteredReduce
476 : spirv::GroupOperation::Reduce);
477 if (isUniform) {
478 return builder.create<UniformOp>(loc, type, scope, groupOp, arg)
479 .getResult();
480 }
481
482 Value clusterSizeValue;
483 if (clusterSize.has_value())
484 clusterSizeValue = builder.create<spirv::ConstantOp>(
485 loc, builder.getI32Type(),
486 builder.getIntegerAttr(builder.getI32Type(), *clusterSize));
487
488 return builder
489 .create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue)
490 .getResult();
491}
492
493static std::optional<Value>
494createGroupReduceOp(OpBuilder &builder, Location loc, Value arg,
495 gpu::AllReduceOperation opType, bool isGroup,
496 bool isUniform, std::optional<uint32_t> clusterSize) {
497 enum class ElemType { Float, Boolean, Integer };
498 using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool,
499 std::optional<uint32_t>);
500 struct OpHandler {
501 gpu::AllReduceOperation kind;
502 ElemType elemType;
503 FuncT func;
504 };
505
506 Type type = arg.getType();
507 ElemType elementType;
508 if (isa<FloatType>(Val: type)) {
509 elementType = ElemType::Float;
510 } else if (auto intTy = dyn_cast<IntegerType>(type)) {
511 elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean
512 : ElemType::Integer;
513 } else {
514 return std::nullopt;
515 }
516
517 // TODO(https://github.com/llvm/llvm-project/issues/73459): The SPIR-V spec
518 // does not specify how -0.0 / +0.0 and NaN values are handled in *FMin/*FMax
519 // reduction ops. We should account possible precision requirements in this
520 // conversion.
521
522 using ReduceType = gpu::AllReduceOperation;
523 const OpHandler handlers[] = {
524 {ReduceType::ADD, ElemType::Integer,
525 &createGroupReduceOpImpl<spirv::GroupIAddOp,
526 spirv::GroupNonUniformIAddOp>},
527 {ReduceType::ADD, ElemType::Float,
528 &createGroupReduceOpImpl<spirv::GroupFAddOp,
529 spirv::GroupNonUniformFAddOp>},
530 {ReduceType::MUL, ElemType::Integer,
531 &createGroupReduceOpImpl<spirv::GroupIMulKHROp,
532 spirv::GroupNonUniformIMulOp>},
533 {ReduceType::MUL, ElemType::Float,
534 &createGroupReduceOpImpl<spirv::GroupFMulKHROp,
535 spirv::GroupNonUniformFMulOp>},
536 {ReduceType::MINUI, ElemType::Integer,
537 &createGroupReduceOpImpl<spirv::GroupUMinOp,
538 spirv::GroupNonUniformUMinOp>},
539 {ReduceType::MINSI, ElemType::Integer,
540 &createGroupReduceOpImpl<spirv::GroupSMinOp,
541 spirv::GroupNonUniformSMinOp>},
542 {ReduceType::MINNUMF, ElemType::Float,
543 &createGroupReduceOpImpl<spirv::GroupFMinOp,
544 spirv::GroupNonUniformFMinOp>},
545 {ReduceType::MAXUI, ElemType::Integer,
546 &createGroupReduceOpImpl<spirv::GroupUMaxOp,
547 spirv::GroupNonUniformUMaxOp>},
548 {ReduceType::MAXSI, ElemType::Integer,
549 &createGroupReduceOpImpl<spirv::GroupSMaxOp,
550 spirv::GroupNonUniformSMaxOp>},
551 {ReduceType::MAXNUMF, ElemType::Float,
552 &createGroupReduceOpImpl<spirv::GroupFMaxOp,
553 spirv::GroupNonUniformFMaxOp>},
554 {ReduceType::MINIMUMF, ElemType::Float,
555 &createGroupReduceOpImpl<spirv::GroupFMinOp,
556 spirv::GroupNonUniformFMinOp>},
557 {ReduceType::MAXIMUMF, ElemType::Float,
558 &createGroupReduceOpImpl<spirv::GroupFMaxOp,
559 spirv::GroupNonUniformFMaxOp>}};
560
561 for (const OpHandler &handler : handlers)
562 if (handler.kind == opType && elementType == handler.elemType)
563 return handler.func(builder, loc, arg, isGroup, isUniform, clusterSize);
564
565 return std::nullopt;
566}
567
568/// Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
569class GPUAllReduceConversion final
570 : public OpConversionPattern<gpu::AllReduceOp> {
571public:
572 using OpConversionPattern::OpConversionPattern;
573
574 LogicalResult
575 matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor,
576 ConversionPatternRewriter &rewriter) const override {
577 auto opType = op.getOp();
578
579 // gpu.all_reduce can have either reduction op attribute or reduction
580 // region. Only attribute version is supported.
581 if (!opType)
582 return failure();
583
584 auto result =
585 createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), *opType,
586 /*isGroup*/ true, op.getUniform(), std::nullopt);
587 if (!result)
588 return failure();
589
590 rewriter.replaceOp(op, *result);
591 return success();
592 }
593};
594
595/// Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
596class GPUSubgroupReduceConversion final
597 : public OpConversionPattern<gpu::SubgroupReduceOp> {
598public:
599 using OpConversionPattern::OpConversionPattern;
600
601 LogicalResult
602 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
603 ConversionPatternRewriter &rewriter) const override {
604 if (op.getClusterStride() > 1) {
605 return rewriter.notifyMatchFailure(
606 op, "lowering for cluster stride > 1 is not implemented");
607 }
608
609 if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
610 return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");
611
612 auto result = createGroupReduceOp(
613 rewriter, op.getLoc(), adaptor.getValue(), adaptor.getOp(),
614 /*isGroup=*/false, adaptor.getUniform(), op.getClusterSize());
615 if (!result)
616 return failure();
617
618 rewriter.replaceOp(op, *result);
619 return success();
620 }
621};
622
623// Formulate a unique variable/constant name after
624// searching in the module for existing variable/constant names.
625// This is to avoid name collision with existing variables.
626// Example: printfMsg0, printfMsg1, printfMsg2, ...
627static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix) {
628 std::string name;
629 unsigned number = 0;
630
631 do {
632 name.clear();
633 name = (prefix + llvm::Twine(number++)).str();
634 } while (moduleOp.lookupSymbol(name));
635
636 return name;
637}
638
639/// Pattern to convert a gpu.printf op into a SPIR-V CLPrintf op.
640
641LogicalResult GPUPrintfConversion::matchAndRewrite(
642 gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
643 ConversionPatternRewriter &rewriter) const {
644
645 Location loc = gpuPrintfOp.getLoc();
646
647 auto moduleOp = gpuPrintfOp->getParentOfType<spirv::ModuleOp>();
648 if (!moduleOp)
649 return failure();
650
651 // SPIR-V global variable is used to initialize printf
652 // format string value, if there are multiple printf messages,
653 // each global var needs to be created with a unique name.
654 std::string globalVarName = makeVarName(moduleOp, llvm::Twine("printfMsg"));
655 spirv::GlobalVariableOp globalVar;
656
657 IntegerType i8Type = rewriter.getI8Type();
658 IntegerType i32Type = rewriter.getI32Type();
659
660 // Each character of printf format string is
661 // stored as a spec constant. We need to create
662 // unique name for this spec constant like
663 // @printfMsg0_sc0, @printfMsg0_sc1, ... by searching in the module
664 // for existing spec constant names.
665 auto createSpecConstant = [&](unsigned value) {
666 auto attr = rewriter.getI8IntegerAttr(value);
667 std::string specCstName =
668 makeVarName(moduleOp, llvm::Twine(globalVarName) + "_sc");
669
670 return rewriter.create<spirv::SpecConstantOp>(
671 loc, rewriter.getStringAttr(specCstName), attr);
672 };
673 {
674 Operation *parent =
675 SymbolTable::getNearestSymbolTable(from: gpuPrintfOp->getParentOp());
676
677 ConversionPatternRewriter::InsertionGuard guard(rewriter);
678
679 Block &entryBlock = *parent->getRegion(index: 0).begin();
680 rewriter.setInsertionPointToStart(
681 &entryBlock); // insertion point at module level
682
683 // Create Constituents with SpecConstant by scanning format string
684 // Each character of format string is stored as a spec constant
685 // and then these spec constants are used to create a
686 // SpecConstantCompositeOp.
687 llvm::SmallString<20> formatString(adaptor.getFormat());
688 formatString.push_back(Elt: '\0'); // Null terminate for C.
689 SmallVector<Attribute, 4> constituents;
690 for (char c : formatString) {
691 spirv::SpecConstantOp cSpecConstantOp = createSpecConstant(c);
692 constituents.push_back(SymbolRefAttr::get(cSpecConstantOp));
693 }
694
695 // Create SpecConstantCompositeOp to initialize the global variable
696 size_t contentSize = constituents.size();
697 auto globalType = spirv::ArrayType::get(i8Type, contentSize);
698 spirv::SpecConstantCompositeOp specCstComposite;
699 // There will be one SpecConstantCompositeOp per printf message/global var,
700 // so no need do lookup for existing ones.
701 std::string specCstCompositeName =
702 (llvm::Twine(globalVarName) + "_scc").str();
703
704 specCstComposite = rewriter.create<spirv::SpecConstantCompositeOp>(
705 loc, TypeAttr::get(globalType),
706 rewriter.getStringAttr(specCstCompositeName),
707 rewriter.getArrayAttr(constituents));
708
709 auto ptrType = spirv::PointerType::get(
710 globalType, spirv::StorageClass::UniformConstant);
711
712 // Define a GlobalVarOp initialized using specialized constants
713 // that is used to specify the printf format string
714 // to be passed to the SPIRV CLPrintfOp.
715 globalVar = rewriter.create<spirv::GlobalVariableOp>(
716 loc, ptrType, globalVarName, FlatSymbolRefAttr::get(specCstComposite));
717
718 globalVar->setAttr("Constant", rewriter.getUnitAttr());
719 }
720 // Get SSA value of Global variable and create pointer to i8 to point to
721 // the format string.
722 Value globalPtr = rewriter.create<spirv::AddressOfOp>(loc, globalVar);
723 Value fmtStr = rewriter.create<spirv::BitcastOp>(
724 loc,
725 spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant),
726 globalPtr);
727
728 // Get printf arguments.
729 auto printfArgs = llvm::to_vector_of<Value, 4>(adaptor.getArgs());
730
731 rewriter.create<spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
732
733 // Need to erase the gpu.printf op as gpu.printf does not use result vs
734 // spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V
735 // printf op.
736 rewriter.eraseOp(op: gpuPrintfOp);
737
738 return success();
739}
740
741//===----------------------------------------------------------------------===//
742// GPU To SPIRV Patterns.
743//===----------------------------------------------------------------------===//
744
745void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
746 RewritePatternSet &patterns) {
747 patterns.add<
748 GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
749 GPUReturnOpConversion, GPUShuffleConversion,
750 LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
751 LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
752 LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
753 LaunchConfigConversion<gpu::ThreadIdOp,
754 spirv::BuiltIn::LocalInvocationId>,
755 LaunchConfigConversion<gpu::GlobalIdOp,
756 spirv::BuiltIn::GlobalInvocationId>,
757 SingleDimLaunchConfigConversion<gpu::SubgroupIdOp,
758 spirv::BuiltIn::SubgroupId>,
759 SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp,
760 spirv::BuiltIn::NumSubgroups>,
761 SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
762 spirv::BuiltIn::SubgroupSize>,
763 SingleDimLaunchConfigConversion<
764 gpu::LaneIdOp, spirv::BuiltIn::SubgroupLocalInvocationId>,
765 WorkGroupSizeConversion, GPUAllReduceConversion,
766 GPUSubgroupReduceConversion, GPUPrintfConversion>(typeConverter,
767 patterns.getContext());
768}
769

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp