1//===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
14#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
16#include "mlir/IR/Attributes.h"
17#include "mlir/IR/BuiltinTypes.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/TypeSwitch.h"
20
21#include <cstdint>
22#include <iterator>
23
24using namespace mlir;
25using namespace mlir::spirv;
26
27//===----------------------------------------------------------------------===//
28// ArrayType
29//===----------------------------------------------------------------------===//
30
31struct spirv::detail::ArrayTypeStorage : public TypeStorage {
32 using KeyTy = std::tuple<Type, unsigned, unsigned>;
33
34 static ArrayTypeStorage *construct(TypeStorageAllocator &allocator,
35 const KeyTy &key) {
36 return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
37 }
38
39 bool operator==(const KeyTy &key) const {
40 return key == KeyTy(elementType, elementCount, stride);
41 }
42
43 ArrayTypeStorage(const KeyTy &key)
44 : elementType(std::get<0>(t: key)), elementCount(std::get<1>(t: key)),
45 stride(std::get<2>(t: key)) {}
46
47 Type elementType;
48 unsigned elementCount;
49 unsigned stride;
50};
51
52ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
53 assert(elementCount && "ArrayType needs at least one element");
54 return Base::get(ctx: elementType.getContext(), args&: elementType, args&: elementCount,
55 /*stride=*/args: 0);
56}
57
58ArrayType ArrayType::get(Type elementType, unsigned elementCount,
59 unsigned stride) {
60 assert(elementCount && "ArrayType needs at least one element");
61 return Base::get(ctx: elementType.getContext(), args&: elementType, args&: elementCount, args&: stride);
62}
63
64unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
65
66Type ArrayType::getElementType() const { return getImpl()->elementType; }
67
68unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
69
70void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
71 std::optional<StorageClass> storage) {
72 llvm::cast<SPIRVType>(Val: getElementType()).getExtensions(extensions, storage);
73}
74
75void ArrayType::getCapabilities(
76 SPIRVType::CapabilityArrayRefVector &capabilities,
77 std::optional<StorageClass> storage) {
78 llvm::cast<SPIRVType>(Val: getElementType())
79 .getCapabilities(capabilities, storage);
80}
81
82std::optional<int64_t> ArrayType::getSizeInBytes() {
83 auto elementType = llvm::cast<SPIRVType>(Val: getElementType());
84 std::optional<int64_t> size = elementType.getSizeInBytes();
85 if (!size)
86 return std::nullopt;
87 return (*size + getArrayStride()) * getNumElements();
88}
89
90//===----------------------------------------------------------------------===//
91// CompositeType
92//===----------------------------------------------------------------------===//
93
94bool CompositeType::classof(Type type) {
95 if (auto vectorType = llvm::dyn_cast<VectorType>(type))
96 return isValid(vectorType);
97 return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType,
98 spirv::JointMatrixINTELType, spirv::MatrixType,
99 spirv::RuntimeArrayType, spirv::StructType>(Val: type);
100}
101
102bool CompositeType::isValid(VectorType type) {
103 return type.getRank() == 1 &&
104 llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
105 llvm::isa<ScalarType>(type.getElementType());
106}
107
108Type CompositeType::getElementType(unsigned index) const {
109 return TypeSwitch<Type, Type>(*this)
110 .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType,
111 RuntimeArrayType, VectorType>(
112 [](auto type) { return type.getElementType(); })
113 .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
114 .Case<StructType>(
115 [index](StructType type) { return type.getElementType(index); })
116 .Default(
117 [](Type) -> Type { llvm_unreachable("invalid composite type"); });
118}
119
120unsigned CompositeType::getNumElements() const {
121 if (auto arrayType = llvm::dyn_cast<ArrayType>(Val: *this))
122 return arrayType.getNumElements();
123 if (auto matrixType = llvm::dyn_cast<MatrixType>(Val: *this))
124 return matrixType.getNumColumns();
125 if (auto structType = llvm::dyn_cast<StructType>(Val: *this))
126 return structType.getNumElements();
127 if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
128 return vectorType.getNumElements();
129 if (llvm::isa<CooperativeMatrixType>(Val: *this)) {
130 llvm_unreachable(
131 "invalid to query number of elements of spirv Cooperative Matrix type");
132 }
133 if (llvm::isa<JointMatrixINTELType>(Val: *this)) {
134 llvm_unreachable(
135 "invalid to query number of elements of spirv::JointMatrix type");
136 }
137 if (llvm::isa<RuntimeArrayType>(Val: *this)) {
138 llvm_unreachable(
139 "invalid to query number of elements of spirv::RuntimeArray type");
140 }
141 llvm_unreachable("invalid composite type");
142}
143
144bool CompositeType::hasCompileTimeKnownNumElements() const {
145 return !llvm::isa<CooperativeMatrixType, JointMatrixINTELType,
146 RuntimeArrayType>(Val: *this);
147}
148
149void CompositeType::getExtensions(
150 SPIRVType::ExtensionArrayRefVector &extensions,
151 std::optional<StorageClass> storage) {
152 TypeSwitch<Type>(*this)
153 .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType, MatrixType,
154 RuntimeArrayType, StructType>(
155 [&](auto type) { type.getExtensions(extensions, storage); })
156 .Case<VectorType>([&](VectorType type) {
157 return llvm::cast<ScalarType>(type.getElementType())
158 .getExtensions(extensions, storage);
159 })
160 .Default([](Type) { llvm_unreachable("invalid composite type"); });
161}
162
163void CompositeType::getCapabilities(
164 SPIRVType::CapabilityArrayRefVector &capabilities,
165 std::optional<StorageClass> storage) {
166 TypeSwitch<Type>(*this)
167 .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType, MatrixType,
168 RuntimeArrayType, StructType>(
169 [&](auto type) { type.getCapabilities(capabilities, storage); })
170 .Case<VectorType>([&](VectorType type) {
171 auto vecSize = getNumElements();
172 if (vecSize == 8 || vecSize == 16) {
173 static const Capability caps[] = {Capability::Vector16};
174 ArrayRef<Capability> ref(caps, std::size(caps));
175 capabilities.push_back(ref);
176 }
177 return llvm::cast<ScalarType>(type.getElementType())
178 .getCapabilities(capabilities, storage);
179 })
180 .Default([](Type) { llvm_unreachable("invalid composite type"); });
181}
182
183std::optional<int64_t> CompositeType::getSizeInBytes() {
184 if (auto arrayType = llvm::dyn_cast<ArrayType>(Val&: *this))
185 return arrayType.getSizeInBytes();
186 if (auto structType = llvm::dyn_cast<StructType>(Val&: *this))
187 return structType.getSizeInBytes();
188 if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) {
189 std::optional<int64_t> elementSize =
190 llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
191 if (!elementSize)
192 return std::nullopt;
193 return *elementSize * vectorType.getNumElements();
194 }
195 return std::nullopt;
196}
197
198//===----------------------------------------------------------------------===//
199// CooperativeMatrixType
200//===----------------------------------------------------------------------===//
201
202struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
203 using KeyTy =
204 std::tuple<Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR>;
205
206 static CooperativeMatrixTypeStorage *
207 construct(TypeStorageAllocator &allocator, const KeyTy &key) {
208 return new (allocator.allocate<CooperativeMatrixTypeStorage>())
209 CooperativeMatrixTypeStorage(key);
210 }
211
212 bool operator==(const KeyTy &key) const {
213 return key == KeyTy(elementType, rows, columns, scope, use);
214 }
215
216 CooperativeMatrixTypeStorage(const KeyTy &key)
217 : elementType(std::get<0>(key)), rows(std::get<1>(key)),
218 columns(std::get<2>(key)), scope(std::get<3>(key)),
219 use(std::get<4>(key)) {}
220
221 Type elementType;
222 uint32_t rows;
223 uint32_t columns;
224 Scope scope;
225 CooperativeMatrixUseKHR use;
226};
227
228CooperativeMatrixType CooperativeMatrixType::get(Type elementType,
229 uint32_t rows,
230 uint32_t columns, Scope scope,
231 CooperativeMatrixUseKHR use) {
232 return Base::get(elementType.getContext(), elementType, rows, columns, scope,
233 use);
234}
235
236Type CooperativeMatrixType::getElementType() const {
237 return getImpl()->elementType;
238}
239
240uint32_t CooperativeMatrixType::getRows() const { return getImpl()->rows; }
241
242uint32_t CooperativeMatrixType::getColumns() const {
243 return getImpl()->columns;
244}
245
246Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
247
248CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
249 return getImpl()->use;
250}
251
252void CooperativeMatrixType::getExtensions(
253 SPIRVType::ExtensionArrayRefVector &extensions,
254 std::optional<StorageClass> storage) {
255 llvm::cast<SPIRVType>(Val: getElementType()).getExtensions(extensions, storage);
256 static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
257 extensions.push_back(exts);
258}
259
260void CooperativeMatrixType::getCapabilities(
261 SPIRVType::CapabilityArrayRefVector &capabilities,
262 std::optional<StorageClass> storage) {
263 llvm::cast<SPIRVType>(Val: getElementType())
264 .getCapabilities(capabilities, storage);
265 static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};
266 capabilities.push_back(caps);
267}
268
269//===----------------------------------------------------------------------===//
270// JointMatrixType
271//===----------------------------------------------------------------------===//
272
273struct spirv::detail::JointMatrixTypeStorage : public TypeStorage {
274 using KeyTy = std::tuple<Type, unsigned, unsigned, MatrixLayout, Scope>;
275
276 static JointMatrixTypeStorage *construct(TypeStorageAllocator &allocator,
277 const KeyTy &key) {
278 return new (allocator.allocate<JointMatrixTypeStorage>())
279 JointMatrixTypeStorage(key);
280 }
281
282 bool operator==(const KeyTy &key) const {
283 return key == KeyTy(elementType, rows, columns, matrixLayout, scope);
284 }
285
286 JointMatrixTypeStorage(const KeyTy &key)
287 : elementType(std::get<0>(key)), rows(std::get<1>(key)),
288 columns(std::get<2>(key)), scope(std::get<4>(key)),
289 matrixLayout(std::get<3>(key)) {}
290
291 Type elementType;
292 unsigned rows;
293 unsigned columns;
294 Scope scope;
295 MatrixLayout matrixLayout;
296};
297
298JointMatrixINTELType JointMatrixINTELType::get(Type elementType, Scope scope,
299 unsigned rows, unsigned columns,
300 MatrixLayout matrixLayout) {
301 return Base::get(elementType.getContext(), elementType, rows, columns,
302 matrixLayout, scope);
303}
304
305Type JointMatrixINTELType::getElementType() const {
306 return getImpl()->elementType;
307}
308
309Scope JointMatrixINTELType::getScope() const { return getImpl()->scope; }
310
311unsigned JointMatrixINTELType::getRows() const { return getImpl()->rows; }
312
313unsigned JointMatrixINTELType::getColumns() const { return getImpl()->columns; }
314
315MatrixLayout JointMatrixINTELType::getMatrixLayout() const {
316 return getImpl()->matrixLayout;
317}
318
319void JointMatrixINTELType::getExtensions(
320 SPIRVType::ExtensionArrayRefVector &extensions,
321 std::optional<StorageClass> storage) {
322 llvm::cast<SPIRVType>(Val: getElementType()).getExtensions(extensions, storage);
323 static const Extension exts[] = {Extension::SPV_INTEL_joint_matrix};
324 ArrayRef<Extension> ref(exts, std::size(exts));
325 extensions.push_back(Elt: ref);
326}
327
328void JointMatrixINTELType::getCapabilities(
329 SPIRVType::CapabilityArrayRefVector &capabilities,
330 std::optional<StorageClass> storage) {
331 llvm::cast<SPIRVType>(Val: getElementType())
332 .getCapabilities(capabilities, storage);
333 static const Capability caps[] = {Capability::JointMatrixINTEL};
334 ArrayRef<Capability> ref(caps, std::size(caps));
335 capabilities.push_back(Elt: ref);
336}
337
338//===----------------------------------------------------------------------===//
339// ImageType
340//===----------------------------------------------------------------------===//
341
342template <typename T>
343static constexpr unsigned getNumBits() {
344 return 0;
345}
346template <>
347constexpr unsigned getNumBits<Dim>() {
348 static_assert((1 << 3) > getMaxEnumValForDim(),
349 "Not enough bits to encode Dim value");
350 return 3;
351}
352template <>
353constexpr unsigned getNumBits<ImageDepthInfo>() {
354 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
355 "Not enough bits to encode ImageDepthInfo value");
356 return 2;
357}
358template <>
359constexpr unsigned getNumBits<ImageArrayedInfo>() {
360 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
361 "Not enough bits to encode ImageArrayedInfo value");
362 return 1;
363}
364template <>
365constexpr unsigned getNumBits<ImageSamplingInfo>() {
366 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
367 "Not enough bits to encode ImageSamplingInfo value");
368 return 1;
369}
370template <>
371constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
372 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
373 "Not enough bits to encode ImageSamplerUseInfo value");
374 return 2;
375}
376template <>
377constexpr unsigned getNumBits<ImageFormat>() {
378 static_assert((1 << 6) > getMaxEnumValForImageFormat(),
379 "Not enough bits to encode ImageFormat value");
380 return 6;
381}
382
383struct spirv::detail::ImageTypeStorage : public TypeStorage {
384public:
385 using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
386 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
387
388 static ImageTypeStorage *construct(TypeStorageAllocator &allocator,
389 const KeyTy &key) {
390 return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
391 }
392
393 bool operator==(const KeyTy &key) const {
394 return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
395 samplerUseInfo, format);
396 }
397
398 ImageTypeStorage(const KeyTy &key)
399 : elementType(std::get<0>(key)), dim(std::get<1>(key)),
400 depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
401 samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
402 format(std::get<6>(key)) {}
403
404 Type elementType;
405 Dim dim : getNumBits<Dim>();
406 ImageDepthInfo depthInfo : getNumBits<ImageDepthInfo>();
407 ImageArrayedInfo arrayedInfo : getNumBits<ImageArrayedInfo>();
408 ImageSamplingInfo samplingInfo : getNumBits<ImageSamplingInfo>();
409 ImageSamplerUseInfo samplerUseInfo : getNumBits<ImageSamplerUseInfo>();
410 ImageFormat format : getNumBits<ImageFormat>();
411};
412
413ImageType
414ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
415 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
416 value) {
417 return Base::get(std::get<0>(value).getContext(), value);
418}
419
420Type ImageType::getElementType() const { return getImpl()->elementType; }
421
422Dim ImageType::getDim() const { return getImpl()->dim; }
423
424ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
425
426ImageArrayedInfo ImageType::getArrayedInfo() const {
427 return getImpl()->arrayedInfo;
428}
429
430ImageSamplingInfo ImageType::getSamplingInfo() const {
431 return getImpl()->samplingInfo;
432}
433
434ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
435 return getImpl()->samplerUseInfo;
436}
437
438ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
439
440void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &,
441 std::optional<StorageClass>) {
442 // Image types do not require extra extensions thus far.
443}
444
445void ImageType::getCapabilities(
446 SPIRVType::CapabilityArrayRefVector &capabilities,
447 std::optional<StorageClass>) {
448 if (auto dimCaps = spirv::getCapabilities(getDim()))
449 capabilities.push_back(Elt: *dimCaps);
450
451 if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
452 capabilities.push_back(Elt: *fmtCaps);
453}
454
455//===----------------------------------------------------------------------===//
456// PointerType
457//===----------------------------------------------------------------------===//
458
459struct spirv::detail::PointerTypeStorage : public TypeStorage {
460 // (Type, StorageClass) as the key: Type stored in this struct, and
461 // StorageClass stored as TypeStorage's subclass data.
462 using KeyTy = std::pair<Type, StorageClass>;
463
464 static PointerTypeStorage *construct(TypeStorageAllocator &allocator,
465 const KeyTy &key) {
466 return new (allocator.allocate<PointerTypeStorage>())
467 PointerTypeStorage(key);
468 }
469
470 bool operator==(const KeyTy &key) const {
471 return key == KeyTy(pointeeType, storageClass);
472 }
473
474 PointerTypeStorage(const KeyTy &key)
475 : pointeeType(key.first), storageClass(key.second) {}
476
477 Type pointeeType;
478 StorageClass storageClass;
479};
480
481PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
482 return Base::get(pointeeType.getContext(), pointeeType, storageClass);
483}
484
485Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
486
487StorageClass PointerType::getStorageClass() const {
488 return getImpl()->storageClass;
489}
490
491void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
492 std::optional<StorageClass> storage) {
493 // Use this pointer type's storage class because this pointer indicates we are
494 // using the pointee type in that specific storage class.
495 llvm::cast<SPIRVType>(getPointeeType())
496 .getExtensions(extensions, getStorageClass());
497
498 if (auto scExts = spirv::getExtensions(getStorageClass()))
499 extensions.push_back(Elt: *scExts);
500}
501
502void PointerType::getCapabilities(
503 SPIRVType::CapabilityArrayRefVector &capabilities,
504 std::optional<StorageClass> storage) {
505 // Use this pointer type's storage class because this pointer indicates we are
506 // using the pointee type in that specific storage class.
507 llvm::cast<SPIRVType>(getPointeeType())
508 .getCapabilities(capabilities, getStorageClass());
509
510 if (auto scCaps = spirv::getCapabilities(getStorageClass()))
511 capabilities.push_back(Elt: *scCaps);
512}
513
514//===----------------------------------------------------------------------===//
515// RuntimeArrayType
516//===----------------------------------------------------------------------===//
517
518struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
519 using KeyTy = std::pair<Type, unsigned>;
520
521 static RuntimeArrayTypeStorage *construct(TypeStorageAllocator &allocator,
522 const KeyTy &key) {
523 return new (allocator.allocate<RuntimeArrayTypeStorage>())
524 RuntimeArrayTypeStorage(key);
525 }
526
527 bool operator==(const KeyTy &key) const {
528 return key == KeyTy(elementType, stride);
529 }
530
531 RuntimeArrayTypeStorage(const KeyTy &key)
532 : elementType(key.first), stride(key.second) {}
533
534 Type elementType;
535 unsigned stride;
536};
537
538RuntimeArrayType RuntimeArrayType::get(Type elementType) {
539 return Base::get(ctx: elementType.getContext(), args&: elementType, /*stride=*/args: 0);
540}
541
542RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
543 return Base::get(ctx: elementType.getContext(), args&: elementType, args&: stride);
544}
545
546Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
547
548unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
549
550void RuntimeArrayType::getExtensions(
551 SPIRVType::ExtensionArrayRefVector &extensions,
552 std::optional<StorageClass> storage) {
553 llvm::cast<SPIRVType>(Val: getElementType()).getExtensions(extensions, storage);
554}
555
556void RuntimeArrayType::getCapabilities(
557 SPIRVType::CapabilityArrayRefVector &capabilities,
558 std::optional<StorageClass> storage) {
559 {
560 static const Capability caps[] = {Capability::Shader};
561 ArrayRef<Capability> ref(caps, std::size(caps));
562 capabilities.push_back(Elt: ref);
563 }
564 llvm::cast<SPIRVType>(Val: getElementType())
565 .getCapabilities(capabilities, storage);
566}
567
568//===----------------------------------------------------------------------===//
569// ScalarType
570//===----------------------------------------------------------------------===//
571
572bool ScalarType::classof(Type type) {
573 if (auto floatType = llvm::dyn_cast<FloatType>(Val&: type)) {
574 return isValid(floatType);
575 }
576 if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
577 return isValid(intType);
578 }
579 return false;
580}
581
582bool ScalarType::isValid(FloatType type) {
583 return llvm::is_contained(Set: {16u, 32u, 64u}, Element: type.getWidth()) && !type.isBF16();
584}
585
586bool ScalarType::isValid(IntegerType type) {
587 return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
588}
589
590void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
591 std::optional<StorageClass> storage) {
592 // 8- or 16-bit integer/floating-point numbers will require extra extensions
593 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
594 // SPV_KHR_8bit_storage for more details.
595 if (!storage)
596 return;
597
598 switch (*storage) {
599 case StorageClass::PushConstant:
600 case StorageClass::StorageBuffer:
601 case StorageClass::Uniform:
602 if (getIntOrFloatBitWidth() == 8) {
603 static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
604 ArrayRef<Extension> ref(exts, std::size(exts));
605 extensions.push_back(Elt: ref);
606 }
607 [[fallthrough]];
608 case StorageClass::Input:
609 case StorageClass::Output:
610 if (getIntOrFloatBitWidth() == 16) {
611 static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
612 ArrayRef<Extension> ref(exts, std::size(exts));
613 extensions.push_back(Elt: ref);
614 }
615 break;
616 default:
617 break;
618 }
619}
620
621void ScalarType::getCapabilities(
622 SPIRVType::CapabilityArrayRefVector &capabilities,
623 std::optional<StorageClass> storage) {
624 unsigned bitwidth = getIntOrFloatBitWidth();
625
626 // 8- or 16-bit integer/floating-point numbers will require extra capabilities
627 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
628 // SPV_KHR_8bit_storage for more details.
629
630#define STORAGE_CASE(storage, cap8, cap16) \
631 case StorageClass::storage: { \
632 if (bitwidth == 8) { \
633 static const Capability caps[] = {Capability::cap8}; \
634 ArrayRef<Capability> ref(caps, std::size(caps)); \
635 capabilities.push_back(ref); \
636 return; \
637 } \
638 if (bitwidth == 16) { \
639 static const Capability caps[] = {Capability::cap16}; \
640 ArrayRef<Capability> ref(caps, std::size(caps)); \
641 capabilities.push_back(ref); \
642 return; \
643 } \
644 /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \
645 /* storage classes. Fall through to the next section. */ \
646 } break
647
648 // This part only handles the cases where special bitwidths appearing in
649 // interface storage classes.
650 if (storage) {
651 switch (*storage) {
652 STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
653 STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
654 StorageBuffer16BitAccess);
655 STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
656 StorageUniform16);
657 case StorageClass::Input:
658 case StorageClass::Output: {
659 if (bitwidth == 16) {
660 static const Capability caps[] = {Capability::StorageInputOutput16};
661 ArrayRef<Capability> ref(caps, std::size(caps));
662 capabilities.push_back(Elt: ref);
663 return;
664 }
665 break;
666 }
667 default:
668 break;
669 }
670 }
671#undef STORAGE_CASE
672
673 // For other non-interface storage classes, require a different set of
674 // capabilities for special bitwidths.
675
676#define WIDTH_CASE(type, width) \
677 case width: { \
678 static const Capability caps[] = {Capability::type##width}; \
679 ArrayRef<Capability> ref(caps, std::size(caps)); \
680 capabilities.push_back(ref); \
681 } break
682
683 if (auto intType = llvm::dyn_cast<IntegerType>(*this)) {
684 switch (bitwidth) {
685 WIDTH_CASE(Int, 8);
686 WIDTH_CASE(Int, 16);
687 WIDTH_CASE(Int, 64);
688 case 1:
689 case 32:
690 break;
691 default:
692 llvm_unreachable("invalid bitwidth to getCapabilities");
693 }
694 } else {
695 assert(llvm::isa<FloatType>(*this));
696 switch (bitwidth) {
697 WIDTH_CASE(Float, 16);
698 WIDTH_CASE(Float, 64);
699 case 32:
700 break;
701 default:
702 llvm_unreachable("invalid bitwidth to getCapabilities");
703 }
704 }
705
706#undef WIDTH_CASE
707}
708
709std::optional<int64_t> ScalarType::getSizeInBytes() {
710 auto bitWidth = getIntOrFloatBitWidth();
711 // According to the SPIR-V spec:
712 // "There is no physical size or bit pattern defined for values with boolean
713 // type. If they are stored (in conjunction with OpVariable), they can only
714 // be used with logical addressing operations, not physical, and only with
715 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
716 // Private, Function, Input, and Output."
717 if (bitWidth == 1)
718 return std::nullopt;
719 return bitWidth / 8;
720}
721
722//===----------------------------------------------------------------------===//
723// SPIRVType
724//===----------------------------------------------------------------------===//
725
726bool SPIRVType::classof(Type type) {
727 // Allow SPIR-V dialect types
728 if (llvm::isa<SPIRVDialect>(type.getDialect()))
729 return true;
730 if (llvm::isa<ScalarType>(Val: type))
731 return true;
732 if (auto vectorType = llvm::dyn_cast<VectorType>(type))
733 return CompositeType::isValid(vectorType);
734 return false;
735}
736
737bool SPIRVType::isScalarOrVector() {
738 return isIntOrFloat() || llvm::isa<VectorType>(*this);
739}
740
741void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
742 std::optional<StorageClass> storage) {
743 if (auto scalarType = llvm::dyn_cast<ScalarType>(Val&: *this)) {
744 scalarType.getExtensions(extensions, storage);
745 } else if (auto compositeType = llvm::dyn_cast<CompositeType>(Val&: *this)) {
746 compositeType.getExtensions(extensions, storage);
747 } else if (auto imageType = llvm::dyn_cast<ImageType>(Val&: *this)) {
748 imageType.getExtensions(extensions, storage);
749 } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(Val&: *this)) {
750 sampledImageType.getExtensions(extensions, storage);
751 } else if (auto matrixType = llvm::dyn_cast<MatrixType>(Val&: *this)) {
752 matrixType.getExtensions(extensions, storage);
753 } else if (auto ptrType = llvm::dyn_cast<PointerType>(Val&: *this)) {
754 ptrType.getExtensions(extensions, storage);
755 } else {
756 llvm_unreachable("invalid SPIR-V Type to getExtensions");
757 }
758}
759
760void SPIRVType::getCapabilities(
761 SPIRVType::CapabilityArrayRefVector &capabilities,
762 std::optional<StorageClass> storage) {
763 if (auto scalarType = llvm::dyn_cast<ScalarType>(Val&: *this)) {
764 scalarType.getCapabilities(capabilities, storage);
765 } else if (auto compositeType = llvm::dyn_cast<CompositeType>(Val&: *this)) {
766 compositeType.getCapabilities(capabilities, storage);
767 } else if (auto imageType = llvm::dyn_cast<ImageType>(Val&: *this)) {
768 imageType.getCapabilities(capabilities, storage);
769 } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(Val&: *this)) {
770 sampledImageType.getCapabilities(capabilities, storage);
771 } else if (auto matrixType = llvm::dyn_cast<MatrixType>(Val&: *this)) {
772 matrixType.getCapabilities(capabilities, storage);
773 } else if (auto ptrType = llvm::dyn_cast<PointerType>(Val&: *this)) {
774 ptrType.getCapabilities(capabilities, storage);
775 } else {
776 llvm_unreachable("invalid SPIR-V Type to getCapabilities");
777 }
778}
779
780std::optional<int64_t> SPIRVType::getSizeInBytes() {
781 if (auto scalarType = llvm::dyn_cast<ScalarType>(Val&: *this))
782 return scalarType.getSizeInBytes();
783 if (auto compositeType = llvm::dyn_cast<CompositeType>(Val&: *this))
784 return compositeType.getSizeInBytes();
785 return std::nullopt;
786}
787
788//===----------------------------------------------------------------------===//
789// SampledImageType
790//===----------------------------------------------------------------------===//
791struct spirv::detail::SampledImageTypeStorage : public TypeStorage {
792 using KeyTy = Type;
793
794 SampledImageTypeStorage(const KeyTy &key) : imageType{key} {}
795
796 bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
797
798 static SampledImageTypeStorage *construct(TypeStorageAllocator &allocator,
799 const KeyTy &key) {
800 return new (allocator.allocate<SampledImageTypeStorage>())
801 SampledImageTypeStorage(key);
802 }
803
804 Type imageType;
805};
806
807SampledImageType SampledImageType::get(Type imageType) {
808 return Base::get(ctx: imageType.getContext(), args&: imageType);
809}
810
811SampledImageType
812SampledImageType::getChecked(function_ref<InFlightDiagnostic()> emitError,
813 Type imageType) {
814 return Base::getChecked(emitErrorFn: emitError, ctx: imageType.getContext(), args: imageType);
815}
816
817Type SampledImageType::getImageType() const { return getImpl()->imageType; }
818
819LogicalResult
820SampledImageType::verify(function_ref<InFlightDiagnostic()> emitError,
821 Type imageType) {
822 if (!llvm::isa<ImageType>(Val: imageType))
823 return emitError() << "expected image type";
824
825 return success();
826}
827
828void SampledImageType::getExtensions(
829 SPIRVType::ExtensionArrayRefVector &extensions,
830 std::optional<StorageClass> storage) {
831 llvm::cast<ImageType>(Val: getImageType()).getExtensions(extensions, storage);
832}
833
834void SampledImageType::getCapabilities(
835 SPIRVType::CapabilityArrayRefVector &capabilities,
836 std::optional<StorageClass> storage) {
837 llvm::cast<ImageType>(Val: getImageType()).getCapabilities(capabilities, storage);
838}
839
840//===----------------------------------------------------------------------===//
841// StructType
842//===----------------------------------------------------------------------===//
843
844/// Type storage for SPIR-V structure types:
845///
846/// Structures are uniqued using:
847/// - for identified structs:
848/// - a string identifier;
849/// - for literal structs:
850/// - a list of member types;
851/// - a list of member offset info;
852/// - a list of member decoration info.
853///
854/// Identified structures only have a mutable component consisting of:
855/// - a list of member types;
856/// - a list of member offset info;
857/// - a list of member decoration info.
858struct spirv::detail::StructTypeStorage : public TypeStorage {
859 /// Construct a storage object for an identified struct type. A struct type
860 /// associated with such storage must call StructType::trySetBody(...) later
861 /// in order to mutate the storage object providing the actual content.
862 StructTypeStorage(StringRef identifier)
863 : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
864 numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
865 identifier(identifier) {}
866
867 /// Construct a storage object for a literal struct type. A struct type
868 /// associated with such storage is immutable.
869 StructTypeStorage(
870 unsigned numMembers, Type const *memberTypes,
871 StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
872 StructType::MemberDecorationInfo const *memberDecorationsInfo)
873 : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
874 numMembers(numMembers), numMemberDecorations(numMemberDecorations),
875 memberDecorationsInfo(memberDecorationsInfo) {}
876
877 /// A storage key is divided into 2 parts:
878 /// - for identified structs:
879 /// - a StringRef representing the struct identifier;
880 /// - for literal structs:
881 /// - an ArrayRef<Type> for member types;
882 /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
883 /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
884 /// info.
885 ///
886 /// An identified struct type is uniqued only by the first part (field 0)
887 /// of the key.
888 ///
889 /// A literal struct type is uniqued only by the second part (fields 1, 2, and
890 /// 3) of the key. The identifier field (field 0) must be empty.
891 using KeyTy =
892 std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
893 ArrayRef<StructType::MemberDecorationInfo>>;
894
895 /// For identified structs, return true if the given key contains the same
896 /// identifier.
897 ///
898 /// For literal structs, return true if the given key contains a matching list
899 /// of member types + offset info + decoration info.
900 bool operator==(const KeyTy &key) const {
901 if (isIdentified()) {
902 // Identified types are uniqued by their identifier.
903 return getIdentifier() == std::get<0>(t: key);
904 }
905
906 return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
907 getMemberDecorationsInfo());
908 }
909
910 /// If the given key contains a non-empty identifier, this method constructs
911 /// an identified struct and leaves the rest of the struct type data to be set
912 /// through a later call to StructType::trySetBody(...).
913 ///
914 /// If, on the other hand, the key contains an empty identifier, a literal
915 /// struct is constructed using the other fields of the key.
916 static StructTypeStorage *construct(TypeStorageAllocator &allocator,
917 const KeyTy &key) {
918 StringRef keyIdentifier = std::get<0>(t: key);
919
920 if (!keyIdentifier.empty()) {
921 StringRef identifier = allocator.copyInto(str: keyIdentifier);
922
923 // Identified StructType body/members will be set through trySetBody(...)
924 // later.
925 return new (allocator.allocate<StructTypeStorage>())
926 StructTypeStorage(identifier);
927 }
928
929 ArrayRef<Type> keyTypes = std::get<1>(t: key);
930
931 // Copy the member type and layout information into the bump pointer
932 const Type *typesList = nullptr;
933 if (!keyTypes.empty()) {
934 typesList = allocator.copyInto(elements: keyTypes).data();
935 }
936
937 const StructType::OffsetInfo *offsetInfoList = nullptr;
938 if (!std::get<2>(t: key).empty()) {
939 ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(t: key);
940 assert(keyOffsetInfo.size() == keyTypes.size() &&
941 "size of offset information must be same as the size of number of "
942 "elements");
943 offsetInfoList = allocator.copyInto(elements: keyOffsetInfo).data();
944 }
945
946 const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
947 unsigned numMemberDecorations = 0;
948 if (!std::get<3>(t: key).empty()) {
949 auto keyMemberDecorations = std::get<3>(t: key);
950 numMemberDecorations = keyMemberDecorations.size();
951 memberDecorationList = allocator.copyInto(elements: keyMemberDecorations).data();
952 }
953
954 return new (allocator.allocate<StructTypeStorage>())
955 StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
956 numMemberDecorations, memberDecorationList);
957 }
958
959 ArrayRef<Type> getMemberTypes() const {
960 return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
961 }
962
963 ArrayRef<StructType::OffsetInfo> getOffsetInfo() const {
964 if (offsetInfo) {
965 return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
966 }
967 return {};
968 }
969
970 ArrayRef<StructType::MemberDecorationInfo> getMemberDecorationsInfo() const {
971 if (memberDecorationsInfo) {
972 return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
973 numMemberDecorations);
974 }
975 return {};
976 }
977
978 StringRef getIdentifier() const { return identifier; }
979
980 bool isIdentified() const { return !identifier.empty(); }
981
982 /// Sets the struct type content for identified structs. Calling this method
983 /// is only valid for identified structs.
984 ///
985 /// Fails under the following conditions:
986 /// - If called for a literal struct;
987 /// - If called for an identified struct whose body was set before (through a
988 /// call to this method) but with different contents from the passed
989 /// arguments.
990 LogicalResult mutate(
991 TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
992 ArrayRef<StructType::OffsetInfo> structOffsetInfo,
993 ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
994 if (!isIdentified())
995 return failure();
996
997 if (memberTypesAndIsBodySet.getInt() &&
998 (getMemberTypes() != structMemberTypes ||
999 getOffsetInfo() != structOffsetInfo ||
1000 getMemberDecorationsInfo() != structMemberDecorationInfo))
1001 return failure();
1002
1003 memberTypesAndIsBodySet.setInt(true);
1004 numMembers = structMemberTypes.size();
1005
1006 // Copy the member type and layout information into the bump pointer.
1007 if (!structMemberTypes.empty())
1008 memberTypesAndIsBodySet.setPointer(
1009 allocator.copyInto(elements: structMemberTypes).data());
1010
1011 if (!structOffsetInfo.empty()) {
1012 assert(structOffsetInfo.size() == structMemberTypes.size() &&
1013 "size of offset information must be same as the size of number of "
1014 "elements");
1015 offsetInfo = allocator.copyInto(elements: structOffsetInfo).data();
1016 }
1017
1018 if (!structMemberDecorationInfo.empty()) {
1019 numMemberDecorations = structMemberDecorationInfo.size();
1020 memberDecorationsInfo =
1021 allocator.copyInto(elements: structMemberDecorationInfo).data();
1022 }
1023
1024 return success();
1025 }
1026
1027 llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
1028 StructType::OffsetInfo const *offsetInfo;
1029 unsigned numMembers;
1030 unsigned numMemberDecorations;
1031 StructType::MemberDecorationInfo const *memberDecorationsInfo;
1032 StringRef identifier;
1033};
1034
1035StructType
1036StructType::get(ArrayRef<Type> memberTypes,
1037 ArrayRef<StructType::OffsetInfo> offsetInfo,
1038 ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
1039 assert(!memberTypes.empty() && "Struct needs at least one member type");
1040 // Sort the decorations.
1041 SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
1042 memberDecorations.begin(), memberDecorations.end());
1043 llvm::array_pod_sort(Start: sortedDecorations.begin(), End: sortedDecorations.end());
1044 return Base::get(ctx: memberTypes.vec().front().getContext(),
1045 /*identifier=*/args: StringRef(), args&: memberTypes, args&: offsetInfo,
1046 args&: sortedDecorations);
1047}
1048
1049StructType StructType::getIdentified(MLIRContext *context,
1050 StringRef identifier) {
1051 assert(!identifier.empty() &&
1052 "StructType identifier must be non-empty string");
1053
1054 return Base::get(ctx: context, args&: identifier, args: ArrayRef<Type>(),
1055 args: ArrayRef<StructType::OffsetInfo>(),
1056 args: ArrayRef<StructType::MemberDecorationInfo>());
1057}
1058
1059StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
1060 StructType newStructType = Base::get(
1061 ctx: context, args&: identifier, args: ArrayRef<Type>(), args: ArrayRef<StructType::OffsetInfo>(),
1062 args: ArrayRef<StructType::MemberDecorationInfo>());
1063 // Set an empty body in case this is a identified struct.
1064 if (newStructType.isIdentified() &&
1065 failed(result: newStructType.trySetBody(
1066 memberTypes: ArrayRef<Type>(), offsetInfo: ArrayRef<StructType::OffsetInfo>(),
1067 memberDecorations: ArrayRef<StructType::MemberDecorationInfo>())))
1068 return StructType();
1069
1070 return newStructType;
1071}
1072
1073StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
1074
1075bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1076
1077unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1078
1079Type StructType::getElementType(unsigned index) const {
1080 assert(getNumElements() > index && "member index out of range");
1081 return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1082}
1083
1084TypeRange StructType::getElementTypes() const {
1085 return TypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1086 getNumElements());
1087}
1088
1089bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1090
1091uint64_t StructType::getMemberOffset(unsigned index) const {
1092 assert(getNumElements() > index && "member index out of range");
1093 return getImpl()->offsetInfo[index];
1094}
1095
1096void StructType::getMemberDecorations(
1097 SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorations)
1098 const {
1099 memberDecorations.clear();
1100 auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1101 memberDecorations.append(in_start: implMemberDecorations.begin(),
1102 in_end: implMemberDecorations.end());
1103}
1104
1105void StructType::getMemberDecorations(
1106 unsigned index,
1107 SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
1108 assert(getNumElements() > index && "member index out of range");
1109 auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1110 decorationsInfo.clear();
1111 for (const auto &memberDecoration : memberDecorations) {
1112 if (memberDecoration.memberIndex == index) {
1113 decorationsInfo.push_back(Elt: memberDecoration);
1114 }
1115 if (memberDecoration.memberIndex > index) {
1116 // Early exit since the decorations are stored sorted.
1117 return;
1118 }
1119 }
1120}
1121
1122LogicalResult
1123StructType::trySetBody(ArrayRef<Type> memberTypes,
1124 ArrayRef<OffsetInfo> offsetInfo,
1125 ArrayRef<MemberDecorationInfo> memberDecorations) {
1126 return Base::mutate(args&: memberTypes, args&: offsetInfo, args&: memberDecorations);
1127}
1128
1129void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1130 std::optional<StorageClass> storage) {
1131 for (Type elementType : getElementTypes())
1132 llvm::cast<SPIRVType>(Val&: elementType).getExtensions(extensions, storage);
1133}
1134
1135void StructType::getCapabilities(
1136 SPIRVType::CapabilityArrayRefVector &capabilities,
1137 std::optional<StorageClass> storage) {
1138 for (Type elementType : getElementTypes())
1139 llvm::cast<SPIRVType>(Val&: elementType).getCapabilities(capabilities, storage);
1140}
1141
1142llvm::hash_code spirv::hash_value(
1143 const StructType::MemberDecorationInfo &memberDecorationInfo) {
1144 return llvm::hash_combine(memberDecorationInfo.memberIndex,
1145 memberDecorationInfo.decoration);
1146}
1147
1148//===----------------------------------------------------------------------===//
1149// MatrixType
1150//===----------------------------------------------------------------------===//
1151
1152struct spirv::detail::MatrixTypeStorage : public TypeStorage {
1153 MatrixTypeStorage(Type columnType, uint32_t columnCount)
1154 : columnType(columnType), columnCount(columnCount) {}
1155
1156 using KeyTy = std::tuple<Type, uint32_t>;
1157
1158 static MatrixTypeStorage *construct(TypeStorageAllocator &allocator,
1159 const KeyTy &key) {
1160
1161 // Initialize the memory using placement new.
1162 return new (allocator.allocate<MatrixTypeStorage>())
1163 MatrixTypeStorage(std::get<0>(t: key), std::get<1>(t: key));
1164 }
1165
1166 bool operator==(const KeyTy &key) const {
1167 return key == KeyTy(columnType, columnCount);
1168 }
1169
1170 Type columnType;
1171 const uint32_t columnCount;
1172};
1173
1174MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1175 return Base::get(ctx: columnType.getContext(), args&: columnType, args&: columnCount);
1176}
1177
1178MatrixType MatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
1179 Type columnType, uint32_t columnCount) {
1180 return Base::getChecked(emitErrorFn: emitError, ctx: columnType.getContext(), args: columnType,
1181 args: columnCount);
1182}
1183
1184LogicalResult MatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
1185 Type columnType, uint32_t columnCount) {
1186 if (columnCount < 2 || columnCount > 4)
1187 return emitError() << "matrix can have 2, 3, or 4 columns only";
1188
1189 if (!isValidColumnType(columnType))
1190 return emitError() << "matrix columns must be vectors of floats";
1191
1192 /// The underlying vectors (columns) must be of size 2, 3, or 4
1193 ArrayRef<int64_t> columnShape = llvm::cast<VectorType>(columnType).getShape();
1194 if (columnShape.size() != 1)
1195 return emitError() << "matrix columns must be 1D vectors";
1196
1197 if (columnShape[0] < 2 || columnShape[0] > 4)
1198 return emitError() << "matrix columns must be of size 2, 3, or 4";
1199
1200 return success();
1201}
1202
1203/// Returns true if the matrix elements are vectors of float elements
1204bool MatrixType::isValidColumnType(Type columnType) {
1205 if (auto vectorType = llvm::dyn_cast<VectorType>(columnType)) {
1206 if (llvm::isa<FloatType>(vectorType.getElementType()))
1207 return true;
1208 }
1209 return false;
1210}
1211
1212Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1213
1214Type MatrixType::getElementType() const {
1215 return llvm::cast<VectorType>(getImpl()->columnType).getElementType();
1216}
1217
1218unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1219
1220unsigned MatrixType::getNumRows() const {
1221 return llvm::cast<VectorType>(getImpl()->columnType).getShape()[0];
1222}
1223
1224unsigned MatrixType::getNumElements() const {
1225 return (getImpl()->columnCount) * getNumRows();
1226}
1227
1228void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1229 std::optional<StorageClass> storage) {
1230 llvm::cast<SPIRVType>(Val: getColumnType()).getExtensions(extensions, storage);
1231}
1232
1233void MatrixType::getCapabilities(
1234 SPIRVType::CapabilityArrayRefVector &capabilities,
1235 std::optional<StorageClass> storage) {
1236 {
1237 static const Capability caps[] = {Capability::Matrix};
1238 ArrayRef<Capability> ref(caps, std::size(caps));
1239 capabilities.push_back(Elt: ref);
1240 }
1241 // Add any capabilities associated with the underlying vectors (i.e., columns)
1242 llvm::cast<SPIRVType>(Val: getColumnType()).getCapabilities(capabilities, storage);
1243}
1244
1245//===----------------------------------------------------------------------===//
1246// SPIR-V Dialect
1247//===----------------------------------------------------------------------===//
1248
1249void SPIRVDialect::registerTypes() {
1250 addTypes<ArrayType, CooperativeMatrixType, ImageType, JointMatrixINTELType,
1251 MatrixType, PointerType, RuntimeArrayType, SampledImageType,
1252 StructType>();
1253}
1254

source code of mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp