1/*
2 * Copyright 2016-2021 Robert Konrad
3 * SPDX-License-Identifier: Apache-2.0 OR MIT
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19/*
20 * At your option, you may choose to accept this material under either:
21 * 1. The Apache License, Version 2.0, found at <http://www.apache.org/licenses/LICENSE-2.0>, or
22 * 2. The MIT License, found at <http://opensource.org/licenses/MIT>.
23 */
24
25#include "spirv_hlsl.hpp"
26#include "GLSL.std.450.h"
27#include <algorithm>
28#include <assert.h>
29
30using namespace spv;
31using namespace SPIRV_CROSS_NAMESPACE;
32using namespace std;
33
34enum class ImageFormatNormalizedState
35{
36 None = 0,
37 Unorm = 1,
38 Snorm = 2
39};
40
41static ImageFormatNormalizedState image_format_to_normalized_state(ImageFormat fmt)
42{
43 switch (fmt)
44 {
45 case ImageFormatR8:
46 case ImageFormatR16:
47 case ImageFormatRg8:
48 case ImageFormatRg16:
49 case ImageFormatRgba8:
50 case ImageFormatRgba16:
51 case ImageFormatRgb10A2:
52 return ImageFormatNormalizedState::Unorm;
53
54 case ImageFormatR8Snorm:
55 case ImageFormatR16Snorm:
56 case ImageFormatRg8Snorm:
57 case ImageFormatRg16Snorm:
58 case ImageFormatRgba8Snorm:
59 case ImageFormatRgba16Snorm:
60 return ImageFormatNormalizedState::Snorm;
61
62 default:
63 break;
64 }
65
66 return ImageFormatNormalizedState::None;
67}
68
69static unsigned image_format_to_components(ImageFormat fmt)
70{
71 switch (fmt)
72 {
73 case ImageFormatR8:
74 case ImageFormatR16:
75 case ImageFormatR8Snorm:
76 case ImageFormatR16Snorm:
77 case ImageFormatR16f:
78 case ImageFormatR32f:
79 case ImageFormatR8i:
80 case ImageFormatR16i:
81 case ImageFormatR32i:
82 case ImageFormatR8ui:
83 case ImageFormatR16ui:
84 case ImageFormatR32ui:
85 return 1;
86
87 case ImageFormatRg8:
88 case ImageFormatRg16:
89 case ImageFormatRg8Snorm:
90 case ImageFormatRg16Snorm:
91 case ImageFormatRg16f:
92 case ImageFormatRg32f:
93 case ImageFormatRg8i:
94 case ImageFormatRg16i:
95 case ImageFormatRg32i:
96 case ImageFormatRg8ui:
97 case ImageFormatRg16ui:
98 case ImageFormatRg32ui:
99 return 2;
100
101 case ImageFormatR11fG11fB10f:
102 return 3;
103
104 case ImageFormatRgba8:
105 case ImageFormatRgba16:
106 case ImageFormatRgb10A2:
107 case ImageFormatRgba8Snorm:
108 case ImageFormatRgba16Snorm:
109 case ImageFormatRgba16f:
110 case ImageFormatRgba32f:
111 case ImageFormatRgba8i:
112 case ImageFormatRgba16i:
113 case ImageFormatRgba32i:
114 case ImageFormatRgba8ui:
115 case ImageFormatRgba16ui:
116 case ImageFormatRgba32ui:
117 case ImageFormatRgb10a2ui:
118 return 4;
119
120 case ImageFormatUnknown:
121 return 4; // Assume 4.
122
123 default:
124 SPIRV_CROSS_THROW("Unrecognized typed image format.");
125 }
126}
127
128static string image_format_to_type(ImageFormat fmt, SPIRType::BaseType basetype)
129{
130 switch (fmt)
131 {
132 case ImageFormatR8:
133 case ImageFormatR16:
134 if (basetype != SPIRType::Float)
135 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
136 return "unorm float";
137 case ImageFormatRg8:
138 case ImageFormatRg16:
139 if (basetype != SPIRType::Float)
140 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
141 return "unorm float2";
142 case ImageFormatRgba8:
143 case ImageFormatRgba16:
144 if (basetype != SPIRType::Float)
145 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
146 return "unorm float4";
147 case ImageFormatRgb10A2:
148 if (basetype != SPIRType::Float)
149 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
150 return "unorm float4";
151
152 case ImageFormatR8Snorm:
153 case ImageFormatR16Snorm:
154 if (basetype != SPIRType::Float)
155 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
156 return "snorm float";
157 case ImageFormatRg8Snorm:
158 case ImageFormatRg16Snorm:
159 if (basetype != SPIRType::Float)
160 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
161 return "snorm float2";
162 case ImageFormatRgba8Snorm:
163 case ImageFormatRgba16Snorm:
164 if (basetype != SPIRType::Float)
165 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
166 return "snorm float4";
167
168 case ImageFormatR16f:
169 case ImageFormatR32f:
170 if (basetype != SPIRType::Float)
171 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
172 return "float";
173 case ImageFormatRg16f:
174 case ImageFormatRg32f:
175 if (basetype != SPIRType::Float)
176 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
177 return "float2";
178 case ImageFormatRgba16f:
179 case ImageFormatRgba32f:
180 if (basetype != SPIRType::Float)
181 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
182 return "float4";
183
184 case ImageFormatR11fG11fB10f:
185 if (basetype != SPIRType::Float)
186 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
187 return "float3";
188
189 case ImageFormatR8i:
190 case ImageFormatR16i:
191 case ImageFormatR32i:
192 if (basetype != SPIRType::Int)
193 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
194 return "int";
195 case ImageFormatRg8i:
196 case ImageFormatRg16i:
197 case ImageFormatRg32i:
198 if (basetype != SPIRType::Int)
199 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
200 return "int2";
201 case ImageFormatRgba8i:
202 case ImageFormatRgba16i:
203 case ImageFormatRgba32i:
204 if (basetype != SPIRType::Int)
205 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
206 return "int4";
207
208 case ImageFormatR8ui:
209 case ImageFormatR16ui:
210 case ImageFormatR32ui:
211 if (basetype != SPIRType::UInt)
212 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
213 return "uint";
214 case ImageFormatRg8ui:
215 case ImageFormatRg16ui:
216 case ImageFormatRg32ui:
217 if (basetype != SPIRType::UInt)
218 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
219 return "uint2";
220 case ImageFormatRgba8ui:
221 case ImageFormatRgba16ui:
222 case ImageFormatRgba32ui:
223 if (basetype != SPIRType::UInt)
224 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
225 return "uint4";
226 case ImageFormatRgb10a2ui:
227 if (basetype != SPIRType::UInt)
228 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
229 return "uint4";
230
231 case ImageFormatUnknown:
232 switch (basetype)
233 {
234 case SPIRType::Float:
235 return "float4";
236 case SPIRType::Int:
237 return "int4";
238 case SPIRType::UInt:
239 return "uint4";
240 default:
241 SPIRV_CROSS_THROW("Unsupported base type for image.");
242 }
243
244 default:
245 SPIRV_CROSS_THROW("Unrecognized typed image format.");
246 }
247}
248
249string CompilerHLSL::image_type_hlsl_modern(const SPIRType &type, uint32_t id)
250{
251 auto &imagetype = get<SPIRType>(id: type.image.type);
252 const char *dim = nullptr;
253 bool typed_load = false;
254 uint32_t components = 4;
255
256 bool force_image_srv = hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(id, decoration: DecorationNonWritable);
257
258 switch (type.image.dim)
259 {
260 case Dim1D:
261 typed_load = type.image.sampled == 2;
262 dim = "1D";
263 break;
264 case Dim2D:
265 typed_load = type.image.sampled == 2;
266 dim = "2D";
267 break;
268 case Dim3D:
269 typed_load = type.image.sampled == 2;
270 dim = "3D";
271 break;
272 case DimCube:
273 if (type.image.sampled == 2)
274 SPIRV_CROSS_THROW("RWTextureCube does not exist in HLSL.");
275 dim = "Cube";
276 break;
277 case DimRect:
278 SPIRV_CROSS_THROW("Rectangle texture support is not yet implemented for HLSL."); // TODO
279 case DimBuffer:
280 if (type.image.sampled == 1)
281 return join(ts: "Buffer<", ts: type_to_glsl(type: imagetype), ts&: components, ts: ">");
282 else if (type.image.sampled == 2)
283 {
284 if (interlocked_resources.count(x: id))
285 return join(ts: "RasterizerOrderedBuffer<", ts: image_format_to_type(fmt: type.image.format, basetype: imagetype.basetype),
286 ts: ">");
287
288 typed_load = !force_image_srv && type.image.sampled == 2;
289
290 const char *rw = force_image_srv ? "" : "RW";
291 return join(ts&: rw, ts: "Buffer<",
292 ts: typed_load ? image_format_to_type(fmt: type.image.format, basetype: imagetype.basetype) :
293 join(ts: type_to_glsl(type: imagetype), ts&: components),
294 ts: ">");
295 }
296 else
297 SPIRV_CROSS_THROW("Sampler buffers must be either sampled or unsampled. Cannot deduce in runtime.");
298 case DimSubpassData:
299 dim = "2D";
300 typed_load = false;
301 break;
302 default:
303 SPIRV_CROSS_THROW("Invalid dimension.");
304 }
305 const char *arrayed = type.image.arrayed ? "Array" : "";
306 const char *ms = type.image.ms ? "MS" : "";
307 const char *rw = typed_load && !force_image_srv ? "RW" : "";
308
309 if (force_image_srv)
310 typed_load = false;
311
312 if (typed_load && interlocked_resources.count(x: id))
313 rw = "RasterizerOrdered";
314
315 return join(ts&: rw, ts: "Texture", ts&: dim, ts&: ms, ts&: arrayed, ts: "<",
316 ts: typed_load ? image_format_to_type(fmt: type.image.format, basetype: imagetype.basetype) :
317 join(ts: type_to_glsl(type: imagetype), ts&: components),
318 ts: ">");
319}
320
321string CompilerHLSL::image_type_hlsl_legacy(const SPIRType &type, uint32_t /*id*/)
322{
323 auto &imagetype = get<SPIRType>(id: type.image.type);
324 string res;
325
326 switch (imagetype.basetype)
327 {
328 case SPIRType::Int:
329 res = "i";
330 break;
331 case SPIRType::UInt:
332 res = "u";
333 break;
334 default:
335 break;
336 }
337
338 if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData)
339 return res + "subpassInput" + (type.image.ms ? "MS" : "");
340
341 // If we're emulating subpassInput with samplers, force sampler2D
342 // so we don't have to specify format.
343 if (type.basetype == SPIRType::Image && type.image.dim != DimSubpassData)
344 {
345 // Sampler buffers are always declared as samplerBuffer even though they might be separate images in the SPIR-V.
346 if (type.image.dim == DimBuffer && type.image.sampled == 1)
347 res += "sampler";
348 else
349 res += type.image.sampled == 2 ? "image" : "texture";
350 }
351 else
352 res += "sampler";
353
354 switch (type.image.dim)
355 {
356 case Dim1D:
357 res += "1D";
358 break;
359 case Dim2D:
360 res += "2D";
361 break;
362 case Dim3D:
363 res += "3D";
364 break;
365 case DimCube:
366 res += "CUBE";
367 break;
368
369 case DimBuffer:
370 res += "Buffer";
371 break;
372
373 case DimSubpassData:
374 res += "2D";
375 break;
376 default:
377 SPIRV_CROSS_THROW("Only 1D, 2D, 3D, Buffer, InputTarget and Cube textures supported.");
378 }
379
380 if (type.image.ms)
381 res += "MS";
382 if (type.image.arrayed)
383 res += "Array";
384
385 return res;
386}
387
388string CompilerHLSL::image_type_hlsl(const SPIRType &type, uint32_t id)
389{
390 if (hlsl_options.shader_model <= 30)
391 return image_type_hlsl_legacy(type, id);
392 else
393 return image_type_hlsl_modern(type, id);
394}
395
396// The optional id parameter indicates the object whose type we are trying
397// to find the description for. It is optional. Most type descriptions do not
398// depend on a specific object's use of that type.
399string CompilerHLSL::type_to_glsl(const SPIRType &type, uint32_t id)
400{
401 // Ignore the pointer type since GLSL doesn't have pointers.
402
403 switch (type.basetype)
404 {
405 case SPIRType::Struct:
406 // Need OpName lookup here to get a "sensible" name for a struct.
407 if (backend.explicit_struct_type)
408 return join(ts: "struct ", ts: to_name(id: type.self));
409 else
410 return to_name(id: type.self);
411
412 case SPIRType::Image:
413 case SPIRType::SampledImage:
414 return image_type_hlsl(type, id);
415
416 case SPIRType::Sampler:
417 return comparison_ids.count(x: id) ? "SamplerComparisonState" : "SamplerState";
418
419 case SPIRType::Void:
420 return "void";
421
422 default:
423 break;
424 }
425
426 if (type.vecsize == 1 && type.columns == 1) // Scalar builtin
427 {
428 switch (type.basetype)
429 {
430 case SPIRType::Boolean:
431 return "bool";
432 case SPIRType::Int:
433 return backend.basic_int_type;
434 case SPIRType::UInt:
435 return backend.basic_uint_type;
436 case SPIRType::AtomicCounter:
437 return "atomic_uint";
438 case SPIRType::Half:
439 if (hlsl_options.enable_16bit_types)
440 return "half";
441 else
442 return "min16float";
443 case SPIRType::Short:
444 if (hlsl_options.enable_16bit_types)
445 return "int16_t";
446 else
447 return "min16int";
448 case SPIRType::UShort:
449 if (hlsl_options.enable_16bit_types)
450 return "uint16_t";
451 else
452 return "min16uint";
453 case SPIRType::Float:
454 return "float";
455 case SPIRType::Double:
456 return "double";
457 case SPIRType::Int64:
458 if (hlsl_options.shader_model < 60)
459 SPIRV_CROSS_THROW("64-bit integers only supported in SM 6.0.");
460 return "int64_t";
461 case SPIRType::UInt64:
462 if (hlsl_options.shader_model < 60)
463 SPIRV_CROSS_THROW("64-bit integers only supported in SM 6.0.");
464 return "uint64_t";
465 case SPIRType::AccelerationStructure:
466 return "RaytracingAccelerationStructure";
467 case SPIRType::RayQuery:
468 return "RayQuery<RAY_FLAG_NONE>";
469 default:
470 return "???";
471 }
472 }
473 else if (type.vecsize > 1 && type.columns == 1) // Vector builtin
474 {
475 switch (type.basetype)
476 {
477 case SPIRType::Boolean:
478 return join(ts: "bool", ts: type.vecsize);
479 case SPIRType::Int:
480 return join(ts: "int", ts: type.vecsize);
481 case SPIRType::UInt:
482 return join(ts: "uint", ts: type.vecsize);
483 case SPIRType::Half:
484 return join(ts: hlsl_options.enable_16bit_types ? "half" : "min16float", ts: type.vecsize);
485 case SPIRType::Short:
486 return join(ts: hlsl_options.enable_16bit_types ? "int16_t" : "min16int", ts: type.vecsize);
487 case SPIRType::UShort:
488 return join(ts: hlsl_options.enable_16bit_types ? "uint16_t" : "min16uint", ts: type.vecsize);
489 case SPIRType::Float:
490 return join(ts: "float", ts: type.vecsize);
491 case SPIRType::Double:
492 return join(ts: "double", ts: type.vecsize);
493 case SPIRType::Int64:
494 return join(ts: "i64vec", ts: type.vecsize);
495 case SPIRType::UInt64:
496 return join(ts: "u64vec", ts: type.vecsize);
497 default:
498 return "???";
499 }
500 }
501 else
502 {
503 switch (type.basetype)
504 {
505 case SPIRType::Boolean:
506 return join(ts: "bool", ts: type.columns, ts: "x", ts: type.vecsize);
507 case SPIRType::Int:
508 return join(ts: "int", ts: type.columns, ts: "x", ts: type.vecsize);
509 case SPIRType::UInt:
510 return join(ts: "uint", ts: type.columns, ts: "x", ts: type.vecsize);
511 case SPIRType::Half:
512 return join(ts: hlsl_options.enable_16bit_types ? "half" : "min16float", ts: type.columns, ts: "x", ts: type.vecsize);
513 case SPIRType::Short:
514 return join(ts: hlsl_options.enable_16bit_types ? "int16_t" : "min16int", ts: type.columns, ts: "x", ts: type.vecsize);
515 case SPIRType::UShort:
516 return join(ts: hlsl_options.enable_16bit_types ? "uint16_t" : "min16uint", ts: type.columns, ts: "x", ts: type.vecsize);
517 case SPIRType::Float:
518 return join(ts: "float", ts: type.columns, ts: "x", ts: type.vecsize);
519 case SPIRType::Double:
520 return join(ts: "double", ts: type.columns, ts: "x", ts: type.vecsize);
521 // Matrix types not supported for int64/uint64.
522 default:
523 return "???";
524 }
525 }
526}
527
528void CompilerHLSL::emit_header()
529{
530 for (auto &header : header_lines)
531 statement(ts&: header);
532
533 if (header_lines.size() > 0)
534 {
535 statement(ts: "");
536 }
537}
538
539void CompilerHLSL::emit_interface_block_globally(const SPIRVariable &var)
540{
541 add_resource_name(id: var.self);
542
543 // The global copies of I/O variables should not contain interpolation qualifiers.
544 // These are emitted inside the interface structs.
545 auto &flags = ir.meta[var.self].decoration.decoration_flags;
546 auto old_flags = flags;
547 flags.reset();
548 statement(ts: "static ", ts: variable_decl(variable: var), ts: ";");
549 flags = old_flags;
550}
551
552const char *CompilerHLSL::to_storage_qualifiers_glsl(const SPIRVariable &var)
553{
554 // Input and output variables are handled specially in HLSL backend.
555 // The variables are declared as global, private variables, and do not need any qualifiers.
556 if (var.storage == StorageClassUniformConstant || var.storage == StorageClassUniform ||
557 var.storage == StorageClassPushConstant)
558 {
559 return "uniform ";
560 }
561
562 return "";
563}
564
565void CompilerHLSL::emit_builtin_outputs_in_struct()
566{
567 auto &execution = get_entry_point();
568
569 bool legacy = hlsl_options.shader_model <= 30;
570 active_output_builtins.for_each_bit(op: [&](uint32_t i) {
571 const char *type = nullptr;
572 const char *semantic = nullptr;
573 auto builtin = static_cast<BuiltIn>(i);
574 switch (builtin)
575 {
576 case BuiltInPosition:
577 type = is_position_invariant() && backend.support_precise_qualifier ? "precise float4" : "float4";
578 semantic = legacy ? "POSITION" : "SV_Position";
579 break;
580
581 case BuiltInSampleMask:
582 if (hlsl_options.shader_model < 41 || execution.model != ExecutionModelFragment)
583 SPIRV_CROSS_THROW("Sample Mask output is only supported in PS 4.1 or higher.");
584 type = "uint";
585 semantic = "SV_Coverage";
586 break;
587
588 case BuiltInFragDepth:
589 type = "float";
590 if (legacy)
591 {
592 semantic = "DEPTH";
593 }
594 else
595 {
596 if (hlsl_options.shader_model >= 50 && execution.flags.get(bit: ExecutionModeDepthGreater))
597 semantic = "SV_DepthGreaterEqual";
598 else if (hlsl_options.shader_model >= 50 && execution.flags.get(bit: ExecutionModeDepthLess))
599 semantic = "SV_DepthLessEqual";
600 else
601 semantic = "SV_Depth";
602 }
603 break;
604
605 case BuiltInClipDistance:
606 {
607 static const char *types[] = { "float", "float2", "float3", "float4" };
608
609 // HLSL is a bit weird here, use SV_ClipDistance0, SV_ClipDistance1 and so on with vectors.
610 if (execution.model == ExecutionModelMeshEXT)
611 {
612 if (clip_distance_count > 4)
613 SPIRV_CROSS_THROW("Clip distance count > 4 not supported for mesh shaders.");
614
615 if (clip_distance_count == 1)
616 {
617 // Avoids having to hack up access_chain code. Makes it trivially indexable.
618 statement(ts: "float gl_ClipDistance[1] : SV_ClipDistance;");
619 }
620 else
621 {
622 // Replace array with vector directly, avoids any weird fixup path.
623 statement(ts&: types[clip_distance_count - 1], ts: " gl_ClipDistance : SV_ClipDistance;");
624 }
625 }
626 else
627 {
628 for (uint32_t clip = 0; clip < clip_distance_count; clip += 4)
629 {
630 uint32_t to_declare = clip_distance_count - clip;
631 if (to_declare > 4)
632 to_declare = 4;
633
634 uint32_t semantic_index = clip / 4;
635
636 statement(ts&: types[to_declare - 1], ts: " ", ts: builtin_to_glsl(builtin, storage: StorageClassOutput), ts&: semantic_index,
637 ts: " : SV_ClipDistance", ts&: semantic_index, ts: ";");
638 }
639 }
640 break;
641 }
642
643 case BuiltInCullDistance:
644 {
645 static const char *types[] = { "float", "float2", "float3", "float4" };
646
647 // HLSL is a bit weird here, use SV_CullDistance0, SV_CullDistance1 and so on with vectors.
648 if (execution.model == ExecutionModelMeshEXT)
649 {
650 if (cull_distance_count > 4)
651 SPIRV_CROSS_THROW("Cull distance count > 4 not supported for mesh shaders.");
652
653 if (cull_distance_count == 1)
654 {
655 // Avoids having to hack up access_chain code. Makes it trivially indexable.
656 statement(ts: "float gl_CullDistance[1] : SV_CullDistance;");
657 }
658 else
659 {
660 // Replace array with vector directly, avoids any weird fixup path.
661 statement(ts&: types[cull_distance_count - 1], ts: " gl_CullDistance : SV_CullDistance;");
662 }
663 }
664 else
665 {
666 for (uint32_t cull = 0; cull < cull_distance_count; cull += 4)
667 {
668 uint32_t to_declare = cull_distance_count - cull;
669 if (to_declare > 4)
670 to_declare = 4;
671
672 uint32_t semantic_index = cull / 4;
673
674 statement(ts&: types[to_declare - 1], ts: " ", ts: builtin_to_glsl(builtin, storage: StorageClassOutput), ts&: semantic_index,
675 ts: " : SV_CullDistance", ts&: semantic_index, ts: ";");
676 }
677 }
678 break;
679 }
680
681 case BuiltInPointSize:
682 // If point_size_compat is enabled, just ignore PointSize.
683 // PointSize does not exist in HLSL, but some code bases might want to be able to use these shaders,
684 // even if it means working around the missing feature.
685 if (legacy)
686 {
687 type = "float";
688 semantic = "PSIZE";
689 }
690 else if (!hlsl_options.point_size_compat)
691 SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
692 break;
693
694 case BuiltInLayer:
695 case BuiltInPrimitiveId:
696 case BuiltInViewportIndex:
697 case BuiltInPrimitiveShadingRateKHR:
698 case BuiltInCullPrimitiveEXT:
699 // per-primitive attributes handled separatly
700 break;
701
702 case BuiltInPrimitivePointIndicesEXT:
703 case BuiltInPrimitiveLineIndicesEXT:
704 case BuiltInPrimitiveTriangleIndicesEXT:
705 // meshlet local-index buffer handled separatly
706 break;
707
708 default:
709 SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
710 }
711
712 if (type && semantic)
713 statement(ts&: type, ts: " ", ts: builtin_to_glsl(builtin, storage: StorageClassOutput), ts: " : ", ts&: semantic, ts: ";");
714 });
715}
716
717void CompilerHLSL::emit_builtin_primitive_outputs_in_struct()
718{
719 active_output_builtins.for_each_bit(op: [&](uint32_t i) {
720 const char *type = nullptr;
721 const char *semantic = nullptr;
722 auto builtin = static_cast<BuiltIn>(i);
723 switch (builtin)
724 {
725 case BuiltInLayer:
726 {
727 if (hlsl_options.shader_model < 50)
728 SPIRV_CROSS_THROW("Render target array index output is only supported in SM 5.0 or higher.");
729 type = "uint";
730 semantic = "SV_RenderTargetArrayIndex";
731 break;
732 }
733
734 case BuiltInPrimitiveId:
735 type = "uint";
736 semantic = "SV_PrimitiveID";
737 break;
738
739 case BuiltInViewportIndex:
740 type = "uint";
741 semantic = "SV_ViewportArrayIndex";
742 break;
743
744 case BuiltInPrimitiveShadingRateKHR:
745 type = "uint";
746 semantic = "SV_ShadingRate";
747 break;
748
749 case BuiltInCullPrimitiveEXT:
750 type = "bool";
751 semantic = "SV_CullPrimitive";
752 break;
753
754 default:
755 break;
756 }
757
758 if (type && semantic)
759 statement(ts&: type, ts: " ", ts: builtin_to_glsl(builtin, storage: StorageClassOutput), ts: " : ", ts&: semantic, ts: ";");
760 });
761}
762
763void CompilerHLSL::emit_builtin_inputs_in_struct()
764{
765 bool legacy = hlsl_options.shader_model <= 30;
766 active_input_builtins.for_each_bit(op: [&](uint32_t i) {
767 const char *type = nullptr;
768 const char *semantic = nullptr;
769 auto builtin = static_cast<BuiltIn>(i);
770 switch (builtin)
771 {
772 case BuiltInFragCoord:
773 type = "float4";
774 semantic = legacy ? "VPOS" : "SV_Position";
775 break;
776
777 case BuiltInVertexId:
778 case BuiltInVertexIndex:
779 if (legacy)
780 SPIRV_CROSS_THROW("Vertex index not supported in SM 3.0 or lower.");
781 type = "uint";
782 semantic = "SV_VertexID";
783 break;
784
785 case BuiltInPrimitiveId:
786 type = "uint";
787 semantic = "SV_PrimitiveID";
788 break;
789
790 case BuiltInInstanceId:
791 case BuiltInInstanceIndex:
792 if (legacy)
793 SPIRV_CROSS_THROW("Instance index not supported in SM 3.0 or lower.");
794 type = "uint";
795 semantic = "SV_InstanceID";
796 break;
797
798 case BuiltInSampleId:
799 if (legacy)
800 SPIRV_CROSS_THROW("Sample ID not supported in SM 3.0 or lower.");
801 type = "uint";
802 semantic = "SV_SampleIndex";
803 break;
804
805 case BuiltInSampleMask:
806 if (hlsl_options.shader_model < 50 || get_entry_point().model != ExecutionModelFragment)
807 SPIRV_CROSS_THROW("Sample Mask input is only supported in PS 5.0 or higher.");
808 type = "uint";
809 semantic = "SV_Coverage";
810 break;
811
812 case BuiltInGlobalInvocationId:
813 type = "uint3";
814 semantic = "SV_DispatchThreadID";
815 break;
816
817 case BuiltInLocalInvocationId:
818 type = "uint3";
819 semantic = "SV_GroupThreadID";
820 break;
821
822 case BuiltInLocalInvocationIndex:
823 type = "uint";
824 semantic = "SV_GroupIndex";
825 break;
826
827 case BuiltInWorkgroupId:
828 type = "uint3";
829 semantic = "SV_GroupID";
830 break;
831
832 case BuiltInFrontFacing:
833 type = "bool";
834 semantic = "SV_IsFrontFace";
835 break;
836
837 case BuiltInViewIndex:
838 if (hlsl_options.shader_model < 61 || (get_entry_point().model != ExecutionModelVertex && get_entry_point().model != ExecutionModelFragment))
839 SPIRV_CROSS_THROW("View Index input is only supported in VS and PS 6.1 or higher.");
840 type = "uint";
841 semantic = "SV_ViewID";
842 break;
843
844 case BuiltInNumWorkgroups:
845 case BuiltInSubgroupSize:
846 case BuiltInSubgroupLocalInvocationId:
847 case BuiltInSubgroupEqMask:
848 case BuiltInSubgroupLtMask:
849 case BuiltInSubgroupLeMask:
850 case BuiltInSubgroupGtMask:
851 case BuiltInSubgroupGeMask:
852 case BuiltInBaseVertex:
853 case BuiltInBaseInstance:
854 // Handled specially.
855 break;
856
857 case BuiltInHelperInvocation:
858 if (hlsl_options.shader_model < 50 || get_entry_point().model != ExecutionModelFragment)
859 SPIRV_CROSS_THROW("Helper Invocation input is only supported in PS 5.0 or higher.");
860 break;
861
862 case BuiltInClipDistance:
863 // HLSL is a bit weird here, use SV_ClipDistance0, SV_ClipDistance1 and so on with vectors.
864 for (uint32_t clip = 0; clip < clip_distance_count; clip += 4)
865 {
866 uint32_t to_declare = clip_distance_count - clip;
867 if (to_declare > 4)
868 to_declare = 4;
869
870 uint32_t semantic_index = clip / 4;
871
872 static const char *types[] = { "float", "float2", "float3", "float4" };
873 statement(ts&: types[to_declare - 1], ts: " ", ts: builtin_to_glsl(builtin, storage: StorageClassInput), ts&: semantic_index,
874 ts: " : SV_ClipDistance", ts&: semantic_index, ts: ";");
875 }
876 break;
877
878 case BuiltInCullDistance:
879 // HLSL is a bit weird here, use SV_CullDistance0, SV_CullDistance1 and so on with vectors.
880 for (uint32_t cull = 0; cull < cull_distance_count; cull += 4)
881 {
882 uint32_t to_declare = cull_distance_count - cull;
883 if (to_declare > 4)
884 to_declare = 4;
885
886 uint32_t semantic_index = cull / 4;
887
888 static const char *types[] = { "float", "float2", "float3", "float4" };
889 statement(ts&: types[to_declare - 1], ts: " ", ts: builtin_to_glsl(builtin, storage: StorageClassInput), ts&: semantic_index,
890 ts: " : SV_CullDistance", ts&: semantic_index, ts: ";");
891 }
892 break;
893
894 case BuiltInPointCoord:
895 // PointCoord is not supported, but provide a way to just ignore that, similar to PointSize.
896 if (hlsl_options.point_coord_compat)
897 break;
898 else
899 SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
900
901 case BuiltInLayer:
902 if (hlsl_options.shader_model < 50 || get_entry_point().model != ExecutionModelFragment)
903 SPIRV_CROSS_THROW("Render target array index input is only supported in PS 5.0 or higher.");
904 type = "uint";
905 semantic = "SV_RenderTargetArrayIndex";
906 break;
907
908 default:
909 SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
910 }
911
912 if (type && semantic)
913 statement(ts&: type, ts: " ", ts: builtin_to_glsl(builtin, storage: StorageClassInput), ts: " : ", ts&: semantic, ts: ";");
914 });
915}
916
917uint32_t CompilerHLSL::type_to_consumed_locations(const SPIRType &type) const
918{
919 // TODO: Need to verify correctness.
920 uint32_t elements = 0;
921
922 if (type.basetype == SPIRType::Struct)
923 {
924 for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++)
925 elements += type_to_consumed_locations(type: get<SPIRType>(id: type.member_types[i]));
926 }
927 else
928 {
929 uint32_t array_multiplier = 1;
930 for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
931 {
932 if (type.array_size_literal[i])
933 array_multiplier *= type.array[i];
934 else
935 array_multiplier *= evaluate_constant_u32(id: type.array[i]);
936 }
937 elements += array_multiplier * type.columns;
938 }
939 return elements;
940}
941
942string CompilerHLSL::to_interpolation_qualifiers(const Bitset &flags)
943{
944 string res;
945 //if (flags & (1ull << DecorationSmooth))
946 // res += "linear ";
947 if (flags.get(bit: DecorationFlat))
948 res += "nointerpolation ";
949 if (flags.get(bit: DecorationNoPerspective))
950 res += "noperspective ";
951 if (flags.get(bit: DecorationCentroid))
952 res += "centroid ";
953 if (flags.get(bit: DecorationPatch))
954 res += "patch "; // Seems to be different in actual HLSL.
955 if (flags.get(bit: DecorationSample))
956 res += "sample ";
957 if (flags.get(bit: DecorationInvariant) && backend.support_precise_qualifier)
958 res += "precise "; // Not supported?
959
960 return res;
961}
962
963std::string CompilerHLSL::to_semantic(uint32_t location, ExecutionModel em, StorageClass sc)
964{
965 if (em == ExecutionModelVertex && sc == StorageClassInput)
966 {
967 // We have a vertex attribute - we should look at remapping it if the user provided
968 // vertex attribute hints.
969 for (auto &attribute : remap_vertex_attributes)
970 if (attribute.location == location)
971 return attribute.semantic;
972 }
973
974 // Not a vertex attribute, or no remap_vertex_attributes entry.
975 return join(ts: "TEXCOORD", ts&: location);
976}
977
978std::string CompilerHLSL::to_initializer_expression(const SPIRVariable &var)
979{
980 // We cannot emit static const initializer for block constants for practical reasons,
981 // so just inline the initializer.
982 // FIXME: There is a theoretical problem here if someone tries to composite extract
983 // into this initializer since we don't declare it properly, but that is somewhat non-sensical.
984 auto &type = get<SPIRType>(id: var.basetype);
985 bool is_block = has_decoration(id: type.self, decoration: DecorationBlock);
986 auto *c = maybe_get<SPIRConstant>(id: var.initializer);
987 if (is_block && c)
988 return constant_expression(c: *c);
989 else
990 return CompilerGLSL::to_initializer_expression(var);
991}
992
993void CompilerHLSL::emit_interface_block_member_in_struct(const SPIRVariable &var, uint32_t member_index,
994 uint32_t location,
995 std::unordered_set<uint32_t> &active_locations)
996{
997 auto &execution = get_entry_point();
998 auto type = get<SPIRType>(id: var.basetype);
999 auto semantic = to_semantic(location, em: execution.model, sc: var.storage);
1000 auto mbr_name = join(ts: to_name(id: type.self), ts: "_", ts: to_member_name(type, index: member_index));
1001 auto &mbr_type = get<SPIRType>(id: type.member_types[member_index]);
1002
1003 statement(ts: to_interpolation_qualifiers(flags: get_member_decoration_bitset(id: type.self, index: member_index)),
1004 ts: type_to_glsl(type: mbr_type),
1005 ts: " ", ts&: mbr_name, ts: type_to_array_glsl(type: mbr_type, variable_id: var.self),
1006 ts: " : ", ts&: semantic, ts: ";");
1007
1008 // Structs and arrays should consume more locations.
1009 uint32_t consumed_locations = type_to_consumed_locations(type: mbr_type);
1010 for (uint32_t i = 0; i < consumed_locations; i++)
1011 active_locations.insert(x: location + i);
1012}
1013
1014void CompilerHLSL::emit_interface_block_in_struct(const SPIRVariable &var, unordered_set<uint32_t> &active_locations)
1015{
1016 auto &execution = get_entry_point();
1017 auto type = get<SPIRType>(id: var.basetype);
1018
1019 string binding;
1020 bool use_location_number = true;
1021 bool need_matrix_unroll = false;
1022 bool legacy = hlsl_options.shader_model <= 30;
1023 if (execution.model == ExecutionModelFragment && var.storage == StorageClassOutput)
1024 {
1025 // Dual-source blending is achieved in HLSL by emitting to SV_Target0 and 1.
1026 uint32_t index = get_decoration(id: var.self, decoration: DecorationIndex);
1027 uint32_t location = get_decoration(id: var.self, decoration: DecorationLocation);
1028
1029 if (index != 0 && location != 0)
1030 SPIRV_CROSS_THROW("Dual-source blending is only supported on MRT #0 in HLSL.");
1031
1032 binding = join(ts: legacy ? "COLOR" : "SV_Target", ts: location + index);
1033 use_location_number = false;
1034 if (legacy) // COLOR must be a four-component vector on legacy shader model targets (HLSL ERR_COLOR_4COMP)
1035 type.vecsize = 4;
1036 }
1037 else if (var.storage == StorageClassInput && execution.model == ExecutionModelVertex)
1038 {
1039 need_matrix_unroll = true;
1040 if (legacy) // Inputs must be floating-point in legacy targets.
1041 type.basetype = SPIRType::Float;
1042 }
1043
1044 const auto get_vacant_location = [&]() -> uint32_t {
1045 for (uint32_t i = 0; i < 64; i++)
1046 if (!active_locations.count(x: i))
1047 return i;
1048 SPIRV_CROSS_THROW("All locations from 0 to 63 are exhausted.");
1049 };
1050
1051 auto name = to_name(id: var.self);
1052 if (use_location_number)
1053 {
1054 uint32_t location_number;
1055
1056 // If an explicit location exists, use it with TEXCOORD[N] semantic.
1057 // Otherwise, pick a vacant location.
1058 if (has_decoration(id: var.self, decoration: DecorationLocation))
1059 location_number = get_decoration(id: var.self, decoration: DecorationLocation);
1060 else
1061 location_number = get_vacant_location();
1062
1063 // Allow semantic remap if specified.
1064 auto semantic = to_semantic(location: location_number, em: execution.model, sc: var.storage);
1065
1066 if (need_matrix_unroll && type.columns > 1)
1067 {
1068 if (!type.array.empty())
1069 SPIRV_CROSS_THROW("Arrays of matrices used as input/output. This is not supported.");
1070
1071 // Unroll matrices.
1072 for (uint32_t i = 0; i < type.columns; i++)
1073 {
1074 SPIRType newtype = type;
1075 newtype.columns = 1;
1076
1077 string effective_semantic;
1078 if (hlsl_options.flatten_matrix_vertex_input_semantics)
1079 effective_semantic = to_semantic(location: location_number, em: execution.model, sc: var.storage);
1080 else
1081 effective_semantic = join(ts&: semantic, ts: "_", ts&: i);
1082
1083 statement(ts: to_interpolation_qualifiers(flags: get_decoration_bitset(id: var.self)),
1084 ts: variable_decl(type: newtype, name: join(ts&: name, ts: "_", ts&: i)), ts: " : ", ts&: effective_semantic, ts: ";");
1085 active_locations.insert(x: location_number++);
1086 }
1087 }
1088 else
1089 {
1090 auto decl_type = type;
1091 if (execution.model == ExecutionModelMeshEXT)
1092 {
1093 decl_type.array.erase(itr: decl_type.array.begin());
1094 decl_type.array_size_literal.erase(itr: decl_type.array_size_literal.begin());
1095 }
1096 statement(ts: to_interpolation_qualifiers(flags: get_decoration_bitset(id: var.self)), ts: variable_decl(type: decl_type, name), ts: " : ",
1097 ts&: semantic, ts: ";");
1098
1099 // Structs and arrays should consume more locations.
1100 uint32_t consumed_locations = type_to_consumed_locations(type: decl_type);
1101 for (uint32_t i = 0; i < consumed_locations; i++)
1102 active_locations.insert(x: location_number + i);
1103 }
1104 }
1105 else
1106 {
1107 statement(ts: variable_decl(type, name), ts: " : ", ts&: binding, ts: ";");
1108 }
1109}
1110
1111std::string CompilerHLSL::builtin_to_glsl(spv::BuiltIn builtin, spv::StorageClass storage)
1112{
1113 switch (builtin)
1114 {
1115 case BuiltInVertexId:
1116 return "gl_VertexID";
1117 case BuiltInInstanceId:
1118 return "gl_InstanceID";
1119 case BuiltInNumWorkgroups:
1120 {
1121 if (!num_workgroups_builtin)
1122 SPIRV_CROSS_THROW("NumWorkgroups builtin is used, but remap_num_workgroups_builtin() was not called. "
1123 "Cannot emit code for this builtin.");
1124
1125 auto &var = get<SPIRVariable>(id: num_workgroups_builtin);
1126 auto &type = get<SPIRType>(id: var.basetype);
1127 auto ret = join(ts: to_name(id: num_workgroups_builtin), ts: "_", ts: get_member_name(id: type.self, index: 0));
1128 ParsedIR::sanitize_underscores(str&: ret);
1129 return ret;
1130 }
1131 case BuiltInPointCoord:
1132 // Crude hack, but there is no real alternative. This path is only enabled if point_coord_compat is set.
1133 return "float2(0.5f, 0.5f)";
1134 case BuiltInSubgroupLocalInvocationId:
1135 return "WaveGetLaneIndex()";
1136 case BuiltInSubgroupSize:
1137 return "WaveGetLaneCount()";
1138 case BuiltInHelperInvocation:
1139 return "IsHelperLane()";
1140
1141 default:
1142 return CompilerGLSL::builtin_to_glsl(builtin, storage);
1143 }
1144}
1145
1146void CompilerHLSL::emit_builtin_variables()
1147{
1148 Bitset builtins = active_input_builtins;
1149 builtins.merge_or(other: active_output_builtins);
1150
1151 std::unordered_map<uint32_t, ID> builtin_to_initializer;
1152
1153 // We need to declare sample mask with the same type that module declares it.
1154 // Sample mask is somewhat special in that SPIR-V has an array, and we can copy that array, so we need to
1155 // match sign.
1156 SPIRType::BaseType sample_mask_in_basetype = SPIRType::Void;
1157 SPIRType::BaseType sample_mask_out_basetype = SPIRType::Void;
1158
1159 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
1160 if (!is_builtin_variable(var))
1161 return;
1162
1163 auto &type = this->get<SPIRType>(id: var.basetype);
1164 auto builtin = BuiltIn(get_decoration(id: var.self, decoration: DecorationBuiltIn));
1165
1166 if (var.storage == StorageClassInput && builtin == BuiltInSampleMask)
1167 sample_mask_in_basetype = type.basetype;
1168 else if (var.storage == StorageClassOutput && builtin == BuiltInSampleMask)
1169 sample_mask_out_basetype = type.basetype;
1170
1171 if (var.initializer && var.storage == StorageClassOutput)
1172 {
1173 auto *c = this->maybe_get<SPIRConstant>(id: var.initializer);
1174 if (!c)
1175 return;
1176
1177 if (type.basetype == SPIRType::Struct)
1178 {
1179 uint32_t member_count = uint32_t(type.member_types.size());
1180 for (uint32_t i = 0; i < member_count; i++)
1181 {
1182 if (has_member_decoration(id: type.self, index: i, decoration: DecorationBuiltIn))
1183 {
1184 builtin_to_initializer[get_member_decoration(id: type.self, index: i, decoration: DecorationBuiltIn)] =
1185 c->subconstants[i];
1186 }
1187 }
1188 }
1189 else if (has_decoration(id: var.self, decoration: DecorationBuiltIn))
1190 {
1191 builtin_to_initializer[builtin] = var.initializer;
1192 }
1193 }
1194 });
1195
1196 // Emit global variables for the interface variables which are statically used by the shader.
1197 builtins.for_each_bit(op: [&](uint32_t i) {
1198 const char *type = nullptr;
1199 auto builtin = static_cast<BuiltIn>(i);
1200 uint32_t array_size = 0;
1201
1202 string init_expr;
1203 auto init_itr = builtin_to_initializer.find(x: builtin);
1204 if (init_itr != builtin_to_initializer.end())
1205 init_expr = join(ts: " = ", ts: to_expression(id: init_itr->second));
1206
1207 if (get_execution_model() == ExecutionModelMeshEXT)
1208 {
1209 if (builtin == BuiltInPosition || builtin == BuiltInPointSize || builtin == BuiltInClipDistance ||
1210 builtin == BuiltInCullDistance || builtin == BuiltInLayer || builtin == BuiltInPrimitiveId ||
1211 builtin == BuiltInViewportIndex || builtin == BuiltInCullPrimitiveEXT ||
1212 builtin == BuiltInPrimitiveShadingRateKHR || builtin == BuiltInPrimitivePointIndicesEXT ||
1213 builtin == BuiltInPrimitiveLineIndicesEXT || builtin == BuiltInPrimitiveTriangleIndicesEXT)
1214 {
1215 return;
1216 }
1217 }
1218
1219 switch (builtin)
1220 {
1221 case BuiltInFragCoord:
1222 case BuiltInPosition:
1223 type = "float4";
1224 break;
1225
1226 case BuiltInFragDepth:
1227 type = "float";
1228 break;
1229
1230 case BuiltInVertexId:
1231 case BuiltInVertexIndex:
1232 case BuiltInInstanceIndex:
1233 type = "int";
1234 if (hlsl_options.support_nonzero_base_vertex_base_instance)
1235 base_vertex_info.used = true;
1236 break;
1237
1238 case BuiltInBaseVertex:
1239 case BuiltInBaseInstance:
1240 type = "int";
1241 base_vertex_info.used = true;
1242 break;
1243
1244 case BuiltInInstanceId:
1245 case BuiltInSampleId:
1246 type = "int";
1247 break;
1248
1249 case BuiltInPointSize:
1250 if (hlsl_options.point_size_compat || hlsl_options.shader_model <= 30)
1251 {
1252 // Just emit the global variable, it will be ignored.
1253 type = "float";
1254 break;
1255 }
1256 else
1257 SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin)));
1258
1259 case BuiltInGlobalInvocationId:
1260 case BuiltInLocalInvocationId:
1261 case BuiltInWorkgroupId:
1262 type = "uint3";
1263 break;
1264
1265 case BuiltInLocalInvocationIndex:
1266 type = "uint";
1267 break;
1268
1269 case BuiltInFrontFacing:
1270 type = "bool";
1271 break;
1272
1273 case BuiltInNumWorkgroups:
1274 case BuiltInPointCoord:
1275 // Handled specially.
1276 break;
1277
1278 case BuiltInSubgroupLocalInvocationId:
1279 case BuiltInSubgroupSize:
1280 if (hlsl_options.shader_model < 60)
1281 SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops.");
1282 break;
1283
1284 case BuiltInSubgroupEqMask:
1285 case BuiltInSubgroupLtMask:
1286 case BuiltInSubgroupLeMask:
1287 case BuiltInSubgroupGtMask:
1288 case BuiltInSubgroupGeMask:
1289 if (hlsl_options.shader_model < 60)
1290 SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops.");
1291 type = "uint4";
1292 break;
1293
1294 case BuiltInHelperInvocation:
1295 if (hlsl_options.shader_model < 50)
1296 SPIRV_CROSS_THROW("Need SM 5.0 for Helper Invocation.");
1297 break;
1298
1299 case BuiltInClipDistance:
1300 array_size = clip_distance_count;
1301 type = "float";
1302 break;
1303
1304 case BuiltInCullDistance:
1305 array_size = cull_distance_count;
1306 type = "float";
1307 break;
1308
1309 case BuiltInSampleMask:
1310 if (active_input_builtins.get(bit: BuiltInSampleMask))
1311 type = sample_mask_in_basetype == SPIRType::UInt ? "uint" : "int";
1312 else
1313 type = sample_mask_out_basetype == SPIRType::UInt ? "uint" : "int";
1314 array_size = 1;
1315 break;
1316
1317 case BuiltInPrimitiveId:
1318 case BuiltInViewIndex:
1319 case BuiltInLayer:
1320 type = "uint";
1321 break;
1322
1323 case BuiltInViewportIndex:
1324 case BuiltInPrimitiveShadingRateKHR:
1325 case BuiltInPrimitiveLineIndicesEXT:
1326 case BuiltInCullPrimitiveEXT:
1327 type = "uint";
1328 break;
1329
1330 default:
1331 SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin)));
1332 }
1333
1334 StorageClass storage = active_input_builtins.get(bit: i) ? StorageClassInput : StorageClassOutput;
1335
1336 if (type)
1337 {
1338 if (array_size)
1339 statement(ts: "static ", ts&: type, ts: " ", ts: builtin_to_glsl(builtin, storage), ts: "[", ts&: array_size, ts: "]", ts&: init_expr, ts: ";");
1340 else
1341 statement(ts: "static ", ts&: type, ts: " ", ts: builtin_to_glsl(builtin, storage), ts&: init_expr, ts: ";");
1342 }
1343
1344 // SampleMask can be both in and out with sample builtin, in this case we have already
1345 // declared the input variable and we need to add the output one now.
1346 if (builtin == BuiltInSampleMask && storage == StorageClassInput && this->active_output_builtins.get(bit: i))
1347 {
1348 type = sample_mask_out_basetype == SPIRType::UInt ? "uint" : "int";
1349 if (array_size)
1350 statement(ts: "static ", ts&: type, ts: " ", ts: this->builtin_to_glsl(builtin, storage: StorageClassOutput), ts: "[", ts&: array_size, ts: "]", ts&: init_expr, ts: ";");
1351 else
1352 statement(ts: "static ", ts&: type, ts: " ", ts: this->builtin_to_glsl(builtin, storage: StorageClassOutput), ts&: init_expr, ts: ";");
1353 }
1354 });
1355
1356 if (base_vertex_info.used)
1357 {
1358 string binding_info;
1359 if (base_vertex_info.explicit_binding)
1360 {
1361 binding_info = join(ts: " : register(b", ts&: base_vertex_info.register_index);
1362 if (base_vertex_info.register_space)
1363 binding_info += join(ts: ", space", ts&: base_vertex_info.register_space);
1364 binding_info += ")";
1365 }
1366 statement(ts: "cbuffer SPIRV_Cross_VertexInfo", ts&: binding_info);
1367 begin_scope();
1368 statement(ts: "int SPIRV_Cross_BaseVertex;");
1369 statement(ts: "int SPIRV_Cross_BaseInstance;");
1370 end_scope_decl();
1371 statement(ts: "");
1372 }
1373}
1374
1375void CompilerHLSL::set_hlsl_aux_buffer_binding(HLSLAuxBinding binding, uint32_t register_index, uint32_t register_space)
1376{
1377 if (binding == HLSL_AUX_BINDING_BASE_VERTEX_INSTANCE)
1378 {
1379 base_vertex_info.explicit_binding = true;
1380 base_vertex_info.register_space = register_space;
1381 base_vertex_info.register_index = register_index;
1382 }
1383}
1384
1385void CompilerHLSL::unset_hlsl_aux_buffer_binding(HLSLAuxBinding binding)
1386{
1387 if (binding == HLSL_AUX_BINDING_BASE_VERTEX_INSTANCE)
1388 base_vertex_info.explicit_binding = false;
1389}
1390
1391bool CompilerHLSL::is_hlsl_aux_buffer_binding_used(HLSLAuxBinding binding) const
1392{
1393 if (binding == HLSL_AUX_BINDING_BASE_VERTEX_INSTANCE)
1394 return base_vertex_info.used;
1395 else
1396 return false;
1397}
1398
1399void CompilerHLSL::emit_composite_constants()
1400{
1401 // HLSL cannot declare structs or arrays inline, so we must move them out to
1402 // global constants directly.
1403 bool emitted = false;
1404
1405 ir.for_each_typed_id<SPIRConstant>(op: [&](uint32_t, SPIRConstant &c) {
1406 if (c.specialization)
1407 return;
1408
1409 auto &type = this->get<SPIRType>(id: c.constant_type);
1410
1411 if (type.basetype == SPIRType::Struct && is_builtin_type(type))
1412 return;
1413
1414 if (type.basetype == SPIRType::Struct || !type.array.empty())
1415 {
1416 add_resource_name(id: c.self);
1417 auto name = to_name(id: c.self);
1418 statement(ts: "static const ", ts: variable_decl(type, name), ts: " = ", ts: constant_expression(c), ts: ";");
1419 emitted = true;
1420 }
1421 });
1422
1423 if (emitted)
1424 statement(ts: "");
1425}
1426
1427void CompilerHLSL::emit_specialization_constants_and_structs()
1428{
1429 bool emitted = false;
1430 SpecializationConstant wg_x, wg_y, wg_z;
1431 ID workgroup_size_id = get_work_group_size_specialization_constants(x&: wg_x, y&: wg_y, z&: wg_z);
1432
1433 std::unordered_set<TypeID> io_block_types;
1434 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, const SPIRVariable &var) {
1435 auto &type = this->get<SPIRType>(id: var.basetype);
1436 if ((var.storage == StorageClassInput || var.storage == StorageClassOutput) &&
1437 !var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
1438 interface_variable_exists_in_entry_point(id: var.self) &&
1439 has_decoration(id: type.self, decoration: DecorationBlock))
1440 {
1441 io_block_types.insert(x: type.self);
1442 }
1443 });
1444
1445 auto loop_lock = ir.create_loop_hard_lock();
1446 for (auto &id_ : ir.ids_for_constant_undef_or_type)
1447 {
1448 auto &id = ir.ids[id_];
1449
1450 if (id.get_type() == TypeConstant)
1451 {
1452 auto &c = id.get<SPIRConstant>();
1453
1454 if (c.self == workgroup_size_id)
1455 {
1456 statement(ts: "static const uint3 gl_WorkGroupSize = ",
1457 ts: constant_expression(c: get<SPIRConstant>(id: workgroup_size_id)), ts: ";");
1458 emitted = true;
1459 }
1460 else if (c.specialization)
1461 {
1462 auto &type = get<SPIRType>(id: c.constant_type);
1463 add_resource_name(id: c.self);
1464 auto name = to_name(id: c.self);
1465
1466 if (has_decoration(id: c.self, decoration: DecorationSpecId))
1467 {
1468 // HLSL does not support specialization constants, so fallback to macros.
1469 c.specialization_constant_macro_name =
1470 constant_value_macro_name(id: get_decoration(id: c.self, decoration: DecorationSpecId));
1471
1472 statement(ts: "#ifndef ", ts&: c.specialization_constant_macro_name);
1473 statement(ts: "#define ", ts&: c.specialization_constant_macro_name, ts: " ", ts: constant_expression(c));
1474 statement(ts: "#endif");
1475 statement(ts: "static const ", ts: variable_decl(type, name), ts: " = ", ts&: c.specialization_constant_macro_name, ts: ";");
1476 }
1477 else
1478 statement(ts: "static const ", ts: variable_decl(type, name), ts: " = ", ts: constant_expression(c), ts: ";");
1479
1480 emitted = true;
1481 }
1482 }
1483 else if (id.get_type() == TypeConstantOp)
1484 {
1485 auto &c = id.get<SPIRConstantOp>();
1486 auto &type = get<SPIRType>(id: c.basetype);
1487 add_resource_name(id: c.self);
1488 auto name = to_name(id: c.self);
1489 statement(ts: "static const ", ts: variable_decl(type, name), ts: " = ", ts: constant_op_expression(cop: c), ts: ";");
1490 emitted = true;
1491 }
1492 else if (id.get_type() == TypeType)
1493 {
1494 auto &type = id.get<SPIRType>();
1495 bool is_non_io_block = has_decoration(id: type.self, decoration: DecorationBlock) &&
1496 io_block_types.count(x: type.self) == 0;
1497 bool is_buffer_block = has_decoration(id: type.self, decoration: DecorationBufferBlock);
1498 if (type.basetype == SPIRType::Struct && type.array.empty() &&
1499 !type.pointer && !is_non_io_block && !is_buffer_block)
1500 {
1501 if (emitted)
1502 statement(ts: "");
1503 emitted = false;
1504
1505 emit_struct(type);
1506 }
1507 }
1508 else if (id.get_type() == TypeUndef)
1509 {
1510 auto &undef = id.get<SPIRUndef>();
1511 auto &type = this->get<SPIRType>(id: undef.basetype);
1512 // OpUndef can be void for some reason ...
1513 if (type.basetype == SPIRType::Void)
1514 return;
1515
1516 string initializer;
1517 if (options.force_zero_initialized_variables && type_can_zero_initialize(type))
1518 initializer = join(ts: " = ", ts: to_zero_initialized_expression(type_id: undef.basetype));
1519
1520 statement(ts: "static ", ts: variable_decl(type, name: to_name(id: undef.self), id: undef.self), ts&: initializer, ts: ";");
1521 emitted = true;
1522 }
1523 }
1524
1525 if (emitted)
1526 statement(ts: "");
1527}
1528
1529void CompilerHLSL::replace_illegal_names()
1530{
1531 static const unordered_set<string> keywords = {
1532 // Additional HLSL specific keywords.
1533 // From https://docs.microsoft.com/en-US/windows/win32/direct3dhlsl/dx-graphics-hlsl-appendix-keywords
1534 "AppendStructuredBuffer", "asm", "asm_fragment",
1535 "BlendState", "bool", "break", "Buffer", "ByteAddressBuffer",
1536 "case", "cbuffer", "centroid", "class", "column_major", "compile",
1537 "compile_fragment", "CompileShader", "const", "continue", "ComputeShader",
1538 "ConsumeStructuredBuffer",
1539 "default", "DepthStencilState", "DepthStencilView", "discard", "do",
1540 "double", "DomainShader", "dword",
1541 "else", "export", "false", "float", "for", "fxgroup",
1542 "GeometryShader", "groupshared", "half", "HullShader",
1543 "indices", "if", "in", "inline", "inout", "InputPatch", "int", "interface",
1544 "line", "lineadj", "linear", "LineStream",
1545 "matrix", "min16float", "min10float", "min16int", "min16uint",
1546 "namespace", "nointerpolation", "noperspective", "NULL",
1547 "out", "OutputPatch",
1548 "payload", "packoffset", "pass", "pixelfragment", "PixelShader", "point",
1549 "PointStream", "precise", "RasterizerState", "RenderTargetView",
1550 "return", "register", "row_major", "RWBuffer", "RWByteAddressBuffer",
1551 "RWStructuredBuffer", "RWTexture1D", "RWTexture1DArray", "RWTexture2D",
1552 "RWTexture2DArray", "RWTexture3D", "sample", "sampler", "SamplerState",
1553 "SamplerComparisonState", "shared", "snorm", "stateblock", "stateblock_state",
1554 "static", "string", "struct", "switch", "StructuredBuffer", "tbuffer",
1555 "technique", "technique10", "technique11", "texture", "Texture1D",
1556 "Texture1DArray", "Texture2D", "Texture2DArray", "Texture2DMS", "Texture2DMSArray",
1557 "Texture3D", "TextureCube", "TextureCubeArray", "true", "typedef", "triangle",
1558 "triangleadj", "TriangleStream", "uint", "uniform", "unorm", "unsigned",
1559 "vector", "vertexfragment", "VertexShader", "vertices", "void", "volatile", "while",
1560 };
1561
1562 CompilerGLSL::replace_illegal_names(keywords);
1563 CompilerGLSL::replace_illegal_names();
1564}
1565
1566SPIRType::BaseType CompilerHLSL::get_builtin_basetype(BuiltIn builtin, SPIRType::BaseType default_type)
1567{
1568 switch (builtin)
1569 {
1570 case BuiltInSampleMask:
1571 // We declare sample mask array with module type, so always use default_type here.
1572 return default_type;
1573 default:
1574 return CompilerGLSL::get_builtin_basetype(builtin, default_type);
1575 }
1576}
1577
1578void CompilerHLSL::emit_resources()
1579{
1580 auto &execution = get_entry_point();
1581
1582 replace_illegal_names();
1583
1584 switch (execution.model)
1585 {
1586 case ExecutionModelGeometry:
1587 case ExecutionModelTessellationControl:
1588 case ExecutionModelTessellationEvaluation:
1589 case ExecutionModelMeshEXT:
1590 fixup_implicit_builtin_block_names(model: execution.model);
1591 break;
1592
1593 default:
1594 break;
1595 }
1596
1597 emit_specialization_constants_and_structs();
1598 emit_composite_constants();
1599
1600 bool emitted = false;
1601
1602 // Output UBOs and SSBOs
1603 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
1604 auto &type = this->get<SPIRType>(id: var.basetype);
1605
1606 bool is_block_storage = type.storage == StorageClassStorageBuffer || type.storage == StorageClassUniform;
1607 bool has_block_flags = ir.meta[type.self].decoration.decoration_flags.get(bit: DecorationBlock) ||
1608 ir.meta[type.self].decoration.decoration_flags.get(bit: DecorationBufferBlock);
1609
1610 if (var.storage != StorageClassFunction && type.pointer && is_block_storage && !is_hidden_variable(var) &&
1611 has_block_flags)
1612 {
1613 emit_buffer_block(type: var);
1614 emitted = true;
1615 }
1616 });
1617
1618 // Output push constant blocks
1619 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
1620 auto &type = this->get<SPIRType>(id: var.basetype);
1621 if (var.storage != StorageClassFunction && type.pointer && type.storage == StorageClassPushConstant &&
1622 !is_hidden_variable(var))
1623 {
1624 emit_push_constant_block(var);
1625 emitted = true;
1626 }
1627 });
1628
1629 if (execution.model == ExecutionModelVertex && hlsl_options.shader_model <= 30 &&
1630 active_output_builtins.get(bit: BuiltInPosition))
1631 {
1632 statement(ts: "uniform float4 gl_HalfPixel;");
1633 emitted = true;
1634 }
1635
1636 bool skip_separate_image_sampler = !combined_image_samplers.empty() || hlsl_options.shader_model <= 30;
1637
1638 // Output Uniform Constants (values, samplers, images, etc).
1639 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
1640 auto &type = this->get<SPIRType>(id: var.basetype);
1641
1642 // If we're remapping separate samplers and images, only emit the combined samplers.
1643 if (skip_separate_image_sampler)
1644 {
1645 // Sampler buffers are always used without a sampler, and they will also work in regular D3D.
1646 bool sampler_buffer = type.basetype == SPIRType::Image && type.image.dim == DimBuffer;
1647 bool separate_image = type.basetype == SPIRType::Image && type.image.sampled == 1;
1648 bool separate_sampler = type.basetype == SPIRType::Sampler;
1649 if (!sampler_buffer && (separate_image || separate_sampler))
1650 return;
1651 }
1652
1653 if (var.storage != StorageClassFunction && !is_builtin_variable(var) && !var.remapped_variable &&
1654 type.pointer && (type.storage == StorageClassUniformConstant || type.storage == StorageClassAtomicCounter) &&
1655 !is_hidden_variable(var))
1656 {
1657 emit_uniform(var);
1658 emitted = true;
1659 }
1660 });
1661
1662 if (emitted)
1663 statement(ts: "");
1664 emitted = false;
1665
1666 // Emit builtin input and output variables here.
1667 emit_builtin_variables();
1668
1669 if (execution.model != ExecutionModelMeshEXT)
1670 {
1671 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
1672 auto &type = this->get<SPIRType>(id: var.basetype);
1673
1674 if (var.storage != StorageClassFunction && !var.remapped_variable && type.pointer &&
1675 (var.storage == StorageClassInput || var.storage == StorageClassOutput) && !is_builtin_variable(var) &&
1676 interface_variable_exists_in_entry_point(id: var.self))
1677 {
1678 // Builtin variables are handled separately.
1679 emit_interface_block_globally(var);
1680 emitted = true;
1681 }
1682 });
1683 }
1684
1685 if (emitted)
1686 statement(ts: "");
1687 emitted = false;
1688
1689 require_input = false;
1690 require_output = false;
1691 unordered_set<uint32_t> active_inputs;
1692 unordered_set<uint32_t> active_outputs;
1693
1694 struct IOVariable
1695 {
1696 const SPIRVariable *var;
1697 uint32_t location;
1698 uint32_t block_member_index;
1699 bool block;
1700 };
1701
1702 SmallVector<IOVariable> input_variables;
1703 SmallVector<IOVariable> output_variables;
1704
1705 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
1706 auto &type = this->get<SPIRType>(id: var.basetype);
1707 bool block = has_decoration(id: type.self, decoration: DecorationBlock);
1708
1709 if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
1710 return;
1711
1712 if (!var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
1713 interface_variable_exists_in_entry_point(id: var.self))
1714 {
1715 if (block)
1716 {
1717 for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++)
1718 {
1719 uint32_t location = get_declared_member_location(var, mbr_idx: i, strip_array: false);
1720 if (var.storage == StorageClassInput)
1721 input_variables.push_back(t: { .var: &var, .location: location, .block_member_index: i, .block: true });
1722 else
1723 output_variables.push_back(t: { .var: &var, .location: location, .block_member_index: i, .block: true });
1724 }
1725 }
1726 else
1727 {
1728 uint32_t location = get_decoration(id: var.self, decoration: DecorationLocation);
1729 if (var.storage == StorageClassInput)
1730 input_variables.push_back(t: { .var: &var, .location: location, .block_member_index: 0, .block: false });
1731 else
1732 output_variables.push_back(t: { .var: &var, .location: location, .block_member_index: 0, .block: false });
1733 }
1734 }
1735 });
1736
1737 const auto variable_compare = [&](const IOVariable &a, const IOVariable &b) -> bool {
1738 // Sort input and output variables based on, from more robust to less robust:
1739 // - Location
1740 // - Variable has a location
1741 // - Name comparison
1742 // - Variable has a name
1743 // - Fallback: ID
1744 bool has_location_a = a.block || has_decoration(id: a.var->self, decoration: DecorationLocation);
1745 bool has_location_b = b.block || has_decoration(id: b.var->self, decoration: DecorationLocation);
1746
1747 if (has_location_a && has_location_b)
1748 return a.location < b.location;
1749 else if (has_location_a && !has_location_b)
1750 return true;
1751 else if (!has_location_a && has_location_b)
1752 return false;
1753
1754 const auto &name1 = to_name(id: a.var->self);
1755 const auto &name2 = to_name(id: b.var->self);
1756
1757 if (name1.empty() && name2.empty())
1758 return a.var->self < b.var->self;
1759 else if (name1.empty())
1760 return true;
1761 else if (name2.empty())
1762 return false;
1763
1764 return name1.compare(str: name2) < 0;
1765 };
1766
1767 auto input_builtins = active_input_builtins;
1768 input_builtins.clear(bit: BuiltInNumWorkgroups);
1769 input_builtins.clear(bit: BuiltInPointCoord);
1770 input_builtins.clear(bit: BuiltInSubgroupSize);
1771 input_builtins.clear(bit: BuiltInSubgroupLocalInvocationId);
1772 input_builtins.clear(bit: BuiltInSubgroupEqMask);
1773 input_builtins.clear(bit: BuiltInSubgroupLtMask);
1774 input_builtins.clear(bit: BuiltInSubgroupLeMask);
1775 input_builtins.clear(bit: BuiltInSubgroupGtMask);
1776 input_builtins.clear(bit: BuiltInSubgroupGeMask);
1777
1778 if (!input_variables.empty() || !input_builtins.empty())
1779 {
1780 require_input = true;
1781 statement(ts: "struct SPIRV_Cross_Input");
1782
1783 begin_scope();
1784 sort(first: input_variables.begin(), last: input_variables.end(), comp: variable_compare);
1785 for (auto &var : input_variables)
1786 {
1787 if (var.block)
1788 emit_interface_block_member_in_struct(var: *var.var, member_index: var.block_member_index, location: var.location, active_locations&: active_inputs);
1789 else
1790 emit_interface_block_in_struct(var: *var.var, active_locations&: active_inputs);
1791 }
1792 emit_builtin_inputs_in_struct();
1793 end_scope_decl();
1794 statement(ts: "");
1795 }
1796
1797 const bool is_mesh_shader = execution.model == ExecutionModelMeshEXT;
1798 if (!output_variables.empty() || !active_output_builtins.empty())
1799 {
1800 sort(first: output_variables.begin(), last: output_variables.end(), comp: variable_compare);
1801 require_output = !is_mesh_shader;
1802
1803 statement(ts: is_mesh_shader ? "struct gl_MeshPerVertexEXT" : "struct SPIRV_Cross_Output");
1804 begin_scope();
1805 for (auto &var : output_variables)
1806 {
1807 if (is_per_primitive_variable(var: *var.var))
1808 continue;
1809 if (var.block && is_mesh_shader && var.block_member_index != 0)
1810 continue;
1811 if (var.block && !is_mesh_shader)
1812 emit_interface_block_member_in_struct(var: *var.var, member_index: var.block_member_index, location: var.location, active_locations&: active_outputs);
1813 else
1814 emit_interface_block_in_struct(var: *var.var, active_locations&: active_outputs);
1815 }
1816 emit_builtin_outputs_in_struct();
1817 if (!is_mesh_shader)
1818 emit_builtin_primitive_outputs_in_struct();
1819 end_scope_decl();
1820 statement(ts: "");
1821
1822 if (is_mesh_shader)
1823 {
1824 statement(ts: "struct gl_MeshPerPrimitiveEXT");
1825 begin_scope();
1826 for (auto &var : output_variables)
1827 {
1828 if (!is_per_primitive_variable(var: *var.var))
1829 continue;
1830 if (var.block && var.block_member_index != 0)
1831 continue;
1832
1833 emit_interface_block_in_struct(var: *var.var, active_locations&: active_outputs);
1834 }
1835 emit_builtin_primitive_outputs_in_struct();
1836 end_scope_decl();
1837 statement(ts: "");
1838 }
1839 }
1840
1841 // Global variables.
1842 for (auto global : global_variables)
1843 {
1844 auto &var = get<SPIRVariable>(id: global);
1845 if (is_hidden_variable(var, include_builtins: true))
1846 continue;
1847
1848 if (var.storage == StorageClassTaskPayloadWorkgroupEXT && is_mesh_shader)
1849 continue;
1850
1851 if (var.storage != StorageClassOutput)
1852 {
1853 if (!variable_is_lut(var))
1854 {
1855 add_resource_name(id: var.self);
1856
1857 const char *storage = nullptr;
1858 switch (var.storage)
1859 {
1860 case StorageClassWorkgroup:
1861 case StorageClassTaskPayloadWorkgroupEXT:
1862 storage = "groupshared";
1863 break;
1864
1865 default:
1866 storage = "static";
1867 break;
1868 }
1869
1870 string initializer;
1871 if (options.force_zero_initialized_variables && var.storage == StorageClassPrivate &&
1872 !var.initializer && !var.static_expression && type_can_zero_initialize(type: get_variable_data_type(var)))
1873 {
1874 initializer = join(ts: " = ", ts: to_zero_initialized_expression(type_id: get_variable_data_type_id(var)));
1875 }
1876 statement(ts&: storage, ts: " ", ts: variable_decl(variable: var), ts&: initializer, ts: ";");
1877
1878 emitted = true;
1879 }
1880 }
1881 }
1882
1883 if (emitted)
1884 statement(ts: "");
1885
1886 if (requires_op_fmod)
1887 {
1888 static const char *types[] = {
1889 "float",
1890 "float2",
1891 "float3",
1892 "float4",
1893 };
1894
1895 for (auto &type : types)
1896 {
1897 statement(ts&: type, ts: " mod(", ts&: type, ts: " x, ", ts&: type, ts: " y)");
1898 begin_scope();
1899 statement(ts: "return x - y * floor(x / y);");
1900 end_scope();
1901 statement(ts: "");
1902 }
1903 }
1904
1905 emit_texture_size_variants(variant_mask: required_texture_size_variants.srv, vecsize_qualifier: "4", uav: false, type_qualifier: "");
1906 for (uint32_t norm = 0; norm < 3; norm++)
1907 {
1908 for (uint32_t comp = 0; comp < 4; comp++)
1909 {
1910 static const char *qualifiers[] = { "", "unorm ", "snorm " };
1911 static const char *vecsizes[] = { "", "2", "3", "4" };
1912 emit_texture_size_variants(variant_mask: required_texture_size_variants.uav[norm][comp], vecsize_qualifier: vecsizes[comp], uav: true,
1913 type_qualifier: qualifiers[norm]);
1914 }
1915 }
1916
1917 if (requires_fp16_packing)
1918 {
1919 // HLSL does not pack into a single word sadly :(
1920 statement(ts: "uint spvPackHalf2x16(float2 value)");
1921 begin_scope();
1922 statement(ts: "uint2 Packed = f32tof16(value);");
1923 statement(ts: "return Packed.x | (Packed.y << 16);");
1924 end_scope();
1925 statement(ts: "");
1926
1927 statement(ts: "float2 spvUnpackHalf2x16(uint value)");
1928 begin_scope();
1929 statement(ts: "return f16tof32(uint2(value & 0xffff, value >> 16));");
1930 end_scope();
1931 statement(ts: "");
1932 }
1933
1934 if (requires_uint2_packing)
1935 {
1936 statement(ts: "uint64_t spvPackUint2x32(uint2 value)");
1937 begin_scope();
1938 statement(ts: "return (uint64_t(value.y) << 32) | uint64_t(value.x);");
1939 end_scope();
1940 statement(ts: "");
1941
1942 statement(ts: "uint2 spvUnpackUint2x32(uint64_t value)");
1943 begin_scope();
1944 statement(ts: "uint2 Unpacked;");
1945 statement(ts: "Unpacked.x = uint(value & 0xffffffff);");
1946 statement(ts: "Unpacked.y = uint(value >> 32);");
1947 statement(ts: "return Unpacked;");
1948 end_scope();
1949 statement(ts: "");
1950 }
1951
1952 if (requires_explicit_fp16_packing)
1953 {
1954 // HLSL does not pack into a single word sadly :(
1955 statement(ts: "uint spvPackFloat2x16(min16float2 value)");
1956 begin_scope();
1957 statement(ts: "uint2 Packed = f32tof16(value);");
1958 statement(ts: "return Packed.x | (Packed.y << 16);");
1959 end_scope();
1960 statement(ts: "");
1961
1962 statement(ts: "min16float2 spvUnpackFloat2x16(uint value)");
1963 begin_scope();
1964 statement(ts: "return min16float2(f16tof32(uint2(value & 0xffff, value >> 16)));");
1965 end_scope();
1966 statement(ts: "");
1967 }
1968
1969 // HLSL does not seem to have builtins for these operation, so roll them by hand ...
1970 if (requires_unorm8_packing)
1971 {
1972 statement(ts: "uint spvPackUnorm4x8(float4 value)");
1973 begin_scope();
1974 statement(ts: "uint4 Packed = uint4(round(saturate(value) * 255.0));");
1975 statement(ts: "return Packed.x | (Packed.y << 8) | (Packed.z << 16) | (Packed.w << 24);");
1976 end_scope();
1977 statement(ts: "");
1978
1979 statement(ts: "float4 spvUnpackUnorm4x8(uint value)");
1980 begin_scope();
1981 statement(ts: "uint4 Packed = uint4(value & 0xff, (value >> 8) & 0xff, (value >> 16) & 0xff, value >> 24);");
1982 statement(ts: "return float4(Packed) / 255.0;");
1983 end_scope();
1984 statement(ts: "");
1985 }
1986
1987 if (requires_snorm8_packing)
1988 {
1989 statement(ts: "uint spvPackSnorm4x8(float4 value)");
1990 begin_scope();
1991 statement(ts: "int4 Packed = int4(round(clamp(value, -1.0, 1.0) * 127.0)) & 0xff;");
1992 statement(ts: "return uint(Packed.x | (Packed.y << 8) | (Packed.z << 16) | (Packed.w << 24));");
1993 end_scope();
1994 statement(ts: "");
1995
1996 statement(ts: "float4 spvUnpackSnorm4x8(uint value)");
1997 begin_scope();
1998 statement(ts: "int SignedValue = int(value);");
1999 statement(ts: "int4 Packed = int4(SignedValue << 24, SignedValue << 16, SignedValue << 8, SignedValue) >> 24;");
2000 statement(ts: "return clamp(float4(Packed) / 127.0, -1.0, 1.0);");
2001 end_scope();
2002 statement(ts: "");
2003 }
2004
2005 if (requires_unorm16_packing)
2006 {
2007 statement(ts: "uint spvPackUnorm2x16(float2 value)");
2008 begin_scope();
2009 statement(ts: "uint2 Packed = uint2(round(saturate(value) * 65535.0));");
2010 statement(ts: "return Packed.x | (Packed.y << 16);");
2011 end_scope();
2012 statement(ts: "");
2013
2014 statement(ts: "float2 spvUnpackUnorm2x16(uint value)");
2015 begin_scope();
2016 statement(ts: "uint2 Packed = uint2(value & 0xffff, value >> 16);");
2017 statement(ts: "return float2(Packed) / 65535.0;");
2018 end_scope();
2019 statement(ts: "");
2020 }
2021
2022 if (requires_snorm16_packing)
2023 {
2024 statement(ts: "uint spvPackSnorm2x16(float2 value)");
2025 begin_scope();
2026 statement(ts: "int2 Packed = int2(round(clamp(value, -1.0, 1.0) * 32767.0)) & 0xffff;");
2027 statement(ts: "return uint(Packed.x | (Packed.y << 16));");
2028 end_scope();
2029 statement(ts: "");
2030
2031 statement(ts: "float2 spvUnpackSnorm2x16(uint value)");
2032 begin_scope();
2033 statement(ts: "int SignedValue = int(value);");
2034 statement(ts: "int2 Packed = int2(SignedValue << 16, SignedValue) >> 16;");
2035 statement(ts: "return clamp(float2(Packed) / 32767.0, -1.0, 1.0);");
2036 end_scope();
2037 statement(ts: "");
2038 }
2039
2040 if (requires_bitfield_insert)
2041 {
2042 static const char *types[] = { "uint", "uint2", "uint3", "uint4" };
2043 for (auto &type : types)
2044 {
2045 statement(ts&: type, ts: " spvBitfieldInsert(", ts&: type, ts: " Base, ", ts&: type, ts: " Insert, uint Offset, uint Count)");
2046 begin_scope();
2047 statement(ts: "uint Mask = Count == 32 ? 0xffffffff : (((1u << Count) - 1) << (Offset & 31));");
2048 statement(ts: "return (Base & ~Mask) | ((Insert << Offset) & Mask);");
2049 end_scope();
2050 statement(ts: "");
2051 }
2052 }
2053
2054 if (requires_bitfield_extract)
2055 {
2056 static const char *unsigned_types[] = { "uint", "uint2", "uint3", "uint4" };
2057 for (auto &type : unsigned_types)
2058 {
2059 statement(ts&: type, ts: " spvBitfieldUExtract(", ts&: type, ts: " Base, uint Offset, uint Count)");
2060 begin_scope();
2061 statement(ts: "uint Mask = Count == 32 ? 0xffffffff : ((1 << Count) - 1);");
2062 statement(ts: "return (Base >> Offset) & Mask;");
2063 end_scope();
2064 statement(ts: "");
2065 }
2066
2067 // In this overload, we will have to do sign-extension, which we will emulate by shifting up and down.
2068 static const char *signed_types[] = { "int", "int2", "int3", "int4" };
2069 for (auto &type : signed_types)
2070 {
2071 statement(ts&: type, ts: " spvBitfieldSExtract(", ts&: type, ts: " Base, int Offset, int Count)");
2072 begin_scope();
2073 statement(ts: "int Mask = Count == 32 ? -1 : ((1 << Count) - 1);");
2074 statement(ts&: type, ts: " Masked = (Base >> Offset) & Mask;");
2075 statement(ts: "int ExtendShift = (32 - Count) & 31;");
2076 statement(ts: "return (Masked << ExtendShift) >> ExtendShift;");
2077 end_scope();
2078 statement(ts: "");
2079 }
2080 }
2081
2082 if (requires_inverse_2x2)
2083 {
2084 statement(ts: "// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
2085 statement(ts: "// adjoint and dividing by the determinant. The contents of the matrix are changed.");
2086 statement(ts: "float2x2 spvInverse(float2x2 m)");
2087 begin_scope();
2088 statement(ts: "float2x2 adj; // The adjoint matrix (inverse after dividing by determinant)");
2089 statement_no_indent(ts: "");
2090 statement(ts: "// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
2091 statement(ts: "adj[0][0] = m[1][1];");
2092 statement(ts: "adj[0][1] = -m[0][1];");
2093 statement_no_indent(ts: "");
2094 statement(ts: "adj[1][0] = -m[1][0];");
2095 statement(ts: "adj[1][1] = m[0][0];");
2096 statement_no_indent(ts: "");
2097 statement(ts: "// Calculate the determinant as a combination of the cofactors of the first row.");
2098 statement(ts: "float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]);");
2099 statement_no_indent(ts: "");
2100 statement(ts: "// Divide the classical adjoint matrix by the determinant.");
2101 statement(ts: "// If determinant is zero, matrix is not invertable, so leave it unchanged.");
2102 statement(ts: "return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
2103 end_scope();
2104 statement(ts: "");
2105 }
2106
2107 if (requires_inverse_3x3)
2108 {
2109 statement(ts: "// Returns the determinant of a 2x2 matrix.");
2110 statement(ts: "float spvDet2x2(float a1, float a2, float b1, float b2)");
2111 begin_scope();
2112 statement(ts: "return a1 * b2 - b1 * a2;");
2113 end_scope();
2114 statement_no_indent(ts: "");
2115 statement(ts: "// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
2116 statement(ts: "// adjoint and dividing by the determinant. The contents of the matrix are changed.");
2117 statement(ts: "float3x3 spvInverse(float3x3 m)");
2118 begin_scope();
2119 statement(ts: "float3x3 adj; // The adjoint matrix (inverse after dividing by determinant)");
2120 statement_no_indent(ts: "");
2121 statement(ts: "// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
2122 statement(ts: "adj[0][0] = spvDet2x2(m[1][1], m[1][2], m[2][1], m[2][2]);");
2123 statement(ts: "adj[0][1] = -spvDet2x2(m[0][1], m[0][2], m[2][1], m[2][2]);");
2124 statement(ts: "adj[0][2] = spvDet2x2(m[0][1], m[0][2], m[1][1], m[1][2]);");
2125 statement_no_indent(ts: "");
2126 statement(ts: "adj[1][0] = -spvDet2x2(m[1][0], m[1][2], m[2][0], m[2][2]);");
2127 statement(ts: "adj[1][1] = spvDet2x2(m[0][0], m[0][2], m[2][0], m[2][2]);");
2128 statement(ts: "adj[1][2] = -spvDet2x2(m[0][0], m[0][2], m[1][0], m[1][2]);");
2129 statement_no_indent(ts: "");
2130 statement(ts: "adj[2][0] = spvDet2x2(m[1][0], m[1][1], m[2][0], m[2][1]);");
2131 statement(ts: "adj[2][1] = -spvDet2x2(m[0][0], m[0][1], m[2][0], m[2][1]);");
2132 statement(ts: "adj[2][2] = spvDet2x2(m[0][0], m[0][1], m[1][0], m[1][1]);");
2133 statement_no_indent(ts: "");
2134 statement(ts: "// Calculate the determinant as a combination of the cofactors of the first row.");
2135 statement(ts: "float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]);");
2136 statement_no_indent(ts: "");
2137 statement(ts: "// Divide the classical adjoint matrix by the determinant.");
2138 statement(ts: "// If determinant is zero, matrix is not invertable, so leave it unchanged.");
2139 statement(ts: "return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
2140 end_scope();
2141 statement(ts: "");
2142 }
2143
2144 if (requires_inverse_4x4)
2145 {
2146 if (!requires_inverse_3x3)
2147 {
2148 statement(ts: "// Returns the determinant of a 2x2 matrix.");
2149 statement(ts: "float spvDet2x2(float a1, float a2, float b1, float b2)");
2150 begin_scope();
2151 statement(ts: "return a1 * b2 - b1 * a2;");
2152 end_scope();
2153 statement(ts: "");
2154 }
2155
2156 statement(ts: "// Returns the determinant of a 3x3 matrix.");
2157 statement(ts: "float spvDet3x3(float a1, float a2, float a3, float b1, float b2, float b3, float c1, "
2158 "float c2, float c3)");
2159 begin_scope();
2160 statement(ts: "return a1 * spvDet2x2(b2, b3, c2, c3) - b1 * spvDet2x2(a2, a3, c2, c3) + c1 * "
2161 "spvDet2x2(a2, a3, "
2162 "b2, b3);");
2163 end_scope();
2164 statement_no_indent(ts: "");
2165 statement(ts: "// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
2166 statement(ts: "// adjoint and dividing by the determinant. The contents of the matrix are changed.");
2167 statement(ts: "float4x4 spvInverse(float4x4 m)");
2168 begin_scope();
2169 statement(ts: "float4x4 adj; // The adjoint matrix (inverse after dividing by determinant)");
2170 statement_no_indent(ts: "");
2171 statement(ts: "// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
2172 statement(
2173 ts: "adj[0][0] = spvDet3x3(m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
2174 "m[3][3]);");
2175 statement(
2176 ts: "adj[0][1] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
2177 "m[3][3]);");
2178 statement(
2179 ts: "adj[0][2] = spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[3][1], m[3][2], "
2180 "m[3][3]);");
2181 statement(
2182 ts: "adj[0][3] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], "
2183 "m[2][3]);");
2184 statement_no_indent(ts: "");
2185 statement(
2186 ts: "adj[1][0] = -spvDet3x3(m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
2187 "m[3][3]);");
2188 statement(
2189 ts: "adj[1][1] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
2190 "m[3][3]);");
2191 statement(
2192 ts: "adj[1][2] = -spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[3][0], m[3][2], "
2193 "m[3][3]);");
2194 statement(
2195 ts: "adj[1][3] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], "
2196 "m[2][3]);");
2197 statement_no_indent(ts: "");
2198 statement(
2199 ts: "adj[2][0] = spvDet3x3(m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
2200 "m[3][3]);");
2201 statement(
2202 ts: "adj[2][1] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
2203 "m[3][3]);");
2204 statement(
2205 ts: "adj[2][2] = spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[3][0], m[3][1], "
2206 "m[3][3]);");
2207 statement(
2208 ts: "adj[2][3] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], "
2209 "m[2][3]);");
2210 statement_no_indent(ts: "");
2211 statement(
2212 ts: "adj[3][0] = -spvDet3x3(m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
2213 "m[3][2]);");
2214 statement(
2215 ts: "adj[3][1] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
2216 "m[3][2]);");
2217 statement(
2218 ts: "adj[3][2] = -spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[3][0], m[3][1], "
2219 "m[3][2]);");
2220 statement(
2221 ts: "adj[3][3] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], "
2222 "m[2][2]);");
2223 statement_no_indent(ts: "");
2224 statement(ts: "// Calculate the determinant as a combination of the cofactors of the first row.");
2225 statement(ts: "float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]) + (adj[0][3] "
2226 "* m[3][0]);");
2227 statement_no_indent(ts: "");
2228 statement(ts: "// Divide the classical adjoint matrix by the determinant.");
2229 statement(ts: "// If determinant is zero, matrix is not invertable, so leave it unchanged.");
2230 statement(ts: "return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
2231 end_scope();
2232 statement(ts: "");
2233 }
2234
2235 if (requires_scalar_reflect)
2236 {
2237 // FP16/FP64? No templates in HLSL.
2238 statement(ts: "float spvReflect(float i, float n)");
2239 begin_scope();
2240 statement(ts: "return i - 2.0 * dot(n, i) * n;");
2241 end_scope();
2242 statement(ts: "");
2243 }
2244
2245 if (requires_scalar_refract)
2246 {
2247 // FP16/FP64? No templates in HLSL.
2248 statement(ts: "float spvRefract(float i, float n, float eta)");
2249 begin_scope();
2250 statement(ts: "float NoI = n * i;");
2251 statement(ts: "float NoI2 = NoI * NoI;");
2252 statement(ts: "float k = 1.0 - eta * eta * (1.0 - NoI2);");
2253 statement(ts: "if (k < 0.0)");
2254 begin_scope();
2255 statement(ts: "return 0.0;");
2256 end_scope();
2257 statement(ts: "else");
2258 begin_scope();
2259 statement(ts: "return eta * i - (eta * NoI + sqrt(k)) * n;");
2260 end_scope();
2261 end_scope();
2262 statement(ts: "");
2263 }
2264
2265 if (requires_scalar_faceforward)
2266 {
2267 // FP16/FP64? No templates in HLSL.
2268 statement(ts: "float spvFaceForward(float n, float i, float nref)");
2269 begin_scope();
2270 statement(ts: "return i * nref < 0.0 ? n : -n;");
2271 end_scope();
2272 statement(ts: "");
2273 }
2274
2275 for (TypeID type_id : composite_selection_workaround_types)
2276 {
2277 // Need out variable since HLSL does not support returning arrays.
2278 auto &type = get<SPIRType>(id: type_id);
2279 auto type_str = type_to_glsl(type);
2280 auto type_arr_str = type_to_array_glsl(type, variable_id: 0);
2281 statement(ts: "void spvSelectComposite(out ", ts&: type_str, ts: " out_value", ts&: type_arr_str, ts: ", bool cond, ",
2282 ts&: type_str, ts: " true_val", ts&: type_arr_str, ts: ", ",
2283 ts&: type_str, ts: " false_val", ts&: type_arr_str, ts: ")");
2284 begin_scope();
2285 statement(ts: "if (cond)");
2286 begin_scope();
2287 statement(ts: "out_value = true_val;");
2288 end_scope();
2289 statement(ts: "else");
2290 begin_scope();
2291 statement(ts: "out_value = false_val;");
2292 end_scope();
2293 end_scope();
2294 statement(ts: "");
2295 }
2296
2297 if (is_mesh_shader && options.vertex.flip_vert_y)
2298 {
2299 statement(ts: "float4 spvFlipVertY(float4 v)");
2300 begin_scope();
2301 statement(ts: "return float4(v.x, -v.y, v.z, v.w);");
2302 end_scope();
2303 statement(ts: "");
2304 statement(ts: "float spvFlipVertY(float v)");
2305 begin_scope();
2306 statement(ts: "return -v;");
2307 end_scope();
2308 statement(ts: "");
2309 }
2310}
2311
2312void CompilerHLSL::emit_texture_size_variants(uint64_t variant_mask, const char *vecsize_qualifier, bool uav,
2313 const char *type_qualifier)
2314{
2315 if (variant_mask == 0)
2316 return;
2317
2318 static const char *types[QueryTypeCount] = { "float", "int", "uint" };
2319 static const char *dims[QueryDimCount] = { "Texture1D", "Texture1DArray", "Texture2D", "Texture2DArray",
2320 "Texture3D", "Buffer", "TextureCube", "TextureCubeArray",
2321 "Texture2DMS", "Texture2DMSArray" };
2322
2323 static const bool has_lod[QueryDimCount] = { true, true, true, true, true, false, true, true, false, false };
2324
2325 static const char *ret_types[QueryDimCount] = {
2326 "uint", "uint2", "uint2", "uint3", "uint3", "uint", "uint2", "uint3", "uint2", "uint3",
2327 };
2328
2329 static const uint32_t return_arguments[QueryDimCount] = {
2330 1, 2, 2, 3, 3, 1, 2, 3, 2, 3,
2331 };
2332
2333 for (uint32_t index = 0; index < QueryDimCount; index++)
2334 {
2335 for (uint32_t type_index = 0; type_index < QueryTypeCount; type_index++)
2336 {
2337 uint32_t bit = 16 * type_index + index;
2338 uint64_t mask = 1ull << bit;
2339
2340 if ((variant_mask & mask) == 0)
2341 continue;
2342
2343 statement(ts&: ret_types[index], ts: " spv", ts: (uav ? "Image" : "Texture"), ts: "Size(", ts: (uav ? "RW" : ""),
2344 ts&: dims[index], ts: "<", ts&: type_qualifier, ts&: types[type_index], ts&: vecsize_qualifier, ts: "> Tex, ",
2345 ts: (uav ? "" : "uint Level, "), ts: "out uint Param)");
2346 begin_scope();
2347 statement(ts&: ret_types[index], ts: " ret;");
2348 switch (return_arguments[index])
2349 {
2350 case 1:
2351 if (has_lod[index] && !uav)
2352 statement(ts: "Tex.GetDimensions(Level, ret.x, Param);");
2353 else
2354 {
2355 statement(ts: "Tex.GetDimensions(ret.x);");
2356 statement(ts: "Param = 0u;");
2357 }
2358 break;
2359 case 2:
2360 if (has_lod[index] && !uav)
2361 statement(ts: "Tex.GetDimensions(Level, ret.x, ret.y, Param);");
2362 else if (!uav)
2363 statement(ts: "Tex.GetDimensions(ret.x, ret.y, Param);");
2364 else
2365 {
2366 statement(ts: "Tex.GetDimensions(ret.x, ret.y);");
2367 statement(ts: "Param = 0u;");
2368 }
2369 break;
2370 case 3:
2371 if (has_lod[index] && !uav)
2372 statement(ts: "Tex.GetDimensions(Level, ret.x, ret.y, ret.z, Param);");
2373 else if (!uav)
2374 statement(ts: "Tex.GetDimensions(ret.x, ret.y, ret.z, Param);");
2375 else
2376 {
2377 statement(ts: "Tex.GetDimensions(ret.x, ret.y, ret.z);");
2378 statement(ts: "Param = 0u;");
2379 }
2380 break;
2381 }
2382
2383 statement(ts: "return ret;");
2384 end_scope();
2385 statement(ts: "");
2386 }
2387 }
2388}
2389
2390void CompilerHLSL::analyze_meshlet_writes()
2391{
2392 uint32_t id_per_vertex = 0;
2393 uint32_t id_per_primitive = 0;
2394 bool need_per_primitive = false;
2395 bool need_per_vertex = false;
2396
2397 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
2398 auto &type = this->get<SPIRType>(id: var.basetype);
2399 bool block = has_decoration(id: type.self, decoration: DecorationBlock);
2400 if (var.storage == StorageClassOutput && block && is_builtin_variable(var))
2401 {
2402 auto flags = get_buffer_block_flags(id: var.self);
2403 if (flags.get(bit: DecorationPerPrimitiveEXT))
2404 id_per_primitive = var.self;
2405 else
2406 id_per_vertex = var.self;
2407 }
2408 else if (var.storage == StorageClassOutput)
2409 {
2410 Bitset flags;
2411 if (block)
2412 flags = get_buffer_block_flags(id: var.self);
2413 else
2414 flags = get_decoration_bitset(id: var.self);
2415
2416 if (flags.get(bit: DecorationPerPrimitiveEXT))
2417 need_per_primitive = true;
2418 else
2419 need_per_vertex = true;
2420 }
2421 });
2422
2423 // If we have per-primitive outputs, and no per-primitive builtins,
2424 // empty version of gl_MeshPerPrimitiveEXT will be emitted.
2425 // If we don't use block IO for vertex output, we'll also need to synthesize the PerVertex block.
2426
2427 const auto generate_block = [&](const char *block_name, const char *instance_name, bool per_primitive) -> uint32_t {
2428 auto &execution = get_entry_point();
2429
2430 uint32_t op_type = ir.increase_bound_by(count: 4);
2431 uint32_t op_arr = op_type + 1;
2432 uint32_t op_ptr = op_type + 2;
2433 uint32_t op_var = op_type + 3;
2434
2435 auto &type = set<SPIRType>(id: op_type, args: OpTypeStruct);
2436 type.basetype = SPIRType::Struct;
2437 set_name(id: op_type, name: block_name);
2438 set_decoration(id: op_type, decoration: DecorationBlock);
2439 if (per_primitive)
2440 set_decoration(id: op_type, decoration: DecorationPerPrimitiveEXT);
2441
2442 auto &arr = set<SPIRType>(id: op_arr, args&: type);
2443 arr.parent_type = type.self;
2444 arr.array.push_back(t: per_primitive ? execution.output_primitives : execution.output_vertices);
2445 arr.array_size_literal.push_back(t: true);
2446
2447 auto &ptr = set<SPIRType>(id: op_ptr, args&: arr);
2448 ptr.parent_type = arr.self;
2449 ptr.pointer = true;
2450 ptr.pointer_depth++;
2451 ptr.storage = StorageClassOutput;
2452 set_decoration(id: op_ptr, decoration: DecorationBlock);
2453 set_name(id: op_ptr, name: block_name);
2454
2455 auto &var = set<SPIRVariable>(id: op_var, args&: op_ptr, args: StorageClassOutput);
2456 if (per_primitive)
2457 set_decoration(id: op_var, decoration: DecorationPerPrimitiveEXT);
2458 set_name(id: op_var, name: instance_name);
2459 execution.interface_variables.push_back(t: var.self);
2460
2461 return op_var;
2462 };
2463
2464 if (id_per_vertex == 0 && need_per_vertex)
2465 id_per_vertex = generate_block("gl_MeshPerVertexEXT", "gl_MeshVerticesEXT", false);
2466 if (id_per_primitive == 0 && need_per_primitive)
2467 id_per_primitive = generate_block("gl_MeshPerPrimitiveEXT", "gl_MeshPrimitivesEXT", true);
2468
2469 unordered_set<uint32_t> processed_func_ids;
2470 analyze_meshlet_writes(func_id: ir.default_entry_point, id_per_vertex, id_per_primitive, processed_func_ids);
2471}
2472
2473void CompilerHLSL::analyze_meshlet_writes(uint32_t func_id, uint32_t id_per_vertex, uint32_t id_per_primitive,
2474 std::unordered_set<uint32_t> &processed_func_ids)
2475{
2476 // Avoid processing a function more than once
2477 if (processed_func_ids.find(x: func_id) != processed_func_ids.end())
2478 return;
2479 processed_func_ids.insert(x: func_id);
2480
2481 auto &func = get<SPIRFunction>(id: func_id);
2482 // Recursively establish global args added to functions on which we depend.
2483 for (auto& block : func.blocks)
2484 {
2485 auto &b = get<SPIRBlock>(id: block);
2486 for (auto &i : b.ops)
2487 {
2488 auto ops = stream(instr: i);
2489 auto op = static_cast<Op>(i.op);
2490
2491 switch (op)
2492 {
2493 case OpFunctionCall:
2494 {
2495 // Then recurse into the function itself to extract globals used internally in the function
2496 uint32_t inner_func_id = ops[2];
2497 analyze_meshlet_writes(func_id: inner_func_id, id_per_vertex, id_per_primitive, processed_func_ids);
2498 auto &inner_func = get<SPIRFunction>(id: inner_func_id);
2499 for (auto &iarg : inner_func.arguments)
2500 {
2501 if (!iarg.alias_global_variable)
2502 continue;
2503
2504 bool already_declared = false;
2505 for (auto &arg : func.arguments)
2506 {
2507 if (arg.id == iarg.id)
2508 {
2509 already_declared = true;
2510 break;
2511 }
2512 }
2513
2514 if (!already_declared)
2515 {
2516 // basetype is effectively ignored here since we declare the argument
2517 // with explicit types. Just pass down a valid type.
2518 func.arguments.push_back(t: { .type: expression_type_id(id: iarg.id), .id: iarg.id,
2519 .read_count: iarg.read_count, .write_count: iarg.write_count, .alias_global_variable: true });
2520 }
2521 }
2522 break;
2523 }
2524
2525 case OpStore:
2526 case OpLoad:
2527 case OpInBoundsAccessChain:
2528 case OpAccessChain:
2529 case OpPtrAccessChain:
2530 case OpInBoundsPtrAccessChain:
2531 case OpArrayLength:
2532 {
2533 auto *var = maybe_get<SPIRVariable>(id: ops[op == OpStore ? 0 : 2]);
2534 if (var && (var->storage == StorageClassOutput || var->storage == StorageClassTaskPayloadWorkgroupEXT))
2535 {
2536 bool already_declared = false;
2537 auto builtin_type = BuiltIn(get_decoration(id: var->self, decoration: DecorationBuiltIn));
2538
2539 uint32_t var_id = var->self;
2540 if (var->storage != StorageClassTaskPayloadWorkgroupEXT &&
2541 builtin_type != BuiltInPrimitivePointIndicesEXT &&
2542 builtin_type != BuiltInPrimitiveLineIndicesEXT &&
2543 builtin_type != BuiltInPrimitiveTriangleIndicesEXT)
2544 {
2545 var_id = is_per_primitive_variable(var: *var) ? id_per_primitive : id_per_vertex;
2546 }
2547
2548 for (auto &arg : func.arguments)
2549 {
2550 if (arg.id == var_id)
2551 {
2552 already_declared = true;
2553 break;
2554 }
2555 }
2556
2557 if (!already_declared)
2558 {
2559 // basetype is effectively ignored here since we declare the argument
2560 // with explicit types. Just pass down a valid type.
2561 uint32_t type_id = expression_type_id(id: var_id);
2562 if (var->storage == StorageClassTaskPayloadWorkgroupEXT)
2563 func.arguments.push_back(t: { .type: type_id, .id: var_id, .read_count: 1u, .write_count: 0u, .alias_global_variable: true });
2564 else
2565 func.arguments.push_back(t: { .type: type_id, .id: var_id, .read_count: 1u, .write_count: 1u, .alias_global_variable: true });
2566 }
2567 }
2568 break;
2569 }
2570
2571 default:
2572 break;
2573 }
2574 }
2575 }
2576}
2577
2578string CompilerHLSL::layout_for_member(const SPIRType &type, uint32_t index)
2579{
2580 auto &flags = get_member_decoration_bitset(id: type.self, index);
2581
2582 // HLSL can emit row_major or column_major decoration in any struct.
2583 // Do not try to merge combined decorations for children like in GLSL.
2584
2585 // Flip the convention. HLSL is a bit odd in that the memory layout is column major ... but the language API is "row-major".
2586 // The way to deal with this is to multiply everything in inverse order, and reverse the memory layout.
2587 if (flags.get(bit: DecorationColMajor))
2588 return "row_major ";
2589 else if (flags.get(bit: DecorationRowMajor))
2590 return "column_major ";
2591
2592 return "";
2593}
2594
2595void CompilerHLSL::emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
2596 const string &qualifier, uint32_t base_offset)
2597{
2598 auto &membertype = get<SPIRType>(id: member_type_id);
2599
2600 Bitset memberflags;
2601 auto &memb = ir.meta[type.self].members;
2602 if (index < memb.size())
2603 memberflags = memb[index].decoration_flags;
2604
2605 string packing_offset;
2606 bool is_push_constant = type.storage == StorageClassPushConstant;
2607
2608 if ((has_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationExplicitOffset) || is_push_constant) &&
2609 has_member_decoration(id: type.self, index, decoration: DecorationOffset))
2610 {
2611 uint32_t offset = memb[index].offset - base_offset;
2612 if (offset & 3)
2613 SPIRV_CROSS_THROW("Cannot pack on tighter bounds than 4 bytes in HLSL.");
2614
2615 static const char *packing_swizzle[] = { "", ".y", ".z", ".w" };
2616 packing_offset = join(ts: " : packoffset(c", ts: offset / 16, ts&: packing_swizzle[(offset & 15) >> 2], ts: ")");
2617 }
2618
2619 statement(ts: layout_for_member(type, index), ts: qualifier,
2620 ts: variable_decl(type: membertype, name: to_member_name(type, index)), ts&: packing_offset, ts: ";");
2621}
2622
2623void CompilerHLSL::emit_rayquery_function(const char *commited, const char *candidate, const uint32_t *ops)
2624{
2625 flush_variable_declaration(id: ops[0]);
2626 uint32_t is_commited = evaluate_constant_u32(id: ops[3]);
2627 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts&: is_commited ? commited : candidate), forward_rhs: false);
2628}
2629
2630void CompilerHLSL::emit_mesh_tasks(SPIRBlock &block)
2631{
2632 if (block.mesh.payload != 0)
2633 {
2634 statement(ts: "DispatchMesh(", ts: to_unpacked_expression(id: block.mesh.groups[0]), ts: ", ", ts: to_unpacked_expression(id: block.mesh.groups[1]), ts: ", ",
2635 ts: to_unpacked_expression(id: block.mesh.groups[2]), ts: ", ", ts: to_unpacked_expression(id: block.mesh.payload), ts: ");");
2636 }
2637 else
2638 {
2639 SPIRV_CROSS_THROW("Amplification shader in HLSL must have payload");
2640 }
2641}
2642
2643void CompilerHLSL::emit_buffer_block(const SPIRVariable &var)
2644{
2645 auto &type = get<SPIRType>(id: var.basetype);
2646
2647 bool is_uav = var.storage == StorageClassStorageBuffer || has_decoration(id: type.self, decoration: DecorationBufferBlock);
2648
2649 if (flattened_buffer_blocks.count(x: var.self))
2650 {
2651 emit_buffer_block_flattened(type: var);
2652 }
2653 else if (is_uav)
2654 {
2655 Bitset flags = ir.get_buffer_block_flags(var);
2656 bool is_readonly = flags.get(bit: DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(id: var.self);
2657 bool is_coherent = flags.get(bit: DecorationCoherent) && !is_readonly;
2658 bool is_interlocked = interlocked_resources.count(x: var.self) > 0;
2659
2660 auto to_structuredbuffer_subtype_name = [this](const SPIRType &parent_type) -> std::string
2661 {
2662 if (parent_type.basetype == SPIRType::Struct && parent_type.member_types.size() == 1)
2663 {
2664 // Use type of first struct member as a StructuredBuffer will have only one '._m0' field in SPIR-V
2665 const auto &member0_type = this->get<SPIRType>(id: parent_type.member_types.front());
2666 return this->type_to_glsl(type: member0_type);
2667 }
2668 else
2669 {
2670 // Otherwise, this StructuredBuffer only has a basic subtype, e.g. StructuredBuffer<int>
2671 return this->type_to_glsl(type: parent_type);
2672 }
2673 };
2674
2675 std::string type_name;
2676 if (is_user_type_structured(id: var.self))
2677 type_name = join(ts: is_readonly ? "" : is_interlocked ? "RasterizerOrdered" : "RW", ts: "StructuredBuffer<", ts: to_structuredbuffer_subtype_name(type), ts: ">");
2678 else
2679 type_name = is_readonly ? "ByteAddressBuffer" : is_interlocked ? "RasterizerOrderedByteAddressBuffer" : "RWByteAddressBuffer";
2680
2681 add_resource_name(id: var.self);
2682 statement(ts: is_coherent ? "globallycoherent " : "", ts&: type_name, ts: " ", ts: to_name(id: var.self), ts: type_to_array_glsl(type, variable_id: var.self),
2683 ts: to_resource_binding(var), ts: ";");
2684 }
2685 else
2686 {
2687 if (type.array.empty())
2688 {
2689 // Flatten the top-level struct so we can use packoffset,
2690 // this restriction is similar to GLSL where layout(offset) is not possible on sub-structs.
2691 flattened_structs[var.self] = false;
2692
2693 // Prefer the block name if possible.
2694 auto buffer_name = to_name(id: type.self, allow_alias: false);
2695 if (ir.meta[type.self].decoration.alias.empty() ||
2696 resource_names.find(x: buffer_name) != end(cont&: resource_names) ||
2697 block_names.find(x: buffer_name) != end(cont&: block_names))
2698 {
2699 buffer_name = get_block_fallback_name(id: var.self);
2700 }
2701
2702 add_variable(variables_primary&: block_names, variables_secondary: resource_names, name&: buffer_name);
2703
2704 // If for some reason buffer_name is an illegal name, make a final fallback to a workaround name.
2705 // This cannot conflict with anything else, so we're safe now.
2706 if (buffer_name.empty())
2707 buffer_name = join(ts: "_", ts&: get<SPIRType>(id: var.basetype).self, ts: "_", ts: var.self);
2708
2709 uint32_t failed_index = 0;
2710 if (buffer_is_packing_standard(type, packing: BufferPackingHLSLCbufferPackOffset, failed_index: &failed_index))
2711 set_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationExplicitOffset);
2712 else
2713 {
2714 SPIRV_CROSS_THROW(join("cbuffer ID ", var.self, " (name: ", buffer_name, "), member index ",
2715 failed_index, " (name: ", to_member_name(type, failed_index),
2716 ") cannot be expressed with either HLSL packing layout or packoffset."));
2717 }
2718
2719 block_names.insert(x: buffer_name);
2720
2721 // Save for post-reflection later.
2722 declared_block_names[var.self] = buffer_name;
2723
2724 type.member_name_cache.clear();
2725 // var.self can be used as a backup name for the block name,
2726 // so we need to make sure we don't disturb the name here on a recompile.
2727 // It will need to be reset if we have to recompile.
2728 preserve_alias_on_reset(id: var.self);
2729 add_resource_name(id: var.self);
2730 statement(ts: "cbuffer ", ts&: buffer_name, ts: to_resource_binding(var));
2731 begin_scope();
2732
2733 uint32_t i = 0;
2734 for (auto &member : type.member_types)
2735 {
2736 add_member_name(type, name: i);
2737 auto backup_name = get_member_name(id: type.self, index: i);
2738 auto member_name = to_member_name(type, index: i);
2739 member_name = join(ts: to_name(id: var.self), ts: "_", ts&: member_name);
2740 ParsedIR::sanitize_underscores(str&: member_name);
2741 set_member_name(id: type.self, index: i, name: member_name);
2742 emit_struct_member(type, member_type_id: member, index: i, qualifier: "");
2743 set_member_name(id: type.self, index: i, name: backup_name);
2744 i++;
2745 }
2746
2747 end_scope_decl();
2748 statement(ts: "");
2749 }
2750 else
2751 {
2752 if (hlsl_options.shader_model < 51)
2753 SPIRV_CROSS_THROW(
2754 "Need ConstantBuffer<T> to use arrays of UBOs, but this is only supported in SM 5.1.");
2755
2756 add_resource_name(id: type.self);
2757 add_resource_name(id: var.self);
2758
2759 // ConstantBuffer<T> does not support packoffset, so it is unuseable unless everything aligns as we expect.
2760 uint32_t failed_index = 0;
2761 if (!buffer_is_packing_standard(type, packing: BufferPackingHLSLCbuffer, failed_index: &failed_index))
2762 {
2763 SPIRV_CROSS_THROW(join("HLSL ConstantBuffer<T> ID ", var.self, " (name: ", to_name(type.self),
2764 "), member index ", failed_index, " (name: ", to_member_name(type, failed_index),
2765 ") cannot be expressed with normal HLSL packing rules."));
2766 }
2767
2768 emit_struct(type&: get<SPIRType>(id: type.self));
2769 statement(ts: "ConstantBuffer<", ts: to_name(id: type.self), ts: "> ", ts: to_name(id: var.self), ts: type_to_array_glsl(type, variable_id: var.self),
2770 ts: to_resource_binding(var), ts: ";");
2771 }
2772 }
2773}
2774
2775void CompilerHLSL::emit_push_constant_block(const SPIRVariable &var)
2776{
2777 if (flattened_buffer_blocks.count(x: var.self))
2778 {
2779 emit_buffer_block_flattened(type: var);
2780 }
2781 else if (root_constants_layout.empty())
2782 {
2783 emit_buffer_block(var);
2784 }
2785 else
2786 {
2787 for (const auto &layout : root_constants_layout)
2788 {
2789 auto &type = get<SPIRType>(id: var.basetype);
2790
2791 uint32_t failed_index = 0;
2792 if (buffer_is_packing_standard(type, packing: BufferPackingHLSLCbufferPackOffset, failed_index: &failed_index, start_offset: layout.start,
2793 end_offset: layout.end))
2794 set_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationExplicitOffset);
2795 else
2796 {
2797 SPIRV_CROSS_THROW(join("Root constant cbuffer ID ", var.self, " (name: ", to_name(type.self), ")",
2798 ", member index ", failed_index, " (name: ", to_member_name(type, failed_index),
2799 ") cannot be expressed with either HLSL packing layout or packoffset."));
2800 }
2801
2802 flattened_structs[var.self] = false;
2803 type.member_name_cache.clear();
2804 add_resource_name(id: var.self);
2805 auto &memb = ir.meta[type.self].members;
2806
2807 statement(ts: "cbuffer SPIRV_CROSS_RootConstant_", ts: to_name(id: var.self),
2808 ts: to_resource_register(flag: HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT, space: 'b', binding: layout.binding, set: layout.space));
2809 begin_scope();
2810
2811 // Index of the next field in the generated root constant constant buffer
2812 auto constant_index = 0u;
2813
2814 // Iterate over all member of the push constant and check which of the fields
2815 // fit into the given root constant layout.
2816 for (auto i = 0u; i < memb.size(); i++)
2817 {
2818 const auto offset = memb[i].offset;
2819 if (layout.start <= offset && offset < layout.end)
2820 {
2821 const auto &member = type.member_types[i];
2822
2823 add_member_name(type, name: constant_index);
2824 auto backup_name = get_member_name(id: type.self, index: i);
2825 auto member_name = to_member_name(type, index: i);
2826 member_name = join(ts: to_name(id: var.self), ts: "_", ts&: member_name);
2827 ParsedIR::sanitize_underscores(str&: member_name);
2828 set_member_name(id: type.self, index: constant_index, name: member_name);
2829 emit_struct_member(type, member_type_id: member, index: i, qualifier: "", base_offset: layout.start);
2830 set_member_name(id: type.self, index: constant_index, name: backup_name);
2831
2832 constant_index++;
2833 }
2834 }
2835
2836 end_scope_decl();
2837 }
2838 }
2839}
2840
2841string CompilerHLSL::to_sampler_expression(uint32_t id)
2842{
2843 auto expr = join(ts: "_", ts: to_non_uniform_aware_expression(id));
2844 auto index = expr.find_first_of(c: '[');
2845 if (index == string::npos)
2846 {
2847 return expr + "_sampler";
2848 }
2849 else
2850 {
2851 // We have an expression like _ident[array], so we cannot tack on _sampler, insert it inside the string instead.
2852 return expr.insert(pos: index, s: "_sampler");
2853 }
2854}
2855
2856void CompilerHLSL::emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id)
2857{
2858 if (hlsl_options.shader_model >= 40 && combined_image_samplers.empty())
2859 {
2860 set<SPIRCombinedImageSampler>(id: result_id, args&: result_type, args&: image_id, args&: samp_id);
2861 }
2862 else
2863 {
2864 // Make sure to suppress usage tracking. It is illegal to create temporaries of opaque types.
2865 emit_op(result_type, result_id, rhs: to_combined_image_sampler(image_id, samp_id), forward_rhs: true, suppress_usage_tracking: true);
2866 }
2867}
2868
2869string CompilerHLSL::to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id)
2870{
2871 string arg_str = CompilerGLSL::to_func_call_arg(arg, id);
2872
2873 if (hlsl_options.shader_model <= 30)
2874 return arg_str;
2875
2876 // Manufacture automatic sampler arg if the arg is a SampledImage texture and we're in modern HLSL.
2877 auto &type = expression_type(id);
2878
2879 // We don't have to consider combined image samplers here via OpSampledImage because
2880 // those variables cannot be passed as arguments to functions.
2881 // Only global SampledImage variables may be used as arguments.
2882 if (type.basetype == SPIRType::SampledImage && type.image.dim != DimBuffer)
2883 arg_str += ", " + to_sampler_expression(id);
2884
2885 return arg_str;
2886}
2887
2888string CompilerHLSL::get_inner_entry_point_name() const
2889{
2890 auto &execution = get_entry_point();
2891
2892 if (hlsl_options.use_entry_point_name)
2893 {
2894 auto name = join(ts: execution.name, ts: "_inner");
2895 ParsedIR::sanitize_underscores(str&: name);
2896 return name;
2897 }
2898
2899 if (execution.model == ExecutionModelVertex)
2900 return "vert_main";
2901 else if (execution.model == ExecutionModelFragment)
2902 return "frag_main";
2903 else if (execution.model == ExecutionModelGLCompute)
2904 return "comp_main";
2905 else if (execution.model == ExecutionModelMeshEXT)
2906 return "mesh_main";
2907 else if (execution.model == ExecutionModelTaskEXT)
2908 return "task_main";
2909 else
2910 SPIRV_CROSS_THROW("Unsupported execution model.");
2911}
2912
2913void CompilerHLSL::emit_function_prototype(SPIRFunction &func, const Bitset &return_flags)
2914{
2915 if (func.self != ir.default_entry_point)
2916 add_function_overload(func);
2917
2918 // Avoid shadow declarations.
2919 local_variable_names = resource_names;
2920
2921 string decl;
2922
2923 auto &type = get<SPIRType>(id: func.return_type);
2924 if (type.array.empty())
2925 {
2926 decl += flags_to_qualifiers_glsl(type, flags: return_flags);
2927 decl += type_to_glsl(type);
2928 decl += " ";
2929 }
2930 else
2931 {
2932 // We cannot return arrays in HLSL, so "return" through an out variable.
2933 decl = "void ";
2934 }
2935
2936 if (func.self == ir.default_entry_point)
2937 {
2938 decl += get_inner_entry_point_name();
2939 processing_entry_point = true;
2940 }
2941 else
2942 decl += to_name(id: func.self);
2943
2944 decl += "(";
2945 SmallVector<string> arglist;
2946
2947 if (!type.array.empty())
2948 {
2949 // Fake array returns by writing to an out array instead.
2950 string out_argument;
2951 out_argument += "out ";
2952 out_argument += type_to_glsl(type);
2953 out_argument += " ";
2954 out_argument += "spvReturnValue";
2955 out_argument += type_to_array_glsl(type, variable_id: 0);
2956 arglist.push_back(t: std::move(out_argument));
2957 }
2958
2959 for (auto &arg : func.arguments)
2960 {
2961 // Do not pass in separate images or samplers if we're remapping
2962 // to combined image samplers.
2963 if (skip_argument(id: arg.id))
2964 continue;
2965
2966 // Might change the variable name if it already exists in this function.
2967 // SPIRV OpName doesn't have any semantic effect, so it's valid for an implementation
2968 // to use same name for variables.
2969 // Since we want to make the GLSL debuggable and somewhat sane, use fallback names for variables which are duplicates.
2970 add_local_variable_name(id: arg.id);
2971
2972 arglist.push_back(t: argument_decl(arg));
2973
2974 // Flatten a combined sampler to two separate arguments in modern HLSL.
2975 auto &arg_type = get<SPIRType>(id: arg.type);
2976 if (hlsl_options.shader_model > 30 && arg_type.basetype == SPIRType::SampledImage &&
2977 arg_type.image.dim != DimBuffer)
2978 {
2979 // Manufacture automatic sampler arg for SampledImage texture
2980 arglist.push_back(t: join(ts: is_depth_image(type: arg_type, id: arg.id) ? "SamplerComparisonState " : "SamplerState ",
2981 ts: to_sampler_expression(id: arg.id), ts: type_to_array_glsl(type: arg_type, variable_id: arg.id)));
2982 }
2983
2984 // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
2985 auto *var = maybe_get<SPIRVariable>(id: arg.id);
2986 if (var)
2987 var->parameter = &arg;
2988 }
2989
2990 for (auto &arg : func.shadow_arguments)
2991 {
2992 // Might change the variable name if it already exists in this function.
2993 // SPIRV OpName doesn't have any semantic effect, so it's valid for an implementation
2994 // to use same name for variables.
2995 // Since we want to make the GLSL debuggable and somewhat sane, use fallback names for variables which are duplicates.
2996 add_local_variable_name(id: arg.id);
2997
2998 arglist.push_back(t: argument_decl(arg));
2999
3000 // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
3001 auto *var = maybe_get<SPIRVariable>(id: arg.id);
3002 if (var)
3003 var->parameter = &arg;
3004 }
3005
3006 decl += merge(list: arglist);
3007 decl += ")";
3008 statement(ts&: decl);
3009}
3010
3011void CompilerHLSL::emit_hlsl_entry_point()
3012{
3013 SmallVector<string> arguments;
3014
3015 if (require_input)
3016 arguments.push_back(t: "SPIRV_Cross_Input stage_input");
3017
3018 auto &execution = get_entry_point();
3019
3020 switch (execution.model)
3021 {
3022 case ExecutionModelTaskEXT:
3023 case ExecutionModelMeshEXT:
3024 case ExecutionModelGLCompute:
3025 {
3026 if (execution.model == ExecutionModelMeshEXT)
3027 {
3028 if (execution.flags.get(bit: ExecutionModeOutputTrianglesEXT))
3029 statement(ts: "[outputtopology(\"triangle\")]");
3030 else if (execution.flags.get(bit: ExecutionModeOutputLinesEXT))
3031 statement(ts: "[outputtopology(\"line\")]");
3032 else if (execution.flags.get(bit: ExecutionModeOutputPoints))
3033 SPIRV_CROSS_THROW("Topology mode \"points\" is not supported in DirectX");
3034
3035 auto &func = get<SPIRFunction>(id: ir.default_entry_point);
3036 for (auto &arg : func.arguments)
3037 {
3038 auto &var = get<SPIRVariable>(id: arg.id);
3039 auto &base_type = get<SPIRType>(id: var.basetype);
3040 bool block = has_decoration(id: base_type.self, decoration: DecorationBlock);
3041 if (var.storage == StorageClassTaskPayloadWorkgroupEXT)
3042 {
3043 arguments.push_back(t: "in payload " + variable_decl(variable: var));
3044 }
3045 else if (block)
3046 {
3047 auto flags = get_buffer_block_flags(id: var.self);
3048 if (flags.get(bit: DecorationPerPrimitiveEXT) || has_decoration(id: arg.id, decoration: DecorationPerPrimitiveEXT))
3049 {
3050 arguments.push_back(t: "out primitives gl_MeshPerPrimitiveEXT gl_MeshPrimitivesEXT[" +
3051 std::to_string(val: execution.output_primitives) + "]");
3052 }
3053 else
3054 {
3055 arguments.push_back(t: "out vertices gl_MeshPerVertexEXT gl_MeshVerticesEXT[" +
3056 std::to_string(val: execution.output_vertices) + "]");
3057 }
3058 }
3059 else
3060 {
3061 if (execution.flags.get(bit: ExecutionModeOutputTrianglesEXT))
3062 {
3063 arguments.push_back(t: "out indices uint3 gl_PrimitiveTriangleIndicesEXT[" +
3064 std::to_string(val: execution.output_primitives) + "]");
3065 }
3066 else
3067 {
3068 arguments.push_back(t: "out indices uint2 gl_PrimitiveLineIndicesEXT[" +
3069 std::to_string(val: execution.output_primitives) + "]");
3070 }
3071 }
3072 }
3073 }
3074 SpecializationConstant wg_x, wg_y, wg_z;
3075 get_work_group_size_specialization_constants(x&: wg_x, y&: wg_y, z&: wg_z);
3076
3077 uint32_t x = execution.workgroup_size.x;
3078 uint32_t y = execution.workgroup_size.y;
3079 uint32_t z = execution.workgroup_size.z;
3080
3081 if (!execution.workgroup_size.constant && execution.flags.get(bit: ExecutionModeLocalSizeId))
3082 {
3083 if (execution.workgroup_size.id_x)
3084 x = get<SPIRConstant>(id: execution.workgroup_size.id_x).scalar();
3085 if (execution.workgroup_size.id_y)
3086 y = get<SPIRConstant>(id: execution.workgroup_size.id_y).scalar();
3087 if (execution.workgroup_size.id_z)
3088 z = get<SPIRConstant>(id: execution.workgroup_size.id_z).scalar();
3089 }
3090
3091 auto x_expr = wg_x.id ? get<SPIRConstant>(id: wg_x.id).specialization_constant_macro_name : to_string(val: x);
3092 auto y_expr = wg_y.id ? get<SPIRConstant>(id: wg_y.id).specialization_constant_macro_name : to_string(val: y);
3093 auto z_expr = wg_z.id ? get<SPIRConstant>(id: wg_z.id).specialization_constant_macro_name : to_string(val: z);
3094
3095 statement(ts: "[numthreads(", ts&: x_expr, ts: ", ", ts&: y_expr, ts: ", ", ts&: z_expr, ts: ")]");
3096 break;
3097 }
3098 case ExecutionModelFragment:
3099 if (execution.flags.get(bit: ExecutionModeEarlyFragmentTests))
3100 statement(ts: "[earlydepthstencil]");
3101 break;
3102 default:
3103 break;
3104 }
3105
3106 const char *entry_point_name;
3107 if (hlsl_options.use_entry_point_name)
3108 entry_point_name = get_entry_point().name.c_str();
3109 else
3110 entry_point_name = "main";
3111
3112 statement(ts: require_output ? "SPIRV_Cross_Output " : "void ", ts&: entry_point_name, ts: "(", ts: merge(list: arguments), ts: ")");
3113 begin_scope();
3114 bool legacy = hlsl_options.shader_model <= 30;
3115
3116 // Copy builtins from entry point arguments to globals.
3117 active_input_builtins.for_each_bit(op: [&](uint32_t i) {
3118 auto builtin = builtin_to_glsl(builtin: static_cast<BuiltIn>(i), storage: StorageClassInput);
3119 switch (static_cast<BuiltIn>(i))
3120 {
3121 case BuiltInFragCoord:
3122 // VPOS in D3D9 is sampled at integer locations, apply half-pixel offset to be consistent.
3123 // TODO: Do we need an option here? Any reason why a D3D9 shader would be used
3124 // on a D3D10+ system with a different rasterization config?
3125 if (legacy)
3126 statement(ts&: builtin, ts: " = stage_input.", ts&: builtin, ts: " + float4(0.5f, 0.5f, 0.0f, 0.0f);");
3127 else
3128 {
3129 statement(ts&: builtin, ts: " = stage_input.", ts&: builtin, ts: ";");
3130 // ZW are undefined in D3D9, only do this fixup here.
3131 statement(ts&: builtin, ts: ".w = 1.0 / ", ts&: builtin, ts: ".w;");
3132 }
3133 break;
3134
3135 case BuiltInVertexId:
3136 case BuiltInVertexIndex:
3137 case BuiltInInstanceIndex:
3138 // D3D semantics are uint, but shader wants int.
3139 if (hlsl_options.support_nonzero_base_vertex_base_instance)
3140 {
3141 if (static_cast<BuiltIn>(i) == BuiltInInstanceIndex)
3142 statement(ts&: builtin, ts: " = int(stage_input.", ts&: builtin, ts: ") + SPIRV_Cross_BaseInstance;");
3143 else
3144 statement(ts&: builtin, ts: " = int(stage_input.", ts&: builtin, ts: ") + SPIRV_Cross_BaseVertex;");
3145 }
3146 else
3147 statement(ts&: builtin, ts: " = int(stage_input.", ts&: builtin, ts: ");");
3148 break;
3149
3150 case BuiltInBaseVertex:
3151 statement(ts&: builtin, ts: " = SPIRV_Cross_BaseVertex;");
3152 break;
3153
3154 case BuiltInBaseInstance:
3155 statement(ts&: builtin, ts: " = SPIRV_Cross_BaseInstance;");
3156 break;
3157
3158 case BuiltInInstanceId:
3159 // D3D semantics are uint, but shader wants int.
3160 statement(ts&: builtin, ts: " = int(stage_input.", ts&: builtin, ts: ");");
3161 break;
3162
3163 case BuiltInSampleMask:
3164 statement(ts&: builtin, ts: "[0] = stage_input.", ts&: builtin, ts: ";");
3165 break;
3166
3167 case BuiltInNumWorkgroups:
3168 case BuiltInPointCoord:
3169 case BuiltInSubgroupSize:
3170 case BuiltInSubgroupLocalInvocationId:
3171 case BuiltInHelperInvocation:
3172 break;
3173
3174 case BuiltInSubgroupEqMask:
3175 // Emulate these ...
3176 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
3177 statement(ts: "gl_SubgroupEqMask = 1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96));");
3178 statement(ts: "if (WaveGetLaneIndex() >= 32) gl_SubgroupEqMask.x = 0;");
3179 statement(ts: "if (WaveGetLaneIndex() >= 64 || WaveGetLaneIndex() < 32) gl_SubgroupEqMask.y = 0;");
3180 statement(ts: "if (WaveGetLaneIndex() >= 96 || WaveGetLaneIndex() < 64) gl_SubgroupEqMask.z = 0;");
3181 statement(ts: "if (WaveGetLaneIndex() < 96) gl_SubgroupEqMask.w = 0;");
3182 break;
3183
3184 case BuiltInSubgroupGeMask:
3185 // Emulate these ...
3186 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
3187 statement(ts: "gl_SubgroupGeMask = ~((1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u);");
3188 statement(ts: "if (WaveGetLaneIndex() >= 32) gl_SubgroupGeMask.x = 0u;");
3189 statement(ts: "if (WaveGetLaneIndex() >= 64) gl_SubgroupGeMask.y = 0u;");
3190 statement(ts: "if (WaveGetLaneIndex() >= 96) gl_SubgroupGeMask.z = 0u;");
3191 statement(ts: "if (WaveGetLaneIndex() < 32) gl_SubgroupGeMask.y = ~0u;");
3192 statement(ts: "if (WaveGetLaneIndex() < 64) gl_SubgroupGeMask.z = ~0u;");
3193 statement(ts: "if (WaveGetLaneIndex() < 96) gl_SubgroupGeMask.w = ~0u;");
3194 break;
3195
3196 case BuiltInSubgroupGtMask:
3197 // Emulate these ...
3198 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
3199 statement(ts: "uint gt_lane_index = WaveGetLaneIndex() + 1;");
3200 statement(ts: "gl_SubgroupGtMask = ~((1u << (gt_lane_index - uint4(0, 32, 64, 96))) - 1u);");
3201 statement(ts: "if (gt_lane_index >= 32) gl_SubgroupGtMask.x = 0u;");
3202 statement(ts: "if (gt_lane_index >= 64) gl_SubgroupGtMask.y = 0u;");
3203 statement(ts: "if (gt_lane_index >= 96) gl_SubgroupGtMask.z = 0u;");
3204 statement(ts: "if (gt_lane_index >= 128) gl_SubgroupGtMask.w = 0u;");
3205 statement(ts: "if (gt_lane_index < 32) gl_SubgroupGtMask.y = ~0u;");
3206 statement(ts: "if (gt_lane_index < 64) gl_SubgroupGtMask.z = ~0u;");
3207 statement(ts: "if (gt_lane_index < 96) gl_SubgroupGtMask.w = ~0u;");
3208 break;
3209
3210 case BuiltInSubgroupLeMask:
3211 // Emulate these ...
3212 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
3213 statement(ts: "uint le_lane_index = WaveGetLaneIndex() + 1;");
3214 statement(ts: "gl_SubgroupLeMask = (1u << (le_lane_index - uint4(0, 32, 64, 96))) - 1u;");
3215 statement(ts: "if (le_lane_index >= 32) gl_SubgroupLeMask.x = ~0u;");
3216 statement(ts: "if (le_lane_index >= 64) gl_SubgroupLeMask.y = ~0u;");
3217 statement(ts: "if (le_lane_index >= 96) gl_SubgroupLeMask.z = ~0u;");
3218 statement(ts: "if (le_lane_index >= 128) gl_SubgroupLeMask.w = ~0u;");
3219 statement(ts: "if (le_lane_index < 32) gl_SubgroupLeMask.y = 0u;");
3220 statement(ts: "if (le_lane_index < 64) gl_SubgroupLeMask.z = 0u;");
3221 statement(ts: "if (le_lane_index < 96) gl_SubgroupLeMask.w = 0u;");
3222 break;
3223
3224 case BuiltInSubgroupLtMask:
3225 // Emulate these ...
3226 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
3227 statement(ts: "gl_SubgroupLtMask = (1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u;");
3228 statement(ts: "if (WaveGetLaneIndex() >= 32) gl_SubgroupLtMask.x = ~0u;");
3229 statement(ts: "if (WaveGetLaneIndex() >= 64) gl_SubgroupLtMask.y = ~0u;");
3230 statement(ts: "if (WaveGetLaneIndex() >= 96) gl_SubgroupLtMask.z = ~0u;");
3231 statement(ts: "if (WaveGetLaneIndex() < 32) gl_SubgroupLtMask.y = 0u;");
3232 statement(ts: "if (WaveGetLaneIndex() < 64) gl_SubgroupLtMask.z = 0u;");
3233 statement(ts: "if (WaveGetLaneIndex() < 96) gl_SubgroupLtMask.w = 0u;");
3234 break;
3235
3236 case BuiltInClipDistance:
3237 for (uint32_t clip = 0; clip < clip_distance_count; clip++)
3238 statement(ts: "gl_ClipDistance[", ts&: clip, ts: "] = stage_input.gl_ClipDistance", ts: clip / 4, ts: ".", ts: "xyzw"[clip & 3],
3239 ts: ";");
3240 break;
3241
3242 case BuiltInCullDistance:
3243 for (uint32_t cull = 0; cull < cull_distance_count; cull++)
3244 statement(ts: "gl_CullDistance[", ts&: cull, ts: "] = stage_input.gl_CullDistance", ts: cull / 4, ts: ".", ts: "xyzw"[cull & 3],
3245 ts: ";");
3246 break;
3247
3248 default:
3249 statement(ts&: builtin, ts: " = stage_input.", ts&: builtin, ts: ";");
3250 break;
3251 }
3252 });
3253
3254 // Copy from stage input struct to globals.
3255 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
3256 auto &type = this->get<SPIRType>(id: var.basetype);
3257 bool block = has_decoration(id: type.self, decoration: DecorationBlock);
3258
3259 if (var.storage != StorageClassInput)
3260 return;
3261
3262 bool need_matrix_unroll = var.storage == StorageClassInput && execution.model == ExecutionModelVertex;
3263
3264 if (!var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
3265 interface_variable_exists_in_entry_point(id: var.self))
3266 {
3267 if (block)
3268 {
3269 auto type_name = to_name(id: type.self);
3270 auto var_name = to_name(id: var.self);
3271 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(type.member_types.size()); mbr_idx++)
3272 {
3273 auto mbr_name = to_member_name(type, index: mbr_idx);
3274 auto flat_name = join(ts&: type_name, ts: "_", ts&: mbr_name);
3275 statement(ts&: var_name, ts: ".", ts&: mbr_name, ts: " = stage_input.", ts&: flat_name, ts: ";");
3276 }
3277 }
3278 else
3279 {
3280 auto name = to_name(id: var.self);
3281 auto &mtype = this->get<SPIRType>(id: var.basetype);
3282 if (need_matrix_unroll && mtype.columns > 1)
3283 {
3284 // Unroll matrices.
3285 for (uint32_t col = 0; col < mtype.columns; col++)
3286 statement(ts&: name, ts: "[", ts&: col, ts: "] = stage_input.", ts&: name, ts: "_", ts&: col, ts: ";");
3287 }
3288 else
3289 {
3290 statement(ts&: name, ts: " = stage_input.", ts&: name, ts: ";");
3291 }
3292 }
3293 }
3294 });
3295
3296 // Run the shader.
3297 if (execution.model == ExecutionModelVertex ||
3298 execution.model == ExecutionModelFragment ||
3299 execution.model == ExecutionModelGLCompute ||
3300 execution.model == ExecutionModelMeshEXT ||
3301 execution.model == ExecutionModelTaskEXT)
3302 {
3303 // For mesh shaders, we receive special arguments that we must pass down as function arguments.
3304 // HLSL does not support proper reference types for passing these IO blocks,
3305 // but DXC post-inlining seems to magically fix it up anyways *shrug*.
3306 SmallVector<string> arglist;
3307 auto &func = get<SPIRFunction>(id: ir.default_entry_point);
3308 // The arguments are marked out, avoid detecting reads and emitting inout.
3309 for (auto &arg : func.arguments)
3310 arglist.push_back(t: to_expression(id: arg.id, register_expression_read: false));
3311 statement(ts: get_inner_entry_point_name(), ts: "(", ts: merge(list: arglist), ts: ");");
3312 }
3313 else
3314 SPIRV_CROSS_THROW("Unsupported shader stage.");
3315
3316 // Copy stage outputs.
3317 if (require_output)
3318 {
3319 statement(ts: "SPIRV_Cross_Output stage_output;");
3320
3321 // Copy builtins from globals to return struct.
3322 active_output_builtins.for_each_bit(op: [&](uint32_t i) {
3323 // PointSize doesn't exist in HLSL SM 4+.
3324 if (i == BuiltInPointSize && !legacy)
3325 return;
3326
3327 switch (static_cast<BuiltIn>(i))
3328 {
3329 case BuiltInClipDistance:
3330 for (uint32_t clip = 0; clip < clip_distance_count; clip++)
3331 statement(ts: "stage_output.gl_ClipDistance", ts: clip / 4, ts: ".", ts: "xyzw"[clip & 3], ts: " = gl_ClipDistance[",
3332 ts&: clip, ts: "];");
3333 break;
3334
3335 case BuiltInCullDistance:
3336 for (uint32_t cull = 0; cull < cull_distance_count; cull++)
3337 statement(ts: "stage_output.gl_CullDistance", ts: cull / 4, ts: ".", ts: "xyzw"[cull & 3], ts: " = gl_CullDistance[",
3338 ts&: cull, ts: "];");
3339 break;
3340
3341 case BuiltInSampleMask:
3342 statement(ts: "stage_output.gl_SampleMask = gl_SampleMask[0];");
3343 break;
3344
3345 default:
3346 {
3347 auto builtin_expr = builtin_to_glsl(builtin: static_cast<BuiltIn>(i), storage: StorageClassOutput);
3348 statement(ts: "stage_output.", ts&: builtin_expr, ts: " = ", ts&: builtin_expr, ts: ";");
3349 break;
3350 }
3351 }
3352 });
3353
3354 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
3355 auto &type = this->get<SPIRType>(id: var.basetype);
3356 bool block = has_decoration(id: type.self, decoration: DecorationBlock);
3357
3358 if (var.storage != StorageClassOutput)
3359 return;
3360
3361 if (!var.remapped_variable && type.pointer &&
3362 !is_builtin_variable(var) &&
3363 interface_variable_exists_in_entry_point(id: var.self))
3364 {
3365 if (block)
3366 {
3367 // I/O blocks need to flatten output.
3368 auto type_name = to_name(id: type.self);
3369 auto var_name = to_name(id: var.self);
3370 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(type.member_types.size()); mbr_idx++)
3371 {
3372 auto mbr_name = to_member_name(type, index: mbr_idx);
3373 auto flat_name = join(ts&: type_name, ts: "_", ts&: mbr_name);
3374 statement(ts: "stage_output.", ts&: flat_name, ts: " = ", ts&: var_name, ts: ".", ts&: mbr_name, ts: ";");
3375 }
3376 }
3377 else
3378 {
3379 auto name = to_name(id: var.self);
3380
3381 if (legacy && execution.model == ExecutionModelFragment)
3382 {
3383 string output_filler;
3384 for (uint32_t size = type.vecsize; size < 4; ++size)
3385 output_filler += ", 0.0";
3386
3387 statement(ts: "stage_output.", ts&: name, ts: " = float4(", ts&: name, ts&: output_filler, ts: ");");
3388 }
3389 else
3390 {
3391 statement(ts: "stage_output.", ts&: name, ts: " = ", ts&: name, ts: ";");
3392 }
3393 }
3394 }
3395 });
3396
3397 statement(ts: "return stage_output;");
3398 }
3399
3400 end_scope();
3401}
3402
3403void CompilerHLSL::emit_fixup()
3404{
3405 if (is_vertex_like_shader() && active_output_builtins.get(bit: BuiltInPosition))
3406 {
3407 // Do various mangling on the gl_Position.
3408 if (hlsl_options.shader_model <= 30)
3409 {
3410 statement(ts: "gl_Position.x = gl_Position.x - gl_HalfPixel.x * "
3411 "gl_Position.w;");
3412 statement(ts: "gl_Position.y = gl_Position.y + gl_HalfPixel.y * "
3413 "gl_Position.w;");
3414 }
3415
3416 if (options.vertex.flip_vert_y)
3417 statement(ts: "gl_Position.y = -gl_Position.y;");
3418 if (options.vertex.fixup_clipspace)
3419 statement(ts: "gl_Position.z = (gl_Position.z + gl_Position.w) * 0.5;");
3420 }
3421}
3422
3423void CompilerHLSL::emit_texture_op(const Instruction &i, bool sparse)
3424{
3425 if (sparse)
3426 SPIRV_CROSS_THROW("Sparse feedback not yet supported in HLSL.");
3427
3428 auto *ops = stream(instr: i);
3429 auto op = static_cast<Op>(i.op);
3430 uint32_t length = i.length;
3431
3432 SmallVector<uint32_t> inherited_expressions;
3433
3434 uint32_t result_type = ops[0];
3435 uint32_t id = ops[1];
3436 VariableID img = ops[2];
3437 uint32_t coord = ops[3];
3438 uint32_t dref = 0;
3439 uint32_t comp = 0;
3440 bool gather = false;
3441 bool proj = false;
3442 const uint32_t *opt = nullptr;
3443 auto *combined_image = maybe_get<SPIRCombinedImageSampler>(id: img);
3444
3445 if (combined_image && has_decoration(id: img, decoration: DecorationNonUniform))
3446 {
3447 set_decoration(id: combined_image->image, decoration: DecorationNonUniform);
3448 set_decoration(id: combined_image->sampler, decoration: DecorationNonUniform);
3449 }
3450
3451 auto img_expr = to_non_uniform_aware_expression(id: combined_image ? combined_image->image : img);
3452
3453 inherited_expressions.push_back(t: coord);
3454
3455 switch (op)
3456 {
3457 case OpImageSampleDrefImplicitLod:
3458 case OpImageSampleDrefExplicitLod:
3459 dref = ops[4];
3460 opt = &ops[5];
3461 length -= 5;
3462 break;
3463
3464 case OpImageSampleProjDrefImplicitLod:
3465 case OpImageSampleProjDrefExplicitLod:
3466 dref = ops[4];
3467 proj = true;
3468 opt = &ops[5];
3469 length -= 5;
3470 break;
3471
3472 case OpImageDrefGather:
3473 dref = ops[4];
3474 opt = &ops[5];
3475 gather = true;
3476 length -= 5;
3477 break;
3478
3479 case OpImageGather:
3480 comp = ops[4];
3481 opt = &ops[5];
3482 gather = true;
3483 length -= 5;
3484 break;
3485
3486 case OpImageSampleProjImplicitLod:
3487 case OpImageSampleProjExplicitLod:
3488 opt = &ops[4];
3489 length -= 4;
3490 proj = true;
3491 break;
3492
3493 case OpImageQueryLod:
3494 opt = &ops[4];
3495 length -= 4;
3496 break;
3497
3498 default:
3499 opt = &ops[4];
3500 length -= 4;
3501 break;
3502 }
3503
3504 auto &imgtype = expression_type(id: img);
3505 uint32_t coord_components = 0;
3506 switch (imgtype.image.dim)
3507 {
3508 case spv::Dim1D:
3509 coord_components = 1;
3510 break;
3511 case spv::Dim2D:
3512 coord_components = 2;
3513 break;
3514 case spv::Dim3D:
3515 coord_components = 3;
3516 break;
3517 case spv::DimCube:
3518 coord_components = 3;
3519 break;
3520 case spv::DimBuffer:
3521 coord_components = 1;
3522 break;
3523 default:
3524 coord_components = 2;
3525 break;
3526 }
3527
3528 if (dref)
3529 inherited_expressions.push_back(t: dref);
3530
3531 if (imgtype.image.arrayed)
3532 coord_components++;
3533
3534 uint32_t bias = 0;
3535 uint32_t lod = 0;
3536 uint32_t grad_x = 0;
3537 uint32_t grad_y = 0;
3538 uint32_t coffset = 0;
3539 uint32_t offset = 0;
3540 uint32_t coffsets = 0;
3541 uint32_t sample = 0;
3542 uint32_t minlod = 0;
3543 uint32_t flags = 0;
3544
3545 if (length)
3546 {
3547 flags = opt[0];
3548 opt++;
3549 length--;
3550 }
3551
3552 auto test = [&](uint32_t &v, uint32_t flag) {
3553 if (length && (flags & flag))
3554 {
3555 v = *opt++;
3556 inherited_expressions.push_back(t: v);
3557 length--;
3558 }
3559 };
3560
3561 test(bias, ImageOperandsBiasMask);
3562 test(lod, ImageOperandsLodMask);
3563 test(grad_x, ImageOperandsGradMask);
3564 test(grad_y, ImageOperandsGradMask);
3565 test(coffset, ImageOperandsConstOffsetMask);
3566 test(offset, ImageOperandsOffsetMask);
3567 test(coffsets, ImageOperandsConstOffsetsMask);
3568 test(sample, ImageOperandsSampleMask);
3569 test(minlod, ImageOperandsMinLodMask);
3570
3571 string expr;
3572 string texop;
3573
3574 if (minlod != 0)
3575 SPIRV_CROSS_THROW("MinLod texture operand not supported in HLSL.");
3576
3577 if (op == OpImageFetch)
3578 {
3579 if (hlsl_options.shader_model < 40)
3580 {
3581 SPIRV_CROSS_THROW("texelFetch is not supported in HLSL shader model 2/3.");
3582 }
3583 texop += img_expr;
3584 texop += ".Load";
3585 }
3586 else if (op == OpImageQueryLod)
3587 {
3588 texop += img_expr;
3589 texop += ".CalculateLevelOfDetail";
3590 }
3591 else
3592 {
3593 auto &imgformat = get<SPIRType>(id: imgtype.image.type);
3594 if (hlsl_options.shader_model < 67 && imgformat.basetype != SPIRType::Float)
3595 {
3596 SPIRV_CROSS_THROW("Sampling non-float textures is not supported in HLSL SM < 6.7.");
3597 }
3598
3599 if (hlsl_options.shader_model >= 40)
3600 {
3601 texop += img_expr;
3602
3603 if (is_depth_image(type: imgtype, id: img))
3604 {
3605 if (gather)
3606 {
3607 texop += ".GatherCmp";
3608 }
3609 else if (lod || grad_x || grad_y)
3610 {
3611 // Assume we want a fixed level, and the only thing we can get in HLSL is SampleCmpLevelZero.
3612 texop += ".SampleCmpLevelZero";
3613 }
3614 else
3615 texop += ".SampleCmp";
3616 }
3617 else if (gather)
3618 {
3619 uint32_t comp_num = evaluate_constant_u32(id: comp);
3620 if (hlsl_options.shader_model >= 50)
3621 {
3622 switch (comp_num)
3623 {
3624 case 0:
3625 texop += ".GatherRed";
3626 break;
3627 case 1:
3628 texop += ".GatherGreen";
3629 break;
3630 case 2:
3631 texop += ".GatherBlue";
3632 break;
3633 case 3:
3634 texop += ".GatherAlpha";
3635 break;
3636 default:
3637 SPIRV_CROSS_THROW("Invalid component.");
3638 }
3639 }
3640 else
3641 {
3642 if (comp_num == 0)
3643 texop += ".Gather";
3644 else
3645 SPIRV_CROSS_THROW("HLSL shader model 4 can only gather from the red component.");
3646 }
3647 }
3648 else if (bias)
3649 texop += ".SampleBias";
3650 else if (grad_x || grad_y)
3651 texop += ".SampleGrad";
3652 else if (lod)
3653 texop += ".SampleLevel";
3654 else
3655 texop += ".Sample";
3656 }
3657 else
3658 {
3659 switch (imgtype.image.dim)
3660 {
3661 case Dim1D:
3662 texop += "tex1D";
3663 break;
3664 case Dim2D:
3665 texop += "tex2D";
3666 break;
3667 case Dim3D:
3668 texop += "tex3D";
3669 break;
3670 case DimCube:
3671 texop += "texCUBE";
3672 break;
3673 case DimRect:
3674 case DimBuffer:
3675 case DimSubpassData:
3676 SPIRV_CROSS_THROW("Buffer texture support is not yet implemented for HLSL"); // TODO
3677 default:
3678 SPIRV_CROSS_THROW("Invalid dimension.");
3679 }
3680
3681 if (gather)
3682 SPIRV_CROSS_THROW("textureGather is not supported in HLSL shader model 2/3.");
3683 if (offset || coffset)
3684 SPIRV_CROSS_THROW("textureOffset is not supported in HLSL shader model 2/3.");
3685
3686 if (grad_x || grad_y)
3687 texop += "grad";
3688 else if (lod)
3689 texop += "lod";
3690 else if (bias)
3691 texop += "bias";
3692 else if (proj || dref)
3693 texop += "proj";
3694 }
3695 }
3696
3697 expr += texop;
3698 expr += "(";
3699 if (hlsl_options.shader_model < 40)
3700 {
3701 if (combined_image)
3702 SPIRV_CROSS_THROW("Separate images/samplers are not supported in HLSL shader model 2/3.");
3703 expr += to_expression(id: img);
3704 }
3705 else if (op != OpImageFetch)
3706 {
3707 string sampler_expr;
3708 if (combined_image)
3709 sampler_expr = to_non_uniform_aware_expression(id: combined_image->sampler);
3710 else
3711 sampler_expr = to_sampler_expression(id: img);
3712 expr += sampler_expr;
3713 }
3714
3715 auto swizzle = [](uint32_t comps, uint32_t in_comps) -> const char * {
3716 if (comps == in_comps)
3717 return "";
3718
3719 switch (comps)
3720 {
3721 case 1:
3722 return ".x";
3723 case 2:
3724 return ".xy";
3725 case 3:
3726 return ".xyz";
3727 default:
3728 return "";
3729 }
3730 };
3731
3732 bool forward = should_forward(id: coord);
3733
3734 // The IR can give us more components than we need, so chop them off as needed.
3735 string coord_expr;
3736 auto &coord_type = expression_type(id: coord);
3737 if (coord_components != coord_type.vecsize)
3738 coord_expr = to_enclosed_expression(id: coord) + swizzle(coord_components, expression_type(id: coord).vecsize);
3739 else
3740 coord_expr = to_expression(id: coord);
3741
3742 if (proj && hlsl_options.shader_model >= 40) // Legacy HLSL has "proj" operations which do this for us.
3743 coord_expr = coord_expr + " / " + to_extract_component_expression(id: coord, index: coord_components);
3744
3745 if (hlsl_options.shader_model < 40)
3746 {
3747 if (dref)
3748 {
3749 if (imgtype.image.dim != spv::Dim1D && imgtype.image.dim != spv::Dim2D)
3750 {
3751 SPIRV_CROSS_THROW(
3752 "Depth comparison is only supported for 1D and 2D textures in HLSL shader model 2/3.");
3753 }
3754
3755 if (grad_x || grad_y)
3756 SPIRV_CROSS_THROW("Depth comparison is not supported for grad sampling in HLSL shader model 2/3.");
3757
3758 for (uint32_t size = coord_components; size < 2; ++size)
3759 coord_expr += ", 0.0";
3760
3761 forward = forward && should_forward(id: dref);
3762 coord_expr += ", " + to_expression(id: dref);
3763 }
3764 else if (lod || bias || proj)
3765 {
3766 for (uint32_t size = coord_components; size < 3; ++size)
3767 coord_expr += ", 0.0";
3768 }
3769
3770 if (lod)
3771 {
3772 coord_expr = "float4(" + coord_expr + ", " + to_expression(id: lod) + ")";
3773 }
3774 else if (bias)
3775 {
3776 coord_expr = "float4(" + coord_expr + ", " + to_expression(id: bias) + ")";
3777 }
3778 else if (proj)
3779 {
3780 coord_expr = "float4(" + coord_expr + ", " + to_extract_component_expression(id: coord, index: coord_components) + ")";
3781 }
3782 else if (dref)
3783 {
3784 // A "normal" sample gets fed into tex2Dproj as well, because the
3785 // regular tex2D accepts only two coordinates.
3786 coord_expr = "float4(" + coord_expr + ", 1.0)";
3787 }
3788
3789 if (!!lod + !!bias + !!proj > 1)
3790 SPIRV_CROSS_THROW("Legacy HLSL can only use one of lod/bias/proj modifiers.");
3791 }
3792
3793 if (op == OpImageFetch)
3794 {
3795 if (imgtype.image.dim != DimBuffer && !imgtype.image.ms)
3796 coord_expr =
3797 join(ts: "int", ts: coord_components + 1, ts: "(", ts&: coord_expr, ts: ", ", ts: lod ? to_expression(id: lod) : string("0"), ts: ")");
3798 }
3799 else
3800 expr += ", ";
3801 expr += coord_expr;
3802
3803 if (dref && hlsl_options.shader_model >= 40)
3804 {
3805 forward = forward && should_forward(id: dref);
3806 expr += ", ";
3807
3808 if (proj)
3809 expr += to_enclosed_expression(id: dref) + " / " + to_extract_component_expression(id: coord, index: coord_components);
3810 else
3811 expr += to_expression(id: dref);
3812 }
3813
3814 if (!dref && (grad_x || grad_y))
3815 {
3816 forward = forward && should_forward(id: grad_x);
3817 forward = forward && should_forward(id: grad_y);
3818 expr += ", ";
3819 expr += to_expression(id: grad_x);
3820 expr += ", ";
3821 expr += to_expression(id: grad_y);
3822 }
3823
3824 if (!dref && lod && hlsl_options.shader_model >= 40 && op != OpImageFetch)
3825 {
3826 forward = forward && should_forward(id: lod);
3827 expr += ", ";
3828 expr += to_expression(id: lod);
3829 }
3830
3831 if (!dref && bias && hlsl_options.shader_model >= 40)
3832 {
3833 forward = forward && should_forward(id: bias);
3834 expr += ", ";
3835 expr += to_expression(id: bias);
3836 }
3837
3838 if (coffset)
3839 {
3840 forward = forward && should_forward(id: coffset);
3841 expr += ", ";
3842 expr += to_expression(id: coffset);
3843 }
3844 else if (offset)
3845 {
3846 forward = forward && should_forward(id: offset);
3847 expr += ", ";
3848 expr += to_expression(id: offset);
3849 }
3850
3851 if (sample)
3852 {
3853 expr += ", ";
3854 expr += to_expression(id: sample);
3855 }
3856
3857 expr += ")";
3858
3859 if (dref && hlsl_options.shader_model < 40)
3860 expr += ".x";
3861
3862 if (op == OpImageQueryLod)
3863 {
3864 // This is rather awkward.
3865 // textureQueryLod returns two values, the "accessed level",
3866 // as well as the actual LOD lambda.
3867 // As far as I can tell, there is no way to get the .x component
3868 // according to GLSL spec, and it depends on the sampler itself.
3869 // Just assume X == Y, so we will need to splat the result to a float2.
3870 statement(ts: "float _", ts&: id, ts: "_tmp = ", ts&: expr, ts: ";");
3871 statement(ts: "float2 _", ts&: id, ts: " = _", ts&: id, ts: "_tmp.xx;");
3872 set<SPIRExpression>(id, args: join(ts: "_", ts&: id), args&: result_type, args: true);
3873 }
3874 else
3875 {
3876 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: forward, suppress_usage_tracking: false);
3877 }
3878
3879 for (auto &inherit : inherited_expressions)
3880 inherit_expression_dependencies(dst: id, source: inherit);
3881
3882 switch (op)
3883 {
3884 case OpImageSampleDrefImplicitLod:
3885 case OpImageSampleImplicitLod:
3886 case OpImageSampleProjImplicitLod:
3887 case OpImageSampleProjDrefImplicitLod:
3888 register_control_dependent_expression(expr: id);
3889 break;
3890
3891 default:
3892 break;
3893 }
3894}
3895
3896string CompilerHLSL::to_resource_binding(const SPIRVariable &var)
3897{
3898 const auto &type = get<SPIRType>(id: var.basetype);
3899
3900 // We can remap push constant blocks, even if they don't have any binding decoration.
3901 if (type.storage != StorageClassPushConstant && !has_decoration(id: var.self, decoration: DecorationBinding))
3902 return "";
3903
3904 char space = '\0';
3905
3906 HLSLBindingFlagBits resource_flags = HLSL_BINDING_AUTO_NONE_BIT;
3907
3908 switch (type.basetype)
3909 {
3910 case SPIRType::SampledImage:
3911 space = 't'; // SRV
3912 resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3913 break;
3914
3915 case SPIRType::Image:
3916 if (type.image.sampled == 2 && type.image.dim != DimSubpassData)
3917 {
3918 if (has_decoration(id: var.self, decoration: DecorationNonWritable) && hlsl_options.nonwritable_uav_texture_as_srv)
3919 {
3920 space = 't'; // SRV
3921 resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3922 }
3923 else
3924 {
3925 space = 'u'; // UAV
3926 resource_flags = HLSL_BINDING_AUTO_UAV_BIT;
3927 }
3928 }
3929 else
3930 {
3931 space = 't'; // SRV
3932 resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3933 }
3934 break;
3935
3936 case SPIRType::Sampler:
3937 space = 's';
3938 resource_flags = HLSL_BINDING_AUTO_SAMPLER_BIT;
3939 break;
3940
3941 case SPIRType::AccelerationStructure:
3942 space = 't'; // SRV
3943 resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3944 break;
3945
3946 case SPIRType::Struct:
3947 {
3948 auto storage = type.storage;
3949 if (storage == StorageClassUniform)
3950 {
3951 if (has_decoration(id: type.self, decoration: DecorationBufferBlock))
3952 {
3953 Bitset flags = ir.get_buffer_block_flags(var);
3954 bool is_readonly = flags.get(bit: DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(id: var.self);
3955 space = is_readonly ? 't' : 'u'; // UAV
3956 resource_flags = is_readonly ? HLSL_BINDING_AUTO_SRV_BIT : HLSL_BINDING_AUTO_UAV_BIT;
3957 }
3958 else if (has_decoration(id: type.self, decoration: DecorationBlock))
3959 {
3960 space = 'b'; // Constant buffers
3961 resource_flags = HLSL_BINDING_AUTO_CBV_BIT;
3962 }
3963 }
3964 else if (storage == StorageClassPushConstant)
3965 {
3966 space = 'b'; // Constant buffers
3967 resource_flags = HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT;
3968 }
3969 else if (storage == StorageClassStorageBuffer)
3970 {
3971 // UAV or SRV depending on readonly flag.
3972 Bitset flags = ir.get_buffer_block_flags(var);
3973 bool is_readonly = flags.get(bit: DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(id: var.self);
3974 space = is_readonly ? 't' : 'u';
3975 resource_flags = is_readonly ? HLSL_BINDING_AUTO_SRV_BIT : HLSL_BINDING_AUTO_UAV_BIT;
3976 }
3977
3978 break;
3979 }
3980 default:
3981 break;
3982 }
3983
3984 if (!space)
3985 return "";
3986
3987 uint32_t desc_set =
3988 resource_flags == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT ? ResourceBindingPushConstantDescriptorSet : 0u;
3989 uint32_t binding = resource_flags == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT ? ResourceBindingPushConstantBinding : 0u;
3990
3991 if (has_decoration(id: var.self, decoration: DecorationBinding))
3992 binding = get_decoration(id: var.self, decoration: DecorationBinding);
3993 if (has_decoration(id: var.self, decoration: DecorationDescriptorSet))
3994 desc_set = get_decoration(id: var.self, decoration: DecorationDescriptorSet);
3995
3996 return to_resource_register(flag: resource_flags, space, binding, set: desc_set);
3997}
3998
3999string CompilerHLSL::to_resource_binding_sampler(const SPIRVariable &var)
4000{
4001 // For combined image samplers.
4002 if (!has_decoration(id: var.self, decoration: DecorationBinding))
4003 return "";
4004
4005 return to_resource_register(flag: HLSL_BINDING_AUTO_SAMPLER_BIT, space: 's', binding: get_decoration(id: var.self, decoration: DecorationBinding),
4006 set: get_decoration(id: var.self, decoration: DecorationDescriptorSet));
4007}
4008
4009void CompilerHLSL::remap_hlsl_resource_binding(HLSLBindingFlagBits type, uint32_t &desc_set, uint32_t &binding)
4010{
4011 auto itr = resource_bindings.find(x: { .model: get_execution_model(), .desc_set: desc_set, .binding: binding });
4012 if (itr != end(cont&: resource_bindings))
4013 {
4014 auto &remap = itr->second;
4015 remap.second = true;
4016
4017 switch (type)
4018 {
4019 case HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT:
4020 case HLSL_BINDING_AUTO_CBV_BIT:
4021 desc_set = remap.first.cbv.register_space;
4022 binding = remap.first.cbv.register_binding;
4023 break;
4024
4025 case HLSL_BINDING_AUTO_SRV_BIT:
4026 desc_set = remap.first.srv.register_space;
4027 binding = remap.first.srv.register_binding;
4028 break;
4029
4030 case HLSL_BINDING_AUTO_SAMPLER_BIT:
4031 desc_set = remap.first.sampler.register_space;
4032 binding = remap.first.sampler.register_binding;
4033 break;
4034
4035 case HLSL_BINDING_AUTO_UAV_BIT:
4036 desc_set = remap.first.uav.register_space;
4037 binding = remap.first.uav.register_binding;
4038 break;
4039
4040 default:
4041 break;
4042 }
4043 }
4044}
4045
4046string CompilerHLSL::to_resource_register(HLSLBindingFlagBits flag, char space, uint32_t binding, uint32_t space_set)
4047{
4048 if ((flag & resource_binding_flags) == 0)
4049 {
4050 remap_hlsl_resource_binding(type: flag, desc_set&: space_set, binding);
4051
4052 // The push constant block did not have a binding, and there were no remap for it,
4053 // so, declare without register binding.
4054 if (flag == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT && space_set == ResourceBindingPushConstantDescriptorSet)
4055 return "";
4056
4057 if (hlsl_options.shader_model >= 51)
4058 return join(ts: " : register(", ts&: space, ts&: binding, ts: ", space", ts&: space_set, ts: ")");
4059 else
4060 return join(ts: " : register(", ts&: space, ts&: binding, ts: ")");
4061 }
4062 else
4063 return "";
4064}
4065
4066void CompilerHLSL::emit_modern_uniform(const SPIRVariable &var)
4067{
4068 auto &type = get<SPIRType>(id: var.basetype);
4069 switch (type.basetype)
4070 {
4071 case SPIRType::SampledImage:
4072 case SPIRType::Image:
4073 {
4074 bool is_coherent = false;
4075 if (type.basetype == SPIRType::Image && type.image.sampled == 2)
4076 is_coherent = has_decoration(id: var.self, decoration: DecorationCoherent);
4077
4078 statement(ts: is_coherent ? "globallycoherent " : "", ts: image_type_hlsl_modern(type, id: var.self), ts: " ",
4079 ts: to_name(id: var.self), ts: type_to_array_glsl(type, variable_id: var.self), ts: to_resource_binding(var), ts: ";");
4080
4081 if (type.basetype == SPIRType::SampledImage && type.image.dim != DimBuffer)
4082 {
4083 // For combined image samplers, also emit a combined image sampler.
4084 if (is_depth_image(type, id: var.self))
4085 statement(ts: "SamplerComparisonState ", ts: to_sampler_expression(id: var.self), ts: type_to_array_glsl(type, variable_id: var.self),
4086 ts: to_resource_binding_sampler(var), ts: ";");
4087 else
4088 statement(ts: "SamplerState ", ts: to_sampler_expression(id: var.self), ts: type_to_array_glsl(type, variable_id: var.self),
4089 ts: to_resource_binding_sampler(var), ts: ";");
4090 }
4091 break;
4092 }
4093
4094 case SPIRType::Sampler:
4095 if (comparison_ids.count(x: var.self))
4096 statement(ts: "SamplerComparisonState ", ts: to_name(id: var.self), ts: type_to_array_glsl(type, variable_id: var.self), ts: to_resource_binding(var),
4097 ts: ";");
4098 else
4099 statement(ts: "SamplerState ", ts: to_name(id: var.self), ts: type_to_array_glsl(type, variable_id: var.self), ts: to_resource_binding(var), ts: ";");
4100 break;
4101
4102 default:
4103 statement(ts: variable_decl(variable: var), ts: to_resource_binding(var), ts: ";");
4104 break;
4105 }
4106}
4107
4108void CompilerHLSL::emit_legacy_uniform(const SPIRVariable &var)
4109{
4110 auto &type = get<SPIRType>(id: var.basetype);
4111 switch (type.basetype)
4112 {
4113 case SPIRType::Sampler:
4114 case SPIRType::Image:
4115 SPIRV_CROSS_THROW("Separate image and samplers not supported in legacy HLSL.");
4116
4117 default:
4118 statement(ts: variable_decl(variable: var), ts: ";");
4119 break;
4120 }
4121}
4122
4123void CompilerHLSL::emit_uniform(const SPIRVariable &var)
4124{
4125 add_resource_name(id: var.self);
4126 if (hlsl_options.shader_model >= 40)
4127 emit_modern_uniform(var);
4128 else
4129 emit_legacy_uniform(var);
4130}
4131
4132bool CompilerHLSL::emit_complex_bitcast(uint32_t, uint32_t, uint32_t)
4133{
4134 return false;
4135}
4136
4137string CompilerHLSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
4138{
4139 if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Int)
4140 return type_to_glsl(type: out_type);
4141 else if (out_type.basetype == SPIRType::UInt64 && in_type.basetype == SPIRType::Int64)
4142 return type_to_glsl(type: out_type);
4143 else if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Float)
4144 return "asuint";
4145 else if (out_type.basetype == SPIRType::Int && in_type.basetype == SPIRType::UInt)
4146 return type_to_glsl(type: out_type);
4147 else if (out_type.basetype == SPIRType::Int64 && in_type.basetype == SPIRType::UInt64)
4148 return type_to_glsl(type: out_type);
4149 else if (out_type.basetype == SPIRType::Int && in_type.basetype == SPIRType::Float)
4150 return "asint";
4151 else if (out_type.basetype == SPIRType::Float && in_type.basetype == SPIRType::UInt)
4152 return "asfloat";
4153 else if (out_type.basetype == SPIRType::Float && in_type.basetype == SPIRType::Int)
4154 return "asfloat";
4155 else if (out_type.basetype == SPIRType::Int64 && in_type.basetype == SPIRType::Double)
4156 SPIRV_CROSS_THROW("Double to Int64 is not supported in HLSL.");
4157 else if (out_type.basetype == SPIRType::UInt64 && in_type.basetype == SPIRType::Double)
4158 SPIRV_CROSS_THROW("Double to UInt64 is not supported in HLSL.");
4159 else if (out_type.basetype == SPIRType::Double && in_type.basetype == SPIRType::Int64)
4160 return "asdouble";
4161 else if (out_type.basetype == SPIRType::Double && in_type.basetype == SPIRType::UInt64)
4162 return "asdouble";
4163 else if (out_type.basetype == SPIRType::Half && in_type.basetype == SPIRType::UInt && in_type.vecsize == 1)
4164 {
4165 if (!requires_explicit_fp16_packing)
4166 {
4167 requires_explicit_fp16_packing = true;
4168 force_recompile();
4169 }
4170 return "spvUnpackFloat2x16";
4171 }
4172 else if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Half && in_type.vecsize == 2)
4173 {
4174 if (!requires_explicit_fp16_packing)
4175 {
4176 requires_explicit_fp16_packing = true;
4177 force_recompile();
4178 }
4179 return "spvPackFloat2x16";
4180 }
4181 else if (out_type.basetype == SPIRType::UShort && in_type.basetype == SPIRType::Half)
4182 {
4183 if (hlsl_options.shader_model < 40)
4184 SPIRV_CROSS_THROW("Half to UShort requires Shader Model 4.");
4185 return "(" + type_to_glsl(type: out_type) + ")f32tof16";
4186 }
4187 else if (out_type.basetype == SPIRType::Half && in_type.basetype == SPIRType::UShort)
4188 {
4189 if (hlsl_options.shader_model < 40)
4190 SPIRV_CROSS_THROW("UShort to Half requires Shader Model 4.");
4191 return "(" + type_to_glsl(type: out_type) + ")f16tof32";
4192 }
4193 else
4194 return "";
4195}
4196
4197void CompilerHLSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, uint32_t count)
4198{
4199 auto op = static_cast<GLSLstd450>(eop);
4200
4201 // If we need to do implicit bitcasts, make sure we do it with the correct type.
4202 uint32_t integer_width = get_integer_width_for_glsl_instruction(op, arguments: args, length: count);
4203 auto int_type = to_signed_basetype(width: integer_width);
4204 auto uint_type = to_unsigned_basetype(width: integer_width);
4205
4206 op = get_remapped_glsl_op(std450_op: op);
4207
4208 switch (op)
4209 {
4210 case GLSLstd450InverseSqrt:
4211 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "rsqrt");
4212 break;
4213
4214 case GLSLstd450Fract:
4215 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "frac");
4216 break;
4217
4218 case GLSLstd450RoundEven:
4219 if (hlsl_options.shader_model < 40)
4220 SPIRV_CROSS_THROW("roundEven is not supported in HLSL shader model 2/3.");
4221 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "round");
4222 break;
4223
4224 case GLSLstd450Trunc:
4225 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "trunc");
4226 break;
4227
4228 case GLSLstd450Acosh:
4229 case GLSLstd450Asinh:
4230 case GLSLstd450Atanh:
4231 // These are not supported in HLSL, always emulate them.
4232 emit_emulated_ahyper_op(result_type, result_id: id, op0: args[0], op);
4233 break;
4234
4235 case GLSLstd450FMix:
4236 case GLSLstd450IMix:
4237 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "lerp");
4238 break;
4239
4240 case GLSLstd450Atan2:
4241 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "atan2");
4242 break;
4243
4244 case GLSLstd450Fma:
4245 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "mad");
4246 break;
4247
4248 case GLSLstd450InterpolateAtCentroid:
4249 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "EvaluateAttributeAtCentroid");
4250 break;
4251 case GLSLstd450InterpolateAtSample:
4252 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "EvaluateAttributeAtSample");
4253 break;
4254 case GLSLstd450InterpolateAtOffset:
4255 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "EvaluateAttributeSnapped");
4256 break;
4257
4258 case GLSLstd450PackHalf2x16:
4259 if (!requires_fp16_packing)
4260 {
4261 requires_fp16_packing = true;
4262 force_recompile();
4263 }
4264 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvPackHalf2x16");
4265 break;
4266
4267 case GLSLstd450UnpackHalf2x16:
4268 if (!requires_fp16_packing)
4269 {
4270 requires_fp16_packing = true;
4271 force_recompile();
4272 }
4273 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvUnpackHalf2x16");
4274 break;
4275
4276 case GLSLstd450PackSnorm4x8:
4277 if (!requires_snorm8_packing)
4278 {
4279 requires_snorm8_packing = true;
4280 force_recompile();
4281 }
4282 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvPackSnorm4x8");
4283 break;
4284
4285 case GLSLstd450UnpackSnorm4x8:
4286 if (!requires_snorm8_packing)
4287 {
4288 requires_snorm8_packing = true;
4289 force_recompile();
4290 }
4291 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvUnpackSnorm4x8");
4292 break;
4293
4294 case GLSLstd450PackUnorm4x8:
4295 if (!requires_unorm8_packing)
4296 {
4297 requires_unorm8_packing = true;
4298 force_recompile();
4299 }
4300 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvPackUnorm4x8");
4301 break;
4302
4303 case GLSLstd450UnpackUnorm4x8:
4304 if (!requires_unorm8_packing)
4305 {
4306 requires_unorm8_packing = true;
4307 force_recompile();
4308 }
4309 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvUnpackUnorm4x8");
4310 break;
4311
4312 case GLSLstd450PackSnorm2x16:
4313 if (!requires_snorm16_packing)
4314 {
4315 requires_snorm16_packing = true;
4316 force_recompile();
4317 }
4318 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvPackSnorm2x16");
4319 break;
4320
4321 case GLSLstd450UnpackSnorm2x16:
4322 if (!requires_snorm16_packing)
4323 {
4324 requires_snorm16_packing = true;
4325 force_recompile();
4326 }
4327 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvUnpackSnorm2x16");
4328 break;
4329
4330 case GLSLstd450PackUnorm2x16:
4331 if (!requires_unorm16_packing)
4332 {
4333 requires_unorm16_packing = true;
4334 force_recompile();
4335 }
4336 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvPackUnorm2x16");
4337 break;
4338
4339 case GLSLstd450UnpackUnorm2x16:
4340 if (!requires_unorm16_packing)
4341 {
4342 requires_unorm16_packing = true;
4343 force_recompile();
4344 }
4345 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvUnpackUnorm2x16");
4346 break;
4347
4348 case GLSLstd450PackDouble2x32:
4349 case GLSLstd450UnpackDouble2x32:
4350 SPIRV_CROSS_THROW("packDouble2x32/unpackDouble2x32 not supported in HLSL.");
4351
4352 case GLSLstd450FindILsb:
4353 {
4354 auto basetype = expression_type(id: args[0]).basetype;
4355 emit_unary_func_op_cast(result_type, result_id: id, op0: args[0], op: "firstbitlow", input_type: basetype, expected_result_type: basetype);
4356 break;
4357 }
4358
4359 case GLSLstd450FindSMsb:
4360 emit_unary_func_op_cast(result_type, result_id: id, op0: args[0], op: "firstbithigh", input_type: int_type, expected_result_type: int_type);
4361 break;
4362
4363 case GLSLstd450FindUMsb:
4364 emit_unary_func_op_cast(result_type, result_id: id, op0: args[0], op: "firstbithigh", input_type: uint_type, expected_result_type: uint_type);
4365 break;
4366
4367 case GLSLstd450MatrixInverse:
4368 {
4369 auto &type = get<SPIRType>(id: result_type);
4370 if (type.vecsize == 2 && type.columns == 2)
4371 {
4372 if (!requires_inverse_2x2)
4373 {
4374 requires_inverse_2x2 = true;
4375 force_recompile();
4376 }
4377 }
4378 else if (type.vecsize == 3 && type.columns == 3)
4379 {
4380 if (!requires_inverse_3x3)
4381 {
4382 requires_inverse_3x3 = true;
4383 force_recompile();
4384 }
4385 }
4386 else if (type.vecsize == 4 && type.columns == 4)
4387 {
4388 if (!requires_inverse_4x4)
4389 {
4390 requires_inverse_4x4 = true;
4391 force_recompile();
4392 }
4393 }
4394 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvInverse");
4395 break;
4396 }
4397
4398 case GLSLstd450Normalize:
4399 // HLSL does not support scalar versions here.
4400 if (expression_type(id: args[0]).vecsize == 1)
4401 {
4402 // Returns -1 or 1 for valid input, sign() does the job.
4403 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "sign");
4404 }
4405 else
4406 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
4407 break;
4408
4409 case GLSLstd450Reflect:
4410 if (get<SPIRType>(id: result_type).vecsize == 1)
4411 {
4412 if (!requires_scalar_reflect)
4413 {
4414 requires_scalar_reflect = true;
4415 force_recompile();
4416 }
4417 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "spvReflect");
4418 }
4419 else
4420 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
4421 break;
4422
4423 case GLSLstd450Refract:
4424 if (get<SPIRType>(id: result_type).vecsize == 1)
4425 {
4426 if (!requires_scalar_refract)
4427 {
4428 requires_scalar_refract = true;
4429 force_recompile();
4430 }
4431 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "spvRefract");
4432 }
4433 else
4434 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
4435 break;
4436
4437 case GLSLstd450FaceForward:
4438 if (get<SPIRType>(id: result_type).vecsize == 1)
4439 {
4440 if (!requires_scalar_faceforward)
4441 {
4442 requires_scalar_faceforward = true;
4443 force_recompile();
4444 }
4445 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "spvFaceForward");
4446 }
4447 else
4448 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
4449 break;
4450
4451 case GLSLstd450NMin:
4452 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: GLSLstd450FMin, args, count);
4453 break;
4454
4455 case GLSLstd450NMax:
4456 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: GLSLstd450FMax, args, count);
4457 break;
4458
4459 case GLSLstd450NClamp:
4460 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: GLSLstd450FClamp, args, count);
4461 break;
4462
4463 default:
4464 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
4465 break;
4466 }
4467}
4468
4469void CompilerHLSL::read_access_chain_array(const string &lhs, const SPIRAccessChain &chain)
4470{
4471 auto &type = get<SPIRType>(id: chain.basetype);
4472
4473 // Need to use a reserved identifier here since it might shadow an identifier in the access chain input or other loops.
4474 auto ident = get_unique_identifier();
4475
4476 statement(ts: "[unroll]");
4477 statement(ts: "for (int ", ts&: ident, ts: " = 0; ", ts&: ident, ts: " < ", ts: to_array_size(type, index: uint32_t(type.array.size() - 1)), ts: "; ",
4478 ts&: ident, ts: "++)");
4479 begin_scope();
4480 auto subchain = chain;
4481 subchain.dynamic_index = join(ts&: ident, ts: " * ", ts: chain.array_stride, ts: " + ", ts: chain.dynamic_index);
4482 subchain.basetype = type.parent_type;
4483 if (!get<SPIRType>(id: subchain.basetype).array.empty())
4484 subchain.array_stride = get_decoration(id: subchain.basetype, decoration: DecorationArrayStride);
4485 read_access_chain(expr: nullptr, lhs: join(ts: lhs, ts: "[", ts&: ident, ts: "]"), chain: subchain);
4486 end_scope();
4487}
4488
4489void CompilerHLSL::read_access_chain_struct(const string &lhs, const SPIRAccessChain &chain)
4490{
4491 auto &type = get<SPIRType>(id: chain.basetype);
4492 auto subchain = chain;
4493 uint32_t member_count = uint32_t(type.member_types.size());
4494
4495 for (uint32_t i = 0; i < member_count; i++)
4496 {
4497 uint32_t offset = type_struct_member_offset(type, index: i);
4498 subchain.static_index = chain.static_index + offset;
4499 subchain.basetype = type.member_types[i];
4500
4501 subchain.matrix_stride = 0;
4502 subchain.array_stride = 0;
4503 subchain.row_major_matrix = false;
4504
4505 auto &member_type = get<SPIRType>(id: subchain.basetype);
4506 if (member_type.columns > 1)
4507 {
4508 subchain.matrix_stride = type_struct_member_matrix_stride(type, index: i);
4509 subchain.row_major_matrix = has_member_decoration(id: type.self, index: i, decoration: DecorationRowMajor);
4510 }
4511
4512 if (!member_type.array.empty())
4513 subchain.array_stride = type_struct_member_array_stride(type, index: i);
4514
4515 read_access_chain(expr: nullptr, lhs: join(ts: lhs, ts: ".", ts: to_member_name(type, index: i)), chain: subchain);
4516 }
4517}
4518
4519void CompilerHLSL::read_access_chain(string *expr, const string &lhs, const SPIRAccessChain &chain)
4520{
4521 auto &type = get<SPIRType>(id: chain.basetype);
4522
4523 SPIRType target_type { is_scalar(type) ? OpTypeInt : type.op };
4524 target_type.basetype = SPIRType::UInt;
4525 target_type.vecsize = type.vecsize;
4526 target_type.columns = type.columns;
4527
4528 if (!type.array.empty())
4529 {
4530 read_access_chain_array(lhs, chain);
4531 return;
4532 }
4533 else if (type.basetype == SPIRType::Struct)
4534 {
4535 read_access_chain_struct(lhs, chain);
4536 return;
4537 }
4538 else if (type.width != 32 && !hlsl_options.enable_16bit_types)
4539 SPIRV_CROSS_THROW("Reading types other than 32-bit from ByteAddressBuffer not yet supported, unless SM 6.2 and "
4540 "native 16-bit types are enabled.");
4541
4542 string base = chain.base;
4543 if (has_decoration(id: chain.self, decoration: DecorationNonUniform))
4544 convert_non_uniform_expression(expr&: base, ptr_id: chain.self);
4545
4546 bool templated_load = hlsl_options.shader_model >= 62;
4547 string load_expr;
4548
4549 string template_expr;
4550 if (templated_load)
4551 template_expr = join(ts: "<", ts: type_to_glsl(type), ts: ">");
4552
4553 // Load a vector or scalar.
4554 if (type.columns == 1 && !chain.row_major_matrix)
4555 {
4556 const char *load_op = nullptr;
4557 switch (type.vecsize)
4558 {
4559 case 1:
4560 load_op = "Load";
4561 break;
4562 case 2:
4563 load_op = "Load2";
4564 break;
4565 case 3:
4566 load_op = "Load3";
4567 break;
4568 case 4:
4569 load_op = "Load4";
4570 break;
4571 default:
4572 SPIRV_CROSS_THROW("Unknown vector size.");
4573 }
4574
4575 if (templated_load)
4576 load_op = "Load";
4577
4578 load_expr = join(ts&: base, ts: ".", ts&: load_op, ts&: template_expr, ts: "(", ts: chain.dynamic_index, ts: chain.static_index, ts: ")");
4579 }
4580 else if (type.columns == 1)
4581 {
4582 // Strided load since we are loading a column from a row-major matrix.
4583 if (templated_load)
4584 {
4585 auto scalar_type = type;
4586 scalar_type.vecsize = 1;
4587 scalar_type.columns = 1;
4588 template_expr = join(ts: "<", ts: type_to_glsl(type: scalar_type), ts: ">");
4589 if (type.vecsize > 1)
4590 load_expr += type_to_glsl(type) + "(";
4591 }
4592 else if (type.vecsize > 1)
4593 {
4594 load_expr = type_to_glsl(type: target_type);
4595 load_expr += "(";
4596 }
4597
4598 for (uint32_t r = 0; r < type.vecsize; r++)
4599 {
4600 load_expr += join(ts&: base, ts: ".Load", ts&: template_expr, ts: "(", ts: chain.dynamic_index,
4601 ts: chain.static_index + r * chain.matrix_stride, ts: ")");
4602 if (r + 1 < type.vecsize)
4603 load_expr += ", ";
4604 }
4605
4606 if (type.vecsize > 1)
4607 load_expr += ")";
4608 }
4609 else if (!chain.row_major_matrix)
4610 {
4611 // Load a matrix, column-major, the easy case.
4612 const char *load_op = nullptr;
4613 switch (type.vecsize)
4614 {
4615 case 1:
4616 load_op = "Load";
4617 break;
4618 case 2:
4619 load_op = "Load2";
4620 break;
4621 case 3:
4622 load_op = "Load3";
4623 break;
4624 case 4:
4625 load_op = "Load4";
4626 break;
4627 default:
4628 SPIRV_CROSS_THROW("Unknown vector size.");
4629 }
4630
4631 if (templated_load)
4632 {
4633 auto vector_type = type;
4634 vector_type.columns = 1;
4635 template_expr = join(ts: "<", ts: type_to_glsl(type: vector_type), ts: ">");
4636 load_expr = type_to_glsl(type);
4637 load_op = "Load";
4638 }
4639 else
4640 {
4641 // Note, this loading style in HLSL is *actually* row-major, but we always treat matrices as transposed in this backend,
4642 // so row-major is technically column-major ...
4643 load_expr = type_to_glsl(type: target_type);
4644 }
4645 load_expr += "(";
4646
4647 for (uint32_t c = 0; c < type.columns; c++)
4648 {
4649 load_expr += join(ts&: base, ts: ".", ts&: load_op, ts&: template_expr, ts: "(", ts: chain.dynamic_index,
4650 ts: chain.static_index + c * chain.matrix_stride, ts: ")");
4651 if (c + 1 < type.columns)
4652 load_expr += ", ";
4653 }
4654 load_expr += ")";
4655 }
4656 else
4657 {
4658 // Pick out elements one by one ... Hopefully compilers are smart enough to recognize this pattern
4659 // considering HLSL is "row-major decl", but "column-major" memory layout (basically implicit transpose model, ugh) ...
4660
4661 if (templated_load)
4662 {
4663 load_expr = type_to_glsl(type);
4664 auto scalar_type = type;
4665 scalar_type.vecsize = 1;
4666 scalar_type.columns = 1;
4667 template_expr = join(ts: "<", ts: type_to_glsl(type: scalar_type), ts: ">");
4668 }
4669 else
4670 load_expr = type_to_glsl(type: target_type);
4671
4672 load_expr += "(";
4673
4674 for (uint32_t c = 0; c < type.columns; c++)
4675 {
4676 for (uint32_t r = 0; r < type.vecsize; r++)
4677 {
4678 load_expr += join(ts&: base, ts: ".Load", ts&: template_expr, ts: "(", ts: chain.dynamic_index,
4679 ts: chain.static_index + c * (type.width / 8) + r * chain.matrix_stride, ts: ")");
4680
4681 if ((r + 1 < type.vecsize) || (c + 1 < type.columns))
4682 load_expr += ", ";
4683 }
4684 }
4685 load_expr += ")";
4686 }
4687
4688 if (!templated_load)
4689 {
4690 auto bitcast_op = bitcast_glsl_op(out_type: type, in_type: target_type);
4691 if (!bitcast_op.empty())
4692 load_expr = join(ts&: bitcast_op, ts: "(", ts&: load_expr, ts: ")");
4693 }
4694
4695 if (lhs.empty())
4696 {
4697 assert(expr);
4698 *expr = std::move(load_expr);
4699 }
4700 else
4701 statement(ts: lhs, ts: " = ", ts&: load_expr, ts: ";");
4702}
4703
4704void CompilerHLSL::emit_load(const Instruction &instruction)
4705{
4706 auto ops = stream(instr: instruction);
4707
4708 auto *chain = maybe_get<SPIRAccessChain>(id: ops[2]);
4709 if (chain)
4710 {
4711 uint32_t result_type = ops[0];
4712 uint32_t id = ops[1];
4713 uint32_t ptr = ops[2];
4714
4715 auto &type = get<SPIRType>(id: result_type);
4716 bool composite_load = !type.array.empty() || type.basetype == SPIRType::Struct;
4717
4718 if (composite_load)
4719 {
4720 // We cannot make this work in one single expression as we might have nested structures and arrays,
4721 // so unroll the load to an uninitialized temporary.
4722 emit_uninitialized_temporary_expression(type: result_type, id);
4723 read_access_chain(expr: nullptr, lhs: to_expression(id), chain: *chain);
4724 track_expression_read(id: chain->self);
4725 }
4726 else
4727 {
4728 string load_expr;
4729 read_access_chain(expr: &load_expr, lhs: "", chain: *chain);
4730
4731 bool forward = should_forward(id: ptr) && forced_temporaries.find(x: id) == end(cont&: forced_temporaries);
4732
4733 // If we are forwarding this load,
4734 // don't register the read to access chain here, defer that to when we actually use the expression,
4735 // using the add_implied_read_expression mechanism.
4736 if (!forward)
4737 track_expression_read(id: chain->self);
4738
4739 // Do not forward complex load sequences like matrices, structs and arrays.
4740 if (type.columns > 1)
4741 forward = false;
4742
4743 auto &e = emit_op(result_type, result_id: id, rhs: load_expr, forward_rhs: forward, suppress_usage_tracking: true);
4744 e.need_transpose = false;
4745 register_read(expr: id, chain: ptr, forwarded: forward);
4746 inherit_expression_dependencies(dst: id, source: ptr);
4747 if (forward)
4748 add_implied_read_expression(e, source: chain->self);
4749 }
4750 }
4751 else
4752 CompilerGLSL::emit_instruction(instr: instruction);
4753}
4754
4755void CompilerHLSL::write_access_chain_array(const SPIRAccessChain &chain, uint32_t value,
4756 const SmallVector<uint32_t> &composite_chain)
4757{
4758 auto *ptype = &get<SPIRType>(id: chain.basetype);
4759 while (ptype->pointer)
4760 {
4761 ptype = &get<SPIRType>(id: ptype->basetype);
4762 }
4763 auto &type = *ptype;
4764
4765 // Need to use a reserved identifier here since it might shadow an identifier in the access chain input or other loops.
4766 auto ident = get_unique_identifier();
4767
4768 uint32_t id = ir.increase_bound_by(count: 2);
4769 uint32_t int_type_id = id + 1;
4770 SPIRType int_type { OpTypeInt };
4771 int_type.basetype = SPIRType::Int;
4772 int_type.width = 32;
4773 set<SPIRType>(id: int_type_id, args&: int_type);
4774 set<SPIRExpression>(id, args&: ident, args&: int_type_id, args: true);
4775 set_name(id, name: ident);
4776 suppressed_usage_tracking.insert(x: id);
4777
4778 statement(ts: "[unroll]");
4779 statement(ts: "for (int ", ts&: ident, ts: " = 0; ", ts&: ident, ts: " < ", ts: to_array_size(type, index: uint32_t(type.array.size() - 1)), ts: "; ",
4780 ts&: ident, ts: "++)");
4781 begin_scope();
4782 auto subchain = chain;
4783 subchain.dynamic_index = join(ts&: ident, ts: " * ", ts: chain.array_stride, ts: " + ", ts: chain.dynamic_index);
4784 subchain.basetype = type.parent_type;
4785
4786 // Forcefully allow us to use an ID here by setting MSB.
4787 auto subcomposite_chain = composite_chain;
4788 subcomposite_chain.push_back(t: 0x80000000u | id);
4789
4790 if (!get<SPIRType>(id: subchain.basetype).array.empty())
4791 subchain.array_stride = get_decoration(id: subchain.basetype, decoration: DecorationArrayStride);
4792
4793 write_access_chain(chain: subchain, value, composite_chain: subcomposite_chain);
4794 end_scope();
4795}
4796
4797void CompilerHLSL::write_access_chain_struct(const SPIRAccessChain &chain, uint32_t value,
4798 const SmallVector<uint32_t> &composite_chain)
4799{
4800 auto &type = get<SPIRType>(id: chain.basetype);
4801 uint32_t member_count = uint32_t(type.member_types.size());
4802 auto subchain = chain;
4803
4804 auto subcomposite_chain = composite_chain;
4805 subcomposite_chain.push_back(t: 0);
4806
4807 for (uint32_t i = 0; i < member_count; i++)
4808 {
4809 uint32_t offset = type_struct_member_offset(type, index: i);
4810 subchain.static_index = chain.static_index + offset;
4811 subchain.basetype = type.member_types[i];
4812
4813 subchain.matrix_stride = 0;
4814 subchain.array_stride = 0;
4815 subchain.row_major_matrix = false;
4816
4817 auto &member_type = get<SPIRType>(id: subchain.basetype);
4818 if (member_type.columns > 1)
4819 {
4820 subchain.matrix_stride = type_struct_member_matrix_stride(type, index: i);
4821 subchain.row_major_matrix = has_member_decoration(id: type.self, index: i, decoration: DecorationRowMajor);
4822 }
4823
4824 if (!member_type.array.empty())
4825 subchain.array_stride = type_struct_member_array_stride(type, index: i);
4826
4827 subcomposite_chain.back() = i;
4828 write_access_chain(chain: subchain, value, composite_chain: subcomposite_chain);
4829 }
4830}
4831
4832string CompilerHLSL::write_access_chain_value(uint32_t value, const SmallVector<uint32_t> &composite_chain,
4833 bool enclose)
4834{
4835 string ret;
4836 if (composite_chain.empty())
4837 ret = to_expression(id: value);
4838 else
4839 {
4840 AccessChainMeta meta;
4841 ret = access_chain_internal(base: value, indices: composite_chain.data(), count: uint32_t(composite_chain.size()),
4842 flags: ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_LITERAL_MSB_FORCE_ID, meta: &meta);
4843 }
4844
4845 if (enclose)
4846 ret = enclose_expression(expr: ret);
4847 return ret;
4848}
4849
4850void CompilerHLSL::write_access_chain(const SPIRAccessChain &chain, uint32_t value,
4851 const SmallVector<uint32_t> &composite_chain)
4852{
4853 auto &type = get<SPIRType>(id: chain.basetype);
4854
4855 // Make sure we trigger a read of the constituents in the access chain.
4856 track_expression_read(id: chain.self);
4857
4858 SPIRType target_type { is_scalar(type) ? OpTypeInt : type.op };
4859 target_type.basetype = SPIRType::UInt;
4860 target_type.vecsize = type.vecsize;
4861 target_type.columns = type.columns;
4862
4863 if (!type.array.empty())
4864 {
4865 write_access_chain_array(chain, value, composite_chain);
4866 register_write(chain: chain.self);
4867 return;
4868 }
4869 else if (type.basetype == SPIRType::Struct)
4870 {
4871 write_access_chain_struct(chain, value, composite_chain);
4872 register_write(chain: chain.self);
4873 return;
4874 }
4875 else if (type.width != 32 && !hlsl_options.enable_16bit_types)
4876 SPIRV_CROSS_THROW("Writing types other than 32-bit to RWByteAddressBuffer not yet supported, unless SM 6.2 and "
4877 "native 16-bit types are enabled.");
4878
4879 bool templated_store = hlsl_options.shader_model >= 62;
4880
4881 auto base = chain.base;
4882 if (has_decoration(id: chain.self, decoration: DecorationNonUniform))
4883 convert_non_uniform_expression(expr&: base, ptr_id: chain.self);
4884
4885 string template_expr;
4886 if (templated_store)
4887 template_expr = join(ts: "<", ts: type_to_glsl(type), ts: ">");
4888
4889 if (type.columns == 1 && !chain.row_major_matrix)
4890 {
4891 const char *store_op = nullptr;
4892 switch (type.vecsize)
4893 {
4894 case 1:
4895 store_op = "Store";
4896 break;
4897 case 2:
4898 store_op = "Store2";
4899 break;
4900 case 3:
4901 store_op = "Store3";
4902 break;
4903 case 4:
4904 store_op = "Store4";
4905 break;
4906 default:
4907 SPIRV_CROSS_THROW("Unknown vector size.");
4908 }
4909
4910 auto store_expr = write_access_chain_value(value, composite_chain, enclose: false);
4911
4912 if (!templated_store)
4913 {
4914 auto bitcast_op = bitcast_glsl_op(out_type: target_type, in_type: type);
4915 if (!bitcast_op.empty())
4916 store_expr = join(ts&: bitcast_op, ts: "(", ts&: store_expr, ts: ")");
4917 }
4918 else
4919 store_op = "Store";
4920 statement(ts&: base, ts: ".", ts&: store_op, ts&: template_expr, ts: "(", ts: chain.dynamic_index, ts: chain.static_index, ts: ", ",
4921 ts&: store_expr, ts: ");");
4922 }
4923 else if (type.columns == 1)
4924 {
4925 if (templated_store)
4926 {
4927 auto scalar_type = type;
4928 scalar_type.vecsize = 1;
4929 scalar_type.columns = 1;
4930 template_expr = join(ts: "<", ts: type_to_glsl(type: scalar_type), ts: ">");
4931 }
4932
4933 // Strided store.
4934 for (uint32_t r = 0; r < type.vecsize; r++)
4935 {
4936 auto store_expr = write_access_chain_value(value, composite_chain, enclose: true);
4937 if (type.vecsize > 1)
4938 {
4939 store_expr += ".";
4940 store_expr += index_to_swizzle(index: r);
4941 }
4942 remove_duplicate_swizzle(op&: store_expr);
4943
4944 if (!templated_store)
4945 {
4946 auto bitcast_op = bitcast_glsl_op(out_type: target_type, in_type: type);
4947 if (!bitcast_op.empty())
4948 store_expr = join(ts&: bitcast_op, ts: "(", ts&: store_expr, ts: ")");
4949 }
4950
4951 statement(ts&: base, ts: ".Store", ts&: template_expr, ts: "(", ts: chain.dynamic_index,
4952 ts: chain.static_index + chain.matrix_stride * r, ts: ", ", ts&: store_expr, ts: ");");
4953 }
4954 }
4955 else if (!chain.row_major_matrix)
4956 {
4957 const char *store_op = nullptr;
4958 switch (type.vecsize)
4959 {
4960 case 1:
4961 store_op = "Store";
4962 break;
4963 case 2:
4964 store_op = "Store2";
4965 break;
4966 case 3:
4967 store_op = "Store3";
4968 break;
4969 case 4:
4970 store_op = "Store4";
4971 break;
4972 default:
4973 SPIRV_CROSS_THROW("Unknown vector size.");
4974 }
4975
4976 if (templated_store)
4977 {
4978 store_op = "Store";
4979 auto vector_type = type;
4980 vector_type.columns = 1;
4981 template_expr = join(ts: "<", ts: type_to_glsl(type: vector_type), ts: ">");
4982 }
4983
4984 for (uint32_t c = 0; c < type.columns; c++)
4985 {
4986 auto store_expr = join(ts: write_access_chain_value(value, composite_chain, enclose: true), ts: "[", ts&: c, ts: "]");
4987
4988 if (!templated_store)
4989 {
4990 auto bitcast_op = bitcast_glsl_op(out_type: target_type, in_type: type);
4991 if (!bitcast_op.empty())
4992 store_expr = join(ts&: bitcast_op, ts: "(", ts&: store_expr, ts: ")");
4993 }
4994
4995 statement(ts&: base, ts: ".", ts&: store_op, ts&: template_expr, ts: "(", ts: chain.dynamic_index,
4996 ts: chain.static_index + c * chain.matrix_stride, ts: ", ", ts&: store_expr, ts: ");");
4997 }
4998 }
4999 else
5000 {
5001 if (templated_store)
5002 {
5003 auto scalar_type = type;
5004 scalar_type.vecsize = 1;
5005 scalar_type.columns = 1;
5006 template_expr = join(ts: "<", ts: type_to_glsl(type: scalar_type), ts: ">");
5007 }
5008
5009 for (uint32_t r = 0; r < type.vecsize; r++)
5010 {
5011 for (uint32_t c = 0; c < type.columns; c++)
5012 {
5013 auto store_expr =
5014 join(ts: write_access_chain_value(value, composite_chain, enclose: true), ts: "[", ts&: c, ts: "].", ts: index_to_swizzle(index: r));
5015 remove_duplicate_swizzle(op&: store_expr);
5016 auto bitcast_op = bitcast_glsl_op(out_type: target_type, in_type: type);
5017 if (!bitcast_op.empty())
5018 store_expr = join(ts&: bitcast_op, ts: "(", ts&: store_expr, ts: ")");
5019 statement(ts&: base, ts: ".Store", ts&: template_expr, ts: "(", ts: chain.dynamic_index,
5020 ts: chain.static_index + c * (type.width / 8) + r * chain.matrix_stride, ts: ", ", ts&: store_expr, ts: ");");
5021 }
5022 }
5023 }
5024
5025 register_write(chain: chain.self);
5026}
5027
5028void CompilerHLSL::emit_store(const Instruction &instruction)
5029{
5030 auto ops = stream(instr: instruction);
5031 if (options.vertex.flip_vert_y)
5032 {
5033 auto *expr = maybe_get<SPIRExpression>(id: ops[0]);
5034 if (expr != nullptr && expr->access_meshlet_position_y)
5035 {
5036 auto lhs = to_dereferenced_expression(id: ops[0]);
5037 auto rhs = to_unpacked_expression(id: ops[1]);
5038 statement(ts&: lhs, ts: " = spvFlipVertY(", ts&: rhs, ts: ");");
5039 register_write(chain: ops[0]);
5040 return;
5041 }
5042 }
5043
5044 auto *chain = maybe_get<SPIRAccessChain>(id: ops[0]);
5045 if (chain)
5046 write_access_chain(chain: *chain, value: ops[1], composite_chain: {});
5047 else
5048 CompilerGLSL::emit_instruction(instr: instruction);
5049}
5050
5051void CompilerHLSL::emit_access_chain(const Instruction &instruction)
5052{
5053 auto ops = stream(instr: instruction);
5054 uint32_t length = instruction.length;
5055
5056 bool need_byte_access_chain = false;
5057 auto &type = expression_type(id: ops[2]);
5058 const auto *chain = maybe_get<SPIRAccessChain>(id: ops[2]);
5059
5060 if (chain)
5061 {
5062 // Keep tacking on an existing access chain.
5063 need_byte_access_chain = true;
5064 }
5065 else if (type.storage == StorageClassStorageBuffer || has_decoration(id: type.self, decoration: DecorationBufferBlock))
5066 {
5067 // If we are starting to poke into an SSBO, we are dealing with ByteAddressBuffers, and we need
5068 // to emit SPIRAccessChain rather than a plain SPIRExpression.
5069 uint32_t chain_arguments = length - 3;
5070 if (chain_arguments > type.array.size())
5071 need_byte_access_chain = true;
5072 }
5073
5074 if (need_byte_access_chain)
5075 {
5076 // If we have a chain variable, we are already inside the SSBO, and any array type will refer to arrays within a block,
5077 // and not array of SSBO.
5078 uint32_t to_plain_buffer_length = chain ? 0u : static_cast<uint32_t>(type.array.size());
5079
5080 auto *backing_variable = maybe_get_backing_variable(chain: ops[2]);
5081
5082 if (backing_variable != nullptr && is_user_type_structured(id: backing_variable->self))
5083 {
5084 CompilerGLSL::emit_instruction(instr: instruction);
5085 return;
5086 }
5087
5088 string base;
5089 if (to_plain_buffer_length != 0)
5090 base = access_chain(base: ops[2], indices: &ops[3], count: to_plain_buffer_length, target_type: get<SPIRType>(id: ops[0]));
5091 else if (chain)
5092 base = chain->base;
5093 else
5094 base = to_expression(id: ops[2]);
5095
5096 // Start traversing type hierarchy at the proper non-pointer types.
5097 auto *basetype = &get_pointee_type(type);
5098
5099 // Traverse the type hierarchy down to the actual buffer types.
5100 for (uint32_t i = 0; i < to_plain_buffer_length; i++)
5101 {
5102 assert(basetype->parent_type);
5103 basetype = &get<SPIRType>(id: basetype->parent_type);
5104 }
5105
5106 uint32_t matrix_stride = 0;
5107 uint32_t array_stride = 0;
5108 bool row_major_matrix = false;
5109
5110 // Inherit matrix information.
5111 if (chain)
5112 {
5113 matrix_stride = chain->matrix_stride;
5114 row_major_matrix = chain->row_major_matrix;
5115 array_stride = chain->array_stride;
5116 }
5117
5118 auto offsets = flattened_access_chain_offset(basetype: *basetype, indices: &ops[3 + to_plain_buffer_length],
5119 count: length - 3 - to_plain_buffer_length, offset: 0, word_stride: 1, need_transpose: &row_major_matrix,
5120 matrix_stride: &matrix_stride, array_stride: &array_stride);
5121
5122 auto &e = set<SPIRAccessChain>(id: ops[1], args: ops[0], args: type.storage, args&: base, args&: offsets.first, args&: offsets.second);
5123 e.row_major_matrix = row_major_matrix;
5124 e.matrix_stride = matrix_stride;
5125 e.array_stride = array_stride;
5126 e.immutable = should_forward(id: ops[2]);
5127 e.loaded_from = backing_variable ? backing_variable->self : ID(0);
5128
5129 if (chain)
5130 {
5131 e.dynamic_index += chain->dynamic_index;
5132 e.static_index += chain->static_index;
5133 }
5134
5135 for (uint32_t i = 2; i < length; i++)
5136 {
5137 inherit_expression_dependencies(dst: ops[1], source: ops[i]);
5138 add_implied_read_expression(e, source: ops[i]);
5139 }
5140 }
5141 else
5142 {
5143 CompilerGLSL::emit_instruction(instr: instruction);
5144 }
5145}
5146
5147void CompilerHLSL::emit_atomic(const uint32_t *ops, uint32_t length, spv::Op op)
5148{
5149 const char *atomic_op = nullptr;
5150
5151 string value_expr;
5152 if (op != OpAtomicIDecrement && op != OpAtomicIIncrement && op != OpAtomicLoad && op != OpAtomicStore)
5153 value_expr = to_expression(id: ops[op == OpAtomicCompareExchange ? 6 : 5]);
5154
5155 bool is_atomic_store = false;
5156
5157 switch (op)
5158 {
5159 case OpAtomicIIncrement:
5160 atomic_op = "InterlockedAdd";
5161 value_expr = "1";
5162 break;
5163
5164 case OpAtomicIDecrement:
5165 atomic_op = "InterlockedAdd";
5166 value_expr = "-1";
5167 break;
5168
5169 case OpAtomicLoad:
5170 atomic_op = "InterlockedAdd";
5171 value_expr = "0";
5172 break;
5173
5174 case OpAtomicISub:
5175 atomic_op = "InterlockedAdd";
5176 value_expr = join(ts: "-", ts: enclose_expression(expr: value_expr));
5177 break;
5178
5179 case OpAtomicSMin:
5180 case OpAtomicUMin:
5181 atomic_op = "InterlockedMin";
5182 break;
5183
5184 case OpAtomicSMax:
5185 case OpAtomicUMax:
5186 atomic_op = "InterlockedMax";
5187 break;
5188
5189 case OpAtomicAnd:
5190 atomic_op = "InterlockedAnd";
5191 break;
5192
5193 case OpAtomicOr:
5194 atomic_op = "InterlockedOr";
5195 break;
5196
5197 case OpAtomicXor:
5198 atomic_op = "InterlockedXor";
5199 break;
5200
5201 case OpAtomicIAdd:
5202 atomic_op = "InterlockedAdd";
5203 break;
5204
5205 case OpAtomicExchange:
5206 atomic_op = "InterlockedExchange";
5207 break;
5208
5209 case OpAtomicStore:
5210 atomic_op = "InterlockedExchange";
5211 is_atomic_store = true;
5212 break;
5213
5214 case OpAtomicCompareExchange:
5215 if (length < 8)
5216 SPIRV_CROSS_THROW("Not enough data for opcode.");
5217 atomic_op = "InterlockedCompareExchange";
5218 value_expr = join(ts: to_expression(id: ops[7]), ts: ", ", ts&: value_expr);
5219 break;
5220
5221 default:
5222 SPIRV_CROSS_THROW("Unknown atomic opcode.");
5223 }
5224
5225 if (is_atomic_store)
5226 {
5227 auto &data_type = expression_type(id: ops[0]);
5228 auto *chain = maybe_get<SPIRAccessChain>(id: ops[0]);
5229
5230 auto &tmp_id = extra_sub_expressions[ops[0]];
5231 if (!tmp_id)
5232 {
5233 tmp_id = ir.increase_bound_by(count: 1);
5234 emit_uninitialized_temporary_expression(type: get_pointee_type(type: data_type).self, id: tmp_id);
5235 }
5236
5237 if (data_type.storage == StorageClassImage || !chain)
5238 {
5239 statement(ts&: atomic_op, ts: "(", ts: to_non_uniform_aware_expression(id: ops[0]), ts: ", ",
5240 ts: to_expression(id: ops[3]), ts: ", ", ts: to_expression(id: tmp_id), ts: ");");
5241 }
5242 else
5243 {
5244 string base = chain->base;
5245 if (has_decoration(id: chain->self, decoration: DecorationNonUniform))
5246 convert_non_uniform_expression(expr&: base, ptr_id: chain->self);
5247 // RWByteAddress buffer is always uint in its underlying type.
5248 statement(ts&: base, ts: ".", ts&: atomic_op, ts: "(", ts&: chain->dynamic_index, ts&: chain->static_index, ts: ", ",
5249 ts: to_expression(id: ops[3]), ts: ", ", ts: to_expression(id: tmp_id), ts: ");");
5250 }
5251 }
5252 else
5253 {
5254 uint32_t result_type = ops[0];
5255 uint32_t id = ops[1];
5256 forced_temporaries.insert(x: ops[1]);
5257
5258 auto &type = get<SPIRType>(id: result_type);
5259 statement(ts: variable_decl(type, name: to_name(id)), ts: ";");
5260
5261 auto &data_type = expression_type(id: ops[2]);
5262 auto *chain = maybe_get<SPIRAccessChain>(id: ops[2]);
5263 SPIRType::BaseType expr_type;
5264 if (data_type.storage == StorageClassImage || !chain)
5265 {
5266 statement(ts&: atomic_op, ts: "(", ts: to_non_uniform_aware_expression(id: ops[2]), ts: ", ", ts&: value_expr, ts: ", ", ts: to_name(id), ts: ");");
5267 expr_type = data_type.basetype;
5268 }
5269 else
5270 {
5271 // RWByteAddress buffer is always uint in its underlying type.
5272 string base = chain->base;
5273 if (has_decoration(id: chain->self, decoration: DecorationNonUniform))
5274 convert_non_uniform_expression(expr&: base, ptr_id: chain->self);
5275 expr_type = SPIRType::UInt;
5276 statement(ts&: base, ts: ".", ts&: atomic_op, ts: "(", ts&: chain->dynamic_index, ts&: chain->static_index, ts: ", ", ts&: value_expr,
5277 ts: ", ", ts: to_name(id), ts: ");");
5278 }
5279
5280 auto expr = bitcast_expression(target_type: type, expr_type, expr: to_name(id));
5281 set<SPIRExpression>(id, args&: expr, args&: result_type, args: true);
5282 }
5283 flush_all_atomic_capable_variables();
5284}
5285
5286void CompilerHLSL::emit_subgroup_op(const Instruction &i)
5287{
5288 if (hlsl_options.shader_model < 60)
5289 SPIRV_CROSS_THROW("Wave ops requires SM 6.0 or higher.");
5290
5291 const uint32_t *ops = stream(instr: i);
5292 auto op = static_cast<Op>(i.op);
5293
5294 uint32_t result_type = ops[0];
5295 uint32_t id = ops[1];
5296
5297 auto scope = static_cast<Scope>(evaluate_constant_u32(id: ops[2]));
5298 if (scope != ScopeSubgroup)
5299 SPIRV_CROSS_THROW("Only subgroup scope is supported.");
5300
5301 const auto make_inclusive_Sum = [&](const string &expr) -> string {
5302 return join(ts: expr, ts: " + ", ts: to_expression(id: ops[4]));
5303 };
5304
5305 const auto make_inclusive_Product = [&](const string &expr) -> string {
5306 return join(ts: expr, ts: " * ", ts: to_expression(id: ops[4]));
5307 };
5308
5309 // If we need to do implicit bitcasts, make sure we do it with the correct type.
5310 uint32_t integer_width = get_integer_width_for_instruction(instr: i);
5311 auto int_type = to_signed_basetype(width: integer_width);
5312 auto uint_type = to_unsigned_basetype(width: integer_width);
5313
5314#define make_inclusive_BitAnd(expr) ""
5315#define make_inclusive_BitOr(expr) ""
5316#define make_inclusive_BitXor(expr) ""
5317#define make_inclusive_Min(expr) ""
5318#define make_inclusive_Max(expr) ""
5319
5320 switch (op)
5321 {
5322 case OpGroupNonUniformElect:
5323 emit_op(result_type, result_id: id, rhs: "WaveIsFirstLane()", forward_rhs: true);
5324 break;
5325
5326 case OpGroupNonUniformBroadcast:
5327 emit_binary_func_op(result_type, result_id: id, op0: ops[3], op1: ops[4], op: "WaveReadLaneAt");
5328 break;
5329
5330 case OpGroupNonUniformBroadcastFirst:
5331 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "WaveReadLaneFirst");
5332 break;
5333
5334 case OpGroupNonUniformBallot:
5335 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "WaveActiveBallot");
5336 break;
5337
5338 case OpGroupNonUniformInverseBallot:
5339 SPIRV_CROSS_THROW("Cannot trivially implement InverseBallot in HLSL.");
5340
5341 case OpGroupNonUniformBallotBitExtract:
5342 SPIRV_CROSS_THROW("Cannot trivially implement BallotBitExtract in HLSL.");
5343
5344 case OpGroupNonUniformBallotFindLSB:
5345 SPIRV_CROSS_THROW("Cannot trivially implement BallotFindLSB in HLSL.");
5346
5347 case OpGroupNonUniformBallotFindMSB:
5348 SPIRV_CROSS_THROW("Cannot trivially implement BallotFindMSB in HLSL.");
5349
5350 case OpGroupNonUniformBallotBitCount:
5351 {
5352 auto operation = static_cast<GroupOperation>(ops[3]);
5353 bool forward = should_forward(id: ops[4]);
5354 if (operation == GroupOperationReduce)
5355 {
5356 auto left = join(ts: "countbits(", ts: to_enclosed_expression(id: ops[4]), ts: ".x) + countbits(",
5357 ts: to_enclosed_expression(id: ops[4]), ts: ".y)");
5358 auto right = join(ts: "countbits(", ts: to_enclosed_expression(id: ops[4]), ts: ".z) + countbits(",
5359 ts: to_enclosed_expression(id: ops[4]), ts: ".w)");
5360 emit_op(result_type, result_id: id, rhs: join(ts&: left, ts: " + ", ts&: right), forward_rhs: forward);
5361 inherit_expression_dependencies(dst: id, source: ops[4]);
5362 }
5363 else if (operation == GroupOperationInclusiveScan)
5364 {
5365 auto left = join(ts: "countbits(", ts: to_enclosed_expression(id: ops[4]), ts: ".x & gl_SubgroupLeMask.x) + countbits(",
5366 ts: to_enclosed_expression(id: ops[4]), ts: ".y & gl_SubgroupLeMask.y)");
5367 auto right = join(ts: "countbits(", ts: to_enclosed_expression(id: ops[4]), ts: ".z & gl_SubgroupLeMask.z) + countbits(",
5368 ts: to_enclosed_expression(id: ops[4]), ts: ".w & gl_SubgroupLeMask.w)");
5369 emit_op(result_type, result_id: id, rhs: join(ts&: left, ts: " + ", ts&: right), forward_rhs: forward);
5370 if (!active_input_builtins.get(bit: BuiltInSubgroupLeMask))
5371 {
5372 active_input_builtins.set(BuiltInSubgroupLeMask);
5373 force_recompile_guarantee_forward_progress();
5374 }
5375 }
5376 else if (operation == GroupOperationExclusiveScan)
5377 {
5378 auto left = join(ts: "countbits(", ts: to_enclosed_expression(id: ops[4]), ts: ".x & gl_SubgroupLtMask.x) + countbits(",
5379 ts: to_enclosed_expression(id: ops[4]), ts: ".y & gl_SubgroupLtMask.y)");
5380 auto right = join(ts: "countbits(", ts: to_enclosed_expression(id: ops[4]), ts: ".z & gl_SubgroupLtMask.z) + countbits(",
5381 ts: to_enclosed_expression(id: ops[4]), ts: ".w & gl_SubgroupLtMask.w)");
5382 emit_op(result_type, result_id: id, rhs: join(ts&: left, ts: " + ", ts&: right), forward_rhs: forward);
5383 if (!active_input_builtins.get(bit: BuiltInSubgroupLtMask))
5384 {
5385 active_input_builtins.set(BuiltInSubgroupLtMask);
5386 force_recompile_guarantee_forward_progress();
5387 }
5388 }
5389 else
5390 SPIRV_CROSS_THROW("Invalid BitCount operation.");
5391 break;
5392 }
5393
5394 case OpGroupNonUniformShuffle:
5395 emit_binary_func_op(result_type, result_id: id, op0: ops[3], op1: ops[4], op: "WaveReadLaneAt");
5396 break;
5397 case OpGroupNonUniformShuffleXor:
5398 {
5399 bool forward = should_forward(id: ops[3]);
5400 emit_op(result_type: ops[0], result_id: ops[1],
5401 rhs: join(ts: "WaveReadLaneAt(", ts: to_unpacked_expression(id: ops[3]), ts: ", ",
5402 ts: "WaveGetLaneIndex() ^ ", ts: to_enclosed_expression(id: ops[4]), ts: ")"), forward_rhs: forward);
5403 inherit_expression_dependencies(dst: ops[1], source: ops[3]);
5404 break;
5405 }
5406 case OpGroupNonUniformShuffleUp:
5407 {
5408 bool forward = should_forward(id: ops[3]);
5409 emit_op(result_type: ops[0], result_id: ops[1],
5410 rhs: join(ts: "WaveReadLaneAt(", ts: to_unpacked_expression(id: ops[3]), ts: ", ",
5411 ts: "WaveGetLaneIndex() - ", ts: to_enclosed_expression(id: ops[4]), ts: ")"), forward_rhs: forward);
5412 inherit_expression_dependencies(dst: ops[1], source: ops[3]);
5413 break;
5414 }
5415 case OpGroupNonUniformShuffleDown:
5416 {
5417 bool forward = should_forward(id: ops[3]);
5418 emit_op(result_type: ops[0], result_id: ops[1],
5419 rhs: join(ts: "WaveReadLaneAt(", ts: to_unpacked_expression(id: ops[3]), ts: ", ",
5420 ts: "WaveGetLaneIndex() + ", ts: to_enclosed_expression(id: ops[4]), ts: ")"), forward_rhs: forward);
5421 inherit_expression_dependencies(dst: ops[1], source: ops[3]);
5422 break;
5423 }
5424
5425 case OpGroupNonUniformAll:
5426 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "WaveActiveAllTrue");
5427 break;
5428
5429 case OpGroupNonUniformAny:
5430 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "WaveActiveAnyTrue");
5431 break;
5432
5433 case OpGroupNonUniformAllEqual:
5434 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "WaveActiveAllEqual");
5435 break;
5436
5437 // clang-format off
5438#define HLSL_GROUP_OP(op, hlsl_op, supports_scan) \
5439case OpGroupNonUniform##op: \
5440 { \
5441 auto operation = static_cast<GroupOperation>(ops[3]); \
5442 if (operation == GroupOperationReduce) \
5443 emit_unary_func_op(result_type, id, ops[4], "WaveActive" #hlsl_op); \
5444 else if (operation == GroupOperationInclusiveScan && supports_scan) \
5445 { \
5446 bool forward = should_forward(ops[4]); \
5447 emit_op(result_type, id, make_inclusive_##hlsl_op (join("WavePrefix" #hlsl_op, "(", to_expression(ops[4]), ")")), forward); \
5448 inherit_expression_dependencies(id, ops[4]); \
5449 } \
5450 else if (operation == GroupOperationExclusiveScan && supports_scan) \
5451 emit_unary_func_op(result_type, id, ops[4], "WavePrefix" #hlsl_op); \
5452 else if (operation == GroupOperationClusteredReduce) \
5453 SPIRV_CROSS_THROW("Cannot trivially implement ClusteredReduce in HLSL."); \
5454 else \
5455 SPIRV_CROSS_THROW("Invalid group operation."); \
5456 break; \
5457 }
5458
5459#define HLSL_GROUP_OP_CAST(op, hlsl_op, type) \
5460case OpGroupNonUniform##op: \
5461 { \
5462 auto operation = static_cast<GroupOperation>(ops[3]); \
5463 if (operation == GroupOperationReduce) \
5464 emit_unary_func_op_cast(result_type, id, ops[4], "WaveActive" #hlsl_op, type, type); \
5465 else \
5466 SPIRV_CROSS_THROW("Invalid group operation."); \
5467 break; \
5468 }
5469
5470 HLSL_GROUP_OP(FAdd, Sum, true)
5471 HLSL_GROUP_OP(FMul, Product, true)
5472 HLSL_GROUP_OP(FMin, Min, false)
5473 HLSL_GROUP_OP(FMax, Max, false)
5474 HLSL_GROUP_OP(IAdd, Sum, true)
5475 HLSL_GROUP_OP(IMul, Product, true)
5476 HLSL_GROUP_OP_CAST(SMin, Min, int_type)
5477 HLSL_GROUP_OP_CAST(SMax, Max, int_type)
5478 HLSL_GROUP_OP_CAST(UMin, Min, uint_type)
5479 HLSL_GROUP_OP_CAST(UMax, Max, uint_type)
5480 HLSL_GROUP_OP(BitwiseAnd, BitAnd, false)
5481 HLSL_GROUP_OP(BitwiseOr, BitOr, false)
5482 HLSL_GROUP_OP(BitwiseXor, BitXor, false)
5483 HLSL_GROUP_OP_CAST(LogicalAnd, BitAnd, uint_type)
5484 HLSL_GROUP_OP_CAST(LogicalOr, BitOr, uint_type)
5485 HLSL_GROUP_OP_CAST(LogicalXor, BitXor, uint_type)
5486
5487#undef HLSL_GROUP_OP
5488#undef HLSL_GROUP_OP_CAST
5489 // clang-format on
5490
5491 case OpGroupNonUniformQuadSwap:
5492 {
5493 uint32_t direction = evaluate_constant_u32(id: ops[4]);
5494 if (direction == 0)
5495 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "QuadReadAcrossX");
5496 else if (direction == 1)
5497 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "QuadReadAcrossY");
5498 else if (direction == 2)
5499 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "QuadReadAcrossDiagonal");
5500 else
5501 SPIRV_CROSS_THROW("Invalid quad swap direction.");
5502 break;
5503 }
5504
5505 case OpGroupNonUniformQuadBroadcast:
5506 {
5507 emit_binary_func_op(result_type, result_id: id, op0: ops[3], op1: ops[4], op: "QuadReadLaneAt");
5508 break;
5509 }
5510
5511 default:
5512 SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
5513 }
5514
5515 register_control_dependent_expression(expr: id);
5516}
5517
5518void CompilerHLSL::emit_instruction(const Instruction &instruction)
5519{
5520 auto ops = stream(instr: instruction);
5521 auto opcode = static_cast<Op>(instruction.op);
5522
5523#define HLSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
5524#define HLSL_BOP_CAST(op, type) \
5525 emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode), false)
5526#define HLSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
5527#define HLSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
5528#define HLSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
5529#define HLSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
5530#define HLSL_BFOP_CAST(op, type) \
5531 emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
5532#define HLSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
5533#define HLSL_UFOP(op) emit_unary_func_op(ops[0], ops[1], ops[2], #op)
5534
5535 // If we need to do implicit bitcasts, make sure we do it with the correct type.
5536 uint32_t integer_width = get_integer_width_for_instruction(instr: instruction);
5537 auto int_type = to_signed_basetype(width: integer_width);
5538 auto uint_type = to_unsigned_basetype(width: integer_width);
5539
5540 opcode = get_remapped_spirv_op(op: opcode);
5541
5542 switch (opcode)
5543 {
5544 case OpAccessChain:
5545 case OpInBoundsAccessChain:
5546 {
5547 emit_access_chain(instruction);
5548 break;
5549 }
5550 case OpBitcast:
5551 {
5552 auto bitcast_type = get_bitcast_type(result_type: ops[0], op0: ops[2]);
5553 if (bitcast_type == CompilerHLSL::TypeNormal)
5554 CompilerGLSL::emit_instruction(instr: instruction);
5555 else
5556 {
5557 if (!requires_uint2_packing)
5558 {
5559 requires_uint2_packing = true;
5560 force_recompile();
5561 }
5562
5563 if (bitcast_type == CompilerHLSL::TypePackUint2x32)
5564 emit_unary_func_op(result_type: ops[0], result_id: ops[1], op0: ops[2], op: "spvPackUint2x32");
5565 else
5566 emit_unary_func_op(result_type: ops[0], result_id: ops[1], op0: ops[2], op: "spvUnpackUint2x32");
5567 }
5568
5569 break;
5570 }
5571
5572 case OpSelect:
5573 {
5574 auto &value_type = expression_type(id: ops[3]);
5575 if (value_type.basetype == SPIRType::Struct || is_array(type: value_type))
5576 {
5577 // HLSL does not support ternary expressions on composites.
5578 // Cannot use branches, since we might be in a continue block
5579 // where explicit control flow is prohibited.
5580 // Emit a helper function where we can use control flow.
5581 TypeID value_type_id = expression_type_id(id: ops[3]);
5582 auto itr = std::find(first: composite_selection_workaround_types.begin(),
5583 last: composite_selection_workaround_types.end(),
5584 val: value_type_id);
5585 if (itr == composite_selection_workaround_types.end())
5586 {
5587 composite_selection_workaround_types.push_back(x: value_type_id);
5588 force_recompile();
5589 }
5590 emit_uninitialized_temporary_expression(type: ops[0], id: ops[1]);
5591 statement(ts: "spvSelectComposite(",
5592 ts: to_expression(id: ops[1]), ts: ", ", ts: to_expression(id: ops[2]), ts: ", ",
5593 ts: to_expression(id: ops[3]), ts: ", ", ts: to_expression(id: ops[4]), ts: ");");
5594 }
5595 else
5596 CompilerGLSL::emit_instruction(instr: instruction);
5597 break;
5598 }
5599
5600 case OpStore:
5601 {
5602 emit_store(instruction);
5603 break;
5604 }
5605
5606 case OpLoad:
5607 {
5608 emit_load(instruction);
5609 break;
5610 }
5611
5612 case OpMatrixTimesVector:
5613 {
5614 // Matrices are kept in a transposed state all the time, flip multiplication order always.
5615 emit_binary_func_op(result_type: ops[0], result_id: ops[1], op0: ops[3], op1: ops[2], op: "mul");
5616 break;
5617 }
5618
5619 case OpVectorTimesMatrix:
5620 {
5621 // Matrices are kept in a transposed state all the time, flip multiplication order always.
5622 emit_binary_func_op(result_type: ops[0], result_id: ops[1], op0: ops[3], op1: ops[2], op: "mul");
5623 break;
5624 }
5625
5626 case OpMatrixTimesMatrix:
5627 {
5628 // Matrices are kept in a transposed state all the time, flip multiplication order always.
5629 emit_binary_func_op(result_type: ops[0], result_id: ops[1], op0: ops[3], op1: ops[2], op: "mul");
5630 break;
5631 }
5632
5633 case OpOuterProduct:
5634 {
5635 uint32_t result_type = ops[0];
5636 uint32_t id = ops[1];
5637 uint32_t a = ops[2];
5638 uint32_t b = ops[3];
5639
5640 auto &type = get<SPIRType>(id: result_type);
5641 string expr = type_to_glsl_constructor(type);
5642 expr += "(";
5643 for (uint32_t col = 0; col < type.columns; col++)
5644 {
5645 expr += to_enclosed_expression(id: a);
5646 expr += " * ";
5647 expr += to_extract_component_expression(id: b, index: col);
5648 if (col + 1 < type.columns)
5649 expr += ", ";
5650 }
5651 expr += ")";
5652 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: should_forward(id: a) && should_forward(id: b));
5653 inherit_expression_dependencies(dst: id, source: a);
5654 inherit_expression_dependencies(dst: id, source: b);
5655 break;
5656 }
5657
5658 case OpFMod:
5659 {
5660 if (!requires_op_fmod)
5661 {
5662 requires_op_fmod = true;
5663 force_recompile();
5664 }
5665 CompilerGLSL::emit_instruction(instr: instruction);
5666 break;
5667 }
5668
5669 case OpFRem:
5670 emit_binary_func_op(result_type: ops[0], result_id: ops[1], op0: ops[2], op1: ops[3], op: "fmod");
5671 break;
5672
5673 case OpImage:
5674 {
5675 uint32_t result_type = ops[0];
5676 uint32_t id = ops[1];
5677 auto *combined = maybe_get<SPIRCombinedImageSampler>(id: ops[2]);
5678
5679 if (combined)
5680 {
5681 auto &e = emit_op(result_type, result_id: id, rhs: to_expression(id: combined->image), forward_rhs: true, suppress_usage_tracking: true);
5682 auto *var = maybe_get_backing_variable(chain: combined->image);
5683 if (var)
5684 e.loaded_from = var->self;
5685 }
5686 else
5687 {
5688 auto &e = emit_op(result_type, result_id: id, rhs: to_expression(id: ops[2]), forward_rhs: true, suppress_usage_tracking: true);
5689 auto *var = maybe_get_backing_variable(chain: ops[2]);
5690 if (var)
5691 e.loaded_from = var->self;
5692 }
5693 break;
5694 }
5695
5696 case OpDPdx:
5697 HLSL_UFOP(ddx);
5698 register_control_dependent_expression(expr: ops[1]);
5699 break;
5700
5701 case OpDPdy:
5702 HLSL_UFOP(ddy);
5703 register_control_dependent_expression(expr: ops[1]);
5704 break;
5705
5706 case OpDPdxFine:
5707 HLSL_UFOP(ddx_fine);
5708 register_control_dependent_expression(expr: ops[1]);
5709 break;
5710
5711 case OpDPdyFine:
5712 HLSL_UFOP(ddy_fine);
5713 register_control_dependent_expression(expr: ops[1]);
5714 break;
5715
5716 case OpDPdxCoarse:
5717 HLSL_UFOP(ddx_coarse);
5718 register_control_dependent_expression(expr: ops[1]);
5719 break;
5720
5721 case OpDPdyCoarse:
5722 HLSL_UFOP(ddy_coarse);
5723 register_control_dependent_expression(expr: ops[1]);
5724 break;
5725
5726 case OpFwidth:
5727 case OpFwidthCoarse:
5728 case OpFwidthFine:
5729 HLSL_UFOP(fwidth);
5730 register_control_dependent_expression(expr: ops[1]);
5731 break;
5732
5733 case OpLogicalNot:
5734 {
5735 auto result_type = ops[0];
5736 auto id = ops[1];
5737 auto &type = get<SPIRType>(id: result_type);
5738
5739 if (type.vecsize > 1)
5740 emit_unrolled_unary_op(result_type, result_id: id, operand: ops[2], op: "!");
5741 else
5742 HLSL_UOP(!);
5743 break;
5744 }
5745
5746 case OpIEqual:
5747 {
5748 auto result_type = ops[0];
5749 auto id = ops[1];
5750
5751 if (expression_type(id: ops[2]).vecsize > 1)
5752 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "==", negate: false, expected_type: SPIRType::Unknown);
5753 else
5754 HLSL_BOP_CAST(==, int_type);
5755 break;
5756 }
5757
5758 case OpLogicalEqual:
5759 case OpFOrdEqual:
5760 case OpFUnordEqual:
5761 {
5762 // HLSL != operator is unordered.
5763 // https://docs.microsoft.com/en-us/windows/win32/direct3d10/d3d10-graphics-programming-guide-resources-float-rules.
5764 // isnan() is apparently implemented as x != x as well.
5765 // We cannot implement UnordEqual as !(OrdNotEqual), as HLSL cannot express OrdNotEqual.
5766 // HACK: FUnordEqual will be implemented as FOrdEqual.
5767
5768 auto result_type = ops[0];
5769 auto id = ops[1];
5770
5771 if (expression_type(id: ops[2]).vecsize > 1)
5772 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "==", negate: false, expected_type: SPIRType::Unknown);
5773 else
5774 HLSL_BOP(==);
5775 break;
5776 }
5777
5778 case OpINotEqual:
5779 {
5780 auto result_type = ops[0];
5781 auto id = ops[1];
5782
5783 if (expression_type(id: ops[2]).vecsize > 1)
5784 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "!=", negate: false, expected_type: SPIRType::Unknown);
5785 else
5786 HLSL_BOP_CAST(!=, int_type);
5787 break;
5788 }
5789
5790 case OpLogicalNotEqual:
5791 case OpFOrdNotEqual:
5792 case OpFUnordNotEqual:
5793 {
5794 // HLSL != operator is unordered.
5795 // https://docs.microsoft.com/en-us/windows/win32/direct3d10/d3d10-graphics-programming-guide-resources-float-rules.
5796 // isnan() is apparently implemented as x != x as well.
5797
5798 // FIXME: FOrdNotEqual cannot be implemented in a crisp and simple way here.
5799 // We would need to do something like not(UnordEqual), but that cannot be expressed either.
5800 // Adding a lot of NaN checks would be a breaking change from perspective of performance.
5801 // SPIR-V will generally use isnan() checks when this even matters.
5802 // HACK: FOrdNotEqual will be implemented as FUnordEqual.
5803
5804 auto result_type = ops[0];
5805 auto id = ops[1];
5806
5807 if (expression_type(id: ops[2]).vecsize > 1)
5808 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "!=", negate: false, expected_type: SPIRType::Unknown);
5809 else
5810 HLSL_BOP(!=);
5811 break;
5812 }
5813
5814 case OpUGreaterThan:
5815 case OpSGreaterThan:
5816 {
5817 auto result_type = ops[0];
5818 auto id = ops[1];
5819 auto type = opcode == OpUGreaterThan ? uint_type : int_type;
5820
5821 if (expression_type(id: ops[2]).vecsize > 1)
5822 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: ">", negate: false, expected_type: type);
5823 else
5824 HLSL_BOP_CAST(>, type);
5825 break;
5826 }
5827
5828 case OpFOrdGreaterThan:
5829 {
5830 auto result_type = ops[0];
5831 auto id = ops[1];
5832
5833 if (expression_type(id: ops[2]).vecsize > 1)
5834 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: ">", negate: false, expected_type: SPIRType::Unknown);
5835 else
5836 HLSL_BOP(>);
5837 break;
5838 }
5839
5840 case OpFUnordGreaterThan:
5841 {
5842 auto result_type = ops[0];
5843 auto id = ops[1];
5844
5845 if (expression_type(id: ops[2]).vecsize > 1)
5846 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "<=", negate: true, expected_type: SPIRType::Unknown);
5847 else
5848 CompilerGLSL::emit_instruction(instr: instruction);
5849 break;
5850 }
5851
5852 case OpUGreaterThanEqual:
5853 case OpSGreaterThanEqual:
5854 {
5855 auto result_type = ops[0];
5856 auto id = ops[1];
5857
5858 auto type = opcode == OpUGreaterThanEqual ? uint_type : int_type;
5859 if (expression_type(id: ops[2]).vecsize > 1)
5860 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: ">=", negate: false, expected_type: type);
5861 else
5862 HLSL_BOP_CAST(>=, type);
5863 break;
5864 }
5865
5866 case OpFOrdGreaterThanEqual:
5867 {
5868 auto result_type = ops[0];
5869 auto id = ops[1];
5870
5871 if (expression_type(id: ops[2]).vecsize > 1)
5872 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: ">=", negate: false, expected_type: SPIRType::Unknown);
5873 else
5874 HLSL_BOP(>=);
5875 break;
5876 }
5877
5878 case OpFUnordGreaterThanEqual:
5879 {
5880 auto result_type = ops[0];
5881 auto id = ops[1];
5882
5883 if (expression_type(id: ops[2]).vecsize > 1)
5884 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "<", negate: true, expected_type: SPIRType::Unknown);
5885 else
5886 CompilerGLSL::emit_instruction(instr: instruction);
5887 break;
5888 }
5889
5890 case OpULessThan:
5891 case OpSLessThan:
5892 {
5893 auto result_type = ops[0];
5894 auto id = ops[1];
5895
5896 auto type = opcode == OpULessThan ? uint_type : int_type;
5897 if (expression_type(id: ops[2]).vecsize > 1)
5898 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "<", negate: false, expected_type: type);
5899 else
5900 HLSL_BOP_CAST(<, type);
5901 break;
5902 }
5903
5904 case OpFOrdLessThan:
5905 {
5906 auto result_type = ops[0];
5907 auto id = ops[1];
5908
5909 if (expression_type(id: ops[2]).vecsize > 1)
5910 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "<", negate: false, expected_type: SPIRType::Unknown);
5911 else
5912 HLSL_BOP(<);
5913 break;
5914 }
5915
5916 case OpFUnordLessThan:
5917 {
5918 auto result_type = ops[0];
5919 auto id = ops[1];
5920
5921 if (expression_type(id: ops[2]).vecsize > 1)
5922 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: ">=", negate: true, expected_type: SPIRType::Unknown);
5923 else
5924 CompilerGLSL::emit_instruction(instr: instruction);
5925 break;
5926 }
5927
5928 case OpULessThanEqual:
5929 case OpSLessThanEqual:
5930 {
5931 auto result_type = ops[0];
5932 auto id = ops[1];
5933
5934 auto type = opcode == OpULessThanEqual ? uint_type : int_type;
5935 if (expression_type(id: ops[2]).vecsize > 1)
5936 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "<=", negate: false, expected_type: type);
5937 else
5938 HLSL_BOP_CAST(<=, type);
5939 break;
5940 }
5941
5942 case OpFOrdLessThanEqual:
5943 {
5944 auto result_type = ops[0];
5945 auto id = ops[1];
5946
5947 if (expression_type(id: ops[2]).vecsize > 1)
5948 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "<=", negate: false, expected_type: SPIRType::Unknown);
5949 else
5950 HLSL_BOP(<=);
5951 break;
5952 }
5953
5954 case OpFUnordLessThanEqual:
5955 {
5956 auto result_type = ops[0];
5957 auto id = ops[1];
5958
5959 if (expression_type(id: ops[2]).vecsize > 1)
5960 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: ">", negate: true, expected_type: SPIRType::Unknown);
5961 else
5962 CompilerGLSL::emit_instruction(instr: instruction);
5963 break;
5964 }
5965
5966 case OpImageQueryLod:
5967 emit_texture_op(i: instruction, sparse: false);
5968 break;
5969
5970 case OpImageQuerySizeLod:
5971 {
5972 auto result_type = ops[0];
5973 auto id = ops[1];
5974
5975 require_texture_query_variant(var_id: ops[2]);
5976 auto dummy_samples_levels = join(ts: get_fallback_name(id), ts: "_dummy_parameter");
5977 statement(ts: "uint ", ts&: dummy_samples_levels, ts: ";");
5978
5979 auto expr = join(ts: "spvTextureSize(", ts: to_non_uniform_aware_expression(id: ops[2]), ts: ", ",
5980 ts: bitcast_expression(target_type: SPIRType::UInt, arg: ops[3]), ts: ", ", ts&: dummy_samples_levels, ts: ")");
5981
5982 auto &restype = get<SPIRType>(id: ops[0]);
5983 expr = bitcast_expression(target_type: restype, expr_type: SPIRType::UInt, expr);
5984 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: true);
5985 break;
5986 }
5987
5988 case OpImageQuerySize:
5989 {
5990 auto result_type = ops[0];
5991 auto id = ops[1];
5992
5993 require_texture_query_variant(var_id: ops[2]);
5994 bool uav = expression_type(id: ops[2]).image.sampled == 2;
5995
5996 if (const auto *var = maybe_get_backing_variable(chain: ops[2]))
5997 if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(id: var->self, decoration: DecorationNonWritable))
5998 uav = false;
5999
6000 auto dummy_samples_levels = join(ts: get_fallback_name(id), ts: "_dummy_parameter");
6001 statement(ts: "uint ", ts&: dummy_samples_levels, ts: ";");
6002
6003 string expr;
6004 if (uav)
6005 expr = join(ts: "spvImageSize(", ts: to_non_uniform_aware_expression(id: ops[2]), ts: ", ", ts&: dummy_samples_levels, ts: ")");
6006 else
6007 expr = join(ts: "spvTextureSize(", ts: to_non_uniform_aware_expression(id: ops[2]), ts: ", 0u, ", ts&: dummy_samples_levels, ts: ")");
6008
6009 auto &restype = get<SPIRType>(id: ops[0]);
6010 expr = bitcast_expression(target_type: restype, expr_type: SPIRType::UInt, expr);
6011 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: true);
6012 break;
6013 }
6014
6015 case OpImageQuerySamples:
6016 case OpImageQueryLevels:
6017 {
6018 auto result_type = ops[0];
6019 auto id = ops[1];
6020
6021 require_texture_query_variant(var_id: ops[2]);
6022 bool uav = expression_type(id: ops[2]).image.sampled == 2;
6023 if (opcode == OpImageQueryLevels && uav)
6024 SPIRV_CROSS_THROW("Cannot query levels for UAV images.");
6025
6026 if (const auto *var = maybe_get_backing_variable(chain: ops[2]))
6027 if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(id: var->self, decoration: DecorationNonWritable))
6028 uav = false;
6029
6030 // Keep it simple and do not emit special variants to make this look nicer ...
6031 // This stuff is barely, if ever, used.
6032 forced_temporaries.insert(x: id);
6033 auto &type = get<SPIRType>(id: result_type);
6034 statement(ts: variable_decl(type, name: to_name(id)), ts: ";");
6035
6036 if (uav)
6037 statement(ts: "spvImageSize(", ts: to_non_uniform_aware_expression(id: ops[2]), ts: ", ", ts: to_name(id), ts: ");");
6038 else
6039 statement(ts: "spvTextureSize(", ts: to_non_uniform_aware_expression(id: ops[2]), ts: ", 0u, ", ts: to_name(id), ts: ");");
6040
6041 auto &restype = get<SPIRType>(id: ops[0]);
6042 auto expr = bitcast_expression(target_type: restype, expr_type: SPIRType::UInt, expr: to_name(id));
6043 set<SPIRExpression>(id, args&: expr, args&: result_type, args: true);
6044 break;
6045 }
6046
6047 case OpImageRead:
6048 {
6049 uint32_t result_type = ops[0];
6050 uint32_t id = ops[1];
6051 auto *var = maybe_get_backing_variable(chain: ops[2]);
6052 auto &type = expression_type(id: ops[2]);
6053 bool subpass_data = type.image.dim == DimSubpassData;
6054 bool pure = false;
6055
6056 string imgexpr;
6057
6058 if (subpass_data)
6059 {
6060 if (hlsl_options.shader_model < 40)
6061 SPIRV_CROSS_THROW("Subpass loads are not supported in HLSL shader model 2/3.");
6062
6063 // Similar to GLSL, implement subpass loads using texelFetch.
6064 if (type.image.ms)
6065 {
6066 uint32_t operands = ops[4];
6067 if (operands != ImageOperandsSampleMask || instruction.length != 6)
6068 SPIRV_CROSS_THROW("Multisampled image used in OpImageRead, but unexpected operand mask was used.");
6069 uint32_t sample = ops[5];
6070 imgexpr = join(ts: to_non_uniform_aware_expression(id: ops[2]), ts: ".Load(int2(gl_FragCoord.xy), ", ts: to_expression(id: sample), ts: ")");
6071 }
6072 else
6073 imgexpr = join(ts: to_non_uniform_aware_expression(id: ops[2]), ts: ".Load(int3(int2(gl_FragCoord.xy), 0))");
6074
6075 pure = true;
6076 }
6077 else
6078 {
6079 imgexpr = join(ts: to_non_uniform_aware_expression(id: ops[2]), ts: "[", ts: to_expression(id: ops[3]), ts: "]");
6080 // The underlying image type in HLSL depends on the image format, unlike GLSL, where all images are "vec4",
6081 // except that the underlying type changes how the data is interpreted.
6082
6083 bool force_srv =
6084 hlsl_options.nonwritable_uav_texture_as_srv && var && has_decoration(id: var->self, decoration: DecorationNonWritable);
6085 pure = force_srv;
6086
6087 if (var && !subpass_data && !force_srv)
6088 imgexpr = remap_swizzle(result_type: get<SPIRType>(id: result_type),
6089 input_components: image_format_to_components(fmt: get<SPIRType>(id: var->basetype).image.format), expr: imgexpr);
6090 }
6091
6092 if (var)
6093 {
6094 bool forward = forced_temporaries.find(x: id) == end(cont&: forced_temporaries);
6095 auto &e = emit_op(result_type, result_id: id, rhs: imgexpr, forward_rhs: forward);
6096
6097 if (!pure)
6098 {
6099 e.loaded_from = var->self;
6100 if (forward)
6101 var->dependees.push_back(t: id);
6102 }
6103 }
6104 else
6105 emit_op(result_type, result_id: id, rhs: imgexpr, forward_rhs: false);
6106
6107 inherit_expression_dependencies(dst: id, source: ops[2]);
6108 if (type.image.ms)
6109 inherit_expression_dependencies(dst: id, source: ops[5]);
6110 break;
6111 }
6112
6113 case OpImageWrite:
6114 {
6115 auto *var = maybe_get_backing_variable(chain: ops[0]);
6116
6117 // The underlying image type in HLSL depends on the image format, unlike GLSL, where all images are "vec4",
6118 // except that the underlying type changes how the data is interpreted.
6119 auto value_expr = to_expression(id: ops[2]);
6120 if (var)
6121 {
6122 auto &type = get<SPIRType>(id: var->basetype);
6123 auto narrowed_type = get<SPIRType>(id: type.image.type);
6124 narrowed_type.vecsize = image_format_to_components(fmt: type.image.format);
6125 value_expr = remap_swizzle(result_type: narrowed_type, input_components: expression_type(id: ops[2]).vecsize, expr: value_expr);
6126 }
6127
6128 statement(ts: to_non_uniform_aware_expression(id: ops[0]), ts: "[", ts: to_expression(id: ops[1]), ts: "] = ", ts&: value_expr, ts: ";");
6129 if (var && variable_storage_is_aliased(var: *var))
6130 flush_all_aliased_variables();
6131 break;
6132 }
6133
6134 case OpImageTexelPointer:
6135 {
6136 uint32_t result_type = ops[0];
6137 uint32_t id = ops[1];
6138
6139 auto expr = to_expression(id: ops[2]);
6140 expr += join(ts: "[", ts: to_expression(id: ops[3]), ts: "]");
6141 auto &e = set<SPIRExpression>(id, args&: expr, args&: result_type, args: true);
6142
6143 // When using the pointer, we need to know which variable it is actually loaded from.
6144 auto *var = maybe_get_backing_variable(chain: ops[2]);
6145 e.loaded_from = var ? var->self : ID(0);
6146 inherit_expression_dependencies(dst: id, source: ops[3]);
6147 break;
6148 }
6149
6150 case OpAtomicFAddEXT:
6151 case OpAtomicFMinEXT:
6152 case OpAtomicFMaxEXT:
6153 SPIRV_CROSS_THROW("Floating-point atomics are not supported in HLSL.");
6154
6155 case OpAtomicCompareExchange:
6156 case OpAtomicExchange:
6157 case OpAtomicISub:
6158 case OpAtomicSMin:
6159 case OpAtomicUMin:
6160 case OpAtomicSMax:
6161 case OpAtomicUMax:
6162 case OpAtomicAnd:
6163 case OpAtomicOr:
6164 case OpAtomicXor:
6165 case OpAtomicIAdd:
6166 case OpAtomicIIncrement:
6167 case OpAtomicIDecrement:
6168 case OpAtomicLoad:
6169 case OpAtomicStore:
6170 {
6171 emit_atomic(ops, length: instruction.length, op: opcode);
6172 break;
6173 }
6174
6175 case OpControlBarrier:
6176 case OpMemoryBarrier:
6177 {
6178 uint32_t memory;
6179 uint32_t semantics;
6180
6181 if (opcode == OpMemoryBarrier)
6182 {
6183 memory = evaluate_constant_u32(id: ops[0]);
6184 semantics = evaluate_constant_u32(id: ops[1]);
6185 }
6186 else
6187 {
6188 memory = evaluate_constant_u32(id: ops[1]);
6189 semantics = evaluate_constant_u32(id: ops[2]);
6190 }
6191
6192 if (memory == ScopeSubgroup)
6193 {
6194 // No Wave-barriers in HLSL.
6195 break;
6196 }
6197
6198 // We only care about these flags, acquire/release and friends are not relevant to GLSL.
6199 semantics = mask_relevant_memory_semantics(semantics);
6200
6201 if (opcode == OpMemoryBarrier)
6202 {
6203 // If we are a memory barrier, and the next instruction is a control barrier, check if that memory barrier
6204 // does what we need, so we avoid redundant barriers.
6205 const Instruction *next = get_next_instruction_in_block(instr: instruction);
6206 if (next && next->op == OpControlBarrier)
6207 {
6208 auto *next_ops = stream(instr: *next);
6209 uint32_t next_memory = evaluate_constant_u32(id: next_ops[1]);
6210 uint32_t next_semantics = evaluate_constant_u32(id: next_ops[2]);
6211 next_semantics = mask_relevant_memory_semantics(semantics: next_semantics);
6212
6213 // There is no "just execution barrier" in HLSL.
6214 // If there are no memory semantics for next instruction, we will imply group shared memory is synced.
6215 if (next_semantics == 0)
6216 next_semantics = MemorySemanticsWorkgroupMemoryMask;
6217
6218 bool memory_scope_covered = false;
6219 if (next_memory == memory)
6220 memory_scope_covered = true;
6221 else if (next_semantics == MemorySemanticsWorkgroupMemoryMask)
6222 {
6223 // If we only care about workgroup memory, either Device or Workgroup scope is fine,
6224 // scope does not have to match.
6225 if ((next_memory == ScopeDevice || next_memory == ScopeWorkgroup) &&
6226 (memory == ScopeDevice || memory == ScopeWorkgroup))
6227 {
6228 memory_scope_covered = true;
6229 }
6230 }
6231 else if (memory == ScopeWorkgroup && next_memory == ScopeDevice)
6232 {
6233 // The control barrier has device scope, but the memory barrier just has workgroup scope.
6234 memory_scope_covered = true;
6235 }
6236
6237 // If we have the same memory scope, and all memory types are covered, we're good.
6238 if (memory_scope_covered && (semantics & next_semantics) == semantics)
6239 break;
6240 }
6241 }
6242
6243 // We are synchronizing some memory or syncing execution,
6244 // so we cannot forward any loads beyond the memory barrier.
6245 if (semantics || opcode == OpControlBarrier)
6246 {
6247 assert(current_emitting_block);
6248 flush_control_dependent_expressions(block: current_emitting_block->self);
6249 flush_all_active_variables();
6250 }
6251
6252 if (opcode == OpControlBarrier)
6253 {
6254 // We cannot emit just execution barrier, for no memory semantics pick the cheapest option.
6255 if (semantics == MemorySemanticsWorkgroupMemoryMask || semantics == 0)
6256 statement(ts: "GroupMemoryBarrierWithGroupSync();");
6257 else if (semantics != 0 && (semantics & MemorySemanticsWorkgroupMemoryMask) == 0)
6258 statement(ts: "DeviceMemoryBarrierWithGroupSync();");
6259 else
6260 statement(ts: "AllMemoryBarrierWithGroupSync();");
6261 }
6262 else
6263 {
6264 if (semantics == MemorySemanticsWorkgroupMemoryMask)
6265 statement(ts: "GroupMemoryBarrier();");
6266 else if (semantics != 0 && (semantics & MemorySemanticsWorkgroupMemoryMask) == 0)
6267 statement(ts: "DeviceMemoryBarrier();");
6268 else
6269 statement(ts: "AllMemoryBarrier();");
6270 }
6271 break;
6272 }
6273
6274 case OpBitFieldInsert:
6275 {
6276 if (!requires_bitfield_insert)
6277 {
6278 requires_bitfield_insert = true;
6279 force_recompile();
6280 }
6281
6282 auto expr = join(ts: "spvBitfieldInsert(", ts: to_expression(id: ops[2]), ts: ", ", ts: to_expression(id: ops[3]), ts: ", ",
6283 ts: to_expression(id: ops[4]), ts: ", ", ts: to_expression(id: ops[5]), ts: ")");
6284
6285 bool forward =
6286 should_forward(id: ops[2]) && should_forward(id: ops[3]) && should_forward(id: ops[4]) && should_forward(id: ops[5]);
6287
6288 auto &restype = get<SPIRType>(id: ops[0]);
6289 expr = bitcast_expression(target_type: restype, expr_type: SPIRType::UInt, expr);
6290 emit_op(result_type: ops[0], result_id: ops[1], rhs: expr, forward_rhs: forward);
6291 break;
6292 }
6293
6294 case OpBitFieldSExtract:
6295 case OpBitFieldUExtract:
6296 {
6297 if (!requires_bitfield_extract)
6298 {
6299 requires_bitfield_extract = true;
6300 force_recompile();
6301 }
6302
6303 if (opcode == OpBitFieldSExtract)
6304 HLSL_TFOP(spvBitfieldSExtract);
6305 else
6306 HLSL_TFOP(spvBitfieldUExtract);
6307 break;
6308 }
6309
6310 case OpBitCount:
6311 {
6312 auto basetype = expression_type(id: ops[2]).basetype;
6313 emit_unary_func_op_cast(result_type: ops[0], result_id: ops[1], op0: ops[2], op: "countbits", input_type: basetype, expected_result_type: basetype);
6314 break;
6315 }
6316
6317 case OpBitReverse:
6318 HLSL_UFOP(reversebits);
6319 break;
6320
6321 case OpArrayLength:
6322 {
6323 auto *var = maybe_get_backing_variable(chain: ops[2]);
6324 if (!var)
6325 SPIRV_CROSS_THROW("Array length must point directly to an SSBO block.");
6326
6327 auto &type = get<SPIRType>(id: var->basetype);
6328 if (!has_decoration(id: type.self, decoration: DecorationBlock) && !has_decoration(id: type.self, decoration: DecorationBufferBlock))
6329 SPIRV_CROSS_THROW("Array length expression must point to a block type.");
6330
6331 // This must be 32-bit uint, so we're good to go.
6332 emit_uninitialized_temporary_expression(type: ops[0], id: ops[1]);
6333 statement(ts: to_non_uniform_aware_expression(id: ops[2]), ts: ".GetDimensions(", ts: to_expression(id: ops[1]), ts: ");");
6334 uint32_t offset = type_struct_member_offset(type, index: ops[3]);
6335 uint32_t stride = type_struct_member_array_stride(type, index: ops[3]);
6336 statement(ts: to_expression(id: ops[1]), ts: " = (", ts: to_expression(id: ops[1]), ts: " - ", ts&: offset, ts: ") / ", ts&: stride, ts: ";");
6337 break;
6338 }
6339
6340 case OpIsHelperInvocationEXT:
6341 if (hlsl_options.shader_model < 50 || get_entry_point().model != ExecutionModelFragment)
6342 SPIRV_CROSS_THROW("Helper Invocation input is only supported in PS 5.0 or higher.");
6343 // Helper lane state with demote is volatile by nature.
6344 // Do not forward this.
6345 emit_op(result_type: ops[0], result_id: ops[1], rhs: "IsHelperLane()", forward_rhs: false);
6346 break;
6347
6348 case OpBeginInvocationInterlockEXT:
6349 case OpEndInvocationInterlockEXT:
6350 if (hlsl_options.shader_model < 51)
6351 SPIRV_CROSS_THROW("Rasterizer order views require Shader Model 5.1.");
6352 break; // Nothing to do in the body
6353
6354 case OpRayQueryInitializeKHR:
6355 {
6356 flush_variable_declaration(id: ops[0]);
6357
6358 std::string ray_desc_name = get_unique_identifier();
6359 statement(ts: "RayDesc ", ts&: ray_desc_name, ts: " = {", ts: to_expression(id: ops[4]), ts: ", ", ts: to_expression(id: ops[5]), ts: ", ",
6360 ts: to_expression(id: ops[6]), ts: ", ", ts: to_expression(id: ops[7]), ts: "};");
6361
6362 statement(ts: to_expression(id: ops[0]), ts: ".TraceRayInline(",
6363 ts: to_expression(id: ops[1]), ts: ", ", // acc structure
6364 ts: to_expression(id: ops[2]), ts: ", ", // ray flags
6365 ts: to_expression(id: ops[3]), ts: ", ", // mask
6366 ts&: ray_desc_name, ts: ");"); // ray
6367 break;
6368 }
6369 case OpRayQueryProceedKHR:
6370 {
6371 flush_variable_declaration(id: ops[0]);
6372 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".Proceed()"), forward_rhs: false);
6373 break;
6374 }
6375 case OpRayQueryTerminateKHR:
6376 {
6377 flush_variable_declaration(id: ops[0]);
6378 statement(ts: to_expression(id: ops[0]), ts: ".Abort();");
6379 break;
6380 }
6381 case OpRayQueryGenerateIntersectionKHR:
6382 {
6383 flush_variable_declaration(id: ops[0]);
6384 statement(ts: to_expression(id: ops[0]), ts: ".CommitProceduralPrimitiveHit(", ts: to_expression(id: ops[1]), ts: ");");
6385 break;
6386 }
6387 case OpRayQueryConfirmIntersectionKHR:
6388 {
6389 flush_variable_declaration(id: ops[0]);
6390 statement(ts: to_expression(id: ops[0]), ts: ".CommitNonOpaqueTriangleHit();");
6391 break;
6392 }
6393 case OpRayQueryGetIntersectionTypeKHR:
6394 {
6395 emit_rayquery_function(commited: ".CommittedStatus()", candidate: ".CandidateType()", ops);
6396 break;
6397 }
6398 case OpRayQueryGetIntersectionTKHR:
6399 {
6400 emit_rayquery_function(commited: ".CommittedRayT()", candidate: ".CandidateTriangleRayT()", ops);
6401 break;
6402 }
6403 case OpRayQueryGetIntersectionInstanceCustomIndexKHR:
6404 {
6405 emit_rayquery_function(commited: ".CommittedInstanceID()", candidate: ".CandidateInstanceID()", ops);
6406 break;
6407 }
6408 case OpRayQueryGetIntersectionInstanceIdKHR:
6409 {
6410 emit_rayquery_function(commited: ".CommittedInstanceIndex()", candidate: ".CandidateInstanceIndex()", ops);
6411 break;
6412 }
6413 case OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR:
6414 {
6415 emit_rayquery_function(commited: ".CommittedInstanceContributionToHitGroupIndex()",
6416 candidate: ".CandidateInstanceContributionToHitGroupIndex()", ops);
6417 break;
6418 }
6419 case OpRayQueryGetIntersectionGeometryIndexKHR:
6420 {
6421 emit_rayquery_function(commited: ".CommittedGeometryIndex()",
6422 candidate: ".CandidateGeometryIndex()", ops);
6423 break;
6424 }
6425 case OpRayQueryGetIntersectionPrimitiveIndexKHR:
6426 {
6427 emit_rayquery_function(commited: ".CommittedPrimitiveIndex()", candidate: ".CandidatePrimitiveIndex()", ops);
6428 break;
6429 }
6430 case OpRayQueryGetIntersectionBarycentricsKHR:
6431 {
6432 emit_rayquery_function(commited: ".CommittedTriangleBarycentrics()", candidate: ".CandidateTriangleBarycentrics()", ops);
6433 break;
6434 }
6435 case OpRayQueryGetIntersectionFrontFaceKHR:
6436 {
6437 emit_rayquery_function(commited: ".CommittedTriangleFrontFace()", candidate: ".CandidateTriangleFrontFace()", ops);
6438 break;
6439 }
6440 case OpRayQueryGetIntersectionCandidateAABBOpaqueKHR:
6441 {
6442 flush_variable_declaration(id: ops[0]);
6443 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".CandidateProceduralPrimitiveNonOpaque()"), forward_rhs: false);
6444 break;
6445 }
6446 case OpRayQueryGetIntersectionObjectRayDirectionKHR:
6447 {
6448 emit_rayquery_function(commited: ".CommittedObjectRayDirection()", candidate: ".CandidateObjectRayDirection()", ops);
6449 break;
6450 }
6451 case OpRayQueryGetIntersectionObjectRayOriginKHR:
6452 {
6453 flush_variable_declaration(id: ops[0]);
6454 emit_rayquery_function(commited: ".CommittedObjectRayOrigin()", candidate: ".CandidateObjectRayOrigin()", ops);
6455 break;
6456 }
6457 case OpRayQueryGetIntersectionObjectToWorldKHR:
6458 {
6459 emit_rayquery_function(commited: ".CommittedObjectToWorld4x3()", candidate: ".CandidateObjectToWorld4x3()", ops);
6460 break;
6461 }
6462 case OpRayQueryGetIntersectionWorldToObjectKHR:
6463 {
6464 emit_rayquery_function(commited: ".CommittedWorldToObject4x3()", candidate: ".CandidateWorldToObject4x3()", ops);
6465 break;
6466 }
6467 case OpRayQueryGetRayFlagsKHR:
6468 {
6469 flush_variable_declaration(id: ops[0]);
6470 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".RayFlags()"), forward_rhs: false);
6471 break;
6472 }
6473 case OpRayQueryGetRayTMinKHR:
6474 {
6475 flush_variable_declaration(id: ops[0]);
6476 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".RayTMin()"), forward_rhs: false);
6477 break;
6478 }
6479 case OpRayQueryGetWorldRayOriginKHR:
6480 {
6481 flush_variable_declaration(id: ops[0]);
6482 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".WorldRayOrigin()"), forward_rhs: false);
6483 break;
6484 }
6485 case OpRayQueryGetWorldRayDirectionKHR:
6486 {
6487 flush_variable_declaration(id: ops[0]);
6488 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".WorldRayDirection()"), forward_rhs: false);
6489 break;
6490 }
6491 case OpSetMeshOutputsEXT:
6492 {
6493 statement(ts: "SetMeshOutputCounts(", ts: to_unpacked_expression(id: ops[0]), ts: ", ", ts: to_unpacked_expression(id: ops[1]), ts: ");");
6494 break;
6495 }
6496 default:
6497 CompilerGLSL::emit_instruction(instr: instruction);
6498 break;
6499 }
6500}
6501
6502void CompilerHLSL::require_texture_query_variant(uint32_t var_id)
6503{
6504 if (const auto *var = maybe_get_backing_variable(chain: var_id))
6505 var_id = var->self;
6506
6507 auto &type = expression_type(id: var_id);
6508 bool uav = type.image.sampled == 2;
6509 if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(id: var_id, decoration: DecorationNonWritable))
6510 uav = false;
6511
6512 uint32_t bit = 0;
6513 switch (type.image.dim)
6514 {
6515 case Dim1D:
6516 bit = type.image.arrayed ? Query1DArray : Query1D;
6517 break;
6518
6519 case Dim2D:
6520 if (type.image.ms)
6521 bit = type.image.arrayed ? Query2DMSArray : Query2DMS;
6522 else
6523 bit = type.image.arrayed ? Query2DArray : Query2D;
6524 break;
6525
6526 case Dim3D:
6527 bit = Query3D;
6528 break;
6529
6530 case DimCube:
6531 bit = type.image.arrayed ? QueryCubeArray : QueryCube;
6532 break;
6533
6534 case DimBuffer:
6535 bit = QueryBuffer;
6536 break;
6537
6538 default:
6539 SPIRV_CROSS_THROW("Unsupported query type.");
6540 }
6541
6542 switch (get<SPIRType>(id: type.image.type).basetype)
6543 {
6544 case SPIRType::Float:
6545 bit += QueryTypeFloat;
6546 break;
6547
6548 case SPIRType::Int:
6549 bit += QueryTypeInt;
6550 break;
6551
6552 case SPIRType::UInt:
6553 bit += QueryTypeUInt;
6554 break;
6555
6556 default:
6557 SPIRV_CROSS_THROW("Unsupported query type.");
6558 }
6559
6560 auto norm_state = image_format_to_normalized_state(fmt: type.image.format);
6561 auto &variant = uav ? required_texture_size_variants
6562 .uav[uint32_t(norm_state)][image_format_to_components(fmt: type.image.format) - 1] :
6563 required_texture_size_variants.srv;
6564
6565 uint64_t mask = 1ull << bit;
6566 if ((variant & mask) == 0)
6567 {
6568 force_recompile();
6569 variant |= mask;
6570 }
6571}
6572
6573void CompilerHLSL::set_root_constant_layouts(std::vector<RootConstants> layout)
6574{
6575 root_constants_layout = std::move(layout);
6576}
6577
6578void CompilerHLSL::add_vertex_attribute_remap(const HLSLVertexAttributeRemap &vertex_attributes)
6579{
6580 remap_vertex_attributes.push_back(t: vertex_attributes);
6581}
6582
6583VariableID CompilerHLSL::remap_num_workgroups_builtin()
6584{
6585 update_active_builtins();
6586
6587 if (!active_input_builtins.get(bit: BuiltInNumWorkgroups))
6588 return 0;
6589
6590 // Create a new, fake UBO.
6591 uint32_t offset = ir.increase_bound_by(count: 4);
6592
6593 uint32_t uint_type_id = offset;
6594 uint32_t block_type_id = offset + 1;
6595 uint32_t block_pointer_type_id = offset + 2;
6596 uint32_t variable_id = offset + 3;
6597
6598 SPIRType uint_type { OpTypeVector };
6599 uint_type.basetype = SPIRType::UInt;
6600 uint_type.width = 32;
6601 uint_type.vecsize = 3;
6602 uint_type.columns = 1;
6603 set<SPIRType>(id: uint_type_id, args&: uint_type);
6604
6605 SPIRType block_type { OpTypeStruct };
6606 block_type.basetype = SPIRType::Struct;
6607 block_type.member_types.push_back(t: uint_type_id);
6608 set<SPIRType>(id: block_type_id, args&: block_type);
6609 set_decoration(id: block_type_id, decoration: DecorationBlock);
6610 set_member_name(id: block_type_id, index: 0, name: "count");
6611 set_member_decoration(id: block_type_id, index: 0, decoration: DecorationOffset, argument: 0);
6612
6613 SPIRType block_pointer_type = block_type;
6614 block_pointer_type.pointer = true;
6615 block_pointer_type.storage = StorageClassUniform;
6616 block_pointer_type.parent_type = block_type_id;
6617 auto &ptr_type = set<SPIRType>(id: block_pointer_type_id, args&: block_pointer_type);
6618
6619 // Preserve self.
6620 ptr_type.self = block_type_id;
6621
6622 set<SPIRVariable>(id: variable_id, args&: block_pointer_type_id, args: StorageClassUniform);
6623 ir.meta[variable_id].decoration.alias = "SPIRV_Cross_NumWorkgroups";
6624
6625 num_workgroups_builtin = variable_id;
6626 get_entry_point().interface_variables.push_back(t: num_workgroups_builtin);
6627 return variable_id;
6628}
6629
6630void CompilerHLSL::set_resource_binding_flags(HLSLBindingFlags flags)
6631{
6632 resource_binding_flags = flags;
6633}
6634
6635void CompilerHLSL::validate_shader_model()
6636{
6637 // Check for nonuniform qualifier.
6638 // Instead of looping over all decorations to find this, just look at capabilities.
6639 for (auto &cap : ir.declared_capabilities)
6640 {
6641 switch (cap)
6642 {
6643 case CapabilityShaderNonUniformEXT:
6644 case CapabilityRuntimeDescriptorArrayEXT:
6645 if (hlsl_options.shader_model < 51)
6646 SPIRV_CROSS_THROW(
6647 "Shader model 5.1 or higher is required to use bindless resources or NonUniformResourceIndex.");
6648 break;
6649
6650 case CapabilityVariablePointers:
6651 case CapabilityVariablePointersStorageBuffer:
6652 SPIRV_CROSS_THROW("VariablePointers capability is not supported in HLSL.");
6653
6654 default:
6655 break;
6656 }
6657 }
6658
6659 if (ir.addressing_model != AddressingModelLogical)
6660 SPIRV_CROSS_THROW("Only Logical addressing model can be used with HLSL.");
6661
6662 if (hlsl_options.enable_16bit_types && hlsl_options.shader_model < 62)
6663 SPIRV_CROSS_THROW("Need at least shader model 6.2 when enabling native 16-bit type support.");
6664}
6665
6666string CompilerHLSL::compile()
6667{
6668 ir.fixup_reserved_names();
6669
6670 // Do not deal with ES-isms like precision, older extensions and such.
6671 options.es = false;
6672 options.version = 450;
6673 options.vulkan_semantics = true;
6674 backend.float_literal_suffix = true;
6675 backend.double_literal_suffix = false;
6676 backend.long_long_literal_suffix = true;
6677 backend.uint32_t_literal_suffix = true;
6678 backend.int16_t_literal_suffix = "";
6679 backend.uint16_t_literal_suffix = "u";
6680 backend.basic_int_type = "int";
6681 backend.basic_uint_type = "uint";
6682 backend.demote_literal = "discard";
6683 backend.boolean_mix_function = "";
6684 backend.swizzle_is_function = false;
6685 backend.shared_is_implied = true;
6686 backend.unsized_array_supported = true;
6687 backend.explicit_struct_type = false;
6688 backend.use_initializer_list = true;
6689 backend.use_constructor_splatting = false;
6690 backend.can_swizzle_scalar = true;
6691 backend.can_declare_struct_inline = false;
6692 backend.can_declare_arrays_inline = false;
6693 backend.can_return_array = false;
6694 backend.nonuniform_qualifier = "NonUniformResourceIndex";
6695 backend.support_case_fallthrough = false;
6696 backend.force_merged_mesh_block = get_execution_model() == ExecutionModelMeshEXT;
6697 backend.force_gl_in_out_block = backend.force_merged_mesh_block;
6698
6699 // SM 4.1 does not support precise for some reason.
6700 backend.support_precise_qualifier = hlsl_options.shader_model >= 50 || hlsl_options.shader_model == 40;
6701
6702 fixup_anonymous_struct_names();
6703 fixup_type_alias();
6704 reorder_type_alias();
6705 build_function_control_flow_graphs_and_analyze();
6706 validate_shader_model();
6707 update_active_builtins();
6708 analyze_image_and_sampler_usage();
6709 analyze_interlocked_resource_usage();
6710 if (get_execution_model() == ExecutionModelMeshEXT)
6711 analyze_meshlet_writes();
6712
6713 // Subpass input needs SV_Position.
6714 if (need_subpass_input)
6715 active_input_builtins.set(BuiltInFragCoord);
6716
6717 uint32_t pass_count = 0;
6718 do
6719 {
6720 reset(iteration_count: pass_count);
6721
6722 // Move constructor for this type is broken on GCC 4.9 ...
6723 buffer.reset();
6724
6725 emit_header();
6726 emit_resources();
6727
6728 emit_function(func&: get<SPIRFunction>(id: ir.default_entry_point), return_flags: Bitset());
6729 emit_hlsl_entry_point();
6730
6731 pass_count++;
6732 } while (is_forcing_recompilation());
6733
6734 // Entry point in HLSL is always main() for the time being.
6735 get_entry_point().name = "main";
6736
6737 return buffer.str();
6738}
6739
6740void CompilerHLSL::emit_block_hints(const SPIRBlock &block)
6741{
6742 switch (block.hint)
6743 {
6744 case SPIRBlock::HintFlatten:
6745 statement(ts: "[flatten]");
6746 break;
6747 case SPIRBlock::HintDontFlatten:
6748 statement(ts: "[branch]");
6749 break;
6750 case SPIRBlock::HintUnroll:
6751 statement(ts: "[unroll]");
6752 break;
6753 case SPIRBlock::HintDontUnroll:
6754 statement(ts: "[loop]");
6755 break;
6756 default:
6757 break;
6758 }
6759}
6760
6761string CompilerHLSL::get_unique_identifier()
6762{
6763 return join(ts: "_", ts: unique_identifier_count++, ts: "ident");
6764}
6765
6766void CompilerHLSL::add_hlsl_resource_binding(const HLSLResourceBinding &binding)
6767{
6768 StageSetBinding tuple = { .model: binding.stage, .desc_set: binding.desc_set, .binding: binding.binding };
6769 resource_bindings[tuple] = { binding, false };
6770}
6771
6772bool CompilerHLSL::is_hlsl_resource_binding_used(ExecutionModel model, uint32_t desc_set, uint32_t binding) const
6773{
6774 StageSetBinding tuple = { .model: model, .desc_set: desc_set, .binding: binding };
6775 auto itr = resource_bindings.find(x: tuple);
6776 return itr != end(cont: resource_bindings) && itr->second.second;
6777}
6778
6779CompilerHLSL::BitcastType CompilerHLSL::get_bitcast_type(uint32_t result_type, uint32_t op0)
6780{
6781 auto &rslt_type = get<SPIRType>(id: result_type);
6782 auto &expr_type = expression_type(id: op0);
6783
6784 if (rslt_type.basetype == SPIRType::BaseType::UInt64 && expr_type.basetype == SPIRType::BaseType::UInt &&
6785 expr_type.vecsize == 2)
6786 return BitcastType::TypePackUint2x32;
6787 else if (rslt_type.basetype == SPIRType::BaseType::UInt && rslt_type.vecsize == 2 &&
6788 expr_type.basetype == SPIRType::BaseType::UInt64)
6789 return BitcastType::TypeUnpackUint64;
6790
6791 return BitcastType::TypeNormal;
6792}
6793
6794bool CompilerHLSL::is_hlsl_force_storage_buffer_as_uav(ID id) const
6795{
6796 if (hlsl_options.force_storage_buffer_as_uav)
6797 {
6798 return true;
6799 }
6800
6801 const uint32_t desc_set = get_decoration(id, decoration: spv::DecorationDescriptorSet);
6802 const uint32_t binding = get_decoration(id, decoration: spv::DecorationBinding);
6803
6804 return (force_uav_buffer_bindings.find(x: { .desc_set: desc_set, .binding: binding }) != force_uav_buffer_bindings.end());
6805}
6806
6807void CompilerHLSL::set_hlsl_force_storage_buffer_as_uav(uint32_t desc_set, uint32_t binding)
6808{
6809 SetBindingPair pair = { .desc_set: desc_set, .binding: binding };
6810 force_uav_buffer_bindings.insert(x: pair);
6811}
6812
6813bool CompilerHLSL::is_user_type_structured(uint32_t id) const
6814{
6815 if (hlsl_options.preserve_structured_buffers)
6816 {
6817 // Compare left hand side of string only as these user types can contain more meta data such as their subtypes,
6818 // e.g. "structuredbuffer:int"
6819 const std::string &user_type = get_decoration_string(id, decoration: DecorationUserTypeGOOGLE);
6820 return user_type.compare(pos: 0, n1: 16, s: "structuredbuffer") == 0 ||
6821 user_type.compare(pos: 0, n1: 18, s: "rwstructuredbuffer") == 0 ||
6822 user_type.compare(pos: 0, n1: 33, s: "rasterizerorderedstructuredbuffer") == 0;
6823 }
6824 return false;
6825}
6826

source code of qtshadertools/src/3rdparty/SPIRV-Cross/spirv_hlsl.cpp