1 | //===-- LayoutUtils.cpp - Decorate composite type with layout information -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements Utilities used to get alignment and layout information |
10 | // for types in SPIR-V dialect. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" |
15 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
16 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
17 | |
18 | using namespace mlir; |
19 | |
20 | spirv::StructType |
21 | VulkanLayoutUtils::decorateType(spirv::StructType structType) { |
22 | Size size = 0; |
23 | Size alignment = 1; |
24 | return decorateType(structType, size, alignment); |
25 | } |
26 | |
27 | spirv::StructType |
28 | VulkanLayoutUtils::decorateType(spirv::StructType structType, |
29 | VulkanLayoutUtils::Size &size, |
30 | VulkanLayoutUtils::Size &alignment) { |
31 | if (structType.getNumElements() == 0) { |
32 | return structType; |
33 | } |
34 | |
35 | SmallVector<Type, 4> memberTypes; |
36 | SmallVector<spirv::StructType::OffsetInfo, 4> offsetInfo; |
37 | SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; |
38 | |
39 | Size structMemberOffset = 0; |
40 | Size maxMemberAlignment = 1; |
41 | |
42 | for (uint32_t i = 0, e = structType.getNumElements(); i < e; ++i) { |
43 | Size memberSize = 0; |
44 | Size memberAlignment = 1; |
45 | |
46 | Type memberType = |
47 | decorateType(type: structType.getElementType(i), size&: memberSize, alignment&: memberAlignment); |
48 | structMemberOffset = llvm::alignTo(Value: structMemberOffset, Align: memberAlignment); |
49 | memberTypes.push_back(Elt: memberType); |
50 | offsetInfo.push_back( |
51 | Elt: static_cast<spirv::StructType::OffsetInfo>(structMemberOffset)); |
52 | // If the member's size is the max value, it must be the last member and it |
53 | // must be a runtime array. |
54 | assert(memberSize != std::numeric_limits<Size>().max() || |
55 | (i + 1 == e && |
56 | isa<spirv::RuntimeArrayType>(structType.getElementType(i)))); |
57 | // According to the Vulkan spec: |
58 | // "A structure has a base alignment equal to the largest base alignment of |
59 | // any of its members." |
60 | structMemberOffset += memberSize; |
61 | maxMemberAlignment = std::max(a: maxMemberAlignment, b: memberAlignment); |
62 | } |
63 | |
64 | // According to the Vulkan spec: |
65 | // "The Offset decoration of a member must not place it between the end of a |
66 | // structure or an array and the next multiple of the alignment of that |
67 | // structure or array." |
68 | size = llvm::alignTo(Value: structMemberOffset, Align: maxMemberAlignment); |
69 | alignment = maxMemberAlignment; |
70 | structType.getMemberDecorations(memberDecorations); |
71 | |
72 | if (!structType.isIdentified()) |
73 | return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations); |
74 | |
75 | // Identified structs are uniqued by identifier so it is not possible |
76 | // to create 2 structs with the same name but different decorations. |
77 | return nullptr; |
78 | } |
79 | |
80 | Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size, |
81 | VulkanLayoutUtils::Size &alignment) { |
82 | if (isa<spirv::ScalarType>(Val: type)) { |
83 | alignment = getScalarTypeAlignment(scalarType: type); |
84 | // Vulkan spec does not specify any padding for a scalar type. |
85 | size = alignment; |
86 | return type; |
87 | } |
88 | if (auto structType = dyn_cast<spirv::StructType>(Val&: type)) |
89 | return decorateType(structType, size, alignment); |
90 | if (auto arrayType = dyn_cast<spirv::ArrayType>(Val&: type)) |
91 | return decorateType(arrayType, size, alignment); |
92 | if (auto vectorType = dyn_cast<VectorType>(type)) |
93 | return decorateType(vectorType, size, alignment); |
94 | if (auto matrixType = dyn_cast<spirv::MatrixType>(Val&: type)) |
95 | return decorateType(matrixType, size, alignment); |
96 | if (auto arrayType = dyn_cast<spirv::RuntimeArrayType>(Val&: type)) { |
97 | size = std::numeric_limits<Size>().max(); |
98 | return decorateType(arrayType, alignment); |
99 | } |
100 | if (isa<spirv::PointerType>(Val: type)) { |
101 | // TODO: Add support for `PhysicalStorageBufferAddresses`. |
102 | return nullptr; |
103 | } |
104 | llvm_unreachable("unhandled SPIR-V type" ); |
105 | } |
106 | |
107 | Type VulkanLayoutUtils::decorateType(VectorType vectorType, |
108 | VulkanLayoutUtils::Size &size, |
109 | VulkanLayoutUtils::Size &alignment) { |
110 | const unsigned numElements = vectorType.getNumElements(); |
111 | Type elementType = vectorType.getElementType(); |
112 | Size elementSize = 0; |
113 | Size elementAlignment = 1; |
114 | |
115 | Type memberType = decorateType(type: elementType, size&: elementSize, alignment&: elementAlignment); |
116 | // According to the Vulkan spec: |
117 | // 1. "A two-component vector has a base alignment equal to twice its scalar |
118 | // alignment." |
119 | // 2. "A three- or four-component vector has a base alignment equal to four |
120 | // times its scalar alignment." |
121 | size = elementSize * numElements; |
122 | alignment = numElements == 2 ? elementAlignment * 2 : elementAlignment * 4; |
123 | return VectorType::get(numElements, memberType); |
124 | } |
125 | |
126 | Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType, |
127 | VulkanLayoutUtils::Size &size, |
128 | VulkanLayoutUtils::Size &alignment) { |
129 | const unsigned numElements = arrayType.getNumElements(); |
130 | Type elementType = arrayType.getElementType(); |
131 | Size elementSize = 0; |
132 | Size elementAlignment = 1; |
133 | |
134 | Type memberType = decorateType(type: elementType, size&: elementSize, alignment&: elementAlignment); |
135 | // According to the Vulkan spec: |
136 | // "An array has a base alignment equal to the base alignment of its element |
137 | // type." |
138 | size = elementSize * numElements; |
139 | alignment = elementAlignment; |
140 | return spirv::ArrayType::get(elementType: memberType, elementCount: numElements, stride: elementSize); |
141 | } |
142 | |
143 | Type VulkanLayoutUtils::decorateType(spirv::MatrixType matrixType, |
144 | VulkanLayoutUtils::Size &size, |
145 | VulkanLayoutUtils::Size &alignment) { |
146 | const unsigned numColumns = matrixType.getNumColumns(); |
147 | Type columnType = matrixType.getColumnType(); |
148 | unsigned numElements = matrixType.getNumElements(); |
149 | Type elementType = matrixType.getElementType(); |
150 | Size elementSize = 0; |
151 | Size elementAlignment = 1; |
152 | |
153 | decorateType(type: elementType, size&: elementSize, alignment&: elementAlignment); |
154 | // According to the Vulkan spec: |
155 | // "A matrix type inherits scalar alignment from the equivalent array |
156 | // declaration." |
157 | size = elementSize * numElements; |
158 | alignment = elementAlignment; |
159 | return spirv::MatrixType::get(columnType, columnCount: numColumns); |
160 | } |
161 | |
162 | Type VulkanLayoutUtils::decorateType(spirv::RuntimeArrayType arrayType, |
163 | VulkanLayoutUtils::Size &alignment) { |
164 | Type elementType = arrayType.getElementType(); |
165 | Size elementSize = 0; |
166 | |
167 | Type memberType = decorateType(type: elementType, size&: elementSize, alignment); |
168 | return spirv::RuntimeArrayType::get(elementType: memberType, stride: elementSize); |
169 | } |
170 | |
171 | VulkanLayoutUtils::Size |
172 | VulkanLayoutUtils::getScalarTypeAlignment(Type scalarType) { |
173 | // According to the Vulkan spec: |
174 | // 1. "A scalar of size N has a scalar alignment of N." |
175 | // 2. "A scalar has a base alignment equal to its scalar alignment." |
176 | // 3. "A scalar, vector or matrix type has an extended alignment equal to its |
177 | // base alignment." |
178 | unsigned bitWidth = scalarType.getIntOrFloatBitWidth(); |
179 | if (bitWidth == 1) |
180 | return 1; |
181 | return bitWidth / 8; |
182 | } |
183 | |
184 | bool VulkanLayoutUtils::isLegalType(Type type) { |
185 | auto ptrType = dyn_cast<spirv::PointerType>(Val&: type); |
186 | if (!ptrType) { |
187 | return true; |
188 | } |
189 | |
190 | const spirv::StorageClass storageClass = ptrType.getStorageClass(); |
191 | auto structType = dyn_cast<spirv::StructType>(Val: ptrType.getPointeeType()); |
192 | if (!structType) { |
193 | return true; |
194 | } |
195 | |
196 | switch (storageClass) { |
197 | case spirv::StorageClass::Uniform: |
198 | case spirv::StorageClass::StorageBuffer: |
199 | case spirv::StorageClass::PushConstant: |
200 | case spirv::StorageClass::PhysicalStorageBuffer: |
201 | return structType.hasOffset() || !structType.getNumElements(); |
202 | default: |
203 | return true; |
204 | } |
205 | } |
206 | |