1/*
2 * Copyright 2016-2021 Robert Konrad
3 * SPDX-License-Identifier: Apache-2.0 OR MIT
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19/*
20 * At your option, you may choose to accept this material under either:
21 * 1. The Apache License, Version 2.0, found at <http://www.apache.org/licenses/LICENSE-2.0>, or
22 * 2. The MIT License, found at <http://opensource.org/licenses/MIT>.
23 */
24
25#include "spirv_hlsl.hpp"
26#include "GLSL.std.450.h"
27#include <algorithm>
28#include <assert.h>
29
30using namespace spv;
31using namespace SPIRV_CROSS_NAMESPACE;
32using namespace std;
33
34enum class ImageFormatNormalizedState
35{
36 None = 0,
37 Unorm = 1,
38 Snorm = 2
39};
40
41static ImageFormatNormalizedState image_format_to_normalized_state(ImageFormat fmt)
42{
43 switch (fmt)
44 {
45 case ImageFormatR8:
46 case ImageFormatR16:
47 case ImageFormatRg8:
48 case ImageFormatRg16:
49 case ImageFormatRgba8:
50 case ImageFormatRgba16:
51 case ImageFormatRgb10A2:
52 return ImageFormatNormalizedState::Unorm;
53
54 case ImageFormatR8Snorm:
55 case ImageFormatR16Snorm:
56 case ImageFormatRg8Snorm:
57 case ImageFormatRg16Snorm:
58 case ImageFormatRgba8Snorm:
59 case ImageFormatRgba16Snorm:
60 return ImageFormatNormalizedState::Snorm;
61
62 default:
63 break;
64 }
65
66 return ImageFormatNormalizedState::None;
67}
68
69static unsigned image_format_to_components(ImageFormat fmt)
70{
71 switch (fmt)
72 {
73 case ImageFormatR8:
74 case ImageFormatR16:
75 case ImageFormatR8Snorm:
76 case ImageFormatR16Snorm:
77 case ImageFormatR16f:
78 case ImageFormatR32f:
79 case ImageFormatR8i:
80 case ImageFormatR16i:
81 case ImageFormatR32i:
82 case ImageFormatR8ui:
83 case ImageFormatR16ui:
84 case ImageFormatR32ui:
85 return 1;
86
87 case ImageFormatRg8:
88 case ImageFormatRg16:
89 case ImageFormatRg8Snorm:
90 case ImageFormatRg16Snorm:
91 case ImageFormatRg16f:
92 case ImageFormatRg32f:
93 case ImageFormatRg8i:
94 case ImageFormatRg16i:
95 case ImageFormatRg32i:
96 case ImageFormatRg8ui:
97 case ImageFormatRg16ui:
98 case ImageFormatRg32ui:
99 return 2;
100
101 case ImageFormatR11fG11fB10f:
102 return 3;
103
104 case ImageFormatRgba8:
105 case ImageFormatRgba16:
106 case ImageFormatRgb10A2:
107 case ImageFormatRgba8Snorm:
108 case ImageFormatRgba16Snorm:
109 case ImageFormatRgba16f:
110 case ImageFormatRgba32f:
111 case ImageFormatRgba8i:
112 case ImageFormatRgba16i:
113 case ImageFormatRgba32i:
114 case ImageFormatRgba8ui:
115 case ImageFormatRgba16ui:
116 case ImageFormatRgba32ui:
117 case ImageFormatRgb10a2ui:
118 return 4;
119
120 case ImageFormatUnknown:
121 return 4; // Assume 4.
122
123 default:
124 SPIRV_CROSS_THROW("Unrecognized typed image format.");
125 }
126}
127
128static string image_format_to_type(ImageFormat fmt, SPIRType::BaseType basetype)
129{
130 switch (fmt)
131 {
132 case ImageFormatR8:
133 case ImageFormatR16:
134 if (basetype != SPIRType::Float)
135 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
136 return "unorm float";
137 case ImageFormatRg8:
138 case ImageFormatRg16:
139 if (basetype != SPIRType::Float)
140 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
141 return "unorm float2";
142 case ImageFormatRgba8:
143 case ImageFormatRgba16:
144 if (basetype != SPIRType::Float)
145 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
146 return "unorm float4";
147 case ImageFormatRgb10A2:
148 if (basetype != SPIRType::Float)
149 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
150 return "unorm float4";
151
152 case ImageFormatR8Snorm:
153 case ImageFormatR16Snorm:
154 if (basetype != SPIRType::Float)
155 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
156 return "snorm float";
157 case ImageFormatRg8Snorm:
158 case ImageFormatRg16Snorm:
159 if (basetype != SPIRType::Float)
160 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
161 return "snorm float2";
162 case ImageFormatRgba8Snorm:
163 case ImageFormatRgba16Snorm:
164 if (basetype != SPIRType::Float)
165 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
166 return "snorm float4";
167
168 case ImageFormatR16f:
169 case ImageFormatR32f:
170 if (basetype != SPIRType::Float)
171 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
172 return "float";
173 case ImageFormatRg16f:
174 case ImageFormatRg32f:
175 if (basetype != SPIRType::Float)
176 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
177 return "float2";
178 case ImageFormatRgba16f:
179 case ImageFormatRgba32f:
180 if (basetype != SPIRType::Float)
181 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
182 return "float4";
183
184 case ImageFormatR11fG11fB10f:
185 if (basetype != SPIRType::Float)
186 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
187 return "float3";
188
189 case ImageFormatR8i:
190 case ImageFormatR16i:
191 case ImageFormatR32i:
192 if (basetype != SPIRType::Int)
193 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
194 return "int";
195 case ImageFormatRg8i:
196 case ImageFormatRg16i:
197 case ImageFormatRg32i:
198 if (basetype != SPIRType::Int)
199 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
200 return "int2";
201 case ImageFormatRgba8i:
202 case ImageFormatRgba16i:
203 case ImageFormatRgba32i:
204 if (basetype != SPIRType::Int)
205 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
206 return "int4";
207
208 case ImageFormatR8ui:
209 case ImageFormatR16ui:
210 case ImageFormatR32ui:
211 if (basetype != SPIRType::UInt)
212 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
213 return "uint";
214 case ImageFormatRg8ui:
215 case ImageFormatRg16ui:
216 case ImageFormatRg32ui:
217 if (basetype != SPIRType::UInt)
218 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
219 return "uint2";
220 case ImageFormatRgba8ui:
221 case ImageFormatRgba16ui:
222 case ImageFormatRgba32ui:
223 if (basetype != SPIRType::UInt)
224 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
225 return "uint4";
226 case ImageFormatRgb10a2ui:
227 if (basetype != SPIRType::UInt)
228 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
229 return "uint4";
230
231 case ImageFormatUnknown:
232 switch (basetype)
233 {
234 case SPIRType::Float:
235 return "float4";
236 case SPIRType::Int:
237 return "int4";
238 case SPIRType::UInt:
239 return "uint4";
240 default:
241 SPIRV_CROSS_THROW("Unsupported base type for image.");
242 }
243
244 default:
245 SPIRV_CROSS_THROW("Unrecognized typed image format.");
246 }
247}
248
249string CompilerHLSL::image_type_hlsl_modern(const SPIRType &type, uint32_t id)
250{
251 auto &imagetype = get<SPIRType>(id: type.image.type);
252 const char *dim = nullptr;
253 bool typed_load = false;
254 uint32_t components = 4;
255
256 bool force_image_srv = hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(id, decoration: DecorationNonWritable);
257
258 switch (type.image.dim)
259 {
260 case Dim1D:
261 typed_load = type.image.sampled == 2;
262 dim = "1D";
263 break;
264 case Dim2D:
265 typed_load = type.image.sampled == 2;
266 dim = "2D";
267 break;
268 case Dim3D:
269 typed_load = type.image.sampled == 2;
270 dim = "3D";
271 break;
272 case DimCube:
273 if (type.image.sampled == 2)
274 SPIRV_CROSS_THROW("RWTextureCube does not exist in HLSL.");
275 dim = "Cube";
276 break;
277 case DimRect:
278 SPIRV_CROSS_THROW("Rectangle texture support is not yet implemented for HLSL."); // TODO
279 case DimBuffer:
280 if (type.image.sampled == 1)
281 return join(ts: "Buffer<", ts: type_to_glsl(type: imagetype), ts&: components, ts: ">");
282 else if (type.image.sampled == 2)
283 {
284 if (interlocked_resources.count(x: id))
285 return join(ts: "RasterizerOrderedBuffer<", ts: image_format_to_type(fmt: type.image.format, basetype: imagetype.basetype),
286 ts: ">");
287
288 typed_load = !force_image_srv && type.image.sampled == 2;
289
290 const char *rw = force_image_srv ? "" : "RW";
291 return join(ts&: rw, ts: "Buffer<",
292 ts: typed_load ? image_format_to_type(fmt: type.image.format, basetype: imagetype.basetype) :
293 join(ts: type_to_glsl(type: imagetype), ts&: components),
294 ts: ">");
295 }
296 else
297 SPIRV_CROSS_THROW("Sampler buffers must be either sampled or unsampled. Cannot deduce in runtime.");
298 case DimSubpassData:
299 dim = "2D";
300 typed_load = false;
301 break;
302 default:
303 SPIRV_CROSS_THROW("Invalid dimension.");
304 }
305 const char *arrayed = type.image.arrayed ? "Array" : "";
306 const char *ms = type.image.ms ? "MS" : "";
307 const char *rw = typed_load && !force_image_srv ? "RW" : "";
308
309 if (force_image_srv)
310 typed_load = false;
311
312 if (typed_load && interlocked_resources.count(x: id))
313 rw = "RasterizerOrdered";
314
315 return join(ts&: rw, ts: "Texture", ts&: dim, ts&: ms, ts&: arrayed, ts: "<",
316 ts: typed_load ? image_format_to_type(fmt: type.image.format, basetype: imagetype.basetype) :
317 join(ts: type_to_glsl(type: imagetype), ts&: components),
318 ts: ">");
319}
320
321string CompilerHLSL::image_type_hlsl_legacy(const SPIRType &type, uint32_t /*id*/)
322{
323 auto &imagetype = get<SPIRType>(id: type.image.type);
324 string res;
325
326 switch (imagetype.basetype)
327 {
328 case SPIRType::Int:
329 res = "i";
330 break;
331 case SPIRType::UInt:
332 res = "u";
333 break;
334 default:
335 break;
336 }
337
338 if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData)
339 return res + "subpassInput" + (type.image.ms ? "MS" : "");
340
341 // If we're emulating subpassInput with samplers, force sampler2D
342 // so we don't have to specify format.
343 if (type.basetype == SPIRType::Image && type.image.dim != DimSubpassData)
344 {
345 // Sampler buffers are always declared as samplerBuffer even though they might be separate images in the SPIR-V.
346 if (type.image.dim == DimBuffer && type.image.sampled == 1)
347 res += "sampler";
348 else
349 res += type.image.sampled == 2 ? "image" : "texture";
350 }
351 else
352 res += "sampler";
353
354 switch (type.image.dim)
355 {
356 case Dim1D:
357 res += "1D";
358 break;
359 case Dim2D:
360 res += "2D";
361 break;
362 case Dim3D:
363 res += "3D";
364 break;
365 case DimCube:
366 res += "CUBE";
367 break;
368
369 case DimBuffer:
370 res += "Buffer";
371 break;
372
373 case DimSubpassData:
374 res += "2D";
375 break;
376 default:
377 SPIRV_CROSS_THROW("Only 1D, 2D, 3D, Buffer, InputTarget and Cube textures supported.");
378 }
379
380 if (type.image.ms)
381 res += "MS";
382 if (type.image.arrayed)
383 res += "Array";
384
385 return res;
386}
387
388string CompilerHLSL::image_type_hlsl(const SPIRType &type, uint32_t id)
389{
390 if (hlsl_options.shader_model <= 30)
391 return image_type_hlsl_legacy(type, id);
392 else
393 return image_type_hlsl_modern(type, id);
394}
395
396// The optional id parameter indicates the object whose type we are trying
397// to find the description for. It is optional. Most type descriptions do not
398// depend on a specific object's use of that type.
399string CompilerHLSL::type_to_glsl(const SPIRType &type, uint32_t id)
400{
401 // Ignore the pointer type since GLSL doesn't have pointers.
402
403 switch (type.basetype)
404 {
405 case SPIRType::Struct:
406 // Need OpName lookup here to get a "sensible" name for a struct.
407 if (backend.explicit_struct_type)
408 return join(ts: "struct ", ts: to_name(id: type.self));
409 else
410 return to_name(id: type.self);
411
412 case SPIRType::Image:
413 case SPIRType::SampledImage:
414 return image_type_hlsl(type, id);
415
416 case SPIRType::Sampler:
417 return comparison_ids.count(x: id) ? "SamplerComparisonState" : "SamplerState";
418
419 case SPIRType::Void:
420 return "void";
421
422 default:
423 break;
424 }
425
426 if (type.vecsize == 1 && type.columns == 1) // Scalar builtin
427 {
428 switch (type.basetype)
429 {
430 case SPIRType::Boolean:
431 return "bool";
432 case SPIRType::Int:
433 return backend.basic_int_type;
434 case SPIRType::UInt:
435 return backend.basic_uint_type;
436 case SPIRType::AtomicCounter:
437 return "atomic_uint";
438 case SPIRType::Half:
439 if (hlsl_options.enable_16bit_types)
440 return "half";
441 else
442 return "min16float";
443 case SPIRType::Short:
444 if (hlsl_options.enable_16bit_types)
445 return "int16_t";
446 else
447 return "min16int";
448 case SPIRType::UShort:
449 if (hlsl_options.enable_16bit_types)
450 return "uint16_t";
451 else
452 return "min16uint";
453 case SPIRType::Float:
454 return "float";
455 case SPIRType::Double:
456 return "double";
457 case SPIRType::Int64:
458 if (hlsl_options.shader_model < 60)
459 SPIRV_CROSS_THROW("64-bit integers only supported in SM 6.0.");
460 return "int64_t";
461 case SPIRType::UInt64:
462 if (hlsl_options.shader_model < 60)
463 SPIRV_CROSS_THROW("64-bit integers only supported in SM 6.0.");
464 return "uint64_t";
465 case SPIRType::AccelerationStructure:
466 return "RaytracingAccelerationStructure";
467 case SPIRType::RayQuery:
468 return "RayQuery<RAY_FLAG_NONE>";
469 default:
470 return "???";
471 }
472 }
473 else if (type.vecsize > 1 && type.columns == 1) // Vector builtin
474 {
475 switch (type.basetype)
476 {
477 case SPIRType::Boolean:
478 return join(ts: "bool", ts: type.vecsize);
479 case SPIRType::Int:
480 return join(ts: "int", ts: type.vecsize);
481 case SPIRType::UInt:
482 return join(ts: "uint", ts: type.vecsize);
483 case SPIRType::Half:
484 return join(ts: hlsl_options.enable_16bit_types ? "half" : "min16float", ts: type.vecsize);
485 case SPIRType::Short:
486 return join(ts: hlsl_options.enable_16bit_types ? "int16_t" : "min16int", ts: type.vecsize);
487 case SPIRType::UShort:
488 return join(ts: hlsl_options.enable_16bit_types ? "uint16_t" : "min16uint", ts: type.vecsize);
489 case SPIRType::Float:
490 return join(ts: "float", ts: type.vecsize);
491 case SPIRType::Double:
492 return join(ts: "double", ts: type.vecsize);
493 case SPIRType::Int64:
494 return join(ts: "i64vec", ts: type.vecsize);
495 case SPIRType::UInt64:
496 return join(ts: "u64vec", ts: type.vecsize);
497 default:
498 return "???";
499 }
500 }
501 else
502 {
503 switch (type.basetype)
504 {
505 case SPIRType::Boolean:
506 return join(ts: "bool", ts: type.columns, ts: "x", ts: type.vecsize);
507 case SPIRType::Int:
508 return join(ts: "int", ts: type.columns, ts: "x", ts: type.vecsize);
509 case SPIRType::UInt:
510 return join(ts: "uint", ts: type.columns, ts: "x", ts: type.vecsize);
511 case SPIRType::Half:
512 return join(ts: hlsl_options.enable_16bit_types ? "half" : "min16float", ts: type.columns, ts: "x", ts: type.vecsize);
513 case SPIRType::Short:
514 return join(ts: hlsl_options.enable_16bit_types ? "int16_t" : "min16int", ts: type.columns, ts: "x", ts: type.vecsize);
515 case SPIRType::UShort:
516 return join(ts: hlsl_options.enable_16bit_types ? "uint16_t" : "min16uint", ts: type.columns, ts: "x", ts: type.vecsize);
517 case SPIRType::Float:
518 return join(ts: "float", ts: type.columns, ts: "x", ts: type.vecsize);
519 case SPIRType::Double:
520 return join(ts: "double", ts: type.columns, ts: "x", ts: type.vecsize);
521 // Matrix types not supported for int64/uint64.
522 default:
523 return "???";
524 }
525 }
526}
527
528void CompilerHLSL::emit_header()
529{
530 for (auto &header : header_lines)
531 statement(ts&: header);
532
533 if (header_lines.size() > 0)
534 {
535 statement(ts: "");
536 }
537}
538
539void CompilerHLSL::emit_interface_block_globally(const SPIRVariable &var)
540{
541 add_resource_name(id: var.self);
542
543 // The global copies of I/O variables should not contain interpolation qualifiers.
544 // These are emitted inside the interface structs.
545 auto &flags = ir.meta[var.self].decoration.decoration_flags;
546 auto old_flags = flags;
547 flags.reset();
548 statement(ts: "static ", ts: variable_decl(variable: var), ts: ";");
549 flags = old_flags;
550}
551
552const char *CompilerHLSL::to_storage_qualifiers_glsl(const SPIRVariable &var)
553{
554 // Input and output variables are handled specially in HLSL backend.
555 // The variables are declared as global, private variables, and do not need any qualifiers.
556 if (var.storage == StorageClassUniformConstant || var.storage == StorageClassUniform ||
557 var.storage == StorageClassPushConstant)
558 {
559 return "uniform ";
560 }
561
562 return "";
563}
564
565void CompilerHLSL::emit_builtin_outputs_in_struct()
566{
567 auto &execution = get_entry_point();
568
569 bool legacy = hlsl_options.shader_model <= 30;
570 active_output_builtins.for_each_bit(op: [&](uint32_t i) {
571 const char *type = nullptr;
572 const char *semantic = nullptr;
573 auto builtin = static_cast<BuiltIn>(i);
574 switch (builtin)
575 {
576 case BuiltInPosition:
577 type = is_position_invariant() && backend.support_precise_qualifier ? "precise float4" : "float4";
578 semantic = legacy ? "POSITION" : "SV_Position";
579 break;
580
581 case BuiltInSampleMask:
582 if (hlsl_options.shader_model < 41 || execution.model != ExecutionModelFragment)
583 SPIRV_CROSS_THROW("Sample Mask output is only supported in PS 4.1 or higher.");
584 type = "uint";
585 semantic = "SV_Coverage";
586 break;
587
588 case BuiltInFragDepth:
589 type = "float";
590 if (legacy)
591 {
592 semantic = "DEPTH";
593 }
594 else
595 {
596 if (hlsl_options.shader_model >= 50 && execution.flags.get(bit: ExecutionModeDepthGreater))
597 semantic = "SV_DepthGreaterEqual";
598 else if (hlsl_options.shader_model >= 50 && execution.flags.get(bit: ExecutionModeDepthLess))
599 semantic = "SV_DepthLessEqual";
600 else
601 semantic = "SV_Depth";
602 }
603 break;
604
605 case BuiltInClipDistance:
606 // HLSL is a bit weird here, use SV_ClipDistance0, SV_ClipDistance1 and so on with vectors.
607 for (uint32_t clip = 0; clip < clip_distance_count; clip += 4)
608 {
609 uint32_t to_declare = clip_distance_count - clip;
610 if (to_declare > 4)
611 to_declare = 4;
612
613 uint32_t semantic_index = clip / 4;
614
615 static const char *types[] = { "float", "float2", "float3", "float4" };
616 statement(ts&: types[to_declare - 1], ts: " ", ts: builtin_to_glsl(builtin, storage: StorageClassOutput), ts&: semantic_index,
617 ts: " : SV_ClipDistance", ts&: semantic_index, ts: ";");
618 }
619 break;
620
621 case BuiltInCullDistance:
622 // HLSL is a bit weird here, use SV_CullDistance0, SV_CullDistance1 and so on with vectors.
623 for (uint32_t cull = 0; cull < cull_distance_count; cull += 4)
624 {
625 uint32_t to_declare = cull_distance_count - cull;
626 if (to_declare > 4)
627 to_declare = 4;
628
629 uint32_t semantic_index = cull / 4;
630
631 static const char *types[] = { "float", "float2", "float3", "float4" };
632 statement(ts&: types[to_declare - 1], ts: " ", ts: builtin_to_glsl(builtin, storage: StorageClassOutput), ts&: semantic_index,
633 ts: " : SV_CullDistance", ts&: semantic_index, ts: ";");
634 }
635 break;
636
637 case BuiltInPointSize:
638 // If point_size_compat is enabled, just ignore PointSize.
639 // PointSize does not exist in HLSL, but some code bases might want to be able to use these shaders,
640 // even if it means working around the missing feature.
641 if (hlsl_options.point_size_compat)
642 break;
643 else
644 SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
645
646 case BuiltInLayer:
647 if (hlsl_options.shader_model < 50 || get_entry_point().model != ExecutionModelGeometry)
648 SPIRV_CROSS_THROW("Render target array index output is only supported in GS 5.0 or higher.");
649 type = "uint";
650 semantic = "SV_RenderTargetArrayIndex";
651 break;
652
653 default:
654 SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
655 }
656
657 if (type && semantic)
658 statement(ts&: type, ts: " ", ts: builtin_to_glsl(builtin, storage: StorageClassOutput), ts: " : ", ts&: semantic, ts: ";");
659 });
660}
661
662void CompilerHLSL::emit_builtin_inputs_in_struct()
663{
664 bool legacy = hlsl_options.shader_model <= 30;
665 active_input_builtins.for_each_bit(op: [&](uint32_t i) {
666 const char *type = nullptr;
667 const char *semantic = nullptr;
668 auto builtin = static_cast<BuiltIn>(i);
669 switch (builtin)
670 {
671 case BuiltInFragCoord:
672 type = "float4";
673 semantic = legacy ? "VPOS" : "SV_Position";
674 break;
675
676 case BuiltInVertexId:
677 case BuiltInVertexIndex:
678 if (legacy)
679 SPIRV_CROSS_THROW("Vertex index not supported in SM 3.0 or lower.");
680 type = "uint";
681 semantic = "SV_VertexID";
682 break;
683
684 case BuiltInPrimitiveId:
685 type = "uint";
686 semantic = "SV_PrimitiveID";
687 break;
688
689 case BuiltInInstanceId:
690 case BuiltInInstanceIndex:
691 if (legacy)
692 SPIRV_CROSS_THROW("Instance index not supported in SM 3.0 or lower.");
693 type = "uint";
694 semantic = "SV_InstanceID";
695 break;
696
697 case BuiltInSampleId:
698 if (legacy)
699 SPIRV_CROSS_THROW("Sample ID not supported in SM 3.0 or lower.");
700 type = "uint";
701 semantic = "SV_SampleIndex";
702 break;
703
704 case BuiltInSampleMask:
705 if (hlsl_options.shader_model < 50 || get_entry_point().model != ExecutionModelFragment)
706 SPIRV_CROSS_THROW("Sample Mask input is only supported in PS 5.0 or higher.");
707 type = "uint";
708 semantic = "SV_Coverage";
709 break;
710
711 case BuiltInGlobalInvocationId:
712 type = "uint3";
713 semantic = "SV_DispatchThreadID";
714 break;
715
716 case BuiltInLocalInvocationId:
717 type = "uint3";
718 semantic = "SV_GroupThreadID";
719 break;
720
721 case BuiltInLocalInvocationIndex:
722 type = "uint";
723 semantic = "SV_GroupIndex";
724 break;
725
726 case BuiltInWorkgroupId:
727 type = "uint3";
728 semantic = "SV_GroupID";
729 break;
730
731 case BuiltInFrontFacing:
732 type = "bool";
733 semantic = "SV_IsFrontFace";
734 break;
735
736 case BuiltInViewIndex:
737 if (hlsl_options.shader_model < 61 || (get_entry_point().model != ExecutionModelVertex && get_entry_point().model != ExecutionModelFragment))
738 SPIRV_CROSS_THROW("View Index input is only supported in VS and PS 6.1 or higher.");
739 type = "uint";
740 semantic = "SV_ViewID";
741 break;
742
743 case BuiltInNumWorkgroups:
744 case BuiltInSubgroupSize:
745 case BuiltInSubgroupLocalInvocationId:
746 case BuiltInSubgroupEqMask:
747 case BuiltInSubgroupLtMask:
748 case BuiltInSubgroupLeMask:
749 case BuiltInSubgroupGtMask:
750 case BuiltInSubgroupGeMask:
751 // Handled specially.
752 break;
753
754 case BuiltInHelperInvocation:
755 if (hlsl_options.shader_model < 50 || get_entry_point().model != ExecutionModelFragment)
756 SPIRV_CROSS_THROW("Helper Invocation input is only supported in PS 5.0 or higher.");
757 break;
758
759 case BuiltInClipDistance:
760 // HLSL is a bit weird here, use SV_ClipDistance0, SV_ClipDistance1 and so on with vectors.
761 for (uint32_t clip = 0; clip < clip_distance_count; clip += 4)
762 {
763 uint32_t to_declare = clip_distance_count - clip;
764 if (to_declare > 4)
765 to_declare = 4;
766
767 uint32_t semantic_index = clip / 4;
768
769 static const char *types[] = { "float", "float2", "float3", "float4" };
770 statement(ts&: types[to_declare - 1], ts: " ", ts: builtin_to_glsl(builtin, storage: StorageClassInput), ts&: semantic_index,
771 ts: " : SV_ClipDistance", ts&: semantic_index, ts: ";");
772 }
773 break;
774
775 case BuiltInCullDistance:
776 // HLSL is a bit weird here, use SV_CullDistance0, SV_CullDistance1 and so on with vectors.
777 for (uint32_t cull = 0; cull < cull_distance_count; cull += 4)
778 {
779 uint32_t to_declare = cull_distance_count - cull;
780 if (to_declare > 4)
781 to_declare = 4;
782
783 uint32_t semantic_index = cull / 4;
784
785 static const char *types[] = { "float", "float2", "float3", "float4" };
786 statement(ts&: types[to_declare - 1], ts: " ", ts: builtin_to_glsl(builtin, storage: StorageClassInput), ts&: semantic_index,
787 ts: " : SV_CullDistance", ts&: semantic_index, ts: ";");
788 }
789 break;
790
791 case BuiltInPointCoord:
792 // PointCoord is not supported, but provide a way to just ignore that, similar to PointSize.
793 if (hlsl_options.point_coord_compat)
794 break;
795 else
796 SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
797
798 case BuiltInLayer:
799 if (hlsl_options.shader_model < 50 || get_entry_point().model != ExecutionModelFragment)
800 SPIRV_CROSS_THROW("Render target array index input is only supported in PS 5.0 or higher.");
801 type = "uint";
802 semantic = "SV_RenderTargetArrayIndex";
803 break;
804
805 default:
806 SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
807 }
808
809 if (type && semantic)
810 statement(ts&: type, ts: " ", ts: builtin_to_glsl(builtin, storage: StorageClassInput), ts: " : ", ts&: semantic, ts: ";");
811 });
812}
813
814uint32_t CompilerHLSL::type_to_consumed_locations(const SPIRType &type) const
815{
816 // TODO: Need to verify correctness.
817 uint32_t elements = 0;
818
819 if (type.basetype == SPIRType::Struct)
820 {
821 for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++)
822 elements += type_to_consumed_locations(type: get<SPIRType>(id: type.member_types[i]));
823 }
824 else
825 {
826 uint32_t array_multiplier = 1;
827 for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
828 {
829 if (type.array_size_literal[i])
830 array_multiplier *= type.array[i];
831 else
832 array_multiplier *= evaluate_constant_u32(id: type.array[i]);
833 }
834 elements += array_multiplier * type.columns;
835 }
836 return elements;
837}
838
839string CompilerHLSL::to_interpolation_qualifiers(const Bitset &flags)
840{
841 string res;
842 //if (flags & (1ull << DecorationSmooth))
843 // res += "linear ";
844 if (flags.get(bit: DecorationFlat))
845 res += "nointerpolation ";
846 if (flags.get(bit: DecorationNoPerspective))
847 res += "noperspective ";
848 if (flags.get(bit: DecorationCentroid))
849 res += "centroid ";
850 if (flags.get(bit: DecorationPatch))
851 res += "patch "; // Seems to be different in actual HLSL.
852 if (flags.get(bit: DecorationSample))
853 res += "sample ";
854 if (flags.get(bit: DecorationInvariant) && backend.support_precise_qualifier)
855 res += "precise "; // Not supported?
856
857 return res;
858}
859
860std::string CompilerHLSL::to_semantic(uint32_t location, ExecutionModel em, StorageClass sc)
861{
862 if (em == ExecutionModelVertex && sc == StorageClassInput)
863 {
864 // We have a vertex attribute - we should look at remapping it if the user provided
865 // vertex attribute hints.
866 for (auto &attribute : remap_vertex_attributes)
867 if (attribute.location == location)
868 return attribute.semantic;
869 }
870
871 // Not a vertex attribute, or no remap_vertex_attributes entry.
872 return join(ts: "TEXCOORD", ts&: location);
873}
874
875std::string CompilerHLSL::to_initializer_expression(const SPIRVariable &var)
876{
877 // We cannot emit static const initializer for block constants for practical reasons,
878 // so just inline the initializer.
879 // FIXME: There is a theoretical problem here if someone tries to composite extract
880 // into this initializer since we don't declare it properly, but that is somewhat non-sensical.
881 auto &type = get<SPIRType>(id: var.basetype);
882 bool is_block = has_decoration(id: type.self, decoration: DecorationBlock);
883 auto *c = maybe_get<SPIRConstant>(id: var.initializer);
884 if (is_block && c)
885 return constant_expression(c: *c);
886 else
887 return CompilerGLSL::to_initializer_expression(var);
888}
889
890void CompilerHLSL::emit_interface_block_member_in_struct(const SPIRVariable &var, uint32_t member_index,
891 uint32_t location,
892 std::unordered_set<uint32_t> &active_locations)
893{
894 auto &execution = get_entry_point();
895 auto type = get<SPIRType>(id: var.basetype);
896 auto semantic = to_semantic(location, em: execution.model, sc: var.storage);
897 auto mbr_name = join(ts: to_name(id: type.self), ts: "_", ts: to_member_name(type, index: member_index));
898 auto &mbr_type = get<SPIRType>(id: type.member_types[member_index]);
899
900 statement(ts: to_interpolation_qualifiers(flags: get_member_decoration_bitset(id: type.self, index: member_index)),
901 ts: type_to_glsl(type: mbr_type),
902 ts: " ", ts&: mbr_name, ts: type_to_array_glsl(type: mbr_type),
903 ts: " : ", ts&: semantic, ts: ";");
904
905 // Structs and arrays should consume more locations.
906 uint32_t consumed_locations = type_to_consumed_locations(type: mbr_type);
907 for (uint32_t i = 0; i < consumed_locations; i++)
908 active_locations.insert(x: location + i);
909}
910
911void CompilerHLSL::emit_interface_block_in_struct(const SPIRVariable &var, unordered_set<uint32_t> &active_locations)
912{
913 auto &execution = get_entry_point();
914 auto type = get<SPIRType>(id: var.basetype);
915
916 string binding;
917 bool use_location_number = true;
918 bool legacy = hlsl_options.shader_model <= 30;
919 if (execution.model == ExecutionModelFragment && var.storage == StorageClassOutput)
920 {
921 // Dual-source blending is achieved in HLSL by emitting to SV_Target0 and 1.
922 uint32_t index = get_decoration(id: var.self, decoration: DecorationIndex);
923 uint32_t location = get_decoration(id: var.self, decoration: DecorationLocation);
924
925 if (index != 0 && location != 0)
926 SPIRV_CROSS_THROW("Dual-source blending is only supported on MRT #0 in HLSL.");
927
928 binding = join(ts: legacy ? "COLOR" : "SV_Target", ts: location + index);
929 use_location_number = false;
930 if (legacy) // COLOR must be a four-component vector on legacy shader model targets (HLSL ERR_COLOR_4COMP)
931 type.vecsize = 4;
932 }
933
934 const auto get_vacant_location = [&]() -> uint32_t {
935 for (uint32_t i = 0; i < 64; i++)
936 if (!active_locations.count(x: i))
937 return i;
938 SPIRV_CROSS_THROW("All locations from 0 to 63 are exhausted.");
939 };
940
941 bool need_matrix_unroll = var.storage == StorageClassInput && execution.model == ExecutionModelVertex;
942
943 auto name = to_name(id: var.self);
944 if (use_location_number)
945 {
946 uint32_t location_number;
947
948 // If an explicit location exists, use it with TEXCOORD[N] semantic.
949 // Otherwise, pick a vacant location.
950 if (has_decoration(id: var.self, decoration: DecorationLocation))
951 location_number = get_decoration(id: var.self, decoration: DecorationLocation);
952 else
953 location_number = get_vacant_location();
954
955 // Allow semantic remap if specified.
956 auto semantic = to_semantic(location: location_number, em: execution.model, sc: var.storage);
957
958 if (need_matrix_unroll && type.columns > 1)
959 {
960 if (!type.array.empty())
961 SPIRV_CROSS_THROW("Arrays of matrices used as input/output. This is not supported.");
962
963 // Unroll matrices.
964 for (uint32_t i = 0; i < type.columns; i++)
965 {
966 SPIRType newtype = type;
967 newtype.columns = 1;
968
969 string effective_semantic;
970 if (hlsl_options.flatten_matrix_vertex_input_semantics)
971 effective_semantic = to_semantic(location: location_number, em: execution.model, sc: var.storage);
972 else
973 effective_semantic = join(ts&: semantic, ts: "_", ts&: i);
974
975 statement(ts: to_interpolation_qualifiers(flags: get_decoration_bitset(id: var.self)),
976 ts: variable_decl(type: newtype, name: join(ts&: name, ts: "_", ts&: i)), ts: " : ", ts&: effective_semantic, ts: ";");
977 active_locations.insert(x: location_number++);
978 }
979 }
980 else
981 {
982 statement(ts: to_interpolation_qualifiers(flags: get_decoration_bitset(id: var.self)), ts: variable_decl(type, name), ts: " : ",
983 ts&: semantic, ts: ";");
984
985 // Structs and arrays should consume more locations.
986 uint32_t consumed_locations = type_to_consumed_locations(type);
987 for (uint32_t i = 0; i < consumed_locations; i++)
988 active_locations.insert(x: location_number + i);
989 }
990 }
991 else
992 statement(ts: variable_decl(type, name), ts: " : ", ts&: binding, ts: ";");
993}
994
995std::string CompilerHLSL::builtin_to_glsl(spv::BuiltIn builtin, spv::StorageClass storage)
996{
997 switch (builtin)
998 {
999 case BuiltInVertexId:
1000 return "gl_VertexID";
1001 case BuiltInInstanceId:
1002 return "gl_InstanceID";
1003 case BuiltInNumWorkgroups:
1004 {
1005 if (!num_workgroups_builtin)
1006 SPIRV_CROSS_THROW("NumWorkgroups builtin is used, but remap_num_workgroups_builtin() was not called. "
1007 "Cannot emit code for this builtin.");
1008
1009 auto &var = get<SPIRVariable>(id: num_workgroups_builtin);
1010 auto &type = get<SPIRType>(id: var.basetype);
1011 auto ret = join(ts: to_name(id: num_workgroups_builtin), ts: "_", ts: get_member_name(id: type.self, index: 0));
1012 ParsedIR::sanitize_underscores(str&: ret);
1013 return ret;
1014 }
1015 case BuiltInPointCoord:
1016 // Crude hack, but there is no real alternative. This path is only enabled if point_coord_compat is set.
1017 return "float2(0.5f, 0.5f)";
1018 case BuiltInSubgroupLocalInvocationId:
1019 return "WaveGetLaneIndex()";
1020 case BuiltInSubgroupSize:
1021 return "WaveGetLaneCount()";
1022 case BuiltInHelperInvocation:
1023 return "IsHelperLane()";
1024
1025 default:
1026 return CompilerGLSL::builtin_to_glsl(builtin, storage);
1027 }
1028}
1029
1030void CompilerHLSL::emit_builtin_variables()
1031{
1032 Bitset builtins = active_input_builtins;
1033 builtins.merge_or(other: active_output_builtins);
1034
1035 bool need_base_vertex_info = false;
1036
1037 std::unordered_map<uint32_t, ID> builtin_to_initializer;
1038 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
1039 if (!is_builtin_variable(var) || var.storage != StorageClassOutput || !var.initializer)
1040 return;
1041
1042 auto *c = this->maybe_get<SPIRConstant>(id: var.initializer);
1043 if (!c)
1044 return;
1045
1046 auto &type = this->get<SPIRType>(id: var.basetype);
1047 if (type.basetype == SPIRType::Struct)
1048 {
1049 uint32_t member_count = uint32_t(type.member_types.size());
1050 for (uint32_t i = 0; i < member_count; i++)
1051 {
1052 if (has_member_decoration(id: type.self, index: i, decoration: DecorationBuiltIn))
1053 {
1054 builtin_to_initializer[get_member_decoration(id: type.self, index: i, decoration: DecorationBuiltIn)] =
1055 c->subconstants[i];
1056 }
1057 }
1058 }
1059 else if (has_decoration(id: var.self, decoration: DecorationBuiltIn))
1060 builtin_to_initializer[get_decoration(id: var.self, decoration: DecorationBuiltIn)] = var.initializer;
1061 });
1062
1063 // Emit global variables for the interface variables which are statically used by the shader.
1064 builtins.for_each_bit(op: [&](uint32_t i) {
1065 const char *type = nullptr;
1066 auto builtin = static_cast<BuiltIn>(i);
1067 uint32_t array_size = 0;
1068
1069 string init_expr;
1070 auto init_itr = builtin_to_initializer.find(x: builtin);
1071 if (init_itr != builtin_to_initializer.end())
1072 init_expr = join(ts: " = ", ts: to_expression(id: init_itr->second));
1073
1074 switch (builtin)
1075 {
1076 case BuiltInFragCoord:
1077 case BuiltInPosition:
1078 type = "float4";
1079 break;
1080
1081 case BuiltInFragDepth:
1082 type = "float";
1083 break;
1084
1085 case BuiltInVertexId:
1086 case BuiltInVertexIndex:
1087 case BuiltInInstanceIndex:
1088 type = "int";
1089 if (hlsl_options.support_nonzero_base_vertex_base_instance)
1090 need_base_vertex_info = true;
1091 break;
1092
1093 case BuiltInInstanceId:
1094 case BuiltInSampleId:
1095 type = "int";
1096 break;
1097
1098 case BuiltInPointSize:
1099 if (hlsl_options.point_size_compat)
1100 {
1101 // Just emit the global variable, it will be ignored.
1102 type = "float";
1103 break;
1104 }
1105 else
1106 SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin)));
1107
1108 case BuiltInGlobalInvocationId:
1109 case BuiltInLocalInvocationId:
1110 case BuiltInWorkgroupId:
1111 type = "uint3";
1112 break;
1113
1114 case BuiltInLocalInvocationIndex:
1115 type = "uint";
1116 break;
1117
1118 case BuiltInFrontFacing:
1119 type = "bool";
1120 break;
1121
1122 case BuiltInNumWorkgroups:
1123 case BuiltInPointCoord:
1124 // Handled specially.
1125 break;
1126
1127 case BuiltInSubgroupLocalInvocationId:
1128 case BuiltInSubgroupSize:
1129 if (hlsl_options.shader_model < 60)
1130 SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops.");
1131 break;
1132
1133 case BuiltInSubgroupEqMask:
1134 case BuiltInSubgroupLtMask:
1135 case BuiltInSubgroupLeMask:
1136 case BuiltInSubgroupGtMask:
1137 case BuiltInSubgroupGeMask:
1138 if (hlsl_options.shader_model < 60)
1139 SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops.");
1140 type = "uint4";
1141 break;
1142
1143 case BuiltInHelperInvocation:
1144 if (hlsl_options.shader_model < 50)
1145 SPIRV_CROSS_THROW("Need SM 5.0 for Helper Invocation.");
1146 break;
1147
1148 case BuiltInClipDistance:
1149 array_size = clip_distance_count;
1150 type = "float";
1151 break;
1152
1153 case BuiltInCullDistance:
1154 array_size = cull_distance_count;
1155 type = "float";
1156 break;
1157
1158 case BuiltInSampleMask:
1159 type = "int";
1160 break;
1161
1162 case BuiltInPrimitiveId:
1163 case BuiltInViewIndex:
1164 case BuiltInLayer:
1165 type = "uint";
1166 break;
1167
1168 default:
1169 SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin)));
1170 }
1171
1172 StorageClass storage = active_input_builtins.get(bit: i) ? StorageClassInput : StorageClassOutput;
1173
1174 if (type)
1175 {
1176 if (array_size)
1177 statement(ts: "static ", ts&: type, ts: " ", ts: builtin_to_glsl(builtin, storage), ts: "[", ts&: array_size, ts: "]", ts&: init_expr, ts: ";");
1178 else
1179 statement(ts: "static ", ts&: type, ts: " ", ts: builtin_to_glsl(builtin, storage), ts&: init_expr, ts: ";");
1180 }
1181
1182 // SampleMask can be both in and out with sample builtin, in this case we have already
1183 // declared the input variable and we need to add the output one now.
1184 if (builtin == BuiltInSampleMask && storage == StorageClassInput && this->active_output_builtins.get(bit: i))
1185 {
1186 statement(ts: "static ", ts&: type, ts: " ", ts: this->builtin_to_glsl(builtin, storage: StorageClassOutput), ts&: init_expr, ts: ";");
1187 }
1188 });
1189
1190 if (need_base_vertex_info)
1191 {
1192 statement(ts: "cbuffer SPIRV_Cross_VertexInfo");
1193 begin_scope();
1194 statement(ts: "int SPIRV_Cross_BaseVertex;");
1195 statement(ts: "int SPIRV_Cross_BaseInstance;");
1196 end_scope_decl();
1197 statement(ts: "");
1198 }
1199}
1200
1201void CompilerHLSL::emit_composite_constants()
1202{
1203 // HLSL cannot declare structs or arrays inline, so we must move them out to
1204 // global constants directly.
1205 bool emitted = false;
1206
1207 ir.for_each_typed_id<SPIRConstant>(op: [&](uint32_t, SPIRConstant &c) {
1208 if (c.specialization)
1209 return;
1210
1211 auto &type = this->get<SPIRType>(id: c.constant_type);
1212
1213 if (type.basetype == SPIRType::Struct && is_builtin_type(type))
1214 return;
1215
1216 if (type.basetype == SPIRType::Struct || !type.array.empty())
1217 {
1218 add_resource_name(id: c.self);
1219 auto name = to_name(id: c.self);
1220 statement(ts: "static const ", ts: variable_decl(type, name), ts: " = ", ts: constant_expression(c), ts: ";");
1221 emitted = true;
1222 }
1223 });
1224
1225 if (emitted)
1226 statement(ts: "");
1227}
1228
1229void CompilerHLSL::emit_specialization_constants_and_structs()
1230{
1231 bool emitted = false;
1232 SpecializationConstant wg_x, wg_y, wg_z;
1233 ID workgroup_size_id = get_work_group_size_specialization_constants(x&: wg_x, y&: wg_y, z&: wg_z);
1234
1235 std::unordered_set<TypeID> io_block_types;
1236 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, const SPIRVariable &var) {
1237 auto &type = this->get<SPIRType>(id: var.basetype);
1238 if ((var.storage == StorageClassInput || var.storage == StorageClassOutput) &&
1239 !var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
1240 interface_variable_exists_in_entry_point(id: var.self) &&
1241 has_decoration(id: type.self, decoration: DecorationBlock))
1242 {
1243 io_block_types.insert(x: type.self);
1244 }
1245 });
1246
1247 auto loop_lock = ir.create_loop_hard_lock();
1248 for (auto &id_ : ir.ids_for_constant_or_type)
1249 {
1250 auto &id = ir.ids[id_];
1251
1252 if (id.get_type() == TypeConstant)
1253 {
1254 auto &c = id.get<SPIRConstant>();
1255
1256 if (c.self == workgroup_size_id)
1257 {
1258 statement(ts: "static const uint3 gl_WorkGroupSize = ",
1259 ts: constant_expression(c: get<SPIRConstant>(id: workgroup_size_id)), ts: ";");
1260 emitted = true;
1261 }
1262 else if (c.specialization)
1263 {
1264 auto &type = get<SPIRType>(id: c.constant_type);
1265 add_resource_name(id: c.self);
1266 auto name = to_name(id: c.self);
1267
1268 if (has_decoration(id: c.self, decoration: DecorationSpecId))
1269 {
1270 // HLSL does not support specialization constants, so fallback to macros.
1271 c.specialization_constant_macro_name =
1272 constant_value_macro_name(id: get_decoration(id: c.self, decoration: DecorationSpecId));
1273
1274 statement(ts: "#ifndef ", ts&: c.specialization_constant_macro_name);
1275 statement(ts: "#define ", ts&: c.specialization_constant_macro_name, ts: " ", ts: constant_expression(c));
1276 statement(ts: "#endif");
1277 statement(ts: "static const ", ts: variable_decl(type, name), ts: " = ", ts&: c.specialization_constant_macro_name, ts: ";");
1278 }
1279 else
1280 statement(ts: "static const ", ts: variable_decl(type, name), ts: " = ", ts: constant_expression(c), ts: ";");
1281
1282 emitted = true;
1283 }
1284 }
1285 else if (id.get_type() == TypeConstantOp)
1286 {
1287 auto &c = id.get<SPIRConstantOp>();
1288 auto &type = get<SPIRType>(id: c.basetype);
1289 add_resource_name(id: c.self);
1290 auto name = to_name(id: c.self);
1291 statement(ts: "static const ", ts: variable_decl(type, name), ts: " = ", ts: constant_op_expression(cop: c), ts: ";");
1292 emitted = true;
1293 }
1294 else if (id.get_type() == TypeType)
1295 {
1296 auto &type = id.get<SPIRType>();
1297 bool is_non_io_block = has_decoration(id: type.self, decoration: DecorationBlock) &&
1298 io_block_types.count(x: type.self) == 0;
1299 bool is_buffer_block = has_decoration(id: type.self, decoration: DecorationBufferBlock);
1300 if (type.basetype == SPIRType::Struct && type.array.empty() &&
1301 !type.pointer && !is_non_io_block && !is_buffer_block)
1302 {
1303 if (emitted)
1304 statement(ts: "");
1305 emitted = false;
1306
1307 emit_struct(type);
1308 }
1309 }
1310 }
1311
1312 if (emitted)
1313 statement(ts: "");
1314}
1315
1316void CompilerHLSL::replace_illegal_names()
1317{
1318 static const unordered_set<string> keywords = {
1319 // Additional HLSL specific keywords.
1320 // From https://docs.microsoft.com/en-US/windows/win32/direct3dhlsl/dx-graphics-hlsl-appendix-keywords
1321 "AppendStructuredBuffer", "asm", "asm_fragment",
1322 "BlendState", "bool", "break", "Buffer", "ByteAddressBuffer",
1323 "case", "cbuffer", "centroid", "class", "column_major", "compile",
1324 "compile_fragment", "CompileShader", "const", "continue", "ComputeShader",
1325 "ConsumeStructuredBuffer",
1326 "default", "DepthStencilState", "DepthStencilView", "discard", "do",
1327 "double", "DomainShader", "dword",
1328 "else", "export", "false", "float", "for", "fxgroup",
1329 "GeometryShader", "groupshared", "half", "HullShader",
1330 "if", "in", "inline", "inout", "InputPatch", "int", "interface",
1331 "line", "lineadj", "linear", "LineStream",
1332 "matrix", "min16float", "min10float", "min16int", "min16uint",
1333 "namespace", "nointerpolation", "noperspective", "NULL",
1334 "out", "OutputPatch",
1335 "packoffset", "pass", "pixelfragment", "PixelShader", "point",
1336 "PointStream", "precise", "RasterizerState", "RenderTargetView",
1337 "return", "register", "row_major", "RWBuffer", "RWByteAddressBuffer",
1338 "RWStructuredBuffer", "RWTexture1D", "RWTexture1DArray", "RWTexture2D",
1339 "RWTexture2DArray", "RWTexture3D", "sample", "sampler", "SamplerState",
1340 "SamplerComparisonState", "shared", "snorm", "stateblock", "stateblock_state",
1341 "static", "string", "struct", "switch", "StructuredBuffer", "tbuffer",
1342 "technique", "technique10", "technique11", "texture", "Texture1D",
1343 "Texture1DArray", "Texture2D", "Texture2DArray", "Texture2DMS", "Texture2DMSArray",
1344 "Texture3D", "TextureCube", "TextureCubeArray", "true", "typedef", "triangle",
1345 "triangleadj", "TriangleStream", "uint", "uniform", "unorm", "unsigned",
1346 "vector", "vertexfragment", "VertexShader", "void", "volatile", "while",
1347 };
1348
1349 CompilerGLSL::replace_illegal_names(keywords);
1350 CompilerGLSL::replace_illegal_names();
1351}
1352
1353void CompilerHLSL::declare_undefined_values()
1354{
1355 bool emitted = false;
1356 ir.for_each_typed_id<SPIRUndef>(op: [&](uint32_t, const SPIRUndef &undef) {
1357 auto &type = this->get<SPIRType>(id: undef.basetype);
1358 // OpUndef can be void for some reason ...
1359 if (type.basetype == SPIRType::Void)
1360 return;
1361
1362 string initializer;
1363 if (options.force_zero_initialized_variables && type_can_zero_initialize(type))
1364 initializer = join(ts: " = ", ts: to_zero_initialized_expression(type_id: undef.basetype));
1365
1366 statement(ts: "static ", ts: variable_decl(type, name: to_name(id: undef.self), id: undef.self), ts&: initializer, ts: ";");
1367 emitted = true;
1368 });
1369
1370 if (emitted)
1371 statement(ts: "");
1372}
1373
1374void CompilerHLSL::emit_resources()
1375{
1376 auto &execution = get_entry_point();
1377
1378 replace_illegal_names();
1379
1380 emit_specialization_constants_and_structs();
1381 emit_composite_constants();
1382
1383 bool emitted = false;
1384
1385 // Output UBOs and SSBOs
1386 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
1387 auto &type = this->get<SPIRType>(id: var.basetype);
1388
1389 bool is_block_storage = type.storage == StorageClassStorageBuffer || type.storage == StorageClassUniform;
1390 bool has_block_flags = ir.meta[type.self].decoration.decoration_flags.get(bit: DecorationBlock) ||
1391 ir.meta[type.self].decoration.decoration_flags.get(bit: DecorationBufferBlock);
1392
1393 if (var.storage != StorageClassFunction && type.pointer && is_block_storage && !is_hidden_variable(var) &&
1394 has_block_flags)
1395 {
1396 emit_buffer_block(type: var);
1397 emitted = true;
1398 }
1399 });
1400
1401 // Output push constant blocks
1402 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
1403 auto &type = this->get<SPIRType>(id: var.basetype);
1404 if (var.storage != StorageClassFunction && type.pointer && type.storage == StorageClassPushConstant &&
1405 !is_hidden_variable(var))
1406 {
1407 emit_push_constant_block(var);
1408 emitted = true;
1409 }
1410 });
1411
1412 if (execution.model == ExecutionModelVertex && hlsl_options.shader_model <= 30 &&
1413 active_output_builtins.get(bit: BuiltInPosition))
1414 {
1415 statement(ts: "uniform float4 gl_HalfPixel;");
1416 emitted = true;
1417 }
1418
1419 bool skip_separate_image_sampler = !combined_image_samplers.empty() || hlsl_options.shader_model <= 30;
1420
1421 // Output Uniform Constants (values, samplers, images, etc).
1422 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
1423 auto &type = this->get<SPIRType>(id: var.basetype);
1424
1425 // If we're remapping separate samplers and images, only emit the combined samplers.
1426 if (skip_separate_image_sampler)
1427 {
1428 // Sampler buffers are always used without a sampler, and they will also work in regular D3D.
1429 bool sampler_buffer = type.basetype == SPIRType::Image && type.image.dim == DimBuffer;
1430 bool separate_image = type.basetype == SPIRType::Image && type.image.sampled == 1;
1431 bool separate_sampler = type.basetype == SPIRType::Sampler;
1432 if (!sampler_buffer && (separate_image || separate_sampler))
1433 return;
1434 }
1435
1436 if (var.storage != StorageClassFunction && !is_builtin_variable(var) && !var.remapped_variable &&
1437 type.pointer && (type.storage == StorageClassUniformConstant || type.storage == StorageClassAtomicCounter) &&
1438 !is_hidden_variable(var))
1439 {
1440 emit_uniform(var);
1441 emitted = true;
1442 }
1443 });
1444
1445 if (emitted)
1446 statement(ts: "");
1447 emitted = false;
1448
1449 // Emit builtin input and output variables here.
1450 emit_builtin_variables();
1451
1452 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
1453 auto &type = this->get<SPIRType>(id: var.basetype);
1454
1455 if (var.storage != StorageClassFunction && !var.remapped_variable && type.pointer &&
1456 (var.storage == StorageClassInput || var.storage == StorageClassOutput) && !is_builtin_variable(var) &&
1457 interface_variable_exists_in_entry_point(id: var.self))
1458 {
1459 // Builtin variables are handled separately.
1460 emit_interface_block_globally(var);
1461 emitted = true;
1462 }
1463 });
1464
1465 if (emitted)
1466 statement(ts: "");
1467 emitted = false;
1468
1469 require_input = false;
1470 require_output = false;
1471 unordered_set<uint32_t> active_inputs;
1472 unordered_set<uint32_t> active_outputs;
1473
1474 struct IOVariable
1475 {
1476 const SPIRVariable *var;
1477 uint32_t location;
1478 uint32_t block_member_index;
1479 bool block;
1480 };
1481
1482 SmallVector<IOVariable> input_variables;
1483 SmallVector<IOVariable> output_variables;
1484
1485 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
1486 auto &type = this->get<SPIRType>(id: var.basetype);
1487 bool block = has_decoration(id: type.self, decoration: DecorationBlock);
1488
1489 if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
1490 return;
1491
1492 if (!var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
1493 interface_variable_exists_in_entry_point(id: var.self))
1494 {
1495 if (block)
1496 {
1497 for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++)
1498 {
1499 uint32_t location = get_declared_member_location(var, mbr_idx: i, strip_array: false);
1500 if (var.storage == StorageClassInput)
1501 input_variables.push_back(t: { .var: &var, .location: location, .block_member_index: i, .block: true });
1502 else
1503 output_variables.push_back(t: { .var: &var, .location: location, .block_member_index: i, .block: true });
1504 }
1505 }
1506 else
1507 {
1508 uint32_t location = get_decoration(id: var.self, decoration: DecorationLocation);
1509 if (var.storage == StorageClassInput)
1510 input_variables.push_back(t: { .var: &var, .location: location, .block_member_index: 0, .block: false });
1511 else
1512 output_variables.push_back(t: { .var: &var, .location: location, .block_member_index: 0, .block: false });
1513 }
1514 }
1515 });
1516
1517 const auto variable_compare = [&](const IOVariable &a, const IOVariable &b) -> bool {
1518 // Sort input and output variables based on, from more robust to less robust:
1519 // - Location
1520 // - Variable has a location
1521 // - Name comparison
1522 // - Variable has a name
1523 // - Fallback: ID
1524 bool has_location_a = a.block || has_decoration(id: a.var->self, decoration: DecorationLocation);
1525 bool has_location_b = b.block || has_decoration(id: b.var->self, decoration: DecorationLocation);
1526
1527 if (has_location_a && has_location_b)
1528 return a.location < b.location;
1529 else if (has_location_a && !has_location_b)
1530 return true;
1531 else if (!has_location_a && has_location_b)
1532 return false;
1533
1534 const auto &name1 = to_name(id: a.var->self);
1535 const auto &name2 = to_name(id: b.var->self);
1536
1537 if (name1.empty() && name2.empty())
1538 return a.var->self < b.var->self;
1539 else if (name1.empty())
1540 return true;
1541 else if (name2.empty())
1542 return false;
1543
1544 return name1.compare(str: name2) < 0;
1545 };
1546
1547 auto input_builtins = active_input_builtins;
1548 input_builtins.clear(bit: BuiltInNumWorkgroups);
1549 input_builtins.clear(bit: BuiltInPointCoord);
1550 input_builtins.clear(bit: BuiltInSubgroupSize);
1551 input_builtins.clear(bit: BuiltInSubgroupLocalInvocationId);
1552 input_builtins.clear(bit: BuiltInSubgroupEqMask);
1553 input_builtins.clear(bit: BuiltInSubgroupLtMask);
1554 input_builtins.clear(bit: BuiltInSubgroupLeMask);
1555 input_builtins.clear(bit: BuiltInSubgroupGtMask);
1556 input_builtins.clear(bit: BuiltInSubgroupGeMask);
1557
1558 if (!input_variables.empty() || !input_builtins.empty())
1559 {
1560 require_input = true;
1561 statement(ts: "struct SPIRV_Cross_Input");
1562
1563 begin_scope();
1564 sort(first: input_variables.begin(), last: input_variables.end(), comp: variable_compare);
1565 for (auto &var : input_variables)
1566 {
1567 if (var.block)
1568 emit_interface_block_member_in_struct(var: *var.var, member_index: var.block_member_index, location: var.location, active_locations&: active_inputs);
1569 else
1570 emit_interface_block_in_struct(var: *var.var, active_locations&: active_inputs);
1571 }
1572 emit_builtin_inputs_in_struct();
1573 end_scope_decl();
1574 statement(ts: "");
1575 }
1576
1577 if (!output_variables.empty() || !active_output_builtins.empty())
1578 {
1579 require_output = true;
1580 statement(ts: "struct SPIRV_Cross_Output");
1581
1582 begin_scope();
1583 sort(first: output_variables.begin(), last: output_variables.end(), comp: variable_compare);
1584 for (auto &var : output_variables)
1585 {
1586 if (var.block)
1587 emit_interface_block_member_in_struct(var: *var.var, member_index: var.block_member_index, location: var.location, active_locations&: active_outputs);
1588 else
1589 emit_interface_block_in_struct(var: *var.var, active_locations&: active_outputs);
1590 }
1591 emit_builtin_outputs_in_struct();
1592 end_scope_decl();
1593 statement(ts: "");
1594 }
1595
1596 // Global variables.
1597 for (auto global : global_variables)
1598 {
1599 auto &var = get<SPIRVariable>(id: global);
1600 if (is_hidden_variable(var, include_builtins: true))
1601 continue;
1602
1603 if (var.storage != StorageClassOutput)
1604 {
1605 if (!variable_is_lut(var))
1606 {
1607 add_resource_name(id: var.self);
1608
1609 const char *storage = nullptr;
1610 switch (var.storage)
1611 {
1612 case StorageClassWorkgroup:
1613 storage = "groupshared";
1614 break;
1615
1616 default:
1617 storage = "static";
1618 break;
1619 }
1620
1621 string initializer;
1622 if (options.force_zero_initialized_variables && var.storage == StorageClassPrivate &&
1623 !var.initializer && !var.static_expression && type_can_zero_initialize(type: get_variable_data_type(var)))
1624 {
1625 initializer = join(ts: " = ", ts: to_zero_initialized_expression(type_id: get_variable_data_type_id(var)));
1626 }
1627 statement(ts&: storage, ts: " ", ts: variable_decl(variable: var), ts&: initializer, ts: ";");
1628
1629 emitted = true;
1630 }
1631 }
1632 }
1633
1634 if (emitted)
1635 statement(ts: "");
1636
1637 declare_undefined_values();
1638
1639 if (requires_op_fmod)
1640 {
1641 static const char *types[] = {
1642 "float",
1643 "float2",
1644 "float3",
1645 "float4",
1646 };
1647
1648 for (auto &type : types)
1649 {
1650 statement(ts&: type, ts: " mod(", ts&: type, ts: " x, ", ts&: type, ts: " y)");
1651 begin_scope();
1652 statement(ts: "return x - y * floor(x / y);");
1653 end_scope();
1654 statement(ts: "");
1655 }
1656 }
1657
1658 emit_texture_size_variants(variant_mask: required_texture_size_variants.srv, vecsize_qualifier: "4", uav: false, type_qualifier: "");
1659 for (uint32_t norm = 0; norm < 3; norm++)
1660 {
1661 for (uint32_t comp = 0; comp < 4; comp++)
1662 {
1663 static const char *qualifiers[] = { "", "unorm ", "snorm " };
1664 static const char *vecsizes[] = { "", "2", "3", "4" };
1665 emit_texture_size_variants(variant_mask: required_texture_size_variants.uav[norm][comp], vecsize_qualifier: vecsizes[comp], uav: true,
1666 type_qualifier: qualifiers[norm]);
1667 }
1668 }
1669
1670 if (requires_fp16_packing)
1671 {
1672 // HLSL does not pack into a single word sadly :(
1673 statement(ts: "uint spvPackHalf2x16(float2 value)");
1674 begin_scope();
1675 statement(ts: "uint2 Packed = f32tof16(value);");
1676 statement(ts: "return Packed.x | (Packed.y << 16);");
1677 end_scope();
1678 statement(ts: "");
1679
1680 statement(ts: "float2 spvUnpackHalf2x16(uint value)");
1681 begin_scope();
1682 statement(ts: "return f16tof32(uint2(value & 0xffff, value >> 16));");
1683 end_scope();
1684 statement(ts: "");
1685 }
1686
1687 if (requires_uint2_packing)
1688 {
1689 statement(ts: "uint64_t spvPackUint2x32(uint2 value)");
1690 begin_scope();
1691 statement(ts: "return (uint64_t(value.y) << 32) | uint64_t(value.x);");
1692 end_scope();
1693 statement(ts: "");
1694
1695 statement(ts: "uint2 spvUnpackUint2x32(uint64_t value)");
1696 begin_scope();
1697 statement(ts: "uint2 Unpacked;");
1698 statement(ts: "Unpacked.x = uint(value & 0xffffffff);");
1699 statement(ts: "Unpacked.y = uint(value >> 32);");
1700 statement(ts: "return Unpacked;");
1701 end_scope();
1702 statement(ts: "");
1703 }
1704
1705 if (requires_explicit_fp16_packing)
1706 {
1707 // HLSL does not pack into a single word sadly :(
1708 statement(ts: "uint spvPackFloat2x16(min16float2 value)");
1709 begin_scope();
1710 statement(ts: "uint2 Packed = f32tof16(value);");
1711 statement(ts: "return Packed.x | (Packed.y << 16);");
1712 end_scope();
1713 statement(ts: "");
1714
1715 statement(ts: "min16float2 spvUnpackFloat2x16(uint value)");
1716 begin_scope();
1717 statement(ts: "return min16float2(f16tof32(uint2(value & 0xffff, value >> 16)));");
1718 end_scope();
1719 statement(ts: "");
1720 }
1721
1722 // HLSL does not seem to have builtins for these operation, so roll them by hand ...
1723 if (requires_unorm8_packing)
1724 {
1725 statement(ts: "uint spvPackUnorm4x8(float4 value)");
1726 begin_scope();
1727 statement(ts: "uint4 Packed = uint4(round(saturate(value) * 255.0));");
1728 statement(ts: "return Packed.x | (Packed.y << 8) | (Packed.z << 16) | (Packed.w << 24);");
1729 end_scope();
1730 statement(ts: "");
1731
1732 statement(ts: "float4 spvUnpackUnorm4x8(uint value)");
1733 begin_scope();
1734 statement(ts: "uint4 Packed = uint4(value & 0xff, (value >> 8) & 0xff, (value >> 16) & 0xff, value >> 24);");
1735 statement(ts: "return float4(Packed) / 255.0;");
1736 end_scope();
1737 statement(ts: "");
1738 }
1739
1740 if (requires_snorm8_packing)
1741 {
1742 statement(ts: "uint spvPackSnorm4x8(float4 value)");
1743 begin_scope();
1744 statement(ts: "int4 Packed = int4(round(clamp(value, -1.0, 1.0) * 127.0)) & 0xff;");
1745 statement(ts: "return uint(Packed.x | (Packed.y << 8) | (Packed.z << 16) | (Packed.w << 24));");
1746 end_scope();
1747 statement(ts: "");
1748
1749 statement(ts: "float4 spvUnpackSnorm4x8(uint value)");
1750 begin_scope();
1751 statement(ts: "int SignedValue = int(value);");
1752 statement(ts: "int4 Packed = int4(SignedValue << 24, SignedValue << 16, SignedValue << 8, SignedValue) >> 24;");
1753 statement(ts: "return clamp(float4(Packed) / 127.0, -1.0, 1.0);");
1754 end_scope();
1755 statement(ts: "");
1756 }
1757
1758 if (requires_unorm16_packing)
1759 {
1760 statement(ts: "uint spvPackUnorm2x16(float2 value)");
1761 begin_scope();
1762 statement(ts: "uint2 Packed = uint2(round(saturate(value) * 65535.0));");
1763 statement(ts: "return Packed.x | (Packed.y << 16);");
1764 end_scope();
1765 statement(ts: "");
1766
1767 statement(ts: "float2 spvUnpackUnorm2x16(uint value)");
1768 begin_scope();
1769 statement(ts: "uint2 Packed = uint2(value & 0xffff, value >> 16);");
1770 statement(ts: "return float2(Packed) / 65535.0;");
1771 end_scope();
1772 statement(ts: "");
1773 }
1774
1775 if (requires_snorm16_packing)
1776 {
1777 statement(ts: "uint spvPackSnorm2x16(float2 value)");
1778 begin_scope();
1779 statement(ts: "int2 Packed = int2(round(clamp(value, -1.0, 1.0) * 32767.0)) & 0xffff;");
1780 statement(ts: "return uint(Packed.x | (Packed.y << 16));");
1781 end_scope();
1782 statement(ts: "");
1783
1784 statement(ts: "float2 spvUnpackSnorm2x16(uint value)");
1785 begin_scope();
1786 statement(ts: "int SignedValue = int(value);");
1787 statement(ts: "int2 Packed = int2(SignedValue << 16, SignedValue) >> 16;");
1788 statement(ts: "return clamp(float2(Packed) / 32767.0, -1.0, 1.0);");
1789 end_scope();
1790 statement(ts: "");
1791 }
1792
1793 if (requires_bitfield_insert)
1794 {
1795 static const char *types[] = { "uint", "uint2", "uint3", "uint4" };
1796 for (auto &type : types)
1797 {
1798 statement(ts&: type, ts: " spvBitfieldInsert(", ts&: type, ts: " Base, ", ts&: type, ts: " Insert, uint Offset, uint Count)");
1799 begin_scope();
1800 statement(ts: "uint Mask = Count == 32 ? 0xffffffff : (((1u << Count) - 1) << (Offset & 31));");
1801 statement(ts: "return (Base & ~Mask) | ((Insert << Offset) & Mask);");
1802 end_scope();
1803 statement(ts: "");
1804 }
1805 }
1806
1807 if (requires_bitfield_extract)
1808 {
1809 static const char *unsigned_types[] = { "uint", "uint2", "uint3", "uint4" };
1810 for (auto &type : unsigned_types)
1811 {
1812 statement(ts&: type, ts: " spvBitfieldUExtract(", ts&: type, ts: " Base, uint Offset, uint Count)");
1813 begin_scope();
1814 statement(ts: "uint Mask = Count == 32 ? 0xffffffff : ((1 << Count) - 1);");
1815 statement(ts: "return (Base >> Offset) & Mask;");
1816 end_scope();
1817 statement(ts: "");
1818 }
1819
1820 // In this overload, we will have to do sign-extension, which we will emulate by shifting up and down.
1821 static const char *signed_types[] = { "int", "int2", "int3", "int4" };
1822 for (auto &type : signed_types)
1823 {
1824 statement(ts&: type, ts: " spvBitfieldSExtract(", ts&: type, ts: " Base, int Offset, int Count)");
1825 begin_scope();
1826 statement(ts: "int Mask = Count == 32 ? -1 : ((1 << Count) - 1);");
1827 statement(ts&: type, ts: " Masked = (Base >> Offset) & Mask;");
1828 statement(ts: "int ExtendShift = (32 - Count) & 31;");
1829 statement(ts: "return (Masked << ExtendShift) >> ExtendShift;");
1830 end_scope();
1831 statement(ts: "");
1832 }
1833 }
1834
1835 if (requires_inverse_2x2)
1836 {
1837 statement(ts: "// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
1838 statement(ts: "// adjoint and dividing by the determinant. The contents of the matrix are changed.");
1839 statement(ts: "float2x2 spvInverse(float2x2 m)");
1840 begin_scope();
1841 statement(ts: "float2x2 adj; // The adjoint matrix (inverse after dividing by determinant)");
1842 statement_no_indent(ts: "");
1843 statement(ts: "// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
1844 statement(ts: "adj[0][0] = m[1][1];");
1845 statement(ts: "adj[0][1] = -m[0][1];");
1846 statement_no_indent(ts: "");
1847 statement(ts: "adj[1][0] = -m[1][0];");
1848 statement(ts: "adj[1][1] = m[0][0];");
1849 statement_no_indent(ts: "");
1850 statement(ts: "// Calculate the determinant as a combination of the cofactors of the first row.");
1851 statement(ts: "float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]);");
1852 statement_no_indent(ts: "");
1853 statement(ts: "// Divide the classical adjoint matrix by the determinant.");
1854 statement(ts: "// If determinant is zero, matrix is not invertable, so leave it unchanged.");
1855 statement(ts: "return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
1856 end_scope();
1857 statement(ts: "");
1858 }
1859
1860 if (requires_inverse_3x3)
1861 {
1862 statement(ts: "// Returns the determinant of a 2x2 matrix.");
1863 statement(ts: "float spvDet2x2(float a1, float a2, float b1, float b2)");
1864 begin_scope();
1865 statement(ts: "return a1 * b2 - b1 * a2;");
1866 end_scope();
1867 statement_no_indent(ts: "");
1868 statement(ts: "// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
1869 statement(ts: "// adjoint and dividing by the determinant. The contents of the matrix are changed.");
1870 statement(ts: "float3x3 spvInverse(float3x3 m)");
1871 begin_scope();
1872 statement(ts: "float3x3 adj; // The adjoint matrix (inverse after dividing by determinant)");
1873 statement_no_indent(ts: "");
1874 statement(ts: "// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
1875 statement(ts: "adj[0][0] = spvDet2x2(m[1][1], m[1][2], m[2][1], m[2][2]);");
1876 statement(ts: "adj[0][1] = -spvDet2x2(m[0][1], m[0][2], m[2][1], m[2][2]);");
1877 statement(ts: "adj[0][2] = spvDet2x2(m[0][1], m[0][2], m[1][1], m[1][2]);");
1878 statement_no_indent(ts: "");
1879 statement(ts: "adj[1][0] = -spvDet2x2(m[1][0], m[1][2], m[2][0], m[2][2]);");
1880 statement(ts: "adj[1][1] = spvDet2x2(m[0][0], m[0][2], m[2][0], m[2][2]);");
1881 statement(ts: "adj[1][2] = -spvDet2x2(m[0][0], m[0][2], m[1][0], m[1][2]);");
1882 statement_no_indent(ts: "");
1883 statement(ts: "adj[2][0] = spvDet2x2(m[1][0], m[1][1], m[2][0], m[2][1]);");
1884 statement(ts: "adj[2][1] = -spvDet2x2(m[0][0], m[0][1], m[2][0], m[2][1]);");
1885 statement(ts: "adj[2][2] = spvDet2x2(m[0][0], m[0][1], m[1][0], m[1][1]);");
1886 statement_no_indent(ts: "");
1887 statement(ts: "// Calculate the determinant as a combination of the cofactors of the first row.");
1888 statement(ts: "float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]);");
1889 statement_no_indent(ts: "");
1890 statement(ts: "// Divide the classical adjoint matrix by the determinant.");
1891 statement(ts: "// If determinant is zero, matrix is not invertable, so leave it unchanged.");
1892 statement(ts: "return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
1893 end_scope();
1894 statement(ts: "");
1895 }
1896
1897 if (requires_inverse_4x4)
1898 {
1899 if (!requires_inverse_3x3)
1900 {
1901 statement(ts: "// Returns the determinant of a 2x2 matrix.");
1902 statement(ts: "float spvDet2x2(float a1, float a2, float b1, float b2)");
1903 begin_scope();
1904 statement(ts: "return a1 * b2 - b1 * a2;");
1905 end_scope();
1906 statement(ts: "");
1907 }
1908
1909 statement(ts: "// Returns the determinant of a 3x3 matrix.");
1910 statement(ts: "float spvDet3x3(float a1, float a2, float a3, float b1, float b2, float b3, float c1, "
1911 "float c2, float c3)");
1912 begin_scope();
1913 statement(ts: "return a1 * spvDet2x2(b2, b3, c2, c3) - b1 * spvDet2x2(a2, a3, c2, c3) + c1 * "
1914 "spvDet2x2(a2, a3, "
1915 "b2, b3);");
1916 end_scope();
1917 statement_no_indent(ts: "");
1918 statement(ts: "// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
1919 statement(ts: "// adjoint and dividing by the determinant. The contents of the matrix are changed.");
1920 statement(ts: "float4x4 spvInverse(float4x4 m)");
1921 begin_scope();
1922 statement(ts: "float4x4 adj; // The adjoint matrix (inverse after dividing by determinant)");
1923 statement_no_indent(ts: "");
1924 statement(ts: "// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
1925 statement(
1926 ts: "adj[0][0] = spvDet3x3(m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
1927 "m[3][3]);");
1928 statement(
1929 ts: "adj[0][1] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
1930 "m[3][3]);");
1931 statement(
1932 ts: "adj[0][2] = spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[3][1], m[3][2], "
1933 "m[3][3]);");
1934 statement(
1935 ts: "adj[0][3] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], "
1936 "m[2][3]);");
1937 statement_no_indent(ts: "");
1938 statement(
1939 ts: "adj[1][0] = -spvDet3x3(m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
1940 "m[3][3]);");
1941 statement(
1942 ts: "adj[1][1] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
1943 "m[3][3]);");
1944 statement(
1945 ts: "adj[1][2] = -spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[3][0], m[3][2], "
1946 "m[3][3]);");
1947 statement(
1948 ts: "adj[1][3] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], "
1949 "m[2][3]);");
1950 statement_no_indent(ts: "");
1951 statement(
1952 ts: "adj[2][0] = spvDet3x3(m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
1953 "m[3][3]);");
1954 statement(
1955 ts: "adj[2][1] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
1956 "m[3][3]);");
1957 statement(
1958 ts: "adj[2][2] = spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[3][0], m[3][1], "
1959 "m[3][3]);");
1960 statement(
1961 ts: "adj[2][3] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], "
1962 "m[2][3]);");
1963 statement_no_indent(ts: "");
1964 statement(
1965 ts: "adj[3][0] = -spvDet3x3(m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
1966 "m[3][2]);");
1967 statement(
1968 ts: "adj[3][1] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
1969 "m[3][2]);");
1970 statement(
1971 ts: "adj[3][2] = -spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[3][0], m[3][1], "
1972 "m[3][2]);");
1973 statement(
1974 ts: "adj[3][3] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], "
1975 "m[2][2]);");
1976 statement_no_indent(ts: "");
1977 statement(ts: "// Calculate the determinant as a combination of the cofactors of the first row.");
1978 statement(ts: "float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]) + (adj[0][3] "
1979 "* m[3][0]);");
1980 statement_no_indent(ts: "");
1981 statement(ts: "// Divide the classical adjoint matrix by the determinant.");
1982 statement(ts: "// If determinant is zero, matrix is not invertable, so leave it unchanged.");
1983 statement(ts: "return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
1984 end_scope();
1985 statement(ts: "");
1986 }
1987
1988 if (requires_scalar_reflect)
1989 {
1990 // FP16/FP64? No templates in HLSL.
1991 statement(ts: "float spvReflect(float i, float n)");
1992 begin_scope();
1993 statement(ts: "return i - 2.0 * dot(n, i) * n;");
1994 end_scope();
1995 statement(ts: "");
1996 }
1997
1998 if (requires_scalar_refract)
1999 {
2000 // FP16/FP64? No templates in HLSL.
2001 statement(ts: "float spvRefract(float i, float n, float eta)");
2002 begin_scope();
2003 statement(ts: "float NoI = n * i;");
2004 statement(ts: "float NoI2 = NoI * NoI;");
2005 statement(ts: "float k = 1.0 - eta * eta * (1.0 - NoI2);");
2006 statement(ts: "if (k < 0.0)");
2007 begin_scope();
2008 statement(ts: "return 0.0;");
2009 end_scope();
2010 statement(ts: "else");
2011 begin_scope();
2012 statement(ts: "return eta * i - (eta * NoI + sqrt(k)) * n;");
2013 end_scope();
2014 end_scope();
2015 statement(ts: "");
2016 }
2017
2018 if (requires_scalar_faceforward)
2019 {
2020 // FP16/FP64? No templates in HLSL.
2021 statement(ts: "float spvFaceForward(float n, float i, float nref)");
2022 begin_scope();
2023 statement(ts: "return i * nref < 0.0 ? n : -n;");
2024 end_scope();
2025 statement(ts: "");
2026 }
2027
2028 for (TypeID type_id : composite_selection_workaround_types)
2029 {
2030 // Need out variable since HLSL does not support returning arrays.
2031 auto &type = get<SPIRType>(id: type_id);
2032 auto type_str = type_to_glsl(type);
2033 auto type_arr_str = type_to_array_glsl(type);
2034 statement(ts: "void spvSelectComposite(out ", ts&: type_str, ts: " out_value", ts&: type_arr_str, ts: ", bool cond, ",
2035 ts&: type_str, ts: " true_val", ts&: type_arr_str, ts: ", ",
2036 ts&: type_str, ts: " false_val", ts&: type_arr_str, ts: ")");
2037 begin_scope();
2038 statement(ts: "if (cond)");
2039 begin_scope();
2040 statement(ts: "out_value = true_val;");
2041 end_scope();
2042 statement(ts: "else");
2043 begin_scope();
2044 statement(ts: "out_value = false_val;");
2045 end_scope();
2046 end_scope();
2047 statement(ts: "");
2048 }
2049}
2050
2051void CompilerHLSL::emit_texture_size_variants(uint64_t variant_mask, const char *vecsize_qualifier, bool uav,
2052 const char *type_qualifier)
2053{
2054 if (variant_mask == 0)
2055 return;
2056
2057 static const char *types[QueryTypeCount] = { "float", "int", "uint" };
2058 static const char *dims[QueryDimCount] = { "Texture1D", "Texture1DArray", "Texture2D", "Texture2DArray",
2059 "Texture3D", "Buffer", "TextureCube", "TextureCubeArray",
2060 "Texture2DMS", "Texture2DMSArray" };
2061
2062 static const bool has_lod[QueryDimCount] = { true, true, true, true, true, false, true, true, false, false };
2063
2064 static const char *ret_types[QueryDimCount] = {
2065 "uint", "uint2", "uint2", "uint3", "uint3", "uint", "uint2", "uint3", "uint2", "uint3",
2066 };
2067
2068 static const uint32_t return_arguments[QueryDimCount] = {
2069 1, 2, 2, 3, 3, 1, 2, 3, 2, 3,
2070 };
2071
2072 for (uint32_t index = 0; index < QueryDimCount; index++)
2073 {
2074 for (uint32_t type_index = 0; type_index < QueryTypeCount; type_index++)
2075 {
2076 uint32_t bit = 16 * type_index + index;
2077 uint64_t mask = 1ull << bit;
2078
2079 if ((variant_mask & mask) == 0)
2080 continue;
2081
2082 statement(ts&: ret_types[index], ts: " spv", ts: (uav ? "Image" : "Texture"), ts: "Size(", ts: (uav ? "RW" : ""),
2083 ts&: dims[index], ts: "<", ts&: type_qualifier, ts&: types[type_index], ts&: vecsize_qualifier, ts: "> Tex, ",
2084 ts: (uav ? "" : "uint Level, "), ts: "out uint Param)");
2085 begin_scope();
2086 statement(ts&: ret_types[index], ts: " ret;");
2087 switch (return_arguments[index])
2088 {
2089 case 1:
2090 if (has_lod[index] && !uav)
2091 statement(ts: "Tex.GetDimensions(Level, ret.x, Param);");
2092 else
2093 {
2094 statement(ts: "Tex.GetDimensions(ret.x);");
2095 statement(ts: "Param = 0u;");
2096 }
2097 break;
2098 case 2:
2099 if (has_lod[index] && !uav)
2100 statement(ts: "Tex.GetDimensions(Level, ret.x, ret.y, Param);");
2101 else if (!uav)
2102 statement(ts: "Tex.GetDimensions(ret.x, ret.y, Param);");
2103 else
2104 {
2105 statement(ts: "Tex.GetDimensions(ret.x, ret.y);");
2106 statement(ts: "Param = 0u;");
2107 }
2108 break;
2109 case 3:
2110 if (has_lod[index] && !uav)
2111 statement(ts: "Tex.GetDimensions(Level, ret.x, ret.y, ret.z, Param);");
2112 else if (!uav)
2113 statement(ts: "Tex.GetDimensions(ret.x, ret.y, ret.z, Param);");
2114 else
2115 {
2116 statement(ts: "Tex.GetDimensions(ret.x, ret.y, ret.z);");
2117 statement(ts: "Param = 0u;");
2118 }
2119 break;
2120 }
2121
2122 statement(ts: "return ret;");
2123 end_scope();
2124 statement(ts: "");
2125 }
2126 }
2127}
2128
2129string CompilerHLSL::layout_for_member(const SPIRType &type, uint32_t index)
2130{
2131 auto &flags = get_member_decoration_bitset(id: type.self, index);
2132
2133 // HLSL can emit row_major or column_major decoration in any struct.
2134 // Do not try to merge combined decorations for children like in GLSL.
2135
2136 // Flip the convention. HLSL is a bit odd in that the memory layout is column major ... but the language API is "row-major".
2137 // The way to deal with this is to multiply everything in inverse order, and reverse the memory layout.
2138 if (flags.get(bit: DecorationColMajor))
2139 return "row_major ";
2140 else if (flags.get(bit: DecorationRowMajor))
2141 return "column_major ";
2142
2143 return "";
2144}
2145
2146void CompilerHLSL::emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
2147 const string &qualifier, uint32_t base_offset)
2148{
2149 auto &membertype = get<SPIRType>(id: member_type_id);
2150
2151 Bitset memberflags;
2152 auto &memb = ir.meta[type.self].members;
2153 if (index < memb.size())
2154 memberflags = memb[index].decoration_flags;
2155
2156 string packing_offset;
2157 bool is_push_constant = type.storage == StorageClassPushConstant;
2158
2159 if ((has_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationExplicitOffset) || is_push_constant) &&
2160 has_member_decoration(id: type.self, index, decoration: DecorationOffset))
2161 {
2162 uint32_t offset = memb[index].offset - base_offset;
2163 if (offset & 3)
2164 SPIRV_CROSS_THROW("Cannot pack on tighter bounds than 4 bytes in HLSL.");
2165
2166 static const char *packing_swizzle[] = { "", ".y", ".z", ".w" };
2167 packing_offset = join(ts: " : packoffset(c", ts: offset / 16, ts&: packing_swizzle[(offset & 15) >> 2], ts: ")");
2168 }
2169
2170 statement(ts: layout_for_member(type, index), ts: qualifier,
2171 ts: variable_decl(type: membertype, name: to_member_name(type, index)), ts&: packing_offset, ts: ";");
2172}
2173
2174void CompilerHLSL::emit_rayquery_function(const char *commited, const char *candidate, const uint32_t *ops)
2175{
2176 flush_variable_declaration(id: ops[0]);
2177 uint32_t is_commited = evaluate_constant_u32(id: ops[3]);
2178 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts&: is_commited ? commited : candidate), forward_rhs: false);
2179}
2180
2181void CompilerHLSL::emit_buffer_block(const SPIRVariable &var)
2182{
2183 auto &type = get<SPIRType>(id: var.basetype);
2184
2185 bool is_uav = var.storage == StorageClassStorageBuffer || has_decoration(id: type.self, decoration: DecorationBufferBlock);
2186
2187 if (flattened_buffer_blocks.count(x: var.self))
2188 {
2189 emit_buffer_block_flattened(type: var);
2190 }
2191 else if (is_uav)
2192 {
2193 Bitset flags = ir.get_buffer_block_flags(var);
2194 bool is_readonly = flags.get(bit: DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(id: var.self);
2195 bool is_coherent = flags.get(bit: DecorationCoherent) && !is_readonly;
2196 bool is_interlocked = interlocked_resources.count(x: var.self) > 0;
2197 const char *type_name = "ByteAddressBuffer ";
2198 if (!is_readonly)
2199 type_name = is_interlocked ? "RasterizerOrderedByteAddressBuffer " : "RWByteAddressBuffer ";
2200 add_resource_name(id: var.self);
2201 statement(ts: is_coherent ? "globallycoherent " : "", ts&: type_name, ts: to_name(id: var.self), ts: type_to_array_glsl(type),
2202 ts: to_resource_binding(var), ts: ";");
2203 }
2204 else
2205 {
2206 if (type.array.empty())
2207 {
2208 // Flatten the top-level struct so we can use packoffset,
2209 // this restriction is similar to GLSL where layout(offset) is not possible on sub-structs.
2210 flattened_structs[var.self] = false;
2211
2212 // Prefer the block name if possible.
2213 auto buffer_name = to_name(id: type.self, allow_alias: false);
2214 if (ir.meta[type.self].decoration.alias.empty() ||
2215 resource_names.find(x: buffer_name) != end(cont&: resource_names) ||
2216 block_names.find(x: buffer_name) != end(cont&: block_names))
2217 {
2218 buffer_name = get_block_fallback_name(id: var.self);
2219 }
2220
2221 add_variable(variables_primary&: block_names, variables_secondary: resource_names, name&: buffer_name);
2222
2223 // If for some reason buffer_name is an illegal name, make a final fallback to a workaround name.
2224 // This cannot conflict with anything else, so we're safe now.
2225 if (buffer_name.empty())
2226 buffer_name = join(ts: "_", ts&: get<SPIRType>(id: var.basetype).self, ts: "_", ts: var.self);
2227
2228 uint32_t failed_index = 0;
2229 if (buffer_is_packing_standard(type, packing: BufferPackingHLSLCbufferPackOffset, failed_index: &failed_index))
2230 set_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationExplicitOffset);
2231 else
2232 {
2233 SPIRV_CROSS_THROW(join("cbuffer ID ", var.self, " (name: ", buffer_name, "), member index ",
2234 failed_index, " (name: ", to_member_name(type, failed_index),
2235 ") cannot be expressed with either HLSL packing layout or packoffset."));
2236 }
2237
2238 block_names.insert(x: buffer_name);
2239
2240 // Save for post-reflection later.
2241 declared_block_names[var.self] = buffer_name;
2242
2243 type.member_name_cache.clear();
2244 // var.self can be used as a backup name for the block name,
2245 // so we need to make sure we don't disturb the name here on a recompile.
2246 // It will need to be reset if we have to recompile.
2247 preserve_alias_on_reset(id: var.self);
2248 add_resource_name(id: var.self);
2249 statement(ts: "cbuffer ", ts&: buffer_name, ts: to_resource_binding(var));
2250 begin_scope();
2251
2252 uint32_t i = 0;
2253 for (auto &member : type.member_types)
2254 {
2255 add_member_name(type, name: i);
2256 auto backup_name = get_member_name(id: type.self, index: i);
2257 auto member_name = to_member_name(type, index: i);
2258 member_name = join(ts: to_name(id: var.self), ts: "_", ts&: member_name);
2259 ParsedIR::sanitize_underscores(str&: member_name);
2260 set_member_name(id: type.self, index: i, name: member_name);
2261 emit_struct_member(type, member_type_id: member, index: i, qualifier: "");
2262 set_member_name(id: type.self, index: i, name: backup_name);
2263 i++;
2264 }
2265
2266 end_scope_decl();
2267 statement(ts: "");
2268 }
2269 else
2270 {
2271 if (hlsl_options.shader_model < 51)
2272 SPIRV_CROSS_THROW(
2273 "Need ConstantBuffer<T> to use arrays of UBOs, but this is only supported in SM 5.1.");
2274
2275 add_resource_name(id: type.self);
2276 add_resource_name(id: var.self);
2277
2278 // ConstantBuffer<T> does not support packoffset, so it is unuseable unless everything aligns as we expect.
2279 uint32_t failed_index = 0;
2280 if (!buffer_is_packing_standard(type, packing: BufferPackingHLSLCbuffer, failed_index: &failed_index))
2281 {
2282 SPIRV_CROSS_THROW(join("HLSL ConstantBuffer<T> ID ", var.self, " (name: ", to_name(type.self),
2283 "), member index ", failed_index, " (name: ", to_member_name(type, failed_index),
2284 ") cannot be expressed with normal HLSL packing rules."));
2285 }
2286
2287 emit_struct(type&: get<SPIRType>(id: type.self));
2288 statement(ts: "ConstantBuffer<", ts: to_name(id: type.self), ts: "> ", ts: to_name(id: var.self), ts: type_to_array_glsl(type),
2289 ts: to_resource_binding(var), ts: ";");
2290 }
2291 }
2292}
2293
2294void CompilerHLSL::emit_push_constant_block(const SPIRVariable &var)
2295{
2296 if (flattened_buffer_blocks.count(x: var.self))
2297 {
2298 emit_buffer_block_flattened(type: var);
2299 }
2300 else if (root_constants_layout.empty())
2301 {
2302 emit_buffer_block(var);
2303 }
2304 else
2305 {
2306 for (const auto &layout : root_constants_layout)
2307 {
2308 auto &type = get<SPIRType>(id: var.basetype);
2309
2310 uint32_t failed_index = 0;
2311 if (buffer_is_packing_standard(type, packing: BufferPackingHLSLCbufferPackOffset, failed_index: &failed_index, start_offset: layout.start,
2312 end_offset: layout.end))
2313 set_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationExplicitOffset);
2314 else
2315 {
2316 SPIRV_CROSS_THROW(join("Root constant cbuffer ID ", var.self, " (name: ", to_name(type.self), ")",
2317 ", member index ", failed_index, " (name: ", to_member_name(type, failed_index),
2318 ") cannot be expressed with either HLSL packing layout or packoffset."));
2319 }
2320
2321 flattened_structs[var.self] = false;
2322 type.member_name_cache.clear();
2323 add_resource_name(id: var.self);
2324 auto &memb = ir.meta[type.self].members;
2325
2326 statement(ts: "cbuffer SPIRV_CROSS_RootConstant_", ts: to_name(id: var.self),
2327 ts: to_resource_register(flag: HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT, space: 'b', binding: layout.binding, set: layout.space));
2328 begin_scope();
2329
2330 // Index of the next field in the generated root constant constant buffer
2331 auto constant_index = 0u;
2332
2333 // Iterate over all member of the push constant and check which of the fields
2334 // fit into the given root constant layout.
2335 for (auto i = 0u; i < memb.size(); i++)
2336 {
2337 const auto offset = memb[i].offset;
2338 if (layout.start <= offset && offset < layout.end)
2339 {
2340 const auto &member = type.member_types[i];
2341
2342 add_member_name(type, name: constant_index);
2343 auto backup_name = get_member_name(id: type.self, index: i);
2344 auto member_name = to_member_name(type, index: i);
2345 member_name = join(ts: to_name(id: var.self), ts: "_", ts&: member_name);
2346 ParsedIR::sanitize_underscores(str&: member_name);
2347 set_member_name(id: type.self, index: constant_index, name: member_name);
2348 emit_struct_member(type, member_type_id: member, index: i, qualifier: "", base_offset: layout.start);
2349 set_member_name(id: type.self, index: constant_index, name: backup_name);
2350
2351 constant_index++;
2352 }
2353 }
2354
2355 end_scope_decl();
2356 }
2357 }
2358}
2359
2360string CompilerHLSL::to_sampler_expression(uint32_t id)
2361{
2362 auto expr = join(ts: "_", ts: to_non_uniform_aware_expression(id));
2363 auto index = expr.find_first_of(c: '[');
2364 if (index == string::npos)
2365 {
2366 return expr + "_sampler";
2367 }
2368 else
2369 {
2370 // We have an expression like _ident[array], so we cannot tack on _sampler, insert it inside the string instead.
2371 return expr.insert(pos: index, s: "_sampler");
2372 }
2373}
2374
2375void CompilerHLSL::emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id)
2376{
2377 if (hlsl_options.shader_model >= 40 && combined_image_samplers.empty())
2378 {
2379 set<SPIRCombinedImageSampler>(id: result_id, args&: result_type, args&: image_id, args&: samp_id);
2380 }
2381 else
2382 {
2383 // Make sure to suppress usage tracking. It is illegal to create temporaries of opaque types.
2384 emit_op(result_type, result_id, rhs: to_combined_image_sampler(image_id, samp_id), forward_rhs: true, suppress_usage_tracking: true);
2385 }
2386}
2387
2388string CompilerHLSL::to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id)
2389{
2390 string arg_str = CompilerGLSL::to_func_call_arg(arg, id);
2391
2392 if (hlsl_options.shader_model <= 30)
2393 return arg_str;
2394
2395 // Manufacture automatic sampler arg if the arg is a SampledImage texture and we're in modern HLSL.
2396 auto &type = expression_type(id);
2397
2398 // We don't have to consider combined image samplers here via OpSampledImage because
2399 // those variables cannot be passed as arguments to functions.
2400 // Only global SampledImage variables may be used as arguments.
2401 if (type.basetype == SPIRType::SampledImage && type.image.dim != DimBuffer)
2402 arg_str += ", " + to_sampler_expression(id);
2403
2404 return arg_str;
2405}
2406
2407void CompilerHLSL::emit_function_prototype(SPIRFunction &func, const Bitset &return_flags)
2408{
2409 if (func.self != ir.default_entry_point)
2410 add_function_overload(func);
2411
2412 auto &execution = get_entry_point();
2413 // Avoid shadow declarations.
2414 local_variable_names = resource_names;
2415
2416 string decl;
2417
2418 auto &type = get<SPIRType>(id: func.return_type);
2419 if (type.array.empty())
2420 {
2421 decl += flags_to_qualifiers_glsl(type, flags: return_flags);
2422 decl += type_to_glsl(type);
2423 decl += " ";
2424 }
2425 else
2426 {
2427 // We cannot return arrays in HLSL, so "return" through an out variable.
2428 decl = "void ";
2429 }
2430
2431 if (func.self == ir.default_entry_point)
2432 {
2433 if (execution.model == ExecutionModelVertex)
2434 decl += "vert_main";
2435 else if (execution.model == ExecutionModelFragment)
2436 decl += "frag_main";
2437 else if (execution.model == ExecutionModelGLCompute)
2438 decl += "comp_main";
2439 else
2440 SPIRV_CROSS_THROW("Unsupported execution model.");
2441 processing_entry_point = true;
2442 }
2443 else
2444 decl += to_name(id: func.self);
2445
2446 decl += "(";
2447 SmallVector<string> arglist;
2448
2449 if (!type.array.empty())
2450 {
2451 // Fake array returns by writing to an out array instead.
2452 string out_argument;
2453 out_argument += "out ";
2454 out_argument += type_to_glsl(type);
2455 out_argument += " ";
2456 out_argument += "spvReturnValue";
2457 out_argument += type_to_array_glsl(type);
2458 arglist.push_back(t: std::move(out_argument));
2459 }
2460
2461 for (auto &arg : func.arguments)
2462 {
2463 // Do not pass in separate images or samplers if we're remapping
2464 // to combined image samplers.
2465 if (skip_argument(id: arg.id))
2466 continue;
2467
2468 // Might change the variable name if it already exists in this function.
2469 // SPIRV OpName doesn't have any semantic effect, so it's valid for an implementation
2470 // to use same name for variables.
2471 // Since we want to make the GLSL debuggable and somewhat sane, use fallback names for variables which are duplicates.
2472 add_local_variable_name(id: arg.id);
2473
2474 arglist.push_back(t: argument_decl(arg));
2475
2476 // Flatten a combined sampler to two separate arguments in modern HLSL.
2477 auto &arg_type = get<SPIRType>(id: arg.type);
2478 if (hlsl_options.shader_model > 30 && arg_type.basetype == SPIRType::SampledImage &&
2479 arg_type.image.dim != DimBuffer)
2480 {
2481 // Manufacture automatic sampler arg for SampledImage texture
2482 arglist.push_back(t: join(ts: is_depth_image(type: arg_type, id: arg.id) ? "SamplerComparisonState " : "SamplerState ",
2483 ts: to_sampler_expression(id: arg.id), ts: type_to_array_glsl(type: arg_type)));
2484 }
2485
2486 // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
2487 auto *var = maybe_get<SPIRVariable>(id: arg.id);
2488 if (var)
2489 var->parameter = &arg;
2490 }
2491
2492 for (auto &arg : func.shadow_arguments)
2493 {
2494 // Might change the variable name if it already exists in this function.
2495 // SPIRV OpName doesn't have any semantic effect, so it's valid for an implementation
2496 // to use same name for variables.
2497 // Since we want to make the GLSL debuggable and somewhat sane, use fallback names for variables which are duplicates.
2498 add_local_variable_name(id: arg.id);
2499
2500 arglist.push_back(t: argument_decl(arg));
2501
2502 // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
2503 auto *var = maybe_get<SPIRVariable>(id: arg.id);
2504 if (var)
2505 var->parameter = &arg;
2506 }
2507
2508 decl += merge(list: arglist);
2509 decl += ")";
2510 statement(ts&: decl);
2511}
2512
2513void CompilerHLSL::emit_hlsl_entry_point()
2514{
2515 SmallVector<string> arguments;
2516
2517 if (require_input)
2518 arguments.push_back(t: "SPIRV_Cross_Input stage_input");
2519
2520 auto &execution = get_entry_point();
2521
2522 switch (execution.model)
2523 {
2524 case ExecutionModelGLCompute:
2525 {
2526 SpecializationConstant wg_x, wg_y, wg_z;
2527 get_work_group_size_specialization_constants(x&: wg_x, y&: wg_y, z&: wg_z);
2528
2529 uint32_t x = execution.workgroup_size.x;
2530 uint32_t y = execution.workgroup_size.y;
2531 uint32_t z = execution.workgroup_size.z;
2532
2533 if (!execution.workgroup_size.constant && execution.flags.get(bit: ExecutionModeLocalSizeId))
2534 {
2535 if (execution.workgroup_size.id_x)
2536 x = get<SPIRConstant>(id: execution.workgroup_size.id_x).scalar();
2537 if (execution.workgroup_size.id_y)
2538 y = get<SPIRConstant>(id: execution.workgroup_size.id_y).scalar();
2539 if (execution.workgroup_size.id_z)
2540 z = get<SPIRConstant>(id: execution.workgroup_size.id_z).scalar();
2541 }
2542
2543 auto x_expr = wg_x.id ? get<SPIRConstant>(id: wg_x.id).specialization_constant_macro_name : to_string(val: x);
2544 auto y_expr = wg_y.id ? get<SPIRConstant>(id: wg_y.id).specialization_constant_macro_name : to_string(val: y);
2545 auto z_expr = wg_z.id ? get<SPIRConstant>(id: wg_z.id).specialization_constant_macro_name : to_string(val: z);
2546
2547 statement(ts: "[numthreads(", ts&: x_expr, ts: ", ", ts&: y_expr, ts: ", ", ts&: z_expr, ts: ")]");
2548 break;
2549 }
2550 case ExecutionModelFragment:
2551 if (execution.flags.get(bit: ExecutionModeEarlyFragmentTests))
2552 statement(ts: "[earlydepthstencil]");
2553 break;
2554 default:
2555 break;
2556 }
2557
2558 statement(ts: require_output ? "SPIRV_Cross_Output " : "void ", ts: "main(", ts: merge(list: arguments), ts: ")");
2559 begin_scope();
2560 bool legacy = hlsl_options.shader_model <= 30;
2561
2562 // Copy builtins from entry point arguments to globals.
2563 active_input_builtins.for_each_bit(op: [&](uint32_t i) {
2564 auto builtin = builtin_to_glsl(builtin: static_cast<BuiltIn>(i), storage: StorageClassInput);
2565 switch (static_cast<BuiltIn>(i))
2566 {
2567 case BuiltInFragCoord:
2568 // VPOS in D3D9 is sampled at integer locations, apply half-pixel offset to be consistent.
2569 // TODO: Do we need an option here? Any reason why a D3D9 shader would be used
2570 // on a D3D10+ system with a different rasterization config?
2571 if (legacy)
2572 statement(ts&: builtin, ts: " = stage_input.", ts&: builtin, ts: " + float4(0.5f, 0.5f, 0.0f, 0.0f);");
2573 else
2574 {
2575 statement(ts&: builtin, ts: " = stage_input.", ts&: builtin, ts: ";");
2576 // ZW are undefined in D3D9, only do this fixup here.
2577 statement(ts&: builtin, ts: ".w = 1.0 / ", ts&: builtin, ts: ".w;");
2578 }
2579 break;
2580
2581 case BuiltInVertexId:
2582 case BuiltInVertexIndex:
2583 case BuiltInInstanceIndex:
2584 // D3D semantics are uint, but shader wants int.
2585 if (hlsl_options.support_nonzero_base_vertex_base_instance)
2586 {
2587 if (static_cast<BuiltIn>(i) == BuiltInInstanceIndex)
2588 statement(ts&: builtin, ts: " = int(stage_input.", ts&: builtin, ts: ") + SPIRV_Cross_BaseInstance;");
2589 else
2590 statement(ts&: builtin, ts: " = int(stage_input.", ts&: builtin, ts: ") + SPIRV_Cross_BaseVertex;");
2591 }
2592 else
2593 statement(ts&: builtin, ts: " = int(stage_input.", ts&: builtin, ts: ");");
2594 break;
2595
2596 case BuiltInInstanceId:
2597 // D3D semantics are uint, but shader wants int.
2598 statement(ts&: builtin, ts: " = int(stage_input.", ts&: builtin, ts: ");");
2599 break;
2600
2601 case BuiltInNumWorkgroups:
2602 case BuiltInPointCoord:
2603 case BuiltInSubgroupSize:
2604 case BuiltInSubgroupLocalInvocationId:
2605 case BuiltInHelperInvocation:
2606 break;
2607
2608 case BuiltInSubgroupEqMask:
2609 // Emulate these ...
2610 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2611 statement(ts: "gl_SubgroupEqMask = 1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96));");
2612 statement(ts: "if (WaveGetLaneIndex() >= 32) gl_SubgroupEqMask.x = 0;");
2613 statement(ts: "if (WaveGetLaneIndex() >= 64 || WaveGetLaneIndex() < 32) gl_SubgroupEqMask.y = 0;");
2614 statement(ts: "if (WaveGetLaneIndex() >= 96 || WaveGetLaneIndex() < 64) gl_SubgroupEqMask.z = 0;");
2615 statement(ts: "if (WaveGetLaneIndex() < 96) gl_SubgroupEqMask.w = 0;");
2616 break;
2617
2618 case BuiltInSubgroupGeMask:
2619 // Emulate these ...
2620 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2621 statement(ts: "gl_SubgroupGeMask = ~((1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u);");
2622 statement(ts: "if (WaveGetLaneIndex() >= 32) gl_SubgroupGeMask.x = 0u;");
2623 statement(ts: "if (WaveGetLaneIndex() >= 64) gl_SubgroupGeMask.y = 0u;");
2624 statement(ts: "if (WaveGetLaneIndex() >= 96) gl_SubgroupGeMask.z = 0u;");
2625 statement(ts: "if (WaveGetLaneIndex() < 32) gl_SubgroupGeMask.y = ~0u;");
2626 statement(ts: "if (WaveGetLaneIndex() < 64) gl_SubgroupGeMask.z = ~0u;");
2627 statement(ts: "if (WaveGetLaneIndex() < 96) gl_SubgroupGeMask.w = ~0u;");
2628 break;
2629
2630 case BuiltInSubgroupGtMask:
2631 // Emulate these ...
2632 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2633 statement(ts: "uint gt_lane_index = WaveGetLaneIndex() + 1;");
2634 statement(ts: "gl_SubgroupGtMask = ~((1u << (gt_lane_index - uint4(0, 32, 64, 96))) - 1u);");
2635 statement(ts: "if (gt_lane_index >= 32) gl_SubgroupGtMask.x = 0u;");
2636 statement(ts: "if (gt_lane_index >= 64) gl_SubgroupGtMask.y = 0u;");
2637 statement(ts: "if (gt_lane_index >= 96) gl_SubgroupGtMask.z = 0u;");
2638 statement(ts: "if (gt_lane_index >= 128) gl_SubgroupGtMask.w = 0u;");
2639 statement(ts: "if (gt_lane_index < 32) gl_SubgroupGtMask.y = ~0u;");
2640 statement(ts: "if (gt_lane_index < 64) gl_SubgroupGtMask.z = ~0u;");
2641 statement(ts: "if (gt_lane_index < 96) gl_SubgroupGtMask.w = ~0u;");
2642 break;
2643
2644 case BuiltInSubgroupLeMask:
2645 // Emulate these ...
2646 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2647 statement(ts: "uint le_lane_index = WaveGetLaneIndex() + 1;");
2648 statement(ts: "gl_SubgroupLeMask = (1u << (le_lane_index - uint4(0, 32, 64, 96))) - 1u;");
2649 statement(ts: "if (le_lane_index >= 32) gl_SubgroupLeMask.x = ~0u;");
2650 statement(ts: "if (le_lane_index >= 64) gl_SubgroupLeMask.y = ~0u;");
2651 statement(ts: "if (le_lane_index >= 96) gl_SubgroupLeMask.z = ~0u;");
2652 statement(ts: "if (le_lane_index >= 128) gl_SubgroupLeMask.w = ~0u;");
2653 statement(ts: "if (le_lane_index < 32) gl_SubgroupLeMask.y = 0u;");
2654 statement(ts: "if (le_lane_index < 64) gl_SubgroupLeMask.z = 0u;");
2655 statement(ts: "if (le_lane_index < 96) gl_SubgroupLeMask.w = 0u;");
2656 break;
2657
2658 case BuiltInSubgroupLtMask:
2659 // Emulate these ...
2660 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2661 statement(ts: "gl_SubgroupLtMask = (1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u;");
2662 statement(ts: "if (WaveGetLaneIndex() >= 32) gl_SubgroupLtMask.x = ~0u;");
2663 statement(ts: "if (WaveGetLaneIndex() >= 64) gl_SubgroupLtMask.y = ~0u;");
2664 statement(ts: "if (WaveGetLaneIndex() >= 96) gl_SubgroupLtMask.z = ~0u;");
2665 statement(ts: "if (WaveGetLaneIndex() < 32) gl_SubgroupLtMask.y = 0u;");
2666 statement(ts: "if (WaveGetLaneIndex() < 64) gl_SubgroupLtMask.z = 0u;");
2667 statement(ts: "if (WaveGetLaneIndex() < 96) gl_SubgroupLtMask.w = 0u;");
2668 break;
2669
2670 case BuiltInClipDistance:
2671 for (uint32_t clip = 0; clip < clip_distance_count; clip++)
2672 statement(ts: "gl_ClipDistance[", ts&: clip, ts: "] = stage_input.gl_ClipDistance", ts: clip / 4, ts: ".", ts: "xyzw"[clip & 3],
2673 ts: ";");
2674 break;
2675
2676 case BuiltInCullDistance:
2677 for (uint32_t cull = 0; cull < cull_distance_count; cull++)
2678 statement(ts: "gl_CullDistance[", ts&: cull, ts: "] = stage_input.gl_CullDistance", ts: cull / 4, ts: ".", ts: "xyzw"[cull & 3],
2679 ts: ";");
2680 break;
2681
2682 default:
2683 statement(ts&: builtin, ts: " = stage_input.", ts&: builtin, ts: ";");
2684 break;
2685 }
2686 });
2687
2688 // Copy from stage input struct to globals.
2689 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
2690 auto &type = this->get<SPIRType>(id: var.basetype);
2691 bool block = has_decoration(id: type.self, decoration: DecorationBlock);
2692
2693 if (var.storage != StorageClassInput)
2694 return;
2695
2696 bool need_matrix_unroll = var.storage == StorageClassInput && execution.model == ExecutionModelVertex;
2697
2698 if (!var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
2699 interface_variable_exists_in_entry_point(id: var.self))
2700 {
2701 if (block)
2702 {
2703 auto type_name = to_name(id: type.self);
2704 auto var_name = to_name(id: var.self);
2705 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(type.member_types.size()); mbr_idx++)
2706 {
2707 auto mbr_name = to_member_name(type, index: mbr_idx);
2708 auto flat_name = join(ts&: type_name, ts: "_", ts&: mbr_name);
2709 statement(ts&: var_name, ts: ".", ts&: mbr_name, ts: " = stage_input.", ts&: flat_name, ts: ";");
2710 }
2711 }
2712 else
2713 {
2714 auto name = to_name(id: var.self);
2715 auto &mtype = this->get<SPIRType>(id: var.basetype);
2716 if (need_matrix_unroll && mtype.columns > 1)
2717 {
2718 // Unroll matrices.
2719 for (uint32_t col = 0; col < mtype.columns; col++)
2720 statement(ts&: name, ts: "[", ts&: col, ts: "] = stage_input.", ts&: name, ts: "_", ts&: col, ts: ";");
2721 }
2722 else
2723 {
2724 statement(ts&: name, ts: " = stage_input.", ts&: name, ts: ";");
2725 }
2726 }
2727 }
2728 });
2729
2730 // Run the shader.
2731 if (execution.model == ExecutionModelVertex)
2732 statement(ts: "vert_main();");
2733 else if (execution.model == ExecutionModelFragment)
2734 statement(ts: "frag_main();");
2735 else if (execution.model == ExecutionModelGLCompute)
2736 statement(ts: "comp_main();");
2737 else
2738 SPIRV_CROSS_THROW("Unsupported shader stage.");
2739
2740 // Copy stage outputs.
2741 if (require_output)
2742 {
2743 statement(ts: "SPIRV_Cross_Output stage_output;");
2744
2745 // Copy builtins from globals to return struct.
2746 active_output_builtins.for_each_bit(op: [&](uint32_t i) {
2747 // PointSize doesn't exist in HLSL.
2748 if (i == BuiltInPointSize)
2749 return;
2750
2751 switch (static_cast<BuiltIn>(i))
2752 {
2753 case BuiltInClipDistance:
2754 for (uint32_t clip = 0; clip < clip_distance_count; clip++)
2755 statement(ts: "stage_output.gl_ClipDistance", ts: clip / 4, ts: ".", ts: "xyzw"[clip & 3], ts: " = gl_ClipDistance[",
2756 ts&: clip, ts: "];");
2757 break;
2758
2759 case BuiltInCullDistance:
2760 for (uint32_t cull = 0; cull < cull_distance_count; cull++)
2761 statement(ts: "stage_output.gl_CullDistance", ts: cull / 4, ts: ".", ts: "xyzw"[cull & 3], ts: " = gl_CullDistance[",
2762 ts&: cull, ts: "];");
2763 break;
2764
2765 default:
2766 {
2767 auto builtin_expr = builtin_to_glsl(builtin: static_cast<BuiltIn>(i), storage: StorageClassOutput);
2768 statement(ts: "stage_output.", ts&: builtin_expr, ts: " = ", ts&: builtin_expr, ts: ";");
2769 break;
2770 }
2771 }
2772 });
2773
2774 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
2775 auto &type = this->get<SPIRType>(id: var.basetype);
2776 bool block = has_decoration(id: type.self, decoration: DecorationBlock);
2777
2778 if (var.storage != StorageClassOutput)
2779 return;
2780
2781 if (!var.remapped_variable && type.pointer &&
2782 !is_builtin_variable(var) &&
2783 interface_variable_exists_in_entry_point(id: var.self))
2784 {
2785 if (block)
2786 {
2787 // I/O blocks need to flatten output.
2788 auto type_name = to_name(id: type.self);
2789 auto var_name = to_name(id: var.self);
2790 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(type.member_types.size()); mbr_idx++)
2791 {
2792 auto mbr_name = to_member_name(type, index: mbr_idx);
2793 auto flat_name = join(ts&: type_name, ts: "_", ts&: mbr_name);
2794 statement(ts: "stage_output.", ts&: flat_name, ts: " = ", ts&: var_name, ts: ".", ts&: mbr_name, ts: ";");
2795 }
2796 }
2797 else
2798 {
2799 auto name = to_name(id: var.self);
2800
2801 if (legacy && execution.model == ExecutionModelFragment)
2802 {
2803 string output_filler;
2804 for (uint32_t size = type.vecsize; size < 4; ++size)
2805 output_filler += ", 0.0";
2806
2807 statement(ts: "stage_output.", ts&: name, ts: " = float4(", ts&: name, ts&: output_filler, ts: ");");
2808 }
2809 else
2810 {
2811 statement(ts: "stage_output.", ts&: name, ts: " = ", ts&: name, ts: ";");
2812 }
2813 }
2814 }
2815 });
2816
2817 statement(ts: "return stage_output;");
2818 }
2819
2820 end_scope();
2821}
2822
2823void CompilerHLSL::emit_fixup()
2824{
2825 if (is_vertex_like_shader() && active_output_builtins.get(bit: BuiltInPosition))
2826 {
2827 // Do various mangling on the gl_Position.
2828 if (hlsl_options.shader_model <= 30)
2829 {
2830 statement(ts: "gl_Position.x = gl_Position.x - gl_HalfPixel.x * "
2831 "gl_Position.w;");
2832 statement(ts: "gl_Position.y = gl_Position.y + gl_HalfPixel.y * "
2833 "gl_Position.w;");
2834 }
2835
2836 if (options.vertex.flip_vert_y)
2837 statement(ts: "gl_Position.y = -gl_Position.y;");
2838 if (options.vertex.fixup_clipspace)
2839 statement(ts: "gl_Position.z = (gl_Position.z + gl_Position.w) * 0.5;");
2840 }
2841}
2842
2843void CompilerHLSL::emit_texture_op(const Instruction &i, bool sparse)
2844{
2845 if (sparse)
2846 SPIRV_CROSS_THROW("Sparse feedback not yet supported in HLSL.");
2847
2848 auto *ops = stream(instr: i);
2849 auto op = static_cast<Op>(i.op);
2850 uint32_t length = i.length;
2851
2852 SmallVector<uint32_t> inherited_expressions;
2853
2854 uint32_t result_type = ops[0];
2855 uint32_t id = ops[1];
2856 VariableID img = ops[2];
2857 uint32_t coord = ops[3];
2858 uint32_t dref = 0;
2859 uint32_t comp = 0;
2860 bool gather = false;
2861 bool proj = false;
2862 const uint32_t *opt = nullptr;
2863 auto *combined_image = maybe_get<SPIRCombinedImageSampler>(id: img);
2864
2865 if (combined_image && has_decoration(id: img, decoration: DecorationNonUniform))
2866 {
2867 set_decoration(id: combined_image->image, decoration: DecorationNonUniform);
2868 set_decoration(id: combined_image->sampler, decoration: DecorationNonUniform);
2869 }
2870
2871 auto img_expr = to_non_uniform_aware_expression(id: combined_image ? combined_image->image : img);
2872
2873 inherited_expressions.push_back(t: coord);
2874
2875 switch (op)
2876 {
2877 case OpImageSampleDrefImplicitLod:
2878 case OpImageSampleDrefExplicitLod:
2879 dref = ops[4];
2880 opt = &ops[5];
2881 length -= 5;
2882 break;
2883
2884 case OpImageSampleProjDrefImplicitLod:
2885 case OpImageSampleProjDrefExplicitLod:
2886 dref = ops[4];
2887 proj = true;
2888 opt = &ops[5];
2889 length -= 5;
2890 break;
2891
2892 case OpImageDrefGather:
2893 dref = ops[4];
2894 opt = &ops[5];
2895 gather = true;
2896 length -= 5;
2897 break;
2898
2899 case OpImageGather:
2900 comp = ops[4];
2901 opt = &ops[5];
2902 gather = true;
2903 length -= 5;
2904 break;
2905
2906 case OpImageSampleProjImplicitLod:
2907 case OpImageSampleProjExplicitLod:
2908 opt = &ops[4];
2909 length -= 4;
2910 proj = true;
2911 break;
2912
2913 case OpImageQueryLod:
2914 opt = &ops[4];
2915 length -= 4;
2916 break;
2917
2918 default:
2919 opt = &ops[4];
2920 length -= 4;
2921 break;
2922 }
2923
2924 auto &imgtype = expression_type(id: img);
2925 uint32_t coord_components = 0;
2926 switch (imgtype.image.dim)
2927 {
2928 case spv::Dim1D:
2929 coord_components = 1;
2930 break;
2931 case spv::Dim2D:
2932 coord_components = 2;
2933 break;
2934 case spv::Dim3D:
2935 coord_components = 3;
2936 break;
2937 case spv::DimCube:
2938 coord_components = 3;
2939 break;
2940 case spv::DimBuffer:
2941 coord_components = 1;
2942 break;
2943 default:
2944 coord_components = 2;
2945 break;
2946 }
2947
2948 if (dref)
2949 inherited_expressions.push_back(t: dref);
2950
2951 if (imgtype.image.arrayed)
2952 coord_components++;
2953
2954 uint32_t bias = 0;
2955 uint32_t lod = 0;
2956 uint32_t grad_x = 0;
2957 uint32_t grad_y = 0;
2958 uint32_t coffset = 0;
2959 uint32_t offset = 0;
2960 uint32_t coffsets = 0;
2961 uint32_t sample = 0;
2962 uint32_t minlod = 0;
2963 uint32_t flags = 0;
2964
2965 if (length)
2966 {
2967 flags = opt[0];
2968 opt++;
2969 length--;
2970 }
2971
2972 auto test = [&](uint32_t &v, uint32_t flag) {
2973 if (length && (flags & flag))
2974 {
2975 v = *opt++;
2976 inherited_expressions.push_back(t: v);
2977 length--;
2978 }
2979 };
2980
2981 test(bias, ImageOperandsBiasMask);
2982 test(lod, ImageOperandsLodMask);
2983 test(grad_x, ImageOperandsGradMask);
2984 test(grad_y, ImageOperandsGradMask);
2985 test(coffset, ImageOperandsConstOffsetMask);
2986 test(offset, ImageOperandsOffsetMask);
2987 test(coffsets, ImageOperandsConstOffsetsMask);
2988 test(sample, ImageOperandsSampleMask);
2989 test(minlod, ImageOperandsMinLodMask);
2990
2991 string expr;
2992 string texop;
2993
2994 if (minlod != 0)
2995 SPIRV_CROSS_THROW("MinLod texture operand not supported in HLSL.");
2996
2997 if (op == OpImageFetch)
2998 {
2999 if (hlsl_options.shader_model < 40)
3000 {
3001 SPIRV_CROSS_THROW("texelFetch is not supported in HLSL shader model 2/3.");
3002 }
3003 texop += img_expr;
3004 texop += ".Load";
3005 }
3006 else if (op == OpImageQueryLod)
3007 {
3008 texop += img_expr;
3009 texop += ".CalculateLevelOfDetail";
3010 }
3011 else
3012 {
3013 auto &imgformat = get<SPIRType>(id: imgtype.image.type);
3014 if (imgformat.basetype != SPIRType::Float)
3015 {
3016 SPIRV_CROSS_THROW("Sampling non-float textures is not supported in HLSL.");
3017 }
3018
3019 if (hlsl_options.shader_model >= 40)
3020 {
3021 texop += img_expr;
3022
3023 if (is_depth_image(type: imgtype, id: img))
3024 {
3025 if (gather)
3026 {
3027 SPIRV_CROSS_THROW("GatherCmp does not exist in HLSL.");
3028 }
3029 else if (lod || grad_x || grad_y)
3030 {
3031 // Assume we want a fixed level, and the only thing we can get in HLSL is SampleCmpLevelZero.
3032 texop += ".SampleCmpLevelZero";
3033 }
3034 else
3035 texop += ".SampleCmp";
3036 }
3037 else if (gather)
3038 {
3039 uint32_t comp_num = evaluate_constant_u32(id: comp);
3040 if (hlsl_options.shader_model >= 50)
3041 {
3042 switch (comp_num)
3043 {
3044 case 0:
3045 texop += ".GatherRed";
3046 break;
3047 case 1:
3048 texop += ".GatherGreen";
3049 break;
3050 case 2:
3051 texop += ".GatherBlue";
3052 break;
3053 case 3:
3054 texop += ".GatherAlpha";
3055 break;
3056 default:
3057 SPIRV_CROSS_THROW("Invalid component.");
3058 }
3059 }
3060 else
3061 {
3062 if (comp_num == 0)
3063 texop += ".Gather";
3064 else
3065 SPIRV_CROSS_THROW("HLSL shader model 4 can only gather from the red component.");
3066 }
3067 }
3068 else if (bias)
3069 texop += ".SampleBias";
3070 else if (grad_x || grad_y)
3071 texop += ".SampleGrad";
3072 else if (lod)
3073 texop += ".SampleLevel";
3074 else
3075 texop += ".Sample";
3076 }
3077 else
3078 {
3079 switch (imgtype.image.dim)
3080 {
3081 case Dim1D:
3082 texop += "tex1D";
3083 break;
3084 case Dim2D:
3085 texop += "tex2D";
3086 break;
3087 case Dim3D:
3088 texop += "tex3D";
3089 break;
3090 case DimCube:
3091 texop += "texCUBE";
3092 break;
3093 case DimRect:
3094 case DimBuffer:
3095 case DimSubpassData:
3096 SPIRV_CROSS_THROW("Buffer texture support is not yet implemented for HLSL"); // TODO
3097 default:
3098 SPIRV_CROSS_THROW("Invalid dimension.");
3099 }
3100
3101 if (gather)
3102 SPIRV_CROSS_THROW("textureGather is not supported in HLSL shader model 2/3.");
3103 if (offset || coffset)
3104 SPIRV_CROSS_THROW("textureOffset is not supported in HLSL shader model 2/3.");
3105
3106 if (grad_x || grad_y)
3107 texop += "grad";
3108 else if (lod)
3109 texop += "lod";
3110 else if (bias)
3111 texop += "bias";
3112 else if (proj || dref)
3113 texop += "proj";
3114 }
3115 }
3116
3117 expr += texop;
3118 expr += "(";
3119 if (hlsl_options.shader_model < 40)
3120 {
3121 if (combined_image)
3122 SPIRV_CROSS_THROW("Separate images/samplers are not supported in HLSL shader model 2/3.");
3123 expr += to_expression(id: img);
3124 }
3125 else if (op != OpImageFetch)
3126 {
3127 string sampler_expr;
3128 if (combined_image)
3129 sampler_expr = to_non_uniform_aware_expression(id: combined_image->sampler);
3130 else
3131 sampler_expr = to_sampler_expression(id: img);
3132 expr += sampler_expr;
3133 }
3134
3135 auto swizzle = [](uint32_t comps, uint32_t in_comps) -> const char * {
3136 if (comps == in_comps)
3137 return "";
3138
3139 switch (comps)
3140 {
3141 case 1:
3142 return ".x";
3143 case 2:
3144 return ".xy";
3145 case 3:
3146 return ".xyz";
3147 default:
3148 return "";
3149 }
3150 };
3151
3152 bool forward = should_forward(id: coord);
3153
3154 // The IR can give us more components than we need, so chop them off as needed.
3155 string coord_expr;
3156 auto &coord_type = expression_type(id: coord);
3157 if (coord_components != coord_type.vecsize)
3158 coord_expr = to_enclosed_expression(id: coord) + swizzle(coord_components, expression_type(id: coord).vecsize);
3159 else
3160 coord_expr = to_expression(id: coord);
3161
3162 if (proj && hlsl_options.shader_model >= 40) // Legacy HLSL has "proj" operations which do this for us.
3163 coord_expr = coord_expr + " / " + to_extract_component_expression(id: coord, index: coord_components);
3164
3165 if (hlsl_options.shader_model < 40)
3166 {
3167 if (dref)
3168 {
3169 if (imgtype.image.dim != spv::Dim1D && imgtype.image.dim != spv::Dim2D)
3170 {
3171 SPIRV_CROSS_THROW(
3172 "Depth comparison is only supported for 1D and 2D textures in HLSL shader model 2/3.");
3173 }
3174
3175 if (grad_x || grad_y)
3176 SPIRV_CROSS_THROW("Depth comparison is not supported for grad sampling in HLSL shader model 2/3.");
3177
3178 for (uint32_t size = coord_components; size < 2; ++size)
3179 coord_expr += ", 0.0";
3180
3181 forward = forward && should_forward(id: dref);
3182 coord_expr += ", " + to_expression(id: dref);
3183 }
3184 else if (lod || bias || proj)
3185 {
3186 for (uint32_t size = coord_components; size < 3; ++size)
3187 coord_expr += ", 0.0";
3188 }
3189
3190 if (lod)
3191 {
3192 coord_expr = "float4(" + coord_expr + ", " + to_expression(id: lod) + ")";
3193 }
3194 else if (bias)
3195 {
3196 coord_expr = "float4(" + coord_expr + ", " + to_expression(id: bias) + ")";
3197 }
3198 else if (proj)
3199 {
3200 coord_expr = "float4(" + coord_expr + ", " + to_extract_component_expression(id: coord, index: coord_components) + ")";
3201 }
3202 else if (dref)
3203 {
3204 // A "normal" sample gets fed into tex2Dproj as well, because the
3205 // regular tex2D accepts only two coordinates.
3206 coord_expr = "float4(" + coord_expr + ", 1.0)";
3207 }
3208
3209 if (!!lod + !!bias + !!proj > 1)
3210 SPIRV_CROSS_THROW("Legacy HLSL can only use one of lod/bias/proj modifiers.");
3211 }
3212
3213 if (op == OpImageFetch)
3214 {
3215 if (imgtype.image.dim != DimBuffer && !imgtype.image.ms)
3216 coord_expr =
3217 join(ts: "int", ts: coord_components + 1, ts: "(", ts&: coord_expr, ts: ", ", ts: lod ? to_expression(id: lod) : string("0"), ts: ")");
3218 }
3219 else
3220 expr += ", ";
3221 expr += coord_expr;
3222
3223 if (dref && hlsl_options.shader_model >= 40)
3224 {
3225 forward = forward && should_forward(id: dref);
3226 expr += ", ";
3227
3228 if (proj)
3229 expr += to_enclosed_expression(id: dref) + " / " + to_extract_component_expression(id: coord, index: coord_components);
3230 else
3231 expr += to_expression(id: dref);
3232 }
3233
3234 if (!dref && (grad_x || grad_y))
3235 {
3236 forward = forward && should_forward(id: grad_x);
3237 forward = forward && should_forward(id: grad_y);
3238 expr += ", ";
3239 expr += to_expression(id: grad_x);
3240 expr += ", ";
3241 expr += to_expression(id: grad_y);
3242 }
3243
3244 if (!dref && lod && hlsl_options.shader_model >= 40 && op != OpImageFetch)
3245 {
3246 forward = forward && should_forward(id: lod);
3247 expr += ", ";
3248 expr += to_expression(id: lod);
3249 }
3250
3251 if (!dref && bias && hlsl_options.shader_model >= 40)
3252 {
3253 forward = forward && should_forward(id: bias);
3254 expr += ", ";
3255 expr += to_expression(id: bias);
3256 }
3257
3258 if (coffset)
3259 {
3260 forward = forward && should_forward(id: coffset);
3261 expr += ", ";
3262 expr += to_expression(id: coffset);
3263 }
3264 else if (offset)
3265 {
3266 forward = forward && should_forward(id: offset);
3267 expr += ", ";
3268 expr += to_expression(id: offset);
3269 }
3270
3271 if (sample)
3272 {
3273 expr += ", ";
3274 expr += to_expression(id: sample);
3275 }
3276
3277 expr += ")";
3278
3279 if (dref && hlsl_options.shader_model < 40)
3280 expr += ".x";
3281
3282 if (op == OpImageQueryLod)
3283 {
3284 // This is rather awkward.
3285 // textureQueryLod returns two values, the "accessed level",
3286 // as well as the actual LOD lambda.
3287 // As far as I can tell, there is no way to get the .x component
3288 // according to GLSL spec, and it depends on the sampler itself.
3289 // Just assume X == Y, so we will need to splat the result to a float2.
3290 statement(ts: "float _", ts&: id, ts: "_tmp = ", ts&: expr, ts: ";");
3291 statement(ts: "float2 _", ts&: id, ts: " = _", ts&: id, ts: "_tmp.xx;");
3292 set<SPIRExpression>(id, args: join(ts: "_", ts&: id), args&: result_type, args: true);
3293 }
3294 else
3295 {
3296 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: forward, suppress_usage_tracking: false);
3297 }
3298
3299 for (auto &inherit : inherited_expressions)
3300 inherit_expression_dependencies(dst: id, source: inherit);
3301
3302 switch (op)
3303 {
3304 case OpImageSampleDrefImplicitLod:
3305 case OpImageSampleImplicitLod:
3306 case OpImageSampleProjImplicitLod:
3307 case OpImageSampleProjDrefImplicitLod:
3308 register_control_dependent_expression(expr: id);
3309 break;
3310
3311 default:
3312 break;
3313 }
3314}
3315
3316string CompilerHLSL::to_resource_binding(const SPIRVariable &var)
3317{
3318 const auto &type = get<SPIRType>(id: var.basetype);
3319
3320 // We can remap push constant blocks, even if they don't have any binding decoration.
3321 if (type.storage != StorageClassPushConstant && !has_decoration(id: var.self, decoration: DecorationBinding))
3322 return "";
3323
3324 char space = '\0';
3325
3326 HLSLBindingFlagBits resource_flags = HLSL_BINDING_AUTO_NONE_BIT;
3327
3328 switch (type.basetype)
3329 {
3330 case SPIRType::SampledImage:
3331 space = 't'; // SRV
3332 resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3333 break;
3334
3335 case SPIRType::Image:
3336 if (type.image.sampled == 2 && type.image.dim != DimSubpassData)
3337 {
3338 if (has_decoration(id: var.self, decoration: DecorationNonWritable) && hlsl_options.nonwritable_uav_texture_as_srv)
3339 {
3340 space = 't'; // SRV
3341 resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3342 }
3343 else
3344 {
3345 space = 'u'; // UAV
3346 resource_flags = HLSL_BINDING_AUTO_UAV_BIT;
3347 }
3348 }
3349 else
3350 {
3351 space = 't'; // SRV
3352 resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3353 }
3354 break;
3355
3356 case SPIRType::Sampler:
3357 space = 's';
3358 resource_flags = HLSL_BINDING_AUTO_SAMPLER_BIT;
3359 break;
3360
3361 case SPIRType::AccelerationStructure:
3362 space = 't'; // SRV
3363 resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3364 break;
3365
3366 case SPIRType::Struct:
3367 {
3368 auto storage = type.storage;
3369 if (storage == StorageClassUniform)
3370 {
3371 if (has_decoration(id: type.self, decoration: DecorationBufferBlock))
3372 {
3373 Bitset flags = ir.get_buffer_block_flags(var);
3374 bool is_readonly = flags.get(bit: DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(id: var.self);
3375 space = is_readonly ? 't' : 'u'; // UAV
3376 resource_flags = is_readonly ? HLSL_BINDING_AUTO_SRV_BIT : HLSL_BINDING_AUTO_UAV_BIT;
3377 }
3378 else if (has_decoration(id: type.self, decoration: DecorationBlock))
3379 {
3380 space = 'b'; // Constant buffers
3381 resource_flags = HLSL_BINDING_AUTO_CBV_BIT;
3382 }
3383 }
3384 else if (storage == StorageClassPushConstant)
3385 {
3386 space = 'b'; // Constant buffers
3387 resource_flags = HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT;
3388 }
3389 else if (storage == StorageClassStorageBuffer)
3390 {
3391 // UAV or SRV depending on readonly flag.
3392 Bitset flags = ir.get_buffer_block_flags(var);
3393 bool is_readonly = flags.get(bit: DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(id: var.self);
3394 space = is_readonly ? 't' : 'u';
3395 resource_flags = is_readonly ? HLSL_BINDING_AUTO_SRV_BIT : HLSL_BINDING_AUTO_UAV_BIT;
3396 }
3397
3398 break;
3399 }
3400 default:
3401 break;
3402 }
3403
3404 if (!space)
3405 return "";
3406
3407 uint32_t desc_set =
3408 resource_flags == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT ? ResourceBindingPushConstantDescriptorSet : 0u;
3409 uint32_t binding = resource_flags == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT ? ResourceBindingPushConstantBinding : 0u;
3410
3411 if (has_decoration(id: var.self, decoration: DecorationBinding))
3412 binding = get_decoration(id: var.self, decoration: DecorationBinding);
3413 if (has_decoration(id: var.self, decoration: DecorationDescriptorSet))
3414 desc_set = get_decoration(id: var.self, decoration: DecorationDescriptorSet);
3415
3416 return to_resource_register(flag: resource_flags, space, binding, set: desc_set);
3417}
3418
3419string CompilerHLSL::to_resource_binding_sampler(const SPIRVariable &var)
3420{
3421 // For combined image samplers.
3422 if (!has_decoration(id: var.self, decoration: DecorationBinding))
3423 return "";
3424
3425 return to_resource_register(flag: HLSL_BINDING_AUTO_SAMPLER_BIT, space: 's', binding: get_decoration(id: var.self, decoration: DecorationBinding),
3426 set: get_decoration(id: var.self, decoration: DecorationDescriptorSet));
3427}
3428
3429void CompilerHLSL::remap_hlsl_resource_binding(HLSLBindingFlagBits type, uint32_t &desc_set, uint32_t &binding)
3430{
3431 auto itr = resource_bindings.find(x: { .model: get_execution_model(), .desc_set: desc_set, .binding: binding });
3432 if (itr != end(cont&: resource_bindings))
3433 {
3434 auto &remap = itr->second;
3435 remap.second = true;
3436
3437 switch (type)
3438 {
3439 case HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT:
3440 case HLSL_BINDING_AUTO_CBV_BIT:
3441 desc_set = remap.first.cbv.register_space;
3442 binding = remap.first.cbv.register_binding;
3443 break;
3444
3445 case HLSL_BINDING_AUTO_SRV_BIT:
3446 desc_set = remap.first.srv.register_space;
3447 binding = remap.first.srv.register_binding;
3448 break;
3449
3450 case HLSL_BINDING_AUTO_SAMPLER_BIT:
3451 desc_set = remap.first.sampler.register_space;
3452 binding = remap.first.sampler.register_binding;
3453 break;
3454
3455 case HLSL_BINDING_AUTO_UAV_BIT:
3456 desc_set = remap.first.uav.register_space;
3457 binding = remap.first.uav.register_binding;
3458 break;
3459
3460 default:
3461 break;
3462 }
3463 }
3464}
3465
3466string CompilerHLSL::to_resource_register(HLSLBindingFlagBits flag, char space, uint32_t binding, uint32_t space_set)
3467{
3468 if ((flag & resource_binding_flags) == 0)
3469 {
3470 remap_hlsl_resource_binding(type: flag, desc_set&: space_set, binding);
3471
3472 // The push constant block did not have a binding, and there were no remap for it,
3473 // so, declare without register binding.
3474 if (flag == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT && space_set == ResourceBindingPushConstantDescriptorSet)
3475 return "";
3476
3477 if (hlsl_options.shader_model >= 51)
3478 return join(ts: " : register(", ts&: space, ts&: binding, ts: ", space", ts&: space_set, ts: ")");
3479 else
3480 return join(ts: " : register(", ts&: space, ts&: binding, ts: ")");
3481 }
3482 else
3483 return "";
3484}
3485
3486void CompilerHLSL::emit_modern_uniform(const SPIRVariable &var)
3487{
3488 auto &type = get<SPIRType>(id: var.basetype);
3489 switch (type.basetype)
3490 {
3491 case SPIRType::SampledImage:
3492 case SPIRType::Image:
3493 {
3494 bool is_coherent = false;
3495 if (type.basetype == SPIRType::Image && type.image.sampled == 2)
3496 is_coherent = has_decoration(id: var.self, decoration: DecorationCoherent);
3497
3498 statement(ts: is_coherent ? "globallycoherent " : "", ts: image_type_hlsl_modern(type, id: var.self), ts: " ",
3499 ts: to_name(id: var.self), ts: type_to_array_glsl(type), ts: to_resource_binding(var), ts: ";");
3500
3501 if (type.basetype == SPIRType::SampledImage && type.image.dim != DimBuffer)
3502 {
3503 // For combined image samplers, also emit a combined image sampler.
3504 if (is_depth_image(type, id: var.self))
3505 statement(ts: "SamplerComparisonState ", ts: to_sampler_expression(id: var.self), ts: type_to_array_glsl(type),
3506 ts: to_resource_binding_sampler(var), ts: ";");
3507 else
3508 statement(ts: "SamplerState ", ts: to_sampler_expression(id: var.self), ts: type_to_array_glsl(type),
3509 ts: to_resource_binding_sampler(var), ts: ";");
3510 }
3511 break;
3512 }
3513
3514 case SPIRType::Sampler:
3515 if (comparison_ids.count(x: var.self))
3516 statement(ts: "SamplerComparisonState ", ts: to_name(id: var.self), ts: type_to_array_glsl(type), ts: to_resource_binding(var),
3517 ts: ";");
3518 else
3519 statement(ts: "SamplerState ", ts: to_name(id: var.self), ts: type_to_array_glsl(type), ts: to_resource_binding(var), ts: ";");
3520 break;
3521
3522 default:
3523 statement(ts: variable_decl(variable: var), ts: to_resource_binding(var), ts: ";");
3524 break;
3525 }
3526}
3527
3528void CompilerHLSL::emit_legacy_uniform(const SPIRVariable &var)
3529{
3530 auto &type = get<SPIRType>(id: var.basetype);
3531 switch (type.basetype)
3532 {
3533 case SPIRType::Sampler:
3534 case SPIRType::Image:
3535 SPIRV_CROSS_THROW("Separate image and samplers not supported in legacy HLSL.");
3536
3537 default:
3538 statement(ts: variable_decl(variable: var), ts: ";");
3539 break;
3540 }
3541}
3542
3543void CompilerHLSL::emit_uniform(const SPIRVariable &var)
3544{
3545 add_resource_name(id: var.self);
3546 if (hlsl_options.shader_model >= 40)
3547 emit_modern_uniform(var);
3548 else
3549 emit_legacy_uniform(var);
3550}
3551
3552bool CompilerHLSL::emit_complex_bitcast(uint32_t, uint32_t, uint32_t)
3553{
3554 return false;
3555}
3556
3557string CompilerHLSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
3558{
3559 if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Int)
3560 return type_to_glsl(type: out_type);
3561 else if (out_type.basetype == SPIRType::UInt64 && in_type.basetype == SPIRType::Int64)
3562 return type_to_glsl(type: out_type);
3563 else if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Float)
3564 return "asuint";
3565 else if (out_type.basetype == SPIRType::Int && in_type.basetype == SPIRType::UInt)
3566 return type_to_glsl(type: out_type);
3567 else if (out_type.basetype == SPIRType::Int64 && in_type.basetype == SPIRType::UInt64)
3568 return type_to_glsl(type: out_type);
3569 else if (out_type.basetype == SPIRType::Int && in_type.basetype == SPIRType::Float)
3570 return "asint";
3571 else if (out_type.basetype == SPIRType::Float && in_type.basetype == SPIRType::UInt)
3572 return "asfloat";
3573 else if (out_type.basetype == SPIRType::Float && in_type.basetype == SPIRType::Int)
3574 return "asfloat";
3575 else if (out_type.basetype == SPIRType::Int64 && in_type.basetype == SPIRType::Double)
3576 SPIRV_CROSS_THROW("Double to Int64 is not supported in HLSL.");
3577 else if (out_type.basetype == SPIRType::UInt64 && in_type.basetype == SPIRType::Double)
3578 SPIRV_CROSS_THROW("Double to UInt64 is not supported in HLSL.");
3579 else if (out_type.basetype == SPIRType::Double && in_type.basetype == SPIRType::Int64)
3580 return "asdouble";
3581 else if (out_type.basetype == SPIRType::Double && in_type.basetype == SPIRType::UInt64)
3582 return "asdouble";
3583 else if (out_type.basetype == SPIRType::Half && in_type.basetype == SPIRType::UInt && in_type.vecsize == 1)
3584 {
3585 if (!requires_explicit_fp16_packing)
3586 {
3587 requires_explicit_fp16_packing = true;
3588 force_recompile();
3589 }
3590 return "spvUnpackFloat2x16";
3591 }
3592 else if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Half && in_type.vecsize == 2)
3593 {
3594 if (!requires_explicit_fp16_packing)
3595 {
3596 requires_explicit_fp16_packing = true;
3597 force_recompile();
3598 }
3599 return "spvPackFloat2x16";
3600 }
3601 else
3602 return "";
3603}
3604
3605void CompilerHLSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, uint32_t count)
3606{
3607 auto op = static_cast<GLSLstd450>(eop);
3608
3609 // If we need to do implicit bitcasts, make sure we do it with the correct type.
3610 uint32_t integer_width = get_integer_width_for_glsl_instruction(op, arguments: args, length: count);
3611 auto int_type = to_signed_basetype(width: integer_width);
3612 auto uint_type = to_unsigned_basetype(width: integer_width);
3613
3614 op = get_remapped_glsl_op(std450_op: op);
3615
3616 switch (op)
3617 {
3618 case GLSLstd450InverseSqrt:
3619 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "rsqrt");
3620 break;
3621
3622 case GLSLstd450Fract:
3623 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "frac");
3624 break;
3625
3626 case GLSLstd450RoundEven:
3627 if (hlsl_options.shader_model < 40)
3628 SPIRV_CROSS_THROW("roundEven is not supported in HLSL shader model 2/3.");
3629 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "round");
3630 break;
3631
3632 case GLSLstd450Acosh:
3633 case GLSLstd450Asinh:
3634 case GLSLstd450Atanh:
3635 SPIRV_CROSS_THROW("Inverse hyperbolics are not supported on HLSL.");
3636
3637 case GLSLstd450FMix:
3638 case GLSLstd450IMix:
3639 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "lerp");
3640 break;
3641
3642 case GLSLstd450Atan2:
3643 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "atan2");
3644 break;
3645
3646 case GLSLstd450Fma:
3647 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "mad");
3648 break;
3649
3650 case GLSLstd450InterpolateAtCentroid:
3651 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "EvaluateAttributeAtCentroid");
3652 break;
3653 case GLSLstd450InterpolateAtSample:
3654 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "EvaluateAttributeAtSample");
3655 break;
3656 case GLSLstd450InterpolateAtOffset:
3657 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "EvaluateAttributeSnapped");
3658 break;
3659
3660 case GLSLstd450PackHalf2x16:
3661 if (!requires_fp16_packing)
3662 {
3663 requires_fp16_packing = true;
3664 force_recompile();
3665 }
3666 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvPackHalf2x16");
3667 break;
3668
3669 case GLSLstd450UnpackHalf2x16:
3670 if (!requires_fp16_packing)
3671 {
3672 requires_fp16_packing = true;
3673 force_recompile();
3674 }
3675 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvUnpackHalf2x16");
3676 break;
3677
3678 case GLSLstd450PackSnorm4x8:
3679 if (!requires_snorm8_packing)
3680 {
3681 requires_snorm8_packing = true;
3682 force_recompile();
3683 }
3684 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvPackSnorm4x8");
3685 break;
3686
3687 case GLSLstd450UnpackSnorm4x8:
3688 if (!requires_snorm8_packing)
3689 {
3690 requires_snorm8_packing = true;
3691 force_recompile();
3692 }
3693 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvUnpackSnorm4x8");
3694 break;
3695
3696 case GLSLstd450PackUnorm4x8:
3697 if (!requires_unorm8_packing)
3698 {
3699 requires_unorm8_packing = true;
3700 force_recompile();
3701 }
3702 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvPackUnorm4x8");
3703 break;
3704
3705 case GLSLstd450UnpackUnorm4x8:
3706 if (!requires_unorm8_packing)
3707 {
3708 requires_unorm8_packing = true;
3709 force_recompile();
3710 }
3711 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvUnpackUnorm4x8");
3712 break;
3713
3714 case GLSLstd450PackSnorm2x16:
3715 if (!requires_snorm16_packing)
3716 {
3717 requires_snorm16_packing = true;
3718 force_recompile();
3719 }
3720 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvPackSnorm2x16");
3721 break;
3722
3723 case GLSLstd450UnpackSnorm2x16:
3724 if (!requires_snorm16_packing)
3725 {
3726 requires_snorm16_packing = true;
3727 force_recompile();
3728 }
3729 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvUnpackSnorm2x16");
3730 break;
3731
3732 case GLSLstd450PackUnorm2x16:
3733 if (!requires_unorm16_packing)
3734 {
3735 requires_unorm16_packing = true;
3736 force_recompile();
3737 }
3738 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvPackUnorm2x16");
3739 break;
3740
3741 case GLSLstd450UnpackUnorm2x16:
3742 if (!requires_unorm16_packing)
3743 {
3744 requires_unorm16_packing = true;
3745 force_recompile();
3746 }
3747 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvUnpackUnorm2x16");
3748 break;
3749
3750 case GLSLstd450PackDouble2x32:
3751 case GLSLstd450UnpackDouble2x32:
3752 SPIRV_CROSS_THROW("packDouble2x32/unpackDouble2x32 not supported in HLSL.");
3753
3754 case GLSLstd450FindILsb:
3755 {
3756 auto basetype = expression_type(id: args[0]).basetype;
3757 emit_unary_func_op_cast(result_type, result_id: id, op0: args[0], op: "firstbitlow", input_type: basetype, expected_result_type: basetype);
3758 break;
3759 }
3760
3761 case GLSLstd450FindSMsb:
3762 emit_unary_func_op_cast(result_type, result_id: id, op0: args[0], op: "firstbithigh", input_type: int_type, expected_result_type: int_type);
3763 break;
3764
3765 case GLSLstd450FindUMsb:
3766 emit_unary_func_op_cast(result_type, result_id: id, op0: args[0], op: "firstbithigh", input_type: uint_type, expected_result_type: uint_type);
3767 break;
3768
3769 case GLSLstd450MatrixInverse:
3770 {
3771 auto &type = get<SPIRType>(id: result_type);
3772 if (type.vecsize == 2 && type.columns == 2)
3773 {
3774 if (!requires_inverse_2x2)
3775 {
3776 requires_inverse_2x2 = true;
3777 force_recompile();
3778 }
3779 }
3780 else if (type.vecsize == 3 && type.columns == 3)
3781 {
3782 if (!requires_inverse_3x3)
3783 {
3784 requires_inverse_3x3 = true;
3785 force_recompile();
3786 }
3787 }
3788 else if (type.vecsize == 4 && type.columns == 4)
3789 {
3790 if (!requires_inverse_4x4)
3791 {
3792 requires_inverse_4x4 = true;
3793 force_recompile();
3794 }
3795 }
3796 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvInverse");
3797 break;
3798 }
3799
3800 case GLSLstd450Normalize:
3801 // HLSL does not support scalar versions here.
3802 if (expression_type(id: args[0]).vecsize == 1)
3803 {
3804 // Returns -1 or 1 for valid input, sign() does the job.
3805 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "sign");
3806 }
3807 else
3808 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
3809 break;
3810
3811 case GLSLstd450Reflect:
3812 if (get<SPIRType>(id: result_type).vecsize == 1)
3813 {
3814 if (!requires_scalar_reflect)
3815 {
3816 requires_scalar_reflect = true;
3817 force_recompile();
3818 }
3819 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "spvReflect");
3820 }
3821 else
3822 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
3823 break;
3824
3825 case GLSLstd450Refract:
3826 if (get<SPIRType>(id: result_type).vecsize == 1)
3827 {
3828 if (!requires_scalar_refract)
3829 {
3830 requires_scalar_refract = true;
3831 force_recompile();
3832 }
3833 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "spvRefract");
3834 }
3835 else
3836 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
3837 break;
3838
3839 case GLSLstd450FaceForward:
3840 if (get<SPIRType>(id: result_type).vecsize == 1)
3841 {
3842 if (!requires_scalar_faceforward)
3843 {
3844 requires_scalar_faceforward = true;
3845 force_recompile();
3846 }
3847 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "spvFaceForward");
3848 }
3849 else
3850 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
3851 break;
3852
3853 default:
3854 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
3855 break;
3856 }
3857}
3858
3859void CompilerHLSL::read_access_chain_array(const string &lhs, const SPIRAccessChain &chain)
3860{
3861 auto &type = get<SPIRType>(id: chain.basetype);
3862
3863 // Need to use a reserved identifier here since it might shadow an identifier in the access chain input or other loops.
3864 auto ident = get_unique_identifier();
3865
3866 statement(ts: "[unroll]");
3867 statement(ts: "for (int ", ts&: ident, ts: " = 0; ", ts&: ident, ts: " < ", ts: to_array_size(type, index: uint32_t(type.array.size() - 1)), ts: "; ",
3868 ts&: ident, ts: "++)");
3869 begin_scope();
3870 auto subchain = chain;
3871 subchain.dynamic_index = join(ts&: ident, ts: " * ", ts: chain.array_stride, ts: " + ", ts: chain.dynamic_index);
3872 subchain.basetype = type.parent_type;
3873 if (!get<SPIRType>(id: subchain.basetype).array.empty())
3874 subchain.array_stride = get_decoration(id: subchain.basetype, decoration: DecorationArrayStride);
3875 read_access_chain(expr: nullptr, lhs: join(ts: lhs, ts: "[", ts&: ident, ts: "]"), chain: subchain);
3876 end_scope();
3877}
3878
3879void CompilerHLSL::read_access_chain_struct(const string &lhs, const SPIRAccessChain &chain)
3880{
3881 auto &type = get<SPIRType>(id: chain.basetype);
3882 auto subchain = chain;
3883 uint32_t member_count = uint32_t(type.member_types.size());
3884
3885 for (uint32_t i = 0; i < member_count; i++)
3886 {
3887 uint32_t offset = type_struct_member_offset(type, index: i);
3888 subchain.static_index = chain.static_index + offset;
3889 subchain.basetype = type.member_types[i];
3890
3891 subchain.matrix_stride = 0;
3892 subchain.array_stride = 0;
3893 subchain.row_major_matrix = false;
3894
3895 auto &member_type = get<SPIRType>(id: subchain.basetype);
3896 if (member_type.columns > 1)
3897 {
3898 subchain.matrix_stride = type_struct_member_matrix_stride(type, index: i);
3899 subchain.row_major_matrix = has_member_decoration(id: type.self, index: i, decoration: DecorationRowMajor);
3900 }
3901
3902 if (!member_type.array.empty())
3903 subchain.array_stride = type_struct_member_array_stride(type, index: i);
3904
3905 read_access_chain(expr: nullptr, lhs: join(ts: lhs, ts: ".", ts: to_member_name(type, index: i)), chain: subchain);
3906 }
3907}
3908
3909void CompilerHLSL::read_access_chain(string *expr, const string &lhs, const SPIRAccessChain &chain)
3910{
3911 auto &type = get<SPIRType>(id: chain.basetype);
3912
3913 SPIRType target_type;
3914 target_type.basetype = SPIRType::UInt;
3915 target_type.vecsize = type.vecsize;
3916 target_type.columns = type.columns;
3917
3918 if (!type.array.empty())
3919 {
3920 read_access_chain_array(lhs, chain);
3921 return;
3922 }
3923 else if (type.basetype == SPIRType::Struct)
3924 {
3925 read_access_chain_struct(lhs, chain);
3926 return;
3927 }
3928 else if (type.width != 32 && !hlsl_options.enable_16bit_types)
3929 SPIRV_CROSS_THROW("Reading types other than 32-bit from ByteAddressBuffer not yet supported, unless SM 6.2 and "
3930 "native 16-bit types are enabled.");
3931
3932 string base = chain.base;
3933 if (has_decoration(id: chain.self, decoration: DecorationNonUniform))
3934 convert_non_uniform_expression(expr&: base, ptr_id: chain.self);
3935
3936 bool templated_load = hlsl_options.shader_model >= 62;
3937 string load_expr;
3938
3939 string template_expr;
3940 if (templated_load)
3941 template_expr = join(ts: "<", ts: type_to_glsl(type), ts: ">");
3942
3943 // Load a vector or scalar.
3944 if (type.columns == 1 && !chain.row_major_matrix)
3945 {
3946 const char *load_op = nullptr;
3947 switch (type.vecsize)
3948 {
3949 case 1:
3950 load_op = "Load";
3951 break;
3952 case 2:
3953 load_op = "Load2";
3954 break;
3955 case 3:
3956 load_op = "Load3";
3957 break;
3958 case 4:
3959 load_op = "Load4";
3960 break;
3961 default:
3962 SPIRV_CROSS_THROW("Unknown vector size.");
3963 }
3964
3965 if (templated_load)
3966 load_op = "Load";
3967
3968 load_expr = join(ts&: base, ts: ".", ts&: load_op, ts&: template_expr, ts: "(", ts: chain.dynamic_index, ts: chain.static_index, ts: ")");
3969 }
3970 else if (type.columns == 1)
3971 {
3972 // Strided load since we are loading a column from a row-major matrix.
3973 if (templated_load)
3974 {
3975 auto scalar_type = type;
3976 scalar_type.vecsize = 1;
3977 scalar_type.columns = 1;
3978 template_expr = join(ts: "<", ts: type_to_glsl(type: scalar_type), ts: ">");
3979 if (type.vecsize > 1)
3980 load_expr += type_to_glsl(type) + "(";
3981 }
3982 else if (type.vecsize > 1)
3983 {
3984 load_expr = type_to_glsl(type: target_type);
3985 load_expr += "(";
3986 }
3987
3988 for (uint32_t r = 0; r < type.vecsize; r++)
3989 {
3990 load_expr += join(ts&: base, ts: ".Load", ts&: template_expr, ts: "(", ts: chain.dynamic_index,
3991 ts: chain.static_index + r * chain.matrix_stride, ts: ")");
3992 if (r + 1 < type.vecsize)
3993 load_expr += ", ";
3994 }
3995
3996 if (type.vecsize > 1)
3997 load_expr += ")";
3998 }
3999 else if (!chain.row_major_matrix)
4000 {
4001 // Load a matrix, column-major, the easy case.
4002 const char *load_op = nullptr;
4003 switch (type.vecsize)
4004 {
4005 case 1:
4006 load_op = "Load";
4007 break;
4008 case 2:
4009 load_op = "Load2";
4010 break;
4011 case 3:
4012 load_op = "Load3";
4013 break;
4014 case 4:
4015 load_op = "Load4";
4016 break;
4017 default:
4018 SPIRV_CROSS_THROW("Unknown vector size.");
4019 }
4020
4021 if (templated_load)
4022 {
4023 auto vector_type = type;
4024 vector_type.columns = 1;
4025 template_expr = join(ts: "<", ts: type_to_glsl(type: vector_type), ts: ">");
4026 load_expr = type_to_glsl(type);
4027 load_op = "Load";
4028 }
4029 else
4030 {
4031 // Note, this loading style in HLSL is *actually* row-major, but we always treat matrices as transposed in this backend,
4032 // so row-major is technically column-major ...
4033 load_expr = type_to_glsl(type: target_type);
4034 }
4035 load_expr += "(";
4036
4037 for (uint32_t c = 0; c < type.columns; c++)
4038 {
4039 load_expr += join(ts&: base, ts: ".", ts&: load_op, ts&: template_expr, ts: "(", ts: chain.dynamic_index,
4040 ts: chain.static_index + c * chain.matrix_stride, ts: ")");
4041 if (c + 1 < type.columns)
4042 load_expr += ", ";
4043 }
4044 load_expr += ")";
4045 }
4046 else
4047 {
4048 // Pick out elements one by one ... Hopefully compilers are smart enough to recognize this pattern
4049 // considering HLSL is "row-major decl", but "column-major" memory layout (basically implicit transpose model, ugh) ...
4050
4051 if (templated_load)
4052 {
4053 load_expr = type_to_glsl(type);
4054 auto scalar_type = type;
4055 scalar_type.vecsize = 1;
4056 scalar_type.columns = 1;
4057 template_expr = join(ts: "<", ts: type_to_glsl(type: scalar_type), ts: ">");
4058 }
4059 else
4060 load_expr = type_to_glsl(type: target_type);
4061
4062 load_expr += "(";
4063
4064 for (uint32_t c = 0; c < type.columns; c++)
4065 {
4066 for (uint32_t r = 0; r < type.vecsize; r++)
4067 {
4068 load_expr += join(ts&: base, ts: ".Load", ts&: template_expr, ts: "(", ts: chain.dynamic_index,
4069 ts: chain.static_index + c * (type.width / 8) + r * chain.matrix_stride, ts: ")");
4070
4071 if ((r + 1 < type.vecsize) || (c + 1 < type.columns))
4072 load_expr += ", ";
4073 }
4074 }
4075 load_expr += ")";
4076 }
4077
4078 if (!templated_load)
4079 {
4080 auto bitcast_op = bitcast_glsl_op(out_type: type, in_type: target_type);
4081 if (!bitcast_op.empty())
4082 load_expr = join(ts&: bitcast_op, ts: "(", ts&: load_expr, ts: ")");
4083 }
4084
4085 if (lhs.empty())
4086 {
4087 assert(expr);
4088 *expr = std::move(load_expr);
4089 }
4090 else
4091 statement(ts: lhs, ts: " = ", ts&: load_expr, ts: ";");
4092}
4093
4094void CompilerHLSL::emit_load(const Instruction &instruction)
4095{
4096 auto ops = stream(instr: instruction);
4097
4098 auto *chain = maybe_get<SPIRAccessChain>(id: ops[2]);
4099 if (chain)
4100 {
4101 uint32_t result_type = ops[0];
4102 uint32_t id = ops[1];
4103 uint32_t ptr = ops[2];
4104
4105 auto &type = get<SPIRType>(id: result_type);
4106 bool composite_load = !type.array.empty() || type.basetype == SPIRType::Struct;
4107
4108 if (composite_load)
4109 {
4110 // We cannot make this work in one single expression as we might have nested structures and arrays,
4111 // so unroll the load to an uninitialized temporary.
4112 emit_uninitialized_temporary_expression(type: result_type, id);
4113 read_access_chain(expr: nullptr, lhs: to_expression(id), chain: *chain);
4114 track_expression_read(id: chain->self);
4115 }
4116 else
4117 {
4118 string load_expr;
4119 read_access_chain(expr: &load_expr, lhs: "", chain: *chain);
4120
4121 bool forward = should_forward(id: ptr) && forced_temporaries.find(x: id) == end(cont&: forced_temporaries);
4122
4123 // If we are forwarding this load,
4124 // don't register the read to access chain here, defer that to when we actually use the expression,
4125 // using the add_implied_read_expression mechanism.
4126 if (!forward)
4127 track_expression_read(id: chain->self);
4128
4129 // Do not forward complex load sequences like matrices, structs and arrays.
4130 if (type.columns > 1)
4131 forward = false;
4132
4133 auto &e = emit_op(result_type, result_id: id, rhs: load_expr, forward_rhs: forward, suppress_usage_tracking: true);
4134 e.need_transpose = false;
4135 register_read(expr: id, chain: ptr, forwarded: forward);
4136 inherit_expression_dependencies(dst: id, source: ptr);
4137 if (forward)
4138 add_implied_read_expression(e, source: chain->self);
4139 }
4140 }
4141 else
4142 CompilerGLSL::emit_instruction(instr: instruction);
4143}
4144
4145void CompilerHLSL::write_access_chain_array(const SPIRAccessChain &chain, uint32_t value,
4146 const SmallVector<uint32_t> &composite_chain)
4147{
4148 auto &type = get<SPIRType>(id: chain.basetype);
4149
4150 // Need to use a reserved identifier here since it might shadow an identifier in the access chain input or other loops.
4151 auto ident = get_unique_identifier();
4152
4153 uint32_t id = ir.increase_bound_by(count: 2);
4154 uint32_t int_type_id = id + 1;
4155 SPIRType int_type;
4156 int_type.basetype = SPIRType::Int;
4157 int_type.width = 32;
4158 set<SPIRType>(id: int_type_id, args&: int_type);
4159 set<SPIRExpression>(id, args&: ident, args&: int_type_id, args: true);
4160 set_name(id, name: ident);
4161 suppressed_usage_tracking.insert(x: id);
4162
4163 statement(ts: "[unroll]");
4164 statement(ts: "for (int ", ts&: ident, ts: " = 0; ", ts&: ident, ts: " < ", ts: to_array_size(type, index: uint32_t(type.array.size() - 1)), ts: "; ",
4165 ts&: ident, ts: "++)");
4166 begin_scope();
4167 auto subchain = chain;
4168 subchain.dynamic_index = join(ts&: ident, ts: " * ", ts: chain.array_stride, ts: " + ", ts: chain.dynamic_index);
4169 subchain.basetype = type.parent_type;
4170
4171 // Forcefully allow us to use an ID here by setting MSB.
4172 auto subcomposite_chain = composite_chain;
4173 subcomposite_chain.push_back(t: 0x80000000u | id);
4174
4175 if (!get<SPIRType>(id: subchain.basetype).array.empty())
4176 subchain.array_stride = get_decoration(id: subchain.basetype, decoration: DecorationArrayStride);
4177
4178 write_access_chain(chain: subchain, value, composite_chain: subcomposite_chain);
4179 end_scope();
4180}
4181
4182void CompilerHLSL::write_access_chain_struct(const SPIRAccessChain &chain, uint32_t value,
4183 const SmallVector<uint32_t> &composite_chain)
4184{
4185 auto &type = get<SPIRType>(id: chain.basetype);
4186 uint32_t member_count = uint32_t(type.member_types.size());
4187 auto subchain = chain;
4188
4189 auto subcomposite_chain = composite_chain;
4190 subcomposite_chain.push_back(t: 0);
4191
4192 for (uint32_t i = 0; i < member_count; i++)
4193 {
4194 uint32_t offset = type_struct_member_offset(type, index: i);
4195 subchain.static_index = chain.static_index + offset;
4196 subchain.basetype = type.member_types[i];
4197
4198 subchain.matrix_stride = 0;
4199 subchain.array_stride = 0;
4200 subchain.row_major_matrix = false;
4201
4202 auto &member_type = get<SPIRType>(id: subchain.basetype);
4203 if (member_type.columns > 1)
4204 {
4205 subchain.matrix_stride = type_struct_member_matrix_stride(type, index: i);
4206 subchain.row_major_matrix = has_member_decoration(id: type.self, index: i, decoration: DecorationRowMajor);
4207 }
4208
4209 if (!member_type.array.empty())
4210 subchain.array_stride = type_struct_member_array_stride(type, index: i);
4211
4212 subcomposite_chain.back() = i;
4213 write_access_chain(chain: subchain, value, composite_chain: subcomposite_chain);
4214 }
4215}
4216
4217string CompilerHLSL::write_access_chain_value(uint32_t value, const SmallVector<uint32_t> &composite_chain,
4218 bool enclose)
4219{
4220 string ret;
4221 if (composite_chain.empty())
4222 ret = to_expression(id: value);
4223 else
4224 {
4225 AccessChainMeta meta;
4226 ret = access_chain_internal(base: value, indices: composite_chain.data(), count: uint32_t(composite_chain.size()),
4227 flags: ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_LITERAL_MSB_FORCE_ID, meta: &meta);
4228 }
4229
4230 if (enclose)
4231 ret = enclose_expression(expr: ret);
4232 return ret;
4233}
4234
4235void CompilerHLSL::write_access_chain(const SPIRAccessChain &chain, uint32_t value,
4236 const SmallVector<uint32_t> &composite_chain)
4237{
4238 auto &type = get<SPIRType>(id: chain.basetype);
4239
4240 // Make sure we trigger a read of the constituents in the access chain.
4241 track_expression_read(id: chain.self);
4242
4243 SPIRType target_type;
4244 target_type.basetype = SPIRType::UInt;
4245 target_type.vecsize = type.vecsize;
4246 target_type.columns = type.columns;
4247
4248 if (!type.array.empty())
4249 {
4250 write_access_chain_array(chain, value, composite_chain);
4251 register_write(chain: chain.self);
4252 return;
4253 }
4254 else if (type.basetype == SPIRType::Struct)
4255 {
4256 write_access_chain_struct(chain, value, composite_chain);
4257 register_write(chain: chain.self);
4258 return;
4259 }
4260 else if (type.width != 32 && !hlsl_options.enable_16bit_types)
4261 SPIRV_CROSS_THROW("Writing types other than 32-bit to RWByteAddressBuffer not yet supported, unless SM 6.2 and "
4262 "native 16-bit types are enabled.");
4263
4264 bool templated_store = hlsl_options.shader_model >= 62;
4265
4266 auto base = chain.base;
4267 if (has_decoration(id: chain.self, decoration: DecorationNonUniform))
4268 convert_non_uniform_expression(expr&: base, ptr_id: chain.self);
4269
4270 string template_expr;
4271 if (templated_store)
4272 template_expr = join(ts: "<", ts: type_to_glsl(type), ts: ">");
4273
4274 if (type.columns == 1 && !chain.row_major_matrix)
4275 {
4276 const char *store_op = nullptr;
4277 switch (type.vecsize)
4278 {
4279 case 1:
4280 store_op = "Store";
4281 break;
4282 case 2:
4283 store_op = "Store2";
4284 break;
4285 case 3:
4286 store_op = "Store3";
4287 break;
4288 case 4:
4289 store_op = "Store4";
4290 break;
4291 default:
4292 SPIRV_CROSS_THROW("Unknown vector size.");
4293 }
4294
4295 auto store_expr = write_access_chain_value(value, composite_chain, enclose: false);
4296
4297 if (!templated_store)
4298 {
4299 auto bitcast_op = bitcast_glsl_op(out_type: target_type, in_type: type);
4300 if (!bitcast_op.empty())
4301 store_expr = join(ts&: bitcast_op, ts: "(", ts&: store_expr, ts: ")");
4302 }
4303 else
4304 store_op = "Store";
4305 statement(ts&: base, ts: ".", ts&: store_op, ts&: template_expr, ts: "(", ts: chain.dynamic_index, ts: chain.static_index, ts: ", ",
4306 ts&: store_expr, ts: ");");
4307 }
4308 else if (type.columns == 1)
4309 {
4310 if (templated_store)
4311 {
4312 auto scalar_type = type;
4313 scalar_type.vecsize = 1;
4314 scalar_type.columns = 1;
4315 template_expr = join(ts: "<", ts: type_to_glsl(type: scalar_type), ts: ">");
4316 }
4317
4318 // Strided store.
4319 for (uint32_t r = 0; r < type.vecsize; r++)
4320 {
4321 auto store_expr = write_access_chain_value(value, composite_chain, enclose: true);
4322 if (type.vecsize > 1)
4323 {
4324 store_expr += ".";
4325 store_expr += index_to_swizzle(index: r);
4326 }
4327 remove_duplicate_swizzle(op&: store_expr);
4328
4329 if (!templated_store)
4330 {
4331 auto bitcast_op = bitcast_glsl_op(out_type: target_type, in_type: type);
4332 if (!bitcast_op.empty())
4333 store_expr = join(ts&: bitcast_op, ts: "(", ts&: store_expr, ts: ")");
4334 }
4335
4336 statement(ts&: base, ts: ".Store", ts&: template_expr, ts: "(", ts: chain.dynamic_index,
4337 ts: chain.static_index + chain.matrix_stride * r, ts: ", ", ts&: store_expr, ts: ");");
4338 }
4339 }
4340 else if (!chain.row_major_matrix)
4341 {
4342 const char *store_op = nullptr;
4343 switch (type.vecsize)
4344 {
4345 case 1:
4346 store_op = "Store";
4347 break;
4348 case 2:
4349 store_op = "Store2";
4350 break;
4351 case 3:
4352 store_op = "Store3";
4353 break;
4354 case 4:
4355 store_op = "Store4";
4356 break;
4357 default:
4358 SPIRV_CROSS_THROW("Unknown vector size.");
4359 }
4360
4361 if (templated_store)
4362 {
4363 store_op = "Store";
4364 auto vector_type = type;
4365 vector_type.columns = 1;
4366 template_expr = join(ts: "<", ts: type_to_glsl(type: vector_type), ts: ">");
4367 }
4368
4369 for (uint32_t c = 0; c < type.columns; c++)
4370 {
4371 auto store_expr = join(ts: write_access_chain_value(value, composite_chain, enclose: true), ts: "[", ts&: c, ts: "]");
4372
4373 if (!templated_store)
4374 {
4375 auto bitcast_op = bitcast_glsl_op(out_type: target_type, in_type: type);
4376 if (!bitcast_op.empty())
4377 store_expr = join(ts&: bitcast_op, ts: "(", ts&: store_expr, ts: ")");
4378 }
4379
4380 statement(ts&: base, ts: ".", ts&: store_op, ts&: template_expr, ts: "(", ts: chain.dynamic_index,
4381 ts: chain.static_index + c * chain.matrix_stride, ts: ", ", ts&: store_expr, ts: ");");
4382 }
4383 }
4384 else
4385 {
4386 if (templated_store)
4387 {
4388 auto scalar_type = type;
4389 scalar_type.vecsize = 1;
4390 scalar_type.columns = 1;
4391 template_expr = join(ts: "<", ts: type_to_glsl(type: scalar_type), ts: ">");
4392 }
4393
4394 for (uint32_t r = 0; r < type.vecsize; r++)
4395 {
4396 for (uint32_t c = 0; c < type.columns; c++)
4397 {
4398 auto store_expr =
4399 join(ts: write_access_chain_value(value, composite_chain, enclose: true), ts: "[", ts&: c, ts: "].", ts: index_to_swizzle(index: r));
4400 remove_duplicate_swizzle(op&: store_expr);
4401 auto bitcast_op = bitcast_glsl_op(out_type: target_type, in_type: type);
4402 if (!bitcast_op.empty())
4403 store_expr = join(ts&: bitcast_op, ts: "(", ts&: store_expr, ts: ")");
4404 statement(ts&: base, ts: ".Store", ts&: template_expr, ts: "(", ts: chain.dynamic_index,
4405 ts: chain.static_index + c * (type.width / 8) + r * chain.matrix_stride, ts: ", ", ts&: store_expr, ts: ");");
4406 }
4407 }
4408 }
4409
4410 register_write(chain: chain.self);
4411}
4412
4413void CompilerHLSL::emit_store(const Instruction &instruction)
4414{
4415 auto ops = stream(instr: instruction);
4416 auto *chain = maybe_get<SPIRAccessChain>(id: ops[0]);
4417 if (chain)
4418 write_access_chain(chain: *chain, value: ops[1], composite_chain: {});
4419 else
4420 CompilerGLSL::emit_instruction(instr: instruction);
4421}
4422
4423void CompilerHLSL::emit_access_chain(const Instruction &instruction)
4424{
4425 auto ops = stream(instr: instruction);
4426 uint32_t length = instruction.length;
4427
4428 bool need_byte_access_chain = false;
4429 auto &type = expression_type(id: ops[2]);
4430 const auto *chain = maybe_get<SPIRAccessChain>(id: ops[2]);
4431
4432 if (chain)
4433 {
4434 // Keep tacking on an existing access chain.
4435 need_byte_access_chain = true;
4436 }
4437 else if (type.storage == StorageClassStorageBuffer || has_decoration(id: type.self, decoration: DecorationBufferBlock))
4438 {
4439 // If we are starting to poke into an SSBO, we are dealing with ByteAddressBuffers, and we need
4440 // to emit SPIRAccessChain rather than a plain SPIRExpression.
4441 uint32_t chain_arguments = length - 3;
4442 if (chain_arguments > type.array.size())
4443 need_byte_access_chain = true;
4444 }
4445
4446 if (need_byte_access_chain)
4447 {
4448 // If we have a chain variable, we are already inside the SSBO, and any array type will refer to arrays within a block,
4449 // and not array of SSBO.
4450 uint32_t to_plain_buffer_length = chain ? 0u : static_cast<uint32_t>(type.array.size());
4451
4452 auto *backing_variable = maybe_get_backing_variable(chain: ops[2]);
4453
4454 string base;
4455 if (to_plain_buffer_length != 0)
4456 base = access_chain(base: ops[2], indices: &ops[3], count: to_plain_buffer_length, target_type: get<SPIRType>(id: ops[0]));
4457 else if (chain)
4458 base = chain->base;
4459 else
4460 base = to_expression(id: ops[2]);
4461
4462 // Start traversing type hierarchy at the proper non-pointer types.
4463 auto *basetype = &get_pointee_type(type);
4464
4465 // Traverse the type hierarchy down to the actual buffer types.
4466 for (uint32_t i = 0; i < to_plain_buffer_length; i++)
4467 {
4468 assert(basetype->parent_type);
4469 basetype = &get<SPIRType>(id: basetype->parent_type);
4470 }
4471
4472 uint32_t matrix_stride = 0;
4473 uint32_t array_stride = 0;
4474 bool row_major_matrix = false;
4475
4476 // Inherit matrix information.
4477 if (chain)
4478 {
4479 matrix_stride = chain->matrix_stride;
4480 row_major_matrix = chain->row_major_matrix;
4481 array_stride = chain->array_stride;
4482 }
4483
4484 auto offsets = flattened_access_chain_offset(basetype: *basetype, indices: &ops[3 + to_plain_buffer_length],
4485 count: length - 3 - to_plain_buffer_length, offset: 0, word_stride: 1, need_transpose: &row_major_matrix,
4486 matrix_stride: &matrix_stride, array_stride: &array_stride);
4487
4488 auto &e = set<SPIRAccessChain>(id: ops[1], args: ops[0], args: type.storage, args&: base, args&: offsets.first, args&: offsets.second);
4489 e.row_major_matrix = row_major_matrix;
4490 e.matrix_stride = matrix_stride;
4491 e.array_stride = array_stride;
4492 e.immutable = should_forward(id: ops[2]);
4493 e.loaded_from = backing_variable ? backing_variable->self : ID(0);
4494
4495 if (chain)
4496 {
4497 e.dynamic_index += chain->dynamic_index;
4498 e.static_index += chain->static_index;
4499 }
4500
4501 for (uint32_t i = 2; i < length; i++)
4502 {
4503 inherit_expression_dependencies(dst: ops[1], source: ops[i]);
4504 add_implied_read_expression(e, source: ops[i]);
4505 }
4506 }
4507 else
4508 {
4509 CompilerGLSL::emit_instruction(instr: instruction);
4510 }
4511}
4512
4513void CompilerHLSL::emit_atomic(const uint32_t *ops, uint32_t length, spv::Op op)
4514{
4515 const char *atomic_op = nullptr;
4516
4517 string value_expr;
4518 if (op != OpAtomicIDecrement && op != OpAtomicIIncrement && op != OpAtomicLoad && op != OpAtomicStore)
4519 value_expr = to_expression(id: ops[op == OpAtomicCompareExchange ? 6 : 5]);
4520
4521 bool is_atomic_store = false;
4522
4523 switch (op)
4524 {
4525 case OpAtomicIIncrement:
4526 atomic_op = "InterlockedAdd";
4527 value_expr = "1";
4528 break;
4529
4530 case OpAtomicIDecrement:
4531 atomic_op = "InterlockedAdd";
4532 value_expr = "-1";
4533 break;
4534
4535 case OpAtomicLoad:
4536 atomic_op = "InterlockedAdd";
4537 value_expr = "0";
4538 break;
4539
4540 case OpAtomicISub:
4541 atomic_op = "InterlockedAdd";
4542 value_expr = join(ts: "-", ts: enclose_expression(expr: value_expr));
4543 break;
4544
4545 case OpAtomicSMin:
4546 case OpAtomicUMin:
4547 atomic_op = "InterlockedMin";
4548 break;
4549
4550 case OpAtomicSMax:
4551 case OpAtomicUMax:
4552 atomic_op = "InterlockedMax";
4553 break;
4554
4555 case OpAtomicAnd:
4556 atomic_op = "InterlockedAnd";
4557 break;
4558
4559 case OpAtomicOr:
4560 atomic_op = "InterlockedOr";
4561 break;
4562
4563 case OpAtomicXor:
4564 atomic_op = "InterlockedXor";
4565 break;
4566
4567 case OpAtomicIAdd:
4568 atomic_op = "InterlockedAdd";
4569 break;
4570
4571 case OpAtomicExchange:
4572 atomic_op = "InterlockedExchange";
4573 break;
4574
4575 case OpAtomicStore:
4576 atomic_op = "InterlockedExchange";
4577 is_atomic_store = true;
4578 break;
4579
4580 case OpAtomicCompareExchange:
4581 if (length < 8)
4582 SPIRV_CROSS_THROW("Not enough data for opcode.");
4583 atomic_op = "InterlockedCompareExchange";
4584 value_expr = join(ts: to_expression(id: ops[7]), ts: ", ", ts&: value_expr);
4585 break;
4586
4587 default:
4588 SPIRV_CROSS_THROW("Unknown atomic opcode.");
4589 }
4590
4591 if (is_atomic_store)
4592 {
4593 auto &data_type = expression_type(id: ops[0]);
4594 auto *chain = maybe_get<SPIRAccessChain>(id: ops[0]);
4595
4596 auto &tmp_id = extra_sub_expressions[ops[0]];
4597 if (!tmp_id)
4598 {
4599 tmp_id = ir.increase_bound_by(count: 1);
4600 emit_uninitialized_temporary_expression(type: get_pointee_type(type: data_type).self, id: tmp_id);
4601 }
4602
4603 if (data_type.storage == StorageClassImage || !chain)
4604 {
4605 statement(ts&: atomic_op, ts: "(", ts: to_non_uniform_aware_expression(id: ops[0]), ts: ", ",
4606 ts: to_expression(id: ops[3]), ts: ", ", ts: to_expression(id: tmp_id), ts: ");");
4607 }
4608 else
4609 {
4610 string base = chain->base;
4611 if (has_decoration(id: chain->self, decoration: DecorationNonUniform))
4612 convert_non_uniform_expression(expr&: base, ptr_id: chain->self);
4613 // RWByteAddress buffer is always uint in its underlying type.
4614 statement(ts&: base, ts: ".", ts&: atomic_op, ts: "(", ts&: chain->dynamic_index, ts&: chain->static_index, ts: ", ",
4615 ts: to_expression(id: ops[3]), ts: ", ", ts: to_expression(id: tmp_id), ts: ");");
4616 }
4617 }
4618 else
4619 {
4620 uint32_t result_type = ops[0];
4621 uint32_t id = ops[1];
4622 forced_temporaries.insert(x: ops[1]);
4623
4624 auto &type = get<SPIRType>(id: result_type);
4625 statement(ts: variable_decl(type, name: to_name(id)), ts: ";");
4626
4627 auto &data_type = expression_type(id: ops[2]);
4628 auto *chain = maybe_get<SPIRAccessChain>(id: ops[2]);
4629 SPIRType::BaseType expr_type;
4630 if (data_type.storage == StorageClassImage || !chain)
4631 {
4632 statement(ts&: atomic_op, ts: "(", ts: to_non_uniform_aware_expression(id: ops[2]), ts: ", ", ts&: value_expr, ts: ", ", ts: to_name(id), ts: ");");
4633 expr_type = data_type.basetype;
4634 }
4635 else
4636 {
4637 // RWByteAddress buffer is always uint in its underlying type.
4638 string base = chain->base;
4639 if (has_decoration(id: chain->self, decoration: DecorationNonUniform))
4640 convert_non_uniform_expression(expr&: base, ptr_id: chain->self);
4641 expr_type = SPIRType::UInt;
4642 statement(ts&: base, ts: ".", ts&: atomic_op, ts: "(", ts&: chain->dynamic_index, ts&: chain->static_index, ts: ", ", ts&: value_expr,
4643 ts: ", ", ts: to_name(id), ts: ");");
4644 }
4645
4646 auto expr = bitcast_expression(target_type: type, expr_type, expr: to_name(id));
4647 set<SPIRExpression>(id, args&: expr, args&: result_type, args: true);
4648 }
4649 flush_all_atomic_capable_variables();
4650}
4651
4652void CompilerHLSL::emit_subgroup_op(const Instruction &i)
4653{
4654 if (hlsl_options.shader_model < 60)
4655 SPIRV_CROSS_THROW("Wave ops requires SM 6.0 or higher.");
4656
4657 const uint32_t *ops = stream(instr: i);
4658 auto op = static_cast<Op>(i.op);
4659
4660 uint32_t result_type = ops[0];
4661 uint32_t id = ops[1];
4662
4663 auto scope = static_cast<Scope>(evaluate_constant_u32(id: ops[2]));
4664 if (scope != ScopeSubgroup)
4665 SPIRV_CROSS_THROW("Only subgroup scope is supported.");
4666
4667 const auto make_inclusive_Sum = [&](const string &expr) -> string {
4668 return join(ts: expr, ts: " + ", ts: to_expression(id: ops[4]));
4669 };
4670
4671 const auto make_inclusive_Product = [&](const string &expr) -> string {
4672 return join(ts: expr, ts: " * ", ts: to_expression(id: ops[4]));
4673 };
4674
4675 // If we need to do implicit bitcasts, make sure we do it with the correct type.
4676 uint32_t integer_width = get_integer_width_for_instruction(instr: i);
4677 auto int_type = to_signed_basetype(width: integer_width);
4678 auto uint_type = to_unsigned_basetype(width: integer_width);
4679
4680#define make_inclusive_BitAnd(expr) ""
4681#define make_inclusive_BitOr(expr) ""
4682#define make_inclusive_BitXor(expr) ""
4683#define make_inclusive_Min(expr) ""
4684#define make_inclusive_Max(expr) ""
4685
4686 switch (op)
4687 {
4688 case OpGroupNonUniformElect:
4689 emit_op(result_type, result_id: id, rhs: "WaveIsFirstLane()", forward_rhs: true);
4690 break;
4691
4692 case OpGroupNonUniformBroadcast:
4693 emit_binary_func_op(result_type, result_id: id, op0: ops[3], op1: ops[4], op: "WaveReadLaneAt");
4694 break;
4695
4696 case OpGroupNonUniformBroadcastFirst:
4697 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "WaveReadLaneFirst");
4698 break;
4699
4700 case OpGroupNonUniformBallot:
4701 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "WaveActiveBallot");
4702 break;
4703
4704 case OpGroupNonUniformInverseBallot:
4705 SPIRV_CROSS_THROW("Cannot trivially implement InverseBallot in HLSL.");
4706
4707 case OpGroupNonUniformBallotBitExtract:
4708 SPIRV_CROSS_THROW("Cannot trivially implement BallotBitExtract in HLSL.");
4709
4710 case OpGroupNonUniformBallotFindLSB:
4711 SPIRV_CROSS_THROW("Cannot trivially implement BallotFindLSB in HLSL.");
4712
4713 case OpGroupNonUniformBallotFindMSB:
4714 SPIRV_CROSS_THROW("Cannot trivially implement BallotFindMSB in HLSL.");
4715
4716 case OpGroupNonUniformBallotBitCount:
4717 {
4718 auto operation = static_cast<GroupOperation>(ops[3]);
4719 if (operation == GroupOperationReduce)
4720 {
4721 bool forward = should_forward(id: ops[4]);
4722 auto left = join(ts: "countbits(", ts: to_enclosed_expression(id: ops[4]), ts: ".x) + countbits(",
4723 ts: to_enclosed_expression(id: ops[4]), ts: ".y)");
4724 auto right = join(ts: "countbits(", ts: to_enclosed_expression(id: ops[4]), ts: ".z) + countbits(",
4725 ts: to_enclosed_expression(id: ops[4]), ts: ".w)");
4726 emit_op(result_type, result_id: id, rhs: join(ts&: left, ts: " + ", ts&: right), forward_rhs: forward);
4727 inherit_expression_dependencies(dst: id, source: ops[4]);
4728 }
4729 else if (operation == GroupOperationInclusiveScan)
4730 SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Inclusive Scan in HLSL.");
4731 else if (operation == GroupOperationExclusiveScan)
4732 SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Exclusive Scan in HLSL.");
4733 else
4734 SPIRV_CROSS_THROW("Invalid BitCount operation.");
4735 break;
4736 }
4737
4738 case OpGroupNonUniformShuffle:
4739 emit_binary_func_op(result_type, result_id: id, op0: ops[3], op1: ops[4], op: "WaveReadLaneAt");
4740 break;
4741 case OpGroupNonUniformShuffleXor:
4742 {
4743 bool forward = should_forward(id: ops[3]);
4744 emit_op(result_type: ops[0], result_id: ops[1],
4745 rhs: join(ts: "WaveReadLaneAt(", ts: to_unpacked_expression(id: ops[3]), ts: ", ",
4746 ts: "WaveGetLaneIndex() ^ ", ts: to_enclosed_expression(id: ops[4]), ts: ")"), forward_rhs: forward);
4747 inherit_expression_dependencies(dst: ops[1], source: ops[3]);
4748 break;
4749 }
4750 case OpGroupNonUniformShuffleUp:
4751 {
4752 bool forward = should_forward(id: ops[3]);
4753 emit_op(result_type: ops[0], result_id: ops[1],
4754 rhs: join(ts: "WaveReadLaneAt(", ts: to_unpacked_expression(id: ops[3]), ts: ", ",
4755 ts: "WaveGetLaneIndex() - ", ts: to_enclosed_expression(id: ops[4]), ts: ")"), forward_rhs: forward);
4756 inherit_expression_dependencies(dst: ops[1], source: ops[3]);
4757 break;
4758 }
4759 case OpGroupNonUniformShuffleDown:
4760 {
4761 bool forward = should_forward(id: ops[3]);
4762 emit_op(result_type: ops[0], result_id: ops[1],
4763 rhs: join(ts: "WaveReadLaneAt(", ts: to_unpacked_expression(id: ops[3]), ts: ", ",
4764 ts: "WaveGetLaneIndex() + ", ts: to_enclosed_expression(id: ops[4]), ts: ")"), forward_rhs: forward);
4765 inherit_expression_dependencies(dst: ops[1], source: ops[3]);
4766 break;
4767 }
4768
4769 case OpGroupNonUniformAll:
4770 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "WaveActiveAllTrue");
4771 break;
4772
4773 case OpGroupNonUniformAny:
4774 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "WaveActiveAnyTrue");
4775 break;
4776
4777 case OpGroupNonUniformAllEqual:
4778 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "WaveActiveAllEqual");
4779 break;
4780
4781 // clang-format off
4782#define HLSL_GROUP_OP(op, hlsl_op, supports_scan) \
4783case OpGroupNonUniform##op: \
4784 { \
4785 auto operation = static_cast<GroupOperation>(ops[3]); \
4786 if (operation == GroupOperationReduce) \
4787 emit_unary_func_op(result_type, id, ops[4], "WaveActive" #hlsl_op); \
4788 else if (operation == GroupOperationInclusiveScan && supports_scan) \
4789 { \
4790 bool forward = should_forward(ops[4]); \
4791 emit_op(result_type, id, make_inclusive_##hlsl_op (join("WavePrefix" #hlsl_op, "(", to_expression(ops[4]), ")")), forward); \
4792 inherit_expression_dependencies(id, ops[4]); \
4793 } \
4794 else if (operation == GroupOperationExclusiveScan && supports_scan) \
4795 emit_unary_func_op(result_type, id, ops[4], "WavePrefix" #hlsl_op); \
4796 else if (operation == GroupOperationClusteredReduce) \
4797 SPIRV_CROSS_THROW("Cannot trivially implement ClusteredReduce in HLSL."); \
4798 else \
4799 SPIRV_CROSS_THROW("Invalid group operation."); \
4800 break; \
4801 }
4802
4803#define HLSL_GROUP_OP_CAST(op, hlsl_op, type) \
4804case OpGroupNonUniform##op: \
4805 { \
4806 auto operation = static_cast<GroupOperation>(ops[3]); \
4807 if (operation == GroupOperationReduce) \
4808 emit_unary_func_op_cast(result_type, id, ops[4], "WaveActive" #hlsl_op, type, type); \
4809 else \
4810 SPIRV_CROSS_THROW("Invalid group operation."); \
4811 break; \
4812 }
4813
4814 HLSL_GROUP_OP(FAdd, Sum, true)
4815 HLSL_GROUP_OP(FMul, Product, true)
4816 HLSL_GROUP_OP(FMin, Min, false)
4817 HLSL_GROUP_OP(FMax, Max, false)
4818 HLSL_GROUP_OP(IAdd, Sum, true)
4819 HLSL_GROUP_OP(IMul, Product, true)
4820 HLSL_GROUP_OP_CAST(SMin, Min, int_type)
4821 HLSL_GROUP_OP_CAST(SMax, Max, int_type)
4822 HLSL_GROUP_OP_CAST(UMin, Min, uint_type)
4823 HLSL_GROUP_OP_CAST(UMax, Max, uint_type)
4824 HLSL_GROUP_OP(BitwiseAnd, BitAnd, false)
4825 HLSL_GROUP_OP(BitwiseOr, BitOr, false)
4826 HLSL_GROUP_OP(BitwiseXor, BitXor, false)
4827 HLSL_GROUP_OP_CAST(LogicalAnd, BitAnd, uint_type)
4828 HLSL_GROUP_OP_CAST(LogicalOr, BitOr, uint_type)
4829 HLSL_GROUP_OP_CAST(LogicalXor, BitXor, uint_type)
4830
4831#undef HLSL_GROUP_OP
4832#undef HLSL_GROUP_OP_CAST
4833 // clang-format on
4834
4835 case OpGroupNonUniformQuadSwap:
4836 {
4837 uint32_t direction = evaluate_constant_u32(id: ops[4]);
4838 if (direction == 0)
4839 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "QuadReadAcrossX");
4840 else if (direction == 1)
4841 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "QuadReadAcrossY");
4842 else if (direction == 2)
4843 emit_unary_func_op(result_type, result_id: id, op0: ops[3], op: "QuadReadAcrossDiagonal");
4844 else
4845 SPIRV_CROSS_THROW("Invalid quad swap direction.");
4846 break;
4847 }
4848
4849 case OpGroupNonUniformQuadBroadcast:
4850 {
4851 emit_binary_func_op(result_type, result_id: id, op0: ops[3], op1: ops[4], op: "QuadReadLaneAt");
4852 break;
4853 }
4854
4855 default:
4856 SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
4857 }
4858
4859 register_control_dependent_expression(expr: id);
4860}
4861
4862void CompilerHLSL::emit_instruction(const Instruction &instruction)
4863{
4864 auto ops = stream(instr: instruction);
4865 auto opcode = static_cast<Op>(instruction.op);
4866
4867#define HLSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
4868#define HLSL_BOP_CAST(op, type) \
4869 emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
4870#define HLSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
4871#define HLSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
4872#define HLSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
4873#define HLSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
4874#define HLSL_BFOP_CAST(op, type) \
4875 emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
4876#define HLSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
4877#define HLSL_UFOP(op) emit_unary_func_op(ops[0], ops[1], ops[2], #op)
4878
4879 // If we need to do implicit bitcasts, make sure we do it with the correct type.
4880 uint32_t integer_width = get_integer_width_for_instruction(instr: instruction);
4881 auto int_type = to_signed_basetype(width: integer_width);
4882 auto uint_type = to_unsigned_basetype(width: integer_width);
4883
4884 opcode = get_remapped_spirv_op(op: opcode);
4885
4886 switch (opcode)
4887 {
4888 case OpAccessChain:
4889 case OpInBoundsAccessChain:
4890 {
4891 emit_access_chain(instruction);
4892 break;
4893 }
4894 case OpBitcast:
4895 {
4896 auto bitcast_type = get_bitcast_type(result_type: ops[0], op0: ops[2]);
4897 if (bitcast_type == CompilerHLSL::TypeNormal)
4898 CompilerGLSL::emit_instruction(instr: instruction);
4899 else
4900 {
4901 if (!requires_uint2_packing)
4902 {
4903 requires_uint2_packing = true;
4904 force_recompile();
4905 }
4906
4907 if (bitcast_type == CompilerHLSL::TypePackUint2x32)
4908 emit_unary_func_op(result_type: ops[0], result_id: ops[1], op0: ops[2], op: "spvPackUint2x32");
4909 else
4910 emit_unary_func_op(result_type: ops[0], result_id: ops[1], op0: ops[2], op: "spvUnpackUint2x32");
4911 }
4912
4913 break;
4914 }
4915
4916 case OpSelect:
4917 {
4918 auto &value_type = expression_type(id: ops[3]);
4919 if (value_type.basetype == SPIRType::Struct || is_array(type: value_type))
4920 {
4921 // HLSL does not support ternary expressions on composites.
4922 // Cannot use branches, since we might be in a continue block
4923 // where explicit control flow is prohibited.
4924 // Emit a helper function where we can use control flow.
4925 TypeID value_type_id = expression_type_id(id: ops[3]);
4926 auto itr = std::find(first: composite_selection_workaround_types.begin(),
4927 last: composite_selection_workaround_types.end(),
4928 val: value_type_id);
4929 if (itr == composite_selection_workaround_types.end())
4930 {
4931 composite_selection_workaround_types.push_back(x: value_type_id);
4932 force_recompile();
4933 }
4934 emit_uninitialized_temporary_expression(type: ops[0], id: ops[1]);
4935 statement(ts: "spvSelectComposite(",
4936 ts: to_expression(id: ops[1]), ts: ", ", ts: to_expression(id: ops[2]), ts: ", ",
4937 ts: to_expression(id: ops[3]), ts: ", ", ts: to_expression(id: ops[4]), ts: ");");
4938 }
4939 else
4940 CompilerGLSL::emit_instruction(instr: instruction);
4941 break;
4942 }
4943
4944 case OpStore:
4945 {
4946 emit_store(instruction);
4947 break;
4948 }
4949
4950 case OpLoad:
4951 {
4952 emit_load(instruction);
4953 break;
4954 }
4955
4956 case OpMatrixTimesVector:
4957 {
4958 // Matrices are kept in a transposed state all the time, flip multiplication order always.
4959 emit_binary_func_op(result_type: ops[0], result_id: ops[1], op0: ops[3], op1: ops[2], op: "mul");
4960 break;
4961 }
4962
4963 case OpVectorTimesMatrix:
4964 {
4965 // Matrices are kept in a transposed state all the time, flip multiplication order always.
4966 emit_binary_func_op(result_type: ops[0], result_id: ops[1], op0: ops[3], op1: ops[2], op: "mul");
4967 break;
4968 }
4969
4970 case OpMatrixTimesMatrix:
4971 {
4972 // Matrices are kept in a transposed state all the time, flip multiplication order always.
4973 emit_binary_func_op(result_type: ops[0], result_id: ops[1], op0: ops[3], op1: ops[2], op: "mul");
4974 break;
4975 }
4976
4977 case OpOuterProduct:
4978 {
4979 uint32_t result_type = ops[0];
4980 uint32_t id = ops[1];
4981 uint32_t a = ops[2];
4982 uint32_t b = ops[3];
4983
4984 auto &type = get<SPIRType>(id: result_type);
4985 string expr = type_to_glsl_constructor(type);
4986 expr += "(";
4987 for (uint32_t col = 0; col < type.columns; col++)
4988 {
4989 expr += to_enclosed_expression(id: a);
4990 expr += " * ";
4991 expr += to_extract_component_expression(id: b, index: col);
4992 if (col + 1 < type.columns)
4993 expr += ", ";
4994 }
4995 expr += ")";
4996 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: should_forward(id: a) && should_forward(id: b));
4997 inherit_expression_dependencies(dst: id, source: a);
4998 inherit_expression_dependencies(dst: id, source: b);
4999 break;
5000 }
5001
5002 case OpFMod:
5003 {
5004 if (!requires_op_fmod)
5005 {
5006 requires_op_fmod = true;
5007 force_recompile();
5008 }
5009 CompilerGLSL::emit_instruction(instr: instruction);
5010 break;
5011 }
5012
5013 case OpFRem:
5014 emit_binary_func_op(result_type: ops[0], result_id: ops[1], op0: ops[2], op1: ops[3], op: "fmod");
5015 break;
5016
5017 case OpImage:
5018 {
5019 uint32_t result_type = ops[0];
5020 uint32_t id = ops[1];
5021 auto *combined = maybe_get<SPIRCombinedImageSampler>(id: ops[2]);
5022
5023 if (combined)
5024 {
5025 auto &e = emit_op(result_type, result_id: id, rhs: to_expression(id: combined->image), forward_rhs: true, suppress_usage_tracking: true);
5026 auto *var = maybe_get_backing_variable(chain: combined->image);
5027 if (var)
5028 e.loaded_from = var->self;
5029 }
5030 else
5031 {
5032 auto &e = emit_op(result_type, result_id: id, rhs: to_expression(id: ops[2]), forward_rhs: true, suppress_usage_tracking: true);
5033 auto *var = maybe_get_backing_variable(chain: ops[2]);
5034 if (var)
5035 e.loaded_from = var->self;
5036 }
5037 break;
5038 }
5039
5040 case OpDPdx:
5041 HLSL_UFOP(ddx);
5042 register_control_dependent_expression(expr: ops[1]);
5043 break;
5044
5045 case OpDPdy:
5046 HLSL_UFOP(ddy);
5047 register_control_dependent_expression(expr: ops[1]);
5048 break;
5049
5050 case OpDPdxFine:
5051 HLSL_UFOP(ddx_fine);
5052 register_control_dependent_expression(expr: ops[1]);
5053 break;
5054
5055 case OpDPdyFine:
5056 HLSL_UFOP(ddy_fine);
5057 register_control_dependent_expression(expr: ops[1]);
5058 break;
5059
5060 case OpDPdxCoarse:
5061 HLSL_UFOP(ddx_coarse);
5062 register_control_dependent_expression(expr: ops[1]);
5063 break;
5064
5065 case OpDPdyCoarse:
5066 HLSL_UFOP(ddy_coarse);
5067 register_control_dependent_expression(expr: ops[1]);
5068 break;
5069
5070 case OpFwidth:
5071 case OpFwidthCoarse:
5072 case OpFwidthFine:
5073 HLSL_UFOP(fwidth);
5074 register_control_dependent_expression(expr: ops[1]);
5075 break;
5076
5077 case OpLogicalNot:
5078 {
5079 auto result_type = ops[0];
5080 auto id = ops[1];
5081 auto &type = get<SPIRType>(id: result_type);
5082
5083 if (type.vecsize > 1)
5084 emit_unrolled_unary_op(result_type, result_id: id, operand: ops[2], op: "!");
5085 else
5086 HLSL_UOP(!);
5087 break;
5088 }
5089
5090 case OpIEqual:
5091 {
5092 auto result_type = ops[0];
5093 auto id = ops[1];
5094
5095 if (expression_type(id: ops[2]).vecsize > 1)
5096 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "==", negate: false, expected_type: SPIRType::Unknown);
5097 else
5098 HLSL_BOP_CAST(==, int_type);
5099 break;
5100 }
5101
5102 case OpLogicalEqual:
5103 case OpFOrdEqual:
5104 case OpFUnordEqual:
5105 {
5106 // HLSL != operator is unordered.
5107 // https://docs.microsoft.com/en-us/windows/win32/direct3d10/d3d10-graphics-programming-guide-resources-float-rules.
5108 // isnan() is apparently implemented as x != x as well.
5109 // We cannot implement UnordEqual as !(OrdNotEqual), as HLSL cannot express OrdNotEqual.
5110 // HACK: FUnordEqual will be implemented as FOrdEqual.
5111
5112 auto result_type = ops[0];
5113 auto id = ops[1];
5114
5115 if (expression_type(id: ops[2]).vecsize > 1)
5116 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "==", negate: false, expected_type: SPIRType::Unknown);
5117 else
5118 HLSL_BOP(==);
5119 break;
5120 }
5121
5122 case OpINotEqual:
5123 {
5124 auto result_type = ops[0];
5125 auto id = ops[1];
5126
5127 if (expression_type(id: ops[2]).vecsize > 1)
5128 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "!=", negate: false, expected_type: SPIRType::Unknown);
5129 else
5130 HLSL_BOP_CAST(!=, int_type);
5131 break;
5132 }
5133
5134 case OpLogicalNotEqual:
5135 case OpFOrdNotEqual:
5136 case OpFUnordNotEqual:
5137 {
5138 // HLSL != operator is unordered.
5139 // https://docs.microsoft.com/en-us/windows/win32/direct3d10/d3d10-graphics-programming-guide-resources-float-rules.
5140 // isnan() is apparently implemented as x != x as well.
5141
5142 // FIXME: FOrdNotEqual cannot be implemented in a crisp and simple way here.
5143 // We would need to do something like not(UnordEqual), but that cannot be expressed either.
5144 // Adding a lot of NaN checks would be a breaking change from perspective of performance.
5145 // SPIR-V will generally use isnan() checks when this even matters.
5146 // HACK: FOrdNotEqual will be implemented as FUnordEqual.
5147
5148 auto result_type = ops[0];
5149 auto id = ops[1];
5150
5151 if (expression_type(id: ops[2]).vecsize > 1)
5152 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "!=", negate: false, expected_type: SPIRType::Unknown);
5153 else
5154 HLSL_BOP(!=);
5155 break;
5156 }
5157
5158 case OpUGreaterThan:
5159 case OpSGreaterThan:
5160 {
5161 auto result_type = ops[0];
5162 auto id = ops[1];
5163 auto type = opcode == OpUGreaterThan ? uint_type : int_type;
5164
5165 if (expression_type(id: ops[2]).vecsize > 1)
5166 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: ">", negate: false, expected_type: type);
5167 else
5168 HLSL_BOP_CAST(>, type);
5169 break;
5170 }
5171
5172 case OpFOrdGreaterThan:
5173 {
5174 auto result_type = ops[0];
5175 auto id = ops[1];
5176
5177 if (expression_type(id: ops[2]).vecsize > 1)
5178 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: ">", negate: false, expected_type: SPIRType::Unknown);
5179 else
5180 HLSL_BOP(>);
5181 break;
5182 }
5183
5184 case OpFUnordGreaterThan:
5185 {
5186 auto result_type = ops[0];
5187 auto id = ops[1];
5188
5189 if (expression_type(id: ops[2]).vecsize > 1)
5190 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "<=", negate: true, expected_type: SPIRType::Unknown);
5191 else
5192 CompilerGLSL::emit_instruction(instr: instruction);
5193 break;
5194 }
5195
5196 case OpUGreaterThanEqual:
5197 case OpSGreaterThanEqual:
5198 {
5199 auto result_type = ops[0];
5200 auto id = ops[1];
5201
5202 auto type = opcode == OpUGreaterThanEqual ? uint_type : int_type;
5203 if (expression_type(id: ops[2]).vecsize > 1)
5204 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: ">=", negate: false, expected_type: type);
5205 else
5206 HLSL_BOP_CAST(>=, type);
5207 break;
5208 }
5209
5210 case OpFOrdGreaterThanEqual:
5211 {
5212 auto result_type = ops[0];
5213 auto id = ops[1];
5214
5215 if (expression_type(id: ops[2]).vecsize > 1)
5216 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: ">=", negate: false, expected_type: SPIRType::Unknown);
5217 else
5218 HLSL_BOP(>=);
5219 break;
5220 }
5221
5222 case OpFUnordGreaterThanEqual:
5223 {
5224 auto result_type = ops[0];
5225 auto id = ops[1];
5226
5227 if (expression_type(id: ops[2]).vecsize > 1)
5228 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "<", negate: true, expected_type: SPIRType::Unknown);
5229 else
5230 CompilerGLSL::emit_instruction(instr: instruction);
5231 break;
5232 }
5233
5234 case OpULessThan:
5235 case OpSLessThan:
5236 {
5237 auto result_type = ops[0];
5238 auto id = ops[1];
5239
5240 auto type = opcode == OpULessThan ? uint_type : int_type;
5241 if (expression_type(id: ops[2]).vecsize > 1)
5242 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "<", negate: false, expected_type: type);
5243 else
5244 HLSL_BOP_CAST(<, type);
5245 break;
5246 }
5247
5248 case OpFOrdLessThan:
5249 {
5250 auto result_type = ops[0];
5251 auto id = ops[1];
5252
5253 if (expression_type(id: ops[2]).vecsize > 1)
5254 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "<", negate: false, expected_type: SPIRType::Unknown);
5255 else
5256 HLSL_BOP(<);
5257 break;
5258 }
5259
5260 case OpFUnordLessThan:
5261 {
5262 auto result_type = ops[0];
5263 auto id = ops[1];
5264
5265 if (expression_type(id: ops[2]).vecsize > 1)
5266 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: ">=", negate: true, expected_type: SPIRType::Unknown);
5267 else
5268 CompilerGLSL::emit_instruction(instr: instruction);
5269 break;
5270 }
5271
5272 case OpULessThanEqual:
5273 case OpSLessThanEqual:
5274 {
5275 auto result_type = ops[0];
5276 auto id = ops[1];
5277
5278 auto type = opcode == OpULessThanEqual ? uint_type : int_type;
5279 if (expression_type(id: ops[2]).vecsize > 1)
5280 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "<=", negate: false, expected_type: type);
5281 else
5282 HLSL_BOP_CAST(<=, type);
5283 break;
5284 }
5285
5286 case OpFOrdLessThanEqual:
5287 {
5288 auto result_type = ops[0];
5289 auto id = ops[1];
5290
5291 if (expression_type(id: ops[2]).vecsize > 1)
5292 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: "<=", negate: false, expected_type: SPIRType::Unknown);
5293 else
5294 HLSL_BOP(<=);
5295 break;
5296 }
5297
5298 case OpFUnordLessThanEqual:
5299 {
5300 auto result_type = ops[0];
5301 auto id = ops[1];
5302
5303 if (expression_type(id: ops[2]).vecsize > 1)
5304 emit_unrolled_binary_op(result_type, result_id: id, op0: ops[2], op1: ops[3], op: ">", negate: true, expected_type: SPIRType::Unknown);
5305 else
5306 CompilerGLSL::emit_instruction(instr: instruction);
5307 break;
5308 }
5309
5310 case OpImageQueryLod:
5311 emit_texture_op(i: instruction, sparse: false);
5312 break;
5313
5314 case OpImageQuerySizeLod:
5315 {
5316 auto result_type = ops[0];
5317 auto id = ops[1];
5318
5319 require_texture_query_variant(var_id: ops[2]);
5320 auto dummy_samples_levels = join(ts: get_fallback_name(id), ts: "_dummy_parameter");
5321 statement(ts: "uint ", ts&: dummy_samples_levels, ts: ";");
5322
5323 auto expr = join(ts: "spvTextureSize(", ts: to_non_uniform_aware_expression(id: ops[2]), ts: ", ",
5324 ts: bitcast_expression(target_type: SPIRType::UInt, arg: ops[3]), ts: ", ", ts&: dummy_samples_levels, ts: ")");
5325
5326 auto &restype = get<SPIRType>(id: ops[0]);
5327 expr = bitcast_expression(target_type: restype, expr_type: SPIRType::UInt, expr);
5328 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: true);
5329 break;
5330 }
5331
5332 case OpImageQuerySize:
5333 {
5334 auto result_type = ops[0];
5335 auto id = ops[1];
5336
5337 require_texture_query_variant(var_id: ops[2]);
5338 bool uav = expression_type(id: ops[2]).image.sampled == 2;
5339
5340 if (const auto *var = maybe_get_backing_variable(chain: ops[2]))
5341 if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(id: var->self, decoration: DecorationNonWritable))
5342 uav = false;
5343
5344 auto dummy_samples_levels = join(ts: get_fallback_name(id), ts: "_dummy_parameter");
5345 statement(ts: "uint ", ts&: dummy_samples_levels, ts: ";");
5346
5347 string expr;
5348 if (uav)
5349 expr = join(ts: "spvImageSize(", ts: to_non_uniform_aware_expression(id: ops[2]), ts: ", ", ts&: dummy_samples_levels, ts: ")");
5350 else
5351 expr = join(ts: "spvTextureSize(", ts: to_non_uniform_aware_expression(id: ops[2]), ts: ", 0u, ", ts&: dummy_samples_levels, ts: ")");
5352
5353 auto &restype = get<SPIRType>(id: ops[0]);
5354 expr = bitcast_expression(target_type: restype, expr_type: SPIRType::UInt, expr);
5355 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: true);
5356 break;
5357 }
5358
5359 case OpImageQuerySamples:
5360 case OpImageQueryLevels:
5361 {
5362 auto result_type = ops[0];
5363 auto id = ops[1];
5364
5365 require_texture_query_variant(var_id: ops[2]);
5366 bool uav = expression_type(id: ops[2]).image.sampled == 2;
5367 if (opcode == OpImageQueryLevels && uav)
5368 SPIRV_CROSS_THROW("Cannot query levels for UAV images.");
5369
5370 if (const auto *var = maybe_get_backing_variable(chain: ops[2]))
5371 if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(id: var->self, decoration: DecorationNonWritable))
5372 uav = false;
5373
5374 // Keep it simple and do not emit special variants to make this look nicer ...
5375 // This stuff is barely, if ever, used.
5376 forced_temporaries.insert(x: id);
5377 auto &type = get<SPIRType>(id: result_type);
5378 statement(ts: variable_decl(type, name: to_name(id)), ts: ";");
5379
5380 if (uav)
5381 statement(ts: "spvImageSize(", ts: to_non_uniform_aware_expression(id: ops[2]), ts: ", ", ts: to_name(id), ts: ");");
5382 else
5383 statement(ts: "spvTextureSize(", ts: to_non_uniform_aware_expression(id: ops[2]), ts: ", 0u, ", ts: to_name(id), ts: ");");
5384
5385 auto &restype = get<SPIRType>(id: ops[0]);
5386 auto expr = bitcast_expression(target_type: restype, expr_type: SPIRType::UInt, expr: to_name(id));
5387 set<SPIRExpression>(id, args&: expr, args&: result_type, args: true);
5388 break;
5389 }
5390
5391 case OpImageRead:
5392 {
5393 uint32_t result_type = ops[0];
5394 uint32_t id = ops[1];
5395 auto *var = maybe_get_backing_variable(chain: ops[2]);
5396 auto &type = expression_type(id: ops[2]);
5397 bool subpass_data = type.image.dim == DimSubpassData;
5398 bool pure = false;
5399
5400 string imgexpr;
5401
5402 if (subpass_data)
5403 {
5404 if (hlsl_options.shader_model < 40)
5405 SPIRV_CROSS_THROW("Subpass loads are not supported in HLSL shader model 2/3.");
5406
5407 // Similar to GLSL, implement subpass loads using texelFetch.
5408 if (type.image.ms)
5409 {
5410 uint32_t operands = ops[4];
5411 if (operands != ImageOperandsSampleMask || instruction.length != 6)
5412 SPIRV_CROSS_THROW("Multisampled image used in OpImageRead, but unexpected operand mask was used.");
5413 uint32_t sample = ops[5];
5414 imgexpr = join(ts: to_non_uniform_aware_expression(id: ops[2]), ts: ".Load(int2(gl_FragCoord.xy), ", ts: to_expression(id: sample), ts: ")");
5415 }
5416 else
5417 imgexpr = join(ts: to_non_uniform_aware_expression(id: ops[2]), ts: ".Load(int3(int2(gl_FragCoord.xy), 0))");
5418
5419 pure = true;
5420 }
5421 else
5422 {
5423 imgexpr = join(ts: to_non_uniform_aware_expression(id: ops[2]), ts: "[", ts: to_expression(id: ops[3]), ts: "]");
5424 // The underlying image type in HLSL depends on the image format, unlike GLSL, where all images are "vec4",
5425 // except that the underlying type changes how the data is interpreted.
5426
5427 bool force_srv =
5428 hlsl_options.nonwritable_uav_texture_as_srv && var && has_decoration(id: var->self, decoration: DecorationNonWritable);
5429 pure = force_srv;
5430
5431 if (var && !subpass_data && !force_srv)
5432 imgexpr = remap_swizzle(result_type: get<SPIRType>(id: result_type),
5433 input_components: image_format_to_components(fmt: get<SPIRType>(id: var->basetype).image.format), expr: imgexpr);
5434 }
5435
5436 if (var)
5437 {
5438 bool forward = forced_temporaries.find(x: id) == end(cont&: forced_temporaries);
5439 auto &e = emit_op(result_type, result_id: id, rhs: imgexpr, forward_rhs: forward);
5440
5441 if (!pure)
5442 {
5443 e.loaded_from = var->self;
5444 if (forward)
5445 var->dependees.push_back(t: id);
5446 }
5447 }
5448 else
5449 emit_op(result_type, result_id: id, rhs: imgexpr, forward_rhs: false);
5450
5451 inherit_expression_dependencies(dst: id, source: ops[2]);
5452 if (type.image.ms)
5453 inherit_expression_dependencies(dst: id, source: ops[5]);
5454 break;
5455 }
5456
5457 case OpImageWrite:
5458 {
5459 auto *var = maybe_get_backing_variable(chain: ops[0]);
5460
5461 // The underlying image type in HLSL depends on the image format, unlike GLSL, where all images are "vec4",
5462 // except that the underlying type changes how the data is interpreted.
5463 auto value_expr = to_expression(id: ops[2]);
5464 if (var)
5465 {
5466 auto &type = get<SPIRType>(id: var->basetype);
5467 auto narrowed_type = get<SPIRType>(id: type.image.type);
5468 narrowed_type.vecsize = image_format_to_components(fmt: type.image.format);
5469 value_expr = remap_swizzle(result_type: narrowed_type, input_components: expression_type(id: ops[2]).vecsize, expr: value_expr);
5470 }
5471
5472 statement(ts: to_non_uniform_aware_expression(id: ops[0]), ts: "[", ts: to_expression(id: ops[1]), ts: "] = ", ts&: value_expr, ts: ";");
5473 if (var && variable_storage_is_aliased(var: *var))
5474 flush_all_aliased_variables();
5475 break;
5476 }
5477
5478 case OpImageTexelPointer:
5479 {
5480 uint32_t result_type = ops[0];
5481 uint32_t id = ops[1];
5482
5483 auto expr = to_expression(id: ops[2]);
5484 expr += join(ts: "[", ts: to_expression(id: ops[3]), ts: "]");
5485 auto &e = set<SPIRExpression>(id, args&: expr, args&: result_type, args: true);
5486
5487 // When using the pointer, we need to know which variable it is actually loaded from.
5488 auto *var = maybe_get_backing_variable(chain: ops[2]);
5489 e.loaded_from = var ? var->self : ID(0);
5490 inherit_expression_dependencies(dst: id, source: ops[3]);
5491 break;
5492 }
5493
5494 case OpAtomicCompareExchange:
5495 case OpAtomicExchange:
5496 case OpAtomicISub:
5497 case OpAtomicSMin:
5498 case OpAtomicUMin:
5499 case OpAtomicSMax:
5500 case OpAtomicUMax:
5501 case OpAtomicAnd:
5502 case OpAtomicOr:
5503 case OpAtomicXor:
5504 case OpAtomicIAdd:
5505 case OpAtomicIIncrement:
5506 case OpAtomicIDecrement:
5507 case OpAtomicLoad:
5508 case OpAtomicStore:
5509 {
5510 emit_atomic(ops, length: instruction.length, op: opcode);
5511 break;
5512 }
5513
5514 case OpControlBarrier:
5515 case OpMemoryBarrier:
5516 {
5517 uint32_t memory;
5518 uint32_t semantics;
5519
5520 if (opcode == OpMemoryBarrier)
5521 {
5522 memory = evaluate_constant_u32(id: ops[0]);
5523 semantics = evaluate_constant_u32(id: ops[1]);
5524 }
5525 else
5526 {
5527 memory = evaluate_constant_u32(id: ops[1]);
5528 semantics = evaluate_constant_u32(id: ops[2]);
5529 }
5530
5531 if (memory == ScopeSubgroup)
5532 {
5533 // No Wave-barriers in HLSL.
5534 break;
5535 }
5536
5537 // We only care about these flags, acquire/release and friends are not relevant to GLSL.
5538 semantics = mask_relevant_memory_semantics(semantics);
5539
5540 if (opcode == OpMemoryBarrier)
5541 {
5542 // If we are a memory barrier, and the next instruction is a control barrier, check if that memory barrier
5543 // does what we need, so we avoid redundant barriers.
5544 const Instruction *next = get_next_instruction_in_block(instr: instruction);
5545 if (next && next->op == OpControlBarrier)
5546 {
5547 auto *next_ops = stream(instr: *next);
5548 uint32_t next_memory = evaluate_constant_u32(id: next_ops[1]);
5549 uint32_t next_semantics = evaluate_constant_u32(id: next_ops[2]);
5550 next_semantics = mask_relevant_memory_semantics(semantics: next_semantics);
5551
5552 // There is no "just execution barrier" in HLSL.
5553 // If there are no memory semantics for next instruction, we will imply group shared memory is synced.
5554 if (next_semantics == 0)
5555 next_semantics = MemorySemanticsWorkgroupMemoryMask;
5556
5557 bool memory_scope_covered = false;
5558 if (next_memory == memory)
5559 memory_scope_covered = true;
5560 else if (next_semantics == MemorySemanticsWorkgroupMemoryMask)
5561 {
5562 // If we only care about workgroup memory, either Device or Workgroup scope is fine,
5563 // scope does not have to match.
5564 if ((next_memory == ScopeDevice || next_memory == ScopeWorkgroup) &&
5565 (memory == ScopeDevice || memory == ScopeWorkgroup))
5566 {
5567 memory_scope_covered = true;
5568 }
5569 }
5570 else if (memory == ScopeWorkgroup && next_memory == ScopeDevice)
5571 {
5572 // The control barrier has device scope, but the memory barrier just has workgroup scope.
5573 memory_scope_covered = true;
5574 }
5575
5576 // If we have the same memory scope, and all memory types are covered, we're good.
5577 if (memory_scope_covered && (semantics & next_semantics) == semantics)
5578 break;
5579 }
5580 }
5581
5582 // We are synchronizing some memory or syncing execution,
5583 // so we cannot forward any loads beyond the memory barrier.
5584 if (semantics || opcode == OpControlBarrier)
5585 {
5586 assert(current_emitting_block);
5587 flush_control_dependent_expressions(block: current_emitting_block->self);
5588 flush_all_active_variables();
5589 }
5590
5591 if (opcode == OpControlBarrier)
5592 {
5593 // We cannot emit just execution barrier, for no memory semantics pick the cheapest option.
5594 if (semantics == MemorySemanticsWorkgroupMemoryMask || semantics == 0)
5595 statement(ts: "GroupMemoryBarrierWithGroupSync();");
5596 else if (semantics != 0 && (semantics & MemorySemanticsWorkgroupMemoryMask) == 0)
5597 statement(ts: "DeviceMemoryBarrierWithGroupSync();");
5598 else
5599 statement(ts: "AllMemoryBarrierWithGroupSync();");
5600 }
5601 else
5602 {
5603 if (semantics == MemorySemanticsWorkgroupMemoryMask)
5604 statement(ts: "GroupMemoryBarrier();");
5605 else if (semantics != 0 && (semantics & MemorySemanticsWorkgroupMemoryMask) == 0)
5606 statement(ts: "DeviceMemoryBarrier();");
5607 else
5608 statement(ts: "AllMemoryBarrier();");
5609 }
5610 break;
5611 }
5612
5613 case OpBitFieldInsert:
5614 {
5615 if (!requires_bitfield_insert)
5616 {
5617 requires_bitfield_insert = true;
5618 force_recompile();
5619 }
5620
5621 auto expr = join(ts: "spvBitfieldInsert(", ts: to_expression(id: ops[2]), ts: ", ", ts: to_expression(id: ops[3]), ts: ", ",
5622 ts: to_expression(id: ops[4]), ts: ", ", ts: to_expression(id: ops[5]), ts: ")");
5623
5624 bool forward =
5625 should_forward(id: ops[2]) && should_forward(id: ops[3]) && should_forward(id: ops[4]) && should_forward(id: ops[5]);
5626
5627 auto &restype = get<SPIRType>(id: ops[0]);
5628 expr = bitcast_expression(target_type: restype, expr_type: SPIRType::UInt, expr);
5629 emit_op(result_type: ops[0], result_id: ops[1], rhs: expr, forward_rhs: forward);
5630 break;
5631 }
5632
5633 case OpBitFieldSExtract:
5634 case OpBitFieldUExtract:
5635 {
5636 if (!requires_bitfield_extract)
5637 {
5638 requires_bitfield_extract = true;
5639 force_recompile();
5640 }
5641
5642 if (opcode == OpBitFieldSExtract)
5643 HLSL_TFOP(spvBitfieldSExtract);
5644 else
5645 HLSL_TFOP(spvBitfieldUExtract);
5646 break;
5647 }
5648
5649 case OpBitCount:
5650 {
5651 auto basetype = expression_type(id: ops[2]).basetype;
5652 emit_unary_func_op_cast(result_type: ops[0], result_id: ops[1], op0: ops[2], op: "countbits", input_type: basetype, expected_result_type: basetype);
5653 break;
5654 }
5655
5656 case OpBitReverse:
5657 HLSL_UFOP(reversebits);
5658 break;
5659
5660 case OpArrayLength:
5661 {
5662 auto *var = maybe_get_backing_variable(chain: ops[2]);
5663 if (!var)
5664 SPIRV_CROSS_THROW("Array length must point directly to an SSBO block.");
5665
5666 auto &type = get<SPIRType>(id: var->basetype);
5667 if (!has_decoration(id: type.self, decoration: DecorationBlock) && !has_decoration(id: type.self, decoration: DecorationBufferBlock))
5668 SPIRV_CROSS_THROW("Array length expression must point to a block type.");
5669
5670 // This must be 32-bit uint, so we're good to go.
5671 emit_uninitialized_temporary_expression(type: ops[0], id: ops[1]);
5672 statement(ts: to_non_uniform_aware_expression(id: ops[2]), ts: ".GetDimensions(", ts: to_expression(id: ops[1]), ts: ");");
5673 uint32_t offset = type_struct_member_offset(type, index: ops[3]);
5674 uint32_t stride = type_struct_member_array_stride(type, index: ops[3]);
5675 statement(ts: to_expression(id: ops[1]), ts: " = (", ts: to_expression(id: ops[1]), ts: " - ", ts&: offset, ts: ") / ", ts&: stride, ts: ";");
5676 break;
5677 }
5678
5679 case OpIsHelperInvocationEXT:
5680 if (hlsl_options.shader_model < 50 || get_entry_point().model != ExecutionModelFragment)
5681 SPIRV_CROSS_THROW("Helper Invocation input is only supported in PS 5.0 or higher.");
5682 // Helper lane state with demote is volatile by nature.
5683 // Do not forward this.
5684 emit_op(result_type: ops[0], result_id: ops[1], rhs: "IsHelperLane()", forward_rhs: false);
5685 break;
5686
5687 case OpBeginInvocationInterlockEXT:
5688 case OpEndInvocationInterlockEXT:
5689 if (hlsl_options.shader_model < 51)
5690 SPIRV_CROSS_THROW("Rasterizer order views require Shader Model 5.1.");
5691 break; // Nothing to do in the body
5692
5693 case OpRayQueryInitializeKHR:
5694 {
5695 flush_variable_declaration(id: ops[0]);
5696
5697 std::string ray_desc_name = get_unique_identifier();
5698 statement(ts: "RayDesc ", ts&: ray_desc_name, ts: " = {", ts: to_expression(id: ops[4]), ts: ", ", ts: to_expression(id: ops[5]), ts: ", ",
5699 ts: to_expression(id: ops[6]), ts: ", ", ts: to_expression(id: ops[7]), ts: "};");
5700
5701 statement(ts: to_expression(id: ops[0]), ts: ".TraceRayInline(",
5702 ts: to_expression(id: ops[1]), ts: ", ", // acc structure
5703 ts: to_expression(id: ops[2]), ts: ", ", // ray flags
5704 ts: to_expression(id: ops[3]), ts: ", ", // mask
5705 ts&: ray_desc_name, ts: ");"); // ray
5706 break;
5707 }
5708 case OpRayQueryProceedKHR:
5709 {
5710 flush_variable_declaration(id: ops[0]);
5711 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".Proceed()"), forward_rhs: false);
5712 break;
5713 }
5714 case OpRayQueryTerminateKHR:
5715 {
5716 flush_variable_declaration(id: ops[0]);
5717 statement(ts: to_expression(id: ops[0]), ts: ".Abort();");
5718 break;
5719 }
5720 case OpRayQueryGenerateIntersectionKHR:
5721 {
5722 flush_variable_declaration(id: ops[0]);
5723 statement(ts: to_expression(id: ops[0]), ts: ".CommitProceduralPrimitiveHit(", ts: ops[1], ts: ");");
5724 break;
5725 }
5726 case OpRayQueryConfirmIntersectionKHR:
5727 {
5728 flush_variable_declaration(id: ops[0]);
5729 statement(ts: to_expression(id: ops[0]), ts: ".CommitNonOpaqueTriangleHit();");
5730 break;
5731 }
5732 case OpRayQueryGetIntersectionTypeKHR:
5733 {
5734 emit_rayquery_function(commited: ".CommittedStatus()", candidate: ".CandidateType()", ops);
5735 break;
5736 }
5737 case OpRayQueryGetIntersectionTKHR:
5738 {
5739 emit_rayquery_function(commited: ".CommittedRayT()", candidate: ".CandidateTriangleRayT()", ops);
5740 break;
5741 }
5742 case OpRayQueryGetIntersectionInstanceCustomIndexKHR:
5743 {
5744 emit_rayquery_function(commited: ".CommittedInstanceID()", candidate: ".CandidateInstanceID()", ops);
5745 break;
5746 }
5747 case OpRayQueryGetIntersectionInstanceIdKHR:
5748 {
5749 emit_rayquery_function(commited: ".CommittedInstanceIndex()", candidate: ".CandidateInstanceIndex()", ops);
5750 break;
5751 }
5752 case OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR:
5753 {
5754 emit_rayquery_function(commited: ".CommittedInstanceContributionToHitGroupIndex()",
5755 candidate: ".CandidateInstanceContributionToHitGroupIndex()", ops);
5756 break;
5757 }
5758 case OpRayQueryGetIntersectionGeometryIndexKHR:
5759 {
5760 emit_rayquery_function(commited: ".CommittedGeometryIndex()",
5761 candidate: ".CandidateGeometryIndex()", ops);
5762 break;
5763 }
5764 case OpRayQueryGetIntersectionPrimitiveIndexKHR:
5765 {
5766 emit_rayquery_function(commited: ".CommittedPrimitiveIndex()", candidate: ".CandidatePrimitiveIndex()", ops);
5767 break;
5768 }
5769 case OpRayQueryGetIntersectionBarycentricsKHR:
5770 {
5771 emit_rayquery_function(commited: ".CommittedTriangleBarycentrics()", candidate: ".CandidateTriangleBarycentrics()", ops);
5772 break;
5773 }
5774 case OpRayQueryGetIntersectionFrontFaceKHR:
5775 {
5776 emit_rayquery_function(commited: ".CommittedTriangleFrontFace()", candidate: ".CandidateTriangleFrontFace()", ops);
5777 break;
5778 }
5779 case OpRayQueryGetIntersectionCandidateAABBOpaqueKHR:
5780 {
5781 flush_variable_declaration(id: ops[0]);
5782 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".CandidateProceduralPrimitiveNonOpaque()"), forward_rhs: false);
5783 break;
5784 }
5785 case OpRayQueryGetIntersectionObjectRayDirectionKHR:
5786 {
5787 emit_rayquery_function(commited: ".CommittedObjectRayDirection()", candidate: ".CandidateObjectRayDirection()", ops);
5788 break;
5789 }
5790 case OpRayQueryGetIntersectionObjectRayOriginKHR:
5791 {
5792 flush_variable_declaration(id: ops[0]);
5793 emit_rayquery_function(commited: ".CommittedObjectRayOrigin()", candidate: ".CandidateObjectRayOrigin()", ops);
5794 break;
5795 }
5796 case OpRayQueryGetIntersectionObjectToWorldKHR:
5797 {
5798 emit_rayquery_function(commited: ".CommittedObjectToWorld4x3()", candidate: ".CandidateObjectToWorld4x3()", ops);
5799 break;
5800 }
5801 case OpRayQueryGetIntersectionWorldToObjectKHR:
5802 {
5803 emit_rayquery_function(commited: ".CommittedWorldToObject4x3()", candidate: ".CandidateWorldToObject4x3()", ops);
5804 break;
5805 }
5806 case OpRayQueryGetRayFlagsKHR:
5807 {
5808 flush_variable_declaration(id: ops[0]);
5809 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".RayFlags()"), forward_rhs: false);
5810 break;
5811 }
5812 case OpRayQueryGetRayTMinKHR:
5813 {
5814 flush_variable_declaration(id: ops[0]);
5815 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".RayTMin()"), forward_rhs: false);
5816 break;
5817 }
5818 case OpRayQueryGetWorldRayOriginKHR:
5819 {
5820 flush_variable_declaration(id: ops[0]);
5821 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".WorldRayOrigin()"), forward_rhs: false);
5822 break;
5823 }
5824 case OpRayQueryGetWorldRayDirectionKHR:
5825 {
5826 flush_variable_declaration(id: ops[0]);
5827 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".WorldRayDirection()"), forward_rhs: false);
5828 break;
5829 }
5830 default:
5831 CompilerGLSL::emit_instruction(instr: instruction);
5832 break;
5833 }
5834}
5835
5836void CompilerHLSL::require_texture_query_variant(uint32_t var_id)
5837{
5838 if (const auto *var = maybe_get_backing_variable(chain: var_id))
5839 var_id = var->self;
5840
5841 auto &type = expression_type(id: var_id);
5842 bool uav = type.image.sampled == 2;
5843 if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(id: var_id, decoration: DecorationNonWritable))
5844 uav = false;
5845
5846 uint32_t bit = 0;
5847 switch (type.image.dim)
5848 {
5849 case Dim1D:
5850 bit = type.image.arrayed ? Query1DArray : Query1D;
5851 break;
5852
5853 case Dim2D:
5854 if (type.image.ms)
5855 bit = type.image.arrayed ? Query2DMSArray : Query2DMS;
5856 else
5857 bit = type.image.arrayed ? Query2DArray : Query2D;
5858 break;
5859
5860 case Dim3D:
5861 bit = Query3D;
5862 break;
5863
5864 case DimCube:
5865 bit = type.image.arrayed ? QueryCubeArray : QueryCube;
5866 break;
5867
5868 case DimBuffer:
5869 bit = QueryBuffer;
5870 break;
5871
5872 default:
5873 SPIRV_CROSS_THROW("Unsupported query type.");
5874 }
5875
5876 switch (get<SPIRType>(id: type.image.type).basetype)
5877 {
5878 case SPIRType::Float:
5879 bit += QueryTypeFloat;
5880 break;
5881
5882 case SPIRType::Int:
5883 bit += QueryTypeInt;
5884 break;
5885
5886 case SPIRType::UInt:
5887 bit += QueryTypeUInt;
5888 break;
5889
5890 default:
5891 SPIRV_CROSS_THROW("Unsupported query type.");
5892 }
5893
5894 auto norm_state = image_format_to_normalized_state(fmt: type.image.format);
5895 auto &variant = uav ? required_texture_size_variants
5896 .uav[uint32_t(norm_state)][image_format_to_components(fmt: type.image.format) - 1] :
5897 required_texture_size_variants.srv;
5898
5899 uint64_t mask = 1ull << bit;
5900 if ((variant & mask) == 0)
5901 {
5902 force_recompile();
5903 variant |= mask;
5904 }
5905}
5906
5907void CompilerHLSL::set_root_constant_layouts(std::vector<RootConstants> layout)
5908{
5909 root_constants_layout = std::move(layout);
5910}
5911
5912void CompilerHLSL::add_vertex_attribute_remap(const HLSLVertexAttributeRemap &vertex_attributes)
5913{
5914 remap_vertex_attributes.push_back(t: vertex_attributes);
5915}
5916
5917VariableID CompilerHLSL::remap_num_workgroups_builtin()
5918{
5919 update_active_builtins();
5920
5921 if (!active_input_builtins.get(bit: BuiltInNumWorkgroups))
5922 return 0;
5923
5924 // Create a new, fake UBO.
5925 uint32_t offset = ir.increase_bound_by(count: 4);
5926
5927 uint32_t uint_type_id = offset;
5928 uint32_t block_type_id = offset + 1;
5929 uint32_t block_pointer_type_id = offset + 2;
5930 uint32_t variable_id = offset + 3;
5931
5932 SPIRType uint_type;
5933 uint_type.basetype = SPIRType::UInt;
5934 uint_type.width = 32;
5935 uint_type.vecsize = 3;
5936 uint_type.columns = 1;
5937 set<SPIRType>(id: uint_type_id, args&: uint_type);
5938
5939 SPIRType block_type;
5940 block_type.basetype = SPIRType::Struct;
5941 block_type.member_types.push_back(t: uint_type_id);
5942 set<SPIRType>(id: block_type_id, args&: block_type);
5943 set_decoration(id: block_type_id, decoration: DecorationBlock);
5944 set_member_name(id: block_type_id, index: 0, name: "count");
5945 set_member_decoration(id: block_type_id, index: 0, decoration: DecorationOffset, argument: 0);
5946
5947 SPIRType block_pointer_type = block_type;
5948 block_pointer_type.pointer = true;
5949 block_pointer_type.storage = StorageClassUniform;
5950 block_pointer_type.parent_type = block_type_id;
5951 auto &ptr_type = set<SPIRType>(id: block_pointer_type_id, args&: block_pointer_type);
5952
5953 // Preserve self.
5954 ptr_type.self = block_type_id;
5955
5956 set<SPIRVariable>(id: variable_id, args&: block_pointer_type_id, args: StorageClassUniform);
5957 ir.meta[variable_id].decoration.alias = "SPIRV_Cross_NumWorkgroups";
5958
5959 num_workgroups_builtin = variable_id;
5960 get_entry_point().interface_variables.push_back(t: num_workgroups_builtin);
5961 return variable_id;
5962}
5963
5964void CompilerHLSL::set_resource_binding_flags(HLSLBindingFlags flags)
5965{
5966 resource_binding_flags = flags;
5967}
5968
5969void CompilerHLSL::validate_shader_model()
5970{
5971 // Check for nonuniform qualifier.
5972 // Instead of looping over all decorations to find this, just look at capabilities.
5973 for (auto &cap : ir.declared_capabilities)
5974 {
5975 switch (cap)
5976 {
5977 case CapabilityShaderNonUniformEXT:
5978 case CapabilityRuntimeDescriptorArrayEXT:
5979 if (hlsl_options.shader_model < 51)
5980 SPIRV_CROSS_THROW(
5981 "Shader model 5.1 or higher is required to use bindless resources or NonUniformResourceIndex.");
5982 break;
5983
5984 case CapabilityVariablePointers:
5985 case CapabilityVariablePointersStorageBuffer:
5986 SPIRV_CROSS_THROW("VariablePointers capability is not supported in HLSL.");
5987
5988 default:
5989 break;
5990 }
5991 }
5992
5993 if (ir.addressing_model != AddressingModelLogical)
5994 SPIRV_CROSS_THROW("Only Logical addressing model can be used with HLSL.");
5995
5996 if (hlsl_options.enable_16bit_types && hlsl_options.shader_model < 62)
5997 SPIRV_CROSS_THROW("Need at least shader model 6.2 when enabling native 16-bit type support.");
5998}
5999
6000string CompilerHLSL::compile()
6001{
6002 ir.fixup_reserved_names();
6003
6004 // Do not deal with ES-isms like precision, older extensions and such.
6005 options.es = false;
6006 options.version = 450;
6007 options.vulkan_semantics = true;
6008 backend.float_literal_suffix = true;
6009 backend.double_literal_suffix = false;
6010 backend.long_long_literal_suffix = true;
6011 backend.uint32_t_literal_suffix = true;
6012 backend.int16_t_literal_suffix = "";
6013 backend.uint16_t_literal_suffix = "u";
6014 backend.basic_int_type = "int";
6015 backend.basic_uint_type = "uint";
6016 backend.demote_literal = "discard";
6017 backend.boolean_mix_function = "";
6018 backend.swizzle_is_function = false;
6019 backend.shared_is_implied = true;
6020 backend.unsized_array_supported = true;
6021 backend.explicit_struct_type = false;
6022 backend.use_initializer_list = true;
6023 backend.use_constructor_splatting = false;
6024 backend.can_swizzle_scalar = true;
6025 backend.can_declare_struct_inline = false;
6026 backend.can_declare_arrays_inline = false;
6027 backend.can_return_array = false;
6028 backend.nonuniform_qualifier = "NonUniformResourceIndex";
6029 backend.support_case_fallthrough = false;
6030
6031 // SM 4.1 does not support precise for some reason.
6032 backend.support_precise_qualifier = hlsl_options.shader_model >= 50 || hlsl_options.shader_model == 40;
6033
6034 fixup_anonymous_struct_names();
6035 fixup_type_alias();
6036 reorder_type_alias();
6037 build_function_control_flow_graphs_and_analyze();
6038 validate_shader_model();
6039 update_active_builtins();
6040 analyze_image_and_sampler_usage();
6041 analyze_interlocked_resource_usage();
6042
6043 // Subpass input needs SV_Position.
6044 if (need_subpass_input)
6045 active_input_builtins.set(BuiltInFragCoord);
6046
6047 uint32_t pass_count = 0;
6048 do
6049 {
6050 reset(iteration_count: pass_count);
6051
6052 // Move constructor for this type is broken on GCC 4.9 ...
6053 buffer.reset();
6054
6055 emit_header();
6056 emit_resources();
6057
6058 emit_function(func&: get<SPIRFunction>(id: ir.default_entry_point), return_flags: Bitset());
6059 emit_hlsl_entry_point();
6060
6061 pass_count++;
6062 } while (is_forcing_recompilation());
6063
6064 // Entry point in HLSL is always main() for the time being.
6065 get_entry_point().name = "main";
6066
6067 return buffer.str();
6068}
6069
6070void CompilerHLSL::emit_block_hints(const SPIRBlock &block)
6071{
6072 switch (block.hint)
6073 {
6074 case SPIRBlock::HintFlatten:
6075 statement(ts: "[flatten]");
6076 break;
6077 case SPIRBlock::HintDontFlatten:
6078 statement(ts: "[branch]");
6079 break;
6080 case SPIRBlock::HintUnroll:
6081 statement(ts: "[unroll]");
6082 break;
6083 case SPIRBlock::HintDontUnroll:
6084 statement(ts: "[loop]");
6085 break;
6086 default:
6087 break;
6088 }
6089}
6090
6091string CompilerHLSL::get_unique_identifier()
6092{
6093 return join(ts: "_", ts: unique_identifier_count++, ts: "ident");
6094}
6095
6096void CompilerHLSL::add_hlsl_resource_binding(const HLSLResourceBinding &binding)
6097{
6098 StageSetBinding tuple = { .model: binding.stage, .desc_set: binding.desc_set, .binding: binding.binding };
6099 resource_bindings[tuple] = { binding, false };
6100}
6101
6102bool CompilerHLSL::is_hlsl_resource_binding_used(ExecutionModel model, uint32_t desc_set, uint32_t binding) const
6103{
6104 StageSetBinding tuple = { .model: model, .desc_set: desc_set, .binding: binding };
6105 auto itr = resource_bindings.find(x: tuple);
6106 return itr != end(cont: resource_bindings) && itr->second.second;
6107}
6108
6109CompilerHLSL::BitcastType CompilerHLSL::get_bitcast_type(uint32_t result_type, uint32_t op0)
6110{
6111 auto &rslt_type = get<SPIRType>(id: result_type);
6112 auto &expr_type = expression_type(id: op0);
6113
6114 if (rslt_type.basetype == SPIRType::BaseType::UInt64 && expr_type.basetype == SPIRType::BaseType::UInt &&
6115 expr_type.vecsize == 2)
6116 return BitcastType::TypePackUint2x32;
6117 else if (rslt_type.basetype == SPIRType::BaseType::UInt && rslt_type.vecsize == 2 &&
6118 expr_type.basetype == SPIRType::BaseType::UInt64)
6119 return BitcastType::TypeUnpackUint64;
6120
6121 return BitcastType::TypeNormal;
6122}
6123
6124bool CompilerHLSL::is_hlsl_force_storage_buffer_as_uav(ID id) const
6125{
6126 if (hlsl_options.force_storage_buffer_as_uav)
6127 {
6128 return true;
6129 }
6130
6131 const uint32_t desc_set = get_decoration(id, decoration: spv::DecorationDescriptorSet);
6132 const uint32_t binding = get_decoration(id, decoration: spv::DecorationBinding);
6133
6134 return (force_uav_buffer_bindings.find(x: { .desc_set: desc_set, .binding: binding }) != force_uav_buffer_bindings.end());
6135}
6136
6137void CompilerHLSL::set_hlsl_force_storage_buffer_as_uav(uint32_t desc_set, uint32_t binding)
6138{
6139 SetBindingPair pair = { .desc_set: desc_set, .binding: binding };
6140 force_uav_buffer_bindings.insert(x: pair);
6141}
6142
6143bool CompilerHLSL::builtin_translates_to_nonarray(spv::BuiltIn builtin) const
6144{
6145 return (builtin == BuiltInSampleMask);
6146}
6147

source code of qtshadertools/src/3rdparty/SPIRV-Cross/spirv_hlsl.cpp