| 1 | //===-- LayoutUtils.cpp - Decorate composite type with layout information -===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements Utilities used to get alignment and layout information |
| 10 | // for types in SPIR-V dialect. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" |
| 15 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| 16 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
| 17 | |
| 18 | using namespace mlir; |
| 19 | |
| 20 | spirv::StructType |
| 21 | VulkanLayoutUtils::decorateType(spirv::StructType structType) { |
| 22 | Size size = 0; |
| 23 | Size alignment = 1; |
| 24 | return decorateType(structType, size, alignment); |
| 25 | } |
| 26 | |
| 27 | spirv::StructType |
| 28 | VulkanLayoutUtils::decorateType(spirv::StructType structType, |
| 29 | VulkanLayoutUtils::Size &size, |
| 30 | VulkanLayoutUtils::Size &alignment) { |
| 31 | if (structType.getNumElements() == 0) { |
| 32 | return structType; |
| 33 | } |
| 34 | |
| 35 | SmallVector<Type, 4> memberTypes; |
| 36 | SmallVector<spirv::StructType::OffsetInfo, 4> offsetInfo; |
| 37 | SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; |
| 38 | |
| 39 | Size structMemberOffset = 0; |
| 40 | Size maxMemberAlignment = 1; |
| 41 | |
| 42 | for (uint32_t i = 0, e = structType.getNumElements(); i < e; ++i) { |
| 43 | Size memberSize = 0; |
| 44 | Size memberAlignment = 1; |
| 45 | |
| 46 | Type memberType = |
| 47 | decorateType(type: structType.getElementType(i), size&: memberSize, alignment&: memberAlignment); |
| 48 | structMemberOffset = llvm::alignTo(Value: structMemberOffset, Align: memberAlignment); |
| 49 | memberTypes.push_back(Elt: memberType); |
| 50 | offsetInfo.push_back( |
| 51 | Elt: static_cast<spirv::StructType::OffsetInfo>(structMemberOffset)); |
| 52 | // If the member's size is the max value, it must be the last member and it |
| 53 | // must be a runtime array. |
| 54 | assert(memberSize != std::numeric_limits<Size>().max() || |
| 55 | (i + 1 == e && |
| 56 | isa<spirv::RuntimeArrayType>(structType.getElementType(i)))); |
| 57 | // According to the Vulkan spec: |
| 58 | // "A structure has a base alignment equal to the largest base alignment of |
| 59 | // any of its members." |
| 60 | structMemberOffset += memberSize; |
| 61 | maxMemberAlignment = std::max(a: maxMemberAlignment, b: memberAlignment); |
| 62 | } |
| 63 | |
| 64 | // According to the Vulkan spec: |
| 65 | // "The Offset decoration of a member must not place it between the end of a |
| 66 | // structure or an array and the next multiple of the alignment of that |
| 67 | // structure or array." |
| 68 | size = llvm::alignTo(Value: structMemberOffset, Align: maxMemberAlignment); |
| 69 | alignment = maxMemberAlignment; |
| 70 | structType.getMemberDecorations(memberDecorations); |
| 71 | |
| 72 | if (!structType.isIdentified()) |
| 73 | return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations); |
| 74 | |
| 75 | // Identified structs are uniqued by identifier so it is not possible |
| 76 | // to create 2 structs with the same name but different decorations. |
| 77 | return nullptr; |
| 78 | } |
| 79 | |
| 80 | Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size, |
| 81 | VulkanLayoutUtils::Size &alignment) { |
| 82 | if (isa<spirv::ScalarType>(Val: type)) { |
| 83 | alignment = getScalarTypeAlignment(scalarType: type); |
| 84 | // Vulkan spec does not specify any padding for a scalar type. |
| 85 | size = alignment; |
| 86 | return type; |
| 87 | } |
| 88 | if (auto structType = dyn_cast<spirv::StructType>(Val&: type)) |
| 89 | return decorateType(structType, size, alignment); |
| 90 | if (auto arrayType = dyn_cast<spirv::ArrayType>(Val&: type)) |
| 91 | return decorateType(arrayType, size, alignment); |
| 92 | if (auto vectorType = dyn_cast<VectorType>(type)) |
| 93 | return decorateType(vectorType, size, alignment); |
| 94 | if (auto matrixType = dyn_cast<spirv::MatrixType>(Val&: type)) |
| 95 | return decorateType(matrixType, size, alignment); |
| 96 | if (auto arrayType = dyn_cast<spirv::RuntimeArrayType>(Val&: type)) { |
| 97 | size = std::numeric_limits<Size>().max(); |
| 98 | return decorateType(arrayType, alignment); |
| 99 | } |
| 100 | if (isa<spirv::PointerType>(Val: type)) { |
| 101 | // TODO: Add support for `PhysicalStorageBufferAddresses`. |
| 102 | return nullptr; |
| 103 | } |
| 104 | llvm_unreachable("unhandled SPIR-V type" ); |
| 105 | } |
| 106 | |
| 107 | Type VulkanLayoutUtils::decorateType(VectorType vectorType, |
| 108 | VulkanLayoutUtils::Size &size, |
| 109 | VulkanLayoutUtils::Size &alignment) { |
| 110 | const unsigned numElements = vectorType.getNumElements(); |
| 111 | Type elementType = vectorType.getElementType(); |
| 112 | Size elementSize = 0; |
| 113 | Size elementAlignment = 1; |
| 114 | |
| 115 | Type memberType = decorateType(type: elementType, size&: elementSize, alignment&: elementAlignment); |
| 116 | // According to the Vulkan spec: |
| 117 | // 1. "A two-component vector has a base alignment equal to twice its scalar |
| 118 | // alignment." |
| 119 | // 2. "A three- or four-component vector has a base alignment equal to four |
| 120 | // times its scalar alignment." |
| 121 | size = elementSize * numElements; |
| 122 | alignment = numElements == 2 ? elementAlignment * 2 : elementAlignment * 4; |
| 123 | return VectorType::get(numElements, memberType); |
| 124 | } |
| 125 | |
| 126 | Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType, |
| 127 | VulkanLayoutUtils::Size &size, |
| 128 | VulkanLayoutUtils::Size &alignment) { |
| 129 | const unsigned numElements = arrayType.getNumElements(); |
| 130 | Type elementType = arrayType.getElementType(); |
| 131 | Size elementSize = 0; |
| 132 | Size elementAlignment = 1; |
| 133 | |
| 134 | Type memberType = decorateType(type: elementType, size&: elementSize, alignment&: elementAlignment); |
| 135 | // According to the Vulkan spec: |
| 136 | // "An array has a base alignment equal to the base alignment of its element |
| 137 | // type." |
| 138 | size = elementSize * numElements; |
| 139 | alignment = elementAlignment; |
| 140 | return spirv::ArrayType::get(elementType: memberType, elementCount: numElements, stride: elementSize); |
| 141 | } |
| 142 | |
| 143 | Type VulkanLayoutUtils::decorateType(spirv::MatrixType matrixType, |
| 144 | VulkanLayoutUtils::Size &size, |
| 145 | VulkanLayoutUtils::Size &alignment) { |
| 146 | const unsigned numColumns = matrixType.getNumColumns(); |
| 147 | Type columnType = matrixType.getColumnType(); |
| 148 | unsigned numElements = matrixType.getNumElements(); |
| 149 | Type elementType = matrixType.getElementType(); |
| 150 | Size elementSize = 0; |
| 151 | Size elementAlignment = 1; |
| 152 | |
| 153 | decorateType(type: elementType, size&: elementSize, alignment&: elementAlignment); |
| 154 | // According to the Vulkan spec: |
| 155 | // "A matrix type inherits scalar alignment from the equivalent array |
| 156 | // declaration." |
| 157 | size = elementSize * numElements; |
| 158 | alignment = elementAlignment; |
| 159 | return spirv::MatrixType::get(columnType, columnCount: numColumns); |
| 160 | } |
| 161 | |
| 162 | Type VulkanLayoutUtils::decorateType(spirv::RuntimeArrayType arrayType, |
| 163 | VulkanLayoutUtils::Size &alignment) { |
| 164 | Type elementType = arrayType.getElementType(); |
| 165 | Size elementSize = 0; |
| 166 | |
| 167 | Type memberType = decorateType(type: elementType, size&: elementSize, alignment); |
| 168 | return spirv::RuntimeArrayType::get(elementType: memberType, stride: elementSize); |
| 169 | } |
| 170 | |
| 171 | VulkanLayoutUtils::Size |
| 172 | VulkanLayoutUtils::getScalarTypeAlignment(Type scalarType) { |
| 173 | // According to the Vulkan spec: |
| 174 | // 1. "A scalar of size N has a scalar alignment of N." |
| 175 | // 2. "A scalar has a base alignment equal to its scalar alignment." |
| 176 | // 3. "A scalar, vector or matrix type has an extended alignment equal to its |
| 177 | // base alignment." |
| 178 | unsigned bitWidth = scalarType.getIntOrFloatBitWidth(); |
| 179 | if (bitWidth == 1) |
| 180 | return 1; |
| 181 | return bitWidth / 8; |
| 182 | } |
| 183 | |
| 184 | bool VulkanLayoutUtils::isLegalType(Type type) { |
| 185 | auto ptrType = dyn_cast<spirv::PointerType>(Val&: type); |
| 186 | if (!ptrType) { |
| 187 | return true; |
| 188 | } |
| 189 | |
| 190 | const spirv::StorageClass storageClass = ptrType.getStorageClass(); |
| 191 | auto structType = dyn_cast<spirv::StructType>(Val: ptrType.getPointeeType()); |
| 192 | if (!structType) { |
| 193 | return true; |
| 194 | } |
| 195 | |
| 196 | switch (storageClass) { |
| 197 | case spirv::StorageClass::Uniform: |
| 198 | case spirv::StorageClass::StorageBuffer: |
| 199 | case spirv::StorageClass::PushConstant: |
| 200 | case spirv::StorageClass::PhysicalStorageBuffer: |
| 201 | return structType.hasOffset() || !structType.getNumElements(); |
| 202 | default: |
| 203 | return true; |
| 204 | } |
| 205 | } |
| 206 | |