1//===- ConvertGPULaunchFuncToVulkanLaunchFunc.cpp - MLIR conversion pass --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert gpu launch function into a vulkan
10// launch function. Creates a SPIR-V binary shader from the `spirv::ModuleOp`
11// using `spirv::serialize` function, attaches binary data and entry point name
12// as an attributes to vulkan launch call op.
13//
14//===----------------------------------------------------------------------===//
15
16#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
17
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/GPU/IR/GPUDialect.h"
20#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
21#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
22#include "mlir/IR/Attributes.h"
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/BuiltinOps.h"
25#include "mlir/IR/BuiltinTypes.h"
26#include "mlir/Pass/Pass.h"
27#include "mlir/Target/SPIRV/Serialization.h"
28
29namespace mlir {
30#define GEN_PASS_DEF_CONVERTGPULAUNCHFUNCTOVULKANLAUNCHFUNC
31#include "mlir/Conversion/Passes.h.inc"
32} // namespace mlir
33
34using namespace mlir;
35
36static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
37static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
38static constexpr const char *kSPIRVElementTypesAttrName = "spirv_element_types";
39static constexpr const char *kVulkanLaunch = "vulkanLaunch";
40
41namespace {
42
43/// A pass to convert gpu launch op to vulkan launch call op, by creating a
44/// SPIR-V binary shader from `spirv::ModuleOp` using `spirv::serialize`
45/// function and attaching binary data and entry point name as an attributes to
46/// created vulkan launch call op.
47class ConvertGpuLaunchFuncToVulkanLaunchFunc
48 : public impl::ConvertGpuLaunchFuncToVulkanLaunchFuncBase<
49 ConvertGpuLaunchFuncToVulkanLaunchFunc> {
50public:
51 void runOnOperation() override;
52
53private:
54 /// Creates a SPIR-V binary shader from the given `module` using
55 /// `spirv::serialize` function.
56 LogicalResult createBinaryShader(ModuleOp module,
57 std::vector<char> &binaryShader);
58
59 /// Converts the given `launchOp` to vulkan launch call.
60 void convertGpuLaunchFunc(gpu::LaunchFuncOp launchOp);
61
62 /// Checks where the given type is supported by Vulkan runtime.
63 bool isSupportedType(Type type) {
64 if (auto memRefType = dyn_cast_or_null<MemRefType>(type)) {
65 auto elementType = memRefType.getElementType();
66 return memRefType.hasRank() &&
67 (memRefType.getRank() >= 1 && memRefType.getRank() <= 3) &&
68 (elementType.isIntOrFloat());
69 }
70 return false;
71 }
72
73 /// Declares the vulkan launch function. Returns an error if the any type of
74 /// operand is unsupported by Vulkan runtime.
75 LogicalResult declareVulkanLaunchFunc(Location loc,
76 gpu::LaunchFuncOp launchOp);
77
78private:
79 /// The number of vulkan launch configuration operands, placed at the leading
80 /// positions of the operand list.
81 static constexpr unsigned kVulkanLaunchNumConfigOperands = 3;
82};
83
84} // namespace
85
86void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnOperation() {
87 bool done = false;
88 getOperation().walk([this, &done](gpu::LaunchFuncOp op) {
89 if (done) {
90 op.emitError("should only contain one 'gpu::LaunchFuncOp' op");
91 return signalPassFailure();
92 }
93 done = true;
94 convertGpuLaunchFunc(op);
95 });
96
97 // Erase `gpu::GPUModuleOp` and `spirv::Module` operations.
98 for (auto gpuModule :
99 llvm::make_early_inc_range(getOperation().getOps<gpu::GPUModuleOp>()))
100 gpuModule.erase();
101
102 for (auto spirvModule :
103 llvm::make_early_inc_range(getOperation().getOps<spirv::ModuleOp>()))
104 spirvModule.erase();
105}
106
107LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::declareVulkanLaunchFunc(
108 Location loc, gpu::LaunchFuncOp launchOp) {
109 auto builder = OpBuilder::atBlockEnd(block: getOperation().getBody());
110
111 // Workgroup size is written into the kernel. So to properly modelling
112 // vulkan launch, we have to skip local workgroup size configuration here.
113 SmallVector<Type, 8> gpuLaunchTypes(launchOp.getOperandTypes());
114 // The first kVulkanLaunchNumConfigOperands of the gpu.launch_func op are the
115 // same as the config operands for the vulkan launch call op.
116 SmallVector<Type, 8> vulkanLaunchTypes(gpuLaunchTypes.begin(),
117 gpuLaunchTypes.begin() +
118 kVulkanLaunchNumConfigOperands);
119 vulkanLaunchTypes.append(gpuLaunchTypes.begin() +
120 gpu::LaunchOp::kNumConfigOperands,
121 gpuLaunchTypes.end());
122
123 // Check that all operands have supported types except those for the
124 // launch configuration.
125 for (auto type :
126 llvm::drop_begin(vulkanLaunchTypes, kVulkanLaunchNumConfigOperands)) {
127 if (!isSupportedType(type))
128 return launchOp.emitError() << type << " is unsupported to run on Vulkan";
129 }
130
131 // Declare vulkan launch function.
132 auto funcType = builder.getFunctionType(vulkanLaunchTypes, {});
133 builder.create<func::FuncOp>(loc, kVulkanLaunch, funcType).setPrivate();
134
135 return success();
136}
137
138LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::createBinaryShader(
139 ModuleOp module, std::vector<char> &binaryShader) {
140 bool done = false;
141 SmallVector<uint32_t, 0> binary;
142 for (auto spirvModule : module.getOps<spirv::ModuleOp>()) {
143 if (done)
144 return spirvModule.emitError("should only contain one 'spirv.module' op");
145 done = true;
146
147 if (failed(spirv::serialize(spirvModule, binary)))
148 return failure();
149 }
150 binaryShader.resize(new_size: binary.size() * sizeof(uint32_t));
151 std::memcpy(dest: binaryShader.data(), src: reinterpret_cast<char *>(binary.data()),
152 n: binaryShader.size());
153 return success();
154}
155
156void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc(
157 gpu::LaunchFuncOp launchOp) {
158 ModuleOp module = getOperation();
159 OpBuilder builder(launchOp);
160 Location loc = launchOp.getLoc();
161
162 // Serialize `spirv::Module` into binary form.
163 std::vector<char> binary;
164 if (failed(createBinaryShader(module: module, binaryShader&: binary)))
165 return signalPassFailure();
166
167 // Declare vulkan launch function.
168 if (failed(declareVulkanLaunchFunc(loc, launchOp)))
169 return signalPassFailure();
170
171 SmallVector<Value, 8> gpuLaunchOperands(launchOp.getOperands());
172 SmallVector<Value, 8> vulkanLaunchOperands(
173 gpuLaunchOperands.begin(),
174 gpuLaunchOperands.begin() + kVulkanLaunchNumConfigOperands);
175 vulkanLaunchOperands.append(gpuLaunchOperands.begin() +
176 gpu::LaunchOp::kNumConfigOperands,
177 gpuLaunchOperands.end());
178
179 // Create vulkan launch call op.
180 auto vulkanLaunchCallOp = builder.create<func::CallOp>(
181 loc, TypeRange{}, SymbolRefAttr::get(builder.getContext(), kVulkanLaunch),
182 vulkanLaunchOperands);
183
184 // Set SPIR-V binary shader data as an attribute.
185 vulkanLaunchCallOp->setAttr(
186 kSPIRVBlobAttrName,
187 builder.getStringAttr(StringRef(binary.data(), binary.size())));
188
189 // Set entry point name as an attribute.
190 vulkanLaunchCallOp->setAttr(kSPIRVEntryPointAttrName,
191 launchOp.getKernelName());
192
193 // Add MemRef element types before they're lost when lowering to LLVM.
194 SmallVector<Type> elementTypes;
195 for (Type type : llvm::drop_begin(launchOp.getOperandTypes(),
196 gpu::LaunchOp::kNumConfigOperands)) {
197 // The below cast always succeeds as it has already been verified in
198 // 'declareVulkanLaunchFunc' that these are MemRefs with compatible element
199 // types.
200 elementTypes.push_back(cast<MemRefType>(type).getElementType());
201 }
202 vulkanLaunchCallOp->setAttr(kSPIRVElementTypesAttrName,
203 builder.getTypeArrayAttr(elementTypes));
204
205 launchOp.erase();
206}
207
208std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
209mlir::createConvertGpuLaunchFuncToVulkanLaunchFuncPass() {
210 return std::make_unique<ConvertGpuLaunchFuncToVulkanLaunchFunc>();
211}
212

source code of mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp