1//===- SemaSPIRV.cpp - Semantic Analysis for SPIRV constructs--------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This implements Semantic Analysis for SPIRV constructs.
9//===----------------------------------------------------------------------===//
10
11#include "clang/Sema/SemaSPIRV.h"
12#include "clang/Basic/TargetBuiltins.h"
13#include "clang/Basic/TargetInfo.h"
14#include "clang/Sema/Sema.h"
15
16// SPIR-V enumerants. Enums have only the required entries, see SPIR-V specs for
17// values.
18// FIXME: either use the SPIRV-Headers or generate a custom header using the
19// grammar (like done with MLIR).
20namespace spirv {
21enum class StorageClass : int {
22 Workgroup = 4,
23 CrossWorkgroup = 5,
24 Function = 7
25};
26}
27
28namespace clang {
29
30SemaSPIRV::SemaSPIRV(Sema &S) : SemaBase(S) {}
31
32static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
33 assert(TheCall->getNumArgs() > 1);
34 QualType ArgTy0 = TheCall->getArg(Arg: 0)->getType();
35
36 for (unsigned I = 1, N = TheCall->getNumArgs(); I < N; ++I) {
37 if (!S->getASTContext().hasSameUnqualifiedType(
38 T1: ArgTy0, T2: TheCall->getArg(Arg: I)->getType())) {
39 S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
40 << TheCall->getDirectCallee() << /*useAllTerminology*/ true
41 << SourceRange(TheCall->getArg(0)->getBeginLoc(),
42 TheCall->getArg(N - 1)->getEndLoc());
43 return true;
44 }
45 }
46 return false;
47}
48
49static std::optional<int>
50processConstant32BitIntArgument(Sema &SemaRef, CallExpr *Call, int Argument) {
51 ExprResult Arg =
52 SemaRef.DefaultFunctionArrayLvalueConversion(E: Call->getArg(Arg: Argument));
53 if (Arg.isInvalid())
54 return true;
55 Call->setArg(Arg: Argument, ArgExpr: Arg.get());
56
57 const Expr *IntArg = Arg.get();
58 SmallVector<PartialDiagnosticAt, 8> Notes;
59 Expr::EvalResult Eval;
60 Eval.Diag = &Notes;
61 if ((!IntArg->EvaluateAsConstantExpr(Result&: Eval, Ctx: SemaRef.getASTContext())) ||
62 !Eval.Val.isInt() || Eval.Val.getInt().getBitWidth() > 32) {
63 SemaRef.Diag(IntArg->getBeginLoc(), diag::err_spirv_enum_not_int)
64 << 0 << IntArg->getSourceRange();
65 for (const PartialDiagnosticAt &PDiag : Notes)
66 SemaRef.Diag(PDiag.first, PDiag.second);
67 return true;
68 }
69 return {Eval.Val.getInt().getZExtValue()};
70}
71
72static bool checkGenericCastToPtr(Sema &SemaRef, CallExpr *Call) {
73 if (SemaRef.checkArgCount(Call, DesiredArgCount: 2))
74 return true;
75
76 {
77 ExprResult Arg =
78 SemaRef.DefaultFunctionArrayLvalueConversion(E: Call->getArg(Arg: 0));
79 if (Arg.isInvalid())
80 return true;
81 Call->setArg(Arg: 0, ArgExpr: Arg.get());
82
83 QualType Ty = Arg.get()->getType();
84 const auto *PtrTy = Ty->getAs<PointerType>();
85 auto AddressSpaceNotInGeneric = [&](LangAS AS) {
86 if (SemaRef.LangOpts.OpenCL)
87 return AS != LangAS::opencl_generic;
88 return AS != LangAS::Default;
89 };
90 if (!PtrTy ||
91 AddressSpaceNotInGeneric(PtrTy->getPointeeType().getAddressSpace())) {
92 SemaRef.Diag(Arg.get()->getBeginLoc(),
93 diag::err_spirv_builtin_generic_cast_invalid_arg)
94 << Call->getSourceRange();
95 return true;
96 }
97 }
98
99 spirv::StorageClass StorageClass;
100 if (std::optional<int> SCInt =
101 processConstant32BitIntArgument(SemaRef, Call, Argument: 1);
102 SCInt.has_value()) {
103 StorageClass = static_cast<spirv::StorageClass>(SCInt.value());
104 if (StorageClass != spirv::StorageClass::CrossWorkgroup &&
105 StorageClass != spirv::StorageClass::Workgroup &&
106 StorageClass != spirv::StorageClass::Function) {
107 SemaRef.Diag(Call->getArg(1)->getBeginLoc(),
108 diag::err_spirv_enum_not_valid)
109 << 0 << Call->getArg(1)->getSourceRange();
110 return true;
111 }
112 } else {
113 return true;
114 }
115 auto RT = Call->getArg(Arg: 0)->getType();
116 RT = RT->getPointeeType();
117 auto Qual = RT.getQualifiers();
118 LangAS AddrSpace;
119 switch (static_cast<spirv::StorageClass>(StorageClass)) {
120 case spirv::StorageClass::CrossWorkgroup:
121 AddrSpace =
122 SemaRef.LangOpts.isSYCL() ? LangAS::sycl_global : LangAS::opencl_global;
123 break;
124 case spirv::StorageClass::Workgroup:
125 AddrSpace =
126 SemaRef.LangOpts.isSYCL() ? LangAS::sycl_local : LangAS::opencl_local;
127 break;
128 case spirv::StorageClass::Function:
129 AddrSpace = SemaRef.LangOpts.isSYCL() ? LangAS::sycl_private
130 : LangAS::opencl_private;
131 break;
132 }
133 Qual.setAddressSpace(AddrSpace);
134 Call->setType(SemaRef.getASTContext().getPointerType(
135 T: SemaRef.getASTContext().getQualifiedType(T: RT.getUnqualifiedType(), Qs: Qual)));
136
137 return false;
138}
139
140bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
141 unsigned BuiltinID,
142 CallExpr *TheCall) {
143 if (BuiltinID >= SPIRV::FirstVKBuiltin && BuiltinID <= SPIRV::LastVKBuiltin &&
144 TI.getTriple().getArch() != llvm::Triple::spirv) {
145 SemaRef.Diag(TheCall->getBeginLoc(), diag::err_spirv_invalid_target) << 0;
146 return true;
147 }
148 if (BuiltinID >= SPIRV::FirstCLBuiltin && BuiltinID <= SPIRV::LastTSBuiltin &&
149 TI.getTriple().getArch() != llvm::Triple::spirv32 &&
150 TI.getTriple().getArch() != llvm::Triple::spirv64) {
151 SemaRef.Diag(TheCall->getBeginLoc(), diag::err_spirv_invalid_target) << 1;
152 return true;
153 }
154
155 switch (BuiltinID) {
156 case SPIRV::BI__builtin_spirv_distance: {
157 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
158 return true;
159
160 ExprResult A = TheCall->getArg(Arg: 0);
161 QualType ArgTyA = A.get()->getType();
162 auto *VTyA = ArgTyA->getAs<VectorType>();
163 if (VTyA == nullptr) {
164 SemaRef.Diag(A.get()->getBeginLoc(),
165 diag::err_typecheck_convert_incompatible)
166 << ArgTyA
167 << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
168 << 0 << 0;
169 return true;
170 }
171
172 ExprResult B = TheCall->getArg(Arg: 1);
173 QualType ArgTyB = B.get()->getType();
174 auto *VTyB = ArgTyB->getAs<VectorType>();
175 if (VTyB == nullptr) {
176 SemaRef.Diag(A.get()->getBeginLoc(),
177 diag::err_typecheck_convert_incompatible)
178 << ArgTyB
179 << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
180 << 0 << 0;
181 return true;
182 }
183
184 QualType RetTy = VTyA->getElementType();
185 TheCall->setType(RetTy);
186 break;
187 }
188 case SPIRV::BI__builtin_spirv_length: {
189 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
190 return true;
191 ExprResult A = TheCall->getArg(Arg: 0);
192 QualType ArgTyA = A.get()->getType();
193 auto *VTy = ArgTyA->getAs<VectorType>();
194 if (VTy == nullptr) {
195 SemaRef.Diag(A.get()->getBeginLoc(),
196 diag::err_typecheck_convert_incompatible)
197 << ArgTyA
198 << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
199 << 0 << 0;
200 return true;
201 }
202 QualType RetTy = VTy->getElementType();
203 TheCall->setType(RetTy);
204 break;
205 }
206 case SPIRV::BI__builtin_spirv_reflect: {
207 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
208 return true;
209
210 ExprResult A = TheCall->getArg(Arg: 0);
211 QualType ArgTyA = A.get()->getType();
212 auto *VTyA = ArgTyA->getAs<VectorType>();
213 if (VTyA == nullptr) {
214 SemaRef.Diag(A.get()->getBeginLoc(),
215 diag::err_typecheck_convert_incompatible)
216 << ArgTyA
217 << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
218 << 0 << 0;
219 return true;
220 }
221
222 ExprResult B = TheCall->getArg(Arg: 1);
223 QualType ArgTyB = B.get()->getType();
224 auto *VTyB = ArgTyB->getAs<VectorType>();
225 if (VTyB == nullptr) {
226 SemaRef.Diag(A.get()->getBeginLoc(),
227 diag::err_typecheck_convert_incompatible)
228 << ArgTyB
229 << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
230 << 0 << 0;
231 return true;
232 }
233
234 QualType RetTy = ArgTyA;
235 TheCall->setType(RetTy);
236 break;
237 }
238 case SPIRV::BI__builtin_spirv_smoothstep: {
239 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
240 return true;
241
242 // Check if first argument has floating representation
243 ExprResult A = TheCall->getArg(Arg: 0);
244 QualType ArgTyA = A.get()->getType();
245 if (!ArgTyA->hasFloatingRepresentation()) {
246 SemaRef.Diag(A.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
247 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
248 << /* fp */ 1 << ArgTyA;
249 return true;
250 }
251
252 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
253 return true;
254
255 QualType RetTy = ArgTyA;
256 TheCall->setType(RetTy);
257 break;
258 }
259 case SPIRV::BI__builtin_spirv_faceforward: {
260 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
261 return true;
262
263 // Check if first argument has floating representation
264 ExprResult A = TheCall->getArg(Arg: 0);
265 QualType ArgTyA = A.get()->getType();
266 if (!ArgTyA->hasFloatingRepresentation()) {
267 SemaRef.Diag(A.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
268 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
269 << /* fp */ 1 << ArgTyA;
270 return true;
271 }
272
273 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
274 return true;
275
276 QualType RetTy = ArgTyA;
277 TheCall->setType(RetTy);
278 break;
279 }
280 case SPIRV::BI__builtin_spirv_generic_cast_to_ptr_explicit: {
281 return checkGenericCastToPtr(SemaRef, Call: TheCall);
282 }
283 }
284 return false;
285}
286} // namespace clang
287

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of clang/lib/Sema/SemaSPIRV.cpp