1 | //===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file declares Vulkan runtime API. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef VULKAN_RUNTIME_H |
14 | #define VULKAN_RUNTIME_H |
15 | |
16 | #include "mlir/Support/LogicalResult.h" |
17 | |
18 | #include <unordered_map> |
19 | #include <vector> |
20 | #include <vulkan/vulkan.h> |
21 | |
22 | using namespace mlir; |
23 | |
24 | using DescriptorSetIndex = uint32_t; |
25 | using BindingIndex = uint32_t; |
26 | |
27 | /// Struct containing information regarding to a device memory buffer. |
28 | struct VulkanDeviceMemoryBuffer { |
29 | BindingIndex bindingIndex{0}; |
30 | VkDescriptorType descriptorType{VK_DESCRIPTOR_TYPE_MAX_ENUM}; |
31 | VkDescriptorBufferInfo bufferInfo{}; |
32 | VkBuffer hostBuffer{VK_NULL_HANDLE}; |
33 | VkDeviceMemory hostMemory{VK_NULL_HANDLE}; |
34 | VkBuffer deviceBuffer{VK_NULL_HANDLE}; |
35 | VkDeviceMemory deviceMemory{VK_NULL_HANDLE}; |
36 | uint32_t bufferSize{0}; |
37 | }; |
38 | |
39 | /// Struct containing information regarding to a host memory buffer. |
40 | struct VulkanHostMemoryBuffer { |
41 | /// Pointer to a host memory. |
42 | void *ptr{nullptr}; |
43 | /// Size of a host memory in bytes. |
44 | uint32_t size{0}; |
45 | }; |
46 | |
47 | /// Struct containing the number of local workgroups to dispatch for each |
48 | /// dimension. |
49 | struct NumWorkGroups { |
50 | uint32_t x{1}; |
51 | uint32_t y{1}; |
52 | uint32_t z{1}; |
53 | }; |
54 | |
55 | /// Struct containing information regarding a descriptor set. |
56 | struct DescriptorSetInfo { |
57 | /// Index of a descriptor set in descriptor sets. |
58 | DescriptorSetIndex descriptorSet{0}; |
59 | /// Number of descriptors in a set. |
60 | uint32_t descriptorSize{0}; |
61 | /// Type of a descriptor set. |
62 | VkDescriptorType descriptorType{VK_DESCRIPTOR_TYPE_MAX_ENUM}; |
63 | }; |
64 | |
65 | /// VulkanHostMemoryBuffer mapped into a descriptor set and a binding. |
66 | using ResourceData = std::unordered_map< |
67 | DescriptorSetIndex, |
68 | std::unordered_map<BindingIndex, VulkanHostMemoryBuffer>>; |
69 | |
70 | /// SPIR-V storage classes. |
71 | /// Note that this duplicates spirv::StorageClass but it keeps the Vulkan |
72 | /// runtime library detached from SPIR-V dialect, so we can avoid pick up lots |
73 | /// of dependencies. |
74 | enum class SPIRVStorageClass { |
75 | Uniform = 2, |
76 | StorageBuffer = 12, |
77 | }; |
78 | |
79 | /// StorageClass mapped into a descriptor set and a binding. |
80 | using ResourceStorageClassBindingMap = |
81 | std::unordered_map<DescriptorSetIndex, |
82 | std::unordered_map<BindingIndex, SPIRVStorageClass>>; |
83 | |
84 | /// Vulkan runtime. |
85 | /// The purpose of this class is to run SPIR-V compute shader on Vulkan |
86 | /// device. |
87 | /// Before the run, user must provide and set resource data with descriptors, |
88 | /// SPIR-V shader, number of work groups and entry point. After the creation of |
89 | /// VulkanRuntime, special methods must be called in the following |
90 | /// sequence: initRuntime(), run(), updateHostMemoryBuffers(), destroy(); |
91 | /// each method in the sequence returns success or failure depends on the Vulkan |
92 | /// result code. |
93 | class VulkanRuntime { |
94 | public: |
95 | explicit VulkanRuntime() = default; |
96 | VulkanRuntime(const VulkanRuntime &) = delete; |
97 | VulkanRuntime &operator=(const VulkanRuntime &) = delete; |
98 | |
99 | /// Sets needed data for Vulkan runtime. |
100 | void setResourceData(const ResourceData &resData); |
101 | void setResourceData(const DescriptorSetIndex desIndex, |
102 | const BindingIndex bindIndex, |
103 | const VulkanHostMemoryBuffer &hostMemBuffer); |
104 | void setShaderModule(uint8_t *shader, uint32_t size); |
105 | void setNumWorkGroups(const NumWorkGroups &numberWorkGroups); |
106 | void setResourceStorageClassBindingMap( |
107 | const ResourceStorageClassBindingMap &stClassData); |
108 | void setEntryPoint(const char *entryPointName); |
109 | |
110 | /// Runtime initialization. |
111 | LogicalResult initRuntime(); |
112 | |
113 | /// Runs runtime. |
114 | LogicalResult run(); |
115 | |
116 | /// Updates host memory buffers. |
117 | LogicalResult updateHostMemoryBuffers(); |
118 | |
119 | /// Destroys all created vulkan objects and resources. |
120 | LogicalResult destroy(); |
121 | |
122 | private: |
123 | //===--------------------------------------------------------------------===// |
124 | // Pipeline creation methods. |
125 | //===--------------------------------------------------------------------===// |
126 | |
127 | LogicalResult createInstance(); |
128 | LogicalResult createDevice(); |
129 | LogicalResult getBestComputeQueue(); |
130 | LogicalResult createMemoryBuffers(); |
131 | LogicalResult createShaderModule(); |
132 | void initDescriptorSetLayoutBindingMap(); |
133 | LogicalResult createDescriptorSetLayout(); |
134 | LogicalResult createPipelineLayout(); |
135 | LogicalResult createComputePipeline(); |
136 | LogicalResult createDescriptorPool(); |
137 | LogicalResult allocateDescriptorSets(); |
138 | LogicalResult setWriteDescriptors(); |
139 | LogicalResult createCommandPool(); |
140 | LogicalResult createQueryPool(); |
141 | LogicalResult createComputeCommandBuffer(); |
142 | LogicalResult submitCommandBuffersToQueue(); |
143 | // Copy resources from host (staging buffer) to device buffer or from device |
144 | // buffer to host buffer. |
145 | LogicalResult copyResource(bool deviceToHost); |
146 | |
147 | //===--------------------------------------------------------------------===// |
148 | // Helper methods. |
149 | //===--------------------------------------------------------------------===// |
150 | |
151 | /// Maps storage class to a descriptor type. |
152 | LogicalResult |
153 | mapStorageClassToDescriptorType(SPIRVStorageClass storageClass, |
154 | VkDescriptorType &descriptorType); |
155 | |
156 | /// Maps storage class to buffer usage flags. |
157 | LogicalResult |
158 | mapStorageClassToBufferUsageFlag(SPIRVStorageClass storageClass, |
159 | VkBufferUsageFlagBits &bufferUsage); |
160 | |
161 | LogicalResult countDeviceMemorySize(); |
162 | |
163 | //===--------------------------------------------------------------------===// |
164 | // Vulkan objects. |
165 | //===--------------------------------------------------------------------===// |
166 | |
167 | VkInstance instance{VK_NULL_HANDLE}; |
168 | VkPhysicalDevice physicalDevice{VK_NULL_HANDLE}; |
169 | VkDevice device{VK_NULL_HANDLE}; |
170 | VkQueue queue{VK_NULL_HANDLE}; |
171 | |
172 | /// Specifies VulkanDeviceMemoryBuffers divided into sets. |
173 | std::unordered_map<DescriptorSetIndex, std::vector<VulkanDeviceMemoryBuffer>> |
174 | deviceMemoryBufferMap; |
175 | |
176 | /// Specifies shader module. |
177 | VkShaderModule shaderModule{VK_NULL_HANDLE}; |
178 | |
179 | /// Specifies layout bindings. |
180 | std::unordered_map<DescriptorSetIndex, |
181 | std::vector<VkDescriptorSetLayoutBinding>> |
182 | descriptorSetLayoutBindingMap; |
183 | |
184 | /// Specifies layouts of descriptor sets. |
185 | std::vector<VkDescriptorSetLayout> descriptorSetLayouts; |
186 | VkPipelineLayout pipelineLayout{VK_NULL_HANDLE}; |
187 | |
188 | /// Specifies descriptor sets. |
189 | std::vector<VkDescriptorSet> descriptorSets; |
190 | |
191 | /// Specifies a pool of descriptor set info, each descriptor set must have |
192 | /// information such as type, index and amount of bindings. |
193 | std::vector<DescriptorSetInfo> descriptorSetInfoPool; |
194 | VkDescriptorPool descriptorPool{VK_NULL_HANDLE}; |
195 | |
196 | /// Timestamp query. |
197 | VkQueryPool queryPool{VK_NULL_HANDLE}; |
198 | // Number of nonoseconds for timestamp to increase 1 |
199 | float timestampPeriod{0.f}; |
200 | |
201 | /// Computation pipeline. |
202 | VkPipeline pipeline{VK_NULL_HANDLE}; |
203 | VkCommandPool commandPool{VK_NULL_HANDLE}; |
204 | std::vector<VkCommandBuffer> commandBuffers; |
205 | |
206 | //===--------------------------------------------------------------------===// |
207 | // Vulkan memory context. |
208 | //===--------------------------------------------------------------------===// |
209 | |
210 | uint32_t queueFamilyIndex{0}; |
211 | VkQueueFamilyProperties queueFamilyProperties{}; |
212 | uint32_t hostMemoryTypeIndex{VK_MAX_MEMORY_TYPES}; |
213 | uint32_t deviceMemoryTypeIndex{VK_MAX_MEMORY_TYPES}; |
214 | VkDeviceSize memorySize{0}; |
215 | |
216 | //===--------------------------------------------------------------------===// |
217 | // Vulkan execution context. |
218 | //===--------------------------------------------------------------------===// |
219 | |
220 | NumWorkGroups numWorkGroups; |
221 | const char *entryPoint{nullptr}; |
222 | uint8_t *binary{nullptr}; |
223 | uint32_t binarySize{0}; |
224 | |
225 | //===--------------------------------------------------------------------===// |
226 | // Vulkan resource data and storage classes. |
227 | //===--------------------------------------------------------------------===// |
228 | |
229 | ResourceData resourceData; |
230 | ResourceStorageClassBindingMap resourceStorageClassData; |
231 | }; |
232 | #endif |
233 | |