1 | //===- SemaSPIRV.cpp - Semantic Analysis for SPIRV constructs--------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // This implements Semantic Analysis for SPIRV constructs. |
9 | //===----------------------------------------------------------------------===// |
10 | |
11 | #include "clang/Sema/SemaSPIRV.h" |
12 | #include "clang/Basic/TargetBuiltins.h" |
13 | #include "clang/Basic/TargetInfo.h" |
14 | #include "clang/Sema/Sema.h" |
15 | |
16 | // SPIR-V enumerants. Enums have only the required entries, see SPIR-V specs for |
17 | // values. |
18 | // FIXME: either use the SPIRV-Headers or generate a custom header using the |
19 | // grammar (like done with MLIR). |
20 | namespace spirv { |
21 | enum class StorageClass : int { |
22 | Workgroup = 4, |
23 | CrossWorkgroup = 5, |
24 | Function = 7 |
25 | }; |
26 | } |
27 | |
28 | namespace clang { |
29 | |
30 | SemaSPIRV::SemaSPIRV(Sema &S) : SemaBase(S) {} |
31 | |
32 | static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) { |
33 | assert(TheCall->getNumArgs() > 1); |
34 | QualType ArgTy0 = TheCall->getArg(Arg: 0)->getType(); |
35 | |
36 | for (unsigned I = 1, N = TheCall->getNumArgs(); I < N; ++I) { |
37 | if (!S->getASTContext().hasSameUnqualifiedType( |
38 | T1: ArgTy0, T2: TheCall->getArg(Arg: I)->getType())) { |
39 | S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector) |
40 | << TheCall->getDirectCallee() << /*useAllTerminology*/ true |
41 | << SourceRange(TheCall->getArg(0)->getBeginLoc(), |
42 | TheCall->getArg(N - 1)->getEndLoc()); |
43 | return true; |
44 | } |
45 | } |
46 | return false; |
47 | } |
48 | |
49 | static std::optional<int> |
50 | processConstant32BitIntArgument(Sema &SemaRef, CallExpr *Call, int Argument) { |
51 | ExprResult Arg = |
52 | SemaRef.DefaultFunctionArrayLvalueConversion(E: Call->getArg(Arg: Argument)); |
53 | if (Arg.isInvalid()) |
54 | return true; |
55 | Call->setArg(Arg: Argument, ArgExpr: Arg.get()); |
56 | |
57 | const Expr *IntArg = Arg.get(); |
58 | SmallVector<PartialDiagnosticAt, 8> Notes; |
59 | Expr::EvalResult Eval; |
60 | Eval.Diag = &Notes; |
61 | if ((!IntArg->EvaluateAsConstantExpr(Result&: Eval, Ctx: SemaRef.getASTContext())) || |
62 | !Eval.Val.isInt() || Eval.Val.getInt().getBitWidth() > 32) { |
63 | SemaRef.Diag(IntArg->getBeginLoc(), diag::err_spirv_enum_not_int) |
64 | << 0 << IntArg->getSourceRange(); |
65 | for (const PartialDiagnosticAt &PDiag : Notes) |
66 | SemaRef.Diag(PDiag.first, PDiag.second); |
67 | return true; |
68 | } |
69 | return {Eval.Val.getInt().getZExtValue()}; |
70 | } |
71 | |
72 | static bool checkGenericCastToPtr(Sema &SemaRef, CallExpr *Call) { |
73 | if (SemaRef.checkArgCount(Call, DesiredArgCount: 2)) |
74 | return true; |
75 | |
76 | { |
77 | ExprResult Arg = |
78 | SemaRef.DefaultFunctionArrayLvalueConversion(E: Call->getArg(Arg: 0)); |
79 | if (Arg.isInvalid()) |
80 | return true; |
81 | Call->setArg(Arg: 0, ArgExpr: Arg.get()); |
82 | |
83 | QualType Ty = Arg.get()->getType(); |
84 | const auto *PtrTy = Ty->getAs<PointerType>(); |
85 | auto AddressSpaceNotInGeneric = [&](LangAS AS) { |
86 | if (SemaRef.LangOpts.OpenCL) |
87 | return AS != LangAS::opencl_generic; |
88 | return AS != LangAS::Default; |
89 | }; |
90 | if (!PtrTy || |
91 | AddressSpaceNotInGeneric(PtrTy->getPointeeType().getAddressSpace())) { |
92 | SemaRef.Diag(Arg.get()->getBeginLoc(), |
93 | diag::err_spirv_builtin_generic_cast_invalid_arg) |
94 | << Call->getSourceRange(); |
95 | return true; |
96 | } |
97 | } |
98 | |
99 | spirv::StorageClass StorageClass; |
100 | if (std::optional<int> SCInt = |
101 | processConstant32BitIntArgument(SemaRef, Call, Argument: 1); |
102 | SCInt.has_value()) { |
103 | StorageClass = static_cast<spirv::StorageClass>(SCInt.value()); |
104 | if (StorageClass != spirv::StorageClass::CrossWorkgroup && |
105 | StorageClass != spirv::StorageClass::Workgroup && |
106 | StorageClass != spirv::StorageClass::Function) { |
107 | SemaRef.Diag(Call->getArg(1)->getBeginLoc(), |
108 | diag::err_spirv_enum_not_valid) |
109 | << 0 << Call->getArg(1)->getSourceRange(); |
110 | return true; |
111 | } |
112 | } else { |
113 | return true; |
114 | } |
115 | auto RT = Call->getArg(Arg: 0)->getType(); |
116 | RT = RT->getPointeeType(); |
117 | auto Qual = RT.getQualifiers(); |
118 | LangAS AddrSpace; |
119 | switch (static_cast<spirv::StorageClass>(StorageClass)) { |
120 | case spirv::StorageClass::CrossWorkgroup: |
121 | AddrSpace = |
122 | SemaRef.LangOpts.isSYCL() ? LangAS::sycl_global : LangAS::opencl_global; |
123 | break; |
124 | case spirv::StorageClass::Workgroup: |
125 | AddrSpace = |
126 | SemaRef.LangOpts.isSYCL() ? LangAS::sycl_local : LangAS::opencl_local; |
127 | break; |
128 | case spirv::StorageClass::Function: |
129 | AddrSpace = SemaRef.LangOpts.isSYCL() ? LangAS::sycl_private |
130 | : LangAS::opencl_private; |
131 | break; |
132 | } |
133 | Qual.setAddressSpace(AddrSpace); |
134 | Call->setType(SemaRef.getASTContext().getPointerType( |
135 | T: SemaRef.getASTContext().getQualifiedType(T: RT.getUnqualifiedType(), Qs: Qual))); |
136 | |
137 | return false; |
138 | } |
139 | |
140 | bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI, |
141 | unsigned BuiltinID, |
142 | CallExpr *TheCall) { |
143 | if (BuiltinID >= SPIRV::FirstVKBuiltin && BuiltinID <= SPIRV::LastVKBuiltin && |
144 | TI.getTriple().getArch() != llvm::Triple::spirv) { |
145 | SemaRef.Diag(TheCall->getBeginLoc(), diag::err_spirv_invalid_target) << 0; |
146 | return true; |
147 | } |
148 | if (BuiltinID >= SPIRV::FirstCLBuiltin && BuiltinID <= SPIRV::LastTSBuiltin && |
149 | TI.getTriple().getArch() != llvm::Triple::spirv32 && |
150 | TI.getTriple().getArch() != llvm::Triple::spirv64) { |
151 | SemaRef.Diag(TheCall->getBeginLoc(), diag::err_spirv_invalid_target) << 1; |
152 | return true; |
153 | } |
154 | |
155 | switch (BuiltinID) { |
156 | case SPIRV::BI__builtin_spirv_distance: { |
157 | if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2)) |
158 | return true; |
159 | |
160 | ExprResult A = TheCall->getArg(Arg: 0); |
161 | QualType ArgTyA = A.get()->getType(); |
162 | auto *VTyA = ArgTyA->getAs<VectorType>(); |
163 | if (VTyA == nullptr) { |
164 | SemaRef.Diag(A.get()->getBeginLoc(), |
165 | diag::err_typecheck_convert_incompatible) |
166 | << ArgTyA |
167 | << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1 |
168 | << 0 << 0; |
169 | return true; |
170 | } |
171 | |
172 | ExprResult B = TheCall->getArg(Arg: 1); |
173 | QualType ArgTyB = B.get()->getType(); |
174 | auto *VTyB = ArgTyB->getAs<VectorType>(); |
175 | if (VTyB == nullptr) { |
176 | SemaRef.Diag(A.get()->getBeginLoc(), |
177 | diag::err_typecheck_convert_incompatible) |
178 | << ArgTyB |
179 | << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1 |
180 | << 0 << 0; |
181 | return true; |
182 | } |
183 | |
184 | QualType RetTy = VTyA->getElementType(); |
185 | TheCall->setType(RetTy); |
186 | break; |
187 | } |
188 | case SPIRV::BI__builtin_spirv_length: { |
189 | if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1)) |
190 | return true; |
191 | ExprResult A = TheCall->getArg(Arg: 0); |
192 | QualType ArgTyA = A.get()->getType(); |
193 | auto *VTy = ArgTyA->getAs<VectorType>(); |
194 | if (VTy == nullptr) { |
195 | SemaRef.Diag(A.get()->getBeginLoc(), |
196 | diag::err_typecheck_convert_incompatible) |
197 | << ArgTyA |
198 | << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1 |
199 | << 0 << 0; |
200 | return true; |
201 | } |
202 | QualType RetTy = VTy->getElementType(); |
203 | TheCall->setType(RetTy); |
204 | break; |
205 | } |
206 | case SPIRV::BI__builtin_spirv_reflect: { |
207 | if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2)) |
208 | return true; |
209 | |
210 | ExprResult A = TheCall->getArg(Arg: 0); |
211 | QualType ArgTyA = A.get()->getType(); |
212 | auto *VTyA = ArgTyA->getAs<VectorType>(); |
213 | if (VTyA == nullptr) { |
214 | SemaRef.Diag(A.get()->getBeginLoc(), |
215 | diag::err_typecheck_convert_incompatible) |
216 | << ArgTyA |
217 | << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1 |
218 | << 0 << 0; |
219 | return true; |
220 | } |
221 | |
222 | ExprResult B = TheCall->getArg(Arg: 1); |
223 | QualType ArgTyB = B.get()->getType(); |
224 | auto *VTyB = ArgTyB->getAs<VectorType>(); |
225 | if (VTyB == nullptr) { |
226 | SemaRef.Diag(A.get()->getBeginLoc(), |
227 | diag::err_typecheck_convert_incompatible) |
228 | << ArgTyB |
229 | << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1 |
230 | << 0 << 0; |
231 | return true; |
232 | } |
233 | |
234 | QualType RetTy = ArgTyA; |
235 | TheCall->setType(RetTy); |
236 | break; |
237 | } |
238 | case SPIRV::BI__builtin_spirv_smoothstep: { |
239 | if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3)) |
240 | return true; |
241 | |
242 | // Check if first argument has floating representation |
243 | ExprResult A = TheCall->getArg(Arg: 0); |
244 | QualType ArgTyA = A.get()->getType(); |
245 | if (!ArgTyA->hasFloatingRepresentation()) { |
246 | SemaRef.Diag(A.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type) |
247 | << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0 |
248 | << /* fp */ 1 << ArgTyA; |
249 | return true; |
250 | } |
251 | |
252 | if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall)) |
253 | return true; |
254 | |
255 | QualType RetTy = ArgTyA; |
256 | TheCall->setType(RetTy); |
257 | break; |
258 | } |
259 | case SPIRV::BI__builtin_spirv_faceforward: { |
260 | if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3)) |
261 | return true; |
262 | |
263 | // Check if first argument has floating representation |
264 | ExprResult A = TheCall->getArg(Arg: 0); |
265 | QualType ArgTyA = A.get()->getType(); |
266 | if (!ArgTyA->hasFloatingRepresentation()) { |
267 | SemaRef.Diag(A.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type) |
268 | << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0 |
269 | << /* fp */ 1 << ArgTyA; |
270 | return true; |
271 | } |
272 | |
273 | if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall)) |
274 | return true; |
275 | |
276 | QualType RetTy = ArgTyA; |
277 | TheCall->setType(RetTy); |
278 | break; |
279 | } |
280 | case SPIRV::BI__builtin_spirv_generic_cast_to_ptr_explicit: { |
281 | return checkGenericCastToPtr(SemaRef, Call: TheCall); |
282 | } |
283 | } |
284 | return false; |
285 | } |
286 | } // namespace clang |
287 | |