1 | //===- vulkan-runtime-wrappers.cpp - MLIR Vulkan runner wrapper library ---===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Implements C runtime wrappers around the VulkanRuntime. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include <iostream> |
14 | #include <mutex> |
15 | #include <numeric> |
16 | |
17 | #include "VulkanRuntime.h" |
18 | |
19 | // Explicitly export entry points to the vulkan-runtime-wrapper. |
20 | |
21 | #ifdef _WIN32 |
22 | #define VULKAN_WRAPPER_SYMBOL_EXPORT __declspec(dllexport) |
23 | #else |
24 | #define VULKAN_WRAPPER_SYMBOL_EXPORT __attribute__((visibility("default"))) |
25 | #endif // _WIN32 |
26 | |
27 | namespace { |
28 | |
29 | class VulkanRuntimeManager { |
30 | public: |
31 | VulkanRuntimeManager() = default; |
32 | VulkanRuntimeManager(const VulkanRuntimeManager &) = delete; |
33 | VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete; |
34 | ~VulkanRuntimeManager() = default; |
35 | |
36 | void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex, |
37 | const VulkanHostMemoryBuffer &memBuffer) { |
38 | std::lock_guard<std::mutex> lock(mutex); |
39 | vulkanRuntime.setResourceData(desIndex: setIndex, bindIndex, hostMemBuffer: memBuffer); |
40 | } |
41 | |
42 | void setEntryPoint(const char *entryPoint) { |
43 | std::lock_guard<std::mutex> lock(mutex); |
44 | vulkanRuntime.setEntryPoint(entryPoint); |
45 | } |
46 | |
47 | void setNumWorkGroups(NumWorkGroups numWorkGroups) { |
48 | std::lock_guard<std::mutex> lock(mutex); |
49 | vulkanRuntime.setNumWorkGroups(numWorkGroups); |
50 | } |
51 | |
52 | void setShaderModule(uint8_t *shader, uint32_t size) { |
53 | std::lock_guard<std::mutex> lock(mutex); |
54 | vulkanRuntime.setShaderModule(shader, size); |
55 | } |
56 | |
57 | void runOnVulkan() { |
58 | std::lock_guard<std::mutex> lock(mutex); |
59 | if (failed(result: vulkanRuntime.initRuntime()) || failed(result: vulkanRuntime.run()) || |
60 | failed(result: vulkanRuntime.updateHostMemoryBuffers()) || |
61 | failed(result: vulkanRuntime.destroy())) { |
62 | std::cerr << "runOnVulkan failed" ; |
63 | } |
64 | } |
65 | |
66 | private: |
67 | VulkanRuntime vulkanRuntime; |
68 | std::mutex mutex; |
69 | }; |
70 | |
71 | } // namespace |
72 | |
73 | template <typename T, int N> |
74 | struct MemRefDescriptor { |
75 | T *allocated; |
76 | T *aligned; |
77 | int64_t offset; |
78 | int64_t sizes[N]; |
79 | int64_t strides[N]; |
80 | }; |
81 | |
82 | template <typename T, uint32_t S> |
83 | void bindMemRef(void *vkRuntimeManager, DescriptorSetIndex setIndex, |
84 | BindingIndex bindIndex, MemRefDescriptor<T, S> *ptr) { |
85 | uint32_t size = sizeof(T); |
86 | for (unsigned i = 0; i < S; i++) |
87 | size *= ptr->sizes[i]; |
88 | VulkanHostMemoryBuffer memBuffer{ptr->aligned, size}; |
89 | reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager) |
90 | ->setResourceData(setIndex, bindIndex, memBuffer); |
91 | } |
92 | |
93 | extern "C" { |
94 | /// Initializes `VulkanRuntimeManager` and returns a pointer to it. |
95 | VULKAN_WRAPPER_SYMBOL_EXPORT void *initVulkan() { |
96 | return new VulkanRuntimeManager(); |
97 | } |
98 | |
99 | /// Deinitializes `VulkanRuntimeManager` by the given pointer. |
100 | VULKAN_WRAPPER_SYMBOL_EXPORT void deinitVulkan(void *vkRuntimeManager) { |
101 | delete reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager); |
102 | } |
103 | |
104 | VULKAN_WRAPPER_SYMBOL_EXPORT void runOnVulkan(void *vkRuntimeManager) { |
105 | reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)->runOnVulkan(); |
106 | } |
107 | |
108 | VULKAN_WRAPPER_SYMBOL_EXPORT void setEntryPoint(void *vkRuntimeManager, |
109 | const char *entryPoint) { |
110 | reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager) |
111 | ->setEntryPoint(entryPoint); |
112 | } |
113 | |
114 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
115 | setNumWorkGroups(void *vkRuntimeManager, uint32_t x, uint32_t y, uint32_t z) { |
116 | reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager) |
117 | ->setNumWorkGroups({.x: x, .y: y, .z: z}); |
118 | } |
119 | |
120 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
121 | setBinaryShader(void *vkRuntimeManager, uint8_t *shader, uint32_t size) { |
122 | reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager) |
123 | ->setShaderModule(shader, size); |
124 | } |
125 | |
126 | /// Binds the given memref to the given descriptor set and descriptor |
127 | /// index. |
128 | #define DECLARE_BIND_MEMREF(size, type, typeName) \ |
129 | VULKAN_WRAPPER_SYMBOL_EXPORT void bindMemRef##size##D##typeName( \ |
130 | void *vkRuntimeManager, DescriptorSetIndex setIndex, \ |
131 | BindingIndex bindIndex, MemRefDescriptor<type, size> *ptr) { \ |
132 | bindMemRef<type, size>(vkRuntimeManager, setIndex, bindIndex, ptr); \ |
133 | } |
134 | |
135 | DECLARE_BIND_MEMREF(1, float, Float) |
136 | DECLARE_BIND_MEMREF(2, float, Float) |
137 | DECLARE_BIND_MEMREF(3, float, Float) |
138 | DECLARE_BIND_MEMREF(1, int32_t, Int32) |
139 | DECLARE_BIND_MEMREF(2, int32_t, Int32) |
140 | DECLARE_BIND_MEMREF(3, int32_t, Int32) |
141 | DECLARE_BIND_MEMREF(1, int16_t, Int16) |
142 | DECLARE_BIND_MEMREF(2, int16_t, Int16) |
143 | DECLARE_BIND_MEMREF(3, int16_t, Int16) |
144 | DECLARE_BIND_MEMREF(1, int8_t, Int8) |
145 | DECLARE_BIND_MEMREF(2, int8_t, Int8) |
146 | DECLARE_BIND_MEMREF(3, int8_t, Int8) |
147 | DECLARE_BIND_MEMREF(1, int16_t, Half) |
148 | DECLARE_BIND_MEMREF(2, int16_t, Half) |
149 | DECLARE_BIND_MEMREF(3, int16_t, Half) |
150 | |
151 | /// Fills the given 1D float memref with the given float value. |
152 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
153 | _mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT |
154 | float value) { |
155 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0], value: value); |
156 | } |
157 | |
158 | /// Fills the given 2D float memref with the given float value. |
159 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
160 | _mlir_ciface_fillResource2DFloat(MemRefDescriptor<float, 2> *ptr, // NOLINT |
161 | float value) { |
162 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1], value: value); |
163 | } |
164 | |
165 | /// Fills the given 3D float memref with the given float value. |
166 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
167 | _mlir_ciface_fillResource3DFloat(MemRefDescriptor<float, 3> *ptr, // NOLINT |
168 | float value) { |
169 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], |
170 | value: value); |
171 | } |
172 | |
173 | /// Fills the given 1D int memref with the given int value. |
174 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
175 | _mlir_ciface_fillResource1DInt(MemRefDescriptor<int32_t, 1> *ptr, // NOLINT |
176 | int32_t value) { |
177 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0], value: value); |
178 | } |
179 | |
180 | /// Fills the given 2D int memref with the given int value. |
181 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
182 | _mlir_ciface_fillResource2DInt(MemRefDescriptor<int32_t, 2> *ptr, // NOLINT |
183 | int32_t value) { |
184 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1], value: value); |
185 | } |
186 | |
187 | /// Fills the given 3D int memref with the given int value. |
188 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
189 | _mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t, 3> *ptr, // NOLINT |
190 | int32_t value) { |
191 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], |
192 | value: value); |
193 | } |
194 | |
195 | /// Fills the given 1D int memref with the given int8 value. |
196 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
197 | _mlir_ciface_fillResource1DInt8(MemRefDescriptor<int8_t, 1> *ptr, // NOLINT |
198 | int8_t value) { |
199 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0], value: value); |
200 | } |
201 | |
202 | /// Fills the given 2D int memref with the given int8 value. |
203 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
204 | _mlir_ciface_fillResource2DInt8(MemRefDescriptor<int8_t, 2> *ptr, // NOLINT |
205 | int8_t value) { |
206 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1], value: value); |
207 | } |
208 | |
209 | /// Fills the given 3D int memref with the given int8 value. |
210 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
211 | _mlir_ciface_fillResource3DInt8(MemRefDescriptor<int8_t, 3> *ptr, // NOLINT |
212 | int8_t value) { |
213 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], |
214 | value: value); |
215 | } |
216 | } |
217 | |