1 | //===- VulkanRuntimeWrappers.cpp - MLIR Vulkan runner wrapper library -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Implements C runtime wrappers around the VulkanRuntime. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include <iostream> |
14 | #include <mutex> |
15 | #include <numeric> |
16 | #include <string> |
17 | #include <vector> |
18 | |
19 | #include "VulkanRuntime.h" |
20 | |
21 | // Explicitly export entry points to the vulkan-runtime-wrapper. |
22 | |
23 | #ifdef _WIN32 |
24 | #define VULKAN_WRAPPER_SYMBOL_EXPORT __declspec(dllexport) |
25 | #else |
26 | #define VULKAN_WRAPPER_SYMBOL_EXPORT __attribute__((visibility("default"))) |
27 | #endif // _WIN32 |
28 | |
29 | namespace { |
30 | |
31 | class VulkanModule; |
32 | |
33 | // Class to be a thing that can be returned from `mgpuModuleGetFunction`. |
34 | struct VulkanFunction { |
35 | VulkanModule *module; |
36 | std::string name; |
37 | |
38 | VulkanFunction(VulkanModule *module, const char *name) |
39 | : module(module), name(name) {} |
40 | }; |
41 | |
42 | // Class to own a copy of the SPIR-V provided to `mgpuModuleLoad` and to manage |
43 | // allocation of pointers returned from `mgpuModuleGetFunction`. |
44 | class VulkanModule { |
45 | public: |
46 | VulkanModule(const uint8_t *ptr, size_t sizeInBytes) |
47 | : blob(ptr, ptr + sizeInBytes) {} |
48 | ~VulkanModule() = default; |
49 | |
50 | VulkanFunction *getFunction(const char *name) { |
51 | return functions.emplace_back(args: std::make_unique<VulkanFunction>(args: this, args&: name)) |
52 | .get(); |
53 | } |
54 | |
55 | uint8_t *blobData() { return blob.data(); } |
56 | size_t blobSizeInBytes() const { return blob.size(); } |
57 | |
58 | private: |
59 | std::vector<uint8_t> blob; |
60 | std::vector<std::unique_ptr<VulkanFunction>> functions; |
61 | }; |
62 | |
63 | class VulkanRuntimeManager { |
64 | public: |
65 | VulkanRuntimeManager() = default; |
66 | VulkanRuntimeManager(const VulkanRuntimeManager &) = delete; |
67 | VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete; |
68 | ~VulkanRuntimeManager() = default; |
69 | |
70 | void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex, |
71 | const VulkanHostMemoryBuffer &memBuffer) { |
72 | std::lock_guard<std::mutex> lock(mutex); |
73 | vulkanRuntime.setResourceData(desIndex: setIndex, bindIndex, hostMemBuffer: memBuffer); |
74 | } |
75 | |
76 | void setEntryPoint(const char *entryPoint) { |
77 | std::lock_guard<std::mutex> lock(mutex); |
78 | vulkanRuntime.setEntryPoint(entryPoint); |
79 | } |
80 | |
81 | void setNumWorkGroups(NumWorkGroups numWorkGroups) { |
82 | std::lock_guard<std::mutex> lock(mutex); |
83 | vulkanRuntime.setNumWorkGroups(numWorkGroups); |
84 | } |
85 | |
86 | void setShaderModule(uint8_t *shader, uint32_t size) { |
87 | std::lock_guard<std::mutex> lock(mutex); |
88 | vulkanRuntime.setShaderModule(shader, size); |
89 | } |
90 | |
91 | void runOnVulkan() { |
92 | std::lock_guard<std::mutex> lock(mutex); |
93 | if (failed(Result: vulkanRuntime.initRuntime()) || failed(Result: vulkanRuntime.run()) || |
94 | failed(Result: vulkanRuntime.updateHostMemoryBuffers()) || |
95 | failed(Result: vulkanRuntime.destroy())) { |
96 | std::cerr << "runOnVulkan failed" ; |
97 | } |
98 | } |
99 | |
100 | private: |
101 | VulkanRuntime vulkanRuntime; |
102 | std::mutex mutex; |
103 | }; |
104 | |
105 | } // namespace |
106 | |
107 | template <typename T, int N> |
108 | struct MemRefDescriptor { |
109 | T *allocated; |
110 | T *aligned; |
111 | int64_t offset; |
112 | int64_t sizes[N]; |
113 | int64_t strides[N]; |
114 | }; |
115 | |
116 | extern "C" { |
117 | |
118 | //===----------------------------------------------------------------------===// |
119 | // |
120 | // Wrappers intended for mlir-runner. Uses of GPU dialect operations get |
121 | // lowered to calls to these functions by GPUToLLVMConversionPass. |
122 | // |
123 | //===----------------------------------------------------------------------===// |
124 | |
125 | VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuStreamCreate() { |
126 | return new VulkanRuntimeManager(); |
127 | } |
128 | |
129 | VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamDestroy(void *vkRuntimeManager) { |
130 | delete static_cast<VulkanRuntimeManager *>(vkRuntimeManager); |
131 | } |
132 | |
133 | VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamSynchronize(void *) { |
134 | // Currently a no-op as the other operations are synchronous. |
135 | } |
136 | |
137 | VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuModuleLoad(const void *data, |
138 | size_t gpuBlobSize) { |
139 | // gpuBlobSize is the size of the data in bytes. |
140 | return new VulkanModule(static_cast<const uint8_t *>(data), gpuBlobSize); |
141 | } |
142 | |
143 | VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuModuleUnload(void *vkModule) { |
144 | delete static_cast<VulkanModule *>(vkModule); |
145 | } |
146 | |
147 | VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuModuleGetFunction(void *vkModule, |
148 | const char *name) { |
149 | if (!vkModule) |
150 | abort(); |
151 | return static_cast<VulkanModule *>(vkModule)->getFunction(name); |
152 | } |
153 | |
154 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
155 | mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ, |
156 | size_t /*blockX*/, size_t /*blockY*/, size_t /*blockZ*/, |
157 | size_t /*smem*/, void *vkRuntimeManager, void **params, |
158 | void ** /*extra*/, size_t paramsCount) { |
159 | auto manager = static_cast<VulkanRuntimeManager *>(vkRuntimeManager); |
160 | |
161 | // GpuToLLVMConversionPass with the kernelBarePtrCallConv and |
162 | // kernelIntersperseSizeCallConv options will set up the params array like: |
163 | // { &memref_ptr0, &memref_size0, &memref_ptr1, &memref_size1, ... } |
164 | const size_t paramsPerMemRef = 2; |
165 | if (paramsCount % paramsPerMemRef != 0) { |
166 | abort(); // This would indicate a serious calling convention mismatch. |
167 | } |
168 | const DescriptorSetIndex setIndex = 0; |
169 | BindingIndex bindIndex = 0; |
170 | for (size_t i = 0; i < paramsCount; i += paramsPerMemRef) { |
171 | void *memrefBufferBasePtr = *static_cast<void **>(params[i + 0]); |
172 | size_t memrefBufferSize = *static_cast<size_t *>(params[i + 1]); |
173 | VulkanHostMemoryBuffer memBuffer{.ptr: memrefBufferBasePtr, |
174 | .size: static_cast<uint32_t>(memrefBufferSize)}; |
175 | manager->setResourceData(setIndex, bindIndex, memBuffer); |
176 | ++bindIndex; |
177 | } |
178 | |
179 | manager->setNumWorkGroups(NumWorkGroups{.x: static_cast<uint32_t>(gridX), |
180 | .y: static_cast<uint32_t>(gridY), |
181 | .z: static_cast<uint32_t>(gridZ)}); |
182 | |
183 | auto function = static_cast<VulkanFunction *>(vkKernel); |
184 | // Expected size should be in bytes. |
185 | manager->setShaderModule( |
186 | shader: function->module->blobData(), |
187 | size: static_cast<uint32_t>(function->module->blobSizeInBytes())); |
188 | manager->setEntryPoint(function->name.c_str()); |
189 | |
190 | manager->runOnVulkan(); |
191 | } |
192 | |
193 | //===----------------------------------------------------------------------===// |
194 | // |
195 | // Miscellaneous utility functions that can be directly used by tests. |
196 | // |
197 | //===----------------------------------------------------------------------===// |
198 | |
199 | /// Fills the given 1D float memref with the given float value. |
200 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
201 | _mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT |
202 | float value) { |
203 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0], value: value); |
204 | } |
205 | |
206 | /// Fills the given 2D float memref with the given float value. |
207 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
208 | _mlir_ciface_fillResource2DFloat(MemRefDescriptor<float, 2> *ptr, // NOLINT |
209 | float value) { |
210 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1], value: value); |
211 | } |
212 | |
213 | /// Fills the given 3D float memref with the given float value. |
214 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
215 | _mlir_ciface_fillResource3DFloat(MemRefDescriptor<float, 3> *ptr, // NOLINT |
216 | float value) { |
217 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], |
218 | value: value); |
219 | } |
220 | |
221 | /// Fills the given 1D int memref with the given int value. |
222 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
223 | _mlir_ciface_fillResource1DInt(MemRefDescriptor<int32_t, 1> *ptr, // NOLINT |
224 | int32_t value) { |
225 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0], value: value); |
226 | } |
227 | |
228 | /// Fills the given 2D int memref with the given int value. |
229 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
230 | _mlir_ciface_fillResource2DInt(MemRefDescriptor<int32_t, 2> *ptr, // NOLINT |
231 | int32_t value) { |
232 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1], value: value); |
233 | } |
234 | |
235 | /// Fills the given 3D int memref with the given int value. |
236 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
237 | _mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t, 3> *ptr, // NOLINT |
238 | int32_t value) { |
239 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], |
240 | value: value); |
241 | } |
242 | |
243 | /// Fills the given 1D int memref with the given int8 value. |
244 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
245 | _mlir_ciface_fillResource1DInt8(MemRefDescriptor<int8_t, 1> *ptr, // NOLINT |
246 | int8_t value) { |
247 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0], value: value); |
248 | } |
249 | |
250 | /// Fills the given 2D int memref with the given int8 value. |
251 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
252 | _mlir_ciface_fillResource2DInt8(MemRefDescriptor<int8_t, 2> *ptr, // NOLINT |
253 | int8_t value) { |
254 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1], value: value); |
255 | } |
256 | |
257 | /// Fills the given 3D int memref with the given int8 value. |
258 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
259 | _mlir_ciface_fillResource3DInt8(MemRefDescriptor<int8_t, 3> *ptr, // NOLINT |
260 | int8_t value) { |
261 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], |
262 | value: value); |
263 | } |
264 | } |
265 | |