1//===-- LayoutUtils.cpp - Decorate composite type with layout information -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements Utilities used to get alignment and layout information
10// for types in SPIR-V dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
16
17using namespace mlir;
18
19spirv::StructType
20VulkanLayoutUtils::decorateType(spirv::StructType structType) {
21 Size size = 0;
22 Size alignment = 1;
23 return decorateType(structType, size, alignment);
24}
25
26spirv::StructType
27VulkanLayoutUtils::decorateType(spirv::StructType structType,
28 VulkanLayoutUtils::Size &size,
29 VulkanLayoutUtils::Size &alignment) {
30 if (structType.getNumElements() == 0) {
31 return structType;
32 }
33
34 SmallVector<Type, 4> memberTypes;
35 SmallVector<spirv::StructType::OffsetInfo, 4> offsetInfo;
36 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
37
38 Size structMemberOffset = 0;
39 Size maxMemberAlignment = 1;
40
41 for (uint32_t i = 0, e = structType.getNumElements(); i < e; ++i) {
42 Size memberSize = 0;
43 Size memberAlignment = 1;
44
45 Type memberType =
46 decorateType(type: structType.getElementType(i), size&: memberSize, alignment&: memberAlignment);
47 structMemberOffset = llvm::alignTo(Value: structMemberOffset, Align: memberAlignment);
48 memberTypes.push_back(Elt: memberType);
49 offsetInfo.push_back(
50 Elt: static_cast<spirv::StructType::OffsetInfo>(structMemberOffset));
51 // If the member's size is the max value, it must be the last member and it
52 // must be a runtime array.
53 assert(memberSize != std::numeric_limits<Size>().max() ||
54 (i + 1 == e &&
55 isa<spirv::RuntimeArrayType>(structType.getElementType(i))));
56 // According to the Vulkan spec:
57 // "A structure has a base alignment equal to the largest base alignment of
58 // any of its members."
59 structMemberOffset += memberSize;
60 maxMemberAlignment = std::max(a: maxMemberAlignment, b: memberAlignment);
61 }
62
63 // According to the Vulkan spec:
64 // "The Offset decoration of a member must not place it between the end of a
65 // structure or an array and the next multiple of the alignment of that
66 // structure or array."
67 size = llvm::alignTo(Value: structMemberOffset, Align: maxMemberAlignment);
68 alignment = maxMemberAlignment;
69 structType.getMemberDecorations(memberDecorations);
70
71 if (!structType.isIdentified())
72 return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations);
73
74 // Identified structs are uniqued by identifier so it is not possible
75 // to create 2 structs with the same name but different decorations.
76 return nullptr;
77}
78
79Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
80 VulkanLayoutUtils::Size &alignment) {
81 if (isa<spirv::ScalarType>(Val: type)) {
82 alignment = getScalarTypeAlignment(scalarType: type);
83 // Vulkan spec does not specify any padding for a scalar type.
84 size = alignment;
85 return type;
86 }
87 if (auto structType = dyn_cast<spirv::StructType>(Val&: type))
88 return decorateType(structType, size, alignment);
89 if (auto arrayType = dyn_cast<spirv::ArrayType>(Val&: type))
90 return decorateType(arrayType, size, alignment);
91 if (auto vectorType = dyn_cast<VectorType>(Val&: type))
92 return decorateType(vectorType, size, alignment);
93 if (auto matrixType = dyn_cast<spirv::MatrixType>(Val&: type))
94 return decorateType(matrixType, size, alignment);
95 if (auto arrayType = dyn_cast<spirv::RuntimeArrayType>(Val&: type)) {
96 size = std::numeric_limits<Size>().max();
97 return decorateType(arrayType, alignment);
98 }
99 if (isa<spirv::PointerType>(Val: type)) {
100 // TODO: Add support for `PhysicalStorageBufferAddresses`.
101 return nullptr;
102 }
103 llvm_unreachable("unhandled SPIR-V type");
104}
105
106Type VulkanLayoutUtils::decorateType(VectorType vectorType,
107 VulkanLayoutUtils::Size &size,
108 VulkanLayoutUtils::Size &alignment) {
109 const unsigned numElements = vectorType.getNumElements();
110 Type elementType = vectorType.getElementType();
111 Size elementSize = 0;
112 Size elementAlignment = 1;
113
114 Type memberType = decorateType(type: elementType, size&: elementSize, alignment&: elementAlignment);
115 // According to the Vulkan spec:
116 // 1. "A two-component vector has a base alignment equal to twice its scalar
117 // alignment."
118 // 2. "A three- or four-component vector has a base alignment equal to four
119 // times its scalar alignment."
120 size = elementSize * numElements;
121 alignment = numElements == 2 ? elementAlignment * 2 : elementAlignment * 4;
122 return VectorType::get(shape: numElements, elementType: memberType);
123}
124
125Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType,
126 VulkanLayoutUtils::Size &size,
127 VulkanLayoutUtils::Size &alignment) {
128 const unsigned numElements = arrayType.getNumElements();
129 Type elementType = arrayType.getElementType();
130 Size elementSize = 0;
131 Size elementAlignment = 1;
132
133 Type memberType = decorateType(type: elementType, size&: elementSize, alignment&: elementAlignment);
134 // According to the Vulkan spec:
135 // "An array has a base alignment equal to the base alignment of its element
136 // type."
137 size = elementSize * numElements;
138 alignment = elementAlignment;
139 return spirv::ArrayType::get(elementType: memberType, elementCount: numElements, stride: elementSize);
140}
141
142Type VulkanLayoutUtils::decorateType(spirv::MatrixType matrixType,
143 VulkanLayoutUtils::Size &size,
144 VulkanLayoutUtils::Size &alignment) {
145 const unsigned numColumns = matrixType.getNumColumns();
146 Type columnType = matrixType.getColumnType();
147 unsigned numElements = matrixType.getNumElements();
148 Type elementType = matrixType.getElementType();
149 Size elementSize = 0;
150 Size elementAlignment = 1;
151
152 decorateType(type: elementType, size&: elementSize, alignment&: elementAlignment);
153 // According to the Vulkan spec:
154 // "A matrix type inherits scalar alignment from the equivalent array
155 // declaration."
156 size = elementSize * numElements;
157 alignment = elementAlignment;
158 return spirv::MatrixType::get(columnType, columnCount: numColumns);
159}
160
161Type VulkanLayoutUtils::decorateType(spirv::RuntimeArrayType arrayType,
162 VulkanLayoutUtils::Size &alignment) {
163 Type elementType = arrayType.getElementType();
164 Size elementSize = 0;
165
166 Type memberType = decorateType(type: elementType, size&: elementSize, alignment);
167 return spirv::RuntimeArrayType::get(elementType: memberType, stride: elementSize);
168}
169
170VulkanLayoutUtils::Size
171VulkanLayoutUtils::getScalarTypeAlignment(Type scalarType) {
172 // According to the Vulkan spec:
173 // 1. "A scalar of size N has a scalar alignment of N."
174 // 2. "A scalar has a base alignment equal to its scalar alignment."
175 // 3. "A scalar, vector or matrix type has an extended alignment equal to its
176 // base alignment."
177 unsigned bitWidth = scalarType.getIntOrFloatBitWidth();
178 if (bitWidth == 1)
179 return 1;
180 return bitWidth / 8;
181}
182
183bool VulkanLayoutUtils::isLegalType(Type type) {
184 auto ptrType = dyn_cast<spirv::PointerType>(Val&: type);
185 if (!ptrType) {
186 return true;
187 }
188
189 const spirv::StorageClass storageClass = ptrType.getStorageClass();
190 auto structType = dyn_cast<spirv::StructType>(Val: ptrType.getPointeeType());
191 if (!structType) {
192 return true;
193 }
194
195 switch (storageClass) {
196 case spirv::StorageClass::Uniform:
197 case spirv::StorageClass::StorageBuffer:
198 case spirv::StorageClass::PushConstant:
199 case spirv::StorageClass::PhysicalStorageBuffer:
200 return structType.hasOffset() || !structType.getNumElements();
201 default:
202 return true;
203 }
204}
205

source code of mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp