1 | //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements patterns to convert SPIR-V dialect to LLVM dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h" |
14 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
15 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
16 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
17 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
18 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
19 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
20 | #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" |
21 | #include "mlir/IR/BuiltinOps.h" |
22 | #include "mlir/IR/PatternMatch.h" |
23 | #include "mlir/Support/LogicalResult.h" |
24 | #include "mlir/Transforms/DialectConversion.h" |
25 | #include "llvm/Support/Debug.h" |
26 | #include "llvm/Support/FormatVariadic.h" |
27 | |
28 | #define DEBUG_TYPE "spirv-to-llvm-pattern" |
29 | |
30 | using namespace mlir; |
31 | |
32 | //===----------------------------------------------------------------------===// |
33 | // Constants |
34 | //===----------------------------------------------------------------------===// |
35 | |
36 | constexpr unsigned defaultAddressSpace = 0; |
37 | |
38 | //===----------------------------------------------------------------------===// |
39 | // Utility functions |
40 | //===----------------------------------------------------------------------===// |
41 | |
42 | /// Returns true if the given type is a signed integer or vector type. |
43 | static bool isSignedIntegerOrVector(Type type) { |
44 | if (type.isSignedInteger()) |
45 | return true; |
46 | if (auto vecType = dyn_cast<VectorType>(type)) |
47 | return vecType.getElementType().isSignedInteger(); |
48 | return false; |
49 | } |
50 | |
51 | /// Returns true if the given type is an unsigned integer or vector type |
52 | static bool isUnsignedIntegerOrVector(Type type) { |
53 | if (type.isUnsignedInteger()) |
54 | return true; |
55 | if (auto vecType = dyn_cast<VectorType>(type)) |
56 | return vecType.getElementType().isUnsignedInteger(); |
57 | return false; |
58 | } |
59 | |
60 | /// Returns the width of an integer or of the element type of an integer vector, |
61 | /// if applicable. |
62 | static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) { |
63 | if (auto intType = dyn_cast<IntegerType>(type)) |
64 | return intType.getWidth(); |
65 | if (auto vecType = dyn_cast<VectorType>(type)) |
66 | if (auto intType = dyn_cast<IntegerType>(vecType.getElementType())) |
67 | return intType.getWidth(); |
68 | return std::nullopt; |
69 | } |
70 | |
71 | /// Returns the bit width of integer, float or vector of float or integer values |
72 | static unsigned getBitWidth(Type type) { |
73 | assert((type.isIntOrFloat() || isa<VectorType>(type)) && |
74 | "bitwidth is not supported for this type" ); |
75 | if (type.isIntOrFloat()) |
76 | return type.getIntOrFloatBitWidth(); |
77 | auto vecType = dyn_cast<VectorType>(type); |
78 | auto elementType = vecType.getElementType(); |
79 | assert(elementType.isIntOrFloat() && |
80 | "only integers and floats have a bitwidth" ); |
81 | return elementType.getIntOrFloatBitWidth(); |
82 | } |
83 | |
84 | /// Returns the bit width of LLVMType integer or vector. |
85 | static unsigned getLLVMTypeBitWidth(Type type) { |
86 | return cast<IntegerType>((LLVM::isCompatibleVectorType(type) |
87 | ? LLVM::getVectorElementType(type) |
88 | : type)) |
89 | .getWidth(); |
90 | } |
91 | |
92 | /// Creates `IntegerAttribute` with all bits set for given type |
93 | static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { |
94 | if (auto vecType = dyn_cast<VectorType>(type)) { |
95 | auto integerType = cast<IntegerType>(vecType.getElementType()); |
96 | return builder.getIntegerAttr(integerType, -1); |
97 | } |
98 | auto integerType = cast<IntegerType>(type); |
99 | return builder.getIntegerAttr(integerType, -1); |
100 | } |
101 | |
102 | /// Creates `llvm.mlir.constant` with all bits set for the given type. |
103 | static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, |
104 | PatternRewriter &rewriter) { |
105 | if (isa<VectorType>(Val: srcType)) { |
106 | return rewriter.create<LLVM::ConstantOp>( |
107 | loc, dstType, |
108 | SplatElementsAttr::get(cast<ShapedType>(srcType), |
109 | minusOneIntegerAttribute(srcType, rewriter))); |
110 | } |
111 | return rewriter.create<LLVM::ConstantOp>( |
112 | loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); |
113 | } |
114 | |
115 | /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value. |
116 | static Value createFPConstant(Location loc, Type srcType, Type dstType, |
117 | PatternRewriter &rewriter, double value) { |
118 | if (auto vecType = dyn_cast<VectorType>(srcType)) { |
119 | auto floatType = cast<FloatType>(vecType.getElementType()); |
120 | return rewriter.create<LLVM::ConstantOp>( |
121 | loc, dstType, |
122 | SplatElementsAttr::get(vecType, |
123 | rewriter.getFloatAttr(floatType, value))); |
124 | } |
125 | auto floatType = cast<FloatType>(Val&: srcType); |
126 | return rewriter.create<LLVM::ConstantOp>( |
127 | loc, dstType, rewriter.getFloatAttr(floatType, value)); |
128 | } |
129 | |
130 | /// Utility function for bitfield ops: |
131 | /// - `BitFieldInsert` |
132 | /// - `BitFieldSExtract` |
133 | /// - `BitFieldUExtract` |
134 | /// Truncates or extends the value. If the bitwidth of the value is the same as |
135 | /// `llvmType` bitwidth, the value remains unchanged. |
136 | static Value optionallyTruncateOrExtend(Location loc, Value value, |
137 | Type llvmType, |
138 | PatternRewriter &rewriter) { |
139 | auto srcType = value.getType(); |
140 | unsigned targetBitWidth = getLLVMTypeBitWidth(type: llvmType); |
141 | unsigned valueBitWidth = LLVM::isCompatibleType(type: srcType) |
142 | ? getLLVMTypeBitWidth(type: srcType) |
143 | : getBitWidth(type: srcType); |
144 | |
145 | if (valueBitWidth < targetBitWidth) |
146 | return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value); |
147 | // If the bit widths of `Count` and `Offset` are greater than the bit width |
148 | // of the target type, they are truncated. Truncation is safe since `Count` |
149 | // and `Offset` must be no more than 64 for op behaviour to be defined. Hence, |
150 | // both values can be expressed in 8 bits. |
151 | if (valueBitWidth > targetBitWidth) |
152 | return rewriter.create<LLVM::TruncOp>(loc, llvmType, value); |
153 | return value; |
154 | } |
155 | |
156 | /// Broadcasts the value to vector with `numElements` number of elements. |
157 | static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, |
158 | LLVMTypeConverter &typeConverter, |
159 | ConversionPatternRewriter &rewriter) { |
160 | auto vectorType = VectorType::get(numElements, toBroadcast.getType()); |
161 | auto llvmVectorType = typeConverter.convertType(vectorType); |
162 | auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); |
163 | Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType); |
164 | for (unsigned i = 0; i < numElements; ++i) { |
165 | auto index = rewriter.create<LLVM::ConstantOp>( |
166 | loc, llvmI32Type, rewriter.getI32IntegerAttr(i)); |
167 | broadcasted = rewriter.create<LLVM::InsertElementOp>( |
168 | loc, llvmVectorType, broadcasted, toBroadcast, index); |
169 | } |
170 | return broadcasted; |
171 | } |
172 | |
173 | /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged. |
174 | static Value optionallyBroadcast(Location loc, Value value, Type srcType, |
175 | LLVMTypeConverter &typeConverter, |
176 | ConversionPatternRewriter &rewriter) { |
177 | if (auto vectorType = dyn_cast<VectorType>(srcType)) { |
178 | unsigned numElements = vectorType.getNumElements(); |
179 | return broadcast(loc, toBroadcast: value, numElements, typeConverter, rewriter); |
180 | } |
181 | return value; |
182 | } |
183 | |
184 | /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and |
185 | /// `BitFieldUExtract`. |
186 | /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of |
187 | /// a vector type, construct a vector that has: |
188 | /// - same number of elements as `Base` |
189 | /// - each element has the type that is the same as the type of `Offset` or |
190 | /// `Count` |
191 | /// - each element has the same value as `Offset` or `Count` |
192 | /// Then cast `Offset` and `Count` if their bit width is different |
193 | /// from `Base` bit width. |
194 | static Value processCountOrOffset(Location loc, Value value, Type srcType, |
195 | Type dstType, LLVMTypeConverter &converter, |
196 | ConversionPatternRewriter &rewriter) { |
197 | Value broadcasted = |
198 | optionallyBroadcast(loc, value, srcType, typeConverter&: converter, rewriter); |
199 | return optionallyTruncateOrExtend(loc, value: broadcasted, llvmType: dstType, rewriter); |
200 | } |
201 | |
202 | /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`) |
203 | /// offset to LLVM struct. Otherwise, the conversion is not supported. |
204 | static Type convertStructTypeWithOffset(spirv::StructType type, |
205 | LLVMTypeConverter &converter) { |
206 | if (type != VulkanLayoutUtils::decorateType(structType: type)) |
207 | return nullptr; |
208 | |
209 | SmallVector<Type> elementsVector; |
210 | if (failed(result: converter.convertTypes(types: type.getElementTypes(), results&: elementsVector))) |
211 | return nullptr; |
212 | return LLVM::LLVMStructType::getLiteral(context: type.getContext(), types: elementsVector, |
213 | /*isPacked=*/false); |
214 | } |
215 | |
216 | /// Converts SPIR-V struct with no offset to packed LLVM struct. |
217 | static Type convertStructTypePacked(spirv::StructType type, |
218 | LLVMTypeConverter &converter) { |
219 | SmallVector<Type> elementsVector; |
220 | if (failed(result: converter.convertTypes(types: type.getElementTypes(), results&: elementsVector))) |
221 | return nullptr; |
222 | return LLVM::LLVMStructType::getLiteral(context: type.getContext(), types: elementsVector, |
223 | /*isPacked=*/true); |
224 | } |
225 | |
226 | /// Creates LLVM dialect constant with the given value. |
227 | static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, |
228 | unsigned value) { |
229 | return rewriter.create<LLVM::ConstantOp>( |
230 | loc, IntegerType::get(rewriter.getContext(), 32), |
231 | rewriter.getIntegerAttr(rewriter.getI32Type(), value)); |
232 | } |
233 | |
234 | /// Utility for `spirv.Load` and `spirv.Store` conversion. |
235 | static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, |
236 | ConversionPatternRewriter &rewriter, |
237 | LLVMTypeConverter &typeConverter, |
238 | unsigned alignment, bool isVolatile, |
239 | bool isNonTemporal) { |
240 | if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) { |
241 | auto dstType = typeConverter.convertType(loadOp.getType()); |
242 | if (!dstType) |
243 | return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed" ); |
244 | rewriter.replaceOpWithNewOp<LLVM::LoadOp>( |
245 | loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment, |
246 | isVolatile, isNonTemporal); |
247 | return success(); |
248 | } |
249 | auto storeOp = cast<spirv::StoreOp>(op); |
250 | spirv::StoreOpAdaptor adaptor(operands); |
251 | rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(), |
252 | adaptor.getPtr(), alignment, |
253 | isVolatile, isNonTemporal); |
254 | return success(); |
255 | } |
256 | |
257 | //===----------------------------------------------------------------------===// |
258 | // Type conversion |
259 | //===----------------------------------------------------------------------===// |
260 | |
261 | /// Converts SPIR-V array type to LLVM array. Natural stride (according to |
262 | /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected |
263 | /// when converting ops that manipulate array types. |
264 | static std::optional<Type> convertArrayType(spirv::ArrayType type, |
265 | TypeConverter &converter) { |
266 | unsigned stride = type.getArrayStride(); |
267 | Type elementType = type.getElementType(); |
268 | auto sizeInBytes = cast<spirv::SPIRVType>(Val&: elementType).getSizeInBytes(); |
269 | if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride)) |
270 | return std::nullopt; |
271 | |
272 | auto llvmElementType = converter.convertType(t: elementType); |
273 | unsigned numElements = type.getNumElements(); |
274 | return LLVM::LLVMArrayType::get(llvmElementType, numElements); |
275 | } |
276 | |
277 | static unsigned mapToOpenCLAddressSpace(spirv::StorageClass storageClass) { |
278 | // Based on |
279 | // https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html#_binary_form |
280 | // and clang/lib/Basic/Targets/SPIR.h. |
281 | switch (storageClass) { |
282 | #define STORAGE_SPACE_MAP(storage, space) \ |
283 | case spirv::StorageClass::storage: \ |
284 | return space; |
285 | STORAGE_SPACE_MAP(Function, 0) |
286 | STORAGE_SPACE_MAP(CrossWorkgroup, 1) |
287 | STORAGE_SPACE_MAP(Input, 1) |
288 | STORAGE_SPACE_MAP(UniformConstant, 2) |
289 | STORAGE_SPACE_MAP(Workgroup, 3) |
290 | STORAGE_SPACE_MAP(Generic, 4) |
291 | STORAGE_SPACE_MAP(DeviceOnlyINTEL, 5) |
292 | STORAGE_SPACE_MAP(HostOnlyINTEL, 6) |
293 | #undef STORAGE_SPACE_MAP |
294 | default: |
295 | return defaultAddressSpace; |
296 | } |
297 | } |
298 | |
299 | static unsigned mapToAddressSpace(spirv::ClientAPI clientAPI, |
300 | spirv::StorageClass storageClass) { |
301 | switch (clientAPI) { |
302 | #define CLIENT_MAP(client, storage) \ |
303 | case spirv::ClientAPI::client: \ |
304 | return mapTo##client##AddressSpace(storage); |
305 | CLIENT_MAP(OpenCL, storageClass) |
306 | #undef CLIENT_MAP |
307 | default: |
308 | return defaultAddressSpace; |
309 | } |
310 | } |
311 | |
312 | /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not |
313 | /// modelled at the moment. |
314 | static Type convertPointerType(spirv::PointerType type, |
315 | LLVMTypeConverter &converter, |
316 | spirv::ClientAPI clientAPI) { |
317 | unsigned addressSpace = mapToAddressSpace(clientAPI, type.getStorageClass()); |
318 | return LLVM::LLVMPointerType::get(type.getContext(), addressSpace); |
319 | } |
320 | |
321 | /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over |
322 | /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is |
323 | /// no modelling of array stride at the moment. |
324 | static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type, |
325 | TypeConverter &converter) { |
326 | if (type.getArrayStride() != 0) |
327 | return std::nullopt; |
328 | auto elementType = converter.convertType(t: type.getElementType()); |
329 | return LLVM::LLVMArrayType::get(elementType, 0); |
330 | } |
331 | |
332 | /// Converts SPIR-V struct to LLVM struct. There is no support of structs with |
333 | /// member decorations. Also, only natural offset is supported. |
334 | static Type convertStructType(spirv::StructType type, |
335 | LLVMTypeConverter &converter) { |
336 | SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; |
337 | type.getMemberDecorations(memberDecorations); |
338 | if (!memberDecorations.empty()) |
339 | return nullptr; |
340 | if (type.hasOffset()) |
341 | return convertStructTypeWithOffset(type, converter); |
342 | return convertStructTypePacked(type, converter); |
343 | } |
344 | |
345 | //===----------------------------------------------------------------------===// |
346 | // Operation conversion |
347 | //===----------------------------------------------------------------------===// |
348 | |
349 | namespace { |
350 | |
351 | class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> { |
352 | public: |
353 | using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion; |
354 | |
355 | LogicalResult |
356 | matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor, |
357 | ConversionPatternRewriter &rewriter) const override { |
358 | auto dstType = typeConverter.convertType(op.getComponentPtr().getType()); |
359 | if (!dstType) |
360 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
361 | // To use GEP we need to add a first 0 index to go through the pointer. |
362 | auto indices = llvm::to_vector<4>(adaptor.getIndices()); |
363 | Type indexType = op.getIndices().front().getType(); |
364 | auto llvmIndexType = typeConverter.convertType(indexType); |
365 | if (!llvmIndexType) |
366 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
367 | Value zero = rewriter.create<LLVM::ConstantOp>( |
368 | op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0)); |
369 | indices.insert(indices.begin(), zero); |
370 | |
371 | auto elementType = typeConverter.convertType( |
372 | cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType()); |
373 | if (!elementType) |
374 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
375 | rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType, |
376 | adaptor.getBasePtr(), indices); |
377 | return success(); |
378 | } |
379 | }; |
380 | |
381 | class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> { |
382 | public: |
383 | using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion; |
384 | |
385 | LogicalResult |
386 | matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor, |
387 | ConversionPatternRewriter &rewriter) const override { |
388 | auto dstType = typeConverter.convertType(op.getPointer().getType()); |
389 | if (!dstType) |
390 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
391 | rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, |
392 | op.getVariable()); |
393 | return success(); |
394 | } |
395 | }; |
396 | |
397 | class BitFieldInsertPattern |
398 | : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> { |
399 | public: |
400 | using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion; |
401 | |
402 | LogicalResult |
403 | matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor, |
404 | ConversionPatternRewriter &rewriter) const override { |
405 | auto srcType = op.getType(); |
406 | auto dstType = typeConverter.convertType(srcType); |
407 | if (!dstType) |
408 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
409 | Location loc = op.getLoc(); |
410 | |
411 | // Process `Offset` and `Count`: broadcast and extend/truncate if needed. |
412 | Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType, |
413 | typeConverter, rewriter); |
414 | Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType, |
415 | typeConverter, rewriter); |
416 | |
417 | // Create a mask with bits set outside [Offset, Offset + Count - 1]. |
418 | Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); |
419 | Value maskShiftedByCount = |
420 | rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count); |
421 | Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType, |
422 | maskShiftedByCount, minusOne); |
423 | Value maskShiftedByCountAndOffset = |
424 | rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset); |
425 | Value mask = rewriter.create<LLVM::XOrOp>( |
426 | loc, dstType, maskShiftedByCountAndOffset, minusOne); |
427 | |
428 | // Extract unchanged bits from the `Base` that are outside of |
429 | // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`. |
430 | Value baseAndMask = |
431 | rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask); |
432 | Value insertShiftedByOffset = |
433 | rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset); |
434 | rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask, |
435 | insertShiftedByOffset); |
436 | return success(); |
437 | } |
438 | }; |
439 | |
440 | /// Converts SPIR-V ConstantOp with scalar or vector type. |
441 | class ConstantScalarAndVectorPattern |
442 | : public SPIRVToLLVMConversion<spirv::ConstantOp> { |
443 | public: |
444 | using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion; |
445 | |
446 | LogicalResult |
447 | matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor, |
448 | ConversionPatternRewriter &rewriter) const override { |
449 | auto srcType = constOp.getType(); |
450 | if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat()) |
451 | return failure(); |
452 | |
453 | auto dstType = typeConverter.convertType(srcType); |
454 | if (!dstType) |
455 | return rewriter.notifyMatchFailure(constOp, "type conversion failed" ); |
456 | |
457 | // SPIR-V constant can be a signed/unsigned integer, which has to be |
458 | // casted to signless integer when converting to LLVM dialect. Removing the |
459 | // sign bit may have unexpected behaviour. However, it is better to handle |
460 | // it case-by-case, given that the purpose of the conversion is not to |
461 | // cover all possible corner cases. |
462 | if (isSignedIntegerOrVector(srcType) || |
463 | isUnsignedIntegerOrVector(srcType)) { |
464 | auto signlessType = rewriter.getIntegerType(getBitWidth(srcType)); |
465 | |
466 | if (isa<VectorType>(srcType)) { |
467 | auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue()); |
468 | rewriter.replaceOpWithNewOp<LLVM::ConstantOp>( |
469 | constOp, dstType, |
470 | dstElementsAttr.mapValues( |
471 | signlessType, [&](const APInt &value) { return value; })); |
472 | return success(); |
473 | } |
474 | auto srcAttr = cast<IntegerAttr>(constOp.getValue()); |
475 | auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue()); |
476 | rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr); |
477 | return success(); |
478 | } |
479 | rewriter.replaceOpWithNewOp<LLVM::ConstantOp>( |
480 | constOp, dstType, adaptor.getOperands(), constOp->getAttrs()); |
481 | return success(); |
482 | } |
483 | }; |
484 | |
485 | class |
486 | : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> { |
487 | public: |
488 | using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion; |
489 | |
490 | LogicalResult |
491 | matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor, |
492 | ConversionPatternRewriter &rewriter) const override { |
493 | auto srcType = op.getType(); |
494 | auto dstType = typeConverter.convertType(srcType); |
495 | if (!dstType) |
496 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
497 | Location loc = op.getLoc(); |
498 | |
499 | // Process `Offset` and `Count`: broadcast and extend/truncate if needed. |
500 | Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType, |
501 | typeConverter, rewriter); |
502 | Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType, |
503 | typeConverter, rewriter); |
504 | |
505 | // Create a constant that holds the size of the `Base`. |
506 | IntegerType integerType; |
507 | if (auto vecType = dyn_cast<VectorType>(srcType)) |
508 | integerType = cast<IntegerType>(vecType.getElementType()); |
509 | else |
510 | integerType = cast<IntegerType>(srcType); |
511 | |
512 | auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType)); |
513 | Value size = |
514 | isa<VectorType>(srcType) |
515 | ? rewriter.create<LLVM::ConstantOp>( |
516 | loc, dstType, |
517 | SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize)) |
518 | : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize); |
519 | |
520 | // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit |
521 | // at Offset + Count - 1 is the most significant bit now. |
522 | Value countPlusOffset = |
523 | rewriter.create<LLVM::AddOp>(loc, dstType, count, offset); |
524 | Value amountToShiftLeft = |
525 | rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset); |
526 | Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>( |
527 | loc, dstType, op.getBase(), amountToShiftLeft); |
528 | |
529 | // Shift the result right, filling the bits with the sign bit. |
530 | Value amountToShiftRight = |
531 | rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft); |
532 | rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft, |
533 | amountToShiftRight); |
534 | return success(); |
535 | } |
536 | }; |
537 | |
538 | class |
539 | : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> { |
540 | public: |
541 | using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion; |
542 | |
543 | LogicalResult |
544 | matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor, |
545 | ConversionPatternRewriter &rewriter) const override { |
546 | auto srcType = op.getType(); |
547 | auto dstType = typeConverter.convertType(srcType); |
548 | if (!dstType) |
549 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
550 | Location loc = op.getLoc(); |
551 | |
552 | // Process `Offset` and `Count`: broadcast and extend/truncate if needed. |
553 | Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType, |
554 | typeConverter, rewriter); |
555 | Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType, |
556 | typeConverter, rewriter); |
557 | |
558 | // Create a mask with bits set at [0, Count - 1]. |
559 | Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); |
560 | Value maskShiftedByCount = |
561 | rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count); |
562 | Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount, |
563 | minusOne); |
564 | |
565 | // Shift `Base` by `Offset` and apply the mask on it. |
566 | Value shiftedBase = |
567 | rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset); |
568 | rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask); |
569 | return success(); |
570 | } |
571 | }; |
572 | |
573 | class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> { |
574 | public: |
575 | using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion; |
576 | |
577 | LogicalResult |
578 | matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor, |
579 | ConversionPatternRewriter &rewriter) const override { |
580 | rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(), |
581 | branchOp.getTarget()); |
582 | return success(); |
583 | } |
584 | }; |
585 | |
586 | class BranchConditionalConversionPattern |
587 | : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> { |
588 | public: |
589 | using SPIRVToLLVMConversion< |
590 | spirv::BranchConditionalOp>::SPIRVToLLVMConversion; |
591 | |
592 | LogicalResult |
593 | matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor, |
594 | ConversionPatternRewriter &rewriter) const override { |
595 | // If branch weights exist, map them to 32-bit integer vector. |
596 | DenseI32ArrayAttr branchWeights = nullptr; |
597 | if (auto weights = op.getBranchWeights()) { |
598 | SmallVector<int32_t> weightValues; |
599 | for (auto weight : weights->getAsRange<IntegerAttr>()) |
600 | weightValues.push_back(weight.getInt()); |
601 | branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues); |
602 | } |
603 | |
604 | rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( |
605 | op, op.getCondition(), op.getTrueBlockArguments(), |
606 | op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(), |
607 | op.getFalseBlock()); |
608 | return success(); |
609 | } |
610 | }; |
611 | |
612 | /// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container |
613 | /// type is an aggregate type (struct or array). Otherwise, converts to |
614 | /// `llvm.extractelement` that operates on vectors. |
615 | class |
616 | : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> { |
617 | public: |
618 | using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion; |
619 | |
620 | LogicalResult |
621 | matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor, |
622 | ConversionPatternRewriter &rewriter) const override { |
623 | auto dstType = this->typeConverter.convertType(op.getType()); |
624 | if (!dstType) |
625 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
626 | |
627 | Type containerType = op.getComposite().getType(); |
628 | if (isa<VectorType>(Val: containerType)) { |
629 | Location loc = op.getLoc(); |
630 | IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]); |
631 | Value index = createI32ConstantOf(loc, rewriter, value.getInt()); |
632 | rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( |
633 | op, dstType, adaptor.getComposite(), index); |
634 | return success(); |
635 | } |
636 | |
637 | rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>( |
638 | op, adaptor.getComposite(), |
639 | LLVM::convertArrayToIndices(op.getIndices())); |
640 | return success(); |
641 | } |
642 | }; |
643 | |
644 | /// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container |
645 | /// type is an aggregate type (struct or array). Otherwise, converts to |
646 | /// `llvm.insertelement` that operates on vectors. |
647 | class CompositeInsertPattern |
648 | : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> { |
649 | public: |
650 | using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion; |
651 | |
652 | LogicalResult |
653 | matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor, |
654 | ConversionPatternRewriter &rewriter) const override { |
655 | auto dstType = this->typeConverter.convertType(op.getType()); |
656 | if (!dstType) |
657 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
658 | |
659 | Type containerType = op.getComposite().getType(); |
660 | if (isa<VectorType>(Val: containerType)) { |
661 | Location loc = op.getLoc(); |
662 | IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]); |
663 | Value index = createI32ConstantOf(loc, rewriter, value.getInt()); |
664 | rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( |
665 | op, dstType, adaptor.getComposite(), adaptor.getObject(), index); |
666 | return success(); |
667 | } |
668 | |
669 | rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>( |
670 | op, adaptor.getComposite(), adaptor.getObject(), |
671 | LLVM::convertArrayToIndices(op.getIndices())); |
672 | return success(); |
673 | } |
674 | }; |
675 | |
676 | /// Converts SPIR-V operations that have straightforward LLVM equivalent |
677 | /// into LLVM dialect operations. |
678 | template <typename SPIRVOp, typename LLVMOp> |
679 | class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> { |
680 | public: |
681 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
682 | |
683 | LogicalResult |
684 | matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, |
685 | ConversionPatternRewriter &rewriter) const override { |
686 | auto dstType = this->typeConverter.convertType(op.getType()); |
687 | if (!dstType) |
688 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
689 | rewriter.template replaceOpWithNewOp<LLVMOp>( |
690 | op, dstType, adaptor.getOperands(), op->getAttrs()); |
691 | return success(); |
692 | } |
693 | }; |
694 | |
695 | /// Converts `spirv.ExecutionMode` into a global struct constant that holds |
696 | /// execution mode information. |
697 | class ExecutionModePattern |
698 | : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> { |
699 | public: |
700 | using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion; |
701 | |
702 | LogicalResult |
703 | matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor, |
704 | ConversionPatternRewriter &rewriter) const override { |
705 | // First, create the global struct's name that would be associated with |
706 | // this entry point's execution mode. We set it to be: |
707 | // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode} |
708 | ModuleOp module = op->getParentOfType<ModuleOp>(); |
709 | spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr(); |
710 | std::string moduleName; |
711 | if (module.getName().has_value()) |
712 | moduleName = "_" + module.getName()->str(); |
713 | else |
714 | moduleName = "" ; |
715 | std::string executionModeInfoName = llvm::formatv( |
716 | "__spv_{0}_{1}_execution_mode_info_{2}" , moduleName, op.getFn().str(), |
717 | static_cast<uint32_t>(executionModeAttr.getValue())); |
718 | |
719 | MLIRContext *context = rewriter.getContext(); |
720 | OpBuilder::InsertionGuard guard(rewriter); |
721 | rewriter.setInsertionPointToStart(module.getBody()); |
722 | |
723 | // Create a struct type, corresponding to the C struct below. |
724 | // struct { |
725 | // int32_t executionMode; |
726 | // int32_t values[]; // optional values |
727 | // }; |
728 | auto llvmI32Type = IntegerType::get(context, 32); |
729 | SmallVector<Type, 2> fields; |
730 | fields.push_back(Elt: llvmI32Type); |
731 | ArrayAttr values = op.getValues(); |
732 | if (!values.empty()) { |
733 | auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size()); |
734 | fields.push_back(Elt: arrayType); |
735 | } |
736 | auto structType = LLVM::LLVMStructType::getLiteral(context, types: fields); |
737 | |
738 | // Create `llvm.mlir.global` with initializer region containing one block. |
739 | auto global = rewriter.create<LLVM::GlobalOp>( |
740 | UnknownLoc::get(context), structType, /*isConstant=*/true, |
741 | LLVM::Linkage::External, executionModeInfoName, Attribute(), |
742 | /*alignment=*/0); |
743 | Location loc = global.getLoc(); |
744 | Region ®ion = global.getInitializerRegion(); |
745 | Block *block = rewriter.createBlock(parent: ®ion); |
746 | |
747 | // Initialize the struct and set the execution mode value. |
748 | rewriter.setInsertionPoint(block, insertPoint: block->begin()); |
749 | Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType); |
750 | Value executionMode = rewriter.create<LLVM::ConstantOp>( |
751 | loc, llvmI32Type, |
752 | rewriter.getI32IntegerAttr( |
753 | static_cast<uint32_t>(executionModeAttr.getValue()))); |
754 | structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue, |
755 | executionMode, 0); |
756 | |
757 | // Insert extra operands if they exist into execution mode info struct. |
758 | for (unsigned i = 0, e = values.size(); i < e; ++i) { |
759 | auto attr = values.getValue()[i]; |
760 | Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr); |
761 | structValue = rewriter.create<LLVM::InsertValueOp>( |
762 | loc, structValue, entry, ArrayRef<int64_t>({1, i})); |
763 | } |
764 | rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue})); |
765 | rewriter.eraseOp(op: op); |
766 | return success(); |
767 | } |
768 | }; |
769 | |
770 | /// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V |
771 | /// global returns a pointer, whereas in LLVM dialect the global holds an actual |
772 | /// value. This difference is handled by `spirv.mlir.addressof` and |
773 | /// `llvm.mlir.addressof`ops that both return a pointer. |
774 | class GlobalVariablePattern |
775 | : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> { |
776 | public: |
777 | template <typename... Args> |
778 | GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args) |
779 | : SPIRVToLLVMConversion<spirv::GlobalVariableOp>( |
780 | std::forward<Args>(args)...), |
781 | clientAPI(clientAPI) {} |
782 | |
783 | LogicalResult |
784 | matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor, |
785 | ConversionPatternRewriter &rewriter) const override { |
786 | // Currently, there is no support of initialization with a constant value in |
787 | // SPIR-V dialect. Specialization constants are not considered as well. |
788 | if (op.getInitializer()) |
789 | return failure(); |
790 | |
791 | auto srcType = cast<spirv::PointerType>(op.getType()); |
792 | auto dstType = typeConverter.convertType(srcType.getPointeeType()); |
793 | if (!dstType) |
794 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
795 | |
796 | // Limit conversion to the current invocation only or `StorageBuffer` |
797 | // required by SPIR-V runner. |
798 | // This is okay because multiple invocations are not supported yet. |
799 | auto storageClass = srcType.getStorageClass(); |
800 | switch (storageClass) { |
801 | case spirv::StorageClass::Input: |
802 | case spirv::StorageClass::Private: |
803 | case spirv::StorageClass::Output: |
804 | case spirv::StorageClass::StorageBuffer: |
805 | case spirv::StorageClass::UniformConstant: |
806 | break; |
807 | default: |
808 | return failure(); |
809 | } |
810 | |
811 | // LLVM dialect spec: "If the global value is a constant, storing into it is |
812 | // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant' |
813 | // storage class that is read-only. |
814 | bool isConstant = (storageClass == spirv::StorageClass::Input) || |
815 | (storageClass == spirv::StorageClass::UniformConstant); |
816 | // SPIR-V spec: "By default, functions and global variables are private to a |
817 | // module and cannot be accessed by other modules. However, a module may be |
818 | // written to export or import functions and global (module scope) |
819 | // variables.". Therefore, map 'Private' storage class to private linkage, |
820 | // 'Input' and 'Output' to external linkage. |
821 | auto linkage = storageClass == spirv::StorageClass::Private |
822 | ? LLVM::Linkage::Private |
823 | : LLVM::Linkage::External; |
824 | auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( |
825 | op, dstType, isConstant, linkage, op.getSymName(), Attribute(), |
826 | /*alignment=*/0, mapToAddressSpace(clientAPI, storageClass)); |
827 | |
828 | // Attach location attribute if applicable |
829 | if (op.getLocationAttr()) |
830 | newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr()); |
831 | |
832 | return success(); |
833 | } |
834 | |
835 | private: |
836 | spirv::ClientAPI clientAPI; |
837 | }; |
838 | |
839 | /// Converts SPIR-V cast ops that do not have straightforward LLVM |
840 | /// equivalent in LLVM dialect. |
841 | template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp> |
842 | class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> { |
843 | public: |
844 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
845 | |
846 | LogicalResult |
847 | matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, |
848 | ConversionPatternRewriter &rewriter) const override { |
849 | |
850 | Type fromType = op.getOperand().getType(); |
851 | Type toType = op.getType(); |
852 | |
853 | auto dstType = this->typeConverter.convertType(toType); |
854 | if (!dstType) |
855 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
856 | |
857 | if (getBitWidth(type: fromType) < getBitWidth(type: toType)) { |
858 | rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType, |
859 | adaptor.getOperands()); |
860 | return success(); |
861 | } |
862 | if (getBitWidth(type: fromType) > getBitWidth(type: toType)) { |
863 | rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType, |
864 | adaptor.getOperands()); |
865 | return success(); |
866 | } |
867 | return failure(); |
868 | } |
869 | }; |
870 | |
871 | class FunctionCallPattern |
872 | : public SPIRVToLLVMConversion<spirv::FunctionCallOp> { |
873 | public: |
874 | using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion; |
875 | |
876 | LogicalResult |
877 | matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor, |
878 | ConversionPatternRewriter &rewriter) const override { |
879 | if (callOp.getNumResults() == 0) { |
880 | rewriter.replaceOpWithNewOp<LLVM::CallOp>( |
881 | callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs()); |
882 | return success(); |
883 | } |
884 | |
885 | // Function returns a single result. |
886 | auto dstType = typeConverter.convertType(callOp.getType(0)); |
887 | if (!dstType) |
888 | return rewriter.notifyMatchFailure(callOp, "type conversion failed" ); |
889 | rewriter.replaceOpWithNewOp<LLVM::CallOp>( |
890 | callOp, dstType, adaptor.getOperands(), callOp->getAttrs()); |
891 | return success(); |
892 | } |
893 | }; |
894 | |
895 | /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate" |
896 | template <typename SPIRVOp, LLVM::FCmpPredicate predicate> |
897 | class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> { |
898 | public: |
899 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
900 | |
901 | LogicalResult |
902 | matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, |
903 | ConversionPatternRewriter &rewriter) const override { |
904 | |
905 | auto dstType = this->typeConverter.convertType(op.getType()); |
906 | if (!dstType) |
907 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
908 | |
909 | rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>( |
910 | op, dstType, predicate, op.getOperand1(), op.getOperand2()); |
911 | return success(); |
912 | } |
913 | }; |
914 | |
915 | /// Converts SPIR-V integer comparisons to llvm.icmp "predicate" |
916 | template <typename SPIRVOp, LLVM::ICmpPredicate predicate> |
917 | class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> { |
918 | public: |
919 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
920 | |
921 | LogicalResult |
922 | matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, |
923 | ConversionPatternRewriter &rewriter) const override { |
924 | |
925 | auto dstType = this->typeConverter.convertType(op.getType()); |
926 | if (!dstType) |
927 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
928 | |
929 | rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>( |
930 | op, dstType, predicate, op.getOperand1(), op.getOperand2()); |
931 | return success(); |
932 | } |
933 | }; |
934 | |
935 | class InverseSqrtPattern |
936 | : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> { |
937 | public: |
938 | using SPIRVToLLVMConversion<spirv::GLInverseSqrtOp>::SPIRVToLLVMConversion; |
939 | |
940 | LogicalResult |
941 | matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor, |
942 | ConversionPatternRewriter &rewriter) const override { |
943 | auto srcType = op.getType(); |
944 | auto dstType = typeConverter.convertType(srcType); |
945 | if (!dstType) |
946 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
947 | |
948 | Location loc = op.getLoc(); |
949 | Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); |
950 | Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand()); |
951 | rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt); |
952 | return success(); |
953 | } |
954 | }; |
955 | |
956 | /// Converts `spirv.Load` and `spirv.Store` to LLVM dialect. |
957 | template <typename SPIRVOp> |
958 | class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> { |
959 | public: |
960 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
961 | |
962 | LogicalResult |
963 | matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, |
964 | ConversionPatternRewriter &rewriter) const override { |
965 | if (!op.getMemoryAccess()) { |
966 | return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter, |
967 | this->typeConverter, /*alignment=*/0, |
968 | /*isVolatile=*/false, |
969 | /*isNonTemporal=*/false); |
970 | } |
971 | auto memoryAccess = *op.getMemoryAccess(); |
972 | switch (memoryAccess) { |
973 | case spirv::MemoryAccess::Aligned: |
974 | case spirv::MemoryAccess::None: |
975 | case spirv::MemoryAccess::Nontemporal: |
976 | case spirv::MemoryAccess::Volatile: { |
977 | unsigned alignment = |
978 | memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0; |
979 | bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal; |
980 | bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile; |
981 | return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter, |
982 | this->typeConverter, alignment, isVolatile, |
983 | isNonTemporal); |
984 | } |
985 | default: |
986 | // There is no support of other memory access attributes. |
987 | return failure(); |
988 | } |
989 | } |
990 | }; |
991 | |
992 | /// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect. |
993 | template <typename SPIRVOp> |
994 | class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> { |
995 | public: |
996 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
997 | |
998 | LogicalResult |
999 | matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor, |
1000 | ConversionPatternRewriter &rewriter) const override { |
1001 | auto srcType = notOp.getType(); |
1002 | auto dstType = this->typeConverter.convertType(srcType); |
1003 | if (!dstType) |
1004 | return rewriter.notifyMatchFailure(notOp, "type conversion failed" ); |
1005 | |
1006 | Location loc = notOp.getLoc(); |
1007 | IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter); |
1008 | auto mask = |
1009 | isa<VectorType>(srcType) |
1010 | ? rewriter.create<LLVM::ConstantOp>( |
1011 | loc, dstType, |
1012 | SplatElementsAttr::get(cast<VectorType>(srcType), minusOne)) |
1013 | : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne); |
1014 | rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType, |
1015 | notOp.getOperand(), mask); |
1016 | return success(); |
1017 | } |
1018 | }; |
1019 | |
1020 | /// A template pattern that erases the given `SPIRVOp`. |
1021 | template <typename SPIRVOp> |
1022 | class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> { |
1023 | public: |
1024 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
1025 | |
1026 | LogicalResult |
1027 | matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, |
1028 | ConversionPatternRewriter &rewriter) const override { |
1029 | rewriter.eraseOp(op); |
1030 | return success(); |
1031 | } |
1032 | }; |
1033 | |
1034 | class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> { |
1035 | public: |
1036 | using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion; |
1037 | |
1038 | LogicalResult |
1039 | matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor, |
1040 | ConversionPatternRewriter &rewriter) const override { |
1041 | rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(), |
1042 | ArrayRef<Value>()); |
1043 | return success(); |
1044 | } |
1045 | }; |
1046 | |
1047 | class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> { |
1048 | public: |
1049 | using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion; |
1050 | |
1051 | LogicalResult |
1052 | matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor, |
1053 | ConversionPatternRewriter &rewriter) const override { |
1054 | rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(), |
1055 | adaptor.getOperands()); |
1056 | return success(); |
1057 | } |
1058 | }; |
1059 | |
1060 | /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection |
1061 | /// should be reachable for conversion to succeed. The structure of the loop in |
1062 | /// LLVM dialect will be the following: |
1063 | /// |
1064 | /// +------------------------------------+ |
1065 | /// | <code before spirv.mlir.loop> | |
1066 | /// | llvm.br ^header | |
1067 | /// +------------------------------------+ |
1068 | /// | |
1069 | /// +----------------+ | |
1070 | /// | | | |
1071 | /// | V V |
1072 | /// | +------------------------------------+ |
1073 | /// | | ^header: | |
1074 | /// | | <header code> | |
1075 | /// | | llvm.cond_br %cond, ^body, ^exit | |
1076 | /// | +------------------------------------+ |
1077 | /// | | |
1078 | /// | |----------------------+ |
1079 | /// | | | |
1080 | /// | V | |
1081 | /// | +------------------------------------+ | |
1082 | /// | | ^body: | | |
1083 | /// | | <body code> | | |
1084 | /// | | llvm.br ^continue | | |
1085 | /// | +------------------------------------+ | |
1086 | /// | | | |
1087 | /// | V | |
1088 | /// | +------------------------------------+ | |
1089 | /// | | ^continue: | | |
1090 | /// | | <continue code> | | |
1091 | /// | | llvm.br ^header | | |
1092 | /// | +------------------------------------+ | |
1093 | /// | | | |
1094 | /// +---------------+ +----------------------+ |
1095 | /// | |
1096 | /// V |
1097 | /// +------------------------------------+ |
1098 | /// | ^exit: | |
1099 | /// | llvm.br ^remaining | |
1100 | /// +------------------------------------+ |
1101 | /// | |
1102 | /// V |
1103 | /// +------------------------------------+ |
1104 | /// | ^remaining: | |
1105 | /// | <code after spirv.mlir.loop> | |
1106 | /// +------------------------------------+ |
1107 | /// |
1108 | class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> { |
1109 | public: |
1110 | using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion; |
1111 | |
1112 | LogicalResult |
1113 | matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor, |
1114 | ConversionPatternRewriter &rewriter) const override { |
1115 | // There is no support of loop control at the moment. |
1116 | if (loopOp.getLoopControl() != spirv::LoopControl::None) |
1117 | return failure(); |
1118 | |
1119 | Location loc = loopOp.getLoc(); |
1120 | |
1121 | // Split the current block after `spirv.mlir.loop`. The remaining ops will |
1122 | // be used in `endBlock`. |
1123 | Block *currentBlock = rewriter.getBlock(); |
1124 | auto position = Block::iterator(loopOp); |
1125 | Block *endBlock = rewriter.splitBlock(block: currentBlock, before: position); |
1126 | |
1127 | // Remove entry block and create a branch in the current block going to the |
1128 | // header block. |
1129 | Block *entryBlock = loopOp.getEntryBlock(); |
1130 | assert(entryBlock->getOperations().size() == 1); |
1131 | auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front()); |
1132 | if (!brOp) |
1133 | return failure(); |
1134 | Block * = loopOp.getHeaderBlock(); |
1135 | rewriter.setInsertionPointToEnd(currentBlock); |
1136 | rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock); |
1137 | rewriter.eraseBlock(block: entryBlock); |
1138 | |
1139 | // Branch from merge block to end block. |
1140 | Block *mergeBlock = loopOp.getMergeBlock(); |
1141 | Operation *terminator = mergeBlock->getTerminator(); |
1142 | ValueRange terminatorOperands = terminator->getOperands(); |
1143 | rewriter.setInsertionPointToEnd(mergeBlock); |
1144 | rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock); |
1145 | |
1146 | rewriter.inlineRegionBefore(loopOp.getBody(), endBlock); |
1147 | rewriter.replaceOp(loopOp, endBlock->getArguments()); |
1148 | return success(); |
1149 | } |
1150 | }; |
1151 | |
1152 | /// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header |
1153 | /// block. All blocks within selection should be reachable for conversion to |
1154 | /// succeed. |
1155 | class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> { |
1156 | public: |
1157 | using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion; |
1158 | |
1159 | LogicalResult |
1160 | matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor, |
1161 | ConversionPatternRewriter &rewriter) const override { |
1162 | // There is no support for `Flatten` or `DontFlatten` selection control at |
1163 | // the moment. This are just compiler hints and can be performed during the |
1164 | // optimization passes. |
1165 | if (op.getSelectionControl() != spirv::SelectionControl::None) |
1166 | return failure(); |
1167 | |
1168 | // `spirv.mlir.selection` should have at least two blocks: one selection |
1169 | // header block and one merge block. If no blocks are present, or control |
1170 | // flow branches straight to merge block (two blocks are present), the op is |
1171 | // redundant and it is erased. |
1172 | if (op.getBody().getBlocks().size() <= 2) { |
1173 | rewriter.eraseOp(op: op); |
1174 | return success(); |
1175 | } |
1176 | |
1177 | Location loc = op.getLoc(); |
1178 | |
1179 | // Split the current block after `spirv.mlir.selection`. The remaining ops |
1180 | // will be used in `continueBlock`. |
1181 | auto *currentBlock = rewriter.getInsertionBlock(); |
1182 | rewriter.setInsertionPointAfter(op); |
1183 | auto position = rewriter.getInsertionPoint(); |
1184 | auto *continueBlock = rewriter.splitBlock(block: currentBlock, before: position); |
1185 | |
1186 | // Extract conditional branch information from the header block. By SPIR-V |
1187 | // dialect spec, it should contain `spirv.BranchConditional` or |
1188 | // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the |
1189 | // moment in the SPIR-V dialect. Remove this block when finished. |
1190 | auto * = op.getHeaderBlock(); |
1191 | assert(headerBlock->getOperations().size() == 1); |
1192 | auto condBrOp = dyn_cast<spirv::BranchConditionalOp>( |
1193 | headerBlock->getOperations().front()); |
1194 | if (!condBrOp) |
1195 | return failure(); |
1196 | rewriter.eraseBlock(block: headerBlock); |
1197 | |
1198 | // Branch from merge block to continue block. |
1199 | auto *mergeBlock = op.getMergeBlock(); |
1200 | Operation *terminator = mergeBlock->getTerminator(); |
1201 | ValueRange terminatorOperands = terminator->getOperands(); |
1202 | rewriter.setInsertionPointToEnd(mergeBlock); |
1203 | rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock); |
1204 | |
1205 | // Link current block to `true` and `false` blocks within the selection. |
1206 | Block *trueBlock = condBrOp.getTrueBlock(); |
1207 | Block *falseBlock = condBrOp.getFalseBlock(); |
1208 | rewriter.setInsertionPointToEnd(currentBlock); |
1209 | rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock, |
1210 | condBrOp.getTrueTargetOperands(), |
1211 | falseBlock, |
1212 | condBrOp.getFalseTargetOperands()); |
1213 | |
1214 | rewriter.inlineRegionBefore(op.getBody(), continueBlock); |
1215 | rewriter.replaceOp(op, continueBlock->getArguments()); |
1216 | return success(); |
1217 | } |
1218 | }; |
1219 | |
1220 | /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect |
1221 | /// puts a restriction on `Shift` and `Base` to have the same bit width, |
1222 | /// `Shift` is zero or sign extended to match this specification. Cases when |
1223 | /// `Shift` bit width > `Base` bit width are considered to be illegal. |
1224 | template <typename SPIRVOp, typename LLVMOp> |
1225 | class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> { |
1226 | public: |
1227 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
1228 | |
1229 | LogicalResult |
1230 | matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, |
1231 | ConversionPatternRewriter &rewriter) const override { |
1232 | |
1233 | auto dstType = this->typeConverter.convertType(op.getType()); |
1234 | if (!dstType) |
1235 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
1236 | |
1237 | Type op1Type = op.getOperand1().getType(); |
1238 | Type op2Type = op.getOperand2().getType(); |
1239 | |
1240 | if (op1Type == op2Type) { |
1241 | rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType, |
1242 | adaptor.getOperands()); |
1243 | return success(); |
1244 | } |
1245 | |
1246 | std::optional<uint64_t> dstTypeWidth = |
1247 | getIntegerOrVectorElementWidth(dstType); |
1248 | std::optional<uint64_t> op2TypeWidth = |
1249 | getIntegerOrVectorElementWidth(type: op2Type); |
1250 | |
1251 | if (!dstTypeWidth || !op2TypeWidth) |
1252 | return failure(); |
1253 | |
1254 | Location loc = op.getLoc(); |
1255 | Value extended; |
1256 | if (op2TypeWidth < dstTypeWidth) { |
1257 | if (isUnsignedIntegerOrVector(type: op2Type)) { |
1258 | extended = rewriter.template create<LLVM::ZExtOp>( |
1259 | loc, dstType, adaptor.getOperand2()); |
1260 | } else { |
1261 | extended = rewriter.template create<LLVM::SExtOp>( |
1262 | loc, dstType, adaptor.getOperand2()); |
1263 | } |
1264 | } else if (op2TypeWidth == dstTypeWidth) { |
1265 | extended = adaptor.getOperand2(); |
1266 | } else { |
1267 | return failure(); |
1268 | } |
1269 | |
1270 | Value result = rewriter.template create<LLVMOp>( |
1271 | loc, dstType, adaptor.getOperand1(), extended); |
1272 | rewriter.replaceOp(op, result); |
1273 | return success(); |
1274 | } |
1275 | }; |
1276 | |
1277 | class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> { |
1278 | public: |
1279 | using SPIRVToLLVMConversion<spirv::GLTanOp>::SPIRVToLLVMConversion; |
1280 | |
1281 | LogicalResult |
1282 | matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor, |
1283 | ConversionPatternRewriter &rewriter) const override { |
1284 | auto dstType = typeConverter.convertType(tanOp.getType()); |
1285 | if (!dstType) |
1286 | return rewriter.notifyMatchFailure(tanOp, "type conversion failed" ); |
1287 | |
1288 | Location loc = tanOp.getLoc(); |
1289 | Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand()); |
1290 | Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand()); |
1291 | rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos); |
1292 | return success(); |
1293 | } |
1294 | }; |
1295 | |
1296 | /// Convert `spirv.Tanh` to |
1297 | /// |
1298 | /// exp(2x) - 1 |
1299 | /// ----------- |
1300 | /// exp(2x) + 1 |
1301 | /// |
1302 | class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> { |
1303 | public: |
1304 | using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion; |
1305 | |
1306 | LogicalResult |
1307 | matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor, |
1308 | ConversionPatternRewriter &rewriter) const override { |
1309 | auto srcType = tanhOp.getType(); |
1310 | auto dstType = typeConverter.convertType(srcType); |
1311 | if (!dstType) |
1312 | return rewriter.notifyMatchFailure(tanhOp, "type conversion failed" ); |
1313 | |
1314 | Location loc = tanhOp.getLoc(); |
1315 | Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); |
1316 | Value multiplied = |
1317 | rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand()); |
1318 | Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied); |
1319 | Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); |
1320 | Value numerator = |
1321 | rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one); |
1322 | Value denominator = |
1323 | rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one); |
1324 | rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator, |
1325 | denominator); |
1326 | return success(); |
1327 | } |
1328 | }; |
1329 | |
1330 | class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> { |
1331 | public: |
1332 | using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion; |
1333 | |
1334 | LogicalResult |
1335 | matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor, |
1336 | ConversionPatternRewriter &rewriter) const override { |
1337 | auto srcType = varOp.getType(); |
1338 | // Initialization is supported for scalars and vectors only. |
1339 | auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType(); |
1340 | auto init = varOp.getInitializer(); |
1341 | if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo)) |
1342 | return failure(); |
1343 | |
1344 | auto dstType = typeConverter.convertType(srcType); |
1345 | if (!dstType) |
1346 | return rewriter.notifyMatchFailure(varOp, "type conversion failed" ); |
1347 | |
1348 | Location loc = varOp.getLoc(); |
1349 | Value size = createI32ConstantOf(loc, rewriter, value: 1); |
1350 | if (!init) { |
1351 | auto elementType = typeConverter.convertType(pointerTo); |
1352 | if (!elementType) |
1353 | return rewriter.notifyMatchFailure(varOp, "type conversion failed" ); |
1354 | rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType, |
1355 | size); |
1356 | return success(); |
1357 | } |
1358 | auto elementType = typeConverter.convertType(pointerTo); |
1359 | if (!elementType) |
1360 | return rewriter.notifyMatchFailure(varOp, "type conversion failed" ); |
1361 | Value allocated = |
1362 | rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size); |
1363 | rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated); |
1364 | rewriter.replaceOp(varOp, allocated); |
1365 | return success(); |
1366 | } |
1367 | }; |
1368 | |
1369 | //===----------------------------------------------------------------------===// |
1370 | // BitcastOp conversion |
1371 | //===----------------------------------------------------------------------===// |
1372 | |
1373 | class BitcastConversionPattern |
1374 | : public SPIRVToLLVMConversion<spirv::BitcastOp> { |
1375 | public: |
1376 | using SPIRVToLLVMConversion<spirv::BitcastOp>::SPIRVToLLVMConversion; |
1377 | |
1378 | LogicalResult |
1379 | matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor, |
1380 | ConversionPatternRewriter &rewriter) const override { |
1381 | auto dstType = typeConverter.convertType(bitcastOp.getType()); |
1382 | if (!dstType) |
1383 | return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed" ); |
1384 | |
1385 | // LLVM's opaque pointers do not require bitcasts. |
1386 | if (isa<LLVM::LLVMPointerType>(dstType)) { |
1387 | rewriter.replaceOp(bitcastOp, adaptor.getOperand()); |
1388 | return success(); |
1389 | } |
1390 | |
1391 | rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( |
1392 | bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs()); |
1393 | return success(); |
1394 | } |
1395 | }; |
1396 | |
1397 | //===----------------------------------------------------------------------===// |
1398 | // FuncOp conversion |
1399 | //===----------------------------------------------------------------------===// |
1400 | |
1401 | class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> { |
1402 | public: |
1403 | using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion; |
1404 | |
1405 | LogicalResult |
1406 | matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor, |
1407 | ConversionPatternRewriter &rewriter) const override { |
1408 | |
1409 | // Convert function signature. At the moment LLVMType converter is enough |
1410 | // for currently supported types. |
1411 | auto funcType = funcOp.getFunctionType(); |
1412 | TypeConverter::SignatureConversion signatureConverter( |
1413 | funcType.getNumInputs()); |
1414 | auto llvmType = typeConverter.convertFunctionSignature( |
1415 | funcType, /*isVariadic=*/false, /*useBarePtrCallConv=*/false, |
1416 | signatureConverter); |
1417 | if (!llvmType) |
1418 | return failure(); |
1419 | |
1420 | // Create a new `LLVMFuncOp` |
1421 | Location loc = funcOp.getLoc(); |
1422 | StringRef name = funcOp.getName(); |
1423 | auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType); |
1424 | |
1425 | // Convert SPIR-V Function Control to equivalent LLVM function attribute |
1426 | MLIRContext *context = funcOp.getContext(); |
1427 | switch (funcOp.getFunctionControl()) { |
1428 | #define DISPATCH(functionControl, llvmAttr) \ |
1429 | case functionControl: \ |
1430 | newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \ |
1431 | break; |
1432 | |
1433 | DISPATCH(spirv::FunctionControl::Inline, |
1434 | StringAttr::get(context, "alwaysinline" )); |
1435 | DISPATCH(spirv::FunctionControl::DontInline, |
1436 | StringAttr::get(context, "noinline" )); |
1437 | DISPATCH(spirv::FunctionControl::Pure, |
1438 | StringAttr::get(context, "readonly" )); |
1439 | DISPATCH(spirv::FunctionControl::Const, |
1440 | StringAttr::get(context, "readnone" )); |
1441 | |
1442 | #undef DISPATCH |
1443 | |
1444 | // Default: if `spirv::FunctionControl::None`, then no attributes are |
1445 | // needed. |
1446 | default: |
1447 | break; |
1448 | } |
1449 | |
1450 | rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), |
1451 | newFuncOp.end()); |
1452 | if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, |
1453 | &signatureConverter))) { |
1454 | return failure(); |
1455 | } |
1456 | rewriter.eraseOp(op: funcOp); |
1457 | return success(); |
1458 | } |
1459 | }; |
1460 | |
1461 | //===----------------------------------------------------------------------===// |
1462 | // ModuleOp conversion |
1463 | //===----------------------------------------------------------------------===// |
1464 | |
1465 | class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> { |
1466 | public: |
1467 | using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion; |
1468 | |
1469 | LogicalResult |
1470 | matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor, |
1471 | ConversionPatternRewriter &rewriter) const override { |
1472 | |
1473 | auto newModuleOp = |
1474 | rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName()); |
1475 | rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody()); |
1476 | |
1477 | // Remove the terminator block that was automatically added by builder |
1478 | rewriter.eraseBlock(block: &newModuleOp.getBodyRegion().back()); |
1479 | rewriter.eraseOp(op: spvModuleOp); |
1480 | return success(); |
1481 | } |
1482 | }; |
1483 | |
1484 | //===----------------------------------------------------------------------===// |
1485 | // VectorShuffleOp conversion |
1486 | //===----------------------------------------------------------------------===// |
1487 | |
1488 | class VectorShufflePattern |
1489 | : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> { |
1490 | public: |
1491 | using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion; |
1492 | LogicalResult |
1493 | matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor, |
1494 | ConversionPatternRewriter &rewriter) const override { |
1495 | Location loc = op.getLoc(); |
1496 | auto components = adaptor.getComponents(); |
1497 | auto vector1 = adaptor.getVector1(); |
1498 | auto vector2 = adaptor.getVector2(); |
1499 | int vector1Size = cast<VectorType>(vector1.getType()).getNumElements(); |
1500 | int vector2Size = cast<VectorType>(vector2.getType()).getNumElements(); |
1501 | if (vector1Size == vector2Size) { |
1502 | rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>( |
1503 | op, vector1, vector2, |
1504 | LLVM::convertArrayToIndices<int32_t>(components)); |
1505 | return success(); |
1506 | } |
1507 | |
1508 | auto dstType = typeConverter.convertType(op.getType()); |
1509 | if (!dstType) |
1510 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
1511 | auto scalarType = cast<VectorType>(dstType).getElementType(); |
1512 | auto componentsArray = components.getValue(); |
1513 | auto *context = rewriter.getContext(); |
1514 | auto llvmI32Type = IntegerType::get(context, 32); |
1515 | Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType); |
1516 | for (unsigned i = 0; i < componentsArray.size(); i++) { |
1517 | if (!isa<IntegerAttr>(componentsArray[i])) |
1518 | return op.emitError("unable to support non-constant component" ); |
1519 | |
1520 | int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt(); |
1521 | if (indexVal == -1) |
1522 | continue; |
1523 | |
1524 | int offsetVal = 0; |
1525 | Value baseVector = vector1; |
1526 | if (indexVal >= vector1Size) { |
1527 | offsetVal = vector1Size; |
1528 | baseVector = vector2; |
1529 | } |
1530 | |
1531 | Value dstIndex = rewriter.create<LLVM::ConstantOp>( |
1532 | loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i)); |
1533 | Value index = rewriter.create<LLVM::ConstantOp>( |
1534 | loc, llvmI32Type, |
1535 | rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal)); |
1536 | |
1537 | auto = rewriter.create<LLVM::ExtractElementOp>( |
1538 | loc, scalarType, baseVector, index); |
1539 | targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp, |
1540 | extractOp, dstIndex); |
1541 | } |
1542 | rewriter.replaceOp(op, targetOp); |
1543 | return success(); |
1544 | } |
1545 | }; |
1546 | } // namespace |
1547 | |
1548 | //===----------------------------------------------------------------------===// |
1549 | // Pattern population |
1550 | //===----------------------------------------------------------------------===// |
1551 | |
1552 | void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, |
1553 | spirv::ClientAPI clientAPI) { |
1554 | typeConverter.addConversion(callback: [&](spirv::ArrayType type) { |
1555 | return convertArrayType(type, converter&: typeConverter); |
1556 | }); |
1557 | typeConverter.addConversion(callback: [&, clientAPI](spirv::PointerType type) { |
1558 | return convertPointerType(type, typeConverter, clientAPI); |
1559 | }); |
1560 | typeConverter.addConversion(callback: [&](spirv::RuntimeArrayType type) { |
1561 | return convertRuntimeArrayType(type, converter&: typeConverter); |
1562 | }); |
1563 | typeConverter.addConversion(callback: [&](spirv::StructType type) { |
1564 | return convertStructType(type, converter&: typeConverter); |
1565 | }); |
1566 | } |
1567 | |
1568 | void mlir::populateSPIRVToLLVMConversionPatterns( |
1569 | LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, |
1570 | spirv::ClientAPI clientAPI) { |
1571 | patterns.add< |
1572 | // Arithmetic ops |
1573 | DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>, |
1574 | DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>, |
1575 | DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>, |
1576 | DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>, |
1577 | DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>, |
1578 | DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>, |
1579 | DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>, |
1580 | DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>, |
1581 | DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>, |
1582 | DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>, |
1583 | DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>, |
1584 | DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>, |
1585 | DirectConversionPattern<spirv::UModOp, LLVM::URemOp>, |
1586 | |
1587 | // Bitwise ops |
1588 | BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern, |
1589 | DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>, |
1590 | DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>, |
1591 | DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>, |
1592 | DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>, |
1593 | DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>, |
1594 | NotPattern<spirv::NotOp>, |
1595 | |
1596 | // Cast ops |
1597 | BitcastConversionPattern, |
1598 | DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>, |
1599 | DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>, |
1600 | DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>, |
1601 | DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>, |
1602 | IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>, |
1603 | IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>, |
1604 | IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>, |
1605 | |
1606 | // Comparison ops |
1607 | IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>, |
1608 | IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>, |
1609 | FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>, |
1610 | FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>, |
1611 | FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>, |
1612 | FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>, |
1613 | FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>, |
1614 | FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>, |
1615 | FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>, |
1616 | FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>, |
1617 | FComparePattern<spirv::FUnordGreaterThanEqualOp, |
1618 | LLVM::FCmpPredicate::uge>, |
1619 | FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>, |
1620 | FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>, |
1621 | FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>, |
1622 | IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>, |
1623 | IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>, |
1624 | IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>, |
1625 | IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>, |
1626 | IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>, |
1627 | IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>, |
1628 | IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>, |
1629 | IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>, |
1630 | |
1631 | // Constant op |
1632 | ConstantScalarAndVectorPattern, |
1633 | |
1634 | // Control Flow ops |
1635 | BranchConversionPattern, BranchConditionalConversionPattern, |
1636 | FunctionCallPattern, LoopPattern, SelectionPattern, |
1637 | ErasePattern<spirv::MergeOp>, |
1638 | |
1639 | // Entry points and execution mode are handled separately. |
1640 | ErasePattern<spirv::EntryPointOp>, ExecutionModePattern, |
1641 | |
1642 | // GLSL extended instruction set ops |
1643 | DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>, |
1644 | DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>, |
1645 | DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>, |
1646 | DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>, |
1647 | DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>, |
1648 | DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>, |
1649 | DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>, |
1650 | DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>, |
1651 | DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>, |
1652 | DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>, |
1653 | DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>, |
1654 | DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>, |
1655 | InverseSqrtPattern, TanPattern, TanhPattern, |
1656 | |
1657 | // Logical ops |
1658 | DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>, |
1659 | DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>, |
1660 | IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>, |
1661 | IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>, |
1662 | NotPattern<spirv::LogicalNotOp>, |
1663 | |
1664 | // Memory ops |
1665 | AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>, |
1666 | LoadStorePattern<spirv::StoreOp>, VariablePattern, |
1667 | |
1668 | // Miscellaneous ops |
1669 | CompositeExtractPattern, CompositeInsertPattern, |
1670 | DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>, |
1671 | DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>, |
1672 | VectorShufflePattern, |
1673 | |
1674 | // Shift ops |
1675 | ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>, |
1676 | ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>, |
1677 | ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>, |
1678 | |
1679 | // Return ops |
1680 | ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter); |
1681 | |
1682 | patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(), |
1683 | typeConverter); |
1684 | } |
1685 | |
1686 | void mlir::populateSPIRVToLLVMFunctionConversionPatterns( |
1687 | LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { |
1688 | patterns.add<FuncConversionPattern>(arg: patterns.getContext(), args&: typeConverter); |
1689 | } |
1690 | |
1691 | void mlir::populateSPIRVToLLVMModuleConversionPatterns( |
1692 | LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { |
1693 | patterns.add<ModuleConversionPattern>(arg: patterns.getContext(), args&: typeConverter); |
1694 | } |
1695 | |
1696 | //===----------------------------------------------------------------------===// |
1697 | // Pre-conversion hooks |
1698 | //===----------------------------------------------------------------------===// |
1699 | |
1700 | /// Hook for descriptor set and binding number encoding. |
1701 | static constexpr StringRef kBinding = "binding" ; |
1702 | static constexpr StringRef kDescriptorSet = "descriptor_set" ; |
1703 | void mlir::encodeBindAttribute(ModuleOp module) { |
1704 | auto spvModules = module.getOps<spirv::ModuleOp>(); |
1705 | for (auto spvModule : spvModules) { |
1706 | spvModule.walk([&](spirv::GlobalVariableOp op) { |
1707 | IntegerAttr descriptorSet = |
1708 | op->getAttrOfType<IntegerAttr>(kDescriptorSet); |
1709 | IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding); |
1710 | // For every global variable in the module, get the ones with descriptor |
1711 | // set and binding numbers. |
1712 | if (descriptorSet && binding) { |
1713 | // Encode these numbers into the variable's symbolic name. If the |
1714 | // SPIR-V module has a name, add it at the beginning. |
1715 | auto moduleAndName = |
1716 | spvModule.getName().has_value() |
1717 | ? spvModule.getName()->str() + "_" + op.getSymName().str() |
1718 | : op.getSymName().str(); |
1719 | std::string name = |
1720 | llvm::formatv("{0}_descriptor_set{1}_binding{2}" , moduleAndName, |
1721 | std::to_string(descriptorSet.getInt()), |
1722 | std::to_string(binding.getInt())); |
1723 | auto nameAttr = StringAttr::get(op->getContext(), name); |
1724 | |
1725 | // Replace all symbol uses and set the new symbol name. Finally, remove |
1726 | // descriptor set and binding attributes. |
1727 | if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule))) |
1728 | op.emitError("unable to replace all symbol uses for " ) << name; |
1729 | SymbolTable::setSymbolName(op, nameAttr); |
1730 | op->removeAttr(kDescriptorSet); |
1731 | op->removeAttr(kBinding); |
1732 | } |
1733 | }); |
1734 | } |
1735 | } |
1736 | |