1//===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
14#include "mlir/Conversion/LLVMCommon/Pattern.h"
15#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16#include "mlir/Conversion/SPIRVCommon/AttrToLLVMConverter.h"
17#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
20#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
21#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
22#include "mlir/IR/BuiltinOps.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/Transforms/DialectConversion.h"
25#include "llvm/ADT/TypeSwitch.h"
26#include "llvm/Support/Debug.h"
27#include "llvm/Support/FormatVariadic.h"
28
29#define DEBUG_TYPE "spirv-to-llvm-pattern"
30
31using namespace mlir;
32
33//===----------------------------------------------------------------------===//
34// Utility functions
35//===----------------------------------------------------------------------===//
36
37/// Returns true if the given type is a signed integer or vector type.
38static bool isSignedIntegerOrVector(Type type) {
39 if (type.isSignedInteger())
40 return true;
41 if (auto vecType = dyn_cast<VectorType>(type))
42 return vecType.getElementType().isSignedInteger();
43 return false;
44}
45
46/// Returns true if the given type is an unsigned integer or vector type
47static bool isUnsignedIntegerOrVector(Type type) {
48 if (type.isUnsignedInteger())
49 return true;
50 if (auto vecType = dyn_cast<VectorType>(type))
51 return vecType.getElementType().isUnsignedInteger();
52 return false;
53}
54
55/// Returns the width of an integer or of the element type of an integer vector,
56/// if applicable.
57static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) {
58 if (auto intType = dyn_cast<IntegerType>(type))
59 return intType.getWidth();
60 if (auto vecType = dyn_cast<VectorType>(type))
61 if (auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
62 return intType.getWidth();
63 return std::nullopt;
64}
65
66/// Returns the bit width of integer, float or vector of float or integer values
67static unsigned getBitWidth(Type type) {
68 assert((type.isIntOrFloat() || isa<VectorType>(type)) &&
69 "bitwidth is not supported for this type");
70 if (type.isIntOrFloat())
71 return type.getIntOrFloatBitWidth();
72 auto vecType = dyn_cast<VectorType>(type);
73 auto elementType = vecType.getElementType();
74 assert(elementType.isIntOrFloat() &&
75 "only integers and floats have a bitwidth");
76 return elementType.getIntOrFloatBitWidth();
77}
78
79/// Returns the bit width of LLVMType integer or vector.
80static unsigned getLLVMTypeBitWidth(Type type) {
81 if (auto vecTy = dyn_cast<VectorType>(type))
82 type = vecTy.getElementType();
83 return cast<IntegerType>(type).getWidth();
84}
85
86/// Creates `IntegerAttribute` with all bits set for given type
87static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
88 if (auto vecType = dyn_cast<VectorType>(type)) {
89 auto integerType = cast<IntegerType>(vecType.getElementType());
90 return builder.getIntegerAttr(integerType, -1);
91 }
92 auto integerType = cast<IntegerType>(type);
93 return builder.getIntegerAttr(integerType, -1);
94}
95
96/// Creates `llvm.mlir.constant` with all bits set for the given type.
97static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
98 PatternRewriter &rewriter) {
99 if (isa<VectorType>(Val: srcType)) {
100 return rewriter.create<LLVM::ConstantOp>(
101 loc, dstType,
102 SplatElementsAttr::get(cast<ShapedType>(srcType),
103 minusOneIntegerAttribute(srcType, rewriter)));
104 }
105 return rewriter.create<LLVM::ConstantOp>(
106 loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
107}
108
109/// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
110static Value createFPConstant(Location loc, Type srcType, Type dstType,
111 PatternRewriter &rewriter, double value) {
112 if (auto vecType = dyn_cast<VectorType>(srcType)) {
113 auto floatType = cast<FloatType>(vecType.getElementType());
114 return rewriter.create<LLVM::ConstantOp>(
115 loc, dstType,
116 SplatElementsAttr::get(vecType,
117 rewriter.getFloatAttr(floatType, value)));
118 }
119 auto floatType = cast<FloatType>(srcType);
120 return rewriter.create<LLVM::ConstantOp>(
121 loc, dstType, rewriter.getFloatAttr(floatType, value));
122}
123
124/// Utility function for bitfield ops:
125/// - `BitFieldInsert`
126/// - `BitFieldSExtract`
127/// - `BitFieldUExtract`
128/// Truncates or extends the value. If the bitwidth of the value is the same as
129/// `llvmType` bitwidth, the value remains unchanged.
130static Value optionallyTruncateOrExtend(Location loc, Value value,
131 Type llvmType,
132 PatternRewriter &rewriter) {
133 auto srcType = value.getType();
134 unsigned targetBitWidth = getLLVMTypeBitWidth(type: llvmType);
135 unsigned valueBitWidth = LLVM::isCompatibleType(type: srcType)
136 ? getLLVMTypeBitWidth(type: srcType)
137 : getBitWidth(type: srcType);
138
139 if (valueBitWidth < targetBitWidth)
140 return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
141 // If the bit widths of `Count` and `Offset` are greater than the bit width
142 // of the target type, they are truncated. Truncation is safe since `Count`
143 // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
144 // both values can be expressed in 8 bits.
145 if (valueBitWidth > targetBitWidth)
146 return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
147 return value;
148}
149
150/// Broadcasts the value to vector with `numElements` number of elements.
151static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
152 const TypeConverter &typeConverter,
153 ConversionPatternRewriter &rewriter) {
154 auto vectorType = VectorType::get(numElements, toBroadcast.getType());
155 auto llvmVectorType = typeConverter.convertType(vectorType);
156 auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
157 Value broadcasted = rewriter.create<LLVM::PoisonOp>(loc, llvmVectorType);
158 for (unsigned i = 0; i < numElements; ++i) {
159 auto index = rewriter.create<LLVM::ConstantOp>(
160 loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
161 broadcasted = rewriter.create<LLVM::InsertElementOp>(
162 loc, llvmVectorType, broadcasted, toBroadcast, index);
163 }
164 return broadcasted;
165}
166
167/// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
168static Value optionallyBroadcast(Location loc, Value value, Type srcType,
169 const TypeConverter &typeConverter,
170 ConversionPatternRewriter &rewriter) {
171 if (auto vectorType = dyn_cast<VectorType>(srcType)) {
172 unsigned numElements = vectorType.getNumElements();
173 return broadcast(loc, toBroadcast: value, numElements, typeConverter, rewriter);
174 }
175 return value;
176}
177
178/// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
179/// `BitFieldUExtract`.
180/// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
181/// a vector type, construct a vector that has:
182/// - same number of elements as `Base`
183/// - each element has the type that is the same as the type of `Offset` or
184/// `Count`
185/// - each element has the same value as `Offset` or `Count`
186/// Then cast `Offset` and `Count` if their bit width is different
187/// from `Base` bit width.
188static Value processCountOrOffset(Location loc, Value value, Type srcType,
189 Type dstType, const TypeConverter &converter,
190 ConversionPatternRewriter &rewriter) {
191 Value broadcasted =
192 optionallyBroadcast(loc, value, srcType, typeConverter: converter, rewriter);
193 return optionallyTruncateOrExtend(loc, value: broadcasted, llvmType: dstType, rewriter);
194}
195
196/// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
197/// offset to LLVM struct. Otherwise, the conversion is not supported.
198static Type convertStructTypeWithOffset(spirv::StructType type,
199 const TypeConverter &converter) {
200 if (type != VulkanLayoutUtils::decorateType(structType: type))
201 return nullptr;
202
203 SmallVector<Type> elementsVector;
204 if (failed(Result: converter.convertTypes(types: type.getElementTypes(), results&: elementsVector)))
205 return nullptr;
206 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
207 /*isPacked=*/false);
208}
209
210/// Converts SPIR-V struct with no offset to packed LLVM struct.
211static Type convertStructTypePacked(spirv::StructType type,
212 const TypeConverter &converter) {
213 SmallVector<Type> elementsVector;
214 if (failed(Result: converter.convertTypes(types: type.getElementTypes(), results&: elementsVector)))
215 return nullptr;
216 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
217 /*isPacked=*/true);
218}
219
220/// Creates LLVM dialect constant with the given value.
221static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
222 unsigned value) {
223 return rewriter.create<LLVM::ConstantOp>(
224 loc, IntegerType::get(rewriter.getContext(), 32),
225 rewriter.getIntegerAttr(rewriter.getI32Type(), value));
226}
227
228/// Utility for `spirv.Load` and `spirv.Store` conversion.
229static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
230 ConversionPatternRewriter &rewriter,
231 const TypeConverter &typeConverter,
232 unsigned alignment, bool isVolatile,
233 bool isNonTemporal) {
234 if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
235 auto dstType = typeConverter.convertType(loadOp.getType());
236 if (!dstType)
237 return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed");
238 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
239 loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
240 isVolatile, isNonTemporal);
241 return success();
242 }
243 auto storeOp = cast<spirv::StoreOp>(op);
244 spirv::StoreOpAdaptor adaptor(operands);
245 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
246 adaptor.getPtr(), alignment,
247 isVolatile, isNonTemporal);
248 return success();
249}
250
251//===----------------------------------------------------------------------===//
252// Type conversion
253//===----------------------------------------------------------------------===//
254
255/// Converts SPIR-V array type to LLVM array. Natural stride (according to
256/// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
257/// when converting ops that manipulate array types.
258static std::optional<Type> convertArrayType(spirv::ArrayType type,
259 TypeConverter &converter) {
260 unsigned stride = type.getArrayStride();
261 Type elementType = type.getElementType();
262 auto sizeInBytes = cast<spirv::SPIRVType>(Val&: elementType).getSizeInBytes();
263 if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
264 return std::nullopt;
265
266 auto llvmElementType = converter.convertType(t: elementType);
267 unsigned numElements = type.getNumElements();
268 return LLVM::LLVMArrayType::get(llvmElementType, numElements);
269}
270
271/// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
272/// modelled at the moment.
273static Type convertPointerType(spirv::PointerType type,
274 const TypeConverter &converter,
275 spirv::ClientAPI clientAPI) {
276 unsigned addressSpace =
277 storageClassToAddressSpace(clientAPI, type.getStorageClass());
278 return LLVM::LLVMPointerType::get(type.getContext(), addressSpace);
279}
280
281/// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
282/// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
283/// no modelling of array stride at the moment.
284static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
285 TypeConverter &converter) {
286 if (type.getArrayStride() != 0)
287 return std::nullopt;
288 auto elementType = converter.convertType(t: type.getElementType());
289 return LLVM::LLVMArrayType::get(elementType, 0);
290}
291
292/// Converts SPIR-V struct to LLVM struct. There is no support of structs with
293/// member decorations. Also, only natural offset is supported.
294static Type convertStructType(spirv::StructType type,
295 const TypeConverter &converter) {
296 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
297 type.getMemberDecorations(memberDecorations);
298 if (!memberDecorations.empty())
299 return nullptr;
300 if (type.hasOffset())
301 return convertStructTypeWithOffset(type, converter);
302 return convertStructTypePacked(type, converter);
303}
304
305//===----------------------------------------------------------------------===//
306// Operation conversion
307//===----------------------------------------------------------------------===//
308
309namespace {
310
311class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
312public:
313 using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
314
315 LogicalResult
316 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
317 ConversionPatternRewriter &rewriter) const override {
318 auto dstType =
319 getTypeConverter()->convertType(op.getComponentPtr().getType());
320 if (!dstType)
321 return rewriter.notifyMatchFailure(op, "type conversion failed");
322 // To use GEP we need to add a first 0 index to go through the pointer.
323 auto indices = llvm::to_vector<4>(adaptor.getIndices());
324 Type indexType = op.getIndices().front().getType();
325 auto llvmIndexType = getTypeConverter()->convertType(indexType);
326 if (!llvmIndexType)
327 return rewriter.notifyMatchFailure(op, "type conversion failed");
328 Value zero = rewriter.create<LLVM::ConstantOp>(
329 op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
330 indices.insert(indices.begin(), zero);
331
332 auto elementType = getTypeConverter()->convertType(
333 cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
334 if (!elementType)
335 return rewriter.notifyMatchFailure(op, "type conversion failed");
336 rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
337 adaptor.getBasePtr(), indices);
338 return success();
339 }
340};
341
342class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
343public:
344 using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
345
346 LogicalResult
347 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
348 ConversionPatternRewriter &rewriter) const override {
349 auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
350 if (!dstType)
351 return rewriter.notifyMatchFailure(op, "type conversion failed");
352 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
353 op.getVariable());
354 return success();
355 }
356};
357
358class BitFieldInsertPattern
359 : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
360public:
361 using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion;
362
363 LogicalResult
364 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
365 ConversionPatternRewriter &rewriter) const override {
366 auto srcType = op.getType();
367 auto dstType = getTypeConverter()->convertType(srcType);
368 if (!dstType)
369 return rewriter.notifyMatchFailure(op, "type conversion failed");
370 Location loc = op.getLoc();
371
372 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
373 Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
374 *getTypeConverter(), rewriter);
375 Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
376 *getTypeConverter(), rewriter);
377
378 // Create a mask with bits set outside [Offset, Offset + Count - 1].
379 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
380 Value maskShiftedByCount =
381 rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
382 Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
383 maskShiftedByCount, minusOne);
384 Value maskShiftedByCountAndOffset =
385 rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
386 Value mask = rewriter.create<LLVM::XOrOp>(
387 loc, dstType, maskShiftedByCountAndOffset, minusOne);
388
389 // Extract unchanged bits from the `Base` that are outside of
390 // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
391 Value baseAndMask =
392 rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
393 Value insertShiftedByOffset =
394 rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
395 rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
396 insertShiftedByOffset);
397 return success();
398 }
399};
400
401/// Converts SPIR-V ConstantOp with scalar or vector type.
402class ConstantScalarAndVectorPattern
403 : public SPIRVToLLVMConversion<spirv::ConstantOp> {
404public:
405 using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
406
407 LogicalResult
408 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
409 ConversionPatternRewriter &rewriter) const override {
410 auto srcType = constOp.getType();
411 if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
412 return failure();
413
414 auto dstType = getTypeConverter()->convertType(srcType);
415 if (!dstType)
416 return rewriter.notifyMatchFailure(constOp, "type conversion failed");
417
418 // SPIR-V constant can be a signed/unsigned integer, which has to be
419 // casted to signless integer when converting to LLVM dialect. Removing the
420 // sign bit may have unexpected behaviour. However, it is better to handle
421 // it case-by-case, given that the purpose of the conversion is not to
422 // cover all possible corner cases.
423 if (isSignedIntegerOrVector(srcType) ||
424 isUnsignedIntegerOrVector(srcType)) {
425 auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
426
427 if (isa<VectorType>(srcType)) {
428 auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
429 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
430 constOp, dstType,
431 dstElementsAttr.mapValues(
432 signlessType, [&](const APInt &value) { return value; }));
433 return success();
434 }
435 auto srcAttr = cast<IntegerAttr>(constOp.getValue());
436 auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
437 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
438 return success();
439 }
440 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
441 constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
442 return success();
443 }
444};
445
446class BitFieldSExtractPattern
447 : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
448public:
449 using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion;
450
451 LogicalResult
452 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
453 ConversionPatternRewriter &rewriter) const override {
454 auto srcType = op.getType();
455 auto dstType = getTypeConverter()->convertType(srcType);
456 if (!dstType)
457 return rewriter.notifyMatchFailure(op, "type conversion failed");
458 Location loc = op.getLoc();
459
460 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
461 Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
462 *getTypeConverter(), rewriter);
463 Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
464 *getTypeConverter(), rewriter);
465
466 // Create a constant that holds the size of the `Base`.
467 IntegerType integerType;
468 if (auto vecType = dyn_cast<VectorType>(srcType))
469 integerType = cast<IntegerType>(vecType.getElementType());
470 else
471 integerType = cast<IntegerType>(srcType);
472
473 auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
474 Value size =
475 isa<VectorType>(srcType)
476 ? rewriter.create<LLVM::ConstantOp>(
477 loc, dstType,
478 SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize))
479 : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
480
481 // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
482 // at Offset + Count - 1 is the most significant bit now.
483 Value countPlusOffset =
484 rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
485 Value amountToShiftLeft =
486 rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
487 Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
488 loc, dstType, op.getBase(), amountToShiftLeft);
489
490 // Shift the result right, filling the bits with the sign bit.
491 Value amountToShiftRight =
492 rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
493 rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
494 amountToShiftRight);
495 return success();
496 }
497};
498
499class BitFieldUExtractPattern
500 : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
501public:
502 using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion;
503
504 LogicalResult
505 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
506 ConversionPatternRewriter &rewriter) const override {
507 auto srcType = op.getType();
508 auto dstType = getTypeConverter()->convertType(srcType);
509 if (!dstType)
510 return rewriter.notifyMatchFailure(op, "type conversion failed");
511 Location loc = op.getLoc();
512
513 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
514 Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
515 *getTypeConverter(), rewriter);
516 Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
517 *getTypeConverter(), rewriter);
518
519 // Create a mask with bits set at [0, Count - 1].
520 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
521 Value maskShiftedByCount =
522 rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
523 Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
524 minusOne);
525
526 // Shift `Base` by `Offset` and apply the mask on it.
527 Value shiftedBase =
528 rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
529 rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
530 return success();
531 }
532};
533
534class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
535public:
536 using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion;
537
538 LogicalResult
539 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
540 ConversionPatternRewriter &rewriter) const override {
541 rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
542 branchOp.getTarget());
543 return success();
544 }
545};
546
547class BranchConditionalConversionPattern
548 : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
549public:
550 using SPIRVToLLVMConversion<
551 spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
552
553 LogicalResult
554 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
555 ConversionPatternRewriter &rewriter) const override {
556 // If branch weights exist, map them to 32-bit integer vector.
557 DenseI32ArrayAttr branchWeights = nullptr;
558 if (auto weights = op.getBranchWeights()) {
559 SmallVector<int32_t> weightValues;
560 for (auto weight : weights->getAsRange<IntegerAttr>())
561 weightValues.push_back(weight.getInt());
562 branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues);
563 }
564
565 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
566 op, op.getCondition(), op.getTrueBlockArguments(),
567 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
568 op.getFalseBlock());
569 return success();
570 }
571};
572
573/// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container
574/// type is an aggregate type (struct or array). Otherwise, converts to
575/// `llvm.extractelement` that operates on vectors.
576class CompositeExtractPattern
577 : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
578public:
579 using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion;
580
581 LogicalResult
582 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
583 ConversionPatternRewriter &rewriter) const override {
584 auto dstType = this->getTypeConverter()->convertType(op.getType());
585 if (!dstType)
586 return rewriter.notifyMatchFailure(op, "type conversion failed");
587
588 Type containerType = op.getComposite().getType();
589 if (isa<VectorType>(Val: containerType)) {
590 Location loc = op.getLoc();
591 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
592 Value index = createI32ConstantOf(loc, rewriter, value.getInt());
593 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
594 op, dstType, adaptor.getComposite(), index);
595 return success();
596 }
597
598 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
599 op, adaptor.getComposite(),
600 LLVM::convertArrayToIndices(op.getIndices()));
601 return success();
602 }
603};
604
605/// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container
606/// type is an aggregate type (struct or array). Otherwise, converts to
607/// `llvm.insertelement` that operates on vectors.
608class CompositeInsertPattern
609 : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
610public:
611 using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion;
612
613 LogicalResult
614 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
615 ConversionPatternRewriter &rewriter) const override {
616 auto dstType = this->getTypeConverter()->convertType(op.getType());
617 if (!dstType)
618 return rewriter.notifyMatchFailure(op, "type conversion failed");
619
620 Type containerType = op.getComposite().getType();
621 if (isa<VectorType>(Val: containerType)) {
622 Location loc = op.getLoc();
623 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
624 Value index = createI32ConstantOf(loc, rewriter, value.getInt());
625 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
626 op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
627 return success();
628 }
629
630 rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
631 op, adaptor.getComposite(), adaptor.getObject(),
632 LLVM::convertArrayToIndices(op.getIndices()));
633 return success();
634 }
635};
636
637/// Converts SPIR-V operations that have straightforward LLVM equivalent
638/// into LLVM dialect operations.
639template <typename SPIRVOp, typename LLVMOp>
640class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
641public:
642 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
643
644 LogicalResult
645 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
646 ConversionPatternRewriter &rewriter) const override {
647 auto dstType = this->getTypeConverter()->convertType(op.getType());
648 if (!dstType)
649 return rewriter.notifyMatchFailure(op, "type conversion failed");
650 rewriter.template replaceOpWithNewOp<LLVMOp>(
651 op, dstType, adaptor.getOperands(), op->getAttrs());
652 return success();
653 }
654};
655
656/// Converts `spirv.ExecutionMode` into a global struct constant that holds
657/// execution mode information.
658class ExecutionModePattern
659 : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
660public:
661 using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion;
662
663 LogicalResult
664 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
665 ConversionPatternRewriter &rewriter) const override {
666 // First, create the global struct's name that would be associated with
667 // this entry point's execution mode. We set it to be:
668 // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
669 ModuleOp module = op->getParentOfType<ModuleOp>();
670 spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
671 std::string moduleName;
672 if (module.getName().has_value())
673 moduleName = "_" + module.getName()->str();
674 else
675 moduleName = "";
676 std::string executionModeInfoName = llvm::formatv(
677 "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
678 static_cast<uint32_t>(executionModeAttr.getValue()));
679
680 MLIRContext *context = rewriter.getContext();
681 OpBuilder::InsertionGuard guard(rewriter);
682 rewriter.setInsertionPointToStart(module.getBody());
683
684 // Create a struct type, corresponding to the C struct below.
685 // struct {
686 // int32_t executionMode;
687 // int32_t values[]; // optional values
688 // };
689 auto llvmI32Type = IntegerType::get(context, 32);
690 SmallVector<Type, 2> fields;
691 fields.push_back(Elt: llvmI32Type);
692 ArrayAttr values = op.getValues();
693 if (!values.empty()) {
694 auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
695 fields.push_back(Elt: arrayType);
696 }
697 auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
698
699 // Create `llvm.mlir.global` with initializer region containing one block.
700 auto global = rewriter.create<LLVM::GlobalOp>(
701 UnknownLoc::get(context), structType, /*isConstant=*/true,
702 LLVM::Linkage::External, executionModeInfoName, Attribute(),
703 /*alignment=*/0);
704 Location loc = global.getLoc();
705 Region &region = global.getInitializerRegion();
706 Block *block = rewriter.createBlock(parent: &region);
707
708 // Initialize the struct and set the execution mode value.
709 rewriter.setInsertionPointToStart(block);
710 Value structValue = rewriter.create<LLVM::PoisonOp>(loc, structType);
711 Value executionMode = rewriter.create<LLVM::ConstantOp>(
712 loc, llvmI32Type,
713 rewriter.getI32IntegerAttr(
714 static_cast<uint32_t>(executionModeAttr.getValue())));
715 structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue,
716 executionMode, 0);
717
718 // Insert extra operands if they exist into execution mode info struct.
719 for (unsigned i = 0, e = values.size(); i < e; ++i) {
720 auto attr = values.getValue()[i];
721 Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
722 structValue = rewriter.create<LLVM::InsertValueOp>(
723 loc, structValue, entry, ArrayRef<int64_t>({1, i}));
724 }
725 rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
726 rewriter.eraseOp(op: op);
727 return success();
728 }
729};
730
731/// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V
732/// global returns a pointer, whereas in LLVM dialect the global holds an actual
733/// value. This difference is handled by `spirv.mlir.addressof` and
734/// `llvm.mlir.addressof`ops that both return a pointer.
735class GlobalVariablePattern
736 : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
737public:
738 template <typename... Args>
739 GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
740 : SPIRVToLLVMConversion<spirv::GlobalVariableOp>(
741 std::forward<Args>(args)...),
742 clientAPI(clientAPI) {}
743
744 LogicalResult
745 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
746 ConversionPatternRewriter &rewriter) const override {
747 // Currently, there is no support of initialization with a constant value in
748 // SPIR-V dialect. Specialization constants are not considered as well.
749 if (op.getInitializer())
750 return failure();
751
752 auto srcType = cast<spirv::PointerType>(op.getType());
753 auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
754 if (!dstType)
755 return rewriter.notifyMatchFailure(op, "type conversion failed");
756
757 // Limit conversion to the current invocation only or `StorageBuffer`
758 // required by SPIR-V runner.
759 // This is okay because multiple invocations are not supported yet.
760 auto storageClass = srcType.getStorageClass();
761 switch (storageClass) {
762 case spirv::StorageClass::Input:
763 case spirv::StorageClass::Private:
764 case spirv::StorageClass::Output:
765 case spirv::StorageClass::StorageBuffer:
766 case spirv::StorageClass::UniformConstant:
767 break;
768 default:
769 return failure();
770 }
771
772 // LLVM dialect spec: "If the global value is a constant, storing into it is
773 // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
774 // storage class that is read-only.
775 bool isConstant = (storageClass == spirv::StorageClass::Input) ||
776 (storageClass == spirv::StorageClass::UniformConstant);
777 // SPIR-V spec: "By default, functions and global variables are private to a
778 // module and cannot be accessed by other modules. However, a module may be
779 // written to export or import functions and global (module scope)
780 // variables.". Therefore, map 'Private' storage class to private linkage,
781 // 'Input' and 'Output' to external linkage.
782 auto linkage = storageClass == spirv::StorageClass::Private
783 ? LLVM::Linkage::Private
784 : LLVM::Linkage::External;
785 auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
786 op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
787 /*alignment=*/0, storageClassToAddressSpace(clientAPI, storageClass));
788
789 // Attach location attribute if applicable
790 if (op.getLocationAttr())
791 newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
792
793 return success();
794 }
795
796private:
797 spirv::ClientAPI clientAPI;
798};
799
800/// Converts SPIR-V cast ops that do not have straightforward LLVM
801/// equivalent in LLVM dialect.
802template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
803class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
804public:
805 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
806
807 LogicalResult
808 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
809 ConversionPatternRewriter &rewriter) const override {
810
811 Type fromType = op.getOperand().getType();
812 Type toType = op.getType();
813
814 auto dstType = this->getTypeConverter()->convertType(toType);
815 if (!dstType)
816 return rewriter.notifyMatchFailure(op, "type conversion failed");
817
818 if (getBitWidth(type: fromType) < getBitWidth(type: toType)) {
819 rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
820 adaptor.getOperands());
821 return success();
822 }
823 if (getBitWidth(type: fromType) > getBitWidth(type: toType)) {
824 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
825 adaptor.getOperands());
826 return success();
827 }
828 return failure();
829 }
830};
831
832class FunctionCallPattern
833 : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
834public:
835 using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion;
836
837 LogicalResult
838 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
839 ConversionPatternRewriter &rewriter) const override {
840 if (callOp.getNumResults() == 0) {
841 auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
842 callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
843 newOp.getProperties().operandSegmentSizes = {
844 static_cast<int32_t>(adaptor.getOperands().size()), 0};
845 newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
846 return success();
847 }
848
849 // Function returns a single result.
850 auto dstType = getTypeConverter()->convertType(callOp.getType(0));
851 if (!dstType)
852 return rewriter.notifyMatchFailure(callOp, "type conversion failed");
853 auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
854 callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
855 newOp.getProperties().operandSegmentSizes = {
856 static_cast<int32_t>(adaptor.getOperands().size()), 0};
857 newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
858 return success();
859 }
860};
861
862/// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
863template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
864class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
865public:
866 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
867
868 LogicalResult
869 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
870 ConversionPatternRewriter &rewriter) const override {
871
872 auto dstType = this->getTypeConverter()->convertType(op.getType());
873 if (!dstType)
874 return rewriter.notifyMatchFailure(op, "type conversion failed");
875
876 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
877 op, dstType, predicate, op.getOperand1(), op.getOperand2());
878 return success();
879 }
880};
881
882/// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
883template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
884class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
885public:
886 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
887
888 LogicalResult
889 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
890 ConversionPatternRewriter &rewriter) const override {
891
892 auto dstType = this->getTypeConverter()->convertType(op.getType());
893 if (!dstType)
894 return rewriter.notifyMatchFailure(op, "type conversion failed");
895
896 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
897 op, dstType, predicate, op.getOperand1(), op.getOperand2());
898 return success();
899 }
900};
901
902class InverseSqrtPattern
903 : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> {
904public:
905 using SPIRVToLLVMConversion<spirv::GLInverseSqrtOp>::SPIRVToLLVMConversion;
906
907 LogicalResult
908 matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
909 ConversionPatternRewriter &rewriter) const override {
910 auto srcType = op.getType();
911 auto dstType = getTypeConverter()->convertType(srcType);
912 if (!dstType)
913 return rewriter.notifyMatchFailure(op, "type conversion failed");
914
915 Location loc = op.getLoc();
916 Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
917 Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
918 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
919 return success();
920 }
921};
922
923/// Converts `spirv.Load` and `spirv.Store` to LLVM dialect.
924template <typename SPIRVOp>
925class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
926public:
927 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
928
929 LogicalResult
930 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
931 ConversionPatternRewriter &rewriter) const override {
932 if (!op.getMemoryAccess()) {
933 return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
934 *this->getTypeConverter(), /*alignment=*/0,
935 /*isVolatile=*/false,
936 /*isNonTemporal=*/false);
937 }
938 auto memoryAccess = *op.getMemoryAccess();
939 switch (memoryAccess) {
940 case spirv::MemoryAccess::Aligned:
941 case spirv::MemoryAccess::None:
942 case spirv::MemoryAccess::Nontemporal:
943 case spirv::MemoryAccess::Volatile: {
944 unsigned alignment =
945 memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
946 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
947 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
948 return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
949 *this->getTypeConverter(), alignment,
950 isVolatile, isNonTemporal);
951 }
952 default:
953 // There is no support of other memory access attributes.
954 return failure();
955 }
956 }
957};
958
959/// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect.
960template <typename SPIRVOp>
961class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
962public:
963 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
964
965 LogicalResult
966 matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
967 ConversionPatternRewriter &rewriter) const override {
968 auto srcType = notOp.getType();
969 auto dstType = this->getTypeConverter()->convertType(srcType);
970 if (!dstType)
971 return rewriter.notifyMatchFailure(notOp, "type conversion failed");
972
973 Location loc = notOp.getLoc();
974 IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
975 auto mask =
976 isa<VectorType>(srcType)
977 ? rewriter.create<LLVM::ConstantOp>(
978 loc, dstType,
979 SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
980 : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
981 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
982 notOp.getOperand(), mask);
983 return success();
984 }
985};
986
987/// A template pattern that erases the given `SPIRVOp`.
988template <typename SPIRVOp>
989class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
990public:
991 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
992
993 LogicalResult
994 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
995 ConversionPatternRewriter &rewriter) const override {
996 rewriter.eraseOp(op);
997 return success();
998 }
999};
1000
1001class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
1002public:
1003 using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
1004
1005 LogicalResult
1006 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1007 ConversionPatternRewriter &rewriter) const override {
1008 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
1009 ArrayRef<Value>());
1010 return success();
1011 }
1012};
1013
1014class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
1015public:
1016 using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
1017
1018 LogicalResult
1019 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1020 ConversionPatternRewriter &rewriter) const override {
1021 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
1022 adaptor.getOperands());
1023 return success();
1024 }
1025};
1026
1027static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
1028 StringRef name,
1029 ArrayRef<Type> paramTypes,
1030 Type resultType,
1031 bool convergent = true) {
1032 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1033 SymbolTable::lookupSymbolIn(symbolTable, name));
1034 if (func)
1035 return func;
1036
1037 OpBuilder b(symbolTable->getRegion(index: 0));
1038 func = b.create<LLVM::LLVMFuncOp>(
1039 symbolTable->getLoc(), name,
1040 LLVM::LLVMFunctionType::get(resultType, paramTypes));
1041 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1042 func.setConvergent(convergent);
1043 func.setNoUnwind(true);
1044 func.setWillReturn(true);
1045 return func;
1046}
1047
1048static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
1049 LLVM::LLVMFuncOp func,
1050 ValueRange args) {
1051 auto call = builder.create<LLVM::CallOp>(loc, func, args);
1052 call.setCConv(func.getCConv());
1053 call.setConvergentAttr(func.getConvergentAttr());
1054 call.setNoUnwindAttr(func.getNoUnwindAttr());
1055 call.setWillReturnAttr(func.getWillReturnAttr());
1056 return call;
1057}
1058
1059template <typename BarrierOpTy>
1060class ControlBarrierPattern : public SPIRVToLLVMConversion<BarrierOpTy> {
1061public:
1062 using OpAdaptor = typename SPIRVToLLVMConversion<BarrierOpTy>::OpAdaptor;
1063
1064 using SPIRVToLLVMConversion<BarrierOpTy>::SPIRVToLLVMConversion;
1065
1066 static constexpr StringRef getFuncName();
1067
1068 LogicalResult
1069 matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
1070 ConversionPatternRewriter &rewriter) const override {
1071 constexpr StringRef funcName = getFuncName();
1072 Operation *symbolTable =
1073 controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>();
1074
1075 Type i32 = rewriter.getI32Type();
1076
1077 Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
1078 LLVM::LLVMFuncOp func =
1079 lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy);
1080
1081 Location loc = controlBarrierOp->getLoc();
1082 Value execution = rewriter.create<LLVM::ConstantOp>(
1083 loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
1084 Value memory = rewriter.create<LLVM::ConstantOp>(
1085 loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
1086 Value semantics = rewriter.create<LLVM::ConstantOp>(
1087 loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));
1088
1089 auto call = createSPIRVBuiltinCall(loc, rewriter, func,
1090 {execution, memory, semantics});
1091
1092 rewriter.replaceOp(controlBarrierOp, call);
1093 return success();
1094 }
1095};
1096
1097namespace {
1098
1099StringRef getTypeMangling(Type type, bool isSigned) {
1100 return llvm::TypeSwitch<Type, StringRef>(type)
1101 .Case<Float16Type>([](auto) { return "Dh"; })
1102 .Case<Float32Type>([](auto) { return "f"; })
1103 .Case<Float64Type>([](auto) { return "d"; })
1104 .Case<IntegerType>([isSigned](IntegerType intTy) {
1105 switch (intTy.getWidth()) {
1106 case 1:
1107 return "b";
1108 case 8:
1109 return (isSigned) ? "a" : "c";
1110 case 16:
1111 return (isSigned) ? "s" : "t";
1112 case 32:
1113 return (isSigned) ? "i" : "j";
1114 case 64:
1115 return (isSigned) ? "l" : "m";
1116 default:
1117 llvm_unreachable("Unsupported integer width");
1118 }
1119 })
1120 .Default([](auto) {
1121 llvm_unreachable("No mangling defined");
1122 return "";
1123 });
1124}
1125
1126template <typename ReduceOp>
1127constexpr StringLiteral getGroupFuncName();
1128
1129template <>
1130constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
1131 return "_Z17__spirv_GroupIAddii";
1132}
1133template <>
1134constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
1135 return "_Z17__spirv_GroupFAddii";
1136}
1137template <>
1138constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
1139 return "_Z17__spirv_GroupSMinii";
1140}
1141template <>
1142constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
1143 return "_Z17__spirv_GroupUMinii";
1144}
1145template <>
1146constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
1147 return "_Z17__spirv_GroupFMinii";
1148}
1149template <>
1150constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
1151 return "_Z17__spirv_GroupSMaxii";
1152}
1153template <>
1154constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
1155 return "_Z17__spirv_GroupUMaxii";
1156}
1157template <>
1158constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
1159 return "_Z17__spirv_GroupFMaxii";
1160}
1161template <>
1162constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
1163 return "_Z27__spirv_GroupNonUniformIAddii";
1164}
1165template <>
1166constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
1167 return "_Z27__spirv_GroupNonUniformFAddii";
1168}
1169template <>
1170constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
1171 return "_Z27__spirv_GroupNonUniformIMulii";
1172}
1173template <>
1174constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
1175 return "_Z27__spirv_GroupNonUniformFMulii";
1176}
1177template <>
1178constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
1179 return "_Z27__spirv_GroupNonUniformSMinii";
1180}
1181template <>
1182constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
1183 return "_Z27__spirv_GroupNonUniformUMinii";
1184}
1185template <>
1186constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
1187 return "_Z27__spirv_GroupNonUniformFMinii";
1188}
1189template <>
1190constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
1191 return "_Z27__spirv_GroupNonUniformSMaxii";
1192}
1193template <>
1194constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
1195 return "_Z27__spirv_GroupNonUniformUMaxii";
1196}
1197template <>
1198constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
1199 return "_Z27__spirv_GroupNonUniformFMaxii";
1200}
1201template <>
1202constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
1203 return "_Z33__spirv_GroupNonUniformBitwiseAndii";
1204}
1205template <>
1206constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
1207 return "_Z32__spirv_GroupNonUniformBitwiseOrii";
1208}
1209template <>
1210constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
1211 return "_Z33__spirv_GroupNonUniformBitwiseXorii";
1212}
1213template <>
1214constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
1215 return "_Z33__spirv_GroupNonUniformLogicalAndii";
1216}
1217template <>
1218constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
1219 return "_Z32__spirv_GroupNonUniformLogicalOrii";
1220}
1221template <>
1222constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
1223 return "_Z33__spirv_GroupNonUniformLogicalXorii";
1224}
1225} // namespace
1226
1227template <typename ReduceOp, bool Signed = false, bool NonUniform = false>
1228class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
1229public:
1230 using SPIRVToLLVMConversion<ReduceOp>::SPIRVToLLVMConversion;
1231
1232 LogicalResult
1233 matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor,
1234 ConversionPatternRewriter &rewriter) const override {
1235
1236 Type retTy = op.getResult().getType();
1237 if (!retTy.isIntOrFloat()) {
1238 return failure();
1239 }
1240 SmallString<36> funcName = getGroupFuncName<ReduceOp>();
1241 funcName += getTypeMangling(type: retTy, isSigned: false);
1242
1243 Type i32Ty = rewriter.getI32Type();
1244 SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy};
1245 if constexpr (NonUniform) {
1246 if (adaptor.getClusterSize()) {
1247 funcName += "j";
1248 paramTypes.push_back(Elt: i32Ty);
1249 }
1250 }
1251
1252 Operation *symbolTable =
1253 op->template getParentWithTrait<OpTrait::SymbolTable>();
1254
1255 LLVM::LLVMFuncOp func =
1256 lookupOrCreateSPIRVFn(symbolTable, funcName, paramTypes, retTy);
1257
1258 Location loc = op.getLoc();
1259 Value scope = rewriter.create<LLVM::ConstantOp>(
1260 loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope()));
1261 Value groupOp = rewriter.create<LLVM::ConstantOp>(
1262 loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation()));
1263 SmallVector<Value> operands{scope, groupOp};
1264 operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
1265
1266 auto call = createSPIRVBuiltinCall(loc, rewriter, func, operands);
1267 rewriter.replaceOp(op, call);
1268 return success();
1269 }
1270};
1271
1272template <>
1273constexpr StringRef
1274ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
1275 return "_Z22__spirv_ControlBarrieriii";
1276}
1277
1278template <>
1279constexpr StringRef
1280ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
1281 return "_Z33__spirv_ControlBarrierArriveINTELiii";
1282}
1283
1284template <>
1285constexpr StringRef
1286ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
1287 return "_Z31__spirv_ControlBarrierWaitINTELiii";
1288}
1289
1290/// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
1291/// should be reachable for conversion to succeed. The structure of the loop in
1292/// LLVM dialect will be the following:
1293///
1294/// +------------------------------------+
1295/// | <code before spirv.mlir.loop> |
1296/// | llvm.br ^header |
1297/// +------------------------------------+
1298/// |
1299/// +----------------+ |
1300/// | | |
1301/// | V V
1302/// | +------------------------------------+
1303/// | | ^header: |
1304/// | | <header code> |
1305/// | | llvm.cond_br %cond, ^body, ^exit |
1306/// | +------------------------------------+
1307/// | |
1308/// | |----------------------+
1309/// | | |
1310/// | V |
1311/// | +------------------------------------+ |
1312/// | | ^body: | |
1313/// | | <body code> | |
1314/// | | llvm.br ^continue | |
1315/// | +------------------------------------+ |
1316/// | | |
1317/// | V |
1318/// | +------------------------------------+ |
1319/// | | ^continue: | |
1320/// | | <continue code> | |
1321/// | | llvm.br ^header | |
1322/// | +------------------------------------+ |
1323/// | | |
1324/// +---------------+ +----------------------+
1325/// |
1326/// V
1327/// +------------------------------------+
1328/// | ^exit: |
1329/// | llvm.br ^remaining |
1330/// +------------------------------------+
1331/// |
1332/// V
1333/// +------------------------------------+
1334/// | ^remaining: |
1335/// | <code after spirv.mlir.loop> |
1336/// +------------------------------------+
1337///
1338class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1339public:
1340 using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion;
1341
1342 LogicalResult
1343 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1344 ConversionPatternRewriter &rewriter) const override {
1345 // There is no support of loop control at the moment.
1346 if (loopOp.getLoopControl() != spirv::LoopControl::None)
1347 return failure();
1348
1349 // `spirv.mlir.loop` with empty region is redundant and should be erased.
1350 if (loopOp.getBody().empty()) {
1351 rewriter.eraseOp(op: loopOp);
1352 return success();
1353 }
1354
1355 Location loc = loopOp.getLoc();
1356
1357 // Split the current block after `spirv.mlir.loop`. The remaining ops will
1358 // be used in `endBlock`.
1359 Block *currentBlock = rewriter.getBlock();
1360 auto position = Block::iterator(loopOp);
1361 Block *endBlock = rewriter.splitBlock(block: currentBlock, before: position);
1362
1363 // Remove entry block and create a branch in the current block going to the
1364 // header block.
1365 Block *entryBlock = loopOp.getEntryBlock();
1366 assert(entryBlock->getOperations().size() == 1);
1367 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1368 if (!brOp)
1369 return failure();
1370 Block *headerBlock = loopOp.getHeaderBlock();
1371 rewriter.setInsertionPointToEnd(currentBlock);
1372 rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1373 rewriter.eraseBlock(block: entryBlock);
1374
1375 // Branch from merge block to end block.
1376 Block *mergeBlock = loopOp.getMergeBlock();
1377 Operation *terminator = mergeBlock->getTerminator();
1378 ValueRange terminatorOperands = terminator->getOperands();
1379 rewriter.setInsertionPointToEnd(mergeBlock);
1380 rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1381
1382 rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
1383 rewriter.replaceOp(loopOp, endBlock->getArguments());
1384 return success();
1385 }
1386};
1387
1388/// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header
1389/// block. All blocks within selection should be reachable for conversion to
1390/// succeed.
1391class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1392public:
1393 using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion;
1394
1395 LogicalResult
1396 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1397 ConversionPatternRewriter &rewriter) const override {
1398 // There is no support for `Flatten` or `DontFlatten` selection control at
1399 // the moment. This are just compiler hints and can be performed during the
1400 // optimization passes.
1401 if (op.getSelectionControl() != spirv::SelectionControl::None)
1402 return failure();
1403
1404 // `spirv.mlir.selection` should have at least two blocks: one selection
1405 // header block and one merge block. If no blocks are present, or control
1406 // flow branches straight to merge block (two blocks are present), the op is
1407 // redundant and it is erased.
1408 if (op.getBody().getBlocks().size() <= 2) {
1409 rewriter.eraseOp(op: op);
1410 return success();
1411 }
1412
1413 Location loc = op.getLoc();
1414
1415 // Split the current block after `spirv.mlir.selection`. The remaining ops
1416 // will be used in `continueBlock`.
1417 auto *currentBlock = rewriter.getInsertionBlock();
1418 rewriter.setInsertionPointAfter(op);
1419 auto position = rewriter.getInsertionPoint();
1420 auto *continueBlock = rewriter.splitBlock(block: currentBlock, before: position);
1421
1422 // Extract conditional branch information from the header block. By SPIR-V
1423 // dialect spec, it should contain `spirv.BranchConditional` or
1424 // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the
1425 // moment in the SPIR-V dialect. Remove this block when finished.
1426 auto *headerBlock = op.getHeaderBlock();
1427 assert(headerBlock->getOperations().size() == 1);
1428 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1429 headerBlock->getOperations().front());
1430 if (!condBrOp)
1431 return failure();
1432 rewriter.eraseBlock(block: headerBlock);
1433
1434 // Branch from merge block to continue block.
1435 auto *mergeBlock = op.getMergeBlock();
1436 Operation *terminator = mergeBlock->getTerminator();
1437 ValueRange terminatorOperands = terminator->getOperands();
1438 rewriter.setInsertionPointToEnd(mergeBlock);
1439 rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1440
1441 // Link current block to `true` and `false` blocks within the selection.
1442 Block *trueBlock = condBrOp.getTrueBlock();
1443 Block *falseBlock = condBrOp.getFalseBlock();
1444 rewriter.setInsertionPointToEnd(currentBlock);
1445 rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1446 condBrOp.getTrueTargetOperands(),
1447 falseBlock,
1448 condBrOp.getFalseTargetOperands());
1449
1450 rewriter.inlineRegionBefore(op.getBody(), continueBlock);
1451 rewriter.replaceOp(op, continueBlock->getArguments());
1452 return success();
1453 }
1454};
1455
1456/// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1457/// puts a restriction on `Shift` and `Base` to have the same bit width,
1458/// `Shift` is zero or sign extended to match this specification. Cases when
1459/// `Shift` bit width > `Base` bit width are considered to be illegal.
1460template <typename SPIRVOp, typename LLVMOp>
1461class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1462public:
1463 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
1464
1465 LogicalResult
1466 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
1467 ConversionPatternRewriter &rewriter) const override {
1468
1469 auto dstType = this->getTypeConverter()->convertType(op.getType());
1470 if (!dstType)
1471 return rewriter.notifyMatchFailure(op, "type conversion failed");
1472
1473 Type op1Type = op.getOperand1().getType();
1474 Type op2Type = op.getOperand2().getType();
1475
1476 if (op1Type == op2Type) {
1477 rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1478 adaptor.getOperands());
1479 return success();
1480 }
1481
1482 std::optional<uint64_t> dstTypeWidth =
1483 getIntegerOrVectorElementWidth(dstType);
1484 std::optional<uint64_t> op2TypeWidth =
1485 getIntegerOrVectorElementWidth(type: op2Type);
1486
1487 if (!dstTypeWidth || !op2TypeWidth)
1488 return failure();
1489
1490 Location loc = op.getLoc();
1491 Value extended;
1492 if (op2TypeWidth < dstTypeWidth) {
1493 if (isUnsignedIntegerOrVector(type: op2Type)) {
1494 extended = rewriter.template create<LLVM::ZExtOp>(
1495 loc, dstType, adaptor.getOperand2());
1496 } else {
1497 extended = rewriter.template create<LLVM::SExtOp>(
1498 loc, dstType, adaptor.getOperand2());
1499 }
1500 } else if (op2TypeWidth == dstTypeWidth) {
1501 extended = adaptor.getOperand2();
1502 } else {
1503 return failure();
1504 }
1505
1506 Value result = rewriter.template create<LLVMOp>(
1507 loc, dstType, adaptor.getOperand1(), extended);
1508 rewriter.replaceOp(op, result);
1509 return success();
1510 }
1511};
1512
1513class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
1514public:
1515 using SPIRVToLLVMConversion<spirv::GLTanOp>::SPIRVToLLVMConversion;
1516
1517 LogicalResult
1518 matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1519 ConversionPatternRewriter &rewriter) const override {
1520 auto dstType = getTypeConverter()->convertType(tanOp.getType());
1521 if (!dstType)
1522 return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
1523
1524 Location loc = tanOp.getLoc();
1525 Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1526 Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1527 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1528 return success();
1529 }
1530};
1531
1532/// Convert `spirv.Tanh` to
1533///
1534/// exp(2x) - 1
1535/// -----------
1536/// exp(2x) + 1
1537///
1538class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
1539public:
1540 using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion;
1541
1542 LogicalResult
1543 matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1544 ConversionPatternRewriter &rewriter) const override {
1545 auto srcType = tanhOp.getType();
1546 auto dstType = getTypeConverter()->convertType(srcType);
1547 if (!dstType)
1548 return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
1549
1550 Location loc = tanhOp.getLoc();
1551 Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1552 Value multiplied =
1553 rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1554 Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1555 Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1556 Value numerator =
1557 rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1558 Value denominator =
1559 rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1560 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1561 denominator);
1562 return success();
1563 }
1564};
1565
1566class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1567public:
1568 using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
1569
1570 LogicalResult
1571 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1572 ConversionPatternRewriter &rewriter) const override {
1573 auto srcType = varOp.getType();
1574 // Initialization is supported for scalars and vectors only.
1575 auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1576 auto init = varOp.getInitializer();
1577 if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1578 return failure();
1579
1580 auto dstType = getTypeConverter()->convertType(srcType);
1581 if (!dstType)
1582 return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1583
1584 Location loc = varOp.getLoc();
1585 Value size = createI32ConstantOf(loc, rewriter, value: 1);
1586 if (!init) {
1587 auto elementType = getTypeConverter()->convertType(pointerTo);
1588 if (!elementType)
1589 return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1590 rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
1591 size);
1592 return success();
1593 }
1594 auto elementType = getTypeConverter()->convertType(pointerTo);
1595 if (!elementType)
1596 return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1597 Value allocated =
1598 rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size);
1599 rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1600 rewriter.replaceOp(varOp, allocated);
1601 return success();
1602 }
1603};
1604
1605//===----------------------------------------------------------------------===//
1606// BitcastOp conversion
1607//===----------------------------------------------------------------------===//
1608
1609class BitcastConversionPattern
1610 : public SPIRVToLLVMConversion<spirv::BitcastOp> {
1611public:
1612 using SPIRVToLLVMConversion<spirv::BitcastOp>::SPIRVToLLVMConversion;
1613
1614 LogicalResult
1615 matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1616 ConversionPatternRewriter &rewriter) const override {
1617 auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
1618 if (!dstType)
1619 return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");
1620
1621 // LLVM's opaque pointers do not require bitcasts.
1622 if (isa<LLVM::LLVMPointerType>(dstType)) {
1623 rewriter.replaceOp(bitcastOp, adaptor.getOperand());
1624 return success();
1625 }
1626
1627 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1628 bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1629 return success();
1630 }
1631};
1632
1633//===----------------------------------------------------------------------===//
1634// FuncOp conversion
1635//===----------------------------------------------------------------------===//
1636
1637class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1638public:
1639 using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
1640
1641 LogicalResult
1642 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1643 ConversionPatternRewriter &rewriter) const override {
1644
1645 // Convert function signature. At the moment LLVMType converter is enough
1646 // for currently supported types.
1647 auto funcType = funcOp.getFunctionType();
1648 TypeConverter::SignatureConversion signatureConverter(
1649 funcType.getNumInputs());
1650 auto llvmType = static_cast<const LLVMTypeConverter *>(getTypeConverter())
1651 ->convertFunctionSignature(
1652 funcType, /*isVariadic=*/false,
1653 /*useBarePtrCallConv=*/false, signatureConverter);
1654 if (!llvmType)
1655 return failure();
1656
1657 // Create a new `LLVMFuncOp`
1658 Location loc = funcOp.getLoc();
1659 StringRef name = funcOp.getName();
1660 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1661
1662 // Convert SPIR-V Function Control to equivalent LLVM function attribute
1663 MLIRContext *context = funcOp.getContext();
1664 switch (funcOp.getFunctionControl()) {
1665 case spirv::FunctionControl::Inline:
1666 newFuncOp.setAlwaysInline(true);
1667 break;
1668 case spirv::FunctionControl::DontInline:
1669 newFuncOp.setNoInline(true);
1670 break;
1671
1672#define DISPATCH(functionControl, llvmAttr) \
1673 case functionControl: \
1674 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1675 break;
1676
1677 DISPATCH(spirv::FunctionControl::Pure,
1678 StringAttr::get(context, "readonly"));
1679 DISPATCH(spirv::FunctionControl::Const,
1680 StringAttr::get(context, "readnone"));
1681
1682#undef DISPATCH
1683
1684 // Default: if `spirv::FunctionControl::None`, then no attributes are
1685 // needed.
1686 default:
1687 break;
1688 }
1689
1690 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1691 newFuncOp.end());
1692 if (failed(rewriter.convertRegionTypes(
1693 region: &newFuncOp.getBody(), converter: *getTypeConverter(), entryConversion: &signatureConverter))) {
1694 return failure();
1695 }
1696 rewriter.eraseOp(op: funcOp);
1697 return success();
1698 }
1699};
1700
1701//===----------------------------------------------------------------------===//
1702// ModuleOp conversion
1703//===----------------------------------------------------------------------===//
1704
1705class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1706public:
1707 using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion;
1708
1709 LogicalResult
1710 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1711 ConversionPatternRewriter &rewriter) const override {
1712
1713 auto newModuleOp =
1714 rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1715 rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1716
1717 // Remove the terminator block that was automatically added by builder
1718 rewriter.eraseBlock(block: &newModuleOp.getBodyRegion().back());
1719 rewriter.eraseOp(op: spvModuleOp);
1720 return success();
1721 }
1722};
1723
1724//===----------------------------------------------------------------------===//
1725// VectorShuffleOp conversion
1726//===----------------------------------------------------------------------===//
1727
1728class VectorShufflePattern
1729 : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1730public:
1731 using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion;
1732 LogicalResult
1733 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1734 ConversionPatternRewriter &rewriter) const override {
1735 Location loc = op.getLoc();
1736 auto components = adaptor.getComponents();
1737 auto vector1 = adaptor.getVector1();
1738 auto vector2 = adaptor.getVector2();
1739 int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1740 int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1741 if (vector1Size == vector2Size) {
1742 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1743 op, vector1, vector2,
1744 LLVM::convertArrayToIndices<int32_t>(components));
1745 return success();
1746 }
1747
1748 auto dstType = getTypeConverter()->convertType(op.getType());
1749 if (!dstType)
1750 return rewriter.notifyMatchFailure(op, "type conversion failed");
1751 auto scalarType = cast<VectorType>(dstType).getElementType();
1752 auto componentsArray = components.getValue();
1753 auto *context = rewriter.getContext();
1754 auto llvmI32Type = IntegerType::get(context, 32);
1755 Value targetOp = rewriter.create<LLVM::PoisonOp>(loc, dstType);
1756 for (unsigned i = 0; i < componentsArray.size(); i++) {
1757 if (!isa<IntegerAttr>(componentsArray[i]))
1758 return op.emitError("unable to support non-constant component");
1759
1760 int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1761 if (indexVal == -1)
1762 continue;
1763
1764 int offsetVal = 0;
1765 Value baseVector = vector1;
1766 if (indexVal >= vector1Size) {
1767 offsetVal = vector1Size;
1768 baseVector = vector2;
1769 }
1770
1771 Value dstIndex = rewriter.create<LLVM::ConstantOp>(
1772 loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1773 Value index = rewriter.create<LLVM::ConstantOp>(
1774 loc, llvmI32Type,
1775 rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1776
1777 auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
1778 loc, scalarType, baseVector, index);
1779 targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1780 extractOp, dstIndex);
1781 }
1782 rewriter.replaceOp(op, targetOp);
1783 return success();
1784 }
1785};
1786} // namespace
1787
1788//===----------------------------------------------------------------------===//
1789// Pattern population
1790//===----------------------------------------------------------------------===//
1791
1792void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter,
1793 spirv::ClientAPI clientAPI) {
1794 typeConverter.addConversion(callback: [&](spirv::ArrayType type) {
1795 return convertArrayType(type, converter&: typeConverter);
1796 });
1797 typeConverter.addConversion(callback: [&, clientAPI](spirv::PointerType type) {
1798 return convertPointerType(type, typeConverter, clientAPI);
1799 });
1800 typeConverter.addConversion(callback: [&](spirv::RuntimeArrayType type) {
1801 return convertRuntimeArrayType(type, converter&: typeConverter);
1802 });
1803 typeConverter.addConversion(callback: [&](spirv::StructType type) {
1804 return convertStructType(type, converter: typeConverter);
1805 });
1806}
1807
1808void mlir::populateSPIRVToLLVMConversionPatterns(
1809 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
1810 spirv::ClientAPI clientAPI) {
1811 patterns.add<
1812 // Arithmetic ops
1813 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1814 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1815 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1816 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1817 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1818 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1819 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1820 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1821 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1822 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1823 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1824 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1825 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1826
1827 // Bitwise ops
1828 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1829 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1830 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1831 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1832 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1833 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1834 NotPattern<spirv::NotOp>,
1835
1836 // Cast ops
1837 BitcastConversionPattern,
1838 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1839 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1840 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1841 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1842 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1843 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1844 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1845
1846 // Comparison ops
1847 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1848 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1849 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1850 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1851 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1852 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1853 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1854 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1855 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1856 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1857 FComparePattern<spirv::FUnordGreaterThanEqualOp,
1858 LLVM::FCmpPredicate::uge>,
1859 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1860 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1861 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1862 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1863 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1864 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1865 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1866 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1867 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1868 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1869 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1870
1871 // Constant op
1872 ConstantScalarAndVectorPattern,
1873
1874 // Control Flow ops
1875 BranchConversionPattern, BranchConditionalConversionPattern,
1876 FunctionCallPattern, LoopPattern, SelectionPattern,
1877 ErasePattern<spirv::MergeOp>,
1878
1879 // Entry points and execution mode are handled separately.
1880 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1881
1882 // GLSL extended instruction set ops
1883 DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1884 DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1885 DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1886 DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1887 DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1888 DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1889 DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1890 DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1891 DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1892 DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1893 DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1894 DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1895 InverseSqrtPattern, TanPattern, TanhPattern,
1896
1897 // Logical ops
1898 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1899 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1900 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1901 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1902 NotPattern<spirv::LogicalNotOp>,
1903
1904 // Memory ops
1905 AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1906 LoadStorePattern<spirv::StoreOp>, VariablePattern,
1907
1908 // Miscellaneous ops
1909 CompositeExtractPattern, CompositeInsertPattern,
1910 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1911 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1912 VectorShufflePattern,
1913
1914 // Shift ops
1915 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1916 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1917 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1918
1919 // Return ops
1920 ReturnPattern, ReturnValuePattern,
1921
1922 // Barrier ops
1923 ControlBarrierPattern<spirv::ControlBarrierOp>,
1924 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
1925 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,
1926
1927 // Group reduction operations
1928 GroupReducePattern<spirv::GroupIAddOp>,
1929 GroupReducePattern<spirv::GroupFAddOp>,
1930 GroupReducePattern<spirv::GroupFMinOp>,
1931 GroupReducePattern<spirv::GroupUMinOp>,
1932 GroupReducePattern<spirv::GroupSMinOp, /*Signed=*/true>,
1933 GroupReducePattern<spirv::GroupFMaxOp>,
1934 GroupReducePattern<spirv::GroupUMaxOp>,
1935 GroupReducePattern<spirv::GroupSMaxOp, /*Signed=*/true>,
1936 GroupReducePattern<spirv::GroupNonUniformIAddOp, /*Signed=*/false,
1937 /*NonUniform=*/true>,
1938 GroupReducePattern<spirv::GroupNonUniformFAddOp, /*Signed=*/false,
1939 /*NonUniform=*/true>,
1940 GroupReducePattern<spirv::GroupNonUniformIMulOp, /*Signed=*/false,
1941 /*NonUniform=*/true>,
1942 GroupReducePattern<spirv::GroupNonUniformFMulOp, /*Signed=*/false,
1943 /*NonUniform=*/true>,
1944 GroupReducePattern<spirv::GroupNonUniformSMinOp, /*Signed=*/true,
1945 /*NonUniform=*/true>,
1946 GroupReducePattern<spirv::GroupNonUniformUMinOp, /*Signed=*/false,
1947 /*NonUniform=*/true>,
1948 GroupReducePattern<spirv::GroupNonUniformFMinOp, /*Signed=*/false,
1949 /*NonUniform=*/true>,
1950 GroupReducePattern<spirv::GroupNonUniformSMaxOp, /*Signed=*/true,
1951 /*NonUniform=*/true>,
1952 GroupReducePattern<spirv::GroupNonUniformUMaxOp, /*Signed=*/false,
1953 /*NonUniform=*/true>,
1954 GroupReducePattern<spirv::GroupNonUniformFMaxOp, /*Signed=*/false,
1955 /*NonUniform=*/true>,
1956 GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /*Signed=*/false,
1957 /*NonUniform=*/true>,
1958 GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /*Signed=*/false,
1959 /*NonUniform=*/true>,
1960 GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /*Signed=*/false,
1961 /*NonUniform=*/true>,
1962 GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /*Signed=*/false,
1963 /*NonUniform=*/true>,
1964 GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /*Signed=*/false,
1965 /*NonUniform=*/true>,
1966 GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /*Signed=*/false,
1967 /*NonUniform=*/true>>(patterns.getContext(),
1968 typeConverter);
1969
1970 patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
1971 typeConverter);
1972}
1973
1974void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
1975 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1976 patterns.add<FuncConversionPattern>(arg: patterns.getContext(), args: typeConverter);
1977}
1978
1979void mlir::populateSPIRVToLLVMModuleConversionPatterns(
1980 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1981 patterns.add<ModuleConversionPattern>(arg: patterns.getContext(), args: typeConverter);
1982}
1983
1984//===----------------------------------------------------------------------===//
1985// Pre-conversion hooks
1986//===----------------------------------------------------------------------===//
1987
1988/// Hook for descriptor set and binding number encoding.
1989static constexpr StringRef kBinding = "binding";
1990static constexpr StringRef kDescriptorSet = "descriptor_set";
1991void mlir::encodeBindAttribute(ModuleOp module) {
1992 auto spvModules = module.getOps<spirv::ModuleOp>();
1993 for (auto spvModule : spvModules) {
1994 spvModule.walk([&](spirv::GlobalVariableOp op) {
1995 IntegerAttr descriptorSet =
1996 op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1997 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1998 // For every global variable in the module, get the ones with descriptor
1999 // set and binding numbers.
2000 if (descriptorSet && binding) {
2001 // Encode these numbers into the variable's symbolic name. If the
2002 // SPIR-V module has a name, add it at the beginning.
2003 auto moduleAndName =
2004 spvModule.getName().has_value()
2005 ? spvModule.getName()->str() + "_" + op.getSymName().str()
2006 : op.getSymName().str();
2007 std::string name =
2008 llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
2009 std::to_string(descriptorSet.getInt()),
2010 std::to_string(binding.getInt()));
2011 auto nameAttr = StringAttr::get(op->getContext(), name);
2012
2013 // Replace all symbol uses and set the new symbol name. Finally, remove
2014 // descriptor set and binding attributes.
2015 if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
2016 op.emitError("unable to replace all symbol uses for ") << name;
2017 SymbolTable::setSymbolName(op, nameAttr);
2018 op->removeAttr(kDescriptorSet);
2019 op->removeAttr(kBinding);
2020 }
2021 });
2022 }
2023}
2024

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp