1//===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities used to lower to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21#include "mlir/Dialect/Utils/IndexingUtils.h"
22#include "mlir/Dialect/Vector/IR/VectorOps.h"
23#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
24#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
25#include "mlir/IR/BuiltinTypes.h"
26#include "mlir/IR/Operation.h"
27#include "mlir/IR/PatternMatch.h"
28#include "mlir/Pass/Pass.h"
29#include "mlir/Support/LLVM.h"
30#include "mlir/Transforms/DialectConversion.h"
31#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32#include "llvm/ADT/STLExtras.h"
33#include "llvm/ADT/SmallVector.h"
34#include "llvm/ADT/StringExtras.h"
35#include "llvm/Support/Debug.h"
36#include "llvm/Support/LogicalResult.h"
37#include "llvm/Support/MathExtras.h"
38
39#include <functional>
40#include <optional>
41
42#define DEBUG_TYPE "mlir-spirv-conversion"
43
44using namespace mlir;
45
46namespace {
47
48//===----------------------------------------------------------------------===//
49// Utility functions
50//===----------------------------------------------------------------------===//
51
52static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
53 LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
54 if (vecType.isScalable()) {
55 LLVM_DEBUG(llvm::dbgs()
56 << "--scalable vectors are not supported -> BAIL\n");
57 return std::nullopt;
58 }
59 SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
60 std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>(
61 1, mlir::spirv::getComputeVectorSize(size: vecType.getShape().back()));
62 if (!targetShape) {
63 LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
64 return std::nullopt;
65 }
66 auto maybeShapeRatio = computeShapeRatio(shape: unrollShape, subShape: *targetShape);
67 if (!maybeShapeRatio) {
68 LLVM_DEBUG(llvm::dbgs()
69 << "--could not compute integral shape ratio -> BAIL\n");
70 return std::nullopt;
71 }
72 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
73 LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n");
74 return std::nullopt;
75 }
76 LLVM_DEBUG(llvm::dbgs()
77 << "--found an integral shape ratio to unroll to -> SUCCESS\n");
78 return targetShape;
79}
80
81/// Checks that `candidates` extension requirements are possible to be satisfied
82/// with the given `targetEnv`.
83///
84/// `candidates` is a vector of vector for extension requirements following
85/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
86/// convention.
87template <typename LabelT>
88static LogicalResult checkExtensionRequirements(
89 LabelT label, const spirv::TargetEnv &targetEnv,
90 const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
91 for (const auto &ors : candidates) {
92 if (targetEnv.allows(ors))
93 continue;
94
95 LLVM_DEBUG({
96 SmallVector<StringRef> extStrings;
97 for (spirv::Extension ext : ors)
98 extStrings.push_back(spirv::stringifyExtension(ext));
99
100 llvm::dbgs() << label << " illegal: requires at least one extension in ["
101 << llvm::join(extStrings, ", ")
102 << "] but none allowed in target environment\n";
103 });
104 return failure();
105 }
106 return success();
107}
108
109/// Checks that `candidates`capability requirements are possible to be satisfied
110/// with the given `isAllowedFn`.
111///
112/// `candidates` is a vector of vector for capability requirements following
113/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
114/// convention.
115template <typename LabelT>
116static LogicalResult checkCapabilityRequirements(
117 LabelT label, const spirv::TargetEnv &targetEnv,
118 const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
119 for (const auto &ors : candidates) {
120 if (targetEnv.allows(ors))
121 continue;
122
123 LLVM_DEBUG({
124 SmallVector<StringRef> capStrings;
125 for (spirv::Capability cap : ors)
126 capStrings.push_back(spirv::stringifyCapability(cap));
127
128 llvm::dbgs() << label << " illegal: requires at least one capability in ["
129 << llvm::join(capStrings, ", ")
130 << "] but none allowed in target environment\n";
131 });
132 return failure();
133 }
134 return success();
135}
136
137/// Returns true if the given `storageClass` needs explicit layout when used in
138/// Shader environments.
139static bool needsExplicitLayout(spirv::StorageClass storageClass) {
140 switch (storageClass) {
141 case spirv::StorageClass::PhysicalStorageBuffer:
142 case spirv::StorageClass::PushConstant:
143 case spirv::StorageClass::StorageBuffer:
144 case spirv::StorageClass::Uniform:
145 return true;
146 default:
147 return false;
148 }
149}
150
151/// Wraps the given `elementType` in a struct and gets the pointer to the
152/// struct. This is used to satisfy Vulkan interface requirements.
153static spirv::PointerType
154wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
155 auto structType = needsExplicitLayout(storageClass)
156 ? spirv::StructType::get(memberTypes: elementType, /*offsetInfo=*/0)
157 : spirv::StructType::get(memberTypes: elementType);
158 return spirv::PointerType::get(structType, storageClass);
159}
160
161//===----------------------------------------------------------------------===//
162// Type Conversion
163//===----------------------------------------------------------------------===//
164
165static spirv::ScalarType getIndexType(MLIRContext *ctx,
166 const SPIRVConversionOptions &options) {
167 return cast<spirv::ScalarType>(
168 IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
169}
170
171// TODO: This is a utility function that should probably be exposed by the
172// SPIR-V dialect. Keeping it local till the use case arises.
173static std::optional<int64_t>
174getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
175 if (isa<spirv::ScalarType>(Val: type)) {
176 auto bitWidth = type.getIntOrFloatBitWidth();
177 // According to the SPIR-V spec:
178 // "There is no physical size or bit pattern defined for values with boolean
179 // type. If they are stored (in conjunction with OpVariable), they can only
180 // be used with logical addressing operations, not physical, and only with
181 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
182 // Private, Function, Input, and Output."
183 if (bitWidth == 1)
184 return std::nullopt;
185 return bitWidth / 8;
186 }
187
188 if (auto complexType = dyn_cast<ComplexType>(type)) {
189 auto elementSize = getTypeNumBytes(options, complexType.getElementType());
190 if (!elementSize)
191 return std::nullopt;
192 return 2 * *elementSize;
193 }
194
195 if (auto vecType = dyn_cast<VectorType>(type)) {
196 auto elementSize = getTypeNumBytes(options, vecType.getElementType());
197 if (!elementSize)
198 return std::nullopt;
199 return vecType.getNumElements() * *elementSize;
200 }
201
202 if (auto memRefType = dyn_cast<MemRefType>(type)) {
203 // TODO: Layout should also be controlled by the ABI attributes. For now
204 // using the layout from MemRef.
205 int64_t offset;
206 SmallVector<int64_t, 4> strides;
207 if (!memRefType.hasStaticShape() ||
208 failed(memRefType.getStridesAndOffset(strides, offset)))
209 return std::nullopt;
210
211 // To get the size of the memref object in memory, the total size is the
212 // max(stride * dimension-size) computed for all dimensions times the size
213 // of the element.
214 auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
215 if (!elementSize)
216 return std::nullopt;
217
218 if (memRefType.getRank() == 0)
219 return elementSize;
220
221 auto dims = memRefType.getShape();
222 if (llvm::is_contained(dims, ShapedType::kDynamic) ||
223 ShapedType::isDynamic(offset) ||
224 llvm::is_contained(strides, ShapedType::kDynamic))
225 return std::nullopt;
226
227 int64_t memrefSize = -1;
228 for (const auto &shape : enumerate(dims))
229 memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
230
231 return (offset + memrefSize) * *elementSize;
232 }
233
234 if (auto tensorType = dyn_cast<TensorType>(Val&: type)) {
235 if (!tensorType.hasStaticShape())
236 return std::nullopt;
237
238 auto elementSize = getTypeNumBytes(options, type: tensorType.getElementType());
239 if (!elementSize)
240 return std::nullopt;
241
242 int64_t size = *elementSize;
243 for (auto shape : tensorType.getShape())
244 size *= shape;
245
246 return size;
247 }
248
249 // TODO: Add size computation for other types.
250 return std::nullopt;
251}
252
253/// Converts a scalar `type` to a suitable type under the given `targetEnv`.
254static Type
255convertScalarType(const spirv::TargetEnv &targetEnv,
256 const SPIRVConversionOptions &options, spirv::ScalarType type,
257 std::optional<spirv::StorageClass> storageClass = {}) {
258 // Get extension and capability requirements for the given type.
259 SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
260 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
261 type.getExtensions(extensions, storageClass);
262 type.getCapabilities(capabilities, storageClass);
263
264 // If all requirements are met, then we can accept this type as-is.
265 if (succeeded(Result: checkCapabilityRequirements(label: type, targetEnv, candidates: capabilities)) &&
266 succeeded(Result: checkExtensionRequirements(label: type, targetEnv, candidates: extensions)))
267 return type;
268
269 // Otherwise we need to adjust the type, which really means adjusting the
270 // bitwidth given this is a scalar type.
271 if (!options.emulateLT32BitScalarTypes)
272 return nullptr;
273
274 // We only emulate narrower scalar types here and do not truncate results.
275 if (type.getIntOrFloatBitWidth() > 32) {
276 LLVM_DEBUG(llvm::dbgs()
277 << type
278 << " not converted to 32-bit for SPIR-V to avoid truncation\n");
279 return nullptr;
280 }
281
282 if (auto floatType = dyn_cast<FloatType>(type)) {
283 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
284 return Builder(targetEnv.getContext()).getF32Type();
285 }
286
287 auto intType = cast<IntegerType>(type);
288 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
289 return IntegerType::get(targetEnv.getContext(), /*width=*/32,
290 intType.getSignedness());
291}
292
293/// Converts a sub-byte integer `type` to i32 regardless of target environment.
294/// Returns a nullptr for unsupported integer types, including non sub-byte
295/// types.
296///
297/// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use
298/// the above given that these sub-byte types are not supported at all in
299/// SPIR-V; there are no compute/storage capability for them like other
300/// supported integer types.
301static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
302 IntegerType type) {
303 if (type.getWidth() > 8) {
304 LLVM_DEBUG(llvm::dbgs() << "not a subbyte type\n");
305 return nullptr;
306 }
307 if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
308 LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
309 return nullptr;
310 }
311
312 if (!llvm::isPowerOf2_32(Value: type.getWidth())) {
313 LLVM_DEBUG(llvm::dbgs()
314 << "unsupported non-power-of-two bitwidth in sub-byte" << type
315 << "\n");
316 return nullptr;
317 }
318
319 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
320 return IntegerType::get(type.getContext(), /*width=*/32,
321 type.getSignedness());
322}
323
324/// Returns a type with the same shape but with any index element type converted
325/// to the matching integer type. This is a noop when the element type is not
326/// the index type.
327static ShapedType
328convertIndexElementType(ShapedType type,
329 const SPIRVConversionOptions &options) {
330 Type indexType = dyn_cast<IndexType>(type.getElementType());
331 if (!indexType)
332 return type;
333
334 return type.clone(getIndexType(type.getContext(), options));
335}
336
337/// Converts a vector `type` to a suitable type under the given `targetEnv`.
338static Type
339convertVectorType(const spirv::TargetEnv &targetEnv,
340 const SPIRVConversionOptions &options, VectorType type,
341 std::optional<spirv::StorageClass> storageClass = {}) {
342 type = cast<VectorType>(convertIndexElementType(type, options));
343 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
344 if (!scalarType) {
345 // If this is not a spec allowed scalar type, try to handle sub-byte integer
346 // types.
347 auto intType = dyn_cast<IntegerType>(type.getElementType());
348 if (!intType) {
349 LLVM_DEBUG(llvm::dbgs()
350 << type
351 << " illegal: cannot convert non-scalar element type\n");
352 return nullptr;
353 }
354
355 Type elementType = convertSubByteIntegerType(options, intType);
356 if (!elementType)
357 return nullptr;
358
359 if (type.getRank() <= 1 && type.getNumElements() == 1)
360 return elementType;
361
362 if (type.getNumElements() > 4) {
363 LLVM_DEBUG(llvm::dbgs()
364 << type << " illegal: > 4-element unimplemented\n");
365 return nullptr;
366 }
367
368 return VectorType::get(type.getShape(), elementType);
369 }
370
371 if (type.getRank() <= 1 && type.getNumElements() == 1)
372 return convertScalarType(targetEnv, options, scalarType, storageClass);
373
374 if (!spirv::CompositeType::isValid(type)) {
375 LLVM_DEBUG(llvm::dbgs()
376 << type << " illegal: not a valid composite type\n");
377 return nullptr;
378 }
379
380 // Get extension and capability requirements for the given type.
381 SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
382 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
383 cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
384 cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
385
386 // If all requirements are met, then we can accept this type as-is.
387 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
388 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
389 return type;
390
391 auto elementType =
392 convertScalarType(targetEnv, options, scalarType, storageClass);
393 if (elementType)
394 return VectorType::get(type.getShape(), elementType);
395 return nullptr;
396}
397
398static Type
399convertComplexType(const spirv::TargetEnv &targetEnv,
400 const SPIRVConversionOptions &options, ComplexType type,
401 std::optional<spirv::StorageClass> storageClass = {}) {
402 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
403 if (!scalarType) {
404 LLVM_DEBUG(llvm::dbgs()
405 << type << " illegal: cannot convert non-scalar element type\n");
406 return nullptr;
407 }
408
409 auto elementType =
410 convertScalarType(targetEnv, options, scalarType, storageClass);
411 if (!elementType)
412 return nullptr;
413 if (elementType != type.getElementType()) {
414 LLVM_DEBUG(llvm::dbgs()
415 << type << " illegal: complex type emulation unsupported\n");
416 return nullptr;
417 }
418
419 return VectorType::get(2, elementType);
420}
421
422/// Converts a tensor `type` to a suitable type under the given `targetEnv`.
423///
424/// Note that this is mainly for lowering constant tensors. In SPIR-V one can
425/// create composite constants with OpConstantComposite to embed relative large
426/// constant values and use OpCompositeExtract and OpCompositeInsert to
427/// manipulate, like what we do for vectors.
428static Type convertTensorType(const spirv::TargetEnv &targetEnv,
429 const SPIRVConversionOptions &options,
430 TensorType type) {
431 // TODO: Handle dynamic shapes.
432 if (!type.hasStaticShape()) {
433 LLVM_DEBUG(llvm::dbgs()
434 << type << " illegal: dynamic shape unimplemented\n");
435 return nullptr;
436 }
437
438 type = cast<TensorType>(convertIndexElementType(type, options));
439 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(Val: type.getElementType());
440 if (!scalarType) {
441 LLVM_DEBUG(llvm::dbgs()
442 << type << " illegal: cannot convert non-scalar element type\n");
443 return nullptr;
444 }
445
446 std::optional<int64_t> scalarSize = getTypeNumBytes(options, type: scalarType);
447 std::optional<int64_t> tensorSize = getTypeNumBytes(options, type);
448 if (!scalarSize || !tensorSize) {
449 LLVM_DEBUG(llvm::dbgs()
450 << type << " illegal: cannot deduce element count\n");
451 return nullptr;
452 }
453
454 int64_t arrayElemCount = *tensorSize / *scalarSize;
455 if (arrayElemCount == 0) {
456 LLVM_DEBUG(llvm::dbgs()
457 << type << " illegal: cannot handle zero-element tensors\n");
458 return nullptr;
459 }
460
461 Type arrayElemType = convertScalarType(targetEnv, options, type: scalarType);
462 if (!arrayElemType)
463 return nullptr;
464 std::optional<int64_t> arrayElemSize =
465 getTypeNumBytes(options, type: arrayElemType);
466 if (!arrayElemSize) {
467 LLVM_DEBUG(llvm::dbgs()
468 << type << " illegal: cannot deduce converted element size\n");
469 return nullptr;
470 }
471
472 return spirv::ArrayType::get(elementType: arrayElemType, elementCount: arrayElemCount);
473}
474
475static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
476 const SPIRVConversionOptions &options,
477 MemRefType type,
478 spirv::StorageClass storageClass) {
479 unsigned numBoolBits = options.boolNumBits;
480 if (numBoolBits != 8) {
481 LLVM_DEBUG(llvm::dbgs()
482 << "using non-8-bit storage for bool types unimplemented");
483 return nullptr;
484 }
485 auto elementType = dyn_cast<spirv::ScalarType>(
486 IntegerType::get(type.getContext(), numBoolBits));
487 if (!elementType)
488 return nullptr;
489 Type arrayElemType =
490 convertScalarType(targetEnv, options, elementType, storageClass);
491 if (!arrayElemType)
492 return nullptr;
493 std::optional<int64_t> arrayElemSize =
494 getTypeNumBytes(options, type: arrayElemType);
495 if (!arrayElemSize) {
496 LLVM_DEBUG(llvm::dbgs()
497 << type << " illegal: cannot deduce converted element size\n");
498 return nullptr;
499 }
500
501 if (!type.hasStaticShape()) {
502 // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
503 // to the element.
504 if (targetEnv.allows(spirv::Capability::Kernel))
505 return spirv::PointerType::get(arrayElemType, storageClass);
506 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
507 auto arrayType = spirv::RuntimeArrayType::get(elementType: arrayElemType, stride);
508 // For Vulkan we need extra wrapping struct and array to satisfy interface
509 // needs.
510 return wrapInStructAndGetPointer(arrayType, storageClass);
511 }
512
513 if (type.getNumElements() == 0) {
514 LLVM_DEBUG(llvm::dbgs()
515 << type << " illegal: zero-element memrefs are not supported\n");
516 return nullptr;
517 }
518
519 int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
520 int64_t arrayElemCount = llvm::divideCeil(Numerator: memrefSize, Denominator: *arrayElemSize);
521 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
522 auto arrayType = spirv::ArrayType::get(elementType: arrayElemType, elementCount: arrayElemCount, stride);
523 if (targetEnv.allows(spirv::Capability::Kernel))
524 return spirv::PointerType::get(arrayType, storageClass);
525 return wrapInStructAndGetPointer(arrayType, storageClass);
526}
527
528static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
529 const SPIRVConversionOptions &options,
530 MemRefType type,
531 spirv::StorageClass storageClass) {
532 IntegerType elementType = cast<IntegerType>(type.getElementType());
533 Type arrayElemType = convertSubByteIntegerType(options, elementType);
534 if (!arrayElemType)
535 return nullptr;
536 int64_t arrayElemSize = *getTypeNumBytes(options, type: arrayElemType);
537
538 if (!type.hasStaticShape()) {
539 // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
540 // to the element.
541 if (targetEnv.allows(spirv::Capability::Kernel))
542 return spirv::PointerType::get(arrayElemType, storageClass);
543 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
544 auto arrayType = spirv::RuntimeArrayType::get(elementType: arrayElemType, stride);
545 // For Vulkan we need extra wrapping struct and array to satisfy interface
546 // needs.
547 return wrapInStructAndGetPointer(arrayType, storageClass);
548 }
549
550 if (type.getNumElements() == 0) {
551 LLVM_DEBUG(llvm::dbgs()
552 << type << " illegal: zero-element memrefs are not supported\n");
553 return nullptr;
554 }
555
556 int64_t memrefSize =
557 llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);
558 int64_t arrayElemCount = llvm::divideCeil(Numerator: memrefSize, Denominator: arrayElemSize);
559 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
560 auto arrayType = spirv::ArrayType::get(elementType: arrayElemType, elementCount: arrayElemCount, stride);
561 if (targetEnv.allows(spirv::Capability::Kernel))
562 return spirv::PointerType::get(arrayType, storageClass);
563 return wrapInStructAndGetPointer(arrayType, storageClass);
564}
565
566static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
567 const SPIRVConversionOptions &options,
568 MemRefType type) {
569 auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
570 if (!attr) {
571 LLVM_DEBUG(
572 llvm::dbgs()
573 << type
574 << " illegal: expected memory space to be a SPIR-V storage class "
575 "attribute; please use MemorySpaceToStorageClassConverter to map "
576 "numeric memory spaces beforehand\n");
577 return nullptr;
578 }
579 spirv::StorageClass storageClass = attr.getValue();
580
581 if (isa<IntegerType>(type.getElementType())) {
582 if (type.getElementTypeBitWidth() == 1)
583 return convertBoolMemrefType(targetEnv, options, type, storageClass);
584 if (type.getElementTypeBitWidth() < 8)
585 return convertSubByteMemrefType(targetEnv, options, type, storageClass);
586 }
587
588 Type arrayElemType;
589 Type elementType = type.getElementType();
590 if (auto vecType = dyn_cast<VectorType>(elementType)) {
591 arrayElemType =
592 convertVectorType(targetEnv, options, vecType, storageClass);
593 } else if (auto complexType = dyn_cast<ComplexType>(elementType)) {
594 arrayElemType =
595 convertComplexType(targetEnv, options, complexType, storageClass);
596 } else if (auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
597 arrayElemType =
598 convertScalarType(targetEnv, options, scalarType, storageClass);
599 } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
600 type = cast<MemRefType>(convertIndexElementType(type, options));
601 arrayElemType = type.getElementType();
602 } else {
603 LLVM_DEBUG(
604 llvm::dbgs()
605 << type
606 << " unhandled: can only convert scalar or vector element type\n");
607 return nullptr;
608 }
609 if (!arrayElemType)
610 return nullptr;
611
612 std::optional<int64_t> arrayElemSize =
613 getTypeNumBytes(options, type: arrayElemType);
614 if (!arrayElemSize) {
615 LLVM_DEBUG(llvm::dbgs()
616 << type << " illegal: cannot deduce converted element size\n");
617 return nullptr;
618 }
619
620 if (!type.hasStaticShape()) {
621 // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
622 // to the element.
623 if (targetEnv.allows(spirv::Capability::Kernel))
624 return spirv::PointerType::get(arrayElemType, storageClass);
625 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
626 auto arrayType = spirv::RuntimeArrayType::get(elementType: arrayElemType, stride);
627 // For Vulkan we need extra wrapping struct and array to satisfy interface
628 // needs.
629 return wrapInStructAndGetPointer(arrayType, storageClass);
630 }
631
632 std::optional<int64_t> memrefSize = getTypeNumBytes(options, type);
633 if (!memrefSize) {
634 LLVM_DEBUG(llvm::dbgs()
635 << type << " illegal: cannot deduce element count\n");
636 return nullptr;
637 }
638
639 if (*memrefSize == 0) {
640 LLVM_DEBUG(llvm::dbgs()
641 << type << " illegal: zero-element memrefs are not supported\n");
642 return nullptr;
643 }
644
645 int64_t arrayElemCount = llvm::divideCeil(Numerator: *memrefSize, Denominator: *arrayElemSize);
646 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
647 auto arrayType = spirv::ArrayType::get(elementType: arrayElemType, elementCount: arrayElemCount, stride);
648 if (targetEnv.allows(spirv::Capability::Kernel))
649 return spirv::PointerType::get(arrayType, storageClass);
650 return wrapInStructAndGetPointer(arrayType, storageClass);
651}
652
653//===----------------------------------------------------------------------===//
654// Type casting materialization
655//===----------------------------------------------------------------------===//
656
657/// Converts the given `inputs` to the original source `type` considering the
658/// `targetEnv`'s capabilities.
659///
660/// This function is meant to be used for source materialization in type
661/// converters. When the type converter needs to materialize a cast op back
662/// to some original source type, we need to check whether the original source
663/// type is supported in the target environment. If so, we can insert legal
664/// SPIR-V cast ops accordingly.
665///
666/// Note that in SPIR-V the capabilities for storage and compute are separate.
667/// This function is meant to handle the **compute** side; so it does not
668/// involve storage classes in its logic. The storage side is expected to be
669/// handled by MemRef conversion logic.
670static Value castToSourceType(const spirv::TargetEnv &targetEnv,
671 OpBuilder &builder, Type type, ValueRange inputs,
672 Location loc) {
673 // We can only cast one value in SPIR-V.
674 if (inputs.size() != 1) {
675 auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
676 return castOp.getResult(0);
677 }
678 Value input = inputs.front();
679
680 // Only support integer types for now. Floating point types to be implemented.
681 if (!isa<IntegerType>(Val: type)) {
682 auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
683 return castOp.getResult(0);
684 }
685 auto inputType = cast<IntegerType>(input.getType());
686
687 auto scalarType = dyn_cast<spirv::ScalarType>(Val&: type);
688 if (!scalarType) {
689 auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
690 return castOp.getResult(0);
691 }
692
693 // Only support source type with a smaller bitwidth. This would mean we are
694 // truncating to go back so we don't need to worry about the signedness.
695 // For extension, we cannot have enough signal here to decide which op to use.
696 if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
697 auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
698 return castOp.getResult(0);
699 }
700
701 // Boolean values would need to use different ops than normal integer values.
702 if (type.isInteger(width: 1)) {
703 Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
704 return builder.create<spirv::IEqualOp>(loc, input, one);
705 }
706
707 // Check that the source integer type is supported by the environment.
708 SmallVector<ArrayRef<spirv::Extension>, 1> exts;
709 SmallVector<ArrayRef<spirv::Capability>, 2> caps;
710 scalarType.getExtensions(exts);
711 scalarType.getCapabilities(caps);
712 if (failed(Result: checkCapabilityRequirements(label: type, targetEnv, candidates: caps)) ||
713 failed(Result: checkExtensionRequirements(label: type, targetEnv, candidates: exts))) {
714 auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
715 return castOp.getResult(0);
716 }
717
718 // We've already made sure this is truncating previously, so we don't need to
719 // care about signedness here. Still try to use a corresponding op for better
720 // consistency though.
721 if (type.isSignedInteger()) {
722 return builder.create<spirv::SConvertOp>(loc, type, input);
723 }
724 return builder.create<spirv::UConvertOp>(loc, type, input);
725}
726
727//===----------------------------------------------------------------------===//
728// Builtin Variables
729//===----------------------------------------------------------------------===//
730
731static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
732 spirv::BuiltIn builtin) {
733 // Look through all global variables in the given `body` block and check if
734 // there is a spirv.GlobalVariable that has the same `builtin` attribute.
735 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
736 if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
737 spirv::SPIRVDialect::getAttributeName(
738 spirv::Decoration::BuiltIn))) {
739 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
740 if (varBuiltIn && *varBuiltIn == builtin) {
741 return varOp;
742 }
743 }
744 }
745 return nullptr;
746}
747
748/// Gets name of global variable for a builtin.
749std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
750 StringRef suffix) {
751 return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
752}
753
754/// Gets or inserts a global variable for a builtin within `body` block.
755static spirv::GlobalVariableOp
756getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
757 Type integerType, OpBuilder &builder,
758 StringRef prefix, StringRef suffix) {
759 if (auto varOp = getBuiltinVariable(body, builtin))
760 return varOp;
761
762 OpBuilder::InsertionGuard guard(builder);
763 builder.setInsertionPointToStart(&body);
764
765 spirv::GlobalVariableOp newVarOp;
766 switch (builtin) {
767 case spirv::BuiltIn::NumWorkgroups:
768 case spirv::BuiltIn::WorkgroupSize:
769 case spirv::BuiltIn::WorkgroupId:
770 case spirv::BuiltIn::LocalInvocationId:
771 case spirv::BuiltIn::GlobalInvocationId: {
772 auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
773 spirv::StorageClass::Input);
774 std::string name = getBuiltinVarName(builtin, prefix, suffix);
775 newVarOp =
776 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
777 break;
778 }
779 case spirv::BuiltIn::SubgroupId:
780 case spirv::BuiltIn::NumSubgroups:
781 case spirv::BuiltIn::SubgroupSize:
782 case spirv::BuiltIn::SubgroupLocalInvocationId: {
783 auto ptrType =
784 spirv::PointerType::get(integerType, spirv::StorageClass::Input);
785 std::string name = getBuiltinVarName(builtin, prefix, suffix);
786 newVarOp =
787 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
788 break;
789 }
790 default:
791 emitError(loc, message: "unimplemented builtin variable generation for ")
792 << stringifyBuiltIn(builtin);
793 }
794 return newVarOp;
795}
796
797//===----------------------------------------------------------------------===//
798// Push constant storage
799//===----------------------------------------------------------------------===//
800
801/// Returns the pointer type for the push constant storage containing
802/// `elementCount` 32-bit integer values.
803static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
804 Builder &builder,
805 Type indexType) {
806 auto arrayType = spirv::ArrayType::get(elementType: indexType, elementCount,
807 /*stride=*/4);
808 auto structType = spirv::StructType::get(memberTypes: {arrayType}, /*offsetInfo=*/0);
809 return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
810}
811
812/// Returns the push constant varible containing `elementCount` 32-bit integer
813/// values in `body`. Returns null op if such an op does not exit.
814static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
815 unsigned elementCount) {
816 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
817 auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
818 if (!ptrType)
819 continue;
820
821 // Note that Vulkan requires "There must be no more than one push constant
822 // block statically used per shader entry point." So we should always reuse
823 // the existing one.
824 if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
825 auto numElements = cast<spirv::ArrayType>(
826 cast<spirv::StructType>(ptrType.getPointeeType())
827 .getElementType(0))
828 .getNumElements();
829 if (numElements == elementCount)
830 return varOp;
831 }
832 }
833 return nullptr;
834}
835
836/// Gets or inserts a global variable for push constant storage containing
837/// `elementCount` 32-bit integer values in `block`.
838static spirv::GlobalVariableOp
839getOrInsertPushConstantVariable(Location loc, Block &block,
840 unsigned elementCount, OpBuilder &b,
841 Type indexType) {
842 if (auto varOp = getPushConstantVariable(block, elementCount))
843 return varOp;
844
845 auto builder = OpBuilder::atBlockBegin(block: &block, listener: b.getListener());
846 auto type = getPushConstantStorageType(elementCount, builder, indexType);
847 const char *name = "__push_constant_var__";
848 return builder.create<spirv::GlobalVariableOp>(loc, type, name,
849 /*initializer=*/nullptr);
850}
851
852//===----------------------------------------------------------------------===//
853// func::FuncOp Conversion Patterns
854//===----------------------------------------------------------------------===//
855
856/// A pattern for rewriting function signature to convert arguments of functions
857/// to be of valid SPIR-V types.
858struct FuncOpConversion final : OpConversionPattern<func::FuncOp> {
859 using OpConversionPattern<func::FuncOp>::OpConversionPattern;
860
861 LogicalResult
862 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
863 ConversionPatternRewriter &rewriter) const override {
864 FunctionType fnType = funcOp.getFunctionType();
865 if (fnType.getNumResults() > 1)
866 return failure();
867
868 TypeConverter::SignatureConversion signatureConverter(
869 fnType.getNumInputs());
870 for (const auto &argType : enumerate(fnType.getInputs())) {
871 auto convertedType = getTypeConverter()->convertType(argType.value());
872 if (!convertedType)
873 return failure();
874 signatureConverter.addInputs(argType.index(), convertedType);
875 }
876
877 Type resultType;
878 if (fnType.getNumResults() == 1) {
879 resultType = getTypeConverter()->convertType(fnType.getResult(0));
880 if (!resultType)
881 return failure();
882 }
883
884 // Create the converted spirv.func op.
885 auto newFuncOp = rewriter.create<spirv::FuncOp>(
886 funcOp.getLoc(), funcOp.getName(),
887 rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
888 resultType ? TypeRange(resultType)
889 : TypeRange()));
890
891 // Copy over all attributes other than the function name and type.
892 for (const auto &namedAttr : funcOp->getAttrs()) {
893 if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
894 namedAttr.getName() != SymbolTable::getSymbolAttrName())
895 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
896 }
897
898 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
899 newFuncOp.end());
900 if (failed(rewriter.convertRegionTypes(
901 region: &newFuncOp.getBody(), converter: *getTypeConverter(), entryConversion: &signatureConverter)))
902 return failure();
903 rewriter.eraseOp(op: funcOp);
904 return success();
905 }
906};
907
908/// A pattern for rewriting function signature to convert vector arguments of
909/// functions to be of valid types
910struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
911 using OpRewritePattern::OpRewritePattern;
912
913 LogicalResult matchAndRewrite(func::FuncOp funcOp,
914 PatternRewriter &rewriter) const override {
915 FunctionType fnType = funcOp.getFunctionType();
916
917 // TODO: Handle declarations.
918 if (funcOp.isDeclaration()) {
919 LLVM_DEBUG(llvm::dbgs()
920 << fnType << " illegal: declarations are unsupported\n");
921 return failure();
922 }
923
924 // Create a new func op with the original type and copy the function body.
925 auto newFuncOp = rewriter.create<func::FuncOp>(funcOp.getLoc(),
926 funcOp.getName(), fnType);
927 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
928 newFuncOp.end());
929
930 Location loc = newFuncOp.getBody().getLoc();
931
932 Block &entryBlock = newFuncOp.getBlocks().front();
933 OpBuilder::InsertionGuard guard(rewriter);
934 rewriter.setInsertionPointToStart(&entryBlock);
935
936 TypeConverter::SignatureConversion oneToNTypeMapping(
937 fnType.getInputs().size());
938
939 // For arguments that are of illegal types and require unrolling.
940 // `unrolledInputNums` stores the indices of arguments that result from
941 // unrolling in the new function signature. `newInputNo` is a counter.
942 SmallVector<size_t> unrolledInputNums;
943 size_t newInputNo = 0;
944
945 // For arguments that are of legal types and do not require unrolling.
946 // `tmpOps` stores a mapping from temporary operations that serve as
947 // placeholders for new arguments that will be added later. These operations
948 // will be erased once the entry block's argument list is updated.
949 llvm::SmallDenseMap<Operation *, size_t> tmpOps;
950
951 // This counts the number of new operations created.
952 size_t newOpCount = 0;
953
954 // Enumerate through the arguments.
955 for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
956 // Check whether the argument is of vector type.
957 auto origVecType = dyn_cast<VectorType>(origType);
958 if (!origVecType) {
959 // We need a placeholder for the old argument that will be erased later.
960 Value result = rewriter.create<arith::ConstantOp>(
961 loc, origType, rewriter.getZeroAttr(origType));
962 rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
963 tmpOps.insert({result.getDefiningOp(), newInputNo});
964 oneToNTypeMapping.addInputs(origInputNo, origType);
965 ++newInputNo;
966 ++newOpCount;
967 continue;
968 }
969 // Check whether the vector needs unrolling.
970 auto targetShape = getTargetShape(origVecType);
971 if (!targetShape) {
972 // We need a placeholder for the old argument that will be erased later.
973 Value result = rewriter.create<arith::ConstantOp>(
974 loc, origType, rewriter.getZeroAttr(origType));
975 rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
976 tmpOps.insert({result.getDefiningOp(), newInputNo});
977 oneToNTypeMapping.addInputs(origInputNo, origType);
978 ++newInputNo;
979 ++newOpCount;
980 continue;
981 }
982 VectorType unrolledType =
983 VectorType::get(*targetShape, origVecType.getElementType());
984 auto originalShape =
985 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
986
987 // Prepare the result vector.
988 Value result = rewriter.create<arith::ConstantOp>(
989 loc, origVecType, rewriter.getZeroAttr(origVecType));
990 ++newOpCount;
991 // Prepare the placeholder for the new arguments that will be added later.
992 Value dummy = rewriter.create<arith::ConstantOp>(
993 loc, unrolledType, rewriter.getZeroAttr(unrolledType));
994 ++newOpCount;
995
996 // Create the `vector.insert_strided_slice` ops.
997 SmallVector<int64_t> strides(targetShape->size(), 1);
998 SmallVector<Type> newTypes;
999 for (SmallVector<int64_t> offsets :
1000 StaticTileOffsetRange(originalShape, *targetShape)) {
1001 result = rewriter.create<vector::InsertStridedSliceOp>(
1002 loc, dummy, result, offsets, strides);
1003 newTypes.push_back(unrolledType);
1004 unrolledInputNums.push_back(newInputNo);
1005 ++newInputNo;
1006 ++newOpCount;
1007 }
1008 rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
1009 oneToNTypeMapping.addInputs(origInputNo, newTypes);
1010 }
1011
1012 // Change the function signature.
1013 auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
1014 auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
1015 rewriter.modifyOpInPlace(newFuncOp,
1016 [&] { newFuncOp.setFunctionType(newFnType); });
1017
1018 // Update the arguments in the entry block.
1019 entryBlock.eraseArguments(0, fnType.getNumInputs());
1020 SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
1021 entryBlock.addArguments(types: convertedTypes, locs);
1022
1023 // Replace the placeholder values with the new arguments. We assume there is
1024 // only one block for now.
1025 size_t unrolledInputIdx = 0;
1026 for (auto [count, op] : enumerate(entryBlock.getOperations())) {
1027 // We first look for operands that are placeholders for initially legal
1028 // arguments.
1029 Operation &curOp = op;
1030 for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
1031 Operation *operandOp = operandVal.getDefiningOp();
1032 if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
1033 size_t idx = operandIdx;
1034 rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
1035 curOp.setOperand(idx, newFuncOp.getArgument(it->second));
1036 });
1037 }
1038 }
1039 // Since all newly created operations are in the beginning, reaching the
1040 // end of them means that any later `vector.insert_strided_slice` should
1041 // not be touched.
1042 if (count >= newOpCount)
1043 continue;
1044 if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1045 size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
1046 rewriter.modifyOpInPlace(&curOp, [&] {
1047 curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
1048 });
1049 ++unrolledInputIdx;
1050 }
1051 }
1052
1053 // Erase the original funcOp. The `tmpOps` do not need to be erased since
1054 // they have no uses and will be handled by dead-code elimination.
1055 rewriter.eraseOp(op: funcOp);
1056 return success();
1057 }
1058};
1059
1060//===----------------------------------------------------------------------===//
1061// func::ReturnOp Conversion Patterns
1062//===----------------------------------------------------------------------===//
1063
1064/// A pattern for rewriting function signature and the return op to convert
1065/// vectors to be of valid types.
1066struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
1067 using OpRewritePattern::OpRewritePattern;
1068
1069 LogicalResult matchAndRewrite(func::ReturnOp returnOp,
1070 PatternRewriter &rewriter) const override {
1071 // Check whether the parent funcOp is valid.
1072 auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
1073 if (!funcOp)
1074 return failure();
1075
1076 FunctionType fnType = funcOp.getFunctionType();
1077 TypeConverter::SignatureConversion oneToNTypeMapping(
1078 fnType.getResults().size());
1079 Location loc = returnOp.getLoc();
1080
1081 // For the new return op.
1082 SmallVector<Value> newOperands;
1083
1084 // Enumerate through the results.
1085 for (auto [origResultNo, origType] : enumerate(fnType.getResults())) {
1086 // Check whether the argument is of vector type.
1087 auto origVecType = dyn_cast<VectorType>(origType);
1088 if (!origVecType) {
1089 oneToNTypeMapping.addInputs(origResultNo, origType);
1090 newOperands.push_back(returnOp.getOperand(origResultNo));
1091 continue;
1092 }
1093 // Check whether the vector needs unrolling.
1094 auto targetShape = getTargetShape(origVecType);
1095 if (!targetShape) {
1096 // The original argument can be used.
1097 oneToNTypeMapping.addInputs(origResultNo, origType);
1098 newOperands.push_back(returnOp.getOperand(origResultNo));
1099 continue;
1100 }
1101 VectorType unrolledType =
1102 VectorType::get(*targetShape, origVecType.getElementType());
1103
1104 // Create `vector.extract_strided_slice` ops to form legal vectors from
1105 // the original operand of illegal type.
1106 auto originalShape =
1107 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1108 SmallVector<int64_t> strides(originalShape.size(), 1);
1109 SmallVector<int64_t> extractShape(originalShape.size(), 1);
1110 extractShape.back() = targetShape->back();
1111 SmallVector<Type> newTypes;
1112 Value returnValue = returnOp.getOperand(origResultNo);
1113 for (SmallVector<int64_t> offsets :
1114 StaticTileOffsetRange(originalShape, *targetShape)) {
1115 Value result = rewriter.create<vector::ExtractStridedSliceOp>(
1116 loc, returnValue, offsets, extractShape, strides);
1117 if (originalShape.size() > 1) {
1118 SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
1119 result =
1120 rewriter.create<vector::ExtractOp>(loc, result, extractIndices);
1121 }
1122 newOperands.push_back(result);
1123 newTypes.push_back(unrolledType);
1124 }
1125 oneToNTypeMapping.addInputs(origResultNo, newTypes);
1126 }
1127
1128 // Change the function signature.
1129 auto newFnType =
1130 FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
1131 TypeRange(oneToNTypeMapping.getConvertedTypes()));
1132 rewriter.modifyOpInPlace(funcOp,
1133 [&] { funcOp.setFunctionType(newFnType); });
1134
1135 // Replace the return op using the new operands. This will automatically
1136 // update the entry block as well.
1137 rewriter.replaceOp(returnOp,
1138 rewriter.create<func::ReturnOp>(loc, newOperands));
1139
1140 return success();
1141 }
1142};
1143
1144} // namespace
1145
1146//===----------------------------------------------------------------------===//
1147// Public function for builtin variables
1148//===----------------------------------------------------------------------===//
1149
1150Value mlir::spirv::getBuiltinVariableValue(Operation *op,
1151 spirv::BuiltIn builtin,
1152 Type integerType, OpBuilder &builder,
1153 StringRef prefix, StringRef suffix) {
1154 Operation *parent = SymbolTable::getNearestSymbolTable(from: op->getParentOp());
1155 if (!parent) {
1156 op->emitError(message: "expected operation to be within a module-like op");
1157 return nullptr;
1158 }
1159
1160 spirv::GlobalVariableOp varOp =
1161 getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
1162 builtin, integerType, builder, prefix, suffix);
1163 Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
1164 return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
1165}
1166
1167//===----------------------------------------------------------------------===//
1168// Public function for pushing constant storage
1169//===----------------------------------------------------------------------===//
1170
1171Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
1172 unsigned offset, Type integerType,
1173 OpBuilder &builder) {
1174 Location loc = op->getLoc();
1175 Operation *parent = SymbolTable::getNearestSymbolTable(from: op->getParentOp());
1176 if (!parent) {
1177 op->emitError(message: "expected operation to be within a module-like op");
1178 return nullptr;
1179 }
1180
1181 spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
1182 loc, parent->getRegion(0).front(), elementCount, builder, integerType);
1183
1184 Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
1185 Value offsetOp = builder.create<spirv::ConstantOp>(
1186 loc, integerType, builder.getI32IntegerAttr(offset));
1187 auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
1188 auto acOp = builder.create<spirv::AccessChainOp>(
1189 loc, addrOp, llvm::ArrayRef({zeroOp, offsetOp}));
1190 return builder.create<spirv::LoadOp>(loc, acOp);
1191}
1192
1193//===----------------------------------------------------------------------===//
1194// Public functions for index calculation
1195//===----------------------------------------------------------------------===//
1196
1197Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
1198 int64_t offset, Type integerType,
1199 Location loc, OpBuilder &builder) {
1200 assert(indices.size() == strides.size() &&
1201 "must provide indices for all dimensions");
1202
1203 // TODO: Consider moving to use affine.apply and patterns converting
1204 // affine.apply to standard ops. This needs converting to SPIR-V passes to be
1205 // broken down into progressive small steps so we can have intermediate steps
1206 // using other dialects. At the moment SPIR-V is the final sink.
1207
1208 Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
1209 loc, integerType, IntegerAttr::get(integerType, offset));
1210 for (const auto &index : llvm::enumerate(First&: indices)) {
1211 Value strideVal = builder.createOrFold<spirv::ConstantOp>(
1212 loc, integerType,
1213 IntegerAttr::get(integerType, strides[index.index()]));
1214 Value update =
1215 builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
1216 linearizedIndex =
1217 builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
1218 }
1219 return linearizedIndex;
1220}
1221
1222Value mlir::spirv::getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
1223 MemRefType baseType, Value basePtr,
1224 ValueRange indices, Location loc,
1225 OpBuilder &builder) {
1226 // Get base and offset of the MemRefType and verify they are static.
1227
1228 int64_t offset;
1229 SmallVector<int64_t, 4> strides;
1230 if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1231 llvm::is_contained(strides, ShapedType::kDynamic) ||
1232 ShapedType::isDynamic(offset)) {
1233 return nullptr;
1234 }
1235
1236 auto indexType = typeConverter.getIndexType();
1237
1238 SmallVector<Value, 2> linearizedIndices;
1239 auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
1240
1241 // Add a '0' at the start to index into the struct.
1242 linearizedIndices.push_back(Elt: zero);
1243
1244 if (baseType.getRank() == 0) {
1245 linearizedIndices.push_back(Elt: zero);
1246 } else {
1247 linearizedIndices.push_back(
1248 Elt: linearizeIndex(indices, strides, offset, integerType: indexType, loc, builder));
1249 }
1250 return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
1251}
1252
1253Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter,
1254 MemRefType baseType, Value basePtr,
1255 ValueRange indices, Location loc,
1256 OpBuilder &builder) {
1257 // Get base and offset of the MemRefType and verify they are static.
1258
1259 int64_t offset;
1260 SmallVector<int64_t, 4> strides;
1261 if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1262 llvm::is_contained(strides, ShapedType::kDynamic) ||
1263 ShapedType::isDynamic(offset)) {
1264 return nullptr;
1265 }
1266
1267 auto indexType = typeConverter.getIndexType();
1268
1269 SmallVector<Value, 2> linearizedIndices;
1270 Value linearIndex;
1271 if (baseType.getRank() == 0) {
1272 linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
1273 } else {
1274 linearIndex =
1275 linearizeIndex(indices, strides, offset, integerType: indexType, loc, builder);
1276 }
1277 Type pointeeType =
1278 cast<spirv::PointerType>(Val: basePtr.getType()).getPointeeType();
1279 if (isa<spirv::ArrayType>(Val: pointeeType)) {
1280 linearizedIndices.push_back(Elt: linearIndex);
1281 return builder.create<spirv::AccessChainOp>(loc, basePtr,
1282 linearizedIndices);
1283 }
1284 return builder.create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
1285 linearizedIndices);
1286}
1287
1288Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter,
1289 MemRefType baseType, Value basePtr,
1290 ValueRange indices, Location loc,
1291 OpBuilder &builder) {
1292
1293 if (typeConverter.allows(spirv::Capability::Kernel)) {
1294 return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc,
1295 builder);
1296 }
1297
1298 return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc,
1299 builder);
1300}
1301
1302//===----------------------------------------------------------------------===//
1303// Public functions for vector unrolling
1304//===----------------------------------------------------------------------===//
1305
1306int mlir::spirv::getComputeVectorSize(int64_t size) {
1307 for (int i : {4, 3, 2}) {
1308 if (size % i == 0)
1309 return i;
1310 }
1311 return 1;
1312}
1313
1314SmallVector<int64_t>
1315mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) {
1316 VectorType srcVectorType = op.getSourceVectorType();
1317 assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
1318 int64_t vectorSize =
1319 mlir::spirv::getComputeVectorSize(size: srcVectorType.getDimSize(0));
1320 return {vectorSize};
1321}
1322
1323SmallVector<int64_t>
1324mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) {
1325 VectorType vectorType = op.getResultVectorType();
1326 SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
1327 nativeSize.back() =
1328 mlir::spirv::getComputeVectorSize(size: vectorType.getShape().back());
1329 return nativeSize;
1330}
1331
1332std::optional<SmallVector<int64_t>>
1333mlir::spirv::getNativeVectorShape(Operation *op) {
1334 if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
1335 if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
1336 SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
1337 nativeSize.back() =
1338 mlir::spirv::getComputeVectorSize(size: vecType.getShape().back());
1339 return nativeSize;
1340 }
1341 }
1342
1343 return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
1344 .Case<vector::ReductionOp, vector::TransposeOp>(
1345 [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
1346 .Default([](Operation *) { return std::nullopt; });
1347}
1348
1349LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) {
1350 MLIRContext *context = op->getContext();
1351 RewritePatternSet patterns(context);
1352 populateFuncOpVectorRewritePatterns(patterns);
1353 populateReturnOpVectorRewritePatterns(patterns);
1354 // We only want to apply signature conversion once to the existing func ops.
1355 // Without specifying strictMode, the greedy pattern rewriter will keep
1356 // looking for newly created func ops.
1357 return applyPatternsGreedily(op, std::move(patterns),
1358 GreedyRewriteConfig().setStrictness(
1359 GreedyRewriteStrictness::ExistingOps));
1360}
1361
1362LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
1363 MLIRContext *context = op->getContext();
1364
1365 // Unroll vectors in function bodies to native vector size.
1366 {
1367 RewritePatternSet patterns(context);
1368 auto options = vector::UnrollVectorOptions().setNativeShapeFn(
1369 [](auto op) { return mlir::spirv::getNativeVectorShape(op); });
1370 populateVectorUnrollPatterns(patterns, options);
1371 if (failed(applyPatternsGreedily(op, std::move(patterns))))
1372 return failure();
1373 }
1374
1375 // Convert transpose ops into extract and insert pairs, in preparation of
1376 // further transformations to canonicalize/cancel.
1377 {
1378 RewritePatternSet patterns(context);
1379 vector::populateVectorTransposeLoweringPatterns(
1380 patterns, vector::VectorTransposeLowering::EltWise);
1381 vector::populateVectorShapeCastLoweringPatterns(patterns);
1382 if (failed(applyPatternsGreedily(op, std::move(patterns))))
1383 return failure();
1384 }
1385
1386 // Run canonicalization to cast away leading size-1 dimensions.
1387 {
1388 RewritePatternSet patterns(context);
1389
1390 // We need to pull in casting way leading one dims.
1391 vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
1392 vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
1393 vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
1394
1395 // Decompose different rank insert_strided_slice and n-D
1396 // extract_slided_slice.
1397 vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
1398 patterns);
1399 vector::InsertOp::getCanonicalizationPatterns(patterns, context);
1400 vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
1401
1402 // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
1403 // them up.
1404 vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
1405 vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
1406
1407 if (failed(applyPatternsGreedily(op, std::move(patterns))))
1408 return failure();
1409 }
1410 return success();
1411}
1412
1413//===----------------------------------------------------------------------===//
1414// SPIR-V TypeConverter
1415//===----------------------------------------------------------------------===//
1416
1417SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
1418 const SPIRVConversionOptions &options)
1419 : targetEnv(targetAttr), options(options) {
1420 // Add conversions. The order matters here: later ones will be tried earlier.
1421
1422 // Allow all SPIR-V dialect specific types. This assumes all builtin types
1423 // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
1424 // were tried before.
1425 //
1426 // TODO: This assumes that the SPIR-V types are valid to use in the given
1427 // target environment, which should be the case if the whole pipeline is
1428 // driven by the same target environment. Still, we probably still want to
1429 // validate and convert to be safe.
1430 addConversion(callback: [](spirv::SPIRVType type) { return type; });
1431
1432 addConversion(callback: [this](IndexType /*indexType*/) { return getIndexType(); });
1433
1434 addConversion(callback: [this](IntegerType intType) -> std::optional<Type> {
1435 if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
1436 return convertScalarType(this->targetEnv, this->options, scalarType);
1437 if (intType.getWidth() < 8)
1438 return convertSubByteIntegerType(this->options, intType);
1439 return Type();
1440 });
1441
1442 addConversion(callback: [this](FloatType floatType) -> std::optional<Type> {
1443 if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
1444 return convertScalarType(this->targetEnv, this->options, scalarType);
1445 return Type();
1446 });
1447
1448 addConversion(callback: [this](ComplexType complexType) {
1449 return convertComplexType(this->targetEnv, this->options, complexType);
1450 });
1451
1452 addConversion([this](VectorType vectorType) {
1453 return convertVectorType(this->targetEnv, this->options, vectorType);
1454 });
1455
1456 addConversion(callback: [this](TensorType tensorType) {
1457 return convertTensorType(targetEnv: this->targetEnv, options: this->options, type: tensorType);
1458 });
1459
1460 addConversion([this](MemRefType memRefType) {
1461 return convertMemrefType(this->targetEnv, this->options, memRefType);
1462 });
1463
1464 // Register some last line of defense casting logic.
1465 addSourceMaterialization(
1466 callback: [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1467 return castToSourceType(targetEnv: this->targetEnv, builder, type, inputs, loc);
1468 });
1469 addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
1470 Location loc) {
1471 auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
1472 return cast.getResult(0);
1473 });
1474}
1475
1476Type SPIRVTypeConverter::getIndexType() const {
1477 return ::getIndexType(ctx: getContext(), options);
1478}
1479
1480MLIRContext *SPIRVTypeConverter::getContext() const {
1481 return targetEnv.getAttr().getContext();
1482}
1483
1484bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
1485 return targetEnv.allows(capability);
1486}
1487
1488//===----------------------------------------------------------------------===//
1489// SPIR-V ConversionTarget
1490//===----------------------------------------------------------------------===//
1491
1492std::unique_ptr<SPIRVConversionTarget>
1493SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) {
1494 std::unique_ptr<SPIRVConversionTarget> target(
1495 // std::make_unique does not work here because the constructor is private.
1496 new SPIRVConversionTarget(targetAttr));
1497 SPIRVConversionTarget *targetPtr = target.get();
1498 target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
1499 // We need to capture the raw pointer here because it is stable:
1500 // target will be destroyed once this function is returned.
1501 [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
1502 return target;
1503}
1504
1505SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
1506 : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
1507
1508bool SPIRVConversionTarget::isLegalOp(Operation *op) {
1509 // Make sure this op is available at the given version. Ops not implementing
1510 // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
1511 // SPIR-V versions.
1512 if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
1513 std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1514 if (minVersion && *minVersion > this->targetEnv.getVersion()) {
1515 LLVM_DEBUG(llvm::dbgs()
1516 << op->getName() << " illegal: requiring min version "
1517 << spirv::stringifyVersion(*minVersion) << "\n");
1518 return false;
1519 }
1520 }
1521 if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
1522 std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1523 if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
1524 LLVM_DEBUG(llvm::dbgs()
1525 << op->getName() << " illegal: requiring max version "
1526 << spirv::stringifyVersion(*maxVersion) << "\n");
1527 return false;
1528 }
1529 }
1530
1531 // Make sure this op's required extensions are allowed to use. Ops not
1532 // implementing QueryExtensionInterface do not require extensions to be
1533 // available.
1534 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
1535 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1536 extensions.getExtensions())))
1537 return false;
1538
1539 // Make sure this op's required extensions are allowed to use. Ops not
1540 // implementing QueryCapabilityInterface do not require capabilities to be
1541 // available.
1542 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
1543 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1544 capabilities.getCapabilities())))
1545 return false;
1546
1547 SmallVector<Type, 4> valueTypes;
1548 valueTypes.append(in_start: op->operand_type_begin(), in_end: op->operand_type_end());
1549 valueTypes.append(in_start: op->result_type_begin(), in_end: op->result_type_end());
1550
1551 // Ensure that all types have been converted to SPIRV types.
1552 if (llvm::any_of(Range&: valueTypes,
1553 P: [](Type t) { return !isa<spirv::SPIRVType>(Val: t); }))
1554 return false;
1555
1556 // Special treatment for global variables, whose type requirements are
1557 // conveyed by type attributes.
1558 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
1559 valueTypes.push_back(Elt: globalVar.getType());
1560
1561 // Make sure the op's operands/results use types that are allowed by the
1562 // target environment.
1563 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
1564 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
1565 for (Type valueType : valueTypes) {
1566 typeExtensions.clear();
1567 cast<spirv::SPIRVType>(Val&: valueType).getExtensions(typeExtensions);
1568 if (failed(Result: checkExtensionRequirements(label: op->getName(), targetEnv: this->targetEnv,
1569 candidates: typeExtensions)))
1570 return false;
1571
1572 typeCapabilities.clear();
1573 cast<spirv::SPIRVType>(Val&: valueType).getCapabilities(typeCapabilities);
1574 if (failed(Result: checkCapabilityRequirements(label: op->getName(), targetEnv: this->targetEnv,
1575 candidates: typeCapabilities)))
1576 return false;
1577 }
1578
1579 return true;
1580}
1581
1582//===----------------------------------------------------------------------===//
1583// Public functions for populating patterns
1584//===----------------------------------------------------------------------===//
1585
1586void mlir::populateBuiltinFuncToSPIRVPatterns(
1587 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1588 patterns.add<FuncOpConversion>(arg: typeConverter, args: patterns.getContext());
1589}
1590
1591void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
1592 patterns.add<FuncOpVectorUnroll>(arg: patterns.getContext());
1593}
1594
1595void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) {
1596 patterns.add<ReturnOpVectorUnroll>(arg: patterns.getContext());
1597}
1598

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp