1//===- vulkan-runtime-wrappers.cpp - MLIR Vulkan runner wrapper library ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implements C runtime wrappers around the VulkanRuntime.
10//
11//===----------------------------------------------------------------------===//
12
13#include <iostream>
14#include <mutex>
15#include <numeric>
16
17#include "VulkanRuntime.h"
18
19// Explicitly export entry points to the vulkan-runtime-wrapper.
20
21#ifdef _WIN32
22#define VULKAN_WRAPPER_SYMBOL_EXPORT __declspec(dllexport)
23#else
24#define VULKAN_WRAPPER_SYMBOL_EXPORT __attribute__((visibility("default")))
25#endif // _WIN32
26
27namespace {
28
29class VulkanRuntimeManager {
30public:
31 VulkanRuntimeManager() = default;
32 VulkanRuntimeManager(const VulkanRuntimeManager &) = delete;
33 VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete;
34 ~VulkanRuntimeManager() = default;
35
36 void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex,
37 const VulkanHostMemoryBuffer &memBuffer) {
38 std::lock_guard<std::mutex> lock(mutex);
39 vulkanRuntime.setResourceData(desIndex: setIndex, bindIndex, hostMemBuffer: memBuffer);
40 }
41
42 void setEntryPoint(const char *entryPoint) {
43 std::lock_guard<std::mutex> lock(mutex);
44 vulkanRuntime.setEntryPoint(entryPoint);
45 }
46
47 void setNumWorkGroups(NumWorkGroups numWorkGroups) {
48 std::lock_guard<std::mutex> lock(mutex);
49 vulkanRuntime.setNumWorkGroups(numWorkGroups);
50 }
51
52 void setShaderModule(uint8_t *shader, uint32_t size) {
53 std::lock_guard<std::mutex> lock(mutex);
54 vulkanRuntime.setShaderModule(shader, size);
55 }
56
57 void runOnVulkan() {
58 std::lock_guard<std::mutex> lock(mutex);
59 if (failed(result: vulkanRuntime.initRuntime()) || failed(result: vulkanRuntime.run()) ||
60 failed(result: vulkanRuntime.updateHostMemoryBuffers()) ||
61 failed(result: vulkanRuntime.destroy())) {
62 std::cerr << "runOnVulkan failed";
63 }
64 }
65
66private:
67 VulkanRuntime vulkanRuntime;
68 std::mutex mutex;
69};
70
71} // namespace
72
73template <typename T, int N>
74struct MemRefDescriptor {
75 T *allocated;
76 T *aligned;
77 int64_t offset;
78 int64_t sizes[N];
79 int64_t strides[N];
80};
81
82template <typename T, uint32_t S>
83void bindMemRef(void *vkRuntimeManager, DescriptorSetIndex setIndex,
84 BindingIndex bindIndex, MemRefDescriptor<T, S> *ptr) {
85 uint32_t size = sizeof(T);
86 for (unsigned i = 0; i < S; i++)
87 size *= ptr->sizes[i];
88 VulkanHostMemoryBuffer memBuffer{ptr->aligned, size};
89 reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
90 ->setResourceData(setIndex, bindIndex, memBuffer);
91}
92
93extern "C" {
94/// Initializes `VulkanRuntimeManager` and returns a pointer to it.
95VULKAN_WRAPPER_SYMBOL_EXPORT void *initVulkan() {
96 return new VulkanRuntimeManager();
97}
98
99/// Deinitializes `VulkanRuntimeManager` by the given pointer.
100VULKAN_WRAPPER_SYMBOL_EXPORT void deinitVulkan(void *vkRuntimeManager) {
101 delete reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager);
102}
103
104VULKAN_WRAPPER_SYMBOL_EXPORT void runOnVulkan(void *vkRuntimeManager) {
105 reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)->runOnVulkan();
106}
107
108VULKAN_WRAPPER_SYMBOL_EXPORT void setEntryPoint(void *vkRuntimeManager,
109 const char *entryPoint) {
110 reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
111 ->setEntryPoint(entryPoint);
112}
113
114VULKAN_WRAPPER_SYMBOL_EXPORT void
115setNumWorkGroups(void *vkRuntimeManager, uint32_t x, uint32_t y, uint32_t z) {
116 reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
117 ->setNumWorkGroups({.x: x, .y: y, .z: z});
118}
119
120VULKAN_WRAPPER_SYMBOL_EXPORT void
121setBinaryShader(void *vkRuntimeManager, uint8_t *shader, uint32_t size) {
122 reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
123 ->setShaderModule(shader, size);
124}
125
126/// Binds the given memref to the given descriptor set and descriptor
127/// index.
128#define DECLARE_BIND_MEMREF(size, type, typeName) \
129 VULKAN_WRAPPER_SYMBOL_EXPORT void bindMemRef##size##D##typeName( \
130 void *vkRuntimeManager, DescriptorSetIndex setIndex, \
131 BindingIndex bindIndex, MemRefDescriptor<type, size> *ptr) { \
132 bindMemRef<type, size>(vkRuntimeManager, setIndex, bindIndex, ptr); \
133 }
134
135DECLARE_BIND_MEMREF(1, float, Float)
136DECLARE_BIND_MEMREF(2, float, Float)
137DECLARE_BIND_MEMREF(3, float, Float)
138DECLARE_BIND_MEMREF(1, int32_t, Int32)
139DECLARE_BIND_MEMREF(2, int32_t, Int32)
140DECLARE_BIND_MEMREF(3, int32_t, Int32)
141DECLARE_BIND_MEMREF(1, int16_t, Int16)
142DECLARE_BIND_MEMREF(2, int16_t, Int16)
143DECLARE_BIND_MEMREF(3, int16_t, Int16)
144DECLARE_BIND_MEMREF(1, int8_t, Int8)
145DECLARE_BIND_MEMREF(2, int8_t, Int8)
146DECLARE_BIND_MEMREF(3, int8_t, Int8)
147DECLARE_BIND_MEMREF(1, int16_t, Half)
148DECLARE_BIND_MEMREF(2, int16_t, Half)
149DECLARE_BIND_MEMREF(3, int16_t, Half)
150
151/// Fills the given 1D float memref with the given float value.
152VULKAN_WRAPPER_SYMBOL_EXPORT void
153_mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT
154 float value) {
155 std::fill_n(first: ptr->allocated, n: ptr->sizes[0], value: value);
156}
157
158/// Fills the given 2D float memref with the given float value.
159VULKAN_WRAPPER_SYMBOL_EXPORT void
160_mlir_ciface_fillResource2DFloat(MemRefDescriptor<float, 2> *ptr, // NOLINT
161 float value) {
162 std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1], value: value);
163}
164
165/// Fills the given 3D float memref with the given float value.
166VULKAN_WRAPPER_SYMBOL_EXPORT void
167_mlir_ciface_fillResource3DFloat(MemRefDescriptor<float, 3> *ptr, // NOLINT
168 float value) {
169 std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
170 value: value);
171}
172
173/// Fills the given 1D int memref with the given int value.
174VULKAN_WRAPPER_SYMBOL_EXPORT void
175_mlir_ciface_fillResource1DInt(MemRefDescriptor<int32_t, 1> *ptr, // NOLINT
176 int32_t value) {
177 std::fill_n(first: ptr->allocated, n: ptr->sizes[0], value: value);
178}
179
180/// Fills the given 2D int memref with the given int value.
181VULKAN_WRAPPER_SYMBOL_EXPORT void
182_mlir_ciface_fillResource2DInt(MemRefDescriptor<int32_t, 2> *ptr, // NOLINT
183 int32_t value) {
184 std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1], value: value);
185}
186
187/// Fills the given 3D int memref with the given int value.
188VULKAN_WRAPPER_SYMBOL_EXPORT void
189_mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t, 3> *ptr, // NOLINT
190 int32_t value) {
191 std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
192 value: value);
193}
194
195/// Fills the given 1D int memref with the given int8 value.
196VULKAN_WRAPPER_SYMBOL_EXPORT void
197_mlir_ciface_fillResource1DInt8(MemRefDescriptor<int8_t, 1> *ptr, // NOLINT
198 int8_t value) {
199 std::fill_n(first: ptr->allocated, n: ptr->sizes[0], value: value);
200}
201
202/// Fills the given 2D int memref with the given int8 value.
203VULKAN_WRAPPER_SYMBOL_EXPORT void
204_mlir_ciface_fillResource2DInt8(MemRefDescriptor<int8_t, 2> *ptr, // NOLINT
205 int8_t value) {
206 std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1], value: value);
207}
208
209/// Fills the given 3D int memref with the given int8 value.
210VULKAN_WRAPPER_SYMBOL_EXPORT void
211_mlir_ciface_fillResource3DInt8(MemRefDescriptor<int8_t, 3> *ptr, // NOLINT
212 int8_t value) {
213 std::fill_n(first: ptr->allocated, n: ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
214 value: value);
215}
216}
217

source code of mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp