1//===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This implements Semantic Analysis for HLSL constructs.
9//===----------------------------------------------------------------------===//
10
11#include "clang/Sema/SemaHLSL.h"
12#include "clang/AST/ASTConsumer.h"
13#include "clang/AST/ASTContext.h"
14#include "clang/AST/Attr.h"
15#include "clang/AST/Attrs.inc"
16#include "clang/AST/Decl.h"
17#include "clang/AST/DeclBase.h"
18#include "clang/AST/DeclCXX.h"
19#include "clang/AST/DeclarationName.h"
20#include "clang/AST/DynamicRecursiveASTVisitor.h"
21#include "clang/AST/Expr.h"
22#include "clang/AST/Type.h"
23#include "clang/AST/TypeLoc.h"
24#include "clang/Basic/Builtins.h"
25#include "clang/Basic/DiagnosticSema.h"
26#include "clang/Basic/IdentifierTable.h"
27#include "clang/Basic/LLVM.h"
28#include "clang/Basic/SourceLocation.h"
29#include "clang/Basic/Specifiers.h"
30#include "clang/Basic/TargetInfo.h"
31#include "clang/Sema/Initialization.h"
32#include "clang/Sema/Lookup.h"
33#include "clang/Sema/ParsedAttr.h"
34#include "clang/Sema/Sema.h"
35#include "clang/Sema/Template.h"
36#include "llvm/ADT/ArrayRef.h"
37#include "llvm/ADT/STLExtras.h"
38#include "llvm/ADT/SmallVector.h"
39#include "llvm/ADT/StringExtras.h"
40#include "llvm/ADT/StringRef.h"
41#include "llvm/ADT/Twine.h"
42#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
43#include "llvm/Support/Casting.h"
44#include "llvm/Support/DXILABI.h"
45#include "llvm/Support/ErrorHandling.h"
46#include "llvm/Support/FormatVariadic.h"
47#include "llvm/TargetParser/Triple.h"
48#include <cmath>
49#include <cstddef>
50#include <iterator>
51#include <utility>
52
53using namespace clang;
54using RegisterType = HLSLResourceBindingAttr::RegisterType;
55
56static CXXRecordDecl *createHostLayoutStruct(Sema &S,
57 CXXRecordDecl *StructDecl);
58
59static RegisterType getRegisterType(ResourceClass RC) {
60 switch (RC) {
61 case ResourceClass::SRV:
62 return RegisterType::SRV;
63 case ResourceClass::UAV:
64 return RegisterType::UAV;
65 case ResourceClass::CBuffer:
66 return RegisterType::CBuffer;
67 case ResourceClass::Sampler:
68 return RegisterType::Sampler;
69 }
70 llvm_unreachable("unexpected ResourceClass value");
71}
72
73// Converts the first letter of string Slot to RegisterType.
74// Returns false if the letter does not correspond to a valid register type.
75static bool convertToRegisterType(StringRef Slot, RegisterType *RT) {
76 assert(RT != nullptr);
77 switch (Slot[0]) {
78 case 't':
79 case 'T':
80 *RT = RegisterType::SRV;
81 return true;
82 case 'u':
83 case 'U':
84 *RT = RegisterType::UAV;
85 return true;
86 case 'b':
87 case 'B':
88 *RT = RegisterType::CBuffer;
89 return true;
90 case 's':
91 case 'S':
92 *RT = RegisterType::Sampler;
93 return true;
94 case 'c':
95 case 'C':
96 *RT = RegisterType::C;
97 return true;
98 case 'i':
99 case 'I':
100 *RT = RegisterType::I;
101 return true;
102 default:
103 return false;
104 }
105}
106
107static ResourceClass getResourceClass(RegisterType RT) {
108 switch (RT) {
109 case RegisterType::SRV:
110 return ResourceClass::SRV;
111 case RegisterType::UAV:
112 return ResourceClass::UAV;
113 case RegisterType::CBuffer:
114 return ResourceClass::CBuffer;
115 case RegisterType::Sampler:
116 return ResourceClass::Sampler;
117 case RegisterType::C:
118 case RegisterType::I:
119 // Deliberately falling through to the unreachable below.
120 break;
121 }
122 llvm_unreachable("unexpected RegisterType value");
123}
124
125static Builtin::ID getSpecConstBuiltinId(const Type *Type) {
126 const auto *BT = dyn_cast<BuiltinType>(Val: Type);
127 if (!BT) {
128 if (!Type->isEnumeralType())
129 return Builtin::NotBuiltin;
130 return Builtin::BI__builtin_get_spirv_spec_constant_int;
131 }
132
133 switch (BT->getKind()) {
134 case BuiltinType::Bool:
135 return Builtin::BI__builtin_get_spirv_spec_constant_bool;
136 case BuiltinType::Short:
137 return Builtin::BI__builtin_get_spirv_spec_constant_short;
138 case BuiltinType::Int:
139 return Builtin::BI__builtin_get_spirv_spec_constant_int;
140 case BuiltinType::LongLong:
141 return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
142 case BuiltinType::UShort:
143 return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
144 case BuiltinType::UInt:
145 return Builtin::BI__builtin_get_spirv_spec_constant_uint;
146 case BuiltinType::ULongLong:
147 return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
148 case BuiltinType::Half:
149 return Builtin::BI__builtin_get_spirv_spec_constant_half;
150 case BuiltinType::Float:
151 return Builtin::BI__builtin_get_spirv_spec_constant_float;
152 case BuiltinType::Double:
153 return Builtin::BI__builtin_get_spirv_spec_constant_double;
154 default:
155 return Builtin::NotBuiltin;
156 }
157}
158
159DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
160 ResourceClass ResClass) {
161 assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
162 "DeclBindingInfo already added");
163 assert(!hasBindingInfoForDecl(VD) || BindingsList.back().Decl == VD);
164 // VarDecl may have multiple entries for different resource classes.
165 // DeclToBindingListIndex stores the index of the first binding we saw
166 // for this decl. If there are any additional ones then that index
167 // shouldn't be updated.
168 DeclToBindingListIndex.try_emplace(Key: VD, Args: BindingsList.size());
169 return &BindingsList.emplace_back(Args&: VD, Args&: ResClass);
170}
171
172DeclBindingInfo *ResourceBindings::getDeclBindingInfo(const VarDecl *VD,
173 ResourceClass ResClass) {
174 auto Entry = DeclToBindingListIndex.find(Val: VD);
175 if (Entry != DeclToBindingListIndex.end()) {
176 for (unsigned Index = Entry->getSecond();
177 Index < BindingsList.size() && BindingsList[Index].Decl == VD;
178 ++Index) {
179 if (BindingsList[Index].ResClass == ResClass)
180 return &BindingsList[Index];
181 }
182 }
183 return nullptr;
184}
185
186bool ResourceBindings::hasBindingInfoForDecl(const VarDecl *VD) const {
187 return DeclToBindingListIndex.contains(Val: VD);
188}
189
190SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
191
192Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
193 SourceLocation KwLoc, IdentifierInfo *Ident,
194 SourceLocation IdentLoc,
195 SourceLocation LBrace) {
196 // For anonymous namespace, take the location of the left brace.
197 DeclContext *LexicalParent = SemaRef.getCurLexicalContext();
198 HLSLBufferDecl *Result = HLSLBufferDecl::Create(
199 C&: getASTContext(), LexicalParent, CBuffer, KwLoc, ID: Ident, IDLoc: IdentLoc, LBrace);
200
201 // if CBuffer is false, then it's a TBuffer
202 auto RC = CBuffer ? llvm::hlsl::ResourceClass::CBuffer
203 : llvm::hlsl::ResourceClass::SRV;
204 Result->addAttr(A: HLSLResourceClassAttr::CreateImplicit(Ctx&: getASTContext(), ResourceClass: RC));
205
206 SemaRef.PushOnScopeChains(D: Result, S: BufferScope);
207 SemaRef.PushDeclContext(S: BufferScope, DC: Result);
208
209 return Result;
210}
211
212static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context,
213 QualType T) {
214 // Arrays and Structs are always aligned to new buffer rows
215 if (T->isArrayType() || T->isStructureType())
216 return 16;
217
218 // Vectors are aligned to the type they contain
219 if (const VectorType *VT = T->getAs<VectorType>())
220 return calculateLegacyCbufferFieldAlign(Context, T: VT->getElementType());
221
222 assert(Context.getTypeSize(T) <= 64 &&
223 "Scalar bit widths larger than 64 not supported");
224
225 // Scalar types are aligned to their byte width
226 return Context.getTypeSize(T) / 8;
227}
228
229// Calculate the size of a legacy cbuffer type in bytes based on
230// https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules
231static unsigned calculateLegacyCbufferSize(const ASTContext &Context,
232 QualType T) {
233 constexpr unsigned CBufferAlign = 16;
234 if (const RecordType *RT = T->getAs<RecordType>()) {
235 unsigned Size = 0;
236 const RecordDecl *RD = RT->getDecl();
237 for (const FieldDecl *Field : RD->fields()) {
238 QualType Ty = Field->getType();
239 unsigned FieldSize = calculateLegacyCbufferSize(Context, T: Ty);
240 unsigned FieldAlign = calculateLegacyCbufferFieldAlign(Context, T: Ty);
241
242 // If the field crosses the row boundary after alignment it drops to the
243 // next row
244 unsigned AlignSize = llvm::alignTo(Value: Size, Align: FieldAlign);
245 if ((AlignSize % CBufferAlign) + FieldSize > CBufferAlign) {
246 FieldAlign = CBufferAlign;
247 }
248
249 Size = llvm::alignTo(Value: Size, Align: FieldAlign);
250 Size += FieldSize;
251 }
252 return Size;
253 }
254
255 if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) {
256 unsigned ElementCount = AT->getSize().getZExtValue();
257 if (ElementCount == 0)
258 return 0;
259
260 unsigned ElementSize =
261 calculateLegacyCbufferSize(Context, T: AT->getElementType());
262 unsigned AlignedElementSize = llvm::alignTo(Value: ElementSize, Align: CBufferAlign);
263 return AlignedElementSize * (ElementCount - 1) + ElementSize;
264 }
265
266 if (const VectorType *VT = T->getAs<VectorType>()) {
267 unsigned ElementCount = VT->getNumElements();
268 unsigned ElementSize =
269 calculateLegacyCbufferSize(Context, T: VT->getElementType());
270 return ElementSize * ElementCount;
271 }
272
273 return Context.getTypeSize(T) / 8;
274}
275
276// Validate packoffset:
277// - if packoffset it used it must be set on all declarations inside the buffer
278// - packoffset ranges must not overlap
279static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl) {
280 llvm::SmallVector<std::pair<VarDecl *, HLSLPackOffsetAttr *>> PackOffsetVec;
281
282 // Make sure the packoffset annotations are either on all declarations
283 // or on none.
284 bool HasPackOffset = false;
285 bool HasNonPackOffset = false;
286 for (auto *Field : BufDecl->buffer_decls()) {
287 VarDecl *Var = dyn_cast<VarDecl>(Val: Field);
288 if (!Var)
289 continue;
290 if (Field->hasAttr<HLSLPackOffsetAttr>()) {
291 PackOffsetVec.emplace_back(Args&: Var, Args: Field->getAttr<HLSLPackOffsetAttr>());
292 HasPackOffset = true;
293 } else {
294 HasNonPackOffset = true;
295 }
296 }
297
298 if (!HasPackOffset)
299 return;
300
301 if (HasNonPackOffset)
302 S.Diag(Loc: BufDecl->getLocation(), DiagID: diag::warn_hlsl_packoffset_mix);
303
304 // Make sure there is no overlap in packoffset - sort PackOffsetVec by offset
305 // and compare adjacent values.
306 bool IsValid = true;
307 ASTContext &Context = S.getASTContext();
308 std::sort(first: PackOffsetVec.begin(), last: PackOffsetVec.end(),
309 comp: [](const std::pair<VarDecl *, HLSLPackOffsetAttr *> &LHS,
310 const std::pair<VarDecl *, HLSLPackOffsetAttr *> &RHS) {
311 return LHS.second->getOffsetInBytes() <
312 RHS.second->getOffsetInBytes();
313 });
314 for (unsigned i = 0; i < PackOffsetVec.size() - 1; i++) {
315 VarDecl *Var = PackOffsetVec[i].first;
316 HLSLPackOffsetAttr *Attr = PackOffsetVec[i].second;
317 unsigned Size = calculateLegacyCbufferSize(Context, T: Var->getType());
318 unsigned Begin = Attr->getOffsetInBytes();
319 unsigned End = Begin + Size;
320 unsigned NextBegin = PackOffsetVec[i + 1].second->getOffsetInBytes();
321 if (End > NextBegin) {
322 VarDecl *NextVar = PackOffsetVec[i + 1].first;
323 S.Diag(Loc: NextVar->getLocation(), DiagID: diag::err_hlsl_packoffset_overlap)
324 << NextVar << Var;
325 IsValid = false;
326 }
327 }
328 BufDecl->setHasValidPackoffset(IsValid);
329}
330
331// Returns true if the array has a zero size = if any of the dimensions is 0
332static bool isZeroSizedArray(const ConstantArrayType *CAT) {
333 while (CAT && !CAT->isZeroSize())
334 CAT = dyn_cast<ConstantArrayType>(
335 Val: CAT->getElementType()->getUnqualifiedDesugaredType());
336 return CAT != nullptr;
337}
338
339// Returns true if the record type is an HLSL resource class or an array of
340// resource classes
341static bool isResourceRecordTypeOrArrayOf(const Type *Ty) {
342 while (const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(Val: Ty))
343 Ty = CAT->getArrayElementTypeNoTypeQual();
344 return HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty) != nullptr;
345}
346
347static bool isResourceRecordTypeOrArrayOf(VarDecl *VD) {
348 return isResourceRecordTypeOrArrayOf(Ty: VD->getType().getTypePtr());
349}
350
351// Returns true if the type is a leaf element type that is not valid to be
352// included in HLSL Buffer, such as a resource class, empty struct, zero-sized
353// array, or a builtin intangible type. Returns false it is a valid leaf element
354// type or if it is a record type that needs to be inspected further.
355static bool isInvalidConstantBufferLeafElementType(const Type *Ty) {
356 Ty = Ty->getUnqualifiedDesugaredType();
357 if (isResourceRecordTypeOrArrayOf(Ty))
358 return true;
359 if (Ty->isRecordType())
360 return Ty->getAsCXXRecordDecl()->isEmpty();
361 if (Ty->isConstantArrayType() &&
362 isZeroSizedArray(CAT: cast<ConstantArrayType>(Val: Ty)))
363 return true;
364 if (Ty->isHLSLBuiltinIntangibleType() || Ty->isHLSLAttributedResourceType())
365 return true;
366 return false;
367}
368
369// Returns true if the struct contains at least one element that prevents it
370// from being included inside HLSL Buffer as is, such as an intangible type,
371// empty struct, or zero-sized array. If it does, a new implicit layout struct
372// needs to be created for HLSL Buffer use that will exclude these unwanted
373// declarations (see createHostLayoutStruct function).
374static bool requiresImplicitBufferLayoutStructure(const CXXRecordDecl *RD) {
375 if (RD->getTypeForDecl()->isHLSLIntangibleType() || RD->isEmpty())
376 return true;
377 // check fields
378 for (const FieldDecl *Field : RD->fields()) {
379 QualType Ty = Field->getType();
380 if (isInvalidConstantBufferLeafElementType(Ty: Ty.getTypePtr()))
381 return true;
382 if (Ty->isRecordType() &&
383 requiresImplicitBufferLayoutStructure(RD: Ty->getAsCXXRecordDecl()))
384 return true;
385 }
386 // check bases
387 for (const CXXBaseSpecifier &Base : RD->bases())
388 if (requiresImplicitBufferLayoutStructure(
389 RD: Base.getType()->getAsCXXRecordDecl()))
390 return true;
391 return false;
392}
393
394static CXXRecordDecl *findRecordDeclInContext(IdentifierInfo *II,
395 DeclContext *DC) {
396 CXXRecordDecl *RD = nullptr;
397 for (NamedDecl *Decl :
398 DC->getNonTransparentContext()->lookup(Name: DeclarationName(II))) {
399 if (CXXRecordDecl *FoundRD = dyn_cast<CXXRecordDecl>(Val: Decl)) {
400 assert(RD == nullptr &&
401 "there should be at most 1 record by a given name in a scope");
402 RD = FoundRD;
403 }
404 }
405 return RD;
406}
407
408// Creates a name for buffer layout struct using the provide name base.
409// If the name must be unique (not previously defined), a suffix is added
410// until a unique name is found.
411static IdentifierInfo *getHostLayoutStructName(Sema &S, NamedDecl *BaseDecl,
412 bool MustBeUnique) {
413 ASTContext &AST = S.getASTContext();
414
415 IdentifierInfo *NameBaseII = BaseDecl->getIdentifier();
416 llvm::SmallString<64> Name("__cblayout_");
417 if (NameBaseII) {
418 Name.append(RHS: NameBaseII->getName());
419 } else {
420 // anonymous struct
421 Name.append(RHS: "anon");
422 MustBeUnique = true;
423 }
424
425 size_t NameLength = Name.size();
426 IdentifierInfo *II = &AST.Idents.get(Name, TokenCode: tok::TokenKind::identifier);
427 if (!MustBeUnique)
428 return II;
429
430 unsigned suffix = 0;
431 while (true) {
432 if (suffix != 0) {
433 Name.append(RHS: "_");
434 Name.append(RHS: llvm::Twine(suffix).str());
435 II = &AST.Idents.get(Name, TokenCode: tok::TokenKind::identifier);
436 }
437 if (!findRecordDeclInContext(II, DC: BaseDecl->getDeclContext()))
438 return II;
439 // declaration with that name already exists - increment suffix and try
440 // again until unique name is found
441 suffix++;
442 Name.truncate(N: NameLength);
443 };
444}
445
446// Creates a field declaration of given name and type for HLSL buffer layout
447// struct. Returns nullptr if the type cannot be use in HLSL Buffer layout.
448static FieldDecl *createFieldForHostLayoutStruct(Sema &S, const Type *Ty,
449 IdentifierInfo *II,
450 CXXRecordDecl *LayoutStruct) {
451 if (isInvalidConstantBufferLeafElementType(Ty))
452 return nullptr;
453
454 if (Ty->isRecordType()) {
455 CXXRecordDecl *RD = Ty->getAsCXXRecordDecl();
456 if (requiresImplicitBufferLayoutStructure(RD)) {
457 RD = createHostLayoutStruct(S, StructDecl: RD);
458 if (!RD)
459 return nullptr;
460 Ty = RD->getTypeForDecl();
461 }
462 }
463
464 QualType QT = QualType(Ty, 0);
465 ASTContext &AST = S.getASTContext();
466 TypeSourceInfo *TSI = AST.getTrivialTypeSourceInfo(T: QT, Loc: SourceLocation());
467 auto *Field = FieldDecl::Create(C: AST, DC: LayoutStruct, StartLoc: SourceLocation(),
468 IdLoc: SourceLocation(), Id: II, T: QT, TInfo: TSI, BW: nullptr, Mutable: false,
469 InitStyle: InClassInitStyle::ICIS_NoInit);
470 Field->setAccess(AccessSpecifier::AS_public);
471 return Field;
472}
473
474// Creates host layout struct for a struct included in HLSL Buffer.
475// The layout struct will include only fields that are allowed in HLSL buffer.
476// These fields will be filtered out:
477// - resource classes
478// - empty structs
479// - zero-sized arrays
480// Returns nullptr if the resulting layout struct would be empty.
481static CXXRecordDecl *createHostLayoutStruct(Sema &S,
482 CXXRecordDecl *StructDecl) {
483 assert(requiresImplicitBufferLayoutStructure(StructDecl) &&
484 "struct is already HLSL buffer compatible");
485
486 ASTContext &AST = S.getASTContext();
487 DeclContext *DC = StructDecl->getDeclContext();
488 IdentifierInfo *II = getHostLayoutStructName(S, BaseDecl: StructDecl, MustBeUnique: false);
489
490 // reuse existing if the layout struct if it already exists
491 if (CXXRecordDecl *RD = findRecordDeclInContext(II, DC))
492 return RD;
493
494 CXXRecordDecl *LS =
495 CXXRecordDecl::Create(C: AST, TK: TagDecl::TagKind::Struct, DC, StartLoc: SourceLocation(),
496 IdLoc: SourceLocation(), Id: II);
497 LS->setImplicit(true);
498 LS->addAttr(A: PackedAttr::CreateImplicit(Ctx&: AST));
499 LS->startDefinition();
500
501 // copy base struct, create HLSL Buffer compatible version if needed
502 if (unsigned NumBases = StructDecl->getNumBases()) {
503 assert(NumBases == 1 && "HLSL supports only one base type");
504 (void)NumBases;
505 CXXBaseSpecifier Base = *StructDecl->bases_begin();
506 CXXRecordDecl *BaseDecl = Base.getType()->getAsCXXRecordDecl();
507 if (requiresImplicitBufferLayoutStructure(RD: BaseDecl)) {
508 BaseDecl = createHostLayoutStruct(S, StructDecl: BaseDecl);
509 if (BaseDecl) {
510 TypeSourceInfo *TSI = AST.getTrivialTypeSourceInfo(
511 T: QualType(BaseDecl->getTypeForDecl(), 0));
512 Base = CXXBaseSpecifier(SourceRange(), false, StructDecl->isClass(),
513 AS_none, TSI, SourceLocation());
514 }
515 }
516 if (BaseDecl) {
517 const CXXBaseSpecifier *BasesArray[1] = {&Base};
518 LS->setBases(Bases: BasesArray, NumBases: 1);
519 }
520 }
521
522 // filter struct fields
523 for (const FieldDecl *FD : StructDecl->fields()) {
524 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
525 if (FieldDecl *NewFD =
526 createFieldForHostLayoutStruct(S, Ty, II: FD->getIdentifier(), LayoutStruct: LS))
527 LS->addDecl(D: NewFD);
528 }
529 LS->completeDefinition();
530
531 if (LS->field_empty() && LS->getNumBases() == 0)
532 return nullptr;
533
534 DC->addDecl(D: LS);
535 return LS;
536}
537
538// Creates host layout struct for HLSL Buffer. The struct will include only
539// fields of types that are allowed in HLSL buffer and it will filter out:
540// - static or groupshared variable declarations
541// - resource classes
542// - empty structs
543// - zero-sized arrays
544// - non-variable declarations
545// The layout struct will be added to the HLSLBufferDecl declarations.
546void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl) {
547 ASTContext &AST = S.getASTContext();
548 IdentifierInfo *II = getHostLayoutStructName(S, BaseDecl: BufDecl, MustBeUnique: true);
549
550 CXXRecordDecl *LS =
551 CXXRecordDecl::Create(C: AST, TK: TagDecl::TagKind::Struct, DC: BufDecl,
552 StartLoc: SourceLocation(), IdLoc: SourceLocation(), Id: II);
553 LS->addAttr(A: PackedAttr::CreateImplicit(Ctx&: AST));
554 LS->setImplicit(true);
555 LS->startDefinition();
556
557 for (Decl *D : BufDecl->buffer_decls()) {
558 VarDecl *VD = dyn_cast<VarDecl>(Val: D);
559 if (!VD || VD->getStorageClass() == SC_Static ||
560 VD->getType().getAddressSpace() == LangAS::hlsl_groupshared)
561 continue;
562 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
563 if (FieldDecl *FD =
564 createFieldForHostLayoutStruct(S, Ty, II: VD->getIdentifier(), LayoutStruct: LS)) {
565 // add the field decl to the layout struct
566 LS->addDecl(D: FD);
567 // update address space of the original decl to hlsl_constant
568 QualType NewTy =
569 AST.getAddrSpaceQualType(T: VD->getType(), AddressSpace: LangAS::hlsl_constant);
570 VD->setType(NewTy);
571 }
572 }
573 LS->completeDefinition();
574 BufDecl->addLayoutStruct(LS);
575}
576
577static void addImplicitBindingAttrToBuffer(Sema &S, HLSLBufferDecl *BufDecl,
578 uint32_t ImplicitBindingOrderID) {
579 RegisterType RT =
580 BufDecl->isCBuffer() ? RegisterType::CBuffer : RegisterType::SRV;
581 auto *Attr =
582 HLSLResourceBindingAttr::CreateImplicit(Ctx&: S.getASTContext(), Slot: "", Space: "0", Range: {});
583 std::optional<unsigned> RegSlot;
584 Attr->setBinding(RT, SlotNum: RegSlot, SpaceNum: 0);
585 Attr->setImplicitBindingOrderID(ImplicitBindingOrderID);
586 BufDecl->addAttr(A: Attr);
587}
588
589// Handle end of cbuffer/tbuffer declaration
590void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) {
591 auto *BufDecl = cast<HLSLBufferDecl>(Val: Dcl);
592 BufDecl->setRBraceLoc(RBrace);
593
594 validatePackoffset(S&: SemaRef, BufDecl);
595
596 // create buffer layout struct
597 createHostLayoutStructForBuffer(S&: SemaRef, BufDecl);
598
599 HLSLResourceBindingAttr *RBA = Dcl->getAttr<HLSLResourceBindingAttr>();
600 if (!RBA || !RBA->hasRegisterSlot()) {
601 SemaRef.Diag(Loc: Dcl->getLocation(), DiagID: diag::warn_hlsl_implicit_binding);
602 // Use HLSLResourceBindingAttr to transfer implicit binding order_ID
603 // to codegen. If it does not exist, create an implicit attribute.
604 uint32_t OrderID = getNextImplicitBindingOrderID();
605 if (RBA)
606 RBA->setImplicitBindingOrderID(OrderID);
607 else
608 addImplicitBindingAttrToBuffer(S&: SemaRef, BufDecl, ImplicitBindingOrderID: OrderID);
609 }
610
611 SemaRef.PopDeclContext();
612}
613
614HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
615 const AttributeCommonInfo &AL,
616 int X, int Y, int Z) {
617 if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
618 if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
619 Diag(Loc: NT->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
620 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
621 }
622 return nullptr;
623 }
624 return ::new (getASTContext())
625 HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
626}
627
628HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
629 const AttributeCommonInfo &AL,
630 int Min, int Max, int Preferred,
631 int SpelledArgsCount) {
632 if (HLSLWaveSizeAttr *WS = D->getAttr<HLSLWaveSizeAttr>()) {
633 if (WS->getMin() != Min || WS->getMax() != Max ||
634 WS->getPreferred() != Preferred ||
635 WS->getSpelledArgsCount() != SpelledArgsCount) {
636 Diag(Loc: WS->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
637 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
638 }
639 return nullptr;
640 }
641 HLSLWaveSizeAttr *Result = ::new (getASTContext())
642 HLSLWaveSizeAttr(getASTContext(), AL, Min, Max, Preferred);
643 Result->setSpelledArgsCount(SpelledArgsCount);
644 return Result;
645}
646
647HLSLVkConstantIdAttr *
648SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL,
649 int Id) {
650
651 auto &TargetInfo = getASTContext().getTargetInfo();
652 if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
653 Diag(Loc: AL.getLoc(), DiagID: diag::warn_attribute_ignored) << AL;
654 return nullptr;
655 }
656
657 auto *VD = cast<VarDecl>(Val: D);
658
659 if (getSpecConstBuiltinId(Type: VD->getType()->getUnqualifiedDesugaredType()) ==
660 Builtin::NotBuiltin) {
661 Diag(Loc: VD->getLocation(), DiagID: diag::err_specialization_const);
662 return nullptr;
663 }
664
665 if (!VD->getType().isConstQualified()) {
666 Diag(Loc: VD->getLocation(), DiagID: diag::err_specialization_const);
667 return nullptr;
668 }
669
670 if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
671 if (CI->getId() != Id) {
672 Diag(Loc: CI->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
673 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
674 }
675 return nullptr;
676 }
677
678 HLSLVkConstantIdAttr *Result =
679 ::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
680 return Result;
681}
682
683HLSLShaderAttr *
684SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
685 llvm::Triple::EnvironmentType ShaderType) {
686 if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
687 if (NT->getType() != ShaderType) {
688 Diag(Loc: NT->getLocation(), DiagID: diag::err_hlsl_attribute_param_mismatch) << AL;
689 Diag(Loc: AL.getLoc(), DiagID: diag::note_conflicting_attribute);
690 }
691 return nullptr;
692 }
693 return HLSLShaderAttr::Create(Ctx&: getASTContext(), Type: ShaderType, CommonInfo: AL);
694}
695
696HLSLParamModifierAttr *
697SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
698 HLSLParamModifierAttr::Spelling Spelling) {
699 // We can only merge an `in` attribute with an `out` attribute. All other
700 // combinations of duplicated attributes are ill-formed.
701 if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
702 if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
703 (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
704 D->dropAttr<HLSLParamModifierAttr>();
705 SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
706 return HLSLParamModifierAttr::Create(
707 Ctx&: getASTContext(), /*MergedSpelling=*/true, Range: AdjustedRange,
708 S: HLSLParamModifierAttr::Keyword_inout);
709 }
710 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_duplicate_parameter_modifier) << AL;
711 Diag(Loc: PA->getLocation(), DiagID: diag::note_conflicting_attribute);
712 return nullptr;
713 }
714 return HLSLParamModifierAttr::Create(Ctx&: getASTContext(), CommonInfo: AL);
715}
716
717void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
718 auto &TargetInfo = getASTContext().getTargetInfo();
719
720 if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
721 return;
722
723 llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
724 if (HLSLShaderAttr::isValidShaderType(ShaderType: Env) && Env != llvm::Triple::Library) {
725 if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
726 // The entry point is already annotated - check that it matches the
727 // triple.
728 if (Shader->getType() != Env) {
729 Diag(Loc: Shader->getLocation(), DiagID: diag::err_hlsl_entry_shader_attr_mismatch)
730 << Shader;
731 FD->setInvalidDecl();
732 }
733 } else {
734 // Implicitly add the shader attribute if the entry function isn't
735 // explicitly annotated.
736 FD->addAttr(A: HLSLShaderAttr::CreateImplicit(Ctx&: getASTContext(), Type: Env,
737 Range: FD->getBeginLoc()));
738 }
739 } else {
740 switch (Env) {
741 case llvm::Triple::UnknownEnvironment:
742 case llvm::Triple::Library:
743 break;
744 default:
745 llvm_unreachable("Unhandled environment in triple");
746 }
747 }
748}
749
750void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
751 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
752 assert(ShaderAttr && "Entry point has no shader attribute");
753 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
754 auto &TargetInfo = getASTContext().getTargetInfo();
755 VersionTuple Ver = TargetInfo.getTriple().getOSVersion();
756 switch (ST) {
757 case llvm::Triple::Pixel:
758 case llvm::Triple::Vertex:
759 case llvm::Triple::Geometry:
760 case llvm::Triple::Hull:
761 case llvm::Triple::Domain:
762 case llvm::Triple::RayGeneration:
763 case llvm::Triple::Intersection:
764 case llvm::Triple::AnyHit:
765 case llvm::Triple::ClosestHit:
766 case llvm::Triple::Miss:
767 case llvm::Triple::Callable:
768 if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
769 DiagnoseAttrStageMismatch(A: NT, Stage: ST,
770 AllowedStages: {llvm::Triple::Compute,
771 llvm::Triple::Amplification,
772 llvm::Triple::Mesh});
773 FD->setInvalidDecl();
774 }
775 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
776 DiagnoseAttrStageMismatch(A: WS, Stage: ST,
777 AllowedStages: {llvm::Triple::Compute,
778 llvm::Triple::Amplification,
779 llvm::Triple::Mesh});
780 FD->setInvalidDecl();
781 }
782 break;
783
784 case llvm::Triple::Compute:
785 case llvm::Triple::Amplification:
786 case llvm::Triple::Mesh:
787 if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
788 Diag(Loc: FD->getLocation(), DiagID: diag::err_hlsl_missing_numthreads)
789 << llvm::Triple::getEnvironmentTypeName(Kind: ST);
790 FD->setInvalidDecl();
791 }
792 if (const auto *WS = FD->getAttr<HLSLWaveSizeAttr>()) {
793 if (Ver < VersionTuple(6, 6)) {
794 Diag(Loc: WS->getLocation(), DiagID: diag::err_hlsl_attribute_in_wrong_shader_model)
795 << WS << "6.6";
796 FD->setInvalidDecl();
797 } else if (WS->getSpelledArgsCount() > 1 && Ver < VersionTuple(6, 8)) {
798 Diag(
799 Loc: WS->getLocation(),
800 DiagID: diag::err_hlsl_attribute_number_arguments_insufficient_shader_model)
801 << WS << WS->getSpelledArgsCount() << "6.8";
802 FD->setInvalidDecl();
803 }
804 }
805 break;
806 default:
807 llvm_unreachable("Unhandled environment in triple");
808 }
809
810 for (ParmVarDecl *Param : FD->parameters()) {
811 if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) {
812 CheckSemanticAnnotation(EntryPoint: FD, Param, AnnotationAttr);
813 } else {
814 // FIXME: Handle struct parameters where annotations are on struct fields.
815 // See: https://github.com/llvm/llvm-project/issues/57875
816 Diag(Loc: FD->getLocation(), DiagID: diag::err_hlsl_missing_semantic_annotation);
817 Diag(Loc: Param->getLocation(), DiagID: diag::note_previous_decl) << Param;
818 FD->setInvalidDecl();
819 }
820 }
821 // FIXME: Verify return type semantic annotation.
822}
823
824void SemaHLSL::CheckSemanticAnnotation(
825 FunctionDecl *EntryPoint, const Decl *Param,
826 const HLSLAnnotationAttr *AnnotationAttr) {
827 auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
828 assert(ShaderAttr && "Entry point has no shader attribute");
829 llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
830
831 switch (AnnotationAttr->getKind()) {
832 case attr::HLSLSV_DispatchThreadID:
833 case attr::HLSLSV_GroupIndex:
834 case attr::HLSLSV_GroupThreadID:
835 case attr::HLSLSV_GroupID:
836 if (ST == llvm::Triple::Compute)
837 return;
838 DiagnoseAttrStageMismatch(A: AnnotationAttr, Stage: ST, AllowedStages: {llvm::Triple::Compute});
839 break;
840 case attr::HLSLSV_Position:
841 // TODO(#143523): allow use on other shader types & output once the overall
842 // semantic logic is implemented.
843 if (ST == llvm::Triple::Pixel)
844 return;
845 DiagnoseAttrStageMismatch(A: AnnotationAttr, Stage: ST, AllowedStages: {llvm::Triple::Pixel});
846 break;
847 default:
848 llvm_unreachable("Unknown HLSLAnnotationAttr");
849 }
850}
851
852void SemaHLSL::DiagnoseAttrStageMismatch(
853 const Attr *A, llvm::Triple::EnvironmentType Stage,
854 std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
855 SmallVector<StringRef, 8> StageStrings;
856 llvm::transform(Range&: AllowedStages, d_first: std::back_inserter(x&: StageStrings),
857 F: [](llvm::Triple::EnvironmentType ST) {
858 return StringRef(
859 HLSLShaderAttr::ConvertEnvironmentTypeToStr(Val: ST));
860 });
861 Diag(Loc: A->getLoc(), DiagID: diag::err_hlsl_attr_unsupported_in_stage)
862 << A->getAttrName() << llvm::Triple::getEnvironmentTypeName(Kind: Stage)
863 << (AllowedStages.size() != 1) << join(R&: StageStrings, Separator: ", ");
864}
865
866template <CastKind Kind>
867static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz) {
868 if (const auto *VTy = Ty->getAs<VectorType>())
869 Ty = VTy->getElementType();
870 Ty = S.getASTContext().getExtVectorType(VectorType: Ty, NumElts: Sz);
871 E = S.ImpCastExprToType(E: E.get(), Type: Ty, CK: Kind);
872}
873
874template <CastKind Kind>
875static QualType castElement(Sema &S, ExprResult &E, QualType Ty) {
876 E = S.ImpCastExprToType(E: E.get(), Type: Ty, CK: Kind);
877 return Ty;
878}
879
880static QualType handleFloatVectorBinOpConversion(
881 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
882 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
883 bool LHSFloat = LElTy->isRealFloatingType();
884 bool RHSFloat = RElTy->isRealFloatingType();
885
886 if (LHSFloat && RHSFloat) {
887 if (IsCompAssign ||
888 SemaRef.getASTContext().getFloatingTypeOrder(LHS: LElTy, RHS: RElTy) > 0)
889 return castElement<CK_FloatingCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
890
891 return castElement<CK_FloatingCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
892 }
893
894 if (LHSFloat)
895 return castElement<CK_IntegralToFloating>(S&: SemaRef, E&: RHS, Ty: LHSType);
896
897 assert(RHSFloat);
898 if (IsCompAssign)
899 return castElement<clang::CK_FloatingToIntegral>(S&: SemaRef, E&: RHS, Ty: LHSType);
900
901 return castElement<CK_IntegralToFloating>(S&: SemaRef, E&: LHS, Ty: RHSType);
902}
903
904static QualType handleIntegerVectorBinOpConversion(
905 Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
906 QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
907
908 int IntOrder = SemaRef.Context.getIntegerTypeOrder(LHS: LElTy, RHS: RElTy);
909 bool LHSSigned = LElTy->hasSignedIntegerRepresentation();
910 bool RHSSigned = RElTy->hasSignedIntegerRepresentation();
911 auto &Ctx = SemaRef.getASTContext();
912
913 // If both types have the same signedness, use the higher ranked type.
914 if (LHSSigned == RHSSigned) {
915 if (IsCompAssign || IntOrder >= 0)
916 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
917
918 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
919 }
920
921 // If the unsigned type has greater than or equal rank of the signed type, use
922 // the unsigned type.
923 if (IntOrder != (LHSSigned ? 1 : -1)) {
924 if (IsCompAssign || RHSSigned)
925 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
926 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
927 }
928
929 // At this point the signed type has higher rank than the unsigned type, which
930 // means it will be the same size or bigger. If the signed type is bigger, it
931 // can represent all the values of the unsigned type, so select it.
932 if (Ctx.getIntWidth(T: LElTy) != Ctx.getIntWidth(T: RElTy)) {
933 if (IsCompAssign || LHSSigned)
934 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
935 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: RHSType);
936 }
937
938 // This is a bit of an odd duck case in HLSL. It shouldn't happen, but can due
939 // to C/C++ leaking through. The place this happens today is long vs long
940 // long. When arguments are vector<unsigned long, N> and vector<long long, N>,
941 // the long long has higher rank than long even though they are the same size.
942
943 // If this is a compound assignment cast the right hand side to the left hand
944 // side's type.
945 if (IsCompAssign)
946 return castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: LHSType);
947
948 // If this isn't a compound assignment we convert to unsigned long long.
949 QualType ElTy = Ctx.getCorrespondingUnsignedType(T: LHSSigned ? LElTy : RElTy);
950 QualType NewTy = Ctx.getExtVectorType(
951 VectorType: ElTy, NumElts: RHSType->castAs<VectorType>()->getNumElements());
952 (void)castElement<CK_IntegralCast>(S&: SemaRef, E&: RHS, Ty: NewTy);
953
954 return castElement<CK_IntegralCast>(S&: SemaRef, E&: LHS, Ty: NewTy);
955}
956
957static CastKind getScalarCastKind(ASTContext &Ctx, QualType DestTy,
958 QualType SrcTy) {
959 if (DestTy->isRealFloatingType() && SrcTy->isRealFloatingType())
960 return CK_FloatingCast;
961 if (DestTy->isIntegralType(Ctx) && SrcTy->isIntegralType(Ctx))
962 return CK_IntegralCast;
963 if (DestTy->isRealFloatingType())
964 return CK_IntegralToFloating;
965 assert(SrcTy->isRealFloatingType() && DestTy->isIntegralType(Ctx));
966 return CK_FloatingToIntegral;
967}
968
969QualType SemaHLSL::handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS,
970 QualType LHSType,
971 QualType RHSType,
972 bool IsCompAssign) {
973 const auto *LVecTy = LHSType->getAs<VectorType>();
974 const auto *RVecTy = RHSType->getAs<VectorType>();
975 auto &Ctx = getASTContext();
976
977 // If the LHS is not a vector and this is a compound assignment, we truncate
978 // the argument to a scalar then convert it to the LHS's type.
979 if (!LVecTy && IsCompAssign) {
980 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
981 RHS = SemaRef.ImpCastExprToType(E: RHS.get(), Type: RElTy, CK: CK_HLSLVectorTruncation);
982 RHSType = RHS.get()->getType();
983 if (Ctx.hasSameUnqualifiedType(T1: LHSType, T2: RHSType))
984 return LHSType;
985 RHS = SemaRef.ImpCastExprToType(E: RHS.get(), Type: LHSType,
986 CK: getScalarCastKind(Ctx, DestTy: LHSType, SrcTy: RHSType));
987 return LHSType;
988 }
989
990 unsigned EndSz = std::numeric_limits<unsigned>::max();
991 unsigned LSz = 0;
992 if (LVecTy)
993 LSz = EndSz = LVecTy->getNumElements();
994 if (RVecTy)
995 EndSz = std::min(a: RVecTy->getNumElements(), b: EndSz);
996 assert(EndSz != std::numeric_limits<unsigned>::max() &&
997 "one of the above should have had a value");
998
999 // In a compound assignment, the left operand does not change type, the right
1000 // operand is converted to the type of the left operand.
1001 if (IsCompAssign && LSz != EndSz) {
1002 Diag(Loc: LHS.get()->getBeginLoc(),
1003 DiagID: diag::err_hlsl_vector_compound_assignment_truncation)
1004 << LHSType << RHSType;
1005 return QualType();
1006 }
1007
1008 if (RVecTy && RVecTy->getNumElements() > EndSz)
1009 castVector<CK_HLSLVectorTruncation>(S&: SemaRef, E&: RHS, Ty&: RHSType, Sz: EndSz);
1010 if (!IsCompAssign && LVecTy && LVecTy->getNumElements() > EndSz)
1011 castVector<CK_HLSLVectorTruncation>(S&: SemaRef, E&: LHS, Ty&: LHSType, Sz: EndSz);
1012
1013 if (!RVecTy)
1014 castVector<CK_VectorSplat>(S&: SemaRef, E&: RHS, Ty&: RHSType, Sz: EndSz);
1015 if (!IsCompAssign && !LVecTy)
1016 castVector<CK_VectorSplat>(S&: SemaRef, E&: LHS, Ty&: LHSType, Sz: EndSz);
1017
1018 // If we're at the same type after resizing we can stop here.
1019 if (Ctx.hasSameUnqualifiedType(T1: LHSType, T2: RHSType))
1020 return Ctx.getCommonSugaredType(X: LHSType, Y: RHSType);
1021
1022 QualType LElTy = LHSType->castAs<VectorType>()->getElementType();
1023 QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
1024
1025 // Handle conversion for floating point vectors.
1026 if (LElTy->isRealFloatingType() || RElTy->isRealFloatingType())
1027 return handleFloatVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1028 LElTy, RElTy, IsCompAssign);
1029
1030 assert(LElTy->isIntegralType(Ctx) && RElTy->isIntegralType(Ctx) &&
1031 "HLSL Vectors can only contain integer or floating point types");
1032 return handleIntegerVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
1033 LElTy, RElTy, IsCompAssign);
1034}
1035
1036void SemaHLSL::emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS,
1037 BinaryOperatorKind Opc) {
1038 assert((Opc == BO_LOr || Opc == BO_LAnd) &&
1039 "Called with non-logical operator");
1040 llvm::SmallVector<char, 256> Buff;
1041 llvm::raw_svector_ostream OS(Buff);
1042 PrintingPolicy PP(SemaRef.getLangOpts());
1043 StringRef NewFnName = Opc == BO_LOr ? "or" : "and";
1044 OS << NewFnName << "(";
1045 LHS->printPretty(OS, Helper: nullptr, Policy: PP);
1046 OS << ", ";
1047 RHS->printPretty(OS, Helper: nullptr, Policy: PP);
1048 OS << ")";
1049 SourceRange FullRange = SourceRange(LHS->getBeginLoc(), RHS->getEndLoc());
1050 SemaRef.Diag(Loc: LHS->getBeginLoc(), DiagID: diag::note_function_suggestion)
1051 << NewFnName << FixItHint::CreateReplacement(RemoveRange: FullRange, Code: OS.str());
1052}
1053
1054std::pair<IdentifierInfo *, bool>
1055SemaHLSL::ActOnStartRootSignatureDecl(StringRef Signature) {
1056 llvm::hash_code Hash = llvm::hash_value(S: Signature);
1057 std::string IdStr = "__hlsl_rootsig_decl_" + std::to_string(val: Hash);
1058 IdentifierInfo *DeclIdent = &(getASTContext().Idents.get(Name: IdStr));
1059
1060 // Check if we have already found a decl of the same name.
1061 LookupResult R(SemaRef, DeclIdent, SourceLocation(),
1062 Sema::LookupOrdinaryName);
1063 bool Found = SemaRef.LookupQualifiedName(R, LookupCtx: SemaRef.CurContext);
1064 return {DeclIdent, Found};
1065}
1066
1067void SemaHLSL::ActOnFinishRootSignatureDecl(
1068 SourceLocation Loc, IdentifierInfo *DeclIdent,
1069 ArrayRef<hlsl::RootSignatureElement> RootElements) {
1070
1071 if (handleRootSignatureElements(Elements: RootElements))
1072 return;
1073
1074 SmallVector<llvm::hlsl::rootsig::RootElement> Elements;
1075 for (auto &RootSigElement : RootElements)
1076 Elements.push_back(Elt: RootSigElement.getElement());
1077
1078 auto *SignatureDecl = HLSLRootSignatureDecl::Create(
1079 C&: SemaRef.getASTContext(), /*DeclContext=*/DC: SemaRef.CurContext, Loc,
1080 ID: DeclIdent, Version: SemaRef.getLangOpts().HLSLRootSigVer, RootElements: Elements);
1081
1082 SignatureDecl->setImplicit();
1083 SemaRef.PushOnScopeChains(D: SignatureDecl, S: SemaRef.getCurScope());
1084}
1085
1086bool SemaHLSL::handleRootSignatureElements(
1087 ArrayRef<hlsl::RootSignatureElement> Elements) {
1088 // Define some common error handling functions
1089 bool HadError = false;
1090 auto ReportError = [this, &HadError](SourceLocation Loc, uint32_t LowerBound,
1091 uint32_t UpperBound) {
1092 HadError = true;
1093 this->Diag(Loc, DiagID: diag::err_hlsl_invalid_rootsig_value)
1094 << LowerBound << UpperBound;
1095 };
1096
1097 auto ReportFloatError = [this, &HadError](SourceLocation Loc,
1098 float LowerBound,
1099 float UpperBound) {
1100 HadError = true;
1101 this->Diag(Loc, DiagID: diag::err_hlsl_invalid_rootsig_value)
1102 << llvm::formatv(Fmt: "{0:f}", Vals&: LowerBound).sstr<6>()
1103 << llvm::formatv(Fmt: "{0:f}", Vals&: UpperBound).sstr<6>();
1104 };
1105
1106 auto VerifyRegister = [ReportError](SourceLocation Loc, uint32_t Register) {
1107 if (!llvm::hlsl::rootsig::verifyRegisterValue(RegisterValue: Register))
1108 ReportError(Loc, 0, 0xfffffffe);
1109 };
1110
1111 auto VerifySpace = [ReportError](SourceLocation Loc, uint32_t Space) {
1112 if (!llvm::hlsl::rootsig::verifyRegisterSpace(RegisterSpace: Space))
1113 ReportError(Loc, 0, 0xffffffef);
1114 };
1115
1116 const uint32_t Version =
1117 llvm::to_underlying(E: SemaRef.getLangOpts().HLSLRootSigVer);
1118 const uint32_t VersionEnum = Version - 1;
1119 auto ReportFlagError = [this, &HadError, VersionEnum](SourceLocation Loc) {
1120 HadError = true;
1121 this->Diag(Loc, DiagID: diag::err_hlsl_invalid_rootsig_flag)
1122 << /*version minor*/ VersionEnum;
1123 };
1124
1125 // Iterate through the elements and do basic validations
1126 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1127 SourceLocation Loc = RootSigElem.getLocation();
1128 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1129 if (const auto *Descriptor =
1130 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(ptr: &Elem)) {
1131 VerifyRegister(Loc, Descriptor->Reg.Number);
1132 VerifySpace(Loc, Descriptor->Space);
1133
1134 if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(
1135 Version, FlagsVal: llvm::to_underlying(E: Descriptor->Flags)))
1136 ReportFlagError(Loc);
1137 } else if (const auto *Constants =
1138 std::get_if<llvm::hlsl::rootsig::RootConstants>(ptr: &Elem)) {
1139 VerifyRegister(Loc, Constants->Reg.Number);
1140 VerifySpace(Loc, Constants->Space);
1141 } else if (const auto *Sampler =
1142 std::get_if<llvm::hlsl::rootsig::StaticSampler>(ptr: &Elem)) {
1143 VerifyRegister(Loc, Sampler->Reg.Number);
1144 VerifySpace(Loc, Sampler->Space);
1145
1146 assert(!std::isnan(Sampler->MaxLOD) && !std::isnan(Sampler->MinLOD) &&
1147 "By construction, parseFloatParam can't produce a NaN from a "
1148 "float_literal token");
1149
1150 if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(MaxAnisotropy: Sampler->MaxAnisotropy))
1151 ReportError(Loc, 0, 16);
1152 if (!llvm::hlsl::rootsig::verifyMipLODBias(MipLODBias: Sampler->MipLODBias))
1153 ReportFloatError(Loc, -16.f, 15.99);
1154 } else if (const auto *Clause =
1155 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1156 ptr: &Elem)) {
1157 VerifyRegister(Loc, Clause->Reg.Number);
1158 VerifySpace(Loc, Clause->Space);
1159
1160 if (!llvm::hlsl::rootsig::verifyNumDescriptors(NumDescriptors: Clause->NumDescriptors)) {
1161 // NumDescriptor could techincally be ~0u but that is reserved for
1162 // unbounded, so the diagnostic will not report that as a valid int
1163 // value
1164 ReportError(Loc, 1, 0xfffffffe);
1165 }
1166
1167 if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(
1168 Version, Type: llvm::to_underlying(E: Clause->Type),
1169 FlagsVal: llvm::to_underlying(E: Clause->Flags)))
1170 ReportFlagError(Loc);
1171 }
1172 }
1173
1174 using RangeInfo = llvm::hlsl::rootsig::RangeInfo;
1175 using OverlappingRanges = llvm::hlsl::rootsig::OverlappingRanges;
1176 using InfoPairT = std::pair<RangeInfo, const hlsl::RootSignatureElement *>;
1177
1178 // 1. Collect RangeInfos
1179 llvm::SmallVector<InfoPairT> InfoPairs;
1180 for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
1181 const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
1182 if (const auto *Descriptor =
1183 std::get_if<llvm::hlsl::rootsig::RootDescriptor>(ptr: &Elem)) {
1184 RangeInfo Info;
1185 Info.LowerBound = Descriptor->Reg.Number;
1186 Info.UpperBound = Info.LowerBound; // use inclusive ranges []
1187
1188 Info.Class =
1189 llvm::dxil::ResourceClass(llvm::to_underlying(E: Descriptor->Type));
1190 Info.Space = Descriptor->Space;
1191 Info.Visibility = Descriptor->Visibility;
1192
1193 InfoPairs.push_back(Elt: {Info, &RootSigElem});
1194 } else if (const auto *Constants =
1195 std::get_if<llvm::hlsl::rootsig::RootConstants>(ptr: &Elem)) {
1196 RangeInfo Info;
1197 Info.LowerBound = Constants->Reg.Number;
1198 Info.UpperBound = Info.LowerBound; // use inclusive ranges []
1199
1200 Info.Class = llvm::dxil::ResourceClass::CBuffer;
1201 Info.Space = Constants->Space;
1202 Info.Visibility = Constants->Visibility;
1203
1204 InfoPairs.push_back(Elt: {Info, &RootSigElem});
1205 } else if (const auto *Sampler =
1206 std::get_if<llvm::hlsl::rootsig::StaticSampler>(ptr: &Elem)) {
1207 RangeInfo Info;
1208 Info.LowerBound = Sampler->Reg.Number;
1209 Info.UpperBound = Info.LowerBound; // use inclusive ranges []
1210
1211 Info.Class = llvm::dxil::ResourceClass::Sampler;
1212 Info.Space = Sampler->Space;
1213 Info.Visibility = Sampler->Visibility;
1214
1215 InfoPairs.push_back(Elt: {Info, &RootSigElem});
1216 } else if (const auto *Clause =
1217 std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
1218 ptr: &Elem)) {
1219 RangeInfo Info;
1220 Info.LowerBound = Clause->Reg.Number;
1221 // Relevant error will have already been reported above and needs to be
1222 // fixed before we can conduct range analysis, so shortcut error return
1223 if (Clause->NumDescriptors == 0)
1224 return true;
1225 Info.UpperBound = Clause->NumDescriptors == RangeInfo::Unbounded
1226 ? RangeInfo::Unbounded
1227 : Info.LowerBound + Clause->NumDescriptors -
1228 1; // use inclusive ranges []
1229
1230 Info.Class = Clause->Type;
1231 Info.Space = Clause->Space;
1232
1233 // Note: Clause does not hold the visibility this will need to
1234 InfoPairs.push_back(Elt: {Info, &RootSigElem});
1235 } else if (const auto *Table =
1236 std::get_if<llvm::hlsl::rootsig::DescriptorTable>(ptr: &Elem)) {
1237 // Table holds the Visibility of all owned Clauses in Table, so iterate
1238 // owned Clauses and update their corresponding RangeInfo
1239 assert(Table->NumClauses <= InfoPairs.size() && "RootElement");
1240 // The last Table->NumClauses elements of Infos are the owned Clauses
1241 // generated RangeInfo
1242 auto TableInfos =
1243 MutableArrayRef<InfoPairT>(InfoPairs).take_back(N: Table->NumClauses);
1244 for (InfoPairT &Pair : TableInfos)
1245 Pair.first.Visibility = Table->Visibility;
1246 }
1247 }
1248
1249 // 2. Sort with the RangeInfo <operator to prepare it for findOverlapping
1250 llvm::sort(C&: InfoPairs,
1251 Comp: [](InfoPairT A, InfoPairT B) { return A.first < B.first; });
1252
1253 llvm::SmallVector<RangeInfo> Infos;
1254 for (const InfoPairT &Pair : InfoPairs)
1255 Infos.push_back(Elt: Pair.first);
1256
1257 // Helpers to report diagnostics
1258 uint32_t DuplicateCounter = 0;
1259 using ElemPair = std::pair<const hlsl::RootSignatureElement *,
1260 const hlsl::RootSignatureElement *>;
1261 auto GetElemPair = [&Infos, &InfoPairs, &DuplicateCounter](
1262 OverlappingRanges Overlap) -> ElemPair {
1263 // Given we sorted the InfoPairs (and by implication) Infos, and,
1264 // that Overlap.B is the item retrieved from the ResourceRange. Then it is
1265 // guarenteed that Overlap.B <= Overlap.A.
1266 //
1267 // So we will find Overlap.B first and then continue to find Overlap.A
1268 // after
1269 auto InfoB = std::lower_bound(first: Infos.begin(), last: Infos.end(), val: *Overlap.B);
1270 auto DistB = std::distance(first: Infos.begin(), last: InfoB);
1271 auto PairB = InfoPairs.begin();
1272 std::advance(i&: PairB, n: DistB);
1273
1274 auto InfoA = std::lower_bound(first: InfoB, last: Infos.end(), val: *Overlap.A);
1275 // Similarily, from the property that we have sorted the RangeInfos,
1276 // all duplicates will be processed one after the other. So
1277 // DuplicateCounter can be re-used for each set of duplicates we
1278 // encounter as we handle incoming errors
1279 DuplicateCounter = InfoA == InfoB ? DuplicateCounter + 1 : 0;
1280 auto DistA = std::distance(first: InfoB, last: InfoA) + DuplicateCounter;
1281 auto PairA = PairB;
1282 std::advance(i&: PairA, n: DistA);
1283
1284 return {PairA->second, PairB->second};
1285 };
1286
1287 auto ReportOverlap = [this, &GetElemPair](OverlappingRanges Overlap) {
1288 auto Pair = GetElemPair(Overlap);
1289 const RangeInfo *Info = Overlap.A;
1290 const hlsl::RootSignatureElement *Elem = Pair.first;
1291 const RangeInfo *OInfo = Overlap.B;
1292
1293 auto CommonVis = Info->Visibility == llvm::dxbc::ShaderVisibility::All
1294 ? OInfo->Visibility
1295 : Info->Visibility;
1296 this->Diag(Loc: Elem->getLocation(), DiagID: diag::err_hlsl_resource_range_overlap)
1297 << llvm::to_underlying(E: Info->Class) << Info->LowerBound
1298 << /*unbounded=*/(Info->UpperBound == RangeInfo::Unbounded)
1299 << Info->UpperBound << llvm::to_underlying(E: OInfo->Class)
1300 << OInfo->LowerBound
1301 << /*unbounded=*/(OInfo->UpperBound == RangeInfo::Unbounded)
1302 << OInfo->UpperBound << Info->Space << CommonVis;
1303
1304 const hlsl::RootSignatureElement *OElem = Pair.second;
1305 this->Diag(Loc: OElem->getLocation(), DiagID: diag::note_hlsl_resource_range_here);
1306 };
1307
1308 // 3. Invoke find overlapping ranges
1309 llvm::SmallVector<OverlappingRanges> Overlaps =
1310 llvm::hlsl::rootsig::findOverlappingRanges(Infos);
1311 for (OverlappingRanges Overlap : Overlaps)
1312 ReportOverlap(Overlap);
1313
1314 return Overlaps.size() != 0;
1315}
1316
1317void SemaHLSL::handleRootSignatureAttr(Decl *D, const ParsedAttr &AL) {
1318 if (AL.getNumArgs() != 1) {
1319 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_wrong_number_arguments) << AL << 1;
1320 return;
1321 }
1322
1323 IdentifierInfo *Ident = AL.getArgAsIdent(Arg: 0)->getIdentifierInfo();
1324 if (auto *RS = D->getAttr<RootSignatureAttr>()) {
1325 if (RS->getSignatureIdent() != Ident) {
1326 Diag(Loc: AL.getLoc(), DiagID: diag::err_disallowed_duplicate_attribute) << RS;
1327 return;
1328 }
1329
1330 Diag(Loc: AL.getLoc(), DiagID: diag::warn_duplicate_attribute_exact) << RS;
1331 return;
1332 }
1333
1334 LookupResult R(SemaRef, Ident, SourceLocation(), Sema::LookupOrdinaryName);
1335 if (SemaRef.LookupQualifiedName(R, LookupCtx: D->getDeclContext()))
1336 if (auto *SignatureDecl =
1337 dyn_cast<HLSLRootSignatureDecl>(Val: R.getFoundDecl())) {
1338 D->addAttr(A: ::new (getASTContext()) RootSignatureAttr(
1339 getASTContext(), AL, Ident, SignatureDecl));
1340 }
1341}
1342
1343void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
1344 llvm::VersionTuple SMVersion =
1345 getASTContext().getTargetInfo().getTriple().getOSVersion();
1346 bool IsDXIL = getASTContext().getTargetInfo().getTriple().getArch() ==
1347 llvm::Triple::dxil;
1348
1349 uint32_t ZMax = 1024;
1350 uint32_t ThreadMax = 1024;
1351 if (IsDXIL && SMVersion.getMajor() <= 4) {
1352 ZMax = 1;
1353 ThreadMax = 768;
1354 } else if (IsDXIL && SMVersion.getMajor() == 5) {
1355 ZMax = 64;
1356 ThreadMax = 1024;
1357 }
1358
1359 uint32_t X;
1360 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: X))
1361 return;
1362 if (X > 1024) {
1363 Diag(Loc: AL.getArgAsExpr(Arg: 0)->getExprLoc(),
1364 DiagID: diag::err_hlsl_numthreads_argument_oor)
1365 << 0 << 1024;
1366 return;
1367 }
1368 uint32_t Y;
1369 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Y))
1370 return;
1371 if (Y > 1024) {
1372 Diag(Loc: AL.getArgAsExpr(Arg: 1)->getExprLoc(),
1373 DiagID: diag::err_hlsl_numthreads_argument_oor)
1374 << 1 << 1024;
1375 return;
1376 }
1377 uint32_t Z;
1378 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 2), Val&: Z))
1379 return;
1380 if (Z > ZMax) {
1381 SemaRef.Diag(Loc: AL.getArgAsExpr(Arg: 2)->getExprLoc(),
1382 DiagID: diag::err_hlsl_numthreads_argument_oor)
1383 << 2 << ZMax;
1384 return;
1385 }
1386
1387 if (X * Y * Z > ThreadMax) {
1388 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_numthreads_invalid) << ThreadMax;
1389 return;
1390 }
1391
1392 HLSLNumThreadsAttr *NewAttr = mergeNumThreadsAttr(D, AL, X, Y, Z);
1393 if (NewAttr)
1394 D->addAttr(A: NewAttr);
1395}
1396
1397static bool isValidWaveSizeValue(unsigned Value) {
1398 return llvm::isPowerOf2_32(Value) && Value >= 4 && Value <= 128;
1399}
1400
1401void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
1402 // validate that the wavesize argument is a power of 2 between 4 and 128
1403 // inclusive
1404 unsigned SpelledArgsCount = AL.getNumArgs();
1405 if (SpelledArgsCount == 0 || SpelledArgsCount > 3)
1406 return;
1407
1408 uint32_t Min;
1409 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Min))
1410 return;
1411
1412 uint32_t Max = 0;
1413 if (SpelledArgsCount > 1 &&
1414 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Max))
1415 return;
1416
1417 uint32_t Preferred = 0;
1418 if (SpelledArgsCount > 2 &&
1419 !SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 2), Val&: Preferred))
1420 return;
1421
1422 if (SpelledArgsCount > 2) {
1423 if (!isValidWaveSizeValue(Value: Preferred)) {
1424 Diag(Loc: AL.getArgAsExpr(Arg: 2)->getExprLoc(),
1425 DiagID: diag::err_attribute_power_of_two_in_range)
1426 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize
1427 << Preferred;
1428 return;
1429 }
1430 // Preferred not in range.
1431 if (Preferred < Min || Preferred > Max) {
1432 Diag(Loc: AL.getArgAsExpr(Arg: 2)->getExprLoc(),
1433 DiagID: diag::err_attribute_power_of_two_in_range)
1434 << AL << Min << Max << Preferred;
1435 return;
1436 }
1437 } else if (SpelledArgsCount > 1) {
1438 if (!isValidWaveSizeValue(Value: Max)) {
1439 Diag(Loc: AL.getArgAsExpr(Arg: 1)->getExprLoc(),
1440 DiagID: diag::err_attribute_power_of_two_in_range)
1441 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Max;
1442 return;
1443 }
1444 if (Max < Min) {
1445 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_invalid) << AL << 1;
1446 return;
1447 } else if (Max == Min) {
1448 Diag(Loc: AL.getLoc(), DiagID: diag::warn_attr_min_eq_max) << AL;
1449 }
1450 } else {
1451 if (!isValidWaveSizeValue(Value: Min)) {
1452 Diag(Loc: AL.getArgAsExpr(Arg: 0)->getExprLoc(),
1453 DiagID: diag::err_attribute_power_of_two_in_range)
1454 << AL << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize << Min;
1455 return;
1456 }
1457 }
1458
1459 HLSLWaveSizeAttr *NewAttr =
1460 mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount);
1461 if (NewAttr)
1462 D->addAttr(A: NewAttr);
1463}
1464
1465void SemaHLSL::handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL) {
1466 uint32_t ID;
1467 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: ID))
1468 return;
1469 D->addAttr(A: ::new (getASTContext())
1470 HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
1471}
1472
1473void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
1474 uint32_t Id;
1475 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: Id))
1476 return;
1477 HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
1478 if (NewAttr)
1479 D->addAttr(A: NewAttr);
1480}
1481
1482bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
1483 const auto *VT = T->getAs<VectorType>();
1484
1485 if (!T->hasUnsignedIntegerRepresentation() ||
1486 (VT && VT->getNumElements() > 3)) {
1487 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1488 << AL << "uint/uint2/uint3";
1489 return false;
1490 }
1491
1492 return true;
1493}
1494
1495void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
1496 auto *VD = cast<ValueDecl>(Val: D);
1497 if (!diagnoseInputIDType(T: VD->getType(), AL))
1498 return;
1499
1500 D->addAttr(A: ::new (getASTContext())
1501 HLSLSV_DispatchThreadIDAttr(getASTContext(), AL));
1502}
1503
1504bool SemaHLSL::diagnosePositionType(QualType T, const ParsedAttr &AL) {
1505 const auto *VT = T->getAs<VectorType>();
1506
1507 if (!T->hasFloatingRepresentation() || (VT && VT->getNumElements() > 4)) {
1508 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_type)
1509 << AL << "float/float1/float2/float3/float4";
1510 return false;
1511 }
1512
1513 return true;
1514}
1515
1516void SemaHLSL::handleSV_PositionAttr(Decl *D, const ParsedAttr &AL) {
1517 auto *VD = cast<ValueDecl>(Val: D);
1518 if (!diagnosePositionType(T: VD->getType(), AL))
1519 return;
1520
1521 D->addAttr(A: ::new (getASTContext()) HLSLSV_PositionAttr(getASTContext(), AL));
1522}
1523
1524void SemaHLSL::handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL) {
1525 auto *VD = cast<ValueDecl>(Val: D);
1526 if (!diagnoseInputIDType(T: VD->getType(), AL))
1527 return;
1528
1529 D->addAttr(A: ::new (getASTContext())
1530 HLSLSV_GroupThreadIDAttr(getASTContext(), AL));
1531}
1532
1533void SemaHLSL::handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL) {
1534 auto *VD = cast<ValueDecl>(Val: D);
1535 if (!diagnoseInputIDType(T: VD->getType(), AL))
1536 return;
1537
1538 D->addAttr(A: ::new (getASTContext()) HLSLSV_GroupIDAttr(getASTContext(), AL));
1539}
1540
1541void SemaHLSL::handlePackOffsetAttr(Decl *D, const ParsedAttr &AL) {
1542 if (!isa<VarDecl>(Val: D) || !isa<HLSLBufferDecl>(Val: D->getDeclContext())) {
1543 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attr_invalid_ast_node)
1544 << AL << "shader constant in a constant buffer";
1545 return;
1546 }
1547
1548 uint32_t SubComponent;
1549 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 0), Val&: SubComponent))
1550 return;
1551 uint32_t Component;
1552 if (!SemaRef.checkUInt32Argument(AI: AL, Expr: AL.getArgAsExpr(Arg: 1), Val&: Component))
1553 return;
1554
1555 QualType T = cast<VarDecl>(Val: D)->getType().getCanonicalType();
1556 // Check if T is an array or struct type.
1557 // TODO: mark matrix type as aggregate type.
1558 bool IsAggregateTy = (T->isArrayType() || T->isStructureType());
1559
1560 // Check Component is valid for T.
1561 if (Component) {
1562 unsigned Size = getASTContext().getTypeSize(T);
1563 if (IsAggregateTy || Size > 128) {
1564 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_packoffset_cross_reg_boundary);
1565 return;
1566 } else {
1567 // Make sure Component + sizeof(T) <= 4.
1568 if ((Component * 32 + Size) > 128) {
1569 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_packoffset_cross_reg_boundary);
1570 return;
1571 }
1572 QualType EltTy = T;
1573 if (const auto *VT = T->getAs<VectorType>())
1574 EltTy = VT->getElementType();
1575 unsigned Align = getASTContext().getTypeAlign(T: EltTy);
1576 if (Align > 32 && Component == 1) {
1577 // NOTE: Component 3 will hit err_hlsl_packoffset_cross_reg_boundary.
1578 // So we only need to check Component 1 here.
1579 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_packoffset_alignment_mismatch)
1580 << Align << EltTy;
1581 return;
1582 }
1583 }
1584 }
1585
1586 D->addAttr(A: ::new (getASTContext()) HLSLPackOffsetAttr(
1587 getASTContext(), AL, SubComponent, Component));
1588}
1589
1590void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) {
1591 StringRef Str;
1592 SourceLocation ArgLoc;
1593 if (!SemaRef.checkStringLiteralArgumentAttr(Attr: AL, ArgNum: 0, Str, ArgLocation: &ArgLoc))
1594 return;
1595
1596 llvm::Triple::EnvironmentType ShaderType;
1597 if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Val: Str, Out&: ShaderType)) {
1598 Diag(Loc: AL.getLoc(), DiagID: diag::warn_attribute_type_not_supported)
1599 << AL << Str << ArgLoc;
1600 return;
1601 }
1602
1603 // FIXME: check function match the shader stage.
1604
1605 HLSLShaderAttr *NewAttr = mergeShaderAttr(D, AL, ShaderType);
1606 if (NewAttr)
1607 D->addAttr(A: NewAttr);
1608}
1609
1610bool clang::CreateHLSLAttributedResourceType(
1611 Sema &S, QualType Wrapped, ArrayRef<const Attr *> AttrList,
1612 QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo) {
1613 assert(AttrList.size() && "expected list of resource attributes");
1614
1615 QualType ContainedTy = QualType();
1616 TypeSourceInfo *ContainedTyInfo = nullptr;
1617 SourceLocation LocBegin = AttrList[0]->getRange().getBegin();
1618 SourceLocation LocEnd = AttrList[0]->getRange().getEnd();
1619
1620 HLSLAttributedResourceType::Attributes ResAttrs;
1621
1622 bool HasResourceClass = false;
1623 for (const Attr *A : AttrList) {
1624 if (!A)
1625 continue;
1626 LocEnd = A->getRange().getEnd();
1627 switch (A->getKind()) {
1628 case attr::HLSLResourceClass: {
1629 ResourceClass RC = cast<HLSLResourceClassAttr>(Val: A)->getResourceClass();
1630 if (HasResourceClass) {
1631 S.Diag(Loc: A->getLocation(), DiagID: ResAttrs.ResourceClass == RC
1632 ? diag::warn_duplicate_attribute_exact
1633 : diag::warn_duplicate_attribute)
1634 << A;
1635 return false;
1636 }
1637 ResAttrs.ResourceClass = RC;
1638 HasResourceClass = true;
1639 break;
1640 }
1641 case attr::HLSLROV:
1642 if (ResAttrs.IsROV) {
1643 S.Diag(Loc: A->getLocation(), DiagID: diag::warn_duplicate_attribute_exact) << A;
1644 return false;
1645 }
1646 ResAttrs.IsROV = true;
1647 break;
1648 case attr::HLSLRawBuffer:
1649 if (ResAttrs.RawBuffer) {
1650 S.Diag(Loc: A->getLocation(), DiagID: diag::warn_duplicate_attribute_exact) << A;
1651 return false;
1652 }
1653 ResAttrs.RawBuffer = true;
1654 break;
1655 case attr::HLSLContainedType: {
1656 const HLSLContainedTypeAttr *CTAttr = cast<HLSLContainedTypeAttr>(Val: A);
1657 QualType Ty = CTAttr->getType();
1658 if (!ContainedTy.isNull()) {
1659 S.Diag(Loc: A->getLocation(), DiagID: ContainedTy == Ty
1660 ? diag::warn_duplicate_attribute_exact
1661 : diag::warn_duplicate_attribute)
1662 << A;
1663 return false;
1664 }
1665 ContainedTy = Ty;
1666 ContainedTyInfo = CTAttr->getTypeLoc();
1667 break;
1668 }
1669 default:
1670 llvm_unreachable("unhandled resource attribute type");
1671 }
1672 }
1673
1674 if (!HasResourceClass) {
1675 S.Diag(Loc: AttrList.back()->getRange().getEnd(),
1676 DiagID: diag::err_hlsl_missing_resource_class);
1677 return false;
1678 }
1679
1680 ResType = S.getASTContext().getHLSLAttributedResourceType(
1681 Wrapped, Contained: ContainedTy, Attrs: ResAttrs);
1682
1683 if (LocInfo && ContainedTyInfo) {
1684 LocInfo->Range = SourceRange(LocBegin, LocEnd);
1685 LocInfo->ContainedTyInfo = ContainedTyInfo;
1686 }
1687 return true;
1688}
1689
1690// Validates and creates an HLSL attribute that is applied as type attribute on
1691// HLSL resource. The attributes are collected in HLSLResourcesTypeAttrs and at
1692// the end of the declaration they are applied to the declaration type by
1693// wrapping it in HLSLAttributedResourceType.
1694bool SemaHLSL::handleResourceTypeAttr(QualType T, const ParsedAttr &AL) {
1695 // only allow resource type attributes on intangible types
1696 if (!T->isHLSLResourceType()) {
1697 Diag(Loc: AL.getLoc(), DiagID: diag::err_hlsl_attribute_needs_intangible_type)
1698 << AL << getASTContext().HLSLResourceTy;
1699 return false;
1700 }
1701
1702 // validate number of arguments
1703 if (!AL.checkExactlyNumArgs(S&: SemaRef, Num: AL.getMinArgs()))
1704 return false;
1705
1706 Attr *A = nullptr;
1707
1708 AttributeCommonInfo ACI(
1709 AL.getLoc(), AttributeScopeInfo(AL.getScopeName(), AL.getScopeLoc()),
1710 AttributeCommonInfo::NoSemaHandlerAttribute,
1711 {
1712 AttributeCommonInfo::AS_CXX11, 0, false /*IsAlignas*/,
1713 false /*IsRegularKeywordAttribute*/
1714 });
1715
1716 switch (AL.getKind()) {
1717 case ParsedAttr::AT_HLSLResourceClass: {
1718 if (!AL.isArgIdent(Arg: 0)) {
1719 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type)
1720 << AL << AANT_ArgumentIdentifier;
1721 return false;
1722 }
1723
1724 IdentifierLoc *Loc = AL.getArgAsIdent(Arg: 0);
1725 StringRef Identifier = Loc->getIdentifierInfo()->getName();
1726 SourceLocation ArgLoc = Loc->getLoc();
1727
1728 // Validate resource class value
1729 ResourceClass RC;
1730 if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Val: Identifier, Out&: RC)) {
1731 Diag(Loc: ArgLoc, DiagID: diag::warn_attribute_type_not_supported)
1732 << "ResourceClass" << Identifier;
1733 return false;
1734 }
1735 A = HLSLResourceClassAttr::Create(Ctx&: getASTContext(), ResourceClass: RC, CommonInfo: ACI);
1736 break;
1737 }
1738
1739 case ParsedAttr::AT_HLSLROV:
1740 A = HLSLROVAttr::Create(Ctx&: getASTContext(), CommonInfo: ACI);
1741 break;
1742
1743 case ParsedAttr::AT_HLSLRawBuffer:
1744 A = HLSLRawBufferAttr::Create(Ctx&: getASTContext(), CommonInfo: ACI);
1745 break;
1746
1747 case ParsedAttr::AT_HLSLContainedType: {
1748 if (AL.getNumArgs() != 1 && !AL.hasParsedType()) {
1749 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_wrong_number_arguments) << AL << 1;
1750 return false;
1751 }
1752
1753 TypeSourceInfo *TSI = nullptr;
1754 QualType QT = SemaRef.GetTypeFromParser(Ty: AL.getTypeArg(), TInfo: &TSI);
1755 assert(TSI && "no type source info for attribute argument");
1756 if (SemaRef.RequireCompleteType(Loc: TSI->getTypeLoc().getBeginLoc(), T: QT,
1757 DiagID: diag::err_incomplete_type))
1758 return false;
1759 A = HLSLContainedTypeAttr::Create(Ctx&: getASTContext(), Type: TSI, CommonInfo: ACI);
1760 break;
1761 }
1762
1763 default:
1764 llvm_unreachable("unhandled HLSL attribute");
1765 }
1766
1767 HLSLResourcesTypeAttrs.emplace_back(Args&: A);
1768 return true;
1769}
1770
1771// Combines all resource type attributes and creates HLSLAttributedResourceType.
1772QualType SemaHLSL::ProcessResourceTypeAttributes(QualType CurrentType) {
1773 if (!HLSLResourcesTypeAttrs.size())
1774 return CurrentType;
1775
1776 QualType QT = CurrentType;
1777 HLSLAttributedResourceLocInfo LocInfo;
1778 if (CreateHLSLAttributedResourceType(S&: SemaRef, Wrapped: CurrentType,
1779 AttrList: HLSLResourcesTypeAttrs, ResType&: QT, LocInfo: &LocInfo)) {
1780 const HLSLAttributedResourceType *RT =
1781 cast<HLSLAttributedResourceType>(Val: QT.getTypePtr());
1782
1783 // Temporarily store TypeLoc information for the new type.
1784 // It will be transferred to HLSLAttributesResourceTypeLoc
1785 // shortly after the type is created by TypeSpecLocFiller which
1786 // will call the TakeLocForHLSLAttribute method below.
1787 LocsForHLSLAttributedResources.insert(KV: std::pair(RT, LocInfo));
1788 }
1789 HLSLResourcesTypeAttrs.clear();
1790 return QT;
1791}
1792
1793// Returns source location for the HLSLAttributedResourceType
1794HLSLAttributedResourceLocInfo
1795SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
1796 HLSLAttributedResourceLocInfo LocInfo = {};
1797 auto I = LocsForHLSLAttributedResources.find(Val: RT);
1798 if (I != LocsForHLSLAttributedResources.end()) {
1799 LocInfo = I->second;
1800 LocsForHLSLAttributedResources.erase(I);
1801 return LocInfo;
1802 }
1803 LocInfo.Range = SourceRange();
1804 return LocInfo;
1805}
1806
1807// Walks though the global variable declaration, collects all resource binding
1808// requirements and adds them to Bindings
1809void SemaHLSL::collectResourceBindingsOnUserRecordDecl(const VarDecl *VD,
1810 const RecordType *RT) {
1811 const RecordDecl *RD = RT->getDecl();
1812 for (FieldDecl *FD : RD->fields()) {
1813 const Type *Ty = FD->getType()->getUnqualifiedDesugaredType();
1814
1815 // Unwrap arrays
1816 // FIXME: Calculate array size while unwrapping
1817 assert(!Ty->isIncompleteArrayType() &&
1818 "incomplete arrays inside user defined types are not supported");
1819 while (Ty->isConstantArrayType()) {
1820 const ConstantArrayType *CAT = cast<ConstantArrayType>(Val: Ty);
1821 Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
1822 }
1823
1824 if (!Ty->isRecordType())
1825 continue;
1826
1827 if (const HLSLAttributedResourceType *AttrResType =
1828 HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty)) {
1829 // Add a new DeclBindingInfo to Bindings if it does not already exist
1830 ResourceClass RC = AttrResType->getAttrs().ResourceClass;
1831 DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, ResClass: RC);
1832 if (!DBI)
1833 Bindings.addDeclBindingInfo(VD, ResClass: RC);
1834 } else if (const RecordType *RT = dyn_cast<RecordType>(Val: Ty)) {
1835 // Recursively scan embedded struct or class; it would be nice to do this
1836 // without recursion, but tricky to correctly calculate the size of the
1837 // binding, which is something we are probably going to need to do later
1838 // on. Hopefully nesting of structs in structs too many levels is
1839 // unlikely.
1840 collectResourceBindingsOnUserRecordDecl(VD, RT);
1841 }
1842 }
1843}
1844
1845// Diagnose localized register binding errors for a single binding; does not
1846// diagnose resource binding on user record types, that will be done later
1847// in processResourceBindingOnDecl based on the information collected in
1848// collectResourceBindingsOnVarDecl.
1849// Returns false if the register binding is not valid.
1850static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc,
1851 Decl *D, RegisterType RegType,
1852 bool SpecifiedSpace) {
1853 int RegTypeNum = static_cast<int>(RegType);
1854
1855 // check if the decl type is groupshared
1856 if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
1857 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1858 return false;
1859 }
1860
1861 // Cbuffers and Tbuffers are HLSLBufferDecl types
1862 if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(Val: D)) {
1863 ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
1864 : ResourceClass::SRV;
1865 if (RegType == getRegisterType(RC))
1866 return true;
1867
1868 S.Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_binding_type_mismatch)
1869 << RegTypeNum;
1870 return false;
1871 }
1872
1873 // Samplers, UAVs, and SRVs are VarDecl types
1874 assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
1875 VarDecl *VD = cast<VarDecl>(Val: D);
1876
1877 // Resource
1878 if (const HLSLAttributedResourceType *AttrResType =
1879 HLSLAttributedResourceType::findHandleTypeOnResource(
1880 RT: VD->getType().getTypePtr())) {
1881 if (RegType == getRegisterType(RC: AttrResType->getAttrs().ResourceClass))
1882 return true;
1883
1884 S.Diag(Loc: D->getLocation(), DiagID: diag::err_hlsl_binding_type_mismatch)
1885 << RegTypeNum;
1886 return false;
1887 }
1888
1889 const clang::Type *Ty = VD->getType().getTypePtr();
1890 while (Ty->isArrayType())
1891 Ty = Ty->getArrayElementTypeNoTypeQual();
1892
1893 // Basic types
1894 if (Ty->isArithmeticType() || Ty->isVectorType()) {
1895 bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(Val: D->getDeclContext());
1896 if (SpecifiedSpace && !DeclaredInCOrTBuffer)
1897 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_space_on_global_constant);
1898
1899 if (!DeclaredInCOrTBuffer && (Ty->isIntegralType(Ctx: S.getASTContext()) ||
1900 Ty->isFloatingType() || Ty->isVectorType())) {
1901 // Register annotation on default constant buffer declaration ($Globals)
1902 if (RegType == RegisterType::CBuffer)
1903 S.Diag(Loc: ArgLoc, DiagID: diag::warn_hlsl_deprecated_register_type_b);
1904 else if (RegType != RegisterType::C)
1905 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1906 else
1907 return true;
1908 } else {
1909 if (RegType == RegisterType::C)
1910 S.Diag(Loc: ArgLoc, DiagID: diag::warn_hlsl_register_type_c_packoffset);
1911 else
1912 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1913 }
1914 return false;
1915 }
1916 if (Ty->isRecordType())
1917 // RecordTypes will be diagnosed in processResourceBindingOnDecl
1918 // that is called from ActOnVariableDeclarator
1919 return true;
1920
1921 // Anything else is an error
1922 S.Diag(Loc: ArgLoc, DiagID: diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
1923 return false;
1924}
1925
1926static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
1927 RegisterType regType) {
1928 // make sure that there are no two register annotations
1929 // applied to the decl with the same register type
1930 bool RegisterTypesDetected[5] = {false};
1931 RegisterTypesDetected[static_cast<int>(regType)] = true;
1932
1933 for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) {
1934 if (HLSLResourceBindingAttr *attr =
1935 dyn_cast<HLSLResourceBindingAttr>(Val: *it)) {
1936
1937 RegisterType otherRegType = attr->getRegisterType();
1938 if (RegisterTypesDetected[static_cast<int>(otherRegType)]) {
1939 int otherRegTypeNum = static_cast<int>(otherRegType);
1940 S.Diag(Loc: TheDecl->getLocation(),
1941 DiagID: diag::err_hlsl_duplicate_register_annotation)
1942 << otherRegTypeNum;
1943 return false;
1944 }
1945 RegisterTypesDetected[static_cast<int>(otherRegType)] = true;
1946 }
1947 }
1948 return true;
1949}
1950
1951static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
1952 Decl *D, RegisterType RegType,
1953 bool SpecifiedSpace) {
1954
1955 // exactly one of these two types should be set
1956 assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) ||
1957 (!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) &&
1958 "expecting VarDecl or HLSLBufferDecl");
1959
1960 // check if the declaration contains resource matching the register type
1961 if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace))
1962 return false;
1963
1964 // next, if multiple register annotations exist, check that none conflict.
1965 return ValidateMultipleRegisterAnnotations(S, TheDecl: D, regType: RegType);
1966}
1967
1968void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
1969 if (isa<VarDecl>(Val: TheDecl)) {
1970 if (SemaRef.RequireCompleteType(Loc: TheDecl->getBeginLoc(),
1971 T: cast<ValueDecl>(Val: TheDecl)->getType(),
1972 DiagID: diag::err_incomplete_type))
1973 return;
1974 }
1975
1976 StringRef Slot = "";
1977 StringRef Space = "";
1978 SourceLocation SlotLoc, SpaceLoc;
1979
1980 if (!AL.isArgIdent(Arg: 0)) {
1981 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type)
1982 << AL << AANT_ArgumentIdentifier;
1983 return;
1984 }
1985 IdentifierLoc *Loc = AL.getArgAsIdent(Arg: 0);
1986
1987 if (AL.getNumArgs() == 2) {
1988 Slot = Loc->getIdentifierInfo()->getName();
1989 SlotLoc = Loc->getLoc();
1990 if (!AL.isArgIdent(Arg: 1)) {
1991 Diag(Loc: AL.getLoc(), DiagID: diag::err_attribute_argument_type)
1992 << AL << AANT_ArgumentIdentifier;
1993 return;
1994 }
1995 Loc = AL.getArgAsIdent(Arg: 1);
1996 Space = Loc->getIdentifierInfo()->getName();
1997 SpaceLoc = Loc->getLoc();
1998 } else {
1999 StringRef Str = Loc->getIdentifierInfo()->getName();
2000 if (Str.starts_with(Prefix: "space")) {
2001 Space = Str;
2002 SpaceLoc = Loc->getLoc();
2003 } else {
2004 Slot = Str;
2005 SlotLoc = Loc->getLoc();
2006 Space = "space0";
2007 }
2008 }
2009
2010 RegisterType RegType = RegisterType::SRV;
2011 std::optional<unsigned> SlotNum;
2012 unsigned SpaceNum = 0;
2013
2014 // Validate slot
2015 if (!Slot.empty()) {
2016 if (!convertToRegisterType(Slot, RT: &RegType)) {
2017 Diag(Loc: SlotLoc, DiagID: diag::err_hlsl_binding_type_invalid) << Slot.substr(Start: 0, N: 1);
2018 return;
2019 }
2020 if (RegType == RegisterType::I) {
2021 Diag(Loc: SlotLoc, DiagID: diag::warn_hlsl_deprecated_register_type_i);
2022 return;
2023 }
2024 StringRef SlotNumStr = Slot.substr(Start: 1);
2025 unsigned N;
2026 if (SlotNumStr.getAsInteger(Radix: 10, Result&: N)) {
2027 Diag(Loc: SlotLoc, DiagID: diag::err_hlsl_unsupported_register_number);
2028 return;
2029 }
2030 SlotNum = N;
2031 }
2032
2033 // Validate space
2034 if (!Space.starts_with(Prefix: "space")) {
2035 Diag(Loc: SpaceLoc, DiagID: diag::err_hlsl_expected_space) << Space;
2036 return;
2037 }
2038 StringRef SpaceNumStr = Space.substr(Start: 5);
2039 if (SpaceNumStr.getAsInteger(Radix: 10, Result&: SpaceNum)) {
2040 Diag(Loc: SpaceLoc, DiagID: diag::err_hlsl_expected_space) << Space;
2041 return;
2042 }
2043
2044 // If we have slot, diagnose it is the right register type for the decl
2045 if (SlotNum.has_value())
2046 if (!DiagnoseHLSLRegisterAttribute(S&: SemaRef, ArgLoc&: SlotLoc, D: TheDecl, RegType,
2047 SpecifiedSpace: !SpaceLoc.isInvalid()))
2048 return;
2049
2050 HLSLResourceBindingAttr *NewAttr =
2051 HLSLResourceBindingAttr::Create(Ctx&: getASTContext(), Slot, Space, CommonInfo: AL);
2052 if (NewAttr) {
2053 NewAttr->setBinding(RT: RegType, SlotNum, SpaceNum);
2054 TheDecl->addAttr(A: NewAttr);
2055 }
2056}
2057
2058void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) {
2059 HLSLParamModifierAttr *NewAttr = mergeParamModifierAttr(
2060 D, AL,
2061 Spelling: static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));
2062 if (NewAttr)
2063 D->addAttr(A: NewAttr);
2064}
2065
2066namespace {
2067
2068/// This class implements HLSL availability diagnostics for default
2069/// and relaxed mode
2070///
2071/// The goal of this diagnostic is to emit an error or warning when an
2072/// unavailable API is found in code that is reachable from the shader
2073/// entry function or from an exported function (when compiling a shader
2074/// library).
2075///
2076/// This is done by traversing the AST of all shader entry point functions
2077/// and of all exported functions, and any functions that are referenced
2078/// from this AST. In other words, any functions that are reachable from
2079/// the entry points.
2080class DiagnoseHLSLAvailability : public DynamicRecursiveASTVisitor {
2081 Sema &SemaRef;
2082
2083 // Stack of functions to be scaned
2084 llvm::SmallVector<const FunctionDecl *, 8> DeclsToScan;
2085
2086 // Tracks which environments functions have been scanned in.
2087 //
2088 // Maps FunctionDecl to an unsigned number that represents the set of shader
2089 // environments the function has been scanned for.
2090 // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
2091 // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
2092 // (verified by static_asserts in Triple.cpp), we can use it to index
2093 // individual bits in the set, as long as we shift the values to start with 0
2094 // by subtracting the value of llvm::Triple::Pixel first.
2095 //
2096 // The N'th bit in the set will be set if the function has been scanned
2097 // in shader environment whose llvm::Triple::EnvironmentType integer value
2098 // equals (llvm::Triple::Pixel + N).
2099 //
2100 // For example, if a function has been scanned in compute and pixel stage
2101 // environment, the value will be 0x21 (100001 binary) because:
2102 //
2103 // (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
2104 // (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
2105 //
2106 // A FunctionDecl is mapped to 0 (or not included in the map) if it has not
2107 // been scanned in any environment.
2108 llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
2109
2110 // Do not access these directly, use the get/set methods below to make
2111 // sure the values are in sync
2112 llvm::Triple::EnvironmentType CurrentShaderEnvironment;
2113 unsigned CurrentShaderStageBit;
2114
2115 // True if scanning a function that was already scanned in a different
2116 // shader stage context, and therefore we should not report issues that
2117 // depend only on shader model version because they would be duplicate.
2118 bool ReportOnlyShaderStageIssues;
2119
2120 // Helper methods for dealing with current stage context / environment
2121 void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
2122 static_assert(sizeof(unsigned) >= 4);
2123 assert(HLSLShaderAttr::isValidShaderType(ShaderType));
2124 assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
2125 "ShaderType is too big for this bitmap"); // 31 is reserved for
2126 // "unknown"
2127
2128 unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
2129 CurrentShaderEnvironment = ShaderType;
2130 CurrentShaderStageBit = (1 << bitmapIndex);
2131 }
2132
2133 void SetUnknownShaderStageContext() {
2134 CurrentShaderEnvironment = llvm::Triple::UnknownEnvironment;
2135 CurrentShaderStageBit = (1 << 31);
2136 }
2137
2138 llvm::Triple::EnvironmentType GetCurrentShaderEnvironment() const {
2139 return CurrentShaderEnvironment;
2140 }
2141
2142 bool InUnknownShaderStageContext() const {
2143 return CurrentShaderEnvironment == llvm::Triple::UnknownEnvironment;
2144 }
2145
2146 // Helper methods for dealing with shader stage bitmap
2147 void AddToScannedFunctions(const FunctionDecl *FD) {
2148 unsigned &ScannedStages = ScannedDecls[FD];
2149 ScannedStages |= CurrentShaderStageBit;
2150 }
2151
2152 unsigned GetScannedStages(const FunctionDecl *FD) { return ScannedDecls[FD]; }
2153
2154 bool WasAlreadyScannedInCurrentStage(const FunctionDecl *FD) {
2155 return WasAlreadyScannedInCurrentStage(ScannerStages: GetScannedStages(FD));
2156 }
2157
2158 bool WasAlreadyScannedInCurrentStage(unsigned ScannerStages) {
2159 return ScannerStages & CurrentShaderStageBit;
2160 }
2161
2162 static bool NeverBeenScanned(unsigned ScannedStages) {
2163 return ScannedStages == 0;
2164 }
2165
2166 // Scanning methods
2167 void HandleFunctionOrMethodRef(FunctionDecl *FD, Expr *RefExpr);
2168 void CheckDeclAvailability(NamedDecl *D, const AvailabilityAttr *AA,
2169 SourceRange Range);
2170 const AvailabilityAttr *FindAvailabilityAttr(const Decl *D);
2171 bool HasMatchingEnvironmentOrNone(const AvailabilityAttr *AA);
2172
2173public:
2174 DiagnoseHLSLAvailability(Sema &SemaRef)
2175 : SemaRef(SemaRef),
2176 CurrentShaderEnvironment(llvm::Triple::UnknownEnvironment),
2177 CurrentShaderStageBit(0), ReportOnlyShaderStageIssues(false) {}
2178
2179 // AST traversal methods
2180 void RunOnTranslationUnit(const TranslationUnitDecl *TU);
2181 void RunOnFunction(const FunctionDecl *FD);
2182
2183 bool VisitDeclRefExpr(DeclRefExpr *DRE) override {
2184 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: DRE->getDecl());
2185 if (FD)
2186 HandleFunctionOrMethodRef(FD, RefExpr: DRE);
2187 return true;
2188 }
2189
2190 bool VisitMemberExpr(MemberExpr *ME) override {
2191 FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: ME->getMemberDecl());
2192 if (FD)
2193 HandleFunctionOrMethodRef(FD, RefExpr: ME);
2194 return true;
2195 }
2196};
2197
2198void DiagnoseHLSLAvailability::HandleFunctionOrMethodRef(FunctionDecl *FD,
2199 Expr *RefExpr) {
2200 assert((isa<DeclRefExpr>(RefExpr) || isa<MemberExpr>(RefExpr)) &&
2201 "expected DeclRefExpr or MemberExpr");
2202
2203 // has a definition -> add to stack to be scanned
2204 const FunctionDecl *FDWithBody = nullptr;
2205 if (FD->hasBody(Definition&: FDWithBody)) {
2206 if (!WasAlreadyScannedInCurrentStage(FD: FDWithBody))
2207 DeclsToScan.push_back(Elt: FDWithBody);
2208 return;
2209 }
2210
2211 // no body -> diagnose availability
2212 const AvailabilityAttr *AA = FindAvailabilityAttr(D: FD);
2213 if (AA)
2214 CheckDeclAvailability(
2215 D: FD, AA, Range: SourceRange(RefExpr->getBeginLoc(), RefExpr->getEndLoc()));
2216}
2217
2218void DiagnoseHLSLAvailability::RunOnTranslationUnit(
2219 const TranslationUnitDecl *TU) {
2220
2221 // Iterate over all shader entry functions and library exports, and for those
2222 // that have a body (definiton), run diag scan on each, setting appropriate
2223 // shader environment context based on whether it is a shader entry function
2224 // or an exported function. Exported functions can be in namespaces and in
2225 // export declarations so we need to scan those declaration contexts as well.
2226 llvm::SmallVector<const DeclContext *, 8> DeclContextsToScan;
2227 DeclContextsToScan.push_back(Elt: TU);
2228
2229 while (!DeclContextsToScan.empty()) {
2230 const DeclContext *DC = DeclContextsToScan.pop_back_val();
2231 for (auto &D : DC->decls()) {
2232 // do not scan implicit declaration generated by the implementation
2233 if (D->isImplicit())
2234 continue;
2235
2236 // for namespace or export declaration add the context to the list to be
2237 // scanned later
2238 if (llvm::dyn_cast<NamespaceDecl>(Val: D) || llvm::dyn_cast<ExportDecl>(Val: D)) {
2239 DeclContextsToScan.push_back(Elt: llvm::dyn_cast<DeclContext>(Val: D));
2240 continue;
2241 }
2242
2243 // skip over other decls or function decls without body
2244 const FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(Val: D);
2245 if (!FD || !FD->isThisDeclarationADefinition())
2246 continue;
2247
2248 // shader entry point
2249 if (HLSLShaderAttr *ShaderAttr = FD->getAttr<HLSLShaderAttr>()) {
2250 SetShaderStageContext(ShaderAttr->getType());
2251 RunOnFunction(FD);
2252 continue;
2253 }
2254 // exported library function
2255 // FIXME: replace this loop with external linkage check once issue #92071
2256 // is resolved
2257 bool isExport = FD->isInExportDeclContext();
2258 if (!isExport) {
2259 for (const auto *Redecl : FD->redecls()) {
2260 if (Redecl->isInExportDeclContext()) {
2261 isExport = true;
2262 break;
2263 }
2264 }
2265 }
2266 if (isExport) {
2267 SetUnknownShaderStageContext();
2268 RunOnFunction(FD);
2269 continue;
2270 }
2271 }
2272 }
2273}
2274
2275void DiagnoseHLSLAvailability::RunOnFunction(const FunctionDecl *FD) {
2276 assert(DeclsToScan.empty() && "DeclsToScan should be empty");
2277 DeclsToScan.push_back(Elt: FD);
2278
2279 while (!DeclsToScan.empty()) {
2280 // Take one decl from the stack and check it by traversing its AST.
2281 // For any CallExpr found during the traversal add it's callee to the top of
2282 // the stack to be processed next. Functions already processed are stored in
2283 // ScannedDecls.
2284 const FunctionDecl *FD = DeclsToScan.pop_back_val();
2285
2286 // Decl was already scanned
2287 const unsigned ScannedStages = GetScannedStages(FD);
2288 if (WasAlreadyScannedInCurrentStage(ScannerStages: ScannedStages))
2289 continue;
2290
2291 ReportOnlyShaderStageIssues = !NeverBeenScanned(ScannedStages);
2292
2293 AddToScannedFunctions(FD);
2294 TraverseStmt(S: FD->getBody());
2295 }
2296}
2297
2298bool DiagnoseHLSLAvailability::HasMatchingEnvironmentOrNone(
2299 const AvailabilityAttr *AA) {
2300 IdentifierInfo *IIEnvironment = AA->getEnvironment();
2301 if (!IIEnvironment)
2302 return true;
2303
2304 llvm::Triple::EnvironmentType CurrentEnv = GetCurrentShaderEnvironment();
2305 if (CurrentEnv == llvm::Triple::UnknownEnvironment)
2306 return false;
2307
2308 llvm::Triple::EnvironmentType AttrEnv =
2309 AvailabilityAttr::getEnvironmentType(Environment: IIEnvironment->getName());
2310
2311 return CurrentEnv == AttrEnv;
2312}
2313
2314const AvailabilityAttr *
2315DiagnoseHLSLAvailability::FindAvailabilityAttr(const Decl *D) {
2316 AvailabilityAttr const *PartialMatch = nullptr;
2317 // Check each AvailabilityAttr to find the one for this platform.
2318 // For multiple attributes with the same platform try to find one for this
2319 // environment.
2320 for (const auto *A : D->attrs()) {
2321 if (const auto *Avail = dyn_cast<AvailabilityAttr>(Val: A)) {
2322 StringRef AttrPlatform = Avail->getPlatform()->getName();
2323 StringRef TargetPlatform =
2324 SemaRef.getASTContext().getTargetInfo().getPlatformName();
2325
2326 // Match the platform name.
2327 if (AttrPlatform == TargetPlatform) {
2328 // Find the best matching attribute for this environment
2329 if (HasMatchingEnvironmentOrNone(AA: Avail))
2330 return Avail;
2331 PartialMatch = Avail;
2332 }
2333 }
2334 }
2335 return PartialMatch;
2336}
2337
2338// Check availability against target shader model version and current shader
2339// stage and emit diagnostic
2340void DiagnoseHLSLAvailability::CheckDeclAvailability(NamedDecl *D,
2341 const AvailabilityAttr *AA,
2342 SourceRange Range) {
2343
2344 IdentifierInfo *IIEnv = AA->getEnvironment();
2345
2346 if (!IIEnv) {
2347 // The availability attribute does not have environment -> it depends only
2348 // on shader model version and not on specific the shader stage.
2349
2350 // Skip emitting the diagnostics if the diagnostic mode is set to
2351 // strict (-fhlsl-strict-availability) because all relevant diagnostics
2352 // were already emitted in the DiagnoseUnguardedAvailability scan
2353 // (SemaAvailability.cpp).
2354 if (SemaRef.getLangOpts().HLSLStrictAvailability)
2355 return;
2356
2357 // Do not report shader-stage-independent issues if scanning a function
2358 // that was already scanned in a different shader stage context (they would
2359 // be duplicate)
2360 if (ReportOnlyShaderStageIssues)
2361 return;
2362
2363 } else {
2364 // The availability attribute has environment -> we need to know
2365 // the current stage context to property diagnose it.
2366 if (InUnknownShaderStageContext())
2367 return;
2368 }
2369
2370 // Check introduced version and if environment matches
2371 bool EnvironmentMatches = HasMatchingEnvironmentOrNone(AA);
2372 VersionTuple Introduced = AA->getIntroduced();
2373 VersionTuple TargetVersion =
2374 SemaRef.Context.getTargetInfo().getPlatformMinVersion();
2375
2376 if (TargetVersion >= Introduced && EnvironmentMatches)
2377 return;
2378
2379 // Emit diagnostic message
2380 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
2381 llvm::StringRef PlatformName(
2382 AvailabilityAttr::getPrettyPlatformName(Platform: TI.getPlatformName()));
2383
2384 llvm::StringRef CurrentEnvStr =
2385 llvm::Triple::getEnvironmentTypeName(Kind: GetCurrentShaderEnvironment());
2386
2387 llvm::StringRef AttrEnvStr =
2388 AA->getEnvironment() ? AA->getEnvironment()->getName() : "";
2389 bool UseEnvironment = !AttrEnvStr.empty();
2390
2391 if (EnvironmentMatches) {
2392 SemaRef.Diag(Loc: Range.getBegin(), DiagID: diag::warn_hlsl_availability)
2393 << Range << D << PlatformName << Introduced.getAsString()
2394 << UseEnvironment << CurrentEnvStr;
2395 } else {
2396 SemaRef.Diag(Loc: Range.getBegin(), DiagID: diag::warn_hlsl_availability_unavailable)
2397 << Range << D;
2398 }
2399
2400 SemaRef.Diag(Loc: D->getLocation(), DiagID: diag::note_partial_availability_specified_here)
2401 << D << PlatformName << Introduced.getAsString()
2402 << SemaRef.Context.getTargetInfo().getPlatformMinVersion().getAsString()
2403 << UseEnvironment << AttrEnvStr << CurrentEnvStr;
2404}
2405
2406} // namespace
2407
2408void SemaHLSL::ActOnEndOfTranslationUnit(TranslationUnitDecl *TU) {
2409 // process default CBuffer - create buffer layout struct and invoke codegenCGH
2410 if (!DefaultCBufferDecls.empty()) {
2411 HLSLBufferDecl *DefaultCBuffer = HLSLBufferDecl::CreateDefaultCBuffer(
2412 C&: SemaRef.getASTContext(), LexicalParent: SemaRef.getCurLexicalContext(),
2413 DefaultCBufferDecls);
2414 addImplicitBindingAttrToBuffer(S&: SemaRef, BufDecl: DefaultCBuffer,
2415 ImplicitBindingOrderID: getNextImplicitBindingOrderID());
2416 SemaRef.getCurLexicalContext()->addDecl(D: DefaultCBuffer);
2417 createHostLayoutStructForBuffer(S&: SemaRef, BufDecl: DefaultCBuffer);
2418
2419 // Set HasValidPackoffset if any of the decls has a register(c#) annotation;
2420 for (const Decl *VD : DefaultCBufferDecls) {
2421 const HLSLResourceBindingAttr *RBA =
2422 VD->getAttr<HLSLResourceBindingAttr>();
2423 if (RBA && RBA->hasRegisterSlot() &&
2424 RBA->getRegisterType() == HLSLResourceBindingAttr::RegisterType::C) {
2425 DefaultCBuffer->setHasValidPackoffset(true);
2426 break;
2427 }
2428 }
2429
2430 DeclGroupRef DG(DefaultCBuffer);
2431 SemaRef.Consumer.HandleTopLevelDecl(D: DG);
2432 }
2433 diagnoseAvailabilityViolations(TU);
2434}
2435
2436void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
2437 // Skip running the diagnostics scan if the diagnostic mode is
2438 // strict (-fhlsl-strict-availability) and the target shader stage is known
2439 // because all relevant diagnostics were already emitted in the
2440 // DiagnoseUnguardedAvailability scan (SemaAvailability.cpp).
2441 const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();
2442 if (SemaRef.getLangOpts().HLSLStrictAvailability &&
2443 TI.getTriple().getEnvironment() != llvm::Triple::EnvironmentType::Library)
2444 return;
2445
2446 DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU);
2447}
2448
2449static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
2450 assert(TheCall->getNumArgs() > 1);
2451 QualType ArgTy0 = TheCall->getArg(Arg: 0)->getType();
2452
2453 for (unsigned I = 1, N = TheCall->getNumArgs(); I < N; ++I) {
2454 if (!S->getASTContext().hasSameUnqualifiedType(
2455 T1: ArgTy0, T2: TheCall->getArg(Arg: I)->getType())) {
2456 S->Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_vec_builtin_incompatible_vector)
2457 << TheCall->getDirectCallee() << /*useAllTerminology*/ true
2458 << SourceRange(TheCall->getArg(Arg: 0)->getBeginLoc(),
2459 TheCall->getArg(Arg: N - 1)->getEndLoc());
2460 return true;
2461 }
2462 }
2463 return false;
2464}
2465
2466static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
2467 QualType ArgType = Arg->getType();
2468 if (!S->getASTContext().hasSameUnqualifiedType(T1: ArgType, T2: ExpectedType)) {
2469 S->Diag(Loc: Arg->getBeginLoc(), DiagID: diag::err_typecheck_convert_incompatible)
2470 << ArgType << ExpectedType << 1 << 0 << 0;
2471 return true;
2472 }
2473 return false;
2474}
2475
2476static bool CheckAllArgTypesAreCorrect(
2477 Sema *S, CallExpr *TheCall,
2478 llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
2479 clang::QualType PassedType)>
2480 Check) {
2481 for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
2482 Expr *Arg = TheCall->getArg(Arg: I);
2483 if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
2484 return true;
2485 }
2486 return false;
2487}
2488
2489static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
2490 int ArgOrdinal,
2491 clang::QualType PassedType) {
2492 clang::QualType BaseType =
2493 PassedType->isVectorType()
2494 ? PassedType->castAs<clang::VectorType>()->getElementType()
2495 : PassedType;
2496 if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
2497 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
2498 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
2499 << /* half or float */ 2 << PassedType;
2500 return false;
2501}
2502
2503static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
2504 unsigned ArgIndex) {
2505 auto *Arg = TheCall->getArg(Arg: ArgIndex);
2506 SourceLocation OrigLoc = Arg->getExprLoc();
2507 if (Arg->IgnoreCasts()->isModifiableLvalue(Ctx&: S->Context, Loc: &OrigLoc) ==
2508 Expr::MLV_Valid)
2509 return false;
2510 S->Diag(Loc: OrigLoc, DiagID: diag::error_hlsl_inout_lvalue) << Arg << 0;
2511 return true;
2512}
2513
2514static bool CheckNoDoubleVectors(Sema *S, SourceLocation Loc, int ArgOrdinal,
2515 clang::QualType PassedType) {
2516 const auto *VecTy = PassedType->getAs<VectorType>();
2517 if (!VecTy)
2518 return false;
2519
2520 if (VecTy->getElementType()->isDoubleType())
2521 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
2522 << ArgOrdinal << /* scalar */ 1 << /* no int */ 0 << /* fp */ 1
2523 << PassedType;
2524 return false;
2525}
2526
2527static bool CheckFloatingOrIntRepresentation(Sema *S, SourceLocation Loc,
2528 int ArgOrdinal,
2529 clang::QualType PassedType) {
2530 if (!PassedType->hasIntegerRepresentation() &&
2531 !PassedType->hasFloatingRepresentation())
2532 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
2533 << ArgOrdinal << /* scalar or vector of */ 5 << /* integer */ 1
2534 << /* fp */ 1 << PassedType;
2535 return false;
2536}
2537
2538static bool CheckUnsignedIntVecRepresentation(Sema *S, SourceLocation Loc,
2539 int ArgOrdinal,
2540 clang::QualType PassedType) {
2541 if (auto *VecTy = PassedType->getAs<VectorType>())
2542 if (VecTy->getElementType()->isUnsignedIntegerType())
2543 return false;
2544
2545 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
2546 << ArgOrdinal << /* vector of */ 4 << /* uint */ 3 << /* no fp */ 0
2547 << PassedType;
2548}
2549
2550// checks for unsigned ints of all sizes
2551static bool CheckUnsignedIntRepresentation(Sema *S, SourceLocation Loc,
2552 int ArgOrdinal,
2553 clang::QualType PassedType) {
2554 if (!PassedType->hasUnsignedIntegerRepresentation())
2555 return S->Diag(Loc, DiagID: diag::err_builtin_invalid_arg_type)
2556 << ArgOrdinal << /* scalar or vector of */ 5 << /* unsigned int */ 3
2557 << /* no fp */ 0 << PassedType;
2558 return false;
2559}
2560
2561static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
2562 QualType ReturnType) {
2563 auto *VecTyA = TheCall->getArg(Arg: 0)->getType()->getAs<VectorType>();
2564 if (VecTyA)
2565 ReturnType =
2566 S->Context.getExtVectorType(VectorType: ReturnType, NumElts: VecTyA->getNumElements());
2567
2568 TheCall->setType(ReturnType);
2569}
2570
2571static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
2572 unsigned ArgIndex) {
2573 assert(TheCall->getNumArgs() >= ArgIndex);
2574 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
2575 auto *VTy = ArgType->getAs<VectorType>();
2576 // not the scalar or vector<scalar>
2577 if (!(S->Context.hasSameUnqualifiedType(T1: ArgType, T2: Scalar) ||
2578 (VTy &&
2579 S->Context.hasSameUnqualifiedType(T1: VTy->getElementType(), T2: Scalar)))) {
2580 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
2581 DiagID: diag::err_typecheck_expect_scalar_or_vector)
2582 << ArgType << Scalar;
2583 return true;
2584 }
2585 return false;
2586}
2587
2588static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
2589 unsigned ArgIndex) {
2590 assert(TheCall->getNumArgs() >= ArgIndex);
2591 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
2592 auto *VTy = ArgType->getAs<VectorType>();
2593 // not the scalar or vector<scalar>
2594 if (!(ArgType->isScalarType() ||
2595 (VTy && VTy->getElementType()->isScalarType()))) {
2596 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
2597 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
2598 << ArgType << 1;
2599 return true;
2600 }
2601 return false;
2602}
2603
2604static bool CheckWaveActive(Sema *S, CallExpr *TheCall) {
2605 QualType BoolType = S->getASTContext().BoolTy;
2606 assert(TheCall->getNumArgs() >= 1);
2607 QualType ArgType = TheCall->getArg(Arg: 0)->getType();
2608 auto *VTy = ArgType->getAs<VectorType>();
2609 // is the bool or vector<bool>
2610 if (S->Context.hasSameUnqualifiedType(T1: ArgType, T2: BoolType) ||
2611 (VTy &&
2612 S->Context.hasSameUnqualifiedType(T1: VTy->getElementType(), T2: BoolType))) {
2613 S->Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
2614 DiagID: diag::err_typecheck_expect_any_scalar_or_vector)
2615 << ArgType << 0;
2616 return true;
2617 }
2618 return false;
2619}
2620
2621static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
2622 assert(TheCall->getNumArgs() == 3);
2623 Expr *Arg1 = TheCall->getArg(Arg: 1);
2624 Expr *Arg2 = TheCall->getArg(Arg: 2);
2625 if (!S->Context.hasSameUnqualifiedType(T1: Arg1->getType(), T2: Arg2->getType())) {
2626 S->Diag(Loc: TheCall->getBeginLoc(),
2627 DiagID: diag::err_typecheck_call_different_arg_types)
2628 << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
2629 << Arg2->getSourceRange();
2630 return true;
2631 }
2632
2633 TheCall->setType(Arg1->getType());
2634 return false;
2635}
2636
2637static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
2638 assert(TheCall->getNumArgs() == 3);
2639 Expr *Arg1 = TheCall->getArg(Arg: 1);
2640 QualType Arg1Ty = Arg1->getType();
2641 Expr *Arg2 = TheCall->getArg(Arg: 2);
2642 QualType Arg2Ty = Arg2->getType();
2643
2644 QualType Arg1ScalarTy = Arg1Ty;
2645 if (auto VTy = Arg1ScalarTy->getAs<VectorType>())
2646 Arg1ScalarTy = VTy->getElementType();
2647
2648 QualType Arg2ScalarTy = Arg2Ty;
2649 if (auto VTy = Arg2ScalarTy->getAs<VectorType>())
2650 Arg2ScalarTy = VTy->getElementType();
2651
2652 if (!S->Context.hasSameUnqualifiedType(T1: Arg1ScalarTy, T2: Arg2ScalarTy))
2653 S->Diag(Loc: Arg1->getBeginLoc(), DiagID: diag::err_hlsl_builtin_scalar_vector_mismatch)
2654 << /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty;
2655
2656 QualType Arg0Ty = TheCall->getArg(Arg: 0)->getType();
2657 unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements();
2658 unsigned Arg1Length = Arg1Ty->isVectorType()
2659 ? Arg1Ty->getAs<VectorType>()->getNumElements()
2660 : 0;
2661 unsigned Arg2Length = Arg2Ty->isVectorType()
2662 ? Arg2Ty->getAs<VectorType>()->getNumElements()
2663 : 0;
2664 if (Arg1Length > 0 && Arg0Length != Arg1Length) {
2665 S->Diag(Loc: TheCall->getBeginLoc(),
2666 DiagID: diag::err_typecheck_vector_lengths_not_equal)
2667 << Arg0Ty << Arg1Ty << TheCall->getArg(Arg: 0)->getSourceRange()
2668 << Arg1->getSourceRange();
2669 return true;
2670 }
2671
2672 if (Arg2Length > 0 && Arg0Length != Arg2Length) {
2673 S->Diag(Loc: TheCall->getBeginLoc(),
2674 DiagID: diag::err_typecheck_vector_lengths_not_equal)
2675 << Arg0Ty << Arg2Ty << TheCall->getArg(Arg: 0)->getSourceRange()
2676 << Arg2->getSourceRange();
2677 return true;
2678 }
2679
2680 TheCall->setType(
2681 S->getASTContext().getExtVectorType(VectorType: Arg1ScalarTy, NumElts: Arg0Length));
2682 return false;
2683}
2684
2685static bool CheckResourceHandle(
2686 Sema *S, CallExpr *TheCall, unsigned ArgIndex,
2687 llvm::function_ref<bool(const HLSLAttributedResourceType *ResType)> Check =
2688 nullptr) {
2689 assert(TheCall->getNumArgs() >= ArgIndex);
2690 QualType ArgType = TheCall->getArg(Arg: ArgIndex)->getType();
2691 const HLSLAttributedResourceType *ResTy =
2692 ArgType.getTypePtr()->getAs<HLSLAttributedResourceType>();
2693 if (!ResTy) {
2694 S->Diag(Loc: TheCall->getArg(Arg: ArgIndex)->getBeginLoc(),
2695 DiagID: diag::err_typecheck_expect_hlsl_resource)
2696 << ArgType;
2697 return true;
2698 }
2699 if (Check && Check(ResTy)) {
2700 S->Diag(Loc: TheCall->getArg(Arg: ArgIndex)->getExprLoc(),
2701 DiagID: diag::err_invalid_hlsl_resource_type)
2702 << ArgType;
2703 return true;
2704 }
2705 return false;
2706}
2707
2708// Note: returning true in this case results in CheckBuiltinFunctionCall
2709// returning an ExprError
2710bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
2711 switch (BuiltinID) {
2712 case Builtin::BI__builtin_hlsl_adduint64: {
2713 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
2714 return true;
2715
2716 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
2717 Check: CheckUnsignedIntVecRepresentation))
2718 return true;
2719
2720 auto *VTy = TheCall->getArg(Arg: 0)->getType()->getAs<VectorType>();
2721 // ensure arg integers are 32-bits
2722 uint64_t ElementBitCount = getASTContext()
2723 .getTypeSizeInChars(T: VTy->getElementType())
2724 .getQuantity() *
2725 8;
2726 if (ElementBitCount != 32) {
2727 SemaRef.Diag(Loc: TheCall->getBeginLoc(),
2728 DiagID: diag::err_integer_incorrect_bit_count)
2729 << 32 << ElementBitCount;
2730 return true;
2731 }
2732
2733 // ensure both args are vectors of total bit size of a multiple of 64
2734 int NumElementsArg = VTy->getNumElements();
2735 if (NumElementsArg != 2 && NumElementsArg != 4) {
2736 SemaRef.Diag(Loc: TheCall->getBeginLoc(), DiagID: diag::err_vector_incorrect_bit_count)
2737 << 1 /*a multiple of*/ << 64 << NumElementsArg * ElementBitCount;
2738 return true;
2739 }
2740
2741 // ensure first arg and second arg have the same type
2742 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
2743 return true;
2744
2745 ExprResult A = TheCall->getArg(Arg: 0);
2746 QualType ArgTyA = A.get()->getType();
2747 // return type is the same as the input type
2748 TheCall->setType(ArgTyA);
2749 break;
2750 }
2751 case Builtin::BI__builtin_hlsl_resource_getpointer: {
2752 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2) ||
2753 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
2754 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1),
2755 ExpectedType: SemaRef.getASTContext().UnsignedIntTy))
2756 return true;
2757
2758 auto *ResourceTy =
2759 TheCall->getArg(Arg: 0)->getType()->castAs<HLSLAttributedResourceType>();
2760 QualType ContainedTy = ResourceTy->getContainedType();
2761 auto ReturnType =
2762 SemaRef.Context.getAddrSpaceQualType(T: ContainedTy, AddressSpace: LangAS::hlsl_device);
2763 ReturnType = SemaRef.Context.getPointerType(T: ReturnType);
2764 TheCall->setType(ReturnType);
2765 TheCall->setValueKind(VK_LValue);
2766
2767 break;
2768 }
2769 case Builtin::BI__builtin_hlsl_resource_uninitializedhandle: {
2770 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1) ||
2771 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0))
2772 return true;
2773 // use the type of the handle (arg0) as a return type
2774 QualType ResourceTy = TheCall->getArg(Arg: 0)->getType();
2775 TheCall->setType(ResourceTy);
2776 break;
2777 }
2778 case Builtin::BI__builtin_hlsl_resource_handlefrombinding: {
2779 ASTContext &AST = SemaRef.getASTContext();
2780 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 6) ||
2781 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
2782 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1), ExpectedType: AST.UnsignedIntTy) ||
2783 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 2), ExpectedType: AST.UnsignedIntTy) ||
2784 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 3), ExpectedType: AST.IntTy) ||
2785 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 4), ExpectedType: AST.UnsignedIntTy) ||
2786 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 5),
2787 ExpectedType: AST.getPointerType(T: AST.CharTy.withConst())))
2788 return true;
2789 // use the type of the handle (arg0) as a return type
2790 QualType ResourceTy = TheCall->getArg(Arg: 0)->getType();
2791 TheCall->setType(ResourceTy);
2792 break;
2793 }
2794 case Builtin::BI__builtin_hlsl_resource_handlefromimplicitbinding: {
2795 ASTContext &AST = SemaRef.getASTContext();
2796 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 6) ||
2797 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0) ||
2798 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1), ExpectedType: AST.UnsignedIntTy) ||
2799 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 2), ExpectedType: AST.IntTy) ||
2800 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 3), ExpectedType: AST.UnsignedIntTy) ||
2801 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 4), ExpectedType: AST.UnsignedIntTy) ||
2802 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 5),
2803 ExpectedType: AST.getPointerType(T: AST.CharTy.withConst())))
2804 return true;
2805 // use the type of the handle (arg0) as a return type
2806 QualType ResourceTy = TheCall->getArg(Arg: 0)->getType();
2807 TheCall->setType(ResourceTy);
2808 break;
2809 }
2810 case Builtin::BI__builtin_hlsl_and:
2811 case Builtin::BI__builtin_hlsl_or: {
2812 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
2813 return true;
2814 if (CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: getASTContext().BoolTy, ArgIndex: 0))
2815 return true;
2816 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
2817 return true;
2818
2819 ExprResult A = TheCall->getArg(Arg: 0);
2820 QualType ArgTyA = A.get()->getType();
2821 // return type is the same as the input type
2822 TheCall->setType(ArgTyA);
2823 break;
2824 }
2825 case Builtin::BI__builtin_hlsl_all:
2826 case Builtin::BI__builtin_hlsl_any: {
2827 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
2828 return true;
2829 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
2830 return true;
2831 break;
2832 }
2833 case Builtin::BI__builtin_hlsl_asdouble: {
2834 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
2835 return true;
2836 if (CheckScalarOrVector(
2837 S: &SemaRef, TheCall,
2838 /*only check for uint*/ Scalar: SemaRef.Context.UnsignedIntTy,
2839 /* arg index */ ArgIndex: 0))
2840 return true;
2841 if (CheckScalarOrVector(
2842 S: &SemaRef, TheCall,
2843 /*only check for uint*/ Scalar: SemaRef.Context.UnsignedIntTy,
2844 /* arg index */ ArgIndex: 1))
2845 return true;
2846 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
2847 return true;
2848
2849 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().DoubleTy);
2850 break;
2851 }
2852 case Builtin::BI__builtin_hlsl_elementwise_clamp: {
2853 if (SemaRef.BuiltinElementwiseTernaryMath(
2854 TheCall, /*ArgTyRestr=*/
2855 Sema::EltwiseBuiltinArgTyRestriction::None))
2856 return true;
2857 break;
2858 }
2859 case Builtin::BI__builtin_hlsl_dot: {
2860 // arg count is checked by BuiltinVectorToScalarMath
2861 if (SemaRef.BuiltinVectorToScalarMath(TheCall))
2862 return true;
2863 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall, Check: CheckNoDoubleVectors))
2864 return true;
2865 break;
2866 }
2867 case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
2868 case Builtin::BI__builtin_hlsl_elementwise_firstbitlow: {
2869 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2870 return true;
2871
2872 const Expr *Arg = TheCall->getArg(Arg: 0);
2873 QualType ArgTy = Arg->getType();
2874 QualType EltTy = ArgTy;
2875
2876 QualType ResTy = SemaRef.Context.UnsignedIntTy;
2877
2878 if (auto *VecTy = EltTy->getAs<VectorType>()) {
2879 EltTy = VecTy->getElementType();
2880 ResTy = SemaRef.Context.getExtVectorType(VectorType: ResTy, NumElts: VecTy->getNumElements());
2881 }
2882
2883 if (!EltTy->isIntegerType()) {
2884 Diag(Loc: Arg->getBeginLoc(), DiagID: diag::err_builtin_invalid_arg_type)
2885 << 1 << /* scalar or vector of */ 5 << /* integer ty */ 1
2886 << /* no fp */ 0 << ArgTy;
2887 return true;
2888 }
2889
2890 TheCall->setType(ResTy);
2891 break;
2892 }
2893 case Builtin::BI__builtin_hlsl_select: {
2894 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
2895 return true;
2896 if (CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: getASTContext().BoolTy, ArgIndex: 0))
2897 return true;
2898 QualType ArgTy = TheCall->getArg(Arg: 0)->getType();
2899 if (ArgTy->isBooleanType() && CheckBoolSelect(S: &SemaRef, TheCall))
2900 return true;
2901 auto *VTy = ArgTy->getAs<VectorType>();
2902 if (VTy && VTy->getElementType()->isBooleanType() &&
2903 CheckVectorSelect(S: &SemaRef, TheCall))
2904 return true;
2905 break;
2906 }
2907 case Builtin::BI__builtin_hlsl_elementwise_saturate:
2908 case Builtin::BI__builtin_hlsl_elementwise_rcp: {
2909 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
2910 return true;
2911 if (!TheCall->getArg(Arg: 0)
2912 ->getType()
2913 ->hasFloatingRepresentation()) // half or float or double
2914 return SemaRef.Diag(Loc: TheCall->getArg(Arg: 0)->getBeginLoc(),
2915 DiagID: diag::err_builtin_invalid_arg_type)
2916 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
2917 << /* fp */ 1 << TheCall->getArg(Arg: 0)->getType();
2918 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2919 return true;
2920 break;
2921 }
2922 case Builtin::BI__builtin_hlsl_elementwise_degrees:
2923 case Builtin::BI__builtin_hlsl_elementwise_radians:
2924 case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
2925 case Builtin::BI__builtin_hlsl_elementwise_frac: {
2926 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
2927 return true;
2928 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
2929 Check: CheckFloatOrHalfRepresentation))
2930 return true;
2931 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2932 return true;
2933 break;
2934 }
2935 case Builtin::BI__builtin_hlsl_elementwise_isinf: {
2936 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
2937 return true;
2938 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
2939 Check: CheckFloatOrHalfRepresentation))
2940 return true;
2941 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2942 return true;
2943 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().BoolTy);
2944 break;
2945 }
2946 case Builtin::BI__builtin_hlsl_lerp: {
2947 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
2948 return true;
2949 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
2950 Check: CheckFloatOrHalfRepresentation))
2951 return true;
2952 if (CheckAllArgsHaveSameType(S: &SemaRef, TheCall))
2953 return true;
2954 if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
2955 return true;
2956 break;
2957 }
2958 case Builtin::BI__builtin_hlsl_mad: {
2959 if (SemaRef.BuiltinElementwiseTernaryMath(
2960 TheCall, /*ArgTyRestr=*/
2961 Sema::EltwiseBuiltinArgTyRestriction::None))
2962 return true;
2963 break;
2964 }
2965 case Builtin::BI__builtin_hlsl_normalize: {
2966 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
2967 return true;
2968 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
2969 Check: CheckFloatOrHalfRepresentation))
2970 return true;
2971 ExprResult A = TheCall->getArg(Arg: 0);
2972 QualType ArgTyA = A.get()->getType();
2973 // return type is the same as the input type
2974 TheCall->setType(ArgTyA);
2975 break;
2976 }
2977 case Builtin::BI__builtin_hlsl_elementwise_sign: {
2978 if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2979 return true;
2980 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
2981 Check: CheckFloatingOrIntRepresentation))
2982 return true;
2983 SetElementTypeAsReturnType(S: &SemaRef, TheCall, ReturnType: getASTContext().IntTy);
2984 break;
2985 }
2986 case Builtin::BI__builtin_hlsl_step: {
2987 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
2988 return true;
2989 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
2990 Check: CheckFloatOrHalfRepresentation))
2991 return true;
2992
2993 ExprResult A = TheCall->getArg(Arg: 0);
2994 QualType ArgTyA = A.get()->getType();
2995 // return type is the same as the input type
2996 TheCall->setType(ArgTyA);
2997 break;
2998 }
2999 case Builtin::BI__builtin_hlsl_wave_active_max:
3000 case Builtin::BI__builtin_hlsl_wave_active_sum: {
3001 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3002 return true;
3003
3004 // Ensure input expr type is a scalar/vector and the same as the return type
3005 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
3006 return true;
3007 if (CheckWaveActive(S: &SemaRef, TheCall))
3008 return true;
3009 ExprResult Expr = TheCall->getArg(Arg: 0);
3010 QualType ArgTyExpr = Expr.get()->getType();
3011 TheCall->setType(ArgTyExpr);
3012 break;
3013 }
3014 // Note these are llvm builtins that we want to catch invalid intrinsic
3015 // generation. Normal handling of these builitns will occur elsewhere.
3016 case Builtin::BI__builtin_elementwise_bitreverse: {
3017 // does not include a check for number of arguments
3018 // because that is done previously
3019 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3020 Check: CheckUnsignedIntRepresentation))
3021 return true;
3022 break;
3023 }
3024 case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
3025 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2))
3026 return true;
3027
3028 // Ensure index parameter type can be interpreted as a uint
3029 ExprResult Index = TheCall->getArg(Arg: 1);
3030 QualType ArgTyIndex = Index.get()->getType();
3031 if (!ArgTyIndex->isIntegerType()) {
3032 SemaRef.Diag(Loc: TheCall->getArg(Arg: 1)->getBeginLoc(),
3033 DiagID: diag::err_typecheck_convert_incompatible)
3034 << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
3035 return true;
3036 }
3037
3038 // Ensure input expr type is a scalar/vector and the same as the return type
3039 if (CheckAnyScalarOrVector(S: &SemaRef, TheCall, ArgIndex: 0))
3040 return true;
3041
3042 ExprResult Expr = TheCall->getArg(Arg: 0);
3043 QualType ArgTyExpr = Expr.get()->getType();
3044 TheCall->setType(ArgTyExpr);
3045 break;
3046 }
3047 case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
3048 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 0))
3049 return true;
3050 break;
3051 }
3052 case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
3053 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 3))
3054 return true;
3055
3056 if (CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.DoubleTy, ArgIndex: 0) ||
3057 CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.UnsignedIntTy,
3058 ArgIndex: 1) ||
3059 CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.UnsignedIntTy,
3060 ArgIndex: 2))
3061 return true;
3062
3063 if (CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 1) ||
3064 CheckModifiableLValue(S: &SemaRef, TheCall, ArgIndex: 2))
3065 return true;
3066 break;
3067 }
3068 case Builtin::BI__builtin_hlsl_elementwise_clip: {
3069 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 1))
3070 return true;
3071
3072 if (CheckScalarOrVector(S: &SemaRef, TheCall, Scalar: SemaRef.Context.FloatTy, ArgIndex: 0))
3073 return true;
3074 break;
3075 }
3076 case Builtin::BI__builtin_elementwise_acos:
3077 case Builtin::BI__builtin_elementwise_asin:
3078 case Builtin::BI__builtin_elementwise_atan:
3079 case Builtin::BI__builtin_elementwise_atan2:
3080 case Builtin::BI__builtin_elementwise_ceil:
3081 case Builtin::BI__builtin_elementwise_cos:
3082 case Builtin::BI__builtin_elementwise_cosh:
3083 case Builtin::BI__builtin_elementwise_exp:
3084 case Builtin::BI__builtin_elementwise_exp2:
3085 case Builtin::BI__builtin_elementwise_exp10:
3086 case Builtin::BI__builtin_elementwise_floor:
3087 case Builtin::BI__builtin_elementwise_fmod:
3088 case Builtin::BI__builtin_elementwise_log:
3089 case Builtin::BI__builtin_elementwise_log2:
3090 case Builtin::BI__builtin_elementwise_log10:
3091 case Builtin::BI__builtin_elementwise_pow:
3092 case Builtin::BI__builtin_elementwise_roundeven:
3093 case Builtin::BI__builtin_elementwise_sin:
3094 case Builtin::BI__builtin_elementwise_sinh:
3095 case Builtin::BI__builtin_elementwise_sqrt:
3096 case Builtin::BI__builtin_elementwise_tan:
3097 case Builtin::BI__builtin_elementwise_tanh:
3098 case Builtin::BI__builtin_elementwise_trunc: {
3099 if (CheckAllArgTypesAreCorrect(S: &SemaRef, TheCall,
3100 Check: CheckFloatOrHalfRepresentation))
3101 return true;
3102 break;
3103 }
3104 case Builtin::BI__builtin_hlsl_buffer_update_counter: {
3105 auto checkResTy = [](const HLSLAttributedResourceType *ResTy) -> bool {
3106 return !(ResTy->getAttrs().ResourceClass == ResourceClass::UAV &&
3107 ResTy->getAttrs().RawBuffer && ResTy->hasContainedType());
3108 };
3109 if (SemaRef.checkArgCount(Call: TheCall, DesiredArgCount: 2) ||
3110 CheckResourceHandle(S: &SemaRef, TheCall, ArgIndex: 0, Check: checkResTy) ||
3111 CheckArgTypeMatches(S: &SemaRef, Arg: TheCall->getArg(Arg: 1),
3112 ExpectedType: SemaRef.getASTContext().IntTy))
3113 return true;
3114 Expr *OffsetExpr = TheCall->getArg(Arg: 1);
3115 std::optional<llvm::APSInt> Offset =
3116 OffsetExpr->getIntegerConstantExpr(Ctx: SemaRef.getASTContext());
3117 if (!Offset.has_value() || std::abs(i: Offset->getExtValue()) != 1) {
3118 SemaRef.Diag(Loc: TheCall->getArg(Arg: 1)->getBeginLoc(),
3119 DiagID: diag::err_hlsl_expect_arg_const_int_one_or_neg_one)
3120 << 1;
3121 return true;
3122 }
3123 break;
3124 }
3125 }
3126 return false;
3127}
3128
3129static void BuildFlattenedTypeList(QualType BaseTy,
3130 llvm::SmallVectorImpl<QualType> &List) {
3131 llvm::SmallVector<QualType, 16> WorkList;
3132 WorkList.push_back(Elt: BaseTy);
3133 while (!WorkList.empty()) {
3134 QualType T = WorkList.pop_back_val();
3135 T = T.getCanonicalType().getUnqualifiedType();
3136 assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL");
3137 if (const auto *AT = dyn_cast<ConstantArrayType>(Val&: T)) {
3138 llvm::SmallVector<QualType, 16> ElementFields;
3139 // Generally I've avoided recursion in this algorithm, but arrays of
3140 // structs could be time-consuming to flatten and churn through on the
3141 // work list. Hopefully nesting arrays of structs containing arrays
3142 // of structs too many levels deep is unlikely.
3143 BuildFlattenedTypeList(BaseTy: AT->getElementType(), List&: ElementFields);
3144 // Repeat the element's field list n times.
3145 for (uint64_t Ct = 0; Ct < AT->getZExtSize(); ++Ct)
3146 llvm::append_range(C&: List, R&: ElementFields);
3147 continue;
3148 }
3149 // Vectors can only have element types that are builtin types, so this can
3150 // add directly to the list instead of to the WorkList.
3151 if (const auto *VT = dyn_cast<VectorType>(Val&: T)) {
3152 List.insert(I: List.end(), NumToInsert: VT->getNumElements(), Elt: VT->getElementType());
3153 continue;
3154 }
3155 if (const auto *RT = dyn_cast<RecordType>(Val&: T)) {
3156 const CXXRecordDecl *RD = RT->getAsCXXRecordDecl();
3157 assert(RD && "HLSL record types should all be CXXRecordDecls!");
3158
3159 if (RD->isStandardLayout())
3160 RD = RD->getStandardLayoutBaseWithFields();
3161
3162 // For types that we shouldn't decompose (unions and non-aggregates), just
3163 // add the type itself to the list.
3164 if (RD->isUnion() || !RD->isAggregate()) {
3165 List.push_back(Elt: T);
3166 continue;
3167 }
3168
3169 llvm::SmallVector<QualType, 16> FieldTypes;
3170 for (const auto *FD : RD->fields())
3171 FieldTypes.push_back(Elt: FD->getType());
3172 // Reverse the newly added sub-range.
3173 std::reverse(first: FieldTypes.begin(), last: FieldTypes.end());
3174 llvm::append_range(C&: WorkList, R&: FieldTypes);
3175
3176 // If this wasn't a standard layout type we may also have some base
3177 // classes to deal with.
3178 if (!RD->isStandardLayout()) {
3179 FieldTypes.clear();
3180 for (const auto &Base : RD->bases())
3181 FieldTypes.push_back(Elt: Base.getType());
3182 std::reverse(first: FieldTypes.begin(), last: FieldTypes.end());
3183 llvm::append_range(C&: WorkList, R&: FieldTypes);
3184 }
3185 continue;
3186 }
3187 List.push_back(Elt: T);
3188 }
3189}
3190
3191bool SemaHLSL::IsTypedResourceElementCompatible(clang::QualType QT) {
3192 // null and array types are not allowed.
3193 if (QT.isNull() || QT->isArrayType())
3194 return false;
3195
3196 // UDT types are not allowed
3197 if (QT->isRecordType())
3198 return false;
3199
3200 if (QT->isBooleanType() || QT->isEnumeralType())
3201 return false;
3202
3203 // the only other valid builtin types are scalars or vectors
3204 if (QT->isArithmeticType()) {
3205 if (SemaRef.Context.getTypeSize(T: QT) / 8 > 16)
3206 return false;
3207 return true;
3208 }
3209
3210 if (const VectorType *VT = QT->getAs<VectorType>()) {
3211 int ArraySize = VT->getNumElements();
3212
3213 if (ArraySize > 4)
3214 return false;
3215
3216 QualType ElTy = VT->getElementType();
3217 if (ElTy->isBooleanType())
3218 return false;
3219
3220 if (SemaRef.Context.getTypeSize(T: QT) / 8 > 16)
3221 return false;
3222 return true;
3223 }
3224
3225 return false;
3226}
3227
3228bool SemaHLSL::IsScalarizedLayoutCompatible(QualType T1, QualType T2) const {
3229 if (T1.isNull() || T2.isNull())
3230 return false;
3231
3232 T1 = T1.getCanonicalType().getUnqualifiedType();
3233 T2 = T2.getCanonicalType().getUnqualifiedType();
3234
3235 // If both types are the same canonical type, they're obviously compatible.
3236 if (SemaRef.getASTContext().hasSameType(T1, T2))
3237 return true;
3238
3239 llvm::SmallVector<QualType, 16> T1Types;
3240 BuildFlattenedTypeList(BaseTy: T1, List&: T1Types);
3241 llvm::SmallVector<QualType, 16> T2Types;
3242 BuildFlattenedTypeList(BaseTy: T2, List&: T2Types);
3243
3244 // Check the flattened type list
3245 return llvm::equal(LRange&: T1Types, RRange&: T2Types,
3246 P: [this](QualType LHS, QualType RHS) -> bool {
3247 return SemaRef.IsLayoutCompatible(T1: LHS, T2: RHS);
3248 });
3249}
3250
3251bool SemaHLSL::CheckCompatibleParameterABI(FunctionDecl *New,
3252 FunctionDecl *Old) {
3253 if (New->getNumParams() != Old->getNumParams())
3254 return true;
3255
3256 bool HadError = false;
3257
3258 for (unsigned i = 0, e = New->getNumParams(); i != e; ++i) {
3259 ParmVarDecl *NewParam = New->getParamDecl(i);
3260 ParmVarDecl *OldParam = Old->getParamDecl(i);
3261
3262 // HLSL parameter declarations for inout and out must match between
3263 // declarations. In HLSL inout and out are ambiguous at the call site,
3264 // but have different calling behavior, so you cannot overload a
3265 // method based on a difference between inout and out annotations.
3266 const auto *NDAttr = NewParam->getAttr<HLSLParamModifierAttr>();
3267 unsigned NSpellingIdx = (NDAttr ? NDAttr->getSpellingListIndex() : 0);
3268 const auto *ODAttr = OldParam->getAttr<HLSLParamModifierAttr>();
3269 unsigned OSpellingIdx = (ODAttr ? ODAttr->getSpellingListIndex() : 0);
3270
3271 if (NSpellingIdx != OSpellingIdx) {
3272 SemaRef.Diag(Loc: NewParam->getLocation(),
3273 DiagID: diag::err_hlsl_param_qualifier_mismatch)
3274 << NDAttr << NewParam;
3275 SemaRef.Diag(Loc: OldParam->getLocation(), DiagID: diag::note_previous_declaration_as)
3276 << ODAttr;
3277 HadError = true;
3278 }
3279 }
3280 return HadError;
3281}
3282
3283// Generally follows PerformScalarCast, with cases reordered for
3284// clarity of what types are supported
3285bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
3286
3287 if (!SrcTy->isScalarType() || !DestTy->isScalarType())
3288 return false;
3289
3290 if (SemaRef.getASTContext().hasSameUnqualifiedType(T1: SrcTy, T2: DestTy))
3291 return true;
3292
3293 switch (SrcTy->getScalarTypeKind()) {
3294 case Type::STK_Bool: // casting from bool is like casting from an integer
3295 case Type::STK_Integral:
3296 switch (DestTy->getScalarTypeKind()) {
3297 case Type::STK_Bool:
3298 case Type::STK_Integral:
3299 case Type::STK_Floating:
3300 return true;
3301 case Type::STK_CPointer:
3302 case Type::STK_ObjCObjectPointer:
3303 case Type::STK_BlockPointer:
3304 case Type::STK_MemberPointer:
3305 llvm_unreachable("HLSL doesn't support pointers.");
3306 case Type::STK_IntegralComplex:
3307 case Type::STK_FloatingComplex:
3308 llvm_unreachable("HLSL doesn't support complex types.");
3309 case Type::STK_FixedPoint:
3310 llvm_unreachable("HLSL doesn't support fixed point types.");
3311 }
3312 llvm_unreachable("Should have returned before this");
3313
3314 case Type::STK_Floating:
3315 switch (DestTy->getScalarTypeKind()) {
3316 case Type::STK_Floating:
3317 case Type::STK_Bool:
3318 case Type::STK_Integral:
3319 return true;
3320 case Type::STK_FloatingComplex:
3321 case Type::STK_IntegralComplex:
3322 llvm_unreachable("HLSL doesn't support complex types.");
3323 case Type::STK_FixedPoint:
3324 llvm_unreachable("HLSL doesn't support fixed point types.");
3325 case Type::STK_CPointer:
3326 case Type::STK_ObjCObjectPointer:
3327 case Type::STK_BlockPointer:
3328 case Type::STK_MemberPointer:
3329 llvm_unreachable("HLSL doesn't support pointers.");
3330 }
3331 llvm_unreachable("Should have returned before this");
3332
3333 case Type::STK_MemberPointer:
3334 case Type::STK_CPointer:
3335 case Type::STK_BlockPointer:
3336 case Type::STK_ObjCObjectPointer:
3337 llvm_unreachable("HLSL doesn't support pointers.");
3338
3339 case Type::STK_FixedPoint:
3340 llvm_unreachable("HLSL doesn't support fixed point types.");
3341
3342 case Type::STK_FloatingComplex:
3343 case Type::STK_IntegralComplex:
3344 llvm_unreachable("HLSL doesn't support complex types.");
3345 }
3346
3347 llvm_unreachable("Unhandled scalar cast");
3348}
3349
3350// Detect if a type contains a bitfield. Will be removed when
3351// bitfield support is added to HLSLElementwiseCast and HLSLAggregateSplatCast
3352bool SemaHLSL::ContainsBitField(QualType BaseTy) {
3353 llvm::SmallVector<QualType, 16> WorkList;
3354 WorkList.push_back(Elt: BaseTy);
3355 while (!WorkList.empty()) {
3356 QualType T = WorkList.pop_back_val();
3357 T = T.getCanonicalType().getUnqualifiedType();
3358 // only check aggregate types
3359 if (const auto *AT = dyn_cast<ConstantArrayType>(Val&: T)) {
3360 WorkList.push_back(Elt: AT->getElementType());
3361 continue;
3362 }
3363 if (const auto *RT = dyn_cast<RecordType>(Val&: T)) {
3364 const RecordDecl *RD = RT->getDecl();
3365 if (RD->isUnion())
3366 continue;
3367
3368 const CXXRecordDecl *CXXD = dyn_cast<CXXRecordDecl>(Val: RD);
3369
3370 if (CXXD && CXXD->isStandardLayout())
3371 RD = CXXD->getStandardLayoutBaseWithFields();
3372
3373 for (const auto *FD : RD->fields()) {
3374 if (FD->isBitField())
3375 return true;
3376 WorkList.push_back(Elt: FD->getType());
3377 }
3378 continue;
3379 }
3380 }
3381 return false;
3382}
3383
3384// Can perform an HLSL Aggregate splat cast if the Dest is an aggregate and the
3385// Src is a scalar or a vector of length 1
3386// Or if Dest is a vector and Src is a vector of length 1
3387bool SemaHLSL::CanPerformAggregateSplatCast(Expr *Src, QualType DestTy) {
3388
3389 QualType SrcTy = Src->getType();
3390 // Not a valid HLSL Aggregate Splat cast if Dest is a scalar or if this is
3391 // going to be a vector splat from a scalar.
3392 if ((SrcTy->isScalarType() && DestTy->isVectorType()) ||
3393 DestTy->isScalarType())
3394 return false;
3395
3396 const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
3397
3398 // Src isn't a scalar or a vector of length 1
3399 if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
3400 return false;
3401
3402 if (SrcVecTy)
3403 SrcTy = SrcVecTy->getElementType();
3404
3405 if (ContainsBitField(BaseTy: DestTy))
3406 return false;
3407
3408 llvm::SmallVector<QualType> DestTypes;
3409 BuildFlattenedTypeList(BaseTy: DestTy, List&: DestTypes);
3410
3411 for (unsigned I = 0, Size = DestTypes.size(); I < Size; ++I) {
3412 if (DestTypes[I]->isUnionType())
3413 return false;
3414 if (!CanPerformScalarCast(SrcTy, DestTy: DestTypes[I]))
3415 return false;
3416 }
3417 return true;
3418}
3419
3420// Can we perform an HLSL Elementwise cast?
3421// TODO: update this code when matrices are added; see issue #88060
3422bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) {
3423
3424 // Don't handle casts where LHS and RHS are any combination of scalar/vector
3425 // There must be an aggregate somewhere
3426 QualType SrcTy = Src->getType();
3427 if (SrcTy->isScalarType()) // always a splat and this cast doesn't handle that
3428 return false;
3429
3430 if (SrcTy->isVectorType() &&
3431 (DestTy->isScalarType() || DestTy->isVectorType()))
3432 return false;
3433
3434 if (ContainsBitField(BaseTy: DestTy) || ContainsBitField(BaseTy: SrcTy))
3435 return false;
3436
3437 llvm::SmallVector<QualType> DestTypes;
3438 BuildFlattenedTypeList(BaseTy: DestTy, List&: DestTypes);
3439 llvm::SmallVector<QualType> SrcTypes;
3440 BuildFlattenedTypeList(BaseTy: SrcTy, List&: SrcTypes);
3441
3442 // Usually the size of SrcTypes must be greater than or equal to the size of
3443 // DestTypes.
3444 if (SrcTypes.size() < DestTypes.size())
3445 return false;
3446
3447 unsigned SrcSize = SrcTypes.size();
3448 unsigned DstSize = DestTypes.size();
3449 unsigned I;
3450 for (I = 0; I < DstSize && I < SrcSize; I++) {
3451 if (SrcTypes[I]->isUnionType() || DestTypes[I]->isUnionType())
3452 return false;
3453 if (!CanPerformScalarCast(SrcTy: SrcTypes[I], DestTy: DestTypes[I])) {
3454 return false;
3455 }
3456 }
3457
3458 // check the rest of the source type for unions.
3459 for (; I < SrcSize; I++) {
3460 if (SrcTypes[I]->isUnionType())
3461 return false;
3462 }
3463 return true;
3464}
3465
3466ExprResult SemaHLSL::ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg) {
3467 assert(Param->hasAttr<HLSLParamModifierAttr>() &&
3468 "We should not get here without a parameter modifier expression");
3469 const auto *Attr = Param->getAttr<HLSLParamModifierAttr>();
3470 if (Attr->getABI() == ParameterABI::Ordinary)
3471 return ExprResult(Arg);
3472
3473 bool IsInOut = Attr->getABI() == ParameterABI::HLSLInOut;
3474 if (!Arg->isLValue()) {
3475 SemaRef.Diag(Loc: Arg->getBeginLoc(), DiagID: diag::error_hlsl_inout_lvalue)
3476 << Arg << (IsInOut ? 1 : 0);
3477 return ExprError();
3478 }
3479
3480 ASTContext &Ctx = SemaRef.getASTContext();
3481
3482 QualType Ty = Param->getType().getNonLValueExprType(Context: Ctx);
3483
3484 // HLSL allows implicit conversions from scalars to vectors, but not the
3485 // inverse, so we need to disallow `inout` with scalar->vector or
3486 // scalar->matrix conversions.
3487 if (Arg->getType()->isScalarType() != Ty->isScalarType()) {
3488 SemaRef.Diag(Loc: Arg->getBeginLoc(), DiagID: diag::error_hlsl_inout_scalar_extension)
3489 << Arg << (IsInOut ? 1 : 0);
3490 return ExprError();
3491 }
3492
3493 auto *ArgOpV = new (Ctx) OpaqueValueExpr(Param->getBeginLoc(), Arg->getType(),
3494 VK_LValue, OK_Ordinary, Arg);
3495
3496 // Parameters are initialized via copy initialization. This allows for
3497 // overload resolution of argument constructors.
3498 InitializedEntity Entity =
3499 InitializedEntity::InitializeParameter(Context&: Ctx, Type: Ty, Consumed: false);
3500 ExprResult Res =
3501 SemaRef.PerformCopyInitialization(Entity, EqualLoc: Param->getBeginLoc(), Init: ArgOpV);
3502 if (Res.isInvalid())
3503 return ExprError();
3504 Expr *Base = Res.get();
3505 // After the cast, drop the reference type when creating the exprs.
3506 Ty = Ty.getNonLValueExprType(Context: Ctx);
3507 auto *OpV = new (Ctx)
3508 OpaqueValueExpr(Param->getBeginLoc(), Ty, VK_LValue, OK_Ordinary, Base);
3509
3510 // Writebacks are performed with `=` binary operator, which allows for
3511 // overload resolution on writeback result expressions.
3512 Res = SemaRef.ActOnBinOp(S: SemaRef.getCurScope(), TokLoc: Param->getBeginLoc(),
3513 Kind: tok::equal, LHSExpr: ArgOpV, RHSExpr: OpV);
3514
3515 if (Res.isInvalid())
3516 return ExprError();
3517 Expr *Writeback = Res.get();
3518 auto *OutExpr =
3519 HLSLOutArgExpr::Create(C: Ctx, Ty, Base: ArgOpV, OpV, WB: Writeback, IsInOut);
3520
3521 return ExprResult(OutExpr);
3522}
3523
3524QualType SemaHLSL::getInoutParameterType(QualType Ty) {
3525 // If HLSL gains support for references, all the cites that use this will need
3526 // to be updated with semantic checking to produce errors for
3527 // pointers/references.
3528 assert(!Ty->isReferenceType() &&
3529 "Pointer and reference types cannot be inout or out parameters");
3530 Ty = SemaRef.getASTContext().getLValueReferenceType(T: Ty);
3531 Ty.addRestrict();
3532 return Ty;
3533}
3534
3535static bool IsDefaultBufferConstantDecl(VarDecl *VD) {
3536 QualType QT = VD->getType();
3537 return VD->getDeclContext()->isTranslationUnit() &&
3538 QT.getAddressSpace() == LangAS::Default &&
3539 VD->getStorageClass() != SC_Static &&
3540 !VD->hasAttr<HLSLVkConstantIdAttr>() &&
3541 !isInvalidConstantBufferLeafElementType(Ty: QT.getTypePtr());
3542}
3543
3544void SemaHLSL::deduceAddressSpace(VarDecl *Decl) {
3545 // The variable already has an address space (groupshared for ex).
3546 if (Decl->getType().hasAddressSpace())
3547 return;
3548
3549 if (Decl->getType()->isDependentType())
3550 return;
3551
3552 QualType Type = Decl->getType();
3553
3554 if (Decl->hasAttr<HLSLVkExtBuiltinInputAttr>()) {
3555 LangAS ImplAS = LangAS::hlsl_input;
3556 Type = SemaRef.getASTContext().getAddrSpaceQualType(T: Type, AddressSpace: ImplAS);
3557 Decl->setType(Type);
3558 return;
3559 }
3560
3561 if (Type->isSamplerT() || Type->isVoidType())
3562 return;
3563
3564 // Resource handles.
3565 if (isResourceRecordTypeOrArrayOf(Ty: Type->getUnqualifiedDesugaredType()))
3566 return;
3567
3568 // Only static globals belong to the Private address space.
3569 // Non-static globals belongs to the cbuffer.
3570 if (Decl->getStorageClass() != SC_Static && !Decl->isStaticDataMember())
3571 return;
3572
3573 LangAS ImplAS = LangAS::hlsl_private;
3574 Type = SemaRef.getASTContext().getAddrSpaceQualType(T: Type, AddressSpace: ImplAS);
3575 Decl->setType(Type);
3576}
3577
3578void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
3579 if (VD->hasGlobalStorage()) {
3580 // make sure the declaration has a complete type
3581 if (SemaRef.RequireCompleteType(
3582 Loc: VD->getLocation(),
3583 T: SemaRef.getASTContext().getBaseElementType(QT: VD->getType()),
3584 DiagID: diag::err_typecheck_decl_incomplete_type)) {
3585 VD->setInvalidDecl();
3586 deduceAddressSpace(Decl: VD);
3587 return;
3588 }
3589
3590 // Global variables outside a cbuffer block that are not a resource, static,
3591 // groupshared, or an empty array or struct belong to the default constant
3592 // buffer $Globals (to be created at the end of the translation unit).
3593 if (IsDefaultBufferConstantDecl(VD)) {
3594 // update address space to hlsl_constant
3595 QualType NewTy = getASTContext().getAddrSpaceQualType(
3596 T: VD->getType(), AddressSpace: LangAS::hlsl_constant);
3597 VD->setType(NewTy);
3598 DefaultCBufferDecls.push_back(Elt: VD);
3599 }
3600
3601 // find all resources bindings on decl
3602 if (VD->getType()->isHLSLIntangibleType())
3603 collectResourceBindingsOnVarDecl(D: VD);
3604
3605 const Type *VarType = VD->getType().getTypePtr();
3606 while (VarType->isArrayType())
3607 VarType = VarType->getArrayElementTypeNoTypeQual();
3608 if (VarType->isHLSLResourceRecord() ||
3609 VD->hasAttr<HLSLVkConstantIdAttr>()) {
3610 // Make the variable for resources static. The global externally visible
3611 // storage is accessed through the handle, which is a member. The variable
3612 // itself is not externally visible.
3613 VD->setStorageClass(StorageClass::SC_Static);
3614 }
3615
3616 // process explicit bindings
3617 processExplicitBindingsOnDecl(D: VD);
3618 }
3619
3620 deduceAddressSpace(Decl: VD);
3621}
3622
3623static bool initVarDeclWithCtor(Sema &S, VarDecl *VD,
3624 MutableArrayRef<Expr *> Args) {
3625 InitializedEntity Entity = InitializedEntity::InitializeVariable(Var: VD);
3626 InitializationKind Kind = InitializationKind::CreateDirect(
3627 InitLoc: VD->getLocation(), LParenLoc: SourceLocation(), RParenLoc: SourceLocation());
3628
3629 InitializationSequence InitSeq(S, Entity, Kind, Args);
3630 if (InitSeq.Failed())
3631 return false;
3632
3633 ExprResult Init = InitSeq.Perform(S, Entity, Kind, Args);
3634 if (!Init.get())
3635 return false;
3636
3637 VD->setInit(S.MaybeCreateExprWithCleanups(SubExpr: Init.get()));
3638 VD->setInitStyle(VarDecl::CallInit);
3639 S.CheckCompleteVariableDeclaration(VD);
3640 return true;
3641}
3642
3643bool SemaHLSL::initGlobalResourceDecl(VarDecl *VD) {
3644 std::optional<uint32_t> RegisterSlot;
3645 uint32_t SpaceNo = 0;
3646 HLSLResourceBindingAttr *RBA = VD->getAttr<HLSLResourceBindingAttr>();
3647 if (RBA) {
3648 if (RBA->hasRegisterSlot())
3649 RegisterSlot = RBA->getSlotNumber();
3650 SpaceNo = RBA->getSpaceNumber();
3651 }
3652
3653 ASTContext &AST = SemaRef.getASTContext();
3654 uint64_t UIntTySize = AST.getTypeSize(T: AST.UnsignedIntTy);
3655 uint64_t IntTySize = AST.getTypeSize(T: AST.IntTy);
3656 IntegerLiteral *RangeSize = IntegerLiteral::Create(
3657 C: AST, V: llvm::APInt(IntTySize, 1), type: AST.IntTy, l: SourceLocation());
3658 IntegerLiteral *Index = IntegerLiteral::Create(
3659 C: AST, V: llvm::APInt(UIntTySize, 0), type: AST.UnsignedIntTy, l: SourceLocation());
3660 IntegerLiteral *Space =
3661 IntegerLiteral::Create(C: AST, V: llvm::APInt(UIntTySize, SpaceNo),
3662 type: AST.UnsignedIntTy, l: SourceLocation());
3663 StringRef VarName = VD->getName();
3664 StringLiteral *Name = StringLiteral::Create(
3665 Ctx: AST, Str: VarName, Kind: StringLiteralKind::Ordinary, Pascal: false,
3666 Ty: AST.getStringLiteralArrayType(EltTy: AST.CharTy.withConst(), Length: VarName.size()),
3667 Locs: SourceLocation());
3668
3669 // resource with explicit binding
3670 if (RegisterSlot.has_value()) {
3671 IntegerLiteral *RegSlot = IntegerLiteral::Create(
3672 C: AST, V: llvm::APInt(UIntTySize, RegisterSlot.value()), type: AST.UnsignedIntTy,
3673 l: SourceLocation());
3674 Expr *Args[] = {RegSlot, Space, RangeSize, Index, Name};
3675 return initVarDeclWithCtor(S&: SemaRef, VD, Args);
3676 }
3677
3678 // resource with implicit binding
3679 IntegerLiteral *OrderId = IntegerLiteral::Create(
3680 C: AST, V: llvm::APInt(UIntTySize, getNextImplicitBindingOrderID()),
3681 type: AST.UnsignedIntTy, l: SourceLocation());
3682 Expr *Args[] = {Space, RangeSize, Index, OrderId, Name};
3683 return initVarDeclWithCtor(S&: SemaRef, VD, Args);
3684}
3685
3686// Returns true if the initialization has been handled.
3687// Returns false to use default initialization.
3688bool SemaHLSL::ActOnUninitializedVarDecl(VarDecl *VD) {
3689 // Objects in the hlsl_constant address space are initialized
3690 // externally, so don't synthesize an implicit initializer.
3691 if (VD->getType().getAddressSpace() == LangAS::hlsl_constant)
3692 return true;
3693
3694 // Initialize resources
3695 if (!isResourceRecordTypeOrArrayOf(VD))
3696 return false;
3697
3698 // FIXME: We currectly support only simple resources - no arrays of resources
3699 // or resources in user defined structs.
3700 // (llvm/llvm-project#133835, llvm/llvm-project#133837)
3701 // Initialize resources at the global scope
3702 if (VD->hasGlobalStorage() && VD->getType()->isHLSLResourceRecord())
3703 return initGlobalResourceDecl(VD);
3704
3705 return false;
3706}
3707
3708// Walks though the global variable declaration, collects all resource binding
3709// requirements and adds them to Bindings
3710void SemaHLSL::collectResourceBindingsOnVarDecl(VarDecl *VD) {
3711 assert(VD->hasGlobalStorage() && VD->getType()->isHLSLIntangibleType() &&
3712 "expected global variable that contains HLSL resource");
3713
3714 // Cbuffers and Tbuffers are HLSLBufferDecl types
3715 if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(Val: VD)) {
3716 Bindings.addDeclBindingInfo(VD, ResClass: CBufferOrTBuffer->isCBuffer()
3717 ? ResourceClass::CBuffer
3718 : ResourceClass::SRV);
3719 return;
3720 }
3721
3722 // Unwrap arrays
3723 // FIXME: Calculate array size while unwrapping
3724 const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
3725 while (Ty->isConstantArrayType()) {
3726 const ConstantArrayType *CAT = cast<ConstantArrayType>(Val: Ty);
3727 Ty = CAT->getElementType()->getUnqualifiedDesugaredType();
3728 }
3729
3730 // Resource (or array of resources)
3731 if (const HLSLAttributedResourceType *AttrResType =
3732 HLSLAttributedResourceType::findHandleTypeOnResource(RT: Ty)) {
3733 Bindings.addDeclBindingInfo(VD, ResClass: AttrResType->getAttrs().ResourceClass);
3734 return;
3735 }
3736
3737 // User defined record type
3738 if (const RecordType *RT = dyn_cast<RecordType>(Val: Ty))
3739 collectResourceBindingsOnUserRecordDecl(VD, RT);
3740}
3741
3742// Walks though the explicit resource binding attributes on the declaration,
3743// and makes sure there is a resource that matched the binding and updates
3744// DeclBindingInfoLists
3745void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) {
3746 assert(VD->hasGlobalStorage() && "expected global variable");
3747
3748 bool HasBinding = false;
3749 for (Attr *A : VD->attrs()) {
3750 HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(Val: A);
3751 if (!RBA || !RBA->hasRegisterSlot())
3752 continue;
3753 HasBinding = true;
3754
3755 RegisterType RT = RBA->getRegisterType();
3756 assert(RT != RegisterType::I && "invalid or obsolete register type should "
3757 "never have an attribute created");
3758
3759 if (RT == RegisterType::C) {
3760 if (Bindings.hasBindingInfoForDecl(VD))
3761 SemaRef.Diag(Loc: VD->getLocation(),
3762 DiagID: diag::warn_hlsl_user_defined_type_missing_member)
3763 << static_cast<int>(RT);
3764 continue;
3765 }
3766
3767 // Find DeclBindingInfo for this binding and update it, or report error
3768 // if it does not exist (user type does to contain resources with the
3769 // expected resource class).
3770 ResourceClass RC = getResourceClass(RT);
3771 if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, ResClass: RC)) {
3772 // update binding info
3773 BI->setBindingAttribute(A: RBA, BT: BindingType::Explicit);
3774 } else {
3775 SemaRef.Diag(Loc: VD->getLocation(),
3776 DiagID: diag::warn_hlsl_user_defined_type_missing_member)
3777 << static_cast<int>(RT);
3778 }
3779 }
3780
3781 if (!HasBinding && isResourceRecordTypeOrArrayOf(VD))
3782 SemaRef.Diag(Loc: VD->getLocation(), DiagID: diag::warn_hlsl_implicit_binding);
3783}
3784namespace {
3785class InitListTransformer {
3786 Sema &S;
3787 ASTContext &Ctx;
3788 QualType InitTy;
3789 QualType *DstIt = nullptr;
3790 Expr **ArgIt = nullptr;
3791 // Is wrapping the destination type iterator required? This is only used for
3792 // incomplete array types where we loop over the destination type since we
3793 // don't know the full number of elements from the declaration.
3794 bool Wrap;
3795
3796 bool castInitializer(Expr *E) {
3797 assert(DstIt && "This should always be something!");
3798 if (DstIt == DestTypes.end()) {
3799 if (!Wrap) {
3800 ArgExprs.push_back(Elt: E);
3801 // This is odd, but it isn't technically a failure due to conversion, we
3802 // handle mismatched counts of arguments differently.
3803 return true;
3804 }
3805 DstIt = DestTypes.begin();
3806 }
3807 InitializedEntity Entity = InitializedEntity::InitializeParameter(
3808 Context&: Ctx, Type: *DstIt, /* Consumed (ObjC) */ Consumed: false);
3809 ExprResult Res = S.PerformCopyInitialization(Entity, EqualLoc: E->getBeginLoc(), Init: E);
3810 if (Res.isInvalid())
3811 return false;
3812 Expr *Init = Res.get();
3813 ArgExprs.push_back(Elt: Init);
3814 DstIt++;
3815 return true;
3816 }
3817
3818 bool buildInitializerListImpl(Expr *E) {
3819 // If this is an initialization list, traverse the sub initializers.
3820 if (auto *Init = dyn_cast<InitListExpr>(Val: E)) {
3821 for (auto *SubInit : Init->inits())
3822 if (!buildInitializerListImpl(E: SubInit))
3823 return false;
3824 return true;
3825 }
3826
3827 // If this is a scalar type, just enqueue the expression.
3828 QualType Ty = E->getType();
3829
3830 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()))
3831 return castInitializer(E);
3832
3833 if (auto *VecTy = Ty->getAs<VectorType>()) {
3834 uint64_t Size = VecTy->getNumElements();
3835
3836 QualType SizeTy = Ctx.getSizeType();
3837 uint64_t SizeTySize = Ctx.getTypeSize(T: SizeTy);
3838 for (uint64_t I = 0; I < Size; ++I) {
3839 auto *Idx = IntegerLiteral::Create(C: Ctx, V: llvm::APInt(SizeTySize, I),
3840 type: SizeTy, l: SourceLocation());
3841
3842 ExprResult ElExpr = S.CreateBuiltinArraySubscriptExpr(
3843 Base: E, LLoc: E->getBeginLoc(), Idx, RLoc: E->getEndLoc());
3844 if (ElExpr.isInvalid())
3845 return false;
3846 if (!castInitializer(E: ElExpr.get()))
3847 return false;
3848 }
3849 return true;
3850 }
3851
3852 if (auto *ArrTy = dyn_cast<ConstantArrayType>(Val: Ty.getTypePtr())) {
3853 uint64_t Size = ArrTy->getZExtSize();
3854 QualType SizeTy = Ctx.getSizeType();
3855 uint64_t SizeTySize = Ctx.getTypeSize(T: SizeTy);
3856 for (uint64_t I = 0; I < Size; ++I) {
3857 auto *Idx = IntegerLiteral::Create(C: Ctx, V: llvm::APInt(SizeTySize, I),
3858 type: SizeTy, l: SourceLocation());
3859 ExprResult ElExpr = S.CreateBuiltinArraySubscriptExpr(
3860 Base: E, LLoc: E->getBeginLoc(), Idx, RLoc: E->getEndLoc());
3861 if (ElExpr.isInvalid())
3862 return false;
3863 if (!buildInitializerListImpl(E: ElExpr.get()))
3864 return false;
3865 }
3866 return true;
3867 }
3868
3869 if (auto *RTy = Ty->getAs<RecordType>()) {
3870 llvm::SmallVector<const RecordType *> RecordTypes;
3871 RecordTypes.push_back(Elt: RTy);
3872 while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) {
3873 CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl();
3874 assert(D->getNumBases() == 1 &&
3875 "HLSL doesn't support multiple inheritance");
3876 RecordTypes.push_back(Elt: D->bases_begin()->getType()->getAs<RecordType>());
3877 }
3878 while (!RecordTypes.empty()) {
3879 const RecordType *RT = RecordTypes.pop_back_val();
3880 for (auto *FD : RT->getDecl()->fields()) {
3881 DeclAccessPair Found = DeclAccessPair::make(D: FD, AS: FD->getAccess());
3882 DeclarationNameInfo NameInfo(FD->getDeclName(), E->getBeginLoc());
3883 ExprResult Res = S.BuildFieldReferenceExpr(
3884 BaseExpr: E, IsArrow: false, OpLoc: E->getBeginLoc(), SS: CXXScopeSpec(), Field: FD, FoundDecl: Found, MemberNameInfo: NameInfo);
3885 if (Res.isInvalid())
3886 return false;
3887 if (!buildInitializerListImpl(E: Res.get()))
3888 return false;
3889 }
3890 }
3891 }
3892 return true;
3893 }
3894
3895 Expr *generateInitListsImpl(QualType Ty) {
3896 assert(ArgIt != ArgExprs.end() && "Something is off in iteration!");
3897 if (Ty->isScalarType() || (Ty->isRecordType() && !Ty->isAggregateType()))
3898 return *(ArgIt++);
3899
3900 llvm::SmallVector<Expr *> Inits;
3901 assert(!isa<MatrixType>(Ty) && "Matrix types not yet supported in HLSL");
3902 Ty = Ty.getDesugaredType(Context: Ctx);
3903 if (Ty->isVectorType() || Ty->isConstantArrayType()) {
3904 QualType ElTy;
3905 uint64_t Size = 0;
3906 if (auto *ATy = Ty->getAs<VectorType>()) {
3907 ElTy = ATy->getElementType();
3908 Size = ATy->getNumElements();
3909 } else {
3910 auto *VTy = cast<ConstantArrayType>(Val: Ty.getTypePtr());
3911 ElTy = VTy->getElementType();
3912 Size = VTy->getZExtSize();
3913 }
3914 for (uint64_t I = 0; I < Size; ++I)
3915 Inits.push_back(Elt: generateInitListsImpl(Ty: ElTy));
3916 }
3917 if (auto *RTy = Ty->getAs<RecordType>()) {
3918 llvm::SmallVector<const RecordType *> RecordTypes;
3919 RecordTypes.push_back(Elt: RTy);
3920 while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) {
3921 CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl();
3922 assert(D->getNumBases() == 1 &&
3923 "HLSL doesn't support multiple inheritance");
3924 RecordTypes.push_back(Elt: D->bases_begin()->getType()->getAs<RecordType>());
3925 }
3926 while (!RecordTypes.empty()) {
3927 const RecordType *RT = RecordTypes.pop_back_val();
3928 for (auto *FD : RT->getDecl()->fields()) {
3929 Inits.push_back(Elt: generateInitListsImpl(Ty: FD->getType()));
3930 }
3931 }
3932 }
3933 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
3934 Inits, Inits.back()->getEndLoc());
3935 NewInit->setType(Ty);
3936 return NewInit;
3937 }
3938
3939public:
3940 llvm::SmallVector<QualType, 16> DestTypes;
3941 llvm::SmallVector<Expr *, 16> ArgExprs;
3942 InitListTransformer(Sema &SemaRef, const InitializedEntity &Entity)
3943 : S(SemaRef), Ctx(SemaRef.getASTContext()),
3944 Wrap(Entity.getType()->isIncompleteArrayType()) {
3945 InitTy = Entity.getType().getNonReferenceType();
3946 // When we're generating initializer lists for incomplete array types we
3947 // need to wrap around both when building the initializers and when
3948 // generating the final initializer lists.
3949 if (Wrap) {
3950 assert(InitTy->isIncompleteArrayType());
3951 const IncompleteArrayType *IAT = Ctx.getAsIncompleteArrayType(T: InitTy);
3952 InitTy = IAT->getElementType();
3953 }
3954 BuildFlattenedTypeList(BaseTy: InitTy, List&: DestTypes);
3955 DstIt = DestTypes.begin();
3956 }
3957
3958 bool buildInitializerList(Expr *E) { return buildInitializerListImpl(E); }
3959
3960 Expr *generateInitLists() {
3961 assert(!ArgExprs.empty() &&
3962 "Call buildInitializerList to generate argument expressions.");
3963 ArgIt = ArgExprs.begin();
3964 if (!Wrap)
3965 return generateInitListsImpl(Ty: InitTy);
3966 llvm::SmallVector<Expr *> Inits;
3967 while (ArgIt != ArgExprs.end())
3968 Inits.push_back(Elt: generateInitListsImpl(Ty: InitTy));
3969
3970 auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
3971 Inits, Inits.back()->getEndLoc());
3972 llvm::APInt ArySize(64, Inits.size());
3973 NewInit->setType(Ctx.getConstantArrayType(EltTy: InitTy, ArySize, SizeExpr: nullptr,
3974 ASM: ArraySizeModifier::Normal, IndexTypeQuals: 0));
3975 return NewInit;
3976 }
3977};
3978} // namespace
3979
3980bool SemaHLSL::transformInitList(const InitializedEntity &Entity,
3981 InitListExpr *Init) {
3982 // If the initializer is a scalar, just return it.
3983 if (Init->getType()->isScalarType())
3984 return true;
3985 ASTContext &Ctx = SemaRef.getASTContext();
3986 InitListTransformer ILT(SemaRef, Entity);
3987
3988 for (unsigned I = 0; I < Init->getNumInits(); ++I) {
3989 Expr *E = Init->getInit(Init: I);
3990 if (E->HasSideEffects(Ctx)) {
3991 QualType Ty = E->getType();
3992 if (Ty->isRecordType())
3993 E = new (Ctx) MaterializeTemporaryExpr(Ty, E, E->isLValue());
3994 E = new (Ctx) OpaqueValueExpr(E->getBeginLoc(), Ty, E->getValueKind(),
3995 E->getObjectKind(), E);
3996 Init->setInit(Init: I, expr: E);
3997 }
3998 if (!ILT.buildInitializerList(E))
3999 return false;
4000 }
4001 size_t ExpectedSize = ILT.DestTypes.size();
4002 size_t ActualSize = ILT.ArgExprs.size();
4003 // For incomplete arrays it is completely arbitrary to choose whether we think
4004 // the user intended fewer or more elements. This implementation assumes that
4005 // the user intended more, and errors that there are too few initializers to
4006 // complete the final element.
4007 if (Entity.getType()->isIncompleteArrayType())
4008 ExpectedSize =
4009 ((ActualSize + ExpectedSize - 1) / ExpectedSize) * ExpectedSize;
4010
4011 // An initializer list might be attempting to initialize a reference or
4012 // rvalue-reference. When checking the initializer we should look through
4013 // the reference.
4014 QualType InitTy = Entity.getType().getNonReferenceType();
4015 if (InitTy.hasAddressSpace())
4016 InitTy = SemaRef.getASTContext().removeAddrSpaceQualType(T: InitTy);
4017 if (ExpectedSize != ActualSize) {
4018 int TooManyOrFew = ActualSize > ExpectedSize ? 1 : 0;
4019 SemaRef.Diag(Loc: Init->getBeginLoc(), DiagID: diag::err_hlsl_incorrect_num_initializers)
4020 << TooManyOrFew << InitTy << ExpectedSize << ActualSize;
4021 return false;
4022 }
4023
4024 // generateInitListsImpl will always return an InitListExpr here, because the
4025 // scalar case is handled above.
4026 auto *NewInit = cast<InitListExpr>(Val: ILT.generateInitLists());
4027 Init->resizeInits(Context: Ctx, NumInits: NewInit->getNumInits());
4028 for (unsigned I = 0; I < NewInit->getNumInits(); ++I)
4029 Init->updateInit(C: Ctx, Init: I, expr: NewInit->getInit(Init: I));
4030 return true;
4031}
4032
4033bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
4034 const HLSLVkConstantIdAttr *ConstIdAttr =
4035 VDecl->getAttr<HLSLVkConstantIdAttr>();
4036 if (!ConstIdAttr)
4037 return true;
4038
4039 ASTContext &Context = SemaRef.getASTContext();
4040
4041 APValue InitValue;
4042 if (!Init->isCXX11ConstantExpr(Ctx: Context, Result: &InitValue)) {
4043 Diag(Loc: VDecl->getLocation(), DiagID: diag::err_specialization_const);
4044 VDecl->setInvalidDecl();
4045 return false;
4046 }
4047
4048 Builtin::ID BID =
4049 getSpecConstBuiltinId(Type: VDecl->getType()->getUnqualifiedDesugaredType());
4050
4051 // Argument 1: The ID from the attribute
4052 int ConstantID = ConstIdAttr->getId();
4053 llvm::APInt IDVal(Context.getIntWidth(T: Context.IntTy), ConstantID);
4054 Expr *IdExpr = IntegerLiteral::Create(C: Context, V: IDVal, type: Context.IntTy,
4055 l: ConstIdAttr->getLocation());
4056
4057 SmallVector<Expr *, 2> Args = {IdExpr, Init};
4058 Expr *C = SemaRef.BuildBuiltinCallExpr(Loc: Init->getExprLoc(), Id: BID, CallArgs: Args);
4059 if (C->getType()->getCanonicalTypeUnqualified() !=
4060 VDecl->getType()->getCanonicalTypeUnqualified()) {
4061 C = SemaRef
4062 .BuildCStyleCastExpr(LParenLoc: SourceLocation(),
4063 Ty: Context.getTrivialTypeSourceInfo(
4064 T: Init->getType(), Loc: Init->getExprLoc()),
4065 RParenLoc: SourceLocation(), Op: C)
4066 .get();
4067 }
4068 Init = C;
4069 return true;
4070}
4071

source code of clang/lib/Sema/SemaHLSL.cpp