1//===- GPUToSPIRV.cpp - GPU to SPIR-V Patterns ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert GPU dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
14#include "mlir/Dialect/GPU/IR/GPUDialect.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
20#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/Transforms/DialectConversion.h"
23#include <optional>
24
25using namespace mlir;
26
27static constexpr const char kSPIRVModule[] = "__spv__";
28
29namespace {
30/// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
31/// builtin variables.
32template <typename SourceOp, spirv::BuiltIn builtin>
33class LaunchConfigConversion : public OpConversionPattern<SourceOp> {
34public:
35 using OpConversionPattern<SourceOp>::OpConversionPattern;
36
37 LogicalResult
38 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
39 ConversionPatternRewriter &rewriter) const override;
40};
41
42/// Pattern lowering subgroup size/id to loading SPIR-V invocation
43/// builtin variables.
44template <typename SourceOp, spirv::BuiltIn builtin>
45class SingleDimLaunchConfigConversion : public OpConversionPattern<SourceOp> {
46public:
47 using OpConversionPattern<SourceOp>::OpConversionPattern;
48
49 LogicalResult
50 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
51 ConversionPatternRewriter &rewriter) const override;
52};
53
54/// This is separate because in Vulkan workgroup size is exposed to shaders via
55/// a constant with WorkgroupSize decoration. So here we cannot generate a
56/// builtin variable; instead the information in the `spirv.entry_point_abi`
57/// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp.
58class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> {
59public:
60 WorkGroupSizeConversion(const TypeConverter &typeConverter,
61 MLIRContext *context)
62 : OpConversionPattern(typeConverter, context, /*benefit*/ 10) {}
63
64 LogicalResult
65 matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor,
66 ConversionPatternRewriter &rewriter) const override;
67};
68
69/// Pattern to convert a kernel function in GPU dialect within a spirv.module.
70class GPUFuncOpConversion final : public OpConversionPattern<gpu::GPUFuncOp> {
71public:
72 using OpConversionPattern<gpu::GPUFuncOp>::OpConversionPattern;
73
74 LogicalResult
75 matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
76 ConversionPatternRewriter &rewriter) const override;
77
78private:
79 SmallVector<int32_t, 3> workGroupSizeAsInt32;
80};
81
82/// Pattern to convert a gpu.module to a spirv.module.
83class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> {
84public:
85 using OpConversionPattern<gpu::GPUModuleOp>::OpConversionPattern;
86
87 LogicalResult
88 matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
89 ConversionPatternRewriter &rewriter) const override;
90};
91
92/// Pattern to convert a gpu.return into a SPIR-V return.
93// TODO: This can go to DRR when GPU return has operands.
94class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
95public:
96 using OpConversionPattern<gpu::ReturnOp>::OpConversionPattern;
97
98 LogicalResult
99 matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
100 ConversionPatternRewriter &rewriter) const override;
101};
102
103/// Pattern to convert a gpu.barrier op into a spirv.ControlBarrier op.
104class GPUBarrierConversion final : public OpConversionPattern<gpu::BarrierOp> {
105public:
106 using OpConversionPattern::OpConversionPattern;
107
108 LogicalResult
109 matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor,
110 ConversionPatternRewriter &rewriter) const override;
111};
112
113/// Pattern to convert a gpu.shuffle op into a spirv.GroupNonUniformShuffle op.
114class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
115public:
116 using OpConversionPattern::OpConversionPattern;
117
118 LogicalResult
119 matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
120 ConversionPatternRewriter &rewriter) const override;
121};
122
123/// Pattern to convert a gpu.rotate op into a spirv.GroupNonUniformRotateKHROp.
124class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
125public:
126 using OpConversionPattern::OpConversionPattern;
127
128 LogicalResult
129 matchAndRewrite(gpu::RotateOp rotateOp, OpAdaptor adaptor,
130 ConversionPatternRewriter &rewriter) const override;
131};
132
133class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
134public:
135 using OpConversionPattern::OpConversionPattern;
136
137 LogicalResult
138 matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
139 ConversionPatternRewriter &rewriter) const override;
140};
141
142} // namespace
143
144//===----------------------------------------------------------------------===//
145// Builtins.
146//===----------------------------------------------------------------------===//
147
148template <typename SourceOp, spirv::BuiltIn builtin>
149LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
150 SourceOp op, typename SourceOp::Adaptor adaptor,
151 ConversionPatternRewriter &rewriter) const {
152 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
153 Type indexType = typeConverter->getIndexType();
154
155 // For Vulkan, these SPIR-V builtin variables are required to be a vector of
156 // type <3xi32> by the spec:
157 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumWorkgroups.html
158 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupId.html
159 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupSize.html
160 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
161 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
162 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/GlobalInvocationId.html
163 //
164 // For OpenCL, it depends on the Physical32/Physical64 addressing model:
165 // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
166 bool forShader =
167 typeConverter->getTargetEnv().allows(spirv::Capability::Shader);
168 Type builtinType = forShader ? rewriter.getIntegerType(width: 32) : indexType;
169
170 Value vector =
171 spirv::getBuiltinVariableValue(op, builtin, integerType: builtinType, builder&: rewriter);
172 Value dim = rewriter.create<spirv::CompositeExtractOp>(
173 op.getLoc(), builtinType, vector,
174 rewriter.getI32ArrayAttr(values: {static_cast<int32_t>(op.getDimension())}));
175 if (forShader && builtinType != indexType)
176 dim = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, dim);
177 rewriter.replaceOp(op, dim);
178 return success();
179}
180
181template <typename SourceOp, spirv::BuiltIn builtin>
182LogicalResult
183SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
184 SourceOp op, typename SourceOp::Adaptor adaptor,
185 ConversionPatternRewriter &rewriter) const {
186 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
187 Type indexType = typeConverter->getIndexType();
188 Type i32Type = rewriter.getIntegerType(width: 32);
189
190 // For Vulkan, these SPIR-V builtin variables are required to be a vector of
191 // type i32 by the spec:
192 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumSubgroups.html
193 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupId.html
194 // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupSize.html
195 //
196 // For OpenCL, they are also required to be i32:
197 // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
198 Value builtinValue =
199 spirv::getBuiltinVariableValue(op, builtin, integerType: i32Type, builder&: rewriter);
200 if (i32Type != indexType)
201 builtinValue = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType,
202 builtinValue);
203 rewriter.replaceOp(op, builtinValue);
204 return success();
205}
206
207LogicalResult WorkGroupSizeConversion::matchAndRewrite(
208 gpu::BlockDimOp op, OpAdaptor adaptor,
209 ConversionPatternRewriter &rewriter) const {
210 DenseI32ArrayAttr workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op);
211 if (!workGroupSizeAttr)
212 return failure();
213
214 int val =
215 workGroupSizeAttr.asArrayRef()[static_cast<int32_t>(op.getDimension())];
216 auto convertedType =
217 getTypeConverter()->convertType(t: op.getResult().getType());
218 if (!convertedType)
219 return failure();
220 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
221 op, args&: convertedType, args: IntegerAttr::get(type: convertedType, value: val));
222 return success();
223}
224
225//===----------------------------------------------------------------------===//
226// GPUFuncOp
227//===----------------------------------------------------------------------===//
228
229// Legalizes a GPU function as an entry SPIR-V function.
230static spirv::FuncOp
231lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter,
232 ConversionPatternRewriter &rewriter,
233 spirv::EntryPointABIAttr entryPointInfo,
234 ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) {
235 auto fnType = funcOp.getFunctionType();
236 if (fnType.getNumResults()) {
237 funcOp.emitError(message: "SPIR-V lowering only supports entry functions"
238 "with no return values right now");
239 return nullptr;
240 }
241 if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) {
242 funcOp.emitError(
243 message: "lowering as entry functions requires ABI info for all arguments "
244 "or none of them");
245 return nullptr;
246 }
247 // Update the signature to valid SPIR-V types and add the ABI
248 // attributes. These will be "materialized" by using the
249 // LowerABIAttributesPass.
250 TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
251 {
252 for (const auto &argType :
253 enumerate(First: funcOp.getFunctionType().getInputs())) {
254 auto convertedType = typeConverter.convertType(t: argType.value());
255 if (!convertedType)
256 return nullptr;
257 signatureConverter.addInputs(origInputNo: argType.index(), types: convertedType);
258 }
259 }
260 auto newFuncOp = rewriter.create<spirv::FuncOp>(
261 location: funcOp.getLoc(), args: funcOp.getName(),
262 args: rewriter.getFunctionType(inputs: signatureConverter.getConvertedTypes(), results: {}));
263 for (const auto &namedAttr : funcOp->getAttrs()) {
264 if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
265 namedAttr.getName() == SymbolTable::getSymbolAttrName())
266 continue;
267 newFuncOp->setAttr(name: namedAttr.getName(), value: namedAttr.getValue());
268 }
269
270 rewriter.inlineRegionBefore(region&: funcOp.getBody(), parent&: newFuncOp.getBody(),
271 before: newFuncOp.end());
272 if (failed(Result: rewriter.convertRegionTypes(region: &newFuncOp.getBody(), converter: typeConverter,
273 entryConversion: &signatureConverter)))
274 return nullptr;
275 rewriter.eraseOp(op: funcOp);
276
277 // Set the attributes for argument and the function.
278 StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
279 for (auto argIndex : llvm::seq<unsigned>(Begin: 0, End: argABIInfo.size())) {
280 newFuncOp.setArgAttr(index: argIndex, name: argABIAttrName, value: argABIInfo[argIndex]);
281 }
282 newFuncOp->setAttr(name: spirv::getEntryPointABIAttrName(), value: entryPointInfo);
283
284 return newFuncOp;
285}
286
287/// Populates `argABI` with spirv.interface_var_abi attributes for lowering
288/// gpu.func to spirv.func if no arguments have the attributes set
289/// already. Returns failure if any argument has the ABI attribute set already.
290static LogicalResult
291getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp,
292 SmallVectorImpl<spirv::InterfaceVarABIAttr> &argABI) {
293 if (!spirv::needsInterfaceVarABIAttrs(targetAttr: targetEnv))
294 return success();
295
296 for (auto argIndex : llvm::seq<unsigned>(Begin: 0, End: funcOp.getNumArguments())) {
297 if (funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
298 index: argIndex, name: spirv::getInterfaceVarABIAttrName()))
299 return failure();
300 // Vulkan's interface variable requirements needs scalars to be wrapped in a
301 // struct. The struct held in storage buffer.
302 std::optional<spirv::StorageClass> sc;
303 if (funcOp.getArgument(idx: argIndex).getType().isIntOrIndexOrFloat())
304 sc = spirv::StorageClass::StorageBuffer;
305 argABI.push_back(
306 Elt: spirv::getInterfaceVarABIAttr(descriptorSet: 0, binding: argIndex, storageClass: sc, context: funcOp.getContext()));
307 }
308 return success();
309}
310
311LogicalResult GPUFuncOpConversion::matchAndRewrite(
312 gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
313 ConversionPatternRewriter &rewriter) const {
314 if (!gpu::GPUDialect::isKernel(op: funcOp))
315 return failure();
316
317 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
318 SmallVector<spirv::InterfaceVarABIAttr, 4> argABI;
319 if (failed(
320 Result: getDefaultABIAttrs(targetEnv: typeConverter->getTargetEnv(), funcOp, argABI))) {
321 argABI.clear();
322 for (auto argIndex : llvm::seq<unsigned>(Begin: 0, End: funcOp.getNumArguments())) {
323 // If the ABI is already specified, use it.
324 auto abiAttr = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
325 index: argIndex, name: spirv::getInterfaceVarABIAttrName());
326 if (!abiAttr) {
327 funcOp.emitRemark(
328 message: "match failure: missing 'spirv.interface_var_abi' attribute at "
329 "argument ")
330 << argIndex;
331 return failure();
332 }
333 argABI.push_back(Elt: abiAttr);
334 }
335 }
336
337 auto entryPointAttr = spirv::lookupEntryPointABI(op: funcOp);
338 if (!entryPointAttr) {
339 funcOp.emitRemark(
340 message: "match failure: missing 'spirv.entry_point_abi' attribute");
341 return failure();
342 }
343 spirv::FuncOp newFuncOp = lowerAsEntryFunction(
344 funcOp, typeConverter: *getTypeConverter(), rewriter, entryPointInfo: entryPointAttr, argABIInfo: argABI);
345 if (!newFuncOp)
346 return failure();
347 newFuncOp->removeAttr(
348 name: rewriter.getStringAttr(bytes: gpu::GPUDialect::getKernelFuncAttrName()));
349 return success();
350}
351
352//===----------------------------------------------------------------------===//
353// ModuleOp with gpu.module.
354//===----------------------------------------------------------------------===//
355
356LogicalResult GPUModuleConversion::matchAndRewrite(
357 gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter) const {
359 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
360 const spirv::TargetEnv &targetEnv = typeConverter->getTargetEnv();
361 spirv::AddressingModel addressingModel = spirv::getAddressingModel(
362 targetAttr: targetEnv, use64bitAddress: typeConverter->getOptions().use64bitIndex);
363 FailureOr<spirv::MemoryModel> memoryModel = spirv::getMemoryModel(targetAttr: targetEnv);
364 if (failed(Result: memoryModel))
365 return moduleOp.emitRemark(
366 message: "cannot deduce memory model from 'spirv.target_env'");
367
368 // Add a keyword to the module name to avoid symbolic conflict.
369 std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
370 auto spvModule = rewriter.create<spirv::ModuleOp>(
371 location: moduleOp.getLoc(), args&: addressingModel, args&: *memoryModel, args: std::nullopt,
372 args: StringRef(spvModuleName));
373
374 // Move the region from the module op into the SPIR-V module.
375 Region &spvModuleRegion = spvModule.getRegion();
376 rewriter.inlineRegionBefore(region&: moduleOp.getBodyRegion(), parent&: spvModuleRegion,
377 before: spvModuleRegion.begin());
378 // The spirv.module build method adds a block. Remove that.
379 rewriter.eraseBlock(block: &spvModuleRegion.back());
380
381 // Some of the patterns call `lookupTargetEnv` during conversion and they
382 // will fail if called after GPUModuleConversion and we don't preserve
383 // `TargetEnv` attribute.
384 // Copy TargetEnvAttr only if it is attached directly to the GPUModuleOp.
385 if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>(
386 name: spirv::getTargetEnvAttrName()))
387 spvModule->setAttr(name: spirv::getTargetEnvAttrName(), value: attr);
388
389 rewriter.eraseOp(op: moduleOp);
390 return success();
391}
392
393//===----------------------------------------------------------------------===//
394// GPU return inside kernel functions to SPIR-V return.
395//===----------------------------------------------------------------------===//
396
397LogicalResult GPUReturnOpConversion::matchAndRewrite(
398 gpu::ReturnOp returnOp, OpAdaptor adaptor,
399 ConversionPatternRewriter &rewriter) const {
400 if (!adaptor.getOperands().empty())
401 return failure();
402
403 rewriter.replaceOpWithNewOp<spirv::ReturnOp>(op: returnOp);
404 return success();
405}
406
407//===----------------------------------------------------------------------===//
408// Barrier.
409//===----------------------------------------------------------------------===//
410
411LogicalResult GPUBarrierConversion::matchAndRewrite(
412 gpu::BarrierOp barrierOp, OpAdaptor adaptor,
413 ConversionPatternRewriter &rewriter) const {
414 MLIRContext *context = getContext();
415 // Both execution and memory scope should be workgroup.
416 auto scope = spirv::ScopeAttr::get(context, value: spirv::Scope::Workgroup);
417 // Require acquire and release memory semantics for workgroup memory.
418 auto memorySemantics = spirv::MemorySemanticsAttr::get(
419 context, value: spirv::MemorySemantics::WorkgroupMemory |
420 spirv::MemorySemantics::AcquireRelease);
421 rewriter.replaceOpWithNewOp<spirv::ControlBarrierOp>(op: barrierOp, args&: scope, args&: scope,
422 args&: memorySemantics);
423 return success();
424}
425
426//===----------------------------------------------------------------------===//
427// Shuffle
428//===----------------------------------------------------------------------===//
429
430LogicalResult GPUShuffleConversion::matchAndRewrite(
431 gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
432 ConversionPatternRewriter &rewriter) const {
433 // Require the shuffle width to be the same as the target's subgroup size,
434 // given that for SPIR-V non-uniform subgroup ops, we cannot select
435 // participating invocations.
436 auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
437 unsigned subgroupSize =
438 targetEnv.getAttr().getResourceLimits().getSubgroupSize();
439 IntegerAttr widthAttr;
440 if (!matchPattern(value: shuffleOp.getWidth(), pattern: m_Constant(bind_value: &widthAttr)) ||
441 widthAttr.getValue().getZExtValue() != subgroupSize)
442 return rewriter.notifyMatchFailure(
443 arg&: shuffleOp, msg: "shuffle width and target subgroup size mismatch");
444
445 assert(!adaptor.getOffset().getType().isSignedInteger() &&
446 "shuffle offset must be a signless/unsigned integer");
447
448 Location loc = shuffleOp.getLoc();
449 auto scope = rewriter.getAttr<spirv::ScopeAttr>(args: spirv::Scope::Subgroup);
450 Value result;
451 Value validVal;
452
453 switch (shuffleOp.getMode()) {
454 case gpu::ShuffleMode::XOR: {
455 result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>(
456 location: loc, args&: scope, args: adaptor.getValue(), args: adaptor.getOffset());
457 validVal = spirv::ConstantOp::getOne(type: rewriter.getI1Type(),
458 loc: shuffleOp.getLoc(), builder&: rewriter);
459 break;
460 }
461 case gpu::ShuffleMode::IDX: {
462 result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
463 location: loc, args&: scope, args: adaptor.getValue(), args: adaptor.getOffset());
464 validVal = spirv::ConstantOp::getOne(type: rewriter.getI1Type(),
465 loc: shuffleOp.getLoc(), builder&: rewriter);
466 break;
467 }
468 case gpu::ShuffleMode::DOWN: {
469 result = rewriter.create<spirv::GroupNonUniformShuffleDownOp>(
470 location: loc, args&: scope, args: adaptor.getValue(), args: adaptor.getOffset());
471
472 Value laneId = rewriter.create<gpu::LaneIdOp>(location: loc, args&: widthAttr);
473 Value resultLaneId =
474 rewriter.create<arith::AddIOp>(location: loc, args&: laneId, args: adaptor.getOffset());
475 validVal = rewriter.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::ult,
476 args&: resultLaneId, args: adaptor.getWidth());
477 break;
478 }
479 case gpu::ShuffleMode::UP: {
480 result = rewriter.create<spirv::GroupNonUniformShuffleUpOp>(
481 location: loc, args&: scope, args: adaptor.getValue(), args: adaptor.getOffset());
482
483 Value laneId = rewriter.create<gpu::LaneIdOp>(location: loc, args&: widthAttr);
484 Value resultLaneId =
485 rewriter.create<arith::SubIOp>(location: loc, args&: laneId, args: adaptor.getOffset());
486 auto i32Type = rewriter.getIntegerType(width: 32);
487 validVal = rewriter.create<arith::CmpIOp>(
488 location: loc, args: arith::CmpIPredicate::sge, args&: resultLaneId,
489 args: rewriter.create<arith::ConstantOp>(
490 location: loc, args&: i32Type, args: rewriter.getIntegerAttr(type: i32Type, value: 0)));
491 break;
492 }
493 }
494
495 rewriter.replaceOp(op: shuffleOp, newValues: {result, validVal});
496 return success();
497}
498
499//===----------------------------------------------------------------------===//
500// Rotate
501//===----------------------------------------------------------------------===//
502
503LogicalResult GPURotateConversion::matchAndRewrite(
504 gpu::RotateOp rotateOp, OpAdaptor adaptor,
505 ConversionPatternRewriter &rewriter) const {
506 const spirv::TargetEnv &targetEnv =
507 getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
508 unsigned subgroupSize =
509 targetEnv.getAttr().getResourceLimits().getSubgroupSize();
510 IntegerAttr widthAttr;
511 if (!matchPattern(value: rotateOp.getWidth(), pattern: m_Constant(bind_value: &widthAttr)) ||
512 widthAttr.getValue().getZExtValue() > subgroupSize)
513 return rewriter.notifyMatchFailure(
514 arg&: rotateOp,
515 msg: "rotate width is not a constant or larger than target subgroup size");
516
517 Location loc = rotateOp.getLoc();
518 auto scope = rewriter.getAttr<spirv::ScopeAttr>(args: spirv::Scope::Subgroup);
519 Value rotateResult = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
520 location: loc, args&: scope, args: adaptor.getValue(), args: adaptor.getOffset(), args: adaptor.getWidth());
521 Value validVal;
522 if (widthAttr.getValue().getZExtValue() == subgroupSize) {
523 validVal = spirv::ConstantOp::getOne(type: rewriter.getI1Type(), loc, builder&: rewriter);
524 } else {
525 Value laneId = rewriter.create<gpu::LaneIdOp>(location: loc, args&: widthAttr);
526 validVal = rewriter.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::ult,
527 args&: laneId, args: adaptor.getWidth());
528 }
529
530 rewriter.replaceOp(op: rotateOp, newValues: {rotateResult, validVal});
531 return success();
532}
533
534//===----------------------------------------------------------------------===//
535// Group ops
536//===----------------------------------------------------------------------===//
537
538template <typename UniformOp, typename NonUniformOp>
539static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc,
540 Value arg, bool isGroup, bool isUniform,
541 std::optional<uint32_t> clusterSize) {
542 Type type = arg.getType();
543 auto scope = mlir::spirv::ScopeAttr::get(context: builder.getContext(),
544 value: isGroup ? spirv::Scope::Workgroup
545 : spirv::Scope::Subgroup);
546 auto groupOp = spirv::GroupOperationAttr::get(
547 context: builder.getContext(), value: clusterSize.has_value()
548 ? spirv::GroupOperation::ClusteredReduce
549 : spirv::GroupOperation::Reduce);
550 if (isUniform) {
551 return builder.create<UniformOp>(loc, type, scope, groupOp, arg)
552 .getResult();
553 }
554
555 Value clusterSizeValue;
556 if (clusterSize.has_value())
557 clusterSizeValue = builder.create<spirv::ConstantOp>(
558 location: loc, args: builder.getI32Type(),
559 args: builder.getIntegerAttr(type: builder.getI32Type(), value: *clusterSize));
560
561 return builder
562 .create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue)
563 .getResult();
564}
565
566static std::optional<Value>
567createGroupReduceOp(OpBuilder &builder, Location loc, Value arg,
568 gpu::AllReduceOperation opType, bool isGroup,
569 bool isUniform, std::optional<uint32_t> clusterSize) {
570 enum class ElemType { Float, Boolean, Integer };
571 using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool,
572 std::optional<uint32_t>);
573 struct OpHandler {
574 gpu::AllReduceOperation kind;
575 ElemType elemType;
576 FuncT func;
577 };
578
579 Type type = arg.getType();
580 ElemType elementType;
581 if (isa<FloatType>(Val: type)) {
582 elementType = ElemType::Float;
583 } else if (auto intTy = dyn_cast<IntegerType>(Val&: type)) {
584 elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean
585 : ElemType::Integer;
586 } else {
587 return std::nullopt;
588 }
589
590 // TODO(https://github.com/llvm/llvm-project/issues/73459): The SPIR-V spec
591 // does not specify how -0.0 / +0.0 and NaN values are handled in *FMin/*FMax
592 // reduction ops. We should account possible precision requirements in this
593 // conversion.
594
595 using ReduceType = gpu::AllReduceOperation;
596 const OpHandler handlers[] = {
597 {.kind: ReduceType::ADD, .elemType: ElemType::Integer,
598 .func: &createGroupReduceOpImpl<spirv::GroupIAddOp,
599 spirv::GroupNonUniformIAddOp>},
600 {.kind: ReduceType::ADD, .elemType: ElemType::Float,
601 .func: &createGroupReduceOpImpl<spirv::GroupFAddOp,
602 spirv::GroupNonUniformFAddOp>},
603 {.kind: ReduceType::MUL, .elemType: ElemType::Integer,
604 .func: &createGroupReduceOpImpl<spirv::GroupIMulKHROp,
605 spirv::GroupNonUniformIMulOp>},
606 {.kind: ReduceType::MUL, .elemType: ElemType::Float,
607 .func: &createGroupReduceOpImpl<spirv::GroupFMulKHROp,
608 spirv::GroupNonUniformFMulOp>},
609 {.kind: ReduceType::MINUI, .elemType: ElemType::Integer,
610 .func: &createGroupReduceOpImpl<spirv::GroupUMinOp,
611 spirv::GroupNonUniformUMinOp>},
612 {.kind: ReduceType::MINSI, .elemType: ElemType::Integer,
613 .func: &createGroupReduceOpImpl<spirv::GroupSMinOp,
614 spirv::GroupNonUniformSMinOp>},
615 {.kind: ReduceType::MINNUMF, .elemType: ElemType::Float,
616 .func: &createGroupReduceOpImpl<spirv::GroupFMinOp,
617 spirv::GroupNonUniformFMinOp>},
618 {.kind: ReduceType::MAXUI, .elemType: ElemType::Integer,
619 .func: &createGroupReduceOpImpl<spirv::GroupUMaxOp,
620 spirv::GroupNonUniformUMaxOp>},
621 {.kind: ReduceType::MAXSI, .elemType: ElemType::Integer,
622 .func: &createGroupReduceOpImpl<spirv::GroupSMaxOp,
623 spirv::GroupNonUniformSMaxOp>},
624 {.kind: ReduceType::MAXNUMF, .elemType: ElemType::Float,
625 .func: &createGroupReduceOpImpl<spirv::GroupFMaxOp,
626 spirv::GroupNonUniformFMaxOp>},
627 {.kind: ReduceType::MINIMUMF, .elemType: ElemType::Float,
628 .func: &createGroupReduceOpImpl<spirv::GroupFMinOp,
629 spirv::GroupNonUniformFMinOp>},
630 {.kind: ReduceType::MAXIMUMF, .elemType: ElemType::Float,
631 .func: &createGroupReduceOpImpl<spirv::GroupFMaxOp,
632 spirv::GroupNonUniformFMaxOp>}};
633
634 for (const OpHandler &handler : handlers)
635 if (handler.kind == opType && elementType == handler.elemType)
636 return handler.func(builder, loc, arg, isGroup, isUniform, clusterSize);
637
638 return std::nullopt;
639}
640
641/// Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
642class GPUAllReduceConversion final
643 : public OpConversionPattern<gpu::AllReduceOp> {
644public:
645 using OpConversionPattern::OpConversionPattern;
646
647 LogicalResult
648 matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor,
649 ConversionPatternRewriter &rewriter) const override {
650 auto opType = op.getOp();
651
652 // gpu.all_reduce can have either reduction op attribute or reduction
653 // region. Only attribute version is supported.
654 if (!opType)
655 return failure();
656
657 auto result =
658 createGroupReduceOp(builder&: rewriter, loc: op.getLoc(), arg: adaptor.getValue(), opType: *opType,
659 /*isGroup*/ true, isUniform: op.getUniform(), clusterSize: std::nullopt);
660 if (!result)
661 return failure();
662
663 rewriter.replaceOp(op, newValues: *result);
664 return success();
665 }
666};
667
668/// Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
669class GPUSubgroupReduceConversion final
670 : public OpConversionPattern<gpu::SubgroupReduceOp> {
671public:
672 using OpConversionPattern::OpConversionPattern;
673
674 LogicalResult
675 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
676 ConversionPatternRewriter &rewriter) const override {
677 if (op.getClusterStride() > 1) {
678 return rewriter.notifyMatchFailure(
679 arg&: op, msg: "lowering for cluster stride > 1 is not implemented");
680 }
681
682 if (!isa<spirv::ScalarType>(Val: adaptor.getValue().getType()))
683 return rewriter.notifyMatchFailure(arg&: op, msg: "reduction type is not a scalar");
684
685 auto result = createGroupReduceOp(
686 builder&: rewriter, loc: op.getLoc(), arg: adaptor.getValue(), opType: adaptor.getOp(),
687 /*isGroup=*/false, isUniform: adaptor.getUniform(), clusterSize: op.getClusterSize());
688 if (!result)
689 return failure();
690
691 rewriter.replaceOp(op, newValues: *result);
692 return success();
693 }
694};
695
696// Formulate a unique variable/constant name after
697// searching in the module for existing variable/constant names.
698// This is to avoid name collision with existing variables.
699// Example: printfMsg0, printfMsg1, printfMsg2, ...
700static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix) {
701 std::string name;
702 unsigned number = 0;
703
704 do {
705 name.clear();
706 name = (prefix + llvm::Twine(number++)).str();
707 } while (moduleOp.lookupSymbol(name));
708
709 return name;
710}
711
712/// Pattern to convert a gpu.printf op into a SPIR-V CLPrintf op.
713
714LogicalResult GPUPrintfConversion::matchAndRewrite(
715 gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
716 ConversionPatternRewriter &rewriter) const {
717
718 Location loc = gpuPrintfOp.getLoc();
719
720 auto moduleOp = gpuPrintfOp->getParentOfType<spirv::ModuleOp>();
721 if (!moduleOp)
722 return failure();
723
724 // SPIR-V global variable is used to initialize printf
725 // format string value, if there are multiple printf messages,
726 // each global var needs to be created with a unique name.
727 std::string globalVarName = makeVarName(moduleOp, prefix: llvm::Twine("printfMsg"));
728 spirv::GlobalVariableOp globalVar;
729
730 IntegerType i8Type = rewriter.getI8Type();
731 IntegerType i32Type = rewriter.getI32Type();
732
733 // Each character of printf format string is
734 // stored as a spec constant. We need to create
735 // unique name for this spec constant like
736 // @printfMsg0_sc0, @printfMsg0_sc1, ... by searching in the module
737 // for existing spec constant names.
738 auto createSpecConstant = [&](unsigned value) {
739 auto attr = rewriter.getI8IntegerAttr(value);
740 std::string specCstName =
741 makeVarName(moduleOp, prefix: llvm::Twine(globalVarName) + "_sc");
742
743 return rewriter.create<spirv::SpecConstantOp>(
744 location: loc, args: rewriter.getStringAttr(bytes: specCstName), args&: attr);
745 };
746 {
747 Operation *parent =
748 SymbolTable::getNearestSymbolTable(from: gpuPrintfOp->getParentOp());
749
750 ConversionPatternRewriter::InsertionGuard guard(rewriter);
751
752 Block &entryBlock = *parent->getRegion(index: 0).begin();
753 rewriter.setInsertionPointToStart(
754 &entryBlock); // insertion point at module level
755
756 // Create Constituents with SpecConstant by scanning format string
757 // Each character of format string is stored as a spec constant
758 // and then these spec constants are used to create a
759 // SpecConstantCompositeOp.
760 llvm::SmallString<20> formatString(adaptor.getFormat());
761 formatString.push_back(Elt: '\0'); // Null terminate for C.
762 SmallVector<Attribute, 4> constituents;
763 for (char c : formatString) {
764 spirv::SpecConstantOp cSpecConstantOp = createSpecConstant(c);
765 constituents.push_back(Elt: SymbolRefAttr::get(symbol: cSpecConstantOp));
766 }
767
768 // Create SpecConstantCompositeOp to initialize the global variable
769 size_t contentSize = constituents.size();
770 auto globalType = spirv::ArrayType::get(elementType: i8Type, elementCount: contentSize);
771 spirv::SpecConstantCompositeOp specCstComposite;
772 // There will be one SpecConstantCompositeOp per printf message/global var,
773 // so no need do lookup for existing ones.
774 std::string specCstCompositeName =
775 (llvm::Twine(globalVarName) + "_scc").str();
776
777 specCstComposite = rewriter.create<spirv::SpecConstantCompositeOp>(
778 location: loc, args: TypeAttr::get(type: globalType),
779 args: rewriter.getStringAttr(bytes: specCstCompositeName),
780 args: rewriter.getArrayAttr(value: constituents));
781
782 auto ptrType = spirv::PointerType::get(
783 pointeeType: globalType, storageClass: spirv::StorageClass::UniformConstant);
784
785 // Define a GlobalVarOp initialized using specialized constants
786 // that is used to specify the printf format string
787 // to be passed to the SPIRV CLPrintfOp.
788 globalVar = rewriter.create<spirv::GlobalVariableOp>(
789 location: loc, args&: ptrType, args&: globalVarName, args: FlatSymbolRefAttr::get(symbol: specCstComposite));
790
791 globalVar->setAttr(name: "Constant", value: rewriter.getUnitAttr());
792 }
793 // Get SSA value of Global variable and create pointer to i8 to point to
794 // the format string.
795 Value globalPtr = rewriter.create<spirv::AddressOfOp>(location: loc, args&: globalVar);
796 Value fmtStr = rewriter.create<spirv::BitcastOp>(
797 location: loc,
798 args: spirv::PointerType::get(pointeeType: i8Type, storageClass: spirv::StorageClass::UniformConstant),
799 args&: globalPtr);
800
801 // Get printf arguments.
802 auto printfArgs = llvm::to_vector_of<Value, 4>(Range: adaptor.getArgs());
803
804 rewriter.create<spirv::CLPrintfOp>(location: loc, args&: i32Type, args&: fmtStr, args&: printfArgs);
805
806 // Need to erase the gpu.printf op as gpu.printf does not use result vs
807 // spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V
808 // printf op.
809 rewriter.eraseOp(op: gpuPrintfOp);
810
811 return success();
812}
813
814//===----------------------------------------------------------------------===//
815// GPU To SPIRV Patterns.
816//===----------------------------------------------------------------------===//
817
818void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
819 RewritePatternSet &patterns) {
820 patterns.add<
821 GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
822 GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
823 LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
824 LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
825 LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
826 LaunchConfigConversion<gpu::ThreadIdOp,
827 spirv::BuiltIn::LocalInvocationId>,
828 LaunchConfigConversion<gpu::GlobalIdOp,
829 spirv::BuiltIn::GlobalInvocationId>,
830 SingleDimLaunchConfigConversion<gpu::SubgroupIdOp,
831 spirv::BuiltIn::SubgroupId>,
832 SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp,
833 spirv::BuiltIn::NumSubgroups>,
834 SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
835 spirv::BuiltIn::SubgroupSize>,
836 SingleDimLaunchConfigConversion<
837 gpu::LaneIdOp, spirv::BuiltIn::SubgroupLocalInvocationId>,
838 WorkGroupSizeConversion, GPUAllReduceConversion,
839 GPUSubgroupReduceConversion, GPUPrintfConversion>(arg: typeConverter,
840 args: patterns.getContext());
841}
842

source code of mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp