1//===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
14#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15#include "mlir/Conversion/SPIRVCommon/AttrToLLVMConverter.h"
16#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
20#include "mlir/IR/BuiltinOps.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Transforms/DialectConversion.h"
23#include "llvm/ADT/TypeSwitch.h"
24#include "llvm/Support/FormatVariadic.h"
25
26#define DEBUG_TYPE "spirv-to-llvm-pattern"
27
28using namespace mlir;
29
30//===----------------------------------------------------------------------===//
31// Utility functions
32//===----------------------------------------------------------------------===//
33
34/// Returns true if the given type is a signed integer or vector type.
35static bool isSignedIntegerOrVector(Type type) {
36 if (type.isSignedInteger())
37 return true;
38 if (auto vecType = dyn_cast<VectorType>(Val&: type))
39 return vecType.getElementType().isSignedInteger();
40 return false;
41}
42
43/// Returns true if the given type is an unsigned integer or vector type
44static bool isUnsignedIntegerOrVector(Type type) {
45 if (type.isUnsignedInteger())
46 return true;
47 if (auto vecType = dyn_cast<VectorType>(Val&: type))
48 return vecType.getElementType().isUnsignedInteger();
49 return false;
50}
51
52/// Returns the width of an integer or of the element type of an integer vector,
53/// if applicable.
54static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) {
55 if (auto intType = dyn_cast<IntegerType>(Val&: type))
56 return intType.getWidth();
57 if (auto vecType = dyn_cast<VectorType>(Val&: type))
58 if (auto intType = dyn_cast<IntegerType>(Val: vecType.getElementType()))
59 return intType.getWidth();
60 return std::nullopt;
61}
62
63/// Returns the bit width of integer, float or vector of float or integer values
64static unsigned getBitWidth(Type type) {
65 assert((type.isIntOrFloat() || isa<VectorType>(type)) &&
66 "bitwidth is not supported for this type");
67 if (type.isIntOrFloat())
68 return type.getIntOrFloatBitWidth();
69 auto vecType = dyn_cast<VectorType>(Val&: type);
70 auto elementType = vecType.getElementType();
71 assert(elementType.isIntOrFloat() &&
72 "only integers and floats have a bitwidth");
73 return elementType.getIntOrFloatBitWidth();
74}
75
76/// Returns the bit width of LLVMType integer or vector.
77static unsigned getLLVMTypeBitWidth(Type type) {
78 if (auto vecTy = dyn_cast<VectorType>(Val&: type))
79 type = vecTy.getElementType();
80 return cast<IntegerType>(Val&: type).getWidth();
81}
82
83/// Creates `IntegerAttribute` with all bits set for given type
84static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
85 if (auto vecType = dyn_cast<VectorType>(Val&: type)) {
86 auto integerType = cast<IntegerType>(Val: vecType.getElementType());
87 return builder.getIntegerAttr(type: integerType, value: -1);
88 }
89 auto integerType = cast<IntegerType>(Val&: type);
90 return builder.getIntegerAttr(type: integerType, value: -1);
91}
92
93/// Creates `llvm.mlir.constant` with all bits set for the given type.
94static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
95 PatternRewriter &rewriter) {
96 if (isa<VectorType>(Val: srcType)) {
97 return rewriter.create<LLVM::ConstantOp>(
98 location: loc, args&: dstType,
99 args: SplatElementsAttr::get(type: cast<ShapedType>(Val&: srcType),
100 values: minusOneIntegerAttribute(type: srcType, builder: rewriter)));
101 }
102 return rewriter.create<LLVM::ConstantOp>(
103 location: loc, args&: dstType, args: minusOneIntegerAttribute(type: srcType, builder: rewriter));
104}
105
106/// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
107static Value createFPConstant(Location loc, Type srcType, Type dstType,
108 PatternRewriter &rewriter, double value) {
109 if (auto vecType = dyn_cast<VectorType>(Val&: srcType)) {
110 auto floatType = cast<FloatType>(Val: vecType.getElementType());
111 return rewriter.create<LLVM::ConstantOp>(
112 location: loc, args&: dstType,
113 args: SplatElementsAttr::get(type: vecType,
114 values: rewriter.getFloatAttr(type: floatType, value)));
115 }
116 auto floatType = cast<FloatType>(Val&: srcType);
117 return rewriter.create<LLVM::ConstantOp>(
118 location: loc, args&: dstType, args: rewriter.getFloatAttr(type: floatType, value));
119}
120
121/// Utility function for bitfield ops:
122/// - `BitFieldInsert`
123/// - `BitFieldSExtract`
124/// - `BitFieldUExtract`
125/// Truncates or extends the value. If the bitwidth of the value is the same as
126/// `llvmType` bitwidth, the value remains unchanged.
127static Value optionallyTruncateOrExtend(Location loc, Value value,
128 Type llvmType,
129 PatternRewriter &rewriter) {
130 auto srcType = value.getType();
131 unsigned targetBitWidth = getLLVMTypeBitWidth(type: llvmType);
132 unsigned valueBitWidth = LLVM::isCompatibleType(type: srcType)
133 ? getLLVMTypeBitWidth(type: srcType)
134 : getBitWidth(type: srcType);
135
136 if (valueBitWidth < targetBitWidth)
137 return rewriter.create<LLVM::ZExtOp>(location: loc, args&: llvmType, args&: value);
138 // If the bit widths of `Count` and `Offset` are greater than the bit width
139 // of the target type, they are truncated. Truncation is safe since `Count`
140 // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
141 // both values can be expressed in 8 bits.
142 if (valueBitWidth > targetBitWidth)
143 return rewriter.create<LLVM::TruncOp>(location: loc, args&: llvmType, args&: value);
144 return value;
145}
146
147/// Broadcasts the value to vector with `numElements` number of elements.
148static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
149 const TypeConverter &typeConverter,
150 ConversionPatternRewriter &rewriter) {
151 auto vectorType = VectorType::get(shape: numElements, elementType: toBroadcast.getType());
152 auto llvmVectorType = typeConverter.convertType(t: vectorType);
153 auto llvmI32Type = typeConverter.convertType(t: rewriter.getIntegerType(width: 32));
154 Value broadcasted = rewriter.create<LLVM::PoisonOp>(location: loc, args&: llvmVectorType);
155 for (unsigned i = 0; i < numElements; ++i) {
156 auto index = rewriter.create<LLVM::ConstantOp>(
157 location: loc, args&: llvmI32Type, args: rewriter.getI32IntegerAttr(value: i));
158 broadcasted = rewriter.create<LLVM::InsertElementOp>(
159 location: loc, args&: llvmVectorType, args&: broadcasted, args&: toBroadcast, args&: index);
160 }
161 return broadcasted;
162}
163
164/// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
165static Value optionallyBroadcast(Location loc, Value value, Type srcType,
166 const TypeConverter &typeConverter,
167 ConversionPatternRewriter &rewriter) {
168 if (auto vectorType = dyn_cast<VectorType>(Val&: srcType)) {
169 unsigned numElements = vectorType.getNumElements();
170 return broadcast(loc, toBroadcast: value, numElements, typeConverter, rewriter);
171 }
172 return value;
173}
174
175/// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
176/// `BitFieldUExtract`.
177/// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
178/// a vector type, construct a vector that has:
179/// - same number of elements as `Base`
180/// - each element has the type that is the same as the type of `Offset` or
181/// `Count`
182/// - each element has the same value as `Offset` or `Count`
183/// Then cast `Offset` and `Count` if their bit width is different
184/// from `Base` bit width.
185static Value processCountOrOffset(Location loc, Value value, Type srcType,
186 Type dstType, const TypeConverter &converter,
187 ConversionPatternRewriter &rewriter) {
188 Value broadcasted =
189 optionallyBroadcast(loc, value, srcType, typeConverter: converter, rewriter);
190 return optionallyTruncateOrExtend(loc, value: broadcasted, llvmType: dstType, rewriter);
191}
192
193/// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
194/// offset to LLVM struct. Otherwise, the conversion is not supported.
195static Type convertStructTypeWithOffset(spirv::StructType type,
196 const TypeConverter &converter) {
197 if (type != VulkanLayoutUtils::decorateType(structType: type))
198 return nullptr;
199
200 SmallVector<Type> elementsVector;
201 if (failed(Result: converter.convertTypes(types: type.getElementTypes(), results&: elementsVector)))
202 return nullptr;
203 return LLVM::LLVMStructType::getLiteral(context: type.getContext(), types: elementsVector,
204 /*isPacked=*/false);
205}
206
207/// Converts SPIR-V struct with no offset to packed LLVM struct.
208static Type convertStructTypePacked(spirv::StructType type,
209 const TypeConverter &converter) {
210 SmallVector<Type> elementsVector;
211 if (failed(Result: converter.convertTypes(types: type.getElementTypes(), results&: elementsVector)))
212 return nullptr;
213 return LLVM::LLVMStructType::getLiteral(context: type.getContext(), types: elementsVector,
214 /*isPacked=*/true);
215}
216
217/// Creates LLVM dialect constant with the given value.
218static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
219 unsigned value) {
220 return rewriter.create<LLVM::ConstantOp>(
221 location: loc, args: IntegerType::get(context: rewriter.getContext(), width: 32),
222 args: rewriter.getIntegerAttr(type: rewriter.getI32Type(), value));
223}
224
225/// Utility for `spirv.Load` and `spirv.Store` conversion.
226static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
227 ConversionPatternRewriter &rewriter,
228 const TypeConverter &typeConverter,
229 unsigned alignment, bool isVolatile,
230 bool isNonTemporal) {
231 if (auto loadOp = dyn_cast<spirv::LoadOp>(Val: op)) {
232 auto dstType = typeConverter.convertType(t: loadOp.getType());
233 if (!dstType)
234 return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed");
235 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
236 op: loadOp, args&: dstType, args: spirv::LoadOpAdaptor(operands).getPtr(), args&: alignment,
237 args&: isVolatile, args&: isNonTemporal);
238 return success();
239 }
240 auto storeOp = cast<spirv::StoreOp>(Val: op);
241 spirv::StoreOpAdaptor adaptor(operands);
242 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op: storeOp, args: adaptor.getValue(),
243 args: adaptor.getPtr(), args&: alignment,
244 args&: isVolatile, args&: isNonTemporal);
245 return success();
246}
247
248//===----------------------------------------------------------------------===//
249// Type conversion
250//===----------------------------------------------------------------------===//
251
252/// Converts SPIR-V array type to LLVM array. Natural stride (according to
253/// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
254/// when converting ops that manipulate array types.
255static std::optional<Type> convertArrayType(spirv::ArrayType type,
256 TypeConverter &converter) {
257 unsigned stride = type.getArrayStride();
258 Type elementType = type.getElementType();
259 auto sizeInBytes = cast<spirv::SPIRVType>(Val&: elementType).getSizeInBytes();
260 if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
261 return std::nullopt;
262
263 auto llvmElementType = converter.convertType(t: elementType);
264 unsigned numElements = type.getNumElements();
265 return LLVM::LLVMArrayType::get(elementType: llvmElementType, numElements);
266}
267
268/// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
269/// modelled at the moment.
270static Type convertPointerType(spirv::PointerType type,
271 const TypeConverter &converter,
272 spirv::ClientAPI clientAPI) {
273 unsigned addressSpace =
274 storageClassToAddressSpace(clientAPI, storageClass: type.getStorageClass());
275 return LLVM::LLVMPointerType::get(context: type.getContext(), addressSpace);
276}
277
278/// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
279/// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
280/// no modelling of array stride at the moment.
281static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
282 TypeConverter &converter) {
283 if (type.getArrayStride() != 0)
284 return std::nullopt;
285 auto elementType = converter.convertType(t: type.getElementType());
286 return LLVM::LLVMArrayType::get(elementType, numElements: 0);
287}
288
289/// Converts SPIR-V struct to LLVM struct. There is no support of structs with
290/// member decorations. Also, only natural offset is supported.
291static Type convertStructType(spirv::StructType type,
292 const TypeConverter &converter) {
293 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
294 type.getMemberDecorations(memberDecorations);
295 if (!memberDecorations.empty())
296 return nullptr;
297 if (type.hasOffset())
298 return convertStructTypeWithOffset(type, converter);
299 return convertStructTypePacked(type, converter);
300}
301
302//===----------------------------------------------------------------------===//
303// Operation conversion
304//===----------------------------------------------------------------------===//
305
306namespace {
307
308class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
309public:
310 using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
311
312 LogicalResult
313 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
314 ConversionPatternRewriter &rewriter) const override {
315 auto dstType =
316 getTypeConverter()->convertType(t: op.getComponentPtr().getType());
317 if (!dstType)
318 return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed");
319 // To use GEP we need to add a first 0 index to go through the pointer.
320 auto indices = llvm::to_vector<4>(Range: adaptor.getIndices());
321 Type indexType = op.getIndices().front().getType();
322 auto llvmIndexType = getTypeConverter()->convertType(t: indexType);
323 if (!llvmIndexType)
324 return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed");
325 Value zero = rewriter.create<LLVM::ConstantOp>(
326 location: op.getLoc(), args&: llvmIndexType, args: rewriter.getIntegerAttr(type: indexType, value: 0));
327 indices.insert(I: indices.begin(), Elt: zero);
328
329 auto elementType = getTypeConverter()->convertType(
330 t: cast<spirv::PointerType>(Val: op.getBasePtr().getType()).getPointeeType());
331 if (!elementType)
332 return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed");
333 rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, args&: dstType, args&: elementType,
334 args: adaptor.getBasePtr(), args&: indices);
335 return success();
336 }
337};
338
339class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
340public:
341 using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
342
343 LogicalResult
344 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
345 ConversionPatternRewriter &rewriter) const override {
346 auto dstType = getTypeConverter()->convertType(t: op.getPointer().getType());
347 if (!dstType)
348 return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed");
349 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, args&: dstType,
350 args: op.getVariable());
351 return success();
352 }
353};
354
355class BitFieldInsertPattern
356 : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
357public:
358 using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion;
359
360 LogicalResult
361 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
362 ConversionPatternRewriter &rewriter) const override {
363 auto srcType = op.getType();
364 auto dstType = getTypeConverter()->convertType(t: srcType);
365 if (!dstType)
366 return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed");
367 Location loc = op.getLoc();
368
369 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
370 Value offset = processCountOrOffset(loc, value: op.getOffset(), srcType, dstType,
371 converter: *getTypeConverter(), rewriter);
372 Value count = processCountOrOffset(loc, value: op.getCount(), srcType, dstType,
373 converter: *getTypeConverter(), rewriter);
374
375 // Create a mask with bits set outside [Offset, Offset + Count - 1].
376 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
377 Value maskShiftedByCount =
378 rewriter.create<LLVM::ShlOp>(location: loc, args&: dstType, args&: minusOne, args&: count);
379 Value negated = rewriter.create<LLVM::XOrOp>(location: loc, args&: dstType,
380 args&: maskShiftedByCount, args&: minusOne);
381 Value maskShiftedByCountAndOffset =
382 rewriter.create<LLVM::ShlOp>(location: loc, args&: dstType, args&: negated, args&: offset);
383 Value mask = rewriter.create<LLVM::XOrOp>(
384 location: loc, args&: dstType, args&: maskShiftedByCountAndOffset, args&: minusOne);
385
386 // Extract unchanged bits from the `Base` that are outside of
387 // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
388 Value baseAndMask =
389 rewriter.create<LLVM::AndOp>(location: loc, args&: dstType, args: op.getBase(), args&: mask);
390 Value insertShiftedByOffset =
391 rewriter.create<LLVM::ShlOp>(location: loc, args&: dstType, args: op.getInsert(), args&: offset);
392 rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, args&: dstType, args&: baseAndMask,
393 args&: insertShiftedByOffset);
394 return success();
395 }
396};
397
398/// Converts SPIR-V ConstantOp with scalar or vector type.
399class ConstantScalarAndVectorPattern
400 : public SPIRVToLLVMConversion<spirv::ConstantOp> {
401public:
402 using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
403
404 LogicalResult
405 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
406 ConversionPatternRewriter &rewriter) const override {
407 auto srcType = constOp.getType();
408 if (!isa<VectorType>(Val: srcType) && !srcType.isIntOrFloat())
409 return failure();
410
411 auto dstType = getTypeConverter()->convertType(t: srcType);
412 if (!dstType)
413 return rewriter.notifyMatchFailure(arg&: constOp, msg: "type conversion failed");
414
415 // SPIR-V constant can be a signed/unsigned integer, which has to be
416 // casted to signless integer when converting to LLVM dialect. Removing the
417 // sign bit may have unexpected behaviour. However, it is better to handle
418 // it case-by-case, given that the purpose of the conversion is not to
419 // cover all possible corner cases.
420 if (isSignedIntegerOrVector(type: srcType) ||
421 isUnsignedIntegerOrVector(type: srcType)) {
422 auto signlessType = rewriter.getIntegerType(width: getBitWidth(type: srcType));
423
424 if (isa<VectorType>(Val: srcType)) {
425 auto dstElementsAttr = cast<DenseIntElementsAttr>(Val: constOp.getValue());
426 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
427 op: constOp, args&: dstType,
428 args: dstElementsAttr.mapValues(
429 newElementType: signlessType, mapping: [&](const APInt &value) { return value; }));
430 return success();
431 }
432 auto srcAttr = cast<IntegerAttr>(Val: constOp.getValue());
433 auto dstAttr = rewriter.getIntegerAttr(type: signlessType, value: srcAttr.getValue());
434 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op: constOp, args&: dstType, args&: dstAttr);
435 return success();
436 }
437 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
438 op: constOp, args&: dstType, args: adaptor.getOperands(), args: constOp->getAttrs());
439 return success();
440 }
441};
442
443class BitFieldSExtractPattern
444 : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
445public:
446 using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion;
447
448 LogicalResult
449 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
450 ConversionPatternRewriter &rewriter) const override {
451 auto srcType = op.getType();
452 auto dstType = getTypeConverter()->convertType(t: srcType);
453 if (!dstType)
454 return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed");
455 Location loc = op.getLoc();
456
457 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
458 Value offset = processCountOrOffset(loc, value: op.getOffset(), srcType, dstType,
459 converter: *getTypeConverter(), rewriter);
460 Value count = processCountOrOffset(loc, value: op.getCount(), srcType, dstType,
461 converter: *getTypeConverter(), rewriter);
462
463 // Create a constant that holds the size of the `Base`.
464 IntegerType integerType;
465 if (auto vecType = dyn_cast<VectorType>(Val&: srcType))
466 integerType = cast<IntegerType>(Val: vecType.getElementType());
467 else
468 integerType = cast<IntegerType>(Val&: srcType);
469
470 auto baseSize = rewriter.getIntegerAttr(type: integerType, value: getBitWidth(type: srcType));
471 Value size =
472 isa<VectorType>(Val: srcType)
473 ? rewriter.create<LLVM::ConstantOp>(
474 location: loc, args&: dstType,
475 args: SplatElementsAttr::get(type: cast<ShapedType>(Val&: srcType), values: baseSize))
476 : rewriter.create<LLVM::ConstantOp>(location: loc, args&: dstType, args&: baseSize);
477
478 // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
479 // at Offset + Count - 1 is the most significant bit now.
480 Value countPlusOffset =
481 rewriter.create<LLVM::AddOp>(location: loc, args&: dstType, args&: count, args&: offset);
482 Value amountToShiftLeft =
483 rewriter.create<LLVM::SubOp>(location: loc, args&: dstType, args&: size, args&: countPlusOffset);
484 Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
485 location: loc, args&: dstType, args: op.getBase(), args&: amountToShiftLeft);
486
487 // Shift the result right, filling the bits with the sign bit.
488 Value amountToShiftRight =
489 rewriter.create<LLVM::AddOp>(location: loc, args&: dstType, args&: offset, args&: amountToShiftLeft);
490 rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, args&: dstType, args&: baseShiftedLeft,
491 args&: amountToShiftRight);
492 return success();
493 }
494};
495
496class BitFieldUExtractPattern
497 : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
498public:
499 using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion;
500
501 LogicalResult
502 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
503 ConversionPatternRewriter &rewriter) const override {
504 auto srcType = op.getType();
505 auto dstType = getTypeConverter()->convertType(t: srcType);
506 if (!dstType)
507 return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed");
508 Location loc = op.getLoc();
509
510 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
511 Value offset = processCountOrOffset(loc, value: op.getOffset(), srcType, dstType,
512 converter: *getTypeConverter(), rewriter);
513 Value count = processCountOrOffset(loc, value: op.getCount(), srcType, dstType,
514 converter: *getTypeConverter(), rewriter);
515
516 // Create a mask with bits set at [0, Count - 1].
517 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
518 Value maskShiftedByCount =
519 rewriter.create<LLVM::ShlOp>(location: loc, args&: dstType, args&: minusOne, args&: count);
520 Value mask = rewriter.create<LLVM::XOrOp>(location: loc, args&: dstType, args&: maskShiftedByCount,
521 args&: minusOne);
522
523 // Shift `Base` by `Offset` and apply the mask on it.
524 Value shiftedBase =
525 rewriter.create<LLVM::LShrOp>(location: loc, args&: dstType, args: op.getBase(), args&: offset);
526 rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, args&: dstType, args&: shiftedBase, args&: mask);
527 return success();
528 }
529};
530
531class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
532public:
533 using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion;
534
535 LogicalResult
536 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
537 ConversionPatternRewriter &rewriter) const override {
538 rewriter.replaceOpWithNewOp<LLVM::BrOp>(op: branchOp, args: adaptor.getOperands(),
539 args: branchOp.getTarget());
540 return success();
541 }
542};
543
544class BranchConditionalConversionPattern
545 : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
546public:
547 using SPIRVToLLVMConversion<
548 spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
549
550 LogicalResult
551 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
552 ConversionPatternRewriter &rewriter) const override {
553 // If branch weights exist, map them to 32-bit integer vector.
554 DenseI32ArrayAttr branchWeights = nullptr;
555 if (auto weights = op.getBranchWeights()) {
556 SmallVector<int32_t> weightValues;
557 for (auto weight : weights->getAsRange<IntegerAttr>())
558 weightValues.push_back(Elt: weight.getInt());
559 branchWeights = DenseI32ArrayAttr::get(context: getContext(), content: weightValues);
560 }
561
562 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
563 op, args: op.getCondition(), args: op.getTrueBlockArguments(),
564 args: op.getFalseBlockArguments(), args&: branchWeights, args: op.getTrueBlock(),
565 args: op.getFalseBlock());
566 return success();
567 }
568};
569
570/// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container
571/// type is an aggregate type (struct or array). Otherwise, converts to
572/// `llvm.extractelement` that operates on vectors.
573class CompositeExtractPattern
574 : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
575public:
576 using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion;
577
578 LogicalResult
579 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
580 ConversionPatternRewriter &rewriter) const override {
581 auto dstType = this->getTypeConverter()->convertType(t: op.getType());
582 if (!dstType)
583 return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed");
584
585 Type containerType = op.getComposite().getType();
586 if (isa<VectorType>(Val: containerType)) {
587 Location loc = op.getLoc();
588 IntegerAttr value = cast<IntegerAttr>(Val: op.getIndices()[0]);
589 Value index = createI32ConstantOf(loc, rewriter, value: value.getInt());
590 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
591 op, args&: dstType, args: adaptor.getComposite(), args&: index);
592 return success();
593 }
594
595 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
596 op, args: adaptor.getComposite(),
597 args: LLVM::convertArrayToIndices(attrs: op.getIndices()));
598 return success();
599 }
600};
601
602/// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container
603/// type is an aggregate type (struct or array). Otherwise, converts to
604/// `llvm.insertelement` that operates on vectors.
605class CompositeInsertPattern
606 : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
607public:
608 using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion;
609
610 LogicalResult
611 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
612 ConversionPatternRewriter &rewriter) const override {
613 auto dstType = this->getTypeConverter()->convertType(t: op.getType());
614 if (!dstType)
615 return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed");
616
617 Type containerType = op.getComposite().getType();
618 if (isa<VectorType>(Val: containerType)) {
619 Location loc = op.getLoc();
620 IntegerAttr value = cast<IntegerAttr>(Val: op.getIndices()[0]);
621 Value index = createI32ConstantOf(loc, rewriter, value: value.getInt());
622 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
623 op, args&: dstType, args: adaptor.getComposite(), args: adaptor.getObject(), args&: index);
624 return success();
625 }
626
627 rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
628 op, args: adaptor.getComposite(), args: adaptor.getObject(),
629 args: LLVM::convertArrayToIndices(attrs: op.getIndices()));
630 return success();
631 }
632};
633
634/// Converts SPIR-V operations that have straightforward LLVM equivalent
635/// into LLVM dialect operations.
636template <typename SPIRVOp, typename LLVMOp>
637class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
638public:
639 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
640
641 LogicalResult
642 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
643 ConversionPatternRewriter &rewriter) const override {
644 auto dstType = this->getTypeConverter()->convertType(op.getType());
645 if (!dstType)
646 return rewriter.notifyMatchFailure(op, "type conversion failed");
647 rewriter.template replaceOpWithNewOp<LLVMOp>(
648 op, dstType, adaptor.getOperands(), op->getAttrs());
649 return success();
650 }
651};
652
653/// Converts `spirv.ExecutionMode` into a global struct constant that holds
654/// execution mode information.
655class ExecutionModePattern
656 : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
657public:
658 using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion;
659
660 LogicalResult
661 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
662 ConversionPatternRewriter &rewriter) const override {
663 // First, create the global struct's name that would be associated with
664 // this entry point's execution mode. We set it to be:
665 // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
666 ModuleOp module = op->getParentOfType<ModuleOp>();
667 spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
668 std::string moduleName;
669 if (module.getName().has_value())
670 moduleName = "_" + module.getName()->str();
671 else
672 moduleName = "";
673 std::string executionModeInfoName = llvm::formatv(
674 Fmt: "__spv_{0}_{1}_execution_mode_info_{2}", Vals&: moduleName, Vals: op.getFn().str(),
675 Vals: static_cast<uint32_t>(executionModeAttr.getValue()));
676
677 MLIRContext *context = rewriter.getContext();
678 OpBuilder::InsertionGuard guard(rewriter);
679 rewriter.setInsertionPointToStart(module.getBody());
680
681 // Create a struct type, corresponding to the C struct below.
682 // struct {
683 // int32_t executionMode;
684 // int32_t values[]; // optional values
685 // };
686 auto llvmI32Type = IntegerType::get(context, width: 32);
687 SmallVector<Type, 2> fields;
688 fields.push_back(Elt: llvmI32Type);
689 ArrayAttr values = op.getValues();
690 if (!values.empty()) {
691 auto arrayType = LLVM::LLVMArrayType::get(elementType: llvmI32Type, numElements: values.size());
692 fields.push_back(Elt: arrayType);
693 }
694 auto structType = LLVM::LLVMStructType::getLiteral(context, types: fields);
695
696 // Create `llvm.mlir.global` with initializer region containing one block.
697 auto global = rewriter.create<LLVM::GlobalOp>(
698 location: UnknownLoc::get(context), args&: structType, /*isConstant=*/args: true,
699 args: LLVM::Linkage::External, args&: executionModeInfoName, args: Attribute(),
700 /*alignment=*/args: 0);
701 Location loc = global.getLoc();
702 Region &region = global.getInitializerRegion();
703 Block *block = rewriter.createBlock(parent: &region);
704
705 // Initialize the struct and set the execution mode value.
706 rewriter.setInsertionPointToStart(block);
707 Value structValue = rewriter.create<LLVM::PoisonOp>(location: loc, args&: structType);
708 Value executionMode = rewriter.create<LLVM::ConstantOp>(
709 location: loc, args&: llvmI32Type,
710 args: rewriter.getI32IntegerAttr(
711 value: static_cast<uint32_t>(executionModeAttr.getValue())));
712 structValue = rewriter.create<LLVM::InsertValueOp>(location: loc, args&: structValue,
713 args&: executionMode, args: 0);
714
715 // Insert extra operands if they exist into execution mode info struct.
716 for (unsigned i = 0, e = values.size(); i < e; ++i) {
717 auto attr = values.getValue()[i];
718 Value entry = rewriter.create<LLVM::ConstantOp>(location: loc, args&: llvmI32Type, args&: attr);
719 structValue = rewriter.create<LLVM::InsertValueOp>(
720 location: loc, args&: structValue, args&: entry, args: ArrayRef<int64_t>({1, i}));
721 }
722 rewriter.create<LLVM::ReturnOp>(location: loc, args: ArrayRef<Value>({structValue}));
723 rewriter.eraseOp(op);
724 return success();
725 }
726};
727
728/// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V
729/// global returns a pointer, whereas in LLVM dialect the global holds an actual
730/// value. This difference is handled by `spirv.mlir.addressof` and
731/// `llvm.mlir.addressof`ops that both return a pointer.
732class GlobalVariablePattern
733 : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
734public:
735 template <typename... Args>
736 GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
737 : SPIRVToLLVMConversion<spirv::GlobalVariableOp>(
738 std::forward<Args>(args)...),
739 clientAPI(clientAPI) {}
740
741 LogicalResult
742 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
743 ConversionPatternRewriter &rewriter) const override {
744 // Currently, there is no support of initialization with a constant value in
745 // SPIR-V dialect. Specialization constants are not considered as well.
746 if (op.getInitializer())
747 return failure();
748
749 auto srcType = cast<spirv::PointerType>(Val: op.getType());
750 auto dstType = getTypeConverter()->convertType(t: srcType.getPointeeType());
751 if (!dstType)
752 return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed");
753
754 // Limit conversion to the current invocation only or `StorageBuffer`
755 // required by SPIR-V runner.
756 // This is okay because multiple invocations are not supported yet.
757 auto storageClass = srcType.getStorageClass();
758 switch (storageClass) {
759 case spirv::StorageClass::Input:
760 case spirv::StorageClass::Private:
761 case spirv::StorageClass::Output:
762 case spirv::StorageClass::StorageBuffer:
763 case spirv::StorageClass::UniformConstant:
764 break;
765 default:
766 return failure();
767 }
768
769 // LLVM dialect spec: "If the global value is a constant, storing into it is
770 // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
771 // storage class that is read-only.
772 bool isConstant = (storageClass == spirv::StorageClass::Input) ||
773 (storageClass == spirv::StorageClass::UniformConstant);
774 // SPIR-V spec: "By default, functions and global variables are private to a
775 // module and cannot be accessed by other modules. However, a module may be
776 // written to export or import functions and global (module scope)
777 // variables.". Therefore, map 'Private' storage class to private linkage,
778 // 'Input' and 'Output' to external linkage.
779 auto linkage = storageClass == spirv::StorageClass::Private
780 ? LLVM::Linkage::Private
781 : LLVM::Linkage::External;
782 StringAttr locationAttrName = op.getLocationAttrName();
783 IntegerAttr locationAttr = op.getLocationAttr();
784 auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
785 op, args&: dstType, args&: isConstant, args&: linkage, args: op.getSymName(), args: Attribute(),
786 /*alignment=*/args: 0, args: storageClassToAddressSpace(clientAPI, storageClass));
787
788 // Attach location attribute if applicable
789 if (locationAttr)
790 newGlobalOp->setAttr(name: locationAttrName, value: locationAttr);
791
792 return success();
793 }
794
795private:
796 spirv::ClientAPI clientAPI;
797};
798
799/// Converts SPIR-V cast ops that do not have straightforward LLVM
800/// equivalent in LLVM dialect.
801template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
802class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
803public:
804 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
805
806 LogicalResult
807 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
808 ConversionPatternRewriter &rewriter) const override {
809
810 Type fromType = op.getOperand().getType();
811 Type toType = op.getType();
812
813 auto dstType = this->getTypeConverter()->convertType(toType);
814 if (!dstType)
815 return rewriter.notifyMatchFailure(op, "type conversion failed");
816
817 if (getBitWidth(type: fromType) < getBitWidth(type: toType)) {
818 rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
819 adaptor.getOperands());
820 return success();
821 }
822 if (getBitWidth(type: fromType) > getBitWidth(type: toType)) {
823 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
824 adaptor.getOperands());
825 return success();
826 }
827 return failure();
828 }
829};
830
831class FunctionCallPattern
832 : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
833public:
834 using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion;
835
836 LogicalResult
837 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
838 ConversionPatternRewriter &rewriter) const override {
839 if (callOp.getNumResults() == 0) {
840 auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
841 op: callOp, args: TypeRange(), args: adaptor.getOperands(), args: callOp->getAttrs());
842 newOp.getProperties().operandSegmentSizes = {
843 static_cast<int32_t>(adaptor.getOperands().size()), 0};
844 newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr(values: {});
845 return success();
846 }
847
848 // Function returns a single result.
849 auto dstType = getTypeConverter()->convertType(t: callOp.getType(i: 0));
850 if (!dstType)
851 return rewriter.notifyMatchFailure(arg&: callOp, msg: "type conversion failed");
852 auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
853 op: callOp, args&: dstType, args: adaptor.getOperands(), args: callOp->getAttrs());
854 newOp.getProperties().operandSegmentSizes = {
855 static_cast<int32_t>(adaptor.getOperands().size()), 0};
856 newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr(values: {});
857 return success();
858 }
859};
860
861/// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
862template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
863class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
864public:
865 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
866
867 LogicalResult
868 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
869 ConversionPatternRewriter &rewriter) const override {
870
871 auto dstType = this->getTypeConverter()->convertType(op.getType());
872 if (!dstType)
873 return rewriter.notifyMatchFailure(op, "type conversion failed");
874
875 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
876 op, dstType, predicate, op.getOperand1(), op.getOperand2());
877 return success();
878 }
879};
880
881/// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
882template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
883class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
884public:
885 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
886
887 LogicalResult
888 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
889 ConversionPatternRewriter &rewriter) const override {
890
891 auto dstType = this->getTypeConverter()->convertType(op.getType());
892 if (!dstType)
893 return rewriter.notifyMatchFailure(op, "type conversion failed");
894
895 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
896 op, dstType, predicate, op.getOperand1(), op.getOperand2());
897 return success();
898 }
899};
900
901class InverseSqrtPattern
902 : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> {
903public:
904 using SPIRVToLLVMConversion<spirv::GLInverseSqrtOp>::SPIRVToLLVMConversion;
905
906 LogicalResult
907 matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
908 ConversionPatternRewriter &rewriter) const override {
909 auto srcType = op.getType();
910 auto dstType = getTypeConverter()->convertType(t: srcType);
911 if (!dstType)
912 return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed");
913
914 Location loc = op.getLoc();
915 Value one = createFPConstant(loc, srcType, dstType, rewriter, value: 1.0);
916 Value sqrt = rewriter.create<LLVM::SqrtOp>(location: loc, args&: dstType, args: op.getOperand());
917 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, args&: dstType, args&: one, args&: sqrt);
918 return success();
919 }
920};
921
922/// Converts `spirv.Load` and `spirv.Store` to LLVM dialect.
923template <typename SPIRVOp>
924class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
925public:
926 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
927
928 LogicalResult
929 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
930 ConversionPatternRewriter &rewriter) const override {
931 if (!op.getMemoryAccess()) {
932 return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
933 *this->getTypeConverter(), /*alignment=*/0,
934 /*isVolatile=*/false,
935 /*isNonTemporal=*/false);
936 }
937 auto memoryAccess = *op.getMemoryAccess();
938 switch (memoryAccess) {
939 case spirv::MemoryAccess::Aligned:
940 case spirv::MemoryAccess::None:
941 case spirv::MemoryAccess::Nontemporal:
942 case spirv::MemoryAccess::Volatile: {
943 unsigned alignment =
944 memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
945 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
946 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
947 return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
948 *this->getTypeConverter(), alignment,
949 isVolatile, isNonTemporal);
950 }
951 default:
952 // There is no support of other memory access attributes.
953 return failure();
954 }
955 }
956};
957
958/// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect.
959template <typename SPIRVOp>
960class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
961public:
962 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
963
964 LogicalResult
965 matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
966 ConversionPatternRewriter &rewriter) const override {
967 auto srcType = notOp.getType();
968 auto dstType = this->getTypeConverter()->convertType(srcType);
969 if (!dstType)
970 return rewriter.notifyMatchFailure(notOp, "type conversion failed");
971
972 Location loc = notOp.getLoc();
973 IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
974 auto mask =
975 isa<VectorType>(srcType)
976 ? rewriter.create<LLVM::ConstantOp>(
977 loc, dstType,
978 SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
979 : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
980 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
981 notOp.getOperand(), mask);
982 return success();
983 }
984};
985
986/// A template pattern that erases the given `SPIRVOp`.
987template <typename SPIRVOp>
988class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
989public:
990 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
991
992 LogicalResult
993 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
994 ConversionPatternRewriter &rewriter) const override {
995 rewriter.eraseOp(op);
996 return success();
997 }
998};
999
1000class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
1001public:
1002 using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
1003
1004 LogicalResult
1005 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1006 ConversionPatternRewriter &rewriter) const override {
1007 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op: returnOp, args: ArrayRef<Type>(),
1008 args: ArrayRef<Value>());
1009 return success();
1010 }
1011};
1012
1013class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
1014public:
1015 using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
1016
1017 LogicalResult
1018 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1019 ConversionPatternRewriter &rewriter) const override {
1020 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op: returnValueOp, args: ArrayRef<Type>(),
1021 args: adaptor.getOperands());
1022 return success();
1023 }
1024};
1025
1026static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
1027 StringRef name,
1028 ArrayRef<Type> paramTypes,
1029 Type resultType,
1030 bool convergent = true) {
1031 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1032 Val: SymbolTable::lookupSymbolIn(op: symbolTable, symbol: name));
1033 if (func)
1034 return func;
1035
1036 OpBuilder b(symbolTable->getRegion(index: 0));
1037 func = b.create<LLVM::LLVMFuncOp>(
1038 location: symbolTable->getLoc(), args&: name,
1039 args: LLVM::LLVMFunctionType::get(result: resultType, arguments: paramTypes));
1040 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1041 func.setConvergent(convergent);
1042 func.setNoUnwind(true);
1043 func.setWillReturn(true);
1044 return func;
1045}
1046
1047static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
1048 LLVM::LLVMFuncOp func,
1049 ValueRange args) {
1050 auto call = builder.create<LLVM::CallOp>(location: loc, args&: func, args);
1051 call.setCConv(func.getCConv());
1052 call.setConvergentAttr(func.getConvergentAttr());
1053 call.setNoUnwindAttr(func.getNoUnwindAttr());
1054 call.setWillReturnAttr(func.getWillReturnAttr());
1055 return call;
1056}
1057
1058template <typename BarrierOpTy>
1059class ControlBarrierPattern : public SPIRVToLLVMConversion<BarrierOpTy> {
1060public:
1061 using OpAdaptor = typename SPIRVToLLVMConversion<BarrierOpTy>::OpAdaptor;
1062
1063 using SPIRVToLLVMConversion<BarrierOpTy>::SPIRVToLLVMConversion;
1064
1065 static constexpr StringRef getFuncName();
1066
1067 LogicalResult
1068 matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
1069 ConversionPatternRewriter &rewriter) const override {
1070 constexpr StringRef funcName = getFuncName();
1071 Operation *symbolTable =
1072 controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>();
1073
1074 Type i32 = rewriter.getI32Type();
1075
1076 Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
1077 LLVM::LLVMFuncOp func =
1078 lookupOrCreateSPIRVFn(symbolTable, name: funcName, paramTypes: {i32, i32, i32}, resultType: voidTy);
1079
1080 Location loc = controlBarrierOp->getLoc();
1081 Value execution = rewriter.create<LLVM::ConstantOp>(
1082 location: loc, args&: i32, args: static_cast<int32_t>(adaptor.getExecutionScope()));
1083 Value memory = rewriter.create<LLVM::ConstantOp>(
1084 location: loc, args&: i32, args: static_cast<int32_t>(adaptor.getMemoryScope()));
1085 Value semantics = rewriter.create<LLVM::ConstantOp>(
1086 location: loc, args&: i32, args: static_cast<int32_t>(adaptor.getMemorySemantics()));
1087
1088 auto call = createSPIRVBuiltinCall(loc, builder&: rewriter, func,
1089 args: {execution, memory, semantics});
1090
1091 rewriter.replaceOp(controlBarrierOp, call);
1092 return success();
1093 }
1094};
1095
1096namespace {
1097
1098StringRef getTypeMangling(Type type, bool isSigned) {
1099 return llvm::TypeSwitch<Type, StringRef>(type)
1100 .Case<Float16Type>(caseFn: [](auto) { return "Dh"; })
1101 .Case<Float32Type>(caseFn: [](auto) { return "f"; })
1102 .Case<Float64Type>(caseFn: [](auto) { return "d"; })
1103 .Case<IntegerType>(caseFn: [isSigned](IntegerType intTy) {
1104 switch (intTy.getWidth()) {
1105 case 1:
1106 return "b";
1107 case 8:
1108 return (isSigned) ? "a" : "c";
1109 case 16:
1110 return (isSigned) ? "s" : "t";
1111 case 32:
1112 return (isSigned) ? "i" : "j";
1113 case 64:
1114 return (isSigned) ? "l" : "m";
1115 default:
1116 llvm_unreachable("Unsupported integer width");
1117 }
1118 })
1119 .Default(defaultFn: [](auto) {
1120 llvm_unreachable("No mangling defined");
1121 return "";
1122 });
1123}
1124
1125template <typename ReduceOp>
1126constexpr StringLiteral getGroupFuncName();
1127
1128template <>
1129constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
1130 return "_Z17__spirv_GroupIAddii";
1131}
1132template <>
1133constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
1134 return "_Z17__spirv_GroupFAddii";
1135}
1136template <>
1137constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
1138 return "_Z17__spirv_GroupSMinii";
1139}
1140template <>
1141constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
1142 return "_Z17__spirv_GroupUMinii";
1143}
1144template <>
1145constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
1146 return "_Z17__spirv_GroupFMinii";
1147}
1148template <>
1149constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
1150 return "_Z17__spirv_GroupSMaxii";
1151}
1152template <>
1153constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
1154 return "_Z17__spirv_GroupUMaxii";
1155}
1156template <>
1157constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
1158 return "_Z17__spirv_GroupFMaxii";
1159}
1160template <>
1161constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
1162 return "_Z27__spirv_GroupNonUniformIAddii";
1163}
1164template <>
1165constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
1166 return "_Z27__spirv_GroupNonUniformFAddii";
1167}
1168template <>
1169constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
1170 return "_Z27__spirv_GroupNonUniformIMulii";
1171}
1172template <>
1173constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
1174 return "_Z27__spirv_GroupNonUniformFMulii";
1175}
1176template <>
1177constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
1178 return "_Z27__spirv_GroupNonUniformSMinii";
1179}
1180template <>
1181constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
1182 return "_Z27__spirv_GroupNonUniformUMinii";
1183}
1184template <>
1185constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
1186 return "_Z27__spirv_GroupNonUniformFMinii";
1187}
1188template <>
1189constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
1190 return "_Z27__spirv_GroupNonUniformSMaxii";
1191}
1192template <>
1193constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
1194 return "_Z27__spirv_GroupNonUniformUMaxii";
1195}
1196template <>
1197constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
1198 return "_Z27__spirv_GroupNonUniformFMaxii";
1199}
1200template <>
1201constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
1202 return "_Z33__spirv_GroupNonUniformBitwiseAndii";
1203}
1204template <>
1205constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
1206 return "_Z32__spirv_GroupNonUniformBitwiseOrii";
1207}
1208template <>
1209constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
1210 return "_Z33__spirv_GroupNonUniformBitwiseXorii";
1211}
1212template <>
1213constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
1214 return "_Z33__spirv_GroupNonUniformLogicalAndii";
1215}
1216template <>
1217constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
1218 return "_Z32__spirv_GroupNonUniformLogicalOrii";
1219}
1220template <>
1221constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
1222 return "_Z33__spirv_GroupNonUniformLogicalXorii";
1223}
1224} // namespace
1225
1226template <typename ReduceOp, bool Signed = false, bool NonUniform = false>
1227class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
1228public:
1229 using SPIRVToLLVMConversion<ReduceOp>::SPIRVToLLVMConversion;
1230
1231 LogicalResult
1232 matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor,
1233 ConversionPatternRewriter &rewriter) const override {
1234
1235 Type retTy = op.getResult().getType();
1236 if (!retTy.isIntOrFloat()) {
1237 return failure();
1238 }
1239 SmallString<36> funcName = getGroupFuncName<ReduceOp>();
1240 funcName += getTypeMangling(type: retTy, isSigned: false);
1241
1242 Type i32Ty = rewriter.getI32Type();
1243 SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy};
1244 if constexpr (NonUniform) {
1245 if (adaptor.getClusterSize()) {
1246 funcName += "j";
1247 paramTypes.push_back(Elt: i32Ty);
1248 }
1249 }
1250
1251 Operation *symbolTable =
1252 op->template getParentWithTrait<OpTrait::SymbolTable>();
1253
1254 LLVM::LLVMFuncOp func =
1255 lookupOrCreateSPIRVFn(symbolTable, name: funcName, paramTypes, resultType: retTy);
1256
1257 Location loc = op.getLoc();
1258 Value scope = rewriter.create<LLVM::ConstantOp>(
1259 location: loc, args&: i32Ty, args: static_cast<int32_t>(adaptor.getExecutionScope()));
1260 Value groupOp = rewriter.create<LLVM::ConstantOp>(
1261 location: loc, args&: i32Ty, args: static_cast<int32_t>(adaptor.getGroupOperation()));
1262 SmallVector<Value> operands{scope, groupOp};
1263 operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
1264
1265 auto call = createSPIRVBuiltinCall(loc, builder&: rewriter, func, args: operands);
1266 rewriter.replaceOp(op, call);
1267 return success();
1268 }
1269};
1270
1271template <>
1272constexpr StringRef
1273ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
1274 return "_Z22__spirv_ControlBarrieriii";
1275}
1276
1277template <>
1278constexpr StringRef
1279ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
1280 return "_Z33__spirv_ControlBarrierArriveINTELiii";
1281}
1282
1283template <>
1284constexpr StringRef
1285ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
1286 return "_Z31__spirv_ControlBarrierWaitINTELiii";
1287}
1288
1289/// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
1290/// should be reachable for conversion to succeed. The structure of the loop in
1291/// LLVM dialect will be the following:
1292///
1293/// +------------------------------------+
1294/// | <code before spirv.mlir.loop> |
1295/// | llvm.br ^header |
1296/// +------------------------------------+
1297/// |
1298/// +----------------+ |
1299/// | | |
1300/// | V V
1301/// | +------------------------------------+
1302/// | | ^header: |
1303/// | | <header code> |
1304/// | | llvm.cond_br %cond, ^body, ^exit |
1305/// | +------------------------------------+
1306/// | |
1307/// | |----------------------+
1308/// | | |
1309/// | V |
1310/// | +------------------------------------+ |
1311/// | | ^body: | |
1312/// | | <body code> | |
1313/// | | llvm.br ^continue | |
1314/// | +------------------------------------+ |
1315/// | | |
1316/// | V |
1317/// | +------------------------------------+ |
1318/// | | ^continue: | |
1319/// | | <continue code> | |
1320/// | | llvm.br ^header | |
1321/// | +------------------------------------+ |
1322/// | | |
1323/// +---------------+ +----------------------+
1324/// |
1325/// V
1326/// +------------------------------------+
1327/// | ^exit: |
1328/// | llvm.br ^remaining |
1329/// +------------------------------------+
1330/// |
1331/// V
1332/// +------------------------------------+
1333/// | ^remaining: |
1334/// | <code after spirv.mlir.loop> |
1335/// +------------------------------------+
1336///
1337class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1338public:
1339 using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion;
1340
1341 LogicalResult
1342 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1343 ConversionPatternRewriter &rewriter) const override {
1344 // There is no support of loop control at the moment.
1345 if (loopOp.getLoopControl() != spirv::LoopControl::None)
1346 return failure();
1347
1348 // `spirv.mlir.loop` with empty region is redundant and should be erased.
1349 if (loopOp.getBody().empty()) {
1350 rewriter.eraseOp(op: loopOp);
1351 return success();
1352 }
1353
1354 Location loc = loopOp.getLoc();
1355
1356 // Split the current block after `spirv.mlir.loop`. The remaining ops will
1357 // be used in `endBlock`.
1358 Block *currentBlock = rewriter.getBlock();
1359 auto position = Block::iterator(loopOp);
1360 Block *endBlock = rewriter.splitBlock(block: currentBlock, before: position);
1361
1362 // Remove entry block and create a branch in the current block going to the
1363 // header block.
1364 Block *entryBlock = loopOp.getEntryBlock();
1365 assert(entryBlock->getOperations().size() == 1);
1366 auto brOp = dyn_cast<spirv::BranchOp>(Val&: entryBlock->getOperations().front());
1367 if (!brOp)
1368 return failure();
1369 Block *headerBlock = loopOp.getHeaderBlock();
1370 rewriter.setInsertionPointToEnd(currentBlock);
1371 rewriter.create<LLVM::BrOp>(location: loc, args: brOp.getBlockArguments(), args&: headerBlock);
1372 rewriter.eraseBlock(block: entryBlock);
1373
1374 // Branch from merge block to end block.
1375 Block *mergeBlock = loopOp.getMergeBlock();
1376 Operation *terminator = mergeBlock->getTerminator();
1377 ValueRange terminatorOperands = terminator->getOperands();
1378 rewriter.setInsertionPointToEnd(mergeBlock);
1379 rewriter.create<LLVM::BrOp>(location: loc, args&: terminatorOperands, args&: endBlock);
1380
1381 rewriter.inlineRegionBefore(region&: loopOp.getBody(), before: endBlock);
1382 rewriter.replaceOp(op: loopOp, newValues: endBlock->getArguments());
1383 return success();
1384 }
1385};
1386
1387/// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header
1388/// block. All blocks within selection should be reachable for conversion to
1389/// succeed.
1390class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1391public:
1392 using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion;
1393
1394 LogicalResult
1395 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1396 ConversionPatternRewriter &rewriter) const override {
1397 // There is no support for `Flatten` or `DontFlatten` selection control at
1398 // the moment. This are just compiler hints and can be performed during the
1399 // optimization passes.
1400 if (op.getSelectionControl() != spirv::SelectionControl::None)
1401 return failure();
1402
1403 // `spirv.mlir.selection` should have at least two blocks: one selection
1404 // header block and one merge block. If no blocks are present, or control
1405 // flow branches straight to merge block (two blocks are present), the op is
1406 // redundant and it is erased.
1407 if (op.getBody().getBlocks().size() <= 2) {
1408 rewriter.eraseOp(op);
1409 return success();
1410 }
1411
1412 Location loc = op.getLoc();
1413
1414 // Split the current block after `spirv.mlir.selection`. The remaining ops
1415 // will be used in `continueBlock`.
1416 auto *currentBlock = rewriter.getInsertionBlock();
1417 rewriter.setInsertionPointAfter(op);
1418 auto position = rewriter.getInsertionPoint();
1419 auto *continueBlock = rewriter.splitBlock(block: currentBlock, before: position);
1420
1421 // Extract conditional branch information from the header block. By SPIR-V
1422 // dialect spec, it should contain `spirv.BranchConditional` or
1423 // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the
1424 // moment in the SPIR-V dialect. Remove this block when finished.
1425 auto *headerBlock = op.getHeaderBlock();
1426 assert(headerBlock->getOperations().size() == 1);
1427 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1428 Val&: headerBlock->getOperations().front());
1429 if (!condBrOp)
1430 return failure();
1431
1432 // Branch from merge block to continue block.
1433 auto *mergeBlock = op.getMergeBlock();
1434 Operation *terminator = mergeBlock->getTerminator();
1435 ValueRange terminatorOperands = terminator->getOperands();
1436 rewriter.setInsertionPointToEnd(mergeBlock);
1437 rewriter.create<LLVM::BrOp>(location: loc, args&: terminatorOperands, args&: continueBlock);
1438
1439 // Link current block to `true` and `false` blocks within the selection.
1440 Block *trueBlock = condBrOp.getTrueBlock();
1441 Block *falseBlock = condBrOp.getFalseBlock();
1442 rewriter.setInsertionPointToEnd(currentBlock);
1443 rewriter.create<LLVM::CondBrOp>(location: loc, args: condBrOp.getCondition(), args&: trueBlock,
1444 args: condBrOp.getTrueTargetOperands(),
1445 args&: falseBlock,
1446 args: condBrOp.getFalseTargetOperands());
1447
1448 rewriter.eraseBlock(block: headerBlock);
1449 rewriter.inlineRegionBefore(region&: op.getBody(), before: continueBlock);
1450 rewriter.replaceOp(op, newValues: continueBlock->getArguments());
1451 return success();
1452 }
1453};
1454
1455/// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1456/// puts a restriction on `Shift` and `Base` to have the same bit width,
1457/// `Shift` is zero or sign extended to match this specification. Cases when
1458/// `Shift` bit width > `Base` bit width are considered to be illegal.
1459template <typename SPIRVOp, typename LLVMOp>
1460class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1461public:
1462 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
1463
1464 LogicalResult
1465 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
1466 ConversionPatternRewriter &rewriter) const override {
1467
1468 auto dstType = this->getTypeConverter()->convertType(op.getType());
1469 if (!dstType)
1470 return rewriter.notifyMatchFailure(op, "type conversion failed");
1471
1472 Type op1Type = op.getOperand1().getType();
1473 Type op2Type = op.getOperand2().getType();
1474
1475 if (op1Type == op2Type) {
1476 rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1477 adaptor.getOperands());
1478 return success();
1479 }
1480
1481 std::optional<uint64_t> dstTypeWidth =
1482 getIntegerOrVectorElementWidth(dstType);
1483 std::optional<uint64_t> op2TypeWidth =
1484 getIntegerOrVectorElementWidth(type: op2Type);
1485
1486 if (!dstTypeWidth || !op2TypeWidth)
1487 return failure();
1488
1489 Location loc = op.getLoc();
1490 Value extended;
1491 if (op2TypeWidth < dstTypeWidth) {
1492 if (isUnsignedIntegerOrVector(type: op2Type)) {
1493 extended = rewriter.template create<LLVM::ZExtOp>(
1494 loc, dstType, adaptor.getOperand2());
1495 } else {
1496 extended = rewriter.template create<LLVM::SExtOp>(
1497 loc, dstType, adaptor.getOperand2());
1498 }
1499 } else if (op2TypeWidth == dstTypeWidth) {
1500 extended = adaptor.getOperand2();
1501 } else {
1502 return failure();
1503 }
1504
1505 Value result = rewriter.template create<LLVMOp>(
1506 loc, dstType, adaptor.getOperand1(), extended);
1507 rewriter.replaceOp(op, result);
1508 return success();
1509 }
1510};
1511
1512class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
1513public:
1514 using SPIRVToLLVMConversion<spirv::GLTanOp>::SPIRVToLLVMConversion;
1515
1516 LogicalResult
1517 matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1518 ConversionPatternRewriter &rewriter) const override {
1519 auto dstType = getTypeConverter()->convertType(t: tanOp.getType());
1520 if (!dstType)
1521 return rewriter.notifyMatchFailure(arg&: tanOp, msg: "type conversion failed");
1522
1523 Location loc = tanOp.getLoc();
1524 Value sin = rewriter.create<LLVM::SinOp>(location: loc, args&: dstType, args: tanOp.getOperand());
1525 Value cos = rewriter.create<LLVM::CosOp>(location: loc, args&: dstType, args: tanOp.getOperand());
1526 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op: tanOp, args&: dstType, args&: sin, args&: cos);
1527 return success();
1528 }
1529};
1530
1531/// Convert `spirv.Tanh` to
1532///
1533/// exp(2x) - 1
1534/// -----------
1535/// exp(2x) + 1
1536///
1537class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
1538public:
1539 using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion;
1540
1541 LogicalResult
1542 matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1543 ConversionPatternRewriter &rewriter) const override {
1544 auto srcType = tanhOp.getType();
1545 auto dstType = getTypeConverter()->convertType(t: srcType);
1546 if (!dstType)
1547 return rewriter.notifyMatchFailure(arg&: tanhOp, msg: "type conversion failed");
1548
1549 Location loc = tanhOp.getLoc();
1550 Value two = createFPConstant(loc, srcType, dstType, rewriter, value: 2.0);
1551 Value multiplied =
1552 rewriter.create<LLVM::FMulOp>(location: loc, args&: dstType, args&: two, args: tanhOp.getOperand());
1553 Value exponential = rewriter.create<LLVM::ExpOp>(location: loc, args&: dstType, args&: multiplied);
1554 Value one = createFPConstant(loc, srcType, dstType, rewriter, value: 1.0);
1555 Value numerator =
1556 rewriter.create<LLVM::FSubOp>(location: loc, args&: dstType, args&: exponential, args&: one);
1557 Value denominator =
1558 rewriter.create<LLVM::FAddOp>(location: loc, args&: dstType, args&: exponential, args&: one);
1559 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op: tanhOp, args&: dstType, args&: numerator,
1560 args&: denominator);
1561 return success();
1562 }
1563};
1564
1565class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1566public:
1567 using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
1568
1569 LogicalResult
1570 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1571 ConversionPatternRewriter &rewriter) const override {
1572 auto srcType = varOp.getType();
1573 // Initialization is supported for scalars and vectors only.
1574 auto pointerTo = cast<spirv::PointerType>(Val&: srcType).getPointeeType();
1575 auto init = varOp.getInitializer();
1576 if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(Val: pointerTo))
1577 return failure();
1578
1579 auto dstType = getTypeConverter()->convertType(t: srcType);
1580 if (!dstType)
1581 return rewriter.notifyMatchFailure(arg&: varOp, msg: "type conversion failed");
1582
1583 Location loc = varOp.getLoc();
1584 Value size = createI32ConstantOf(loc, rewriter, value: 1);
1585 if (!init) {
1586 auto elementType = getTypeConverter()->convertType(t: pointerTo);
1587 if (!elementType)
1588 return rewriter.notifyMatchFailure(arg&: varOp, msg: "type conversion failed");
1589 rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(op: varOp, args&: dstType, args&: elementType,
1590 args&: size);
1591 return success();
1592 }
1593 auto elementType = getTypeConverter()->convertType(t: pointerTo);
1594 if (!elementType)
1595 return rewriter.notifyMatchFailure(arg&: varOp, msg: "type conversion failed");
1596 Value allocated =
1597 rewriter.create<LLVM::AllocaOp>(location: loc, args&: dstType, args&: elementType, args&: size);
1598 rewriter.create<LLVM::StoreOp>(location: loc, args: adaptor.getInitializer(), args&: allocated);
1599 rewriter.replaceOp(op: varOp, newValues: allocated);
1600 return success();
1601 }
1602};
1603
1604//===----------------------------------------------------------------------===//
1605// BitcastOp conversion
1606//===----------------------------------------------------------------------===//
1607
1608class BitcastConversionPattern
1609 : public SPIRVToLLVMConversion<spirv::BitcastOp> {
1610public:
1611 using SPIRVToLLVMConversion<spirv::BitcastOp>::SPIRVToLLVMConversion;
1612
1613 LogicalResult
1614 matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1615 ConversionPatternRewriter &rewriter) const override {
1616 auto dstType = getTypeConverter()->convertType(t: bitcastOp.getType());
1617 if (!dstType)
1618 return rewriter.notifyMatchFailure(arg&: bitcastOp, msg: "type conversion failed");
1619
1620 // LLVM's opaque pointers do not require bitcasts.
1621 if (isa<LLVM::LLVMPointerType>(Val: dstType)) {
1622 rewriter.replaceOp(op: bitcastOp, newValues: adaptor.getOperand());
1623 return success();
1624 }
1625
1626 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1627 op: bitcastOp, args&: dstType, args: adaptor.getOperands(), args: bitcastOp->getAttrs());
1628 return success();
1629 }
1630};
1631
1632//===----------------------------------------------------------------------===//
1633// FuncOp conversion
1634//===----------------------------------------------------------------------===//
1635
1636class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1637public:
1638 using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
1639
1640 LogicalResult
1641 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1642 ConversionPatternRewriter &rewriter) const override {
1643
1644 // Convert function signature. At the moment LLVMType converter is enough
1645 // for currently supported types.
1646 auto funcType = funcOp.getFunctionType();
1647 TypeConverter::SignatureConversion signatureConverter(
1648 funcType.getNumInputs());
1649 auto llvmType = static_cast<const LLVMTypeConverter *>(getTypeConverter())
1650 ->convertFunctionSignature(
1651 funcTy: funcType, /*isVariadic=*/false,
1652 /*useBarePtrCallConv=*/false, result&: signatureConverter);
1653 if (!llvmType)
1654 return failure();
1655
1656 // Create a new `LLVMFuncOp`
1657 Location loc = funcOp.getLoc();
1658 StringRef name = funcOp.getName();
1659 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(location: loc, args&: name, args&: llvmType);
1660
1661 // Convert SPIR-V Function Control to equivalent LLVM function attribute
1662 MLIRContext *context = funcOp.getContext();
1663 switch (funcOp.getFunctionControl()) {
1664 case spirv::FunctionControl::Inline:
1665 newFuncOp.setAlwaysInline(true);
1666 break;
1667 case spirv::FunctionControl::DontInline:
1668 newFuncOp.setNoInline(true);
1669 break;
1670
1671#define DISPATCH(functionControl, llvmAttr) \
1672 case functionControl: \
1673 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1674 break;
1675
1676 DISPATCH(spirv::FunctionControl::Pure,
1677 StringAttr::get(context, "readonly"));
1678 DISPATCH(spirv::FunctionControl::Const,
1679 StringAttr::get(context, "readnone"));
1680
1681#undef DISPATCH
1682
1683 // Default: if `spirv::FunctionControl::None`, then no attributes are
1684 // needed.
1685 default:
1686 break;
1687 }
1688
1689 rewriter.inlineRegionBefore(region&: funcOp.getBody(), parent&: newFuncOp.getBody(),
1690 before: newFuncOp.end());
1691 if (failed(Result: rewriter.convertRegionTypes(
1692 region: &newFuncOp.getBody(), converter: *getTypeConverter(), entryConversion: &signatureConverter))) {
1693 return failure();
1694 }
1695 rewriter.eraseOp(op: funcOp);
1696 return success();
1697 }
1698};
1699
1700//===----------------------------------------------------------------------===//
1701// ModuleOp conversion
1702//===----------------------------------------------------------------------===//
1703
1704class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1705public:
1706 using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion;
1707
1708 LogicalResult
1709 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1710 ConversionPatternRewriter &rewriter) const override {
1711
1712 auto newModuleOp =
1713 rewriter.create<ModuleOp>(location: spvModuleOp.getLoc(), args: spvModuleOp.getName());
1714 rewriter.inlineRegionBefore(region&: spvModuleOp.getRegion(), before: newModuleOp.getBody());
1715
1716 // Remove the terminator block that was automatically added by builder
1717 rewriter.eraseBlock(block: &newModuleOp.getBodyRegion().back());
1718 rewriter.eraseOp(op: spvModuleOp);
1719 return success();
1720 }
1721};
1722
1723//===----------------------------------------------------------------------===//
1724// VectorShuffleOp conversion
1725//===----------------------------------------------------------------------===//
1726
1727class VectorShufflePattern
1728 : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1729public:
1730 using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion;
1731 LogicalResult
1732 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1733 ConversionPatternRewriter &rewriter) const override {
1734 Location loc = op.getLoc();
1735 auto components = adaptor.getComponents();
1736 auto vector1 = adaptor.getVector1();
1737 auto vector2 = adaptor.getVector2();
1738 int vector1Size = cast<VectorType>(Val: vector1.getType()).getNumElements();
1739 int vector2Size = cast<VectorType>(Val: vector2.getType()).getNumElements();
1740 if (vector1Size == vector2Size) {
1741 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1742 op, args&: vector1, args&: vector2,
1743 args: LLVM::convertArrayToIndices<int32_t>(attrs: components));
1744 return success();
1745 }
1746
1747 auto dstType = getTypeConverter()->convertType(t: op.getType());
1748 if (!dstType)
1749 return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed");
1750 auto scalarType = cast<VectorType>(Val&: dstType).getElementType();
1751 auto componentsArray = components.getValue();
1752 auto *context = rewriter.getContext();
1753 auto llvmI32Type = IntegerType::get(context, width: 32);
1754 Value targetOp = rewriter.create<LLVM::PoisonOp>(location: loc, args&: dstType);
1755 for (unsigned i = 0; i < componentsArray.size(); i++) {
1756 if (!isa<IntegerAttr>(Val: componentsArray[i]))
1757 return op.emitError(message: "unable to support non-constant component");
1758
1759 int indexVal = cast<IntegerAttr>(Val: componentsArray[i]).getInt();
1760 if (indexVal == -1)
1761 continue;
1762
1763 int offsetVal = 0;
1764 Value baseVector = vector1;
1765 if (indexVal >= vector1Size) {
1766 offsetVal = vector1Size;
1767 baseVector = vector2;
1768 }
1769
1770 Value dstIndex = rewriter.create<LLVM::ConstantOp>(
1771 location: loc, args&: llvmI32Type, args: rewriter.getIntegerAttr(type: rewriter.getI32Type(), value: i));
1772 Value index = rewriter.create<LLVM::ConstantOp>(
1773 location: loc, args&: llvmI32Type,
1774 args: rewriter.getIntegerAttr(type: rewriter.getI32Type(), value: indexVal - offsetVal));
1775
1776 auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
1777 location: loc, args&: scalarType, args&: baseVector, args&: index);
1778 targetOp = rewriter.create<LLVM::InsertElementOp>(location: loc, args&: dstType, args&: targetOp,
1779 args&: extractOp, args&: dstIndex);
1780 }
1781 rewriter.replaceOp(op, newValues: targetOp);
1782 return success();
1783 }
1784};
1785} // namespace
1786
1787//===----------------------------------------------------------------------===//
1788// Pattern population
1789//===----------------------------------------------------------------------===//
1790
1791void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter,
1792 spirv::ClientAPI clientAPI) {
1793 typeConverter.addConversion(callback: [&](spirv::ArrayType type) {
1794 return convertArrayType(type, converter&: typeConverter);
1795 });
1796 typeConverter.addConversion(callback: [&, clientAPI](spirv::PointerType type) {
1797 return convertPointerType(type, converter: typeConverter, clientAPI);
1798 });
1799 typeConverter.addConversion(callback: [&](spirv::RuntimeArrayType type) {
1800 return convertRuntimeArrayType(type, converter&: typeConverter);
1801 });
1802 typeConverter.addConversion(callback: [&](spirv::StructType type) {
1803 return convertStructType(type, converter: typeConverter);
1804 });
1805}
1806
1807void mlir::populateSPIRVToLLVMConversionPatterns(
1808 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
1809 spirv::ClientAPI clientAPI) {
1810 patterns.add<
1811 // Arithmetic ops
1812 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1813 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1814 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1815 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1816 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1817 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1818 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1819 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1820 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1821 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1822 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1823 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1824 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1825
1826 // Bitwise ops
1827 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1828 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1829 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1830 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1831 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1832 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1833 NotPattern<spirv::NotOp>,
1834
1835 // Cast ops
1836 BitcastConversionPattern,
1837 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1838 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1839 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1840 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1841 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1842 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1843 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1844
1845 // Comparison ops
1846 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1847 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1848 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1849 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1850 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1851 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1852 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1853 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1854 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1855 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1856 FComparePattern<spirv::FUnordGreaterThanEqualOp,
1857 LLVM::FCmpPredicate::uge>,
1858 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1859 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1860 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1861 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1862 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1863 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1864 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1865 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1866 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1867 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1868 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1869
1870 // Constant op
1871 ConstantScalarAndVectorPattern,
1872
1873 // Control Flow ops
1874 BranchConversionPattern, BranchConditionalConversionPattern,
1875 FunctionCallPattern, LoopPattern, SelectionPattern,
1876 ErasePattern<spirv::MergeOp>,
1877
1878 // Entry points and execution mode are handled separately.
1879 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1880
1881 // GLSL extended instruction set ops
1882 DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1883 DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1884 DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1885 DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1886 DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1887 DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1888 DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1889 DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1890 DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1891 DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1892 DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1893 DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1894 InverseSqrtPattern, TanPattern, TanhPattern,
1895
1896 // Logical ops
1897 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1898 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1899 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1900 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1901 NotPattern<spirv::LogicalNotOp>,
1902
1903 // Memory ops
1904 AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1905 LoadStorePattern<spirv::StoreOp>, VariablePattern,
1906
1907 // Miscellaneous ops
1908 CompositeExtractPattern, CompositeInsertPattern,
1909 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1910 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1911 VectorShufflePattern,
1912
1913 // Shift ops
1914 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1915 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1916 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1917
1918 // Return ops
1919 ReturnPattern, ReturnValuePattern,
1920
1921 // Barrier ops
1922 ControlBarrierPattern<spirv::ControlBarrierOp>,
1923 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
1924 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,
1925
1926 // Group reduction operations
1927 GroupReducePattern<spirv::GroupIAddOp>,
1928 GroupReducePattern<spirv::GroupFAddOp>,
1929 GroupReducePattern<spirv::GroupFMinOp>,
1930 GroupReducePattern<spirv::GroupUMinOp>,
1931 GroupReducePattern<spirv::GroupSMinOp, /*Signed=*/true>,
1932 GroupReducePattern<spirv::GroupFMaxOp>,
1933 GroupReducePattern<spirv::GroupUMaxOp>,
1934 GroupReducePattern<spirv::GroupSMaxOp, /*Signed=*/true>,
1935 GroupReducePattern<spirv::GroupNonUniformIAddOp, /*Signed=*/false,
1936 /*NonUniform=*/true>,
1937 GroupReducePattern<spirv::GroupNonUniformFAddOp, /*Signed=*/false,
1938 /*NonUniform=*/true>,
1939 GroupReducePattern<spirv::GroupNonUniformIMulOp, /*Signed=*/false,
1940 /*NonUniform=*/true>,
1941 GroupReducePattern<spirv::GroupNonUniformFMulOp, /*Signed=*/false,
1942 /*NonUniform=*/true>,
1943 GroupReducePattern<spirv::GroupNonUniformSMinOp, /*Signed=*/true,
1944 /*NonUniform=*/true>,
1945 GroupReducePattern<spirv::GroupNonUniformUMinOp, /*Signed=*/false,
1946 /*NonUniform=*/true>,
1947 GroupReducePattern<spirv::GroupNonUniformFMinOp, /*Signed=*/false,
1948 /*NonUniform=*/true>,
1949 GroupReducePattern<spirv::GroupNonUniformSMaxOp, /*Signed=*/true,
1950 /*NonUniform=*/true>,
1951 GroupReducePattern<spirv::GroupNonUniformUMaxOp, /*Signed=*/false,
1952 /*NonUniform=*/true>,
1953 GroupReducePattern<spirv::GroupNonUniformFMaxOp, /*Signed=*/false,
1954 /*NonUniform=*/true>,
1955 GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /*Signed=*/false,
1956 /*NonUniform=*/true>,
1957 GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /*Signed=*/false,
1958 /*NonUniform=*/true>,
1959 GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /*Signed=*/false,
1960 /*NonUniform=*/true>,
1961 GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /*Signed=*/false,
1962 /*NonUniform=*/true>,
1963 GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /*Signed=*/false,
1964 /*NonUniform=*/true>,
1965 GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /*Signed=*/false,
1966 /*NonUniform=*/true>>(arg: patterns.getContext(),
1967 args: typeConverter);
1968
1969 patterns.add<GlobalVariablePattern>(arg&: clientAPI, args: patterns.getContext(),
1970 args: typeConverter);
1971}
1972
1973void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
1974 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1975 patterns.add<FuncConversionPattern>(arg: patterns.getContext(), args: typeConverter);
1976}
1977
1978void mlir::populateSPIRVToLLVMModuleConversionPatterns(
1979 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1980 patterns.add<ModuleConversionPattern>(arg: patterns.getContext(), args: typeConverter);
1981}
1982
1983//===----------------------------------------------------------------------===//
1984// Pre-conversion hooks
1985//===----------------------------------------------------------------------===//
1986
1987/// Hook for descriptor set and binding number encoding.
1988static constexpr StringRef kBinding = "binding";
1989static constexpr StringRef kDescriptorSet = "descriptor_set";
1990void mlir::encodeBindAttribute(ModuleOp module) {
1991 auto spvModules = module.getOps<spirv::ModuleOp>();
1992 for (auto spvModule : spvModules) {
1993 spvModule.walk(callback: [&](spirv::GlobalVariableOp op) {
1994 IntegerAttr descriptorSet =
1995 op->getAttrOfType<IntegerAttr>(name: kDescriptorSet);
1996 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(name: kBinding);
1997 // For every global variable in the module, get the ones with descriptor
1998 // set and binding numbers.
1999 if (descriptorSet && binding) {
2000 // Encode these numbers into the variable's symbolic name. If the
2001 // SPIR-V module has a name, add it at the beginning.
2002 auto moduleAndName =
2003 spvModule.getName().has_value()
2004 ? spvModule.getName()->str() + "_" + op.getSymName().str()
2005 : op.getSymName().str();
2006 std::string name =
2007 llvm::formatv(Fmt: "{0}_descriptor_set{1}_binding{2}", Vals&: moduleAndName,
2008 Vals: std::to_string(val: descriptorSet.getInt()),
2009 Vals: std::to_string(val: binding.getInt()));
2010 auto nameAttr = StringAttr::get(context: op->getContext(), bytes: name);
2011
2012 // Replace all symbol uses and set the new symbol name. Finally, remove
2013 // descriptor set and binding attributes.
2014 if (failed(Result: SymbolTable::replaceAllSymbolUses(oldSymbol: op, newSymbolName: nameAttr, from: spvModule)))
2015 op.emitError(message: "unable to replace all symbol uses for ") << name;
2016 SymbolTable::setSymbolName(symbol: op, name: nameAttr);
2017 op->removeAttr(name: kDescriptorSet);
2018 op->removeAttr(name: kBinding);
2019 }
2020 });
2021 }
2022}
2023

source code of mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp