1 | //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements patterns to convert SPIR-V dialect to LLVM dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h" |
14 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
15 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
16 | #include "mlir/Conversion/SPIRVCommon/AttrToLLVMConverter.h" |
17 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
18 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
19 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
20 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
21 | #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" |
22 | #include "mlir/IR/BuiltinOps.h" |
23 | #include "mlir/IR/PatternMatch.h" |
24 | #include "mlir/Transforms/DialectConversion.h" |
25 | #include "llvm/ADT/TypeSwitch.h" |
26 | #include "llvm/Support/Debug.h" |
27 | #include "llvm/Support/FormatVariadic.h" |
28 | |
29 | #define DEBUG_TYPE "spirv-to-llvm-pattern" |
30 | |
31 | using namespace mlir; |
32 | |
33 | //===----------------------------------------------------------------------===// |
34 | // Utility functions |
35 | //===----------------------------------------------------------------------===// |
36 | |
37 | /// Returns true if the given type is a signed integer or vector type. |
38 | static bool isSignedIntegerOrVector(Type type) { |
39 | if (type.isSignedInteger()) |
40 | return true; |
41 | if (auto vecType = dyn_cast<VectorType>(type)) |
42 | return vecType.getElementType().isSignedInteger(); |
43 | return false; |
44 | } |
45 | |
46 | /// Returns true if the given type is an unsigned integer or vector type |
47 | static bool isUnsignedIntegerOrVector(Type type) { |
48 | if (type.isUnsignedInteger()) |
49 | return true; |
50 | if (auto vecType = dyn_cast<VectorType>(type)) |
51 | return vecType.getElementType().isUnsignedInteger(); |
52 | return false; |
53 | } |
54 | |
55 | /// Returns the width of an integer or of the element type of an integer vector, |
56 | /// if applicable. |
57 | static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) { |
58 | if (auto intType = dyn_cast<IntegerType>(type)) |
59 | return intType.getWidth(); |
60 | if (auto vecType = dyn_cast<VectorType>(type)) |
61 | if (auto intType = dyn_cast<IntegerType>(vecType.getElementType())) |
62 | return intType.getWidth(); |
63 | return std::nullopt; |
64 | } |
65 | |
66 | /// Returns the bit width of integer, float or vector of float or integer values |
67 | static unsigned getBitWidth(Type type) { |
68 | assert((type.isIntOrFloat() || isa<VectorType>(type)) && |
69 | "bitwidth is not supported for this type"); |
70 | if (type.isIntOrFloat()) |
71 | return type.getIntOrFloatBitWidth(); |
72 | auto vecType = dyn_cast<VectorType>(type); |
73 | auto elementType = vecType.getElementType(); |
74 | assert(elementType.isIntOrFloat() && |
75 | "only integers and floats have a bitwidth"); |
76 | return elementType.getIntOrFloatBitWidth(); |
77 | } |
78 | |
79 | /// Returns the bit width of LLVMType integer or vector. |
80 | static unsigned getLLVMTypeBitWidth(Type type) { |
81 | if (auto vecTy = dyn_cast<VectorType>(type)) |
82 | type = vecTy.getElementType(); |
83 | return cast<IntegerType>(type).getWidth(); |
84 | } |
85 | |
86 | /// Creates `IntegerAttribute` with all bits set for given type |
87 | static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { |
88 | if (auto vecType = dyn_cast<VectorType>(type)) { |
89 | auto integerType = cast<IntegerType>(vecType.getElementType()); |
90 | return builder.getIntegerAttr(integerType, -1); |
91 | } |
92 | auto integerType = cast<IntegerType>(type); |
93 | return builder.getIntegerAttr(integerType, -1); |
94 | } |
95 | |
96 | /// Creates `llvm.mlir.constant` with all bits set for the given type. |
97 | static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, |
98 | PatternRewriter &rewriter) { |
99 | if (isa<VectorType>(Val: srcType)) { |
100 | return rewriter.create<LLVM::ConstantOp>( |
101 | loc, dstType, |
102 | SplatElementsAttr::get(cast<ShapedType>(srcType), |
103 | minusOneIntegerAttribute(srcType, rewriter))); |
104 | } |
105 | return rewriter.create<LLVM::ConstantOp>( |
106 | loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); |
107 | } |
108 | |
109 | /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value. |
110 | static Value createFPConstant(Location loc, Type srcType, Type dstType, |
111 | PatternRewriter &rewriter, double value) { |
112 | if (auto vecType = dyn_cast<VectorType>(srcType)) { |
113 | auto floatType = cast<FloatType>(vecType.getElementType()); |
114 | return rewriter.create<LLVM::ConstantOp>( |
115 | loc, dstType, |
116 | SplatElementsAttr::get(vecType, |
117 | rewriter.getFloatAttr(floatType, value))); |
118 | } |
119 | auto floatType = cast<FloatType>(srcType); |
120 | return rewriter.create<LLVM::ConstantOp>( |
121 | loc, dstType, rewriter.getFloatAttr(floatType, value)); |
122 | } |
123 | |
124 | /// Utility function for bitfield ops: |
125 | /// - `BitFieldInsert` |
126 | /// - `BitFieldSExtract` |
127 | /// - `BitFieldUExtract` |
128 | /// Truncates or extends the value. If the bitwidth of the value is the same as |
129 | /// `llvmType` bitwidth, the value remains unchanged. |
130 | static Value optionallyTruncateOrExtend(Location loc, Value value, |
131 | Type llvmType, |
132 | PatternRewriter &rewriter) { |
133 | auto srcType = value.getType(); |
134 | unsigned targetBitWidth = getLLVMTypeBitWidth(type: llvmType); |
135 | unsigned valueBitWidth = LLVM::isCompatibleType(type: srcType) |
136 | ? getLLVMTypeBitWidth(type: srcType) |
137 | : getBitWidth(type: srcType); |
138 | |
139 | if (valueBitWidth < targetBitWidth) |
140 | return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value); |
141 | // If the bit widths of `Count` and `Offset` are greater than the bit width |
142 | // of the target type, they are truncated. Truncation is safe since `Count` |
143 | // and `Offset` must be no more than 64 for op behaviour to be defined. Hence, |
144 | // both values can be expressed in 8 bits. |
145 | if (valueBitWidth > targetBitWidth) |
146 | return rewriter.create<LLVM::TruncOp>(loc, llvmType, value); |
147 | return value; |
148 | } |
149 | |
150 | /// Broadcasts the value to vector with `numElements` number of elements. |
151 | static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, |
152 | const TypeConverter &typeConverter, |
153 | ConversionPatternRewriter &rewriter) { |
154 | auto vectorType = VectorType::get(numElements, toBroadcast.getType()); |
155 | auto llvmVectorType = typeConverter.convertType(vectorType); |
156 | auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); |
157 | Value broadcasted = rewriter.create<LLVM::PoisonOp>(loc, llvmVectorType); |
158 | for (unsigned i = 0; i < numElements; ++i) { |
159 | auto index = rewriter.create<LLVM::ConstantOp>( |
160 | loc, llvmI32Type, rewriter.getI32IntegerAttr(i)); |
161 | broadcasted = rewriter.create<LLVM::InsertElementOp>( |
162 | loc, llvmVectorType, broadcasted, toBroadcast, index); |
163 | } |
164 | return broadcasted; |
165 | } |
166 | |
167 | /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged. |
168 | static Value optionallyBroadcast(Location loc, Value value, Type srcType, |
169 | const TypeConverter &typeConverter, |
170 | ConversionPatternRewriter &rewriter) { |
171 | if (auto vectorType = dyn_cast<VectorType>(srcType)) { |
172 | unsigned numElements = vectorType.getNumElements(); |
173 | return broadcast(loc, toBroadcast: value, numElements, typeConverter, rewriter); |
174 | } |
175 | return value; |
176 | } |
177 | |
178 | /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and |
179 | /// `BitFieldUExtract`. |
180 | /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of |
181 | /// a vector type, construct a vector that has: |
182 | /// - same number of elements as `Base` |
183 | /// - each element has the type that is the same as the type of `Offset` or |
184 | /// `Count` |
185 | /// - each element has the same value as `Offset` or `Count` |
186 | /// Then cast `Offset` and `Count` if their bit width is different |
187 | /// from `Base` bit width. |
188 | static Value processCountOrOffset(Location loc, Value value, Type srcType, |
189 | Type dstType, const TypeConverter &converter, |
190 | ConversionPatternRewriter &rewriter) { |
191 | Value broadcasted = |
192 | optionallyBroadcast(loc, value, srcType, typeConverter: converter, rewriter); |
193 | return optionallyTruncateOrExtend(loc, value: broadcasted, llvmType: dstType, rewriter); |
194 | } |
195 | |
196 | /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`) |
197 | /// offset to LLVM struct. Otherwise, the conversion is not supported. |
198 | static Type convertStructTypeWithOffset(spirv::StructType type, |
199 | const TypeConverter &converter) { |
200 | if (type != VulkanLayoutUtils::decorateType(structType: type)) |
201 | return nullptr; |
202 | |
203 | SmallVector<Type> elementsVector; |
204 | if (failed(Result: converter.convertTypes(types: type.getElementTypes(), results&: elementsVector))) |
205 | return nullptr; |
206 | return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector, |
207 | /*isPacked=*/false); |
208 | } |
209 | |
210 | /// Converts SPIR-V struct with no offset to packed LLVM struct. |
211 | static Type convertStructTypePacked(spirv::StructType type, |
212 | const TypeConverter &converter) { |
213 | SmallVector<Type> elementsVector; |
214 | if (failed(Result: converter.convertTypes(types: type.getElementTypes(), results&: elementsVector))) |
215 | return nullptr; |
216 | return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector, |
217 | /*isPacked=*/true); |
218 | } |
219 | |
220 | /// Creates LLVM dialect constant with the given value. |
221 | static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, |
222 | unsigned value) { |
223 | return rewriter.create<LLVM::ConstantOp>( |
224 | loc, IntegerType::get(rewriter.getContext(), 32), |
225 | rewriter.getIntegerAttr(rewriter.getI32Type(), value)); |
226 | } |
227 | |
228 | /// Utility for `spirv.Load` and `spirv.Store` conversion. |
229 | static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, |
230 | ConversionPatternRewriter &rewriter, |
231 | const TypeConverter &typeConverter, |
232 | unsigned alignment, bool isVolatile, |
233 | bool isNonTemporal) { |
234 | if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) { |
235 | auto dstType = typeConverter.convertType(loadOp.getType()); |
236 | if (!dstType) |
237 | return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed"); |
238 | rewriter.replaceOpWithNewOp<LLVM::LoadOp>( |
239 | loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment, |
240 | isVolatile, isNonTemporal); |
241 | return success(); |
242 | } |
243 | auto storeOp = cast<spirv::StoreOp>(op); |
244 | spirv::StoreOpAdaptor adaptor(operands); |
245 | rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(), |
246 | adaptor.getPtr(), alignment, |
247 | isVolatile, isNonTemporal); |
248 | return success(); |
249 | } |
250 | |
251 | //===----------------------------------------------------------------------===// |
252 | // Type conversion |
253 | //===----------------------------------------------------------------------===// |
254 | |
255 | /// Converts SPIR-V array type to LLVM array. Natural stride (according to |
256 | /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected |
257 | /// when converting ops that manipulate array types. |
258 | static std::optional<Type> convertArrayType(spirv::ArrayType type, |
259 | TypeConverter &converter) { |
260 | unsigned stride = type.getArrayStride(); |
261 | Type elementType = type.getElementType(); |
262 | auto sizeInBytes = cast<spirv::SPIRVType>(Val&: elementType).getSizeInBytes(); |
263 | if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride)) |
264 | return std::nullopt; |
265 | |
266 | auto llvmElementType = converter.convertType(t: elementType); |
267 | unsigned numElements = type.getNumElements(); |
268 | return LLVM::LLVMArrayType::get(llvmElementType, numElements); |
269 | } |
270 | |
271 | /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not |
272 | /// modelled at the moment. |
273 | static Type convertPointerType(spirv::PointerType type, |
274 | const TypeConverter &converter, |
275 | spirv::ClientAPI clientAPI) { |
276 | unsigned addressSpace = |
277 | storageClassToAddressSpace(clientAPI, type.getStorageClass()); |
278 | return LLVM::LLVMPointerType::get(type.getContext(), addressSpace); |
279 | } |
280 | |
281 | /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over |
282 | /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is |
283 | /// no modelling of array stride at the moment. |
284 | static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type, |
285 | TypeConverter &converter) { |
286 | if (type.getArrayStride() != 0) |
287 | return std::nullopt; |
288 | auto elementType = converter.convertType(t: type.getElementType()); |
289 | return LLVM::LLVMArrayType::get(elementType, 0); |
290 | } |
291 | |
292 | /// Converts SPIR-V struct to LLVM struct. There is no support of structs with |
293 | /// member decorations. Also, only natural offset is supported. |
294 | static Type convertStructType(spirv::StructType type, |
295 | const TypeConverter &converter) { |
296 | SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; |
297 | type.getMemberDecorations(memberDecorations); |
298 | if (!memberDecorations.empty()) |
299 | return nullptr; |
300 | if (type.hasOffset()) |
301 | return convertStructTypeWithOffset(type, converter); |
302 | return convertStructTypePacked(type, converter); |
303 | } |
304 | |
305 | //===----------------------------------------------------------------------===// |
306 | // Operation conversion |
307 | //===----------------------------------------------------------------------===// |
308 | |
309 | namespace { |
310 | |
311 | class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> { |
312 | public: |
313 | using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion; |
314 | |
315 | LogicalResult |
316 | matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor, |
317 | ConversionPatternRewriter &rewriter) const override { |
318 | auto dstType = |
319 | getTypeConverter()->convertType(op.getComponentPtr().getType()); |
320 | if (!dstType) |
321 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
322 | // To use GEP we need to add a first 0 index to go through the pointer. |
323 | auto indices = llvm::to_vector<4>(adaptor.getIndices()); |
324 | Type indexType = op.getIndices().front().getType(); |
325 | auto llvmIndexType = getTypeConverter()->convertType(indexType); |
326 | if (!llvmIndexType) |
327 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
328 | Value zero = rewriter.create<LLVM::ConstantOp>( |
329 | op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0)); |
330 | indices.insert(indices.begin(), zero); |
331 | |
332 | auto elementType = getTypeConverter()->convertType( |
333 | cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType()); |
334 | if (!elementType) |
335 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
336 | rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType, |
337 | adaptor.getBasePtr(), indices); |
338 | return success(); |
339 | } |
340 | }; |
341 | |
342 | class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> { |
343 | public: |
344 | using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion; |
345 | |
346 | LogicalResult |
347 | matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor, |
348 | ConversionPatternRewriter &rewriter) const override { |
349 | auto dstType = getTypeConverter()->convertType(op.getPointer().getType()); |
350 | if (!dstType) |
351 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
352 | rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, |
353 | op.getVariable()); |
354 | return success(); |
355 | } |
356 | }; |
357 | |
358 | class BitFieldInsertPattern |
359 | : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> { |
360 | public: |
361 | using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion; |
362 | |
363 | LogicalResult |
364 | matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor, |
365 | ConversionPatternRewriter &rewriter) const override { |
366 | auto srcType = op.getType(); |
367 | auto dstType = getTypeConverter()->convertType(srcType); |
368 | if (!dstType) |
369 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
370 | Location loc = op.getLoc(); |
371 | |
372 | // Process `Offset` and `Count`: broadcast and extend/truncate if needed. |
373 | Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType, |
374 | *getTypeConverter(), rewriter); |
375 | Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType, |
376 | *getTypeConverter(), rewriter); |
377 | |
378 | // Create a mask with bits set outside [Offset, Offset + Count - 1]. |
379 | Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); |
380 | Value maskShiftedByCount = |
381 | rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count); |
382 | Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType, |
383 | maskShiftedByCount, minusOne); |
384 | Value maskShiftedByCountAndOffset = |
385 | rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset); |
386 | Value mask = rewriter.create<LLVM::XOrOp>( |
387 | loc, dstType, maskShiftedByCountAndOffset, minusOne); |
388 | |
389 | // Extract unchanged bits from the `Base` that are outside of |
390 | // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`. |
391 | Value baseAndMask = |
392 | rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask); |
393 | Value insertShiftedByOffset = |
394 | rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset); |
395 | rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask, |
396 | insertShiftedByOffset); |
397 | return success(); |
398 | } |
399 | }; |
400 | |
401 | /// Converts SPIR-V ConstantOp with scalar or vector type. |
402 | class ConstantScalarAndVectorPattern |
403 | : public SPIRVToLLVMConversion<spirv::ConstantOp> { |
404 | public: |
405 | using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion; |
406 | |
407 | LogicalResult |
408 | matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor, |
409 | ConversionPatternRewriter &rewriter) const override { |
410 | auto srcType = constOp.getType(); |
411 | if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat()) |
412 | return failure(); |
413 | |
414 | auto dstType = getTypeConverter()->convertType(srcType); |
415 | if (!dstType) |
416 | return rewriter.notifyMatchFailure(constOp, "type conversion failed"); |
417 | |
418 | // SPIR-V constant can be a signed/unsigned integer, which has to be |
419 | // casted to signless integer when converting to LLVM dialect. Removing the |
420 | // sign bit may have unexpected behaviour. However, it is better to handle |
421 | // it case-by-case, given that the purpose of the conversion is not to |
422 | // cover all possible corner cases. |
423 | if (isSignedIntegerOrVector(srcType) || |
424 | isUnsignedIntegerOrVector(srcType)) { |
425 | auto signlessType = rewriter.getIntegerType(getBitWidth(srcType)); |
426 | |
427 | if (isa<VectorType>(srcType)) { |
428 | auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue()); |
429 | rewriter.replaceOpWithNewOp<LLVM::ConstantOp>( |
430 | constOp, dstType, |
431 | dstElementsAttr.mapValues( |
432 | signlessType, [&](const APInt &value) { return value; })); |
433 | return success(); |
434 | } |
435 | auto srcAttr = cast<IntegerAttr>(constOp.getValue()); |
436 | auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue()); |
437 | rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr); |
438 | return success(); |
439 | } |
440 | rewriter.replaceOpWithNewOp<LLVM::ConstantOp>( |
441 | constOp, dstType, adaptor.getOperands(), constOp->getAttrs()); |
442 | return success(); |
443 | } |
444 | }; |
445 | |
446 | class BitFieldSExtractPattern |
447 | : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> { |
448 | public: |
449 | using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion; |
450 | |
451 | LogicalResult |
452 | matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor, |
453 | ConversionPatternRewriter &rewriter) const override { |
454 | auto srcType = op.getType(); |
455 | auto dstType = getTypeConverter()->convertType(srcType); |
456 | if (!dstType) |
457 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
458 | Location loc = op.getLoc(); |
459 | |
460 | // Process `Offset` and `Count`: broadcast and extend/truncate if needed. |
461 | Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType, |
462 | *getTypeConverter(), rewriter); |
463 | Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType, |
464 | *getTypeConverter(), rewriter); |
465 | |
466 | // Create a constant that holds the size of the `Base`. |
467 | IntegerType integerType; |
468 | if (auto vecType = dyn_cast<VectorType>(srcType)) |
469 | integerType = cast<IntegerType>(vecType.getElementType()); |
470 | else |
471 | integerType = cast<IntegerType>(srcType); |
472 | |
473 | auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType)); |
474 | Value size = |
475 | isa<VectorType>(srcType) |
476 | ? rewriter.create<LLVM::ConstantOp>( |
477 | loc, dstType, |
478 | SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize)) |
479 | : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize); |
480 | |
481 | // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit |
482 | // at Offset + Count - 1 is the most significant bit now. |
483 | Value countPlusOffset = |
484 | rewriter.create<LLVM::AddOp>(loc, dstType, count, offset); |
485 | Value amountToShiftLeft = |
486 | rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset); |
487 | Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>( |
488 | loc, dstType, op.getBase(), amountToShiftLeft); |
489 | |
490 | // Shift the result right, filling the bits with the sign bit. |
491 | Value amountToShiftRight = |
492 | rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft); |
493 | rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft, |
494 | amountToShiftRight); |
495 | return success(); |
496 | } |
497 | }; |
498 | |
499 | class BitFieldUExtractPattern |
500 | : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> { |
501 | public: |
502 | using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion; |
503 | |
504 | LogicalResult |
505 | matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor, |
506 | ConversionPatternRewriter &rewriter) const override { |
507 | auto srcType = op.getType(); |
508 | auto dstType = getTypeConverter()->convertType(srcType); |
509 | if (!dstType) |
510 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
511 | Location loc = op.getLoc(); |
512 | |
513 | // Process `Offset` and `Count`: broadcast and extend/truncate if needed. |
514 | Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType, |
515 | *getTypeConverter(), rewriter); |
516 | Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType, |
517 | *getTypeConverter(), rewriter); |
518 | |
519 | // Create a mask with bits set at [0, Count - 1]. |
520 | Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); |
521 | Value maskShiftedByCount = |
522 | rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count); |
523 | Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount, |
524 | minusOne); |
525 | |
526 | // Shift `Base` by `Offset` and apply the mask on it. |
527 | Value shiftedBase = |
528 | rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset); |
529 | rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask); |
530 | return success(); |
531 | } |
532 | }; |
533 | |
534 | class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> { |
535 | public: |
536 | using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion; |
537 | |
538 | LogicalResult |
539 | matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor, |
540 | ConversionPatternRewriter &rewriter) const override { |
541 | rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(), |
542 | branchOp.getTarget()); |
543 | return success(); |
544 | } |
545 | }; |
546 | |
547 | class BranchConditionalConversionPattern |
548 | : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> { |
549 | public: |
550 | using SPIRVToLLVMConversion< |
551 | spirv::BranchConditionalOp>::SPIRVToLLVMConversion; |
552 | |
553 | LogicalResult |
554 | matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor, |
555 | ConversionPatternRewriter &rewriter) const override { |
556 | // If branch weights exist, map them to 32-bit integer vector. |
557 | DenseI32ArrayAttr branchWeights = nullptr; |
558 | if (auto weights = op.getBranchWeights()) { |
559 | SmallVector<int32_t> weightValues; |
560 | for (auto weight : weights->getAsRange<IntegerAttr>()) |
561 | weightValues.push_back(weight.getInt()); |
562 | branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues); |
563 | } |
564 | |
565 | rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( |
566 | op, op.getCondition(), op.getTrueBlockArguments(), |
567 | op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(), |
568 | op.getFalseBlock()); |
569 | return success(); |
570 | } |
571 | }; |
572 | |
573 | /// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container |
574 | /// type is an aggregate type (struct or array). Otherwise, converts to |
575 | /// `llvm.extractelement` that operates on vectors. |
576 | class CompositeExtractPattern |
577 | : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> { |
578 | public: |
579 | using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion; |
580 | |
581 | LogicalResult |
582 | matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor, |
583 | ConversionPatternRewriter &rewriter) const override { |
584 | auto dstType = this->getTypeConverter()->convertType(op.getType()); |
585 | if (!dstType) |
586 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
587 | |
588 | Type containerType = op.getComposite().getType(); |
589 | if (isa<VectorType>(Val: containerType)) { |
590 | Location loc = op.getLoc(); |
591 | IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]); |
592 | Value index = createI32ConstantOf(loc, rewriter, value.getInt()); |
593 | rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( |
594 | op, dstType, adaptor.getComposite(), index); |
595 | return success(); |
596 | } |
597 | |
598 | rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>( |
599 | op, adaptor.getComposite(), |
600 | LLVM::convertArrayToIndices(op.getIndices())); |
601 | return success(); |
602 | } |
603 | }; |
604 | |
605 | /// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container |
606 | /// type is an aggregate type (struct or array). Otherwise, converts to |
607 | /// `llvm.insertelement` that operates on vectors. |
608 | class CompositeInsertPattern |
609 | : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> { |
610 | public: |
611 | using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion; |
612 | |
613 | LogicalResult |
614 | matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor, |
615 | ConversionPatternRewriter &rewriter) const override { |
616 | auto dstType = this->getTypeConverter()->convertType(op.getType()); |
617 | if (!dstType) |
618 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
619 | |
620 | Type containerType = op.getComposite().getType(); |
621 | if (isa<VectorType>(Val: containerType)) { |
622 | Location loc = op.getLoc(); |
623 | IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]); |
624 | Value index = createI32ConstantOf(loc, rewriter, value.getInt()); |
625 | rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( |
626 | op, dstType, adaptor.getComposite(), adaptor.getObject(), index); |
627 | return success(); |
628 | } |
629 | |
630 | rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>( |
631 | op, adaptor.getComposite(), adaptor.getObject(), |
632 | LLVM::convertArrayToIndices(op.getIndices())); |
633 | return success(); |
634 | } |
635 | }; |
636 | |
637 | /// Converts SPIR-V operations that have straightforward LLVM equivalent |
638 | /// into LLVM dialect operations. |
639 | template <typename SPIRVOp, typename LLVMOp> |
640 | class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> { |
641 | public: |
642 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
643 | |
644 | LogicalResult |
645 | matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, |
646 | ConversionPatternRewriter &rewriter) const override { |
647 | auto dstType = this->getTypeConverter()->convertType(op.getType()); |
648 | if (!dstType) |
649 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
650 | rewriter.template replaceOpWithNewOp<LLVMOp>( |
651 | op, dstType, adaptor.getOperands(), op->getAttrs()); |
652 | return success(); |
653 | } |
654 | }; |
655 | |
656 | /// Converts `spirv.ExecutionMode` into a global struct constant that holds |
657 | /// execution mode information. |
658 | class ExecutionModePattern |
659 | : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> { |
660 | public: |
661 | using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion; |
662 | |
663 | LogicalResult |
664 | matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor, |
665 | ConversionPatternRewriter &rewriter) const override { |
666 | // First, create the global struct's name that would be associated with |
667 | // this entry point's execution mode. We set it to be: |
668 | // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode} |
669 | ModuleOp module = op->getParentOfType<ModuleOp>(); |
670 | spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr(); |
671 | std::string moduleName; |
672 | if (module.getName().has_value()) |
673 | moduleName = "_"+ module.getName()->str(); |
674 | else |
675 | moduleName = ""; |
676 | std::string executionModeInfoName = llvm::formatv( |
677 | "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(), |
678 | static_cast<uint32_t>(executionModeAttr.getValue())); |
679 | |
680 | MLIRContext *context = rewriter.getContext(); |
681 | OpBuilder::InsertionGuard guard(rewriter); |
682 | rewriter.setInsertionPointToStart(module.getBody()); |
683 | |
684 | // Create a struct type, corresponding to the C struct below. |
685 | // struct { |
686 | // int32_t executionMode; |
687 | // int32_t values[]; // optional values |
688 | // }; |
689 | auto llvmI32Type = IntegerType::get(context, 32); |
690 | SmallVector<Type, 2> fields; |
691 | fields.push_back(Elt: llvmI32Type); |
692 | ArrayAttr values = op.getValues(); |
693 | if (!values.empty()) { |
694 | auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size()); |
695 | fields.push_back(Elt: arrayType); |
696 | } |
697 | auto structType = LLVM::LLVMStructType::getLiteral(context, fields); |
698 | |
699 | // Create `llvm.mlir.global` with initializer region containing one block. |
700 | auto global = rewriter.create<LLVM::GlobalOp>( |
701 | UnknownLoc::get(context), structType, /*isConstant=*/true, |
702 | LLVM::Linkage::External, executionModeInfoName, Attribute(), |
703 | /*alignment=*/0); |
704 | Location loc = global.getLoc(); |
705 | Region ®ion = global.getInitializerRegion(); |
706 | Block *block = rewriter.createBlock(parent: ®ion); |
707 | |
708 | // Initialize the struct and set the execution mode value. |
709 | rewriter.setInsertionPointToStart(block); |
710 | Value structValue = rewriter.create<LLVM::PoisonOp>(loc, structType); |
711 | Value executionMode = rewriter.create<LLVM::ConstantOp>( |
712 | loc, llvmI32Type, |
713 | rewriter.getI32IntegerAttr( |
714 | static_cast<uint32_t>(executionModeAttr.getValue()))); |
715 | structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue, |
716 | executionMode, 0); |
717 | |
718 | // Insert extra operands if they exist into execution mode info struct. |
719 | for (unsigned i = 0, e = values.size(); i < e; ++i) { |
720 | auto attr = values.getValue()[i]; |
721 | Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr); |
722 | structValue = rewriter.create<LLVM::InsertValueOp>( |
723 | loc, structValue, entry, ArrayRef<int64_t>({1, i})); |
724 | } |
725 | rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue})); |
726 | rewriter.eraseOp(op: op); |
727 | return success(); |
728 | } |
729 | }; |
730 | |
731 | /// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V |
732 | /// global returns a pointer, whereas in LLVM dialect the global holds an actual |
733 | /// value. This difference is handled by `spirv.mlir.addressof` and |
734 | /// `llvm.mlir.addressof`ops that both return a pointer. |
735 | class GlobalVariablePattern |
736 | : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> { |
737 | public: |
738 | template <typename... Args> |
739 | GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args) |
740 | : SPIRVToLLVMConversion<spirv::GlobalVariableOp>( |
741 | std::forward<Args>(args)...), |
742 | clientAPI(clientAPI) {} |
743 | |
744 | LogicalResult |
745 | matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor, |
746 | ConversionPatternRewriter &rewriter) const override { |
747 | // Currently, there is no support of initialization with a constant value in |
748 | // SPIR-V dialect. Specialization constants are not considered as well. |
749 | if (op.getInitializer()) |
750 | return failure(); |
751 | |
752 | auto srcType = cast<spirv::PointerType>(op.getType()); |
753 | auto dstType = getTypeConverter()->convertType(srcType.getPointeeType()); |
754 | if (!dstType) |
755 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
756 | |
757 | // Limit conversion to the current invocation only or `StorageBuffer` |
758 | // required by SPIR-V runner. |
759 | // This is okay because multiple invocations are not supported yet. |
760 | auto storageClass = srcType.getStorageClass(); |
761 | switch (storageClass) { |
762 | case spirv::StorageClass::Input: |
763 | case spirv::StorageClass::Private: |
764 | case spirv::StorageClass::Output: |
765 | case spirv::StorageClass::StorageBuffer: |
766 | case spirv::StorageClass::UniformConstant: |
767 | break; |
768 | default: |
769 | return failure(); |
770 | } |
771 | |
772 | // LLVM dialect spec: "If the global value is a constant, storing into it is |
773 | // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant' |
774 | // storage class that is read-only. |
775 | bool isConstant = (storageClass == spirv::StorageClass::Input) || |
776 | (storageClass == spirv::StorageClass::UniformConstant); |
777 | // SPIR-V spec: "By default, functions and global variables are private to a |
778 | // module and cannot be accessed by other modules. However, a module may be |
779 | // written to export or import functions and global (module scope) |
780 | // variables.". Therefore, map 'Private' storage class to private linkage, |
781 | // 'Input' and 'Output' to external linkage. |
782 | auto linkage = storageClass == spirv::StorageClass::Private |
783 | ? LLVM::Linkage::Private |
784 | : LLVM::Linkage::External; |
785 | auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( |
786 | op, dstType, isConstant, linkage, op.getSymName(), Attribute(), |
787 | /*alignment=*/0, storageClassToAddressSpace(clientAPI, storageClass)); |
788 | |
789 | // Attach location attribute if applicable |
790 | if (op.getLocationAttr()) |
791 | newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr()); |
792 | |
793 | return success(); |
794 | } |
795 | |
796 | private: |
797 | spirv::ClientAPI clientAPI; |
798 | }; |
799 | |
800 | /// Converts SPIR-V cast ops that do not have straightforward LLVM |
801 | /// equivalent in LLVM dialect. |
802 | template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp> |
803 | class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> { |
804 | public: |
805 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
806 | |
807 | LogicalResult |
808 | matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, |
809 | ConversionPatternRewriter &rewriter) const override { |
810 | |
811 | Type fromType = op.getOperand().getType(); |
812 | Type toType = op.getType(); |
813 | |
814 | auto dstType = this->getTypeConverter()->convertType(toType); |
815 | if (!dstType) |
816 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
817 | |
818 | if (getBitWidth(type: fromType) < getBitWidth(type: toType)) { |
819 | rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType, |
820 | adaptor.getOperands()); |
821 | return success(); |
822 | } |
823 | if (getBitWidth(type: fromType) > getBitWidth(type: toType)) { |
824 | rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType, |
825 | adaptor.getOperands()); |
826 | return success(); |
827 | } |
828 | return failure(); |
829 | } |
830 | }; |
831 | |
832 | class FunctionCallPattern |
833 | : public SPIRVToLLVMConversion<spirv::FunctionCallOp> { |
834 | public: |
835 | using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion; |
836 | |
837 | LogicalResult |
838 | matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor, |
839 | ConversionPatternRewriter &rewriter) const override { |
840 | if (callOp.getNumResults() == 0) { |
841 | auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>( |
842 | callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs()); |
843 | newOp.getProperties().operandSegmentSizes = { |
844 | static_cast<int32_t>(adaptor.getOperands().size()), 0}; |
845 | newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({}); |
846 | return success(); |
847 | } |
848 | |
849 | // Function returns a single result. |
850 | auto dstType = getTypeConverter()->convertType(callOp.getType(0)); |
851 | if (!dstType) |
852 | return rewriter.notifyMatchFailure(callOp, "type conversion failed"); |
853 | auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>( |
854 | callOp, dstType, adaptor.getOperands(), callOp->getAttrs()); |
855 | newOp.getProperties().operandSegmentSizes = { |
856 | static_cast<int32_t>(adaptor.getOperands().size()), 0}; |
857 | newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({}); |
858 | return success(); |
859 | } |
860 | }; |
861 | |
862 | /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate" |
863 | template <typename SPIRVOp, LLVM::FCmpPredicate predicate> |
864 | class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> { |
865 | public: |
866 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
867 | |
868 | LogicalResult |
869 | matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, |
870 | ConversionPatternRewriter &rewriter) const override { |
871 | |
872 | auto dstType = this->getTypeConverter()->convertType(op.getType()); |
873 | if (!dstType) |
874 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
875 | |
876 | rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>( |
877 | op, dstType, predicate, op.getOperand1(), op.getOperand2()); |
878 | return success(); |
879 | } |
880 | }; |
881 | |
882 | /// Converts SPIR-V integer comparisons to llvm.icmp "predicate" |
883 | template <typename SPIRVOp, LLVM::ICmpPredicate predicate> |
884 | class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> { |
885 | public: |
886 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
887 | |
888 | LogicalResult |
889 | matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, |
890 | ConversionPatternRewriter &rewriter) const override { |
891 | |
892 | auto dstType = this->getTypeConverter()->convertType(op.getType()); |
893 | if (!dstType) |
894 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
895 | |
896 | rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>( |
897 | op, dstType, predicate, op.getOperand1(), op.getOperand2()); |
898 | return success(); |
899 | } |
900 | }; |
901 | |
902 | class InverseSqrtPattern |
903 | : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> { |
904 | public: |
905 | using SPIRVToLLVMConversion<spirv::GLInverseSqrtOp>::SPIRVToLLVMConversion; |
906 | |
907 | LogicalResult |
908 | matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor, |
909 | ConversionPatternRewriter &rewriter) const override { |
910 | auto srcType = op.getType(); |
911 | auto dstType = getTypeConverter()->convertType(srcType); |
912 | if (!dstType) |
913 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
914 | |
915 | Location loc = op.getLoc(); |
916 | Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); |
917 | Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand()); |
918 | rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt); |
919 | return success(); |
920 | } |
921 | }; |
922 | |
923 | /// Converts `spirv.Load` and `spirv.Store` to LLVM dialect. |
924 | template <typename SPIRVOp> |
925 | class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> { |
926 | public: |
927 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
928 | |
929 | LogicalResult |
930 | matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, |
931 | ConversionPatternRewriter &rewriter) const override { |
932 | if (!op.getMemoryAccess()) { |
933 | return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter, |
934 | *this->getTypeConverter(), /*alignment=*/0, |
935 | /*isVolatile=*/false, |
936 | /*isNonTemporal=*/false); |
937 | } |
938 | auto memoryAccess = *op.getMemoryAccess(); |
939 | switch (memoryAccess) { |
940 | case spirv::MemoryAccess::Aligned: |
941 | case spirv::MemoryAccess::None: |
942 | case spirv::MemoryAccess::Nontemporal: |
943 | case spirv::MemoryAccess::Volatile: { |
944 | unsigned alignment = |
945 | memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0; |
946 | bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal; |
947 | bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile; |
948 | return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter, |
949 | *this->getTypeConverter(), alignment, |
950 | isVolatile, isNonTemporal); |
951 | } |
952 | default: |
953 | // There is no support of other memory access attributes. |
954 | return failure(); |
955 | } |
956 | } |
957 | }; |
958 | |
959 | /// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect. |
960 | template <typename SPIRVOp> |
961 | class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> { |
962 | public: |
963 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
964 | |
965 | LogicalResult |
966 | matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor, |
967 | ConversionPatternRewriter &rewriter) const override { |
968 | auto srcType = notOp.getType(); |
969 | auto dstType = this->getTypeConverter()->convertType(srcType); |
970 | if (!dstType) |
971 | return rewriter.notifyMatchFailure(notOp, "type conversion failed"); |
972 | |
973 | Location loc = notOp.getLoc(); |
974 | IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter); |
975 | auto mask = |
976 | isa<VectorType>(srcType) |
977 | ? rewriter.create<LLVM::ConstantOp>( |
978 | loc, dstType, |
979 | SplatElementsAttr::get(cast<VectorType>(srcType), minusOne)) |
980 | : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne); |
981 | rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType, |
982 | notOp.getOperand(), mask); |
983 | return success(); |
984 | } |
985 | }; |
986 | |
987 | /// A template pattern that erases the given `SPIRVOp`. |
988 | template <typename SPIRVOp> |
989 | class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> { |
990 | public: |
991 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
992 | |
993 | LogicalResult |
994 | matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, |
995 | ConversionPatternRewriter &rewriter) const override { |
996 | rewriter.eraseOp(op); |
997 | return success(); |
998 | } |
999 | }; |
1000 | |
1001 | class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> { |
1002 | public: |
1003 | using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion; |
1004 | |
1005 | LogicalResult |
1006 | matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor, |
1007 | ConversionPatternRewriter &rewriter) const override { |
1008 | rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(), |
1009 | ArrayRef<Value>()); |
1010 | return success(); |
1011 | } |
1012 | }; |
1013 | |
1014 | class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> { |
1015 | public: |
1016 | using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion; |
1017 | |
1018 | LogicalResult |
1019 | matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor, |
1020 | ConversionPatternRewriter &rewriter) const override { |
1021 | rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(), |
1022 | adaptor.getOperands()); |
1023 | return success(); |
1024 | } |
1025 | }; |
1026 | |
1027 | static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, |
1028 | StringRef name, |
1029 | ArrayRef<Type> paramTypes, |
1030 | Type resultType, |
1031 | bool convergent = true) { |
1032 | auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>( |
1033 | SymbolTable::lookupSymbolIn(symbolTable, name)); |
1034 | if (func) |
1035 | return func; |
1036 | |
1037 | OpBuilder b(symbolTable->getRegion(index: 0)); |
1038 | func = b.create<LLVM::LLVMFuncOp>( |
1039 | symbolTable->getLoc(), name, |
1040 | LLVM::LLVMFunctionType::get(resultType, paramTypes)); |
1041 | func.setCConv(LLVM::cconv::CConv::SPIR_FUNC); |
1042 | func.setConvergent(convergent); |
1043 | func.setNoUnwind(true); |
1044 | func.setWillReturn(true); |
1045 | return func; |
1046 | } |
1047 | |
1048 | static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder, |
1049 | LLVM::LLVMFuncOp func, |
1050 | ValueRange args) { |
1051 | auto call = builder.create<LLVM::CallOp>(loc, func, args); |
1052 | call.setCConv(func.getCConv()); |
1053 | call.setConvergentAttr(func.getConvergentAttr()); |
1054 | call.setNoUnwindAttr(func.getNoUnwindAttr()); |
1055 | call.setWillReturnAttr(func.getWillReturnAttr()); |
1056 | return call; |
1057 | } |
1058 | |
1059 | template <typename BarrierOpTy> |
1060 | class ControlBarrierPattern : public SPIRVToLLVMConversion<BarrierOpTy> { |
1061 | public: |
1062 | using OpAdaptor = typename SPIRVToLLVMConversion<BarrierOpTy>::OpAdaptor; |
1063 | |
1064 | using SPIRVToLLVMConversion<BarrierOpTy>::SPIRVToLLVMConversion; |
1065 | |
1066 | static constexpr StringRef getFuncName(); |
1067 | |
1068 | LogicalResult |
1069 | matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor, |
1070 | ConversionPatternRewriter &rewriter) const override { |
1071 | constexpr StringRef funcName = getFuncName(); |
1072 | Operation *symbolTable = |
1073 | controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>(); |
1074 | |
1075 | Type i32 = rewriter.getI32Type(); |
1076 | |
1077 | Type voidTy = rewriter.getType<LLVM::LLVMVoidType>(); |
1078 | LLVM::LLVMFuncOp func = |
1079 | lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy); |
1080 | |
1081 | Location loc = controlBarrierOp->getLoc(); |
1082 | Value execution = rewriter.create<LLVM::ConstantOp>( |
1083 | loc, i32, static_cast<int32_t>(adaptor.getExecutionScope())); |
1084 | Value memory = rewriter.create<LLVM::ConstantOp>( |
1085 | loc, i32, static_cast<int32_t>(adaptor.getMemoryScope())); |
1086 | Value semantics = rewriter.create<LLVM::ConstantOp>( |
1087 | loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics())); |
1088 | |
1089 | auto call = createSPIRVBuiltinCall(loc, rewriter, func, |
1090 | {execution, memory, semantics}); |
1091 | |
1092 | rewriter.replaceOp(controlBarrierOp, call); |
1093 | return success(); |
1094 | } |
1095 | }; |
1096 | |
1097 | namespace { |
1098 | |
1099 | StringRef getTypeMangling(Type type, bool isSigned) { |
1100 | return llvm::TypeSwitch<Type, StringRef>(type) |
1101 | .Case<Float16Type>([](auto) { return "Dh"; }) |
1102 | .Case<Float32Type>([](auto) { return "f"; }) |
1103 | .Case<Float64Type>([](auto) { return "d"; }) |
1104 | .Case<IntegerType>([isSigned](IntegerType intTy) { |
1105 | switch (intTy.getWidth()) { |
1106 | case 1: |
1107 | return "b"; |
1108 | case 8: |
1109 | return (isSigned) ? "a": "c"; |
1110 | case 16: |
1111 | return (isSigned) ? "s": "t"; |
1112 | case 32: |
1113 | return (isSigned) ? "i": "j"; |
1114 | case 64: |
1115 | return (isSigned) ? "l": "m"; |
1116 | default: |
1117 | llvm_unreachable("Unsupported integer width"); |
1118 | } |
1119 | }) |
1120 | .Default([](auto) { |
1121 | llvm_unreachable("No mangling defined"); |
1122 | return ""; |
1123 | }); |
1124 | } |
1125 | |
1126 | template <typename ReduceOp> |
1127 | constexpr StringLiteral getGroupFuncName(); |
1128 | |
1129 | template <> |
1130 | constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() { |
1131 | return "_Z17__spirv_GroupIAddii"; |
1132 | } |
1133 | template <> |
1134 | constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() { |
1135 | return "_Z17__spirv_GroupFAddii"; |
1136 | } |
1137 | template <> |
1138 | constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() { |
1139 | return "_Z17__spirv_GroupSMinii"; |
1140 | } |
1141 | template <> |
1142 | constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() { |
1143 | return "_Z17__spirv_GroupUMinii"; |
1144 | } |
1145 | template <> |
1146 | constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() { |
1147 | return "_Z17__spirv_GroupFMinii"; |
1148 | } |
1149 | template <> |
1150 | constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() { |
1151 | return "_Z17__spirv_GroupSMaxii"; |
1152 | } |
1153 | template <> |
1154 | constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() { |
1155 | return "_Z17__spirv_GroupUMaxii"; |
1156 | } |
1157 | template <> |
1158 | constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() { |
1159 | return "_Z17__spirv_GroupFMaxii"; |
1160 | } |
1161 | template <> |
1162 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() { |
1163 | return "_Z27__spirv_GroupNonUniformIAddii"; |
1164 | } |
1165 | template <> |
1166 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() { |
1167 | return "_Z27__spirv_GroupNonUniformFAddii"; |
1168 | } |
1169 | template <> |
1170 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() { |
1171 | return "_Z27__spirv_GroupNonUniformIMulii"; |
1172 | } |
1173 | template <> |
1174 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() { |
1175 | return "_Z27__spirv_GroupNonUniformFMulii"; |
1176 | } |
1177 | template <> |
1178 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() { |
1179 | return "_Z27__spirv_GroupNonUniformSMinii"; |
1180 | } |
1181 | template <> |
1182 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() { |
1183 | return "_Z27__spirv_GroupNonUniformUMinii"; |
1184 | } |
1185 | template <> |
1186 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() { |
1187 | return "_Z27__spirv_GroupNonUniformFMinii"; |
1188 | } |
1189 | template <> |
1190 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() { |
1191 | return "_Z27__spirv_GroupNonUniformSMaxii"; |
1192 | } |
1193 | template <> |
1194 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() { |
1195 | return "_Z27__spirv_GroupNonUniformUMaxii"; |
1196 | } |
1197 | template <> |
1198 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() { |
1199 | return "_Z27__spirv_GroupNonUniformFMaxii"; |
1200 | } |
1201 | template <> |
1202 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() { |
1203 | return "_Z33__spirv_GroupNonUniformBitwiseAndii"; |
1204 | } |
1205 | template <> |
1206 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() { |
1207 | return "_Z32__spirv_GroupNonUniformBitwiseOrii"; |
1208 | } |
1209 | template <> |
1210 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() { |
1211 | return "_Z33__spirv_GroupNonUniformBitwiseXorii"; |
1212 | } |
1213 | template <> |
1214 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() { |
1215 | return "_Z33__spirv_GroupNonUniformLogicalAndii"; |
1216 | } |
1217 | template <> |
1218 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() { |
1219 | return "_Z32__spirv_GroupNonUniformLogicalOrii"; |
1220 | } |
1221 | template <> |
1222 | constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() { |
1223 | return "_Z33__spirv_GroupNonUniformLogicalXorii"; |
1224 | } |
1225 | } // namespace |
1226 | |
1227 | template <typename ReduceOp, bool Signed = false, bool NonUniform = false> |
1228 | class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> { |
1229 | public: |
1230 | using SPIRVToLLVMConversion<ReduceOp>::SPIRVToLLVMConversion; |
1231 | |
1232 | LogicalResult |
1233 | matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor, |
1234 | ConversionPatternRewriter &rewriter) const override { |
1235 | |
1236 | Type retTy = op.getResult().getType(); |
1237 | if (!retTy.isIntOrFloat()) { |
1238 | return failure(); |
1239 | } |
1240 | SmallString<36> funcName = getGroupFuncName<ReduceOp>(); |
1241 | funcName += getTypeMangling(type: retTy, isSigned: false); |
1242 | |
1243 | Type i32Ty = rewriter.getI32Type(); |
1244 | SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy}; |
1245 | if constexpr (NonUniform) { |
1246 | if (adaptor.getClusterSize()) { |
1247 | funcName += "j"; |
1248 | paramTypes.push_back(Elt: i32Ty); |
1249 | } |
1250 | } |
1251 | |
1252 | Operation *symbolTable = |
1253 | op->template getParentWithTrait<OpTrait::SymbolTable>(); |
1254 | |
1255 | LLVM::LLVMFuncOp func = |
1256 | lookupOrCreateSPIRVFn(symbolTable, funcName, paramTypes, retTy); |
1257 | |
1258 | Location loc = op.getLoc(); |
1259 | Value scope = rewriter.create<LLVM::ConstantOp>( |
1260 | loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope())); |
1261 | Value groupOp = rewriter.create<LLVM::ConstantOp>( |
1262 | loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation())); |
1263 | SmallVector<Value> operands{scope, groupOp}; |
1264 | operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end()); |
1265 | |
1266 | auto call = createSPIRVBuiltinCall(loc, rewriter, func, operands); |
1267 | rewriter.replaceOp(op, call); |
1268 | return success(); |
1269 | } |
1270 | }; |
1271 | |
1272 | template <> |
1273 | constexpr StringRef |
1274 | ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() { |
1275 | return "_Z22__spirv_ControlBarrieriii"; |
1276 | } |
1277 | |
1278 | template <> |
1279 | constexpr StringRef |
1280 | ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() { |
1281 | return "_Z33__spirv_ControlBarrierArriveINTELiii"; |
1282 | } |
1283 | |
1284 | template <> |
1285 | constexpr StringRef |
1286 | ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() { |
1287 | return "_Z31__spirv_ControlBarrierWaitINTELiii"; |
1288 | } |
1289 | |
1290 | /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection |
1291 | /// should be reachable for conversion to succeed. The structure of the loop in |
1292 | /// LLVM dialect will be the following: |
1293 | /// |
1294 | /// +------------------------------------+ |
1295 | /// | <code before spirv.mlir.loop> | |
1296 | /// | llvm.br ^header | |
1297 | /// +------------------------------------+ |
1298 | /// | |
1299 | /// +----------------+ | |
1300 | /// | | | |
1301 | /// | V V |
1302 | /// | +------------------------------------+ |
1303 | /// | | ^header: | |
1304 | /// | | <header code> | |
1305 | /// | | llvm.cond_br %cond, ^body, ^exit | |
1306 | /// | +------------------------------------+ |
1307 | /// | | |
1308 | /// | |----------------------+ |
1309 | /// | | | |
1310 | /// | V | |
1311 | /// | +------------------------------------+ | |
1312 | /// | | ^body: | | |
1313 | /// | | <body code> | | |
1314 | /// | | llvm.br ^continue | | |
1315 | /// | +------------------------------------+ | |
1316 | /// | | | |
1317 | /// | V | |
1318 | /// | +------------------------------------+ | |
1319 | /// | | ^continue: | | |
1320 | /// | | <continue code> | | |
1321 | /// | | llvm.br ^header | | |
1322 | /// | +------------------------------------+ | |
1323 | /// | | | |
1324 | /// +---------------+ +----------------------+ |
1325 | /// | |
1326 | /// V |
1327 | /// +------------------------------------+ |
1328 | /// | ^exit: | |
1329 | /// | llvm.br ^remaining | |
1330 | /// +------------------------------------+ |
1331 | /// | |
1332 | /// V |
1333 | /// +------------------------------------+ |
1334 | /// | ^remaining: | |
1335 | /// | <code after spirv.mlir.loop> | |
1336 | /// +------------------------------------+ |
1337 | /// |
1338 | class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> { |
1339 | public: |
1340 | using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion; |
1341 | |
1342 | LogicalResult |
1343 | matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor, |
1344 | ConversionPatternRewriter &rewriter) const override { |
1345 | // There is no support of loop control at the moment. |
1346 | if (loopOp.getLoopControl() != spirv::LoopControl::None) |
1347 | return failure(); |
1348 | |
1349 | // `spirv.mlir.loop` with empty region is redundant and should be erased. |
1350 | if (loopOp.getBody().empty()) { |
1351 | rewriter.eraseOp(op: loopOp); |
1352 | return success(); |
1353 | } |
1354 | |
1355 | Location loc = loopOp.getLoc(); |
1356 | |
1357 | // Split the current block after `spirv.mlir.loop`. The remaining ops will |
1358 | // be used in `endBlock`. |
1359 | Block *currentBlock = rewriter.getBlock(); |
1360 | auto position = Block::iterator(loopOp); |
1361 | Block *endBlock = rewriter.splitBlock(block: currentBlock, before: position); |
1362 | |
1363 | // Remove entry block and create a branch in the current block going to the |
1364 | // header block. |
1365 | Block *entryBlock = loopOp.getEntryBlock(); |
1366 | assert(entryBlock->getOperations().size() == 1); |
1367 | auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front()); |
1368 | if (!brOp) |
1369 | return failure(); |
1370 | Block *headerBlock = loopOp.getHeaderBlock(); |
1371 | rewriter.setInsertionPointToEnd(currentBlock); |
1372 | rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock); |
1373 | rewriter.eraseBlock(block: entryBlock); |
1374 | |
1375 | // Branch from merge block to end block. |
1376 | Block *mergeBlock = loopOp.getMergeBlock(); |
1377 | Operation *terminator = mergeBlock->getTerminator(); |
1378 | ValueRange terminatorOperands = terminator->getOperands(); |
1379 | rewriter.setInsertionPointToEnd(mergeBlock); |
1380 | rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock); |
1381 | |
1382 | rewriter.inlineRegionBefore(loopOp.getBody(), endBlock); |
1383 | rewriter.replaceOp(loopOp, endBlock->getArguments()); |
1384 | return success(); |
1385 | } |
1386 | }; |
1387 | |
1388 | /// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header |
1389 | /// block. All blocks within selection should be reachable for conversion to |
1390 | /// succeed. |
1391 | class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> { |
1392 | public: |
1393 | using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion; |
1394 | |
1395 | LogicalResult |
1396 | matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor, |
1397 | ConversionPatternRewriter &rewriter) const override { |
1398 | // There is no support for `Flatten` or `DontFlatten` selection control at |
1399 | // the moment. This are just compiler hints and can be performed during the |
1400 | // optimization passes. |
1401 | if (op.getSelectionControl() != spirv::SelectionControl::None) |
1402 | return failure(); |
1403 | |
1404 | // `spirv.mlir.selection` should have at least two blocks: one selection |
1405 | // header block and one merge block. If no blocks are present, or control |
1406 | // flow branches straight to merge block (two blocks are present), the op is |
1407 | // redundant and it is erased. |
1408 | if (op.getBody().getBlocks().size() <= 2) { |
1409 | rewriter.eraseOp(op: op); |
1410 | return success(); |
1411 | } |
1412 | |
1413 | Location loc = op.getLoc(); |
1414 | |
1415 | // Split the current block after `spirv.mlir.selection`. The remaining ops |
1416 | // will be used in `continueBlock`. |
1417 | auto *currentBlock = rewriter.getInsertionBlock(); |
1418 | rewriter.setInsertionPointAfter(op); |
1419 | auto position = rewriter.getInsertionPoint(); |
1420 | auto *continueBlock = rewriter.splitBlock(block: currentBlock, before: position); |
1421 | |
1422 | // Extract conditional branch information from the header block. By SPIR-V |
1423 | // dialect spec, it should contain `spirv.BranchConditional` or |
1424 | // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the |
1425 | // moment in the SPIR-V dialect. Remove this block when finished. |
1426 | auto *headerBlock = op.getHeaderBlock(); |
1427 | assert(headerBlock->getOperations().size() == 1); |
1428 | auto condBrOp = dyn_cast<spirv::BranchConditionalOp>( |
1429 | headerBlock->getOperations().front()); |
1430 | if (!condBrOp) |
1431 | return failure(); |
1432 | rewriter.eraseBlock(block: headerBlock); |
1433 | |
1434 | // Branch from merge block to continue block. |
1435 | auto *mergeBlock = op.getMergeBlock(); |
1436 | Operation *terminator = mergeBlock->getTerminator(); |
1437 | ValueRange terminatorOperands = terminator->getOperands(); |
1438 | rewriter.setInsertionPointToEnd(mergeBlock); |
1439 | rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock); |
1440 | |
1441 | // Link current block to `true` and `false` blocks within the selection. |
1442 | Block *trueBlock = condBrOp.getTrueBlock(); |
1443 | Block *falseBlock = condBrOp.getFalseBlock(); |
1444 | rewriter.setInsertionPointToEnd(currentBlock); |
1445 | rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock, |
1446 | condBrOp.getTrueTargetOperands(), |
1447 | falseBlock, |
1448 | condBrOp.getFalseTargetOperands()); |
1449 | |
1450 | rewriter.inlineRegionBefore(op.getBody(), continueBlock); |
1451 | rewriter.replaceOp(op, continueBlock->getArguments()); |
1452 | return success(); |
1453 | } |
1454 | }; |
1455 | |
1456 | /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect |
1457 | /// puts a restriction on `Shift` and `Base` to have the same bit width, |
1458 | /// `Shift` is zero or sign extended to match this specification. Cases when |
1459 | /// `Shift` bit width > `Base` bit width are considered to be illegal. |
1460 | template <typename SPIRVOp, typename LLVMOp> |
1461 | class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> { |
1462 | public: |
1463 | using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; |
1464 | |
1465 | LogicalResult |
1466 | matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, |
1467 | ConversionPatternRewriter &rewriter) const override { |
1468 | |
1469 | auto dstType = this->getTypeConverter()->convertType(op.getType()); |
1470 | if (!dstType) |
1471 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
1472 | |
1473 | Type op1Type = op.getOperand1().getType(); |
1474 | Type op2Type = op.getOperand2().getType(); |
1475 | |
1476 | if (op1Type == op2Type) { |
1477 | rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType, |
1478 | adaptor.getOperands()); |
1479 | return success(); |
1480 | } |
1481 | |
1482 | std::optional<uint64_t> dstTypeWidth = |
1483 | getIntegerOrVectorElementWidth(dstType); |
1484 | std::optional<uint64_t> op2TypeWidth = |
1485 | getIntegerOrVectorElementWidth(type: op2Type); |
1486 | |
1487 | if (!dstTypeWidth || !op2TypeWidth) |
1488 | return failure(); |
1489 | |
1490 | Location loc = op.getLoc(); |
1491 | Value extended; |
1492 | if (op2TypeWidth < dstTypeWidth) { |
1493 | if (isUnsignedIntegerOrVector(type: op2Type)) { |
1494 | extended = rewriter.template create<LLVM::ZExtOp>( |
1495 | loc, dstType, adaptor.getOperand2()); |
1496 | } else { |
1497 | extended = rewriter.template create<LLVM::SExtOp>( |
1498 | loc, dstType, adaptor.getOperand2()); |
1499 | } |
1500 | } else if (op2TypeWidth == dstTypeWidth) { |
1501 | extended = adaptor.getOperand2(); |
1502 | } else { |
1503 | return failure(); |
1504 | } |
1505 | |
1506 | Value result = rewriter.template create<LLVMOp>( |
1507 | loc, dstType, adaptor.getOperand1(), extended); |
1508 | rewriter.replaceOp(op, result); |
1509 | return success(); |
1510 | } |
1511 | }; |
1512 | |
1513 | class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> { |
1514 | public: |
1515 | using SPIRVToLLVMConversion<spirv::GLTanOp>::SPIRVToLLVMConversion; |
1516 | |
1517 | LogicalResult |
1518 | matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor, |
1519 | ConversionPatternRewriter &rewriter) const override { |
1520 | auto dstType = getTypeConverter()->convertType(tanOp.getType()); |
1521 | if (!dstType) |
1522 | return rewriter.notifyMatchFailure(tanOp, "type conversion failed"); |
1523 | |
1524 | Location loc = tanOp.getLoc(); |
1525 | Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand()); |
1526 | Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand()); |
1527 | rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos); |
1528 | return success(); |
1529 | } |
1530 | }; |
1531 | |
1532 | /// Convert `spirv.Tanh` to |
1533 | /// |
1534 | /// exp(2x) - 1 |
1535 | /// ----------- |
1536 | /// exp(2x) + 1 |
1537 | /// |
1538 | class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> { |
1539 | public: |
1540 | using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion; |
1541 | |
1542 | LogicalResult |
1543 | matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor, |
1544 | ConversionPatternRewriter &rewriter) const override { |
1545 | auto srcType = tanhOp.getType(); |
1546 | auto dstType = getTypeConverter()->convertType(srcType); |
1547 | if (!dstType) |
1548 | return rewriter.notifyMatchFailure(tanhOp, "type conversion failed"); |
1549 | |
1550 | Location loc = tanhOp.getLoc(); |
1551 | Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); |
1552 | Value multiplied = |
1553 | rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand()); |
1554 | Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied); |
1555 | Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); |
1556 | Value numerator = |
1557 | rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one); |
1558 | Value denominator = |
1559 | rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one); |
1560 | rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator, |
1561 | denominator); |
1562 | return success(); |
1563 | } |
1564 | }; |
1565 | |
1566 | class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> { |
1567 | public: |
1568 | using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion; |
1569 | |
1570 | LogicalResult |
1571 | matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor, |
1572 | ConversionPatternRewriter &rewriter) const override { |
1573 | auto srcType = varOp.getType(); |
1574 | // Initialization is supported for scalars and vectors only. |
1575 | auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType(); |
1576 | auto init = varOp.getInitializer(); |
1577 | if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo)) |
1578 | return failure(); |
1579 | |
1580 | auto dstType = getTypeConverter()->convertType(srcType); |
1581 | if (!dstType) |
1582 | return rewriter.notifyMatchFailure(varOp, "type conversion failed"); |
1583 | |
1584 | Location loc = varOp.getLoc(); |
1585 | Value size = createI32ConstantOf(loc, rewriter, value: 1); |
1586 | if (!init) { |
1587 | auto elementType = getTypeConverter()->convertType(pointerTo); |
1588 | if (!elementType) |
1589 | return rewriter.notifyMatchFailure(varOp, "type conversion failed"); |
1590 | rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType, |
1591 | size); |
1592 | return success(); |
1593 | } |
1594 | auto elementType = getTypeConverter()->convertType(pointerTo); |
1595 | if (!elementType) |
1596 | return rewriter.notifyMatchFailure(varOp, "type conversion failed"); |
1597 | Value allocated = |
1598 | rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size); |
1599 | rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated); |
1600 | rewriter.replaceOp(varOp, allocated); |
1601 | return success(); |
1602 | } |
1603 | }; |
1604 | |
1605 | //===----------------------------------------------------------------------===// |
1606 | // BitcastOp conversion |
1607 | //===----------------------------------------------------------------------===// |
1608 | |
1609 | class BitcastConversionPattern |
1610 | : public SPIRVToLLVMConversion<spirv::BitcastOp> { |
1611 | public: |
1612 | using SPIRVToLLVMConversion<spirv::BitcastOp>::SPIRVToLLVMConversion; |
1613 | |
1614 | LogicalResult |
1615 | matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor, |
1616 | ConversionPatternRewriter &rewriter) const override { |
1617 | auto dstType = getTypeConverter()->convertType(bitcastOp.getType()); |
1618 | if (!dstType) |
1619 | return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed"); |
1620 | |
1621 | // LLVM's opaque pointers do not require bitcasts. |
1622 | if (isa<LLVM::LLVMPointerType>(dstType)) { |
1623 | rewriter.replaceOp(bitcastOp, adaptor.getOperand()); |
1624 | return success(); |
1625 | } |
1626 | |
1627 | rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( |
1628 | bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs()); |
1629 | return success(); |
1630 | } |
1631 | }; |
1632 | |
1633 | //===----------------------------------------------------------------------===// |
1634 | // FuncOp conversion |
1635 | //===----------------------------------------------------------------------===// |
1636 | |
1637 | class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> { |
1638 | public: |
1639 | using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion; |
1640 | |
1641 | LogicalResult |
1642 | matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor, |
1643 | ConversionPatternRewriter &rewriter) const override { |
1644 | |
1645 | // Convert function signature. At the moment LLVMType converter is enough |
1646 | // for currently supported types. |
1647 | auto funcType = funcOp.getFunctionType(); |
1648 | TypeConverter::SignatureConversion signatureConverter( |
1649 | funcType.getNumInputs()); |
1650 | auto llvmType = static_cast<const LLVMTypeConverter *>(getTypeConverter()) |
1651 | ->convertFunctionSignature( |
1652 | funcType, /*isVariadic=*/false, |
1653 | /*useBarePtrCallConv=*/false, signatureConverter); |
1654 | if (!llvmType) |
1655 | return failure(); |
1656 | |
1657 | // Create a new `LLVMFuncOp` |
1658 | Location loc = funcOp.getLoc(); |
1659 | StringRef name = funcOp.getName(); |
1660 | auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType); |
1661 | |
1662 | // Convert SPIR-V Function Control to equivalent LLVM function attribute |
1663 | MLIRContext *context = funcOp.getContext(); |
1664 | switch (funcOp.getFunctionControl()) { |
1665 | case spirv::FunctionControl::Inline: |
1666 | newFuncOp.setAlwaysInline(true); |
1667 | break; |
1668 | case spirv::FunctionControl::DontInline: |
1669 | newFuncOp.setNoInline(true); |
1670 | break; |
1671 | |
1672 | #define DISPATCH(functionControl, llvmAttr) \ |
1673 | case functionControl: \ |
1674 | newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \ |
1675 | break; |
1676 | |
1677 | DISPATCH(spirv::FunctionControl::Pure, |
1678 | StringAttr::get(context, "readonly")); |
1679 | DISPATCH(spirv::FunctionControl::Const, |
1680 | StringAttr::get(context, "readnone")); |
1681 | |
1682 | #undef DISPATCH |
1683 | |
1684 | // Default: if `spirv::FunctionControl::None`, then no attributes are |
1685 | // needed. |
1686 | default: |
1687 | break; |
1688 | } |
1689 | |
1690 | rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), |
1691 | newFuncOp.end()); |
1692 | if (failed(rewriter.convertRegionTypes( |
1693 | region: &newFuncOp.getBody(), converter: *getTypeConverter(), entryConversion: &signatureConverter))) { |
1694 | return failure(); |
1695 | } |
1696 | rewriter.eraseOp(op: funcOp); |
1697 | return success(); |
1698 | } |
1699 | }; |
1700 | |
1701 | //===----------------------------------------------------------------------===// |
1702 | // ModuleOp conversion |
1703 | //===----------------------------------------------------------------------===// |
1704 | |
1705 | class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> { |
1706 | public: |
1707 | using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion; |
1708 | |
1709 | LogicalResult |
1710 | matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor, |
1711 | ConversionPatternRewriter &rewriter) const override { |
1712 | |
1713 | auto newModuleOp = |
1714 | rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName()); |
1715 | rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody()); |
1716 | |
1717 | // Remove the terminator block that was automatically added by builder |
1718 | rewriter.eraseBlock(block: &newModuleOp.getBodyRegion().back()); |
1719 | rewriter.eraseOp(op: spvModuleOp); |
1720 | return success(); |
1721 | } |
1722 | }; |
1723 | |
1724 | //===----------------------------------------------------------------------===// |
1725 | // VectorShuffleOp conversion |
1726 | //===----------------------------------------------------------------------===// |
1727 | |
1728 | class VectorShufflePattern |
1729 | : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> { |
1730 | public: |
1731 | using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion; |
1732 | LogicalResult |
1733 | matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor, |
1734 | ConversionPatternRewriter &rewriter) const override { |
1735 | Location loc = op.getLoc(); |
1736 | auto components = adaptor.getComponents(); |
1737 | auto vector1 = adaptor.getVector1(); |
1738 | auto vector2 = adaptor.getVector2(); |
1739 | int vector1Size = cast<VectorType>(vector1.getType()).getNumElements(); |
1740 | int vector2Size = cast<VectorType>(vector2.getType()).getNumElements(); |
1741 | if (vector1Size == vector2Size) { |
1742 | rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>( |
1743 | op, vector1, vector2, |
1744 | LLVM::convertArrayToIndices<int32_t>(components)); |
1745 | return success(); |
1746 | } |
1747 | |
1748 | auto dstType = getTypeConverter()->convertType(op.getType()); |
1749 | if (!dstType) |
1750 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
1751 | auto scalarType = cast<VectorType>(dstType).getElementType(); |
1752 | auto componentsArray = components.getValue(); |
1753 | auto *context = rewriter.getContext(); |
1754 | auto llvmI32Type = IntegerType::get(context, 32); |
1755 | Value targetOp = rewriter.create<LLVM::PoisonOp>(loc, dstType); |
1756 | for (unsigned i = 0; i < componentsArray.size(); i++) { |
1757 | if (!isa<IntegerAttr>(componentsArray[i])) |
1758 | return op.emitError("unable to support non-constant component"); |
1759 | |
1760 | int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt(); |
1761 | if (indexVal == -1) |
1762 | continue; |
1763 | |
1764 | int offsetVal = 0; |
1765 | Value baseVector = vector1; |
1766 | if (indexVal >= vector1Size) { |
1767 | offsetVal = vector1Size; |
1768 | baseVector = vector2; |
1769 | } |
1770 | |
1771 | Value dstIndex = rewriter.create<LLVM::ConstantOp>( |
1772 | loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i)); |
1773 | Value index = rewriter.create<LLVM::ConstantOp>( |
1774 | loc, llvmI32Type, |
1775 | rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal)); |
1776 | |
1777 | auto extractOp = rewriter.create<LLVM::ExtractElementOp>( |
1778 | loc, scalarType, baseVector, index); |
1779 | targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp, |
1780 | extractOp, dstIndex); |
1781 | } |
1782 | rewriter.replaceOp(op, targetOp); |
1783 | return success(); |
1784 | } |
1785 | }; |
1786 | } // namespace |
1787 | |
1788 | //===----------------------------------------------------------------------===// |
1789 | // Pattern population |
1790 | //===----------------------------------------------------------------------===// |
1791 | |
1792 | void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, |
1793 | spirv::ClientAPI clientAPI) { |
1794 | typeConverter.addConversion(callback: [&](spirv::ArrayType type) { |
1795 | return convertArrayType(type, converter&: typeConverter); |
1796 | }); |
1797 | typeConverter.addConversion(callback: [&, clientAPI](spirv::PointerType type) { |
1798 | return convertPointerType(type, typeConverter, clientAPI); |
1799 | }); |
1800 | typeConverter.addConversion(callback: [&](spirv::RuntimeArrayType type) { |
1801 | return convertRuntimeArrayType(type, converter&: typeConverter); |
1802 | }); |
1803 | typeConverter.addConversion(callback: [&](spirv::StructType type) { |
1804 | return convertStructType(type, converter: typeConverter); |
1805 | }); |
1806 | } |
1807 | |
1808 | void mlir::populateSPIRVToLLVMConversionPatterns( |
1809 | const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, |
1810 | spirv::ClientAPI clientAPI) { |
1811 | patterns.add< |
1812 | // Arithmetic ops |
1813 | DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>, |
1814 | DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>, |
1815 | DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>, |
1816 | DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>, |
1817 | DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>, |
1818 | DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>, |
1819 | DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>, |
1820 | DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>, |
1821 | DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>, |
1822 | DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>, |
1823 | DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>, |
1824 | DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>, |
1825 | DirectConversionPattern<spirv::UModOp, LLVM::URemOp>, |
1826 | |
1827 | // Bitwise ops |
1828 | BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern, |
1829 | DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>, |
1830 | DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>, |
1831 | DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>, |
1832 | DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>, |
1833 | DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>, |
1834 | NotPattern<spirv::NotOp>, |
1835 | |
1836 | // Cast ops |
1837 | BitcastConversionPattern, |
1838 | DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>, |
1839 | DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>, |
1840 | DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>, |
1841 | DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>, |
1842 | IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>, |
1843 | IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>, |
1844 | IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>, |
1845 | |
1846 | // Comparison ops |
1847 | IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>, |
1848 | IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>, |
1849 | FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>, |
1850 | FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>, |
1851 | FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>, |
1852 | FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>, |
1853 | FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>, |
1854 | FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>, |
1855 | FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>, |
1856 | FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>, |
1857 | FComparePattern<spirv::FUnordGreaterThanEqualOp, |
1858 | LLVM::FCmpPredicate::uge>, |
1859 | FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>, |
1860 | FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>, |
1861 | FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>, |
1862 | IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>, |
1863 | IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>, |
1864 | IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>, |
1865 | IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>, |
1866 | IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>, |
1867 | IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>, |
1868 | IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>, |
1869 | IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>, |
1870 | |
1871 | // Constant op |
1872 | ConstantScalarAndVectorPattern, |
1873 | |
1874 | // Control Flow ops |
1875 | BranchConversionPattern, BranchConditionalConversionPattern, |
1876 | FunctionCallPattern, LoopPattern, SelectionPattern, |
1877 | ErasePattern<spirv::MergeOp>, |
1878 | |
1879 | // Entry points and execution mode are handled separately. |
1880 | ErasePattern<spirv::EntryPointOp>, ExecutionModePattern, |
1881 | |
1882 | // GLSL extended instruction set ops |
1883 | DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>, |
1884 | DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>, |
1885 | DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>, |
1886 | DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>, |
1887 | DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>, |
1888 | DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>, |
1889 | DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>, |
1890 | DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>, |
1891 | DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>, |
1892 | DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>, |
1893 | DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>, |
1894 | DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>, |
1895 | InverseSqrtPattern, TanPattern, TanhPattern, |
1896 | |
1897 | // Logical ops |
1898 | DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>, |
1899 | DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>, |
1900 | IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>, |
1901 | IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>, |
1902 | NotPattern<spirv::LogicalNotOp>, |
1903 | |
1904 | // Memory ops |
1905 | AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>, |
1906 | LoadStorePattern<spirv::StoreOp>, VariablePattern, |
1907 | |
1908 | // Miscellaneous ops |
1909 | CompositeExtractPattern, CompositeInsertPattern, |
1910 | DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>, |
1911 | DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>, |
1912 | VectorShufflePattern, |
1913 | |
1914 | // Shift ops |
1915 | ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>, |
1916 | ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>, |
1917 | ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>, |
1918 | |
1919 | // Return ops |
1920 | ReturnPattern, ReturnValuePattern, |
1921 | |
1922 | // Barrier ops |
1923 | ControlBarrierPattern<spirv::ControlBarrierOp>, |
1924 | ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>, |
1925 | ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>, |
1926 | |
1927 | // Group reduction operations |
1928 | GroupReducePattern<spirv::GroupIAddOp>, |
1929 | GroupReducePattern<spirv::GroupFAddOp>, |
1930 | GroupReducePattern<spirv::GroupFMinOp>, |
1931 | GroupReducePattern<spirv::GroupUMinOp>, |
1932 | GroupReducePattern<spirv::GroupSMinOp, /*Signed=*/true>, |
1933 | GroupReducePattern<spirv::GroupFMaxOp>, |
1934 | GroupReducePattern<spirv::GroupUMaxOp>, |
1935 | GroupReducePattern<spirv::GroupSMaxOp, /*Signed=*/true>, |
1936 | GroupReducePattern<spirv::GroupNonUniformIAddOp, /*Signed=*/false, |
1937 | /*NonUniform=*/true>, |
1938 | GroupReducePattern<spirv::GroupNonUniformFAddOp, /*Signed=*/false, |
1939 | /*NonUniform=*/true>, |
1940 | GroupReducePattern<spirv::GroupNonUniformIMulOp, /*Signed=*/false, |
1941 | /*NonUniform=*/true>, |
1942 | GroupReducePattern<spirv::GroupNonUniformFMulOp, /*Signed=*/false, |
1943 | /*NonUniform=*/true>, |
1944 | GroupReducePattern<spirv::GroupNonUniformSMinOp, /*Signed=*/true, |
1945 | /*NonUniform=*/true>, |
1946 | GroupReducePattern<spirv::GroupNonUniformUMinOp, /*Signed=*/false, |
1947 | /*NonUniform=*/true>, |
1948 | GroupReducePattern<spirv::GroupNonUniformFMinOp, /*Signed=*/false, |
1949 | /*NonUniform=*/true>, |
1950 | GroupReducePattern<spirv::GroupNonUniformSMaxOp, /*Signed=*/true, |
1951 | /*NonUniform=*/true>, |
1952 | GroupReducePattern<spirv::GroupNonUniformUMaxOp, /*Signed=*/false, |
1953 | /*NonUniform=*/true>, |
1954 | GroupReducePattern<spirv::GroupNonUniformFMaxOp, /*Signed=*/false, |
1955 | /*NonUniform=*/true>, |
1956 | GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /*Signed=*/false, |
1957 | /*NonUniform=*/true>, |
1958 | GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /*Signed=*/false, |
1959 | /*NonUniform=*/true>, |
1960 | GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /*Signed=*/false, |
1961 | /*NonUniform=*/true>, |
1962 | GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /*Signed=*/false, |
1963 | /*NonUniform=*/true>, |
1964 | GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /*Signed=*/false, |
1965 | /*NonUniform=*/true>, |
1966 | GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /*Signed=*/false, |
1967 | /*NonUniform=*/true>>(patterns.getContext(), |
1968 | typeConverter); |
1969 | |
1970 | patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(), |
1971 | typeConverter); |
1972 | } |
1973 | |
1974 | void mlir::populateSPIRVToLLVMFunctionConversionPatterns( |
1975 | const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { |
1976 | patterns.add<FuncConversionPattern>(arg: patterns.getContext(), args: typeConverter); |
1977 | } |
1978 | |
1979 | void mlir::populateSPIRVToLLVMModuleConversionPatterns( |
1980 | const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { |
1981 | patterns.add<ModuleConversionPattern>(arg: patterns.getContext(), args: typeConverter); |
1982 | } |
1983 | |
1984 | //===----------------------------------------------------------------------===// |
1985 | // Pre-conversion hooks |
1986 | //===----------------------------------------------------------------------===// |
1987 | |
1988 | /// Hook for descriptor set and binding number encoding. |
1989 | static constexpr StringRef kBinding = "binding"; |
1990 | static constexpr StringRef kDescriptorSet = "descriptor_set"; |
1991 | void mlir::encodeBindAttribute(ModuleOp module) { |
1992 | auto spvModules = module.getOps<spirv::ModuleOp>(); |
1993 | for (auto spvModule : spvModules) { |
1994 | spvModule.walk([&](spirv::GlobalVariableOp op) { |
1995 | IntegerAttr descriptorSet = |
1996 | op->getAttrOfType<IntegerAttr>(kDescriptorSet); |
1997 | IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding); |
1998 | // For every global variable in the module, get the ones with descriptor |
1999 | // set and binding numbers. |
2000 | if (descriptorSet && binding) { |
2001 | // Encode these numbers into the variable's symbolic name. If the |
2002 | // SPIR-V module has a name, add it at the beginning. |
2003 | auto moduleAndName = |
2004 | spvModule.getName().has_value() |
2005 | ? spvModule.getName()->str() + "_"+ op.getSymName().str() |
2006 | : op.getSymName().str(); |
2007 | std::string name = |
2008 | llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName, |
2009 | std::to_string(descriptorSet.getInt()), |
2010 | std::to_string(binding.getInt())); |
2011 | auto nameAttr = StringAttr::get(op->getContext(), name); |
2012 | |
2013 | // Replace all symbol uses and set the new symbol name. Finally, remove |
2014 | // descriptor set and binding attributes. |
2015 | if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule))) |
2016 | op.emitError("unable to replace all symbol uses for ") << name; |
2017 | SymbolTable::setSymbolName(op, nameAttr); |
2018 | op->removeAttr(kDescriptorSet); |
2019 | op->removeAttr(kBinding); |
2020 | } |
2021 | }); |
2022 | } |
2023 | } |
2024 |
Definitions
- isSignedIntegerOrVector
- isUnsignedIntegerOrVector
- getIntegerOrVectorElementWidth
- getBitWidth
- getLLVMTypeBitWidth
- minusOneIntegerAttribute
- createConstantAllBitsSet
- createFPConstant
- optionallyTruncateOrExtend
- broadcast
- optionallyBroadcast
- processCountOrOffset
- convertStructTypeWithOffset
- convertStructTypePacked
- createI32ConstantOf
- replaceWithLoadOrStore
- convertArrayType
- convertPointerType
- convertRuntimeArrayType
- convertStructType
- AccessChainPattern
- matchAndRewrite
- AddressOfPattern
- matchAndRewrite
- BitFieldInsertPattern
- matchAndRewrite
- ConstantScalarAndVectorPattern
- matchAndRewrite
- BitFieldSExtractPattern
- matchAndRewrite
- BitFieldUExtractPattern
- matchAndRewrite
- BranchConversionPattern
- matchAndRewrite
- BranchConditionalConversionPattern
- matchAndRewrite
- CompositeExtractPattern
- matchAndRewrite
- CompositeInsertPattern
- matchAndRewrite
- DirectConversionPattern
- matchAndRewrite
- ExecutionModePattern
- matchAndRewrite
- GlobalVariablePattern
- GlobalVariablePattern
- matchAndRewrite
- IndirectCastPattern
- matchAndRewrite
- FunctionCallPattern
- matchAndRewrite
- FComparePattern
- matchAndRewrite
- IComparePattern
- matchAndRewrite
- InverseSqrtPattern
- matchAndRewrite
- LoadStorePattern
- matchAndRewrite
- NotPattern
- matchAndRewrite
- ErasePattern
- matchAndRewrite
- ReturnPattern
- matchAndRewrite
- ReturnValuePattern
- matchAndRewrite
- lookupOrCreateSPIRVFn
- createSPIRVBuiltinCall
- ControlBarrierPattern
- matchAndRewrite
- getTypeMangling
- GroupReducePattern
- matchAndRewrite
- LoopPattern
- matchAndRewrite
- SelectionPattern
- matchAndRewrite
- ShiftPattern
- matchAndRewrite
- TanPattern
- matchAndRewrite
- TanhPattern
- matchAndRewrite
- VariablePattern
- matchAndRewrite
- BitcastConversionPattern
- matchAndRewrite
- FuncConversionPattern
- matchAndRewrite
- ModuleConversionPattern
- matchAndRewrite
- VectorShufflePattern
- matchAndRewrite
- populateSPIRVToLLVMTypeConversion
- populateSPIRVToLLVMConversionPatterns
- populateSPIRVToLLVMFunctionConversionPatterns
- populateSPIRVToLLVMModuleConversionPatterns
- kBinding
- kDescriptorSet
Improve your Profiling and Debugging skills
Find out more