1//===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file declares Vulkan runtime API.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef VULKAN_RUNTIME_H
14#define VULKAN_RUNTIME_H
15
16#include "mlir/Support/LogicalResult.h"
17
18#include <unordered_map>
19#include <vector>
20#include <vulkan/vulkan.h>
21
22using namespace mlir;
23
24using DescriptorSetIndex = uint32_t;
25using BindingIndex = uint32_t;
26
27/// Struct containing information regarding to a device memory buffer.
28struct VulkanDeviceMemoryBuffer {
29 BindingIndex bindingIndex{0};
30 VkDescriptorType descriptorType{VK_DESCRIPTOR_TYPE_MAX_ENUM};
31 VkDescriptorBufferInfo bufferInfo{};
32 VkBuffer hostBuffer{VK_NULL_HANDLE};
33 VkDeviceMemory hostMemory{VK_NULL_HANDLE};
34 VkBuffer deviceBuffer{VK_NULL_HANDLE};
35 VkDeviceMemory deviceMemory{VK_NULL_HANDLE};
36 uint32_t bufferSize{0};
37};
38
39/// Struct containing information regarding to a host memory buffer.
40struct VulkanHostMemoryBuffer {
41 /// Pointer to a host memory.
42 void *ptr{nullptr};
43 /// Size of a host memory in bytes.
44 uint32_t size{0};
45};
46
47/// Struct containing the number of local workgroups to dispatch for each
48/// dimension.
49struct NumWorkGroups {
50 uint32_t x{1};
51 uint32_t y{1};
52 uint32_t z{1};
53};
54
55/// Struct containing information regarding a descriptor set.
56struct DescriptorSetInfo {
57 /// Index of a descriptor set in descriptor sets.
58 DescriptorSetIndex descriptorSet{0};
59 /// Number of descriptors in a set.
60 uint32_t descriptorSize{0};
61 /// Type of a descriptor set.
62 VkDescriptorType descriptorType{VK_DESCRIPTOR_TYPE_MAX_ENUM};
63};
64
65/// VulkanHostMemoryBuffer mapped into a descriptor set and a binding.
66using ResourceData = std::unordered_map<
67 DescriptorSetIndex,
68 std::unordered_map<BindingIndex, VulkanHostMemoryBuffer>>;
69
70/// SPIR-V storage classes.
71/// Note that this duplicates spirv::StorageClass but it keeps the Vulkan
72/// runtime library detached from SPIR-V dialect, so we can avoid pick up lots
73/// of dependencies.
74enum class SPIRVStorageClass {
75 Uniform = 2,
76 StorageBuffer = 12,
77};
78
79/// StorageClass mapped into a descriptor set and a binding.
80using ResourceStorageClassBindingMap =
81 std::unordered_map<DescriptorSetIndex,
82 std::unordered_map<BindingIndex, SPIRVStorageClass>>;
83
84/// Vulkan runtime.
85/// The purpose of this class is to run SPIR-V compute shader on Vulkan
86/// device.
87/// Before the run, user must provide and set resource data with descriptors,
88/// SPIR-V shader, number of work groups and entry point. After the creation of
89/// VulkanRuntime, special methods must be called in the following
90/// sequence: initRuntime(), run(), updateHostMemoryBuffers(), destroy();
91/// each method in the sequence returns success or failure depends on the Vulkan
92/// result code.
93class VulkanRuntime {
94public:
95 explicit VulkanRuntime() = default;
96 VulkanRuntime(const VulkanRuntime &) = delete;
97 VulkanRuntime &operator=(const VulkanRuntime &) = delete;
98
99 /// Sets needed data for Vulkan runtime.
100 void setResourceData(const ResourceData &resData);
101 void setResourceData(const DescriptorSetIndex desIndex,
102 const BindingIndex bindIndex,
103 const VulkanHostMemoryBuffer &hostMemBuffer);
104 void setShaderModule(uint8_t *shader, uint32_t size);
105 void setNumWorkGroups(const NumWorkGroups &numberWorkGroups);
106 void setResourceStorageClassBindingMap(
107 const ResourceStorageClassBindingMap &stClassData);
108 void setEntryPoint(const char *entryPointName);
109
110 /// Runtime initialization.
111 LogicalResult initRuntime();
112
113 /// Runs runtime.
114 LogicalResult run();
115
116 /// Updates host memory buffers.
117 LogicalResult updateHostMemoryBuffers();
118
119 /// Destroys all created vulkan objects and resources.
120 LogicalResult destroy();
121
122private:
123 //===--------------------------------------------------------------------===//
124 // Pipeline creation methods.
125 //===--------------------------------------------------------------------===//
126
127 LogicalResult createInstance();
128 LogicalResult createDevice();
129 LogicalResult getBestComputeQueue();
130 LogicalResult createMemoryBuffers();
131 LogicalResult createShaderModule();
132 void initDescriptorSetLayoutBindingMap();
133 LogicalResult createDescriptorSetLayout();
134 LogicalResult createPipelineLayout();
135 LogicalResult createComputePipeline();
136 LogicalResult createDescriptorPool();
137 LogicalResult allocateDescriptorSets();
138 LogicalResult setWriteDescriptors();
139 LogicalResult createCommandPool();
140 LogicalResult createQueryPool();
141 LogicalResult createComputeCommandBuffer();
142 LogicalResult submitCommandBuffersToQueue();
143 // Copy resources from host (staging buffer) to device buffer or from device
144 // buffer to host buffer.
145 LogicalResult copyResource(bool deviceToHost);
146
147 //===--------------------------------------------------------------------===//
148 // Helper methods.
149 //===--------------------------------------------------------------------===//
150
151 /// Maps storage class to a descriptor type.
152 LogicalResult
153 mapStorageClassToDescriptorType(SPIRVStorageClass storageClass,
154 VkDescriptorType &descriptorType);
155
156 /// Maps storage class to buffer usage flags.
157 LogicalResult
158 mapStorageClassToBufferUsageFlag(SPIRVStorageClass storageClass,
159 VkBufferUsageFlagBits &bufferUsage);
160
161 LogicalResult countDeviceMemorySize();
162
163 //===--------------------------------------------------------------------===//
164 // Vulkan objects.
165 //===--------------------------------------------------------------------===//
166
167 VkInstance instance{VK_NULL_HANDLE};
168 VkPhysicalDevice physicalDevice{VK_NULL_HANDLE};
169 VkDevice device{VK_NULL_HANDLE};
170 VkQueue queue{VK_NULL_HANDLE};
171
172 /// Specifies VulkanDeviceMemoryBuffers divided into sets.
173 std::unordered_map<DescriptorSetIndex, std::vector<VulkanDeviceMemoryBuffer>>
174 deviceMemoryBufferMap;
175
176 /// Specifies shader module.
177 VkShaderModule shaderModule{VK_NULL_HANDLE};
178
179 /// Specifies layout bindings.
180 std::unordered_map<DescriptorSetIndex,
181 std::vector<VkDescriptorSetLayoutBinding>>
182 descriptorSetLayoutBindingMap;
183
184 /// Specifies layouts of descriptor sets.
185 std::vector<VkDescriptorSetLayout> descriptorSetLayouts;
186 VkPipelineLayout pipelineLayout{VK_NULL_HANDLE};
187
188 /// Specifies descriptor sets.
189 std::vector<VkDescriptorSet> descriptorSets;
190
191 /// Specifies a pool of descriptor set info, each descriptor set must have
192 /// information such as type, index and amount of bindings.
193 std::vector<DescriptorSetInfo> descriptorSetInfoPool;
194 VkDescriptorPool descriptorPool{VK_NULL_HANDLE};
195
196 /// Timestamp query.
197 VkQueryPool queryPool{VK_NULL_HANDLE};
198 // Number of nonoseconds for timestamp to increase 1
199 float timestampPeriod{0.f};
200
201 /// Computation pipeline.
202 VkPipeline pipeline{VK_NULL_HANDLE};
203 VkCommandPool commandPool{VK_NULL_HANDLE};
204 std::vector<VkCommandBuffer> commandBuffers;
205
206 //===--------------------------------------------------------------------===//
207 // Vulkan memory context.
208 //===--------------------------------------------------------------------===//
209
210 uint32_t queueFamilyIndex{0};
211 VkQueueFamilyProperties queueFamilyProperties{};
212 uint32_t hostMemoryTypeIndex{VK_MAX_MEMORY_TYPES};
213 uint32_t deviceMemoryTypeIndex{VK_MAX_MEMORY_TYPES};
214 VkDeviceSize memorySize{0};
215
216 //===--------------------------------------------------------------------===//
217 // Vulkan execution context.
218 //===--------------------------------------------------------------------===//
219
220 NumWorkGroups numWorkGroups;
221 const char *entryPoint{nullptr};
222 uint8_t *binary{nullptr};
223 uint32_t binarySize{0};
224
225 //===--------------------------------------------------------------------===//
226 // Vulkan resource data and storage classes.
227 //===--------------------------------------------------------------------===//
228
229 ResourceData resourceData;
230 ResourceStorageClassBindingMap resourceStorageClassData;
231};
232#endif
233

source code of mlir/tools/mlir-vulkan-runner/VulkanRuntime.h