1 | //===- CastOps.cpp - MLIR SPIR-V Cast Ops --------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Defines the cast and conversion operations in the SPIR-V dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
14 | |
15 | #include "SPIRVOpUtils.h" |
16 | #include "SPIRVParsingUtils.h" |
17 | |
18 | #include "llvm/ADT/TypeSwitch.h" |
19 | |
20 | using namespace mlir::spirv::AttrNames; |
21 | |
22 | namespace mlir::spirv { |
23 | |
24 | static LogicalResult verifyCastOp(Operation *op, |
25 | bool requireSameBitWidth = true, |
26 | bool skipBitWidthCheck = false) { |
27 | // Some CastOps have no limit on bit widths for result and operand type. |
28 | if (skipBitWidthCheck) |
29 | return success(); |
30 | |
31 | Type operandType = op->getOperand(idx: 0).getType(); |
32 | Type resultType = op->getResult(idx: 0).getType(); |
33 | |
34 | // ODS checks that result type and operand type have the same shape. Check |
35 | // that composite types match and extract the element types, if any. |
36 | using TypePair = std::pair<Type, Type>; |
37 | auto [operandElemTy, resultElemTy] = |
38 | TypeSwitch<Type, TypePair>(operandType) |
39 | .Case<VectorType, spirv::CooperativeMatrixType, |
40 | spirv::JointMatrixINTELType>( |
41 | caseFn: [resultType](auto concreteOperandTy) -> TypePair { |
42 | if (auto concreteResultTy = |
43 | dyn_cast<decltype(concreteOperandTy)>(resultType)) { |
44 | return {concreteOperandTy.getElementType(), |
45 | concreteResultTy.getElementType()}; |
46 | } |
47 | return {}; |
48 | }) |
49 | .Default(defaultFn: [resultType](Type operandType) -> TypePair { |
50 | return {operandType, resultType}; |
51 | }); |
52 | |
53 | if (!operandElemTy || !resultElemTy) |
54 | return op->emitOpError(message: "incompatible operand and result types" ); |
55 | |
56 | unsigned operandTypeBitWidth = operandElemTy.getIntOrFloatBitWidth(); |
57 | unsigned resultTypeBitWidth = resultElemTy.getIntOrFloatBitWidth(); |
58 | bool isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth; |
59 | |
60 | if (requireSameBitWidth) { |
61 | if (!isSameBitWidth) { |
62 | return op->emitOpError( |
63 | message: "expected the same bit widths for operand type and result " |
64 | "type, but provided " ) |
65 | << operandElemTy << " and " << resultElemTy; |
66 | } |
67 | return success(); |
68 | } |
69 | |
70 | if (isSameBitWidth) { |
71 | return op->emitOpError( |
72 | message: "expected the different bit widths for operand type and result " |
73 | "type, but provided " ) |
74 | << operandElemTy << " and " << resultElemTy; |
75 | } |
76 | return success(); |
77 | } |
78 | |
79 | //===----------------------------------------------------------------------===// |
80 | // spirv.BitcastOp |
81 | //===----------------------------------------------------------------------===// |
82 | |
83 | LogicalResult BitcastOp::verify() { |
84 | // TODO: The SPIR-V spec validation rules are different for different |
85 | // versions. |
86 | auto operandType = getOperand().getType(); |
87 | auto resultType = getResult().getType(); |
88 | if (operandType == resultType) { |
89 | return emitError("result type must be different from operand type" ); |
90 | } |
91 | if (llvm::isa<spirv::PointerType>(operandType) && |
92 | !llvm::isa<spirv::PointerType>(resultType)) { |
93 | return emitError( |
94 | "unhandled bit cast conversion from pointer type to non-pointer type" ); |
95 | } |
96 | if (!llvm::isa<spirv::PointerType>(operandType) && |
97 | llvm::isa<spirv::PointerType>(resultType)) { |
98 | return emitError( |
99 | "unhandled bit cast conversion from non-pointer type to pointer type" ); |
100 | } |
101 | auto operandBitWidth = getBitWidth(operandType); |
102 | auto resultBitWidth = getBitWidth(resultType); |
103 | if (operandBitWidth != resultBitWidth) { |
104 | return emitOpError("mismatch in result type bitwidth " ) |
105 | << resultBitWidth << " and operand type bitwidth " |
106 | << operandBitWidth; |
107 | } |
108 | return success(); |
109 | } |
110 | |
111 | //===----------------------------------------------------------------------===// |
112 | // spirv.ConvertPtrToUOp |
113 | //===----------------------------------------------------------------------===// |
114 | |
115 | LogicalResult ConvertPtrToUOp::verify() { |
116 | auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType()); |
117 | auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType()); |
118 | if (!resultType || !resultType.isSignlessInteger()) |
119 | return emitError("result must be a scalar type of unsigned integer" ); |
120 | auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>(); |
121 | if (!spirvModule) |
122 | return success(); |
123 | auto addressingModel = spirvModule.getAddressingModel(); |
124 | if ((addressingModel == spirv::AddressingModel::Logical) || |
125 | (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 && |
126 | operandType.getStorageClass() != |
127 | spirv::StorageClass::PhysicalStorageBuffer)) |
128 | return emitError("operand must be a physical pointer" ); |
129 | return success(); |
130 | } |
131 | |
132 | //===----------------------------------------------------------------------===// |
133 | // spirv.ConvertUToPtrOp |
134 | //===----------------------------------------------------------------------===// |
135 | |
136 | LogicalResult ConvertUToPtrOp::verify() { |
137 | auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType()); |
138 | auto resultType = llvm::cast<spirv::PointerType>(getResult().getType()); |
139 | if (!operandType || !operandType.isSignlessInteger()) |
140 | return emitError("result must be a scalar type of unsigned integer" ); |
141 | auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>(); |
142 | if (!spirvModule) |
143 | return success(); |
144 | auto addressingModel = spirvModule.getAddressingModel(); |
145 | if ((addressingModel == spirv::AddressingModel::Logical) || |
146 | (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 && |
147 | resultType.getStorageClass() != |
148 | spirv::StorageClass::PhysicalStorageBuffer)) |
149 | return emitError("result must be a physical pointer" ); |
150 | return success(); |
151 | } |
152 | |
153 | //===----------------------------------------------------------------------===// |
154 | // spirv.PtrCastToGenericOp |
155 | //===----------------------------------------------------------------------===// |
156 | |
157 | LogicalResult PtrCastToGenericOp::verify() { |
158 | auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType()); |
159 | auto resultType = llvm::cast<spirv::PointerType>(getResult().getType()); |
160 | |
161 | spirv::StorageClass operandStorage = operandType.getStorageClass(); |
162 | if (operandStorage != spirv::StorageClass::Workgroup && |
163 | operandStorage != spirv::StorageClass::CrossWorkgroup && |
164 | operandStorage != spirv::StorageClass::Function) |
165 | return emitError("pointer must point to the Workgroup, CrossWorkgroup" |
166 | ", or Function Storage Class" ); |
167 | |
168 | spirv::StorageClass resultStorage = resultType.getStorageClass(); |
169 | if (resultStorage != spirv::StorageClass::Generic) |
170 | return emitError("result type must be of storage class Generic" ); |
171 | |
172 | Type operandPointeeType = operandType.getPointeeType(); |
173 | Type resultPointeeType = resultType.getPointeeType(); |
174 | if (operandPointeeType != resultPointeeType) |
175 | return emitOpError("pointer operand's pointee type must have the same " |
176 | "as the op result type, but found " ) |
177 | << operandPointeeType << " vs " << resultPointeeType; |
178 | return success(); |
179 | } |
180 | |
181 | //===----------------------------------------------------------------------===// |
182 | // spirv.GenericCastToPtrOp |
183 | //===----------------------------------------------------------------------===// |
184 | |
185 | LogicalResult GenericCastToPtrOp::verify() { |
186 | auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType()); |
187 | auto resultType = llvm::cast<spirv::PointerType>(getResult().getType()); |
188 | |
189 | spirv::StorageClass operandStorage = operandType.getStorageClass(); |
190 | if (operandStorage != spirv::StorageClass::Generic) |
191 | return emitError("pointer type must be of storage class Generic" ); |
192 | |
193 | spirv::StorageClass resultStorage = resultType.getStorageClass(); |
194 | if (resultStorage != spirv::StorageClass::Workgroup && |
195 | resultStorage != spirv::StorageClass::CrossWorkgroup && |
196 | resultStorage != spirv::StorageClass::Function) |
197 | return emitError("result must point to the Workgroup, CrossWorkgroup, " |
198 | "or Function Storage Class" ); |
199 | |
200 | Type operandPointeeType = operandType.getPointeeType(); |
201 | Type resultPointeeType = resultType.getPointeeType(); |
202 | if (operandPointeeType != resultPointeeType) |
203 | return emitOpError("pointer operand's pointee type must have the same " |
204 | "as the op result type, but found " ) |
205 | << operandPointeeType << " vs " << resultPointeeType; |
206 | return success(); |
207 | } |
208 | |
209 | //===----------------------------------------------------------------------===// |
210 | // spirv.GenericCastToPtrExplicitOp |
211 | //===----------------------------------------------------------------------===// |
212 | |
213 | LogicalResult GenericCastToPtrExplicitOp::verify() { |
214 | auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType()); |
215 | auto resultType = llvm::cast<spirv::PointerType>(getResult().getType()); |
216 | |
217 | spirv::StorageClass operandStorage = operandType.getStorageClass(); |
218 | if (operandStorage != spirv::StorageClass::Generic) |
219 | return emitError("pointer type must be of storage class Generic" ); |
220 | |
221 | spirv::StorageClass resultStorage = resultType.getStorageClass(); |
222 | if (resultStorage != spirv::StorageClass::Workgroup && |
223 | resultStorage != spirv::StorageClass::CrossWorkgroup && |
224 | resultStorage != spirv::StorageClass::Function) |
225 | return emitError("result must point to the Workgroup, CrossWorkgroup, " |
226 | "or Function Storage Class" ); |
227 | |
228 | Type operandPointeeType = operandType.getPointeeType(); |
229 | Type resultPointeeType = resultType.getPointeeType(); |
230 | if (operandPointeeType != resultPointeeType) |
231 | return emitOpError("pointer operand's pointee type must have the same " |
232 | "as the op result type, but found " ) |
233 | << operandPointeeType << " vs " << resultPointeeType; |
234 | return success(); |
235 | } |
236 | |
237 | //===----------------------------------------------------------------------===// |
238 | // spirv.ConvertFToSOp |
239 | //===----------------------------------------------------------------------===// |
240 | |
241 | LogicalResult ConvertFToSOp::verify() { |
242 | return verifyCastOp(*this, /*requireSameBitWidth=*/false, |
243 | /*skipBitWidthCheck=*/true); |
244 | } |
245 | |
246 | //===----------------------------------------------------------------------===// |
247 | // spirv.ConvertFToUOp |
248 | //===----------------------------------------------------------------------===// |
249 | |
250 | LogicalResult ConvertFToUOp::verify() { |
251 | return verifyCastOp(*this, /*requireSameBitWidth=*/false, |
252 | /*skipBitWidthCheck=*/true); |
253 | } |
254 | |
255 | //===----------------------------------------------------------------------===// |
256 | // spirv.ConvertSToFOp |
257 | //===----------------------------------------------------------------------===// |
258 | |
259 | LogicalResult ConvertSToFOp::verify() { |
260 | return verifyCastOp(*this, /*requireSameBitWidth=*/false, |
261 | /*skipBitWidthCheck=*/true); |
262 | } |
263 | |
264 | //===----------------------------------------------------------------------===// |
265 | // spirv.ConvertUToFOp |
266 | //===----------------------------------------------------------------------===// |
267 | |
268 | LogicalResult ConvertUToFOp::verify() { |
269 | return verifyCastOp(*this, /*requireSameBitWidth=*/false, |
270 | /*skipBitWidthCheck=*/true); |
271 | } |
272 | |
273 | //===----------------------------------------------------------------------===// |
274 | // spirv.INTELConvertBF16ToFOp |
275 | //===----------------------------------------------------------------------===// |
276 | |
277 | LogicalResult INTELConvertBF16ToFOp::verify() { |
278 | auto operandType = getOperand().getType(); |
279 | auto resultType = getResult().getType(); |
280 | // ODS checks that vector result type and vector operand type have the same |
281 | // shape. |
282 | if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) { |
283 | unsigned operandNumElements = vectorType.getNumElements(); |
284 | unsigned resultNumElements = |
285 | llvm::cast<VectorType>(resultType).getNumElements(); |
286 | if (operandNumElements != resultNumElements) { |
287 | return emitOpError( |
288 | "operand and result must have same number of elements" ); |
289 | } |
290 | } |
291 | return success(); |
292 | } |
293 | |
294 | //===----------------------------------------------------------------------===// |
295 | // spirv.INTELConvertFToBF16Op |
296 | //===----------------------------------------------------------------------===// |
297 | |
298 | LogicalResult INTELConvertFToBF16Op::verify() { |
299 | auto operandType = getOperand().getType(); |
300 | auto resultType = getResult().getType(); |
301 | // ODS checks that vector result type and vector operand type have the same |
302 | // shape. |
303 | if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) { |
304 | unsigned operandNumElements = vectorType.getNumElements(); |
305 | unsigned resultNumElements = |
306 | llvm::cast<VectorType>(resultType).getNumElements(); |
307 | if (operandNumElements != resultNumElements) { |
308 | return emitOpError( |
309 | "operand and result must have same number of elements" ); |
310 | } |
311 | } |
312 | return success(); |
313 | } |
314 | |
315 | //===----------------------------------------------------------------------===// |
316 | // spirv.FConvertOp |
317 | //===----------------------------------------------------------------------===// |
318 | |
319 | LogicalResult spirv::FConvertOp::verify() { |
320 | return verifyCastOp(*this, /*requireSameBitWidth=*/false); |
321 | } |
322 | |
323 | //===----------------------------------------------------------------------===// |
324 | // spirv.SConvertOp |
325 | //===----------------------------------------------------------------------===// |
326 | |
327 | LogicalResult spirv::SConvertOp::verify() { |
328 | return verifyCastOp(*this, /*requireSameBitWidth=*/false); |
329 | } |
330 | |
331 | //===----------------------------------------------------------------------===// |
332 | // spirv.UConvertOp |
333 | //===----------------------------------------------------------------------===// |
334 | |
335 | LogicalResult spirv::UConvertOp::verify() { |
336 | return verifyCastOp(*this, /*requireSameBitWidth=*/false); |
337 | } |
338 | |
339 | } // namespace mlir::spirv |
340 | |