1//===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file provides a library for running a module on a Vulkan device.
10// Implements a Vulkan runtime.
11//
12//===----------------------------------------------------------------------===//
13
14#include "VulkanRuntime.h"
15
16#include <chrono>
17#include <cstring>
18// TODO: It's generally bad to access stdout/stderr in a library.
19// Figure out a better way for error reporting.
20#include <iomanip>
21#include <iostream>
22
23inline void emitVulkanError(const char *api, VkResult error) {
24 std::cerr << " failed with error code " << error << " when executing " << api;
25}
26
27#define RETURN_ON_VULKAN_ERROR(result, api) \
28 if ((result) != VK_SUCCESS) { \
29 emitVulkanError(api, (result)); \
30 return failure(); \
31 }
32
33using namespace mlir;
34
35void VulkanRuntime::setNumWorkGroups(const NumWorkGroups &numberWorkGroups) {
36 numWorkGroups = numberWorkGroups;
37}
38
39void VulkanRuntime::setResourceStorageClassBindingMap(
40 const ResourceStorageClassBindingMap &stClassData) {
41 resourceStorageClassData = stClassData;
42}
43
44void VulkanRuntime::setResourceData(
45 const DescriptorSetIndex desIndex, const BindingIndex bindIndex,
46 const VulkanHostMemoryBuffer &hostMemBuffer) {
47 resourceData[desIndex][bindIndex] = hostMemBuffer;
48 resourceStorageClassData[desIndex][bindIndex] =
49 SPIRVStorageClass::StorageBuffer;
50}
51
52void VulkanRuntime::setEntryPoint(const char *entryPointName) {
53 entryPoint = entryPointName;
54}
55
56void VulkanRuntime::setResourceData(const ResourceData &resData) {
57 resourceData = resData;
58}
59
60void VulkanRuntime::setShaderModule(uint8_t *shader, uint32_t size) {
61 binary = shader;
62 binarySize = size;
63}
64
65LogicalResult VulkanRuntime::mapStorageClassToDescriptorType(
66 SPIRVStorageClass storageClass, VkDescriptorType &descriptorType) {
67 switch (storageClass) {
68 case SPIRVStorageClass::StorageBuffer:
69 descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
70 break;
71 case SPIRVStorageClass::Uniform:
72 descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
73 break;
74 }
75 return success();
76}
77
78LogicalResult VulkanRuntime::mapStorageClassToBufferUsageFlag(
79 SPIRVStorageClass storageClass, VkBufferUsageFlagBits &bufferUsage) {
80 switch (storageClass) {
81 case SPIRVStorageClass::StorageBuffer:
82 bufferUsage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
83 break;
84 case SPIRVStorageClass::Uniform:
85 bufferUsage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT;
86 break;
87 }
88 return success();
89}
90
91LogicalResult VulkanRuntime::countDeviceMemorySize() {
92 for (const auto &resourceDataMapPair : resourceData) {
93 const auto &resourceDataMap = resourceDataMapPair.second;
94 for (const auto &resourceDataBindingPair : resourceDataMap) {
95 if (resourceDataBindingPair.second.size) {
96 memorySize += resourceDataBindingPair.second.size;
97 } else {
98 std::cerr << "expected buffer size greater than zero for resource data";
99 return failure();
100 }
101 }
102 }
103 return success();
104}
105
106LogicalResult VulkanRuntime::initRuntime() {
107 if (resourceData.empty()) {
108 std::cerr << "Vulkan runtime needs at least one resource";
109 return failure();
110 }
111 if (!binarySize || !binary) {
112 std::cerr << "binary shader size must be greater than zero";
113 return failure();
114 }
115 if (failed(result: countDeviceMemorySize())) {
116 return failure();
117 }
118 return success();
119}
120
121LogicalResult VulkanRuntime::destroy() {
122 // According to Vulkan spec:
123 // "To ensure that no work is active on the device, vkDeviceWaitIdle can be
124 // used to gate the destruction of the device. Prior to destroying a device,
125 // an application is responsible for destroying/freeing any Vulkan objects
126 // that were created using that device as the first parameter of the
127 // corresponding vkCreate* or vkAllocate* command."
128 RETURN_ON_VULKAN_ERROR(vkDeviceWaitIdle(device), "vkDeviceWaitIdle");
129
130 // Free and destroy.
131 vkFreeCommandBuffers(device, commandPool, commandBufferCount: commandBuffers.size(),
132 pCommandBuffers: commandBuffers.data());
133 vkDestroyQueryPool(device, queryPool, pAllocator: nullptr);
134 vkDestroyCommandPool(device, commandPool, pAllocator: nullptr);
135 vkFreeDescriptorSets(device, descriptorPool, descriptorSetCount: descriptorSets.size(),
136 pDescriptorSets: descriptorSets.data());
137 vkDestroyDescriptorPool(device, descriptorPool, pAllocator: nullptr);
138 vkDestroyPipeline(device, pipeline, pAllocator: nullptr);
139 vkDestroyPipelineLayout(device, pipelineLayout, pAllocator: nullptr);
140 for (auto &descriptorSetLayout : descriptorSetLayouts) {
141 vkDestroyDescriptorSetLayout(device, descriptorSetLayout, pAllocator: nullptr);
142 }
143 vkDestroyShaderModule(device, shaderModule, pAllocator: nullptr);
144
145 // For each descriptor set.
146 for (auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
147 auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
148 // For each descriptor binding.
149 for (auto &memoryBuffer : deviceMemoryBuffers) {
150 vkFreeMemory(device, memory: memoryBuffer.deviceMemory, pAllocator: nullptr);
151 vkFreeMemory(device, memory: memoryBuffer.hostMemory, pAllocator: nullptr);
152 vkDestroyBuffer(device, buffer: memoryBuffer.hostBuffer, pAllocator: nullptr);
153 vkDestroyBuffer(device, buffer: memoryBuffer.deviceBuffer, pAllocator: nullptr);
154 }
155 }
156
157 vkDestroyDevice(device, pAllocator: nullptr);
158 vkDestroyInstance(instance, pAllocator: nullptr);
159 return success();
160}
161
162LogicalResult VulkanRuntime::run() {
163 // Create logical device, shader module and memory buffers.
164 if (failed(result: createInstance()) || failed(result: createDevice()) ||
165 failed(result: createMemoryBuffers()) || failed(result: createShaderModule())) {
166 return failure();
167 }
168
169 // Descriptor bindings divided into sets. Each descriptor binding
170 // must have a layout binding attached into a descriptor set layout.
171 // Each layout set must be binded into a pipeline layout.
172 initDescriptorSetLayoutBindingMap();
173 if (failed(result: createDescriptorSetLayout()) || failed(result: createPipelineLayout()) ||
174 // Each descriptor set must be allocated from a descriptor pool.
175 failed(result: createComputePipeline()) || failed(result: createDescriptorPool()) ||
176 failed(result: allocateDescriptorSets()) || failed(result: setWriteDescriptors()) ||
177 // Create command buffer.
178 failed(result: createCommandPool()) || failed(result: createQueryPool()) ||
179 failed(result: createComputeCommandBuffer())) {
180 return failure();
181 }
182
183 // Get working queue.
184 vkGetDeviceQueue(device, queueFamilyIndex, queueIndex: 0, pQueue: &queue);
185
186 if (failed(result: copyResource(/*deviceToHost=*/false)))
187 return failure();
188
189 auto submitStart = std::chrono::high_resolution_clock::now();
190 // Submit command buffer into the queue.
191 if (failed(result: submitCommandBuffersToQueue()))
192 return failure();
193 auto submitEnd = std::chrono::high_resolution_clock::now();
194
195 RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle");
196 auto execEnd = std::chrono::high_resolution_clock::now();
197
198 auto submitDuration = std::chrono::duration_cast<std::chrono::microseconds>(
199 d: submitEnd - submitStart);
200 auto execDuration = std::chrono::duration_cast<std::chrono::microseconds>(
201 d: execEnd - submitEnd);
202
203 if (queryPool != VK_NULL_HANDLE) {
204 uint64_t timestamps[2];
205 RETURN_ON_VULKAN_ERROR(
206 vkGetQueryPoolResults(
207 device, queryPool, /*firstQuery=*/0, /*queryCount=*/2,
208 /*dataSize=*/sizeof(timestamps),
209 /*pData=*/reinterpret_cast<void *>(timestamps),
210 /*stride=*/sizeof(uint64_t),
211 VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT),
212 "vkGetQueryPoolResults");
213 float microsec = (timestamps[1] - timestamps[0]) * timestampPeriod / 1000;
214 std::cout << "Compute shader execution time: " << std::setprecision(3)
215 << microsec << "us\n";
216 }
217
218 std::cout << "Command buffer submit time: " << submitDuration.count()
219 << "us\nWait idle time: " << execDuration.count() << "us\n";
220
221 return success();
222}
223
224LogicalResult VulkanRuntime::createInstance() {
225 VkApplicationInfo applicationInfo = {};
226 applicationInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
227 applicationInfo.pNext = nullptr;
228 applicationInfo.pApplicationName = "MLIR Vulkan runtime";
229 applicationInfo.applicationVersion = 0;
230 applicationInfo.pEngineName = "mlir";
231 applicationInfo.engineVersion = 0;
232 applicationInfo.apiVersion = VK_MAKE_VERSION(1, 0, 0);
233
234 VkInstanceCreateInfo instanceCreateInfo = {};
235 instanceCreateInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
236 instanceCreateInfo.pNext = nullptr;
237 instanceCreateInfo.pApplicationInfo = &applicationInfo;
238 instanceCreateInfo.enabledLayerCount = 0;
239 instanceCreateInfo.ppEnabledLayerNames = nullptr;
240
241 std::vector<const char *> extNames;
242#if defined(__APPLE__)
243 // enumerate MoltenVK for Vulkan 1.0
244 instanceCreateInfo.flags = VK_INSTANCE_CREATE_ENUMERATE_PORTABILITY_BIT_KHR;
245 // add KHR portability instance extensions
246 extNames.push_back(VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME);
247 extNames.push_back(VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME);
248#else
249 instanceCreateInfo.flags = 0;
250#endif // __APPLE__
251 instanceCreateInfo.enabledExtensionCount =
252 static_cast<uint32_t>(extNames.size());
253 instanceCreateInfo.ppEnabledExtensionNames = extNames.data();
254
255 RETURN_ON_VULKAN_ERROR(
256 vkCreateInstance(&instanceCreateInfo, nullptr, &instance),
257 "vkCreateInstance");
258 return success();
259}
260
261LogicalResult VulkanRuntime::createDevice() {
262 uint32_t physicalDeviceCount = 0;
263 RETURN_ON_VULKAN_ERROR(
264 vkEnumeratePhysicalDevices(instance, &physicalDeviceCount, nullptr),
265 "vkEnumeratePhysicalDevices");
266
267 std::vector<VkPhysicalDevice> physicalDevices(physicalDeviceCount);
268 RETURN_ON_VULKAN_ERROR(vkEnumeratePhysicalDevices(instance,
269 &physicalDeviceCount,
270 physicalDevices.data()),
271 "vkEnumeratePhysicalDevices");
272
273 RETURN_ON_VULKAN_ERROR(physicalDeviceCount ? VK_SUCCESS : VK_INCOMPLETE,
274 "physicalDeviceCount");
275
276 // TODO: find the best device.
277 physicalDevice = physicalDevices.front();
278 if (failed(result: getBestComputeQueue()))
279 return failure();
280
281 const float queuePriority = 1.0f;
282 VkDeviceQueueCreateInfo deviceQueueCreateInfo = {};
283 deviceQueueCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
284 deviceQueueCreateInfo.pNext = nullptr;
285 deviceQueueCreateInfo.flags = 0;
286 deviceQueueCreateInfo.queueFamilyIndex = queueFamilyIndex;
287 deviceQueueCreateInfo.queueCount = 1;
288 deviceQueueCreateInfo.pQueuePriorities = &queuePriority;
289
290 // Structure specifying parameters of a newly created device.
291 VkDeviceCreateInfo deviceCreateInfo = {};
292 deviceCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
293 deviceCreateInfo.pNext = nullptr;
294 deviceCreateInfo.flags = 0;
295 deviceCreateInfo.queueCreateInfoCount = 1;
296 deviceCreateInfo.pQueueCreateInfos = &deviceQueueCreateInfo;
297 deviceCreateInfo.enabledLayerCount = 0;
298 deviceCreateInfo.ppEnabledLayerNames = nullptr;
299 deviceCreateInfo.enabledExtensionCount = 0;
300 deviceCreateInfo.ppEnabledExtensionNames = nullptr;
301 deviceCreateInfo.pEnabledFeatures = nullptr;
302
303 RETURN_ON_VULKAN_ERROR(
304 vkCreateDevice(physicalDevice, &deviceCreateInfo, nullptr, &device),
305 "vkCreateDevice");
306
307 VkPhysicalDeviceMemoryProperties properties = {};
308 vkGetPhysicalDeviceMemoryProperties(physicalDevice, pMemoryProperties: &properties);
309
310 // Try to find memory type with following properties:
311 // VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT bit specifies that memory allocated
312 // with this type can be mapped for host access using vkMapMemory;
313 // VK_MEMORY_PROPERTY_HOST_COHERENT_BIT bit specifies that the host cache
314 // management commands vkFlushMappedMemoryRanges and
315 // vkInvalidateMappedMemoryRanges are not needed to flush host writes to the
316 // device or make device writes visible to the host, respectively.
317 for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) {
318 if ((VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT &
319 properties.memoryTypes[i].propertyFlags) &&
320 (VK_MEMORY_PROPERTY_HOST_COHERENT_BIT &
321 properties.memoryTypes[i].propertyFlags) &&
322 (memorySize <=
323 properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) {
324 hostMemoryTypeIndex = i;
325 break;
326 }
327 }
328
329 // Find memory type memory type with VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT to be
330 // used on the device. This will allow better performance access for GPU with
331 // on device memory.
332 for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) {
333 if ((VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT &
334 properties.memoryTypes[i].propertyFlags) &&
335 (memorySize <=
336 properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) {
337 deviceMemoryTypeIndex = i;
338 break;
339 }
340 }
341
342 RETURN_ON_VULKAN_ERROR((hostMemoryTypeIndex == VK_MAX_MEMORY_TYPES ||
343 deviceMemoryTypeIndex == VK_MAX_MEMORY_TYPES)
344 ? VK_INCOMPLETE
345 : VK_SUCCESS,
346 "invalid memoryTypeIndex");
347 return success();
348}
349
350LogicalResult VulkanRuntime::getBestComputeQueue() {
351 uint32_t queueFamilyPropertiesCount = 0;
352 vkGetPhysicalDeviceQueueFamilyProperties(
353 physicalDevice, pQueueFamilyPropertyCount: &queueFamilyPropertiesCount, pQueueFamilyProperties: nullptr);
354
355 std::vector<VkQueueFamilyProperties> familyProperties(
356 queueFamilyPropertiesCount);
357 vkGetPhysicalDeviceQueueFamilyProperties(
358 physicalDevice, pQueueFamilyPropertyCount: &queueFamilyPropertiesCount, pQueueFamilyProperties: familyProperties.data());
359
360 // VK_QUEUE_COMPUTE_BIT specifies that queues in this queue family support
361 // compute operations. Try to find a compute-only queue first if possible.
362 for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
363 auto flags = familyProperties[i].queueFlags;
364 if ((flags & VK_QUEUE_COMPUTE_BIT) && !(flags & VK_QUEUE_GRAPHICS_BIT)) {
365 queueFamilyIndex = i;
366 queueFamilyProperties = familyProperties[i];
367 return success();
368 }
369 }
370
371 // Otherwise use a queue that can also support graphics.
372 for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
373 auto flags = familyProperties[i].queueFlags;
374 if ((flags & VK_QUEUE_COMPUTE_BIT)) {
375 queueFamilyIndex = i;
376 queueFamilyProperties = familyProperties[i];
377 return success();
378 }
379 }
380
381 std::cerr << "cannot find valid queue";
382 return failure();
383}
384
385LogicalResult VulkanRuntime::createMemoryBuffers() {
386 // For each descriptor set.
387 for (const auto &resourceDataMapPair : resourceData) {
388 std::vector<VulkanDeviceMemoryBuffer> deviceMemoryBuffers;
389 const auto descriptorSetIndex = resourceDataMapPair.first;
390 const auto &resourceDataMap = resourceDataMapPair.second;
391
392 // For each descriptor binding.
393 for (const auto &resourceDataBindingPair : resourceDataMap) {
394 // Create device memory buffer.
395 VulkanDeviceMemoryBuffer memoryBuffer;
396 memoryBuffer.bindingIndex = resourceDataBindingPair.first;
397 VkDescriptorType descriptorType = {};
398 VkBufferUsageFlagBits bufferUsage = {};
399
400 // Check that descriptor set has storage class map.
401 const auto resourceStorageClassMapIt =
402 resourceStorageClassData.find(x: descriptorSetIndex);
403 if (resourceStorageClassMapIt == resourceStorageClassData.end()) {
404 std::cerr
405 << "cannot find storage class for resource in descriptor set: "
406 << descriptorSetIndex;
407 return failure();
408 }
409
410 // Check that specific descriptor binding has storage class.
411 const auto &resourceStorageClassMap = resourceStorageClassMapIt->second;
412 const auto resourceStorageClassIt =
413 resourceStorageClassMap.find(x: resourceDataBindingPair.first);
414 if (resourceStorageClassIt == resourceStorageClassMap.end()) {
415 std::cerr
416 << "cannot find storage class for resource with descriptor index: "
417 << resourceDataBindingPair.first;
418 return failure();
419 }
420
421 const auto resourceStorageClassBinding = resourceStorageClassIt->second;
422 if (failed(result: mapStorageClassToDescriptorType(storageClass: resourceStorageClassBinding,
423 descriptorType)) ||
424 failed(result: mapStorageClassToBufferUsageFlag(storageClass: resourceStorageClassBinding,
425 bufferUsage))) {
426 std::cerr << "storage class for resource with descriptor binding: "
427 << resourceDataBindingPair.first
428 << " in the descriptor set: " << descriptorSetIndex
429 << " is not supported ";
430 return failure();
431 }
432
433 // Set descriptor type for the specific device memory buffer.
434 memoryBuffer.descriptorType = descriptorType;
435 const auto bufferSize = resourceDataBindingPair.second.size;
436 memoryBuffer.bufferSize = bufferSize;
437 // Specify memory allocation info.
438 VkMemoryAllocateInfo memoryAllocateInfo = {};
439 memoryAllocateInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
440 memoryAllocateInfo.pNext = nullptr;
441 memoryAllocateInfo.allocationSize = bufferSize;
442 memoryAllocateInfo.memoryTypeIndex = hostMemoryTypeIndex;
443
444 // Allocate device memory.
445 RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo,
446 nullptr,
447 &memoryBuffer.hostMemory),
448 "vkAllocateMemory");
449 memoryAllocateInfo.memoryTypeIndex = deviceMemoryTypeIndex;
450 RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo,
451 nullptr,
452 &memoryBuffer.deviceMemory),
453 "vkAllocateMemory");
454 void *payload;
455 RETURN_ON_VULKAN_ERROR(vkMapMemory(device, memoryBuffer.hostMemory, 0,
456 bufferSize, 0,
457 reinterpret_cast<void **>(&payload)),
458 "vkMapMemory");
459
460 // Copy host memory into the mapped area.
461 std::memcpy(dest: payload, src: resourceDataBindingPair.second.ptr, n: bufferSize);
462 vkUnmapMemory(device, memory: memoryBuffer.hostMemory);
463
464 VkBufferCreateInfo bufferCreateInfo = {};
465 bufferCreateInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
466 bufferCreateInfo.pNext = nullptr;
467 bufferCreateInfo.flags = 0;
468 bufferCreateInfo.size = bufferSize;
469 bufferCreateInfo.usage = bufferUsage | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
470 VK_BUFFER_USAGE_TRANSFER_SRC_BIT;
471 bufferCreateInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
472 bufferCreateInfo.queueFamilyIndexCount = 1;
473 bufferCreateInfo.pQueueFamilyIndices = &queueFamilyIndex;
474 RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, nullptr,
475 &memoryBuffer.hostBuffer),
476 "vkCreateBuffer");
477 RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, nullptr,
478 &memoryBuffer.deviceBuffer),
479 "vkCreateBuffer");
480
481 // Bind buffer and device memory.
482 RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device, memoryBuffer.hostBuffer,
483 memoryBuffer.hostMemory, 0),
484 "vkBindBufferMemory");
485 RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device,
486 memoryBuffer.deviceBuffer,
487 memoryBuffer.deviceMemory, 0),
488 "vkBindBufferMemory");
489
490 // Update buffer info.
491 memoryBuffer.bufferInfo.buffer = memoryBuffer.deviceBuffer;
492 memoryBuffer.bufferInfo.offset = 0;
493 memoryBuffer.bufferInfo.range = VK_WHOLE_SIZE;
494 deviceMemoryBuffers.push_back(x: memoryBuffer);
495 }
496
497 // Associate device memory buffers with a descriptor set.
498 deviceMemoryBufferMap[descriptorSetIndex] = deviceMemoryBuffers;
499 }
500 return success();
501}
502
503LogicalResult VulkanRuntime::copyResource(bool deviceToHost) {
504 VkCommandBufferAllocateInfo commandBufferAllocateInfo = {
505 .sType: VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO,
506 .pNext: nullptr,
507 .commandPool: commandPool,
508 .level: VK_COMMAND_BUFFER_LEVEL_PRIMARY,
509 .commandBufferCount: 1,
510 };
511 VkCommandBuffer commandBuffer;
512 RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device,
513 &commandBufferAllocateInfo,
514 &commandBuffer),
515 "vkAllocateCommandBuffers");
516
517 VkCommandBufferBeginInfo commandBufferBeginInfo = {
518 .sType: VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO,
519 .pNext: nullptr,
520 .flags: 0,
521 .pInheritanceInfo: nullptr,
522 };
523 RETURN_ON_VULKAN_ERROR(
524 vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo),
525 "vkBeginCommandBuffer");
526
527 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
528 std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings;
529 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
530 for (const auto &memBuffer : deviceMemoryBuffers) {
531 VkBufferCopy copy = {.srcOffset: 0, .dstOffset: 0, .size: memBuffer.bufferSize};
532 if (deviceToHost)
533 vkCmdCopyBuffer(commandBuffer, srcBuffer: memBuffer.deviceBuffer,
534 dstBuffer: memBuffer.hostBuffer, regionCount: 1, pRegions: &copy);
535 else
536 vkCmdCopyBuffer(commandBuffer, srcBuffer: memBuffer.hostBuffer,
537 dstBuffer: memBuffer.deviceBuffer, regionCount: 1, pRegions: &copy);
538 }
539 }
540
541 RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer),
542 "vkEndCommandBuffer");
543 VkSubmitInfo submitInfo = {
544 .sType: VK_STRUCTURE_TYPE_SUBMIT_INFO,
545 .pNext: nullptr,
546 .waitSemaphoreCount: 0,
547 .pWaitSemaphores: nullptr,
548 .pWaitDstStageMask: nullptr,
549 .commandBufferCount: 1,
550 .pCommandBuffers: &commandBuffer,
551 .signalSemaphoreCount: 0,
552 .pSignalSemaphores: nullptr,
553 };
554 submitInfo.pCommandBuffers = &commandBuffer;
555 RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, VK_NULL_HANDLE),
556 "vkQueueSubmit");
557 RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle");
558
559 vkFreeCommandBuffers(device, commandPool, commandBufferCount: 1, pCommandBuffers: &commandBuffer);
560 return success();
561}
562
563LogicalResult VulkanRuntime::createShaderModule() {
564 VkShaderModuleCreateInfo shaderModuleCreateInfo = {};
565 shaderModuleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
566 shaderModuleCreateInfo.pNext = nullptr;
567 shaderModuleCreateInfo.flags = 0;
568 // Set size in bytes.
569 shaderModuleCreateInfo.codeSize = binarySize;
570 // Set pointer to the binary shader.
571 shaderModuleCreateInfo.pCode = reinterpret_cast<uint32_t *>(binary);
572 RETURN_ON_VULKAN_ERROR(vkCreateShaderModule(device, &shaderModuleCreateInfo,
573 nullptr, &shaderModule),
574 "vkCreateShaderModule");
575 return success();
576}
577
578void VulkanRuntime::initDescriptorSetLayoutBindingMap() {
579 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
580 std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings;
581 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
582 const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
583
584 // Create a layout binding for each descriptor.
585 for (const auto &memBuffer : deviceMemoryBuffers) {
586 VkDescriptorSetLayoutBinding descriptorSetLayoutBinding = {};
587 descriptorSetLayoutBinding.binding = memBuffer.bindingIndex;
588 descriptorSetLayoutBinding.descriptorType = memBuffer.descriptorType;
589 descriptorSetLayoutBinding.descriptorCount = 1;
590 descriptorSetLayoutBinding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
591 descriptorSetLayoutBinding.pImmutableSamplers = nullptr;
592 descriptorSetLayoutBindings.push_back(x: descriptorSetLayoutBinding);
593 }
594 descriptorSetLayoutBindingMap[descriptorSetIndex] =
595 descriptorSetLayoutBindings;
596 }
597}
598
599LogicalResult VulkanRuntime::createDescriptorSetLayout() {
600 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
601 const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
602 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
603 // Each descriptor in a descriptor set must be the same type.
604 VkDescriptorType descriptorType =
605 deviceMemoryBuffers.front().descriptorType;
606 const uint32_t descriptorSize = deviceMemoryBuffers.size();
607 const auto descriptorSetLayoutBindingIt =
608 descriptorSetLayoutBindingMap.find(x: descriptorSetIndex);
609
610 if (descriptorSetLayoutBindingIt == descriptorSetLayoutBindingMap.end()) {
611 std::cerr << "cannot find layout bindings for the set with number: "
612 << descriptorSetIndex;
613 return failure();
614 }
615
616 const auto &descriptorSetLayoutBindings =
617 descriptorSetLayoutBindingIt->second;
618 // Create descriptor set layout.
619 VkDescriptorSetLayout descriptorSetLayout = {};
620 VkDescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo = {};
621
622 descriptorSetLayoutCreateInfo.sType =
623 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
624 descriptorSetLayoutCreateInfo.pNext = nullptr;
625 descriptorSetLayoutCreateInfo.flags = 0;
626 // Amount of descriptor bindings in a layout set.
627 descriptorSetLayoutCreateInfo.bindingCount =
628 descriptorSetLayoutBindings.size();
629 descriptorSetLayoutCreateInfo.pBindings =
630 descriptorSetLayoutBindings.data();
631 RETURN_ON_VULKAN_ERROR(
632 vkCreateDescriptorSetLayout(device, &descriptorSetLayoutCreateInfo,
633 nullptr, &descriptorSetLayout),
634 "vkCreateDescriptorSetLayout");
635
636 descriptorSetLayouts.push_back(x: descriptorSetLayout);
637 descriptorSetInfoPool.push_back(
638 x: {.descriptorSet: descriptorSetIndex, .descriptorSize: descriptorSize, .descriptorType: descriptorType});
639 }
640 return success();
641}
642
643LogicalResult VulkanRuntime::createPipelineLayout() {
644 // Associate descriptor sets with a pipeline layout.
645 VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo = {};
646 pipelineLayoutCreateInfo.sType =
647 VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
648 pipelineLayoutCreateInfo.pNext = nullptr;
649 pipelineLayoutCreateInfo.flags = 0;
650 pipelineLayoutCreateInfo.setLayoutCount = descriptorSetLayouts.size();
651 pipelineLayoutCreateInfo.pSetLayouts = descriptorSetLayouts.data();
652 pipelineLayoutCreateInfo.pushConstantRangeCount = 0;
653 pipelineLayoutCreateInfo.pPushConstantRanges = nullptr;
654 RETURN_ON_VULKAN_ERROR(vkCreatePipelineLayout(device,
655 &pipelineLayoutCreateInfo,
656 nullptr, &pipelineLayout),
657 "vkCreatePipelineLayout");
658 return success();
659}
660
661LogicalResult VulkanRuntime::createComputePipeline() {
662 VkPipelineShaderStageCreateInfo stageInfo = {};
663 stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
664 stageInfo.pNext = nullptr;
665 stageInfo.flags = 0;
666 stageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT;
667 stageInfo.module = shaderModule;
668 // Set entry point.
669 stageInfo.pName = entryPoint;
670 stageInfo.pSpecializationInfo = nullptr;
671
672 VkComputePipelineCreateInfo computePipelineCreateInfo = {};
673 computePipelineCreateInfo.sType =
674 VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
675 computePipelineCreateInfo.pNext = nullptr;
676 computePipelineCreateInfo.flags = 0;
677 computePipelineCreateInfo.stage = stageInfo;
678 computePipelineCreateInfo.layout = pipelineLayout;
679 computePipelineCreateInfo.basePipelineHandle = nullptr;
680 computePipelineCreateInfo.basePipelineIndex = 0;
681 RETURN_ON_VULKAN_ERROR(vkCreateComputePipelines(device, nullptr, 1,
682 &computePipelineCreateInfo,
683 nullptr, &pipeline),
684 "vkCreateComputePipelines");
685 return success();
686}
687
688LogicalResult VulkanRuntime::createDescriptorPool() {
689 std::vector<VkDescriptorPoolSize> descriptorPoolSizes;
690 for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
691 // For each descriptor set populate descriptor pool size.
692 VkDescriptorPoolSize descriptorPoolSize = {};
693 descriptorPoolSize.type = descriptorSetInfo.descriptorType;
694 descriptorPoolSize.descriptorCount = descriptorSetInfo.descriptorSize;
695 descriptorPoolSizes.push_back(x: descriptorPoolSize);
696 }
697
698 VkDescriptorPoolCreateInfo descriptorPoolCreateInfo = {};
699 descriptorPoolCreateInfo.sType =
700 VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
701 descriptorPoolCreateInfo.pNext = nullptr;
702 descriptorPoolCreateInfo.flags = 0;
703 descriptorPoolCreateInfo.maxSets = descriptorPoolSizes.size();
704 descriptorPoolCreateInfo.poolSizeCount = descriptorPoolSizes.size();
705 descriptorPoolCreateInfo.pPoolSizes = descriptorPoolSizes.data();
706 RETURN_ON_VULKAN_ERROR(vkCreateDescriptorPool(device,
707 &descriptorPoolCreateInfo,
708 nullptr, &descriptorPool),
709 "vkCreateDescriptorPool");
710 return success();
711}
712
713LogicalResult VulkanRuntime::allocateDescriptorSets() {
714 VkDescriptorSetAllocateInfo descriptorSetAllocateInfo = {};
715 // Size of descriptor sets and descriptor layout sets is the same.
716 descriptorSets.resize(new_size: descriptorSetLayouts.size());
717 descriptorSetAllocateInfo.sType =
718 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
719 descriptorSetAllocateInfo.pNext = nullptr;
720 descriptorSetAllocateInfo.descriptorPool = descriptorPool;
721 descriptorSetAllocateInfo.descriptorSetCount = descriptorSetLayouts.size();
722 descriptorSetAllocateInfo.pSetLayouts = descriptorSetLayouts.data();
723 RETURN_ON_VULKAN_ERROR(vkAllocateDescriptorSets(device,
724 &descriptorSetAllocateInfo,
725 descriptorSets.data()),
726 "vkAllocateDescriptorSets");
727 return success();
728}
729
730LogicalResult VulkanRuntime::setWriteDescriptors() {
731 if (descriptorSets.size() != descriptorSetInfoPool.size()) {
732 std::cerr << "Each descriptor set must have descriptor set information";
733 return failure();
734 }
735 // For each descriptor set.
736 auto descriptorSetIt = descriptorSets.begin();
737 // Each descriptor set is associated with descriptor set info.
738 for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
739 // For each device memory buffer in the descriptor set.
740 const auto &deviceMemoryBuffers =
741 deviceMemoryBufferMap[descriptorSetInfo.descriptorSet];
742 for (const auto &memoryBuffer : deviceMemoryBuffers) {
743 // Structure describing descriptor sets to write to.
744 VkWriteDescriptorSet wSet = {};
745 wSet.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
746 wSet.pNext = nullptr;
747 // Descriptor set.
748 wSet.dstSet = *descriptorSetIt;
749 wSet.dstBinding = memoryBuffer.bindingIndex;
750 wSet.dstArrayElement = 0;
751 wSet.descriptorCount = 1;
752 wSet.descriptorType = memoryBuffer.descriptorType;
753 wSet.pImageInfo = nullptr;
754 wSet.pBufferInfo = &memoryBuffer.bufferInfo;
755 wSet.pTexelBufferView = nullptr;
756 vkUpdateDescriptorSets(device, descriptorWriteCount: 1, pDescriptorWrites: &wSet, descriptorCopyCount: 0, pDescriptorCopies: nullptr);
757 }
758 // Increment descriptor set iterator.
759 ++descriptorSetIt;
760 }
761 return success();
762}
763
764LogicalResult VulkanRuntime::createCommandPool() {
765 VkCommandPoolCreateInfo commandPoolCreateInfo = {};
766 commandPoolCreateInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
767 commandPoolCreateInfo.pNext = nullptr;
768 commandPoolCreateInfo.flags = 0;
769 commandPoolCreateInfo.queueFamilyIndex = queueFamilyIndex;
770 RETURN_ON_VULKAN_ERROR(vkCreateCommandPool(device, &commandPoolCreateInfo,
771 /*pAllocator=*/nullptr,
772 &commandPool),
773 "vkCreateCommandPool");
774 return success();
775}
776
777LogicalResult VulkanRuntime::createQueryPool() {
778 // Return directly if timestamp query is not supported.
779 if (queueFamilyProperties.timestampValidBits == 0)
780 return success();
781
782 // Get timestamp period for this physical device.
783 VkPhysicalDeviceProperties deviceProperties = {};
784 vkGetPhysicalDeviceProperties(physicalDevice, pProperties: &deviceProperties);
785 timestampPeriod = deviceProperties.limits.timestampPeriod;
786
787 // Create query pool.
788 VkQueryPoolCreateInfo queryPoolCreateInfo = {};
789 queryPoolCreateInfo.sType = VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO;
790 queryPoolCreateInfo.pNext = nullptr;
791 queryPoolCreateInfo.flags = 0;
792 queryPoolCreateInfo.queryType = VK_QUERY_TYPE_TIMESTAMP;
793 queryPoolCreateInfo.queryCount = 2;
794 queryPoolCreateInfo.pipelineStatistics = 0;
795 RETURN_ON_VULKAN_ERROR(vkCreateQueryPool(device, &queryPoolCreateInfo,
796 /*pAllocator=*/nullptr, &queryPool),
797 "vkCreateQueryPool");
798
799 return success();
800}
801
802LogicalResult VulkanRuntime::createComputeCommandBuffer() {
803 VkCommandBufferAllocateInfo commandBufferAllocateInfo = {};
804 commandBufferAllocateInfo.sType =
805 VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
806 commandBufferAllocateInfo.pNext = nullptr;
807 commandBufferAllocateInfo.commandPool = commandPool;
808 commandBufferAllocateInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
809 commandBufferAllocateInfo.commandBufferCount = 1;
810
811 VkCommandBuffer commandBuffer;
812 RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device,
813 &commandBufferAllocateInfo,
814 &commandBuffer),
815 "vkAllocateCommandBuffers");
816
817 VkCommandBufferBeginInfo commandBufferBeginInfo = {};
818 commandBufferBeginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
819 commandBufferBeginInfo.pNext = nullptr;
820 commandBufferBeginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
821 commandBufferBeginInfo.pInheritanceInfo = nullptr;
822
823 // Commands begin.
824 RETURN_ON_VULKAN_ERROR(
825 vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo),
826 "vkBeginCommandBuffer");
827
828 if (queryPool != VK_NULL_HANDLE)
829 vkCmdResetQueryPool(commandBuffer, queryPool, firstQuery: 0, queryCount: 2);
830
831 vkCmdBindPipeline(commandBuffer, pipelineBindPoint: VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
832 vkCmdBindDescriptorSets(commandBuffer, pipelineBindPoint: VK_PIPELINE_BIND_POINT_COMPUTE,
833 layout: pipelineLayout, firstSet: 0, descriptorSetCount: descriptorSets.size(),
834 pDescriptorSets: descriptorSets.data(), dynamicOffsetCount: 0, pDynamicOffsets: nullptr);
835 // Get a timestamp before invoking the compute shader.
836 if (queryPool != VK_NULL_HANDLE)
837 vkCmdWriteTimestamp(commandBuffer, pipelineStage: VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT,
838 queryPool, query: 0);
839 vkCmdDispatch(commandBuffer, groupCountX: numWorkGroups.x, groupCountY: numWorkGroups.y,
840 groupCountZ: numWorkGroups.z);
841 // Get another timestamp after invoking the compute shader.
842 if (queryPool != VK_NULL_HANDLE)
843 vkCmdWriteTimestamp(commandBuffer, pipelineStage: VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT,
844 queryPool, query: 1);
845
846 // Commands end.
847 RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer),
848 "vkEndCommandBuffer");
849
850 commandBuffers.push_back(x: commandBuffer);
851 return success();
852}
853
854LogicalResult VulkanRuntime::submitCommandBuffersToQueue() {
855 VkSubmitInfo submitInfo = {};
856 submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
857 submitInfo.pNext = nullptr;
858 submitInfo.waitSemaphoreCount = 0;
859 submitInfo.pWaitSemaphores = nullptr;
860 submitInfo.pWaitDstStageMask = nullptr;
861 submitInfo.commandBufferCount = commandBuffers.size();
862 submitInfo.pCommandBuffers = commandBuffers.data();
863 submitInfo.signalSemaphoreCount = 0;
864 submitInfo.pSignalSemaphores = nullptr;
865 RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, nullptr),
866 "vkQueueSubmit");
867 return success();
868}
869
870LogicalResult VulkanRuntime::updateHostMemoryBuffers() {
871 // First copy back the data to the staging buffer.
872 (void)copyResource(/*deviceToHost=*/true);
873
874 // For each descriptor set.
875 for (auto &resourceDataMapPair : resourceData) {
876 auto &resourceDataMap = resourceDataMapPair.second;
877 auto &deviceMemoryBuffers =
878 deviceMemoryBufferMap[resourceDataMapPair.first];
879 // For each device memory buffer in the set.
880 for (auto &deviceMemoryBuffer : deviceMemoryBuffers) {
881 if (resourceDataMap.count(x: deviceMemoryBuffer.bindingIndex)) {
882 void *payload;
883 auto &hostMemoryBuffer =
884 resourceDataMap[deviceMemoryBuffer.bindingIndex];
885 RETURN_ON_VULKAN_ERROR(vkMapMemory(device,
886 deviceMemoryBuffer.hostMemory, 0,
887 hostMemoryBuffer.size, 0,
888 reinterpret_cast<void **>(&payload)),
889 "vkMapMemory");
890 std::memcpy(dest: hostMemoryBuffer.ptr, src: payload, n: hostMemoryBuffer.size);
891 vkUnmapMemory(device, memory: deviceMemoryBuffer.hostMemory);
892 }
893 }
894 }
895 return success();
896}
897

source code of mlir/tools/mlir-vulkan-runner/VulkanRuntime.cpp