| 1 | //===- VulkanRuntimeWrappers.cpp - MLIR Vulkan runner wrapper library -----===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Implements C runtime wrappers around the VulkanRuntime. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include <iostream> |
| 14 | #include <mutex> |
| 15 | #include <numeric> |
| 16 | #include <string> |
| 17 | #include <vector> |
| 18 | |
| 19 | #include "VulkanRuntime.h" |
| 20 | |
| 21 | // Explicitly export entry points to the vulkan-runtime-wrapper. |
| 22 | |
| 23 | #ifdef _WIN32 |
| 24 | #define VULKAN_WRAPPER_SYMBOL_EXPORT __declspec(dllexport) |
| 25 | #else |
| 26 | #define VULKAN_WRAPPER_SYMBOL_EXPORT __attribute__((visibility("default"))) |
| 27 | #endif // _WIN32 |
| 28 | |
| 29 | namespace { |
| 30 | |
| 31 | class VulkanModule; |
| 32 | |
| 33 | // Class to be a thing that can be returned from `mgpuModuleGetFunction`. |
| 34 | struct VulkanFunction { |
| 35 | VulkanModule *module; |
| 36 | std::string name; |
| 37 | |
| 38 | VulkanFunction(VulkanModule *module, const char *name) |
| 39 | : module(module), name(name) {} |
| 40 | }; |
| 41 | |
| 42 | // Class to own a copy of the SPIR-V provided to `mgpuModuleLoad` and to manage |
| 43 | // allocation of pointers returned from `mgpuModuleGetFunction`. |
| 44 | class VulkanModule { |
| 45 | public: |
| 46 | VulkanModule(const uint8_t *ptr, size_t sizeInBytes) |
| 47 | : blob(ptr, ptr + sizeInBytes) {} |
| 48 | ~VulkanModule() = default; |
| 49 | |
| 50 | VulkanFunction *getFunction(const char *name) { |
| 51 | return functions.emplace_back(args: std::make_unique<VulkanFunction>(args: this, args&: name)) |
| 52 | .get(); |
| 53 | } |
| 54 | |
| 55 | uint8_t *blobData() { return blob.data(); } |
| 56 | size_t blobSizeInBytes() const { return blob.size(); } |
| 57 | |
| 58 | private: |
| 59 | std::vector<uint8_t> blob; |
| 60 | std::vector<std::unique_ptr<VulkanFunction>> functions; |
| 61 | }; |
| 62 | |
| 63 | class VulkanRuntimeManager { |
| 64 | public: |
| 65 | VulkanRuntimeManager() = default; |
| 66 | VulkanRuntimeManager(const VulkanRuntimeManager &) = delete; |
| 67 | VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete; |
| 68 | ~VulkanRuntimeManager() = default; |
| 69 | |
| 70 | void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex, |
| 71 | const VulkanHostMemoryBuffer &memBuffer) { |
| 72 | std::lock_guard<std::mutex> lock(mutex); |
| 73 | vulkanRuntime.setResourceData(desIndex: setIndex, bindIndex, hostMemBuffer: memBuffer); |
| 74 | } |
| 75 | |
| 76 | void setEntryPoint(const char *entryPoint) { |
| 77 | std::lock_guard<std::mutex> lock(mutex); |
| 78 | vulkanRuntime.setEntryPoint(entryPoint); |
| 79 | } |
| 80 | |
| 81 | void setNumWorkGroups(NumWorkGroups numWorkGroups) { |
| 82 | std::lock_guard<std::mutex> lock(mutex); |
| 83 | vulkanRuntime.setNumWorkGroups(numWorkGroups); |
| 84 | } |
| 85 | |
| 86 | void setShaderModule(uint8_t *shader, uint32_t size) { |
| 87 | std::lock_guard<std::mutex> lock(mutex); |
| 88 | vulkanRuntime.setShaderModule(shader, size); |
| 89 | } |
| 90 | |
| 91 | void runOnVulkan() { |
| 92 | std::lock_guard<std::mutex> lock(mutex); |
| 93 | if (failed(Result: vulkanRuntime.initRuntime()) || failed(Result: vulkanRuntime.run()) || |
| 94 | failed(Result: vulkanRuntime.updateHostMemoryBuffers()) || |
| 95 | failed(Result: vulkanRuntime.destroy())) { |
| 96 | std::cerr << "runOnVulkan failed" ; |
| 97 | } |
| 98 | } |
| 99 | |
| 100 | private: |
| 101 | VulkanRuntime vulkanRuntime; |
| 102 | std::mutex mutex; |
| 103 | }; |
| 104 | |
| 105 | } // namespace |
| 106 | |
| 107 | template <typename T, int N> |
| 108 | struct MemRefDescriptor { |
| 109 | T *allocated; |
| 110 | T *aligned; |
| 111 | int64_t offset; |
| 112 | int64_t sizes[N]; |
| 113 | int64_t strides[N]; |
| 114 | }; |
| 115 | |
| 116 | extern "C" { |
| 117 | |
| 118 | //===----------------------------------------------------------------------===// |
| 119 | // |
| 120 | // Wrappers intended for mlir-runner. Uses of GPU dialect operations get |
| 121 | // lowered to calls to these functions by GPUToLLVMConversionPass. |
| 122 | // |
| 123 | //===----------------------------------------------------------------------===// |
| 124 | |
| 125 | VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuStreamCreate() { |
| 126 | return new VulkanRuntimeManager(); |
| 127 | } |
| 128 | |
| 129 | VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamDestroy(void *vkRuntimeManager) { |
| 130 | delete static_cast<VulkanRuntimeManager *>(vkRuntimeManager); |
| 131 | } |
| 132 | |
| 133 | VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamSynchronize(void *) { |
| 134 | // Currently a no-op as the other operations are synchronous. |
| 135 | } |
| 136 | |
| 137 | VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuModuleLoad(const void *data, |
| 138 | size_t gpuBlobSize) { |
| 139 | // gpuBlobSize is the size of the data in bytes. |
| 140 | return new VulkanModule(static_cast<const uint8_t *>(data), gpuBlobSize); |
| 141 | } |
| 142 | |
| 143 | VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuModuleUnload(void *vkModule) { |
| 144 | delete static_cast<VulkanModule *>(vkModule); |
| 145 | } |
| 146 | |
| 147 | VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuModuleGetFunction(void *vkModule, |
| 148 | const char *name) { |
| 149 | if (!vkModule) |
| 150 | abort(); |
| 151 | return static_cast<VulkanModule *>(vkModule)->getFunction(name); |
| 152 | } |
| 153 | |
| 154 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
| 155 | mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ, |
| 156 | size_t /*blockX*/, size_t /*blockY*/, size_t /*blockZ*/, |
| 157 | size_t /*smem*/, void *vkRuntimeManager, void **params, |
| 158 | void ** /*extra*/, size_t paramsCount) { |
| 159 | auto manager = static_cast<VulkanRuntimeManager *>(vkRuntimeManager); |
| 160 | |
| 161 | // GpuToLLVMConversionPass with the kernelBarePtrCallConv and |
| 162 | // kernelIntersperseSizeCallConv options will set up the params array like: |
| 163 | // { &memref_ptr0, &memref_size0, &memref_ptr1, &memref_size1, ... } |
| 164 | const size_t paramsPerMemRef = 2; |
| 165 | if (paramsCount % paramsPerMemRef != 0) { |
| 166 | abort(); // This would indicate a serious calling convention mismatch. |
| 167 | } |
| 168 | const DescriptorSetIndex setIndex = 0; |
| 169 | BindingIndex bindIndex = 0; |
| 170 | for (size_t i = 0; i < paramsCount; i += paramsPerMemRef) { |
| 171 | void *memrefBufferBasePtr = *static_cast<void **>(params[i + 0]); |
| 172 | size_t memrefBufferSize = *static_cast<size_t *>(params[i + 1]); |
| 173 | VulkanHostMemoryBuffer memBuffer{.ptr: memrefBufferBasePtr, |
| 174 | .size: static_cast<uint32_t>(memrefBufferSize)}; |
| 175 | manager->setResourceData(setIndex, bindIndex, memBuffer); |
| 176 | ++bindIndex; |
| 177 | } |
| 178 | |
| 179 | manager->setNumWorkGroups(NumWorkGroups{.x: static_cast<uint32_t>(gridX), |
| 180 | .y: static_cast<uint32_t>(gridY), |
| 181 | .z: static_cast<uint32_t>(gridZ)}); |
| 182 | |
| 183 | auto function = static_cast<VulkanFunction *>(vkKernel); |
| 184 | // Expected size should be in bytes. |
| 185 | manager->setShaderModule( |
| 186 | shader: function->module->blobData(), |
| 187 | size: static_cast<uint32_t>(function->module->blobSizeInBytes())); |
| 188 | manager->setEntryPoint(function->name.c_str()); |
| 189 | |
| 190 | manager->runOnVulkan(); |
| 191 | } |
| 192 | |
| 193 | //===----------------------------------------------------------------------===// |
| 194 | // |
| 195 | // Miscellaneous utility functions that can be directly used by tests. |
| 196 | // |
| 197 | //===----------------------------------------------------------------------===// |
| 198 | |
| 199 | /// Fills the given 1D float memref with the given float value. |
| 200 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
| 201 | _mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT |
| 202 | float value) { |
| 203 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0], value: value); |
| 204 | } |
| 205 | |
| 206 | /// Fills the given 2D float memref with the given float value. |
| 207 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
| 208 | _mlir_ciface_fillResource2DFloat(MemRefDescriptor<float, 2> *ptr, // NOLINT |
| 209 | float value) { |
| 210 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1], value: value); |
| 211 | } |
| 212 | |
| 213 | /// Fills the given 3D float memref with the given float value. |
| 214 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
| 215 | _mlir_ciface_fillResource3DFloat(MemRefDescriptor<float, 3> *ptr, // NOLINT |
| 216 | float value) { |
| 217 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], |
| 218 | value: value); |
| 219 | } |
| 220 | |
| 221 | /// Fills the given 1D int memref with the given int value. |
| 222 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
| 223 | _mlir_ciface_fillResource1DInt(MemRefDescriptor<int32_t, 1> *ptr, // NOLINT |
| 224 | int32_t value) { |
| 225 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0], value: value); |
| 226 | } |
| 227 | |
| 228 | /// Fills the given 2D int memref with the given int value. |
| 229 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
| 230 | _mlir_ciface_fillResource2DInt(MemRefDescriptor<int32_t, 2> *ptr, // NOLINT |
| 231 | int32_t value) { |
| 232 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1], value: value); |
| 233 | } |
| 234 | |
| 235 | /// Fills the given 3D int memref with the given int value. |
| 236 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
| 237 | _mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t, 3> *ptr, // NOLINT |
| 238 | int32_t value) { |
| 239 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], |
| 240 | value: value); |
| 241 | } |
| 242 | |
| 243 | /// Fills the given 1D int memref with the given int8 value. |
| 244 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
| 245 | _mlir_ciface_fillResource1DInt8(MemRefDescriptor<int8_t, 1> *ptr, // NOLINT |
| 246 | int8_t value) { |
| 247 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0], value: value); |
| 248 | } |
| 249 | |
| 250 | /// Fills the given 2D int memref with the given int8 value. |
| 251 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
| 252 | _mlir_ciface_fillResource2DInt8(MemRefDescriptor<int8_t, 2> *ptr, // NOLINT |
| 253 | int8_t value) { |
| 254 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1], value: value); |
| 255 | } |
| 256 | |
| 257 | /// Fills the given 3D int memref with the given int8 value. |
| 258 | VULKAN_WRAPPER_SYMBOL_EXPORT void |
| 259 | _mlir_ciface_fillResource3DInt8(MemRefDescriptor<int8_t, 3> *ptr, // NOLINT |
| 260 | int8_t value) { |
| 261 | std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], |
| 262 | value: value); |
| 263 | } |
| 264 | } |
| 265 | |