1 | //===- GPUToSPIRV.cpp - GPU to SPIR-V Patterns ----------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements patterns to convert GPU dialect to SPIR-V dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h" |
14 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
15 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
16 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
17 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
18 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
19 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
20 | #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" |
21 | #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" |
22 | #include "mlir/IR/BuiltinOps.h" |
23 | #include "mlir/IR/Matchers.h" |
24 | #include "mlir/Support/LogicalResult.h" |
25 | #include "mlir/Transforms/DialectConversion.h" |
26 | #include <optional> |
27 | |
28 | using namespace mlir; |
29 | |
30 | static constexpr const char kSPIRVModule[] = "__spv__" ; |
31 | |
32 | namespace { |
33 | /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation |
34 | /// builtin variables. |
35 | template <typename SourceOp, spirv::BuiltIn builtin> |
36 | class LaunchConfigConversion : public OpConversionPattern<SourceOp> { |
37 | public: |
38 | using OpConversionPattern<SourceOp>::OpConversionPattern; |
39 | |
40 | LogicalResult |
41 | matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, |
42 | ConversionPatternRewriter &rewriter) const override; |
43 | }; |
44 | |
45 | /// Pattern lowering subgroup size/id to loading SPIR-V invocation |
46 | /// builtin variables. |
47 | template <typename SourceOp, spirv::BuiltIn builtin> |
48 | class SingleDimLaunchConfigConversion : public OpConversionPattern<SourceOp> { |
49 | public: |
50 | using OpConversionPattern<SourceOp>::OpConversionPattern; |
51 | |
52 | LogicalResult |
53 | matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, |
54 | ConversionPatternRewriter &rewriter) const override; |
55 | }; |
56 | |
57 | /// This is separate because in Vulkan workgroup size is exposed to shaders via |
58 | /// a constant with WorkgroupSize decoration. So here we cannot generate a |
59 | /// builtin variable; instead the information in the `spirv.entry_point_abi` |
60 | /// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp. |
61 | class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> { |
62 | public: |
63 | WorkGroupSizeConversion(TypeConverter &typeConverter, MLIRContext *context) |
64 | : OpConversionPattern(typeConverter, context, /*benefit*/ 10) {} |
65 | |
66 | LogicalResult |
67 | matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor, |
68 | ConversionPatternRewriter &rewriter) const override; |
69 | }; |
70 | |
71 | /// Pattern to convert a kernel function in GPU dialect within a spirv.module. |
72 | class GPUFuncOpConversion final : public OpConversionPattern<gpu::GPUFuncOp> { |
73 | public: |
74 | using OpConversionPattern<gpu::GPUFuncOp>::OpConversionPattern; |
75 | |
76 | LogicalResult |
77 | matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor, |
78 | ConversionPatternRewriter &rewriter) const override; |
79 | |
80 | private: |
81 | SmallVector<int32_t, 3> workGroupSizeAsInt32; |
82 | }; |
83 | |
84 | /// Pattern to convert a gpu.module to a spirv.module. |
85 | class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> { |
86 | public: |
87 | using OpConversionPattern<gpu::GPUModuleOp>::OpConversionPattern; |
88 | |
89 | LogicalResult |
90 | matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor, |
91 | ConversionPatternRewriter &rewriter) const override; |
92 | }; |
93 | |
94 | class GPUModuleEndConversion final |
95 | : public OpConversionPattern<gpu::ModuleEndOp> { |
96 | public: |
97 | using OpConversionPattern::OpConversionPattern; |
98 | |
99 | LogicalResult |
100 | matchAndRewrite(gpu::ModuleEndOp endOp, OpAdaptor adaptor, |
101 | ConversionPatternRewriter &rewriter) const override { |
102 | rewriter.eraseOp(op: endOp); |
103 | return success(); |
104 | } |
105 | }; |
106 | |
107 | /// Pattern to convert a gpu.return into a SPIR-V return. |
108 | // TODO: This can go to DRR when GPU return has operands. |
109 | class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> { |
110 | public: |
111 | using OpConversionPattern<gpu::ReturnOp>::OpConversionPattern; |
112 | |
113 | LogicalResult |
114 | matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor, |
115 | ConversionPatternRewriter &rewriter) const override; |
116 | }; |
117 | |
118 | /// Pattern to convert a gpu.barrier op into a spirv.ControlBarrier op. |
119 | class GPUBarrierConversion final : public OpConversionPattern<gpu::BarrierOp> { |
120 | public: |
121 | using OpConversionPattern::OpConversionPattern; |
122 | |
123 | LogicalResult |
124 | matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor, |
125 | ConversionPatternRewriter &rewriter) const override; |
126 | }; |
127 | |
128 | /// Pattern to convert a gpu.shuffle op into a spirv.GroupNonUniformShuffle op. |
129 | class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> { |
130 | public: |
131 | using OpConversionPattern::OpConversionPattern; |
132 | |
133 | LogicalResult |
134 | matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor, |
135 | ConversionPatternRewriter &rewriter) const override; |
136 | }; |
137 | |
138 | } // namespace |
139 | |
140 | //===----------------------------------------------------------------------===// |
141 | // Builtins. |
142 | //===----------------------------------------------------------------------===// |
143 | |
144 | template <typename SourceOp, spirv::BuiltIn builtin> |
145 | LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite( |
146 | SourceOp op, typename SourceOp::Adaptor adaptor, |
147 | ConversionPatternRewriter &rewriter) const { |
148 | auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>(); |
149 | Type indexType = typeConverter->getIndexType(); |
150 | |
151 | // For Vulkan, these SPIR-V builtin variables are required to be a vector of |
152 | // type <3xi32> by the spec: |
153 | // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumWorkgroups.html |
154 | // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupId.html |
155 | // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupSize.html |
156 | // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html |
157 | // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html |
158 | // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/GlobalInvocationId.html |
159 | // |
160 | // For OpenCL, it depends on the Physical32/Physical64 addressing model: |
161 | // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables |
162 | bool forShader = |
163 | typeConverter->getTargetEnv().allows(spirv::Capability::Shader); |
164 | Type builtinType = forShader ? rewriter.getIntegerType(32) : indexType; |
165 | |
166 | Value vector = |
167 | spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter); |
168 | Value dim = rewriter.create<spirv::CompositeExtractOp>( |
169 | op.getLoc(), builtinType, vector, |
170 | rewriter.getI32ArrayAttr({static_cast<int32_t>(op.getDimension())})); |
171 | if (forShader && builtinType != indexType) |
172 | dim = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, dim); |
173 | rewriter.replaceOp(op, dim); |
174 | return success(); |
175 | } |
176 | |
177 | template <typename SourceOp, spirv::BuiltIn builtin> |
178 | LogicalResult |
179 | SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite( |
180 | SourceOp op, typename SourceOp::Adaptor adaptor, |
181 | ConversionPatternRewriter &rewriter) const { |
182 | auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>(); |
183 | Type indexType = typeConverter->getIndexType(); |
184 | Type i32Type = rewriter.getIntegerType(32); |
185 | |
186 | // For Vulkan, these SPIR-V builtin variables are required to be a vector of |
187 | // type i32 by the spec: |
188 | // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumSubgroups.html |
189 | // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupId.html |
190 | // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupSize.html |
191 | // |
192 | // For OpenCL, they are also required to be i32: |
193 | // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables |
194 | Value builtinValue = |
195 | spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter); |
196 | if (i32Type != indexType) |
197 | builtinValue = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, |
198 | builtinValue); |
199 | rewriter.replaceOp(op, builtinValue); |
200 | return success(); |
201 | } |
202 | |
203 | LogicalResult WorkGroupSizeConversion::matchAndRewrite( |
204 | gpu::BlockDimOp op, OpAdaptor adaptor, |
205 | ConversionPatternRewriter &rewriter) const { |
206 | DenseI32ArrayAttr workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op); |
207 | if (!workGroupSizeAttr) |
208 | return failure(); |
209 | |
210 | int val = |
211 | workGroupSizeAttr.asArrayRef()[static_cast<int32_t>(op.getDimension())]; |
212 | auto convertedType = |
213 | getTypeConverter()->convertType(op.getResult().getType()); |
214 | if (!convertedType) |
215 | return failure(); |
216 | rewriter.replaceOpWithNewOp<spirv::ConstantOp>( |
217 | op, convertedType, IntegerAttr::get(convertedType, val)); |
218 | return success(); |
219 | } |
220 | |
221 | //===----------------------------------------------------------------------===// |
222 | // GPUFuncOp |
223 | //===----------------------------------------------------------------------===// |
224 | |
225 | // Legalizes a GPU function as an entry SPIR-V function. |
226 | static spirv::FuncOp |
227 | lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter, |
228 | ConversionPatternRewriter &rewriter, |
229 | spirv::EntryPointABIAttr entryPointInfo, |
230 | ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) { |
231 | auto fnType = funcOp.getFunctionType(); |
232 | if (fnType.getNumResults()) { |
233 | funcOp.emitError("SPIR-V lowering only supports entry functions" |
234 | "with no return values right now" ); |
235 | return nullptr; |
236 | } |
237 | if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) { |
238 | funcOp.emitError( |
239 | "lowering as entry functions requires ABI info for all arguments " |
240 | "or none of them" ); |
241 | return nullptr; |
242 | } |
243 | // Update the signature to valid SPIR-V types and add the ABI |
244 | // attributes. These will be "materialized" by using the |
245 | // LowerABIAttributesPass. |
246 | TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); |
247 | { |
248 | for (const auto &argType : |
249 | enumerate(funcOp.getFunctionType().getInputs())) { |
250 | auto convertedType = typeConverter.convertType(argType.value()); |
251 | if (!convertedType) |
252 | return nullptr; |
253 | signatureConverter.addInputs(argType.index(), convertedType); |
254 | } |
255 | } |
256 | auto newFuncOp = rewriter.create<spirv::FuncOp>( |
257 | funcOp.getLoc(), funcOp.getName(), |
258 | rewriter.getFunctionType(signatureConverter.getConvertedTypes(), |
259 | std::nullopt)); |
260 | for (const auto &namedAttr : funcOp->getAttrs()) { |
261 | if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() || |
262 | namedAttr.getName() == SymbolTable::getSymbolAttrName()) |
263 | continue; |
264 | newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); |
265 | } |
266 | |
267 | rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), |
268 | newFuncOp.end()); |
269 | if (failed(rewriter.convertRegionTypes(region: &newFuncOp.getBody(), converter: typeConverter, |
270 | entryConversion: &signatureConverter))) |
271 | return nullptr; |
272 | rewriter.eraseOp(op: funcOp); |
273 | |
274 | // Set the attributes for argument and the function. |
275 | StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName(); |
276 | for (auto argIndex : llvm::seq<unsigned>(Begin: 0, End: argABIInfo.size())) { |
277 | newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]); |
278 | } |
279 | newFuncOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo); |
280 | |
281 | return newFuncOp; |
282 | } |
283 | |
284 | /// Populates `argABI` with spirv.interface_var_abi attributes for lowering |
285 | /// gpu.func to spirv.func if no arguments have the attributes set |
286 | /// already. Returns failure if any argument has the ABI attribute set already. |
287 | static LogicalResult |
288 | getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp, |
289 | SmallVectorImpl<spirv::InterfaceVarABIAttr> &argABI) { |
290 | if (!spirv::needsInterfaceVarABIAttrs(targetAttr: targetEnv)) |
291 | return success(); |
292 | |
293 | for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) { |
294 | if (funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>( |
295 | argIndex, spirv::getInterfaceVarABIAttrName())) |
296 | return failure(); |
297 | // Vulkan's interface variable requirements needs scalars to be wrapped in a |
298 | // struct. The struct held in storage buffer. |
299 | std::optional<spirv::StorageClass> sc; |
300 | if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat()) |
301 | sc = spirv::StorageClass::StorageBuffer; |
302 | argABI.push_back( |
303 | spirv::getInterfaceVarABIAttr(0, argIndex, sc, funcOp.getContext())); |
304 | } |
305 | return success(); |
306 | } |
307 | |
308 | LogicalResult GPUFuncOpConversion::matchAndRewrite( |
309 | gpu::GPUFuncOp funcOp, OpAdaptor adaptor, |
310 | ConversionPatternRewriter &rewriter) const { |
311 | if (!gpu::GPUDialect::isKernel(funcOp)) |
312 | return failure(); |
313 | |
314 | auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); |
315 | SmallVector<spirv::InterfaceVarABIAttr, 4> argABI; |
316 | if (failed( |
317 | getDefaultABIAttrs(typeConverter->getTargetEnv(), funcOp, argABI))) { |
318 | argABI.clear(); |
319 | for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) { |
320 | // If the ABI is already specified, use it. |
321 | auto abiAttr = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>( |
322 | argIndex, spirv::getInterfaceVarABIAttrName()); |
323 | if (!abiAttr) { |
324 | funcOp.emitRemark( |
325 | "match failure: missing 'spirv.interface_var_abi' attribute at " |
326 | "argument " ) |
327 | << argIndex; |
328 | return failure(); |
329 | } |
330 | argABI.push_back(abiAttr); |
331 | } |
332 | } |
333 | |
334 | auto entryPointAttr = spirv::lookupEntryPointABI(funcOp); |
335 | if (!entryPointAttr) { |
336 | funcOp.emitRemark( |
337 | "match failure: missing 'spirv.entry_point_abi' attribute" ); |
338 | return failure(); |
339 | } |
340 | spirv::FuncOp newFuncOp = lowerAsEntryFunction( |
341 | funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI); |
342 | if (!newFuncOp) |
343 | return failure(); |
344 | newFuncOp->removeAttr( |
345 | rewriter.getStringAttr(gpu::GPUDialect::getKernelFuncAttrName())); |
346 | return success(); |
347 | } |
348 | |
349 | //===----------------------------------------------------------------------===// |
350 | // ModuleOp with gpu.module. |
351 | //===----------------------------------------------------------------------===// |
352 | |
353 | LogicalResult GPUModuleConversion::matchAndRewrite( |
354 | gpu::GPUModuleOp moduleOp, OpAdaptor adaptor, |
355 | ConversionPatternRewriter &rewriter) const { |
356 | auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); |
357 | const spirv::TargetEnv &targetEnv = typeConverter->getTargetEnv(); |
358 | spirv::AddressingModel addressingModel = spirv::getAddressingModel( |
359 | targetEnv, typeConverter->getOptions().use64bitIndex); |
360 | FailureOr<spirv::MemoryModel> memoryModel = spirv::getMemoryModel(targetEnv); |
361 | if (failed(memoryModel)) |
362 | return moduleOp.emitRemark( |
363 | "cannot deduce memory model from 'spirv.target_env'" ); |
364 | |
365 | // Add a keyword to the module name to avoid symbolic conflict. |
366 | std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str(); |
367 | auto spvModule = rewriter.create<spirv::ModuleOp>( |
368 | moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt, |
369 | StringRef(spvModuleName)); |
370 | |
371 | // Move the region from the module op into the SPIR-V module. |
372 | Region &spvModuleRegion = spvModule.getRegion(); |
373 | rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion, |
374 | spvModuleRegion.begin()); |
375 | // The spirv.module build method adds a block. Remove that. |
376 | rewriter.eraseBlock(block: &spvModuleRegion.back()); |
377 | |
378 | // Some of the patterns call `lookupTargetEnv` during conversion and they |
379 | // will fail if called after GPUModuleConversion and we don't preserve |
380 | // `TargetEnv` attribute. |
381 | // Copy TargetEnvAttr only if it is attached directly to the GPUModuleOp. |
382 | if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>( |
383 | spirv::getTargetEnvAttrName())) |
384 | spvModule->setAttr(spirv::getTargetEnvAttrName(), attr); |
385 | |
386 | rewriter.eraseOp(op: moduleOp); |
387 | return success(); |
388 | } |
389 | |
390 | //===----------------------------------------------------------------------===// |
391 | // GPU return inside kernel functions to SPIR-V return. |
392 | //===----------------------------------------------------------------------===// |
393 | |
394 | LogicalResult GPUReturnOpConversion::matchAndRewrite( |
395 | gpu::ReturnOp returnOp, OpAdaptor adaptor, |
396 | ConversionPatternRewriter &rewriter) const { |
397 | if (!adaptor.getOperands().empty()) |
398 | return failure(); |
399 | |
400 | rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp); |
401 | return success(); |
402 | } |
403 | |
404 | //===----------------------------------------------------------------------===// |
405 | // Barrier. |
406 | //===----------------------------------------------------------------------===// |
407 | |
408 | LogicalResult GPUBarrierConversion::matchAndRewrite( |
409 | gpu::BarrierOp barrierOp, OpAdaptor adaptor, |
410 | ConversionPatternRewriter &rewriter) const { |
411 | MLIRContext *context = getContext(); |
412 | // Both execution and memory scope should be workgroup. |
413 | auto scope = spirv::ScopeAttr::get(context, spirv::Scope::Workgroup); |
414 | // Require acquire and release memory semantics for workgroup memory. |
415 | auto memorySemantics = spirv::MemorySemanticsAttr::get( |
416 | context, spirv::MemorySemantics::WorkgroupMemory | |
417 | spirv::MemorySemantics::AcquireRelease); |
418 | rewriter.replaceOpWithNewOp<spirv::ControlBarrierOp>(barrierOp, scope, scope, |
419 | memorySemantics); |
420 | return success(); |
421 | } |
422 | |
423 | //===----------------------------------------------------------------------===// |
424 | // Shuffle |
425 | //===----------------------------------------------------------------------===// |
426 | |
427 | LogicalResult GPUShuffleConversion::matchAndRewrite( |
428 | gpu::ShuffleOp shuffleOp, OpAdaptor adaptor, |
429 | ConversionPatternRewriter &rewriter) const { |
430 | // Require the shuffle width to be the same as the target's subgroup size, |
431 | // given that for SPIR-V non-uniform subgroup ops, we cannot select |
432 | // participating invocations. |
433 | auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv(); |
434 | unsigned subgroupSize = |
435 | targetEnv.getAttr().getResourceLimits().getSubgroupSize(); |
436 | IntegerAttr widthAttr; |
437 | if (!matchPattern(shuffleOp.getWidth(), m_Constant(&widthAttr)) || |
438 | widthAttr.getValue().getZExtValue() != subgroupSize) |
439 | return rewriter.notifyMatchFailure( |
440 | shuffleOp, "shuffle width and target subgroup size mismatch" ); |
441 | |
442 | Location loc = shuffleOp.getLoc(); |
443 | Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), |
444 | shuffleOp.getLoc(), rewriter); |
445 | auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup); |
446 | Value result; |
447 | |
448 | switch (shuffleOp.getMode()) { |
449 | case gpu::ShuffleMode::XOR: |
450 | result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>( |
451 | loc, scope, adaptor.getValue(), adaptor.getOffset()); |
452 | break; |
453 | case gpu::ShuffleMode::IDX: |
454 | result = rewriter.create<spirv::GroupNonUniformShuffleOp>( |
455 | loc, scope, adaptor.getValue(), adaptor.getOffset()); |
456 | break; |
457 | default: |
458 | return rewriter.notifyMatchFailure(shuffleOp, "unimplemented shuffle mode" ); |
459 | } |
460 | |
461 | rewriter.replaceOp(shuffleOp, {result, trueVal}); |
462 | return success(); |
463 | } |
464 | |
465 | //===----------------------------------------------------------------------===// |
466 | // Group ops |
467 | //===----------------------------------------------------------------------===// |
468 | |
469 | template <typename UniformOp, typename NonUniformOp> |
470 | static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, |
471 | Value arg, bool isGroup, bool isUniform) { |
472 | Type type = arg.getType(); |
473 | auto scope = mlir::spirv::ScopeAttr::get(builder.getContext(), |
474 | isGroup ? spirv::Scope::Workgroup |
475 | : spirv::Scope::Subgroup); |
476 | auto groupOp = spirv::GroupOperationAttr::get(builder.getContext(), |
477 | spirv::GroupOperation::Reduce); |
478 | if (isUniform) { |
479 | return builder.create<UniformOp>(loc, type, scope, groupOp, arg) |
480 | .getResult(); |
481 | } |
482 | return builder.create<NonUniformOp>(loc, type, scope, groupOp, arg, Value{}) |
483 | .getResult(); |
484 | } |
485 | |
486 | static std::optional<Value> createGroupReduceOp(OpBuilder &builder, |
487 | Location loc, Value arg, |
488 | gpu::AllReduceOperation opType, |
489 | bool isGroup, bool isUniform) { |
490 | enum class ElemType { Float, Boolean, Integer }; |
491 | using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool); |
492 | struct OpHandler { |
493 | gpu::AllReduceOperation kind; |
494 | ElemType elemType; |
495 | FuncT func; |
496 | }; |
497 | |
498 | Type type = arg.getType(); |
499 | ElemType elementType; |
500 | if (isa<FloatType>(Val: type)) { |
501 | elementType = ElemType::Float; |
502 | } else if (auto intTy = dyn_cast<IntegerType>(type)) { |
503 | elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean |
504 | : ElemType::Integer; |
505 | } else { |
506 | return std::nullopt; |
507 | } |
508 | |
509 | // TODO(https://github.com/llvm/llvm-project/issues/73459): The SPIR-V spec |
510 | // does not specify how -0.0 / +0.0 and NaN values are handled in *FMin/*FMax |
511 | // reduction ops. We should account possible precision requirements in this |
512 | // conversion. |
513 | |
514 | using ReduceType = gpu::AllReduceOperation; |
515 | const OpHandler handlers[] = { |
516 | {ReduceType::ADD, ElemType::Integer, |
517 | &createGroupReduceOpImpl<spirv::GroupIAddOp, |
518 | spirv::GroupNonUniformIAddOp>}, |
519 | {ReduceType::ADD, ElemType::Float, |
520 | &createGroupReduceOpImpl<spirv::GroupFAddOp, |
521 | spirv::GroupNonUniformFAddOp>}, |
522 | {ReduceType::MUL, ElemType::Integer, |
523 | &createGroupReduceOpImpl<spirv::GroupIMulKHROp, |
524 | spirv::GroupNonUniformIMulOp>}, |
525 | {ReduceType::MUL, ElemType::Float, |
526 | &createGroupReduceOpImpl<spirv::GroupFMulKHROp, |
527 | spirv::GroupNonUniformFMulOp>}, |
528 | {ReduceType::MINUI, ElemType::Integer, |
529 | &createGroupReduceOpImpl<spirv::GroupUMinOp, |
530 | spirv::GroupNonUniformUMinOp>}, |
531 | {ReduceType::MINSI, ElemType::Integer, |
532 | &createGroupReduceOpImpl<spirv::GroupSMinOp, |
533 | spirv::GroupNonUniformSMinOp>}, |
534 | {ReduceType::MINNUMF, ElemType::Float, |
535 | &createGroupReduceOpImpl<spirv::GroupFMinOp, |
536 | spirv::GroupNonUniformFMinOp>}, |
537 | {ReduceType::MAXUI, ElemType::Integer, |
538 | &createGroupReduceOpImpl<spirv::GroupUMaxOp, |
539 | spirv::GroupNonUniformUMaxOp>}, |
540 | {ReduceType::MAXSI, ElemType::Integer, |
541 | &createGroupReduceOpImpl<spirv::GroupSMaxOp, |
542 | spirv::GroupNonUniformSMaxOp>}, |
543 | {ReduceType::MAXNUMF, ElemType::Float, |
544 | &createGroupReduceOpImpl<spirv::GroupFMaxOp, |
545 | spirv::GroupNonUniformFMaxOp>}, |
546 | {ReduceType::MINIMUMF, ElemType::Float, |
547 | &createGroupReduceOpImpl<spirv::GroupFMinOp, |
548 | spirv::GroupNonUniformFMinOp>}, |
549 | {ReduceType::MAXIMUMF, ElemType::Float, |
550 | &createGroupReduceOpImpl<spirv::GroupFMaxOp, |
551 | spirv::GroupNonUniformFMaxOp>}}; |
552 | |
553 | for (const OpHandler &handler : handlers) |
554 | if (handler.kind == opType && elementType == handler.elemType) |
555 | return handler.func(builder, loc, arg, isGroup, isUniform); |
556 | |
557 | return std::nullopt; |
558 | } |
559 | |
560 | /// Pattern to convert a gpu.all_reduce op into a SPIR-V group op. |
561 | class GPUAllReduceConversion final |
562 | : public OpConversionPattern<gpu::AllReduceOp> { |
563 | public: |
564 | using OpConversionPattern::OpConversionPattern; |
565 | |
566 | LogicalResult |
567 | matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor, |
568 | ConversionPatternRewriter &rewriter) const override { |
569 | auto opType = op.getOp(); |
570 | |
571 | // gpu.all_reduce can have either reduction op attribute or reduction |
572 | // region. Only attribute version is supported. |
573 | if (!opType) |
574 | return failure(); |
575 | |
576 | auto result = |
577 | createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), *opType, |
578 | /*isGroup*/ true, op.getUniform()); |
579 | if (!result) |
580 | return failure(); |
581 | |
582 | rewriter.replaceOp(op, *result); |
583 | return success(); |
584 | } |
585 | }; |
586 | |
587 | /// Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op. |
588 | class GPUSubgroupReduceConversion final |
589 | : public OpConversionPattern<gpu::SubgroupReduceOp> { |
590 | public: |
591 | using OpConversionPattern::OpConversionPattern; |
592 | |
593 | LogicalResult |
594 | matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, |
595 | ConversionPatternRewriter &rewriter) const override { |
596 | if (!isa<spirv::ScalarType>(adaptor.getValue().getType())) |
597 | return rewriter.notifyMatchFailure(op, "reduction type is not a scalar" ); |
598 | |
599 | auto result = createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), |
600 | adaptor.getOp(), |
601 | /*isGroup=*/false, adaptor.getUniform()); |
602 | if (!result) |
603 | return failure(); |
604 | |
605 | rewriter.replaceOp(op, *result); |
606 | return success(); |
607 | } |
608 | }; |
609 | |
610 | //===----------------------------------------------------------------------===// |
611 | // GPU To SPIRV Patterns. |
612 | //===----------------------------------------------------------------------===// |
613 | |
614 | void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter, |
615 | RewritePatternSet &patterns) { |
616 | patterns.add< |
617 | GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion, |
618 | GPUModuleEndConversion, GPUReturnOpConversion, GPUShuffleConversion, |
619 | LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>, |
620 | LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>, |
621 | LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>, |
622 | LaunchConfigConversion<gpu::ThreadIdOp, |
623 | spirv::BuiltIn::LocalInvocationId>, |
624 | LaunchConfigConversion<gpu::GlobalIdOp, |
625 | spirv::BuiltIn::GlobalInvocationId>, |
626 | SingleDimLaunchConfigConversion<gpu::SubgroupIdOp, |
627 | spirv::BuiltIn::SubgroupId>, |
628 | SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp, |
629 | spirv::BuiltIn::NumSubgroups>, |
630 | SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp, |
631 | spirv::BuiltIn::SubgroupSize>, |
632 | WorkGroupSizeConversion, GPUAllReduceConversion, |
633 | GPUSubgroupReduceConversion>(typeConverter, patterns.getContext()); |
634 | } |
635 | |