1 | //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the types in the SPIR-V dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
14 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
15 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
16 | #include "mlir/IR/Attributes.h" |
17 | #include "mlir/IR/BuiltinTypes.h" |
18 | #include "llvm/ADT/STLExtras.h" |
19 | #include "llvm/ADT/TypeSwitch.h" |
20 | |
21 | #include <cstdint> |
22 | #include <iterator> |
23 | |
24 | using namespace mlir; |
25 | using namespace mlir::spirv; |
26 | |
27 | //===----------------------------------------------------------------------===// |
28 | // ArrayType |
29 | //===----------------------------------------------------------------------===// |
30 | |
31 | struct spirv::detail::ArrayTypeStorage : public TypeStorage { |
32 | using KeyTy = std::tuple<Type, unsigned, unsigned>; |
33 | |
34 | static ArrayTypeStorage *construct(TypeStorageAllocator &allocator, |
35 | const KeyTy &key) { |
36 | return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key); |
37 | } |
38 | |
39 | bool operator==(const KeyTy &key) const { |
40 | return key == KeyTy(elementType, elementCount, stride); |
41 | } |
42 | |
43 | ArrayTypeStorage(const KeyTy &key) |
44 | : elementType(std::get<0>(t: key)), elementCount(std::get<1>(t: key)), |
45 | stride(std::get<2>(t: key)) {} |
46 | |
47 | Type elementType; |
48 | unsigned elementCount; |
49 | unsigned stride; |
50 | }; |
51 | |
52 | ArrayType ArrayType::get(Type elementType, unsigned elementCount) { |
53 | assert(elementCount && "ArrayType needs at least one element"); |
54 | return Base::get(ctx: elementType.getContext(), args&: elementType, args&: elementCount, |
55 | /*stride=*/args: 0); |
56 | } |
57 | |
58 | ArrayType ArrayType::get(Type elementType, unsigned elementCount, |
59 | unsigned stride) { |
60 | assert(elementCount && "ArrayType needs at least one element"); |
61 | return Base::get(ctx: elementType.getContext(), args&: elementType, args&: elementCount, args&: stride); |
62 | } |
63 | |
64 | unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; } |
65 | |
66 | Type ArrayType::getElementType() const { return getImpl()->elementType; } |
67 | |
68 | unsigned ArrayType::getArrayStride() const { return getImpl()->stride; } |
69 | |
70 | void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, |
71 | std::optional<StorageClass> storage) { |
72 | llvm::cast<SPIRVType>(Val: getElementType()).getExtensions(extensions, storage); |
73 | } |
74 | |
75 | void ArrayType::getCapabilities( |
76 | SPIRVType::CapabilityArrayRefVector &capabilities, |
77 | std::optional<StorageClass> storage) { |
78 | llvm::cast<SPIRVType>(Val: getElementType()) |
79 | .getCapabilities(capabilities, storage); |
80 | } |
81 | |
82 | std::optional<int64_t> ArrayType::getSizeInBytes() { |
83 | auto elementType = llvm::cast<SPIRVType>(Val: getElementType()); |
84 | std::optional<int64_t> size = elementType.getSizeInBytes(); |
85 | if (!size) |
86 | return std::nullopt; |
87 | return (*size + getArrayStride()) * getNumElements(); |
88 | } |
89 | |
90 | //===----------------------------------------------------------------------===// |
91 | // CompositeType |
92 | //===----------------------------------------------------------------------===// |
93 | |
94 | bool CompositeType::classof(Type type) { |
95 | if (auto vectorType = llvm::dyn_cast<VectorType>(type)) |
96 | return isValid(vectorType); |
97 | return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType, |
98 | spirv::MatrixType, spirv::RuntimeArrayType, |
99 | spirv::StructType>(Val: type); |
100 | } |
101 | |
102 | bool CompositeType::isValid(VectorType type) { |
103 | return type.getRank() == 1 && |
104 | llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) && |
105 | llvm::isa<ScalarType>(type.getElementType()); |
106 | } |
107 | |
108 | Type CompositeType::getElementType(unsigned index) const { |
109 | return TypeSwitch<Type, Type>(*this) |
110 | .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType>( |
111 | [](auto type) { return type.getElementType(); }) |
112 | .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); }) |
113 | .Case<StructType>( |
114 | [index](StructType type) { return type.getElementType(index); }) |
115 | .Default( |
116 | [](Type) -> Type { llvm_unreachable("invalid composite type"); }); |
117 | } |
118 | |
119 | unsigned CompositeType::getNumElements() const { |
120 | if (auto arrayType = llvm::dyn_cast<ArrayType>(Val: *this)) |
121 | return arrayType.getNumElements(); |
122 | if (auto matrixType = llvm::dyn_cast<MatrixType>(Val: *this)) |
123 | return matrixType.getNumColumns(); |
124 | if (auto structType = llvm::dyn_cast<StructType>(Val: *this)) |
125 | return structType.getNumElements(); |
126 | if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) |
127 | return vectorType.getNumElements(); |
128 | if (llvm::isa<CooperativeMatrixType>(Val: *this)) { |
129 | llvm_unreachable( |
130 | "invalid to query number of elements of spirv Cooperative Matrix type"); |
131 | } |
132 | if (llvm::isa<RuntimeArrayType>(Val: *this)) { |
133 | llvm_unreachable( |
134 | "invalid to query number of elements of spirv::RuntimeArray type"); |
135 | } |
136 | llvm_unreachable("invalid composite type"); |
137 | } |
138 | |
139 | bool CompositeType::hasCompileTimeKnownNumElements() const { |
140 | return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(Val: *this); |
141 | } |
142 | |
143 | void CompositeType::getExtensions( |
144 | SPIRVType::ExtensionArrayRefVector &extensions, |
145 | std::optional<StorageClass> storage) { |
146 | TypeSwitch<Type>(*this) |
147 | .Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType, |
148 | StructType>( |
149 | [&](auto type) { type.getExtensions(extensions, storage); }) |
150 | .Case<VectorType>([&](VectorType type) { |
151 | return llvm::cast<ScalarType>(type.getElementType()) |
152 | .getExtensions(extensions, storage); |
153 | }) |
154 | .Default([](Type) { llvm_unreachable("invalid composite type"); }); |
155 | } |
156 | |
157 | void CompositeType::getCapabilities( |
158 | SPIRVType::CapabilityArrayRefVector &capabilities, |
159 | std::optional<StorageClass> storage) { |
160 | TypeSwitch<Type>(*this) |
161 | .Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType, |
162 | StructType>( |
163 | [&](auto type) { type.getCapabilities(capabilities, storage); }) |
164 | .Case<VectorType>([&](VectorType type) { |
165 | auto vecSize = getNumElements(); |
166 | if (vecSize == 8 || vecSize == 16) { |
167 | static const Capability caps[] = {Capability::Vector16}; |
168 | ArrayRef<Capability> ref(caps, std::size(caps)); |
169 | capabilities.push_back(ref); |
170 | } |
171 | return llvm::cast<ScalarType>(type.getElementType()) |
172 | .getCapabilities(capabilities, storage); |
173 | }) |
174 | .Default([](Type) { llvm_unreachable("invalid composite type"); }); |
175 | } |
176 | |
177 | std::optional<int64_t> CompositeType::getSizeInBytes() { |
178 | if (auto arrayType = llvm::dyn_cast<ArrayType>(Val&: *this)) |
179 | return arrayType.getSizeInBytes(); |
180 | if (auto structType = llvm::dyn_cast<StructType>(Val&: *this)) |
181 | return structType.getSizeInBytes(); |
182 | if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) { |
183 | std::optional<int64_t> elementSize = |
184 | llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes(); |
185 | if (!elementSize) |
186 | return std::nullopt; |
187 | return *elementSize * vectorType.getNumElements(); |
188 | } |
189 | return std::nullopt; |
190 | } |
191 | |
192 | //===----------------------------------------------------------------------===// |
193 | // CooperativeMatrixType |
194 | //===----------------------------------------------------------------------===// |
195 | |
196 | struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage { |
197 | // In the specification dimensions of the Cooperative Matrix are 32-bit |
198 | // integers --- the initial implementation kept those values as such. However, |
199 | // the `ShapedType` expects the shape to be `int64_t`. We could keep the shape |
200 | // as 32-bits and expose it as int64_t through `getShape`, however, this |
201 | // method returns an `ArrayRef`, so returning `ArrayRef<int64_t>` having two |
202 | // 32-bits integers would require an extra logic and storage. So, we diverge |
203 | // from the spec and internally represent the dimensions as 64-bit integers, |
204 | // so we can easily return an `ArrayRef` from `getShape` without any extra |
205 | // logic. Alternatively, we could store both rows and columns (both 32-bits) |
206 | // and shape (64-bits), assigning rows and columns to shape whenever |
207 | // `getShape` is called. This would be at the cost of extra logic and storage. |
208 | // Note: Because `ArrayRef` is returned we cannot construct an object in |
209 | // `getShape` on the fly. |
210 | using KeyTy = |
211 | std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>; |
212 | |
213 | static CooperativeMatrixTypeStorage * |
214 | construct(TypeStorageAllocator &allocator, const KeyTy &key) { |
215 | return new (allocator.allocate<CooperativeMatrixTypeStorage>()) |
216 | CooperativeMatrixTypeStorage(key); |
217 | } |
218 | |
219 | bool operator==(const KeyTy &key) const { |
220 | return key == KeyTy(elementType, shape[0], shape[1], scope, use); |
221 | } |
222 | |
223 | CooperativeMatrixTypeStorage(const KeyTy &key) |
224 | : elementType(std::get<0>(key)), |
225 | shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)), |
226 | use(std::get<4>(key)) {} |
227 | |
228 | Type elementType; |
229 | // [#rows, #columns] |
230 | std::array<int64_t, 2> shape; |
231 | Scope scope; |
232 | CooperativeMatrixUseKHR use; |
233 | }; |
234 | |
235 | CooperativeMatrixType CooperativeMatrixType::get(Type elementType, |
236 | uint32_t rows, |
237 | uint32_t columns, Scope scope, |
238 | CooperativeMatrixUseKHR use) { |
239 | return Base::get(elementType.getContext(), elementType, rows, columns, scope, |
240 | use); |
241 | } |
242 | |
243 | Type CooperativeMatrixType::getElementType() const { |
244 | return getImpl()->elementType; |
245 | } |
246 | |
247 | uint32_t CooperativeMatrixType::getRows() const { |
248 | assert(getImpl()->shape[0] != ShapedType::kDynamic); |
249 | return static_cast<uint32_t>(getImpl()->shape[0]); |
250 | } |
251 | |
252 | uint32_t CooperativeMatrixType::getColumns() const { |
253 | assert(getImpl()->shape[1] != ShapedType::kDynamic); |
254 | return static_cast<uint32_t>(getImpl()->shape[1]); |
255 | } |
256 | |
257 | ArrayRef<int64_t> CooperativeMatrixType::getShape() const { |
258 | return getImpl()->shape; |
259 | } |
260 | |
261 | Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; } |
262 | |
263 | CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const { |
264 | return getImpl()->use; |
265 | } |
266 | |
267 | void CooperativeMatrixType::getExtensions( |
268 | SPIRVType::ExtensionArrayRefVector &extensions, |
269 | std::optional<StorageClass> storage) { |
270 | llvm::cast<SPIRVType>(Val: getElementType()).getExtensions(extensions, storage); |
271 | static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix}; |
272 | extensions.push_back(exts); |
273 | } |
274 | |
275 | void CooperativeMatrixType::getCapabilities( |
276 | SPIRVType::CapabilityArrayRefVector &capabilities, |
277 | std::optional<StorageClass> storage) { |
278 | llvm::cast<SPIRVType>(Val: getElementType()) |
279 | .getCapabilities(capabilities, storage); |
280 | static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR}; |
281 | capabilities.push_back(caps); |
282 | } |
283 | |
284 | //===----------------------------------------------------------------------===// |
285 | // ImageType |
286 | //===----------------------------------------------------------------------===// |
287 | |
288 | template <typename T> |
289 | static constexpr unsigned getNumBits() { |
290 | return 0; |
291 | } |
292 | template <> |
293 | constexpr unsigned getNumBits<Dim>() { |
294 | static_assert((1 << 3) > getMaxEnumValForDim(), |
295 | "Not enough bits to encode Dim value"); |
296 | return 3; |
297 | } |
298 | template <> |
299 | constexpr unsigned getNumBits<ImageDepthInfo>() { |
300 | static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(), |
301 | "Not enough bits to encode ImageDepthInfo value"); |
302 | return 2; |
303 | } |
304 | template <> |
305 | constexpr unsigned getNumBits<ImageArrayedInfo>() { |
306 | static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(), |
307 | "Not enough bits to encode ImageArrayedInfo value"); |
308 | return 1; |
309 | } |
310 | template <> |
311 | constexpr unsigned getNumBits<ImageSamplingInfo>() { |
312 | static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(), |
313 | "Not enough bits to encode ImageSamplingInfo value"); |
314 | return 1; |
315 | } |
316 | template <> |
317 | constexpr unsigned getNumBits<ImageSamplerUseInfo>() { |
318 | static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(), |
319 | "Not enough bits to encode ImageSamplerUseInfo value"); |
320 | return 2; |
321 | } |
322 | template <> |
323 | constexpr unsigned getNumBits<ImageFormat>() { |
324 | static_assert((1 << 6) > getMaxEnumValForImageFormat(), |
325 | "Not enough bits to encode ImageFormat value"); |
326 | return 6; |
327 | } |
328 | |
329 | struct spirv::detail::ImageTypeStorage : public TypeStorage { |
330 | public: |
331 | using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo, |
332 | ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>; |
333 | |
334 | static ImageTypeStorage *construct(TypeStorageAllocator &allocator, |
335 | const KeyTy &key) { |
336 | return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key); |
337 | } |
338 | |
339 | bool operator==(const KeyTy &key) const { |
340 | return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo, |
341 | samplerUseInfo, format); |
342 | } |
343 | |
344 | ImageTypeStorage(const KeyTy &key) |
345 | : elementType(std::get<0>(key)), dim(std::get<1>(key)), |
346 | depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)), |
347 | samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)), |
348 | format(std::get<6>(key)) {} |
349 | |
350 | Type elementType; |
351 | Dim dim : getNumBits<Dim>(); |
352 | ImageDepthInfo depthInfo : getNumBits<ImageDepthInfo>(); |
353 | ImageArrayedInfo arrayedInfo : getNumBits<ImageArrayedInfo>(); |
354 | ImageSamplingInfo samplingInfo : getNumBits<ImageSamplingInfo>(); |
355 | ImageSamplerUseInfo samplerUseInfo : getNumBits<ImageSamplerUseInfo>(); |
356 | ImageFormat format : getNumBits<ImageFormat>(); |
357 | }; |
358 | |
359 | ImageType |
360 | ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo, |
361 | ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat> |
362 | value) { |
363 | return Base::get(std::get<0>(value).getContext(), value); |
364 | } |
365 | |
366 | Type ImageType::getElementType() const { return getImpl()->elementType; } |
367 | |
368 | Dim ImageType::getDim() const { return getImpl()->dim; } |
369 | |
370 | ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; } |
371 | |
372 | ImageArrayedInfo ImageType::getArrayedInfo() const { |
373 | return getImpl()->arrayedInfo; |
374 | } |
375 | |
376 | ImageSamplingInfo ImageType::getSamplingInfo() const { |
377 | return getImpl()->samplingInfo; |
378 | } |
379 | |
380 | ImageSamplerUseInfo ImageType::getSamplerUseInfo() const { |
381 | return getImpl()->samplerUseInfo; |
382 | } |
383 | |
384 | ImageFormat ImageType::getImageFormat() const { return getImpl()->format; } |
385 | |
386 | void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &, |
387 | std::optional<StorageClass>) { |
388 | // Image types do not require extra extensions thus far. |
389 | } |
390 | |
391 | void ImageType::getCapabilities( |
392 | SPIRVType::CapabilityArrayRefVector &capabilities, |
393 | std::optional<StorageClass>) { |
394 | if (auto dimCaps = spirv::getCapabilities(getDim())) |
395 | capabilities.push_back(Elt: *dimCaps); |
396 | |
397 | if (auto fmtCaps = spirv::getCapabilities(getImageFormat())) |
398 | capabilities.push_back(Elt: *fmtCaps); |
399 | } |
400 | |
401 | //===----------------------------------------------------------------------===// |
402 | // PointerType |
403 | //===----------------------------------------------------------------------===// |
404 | |
405 | struct spirv::detail::PointerTypeStorage : public TypeStorage { |
406 | // (Type, StorageClass) as the key: Type stored in this struct, and |
407 | // StorageClass stored as TypeStorage's subclass data. |
408 | using KeyTy = std::pair<Type, StorageClass>; |
409 | |
410 | static PointerTypeStorage *construct(TypeStorageAllocator &allocator, |
411 | const KeyTy &key) { |
412 | return new (allocator.allocate<PointerTypeStorage>()) |
413 | PointerTypeStorage(key); |
414 | } |
415 | |
416 | bool operator==(const KeyTy &key) const { |
417 | return key == KeyTy(pointeeType, storageClass); |
418 | } |
419 | |
420 | PointerTypeStorage(const KeyTy &key) |
421 | : pointeeType(key.first), storageClass(key.second) {} |
422 | |
423 | Type pointeeType; |
424 | StorageClass storageClass; |
425 | }; |
426 | |
427 | PointerType PointerType::get(Type pointeeType, StorageClass storageClass) { |
428 | return Base::get(pointeeType.getContext(), pointeeType, storageClass); |
429 | } |
430 | |
431 | Type PointerType::getPointeeType() const { return getImpl()->pointeeType; } |
432 | |
433 | StorageClass PointerType::getStorageClass() const { |
434 | return getImpl()->storageClass; |
435 | } |
436 | |
437 | void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, |
438 | std::optional<StorageClass> storage) { |
439 | // Use this pointer type's storage class because this pointer indicates we are |
440 | // using the pointee type in that specific storage class. |
441 | llvm::cast<SPIRVType>(getPointeeType()) |
442 | .getExtensions(extensions, getStorageClass()); |
443 | |
444 | if (auto scExts = spirv::getExtensions(getStorageClass())) |
445 | extensions.push_back(Elt: *scExts); |
446 | } |
447 | |
448 | void PointerType::getCapabilities( |
449 | SPIRVType::CapabilityArrayRefVector &capabilities, |
450 | std::optional<StorageClass> storage) { |
451 | // Use this pointer type's storage class because this pointer indicates we are |
452 | // using the pointee type in that specific storage class. |
453 | llvm::cast<SPIRVType>(getPointeeType()) |
454 | .getCapabilities(capabilities, getStorageClass()); |
455 | |
456 | if (auto scCaps = spirv::getCapabilities(getStorageClass())) |
457 | capabilities.push_back(Elt: *scCaps); |
458 | } |
459 | |
460 | //===----------------------------------------------------------------------===// |
461 | // RuntimeArrayType |
462 | //===----------------------------------------------------------------------===// |
463 | |
464 | struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage { |
465 | using KeyTy = std::pair<Type, unsigned>; |
466 | |
467 | static RuntimeArrayTypeStorage *construct(TypeStorageAllocator &allocator, |
468 | const KeyTy &key) { |
469 | return new (allocator.allocate<RuntimeArrayTypeStorage>()) |
470 | RuntimeArrayTypeStorage(key); |
471 | } |
472 | |
473 | bool operator==(const KeyTy &key) const { |
474 | return key == KeyTy(elementType, stride); |
475 | } |
476 | |
477 | RuntimeArrayTypeStorage(const KeyTy &key) |
478 | : elementType(key.first), stride(key.second) {} |
479 | |
480 | Type elementType; |
481 | unsigned stride; |
482 | }; |
483 | |
484 | RuntimeArrayType RuntimeArrayType::get(Type elementType) { |
485 | return Base::get(ctx: elementType.getContext(), args&: elementType, /*stride=*/args: 0); |
486 | } |
487 | |
488 | RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) { |
489 | return Base::get(ctx: elementType.getContext(), args&: elementType, args&: stride); |
490 | } |
491 | |
492 | Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; } |
493 | |
494 | unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; } |
495 | |
496 | void RuntimeArrayType::getExtensions( |
497 | SPIRVType::ExtensionArrayRefVector &extensions, |
498 | std::optional<StorageClass> storage) { |
499 | llvm::cast<SPIRVType>(Val: getElementType()).getExtensions(extensions, storage); |
500 | } |
501 | |
502 | void RuntimeArrayType::getCapabilities( |
503 | SPIRVType::CapabilityArrayRefVector &capabilities, |
504 | std::optional<StorageClass> storage) { |
505 | { |
506 | static const Capability caps[] = {Capability::Shader}; |
507 | ArrayRef<Capability> ref(caps, std::size(caps)); |
508 | capabilities.push_back(Elt: ref); |
509 | } |
510 | llvm::cast<SPIRVType>(Val: getElementType()) |
511 | .getCapabilities(capabilities, storage); |
512 | } |
513 | |
514 | //===----------------------------------------------------------------------===// |
515 | // ScalarType |
516 | //===----------------------------------------------------------------------===// |
517 | |
518 | bool ScalarType::classof(Type type) { |
519 | if (auto floatType = llvm::dyn_cast<FloatType>(type)) { |
520 | return isValid(floatType); |
521 | } |
522 | if (auto intType = llvm::dyn_cast<IntegerType>(type)) { |
523 | return isValid(intType); |
524 | } |
525 | return false; |
526 | } |
527 | |
528 | bool ScalarType::isValid(FloatType type) { |
529 | return llvm::is_contained({16u, 32u, 64u}, type.getWidth()) && !type.isBF16(); |
530 | } |
531 | |
532 | bool ScalarType::isValid(IntegerType type) { |
533 | return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth()); |
534 | } |
535 | |
536 | void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, |
537 | std::optional<StorageClass> storage) { |
538 | // 8- or 16-bit integer/floating-point numbers will require extra extensions |
539 | // to appear in interface storage classes. See SPV_KHR_16bit_storage and |
540 | // SPV_KHR_8bit_storage for more details. |
541 | if (!storage) |
542 | return; |
543 | |
544 | switch (*storage) { |
545 | case StorageClass::PushConstant: |
546 | case StorageClass::StorageBuffer: |
547 | case StorageClass::Uniform: |
548 | if (getIntOrFloatBitWidth() == 8) { |
549 | static const Extension exts[] = {Extension::SPV_KHR_8bit_storage}; |
550 | ArrayRef<Extension> ref(exts, std::size(exts)); |
551 | extensions.push_back(Elt: ref); |
552 | } |
553 | [[fallthrough]]; |
554 | case StorageClass::Input: |
555 | case StorageClass::Output: |
556 | if (getIntOrFloatBitWidth() == 16) { |
557 | static const Extension exts[] = {Extension::SPV_KHR_16bit_storage}; |
558 | ArrayRef<Extension> ref(exts, std::size(exts)); |
559 | extensions.push_back(Elt: ref); |
560 | } |
561 | break; |
562 | default: |
563 | break; |
564 | } |
565 | } |
566 | |
567 | void ScalarType::getCapabilities( |
568 | SPIRVType::CapabilityArrayRefVector &capabilities, |
569 | std::optional<StorageClass> storage) { |
570 | unsigned bitwidth = getIntOrFloatBitWidth(); |
571 | |
572 | // 8- or 16-bit integer/floating-point numbers will require extra capabilities |
573 | // to appear in interface storage classes. See SPV_KHR_16bit_storage and |
574 | // SPV_KHR_8bit_storage for more details. |
575 | |
576 | #define STORAGE_CASE(storage, cap8, cap16) \ |
577 | case StorageClass::storage: { \ |
578 | if (bitwidth == 8) { \ |
579 | static const Capability caps[] = {Capability::cap8}; \ |
580 | ArrayRef<Capability> ref(caps, std::size(caps)); \ |
581 | capabilities.push_back(ref); \ |
582 | return; \ |
583 | } \ |
584 | if (bitwidth == 16) { \ |
585 | static const Capability caps[] = {Capability::cap16}; \ |
586 | ArrayRef<Capability> ref(caps, std::size(caps)); \ |
587 | capabilities.push_back(ref); \ |
588 | return; \ |
589 | } \ |
590 | /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \ |
591 | /* storage classes. Fall through to the next section. */ \ |
592 | } break |
593 | |
594 | // This part only handles the cases where special bitwidths appearing in |
595 | // interface storage classes. |
596 | if (storage) { |
597 | switch (*storage) { |
598 | STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16); |
599 | STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess, |
600 | StorageBuffer16BitAccess); |
601 | STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess, |
602 | StorageUniform16); |
603 | case StorageClass::Input: |
604 | case StorageClass::Output: { |
605 | if (bitwidth == 16) { |
606 | static const Capability caps[] = {Capability::StorageInputOutput16}; |
607 | ArrayRef<Capability> ref(caps, std::size(caps)); |
608 | capabilities.push_back(Elt: ref); |
609 | return; |
610 | } |
611 | break; |
612 | } |
613 | default: |
614 | break; |
615 | } |
616 | } |
617 | #undef STORAGE_CASE |
618 | |
619 | // For other non-interface storage classes, require a different set of |
620 | // capabilities for special bitwidths. |
621 | |
622 | #define WIDTH_CASE(type, width) \ |
623 | case width: { \ |
624 | static const Capability caps[] = {Capability::type##width}; \ |
625 | ArrayRef<Capability> ref(caps, std::size(caps)); \ |
626 | capabilities.push_back(ref); \ |
627 | } break |
628 | |
629 | if (auto intType = llvm::dyn_cast<IntegerType>(*this)) { |
630 | switch (bitwidth) { |
631 | WIDTH_CASE(Int, 8); |
632 | WIDTH_CASE(Int, 16); |
633 | WIDTH_CASE(Int, 64); |
634 | case 1: |
635 | case 32: |
636 | break; |
637 | default: |
638 | llvm_unreachable("invalid bitwidth to getCapabilities"); |
639 | } |
640 | } else { |
641 | assert(llvm::isa<FloatType>(*this)); |
642 | switch (bitwidth) { |
643 | WIDTH_CASE(Float, 16); |
644 | WIDTH_CASE(Float, 64); |
645 | case 32: |
646 | break; |
647 | default: |
648 | llvm_unreachable("invalid bitwidth to getCapabilities"); |
649 | } |
650 | } |
651 | |
652 | #undef WIDTH_CASE |
653 | } |
654 | |
655 | std::optional<int64_t> ScalarType::getSizeInBytes() { |
656 | auto bitWidth = getIntOrFloatBitWidth(); |
657 | // According to the SPIR-V spec: |
658 | // "There is no physical size or bit pattern defined for values with boolean |
659 | // type. If they are stored (in conjunction with OpVariable), they can only |
660 | // be used with logical addressing operations, not physical, and only with |
661 | // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup, |
662 | // Private, Function, Input, and Output." |
663 | if (bitWidth == 1) |
664 | return std::nullopt; |
665 | return bitWidth / 8; |
666 | } |
667 | |
668 | //===----------------------------------------------------------------------===// |
669 | // SPIRVType |
670 | //===----------------------------------------------------------------------===// |
671 | |
672 | bool SPIRVType::classof(Type type) { |
673 | // Allow SPIR-V dialect types |
674 | if (llvm::isa<SPIRVDialect>(type.getDialect())) |
675 | return true; |
676 | if (llvm::isa<ScalarType>(Val: type)) |
677 | return true; |
678 | if (auto vectorType = llvm::dyn_cast<VectorType>(type)) |
679 | return CompositeType::isValid(vectorType); |
680 | return false; |
681 | } |
682 | |
683 | bool SPIRVType::isScalarOrVector() { |
684 | return isIntOrFloat() || llvm::isa<VectorType>(*this); |
685 | } |
686 | |
687 | void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, |
688 | std::optional<StorageClass> storage) { |
689 | if (auto scalarType = llvm::dyn_cast<ScalarType>(Val&: *this)) { |
690 | scalarType.getExtensions(extensions, storage); |
691 | } else if (auto compositeType = llvm::dyn_cast<CompositeType>(Val&: *this)) { |
692 | compositeType.getExtensions(extensions, storage); |
693 | } else if (auto imageType = llvm::dyn_cast<ImageType>(Val&: *this)) { |
694 | imageType.getExtensions(extensions, storage); |
695 | } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(Val&: *this)) { |
696 | sampledImageType.getExtensions(extensions, storage); |
697 | } else if (auto matrixType = llvm::dyn_cast<MatrixType>(Val&: *this)) { |
698 | matrixType.getExtensions(extensions, storage); |
699 | } else if (auto ptrType = llvm::dyn_cast<PointerType>(Val&: *this)) { |
700 | ptrType.getExtensions(extensions, storage); |
701 | } else { |
702 | llvm_unreachable("invalid SPIR-V Type to getExtensions"); |
703 | } |
704 | } |
705 | |
706 | void SPIRVType::getCapabilities( |
707 | SPIRVType::CapabilityArrayRefVector &capabilities, |
708 | std::optional<StorageClass> storage) { |
709 | if (auto scalarType = llvm::dyn_cast<ScalarType>(Val&: *this)) { |
710 | scalarType.getCapabilities(capabilities, storage); |
711 | } else if (auto compositeType = llvm::dyn_cast<CompositeType>(Val&: *this)) { |
712 | compositeType.getCapabilities(capabilities, storage); |
713 | } else if (auto imageType = llvm::dyn_cast<ImageType>(Val&: *this)) { |
714 | imageType.getCapabilities(capabilities, storage); |
715 | } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(Val&: *this)) { |
716 | sampledImageType.getCapabilities(capabilities, storage); |
717 | } else if (auto matrixType = llvm::dyn_cast<MatrixType>(Val&: *this)) { |
718 | matrixType.getCapabilities(capabilities, storage); |
719 | } else if (auto ptrType = llvm::dyn_cast<PointerType>(Val&: *this)) { |
720 | ptrType.getCapabilities(capabilities, storage); |
721 | } else { |
722 | llvm_unreachable("invalid SPIR-V Type to getCapabilities"); |
723 | } |
724 | } |
725 | |
726 | std::optional<int64_t> SPIRVType::getSizeInBytes() { |
727 | if (auto scalarType = llvm::dyn_cast<ScalarType>(Val&: *this)) |
728 | return scalarType.getSizeInBytes(); |
729 | if (auto compositeType = llvm::dyn_cast<CompositeType>(Val&: *this)) |
730 | return compositeType.getSizeInBytes(); |
731 | return std::nullopt; |
732 | } |
733 | |
734 | //===----------------------------------------------------------------------===// |
735 | // SampledImageType |
736 | //===----------------------------------------------------------------------===// |
737 | struct spirv::detail::SampledImageTypeStorage : public TypeStorage { |
738 | using KeyTy = Type; |
739 | |
740 | SampledImageTypeStorage(const KeyTy &key) : imageType{key} {} |
741 | |
742 | bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); } |
743 | |
744 | static SampledImageTypeStorage *construct(TypeStorageAllocator &allocator, |
745 | const KeyTy &key) { |
746 | return new (allocator.allocate<SampledImageTypeStorage>()) |
747 | SampledImageTypeStorage(key); |
748 | } |
749 | |
750 | Type imageType; |
751 | }; |
752 | |
753 | SampledImageType SampledImageType::get(Type imageType) { |
754 | return Base::get(ctx: imageType.getContext(), args&: imageType); |
755 | } |
756 | |
757 | SampledImageType |
758 | SampledImageType::getChecked(function_ref<InFlightDiagnostic()> emitError, |
759 | Type imageType) { |
760 | return Base::getChecked(emitErrorFn: emitError, ctx: imageType.getContext(), args: imageType); |
761 | } |
762 | |
763 | Type SampledImageType::getImageType() const { return getImpl()->imageType; } |
764 | |
765 | LogicalResult |
766 | SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError, |
767 | Type imageType) { |
768 | if (!llvm::isa<ImageType>(Val: imageType)) |
769 | return emitError() << "expected image type"; |
770 | |
771 | return success(); |
772 | } |
773 | |
774 | void SampledImageType::getExtensions( |
775 | SPIRVType::ExtensionArrayRefVector &extensions, |
776 | std::optional<StorageClass> storage) { |
777 | llvm::cast<ImageType>(Val: getImageType()).getExtensions(extensions, storage); |
778 | } |
779 | |
780 | void SampledImageType::getCapabilities( |
781 | SPIRVType::CapabilityArrayRefVector &capabilities, |
782 | std::optional<StorageClass> storage) { |
783 | llvm::cast<ImageType>(Val: getImageType()).getCapabilities(capabilities, storage); |
784 | } |
785 | |
786 | //===----------------------------------------------------------------------===// |
787 | // StructType |
788 | //===----------------------------------------------------------------------===// |
789 | |
790 | /// Type storage for SPIR-V structure types: |
791 | /// |
792 | /// Structures are uniqued using: |
793 | /// - for identified structs: |
794 | /// - a string identifier; |
795 | /// - for literal structs: |
796 | /// - a list of member types; |
797 | /// - a list of member offset info; |
798 | /// - a list of member decoration info. |
799 | /// |
800 | /// Identified structures only have a mutable component consisting of: |
801 | /// - a list of member types; |
802 | /// - a list of member offset info; |
803 | /// - a list of member decoration info. |
804 | struct spirv::detail::StructTypeStorage : public TypeStorage { |
805 | /// Construct a storage object for an identified struct type. A struct type |
806 | /// associated with such storage must call StructType::trySetBody(...) later |
807 | /// in order to mutate the storage object providing the actual content. |
808 | StructTypeStorage(StringRef identifier) |
809 | : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr), |
810 | numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr), |
811 | identifier(identifier) {} |
812 | |
813 | /// Construct a storage object for a literal struct type. A struct type |
814 | /// associated with such storage is immutable. |
815 | StructTypeStorage( |
816 | unsigned numMembers, Type const *memberTypes, |
817 | StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, |
818 | StructType::MemberDecorationInfo const *memberDecorationsInfo) |
819 | : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo), |
820 | numMembers(numMembers), numMemberDecorations(numMemberDecorations), |
821 | memberDecorationsInfo(memberDecorationsInfo) {} |
822 | |
823 | /// A storage key is divided into 2 parts: |
824 | /// - for identified structs: |
825 | /// - a StringRef representing the struct identifier; |
826 | /// - for literal structs: |
827 | /// - an ArrayRef<Type> for member types; |
828 | /// - an ArrayRef<StructType::OffsetInfo> for member offset info; |
829 | /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration |
830 | /// info. |
831 | /// |
832 | /// An identified struct type is uniqued only by the first part (field 0) |
833 | /// of the key. |
834 | /// |
835 | /// A literal struct type is uniqued only by the second part (fields 1, 2, and |
836 | /// 3) of the key. The identifier field (field 0) must be empty. |
837 | using KeyTy = |
838 | std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>, |
839 | ArrayRef<StructType::MemberDecorationInfo>>; |
840 | |
841 | /// For identified structs, return true if the given key contains the same |
842 | /// identifier. |
843 | /// |
844 | /// For literal structs, return true if the given key contains a matching list |
845 | /// of member types + offset info + decoration info. |
846 | bool operator==(const KeyTy &key) const { |
847 | if (isIdentified()) { |
848 | // Identified types are uniqued by their identifier. |
849 | return getIdentifier() == std::get<0>(t: key); |
850 | } |
851 | |
852 | return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(), |
853 | getMemberDecorationsInfo()); |
854 | } |
855 | |
856 | /// If the given key contains a non-empty identifier, this method constructs |
857 | /// an identified struct and leaves the rest of the struct type data to be set |
858 | /// through a later call to StructType::trySetBody(...). |
859 | /// |
860 | /// If, on the other hand, the key contains an empty identifier, a literal |
861 | /// struct is constructed using the other fields of the key. |
862 | static StructTypeStorage *construct(TypeStorageAllocator &allocator, |
863 | const KeyTy &key) { |
864 | StringRef keyIdentifier = std::get<0>(t: key); |
865 | |
866 | if (!keyIdentifier.empty()) { |
867 | StringRef identifier = allocator.copyInto(str: keyIdentifier); |
868 | |
869 | // Identified StructType body/members will be set through trySetBody(...) |
870 | // later. |
871 | return new (allocator.allocate<StructTypeStorage>()) |
872 | StructTypeStorage(identifier); |
873 | } |
874 | |
875 | ArrayRef<Type> keyTypes = std::get<1>(t: key); |
876 | |
877 | // Copy the member type and layout information into the bump pointer |
878 | const Type *typesList = nullptr; |
879 | if (!keyTypes.empty()) { |
880 | typesList = allocator.copyInto(elements: keyTypes).data(); |
881 | } |
882 | |
883 | const StructType::OffsetInfo *offsetInfoList = nullptr; |
884 | if (!std::get<2>(t: key).empty()) { |
885 | ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(t: key); |
886 | assert(keyOffsetInfo.size() == keyTypes.size() && |
887 | "size of offset information must be same as the size of number of " |
888 | "elements"); |
889 | offsetInfoList = allocator.copyInto(elements: keyOffsetInfo).data(); |
890 | } |
891 | |
892 | const StructType::MemberDecorationInfo *memberDecorationList = nullptr; |
893 | unsigned numMemberDecorations = 0; |
894 | if (!std::get<3>(t: key).empty()) { |
895 | auto keyMemberDecorations = std::get<3>(t: key); |
896 | numMemberDecorations = keyMemberDecorations.size(); |
897 | memberDecorationList = allocator.copyInto(elements: keyMemberDecorations).data(); |
898 | } |
899 | |
900 | return new (allocator.allocate<StructTypeStorage>()) |
901 | StructTypeStorage(keyTypes.size(), typesList, offsetInfoList, |
902 | numMemberDecorations, memberDecorationList); |
903 | } |
904 | |
905 | ArrayRef<Type> getMemberTypes() const { |
906 | return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers); |
907 | } |
908 | |
909 | ArrayRef<StructType::OffsetInfo> getOffsetInfo() const { |
910 | if (offsetInfo) { |
911 | return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers); |
912 | } |
913 | return {}; |
914 | } |
915 | |
916 | ArrayRef<StructType::MemberDecorationInfo> getMemberDecorationsInfo() const { |
917 | if (memberDecorationsInfo) { |
918 | return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo, |
919 | numMemberDecorations); |
920 | } |
921 | return {}; |
922 | } |
923 | |
924 | StringRef getIdentifier() const { return identifier; } |
925 | |
926 | bool isIdentified() const { return !identifier.empty(); } |
927 | |
928 | /// Sets the struct type content for identified structs. Calling this method |
929 | /// is only valid for identified structs. |
930 | /// |
931 | /// Fails under the following conditions: |
932 | /// - If called for a literal struct; |
933 | /// - If called for an identified struct whose body was set before (through a |
934 | /// call to this method) but with different contents from the passed |
935 | /// arguments. |
936 | LogicalResult mutate( |
937 | TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes, |
938 | ArrayRef<StructType::OffsetInfo> structOffsetInfo, |
939 | ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) { |
940 | if (!isIdentified()) |
941 | return failure(); |
942 | |
943 | if (memberTypesAndIsBodySet.getInt() && |
944 | (getMemberTypes() != structMemberTypes || |
945 | getOffsetInfo() != structOffsetInfo || |
946 | getMemberDecorationsInfo() != structMemberDecorationInfo)) |
947 | return failure(); |
948 | |
949 | memberTypesAndIsBodySet.setInt(true); |
950 | numMembers = structMemberTypes.size(); |
951 | |
952 | // Copy the member type and layout information into the bump pointer. |
953 | if (!structMemberTypes.empty()) |
954 | memberTypesAndIsBodySet.setPointer( |
955 | allocator.copyInto(elements: structMemberTypes).data()); |
956 | |
957 | if (!structOffsetInfo.empty()) { |
958 | assert(structOffsetInfo.size() == structMemberTypes.size() && |
959 | "size of offset information must be same as the size of number of " |
960 | "elements"); |
961 | offsetInfo = allocator.copyInto(elements: structOffsetInfo).data(); |
962 | } |
963 | |
964 | if (!structMemberDecorationInfo.empty()) { |
965 | numMemberDecorations = structMemberDecorationInfo.size(); |
966 | memberDecorationsInfo = |
967 | allocator.copyInto(elements: structMemberDecorationInfo).data(); |
968 | } |
969 | |
970 | return success(); |
971 | } |
972 | |
973 | llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet; |
974 | StructType::OffsetInfo const *offsetInfo; |
975 | unsigned numMembers; |
976 | unsigned numMemberDecorations; |
977 | StructType::MemberDecorationInfo const *memberDecorationsInfo; |
978 | StringRef identifier; |
979 | }; |
980 | |
981 | StructType |
982 | StructType::get(ArrayRef<Type> memberTypes, |
983 | ArrayRef<StructType::OffsetInfo> offsetInfo, |
984 | ArrayRef<StructType::MemberDecorationInfo> memberDecorations) { |
985 | assert(!memberTypes.empty() && "Struct needs at least one member type"); |
986 | // Sort the decorations. |
987 | SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations( |
988 | memberDecorations); |
989 | llvm::array_pod_sort(Start: sortedDecorations.begin(), End: sortedDecorations.end()); |
990 | return Base::get(ctx: memberTypes.vec().front().getContext(), |
991 | /*identifier=*/args: StringRef(), args&: memberTypes, args&: offsetInfo, |
992 | args&: sortedDecorations); |
993 | } |
994 | |
995 | StructType StructType::getIdentified(MLIRContext *context, |
996 | StringRef identifier) { |
997 | assert(!identifier.empty() && |
998 | "StructType identifier must be non-empty string"); |
999 | |
1000 | return Base::get(ctx: context, args&: identifier, args: ArrayRef<Type>(), |
1001 | args: ArrayRef<StructType::OffsetInfo>(), |
1002 | args: ArrayRef<StructType::MemberDecorationInfo>()); |
1003 | } |
1004 | |
1005 | StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) { |
1006 | StructType newStructType = Base::get( |
1007 | ctx: context, args&: identifier, args: ArrayRef<Type>(), args: ArrayRef<StructType::OffsetInfo>(), |
1008 | args: ArrayRef<StructType::MemberDecorationInfo>()); |
1009 | // Set an empty body in case this is a identified struct. |
1010 | if (newStructType.isIdentified() && |
1011 | failed(Result: newStructType.trySetBody( |
1012 | memberTypes: ArrayRef<Type>(), offsetInfo: ArrayRef<StructType::OffsetInfo>(), |
1013 | memberDecorations: ArrayRef<StructType::MemberDecorationInfo>()))) |
1014 | return StructType(); |
1015 | |
1016 | return newStructType; |
1017 | } |
1018 | |
1019 | StringRef StructType::getIdentifier() const { return getImpl()->identifier; } |
1020 | |
1021 | bool StructType::isIdentified() const { return getImpl()->isIdentified(); } |
1022 | |
1023 | unsigned StructType::getNumElements() const { return getImpl()->numMembers; } |
1024 | |
1025 | Type StructType::getElementType(unsigned index) const { |
1026 | assert(getNumElements() > index && "member index out of range"); |
1027 | return getImpl()->memberTypesAndIsBodySet.getPointer()[index]; |
1028 | } |
1029 | |
1030 | TypeRange StructType::getElementTypes() const { |
1031 | return TypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(), |
1032 | getNumElements()); |
1033 | } |
1034 | |
1035 | bool StructType::hasOffset() const { return getImpl()->offsetInfo; } |
1036 | |
1037 | uint64_t StructType::getMemberOffset(unsigned index) const { |
1038 | assert(getNumElements() > index && "member index out of range"); |
1039 | return getImpl()->offsetInfo[index]; |
1040 | } |
1041 | |
1042 | void StructType::getMemberDecorations( |
1043 | SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorations) |
1044 | const { |
1045 | memberDecorations.clear(); |
1046 | auto implMemberDecorations = getImpl()->getMemberDecorationsInfo(); |
1047 | memberDecorations.append(in_start: implMemberDecorations.begin(), |
1048 | in_end: implMemberDecorations.end()); |
1049 | } |
1050 | |
1051 | void StructType::getMemberDecorations( |
1052 | unsigned index, |
1053 | SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const { |
1054 | assert(getNumElements() > index && "member index out of range"); |
1055 | auto memberDecorations = getImpl()->getMemberDecorationsInfo(); |
1056 | decorationsInfo.clear(); |
1057 | for (const auto &memberDecoration : memberDecorations) { |
1058 | if (memberDecoration.memberIndex == index) { |
1059 | decorationsInfo.push_back(Elt: memberDecoration); |
1060 | } |
1061 | if (memberDecoration.memberIndex > index) { |
1062 | // Early exit since the decorations are stored sorted. |
1063 | return; |
1064 | } |
1065 | } |
1066 | } |
1067 | |
1068 | LogicalResult |
1069 | StructType::trySetBody(ArrayRef<Type> memberTypes, |
1070 | ArrayRef<OffsetInfo> offsetInfo, |
1071 | ArrayRef<MemberDecorationInfo> memberDecorations) { |
1072 | return Base::mutate(args&: memberTypes, args&: offsetInfo, args&: memberDecorations); |
1073 | } |
1074 | |
1075 | void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, |
1076 | std::optional<StorageClass> storage) { |
1077 | for (Type elementType : getElementTypes()) |
1078 | llvm::cast<SPIRVType>(Val&: elementType).getExtensions(extensions, storage); |
1079 | } |
1080 | |
1081 | void StructType::getCapabilities( |
1082 | SPIRVType::CapabilityArrayRefVector &capabilities, |
1083 | std::optional<StorageClass> storage) { |
1084 | for (Type elementType : getElementTypes()) |
1085 | llvm::cast<SPIRVType>(Val&: elementType).getCapabilities(capabilities, storage); |
1086 | } |
1087 | |
1088 | llvm::hash_code spirv::hash_value( |
1089 | const StructType::MemberDecorationInfo &memberDecorationInfo) { |
1090 | return llvm::hash_combine(memberDecorationInfo.memberIndex, |
1091 | memberDecorationInfo.decoration); |
1092 | } |
1093 | |
1094 | //===----------------------------------------------------------------------===// |
1095 | // MatrixType |
1096 | //===----------------------------------------------------------------------===// |
1097 | |
1098 | struct spirv::detail::MatrixTypeStorage : public TypeStorage { |
1099 | MatrixTypeStorage(Type columnType, uint32_t columnCount) |
1100 | : columnType(columnType), columnCount(columnCount) {} |
1101 | |
1102 | using KeyTy = std::tuple<Type, uint32_t>; |
1103 | |
1104 | static MatrixTypeStorage *construct(TypeStorageAllocator &allocator, |
1105 | const KeyTy &key) { |
1106 | |
1107 | // Initialize the memory using placement new. |
1108 | return new (allocator.allocate<MatrixTypeStorage>()) |
1109 | MatrixTypeStorage(std::get<0>(t: key), std::get<1>(t: key)); |
1110 | } |
1111 | |
1112 | bool operator==(const KeyTy &key) const { |
1113 | return key == KeyTy(columnType, columnCount); |
1114 | } |
1115 | |
1116 | Type columnType; |
1117 | const uint32_t columnCount; |
1118 | }; |
1119 | |
1120 | MatrixType MatrixType::get(Type columnType, uint32_t columnCount) { |
1121 | return Base::get(ctx: columnType.getContext(), args&: columnType, args&: columnCount); |
1122 | } |
1123 | |
1124 | MatrixType MatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError, |
1125 | Type columnType, uint32_t columnCount) { |
1126 | return Base::getChecked(emitErrorFn: emitError, ctx: columnType.getContext(), args: columnType, |
1127 | args: columnCount); |
1128 | } |
1129 | |
1130 | LogicalResult |
1131 | MatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError, |
1132 | Type columnType, uint32_t columnCount) { |
1133 | if (columnCount < 2 || columnCount > 4) |
1134 | return emitError() << "matrix can have 2, 3, or 4 columns only"; |
1135 | |
1136 | if (!isValidColumnType(columnType)) |
1137 | return emitError() << "matrix columns must be vectors of floats"; |
1138 | |
1139 | /// The underlying vectors (columns) must be of size 2, 3, or 4 |
1140 | ArrayRef<int64_t> columnShape = llvm::cast<VectorType>(columnType).getShape(); |
1141 | if (columnShape.size() != 1) |
1142 | return emitError() << "matrix columns must be 1D vectors"; |
1143 | |
1144 | if (columnShape[0] < 2 || columnShape[0] > 4) |
1145 | return emitError() << "matrix columns must be of size 2, 3, or 4"; |
1146 | |
1147 | return success(); |
1148 | } |
1149 | |
1150 | /// Returns true if the matrix elements are vectors of float elements |
1151 | bool MatrixType::isValidColumnType(Type columnType) { |
1152 | if (auto vectorType = llvm::dyn_cast<VectorType>(columnType)) { |
1153 | if (llvm::isa<FloatType>(vectorType.getElementType())) |
1154 | return true; |
1155 | } |
1156 | return false; |
1157 | } |
1158 | |
1159 | Type MatrixType::getColumnType() const { return getImpl()->columnType; } |
1160 | |
1161 | Type MatrixType::getElementType() const { |
1162 | return llvm::cast<VectorType>(getImpl()->columnType).getElementType(); |
1163 | } |
1164 | |
1165 | unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; } |
1166 | |
1167 | unsigned MatrixType::getNumRows() const { |
1168 | return llvm::cast<VectorType>(getImpl()->columnType).getShape()[0]; |
1169 | } |
1170 | |
1171 | unsigned MatrixType::getNumElements() const { |
1172 | return (getImpl()->columnCount) * getNumRows(); |
1173 | } |
1174 | |
1175 | void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, |
1176 | std::optional<StorageClass> storage) { |
1177 | llvm::cast<SPIRVType>(Val: getColumnType()).getExtensions(extensions, storage); |
1178 | } |
1179 | |
1180 | void MatrixType::getCapabilities( |
1181 | SPIRVType::CapabilityArrayRefVector &capabilities, |
1182 | std::optional<StorageClass> storage) { |
1183 | { |
1184 | static const Capability caps[] = {Capability::Matrix}; |
1185 | ArrayRef<Capability> ref(caps, std::size(caps)); |
1186 | capabilities.push_back(Elt: ref); |
1187 | } |
1188 | // Add any capabilities associated with the underlying vectors (i.e., columns) |
1189 | llvm::cast<SPIRVType>(Val: getColumnType()).getCapabilities(capabilities, storage); |
1190 | } |
1191 | |
1192 | //===----------------------------------------------------------------------===// |
1193 | // SPIR-V Dialect |
1194 | //===----------------------------------------------------------------------===// |
1195 | |
1196 | void SPIRVDialect::registerTypes() { |
1197 | addTypes<ArrayType, CooperativeMatrixType, ImageType, MatrixType, PointerType, |
1198 | RuntimeArrayType, SampledImageType, StructType>(); |
1199 | } |
1200 |
Definitions
- ArrayTypeStorage
- construct
- operator==
- ArrayTypeStorage
- get
- get
- getNumElements
- getElementType
- getArrayStride
- getExtensions
- getCapabilities
- getSizeInBytes
- classof
- isValid
- getElementType
- getNumElements
- hasCompileTimeKnownNumElements
- getExtensions
- getCapabilities
- getSizeInBytes
- CooperativeMatrixTypeStorage
- construct
- operator==
- CooperativeMatrixTypeStorage
- get
- getElementType
- getRows
- getColumns
- getShape
- getScope
- getUse
- getExtensions
- getCapabilities
- getNumBits
- ImageTypeStorage
- construct
- operator==
- ImageTypeStorage
- get
- getElementType
- getDim
- getDepthInfo
- getArrayedInfo
- getSamplingInfo
- getSamplerUseInfo
- getImageFormat
- getExtensions
- getCapabilities
- PointerTypeStorage
- construct
- operator==
- PointerTypeStorage
- get
- getPointeeType
- getStorageClass
- getExtensions
- getCapabilities
- RuntimeArrayTypeStorage
- construct
- operator==
- RuntimeArrayTypeStorage
- get
- get
- getElementType
- getArrayStride
- getExtensions
- getCapabilities
- classof
- isValid
- isValid
- getExtensions
- getCapabilities
- getSizeInBytes
- classof
- isScalarOrVector
- getExtensions
- getCapabilities
- getSizeInBytes
- SampledImageTypeStorage
- SampledImageTypeStorage
- operator==
- construct
- get
- getChecked
- getImageType
- verifyInvariants
- getExtensions
- getCapabilities
- StructTypeStorage
- StructTypeStorage
- StructTypeStorage
- operator==
- construct
- getMemberTypes
- getOffsetInfo
- getMemberDecorationsInfo
- getIdentifier
- isIdentified
- mutate
- get
- getIdentified
- getEmpty
- getIdentifier
- isIdentified
- getNumElements
- getElementType
- getElementTypes
- hasOffset
- getMemberOffset
- getMemberDecorations
- getMemberDecorations
- trySetBody
- getExtensions
- getCapabilities
- hash_value
- MatrixTypeStorage
- MatrixTypeStorage
- construct
- operator==
- get
- getChecked
- verifyInvariants
- isValidColumnType
- getColumnType
- getElementType
- getNumColumns
- getNumRows
- getNumElements
- getExtensions
Learn to use CMake with our Intro Training
Find out more