1//===- VulkanRuntimeWrappers.cpp - MLIR Vulkan runner wrapper library -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implements C runtime wrappers around the VulkanRuntime.
10//
11//===----------------------------------------------------------------------===//
12
13#include <iostream>
14#include <mutex>
15#include <numeric>
16#include <string>
17#include <vector>
18
19#include "VulkanRuntime.h"
20
21// Explicitly export entry points to the vulkan-runtime-wrapper.
22
23#ifdef _WIN32
24#define VULKAN_WRAPPER_SYMBOL_EXPORT __declspec(dllexport)
25#else
26#define VULKAN_WRAPPER_SYMBOL_EXPORT __attribute__((visibility("default")))
27#endif // _WIN32
28
29namespace {
30
31class VulkanModule;
32
33// Class to be a thing that can be returned from `mgpuModuleGetFunction`.
34struct VulkanFunction {
35 VulkanModule *module;
36 std::string name;
37
38 VulkanFunction(VulkanModule *module, const char *name)
39 : module(module), name(name) {}
40};
41
42// Class to own a copy of the SPIR-V provided to `mgpuModuleLoad` and to manage
43// allocation of pointers returned from `mgpuModuleGetFunction`.
44class VulkanModule {
45public:
46 VulkanModule(const uint8_t *ptr, size_t sizeInBytes)
47 : blob(ptr, ptr + sizeInBytes) {}
48 ~VulkanModule() = default;
49
50 VulkanFunction *getFunction(const char *name) {
51 return functions.emplace_back(args: std::make_unique<VulkanFunction>(args: this, args&: name))
52 .get();
53 }
54
55 uint8_t *blobData() { return blob.data(); }
56 size_t blobSizeInBytes() const { return blob.size(); }
57
58private:
59 std::vector<uint8_t> blob;
60 std::vector<std::unique_ptr<VulkanFunction>> functions;
61};
62
63class VulkanRuntimeManager {
64public:
65 VulkanRuntimeManager() = default;
66 VulkanRuntimeManager(const VulkanRuntimeManager &) = delete;
67 VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete;
68 ~VulkanRuntimeManager() = default;
69
70 void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex,
71 const VulkanHostMemoryBuffer &memBuffer) {
72 std::lock_guard<std::mutex> lock(mutex);
73 vulkanRuntime.setResourceData(desIndex: setIndex, bindIndex, hostMemBuffer: memBuffer);
74 }
75
76 void setEntryPoint(const char *entryPoint) {
77 std::lock_guard<std::mutex> lock(mutex);
78 vulkanRuntime.setEntryPoint(entryPoint);
79 }
80
81 void setNumWorkGroups(NumWorkGroups numWorkGroups) {
82 std::lock_guard<std::mutex> lock(mutex);
83 vulkanRuntime.setNumWorkGroups(numWorkGroups);
84 }
85
86 void setShaderModule(uint8_t *shader, uint32_t size) {
87 std::lock_guard<std::mutex> lock(mutex);
88 vulkanRuntime.setShaderModule(shader, size);
89 }
90
91 void runOnVulkan() {
92 std::lock_guard<std::mutex> lock(mutex);
93 if (failed(Result: vulkanRuntime.initRuntime()) || failed(Result: vulkanRuntime.run()) ||
94 failed(Result: vulkanRuntime.updateHostMemoryBuffers()) ||
95 failed(Result: vulkanRuntime.destroy())) {
96 std::cerr << "runOnVulkan failed";
97 }
98 }
99
100private:
101 VulkanRuntime vulkanRuntime;
102 std::mutex mutex;
103};
104
105} // namespace
106
107template <typename T, int N>
108struct MemRefDescriptor {
109 T *allocated;
110 T *aligned;
111 int64_t offset;
112 int64_t sizes[N];
113 int64_t strides[N];
114};
115
116extern "C" {
117
118//===----------------------------------------------------------------------===//
119//
120// Wrappers intended for mlir-runner. Uses of GPU dialect operations get
121// lowered to calls to these functions by GPUToLLVMConversionPass.
122//
123//===----------------------------------------------------------------------===//
124
125VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuStreamCreate() {
126 return new VulkanRuntimeManager();
127}
128
129VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamDestroy(void *vkRuntimeManager) {
130 delete static_cast<VulkanRuntimeManager *>(vkRuntimeManager);
131}
132
133VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamSynchronize(void *) {
134 // Currently a no-op as the other operations are synchronous.
135}
136
137VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuModuleLoad(const void *data,
138 size_t gpuBlobSize) {
139 // gpuBlobSize is the size of the data in bytes.
140 return new VulkanModule(static_cast<const uint8_t *>(data), gpuBlobSize);
141}
142
143VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuModuleUnload(void *vkModule) {
144 delete static_cast<VulkanModule *>(vkModule);
145}
146
147VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuModuleGetFunction(void *vkModule,
148 const char *name) {
149 if (!vkModule)
150 abort();
151 return static_cast<VulkanModule *>(vkModule)->getFunction(name);
152}
153
154VULKAN_WRAPPER_SYMBOL_EXPORT void
155mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ,
156 size_t /*blockX*/, size_t /*blockY*/, size_t /*blockZ*/,
157 size_t /*smem*/, void *vkRuntimeManager, void **params,
158 void ** /*extra*/, size_t paramsCount) {
159 auto manager = static_cast<VulkanRuntimeManager *>(vkRuntimeManager);
160
161 // GpuToLLVMConversionPass with the kernelBarePtrCallConv and
162 // kernelIntersperseSizeCallConv options will set up the params array like:
163 // { &memref_ptr0, &memref_size0, &memref_ptr1, &memref_size1, ... }
164 const size_t paramsPerMemRef = 2;
165 if (paramsCount % paramsPerMemRef != 0) {
166 abort(); // This would indicate a serious calling convention mismatch.
167 }
168 const DescriptorSetIndex setIndex = 0;
169 BindingIndex bindIndex = 0;
170 for (size_t i = 0; i < paramsCount; i += paramsPerMemRef) {
171 void *memrefBufferBasePtr = *static_cast<void **>(params[i + 0]);
172 size_t memrefBufferSize = *static_cast<size_t *>(params[i + 1]);
173 VulkanHostMemoryBuffer memBuffer{.ptr: memrefBufferBasePtr,
174 .size: static_cast<uint32_t>(memrefBufferSize)};
175 manager->setResourceData(setIndex, bindIndex, memBuffer);
176 ++bindIndex;
177 }
178
179 manager->setNumWorkGroups(NumWorkGroups{.x: static_cast<uint32_t>(gridX),
180 .y: static_cast<uint32_t>(gridY),
181 .z: static_cast<uint32_t>(gridZ)});
182
183 auto function = static_cast<VulkanFunction *>(vkKernel);
184 // Expected size should be in bytes.
185 manager->setShaderModule(
186 shader: function->module->blobData(),
187 size: static_cast<uint32_t>(function->module->blobSizeInBytes()));
188 manager->setEntryPoint(function->name.c_str());
189
190 manager->runOnVulkan();
191}
192
193//===----------------------------------------------------------------------===//
194//
195// Miscellaneous utility functions that can be directly used by tests.
196//
197//===----------------------------------------------------------------------===//
198
199/// Fills the given 1D float memref with the given float value.
200VULKAN_WRAPPER_SYMBOL_EXPORT void
201_mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT
202 float value) {
203 std::fill_n(first: ptr->allocated, n: ptr->sizes[0], value: value);
204}
205
206/// Fills the given 2D float memref with the given float value.
207VULKAN_WRAPPER_SYMBOL_EXPORT void
208_mlir_ciface_fillResource2DFloat(MemRefDescriptor<float, 2> *ptr, // NOLINT
209 float value) {
210 std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1], value: value);
211}
212
213/// Fills the given 3D float memref with the given float value.
214VULKAN_WRAPPER_SYMBOL_EXPORT void
215_mlir_ciface_fillResource3DFloat(MemRefDescriptor<float, 3> *ptr, // NOLINT
216 float value) {
217 std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
218 value: value);
219}
220
221/// Fills the given 1D int memref with the given int value.
222VULKAN_WRAPPER_SYMBOL_EXPORT void
223_mlir_ciface_fillResource1DInt(MemRefDescriptor<int32_t, 1> *ptr, // NOLINT
224 int32_t value) {
225 std::fill_n(first: ptr->allocated, n: ptr->sizes[0], value: value);
226}
227
228/// Fills the given 2D int memref with the given int value.
229VULKAN_WRAPPER_SYMBOL_EXPORT void
230_mlir_ciface_fillResource2DInt(MemRefDescriptor<int32_t, 2> *ptr, // NOLINT
231 int32_t value) {
232 std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1], value: value);
233}
234
235/// Fills the given 3D int memref with the given int value.
236VULKAN_WRAPPER_SYMBOL_EXPORT void
237_mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t, 3> *ptr, // NOLINT
238 int32_t value) {
239 std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
240 value: value);
241}
242
243/// Fills the given 1D int memref with the given int8 value.
244VULKAN_WRAPPER_SYMBOL_EXPORT void
245_mlir_ciface_fillResource1DInt8(MemRefDescriptor<int8_t, 1> *ptr, // NOLINT
246 int8_t value) {
247 std::fill_n(first: ptr->allocated, n: ptr->sizes[0], value: value);
248}
249
250/// Fills the given 2D int memref with the given int8 value.
251VULKAN_WRAPPER_SYMBOL_EXPORT void
252_mlir_ciface_fillResource2DInt8(MemRefDescriptor<int8_t, 2> *ptr, // NOLINT
253 int8_t value) {
254 std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1], value: value);
255}
256
257/// Fills the given 3D int memref with the given int8 value.
258VULKAN_WRAPPER_SYMBOL_EXPORT void
259_mlir_ciface_fillResource3DInt8(MemRefDescriptor<int8_t, 3> *ptr, // NOLINT
260 int8_t value) {
261 std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
262 value: value);
263}
264}
265

source code of mlir/lib/ExecutionEngine/VulkanRuntimeWrappers.cpp