1//===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
14#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
16#include "mlir/IR/BuiltinTypes.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/TypeSwitch.h"
19
20#include <algorithm>
21#include <cstdint>
22#include <numeric>
23
24using namespace mlir;
25using namespace mlir::spirv;
26
27//===----------------------------------------------------------------------===//
28// ArrayType
29//===----------------------------------------------------------------------===//
30
31struct spirv::detail::ArrayTypeStorage : public TypeStorage {
32 using KeyTy = std::tuple<Type, unsigned, unsigned>;
33
34 static ArrayTypeStorage *construct(TypeStorageAllocator &allocator,
35 const KeyTy &key) {
36 return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
37 }
38
39 bool operator==(const KeyTy &key) const {
40 return key == KeyTy(elementType, elementCount, stride);
41 }
42
43 ArrayTypeStorage(const KeyTy &key)
44 : elementType(std::get<0>(t: key)), elementCount(std::get<1>(t: key)),
45 stride(std::get<2>(t: key)) {}
46
47 Type elementType;
48 unsigned elementCount;
49 unsigned stride;
50};
51
52ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
53 assert(elementCount && "ArrayType needs at least one element");
54 return Base::get(ctx: elementType.getContext(), args&: elementType, args&: elementCount,
55 /*stride=*/args: 0);
56}
57
58ArrayType ArrayType::get(Type elementType, unsigned elementCount,
59 unsigned stride) {
60 assert(elementCount && "ArrayType needs at least one element");
61 return Base::get(ctx: elementType.getContext(), args&: elementType, args&: elementCount, args&: stride);
62}
63
64unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
65
66Type ArrayType::getElementType() const { return getImpl()->elementType; }
67
68unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
69
70void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
71 std::optional<StorageClass> storage) {
72 llvm::cast<SPIRVType>(Val: getElementType()).getExtensions(extensions, storage);
73}
74
75void ArrayType::getCapabilities(
76 SPIRVType::CapabilityArrayRefVector &capabilities,
77 std::optional<StorageClass> storage) {
78 llvm::cast<SPIRVType>(Val: getElementType())
79 .getCapabilities(capabilities, storage);
80}
81
82std::optional<int64_t> ArrayType::getSizeInBytes() {
83 auto elementType = llvm::cast<SPIRVType>(Val: getElementType());
84 std::optional<int64_t> size = elementType.getSizeInBytes();
85 if (!size)
86 return std::nullopt;
87 return (*size + getArrayStride()) * getNumElements();
88}
89
90//===----------------------------------------------------------------------===//
91// CompositeType
92//===----------------------------------------------------------------------===//
93
94bool CompositeType::classof(Type type) {
95 if (auto vectorType = llvm::dyn_cast<VectorType>(Val&: type))
96 return isValid(vectorType);
97 return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType,
98 spirv::MatrixType, spirv::RuntimeArrayType,
99 spirv::StructType, spirv::TensorArmType>(Val: type);
100}
101
102bool CompositeType::isValid(VectorType type) {
103 return type.getRank() == 1 &&
104 llvm::is_contained(Set: {2, 3, 4, 8, 16}, Element: type.getNumElements()) &&
105 llvm::isa<ScalarType>(Val: type.getElementType());
106}
107
108Type CompositeType::getElementType(unsigned index) const {
109 return TypeSwitch<Type, Type>(*this)
110 .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType,
111 TensorArmType>(caseFn: [](auto type) { return type.getElementType(); })
112 .Case<MatrixType>(caseFn: [](MatrixType type) { return type.getColumnType(); })
113 .Case<StructType>(
114 caseFn: [index](StructType type) { return type.getElementType(index); })
115 .Default(
116 defaultFn: [](Type) -> Type { llvm_unreachable("invalid composite type"); });
117}
118
119unsigned CompositeType::getNumElements() const {
120 if (auto arrayType = llvm::dyn_cast<ArrayType>(Val: *this))
121 return arrayType.getNumElements();
122 if (auto matrixType = llvm::dyn_cast<MatrixType>(Val: *this))
123 return matrixType.getNumColumns();
124 if (auto structType = llvm::dyn_cast<StructType>(Val: *this))
125 return structType.getNumElements();
126 if (auto vectorType = llvm::dyn_cast<VectorType>(Val: *this))
127 return vectorType.getNumElements();
128 if (auto tensorArmType = dyn_cast<TensorArmType>(Val: *this))
129 return tensorArmType.getNumElements();
130 if (llvm::isa<CooperativeMatrixType>(Val: *this)) {
131 llvm_unreachable(
132 "invalid to query number of elements of spirv Cooperative Matrix type");
133 }
134 if (llvm::isa<RuntimeArrayType>(Val: *this)) {
135 llvm_unreachable(
136 "invalid to query number of elements of spirv::RuntimeArray type");
137 }
138 llvm_unreachable("invalid composite type");
139}
140
141bool CompositeType::hasCompileTimeKnownNumElements() const {
142 return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(Val: *this);
143}
144
145void CompositeType::getExtensions(
146 SPIRVType::ExtensionArrayRefVector &extensions,
147 std::optional<StorageClass> storage) {
148 TypeSwitch<Type>(*this)
149 .Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType,
150 StructType>(
151 caseFn: [&](auto type) { type.getExtensions(extensions, storage); })
152 .Case<VectorType>(caseFn: [&](VectorType type) {
153 return llvm::cast<ScalarType>(Val: type.getElementType())
154 .getExtensions(extensions, storage);
155 })
156 .Case<TensorArmType>(caseFn: [&](TensorArmType type) {
157 static constexpr Extension ext{Extension::SPV_ARM_tensors};
158 extensions.push_back(Elt: ext);
159 return llvm::cast<ScalarType>(Val: type.getElementType())
160 .getExtensions(extensions, storage);
161 })
162
163 .Default(defaultFn: [](Type) { llvm_unreachable("invalid composite type"); });
164}
165
166void CompositeType::getCapabilities(
167 SPIRVType::CapabilityArrayRefVector &capabilities,
168 std::optional<StorageClass> storage) {
169 TypeSwitch<Type>(*this)
170 .Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType,
171 StructType>(
172 caseFn: [&](auto type) { type.getCapabilities(capabilities, storage); })
173 .Case<VectorType>(caseFn: [&](VectorType type) {
174 auto vecSize = getNumElements();
175 if (vecSize == 8 || vecSize == 16) {
176 static const Capability caps[] = {Capability::Vector16};
177 ArrayRef<Capability> ref(caps, std::size(caps));
178 capabilities.push_back(Elt: ref);
179 }
180 return llvm::cast<ScalarType>(Val: type.getElementType())
181 .getCapabilities(capabilities, storage);
182 })
183 .Case<TensorArmType>(caseFn: [&](TensorArmType type) {
184 static constexpr Capability cap{Capability::TensorsARM};
185 capabilities.push_back(Elt: cap);
186 return llvm::cast<ScalarType>(Val: type.getElementType())
187 .getCapabilities(capabilities, storage);
188 })
189 .Default(defaultFn: [](Type) { llvm_unreachable("invalid composite type"); });
190}
191
192std::optional<int64_t> CompositeType::getSizeInBytes() {
193 if (auto arrayType = llvm::dyn_cast<ArrayType>(Val&: *this))
194 return arrayType.getSizeInBytes();
195 if (auto structType = llvm::dyn_cast<StructType>(Val&: *this))
196 return structType.getSizeInBytes();
197 if (auto vectorType = llvm::dyn_cast<VectorType>(Val&: *this)) {
198 std::optional<int64_t> elementSize =
199 llvm::cast<ScalarType>(Val: vectorType.getElementType()).getSizeInBytes();
200 if (!elementSize)
201 return std::nullopt;
202 return *elementSize * vectorType.getNumElements();
203 }
204 if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(Val&: *this)) {
205 std::optional<int64_t> elementSize =
206 llvm::cast<ScalarType>(Val: tensorArmType.getElementType()).getSizeInBytes();
207 if (!elementSize)
208 return std::nullopt;
209 return *elementSize * tensorArmType.getNumElements();
210 }
211 return std::nullopt;
212}
213
214//===----------------------------------------------------------------------===//
215// CooperativeMatrixType
216//===----------------------------------------------------------------------===//
217
218struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
219 // In the specification dimensions of the Cooperative Matrix are 32-bit
220 // integers --- the initial implementation kept those values as such. However,
221 // the `ShapedType` expects the shape to be `int64_t`. We could keep the shape
222 // as 32-bits and expose it as int64_t through `getShape`, however, this
223 // method returns an `ArrayRef`, so returning `ArrayRef<int64_t>` having two
224 // 32-bits integers would require an extra logic and storage. So, we diverge
225 // from the spec and internally represent the dimensions as 64-bit integers,
226 // so we can easily return an `ArrayRef` from `getShape` without any extra
227 // logic. Alternatively, we could store both rows and columns (both 32-bits)
228 // and shape (64-bits), assigning rows and columns to shape whenever
229 // `getShape` is called. This would be at the cost of extra logic and storage.
230 // Note: Because `ArrayRef` is returned we cannot construct an object in
231 // `getShape` on the fly.
232 using KeyTy =
233 std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
234
235 static CooperativeMatrixTypeStorage *
236 construct(TypeStorageAllocator &allocator, const KeyTy &key) {
237 return new (allocator.allocate<CooperativeMatrixTypeStorage>())
238 CooperativeMatrixTypeStorage(key);
239 }
240
241 bool operator==(const KeyTy &key) const {
242 return key == KeyTy(elementType, shape[0], shape[1], scope, use);
243 }
244
245 CooperativeMatrixTypeStorage(const KeyTy &key)
246 : elementType(std::get<0>(t: key)),
247 shape({std::get<1>(t: key), std::get<2>(t: key)}), scope(std::get<3>(t: key)),
248 use(std::get<4>(t: key)) {}
249
250 Type elementType;
251 // [#rows, #columns]
252 std::array<int64_t, 2> shape;
253 Scope scope;
254 CooperativeMatrixUseKHR use;
255};
256
257CooperativeMatrixType CooperativeMatrixType::get(Type elementType,
258 uint32_t rows,
259 uint32_t columns, Scope scope,
260 CooperativeMatrixUseKHR use) {
261 return Base::get(ctx: elementType.getContext(), args&: elementType, args&: rows, args&: columns, args&: scope,
262 args&: use);
263}
264
265Type CooperativeMatrixType::getElementType() const {
266 return getImpl()->elementType;
267}
268
269uint32_t CooperativeMatrixType::getRows() const {
270 assert(getImpl()->shape[0] != ShapedType::kDynamic);
271 return static_cast<uint32_t>(getImpl()->shape[0]);
272}
273
274uint32_t CooperativeMatrixType::getColumns() const {
275 assert(getImpl()->shape[1] != ShapedType::kDynamic);
276 return static_cast<uint32_t>(getImpl()->shape[1]);
277}
278
279ArrayRef<int64_t> CooperativeMatrixType::getShape() const {
280 return getImpl()->shape;
281}
282
283Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
284
285CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
286 return getImpl()->use;
287}
288
289void CooperativeMatrixType::getExtensions(
290 SPIRVType::ExtensionArrayRefVector &extensions,
291 std::optional<StorageClass> storage) {
292 llvm::cast<SPIRVType>(Val: getElementType()).getExtensions(extensions, storage);
293 static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
294 extensions.push_back(Elt: exts);
295}
296
297void CooperativeMatrixType::getCapabilities(
298 SPIRVType::CapabilityArrayRefVector &capabilities,
299 std::optional<StorageClass> storage) {
300 llvm::cast<SPIRVType>(Val: getElementType())
301 .getCapabilities(capabilities, storage);
302 static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};
303 capabilities.push_back(Elt: caps);
304}
305
306//===----------------------------------------------------------------------===//
307// ImageType
308//===----------------------------------------------------------------------===//
309
310template <typename T>
311static constexpr unsigned getNumBits() {
312 return 0;
313}
314template <>
315constexpr unsigned getNumBits<Dim>() {
316 static_assert((1 << 3) > getMaxEnumValForDim(),
317 "Not enough bits to encode Dim value");
318 return 3;
319}
320template <>
321constexpr unsigned getNumBits<ImageDepthInfo>() {
322 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
323 "Not enough bits to encode ImageDepthInfo value");
324 return 2;
325}
326template <>
327constexpr unsigned getNumBits<ImageArrayedInfo>() {
328 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
329 "Not enough bits to encode ImageArrayedInfo value");
330 return 1;
331}
332template <>
333constexpr unsigned getNumBits<ImageSamplingInfo>() {
334 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
335 "Not enough bits to encode ImageSamplingInfo value");
336 return 1;
337}
338template <>
339constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
340 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
341 "Not enough bits to encode ImageSamplerUseInfo value");
342 return 2;
343}
344template <>
345constexpr unsigned getNumBits<ImageFormat>() {
346 static_assert((1 << 6) > getMaxEnumValForImageFormat(),
347 "Not enough bits to encode ImageFormat value");
348 return 6;
349}
350
351struct spirv::detail::ImageTypeStorage : public TypeStorage {
352public:
353 using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
354 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
355
356 static ImageTypeStorage *construct(TypeStorageAllocator &allocator,
357 const KeyTy &key) {
358 return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
359 }
360
361 bool operator==(const KeyTy &key) const {
362 return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
363 samplerUseInfo, format);
364 }
365
366 ImageTypeStorage(const KeyTy &key)
367 : elementType(std::get<0>(t: key)), dim(std::get<1>(t: key)),
368 depthInfo(std::get<2>(t: key)), arrayedInfo(std::get<3>(t: key)),
369 samplingInfo(std::get<4>(t: key)), samplerUseInfo(std::get<5>(t: key)),
370 format(std::get<6>(t: key)) {}
371
372 Type elementType;
373 Dim dim : getNumBits<Dim>();
374 ImageDepthInfo depthInfo : getNumBits<ImageDepthInfo>();
375 ImageArrayedInfo arrayedInfo : getNumBits<ImageArrayedInfo>();
376 ImageSamplingInfo samplingInfo : getNumBits<ImageSamplingInfo>();
377 ImageSamplerUseInfo samplerUseInfo : getNumBits<ImageSamplerUseInfo>();
378 ImageFormat format : getNumBits<ImageFormat>();
379};
380
381ImageType
382ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
383 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
384 value) {
385 return Base::get(ctx: std::get<0>(t&: value).getContext(), args&: value);
386}
387
388Type ImageType::getElementType() const { return getImpl()->elementType; }
389
390Dim ImageType::getDim() const { return getImpl()->dim; }
391
392ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
393
394ImageArrayedInfo ImageType::getArrayedInfo() const {
395 return getImpl()->arrayedInfo;
396}
397
398ImageSamplingInfo ImageType::getSamplingInfo() const {
399 return getImpl()->samplingInfo;
400}
401
402ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
403 return getImpl()->samplerUseInfo;
404}
405
406ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
407
408void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &,
409 std::optional<StorageClass>) {
410 // Image types do not require extra extensions thus far.
411}
412
413void ImageType::getCapabilities(
414 SPIRVType::CapabilityArrayRefVector &capabilities,
415 std::optional<StorageClass>) {
416 if (auto dimCaps = spirv::getCapabilities(value: getDim()))
417 capabilities.push_back(Elt: *dimCaps);
418
419 if (auto fmtCaps = spirv::getCapabilities(value: getImageFormat()))
420 capabilities.push_back(Elt: *fmtCaps);
421}
422
423//===----------------------------------------------------------------------===//
424// PointerType
425//===----------------------------------------------------------------------===//
426
427struct spirv::detail::PointerTypeStorage : public TypeStorage {
428 // (Type, StorageClass) as the key: Type stored in this struct, and
429 // StorageClass stored as TypeStorage's subclass data.
430 using KeyTy = std::pair<Type, StorageClass>;
431
432 static PointerTypeStorage *construct(TypeStorageAllocator &allocator,
433 const KeyTy &key) {
434 return new (allocator.allocate<PointerTypeStorage>())
435 PointerTypeStorage(key);
436 }
437
438 bool operator==(const KeyTy &key) const {
439 return key == KeyTy(pointeeType, storageClass);
440 }
441
442 PointerTypeStorage(const KeyTy &key)
443 : pointeeType(key.first), storageClass(key.second) {}
444
445 Type pointeeType;
446 StorageClass storageClass;
447};
448
449PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
450 return Base::get(ctx: pointeeType.getContext(), args&: pointeeType, args&: storageClass);
451}
452
453Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
454
455StorageClass PointerType::getStorageClass() const {
456 return getImpl()->storageClass;
457}
458
459void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
460 std::optional<StorageClass> storage) {
461 // Use this pointer type's storage class because this pointer indicates we are
462 // using the pointee type in that specific storage class.
463 llvm::cast<SPIRVType>(Val: getPointeeType())
464 .getExtensions(extensions, storage: getStorageClass());
465
466 if (auto scExts = spirv::getExtensions(value: getStorageClass()))
467 extensions.push_back(Elt: *scExts);
468}
469
470void PointerType::getCapabilities(
471 SPIRVType::CapabilityArrayRefVector &capabilities,
472 std::optional<StorageClass> storage) {
473 // Use this pointer type's storage class because this pointer indicates we are
474 // using the pointee type in that specific storage class.
475 llvm::cast<SPIRVType>(Val: getPointeeType())
476 .getCapabilities(capabilities, storage: getStorageClass());
477
478 if (auto scCaps = spirv::getCapabilities(value: getStorageClass()))
479 capabilities.push_back(Elt: *scCaps);
480}
481
482//===----------------------------------------------------------------------===//
483// RuntimeArrayType
484//===----------------------------------------------------------------------===//
485
486struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
487 using KeyTy = std::pair<Type, unsigned>;
488
489 static RuntimeArrayTypeStorage *construct(TypeStorageAllocator &allocator,
490 const KeyTy &key) {
491 return new (allocator.allocate<RuntimeArrayTypeStorage>())
492 RuntimeArrayTypeStorage(key);
493 }
494
495 bool operator==(const KeyTy &key) const {
496 return key == KeyTy(elementType, stride);
497 }
498
499 RuntimeArrayTypeStorage(const KeyTy &key)
500 : elementType(key.first), stride(key.second) {}
501
502 Type elementType;
503 unsigned stride;
504};
505
506RuntimeArrayType RuntimeArrayType::get(Type elementType) {
507 return Base::get(ctx: elementType.getContext(), args&: elementType, /*stride=*/args: 0);
508}
509
510RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
511 return Base::get(ctx: elementType.getContext(), args&: elementType, args&: stride);
512}
513
514Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
515
516unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
517
518void RuntimeArrayType::getExtensions(
519 SPIRVType::ExtensionArrayRefVector &extensions,
520 std::optional<StorageClass> storage) {
521 llvm::cast<SPIRVType>(Val: getElementType()).getExtensions(extensions, storage);
522}
523
524void RuntimeArrayType::getCapabilities(
525 SPIRVType::CapabilityArrayRefVector &capabilities,
526 std::optional<StorageClass> storage) {
527 {
528 static const Capability caps[] = {Capability::Shader};
529 ArrayRef<Capability> ref(caps, std::size(caps));
530 capabilities.push_back(Elt: ref);
531 }
532 llvm::cast<SPIRVType>(Val: getElementType())
533 .getCapabilities(capabilities, storage);
534}
535
536//===----------------------------------------------------------------------===//
537// ScalarType
538//===----------------------------------------------------------------------===//
539
540bool ScalarType::classof(Type type) {
541 if (auto floatType = llvm::dyn_cast<FloatType>(Val&: type)) {
542 return isValid(floatType);
543 }
544 if (auto intType = llvm::dyn_cast<IntegerType>(Val&: type)) {
545 return isValid(intType);
546 }
547 return false;
548}
549
550bool ScalarType::isValid(FloatType type) {
551 return llvm::is_contained(Set: {16u, 32u, 64u}, Element: type.getWidth());
552}
553
554bool ScalarType::isValid(IntegerType type) {
555 return llvm::is_contained(Set: {1u, 8u, 16u, 32u, 64u}, Element: type.getWidth());
556}
557
558void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
559 std::optional<StorageClass> storage) {
560 if (isa<BFloat16Type>(Val: *this)) {
561 static const Extension ext = Extension::SPV_KHR_bfloat16;
562 extensions.push_back(Elt: ext);
563 }
564
565 // 8- or 16-bit integer/floating-point numbers will require extra extensions
566 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
567 // SPV_KHR_8bit_storage for more details.
568 if (!storage)
569 return;
570
571 switch (*storage) {
572 case StorageClass::PushConstant:
573 case StorageClass::StorageBuffer:
574 case StorageClass::Uniform:
575 if (getIntOrFloatBitWidth() == 8) {
576 static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
577 ArrayRef<Extension> ref(exts, std::size(exts));
578 extensions.push_back(Elt: ref);
579 }
580 [[fallthrough]];
581 case StorageClass::Input:
582 case StorageClass::Output:
583 if (getIntOrFloatBitWidth() == 16) {
584 static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
585 ArrayRef<Extension> ref(exts, std::size(exts));
586 extensions.push_back(Elt: ref);
587 }
588 break;
589 default:
590 break;
591 }
592}
593
594void ScalarType::getCapabilities(
595 SPIRVType::CapabilityArrayRefVector &capabilities,
596 std::optional<StorageClass> storage) {
597 unsigned bitwidth = getIntOrFloatBitWidth();
598
599 // 8- or 16-bit integer/floating-point numbers will require extra capabilities
600 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
601 // SPV_KHR_8bit_storage for more details.
602
603#define STORAGE_CASE(storage, cap8, cap16) \
604 case StorageClass::storage: { \
605 if (bitwidth == 8) { \
606 static const Capability caps[] = {Capability::cap8}; \
607 ArrayRef<Capability> ref(caps, std::size(caps)); \
608 capabilities.push_back(ref); \
609 return; \
610 } \
611 if (bitwidth == 16) { \
612 static const Capability caps[] = {Capability::cap16}; \
613 ArrayRef<Capability> ref(caps, std::size(caps)); \
614 capabilities.push_back(ref); \
615 return; \
616 } \
617 /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \
618 /* storage classes. Fall through to the next section. */ \
619 } break
620
621 // This part only handles the cases where special bitwidths appearing in
622 // interface storage classes.
623 if (storage) {
624 switch (*storage) {
625 STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
626 STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
627 StorageBuffer16BitAccess);
628 STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
629 StorageUniform16);
630 case StorageClass::Input:
631 case StorageClass::Output: {
632 if (bitwidth == 16) {
633 static const Capability caps[] = {Capability::StorageInputOutput16};
634 ArrayRef<Capability> ref(caps, std::size(caps));
635 capabilities.push_back(Elt: ref);
636 return;
637 }
638 break;
639 }
640 default:
641 break;
642 }
643 }
644#undef STORAGE_CASE
645
646 // For other non-interface storage classes, require a different set of
647 // capabilities for special bitwidths.
648
649#define WIDTH_CASE(type, width) \
650 case width: { \
651 static const Capability caps[] = {Capability::type##width}; \
652 ArrayRef<Capability> ref(caps, std::size(caps)); \
653 capabilities.push_back(ref); \
654 } break
655
656 if (auto intType = llvm::dyn_cast<IntegerType>(Val&: *this)) {
657 switch (bitwidth) {
658 WIDTH_CASE(Int, 8);
659 WIDTH_CASE(Int, 16);
660 WIDTH_CASE(Int, 64);
661 case 1:
662 case 32:
663 break;
664 default:
665 llvm_unreachable("invalid bitwidth to getCapabilities");
666 }
667 } else {
668 assert(llvm::isa<FloatType>(*this));
669 switch (bitwidth) {
670 case 16: {
671 if (isa<BFloat16Type>(Val: *this)) {
672 static const Capability cap = Capability::BFloat16TypeKHR;
673 capabilities.push_back(Elt: cap);
674 } else {
675 static const Capability cap = Capability::Float16;
676 capabilities.push_back(Elt: cap);
677 }
678 break;
679 }
680 WIDTH_CASE(Float, 64);
681 case 32:
682 break;
683 default:
684 llvm_unreachable("invalid bitwidth to getCapabilities");
685 }
686 }
687
688#undef WIDTH_CASE
689}
690
691std::optional<int64_t> ScalarType::getSizeInBytes() {
692 auto bitWidth = getIntOrFloatBitWidth();
693 // According to the SPIR-V spec:
694 // "There is no physical size or bit pattern defined for values with boolean
695 // type. If they are stored (in conjunction with OpVariable), they can only
696 // be used with logical addressing operations, not physical, and only with
697 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
698 // Private, Function, Input, and Output."
699 if (bitWidth == 1)
700 return std::nullopt;
701 return bitWidth / 8;
702}
703
704//===----------------------------------------------------------------------===//
705// SPIRVType
706//===----------------------------------------------------------------------===//
707
708bool SPIRVType::classof(Type type) {
709 // Allow SPIR-V dialect types
710 if (llvm::isa<SPIRVDialect>(Val: type.getDialect()))
711 return true;
712 if (llvm::isa<ScalarType>(Val: type))
713 return true;
714 if (auto vectorType = llvm::dyn_cast<VectorType>(Val&: type))
715 return CompositeType::isValid(type: vectorType);
716 if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(Val&: type))
717 return llvm::isa<ScalarType>(Val: tensorArmType.getElementType());
718 return false;
719}
720
721bool SPIRVType::isScalarOrVector() {
722 return isIntOrFloat() || llvm::isa<VectorType>(Val: *this);
723}
724
725void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
726 std::optional<StorageClass> storage) {
727 if (auto scalarType = llvm::dyn_cast<ScalarType>(Val&: *this)) {
728 scalarType.getExtensions(extensions, storage);
729 } else if (auto compositeType = llvm::dyn_cast<CompositeType>(Val&: *this)) {
730 compositeType.getExtensions(extensions, storage);
731 } else if (auto imageType = llvm::dyn_cast<ImageType>(Val&: *this)) {
732 imageType.getExtensions(extensions, storage);
733 } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(Val&: *this)) {
734 sampledImageType.getExtensions(extensions, storage);
735 } else if (auto matrixType = llvm::dyn_cast<MatrixType>(Val&: *this)) {
736 matrixType.getExtensions(extensions, storage);
737 } else if (auto ptrType = llvm::dyn_cast<PointerType>(Val&: *this)) {
738 ptrType.getExtensions(extensions, storage);
739 } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(Val&: *this)) {
740 tensorArmType.getExtensions(extensions, storage);
741 } else {
742 llvm_unreachable("invalid SPIR-V Type to getExtensions");
743 }
744}
745
746void SPIRVType::getCapabilities(
747 SPIRVType::CapabilityArrayRefVector &capabilities,
748 std::optional<StorageClass> storage) {
749 if (auto scalarType = llvm::dyn_cast<ScalarType>(Val&: *this)) {
750 scalarType.getCapabilities(capabilities, storage);
751 } else if (auto compositeType = llvm::dyn_cast<CompositeType>(Val&: *this)) {
752 compositeType.getCapabilities(capabilities, storage);
753 } else if (auto imageType = llvm::dyn_cast<ImageType>(Val&: *this)) {
754 imageType.getCapabilities(capabilities, storage);
755 } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(Val&: *this)) {
756 sampledImageType.getCapabilities(capabilities, storage);
757 } else if (auto matrixType = llvm::dyn_cast<MatrixType>(Val&: *this)) {
758 matrixType.getCapabilities(capabilities, storage);
759 } else if (auto ptrType = llvm::dyn_cast<PointerType>(Val&: *this)) {
760 ptrType.getCapabilities(capabilities, storage);
761 } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(Val&: *this)) {
762 tensorArmType.getCapabilities(capabilities, storage);
763 } else {
764 llvm_unreachable("invalid SPIR-V Type to getCapabilities");
765 }
766}
767
768std::optional<int64_t> SPIRVType::getSizeInBytes() {
769 if (auto scalarType = llvm::dyn_cast<ScalarType>(Val&: *this))
770 return scalarType.getSizeInBytes();
771 if (auto compositeType = llvm::dyn_cast<CompositeType>(Val&: *this))
772 return compositeType.getSizeInBytes();
773 return std::nullopt;
774}
775
776//===----------------------------------------------------------------------===//
777// SampledImageType
778//===----------------------------------------------------------------------===//
779struct spirv::detail::SampledImageTypeStorage : public TypeStorage {
780 using KeyTy = Type;
781
782 SampledImageTypeStorage(const KeyTy &key) : imageType{key} {}
783
784 bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
785
786 static SampledImageTypeStorage *construct(TypeStorageAllocator &allocator,
787 const KeyTy &key) {
788 return new (allocator.allocate<SampledImageTypeStorage>())
789 SampledImageTypeStorage(key);
790 }
791
792 Type imageType;
793};
794
795SampledImageType SampledImageType::get(Type imageType) {
796 return Base::get(ctx: imageType.getContext(), args&: imageType);
797}
798
799SampledImageType
800SampledImageType::getChecked(function_ref<InFlightDiagnostic()> emitError,
801 Type imageType) {
802 return Base::getChecked(emitErrorFn: emitError, ctx: imageType.getContext(), args: imageType);
803}
804
805Type SampledImageType::getImageType() const { return getImpl()->imageType; }
806
807LogicalResult
808SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
809 Type imageType) {
810 if (!llvm::isa<ImageType>(Val: imageType))
811 return emitError() << "expected image type";
812
813 return success();
814}
815
816void SampledImageType::getExtensions(
817 SPIRVType::ExtensionArrayRefVector &extensions,
818 std::optional<StorageClass> storage) {
819 llvm::cast<ImageType>(Val: getImageType()).getExtensions(extensions, storage);
820}
821
822void SampledImageType::getCapabilities(
823 SPIRVType::CapabilityArrayRefVector &capabilities,
824 std::optional<StorageClass> storage) {
825 llvm::cast<ImageType>(Val: getImageType()).getCapabilities(capabilities, storage);
826}
827
828//===----------------------------------------------------------------------===//
829// StructType
830//===----------------------------------------------------------------------===//
831
832/// Type storage for SPIR-V structure types:
833///
834/// Structures are uniqued using:
835/// - for identified structs:
836/// - a string identifier;
837/// - for literal structs:
838/// - a list of member types;
839/// - a list of member offset info;
840/// - a list of member decoration info.
841///
842/// Identified structures only have a mutable component consisting of:
843/// - a list of member types;
844/// - a list of member offset info;
845/// - a list of member decoration info.
846struct spirv::detail::StructTypeStorage : public TypeStorage {
847 /// Construct a storage object for an identified struct type. A struct type
848 /// associated with such storage must call StructType::trySetBody(...) later
849 /// in order to mutate the storage object providing the actual content.
850 StructTypeStorage(StringRef identifier)
851 : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
852 numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
853 identifier(identifier) {}
854
855 /// Construct a storage object for a literal struct type. A struct type
856 /// associated with such storage is immutable.
857 StructTypeStorage(
858 unsigned numMembers, Type const *memberTypes,
859 StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
860 StructType::MemberDecorationInfo const *memberDecorationsInfo)
861 : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
862 numMembers(numMembers), numMemberDecorations(numMemberDecorations),
863 memberDecorationsInfo(memberDecorationsInfo) {}
864
865 /// A storage key is divided into 2 parts:
866 /// - for identified structs:
867 /// - a StringRef representing the struct identifier;
868 /// - for literal structs:
869 /// - an ArrayRef<Type> for member types;
870 /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
871 /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
872 /// info.
873 ///
874 /// An identified struct type is uniqued only by the first part (field 0)
875 /// of the key.
876 ///
877 /// A literal struct type is uniqued only by the second part (fields 1, 2, and
878 /// 3) of the key. The identifier field (field 0) must be empty.
879 using KeyTy =
880 std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
881 ArrayRef<StructType::MemberDecorationInfo>>;
882
883 /// For identified structs, return true if the given key contains the same
884 /// identifier.
885 ///
886 /// For literal structs, return true if the given key contains a matching list
887 /// of member types + offset info + decoration info.
888 bool operator==(const KeyTy &key) const {
889 if (isIdentified()) {
890 // Identified types are uniqued by their identifier.
891 return getIdentifier() == std::get<0>(t: key);
892 }
893
894 return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
895 getMemberDecorationsInfo());
896 }
897
898 /// If the given key contains a non-empty identifier, this method constructs
899 /// an identified struct and leaves the rest of the struct type data to be set
900 /// through a later call to StructType::trySetBody(...).
901 ///
902 /// If, on the other hand, the key contains an empty identifier, a literal
903 /// struct is constructed using the other fields of the key.
904 static StructTypeStorage *construct(TypeStorageAllocator &allocator,
905 const KeyTy &key) {
906 StringRef keyIdentifier = std::get<0>(t: key);
907
908 if (!keyIdentifier.empty()) {
909 StringRef identifier = allocator.copyInto(str: keyIdentifier);
910
911 // Identified StructType body/members will be set through trySetBody(...)
912 // later.
913 return new (allocator.allocate<StructTypeStorage>())
914 StructTypeStorage(identifier);
915 }
916
917 ArrayRef<Type> keyTypes = std::get<1>(t: key);
918
919 // Copy the member type and layout information into the bump pointer
920 const Type *typesList = nullptr;
921 if (!keyTypes.empty()) {
922 typesList = allocator.copyInto(elements: keyTypes).data();
923 }
924
925 const StructType::OffsetInfo *offsetInfoList = nullptr;
926 if (!std::get<2>(t: key).empty()) {
927 ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(t: key);
928 assert(keyOffsetInfo.size() == keyTypes.size() &&
929 "size of offset information must be same as the size of number of "
930 "elements");
931 offsetInfoList = allocator.copyInto(elements: keyOffsetInfo).data();
932 }
933
934 const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
935 unsigned numMemberDecorations = 0;
936 if (!std::get<3>(t: key).empty()) {
937 auto keyMemberDecorations = std::get<3>(t: key);
938 numMemberDecorations = keyMemberDecorations.size();
939 memberDecorationList = allocator.copyInto(elements: keyMemberDecorations).data();
940 }
941
942 return new (allocator.allocate<StructTypeStorage>())
943 StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
944 numMemberDecorations, memberDecorationList);
945 }
946
947 ArrayRef<Type> getMemberTypes() const {
948 return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
949 }
950
951 ArrayRef<StructType::OffsetInfo> getOffsetInfo() const {
952 if (offsetInfo) {
953 return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
954 }
955 return {};
956 }
957
958 ArrayRef<StructType::MemberDecorationInfo> getMemberDecorationsInfo() const {
959 if (memberDecorationsInfo) {
960 return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
961 numMemberDecorations);
962 }
963 return {};
964 }
965
966 StringRef getIdentifier() const { return identifier; }
967
968 bool isIdentified() const { return !identifier.empty(); }
969
970 /// Sets the struct type content for identified structs. Calling this method
971 /// is only valid for identified structs.
972 ///
973 /// Fails under the following conditions:
974 /// - If called for a literal struct;
975 /// - If called for an identified struct whose body was set before (through a
976 /// call to this method) but with different contents from the passed
977 /// arguments.
978 LogicalResult mutate(
979 TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
980 ArrayRef<StructType::OffsetInfo> structOffsetInfo,
981 ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
982 if (!isIdentified())
983 return failure();
984
985 if (memberTypesAndIsBodySet.getInt() &&
986 (getMemberTypes() != structMemberTypes ||
987 getOffsetInfo() != structOffsetInfo ||
988 getMemberDecorationsInfo() != structMemberDecorationInfo))
989 return failure();
990
991 memberTypesAndIsBodySet.setInt(true);
992 numMembers = structMemberTypes.size();
993
994 // Copy the member type and layout information into the bump pointer.
995 if (!structMemberTypes.empty())
996 memberTypesAndIsBodySet.setPointer(
997 allocator.copyInto(elements: structMemberTypes).data());
998
999 if (!structOffsetInfo.empty()) {
1000 assert(structOffsetInfo.size() == structMemberTypes.size() &&
1001 "size of offset information must be same as the size of number of "
1002 "elements");
1003 offsetInfo = allocator.copyInto(elements: structOffsetInfo).data();
1004 }
1005
1006 if (!structMemberDecorationInfo.empty()) {
1007 numMemberDecorations = structMemberDecorationInfo.size();
1008 memberDecorationsInfo =
1009 allocator.copyInto(elements: structMemberDecorationInfo).data();
1010 }
1011
1012 return success();
1013 }
1014
1015 llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
1016 StructType::OffsetInfo const *offsetInfo;
1017 unsigned numMembers;
1018 unsigned numMemberDecorations;
1019 StructType::MemberDecorationInfo const *memberDecorationsInfo;
1020 StringRef identifier;
1021};
1022
1023StructType
1024StructType::get(ArrayRef<Type> memberTypes,
1025 ArrayRef<StructType::OffsetInfo> offsetInfo,
1026 ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
1027 assert(!memberTypes.empty() && "Struct needs at least one member type");
1028 // Sort the decorations.
1029 SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
1030 memberDecorations);
1031 llvm::array_pod_sort(Start: sortedDecorations.begin(), End: sortedDecorations.end());
1032 return Base::get(ctx: memberTypes.vec().front().getContext(),
1033 /*identifier=*/args: StringRef(), args&: memberTypes, args&: offsetInfo,
1034 args&: sortedDecorations);
1035}
1036
1037StructType StructType::getIdentified(MLIRContext *context,
1038 StringRef identifier) {
1039 assert(!identifier.empty() &&
1040 "StructType identifier must be non-empty string");
1041
1042 return Base::get(ctx: context, args&: identifier, args: ArrayRef<Type>(),
1043 args: ArrayRef<StructType::OffsetInfo>(),
1044 args: ArrayRef<StructType::MemberDecorationInfo>());
1045}
1046
1047StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
1048 StructType newStructType = Base::get(
1049 ctx: context, args&: identifier, args: ArrayRef<Type>(), args: ArrayRef<StructType::OffsetInfo>(),
1050 args: ArrayRef<StructType::MemberDecorationInfo>());
1051 // Set an empty body in case this is a identified struct.
1052 if (newStructType.isIdentified() &&
1053 failed(Result: newStructType.trySetBody(
1054 memberTypes: ArrayRef<Type>(), offsetInfo: ArrayRef<StructType::OffsetInfo>(),
1055 memberDecorations: ArrayRef<StructType::MemberDecorationInfo>())))
1056 return StructType();
1057
1058 return newStructType;
1059}
1060
1061StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
1062
1063bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1064
1065unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1066
1067Type StructType::getElementType(unsigned index) const {
1068 assert(getNumElements() > index && "member index out of range");
1069 return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1070}
1071
1072TypeRange StructType::getElementTypes() const {
1073 return TypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1074 getNumElements());
1075}
1076
1077bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1078
1079uint64_t StructType::getMemberOffset(unsigned index) const {
1080 assert(getNumElements() > index && "member index out of range");
1081 return getImpl()->offsetInfo[index];
1082}
1083
1084void StructType::getMemberDecorations(
1085 SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorations)
1086 const {
1087 memberDecorations.clear();
1088 auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1089 memberDecorations.append(in_start: implMemberDecorations.begin(),
1090 in_end: implMemberDecorations.end());
1091}
1092
1093void StructType::getMemberDecorations(
1094 unsigned index,
1095 SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
1096 assert(getNumElements() > index && "member index out of range");
1097 auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1098 decorationsInfo.clear();
1099 for (const auto &memberDecoration : memberDecorations) {
1100 if (memberDecoration.memberIndex == index) {
1101 decorationsInfo.push_back(Elt: memberDecoration);
1102 }
1103 if (memberDecoration.memberIndex > index) {
1104 // Early exit since the decorations are stored sorted.
1105 return;
1106 }
1107 }
1108}
1109
1110LogicalResult
1111StructType::trySetBody(ArrayRef<Type> memberTypes,
1112 ArrayRef<OffsetInfo> offsetInfo,
1113 ArrayRef<MemberDecorationInfo> memberDecorations) {
1114 return Base::mutate(args&: memberTypes, args&: offsetInfo, args&: memberDecorations);
1115}
1116
1117void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1118 std::optional<StorageClass> storage) {
1119 for (Type elementType : getElementTypes())
1120 llvm::cast<SPIRVType>(Val&: elementType).getExtensions(extensions, storage);
1121}
1122
1123void StructType::getCapabilities(
1124 SPIRVType::CapabilityArrayRefVector &capabilities,
1125 std::optional<StorageClass> storage) {
1126 for (Type elementType : getElementTypes())
1127 llvm::cast<SPIRVType>(Val&: elementType).getCapabilities(capabilities, storage);
1128}
1129
1130llvm::hash_code spirv::hash_value(
1131 const StructType::MemberDecorationInfo &memberDecorationInfo) {
1132 return llvm::hash_combine(args: memberDecorationInfo.memberIndex,
1133 args: memberDecorationInfo.decoration);
1134}
1135
1136//===----------------------------------------------------------------------===//
1137// MatrixType
1138//===----------------------------------------------------------------------===//
1139
1140struct spirv::detail::MatrixTypeStorage : public TypeStorage {
1141 MatrixTypeStorage(Type columnType, uint32_t columnCount)
1142 : columnType(columnType), columnCount(columnCount) {}
1143
1144 using KeyTy = std::tuple<Type, uint32_t>;
1145
1146 static MatrixTypeStorage *construct(TypeStorageAllocator &allocator,
1147 const KeyTy &key) {
1148
1149 // Initialize the memory using placement new.
1150 return new (allocator.allocate<MatrixTypeStorage>())
1151 MatrixTypeStorage(std::get<0>(t: key), std::get<1>(t: key));
1152 }
1153
1154 bool operator==(const KeyTy &key) const {
1155 return key == KeyTy(columnType, columnCount);
1156 }
1157
1158 Type columnType;
1159 const uint32_t columnCount;
1160};
1161
1162MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1163 return Base::get(ctx: columnType.getContext(), args&: columnType, args&: columnCount);
1164}
1165
1166MatrixType MatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
1167 Type columnType, uint32_t columnCount) {
1168 return Base::getChecked(emitErrorFn: emitError, ctx: columnType.getContext(), args: columnType,
1169 args: columnCount);
1170}
1171
1172LogicalResult
1173MatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
1174 Type columnType, uint32_t columnCount) {
1175 if (columnCount < 2 || columnCount > 4)
1176 return emitError() << "matrix can have 2, 3, or 4 columns only";
1177
1178 if (!isValidColumnType(columnType))
1179 return emitError() << "matrix columns must be vectors of floats";
1180
1181 /// The underlying vectors (columns) must be of size 2, 3, or 4
1182 ArrayRef<int64_t> columnShape = llvm::cast<VectorType>(Val&: columnType).getShape();
1183 if (columnShape.size() != 1)
1184 return emitError() << "matrix columns must be 1D vectors";
1185
1186 if (columnShape[0] < 2 || columnShape[0] > 4)
1187 return emitError() << "matrix columns must be of size 2, 3, or 4";
1188
1189 return success();
1190}
1191
1192/// Returns true if the matrix elements are vectors of float elements
1193bool MatrixType::isValidColumnType(Type columnType) {
1194 if (auto vectorType = llvm::dyn_cast<VectorType>(Val&: columnType)) {
1195 if (llvm::isa<FloatType>(Val: vectorType.getElementType()))
1196 return true;
1197 }
1198 return false;
1199}
1200
1201Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1202
1203Type MatrixType::getElementType() const {
1204 return llvm::cast<VectorType>(Val&: getImpl()->columnType).getElementType();
1205}
1206
1207unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1208
1209unsigned MatrixType::getNumRows() const {
1210 return llvm::cast<VectorType>(Val&: getImpl()->columnType).getShape()[0];
1211}
1212
1213unsigned MatrixType::getNumElements() const {
1214 return (getImpl()->columnCount) * getNumRows();
1215}
1216
1217void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1218 std::optional<StorageClass> storage) {
1219 llvm::cast<SPIRVType>(Val: getColumnType()).getExtensions(extensions, storage);
1220}
1221
1222void MatrixType::getCapabilities(
1223 SPIRVType::CapabilityArrayRefVector &capabilities,
1224 std::optional<StorageClass> storage) {
1225 {
1226 static const Capability caps[] = {Capability::Matrix};
1227 ArrayRef<Capability> ref(caps, std::size(caps));
1228 capabilities.push_back(Elt: ref);
1229 }
1230 // Add any capabilities associated with the underlying vectors (i.e., columns)
1231 llvm::cast<SPIRVType>(Val: getColumnType()).getCapabilities(capabilities, storage);
1232}
1233
1234//===----------------------------------------------------------------------===//
1235// TensorArmType
1236//===----------------------------------------------------------------------===//
1237
1238struct spirv::detail::TensorArmTypeStorage final : TypeStorage {
1239 using KeyTy = std::tuple<ArrayRef<int64_t>, Type>;
1240
1241 static TensorArmTypeStorage *construct(TypeStorageAllocator &allocator,
1242 const KeyTy &key) {
1243 auto [shape, elementType] = key;
1244 shape = allocator.copyInto(elements: shape);
1245 return new (allocator.allocate<TensorArmTypeStorage>())
1246 TensorArmTypeStorage(shape, elementType);
1247 }
1248
1249 static llvm::hash_code hashKey(const KeyTy &key) {
1250 auto [shape, elementType] = key;
1251 return llvm::hash_combine(args: shape, args: elementType);
1252 }
1253
1254 bool operator==(const KeyTy &key) const {
1255 return key == KeyTy(shape, elementType);
1256 }
1257
1258 TensorArmTypeStorage(ArrayRef<int64_t> shape, Type elementType)
1259 : shape(std::move(shape)), elementType(std::move(elementType)) {}
1260
1261 ArrayRef<int64_t> shape;
1262 Type elementType;
1263};
1264
1265TensorArmType TensorArmType::get(ArrayRef<int64_t> shape, Type elementType) {
1266 return Base::get(ctx: elementType.getContext(), args&: shape, args&: elementType);
1267}
1268
1269TensorArmType TensorArmType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
1270 Type elementType) const {
1271 return TensorArmType::get(shape: shape.value_or(u: getShape()), elementType);
1272}
1273
1274Type TensorArmType::getElementType() const { return getImpl()->elementType; }
1275ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }
1276
1277void TensorArmType::getExtensions(
1278 SPIRVType::ExtensionArrayRefVector &extensions,
1279 std::optional<StorageClass> storage) {
1280
1281 llvm::cast<SPIRVType>(Val: getElementType()).getExtensions(extensions, storage);
1282 static constexpr Extension ext{Extension::SPV_ARM_tensors};
1283 extensions.push_back(Elt: ext);
1284}
1285
1286void TensorArmType::getCapabilities(
1287 SPIRVType::CapabilityArrayRefVector &capabilities,
1288 std::optional<StorageClass> storage) {
1289 llvm::cast<SPIRVType>(Val: getElementType())
1290 .getCapabilities(capabilities, storage);
1291 static constexpr Capability cap{Capability::TensorsARM};
1292 capabilities.push_back(Elt: cap);
1293}
1294
1295LogicalResult
1296TensorArmType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
1297 ArrayRef<int64_t> shape, Type elementType) {
1298 if (llvm::is_contained(Range&: shape, Element: 0))
1299 return emitError() << "arm.tensor do not support dimensions = 0";
1300 if (llvm::any_of(Range&: shape, P: [](int64_t dim) { return dim < 0; }) &&
1301 llvm::any_of(Range&: shape, P: [](int64_t dim) { return dim > 0; }))
1302 return emitError()
1303 << "arm.tensor shape dimensions must be either fully dynamic or "
1304 "completed shaped";
1305 return success();
1306}
1307
1308//===----------------------------------------------------------------------===//
1309// SPIR-V Dialect
1310//===----------------------------------------------------------------------===//
1311
1312void SPIRVDialect::registerTypes() {
1313 addTypes<ArrayType, CooperativeMatrixType, ImageType, MatrixType, PointerType,
1314 RuntimeArrayType, SampledImageType, StructType, TensorArmType>();
1315}
1316

source code of mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp