1 | //===- CastOps.cpp - MLIR SPIR-V Cast Ops --------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Defines the cast and conversion operations in the SPIR-V dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
14 | |
15 | #include "SPIRVOpUtils.h" |
16 | #include "SPIRVParsingUtils.h" |
17 | |
18 | #include "llvm/ADT/TypeSwitch.h" |
19 | |
20 | using namespace mlir::spirv::AttrNames; |
21 | |
22 | namespace mlir::spirv { |
23 | |
24 | static LogicalResult verifyCastOp(Operation *op, |
25 | bool requireSameBitWidth = true, |
26 | bool skipBitWidthCheck = false) { |
27 | // Some CastOps have no limit on bit widths for result and operand type. |
28 | if (skipBitWidthCheck) |
29 | return success(); |
30 | |
31 | Type operandType = op->getOperand(idx: 0).getType(); |
32 | Type resultType = op->getResult(idx: 0).getType(); |
33 | |
34 | // ODS checks that result type and operand type have the same shape. Check |
35 | // that composite types match and extract the element types, if any. |
36 | using TypePair = std::pair<Type, Type>; |
37 | auto [operandElemTy, resultElemTy] = |
38 | TypeSwitch<Type, TypePair>(operandType) |
39 | .Case<VectorType, spirv::CooperativeMatrixType>( |
40 | caseFn: [resultType](auto concreteOperandTy) -> TypePair { |
41 | if (auto concreteResultTy = |
42 | dyn_cast<decltype(concreteOperandTy)>(resultType)) { |
43 | return {concreteOperandTy.getElementType(), |
44 | concreteResultTy.getElementType()}; |
45 | } |
46 | return {}; |
47 | }) |
48 | .Default(defaultFn: [resultType](Type operandType) -> TypePair { |
49 | return {operandType, resultType}; |
50 | }); |
51 | |
52 | if (!operandElemTy || !resultElemTy) |
53 | return op->emitOpError(message: "incompatible operand and result types" ); |
54 | |
55 | unsigned operandTypeBitWidth = operandElemTy.getIntOrFloatBitWidth(); |
56 | unsigned resultTypeBitWidth = resultElemTy.getIntOrFloatBitWidth(); |
57 | bool isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth; |
58 | |
59 | if (requireSameBitWidth) { |
60 | if (!isSameBitWidth) { |
61 | return op->emitOpError( |
62 | message: "expected the same bit widths for operand type and result " |
63 | "type, but provided " ) |
64 | << operandElemTy << " and " << resultElemTy; |
65 | } |
66 | return success(); |
67 | } |
68 | |
69 | if (isSameBitWidth) { |
70 | return op->emitOpError( |
71 | message: "expected the different bit widths for operand type and result " |
72 | "type, but provided " ) |
73 | << operandElemTy << " and " << resultElemTy; |
74 | } |
75 | return success(); |
76 | } |
77 | |
78 | //===----------------------------------------------------------------------===// |
79 | // spirv.BitcastOp |
80 | //===----------------------------------------------------------------------===// |
81 | |
82 | LogicalResult BitcastOp::verify() { |
83 | // TODO: The SPIR-V spec validation rules are different for different |
84 | // versions. |
85 | auto operandType = getOperand().getType(); |
86 | auto resultType = getResult().getType(); |
87 | if (operandType == resultType) { |
88 | return emitError("result type must be different from operand type" ); |
89 | } |
90 | if (llvm::isa<spirv::PointerType>(operandType) && |
91 | !llvm::isa<spirv::PointerType>(resultType)) { |
92 | return emitError( |
93 | "unhandled bit cast conversion from pointer type to non-pointer type" ); |
94 | } |
95 | if (!llvm::isa<spirv::PointerType>(operandType) && |
96 | llvm::isa<spirv::PointerType>(resultType)) { |
97 | return emitError( |
98 | "unhandled bit cast conversion from non-pointer type to pointer type" ); |
99 | } |
100 | auto operandBitWidth = getBitWidth(operandType); |
101 | auto resultBitWidth = getBitWidth(resultType); |
102 | if (operandBitWidth != resultBitWidth) { |
103 | return emitOpError("mismatch in result type bitwidth " ) |
104 | << resultBitWidth << " and operand type bitwidth " |
105 | << operandBitWidth; |
106 | } |
107 | return success(); |
108 | } |
109 | |
110 | //===----------------------------------------------------------------------===// |
111 | // spirv.ConvertPtrToUOp |
112 | //===----------------------------------------------------------------------===// |
113 | |
114 | LogicalResult ConvertPtrToUOp::verify() { |
115 | auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType()); |
116 | auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType()); |
117 | if (!resultType || !resultType.isSignlessInteger()) |
118 | return emitError("result must be a scalar type of unsigned integer" ); |
119 | auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>(); |
120 | if (!spirvModule) |
121 | return success(); |
122 | auto addressingModel = spirvModule.getAddressingModel(); |
123 | if ((addressingModel == spirv::AddressingModel::Logical) || |
124 | (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 && |
125 | operandType.getStorageClass() != |
126 | spirv::StorageClass::PhysicalStorageBuffer)) |
127 | return emitError("operand must be a physical pointer" ); |
128 | return success(); |
129 | } |
130 | |
131 | //===----------------------------------------------------------------------===// |
132 | // spirv.ConvertUToPtrOp |
133 | //===----------------------------------------------------------------------===// |
134 | |
135 | LogicalResult ConvertUToPtrOp::verify() { |
136 | auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType()); |
137 | auto resultType = llvm::cast<spirv::PointerType>(getResult().getType()); |
138 | if (!operandType || !operandType.isSignlessInteger()) |
139 | return emitError("result must be a scalar type of unsigned integer" ); |
140 | auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>(); |
141 | if (!spirvModule) |
142 | return success(); |
143 | auto addressingModel = spirvModule.getAddressingModel(); |
144 | if ((addressingModel == spirv::AddressingModel::Logical) || |
145 | (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 && |
146 | resultType.getStorageClass() != |
147 | spirv::StorageClass::PhysicalStorageBuffer)) |
148 | return emitError("result must be a physical pointer" ); |
149 | return success(); |
150 | } |
151 | |
152 | //===----------------------------------------------------------------------===// |
153 | // spirv.PtrCastToGenericOp |
154 | //===----------------------------------------------------------------------===// |
155 | |
156 | LogicalResult PtrCastToGenericOp::verify() { |
157 | auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType()); |
158 | auto resultType = llvm::cast<spirv::PointerType>(getResult().getType()); |
159 | |
160 | spirv::StorageClass operandStorage = operandType.getStorageClass(); |
161 | if (operandStorage != spirv::StorageClass::Workgroup && |
162 | operandStorage != spirv::StorageClass::CrossWorkgroup && |
163 | operandStorage != spirv::StorageClass::Function) |
164 | return emitError("pointer must point to the Workgroup, CrossWorkgroup" |
165 | ", or Function Storage Class" ); |
166 | |
167 | spirv::StorageClass resultStorage = resultType.getStorageClass(); |
168 | if (resultStorage != spirv::StorageClass::Generic) |
169 | return emitError("result type must be of storage class Generic" ); |
170 | |
171 | Type operandPointeeType = operandType.getPointeeType(); |
172 | Type resultPointeeType = resultType.getPointeeType(); |
173 | if (operandPointeeType != resultPointeeType) |
174 | return emitOpError("pointer operand's pointee type must have the same " |
175 | "as the op result type, but found " ) |
176 | << operandPointeeType << " vs " << resultPointeeType; |
177 | return success(); |
178 | } |
179 | |
180 | //===----------------------------------------------------------------------===// |
181 | // spirv.GenericCastToPtrOp |
182 | //===----------------------------------------------------------------------===// |
183 | |
184 | LogicalResult GenericCastToPtrOp::verify() { |
185 | auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType()); |
186 | auto resultType = llvm::cast<spirv::PointerType>(getResult().getType()); |
187 | |
188 | spirv::StorageClass operandStorage = operandType.getStorageClass(); |
189 | if (operandStorage != spirv::StorageClass::Generic) |
190 | return emitError("pointer type must be of storage class Generic" ); |
191 | |
192 | spirv::StorageClass resultStorage = resultType.getStorageClass(); |
193 | if (resultStorage != spirv::StorageClass::Workgroup && |
194 | resultStorage != spirv::StorageClass::CrossWorkgroup && |
195 | resultStorage != spirv::StorageClass::Function) |
196 | return emitError("result must point to the Workgroup, CrossWorkgroup, " |
197 | "or Function Storage Class" ); |
198 | |
199 | Type operandPointeeType = operandType.getPointeeType(); |
200 | Type resultPointeeType = resultType.getPointeeType(); |
201 | if (operandPointeeType != resultPointeeType) |
202 | return emitOpError("pointer operand's pointee type must have the same " |
203 | "as the op result type, but found " ) |
204 | << operandPointeeType << " vs " << resultPointeeType; |
205 | return success(); |
206 | } |
207 | |
208 | //===----------------------------------------------------------------------===// |
209 | // spirv.GenericCastToPtrExplicitOp |
210 | //===----------------------------------------------------------------------===// |
211 | |
212 | LogicalResult GenericCastToPtrExplicitOp::verify() { |
213 | auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType()); |
214 | auto resultType = llvm::cast<spirv::PointerType>(getResult().getType()); |
215 | |
216 | spirv::StorageClass operandStorage = operandType.getStorageClass(); |
217 | if (operandStorage != spirv::StorageClass::Generic) |
218 | return emitError("pointer type must be of storage class Generic" ); |
219 | |
220 | spirv::StorageClass resultStorage = resultType.getStorageClass(); |
221 | if (resultStorage != spirv::StorageClass::Workgroup && |
222 | resultStorage != spirv::StorageClass::CrossWorkgroup && |
223 | resultStorage != spirv::StorageClass::Function) |
224 | return emitError("result must point to the Workgroup, CrossWorkgroup, " |
225 | "or Function Storage Class" ); |
226 | |
227 | Type operandPointeeType = operandType.getPointeeType(); |
228 | Type resultPointeeType = resultType.getPointeeType(); |
229 | if (operandPointeeType != resultPointeeType) |
230 | return emitOpError("pointer operand's pointee type must have the same " |
231 | "as the op result type, but found " ) |
232 | << operandPointeeType << " vs " << resultPointeeType; |
233 | return success(); |
234 | } |
235 | |
236 | //===----------------------------------------------------------------------===// |
237 | // spirv.ConvertFToSOp |
238 | //===----------------------------------------------------------------------===// |
239 | |
240 | LogicalResult ConvertFToSOp::verify() { |
241 | return verifyCastOp(*this, /*requireSameBitWidth=*/false, |
242 | /*skipBitWidthCheck=*/true); |
243 | } |
244 | |
245 | //===----------------------------------------------------------------------===// |
246 | // spirv.ConvertFToUOp |
247 | //===----------------------------------------------------------------------===// |
248 | |
249 | LogicalResult ConvertFToUOp::verify() { |
250 | return verifyCastOp(*this, /*requireSameBitWidth=*/false, |
251 | /*skipBitWidthCheck=*/true); |
252 | } |
253 | |
254 | //===----------------------------------------------------------------------===// |
255 | // spirv.ConvertSToFOp |
256 | //===----------------------------------------------------------------------===// |
257 | |
258 | LogicalResult ConvertSToFOp::verify() { |
259 | return verifyCastOp(*this, /*requireSameBitWidth=*/false, |
260 | /*skipBitWidthCheck=*/true); |
261 | } |
262 | |
263 | //===----------------------------------------------------------------------===// |
264 | // spirv.ConvertUToFOp |
265 | //===----------------------------------------------------------------------===// |
266 | |
267 | LogicalResult ConvertUToFOp::verify() { |
268 | return verifyCastOp(*this, /*requireSameBitWidth=*/false, |
269 | /*skipBitWidthCheck=*/true); |
270 | } |
271 | |
272 | //===----------------------------------------------------------------------===// |
273 | // spirv.INTELConvertBF16ToFOp |
274 | //===----------------------------------------------------------------------===// |
275 | |
276 | LogicalResult INTELConvertBF16ToFOp::verify() { |
277 | auto operandType = getOperand().getType(); |
278 | auto resultType = getResult().getType(); |
279 | // ODS checks that vector result type and vector operand type have the same |
280 | // shape. |
281 | if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) { |
282 | unsigned operandNumElements = vectorType.getNumElements(); |
283 | unsigned resultNumElements = |
284 | llvm::cast<VectorType>(resultType).getNumElements(); |
285 | if (operandNumElements != resultNumElements) { |
286 | return emitOpError( |
287 | "operand and result must have same number of elements" ); |
288 | } |
289 | } |
290 | return success(); |
291 | } |
292 | |
293 | //===----------------------------------------------------------------------===// |
294 | // spirv.INTELConvertFToBF16Op |
295 | //===----------------------------------------------------------------------===// |
296 | |
297 | LogicalResult INTELConvertFToBF16Op::verify() { |
298 | auto operandType = getOperand().getType(); |
299 | auto resultType = getResult().getType(); |
300 | // ODS checks that vector result type and vector operand type have the same |
301 | // shape. |
302 | if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) { |
303 | unsigned operandNumElements = vectorType.getNumElements(); |
304 | unsigned resultNumElements = |
305 | llvm::cast<VectorType>(resultType).getNumElements(); |
306 | if (operandNumElements != resultNumElements) { |
307 | return emitOpError( |
308 | "operand and result must have same number of elements" ); |
309 | } |
310 | } |
311 | return success(); |
312 | } |
313 | |
314 | //===----------------------------------------------------------------------===// |
315 | // spirv.FConvertOp |
316 | //===----------------------------------------------------------------------===// |
317 | |
318 | LogicalResult spirv::FConvertOp::verify() { |
319 | return verifyCastOp(*this, /*requireSameBitWidth=*/false); |
320 | } |
321 | |
322 | //===----------------------------------------------------------------------===// |
323 | // spirv.SConvertOp |
324 | //===----------------------------------------------------------------------===// |
325 | |
326 | LogicalResult spirv::SConvertOp::verify() { |
327 | return verifyCastOp(*this, /*requireSameBitWidth=*/false); |
328 | } |
329 | |
330 | //===----------------------------------------------------------------------===// |
331 | // spirv.UConvertOp |
332 | //===----------------------------------------------------------------------===// |
333 | |
334 | LogicalResult spirv::UConvertOp::verify() { |
335 | return verifyCastOp(*this, /*requireSameBitWidth=*/false); |
336 | } |
337 | |
338 | } // namespace mlir::spirv |
339 | |