1//===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities used to lower to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/Transforms/DialectConversion.h"
22#include "llvm/ADT/StringExtras.h"
23#include "llvm/Support/Debug.h"
24#include "llvm/Support/MathExtras.h"
25
26#include <functional>
27#include <optional>
28
29#define DEBUG_TYPE "mlir-spirv-conversion"
30
31using namespace mlir;
32
33//===----------------------------------------------------------------------===//
34// Utility functions
35//===----------------------------------------------------------------------===//
36
37/// Checks that `candidates` extension requirements are possible to be satisfied
38/// with the given `targetEnv`.
39///
40/// `candidates` is a vector of vector for extension requirements following
41/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
42/// convention.
43template <typename LabelT>
44static LogicalResult checkExtensionRequirements(
45 LabelT label, const spirv::TargetEnv &targetEnv,
46 const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
47 for (const auto &ors : candidates) {
48 if (targetEnv.allows(ors))
49 continue;
50
51 LLVM_DEBUG({
52 SmallVector<StringRef> extStrings;
53 for (spirv::Extension ext : ors)
54 extStrings.push_back(spirv::stringifyExtension(ext));
55
56 llvm::dbgs() << label << " illegal: requires at least one extension in ["
57 << llvm::join(extStrings, ", ")
58 << "] but none allowed in target environment\n";
59 });
60 return failure();
61 }
62 return success();
63}
64
65/// Checks that `candidates`capability requirements are possible to be satisfied
66/// with the given `isAllowedFn`.
67///
68/// `candidates` is a vector of vector for capability requirements following
69/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
70/// convention.
71template <typename LabelT>
72static LogicalResult checkCapabilityRequirements(
73 LabelT label, const spirv::TargetEnv &targetEnv,
74 const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
75 for (const auto &ors : candidates) {
76 if (targetEnv.allows(ors))
77 continue;
78
79 LLVM_DEBUG({
80 SmallVector<StringRef> capStrings;
81 for (spirv::Capability cap : ors)
82 capStrings.push_back(spirv::stringifyCapability(cap));
83
84 llvm::dbgs() << label << " illegal: requires at least one capability in ["
85 << llvm::join(capStrings, ", ")
86 << "] but none allowed in target environment\n";
87 });
88 return failure();
89 }
90 return success();
91}
92
93/// Returns true if the given `storageClass` needs explicit layout when used in
94/// Shader environments.
95static bool needsExplicitLayout(spirv::StorageClass storageClass) {
96 switch (storageClass) {
97 case spirv::StorageClass::PhysicalStorageBuffer:
98 case spirv::StorageClass::PushConstant:
99 case spirv::StorageClass::StorageBuffer:
100 case spirv::StorageClass::Uniform:
101 return true;
102 default:
103 return false;
104 }
105}
106
107/// Wraps the given `elementType` in a struct and gets the pointer to the
108/// struct. This is used to satisfy Vulkan interface requirements.
109static spirv::PointerType
110wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
111 auto structType = needsExplicitLayout(storageClass)
112 ? spirv::StructType::get(memberTypes: elementType, /*offsetInfo=*/0)
113 : spirv::StructType::get(memberTypes: elementType);
114 return spirv::PointerType::get(structType, storageClass);
115}
116
117//===----------------------------------------------------------------------===//
118// Type Conversion
119//===----------------------------------------------------------------------===//
120
121static spirv::ScalarType getIndexType(MLIRContext *ctx,
122 const SPIRVConversionOptions &options) {
123 return cast<spirv::ScalarType>(
124 IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
125}
126
127Type SPIRVTypeConverter::getIndexType() const {
128 return ::getIndexType(ctx: getContext(), options);
129}
130
131MLIRContext *SPIRVTypeConverter::getContext() const {
132 return targetEnv.getAttr().getContext();
133}
134
135bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
136 return targetEnv.allows(capability);
137}
138
139// TODO: This is a utility function that should probably be exposed by the
140// SPIR-V dialect. Keeping it local till the use case arises.
141static std::optional<int64_t>
142getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
143 if (isa<spirv::ScalarType>(Val: type)) {
144 auto bitWidth = type.getIntOrFloatBitWidth();
145 // According to the SPIR-V spec:
146 // "There is no physical size or bit pattern defined for values with boolean
147 // type. If they are stored (in conjunction with OpVariable), they can only
148 // be used with logical addressing operations, not physical, and only with
149 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
150 // Private, Function, Input, and Output."
151 if (bitWidth == 1)
152 return std::nullopt;
153 return bitWidth / 8;
154 }
155
156 if (auto complexType = dyn_cast<ComplexType>(type)) {
157 auto elementSize = getTypeNumBytes(options, complexType.getElementType());
158 if (!elementSize)
159 return std::nullopt;
160 return 2 * *elementSize;
161 }
162
163 if (auto vecType = dyn_cast<VectorType>(type)) {
164 auto elementSize = getTypeNumBytes(options, vecType.getElementType());
165 if (!elementSize)
166 return std::nullopt;
167 return vecType.getNumElements() * *elementSize;
168 }
169
170 if (auto memRefType = dyn_cast<MemRefType>(type)) {
171 // TODO: Layout should also be controlled by the ABI attributes. For now
172 // using the layout from MemRef.
173 int64_t offset;
174 SmallVector<int64_t, 4> strides;
175 if (!memRefType.hasStaticShape() ||
176 failed(getStridesAndOffset(memRefType, strides, offset)))
177 return std::nullopt;
178
179 // To get the size of the memref object in memory, the total size is the
180 // max(stride * dimension-size) computed for all dimensions times the size
181 // of the element.
182 auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
183 if (!elementSize)
184 return std::nullopt;
185
186 if (memRefType.getRank() == 0)
187 return elementSize;
188
189 auto dims = memRefType.getShape();
190 if (llvm::is_contained(dims, ShapedType::kDynamic) ||
191 ShapedType::isDynamic(offset) ||
192 llvm::is_contained(strides, ShapedType::kDynamic))
193 return std::nullopt;
194
195 int64_t memrefSize = -1;
196 for (const auto &shape : enumerate(dims))
197 memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
198
199 return (offset + memrefSize) * *elementSize;
200 }
201
202 if (auto tensorType = dyn_cast<TensorType>(Val&: type)) {
203 if (!tensorType.hasStaticShape())
204 return std::nullopt;
205
206 auto elementSize = getTypeNumBytes(options, type: tensorType.getElementType());
207 if (!elementSize)
208 return std::nullopt;
209
210 int64_t size = *elementSize;
211 for (auto shape : tensorType.getShape())
212 size *= shape;
213
214 return size;
215 }
216
217 // TODO: Add size computation for other types.
218 return std::nullopt;
219}
220
221/// Converts a scalar `type` to a suitable type under the given `targetEnv`.
222static Type
223convertScalarType(const spirv::TargetEnv &targetEnv,
224 const SPIRVConversionOptions &options, spirv::ScalarType type,
225 std::optional<spirv::StorageClass> storageClass = {}) {
226 // Get extension and capability requirements for the given type.
227 SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
228 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
229 type.getExtensions(extensions, storageClass);
230 type.getCapabilities(capabilities, storageClass);
231
232 // If all requirements are met, then we can accept this type as-is.
233 if (succeeded(result: checkCapabilityRequirements(label: type, targetEnv, candidates: capabilities)) &&
234 succeeded(result: checkExtensionRequirements(label: type, targetEnv, candidates: extensions)))
235 return type;
236
237 // Otherwise we need to adjust the type, which really means adjusting the
238 // bitwidth given this is a scalar type.
239 if (!options.emulateLT32BitScalarTypes)
240 return nullptr;
241
242 // We only emulate narrower scalar types here and do not truncate results.
243 if (type.getIntOrFloatBitWidth() > 32) {
244 LLVM_DEBUG(llvm::dbgs()
245 << type
246 << " not converted to 32-bit for SPIR-V to avoid truncation\n");
247 return nullptr;
248 }
249
250 if (auto floatType = dyn_cast<FloatType>(Val&: type)) {
251 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
252 return Builder(targetEnv.getContext()).getF32Type();
253 }
254
255 auto intType = cast<IntegerType>(type);
256 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
257 return IntegerType::get(targetEnv.getContext(), /*width=*/32,
258 intType.getSignedness());
259}
260
261/// Converts a sub-byte integer `type` to i32 regardless of target environment.
262///
263/// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use
264/// the above given that these sub-byte types are not supported at all in
265/// SPIR-V; there are no compute/storage capability for them like other
266/// supported integer types.
267static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
268 IntegerType type) {
269 if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
270 LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
271 return nullptr;
272 }
273
274 if (!llvm::isPowerOf2_32(Value: type.getWidth())) {
275 LLVM_DEBUG(llvm::dbgs()
276 << "unsupported non-power-of-two bitwidth in sub-byte" << type
277 << "\n");
278 return nullptr;
279 }
280
281 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
282 return IntegerType::get(type.getContext(), /*width=*/32,
283 type.getSignedness());
284}
285
286/// Returns a type with the same shape but with any index element type converted
287/// to the matching integer type. This is a noop when the element type is not
288/// the index type.
289static ShapedType
290convertIndexElementType(ShapedType type,
291 const SPIRVConversionOptions &options) {
292 Type indexType = dyn_cast<IndexType>(type.getElementType());
293 if (!indexType)
294 return type;
295
296 return type.clone(getIndexType(type.getContext(), options));
297}
298
299/// Converts a vector `type` to a suitable type under the given `targetEnv`.
300static Type
301convertVectorType(const spirv::TargetEnv &targetEnv,
302 const SPIRVConversionOptions &options, VectorType type,
303 std::optional<spirv::StorageClass> storageClass = {}) {
304 type = cast<VectorType>(convertIndexElementType(type, options));
305 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
306 if (!scalarType) {
307 // If this is not a spec allowed scalar type, try to handle sub-byte integer
308 // types.
309 auto intType = dyn_cast<IntegerType>(type.getElementType());
310 if (!intType) {
311 LLVM_DEBUG(llvm::dbgs()
312 << type
313 << " illegal: cannot convert non-scalar element type\n");
314 return nullptr;
315 }
316
317 Type elementType = convertSubByteIntegerType(options, intType);
318 if (type.getRank() <= 1 && type.getNumElements() == 1)
319 return elementType;
320
321 if (type.getNumElements() > 4) {
322 LLVM_DEBUG(llvm::dbgs()
323 << type << " illegal: > 4-element unimplemented\n");
324 return nullptr;
325 }
326
327 return VectorType::get(type.getShape(), elementType);
328 }
329
330 if (type.getRank() <= 1 && type.getNumElements() == 1)
331 return convertScalarType(targetEnv, options, scalarType, storageClass);
332
333 if (!spirv::CompositeType::isValid(type)) {
334 LLVM_DEBUG(llvm::dbgs()
335 << type << " illegal: not a valid composite type\n");
336 return nullptr;
337 }
338
339 // Get extension and capability requirements for the given type.
340 SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
341 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
342 cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
343 cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
344
345 // If all requirements are met, then we can accept this type as-is.
346 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
347 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
348 return type;
349
350 auto elementType =
351 convertScalarType(targetEnv, options, scalarType, storageClass);
352 if (elementType)
353 return VectorType::get(type.getShape(), elementType);
354 return nullptr;
355}
356
357static Type
358convertComplexType(const spirv::TargetEnv &targetEnv,
359 const SPIRVConversionOptions &options, ComplexType type,
360 std::optional<spirv::StorageClass> storageClass = {}) {
361 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
362 if (!scalarType) {
363 LLVM_DEBUG(llvm::dbgs()
364 << type << " illegal: cannot convert non-scalar element type\n");
365 return nullptr;
366 }
367
368 auto elementType =
369 convertScalarType(targetEnv, options, scalarType, storageClass);
370 if (!elementType)
371 return nullptr;
372 if (elementType != type.getElementType()) {
373 LLVM_DEBUG(llvm::dbgs()
374 << type << " illegal: complex type emulation unsupported\n");
375 return nullptr;
376 }
377
378 return VectorType::get(2, elementType);
379}
380
381/// Converts a tensor `type` to a suitable type under the given `targetEnv`.
382///
383/// Note that this is mainly for lowering constant tensors. In SPIR-V one can
384/// create composite constants with OpConstantComposite to embed relative large
385/// constant values and use OpCompositeExtract and OpCompositeInsert to
386/// manipulate, like what we do for vectors.
387static Type convertTensorType(const spirv::TargetEnv &targetEnv,
388 const SPIRVConversionOptions &options,
389 TensorType type) {
390 // TODO: Handle dynamic shapes.
391 if (!type.hasStaticShape()) {
392 LLVM_DEBUG(llvm::dbgs()
393 << type << " illegal: dynamic shape unimplemented\n");
394 return nullptr;
395 }
396
397 type = cast<TensorType>(convertIndexElementType(type, options));
398 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(Val: type.getElementType());
399 if (!scalarType) {
400 LLVM_DEBUG(llvm::dbgs()
401 << type << " illegal: cannot convert non-scalar element type\n");
402 return nullptr;
403 }
404
405 std::optional<int64_t> scalarSize = getTypeNumBytes(options, type: scalarType);
406 std::optional<int64_t> tensorSize = getTypeNumBytes(options, type);
407 if (!scalarSize || !tensorSize) {
408 LLVM_DEBUG(llvm::dbgs()
409 << type << " illegal: cannot deduce element count\n");
410 return nullptr;
411 }
412
413 int64_t arrayElemCount = *tensorSize / *scalarSize;
414 if (arrayElemCount == 0) {
415 LLVM_DEBUG(llvm::dbgs()
416 << type << " illegal: cannot handle zero-element tensors\n");
417 return nullptr;
418 }
419
420 Type arrayElemType = convertScalarType(targetEnv, options, type: scalarType);
421 if (!arrayElemType)
422 return nullptr;
423 std::optional<int64_t> arrayElemSize =
424 getTypeNumBytes(options, type: arrayElemType);
425 if (!arrayElemSize) {
426 LLVM_DEBUG(llvm::dbgs()
427 << type << " illegal: cannot deduce converted element size\n");
428 return nullptr;
429 }
430
431 return spirv::ArrayType::get(elementType: arrayElemType, elementCount: arrayElemCount);
432}
433
434static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
435 const SPIRVConversionOptions &options,
436 MemRefType type,
437 spirv::StorageClass storageClass) {
438 unsigned numBoolBits = options.boolNumBits;
439 if (numBoolBits != 8) {
440 LLVM_DEBUG(llvm::dbgs()
441 << "using non-8-bit storage for bool types unimplemented");
442 return nullptr;
443 }
444 auto elementType = dyn_cast<spirv::ScalarType>(
445 IntegerType::get(type.getContext(), numBoolBits));
446 if (!elementType)
447 return nullptr;
448 Type arrayElemType =
449 convertScalarType(targetEnv, options, elementType, storageClass);
450 if (!arrayElemType)
451 return nullptr;
452 std::optional<int64_t> arrayElemSize =
453 getTypeNumBytes(options, type: arrayElemType);
454 if (!arrayElemSize) {
455 LLVM_DEBUG(llvm::dbgs()
456 << type << " illegal: cannot deduce converted element size\n");
457 return nullptr;
458 }
459
460 if (!type.hasStaticShape()) {
461 // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
462 // to the element.
463 if (targetEnv.allows(spirv::Capability::Kernel))
464 return spirv::PointerType::get(arrayElemType, storageClass);
465 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
466 auto arrayType = spirv::RuntimeArrayType::get(elementType: arrayElemType, stride);
467 // For Vulkan we need extra wrapping struct and array to satisfy interface
468 // needs.
469 return wrapInStructAndGetPointer(arrayType, storageClass);
470 }
471
472 if (type.getNumElements() == 0) {
473 LLVM_DEBUG(llvm::dbgs()
474 << type << " illegal: zero-element memrefs are not supported\n");
475 return nullptr;
476 }
477
478 int64_t memrefSize = llvm::divideCeil(Numerator: type.getNumElements() * numBoolBits, Denominator: 8);
479 int64_t arrayElemCount = llvm::divideCeil(Numerator: memrefSize, Denominator: *arrayElemSize);
480 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
481 auto arrayType = spirv::ArrayType::get(elementType: arrayElemType, elementCount: arrayElemCount, stride);
482 if (targetEnv.allows(spirv::Capability::Kernel))
483 return spirv::PointerType::get(arrayType, storageClass);
484 return wrapInStructAndGetPointer(arrayType, storageClass);
485}
486
487static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
488 const SPIRVConversionOptions &options,
489 MemRefType type,
490 spirv::StorageClass storageClass) {
491 IntegerType elementType = cast<IntegerType>(type.getElementType());
492 Type arrayElemType = convertSubByteIntegerType(options, elementType);
493 if (!arrayElemType)
494 return nullptr;
495 int64_t arrayElemSize = *getTypeNumBytes(options, type: arrayElemType);
496
497 if (!type.hasStaticShape()) {
498 // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
499 // to the element.
500 if (targetEnv.allows(spirv::Capability::Kernel))
501 return spirv::PointerType::get(arrayElemType, storageClass);
502 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
503 auto arrayType = spirv::RuntimeArrayType::get(elementType: arrayElemType, stride);
504 // For Vulkan we need extra wrapping struct and array to satisfy interface
505 // needs.
506 return wrapInStructAndGetPointer(arrayType, storageClass);
507 }
508
509 if (type.getNumElements() == 0) {
510 LLVM_DEBUG(llvm::dbgs()
511 << type << " illegal: zero-element memrefs are not supported\n");
512 return nullptr;
513 }
514
515 int64_t memrefSize =
516 llvm::divideCeil(Numerator: type.getNumElements() * elementType.getWidth(), Denominator: 8);
517 int64_t arrayElemCount = llvm::divideCeil(Numerator: memrefSize, Denominator: arrayElemSize);
518 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
519 auto arrayType = spirv::ArrayType::get(elementType: arrayElemType, elementCount: arrayElemCount, stride);
520 if (targetEnv.allows(spirv::Capability::Kernel))
521 return spirv::PointerType::get(arrayType, storageClass);
522 return wrapInStructAndGetPointer(arrayType, storageClass);
523}
524
525static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
526 const SPIRVConversionOptions &options,
527 MemRefType type) {
528 auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
529 if (!attr) {
530 LLVM_DEBUG(
531 llvm::dbgs()
532 << type
533 << " illegal: expected memory space to be a SPIR-V storage class "
534 "attribute; please use MemorySpaceToStorageClassConverter to map "
535 "numeric memory spaces beforehand\n");
536 return nullptr;
537 }
538 spirv::StorageClass storageClass = attr.getValue();
539
540 if (isa<IntegerType>(type.getElementType())) {
541 if (type.getElementTypeBitWidth() == 1)
542 return convertBoolMemrefType(targetEnv, options, type, storageClass);
543 if (type.getElementTypeBitWidth() < 8)
544 return convertSubByteMemrefType(targetEnv, options, type, storageClass);
545 }
546
547 Type arrayElemType;
548 Type elementType = type.getElementType();
549 if (auto vecType = dyn_cast<VectorType>(elementType)) {
550 arrayElemType =
551 convertVectorType(targetEnv, options, vecType, storageClass);
552 } else if (auto complexType = dyn_cast<ComplexType>(elementType)) {
553 arrayElemType =
554 convertComplexType(targetEnv, options, complexType, storageClass);
555 } else if (auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
556 arrayElemType =
557 convertScalarType(targetEnv, options, scalarType, storageClass);
558 } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
559 type = cast<MemRefType>(convertIndexElementType(type, options));
560 arrayElemType = type.getElementType();
561 } else {
562 LLVM_DEBUG(
563 llvm::dbgs()
564 << type
565 << " unhandled: can only convert scalar or vector element type\n");
566 return nullptr;
567 }
568 if (!arrayElemType)
569 return nullptr;
570
571 std::optional<int64_t> arrayElemSize =
572 getTypeNumBytes(options, type: arrayElemType);
573 if (!arrayElemSize) {
574 LLVM_DEBUG(llvm::dbgs()
575 << type << " illegal: cannot deduce converted element size\n");
576 return nullptr;
577 }
578
579 if (!type.hasStaticShape()) {
580 // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
581 // to the element.
582 if (targetEnv.allows(spirv::Capability::Kernel))
583 return spirv::PointerType::get(arrayElemType, storageClass);
584 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
585 auto arrayType = spirv::RuntimeArrayType::get(elementType: arrayElemType, stride);
586 // For Vulkan we need extra wrapping struct and array to satisfy interface
587 // needs.
588 return wrapInStructAndGetPointer(arrayType, storageClass);
589 }
590
591 std::optional<int64_t> memrefSize = getTypeNumBytes(options, type);
592 if (!memrefSize) {
593 LLVM_DEBUG(llvm::dbgs()
594 << type << " illegal: cannot deduce element count\n");
595 return nullptr;
596 }
597
598 if (*memrefSize == 0) {
599 LLVM_DEBUG(llvm::dbgs()
600 << type << " illegal: zero-element memrefs are not supported\n");
601 return nullptr;
602 }
603
604 int64_t arrayElemCount = llvm::divideCeil(Numerator: *memrefSize, Denominator: *arrayElemSize);
605 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
606 auto arrayType = spirv::ArrayType::get(elementType: arrayElemType, elementCount: arrayElemCount, stride);
607 if (targetEnv.allows(spirv::Capability::Kernel))
608 return spirv::PointerType::get(arrayType, storageClass);
609 return wrapInStructAndGetPointer(arrayType, storageClass);
610}
611
612//===----------------------------------------------------------------------===//
613// Type casting materialization
614//===----------------------------------------------------------------------===//
615
616/// Converts the given `inputs` to the original source `type` considering the
617/// `targetEnv`'s capabilities.
618///
619/// This function is meant to be used for source materialization in type
620/// converters. When the type converter needs to materialize a cast op back
621/// to some original source type, we need to check whether the original source
622/// type is supported in the target environment. If so, we can insert legal
623/// SPIR-V cast ops accordingly.
624///
625/// Note that in SPIR-V the capabilities for storage and compute are separate.
626/// This function is meant to handle the **compute** side; so it does not
627/// involve storage classes in its logic. The storage side is expected to be
628/// handled by MemRef conversion logic.
629std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
630 OpBuilder &builder, Type type,
631 ValueRange inputs, Location loc) {
632 // We can only cast one value in SPIR-V.
633 if (inputs.size() != 1) {
634 auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
635 return castOp.getResult(0);
636 }
637 Value input = inputs.front();
638
639 // Only support integer types for now. Floating point types to be implemented.
640 if (!isa<IntegerType>(Val: type)) {
641 auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
642 return castOp.getResult(0);
643 }
644 auto inputType = cast<IntegerType>(input.getType());
645
646 auto scalarType = dyn_cast<spirv::ScalarType>(Val&: type);
647 if (!scalarType) {
648 auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
649 return castOp.getResult(0);
650 }
651
652 // Only support source type with a smaller bitwidth. This would mean we are
653 // truncating to go back so we don't need to worry about the signedness.
654 // For extension, we cannot have enough signal here to decide which op to use.
655 if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
656 auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
657 return castOp.getResult(0);
658 }
659
660 // Boolean values would need to use different ops than normal integer values.
661 if (type.isInteger(width: 1)) {
662 Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
663 return builder.create<spirv::IEqualOp>(loc, input, one);
664 }
665
666 // Check that the source integer type is supported by the environment.
667 SmallVector<ArrayRef<spirv::Extension>, 1> exts;
668 SmallVector<ArrayRef<spirv::Capability>, 2> caps;
669 scalarType.getExtensions(exts);
670 scalarType.getCapabilities(caps);
671 if (failed(result: checkCapabilityRequirements(label: type, targetEnv, candidates: caps)) ||
672 failed(result: checkExtensionRequirements(label: type, targetEnv, candidates: exts))) {
673 auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
674 return castOp.getResult(0);
675 }
676
677 // We've already made sure this is truncating previously, so we don't need to
678 // care about signedness here. Still try to use a corresponding op for better
679 // consistency though.
680 if (type.isSignedInteger()) {
681 return builder.create<spirv::SConvertOp>(loc, type, input);
682 }
683 return builder.create<spirv::UConvertOp>(loc, type, input);
684}
685
686//===----------------------------------------------------------------------===//
687// SPIRVTypeConverter
688//===----------------------------------------------------------------------===//
689
690SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
691 const SPIRVConversionOptions &options)
692 : targetEnv(targetAttr), options(options) {
693 // Add conversions. The order matters here: later ones will be tried earlier.
694
695 // Allow all SPIR-V dialect specific types. This assumes all builtin types
696 // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
697 // were tried before.
698 //
699 // TODO: This assumes that the SPIR-V types are valid to use in the given
700 // target environment, which should be the case if the whole pipeline is
701 // driven by the same target environment. Still, we probably still want to
702 // validate and convert to be safe.
703 addConversion(callback: [](spirv::SPIRVType type) { return type; });
704
705 addConversion(callback: [this](IndexType /*indexType*/) { return getIndexType(); });
706
707 addConversion(callback: [this](IntegerType intType) -> std::optional<Type> {
708 if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
709 return convertScalarType(this->targetEnv, this->options, scalarType);
710 if (intType.getWidth() < 8)
711 return convertSubByteIntegerType(this->options, intType);
712 return Type();
713 });
714
715 addConversion(callback: [this](FloatType floatType) -> std::optional<Type> {
716 if (auto scalarType = dyn_cast<spirv::ScalarType>(Val&: floatType))
717 return convertScalarType(targetEnv: this->targetEnv, options: this->options, type: scalarType);
718 return Type();
719 });
720
721 addConversion(callback: [this](ComplexType complexType) {
722 return convertComplexType(this->targetEnv, this->options, complexType);
723 });
724
725 addConversion([this](VectorType vectorType) {
726 return convertVectorType(this->targetEnv, this->options, vectorType);
727 });
728
729 addConversion(callback: [this](TensorType tensorType) {
730 return convertTensorType(targetEnv: this->targetEnv, options: this->options, type: tensorType);
731 });
732
733 addConversion([this](MemRefType memRefType) {
734 return convertMemrefType(this->targetEnv, this->options, memRefType);
735 });
736
737 // Register some last line of defense casting logic.
738 addSourceMaterialization(
739 callback: [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
740 return castToSourceType(targetEnv: this->targetEnv, builder, type, inputs, loc);
741 });
742 addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
743 Location loc) {
744 auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
745 return std::optional<Value>(cast.getResult(0));
746 });
747}
748
749//===----------------------------------------------------------------------===//
750// func::FuncOp Conversion Patterns
751//===----------------------------------------------------------------------===//
752
753namespace {
754/// A pattern for rewriting function signature to convert arguments of functions
755/// to be of valid SPIR-V types.
756class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
757public:
758 using OpConversionPattern<func::FuncOp>::OpConversionPattern;
759
760 LogicalResult
761 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
762 ConversionPatternRewriter &rewriter) const override;
763};
764} // namespace
765
766LogicalResult
767FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
768 ConversionPatternRewriter &rewriter) const {
769 auto fnType = funcOp.getFunctionType();
770 if (fnType.getNumResults() > 1)
771 return failure();
772
773 TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
774 for (const auto &argType : enumerate(fnType.getInputs())) {
775 auto convertedType = getTypeConverter()->convertType(argType.value());
776 if (!convertedType)
777 return failure();
778 signatureConverter.addInputs(argType.index(), convertedType);
779 }
780
781 Type resultType;
782 if (fnType.getNumResults() == 1) {
783 resultType = getTypeConverter()->convertType(fnType.getResult(0));
784 if (!resultType)
785 return failure();
786 }
787
788 // Create the converted spirv.func op.
789 auto newFuncOp = rewriter.create<spirv::FuncOp>(
790 funcOp.getLoc(), funcOp.getName(),
791 rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
792 resultType ? TypeRange(resultType)
793 : TypeRange()));
794
795 // Copy over all attributes other than the function name and type.
796 for (const auto &namedAttr : funcOp->getAttrs()) {
797 if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
798 namedAttr.getName() != SymbolTable::getSymbolAttrName())
799 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
800 }
801
802 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
803 newFuncOp.end());
804 if (failed(rewriter.convertRegionTypes(
805 region: &newFuncOp.getBody(), converter: *getTypeConverter(), entryConversion: &signatureConverter)))
806 return failure();
807 rewriter.eraseOp(op: funcOp);
808 return success();
809}
810
811void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
812 RewritePatternSet &patterns) {
813 patterns.add<FuncOpConversion>(arg&: typeConverter, args: patterns.getContext());
814}
815
816//===----------------------------------------------------------------------===//
817// Builtin Variables
818//===----------------------------------------------------------------------===//
819
820static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
821 spirv::BuiltIn builtin) {
822 // Look through all global variables in the given `body` block and check if
823 // there is a spirv.GlobalVariable that has the same `builtin` attribute.
824 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
825 if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
826 spirv::SPIRVDialect::getAttributeName(
827 spirv::Decoration::BuiltIn))) {
828 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
829 if (varBuiltIn && *varBuiltIn == builtin) {
830 return varOp;
831 }
832 }
833 }
834 return nullptr;
835}
836
837/// Gets name of global variable for a builtin.
838static std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
839 StringRef suffix) {
840 return Twine(prefix).concat(Suffix: stringifyBuiltIn(builtin)).concat(suffix).str();
841}
842
843/// Gets or inserts a global variable for a builtin within `body` block.
844static spirv::GlobalVariableOp
845getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
846 Type integerType, OpBuilder &builder,
847 StringRef prefix, StringRef suffix) {
848 if (auto varOp = getBuiltinVariable(body, builtin))
849 return varOp;
850
851 OpBuilder::InsertionGuard guard(builder);
852 builder.setInsertionPointToStart(&body);
853
854 spirv::GlobalVariableOp newVarOp;
855 switch (builtin) {
856 case spirv::BuiltIn::NumWorkgroups:
857 case spirv::BuiltIn::WorkgroupSize:
858 case spirv::BuiltIn::WorkgroupId:
859 case spirv::BuiltIn::LocalInvocationId:
860 case spirv::BuiltIn::GlobalInvocationId: {
861 auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
862 spirv::StorageClass::Input);
863 std::string name = getBuiltinVarName(builtin, prefix, suffix);
864 newVarOp =
865 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
866 break;
867 }
868 case spirv::BuiltIn::SubgroupId:
869 case spirv::BuiltIn::NumSubgroups:
870 case spirv::BuiltIn::SubgroupSize: {
871 auto ptrType =
872 spirv::PointerType::get(integerType, spirv::StorageClass::Input);
873 std::string name = getBuiltinVarName(builtin, prefix, suffix);
874 newVarOp =
875 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
876 break;
877 }
878 default:
879 emitError(loc, message: "unimplemented builtin variable generation for ")
880 << stringifyBuiltIn(builtin);
881 }
882 return newVarOp;
883}
884
885Value mlir::spirv::getBuiltinVariableValue(Operation *op,
886 spirv::BuiltIn builtin,
887 Type integerType, OpBuilder &builder,
888 StringRef prefix, StringRef suffix) {
889 Operation *parent = SymbolTable::getNearestSymbolTable(from: op->getParentOp());
890 if (!parent) {
891 op->emitError(message: "expected operation to be within a module-like op");
892 return nullptr;
893 }
894
895 spirv::GlobalVariableOp varOp =
896 getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
897 builtin, integerType, builder, prefix, suffix);
898 Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
899 return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
900}
901
902//===----------------------------------------------------------------------===//
903// Push constant storage
904//===----------------------------------------------------------------------===//
905
906/// Returns the pointer type for the push constant storage containing
907/// `elementCount` 32-bit integer values.
908static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
909 Builder &builder,
910 Type indexType) {
911 auto arrayType = spirv::ArrayType::get(elementType: indexType, elementCount,
912 /*stride=*/4);
913 auto structType = spirv::StructType::get(memberTypes: {arrayType}, /*offsetInfo=*/0);
914 return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
915}
916
917/// Returns the push constant varible containing `elementCount` 32-bit integer
918/// values in `body`. Returns null op if such an op does not exit.
919static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
920 unsigned elementCount) {
921 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
922 auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
923 if (!ptrType)
924 continue;
925
926 // Note that Vulkan requires "There must be no more than one push constant
927 // block statically used per shader entry point." So we should always reuse
928 // the existing one.
929 if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
930 auto numElements = cast<spirv::ArrayType>(
931 cast<spirv::StructType>(ptrType.getPointeeType())
932 .getElementType(0))
933 .getNumElements();
934 if (numElements == elementCount)
935 return varOp;
936 }
937 }
938 return nullptr;
939}
940
941/// Gets or inserts a global variable for push constant storage containing
942/// `elementCount` 32-bit integer values in `block`.
943static spirv::GlobalVariableOp
944getOrInsertPushConstantVariable(Location loc, Block &block,
945 unsigned elementCount, OpBuilder &b,
946 Type indexType) {
947 if (auto varOp = getPushConstantVariable(block, elementCount))
948 return varOp;
949
950 auto builder = OpBuilder::atBlockBegin(block: &block, listener: b.getListener());
951 auto type = getPushConstantStorageType(elementCount, builder, indexType);
952 const char *name = "__push_constant_var__";
953 return builder.create<spirv::GlobalVariableOp>(loc, type, name,
954 /*initializer=*/nullptr);
955}
956
957Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
958 unsigned offset, Type integerType,
959 OpBuilder &builder) {
960 Location loc = op->getLoc();
961 Operation *parent = SymbolTable::getNearestSymbolTable(from: op->getParentOp());
962 if (!parent) {
963 op->emitError(message: "expected operation to be within a module-like op");
964 return nullptr;
965 }
966
967 spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
968 loc, parent->getRegion(0).front(), elementCount, builder, integerType);
969
970 Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
971 Value offsetOp = builder.create<spirv::ConstantOp>(
972 loc, integerType, builder.getI32IntegerAttr(offset));
973 auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
974 auto acOp = builder.create<spirv::AccessChainOp>(
975 loc, addrOp, llvm::ArrayRef({zeroOp, offsetOp}));
976 return builder.create<spirv::LoadOp>(loc, acOp);
977}
978
979//===----------------------------------------------------------------------===//
980// Index calculation
981//===----------------------------------------------------------------------===//
982
983Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
984 int64_t offset, Type integerType,
985 Location loc, OpBuilder &builder) {
986 assert(indices.size() == strides.size() &&
987 "must provide indices for all dimensions");
988
989 // TODO: Consider moving to use affine.apply and patterns converting
990 // affine.apply to standard ops. This needs converting to SPIR-V passes to be
991 // broken down into progressive small steps so we can have intermediate steps
992 // using other dialects. At the moment SPIR-V is the final sink.
993
994 Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
995 loc, integerType, IntegerAttr::get(integerType, offset));
996 for (const auto &index : llvm::enumerate(First&: indices)) {
997 Value strideVal = builder.createOrFold<spirv::ConstantOp>(
998 loc, integerType,
999 IntegerAttr::get(integerType, strides[index.index()]));
1000 Value update =
1001 builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
1002 linearizedIndex =
1003 builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
1004 }
1005 return linearizedIndex;
1006}
1007
1008Value mlir::spirv::getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
1009 MemRefType baseType, Value basePtr,
1010 ValueRange indices, Location loc,
1011 OpBuilder &builder) {
1012 // Get base and offset of the MemRefType and verify they are static.
1013
1014 int64_t offset;
1015 SmallVector<int64_t, 4> strides;
1016 if (failed(getStridesAndOffset(baseType, strides, offset)) ||
1017 llvm::is_contained(strides, ShapedType::kDynamic) ||
1018 ShapedType::isDynamic(offset)) {
1019 return nullptr;
1020 }
1021
1022 auto indexType = typeConverter.getIndexType();
1023
1024 SmallVector<Value, 2> linearizedIndices;
1025 auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
1026
1027 // Add a '0' at the start to index into the struct.
1028 linearizedIndices.push_back(Elt: zero);
1029
1030 if (baseType.getRank() == 0) {
1031 linearizedIndices.push_back(Elt: zero);
1032 } else {
1033 linearizedIndices.push_back(
1034 Elt: linearizeIndex(indices, strides, offset, integerType: indexType, loc, builder));
1035 }
1036 return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
1037}
1038
1039Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter,
1040 MemRefType baseType, Value basePtr,
1041 ValueRange indices, Location loc,
1042 OpBuilder &builder) {
1043 // Get base and offset of the MemRefType and verify they are static.
1044
1045 int64_t offset;
1046 SmallVector<int64_t, 4> strides;
1047 if (failed(getStridesAndOffset(baseType, strides, offset)) ||
1048 llvm::is_contained(strides, ShapedType::kDynamic) ||
1049 ShapedType::isDynamic(offset)) {
1050 return nullptr;
1051 }
1052
1053 auto indexType = typeConverter.getIndexType();
1054
1055 SmallVector<Value, 2> linearizedIndices;
1056 Value linearIndex;
1057 if (baseType.getRank() == 0) {
1058 linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
1059 } else {
1060 linearIndex =
1061 linearizeIndex(indices, strides, offset, integerType: indexType, loc, builder);
1062 }
1063 Type pointeeType =
1064 cast<spirv::PointerType>(Val: basePtr.getType()).getPointeeType();
1065 if (isa<spirv::ArrayType>(Val: pointeeType)) {
1066 linearizedIndices.push_back(Elt: linearIndex);
1067 return builder.create<spirv::AccessChainOp>(loc, basePtr,
1068 linearizedIndices);
1069 }
1070 return builder.create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
1071 linearizedIndices);
1072}
1073
1074Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter,
1075 MemRefType baseType, Value basePtr,
1076 ValueRange indices, Location loc,
1077 OpBuilder &builder) {
1078
1079 if (typeConverter.allows(spirv::Capability::Kernel)) {
1080 return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc,
1081 builder);
1082 }
1083
1084 return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc,
1085 builder);
1086}
1087
1088//===----------------------------------------------------------------------===//
1089// SPIR-V ConversionTarget
1090//===----------------------------------------------------------------------===//
1091
1092std::unique_ptr<SPIRVConversionTarget>
1093SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) {
1094 std::unique_ptr<SPIRVConversionTarget> target(
1095 // std::make_unique does not work here because the constructor is private.
1096 new SPIRVConversionTarget(targetAttr));
1097 SPIRVConversionTarget *targetPtr = target.get();
1098 target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
1099 // We need to capture the raw pointer here because it is stable:
1100 // target will be destroyed once this function is returned.
1101 [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
1102 return target;
1103}
1104
1105SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
1106 : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
1107
1108bool SPIRVConversionTarget::isLegalOp(Operation *op) {
1109 // Make sure this op is available at the given version. Ops not implementing
1110 // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
1111 // SPIR-V versions.
1112 if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
1113 std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1114 if (minVersion && *minVersion > this->targetEnv.getVersion()) {
1115 LLVM_DEBUG(llvm::dbgs()
1116 << op->getName() << " illegal: requiring min version "
1117 << spirv::stringifyVersion(*minVersion) << "\n");
1118 return false;
1119 }
1120 }
1121 if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
1122 std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1123 if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
1124 LLVM_DEBUG(llvm::dbgs()
1125 << op->getName() << " illegal: requiring max version "
1126 << spirv::stringifyVersion(*maxVersion) << "\n");
1127 return false;
1128 }
1129 }
1130
1131 // Make sure this op's required extensions are allowed to use. Ops not
1132 // implementing QueryExtensionInterface do not require extensions to be
1133 // available.
1134 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
1135 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1136 extensions.getExtensions())))
1137 return false;
1138
1139 // Make sure this op's required extensions are allowed to use. Ops not
1140 // implementing QueryCapabilityInterface do not require capabilities to be
1141 // available.
1142 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
1143 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1144 capabilities.getCapabilities())))
1145 return false;
1146
1147 SmallVector<Type, 4> valueTypes;
1148 valueTypes.append(in_start: op->operand_type_begin(), in_end: op->operand_type_end());
1149 valueTypes.append(in_start: op->result_type_begin(), in_end: op->result_type_end());
1150
1151 // Ensure that all types have been converted to SPIRV types.
1152 if (llvm::any_of(Range&: valueTypes,
1153 P: [](Type t) { return !isa<spirv::SPIRVType>(Val: t); }))
1154 return false;
1155
1156 // Special treatment for global variables, whose type requirements are
1157 // conveyed by type attributes.
1158 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
1159 valueTypes.push_back(Elt: globalVar.getType());
1160
1161 // Make sure the op's operands/results use types that are allowed by the
1162 // target environment.
1163 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
1164 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
1165 for (Type valueType : valueTypes) {
1166 typeExtensions.clear();
1167 cast<spirv::SPIRVType>(Val&: valueType).getExtensions(typeExtensions);
1168 if (failed(result: checkExtensionRequirements(label: op->getName(), targetEnv: this->targetEnv,
1169 candidates: typeExtensions)))
1170 return false;
1171
1172 typeCapabilities.clear();
1173 cast<spirv::SPIRVType>(Val&: valueType).getCapabilities(typeCapabilities);
1174 if (failed(result: checkCapabilityRequirements(label: op->getName(), targetEnv: this->targetEnv,
1175 candidates: typeCapabilities)))
1176 return false;
1177 }
1178
1179 return true;
1180}
1181

source code of mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp