1 | //===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // This implements Semantic Analysis for HLSL constructs. |
9 | //===----------------------------------------------------------------------===// |
10 | |
11 | #include "clang/Sema/SemaHLSL.h" |
12 | #include "clang/Basic/DiagnosticSema.h" |
13 | #include "clang/Basic/LLVM.h" |
14 | #include "clang/Basic/TargetInfo.h" |
15 | #include "clang/Sema/Sema.h" |
16 | #include "llvm/ADT/STLExtras.h" |
17 | #include "llvm/ADT/StringExtras.h" |
18 | #include "llvm/ADT/StringRef.h" |
19 | #include "llvm/Support/ErrorHandling.h" |
20 | #include "llvm/TargetParser/Triple.h" |
21 | #include <iterator> |
22 | |
23 | using namespace clang; |
24 | |
25 | SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {} |
26 | |
27 | Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer, |
28 | SourceLocation KwLoc, IdentifierInfo *Ident, |
29 | SourceLocation IdentLoc, |
30 | SourceLocation LBrace) { |
31 | // For anonymous namespace, take the location of the left brace. |
32 | DeclContext *LexicalParent = SemaRef.getCurLexicalContext(); |
33 | HLSLBufferDecl *Result = HLSLBufferDecl::Create( |
34 | C&: getASTContext(), LexicalParent, CBuffer, KwLoc, ID: Ident, IDLoc: IdentLoc, LBrace); |
35 | |
36 | SemaRef.PushOnScopeChains(Result, BufferScope); |
37 | SemaRef.PushDeclContext(BufferScope, Result); |
38 | |
39 | return Result; |
40 | } |
41 | |
42 | void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) { |
43 | auto *BufDecl = cast<HLSLBufferDecl>(Val: Dcl); |
44 | BufDecl->setRBraceLoc(RBrace); |
45 | SemaRef.PopDeclContext(); |
46 | } |
47 | |
48 | HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D, |
49 | const AttributeCommonInfo &AL, |
50 | int X, int Y, int Z) { |
51 | if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) { |
52 | if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) { |
53 | Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL; |
54 | Diag(AL.getLoc(), diag::note_conflicting_attribute); |
55 | } |
56 | return nullptr; |
57 | } |
58 | return ::new (getASTContext()) |
59 | HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z); |
60 | } |
61 | |
62 | HLSLShaderAttr * |
63 | SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL, |
64 | HLSLShaderAttr::ShaderType ShaderType) { |
65 | if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) { |
66 | if (NT->getType() != ShaderType) { |
67 | Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL; |
68 | Diag(AL.getLoc(), diag::note_conflicting_attribute); |
69 | } |
70 | return nullptr; |
71 | } |
72 | return HLSLShaderAttr::Create(getASTContext(), ShaderType, AL); |
73 | } |
74 | |
75 | HLSLParamModifierAttr * |
76 | SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL, |
77 | HLSLParamModifierAttr::Spelling Spelling) { |
78 | // We can only merge an `in` attribute with an `out` attribute. All other |
79 | // combinations of duplicated attributes are ill-formed. |
80 | if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) { |
81 | if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) || |
82 | (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) { |
83 | D->dropAttr<HLSLParamModifierAttr>(); |
84 | SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()}; |
85 | return HLSLParamModifierAttr::Create( |
86 | getASTContext(), /*MergedSpelling=*/true, AdjustedRange, |
87 | HLSLParamModifierAttr::Keyword_inout); |
88 | } |
89 | Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL; |
90 | Diag(PA->getLocation(), diag::note_conflicting_attribute); |
91 | return nullptr; |
92 | } |
93 | return HLSLParamModifierAttr::Create(getASTContext(), AL); |
94 | } |
95 | |
96 | void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) { |
97 | auto &TargetInfo = getASTContext().getTargetInfo(); |
98 | |
99 | if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry) |
100 | return; |
101 | |
102 | StringRef Env = TargetInfo.getTriple().getEnvironmentName(); |
103 | HLSLShaderAttr::ShaderType ShaderType; |
104 | if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) { |
105 | if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) { |
106 | // The entry point is already annotated - check that it matches the |
107 | // triple. |
108 | if (Shader->getType() != ShaderType) { |
109 | Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch) |
110 | << Shader; |
111 | FD->setInvalidDecl(); |
112 | } |
113 | } else { |
114 | // Implicitly add the shader attribute if the entry function isn't |
115 | // explicitly annotated. |
116 | FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), ShaderType, |
117 | FD->getBeginLoc())); |
118 | } |
119 | } else { |
120 | switch (TargetInfo.getTriple().getEnvironment()) { |
121 | case llvm::Triple::UnknownEnvironment: |
122 | case llvm::Triple::Library: |
123 | break; |
124 | default: |
125 | llvm_unreachable("Unhandled environment in triple" ); |
126 | } |
127 | } |
128 | } |
129 | |
130 | void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) { |
131 | const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>(); |
132 | assert(ShaderAttr && "Entry point has no shader attribute" ); |
133 | HLSLShaderAttr::ShaderType ST = ShaderAttr->getType(); |
134 | |
135 | switch (ST) { |
136 | case HLSLShaderAttr::Pixel: |
137 | case HLSLShaderAttr::Vertex: |
138 | case HLSLShaderAttr::Geometry: |
139 | case HLSLShaderAttr::Hull: |
140 | case HLSLShaderAttr::Domain: |
141 | case HLSLShaderAttr::RayGeneration: |
142 | case HLSLShaderAttr::Intersection: |
143 | case HLSLShaderAttr::AnyHit: |
144 | case HLSLShaderAttr::ClosestHit: |
145 | case HLSLShaderAttr::Miss: |
146 | case HLSLShaderAttr::Callable: |
147 | if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) { |
148 | DiagnoseAttrStageMismatch(NT, ST, |
149 | {HLSLShaderAttr::Compute, |
150 | HLSLShaderAttr::Amplification, |
151 | HLSLShaderAttr::Mesh}); |
152 | FD->setInvalidDecl(); |
153 | } |
154 | break; |
155 | |
156 | case HLSLShaderAttr::Compute: |
157 | case HLSLShaderAttr::Amplification: |
158 | case HLSLShaderAttr::Mesh: |
159 | if (!FD->hasAttr<HLSLNumThreadsAttr>()) { |
160 | Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads) |
161 | << HLSLShaderAttr::ConvertShaderTypeToStr(ST); |
162 | FD->setInvalidDecl(); |
163 | } |
164 | break; |
165 | } |
166 | |
167 | for (ParmVarDecl *Param : FD->parameters()) { |
168 | if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) { |
169 | CheckSemanticAnnotation(EntryPoint: FD, Param, AnnotationAttr: AnnotationAttr); |
170 | } else { |
171 | // FIXME: Handle struct parameters where annotations are on struct fields. |
172 | // See: https://github.com/llvm/llvm-project/issues/57875 |
173 | Diag(FD->getLocation(), diag::err_hlsl_missing_semantic_annotation); |
174 | Diag(Param->getLocation(), diag::note_previous_decl) << Param; |
175 | FD->setInvalidDecl(); |
176 | } |
177 | } |
178 | // FIXME: Verify return type semantic annotation. |
179 | } |
180 | |
181 | void SemaHLSL::CheckSemanticAnnotation( |
182 | FunctionDecl *EntryPoint, const Decl *Param, |
183 | const HLSLAnnotationAttr *AnnotationAttr) { |
184 | auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>(); |
185 | assert(ShaderAttr && "Entry point has no shader attribute" ); |
186 | HLSLShaderAttr::ShaderType ST = ShaderAttr->getType(); |
187 | |
188 | switch (AnnotationAttr->getKind()) { |
189 | case attr::HLSLSV_DispatchThreadID: |
190 | case attr::HLSLSV_GroupIndex: |
191 | if (ST == HLSLShaderAttr::Compute) |
192 | return; |
193 | DiagnoseAttrStageMismatch(AnnotationAttr, ST, {HLSLShaderAttr::Compute}); |
194 | break; |
195 | default: |
196 | llvm_unreachable("Unknown HLSLAnnotationAttr" ); |
197 | } |
198 | } |
199 | |
200 | void SemaHLSL::DiagnoseAttrStageMismatch( |
201 | const Attr *A, HLSLShaderAttr::ShaderType Stage, |
202 | std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages) { |
203 | SmallVector<StringRef, 8> StageStrings; |
204 | llvm::transform(AllowedStages, std::back_inserter(x&: StageStrings), |
205 | [](HLSLShaderAttr::ShaderType ST) { |
206 | return StringRef( |
207 | HLSLShaderAttr::ConvertShaderTypeToStr(ST)); |
208 | }); |
209 | Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage) |
210 | << A << HLSLShaderAttr::ConvertShaderTypeToStr(Stage) |
211 | << (AllowedStages.size() != 1) << join(StageStrings, ", " ); |
212 | } |
213 | |