1//===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities used to lower to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21#include "mlir/Dialect/Utils/IndexingUtils.h"
22#include "mlir/Dialect/Vector/IR/VectorOps.h"
23#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
24#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
25#include "mlir/IR/BuiltinTypes.h"
26#include "mlir/IR/Operation.h"
27#include "mlir/IR/PatternMatch.h"
28#include "mlir/Support/LLVM.h"
29#include "mlir/Transforms/DialectConversion.h"
30#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/StringExtras.h"
34#include "llvm/Support/Debug.h"
35#include "llvm/Support/MathExtras.h"
36
37#include <optional>
38
39#define DEBUG_TYPE "mlir-spirv-conversion"
40
41using namespace mlir;
42
43namespace {
44
45//===----------------------------------------------------------------------===//
46// Utility functions
47//===----------------------------------------------------------------------===//
48
49static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
50 LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
51 if (vecType.isScalable()) {
52 LLVM_DEBUG(llvm::dbgs()
53 << "--scalable vectors are not supported -> BAIL\n");
54 return std::nullopt;
55 }
56 SmallVector<int64_t> unrollShape = llvm::to_vector<4>(Range: vecType.getShape());
57 std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>(
58 1, mlir::spirv::getComputeVectorSize(size: vecType.getShape().back()));
59 if (!targetShape) {
60 LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
61 return std::nullopt;
62 }
63 auto maybeShapeRatio = computeShapeRatio(shape: unrollShape, subShape: *targetShape);
64 if (!maybeShapeRatio) {
65 LLVM_DEBUG(llvm::dbgs()
66 << "--could not compute integral shape ratio -> BAIL\n");
67 return std::nullopt;
68 }
69 if (llvm::all_of(Range&: *maybeShapeRatio, P: [](int64_t v) { return v == 1; })) {
70 LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n");
71 return std::nullopt;
72 }
73 LLVM_DEBUG(llvm::dbgs()
74 << "--found an integral shape ratio to unroll to -> SUCCESS\n");
75 return targetShape;
76}
77
78/// Checks that `candidates` extension requirements are possible to be satisfied
79/// with the given `targetEnv`.
80///
81/// `candidates` is a vector of vector for extension requirements following
82/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
83/// convention.
84template <typename LabelT>
85static LogicalResult checkExtensionRequirements(
86 LabelT label, const spirv::TargetEnv &targetEnv,
87 const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
88 for (const auto &ors : candidates) {
89 if (targetEnv.allows(ors))
90 continue;
91
92 LLVM_DEBUG({
93 SmallVector<StringRef> extStrings;
94 for (spirv::Extension ext : ors)
95 extStrings.push_back(spirv::stringifyExtension(ext));
96
97 llvm::dbgs() << label << " illegal: requires at least one extension in ["
98 << llvm::join(extStrings, ", ")
99 << "] but none allowed in target environment\n";
100 });
101 return failure();
102 }
103 return success();
104}
105
106/// Checks that `candidates`capability requirements are possible to be satisfied
107/// with the given `isAllowedFn`.
108///
109/// `candidates` is a vector of vector for capability requirements following
110/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
111/// convention.
112template <typename LabelT>
113static LogicalResult checkCapabilityRequirements(
114 LabelT label, const spirv::TargetEnv &targetEnv,
115 const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
116 for (const auto &ors : candidates) {
117 if (targetEnv.allows(ors))
118 continue;
119
120 LLVM_DEBUG({
121 SmallVector<StringRef> capStrings;
122 for (spirv::Capability cap : ors)
123 capStrings.push_back(spirv::stringifyCapability(cap));
124
125 llvm::dbgs() << label << " illegal: requires at least one capability in ["
126 << llvm::join(capStrings, ", ")
127 << "] but none allowed in target environment\n";
128 });
129 return failure();
130 }
131 return success();
132}
133
134/// Returns true if the given `storageClass` needs explicit layout when used in
135/// Shader environments.
136static bool needsExplicitLayout(spirv::StorageClass storageClass) {
137 switch (storageClass) {
138 case spirv::StorageClass::PhysicalStorageBuffer:
139 case spirv::StorageClass::PushConstant:
140 case spirv::StorageClass::StorageBuffer:
141 case spirv::StorageClass::Uniform:
142 return true;
143 default:
144 return false;
145 }
146}
147
148/// Wraps the given `elementType` in a struct and gets the pointer to the
149/// struct. This is used to satisfy Vulkan interface requirements.
150static spirv::PointerType
151wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
152 auto structType = needsExplicitLayout(storageClass)
153 ? spirv::StructType::get(memberTypes: elementType, /*offsetInfo=*/0)
154 : spirv::StructType::get(memberTypes: elementType);
155 return spirv::PointerType::get(pointeeType: structType, storageClass);
156}
157
158//===----------------------------------------------------------------------===//
159// Type Conversion
160//===----------------------------------------------------------------------===//
161
162static spirv::ScalarType getIndexType(MLIRContext *ctx,
163 const SPIRVConversionOptions &options) {
164 return cast<spirv::ScalarType>(
165 Val: IntegerType::get(context: ctx, width: options.use64bitIndex ? 64 : 32));
166}
167
168// TODO: This is a utility function that should probably be exposed by the
169// SPIR-V dialect. Keeping it local till the use case arises.
170static std::optional<int64_t>
171getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
172 if (isa<spirv::ScalarType>(Val: type)) {
173 auto bitWidth = type.getIntOrFloatBitWidth();
174 // According to the SPIR-V spec:
175 // "There is no physical size or bit pattern defined for values with boolean
176 // type. If they are stored (in conjunction with OpVariable), they can only
177 // be used with logical addressing operations, not physical, and only with
178 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
179 // Private, Function, Input, and Output."
180 if (bitWidth == 1)
181 return std::nullopt;
182 return bitWidth / 8;
183 }
184
185 if (auto complexType = dyn_cast<ComplexType>(Val&: type)) {
186 auto elementSize = getTypeNumBytes(options, type: complexType.getElementType());
187 if (!elementSize)
188 return std::nullopt;
189 return 2 * *elementSize;
190 }
191
192 if (auto vecType = dyn_cast<VectorType>(Val&: type)) {
193 auto elementSize = getTypeNumBytes(options, type: vecType.getElementType());
194 if (!elementSize)
195 return std::nullopt;
196 return vecType.getNumElements() * *elementSize;
197 }
198
199 if (auto memRefType = dyn_cast<MemRefType>(Val&: type)) {
200 // TODO: Layout should also be controlled by the ABI attributes. For now
201 // using the layout from MemRef.
202 int64_t offset;
203 SmallVector<int64_t, 4> strides;
204 if (!memRefType.hasStaticShape() ||
205 failed(Result: memRefType.getStridesAndOffset(strides, offset)))
206 return std::nullopt;
207
208 // To get the size of the memref object in memory, the total size is the
209 // max(stride * dimension-size) computed for all dimensions times the size
210 // of the element.
211 auto elementSize = getTypeNumBytes(options, type: memRefType.getElementType());
212 if (!elementSize)
213 return std::nullopt;
214
215 if (memRefType.getRank() == 0)
216 return elementSize;
217
218 auto dims = memRefType.getShape();
219 if (llvm::is_contained(Range&: dims, Element: ShapedType::kDynamic) ||
220 ShapedType::isDynamic(dValue: offset) ||
221 llvm::is_contained(Range&: strides, Element: ShapedType::kDynamic))
222 return std::nullopt;
223
224 int64_t memrefSize = -1;
225 for (const auto &shape : enumerate(First&: dims))
226 memrefSize = std::max(a: memrefSize, b: shape.value() * strides[shape.index()]);
227
228 return (offset + memrefSize) * *elementSize;
229 }
230
231 if (auto tensorType = dyn_cast<TensorType>(Val&: type)) {
232 if (!tensorType.hasStaticShape())
233 return std::nullopt;
234
235 auto elementSize = getTypeNumBytes(options, type: tensorType.getElementType());
236 if (!elementSize)
237 return std::nullopt;
238
239 int64_t size = *elementSize;
240 for (auto shape : tensorType.getShape())
241 size *= shape;
242
243 return size;
244 }
245
246 // TODO: Add size computation for other types.
247 return std::nullopt;
248}
249
250/// Converts a scalar `type` to a suitable type under the given `targetEnv`.
251static Type
252convertScalarType(const spirv::TargetEnv &targetEnv,
253 const SPIRVConversionOptions &options, spirv::ScalarType type,
254 std::optional<spirv::StorageClass> storageClass = {}) {
255 // Get extension and capability requirements for the given type.
256 SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
257 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
258 type.getExtensions(extensions, storage: storageClass);
259 type.getCapabilities(capabilities, storage: storageClass);
260
261 // If all requirements are met, then we can accept this type as-is.
262 if (succeeded(Result: checkCapabilityRequirements(label: type, targetEnv, candidates: capabilities)) &&
263 succeeded(Result: checkExtensionRequirements(label: type, targetEnv, candidates: extensions)))
264 return type;
265
266 // Otherwise we need to adjust the type, which really means adjusting the
267 // bitwidth given this is a scalar type.
268 if (!options.emulateLT32BitScalarTypes)
269 return nullptr;
270
271 // We only emulate narrower scalar types here and do not truncate results.
272 if (type.getIntOrFloatBitWidth() > 32) {
273 LLVM_DEBUG(llvm::dbgs()
274 << type
275 << " not converted to 32-bit for SPIR-V to avoid truncation\n");
276 return nullptr;
277 }
278
279 if (auto floatType = dyn_cast<FloatType>(Val&: type)) {
280 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
281 return Builder(targetEnv.getContext()).getF32Type();
282 }
283
284 auto intType = cast<IntegerType>(Val&: type);
285 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
286 return IntegerType::get(context: targetEnv.getContext(), /*width=*/32,
287 signedness: intType.getSignedness());
288}
289
290/// Converts a sub-byte integer `type` to i32 regardless of target environment.
291/// Returns a nullptr for unsupported integer types, including non sub-byte
292/// types.
293///
294/// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use
295/// the above given that these sub-byte types are not supported at all in
296/// SPIR-V; there are no compute/storage capability for them like other
297/// supported integer types.
298static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
299 IntegerType type) {
300 if (type.getWidth() > 8) {
301 LLVM_DEBUG(llvm::dbgs() << "not a subbyte type\n");
302 return nullptr;
303 }
304 if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
305 LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
306 return nullptr;
307 }
308
309 if (!llvm::isPowerOf2_32(Value: type.getWidth())) {
310 LLVM_DEBUG(llvm::dbgs()
311 << "unsupported non-power-of-two bitwidth in sub-byte" << type
312 << "\n");
313 return nullptr;
314 }
315
316 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
317 return IntegerType::get(context: type.getContext(), /*width=*/32,
318 signedness: type.getSignedness());
319}
320
321/// Returns a type with the same shape but with any index element type converted
322/// to the matching integer type. This is a noop when the element type is not
323/// the index type.
324static ShapedType
325convertIndexElementType(ShapedType type,
326 const SPIRVConversionOptions &options) {
327 Type indexType = dyn_cast<IndexType>(Val: type.getElementType());
328 if (!indexType)
329 return type;
330
331 return type.clone(elementType: getIndexType(ctx: type.getContext(), options));
332}
333
334/// Converts a vector `type` to a suitable type under the given `targetEnv`.
335static Type
336convertVectorType(const spirv::TargetEnv &targetEnv,
337 const SPIRVConversionOptions &options, VectorType type,
338 std::optional<spirv::StorageClass> storageClass = {}) {
339 type = cast<VectorType>(Val: convertIndexElementType(type, options));
340 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(Val: type.getElementType());
341 if (!scalarType) {
342 // If this is not a spec allowed scalar type, try to handle sub-byte integer
343 // types.
344 auto intType = dyn_cast<IntegerType>(Val: type.getElementType());
345 if (!intType) {
346 LLVM_DEBUG(llvm::dbgs()
347 << type
348 << " illegal: cannot convert non-scalar element type\n");
349 return nullptr;
350 }
351
352 Type elementType = convertSubByteIntegerType(options, type: intType);
353 if (!elementType)
354 return nullptr;
355
356 if (type.getRank() <= 1 && type.getNumElements() == 1)
357 return elementType;
358
359 if (type.getNumElements() > 4) {
360 LLVM_DEBUG(llvm::dbgs()
361 << type << " illegal: > 4-element unimplemented\n");
362 return nullptr;
363 }
364
365 return VectorType::get(shape: type.getShape(), elementType);
366 }
367
368 if (type.getRank() <= 1 && type.getNumElements() == 1)
369 return convertScalarType(targetEnv, options, type: scalarType, storageClass);
370
371 if (!spirv::CompositeType::isValid(type)) {
372 LLVM_DEBUG(llvm::dbgs()
373 << type << " illegal: not a valid composite type\n");
374 return nullptr;
375 }
376
377 // Get extension and capability requirements for the given type.
378 SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
379 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
380 cast<spirv::CompositeType>(Val&: type).getExtensions(extensions, storage: storageClass);
381 cast<spirv::CompositeType>(Val&: type).getCapabilities(capabilities, storage: storageClass);
382
383 // If all requirements are met, then we can accept this type as-is.
384 if (succeeded(Result: checkCapabilityRequirements(label: type, targetEnv, candidates: capabilities)) &&
385 succeeded(Result: checkExtensionRequirements(label: type, targetEnv, candidates: extensions)))
386 return type;
387
388 auto elementType =
389 convertScalarType(targetEnv, options, type: scalarType, storageClass);
390 if (elementType)
391 return VectorType::get(shape: type.getShape(), elementType);
392 return nullptr;
393}
394
395static Type
396convertComplexType(const spirv::TargetEnv &targetEnv,
397 const SPIRVConversionOptions &options, ComplexType type,
398 std::optional<spirv::StorageClass> storageClass = {}) {
399 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(Val: type.getElementType());
400 if (!scalarType) {
401 LLVM_DEBUG(llvm::dbgs()
402 << type << " illegal: cannot convert non-scalar element type\n");
403 return nullptr;
404 }
405
406 auto elementType =
407 convertScalarType(targetEnv, options, type: scalarType, storageClass);
408 if (!elementType)
409 return nullptr;
410 if (elementType != type.getElementType()) {
411 LLVM_DEBUG(llvm::dbgs()
412 << type << " illegal: complex type emulation unsupported\n");
413 return nullptr;
414 }
415
416 return VectorType::get(shape: 2, elementType);
417}
418
419/// Converts a tensor `type` to a suitable type under the given `targetEnv`.
420///
421/// Note that this is mainly for lowering constant tensors. In SPIR-V one can
422/// create composite constants with OpConstantComposite to embed relative large
423/// constant values and use OpCompositeExtract and OpCompositeInsert to
424/// manipulate, like what we do for vectors.
425static Type convertTensorType(const spirv::TargetEnv &targetEnv,
426 const SPIRVConversionOptions &options,
427 TensorType type) {
428 // TODO: Handle dynamic shapes.
429 if (!type.hasStaticShape()) {
430 LLVM_DEBUG(llvm::dbgs()
431 << type << " illegal: dynamic shape unimplemented\n");
432 return nullptr;
433 }
434
435 type = cast<TensorType>(Val: convertIndexElementType(type, options));
436 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(Val: type.getElementType());
437 if (!scalarType) {
438 LLVM_DEBUG(llvm::dbgs()
439 << type << " illegal: cannot convert non-scalar element type\n");
440 return nullptr;
441 }
442
443 std::optional<int64_t> scalarSize = getTypeNumBytes(options, type: scalarType);
444 std::optional<int64_t> tensorSize = getTypeNumBytes(options, type);
445 if (!scalarSize || !tensorSize) {
446 LLVM_DEBUG(llvm::dbgs()
447 << type << " illegal: cannot deduce element count\n");
448 return nullptr;
449 }
450
451 int64_t arrayElemCount = *tensorSize / *scalarSize;
452 if (arrayElemCount == 0) {
453 LLVM_DEBUG(llvm::dbgs()
454 << type << " illegal: cannot handle zero-element tensors\n");
455 return nullptr;
456 }
457
458 Type arrayElemType = convertScalarType(targetEnv, options, type: scalarType);
459 if (!arrayElemType)
460 return nullptr;
461 std::optional<int64_t> arrayElemSize =
462 getTypeNumBytes(options, type: arrayElemType);
463 if (!arrayElemSize) {
464 LLVM_DEBUG(llvm::dbgs()
465 << type << " illegal: cannot deduce converted element size\n");
466 return nullptr;
467 }
468
469 return spirv::ArrayType::get(elementType: arrayElemType, elementCount: arrayElemCount);
470}
471
472static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
473 const SPIRVConversionOptions &options,
474 MemRefType type,
475 spirv::StorageClass storageClass) {
476 unsigned numBoolBits = options.boolNumBits;
477 if (numBoolBits != 8) {
478 LLVM_DEBUG(llvm::dbgs()
479 << "using non-8-bit storage for bool types unimplemented");
480 return nullptr;
481 }
482 auto elementType = dyn_cast<spirv::ScalarType>(
483 Val: IntegerType::get(context: type.getContext(), width: numBoolBits));
484 if (!elementType)
485 return nullptr;
486 Type arrayElemType =
487 convertScalarType(targetEnv, options, type: elementType, storageClass);
488 if (!arrayElemType)
489 return nullptr;
490 std::optional<int64_t> arrayElemSize =
491 getTypeNumBytes(options, type: arrayElemType);
492 if (!arrayElemSize) {
493 LLVM_DEBUG(llvm::dbgs()
494 << type << " illegal: cannot deduce converted element size\n");
495 return nullptr;
496 }
497
498 if (!type.hasStaticShape()) {
499 // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
500 // to the element.
501 if (targetEnv.allows(spirv::Capability::Kernel))
502 return spirv::PointerType::get(pointeeType: arrayElemType, storageClass);
503 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
504 auto arrayType = spirv::RuntimeArrayType::get(elementType: arrayElemType, stride);
505 // For Vulkan we need extra wrapping struct and array to satisfy interface
506 // needs.
507 return wrapInStructAndGetPointer(elementType: arrayType, storageClass);
508 }
509
510 if (type.getNumElements() == 0) {
511 LLVM_DEBUG(llvm::dbgs()
512 << type << " illegal: zero-element memrefs are not supported\n");
513 return nullptr;
514 }
515
516 int64_t memrefSize = llvm::divideCeil(Numerator: type.getNumElements() * numBoolBits, Denominator: 8);
517 int64_t arrayElemCount = llvm::divideCeil(Numerator: memrefSize, Denominator: *arrayElemSize);
518 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
519 auto arrayType = spirv::ArrayType::get(elementType: arrayElemType, elementCount: arrayElemCount, stride);
520 if (targetEnv.allows(spirv::Capability::Kernel))
521 return spirv::PointerType::get(pointeeType: arrayType, storageClass);
522 return wrapInStructAndGetPointer(elementType: arrayType, storageClass);
523}
524
525static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
526 const SPIRVConversionOptions &options,
527 MemRefType type,
528 spirv::StorageClass storageClass) {
529 IntegerType elementType = cast<IntegerType>(Val: type.getElementType());
530 Type arrayElemType = convertSubByteIntegerType(options, type: elementType);
531 if (!arrayElemType)
532 return nullptr;
533 int64_t arrayElemSize = *getTypeNumBytes(options, type: arrayElemType);
534
535 if (!type.hasStaticShape()) {
536 // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
537 // to the element.
538 if (targetEnv.allows(spirv::Capability::Kernel))
539 return spirv::PointerType::get(pointeeType: arrayElemType, storageClass);
540 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
541 auto arrayType = spirv::RuntimeArrayType::get(elementType: arrayElemType, stride);
542 // For Vulkan we need extra wrapping struct and array to satisfy interface
543 // needs.
544 return wrapInStructAndGetPointer(elementType: arrayType, storageClass);
545 }
546
547 if (type.getNumElements() == 0) {
548 LLVM_DEBUG(llvm::dbgs()
549 << type << " illegal: zero-element memrefs are not supported\n");
550 return nullptr;
551 }
552
553 int64_t memrefSize =
554 llvm::divideCeil(Numerator: type.getNumElements() * elementType.getWidth(), Denominator: 8);
555 int64_t arrayElemCount = llvm::divideCeil(Numerator: memrefSize, Denominator: arrayElemSize);
556 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
557 auto arrayType = spirv::ArrayType::get(elementType: arrayElemType, elementCount: arrayElemCount, stride);
558 if (targetEnv.allows(spirv::Capability::Kernel))
559 return spirv::PointerType::get(pointeeType: arrayType, storageClass);
560 return wrapInStructAndGetPointer(elementType: arrayType, storageClass);
561}
562
563static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
564 const SPIRVConversionOptions &options,
565 MemRefType type) {
566 auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(Val: type.getMemorySpace());
567 if (!attr) {
568 LLVM_DEBUG(
569 llvm::dbgs()
570 << type
571 << " illegal: expected memory space to be a SPIR-V storage class "
572 "attribute; please use MemorySpaceToStorageClassConverter to map "
573 "numeric memory spaces beforehand\n");
574 return nullptr;
575 }
576 spirv::StorageClass storageClass = attr.getValue();
577
578 if (isa<IntegerType>(Val: type.getElementType())) {
579 if (type.getElementTypeBitWidth() == 1)
580 return convertBoolMemrefType(targetEnv, options, type, storageClass);
581 if (type.getElementTypeBitWidth() < 8)
582 return convertSubByteMemrefType(targetEnv, options, type, storageClass);
583 }
584
585 Type arrayElemType;
586 Type elementType = type.getElementType();
587 if (auto vecType = dyn_cast<VectorType>(Val&: elementType)) {
588 arrayElemType =
589 convertVectorType(targetEnv, options, type: vecType, storageClass);
590 } else if (auto complexType = dyn_cast<ComplexType>(Val&: elementType)) {
591 arrayElemType =
592 convertComplexType(targetEnv, options, type: complexType, storageClass);
593 } else if (auto scalarType = dyn_cast<spirv::ScalarType>(Val&: elementType)) {
594 arrayElemType =
595 convertScalarType(targetEnv, options, type: scalarType, storageClass);
596 } else if (auto indexType = dyn_cast<IndexType>(Val&: elementType)) {
597 type = cast<MemRefType>(Val: convertIndexElementType(type, options));
598 arrayElemType = type.getElementType();
599 } else {
600 LLVM_DEBUG(
601 llvm::dbgs()
602 << type
603 << " unhandled: can only convert scalar or vector element type\n");
604 return nullptr;
605 }
606 if (!arrayElemType)
607 return nullptr;
608
609 std::optional<int64_t> arrayElemSize =
610 getTypeNumBytes(options, type: arrayElemType);
611 if (!arrayElemSize) {
612 LLVM_DEBUG(llvm::dbgs()
613 << type << " illegal: cannot deduce converted element size\n");
614 return nullptr;
615 }
616
617 if (!type.hasStaticShape()) {
618 // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
619 // to the element.
620 if (targetEnv.allows(spirv::Capability::Kernel))
621 return spirv::PointerType::get(pointeeType: arrayElemType, storageClass);
622 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
623 auto arrayType = spirv::RuntimeArrayType::get(elementType: arrayElemType, stride);
624 // For Vulkan we need extra wrapping struct and array to satisfy interface
625 // needs.
626 return wrapInStructAndGetPointer(elementType: arrayType, storageClass);
627 }
628
629 std::optional<int64_t> memrefSize = getTypeNumBytes(options, type);
630 if (!memrefSize) {
631 LLVM_DEBUG(llvm::dbgs()
632 << type << " illegal: cannot deduce element count\n");
633 return nullptr;
634 }
635
636 if (*memrefSize == 0) {
637 LLVM_DEBUG(llvm::dbgs()
638 << type << " illegal: zero-element memrefs are not supported\n");
639 return nullptr;
640 }
641
642 int64_t arrayElemCount = llvm::divideCeil(Numerator: *memrefSize, Denominator: *arrayElemSize);
643 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
644 auto arrayType = spirv::ArrayType::get(elementType: arrayElemType, elementCount: arrayElemCount, stride);
645 if (targetEnv.allows(spirv::Capability::Kernel))
646 return spirv::PointerType::get(pointeeType: arrayType, storageClass);
647 return wrapInStructAndGetPointer(elementType: arrayType, storageClass);
648}
649
650//===----------------------------------------------------------------------===//
651// Type casting materialization
652//===----------------------------------------------------------------------===//
653
654/// Converts the given `inputs` to the original source `type` considering the
655/// `targetEnv`'s capabilities.
656///
657/// This function is meant to be used for source materialization in type
658/// converters. When the type converter needs to materialize a cast op back
659/// to some original source type, we need to check whether the original source
660/// type is supported in the target environment. If so, we can insert legal
661/// SPIR-V cast ops accordingly.
662///
663/// Note that in SPIR-V the capabilities for storage and compute are separate.
664/// This function is meant to handle the **compute** side; so it does not
665/// involve storage classes in its logic. The storage side is expected to be
666/// handled by MemRef conversion logic.
667static Value castToSourceType(const spirv::TargetEnv &targetEnv,
668 OpBuilder &builder, Type type, ValueRange inputs,
669 Location loc) {
670 // We can only cast one value in SPIR-V.
671 if (inputs.size() != 1) {
672 auto castOp = builder.create<UnrealizedConversionCastOp>(location: loc, args&: type, args&: inputs);
673 return castOp.getResult(i: 0);
674 }
675 Value input = inputs.front();
676
677 // Only support integer types for now. Floating point types to be implemented.
678 if (!isa<IntegerType>(Val: type)) {
679 auto castOp = builder.create<UnrealizedConversionCastOp>(location: loc, args&: type, args&: inputs);
680 return castOp.getResult(i: 0);
681 }
682 auto inputType = cast<IntegerType>(Val: input.getType());
683
684 auto scalarType = dyn_cast<spirv::ScalarType>(Val&: type);
685 if (!scalarType) {
686 auto castOp = builder.create<UnrealizedConversionCastOp>(location: loc, args&: type, args&: inputs);
687 return castOp.getResult(i: 0);
688 }
689
690 // Only support source type with a smaller bitwidth. This would mean we are
691 // truncating to go back so we don't need to worry about the signedness.
692 // For extension, we cannot have enough signal here to decide which op to use.
693 if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
694 auto castOp = builder.create<UnrealizedConversionCastOp>(location: loc, args&: type, args&: inputs);
695 return castOp.getResult(i: 0);
696 }
697
698 // Boolean values would need to use different ops than normal integer values.
699 if (type.isInteger(width: 1)) {
700 Value one = spirv::ConstantOp::getOne(type: inputType, loc, builder);
701 return builder.create<spirv::IEqualOp>(location: loc, args&: input, args&: one);
702 }
703
704 // Check that the source integer type is supported by the environment.
705 SmallVector<ArrayRef<spirv::Extension>, 1> exts;
706 SmallVector<ArrayRef<spirv::Capability>, 2> caps;
707 scalarType.getExtensions(extensions&: exts);
708 scalarType.getCapabilities(capabilities&: caps);
709 if (failed(Result: checkCapabilityRequirements(label: type, targetEnv, candidates: caps)) ||
710 failed(Result: checkExtensionRequirements(label: type, targetEnv, candidates: exts))) {
711 auto castOp = builder.create<UnrealizedConversionCastOp>(location: loc, args&: type, args&: inputs);
712 return castOp.getResult(i: 0);
713 }
714
715 // We've already made sure this is truncating previously, so we don't need to
716 // care about signedness here. Still try to use a corresponding op for better
717 // consistency though.
718 if (type.isSignedInteger()) {
719 return builder.create<spirv::SConvertOp>(location: loc, args&: type, args&: input);
720 }
721 return builder.create<spirv::UConvertOp>(location: loc, args&: type, args&: input);
722}
723
724//===----------------------------------------------------------------------===//
725// Builtin Variables
726//===----------------------------------------------------------------------===//
727
728static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
729 spirv::BuiltIn builtin) {
730 // Look through all global variables in the given `body` block and check if
731 // there is a spirv.GlobalVariable that has the same `builtin` attribute.
732 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
733 if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
734 name: spirv::SPIRVDialect::getAttributeName(
735 decoration: spirv::Decoration::BuiltIn))) {
736 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
737 if (varBuiltIn == builtin) {
738 return varOp;
739 }
740 }
741 }
742 return nullptr;
743}
744
745/// Gets name of global variable for a builtin.
746std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
747 StringRef suffix) {
748 return Twine(prefix).concat(Suffix: stringifyBuiltIn(builtin)).concat(Suffix: suffix).str();
749}
750
751/// Gets or inserts a global variable for a builtin within `body` block.
752static spirv::GlobalVariableOp
753getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
754 Type integerType, OpBuilder &builder,
755 StringRef prefix, StringRef suffix) {
756 if (auto varOp = getBuiltinVariable(body, builtin))
757 return varOp;
758
759 OpBuilder::InsertionGuard guard(builder);
760 builder.setInsertionPointToStart(&body);
761
762 spirv::GlobalVariableOp newVarOp;
763 switch (builtin) {
764 case spirv::BuiltIn::NumWorkgroups:
765 case spirv::BuiltIn::WorkgroupSize:
766 case spirv::BuiltIn::WorkgroupId:
767 case spirv::BuiltIn::LocalInvocationId:
768 case spirv::BuiltIn::GlobalInvocationId: {
769 auto ptrType = spirv::PointerType::get(pointeeType: VectorType::get(shape: {3}, elementType: integerType),
770 storageClass: spirv::StorageClass::Input);
771 std::string name = getBuiltinVarName(builtin, prefix, suffix);
772 newVarOp =
773 builder.create<spirv::GlobalVariableOp>(location: loc, args&: ptrType, args&: name, args&: builtin);
774 break;
775 }
776 case spirv::BuiltIn::SubgroupId:
777 case spirv::BuiltIn::NumSubgroups:
778 case spirv::BuiltIn::SubgroupSize:
779 case spirv::BuiltIn::SubgroupLocalInvocationId: {
780 auto ptrType =
781 spirv::PointerType::get(pointeeType: integerType, storageClass: spirv::StorageClass::Input);
782 std::string name = getBuiltinVarName(builtin, prefix, suffix);
783 newVarOp =
784 builder.create<spirv::GlobalVariableOp>(location: loc, args&: ptrType, args&: name, args&: builtin);
785 break;
786 }
787 default:
788 emitError(loc, message: "unimplemented builtin variable generation for ")
789 << stringifyBuiltIn(builtin);
790 }
791 return newVarOp;
792}
793
794//===----------------------------------------------------------------------===//
795// Push constant storage
796//===----------------------------------------------------------------------===//
797
798/// Returns the pointer type for the push constant storage containing
799/// `elementCount` 32-bit integer values.
800static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
801 Builder &builder,
802 Type indexType) {
803 auto arrayType = spirv::ArrayType::get(elementType: indexType, elementCount,
804 /*stride=*/4);
805 auto structType = spirv::StructType::get(memberTypes: {arrayType}, /*offsetInfo=*/0);
806 return spirv::PointerType::get(pointeeType: structType, storageClass: spirv::StorageClass::PushConstant);
807}
808
809/// Returns the push constant varible containing `elementCount` 32-bit integer
810/// values in `body`. Returns null op if such an op does not exit.
811static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
812 unsigned elementCount) {
813 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
814 auto ptrType = dyn_cast<spirv::PointerType>(Val: varOp.getType());
815 if (!ptrType)
816 continue;
817
818 // Note that Vulkan requires "There must be no more than one push constant
819 // block statically used per shader entry point." So we should always reuse
820 // the existing one.
821 if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
822 auto numElements = cast<spirv::ArrayType>(
823 Val: cast<spirv::StructType>(Val: ptrType.getPointeeType())
824 .getElementType(0))
825 .getNumElements();
826 if (numElements == elementCount)
827 return varOp;
828 }
829 }
830 return nullptr;
831}
832
833/// Gets or inserts a global variable for push constant storage containing
834/// `elementCount` 32-bit integer values in `block`.
835static spirv::GlobalVariableOp
836getOrInsertPushConstantVariable(Location loc, Block &block,
837 unsigned elementCount, OpBuilder &b,
838 Type indexType) {
839 if (auto varOp = getPushConstantVariable(body&: block, elementCount))
840 return varOp;
841
842 auto builder = OpBuilder::atBlockBegin(block: &block, listener: b.getListener());
843 auto type = getPushConstantStorageType(elementCount, builder, indexType);
844 const char *name = "__push_constant_var__";
845 return builder.create<spirv::GlobalVariableOp>(location: loc, args&: type, args&: name,
846 /*initializer=*/args: nullptr);
847}
848
849//===----------------------------------------------------------------------===//
850// func::FuncOp Conversion Patterns
851//===----------------------------------------------------------------------===//
852
853/// A pattern for rewriting function signature to convert arguments of functions
854/// to be of valid SPIR-V types.
855struct FuncOpConversion final : OpConversionPattern<func::FuncOp> {
856 using OpConversionPattern<func::FuncOp>::OpConversionPattern;
857
858 LogicalResult
859 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
860 ConversionPatternRewriter &rewriter) const override {
861 FunctionType fnType = funcOp.getFunctionType();
862 if (fnType.getNumResults() > 1)
863 return failure();
864
865 TypeConverter::SignatureConversion signatureConverter(
866 fnType.getNumInputs());
867 for (const auto &argType : enumerate(First: fnType.getInputs())) {
868 auto convertedType = getTypeConverter()->convertType(t: argType.value());
869 if (!convertedType)
870 return failure();
871 signatureConverter.addInputs(origInputNo: argType.index(), types: convertedType);
872 }
873
874 Type resultType;
875 if (fnType.getNumResults() == 1) {
876 resultType = getTypeConverter()->convertType(t: fnType.getResult(i: 0));
877 if (!resultType)
878 return failure();
879 }
880
881 // Create the converted spirv.func op.
882 auto newFuncOp = rewriter.create<spirv::FuncOp>(
883 location: funcOp.getLoc(), args: funcOp.getName(),
884 args: rewriter.getFunctionType(inputs: signatureConverter.getConvertedTypes(),
885 results: resultType ? TypeRange(resultType)
886 : TypeRange()));
887
888 // Copy over all attributes other than the function name and type.
889 for (const auto &namedAttr : funcOp->getAttrs()) {
890 if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
891 namedAttr.getName() != SymbolTable::getSymbolAttrName())
892 newFuncOp->setAttr(name: namedAttr.getName(), value: namedAttr.getValue());
893 }
894
895 rewriter.inlineRegionBefore(region&: funcOp.getBody(), parent&: newFuncOp.getBody(),
896 before: newFuncOp.end());
897 if (failed(Result: rewriter.convertRegionTypes(
898 region: &newFuncOp.getBody(), converter: *getTypeConverter(), entryConversion: &signatureConverter)))
899 return failure();
900 rewriter.eraseOp(op: funcOp);
901 return success();
902 }
903};
904
905/// A pattern for rewriting function signature to convert vector arguments of
906/// functions to be of valid types
907struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
908 using OpRewritePattern::OpRewritePattern;
909
910 LogicalResult matchAndRewrite(func::FuncOp funcOp,
911 PatternRewriter &rewriter) const override {
912 FunctionType fnType = funcOp.getFunctionType();
913
914 // TODO: Handle declarations.
915 if (funcOp.isDeclaration()) {
916 LLVM_DEBUG(llvm::dbgs()
917 << fnType << " illegal: declarations are unsupported\n");
918 return failure();
919 }
920
921 // Create a new func op with the original type and copy the function body.
922 auto newFuncOp = rewriter.create<func::FuncOp>(location: funcOp.getLoc(),
923 args: funcOp.getName(), args&: fnType);
924 rewriter.inlineRegionBefore(region&: funcOp.getBody(), parent&: newFuncOp.getBody(),
925 before: newFuncOp.end());
926
927 Location loc = newFuncOp.getBody().getLoc();
928
929 Block &entryBlock = newFuncOp.getBlocks().front();
930 OpBuilder::InsertionGuard guard(rewriter);
931 rewriter.setInsertionPointToStart(&entryBlock);
932
933 TypeConverter::SignatureConversion oneToNTypeMapping(
934 fnType.getInputs().size());
935
936 // For arguments that are of illegal types and require unrolling.
937 // `unrolledInputNums` stores the indices of arguments that result from
938 // unrolling in the new function signature. `newInputNo` is a counter.
939 SmallVector<size_t> unrolledInputNums;
940 size_t newInputNo = 0;
941
942 // For arguments that are of legal types and do not require unrolling.
943 // `tmpOps` stores a mapping from temporary operations that serve as
944 // placeholders for new arguments that will be added later. These operations
945 // will be erased once the entry block's argument list is updated.
946 llvm::SmallDenseMap<Operation *, size_t> tmpOps;
947
948 // This counts the number of new operations created.
949 size_t newOpCount = 0;
950
951 // Enumerate through the arguments.
952 for (auto [origInputNo, origType] : enumerate(First: fnType.getInputs())) {
953 // Check whether the argument is of vector type.
954 auto origVecType = dyn_cast<VectorType>(Val: origType);
955 if (!origVecType) {
956 // We need a placeholder for the old argument that will be erased later.
957 Value result = rewriter.create<arith::ConstantOp>(
958 location: loc, args: origType, args: rewriter.getZeroAttr(type: origType));
959 rewriter.replaceAllUsesWith(from: newFuncOp.getArgument(idx: origInputNo), to: result);
960 tmpOps.insert(KV: {result.getDefiningOp(), newInputNo});
961 oneToNTypeMapping.addInputs(origInputNo, types: origType);
962 ++newInputNo;
963 ++newOpCount;
964 continue;
965 }
966 // Check whether the vector needs unrolling.
967 auto targetShape = getTargetShape(vecType: origVecType);
968 if (!targetShape) {
969 // We need a placeholder for the old argument that will be erased later.
970 Value result = rewriter.create<arith::ConstantOp>(
971 location: loc, args: origType, args: rewriter.getZeroAttr(type: origType));
972 rewriter.replaceAllUsesWith(from: newFuncOp.getArgument(idx: origInputNo), to: result);
973 tmpOps.insert(KV: {result.getDefiningOp(), newInputNo});
974 oneToNTypeMapping.addInputs(origInputNo, types: origType);
975 ++newInputNo;
976 ++newOpCount;
977 continue;
978 }
979 VectorType unrolledType =
980 VectorType::get(shape: *targetShape, elementType: origVecType.getElementType());
981 auto originalShape =
982 llvm::to_vector_of<int64_t, 4>(Range: origVecType.getShape());
983
984 // Prepare the result vector.
985 Value result = rewriter.create<arith::ConstantOp>(
986 location: loc, args&: origVecType, args: rewriter.getZeroAttr(type: origVecType));
987 ++newOpCount;
988 // Prepare the placeholder for the new arguments that will be added later.
989 Value dummy = rewriter.create<arith::ConstantOp>(
990 location: loc, args&: unrolledType, args: rewriter.getZeroAttr(type: unrolledType));
991 ++newOpCount;
992
993 // Create the `vector.insert_strided_slice` ops.
994 SmallVector<int64_t> strides(targetShape->size(), 1);
995 SmallVector<Type> newTypes;
996 for (SmallVector<int64_t> offsets :
997 StaticTileOffsetRange(originalShape, *targetShape)) {
998 result = rewriter.create<vector::InsertStridedSliceOp>(
999 location: loc, args&: dummy, args&: result, args&: offsets, args&: strides);
1000 newTypes.push_back(Elt: unrolledType);
1001 unrolledInputNums.push_back(Elt: newInputNo);
1002 ++newInputNo;
1003 ++newOpCount;
1004 }
1005 rewriter.replaceAllUsesWith(from: newFuncOp.getArgument(idx: origInputNo), to: result);
1006 oneToNTypeMapping.addInputs(origInputNo, types: newTypes);
1007 }
1008
1009 // Change the function signature.
1010 auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
1011 auto newFnType = fnType.clone(inputs: convertedTypes, results: fnType.getResults());
1012 rewriter.modifyOpInPlace(root: newFuncOp,
1013 callable: [&] { newFuncOp.setFunctionType(newFnType); });
1014
1015 // Update the arguments in the entry block.
1016 entryBlock.eraseArguments(start: 0, num: fnType.getNumInputs());
1017 SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
1018 entryBlock.addArguments(types: convertedTypes, locs);
1019
1020 // Replace all uses of placeholders for initially legal arguments with their
1021 // original function arguments (that were added to `newFuncOp`).
1022 for (auto &[placeholderOp, argIdx] : tmpOps) {
1023 if (!placeholderOp)
1024 continue;
1025 Value replacement = newFuncOp.getArgument(idx: argIdx);
1026 rewriter.replaceAllUsesWith(from: placeholderOp->getResult(idx: 0), to: replacement);
1027 }
1028
1029 // Replace dummy operands of new `vector.insert_strided_slice` ops with
1030 // their corresponding new function arguments. The new
1031 // `vector.insert_strided_slice` ops are inserted only into the entry block,
1032 // so iterating over that block is sufficient.
1033 size_t unrolledInputIdx = 0;
1034 for (auto [count, op] : enumerate(First&: entryBlock.getOperations())) {
1035 Operation &curOp = op;
1036 // Since all newly created operations are in the beginning, reaching the
1037 // end of them means that any later `vector.insert_strided_slice` should
1038 // not be touched.
1039 if (count >= newOpCount)
1040 continue;
1041 if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(Val&: op)) {
1042 size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
1043 rewriter.modifyOpInPlace(root: &curOp, callable: [&] {
1044 curOp.setOperand(idx: 0, value: newFuncOp.getArgument(idx: unrolledInputNo));
1045 });
1046 ++unrolledInputIdx;
1047 }
1048 }
1049
1050 // Erase the original funcOp. The `tmpOps` do not need to be erased since
1051 // they have no uses and will be handled by dead-code elimination.
1052 rewriter.eraseOp(op: funcOp);
1053 return success();
1054 }
1055};
1056
1057//===----------------------------------------------------------------------===//
1058// func::ReturnOp Conversion Patterns
1059//===----------------------------------------------------------------------===//
1060
1061/// A pattern for rewriting function signature and the return op to convert
1062/// vectors to be of valid types.
1063struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
1064 using OpRewritePattern::OpRewritePattern;
1065
1066 LogicalResult matchAndRewrite(func::ReturnOp returnOp,
1067 PatternRewriter &rewriter) const override {
1068 // Check whether the parent funcOp is valid.
1069 auto funcOp = dyn_cast<func::FuncOp>(Val: returnOp->getParentOp());
1070 if (!funcOp)
1071 return failure();
1072
1073 FunctionType fnType = funcOp.getFunctionType();
1074 TypeConverter::SignatureConversion oneToNTypeMapping(
1075 fnType.getResults().size());
1076 Location loc = returnOp.getLoc();
1077
1078 // For the new return op.
1079 SmallVector<Value> newOperands;
1080
1081 // Enumerate through the results.
1082 for (auto [origResultNo, origType] : enumerate(First: fnType.getResults())) {
1083 // Check whether the argument is of vector type.
1084 auto origVecType = dyn_cast<VectorType>(Val: origType);
1085 if (!origVecType) {
1086 oneToNTypeMapping.addInputs(origInputNo: origResultNo, types: origType);
1087 newOperands.push_back(Elt: returnOp.getOperand(i: origResultNo));
1088 continue;
1089 }
1090 // Check whether the vector needs unrolling.
1091 auto targetShape = getTargetShape(vecType: origVecType);
1092 if (!targetShape) {
1093 // The original argument can be used.
1094 oneToNTypeMapping.addInputs(origInputNo: origResultNo, types: origType);
1095 newOperands.push_back(Elt: returnOp.getOperand(i: origResultNo));
1096 continue;
1097 }
1098 VectorType unrolledType =
1099 VectorType::get(shape: *targetShape, elementType: origVecType.getElementType());
1100
1101 // Create `vector.extract_strided_slice` ops to form legal vectors from
1102 // the original operand of illegal type.
1103 auto originalShape =
1104 llvm::to_vector_of<int64_t, 4>(Range: origVecType.getShape());
1105 SmallVector<int64_t> strides(originalShape.size(), 1);
1106 SmallVector<int64_t> extractShape(originalShape.size(), 1);
1107 extractShape.back() = targetShape->back();
1108 SmallVector<Type> newTypes;
1109 Value returnValue = returnOp.getOperand(i: origResultNo);
1110 for (SmallVector<int64_t> offsets :
1111 StaticTileOffsetRange(originalShape, *targetShape)) {
1112 Value result = rewriter.create<vector::ExtractStridedSliceOp>(
1113 location: loc, args&: returnValue, args&: offsets, args&: extractShape, args&: strides);
1114 if (originalShape.size() > 1) {
1115 SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
1116 result =
1117 rewriter.create<vector::ExtractOp>(location: loc, args&: result, args&: extractIndices);
1118 }
1119 newOperands.push_back(Elt: result);
1120 newTypes.push_back(Elt: unrolledType);
1121 }
1122 oneToNTypeMapping.addInputs(origInputNo: origResultNo, types: newTypes);
1123 }
1124
1125 // Change the function signature.
1126 auto newFnType =
1127 FunctionType::get(context: rewriter.getContext(), inputs: TypeRange(fnType.getInputs()),
1128 results: TypeRange(oneToNTypeMapping.getConvertedTypes()));
1129 rewriter.modifyOpInPlace(root: funcOp,
1130 callable: [&] { funcOp.setFunctionType(newFnType); });
1131
1132 // Replace the return op using the new operands. This will automatically
1133 // update the entry block as well.
1134 rewriter.replaceOp(op: returnOp,
1135 newOp: rewriter.create<func::ReturnOp>(location: loc, args&: newOperands));
1136
1137 return success();
1138 }
1139};
1140
1141} // namespace
1142
1143//===----------------------------------------------------------------------===//
1144// Public function for builtin variables
1145//===----------------------------------------------------------------------===//
1146
1147Value mlir::spirv::getBuiltinVariableValue(Operation *op,
1148 spirv::BuiltIn builtin,
1149 Type integerType, OpBuilder &builder,
1150 StringRef prefix, StringRef suffix) {
1151 Operation *parent = SymbolTable::getNearestSymbolTable(from: op->getParentOp());
1152 if (!parent) {
1153 op->emitError(message: "expected operation to be within a module-like op");
1154 return nullptr;
1155 }
1156
1157 spirv::GlobalVariableOp varOp =
1158 getOrInsertBuiltinVariable(body&: *parent->getRegion(index: 0).begin(), loc: op->getLoc(),
1159 builtin, integerType, builder, prefix, suffix);
1160 Value ptr = builder.create<spirv::AddressOfOp>(location: op->getLoc(), args&: varOp);
1161 return builder.create<spirv::LoadOp>(location: op->getLoc(), args&: ptr);
1162}
1163
1164//===----------------------------------------------------------------------===//
1165// Public function for pushing constant storage
1166//===----------------------------------------------------------------------===//
1167
1168Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
1169 unsigned offset, Type integerType,
1170 OpBuilder &builder) {
1171 Location loc = op->getLoc();
1172 Operation *parent = SymbolTable::getNearestSymbolTable(from: op->getParentOp());
1173 if (!parent) {
1174 op->emitError(message: "expected operation to be within a module-like op");
1175 return nullptr;
1176 }
1177
1178 spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
1179 loc, block&: parent->getRegion(index: 0).front(), elementCount, b&: builder, indexType: integerType);
1180
1181 Value zeroOp = spirv::ConstantOp::getZero(type: integerType, loc, builder);
1182 Value offsetOp = builder.create<spirv::ConstantOp>(
1183 location: loc, args&: integerType, args: builder.getI32IntegerAttr(value: offset));
1184 auto addrOp = builder.create<spirv::AddressOfOp>(location: loc, args&: varOp);
1185 auto acOp = builder.create<spirv::AccessChainOp>(
1186 location: loc, args&: addrOp, args: llvm::ArrayRef({zeroOp, offsetOp}));
1187 return builder.create<spirv::LoadOp>(location: loc, args&: acOp);
1188}
1189
1190//===----------------------------------------------------------------------===//
1191// Public functions for index calculation
1192//===----------------------------------------------------------------------===//
1193
1194Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
1195 int64_t offset, Type integerType,
1196 Location loc, OpBuilder &builder) {
1197 assert(indices.size() == strides.size() &&
1198 "must provide indices for all dimensions");
1199
1200 // TODO: Consider moving to use affine.apply and patterns converting
1201 // affine.apply to standard ops. This needs converting to SPIR-V passes to be
1202 // broken down into progressive small steps so we can have intermediate steps
1203 // using other dialects. At the moment SPIR-V is the final sink.
1204
1205 Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
1206 location: loc, args&: integerType, args: IntegerAttr::get(type: integerType, value: offset));
1207 for (const auto &index : llvm::enumerate(First&: indices)) {
1208 Value strideVal = builder.createOrFold<spirv::ConstantOp>(
1209 location: loc, args&: integerType,
1210 args: IntegerAttr::get(type: integerType, value: strides[index.index()]));
1211 Value update =
1212 builder.createOrFold<spirv::IMulOp>(location: loc, args&: index.value(), args&: strideVal);
1213 linearizedIndex =
1214 builder.createOrFold<spirv::IAddOp>(location: loc, args&: update, args&: linearizedIndex);
1215 }
1216 return linearizedIndex;
1217}
1218
1219Value mlir::spirv::getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
1220 MemRefType baseType, Value basePtr,
1221 ValueRange indices, Location loc,
1222 OpBuilder &builder) {
1223 // Get base and offset of the MemRefType and verify they are static.
1224
1225 int64_t offset;
1226 SmallVector<int64_t, 4> strides;
1227 if (failed(Result: baseType.getStridesAndOffset(strides, offset)) ||
1228 llvm::is_contained(Range&: strides, Element: ShapedType::kDynamic) ||
1229 ShapedType::isDynamic(dValue: offset)) {
1230 return nullptr;
1231 }
1232
1233 auto indexType = typeConverter.getIndexType();
1234
1235 SmallVector<Value, 2> linearizedIndices;
1236 auto zero = spirv::ConstantOp::getZero(type: indexType, loc, builder);
1237
1238 // Add a '0' at the start to index into the struct.
1239 linearizedIndices.push_back(Elt: zero);
1240
1241 if (baseType.getRank() == 0) {
1242 linearizedIndices.push_back(Elt: zero);
1243 } else {
1244 linearizedIndices.push_back(
1245 Elt: linearizeIndex(indices, strides, offset, integerType: indexType, loc, builder));
1246 }
1247 return builder.create<spirv::AccessChainOp>(location: loc, args&: basePtr, args&: linearizedIndices);
1248}
1249
1250Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter,
1251 MemRefType baseType, Value basePtr,
1252 ValueRange indices, Location loc,
1253 OpBuilder &builder) {
1254 // Get base and offset of the MemRefType and verify they are static.
1255
1256 int64_t offset;
1257 SmallVector<int64_t, 4> strides;
1258 if (failed(Result: baseType.getStridesAndOffset(strides, offset)) ||
1259 llvm::is_contained(Range&: strides, Element: ShapedType::kDynamic) ||
1260 ShapedType::isDynamic(dValue: offset)) {
1261 return nullptr;
1262 }
1263
1264 auto indexType = typeConverter.getIndexType();
1265
1266 SmallVector<Value, 2> linearizedIndices;
1267 Value linearIndex;
1268 if (baseType.getRank() == 0) {
1269 linearIndex = spirv::ConstantOp::getZero(type: indexType, loc, builder);
1270 } else {
1271 linearIndex =
1272 linearizeIndex(indices, strides, offset, integerType: indexType, loc, builder);
1273 }
1274 Type pointeeType =
1275 cast<spirv::PointerType>(Val: basePtr.getType()).getPointeeType();
1276 if (isa<spirv::ArrayType>(Val: pointeeType)) {
1277 linearizedIndices.push_back(Elt: linearIndex);
1278 return builder.create<spirv::AccessChainOp>(location: loc, args&: basePtr,
1279 args&: linearizedIndices);
1280 }
1281 return builder.create<spirv::PtrAccessChainOp>(location: loc, args&: basePtr, args&: linearIndex,
1282 args&: linearizedIndices);
1283}
1284
1285Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter,
1286 MemRefType baseType, Value basePtr,
1287 ValueRange indices, Location loc,
1288 OpBuilder &builder) {
1289
1290 if (typeConverter.allows(capability: spirv::Capability::Kernel)) {
1291 return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc,
1292 builder);
1293 }
1294
1295 return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc,
1296 builder);
1297}
1298
1299//===----------------------------------------------------------------------===//
1300// Public functions for vector unrolling
1301//===----------------------------------------------------------------------===//
1302
1303int mlir::spirv::getComputeVectorSize(int64_t size) {
1304 for (int i : {4, 3, 2}) {
1305 if (size % i == 0)
1306 return i;
1307 }
1308 return 1;
1309}
1310
1311SmallVector<int64_t>
1312mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) {
1313 VectorType srcVectorType = op.getSourceVectorType();
1314 assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
1315 int64_t vectorSize =
1316 mlir::spirv::getComputeVectorSize(size: srcVectorType.getDimSize(idx: 0));
1317 return {vectorSize};
1318}
1319
1320SmallVector<int64_t>
1321mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) {
1322 VectorType vectorType = op.getResultVectorType();
1323 SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
1324 nativeSize.back() =
1325 mlir::spirv::getComputeVectorSize(size: vectorType.getShape().back());
1326 return nativeSize;
1327}
1328
1329std::optional<SmallVector<int64_t>>
1330mlir::spirv::getNativeVectorShape(Operation *op) {
1331 if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
1332 if (auto vecType = dyn_cast<VectorType>(Val: op->getResultTypes()[0])) {
1333 SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
1334 nativeSize.back() =
1335 mlir::spirv::getComputeVectorSize(size: vecType.getShape().back());
1336 return nativeSize;
1337 }
1338 }
1339
1340 return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
1341 .Case<vector::ReductionOp, vector::TransposeOp>(
1342 caseFn: [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
1343 .Default(defaultFn: [](Operation *) { return std::nullopt; });
1344}
1345
1346LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) {
1347 MLIRContext *context = op->getContext();
1348 RewritePatternSet patterns(context);
1349 populateFuncOpVectorRewritePatterns(patterns);
1350 populateReturnOpVectorRewritePatterns(patterns);
1351 // We only want to apply signature conversion once to the existing func ops.
1352 // Without specifying strictMode, the greedy pattern rewriter will keep
1353 // looking for newly created func ops.
1354 return applyPatternsGreedily(op, patterns: std::move(patterns),
1355 config: GreedyRewriteConfig().setStrictness(
1356 GreedyRewriteStrictness::ExistingOps));
1357}
1358
1359LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
1360 MLIRContext *context = op->getContext();
1361
1362 // Unroll vectors in function bodies to native vector size.
1363 {
1364 RewritePatternSet patterns(context);
1365 auto options = vector::UnrollVectorOptions().setNativeShapeFn(
1366 [](auto op) { return mlir::spirv::getNativeVectorShape(op); });
1367 populateVectorUnrollPatterns(patterns, options);
1368 if (failed(Result: applyPatternsGreedily(op, patterns: std::move(patterns))))
1369 return failure();
1370 }
1371
1372 // Convert transpose ops into extract and insert pairs, in preparation of
1373 // further transformations to canonicalize/cancel.
1374 {
1375 RewritePatternSet patterns(context);
1376 vector::populateVectorTransposeLoweringPatterns(
1377 patterns, vectorTransposeLowering: vector::VectorTransposeLowering::EltWise);
1378 vector::populateVectorShapeCastLoweringPatterns(patterns);
1379 if (failed(Result: applyPatternsGreedily(op, patterns: std::move(patterns))))
1380 return failure();
1381 }
1382
1383 // Run canonicalization to cast away leading size-1 dimensions.
1384 {
1385 RewritePatternSet patterns(context);
1386
1387 // We need to pull in casting way leading one dims.
1388 vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
1389 vector::ReductionOp::getCanonicalizationPatterns(results&: patterns, context);
1390 vector::TransposeOp::getCanonicalizationPatterns(results&: patterns, context);
1391
1392 // Decompose different rank insert_strided_slice and n-D
1393 // extract_slided_slice.
1394 vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
1395 patterns);
1396 vector::InsertOp::getCanonicalizationPatterns(results&: patterns, context);
1397 vector::ExtractOp::getCanonicalizationPatterns(results&: patterns, context);
1398
1399 // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
1400 // them up.
1401 vector::BroadcastOp::getCanonicalizationPatterns(results&: patterns, context);
1402 vector::ShapeCastOp::getCanonicalizationPatterns(results&: patterns, context);
1403
1404 if (failed(Result: applyPatternsGreedily(op, patterns: std::move(patterns))))
1405 return failure();
1406 }
1407 return success();
1408}
1409
1410//===----------------------------------------------------------------------===//
1411// SPIR-V TypeConverter
1412//===----------------------------------------------------------------------===//
1413
1414SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
1415 const SPIRVConversionOptions &options)
1416 : targetEnv(targetAttr), options(options) {
1417 // Add conversions. The order matters here: later ones will be tried earlier.
1418
1419 // Allow all SPIR-V dialect specific types. This assumes all builtin types
1420 // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
1421 // were tried before.
1422 //
1423 // TODO: This assumes that the SPIR-V types are valid to use in the given
1424 // target environment, which should be the case if the whole pipeline is
1425 // driven by the same target environment. Still, we probably still want to
1426 // validate and convert to be safe.
1427 addConversion(callback: [](spirv::SPIRVType type) { return type; });
1428
1429 addConversion(callback: [this](IndexType /*indexType*/) { return getIndexType(); });
1430
1431 addConversion(callback: [this](IntegerType intType) -> std::optional<Type> {
1432 if (auto scalarType = dyn_cast<spirv::ScalarType>(Val&: intType))
1433 return convertScalarType(targetEnv: this->targetEnv, options: this->options, type: scalarType);
1434 if (intType.getWidth() < 8)
1435 return convertSubByteIntegerType(options: this->options, type: intType);
1436 return Type();
1437 });
1438
1439 addConversion(callback: [this](FloatType floatType) -> std::optional<Type> {
1440 if (auto scalarType = dyn_cast<spirv::ScalarType>(Val&: floatType))
1441 return convertScalarType(targetEnv: this->targetEnv, options: this->options, type: scalarType);
1442 return Type();
1443 });
1444
1445 addConversion(callback: [this](ComplexType complexType) {
1446 return convertComplexType(targetEnv: this->targetEnv, options: this->options, type: complexType);
1447 });
1448
1449 addConversion(callback: [this](VectorType vectorType) {
1450 return convertVectorType(targetEnv: this->targetEnv, options: this->options, type: vectorType);
1451 });
1452
1453 addConversion(callback: [this](TensorType tensorType) {
1454 return convertTensorType(targetEnv: this->targetEnv, options: this->options, type: tensorType);
1455 });
1456
1457 addConversion(callback: [this](MemRefType memRefType) {
1458 return convertMemrefType(targetEnv: this->targetEnv, options: this->options, type: memRefType);
1459 });
1460
1461 // Register some last line of defense casting logic.
1462 addSourceMaterialization(
1463 callback: [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1464 return castToSourceType(targetEnv: this->targetEnv, builder, type, inputs, loc);
1465 });
1466 addTargetMaterialization(callback: [](OpBuilder &builder, Type type, ValueRange inputs,
1467 Location loc) {
1468 auto cast = builder.create<UnrealizedConversionCastOp>(location: loc, args&: type, args&: inputs);
1469 return cast.getResult(i: 0);
1470 });
1471}
1472
1473Type SPIRVTypeConverter::getIndexType() const {
1474 return ::getIndexType(ctx: getContext(), options);
1475}
1476
1477MLIRContext *SPIRVTypeConverter::getContext() const {
1478 return targetEnv.getAttr().getContext();
1479}
1480
1481bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
1482 return targetEnv.allows(capability);
1483}
1484
1485//===----------------------------------------------------------------------===//
1486// SPIR-V ConversionTarget
1487//===----------------------------------------------------------------------===//
1488
1489std::unique_ptr<SPIRVConversionTarget>
1490SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) {
1491 std::unique_ptr<SPIRVConversionTarget> target(
1492 // std::make_unique does not work here because the constructor is private.
1493 new SPIRVConversionTarget(targetAttr));
1494 SPIRVConversionTarget *targetPtr = target.get();
1495 target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
1496 // We need to capture the raw pointer here because it is stable:
1497 // target will be destroyed once this function is returned.
1498 callback: [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
1499 return target;
1500}
1501
1502SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
1503 : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
1504
1505bool SPIRVConversionTarget::isLegalOp(Operation *op) {
1506 // Make sure this op is available at the given version. Ops not implementing
1507 // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
1508 // SPIR-V versions.
1509 if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(Val: op)) {
1510 std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1511 if (minVersion && *minVersion > this->targetEnv.getVersion()) {
1512 LLVM_DEBUG(llvm::dbgs()
1513 << op->getName() << " illegal: requiring min version "
1514 << spirv::stringifyVersion(*minVersion) << "\n");
1515 return false;
1516 }
1517 }
1518 if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(Val: op)) {
1519 std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1520 if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
1521 LLVM_DEBUG(llvm::dbgs()
1522 << op->getName() << " illegal: requiring max version "
1523 << spirv::stringifyVersion(*maxVersion) << "\n");
1524 return false;
1525 }
1526 }
1527
1528 // Make sure this op's required extensions are allowed to use. Ops not
1529 // implementing QueryExtensionInterface do not require extensions to be
1530 // available.
1531 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(Val: op))
1532 if (failed(Result: checkExtensionRequirements(label: op->getName(), targetEnv: this->targetEnv,
1533 candidates: extensions.getExtensions())))
1534 return false;
1535
1536 // Make sure this op's required extensions are allowed to use. Ops not
1537 // implementing QueryCapabilityInterface do not require capabilities to be
1538 // available.
1539 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(Val: op))
1540 if (failed(Result: checkCapabilityRequirements(label: op->getName(), targetEnv: this->targetEnv,
1541 candidates: capabilities.getCapabilities())))
1542 return false;
1543
1544 SmallVector<Type, 4> valueTypes;
1545 valueTypes.append(in_start: op->operand_type_begin(), in_end: op->operand_type_end());
1546 valueTypes.append(in_start: op->result_type_begin(), in_end: op->result_type_end());
1547
1548 // Ensure that all types have been converted to SPIRV types.
1549 if (llvm::any_of(Range&: valueTypes,
1550 P: [](Type t) { return !isa<spirv::SPIRVType>(Val: t); }))
1551 return false;
1552
1553 // Special treatment for global variables, whose type requirements are
1554 // conveyed by type attributes.
1555 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(Val: op))
1556 valueTypes.push_back(Elt: globalVar.getType());
1557
1558 // Make sure the op's operands/results use types that are allowed by the
1559 // target environment.
1560 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
1561 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
1562 for (Type valueType : valueTypes) {
1563 typeExtensions.clear();
1564 cast<spirv::SPIRVType>(Val&: valueType).getExtensions(extensions&: typeExtensions);
1565 if (failed(Result: checkExtensionRequirements(label: op->getName(), targetEnv: this->targetEnv,
1566 candidates: typeExtensions)))
1567 return false;
1568
1569 typeCapabilities.clear();
1570 cast<spirv::SPIRVType>(Val&: valueType).getCapabilities(capabilities&: typeCapabilities);
1571 if (failed(Result: checkCapabilityRequirements(label: op->getName(), targetEnv: this->targetEnv,
1572 candidates: typeCapabilities)))
1573 return false;
1574 }
1575
1576 return true;
1577}
1578
1579//===----------------------------------------------------------------------===//
1580// Public functions for populating patterns
1581//===----------------------------------------------------------------------===//
1582
1583void mlir::populateBuiltinFuncToSPIRVPatterns(
1584 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1585 patterns.add<FuncOpConversion>(arg: typeConverter, args: patterns.getContext());
1586}
1587
1588void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
1589 patterns.add<FuncOpVectorUnroll>(arg: patterns.getContext());
1590}
1591
1592void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) {
1593 patterns.add<ReturnOpVectorUnroll>(arg: patterns.getContext());
1594}
1595

source code of mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp