1//===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
14#include "mlir/Conversion/LLVMCommon/Pattern.h"
15#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
20#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
21#include "mlir/IR/BuiltinOps.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/Support/LogicalResult.h"
24#include "mlir/Transforms/DialectConversion.h"
25#include "llvm/Support/Debug.h"
26#include "llvm/Support/FormatVariadic.h"
27
28#define DEBUG_TYPE "spirv-to-llvm-pattern"
29
30using namespace mlir;
31
32//===----------------------------------------------------------------------===//
33// Constants
34//===----------------------------------------------------------------------===//
35
36constexpr unsigned defaultAddressSpace = 0;
37
38//===----------------------------------------------------------------------===//
39// Utility functions
40//===----------------------------------------------------------------------===//
41
42/// Returns true if the given type is a signed integer or vector type.
43static bool isSignedIntegerOrVector(Type type) {
44 if (type.isSignedInteger())
45 return true;
46 if (auto vecType = dyn_cast<VectorType>(type))
47 return vecType.getElementType().isSignedInteger();
48 return false;
49}
50
51/// Returns true if the given type is an unsigned integer or vector type
52static bool isUnsignedIntegerOrVector(Type type) {
53 if (type.isUnsignedInteger())
54 return true;
55 if (auto vecType = dyn_cast<VectorType>(type))
56 return vecType.getElementType().isUnsignedInteger();
57 return false;
58}
59
60/// Returns the width of an integer or of the element type of an integer vector,
61/// if applicable.
62static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) {
63 if (auto intType = dyn_cast<IntegerType>(type))
64 return intType.getWidth();
65 if (auto vecType = dyn_cast<VectorType>(type))
66 if (auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
67 return intType.getWidth();
68 return std::nullopt;
69}
70
71/// Returns the bit width of integer, float or vector of float or integer values
72static unsigned getBitWidth(Type type) {
73 assert((type.isIntOrFloat() || isa<VectorType>(type)) &&
74 "bitwidth is not supported for this type");
75 if (type.isIntOrFloat())
76 return type.getIntOrFloatBitWidth();
77 auto vecType = dyn_cast<VectorType>(type);
78 auto elementType = vecType.getElementType();
79 assert(elementType.isIntOrFloat() &&
80 "only integers and floats have a bitwidth");
81 return elementType.getIntOrFloatBitWidth();
82}
83
84/// Returns the bit width of LLVMType integer or vector.
85static unsigned getLLVMTypeBitWidth(Type type) {
86 return cast<IntegerType>((LLVM::isCompatibleVectorType(type)
87 ? LLVM::getVectorElementType(type)
88 : type))
89 .getWidth();
90}
91
92/// Creates `IntegerAttribute` with all bits set for given type
93static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
94 if (auto vecType = dyn_cast<VectorType>(type)) {
95 auto integerType = cast<IntegerType>(vecType.getElementType());
96 return builder.getIntegerAttr(integerType, -1);
97 }
98 auto integerType = cast<IntegerType>(type);
99 return builder.getIntegerAttr(integerType, -1);
100}
101
102/// Creates `llvm.mlir.constant` with all bits set for the given type.
103static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
104 PatternRewriter &rewriter) {
105 if (isa<VectorType>(Val: srcType)) {
106 return rewriter.create<LLVM::ConstantOp>(
107 loc, dstType,
108 SplatElementsAttr::get(cast<ShapedType>(srcType),
109 minusOneIntegerAttribute(srcType, rewriter)));
110 }
111 return rewriter.create<LLVM::ConstantOp>(
112 loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
113}
114
115/// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
116static Value createFPConstant(Location loc, Type srcType, Type dstType,
117 PatternRewriter &rewriter, double value) {
118 if (auto vecType = dyn_cast<VectorType>(srcType)) {
119 auto floatType = cast<FloatType>(vecType.getElementType());
120 return rewriter.create<LLVM::ConstantOp>(
121 loc, dstType,
122 SplatElementsAttr::get(vecType,
123 rewriter.getFloatAttr(floatType, value)));
124 }
125 auto floatType = cast<FloatType>(Val&: srcType);
126 return rewriter.create<LLVM::ConstantOp>(
127 loc, dstType, rewriter.getFloatAttr(floatType, value));
128}
129
130/// Utility function for bitfield ops:
131/// - `BitFieldInsert`
132/// - `BitFieldSExtract`
133/// - `BitFieldUExtract`
134/// Truncates or extends the value. If the bitwidth of the value is the same as
135/// `llvmType` bitwidth, the value remains unchanged.
136static Value optionallyTruncateOrExtend(Location loc, Value value,
137 Type llvmType,
138 PatternRewriter &rewriter) {
139 auto srcType = value.getType();
140 unsigned targetBitWidth = getLLVMTypeBitWidth(type: llvmType);
141 unsigned valueBitWidth = LLVM::isCompatibleType(type: srcType)
142 ? getLLVMTypeBitWidth(type: srcType)
143 : getBitWidth(type: srcType);
144
145 if (valueBitWidth < targetBitWidth)
146 return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
147 // If the bit widths of `Count` and `Offset` are greater than the bit width
148 // of the target type, they are truncated. Truncation is safe since `Count`
149 // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
150 // both values can be expressed in 8 bits.
151 if (valueBitWidth > targetBitWidth)
152 return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
153 return value;
154}
155
156/// Broadcasts the value to vector with `numElements` number of elements.
157static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
158 LLVMTypeConverter &typeConverter,
159 ConversionPatternRewriter &rewriter) {
160 auto vectorType = VectorType::get(numElements, toBroadcast.getType());
161 auto llvmVectorType = typeConverter.convertType(vectorType);
162 auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
163 Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
164 for (unsigned i = 0; i < numElements; ++i) {
165 auto index = rewriter.create<LLVM::ConstantOp>(
166 loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
167 broadcasted = rewriter.create<LLVM::InsertElementOp>(
168 loc, llvmVectorType, broadcasted, toBroadcast, index);
169 }
170 return broadcasted;
171}
172
173/// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
174static Value optionallyBroadcast(Location loc, Value value, Type srcType,
175 LLVMTypeConverter &typeConverter,
176 ConversionPatternRewriter &rewriter) {
177 if (auto vectorType = dyn_cast<VectorType>(srcType)) {
178 unsigned numElements = vectorType.getNumElements();
179 return broadcast(loc, toBroadcast: value, numElements, typeConverter, rewriter);
180 }
181 return value;
182}
183
184/// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
185/// `BitFieldUExtract`.
186/// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
187/// a vector type, construct a vector that has:
188/// - same number of elements as `Base`
189/// - each element has the type that is the same as the type of `Offset` or
190/// `Count`
191/// - each element has the same value as `Offset` or `Count`
192/// Then cast `Offset` and `Count` if their bit width is different
193/// from `Base` bit width.
194static Value processCountOrOffset(Location loc, Value value, Type srcType,
195 Type dstType, LLVMTypeConverter &converter,
196 ConversionPatternRewriter &rewriter) {
197 Value broadcasted =
198 optionallyBroadcast(loc, value, srcType, typeConverter&: converter, rewriter);
199 return optionallyTruncateOrExtend(loc, value: broadcasted, llvmType: dstType, rewriter);
200}
201
202/// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
203/// offset to LLVM struct. Otherwise, the conversion is not supported.
204static Type convertStructTypeWithOffset(spirv::StructType type,
205 LLVMTypeConverter &converter) {
206 if (type != VulkanLayoutUtils::decorateType(structType: type))
207 return nullptr;
208
209 SmallVector<Type> elementsVector;
210 if (failed(result: converter.convertTypes(types: type.getElementTypes(), results&: elementsVector)))
211 return nullptr;
212 return LLVM::LLVMStructType::getLiteral(context: type.getContext(), types: elementsVector,
213 /*isPacked=*/false);
214}
215
216/// Converts SPIR-V struct with no offset to packed LLVM struct.
217static Type convertStructTypePacked(spirv::StructType type,
218 LLVMTypeConverter &converter) {
219 SmallVector<Type> elementsVector;
220 if (failed(result: converter.convertTypes(types: type.getElementTypes(), results&: elementsVector)))
221 return nullptr;
222 return LLVM::LLVMStructType::getLiteral(context: type.getContext(), types: elementsVector,
223 /*isPacked=*/true);
224}
225
226/// Creates LLVM dialect constant with the given value.
227static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
228 unsigned value) {
229 return rewriter.create<LLVM::ConstantOp>(
230 loc, IntegerType::get(rewriter.getContext(), 32),
231 rewriter.getIntegerAttr(rewriter.getI32Type(), value));
232}
233
234/// Utility for `spirv.Load` and `spirv.Store` conversion.
235static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
236 ConversionPatternRewriter &rewriter,
237 LLVMTypeConverter &typeConverter,
238 unsigned alignment, bool isVolatile,
239 bool isNonTemporal) {
240 if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
241 auto dstType = typeConverter.convertType(loadOp.getType());
242 if (!dstType)
243 return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed");
244 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
245 loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
246 isVolatile, isNonTemporal);
247 return success();
248 }
249 auto storeOp = cast<spirv::StoreOp>(op);
250 spirv::StoreOpAdaptor adaptor(operands);
251 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
252 adaptor.getPtr(), alignment,
253 isVolatile, isNonTemporal);
254 return success();
255}
256
257//===----------------------------------------------------------------------===//
258// Type conversion
259//===----------------------------------------------------------------------===//
260
261/// Converts SPIR-V array type to LLVM array. Natural stride (according to
262/// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
263/// when converting ops that manipulate array types.
264static std::optional<Type> convertArrayType(spirv::ArrayType type,
265 TypeConverter &converter) {
266 unsigned stride = type.getArrayStride();
267 Type elementType = type.getElementType();
268 auto sizeInBytes = cast<spirv::SPIRVType>(Val&: elementType).getSizeInBytes();
269 if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
270 return std::nullopt;
271
272 auto llvmElementType = converter.convertType(t: elementType);
273 unsigned numElements = type.getNumElements();
274 return LLVM::LLVMArrayType::get(llvmElementType, numElements);
275}
276
277static unsigned mapToOpenCLAddressSpace(spirv::StorageClass storageClass) {
278 // Based on
279 // https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html#_binary_form
280 // and clang/lib/Basic/Targets/SPIR.h.
281 switch (storageClass) {
282#define STORAGE_SPACE_MAP(storage, space) \
283 case spirv::StorageClass::storage: \
284 return space;
285 STORAGE_SPACE_MAP(Function, 0)
286 STORAGE_SPACE_MAP(CrossWorkgroup, 1)
287 STORAGE_SPACE_MAP(Input, 1)
288 STORAGE_SPACE_MAP(UniformConstant, 2)
289 STORAGE_SPACE_MAP(Workgroup, 3)
290 STORAGE_SPACE_MAP(Generic, 4)
291 STORAGE_SPACE_MAP(DeviceOnlyINTEL, 5)
292 STORAGE_SPACE_MAP(HostOnlyINTEL, 6)
293#undef STORAGE_SPACE_MAP
294 default:
295 return defaultAddressSpace;
296 }
297}
298
299static unsigned mapToAddressSpace(spirv::ClientAPI clientAPI,
300 spirv::StorageClass storageClass) {
301 switch (clientAPI) {
302#define CLIENT_MAP(client, storage) \
303 case spirv::ClientAPI::client: \
304 return mapTo##client##AddressSpace(storage);
305 CLIENT_MAP(OpenCL, storageClass)
306#undef CLIENT_MAP
307 default:
308 return defaultAddressSpace;
309 }
310}
311
312/// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
313/// modelled at the moment.
314static Type convertPointerType(spirv::PointerType type,
315 LLVMTypeConverter &converter,
316 spirv::ClientAPI clientAPI) {
317 unsigned addressSpace = mapToAddressSpace(clientAPI, type.getStorageClass());
318 return LLVM::LLVMPointerType::get(type.getContext(), addressSpace);
319}
320
321/// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
322/// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
323/// no modelling of array stride at the moment.
324static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
325 TypeConverter &converter) {
326 if (type.getArrayStride() != 0)
327 return std::nullopt;
328 auto elementType = converter.convertType(t: type.getElementType());
329 return LLVM::LLVMArrayType::get(elementType, 0);
330}
331
332/// Converts SPIR-V struct to LLVM struct. There is no support of structs with
333/// member decorations. Also, only natural offset is supported.
334static Type convertStructType(spirv::StructType type,
335 LLVMTypeConverter &converter) {
336 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
337 type.getMemberDecorations(memberDecorations);
338 if (!memberDecorations.empty())
339 return nullptr;
340 if (type.hasOffset())
341 return convertStructTypeWithOffset(type, converter);
342 return convertStructTypePacked(type, converter);
343}
344
345//===----------------------------------------------------------------------===//
346// Operation conversion
347//===----------------------------------------------------------------------===//
348
349namespace {
350
351class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
352public:
353 using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
354
355 LogicalResult
356 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
357 ConversionPatternRewriter &rewriter) const override {
358 auto dstType = typeConverter.convertType(op.getComponentPtr().getType());
359 if (!dstType)
360 return rewriter.notifyMatchFailure(op, "type conversion failed");
361 // To use GEP we need to add a first 0 index to go through the pointer.
362 auto indices = llvm::to_vector<4>(adaptor.getIndices());
363 Type indexType = op.getIndices().front().getType();
364 auto llvmIndexType = typeConverter.convertType(indexType);
365 if (!llvmIndexType)
366 return rewriter.notifyMatchFailure(op, "type conversion failed");
367 Value zero = rewriter.create<LLVM::ConstantOp>(
368 op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
369 indices.insert(indices.begin(), zero);
370
371 auto elementType = typeConverter.convertType(
372 cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
373 if (!elementType)
374 return rewriter.notifyMatchFailure(op, "type conversion failed");
375 rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
376 adaptor.getBasePtr(), indices);
377 return success();
378 }
379};
380
381class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
382public:
383 using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
384
385 LogicalResult
386 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
387 ConversionPatternRewriter &rewriter) const override {
388 auto dstType = typeConverter.convertType(op.getPointer().getType());
389 if (!dstType)
390 return rewriter.notifyMatchFailure(op, "type conversion failed");
391 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
392 op.getVariable());
393 return success();
394 }
395};
396
397class BitFieldInsertPattern
398 : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
399public:
400 using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion;
401
402 LogicalResult
403 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
404 ConversionPatternRewriter &rewriter) const override {
405 auto srcType = op.getType();
406 auto dstType = typeConverter.convertType(srcType);
407 if (!dstType)
408 return rewriter.notifyMatchFailure(op, "type conversion failed");
409 Location loc = op.getLoc();
410
411 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
412 Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
413 typeConverter, rewriter);
414 Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
415 typeConverter, rewriter);
416
417 // Create a mask with bits set outside [Offset, Offset + Count - 1].
418 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
419 Value maskShiftedByCount =
420 rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
421 Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
422 maskShiftedByCount, minusOne);
423 Value maskShiftedByCountAndOffset =
424 rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
425 Value mask = rewriter.create<LLVM::XOrOp>(
426 loc, dstType, maskShiftedByCountAndOffset, minusOne);
427
428 // Extract unchanged bits from the `Base` that are outside of
429 // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
430 Value baseAndMask =
431 rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
432 Value insertShiftedByOffset =
433 rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
434 rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
435 insertShiftedByOffset);
436 return success();
437 }
438};
439
440/// Converts SPIR-V ConstantOp with scalar or vector type.
441class ConstantScalarAndVectorPattern
442 : public SPIRVToLLVMConversion<spirv::ConstantOp> {
443public:
444 using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
445
446 LogicalResult
447 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
448 ConversionPatternRewriter &rewriter) const override {
449 auto srcType = constOp.getType();
450 if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
451 return failure();
452
453 auto dstType = typeConverter.convertType(srcType);
454 if (!dstType)
455 return rewriter.notifyMatchFailure(constOp, "type conversion failed");
456
457 // SPIR-V constant can be a signed/unsigned integer, which has to be
458 // casted to signless integer when converting to LLVM dialect. Removing the
459 // sign bit may have unexpected behaviour. However, it is better to handle
460 // it case-by-case, given that the purpose of the conversion is not to
461 // cover all possible corner cases.
462 if (isSignedIntegerOrVector(srcType) ||
463 isUnsignedIntegerOrVector(srcType)) {
464 auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
465
466 if (isa<VectorType>(srcType)) {
467 auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
468 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
469 constOp, dstType,
470 dstElementsAttr.mapValues(
471 signlessType, [&](const APInt &value) { return value; }));
472 return success();
473 }
474 auto srcAttr = cast<IntegerAttr>(constOp.getValue());
475 auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
476 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
477 return success();
478 }
479 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
480 constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
481 return success();
482 }
483};
484
485class BitFieldSExtractPattern
486 : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
487public:
488 using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion;
489
490 LogicalResult
491 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
492 ConversionPatternRewriter &rewriter) const override {
493 auto srcType = op.getType();
494 auto dstType = typeConverter.convertType(srcType);
495 if (!dstType)
496 return rewriter.notifyMatchFailure(op, "type conversion failed");
497 Location loc = op.getLoc();
498
499 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
500 Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
501 typeConverter, rewriter);
502 Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
503 typeConverter, rewriter);
504
505 // Create a constant that holds the size of the `Base`.
506 IntegerType integerType;
507 if (auto vecType = dyn_cast<VectorType>(srcType))
508 integerType = cast<IntegerType>(vecType.getElementType());
509 else
510 integerType = cast<IntegerType>(srcType);
511
512 auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
513 Value size =
514 isa<VectorType>(srcType)
515 ? rewriter.create<LLVM::ConstantOp>(
516 loc, dstType,
517 SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize))
518 : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
519
520 // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
521 // at Offset + Count - 1 is the most significant bit now.
522 Value countPlusOffset =
523 rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
524 Value amountToShiftLeft =
525 rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
526 Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
527 loc, dstType, op.getBase(), amountToShiftLeft);
528
529 // Shift the result right, filling the bits with the sign bit.
530 Value amountToShiftRight =
531 rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
532 rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
533 amountToShiftRight);
534 return success();
535 }
536};
537
538class BitFieldUExtractPattern
539 : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
540public:
541 using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion;
542
543 LogicalResult
544 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
545 ConversionPatternRewriter &rewriter) const override {
546 auto srcType = op.getType();
547 auto dstType = typeConverter.convertType(srcType);
548 if (!dstType)
549 return rewriter.notifyMatchFailure(op, "type conversion failed");
550 Location loc = op.getLoc();
551
552 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
553 Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
554 typeConverter, rewriter);
555 Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
556 typeConverter, rewriter);
557
558 // Create a mask with bits set at [0, Count - 1].
559 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
560 Value maskShiftedByCount =
561 rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
562 Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
563 minusOne);
564
565 // Shift `Base` by `Offset` and apply the mask on it.
566 Value shiftedBase =
567 rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
568 rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
569 return success();
570 }
571};
572
573class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
574public:
575 using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion;
576
577 LogicalResult
578 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
579 ConversionPatternRewriter &rewriter) const override {
580 rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
581 branchOp.getTarget());
582 return success();
583 }
584};
585
586class BranchConditionalConversionPattern
587 : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
588public:
589 using SPIRVToLLVMConversion<
590 spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
591
592 LogicalResult
593 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
594 ConversionPatternRewriter &rewriter) const override {
595 // If branch weights exist, map them to 32-bit integer vector.
596 DenseI32ArrayAttr branchWeights = nullptr;
597 if (auto weights = op.getBranchWeights()) {
598 SmallVector<int32_t> weightValues;
599 for (auto weight : weights->getAsRange<IntegerAttr>())
600 weightValues.push_back(weight.getInt());
601 branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues);
602 }
603
604 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
605 op, op.getCondition(), op.getTrueBlockArguments(),
606 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
607 op.getFalseBlock());
608 return success();
609 }
610};
611
612/// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container
613/// type is an aggregate type (struct or array). Otherwise, converts to
614/// `llvm.extractelement` that operates on vectors.
615class CompositeExtractPattern
616 : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
617public:
618 using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion;
619
620 LogicalResult
621 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
622 ConversionPatternRewriter &rewriter) const override {
623 auto dstType = this->typeConverter.convertType(op.getType());
624 if (!dstType)
625 return rewriter.notifyMatchFailure(op, "type conversion failed");
626
627 Type containerType = op.getComposite().getType();
628 if (isa<VectorType>(Val: containerType)) {
629 Location loc = op.getLoc();
630 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
631 Value index = createI32ConstantOf(loc, rewriter, value.getInt());
632 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
633 op, dstType, adaptor.getComposite(), index);
634 return success();
635 }
636
637 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
638 op, adaptor.getComposite(),
639 LLVM::convertArrayToIndices(op.getIndices()));
640 return success();
641 }
642};
643
644/// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container
645/// type is an aggregate type (struct or array). Otherwise, converts to
646/// `llvm.insertelement` that operates on vectors.
647class CompositeInsertPattern
648 : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
649public:
650 using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion;
651
652 LogicalResult
653 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
654 ConversionPatternRewriter &rewriter) const override {
655 auto dstType = this->typeConverter.convertType(op.getType());
656 if (!dstType)
657 return rewriter.notifyMatchFailure(op, "type conversion failed");
658
659 Type containerType = op.getComposite().getType();
660 if (isa<VectorType>(Val: containerType)) {
661 Location loc = op.getLoc();
662 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
663 Value index = createI32ConstantOf(loc, rewriter, value.getInt());
664 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
665 op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
666 return success();
667 }
668
669 rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
670 op, adaptor.getComposite(), adaptor.getObject(),
671 LLVM::convertArrayToIndices(op.getIndices()));
672 return success();
673 }
674};
675
676/// Converts SPIR-V operations that have straightforward LLVM equivalent
677/// into LLVM dialect operations.
678template <typename SPIRVOp, typename LLVMOp>
679class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
680public:
681 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
682
683 LogicalResult
684 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
685 ConversionPatternRewriter &rewriter) const override {
686 auto dstType = this->typeConverter.convertType(op.getType());
687 if (!dstType)
688 return rewriter.notifyMatchFailure(op, "type conversion failed");
689 rewriter.template replaceOpWithNewOp<LLVMOp>(
690 op, dstType, adaptor.getOperands(), op->getAttrs());
691 return success();
692 }
693};
694
695/// Converts `spirv.ExecutionMode` into a global struct constant that holds
696/// execution mode information.
697class ExecutionModePattern
698 : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
699public:
700 using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion;
701
702 LogicalResult
703 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
704 ConversionPatternRewriter &rewriter) const override {
705 // First, create the global struct's name that would be associated with
706 // this entry point's execution mode. We set it to be:
707 // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
708 ModuleOp module = op->getParentOfType<ModuleOp>();
709 spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
710 std::string moduleName;
711 if (module.getName().has_value())
712 moduleName = "_" + module.getName()->str();
713 else
714 moduleName = "";
715 std::string executionModeInfoName = llvm::formatv(
716 "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
717 static_cast<uint32_t>(executionModeAttr.getValue()));
718
719 MLIRContext *context = rewriter.getContext();
720 OpBuilder::InsertionGuard guard(rewriter);
721 rewriter.setInsertionPointToStart(module.getBody());
722
723 // Create a struct type, corresponding to the C struct below.
724 // struct {
725 // int32_t executionMode;
726 // int32_t values[]; // optional values
727 // };
728 auto llvmI32Type = IntegerType::get(context, 32);
729 SmallVector<Type, 2> fields;
730 fields.push_back(Elt: llvmI32Type);
731 ArrayAttr values = op.getValues();
732 if (!values.empty()) {
733 auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
734 fields.push_back(Elt: arrayType);
735 }
736 auto structType = LLVM::LLVMStructType::getLiteral(context, types: fields);
737
738 // Create `llvm.mlir.global` with initializer region containing one block.
739 auto global = rewriter.create<LLVM::GlobalOp>(
740 UnknownLoc::get(context), structType, /*isConstant=*/true,
741 LLVM::Linkage::External, executionModeInfoName, Attribute(),
742 /*alignment=*/0);
743 Location loc = global.getLoc();
744 Region &region = global.getInitializerRegion();
745 Block *block = rewriter.createBlock(parent: &region);
746
747 // Initialize the struct and set the execution mode value.
748 rewriter.setInsertionPoint(block, insertPoint: block->begin());
749 Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
750 Value executionMode = rewriter.create<LLVM::ConstantOp>(
751 loc, llvmI32Type,
752 rewriter.getI32IntegerAttr(
753 static_cast<uint32_t>(executionModeAttr.getValue())));
754 structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue,
755 executionMode, 0);
756
757 // Insert extra operands if they exist into execution mode info struct.
758 for (unsigned i = 0, e = values.size(); i < e; ++i) {
759 auto attr = values.getValue()[i];
760 Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
761 structValue = rewriter.create<LLVM::InsertValueOp>(
762 loc, structValue, entry, ArrayRef<int64_t>({1, i}));
763 }
764 rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
765 rewriter.eraseOp(op: op);
766 return success();
767 }
768};
769
770/// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V
771/// global returns a pointer, whereas in LLVM dialect the global holds an actual
772/// value. This difference is handled by `spirv.mlir.addressof` and
773/// `llvm.mlir.addressof`ops that both return a pointer.
774class GlobalVariablePattern
775 : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
776public:
777 template <typename... Args>
778 GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
779 : SPIRVToLLVMConversion<spirv::GlobalVariableOp>(
780 std::forward<Args>(args)...),
781 clientAPI(clientAPI) {}
782
783 LogicalResult
784 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
785 ConversionPatternRewriter &rewriter) const override {
786 // Currently, there is no support of initialization with a constant value in
787 // SPIR-V dialect. Specialization constants are not considered as well.
788 if (op.getInitializer())
789 return failure();
790
791 auto srcType = cast<spirv::PointerType>(op.getType());
792 auto dstType = typeConverter.convertType(srcType.getPointeeType());
793 if (!dstType)
794 return rewriter.notifyMatchFailure(op, "type conversion failed");
795
796 // Limit conversion to the current invocation only or `StorageBuffer`
797 // required by SPIR-V runner.
798 // This is okay because multiple invocations are not supported yet.
799 auto storageClass = srcType.getStorageClass();
800 switch (storageClass) {
801 case spirv::StorageClass::Input:
802 case spirv::StorageClass::Private:
803 case spirv::StorageClass::Output:
804 case spirv::StorageClass::StorageBuffer:
805 case spirv::StorageClass::UniformConstant:
806 break;
807 default:
808 return failure();
809 }
810
811 // LLVM dialect spec: "If the global value is a constant, storing into it is
812 // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
813 // storage class that is read-only.
814 bool isConstant = (storageClass == spirv::StorageClass::Input) ||
815 (storageClass == spirv::StorageClass::UniformConstant);
816 // SPIR-V spec: "By default, functions and global variables are private to a
817 // module and cannot be accessed by other modules. However, a module may be
818 // written to export or import functions and global (module scope)
819 // variables.". Therefore, map 'Private' storage class to private linkage,
820 // 'Input' and 'Output' to external linkage.
821 auto linkage = storageClass == spirv::StorageClass::Private
822 ? LLVM::Linkage::Private
823 : LLVM::Linkage::External;
824 auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
825 op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
826 /*alignment=*/0, mapToAddressSpace(clientAPI, storageClass));
827
828 // Attach location attribute if applicable
829 if (op.getLocationAttr())
830 newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
831
832 return success();
833 }
834
835private:
836 spirv::ClientAPI clientAPI;
837};
838
839/// Converts SPIR-V cast ops that do not have straightforward LLVM
840/// equivalent in LLVM dialect.
841template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
842class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
843public:
844 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
845
846 LogicalResult
847 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
848 ConversionPatternRewriter &rewriter) const override {
849
850 Type fromType = op.getOperand().getType();
851 Type toType = op.getType();
852
853 auto dstType = this->typeConverter.convertType(toType);
854 if (!dstType)
855 return rewriter.notifyMatchFailure(op, "type conversion failed");
856
857 if (getBitWidth(type: fromType) < getBitWidth(type: toType)) {
858 rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
859 adaptor.getOperands());
860 return success();
861 }
862 if (getBitWidth(type: fromType) > getBitWidth(type: toType)) {
863 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
864 adaptor.getOperands());
865 return success();
866 }
867 return failure();
868 }
869};
870
871class FunctionCallPattern
872 : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
873public:
874 using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion;
875
876 LogicalResult
877 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
878 ConversionPatternRewriter &rewriter) const override {
879 if (callOp.getNumResults() == 0) {
880 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
881 callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
882 return success();
883 }
884
885 // Function returns a single result.
886 auto dstType = typeConverter.convertType(callOp.getType(0));
887 if (!dstType)
888 return rewriter.notifyMatchFailure(callOp, "type conversion failed");
889 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
890 callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
891 return success();
892 }
893};
894
895/// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
896template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
897class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
898public:
899 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
900
901 LogicalResult
902 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
903 ConversionPatternRewriter &rewriter) const override {
904
905 auto dstType = this->typeConverter.convertType(op.getType());
906 if (!dstType)
907 return rewriter.notifyMatchFailure(op, "type conversion failed");
908
909 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
910 op, dstType, predicate, op.getOperand1(), op.getOperand2());
911 return success();
912 }
913};
914
915/// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
916template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
917class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
918public:
919 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
920
921 LogicalResult
922 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
923 ConversionPatternRewriter &rewriter) const override {
924
925 auto dstType = this->typeConverter.convertType(op.getType());
926 if (!dstType)
927 return rewriter.notifyMatchFailure(op, "type conversion failed");
928
929 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
930 op, dstType, predicate, op.getOperand1(), op.getOperand2());
931 return success();
932 }
933};
934
935class InverseSqrtPattern
936 : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> {
937public:
938 using SPIRVToLLVMConversion<spirv::GLInverseSqrtOp>::SPIRVToLLVMConversion;
939
940 LogicalResult
941 matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
942 ConversionPatternRewriter &rewriter) const override {
943 auto srcType = op.getType();
944 auto dstType = typeConverter.convertType(srcType);
945 if (!dstType)
946 return rewriter.notifyMatchFailure(op, "type conversion failed");
947
948 Location loc = op.getLoc();
949 Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
950 Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
951 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
952 return success();
953 }
954};
955
956/// Converts `spirv.Load` and `spirv.Store` to LLVM dialect.
957template <typename SPIRVOp>
958class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
959public:
960 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
961
962 LogicalResult
963 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
964 ConversionPatternRewriter &rewriter) const override {
965 if (!op.getMemoryAccess()) {
966 return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
967 this->typeConverter, /*alignment=*/0,
968 /*isVolatile=*/false,
969 /*isNonTemporal=*/false);
970 }
971 auto memoryAccess = *op.getMemoryAccess();
972 switch (memoryAccess) {
973 case spirv::MemoryAccess::Aligned:
974 case spirv::MemoryAccess::None:
975 case spirv::MemoryAccess::Nontemporal:
976 case spirv::MemoryAccess::Volatile: {
977 unsigned alignment =
978 memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
979 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
980 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
981 return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
982 this->typeConverter, alignment, isVolatile,
983 isNonTemporal);
984 }
985 default:
986 // There is no support of other memory access attributes.
987 return failure();
988 }
989 }
990};
991
992/// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect.
993template <typename SPIRVOp>
994class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
995public:
996 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
997
998 LogicalResult
999 matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
1000 ConversionPatternRewriter &rewriter) const override {
1001 auto srcType = notOp.getType();
1002 auto dstType = this->typeConverter.convertType(srcType);
1003 if (!dstType)
1004 return rewriter.notifyMatchFailure(notOp, "type conversion failed");
1005
1006 Location loc = notOp.getLoc();
1007 IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
1008 auto mask =
1009 isa<VectorType>(srcType)
1010 ? rewriter.create<LLVM::ConstantOp>(
1011 loc, dstType,
1012 SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
1013 : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
1014 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
1015 notOp.getOperand(), mask);
1016 return success();
1017 }
1018};
1019
1020/// A template pattern that erases the given `SPIRVOp`.
1021template <typename SPIRVOp>
1022class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
1023public:
1024 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
1025
1026 LogicalResult
1027 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
1028 ConversionPatternRewriter &rewriter) const override {
1029 rewriter.eraseOp(op);
1030 return success();
1031 }
1032};
1033
1034class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
1035public:
1036 using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
1037
1038 LogicalResult
1039 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1040 ConversionPatternRewriter &rewriter) const override {
1041 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
1042 ArrayRef<Value>());
1043 return success();
1044 }
1045};
1046
1047class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
1048public:
1049 using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
1050
1051 LogicalResult
1052 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1053 ConversionPatternRewriter &rewriter) const override {
1054 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
1055 adaptor.getOperands());
1056 return success();
1057 }
1058};
1059
1060/// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
1061/// should be reachable for conversion to succeed. The structure of the loop in
1062/// LLVM dialect will be the following:
1063///
1064/// +------------------------------------+
1065/// | <code before spirv.mlir.loop> |
1066/// | llvm.br ^header |
1067/// +------------------------------------+
1068/// |
1069/// +----------------+ |
1070/// | | |
1071/// | V V
1072/// | +------------------------------------+
1073/// | | ^header: |
1074/// | | <header code> |
1075/// | | llvm.cond_br %cond, ^body, ^exit |
1076/// | +------------------------------------+
1077/// | |
1078/// | |----------------------+
1079/// | | |
1080/// | V |
1081/// | +------------------------------------+ |
1082/// | | ^body: | |
1083/// | | <body code> | |
1084/// | | llvm.br ^continue | |
1085/// | +------------------------------------+ |
1086/// | | |
1087/// | V |
1088/// | +------------------------------------+ |
1089/// | | ^continue: | |
1090/// | | <continue code> | |
1091/// | | llvm.br ^header | |
1092/// | +------------------------------------+ |
1093/// | | |
1094/// +---------------+ +----------------------+
1095/// |
1096/// V
1097/// +------------------------------------+
1098/// | ^exit: |
1099/// | llvm.br ^remaining |
1100/// +------------------------------------+
1101/// |
1102/// V
1103/// +------------------------------------+
1104/// | ^remaining: |
1105/// | <code after spirv.mlir.loop> |
1106/// +------------------------------------+
1107///
1108class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1109public:
1110 using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion;
1111
1112 LogicalResult
1113 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1114 ConversionPatternRewriter &rewriter) const override {
1115 // There is no support of loop control at the moment.
1116 if (loopOp.getLoopControl() != spirv::LoopControl::None)
1117 return failure();
1118
1119 Location loc = loopOp.getLoc();
1120
1121 // Split the current block after `spirv.mlir.loop`. The remaining ops will
1122 // be used in `endBlock`.
1123 Block *currentBlock = rewriter.getBlock();
1124 auto position = Block::iterator(loopOp);
1125 Block *endBlock = rewriter.splitBlock(block: currentBlock, before: position);
1126
1127 // Remove entry block and create a branch in the current block going to the
1128 // header block.
1129 Block *entryBlock = loopOp.getEntryBlock();
1130 assert(entryBlock->getOperations().size() == 1);
1131 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1132 if (!brOp)
1133 return failure();
1134 Block *headerBlock = loopOp.getHeaderBlock();
1135 rewriter.setInsertionPointToEnd(currentBlock);
1136 rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1137 rewriter.eraseBlock(block: entryBlock);
1138
1139 // Branch from merge block to end block.
1140 Block *mergeBlock = loopOp.getMergeBlock();
1141 Operation *terminator = mergeBlock->getTerminator();
1142 ValueRange terminatorOperands = terminator->getOperands();
1143 rewriter.setInsertionPointToEnd(mergeBlock);
1144 rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1145
1146 rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
1147 rewriter.replaceOp(loopOp, endBlock->getArguments());
1148 return success();
1149 }
1150};
1151
1152/// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header
1153/// block. All blocks within selection should be reachable for conversion to
1154/// succeed.
1155class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1156public:
1157 using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion;
1158
1159 LogicalResult
1160 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1161 ConversionPatternRewriter &rewriter) const override {
1162 // There is no support for `Flatten` or `DontFlatten` selection control at
1163 // the moment. This are just compiler hints and can be performed during the
1164 // optimization passes.
1165 if (op.getSelectionControl() != spirv::SelectionControl::None)
1166 return failure();
1167
1168 // `spirv.mlir.selection` should have at least two blocks: one selection
1169 // header block and one merge block. If no blocks are present, or control
1170 // flow branches straight to merge block (two blocks are present), the op is
1171 // redundant and it is erased.
1172 if (op.getBody().getBlocks().size() <= 2) {
1173 rewriter.eraseOp(op: op);
1174 return success();
1175 }
1176
1177 Location loc = op.getLoc();
1178
1179 // Split the current block after `spirv.mlir.selection`. The remaining ops
1180 // will be used in `continueBlock`.
1181 auto *currentBlock = rewriter.getInsertionBlock();
1182 rewriter.setInsertionPointAfter(op);
1183 auto position = rewriter.getInsertionPoint();
1184 auto *continueBlock = rewriter.splitBlock(block: currentBlock, before: position);
1185
1186 // Extract conditional branch information from the header block. By SPIR-V
1187 // dialect spec, it should contain `spirv.BranchConditional` or
1188 // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the
1189 // moment in the SPIR-V dialect. Remove this block when finished.
1190 auto *headerBlock = op.getHeaderBlock();
1191 assert(headerBlock->getOperations().size() == 1);
1192 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1193 headerBlock->getOperations().front());
1194 if (!condBrOp)
1195 return failure();
1196 rewriter.eraseBlock(block: headerBlock);
1197
1198 // Branch from merge block to continue block.
1199 auto *mergeBlock = op.getMergeBlock();
1200 Operation *terminator = mergeBlock->getTerminator();
1201 ValueRange terminatorOperands = terminator->getOperands();
1202 rewriter.setInsertionPointToEnd(mergeBlock);
1203 rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1204
1205 // Link current block to `true` and `false` blocks within the selection.
1206 Block *trueBlock = condBrOp.getTrueBlock();
1207 Block *falseBlock = condBrOp.getFalseBlock();
1208 rewriter.setInsertionPointToEnd(currentBlock);
1209 rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1210 condBrOp.getTrueTargetOperands(),
1211 falseBlock,
1212 condBrOp.getFalseTargetOperands());
1213
1214 rewriter.inlineRegionBefore(op.getBody(), continueBlock);
1215 rewriter.replaceOp(op, continueBlock->getArguments());
1216 return success();
1217 }
1218};
1219
1220/// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1221/// puts a restriction on `Shift` and `Base` to have the same bit width,
1222/// `Shift` is zero or sign extended to match this specification. Cases when
1223/// `Shift` bit width > `Base` bit width are considered to be illegal.
1224template <typename SPIRVOp, typename LLVMOp>
1225class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1226public:
1227 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
1228
1229 LogicalResult
1230 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
1231 ConversionPatternRewriter &rewriter) const override {
1232
1233 auto dstType = this->typeConverter.convertType(op.getType());
1234 if (!dstType)
1235 return rewriter.notifyMatchFailure(op, "type conversion failed");
1236
1237 Type op1Type = op.getOperand1().getType();
1238 Type op2Type = op.getOperand2().getType();
1239
1240 if (op1Type == op2Type) {
1241 rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1242 adaptor.getOperands());
1243 return success();
1244 }
1245
1246 std::optional<uint64_t> dstTypeWidth =
1247 getIntegerOrVectorElementWidth(dstType);
1248 std::optional<uint64_t> op2TypeWidth =
1249 getIntegerOrVectorElementWidth(type: op2Type);
1250
1251 if (!dstTypeWidth || !op2TypeWidth)
1252 return failure();
1253
1254 Location loc = op.getLoc();
1255 Value extended;
1256 if (op2TypeWidth < dstTypeWidth) {
1257 if (isUnsignedIntegerOrVector(type: op2Type)) {
1258 extended = rewriter.template create<LLVM::ZExtOp>(
1259 loc, dstType, adaptor.getOperand2());
1260 } else {
1261 extended = rewriter.template create<LLVM::SExtOp>(
1262 loc, dstType, adaptor.getOperand2());
1263 }
1264 } else if (op2TypeWidth == dstTypeWidth) {
1265 extended = adaptor.getOperand2();
1266 } else {
1267 return failure();
1268 }
1269
1270 Value result = rewriter.template create<LLVMOp>(
1271 loc, dstType, adaptor.getOperand1(), extended);
1272 rewriter.replaceOp(op, result);
1273 return success();
1274 }
1275};
1276
1277class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
1278public:
1279 using SPIRVToLLVMConversion<spirv::GLTanOp>::SPIRVToLLVMConversion;
1280
1281 LogicalResult
1282 matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1283 ConversionPatternRewriter &rewriter) const override {
1284 auto dstType = typeConverter.convertType(tanOp.getType());
1285 if (!dstType)
1286 return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
1287
1288 Location loc = tanOp.getLoc();
1289 Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1290 Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1291 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1292 return success();
1293 }
1294};
1295
1296/// Convert `spirv.Tanh` to
1297///
1298/// exp(2x) - 1
1299/// -----------
1300/// exp(2x) + 1
1301///
1302class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
1303public:
1304 using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion;
1305
1306 LogicalResult
1307 matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1308 ConversionPatternRewriter &rewriter) const override {
1309 auto srcType = tanhOp.getType();
1310 auto dstType = typeConverter.convertType(srcType);
1311 if (!dstType)
1312 return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
1313
1314 Location loc = tanhOp.getLoc();
1315 Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1316 Value multiplied =
1317 rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1318 Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1319 Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1320 Value numerator =
1321 rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1322 Value denominator =
1323 rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1324 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1325 denominator);
1326 return success();
1327 }
1328};
1329
1330class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1331public:
1332 using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
1333
1334 LogicalResult
1335 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1336 ConversionPatternRewriter &rewriter) const override {
1337 auto srcType = varOp.getType();
1338 // Initialization is supported for scalars and vectors only.
1339 auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1340 auto init = varOp.getInitializer();
1341 if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1342 return failure();
1343
1344 auto dstType = typeConverter.convertType(srcType);
1345 if (!dstType)
1346 return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1347
1348 Location loc = varOp.getLoc();
1349 Value size = createI32ConstantOf(loc, rewriter, value: 1);
1350 if (!init) {
1351 auto elementType = typeConverter.convertType(pointerTo);
1352 if (!elementType)
1353 return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1354 rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
1355 size);
1356 return success();
1357 }
1358 auto elementType = typeConverter.convertType(pointerTo);
1359 if (!elementType)
1360 return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1361 Value allocated =
1362 rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size);
1363 rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1364 rewriter.replaceOp(varOp, allocated);
1365 return success();
1366 }
1367};
1368
1369//===----------------------------------------------------------------------===//
1370// BitcastOp conversion
1371//===----------------------------------------------------------------------===//
1372
1373class BitcastConversionPattern
1374 : public SPIRVToLLVMConversion<spirv::BitcastOp> {
1375public:
1376 using SPIRVToLLVMConversion<spirv::BitcastOp>::SPIRVToLLVMConversion;
1377
1378 LogicalResult
1379 matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1380 ConversionPatternRewriter &rewriter) const override {
1381 auto dstType = typeConverter.convertType(bitcastOp.getType());
1382 if (!dstType)
1383 return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");
1384
1385 // LLVM's opaque pointers do not require bitcasts.
1386 if (isa<LLVM::LLVMPointerType>(dstType)) {
1387 rewriter.replaceOp(bitcastOp, adaptor.getOperand());
1388 return success();
1389 }
1390
1391 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1392 bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1393 return success();
1394 }
1395};
1396
1397//===----------------------------------------------------------------------===//
1398// FuncOp conversion
1399//===----------------------------------------------------------------------===//
1400
1401class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1402public:
1403 using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
1404
1405 LogicalResult
1406 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1407 ConversionPatternRewriter &rewriter) const override {
1408
1409 // Convert function signature. At the moment LLVMType converter is enough
1410 // for currently supported types.
1411 auto funcType = funcOp.getFunctionType();
1412 TypeConverter::SignatureConversion signatureConverter(
1413 funcType.getNumInputs());
1414 auto llvmType = typeConverter.convertFunctionSignature(
1415 funcType, /*isVariadic=*/false, /*useBarePtrCallConv=*/false,
1416 signatureConverter);
1417 if (!llvmType)
1418 return failure();
1419
1420 // Create a new `LLVMFuncOp`
1421 Location loc = funcOp.getLoc();
1422 StringRef name = funcOp.getName();
1423 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1424
1425 // Convert SPIR-V Function Control to equivalent LLVM function attribute
1426 MLIRContext *context = funcOp.getContext();
1427 switch (funcOp.getFunctionControl()) {
1428#define DISPATCH(functionControl, llvmAttr) \
1429 case functionControl: \
1430 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1431 break;
1432
1433 DISPATCH(spirv::FunctionControl::Inline,
1434 StringAttr::get(context, "alwaysinline"));
1435 DISPATCH(spirv::FunctionControl::DontInline,
1436 StringAttr::get(context, "noinline"));
1437 DISPATCH(spirv::FunctionControl::Pure,
1438 StringAttr::get(context, "readonly"));
1439 DISPATCH(spirv::FunctionControl::Const,
1440 StringAttr::get(context, "readnone"));
1441
1442#undef DISPATCH
1443
1444 // Default: if `spirv::FunctionControl::None`, then no attributes are
1445 // needed.
1446 default:
1447 break;
1448 }
1449
1450 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1451 newFuncOp.end());
1452 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
1453 &signatureConverter))) {
1454 return failure();
1455 }
1456 rewriter.eraseOp(op: funcOp);
1457 return success();
1458 }
1459};
1460
1461//===----------------------------------------------------------------------===//
1462// ModuleOp conversion
1463//===----------------------------------------------------------------------===//
1464
1465class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1466public:
1467 using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion;
1468
1469 LogicalResult
1470 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1471 ConversionPatternRewriter &rewriter) const override {
1472
1473 auto newModuleOp =
1474 rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1475 rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1476
1477 // Remove the terminator block that was automatically added by builder
1478 rewriter.eraseBlock(block: &newModuleOp.getBodyRegion().back());
1479 rewriter.eraseOp(op: spvModuleOp);
1480 return success();
1481 }
1482};
1483
1484//===----------------------------------------------------------------------===//
1485// VectorShuffleOp conversion
1486//===----------------------------------------------------------------------===//
1487
1488class VectorShufflePattern
1489 : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1490public:
1491 using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion;
1492 LogicalResult
1493 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1494 ConversionPatternRewriter &rewriter) const override {
1495 Location loc = op.getLoc();
1496 auto components = adaptor.getComponents();
1497 auto vector1 = adaptor.getVector1();
1498 auto vector2 = adaptor.getVector2();
1499 int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1500 int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1501 if (vector1Size == vector2Size) {
1502 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1503 op, vector1, vector2,
1504 LLVM::convertArrayToIndices<int32_t>(components));
1505 return success();
1506 }
1507
1508 auto dstType = typeConverter.convertType(op.getType());
1509 if (!dstType)
1510 return rewriter.notifyMatchFailure(op, "type conversion failed");
1511 auto scalarType = cast<VectorType>(dstType).getElementType();
1512 auto componentsArray = components.getValue();
1513 auto *context = rewriter.getContext();
1514 auto llvmI32Type = IntegerType::get(context, 32);
1515 Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
1516 for (unsigned i = 0; i < componentsArray.size(); i++) {
1517 if (!isa<IntegerAttr>(componentsArray[i]))
1518 return op.emitError("unable to support non-constant component");
1519
1520 int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1521 if (indexVal == -1)
1522 continue;
1523
1524 int offsetVal = 0;
1525 Value baseVector = vector1;
1526 if (indexVal >= vector1Size) {
1527 offsetVal = vector1Size;
1528 baseVector = vector2;
1529 }
1530
1531 Value dstIndex = rewriter.create<LLVM::ConstantOp>(
1532 loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1533 Value index = rewriter.create<LLVM::ConstantOp>(
1534 loc, llvmI32Type,
1535 rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1536
1537 auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
1538 loc, scalarType, baseVector, index);
1539 targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1540 extractOp, dstIndex);
1541 }
1542 rewriter.replaceOp(op, targetOp);
1543 return success();
1544 }
1545};
1546} // namespace
1547
1548//===----------------------------------------------------------------------===//
1549// Pattern population
1550//===----------------------------------------------------------------------===//
1551
1552void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter,
1553 spirv::ClientAPI clientAPI) {
1554 typeConverter.addConversion(callback: [&](spirv::ArrayType type) {
1555 return convertArrayType(type, converter&: typeConverter);
1556 });
1557 typeConverter.addConversion(callback: [&, clientAPI](spirv::PointerType type) {
1558 return convertPointerType(type, typeConverter, clientAPI);
1559 });
1560 typeConverter.addConversion(callback: [&](spirv::RuntimeArrayType type) {
1561 return convertRuntimeArrayType(type, converter&: typeConverter);
1562 });
1563 typeConverter.addConversion(callback: [&](spirv::StructType type) {
1564 return convertStructType(type, converter&: typeConverter);
1565 });
1566}
1567
1568void mlir::populateSPIRVToLLVMConversionPatterns(
1569 LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
1570 spirv::ClientAPI clientAPI) {
1571 patterns.add<
1572 // Arithmetic ops
1573 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1574 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1575 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1576 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1577 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1578 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1579 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1580 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1581 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1582 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1583 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1584 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1585 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1586
1587 // Bitwise ops
1588 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1589 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1590 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1591 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1592 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1593 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1594 NotPattern<spirv::NotOp>,
1595
1596 // Cast ops
1597 BitcastConversionPattern,
1598 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1599 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1600 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1601 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1602 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1603 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1604 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1605
1606 // Comparison ops
1607 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1608 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1609 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1610 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1611 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1612 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1613 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1614 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1615 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1616 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1617 FComparePattern<spirv::FUnordGreaterThanEqualOp,
1618 LLVM::FCmpPredicate::uge>,
1619 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1620 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1621 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1622 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1623 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1624 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1625 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1626 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1627 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1628 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1629 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1630
1631 // Constant op
1632 ConstantScalarAndVectorPattern,
1633
1634 // Control Flow ops
1635 BranchConversionPattern, BranchConditionalConversionPattern,
1636 FunctionCallPattern, LoopPattern, SelectionPattern,
1637 ErasePattern<spirv::MergeOp>,
1638
1639 // Entry points and execution mode are handled separately.
1640 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1641
1642 // GLSL extended instruction set ops
1643 DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1644 DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1645 DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1646 DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1647 DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1648 DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1649 DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1650 DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1651 DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1652 DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1653 DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1654 DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1655 InverseSqrtPattern, TanPattern, TanhPattern,
1656
1657 // Logical ops
1658 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1659 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1660 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1661 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1662 NotPattern<spirv::LogicalNotOp>,
1663
1664 // Memory ops
1665 AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1666 LoadStorePattern<spirv::StoreOp>, VariablePattern,
1667
1668 // Miscellaneous ops
1669 CompositeExtractPattern, CompositeInsertPattern,
1670 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1671 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1672 VectorShufflePattern,
1673
1674 // Shift ops
1675 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1676 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1677 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1678
1679 // Return ops
1680 ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter);
1681
1682 patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
1683 typeConverter);
1684}
1685
1686void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
1687 LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1688 patterns.add<FuncConversionPattern>(arg: patterns.getContext(), args&: typeConverter);
1689}
1690
1691void mlir::populateSPIRVToLLVMModuleConversionPatterns(
1692 LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1693 patterns.add<ModuleConversionPattern>(arg: patterns.getContext(), args&: typeConverter);
1694}
1695
1696//===----------------------------------------------------------------------===//
1697// Pre-conversion hooks
1698//===----------------------------------------------------------------------===//
1699
1700/// Hook for descriptor set and binding number encoding.
1701static constexpr StringRef kBinding = "binding";
1702static constexpr StringRef kDescriptorSet = "descriptor_set";
1703void mlir::encodeBindAttribute(ModuleOp module) {
1704 auto spvModules = module.getOps<spirv::ModuleOp>();
1705 for (auto spvModule : spvModules) {
1706 spvModule.walk([&](spirv::GlobalVariableOp op) {
1707 IntegerAttr descriptorSet =
1708 op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1709 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1710 // For every global variable in the module, get the ones with descriptor
1711 // set and binding numbers.
1712 if (descriptorSet && binding) {
1713 // Encode these numbers into the variable's symbolic name. If the
1714 // SPIR-V module has a name, add it at the beginning.
1715 auto moduleAndName =
1716 spvModule.getName().has_value()
1717 ? spvModule.getName()->str() + "_" + op.getSymName().str()
1718 : op.getSymName().str();
1719 std::string name =
1720 llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
1721 std::to_string(descriptorSet.getInt()),
1722 std::to_string(binding.getInt()));
1723 auto nameAttr = StringAttr::get(op->getContext(), name);
1724
1725 // Replace all symbol uses and set the new symbol name. Finally, remove
1726 // descriptor set and binding attributes.
1727 if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
1728 op.emitError("unable to replace all symbol uses for ") << name;
1729 SymbolTable::setSymbolName(op, nameAttr);
1730 op->removeAttr(kDescriptorSet);
1731 op->removeAttr(kBinding);
1732 }
1733 });
1734 }
1735}
1736

source code of mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp