1//=== ParseHLSLRootSignature.cpp - Parse Root Signature -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "clang/Parse/ParseHLSLRootSignature.h"
10
11#include "clang/Lex/LiteralSupport.h"
12
13using namespace llvm::hlsl::rootsig;
14
15namespace clang {
16namespace hlsl {
17
18using TokenKind = RootSignatureToken::Kind;
19
20static const TokenKind RootElementKeywords[] = {
21 TokenKind::kw_RootFlags,
22 TokenKind::kw_CBV,
23 TokenKind::kw_UAV,
24 TokenKind::kw_SRV,
25 TokenKind::kw_DescriptorTable,
26 TokenKind::kw_StaticSampler,
27};
28
29RootSignatureParser::RootSignatureParser(
30 llvm::dxbc::RootSignatureVersion Version,
31 SmallVector<RootSignatureElement> &Elements, StringLiteral *Signature,
32 Preprocessor &PP)
33 : Version(Version), Elements(Elements), Signature(Signature),
34 Lexer(Signature->getString()), PP(PP), CurToken(0) {}
35
36bool RootSignatureParser::parse() {
37 // Iterate as many RootSignatureElements as possible, until we hit the
38 // end of the stream
39 bool HadError = false;
40 while (!peekExpectedToken(Expected: TokenKind::end_of_stream)) {
41 if (tryConsumeExpectedToken(Expected: TokenKind::kw_RootFlags)) {
42 SourceLocation ElementLoc = getTokenLocation(Tok: CurToken);
43 auto Flags = parseRootFlags();
44 if (!Flags.has_value()) {
45 HadError = true;
46 skipUntilExpectedToken(Expected: RootElementKeywords);
47 continue;
48 }
49
50 Elements.emplace_back(Args&: ElementLoc, Args&: *Flags);
51 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_RootConstants)) {
52 SourceLocation ElementLoc = getTokenLocation(Tok: CurToken);
53 auto Constants = parseRootConstants();
54 if (!Constants.has_value()) {
55 HadError = true;
56 skipUntilExpectedToken(Expected: RootElementKeywords);
57 continue;
58 }
59 Elements.emplace_back(Args&: ElementLoc, Args&: *Constants);
60 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_DescriptorTable)) {
61 SourceLocation ElementLoc = getTokenLocation(Tok: CurToken);
62 auto Table = parseDescriptorTable();
63 if (!Table.has_value()) {
64 HadError = true;
65 // We are within a DescriptorTable, we will do our best to recover
66 // by skipping until we encounter the expected closing ')'.
67 skipUntilClosedParens();
68 consumeNextToken();
69 skipUntilExpectedToken(Expected: RootElementKeywords);
70 continue;
71 }
72 Elements.emplace_back(Args&: ElementLoc, Args&: *Table);
73 } else if (tryConsumeExpectedToken(
74 Expected: {TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
75 SourceLocation ElementLoc = getTokenLocation(Tok: CurToken);
76 auto Descriptor = parseRootDescriptor();
77 if (!Descriptor.has_value()) {
78 HadError = true;
79 skipUntilExpectedToken(Expected: RootElementKeywords);
80 continue;
81 }
82 Elements.emplace_back(Args&: ElementLoc, Args&: *Descriptor);
83 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_StaticSampler)) {
84 SourceLocation ElementLoc = getTokenLocation(Tok: CurToken);
85 auto Sampler = parseStaticSampler();
86 if (!Sampler.has_value()) {
87 HadError = true;
88 skipUntilExpectedToken(Expected: RootElementKeywords);
89 continue;
90 }
91 Elements.emplace_back(Args&: ElementLoc, Args&: *Sampler);
92 } else {
93 HadError = true;
94 consumeNextToken(); // let diagnostic be at the start of invalid token
95 reportDiag(DiagID: diag::err_hlsl_invalid_token)
96 << /*parameter=*/0 << /*param of*/ TokenKind::kw_RootSignature;
97 skipUntilExpectedToken(Expected: RootElementKeywords);
98 continue;
99 }
100
101 if (!tryConsumeExpectedToken(Expected: TokenKind::pu_comma)) {
102 // ',' denotes another element, otherwise, expected to be at end of stream
103 break;
104 }
105 }
106
107 return HadError ||
108 consumeExpectedToken(Expected: TokenKind::end_of_stream,
109 DiagID: diag::err_expected_either, Context: TokenKind::pu_comma);
110}
111
112template <typename FlagType>
113static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) {
114 if (!Flags.has_value())
115 return Flag;
116
117 return static_cast<FlagType>(llvm::to_underlying(Flags.value()) |
118 llvm::to_underlying(Flag));
119}
120
121std::optional<llvm::dxbc::RootFlags> RootSignatureParser::parseRootFlags() {
122 assert(CurToken.TokKind == TokenKind::kw_RootFlags &&
123 "Expects to only be invoked starting at given keyword");
124
125 if (consumeExpectedToken(Expected: TokenKind::pu_l_paren, DiagID: diag::err_expected_after,
126 Context: CurToken.TokKind))
127 return std::nullopt;
128
129 std::optional<llvm::dxbc::RootFlags> Flags = llvm::dxbc::RootFlags::None;
130
131 // Handle valid empty case
132 if (tryConsumeExpectedToken(Expected: TokenKind::pu_r_paren))
133 return Flags;
134
135 // Handle the edge-case of '0' to specify no flags set
136 if (tryConsumeExpectedToken(Expected: TokenKind::int_literal)) {
137 if (!verifyZeroFlag()) {
138 reportDiag(DiagID: diag::err_hlsl_rootsig_non_zero_flag);
139 return std::nullopt;
140 }
141 } else {
142 // Otherwise, parse as many flags as possible
143 TokenKind Expected[] = {
144#define ROOT_FLAG_ENUM(NAME, LIT) TokenKind::en_##NAME,
145#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
146 };
147
148 do {
149 if (tryConsumeExpectedToken(Expected)) {
150 switch (CurToken.TokKind) {
151#define ROOT_FLAG_ENUM(NAME, LIT) \
152 case TokenKind::en_##NAME: \
153 Flags = maybeOrFlag<llvm::dxbc::RootFlags>(Flags, \
154 llvm::dxbc::RootFlags::NAME); \
155 break;
156#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
157 default:
158 llvm_unreachable("Switch for consumed enum token was not provided");
159 }
160 } else {
161 consumeNextToken(); // consume token to point at invalid token
162 reportDiag(DiagID: diag::err_hlsl_invalid_token)
163 << /*value=*/1 << /*value of*/ TokenKind::kw_RootFlags;
164 return std::nullopt;
165 }
166 } while (tryConsumeExpectedToken(Expected: TokenKind::pu_or));
167 }
168
169 if (consumeExpectedToken(Expected: TokenKind::pu_r_paren, DiagID: diag::err_expected_either,
170 Context: TokenKind::pu_comma))
171 return std::nullopt;
172
173 return Flags;
174}
175
176std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
177 assert(CurToken.TokKind == TokenKind::kw_RootConstants &&
178 "Expects to only be invoked starting at given keyword");
179
180 if (consumeExpectedToken(Expected: TokenKind::pu_l_paren, DiagID: diag::err_expected_after,
181 Context: CurToken.TokKind))
182 return std::nullopt;
183
184 RootConstants Constants;
185
186 auto Params = parseRootConstantParams();
187 if (!Params.has_value())
188 return std::nullopt;
189
190 if (consumeExpectedToken(Expected: TokenKind::pu_r_paren, DiagID: diag::err_expected_either,
191 Context: TokenKind::pu_comma))
192 return std::nullopt;
193
194 // Check mandatory parameters where provided
195 if (!Params->Num32BitConstants.has_value()) {
196 reportDiag(DiagID: diag::err_hlsl_rootsig_missing_param)
197 << TokenKind::kw_num32BitConstants;
198 return std::nullopt;
199 }
200
201 Constants.Num32BitConstants = Params->Num32BitConstants.value();
202
203 if (!Params->Reg.has_value()) {
204 reportDiag(DiagID: diag::err_hlsl_rootsig_missing_param) << TokenKind::bReg;
205 return std::nullopt;
206 }
207
208 Constants.Reg = Params->Reg.value();
209
210 // Fill in optional parameters
211 if (Params->Visibility.has_value())
212 Constants.Visibility = Params->Visibility.value();
213
214 if (Params->Space.has_value())
215 Constants.Space = Params->Space.value();
216
217 return Constants;
218}
219
220std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
221 assert((CurToken.TokKind == TokenKind::kw_CBV ||
222 CurToken.TokKind == TokenKind::kw_SRV ||
223 CurToken.TokKind == TokenKind::kw_UAV) &&
224 "Expects to only be invoked starting at given keyword");
225
226 TokenKind DescriptorKind = CurToken.TokKind;
227
228 if (consumeExpectedToken(Expected: TokenKind::pu_l_paren, DiagID: diag::err_expected_after,
229 Context: CurToken.TokKind))
230 return std::nullopt;
231
232 RootDescriptor Descriptor;
233 TokenKind ExpectedReg;
234 switch (DescriptorKind) {
235 default:
236 llvm_unreachable("Switch for consumed token was not provided");
237 case TokenKind::kw_CBV:
238 Descriptor.Type = DescriptorType::CBuffer;
239 ExpectedReg = TokenKind::bReg;
240 break;
241 case TokenKind::kw_SRV:
242 Descriptor.Type = DescriptorType::SRV;
243 ExpectedReg = TokenKind::tReg;
244 break;
245 case TokenKind::kw_UAV:
246 Descriptor.Type = DescriptorType::UAV;
247 ExpectedReg = TokenKind::uReg;
248 break;
249 }
250 Descriptor.setDefaultFlags(Version);
251
252 auto Params = parseRootDescriptorParams(DescKind: DescriptorKind, RegType: ExpectedReg);
253 if (!Params.has_value())
254 return std::nullopt;
255
256 if (consumeExpectedToken(Expected: TokenKind::pu_r_paren, DiagID: diag::err_expected_either,
257 Context: TokenKind::pu_comma))
258 return std::nullopt;
259
260 // Check mandatory parameters were provided
261 if (!Params->Reg.has_value()) {
262 reportDiag(DiagID: diag::err_hlsl_rootsig_missing_param) << ExpectedReg;
263 return std::nullopt;
264 }
265
266 Descriptor.Reg = Params->Reg.value();
267
268 // Fill in optional values
269 if (Params->Space.has_value())
270 Descriptor.Space = Params->Space.value();
271
272 if (Params->Visibility.has_value())
273 Descriptor.Visibility = Params->Visibility.value();
274
275 if (Params->Flags.has_value())
276 Descriptor.Flags = Params->Flags.value();
277
278 return Descriptor;
279}
280
281std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
282 assert(CurToken.TokKind == TokenKind::kw_DescriptorTable &&
283 "Expects to only be invoked starting at given keyword");
284
285 if (consumeExpectedToken(Expected: TokenKind::pu_l_paren, DiagID: diag::err_expected_after,
286 Context: CurToken.TokKind))
287 return std::nullopt;
288
289 DescriptorTable Table;
290 std::optional<llvm::dxbc::ShaderVisibility> Visibility;
291
292 // Iterate as many Clauses as possible, until we hit ')'
293 while (!peekExpectedToken(Expected: TokenKind::pu_r_paren)) {
294 if (tryConsumeExpectedToken(Expected: {TokenKind::kw_CBV, TokenKind::kw_SRV,
295 TokenKind::kw_UAV, TokenKind::kw_Sampler})) {
296 // DescriptorTableClause - CBV, SRV, UAV, or Sampler
297 SourceLocation ElementLoc = getTokenLocation(Tok: CurToken);
298 auto Clause = parseDescriptorTableClause();
299 if (!Clause.has_value()) {
300 // We are within a DescriptorTableClause, we will do our best to recover
301 // by skipping until we encounter the expected closing ')'
302 skipUntilExpectedToken(Expected: TokenKind::pu_r_paren);
303 consumeNextToken();
304 return std::nullopt;
305 }
306 Elements.emplace_back(Args&: ElementLoc, Args&: *Clause);
307 Table.NumClauses++;
308 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_visibility)) {
309 // visibility = SHADER_VISIBILITY
310 if (Visibility.has_value()) {
311 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
312 return std::nullopt;
313 }
314
315 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
316 return std::nullopt;
317
318 Visibility = parseShaderVisibility(Context: TokenKind::kw_visibility);
319 if (!Visibility.has_value())
320 return std::nullopt;
321 } else {
322 consumeNextToken(); // let diagnostic be at the start of invalid token
323 reportDiag(DiagID: diag::err_hlsl_invalid_token)
324 << /*parameter=*/0 << /*param of*/ TokenKind::kw_DescriptorTable;
325 return std::nullopt;
326 }
327
328 // ',' denotes another element, otherwise, expected to be at ')'
329 if (!tryConsumeExpectedToken(Expected: TokenKind::pu_comma))
330 break;
331 }
332
333 if (consumeExpectedToken(Expected: TokenKind::pu_r_paren, DiagID: diag::err_expected_either,
334 Context: TokenKind::pu_comma))
335 return std::nullopt;
336
337 // Fill in optional visibility
338 if (Visibility.has_value())
339 Table.Visibility = Visibility.value();
340
341 return Table;
342}
343
344std::optional<DescriptorTableClause>
345RootSignatureParser::parseDescriptorTableClause() {
346 assert((CurToken.TokKind == TokenKind::kw_CBV ||
347 CurToken.TokKind == TokenKind::kw_SRV ||
348 CurToken.TokKind == TokenKind::kw_UAV ||
349 CurToken.TokKind == TokenKind::kw_Sampler) &&
350 "Expects to only be invoked starting at given keyword");
351
352 TokenKind ParamKind = CurToken.TokKind;
353
354 if (consumeExpectedToken(Expected: TokenKind::pu_l_paren, DiagID: diag::err_expected_after,
355 Context: CurToken.TokKind))
356 return std::nullopt;
357
358 DescriptorTableClause Clause;
359 TokenKind ExpectedReg;
360 switch (ParamKind) {
361 default:
362 llvm_unreachable("Switch for consumed token was not provided");
363 case TokenKind::kw_CBV:
364 Clause.Type = ClauseType::CBuffer;
365 ExpectedReg = TokenKind::bReg;
366 break;
367 case TokenKind::kw_SRV:
368 Clause.Type = ClauseType::SRV;
369 ExpectedReg = TokenKind::tReg;
370 break;
371 case TokenKind::kw_UAV:
372 Clause.Type = ClauseType::UAV;
373 ExpectedReg = TokenKind::uReg;
374 break;
375 case TokenKind::kw_Sampler:
376 Clause.Type = ClauseType::Sampler;
377 ExpectedReg = TokenKind::sReg;
378 break;
379 }
380 Clause.setDefaultFlags(Version);
381
382 auto Params = parseDescriptorTableClauseParams(ClauseKind: ParamKind, RegType: ExpectedReg);
383 if (!Params.has_value())
384 return std::nullopt;
385
386 if (consumeExpectedToken(Expected: TokenKind::pu_r_paren, DiagID: diag::err_expected_either,
387 Context: TokenKind::pu_comma))
388 return std::nullopt;
389
390 // Check mandatory parameters were provided
391 if (!Params->Reg.has_value()) {
392 reportDiag(DiagID: diag::err_hlsl_rootsig_missing_param) << ExpectedReg;
393 return std::nullopt;
394 }
395
396 Clause.Reg = Params->Reg.value();
397
398 // Fill in optional values
399 if (Params->NumDescriptors.has_value())
400 Clause.NumDescriptors = Params->NumDescriptors.value();
401
402 if (Params->Space.has_value())
403 Clause.Space = Params->Space.value();
404
405 if (Params->Offset.has_value())
406 Clause.Offset = Params->Offset.value();
407
408 if (Params->Flags.has_value())
409 Clause.Flags = Params->Flags.value();
410
411 return Clause;
412}
413
414std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {
415 assert(CurToken.TokKind == TokenKind::kw_StaticSampler &&
416 "Expects to only be invoked starting at given keyword");
417
418 if (consumeExpectedToken(Expected: TokenKind::pu_l_paren, DiagID: diag::err_expected_after,
419 Context: CurToken.TokKind))
420 return std::nullopt;
421
422 StaticSampler Sampler;
423
424 auto Params = parseStaticSamplerParams();
425 if (!Params.has_value())
426 return std::nullopt;
427
428 if (consumeExpectedToken(Expected: TokenKind::pu_r_paren, DiagID: diag::err_expected_either,
429 Context: TokenKind::pu_comma))
430 return std::nullopt;
431
432 // Check mandatory parameters were provided
433 if (!Params->Reg.has_value()) {
434 reportDiag(DiagID: diag::err_hlsl_rootsig_missing_param) << TokenKind::sReg;
435 return std::nullopt;
436 }
437
438 Sampler.Reg = Params->Reg.value();
439
440 // Fill in optional values
441 if (Params->Filter.has_value())
442 Sampler.Filter = Params->Filter.value();
443
444 if (Params->AddressU.has_value())
445 Sampler.AddressU = Params->AddressU.value();
446
447 if (Params->AddressV.has_value())
448 Sampler.AddressV = Params->AddressV.value();
449
450 if (Params->AddressW.has_value())
451 Sampler.AddressW = Params->AddressW.value();
452
453 if (Params->MipLODBias.has_value())
454 Sampler.MipLODBias = Params->MipLODBias.value();
455
456 if (Params->MaxAnisotropy.has_value())
457 Sampler.MaxAnisotropy = Params->MaxAnisotropy.value();
458
459 if (Params->CompFunc.has_value())
460 Sampler.CompFunc = Params->CompFunc.value();
461
462 if (Params->BorderColor.has_value())
463 Sampler.BorderColor = Params->BorderColor.value();
464
465 if (Params->MinLOD.has_value())
466 Sampler.MinLOD = Params->MinLOD.value();
467
468 if (Params->MaxLOD.has_value())
469 Sampler.MaxLOD = Params->MaxLOD.value();
470
471 if (Params->Space.has_value())
472 Sampler.Space = Params->Space.value();
473
474 if (Params->Visibility.has_value())
475 Sampler.Visibility = Params->Visibility.value();
476
477 return Sampler;
478}
479
480// Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
481// order and only exactly once. The following methods will parse through as
482// many arguments as possible reporting an error if a duplicate is seen.
483std::optional<RootSignatureParser::ParsedConstantParams>
484RootSignatureParser::parseRootConstantParams() {
485 assert(CurToken.TokKind == TokenKind::pu_l_paren &&
486 "Expects to only be invoked starting at given token");
487
488 ParsedConstantParams Params;
489 while (!peekExpectedToken(Expected: TokenKind::pu_r_paren)) {
490 if (tryConsumeExpectedToken(Expected: TokenKind::kw_num32BitConstants)) {
491 // `num32BitConstants` `=` POS_INT
492 if (Params.Num32BitConstants.has_value()) {
493 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
494 return std::nullopt;
495 }
496
497 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
498 return std::nullopt;
499
500 auto Num32BitConstants = parseUIntParam();
501 if (!Num32BitConstants.has_value())
502 return std::nullopt;
503 Params.Num32BitConstants = Num32BitConstants;
504 } else if (tryConsumeExpectedToken(Expected: TokenKind::bReg)) {
505 // `b` POS_INT
506 if (Params.Reg.has_value()) {
507 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
508 return std::nullopt;
509 }
510 auto Reg = parseRegister();
511 if (!Reg.has_value())
512 return std::nullopt;
513 Params.Reg = Reg;
514 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_space)) {
515 // `space` `=` POS_INT
516 if (Params.Space.has_value()) {
517 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
518 return std::nullopt;
519 }
520
521 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
522 return std::nullopt;
523
524 auto Space = parseUIntParam();
525 if (!Space.has_value())
526 return std::nullopt;
527 Params.Space = Space;
528 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_visibility)) {
529 // `visibility` `=` SHADER_VISIBILITY
530 if (Params.Visibility.has_value()) {
531 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
532 return std::nullopt;
533 }
534
535 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
536 return std::nullopt;
537
538 auto Visibility = parseShaderVisibility(Context: TokenKind::kw_visibility);
539 if (!Visibility.has_value())
540 return std::nullopt;
541 Params.Visibility = Visibility;
542 } else {
543 consumeNextToken(); // let diagnostic be at the start of invalid token
544 reportDiag(DiagID: diag::err_hlsl_invalid_token)
545 << /*parameter=*/0 << /*param of*/ TokenKind::kw_RootConstants;
546 return std::nullopt;
547 }
548
549 // ',' denotes another element, otherwise, expected to be at ')'
550 if (!tryConsumeExpectedToken(Expected: TokenKind::pu_comma))
551 break;
552 }
553
554 return Params;
555}
556
557std::optional<RootSignatureParser::ParsedRootDescriptorParams>
558RootSignatureParser::parseRootDescriptorParams(TokenKind DescKind,
559 TokenKind RegType) {
560 assert(CurToken.TokKind == TokenKind::pu_l_paren &&
561 "Expects to only be invoked starting at given token");
562
563 ParsedRootDescriptorParams Params;
564 while (!peekExpectedToken(Expected: TokenKind::pu_r_paren)) {
565 if (tryConsumeExpectedToken(Expected: RegType)) {
566 // ( `b` | `t` | `u`) POS_INT
567 if (Params.Reg.has_value()) {
568 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
569 return std::nullopt;
570 }
571 auto Reg = parseRegister();
572 if (!Reg.has_value())
573 return std::nullopt;
574 Params.Reg = Reg;
575 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_space)) {
576 // `space` `=` POS_INT
577 if (Params.Space.has_value()) {
578 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
579 return std::nullopt;
580 }
581
582 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
583 return std::nullopt;
584
585 auto Space = parseUIntParam();
586 if (!Space.has_value())
587 return std::nullopt;
588 Params.Space = Space;
589 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_visibility)) {
590 // `visibility` `=` SHADER_VISIBILITY
591 if (Params.Visibility.has_value()) {
592 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
593 return std::nullopt;
594 }
595
596 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
597 return std::nullopt;
598
599 auto Visibility = parseShaderVisibility(Context: TokenKind::kw_visibility);
600 if (!Visibility.has_value())
601 return std::nullopt;
602 Params.Visibility = Visibility;
603 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_flags)) {
604 // `flags` `=` ROOT_DESCRIPTOR_FLAGS
605 if (Params.Flags.has_value()) {
606 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
607 return std::nullopt;
608 }
609
610 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
611 return std::nullopt;
612
613 auto Flags = parseRootDescriptorFlags(Context: TokenKind::kw_flags);
614 if (!Flags.has_value())
615 return std::nullopt;
616 Params.Flags = Flags;
617 } else {
618 consumeNextToken(); // let diagnostic be at the start of invalid token
619 reportDiag(DiagID: diag::err_hlsl_invalid_token)
620 << /*parameter=*/0 << /*param of*/ DescKind;
621 return std::nullopt;
622 }
623
624 // ',' denotes another element, otherwise, expected to be at ')'
625 if (!tryConsumeExpectedToken(Expected: TokenKind::pu_comma))
626 break;
627 }
628
629 return Params;
630}
631
632std::optional<RootSignatureParser::ParsedClauseParams>
633RootSignatureParser::parseDescriptorTableClauseParams(TokenKind ClauseKind,
634 TokenKind RegType) {
635 assert(CurToken.TokKind == TokenKind::pu_l_paren &&
636 "Expects to only be invoked starting at given token");
637
638 ParsedClauseParams Params;
639 while (!peekExpectedToken(Expected: TokenKind::pu_r_paren)) {
640 if (tryConsumeExpectedToken(Expected: RegType)) {
641 // ( `b` | `t` | `u` | `s`) POS_INT
642 if (Params.Reg.has_value()) {
643 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
644 return std::nullopt;
645 }
646 auto Reg = parseRegister();
647 if (!Reg.has_value())
648 return std::nullopt;
649 Params.Reg = Reg;
650 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_numDescriptors)) {
651 // `numDescriptors` `=` POS_INT | unbounded
652 if (Params.NumDescriptors.has_value()) {
653 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
654 return std::nullopt;
655 }
656
657 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
658 return std::nullopt;
659
660 std::optional<uint32_t> NumDescriptors;
661 if (tryConsumeExpectedToken(Expected: TokenKind::en_unbounded))
662 NumDescriptors = NumDescriptorsUnbounded;
663 else {
664 NumDescriptors = parseUIntParam();
665 if (!NumDescriptors.has_value())
666 return std::nullopt;
667 }
668
669 Params.NumDescriptors = NumDescriptors;
670 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_space)) {
671 // `space` `=` POS_INT
672 if (Params.Space.has_value()) {
673 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
674 return std::nullopt;
675 }
676
677 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
678 return std::nullopt;
679
680 auto Space = parseUIntParam();
681 if (!Space.has_value())
682 return std::nullopt;
683 Params.Space = Space;
684 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_offset)) {
685 // `offset` `=` POS_INT | DESCRIPTOR_RANGE_OFFSET_APPEND
686 if (Params.Offset.has_value()) {
687 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
688 return std::nullopt;
689 }
690
691 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
692 return std::nullopt;
693
694 std::optional<uint32_t> Offset;
695 if (tryConsumeExpectedToken(Expected: TokenKind::en_DescriptorRangeOffsetAppend))
696 Offset = DescriptorTableOffsetAppend;
697 else {
698 Offset = parseUIntParam();
699 if (!Offset.has_value())
700 return std::nullopt;
701 }
702
703 Params.Offset = Offset;
704 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_flags)) {
705 // `flags` `=` DESCRIPTOR_RANGE_FLAGS
706 if (Params.Flags.has_value()) {
707 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
708 return std::nullopt;
709 }
710
711 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
712 return std::nullopt;
713
714 auto Flags = parseDescriptorRangeFlags(Context: TokenKind::kw_flags);
715 if (!Flags.has_value())
716 return std::nullopt;
717 Params.Flags = Flags;
718 } else {
719 consumeNextToken(); // let diagnostic be at the start of invalid token
720 reportDiag(DiagID: diag::err_hlsl_invalid_token)
721 << /*parameter=*/0 << /*param of*/ ClauseKind;
722 return std::nullopt;
723 }
724
725 // ',' denotes another element, otherwise, expected to be at ')'
726 if (!tryConsumeExpectedToken(Expected: TokenKind::pu_comma))
727 break;
728 }
729
730 return Params;
731}
732
733std::optional<RootSignatureParser::ParsedStaticSamplerParams>
734RootSignatureParser::parseStaticSamplerParams() {
735 assert(CurToken.TokKind == TokenKind::pu_l_paren &&
736 "Expects to only be invoked starting at given token");
737
738 ParsedStaticSamplerParams Params;
739 while (!peekExpectedToken(Expected: TokenKind::pu_r_paren)) {
740 if (tryConsumeExpectedToken(Expected: TokenKind::sReg)) {
741 // `s` POS_INT
742 if (Params.Reg.has_value()) {
743 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
744 return std::nullopt;
745 }
746 auto Reg = parseRegister();
747 if (!Reg.has_value())
748 return std::nullopt;
749 Params.Reg = Reg;
750 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_filter)) {
751 // `filter` `=` FILTER
752 if (Params.Filter.has_value()) {
753 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
754 return std::nullopt;
755 }
756
757 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
758 return std::nullopt;
759
760 auto Filter = parseSamplerFilter(Context: TokenKind::kw_filter);
761 if (!Filter.has_value())
762 return std::nullopt;
763 Params.Filter = Filter;
764 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_addressU)) {
765 // `addressU` `=` TEXTURE_ADDRESS
766 if (Params.AddressU.has_value()) {
767 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
768 return std::nullopt;
769 }
770
771 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
772 return std::nullopt;
773
774 auto AddressU = parseTextureAddressMode(Context: TokenKind::kw_addressU);
775 if (!AddressU.has_value())
776 return std::nullopt;
777 Params.AddressU = AddressU;
778 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_addressV)) {
779 // `addressV` `=` TEXTURE_ADDRESS
780 if (Params.AddressV.has_value()) {
781 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
782 return std::nullopt;
783 }
784
785 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
786 return std::nullopt;
787
788 auto AddressV = parseTextureAddressMode(Context: TokenKind::kw_addressV);
789 if (!AddressV.has_value())
790 return std::nullopt;
791 Params.AddressV = AddressV;
792 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_addressW)) {
793 // `addressW` `=` TEXTURE_ADDRESS
794 if (Params.AddressW.has_value()) {
795 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
796 return std::nullopt;
797 }
798
799 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
800 return std::nullopt;
801
802 auto AddressW = parseTextureAddressMode(Context: TokenKind::kw_addressW);
803 if (!AddressW.has_value())
804 return std::nullopt;
805 Params.AddressW = AddressW;
806 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_mipLODBias)) {
807 // `mipLODBias` `=` NUMBER
808 if (Params.MipLODBias.has_value()) {
809 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
810 return std::nullopt;
811 }
812
813 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
814 return std::nullopt;
815
816 auto MipLODBias = parseFloatParam();
817 if (!MipLODBias.has_value())
818 return std::nullopt;
819 Params.MipLODBias = MipLODBias;
820 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_maxAnisotropy)) {
821 // `maxAnisotropy` `=` POS_INT
822 if (Params.MaxAnisotropy.has_value()) {
823 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
824 return std::nullopt;
825 }
826
827 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
828 return std::nullopt;
829
830 auto MaxAnisotropy = parseUIntParam();
831 if (!MaxAnisotropy.has_value())
832 return std::nullopt;
833 Params.MaxAnisotropy = MaxAnisotropy;
834 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_comparisonFunc)) {
835 // `comparisonFunc` `=` COMPARISON_FUNC
836 if (Params.CompFunc.has_value()) {
837 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
838 return std::nullopt;
839 }
840
841 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
842 return std::nullopt;
843
844 auto CompFunc = parseComparisonFunc(Context: TokenKind::kw_comparisonFunc);
845 if (!CompFunc.has_value())
846 return std::nullopt;
847 Params.CompFunc = CompFunc;
848 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_borderColor)) {
849 // `borderColor` `=` STATIC_BORDER_COLOR
850 if (Params.BorderColor.has_value()) {
851 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
852 return std::nullopt;
853 }
854
855 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
856 return std::nullopt;
857
858 auto BorderColor = parseStaticBorderColor(Context: TokenKind::kw_borderColor);
859 if (!BorderColor.has_value())
860 return std::nullopt;
861 Params.BorderColor = BorderColor;
862 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_minLOD)) {
863 // `minLOD` `=` NUMBER
864 if (Params.MinLOD.has_value()) {
865 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
866 return std::nullopt;
867 }
868
869 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
870 return std::nullopt;
871
872 auto MinLOD = parseFloatParam();
873 if (!MinLOD.has_value())
874 return std::nullopt;
875 Params.MinLOD = MinLOD;
876 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_maxLOD)) {
877 // `maxLOD` `=` NUMBER
878 if (Params.MaxLOD.has_value()) {
879 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
880 return std::nullopt;
881 }
882
883 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
884 return std::nullopt;
885
886 auto MaxLOD = parseFloatParam();
887 if (!MaxLOD.has_value())
888 return std::nullopt;
889 Params.MaxLOD = MaxLOD;
890 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_space)) {
891 // `space` `=` POS_INT
892 if (Params.Space.has_value()) {
893 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
894 return std::nullopt;
895 }
896
897 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
898 return std::nullopt;
899
900 auto Space = parseUIntParam();
901 if (!Space.has_value())
902 return std::nullopt;
903 Params.Space = Space;
904 } else if (tryConsumeExpectedToken(Expected: TokenKind::kw_visibility)) {
905 // `visibility` `=` SHADER_VISIBILITY
906 if (Params.Visibility.has_value()) {
907 reportDiag(DiagID: diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
908 return std::nullopt;
909 }
910
911 if (consumeExpectedToken(Expected: TokenKind::pu_equal))
912 return std::nullopt;
913
914 auto Visibility = parseShaderVisibility(Context: TokenKind::kw_visibility);
915 if (!Visibility.has_value())
916 return std::nullopt;
917 Params.Visibility = Visibility;
918 } else {
919 consumeNextToken(); // let diagnostic be at the start of invalid token
920 reportDiag(DiagID: diag::err_hlsl_invalid_token)
921 << /*parameter=*/0 << /*param of*/ TokenKind::kw_StaticSampler;
922 return std::nullopt;
923 }
924
925 // ',' denotes another element, otherwise, expected to be at ')'
926 if (!tryConsumeExpectedToken(Expected: TokenKind::pu_comma))
927 break;
928 }
929
930 return Params;
931}
932
933std::optional<uint32_t> RootSignatureParser::parseUIntParam() {
934 assert(CurToken.TokKind == TokenKind::pu_equal &&
935 "Expects to only be invoked starting at given keyword");
936 tryConsumeExpectedToken(Expected: TokenKind::pu_plus);
937 if (consumeExpectedToken(Expected: TokenKind::int_literal, DiagID: diag::err_expected_after,
938 Context: CurToken.TokKind))
939 return std::nullopt;
940 return handleUIntLiteral();
941}
942
943std::optional<Register> RootSignatureParser::parseRegister() {
944 assert((CurToken.TokKind == TokenKind::bReg ||
945 CurToken.TokKind == TokenKind::tReg ||
946 CurToken.TokKind == TokenKind::uReg ||
947 CurToken.TokKind == TokenKind::sReg) &&
948 "Expects to only be invoked starting at given keyword");
949
950 Register Reg;
951 switch (CurToken.TokKind) {
952 default:
953 llvm_unreachable("Switch for consumed token was not provided");
954 case TokenKind::bReg:
955 Reg.ViewType = RegisterType::BReg;
956 break;
957 case TokenKind::tReg:
958 Reg.ViewType = RegisterType::TReg;
959 break;
960 case TokenKind::uReg:
961 Reg.ViewType = RegisterType::UReg;
962 break;
963 case TokenKind::sReg:
964 Reg.ViewType = RegisterType::SReg;
965 break;
966 }
967
968 auto Number = handleUIntLiteral();
969 if (!Number.has_value())
970 return std::nullopt; // propogate NumericLiteralParser error
971
972 Reg.Number = *Number;
973 return Reg;
974}
975
976std::optional<float> RootSignatureParser::parseFloatParam() {
977 assert(CurToken.TokKind == TokenKind::pu_equal &&
978 "Expects to only be invoked starting at given keyword");
979 // Consume sign modifier
980 bool Signed =
981 tryConsumeExpectedToken(Expected: {TokenKind::pu_plus, TokenKind::pu_minus});
982 bool Negated = Signed && CurToken.TokKind == TokenKind::pu_minus;
983
984 // DXC will treat a postive signed integer as unsigned
985 if (!Negated && tryConsumeExpectedToken(Expected: TokenKind::int_literal)) {
986 std::optional<uint32_t> UInt = handleUIntLiteral();
987 if (!UInt.has_value())
988 return std::nullopt;
989 return float(UInt.value());
990 }
991
992 if (Negated && tryConsumeExpectedToken(Expected: TokenKind::int_literal)) {
993 std::optional<int32_t> Int = handleIntLiteral(Negated);
994 if (!Int.has_value())
995 return std::nullopt;
996 return float(Int.value());
997 }
998
999 if (tryConsumeExpectedToken(Expected: TokenKind::float_literal)) {
1000 std::optional<float> Float = handleFloatLiteral(Negated);
1001 if (!Float.has_value())
1002 return std::nullopt;
1003 return Float.value();
1004 }
1005
1006 return std::nullopt;
1007}
1008
1009std::optional<llvm::dxbc::ShaderVisibility>
1010RootSignatureParser::parseShaderVisibility(TokenKind Context) {
1011 assert(CurToken.TokKind == TokenKind::pu_equal &&
1012 "Expects to only be invoked starting at given keyword");
1013
1014 TokenKind Expected[] = {
1015#define SHADER_VISIBILITY_ENUM(NAME, LIT) TokenKind::en_##NAME,
1016#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1017 };
1018
1019 if (!tryConsumeExpectedToken(Expected)) {
1020 consumeNextToken(); // consume token to point at invalid token
1021 reportDiag(DiagID: diag::err_hlsl_invalid_token)
1022 << /*value=*/1 << /*value of*/ Context;
1023 return std::nullopt;
1024 }
1025
1026 switch (CurToken.TokKind) {
1027#define SHADER_VISIBILITY_ENUM(NAME, LIT) \
1028 case TokenKind::en_##NAME: \
1029 return llvm::dxbc::ShaderVisibility::NAME; \
1030 break;
1031#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1032 default:
1033 llvm_unreachable("Switch for consumed enum token was not provided");
1034 }
1035
1036 return std::nullopt;
1037}
1038
1039std::optional<llvm::dxbc::SamplerFilter>
1040RootSignatureParser::parseSamplerFilter(TokenKind Context) {
1041 assert(CurToken.TokKind == TokenKind::pu_equal &&
1042 "Expects to only be invoked starting at given keyword");
1043
1044 TokenKind Expected[] = {
1045#define FILTER_ENUM(NAME, LIT) TokenKind::en_##NAME,
1046#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1047 };
1048
1049 if (!tryConsumeExpectedToken(Expected)) {
1050 consumeNextToken(); // consume token to point at invalid token
1051 reportDiag(DiagID: diag::err_hlsl_invalid_token)
1052 << /*value=*/1 << /*value of*/ Context;
1053 return std::nullopt;
1054 }
1055
1056 switch (CurToken.TokKind) {
1057#define FILTER_ENUM(NAME, LIT) \
1058 case TokenKind::en_##NAME: \
1059 return llvm::dxbc::SamplerFilter::NAME; \
1060 break;
1061#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1062 default:
1063 llvm_unreachable("Switch for consumed enum token was not provided");
1064 }
1065
1066 return std::nullopt;
1067}
1068
1069std::optional<llvm::dxbc::TextureAddressMode>
1070RootSignatureParser::parseTextureAddressMode(TokenKind Context) {
1071 assert(CurToken.TokKind == TokenKind::pu_equal &&
1072 "Expects to only be invoked starting at given keyword");
1073
1074 TokenKind Expected[] = {
1075#define TEXTURE_ADDRESS_MODE_ENUM(NAME, LIT) TokenKind::en_##NAME,
1076#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1077 };
1078
1079 if (!tryConsumeExpectedToken(Expected)) {
1080 consumeNextToken(); // consume token to point at invalid token
1081 reportDiag(DiagID: diag::err_hlsl_invalid_token)
1082 << /*value=*/1 << /*value of*/ Context;
1083 return std::nullopt;
1084 }
1085
1086 switch (CurToken.TokKind) {
1087#define TEXTURE_ADDRESS_MODE_ENUM(NAME, LIT) \
1088 case TokenKind::en_##NAME: \
1089 return llvm::dxbc::TextureAddressMode::NAME; \
1090 break;
1091#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1092 default:
1093 llvm_unreachable("Switch for consumed enum token was not provided");
1094 }
1095
1096 return std::nullopt;
1097}
1098
1099std::optional<llvm::dxbc::ComparisonFunc>
1100RootSignatureParser::parseComparisonFunc(TokenKind Context) {
1101 assert(CurToken.TokKind == TokenKind::pu_equal &&
1102 "Expects to only be invoked starting at given keyword");
1103
1104 TokenKind Expected[] = {
1105#define COMPARISON_FUNC_ENUM(NAME, LIT) TokenKind::en_##NAME,
1106#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1107 };
1108
1109 if (!tryConsumeExpectedToken(Expected)) {
1110 consumeNextToken(); // consume token to point at invalid token
1111 reportDiag(DiagID: diag::err_hlsl_invalid_token)
1112 << /*value=*/1 << /*value of*/ Context;
1113 return std::nullopt;
1114 }
1115
1116 switch (CurToken.TokKind) {
1117#define COMPARISON_FUNC_ENUM(NAME, LIT) \
1118 case TokenKind::en_##NAME: \
1119 return llvm::dxbc::ComparisonFunc::NAME; \
1120 break;
1121#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1122 default:
1123 llvm_unreachable("Switch for consumed enum token was not provided");
1124 }
1125
1126 return std::nullopt;
1127}
1128
1129std::optional<llvm::dxbc::StaticBorderColor>
1130RootSignatureParser::parseStaticBorderColor(TokenKind Context) {
1131 assert(CurToken.TokKind == TokenKind::pu_equal &&
1132 "Expects to only be invoked starting at given keyword");
1133
1134 TokenKind Expected[] = {
1135#define STATIC_BORDER_COLOR_ENUM(NAME, LIT) TokenKind::en_##NAME,
1136#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1137 };
1138
1139 if (!tryConsumeExpectedToken(Expected)) {
1140 consumeNextToken(); // consume token to point at invalid token
1141 reportDiag(DiagID: diag::err_hlsl_invalid_token)
1142 << /*value=*/1 << /*value of*/ Context;
1143 return std::nullopt;
1144 }
1145
1146 switch (CurToken.TokKind) {
1147#define STATIC_BORDER_COLOR_ENUM(NAME, LIT) \
1148 case TokenKind::en_##NAME: \
1149 return llvm::dxbc::StaticBorderColor::NAME; \
1150 break;
1151#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1152 default:
1153 llvm_unreachable("Switch for consumed enum token was not provided");
1154 }
1155
1156 return std::nullopt;
1157}
1158
1159std::optional<llvm::dxbc::RootDescriptorFlags>
1160RootSignatureParser::parseRootDescriptorFlags(TokenKind Context) {
1161 assert(CurToken.TokKind == TokenKind::pu_equal &&
1162 "Expects to only be invoked starting at given keyword");
1163
1164 // Handle the edge-case of '0' to specify no flags set
1165 if (tryConsumeExpectedToken(Expected: TokenKind::int_literal)) {
1166 if (!verifyZeroFlag()) {
1167 reportDiag(DiagID: diag::err_hlsl_rootsig_non_zero_flag);
1168 return std::nullopt;
1169 }
1170 return llvm::dxbc::RootDescriptorFlags::None;
1171 }
1172
1173 TokenKind Expected[] = {
1174#define ROOT_DESCRIPTOR_FLAG_ENUM(NAME, LIT) TokenKind::en_##NAME,
1175#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1176 };
1177
1178 std::optional<llvm::dxbc::RootDescriptorFlags> Flags;
1179
1180 do {
1181 if (tryConsumeExpectedToken(Expected)) {
1182 switch (CurToken.TokKind) {
1183#define ROOT_DESCRIPTOR_FLAG_ENUM(NAME, LIT) \
1184 case TokenKind::en_##NAME: \
1185 Flags = maybeOrFlag<llvm::dxbc::RootDescriptorFlags>( \
1186 Flags, llvm::dxbc::RootDescriptorFlags::NAME); \
1187 break;
1188#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1189 default:
1190 llvm_unreachable("Switch for consumed enum token was not provided");
1191 }
1192 } else {
1193 consumeNextToken(); // consume token to point at invalid token
1194 reportDiag(DiagID: diag::err_hlsl_invalid_token)
1195 << /*value=*/1 << /*value of*/ Context;
1196 return std::nullopt;
1197 }
1198 } while (tryConsumeExpectedToken(Expected: TokenKind::pu_or));
1199
1200 return Flags;
1201}
1202
1203std::optional<llvm::dxbc::DescriptorRangeFlags>
1204RootSignatureParser::parseDescriptorRangeFlags(TokenKind Context) {
1205 assert(CurToken.TokKind == TokenKind::pu_equal &&
1206 "Expects to only be invoked starting at given keyword");
1207
1208 // Handle the edge-case of '0' to specify no flags set
1209 if (tryConsumeExpectedToken(Expected: TokenKind::int_literal)) {
1210 if (!verifyZeroFlag()) {
1211 reportDiag(DiagID: diag::err_hlsl_rootsig_non_zero_flag);
1212 return std::nullopt;
1213 }
1214 return llvm::dxbc::DescriptorRangeFlags::None;
1215 }
1216
1217 TokenKind Expected[] = {
1218#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON) TokenKind::en_##NAME,
1219#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1220 };
1221
1222 std::optional<llvm::dxbc::DescriptorRangeFlags> Flags;
1223
1224 do {
1225 if (tryConsumeExpectedToken(Expected)) {
1226 switch (CurToken.TokKind) {
1227#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON) \
1228 case TokenKind::en_##NAME: \
1229 Flags = maybeOrFlag<llvm::dxbc::DescriptorRangeFlags>( \
1230 Flags, llvm::dxbc::DescriptorRangeFlags::NAME); \
1231 break;
1232#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
1233 default:
1234 llvm_unreachable("Switch for consumed enum token was not provided");
1235 }
1236 } else {
1237 consumeNextToken(); // consume token to point at invalid token
1238 reportDiag(DiagID: diag::err_hlsl_invalid_token)
1239 << /*value=*/1 << /*value of*/ Context;
1240 return std::nullopt;
1241 }
1242 } while (tryConsumeExpectedToken(Expected: TokenKind::pu_or));
1243
1244 return Flags;
1245}
1246
1247std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
1248 // Parse the numeric value and do semantic checks on its specification
1249 clang::NumericLiteralParser Literal(
1250 CurToken.NumSpelling, getTokenLocation(Tok: CurToken), PP.getSourceManager(),
1251 PP.getLangOpts(), PP.getTargetInfo(), PP.getDiagnostics());
1252 if (Literal.hadError)
1253 return std::nullopt; // Error has already been reported so just return
1254
1255 assert(Literal.isIntegerLiteral() &&
1256 "NumSpelling can only consist of digits");
1257
1258 llvm::APSInt Val(32, /*IsUnsigned=*/true);
1259 if (Literal.GetIntegerValue(Val)) {
1260 // Report that the value has overflowed
1261 reportDiag(DiagID: diag::err_hlsl_number_literal_overflow)
1262 << /*integer type*/ 0 << /*is signed*/ 0;
1263 return std::nullopt;
1264 }
1265
1266 return Val.getExtValue();
1267}
1268
1269std::optional<int32_t> RootSignatureParser::handleIntLiteral(bool Negated) {
1270 // Parse the numeric value and do semantic checks on its specification
1271 clang::NumericLiteralParser Literal(
1272 CurToken.NumSpelling, getTokenLocation(Tok: CurToken), PP.getSourceManager(),
1273 PP.getLangOpts(), PP.getTargetInfo(), PP.getDiagnostics());
1274 if (Literal.hadError)
1275 return std::nullopt; // Error has already been reported so just return
1276
1277 assert(Literal.isIntegerLiteral() &&
1278 "NumSpelling can only consist of digits");
1279
1280 llvm::APSInt Val(32, /*IsUnsigned=*/true);
1281 // GetIntegerValue will overwrite Val from the parsed Literal and return
1282 // true if it overflows as a 32-bit unsigned int
1283 bool Overflowed = Literal.GetIntegerValue(Val);
1284
1285 // So we then need to check that it doesn't overflow as a 32-bit signed int:
1286 int64_t MaxNegativeMagnitude = -int64_t(std::numeric_limits<int32_t>::min());
1287 Overflowed |= (Negated && MaxNegativeMagnitude < Val.getExtValue());
1288
1289 int64_t MaxPositiveMagnitude = int64_t(std::numeric_limits<int32_t>::max());
1290 Overflowed |= (!Negated && MaxPositiveMagnitude < Val.getExtValue());
1291
1292 if (Overflowed) {
1293 // Report that the value has overflowed
1294 reportDiag(DiagID: diag::err_hlsl_number_literal_overflow)
1295 << /*integer type*/ 0 << /*is signed*/ 1;
1296 return std::nullopt;
1297 }
1298
1299 if (Negated)
1300 Val = -Val;
1301
1302 return int32_t(Val.getExtValue());
1303}
1304
1305std::optional<float> RootSignatureParser::handleFloatLiteral(bool Negated) {
1306 // Parse the numeric value and do semantic checks on its specification
1307 clang::NumericLiteralParser Literal(
1308 CurToken.NumSpelling, getTokenLocation(Tok: CurToken), PP.getSourceManager(),
1309 PP.getLangOpts(), PP.getTargetInfo(), PP.getDiagnostics());
1310 if (Literal.hadError)
1311 return std::nullopt; // Error has already been reported so just return
1312
1313 assert(Literal.isFloatingLiteral() &&
1314 "NumSpelling consists only of [0-9.ef+-]. Any malformed NumSpelling "
1315 "will be caught and reported by NumericLiteralParser.");
1316
1317 // DXC used `strtod` to convert the token string to a float which corresponds
1318 // to:
1319 auto DXCSemantics = llvm::APFloat::Semantics::S_IEEEdouble;
1320 auto DXCRoundingMode = llvm::RoundingMode::NearestTiesToEven;
1321
1322 llvm::APFloat Val(llvm::APFloat::EnumToSemantics(S: DXCSemantics));
1323 llvm::APFloat::opStatus Status(Literal.GetFloatValue(Result&: Val, RM: DXCRoundingMode));
1324
1325 // Note: we do not error when opStatus::opInexact by itself as this just
1326 // denotes that rounding occured but not that it is invalid
1327 assert(!(Status & llvm::APFloat::opStatus::opInvalidOp) &&
1328 "NumSpelling consists only of [0-9.ef+-]. Any malformed NumSpelling "
1329 "will be caught and reported by NumericLiteralParser.");
1330
1331 assert(!(Status & llvm::APFloat::opStatus::opDivByZero) &&
1332 "It is not possible for a division to be performed when "
1333 "constructing an APFloat from a string");
1334
1335 if (Status & llvm::APFloat::opStatus::opUnderflow) {
1336 // Report that the value has underflowed
1337 reportDiag(DiagID: diag::err_hlsl_number_literal_underflow);
1338 return std::nullopt;
1339 }
1340
1341 if (Status & llvm::APFloat::opStatus::opOverflow) {
1342 // Report that the value has overflowed
1343 reportDiag(DiagID: diag::err_hlsl_number_literal_overflow) << /*float type*/ 1;
1344 return std::nullopt;
1345 }
1346
1347 if (Negated)
1348 Val = -Val;
1349
1350 double DoubleVal = Val.convertToDouble();
1351 double FloatMax = double(std::numeric_limits<float>::max());
1352 if (FloatMax < DoubleVal || DoubleVal < -FloatMax) {
1353 // Report that the value has overflowed
1354 reportDiag(DiagID: diag::err_hlsl_number_literal_overflow) << /*float type*/ 1;
1355 return std::nullopt;
1356 }
1357
1358 return static_cast<float>(DoubleVal);
1359}
1360
1361bool RootSignatureParser::verifyZeroFlag() {
1362 assert(CurToken.TokKind == TokenKind::int_literal);
1363 auto X = handleUIntLiteral();
1364 return X.has_value() && X.value() == 0;
1365}
1366
1367bool RootSignatureParser::peekExpectedToken(TokenKind Expected) {
1368 return peekExpectedToken(AnyExpected: ArrayRef{Expected});
1369}
1370
1371bool RootSignatureParser::peekExpectedToken(ArrayRef<TokenKind> AnyExpected) {
1372 RootSignatureToken Result = Lexer.peekNextToken();
1373 return llvm::is_contained(Range&: AnyExpected, Element: Result.TokKind);
1374}
1375
1376bool RootSignatureParser::consumeExpectedToken(TokenKind Expected,
1377 unsigned DiagID,
1378 TokenKind Context) {
1379 if (tryConsumeExpectedToken(Expected))
1380 return false;
1381
1382 // Report unexpected token kind error
1383 DiagnosticBuilder DB = reportDiag(DiagID);
1384 switch (DiagID) {
1385 case diag::err_expected:
1386 DB << Expected;
1387 break;
1388 case diag::err_expected_either:
1389 DB << Expected << Context;
1390 break;
1391 case diag::err_expected_after:
1392 DB << Context << Expected;
1393 break;
1394 default:
1395 break;
1396 }
1397 return true;
1398}
1399
1400bool RootSignatureParser::tryConsumeExpectedToken(TokenKind Expected) {
1401 return tryConsumeExpectedToken(Expected: ArrayRef{Expected});
1402}
1403
1404bool RootSignatureParser::tryConsumeExpectedToken(
1405 ArrayRef<TokenKind> AnyExpected) {
1406 // If not the expected token just return
1407 if (!peekExpectedToken(AnyExpected))
1408 return false;
1409 consumeNextToken();
1410 return true;
1411}
1412
1413bool RootSignatureParser::skipUntilExpectedToken(TokenKind Expected) {
1414 return skipUntilExpectedToken(Expected: ArrayRef{Expected});
1415}
1416
1417bool RootSignatureParser::skipUntilExpectedToken(
1418 ArrayRef<TokenKind> AnyExpected) {
1419
1420 while (!peekExpectedToken(AnyExpected)) {
1421 if (peekExpectedToken(Expected: TokenKind::end_of_stream))
1422 return false;
1423 consumeNextToken();
1424 }
1425
1426 return true;
1427}
1428
1429bool RootSignatureParser::skipUntilClosedParens(uint32_t NumParens) {
1430 TokenKind ParenKinds[] = {
1431 TokenKind::pu_l_paren,
1432 TokenKind::pu_r_paren,
1433 };
1434 while (skipUntilExpectedToken(AnyExpected: ParenKinds)) {
1435 consumeNextToken();
1436 if (CurToken.TokKind == TokenKind::pu_r_paren)
1437 NumParens--;
1438 else
1439 NumParens++;
1440 if (NumParens == 0)
1441 return true;
1442 }
1443
1444 return false;
1445}
1446
1447SourceLocation RootSignatureParser::getTokenLocation(RootSignatureToken Tok) {
1448 return Signature->getLocationOfByte(ByteNo: Tok.LocOffset, SM: PP.getSourceManager(),
1449 Features: PP.getLangOpts(), Target: PP.getTargetInfo());
1450}
1451
1452} // namespace hlsl
1453} // namespace clang
1454

source code of clang/lib/Parse/ParseHLSLRootSignature.cpp