1//===- CastOps.cpp - MLIR SPIR-V Cast Ops --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines the cast and conversion operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14
15#include "SPIRVOpUtils.h"
16#include "SPIRVParsingUtils.h"
17
18#include "llvm/ADT/TypeSwitch.h"
19
20using namespace mlir::spirv::AttrNames;
21
22namespace mlir::spirv {
23
24static LogicalResult verifyCastOp(Operation *op,
25 bool requireSameBitWidth = true,
26 bool skipBitWidthCheck = false) {
27 // Some CastOps have no limit on bit widths for result and operand type.
28 if (skipBitWidthCheck)
29 return success();
30
31 Type operandType = op->getOperand(idx: 0).getType();
32 Type resultType = op->getResult(idx: 0).getType();
33
34 // ODS checks that result type and operand type have the same shape. Check
35 // that composite types match and extract the element types, if any.
36 using TypePair = std::pair<Type, Type>;
37 auto [operandElemTy, resultElemTy] =
38 TypeSwitch<Type, TypePair>(operandType)
39 .Case<VectorType, spirv::CooperativeMatrixType,
40 spirv::JointMatrixINTELType>(
41 caseFn: [resultType](auto concreteOperandTy) -> TypePair {
42 if (auto concreteResultTy =
43 dyn_cast<decltype(concreteOperandTy)>(resultType)) {
44 return {concreteOperandTy.getElementType(),
45 concreteResultTy.getElementType()};
46 }
47 return {};
48 })
49 .Default(defaultFn: [resultType](Type operandType) -> TypePair {
50 return {operandType, resultType};
51 });
52
53 if (!operandElemTy || !resultElemTy)
54 return op->emitOpError(message: "incompatible operand and result types");
55
56 unsigned operandTypeBitWidth = operandElemTy.getIntOrFloatBitWidth();
57 unsigned resultTypeBitWidth = resultElemTy.getIntOrFloatBitWidth();
58 bool isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
59
60 if (requireSameBitWidth) {
61 if (!isSameBitWidth) {
62 return op->emitOpError(
63 message: "expected the same bit widths for operand type and result "
64 "type, but provided ")
65 << operandElemTy << " and " << resultElemTy;
66 }
67 return success();
68 }
69
70 if (isSameBitWidth) {
71 return op->emitOpError(
72 message: "expected the different bit widths for operand type and result "
73 "type, but provided ")
74 << operandElemTy << " and " << resultElemTy;
75 }
76 return success();
77}
78
79//===----------------------------------------------------------------------===//
80// spirv.BitcastOp
81//===----------------------------------------------------------------------===//
82
83LogicalResult BitcastOp::verify() {
84 // TODO: The SPIR-V spec validation rules are different for different
85 // versions.
86 auto operandType = getOperand().getType();
87 auto resultType = getResult().getType();
88 if (operandType == resultType) {
89 return emitError("result type must be different from operand type");
90 }
91 if (llvm::isa<spirv::PointerType>(operandType) &&
92 !llvm::isa<spirv::PointerType>(resultType)) {
93 return emitError(
94 "unhandled bit cast conversion from pointer type to non-pointer type");
95 }
96 if (!llvm::isa<spirv::PointerType>(operandType) &&
97 llvm::isa<spirv::PointerType>(resultType)) {
98 return emitError(
99 "unhandled bit cast conversion from non-pointer type to pointer type");
100 }
101 auto operandBitWidth = getBitWidth(operandType);
102 auto resultBitWidth = getBitWidth(resultType);
103 if (operandBitWidth != resultBitWidth) {
104 return emitOpError("mismatch in result type bitwidth ")
105 << resultBitWidth << " and operand type bitwidth "
106 << operandBitWidth;
107 }
108 return success();
109}
110
111//===----------------------------------------------------------------------===//
112// spirv.ConvertPtrToUOp
113//===----------------------------------------------------------------------===//
114
115LogicalResult ConvertPtrToUOp::verify() {
116 auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
117 auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType());
118 if (!resultType || !resultType.isSignlessInteger())
119 return emitError("result must be a scalar type of unsigned integer");
120 auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
121 if (!spirvModule)
122 return success();
123 auto addressingModel = spirvModule.getAddressingModel();
124 if ((addressingModel == spirv::AddressingModel::Logical) ||
125 (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
126 operandType.getStorageClass() !=
127 spirv::StorageClass::PhysicalStorageBuffer))
128 return emitError("operand must be a physical pointer");
129 return success();
130}
131
132//===----------------------------------------------------------------------===//
133// spirv.ConvertUToPtrOp
134//===----------------------------------------------------------------------===//
135
136LogicalResult ConvertUToPtrOp::verify() {
137 auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType());
138 auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
139 if (!operandType || !operandType.isSignlessInteger())
140 return emitError("result must be a scalar type of unsigned integer");
141 auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
142 if (!spirvModule)
143 return success();
144 auto addressingModel = spirvModule.getAddressingModel();
145 if ((addressingModel == spirv::AddressingModel::Logical) ||
146 (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
147 resultType.getStorageClass() !=
148 spirv::StorageClass::PhysicalStorageBuffer))
149 return emitError("result must be a physical pointer");
150 return success();
151}
152
153//===----------------------------------------------------------------------===//
154// spirv.PtrCastToGenericOp
155//===----------------------------------------------------------------------===//
156
157LogicalResult PtrCastToGenericOp::verify() {
158 auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
159 auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
160
161 spirv::StorageClass operandStorage = operandType.getStorageClass();
162 if (operandStorage != spirv::StorageClass::Workgroup &&
163 operandStorage != spirv::StorageClass::CrossWorkgroup &&
164 operandStorage != spirv::StorageClass::Function)
165 return emitError("pointer must point to the Workgroup, CrossWorkgroup"
166 ", or Function Storage Class");
167
168 spirv::StorageClass resultStorage = resultType.getStorageClass();
169 if (resultStorage != spirv::StorageClass::Generic)
170 return emitError("result type must be of storage class Generic");
171
172 Type operandPointeeType = operandType.getPointeeType();
173 Type resultPointeeType = resultType.getPointeeType();
174 if (operandPointeeType != resultPointeeType)
175 return emitOpError("pointer operand's pointee type must have the same "
176 "as the op result type, but found ")
177 << operandPointeeType << " vs " << resultPointeeType;
178 return success();
179}
180
181//===----------------------------------------------------------------------===//
182// spirv.GenericCastToPtrOp
183//===----------------------------------------------------------------------===//
184
185LogicalResult GenericCastToPtrOp::verify() {
186 auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
187 auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
188
189 spirv::StorageClass operandStorage = operandType.getStorageClass();
190 if (operandStorage != spirv::StorageClass::Generic)
191 return emitError("pointer type must be of storage class Generic");
192
193 spirv::StorageClass resultStorage = resultType.getStorageClass();
194 if (resultStorage != spirv::StorageClass::Workgroup &&
195 resultStorage != spirv::StorageClass::CrossWorkgroup &&
196 resultStorage != spirv::StorageClass::Function)
197 return emitError("result must point to the Workgroup, CrossWorkgroup, "
198 "or Function Storage Class");
199
200 Type operandPointeeType = operandType.getPointeeType();
201 Type resultPointeeType = resultType.getPointeeType();
202 if (operandPointeeType != resultPointeeType)
203 return emitOpError("pointer operand's pointee type must have the same "
204 "as the op result type, but found ")
205 << operandPointeeType << " vs " << resultPointeeType;
206 return success();
207}
208
209//===----------------------------------------------------------------------===//
210// spirv.GenericCastToPtrExplicitOp
211//===----------------------------------------------------------------------===//
212
213LogicalResult GenericCastToPtrExplicitOp::verify() {
214 auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
215 auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
216
217 spirv::StorageClass operandStorage = operandType.getStorageClass();
218 if (operandStorage != spirv::StorageClass::Generic)
219 return emitError("pointer type must be of storage class Generic");
220
221 spirv::StorageClass resultStorage = resultType.getStorageClass();
222 if (resultStorage != spirv::StorageClass::Workgroup &&
223 resultStorage != spirv::StorageClass::CrossWorkgroup &&
224 resultStorage != spirv::StorageClass::Function)
225 return emitError("result must point to the Workgroup, CrossWorkgroup, "
226 "or Function Storage Class");
227
228 Type operandPointeeType = operandType.getPointeeType();
229 Type resultPointeeType = resultType.getPointeeType();
230 if (operandPointeeType != resultPointeeType)
231 return emitOpError("pointer operand's pointee type must have the same "
232 "as the op result type, but found ")
233 << operandPointeeType << " vs " << resultPointeeType;
234 return success();
235}
236
237//===----------------------------------------------------------------------===//
238// spirv.ConvertFToSOp
239//===----------------------------------------------------------------------===//
240
241LogicalResult ConvertFToSOp::verify() {
242 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
243 /*skipBitWidthCheck=*/true);
244}
245
246//===----------------------------------------------------------------------===//
247// spirv.ConvertFToUOp
248//===----------------------------------------------------------------------===//
249
250LogicalResult ConvertFToUOp::verify() {
251 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
252 /*skipBitWidthCheck=*/true);
253}
254
255//===----------------------------------------------------------------------===//
256// spirv.ConvertSToFOp
257//===----------------------------------------------------------------------===//
258
259LogicalResult ConvertSToFOp::verify() {
260 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
261 /*skipBitWidthCheck=*/true);
262}
263
264//===----------------------------------------------------------------------===//
265// spirv.ConvertUToFOp
266//===----------------------------------------------------------------------===//
267
268LogicalResult ConvertUToFOp::verify() {
269 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
270 /*skipBitWidthCheck=*/true);
271}
272
273//===----------------------------------------------------------------------===//
274// spirv.INTELConvertBF16ToFOp
275//===----------------------------------------------------------------------===//
276
277LogicalResult INTELConvertBF16ToFOp::verify() {
278 auto operandType = getOperand().getType();
279 auto resultType = getResult().getType();
280 // ODS checks that vector result type and vector operand type have the same
281 // shape.
282 if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
283 unsigned operandNumElements = vectorType.getNumElements();
284 unsigned resultNumElements =
285 llvm::cast<VectorType>(resultType).getNumElements();
286 if (operandNumElements != resultNumElements) {
287 return emitOpError(
288 "operand and result must have same number of elements");
289 }
290 }
291 return success();
292}
293
294//===----------------------------------------------------------------------===//
295// spirv.INTELConvertFToBF16Op
296//===----------------------------------------------------------------------===//
297
298LogicalResult INTELConvertFToBF16Op::verify() {
299 auto operandType = getOperand().getType();
300 auto resultType = getResult().getType();
301 // ODS checks that vector result type and vector operand type have the same
302 // shape.
303 if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
304 unsigned operandNumElements = vectorType.getNumElements();
305 unsigned resultNumElements =
306 llvm::cast<VectorType>(resultType).getNumElements();
307 if (operandNumElements != resultNumElements) {
308 return emitOpError(
309 "operand and result must have same number of elements");
310 }
311 }
312 return success();
313}
314
315//===----------------------------------------------------------------------===//
316// spirv.FConvertOp
317//===----------------------------------------------------------------------===//
318
319LogicalResult spirv::FConvertOp::verify() {
320 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
321}
322
323//===----------------------------------------------------------------------===//
324// spirv.SConvertOp
325//===----------------------------------------------------------------------===//
326
327LogicalResult spirv::SConvertOp::verify() {
328 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
329}
330
331//===----------------------------------------------------------------------===//
332// spirv.UConvertOp
333//===----------------------------------------------------------------------===//
334
335LogicalResult spirv::UConvertOp::verify() {
336 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
337}
338
339} // namespace mlir::spirv
340

source code of mlir/lib/Dialect/SPIRV/IR/CastOps.cpp