1/*
2 * Copyright 2016-2021 Robert Konrad
3 * SPDX-License-Identifier: Apache-2.0 OR MIT
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19/*
20 * At your option, you may choose to accept this material under either:
21 * 1. The Apache License, Version 2.0, found at <http://www.apache.org/licenses/LICENSE-2.0>, or
22 * 2. The MIT License, found at <http://opensource.org/licenses/MIT>.
23 */
24
25#include "spirv_hlsl.hpp"
26#include "GLSL.std.450.h"
27#include <algorithm>
28#include <assert.h>
29
30using namespace spv;
31using namespace SPIRV_CROSS_NAMESPACE;
32using namespace std;
33
34enum class ImageFormatNormalizedState
35{
36 None = 0,
37 Unorm = 1,
38 Snorm = 2
39};
40
41static ImageFormatNormalizedState image_format_to_normalized_state(ImageFormat fmt)
42{
43 switch (fmt)
44 {
45 case ImageFormatR8:
46 case ImageFormatR16:
47 case ImageFormatRg8:
48 case ImageFormatRg16:
49 case ImageFormatRgba8:
50 case ImageFormatRgba16:
51 case ImageFormatRgb10A2:
52 return ImageFormatNormalizedState::Unorm;
53
54 case ImageFormatR8Snorm:
55 case ImageFormatR16Snorm:
56 case ImageFormatRg8Snorm:
57 case ImageFormatRg16Snorm:
58 case ImageFormatRgba8Snorm:
59 case ImageFormatRgba16Snorm:
60 return ImageFormatNormalizedState::Snorm;
61
62 default:
63 break;
64 }
65
66 return ImageFormatNormalizedState::None;
67}
68
69static unsigned image_format_to_components(ImageFormat fmt)
70{
71 switch (fmt)
72 {
73 case ImageFormatR8:
74 case ImageFormatR16:
75 case ImageFormatR8Snorm:
76 case ImageFormatR16Snorm:
77 case ImageFormatR16f:
78 case ImageFormatR32f:
79 case ImageFormatR8i:
80 case ImageFormatR16i:
81 case ImageFormatR32i:
82 case ImageFormatR8ui:
83 case ImageFormatR16ui:
84 case ImageFormatR32ui:
85 return 1;
86
87 case ImageFormatRg8:
88 case ImageFormatRg16:
89 case ImageFormatRg8Snorm:
90 case ImageFormatRg16Snorm:
91 case ImageFormatRg16f:
92 case ImageFormatRg32f:
93 case ImageFormatRg8i:
94 case ImageFormatRg16i:
95 case ImageFormatRg32i:
96 case ImageFormatRg8ui:
97 case ImageFormatRg16ui:
98 case ImageFormatRg32ui:
99 return 2;
100
101 case ImageFormatR11fG11fB10f:
102 return 3;
103
104 case ImageFormatRgba8:
105 case ImageFormatRgba16:
106 case ImageFormatRgb10A2:
107 case ImageFormatRgba8Snorm:
108 case ImageFormatRgba16Snorm:
109 case ImageFormatRgba16f:
110 case ImageFormatRgba32f:
111 case ImageFormatRgba8i:
112 case ImageFormatRgba16i:
113 case ImageFormatRgba32i:
114 case ImageFormatRgba8ui:
115 case ImageFormatRgba16ui:
116 case ImageFormatRgba32ui:
117 case ImageFormatRgb10a2ui:
118 return 4;
119
120 case ImageFormatUnknown:
121 return 4; // Assume 4.
122
123 default:
124 SPIRV_CROSS_THROW("Unrecognized typed image format.");
125 }
126}
127
128static string image_format_to_type(ImageFormat fmt, SPIRType::BaseType basetype)
129{
130 switch (fmt)
131 {
132 case ImageFormatR8:
133 case ImageFormatR16:
134 if (basetype != SPIRType::Float)
135 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
136 return "unorm float";
137 case ImageFormatRg8:
138 case ImageFormatRg16:
139 if (basetype != SPIRType::Float)
140 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
141 return "unorm float2";
142 case ImageFormatRgba8:
143 case ImageFormatRgba16:
144 if (basetype != SPIRType::Float)
145 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
146 return "unorm float4";
147 case ImageFormatRgb10A2:
148 if (basetype != SPIRType::Float)
149 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
150 return "unorm float4";
151
152 case ImageFormatR8Snorm:
153 case ImageFormatR16Snorm:
154 if (basetype != SPIRType::Float)
155 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
156 return "snorm float";
157 case ImageFormatRg8Snorm:
158 case ImageFormatRg16Snorm:
159 if (basetype != SPIRType::Float)
160 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
161 return "snorm float2";
162 case ImageFormatRgba8Snorm:
163 case ImageFormatRgba16Snorm:
164 if (basetype != SPIRType::Float)
165 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
166 return "snorm float4";
167
168 case ImageFormatR16f:
169 case ImageFormatR32f:
170 if (basetype != SPIRType::Float)
171 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
172 return "float";
173 case ImageFormatRg16f:
174 case ImageFormatRg32f:
175 if (basetype != SPIRType::Float)
176 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
177 return "float2";
178 case ImageFormatRgba16f:
179 case ImageFormatRgba32f:
180 if (basetype != SPIRType::Float)
181 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
182 return "float4";
183
184 case ImageFormatR11fG11fB10f:
185 if (basetype != SPIRType::Float)
186 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
187 return "float3";
188
189 case ImageFormatR8i:
190 case ImageFormatR16i:
191 case ImageFormatR32i:
192 if (basetype != SPIRType::Int)
193 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
194 return "int";
195 case ImageFormatRg8i:
196 case ImageFormatRg16i:
197 case ImageFormatRg32i:
198 if (basetype != SPIRType::Int)
199 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
200 return "int2";
201 case ImageFormatRgba8i:
202 case ImageFormatRgba16i:
203 case ImageFormatRgba32i:
204 if (basetype != SPIRType::Int)
205 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
206 return "int4";
207
208 case ImageFormatR8ui:
209 case ImageFormatR16ui:
210 case ImageFormatR32ui:
211 if (basetype != SPIRType::UInt)
212 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
213 return "uint";
214 case ImageFormatRg8ui:
215 case ImageFormatRg16ui:
216 case ImageFormatRg32ui:
217 if (basetype != SPIRType::UInt)
218 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
219 return "uint2";
220 case ImageFormatRgba8ui:
221 case ImageFormatRgba16ui:
222 case ImageFormatRgba32ui:
223 if (basetype != SPIRType::UInt)
224 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
225 return "uint4";
226 case ImageFormatRgb10a2ui:
227 if (basetype != SPIRType::UInt)
228 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
229 return "uint4";
230
231 case ImageFormatUnknown:
232 switch (basetype)
233 {
234 case SPIRType::Float:
235 return "float4";
236 case SPIRType::Int:
237 return "int4";
238 case SPIRType::UInt:
239 return "uint4";
240 default:
241 SPIRV_CROSS_THROW("Unsupported base type for image.");
242 }
243
244 default:
245 SPIRV_CROSS_THROW("Unrecognized typed image format.");
246 }
247}
248
249string CompilerHLSL::image_type_hlsl_modern(const SPIRType &type, uint32_t id)
250{
251 auto &imagetype = get<SPIRType>(id: type.image.type);
252 const char *dim = nullptr;
253 bool typed_load = false;
254 uint32_t components = 4;
255
256 bool force_image_srv = hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(id, decoration: DecorationNonWritable);
257
258 switch (type.image.dim)
259 {
260 case Dim1D:
261 typed_load = type.image.sampled == 2;
262 dim = "1D";
263 break;
264 case Dim2D:
265 typed_load = type.image.sampled == 2;
266 dim = "2D";
267 break;
268 case Dim3D:
269 typed_load = type.image.sampled == 2;
270 dim = "3D";
271 break;
272 case DimCube:
273 if (type.image.sampled == 2)
274 SPIRV_CROSS_THROW("RWTextureCube does not exist in HLSL.");
275 dim = "Cube";
276 break;
277 case DimRect:
278 SPIRV_CROSS_THROW("Rectangle texture support is not yet implemented for HLSL."); // TODO
279 case DimBuffer:
280 if (type.image.sampled == 1)
281 return join(ts: "Buffer<", ts: type_to_glsl(type: imagetype), ts&: components, ts: ">");
282 else if (type.image.sampled == 2)
283 {
284 if (interlocked_resources.count(x: id))
285 return join(ts: "RasterizerOrderedBuffer<", ts: image_format_to_type(fmt: type.image.format, basetype: imagetype.basetype),
286 ts: ">");
287
288 typed_load = !force_image_srv && type.image.sampled == 2;
289
290 const char *rw = force_image_srv ? "" : "RW";
291 return join(ts&: rw, ts: "Buffer<",
292 ts: typed_load ? image_format_to_type(fmt: type.image.format, basetype: imagetype.basetype) :
293 join(ts: type_to_glsl(type: imagetype), ts&: components),
294 ts: ">");
295 }
296 else
297 SPIRV_CROSS_THROW("Sampler buffers must be either sampled or unsampled. Cannot deduce in runtime.");
298 case DimSubpassData:
299 dim = "2D";
300 typed_load = false;
301 break;
302 default:
303 SPIRV_CROSS_THROW("Invalid dimension.");
304 }
305 const char *arrayed = type.image.arrayed ? "Array" : "";
306 const char *ms = type.image.ms ? "MS" : "";
307 const char *rw = typed_load && !force_image_srv ? "RW" : "";
308
309 if (force_image_srv)
310 typed_load = false;
311
312 if (typed_load && interlocked_resources.count(x: id))
313 rw = "RasterizerOrdered";
314
315 return join(ts&: rw, ts: "Texture", ts&: dim, ts&: ms, ts&: arrayed, ts: "<",
316 ts: typed_load ? image_format_to_type(fmt: type.image.format, basetype: imagetype.basetype) :
317 join(ts: type_to_glsl(type: imagetype), ts&: components),
318 ts: ">");
319}
320
321string CompilerHLSL::image_type_hlsl_legacy(const SPIRType &type, uint32_t /*id*/)
322{
323 auto &imagetype = get<SPIRType>(id: type.image.type);
324 string res;
325
326 switch (imagetype.basetype)
327 {
328 case SPIRType::Int:
329 res = "i";
330 break;
331 case SPIRType::UInt:
332 res = "u";
333 break;
334 default:
335 break;
336 }
337
338 if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData)
339 return res + "subpassInput" + (type.image.ms ? "MS" : "");
340
341 // If we're emulating subpassInput with samplers, force sampler2D
342 // so we don't have to specify format.
343 if (type.basetype == SPIRType::Image && type.image.dim != DimSubpassData)
344 {
345 // Sampler buffers are always declared as samplerBuffer even though they might be separate images in the SPIR-V.
346 if (type.image.dim == DimBuffer && type.image.sampled == 1)
347 res += "sampler";
348 else
349 res += type.image.sampled == 2 ? "image" : "texture";
350 }
351 else
352 res += "sampler";
353
354 switch (type.image.dim)
355 {
356 case Dim1D:
357 res += "1D";
358 break;
359 case Dim2D:
360 res += "2D";
361 break;
362 case Dim3D:
363 res += "3D";
364 break;
365 case DimCube:
366 res += "CUBE";
367 break;
368
369 case DimBuffer:
370 res += "Buffer";
371 break;
372
373 case DimSubpassData:
374 res += "2D";
375 break;
376 default:
377 SPIRV_CROSS_THROW("Only 1D, 2D, 3D, Buffer, InputTarget and Cube textures supported.");
378 }
379
380 if (type.image.ms)
381 res += "MS";
382 if (type.image.arrayed)
383 res += "Array";
384
385 return res;
386}
387
388string CompilerHLSL::image_type_hlsl(const SPIRType &type, uint32_t id)
389{
390 if (hlsl_options.shader_model <= 30)
391 return image_type_hlsl_legacy(type, id);
392 else
393 return image_type_hlsl_modern(type, id);
394}
395
396// The optional id parameter indicates the object whose type we are trying
397// to find the description for. It is optional. Most type descriptions do not
398// depend on a specific object's use of that type.
399string CompilerHLSL::type_to_glsl(const SPIRType &type, uint32_t id)
400{
401 // Ignore the pointer type since GLSL doesn't have pointers.
402
403 switch (type.basetype)
404 {
405 case SPIRType::Struct:
406 // Need OpName lookup here to get a "sensible" name for a struct.
407 if (backend.explicit_struct_type)
408 return join(ts: "struct ", ts: to_name(id: type.self));
409 else
410 return to_name(id: type.self);
411
412 case SPIRType::Image:
413 case SPIRType::SampledImage:
414 return image_type_hlsl(type, id);
415
416 case SPIRType::Sampler:
417 return comparison_ids.count(x: id) ? "SamplerComparisonState" : "SamplerState";
418
419 case SPIRType::Void:
420 return "void";
421
422 default:
423 break;
424 }
425
426 if (type.vecsize == 1 && type.columns == 1) // Scalar builtin
427 {
428 switch (type.basetype)
429 {
430 case SPIRType::Boolean:
431 return "bool";
432 case SPIRType::Int:
433 return backend.basic_int_type;
434 case SPIRType::UInt:
435 return backend.basic_uint_type;
436 case SPIRType::AtomicCounter:
437 return "atomic_uint";
438 case SPIRType::Half:
439 if (hlsl_options.enable_16bit_types)
440 return "half";
441 else
442 return "min16float";
443 case SPIRType::Short:
444 if (hlsl_options.enable_16bit_types)
445 return "int16_t";
446 else
447 return "min16int";
448 case SPIRType::UShort:
449 if (hlsl_options.enable_16bit_types)
450 return "uint16_t";
451 else
452 return "min16uint";
453 case SPIRType::Float:
454 return "float";
455 case SPIRType::Double:
456 return "double";
457 case SPIRType::Int64:
458 if (hlsl_options.shader_model < 60)
459 SPIRV_CROSS_THROW("64-bit integers only supported in SM 6.0.");
460 return "int64_t";
461 case SPIRType::UInt64:
462 if (hlsl_options.shader_model < 60)
463 SPIRV_CROSS_THROW("64-bit integers only supported in SM 6.0.");
464 return "uint64_t";
465 case SPIRType::AccelerationStructure:
466 return "RaytracingAccelerationStructure";
467 case SPIRType::RayQuery:
468 return "RayQuery<RAY_FLAG_NONE>";
469 default:
470 return "???";
471 }
472 }
473 else if (type.vecsize > 1 && type.columns == 1) // Vector builtin
474 {
475 switch (type.basetype)
476 {
477 case SPIRType::Boolean:
478 return join(ts: "bool", ts: type.vecsize);
479 case SPIRType::Int:
480 return join(ts: "int", ts: type.vecsize);
481 case SPIRType::UInt:
482 return join(ts: "uint", ts: type.vecsize);
483 case SPIRType::Half:
484 return join(ts: hlsl_options.enable_16bit_types ? "half" : "min16float", ts: type.vecsize);
485 case SPIRType::Short:
486 return join(ts: hlsl_options.enable_16bit_types ? "int16_t" : "min16int", ts: type.vecsize);
487 case SPIRType::UShort:
488 return join(ts: hlsl_options.enable_16bit_types ? "uint16_t" : "min16uint", ts: type.vecsize);
489 case SPIRType::Float:
490 return join(ts: "float", ts: type.vecsize);
491 case SPIRType::Double:
492 return join(ts: "double", ts: type.vecsize);
493 case SPIRType::Int64:
494 return join(ts: "i64vec", ts: type.vecsize);
495 case SPIRType::UInt64:
496 return join(ts: "u64vec", ts: type.vecsize);
497 default:
498 return "???";
499 }
500 }
501 else
502 {
503 switch (type.basetype)
504 {
505 case SPIRType::Boolean:
506 return join(ts: "bool", ts: type.columns, ts: "x", ts: type.vecsize);
507 case SPIRType::Int:
508 return join(ts: "int", ts: type.columns, ts: "x", ts: type.vecsize);
509 case SPIRType::UInt:
510 return join(ts: "uint", ts: type.columns, ts: "x", ts: type.vecsize);
511 case SPIRType::Half:
512 return join(ts: hlsl_options.enable_16bit_types ? "half" : "min16float", ts: type.columns, ts: "x", ts: type.vecsize);
513 case SPIRType::Short:
514 return join(ts: hlsl_options.enable_16bit_types ? "int16_t" : "min16int", ts: type.columns, ts: "x", ts: type.vecsize);
515 case SPIRType::UShort:
516 return join(ts: hlsl_options.enable_16bit_types ? "uint16_t" : "min16uint", ts: type.columns, ts: "x", ts: type.vecsize);
517 case SPIRType::Float:
518 return join(ts: "float", ts: type.columns, ts: "x", ts: type.vecsize);
519 case SPIRType::Double:
520 return join(ts: "double", ts: type.columns, ts: "x", ts: type.vecsize);
521 // Matrix types not supported for int64/uint64.
522 default:
523 return "???";
524 }
525 }
526}
527
528void CompilerHLSL::emit_header()
529{
530 for (auto &header : header_lines)
531 statement(ts&: header);
532
533 if (header_lines.size() > 0)
534 {
535 statement(ts: "");
536 }
537}
538
539void CompilerHLSL::emit_interface_block_globally(const SPIRVariable &var)
540{
541 add_resource_name(id: var.self);
542
543 // The global copies of I/O variables should not contain interpolation qualifiers.
544 // These are emitted inside the interface structs.
545 auto &flags = ir.meta[var.self].decoration.decoration_flags;
546 auto old_flags = flags;
547 flags.reset();
548 statement(ts: "static ", ts: variable_decl(variable: var), ts: ";");
549 flags = old_flags;
550}
551
552const char *CompilerHLSL::to_storage_qualifiers_glsl(const SPIRVariable &var)
553{
554 // Input and output variables are handled specially in HLSL backend.
555 // The variables are declared as global, private variables, and do not need any qualifiers.
556 if (var.storage == StorageClassUniformConstant || var.storage == StorageClassUniform ||
557 var.storage == StorageClassPushConstant)
558 {
559 return "uniform ";
560 }
561
562 return "";
563}
564
565void CompilerHLSL::emit_builtin_outputs_in_struct()
566{
567 auto &execution = get_entry_point();
568
569 bool legacy = hlsl_options.shader_model <= 30;
570 active_output_builtins.for_each_bit(op: [&](uint32_t i) {
571 const char *type = nullptr;
572 const char *semantic = nullptr;
573 auto builtin = static_cast<BuiltIn>(i);
574 switch (builtin)
575 {
576 case BuiltInPosition:
577 type = is_position_invariant() && backend.support_precise_qualifier ? "precise float4" : "float4";
578 semantic = legacy ? "POSITION" : "SV_Position";
579 break;
580
581 case BuiltInSampleMask:
582 if (hlsl_options.shader_model < 41 || execution.model != ExecutionModelFragment)
583 SPIRV_CROSS_THROW("Sample Mask output is only supported in PS 4.1 or higher.");
584 type = "uint";
585 semantic = "SV_Coverage";
586 break;
587
588 case BuiltInFragDepth:
589 type = "float";
590 if (legacy)
591 {
592 semantic = "DEPTH";
593 }
594 else
595 {
596 if (hlsl_options.shader_model >= 50 && execution.flags.get(bit: ExecutionModeDepthGreater))
597 semantic = "SV_DepthGreaterEqual";
598 else if (hlsl_options.shader_model >= 50 && execution.flags.get(bit: ExecutionModeDepthLess))
599 semantic = "SV_DepthLessEqual";
600 else
601 semantic = "SV_Depth";
602 }
603 break;
604
605 case BuiltInClipDistance:
606 {
607 static const char *types[] = { "float", "float2", "float3", "float4" };
608
609 // HLSL is a bit weird here, use SV_ClipDistance0, SV_ClipDistance1 and so on with vectors.
610 if (execution.model == ExecutionModelMeshEXT)
611 {
612 if (clip_distance_count > 4)
613 SPIRV_CROSS_THROW("Clip distance count > 4 not supported for mesh shaders.");
614
615 if (clip_distance_count == 1)
616 {
617 // Avoids having to hack up access_chain code. Makes it trivially indexable.
618 statement(ts: "float gl_ClipDistance[1] : SV_ClipDistance;");
619 }
620 else
621 {
622 // Replace array with vector directly, avoids any weird fixup path.
623 statement(ts&: types[clip_distance_count - 1], ts: " gl_ClipDistance : SV_ClipDistance;");
624 }
625 }
626 else
627 {
628 for (uint32_t clip = 0; clip < clip_distance_count; clip += 4)
629 {
630 uint32_t to_declare = clip_distance_count - clip;
631 if (to_declare > 4)
632 to_declare = 4;
633
634 uint32_t semantic_index = clip / 4;
635
636 statement(ts&: types[to_declare - 1], ts: " ", ts: builtin_to_glsl(builtin, storage: StorageClassOutput), ts&: semantic_index,
637 ts: " : SV_ClipDistance", ts&: semantic_index, ts: ";");
638 }
639 }
640 break;
641 }
642
643 case BuiltInCullDistance:
644 {
645 static const char *types[] = { "float", "float2", "float3", "float4" };
646
647 // HLSL is a bit weird here, use SV_CullDistance0, SV_CullDistance1 and so on with vectors.
648 if (execution.model == ExecutionModelMeshEXT)
649 {
650 if (cull_distance_count > 4)
651 SPIRV_CROSS_THROW("Cull distance count > 4 not supported for mesh shaders.");
652
653 if (cull_distance_count == 1)
654 {
655 // Avoids having to hack up access_chain code. Makes it trivially indexable.
656 statement(ts: "float gl_CullDistance[1] : SV_CullDistance;");
657 }
658 else
659 {
660 // Replace array with vector directly, avoids any weird fixup path.
661 statement(ts&: types[cull_distance_count - 1], ts: " gl_CullDistance : SV_CullDistance;");
662 }
663 }
664 else
665 {
666 for (uint32_t cull = 0; cull < cull_distance_count; cull += 4)
667 {
668 uint32_t to_declare = cull_distance_count - cull;
669 if (to_declare > 4)
670 to_declare = 4;
671
672 uint32_t semantic_index = cull / 4;
673
674 statement(ts&: types[to_declare - 1], ts: " ", ts: builtin_to_glsl(builtin, storage: StorageClassOutput), ts&: semantic_index,
675 ts: " : SV_CullDistance", ts&: semantic_index, ts: ";");
676 }
677 }
678 break;
679 }
680
681 case BuiltInPointSize:
682 // If point_size_compat is enabled, just ignore PointSize.
683 // PointSize does not exist in HLSL, but some code bases might want to be able to use these shaders,
684 // even if it means working around the missing feature.
685 if (legacy)
686 {
687 type = "float";
688 semantic = "PSIZE";
689 }
690 else if (!hlsl_options.point_size_compat)
691 SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
692 break;
693
694 case BuiltInLayer:
695 case BuiltInPrimitiveId:
696 case BuiltInViewportIndex:
697 case BuiltInPrimitiveShadingRateKHR:
698 case BuiltInCullPrimitiveEXT:
699 // per-primitive attributes handled separatly
700 break;
701
702 case BuiltInPrimitivePointIndicesEXT:
703 case BuiltInPrimitiveLineIndicesEXT:
704 case BuiltInPrimitiveTriangleIndicesEXT:
705 // meshlet local-index buffer handled separatly
706 break;
707
708 default:
709 SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
710 }
711
712 if (type && semantic)
713 statement(ts&: type, ts: " ", ts: builtin_to_glsl(builtin, storage: StorageClassOutput), ts: " : ", ts&: semantic, ts: ";");
714 });
715}
716
717void CompilerHLSL::emit_builtin_primitive_outputs_in_struct()
718{
719 active_output_builtins.for_each_bit(op: [&](uint32_t i) {
720 const char *type = nullptr;
721 const char *semantic = nullptr;
722 auto builtin = static_cast<BuiltIn>(i);
723 switch (builtin)
724 {
725 case BuiltInLayer:
726 {
727 if (hlsl_options.shader_model < 50)
728 SPIRV_CROSS_THROW("Render target array index output is only supported in SM 5.0 or higher.");
729 type = "uint";
730 semantic = "SV_RenderTargetArrayIndex";
731 break;
732 }
733
734 case BuiltInPrimitiveId:
735 type = "uint";
736 semantic = "SV_PrimitiveID";
737 break;
738
739 case BuiltInViewportIndex:
740 type = "uint";
741 semantic = "SV_ViewportArrayIndex";
742 break;
743
744 case BuiltInPrimitiveShadingRateKHR:
745 type = "uint";
746 semantic = "SV_ShadingRate";
747 break;
748
749 case BuiltInCullPrimitiveEXT:
750 type = "bool";
751 semantic = "SV_CullPrimitive";
752 break;
753
754 default:
755 break;
756 }
757
758 if (type && semantic)
759 statement(ts&: type, ts: " ", ts: builtin_to_glsl(builtin, storage: StorageClassOutput), ts: " : ", ts&: semantic, ts: ";");
760 });
761}
762
763void CompilerHLSL::emit_builtin_inputs_in_struct()
764{
765 bool legacy = hlsl_options.shader_model <= 30;
766 active_input_builtins.for_each_bit(op: [&](uint32_t i) {
767 const char *type = nullptr;
768 const char *semantic = nullptr;
769 auto builtin = static_cast<BuiltIn>(i);
770 switch (builtin)
771 {
772 case BuiltInFragCoord:
773 type = "float4";
774 semantic = legacy ? "VPOS" : "SV_Position";
775 break;
776
777 case BuiltInVertexId:
778 case BuiltInVertexIndex:
779 if (legacy)
780 SPIRV_CROSS_THROW("Vertex index not supported in SM 3.0 or lower.");
781 type = "uint";
782 semantic = "SV_VertexID";
783 break;
784
785 case BuiltInPrimitiveId:
786 type = "uint";
787 semantic = "SV_PrimitiveID";
788 break;
789
790 case BuiltInInstanceId:
791 case BuiltInInstanceIndex:
792 if (legacy)
793 SPIRV_CROSS_THROW("Instance index not supported in SM 3.0 or lower.");
794 type = "uint";
795 semantic = "SV_InstanceID";
796 break;
797
798 case BuiltInSampleId:
799 if (legacy)
800 SPIRV_CROSS_THROW("Sample ID not supported in SM 3.0 or lower.");
801 type = "uint";
802 semantic = "SV_SampleIndex";
803 break;
804
805 case BuiltInSampleMask:
806 if (hlsl_options.shader_model < 50 || get_entry_point().model != ExecutionModelFragment)
807 SPIRV_CROSS_THROW("Sample Mask input is only supported in PS 5.0 or higher.");
808 type = "uint";
809 semantic = "SV_Coverage";
810 break;
811
812 case BuiltInGlobalInvocationId:
813 type = "uint3";
814 semantic = "SV_DispatchThreadID";
815 break;
816
817 case BuiltInLocalInvocationId:
818 type = "uint3";
819 semantic = "SV_GroupThreadID";
820 break;
821
822 case BuiltInLocalInvocationIndex:
823 type = "uint";
824 semantic = "SV_GroupIndex";
825 break;
826
827 case BuiltInWorkgroupId:
828 type = "uint3";
829 semantic = "SV_GroupID";
830 break;
831
832 case BuiltInFrontFacing:
833 type = "bool";
834 semantic = "SV_IsFrontFace";
835 break;
836
837 case BuiltInViewIndex:
838 if (hlsl_options.shader_model < 61 || (get_entry_point().model != ExecutionModelVertex && get_entry_point().model != ExecutionModelFragment))
839 SPIRV_CROSS_THROW("View Index input is only supported in VS and PS 6.1 or higher.");
840 type = "uint";
841 semantic = "SV_ViewID";
842 break;
843
844 case BuiltInNumWorkgroups:
845 case BuiltInSubgroupSize:
846 case BuiltInSubgroupLocalInvocationId:
847 case BuiltInSubgroupEqMask:
848 case BuiltInSubgroupLtMask:
849 case BuiltInSubgroupLeMask:
850 case BuiltInSubgroupGtMask:
851 case BuiltInSubgroupGeMask:
852 // Handled specially.
853 break;
854
855 case BuiltInBaseVertex:
856 if (hlsl_options.shader_model >= 68)
857 {
858 type = "uint";
859 semantic = "SV_StartVertexLocation";
860 }
861 break;
862
863 case BuiltInBaseInstance:
864 if (hlsl_options.shader_model >= 68)
865 {
866 type = "uint";
867 semantic = "SV_StartInstanceLocation";
868 }
869 break;
870
871 case BuiltInHelperInvocation:
872 if (hlsl_options.shader_model < 50 || get_entry_point().model != ExecutionModelFragment)
873 SPIRV_CROSS_THROW("Helper Invocation input is only supported in PS 5.0 or higher.");
874 break;
875
876 case BuiltInClipDistance:
877 // HLSL is a bit weird here, use SV_ClipDistance0, SV_ClipDistance1 and so on with vectors.
878 for (uint32_t clip = 0; clip < clip_distance_count; clip += 4)
879 {
880 uint32_t to_declare = clip_distance_count - clip;
881 if (to_declare > 4)
882 to_declare = 4;
883
884 uint32_t semantic_index = clip / 4;
885
886 static const char *types[] = { "float", "float2", "float3", "float4" };
887 statement(ts&: types[to_declare - 1], ts: " ", ts: builtin_to_glsl(builtin, storage: StorageClassInput), ts&: semantic_index,
888 ts: " : SV_ClipDistance", ts&: semantic_index, ts: ";");
889 }
890 break;
891
892 case BuiltInCullDistance:
893 // HLSL is a bit weird here, use SV_CullDistance0, SV_CullDistance1 and so on with vectors.
894 for (uint32_t cull = 0; cull < cull_distance_count; cull += 4)
895 {
896 uint32_t to_declare = cull_distance_count - cull;
897 if (to_declare > 4)
898 to_declare = 4;
899
900 uint32_t semantic_index = cull / 4;
901
902 static const char *types[] = { "float", "float2", "float3", "float4" };
903 statement(ts&: types[to_declare - 1], ts: " ", ts: builtin_to_glsl(builtin, storage: StorageClassInput), ts&: semantic_index,
904 ts: " : SV_CullDistance", ts&: semantic_index, ts: ";");
905 }
906 break;
907
908 case BuiltInPointCoord:
909 // PointCoord is not supported, but provide a way to just ignore that, similar to PointSize.
910 if (hlsl_options.point_coord_compat)
911 break;
912 else
913 SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
914
915 case BuiltInLayer:
916 if (hlsl_options.shader_model < 50 || get_entry_point().model != ExecutionModelFragment)
917 SPIRV_CROSS_THROW("Render target array index input is only supported in PS 5.0 or higher.");
918 type = "uint";
919 semantic = "SV_RenderTargetArrayIndex";
920 break;
921
922 case BuiltInBaryCoordKHR:
923 case BuiltInBaryCoordNoPerspKHR:
924 if (hlsl_options.shader_model < 61)
925 SPIRV_CROSS_THROW("SM 6.1 is required for barycentrics.");
926 type = builtin == BuiltInBaryCoordNoPerspKHR ? "noperspective float3" : "float3";
927 if (active_input_builtins.get(bit: BuiltInBaryCoordKHR) && active_input_builtins.get(bit: BuiltInBaryCoordNoPerspKHR))
928 semantic = builtin == BuiltInBaryCoordKHR ? "SV_Barycentrics0" : "SV_Barycentrics1";
929 else
930 semantic = "SV_Barycentrics";
931 break;
932
933 default:
934 SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
935 }
936
937 if (type && semantic)
938 statement(ts&: type, ts: " ", ts: builtin_to_glsl(builtin, storage: StorageClassInput), ts: " : ", ts&: semantic, ts: ";");
939 });
940}
941
942uint32_t CompilerHLSL::type_to_consumed_locations(const SPIRType &type) const
943{
944 // TODO: Need to verify correctness.
945 uint32_t elements = 0;
946
947 if (type.basetype == SPIRType::Struct)
948 {
949 for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++)
950 elements += type_to_consumed_locations(type: get<SPIRType>(id: type.member_types[i]));
951 }
952 else
953 {
954 uint32_t array_multiplier = 1;
955 for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
956 {
957 if (type.array_size_literal[i])
958 array_multiplier *= type.array[i];
959 else
960 array_multiplier *= evaluate_constant_u32(id: type.array[i]);
961 }
962 elements += array_multiplier * type.columns;
963 }
964 return elements;
965}
966
967string CompilerHLSL::to_interpolation_qualifiers(const Bitset &flags)
968{
969 string res;
970 //if (flags & (1ull << DecorationSmooth))
971 // res += "linear ";
972 if (flags.get(bit: DecorationFlat) || flags.get(bit: DecorationPerVertexKHR))
973 res += "nointerpolation ";
974 if (flags.get(bit: DecorationNoPerspective))
975 res += "noperspective ";
976 if (flags.get(bit: DecorationCentroid))
977 res += "centroid ";
978 if (flags.get(bit: DecorationPatch))
979 res += "patch "; // Seems to be different in actual HLSL.
980 if (flags.get(bit: DecorationSample))
981 res += "sample ";
982 if (flags.get(bit: DecorationInvariant) && backend.support_precise_qualifier)
983 res += "precise "; // Not supported?
984
985 return res;
986}
987
988std::string CompilerHLSL::to_semantic(uint32_t location, ExecutionModel em, StorageClass sc)
989{
990 if (em == ExecutionModelVertex && sc == StorageClassInput)
991 {
992 // We have a vertex attribute - we should look at remapping it if the user provided
993 // vertex attribute hints.
994 for (auto &attribute : remap_vertex_attributes)
995 if (attribute.location == location)
996 return attribute.semantic;
997 }
998
999 // Not a vertex attribute, or no remap_vertex_attributes entry.
1000 return join(ts: "TEXCOORD", ts&: location);
1001}
1002
1003std::string CompilerHLSL::to_initializer_expression(const SPIRVariable &var)
1004{
1005 // We cannot emit static const initializer for block constants for practical reasons,
1006 // so just inline the initializer.
1007 // FIXME: There is a theoretical problem here if someone tries to composite extract
1008 // into this initializer since we don't declare it properly, but that is somewhat non-sensical.
1009 auto &type = get<SPIRType>(id: var.basetype);
1010 bool is_block = has_decoration(id: type.self, decoration: DecorationBlock);
1011 auto *c = maybe_get<SPIRConstant>(id: var.initializer);
1012 if (is_block && c)
1013 return constant_expression(c: *c);
1014 else
1015 return CompilerGLSL::to_initializer_expression(var);
1016}
1017
1018void CompilerHLSL::emit_interface_block_member_in_struct(const SPIRVariable &var, uint32_t member_index,
1019 uint32_t location,
1020 std::unordered_set<uint32_t> &active_locations)
1021{
1022 auto &execution = get_entry_point();
1023 auto type = get<SPIRType>(id: var.basetype);
1024 auto semantic = to_semantic(location, em: execution.model, sc: var.storage);
1025 auto mbr_name = join(ts: to_name(id: type.self), ts: "_", ts: to_member_name(type, index: member_index));
1026 auto &mbr_type = get<SPIRType>(id: type.member_types[member_index]);
1027
1028 Bitset member_decorations = get_member_decoration_bitset(id: type.self, index: member_index);
1029 if (has_decoration(id: var.self, decoration: DecorationPerVertexKHR))
1030 member_decorations.set(DecorationPerVertexKHR);
1031
1032 statement(ts: to_interpolation_qualifiers(flags: member_decorations),
1033 ts: type_to_glsl(type: mbr_type),
1034 ts: " ", ts&: mbr_name, ts: type_to_array_glsl(type: mbr_type, variable_id: var.self),
1035 ts: " : ", ts&: semantic, ts: ";");
1036
1037 // Structs and arrays should consume more locations.
1038 uint32_t consumed_locations = type_to_consumed_locations(type: mbr_type);
1039 for (uint32_t i = 0; i < consumed_locations; i++)
1040 active_locations.insert(x: location + i);
1041}
1042
1043void CompilerHLSL::emit_interface_block_in_struct(const SPIRVariable &var, unordered_set<uint32_t> &active_locations)
1044{
1045 auto &execution = get_entry_point();
1046 auto type = get<SPIRType>(id: var.basetype);
1047
1048 string binding;
1049 bool use_location_number = true;
1050 bool need_matrix_unroll = false;
1051 bool legacy = hlsl_options.shader_model <= 30;
1052 if (execution.model == ExecutionModelFragment && var.storage == StorageClassOutput)
1053 {
1054 // Dual-source blending is achieved in HLSL by emitting to SV_Target0 and 1.
1055 uint32_t index = get_decoration(id: var.self, decoration: DecorationIndex);
1056 uint32_t location = get_decoration(id: var.self, decoration: DecorationLocation);
1057
1058 if (index != 0 && location != 0)
1059 SPIRV_CROSS_THROW("Dual-source blending is only supported on MRT #0 in HLSL.");
1060
1061 binding = join(ts: legacy ? "COLOR" : "SV_Target", ts: location + index);
1062 use_location_number = false;
1063 if (legacy) // COLOR must be a four-component vector on legacy shader model targets (HLSL ERR_COLOR_4COMP)
1064 type.vecsize = 4;
1065 }
1066 else if (var.storage == StorageClassInput && execution.model == ExecutionModelVertex)
1067 {
1068 need_matrix_unroll = true;
1069 if (legacy) // Inputs must be floating-point in legacy targets.
1070 type.basetype = SPIRType::Float;
1071 }
1072
1073 const auto get_vacant_location = [&]() -> uint32_t {
1074 for (uint32_t i = 0; i < 64; i++)
1075 if (!active_locations.count(x: i))
1076 return i;
1077 SPIRV_CROSS_THROW("All locations from 0 to 63 are exhausted.");
1078 };
1079
1080 auto name = to_name(id: var.self);
1081 if (use_location_number)
1082 {
1083 uint32_t location_number;
1084
1085 // If an explicit location exists, use it with TEXCOORD[N] semantic.
1086 // Otherwise, pick a vacant location.
1087 if (has_decoration(id: var.self, decoration: DecorationLocation))
1088 location_number = get_decoration(id: var.self, decoration: DecorationLocation);
1089 else
1090 location_number = get_vacant_location();
1091
1092 // Allow semantic remap if specified.
1093 auto semantic = to_semantic(location: location_number, em: execution.model, sc: var.storage);
1094
1095 if (need_matrix_unroll && type.columns > 1)
1096 {
1097 if (!type.array.empty())
1098 SPIRV_CROSS_THROW("Arrays of matrices used as input/output. This is not supported.");
1099
1100 // Unroll matrices.
1101 for (uint32_t i = 0; i < type.columns; i++)
1102 {
1103 SPIRType newtype = type;
1104 newtype.columns = 1;
1105
1106 string effective_semantic;
1107 if (hlsl_options.flatten_matrix_vertex_input_semantics)
1108 effective_semantic = to_semantic(location: location_number, em: execution.model, sc: var.storage);
1109 else
1110 effective_semantic = join(ts&: semantic, ts: "_", ts&: i);
1111
1112 statement(ts: to_interpolation_qualifiers(flags: get_decoration_bitset(id: var.self)),
1113 ts: variable_decl(type: newtype, name: join(ts&: name, ts: "_", ts&: i)), ts: " : ", ts&: effective_semantic, ts: ";");
1114 active_locations.insert(x: location_number++);
1115 }
1116 }
1117 else
1118 {
1119 auto decl_type = type;
1120 if (execution.model == ExecutionModelMeshEXT || has_decoration(id: var.self, decoration: DecorationPerVertexKHR))
1121 {
1122 decl_type.array.erase(itr: decl_type.array.begin());
1123 decl_type.array_size_literal.erase(itr: decl_type.array_size_literal.begin());
1124 }
1125 statement(ts: to_interpolation_qualifiers(flags: get_decoration_bitset(id: var.self)), ts: variable_decl(type: decl_type, name), ts: " : ",
1126 ts&: semantic, ts: ";");
1127
1128 // Structs and arrays should consume more locations.
1129 uint32_t consumed_locations = type_to_consumed_locations(type: decl_type);
1130 for (uint32_t i = 0; i < consumed_locations; i++)
1131 active_locations.insert(x: location_number + i);
1132 }
1133 }
1134 else
1135 {
1136 statement(ts: variable_decl(type, name), ts: " : ", ts&: binding, ts: ";");
1137 }
1138}
1139
1140std::string CompilerHLSL::builtin_to_glsl(spv::BuiltIn builtin, spv::StorageClass storage)
1141{
1142 switch (builtin)
1143 {
1144 case BuiltInVertexId:
1145 return "gl_VertexID";
1146 case BuiltInInstanceId:
1147 return "gl_InstanceID";
1148 case BuiltInNumWorkgroups:
1149 {
1150 if (!num_workgroups_builtin)
1151 SPIRV_CROSS_THROW("NumWorkgroups builtin is used, but remap_num_workgroups_builtin() was not called. "
1152 "Cannot emit code for this builtin.");
1153
1154 auto &var = get<SPIRVariable>(id: num_workgroups_builtin);
1155 auto &type = get<SPIRType>(id: var.basetype);
1156 auto ret = join(ts: to_name(id: num_workgroups_builtin), ts: "_", ts: get_member_name(id: type.self, index: 0));
1157 ParsedIR::sanitize_underscores(str&: ret);
1158 return ret;
1159 }
1160 case BuiltInPointCoord:
1161 // Crude hack, but there is no real alternative. This path is only enabled if point_coord_compat is set.
1162 return "float2(0.5f, 0.5f)";
1163 case BuiltInSubgroupLocalInvocationId:
1164 return "WaveGetLaneIndex()";
1165 case BuiltInSubgroupSize:
1166 return "WaveGetLaneCount()";
1167 case BuiltInHelperInvocation:
1168 return "IsHelperLane()";
1169
1170 default:
1171 return CompilerGLSL::builtin_to_glsl(builtin, storage);
1172 }
1173}
1174
1175void CompilerHLSL::emit_builtin_variables()
1176{
1177 Bitset builtins = active_input_builtins;
1178 builtins.merge_or(other: active_output_builtins);
1179
1180 std::unordered_map<uint32_t, ID> builtin_to_initializer;
1181
1182 // We need to declare sample mask with the same type that module declares it.
1183 // Sample mask is somewhat special in that SPIR-V has an array, and we can copy that array, so we need to
1184 // match sign.
1185 SPIRType::BaseType sample_mask_in_basetype = SPIRType::Void;
1186 SPIRType::BaseType sample_mask_out_basetype = SPIRType::Void;
1187
1188 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
1189 if (!is_builtin_variable(var))
1190 return;
1191
1192 auto &type = this->get<SPIRType>(id: var.basetype);
1193 auto builtin = BuiltIn(get_decoration(id: var.self, decoration: DecorationBuiltIn));
1194
1195 if (var.storage == StorageClassInput && builtin == BuiltInSampleMask)
1196 sample_mask_in_basetype = type.basetype;
1197 else if (var.storage == StorageClassOutput && builtin == BuiltInSampleMask)
1198 sample_mask_out_basetype = type.basetype;
1199
1200 if (var.initializer && var.storage == StorageClassOutput)
1201 {
1202 auto *c = this->maybe_get<SPIRConstant>(id: var.initializer);
1203 if (!c)
1204 return;
1205
1206 if (type.basetype == SPIRType::Struct)
1207 {
1208 uint32_t member_count = uint32_t(type.member_types.size());
1209 for (uint32_t i = 0; i < member_count; i++)
1210 {
1211 if (has_member_decoration(id: type.self, index: i, decoration: DecorationBuiltIn))
1212 {
1213 builtin_to_initializer[get_member_decoration(id: type.self, index: i, decoration: DecorationBuiltIn)] =
1214 c->subconstants[i];
1215 }
1216 }
1217 }
1218 else if (has_decoration(id: var.self, decoration: DecorationBuiltIn))
1219 {
1220 builtin_to_initializer[builtin] = var.initializer;
1221 }
1222 }
1223 });
1224
1225 // Emit global variables for the interface variables which are statically used by the shader.
1226 builtins.for_each_bit(op: [&](uint32_t i) {
1227 const char *type = nullptr;
1228 auto builtin = static_cast<BuiltIn>(i);
1229 uint32_t array_size = 0;
1230
1231 string init_expr;
1232 auto init_itr = builtin_to_initializer.find(x: builtin);
1233 if (init_itr != builtin_to_initializer.end())
1234 init_expr = join(ts: " = ", ts: to_expression(id: init_itr->second));
1235
1236 if (get_execution_model() == ExecutionModelMeshEXT)
1237 {
1238 if (builtin == BuiltInPosition || builtin == BuiltInPointSize || builtin == BuiltInClipDistance ||
1239 builtin == BuiltInCullDistance || builtin == BuiltInLayer || builtin == BuiltInPrimitiveId ||
1240 builtin == BuiltInViewportIndex || builtin == BuiltInCullPrimitiveEXT ||
1241 builtin == BuiltInPrimitiveShadingRateKHR || builtin == BuiltInPrimitivePointIndicesEXT ||
1242 builtin == BuiltInPrimitiveLineIndicesEXT || builtin == BuiltInPrimitiveTriangleIndicesEXT)
1243 {
1244 return;
1245 }
1246 }
1247
1248 switch (builtin)
1249 {
1250 case BuiltInFragCoord:
1251 case BuiltInPosition:
1252 type = "float4";
1253 break;
1254
1255 case BuiltInFragDepth:
1256 type = "float";
1257 break;
1258
1259 case BuiltInVertexId:
1260 case BuiltInVertexIndex:
1261 case BuiltInInstanceIndex:
1262 type = "int";
1263 if (hlsl_options.support_nonzero_base_vertex_base_instance || hlsl_options.shader_model >= 68)
1264 base_vertex_info.used = true;
1265 break;
1266
1267 case BuiltInBaseVertex:
1268 case BuiltInBaseInstance:
1269 type = "int";
1270 base_vertex_info.used = true;
1271 break;
1272
1273 case BuiltInInstanceId:
1274 case BuiltInSampleId:
1275 type = "int";
1276 break;
1277
1278 case BuiltInPointSize:
1279 if (hlsl_options.point_size_compat || hlsl_options.shader_model <= 30)
1280 {
1281 // Just emit the global variable, it will be ignored.
1282 type = "float";
1283 break;
1284 }
1285 else
1286 SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin)));
1287
1288 case BuiltInGlobalInvocationId:
1289 case BuiltInLocalInvocationId:
1290 case BuiltInWorkgroupId:
1291 type = "uint3";
1292 break;
1293
1294 case BuiltInLocalInvocationIndex:
1295 type = "uint";
1296 break;
1297
1298 case BuiltInFrontFacing:
1299 type = "bool";
1300 break;
1301
1302 case BuiltInNumWorkgroups:
1303 case BuiltInPointCoord:
1304 // Handled specially.
1305 break;
1306
1307 case BuiltInSubgroupLocalInvocationId:
1308 case BuiltInSubgroupSize:
1309 if (hlsl_options.shader_model < 60)
1310 SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops.");
1311 break;
1312
1313 case BuiltInSubgroupEqMask:
1314 case BuiltInSubgroupLtMask:
1315 case BuiltInSubgroupLeMask:
1316 case BuiltInSubgroupGtMask:
1317 case BuiltInSubgroupGeMask:
1318 if (hlsl_options.shader_model < 60)
1319 SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops.");
1320 type = "uint4";
1321 break;
1322
1323 case BuiltInHelperInvocation:
1324 if (hlsl_options.shader_model < 50)
1325 SPIRV_CROSS_THROW("Need SM 5.0 for Helper Invocation.");
1326 break;
1327
1328 case BuiltInClipDistance:
1329 array_size = clip_distance_count;
1330 type = "float";
1331 break;
1332
1333 case BuiltInCullDistance:
1334 array_size = cull_distance_count;
1335 type = "float";
1336 break;
1337
1338 case BuiltInSampleMask:
1339 if (active_input_builtins.get(bit: BuiltInSampleMask))
1340 type = sample_mask_in_basetype == SPIRType::UInt ? "uint" : "int";
1341 else
1342 type = sample_mask_out_basetype == SPIRType::UInt ? "uint" : "int";
1343 array_size = 1;
1344 break;
1345
1346 case BuiltInPrimitiveId:
1347 case BuiltInViewIndex:
1348 case BuiltInLayer:
1349 type = "uint";
1350 break;
1351
1352 case BuiltInViewportIndex:
1353 case BuiltInPrimitiveShadingRateKHR:
1354 case BuiltInPrimitiveLineIndicesEXT:
1355 case BuiltInCullPrimitiveEXT:
1356 type = "uint";
1357 break;
1358
1359 case BuiltInBaryCoordKHR:
1360 case BuiltInBaryCoordNoPerspKHR:
1361 if (hlsl_options.shader_model < 61)
1362 SPIRV_CROSS_THROW("Need SM 6.1 for barycentrics.");
1363 type = "float3";
1364 break;
1365
1366 default:
1367 SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin)));
1368 }
1369
1370 StorageClass storage = active_input_builtins.get(bit: i) ? StorageClassInput : StorageClassOutput;
1371
1372 if (type)
1373 {
1374 if (array_size)
1375 statement(ts: "static ", ts&: type, ts: " ", ts: builtin_to_glsl(builtin, storage), ts: "[", ts&: array_size, ts: "]", ts&: init_expr, ts: ";");
1376 else
1377 statement(ts: "static ", ts&: type, ts: " ", ts: builtin_to_glsl(builtin, storage), ts&: init_expr, ts: ";");
1378 }
1379
1380 // SampleMask can be both in and out with sample builtin, in this case we have already
1381 // declared the input variable and we need to add the output one now.
1382 if (builtin == BuiltInSampleMask && storage == StorageClassInput && this->active_output_builtins.get(bit: i))
1383 {
1384 type = sample_mask_out_basetype == SPIRType::UInt ? "uint" : "int";
1385 if (array_size)
1386 statement(ts: "static ", ts&: type, ts: " ", ts: this->builtin_to_glsl(builtin, storage: StorageClassOutput), ts: "[", ts&: array_size, ts: "]", ts&: init_expr, ts: ";");
1387 else
1388 statement(ts: "static ", ts&: type, ts: " ", ts: this->builtin_to_glsl(builtin, storage: StorageClassOutput), ts&: init_expr, ts: ";");
1389 }
1390 });
1391
1392 if (base_vertex_info.used && hlsl_options.shader_model < 68)
1393 {
1394 string binding_info;
1395 if (base_vertex_info.explicit_binding)
1396 {
1397 binding_info = join(ts: " : register(b", ts&: base_vertex_info.register_index);
1398 if (base_vertex_info.register_space)
1399 binding_info += join(ts: ", space", ts&: base_vertex_info.register_space);
1400 binding_info += ")";
1401 }
1402 statement(ts: "cbuffer SPIRV_Cross_VertexInfo", ts&: binding_info);
1403 begin_scope();
1404 statement(ts: "int SPIRV_Cross_BaseVertex;");
1405 statement(ts: "int SPIRV_Cross_BaseInstance;");
1406 end_scope_decl();
1407 statement(ts: "");
1408 }
1409}
1410
1411void CompilerHLSL::set_hlsl_aux_buffer_binding(HLSLAuxBinding binding, uint32_t register_index, uint32_t register_space)
1412{
1413 if (binding == HLSL_AUX_BINDING_BASE_VERTEX_INSTANCE)
1414 {
1415 base_vertex_info.explicit_binding = true;
1416 base_vertex_info.register_space = register_space;
1417 base_vertex_info.register_index = register_index;
1418 }
1419}
1420
1421void CompilerHLSL::unset_hlsl_aux_buffer_binding(HLSLAuxBinding binding)
1422{
1423 if (binding == HLSL_AUX_BINDING_BASE_VERTEX_INSTANCE)
1424 base_vertex_info.explicit_binding = false;
1425}
1426
1427bool CompilerHLSL::is_hlsl_aux_buffer_binding_used(HLSLAuxBinding binding) const
1428{
1429 if (binding == HLSL_AUX_BINDING_BASE_VERTEX_INSTANCE)
1430 return base_vertex_info.used;
1431 else
1432 return false;
1433}
1434
1435void CompilerHLSL::emit_composite_constants()
1436{
1437 // HLSL cannot declare structs or arrays inline, so we must move them out to
1438 // global constants directly.
1439 bool emitted = false;
1440
1441 ir.for_each_typed_id<SPIRConstant>(op: [&](uint32_t, SPIRConstant &c) {
1442 if (c.specialization)
1443 return;
1444
1445 auto &type = this->get<SPIRType>(id: c.constant_type);
1446
1447 if (type.basetype == SPIRType::Struct && is_builtin_type(type))
1448 return;
1449
1450 if (type.basetype == SPIRType::Struct || !type.array.empty())
1451 {
1452 add_resource_name(id: c.self);
1453 auto name = to_name(id: c.self);
1454 statement(ts: "static const ", ts: variable_decl(type, name), ts: " = ", ts: constant_expression(c), ts: ";");
1455 emitted = true;
1456 }
1457 });
1458
1459 if (emitted)
1460 statement(ts: "");
1461}
1462
1463void CompilerHLSL::emit_specialization_constants_and_structs()
1464{
1465 bool emitted = false;
1466 SpecializationConstant wg_x, wg_y, wg_z;
1467 ID workgroup_size_id = get_work_group_size_specialization_constants(x&: wg_x, y&: wg_y, z&: wg_z);
1468
1469 std::unordered_set<TypeID> io_block_types;
1470 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, const SPIRVariable &var) {
1471 auto &type = this->get<SPIRType>(id: var.basetype);
1472 if ((var.storage == StorageClassInput || var.storage == StorageClassOutput) &&
1473 !var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
1474 interface_variable_exists_in_entry_point(id: var.self) &&
1475 has_decoration(id: type.self, decoration: DecorationBlock))
1476 {
1477 io_block_types.insert(x: type.self);
1478 }
1479 });
1480
1481 auto loop_lock = ir.create_loop_hard_lock();
1482 for (auto &id_ : ir.ids_for_constant_undef_or_type)
1483 {
1484 auto &id = ir.ids[id_];
1485
1486 if (id.get_type() == TypeConstant)
1487 {
1488 auto &c = id.get<SPIRConstant>();
1489
1490 if (c.self == workgroup_size_id)
1491 {
1492 statement(ts: "static const uint3 gl_WorkGroupSize = ",
1493 ts: constant_expression(c: get<SPIRConstant>(id: workgroup_size_id)), ts: ";");
1494 emitted = true;
1495 }
1496 else if (c.specialization)
1497 {
1498 auto &type = get<SPIRType>(id: c.constant_type);
1499 add_resource_name(id: c.self);
1500 auto name = to_name(id: c.self);
1501
1502 if (has_decoration(id: c.self, decoration: DecorationSpecId))
1503 {
1504 // HLSL does not support specialization constants, so fallback to macros.
1505 c.specialization_constant_macro_name =
1506 constant_value_macro_name(id: get_decoration(id: c.self, decoration: DecorationSpecId));
1507
1508 statement(ts: "#ifndef ", ts&: c.specialization_constant_macro_name);
1509 statement(ts: "#define ", ts&: c.specialization_constant_macro_name, ts: " ", ts: constant_expression(c));
1510 statement(ts: "#endif");
1511 statement(ts: "static const ", ts: variable_decl(type, name), ts: " = ", ts&: c.specialization_constant_macro_name, ts: ";");
1512 }
1513 else
1514 statement(ts: "static const ", ts: variable_decl(type, name), ts: " = ", ts: constant_expression(c), ts: ";");
1515
1516 emitted = true;
1517 }
1518 }
1519 else if (id.get_type() == TypeConstantOp)
1520 {
1521 auto &c = id.get<SPIRConstantOp>();
1522 auto &type = get<SPIRType>(id: c.basetype);
1523 add_resource_name(id: c.self);
1524 auto name = to_name(id: c.self);
1525 statement(ts: "static const ", ts: variable_decl(type, name), ts: " = ", ts: constant_op_expression(cop: c), ts: ";");
1526 emitted = true;
1527 }
1528 else if (id.get_type() == TypeType)
1529 {
1530 auto &type = id.get<SPIRType>();
1531 bool is_non_io_block = has_decoration(id: type.self, decoration: DecorationBlock) &&
1532 io_block_types.count(x: type.self) == 0;
1533 bool is_buffer_block = has_decoration(id: type.self, decoration: DecorationBufferBlock);
1534 if (type.basetype == SPIRType::Struct && type.array.empty() &&
1535 !type.pointer && !is_non_io_block && !is_buffer_block)
1536 {
1537 if (emitted)
1538 statement(ts: "");
1539 emitted = false;
1540
1541 emit_struct(type);
1542 }
1543 }
1544 else if (id.get_type() == TypeUndef)
1545 {
1546 auto &undef = id.get<SPIRUndef>();
1547 auto &type = this->get<SPIRType>(id: undef.basetype);
1548 // OpUndef can be void for some reason ...
1549 if (type.basetype == SPIRType::Void)
1550 return;
1551
1552 string initializer;
1553 if (options.force_zero_initialized_variables && type_can_zero_initialize(type))
1554 initializer = join(ts: " = ", ts: to_zero_initialized_expression(type_id: undef.basetype));
1555
1556 statement(ts: "static ", ts: variable_decl(type, name: to_name(id: undef.self), id: undef.self), ts&: initializer, ts: ";");
1557 emitted = true;
1558 }
1559 }
1560
1561 if (emitted)
1562 statement(ts: "");
1563}
1564
1565void CompilerHLSL::replace_illegal_names()
1566{
1567 static const unordered_set<string> keywords = {
1568 // Additional HLSL specific keywords.
1569 // From https://docs.microsoft.com/en-US/windows/win32/direct3dhlsl/dx-graphics-hlsl-appendix-keywords
1570 "AppendStructuredBuffer", "asm", "asm_fragment",
1571 "BlendState", "bool", "break", "Buffer", "ByteAddressBuffer",
1572 "case", "cbuffer", "centroid", "class", "column_major", "compile",
1573 "compile_fragment", "CompileShader", "const", "continue", "ComputeShader",
1574 "ConsumeStructuredBuffer",
1575 "default", "DepthStencilState", "DepthStencilView", "discard", "do",
1576 "double", "DomainShader", "dword",
1577 "else", "export", "false", "float", "for", "fxgroup",
1578 "GeometryShader", "groupshared", "half", "HullShader",
1579 "indices", "if", "in", "inline", "inout", "InputPatch", "int", "interface",
1580 "line", "lineadj", "linear", "LineStream",
1581 "matrix", "min16float", "min10float", "min16int", "min16uint",
1582 "namespace", "nointerpolation", "noperspective", "NULL",
1583 "out", "OutputPatch",
1584 "payload", "packoffset", "pass", "pixelfragment", "PixelShader", "point",
1585 "PointStream", "precise", "RasterizerState", "RenderTargetView",
1586 "return", "register", "row_major", "RWBuffer", "RWByteAddressBuffer",
1587 "RWStructuredBuffer", "RWTexture1D", "RWTexture1DArray", "RWTexture2D",
1588 "RWTexture2DArray", "RWTexture3D", "sample", "sampler", "SamplerState",
1589 "SamplerComparisonState", "shared", "snorm", "stateblock", "stateblock_state",
1590 "static", "string", "struct", "switch", "StructuredBuffer", "tbuffer",
1591 "technique", "technique10", "technique11", "texture", "Texture1D",
1592 "Texture1DArray", "Texture2D", "Texture2DArray", "Texture2DMS", "Texture2DMSArray",
1593 "Texture3D", "TextureCube", "TextureCubeArray", "true", "typedef", "triangle",
1594 "triangleadj", "TriangleStream", "uint", "uniform", "unorm", "unsigned",
1595 "vector", "vertexfragment", "VertexShader", "vertices", "void", "volatile", "while",
1596 };
1597
1598 CompilerGLSL::replace_illegal_names(keywords);
1599 CompilerGLSL::replace_illegal_names();
1600}
1601
1602SPIRType::BaseType CompilerHLSL::get_builtin_basetype(BuiltIn builtin, SPIRType::BaseType default_type)
1603{
1604 switch (builtin)
1605 {
1606 case BuiltInSampleMask:
1607 // We declare sample mask array with module type, so always use default_type here.
1608 return default_type;
1609 default:
1610 return CompilerGLSL::get_builtin_basetype(builtin, default_type);
1611 }
1612}
1613
1614void CompilerHLSL::emit_resources()
1615{
1616 auto &execution = get_entry_point();
1617
1618 replace_illegal_names();
1619
1620 switch (execution.model)
1621 {
1622 case ExecutionModelGeometry:
1623 case ExecutionModelTessellationControl:
1624 case ExecutionModelTessellationEvaluation:
1625 case ExecutionModelMeshEXT:
1626 fixup_implicit_builtin_block_names(model: execution.model);
1627 break;
1628
1629 default:
1630 break;
1631 }
1632
1633 emit_specialization_constants_and_structs();
1634 emit_composite_constants();
1635
1636 bool emitted = false;
1637
1638 // Output UBOs and SSBOs
1639 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
1640 auto &type = this->get<SPIRType>(id: var.basetype);
1641
1642 bool is_block_storage = type.storage == StorageClassStorageBuffer || type.storage == StorageClassUniform;
1643 bool has_block_flags = ir.meta[type.self].decoration.decoration_flags.get(bit: DecorationBlock) ||
1644 ir.meta[type.self].decoration.decoration_flags.get(bit: DecorationBufferBlock);
1645
1646 if (var.storage != StorageClassFunction && type.pointer && is_block_storage && !is_hidden_variable(var) &&
1647 has_block_flags)
1648 {
1649 emit_buffer_block(type: var);
1650 emitted = true;
1651 }
1652 });
1653
1654 // Output push constant blocks
1655 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
1656 auto &type = this->get<SPIRType>(id: var.basetype);
1657 if (var.storage != StorageClassFunction && type.pointer && type.storage == StorageClassPushConstant &&
1658 !is_hidden_variable(var))
1659 {
1660 emit_push_constant_block(var);
1661 emitted = true;
1662 }
1663 });
1664
1665 if (execution.model == ExecutionModelVertex && hlsl_options.shader_model <= 30 &&
1666 active_output_builtins.get(bit: BuiltInPosition))
1667 {
1668 statement(ts: "uniform float4 gl_HalfPixel;");
1669 emitted = true;
1670 }
1671
1672 bool skip_separate_image_sampler = !combined_image_samplers.empty() || hlsl_options.shader_model <= 30;
1673
1674 // Output Uniform Constants (values, samplers, images, etc).
1675 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
1676 auto &type = this->get<SPIRType>(id: var.basetype);
1677
1678 // If we're remapping separate samplers and images, only emit the combined samplers.
1679 if (skip_separate_image_sampler)
1680 {
1681 // Sampler buffers are always used without a sampler, and they will also work in regular D3D.
1682 bool sampler_buffer = type.basetype == SPIRType::Image && type.image.dim == DimBuffer;
1683 bool separate_image = type.basetype == SPIRType::Image && type.image.sampled == 1;
1684 bool separate_sampler = type.basetype == SPIRType::Sampler;
1685 if (!sampler_buffer && (separate_image || separate_sampler))
1686 return;
1687 }
1688
1689 if (var.storage != StorageClassFunction && !is_builtin_variable(var) && !var.remapped_variable &&
1690 type.pointer && (type.storage == StorageClassUniformConstant || type.storage == StorageClassAtomicCounter) &&
1691 !is_hidden_variable(var))
1692 {
1693 emit_uniform(var);
1694 emitted = true;
1695 }
1696 });
1697
1698 if (emitted)
1699 statement(ts: "");
1700 emitted = false;
1701
1702 // Emit builtin input and output variables here.
1703 emit_builtin_variables();
1704
1705 if (execution.model != ExecutionModelMeshEXT)
1706 {
1707 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
1708 auto &type = this->get<SPIRType>(id: var.basetype);
1709
1710 if (var.storage != StorageClassFunction && !var.remapped_variable && type.pointer &&
1711 (var.storage == StorageClassInput || var.storage == StorageClassOutput) && !is_builtin_variable(var) &&
1712 interface_variable_exists_in_entry_point(id: var.self))
1713 {
1714 // Builtin variables are handled separately.
1715 emit_interface_block_globally(var);
1716 emitted = true;
1717 }
1718 });
1719 }
1720
1721 if (emitted)
1722 statement(ts: "");
1723 emitted = false;
1724
1725 require_input = false;
1726 require_output = false;
1727 unordered_set<uint32_t> active_inputs;
1728 unordered_set<uint32_t> active_outputs;
1729
1730 struct IOVariable
1731 {
1732 const SPIRVariable *var;
1733 uint32_t location;
1734 uint32_t block_member_index;
1735 bool block;
1736 };
1737
1738 SmallVector<IOVariable> input_variables;
1739 SmallVector<IOVariable> output_variables;
1740
1741 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
1742 auto &type = this->get<SPIRType>(id: var.basetype);
1743 bool block = has_decoration(id: type.self, decoration: DecorationBlock);
1744
1745 if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
1746 return;
1747
1748 if (!var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
1749 interface_variable_exists_in_entry_point(id: var.self))
1750 {
1751 if (block)
1752 {
1753 for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++)
1754 {
1755 uint32_t location = get_declared_member_location(var, mbr_idx: i, strip_array: false);
1756 if (var.storage == StorageClassInput)
1757 input_variables.push_back(t: { .var: &var, .location: location, .block_member_index: i, .block: true });
1758 else
1759 output_variables.push_back(t: { .var: &var, .location: location, .block_member_index: i, .block: true });
1760 }
1761 }
1762 else
1763 {
1764 uint32_t location = get_decoration(id: var.self, decoration: DecorationLocation);
1765 if (var.storage == StorageClassInput)
1766 input_variables.push_back(t: { .var: &var, .location: location, .block_member_index: 0, .block: false });
1767 else
1768 output_variables.push_back(t: { .var: &var, .location: location, .block_member_index: 0, .block: false });
1769 }
1770 }
1771 });
1772
1773 const auto variable_compare = [&](const IOVariable &a, const IOVariable &b) -> bool {
1774 // Sort input and output variables based on, from more robust to less robust:
1775 // - Location
1776 // - Variable has a location
1777 // - Name comparison
1778 // - Variable has a name
1779 // - Fallback: ID
1780 bool has_location_a = a.block || has_decoration(id: a.var->self, decoration: DecorationLocation);
1781 bool has_location_b = b.block || has_decoration(id: b.var->self, decoration: DecorationLocation);
1782
1783 if (has_location_a && has_location_b)
1784 return a.location < b.location;
1785 else if (has_location_a && !has_location_b)
1786 return true;
1787 else if (!has_location_a && has_location_b)
1788 return false;
1789
1790 const auto &name1 = to_name(id: a.var->self);
1791 const auto &name2 = to_name(id: b.var->self);
1792
1793 if (name1.empty() && name2.empty())
1794 return a.var->self < b.var->self;
1795 else if (name1.empty())
1796 return true;
1797 else if (name2.empty())
1798 return false;
1799
1800 return name1.compare(str: name2) < 0;
1801 };
1802
1803 auto input_builtins = active_input_builtins;
1804 input_builtins.clear(bit: BuiltInNumWorkgroups);
1805 input_builtins.clear(bit: BuiltInPointCoord);
1806 input_builtins.clear(bit: BuiltInSubgroupSize);
1807 input_builtins.clear(bit: BuiltInSubgroupLocalInvocationId);
1808 input_builtins.clear(bit: BuiltInSubgroupEqMask);
1809 input_builtins.clear(bit: BuiltInSubgroupLtMask);
1810 input_builtins.clear(bit: BuiltInSubgroupLeMask);
1811 input_builtins.clear(bit: BuiltInSubgroupGtMask);
1812 input_builtins.clear(bit: BuiltInSubgroupGeMask);
1813
1814 if (!input_variables.empty() || !input_builtins.empty())
1815 {
1816 require_input = true;
1817 statement(ts: "struct SPIRV_Cross_Input");
1818
1819 begin_scope();
1820 sort(first: input_variables.begin(), last: input_variables.end(), comp: variable_compare);
1821 for (auto &var : input_variables)
1822 {
1823 if (var.block)
1824 emit_interface_block_member_in_struct(var: *var.var, member_index: var.block_member_index, location: var.location, active_locations&: active_inputs);
1825 else
1826 emit_interface_block_in_struct(var: *var.var, active_locations&: active_inputs);
1827 }
1828 emit_builtin_inputs_in_struct();
1829 end_scope_decl();
1830 statement(ts: "");
1831 }
1832
1833 const bool is_mesh_shader = execution.model == ExecutionModelMeshEXT;
1834 if (!output_variables.empty() || !active_output_builtins.empty())
1835 {
1836 sort(first: output_variables.begin(), last: output_variables.end(), comp: variable_compare);
1837 require_output = !is_mesh_shader;
1838
1839 statement(ts: is_mesh_shader ? "struct gl_MeshPerVertexEXT" : "struct SPIRV_Cross_Output");
1840 begin_scope();
1841 for (auto &var : output_variables)
1842 {
1843 if (is_per_primitive_variable(var: *var.var))
1844 continue;
1845 if (var.block && is_mesh_shader && var.block_member_index != 0)
1846 continue;
1847 if (var.block && !is_mesh_shader)
1848 emit_interface_block_member_in_struct(var: *var.var, member_index: var.block_member_index, location: var.location, active_locations&: active_outputs);
1849 else
1850 emit_interface_block_in_struct(var: *var.var, active_locations&: active_outputs);
1851 }
1852 emit_builtin_outputs_in_struct();
1853 if (!is_mesh_shader)
1854 emit_builtin_primitive_outputs_in_struct();
1855 end_scope_decl();
1856 statement(ts: "");
1857
1858 if (is_mesh_shader)
1859 {
1860 statement(ts: "struct gl_MeshPerPrimitiveEXT");
1861 begin_scope();
1862 for (auto &var : output_variables)
1863 {
1864 if (!is_per_primitive_variable(var: *var.var))
1865 continue;
1866 if (var.block && var.block_member_index != 0)
1867 continue;
1868
1869 emit_interface_block_in_struct(var: *var.var, active_locations&: active_outputs);
1870 }
1871 emit_builtin_primitive_outputs_in_struct();
1872 end_scope_decl();
1873 statement(ts: "");
1874 }
1875 }
1876
1877 // Global variables.
1878 for (auto global : global_variables)
1879 {
1880 auto &var = get<SPIRVariable>(id: global);
1881 if (is_hidden_variable(var, include_builtins: true))
1882 continue;
1883
1884 if (var.storage == StorageClassTaskPayloadWorkgroupEXT && is_mesh_shader)
1885 continue;
1886
1887 if (var.storage != StorageClassOutput)
1888 {
1889 if (!variable_is_lut(var))
1890 {
1891 add_resource_name(id: var.self);
1892
1893 const char *storage = nullptr;
1894 switch (var.storage)
1895 {
1896 case StorageClassWorkgroup:
1897 case StorageClassTaskPayloadWorkgroupEXT:
1898 storage = "groupshared";
1899 break;
1900
1901 default:
1902 storage = "static";
1903 break;
1904 }
1905
1906 string initializer;
1907 if (options.force_zero_initialized_variables && var.storage == StorageClassPrivate &&
1908 !var.initializer && !var.static_expression && type_can_zero_initialize(type: get_variable_data_type(var)))
1909 {
1910 initializer = join(ts: " = ", ts: to_zero_initialized_expression(type_id: get_variable_data_type_id(var)));
1911 }
1912 statement(ts&: storage, ts: " ", ts: variable_decl(variable: var), ts&: initializer, ts: ";");
1913
1914 emitted = true;
1915 }
1916 }
1917 }
1918
1919 if (emitted)
1920 statement(ts: "");
1921
1922 if (requires_op_fmod)
1923 {
1924 static const char *types[] = {
1925 "float",
1926 "float2",
1927 "float3",
1928 "float4",
1929 };
1930
1931 for (auto &type : types)
1932 {
1933 statement(ts&: type, ts: " mod(", ts&: type, ts: " x, ", ts&: type, ts: " y)");
1934 begin_scope();
1935 statement(ts: "return x - y * floor(x / y);");
1936 end_scope();
1937 statement(ts: "");
1938 }
1939 }
1940
1941 emit_texture_size_variants(variant_mask: required_texture_size_variants.srv, vecsize_qualifier: "4", uav: false, type_qualifier: "");
1942 for (uint32_t norm = 0; norm < 3; norm++)
1943 {
1944 for (uint32_t comp = 0; comp < 4; comp++)
1945 {
1946 static const char *qualifiers[] = { "", "unorm ", "snorm " };
1947 static const char *vecsizes[] = { "", "2", "3", "4" };
1948 emit_texture_size_variants(variant_mask: required_texture_size_variants.uav[norm][comp], vecsize_qualifier: vecsizes[comp], uav: true,
1949 type_qualifier: qualifiers[norm]);
1950 }
1951 }
1952
1953 if (requires_fp16_packing)
1954 {
1955 // HLSL does not pack into a single word sadly :(
1956 statement(ts: "uint spvPackHalf2x16(float2 value)");
1957 begin_scope();
1958 statement(ts: "uint2 Packed = f32tof16(value);");
1959 statement(ts: "return Packed.x | (Packed.y << 16);");
1960 end_scope();
1961 statement(ts: "");
1962
1963 statement(ts: "float2 spvUnpackHalf2x16(uint value)");
1964 begin_scope();
1965 statement(ts: "return f16tof32(uint2(value & 0xffff, value >> 16));");
1966 end_scope();
1967 statement(ts: "");
1968 }
1969
1970 if (requires_uint2_packing)
1971 {
1972 statement(ts: "uint64_t spvPackUint2x32(uint2 value)");
1973 begin_scope();
1974 statement(ts: "return (uint64_t(value.y) << 32) | uint64_t(value.x);");
1975 end_scope();
1976 statement(ts: "");
1977
1978 statement(ts: "uint2 spvUnpackUint2x32(uint64_t value)");
1979 begin_scope();
1980 statement(ts: "uint2 Unpacked;");
1981 statement(ts: "Unpacked.x = uint(value & 0xffffffff);");
1982 statement(ts: "Unpacked.y = uint(value >> 32);");
1983 statement(ts: "return Unpacked;");
1984 end_scope();
1985 statement(ts: "");
1986 }
1987
1988 if (requires_explicit_fp16_packing)
1989 {
1990 // HLSL does not pack into a single word sadly :(
1991 statement(ts: "uint spvPackFloat2x16(min16float2 value)");
1992 begin_scope();
1993 statement(ts: "uint2 Packed = f32tof16(value);");
1994 statement(ts: "return Packed.x | (Packed.y << 16);");
1995 end_scope();
1996 statement(ts: "");
1997
1998 statement(ts: "min16float2 spvUnpackFloat2x16(uint value)");
1999 begin_scope();
2000 statement(ts: "return min16float2(f16tof32(uint2(value & 0xffff, value >> 16)));");
2001 end_scope();
2002 statement(ts: "");
2003 }
2004
2005 // HLSL does not seem to have builtins for these operation, so roll them by hand ...
2006 if (requires_unorm8_packing)
2007 {
2008 statement(ts: "uint spvPackUnorm4x8(float4 value)");
2009 begin_scope();
2010 statement(ts: "uint4 Packed = uint4(round(saturate(value) * 255.0));");
2011 statement(ts: "return Packed.x | (Packed.y << 8) | (Packed.z << 16) | (Packed.w << 24);");
2012 end_scope();
2013 statement(ts: "");
2014
2015 statement(ts: "float4 spvUnpackUnorm4x8(uint value)");
2016 begin_scope();
2017 statement(ts: "uint4 Packed = uint4(value & 0xff, (value >> 8) & 0xff, (value >> 16) & 0xff, value >> 24);");
2018 statement(ts: "return float4(Packed) / 255.0;");
2019 end_scope();
2020 statement(ts: "");
2021 }
2022
2023 if (requires_snorm8_packing)
2024 {
2025 statement(ts: "uint spvPackSnorm4x8(float4 value)");
2026 begin_scope();
2027 statement(ts: "int4 Packed = int4(round(clamp(value, -1.0, 1.0) * 127.0)) & 0xff;");
2028 statement(ts: "return uint(Packed.x | (Packed.y << 8) | (Packed.z << 16) | (Packed.w << 24));");
2029 end_scope();
2030 statement(ts: "");
2031
2032 statement(ts: "float4 spvUnpackSnorm4x8(uint value)");
2033 begin_scope();
2034 statement(ts: "int SignedValue = int(value);");
2035 statement(ts: "int4 Packed = int4(SignedValue << 24, SignedValue << 16, SignedValue << 8, SignedValue) >> 24;");
2036 statement(ts: "return clamp(float4(Packed) / 127.0, -1.0, 1.0);");
2037 end_scope();
2038 statement(ts: "");
2039 }
2040
2041 if (requires_unorm16_packing)
2042 {
2043 statement(ts: "uint spvPackUnorm2x16(float2 value)");
2044 begin_scope();
2045 statement(ts: "uint2 Packed = uint2(round(saturate(value) * 65535.0));");
2046 statement(ts: "return Packed.x | (Packed.y << 16);");
2047 end_scope();
2048 statement(ts: "");
2049
2050 statement(ts: "float2 spvUnpackUnorm2x16(uint value)");
2051 begin_scope();
2052 statement(ts: "uint2 Packed = uint2(value & 0xffff, value >> 16);");
2053 statement(ts: "return float2(Packed) / 65535.0;");
2054 end_scope();
2055 statement(ts: "");
2056 }
2057
2058 if (requires_snorm16_packing)
2059 {
2060 statement(ts: "uint spvPackSnorm2x16(float2 value)");
2061 begin_scope();
2062 statement(ts: "int2 Packed = int2(round(clamp(value, -1.0, 1.0) * 32767.0)) & 0xffff;");
2063 statement(ts: "return uint(Packed.x | (Packed.y << 16));");
2064 end_scope();
2065 statement(ts: "");
2066
2067 statement(ts: "float2 spvUnpackSnorm2x16(uint value)");
2068 begin_scope();
2069 statement(ts: "int SignedValue = int(value);");
2070 statement(ts: "int2 Packed = int2(SignedValue << 16, SignedValue) >> 16;");
2071 statement(ts: "return clamp(float2(Packed) / 32767.0, -1.0, 1.0);");
2072 end_scope();
2073 statement(ts: "");
2074 }
2075
2076 if (requires_bitfield_insert)
2077 {
2078 static const char *types[] = { "uint", "uint2", "uint3", "uint4" };
2079 for (auto &type : types)
2080 {
2081 statement(ts&: type, ts: " spvBitfieldInsert(", ts&: type, ts: " Base, ", ts&: type, ts: " Insert, uint Offset, uint Count)");
2082 begin_scope();
2083 statement(ts: "uint Mask = Count == 32 ? 0xffffffff : (((1u << Count) - 1) << (Offset & 31));");
2084 statement(ts: "return (Base & ~Mask) | ((Insert << Offset) & Mask);");
2085 end_scope();
2086 statement(ts: "");
2087 }
2088 }
2089
2090 if (requires_bitfield_extract)
2091 {
2092 static const char *unsigned_types[] = { "uint", "uint2", "uint3", "uint4" };
2093 for (auto &type : unsigned_types)
2094 {
2095 statement(ts&: type, ts: " spvBitfieldUExtract(", ts&: type, ts: " Base, uint Offset, uint Count)");
2096 begin_scope();
2097 statement(ts: "uint Mask = Count == 32 ? 0xffffffff : ((1 << Count) - 1);");
2098 statement(ts: "return (Base >> Offset) & Mask;");
2099 end_scope();
2100 statement(ts: "");
2101 }
2102
2103 // In this overload, we will have to do sign-extension, which we will emulate by shifting up and down.
2104 static const char *signed_types[] = { "int", "int2", "int3", "int4" };
2105 for (auto &type : signed_types)
2106 {
2107 statement(ts&: type, ts: " spvBitfieldSExtract(", ts&: type, ts: " Base, int Offset, int Count)");
2108 begin_scope();
2109 statement(ts: "int Mask = Count == 32 ? -1 : ((1 << Count) - 1);");
2110 statement(ts&: type, ts: " Masked = (Base >> Offset) & Mask;");
2111 statement(ts: "int ExtendShift = (32 - Count) & 31;");
2112 statement(ts: "return (Masked << ExtendShift) >> ExtendShift;");
2113 end_scope();
2114 statement(ts: "");
2115 }
2116 }
2117
2118 if (requires_inverse_2x2)
2119 {
2120 statement(ts: "// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
2121 statement(ts: "// adjoint and dividing by the determinant. The contents of the matrix are changed.");
2122 statement(ts: "float2x2 spvInverse(float2x2 m)");
2123 begin_scope();
2124 statement(ts: "float2x2 adj; // The adjoint matrix (inverse after dividing by determinant)");
2125 statement_no_indent(ts: "");
2126 statement(ts: "// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
2127 statement(ts: "adj[0][0] = m[1][1];");
2128 statement(ts: "adj[0][1] = -m[0][1];");
2129 statement_no_indent(ts: "");
2130 statement(ts: "adj[1][0] = -m[1][0];");
2131 statement(ts: "adj[1][1] = m[0][0];");
2132 statement_no_indent(ts: "");
2133 statement(ts: "// Calculate the determinant as a combination of the cofactors of the first row.");
2134 statement(ts: "float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]);");
2135 statement_no_indent(ts: "");
2136 statement(ts: "// Divide the classical adjoint matrix by the determinant.");
2137 statement(ts: "// If determinant is zero, matrix is not invertable, so leave it unchanged.");
2138 statement(ts: "return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
2139 end_scope();
2140 statement(ts: "");
2141 }
2142
2143 if (requires_inverse_3x3)
2144 {
2145 statement(ts: "// Returns the determinant of a 2x2 matrix.");
2146 statement(ts: "float spvDet2x2(float a1, float a2, float b1, float b2)");
2147 begin_scope();
2148 statement(ts: "return a1 * b2 - b1 * a2;");
2149 end_scope();
2150 statement_no_indent(ts: "");
2151 statement(ts: "// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
2152 statement(ts: "// adjoint and dividing by the determinant. The contents of the matrix are changed.");
2153 statement(ts: "float3x3 spvInverse(float3x3 m)");
2154 begin_scope();
2155 statement(ts: "float3x3 adj; // The adjoint matrix (inverse after dividing by determinant)");
2156 statement_no_indent(ts: "");
2157 statement(ts: "// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
2158 statement(ts: "adj[0][0] = spvDet2x2(m[1][1], m[1][2], m[2][1], m[2][2]);");
2159 statement(ts: "adj[0][1] = -spvDet2x2(m[0][1], m[0][2], m[2][1], m[2][2]);");
2160 statement(ts: "adj[0][2] = spvDet2x2(m[0][1], m[0][2], m[1][1], m[1][2]);");
2161 statement_no_indent(ts: "");
2162 statement(ts: "adj[1][0] = -spvDet2x2(m[1][0], m[1][2], m[2][0], m[2][2]);");
2163 statement(ts: "adj[1][1] = spvDet2x2(m[0][0], m[0][2], m[2][0], m[2][2]);");
2164 statement(ts: "adj[1][2] = -spvDet2x2(m[0][0], m[0][2], m[1][0], m[1][2]);");
2165 statement_no_indent(ts: "");
2166 statement(ts: "adj[2][0] = spvDet2x2(m[1][0], m[1][1], m[2][0], m[2][1]);");
2167 statement(ts: "adj[2][1] = -spvDet2x2(m[0][0], m[0][1], m[2][0], m[2][1]);");
2168 statement(ts: "adj[2][2] = spvDet2x2(m[0][0], m[0][1], m[1][0], m[1][1]);");
2169 statement_no_indent(ts: "");
2170 statement(ts: "// Calculate the determinant as a combination of the cofactors of the first row.");
2171 statement(ts: "float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]);");
2172 statement_no_indent(ts: "");
2173 statement(ts: "// Divide the classical adjoint matrix by the determinant.");
2174 statement(ts: "// If determinant is zero, matrix is not invertable, so leave it unchanged.");
2175 statement(ts: "return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
2176 end_scope();
2177 statement(ts: "");
2178 }
2179
2180 if (requires_inverse_4x4)
2181 {
2182 if (!requires_inverse_3x3)
2183 {
2184 statement(ts: "// Returns the determinant of a 2x2 matrix.");
2185 statement(ts: "float spvDet2x2(float a1, float a2, float b1, float b2)");
2186 begin_scope();
2187 statement(ts: "return a1 * b2 - b1 * a2;");
2188 end_scope();
2189 statement(ts: "");
2190 }
2191
2192 statement(ts: "// Returns the determinant of a 3x3 matrix.");
2193 statement(ts: "float spvDet3x3(float a1, float a2, float a3, float b1, float b2, float b3, float c1, "
2194 "float c2, float c3)");
2195 begin_scope();
2196 statement(ts: "return a1 * spvDet2x2(b2, b3, c2, c3) - b1 * spvDet2x2(a2, a3, c2, c3) + c1 * "
2197 "spvDet2x2(a2, a3, "
2198 "b2, b3);");
2199 end_scope();
2200 statement_no_indent(ts: "");
2201 statement(ts: "// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
2202 statement(ts: "// adjoint and dividing by the determinant. The contents of the matrix are changed.");
2203 statement(ts: "float4x4 spvInverse(float4x4 m)");
2204 begin_scope();
2205 statement(ts: "float4x4 adj; // The adjoint matrix (inverse after dividing by determinant)");
2206 statement_no_indent(ts: "");
2207 statement(ts: "// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
2208 statement(
2209 ts: "adj[0][0] = spvDet3x3(m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
2210 "m[3][3]);");
2211 statement(
2212 ts: "adj[0][1] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
2213 "m[3][3]);");
2214 statement(
2215 ts: "adj[0][2] = spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[3][1], m[3][2], "
2216 "m[3][3]);");
2217 statement(
2218 ts: "adj[0][3] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], "
2219 "m[2][3]);");
2220 statement_no_indent(ts: "");
2221 statement(
2222 ts: "adj[1][0] = -spvDet3x3(m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
2223 "m[3][3]);");
2224 statement(
2225 ts: "adj[1][1] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
2226 "m[3][3]);");
2227 statement(
2228 ts: "adj[1][2] = -spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[3][0], m[3][2], "
2229 "m[3][3]);");
2230 statement(
2231 ts: "adj[1][3] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], "
2232 "m[2][3]);");
2233 statement_no_indent(ts: "");
2234 statement(
2235 ts: "adj[2][0] = spvDet3x3(m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
2236 "m[3][3]);");
2237 statement(
2238 ts: "adj[2][1] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
2239 "m[3][3]);");
2240 statement(
2241 ts: "adj[2][2] = spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[3][0], m[3][1], "
2242 "m[3][3]);");
2243 statement(
2244 ts: "adj[2][3] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], "
2245 "m[2][3]);");
2246 statement_no_indent(ts: "");
2247 statement(
2248 ts: "adj[3][0] = -spvDet3x3(m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
2249 "m[3][2]);");
2250 statement(
2251 ts: "adj[3][1] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
2252 "m[3][2]);");
2253 statement(
2254 ts: "adj[3][2] = -spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[3][0], m[3][1], "
2255 "m[3][2]);");
2256 statement(
2257 ts: "adj[3][3] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], "
2258 "m[2][2]);");
2259 statement_no_indent(ts: "");
2260 statement(ts: "// Calculate the determinant as a combination of the cofactors of the first row.");
2261 statement(ts: "float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]) + (adj[0][3] "
2262 "* m[3][0]);");
2263 statement_no_indent(ts: "");
2264 statement(ts: "// Divide the classical adjoint matrix by the determinant.");
2265 statement(ts: "// If determinant is zero, matrix is not invertable, so leave it unchanged.");
2266 statement(ts: "return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
2267 end_scope();
2268 statement(ts: "");
2269 }
2270
2271 if (requires_scalar_reflect)
2272 {
2273 // FP16/FP64? No templates in HLSL.
2274 statement(ts: "float spvReflect(float i, float n)");
2275 begin_scope();
2276 statement(ts: "return i - 2.0 * dot(n, i) * n;");
2277 end_scope();
2278 statement(ts: "");
2279 }
2280
2281 if (requires_scalar_refract)
2282 {
2283 // FP16/FP64? No templates in HLSL.
2284 statement(ts: "float spvRefract(float i, float n, float eta)");
2285 begin_scope();
2286 statement(ts: "float NoI = n * i;");
2287 statement(ts: "float NoI2 = NoI * NoI;");
2288 statement(ts: "float k = 1.0 - eta * eta * (1.0 - NoI2);");
2289 statement(ts: "if (k < 0.0)");
2290 begin_scope();
2291 statement(ts: "return 0.0;");
2292 end_scope();
2293 statement(ts: "else");
2294 begin_scope();
2295 statement(ts: "return eta * i - (eta * NoI + sqrt(k)) * n;");
2296 end_scope();
2297 end_scope();
2298 statement(ts: "");
2299 }
2300
2301 if (requires_scalar_faceforward)
2302 {
2303 // FP16/FP64? No templates in HLSL.
2304 statement(ts: "float spvFaceForward(float n, float i, float nref)");
2305 begin_scope();
2306 statement(ts: "return i * nref < 0.0 ? n : -n;");
2307 end_scope();
2308 statement(ts: "");
2309 }
2310
2311 for (TypeID type_id : composite_selection_workaround_types)
2312 {
2313 // Need out variable since HLSL does not support returning arrays.
2314 auto &type = get<SPIRType>(id: type_id);
2315 auto type_str = type_to_glsl(type);
2316 auto type_arr_str = type_to_array_glsl(type, variable_id: 0);
2317 statement(ts: "void spvSelectComposite(out ", ts&: type_str, ts: " out_value", ts&: type_arr_str, ts: ", bool cond, ",
2318 ts&: type_str, ts: " true_val", ts&: type_arr_str, ts: ", ",
2319 ts&: type_str, ts: " false_val", ts&: type_arr_str, ts: ")");
2320 begin_scope();
2321 statement(ts: "if (cond)");
2322 begin_scope();
2323 statement(ts: "out_value = true_val;");
2324 end_scope();
2325 statement(ts: "else");
2326 begin_scope();
2327 statement(ts: "out_value = false_val;");
2328 end_scope();
2329 end_scope();
2330 statement(ts: "");
2331 }
2332
2333 if (is_mesh_shader && options.vertex.flip_vert_y)
2334 {
2335 statement(ts: "float4 spvFlipVertY(float4 v)");
2336 begin_scope();
2337 statement(ts: "return float4(v.x, -v.y, v.z, v.w);");
2338 end_scope();
2339 statement(ts: "");
2340 statement(ts: "float spvFlipVertY(float v)");
2341 begin_scope();
2342 statement(ts: "return -v;");
2343 end_scope();
2344 statement(ts: "");
2345 }
2346}
2347
2348void CompilerHLSL::emit_texture_size_variants(uint64_t variant_mask, const char *vecsize_qualifier, bool uav,
2349 const char *type_qualifier)
2350{
2351 if (variant_mask == 0)
2352 return;
2353
2354 static const char *types[QueryTypeCount] = { "float", "int", "uint" };
2355 static const char *dims[QueryDimCount] = { "Texture1D", "Texture1DArray", "Texture2D", "Texture2DArray",
2356 "Texture3D", "Buffer", "TextureCube", "TextureCubeArray",
2357 "Texture2DMS", "Texture2DMSArray" };
2358
2359 static const bool has_lod[QueryDimCount] = { true, true, true, true, true, false, true, true, false, false };
2360
2361 static const char *ret_types[QueryDimCount] = {
2362 "uint", "uint2", "uint2", "uint3", "uint3", "uint", "uint2", "uint3", "uint2", "uint3",
2363 };
2364
2365 static const uint32_t return_arguments[QueryDimCount] = {
2366 1, 2, 2, 3, 3, 1, 2, 3, 2, 3,
2367 };
2368
2369 for (uint32_t index = 0; index < QueryDimCount; index++)
2370 {
2371 for (uint32_t type_index = 0; type_index < QueryTypeCount; type_index++)
2372 {
2373 uint32_t bit = 16 * type_index + index;
2374 uint64_t mask = 1ull << bit;
2375
2376 if ((variant_mask & mask) == 0)
2377 continue;
2378
2379 statement(ts&: ret_types[index], ts: " spv", ts: (uav ? "Image" : "Texture"), ts: "Size(", ts: (uav ? "RW" : ""),
2380 ts&: dims[index], ts: "<", ts&: type_qualifier, ts&: types[type_index], ts&: vecsize_qualifier, ts: "> Tex, ",
2381 ts: (uav ? "" : "uint Level, "), ts: "out uint Param)");
2382 begin_scope();
2383 statement(ts&: ret_types[index], ts: " ret;");
2384 switch (return_arguments[index])
2385 {
2386 case 1:
2387 if (has_lod[index] && !uav)
2388 statement(ts: "Tex.GetDimensions(Level, ret.x, Param);");
2389 else
2390 {
2391 statement(ts: "Tex.GetDimensions(ret.x);");
2392 statement(ts: "Param = 0u;");
2393 }
2394 break;
2395 case 2:
2396 if (has_lod[index] && !uav)
2397 statement(ts: "Tex.GetDimensions(Level, ret.x, ret.y, Param);");
2398 else if (!uav)
2399 statement(ts: "Tex.GetDimensions(ret.x, ret.y, Param);");
2400 else
2401 {
2402 statement(ts: "Tex.GetDimensions(ret.x, ret.y);");
2403 statement(ts: "Param = 0u;");
2404 }
2405 break;
2406 case 3:
2407 if (has_lod[index] && !uav)
2408 statement(ts: "Tex.GetDimensions(Level, ret.x, ret.y, ret.z, Param);");
2409 else if (!uav)
2410 statement(ts: "Tex.GetDimensions(ret.x, ret.y, ret.z, Param);");
2411 else
2412 {
2413 statement(ts: "Tex.GetDimensions(ret.x, ret.y, ret.z);");
2414 statement(ts: "Param = 0u;");
2415 }
2416 break;
2417 }
2418
2419 statement(ts: "return ret;");
2420 end_scope();
2421 statement(ts: "");
2422 }
2423 }
2424}
2425
2426void CompilerHLSL::analyze_meshlet_writes()
2427{
2428 uint32_t id_per_vertex = 0;
2429 uint32_t id_per_primitive = 0;
2430 bool need_per_primitive = false;
2431 bool need_per_vertex = false;
2432
2433 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
2434 auto &type = this->get<SPIRType>(id: var.basetype);
2435 bool block = has_decoration(id: type.self, decoration: DecorationBlock);
2436 if (var.storage == StorageClassOutput && block && is_builtin_variable(var))
2437 {
2438 auto flags = get_buffer_block_flags(id: var.self);
2439 if (flags.get(bit: DecorationPerPrimitiveEXT))
2440 id_per_primitive = var.self;
2441 else
2442 id_per_vertex = var.self;
2443 }
2444 else if (var.storage == StorageClassOutput)
2445 {
2446 Bitset flags;
2447 if (block)
2448 flags = get_buffer_block_flags(id: var.self);
2449 else
2450 flags = get_decoration_bitset(id: var.self);
2451
2452 if (flags.get(bit: DecorationPerPrimitiveEXT))
2453 need_per_primitive = true;
2454 else
2455 need_per_vertex = true;
2456 }
2457 });
2458
2459 // If we have per-primitive outputs, and no per-primitive builtins,
2460 // empty version of gl_MeshPerPrimitiveEXT will be emitted.
2461 // If we don't use block IO for vertex output, we'll also need to synthesize the PerVertex block.
2462
2463 const auto generate_block = [&](const char *block_name, const char *instance_name, bool per_primitive) -> uint32_t {
2464 auto &execution = get_entry_point();
2465
2466 uint32_t op_type = ir.increase_bound_by(count: 4);
2467 uint32_t op_arr = op_type + 1;
2468 uint32_t op_ptr = op_type + 2;
2469 uint32_t op_var = op_type + 3;
2470
2471 auto &type = set<SPIRType>(id: op_type, args: OpTypeStruct);
2472 type.basetype = SPIRType::Struct;
2473 set_name(id: op_type, name: block_name);
2474 set_decoration(id: op_type, decoration: DecorationBlock);
2475 if (per_primitive)
2476 set_decoration(id: op_type, decoration: DecorationPerPrimitiveEXT);
2477
2478 auto &arr = set<SPIRType>(id: op_arr, args&: type);
2479 arr.parent_type = type.self;
2480 arr.array.push_back(t: per_primitive ? execution.output_primitives : execution.output_vertices);
2481 arr.array_size_literal.push_back(t: true);
2482
2483 auto &ptr = set<SPIRType>(id: op_ptr, args&: arr);
2484 ptr.parent_type = arr.self;
2485 ptr.pointer = true;
2486 ptr.pointer_depth++;
2487 ptr.storage = StorageClassOutput;
2488 set_decoration(id: op_ptr, decoration: DecorationBlock);
2489 set_name(id: op_ptr, name: block_name);
2490
2491 auto &var = set<SPIRVariable>(id: op_var, args&: op_ptr, args: StorageClassOutput);
2492 if (per_primitive)
2493 set_decoration(id: op_var, decoration: DecorationPerPrimitiveEXT);
2494 set_name(id: op_var, name: instance_name);
2495 execution.interface_variables.push_back(t: var.self);
2496
2497 return op_var;
2498 };
2499
2500 if (id_per_vertex == 0 && need_per_vertex)
2501 id_per_vertex = generate_block("gl_MeshPerVertexEXT", "gl_MeshVerticesEXT", false);
2502 if (id_per_primitive == 0 && need_per_primitive)
2503 id_per_primitive = generate_block("gl_MeshPerPrimitiveEXT", "gl_MeshPrimitivesEXT", true);
2504
2505 unordered_set<uint32_t> processed_func_ids;
2506 analyze_meshlet_writes(func_id: ir.default_entry_point, id_per_vertex, id_per_primitive, processed_func_ids);
2507}
2508
2509void CompilerHLSL::analyze_meshlet_writes(uint32_t func_id, uint32_t id_per_vertex, uint32_t id_per_primitive,
2510 std::unordered_set<uint32_t> &processed_func_ids)
2511{
2512 // Avoid processing a function more than once
2513 if (processed_func_ids.find(x: func_id) != processed_func_ids.end())
2514 return;
2515 processed_func_ids.insert(x: func_id);
2516
2517 auto &func = get<SPIRFunction>(id: func_id);
2518 // Recursively establish global args added to functions on which we depend.
2519 for (auto& block : func.blocks)
2520 {
2521 auto &b = get<SPIRBlock>(id: block);
2522 for (auto &i : b.ops)
2523 {
2524 auto ops = stream(instr: i);
2525 auto op = static_cast<Op>(i.op);
2526
2527 switch (op)
2528 {
2529 case OpFunctionCall:
2530 {
2531 // Then recurse into the function itself to extract globals used internally in the function
2532 uint32_t inner_func_id = ops[2];
2533 analyze_meshlet_writes(func_id: inner_func_id, id_per_vertex, id_per_primitive, processed_func_ids);
2534 auto &inner_func = get<SPIRFunction>(id: inner_func_id);
2535 for (auto &iarg : inner_func.arguments)
2536 {
2537 if (!iarg.alias_global_variable)
2538 continue;
2539
2540 bool already_declared = false;
2541 for (auto &arg : func.arguments)
2542 {
2543 if (arg.id == iarg.id)
2544 {
2545 already_declared = true;
2546 break;
2547 }
2548 }
2549
2550 if (!already_declared)
2551 {
2552 // basetype is effectively ignored here since we declare the argument
2553 // with explicit types. Just pass down a valid type.
2554 func.arguments.push_back(t: { .type: expression_type_id(id: iarg.id), .id: iarg.id,
2555 .read_count: iarg.read_count, .write_count: iarg.write_count, .alias_global_variable: true });
2556 }
2557 }
2558 break;
2559 }
2560
2561 case OpStore:
2562 case OpLoad:
2563 case OpInBoundsAccessChain:
2564 case OpAccessChain:
2565 case OpPtrAccessChain:
2566 case OpInBoundsPtrAccessChain:
2567 case OpArrayLength:
2568 {
2569 auto *var = maybe_get<SPIRVariable>(id: ops[op == OpStore ? 0 : 2]);
2570 if (var && (var->storage == StorageClassOutput || var->storage == StorageClassTaskPayloadWorkgroupEXT))
2571 {
2572 bool already_declared = false;
2573 auto builtin_type = BuiltIn(get_decoration(id: var->self, decoration: DecorationBuiltIn));
2574
2575 uint32_t var_id = var->self;
2576 if (var->storage != StorageClassTaskPayloadWorkgroupEXT &&
2577 builtin_type != BuiltInPrimitivePointIndicesEXT &&
2578 builtin_type != BuiltInPrimitiveLineIndicesEXT &&
2579 builtin_type != BuiltInPrimitiveTriangleIndicesEXT)
2580 {
2581 var_id = is_per_primitive_variable(var: *var) ? id_per_primitive : id_per_vertex;
2582 }
2583
2584 for (auto &arg : func.arguments)
2585 {
2586 if (arg.id == var_id)
2587 {
2588 already_declared = true;
2589 break;
2590 }
2591 }
2592
2593 if (!already_declared)
2594 {
2595 // basetype is effectively ignored here since we declare the argument
2596 // with explicit types. Just pass down a valid type.
2597 uint32_t type_id = expression_type_id(id: var_id);
2598 if (var->storage == StorageClassTaskPayloadWorkgroupEXT)
2599 func.arguments.push_back(t: { .type: type_id, .id: var_id, .read_count: 1u, .write_count: 0u, .alias_global_variable: true });
2600 else
2601 func.arguments.push_back(t: { .type: type_id, .id: var_id, .read_count: 1u, .write_count: 1u, .alias_global_variable: true });
2602 }
2603 }
2604 break;
2605 }
2606
2607 default:
2608 break;
2609 }
2610 }
2611 }
2612}
2613
2614string CompilerHLSL::layout_for_member(const SPIRType &type, uint32_t index)
2615{
2616 auto &flags = get_member_decoration_bitset(id: type.self, index);
2617
2618 // HLSL can emit row_major or column_major decoration in any struct.
2619 // Do not try to merge combined decorations for children like in GLSL.
2620
2621 // Flip the convention. HLSL is a bit odd in that the memory layout is column major ... but the language API is "row-major".
2622 // The way to deal with this is to multiply everything in inverse order, and reverse the memory layout.
2623 if (flags.get(bit: DecorationColMajor))
2624 return "row_major ";
2625 else if (flags.get(bit: DecorationRowMajor))
2626 return "column_major ";
2627
2628 return "";
2629}
2630
2631void CompilerHLSL::emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
2632 const string &qualifier, uint32_t base_offset)
2633{
2634 auto &membertype = get<SPIRType>(id: member_type_id);
2635
2636 Bitset memberflags;
2637 auto &memb = ir.meta[type.self].members;
2638 if (index < memb.size())
2639 memberflags = memb[index].decoration_flags;
2640
2641 string packing_offset;
2642 bool is_push_constant = type.storage == StorageClassPushConstant;
2643
2644 if ((has_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationExplicitOffset) || is_push_constant) &&
2645 has_member_decoration(id: type.self, index, decoration: DecorationOffset))
2646 {
2647 uint32_t offset = memb[index].offset - base_offset;
2648 if (offset & 3)
2649 SPIRV_CROSS_THROW("Cannot pack on tighter bounds than 4 bytes in HLSL.");
2650
2651 static const char *packing_swizzle[] = { "", ".y", ".z", ".w" };
2652 packing_offset = join(ts: " : packoffset(c", ts: offset / 16, ts&: packing_swizzle[(offset & 15) >> 2], ts: ")");
2653 }
2654
2655 statement(ts: layout_for_member(type, index), ts: qualifier,
2656 ts: variable_decl(type: membertype, name: to_member_name(type, index)), ts&: packing_offset, ts: ";");
2657}
2658
2659void CompilerHLSL::emit_rayquery_function(const char *commited, const char *candidate, const uint32_t *ops)
2660{
2661 flush_variable_declaration(id: ops[0]);
2662 uint32_t is_commited = evaluate_constant_u32(id: ops[3]);
2663 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts&: is_commited ? commited : candidate), forward_rhs: false);
2664}
2665
2666void CompilerHLSL::emit_mesh_tasks(SPIRBlock &block)
2667{
2668 if (block.mesh.payload != 0)
2669 {
2670 statement(ts: "DispatchMesh(", ts: to_unpacked_expression(id: block.mesh.groups[0]), ts: ", ", ts: to_unpacked_expression(id: block.mesh.groups[1]), ts: ", ",
2671 ts: to_unpacked_expression(id: block.mesh.groups[2]), ts: ", ", ts: to_unpacked_expression(id: block.mesh.payload), ts: ");");
2672 }
2673 else
2674 {
2675 SPIRV_CROSS_THROW("Amplification shader in HLSL must have payload");
2676 }
2677}
2678
2679void CompilerHLSL::emit_buffer_block(const SPIRVariable &var)
2680{
2681 auto &type = get<SPIRType>(id: var.basetype);
2682
2683 bool is_uav = var.storage == StorageClassStorageBuffer || has_decoration(id: type.self, decoration: DecorationBufferBlock);
2684
2685 if (flattened_buffer_blocks.count(x: var.self))
2686 {
2687 emit_buffer_block_flattened(type: var);
2688 }
2689 else if (is_uav)
2690 {
2691 Bitset flags = ir.get_buffer_block_flags(var);
2692 bool is_readonly = flags.get(bit: DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(id: var.self);
2693 bool is_coherent = flags.get(bit: DecorationCoherent) && !is_readonly;
2694 bool is_interlocked = interlocked_resources.count(x: var.self) > 0;
2695
2696 auto to_structuredbuffer_subtype_name = [this](const SPIRType &parent_type) -> std::string
2697 {
2698 if (parent_type.basetype == SPIRType::Struct && parent_type.member_types.size() == 1)
2699 {
2700 // Use type of first struct member as a StructuredBuffer will have only one '._m0' field in SPIR-V
2701 const auto &member0_type = this->get<SPIRType>(id: parent_type.member_types.front());
2702 return this->type_to_glsl(type: member0_type);
2703 }
2704 else
2705 {
2706 // Otherwise, this StructuredBuffer only has a basic subtype, e.g. StructuredBuffer<int>
2707 return this->type_to_glsl(type: parent_type);
2708 }
2709 };
2710
2711 std::string type_name;
2712 if (is_user_type_structured(id: var.self))
2713 type_name = join(ts: is_readonly ? "" : is_interlocked ? "RasterizerOrdered" : "RW", ts: "StructuredBuffer<", ts: to_structuredbuffer_subtype_name(type), ts: ">");
2714 else
2715 type_name = is_readonly ? "ByteAddressBuffer" : is_interlocked ? "RasterizerOrderedByteAddressBuffer" : "RWByteAddressBuffer";
2716
2717 add_resource_name(id: var.self);
2718 statement(ts: is_coherent ? "globallycoherent " : "", ts&: type_name, ts: " ", ts: to_name(id: var.self), ts: type_to_array_glsl(type, variable_id: var.self),
2719 ts: to_resource_binding(var), ts: ";");
2720 }
2721 else
2722 {
2723 if (type.array.empty())
2724 {
2725 // Flatten the top-level struct so we can use packoffset,
2726 // this restriction is similar to GLSL where layout(offset) is not possible on sub-structs.
2727 flattened_structs[var.self] = false;
2728
2729 // Prefer the block name if possible.
2730 auto buffer_name = to_name(id: type.self, allow_alias: false);
2731 if (ir.meta[type.self].decoration.alias.empty() ||
2732 resource_names.find(x: buffer_name) != end(cont&: resource_names) ||
2733 block_names.find(x: buffer_name) != end(cont&: block_names))
2734 {
2735 buffer_name = get_block_fallback_name(id: var.self);
2736 }
2737
2738 add_variable(variables_primary&: block_names, variables_secondary: resource_names, name&: buffer_name);
2739
2740 // If for some reason buffer_name is an illegal name, make a final fallback to a workaround name.
2741 // This cannot conflict with anything else, so we're safe now.
2742 if (buffer_name.empty())
2743 buffer_name = join(ts: "_", ts&: get<SPIRType>(id: var.basetype).self, ts: "_", ts: var.self);
2744
2745 uint32_t failed_index = 0;
2746 if (buffer_is_packing_standard(type, packing: BufferPackingHLSLCbufferPackOffset, failed_index: &failed_index))
2747 set_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationExplicitOffset);
2748 else
2749 {
2750 SPIRV_CROSS_THROW(join("cbuffer ID ", var.self, " (name: ", buffer_name, "), member index ",
2751 failed_index, " (name: ", to_member_name(type, failed_index),
2752 ") cannot be expressed with either HLSL packing layout or packoffset."));
2753 }
2754
2755 block_names.insert(x: buffer_name);
2756
2757 // Save for post-reflection later.
2758 declared_block_names[var.self] = buffer_name;
2759
2760 type.member_name_cache.clear();
2761 // var.self can be used as a backup name for the block name,
2762 // so we need to make sure we don't disturb the name here on a recompile.
2763 // It will need to be reset if we have to recompile.
2764 preserve_alias_on_reset(id: var.self);
2765 add_resource_name(id: var.self);
2766 statement(ts: "cbuffer ", ts&: buffer_name, ts: to_resource_binding(var));
2767 begin_scope();
2768
2769 uint32_t i = 0;
2770 for (auto &member : type.member_types)
2771 {
2772 add_member_name(type, name: i);
2773 auto backup_name = get_member_name(id: type.self, index: i);
2774 auto member_name = to_member_name(type, index: i);
2775 member_name = join(ts: to_name(id: var.self), ts: "_", ts&: member_name);
2776 ParsedIR::sanitize_underscores(str&: member_name);
2777 set_member_name(id: type.self, index: i, name: member_name);
2778 emit_struct_member(type, member_type_id: member, index: i, qualifier: "");
2779 set_member_name(id: type.self, index: i, name: backup_name);
2780 i++;
2781 }
2782
2783 end_scope_decl();
2784 statement(ts: "");
2785 }
2786 else
2787 {
2788 if (hlsl_options.shader_model < 51)
2789 SPIRV_CROSS_THROW(
2790 "Need ConstantBuffer<T> to use arrays of UBOs, but this is only supported in SM 5.1.");
2791
2792 add_resource_name(id: type.self);
2793 add_resource_name(id: var.self);
2794
2795 // ConstantBuffer<T> does not support packoffset, so it is unuseable unless everything aligns as we expect.
2796 uint32_t failed_index = 0;
2797 if (!buffer_is_packing_standard(type, packing: BufferPackingHLSLCbuffer, failed_index: &failed_index))
2798 {
2799 SPIRV_CROSS_THROW(join("HLSL ConstantBuffer<T> ID ", var.self, " (name: ", to_name(type.self),
2800 "), member index ", failed_index, " (name: ", to_member_name(type, failed_index),
2801 ") cannot be expressed with normal HLSL packing rules."));
2802 }
2803
2804 emit_struct(type&: get<SPIRType>(id: type.self));
2805 statement(ts: "ConstantBuffer<", ts: to_name(id: type.self), ts: "> ", ts: to_name(id: var.self), ts: type_to_array_glsl(type, variable_id: var.self),
2806 ts: to_resource_binding(var), ts: ";");
2807 }
2808 }
2809}
2810
2811void CompilerHLSL::emit_push_constant_block(const SPIRVariable &var)
2812{
2813 if (flattened_buffer_blocks.count(x: var.self))
2814 {
2815 emit_buffer_block_flattened(type: var);
2816 }
2817 else if (root_constants_layout.empty())
2818 {
2819 emit_buffer_block(var);
2820 }
2821 else
2822 {
2823 for (const auto &layout : root_constants_layout)
2824 {
2825 auto &type = get<SPIRType>(id: var.basetype);
2826
2827 uint32_t failed_index = 0;
2828 if (buffer_is_packing_standard(type, packing: BufferPackingHLSLCbufferPackOffset, failed_index: &failed_index, start_offset: layout.start,
2829 end_offset: layout.end))
2830 set_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationExplicitOffset);
2831 else
2832 {
2833 SPIRV_CROSS_THROW(join("Root constant cbuffer ID ", var.self, " (name: ", to_name(type.self), ")",
2834 ", member index ", failed_index, " (name: ", to_member_name(type, failed_index),
2835 ") cannot be expressed with either HLSL packing layout or packoffset."));
2836 }
2837
2838 flattened_structs[var.self] = false;
2839 type.member_name_cache.clear();
2840 add_resource_name(id: var.self);
2841 auto &memb = ir.meta[type.self].members;
2842
2843 statement(ts: "cbuffer SPIRV_CROSS_RootConstant_", ts: to_name(id: var.self),
2844 ts: to_resource_register(flag: HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT, space: 'b', binding: layout.binding, set: layout.space));
2845 begin_scope();
2846
2847 // Index of the next field in the generated root constant constant buffer
2848 auto constant_index = 0u;
2849
2850 // Iterate over all member of the push constant and check which of the fields
2851 // fit into the given root constant layout.
2852 for (auto i = 0u; i < memb.size(); i++)
2853 {
2854 const auto offset = memb[i].offset;
2855 if (layout.start <= offset && offset < layout.end)
2856 {
2857 const auto &member = type.member_types[i];
2858
2859 add_member_name(type, name: constant_index);
2860 auto backup_name = get_member_name(id: type.self, index: i);
2861 auto member_name = to_member_name(type, index: i);
2862 member_name = join(ts: to_name(id: var.self), ts: "_", ts&: member_name);
2863 ParsedIR::sanitize_underscores(str&: member_name);
2864 set_member_name(id: type.self, index: constant_index, name: member_name);
2865 emit_struct_member(type, member_type_id: member, index: i, qualifier: "", base_offset: layout.start);
2866 set_member_name(id: type.self, index: constant_index, name: backup_name);
2867
2868 constant_index++;
2869 }
2870 }
2871
2872 end_scope_decl();
2873 }
2874 }
2875}
2876
2877string CompilerHLSL::to_sampler_expression(uint32_t id)
2878{
2879 auto expr = join(ts: "_", ts: to_non_uniform_aware_expression(id));
2880 auto index = expr.find_first_of(c: '[');
2881 if (index == string::npos)
2882 {
2883 return expr + "_sampler";
2884 }
2885 else
2886 {
2887 // We have an expression like _ident[array], so we cannot tack on _sampler, insert it inside the string instead.
2888 return expr.insert(pos: index, s: "_sampler");
2889 }
2890}
2891
2892void CompilerHLSL::emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id)
2893{
2894 if (hlsl_options.shader_model >= 40 && combined_image_samplers.empty())
2895 {
2896 set<SPIRCombinedImageSampler>(id: result_id, args&: result_type, args&: image_id, args&: samp_id);
2897 }
2898 else
2899 {
2900 // Make sure to suppress usage tracking. It is illegal to create temporaries of opaque types.
2901 emit_op(result_type, result_id, rhs: to_combined_image_sampler(image_id, samp_id), forward_rhs: true, suppress_usage_tracking: true);
2902 }
2903}
2904
2905string CompilerHLSL::to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id)
2906{
2907 string arg_str = CompilerGLSL::to_func_call_arg(arg, id);
2908
2909 if (hlsl_options.shader_model <= 30)
2910 return arg_str;
2911
2912 // Manufacture automatic sampler arg if the arg is a SampledImage texture and we're in modern HLSL.
2913 auto &type = expression_type(id);
2914
2915 // We don't have to consider combined image samplers here via OpSampledImage because
2916 // those variables cannot be passed as arguments to functions.
2917 // Only global SampledImage variables may be used as arguments.
2918 if (type.basetype == SPIRType::SampledImage && type.image.dim != DimBuffer)
2919 arg_str += ", " + to_sampler_expression(id);
2920
2921 return arg_str;
2922}
2923
2924string CompilerHLSL::get_inner_entry_point_name() const
2925{
2926 auto &execution = get_entry_point();
2927
2928 if (hlsl_options.use_entry_point_name)
2929 {
2930 auto name = join(ts: execution.name, ts: "_inner");
2931 ParsedIR::sanitize_underscores(str&: name);
2932 return name;
2933 }
2934
2935 if (execution.model == ExecutionModelVertex)
2936 return "vert_main";
2937 else if (execution.model == ExecutionModelFragment)
2938 return "frag_main";
2939 else if (execution.model == ExecutionModelGLCompute)
2940 return "comp_main";
2941 else if (execution.model == ExecutionModelMeshEXT)
2942 return "mesh_main";
2943 else if (execution.model == ExecutionModelTaskEXT)
2944 return "task_main";
2945 else
2946 SPIRV_CROSS_THROW("Unsupported execution model.");
2947}
2948
2949void CompilerHLSL::emit_function_prototype(SPIRFunction &func, const Bitset &return_flags)
2950{
2951 if (func.self != ir.default_entry_point)
2952 add_function_overload(func);
2953
2954 // Avoid shadow declarations.
2955 local_variable_names = resource_names;
2956
2957 string decl;
2958
2959 auto &type = get<SPIRType>(id: func.return_type);
2960 if (type.array.empty())
2961 {
2962 decl += flags_to_qualifiers_glsl(type, flags: return_flags);
2963 decl += type_to_glsl(type);
2964 decl += " ";
2965 }
2966 else
2967 {
2968 // We cannot return arrays in HLSL, so "return" through an out variable.
2969 decl = "void ";
2970 }
2971
2972 if (func.self == ir.default_entry_point)
2973 {
2974 decl += get_inner_entry_point_name();
2975 processing_entry_point = true;
2976 }
2977 else
2978 decl += to_name(id: func.self);
2979
2980 decl += "(";
2981 SmallVector<string> arglist;
2982
2983 if (!type.array.empty())
2984 {
2985 // Fake array returns by writing to an out array instead.
2986 string out_argument;
2987 out_argument += "out ";
2988 out_argument += type_to_glsl(type);
2989 out_argument += " ";
2990 out_argument += "spvReturnValue";
2991 out_argument += type_to_array_glsl(type, variable_id: 0);
2992 arglist.push_back(t: std::move(out_argument));
2993 }
2994
2995 for (auto &arg : func.arguments)
2996 {
2997 // Do not pass in separate images or samplers if we're remapping
2998 // to combined image samplers.
2999 if (skip_argument(id: arg.id))
3000 continue;
3001
3002 // Might change the variable name if it already exists in this function.
3003 // SPIRV OpName doesn't have any semantic effect, so it's valid for an implementation
3004 // to use same name for variables.
3005 // Since we want to make the GLSL debuggable and somewhat sane, use fallback names for variables which are duplicates.
3006 add_local_variable_name(id: arg.id);
3007
3008 arglist.push_back(t: argument_decl(arg));
3009
3010 // Flatten a combined sampler to two separate arguments in modern HLSL.
3011 auto &arg_type = get<SPIRType>(id: arg.type);
3012 if (hlsl_options.shader_model > 30 && arg_type.basetype == SPIRType::SampledImage &&
3013 arg_type.image.dim != DimBuffer)
3014 {
3015 // Manufacture automatic sampler arg for SampledImage texture
3016 arglist.push_back(t: join(ts: is_depth_image(type: arg_type, id: arg.id) ? "SamplerComparisonState " : "SamplerState ",
3017 ts: to_sampler_expression(id: arg.id), ts: type_to_array_glsl(type: arg_type, variable_id: arg.id)));
3018 }
3019
3020 // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
3021 auto *var = maybe_get<SPIRVariable>(id: arg.id);
3022 if (var)
3023 var->parameter = &arg;
3024 }
3025
3026 for (auto &arg : func.shadow_arguments)
3027 {
3028 // Might change the variable name if it already exists in this function.
3029 // SPIRV OpName doesn't have any semantic effect, so it's valid for an implementation
3030 // to use same name for variables.
3031 // Since we want to make the GLSL debuggable and somewhat sane, use fallback names for variables which are duplicates.
3032 add_local_variable_name(id: arg.id);
3033
3034 arglist.push_back(t: argument_decl(arg));
3035
3036 // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
3037 auto *var = maybe_get<SPIRVariable>(id: arg.id);
3038 if (var)
3039 var->parameter = &arg;
3040 }
3041
3042 decl += merge(list: arglist);
3043 decl += ")";
3044 statement(ts&: decl);
3045}
3046
3047void CompilerHLSL::emit_hlsl_entry_point()
3048{
3049 SmallVector<string> arguments;
3050
3051 if (require_input)
3052 arguments.push_back(t: "SPIRV_Cross_Input stage_input");
3053
3054 auto &execution = get_entry_point();
3055
3056 switch (execution.model)
3057 {
3058 case ExecutionModelTaskEXT:
3059 case ExecutionModelMeshEXT:
3060 case ExecutionModelGLCompute:
3061 {
3062 if (execution.model == ExecutionModelMeshEXT)
3063 {
3064 if (execution.flags.get(bit: ExecutionModeOutputTrianglesEXT))
3065 statement(ts: "[outputtopology(\"triangle\")]");
3066 else if (execution.flags.get(bit: ExecutionModeOutputLinesEXT))
3067 statement(ts: "[outputtopology(\"line\")]");
3068 else if (execution.flags.get(bit: ExecutionModeOutputPoints))
3069 SPIRV_CROSS_THROW("Topology mode \"points\" is not supported in DirectX");
3070
3071 auto &func = get<SPIRFunction>(id: ir.default_entry_point);
3072 for (auto &arg : func.arguments)
3073 {
3074 auto &var = get<SPIRVariable>(id: arg.id);
3075 auto &base_type = get<SPIRType>(id: var.basetype);
3076 bool block = has_decoration(id: base_type.self, decoration: DecorationBlock);
3077 if (var.storage == StorageClassTaskPayloadWorkgroupEXT)
3078 {
3079 arguments.push_back(t: "in payload " + variable_decl(variable: var));
3080 }
3081 else if (block)
3082 {
3083 auto flags = get_buffer_block_flags(id: var.self);
3084 if (flags.get(bit: DecorationPerPrimitiveEXT) || has_decoration(id: arg.id, decoration: DecorationPerPrimitiveEXT))
3085 {
3086 arguments.push_back(t: "out primitives gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[" +
3087 std::to_string(val: execution.output_primitives) + "]");
3088 }
3089 else
3090 {
3091 arguments.push_back(t: "out vertices gl_MeshPerVertexEXT gl_MeshVerticesEXT[" +
3092 std::to_string(val: execution.output_vertices) + "]");
3093 }
3094 }
3095 else
3096 {
3097 if (execution.flags.get(bit: ExecutionModeOutputTrianglesEXT))
3098 {
3099 arguments.push_back(t: "out indices uint3 gl_PrimitiveTriangleIndicesEXT[" +
3100 std::to_string(val: execution.output_primitives) + "]");
3101 }
3102 else
3103 {
3104 arguments.push_back(t: "out indices uint2 gl_PrimitiveLineIndicesEXT[" +
3105 std::to_string(val: execution.output_primitives) + "]");
3106 }
3107 }
3108 }
3109 }
3110 SpecializationConstant wg_x, wg_y, wg_z;
3111 get_work_group_size_specialization_constants(x&: wg_x, y&: wg_y, z&: wg_z);
3112
3113 uint32_t x = execution.workgroup_size.x;
3114 uint32_t y = execution.workgroup_size.y;
3115 uint32_t z = execution.workgroup_size.z;
3116
3117 if (!execution.workgroup_size.constant && execution.flags.get(bit: ExecutionModeLocalSizeId))
3118 {
3119 if (execution.workgroup_size.id_x)
3120 x = get<SPIRConstant>(id: execution.workgroup_size.id_x).scalar();
3121 if (execution.workgroup_size.id_y)
3122 y = get<SPIRConstant>(id: execution.workgroup_size.id_y).scalar();
3123 if (execution.workgroup_size.id_z)
3124 z = get<SPIRConstant>(id: execution.workgroup_size.id_z).scalar();
3125 }
3126
3127 auto x_expr = wg_x.id ? get<SPIRConstant>(id: wg_x.id).specialization_constant_macro_name : to_string(val: x);
3128 auto y_expr = wg_y.id ? get<SPIRConstant>(id: wg_y.id).specialization_constant_macro_name : to_string(val: y);
3129 auto z_expr = wg_z.id ? get<SPIRConstant>(id: wg_z.id).specialization_constant_macro_name : to_string(val: z);
3130
3131 statement(ts: "[numthreads(", ts&: x_expr, ts: ", ", ts&: y_expr, ts: ", ", ts&: z_expr, ts: ")]");
3132 break;
3133 }
3134 case ExecutionModelFragment:
3135 if (execution.flags.get(bit: ExecutionModeEarlyFragmentTests))
3136 statement(ts: "[earlydepthstencil]");
3137 break;
3138 default:
3139 break;
3140 }
3141
3142 const char *entry_point_name;
3143 if (hlsl_options.use_entry_point_name)
3144 entry_point_name = get_entry_point().name.c_str();
3145 else
3146 entry_point_name = "main";
3147
3148 statement(ts: require_output ? "SPIRV_Cross_Output " : "void ", ts&: entry_point_name, ts: "(", ts: merge(list: arguments), ts: ")");
3149 begin_scope();
3150 bool legacy = hlsl_options.shader_model <= 30;
3151
3152 // Copy builtins from entry point arguments to globals.
3153 active_input_builtins.for_each_bit(op: [&](uint32_t i) {
3154 auto builtin = builtin_to_glsl(builtin: static_cast<BuiltIn>(i), storage: StorageClassInput);
3155 switch (static_cast<BuiltIn>(i))
3156 {
3157 case BuiltInFragCoord:
3158 // VPOS in D3D9 is sampled at integer locations, apply half-pixel offset to be consistent.
3159 // TODO: Do we need an option here? Any reason why a D3D9 shader would be used
3160 // on a D3D10+ system with a different rasterization config?
3161 if (legacy)
3162 statement(ts&: builtin, ts: " = stage_input.", ts&: builtin, ts: " + float4(0.5f, 0.5f, 0.0f, 0.0f);");
3163 else
3164 {
3165 statement(ts&: builtin, ts: " = stage_input.", ts&: builtin, ts: ";");
3166 // ZW are undefined in D3D9, only do this fixup here.
3167 statement(ts&: builtin, ts: ".w = 1.0 / ", ts&: builtin, ts: ".w;");
3168 }
3169 break;
3170
3171 case BuiltInVertexId:
3172 case BuiltInVertexIndex:
3173 case BuiltInInstanceIndex:
3174 // D3D semantics are uint, but shader wants int.
3175 if (hlsl_options.support_nonzero_base_vertex_base_instance || hlsl_options.shader_model >= 68)
3176 {
3177 if (hlsl_options.shader_model >= 68)
3178 {
3179 if (static_cast<BuiltIn>(i) == BuiltInInstanceIndex)
3180 statement(ts&: builtin, ts: " = int(stage_input.", ts&: builtin, ts: " + stage_input.gl_BaseInstanceARB);");
3181 else
3182 statement(ts&: builtin, ts: " = int(stage_input.", ts&: builtin, ts: " + stage_input.gl_BaseVertexARB);");
3183 }
3184 else
3185 {
3186 if (static_cast<BuiltIn>(i) == BuiltInInstanceIndex)
3187 statement(ts&: builtin, ts: " = int(stage_input.", ts&: builtin, ts: ") + SPIRV_Cross_BaseInstance;");
3188 else
3189 statement(ts&: builtin, ts: " = int(stage_input.", ts&: builtin, ts: ") + SPIRV_Cross_BaseVertex;");
3190 }
3191 }
3192 else
3193 statement(ts&: builtin, ts: " = int(stage_input.", ts&: builtin, ts: ");");
3194 break;
3195
3196 case BuiltInBaseVertex:
3197 if (hlsl_options.shader_model >= 68)
3198 statement(ts&: builtin, ts: " = stage_input.gl_BaseVertexARB;");
3199 else
3200 statement(ts&: builtin, ts: " = SPIRV_Cross_BaseVertex;");
3201 break;
3202
3203 case BuiltInBaseInstance:
3204 if (hlsl_options.shader_model >= 68)
3205 statement(ts&: builtin, ts: " = stage_input.gl_BaseInstanceARB;");
3206 else
3207 statement(ts&: builtin, ts: " = SPIRV_Cross_BaseInstance;");
3208 break;
3209
3210 case BuiltInInstanceId:
3211 // D3D semantics are uint, but shader wants int.
3212 statement(ts&: builtin, ts: " = int(stage_input.", ts&: builtin, ts: ");");
3213 break;
3214
3215 case BuiltInSampleMask:
3216 statement(ts&: builtin, ts: "[0] = stage_input.", ts&: builtin, ts: ";");
3217 break;
3218
3219 case BuiltInNumWorkgroups:
3220 case BuiltInPointCoord:
3221 case BuiltInSubgroupSize:
3222 case BuiltInSubgroupLocalInvocationId:
3223 case BuiltInHelperInvocation:
3224 break;
3225
3226 case BuiltInSubgroupEqMask:
3227 // Emulate these ...
3228 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
3229 statement(ts: "gl_SubgroupEqMask = 1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96));");
3230 statement(ts: "if (WaveGetLaneIndex() >= 32) gl_SubgroupEqMask.x = 0;");
3231 statement(ts: "if (WaveGetLaneIndex() >= 64 || WaveGetLaneIndex() < 32) gl_SubgroupEqMask.y = 0;");
3232 statement(ts: "if (WaveGetLaneIndex() >= 96 || WaveGetLaneIndex() < 64) gl_SubgroupEqMask.z = 0;");
3233 statement(ts: "if (WaveGetLaneIndex() < 96) gl_SubgroupEqMask.w = 0;");
3234 break;
3235
3236 case BuiltInSubgroupGeMask:
3237 // Emulate these ...
3238 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
3239 statement(ts: "gl_SubgroupGeMask = ~((1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u);");
3240 statement(ts: "if (WaveGetLaneIndex() >= 32) gl_SubgroupGeMask.x = 0u;");
3241 statement(ts: "if (WaveGetLaneIndex() >= 64) gl_SubgroupGeMask.y = 0u;");
3242 statement(ts: "if (WaveGetLaneIndex() >= 96) gl_SubgroupGeMask.z = 0u;");
3243 statement(ts: "if (WaveGetLaneIndex() < 32) gl_SubgroupGeMask.y = ~0u;");
3244 statement(ts: "if (WaveGetLaneIndex() < 64) gl_SubgroupGeMask.z = ~0u;");
3245 statement(ts: "if (WaveGetLaneIndex() < 96) gl_SubgroupGeMask.w = ~0u;");
3246 break;
3247
3248 case BuiltInSubgroupGtMask:
3249 // Emulate these ...
3250 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
3251 statement(ts: "uint gt_lane_index = WaveGetLaneIndex() + 1;");
3252 statement(ts: "gl_SubgroupGtMask = ~((1u << (gt_lane_index - uint4(0, 32, 64, 96))) - 1u);");
3253 statement(ts: "if (gt_lane_index >= 32) gl_SubgroupGtMask.x = 0u;");
3254 statement(ts: "if (gt_lane_index >= 64) gl_SubgroupGtMask.y = 0u;");
3255 statement(ts: "if (gt_lane_index >= 96) gl_SubgroupGtMask.z = 0u;");
3256 statement(ts: "if (gt_lane_index >= 128) gl_SubgroupGtMask.w = 0u;");
3257 statement(ts: "if (gt_lane_index < 32) gl_SubgroupGtMask.y = ~0u;");
3258 statement(ts: "if (gt_lane_index < 64) gl_SubgroupGtMask.z = ~0u;");
3259 statement(ts: "if (gt_lane_index < 96) gl_SubgroupGtMask.w = ~0u;");
3260 break;
3261
3262 case BuiltInSubgroupLeMask:
3263 // Emulate these ...
3264 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
3265 statement(ts: "uint le_lane_index = WaveGetLaneIndex() + 1;");
3266 statement(ts: "gl_SubgroupLeMask = (1u << (le_lane_index - uint4(0, 32, 64, 96))) - 1u;");
3267 statement(ts: "if (le_lane_index >= 32) gl_SubgroupLeMask.x = ~0u;");
3268 statement(ts: "if (le_lane_index >= 64) gl_SubgroupLeMask.y = ~0u;");
3269 statement(ts: "if (le_lane_index >= 96) gl_SubgroupLeMask.z = ~0u;");
3270 statement(ts: "if (le_lane_index >= 128) gl_SubgroupLeMask.w = ~0u;");
3271 statement(ts: "if (le_lane_index < 32) gl_SubgroupLeMask.y = 0u;");
3272 statement(ts: "if (le_lane_index < 64) gl_SubgroupLeMask.z = 0u;");
3273 statement(ts: "if (le_lane_index < 96) gl_SubgroupLeMask.w = 0u;");
3274 break;
3275
3276 case BuiltInSubgroupLtMask:
3277 // Emulate these ...
3278 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
3279 statement(ts: "gl_SubgroupLtMask = (1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u;");
3280 statement(ts: "if (WaveGetLaneIndex() >= 32) gl_SubgroupLtMask.x = ~0u;");
3281 statement(ts: "if (WaveGetLaneIndex() >= 64) gl_SubgroupLtMask.y = ~0u;");
3282 statement(ts: "if (WaveGetLaneIndex() >= 96) gl_SubgroupLtMask.z = ~0u;");
3283 statement(ts: "if (WaveGetLaneIndex() < 32) gl_SubgroupLtMask.y = 0u;");
3284 statement(ts: "if (WaveGetLaneIndex() < 64) gl_SubgroupLtMask.z = 0u;");
3285 statement(ts: "if (WaveGetLaneIndex() < 96) gl_SubgroupLtMask.w = 0u;");
3286 break;
3287
3288 case BuiltInClipDistance:
3289 for (uint32_t clip = 0; clip < clip_distance_count; clip++)
3290 statement(ts: "gl_ClipDistance[", ts&: clip, ts: "] = stage_input.gl_ClipDistance", ts: clip / 4, ts: ".", ts: "xyzw"[clip & 3],
3291 ts: ";");
3292 break;
3293
3294 case BuiltInCullDistance:
3295 for (uint32_t cull = 0; cull < cull_distance_count; cull++)
3296 statement(ts: "gl_CullDistance[", ts&: cull, ts: "] = stage_input.gl_CullDistance", ts: cull / 4, ts: ".", ts: "xyzw"[cull & 3],
3297 ts: ";");
3298 break;
3299
3300 default:
3301 statement(ts&: builtin, ts: " = stage_input.", ts&: builtin, ts: ";");
3302 break;
3303 }
3304 });
3305
3306 // Copy from stage input struct to globals.
3307 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
3308 auto &type = this->get<SPIRType>(id: var.basetype);
3309 bool block = has_decoration(id: type.self, decoration: DecorationBlock);
3310
3311 if (var.storage != StorageClassInput)
3312 return;
3313
3314 bool need_matrix_unroll = var.storage == StorageClassInput && execution.model == ExecutionModelVertex;
3315
3316 if (!var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
3317 interface_variable_exists_in_entry_point(id: var.self))
3318 {
3319 if (block)
3320 {
3321 auto type_name = to_name(id: type.self);
3322 auto var_name = to_name(id: var.self);
3323 bool is_per_vertex = has_decoration(id: var.self, decoration: DecorationPerVertexKHR);
3324 uint32_t array_size = is_per_vertex ? to_array_size_literal(type) : 0;
3325
3326 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(type.member_types.size()); mbr_idx++)
3327 {
3328 auto mbr_name = to_member_name(type, index: mbr_idx);
3329 auto flat_name = join(ts&: type_name, ts: "_", ts&: mbr_name);
3330
3331 if (is_per_vertex)
3332 {
3333 for (uint32_t i = 0; i < array_size; i++)
3334 statement(ts&: var_name, ts: "[", ts&: i, ts: "].", ts&: mbr_name, ts: " = GetAttributeAtVertex(stage_input.", ts&: flat_name, ts: ", ", ts&: i, ts: ");");
3335 }
3336 else
3337 {
3338 statement(ts&: var_name, ts: ".", ts&: mbr_name, ts: " = stage_input.", ts&: flat_name, ts: ";");
3339 }
3340 }
3341 }
3342 else
3343 {
3344 auto name = to_name(id: var.self);
3345 auto &mtype = this->get<SPIRType>(id: var.basetype);
3346 if (need_matrix_unroll && mtype.columns > 1)
3347 {
3348 // Unroll matrices.
3349 for (uint32_t col = 0; col < mtype.columns; col++)
3350 statement(ts&: name, ts: "[", ts&: col, ts: "] = stage_input.", ts&: name, ts: "_", ts&: col, ts: ";");
3351 }
3352 else if (has_decoration(id: var.self, decoration: DecorationPerVertexKHR))
3353 {
3354 uint32_t array_size = to_array_size_literal(type);
3355 for (uint32_t i = 0; i < array_size; i++)
3356 statement(ts&: name, ts: "[", ts&: i, ts: "]", ts: " = GetAttributeAtVertex(stage_input.", ts&: name, ts: ", ", ts&: i, ts: ");");
3357 }
3358 else
3359 {
3360 statement(ts&: name, ts: " = stage_input.", ts&: name, ts: ";");
3361 }
3362 }
3363 }
3364 });
3365
3366 // Run the shader.
3367 if (execution.model == ExecutionModelVertex ||
3368 execution.model == ExecutionModelFragment ||
3369 execution.model == ExecutionModelGLCompute ||
3370 execution.model == ExecutionModelMeshEXT ||
3371 execution.model == ExecutionModelTaskEXT)
3372 {
3373 // For mesh shaders, we receive special arguments that we must pass down as function arguments.
3374 // HLSL does not support proper reference types for passing these IO blocks,
3375 // but DXC post-inlining seems to magically fix it up anyways *shrug*.
3376 SmallVector<string> arglist;
3377 auto &func = get<SPIRFunction>(id: ir.default_entry_point);
3378 // The arguments are marked out, avoid detecting reads and emitting inout.
3379 for (auto &arg : func.arguments)
3380 arglist.push_back(t: to_expression(id: arg.id, register_expression_read: false));
3381 statement(ts: get_inner_entry_point_name(), ts: "(", ts: merge(list: arglist), ts: ");");
3382 }
3383 else
3384 SPIRV_CROSS_THROW("Unsupported shader stage.");
3385
3386 // Copy stage outputs.
3387 if (require_output)
3388 {
3389 statement(ts: "SPIRV_Cross_Output stage_output;");
3390
3391 // Copy builtins from globals to return struct.
3392 active_output_builtins.for_each_bit(op: [&](uint32_t i) {
3393 // PointSize doesn't exist in HLSL SM 4+.
3394 if (i == BuiltInPointSize && !legacy)
3395 return;
3396
3397 switch (static_cast<BuiltIn>(i))
3398 {
3399 case BuiltInClipDistance:
3400 for (uint32_t clip = 0; clip < clip_distance_count; clip++)
3401 statement(ts: "stage_output.gl_ClipDistance", ts: clip / 4, ts: ".", ts: "xyzw"[clip & 3], ts: " = gl_ClipDistance[",
3402 ts&: clip, ts: "];");
3403 break;
3404
3405 case BuiltInCullDistance:
3406 for (uint32_t cull = 0; cull < cull_distance_count; cull++)
3407 statement(ts: "stage_output.gl_CullDistance", ts: cull / 4, ts: ".", ts: "xyzw"[cull & 3], ts: " = gl_CullDistance[",
3408 ts&: cull, ts: "];");
3409 break;
3410
3411 case BuiltInSampleMask:
3412 statement(ts: "stage_output.gl_SampleMask = gl_SampleMask[0];");
3413 break;
3414
3415 default:
3416 {
3417 auto builtin_expr = builtin_to_glsl(builtin: static_cast<BuiltIn>(i), storage: StorageClassOutput);
3418 statement(ts: "stage_output.", ts&: builtin_expr, ts: " = ", ts&: builtin_expr, ts: ";");
3419 break;
3420 }
3421 }
3422 });
3423
3424 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
3425 auto &type = this->get<SPIRType>(id: var.basetype);
3426 bool block = has_decoration(id: type.self, decoration: DecorationBlock);
3427
3428 if (var.storage != StorageClassOutput)
3429 return;
3430
3431 if (!var.remapped_variable && type.pointer &&
3432 !is_builtin_variable(var) &&
3433 interface_variable_exists_in_entry_point(id: var.self))
3434 {
3435 if (block)
3436 {
3437 // I/O blocks need to flatten output.
3438 auto type_name = to_name(id: type.self);
3439 auto var_name = to_name(id: var.self);
3440 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(type.member_types.size()); mbr_idx++)
3441 {
3442 auto mbr_name = to_member_name(type, index: mbr_idx);
3443 auto flat_name = join(ts&: type_name, ts: "_", ts&: mbr_name);
3444 statement(ts: "stage_output.", ts&: flat_name, ts: " = ", ts&: var_name, ts: ".", ts&: mbr_name, ts: ";");
3445 }
3446 }
3447 else
3448 {
3449 auto name = to_name(id: var.self);
3450
3451 if (legacy && execution.model == ExecutionModelFragment)
3452 {
3453 string output_filler;
3454 for (uint32_t size = type.vecsize; size < 4; ++size)
3455 output_filler += ", 0.0";
3456
3457 statement(ts: "stage_output.", ts&: name, ts: " = float4(", ts&: name, ts&: output_filler, ts: ");");
3458 }
3459 else
3460 {
3461 statement(ts: "stage_output.", ts&: name, ts: " = ", ts&: name, ts: ";");
3462 }
3463 }
3464 }
3465 });
3466
3467 statement(ts: "return stage_output;");
3468 }
3469
3470 end_scope();
3471}
3472
3473void CompilerHLSL::emit_fixup()
3474{
3475 if (is_vertex_like_shader() && active_output_builtins.get(bit: BuiltInPosition))
3476 {
3477 // Do various mangling on the gl_Position.
3478 if (hlsl_options.shader_model <= 30)
3479 {
3480 statement(ts: "gl_Position.x = gl_Position.x - gl_HalfPixel.x * "
3481 "gl_Position.w;");
3482 statement(ts: "gl_Position.y = gl_Position.y + gl_HalfPixel.y * "
3483 "gl_Position.w;");
3484 }
3485
3486 if (options.vertex.flip_vert_y)
3487 statement(ts: "gl_Position.y = -gl_Position.y;");
3488 if (options.vertex.fixup_clipspace)
3489 statement(ts: "gl_Position.z = (gl_Position.z + gl_Position.w) * 0.5;");
3490 }
3491}
3492
3493void CompilerHLSL::emit_texture_op(const Instruction &i, bool sparse)
3494{
3495 if (sparse)
3496 SPIRV_CROSS_THROW("Sparse feedback not yet supported in HLSL.");
3497
3498 auto *ops = stream(instr: i);
3499 auto op = static_cast<Op>(i.op);
3500 uint32_t length = i.length;
3501
3502 SmallVector<uint32_t> inherited_expressions;
3503
3504 uint32_t result_type = ops[0];
3505 uint32_t id = ops[1];
3506 VariableID img = ops[2];
3507 uint32_t coord = ops[3];
3508 uint32_t dref = 0;
3509 uint32_t comp = 0;
3510 bool gather = false;
3511 bool proj = false;
3512 const uint32_t *opt = nullptr;
3513 auto *combined_image = maybe_get<SPIRCombinedImageSampler>(id: img);
3514
3515 if (combined_image && has_decoration(id: img, decoration: DecorationNonUniform))
3516 {
3517 set_decoration(id: combined_image->image, decoration: DecorationNonUniform);
3518 set_decoration(id: combined_image->sampler, decoration: DecorationNonUniform);
3519 }
3520
3521 auto img_expr = to_non_uniform_aware_expression(id: combined_image ? combined_image->image : img);
3522
3523 inherited_expressions.push_back(t: coord);
3524
3525 switch (op)
3526 {
3527 case OpImageSampleDrefImplicitLod:
3528 case OpImageSampleDrefExplicitLod:
3529 dref = ops[4];
3530 opt = &ops[5];
3531 length -= 5;
3532 break;
3533
3534 case OpImageSampleProjDrefImplicitLod:
3535 case OpImageSampleProjDrefExplicitLod:
3536 dref = ops[4];
3537 proj = true;
3538 opt = &ops[5];
3539 length -= 5;
3540 break;
3541
3542 case OpImageDrefGather:
3543 dref = ops[4];
3544 opt = &ops[5];
3545 gather = true;
3546 length -= 5;
3547 break;
3548
3549 case OpImageGather:
3550 comp = ops[4];
3551 opt = &ops[5];
3552 gather = true;
3553 length -= 5;
3554 break;
3555
3556 case OpImageSampleProjImplicitLod:
3557 case OpImageSampleProjExplicitLod:
3558 opt = &ops[4];
3559 length -= 4;
3560 proj = true;
3561 break;
3562
3563 case OpImageQueryLod:
3564 opt = &ops[4];
3565 length -= 4;
3566 break;
3567
3568 default:
3569 opt = &ops[4];
3570 length -= 4;
3571 break;
3572 }
3573
3574 auto &imgtype = expression_type(id: img);
3575 uint32_t coord_components = 0;
3576 switch (imgtype.image.dim)
3577 {
3578 case spv::Dim1D:
3579 coord_components = 1;
3580 break;
3581 case spv::Dim2D:
3582 coord_components = 2;
3583 break;
3584 case spv::Dim3D:
3585 coord_components = 3;
3586 break;
3587 case spv::DimCube:
3588 coord_components = 3;
3589 break;
3590 case spv::DimBuffer:
3591 coord_components = 1;
3592 break;
3593 default:
3594 coord_components = 2;
3595 break;
3596 }
3597
3598 if (dref)
3599 inherited_expressions.push_back(t: dref);
3600
3601 if (imgtype.image.arrayed)
3602 coord_components++;
3603
3604 uint32_t bias = 0;
3605 uint32_t lod = 0;
3606 uint32_t grad_x = 0;
3607 uint32_t grad_y = 0;
3608 uint32_t coffset = 0;
3609 uint32_t offset = 0;
3610 uint32_t coffsets = 0;
3611 uint32_t sample = 0;
3612 uint32_t minlod = 0;
3613 uint32_t flags = 0;
3614
3615 if (length)
3616 {
3617 flags = opt[0];
3618 opt++;
3619 length--;
3620 }
3621
3622 auto test = [&](uint32_t &v, uint32_t flag) {
3623 if (length && (flags & flag))
3624 {
3625 v = *opt++;
3626 inherited_expressions.push_back(t: v);
3627 length--;
3628 }
3629 };
3630
3631 test(bias, ImageOperandsBiasMask);
3632 test(lod, ImageOperandsLodMask);
3633 test(grad_x, ImageOperandsGradMask);
3634 test(grad_y, ImageOperandsGradMask);
3635 test(coffset, ImageOperandsConstOffsetMask);
3636 test(offset, ImageOperandsOffsetMask);
3637 test(coffsets, ImageOperandsConstOffsetsMask);
3638 test(sample, ImageOperandsSampleMask);
3639 test(minlod, ImageOperandsMinLodMask);
3640
3641 string expr;
3642 string texop;
3643
3644 if (minlod != 0)
3645 SPIRV_CROSS_THROW("MinLod texture operand not supported in HLSL.");
3646
3647 if (op == OpImageFetch)
3648 {
3649 if (hlsl_options.shader_model < 40)
3650 {
3651 SPIRV_CROSS_THROW("texelFetch is not supported in HLSL shader model 2/3.");
3652 }
3653 texop += img_expr;
3654 texop += ".Load";
3655 }
3656 else if (op == OpImageQueryLod)
3657 {
3658 texop += img_expr;
3659 texop += ".CalculateLevelOfDetail";
3660 }
3661 else
3662 {
3663 auto &imgformat = get<SPIRType>(id: imgtype.image.type);
3664 if (hlsl_options.shader_model < 67 && imgformat.basetype != SPIRType::Float)
3665 {
3666 SPIRV_CROSS_THROW("Sampling non-float textures is not supported in HLSL SM < 6.7.");
3667 }
3668
3669 if (hlsl_options.shader_model >= 40)
3670 {
3671 texop += img_expr;
3672
3673 if (is_depth_image(type: imgtype, id: img))
3674 {
3675 if (gather)
3676 {
3677 texop += ".GatherCmp";
3678 }
3679 else if (lod || grad_x || grad_y)
3680 {
3681 // Assume we want a fixed level, and the only thing we can get in HLSL is SampleCmpLevelZero.
3682 texop += ".SampleCmpLevelZero";
3683 }
3684 else
3685 texop += ".SampleCmp";
3686 }
3687 else if (gather)
3688 {
3689 uint32_t comp_num = evaluate_constant_u32(id: comp);
3690 if (hlsl_options.shader_model >= 50)
3691 {
3692 switch (comp_num)
3693 {
3694 case 0:
3695 texop += ".GatherRed";
3696 break;
3697 case 1:
3698 texop += ".GatherGreen";
3699 break;
3700 case 2:
3701 texop += ".GatherBlue";
3702 break;
3703 case 3:
3704 texop += ".GatherAlpha";
3705 break;
3706 default:
3707 SPIRV_CROSS_THROW("Invalid component.");
3708 }
3709 }
3710 else
3711 {
3712 if (comp_num == 0)
3713 texop += ".Gather";
3714 else
3715 SPIRV_CROSS_THROW("HLSL shader model 4 can only gather from the red component.");
3716 }
3717 }
3718 else if (bias)
3719 texop += ".SampleBias";
3720 else if (grad_x || grad_y)
3721 texop += ".SampleGrad";
3722 else if (lod)
3723 texop += ".SampleLevel";
3724 else
3725 texop += ".Sample";
3726 }
3727 else
3728 {
3729 switch (imgtype.image.dim)
3730 {
3731 case Dim1D:
3732 texop += "tex1D";
3733 break;
3734 case Dim2D:
3735 texop += "tex2D";
3736 break;
3737 case Dim3D:
3738 texop += "tex3D";
3739 break;
3740 case DimCube:
3741 texop += "texCUBE";
3742 break;
3743 case DimRect:
3744 case DimBuffer:
3745 case DimSubpassData:
3746 SPIRV_CROSS_THROW("Buffer texture support is not yet implemented for HLSL"); // TODO
3747 default:
3748 SPIRV_CROSS_THROW("Invalid dimension.");
3749 }
3750
3751 if (gather)
3752 SPIRV_CROSS_THROW("textureGather is not supported in HLSL shader model 2/3.");
3753 if (offset || coffset)
3754 SPIRV_CROSS_THROW("textureOffset is not supported in HLSL shader model 2/3.");
3755
3756 if (grad_x || grad_y)
3757 texop += "grad";
3758 else if (lod)
3759 texop += "lod";
3760 else if (bias)
3761 texop += "bias";
3762 else if (proj || dref)
3763 texop += "proj";
3764 }
3765 }
3766
3767 expr += texop;
3768 expr += "(";
3769 if (hlsl_options.shader_model < 40)
3770 {
3771 if (combined_image)
3772 SPIRV_CROSS_THROW("Separate images/samplers are not supported in HLSL shader model 2/3.");
3773 expr += to_expression(id: img);
3774 }
3775 else if (op != OpImageFetch)
3776 {
3777 string sampler_expr;
3778 if (combined_image)
3779 sampler_expr = to_non_uniform_aware_expression(id: combined_image->sampler);
3780 else
3781 sampler_expr = to_sampler_expression(id: img);
3782 expr += sampler_expr;
3783 }
3784
3785 auto swizzle = [](uint32_t comps, uint32_t in_comps) -> const char * {
3786 if (comps == in_comps)
3787 return "";
3788
3789 switch (comps)
3790 {
3791 case 1:
3792 return ".x";
3793 case 2:
3794 return ".xy";
3795 case 3:
3796 return ".xyz";
3797 default:
3798 return "";
3799 }
3800 };
3801
3802 bool forward = should_forward(id: coord);
3803
3804 // The IR can give us more components than we need, so chop them off as needed.
3805 string coord_expr;
3806 auto &coord_type = expression_type(id: coord);
3807 if (coord_components != coord_type.vecsize)
3808 coord_expr = to_enclosed_expression(id: coord) + swizzle(coord_components, expression_type(id: coord).vecsize);
3809 else
3810 coord_expr = to_expression(id: coord);
3811
3812 if (proj && hlsl_options.shader_model >= 40) // Legacy HLSL has "proj" operations which do this for us.
3813 coord_expr = coord_expr + " / " + to_extract_component_expression(id: coord, index: coord_components);
3814
3815 if (hlsl_options.shader_model < 40)
3816 {
3817 if (dref)
3818 {
3819 if (imgtype.image.dim != spv::Dim1D && imgtype.image.dim != spv::Dim2D)
3820 {
3821 SPIRV_CROSS_THROW(
3822 "Depth comparison is only supported for 1D and 2D textures in HLSL shader model 2/3.");
3823 }
3824
3825 if (grad_x || grad_y)
3826 SPIRV_CROSS_THROW("Depth comparison is not supported for grad sampling in HLSL shader model 2/3.");
3827
3828 for (uint32_t size = coord_components; size < 2; ++size)
3829 coord_expr += ", 0.0";
3830
3831 forward = forward && should_forward(id: dref);
3832 coord_expr += ", " + to_expression(id: dref);
3833 }
3834 else if (lod || bias || proj)
3835 {
3836 for (uint32_t size = coord_components; size < 3; ++size)
3837 coord_expr += ", 0.0";
3838 }
3839
3840 if (lod)
3841 {
3842 coord_expr = "float4(" + coord_expr + ", " + to_expression(id: lod) + ")";
3843 }
3844 else if (bias)
3845 {
3846 coord_expr = "float4(" + coord_expr + ", " + to_expression(id: bias) + ")";
3847 }
3848 else if (proj)
3849 {
3850 coord_expr = "float4(" + coord_expr + ", " + to_extract_component_expression(id: coord, index: coord_components) + ")";
3851 }
3852 else if (dref)
3853 {
3854 // A "normal" sample gets fed into tex2Dproj as well, because the
3855 // regular tex2D accepts only two coordinates.
3856 coord_expr = "float4(" + coord_expr + ", 1.0)";
3857 }
3858
3859 if (!!lod + !!bias + !!proj > 1)
3860 SPIRV_CROSS_THROW("Legacy HLSL can only use one of lod/bias/proj modifiers.");
3861 }
3862
3863 if (op == OpImageFetch)
3864 {
3865 if (imgtype.image.dim != DimBuffer && !imgtype.image.ms)
3866 coord_expr =
3867 join(ts: "int", ts: coord_components + 1, ts: "(", ts&: coord_expr, ts: ", ", ts: lod ? to_expression(id: lod) : string("0"), ts: ")");
3868 }
3869 else
3870 expr += ", ";
3871 expr += coord_expr;
3872
3873 if (dref && hlsl_options.shader_model >= 40)
3874 {
3875 forward = forward && should_forward(id: dref);
3876 expr += ", ";
3877
3878 if (proj)
3879 expr += to_enclosed_expression(id: dref) + " / " + to_extract_component_expression(id: coord, index: coord_components);
3880 else
3881 expr += to_expression(id: dref);
3882 }
3883
3884 if (!dref && (grad_x || grad_y))
3885 {
3886 forward = forward && should_forward(id: grad_x);
3887 forward = forward && should_forward(id: grad_y);
3888 expr += ", ";
3889 expr += to_expression(id: grad_x);
3890 expr += ", ";
3891 expr += to_expression(id: grad_y);
3892 }
3893
3894 if (!dref && lod && hlsl_options.shader_model >= 40 && op != OpImageFetch)
3895 {
3896 forward = forward && should_forward(id: lod);
3897 expr += ", ";
3898 expr += to_expression(id: lod);
3899 }
3900
3901 if (!dref && bias && hlsl_options.shader_model >= 40)
3902 {
3903 forward = forward && should_forward(id: bias);
3904 expr += ", ";
3905 expr += to_expression(id: bias);
3906 }
3907
3908 if (coffset)
3909 {
3910 forward = forward && should_forward(id: coffset);
3911 expr += ", ";
3912 expr += to_expression(id: coffset);
3913 }
3914 else if (offset)
3915 {
3916 forward = forward && should_forward(id: offset);
3917 expr += ", ";
3918 expr += to_expression(id: offset);
3919 }
3920
3921 if (sample)
3922 {
3923 expr += ", ";
3924 expr += to_expression(id: sample);
3925 }
3926
3927 expr += ")";
3928
3929 if (dref && hlsl_options.shader_model < 40)
3930 expr += ".x";
3931
3932 if (op == OpImageQueryLod)
3933 {
3934 // This is rather awkward.
3935 // textureQueryLod returns two values, the "accessed level",
3936 // as well as the actual LOD lambda.
3937 // As far as I can tell, there is no way to get the .x component
3938 // according to GLSL spec, and it depends on the sampler itself.
3939 // Just assume X == Y, so we will need to splat the result to a float2.
3940 statement(ts: "float _", ts&: id, ts: "_tmp = ", ts&: expr, ts: ";");
3941 statement(ts: "float2 _", ts&: id, ts: " = _", ts&: id, ts: "_tmp.xx;");
3942 set<SPIRExpression>(id, args: join(ts: "_", ts&: id), args&: result_type, args: true);
3943 }
3944 else
3945 {
3946 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: forward, suppress_usage_tracking: false);
3947 }
3948
3949 for (auto &inherit : inherited_expressions)
3950 inherit_expression_dependencies(dst: id, source: inherit);
3951
3952 switch (op)
3953 {
3954 case OpImageSampleDrefImplicitLod:
3955 case OpImageSampleImplicitLod:
3956 case OpImageSampleProjImplicitLod:
3957 case OpImageSampleProjDrefImplicitLod:
3958 register_control_dependent_expression(expr: id);
3959 break;
3960
3961 default:
3962 break;
3963 }
3964}
3965
3966string CompilerHLSL::to_resource_binding(const SPIRVariable &var)
3967{
3968 const auto &type = get<SPIRType>(id: var.basetype);
3969
3970 // We can remap push constant blocks, even if they don't have any binding decoration.
3971 if (type.storage != StorageClassPushConstant && !has_decoration(id: var.self, decoration: DecorationBinding))
3972 return "";
3973
3974 char space = '\0';
3975
3976 HLSLBindingFlagBits resource_flags = HLSL_BINDING_AUTO_NONE_BIT;
3977
3978 switch (type.basetype)
3979 {
3980 case SPIRType::SampledImage:
3981 space = 't'; // SRV
3982 resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3983 break;
3984
3985 case SPIRType::Image:
3986 if (type.image.sampled == 2 && type.image.dim != DimSubpassData)
3987 {
3988 if (has_decoration(id: var.self, decoration: DecorationNonWritable) && hlsl_options.nonwritable_uav_texture_as_srv)
3989 {
3990 space = 't'; // SRV
3991 resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3992 }
3993 else
3994 {
3995 space = 'u'; // UAV
3996 resource_flags = HLSL_BINDING_AUTO_UAV_BIT;
3997 }
3998 }
3999 else
4000 {
4001 space = 't'; // SRV
4002 resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
4003 }
4004 break;
4005
4006 case SPIRType::Sampler:
4007 space = 's';
4008 resource_flags = HLSL_BINDING_AUTO_SAMPLER_BIT;
4009 break;
4010
4011 case SPIRType::AccelerationStructure:
4012 space = 't'; // SRV
4013 resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
4014 break;
4015
4016 case SPIRType::Struct:
4017 {
4018 auto storage = type.storage;
4019 if (storage == StorageClassUniform)
4020 {
4021 if (has_decoration(id: type.self, decoration: DecorationBufferBlock))
4022 {
4023 Bitset flags = ir.get_buffer_block_flags(var);
4024 bool is_readonly = flags.get(bit: DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(id: var.self);
4025 space = is_readonly ? 't' : 'u'; // UAV
4026 resource_flags = is_readonly ? HLSL_BINDING_AUTO_SRV_BIT : HLSL_BINDING_AUTO_UAV_BIT;
4027 }
4028 else if (has_decoration(id: type.self, decoration: DecorationBlock))
4029 {
4030 space = 'b'; // Constant buffers
4031 resource_flags = HLSL_BINDING_AUTO_CBV_BIT;
4032 }
4033 }
4034 else if (storage == StorageClassPushConstant)
4035 {
4036 space = 'b'; // Constant buffers
4037 resource_flags = HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT;
4038 }
4039 else if (storage == StorageClassStorageBuffer)
4040 {
4041 // UAV or SRV depending on readonly flag.
4042 Bitset flags = ir.get_buffer_block_flags(var);
4043 bool is_readonly = flags.get(bit: DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(id: var.self);
4044 space = is_readonly ? 't' : 'u';
4045 resource_flags = is_readonly ? HLSL_BINDING_AUTO_SRV_BIT : HLSL_BINDING_AUTO_UAV_BIT;
4046 }
4047
4048 break;
4049 }
4050 default:
4051 break;
4052 }
4053
4054 if (!space)
4055 return "";
4056
4057 uint32_t desc_set =
4058 resource_flags == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT ? ResourceBindingPushConstantDescriptorSet : 0u;
4059 uint32_t binding = resource_flags == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT ? ResourceBindingPushConstantBinding : 0u;
4060
4061 if (has_decoration(id: var.self, decoration: DecorationBinding))
4062 binding = get_decoration(id: var.self, decoration: DecorationBinding);
4063 if (has_decoration(id: var.self, decoration: DecorationDescriptorSet))
4064 desc_set = get_decoration(id: var.self, decoration: DecorationDescriptorSet);
4065
4066 return to_resource_register(flag: resource_flags, space, binding, set: desc_set);
4067}
4068
4069string CompilerHLSL::to_resource_binding_sampler(const SPIRVariable &var)
4070{
4071 // For combined image samplers.
4072 if (!has_decoration(id: var.self, decoration: DecorationBinding))
4073 return "";
4074
4075 return to_resource_register(flag: HLSL_BINDING_AUTO_SAMPLER_BIT, space: 's', binding: get_decoration(id: var.self, decoration: DecorationBinding),
4076 set: get_decoration(id: var.self, decoration: DecorationDescriptorSet));
4077}
4078
4079void CompilerHLSL::remap_hlsl_resource_binding(HLSLBindingFlagBits type, uint32_t &desc_set, uint32_t &binding)
4080{
4081 auto itr = resource_bindings.find(x: { .model: get_execution_model(), .desc_set: desc_set, .binding: binding });
4082 if (itr != end(cont&: resource_bindings))
4083 {
4084 auto &remap = itr->second;
4085 remap.second = true;
4086
4087 switch (type)
4088 {
4089 case HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT:
4090 case HLSL_BINDING_AUTO_CBV_BIT:
4091 desc_set = remap.first.cbv.register_space;
4092 binding = remap.first.cbv.register_binding;
4093 break;
4094
4095 case HLSL_BINDING_AUTO_SRV_BIT:
4096 desc_set = remap.first.srv.register_space;
4097 binding = remap.first.srv.register_binding;
4098 break;
4099
4100 case HLSL_BINDING_AUTO_SAMPLER_BIT:
4101 desc_set = remap.first.sampler.register_space;
4102 binding = remap.first.sampler.register_binding;
4103 break;
4104
4105 case HLSL_BINDING_AUTO_UAV_BIT:
4106 desc_set = remap.first.uav.register_space;
4107 binding = remap.first.uav.register_binding;
4108 break;
4109
4110 default:
4111 break;
4112 }
4113 }
4114}
4115
4116string CompilerHLSL::to_resource_register(HLSLBindingFlagBits flag, char space, uint32_t binding, uint32_t space_set)
4117{
4118 if ((flag & resource_binding_flags) == 0)
4119 {
4120 remap_hlsl_resource_binding(type: flag, desc_set&: space_set, binding);
4121
4122 // The push constant block did not have a binding, and there were no remap for it,
4123 // so, declare without register binding.
4124 if (flag == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT && space_set == ResourceBindingPushConstantDescriptorSet)
4125 return "";
4126
4127 if (hlsl_options.shader_model >= 51)
4128 return join(ts: " : register(", ts&: space, ts&: binding, ts: ", space", ts&: space_set, ts: ")");
4129 else
4130 return join(ts: " : register(", ts&: space, ts&: binding, ts: ")");
4131 }
4132 else
4133 return "";
4134}
4135
4136void CompilerHLSL::emit_modern_uniform(const SPIRVariable &var)
4137{
4138 auto &type = get<SPIRType>(id: var.basetype);
4139 switch (type.basetype)
4140 {
4141 case SPIRType::SampledImage:
4142 case SPIRType::Image:
4143 {
4144 bool is_coherent = false;
4145 if (type.basetype == SPIRType::Image && type.image.sampled == 2)
4146 is_coherent = has_decoration(id: var.self, decoration: DecorationCoherent);
4147
4148 statement(ts: is_coherent ? "globallycoherent " : "", ts: image_type_hlsl_modern(type, id: var.self), ts: " ",
4149 ts: to_name(id: var.self), ts: type_to_array_glsl(type, variable_id: var.self), ts: to_resource_binding(var), ts: ";");
4150
4151 if (type.basetype == SPIRType::SampledImage && type.image.dim != DimBuffer)
4152 {
4153 // For combined image samplers, also emit a combined image sampler.
4154 if (is_depth_image(type, id: var.self))
4155 statement(ts: "SamplerComparisonState ", ts: to_sampler_expression(id: var.self), ts: type_to_array_glsl(type, variable_id: var.self),
4156 ts: to_resource_binding_sampler(var), ts: ";");
4157 else
4158 statement(ts: "SamplerState ", ts: to_sampler_expression(id: var.self), ts: type_to_array_glsl(type, variable_id: var.self),
4159 ts: to_resource_binding_sampler(var), ts: ";");
4160 }
4161 break;
4162 }
4163
4164 case SPIRType::Sampler:
4165 if (comparison_ids.count(x: var.self))
4166 statement(ts: "SamplerComparisonState ", ts: to_name(id: var.self), ts: type_to_array_glsl(type, variable_id: var.self), ts: to_resource_binding(var),
4167 ts: ";");
4168 else
4169 statement(ts: "SamplerState ", ts: to_name(id: var.self), ts: type_to_array_glsl(type, variable_id: var.self), ts: to_resource_binding(var), ts: ";");
4170 break;
4171
4172 default:
4173 statement(ts: variable_decl(variable: var), ts: to_resource_binding(var), ts: ";");
4174 break;
4175 }
4176}
4177
4178void CompilerHLSL::emit_legacy_uniform(const SPIRVariable &var)
4179{
4180 auto &type = get<SPIRType>(id: var.basetype);
4181 switch (type.basetype)
4182 {
4183 case SPIRType::Sampler:
4184 case SPIRType::Image:
4185 SPIRV_CROSS_THROW("Separate image and samplers not supported in legacy HLSL.");
4186
4187 default:
4188 statement(ts: variable_decl(variable: var), ts: ";");
4189 break;
4190 }
4191}
4192
4193void CompilerHLSL::emit_uniform(const SPIRVariable &var)
4194{
4195 add_resource_name(id: var.self);
4196 if (hlsl_options.shader_model >= 40)
4197 emit_modern_uniform(var);
4198 else
4199 emit_legacy_uniform(var);
4200}
4201
4202bool CompilerHLSL::emit_complex_bitcast(uint32_t, uint32_t, uint32_t)
4203{
4204 return false;
4205}
4206
4207string CompilerHLSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
4208{
4209 if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Int)
4210 return type_to_glsl(type: out_type);
4211 else if (out_type.basetype == SPIRType::UInt64 && in_type.basetype == SPIRType::Int64)
4212 return type_to_glsl(type: out_type);
4213 else if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Float)
4214 return "asuint";
4215 else if (out_type.basetype == SPIRType::Int && in_type.basetype == SPIRType::UInt)
4216 return type_to_glsl(type: out_type);
4217 else if (out_type.basetype == SPIRType::Int64 && in_type.basetype == SPIRType::UInt64)
4218 return type_to_glsl(type: out_type);
4219 else if (out_type.basetype == SPIRType::Int && in_type.basetype == SPIRType::Float)
4220 return "asint";
4221 else if (out_type.basetype == SPIRType::Float && in_type.basetype == SPIRType::UInt)
4222 return "asfloat";
4223 else if (out_type.basetype == SPIRType::Float && in_type.basetype == SPIRType::Int)
4224 return "asfloat";
4225 else if (out_type.basetype == SPIRType::Int64 && in_type.basetype == SPIRType::Double)
4226 SPIRV_CROSS_THROW("Double to Int64 is not supported in HLSL.");
4227 else if (out_type.basetype == SPIRType::UInt64 && in_type.basetype == SPIRType::Double)
4228 SPIRV_CROSS_THROW("Double to UInt64 is not supported in HLSL.");
4229 else if (out_type.basetype == SPIRType::Double && in_type.basetype == SPIRType::Int64)
4230 return "asdouble";
4231 else if (out_type.basetype == SPIRType::Double && in_type.basetype == SPIRType::UInt64)
4232 return "asdouble";
4233 else if (out_type.basetype == SPIRType::Half && in_type.basetype == SPIRType::UInt && in_type.vecsize == 1)
4234 {
4235 if (!requires_explicit_fp16_packing)
4236 {
4237 requires_explicit_fp16_packing = true;
4238 force_recompile();
4239 }
4240 return "spvUnpackFloat2x16";
4241 }
4242 else if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Half && in_type.vecsize == 2)
4243 {
4244 if (!requires_explicit_fp16_packing)
4245 {
4246 requires_explicit_fp16_packing = true;
4247 force_recompile();
4248 }
4249 return "spvPackFloat2x16";
4250 }
4251 else if (out_type.basetype == SPIRType::UShort && in_type.basetype == SPIRType::Half)
4252 {
4253 if (hlsl_options.shader_model < 40)
4254 SPIRV_CROSS_THROW("Half to UShort requires Shader Model 4.");
4255 return "(" + type_to_glsl(type: out_type) + ")f32tof16";
4256 }
4257 else if (out_type.basetype == SPIRType::Half && in_type.basetype == SPIRType::UShort)
4258 {
4259 if (hlsl_options.shader_model < 40)
4260 SPIRV_CROSS_THROW("UShort to Half requires Shader Model 4.");
4261 return "(" + type_to_glsl(type: out_type) + ")f16tof32";
4262 }
4263 else
4264 return "";
4265}
4266
4267void CompilerHLSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, uint32_t count)
4268{
4269 auto op = static_cast<GLSLstd450>(eop);
4270
4271 // If we need to do implicit bitcasts, make sure we do it with the correct type.
4272 uint32_t integer_width = get_integer_width_for_glsl_instruction(op, arguments: args, length: count);
4273 auto int_type = to_signed_basetype(width: integer_width);
4274 auto uint_type = to_unsigned_basetype(width: integer_width);
4275
4276 op = get_remapped_glsl_op(std450_op: op);
4277
4278 switch (op)
4279 {
4280 case GLSLstd450InverseSqrt:
4281 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "rsqrt");
4282 break;
4283
4284 case GLSLstd450Fract:
4285 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "frac");
4286 break;
4287
4288 case GLSLstd450RoundEven:
4289 if (hlsl_options.shader_model < 40)
4290 SPIRV_CROSS_THROW("roundEven is not supported in HLSL shader model 2/3.");
4291 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "round");
4292 break;
4293
4294 case GLSLstd450Trunc:
4295 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "trunc");
4296 break;
4297
4298 case GLSLstd450Acosh:
4299 case GLSLstd450Asinh:
4300 case GLSLstd450Atanh:
4301 // These are not supported in HLSL, always emulate them.
4302 emit_emulated_ahyper_op(result_type, result_id: id, op0: args[0], op);
4303 break;
4304
4305 case GLSLstd450FMix:
4306 case GLSLstd450IMix:
4307 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "lerp");
4308 break;
4309
4310 case GLSLstd450Atan2:
4311 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "atan2");
4312 break;
4313
4314 case GLSLstd450Fma:
4315 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "mad");
4316 break;
4317
4318 case GLSLstd450InterpolateAtCentroid:
4319 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "EvaluateAttributeAtCentroid");
4320 break;
4321 case GLSLstd450InterpolateAtSample:
4322 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "EvaluateAttributeAtSample");
4323 break;
4324 case GLSLstd450InterpolateAtOffset:
4325 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "EvaluateAttributeSnapped");
4326 break;
4327
4328 case GLSLstd450PackHalf2x16:
4329 if (!requires_fp16_packing)
4330 {
4331 requires_fp16_packing = true;
4332 force_recompile();
4333 }
4334 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvPackHalf2x16");
4335 break;
4336
4337 case GLSLstd450UnpackHalf2x16:
4338 if (!requires_fp16_packing)
4339 {
4340 requires_fp16_packing = true;
4341 force_recompile();
4342 }
4343 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvUnpackHalf2x16");
4344 break;
4345
4346 case GLSLstd450PackSnorm4x8:
4347 if (!requires_snorm8_packing)
4348 {
4349 requires_snorm8_packing = true;
4350 force_recompile();
4351 }
4352 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvPackSnorm4x8");
4353 break;
4354
4355 case GLSLstd450UnpackSnorm4x8:
4356 if (!requires_snorm8_packing)
4357 {
4358 requires_snorm8_packing = true;
4359 force_recompile();
4360 }
4361 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvUnpackSnorm4x8");
4362 break;
4363
4364 case GLSLstd450PackUnorm4x8:
4365 if (!requires_unorm8_packing)
4366 {
4367 requires_unorm8_packing = true;
4368 force_recompile();
4369 }
4370 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvPackUnorm4x8");
4371 break;
4372
4373 case GLSLstd450UnpackUnorm4x8:
4374 if (!requires_unorm8_packing)
4375 {
4376 requires_unorm8_packing = true;
4377 force_recompile();
4378 }
4379 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvUnpackUnorm4x8");
4380 break;
4381
4382 case GLSLstd450PackSnorm2x16:
4383 if (!requires_snorm16_packing)
4384 {
4385 requires_snorm16_packing = true;
4386 force_recompile();
4387 }
4388 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvPackSnorm2x16");
4389 break;
4390
4391 case GLSLstd450UnpackSnorm2x16:
4392 if (!requires_snorm16_packing)
4393 {
4394 requires_snorm16_packing = true;
4395 force_recompile();
4396 }
4397 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvUnpackSnorm2x16");
4398 break;
4399
4400 case GLSLstd450PackUnorm2x16:
4401 if (!requires_unorm16_packing)
4402 {
4403 requires_unorm16_packing = true;
4404 force_recompile();
4405 }
4406 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvPackUnorm2x16");
4407 break;
4408
4409 case GLSLstd450UnpackUnorm2x16:
4410 if (!requires_unorm16_packing)
4411 {
4412 requires_unorm16_packing = true;
4413 force_recompile();
4414 }
4415 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvUnpackUnorm2x16");
4416 break;
4417
4418 case GLSLstd450PackDouble2x32:
4419 case GLSLstd450UnpackDouble2x32:
4420 SPIRV_CROSS_THROW("packDouble2x32/unpackDouble2x32 not supported in HLSL.");
4421
4422 case GLSLstd450FindILsb:
4423 {
4424 auto basetype = expression_type(id: args[0]).basetype;
4425 emit_unary_func_op_cast(result_type, result_id: id, op0: args[0], op: "firstbitlow", input_type: basetype, expected_result_type: basetype);
4426 break;
4427 }
4428
4429 case GLSLstd450FindSMsb:
4430 emit_unary_func_op_cast(result_type, result_id: id, op0: args[0], op: "firstbithigh", input_type: int_type, expected_result_type: int_type);
4431 break;
4432
4433 case GLSLstd450FindUMsb:
4434 emit_unary_func_op_cast(result_type, result_id: id, op0: args[0], op: "firstbithigh", input_type: uint_type, expected_result_type: uint_type);
4435 break;
4436
4437 case GLSLstd450MatrixInverse:
4438 {
4439 auto &type = get<SPIRType>(id: result_type);
4440 if (type.vecsize == 2 && type.columns == 2)
4441 {
4442 if (!requires_inverse_2x2)
4443 {
4444 requires_inverse_2x2 = true;
4445 force_recompile();
4446 }
4447 }
4448 else if (type.vecsize == 3 && type.columns == 3)
4449 {
4450 if (!requires_inverse_3x3)
4451 {
4452 requires_inverse_3x3 = true;
4453 force_recompile();
4454 }
4455 }
4456 else if (type.vecsize == 4 && type.columns == 4)
4457 {
4458 if (!requires_inverse_4x4)
4459 {
4460 requires_inverse_4x4 = true;
4461 force_recompile();
4462 }
4463 }
4464 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvInverse");
4465 break;
4466 }
4467
4468 case GLSLstd450Normalize:
4469 // HLSL does not support scalar versions here.
4470 if (expression_type(id: args[0]).vecsize == 1)
4471 {
4472 // Returns -1 or 1 for valid input, sign() does the job.
4473 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "sign");
4474 }
4475 else
4476 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
4477 break;
4478
4479 case GLSLstd450Reflect:
4480 if (get<SPIRType>(id: result_type).vecsize == 1)
4481 {
4482 if (!requires_scalar_reflect)
4483 {
4484 requires_scalar_reflect = true;
4485 force_recompile();
4486 }
4487 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "spvReflect");
4488 }
4489 else
4490 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
4491 break;
4492
4493 case GLSLstd450Refract:
4494 if (get<SPIRType>(id: result_type).vecsize == 1)
4495 {
4496 if (!requires_scalar_refract)
4497 {
4498 requires_scalar_refract = true;
4499 force_recompile();
4500 }
4501 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "spvRefract");
4502 }
4503 else
4504 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
4505 break;
4506
4507 case GLSLstd450FaceForward:
4508 if (get<SPIRType>(id: result_type).vecsize == 1)
4509 {
4510 if (!requires_scalar_faceforward)
4511 {
4512 requires_scalar_faceforward = true;
4513 force_recompile();
4514 }
4515 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "spvFaceForward");
4516 }
4517 else
4518 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
4519 break;
4520
4521 case GLSLstd450NMin:
4522 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: GLSLstd450FMin, args, count);
4523 break;
4524
4525 case GLSLstd450NMax:
4526 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: GLSLstd450FMax, args, count);
4527 break;
4528
4529 case GLSLstd450NClamp:
4530 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: GLSLstd450FClamp, args, count);
4531 break;
4532
4533 default:
4534 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
4535 break;
4536 }
4537}
4538
4539void CompilerHLSL::read_access_chain_array(const string &lhs, const SPIRAccessChain &chain)
4540{
4541 auto &type = get<SPIRType>(id: chain.basetype);
4542
4543 // Need to use a reserved identifier here since it might shadow an identifier in the access chain input or other loops.
4544 auto ident = get_unique_identifier();
4545
4546 statement(ts: "[unroll]");
4547 statement(ts: "for (int ", ts&: ident, ts: " = 0; ", ts&: ident, ts: " < ", ts: to_array_size(type, index: uint32_t(type.array.size() - 1)), ts: "; ",
4548 ts&: ident, ts: "++)");
4549 begin_scope();
4550 auto subchain = chain;
4551 subchain.dynamic_index = join(ts&: ident, ts: " * ", ts: chain.array_stride, ts: " + ", ts: chain.dynamic_index);
4552 subchain.basetype = type.parent_type;
4553 if (!get<SPIRType>(id: subchain.basetype).array.empty())
4554 subchain.array_stride = get_decoration(id: subchain.basetype, decoration: DecorationArrayStride);
4555 read_access_chain(expr: nullptr, lhs: join(ts: lhs, ts: "[", ts&: ident, ts: "]"), chain: subchain);
4556 end_scope();
4557}
4558
4559void CompilerHLSL::read_access_chain_struct(const string &lhs, const SPIRAccessChain &chain)
4560{
4561 auto &type = get<SPIRType>(id: chain.basetype);
4562 auto subchain = chain;
4563 uint32_t member_count = uint32_t(type.member_types.size());
4564
4565 for (uint32_t i = 0; i < member_count; i++)
4566 {
4567 uint32_t offset = type_struct_member_offset(type, index: i);
4568 subchain.static_index = chain.static_index + offset;
4569 subchain.basetype = type.member_types[i];
4570
4571 subchain.matrix_stride = 0;
4572 subchain.array_stride = 0;
4573 subchain.row_major_matrix = false;
4574
4575 auto &member_type = get<SPIRType>(id: subchain.basetype);
4576 if (member_type.columns > 1)
4577 {
4578 subchain.matrix_stride = type_struct_member_matrix_stride(type, index: i);
4579 subchain.row_major_matrix = has_member_decoration(id: type.self, index: i, decoration: DecorationRowMajor);
4580 }
4581
4582 if (!member_type.array.empty())
4583 subchain.array_stride = type_struct_member_array_stride(type, index: i);
4584
4585 read_access_chain(expr: nullptr, lhs: join(ts: lhs, ts: ".", ts: to_member_name(type, index: i)), chain: subchain);
4586 }
4587}
4588
4589void CompilerHLSL::read_access_chain(string *expr, const string &lhs, const SPIRAccessChain &chain)
4590{
4591 auto &type = get<SPIRType>(id: chain.basetype);
4592
4593 SPIRType target_type { is_scalar(type) ? OpTypeInt : type.op };
4594 target_type.basetype = SPIRType::UInt;
4595 target_type.vecsize = type.vecsize;
4596 target_type.columns = type.columns;
4597
4598 if (!type.array.empty())
4599 {
4600 read_access_chain_array(lhs, chain);
4601 return;
4602 }
4603 else if (type.basetype == SPIRType::Struct)
4604 {
4605 read_access_chain_struct(lhs, chain);
4606 return;
4607 }
4608 else if (type.width != 32 && !hlsl_options.enable_16bit_types)
4609 SPIRV_CROSS_THROW("Reading types other than 32-bit from ByteAddressBuffer not yet supported, unless SM 6.2 and "
4610 "native 16-bit types are enabled.");
4611
4612 string base = chain.base;
4613 if (has_decoration(id: chain.self, decoration: DecorationNonUniform))
4614 convert_non_uniform_expression(expr&: base, ptr_id: chain.self);
4615
4616 bool templated_load = hlsl_options.shader_model >= 62;
4617 string load_expr;
4618
4619 string template_expr;
4620 if (templated_load)
4621 template_expr = join(ts: "<", ts: type_to_glsl(type), ts: ">");
4622
4623 // Load a vector or scalar.
4624 if (type.columns == 1 && !chain.row_major_matrix)
4625 {
4626 const char *load_op = nullptr;
4627 switch (type.vecsize)
4628 {
4629 case 1:
4630 load_op = "Load";
4631 break;
4632 case 2:
4633 load_op = "Load2";
4634 break;
4635 case 3:
4636 load_op = "Load3";
4637 break;
4638 case 4:
4639 load_op = "Load4";
4640 break;
4641 default:
4642 SPIRV_CROSS_THROW("Unknown vector size.");
4643 }
4644
4645 if (templated_load)
4646 load_op = "Load";
4647
4648 load_expr = join(ts&: base, ts: ".", ts&: load_op, ts&: template_expr, ts: "(", ts: chain.dynamic_index, ts: chain.static_index, ts: ")");
4649 }
4650 else if (type.columns == 1)
4651 {
4652 // Strided load since we are loading a column from a row-major matrix.
4653 if (templated_load)
4654 {
4655 auto scalar_type = type;
4656 scalar_type.vecsize = 1;
4657 scalar_type.columns = 1;
4658 template_expr = join(ts: "<", ts: type_to_glsl(type: scalar_type), ts: ">");
4659 if (type.vecsize > 1)
4660 load_expr += type_to_glsl(type) + "(";
4661 }
4662 else if (type.vecsize > 1)
4663 {
4664 load_expr = type_to_glsl(type: target_type);
4665 load_expr += "(";
4666 }
4667
4668 for (uint32_t r = 0; r < type.vecsize; r++)
4669 {
4670 load_expr += join(ts&: base, ts: ".Load", ts&: template_expr, ts: "(", ts: chain.dynamic_index,
4671 ts: chain.static_index + r * chain.matrix_stride, ts: ")");
4672 if (r + 1 < type.vecsize)
4673 load_expr += ", ";
4674 }
4675
4676 if (type.vecsize > 1)
4677 load_expr += ")";
4678 }
4679 else if (!chain.row_major_matrix)
4680 {
4681 // Load a matrix, column-major, the easy case.
4682 const char *load_op = nullptr;
4683 switch (type.vecsize)
4684 {
4685 case 1:
4686 load_op = "Load";
4687 break;
4688 case 2:
4689 load_op = "Load2";
4690 break;
4691 case 3:
4692 load_op = "Load3";
4693 break;
4694 case 4:
4695 load_op = "Load4";
4696 break;
4697 default:
4698 SPIRV_CROSS_THROW("Unknown vector size.");
4699 }
4700
4701 if (templated_load)
4702 {
4703 auto vector_type = type;
4704 vector_type.columns = 1;
4705 template_expr = join(ts: "<", ts: type_to_glsl(type: vector_type), ts: ">");
4706 load_expr = type_to_glsl(type);
4707 load_op = "Load";
4708 }
4709 else
4710 {
4711 // Note, this loading style in HLSL is *actually* row-major, but we always treat matrices as transposed in this backend,
4712 // so row-major is technically column-major ...
4713 load_expr = type_to_glsl(type: target_type);
4714 }
4715 load_expr += "(";
4716
4717 for (uint32_t c = 0; c < type.columns; c++)
4718 {
4719 load_expr += join(ts&: base, ts: ".", ts&: load_op, ts&: template_expr, ts: "(", ts: chain.dynamic_index,
4720 ts: chain.static_index + c * chain.matrix_stride, ts: ")");
4721 if (c + 1 < type.columns)
4722 load_expr += ", ";
4723 }
4724 load_expr += ")";
4725 }
4726 else
4727 {
4728 // Pick out elements one by one ... Hopefully compilers are smart enough to recognize this pattern
4729 // considering HLSL is "row-major decl", but "column-major" memory layout (basically implicit transpose model, ugh) ...
4730
4731 if (templated_load)
4732 {
4733 load_expr = type_to_glsl(type);
4734 auto scalar_type = type;
4735 scalar_type.vecsize = 1;
4736 scalar_type.columns = 1;
4737 template_expr = join(ts: "<", ts: type_to_glsl(type: scalar_type), ts: ">");
4738 }
4739 else
4740 load_expr = type_to_glsl(type: target_type);
4741
4742 load_expr += "(";
4743
4744 for (uint32_t c = 0; c < type.columns; c++)
4745 {
4746 for (uint32_t r = 0; r < type.vecsize; r++)
4747 {
4748 load_expr += join(ts&: base, ts: ".Load", ts&: template_expr, ts: "(", ts: chain.dynamic_index,
4749 ts: chain.static_index + c * (type.width / 8) + r * chain.matrix_stride, ts: ")");
4750
4751 if ((r + 1 < type.vecsize) || (c + 1 < type.columns))
4752 load_expr += ", ";
4753 }
4754 }
4755 load_expr += ")";
4756 }
4757
4758 if (!templated_load)
4759 {
4760 auto bitcast_op = bitcast_glsl_op(out_type: type, in_type: target_type);
4761 if (!bitcast_op.empty())
4762 load_expr = join(ts&: bitcast_op, ts: "(", ts&: load_expr, ts: ")");
4763 }
4764
4765 if (lhs.empty())
4766 {
4767 assert(expr);
4768 *expr = std::move(load_expr);
4769 }
4770 else
4771 statement(ts: lhs, ts: " = ", ts&: load_expr, ts: ";");
4772}
4773
4774void CompilerHLSL::emit_load(const Instruction &instruction)
4775{
4776 auto ops = stream(instr: instruction);
4777
4778 uint32_t result_type = ops[0];
4779 uint32_t id = ops[1];
4780 uint32_t ptr = ops[2];
4781
4782 auto *chain = maybe_get<SPIRAccessChain>(id: ptr);
4783 if (chain)
4784 {
4785 auto &type = get<SPIRType>(id: result_type);
4786 bool composite_load = !type.array.empty() || type.basetype == SPIRType::Struct;
4787
4788 if (composite_load)
4789 {
4790 // We cannot make this work in one single expression as we might have nested structures and arrays,
4791 // so unroll the load to an uninitialized temporary.
4792 emit_uninitialized_temporary_expression(type: result_type, id);
4793 read_access_chain(expr: nullptr, lhs: to_expression(id), chain: *chain);
4794 track_expression_read(id: chain->self);
4795 }
4796 else
4797 {
4798 string load_expr;
4799 read_access_chain(expr: &load_expr, lhs: "", chain: *chain);
4800
4801 bool forward = should_forward(id: ptr) && forced_temporaries.find(x: id) == end(cont&: forced_temporaries);
4802
4803 // If we are forwarding this load,
4804 // don't register the read to access chain here, defer that to when we actually use the expression,
4805 // using the add_implied_read_expression mechanism.
4806 if (!forward)
4807 track_expression_read(id: chain->self);
4808
4809 // Do not forward complex load sequences like matrices, structs and arrays.
4810 if (type.columns > 1)
4811 forward = false;
4812
4813 auto &e = emit_op(result_type, result_id: id, rhs: load_expr, forward_rhs: forward, suppress_usage_tracking: true);
4814 e.need_transpose = false;
4815 register_read(expr: id, chain: ptr, forwarded: forward);
4816 inherit_expression_dependencies(dst: id, source: ptr);
4817 if (forward)
4818 add_implied_read_expression(e, source: chain->self);
4819 }
4820 }
4821 else
4822 {
4823 // Very special case where we cannot rely on IO lowering.
4824 // Mesh shader clip/cull arrays ... Cursed.
4825 auto &res_type = get<SPIRType>(id: result_type);
4826 if (get_execution_model() == ExecutionModelMeshEXT &&
4827 has_decoration(id: ptr, decoration: DecorationBuiltIn) &&
4828 (get_decoration(id: ptr, decoration: DecorationBuiltIn) == BuiltInClipDistance ||
4829 get_decoration(id: ptr, decoration: DecorationBuiltIn) == BuiltInCullDistance) &&
4830 is_array(type: res_type) && !is_array(type: get<SPIRType>(id: res_type.parent_type)))
4831 {
4832 track_expression_read(id: ptr);
4833 string load_expr = "{ ";
4834 uint32_t num_elements = to_array_size_literal(type: res_type);
4835 for (uint32_t i = 0; i < num_elements; i++)
4836 {
4837 load_expr += join(ts: to_expression(id: ptr), ts: "[", ts&: i, ts: "]");
4838 if (i + 1 < num_elements)
4839 load_expr += ", ";
4840 }
4841 load_expr += " }";
4842 emit_op(result_type, result_id: id, rhs: load_expr, forward_rhs: false);
4843 register_read(expr: id, chain: ptr, forwarded: false);
4844 inherit_expression_dependencies(dst: id, source: ptr);
4845 }
4846 else
4847 {
4848 CompilerGLSL::emit_instruction(instr: instruction);
4849 }
4850 }
4851}
4852
4853void CompilerHLSL::write_access_chain_array(const SPIRAccessChain &chain, uint32_t value,
4854 const SmallVector<uint32_t> &composite_chain)
4855{
4856 auto *ptype = &get<SPIRType>(id: chain.basetype);
4857 while (ptype->pointer)
4858 {
4859 ptype = &get<SPIRType>(id: ptype->basetype);
4860 }
4861 auto &type = *ptype;
4862
4863 // Need to use a reserved identifier here since it might shadow an identifier in the access chain input or other loops.
4864 auto ident = get_unique_identifier();
4865
4866 uint32_t id = ir.increase_bound_by(count: 2);
4867 uint32_t int_type_id = id + 1;
4868 SPIRType int_type { OpTypeInt };
4869 int_type.basetype = SPIRType::Int;
4870 int_type.width = 32;
4871 set<SPIRType>(id: int_type_id, args&: int_type);
4872 set<SPIRExpression>(id, args&: ident, args&: int_type_id, args: true);
4873 set_name(id, name: ident);
4874 suppressed_usage_tracking.insert(x: id);
4875
4876 statement(ts: "[unroll]");
4877 statement(ts: "for (int ", ts&: ident, ts: " = 0; ", ts&: ident, ts: " < ", ts: to_array_size(type, index: uint32_t(type.array.size() - 1)), ts: "; ",
4878 ts&: ident, ts: "++)");
4879 begin_scope();
4880 auto subchain = chain;
4881 subchain.dynamic_index = join(ts&: ident, ts: " * ", ts: chain.array_stride, ts: " + ", ts: chain.dynamic_index);
4882 subchain.basetype = type.parent_type;
4883
4884 // Forcefully allow us to use an ID here by setting MSB.
4885 auto subcomposite_chain = composite_chain;
4886 subcomposite_chain.push_back(t: 0x80000000u | id);
4887
4888 if (!get<SPIRType>(id: subchain.basetype).array.empty())
4889 subchain.array_stride = get_decoration(id: subchain.basetype, decoration: DecorationArrayStride);
4890
4891 write_access_chain(chain: subchain, value, composite_chain: subcomposite_chain);
4892 end_scope();
4893}
4894
4895void CompilerHLSL::write_access_chain_struct(const SPIRAccessChain &chain, uint32_t value,
4896 const SmallVector<uint32_t> &composite_chain)
4897{
4898 auto &type = get<SPIRType>(id: chain.basetype);
4899 uint32_t member_count = uint32_t(type.member_types.size());
4900 auto subchain = chain;
4901
4902 auto subcomposite_chain = composite_chain;
4903 subcomposite_chain.push_back(t: 0);
4904
4905 for (uint32_t i = 0; i < member_count; i++)
4906 {
4907 uint32_t offset = type_struct_member_offset(type, index: i);
4908 subchain.static_index = chain.static_index + offset;
4909 subchain.basetype = type.member_types[i];
4910
4911 subchain.matrix_stride = 0;
4912 subchain.array_stride = 0;
4913 subchain.row_major_matrix = false;
4914
4915 auto &member_type = get<SPIRType>(id: subchain.basetype);
4916 if (member_type.columns > 1)
4917 {
4918 subchain.matrix_stride = type_struct_member_matrix_stride(type, index: i);
4919 subchain.row_major_matrix = has_member_decoration(id: type.self, index: i, decoration: DecorationRowMajor);
4920 }
4921
4922 if (!member_type.array.empty())
4923 subchain.array_stride = type_struct_member_array_stride(type, index: i);
4924
4925 subcomposite_chain.back() = i;
4926 write_access_chain(chain: subchain, value, composite_chain: subcomposite_chain);
4927 }
4928}
4929
4930string CompilerHLSL::write_access_chain_value(uint32_t value, const SmallVector<uint32_t> &composite_chain,
4931 bool enclose)
4932{
4933 string ret;
4934 if (composite_chain.empty())
4935 ret = to_expression(id: value);
4936 else
4937 {
4938 AccessChainMeta meta;
4939 ret = access_chain_internal(base: value, indices: composite_chain.data(), count: uint32_t(composite_chain.size()),
4940 flags: ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_LITERAL_MSB_FORCE_ID, meta: &meta);
4941 }
4942
4943 if (enclose)
4944 ret = enclose_expression(expr: ret);
4945 return ret;
4946}
4947
4948void CompilerHLSL::write_access_chain(const SPIRAccessChain &chain, uint32_t value,
4949 const SmallVector<uint32_t> &composite_chain)
4950{
4951 auto &type = get<SPIRType>(id: chain.basetype);
4952
4953 // Make sure we trigger a read of the constituents in the access chain.
4954 track_expression_read(id: chain.self);
4955
4956 SPIRType target_type { is_scalar(type) ? OpTypeInt : type.op };
4957 target_type.basetype = SPIRType::UInt;
4958 target_type.vecsize = type.vecsize;
4959 target_type.columns = type.columns;
4960
4961 if (!type.array.empty())
4962 {
4963 write_access_chain_array(chain, value, composite_chain);
4964 register_write(chain: chain.self);
4965 return;
4966 }
4967 else if (type.basetype == SPIRType::Struct)
4968 {
4969 write_access_chain_struct(chain, value, composite_chain);
4970 register_write(chain: chain.self);
4971 return;
4972 }
4973 else if (type.width != 32 && !hlsl_options.enable_16bit_types)
4974 SPIRV_CROSS_THROW("Writing types other than 32-bit to RWByteAddressBuffer not yet supported, unless SM 6.2 and "
4975 "native 16-bit types are enabled.");
4976
4977 bool templated_store = hlsl_options.shader_model >= 62;
4978
4979 auto base = chain.base;
4980 if (has_decoration(id: chain.self, decoration: DecorationNonUniform))
4981 convert_non_uniform_expression(expr&: base, ptr_id: chain.self);
4982
4983 string template_expr;
4984 if (templated_store)
4985 template_expr = join(ts: "<", ts: type_to_glsl(type), ts: ">");
4986
4987 if (type.columns == 1 && !chain.row_major_matrix)
4988 {
4989 const char *store_op = nullptr;
4990 switch (type.vecsize)
4991 {
4992 case 1:
4993 store_op = "Store";
4994 break;
4995 case 2:
4996 store_op = "Store2";
4997 break;
4998 case 3:
4999 store_op = "Store3";
5000 break;
5001 case 4:
5002 store_op = "Store4";
5003 break;
5004 default:
5005 SPIRV_CROSS_THROW("Unknown vector size.");
5006 }
5007
5008 auto store_expr = write_access_chain_value(value, composite_chain, enclose: false);
5009
5010 if (!templated_store)
5011 {
5012 auto bitcast_op = bitcast_glsl_op(out_type: target_type, in_type: type);
5013 if (!bitcast_op.empty())
5014 store_expr = join(ts&: bitcast_op, ts: "(", ts&: store_expr, ts: ")");
5015 }
5016 else
5017 store_op = "Store";
5018 statement(ts&: base, ts: ".", ts&: store_op, ts&: template_expr, ts: "(", ts: chain.dynamic_index, ts: chain.static_index, ts: ", ",
5019 ts&: store_expr, ts: ");");
5020 }
5021 else if (type.columns == 1)
5022 {
5023 if (templated_store)
5024 {
5025 auto scalar_type = type;
5026 scalar_type.vecsize = 1;
5027 scalar_type.columns = 1;
5028 template_expr = join(ts: "<", ts: type_to_glsl(type: scalar_type), ts: ">");
5029 }
5030
5031 // Strided store.
5032 for (uint32_t r = 0; r < type.vecsize; r++)
5033 {
5034 auto store_expr = write_access_chain_value(value, composite_chain, enclose: true);
5035 if (type.vecsize > 1)
5036 {
5037 store_expr += ".";
5038 store_expr += index_to_swizzle(index: r);
5039 }
5040 remove_duplicate_swizzle(op&: store_expr);
5041
5042 if (!templated_store)
5043 {
5044 auto bitcast_op = bitcast_glsl_op(out_type: target_type, in_type: type);
5045 if (!bitcast_op.empty())
5046 store_expr = join(ts&: bitcast_op, ts: "(", ts&: store_expr, ts: ")");
5047 }
5048
5049 statement(ts&: base, ts: ".Store", ts&: template_expr, ts: "(", ts: chain.dynamic_index,
5050 ts: chain.static_index + chain.matrix_stride * r, ts: ", ", ts&: store_expr, ts: ");");
5051 }
5052 }
5053 else if (!chain.row_major_matrix)
5054 {
5055 const char *store_op = nullptr;
5056 switch (type.vecsize)
5057 {
5058 case 1:
5059 store_op = "Store";
5060 break;
5061 case 2:
5062 store_op = "Store2";
5063 break;
5064 case 3:
5065 store_op = "Store3";
5066 break;
5067 case 4:
5068 store_op = "Store4";
5069 break;
5070 default:
5071 SPIRV_CROSS_THROW("Unknown vector size.");
5072 }
5073
5074 if (templated_store)
5075 {
5076 store_op = "Store";
5077 auto vector_type = type;
5078 vector_type.columns = 1;
5079 template_expr = join(ts: "<", ts: type_to_glsl(type: vector_type), ts: ">");
5080 }
5081
5082 for (uint32_t c = 0; c < type.columns; c++)
5083 {
5084 auto store_expr = join(ts: write_access_chain_value(value, composite_chain, enclose: true), ts: "[", ts&: c, ts: "]");
5085
5086 if (!templated_store)
5087 {
5088 auto bitcast_op = bitcast_glsl_op(out_type: target_type, in_type: type);
5089 if (!bitcast_op.empty())
5090 store_expr = join(ts&: bitcast_op, ts: "(", ts&: store_expr, ts: ")");
5091 }
5092
5093 statement(ts&: base, ts: ".", ts&: store_op, ts&: template_expr, ts: "(", ts: chain.dynamic_index,
5094 ts: chain.static_index + c * chain.matrix_stride, ts: ", ", ts&: store_expr, ts: ");");
5095 }
5096 }
5097 else
5098 {
5099 if (templated_store)
5100 {
5101 auto scalar_type = type;
5102 scalar_type.vecsize = 1;
5103 scalar_type.columns = 1;
5104 template_expr = join(ts: "<", ts: type_to_glsl(type: scalar_type), ts: ">");
5105 }
5106
5107 for (uint32_t r = 0; r < type.vecsize; r++)
5108 {
5109 for (uint32_t c = 0; c < type.columns; c++)
5110 {
5111 auto store_expr =
5112 join(ts: write_access_chain_value(value, composite_chain, enclose: true), ts: "[", ts&: c, ts: "].", ts: index_to_swizzle(index: r));
5113 remove_duplicate_swizzle(op&: store_expr);
5114 auto bitcast_op = bitcast_glsl_op(out_type: target_type, in_type: type);
5115 if (!bitcast_op.empty())
5116 store_expr = join(ts&: bitcast_op, ts: "(", ts&: store_expr, ts: ")");
5117 statement(ts&: base, ts: ".Store", ts&: template_expr, ts: "(", ts: chain.dynamic_index,
5118 ts: chain.static_index + c * (type.width / 8) + r * chain.matrix_stride, ts: ", ", ts&: store_expr, ts: ");");
5119 }
5120 }
5121 }
5122
5123 register_write(chain: chain.self);
5124}
5125
5126void CompilerHLSL::emit_store(const Instruction &instruction)
5127{
5128 auto ops = stream(instr: instruction);
5129 if (options.vertex.flip_vert_y)
5130 {
5131 auto *expr = maybe_get<SPIRExpression>(id: ops[0]);
5132 if (expr != nullptr && expr->access_meshlet_position_y)
5133 {
5134 auto lhs = to_dereferenced_expression(id: ops[0]);
5135 auto rhs = to_unpacked_expression(id: ops[1]);
5136 statement(ts&: lhs, ts: " = spvFlipVertY(", ts&: rhs, ts: ");");
5137 register_write(chain: ops[0]);
5138 return;
5139 }
5140 }
5141
5142 auto *chain = maybe_get<SPIRAccessChain>(id: ops[0]);
5143 if (chain)
5144 write_access_chain(chain: *chain, value: ops[1], composite_chain: {});
5145 else
5146 CompilerGLSL::emit_instruction(instr: instruction);
5147}
5148
5149void CompilerHLSL::emit_access_chain(const Instruction &instruction)
5150{
5151 auto ops = stream(instr: instruction);
5152 uint32_t length = instruction.length;
5153
5154 bool need_byte_access_chain = false;
5155 auto &type = expression_type(id: ops[2]);
5156 const auto *chain = maybe_get<SPIRAccessChain>(id: ops[2]);
5157
5158 if (chain)
5159 {
5160 // Keep tacking on an existing access chain.
5161 need_byte_access_chain = true;
5162 }
5163 else if (type.storage == StorageClassStorageBuffer || has_decoration(id: type.self, decoration: DecorationBufferBlock))
5164 {
5165 // If we are starting to poke into an SSBO, we are dealing with ByteAddressBuffers, and we need
5166 // to emit SPIRAccessChain rather than a plain SPIRExpression.
5167 uint32_t chain_arguments = length - 3;
5168 if (chain_arguments > type.array.size())
5169 need_byte_access_chain = true;
5170 }
5171
5172 if (need_byte_access_chain)
5173 {
5174 // If we have a chain variable, we are already inside the SSBO, and any array type will refer to arrays within a block,
5175 // and not array of SSBO.
5176 uint32_t to_plain_buffer_length = chain ? 0u : static_cast<uint32_t>(type.array.size());
5177
5178 auto *backing_variable = maybe_get_backing_variable(chain: ops[2]);
5179
5180 if (backing_variable != nullptr && is_user_type_structured(id: backing_variable->self))
5181 {
5182 CompilerGLSL::emit_instruction(instr: instruction);
5183 return;
5184 }
5185
5186 string base;
5187 if (to_plain_buffer_length != 0)
5188 base = access_chain(base: ops[2], indices: &ops[3], count: to_plain_buffer_length, target_type: get<SPIRType>(id: ops[0]));
5189 else if (chain)
5190 base = chain->base;
5191 else
5192 base = to_expression(id: ops[2]);
5193
5194 // Start traversing type hierarchy at the proper non-pointer types.
5195 auto *basetype = &get_pointee_type(type);
5196
5197 // Traverse the type hierarchy down to the actual buffer types.
5198 for (uint32_t i = 0; i < to_plain_buffer_length; i++)
5199 {
5200 assert(basetype->parent_type);
5201 basetype = &get<SPIRType>(id: basetype->parent_type);
5202 }
5203
5204 uint32_t matrix_stride = 0;
5205 uint32_t array_stride = 0;
5206 bool row_major_matrix = false;
5207
5208 // Inherit matrix information.
5209 if (chain)
5210 {
5211 matrix_stride = chain->matrix_stride;
5212 row_major_matrix = chain->row_major_matrix;
5213 array_stride = chain->array_stride;
5214 }
5215
5216 auto offsets = flattened_access_chain_offset(basetype: *basetype, indices: &ops[3 + to_plain_buffer_length],
5217 count: length - 3 - to_plain_buffer_length, offset: 0, word_stride: 1, need_transpose: &row_major_matrix,
5218 matrix_stride: &matrix_stride, array_stride: &array_stride);
5219
5220 auto &e = set<SPIRAccessChain>(id: ops[1], args: ops[0], args: type.storage, args&: base, args&: offsets.first, args&: offsets.second);
5221 e.row_major_matrix = row_major_matrix;
5222 e.matrix_stride = matrix_stride;
5223 e.array_stride = array_stride;
5224 e.immutable = should_forward(id: ops[2]);
5225 e.loaded_from = backing_variable ? backing_variable->self : ID(0);
5226
5227 if (chain)
5228 {
5229 e.dynamic_index += chain->dynamic_index;
5230 e.static_index += chain->static_index;
5231 }
5232
5233 for (uint32_t i = 2; i < length; i++)
5234 {
5235 inherit_expression_dependencies(dst: ops[1], source: ops[i]);
5236 add_implied_read_expression(e, source: ops[i]);
5237 }
5238 }
5239 else
5240 {
5241 CompilerGLSL::emit_instruction(instr: instruction);
5242 }
5243}
5244
5245void CompilerHLSL::emit_atomic(const uint32_t *ops, uint32_t length, spv::Op op)
5246{
5247 const char *atomic_op = nullptr;
5248
5249 string value_expr;
5250 if (op != OpAtomicIDecrement && op != OpAtomicIIncrement && op != OpAtomicLoad && op != OpAtomicStore)
5251 value_expr = to_expression(id: ops[op == OpAtomicCompareExchange ? 6 : 5]);
5252
5253 bool is_atomic_store = false;
5254
5255 switch (op)
5256 {
5257 case OpAtomicIIncrement:
5258 atomic_op = "InterlockedAdd";
5259 value_expr = "1";
5260 break;
5261
5262 case OpAtomicIDecrement:
5263 atomic_op = "InterlockedAdd";
5264 value_expr = "-1";
5265 break;
5266
5267 case OpAtomicLoad:
5268 atomic_op = "InterlockedAdd";
5269 value_expr = "0";
5270 break;
5271
5272 case OpAtomicISub:
5273 atomic_op = "InterlockedAdd";
5274 value_expr = join(ts: "-", ts: enclose_expression(expr: value_expr));
5275 break;
5276
5277 case OpAtomicSMin:
5278 case OpAtomicUMin:
5279 atomic_op = "InterlockedMin";
5280 break;
5281
5282 case OpAtomicSMax:
5283 case OpAtomicUMax:
5284 atomic_op = "InterlockedMax";
5285 break;
5286
5287 case OpAtomicAnd:
5288 atomic_op = "InterlockedAnd";
5289 break;
5290
5291 case OpAtomicOr:
5292 atomic_op = "InterlockedOr";
5293 break;
5294
5295 case OpAtomicXor:
5296 atomic_op = "InterlockedXor";
5297 break;
5298
5299 case OpAtomicIAdd:
5300 atomic_op = "InterlockedAdd";
5301 break;
5302
5303 case OpAtomicExchange:
5304 atomic_op = "InterlockedExchange";
5305 break;
5306
5307 case OpAtomicStore:
5308 atomic_op = "InterlockedExchange";
5309 is_atomic_store = true;
5310 break;
5311
5312 case OpAtomicCompareExchange:
5313 if (length < 8)
5314 SPIRV_CROSS_THROW("Not enough data for opcode.");
5315 atomic_op = "InterlockedCompareExchange";
5316 value_expr = join(ts: to_expression(id: ops[7]), ts: ", ", ts&: value_expr);
5317 break;
5318
5319 default:
5320 SPIRV_CROSS_THROW("Unknown atomic opcode.");
5321 }
5322
5323 if (is_atomic_store)
5324 {
5325 auto &data_type = expression_type(id: ops[0]);
5326 auto *chain = maybe_get<SPIRAccessChain>(id: ops[0]);
5327
5328 auto &tmp_id = extra_sub_expressions[ops[0]];
5329 if (!tmp_id)
5330 {
5331 tmp_id = ir.increase_bound_by(count: 1);
5332 emit_uninitialized_temporary_expression(type: get_pointee_type(type: data_type).self, id: tmp_id);
5333 }
5334
5335 if (data_type.storage == StorageClassImage || !chain)
5336 {
5337 statement(ts&: atomic_op, ts: "(", ts: to_non_uniform_aware_expression(id: ops[0]), ts: ", ",
5338 ts: to_expression(id: ops[3]), ts: ", ", ts: to_expression(id: tmp_id), ts: ");");
5339 }
5340 else
5341 {
5342 string base = chain->base;
5343 if (has_decoration(id: chain->self, decoration: DecorationNonUniform))
5344 convert_non_uniform_expression(expr&: base, ptr_id: chain->self);
5345 // RWByteAddress buffer is always uint in its underlying type.
5346 statement(ts&: base, ts: ".", ts&: atomic_op, ts: "(", ts&: chain->dynamic_index, ts&: chain->static_index, ts: ", ",
5347 ts: to_expression(id: ops[3]), ts: ", ", ts: to_expression(id: tmp_id), ts: ");");
5348 }
5349 }
5350 else
5351 {
5352 uint32_t result_type = ops[0];
5353 uint32_t id = ops[1];
5354 forced_temporaries.insert(x: ops[1]);
5355
5356 auto &type = get<SPIRType>(id: result_type);
5357 statement(ts: variable_decl(type, name: to_name(id)), ts: ";");
5358
5359 auto &data_type = expression_type(id: ops[2]);
5360 auto *chain = maybe_get<SPIRAccessChain>(id: ops[2]);
5361 SPIRType::BaseType expr_type;
5362 if (data_type.storage == StorageClassImage || !chain)
5363 {
5364 statement(ts&: atomic_op, ts: "(", ts: to_non_uniform_aware_expression(id: ops[2]), ts: ", ", ts&: value_expr, ts: ", ", ts: to_name(id), ts: ");");
5365 expr_type = data_type.basetype;
5366 }
5367 else
5368 {
5369 // RWByteAddress buffer is always uint in its underlying type.
5370 string base = chain->base;
5371 if (has_decoration(id: chain->self, decoration: DecorationNonUniform))
5372 convert_non_uniform_expression(expr&: base, ptr_id: chain->self);
5373 expr_type = SPIRType::UInt;
5374 statement(ts&: base, ts: ".", ts&: atomic_op, ts: "(", ts&: chain->dynamic_index, ts&: chain->static_index, ts: ", ", ts&: value_expr,
5375 ts: ", ", ts: to_name(id), ts: ");");
5376 }
5377
5378 auto expr = bitcast_expression(target_type: type, expr_type, expr: to_name(id));
5379 set<SPIRExpression>(id, args&: expr, args&: result_type, args: true);
5380 }
5381 flush_all_atomic_capable_variables();
5382}
5383
5384void CompilerHLSL::emit_subgroup_op(const Instruction &i)
5385{
5386 if (hlsl_options.shader_model < 60)
5387 SPIRV_CROSS_THROW("Wave ops requires SM 6.0 or higher.");
5388
5389 const uint32_t *ops = stream(instr: i);
5390 auto op = static_cast<Op>(i.op);
5391
5392 uint32_t result_type = ops[0];
5393 uint32_t id = ops[1];
5394
5395 auto scope = static_cast<Scope>(evaluate_constant_u32(id: ops[2]));
5396 if (scope != ScopeSubgroup)
5397 SPIRV_CROSS_THROW("Only subgroup scope is supported.");
5398
5399 const auto make_inclusive_Sum = [&](const string &expr) -> string {
5400 return join(ts: expr, ts: " + ", ts: to_expression(id: ops[4]));
5401 };
5402
5403 const auto make_inclusive_Product = [&](const string &expr) -> string {
5404 return join(ts: expr, ts: " * ", ts: to_expression(id: ops[4]));
5405 };
5406
5407 // If we need to do implicit bitcasts, make sure we do it with the correct type.
5408 uint32_t integer_width = get_integer_width_for_instruction(instr: i);
5409 auto int_type = to_signed_basetype(width: integer_width);
5410 auto uint_type = to_unsigned_basetype(width: integer_width);
5411
5412#define make_inclusive_BitAnd(expr) ""
5413#define make_inclusive_BitOr(expr) ""
5414#define make_inclusive_BitXor(expr) ""
5415#define make_inclusive_Min(expr) ""
5416#define make_inclusive_Max(expr) ""
5417
5418 switch (op)
5419 {
5420 case OpGroupNonUniformElect:
5421 emit_op(result_type, result_id: id, rhs: "WaveIsFirstLane()", forward_rhs: true);
5422 break;
5423
5424 case OpGroupNonUniformBroadcast:
5425 emit_binary_func_op(result_type, result_id: id, op0: ops[3], op1: ops[4], op: "WaveReadLaneAt");
5426 break;
5427
5428 case OpGroupNonUniformBroadcastFirst:
5429 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "WaveReadLaneFirst");
5430 break;
5431
5432 case OpGroupNonUniformBallot:
5433 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "WaveActiveBallot");
5434 break;
5435
5436 case OpGroupNonUniformInverseBallot:
5437 SPIRV_CROSS_THROW("Cannot trivially implement InverseBallot in HLSL.");
5438
5439 case OpGroupNonUniformBallotBitExtract:
5440 SPIRV_CROSS_THROW("Cannot trivially implement BallotBitExtract in HLSL.");
5441
5442 case OpGroupNonUniformBallotFindLSB:
5443 SPIRV_CROSS_THROW("Cannot trivially implement BallotFindLSB in HLSL.");
5444
5445 case OpGroupNonUniformBallotFindMSB:
5446 SPIRV_CROSS_THROW("Cannot trivially implement BallotFindMSB in HLSL.");
5447
5448 case OpGroupNonUniformBallotBitCount:
5449 {
5450 auto operation = static_cast<GroupOperation>(ops[3]);
5451 bool forward = should_forward(id: ops[4]);
5452 if (operation == GroupOperationReduce)
5453 {
5454 auto left = join(ts: "countbits(", ts: to_enclosed_expression(id: ops[4]), ts: ".x) + countbits(",
5455 ts: to_enclosed_expression(id: ops[4]), ts: ".y)");
5456 auto right = join(ts: "countbits(", ts: to_enclosed_expression(id: ops[4]), ts: ".z) + countbits(",
5457 ts: to_enclosed_expression(id: ops[4]), ts: ".w)");
5458 emit_op(result_type, result_id: id, rhs: join(ts&: left, ts: " + ", ts&: right), forward_rhs: forward);
5459 inherit_expression_dependencies(dst: id, source: ops[4]);
5460 }
5461 else if (operation == GroupOperationInclusiveScan)
5462 {
5463 auto left = join(ts: "countbits(", ts: to_enclosed_expression(id: ops[4]), ts: ".x & gl_SubgroupLeMask.x) + countbits(",
5464 ts: to_enclosed_expression(id: ops[4]), ts: ".y & gl_SubgroupLeMask.y)");
5465 auto right = join(ts: "countbits(", ts: to_enclosed_expression(id: ops[4]), ts: ".z & gl_SubgroupLeMask.z) + countbits(",
5466 ts: to_enclosed_expression(id: ops[4]), ts: ".w & gl_SubgroupLeMask.w)");
5467 emit_op(result_type, result_id: id, rhs: join(ts&: left, ts: " + ", ts&: right), forward_rhs: forward);
5468 if (!active_input_builtins.get(bit: BuiltInSubgroupLeMask))
5469 {
5470 active_input_builtins.set(BuiltInSubgroupLeMask);
5471 force_recompile_guarantee_forward_progress();
5472 }
5473 }
5474 else if (operation == GroupOperationExclusiveScan)
5475 {
5476 auto left = join(ts: "countbits(", ts: to_enclosed_expression(id: ops[4]), ts: ".x & gl_SubgroupLtMask.x) + countbits(",
5477 ts: to_enclosed_expression(id: ops[4]), ts: ".y & gl_SubgroupLtMask.y)");
5478 auto right = join(ts: "countbits(", ts: to_enclosed_expression(id: ops[4]), ts: ".z & gl_SubgroupLtMask.z) + countbits(",
5479 ts: to_enclosed_expression(id: ops[4]), ts: ".w & gl_SubgroupLtMask.w)");
5480 emit_op(result_type, result_id: id, rhs: join(ts&: left, ts: " + ", ts&: right), forward_rhs: forward);
5481 if (!active_input_builtins.get(bit: BuiltInSubgroupLtMask))
5482 {
5483 active_input_builtins.set(BuiltInSubgroupLtMask);
5484 force_recompile_guarantee_forward_progress();
5485 }
5486 }
5487 else
5488 SPIRV_CROSS_THROW("Invalid BitCount operation.");
5489 break;
5490 }
5491
5492 case OpGroupNonUniformShuffle:
5493 emit_binary_func_op(result_type, result_id: id, op0: ops[3], op1: ops[4], op: "WaveReadLaneAt");
5494 break;
5495 case OpGroupNonUniformShuffleXor:
5496 {
5497 bool forward = should_forward(id: ops[3]);
5498 emit_op(result_type: ops[0], result_id: ops[1],
5499 rhs: join(ts: "WaveReadLaneAt(", ts: to_unpacked_expression(id: ops[3]), ts: ", ",
5500 ts: "WaveGetLaneIndex() ^ ", ts: to_enclosed_expression(id: ops[4]), ts: ")"), forward_rhs: forward);
5501 inherit_expression_dependencies(dst: ops[1], source: ops[3]);
5502 break;
5503 }
5504 case OpGroupNonUniformShuffleUp:
5505 {
5506 bool forward = should_forward(id: ops[3]);
5507 emit_op(result_type: ops[0], result_id: ops[1],
5508 rhs: join(ts: "WaveReadLaneAt(", ts: to_unpacked_expression(id: ops[3]), ts: ", ",
5509 ts: "WaveGetLaneIndex() - ", ts: to_enclosed_expression(id: ops[4]), ts: ")"), forward_rhs: forward);
5510 inherit_expression_dependencies(dst: ops[1], source: ops[3]);
5511 break;
5512 }
5513 case OpGroupNonUniformShuffleDown:
5514 {
5515 bool forward = should_forward(id: ops[3]);
5516 emit_op(result_type: ops[0], result_id: ops[1],
5517 rhs: join(ts: "WaveReadLaneAt(", ts: to_unpacked_expression(id: ops[3]), ts: ", ",
5518 ts: "WaveGetLaneIndex() + ", ts: to_enclosed_expression(id: ops[4]), ts: ")"), forward_rhs: forward);
5519 inherit_expression_dependencies(dst: ops[1], source: ops[3]);
5520 break;
5521 }
5522
5523 case OpGroupNonUniformAll:
5524 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "WaveActiveAllTrue");
5525 break;
5526
5527 case OpGroupNonUniformAny:
5528 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "WaveActiveAnyTrue");
5529 break;
5530
5531 case OpGroupNonUniformAllEqual:
5532 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "WaveActiveAllEqual");
5533 break;
5534
5535 // clang-format off
5536#define HLSL_GROUP_OP(op, hlsl_op, supports_scan) \
5537case OpGroupNonUniform##op: \
5538 { \
5539 auto operation = static_cast<GroupOperation>(ops[3]); \
5540 if (operation == GroupOperationReduce) \
5541 emit_unary_func_op(result_type, id, ops[4], "WaveActive" #hlsl_op); \
5542 else if (operation == GroupOperationInclusiveScan && supports_scan) \
5543 { \
5544 bool forward = should_forward(ops[4]); \
5545 emit_op(result_type, id, make_inclusive_##hlsl_op (join("WavePrefix" #hlsl_op, "(", to_expression(ops[4]), ")")), forward); \
5546 inherit_expression_dependencies(id, ops[4]); \
5547 } \
5548 else if (operation == GroupOperationExclusiveScan && supports_scan) \
5549 emit_unary_func_op(result_type, id, ops[4], "WavePrefix" #hlsl_op); \
5550 else if (operation == GroupOperationClusteredReduce) \
5551 SPIRV_CROSS_THROW("Cannot trivially implement ClusteredReduce in HLSL."); \
5552 else \
5553 SPIRV_CROSS_THROW("Invalid group operation."); \
5554 break; \
5555 }
5556
5557#define HLSL_GROUP_OP_CAST(op, hlsl_op, type) \
5558case OpGroupNonUniform##op: \
5559 { \
5560 auto operation = static_cast<GroupOperation>(ops[3]); \
5561 if (operation == GroupOperationReduce) \
5562 emit_unary_func_op_cast(result_type, id, ops[4], "WaveActive" #hlsl_op, type, type); \
5563 else \
5564 SPIRV_CROSS_THROW("Invalid group operation."); \
5565 break; \
5566 }
5567
5568 HLSL_GROUP_OP(FAdd, Sum, true)
5569 HLSL_GROUP_OP(FMul, Product, true)
5570 HLSL_GROUP_OP(FMin, Min, false)
5571 HLSL_GROUP_OP(FMax, Max, false)
5572 HLSL_GROUP_OP(IAdd, Sum, true)
5573 HLSL_GROUP_OP(IMul, Product, true)
5574 HLSL_GROUP_OP_CAST(SMin, Min, int_type)
5575 HLSL_GROUP_OP_CAST(SMax, Max, int_type)
5576 HLSL_GROUP_OP_CAST(UMin, Min, uint_type)
5577 HLSL_GROUP_OP_CAST(UMax, Max, uint_type)
5578 HLSL_GROUP_OP(BitwiseAnd, BitAnd, false)
5579 HLSL_GROUP_OP(BitwiseOr, BitOr, false)
5580 HLSL_GROUP_OP(BitwiseXor, BitXor, false)
5581 HLSL_GROUP_OP_CAST(LogicalAnd, BitAnd, uint_type)
5582 HLSL_GROUP_OP_CAST(LogicalOr, BitOr, uint_type)
5583 HLSL_GROUP_OP_CAST(LogicalXor, BitXor, uint_type)
5584
5585#undef HLSL_GROUP_OP
5586#undef HLSL_GROUP_OP_CAST
5587 // clang-format on
5588
5589 case OpGroupNonUniformQuadSwap:
5590 {
5591 uint32_t direction = evaluate_constant_u32(id: ops[4]);
5592 if (direction == 0)
5593 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "QuadReadAcrossX");
5594 else if (direction == 1)
5595 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "QuadReadAcrossY");
5596 else if (direction == 2)
5597 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "QuadReadAcrossDiagonal");
5598 else
5599 SPIRV_CROSS_THROW("Invalid quad swap direction.");
5600 break;
5601 }
5602
5603 case OpGroupNonUniformQuadBroadcast:
5604 {
5605 emit_binary_func_op(result_type, result_id: id, op0: ops[3], op1: ops[4], op: "QuadReadLaneAt");
5606 break;
5607 }
5608
5609 default:
5610 SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
5611 }
5612
5613 register_control_dependent_expression(expr: id);
5614}
5615
5616void CompilerHLSL::emit_instruction(const Instruction &instruction)
5617{
5618 auto ops = stream(instr: instruction);
5619 auto opcode = static_cast<Op>(instruction.op);
5620
5621#define HLSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
5622#define HLSL_BOP_CAST(op, type) \
5623 emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode), false)
5624#define HLSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
5625#define HLSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
5626#define HLSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
5627#define HLSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
5628#define HLSL_BFOP_CAST(op, type) \
5629 emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
5630#define HLSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
5631#define HLSL_UFOP(op) emit_unary_func_op(ops[0], ops[1], ops[2], #op)
5632
5633 // If we need to do implicit bitcasts, make sure we do it with the correct type.
5634 uint32_t integer_width = get_integer_width_for_instruction(instr: instruction);
5635 auto int_type = to_signed_basetype(width: integer_width);
5636 auto uint_type = to_unsigned_basetype(width: integer_width);
5637
5638 opcode = get_remapped_spirv_op(op: opcode);
5639
5640 switch (opcode)
5641 {
5642 case OpAccessChain:
5643 case OpInBoundsAccessChain:
5644 {
5645 emit_access_chain(instruction);
5646 break;
5647 }
5648 case OpBitcast:
5649 {
5650 auto bitcast_type = get_bitcast_type(result_type: ops[0], op0: ops[2]);
5651 if (bitcast_type == CompilerHLSL::TypeNormal)
5652 CompilerGLSL::emit_instruction(instr: instruction);
5653 else
5654 {
5655 if (!requires_uint2_packing)
5656 {
5657 requires_uint2_packing = true;
5658 force_recompile();
5659 }
5660
5661 if (bitcast_type == CompilerHLSL::TypePackUint2x32)
5662 emit_unary_func_op(result_type: ops[0], result_id: ops[1], op0: ops[2], op: "spvPackUint2x32");
5663 else
5664 emit_unary_func_op(result_type: ops[0], result_id: ops[1], op0: ops[2], op: "spvUnpackUint2x32");
5665 }
5666
5667 break;
5668 }
5669
5670 case OpSelect:
5671 {
5672 auto &value_type = expression_type(id: ops[3]);
5673 if (value_type.basetype == SPIRType::Struct || is_array(type: value_type))
5674 {
5675 // HLSL does not support ternary expressions on composites.
5676 // Cannot use branches, since we might be in a continue block
5677 // where explicit control flow is prohibited.
5678 // Emit a helper function where we can use control flow.
5679 TypeID value_type_id = expression_type_id(id: ops[3]);
5680 auto itr = std::find(first: composite_selection_workaround_types.begin(),
5681 last: composite_selection_workaround_types.end(),
5682 val: value_type_id);
5683 if (itr == composite_selection_workaround_types.end())
5684 {
5685 composite_selection_workaround_types.push_back(x: value_type_id);
5686 force_recompile();
5687 }
5688 emit_uninitialized_temporary_expression(type: ops[0], id: ops[1]);
5689 statement(ts: "spvSelectComposite(",
5690 ts: to_expression(id: ops[1]), ts: ", ", ts: to_expression(id: ops[2]), ts: ", ",
5691 ts: to_expression(id: ops[3]), ts: ", ", ts: to_expression(id: ops[4]), ts: ");");
5692 }
5693 else
5694 CompilerGLSL::emit_instruction(instr: instruction);
5695 break;
5696 }
5697
5698 case OpStore:
5699 {
5700 emit_store(instruction);
5701 break;
5702 }
5703
5704 case OpLoad:
5705 {
5706 emit_load(instruction);
5707 break;
5708 }
5709
5710 case OpMatrixTimesVector:
5711 {
5712 // Matrices are kept in a transposed state all the time, flip multiplication order always.
5713 emit_binary_func_op(result_type: ops[0], result_id: ops[1], op0: ops[3], op1: ops[2], op: "mul");
5714 break;
5715 }
5716
5717 case OpVectorTimesMatrix:
5718 {
5719 // Matrices are kept in a transposed state all the time, flip multiplication order always.
5720 emit_binary_func_op(result_type: ops[0], result_id: ops[1], op0: ops[3], op1: ops[2], op: "mul");
5721 break;
5722 }
5723
5724 case OpMatrixTimesMatrix:
5725 {
5726 // Matrices are kept in a transposed state all the time, flip multiplication order always.
5727 emit_binary_func_op(result_type: ops[0], result_id: ops[1], op0: ops[3], op1: ops[2], op: "mul");
5728 break;
5729 }
5730
5731 case OpOuterProduct:
5732 {
5733 uint32_t result_type = ops[0];
5734 uint32_t id = ops[1];
5735 uint32_t a = ops[2];
5736 uint32_t b = ops[3];
5737
5738 auto &type = get<SPIRType>(id: result_type);
5739 string expr = type_to_glsl_constructor(type);
5740 expr += "(";
5741 for (uint32_t col = 0; col < type.columns; col++)
5742 {
5743 expr += to_enclosed_expression(id: a);
5744 expr += " * ";
5745 expr += to_extract_component_expression(id: b, index: col);
5746 if (col + 1 < type.columns)
5747 expr += ", ";
5748 }
5749 expr += ")";
5750 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: should_forward(id: a) && should_forward(id: b));
5751 inherit_expression_dependencies(dst: id, source: a);
5752 inherit_expression_dependencies(dst: id, source: b);
5753 break;
5754 }
5755
5756 case OpFMod:
5757 {
5758 if (!requires_op_fmod)
5759 {
5760 requires_op_fmod = true;
5761 force_recompile();
5762 }
5763 CompilerGLSL::emit_instruction(instr: instruction);
5764 break;
5765 }
5766
5767 case OpFRem:
5768 emit_binary_func_op(result_type: ops[0], result_id: ops[1], op0: ops[2], op1: ops[3], op: "fmod");
5769 break;
5770
5771 case OpImage:
5772 {
5773 uint32_t result_type = ops[0];
5774 uint32_t id = ops[1];
5775 auto *combined = maybe_get<SPIRCombinedImageSampler>(id: ops[2]);
5776
5777 if (combined)
5778 {
5779 auto &e = emit_op(result_type, result_id: id, rhs: to_expression(id: combined->image), forward_rhs: true, suppress_usage_tracking: true);
5780 auto *var = maybe_get_backing_variable(chain: combined->image);
5781 if (var)
5782 e.loaded_from = var->self;
5783 }
5784 else
5785 {
5786 auto &e = emit_op(result_type, result_id: id, rhs: to_expression(id: ops[2]), forward_rhs: true, suppress_usage_tracking: true);
5787 auto *var = maybe_get_backing_variable(chain: ops[2]);
5788 if (var)
5789 e.loaded_from = var->self;
5790 }
5791 break;
5792 }
5793
5794 case OpDPdx:
5795 HLSL_UFOP(ddx);
5796 register_control_dependent_expression(expr: ops[1]);
5797 break;
5798
5799 case OpDPdy:
5800 HLSL_UFOP(ddy);
5801 register_control_dependent_expression(expr: ops[1]);
5802 break;
5803
5804 case OpDPdxFine:
5805 HLSL_UFOP(ddx_fine);
5806 register_control_dependent_expression(expr: ops[1]);
5807 break;
5808
5809 case OpDPdyFine:
5810 HLSL_UFOP(ddy_fine);
5811 register_control_dependent_expression(expr: ops[1]);
5812 break;
5813
5814 case OpDPdxCoarse:
5815 HLSL_UFOP(ddx_coarse);
5816 register_control_dependent_expression(expr: ops[1]);
5817 break;
5818
5819 case OpDPdyCoarse:
5820 HLSL_UFOP(ddy_coarse);
5821 register_control_dependent_expression(expr: ops[1]);
5822 break;
5823
5824 case OpFwidth:
5825 case OpFwidthCoarse:
5826 case OpFwidthFine:
5827 HLSL_UFOP(fwidth);
5828 register_control_dependent_expression(expr: ops[1]);
5829 break;
5830
5831 case OpLogicalNot:
5832 {
5833 auto result_type = ops[0];
5834 auto id = ops[1];
5835 auto &type = get<SPIRType>(id: result_type);
5836
5837 if (type.vecsize > 1)
5838 emit_unrolled_unary_op(result_type, result_id: id, operand: ops[2], op: "!");
5839 else
5840 HLSL_UOP(!);
5841 break;
5842 }
5843
5844 case OpIEqual:
5845 {
5846 auto result_type = ops[0];
5847 auto id = ops[1];
5848
5849 if (expression_type(id: ops[2]).vecsize > 1)
5850 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "==", negate: false, expected_type: SPIRType::Unknown);
5851 else
5852 HLSL_BOP_CAST(==, int_type);
5853 break;
5854 }
5855
5856 case OpLogicalEqual:
5857 case OpFOrdEqual:
5858 case OpFUnordEqual:
5859 {
5860 // HLSL != operator is unordered.
5861 // https://docs.microsoft.com/en-us/windows/win32/direct3d10/d3d10-graphics-programming-guide-resources-float-rules.
5862 // isnan() is apparently implemented as x != x as well.
5863 // We cannot implement UnordEqual as !(OrdNotEqual), as HLSL cannot express OrdNotEqual.
5864 // HACK: FUnordEqual will be implemented as FOrdEqual.
5865
5866 auto result_type = ops[0];
5867 auto id = ops[1];
5868
5869 if (expression_type(id: ops[2]).vecsize > 1)
5870 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "==", negate: false, expected_type: SPIRType::Unknown);
5871 else
5872 HLSL_BOP(==);
5873 break;
5874 }
5875
5876 case OpINotEqual:
5877 {
5878 auto result_type = ops[0];
5879 auto id = ops[1];
5880
5881 if (expression_type(id: ops[2]).vecsize > 1)
5882 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "!=", negate: false, expected_type: SPIRType::Unknown);
5883 else
5884 HLSL_BOP_CAST(!=, int_type);
5885 break;
5886 }
5887
5888 case OpLogicalNotEqual:
5889 case OpFOrdNotEqual:
5890 case OpFUnordNotEqual:
5891 {
5892 // HLSL != operator is unordered.
5893 // https://docs.microsoft.com/en-us/windows/win32/direct3d10/d3d10-graphics-programming-guide-resources-float-rules.
5894 // isnan() is apparently implemented as x != x as well.
5895
5896 // FIXME: FOrdNotEqual cannot be implemented in a crisp and simple way here.
5897 // We would need to do something like not(UnordEqual), but that cannot be expressed either.
5898 // Adding a lot of NaN checks would be a breaking change from perspective of performance.
5899 // SPIR-V will generally use isnan() checks when this even matters.
5900 // HACK: FOrdNotEqual will be implemented as FUnordEqual.
5901
5902 auto result_type = ops[0];
5903 auto id = ops[1];
5904
5905 if (expression_type(id: ops[2]).vecsize > 1)
5906 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "!=", negate: false, expected_type: SPIRType::Unknown);
5907 else
5908 HLSL_BOP(!=);
5909 break;
5910 }
5911
5912 case OpUGreaterThan:
5913 case OpSGreaterThan:
5914 {
5915 auto result_type = ops[0];
5916 auto id = ops[1];
5917 auto type = opcode == OpUGreaterThan ? uint_type : int_type;
5918
5919 if (expression_type(id: ops[2]).vecsize > 1)
5920 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: ">", negate: false, expected_type: type);
5921 else
5922 HLSL_BOP_CAST(>, type);
5923 break;
5924 }
5925
5926 case OpFOrdGreaterThan:
5927 {
5928 auto result_type = ops[0];
5929 auto id = ops[1];
5930
5931 if (expression_type(id: ops[2]).vecsize > 1)
5932 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: ">", negate: false, expected_type: SPIRType::Unknown);
5933 else
5934 HLSL_BOP(>);
5935 break;
5936 }
5937
5938 case OpFUnordGreaterThan:
5939 {
5940 auto result_type = ops[0];
5941 auto id = ops[1];
5942
5943 if (expression_type(id: ops[2]).vecsize > 1)
5944 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "<=", negate: true, expected_type: SPIRType::Unknown);
5945 else
5946 CompilerGLSL::emit_instruction(instr: instruction);
5947 break;
5948 }
5949
5950 case OpUGreaterThanEqual:
5951 case OpSGreaterThanEqual:
5952 {
5953 auto result_type = ops[0];
5954 auto id = ops[1];
5955
5956 auto type = opcode == OpUGreaterThanEqual ? uint_type : int_type;
5957 if (expression_type(id: ops[2]).vecsize > 1)
5958 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: ">=", negate: false, expected_type: type);
5959 else
5960 HLSL_BOP_CAST(>=, type);
5961 break;
5962 }
5963
5964 case OpFOrdGreaterThanEqual:
5965 {
5966 auto result_type = ops[0];
5967 auto id = ops[1];
5968
5969 if (expression_type(id: ops[2]).vecsize > 1)
5970 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: ">=", negate: false, expected_type: SPIRType::Unknown);
5971 else
5972 HLSL_BOP(>=);
5973 break;
5974 }
5975
5976 case OpFUnordGreaterThanEqual:
5977 {
5978 auto result_type = ops[0];
5979 auto id = ops[1];
5980
5981 if (expression_type(id: ops[2]).vecsize > 1)
5982 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "<", negate: true, expected_type: SPIRType::Unknown);
5983 else
5984 CompilerGLSL::emit_instruction(instr: instruction);
5985 break;
5986 }
5987
5988 case OpULessThan:
5989 case OpSLessThan:
5990 {
5991 auto result_type = ops[0];
5992 auto id = ops[1];
5993
5994 auto type = opcode == OpULessThan ? uint_type : int_type;
5995 if (expression_type(id: ops[2]).vecsize > 1)
5996 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "<", negate: false, expected_type: type);
5997 else
5998 HLSL_BOP_CAST(<, type);
5999 break;
6000 }
6001
6002 case OpFOrdLessThan:
6003 {
6004 auto result_type = ops[0];
6005 auto id = ops[1];
6006
6007 if (expression_type(id: ops[2]).vecsize > 1)
6008 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "<", negate: false, expected_type: SPIRType::Unknown);
6009 else
6010 HLSL_BOP(<);
6011 break;
6012 }
6013
6014 case OpFUnordLessThan:
6015 {
6016 auto result_type = ops[0];
6017 auto id = ops[1];
6018
6019 if (expression_type(id: ops[2]).vecsize > 1)
6020 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: ">=", negate: true, expected_type: SPIRType::Unknown);
6021 else
6022 CompilerGLSL::emit_instruction(instr: instruction);
6023 break;
6024 }
6025
6026 case OpULessThanEqual:
6027 case OpSLessThanEqual:
6028 {
6029 auto result_type = ops[0];
6030 auto id = ops[1];
6031
6032 auto type = opcode == OpULessThanEqual ? uint_type : int_type;
6033 if (expression_type(id: ops[2]).vecsize > 1)
6034 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "<=", negate: false, expected_type: type);
6035 else
6036 HLSL_BOP_CAST(<=, type);
6037 break;
6038 }
6039
6040 case OpFOrdLessThanEqual:
6041 {
6042 auto result_type = ops[0];
6043 auto id = ops[1];
6044
6045 if (expression_type(id: ops[2]).vecsize > 1)
6046 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "<=", negate: false, expected_type: SPIRType::Unknown);
6047 else
6048 HLSL_BOP(<=);
6049 break;
6050 }
6051
6052 case OpFUnordLessThanEqual:
6053 {
6054 auto result_type = ops[0];
6055 auto id = ops[1];
6056
6057 if (expression_type(id: ops[2]).vecsize > 1)
6058 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: ">", negate: true, expected_type: SPIRType::Unknown);
6059 else
6060 CompilerGLSL::emit_instruction(instr: instruction);
6061 break;
6062 }
6063
6064 case OpImageQueryLod:
6065 emit_texture_op(i: instruction, sparse: false);
6066 break;
6067
6068 case OpImageQuerySizeLod:
6069 {
6070 auto result_type = ops[0];
6071 auto id = ops[1];
6072
6073 require_texture_query_variant(var_id: ops[2]);
6074 auto dummy_samples_levels = join(ts: get_fallback_name(id), ts: "_dummy_parameter");
6075 statement(ts: "uint ", ts&: dummy_samples_levels, ts: ";");
6076
6077 auto expr = join(ts: "spvTextureSize(", ts: to_non_uniform_aware_expression(id: ops[2]), ts: ", ",
6078 ts: bitcast_expression(target_type: SPIRType::UInt, arg: ops[3]), ts: ", ", ts&: dummy_samples_levels, ts: ")");
6079
6080 auto &restype = get<SPIRType>(id: ops[0]);
6081 expr = bitcast_expression(target_type: restype, expr_type: SPIRType::UInt, expr);
6082 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: true);
6083 break;
6084 }
6085
6086 case OpImageQuerySize:
6087 {
6088 auto result_type = ops[0];
6089 auto id = ops[1];
6090
6091 require_texture_query_variant(var_id: ops[2]);
6092 bool uav = expression_type(id: ops[2]).image.sampled == 2;
6093
6094 if (const auto *var = maybe_get_backing_variable(chain: ops[2]))
6095 if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(id: var->self, decoration: DecorationNonWritable))
6096 uav = false;
6097
6098 auto dummy_samples_levels = join(ts: get_fallback_name(id), ts: "_dummy_parameter");
6099 statement(ts: "uint ", ts&: dummy_samples_levels, ts: ";");
6100
6101 string expr;
6102 if (uav)
6103 expr = join(ts: "spvImageSize(", ts: to_non_uniform_aware_expression(id: ops[2]), ts: ", ", ts&: dummy_samples_levels, ts: ")");
6104 else
6105 expr = join(ts: "spvTextureSize(", ts: to_non_uniform_aware_expression(id: ops[2]), ts: ", 0u, ", ts&: dummy_samples_levels, ts: ")");
6106
6107 auto &restype = get<SPIRType>(id: ops[0]);
6108 expr = bitcast_expression(target_type: restype, expr_type: SPIRType::UInt, expr);
6109 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: true);
6110 break;
6111 }
6112
6113 case OpImageQuerySamples:
6114 case OpImageQueryLevels:
6115 {
6116 auto result_type = ops[0];
6117 auto id = ops[1];
6118
6119 require_texture_query_variant(var_id: ops[2]);
6120 bool uav = expression_type(id: ops[2]).image.sampled == 2;
6121 if (opcode == OpImageQueryLevels && uav)
6122 SPIRV_CROSS_THROW("Cannot query levels for UAV images.");
6123
6124 if (const auto *var = maybe_get_backing_variable(chain: ops[2]))
6125 if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(id: var->self, decoration: DecorationNonWritable))
6126 uav = false;
6127
6128 // Keep it simple and do not emit special variants to make this look nicer ...
6129 // This stuff is barely, if ever, used.
6130 forced_temporaries.insert(x: id);
6131 auto &type = get<SPIRType>(id: result_type);
6132 statement(ts: variable_decl(type, name: to_name(id)), ts: ";");
6133
6134 if (uav)
6135 statement(ts: "spvImageSize(", ts: to_non_uniform_aware_expression(id: ops[2]), ts: ", ", ts: to_name(id), ts: ");");
6136 else
6137 statement(ts: "spvTextureSize(", ts: to_non_uniform_aware_expression(id: ops[2]), ts: ", 0u, ", ts: to_name(id), ts: ");");
6138
6139 auto &restype = get<SPIRType>(id: ops[0]);
6140 auto expr = bitcast_expression(target_type: restype, expr_type: SPIRType::UInt, expr: to_name(id));
6141 set<SPIRExpression>(id, args&: expr, args&: result_type, args: true);
6142 break;
6143 }
6144
6145 case OpImageRead:
6146 {
6147 uint32_t result_type = ops[0];
6148 uint32_t id = ops[1];
6149 auto *var = maybe_get_backing_variable(chain: ops[2]);
6150 auto &type = expression_type(id: ops[2]);
6151 bool subpass_data = type.image.dim == DimSubpassData;
6152 bool pure = false;
6153
6154 string imgexpr;
6155
6156 if (subpass_data)
6157 {
6158 if (hlsl_options.shader_model < 40)
6159 SPIRV_CROSS_THROW("Subpass loads are not supported in HLSL shader model 2/3.");
6160
6161 // Similar to GLSL, implement subpass loads using texelFetch.
6162 if (type.image.ms)
6163 {
6164 uint32_t operands = ops[4];
6165 if (operands != ImageOperandsSampleMask || instruction.length != 6)
6166 SPIRV_CROSS_THROW("Multisampled image used in OpImageRead, but unexpected operand mask was used.");
6167 uint32_t sample = ops[5];
6168 imgexpr = join(ts: to_non_uniform_aware_expression(id: ops[2]), ts: ".Load(int2(gl_FragCoord.xy), ", ts: to_expression(id: sample), ts: ")");
6169 }
6170 else
6171 imgexpr = join(ts: to_non_uniform_aware_expression(id: ops[2]), ts: ".Load(int3(int2(gl_FragCoord.xy), 0))");
6172
6173 pure = true;
6174 }
6175 else
6176 {
6177 imgexpr = join(ts: to_non_uniform_aware_expression(id: ops[2]), ts: "[", ts: to_expression(id: ops[3]), ts: "]");
6178 // The underlying image type in HLSL depends on the image format, unlike GLSL, where all images are "vec4",
6179 // except that the underlying type changes how the data is interpreted.
6180
6181 bool force_srv =
6182 hlsl_options.nonwritable_uav_texture_as_srv && var && has_decoration(id: var->self, decoration: DecorationNonWritable);
6183 pure = force_srv;
6184
6185 if (var && !subpass_data && !force_srv)
6186 imgexpr = remap_swizzle(result_type: get<SPIRType>(id: result_type),
6187 input_components: image_format_to_components(fmt: get<SPIRType>(id: var->basetype).image.format), expr: imgexpr);
6188 }
6189
6190 if (var)
6191 {
6192 bool forward = forced_temporaries.find(x: id) == end(cont&: forced_temporaries);
6193 auto &e = emit_op(result_type, result_id: id, rhs: imgexpr, forward_rhs: forward);
6194
6195 if (!pure)
6196 {
6197 e.loaded_from = var->self;
6198 if (forward)
6199 var->dependees.push_back(t: id);
6200 }
6201 }
6202 else
6203 emit_op(result_type, result_id: id, rhs: imgexpr, forward_rhs: false);
6204
6205 inherit_expression_dependencies(dst: id, source: ops[2]);
6206 if (type.image.ms)
6207 inherit_expression_dependencies(dst: id, source: ops[5]);
6208 break;
6209 }
6210
6211 case OpImageWrite:
6212 {
6213 auto *var = maybe_get_backing_variable(chain: ops[0]);
6214
6215 // The underlying image type in HLSL depends on the image format, unlike GLSL, where all images are "vec4",
6216 // except that the underlying type changes how the data is interpreted.
6217 auto value_expr = to_expression(id: ops[2]);
6218 if (var)
6219 {
6220 auto &type = get<SPIRType>(id: var->basetype);
6221 auto narrowed_type = get<SPIRType>(id: type.image.type);
6222 narrowed_type.vecsize = image_format_to_components(fmt: type.image.format);
6223 value_expr = remap_swizzle(result_type: narrowed_type, input_components: expression_type(id: ops[2]).vecsize, expr: value_expr);
6224 }
6225
6226 statement(ts: to_non_uniform_aware_expression(id: ops[0]), ts: "[", ts: to_expression(id: ops[1]), ts: "] = ", ts&: value_expr, ts: ";");
6227 if (var && variable_storage_is_aliased(var: *var))
6228 flush_all_aliased_variables();
6229 break;
6230 }
6231
6232 case OpImageTexelPointer:
6233 {
6234 uint32_t result_type = ops[0];
6235 uint32_t id = ops[1];
6236
6237 auto expr = to_expression(id: ops[2]);
6238 expr += join(ts: "[", ts: to_expression(id: ops[3]), ts: "]");
6239 auto &e = set<SPIRExpression>(id, args&: expr, args&: result_type, args: true);
6240
6241 // When using the pointer, we need to know which variable it is actually loaded from.
6242 auto *var = maybe_get_backing_variable(chain: ops[2]);
6243 e.loaded_from = var ? var->self : ID(0);
6244 inherit_expression_dependencies(dst: id, source: ops[3]);
6245 break;
6246 }
6247
6248 case OpAtomicFAddEXT:
6249 case OpAtomicFMinEXT:
6250 case OpAtomicFMaxEXT:
6251 SPIRV_CROSS_THROW("Floating-point atomics are not supported in HLSL.");
6252
6253 case OpAtomicCompareExchange:
6254 case OpAtomicExchange:
6255 case OpAtomicISub:
6256 case OpAtomicSMin:
6257 case OpAtomicUMin:
6258 case OpAtomicSMax:
6259 case OpAtomicUMax:
6260 case OpAtomicAnd:
6261 case OpAtomicOr:
6262 case OpAtomicXor:
6263 case OpAtomicIAdd:
6264 case OpAtomicIIncrement:
6265 case OpAtomicIDecrement:
6266 case OpAtomicLoad:
6267 case OpAtomicStore:
6268 {
6269 emit_atomic(ops, length: instruction.length, op: opcode);
6270 break;
6271 }
6272
6273 case OpControlBarrier:
6274 case OpMemoryBarrier:
6275 {
6276 uint32_t memory;
6277 uint32_t semantics;
6278
6279 if (opcode == OpMemoryBarrier)
6280 {
6281 memory = evaluate_constant_u32(id: ops[0]);
6282 semantics = evaluate_constant_u32(id: ops[1]);
6283 }
6284 else
6285 {
6286 memory = evaluate_constant_u32(id: ops[1]);
6287 semantics = evaluate_constant_u32(id: ops[2]);
6288 }
6289
6290 if (memory == ScopeSubgroup)
6291 {
6292 // No Wave-barriers in HLSL.
6293 break;
6294 }
6295
6296 // We only care about these flags, acquire/release and friends are not relevant to GLSL.
6297 semantics = mask_relevant_memory_semantics(semantics);
6298
6299 if (opcode == OpMemoryBarrier)
6300 {
6301 // If we are a memory barrier, and the next instruction is a control barrier, check if that memory barrier
6302 // does what we need, so we avoid redundant barriers.
6303 const Instruction *next = get_next_instruction_in_block(instr: instruction);
6304 if (next && next->op == OpControlBarrier)
6305 {
6306 auto *next_ops = stream(instr: *next);
6307 uint32_t next_memory = evaluate_constant_u32(id: next_ops[1]);
6308 uint32_t next_semantics = evaluate_constant_u32(id: next_ops[2]);
6309 next_semantics = mask_relevant_memory_semantics(semantics: next_semantics);
6310
6311 // There is no "just execution barrier" in HLSL.
6312 // If there are no memory semantics for next instruction, we will imply group shared memory is synced.
6313 if (next_semantics == 0)
6314 next_semantics = MemorySemanticsWorkgroupMemoryMask;
6315
6316 bool memory_scope_covered = false;
6317 if (next_memory == memory)
6318 memory_scope_covered = true;
6319 else if (next_semantics == MemorySemanticsWorkgroupMemoryMask)
6320 {
6321 // If we only care about workgroup memory, either Device or Workgroup scope is fine,
6322 // scope does not have to match.
6323 if ((next_memory == ScopeDevice || next_memory == ScopeWorkgroup) &&
6324 (memory == ScopeDevice || memory == ScopeWorkgroup))
6325 {
6326 memory_scope_covered = true;
6327 }
6328 }
6329 else if (memory == ScopeWorkgroup && next_memory == ScopeDevice)
6330 {
6331 // The control barrier has device scope, but the memory barrier just has workgroup scope.
6332 memory_scope_covered = true;
6333 }
6334
6335 // If we have the same memory scope, and all memory types are covered, we're good.
6336 if (memory_scope_covered && (semantics & next_semantics) == semantics)
6337 break;
6338 }
6339 }
6340
6341 // We are synchronizing some memory or syncing execution,
6342 // so we cannot forward any loads beyond the memory barrier.
6343 if (semantics || opcode == OpControlBarrier)
6344 {
6345 assert(current_emitting_block);
6346 flush_control_dependent_expressions(block: current_emitting_block->self);
6347 flush_all_active_variables();
6348 }
6349
6350 if (opcode == OpControlBarrier)
6351 {
6352 // We cannot emit just execution barrier, for no memory semantics pick the cheapest option.
6353 if (semantics == MemorySemanticsWorkgroupMemoryMask || semantics == 0)
6354 statement(ts: "GroupMemoryBarrierWithGroupSync();");
6355 else if (semantics != 0 && (semantics & MemorySemanticsWorkgroupMemoryMask) == 0)
6356 statement(ts: "DeviceMemoryBarrierWithGroupSync();");
6357 else
6358 statement(ts: "AllMemoryBarrierWithGroupSync();");
6359 }
6360 else
6361 {
6362 if (semantics == MemorySemanticsWorkgroupMemoryMask)
6363 statement(ts: "GroupMemoryBarrier();");
6364 else if (semantics != 0 && (semantics & MemorySemanticsWorkgroupMemoryMask) == 0)
6365 statement(ts: "DeviceMemoryBarrier();");
6366 else
6367 statement(ts: "AllMemoryBarrier();");
6368 }
6369 break;
6370 }
6371
6372 case OpBitFieldInsert:
6373 {
6374 if (!requires_bitfield_insert)
6375 {
6376 requires_bitfield_insert = true;
6377 force_recompile();
6378 }
6379
6380 auto expr = join(ts: "spvBitfieldInsert(", ts: to_expression(id: ops[2]), ts: ", ", ts: to_expression(id: ops[3]), ts: ", ",
6381 ts: to_expression(id: ops[4]), ts: ", ", ts: to_expression(id: ops[5]), ts: ")");
6382
6383 bool forward =
6384 should_forward(id: ops[2]) && should_forward(id: ops[3]) && should_forward(id: ops[4]) && should_forward(id: ops[5]);
6385
6386 auto &restype = get<SPIRType>(id: ops[0]);
6387 expr = bitcast_expression(target_type: restype, expr_type: SPIRType::UInt, expr);
6388 emit_op(result_type: ops[0], result_id: ops[1], rhs: expr, forward_rhs: forward);
6389 break;
6390 }
6391
6392 case OpBitFieldSExtract:
6393 case OpBitFieldUExtract:
6394 {
6395 if (!requires_bitfield_extract)
6396 {
6397 requires_bitfield_extract = true;
6398 force_recompile();
6399 }
6400
6401 if (opcode == OpBitFieldSExtract)
6402 HLSL_TFOP(spvBitfieldSExtract);
6403 else
6404 HLSL_TFOP(spvBitfieldUExtract);
6405 break;
6406 }
6407
6408 case OpBitCount:
6409 {
6410 auto basetype = expression_type(id: ops[2]).basetype;
6411 emit_unary_func_op_cast(result_type: ops[0], result_id: ops[1], op0: ops[2], op: "countbits", input_type: basetype, expected_result_type: basetype);
6412 break;
6413 }
6414
6415 case OpBitReverse:
6416 HLSL_UFOP(reversebits);
6417 break;
6418
6419 case OpArrayLength:
6420 {
6421 auto *var = maybe_get_backing_variable(chain: ops[2]);
6422 if (!var)
6423 SPIRV_CROSS_THROW("Array length must point directly to an SSBO block.");
6424
6425 auto &type = get<SPIRType>(id: var->basetype);
6426 if (!has_decoration(id: type.self, decoration: DecorationBlock) && !has_decoration(id: type.self, decoration: DecorationBufferBlock))
6427 SPIRV_CROSS_THROW("Array length expression must point to a block type.");
6428
6429 // This must be 32-bit uint, so we're good to go.
6430 emit_uninitialized_temporary_expression(type: ops[0], id: ops[1]);
6431 statement(ts: to_non_uniform_aware_expression(id: ops[2]), ts: ".GetDimensions(", ts: to_expression(id: ops[1]), ts: ");");
6432 uint32_t offset = type_struct_member_offset(type, index: ops[3]);
6433 uint32_t stride = type_struct_member_array_stride(type, index: ops[3]);
6434 statement(ts: to_expression(id: ops[1]), ts: " = (", ts: to_expression(id: ops[1]), ts: " - ", ts&: offset, ts: ") / ", ts&: stride, ts: ";");
6435 break;
6436 }
6437
6438 case OpIsHelperInvocationEXT:
6439 if (hlsl_options.shader_model < 50 || get_entry_point().model != ExecutionModelFragment)
6440 SPIRV_CROSS_THROW("Helper Invocation input is only supported in PS 5.0 or higher.");
6441 // Helper lane state with demote is volatile by nature.
6442 // Do not forward this.
6443 emit_op(result_type: ops[0], result_id: ops[1], rhs: "IsHelperLane()", forward_rhs: false);
6444 break;
6445
6446 case OpBeginInvocationInterlockEXT:
6447 case OpEndInvocationInterlockEXT:
6448 if (hlsl_options.shader_model < 51)
6449 SPIRV_CROSS_THROW("Rasterizer order views require Shader Model 5.1.");
6450 break; // Nothing to do in the body
6451
6452 case OpRayQueryInitializeKHR:
6453 {
6454 flush_variable_declaration(id: ops[0]);
6455
6456 std::string ray_desc_name = get_unique_identifier();
6457 statement(ts: "RayDesc ", ts&: ray_desc_name, ts: " = {", ts: to_expression(id: ops[4]), ts: ", ", ts: to_expression(id: ops[5]), ts: ", ",
6458 ts: to_expression(id: ops[6]), ts: ", ", ts: to_expression(id: ops[7]), ts: "};");
6459
6460 statement(ts: to_expression(id: ops[0]), ts: ".TraceRayInline(",
6461 ts: to_expression(id: ops[1]), ts: ", ", // acc structure
6462 ts: to_expression(id: ops[2]), ts: ", ", // ray flags
6463 ts: to_expression(id: ops[3]), ts: ", ", // mask
6464 ts&: ray_desc_name, ts: ");"); // ray
6465 break;
6466 }
6467 case OpRayQueryProceedKHR:
6468 {
6469 flush_variable_declaration(id: ops[0]);
6470 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".Proceed()"), forward_rhs: false);
6471 break;
6472 }
6473 case OpRayQueryTerminateKHR:
6474 {
6475 flush_variable_declaration(id: ops[0]);
6476 statement(ts: to_expression(id: ops[0]), ts: ".Abort();");
6477 break;
6478 }
6479 case OpRayQueryGenerateIntersectionKHR:
6480 {
6481 flush_variable_declaration(id: ops[0]);
6482 statement(ts: to_expression(id: ops[0]), ts: ".CommitProceduralPrimitiveHit(", ts: to_expression(id: ops[1]), ts: ");");
6483 break;
6484 }
6485 case OpRayQueryConfirmIntersectionKHR:
6486 {
6487 flush_variable_declaration(id: ops[0]);
6488 statement(ts: to_expression(id: ops[0]), ts: ".CommitNonOpaqueTriangleHit();");
6489 break;
6490 }
6491 case OpRayQueryGetIntersectionTypeKHR:
6492 {
6493 emit_rayquery_function(commited: ".CommittedStatus()", candidate: ".CandidateType()", ops);
6494 break;
6495 }
6496 case OpRayQueryGetIntersectionTKHR:
6497 {
6498 emit_rayquery_function(commited: ".CommittedRayT()", candidate: ".CandidateTriangleRayT()", ops);
6499 break;
6500 }
6501 case OpRayQueryGetIntersectionInstanceCustomIndexKHR:
6502 {
6503 emit_rayquery_function(commited: ".CommittedInstanceID()", candidate: ".CandidateInstanceID()", ops);
6504 break;
6505 }
6506 case OpRayQueryGetIntersectionInstanceIdKHR:
6507 {
6508 emit_rayquery_function(commited: ".CommittedInstanceIndex()", candidate: ".CandidateInstanceIndex()", ops);
6509 break;
6510 }
6511 case OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR:
6512 {
6513 emit_rayquery_function(commited: ".CommittedInstanceContributionToHitGroupIndex()",
6514 candidate: ".CandidateInstanceContributionToHitGroupIndex()", ops);
6515 break;
6516 }
6517 case OpRayQueryGetIntersectionGeometryIndexKHR:
6518 {
6519 emit_rayquery_function(commited: ".CommittedGeometryIndex()",
6520 candidate: ".CandidateGeometryIndex()", ops);
6521 break;
6522 }
6523 case OpRayQueryGetIntersectionPrimitiveIndexKHR:
6524 {
6525 emit_rayquery_function(commited: ".CommittedPrimitiveIndex()", candidate: ".CandidatePrimitiveIndex()", ops);
6526 break;
6527 }
6528 case OpRayQueryGetIntersectionBarycentricsKHR:
6529 {
6530 emit_rayquery_function(commited: ".CommittedTriangleBarycentrics()", candidate: ".CandidateTriangleBarycentrics()", ops);
6531 break;
6532 }
6533 case OpRayQueryGetIntersectionFrontFaceKHR:
6534 {
6535 emit_rayquery_function(commited: ".CommittedTriangleFrontFace()", candidate: ".CandidateTriangleFrontFace()", ops);
6536 break;
6537 }
6538 case OpRayQueryGetIntersectionCandidateAABBOpaqueKHR:
6539 {
6540 flush_variable_declaration(id: ops[0]);
6541 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".CandidateProceduralPrimitiveNonOpaque()"), forward_rhs: false);
6542 break;
6543 }
6544 case OpRayQueryGetIntersectionObjectRayDirectionKHR:
6545 {
6546 emit_rayquery_function(commited: ".CommittedObjectRayDirection()", candidate: ".CandidateObjectRayDirection()", ops);
6547 break;
6548 }
6549 case OpRayQueryGetIntersectionObjectRayOriginKHR:
6550 {
6551 flush_variable_declaration(id: ops[0]);
6552 emit_rayquery_function(commited: ".CommittedObjectRayOrigin()", candidate: ".CandidateObjectRayOrigin()", ops);
6553 break;
6554 }
6555 case OpRayQueryGetIntersectionObjectToWorldKHR:
6556 {
6557 emit_rayquery_function(commited: ".CommittedObjectToWorld4x3()", candidate: ".CandidateObjectToWorld4x3()", ops);
6558 break;
6559 }
6560 case OpRayQueryGetIntersectionWorldToObjectKHR:
6561 {
6562 emit_rayquery_function(commited: ".CommittedWorldToObject4x3()", candidate: ".CandidateWorldToObject4x3()", ops);
6563 break;
6564 }
6565 case OpRayQueryGetRayFlagsKHR:
6566 {
6567 flush_variable_declaration(id: ops[0]);
6568 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".RayFlags()"), forward_rhs: false);
6569 break;
6570 }
6571 case OpRayQueryGetRayTMinKHR:
6572 {
6573 flush_variable_declaration(id: ops[0]);
6574 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".RayTMin()"), forward_rhs: false);
6575 break;
6576 }
6577 case OpRayQueryGetWorldRayOriginKHR:
6578 {
6579 flush_variable_declaration(id: ops[0]);
6580 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".WorldRayOrigin()"), forward_rhs: false);
6581 break;
6582 }
6583 case OpRayQueryGetWorldRayDirectionKHR:
6584 {
6585 flush_variable_declaration(id: ops[0]);
6586 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".WorldRayDirection()"), forward_rhs: false);
6587 break;
6588 }
6589 case OpSetMeshOutputsEXT:
6590 {
6591 statement(ts: "SetMeshOutputCounts(", ts: to_unpacked_expression(id: ops[0]), ts: ", ", ts: to_unpacked_expression(id: ops[1]), ts: ");");
6592 break;
6593 }
6594 default:
6595 CompilerGLSL::emit_instruction(instr: instruction);
6596 break;
6597 }
6598}
6599
6600void CompilerHLSL::require_texture_query_variant(uint32_t var_id)
6601{
6602 if (const auto *var = maybe_get_backing_variable(chain: var_id))
6603 var_id = var->self;
6604
6605 auto &type = expression_type(id: var_id);
6606 bool uav = type.image.sampled == 2;
6607 if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(id: var_id, decoration: DecorationNonWritable))
6608 uav = false;
6609
6610 uint32_t bit = 0;
6611 switch (type.image.dim)
6612 {
6613 case Dim1D:
6614 bit = type.image.arrayed ? Query1DArray : Query1D;
6615 break;
6616
6617 case Dim2D:
6618 if (type.image.ms)
6619 bit = type.image.arrayed ? Query2DMSArray : Query2DMS;
6620 else
6621 bit = type.image.arrayed ? Query2DArray : Query2D;
6622 break;
6623
6624 case Dim3D:
6625 bit = Query3D;
6626 break;
6627
6628 case DimCube:
6629 bit = type.image.arrayed ? QueryCubeArray : QueryCube;
6630 break;
6631
6632 case DimBuffer:
6633 bit = QueryBuffer;
6634 break;
6635
6636 default:
6637 SPIRV_CROSS_THROW("Unsupported query type.");
6638 }
6639
6640 switch (get<SPIRType>(id: type.image.type).basetype)
6641 {
6642 case SPIRType::Float:
6643 bit += QueryTypeFloat;
6644 break;
6645
6646 case SPIRType::Int:
6647 bit += QueryTypeInt;
6648 break;
6649
6650 case SPIRType::UInt:
6651 bit += QueryTypeUInt;
6652 break;
6653
6654 default:
6655 SPIRV_CROSS_THROW("Unsupported query type.");
6656 }
6657
6658 auto norm_state = image_format_to_normalized_state(fmt: type.image.format);
6659 auto &variant = uav ? required_texture_size_variants
6660 .uav[uint32_t(norm_state)][image_format_to_components(fmt: type.image.format) - 1] :
6661 required_texture_size_variants.srv;
6662
6663 uint64_t mask = 1ull << bit;
6664 if ((variant & mask) == 0)
6665 {
6666 force_recompile();
6667 variant |= mask;
6668 }
6669}
6670
6671void CompilerHLSL::set_root_constant_layouts(std::vector<RootConstants> layout)
6672{
6673 root_constants_layout = std::move(layout);
6674}
6675
6676void CompilerHLSL::add_vertex_attribute_remap(const HLSLVertexAttributeRemap &vertex_attributes)
6677{
6678 remap_vertex_attributes.push_back(t: vertex_attributes);
6679}
6680
6681VariableID CompilerHLSL::remap_num_workgroups_builtin()
6682{
6683 update_active_builtins();
6684
6685 if (!active_input_builtins.get(bit: BuiltInNumWorkgroups))
6686 return 0;
6687
6688 // Create a new, fake UBO.
6689 uint32_t offset = ir.increase_bound_by(count: 4);
6690
6691 uint32_t uint_type_id = offset;
6692 uint32_t block_type_id = offset + 1;
6693 uint32_t block_pointer_type_id = offset + 2;
6694 uint32_t variable_id = offset + 3;
6695
6696 SPIRType uint_type { OpTypeVector };
6697 uint_type.basetype = SPIRType::UInt;
6698 uint_type.width = 32;
6699 uint_type.vecsize = 3;
6700 uint_type.columns = 1;
6701 set<SPIRType>(id: uint_type_id, args&: uint_type);
6702
6703 SPIRType block_type { OpTypeStruct };
6704 block_type.basetype = SPIRType::Struct;
6705 block_type.member_types.push_back(t: uint_type_id);
6706 set<SPIRType>(id: block_type_id, args&: block_type);
6707 set_decoration(id: block_type_id, decoration: DecorationBlock);
6708 set_member_name(id: block_type_id, index: 0, name: "count");
6709 set_member_decoration(id: block_type_id, index: 0, decoration: DecorationOffset, argument: 0);
6710
6711 SPIRType block_pointer_type = block_type;
6712 block_pointer_type.pointer = true;
6713 block_pointer_type.storage = StorageClassUniform;
6714 block_pointer_type.parent_type = block_type_id;
6715 auto &ptr_type = set<SPIRType>(id: block_pointer_type_id, args&: block_pointer_type);
6716
6717 // Preserve self.
6718 ptr_type.self = block_type_id;
6719
6720 set<SPIRVariable>(id: variable_id, args&: block_pointer_type_id, args: StorageClassUniform);
6721 ir.meta[variable_id].decoration.alias = "SPIRV_Cross_NumWorkgroups";
6722
6723 num_workgroups_builtin = variable_id;
6724 get_entry_point().interface_variables.push_back(t: num_workgroups_builtin);
6725 return variable_id;
6726}
6727
6728void CompilerHLSL::set_resource_binding_flags(HLSLBindingFlags flags)
6729{
6730 resource_binding_flags = flags;
6731}
6732
6733void CompilerHLSL::validate_shader_model()
6734{
6735 // Check for nonuniform qualifier.
6736 // Instead of looping over all decorations to find this, just look at capabilities.
6737 for (auto &cap : ir.declared_capabilities)
6738 {
6739 switch (cap)
6740 {
6741 case CapabilityShaderNonUniformEXT:
6742 case CapabilityRuntimeDescriptorArrayEXT:
6743 if (hlsl_options.shader_model < 51)
6744 SPIRV_CROSS_THROW(
6745 "Shader model 5.1 or higher is required to use bindless resources or NonUniformResourceIndex.");
6746 break;
6747
6748 case CapabilityVariablePointers:
6749 case CapabilityVariablePointersStorageBuffer:
6750 SPIRV_CROSS_THROW("VariablePointers capability is not supported in HLSL.");
6751
6752 default:
6753 break;
6754 }
6755 }
6756
6757 if (ir.addressing_model != AddressingModelLogical)
6758 SPIRV_CROSS_THROW("Only Logical addressing model can be used with HLSL.");
6759
6760 if (hlsl_options.enable_16bit_types && hlsl_options.shader_model < 62)
6761 SPIRV_CROSS_THROW("Need at least shader model 6.2 when enabling native 16-bit type support.");
6762}
6763
6764string CompilerHLSL::compile()
6765{
6766 ir.fixup_reserved_names();
6767
6768 // Do not deal with ES-isms like precision, older extensions and such.
6769 options.es = false;
6770 options.version = 450;
6771 options.vulkan_semantics = true;
6772 backend.float_literal_suffix = true;
6773 backend.double_literal_suffix = false;
6774 backend.long_long_literal_suffix = true;
6775 backend.uint32_t_literal_suffix = true;
6776 backend.int16_t_literal_suffix = "";
6777 backend.uint16_t_literal_suffix = "u";
6778 backend.basic_int_type = "int";
6779 backend.basic_uint_type = "uint";
6780 backend.demote_literal = "discard";
6781 backend.boolean_mix_function = "";
6782 backend.swizzle_is_function = false;
6783 backend.shared_is_implied = true;
6784 backend.unsized_array_supported = true;
6785 backend.explicit_struct_type = false;
6786 backend.use_initializer_list = true;
6787 backend.use_constructor_splatting = false;
6788 backend.can_swizzle_scalar = true;
6789 backend.can_declare_struct_inline = false;
6790 backend.can_declare_arrays_inline = false;
6791 backend.can_return_array = false;
6792 backend.nonuniform_qualifier = "NonUniformResourceIndex";
6793 backend.support_case_fallthrough = false;
6794 backend.force_merged_mesh_block = get_execution_model() == ExecutionModelMeshEXT;
6795 backend.force_gl_in_out_block = backend.force_merged_mesh_block;
6796 backend.supports_empty_struct = hlsl_options.shader_model <= 30;
6797
6798 // SM 4.1 does not support precise for some reason.
6799 backend.support_precise_qualifier = hlsl_options.shader_model >= 50 || hlsl_options.shader_model == 40;
6800
6801 fixup_anonymous_struct_names();
6802 fixup_type_alias();
6803 reorder_type_alias();
6804 build_function_control_flow_graphs_and_analyze();
6805 validate_shader_model();
6806 update_active_builtins();
6807 analyze_image_and_sampler_usage();
6808 analyze_interlocked_resource_usage();
6809 if (get_execution_model() == ExecutionModelMeshEXT)
6810 analyze_meshlet_writes();
6811
6812 // Subpass input needs SV_Position.
6813 if (need_subpass_input)
6814 active_input_builtins.set(BuiltInFragCoord);
6815
6816 // Need to offset by BaseVertex/BaseInstance in SM 6.8+.
6817 if (hlsl_options.shader_model >= 68)
6818 {
6819 if (active_input_builtins.get(bit: BuiltInVertexIndex))
6820 active_input_builtins.set(BuiltInBaseVertex);
6821 if (active_input_builtins.get(bit: BuiltInInstanceIndex))
6822 active_input_builtins.set(BuiltInBaseInstance);
6823 }
6824
6825 uint32_t pass_count = 0;
6826 do
6827 {
6828 reset(iteration_count: pass_count);
6829
6830 // Move constructor for this type is broken on GCC 4.9 ...
6831 buffer.reset();
6832
6833 emit_header();
6834 emit_resources();
6835
6836 emit_function(func&: get<SPIRFunction>(id: ir.default_entry_point), return_flags: Bitset());
6837 emit_hlsl_entry_point();
6838
6839 pass_count++;
6840 } while (is_forcing_recompilation());
6841
6842 // Entry point in HLSL is always main() for the time being.
6843 get_entry_point().name = "main";
6844
6845 return buffer.str();
6846}
6847
6848void CompilerHLSL::emit_block_hints(const SPIRBlock &block)
6849{
6850 switch (block.hint)
6851 {
6852 case SPIRBlock::HintFlatten:
6853 statement(ts: "[flatten]");
6854 break;
6855 case SPIRBlock::HintDontFlatten:
6856 statement(ts: "[branch]");
6857 break;
6858 case SPIRBlock::HintUnroll:
6859 statement(ts: "[unroll]");
6860 break;
6861 case SPIRBlock::HintDontUnroll:
6862 statement(ts: "[loop]");
6863 break;
6864 default:
6865 break;
6866 }
6867}
6868
6869string CompilerHLSL::get_unique_identifier()
6870{
6871 return join(ts: "_", ts: unique_identifier_count++, ts: "ident");
6872}
6873
6874void CompilerHLSL::add_hlsl_resource_binding(const HLSLResourceBinding &binding)
6875{
6876 StageSetBinding tuple = { .model: binding.stage, .desc_set: binding.desc_set, .binding: binding.binding };
6877 resource_bindings[tuple] = { binding, false };
6878}
6879
6880bool CompilerHLSL::is_hlsl_resource_binding_used(ExecutionModel model, uint32_t desc_set, uint32_t binding) const
6881{
6882 StageSetBinding tuple = { .model: model, .desc_set: desc_set, .binding: binding };
6883 auto itr = resource_bindings.find(x: tuple);
6884 return itr != end(cont: resource_bindings) && itr->second.second;
6885}
6886
6887CompilerHLSL::BitcastType CompilerHLSL::get_bitcast_type(uint32_t result_type, uint32_t op0)
6888{
6889 auto &rslt_type = get<SPIRType>(id: result_type);
6890 auto &expr_type = expression_type(id: op0);
6891
6892 if (rslt_type.basetype == SPIRType::BaseType::UInt64 && expr_type.basetype == SPIRType::BaseType::UInt &&
6893 expr_type.vecsize == 2)
6894 return BitcastType::TypePackUint2x32;
6895 else if (rslt_type.basetype == SPIRType::BaseType::UInt && rslt_type.vecsize == 2 &&
6896 expr_type.basetype == SPIRType::BaseType::UInt64)
6897 return BitcastType::TypeUnpackUint64;
6898
6899 return BitcastType::TypeNormal;
6900}
6901
6902bool CompilerHLSL::is_hlsl_force_storage_buffer_as_uav(ID id) const
6903{
6904 if (hlsl_options.force_storage_buffer_as_uav)
6905 {
6906 return true;
6907 }
6908
6909 const uint32_t desc_set = get_decoration(id, decoration: spv::DecorationDescriptorSet);
6910 const uint32_t binding = get_decoration(id, decoration: spv::DecorationBinding);
6911
6912 return (force_uav_buffer_bindings.find(x: { .desc_set: desc_set, .binding: binding }) != force_uav_buffer_bindings.end());
6913}
6914
6915void CompilerHLSL::set_hlsl_force_storage_buffer_as_uav(uint32_t desc_set, uint32_t binding)
6916{
6917 SetBindingPair pair = { .desc_set: desc_set, .binding: binding };
6918 force_uav_buffer_bindings.insert(x: pair);
6919}
6920
6921bool CompilerHLSL::is_user_type_structured(uint32_t id) const
6922{
6923 if (hlsl_options.preserve_structured_buffers)
6924 {
6925 // Compare left hand side of string only as these user types can contain more meta data such as their subtypes,
6926 // e.g. "structuredbuffer:int"
6927 const std::string &user_type = get_decoration_string(id, decoration: DecorationUserTypeGOOGLE);
6928 return user_type.compare(pos: 0, n1: 16, s: "structuredbuffer") == 0 ||
6929 user_type.compare(pos: 0, n1: 18, s: "rwstructuredbuffer") == 0 ||
6930 user_type.compare(pos: 0, n1: 33, s: "rasterizerorderedstructuredbuffer") == 0;
6931 }
6932 return false;
6933}
6934
6935void CompilerHLSL::cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
6936{
6937 // Loading a full array of ClipDistance needs special consideration in mesh shaders
6938 // since we cannot lower them by wrapping the variables in global statics.
6939 // Fortunately, clip/cull is a proper vector in HLSL so we can lower with simple rvalue casts.
6940 if (get_execution_model() != ExecutionModelMeshEXT ||
6941 !has_decoration(id: target_id, decoration: DecorationBuiltIn) ||
6942 !is_array(type: expr_type))
6943 {
6944 CompilerGLSL::cast_to_variable_store(target_id, expr, expr_type);
6945 return;
6946 }
6947
6948 auto builtin = BuiltIn(get_decoration(id: target_id, decoration: DecorationBuiltIn));
6949 if (builtin != BuiltInClipDistance && builtin != BuiltInCullDistance)
6950 {
6951 CompilerGLSL::cast_to_variable_store(target_id, expr, expr_type);
6952 return;
6953 }
6954
6955 // Array of array means one thread is storing clip distance for all vertices. Nonsensical?
6956 if (is_array(type: get<SPIRType>(id: expr_type.parent_type)))
6957 SPIRV_CROSS_THROW("Attempting to store all mesh vertices in one go. This is not supported.");
6958
6959 uint32_t num_clip = to_array_size_literal(type: expr_type);
6960 if (num_clip > 4)
6961 SPIRV_CROSS_THROW("Number of clip or cull distances exceeds 4, this will not work with mesh shaders.");
6962
6963 auto unrolled_expr = join(ts: "float", ts&: num_clip, ts: "(");
6964 for (uint32_t i = 0; i < num_clip; i++)
6965 {
6966 unrolled_expr += join(ts&: expr, ts: "[", ts&: i, ts: "]");
6967 if (i + 1 < num_clip)
6968 unrolled_expr += ", ";
6969 }
6970
6971 unrolled_expr += ")";
6972 expr = std::move(unrolled_expr);
6973}
6974

source code of qtshadertools/src/3rdparty/SPIRV-Cross/spirv_hlsl.cpp