1 | //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the types in the SPIR-V dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
14 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
15 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
16 | #include "mlir/IR/Attributes.h" |
17 | #include "mlir/IR/BuiltinTypes.h" |
18 | #include "llvm/ADT/STLExtras.h" |
19 | #include "llvm/ADT/TypeSwitch.h" |
20 | |
21 | #include <cstdint> |
22 | #include <iterator> |
23 | |
24 | using namespace mlir; |
25 | using namespace mlir::spirv; |
26 | |
27 | //===----------------------------------------------------------------------===// |
28 | // ArrayType |
29 | //===----------------------------------------------------------------------===// |
30 | |
31 | struct spirv::detail::ArrayTypeStorage : public TypeStorage { |
32 | using KeyTy = std::tuple<Type, unsigned, unsigned>; |
33 | |
34 | static ArrayTypeStorage *construct(TypeStorageAllocator &allocator, |
35 | const KeyTy &key) { |
36 | return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key); |
37 | } |
38 | |
39 | bool operator==(const KeyTy &key) const { |
40 | return key == KeyTy(elementType, elementCount, stride); |
41 | } |
42 | |
43 | ArrayTypeStorage(const KeyTy &key) |
44 | : elementType(std::get<0>(t: key)), elementCount(std::get<1>(t: key)), |
45 | stride(std::get<2>(t: key)) {} |
46 | |
47 | Type elementType; |
48 | unsigned elementCount; |
49 | unsigned stride; |
50 | }; |
51 | |
52 | ArrayType ArrayType::get(Type elementType, unsigned elementCount) { |
53 | assert(elementCount && "ArrayType needs at least one element" ); |
54 | return Base::get(ctx: elementType.getContext(), args&: elementType, args&: elementCount, |
55 | /*stride=*/args: 0); |
56 | } |
57 | |
58 | ArrayType ArrayType::get(Type elementType, unsigned elementCount, |
59 | unsigned stride) { |
60 | assert(elementCount && "ArrayType needs at least one element" ); |
61 | return Base::get(ctx: elementType.getContext(), args&: elementType, args&: elementCount, args&: stride); |
62 | } |
63 | |
64 | unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; } |
65 | |
66 | Type ArrayType::getElementType() const { return getImpl()->elementType; } |
67 | |
68 | unsigned ArrayType::getArrayStride() const { return getImpl()->stride; } |
69 | |
70 | void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, |
71 | std::optional<StorageClass> storage) { |
72 | llvm::cast<SPIRVType>(Val: getElementType()).getExtensions(extensions, storage); |
73 | } |
74 | |
75 | void ArrayType::getCapabilities( |
76 | SPIRVType::CapabilityArrayRefVector &capabilities, |
77 | std::optional<StorageClass> storage) { |
78 | llvm::cast<SPIRVType>(Val: getElementType()) |
79 | .getCapabilities(capabilities, storage); |
80 | } |
81 | |
82 | std::optional<int64_t> ArrayType::getSizeInBytes() { |
83 | auto elementType = llvm::cast<SPIRVType>(Val: getElementType()); |
84 | std::optional<int64_t> size = elementType.getSizeInBytes(); |
85 | if (!size) |
86 | return std::nullopt; |
87 | return (*size + getArrayStride()) * getNumElements(); |
88 | } |
89 | |
90 | //===----------------------------------------------------------------------===// |
91 | // CompositeType |
92 | //===----------------------------------------------------------------------===// |
93 | |
94 | bool CompositeType::classof(Type type) { |
95 | if (auto vectorType = llvm::dyn_cast<VectorType>(type)) |
96 | return isValid(vectorType); |
97 | return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType, |
98 | spirv::JointMatrixINTELType, spirv::MatrixType, |
99 | spirv::RuntimeArrayType, spirv::StructType>(Val: type); |
100 | } |
101 | |
102 | bool CompositeType::isValid(VectorType type) { |
103 | return type.getRank() == 1 && |
104 | llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) && |
105 | llvm::isa<ScalarType>(type.getElementType()); |
106 | } |
107 | |
108 | Type CompositeType::getElementType(unsigned index) const { |
109 | return TypeSwitch<Type, Type>(*this) |
110 | .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType, |
111 | RuntimeArrayType, VectorType>( |
112 | [](auto type) { return type.getElementType(); }) |
113 | .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); }) |
114 | .Case<StructType>( |
115 | [index](StructType type) { return type.getElementType(index); }) |
116 | .Default( |
117 | [](Type) -> Type { llvm_unreachable("invalid composite type" ); }); |
118 | } |
119 | |
120 | unsigned CompositeType::getNumElements() const { |
121 | if (auto arrayType = llvm::dyn_cast<ArrayType>(Val: *this)) |
122 | return arrayType.getNumElements(); |
123 | if (auto matrixType = llvm::dyn_cast<MatrixType>(Val: *this)) |
124 | return matrixType.getNumColumns(); |
125 | if (auto structType = llvm::dyn_cast<StructType>(Val: *this)) |
126 | return structType.getNumElements(); |
127 | if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) |
128 | return vectorType.getNumElements(); |
129 | if (llvm::isa<CooperativeMatrixType>(Val: *this)) { |
130 | llvm_unreachable( |
131 | "invalid to query number of elements of spirv Cooperative Matrix type" ); |
132 | } |
133 | if (llvm::isa<JointMatrixINTELType>(Val: *this)) { |
134 | llvm_unreachable( |
135 | "invalid to query number of elements of spirv::JointMatrix type" ); |
136 | } |
137 | if (llvm::isa<RuntimeArrayType>(Val: *this)) { |
138 | llvm_unreachable( |
139 | "invalid to query number of elements of spirv::RuntimeArray type" ); |
140 | } |
141 | llvm_unreachable("invalid composite type" ); |
142 | } |
143 | |
144 | bool CompositeType::hasCompileTimeKnownNumElements() const { |
145 | return !llvm::isa<CooperativeMatrixType, JointMatrixINTELType, |
146 | RuntimeArrayType>(Val: *this); |
147 | } |
148 | |
149 | void CompositeType::getExtensions( |
150 | SPIRVType::ExtensionArrayRefVector &extensions, |
151 | std::optional<StorageClass> storage) { |
152 | TypeSwitch<Type>(*this) |
153 | .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType, MatrixType, |
154 | RuntimeArrayType, StructType>( |
155 | [&](auto type) { type.getExtensions(extensions, storage); }) |
156 | .Case<VectorType>([&](VectorType type) { |
157 | return llvm::cast<ScalarType>(type.getElementType()) |
158 | .getExtensions(extensions, storage); |
159 | }) |
160 | .Default([](Type) { llvm_unreachable("invalid composite type" ); }); |
161 | } |
162 | |
163 | void CompositeType::getCapabilities( |
164 | SPIRVType::CapabilityArrayRefVector &capabilities, |
165 | std::optional<StorageClass> storage) { |
166 | TypeSwitch<Type>(*this) |
167 | .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType, MatrixType, |
168 | RuntimeArrayType, StructType>( |
169 | [&](auto type) { type.getCapabilities(capabilities, storage); }) |
170 | .Case<VectorType>([&](VectorType type) { |
171 | auto vecSize = getNumElements(); |
172 | if (vecSize == 8 || vecSize == 16) { |
173 | static const Capability caps[] = {Capability::Vector16}; |
174 | ArrayRef<Capability> ref(caps, std::size(caps)); |
175 | capabilities.push_back(ref); |
176 | } |
177 | return llvm::cast<ScalarType>(type.getElementType()) |
178 | .getCapabilities(capabilities, storage); |
179 | }) |
180 | .Default([](Type) { llvm_unreachable("invalid composite type" ); }); |
181 | } |
182 | |
183 | std::optional<int64_t> CompositeType::getSizeInBytes() { |
184 | if (auto arrayType = llvm::dyn_cast<ArrayType>(Val&: *this)) |
185 | return arrayType.getSizeInBytes(); |
186 | if (auto structType = llvm::dyn_cast<StructType>(Val&: *this)) |
187 | return structType.getSizeInBytes(); |
188 | if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) { |
189 | std::optional<int64_t> elementSize = |
190 | llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes(); |
191 | if (!elementSize) |
192 | return std::nullopt; |
193 | return *elementSize * vectorType.getNumElements(); |
194 | } |
195 | return std::nullopt; |
196 | } |
197 | |
198 | //===----------------------------------------------------------------------===// |
199 | // CooperativeMatrixType |
200 | //===----------------------------------------------------------------------===// |
201 | |
202 | struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage { |
203 | using KeyTy = |
204 | std::tuple<Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR>; |
205 | |
206 | static CooperativeMatrixTypeStorage * |
207 | construct(TypeStorageAllocator &allocator, const KeyTy &key) { |
208 | return new (allocator.allocate<CooperativeMatrixTypeStorage>()) |
209 | CooperativeMatrixTypeStorage(key); |
210 | } |
211 | |
212 | bool operator==(const KeyTy &key) const { |
213 | return key == KeyTy(elementType, rows, columns, scope, use); |
214 | } |
215 | |
216 | CooperativeMatrixTypeStorage(const KeyTy &key) |
217 | : elementType(std::get<0>(key)), rows(std::get<1>(key)), |
218 | columns(std::get<2>(key)), scope(std::get<3>(key)), |
219 | use(std::get<4>(key)) {} |
220 | |
221 | Type elementType; |
222 | uint32_t rows; |
223 | uint32_t columns; |
224 | Scope scope; |
225 | CooperativeMatrixUseKHR use; |
226 | }; |
227 | |
228 | CooperativeMatrixType CooperativeMatrixType::get(Type elementType, |
229 | uint32_t rows, |
230 | uint32_t columns, Scope scope, |
231 | CooperativeMatrixUseKHR use) { |
232 | return Base::get(elementType.getContext(), elementType, rows, columns, scope, |
233 | use); |
234 | } |
235 | |
236 | Type CooperativeMatrixType::getElementType() const { |
237 | return getImpl()->elementType; |
238 | } |
239 | |
240 | uint32_t CooperativeMatrixType::getRows() const { return getImpl()->rows; } |
241 | |
242 | uint32_t CooperativeMatrixType::getColumns() const { |
243 | return getImpl()->columns; |
244 | } |
245 | |
246 | Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; } |
247 | |
248 | CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const { |
249 | return getImpl()->use; |
250 | } |
251 | |
252 | void CooperativeMatrixType::getExtensions( |
253 | SPIRVType::ExtensionArrayRefVector &extensions, |
254 | std::optional<StorageClass> storage) { |
255 | llvm::cast<SPIRVType>(Val: getElementType()).getExtensions(extensions, storage); |
256 | static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix}; |
257 | extensions.push_back(exts); |
258 | } |
259 | |
260 | void CooperativeMatrixType::getCapabilities( |
261 | SPIRVType::CapabilityArrayRefVector &capabilities, |
262 | std::optional<StorageClass> storage) { |
263 | llvm::cast<SPIRVType>(Val: getElementType()) |
264 | .getCapabilities(capabilities, storage); |
265 | static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR}; |
266 | capabilities.push_back(caps); |
267 | } |
268 | |
269 | //===----------------------------------------------------------------------===// |
270 | // JointMatrixType |
271 | //===----------------------------------------------------------------------===// |
272 | |
273 | struct spirv::detail::JointMatrixTypeStorage : public TypeStorage { |
274 | using KeyTy = std::tuple<Type, unsigned, unsigned, MatrixLayout, Scope>; |
275 | |
276 | static JointMatrixTypeStorage *construct(TypeStorageAllocator &allocator, |
277 | const KeyTy &key) { |
278 | return new (allocator.allocate<JointMatrixTypeStorage>()) |
279 | JointMatrixTypeStorage(key); |
280 | } |
281 | |
282 | bool operator==(const KeyTy &key) const { |
283 | return key == KeyTy(elementType, rows, columns, matrixLayout, scope); |
284 | } |
285 | |
286 | JointMatrixTypeStorage(const KeyTy &key) |
287 | : elementType(std::get<0>(key)), rows(std::get<1>(key)), |
288 | columns(std::get<2>(key)), scope(std::get<4>(key)), |
289 | matrixLayout(std::get<3>(key)) {} |
290 | |
291 | Type elementType; |
292 | unsigned rows; |
293 | unsigned columns; |
294 | Scope scope; |
295 | MatrixLayout matrixLayout; |
296 | }; |
297 | |
298 | JointMatrixINTELType JointMatrixINTELType::get(Type elementType, Scope scope, |
299 | unsigned rows, unsigned columns, |
300 | MatrixLayout matrixLayout) { |
301 | return Base::get(elementType.getContext(), elementType, rows, columns, |
302 | matrixLayout, scope); |
303 | } |
304 | |
305 | Type JointMatrixINTELType::getElementType() const { |
306 | return getImpl()->elementType; |
307 | } |
308 | |
309 | Scope JointMatrixINTELType::getScope() const { return getImpl()->scope; } |
310 | |
311 | unsigned JointMatrixINTELType::getRows() const { return getImpl()->rows; } |
312 | |
313 | unsigned JointMatrixINTELType::getColumns() const { return getImpl()->columns; } |
314 | |
315 | MatrixLayout JointMatrixINTELType::getMatrixLayout() const { |
316 | return getImpl()->matrixLayout; |
317 | } |
318 | |
319 | void JointMatrixINTELType::getExtensions( |
320 | SPIRVType::ExtensionArrayRefVector &extensions, |
321 | std::optional<StorageClass> storage) { |
322 | llvm::cast<SPIRVType>(Val: getElementType()).getExtensions(extensions, storage); |
323 | static const Extension exts[] = {Extension::SPV_INTEL_joint_matrix}; |
324 | ArrayRef<Extension> ref(exts, std::size(exts)); |
325 | extensions.push_back(Elt: ref); |
326 | } |
327 | |
328 | void JointMatrixINTELType::getCapabilities( |
329 | SPIRVType::CapabilityArrayRefVector &capabilities, |
330 | std::optional<StorageClass> storage) { |
331 | llvm::cast<SPIRVType>(Val: getElementType()) |
332 | .getCapabilities(capabilities, storage); |
333 | static const Capability caps[] = {Capability::JointMatrixINTEL}; |
334 | ArrayRef<Capability> ref(caps, std::size(caps)); |
335 | capabilities.push_back(Elt: ref); |
336 | } |
337 | |
338 | //===----------------------------------------------------------------------===// |
339 | // ImageType |
340 | //===----------------------------------------------------------------------===// |
341 | |
342 | template <typename T> |
343 | static constexpr unsigned getNumBits() { |
344 | return 0; |
345 | } |
346 | template <> |
347 | constexpr unsigned getNumBits<Dim>() { |
348 | static_assert((1 << 3) > getMaxEnumValForDim(), |
349 | "Not enough bits to encode Dim value" ); |
350 | return 3; |
351 | } |
352 | template <> |
353 | constexpr unsigned getNumBits<ImageDepthInfo>() { |
354 | static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(), |
355 | "Not enough bits to encode ImageDepthInfo value" ); |
356 | return 2; |
357 | } |
358 | template <> |
359 | constexpr unsigned getNumBits<ImageArrayedInfo>() { |
360 | static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(), |
361 | "Not enough bits to encode ImageArrayedInfo value" ); |
362 | return 1; |
363 | } |
364 | template <> |
365 | constexpr unsigned getNumBits<ImageSamplingInfo>() { |
366 | static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(), |
367 | "Not enough bits to encode ImageSamplingInfo value" ); |
368 | return 1; |
369 | } |
370 | template <> |
371 | constexpr unsigned getNumBits<ImageSamplerUseInfo>() { |
372 | static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(), |
373 | "Not enough bits to encode ImageSamplerUseInfo value" ); |
374 | return 2; |
375 | } |
376 | template <> |
377 | constexpr unsigned getNumBits<ImageFormat>() { |
378 | static_assert((1 << 6) > getMaxEnumValForImageFormat(), |
379 | "Not enough bits to encode ImageFormat value" ); |
380 | return 6; |
381 | } |
382 | |
383 | struct spirv::detail::ImageTypeStorage : public TypeStorage { |
384 | public: |
385 | using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo, |
386 | ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>; |
387 | |
388 | static ImageTypeStorage *construct(TypeStorageAllocator &allocator, |
389 | const KeyTy &key) { |
390 | return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key); |
391 | } |
392 | |
393 | bool operator==(const KeyTy &key) const { |
394 | return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo, |
395 | samplerUseInfo, format); |
396 | } |
397 | |
398 | ImageTypeStorage(const KeyTy &key) |
399 | : elementType(std::get<0>(key)), dim(std::get<1>(key)), |
400 | depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)), |
401 | samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)), |
402 | format(std::get<6>(key)) {} |
403 | |
404 | Type elementType; |
405 | Dim dim : getNumBits<Dim>(); |
406 | ImageDepthInfo depthInfo : getNumBits<ImageDepthInfo>(); |
407 | ImageArrayedInfo arrayedInfo : getNumBits<ImageArrayedInfo>(); |
408 | ImageSamplingInfo samplingInfo : getNumBits<ImageSamplingInfo>(); |
409 | ImageSamplerUseInfo samplerUseInfo : getNumBits<ImageSamplerUseInfo>(); |
410 | ImageFormat format : getNumBits<ImageFormat>(); |
411 | }; |
412 | |
413 | ImageType |
414 | ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo, |
415 | ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat> |
416 | value) { |
417 | return Base::get(std::get<0>(value).getContext(), value); |
418 | } |
419 | |
420 | Type ImageType::getElementType() const { return getImpl()->elementType; } |
421 | |
422 | Dim ImageType::getDim() const { return getImpl()->dim; } |
423 | |
424 | ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; } |
425 | |
426 | ImageArrayedInfo ImageType::getArrayedInfo() const { |
427 | return getImpl()->arrayedInfo; |
428 | } |
429 | |
430 | ImageSamplingInfo ImageType::getSamplingInfo() const { |
431 | return getImpl()->samplingInfo; |
432 | } |
433 | |
434 | ImageSamplerUseInfo ImageType::getSamplerUseInfo() const { |
435 | return getImpl()->samplerUseInfo; |
436 | } |
437 | |
438 | ImageFormat ImageType::getImageFormat() const { return getImpl()->format; } |
439 | |
440 | void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &, |
441 | std::optional<StorageClass>) { |
442 | // Image types do not require extra extensions thus far. |
443 | } |
444 | |
445 | void ImageType::getCapabilities( |
446 | SPIRVType::CapabilityArrayRefVector &capabilities, |
447 | std::optional<StorageClass>) { |
448 | if (auto dimCaps = spirv::getCapabilities(getDim())) |
449 | capabilities.push_back(Elt: *dimCaps); |
450 | |
451 | if (auto fmtCaps = spirv::getCapabilities(getImageFormat())) |
452 | capabilities.push_back(Elt: *fmtCaps); |
453 | } |
454 | |
455 | //===----------------------------------------------------------------------===// |
456 | // PointerType |
457 | //===----------------------------------------------------------------------===// |
458 | |
459 | struct spirv::detail::PointerTypeStorage : public TypeStorage { |
460 | // (Type, StorageClass) as the key: Type stored in this struct, and |
461 | // StorageClass stored as TypeStorage's subclass data. |
462 | using KeyTy = std::pair<Type, StorageClass>; |
463 | |
464 | static PointerTypeStorage *construct(TypeStorageAllocator &allocator, |
465 | const KeyTy &key) { |
466 | return new (allocator.allocate<PointerTypeStorage>()) |
467 | PointerTypeStorage(key); |
468 | } |
469 | |
470 | bool operator==(const KeyTy &key) const { |
471 | return key == KeyTy(pointeeType, storageClass); |
472 | } |
473 | |
474 | PointerTypeStorage(const KeyTy &key) |
475 | : pointeeType(key.first), storageClass(key.second) {} |
476 | |
477 | Type pointeeType; |
478 | StorageClass storageClass; |
479 | }; |
480 | |
481 | PointerType PointerType::get(Type pointeeType, StorageClass storageClass) { |
482 | return Base::get(pointeeType.getContext(), pointeeType, storageClass); |
483 | } |
484 | |
485 | Type PointerType::getPointeeType() const { return getImpl()->pointeeType; } |
486 | |
487 | StorageClass PointerType::getStorageClass() const { |
488 | return getImpl()->storageClass; |
489 | } |
490 | |
491 | void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, |
492 | std::optional<StorageClass> storage) { |
493 | // Use this pointer type's storage class because this pointer indicates we are |
494 | // using the pointee type in that specific storage class. |
495 | llvm::cast<SPIRVType>(getPointeeType()) |
496 | .getExtensions(extensions, getStorageClass()); |
497 | |
498 | if (auto scExts = spirv::getExtensions(getStorageClass())) |
499 | extensions.push_back(Elt: *scExts); |
500 | } |
501 | |
502 | void PointerType::getCapabilities( |
503 | SPIRVType::CapabilityArrayRefVector &capabilities, |
504 | std::optional<StorageClass> storage) { |
505 | // Use this pointer type's storage class because this pointer indicates we are |
506 | // using the pointee type in that specific storage class. |
507 | llvm::cast<SPIRVType>(getPointeeType()) |
508 | .getCapabilities(capabilities, getStorageClass()); |
509 | |
510 | if (auto scCaps = spirv::getCapabilities(getStorageClass())) |
511 | capabilities.push_back(Elt: *scCaps); |
512 | } |
513 | |
514 | //===----------------------------------------------------------------------===// |
515 | // RuntimeArrayType |
516 | //===----------------------------------------------------------------------===// |
517 | |
518 | struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage { |
519 | using KeyTy = std::pair<Type, unsigned>; |
520 | |
521 | static RuntimeArrayTypeStorage *construct(TypeStorageAllocator &allocator, |
522 | const KeyTy &key) { |
523 | return new (allocator.allocate<RuntimeArrayTypeStorage>()) |
524 | RuntimeArrayTypeStorage(key); |
525 | } |
526 | |
527 | bool operator==(const KeyTy &key) const { |
528 | return key == KeyTy(elementType, stride); |
529 | } |
530 | |
531 | RuntimeArrayTypeStorage(const KeyTy &key) |
532 | : elementType(key.first), stride(key.second) {} |
533 | |
534 | Type elementType; |
535 | unsigned stride; |
536 | }; |
537 | |
538 | RuntimeArrayType RuntimeArrayType::get(Type elementType) { |
539 | return Base::get(ctx: elementType.getContext(), args&: elementType, /*stride=*/args: 0); |
540 | } |
541 | |
542 | RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) { |
543 | return Base::get(ctx: elementType.getContext(), args&: elementType, args&: stride); |
544 | } |
545 | |
546 | Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; } |
547 | |
548 | unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; } |
549 | |
550 | void RuntimeArrayType::getExtensions( |
551 | SPIRVType::ExtensionArrayRefVector &extensions, |
552 | std::optional<StorageClass> storage) { |
553 | llvm::cast<SPIRVType>(Val: getElementType()).getExtensions(extensions, storage); |
554 | } |
555 | |
556 | void RuntimeArrayType::getCapabilities( |
557 | SPIRVType::CapabilityArrayRefVector &capabilities, |
558 | std::optional<StorageClass> storage) { |
559 | { |
560 | static const Capability caps[] = {Capability::Shader}; |
561 | ArrayRef<Capability> ref(caps, std::size(caps)); |
562 | capabilities.push_back(Elt: ref); |
563 | } |
564 | llvm::cast<SPIRVType>(Val: getElementType()) |
565 | .getCapabilities(capabilities, storage); |
566 | } |
567 | |
568 | //===----------------------------------------------------------------------===// |
569 | // ScalarType |
570 | //===----------------------------------------------------------------------===// |
571 | |
572 | bool ScalarType::classof(Type type) { |
573 | if (auto floatType = llvm::dyn_cast<FloatType>(Val&: type)) { |
574 | return isValid(floatType); |
575 | } |
576 | if (auto intType = llvm::dyn_cast<IntegerType>(type)) { |
577 | return isValid(intType); |
578 | } |
579 | return false; |
580 | } |
581 | |
582 | bool ScalarType::isValid(FloatType type) { |
583 | return llvm::is_contained(Set: {16u, 32u, 64u}, Element: type.getWidth()) && !type.isBF16(); |
584 | } |
585 | |
586 | bool ScalarType::isValid(IntegerType type) { |
587 | return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth()); |
588 | } |
589 | |
590 | void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, |
591 | std::optional<StorageClass> storage) { |
592 | // 8- or 16-bit integer/floating-point numbers will require extra extensions |
593 | // to appear in interface storage classes. See SPV_KHR_16bit_storage and |
594 | // SPV_KHR_8bit_storage for more details. |
595 | if (!storage) |
596 | return; |
597 | |
598 | switch (*storage) { |
599 | case StorageClass::PushConstant: |
600 | case StorageClass::StorageBuffer: |
601 | case StorageClass::Uniform: |
602 | if (getIntOrFloatBitWidth() == 8) { |
603 | static const Extension exts[] = {Extension::SPV_KHR_8bit_storage}; |
604 | ArrayRef<Extension> ref(exts, std::size(exts)); |
605 | extensions.push_back(Elt: ref); |
606 | } |
607 | [[fallthrough]]; |
608 | case StorageClass::Input: |
609 | case StorageClass::Output: |
610 | if (getIntOrFloatBitWidth() == 16) { |
611 | static const Extension exts[] = {Extension::SPV_KHR_16bit_storage}; |
612 | ArrayRef<Extension> ref(exts, std::size(exts)); |
613 | extensions.push_back(Elt: ref); |
614 | } |
615 | break; |
616 | default: |
617 | break; |
618 | } |
619 | } |
620 | |
621 | void ScalarType::getCapabilities( |
622 | SPIRVType::CapabilityArrayRefVector &capabilities, |
623 | std::optional<StorageClass> storage) { |
624 | unsigned bitwidth = getIntOrFloatBitWidth(); |
625 | |
626 | // 8- or 16-bit integer/floating-point numbers will require extra capabilities |
627 | // to appear in interface storage classes. See SPV_KHR_16bit_storage and |
628 | // SPV_KHR_8bit_storage for more details. |
629 | |
630 | #define STORAGE_CASE(storage, cap8, cap16) \ |
631 | case StorageClass::storage: { \ |
632 | if (bitwidth == 8) { \ |
633 | static const Capability caps[] = {Capability::cap8}; \ |
634 | ArrayRef<Capability> ref(caps, std::size(caps)); \ |
635 | capabilities.push_back(ref); \ |
636 | return; \ |
637 | } \ |
638 | if (bitwidth == 16) { \ |
639 | static const Capability caps[] = {Capability::cap16}; \ |
640 | ArrayRef<Capability> ref(caps, std::size(caps)); \ |
641 | capabilities.push_back(ref); \ |
642 | return; \ |
643 | } \ |
644 | /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \ |
645 | /* storage classes. Fall through to the next section. */ \ |
646 | } break |
647 | |
648 | // This part only handles the cases where special bitwidths appearing in |
649 | // interface storage classes. |
650 | if (storage) { |
651 | switch (*storage) { |
652 | STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16); |
653 | STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess, |
654 | StorageBuffer16BitAccess); |
655 | STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess, |
656 | StorageUniform16); |
657 | case StorageClass::Input: |
658 | case StorageClass::Output: { |
659 | if (bitwidth == 16) { |
660 | static const Capability caps[] = {Capability::StorageInputOutput16}; |
661 | ArrayRef<Capability> ref(caps, std::size(caps)); |
662 | capabilities.push_back(Elt: ref); |
663 | return; |
664 | } |
665 | break; |
666 | } |
667 | default: |
668 | break; |
669 | } |
670 | } |
671 | #undef STORAGE_CASE |
672 | |
673 | // For other non-interface storage classes, require a different set of |
674 | // capabilities for special bitwidths. |
675 | |
676 | #define WIDTH_CASE(type, width) \ |
677 | case width: { \ |
678 | static const Capability caps[] = {Capability::type##width}; \ |
679 | ArrayRef<Capability> ref(caps, std::size(caps)); \ |
680 | capabilities.push_back(ref); \ |
681 | } break |
682 | |
683 | if (auto intType = llvm::dyn_cast<IntegerType>(*this)) { |
684 | switch (bitwidth) { |
685 | WIDTH_CASE(Int, 8); |
686 | WIDTH_CASE(Int, 16); |
687 | WIDTH_CASE(Int, 64); |
688 | case 1: |
689 | case 32: |
690 | break; |
691 | default: |
692 | llvm_unreachable("invalid bitwidth to getCapabilities" ); |
693 | } |
694 | } else { |
695 | assert(llvm::isa<FloatType>(*this)); |
696 | switch (bitwidth) { |
697 | WIDTH_CASE(Float, 16); |
698 | WIDTH_CASE(Float, 64); |
699 | case 32: |
700 | break; |
701 | default: |
702 | llvm_unreachable("invalid bitwidth to getCapabilities" ); |
703 | } |
704 | } |
705 | |
706 | #undef WIDTH_CASE |
707 | } |
708 | |
709 | std::optional<int64_t> ScalarType::getSizeInBytes() { |
710 | auto bitWidth = getIntOrFloatBitWidth(); |
711 | // According to the SPIR-V spec: |
712 | // "There is no physical size or bit pattern defined for values with boolean |
713 | // type. If they are stored (in conjunction with OpVariable), they can only |
714 | // be used with logical addressing operations, not physical, and only with |
715 | // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup, |
716 | // Private, Function, Input, and Output." |
717 | if (bitWidth == 1) |
718 | return std::nullopt; |
719 | return bitWidth / 8; |
720 | } |
721 | |
722 | //===----------------------------------------------------------------------===// |
723 | // SPIRVType |
724 | //===----------------------------------------------------------------------===// |
725 | |
726 | bool SPIRVType::classof(Type type) { |
727 | // Allow SPIR-V dialect types |
728 | if (llvm::isa<SPIRVDialect>(type.getDialect())) |
729 | return true; |
730 | if (llvm::isa<ScalarType>(Val: type)) |
731 | return true; |
732 | if (auto vectorType = llvm::dyn_cast<VectorType>(type)) |
733 | return CompositeType::isValid(vectorType); |
734 | return false; |
735 | } |
736 | |
737 | bool SPIRVType::isScalarOrVector() { |
738 | return isIntOrFloat() || llvm::isa<VectorType>(*this); |
739 | } |
740 | |
741 | void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, |
742 | std::optional<StorageClass> storage) { |
743 | if (auto scalarType = llvm::dyn_cast<ScalarType>(Val&: *this)) { |
744 | scalarType.getExtensions(extensions, storage); |
745 | } else if (auto compositeType = llvm::dyn_cast<CompositeType>(Val&: *this)) { |
746 | compositeType.getExtensions(extensions, storage); |
747 | } else if (auto imageType = llvm::dyn_cast<ImageType>(Val&: *this)) { |
748 | imageType.getExtensions(extensions, storage); |
749 | } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(Val&: *this)) { |
750 | sampledImageType.getExtensions(extensions, storage); |
751 | } else if (auto matrixType = llvm::dyn_cast<MatrixType>(Val&: *this)) { |
752 | matrixType.getExtensions(extensions, storage); |
753 | } else if (auto ptrType = llvm::dyn_cast<PointerType>(Val&: *this)) { |
754 | ptrType.getExtensions(extensions, storage); |
755 | } else { |
756 | llvm_unreachable("invalid SPIR-V Type to getExtensions" ); |
757 | } |
758 | } |
759 | |
760 | void SPIRVType::getCapabilities( |
761 | SPIRVType::CapabilityArrayRefVector &capabilities, |
762 | std::optional<StorageClass> storage) { |
763 | if (auto scalarType = llvm::dyn_cast<ScalarType>(Val&: *this)) { |
764 | scalarType.getCapabilities(capabilities, storage); |
765 | } else if (auto compositeType = llvm::dyn_cast<CompositeType>(Val&: *this)) { |
766 | compositeType.getCapabilities(capabilities, storage); |
767 | } else if (auto imageType = llvm::dyn_cast<ImageType>(Val&: *this)) { |
768 | imageType.getCapabilities(capabilities, storage); |
769 | } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(Val&: *this)) { |
770 | sampledImageType.getCapabilities(capabilities, storage); |
771 | } else if (auto matrixType = llvm::dyn_cast<MatrixType>(Val&: *this)) { |
772 | matrixType.getCapabilities(capabilities, storage); |
773 | } else if (auto ptrType = llvm::dyn_cast<PointerType>(Val&: *this)) { |
774 | ptrType.getCapabilities(capabilities, storage); |
775 | } else { |
776 | llvm_unreachable("invalid SPIR-V Type to getCapabilities" ); |
777 | } |
778 | } |
779 | |
780 | std::optional<int64_t> SPIRVType::getSizeInBytes() { |
781 | if (auto scalarType = llvm::dyn_cast<ScalarType>(Val&: *this)) |
782 | return scalarType.getSizeInBytes(); |
783 | if (auto compositeType = llvm::dyn_cast<CompositeType>(Val&: *this)) |
784 | return compositeType.getSizeInBytes(); |
785 | return std::nullopt; |
786 | } |
787 | |
788 | //===----------------------------------------------------------------------===// |
789 | // SampledImageType |
790 | //===----------------------------------------------------------------------===// |
791 | struct spirv::detail::SampledImageTypeStorage : public TypeStorage { |
792 | using KeyTy = Type; |
793 | |
794 | SampledImageTypeStorage(const KeyTy &key) : imageType{key} {} |
795 | |
796 | bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); } |
797 | |
798 | static SampledImageTypeStorage *construct(TypeStorageAllocator &allocator, |
799 | const KeyTy &key) { |
800 | return new (allocator.allocate<SampledImageTypeStorage>()) |
801 | SampledImageTypeStorage(key); |
802 | } |
803 | |
804 | Type imageType; |
805 | }; |
806 | |
807 | SampledImageType SampledImageType::get(Type imageType) { |
808 | return Base::get(ctx: imageType.getContext(), args&: imageType); |
809 | } |
810 | |
811 | SampledImageType |
812 | SampledImageType::getChecked(function_ref<InFlightDiagnostic()> emitError, |
813 | Type imageType) { |
814 | return Base::getChecked(emitErrorFn: emitError, ctx: imageType.getContext(), args: imageType); |
815 | } |
816 | |
817 | Type SampledImageType::getImageType() const { return getImpl()->imageType; } |
818 | |
819 | LogicalResult |
820 | SampledImageType::verify(function_ref<InFlightDiagnostic()> emitError, |
821 | Type imageType) { |
822 | if (!llvm::isa<ImageType>(Val: imageType)) |
823 | return emitError() << "expected image type" ; |
824 | |
825 | return success(); |
826 | } |
827 | |
828 | void SampledImageType::getExtensions( |
829 | SPIRVType::ExtensionArrayRefVector &extensions, |
830 | std::optional<StorageClass> storage) { |
831 | llvm::cast<ImageType>(Val: getImageType()).getExtensions(extensions, storage); |
832 | } |
833 | |
834 | void SampledImageType::getCapabilities( |
835 | SPIRVType::CapabilityArrayRefVector &capabilities, |
836 | std::optional<StorageClass> storage) { |
837 | llvm::cast<ImageType>(Val: getImageType()).getCapabilities(capabilities, storage); |
838 | } |
839 | |
840 | //===----------------------------------------------------------------------===// |
841 | // StructType |
842 | //===----------------------------------------------------------------------===// |
843 | |
844 | /// Type storage for SPIR-V structure types: |
845 | /// |
846 | /// Structures are uniqued using: |
847 | /// - for identified structs: |
848 | /// - a string identifier; |
849 | /// - for literal structs: |
850 | /// - a list of member types; |
851 | /// - a list of member offset info; |
852 | /// - a list of member decoration info. |
853 | /// |
854 | /// Identified structures only have a mutable component consisting of: |
855 | /// - a list of member types; |
856 | /// - a list of member offset info; |
857 | /// - a list of member decoration info. |
858 | struct spirv::detail::StructTypeStorage : public TypeStorage { |
859 | /// Construct a storage object for an identified struct type. A struct type |
860 | /// associated with such storage must call StructType::trySetBody(...) later |
861 | /// in order to mutate the storage object providing the actual content. |
862 | StructTypeStorage(StringRef identifier) |
863 | : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr), |
864 | numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr), |
865 | identifier(identifier) {} |
866 | |
867 | /// Construct a storage object for a literal struct type. A struct type |
868 | /// associated with such storage is immutable. |
869 | StructTypeStorage( |
870 | unsigned numMembers, Type const *memberTypes, |
871 | StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, |
872 | StructType::MemberDecorationInfo const *memberDecorationsInfo) |
873 | : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo), |
874 | numMembers(numMembers), numMemberDecorations(numMemberDecorations), |
875 | memberDecorationsInfo(memberDecorationsInfo) {} |
876 | |
877 | /// A storage key is divided into 2 parts: |
878 | /// - for identified structs: |
879 | /// - a StringRef representing the struct identifier; |
880 | /// - for literal structs: |
881 | /// - an ArrayRef<Type> for member types; |
882 | /// - an ArrayRef<StructType::OffsetInfo> for member offset info; |
883 | /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration |
884 | /// info. |
885 | /// |
886 | /// An identified struct type is uniqued only by the first part (field 0) |
887 | /// of the key. |
888 | /// |
889 | /// A literal struct type is uniqued only by the second part (fields 1, 2, and |
890 | /// 3) of the key. The identifier field (field 0) must be empty. |
891 | using KeyTy = |
892 | std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>, |
893 | ArrayRef<StructType::MemberDecorationInfo>>; |
894 | |
895 | /// For identified structs, return true if the given key contains the same |
896 | /// identifier. |
897 | /// |
898 | /// For literal structs, return true if the given key contains a matching list |
899 | /// of member types + offset info + decoration info. |
900 | bool operator==(const KeyTy &key) const { |
901 | if (isIdentified()) { |
902 | // Identified types are uniqued by their identifier. |
903 | return getIdentifier() == std::get<0>(t: key); |
904 | } |
905 | |
906 | return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(), |
907 | getMemberDecorationsInfo()); |
908 | } |
909 | |
910 | /// If the given key contains a non-empty identifier, this method constructs |
911 | /// an identified struct and leaves the rest of the struct type data to be set |
912 | /// through a later call to StructType::trySetBody(...). |
913 | /// |
914 | /// If, on the other hand, the key contains an empty identifier, a literal |
915 | /// struct is constructed using the other fields of the key. |
916 | static StructTypeStorage *construct(TypeStorageAllocator &allocator, |
917 | const KeyTy &key) { |
918 | StringRef keyIdentifier = std::get<0>(t: key); |
919 | |
920 | if (!keyIdentifier.empty()) { |
921 | StringRef identifier = allocator.copyInto(str: keyIdentifier); |
922 | |
923 | // Identified StructType body/members will be set through trySetBody(...) |
924 | // later. |
925 | return new (allocator.allocate<StructTypeStorage>()) |
926 | StructTypeStorage(identifier); |
927 | } |
928 | |
929 | ArrayRef<Type> keyTypes = std::get<1>(t: key); |
930 | |
931 | // Copy the member type and layout information into the bump pointer |
932 | const Type *typesList = nullptr; |
933 | if (!keyTypes.empty()) { |
934 | typesList = allocator.copyInto(elements: keyTypes).data(); |
935 | } |
936 | |
937 | const StructType::OffsetInfo *offsetInfoList = nullptr; |
938 | if (!std::get<2>(t: key).empty()) { |
939 | ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(t: key); |
940 | assert(keyOffsetInfo.size() == keyTypes.size() && |
941 | "size of offset information must be same as the size of number of " |
942 | "elements" ); |
943 | offsetInfoList = allocator.copyInto(elements: keyOffsetInfo).data(); |
944 | } |
945 | |
946 | const StructType::MemberDecorationInfo *memberDecorationList = nullptr; |
947 | unsigned numMemberDecorations = 0; |
948 | if (!std::get<3>(t: key).empty()) { |
949 | auto keyMemberDecorations = std::get<3>(t: key); |
950 | numMemberDecorations = keyMemberDecorations.size(); |
951 | memberDecorationList = allocator.copyInto(elements: keyMemberDecorations).data(); |
952 | } |
953 | |
954 | return new (allocator.allocate<StructTypeStorage>()) |
955 | StructTypeStorage(keyTypes.size(), typesList, offsetInfoList, |
956 | numMemberDecorations, memberDecorationList); |
957 | } |
958 | |
959 | ArrayRef<Type> getMemberTypes() const { |
960 | return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers); |
961 | } |
962 | |
963 | ArrayRef<StructType::OffsetInfo> getOffsetInfo() const { |
964 | if (offsetInfo) { |
965 | return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers); |
966 | } |
967 | return {}; |
968 | } |
969 | |
970 | ArrayRef<StructType::MemberDecorationInfo> getMemberDecorationsInfo() const { |
971 | if (memberDecorationsInfo) { |
972 | return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo, |
973 | numMemberDecorations); |
974 | } |
975 | return {}; |
976 | } |
977 | |
978 | StringRef getIdentifier() const { return identifier; } |
979 | |
980 | bool isIdentified() const { return !identifier.empty(); } |
981 | |
982 | /// Sets the struct type content for identified structs. Calling this method |
983 | /// is only valid for identified structs. |
984 | /// |
985 | /// Fails under the following conditions: |
986 | /// - If called for a literal struct; |
987 | /// - If called for an identified struct whose body was set before (through a |
988 | /// call to this method) but with different contents from the passed |
989 | /// arguments. |
990 | LogicalResult mutate( |
991 | TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes, |
992 | ArrayRef<StructType::OffsetInfo> structOffsetInfo, |
993 | ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) { |
994 | if (!isIdentified()) |
995 | return failure(); |
996 | |
997 | if (memberTypesAndIsBodySet.getInt() && |
998 | (getMemberTypes() != structMemberTypes || |
999 | getOffsetInfo() != structOffsetInfo || |
1000 | getMemberDecorationsInfo() != structMemberDecorationInfo)) |
1001 | return failure(); |
1002 | |
1003 | memberTypesAndIsBodySet.setInt(true); |
1004 | numMembers = structMemberTypes.size(); |
1005 | |
1006 | // Copy the member type and layout information into the bump pointer. |
1007 | if (!structMemberTypes.empty()) |
1008 | memberTypesAndIsBodySet.setPointer( |
1009 | allocator.copyInto(elements: structMemberTypes).data()); |
1010 | |
1011 | if (!structOffsetInfo.empty()) { |
1012 | assert(structOffsetInfo.size() == structMemberTypes.size() && |
1013 | "size of offset information must be same as the size of number of " |
1014 | "elements" ); |
1015 | offsetInfo = allocator.copyInto(elements: structOffsetInfo).data(); |
1016 | } |
1017 | |
1018 | if (!structMemberDecorationInfo.empty()) { |
1019 | numMemberDecorations = structMemberDecorationInfo.size(); |
1020 | memberDecorationsInfo = |
1021 | allocator.copyInto(elements: structMemberDecorationInfo).data(); |
1022 | } |
1023 | |
1024 | return success(); |
1025 | } |
1026 | |
1027 | llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet; |
1028 | StructType::OffsetInfo const *offsetInfo; |
1029 | unsigned numMembers; |
1030 | unsigned numMemberDecorations; |
1031 | StructType::MemberDecorationInfo const *memberDecorationsInfo; |
1032 | StringRef identifier; |
1033 | }; |
1034 | |
1035 | StructType |
1036 | StructType::get(ArrayRef<Type> memberTypes, |
1037 | ArrayRef<StructType::OffsetInfo> offsetInfo, |
1038 | ArrayRef<StructType::MemberDecorationInfo> memberDecorations) { |
1039 | assert(!memberTypes.empty() && "Struct needs at least one member type" ); |
1040 | // Sort the decorations. |
1041 | SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations( |
1042 | memberDecorations.begin(), memberDecorations.end()); |
1043 | llvm::array_pod_sort(Start: sortedDecorations.begin(), End: sortedDecorations.end()); |
1044 | return Base::get(ctx: memberTypes.vec().front().getContext(), |
1045 | /*identifier=*/args: StringRef(), args&: memberTypes, args&: offsetInfo, |
1046 | args&: sortedDecorations); |
1047 | } |
1048 | |
1049 | StructType StructType::getIdentified(MLIRContext *context, |
1050 | StringRef identifier) { |
1051 | assert(!identifier.empty() && |
1052 | "StructType identifier must be non-empty string" ); |
1053 | |
1054 | return Base::get(ctx: context, args&: identifier, args: ArrayRef<Type>(), |
1055 | args: ArrayRef<StructType::OffsetInfo>(), |
1056 | args: ArrayRef<StructType::MemberDecorationInfo>()); |
1057 | } |
1058 | |
1059 | StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) { |
1060 | StructType newStructType = Base::get( |
1061 | ctx: context, args&: identifier, args: ArrayRef<Type>(), args: ArrayRef<StructType::OffsetInfo>(), |
1062 | args: ArrayRef<StructType::MemberDecorationInfo>()); |
1063 | // Set an empty body in case this is a identified struct. |
1064 | if (newStructType.isIdentified() && |
1065 | failed(result: newStructType.trySetBody( |
1066 | memberTypes: ArrayRef<Type>(), offsetInfo: ArrayRef<StructType::OffsetInfo>(), |
1067 | memberDecorations: ArrayRef<StructType::MemberDecorationInfo>()))) |
1068 | return StructType(); |
1069 | |
1070 | return newStructType; |
1071 | } |
1072 | |
1073 | StringRef StructType::getIdentifier() const { return getImpl()->identifier; } |
1074 | |
1075 | bool StructType::isIdentified() const { return getImpl()->isIdentified(); } |
1076 | |
1077 | unsigned StructType::getNumElements() const { return getImpl()->numMembers; } |
1078 | |
1079 | Type StructType::getElementType(unsigned index) const { |
1080 | assert(getNumElements() > index && "member index out of range" ); |
1081 | return getImpl()->memberTypesAndIsBodySet.getPointer()[index]; |
1082 | } |
1083 | |
1084 | TypeRange StructType::getElementTypes() const { |
1085 | return TypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(), |
1086 | getNumElements()); |
1087 | } |
1088 | |
1089 | bool StructType::hasOffset() const { return getImpl()->offsetInfo; } |
1090 | |
1091 | uint64_t StructType::getMemberOffset(unsigned index) const { |
1092 | assert(getNumElements() > index && "member index out of range" ); |
1093 | return getImpl()->offsetInfo[index]; |
1094 | } |
1095 | |
1096 | void StructType::getMemberDecorations( |
1097 | SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorations) |
1098 | const { |
1099 | memberDecorations.clear(); |
1100 | auto implMemberDecorations = getImpl()->getMemberDecorationsInfo(); |
1101 | memberDecorations.append(in_start: implMemberDecorations.begin(), |
1102 | in_end: implMemberDecorations.end()); |
1103 | } |
1104 | |
1105 | void StructType::getMemberDecorations( |
1106 | unsigned index, |
1107 | SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const { |
1108 | assert(getNumElements() > index && "member index out of range" ); |
1109 | auto memberDecorations = getImpl()->getMemberDecorationsInfo(); |
1110 | decorationsInfo.clear(); |
1111 | for (const auto &memberDecoration : memberDecorations) { |
1112 | if (memberDecoration.memberIndex == index) { |
1113 | decorationsInfo.push_back(Elt: memberDecoration); |
1114 | } |
1115 | if (memberDecoration.memberIndex > index) { |
1116 | // Early exit since the decorations are stored sorted. |
1117 | return; |
1118 | } |
1119 | } |
1120 | } |
1121 | |
1122 | LogicalResult |
1123 | StructType::trySetBody(ArrayRef<Type> memberTypes, |
1124 | ArrayRef<OffsetInfo> offsetInfo, |
1125 | ArrayRef<MemberDecorationInfo> memberDecorations) { |
1126 | return Base::mutate(args&: memberTypes, args&: offsetInfo, args&: memberDecorations); |
1127 | } |
1128 | |
1129 | void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, |
1130 | std::optional<StorageClass> storage) { |
1131 | for (Type elementType : getElementTypes()) |
1132 | llvm::cast<SPIRVType>(Val&: elementType).getExtensions(extensions, storage); |
1133 | } |
1134 | |
1135 | void StructType::getCapabilities( |
1136 | SPIRVType::CapabilityArrayRefVector &capabilities, |
1137 | std::optional<StorageClass> storage) { |
1138 | for (Type elementType : getElementTypes()) |
1139 | llvm::cast<SPIRVType>(Val&: elementType).getCapabilities(capabilities, storage); |
1140 | } |
1141 | |
1142 | llvm::hash_code spirv::hash_value( |
1143 | const StructType::MemberDecorationInfo &memberDecorationInfo) { |
1144 | return llvm::hash_combine(memberDecorationInfo.memberIndex, |
1145 | memberDecorationInfo.decoration); |
1146 | } |
1147 | |
1148 | //===----------------------------------------------------------------------===// |
1149 | // MatrixType |
1150 | //===----------------------------------------------------------------------===// |
1151 | |
1152 | struct spirv::detail::MatrixTypeStorage : public TypeStorage { |
1153 | MatrixTypeStorage(Type columnType, uint32_t columnCount) |
1154 | : columnType(columnType), columnCount(columnCount) {} |
1155 | |
1156 | using KeyTy = std::tuple<Type, uint32_t>; |
1157 | |
1158 | static MatrixTypeStorage *construct(TypeStorageAllocator &allocator, |
1159 | const KeyTy &key) { |
1160 | |
1161 | // Initialize the memory using placement new. |
1162 | return new (allocator.allocate<MatrixTypeStorage>()) |
1163 | MatrixTypeStorage(std::get<0>(t: key), std::get<1>(t: key)); |
1164 | } |
1165 | |
1166 | bool operator==(const KeyTy &key) const { |
1167 | return key == KeyTy(columnType, columnCount); |
1168 | } |
1169 | |
1170 | Type columnType; |
1171 | const uint32_t columnCount; |
1172 | }; |
1173 | |
1174 | MatrixType MatrixType::get(Type columnType, uint32_t columnCount) { |
1175 | return Base::get(ctx: columnType.getContext(), args&: columnType, args&: columnCount); |
1176 | } |
1177 | |
1178 | MatrixType MatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError, |
1179 | Type columnType, uint32_t columnCount) { |
1180 | return Base::getChecked(emitErrorFn: emitError, ctx: columnType.getContext(), args: columnType, |
1181 | args: columnCount); |
1182 | } |
1183 | |
1184 | LogicalResult MatrixType::verify(function_ref<InFlightDiagnostic()> emitError, |
1185 | Type columnType, uint32_t columnCount) { |
1186 | if (columnCount < 2 || columnCount > 4) |
1187 | return emitError() << "matrix can have 2, 3, or 4 columns only" ; |
1188 | |
1189 | if (!isValidColumnType(columnType)) |
1190 | return emitError() << "matrix columns must be vectors of floats" ; |
1191 | |
1192 | /// The underlying vectors (columns) must be of size 2, 3, or 4 |
1193 | ArrayRef<int64_t> columnShape = llvm::cast<VectorType>(columnType).getShape(); |
1194 | if (columnShape.size() != 1) |
1195 | return emitError() << "matrix columns must be 1D vectors" ; |
1196 | |
1197 | if (columnShape[0] < 2 || columnShape[0] > 4) |
1198 | return emitError() << "matrix columns must be of size 2, 3, or 4" ; |
1199 | |
1200 | return success(); |
1201 | } |
1202 | |
1203 | /// Returns true if the matrix elements are vectors of float elements |
1204 | bool MatrixType::isValidColumnType(Type columnType) { |
1205 | if (auto vectorType = llvm::dyn_cast<VectorType>(columnType)) { |
1206 | if (llvm::isa<FloatType>(vectorType.getElementType())) |
1207 | return true; |
1208 | } |
1209 | return false; |
1210 | } |
1211 | |
1212 | Type MatrixType::getColumnType() const { return getImpl()->columnType; } |
1213 | |
1214 | Type MatrixType::getElementType() const { |
1215 | return llvm::cast<VectorType>(getImpl()->columnType).getElementType(); |
1216 | } |
1217 | |
1218 | unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; } |
1219 | |
1220 | unsigned MatrixType::getNumRows() const { |
1221 | return llvm::cast<VectorType>(getImpl()->columnType).getShape()[0]; |
1222 | } |
1223 | |
1224 | unsigned MatrixType::getNumElements() const { |
1225 | return (getImpl()->columnCount) * getNumRows(); |
1226 | } |
1227 | |
1228 | void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, |
1229 | std::optional<StorageClass> storage) { |
1230 | llvm::cast<SPIRVType>(Val: getColumnType()).getExtensions(extensions, storage); |
1231 | } |
1232 | |
1233 | void MatrixType::getCapabilities( |
1234 | SPIRVType::CapabilityArrayRefVector &capabilities, |
1235 | std::optional<StorageClass> storage) { |
1236 | { |
1237 | static const Capability caps[] = {Capability::Matrix}; |
1238 | ArrayRef<Capability> ref(caps, std::size(caps)); |
1239 | capabilities.push_back(Elt: ref); |
1240 | } |
1241 | // Add any capabilities associated with the underlying vectors (i.e., columns) |
1242 | llvm::cast<SPIRVType>(Val: getColumnType()).getCapabilities(capabilities, storage); |
1243 | } |
1244 | |
1245 | //===----------------------------------------------------------------------===// |
1246 | // SPIR-V Dialect |
1247 | //===----------------------------------------------------------------------===// |
1248 | |
1249 | void SPIRVDialect::registerTypes() { |
1250 | addTypes<ArrayType, CooperativeMatrixType, ImageType, JointMatrixINTELType, |
1251 | MatrixType, PointerType, RuntimeArrayType, SampledImageType, |
1252 | StructType>(); |
1253 | } |
1254 | |