1//===- CastOps.cpp - MLIR SPIR-V Cast Ops --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines the cast and conversion operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14
15#include "SPIRVOpUtils.h"
16#include "SPIRVParsingUtils.h"
17
18#include "llvm/ADT/TypeSwitch.h"
19
20using namespace mlir::spirv::AttrNames;
21
22namespace mlir::spirv {
23
24static LogicalResult verifyCastOp(Operation *op,
25 bool requireSameBitWidth = true,
26 bool skipBitWidthCheck = false) {
27 // Some CastOps have no limit on bit widths for result and operand type.
28 if (skipBitWidthCheck)
29 return success();
30
31 Type operandType = op->getOperand(idx: 0).getType();
32 Type resultType = op->getResult(idx: 0).getType();
33
34 // ODS checks that result type and operand type have the same shape. Check
35 // that composite types match and extract the element types, if any.
36 using TypePair = std::pair<Type, Type>;
37 auto [operandElemTy, resultElemTy] =
38 TypeSwitch<Type, TypePair>(operandType)
39 .Case<VectorType, spirv::CooperativeMatrixType>(
40 caseFn: [resultType](auto concreteOperandTy) -> TypePair {
41 if (auto concreteResultTy =
42 dyn_cast<decltype(concreteOperandTy)>(resultType)) {
43 return {concreteOperandTy.getElementType(),
44 concreteResultTy.getElementType()};
45 }
46 return {};
47 })
48 .Default(defaultFn: [resultType](Type operandType) -> TypePair {
49 return {operandType, resultType};
50 });
51
52 if (!operandElemTy || !resultElemTy)
53 return op->emitOpError(message: "incompatible operand and result types");
54
55 unsigned operandTypeBitWidth = operandElemTy.getIntOrFloatBitWidth();
56 unsigned resultTypeBitWidth = resultElemTy.getIntOrFloatBitWidth();
57 bool isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
58
59 if (requireSameBitWidth) {
60 if (!isSameBitWidth) {
61 return op->emitOpError(
62 message: "expected the same bit widths for operand type and result "
63 "type, but provided ")
64 << operandElemTy << " and " << resultElemTy;
65 }
66 return success();
67 }
68
69 if (isSameBitWidth) {
70 return op->emitOpError(
71 message: "expected the different bit widths for operand type and result "
72 "type, but provided ")
73 << operandElemTy << " and " << resultElemTy;
74 }
75 return success();
76}
77
78//===----------------------------------------------------------------------===//
79// spirv.BitcastOp
80//===----------------------------------------------------------------------===//
81
82LogicalResult BitcastOp::verify() {
83 // TODO: The SPIR-V spec validation rules are different for different
84 // versions.
85 auto operandType = getOperand().getType();
86 auto resultType = getResult().getType();
87 if (operandType == resultType) {
88 return emitError("result type must be different from operand type");
89 }
90 if (llvm::isa<spirv::PointerType>(operandType) &&
91 !llvm::isa<spirv::PointerType>(resultType)) {
92 return emitError(
93 "unhandled bit cast conversion from pointer type to non-pointer type");
94 }
95 if (!llvm::isa<spirv::PointerType>(operandType) &&
96 llvm::isa<spirv::PointerType>(resultType)) {
97 return emitError(
98 "unhandled bit cast conversion from non-pointer type to pointer type");
99 }
100 auto operandBitWidth = getBitWidth(operandType);
101 auto resultBitWidth = getBitWidth(resultType);
102 if (operandBitWidth != resultBitWidth) {
103 return emitOpError("mismatch in result type bitwidth ")
104 << resultBitWidth << " and operand type bitwidth "
105 << operandBitWidth;
106 }
107 return success();
108}
109
110//===----------------------------------------------------------------------===//
111// spirv.ConvertPtrToUOp
112//===----------------------------------------------------------------------===//
113
114LogicalResult ConvertPtrToUOp::verify() {
115 auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
116 auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType());
117 if (!resultType || !resultType.isSignlessInteger())
118 return emitError("result must be a scalar type of unsigned integer");
119 auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
120 if (!spirvModule)
121 return success();
122 auto addressingModel = spirvModule.getAddressingModel();
123 if ((addressingModel == spirv::AddressingModel::Logical) ||
124 (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
125 operandType.getStorageClass() !=
126 spirv::StorageClass::PhysicalStorageBuffer))
127 return emitError("operand must be a physical pointer");
128 return success();
129}
130
131//===----------------------------------------------------------------------===//
132// spirv.ConvertUToPtrOp
133//===----------------------------------------------------------------------===//
134
135LogicalResult ConvertUToPtrOp::verify() {
136 auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType());
137 auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
138 if (!operandType || !operandType.isSignlessInteger())
139 return emitError("result must be a scalar type of unsigned integer");
140 auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
141 if (!spirvModule)
142 return success();
143 auto addressingModel = spirvModule.getAddressingModel();
144 if ((addressingModel == spirv::AddressingModel::Logical) ||
145 (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
146 resultType.getStorageClass() !=
147 spirv::StorageClass::PhysicalStorageBuffer))
148 return emitError("result must be a physical pointer");
149 return success();
150}
151
152//===----------------------------------------------------------------------===//
153// spirv.PtrCastToGenericOp
154//===----------------------------------------------------------------------===//
155
156LogicalResult PtrCastToGenericOp::verify() {
157 auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
158 auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
159
160 spirv::StorageClass operandStorage = operandType.getStorageClass();
161 if (operandStorage != spirv::StorageClass::Workgroup &&
162 operandStorage != spirv::StorageClass::CrossWorkgroup &&
163 operandStorage != spirv::StorageClass::Function)
164 return emitError("pointer must point to the Workgroup, CrossWorkgroup"
165 ", or Function Storage Class");
166
167 spirv::StorageClass resultStorage = resultType.getStorageClass();
168 if (resultStorage != spirv::StorageClass::Generic)
169 return emitError("result type must be of storage class Generic");
170
171 Type operandPointeeType = operandType.getPointeeType();
172 Type resultPointeeType = resultType.getPointeeType();
173 if (operandPointeeType != resultPointeeType)
174 return emitOpError("pointer operand's pointee type must have the same "
175 "as the op result type, but found ")
176 << operandPointeeType << " vs " << resultPointeeType;
177 return success();
178}
179
180//===----------------------------------------------------------------------===//
181// spirv.GenericCastToPtrOp
182//===----------------------------------------------------------------------===//
183
184LogicalResult GenericCastToPtrOp::verify() {
185 auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
186 auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
187
188 spirv::StorageClass operandStorage = operandType.getStorageClass();
189 if (operandStorage != spirv::StorageClass::Generic)
190 return emitError("pointer type must be of storage class Generic");
191
192 spirv::StorageClass resultStorage = resultType.getStorageClass();
193 if (resultStorage != spirv::StorageClass::Workgroup &&
194 resultStorage != spirv::StorageClass::CrossWorkgroup &&
195 resultStorage != spirv::StorageClass::Function)
196 return emitError("result must point to the Workgroup, CrossWorkgroup, "
197 "or Function Storage Class");
198
199 Type operandPointeeType = operandType.getPointeeType();
200 Type resultPointeeType = resultType.getPointeeType();
201 if (operandPointeeType != resultPointeeType)
202 return emitOpError("pointer operand's pointee type must have the same "
203 "as the op result type, but found ")
204 << operandPointeeType << " vs " << resultPointeeType;
205 return success();
206}
207
208//===----------------------------------------------------------------------===//
209// spirv.GenericCastToPtrExplicitOp
210//===----------------------------------------------------------------------===//
211
212LogicalResult GenericCastToPtrExplicitOp::verify() {
213 auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
214 auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
215
216 spirv::StorageClass operandStorage = operandType.getStorageClass();
217 if (operandStorage != spirv::StorageClass::Generic)
218 return emitError("pointer type must be of storage class Generic");
219
220 spirv::StorageClass resultStorage = resultType.getStorageClass();
221 if (resultStorage != spirv::StorageClass::Workgroup &&
222 resultStorage != spirv::StorageClass::CrossWorkgroup &&
223 resultStorage != spirv::StorageClass::Function)
224 return emitError("result must point to the Workgroup, CrossWorkgroup, "
225 "or Function Storage Class");
226
227 Type operandPointeeType = operandType.getPointeeType();
228 Type resultPointeeType = resultType.getPointeeType();
229 if (operandPointeeType != resultPointeeType)
230 return emitOpError("pointer operand's pointee type must have the same "
231 "as the op result type, but found ")
232 << operandPointeeType << " vs " << resultPointeeType;
233 return success();
234}
235
236//===----------------------------------------------------------------------===//
237// spirv.ConvertFToSOp
238//===----------------------------------------------------------------------===//
239
240LogicalResult ConvertFToSOp::verify() {
241 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
242 /*skipBitWidthCheck=*/true);
243}
244
245//===----------------------------------------------------------------------===//
246// spirv.ConvertFToUOp
247//===----------------------------------------------------------------------===//
248
249LogicalResult ConvertFToUOp::verify() {
250 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
251 /*skipBitWidthCheck=*/true);
252}
253
254//===----------------------------------------------------------------------===//
255// spirv.ConvertSToFOp
256//===----------------------------------------------------------------------===//
257
258LogicalResult ConvertSToFOp::verify() {
259 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
260 /*skipBitWidthCheck=*/true);
261}
262
263//===----------------------------------------------------------------------===//
264// spirv.ConvertUToFOp
265//===----------------------------------------------------------------------===//
266
267LogicalResult ConvertUToFOp::verify() {
268 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
269 /*skipBitWidthCheck=*/true);
270}
271
272//===----------------------------------------------------------------------===//
273// spirv.INTELConvertBF16ToFOp
274//===----------------------------------------------------------------------===//
275
276LogicalResult INTELConvertBF16ToFOp::verify() {
277 auto operandType = getOperand().getType();
278 auto resultType = getResult().getType();
279 // ODS checks that vector result type and vector operand type have the same
280 // shape.
281 if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
282 unsigned operandNumElements = vectorType.getNumElements();
283 unsigned resultNumElements =
284 llvm::cast<VectorType>(resultType).getNumElements();
285 if (operandNumElements != resultNumElements) {
286 return emitOpError(
287 "operand and result must have same number of elements");
288 }
289 }
290 return success();
291}
292
293//===----------------------------------------------------------------------===//
294// spirv.INTELConvertFToBF16Op
295//===----------------------------------------------------------------------===//
296
297LogicalResult INTELConvertFToBF16Op::verify() {
298 auto operandType = getOperand().getType();
299 auto resultType = getResult().getType();
300 // ODS checks that vector result type and vector operand type have the same
301 // shape.
302 if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
303 unsigned operandNumElements = vectorType.getNumElements();
304 unsigned resultNumElements =
305 llvm::cast<VectorType>(resultType).getNumElements();
306 if (operandNumElements != resultNumElements) {
307 return emitOpError(
308 "operand and result must have same number of elements");
309 }
310 }
311 return success();
312}
313
314//===----------------------------------------------------------------------===//
315// spirv.FConvertOp
316//===----------------------------------------------------------------------===//
317
318LogicalResult spirv::FConvertOp::verify() {
319 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
320}
321
322//===----------------------------------------------------------------------===//
323// spirv.SConvertOp
324//===----------------------------------------------------------------------===//
325
326LogicalResult spirv::SConvertOp::verify() {
327 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
328}
329
330//===----------------------------------------------------------------------===//
331// spirv.UConvertOp
332//===----------------------------------------------------------------------===//
333
334LogicalResult spirv::UConvertOp::verify() {
335 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
336}
337
338} // namespace mlir::spirv
339

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/SPIRV/IR/CastOps.cpp