1/*
2 * Copyright 2016-2021 The Brenwill Workshop Ltd.
3 * SPDX-License-Identifier: Apache-2.0 OR MIT
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18/*
19 * At your option, you may choose to accept this material under either:
20 * 1. The Apache License, Version 2.0, found at <http://www.apache.org/licenses/LICENSE-2.0>, or
21 * 2. The MIT License, found at <http://opensource.org/licenses/MIT>.
22 */
23
24#include "spirv_msl.hpp"
25#include "GLSL.std.450.h"
26
27#include <algorithm>
28#include <assert.h>
29#include <numeric>
30
31using namespace spv;
32using namespace SPIRV_CROSS_NAMESPACE;
33using namespace std;
34
35static const uint32_t k_unknown_location = ~0u;
36static const uint32_t k_unknown_component = ~0u;
37static const char *force_inline = "static inline __attribute__((always_inline))";
38
39CompilerMSL::CompilerMSL(std::vector<uint32_t> spirv_)
40 : CompilerGLSL(std::move(spirv_))
41{
42}
43
44CompilerMSL::CompilerMSL(const uint32_t *ir_, size_t word_count)
45 : CompilerGLSL(ir_, word_count)
46{
47}
48
49CompilerMSL::CompilerMSL(const ParsedIR &ir_)
50 : CompilerGLSL(ir_)
51{
52}
53
54CompilerMSL::CompilerMSL(ParsedIR &&ir_)
55 : CompilerGLSL(std::move(ir_))
56{
57}
58
59void CompilerMSL::add_msl_shader_input(const MSLShaderInterfaceVariable &si)
60{
61 inputs_by_location[{.location: si.location, .component: si.component}] = si;
62 if (si.builtin != BuiltInMax && !inputs_by_builtin.count(x: si.builtin))
63 inputs_by_builtin[si.builtin] = si;
64}
65
66void CompilerMSL::add_msl_shader_output(const MSLShaderInterfaceVariable &so)
67{
68 outputs_by_location[{.location: so.location, .component: so.component}] = so;
69 if (so.builtin != BuiltInMax && !outputs_by_builtin.count(x: so.builtin))
70 outputs_by_builtin[so.builtin] = so;
71}
72
73void CompilerMSL::add_msl_resource_binding(const MSLResourceBinding &binding)
74{
75 StageSetBinding tuple = { .model: binding.stage, .desc_set: binding.desc_set, .binding: binding.binding };
76 resource_bindings[tuple] = { binding, false };
77
78 // If we might need to pad argument buffer members to positionally align
79 // arg buffer indexes, also maintain a lookup by argument buffer index.
80 if (msl_options.pad_argument_buffer_resources)
81 {
82 StageSetBinding arg_idx_tuple = { .model: binding.stage, .desc_set: binding.desc_set, .binding: k_unknown_component };
83
84#define ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(rez) \
85 arg_idx_tuple.binding = binding.msl_##rez; \
86 resource_arg_buff_idx_to_binding_number[arg_idx_tuple] = binding.binding
87
88 switch (binding.basetype)
89 {
90 case SPIRType::Void:
91 case SPIRType::Boolean:
92 case SPIRType::SByte:
93 case SPIRType::UByte:
94 case SPIRType::Short:
95 case SPIRType::UShort:
96 case SPIRType::Int:
97 case SPIRType::UInt:
98 case SPIRType::Int64:
99 case SPIRType::UInt64:
100 case SPIRType::AtomicCounter:
101 case SPIRType::Half:
102 case SPIRType::Float:
103 case SPIRType::Double:
104 ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(buffer);
105 break;
106 case SPIRType::Image:
107 ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(texture);
108 break;
109 case SPIRType::Sampler:
110 ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(sampler);
111 break;
112 case SPIRType::SampledImage:
113 ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(texture);
114 ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(sampler);
115 break;
116 default:
117 SPIRV_CROSS_THROW("Unexpected argument buffer resource base type. When padding argument buffer elements, "
118 "all descriptor set resources must be supplied with a base type by the app.");
119 }
120#undef ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP
121 }
122}
123
124void CompilerMSL::add_dynamic_buffer(uint32_t desc_set, uint32_t binding, uint32_t index)
125{
126 SetBindingPair pair = { .desc_set: desc_set, .binding: binding };
127 buffers_requiring_dynamic_offset[pair] = { index, 0 };
128}
129
130void CompilerMSL::add_inline_uniform_block(uint32_t desc_set, uint32_t binding)
131{
132 SetBindingPair pair = { .desc_set: desc_set, .binding: binding };
133 inline_uniform_blocks.insert(x: pair);
134}
135
136void CompilerMSL::add_discrete_descriptor_set(uint32_t desc_set)
137{
138 if (desc_set < kMaxArgumentBuffers)
139 argument_buffer_discrete_mask |= 1u << desc_set;
140}
141
142void CompilerMSL::set_argument_buffer_device_address_space(uint32_t desc_set, bool device_storage)
143{
144 if (desc_set < kMaxArgumentBuffers)
145 {
146 if (device_storage)
147 argument_buffer_device_storage_mask |= 1u << desc_set;
148 else
149 argument_buffer_device_storage_mask &= ~(1u << desc_set);
150 }
151}
152
153bool CompilerMSL::is_msl_shader_input_used(uint32_t location)
154{
155 // Don't report internal location allocations to app.
156 return location_inputs_in_use.count(x: location) != 0 &&
157 location_inputs_in_use_fallback.count(x: location) == 0;
158}
159
160bool CompilerMSL::is_msl_shader_output_used(uint32_t location)
161{
162 // Don't report internal location allocations to app.
163 return location_outputs_in_use.count(x: location) != 0 &&
164 location_outputs_in_use_fallback.count(x: location) == 0;
165}
166
167uint32_t CompilerMSL::get_automatic_builtin_input_location(spv::BuiltIn builtin) const
168{
169 auto itr = builtin_to_automatic_input_location.find(x: builtin);
170 if (itr == builtin_to_automatic_input_location.end())
171 return k_unknown_location;
172 else
173 return itr->second;
174}
175
176uint32_t CompilerMSL::get_automatic_builtin_output_location(spv::BuiltIn builtin) const
177{
178 auto itr = builtin_to_automatic_output_location.find(x: builtin);
179 if (itr == builtin_to_automatic_output_location.end())
180 return k_unknown_location;
181 else
182 return itr->second;
183}
184
185bool CompilerMSL::is_msl_resource_binding_used(ExecutionModel model, uint32_t desc_set, uint32_t binding) const
186{
187 StageSetBinding tuple = { .model: model, .desc_set: desc_set, .binding: binding };
188 auto itr = resource_bindings.find(x: tuple);
189 return itr != end(cont: resource_bindings) && itr->second.second;
190}
191
192bool CompilerMSL::is_var_runtime_size_array(const SPIRVariable &var) const
193{
194 auto& type = get_variable_data_type(var);
195 return is_runtime_size_array(type) && get_resource_array_size(type, id: var.self) == 0;
196}
197
198// Returns the size of the array of resources used by the variable with the specified type and id.
199// The size is first retrieved from the type, but in the case of runtime array sizing,
200// the size is retrieved from the resource binding added using add_msl_resource_binding().
201uint32_t CompilerMSL::get_resource_array_size(const SPIRType &type, uint32_t id) const
202{
203 uint32_t array_size = to_array_size_literal(type);
204
205 if (id == 0)
206 return array_size;
207
208 // If we have argument buffers, we need to honor the ABI by using the correct array size
209 // from the layout. Only use shader declared size if we're not using argument buffers.
210 uint32_t desc_set = get_decoration(id, decoration: DecorationDescriptorSet);
211 if (!descriptor_set_is_argument_buffer(desc_set) && array_size)
212 return array_size;
213
214 StageSetBinding tuple = { .model: get_entry_point().model, .desc_set: desc_set,
215 .binding: get_decoration(id, decoration: DecorationBinding) };
216 auto itr = resource_bindings.find(x: tuple);
217 return itr != end(cont: resource_bindings) ? itr->second.first.count : array_size;
218}
219
220uint32_t CompilerMSL::get_automatic_msl_resource_binding(uint32_t id) const
221{
222 return get_extended_decoration(id, decoration: SPIRVCrossDecorationResourceIndexPrimary);
223}
224
225uint32_t CompilerMSL::get_automatic_msl_resource_binding_secondary(uint32_t id) const
226{
227 return get_extended_decoration(id, decoration: SPIRVCrossDecorationResourceIndexSecondary);
228}
229
230uint32_t CompilerMSL::get_automatic_msl_resource_binding_tertiary(uint32_t id) const
231{
232 return get_extended_decoration(id, decoration: SPIRVCrossDecorationResourceIndexTertiary);
233}
234
235uint32_t CompilerMSL::get_automatic_msl_resource_binding_quaternary(uint32_t id) const
236{
237 return get_extended_decoration(id, decoration: SPIRVCrossDecorationResourceIndexQuaternary);
238}
239
240void CompilerMSL::set_fragment_output_components(uint32_t location, uint32_t components)
241{
242 fragment_output_components[location] = components;
243}
244
245bool CompilerMSL::builtin_translates_to_nonarray(spv::BuiltIn builtin) const
246{
247 return (builtin == BuiltInSampleMask);
248}
249
250void CompilerMSL::build_implicit_builtins()
251{
252 bool need_sample_pos = active_input_builtins.get(bit: BuiltInSamplePosition);
253 bool need_vertex_params = capture_output_to_buffer && get_execution_model() == ExecutionModelVertex &&
254 !msl_options.vertex_for_tessellation;
255 bool need_tesc_params = is_tesc_shader();
256 bool need_tese_params = is_tese_shader() && msl_options.raw_buffer_tese_input;
257 bool need_subgroup_mask =
258 active_input_builtins.get(bit: BuiltInSubgroupEqMask) || active_input_builtins.get(bit: BuiltInSubgroupGeMask) ||
259 active_input_builtins.get(bit: BuiltInSubgroupGtMask) || active_input_builtins.get(bit: BuiltInSubgroupLeMask) ||
260 active_input_builtins.get(bit: BuiltInSubgroupLtMask);
261 bool need_subgroup_ge_mask = !msl_options.is_ios() && (active_input_builtins.get(bit: BuiltInSubgroupGeMask) ||
262 active_input_builtins.get(bit: BuiltInSubgroupGtMask));
263 bool need_multiview = get_execution_model() == ExecutionModelVertex && !msl_options.view_index_from_device_index &&
264 msl_options.multiview_layered_rendering &&
265 (msl_options.multiview || active_input_builtins.get(bit: BuiltInViewIndex));
266 bool need_dispatch_base =
267 msl_options.dispatch_base && get_execution_model() == ExecutionModelGLCompute &&
268 (active_input_builtins.get(bit: BuiltInWorkgroupId) || active_input_builtins.get(bit: BuiltInGlobalInvocationId));
269 bool need_grid_params = get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation;
270 bool need_vertex_base_params =
271 need_grid_params &&
272 (active_input_builtins.get(bit: BuiltInVertexId) || active_input_builtins.get(bit: BuiltInVertexIndex) ||
273 active_input_builtins.get(bit: BuiltInBaseVertex) || active_input_builtins.get(bit: BuiltInInstanceId) ||
274 active_input_builtins.get(bit: BuiltInInstanceIndex) || active_input_builtins.get(bit: BuiltInBaseInstance));
275 bool need_local_invocation_index = (msl_options.emulate_subgroups && active_input_builtins.get(bit: BuiltInSubgroupId)) || is_mesh_shader();
276 bool need_workgroup_size = msl_options.emulate_subgroups && active_input_builtins.get(bit: BuiltInNumSubgroups);
277 bool force_frag_depth_passthrough =
278 get_execution_model() == ExecutionModelFragment && !uses_explicit_early_fragment_test() && need_subpass_input &&
279 msl_options.enable_frag_depth_builtin && msl_options.input_attachment_is_ds_attachment;
280
281 if (need_subpass_input || need_sample_pos || need_subgroup_mask || need_vertex_params || need_tesc_params ||
282 need_tese_params || need_multiview || need_dispatch_base || need_vertex_base_params || need_grid_params ||
283 needs_sample_id || needs_subgroup_invocation_id || needs_subgroup_size || needs_helper_invocation ||
284 has_additional_fixed_sample_mask() || need_local_invocation_index || need_workgroup_size || force_frag_depth_passthrough || is_mesh_shader())
285 {
286 bool has_frag_coord = false;
287 bool has_sample_id = false;
288 bool has_vertex_idx = false;
289 bool has_base_vertex = false;
290 bool has_instance_idx = false;
291 bool has_base_instance = false;
292 bool has_invocation_id = false;
293 bool has_primitive_id = false;
294 bool has_subgroup_invocation_id = false;
295 bool has_subgroup_size = false;
296 bool has_view_idx = false;
297 bool has_layer = false;
298 bool has_helper_invocation = false;
299 bool has_local_invocation_index = false;
300 bool has_workgroup_size = false;
301 bool has_frag_depth = false;
302 uint32_t workgroup_id_type = 0;
303
304 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
305 if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
306 return;
307 if (!interface_variable_exists_in_entry_point(id: var.self))
308 return;
309 if (!has_decoration(id: var.self, decoration: DecorationBuiltIn))
310 return;
311
312 BuiltIn builtin = ir.meta[var.self].decoration.builtin_type;
313
314 if (var.storage == StorageClassOutput)
315 {
316 if (has_additional_fixed_sample_mask() && builtin == BuiltInSampleMask)
317 {
318 builtin_sample_mask_id = var.self;
319 mark_implicit_builtin(storage: StorageClassOutput, builtin: BuiltInSampleMask, id: var.self);
320 does_shader_write_sample_mask = true;
321 }
322
323 if (force_frag_depth_passthrough && builtin == BuiltInFragDepth)
324 {
325 builtin_frag_depth_id = var.self;
326 mark_implicit_builtin(storage: StorageClassOutput, builtin: BuiltInFragDepth, id: var.self);
327 has_frag_depth = true;
328 }
329 }
330
331 if (builtin == BuiltInPrimitivePointIndicesEXT ||
332 builtin == BuiltInPrimitiveLineIndicesEXT ||
333 builtin == BuiltInPrimitiveTriangleIndicesEXT)
334 {
335 builtin_mesh_primitive_indices_id = var.self;
336 }
337
338 if (var.storage != StorageClassInput)
339 return;
340
341 // Use Metal's native frame-buffer fetch API for subpass inputs.
342 if (need_subpass_input && (!msl_options.use_framebuffer_fetch_subpasses))
343 {
344 switch (builtin)
345 {
346 case BuiltInFragCoord:
347 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInFragCoord, id: var.self);
348 builtin_frag_coord_id = var.self;
349 has_frag_coord = true;
350 break;
351 case BuiltInLayer:
352 if (!msl_options.arrayed_subpass_input || msl_options.multiview)
353 break;
354 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInLayer, id: var.self);
355 builtin_layer_id = var.self;
356 has_layer = true;
357 break;
358 case BuiltInViewIndex:
359 if (!msl_options.multiview)
360 break;
361 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInViewIndex, id: var.self);
362 builtin_view_idx_id = var.self;
363 has_view_idx = true;
364 break;
365 default:
366 break;
367 }
368 }
369
370 if ((need_sample_pos || needs_sample_id) && builtin == BuiltInSampleId)
371 {
372 builtin_sample_id_id = var.self;
373 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInSampleId, id: var.self);
374 has_sample_id = true;
375 }
376
377 if (need_vertex_params)
378 {
379 switch (builtin)
380 {
381 case BuiltInVertexIndex:
382 builtin_vertex_idx_id = var.self;
383 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInVertexIndex, id: var.self);
384 has_vertex_idx = true;
385 break;
386 case BuiltInBaseVertex:
387 builtin_base_vertex_id = var.self;
388 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInBaseVertex, id: var.self);
389 has_base_vertex = true;
390 break;
391 case BuiltInInstanceIndex:
392 builtin_instance_idx_id = var.self;
393 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInInstanceIndex, id: var.self);
394 has_instance_idx = true;
395 break;
396 case BuiltInBaseInstance:
397 builtin_base_instance_id = var.self;
398 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInBaseInstance, id: var.self);
399 has_base_instance = true;
400 break;
401 default:
402 break;
403 }
404 }
405
406 if (need_tesc_params && builtin == BuiltInInvocationId)
407 {
408 builtin_invocation_id_id = var.self;
409 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInInvocationId, id: var.self);
410 has_invocation_id = true;
411 }
412
413 if ((need_tesc_params || need_tese_params) && builtin == BuiltInPrimitiveId)
414 {
415 builtin_primitive_id_id = var.self;
416 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInPrimitiveId, id: var.self);
417 has_primitive_id = true;
418 }
419
420 if (need_tese_params && builtin == BuiltInTessLevelOuter)
421 {
422 tess_level_outer_var_id = var.self;
423 }
424
425 if (need_tese_params && builtin == BuiltInTessLevelInner)
426 {
427 tess_level_inner_var_id = var.self;
428 }
429
430 if ((need_subgroup_mask || needs_subgroup_invocation_id) && builtin == BuiltInSubgroupLocalInvocationId)
431 {
432 builtin_subgroup_invocation_id_id = var.self;
433 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInSubgroupLocalInvocationId, id: var.self);
434 has_subgroup_invocation_id = true;
435 }
436
437 if ((need_subgroup_ge_mask || needs_subgroup_size) && builtin == BuiltInSubgroupSize)
438 {
439 builtin_subgroup_size_id = var.self;
440 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInSubgroupSize, id: var.self);
441 has_subgroup_size = true;
442 }
443
444 if (need_multiview)
445 {
446 switch (builtin)
447 {
448 case BuiltInInstanceIndex:
449 // The view index here is derived from the instance index.
450 builtin_instance_idx_id = var.self;
451 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInInstanceIndex, id: var.self);
452 has_instance_idx = true;
453 break;
454 case BuiltInBaseInstance:
455 // If a non-zero base instance is used, we need to adjust for it when calculating the view index.
456 builtin_base_instance_id = var.self;
457 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInBaseInstance, id: var.self);
458 has_base_instance = true;
459 break;
460 case BuiltInViewIndex:
461 builtin_view_idx_id = var.self;
462 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInViewIndex, id: var.self);
463 has_view_idx = true;
464 break;
465 default:
466 break;
467 }
468 }
469
470 if (needs_helper_invocation && builtin == BuiltInHelperInvocation)
471 {
472 builtin_helper_invocation_id = var.self;
473 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInHelperInvocation, id: var.self);
474 has_helper_invocation = true;
475 }
476
477 if (need_local_invocation_index && builtin == BuiltInLocalInvocationIndex)
478 {
479 builtin_local_invocation_index_id = var.self;
480 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInLocalInvocationIndex, id: var.self);
481 has_local_invocation_index = true;
482 }
483
484 if (need_workgroup_size && builtin == BuiltInLocalInvocationId)
485 {
486 builtin_workgroup_size_id = var.self;
487 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInWorkgroupSize, id: var.self);
488 has_workgroup_size = true;
489 }
490
491 // The base workgroup needs to have the same type and vector size
492 // as the workgroup or invocation ID, so keep track of the type that
493 // was used.
494 if (need_dispatch_base && workgroup_id_type == 0 &&
495 (builtin == BuiltInWorkgroupId || builtin == BuiltInGlobalInvocationId))
496 workgroup_id_type = var.basetype;
497 });
498
499 // Use Metal's native frame-buffer fetch API for subpass inputs.
500 if ((!has_frag_coord || (msl_options.multiview && !has_view_idx) ||
501 (msl_options.arrayed_subpass_input && !msl_options.multiview && !has_layer)) &&
502 (!msl_options.use_framebuffer_fetch_subpasses) && need_subpass_input)
503 {
504 if (!has_frag_coord)
505 {
506 uint32_t offset = ir.increase_bound_by(count: 3);
507 uint32_t type_id = offset;
508 uint32_t type_ptr_id = offset + 1;
509 uint32_t var_id = offset + 2;
510
511 // Create gl_FragCoord.
512 SPIRType vec4_type { OpTypeVector };
513 vec4_type.basetype = SPIRType::Float;
514 vec4_type.width = 32;
515 vec4_type.vecsize = 4;
516 set<SPIRType>(id: type_id, args&: vec4_type);
517
518 SPIRType vec4_type_ptr = vec4_type;
519 vec4_type_ptr.op = OpTypePointer;
520 vec4_type_ptr.pointer = true;
521 vec4_type_ptr.pointer_depth++;
522 vec4_type_ptr.parent_type = type_id;
523 vec4_type_ptr.storage = StorageClassInput;
524 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: vec4_type_ptr);
525 ptr_type.self = type_id;
526
527 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
528 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInFragCoord);
529 builtin_frag_coord_id = var_id;
530 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInFragCoord, id: var_id);
531 }
532
533 if (!has_layer && msl_options.arrayed_subpass_input && !msl_options.multiview)
534 {
535 uint32_t offset = ir.increase_bound_by(count: 2);
536 uint32_t type_ptr_id = offset;
537 uint32_t var_id = offset + 1;
538
539 // Create gl_Layer.
540 SPIRType uint_type_ptr = get_uint_type();
541 uint_type_ptr.op = OpTypePointer;
542 uint_type_ptr.pointer = true;
543 uint_type_ptr.pointer_depth++;
544 uint_type_ptr.parent_type = get_uint_type_id();
545 uint_type_ptr.storage = StorageClassInput;
546 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: uint_type_ptr);
547 ptr_type.self = get_uint_type_id();
548
549 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
550 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInLayer);
551 builtin_layer_id = var_id;
552 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInLayer, id: var_id);
553 }
554
555 if (!has_view_idx && msl_options.multiview)
556 {
557 uint32_t offset = ir.increase_bound_by(count: 2);
558 uint32_t type_ptr_id = offset;
559 uint32_t var_id = offset + 1;
560
561 // Create gl_ViewIndex.
562 SPIRType uint_type_ptr = get_uint_type();
563 uint_type_ptr.op = OpTypePointer;
564 uint_type_ptr.pointer = true;
565 uint_type_ptr.pointer_depth++;
566 uint_type_ptr.parent_type = get_uint_type_id();
567 uint_type_ptr.storage = StorageClassInput;
568 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: uint_type_ptr);
569 ptr_type.self = get_uint_type_id();
570
571 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
572 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInViewIndex);
573 builtin_view_idx_id = var_id;
574 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInViewIndex, id: var_id);
575 }
576 }
577
578 if (!has_sample_id && (need_sample_pos || needs_sample_id))
579 {
580 uint32_t offset = ir.increase_bound_by(count: 2);
581 uint32_t type_ptr_id = offset;
582 uint32_t var_id = offset + 1;
583
584 // Create gl_SampleID.
585 SPIRType uint_type_ptr = get_uint_type();
586 uint_type_ptr.op = OpTypePointer;
587 uint_type_ptr.pointer = true;
588 uint_type_ptr.pointer_depth++;
589 uint_type_ptr.parent_type = get_uint_type_id();
590 uint_type_ptr.storage = StorageClassInput;
591 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: uint_type_ptr);
592 ptr_type.self = get_uint_type_id();
593
594 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
595 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInSampleId);
596 builtin_sample_id_id = var_id;
597 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInSampleId, id: var_id);
598 }
599
600 if ((need_vertex_params && (!has_vertex_idx || !has_base_vertex || !has_instance_idx || !has_base_instance)) ||
601 (need_multiview && (!has_instance_idx || !has_base_instance || !has_view_idx)))
602 {
603 uint32_t type_ptr_id = ir.increase_bound_by(count: 1);
604
605 SPIRType uint_type_ptr = get_uint_type();
606 uint_type_ptr.op = OpTypePointer;
607 uint_type_ptr.pointer = true;
608 uint_type_ptr.pointer_depth++;
609 uint_type_ptr.parent_type = get_uint_type_id();
610 uint_type_ptr.storage = StorageClassInput;
611 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: uint_type_ptr);
612 ptr_type.self = get_uint_type_id();
613
614 if (need_vertex_params && !has_vertex_idx)
615 {
616 uint32_t var_id = ir.increase_bound_by(count: 1);
617
618 // Create gl_VertexIndex.
619 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
620 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInVertexIndex);
621 builtin_vertex_idx_id = var_id;
622 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInVertexIndex, id: var_id);
623 }
624
625 if (need_vertex_params && !has_base_vertex)
626 {
627 uint32_t var_id = ir.increase_bound_by(count: 1);
628
629 // Create gl_BaseVertex.
630 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
631 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInBaseVertex);
632 builtin_base_vertex_id = var_id;
633 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInBaseVertex, id: var_id);
634 }
635
636 if (!has_instance_idx) // Needed by both multiview and tessellation
637 {
638 uint32_t var_id = ir.increase_bound_by(count: 1);
639
640 // Create gl_InstanceIndex.
641 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
642 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInInstanceIndex);
643 builtin_instance_idx_id = var_id;
644 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInInstanceIndex, id: var_id);
645 }
646
647 if (!has_base_instance) // Needed by both multiview and tessellation
648 {
649 uint32_t var_id = ir.increase_bound_by(count: 1);
650
651 // Create gl_BaseInstance.
652 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
653 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInBaseInstance);
654 builtin_base_instance_id = var_id;
655 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInBaseInstance, id: var_id);
656 }
657
658 if (need_multiview)
659 {
660 // Multiview shaders are not allowed to write to gl_Layer, ostensibly because
661 // it is implicitly written from gl_ViewIndex, but we have to do that explicitly.
662 // Note that we can't just abuse gl_ViewIndex for this purpose: it's an input, but
663 // gl_Layer is an output in vertex-pipeline shaders.
664 uint32_t type_ptr_out_id = ir.increase_bound_by(count: 2);
665 SPIRType uint_type_ptr_out = get_uint_type();
666 uint_type_ptr.op = OpTypePointer;
667 uint_type_ptr_out.pointer = true;
668 uint_type_ptr_out.pointer_depth++;
669 uint_type_ptr_out.parent_type = get_uint_type_id();
670 uint_type_ptr_out.storage = StorageClassOutput;
671 auto &ptr_out_type = set<SPIRType>(id: type_ptr_out_id, args&: uint_type_ptr_out);
672 ptr_out_type.self = get_uint_type_id();
673 uint32_t var_id = type_ptr_out_id + 1;
674 set<SPIRVariable>(id: var_id, args&: type_ptr_out_id, args: StorageClassOutput);
675 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInLayer);
676 builtin_layer_id = var_id;
677 mark_implicit_builtin(storage: StorageClassOutput, builtin: BuiltInLayer, id: var_id);
678 }
679
680 if (need_multiview && !has_view_idx)
681 {
682 uint32_t var_id = ir.increase_bound_by(count: 1);
683
684 // Create gl_ViewIndex.
685 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
686 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInViewIndex);
687 builtin_view_idx_id = var_id;
688 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInViewIndex, id: var_id);
689 }
690 }
691
692 if ((need_tesc_params && (msl_options.multi_patch_workgroup || !has_invocation_id || !has_primitive_id)) ||
693 (need_tese_params && !has_primitive_id) || need_grid_params)
694 {
695 uint32_t type_ptr_id = ir.increase_bound_by(count: 1);
696
697 SPIRType uint_type_ptr = get_uint_type();
698 uint_type_ptr.op = OpTypePointer;
699 uint_type_ptr.pointer = true;
700 uint_type_ptr.pointer_depth++;
701 uint_type_ptr.parent_type = get_uint_type_id();
702 uint_type_ptr.storage = StorageClassInput;
703 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: uint_type_ptr);
704 ptr_type.self = get_uint_type_id();
705
706 if ((need_tesc_params && msl_options.multi_patch_workgroup) || need_grid_params)
707 {
708 uint32_t var_id = ir.increase_bound_by(count: 1);
709
710 // Create gl_GlobalInvocationID.
711 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
712 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInGlobalInvocationId);
713 builtin_invocation_id_id = var_id;
714 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInGlobalInvocationId, id: var_id);
715 }
716 else if (need_tesc_params && !has_invocation_id)
717 {
718 uint32_t var_id = ir.increase_bound_by(count: 1);
719
720 // Create gl_InvocationID.
721 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
722 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInInvocationId);
723 builtin_invocation_id_id = var_id;
724 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInInvocationId, id: var_id);
725 }
726
727 if ((need_tesc_params || need_tese_params) && !has_primitive_id)
728 {
729 uint32_t var_id = ir.increase_bound_by(count: 1);
730
731 // Create gl_PrimitiveID.
732 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
733 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInPrimitiveId);
734 builtin_primitive_id_id = var_id;
735 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInPrimitiveId, id: var_id);
736 }
737
738 if (need_grid_params)
739 {
740 uint32_t var_id = ir.increase_bound_by(count: 1);
741
742 set<SPIRVariable>(id: var_id, args: build_extended_vector_type(type_id: get_uint_type_id(), components: 3), args: StorageClassInput);
743 set_extended_decoration(id: var_id, decoration: SPIRVCrossDecorationBuiltInStageInputSize);
744 get_entry_point().interface_variables.push_back(t: var_id);
745 set_name(id: var_id, name: "spvStageInputSize");
746 builtin_stage_input_size_id = var_id;
747 }
748 }
749
750 if (!has_subgroup_invocation_id && (need_subgroup_mask || needs_subgroup_invocation_id))
751 {
752 uint32_t offset = ir.increase_bound_by(count: 2);
753 uint32_t type_ptr_id = offset;
754 uint32_t var_id = offset + 1;
755
756 // Create gl_SubgroupInvocationID.
757 SPIRType uint_type_ptr = get_uint_type();
758 uint_type_ptr.op = OpTypePointer;
759 uint_type_ptr.pointer = true;
760 uint_type_ptr.pointer_depth++;
761 uint_type_ptr.parent_type = get_uint_type_id();
762 uint_type_ptr.storage = StorageClassInput;
763 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: uint_type_ptr);
764 ptr_type.self = get_uint_type_id();
765
766 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
767 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInSubgroupLocalInvocationId);
768 builtin_subgroup_invocation_id_id = var_id;
769 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInSubgroupLocalInvocationId, id: var_id);
770 }
771
772 if (!has_subgroup_size && (need_subgroup_ge_mask || needs_subgroup_size))
773 {
774 uint32_t offset = ir.increase_bound_by(count: 2);
775 uint32_t type_ptr_id = offset;
776 uint32_t var_id = offset + 1;
777
778 // Create gl_SubgroupSize.
779 SPIRType uint_type_ptr = get_uint_type();
780 uint_type_ptr.op = OpTypePointer;
781 uint_type_ptr.pointer = true;
782 uint_type_ptr.pointer_depth++;
783 uint_type_ptr.parent_type = get_uint_type_id();
784 uint_type_ptr.storage = StorageClassInput;
785 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: uint_type_ptr);
786 ptr_type.self = get_uint_type_id();
787
788 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
789 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInSubgroupSize);
790 builtin_subgroup_size_id = var_id;
791 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInSubgroupSize, id: var_id);
792 }
793
794 if (need_dispatch_base || need_vertex_base_params)
795 {
796 if (workgroup_id_type == 0)
797 workgroup_id_type = build_extended_vector_type(type_id: get_uint_type_id(), components: 3);
798 uint32_t var_id;
799 if (msl_options.supports_msl_version(major: 1, minor: 2))
800 {
801 // If we have MSL 1.2, we can (ab)use the [[grid_origin]] builtin
802 // to convey this information and save a buffer slot.
803 uint32_t offset = ir.increase_bound_by(count: 1);
804 var_id = offset;
805
806 set<SPIRVariable>(id: var_id, args&: workgroup_id_type, args: StorageClassInput);
807 set_extended_decoration(id: var_id, decoration: SPIRVCrossDecorationBuiltInDispatchBase);
808 get_entry_point().interface_variables.push_back(t: var_id);
809 }
810 else
811 {
812 // Otherwise, we need to fall back to a good ol' fashioned buffer.
813 uint32_t offset = ir.increase_bound_by(count: 2);
814 var_id = offset;
815 uint32_t type_id = offset + 1;
816
817 SPIRType var_type = get<SPIRType>(id: workgroup_id_type);
818 var_type.storage = StorageClassUniform;
819 set<SPIRType>(id: type_id, args&: var_type);
820
821 set<SPIRVariable>(id: var_id, args&: type_id, args: StorageClassUniform);
822 // This should never match anything.
823 set_decoration(id: var_id, decoration: DecorationDescriptorSet, argument: ~(5u));
824 set_decoration(id: var_id, decoration: DecorationBinding, argument: msl_options.indirect_params_buffer_index);
825 set_extended_decoration(id: var_id, decoration: SPIRVCrossDecorationResourceIndexPrimary,
826 value: msl_options.indirect_params_buffer_index);
827 }
828 set_name(id: var_id, name: "spvDispatchBase");
829 builtin_dispatch_base_id = var_id;
830 }
831
832 if (has_additional_fixed_sample_mask() && !does_shader_write_sample_mask)
833 {
834 uint32_t offset = ir.increase_bound_by(count: 2);
835 uint32_t var_id = offset + 1;
836
837 // Create gl_SampleMask.
838 SPIRType uint_type_ptr_out = get_uint_type();
839 uint_type_ptr_out.op = OpTypePointer;
840 uint_type_ptr_out.pointer = true;
841 uint_type_ptr_out.pointer_depth++;
842 uint_type_ptr_out.parent_type = get_uint_type_id();
843 uint_type_ptr_out.storage = StorageClassOutput;
844
845 auto &ptr_out_type = set<SPIRType>(id: offset, args&: uint_type_ptr_out);
846 ptr_out_type.self = get_uint_type_id();
847 set<SPIRVariable>(id: var_id, args&: offset, args: StorageClassOutput);
848 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInSampleMask);
849 builtin_sample_mask_id = var_id;
850 mark_implicit_builtin(storage: StorageClassOutput, builtin: BuiltInSampleMask, id: var_id);
851 }
852
853 if (!has_helper_invocation && needs_helper_invocation)
854 {
855 uint32_t offset = ir.increase_bound_by(count: 3);
856 uint32_t type_id = offset;
857 uint32_t type_ptr_id = offset + 1;
858 uint32_t var_id = offset + 2;
859
860 // Create gl_HelperInvocation.
861 SPIRType bool_type { OpTypeBool };
862 bool_type.basetype = SPIRType::Boolean;
863 bool_type.width = 8;
864 bool_type.vecsize = 1;
865 set<SPIRType>(id: type_id, args&: bool_type);
866
867 SPIRType bool_type_ptr_in = bool_type;
868 bool_type_ptr_in.op = spv::OpTypePointer;
869 bool_type_ptr_in.pointer = true;
870 bool_type_ptr_in.pointer_depth++;
871 bool_type_ptr_in.parent_type = type_id;
872 bool_type_ptr_in.storage = StorageClassInput;
873
874 auto &ptr_in_type = set<SPIRType>(id: type_ptr_id, args&: bool_type_ptr_in);
875 ptr_in_type.self = type_id;
876 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
877 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInHelperInvocation);
878 builtin_helper_invocation_id = var_id;
879 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInHelperInvocation, id: var_id);
880 }
881
882 if (need_local_invocation_index && !has_local_invocation_index)
883 {
884 uint32_t offset = ir.increase_bound_by(count: 2);
885 uint32_t type_ptr_id = offset;
886 uint32_t var_id = offset + 1;
887
888 // Create gl_LocalInvocationIndex.
889 SPIRType uint_type_ptr = get_uint_type();
890 uint_type_ptr.op = OpTypePointer;
891 uint_type_ptr.pointer = true;
892 uint_type_ptr.pointer_depth++;
893 uint_type_ptr.parent_type = get_uint_type_id();
894 uint_type_ptr.storage = StorageClassInput;
895
896 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: uint_type_ptr);
897 ptr_type.self = get_uint_type_id();
898 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
899 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInLocalInvocationIndex);
900 builtin_local_invocation_index_id = var_id;
901 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInLocalInvocationIndex, id: var_id);
902 }
903
904 if (need_workgroup_size && !has_workgroup_size)
905 {
906 uint32_t offset = ir.increase_bound_by(count: 2);
907 uint32_t type_ptr_id = offset;
908 uint32_t var_id = offset + 1;
909
910 // Create gl_WorkgroupSize.
911 uint32_t type_id = build_extended_vector_type(type_id: get_uint_type_id(), components: 3);
912 SPIRType uint_type_ptr = get<SPIRType>(id: type_id);
913 uint_type_ptr.op = OpTypePointer;
914 uint_type_ptr.pointer = true;
915 uint_type_ptr.pointer_depth++;
916 uint_type_ptr.parent_type = type_id;
917 uint_type_ptr.storage = StorageClassInput;
918
919 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: uint_type_ptr);
920 ptr_type.self = type_id;
921 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
922 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInWorkgroupSize);
923 builtin_workgroup_size_id = var_id;
924 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInWorkgroupSize, id: var_id);
925 }
926
927 if (!has_frag_depth && force_frag_depth_passthrough)
928 {
929 uint32_t offset = ir.increase_bound_by(count: 3);
930 uint32_t type_id = offset;
931 uint32_t type_ptr_id = offset + 1;
932 uint32_t var_id = offset + 2;
933
934 // Create gl_FragDepth
935 SPIRType float_type { OpTypeFloat };
936 float_type.basetype = SPIRType::Float;
937 float_type.width = 32;
938 float_type.vecsize = 1;
939 set<SPIRType>(id: type_id, args&: float_type);
940
941 SPIRType float_type_ptr_in = float_type;
942 float_type_ptr_in.op = spv::OpTypePointer;
943 float_type_ptr_in.pointer = true;
944 float_type_ptr_in.pointer_depth++;
945 float_type_ptr_in.parent_type = type_id;
946 float_type_ptr_in.storage = StorageClassOutput;
947
948 auto &ptr_in_type = set<SPIRType>(id: type_ptr_id, args&: float_type_ptr_in);
949 ptr_in_type.self = type_id;
950 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassOutput);
951 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInFragDepth);
952 builtin_frag_depth_id = var_id;
953 mark_implicit_builtin(storage: StorageClassOutput, builtin: BuiltInFragDepth, id: var_id);
954 active_output_builtins.set(BuiltInFragDepth);
955 }
956 }
957
958 if (needs_swizzle_buffer_def)
959 {
960 uint32_t var_id = build_constant_uint_array_pointer();
961 set_name(id: var_id, name: "spvSwizzleConstants");
962 // This should never match anything.
963 set_decoration(id: var_id, decoration: DecorationDescriptorSet, argument: kSwizzleBufferBinding);
964 set_decoration(id: var_id, decoration: DecorationBinding, argument: msl_options.swizzle_buffer_index);
965 set_extended_decoration(id: var_id, decoration: SPIRVCrossDecorationResourceIndexPrimary, value: msl_options.swizzle_buffer_index);
966 swizzle_buffer_id = var_id;
967 }
968
969 if (needs_buffer_size_buffer())
970 {
971 uint32_t var_id = build_constant_uint_array_pointer();
972 set_name(id: var_id, name: "spvBufferSizeConstants");
973 // This should never match anything.
974 set_decoration(id: var_id, decoration: DecorationDescriptorSet, argument: kBufferSizeBufferBinding);
975 set_decoration(id: var_id, decoration: DecorationBinding, argument: msl_options.buffer_size_buffer_index);
976 set_extended_decoration(id: var_id, decoration: SPIRVCrossDecorationResourceIndexPrimary, value: msl_options.buffer_size_buffer_index);
977 buffer_size_buffer_id = var_id;
978 }
979
980 if (needs_view_mask_buffer())
981 {
982 uint32_t var_id = build_constant_uint_array_pointer();
983 set_name(id: var_id, name: "spvViewMask");
984 // This should never match anything.
985 set_decoration(id: var_id, decoration: DecorationDescriptorSet, argument: ~(4u));
986 set_decoration(id: var_id, decoration: DecorationBinding, argument: msl_options.view_mask_buffer_index);
987 set_extended_decoration(id: var_id, decoration: SPIRVCrossDecorationResourceIndexPrimary, value: msl_options.view_mask_buffer_index);
988 view_mask_buffer_id = var_id;
989 }
990
991 if (!buffers_requiring_dynamic_offset.empty())
992 {
993 uint32_t var_id = build_constant_uint_array_pointer();
994 set_name(id: var_id, name: "spvDynamicOffsets");
995 // This should never match anything.
996 set_decoration(id: var_id, decoration: DecorationDescriptorSet, argument: ~(5u));
997 set_decoration(id: var_id, decoration: DecorationBinding, argument: msl_options.dynamic_offsets_buffer_index);
998 set_extended_decoration(id: var_id, decoration: SPIRVCrossDecorationResourceIndexPrimary,
999 value: msl_options.dynamic_offsets_buffer_index);
1000 dynamic_offsets_buffer_id = var_id;
1001 }
1002
1003 // If we're returning a struct from a vertex-like entry point, we must return a position attribute.
1004 bool need_position = (get_execution_model() == ExecutionModelVertex || is_tese_shader()) &&
1005 !capture_output_to_buffer && !get_is_rasterization_disabled() &&
1006 !active_output_builtins.get(bit: BuiltInPosition);
1007
1008 if (need_position)
1009 {
1010 // If we can get away with returning void from entry point, we don't need to care.
1011 // If there is at least one other stage output, we need to return [[position]],
1012 // so we need to create one if it doesn't appear in the SPIR-V. Before adding the
1013 // implicit variable, check if it actually exists already, but just has not been used
1014 // or initialized, and if so, mark it as active, and do not create the implicit variable.
1015 bool has_output = false;
1016 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
1017 if (var.storage == StorageClassOutput && interface_variable_exists_in_entry_point(id: var.self))
1018 {
1019 has_output = true;
1020
1021 // Check if the var is the Position builtin
1022 if (has_decoration(id: var.self, decoration: DecorationBuiltIn) && get_decoration(id: var.self, decoration: DecorationBuiltIn) == BuiltInPosition)
1023 active_output_builtins.set(BuiltInPosition);
1024
1025 // If the var is a struct, check if any members is the Position builtin
1026 auto &var_type = get_variable_element_type(var);
1027 if (var_type.basetype == SPIRType::Struct)
1028 {
1029 auto mbr_cnt = var_type.member_types.size();
1030 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
1031 {
1032 auto builtin = BuiltInMax;
1033 bool is_builtin = is_member_builtin(type: var_type, index: mbr_idx, builtin: &builtin);
1034 if (is_builtin && builtin == BuiltInPosition)
1035 active_output_builtins.set(BuiltInPosition);
1036 }
1037 }
1038 }
1039 });
1040 need_position = has_output && !active_output_builtins.get(bit: BuiltInPosition);
1041 }
1042
1043 if (need_position)
1044 {
1045 uint32_t offset = ir.increase_bound_by(count: 3);
1046 uint32_t type_id = offset;
1047 uint32_t type_ptr_id = offset + 1;
1048 uint32_t var_id = offset + 2;
1049
1050 // Create gl_Position.
1051 SPIRType vec4_type { OpTypeVector };
1052 vec4_type.basetype = SPIRType::Float;
1053 vec4_type.width = 32;
1054 vec4_type.vecsize = 4;
1055 set<SPIRType>(id: type_id, args&: vec4_type);
1056
1057 SPIRType vec4_type_ptr = vec4_type;
1058 vec4_type_ptr.op = OpTypePointer;
1059 vec4_type_ptr.pointer = true;
1060 vec4_type_ptr.pointer_depth++;
1061 vec4_type_ptr.parent_type = type_id;
1062 vec4_type_ptr.storage = StorageClassOutput;
1063 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: vec4_type_ptr);
1064 ptr_type.self = type_id;
1065
1066 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassOutput);
1067 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInPosition);
1068 mark_implicit_builtin(storage: StorageClassOutput, builtin: BuiltInPosition, id: var_id);
1069 }
1070
1071 if (is_mesh_shader())
1072 {
1073 uint32_t offset = ir.increase_bound_by(count: 2);
1074 uint32_t type_ptr_id = offset;
1075 uint32_t var_id = offset + 1;
1076
1077 // Create variable to store meshlet size.
1078 uint32_t type_id = build_extended_vector_type(type_id: get_uint_type_id(), components: 2);
1079 SPIRType uint_type_ptr = get<SPIRType>(id: type_id);
1080 uint_type_ptr.op = OpTypePointer;
1081 uint_type_ptr.pointer = true;
1082 uint_type_ptr.pointer_depth++;
1083 uint_type_ptr.parent_type = type_id;
1084 uint_type_ptr.storage = StorageClassWorkgroup;
1085
1086 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: uint_type_ptr);
1087 ptr_type.self = type_id;
1088 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassWorkgroup);
1089 set_name(id: var_id, name: "spvMeshSizes");
1090 builtin_mesh_sizes_id = var_id;
1091 }
1092
1093 if (get_execution_model() == spv::ExecutionModelTaskEXT)
1094 {
1095 uint32_t offset = ir.increase_bound_by(count: 3);
1096 uint32_t type_id = offset;
1097 uint32_t type_ptr_id = offset + 1;
1098 uint32_t var_id = offset + 2;
1099
1100 SPIRType mesh_grid_type { OpTypeStruct };
1101 mesh_grid_type.basetype = SPIRType::MeshGridProperties;
1102 set<SPIRType>(id: type_id, args&: mesh_grid_type);
1103
1104 SPIRType mesh_grid_type_ptr = mesh_grid_type;
1105 mesh_grid_type_ptr.op = spv::OpTypePointer;
1106 mesh_grid_type_ptr.pointer = true;
1107 mesh_grid_type_ptr.pointer_depth++;
1108 mesh_grid_type_ptr.parent_type = type_id;
1109 mesh_grid_type_ptr.storage = StorageClassOutput;
1110
1111 auto &ptr_in_type = set<SPIRType>(id: type_ptr_id, args&: mesh_grid_type_ptr);
1112 ptr_in_type.self = type_id;
1113 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassOutput);
1114 set_name(id: var_id, name: "spvMgp");
1115 builtin_task_grid_id = var_id;
1116 }
1117}
1118
1119// Checks if the specified builtin variable (e.g. gl_InstanceIndex) is marked as active.
1120// If not, it marks it as active and forces a recompilation.
1121// This might be used when the optimization of inactive builtins was too optimistic (e.g. when "spvOut" is emitted).
1122void CompilerMSL::ensure_builtin(spv::StorageClass storage, spv::BuiltIn builtin)
1123{
1124 Bitset *active_builtins = nullptr;
1125 switch (storage)
1126 {
1127 case StorageClassInput:
1128 active_builtins = &active_input_builtins;
1129 break;
1130
1131 case StorageClassOutput:
1132 active_builtins = &active_output_builtins;
1133 break;
1134
1135 default:
1136 break;
1137 }
1138
1139 // At this point, the specified builtin variable must have already been declared in the entry point.
1140 // If not, mark as active and force recompile.
1141 if (active_builtins != nullptr && !active_builtins->get(bit: builtin))
1142 {
1143 active_builtins->set(builtin);
1144 force_recompile();
1145 }
1146}
1147
1148void CompilerMSL::mark_implicit_builtin(StorageClass storage, BuiltIn builtin, uint32_t id)
1149{
1150 Bitset *active_builtins = nullptr;
1151 switch (storage)
1152 {
1153 case StorageClassInput:
1154 active_builtins = &active_input_builtins;
1155 break;
1156
1157 case StorageClassOutput:
1158 active_builtins = &active_output_builtins;
1159 break;
1160
1161 default:
1162 break;
1163 }
1164
1165 assert(active_builtins != nullptr);
1166 active_builtins->set(builtin);
1167
1168 auto &var = get_entry_point().interface_variables;
1169 if (find(first: begin(cont&: var), last: end(cont&: var), val: VariableID(id)) == end(cont&: var))
1170 var.push_back(t: id);
1171}
1172
1173uint32_t CompilerMSL::build_constant_uint_array_pointer()
1174{
1175 uint32_t offset = ir.increase_bound_by(count: 3);
1176 uint32_t type_ptr_id = offset;
1177 uint32_t type_ptr_ptr_id = offset + 1;
1178 uint32_t var_id = offset + 2;
1179
1180 // Create a buffer to hold extra data, including the swizzle constants.
1181 SPIRType uint_type_pointer = get_uint_type();
1182 uint_type_pointer.op = OpTypePointer;
1183 uint_type_pointer.pointer = true;
1184 uint_type_pointer.pointer_depth++;
1185 uint_type_pointer.parent_type = get_uint_type_id();
1186 uint_type_pointer.storage = StorageClassUniform;
1187 set<SPIRType>(id: type_ptr_id, args&: uint_type_pointer);
1188 set_decoration(id: type_ptr_id, decoration: DecorationArrayStride, argument: 4);
1189
1190 SPIRType uint_type_pointer2 = uint_type_pointer;
1191 uint_type_pointer2.pointer_depth++;
1192 uint_type_pointer2.parent_type = type_ptr_id;
1193 set<SPIRType>(id: type_ptr_ptr_id, args&: uint_type_pointer2);
1194
1195 set<SPIRVariable>(id: var_id, args&: type_ptr_ptr_id, args: StorageClassUniformConstant);
1196 return var_id;
1197}
1198
1199static string create_sampler_address(const char *prefix, MSLSamplerAddress addr)
1200{
1201 switch (addr)
1202 {
1203 case MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE:
1204 return join(ts&: prefix, ts: "address::clamp_to_edge");
1205 case MSL_SAMPLER_ADDRESS_CLAMP_TO_ZERO:
1206 return join(ts&: prefix, ts: "address::clamp_to_zero");
1207 case MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER:
1208 return join(ts&: prefix, ts: "address::clamp_to_border");
1209 case MSL_SAMPLER_ADDRESS_REPEAT:
1210 return join(ts&: prefix, ts: "address::repeat");
1211 case MSL_SAMPLER_ADDRESS_MIRRORED_REPEAT:
1212 return join(ts&: prefix, ts: "address::mirrored_repeat");
1213 default:
1214 SPIRV_CROSS_THROW("Invalid sampler addressing mode.");
1215 }
1216}
1217
1218SPIRType &CompilerMSL::get_stage_in_struct_type()
1219{
1220 auto &si_var = get<SPIRVariable>(id: stage_in_var_id);
1221 return get_variable_data_type(var: si_var);
1222}
1223
1224SPIRType &CompilerMSL::get_stage_out_struct_type()
1225{
1226 auto &so_var = get<SPIRVariable>(id: stage_out_var_id);
1227 return get_variable_data_type(var: so_var);
1228}
1229
1230SPIRType &CompilerMSL::get_patch_stage_in_struct_type()
1231{
1232 auto &si_var = get<SPIRVariable>(id: patch_stage_in_var_id);
1233 return get_variable_data_type(var: si_var);
1234}
1235
1236SPIRType &CompilerMSL::get_patch_stage_out_struct_type()
1237{
1238 auto &so_var = get<SPIRVariable>(id: patch_stage_out_var_id);
1239 return get_variable_data_type(var: so_var);
1240}
1241
1242std::string CompilerMSL::get_tess_factor_struct_name()
1243{
1244 if (is_tessellating_triangles())
1245 return "MTLTriangleTessellationFactorsHalf";
1246 return "MTLQuadTessellationFactorsHalf";
1247}
1248
1249SPIRType &CompilerMSL::get_uint_type()
1250{
1251 return get<SPIRType>(id: get_uint_type_id());
1252}
1253
1254uint32_t CompilerMSL::get_uint_type_id()
1255{
1256 if (uint_type_id != 0)
1257 return uint_type_id;
1258
1259 uint_type_id = ir.increase_bound_by(count: 1);
1260
1261 SPIRType type { OpTypeInt };
1262 type.basetype = SPIRType::UInt;
1263 type.width = 32;
1264 set<SPIRType>(id: uint_type_id, args&: type);
1265 return uint_type_id;
1266}
1267
1268void CompilerMSL::emit_entry_point_declarations()
1269{
1270 // FIXME: Get test coverage here ...
1271 // Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries
1272 declare_complex_constant_arrays();
1273
1274 // Emit constexpr samplers here.
1275 for (auto &samp : constexpr_samplers_by_id)
1276 {
1277 auto &var = get<SPIRVariable>(id: samp.first);
1278 auto &type = get<SPIRType>(id: var.basetype);
1279 if (type.basetype == SPIRType::Sampler)
1280 add_resource_name(id: samp.first);
1281
1282 SmallVector<string> args;
1283 auto &s = samp.second;
1284
1285 if (s.coord != MSL_SAMPLER_COORD_NORMALIZED)
1286 args.push_back(t: "coord::pixel");
1287
1288 if (s.min_filter == s.mag_filter)
1289 {
1290 if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST)
1291 args.push_back(t: "filter::linear");
1292 }
1293 else
1294 {
1295 if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST)
1296 args.push_back(t: "min_filter::linear");
1297 if (s.mag_filter != MSL_SAMPLER_FILTER_NEAREST)
1298 args.push_back(t: "mag_filter::linear");
1299 }
1300
1301 switch (s.mip_filter)
1302 {
1303 case MSL_SAMPLER_MIP_FILTER_NONE:
1304 // Default
1305 break;
1306 case MSL_SAMPLER_MIP_FILTER_NEAREST:
1307 args.push_back(t: "mip_filter::nearest");
1308 break;
1309 case MSL_SAMPLER_MIP_FILTER_LINEAR:
1310 args.push_back(t: "mip_filter::linear");
1311 break;
1312 default:
1313 SPIRV_CROSS_THROW("Invalid mip filter.");
1314 }
1315
1316 if (s.s_address == s.t_address && s.s_address == s.r_address)
1317 {
1318 if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1319 args.push_back(t: create_sampler_address(prefix: "", addr: s.s_address));
1320 }
1321 else
1322 {
1323 if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1324 args.push_back(t: create_sampler_address(prefix: "s_", addr: s.s_address));
1325 if (s.t_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1326 args.push_back(t: create_sampler_address(prefix: "t_", addr: s.t_address));
1327 if (s.r_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1328 args.push_back(t: create_sampler_address(prefix: "r_", addr: s.r_address));
1329 }
1330
1331 if (s.compare_enable)
1332 {
1333 switch (s.compare_func)
1334 {
1335 case MSL_SAMPLER_COMPARE_FUNC_ALWAYS:
1336 args.push_back(t: "compare_func::always");
1337 break;
1338 case MSL_SAMPLER_COMPARE_FUNC_NEVER:
1339 args.push_back(t: "compare_func::never");
1340 break;
1341 case MSL_SAMPLER_COMPARE_FUNC_EQUAL:
1342 args.push_back(t: "compare_func::equal");
1343 break;
1344 case MSL_SAMPLER_COMPARE_FUNC_NOT_EQUAL:
1345 args.push_back(t: "compare_func::not_equal");
1346 break;
1347 case MSL_SAMPLER_COMPARE_FUNC_LESS:
1348 args.push_back(t: "compare_func::less");
1349 break;
1350 case MSL_SAMPLER_COMPARE_FUNC_LESS_EQUAL:
1351 args.push_back(t: "compare_func::less_equal");
1352 break;
1353 case MSL_SAMPLER_COMPARE_FUNC_GREATER:
1354 args.push_back(t: "compare_func::greater");
1355 break;
1356 case MSL_SAMPLER_COMPARE_FUNC_GREATER_EQUAL:
1357 args.push_back(t: "compare_func::greater_equal");
1358 break;
1359 default:
1360 SPIRV_CROSS_THROW("Invalid sampler compare function.");
1361 }
1362 }
1363
1364 if (s.s_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER || s.t_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER ||
1365 s.r_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER)
1366 {
1367 switch (s.border_color)
1368 {
1369 case MSL_SAMPLER_BORDER_COLOR_OPAQUE_BLACK:
1370 args.push_back(t: "border_color::opaque_black");
1371 break;
1372 case MSL_SAMPLER_BORDER_COLOR_OPAQUE_WHITE:
1373 args.push_back(t: "border_color::opaque_white");
1374 break;
1375 case MSL_SAMPLER_BORDER_COLOR_TRANSPARENT_BLACK:
1376 args.push_back(t: "border_color::transparent_black");
1377 break;
1378 default:
1379 SPIRV_CROSS_THROW("Invalid sampler border color.");
1380 }
1381 }
1382
1383 if (s.anisotropy_enable)
1384 args.push_back(t: join(ts: "max_anisotropy(", ts&: s.max_anisotropy, ts: ")"));
1385 if (s.lod_clamp_enable)
1386 {
1387 args.push_back(t: join(ts: "lod_clamp(", ts: format_float(value: s.lod_clamp_min), ts: ", ", ts: format_float(value: s.lod_clamp_max), ts: ")"));
1388 }
1389
1390 // If we would emit no arguments, then omit the parentheses entirely. Otherwise,
1391 // we'll wind up with a "most vexing parse" situation.
1392 if (args.empty())
1393 statement(ts: "constexpr sampler ",
1394 ts: type.basetype == SPIRType::SampledImage ? to_sampler_expression(id: samp.first) : to_name(id: samp.first),
1395 ts: ";");
1396 else
1397 statement(ts: "constexpr sampler ",
1398 ts: type.basetype == SPIRType::SampledImage ? to_sampler_expression(id: samp.first) : to_name(id: samp.first),
1399 ts: "(", ts: merge(list: args), ts: ");");
1400 }
1401
1402 // Emit dynamic buffers here.
1403 for (auto &dynamic_buffer : buffers_requiring_dynamic_offset)
1404 {
1405 if (!dynamic_buffer.second.second)
1406 {
1407 // Could happen if no buffer was used at requested binding point.
1408 continue;
1409 }
1410
1411 const auto &var = get<SPIRVariable>(id: dynamic_buffer.second.second);
1412 uint32_t var_id = var.self;
1413 const auto &type = get_variable_data_type(var);
1414 string name = to_name(id: var.self);
1415 uint32_t desc_set = get_decoration(id: var.self, decoration: DecorationDescriptorSet);
1416 uint32_t arg_id = argument_buffer_ids[desc_set];
1417 uint32_t base_index = dynamic_buffer.second.first;
1418
1419 if (is_array(type))
1420 {
1421 is_using_builtin_array = true;
1422 statement(ts: get_argument_address_space(argument: var), ts: " ", ts: type_to_glsl(type), ts: "* ", ts: to_restrict(id: var_id, space: true), ts&: name,
1423 ts: type_to_array_glsl(type, variable_id: var_id), ts: " =");
1424
1425 uint32_t array_size = get_resource_array_size(type, id: var_id);
1426 if (array_size == 0)
1427 SPIRV_CROSS_THROW("Size of runtime array with dynamic offset could not be determined from resource bindings.");
1428
1429 begin_scope();
1430
1431 for (uint32_t i = 0; i < array_size; i++)
1432 {
1433 statement(ts: "(", ts: get_argument_address_space(argument: var), ts: " ", ts: type_to_glsl(type), ts: "* ",
1434 ts: to_restrict(id: var_id, space: false), ts: ")((", ts: get_argument_address_space(argument: var), ts: " char* ",
1435 ts: to_restrict(id: var_id, space: false), ts: ")", ts: to_name(id: arg_id), ts: ".", ts: ensure_valid_name(name, pfx: "m"),
1436 ts: "[", ts&: i, ts: "]", ts: " + ", ts: to_name(id: dynamic_offsets_buffer_id), ts: "[", ts: base_index + i, ts: "]),");
1437 }
1438
1439 end_scope_decl();
1440 statement_no_indent(ts: "");
1441 is_using_builtin_array = false;
1442 }
1443 else
1444 {
1445 statement(ts: get_argument_address_space(argument: var), ts: " auto& ", ts: to_restrict(id: var_id, space: true), ts&: name, ts: " = *(",
1446 ts: get_argument_address_space(argument: var), ts: " ", ts: type_to_glsl(type), ts: "* ", ts: to_restrict(id: var_id, space: false), ts: ")((",
1447 ts: get_argument_address_space(argument: var), ts: " char* ", ts: to_restrict(id: var_id, space: false), ts: ")", ts: to_name(id: arg_id), ts: ".",
1448 ts: ensure_valid_name(name, pfx: "m"), ts: " + ", ts: to_name(id: dynamic_offsets_buffer_id), ts: "[", ts&: base_index, ts: "]);");
1449 }
1450 }
1451
1452 bool has_runtime_array_declaration = false;
1453 for (SPIRVariable *arg : entry_point_bindings)
1454 {
1455 const auto &var = *arg;
1456 const auto &type = get_variable_data_type(var);
1457 const auto &buffer_type = get_variable_element_type(var);
1458 const string name = to_name(id: var.self);
1459
1460 if (is_var_runtime_size_array(var))
1461 {
1462 if (msl_options.argument_buffers_tier < Options::ArgumentBuffersTier::Tier2)
1463 {
1464 SPIRV_CROSS_THROW("Unsized array of descriptors requires argument buffer tier 2");
1465 }
1466
1467 string resource_name;
1468 if (descriptor_set_is_argument_buffer(desc_set: get_decoration(id: var.self, decoration: DecorationDescriptorSet)))
1469 resource_name = ir.meta[var.self].decoration.qualified_alias;
1470 else
1471 resource_name = name + "_";
1472
1473 switch (type.basetype)
1474 {
1475 case SPIRType::Image:
1476 case SPIRType::Sampler:
1477 case SPIRType::AccelerationStructure:
1478 statement(ts: "spvDescriptorArray<", ts: type_to_glsl(type: buffer_type, id: var.self), ts: "> ", ts: name, ts: " {", ts&: resource_name, ts: "};");
1479 break;
1480 case SPIRType::SampledImage:
1481 statement(ts: "spvDescriptorArray<", ts: type_to_glsl(type: buffer_type, id: var.self), ts: "> ", ts: name, ts: " {", ts&: resource_name, ts: "};");
1482 // Unsupported with argument buffer for now.
1483 statement(ts: "spvDescriptorArray<sampler> ", ts: name, ts: "Smplr {", ts: name, ts: "Smplr_};");
1484 break;
1485 case SPIRType::Struct:
1486 statement(ts: "spvDescriptorArray<", ts: get_argument_address_space(argument: var), ts: " ", ts: type_to_glsl(type: buffer_type), ts: "*> ",
1487 ts: name, ts: " {", ts&: resource_name, ts: "};");
1488 break;
1489 default:
1490 break;
1491 }
1492 has_runtime_array_declaration = true;
1493 }
1494 else if (!type.array.empty() && type.basetype == SPIRType::Struct)
1495 {
1496 // Emit only buffer arrays here.
1497 statement(ts: get_argument_address_space(argument: var), ts: " ", ts: type_to_glsl(type: buffer_type), ts: "* ",
1498 ts: to_restrict(id: var.self, space: true), ts: name, ts: "[] =");
1499 begin_scope();
1500 uint32_t array_size = get_resource_array_size(type, id: var.self);
1501 for (uint32_t i = 0; i < array_size; ++i)
1502 statement(ts: name, ts: "_", ts&: i, ts: ",");
1503 end_scope_decl();
1504 statement_no_indent(ts: "");
1505 }
1506 }
1507
1508 if (has_runtime_array_declaration)
1509 statement_no_indent(ts: "");
1510
1511 // Emit buffer aliases here.
1512 for (auto &var_id : buffer_aliases_discrete)
1513 {
1514 const auto &var = get<SPIRVariable>(id: var_id);
1515 const auto &type = get_variable_data_type(var);
1516 auto addr_space = get_argument_address_space(argument: var);
1517 auto name = to_name(id: var_id);
1518
1519 uint32_t desc_set = get_decoration(id: var_id, decoration: DecorationDescriptorSet);
1520 uint32_t desc_binding = get_decoration(id: var_id, decoration: DecorationBinding);
1521 auto alias_name = join(ts: "spvBufferAliasSet", ts&: desc_set, ts: "Binding", ts&: desc_binding);
1522
1523 statement(ts&: addr_space, ts: " auto& ", ts: to_restrict(id: var_id, space: true),
1524 ts&: name,
1525 ts: " = *(", ts&: addr_space, ts: " ", ts: type_to_glsl(type), ts: "*)", ts&: alias_name, ts: ";");
1526 }
1527 // Discrete descriptors are processed in entry point emission every compiler iteration.
1528 buffer_aliases_discrete.clear();
1529
1530 for (auto &var_pair : buffer_aliases_argument)
1531 {
1532 uint32_t var_id = var_pair.first;
1533 uint32_t alias_id = var_pair.second;
1534
1535 const auto &var = get<SPIRVariable>(id: var_id);
1536 const auto &type = get_variable_data_type(var);
1537 auto addr_space = get_argument_address_space(argument: var);
1538
1539 if (type.array.empty())
1540 {
1541 statement(ts&: addr_space, ts: " auto& ", ts: to_restrict(id: var_id, space: true), ts: to_name(id: var_id), ts: " = (", ts&: addr_space, ts: " ",
1542 ts: type_to_glsl(type), ts: "&)", ts&: ir.meta[alias_id].decoration.qualified_alias, ts: ";");
1543 }
1544 else
1545 {
1546 const char *desc_addr_space = descriptor_address_space(id: var_id, storage: var.storage, plain_address_space: "thread");
1547
1548 // Esoteric type cast. Reference to array of pointers.
1549 // Auto here defers to UBO or SSBO. The address space of the reference needs to refer to the
1550 // address space of the argument buffer itself, which is usually constant, but can be const device for
1551 // large argument buffers.
1552 is_using_builtin_array = true;
1553 statement(ts&: desc_addr_space, ts: " auto& ", ts: to_restrict(id: var_id, space: true), ts: to_name(id: var_id), ts: " = (", ts&: addr_space, ts: " ",
1554 ts: type_to_glsl(type), ts: "* ", ts&: desc_addr_space, ts: " (&)",
1555 ts: type_to_array_glsl(type, variable_id: var_id), ts: ")", ts&: ir.meta[alias_id].decoration.qualified_alias, ts: ";");
1556 is_using_builtin_array = false;
1557 }
1558 }
1559
1560 // Emit disabled fragment outputs.
1561 std::sort(first: disabled_frag_outputs.begin(), last: disabled_frag_outputs.end());
1562 for (uint32_t var_id : disabled_frag_outputs)
1563 {
1564 auto &var = get<SPIRVariable>(id: var_id);
1565 add_local_variable_name(id: var_id);
1566 statement(ts: CompilerGLSL::variable_decl(variable: var), ts: ";");
1567 var.deferred_declaration = false;
1568 }
1569
1570 // Holds SetMeshOutputsEXT information. Threadgroup since first thread wins.
1571 if (processing_entry_point && is_mesh_shader())
1572 statement(ts: "threadgroup uint2 spvMeshSizes;");
1573}
1574
1575string CompilerMSL::compile()
1576{
1577 replace_illegal_entry_point_names();
1578 ir.fixup_reserved_names();
1579
1580 // Do not deal with GLES-isms like precision, older extensions and such.
1581 options.vulkan_semantics = true;
1582 options.es = false;
1583 options.version = 450;
1584 backend.null_pointer_literal = "nullptr";
1585 backend.float_literal_suffix = false;
1586 backend.uint32_t_literal_suffix = true;
1587 backend.int16_t_literal_suffix = "";
1588 backend.uint16_t_literal_suffix = "";
1589 backend.basic_int_type = "int";
1590 backend.basic_uint_type = "uint";
1591 backend.basic_int8_type = "char";
1592 backend.basic_uint8_type = "uchar";
1593 backend.basic_int16_type = "short";
1594 backend.basic_uint16_type = "ushort";
1595 backend.boolean_mix_function = "select";
1596 backend.swizzle_is_function = false;
1597 backend.shared_is_implied = false;
1598 backend.use_initializer_list = true;
1599 backend.use_typed_initializer_list = true;
1600 backend.native_row_major_matrix = false;
1601 backend.unsized_array_supported = false;
1602 backend.can_declare_arrays_inline = false;
1603 backend.allow_truncated_access_chain = true;
1604 backend.comparison_image_samples_scalar = true;
1605 backend.native_pointers = true;
1606 backend.nonuniform_qualifier = "";
1607 backend.support_small_type_sampling_result = true;
1608 backend.force_merged_mesh_block = false;
1609 backend.force_gl_in_out_block = get_execution_model() == ExecutionModelMeshEXT;
1610 backend.supports_empty_struct = true;
1611 backend.support_64bit_switch = true;
1612 backend.boolean_in_struct_remapped_type = SPIRType::Short;
1613
1614 // Allow Metal to use the array<T> template unless we force it off.
1615 backend.can_return_array = !msl_options.force_native_arrays;
1616 backend.array_is_value_type = !msl_options.force_native_arrays;
1617 // Arrays which are part of buffer objects are never considered to be value types (just plain C-style).
1618 backend.array_is_value_type_in_buffer_blocks = false;
1619 backend.support_pointer_to_pointer = true;
1620 backend.implicit_c_integer_promotion_rules = true;
1621
1622 capture_output_to_buffer = msl_options.capture_output_to_buffer;
1623 is_rasterization_disabled = msl_options.disable_rasterization || capture_output_to_buffer;
1624
1625 if (is_mesh_shader() && !get_entry_point().flags.get(bit: ExecutionModeOutputPoints))
1626 msl_options.enable_point_size_builtin = false;
1627
1628 // Initialize array here rather than constructor, MSVC 2013 workaround.
1629 for (auto &id : next_metal_resource_ids)
1630 id = 0;
1631
1632 fixup_anonymous_struct_names();
1633 fixup_type_alias();
1634 replace_illegal_names();
1635 if (get_execution_model() == ExecutionModelMeshEXT)
1636 {
1637 // Emit proxy entry-point for the sake of copy-pass
1638 emit_mesh_entry_point();
1639 }
1640 sync_entry_point_aliases_and_names();
1641
1642 build_function_control_flow_graphs_and_analyze();
1643 update_active_builtins();
1644 analyze_image_and_sampler_usage();
1645 analyze_sampled_image_usage();
1646 analyze_interlocked_resource_usage();
1647 preprocess_op_codes();
1648 build_implicit_builtins();
1649
1650 if (needs_manual_helper_invocation_updates() && needs_helper_invocation)
1651 {
1652 string builtin_helper_invocation = builtin_to_glsl(builtin: BuiltInHelperInvocation, storage: StorageClassInput);
1653 string discard_expr = join(ts&: builtin_helper_invocation, ts: " = true, discard_fragment()");
1654 if (msl_options.force_fragment_with_side_effects_execution)
1655 discard_expr = join(ts: "!", ts&: builtin_helper_invocation, ts: " ? (", ts&: discard_expr, ts: ") : (void)0");
1656 backend.discard_literal = discard_expr;
1657 backend.demote_literal = discard_expr;
1658 }
1659 else
1660 {
1661 backend.discard_literal = "discard_fragment()";
1662 backend.demote_literal = "discard_fragment()";
1663 }
1664
1665 fixup_image_load_store_access();
1666
1667 set_enabled_interface_variables(get_active_interface_variables());
1668 if (msl_options.force_active_argument_buffer_resources)
1669 activate_argument_buffer_resources();
1670
1671 if (swizzle_buffer_id)
1672 add_active_interface_variable(var_id: swizzle_buffer_id);
1673 if (buffer_size_buffer_id)
1674 add_active_interface_variable(var_id: buffer_size_buffer_id);
1675 if (view_mask_buffer_id)
1676 add_active_interface_variable(var_id: view_mask_buffer_id);
1677 if (dynamic_offsets_buffer_id)
1678 add_active_interface_variable(var_id: dynamic_offsets_buffer_id);
1679 if (builtin_layer_id)
1680 add_active_interface_variable(var_id: builtin_layer_id);
1681 if (builtin_dispatch_base_id && !msl_options.supports_msl_version(major: 1, minor: 2))
1682 add_active_interface_variable(var_id: builtin_dispatch_base_id);
1683 if (builtin_sample_mask_id)
1684 add_active_interface_variable(var_id: builtin_sample_mask_id);
1685 if (builtin_frag_depth_id)
1686 add_active_interface_variable(var_id: builtin_frag_depth_id);
1687
1688 // Create structs to hold input, output and uniform variables.
1689 // Do output first to ensure out. is declared at top of entry function.
1690 qual_pos_var_name = "";
1691 if (is_mesh_shader())
1692 {
1693 fixup_implicit_builtin_block_names(model: get_execution_model());
1694 }
1695 else
1696 {
1697 stage_out_var_id = add_interface_block(storage: StorageClassOutput);
1698 patch_stage_out_var_id = add_interface_block(storage: StorageClassOutput, patch: true);
1699 stage_in_var_id = add_interface_block(storage: StorageClassInput);
1700 }
1701
1702 if (is_tese_shader())
1703 patch_stage_in_var_id = add_interface_block(storage: StorageClassInput, patch: true);
1704
1705 if (is_tesc_shader())
1706 stage_out_ptr_var_id = add_interface_block_pointer(ib_var_id: stage_out_var_id, storage: StorageClassOutput);
1707 if (is_tessellation_shader())
1708 stage_in_ptr_var_id = add_interface_block_pointer(ib_var_id: stage_in_var_id, storage: StorageClassInput);
1709
1710 if (is_mesh_shader())
1711 {
1712 mesh_out_per_vertex = add_meshlet_block(per_primitive: false);
1713 mesh_out_per_primitive = add_meshlet_block(per_primitive: true);
1714 }
1715
1716 // Metal vertex functions that define no output must disable rasterization and return void.
1717 if (!stage_out_var_id)
1718 is_rasterization_disabled = true;
1719
1720 // Convert the use of global variables to recursively-passed function parameters
1721 localize_global_variables();
1722 extract_global_variables_from_functions();
1723
1724 // Mark any non-stage-in structs to be tightly packed.
1725 mark_packable_structs();
1726 reorder_type_alias();
1727
1728 // Add fixup hooks required by shader inputs and outputs. This needs to happen before
1729 // the loop, so the hooks aren't added multiple times.
1730 fix_up_shader_inputs_outputs();
1731
1732 // If we are using argument buffers, we create argument buffer structures for them here.
1733 // These buffers will be used in the entry point, not the individual resources.
1734 if (msl_options.argument_buffers)
1735 {
1736 if (!msl_options.supports_msl_version(major: 2, minor: 0))
1737 SPIRV_CROSS_THROW("Argument buffers can only be used with MSL 2.0 and up.");
1738 analyze_argument_buffers();
1739 }
1740
1741 uint32_t pass_count = 0;
1742 do
1743 {
1744 reset(iteration_count: pass_count);
1745
1746 // Start bindings at zero.
1747 next_metal_resource_index_buffer = 0;
1748 next_metal_resource_index_texture = 0;
1749 next_metal_resource_index_sampler = 0;
1750 for (auto &id : next_metal_resource_ids)
1751 id = 0;
1752
1753 // Move constructor for this type is broken on GCC 4.9 ...
1754 buffer.reset();
1755
1756 emit_header();
1757 emit_custom_templates();
1758 emit_custom_functions();
1759 emit_specialization_constants_and_structs();
1760 emit_resources();
1761 emit_function(func&: get<SPIRFunction>(id: ir.default_entry_point), return_flags: Bitset());
1762
1763 pass_count++;
1764 } while (is_forcing_recompilation());
1765
1766 return buffer.str();
1767}
1768
1769// Register the need to output any custom functions.
1770void CompilerMSL::preprocess_op_codes()
1771{
1772 OpCodePreprocessor preproc(*this);
1773 traverse_all_reachable_opcodes(block: get<SPIRFunction>(id: ir.default_entry_point), handler&: preproc);
1774
1775 suppress_missing_prototypes = preproc.suppress_missing_prototypes;
1776
1777 if (preproc.uses_atomics)
1778 {
1779 add_header_line(str: "#include <metal_atomic>");
1780 add_pragma_line(line: "#pragma clang diagnostic ignored \"-Wunused-variable\"");
1781 }
1782
1783 // Before MSL 2.1 (2.2 for textures), Metal vertex functions that write to
1784 // resources must disable rasterization and return void.
1785 if ((preproc.uses_buffer_write && !msl_options.supports_msl_version(major: 2, minor: 1)) ||
1786 (preproc.uses_image_write && !msl_options.supports_msl_version(major: 2, minor: 2)))
1787 is_rasterization_disabled = true;
1788
1789 // Tessellation control shaders are run as compute functions in Metal, and so
1790 // must capture their output to a buffer.
1791 if (is_tesc_shader() || (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation))
1792 {
1793 is_rasterization_disabled = true;
1794 capture_output_to_buffer = true;
1795 }
1796
1797 if (preproc.needs_subgroup_invocation_id)
1798 needs_subgroup_invocation_id = true;
1799 if (preproc.needs_subgroup_size)
1800 needs_subgroup_size = true;
1801 // build_implicit_builtins() hasn't run yet, and in fact, this needs to execute
1802 // before then so that gl_SampleID will get added; so we also need to check if
1803 // that function would add gl_FragCoord.
1804 if (preproc.needs_sample_id || msl_options.force_sample_rate_shading ||
1805 (is_sample_rate() && (active_input_builtins.get(bit: BuiltInFragCoord) ||
1806 (need_subpass_input_ms && !msl_options.use_framebuffer_fetch_subpasses))))
1807 needs_sample_id = true;
1808 if (preproc.needs_helper_invocation || active_input_builtins.get(bit: BuiltInHelperInvocation))
1809 needs_helper_invocation = true;
1810
1811 // OpKill is removed by the parser, so we need to identify those by inspecting
1812 // blocks.
1813 ir.for_each_typed_id<SPIRBlock>(op: [&preproc](uint32_t, SPIRBlock &block) {
1814 if (block.terminator == SPIRBlock::Kill)
1815 preproc.uses_discard = true;
1816 });
1817
1818 // Fragment shaders that both write to storage resources and discard fragments
1819 // need checks on the writes, to work around Metal allowing these writes despite
1820 // the fragment being dead. We also require to force Metal to execute fragment
1821 // shaders instead of being prematurely discarded.
1822 if (preproc.uses_discard && (preproc.uses_buffer_write || preproc.uses_image_write))
1823 {
1824 bool should_enable = (msl_options.check_discarded_frag_stores || msl_options.force_fragment_with_side_effects_execution);
1825 frag_shader_needs_discard_checks |= msl_options.check_discarded_frag_stores;
1826 needs_helper_invocation |= should_enable;
1827 // Fragment discard store checks imply manual HelperInvocation updates.
1828 msl_options.manual_helper_invocation_updates |= should_enable;
1829 }
1830
1831 if (is_intersection_query())
1832 {
1833 add_header_line(str: "#if __METAL_VERSION__ >= 230");
1834 add_header_line(str: "#include <metal_raytracing>");
1835 add_header_line(str: "using namespace metal::raytracing;");
1836 add_header_line(str: "#endif");
1837 }
1838}
1839
1840// Move the Private and Workgroup global variables to the entry function.
1841// Non-constant variables cannot have global scope in Metal.
1842void CompilerMSL::localize_global_variables()
1843{
1844 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
1845 auto iter = global_variables.begin();
1846 while (iter != global_variables.end())
1847 {
1848 uint32_t v_id = *iter;
1849 auto &var = get<SPIRVariable>(id: v_id);
1850 if (var.storage == StorageClassPrivate || var.storage == StorageClassWorkgroup ||
1851 var.storage == StorageClassTaskPayloadWorkgroupEXT)
1852 {
1853 if (!variable_is_lut(var))
1854 entry_func.add_local_variable(id: v_id);
1855 iter = global_variables.erase(itr: iter);
1856 }
1857 else if (var.storage == StorageClassOutput && is_mesh_shader())
1858 {
1859 entry_func.add_local_variable(id: v_id);
1860 iter = global_variables.erase(itr: iter);
1861 }
1862 else
1863 iter++;
1864 }
1865}
1866
1867// For any global variable accessed directly by a function,
1868// extract that variable and add it as an argument to that function.
1869void CompilerMSL::extract_global_variables_from_functions()
1870{
1871 // Uniforms
1872 unordered_set<uint32_t> global_var_ids;
1873 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
1874 // Some builtins resolve directly to a function call which does not need any declared variables.
1875 // Skip these.
1876 if (var.storage == StorageClassInput && has_decoration(id: var.self, decoration: DecorationBuiltIn))
1877 {
1878 auto bi_type = BuiltIn(get_decoration(id: var.self, decoration: DecorationBuiltIn));
1879 if (bi_type == BuiltInHelperInvocation && !needs_manual_helper_invocation_updates())
1880 return;
1881 if (bi_type == BuiltInHelperInvocation && needs_manual_helper_invocation_updates())
1882 {
1883 if (msl_options.is_ios() && !msl_options.supports_msl_version(major: 2, minor: 3))
1884 SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.3 on iOS.");
1885 else if (msl_options.is_macos() && !msl_options.supports_msl_version(major: 2, minor: 1))
1886 SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.1 on macOS.");
1887 // Make sure this is declared and initialized.
1888 // Force this to have the proper name.
1889 set_name(id: var.self, name: builtin_to_glsl(builtin: BuiltInHelperInvocation, storage: StorageClassInput));
1890 auto &entry_func = this->get<SPIRFunction>(id: ir.default_entry_point);
1891 entry_func.add_local_variable(id: var.self);
1892 vars_needing_early_declaration.push_back(t: var.self);
1893 entry_func.fixup_hooks_in.push_back(t: [this, &var]()
1894 { statement(ts: to_name(id: var.self), ts: " = simd_is_helper_thread();"); });
1895 }
1896 }
1897
1898 if (var.storage == StorageClassInput || var.storage == StorageClassOutput ||
1899 var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
1900 var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer)
1901 {
1902 global_var_ids.insert(x: var.self);
1903 }
1904 });
1905
1906 // Local vars that are declared in the main function and accessed directly by a function
1907 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
1908 for (auto &var : entry_func.local_variables)
1909 if (get<SPIRVariable>(id: var).storage != StorageClassFunction)
1910 global_var_ids.insert(x: var);
1911
1912 std::set<uint32_t> added_arg_ids;
1913 unordered_set<uint32_t> processed_func_ids;
1914 extract_global_variables_from_function(func_id: ir.default_entry_point, added_arg_ids, global_var_ids, processed_func_ids);
1915}
1916
1917// MSL does not support the use of global variables for shader input content.
1918// For any global variable accessed directly by the specified function, extract that variable,
1919// add it as an argument to that function, and the arg to the added_arg_ids collection.
1920void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::set<uint32_t> &added_arg_ids,
1921 unordered_set<uint32_t> &global_var_ids,
1922 unordered_set<uint32_t> &processed_func_ids)
1923{
1924 // Avoid processing a function more than once
1925 if (processed_func_ids.find(x: func_id) != processed_func_ids.end())
1926 {
1927 // Return function global variables
1928 added_arg_ids = function_global_vars[func_id];
1929 return;
1930 }
1931
1932 processed_func_ids.insert(x: func_id);
1933
1934 auto &func = get<SPIRFunction>(id: func_id);
1935
1936 // Recursively establish global args added to functions on which we depend.
1937 for (auto block : func.blocks)
1938 {
1939 auto &b = get<SPIRBlock>(id: block);
1940 for (auto &i : b.ops)
1941 {
1942 auto ops = stream(instr: i);
1943 auto op = static_cast<Op>(i.op);
1944
1945 switch (op)
1946 {
1947 case OpLoad:
1948 case OpInBoundsAccessChain:
1949 case OpAccessChain:
1950 case OpPtrAccessChain:
1951 case OpArrayLength:
1952 {
1953 uint32_t base_id = ops[2];
1954 if (global_var_ids.find(x: base_id) != global_var_ids.end())
1955 added_arg_ids.insert(x: base_id);
1956
1957 // Use Metal's native frame-buffer fetch API for subpass inputs.
1958 auto &type = get<SPIRType>(id: ops[0]);
1959 if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData &&
1960 (!msl_options.use_framebuffer_fetch_subpasses))
1961 {
1962 // Implicitly reads gl_FragCoord.
1963 assert(builtin_frag_coord_id != 0);
1964 added_arg_ids.insert(x: builtin_frag_coord_id);
1965 if (msl_options.multiview)
1966 {
1967 // Implicitly reads gl_ViewIndex.
1968 assert(builtin_view_idx_id != 0);
1969 added_arg_ids.insert(x: builtin_view_idx_id);
1970 }
1971 else if (msl_options.arrayed_subpass_input)
1972 {
1973 // Implicitly reads gl_Layer.
1974 assert(builtin_layer_id != 0);
1975 added_arg_ids.insert(x: builtin_layer_id);
1976 }
1977 }
1978
1979 break;
1980 }
1981
1982 case OpFunctionCall:
1983 {
1984 // First see if any of the function call args are globals
1985 for (uint32_t arg_idx = 3; arg_idx < i.length; arg_idx++)
1986 {
1987 uint32_t arg_id = ops[arg_idx];
1988 if (global_var_ids.find(x: arg_id) != global_var_ids.end())
1989 added_arg_ids.insert(x: arg_id);
1990 }
1991
1992 // Then recurse into the function itself to extract globals used internally in the function
1993 uint32_t inner_func_id = ops[2];
1994 std::set<uint32_t> inner_func_args;
1995 extract_global_variables_from_function(func_id: inner_func_id, added_arg_ids&: inner_func_args, global_var_ids,
1996 processed_func_ids);
1997 added_arg_ids.insert(first: inner_func_args.begin(), last: inner_func_args.end());
1998 break;
1999 }
2000
2001 case OpStore:
2002 {
2003 uint32_t base_id = ops[0];
2004 if (global_var_ids.find(x: base_id) != global_var_ids.end())
2005 {
2006 added_arg_ids.insert(x: base_id);
2007
2008 if (msl_options.input_attachment_is_ds_attachment && base_id == builtin_frag_depth_id)
2009 writes_to_depth = true;
2010 }
2011
2012 uint32_t rvalue_id = ops[1];
2013 if (global_var_ids.find(x: rvalue_id) != global_var_ids.end())
2014 added_arg_ids.insert(x: rvalue_id);
2015
2016 if (needs_frag_discard_checks())
2017 added_arg_ids.insert(x: builtin_helper_invocation_id);
2018
2019 break;
2020 }
2021
2022 case OpSelect:
2023 {
2024 uint32_t base_id = ops[3];
2025 if (global_var_ids.find(x: base_id) != global_var_ids.end())
2026 added_arg_ids.insert(x: base_id);
2027 base_id = ops[4];
2028 if (global_var_ids.find(x: base_id) != global_var_ids.end())
2029 added_arg_ids.insert(x: base_id);
2030 break;
2031 }
2032
2033 case OpAtomicExchange:
2034 case OpAtomicCompareExchange:
2035 case OpAtomicStore:
2036 case OpAtomicIIncrement:
2037 case OpAtomicIDecrement:
2038 case OpAtomicIAdd:
2039 case OpAtomicFAddEXT:
2040 case OpAtomicISub:
2041 case OpAtomicSMin:
2042 case OpAtomicUMin:
2043 case OpAtomicSMax:
2044 case OpAtomicUMax:
2045 case OpAtomicAnd:
2046 case OpAtomicOr:
2047 case OpAtomicXor:
2048 case OpImageWrite:
2049 {
2050 if (needs_frag_discard_checks())
2051 added_arg_ids.insert(x: builtin_helper_invocation_id);
2052 uint32_t ptr = 0;
2053 if (op == OpAtomicStore || op == OpImageWrite)
2054 ptr = ops[0];
2055 else
2056 ptr = ops[2];
2057 if (global_var_ids.find(x: ptr) != global_var_ids.end())
2058 added_arg_ids.insert(x: ptr);
2059 break;
2060 }
2061
2062 // Emulate texture2D atomic operations
2063 case OpImageTexelPointer:
2064 {
2065 // When using the pointer, we need to know which variable it is actually loaded from.
2066 uint32_t base_id = ops[2];
2067 auto *var = maybe_get_backing_variable(chain: base_id);
2068 if (var)
2069 {
2070 if (atomic_image_vars_emulated.count(x: var->self) &&
2071 !get<SPIRType>(id: var->basetype).array.empty())
2072 {
2073 SPIRV_CROSS_THROW(
2074 "Cannot emulate array of storage images with atomics. Use MSL 3.1 for native support.");
2075 }
2076
2077 if (global_var_ids.find(x: base_id) != global_var_ids.end())
2078 added_arg_ids.insert(x: base_id);
2079 }
2080 break;
2081 }
2082
2083 case OpExtInst:
2084 {
2085 uint32_t extension_set = ops[2];
2086 if (get<SPIRExtension>(id: extension_set).ext == SPIRExtension::GLSL)
2087 {
2088 auto op_450 = static_cast<GLSLstd450>(ops[3]);
2089 switch (op_450)
2090 {
2091 case GLSLstd450InterpolateAtCentroid:
2092 case GLSLstd450InterpolateAtSample:
2093 case GLSLstd450InterpolateAtOffset:
2094 {
2095 // For these, we really need the stage-in block. It is theoretically possible to pass the
2096 // interpolant object, but a) doing so would require us to create an entirely new variable
2097 // with Interpolant type, and b) if we have a struct or array, handling all the members and
2098 // elements could get unwieldy fast.
2099 added_arg_ids.insert(x: stage_in_var_id);
2100 break;
2101 }
2102
2103 case GLSLstd450Modf:
2104 case GLSLstd450Frexp:
2105 {
2106 uint32_t base_id = ops[5];
2107 if (global_var_ids.find(x: base_id) != global_var_ids.end())
2108 added_arg_ids.insert(x: base_id);
2109 break;
2110 }
2111
2112 default:
2113 break;
2114 }
2115 }
2116 break;
2117 }
2118
2119 case OpGroupNonUniformInverseBallot:
2120 {
2121 added_arg_ids.insert(x: builtin_subgroup_invocation_id_id);
2122 break;
2123 }
2124
2125 case OpGroupNonUniformBallotFindLSB:
2126 case OpGroupNonUniformBallotFindMSB:
2127 {
2128 added_arg_ids.insert(x: builtin_subgroup_size_id);
2129 break;
2130 }
2131
2132 case OpGroupNonUniformBallotBitCount:
2133 {
2134 auto operation = static_cast<GroupOperation>(ops[3]);
2135 switch (operation)
2136 {
2137 case GroupOperationReduce:
2138 added_arg_ids.insert(x: builtin_subgroup_size_id);
2139 break;
2140 case GroupOperationInclusiveScan:
2141 case GroupOperationExclusiveScan:
2142 added_arg_ids.insert(x: builtin_subgroup_invocation_id_id);
2143 break;
2144 default:
2145 break;
2146 }
2147 break;
2148 }
2149
2150 case OpDemoteToHelperInvocation:
2151 if (needs_manual_helper_invocation_updates() && needs_helper_invocation)
2152 added_arg_ids.insert(x: builtin_helper_invocation_id);
2153 break;
2154
2155 case OpIsHelperInvocationEXT:
2156 if (needs_manual_helper_invocation_updates())
2157 added_arg_ids.insert(x: builtin_helper_invocation_id);
2158 break;
2159
2160 case OpRayQueryInitializeKHR:
2161 case OpRayQueryProceedKHR:
2162 case OpRayQueryTerminateKHR:
2163 case OpRayQueryGenerateIntersectionKHR:
2164 case OpRayQueryConfirmIntersectionKHR:
2165 {
2166 // Ray query accesses memory directly, need check pass down object if using Private storage class.
2167 uint32_t base_id = ops[0];
2168 if (global_var_ids.find(x: base_id) != global_var_ids.end())
2169 added_arg_ids.insert(x: base_id);
2170 break;
2171 }
2172
2173 case OpRayQueryGetRayTMinKHR:
2174 case OpRayQueryGetRayFlagsKHR:
2175 case OpRayQueryGetWorldRayOriginKHR:
2176 case OpRayQueryGetWorldRayDirectionKHR:
2177 case OpRayQueryGetIntersectionCandidateAABBOpaqueKHR:
2178 case OpRayQueryGetIntersectionTypeKHR:
2179 case OpRayQueryGetIntersectionTKHR:
2180 case OpRayQueryGetIntersectionInstanceCustomIndexKHR:
2181 case OpRayQueryGetIntersectionInstanceIdKHR:
2182 case OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR:
2183 case OpRayQueryGetIntersectionGeometryIndexKHR:
2184 case OpRayQueryGetIntersectionPrimitiveIndexKHR:
2185 case OpRayQueryGetIntersectionBarycentricsKHR:
2186 case OpRayQueryGetIntersectionFrontFaceKHR:
2187 case OpRayQueryGetIntersectionObjectRayDirectionKHR:
2188 case OpRayQueryGetIntersectionObjectRayOriginKHR:
2189 case OpRayQueryGetIntersectionObjectToWorldKHR:
2190 case OpRayQueryGetIntersectionWorldToObjectKHR:
2191 {
2192 // Ray query accesses memory directly, need check pass down object if using Private storage class.
2193 uint32_t base_id = ops[2];
2194 if (global_var_ids.find(x: base_id) != global_var_ids.end())
2195 added_arg_ids.insert(x: base_id);
2196 break;
2197 }
2198
2199 case OpSetMeshOutputsEXT:
2200 {
2201 if (builtin_local_invocation_index_id != 0)
2202 added_arg_ids.insert(x: builtin_local_invocation_index_id);
2203 if (builtin_mesh_sizes_id != 0)
2204 added_arg_ids.insert(x: builtin_mesh_sizes_id);
2205 break;
2206 }
2207
2208 default:
2209 break;
2210 }
2211
2212 if (needs_manual_helper_invocation_updates() && b.terminator == SPIRBlock::Kill &&
2213 needs_helper_invocation)
2214 added_arg_ids.insert(x: builtin_helper_invocation_id);
2215
2216 // TODO: Add all other operations which can affect memory.
2217 // We should consider a more unified system here to reduce boiler-plate.
2218 // This kind of analysis is done in several places ...
2219 }
2220
2221 if (b.terminator == SPIRBlock::EmitMeshTasks && builtin_task_grid_id != 0)
2222 added_arg_ids.insert(x: builtin_task_grid_id);
2223 }
2224
2225 function_global_vars[func_id] = added_arg_ids;
2226
2227 // Add the global variables as arguments to the function
2228 if (func_id != ir.default_entry_point)
2229 {
2230 bool control_point_added_in = false;
2231 bool control_point_added_out = false;
2232 bool patch_added_in = false;
2233 bool patch_added_out = false;
2234
2235 for (uint32_t arg_id : added_arg_ids)
2236 {
2237 auto &var = get<SPIRVariable>(id: arg_id);
2238 uint32_t type_id = var.basetype;
2239 auto *p_type = &get<SPIRType>(id: type_id);
2240 BuiltIn bi_type = BuiltIn(get_decoration(id: arg_id, decoration: DecorationBuiltIn));
2241
2242 bool is_patch = has_decoration(id: arg_id, decoration: DecorationPatch) || is_patch_block(type: *p_type);
2243 bool is_block = has_decoration(id: p_type->self, decoration: DecorationBlock);
2244 bool is_control_point_storage =
2245 !is_patch && ((is_tessellation_shader() && var.storage == StorageClassInput) ||
2246 (is_tesc_shader() && var.storage == StorageClassOutput));
2247 bool is_patch_block_storage = is_patch && is_block && var.storage == StorageClassOutput;
2248 bool is_builtin = is_builtin_variable(var);
2249 bool variable_is_stage_io =
2250 !is_builtin || bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
2251 bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance ||
2252 p_type->basetype == SPIRType::Struct;
2253 bool is_redirected_to_global_stage_io = (is_control_point_storage || is_patch_block_storage) &&
2254 variable_is_stage_io;
2255
2256 // If output is masked it is not considered part of the global stage IO interface.
2257 if (is_redirected_to_global_stage_io && var.storage == StorageClassOutput)
2258 is_redirected_to_global_stage_io = !is_stage_output_variable_masked(var);
2259
2260 if (is_redirected_to_global_stage_io)
2261 {
2262 // Tessellation control shaders see inputs and per-point outputs as arrays.
2263 // Similarly, tessellation evaluation shaders see per-point inputs as arrays.
2264 // We collected them into a structure; we must pass the array of this
2265 // structure to the function.
2266 std::string name;
2267 if (is_patch)
2268 name = var.storage == StorageClassInput ? patch_stage_in_var_name : patch_stage_out_var_name;
2269 else
2270 name = var.storage == StorageClassInput ? "gl_in" : "gl_out";
2271
2272 if (var.storage == StorageClassOutput && has_decoration(id: p_type->self, decoration: DecorationBlock))
2273 {
2274 // If we're redirecting a block, we might still need to access the original block
2275 // variable if we're masking some members.
2276 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(p_type->member_types.size()); mbr_idx++)
2277 {
2278 if (is_stage_output_block_member_masked(var, index: mbr_idx, strip_array: true))
2279 {
2280 func.add_parameter(parameter_type: var.basetype, id: var.self, alias_global_variable: true);
2281 break;
2282 }
2283 }
2284 }
2285
2286 if (var.storage == StorageClassInput)
2287 {
2288 auto &added_in = is_patch ? patch_added_in : control_point_added_in;
2289 if (added_in)
2290 continue;
2291 arg_id = is_patch ? patch_stage_in_var_id : stage_in_ptr_var_id;
2292 added_in = true;
2293 }
2294 else if (var.storage == StorageClassOutput)
2295 {
2296 auto &added_out = is_patch ? patch_added_out : control_point_added_out;
2297 if (added_out)
2298 continue;
2299 arg_id = is_patch ? patch_stage_out_var_id : stage_out_ptr_var_id;
2300 added_out = true;
2301 }
2302
2303 type_id = get<SPIRVariable>(id: arg_id).basetype;
2304 uint32_t next_id = ir.increase_bound_by(count: 1);
2305 func.add_parameter(parameter_type: type_id, id: next_id, alias_global_variable: true);
2306 set<SPIRVariable>(id: next_id, args&: type_id, args: StorageClassFunction, args: 0, args&: arg_id);
2307
2308 set_name(id: next_id, name);
2309 if (is_tese_shader() && msl_options.raw_buffer_tese_input && var.storage == StorageClassInput)
2310 set_decoration(id: next_id, decoration: DecorationNonWritable);
2311 }
2312 else if (is_builtin && is_mesh_shader())
2313 {
2314 uint32_t next_id = ir.increase_bound_by(count: 1);
2315 func.add_parameter(parameter_type: type_id, id: next_id, alias_global_variable: true);
2316 auto &v = set<SPIRVariable>(id: next_id, args&: type_id, args: StorageClassFunction, args: 0, args&: arg_id);
2317 v.storage = StorageClassWorkgroup;
2318
2319 // Ensure the existing variable has a valid name and the new variable has all the same meta info
2320 set_name(id: arg_id, name: ensure_valid_name(name: to_name(id: arg_id), pfx: "v"));
2321 ir.meta[next_id] = ir.meta[arg_id];
2322 }
2323 else if (is_builtin && has_decoration(id: p_type->self, decoration: DecorationBlock))
2324 {
2325 // Get the pointee type
2326 type_id = get_pointee_type_id(type_id);
2327 p_type = &get<SPIRType>(id: type_id);
2328
2329 uint32_t mbr_idx = 0;
2330 for (auto &mbr_type_id : p_type->member_types)
2331 {
2332 BuiltIn builtin = BuiltInMax;
2333 is_builtin = is_member_builtin(type: *p_type, index: mbr_idx, builtin: &builtin);
2334 if (is_builtin && has_active_builtin(builtin, storage: var.storage))
2335 {
2336 // Add a arg variable with the same type and decorations as the member
2337 uint32_t next_ids = ir.increase_bound_by(count: 2);
2338 uint32_t ptr_type_id = next_ids + 0;
2339 uint32_t var_id = next_ids + 1;
2340
2341 // Make sure we have an actual pointer type,
2342 // so that we will get the appropriate address space when declaring these builtins.
2343 auto &ptr = set<SPIRType>(id: ptr_type_id, args&: get<SPIRType>(id: mbr_type_id));
2344 ptr.self = mbr_type_id;
2345 ptr.storage = var.storage;
2346 ptr.pointer = true;
2347 ptr.pointer_depth++;
2348 ptr.parent_type = mbr_type_id;
2349
2350 func.add_parameter(parameter_type: mbr_type_id, id: var_id, alias_global_variable: true);
2351 set<SPIRVariable>(id: var_id, args&: ptr_type_id, args: StorageClassFunction);
2352 ir.meta[var_id].decoration = ir.meta[type_id].members[mbr_idx];
2353 }
2354 mbr_idx++;
2355 }
2356 }
2357 else
2358 {
2359 uint32_t next_id = ir.increase_bound_by(count: 1);
2360 func.add_parameter(parameter_type: type_id, id: next_id, alias_global_variable: true);
2361 set<SPIRVariable>(id: next_id, args&: type_id, args: StorageClassFunction, args: 0, args&: arg_id);
2362
2363 // Ensure the new variable has all the same meta info
2364 ir.meta[next_id] = ir.meta[arg_id];
2365 }
2366 }
2367 }
2368}
2369
2370// For all variables that are some form of non-input-output interface block, mark that all the structs
2371// that are recursively contained within the type referenced by that variable should be packed tightly.
2372void CompilerMSL::mark_packable_structs()
2373{
2374 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
2375 if (var.storage != StorageClassFunction && !is_hidden_variable(var))
2376 {
2377 auto &type = this->get<SPIRType>(id: var.basetype);
2378 if (type.pointer &&
2379 (type.storage == StorageClassUniform || type.storage == StorageClassUniformConstant ||
2380 type.storage == StorageClassPushConstant || type.storage == StorageClassStorageBuffer) &&
2381 (has_decoration(id: type.self, decoration: DecorationBlock) || has_decoration(id: type.self, decoration: DecorationBufferBlock)))
2382 mark_as_packable(type);
2383 }
2384
2385 if (var.storage == StorageClassWorkgroup)
2386 {
2387 auto *type = &this->get<SPIRType>(id: var.basetype);
2388 if (type->basetype == SPIRType::Struct)
2389 mark_as_workgroup_struct(type&: *type);
2390 }
2391 });
2392
2393 // Physical storage buffer pointers can appear outside of the context of a variable, if the address
2394 // is calculated from a ulong or uvec2 and cast to a pointer, so check if they need to be packed too.
2395 ir.for_each_typed_id<SPIRType>(op: [&](uint32_t, SPIRType &type) {
2396 if (type.basetype == SPIRType::Struct && type.pointer && type.storage == StorageClassPhysicalStorageBuffer)
2397 mark_as_packable(type);
2398 });
2399}
2400
2401// If the specified type is a struct, it and any nested structs
2402// are marked as packable with the SPIRVCrossDecorationBufferBlockRepacked decoration,
2403void CompilerMSL::mark_as_packable(SPIRType &type)
2404{
2405 // If this is not the base type (eg. it's a pointer or array), tunnel down
2406 if (type.parent_type)
2407 {
2408 mark_as_packable(type&: get<SPIRType>(id: type.parent_type));
2409 return;
2410 }
2411
2412 // Handle possible recursion when a struct contains a pointer to its own type nested somewhere.
2413 if (type.basetype == SPIRType::Struct && !has_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationBufferBlockRepacked))
2414 {
2415 set_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationBufferBlockRepacked);
2416
2417 // Recurse
2418 uint32_t mbr_cnt = uint32_t(type.member_types.size());
2419 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
2420 {
2421 uint32_t mbr_type_id = type.member_types[mbr_idx];
2422 auto &mbr_type = get<SPIRType>(id: mbr_type_id);
2423 mark_as_packable(type&: mbr_type);
2424 if (mbr_type.type_alias)
2425 {
2426 auto &mbr_type_alias = get<SPIRType>(id: mbr_type.type_alias);
2427 mark_as_packable(type&: mbr_type_alias);
2428 }
2429 }
2430 }
2431}
2432
2433// If the specified type is a struct, it and any nested structs
2434// are marked as used with workgroup storage using the SPIRVCrossDecorationWorkgroupStruct decoration.
2435void CompilerMSL::mark_as_workgroup_struct(SPIRType &type)
2436{
2437 // If this is not the base type (eg. it's a pointer or array), tunnel down
2438 if (type.parent_type)
2439 {
2440 mark_as_workgroup_struct(type&: get<SPIRType>(id: type.parent_type));
2441 return;
2442 }
2443
2444 // Handle possible recursion when a struct contains a pointer to its own type nested somewhere.
2445 if (type.basetype == SPIRType::Struct && !has_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationWorkgroupStruct))
2446 {
2447 set_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationWorkgroupStruct);
2448
2449 // Recurse
2450 uint32_t mbr_cnt = uint32_t(type.member_types.size());
2451 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
2452 {
2453 uint32_t mbr_type_id = type.member_types[mbr_idx];
2454 auto &mbr_type = get<SPIRType>(id: mbr_type_id);
2455 mark_as_workgroup_struct(type&: mbr_type);
2456 if (mbr_type.type_alias)
2457 {
2458 auto &mbr_type_alias = get<SPIRType>(id: mbr_type.type_alias);
2459 mark_as_workgroup_struct(type&: mbr_type_alias);
2460 }
2461 }
2462 }
2463}
2464
2465// If a shader input exists at the location, it is marked as being used by this shader
2466void CompilerMSL::mark_location_as_used_by_shader(uint32_t location, const SPIRType &type,
2467 StorageClass storage, bool fallback)
2468{
2469 uint32_t count = type_to_location_count(type);
2470 switch (storage)
2471 {
2472 case StorageClassInput:
2473 for (uint32_t i = 0; i < count; i++)
2474 {
2475 location_inputs_in_use.insert(x: location + i);
2476 if (fallback)
2477 location_inputs_in_use_fallback.insert(x: location + i);
2478 }
2479 break;
2480 case StorageClassOutput:
2481 for (uint32_t i = 0; i < count; i++)
2482 {
2483 location_outputs_in_use.insert(x: location + i);
2484 if (fallback)
2485 location_outputs_in_use_fallback.insert(x: location + i);
2486 }
2487 break;
2488 default:
2489 return;
2490 }
2491}
2492
2493uint32_t CompilerMSL::get_target_components_for_fragment_location(uint32_t location) const
2494{
2495 auto itr = fragment_output_components.find(x: location);
2496 if (itr == end(cont: fragment_output_components))
2497 return 4;
2498 else
2499 return itr->second;
2500}
2501
2502uint32_t CompilerMSL::build_extended_vector_type(uint32_t type_id, uint32_t components, SPIRType::BaseType basetype)
2503{
2504 assert(components > 1);
2505 uint32_t new_type_id = ir.increase_bound_by(count: 1);
2506 const auto *p_old_type = &get<SPIRType>(id: type_id);
2507 const SPIRType *old_ptr_t = nullptr;
2508 const SPIRType *old_array_t = nullptr;
2509
2510 if (is_pointer(type: *p_old_type))
2511 {
2512 old_ptr_t = p_old_type;
2513 p_old_type = &get_pointee_type(type: *old_ptr_t);
2514 }
2515
2516 if (is_array(type: *p_old_type))
2517 {
2518 old_array_t = p_old_type;
2519 p_old_type = &get_type(id: old_array_t->parent_type);
2520 }
2521
2522 auto *type = &set<SPIRType>(id: new_type_id, args: *p_old_type);
2523 assert(is_scalar(*type) || is_vector(*type));
2524 type->op = OpTypeVector;
2525 type->vecsize = components;
2526 if (basetype != SPIRType::Unknown)
2527 type->basetype = basetype;
2528 type->self = new_type_id;
2529 // We want parent type to point to the scalar type.
2530 type->parent_type = is_scalar(type: *p_old_type) ? TypeID(p_old_type->self) : p_old_type->parent_type;
2531 assert(is_scalar(get<SPIRType>(type->parent_type)));
2532 type->array.clear();
2533 type->array_size_literal.clear();
2534 type->pointer = false;
2535
2536 if (old_array_t)
2537 {
2538 uint32_t array_type_id = ir.increase_bound_by(count: 1);
2539 type = &set<SPIRType>(id: array_type_id, args&: *type);
2540 type->op = OpTypeArray;
2541 type->parent_type = new_type_id;
2542 type->array = old_array_t->array;
2543 type->array_size_literal = old_array_t->array_size_literal;
2544 new_type_id = array_type_id;
2545 }
2546
2547 if (old_ptr_t)
2548 {
2549 uint32_t ptr_type_id = ir.increase_bound_by(count: 1);
2550 type = &set<SPIRType>(id: ptr_type_id, args&: *type);
2551 type->op = OpTypePointer;
2552 type->parent_type = new_type_id;
2553 type->storage = old_ptr_t->storage;
2554 type->pointer = true;
2555 type->pointer_depth++;
2556 new_type_id = ptr_type_id;
2557 }
2558
2559 return new_type_id;
2560}
2561
2562uint32_t CompilerMSL::build_msl_interpolant_type(uint32_t type_id, bool is_noperspective)
2563{
2564 uint32_t new_type_id = ir.increase_bound_by(count: 1);
2565 SPIRType &type = set<SPIRType>(id: new_type_id, args&: get<SPIRType>(id: type_id));
2566 type.basetype = SPIRType::Interpolant;
2567 type.parent_type = type_id;
2568 // In Metal, the pull-model interpolant type encodes perspective-vs-no-perspective in the type itself.
2569 // Add this decoration so we know which argument to pass to the template.
2570 if (is_noperspective)
2571 set_decoration(id: new_type_id, decoration: DecorationNoPerspective);
2572 return new_type_id;
2573}
2574
2575bool CompilerMSL::add_component_variable_to_interface_block(spv::StorageClass storage, const std::string &ib_var_ref,
2576 SPIRVariable &var,
2577 const SPIRType &type,
2578 InterfaceBlockMeta &meta)
2579{
2580 // Deal with Component decorations.
2581 const InterfaceBlockMeta::LocationMeta *location_meta = nullptr;
2582 uint32_t location = ~0u;
2583 if (has_decoration(id: var.self, decoration: DecorationLocation))
2584 {
2585 location = get_decoration(id: var.self, decoration: DecorationLocation);
2586 auto location_meta_itr = meta.location_meta.find(x: location);
2587 if (location_meta_itr != end(cont&: meta.location_meta))
2588 location_meta = &location_meta_itr->second;
2589 }
2590
2591 // Check if we need to pad fragment output to match a certain number of components.
2592 if (location_meta)
2593 {
2594 bool pad_fragment_output = has_decoration(id: var.self, decoration: DecorationLocation) &&
2595 msl_options.pad_fragment_output_components &&
2596 get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput;
2597
2598 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
2599 uint32_t start_component = get_decoration(id: var.self, decoration: DecorationComponent);
2600 uint32_t type_components = type.vecsize;
2601 uint32_t num_components = location_meta->num_components;
2602
2603 if (pad_fragment_output)
2604 {
2605 uint32_t locn = get_decoration(id: var.self, decoration: DecorationLocation);
2606 num_components = max<uint32_t>(a: num_components, b: get_target_components_for_fragment_location(location: locn));
2607 }
2608
2609 // We have already declared an IO block member as m_location_N.
2610 // Just emit an early-declared variable and fixup as needed.
2611 // Arrays need to be unrolled here since each location might need a different number of components.
2612 entry_func.add_local_variable(id: var.self);
2613 vars_needing_early_declaration.push_back(t: var.self);
2614
2615 if (var.storage == StorageClassInput)
2616 {
2617 entry_func.fixup_hooks_in.push_back(t: [=, &type, &var]() {
2618 if (!type.array.empty())
2619 {
2620 uint32_t array_size = to_array_size_literal(type);
2621 for (uint32_t loc_off = 0; loc_off < array_size; loc_off++)
2622 {
2623 statement(ts: to_name(id: var.self), ts: "[", ts&: loc_off, ts: "]", ts: " = ", ts: ib_var_ref,
2624 ts: ".m_location_", ts: location + loc_off,
2625 ts: vector_swizzle(vecsize: type_components, index: start_component), ts: ";");
2626 }
2627 }
2628 else
2629 {
2630 statement(ts: to_name(id: var.self), ts: " = ", ts: ib_var_ref, ts: ".m_location_", ts: location,
2631 ts: vector_swizzle(vecsize: type_components, index: start_component), ts: ";");
2632 }
2633 });
2634 }
2635 else
2636 {
2637 entry_func.fixup_hooks_out.push_back(t: [=, &type, &var]() {
2638 if (!type.array.empty())
2639 {
2640 uint32_t array_size = to_array_size_literal(type);
2641 for (uint32_t loc_off = 0; loc_off < array_size; loc_off++)
2642 {
2643 statement(ts: ib_var_ref, ts: ".m_location_", ts: location + loc_off,
2644 ts: vector_swizzle(vecsize: type_components, index: start_component), ts: " = ",
2645 ts: to_name(id: var.self), ts: "[", ts&: loc_off, ts: "];");
2646 }
2647 }
2648 else
2649 {
2650 statement(ts: ib_var_ref, ts: ".m_location_", ts: location,
2651 ts: vector_swizzle(vecsize: type_components, index: start_component), ts: " = ", ts: to_name(id: var.self), ts: ";");
2652 }
2653 });
2654 }
2655 return true;
2656 }
2657 else
2658 return false;
2659}
2660
2661void CompilerMSL::add_plain_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2662 SPIRType &ib_type, SPIRVariable &var, InterfaceBlockMeta &meta)
2663{
2664 bool is_builtin = is_builtin_variable(var);
2665 BuiltIn builtin = BuiltIn(get_decoration(id: var.self, decoration: DecorationBuiltIn));
2666 bool is_flat = has_decoration(id: var.self, decoration: DecorationFlat);
2667 bool is_noperspective = has_decoration(id: var.self, decoration: DecorationNoPerspective);
2668 bool is_centroid = has_decoration(id: var.self, decoration: DecorationCentroid);
2669 bool is_sample = has_decoration(id: var.self, decoration: DecorationSample);
2670
2671 // Add a reference to the variable type to the interface struct.
2672 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2673 uint32_t type_id = ensure_correct_builtin_type(type_id: var.basetype, builtin);
2674 var.basetype = type_id;
2675
2676 type_id = get_pointee_type_id(type_id: var.basetype);
2677 if (meta.strip_array && is_array(type: get<SPIRType>(id: type_id)))
2678 type_id = get<SPIRType>(id: type_id).parent_type;
2679 auto &type = get<SPIRType>(id: type_id);
2680 uint32_t target_components = 0;
2681 uint32_t type_components = type.vecsize;
2682
2683 bool padded_output = false;
2684 bool padded_input = false;
2685 uint32_t start_component = 0;
2686
2687 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
2688
2689 if (add_component_variable_to_interface_block(storage, ib_var_ref, var, type, meta))
2690 return;
2691
2692 bool pad_fragment_output = has_decoration(id: var.self, decoration: DecorationLocation) &&
2693 msl_options.pad_fragment_output_components &&
2694 get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput;
2695
2696 if (pad_fragment_output)
2697 {
2698 uint32_t locn = get_decoration(id: var.self, decoration: DecorationLocation);
2699 target_components = get_target_components_for_fragment_location(location: locn);
2700 if (type_components < target_components)
2701 {
2702 // Make a new type here.
2703 type_id = build_extended_vector_type(type_id, components: target_components);
2704 padded_output = true;
2705 }
2706 }
2707
2708 if (storage == StorageClassInput && pull_model_inputs.count(x: var.self))
2709 ib_type.member_types.push_back(t: build_msl_interpolant_type(type_id, is_noperspective));
2710 else
2711 ib_type.member_types.push_back(t: type_id);
2712
2713 // Give the member a name
2714 string mbr_name = ensure_valid_name(name: to_expression(id: var.self), pfx: "m");
2715 set_member_name(id: ib_type.self, index: ib_mbr_idx, name: mbr_name);
2716
2717 // Update the original variable reference to include the structure reference
2718 string qual_var_name = ib_var_ref + "." + mbr_name;
2719 // If using pull-model interpolation, need to add a call to the correct interpolation method.
2720 if (storage == StorageClassInput && pull_model_inputs.count(x: var.self))
2721 {
2722 if (is_centroid)
2723 qual_var_name += ".interpolate_at_centroid()";
2724 else if (is_sample)
2725 qual_var_name += join(ts: ".interpolate_at_sample(", ts: to_expression(id: builtin_sample_id_id), ts: ")");
2726 else
2727 qual_var_name += ".interpolate_at_center()";
2728 }
2729
2730 if (padded_output || padded_input)
2731 {
2732 entry_func.add_local_variable(id: var.self);
2733 vars_needing_early_declaration.push_back(t: var.self);
2734
2735 if (padded_output)
2736 {
2737 entry_func.fixup_hooks_out.push_back(t: [=, &var]() {
2738 statement(ts: qual_var_name, ts: vector_swizzle(vecsize: type_components, index: start_component), ts: " = ", ts: to_name(id: var.self),
2739 ts: ";");
2740 });
2741 }
2742 else
2743 {
2744 entry_func.fixup_hooks_in.push_back(t: [=, &var]() {
2745 statement(ts: to_name(id: var.self), ts: " = ", ts: qual_var_name, ts: vector_swizzle(vecsize: type_components, index: start_component),
2746 ts: ";");
2747 });
2748 }
2749 }
2750 else if (!meta.strip_array)
2751 ir.meta[var.self].decoration.qualified_alias = qual_var_name;
2752
2753 if (var.storage == StorageClassOutput && var.initializer != ID(0))
2754 {
2755 if (padded_output || padded_input)
2756 {
2757 entry_func.fixup_hooks_in.push_back(
2758 t: [=, &var]() { statement(ts: to_name(id: var.self), ts: " = ", ts: to_expression(id: var.initializer), ts: ";"); });
2759 }
2760 else
2761 {
2762 if (meta.strip_array)
2763 {
2764 entry_func.fixup_hooks_in.push_back(t: [=, &var]() {
2765 uint32_t index = get_extended_decoration(id: var.self, decoration: SPIRVCrossDecorationInterfaceMemberIndex);
2766 auto invocation = to_tesc_invocation_id();
2767 statement(ts: to_expression(id: stage_out_ptr_var_id), ts: "[",
2768 ts&: invocation, ts: "].",
2769 ts: to_member_name(type: ib_type, index), ts: " = ", ts: to_expression(id: var.initializer), ts: "[",
2770 ts&: invocation, ts: "];");
2771 });
2772 }
2773 else
2774 {
2775 entry_func.fixup_hooks_in.push_back(t: [=, &var]() {
2776 statement(ts: qual_var_name, ts: " = ", ts: to_expression(id: var.initializer), ts: ";");
2777 });
2778 }
2779 }
2780 }
2781
2782 // Copy the variable location from the original variable to the member
2783 if (get_decoration_bitset(id: var.self).get(bit: DecorationLocation))
2784 {
2785 uint32_t locn = get_decoration(id: var.self, decoration: DecorationLocation);
2786 uint32_t comp = get_decoration(id: var.self, decoration: DecorationComponent);
2787 if (storage == StorageClassInput)
2788 {
2789 type_id = ensure_correct_input_type(type_id: var.basetype, location: locn, component: comp, num_components: 0, strip_array: meta.strip_array);
2790 var.basetype = type_id;
2791
2792 type_id = get_pointee_type_id(type_id);
2793 if (meta.strip_array && is_array(type: get<SPIRType>(id: type_id)))
2794 type_id = get<SPIRType>(id: type_id).parent_type;
2795 if (pull_model_inputs.count(x: var.self))
2796 ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(type_id, is_noperspective);
2797 else
2798 ib_type.member_types[ib_mbr_idx] = type_id;
2799 }
2800 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: locn);
2801 if (comp)
2802 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationComponent, argument: comp);
2803 mark_location_as_used_by_shader(location: locn, type: get<SPIRType>(id: type_id), storage);
2804 }
2805 else if (is_builtin && is_tessellation_shader() && storage == StorageClassInput && inputs_by_builtin.count(x: builtin))
2806 {
2807 uint32_t locn = inputs_by_builtin[builtin].location;
2808 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: locn);
2809 mark_location_as_used_by_shader(location: locn, type, storage);
2810 }
2811 else if (is_builtin && capture_output_to_buffer && storage == StorageClassOutput && outputs_by_builtin.count(x: builtin))
2812 {
2813 uint32_t locn = outputs_by_builtin[builtin].location;
2814 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: locn);
2815 mark_location_as_used_by_shader(location: locn, type, storage);
2816 }
2817
2818 if (get_decoration_bitset(id: var.self).get(bit: DecorationComponent))
2819 {
2820 uint32_t component = get_decoration(id: var.self, decoration: DecorationComponent);
2821 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationComponent, argument: component);
2822 }
2823
2824 if (get_decoration_bitset(id: var.self).get(bit: DecorationIndex))
2825 {
2826 uint32_t index = get_decoration(id: var.self, decoration: DecorationIndex);
2827 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationIndex, argument: index);
2828 }
2829
2830 // Mark the member as builtin if needed
2831 if (is_builtin)
2832 {
2833 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationBuiltIn, argument: builtin);
2834 if (builtin == BuiltInPosition && storage == StorageClassOutput)
2835 qual_pos_var_name = qual_var_name;
2836 }
2837
2838 // Copy interpolation decorations if needed
2839 if (storage != StorageClassInput || !pull_model_inputs.count(x: var.self))
2840 {
2841 if (is_flat)
2842 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationFlat);
2843 if (is_noperspective)
2844 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationNoPerspective);
2845 if (is_centroid)
2846 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationCentroid);
2847 if (is_sample)
2848 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationSample);
2849 }
2850
2851 set_extended_member_decoration(type: ib_type.self, index: ib_mbr_idx, decoration: SPIRVCrossDecorationInterfaceOrigID, value: var.self);
2852}
2853
2854void CompilerMSL::add_composite_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2855 SPIRType &ib_type, SPIRVariable &var,
2856 InterfaceBlockMeta &meta)
2857{
2858 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
2859 auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2860 uint32_t elem_cnt = 0;
2861
2862 if (add_component_variable_to_interface_block(storage, ib_var_ref, var, type: var_type, meta))
2863 return;
2864
2865 if (is_matrix(type: var_type))
2866 {
2867 if (is_array(type: var_type))
2868 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
2869
2870 elem_cnt = var_type.columns;
2871 }
2872 else if (is_array(type: var_type))
2873 {
2874 if (var_type.array.size() != 1)
2875 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
2876
2877 elem_cnt = to_array_size_literal(type: var_type);
2878 }
2879
2880 bool is_builtin = is_builtin_variable(var);
2881 BuiltIn builtin = BuiltIn(get_decoration(id: var.self, decoration: DecorationBuiltIn));
2882 bool is_flat = has_decoration(id: var.self, decoration: DecorationFlat);
2883 bool is_noperspective = has_decoration(id: var.self, decoration: DecorationNoPerspective);
2884 bool is_centroid = has_decoration(id: var.self, decoration: DecorationCentroid);
2885 bool is_sample = has_decoration(id: var.self, decoration: DecorationSample);
2886
2887 auto *usable_type = &var_type;
2888 if (usable_type->pointer)
2889 usable_type = &get<SPIRType>(id: usable_type->parent_type);
2890 while (is_array(type: *usable_type) || is_matrix(type: *usable_type))
2891 usable_type = &get<SPIRType>(id: usable_type->parent_type);
2892
2893 // If a builtin, force it to have the proper name.
2894 if (is_builtin)
2895 set_name(id: var.self, name: builtin_to_glsl(builtin, storage: StorageClassFunction));
2896
2897 bool flatten_from_ib_var = false;
2898 string flatten_from_ib_mbr_name;
2899
2900 if (storage == StorageClassOutput && is_builtin && builtin == BuiltInClipDistance)
2901 {
2902 // Also declare [[clip_distance]] attribute here.
2903 uint32_t clip_array_mbr_idx = uint32_t(ib_type.member_types.size());
2904 ib_type.member_types.push_back(t: get_variable_data_type_id(var));
2905 set_member_decoration(id: ib_type.self, index: clip_array_mbr_idx, decoration: DecorationBuiltIn, argument: BuiltInClipDistance);
2906
2907 flatten_from_ib_mbr_name = builtin_to_glsl(builtin: BuiltInClipDistance, storage: StorageClassOutput);
2908 set_member_name(id: ib_type.self, index: clip_array_mbr_idx, name: flatten_from_ib_mbr_name);
2909
2910 // When we flatten, we flatten directly from the "out" struct,
2911 // not from a function variable.
2912 flatten_from_ib_var = true;
2913
2914 if (!msl_options.enable_clip_distance_user_varying)
2915 return;
2916 }
2917 else if (!meta.strip_array)
2918 {
2919 // Only flatten/unflatten IO composites for non-tessellation cases where arrays are not stripped.
2920 entry_func.add_local_variable(id: var.self);
2921 // We need to declare the variable early and at entry-point scope.
2922 vars_needing_early_declaration.push_back(t: var.self);
2923 }
2924
2925 for (uint32_t i = 0; i < elem_cnt; i++)
2926 {
2927 // Add a reference to the variable type to the interface struct.
2928 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2929
2930 uint32_t target_components = 0;
2931 bool padded_output = false;
2932 uint32_t type_id = usable_type->self;
2933
2934 // Check if we need to pad fragment output to match a certain number of components.
2935 if (get_decoration_bitset(id: var.self).get(bit: DecorationLocation) && msl_options.pad_fragment_output_components &&
2936 get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput)
2937 {
2938 uint32_t locn = get_decoration(id: var.self, decoration: DecorationLocation) + i;
2939 target_components = get_target_components_for_fragment_location(location: locn);
2940 if (usable_type->vecsize < target_components)
2941 {
2942 // Make a new type here.
2943 type_id = build_extended_vector_type(type_id: usable_type->self, components: target_components);
2944 padded_output = true;
2945 }
2946 }
2947
2948 if (storage == StorageClassInput && pull_model_inputs.count(x: var.self))
2949 ib_type.member_types.push_back(t: build_msl_interpolant_type(type_id: get_pointee_type_id(type_id), is_noperspective));
2950 else
2951 ib_type.member_types.push_back(t: get_pointee_type_id(type_id));
2952
2953 // Give the member a name
2954 string mbr_name = ensure_valid_name(name: join(ts: to_expression(id: var.self), ts: "_", ts&: i), pfx: "m");
2955 set_member_name(id: ib_type.self, index: ib_mbr_idx, name: mbr_name);
2956
2957 // There is no qualified alias since we need to flatten the internal array on return.
2958 if (get_decoration_bitset(id: var.self).get(bit: DecorationLocation))
2959 {
2960 uint32_t locn = get_decoration(id: var.self, decoration: DecorationLocation) + i;
2961 uint32_t comp = get_decoration(id: var.self, decoration: DecorationComponent);
2962 if (storage == StorageClassInput)
2963 {
2964 var.basetype = ensure_correct_input_type(type_id: var.basetype, location: locn, component: comp, num_components: 0, strip_array: meta.strip_array);
2965 uint32_t mbr_type_id = ensure_correct_input_type(type_id: usable_type->self, location: locn, component: comp, num_components: 0, strip_array: meta.strip_array);
2966 if (storage == StorageClassInput && pull_model_inputs.count(x: var.self))
2967 ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(type_id: mbr_type_id, is_noperspective);
2968 else
2969 ib_type.member_types[ib_mbr_idx] = mbr_type_id;
2970 }
2971 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: locn);
2972 if (comp)
2973 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationComponent, argument: comp);
2974 mark_location_as_used_by_shader(location: locn, type: *usable_type, storage);
2975 }
2976 else if (is_builtin && is_tessellation_shader() && storage == StorageClassInput && inputs_by_builtin.count(x: builtin))
2977 {
2978 uint32_t locn = inputs_by_builtin[builtin].location + i;
2979 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: locn);
2980 mark_location_as_used_by_shader(location: locn, type: *usable_type, storage);
2981 }
2982 else if (is_builtin && capture_output_to_buffer && storage == StorageClassOutput && outputs_by_builtin.count(x: builtin))
2983 {
2984 uint32_t locn = outputs_by_builtin[builtin].location + i;
2985 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: locn);
2986 mark_location_as_used_by_shader(location: locn, type: *usable_type, storage);
2987 }
2988 else if (is_builtin && (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance))
2989 {
2990 // Declare the Clip/CullDistance as [[user(clip/cullN)]].
2991 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationBuiltIn, argument: builtin);
2992 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationIndex, argument: i);
2993 }
2994
2995 if (get_decoration_bitset(id: var.self).get(bit: DecorationIndex))
2996 {
2997 uint32_t index = get_decoration(id: var.self, decoration: DecorationIndex);
2998 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationIndex, argument: index);
2999 }
3000
3001 if (storage != StorageClassInput || !pull_model_inputs.count(x: var.self))
3002 {
3003 // Copy interpolation decorations if needed
3004 if (is_flat)
3005 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationFlat);
3006 if (is_noperspective)
3007 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationNoPerspective);
3008 if (is_centroid)
3009 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationCentroid);
3010 if (is_sample)
3011 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationSample);
3012 }
3013
3014 set_extended_member_decoration(type: ib_type.self, index: ib_mbr_idx, decoration: SPIRVCrossDecorationInterfaceOrigID, value: var.self);
3015
3016 // Only flatten/unflatten IO composites for non-tessellation cases where arrays are not stripped.
3017 if (!meta.strip_array)
3018 {
3019 switch (storage)
3020 {
3021 case StorageClassInput:
3022 entry_func.fixup_hooks_in.push_back(t: [=, &var]() {
3023 if (pull_model_inputs.count(x: var.self))
3024 {
3025 string lerp_call;
3026 if (is_centroid)
3027 lerp_call = ".interpolate_at_centroid()";
3028 else if (is_sample)
3029 lerp_call = join(ts: ".interpolate_at_sample(", ts: to_expression(id: builtin_sample_id_id), ts: ")");
3030 else
3031 lerp_call = ".interpolate_at_center()";
3032 statement(ts: to_name(id: var.self), ts: "[", ts: i, ts: "] = ", ts: ib_var_ref, ts: ".", ts: mbr_name, ts&: lerp_call, ts: ";");
3033 }
3034 else
3035 {
3036 statement(ts: to_name(id: var.self), ts: "[", ts: i, ts: "] = ", ts: ib_var_ref, ts: ".", ts: mbr_name, ts: ";");
3037 }
3038 });
3039 break;
3040
3041 case StorageClassOutput:
3042 entry_func.fixup_hooks_out.push_back(t: [=, &var]() {
3043 if (padded_output)
3044 {
3045 auto &padded_type = this->get<SPIRType>(id: type_id);
3046 statement(
3047 ts: ib_var_ref, ts: ".", ts: mbr_name, ts: " = ",
3048 ts: remap_swizzle(result_type: padded_type, input_components: usable_type->vecsize, expr: join(ts: to_name(id: var.self), ts: "[", ts: i, ts: "]")),
3049 ts: ";");
3050 }
3051 else if (flatten_from_ib_var)
3052 statement(ts: ib_var_ref, ts: ".", ts: mbr_name, ts: " = ", ts: ib_var_ref, ts: ".", ts: flatten_from_ib_mbr_name, ts: "[", ts: i,
3053 ts: "];");
3054 else
3055 statement(ts: ib_var_ref, ts: ".", ts: mbr_name, ts: " = ", ts: to_name(id: var.self), ts: "[", ts: i, ts: "];");
3056 });
3057 break;
3058
3059 default:
3060 break;
3061 }
3062 }
3063 }
3064}
3065
3066void CompilerMSL::add_composite_member_variable_to_interface_block(StorageClass storage,
3067 const string &ib_var_ref, SPIRType &ib_type,
3068 SPIRVariable &var, SPIRType &var_type,
3069 uint32_t mbr_idx, InterfaceBlockMeta &meta,
3070 const string &mbr_name_qual,
3071 const string &var_chain_qual,
3072 uint32_t &location, uint32_t &var_mbr_idx,
3073 const Bitset &interpolation_qual)
3074{
3075 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
3076
3077 BuiltIn builtin = BuiltInMax;
3078 bool is_builtin = is_member_builtin(type: var_type, index: mbr_idx, builtin: &builtin);
3079 bool is_flat = interpolation_qual.get(bit: DecorationFlat) ||
3080 has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationFlat) ||
3081 has_decoration(id: var.self, decoration: DecorationFlat);
3082 bool is_noperspective = interpolation_qual.get(bit: DecorationNoPerspective) ||
3083 has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationNoPerspective) ||
3084 has_decoration(id: var.self, decoration: DecorationNoPerspective);
3085 bool is_centroid = interpolation_qual.get(bit: DecorationCentroid) ||
3086 has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationCentroid) ||
3087 has_decoration(id: var.self, decoration: DecorationCentroid);
3088 bool is_sample = interpolation_qual.get(bit: DecorationSample) ||
3089 has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationSample) ||
3090 has_decoration(id: var.self, decoration: DecorationSample);
3091
3092 Bitset inherited_qual;
3093 if (is_flat)
3094 inherited_qual.set(DecorationFlat);
3095 if (is_noperspective)
3096 inherited_qual.set(DecorationNoPerspective);
3097 if (is_centroid)
3098 inherited_qual.set(DecorationCentroid);
3099 if (is_sample)
3100 inherited_qual.set(DecorationSample);
3101
3102 uint32_t mbr_type_id = var_type.member_types[mbr_idx];
3103 auto &mbr_type = get<SPIRType>(id: mbr_type_id);
3104
3105 bool mbr_is_indexable = false;
3106 uint32_t elem_cnt = 1;
3107 if (is_matrix(type: mbr_type))
3108 {
3109 if (is_array(type: mbr_type))
3110 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
3111
3112 mbr_is_indexable = true;
3113 elem_cnt = mbr_type.columns;
3114 }
3115 else if (is_array(type: mbr_type))
3116 {
3117 if (mbr_type.array.size() != 1)
3118 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
3119
3120 mbr_is_indexable = true;
3121 elem_cnt = to_array_size_literal(type: mbr_type);
3122 }
3123
3124 auto *usable_type = &mbr_type;
3125 if (usable_type->pointer)
3126 usable_type = &get<SPIRType>(id: usable_type->parent_type);
3127 while (is_array(type: *usable_type) || is_matrix(type: *usable_type))
3128 usable_type = &get<SPIRType>(id: usable_type->parent_type);
3129
3130 bool flatten_from_ib_var = false;
3131 string flatten_from_ib_mbr_name;
3132
3133 if (storage == StorageClassOutput && is_builtin && builtin == BuiltInClipDistance)
3134 {
3135 // Also declare [[clip_distance]] attribute here.
3136 uint32_t clip_array_mbr_idx = uint32_t(ib_type.member_types.size());
3137 ib_type.member_types.push_back(t: mbr_type_id);
3138 set_member_decoration(id: ib_type.self, index: clip_array_mbr_idx, decoration: DecorationBuiltIn, argument: BuiltInClipDistance);
3139
3140 flatten_from_ib_mbr_name = builtin_to_glsl(builtin: BuiltInClipDistance, storage: StorageClassOutput);
3141 set_member_name(id: ib_type.self, index: clip_array_mbr_idx, name: flatten_from_ib_mbr_name);
3142
3143 // When we flatten, we flatten directly from the "out" struct,
3144 // not from a function variable.
3145 flatten_from_ib_var = true;
3146
3147 if (!msl_options.enable_clip_distance_user_varying)
3148 return;
3149 }
3150
3151 // Recursively handle nested structures.
3152 if (mbr_type.basetype == SPIRType::Struct)
3153 {
3154 for (uint32_t i = 0; i < elem_cnt; i++)
3155 {
3156 string mbr_name = append_member_name(qualifier: mbr_name_qual, type: var_type, index: mbr_idx) + (mbr_is_indexable ? join(ts: "_", ts&: i) : "");
3157 string var_chain = join(ts: var_chain_qual, ts: ".", ts: to_member_name(type: var_type, index: mbr_idx), ts: (mbr_is_indexable ? join(ts: "[", ts&: i, ts: "]") : ""));
3158 uint32_t sub_mbr_cnt = uint32_t(mbr_type.member_types.size());
3159 for (uint32_t sub_mbr_idx = 0; sub_mbr_idx < sub_mbr_cnt; sub_mbr_idx++)
3160 {
3161 add_composite_member_variable_to_interface_block(storage, ib_var_ref, ib_type,
3162 var, var_type&: mbr_type, mbr_idx: sub_mbr_idx,
3163 meta, mbr_name_qual: mbr_name, var_chain_qual: var_chain,
3164 location, var_mbr_idx, interpolation_qual: inherited_qual);
3165 // FIXME: Recursive structs and tessellation breaks here.
3166 var_mbr_idx++;
3167 }
3168 }
3169 return;
3170 }
3171
3172 for (uint32_t i = 0; i < elem_cnt; i++)
3173 {
3174 // Add a reference to the variable type to the interface struct.
3175 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
3176 if (storage == StorageClassInput && pull_model_inputs.count(x: var.self))
3177 ib_type.member_types.push_back(t: build_msl_interpolant_type(type_id: usable_type->self, is_noperspective));
3178 else
3179 ib_type.member_types.push_back(t: usable_type->self);
3180
3181 // Give the member a name
3182 string mbr_name = ensure_valid_name(name: append_member_name(qualifier: mbr_name_qual, type: var_type, index: mbr_idx) + (mbr_is_indexable ? join(ts: "_", ts&: i) : ""), pfx: "m");
3183 set_member_name(id: ib_type.self, index: ib_mbr_idx, name: mbr_name);
3184
3185 // Once we determine the location of the first member within nested structures,
3186 // from a var of the topmost structure, the remaining flattened members of
3187 // the nested structures will have consecutive location values. At this point,
3188 // we've recursively tunnelled into structs, arrays, and matrices, and are
3189 // down to a single location for each member now.
3190 if (!is_builtin && location != UINT32_MAX)
3191 {
3192 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: location);
3193 mark_location_as_used_by_shader(location, type: *usable_type, storage);
3194 location++;
3195 }
3196 else if (has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationLocation))
3197 {
3198 location = get_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationLocation) + i;
3199 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: location);
3200 mark_location_as_used_by_shader(location, type: *usable_type, storage);
3201 location++;
3202 }
3203 else if (has_decoration(id: var.self, decoration: DecorationLocation))
3204 {
3205 location = get_accumulated_member_location(var, mbr_idx, strip_array: meta.strip_array) + i;
3206 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: location);
3207 mark_location_as_used_by_shader(location, type: *usable_type, storage);
3208 location++;
3209 }
3210 else if (is_builtin && is_tessellation_shader() && storage == StorageClassInput && inputs_by_builtin.count(x: builtin))
3211 {
3212 location = inputs_by_builtin[builtin].location + i;
3213 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: location);
3214 mark_location_as_used_by_shader(location, type: *usable_type, storage);
3215 location++;
3216 }
3217 else if (is_builtin && capture_output_to_buffer && storage == StorageClassOutput && outputs_by_builtin.count(x: builtin))
3218 {
3219 location = outputs_by_builtin[builtin].location + i;
3220 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: location);
3221 mark_location_as_used_by_shader(location, type: *usable_type, storage);
3222 location++;
3223 }
3224 else if (is_builtin && (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance))
3225 {
3226 // Declare the Clip/CullDistance as [[user(clip/cullN)]].
3227 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationBuiltIn, argument: builtin);
3228 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationIndex, argument: i);
3229 }
3230
3231 if (has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationComponent))
3232 SPIRV_CROSS_THROW("DecorationComponent on matrices and arrays is not supported.");
3233
3234 if (storage != StorageClassInput || !pull_model_inputs.count(x: var.self))
3235 {
3236 // Copy interpolation decorations if needed
3237 if (is_flat)
3238 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationFlat);
3239 if (is_noperspective)
3240 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationNoPerspective);
3241 if (is_centroid)
3242 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationCentroid);
3243 if (is_sample)
3244 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationSample);
3245 }
3246
3247 set_extended_member_decoration(type: ib_type.self, index: ib_mbr_idx, decoration: SPIRVCrossDecorationInterfaceOrigID, value: var.self);
3248 set_extended_member_decoration(type: ib_type.self, index: ib_mbr_idx, decoration: SPIRVCrossDecorationInterfaceMemberIndex, value: var_mbr_idx);
3249
3250 // Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate.
3251 if (!meta.strip_array && meta.allow_local_declaration)
3252 {
3253 string var_chain = join(ts: var_chain_qual, ts: ".", ts: to_member_name(type: var_type, index: mbr_idx), ts: (mbr_is_indexable ? join(ts: "[", ts&: i, ts: "]") : ""));
3254 switch (storage)
3255 {
3256 case StorageClassInput:
3257 entry_func.fixup_hooks_in.push_back(t: [=, &var]() {
3258 string lerp_call;
3259 if (pull_model_inputs.count(x: var.self))
3260 {
3261 if (is_centroid)
3262 lerp_call = ".interpolate_at_centroid()";
3263 else if (is_sample)
3264 lerp_call = join(ts: ".interpolate_at_sample(", ts: to_expression(id: builtin_sample_id_id), ts: ")");
3265 else
3266 lerp_call = ".interpolate_at_center()";
3267 }
3268 statement(ts: var_chain, ts: " = ", ts: ib_var_ref, ts: ".", ts: mbr_name, ts&: lerp_call, ts: ";");
3269 });
3270 break;
3271
3272 case StorageClassOutput:
3273 entry_func.fixup_hooks_out.push_back(t: [=]() {
3274 if (flatten_from_ib_var)
3275 statement(ts: ib_var_ref, ts: ".", ts: mbr_name, ts: " = ", ts: ib_var_ref, ts: ".", ts: flatten_from_ib_mbr_name, ts: "[", ts: i, ts: "];");
3276 else
3277 statement(ts: ib_var_ref, ts: ".", ts: mbr_name, ts: " = ", ts: var_chain, ts: ";");
3278 });
3279 break;
3280
3281 default:
3282 break;
3283 }
3284 }
3285 }
3286}
3287
3288void CompilerMSL::add_plain_member_variable_to_interface_block(StorageClass storage,
3289 const string &ib_var_ref, SPIRType &ib_type,
3290 SPIRVariable &var, SPIRType &var_type,
3291 uint32_t mbr_idx, InterfaceBlockMeta &meta,
3292 const string &mbr_name_qual,
3293 const string &var_chain_qual,
3294 uint32_t &location, uint32_t &var_mbr_idx)
3295{
3296 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
3297
3298 BuiltIn builtin = BuiltInMax;
3299 bool is_builtin = is_member_builtin(type: var_type, index: mbr_idx, builtin: &builtin);
3300 bool is_flat =
3301 has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationFlat) || has_decoration(id: var.self, decoration: DecorationFlat);
3302 bool is_noperspective = has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationNoPerspective) ||
3303 has_decoration(id: var.self, decoration: DecorationNoPerspective);
3304 bool is_centroid = has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationCentroid) ||
3305 has_decoration(id: var.self, decoration: DecorationCentroid);
3306 bool is_sample =
3307 has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationSample) || has_decoration(id: var.self, decoration: DecorationSample);
3308
3309 // Add a reference to the member to the interface struct.
3310 uint32_t mbr_type_id = var_type.member_types[mbr_idx];
3311 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
3312 mbr_type_id = ensure_correct_builtin_type(type_id: mbr_type_id, builtin);
3313 var_type.member_types[mbr_idx] = mbr_type_id;
3314 if (storage == StorageClassInput && pull_model_inputs.count(x: var.self))
3315 ib_type.member_types.push_back(t: build_msl_interpolant_type(type_id: mbr_type_id, is_noperspective));
3316 else
3317 ib_type.member_types.push_back(t: mbr_type_id);
3318
3319 // Give the member a name
3320 string mbr_name = ensure_valid_name(name: append_member_name(qualifier: mbr_name_qual, type: var_type, index: mbr_idx), pfx: "m");
3321 set_member_name(id: ib_type.self, index: ib_mbr_idx, name: mbr_name);
3322
3323 // Update the original variable reference to include the structure reference
3324 string qual_var_name = ib_var_ref + "." + mbr_name;
3325 // If using pull-model interpolation, need to add a call to the correct interpolation method.
3326 if (storage == StorageClassInput && pull_model_inputs.count(x: var.self))
3327 {
3328 if (is_centroid)
3329 qual_var_name += ".interpolate_at_centroid()";
3330 else if (is_sample)
3331 qual_var_name += join(ts: ".interpolate_at_sample(", ts: to_expression(id: builtin_sample_id_id), ts: ")");
3332 else
3333 qual_var_name += ".interpolate_at_center()";
3334 }
3335
3336 bool flatten_stage_out = false;
3337 string var_chain = var_chain_qual + "." + to_member_name(type: var_type, index: mbr_idx);
3338 if (is_builtin && !meta.strip_array)
3339 {
3340 // For the builtin gl_PerVertex, we cannot treat it as a block anyways,
3341 // so redirect to qualified name.
3342 set_member_qualified_name(type_id: var_type.self, index: mbr_idx, name: qual_var_name);
3343 }
3344 else if (!meta.strip_array && meta.allow_local_declaration)
3345 {
3346 // Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate.
3347 switch (storage)
3348 {
3349 case StorageClassInput:
3350 entry_func.fixup_hooks_in.push_back(t: [=]() {
3351 statement(ts: var_chain, ts: " = ", ts: qual_var_name, ts: ";");
3352 });
3353 break;
3354
3355 case StorageClassOutput:
3356 flatten_stage_out = true;
3357 entry_func.fixup_hooks_out.push_back(t: [=]() {
3358 statement(ts: qual_var_name, ts: " = ", ts: var_chain, ts: ";");
3359 });
3360 break;
3361
3362 default:
3363 break;
3364 }
3365 }
3366
3367 // Once we determine the location of the first member within nested structures,
3368 // from a var of the topmost structure, the remaining flattened members of
3369 // the nested structures will have consecutive location values. At this point,
3370 // we've recursively tunnelled into structs, arrays, and matrices, and are
3371 // down to a single location for each member now.
3372 if (!is_builtin && location != UINT32_MAX)
3373 {
3374 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: location);
3375 mark_location_as_used_by_shader(location, type: get<SPIRType>(id: mbr_type_id), storage);
3376 location += type_to_location_count(type: get<SPIRType>(id: mbr_type_id));
3377 }
3378 else if (has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationLocation))
3379 {
3380 location = get_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationLocation);
3381 uint32_t comp = get_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationComponent);
3382 if (storage == StorageClassInput)
3383 {
3384 mbr_type_id = ensure_correct_input_type(type_id: mbr_type_id, location, component: comp, num_components: 0, strip_array: meta.strip_array);
3385 var_type.member_types[mbr_idx] = mbr_type_id;
3386 if (storage == StorageClassInput && pull_model_inputs.count(x: var.self))
3387 ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(type_id: mbr_type_id, is_noperspective);
3388 else
3389 ib_type.member_types[ib_mbr_idx] = mbr_type_id;
3390 }
3391 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: location);
3392 mark_location_as_used_by_shader(location, type: get<SPIRType>(id: mbr_type_id), storage);
3393 location += type_to_location_count(type: get<SPIRType>(id: mbr_type_id));
3394 }
3395 else if (has_decoration(id: var.self, decoration: DecorationLocation))
3396 {
3397 location = get_accumulated_member_location(var, mbr_idx, strip_array: meta.strip_array);
3398 if (storage == StorageClassInput)
3399 {
3400 mbr_type_id = ensure_correct_input_type(type_id: mbr_type_id, location, component: 0, num_components: 0, strip_array: meta.strip_array);
3401 var_type.member_types[mbr_idx] = mbr_type_id;
3402 if (storage == StorageClassInput && pull_model_inputs.count(x: var.self))
3403 ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(type_id: mbr_type_id, is_noperspective);
3404 else
3405 ib_type.member_types[ib_mbr_idx] = mbr_type_id;
3406 }
3407 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: location);
3408 mark_location_as_used_by_shader(location, type: get<SPIRType>(id: mbr_type_id), storage);
3409 location += type_to_location_count(type: get<SPIRType>(id: mbr_type_id));
3410 }
3411 else if (is_builtin && is_tessellation_shader() && storage == StorageClassInput && inputs_by_builtin.count(x: builtin))
3412 {
3413 location = inputs_by_builtin[builtin].location;
3414 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: location);
3415 mark_location_as_used_by_shader(location, type: get<SPIRType>(id: mbr_type_id), storage);
3416 location += type_to_location_count(type: get<SPIRType>(id: mbr_type_id));
3417 }
3418 else if (is_builtin && capture_output_to_buffer && storage == StorageClassOutput && outputs_by_builtin.count(x: builtin))
3419 {
3420 location = outputs_by_builtin[builtin].location;
3421 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: location);
3422 mark_location_as_used_by_shader(location, type: get<SPIRType>(id: mbr_type_id), storage);
3423 location += type_to_location_count(type: get<SPIRType>(id: mbr_type_id));
3424 }
3425
3426 // Copy the component location, if present.
3427 if (has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationComponent))
3428 {
3429 uint32_t comp = get_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationComponent);
3430 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationComponent, argument: comp);
3431 }
3432
3433 // Mark the member as builtin if needed
3434 if (is_builtin)
3435 {
3436 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationBuiltIn, argument: builtin);
3437 if (builtin == BuiltInPosition && storage == StorageClassOutput)
3438 qual_pos_var_name = qual_var_name;
3439 }
3440
3441 const SPIRConstant *c = nullptr;
3442 if (!flatten_stage_out && var.storage == StorageClassOutput &&
3443 var.initializer != ID(0) && (c = maybe_get<SPIRConstant>(id: var.initializer)))
3444 {
3445 if (meta.strip_array)
3446 {
3447 entry_func.fixup_hooks_in.push_back(t: [=, &var]() {
3448 auto &type = this->get<SPIRType>(id: var.basetype);
3449 uint32_t index = get_extended_member_decoration(type: var.self, index: mbr_idx, decoration: SPIRVCrossDecorationInterfaceMemberIndex);
3450
3451 auto invocation = to_tesc_invocation_id();
3452 auto constant_chain = join(ts: to_expression(id: var.initializer), ts: "[", ts&: invocation, ts: "]");
3453 statement(ts: to_expression(id: stage_out_ptr_var_id), ts: "[",
3454 ts&: invocation, ts: "].",
3455 ts: to_member_name(type: ib_type, index), ts: " = ",
3456 ts&: constant_chain, ts: ".", ts: to_member_name(type, index: mbr_idx), ts: ";");
3457 });
3458 }
3459 else
3460 {
3461 entry_func.fixup_hooks_in.push_back(t: [=]() {
3462 statement(ts: qual_var_name, ts: " = ", ts: constant_expression(
3463 c: this->get<SPIRConstant>(id: c->subconstants[mbr_idx])), ts: ";");
3464 });
3465 }
3466 }
3467
3468 if (storage != StorageClassInput || !pull_model_inputs.count(x: var.self))
3469 {
3470 // Copy interpolation decorations if needed
3471 if (is_flat)
3472 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationFlat);
3473 if (is_noperspective)
3474 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationNoPerspective);
3475 if (is_centroid)
3476 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationCentroid);
3477 if (is_sample)
3478 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationSample);
3479 }
3480
3481 set_extended_member_decoration(type: ib_type.self, index: ib_mbr_idx, decoration: SPIRVCrossDecorationInterfaceOrigID, value: var.self);
3482 set_extended_member_decoration(type: ib_type.self, index: ib_mbr_idx, decoration: SPIRVCrossDecorationInterfaceMemberIndex, value: var_mbr_idx);
3483}
3484
3485// In Metal, the tessellation levels are stored as tightly packed half-precision floating point values.
3486// But, stage-in attribute offsets and strides must be multiples of four, so we can't pass the levels
3487// individually. Therefore, we must pass them as vectors. Triangles get a single float4, with the outer
3488// levels in 'xyz' and the inner level in 'w'. Quads get a float4 containing the outer levels and a
3489// float2 containing the inner levels.
3490void CompilerMSL::add_tess_level_input_to_interface_block(const std::string &ib_var_ref, SPIRType &ib_type,
3491 SPIRVariable &var)
3492{
3493 auto &var_type = get_variable_element_type(var);
3494
3495 BuiltIn builtin = BuiltIn(get_decoration(id: var.self, decoration: DecorationBuiltIn));
3496 bool triangles = is_tessellating_triangles();
3497 string mbr_name;
3498
3499 // Add a reference to the variable type to the interface struct.
3500 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
3501
3502 const auto mark_locations = [&](const SPIRType &new_var_type) {
3503 if (get_decoration_bitset(id: var.self).get(bit: DecorationLocation))
3504 {
3505 uint32_t locn = get_decoration(id: var.self, decoration: DecorationLocation);
3506 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: locn);
3507 mark_location_as_used_by_shader(location: locn, type: new_var_type, storage: StorageClassInput);
3508 }
3509 else if (inputs_by_builtin.count(x: builtin))
3510 {
3511 uint32_t locn = inputs_by_builtin[builtin].location;
3512 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: locn);
3513 mark_location_as_used_by_shader(location: locn, type: new_var_type, storage: StorageClassInput);
3514 }
3515 };
3516
3517 if (triangles)
3518 {
3519 // Triangles are tricky, because we want only one member in the struct.
3520 mbr_name = "gl_TessLevel";
3521
3522 // If we already added the other one, we can skip this step.
3523 if (!added_builtin_tess_level)
3524 {
3525 uint32_t type_id = build_extended_vector_type(type_id: var_type.self, components: 4);
3526
3527 ib_type.member_types.push_back(t: type_id);
3528
3529 // Give the member a name
3530 set_member_name(id: ib_type.self, index: ib_mbr_idx, name: mbr_name);
3531
3532 // We cannot decorate both, but the important part is that
3533 // it's marked as builtin so we can get automatic attribute assignment if needed.
3534 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationBuiltIn, argument: builtin);
3535
3536 mark_locations(var_type);
3537 added_builtin_tess_level = true;
3538 }
3539 }
3540 else
3541 {
3542 mbr_name = builtin_to_glsl(builtin, storage: StorageClassFunction);
3543
3544 uint32_t type_id = build_extended_vector_type(type_id: var_type.self, components: builtin == BuiltInTessLevelOuter ? 4 : 2);
3545
3546 uint32_t ptr_type_id = ir.increase_bound_by(count: 1);
3547 auto &new_var_type = set<SPIRType>(id: ptr_type_id, args&: get<SPIRType>(id: type_id));
3548 new_var_type.pointer = true;
3549 new_var_type.pointer_depth++;
3550 new_var_type.storage = StorageClassInput;
3551 new_var_type.parent_type = type_id;
3552
3553 ib_type.member_types.push_back(t: type_id);
3554
3555 // Give the member a name
3556 set_member_name(id: ib_type.self, index: ib_mbr_idx, name: mbr_name);
3557 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationBuiltIn, argument: builtin);
3558
3559 mark_locations(new_var_type);
3560 }
3561
3562 add_tess_level_input(base_ref: ib_var_ref, mbr_name, var);
3563}
3564
3565void CompilerMSL::add_tess_level_input(const std::string &base_ref, const std::string &mbr_name, SPIRVariable &var)
3566{
3567 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
3568 BuiltIn builtin = BuiltIn(get_decoration(id: var.self, decoration: DecorationBuiltIn));
3569
3570 // Force the variable to have the proper name.
3571 string var_name = builtin_to_glsl(builtin, storage: StorageClassFunction);
3572 set_name(id: var.self, name: var_name);
3573
3574 // We need to declare the variable early and at entry-point scope.
3575 entry_func.add_local_variable(id: var.self);
3576 vars_needing_early_declaration.push_back(t: var.self);
3577 bool triangles = is_tessellating_triangles();
3578
3579 if (builtin == BuiltInTessLevelOuter)
3580 {
3581 entry_func.fixup_hooks_in.push_back(
3582 t: [=]()
3583 {
3584 statement(ts: var_name, ts: "[0] = ", ts: base_ref, ts: ".", ts: mbr_name, ts: "[0];");
3585 statement(ts: var_name, ts: "[1] = ", ts: base_ref, ts: ".", ts: mbr_name, ts: "[1];");
3586 statement(ts: var_name, ts: "[2] = ", ts: base_ref, ts: ".", ts: mbr_name, ts: "[2];");
3587 if (!triangles)
3588 statement(ts: var_name, ts: "[3] = ", ts: base_ref, ts: ".", ts: mbr_name, ts: "[3];");
3589 });
3590 }
3591 else
3592 {
3593 entry_func.fixup_hooks_in.push_back(t: [=]() {
3594 if (triangles)
3595 {
3596 if (msl_options.raw_buffer_tese_input)
3597 statement(ts: var_name, ts: "[0] = ", ts: base_ref, ts: ".", ts: mbr_name, ts: ";");
3598 else
3599 statement(ts: var_name, ts: "[0] = ", ts: base_ref, ts: ".", ts: mbr_name, ts: "[3];");
3600 }
3601 else
3602 {
3603 statement(ts: var_name, ts: "[0] = ", ts: base_ref, ts: ".", ts: mbr_name, ts: "[0];");
3604 statement(ts: var_name, ts: "[1] = ", ts: base_ref, ts: ".", ts: mbr_name, ts: "[1];");
3605 }
3606 });
3607 }
3608}
3609
3610bool CompilerMSL::variable_storage_requires_stage_io(spv::StorageClass storage) const
3611{
3612 if (storage == StorageClassOutput)
3613 return !capture_output_to_buffer;
3614 else if (storage == StorageClassInput)
3615 return !(is_tesc_shader() && msl_options.multi_patch_workgroup) &&
3616 !(is_tese_shader() && msl_options.raw_buffer_tese_input);
3617 else
3618 return false;
3619}
3620
3621string CompilerMSL::to_tesc_invocation_id()
3622{
3623 if (msl_options.multi_patch_workgroup)
3624 {
3625 // n.b. builtin_invocation_id_id here is the dispatch global invocation ID,
3626 // not the TC invocation ID.
3627 return join(ts: to_expression(id: builtin_invocation_id_id), ts: ".x % ", ts&: get_entry_point().output_vertices);
3628 }
3629 else
3630 return builtin_to_glsl(builtin: BuiltInInvocationId, storage: StorageClassInput);
3631}
3632
3633void CompilerMSL::emit_local_masked_variable(const SPIRVariable &masked_var, bool strip_array)
3634{
3635 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
3636 bool threadgroup_storage = variable_decl_is_remapped_storage(variable: masked_var, storage: StorageClassWorkgroup);
3637
3638 if (threadgroup_storage && msl_options.multi_patch_workgroup)
3639 {
3640 // We need one threadgroup block per patch, so fake this.
3641 entry_func.fixup_hooks_in.push_back(t: [this, &masked_var]() {
3642 auto &type = get_variable_data_type(var: masked_var);
3643 add_local_variable_name(id: masked_var.self);
3644
3645 const uint32_t max_control_points_per_patch = 32u;
3646 uint32_t max_num_instances =
3647 (max_control_points_per_patch + get_entry_point().output_vertices - 1u) /
3648 get_entry_point().output_vertices;
3649 statement(ts: "threadgroup ", ts: type_to_glsl(type), ts: " ",
3650 ts: "spvStorage", ts: to_name(id: masked_var.self), ts: "[", ts&: max_num_instances, ts: "]",
3651 ts: type_to_array_glsl(type, variable_id: 0), ts: ";");
3652
3653 // Assign a threadgroup slice to each PrimitiveID.
3654 // We assume here that workgroup size is rounded to 32,
3655 // since that's the maximum number of control points per patch.
3656 // We cannot size the array based on fixed dispatch parameters,
3657 // since Metal does not allow that. :(
3658 // FIXME: We will likely need an option to support passing down target workgroup size,
3659 // so we can emit appropriate size here.
3660 statement(ts: "threadgroup auto ",
3661 ts: "&", ts: to_name(id: masked_var.self),
3662 ts: " = spvStorage", ts: to_name(id: masked_var.self), ts: "[",
3663 ts: "(", ts: to_expression(id: builtin_invocation_id_id), ts: ".x / ",
3664 ts&: get_entry_point().output_vertices, ts: ") % ",
3665 ts&: max_num_instances, ts: "];");
3666 });
3667 }
3668 else
3669 {
3670 entry_func.add_local_variable(id: masked_var.self);
3671 }
3672
3673 if (!threadgroup_storage)
3674 {
3675 vars_needing_early_declaration.push_back(t: masked_var.self);
3676 }
3677 else if (masked_var.initializer)
3678 {
3679 // Cannot directly initialize threadgroup variables. Need fixup hooks.
3680 ID initializer = masked_var.initializer;
3681 if (strip_array)
3682 {
3683 entry_func.fixup_hooks_in.push_back(t: [this, &masked_var, initializer]() {
3684 auto invocation = to_tesc_invocation_id();
3685 statement(ts: to_expression(id: masked_var.self), ts: "[",
3686 ts&: invocation, ts: "] = ",
3687 ts: to_expression(id: initializer), ts: "[",
3688 ts&: invocation, ts: "];");
3689 });
3690 }
3691 else
3692 {
3693 entry_func.fixup_hooks_in.push_back(t: [this, &masked_var, initializer]() {
3694 statement(ts: to_expression(id: masked_var.self), ts: " = ", ts: to_expression(id: initializer), ts: ";");
3695 });
3696 }
3697 }
3698}
3699
3700void CompilerMSL::add_variable_to_interface_block(StorageClass storage, const string &ib_var_ref, SPIRType &ib_type,
3701 SPIRVariable &var, InterfaceBlockMeta &meta)
3702{
3703 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
3704 // Tessellation control I/O variables and tessellation evaluation per-point inputs are
3705 // usually declared as arrays. In these cases, we want to add the element type to the
3706 // interface block, since in Metal it's the interface block itself which is arrayed.
3707 auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
3708 bool is_builtin = is_builtin_variable(var);
3709 auto builtin = BuiltIn(get_decoration(id: var.self, decoration: DecorationBuiltIn));
3710 bool is_block = has_decoration(id: var_type.self, decoration: DecorationBlock);
3711
3712 // If stage variables are masked out, emit them as plain variables instead.
3713 // For builtins, we query them one by one later.
3714 // IO blocks are not masked here, we need to mask them per-member instead.
3715 if (storage == StorageClassOutput && is_stage_output_variable_masked(var))
3716 {
3717 // If we ignore an output, we must still emit it, since it might be used by app.
3718 // Instead, just emit it as early declaration.
3719 emit_local_masked_variable(masked_var: var, strip_array: meta.strip_array);
3720 return;
3721 }
3722
3723 if (storage == StorageClassInput && has_decoration(id: var.self, decoration: DecorationPerVertexKHR))
3724 SPIRV_CROSS_THROW("PerVertexKHR decoration is not supported in MSL.");
3725
3726 // If variable names alias, they will end up with wrong names in the interface struct, because
3727 // there might be aliases in the member name cache and there would be a mismatch in fixup_in code.
3728 // Make sure to register the variables as unique resource names ahead of time.
3729 // This would normally conflict with the name cache when emitting local variables,
3730 // but this happens in the setup stage, before we hit compilation loops.
3731 // The name cache is cleared before we actually emit code, so this is safe.
3732 add_resource_name(id: var.self);
3733
3734 if (var_type.basetype == SPIRType::Struct)
3735 {
3736 bool block_requires_flattening =
3737 variable_storage_requires_stage_io(storage) || (is_block && var_type.array.empty());
3738 bool needs_local_declaration = !is_builtin && block_requires_flattening && meta.allow_local_declaration;
3739
3740 if (needs_local_declaration)
3741 {
3742 // For I/O blocks or structs, we will need to pass the block itself around
3743 // to functions if they are used globally in leaf functions.
3744 // Rather than passing down member by member,
3745 // we unflatten I/O blocks while running the shader,
3746 // and pass the actual struct type down to leaf functions.
3747 // We then unflatten inputs, and flatten outputs in the "fixup" stages.
3748 emit_local_masked_variable(masked_var: var, strip_array: meta.strip_array);
3749 }
3750
3751 if (!block_requires_flattening)
3752 {
3753 // In Metal tessellation shaders, the interface block itself is arrayed. This makes things
3754 // very complicated, since stage-in structures in MSL don't support nested structures.
3755 // Luckily, for stage-out when capturing output, we can avoid this and just add
3756 // composite members directly, because the stage-out structure is stored to a buffer,
3757 // not returned.
3758 add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
3759 }
3760 else
3761 {
3762 bool masked_block = false;
3763 uint32_t location = UINT32_MAX;
3764 uint32_t var_mbr_idx = 0;
3765 uint32_t elem_cnt = 1;
3766 if (is_matrix(type: var_type))
3767 {
3768 if (is_array(type: var_type))
3769 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
3770
3771 elem_cnt = var_type.columns;
3772 }
3773 else if (is_array(type: var_type))
3774 {
3775 if (var_type.array.size() != 1)
3776 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
3777
3778 elem_cnt = to_array_size_literal(type: var_type);
3779 }
3780
3781 for (uint32_t elem_idx = 0; elem_idx < elem_cnt; elem_idx++)
3782 {
3783 // Flatten the struct members into the interface struct
3784 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(var_type.member_types.size()); mbr_idx++)
3785 {
3786 builtin = BuiltInMax;
3787 is_builtin = is_member_builtin(type: var_type, index: mbr_idx, builtin: &builtin);
3788 auto &mbr_type = get<SPIRType>(id: var_type.member_types[mbr_idx]);
3789
3790 if (storage == StorageClassOutput && is_stage_output_block_member_masked(var, index: mbr_idx, strip_array: meta.strip_array))
3791 {
3792 location = UINT32_MAX; // Skip this member and resolve location again on next var member
3793
3794 if (is_block)
3795 masked_block = true;
3796
3797 // Non-builtin block output variables are just ignored, since they will still access
3798 // the block variable as-is. They're just not flattened.
3799 if (is_builtin && !meta.strip_array)
3800 {
3801 // Emit a fake variable instead.
3802 uint32_t ids = ir.increase_bound_by(count: 2);
3803 uint32_t ptr_type_id = ids + 0;
3804 uint32_t var_id = ids + 1;
3805
3806 auto ptr_type = mbr_type;
3807 ptr_type.pointer = true;
3808 ptr_type.pointer_depth++;
3809 ptr_type.parent_type = var_type.member_types[mbr_idx];
3810 ptr_type.storage = StorageClassOutput;
3811
3812 uint32_t initializer = 0;
3813 if (var.initializer)
3814 if (auto *c = maybe_get<SPIRConstant>(id: var.initializer))
3815 initializer = c->subconstants[mbr_idx];
3816
3817 set<SPIRType>(id: ptr_type_id, args&: ptr_type);
3818 set<SPIRVariable>(id: var_id, args&: ptr_type_id, args: StorageClassOutput, args&: initializer);
3819 entry_func.add_local_variable(id: var_id);
3820 vars_needing_early_declaration.push_back(t: var_id);
3821 set_name(id: var_id, name: builtin_to_glsl(builtin, storage: StorageClassOutput));
3822 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: builtin);
3823 }
3824 }
3825 else if (!is_builtin || has_active_builtin(builtin, storage))
3826 {
3827 bool is_composite_type = is_matrix(type: mbr_type) || is_array(type: mbr_type) || mbr_type.basetype == SPIRType::Struct;
3828 bool attribute_load_store =
3829 storage == StorageClassInput && get_execution_model() != ExecutionModelFragment;
3830 bool storage_is_stage_io = variable_storage_requires_stage_io(storage);
3831
3832 // Clip/CullDistance always need to be declared as user attributes.
3833 if (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance)
3834 is_builtin = false;
3835
3836 const string var_name = to_name(id: var.self);
3837 string mbr_name_qual = var_name;
3838 string var_chain_qual = var_name;
3839 if (elem_cnt > 1)
3840 {
3841 mbr_name_qual += join(ts: "_", ts&: elem_idx);
3842 var_chain_qual += join(ts: "[", ts&: elem_idx, ts: "]");
3843 }
3844
3845 if ((!is_builtin || attribute_load_store) && storage_is_stage_io && is_composite_type)
3846 {
3847 add_composite_member_variable_to_interface_block(storage, ib_var_ref, ib_type,
3848 var, var_type, mbr_idx, meta,
3849 mbr_name_qual, var_chain_qual,
3850 location, var_mbr_idx, interpolation_qual: {});
3851 }
3852 else
3853 {
3854 add_plain_member_variable_to_interface_block(storage, ib_var_ref, ib_type,
3855 var, var_type, mbr_idx, meta,
3856 mbr_name_qual, var_chain_qual,
3857 location, var_mbr_idx);
3858 }
3859 }
3860 var_mbr_idx++;
3861 }
3862 }
3863
3864 // If we're redirecting a block, we might still need to access the original block
3865 // variable if we're masking some members.
3866 if (masked_block && !needs_local_declaration && (!is_builtin_variable(var) || is_tesc_shader()))
3867 {
3868 if (is_builtin_variable(var))
3869 {
3870 // Ensure correct names for the block members if we're actually going to
3871 // declare gl_PerVertex.
3872 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(var_type.member_types.size()); mbr_idx++)
3873 {
3874 set_member_name(id: var_type.self, index: mbr_idx, name: builtin_to_glsl(
3875 builtin: BuiltIn(get_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationBuiltIn)),
3876 storage: StorageClassOutput));
3877 }
3878
3879 set_name(id: var_type.self, name: "gl_PerVertex");
3880 set_name(id: var.self, name: "gl_out_masked");
3881 stage_out_masked_builtin_type_id = var_type.self;
3882 }
3883 emit_local_masked_variable(masked_var: var, strip_array: meta.strip_array);
3884 }
3885 }
3886 }
3887 else if (is_tese_shader() && storage == StorageClassInput && !meta.strip_array && is_builtin &&
3888 (builtin == BuiltInTessLevelOuter || builtin == BuiltInTessLevelInner))
3889 {
3890 add_tess_level_input_to_interface_block(ib_var_ref, ib_type, var);
3891 }
3892 else if (var_type.basetype == SPIRType::Boolean || var_type.basetype == SPIRType::Char ||
3893 type_is_integral(type: var_type) || type_is_floating_point(type: var_type))
3894 {
3895 if (!is_builtin || has_active_builtin(builtin, storage))
3896 {
3897 bool is_composite_type = is_matrix(type: var_type) || is_array(type: var_type);
3898 bool storage_is_stage_io = variable_storage_requires_stage_io(storage);
3899 bool attribute_load_store = storage == StorageClassInput && get_execution_model() != ExecutionModelFragment;
3900
3901 // Clip/CullDistance always needs to be declared as user attributes.
3902 if (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance)
3903 is_builtin = false;
3904
3905 // MSL does not allow matrices or arrays in input or output variables, so need to handle it specially.
3906 if ((!is_builtin || attribute_load_store) && storage_is_stage_io && is_composite_type)
3907 {
3908 add_composite_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
3909 }
3910 else
3911 {
3912 add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
3913 }
3914 }
3915 }
3916}
3917
3918// Fix up the mapping of variables to interface member indices, which is used to compile access chains
3919// for per-vertex variables in a tessellation control shader.
3920void CompilerMSL::fix_up_interface_member_indices(StorageClass storage, uint32_t ib_type_id)
3921{
3922 // Only needed for tessellation shaders and pull-model interpolants.
3923 // Need to redirect interface indices back to variables themselves.
3924 // For structs, each member of the struct need a separate instance.
3925 if (!is_tesc_shader() && !(is_tese_shader() && storage == StorageClassInput) &&
3926 !(get_execution_model() == ExecutionModelFragment && storage == StorageClassInput &&
3927 !pull_model_inputs.empty()))
3928 return;
3929
3930 auto mbr_cnt = uint32_t(ir.meta[ib_type_id].members.size());
3931 for (uint32_t i = 0; i < mbr_cnt; i++)
3932 {
3933 uint32_t var_id = get_extended_member_decoration(type: ib_type_id, index: i, decoration: SPIRVCrossDecorationInterfaceOrigID);
3934 if (!var_id)
3935 continue;
3936 auto &var = get<SPIRVariable>(id: var_id);
3937
3938 auto &type = get_variable_element_type(var);
3939
3940 bool flatten_composites = variable_storage_requires_stage_io(storage: var.storage);
3941 bool is_block = has_decoration(id: type.self, decoration: DecorationBlock);
3942
3943 uint32_t mbr_idx = uint32_t(-1);
3944 if (type.basetype == SPIRType::Struct && (flatten_composites || is_block))
3945 mbr_idx = get_extended_member_decoration(type: ib_type_id, index: i, decoration: SPIRVCrossDecorationInterfaceMemberIndex);
3946
3947 if (mbr_idx != uint32_t(-1))
3948 {
3949 // Only set the lowest InterfaceMemberIndex for each variable member.
3950 // IB struct members will be emitted in-order w.r.t. interface member index.
3951 if (!has_extended_member_decoration(type: var_id, index: mbr_idx, decoration: SPIRVCrossDecorationInterfaceMemberIndex))
3952 set_extended_member_decoration(type: var_id, index: mbr_idx, decoration: SPIRVCrossDecorationInterfaceMemberIndex, value: i);
3953 }
3954 else
3955 {
3956 // Only set the lowest InterfaceMemberIndex for each variable.
3957 // IB struct members will be emitted in-order w.r.t. interface member index.
3958 if (!has_extended_decoration(id: var_id, decoration: SPIRVCrossDecorationInterfaceMemberIndex))
3959 set_extended_decoration(id: var_id, decoration: SPIRVCrossDecorationInterfaceMemberIndex, value: i);
3960 }
3961 }
3962}
3963
3964// Add an interface structure for the type of storage, which is either StorageClassInput or StorageClassOutput.
3965// Returns the ID of the newly added variable, or zero if no variable was added.
3966uint32_t CompilerMSL::add_interface_block(StorageClass storage, bool patch)
3967{
3968 // Accumulate the variables that should appear in the interface struct.
3969 SmallVector<SPIRVariable *> vars;
3970 bool incl_builtins = storage == StorageClassOutput || is_tessellation_shader();
3971 bool has_seen_barycentric = false;
3972
3973 InterfaceBlockMeta meta;
3974
3975 // Varying interfaces between stages which use "user()" attribute can be dealt with
3976 // without explicit packing and unpacking of components. For any variables which link against the runtime
3977 // in some way (vertex attributes, fragment output, etc), we'll need to deal with it somehow.
3978 bool pack_components =
3979 (storage == StorageClassInput && get_execution_model() == ExecutionModelVertex) ||
3980 (storage == StorageClassOutput && get_execution_model() == ExecutionModelFragment) ||
3981 (storage == StorageClassOutput && get_execution_model() == ExecutionModelVertex && capture_output_to_buffer);
3982
3983 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t var_id, SPIRVariable &var) {
3984 if (var.storage != storage)
3985 return;
3986
3987 auto &type = this->get<SPIRType>(id: var.basetype);
3988
3989 bool is_builtin = is_builtin_variable(var);
3990 bool is_block = has_decoration(id: type.self, decoration: DecorationBlock);
3991
3992 auto bi_type = BuiltInMax;
3993 bool builtin_is_gl_in_out = false;
3994 if (is_builtin && !is_block)
3995 {
3996 bi_type = BuiltIn(get_decoration(id: var_id, decoration: DecorationBuiltIn));
3997 builtin_is_gl_in_out = bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
3998 bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance;
3999 }
4000
4001 if (is_builtin && is_block)
4002 builtin_is_gl_in_out = true;
4003
4004 uint32_t location = get_decoration(id: var_id, decoration: DecorationLocation);
4005
4006 bool builtin_is_stage_in_out = builtin_is_gl_in_out ||
4007 bi_type == BuiltInLayer || bi_type == BuiltInViewportIndex ||
4008 bi_type == BuiltInBaryCoordKHR || bi_type == BuiltInBaryCoordNoPerspKHR ||
4009 bi_type == BuiltInFragDepth ||
4010 bi_type == BuiltInFragStencilRefEXT || bi_type == BuiltInSampleMask;
4011
4012 // These builtins are part of the stage in/out structs.
4013 bool is_interface_block_builtin =
4014 builtin_is_stage_in_out || (is_tese_shader() && !msl_options.raw_buffer_tese_input &&
4015 (bi_type == BuiltInTessLevelOuter || bi_type == BuiltInTessLevelInner));
4016
4017 bool is_active = interface_variable_exists_in_entry_point(id: var.self);
4018 if (is_builtin && is_active)
4019 {
4020 // Only emit the builtin if it's active in this entry point. Interface variable list might lie.
4021 if (is_block)
4022 {
4023 // If any builtin is active, the block is active.
4024 uint32_t mbr_cnt = uint32_t(type.member_types.size());
4025 for (uint32_t i = 0; !is_active && i < mbr_cnt; i++)
4026 is_active = has_active_builtin(builtin: BuiltIn(get_member_decoration(id: type.self, index: i, decoration: DecorationBuiltIn)), storage);
4027 }
4028 else
4029 {
4030 is_active = has_active_builtin(builtin: bi_type, storage);
4031 }
4032 }
4033
4034 bool filter_patch_decoration = (has_decoration(id: var_id, decoration: DecorationPatch) || is_patch_block(type)) == patch;
4035
4036 bool hidden = is_hidden_variable(var, include_builtins: incl_builtins);
4037
4038 // ClipDistance is never hidden, we need to emulate it when used as an input.
4039 if (bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance)
4040 hidden = false;
4041
4042 // It's not enough to simply avoid marking fragment outputs if the pipeline won't
4043 // accept them. We can't put them in the struct at all, or otherwise the compiler
4044 // complains that the outputs weren't explicitly marked.
4045 // Frag depth and stencil outputs are incompatible with explicit early fragment tests.
4046 // In GLSL, depth and stencil outputs are just ignored when explicit early fragment tests are required.
4047 // In Metal, it's a compilation error, so we need to exclude them from the output struct.
4048 if (get_execution_model() == ExecutionModelFragment && storage == StorageClassOutput && !patch &&
4049 ((is_builtin && ((bi_type == BuiltInFragDepth && (!msl_options.enable_frag_depth_builtin || uses_explicit_early_fragment_test())) ||
4050 (bi_type == BuiltInFragStencilRefEXT && (!msl_options.enable_frag_stencil_ref_builtin || uses_explicit_early_fragment_test())))) ||
4051 (!is_builtin && !(msl_options.enable_frag_output_mask & (1 << location)))))
4052 {
4053 hidden = true;
4054 disabled_frag_outputs.push_back(t: var_id);
4055 // If a builtin, force it to have the proper name, and mark it as not part of the output struct.
4056 if (is_builtin)
4057 {
4058 set_name(id: var_id, name: builtin_to_glsl(builtin: bi_type, storage: StorageClassFunction));
4059 mask_stage_output_by_builtin(builtin: bi_type);
4060 }
4061 }
4062
4063 // Barycentric inputs must be emitted in stage-in, because they can have interpolation arguments.
4064 if (is_active && (bi_type == BuiltInBaryCoordKHR || bi_type == BuiltInBaryCoordNoPerspKHR))
4065 {
4066 if (has_seen_barycentric)
4067 SPIRV_CROSS_THROW("Cannot declare both BaryCoordNV and BaryCoordNoPerspNV in same shader in MSL.");
4068 has_seen_barycentric = true;
4069 hidden = false;
4070 }
4071
4072 if (is_active && !hidden && type.pointer && filter_patch_decoration &&
4073 (!is_builtin || is_interface_block_builtin))
4074 {
4075 vars.push_back(t: &var);
4076
4077 if (!is_builtin)
4078 {
4079 // Need to deal specially with DecorationComponent.
4080 // Multiple variables can alias the same Location, and try to make sure each location is declared only once.
4081 // We will swizzle data in and out to make this work.
4082 // This is only relevant for vertex inputs and fragment outputs.
4083 // Technically tessellation as well, but it is too complicated to support.
4084 uint32_t component = get_decoration(id: var_id, decoration: DecorationComponent);
4085 if (component != 0)
4086 {
4087 if (is_tessellation_shader())
4088 SPIRV_CROSS_THROW("Component decoration is not supported in tessellation shaders.");
4089 else if (pack_components)
4090 {
4091 uint32_t array_size = 1;
4092 if (!type.array.empty())
4093 array_size = to_array_size_literal(type);
4094
4095 for (uint32_t location_offset = 0; location_offset < array_size; location_offset++)
4096 {
4097 auto &location_meta = meta.location_meta[location + location_offset];
4098 location_meta.num_components = max<uint32_t>(a: location_meta.num_components, b: component + type.vecsize);
4099
4100 // For variables sharing location, decorations and base type must match.
4101 location_meta.base_type_id = type.self;
4102 location_meta.flat = has_decoration(id: var.self, decoration: DecorationFlat);
4103 location_meta.noperspective = has_decoration(id: var.self, decoration: DecorationNoPerspective);
4104 location_meta.centroid = has_decoration(id: var.self, decoration: DecorationCentroid);
4105 location_meta.sample = has_decoration(id: var.self, decoration: DecorationSample);
4106 }
4107 }
4108 }
4109 }
4110 }
4111
4112 if (is_tese_shader() && msl_options.raw_buffer_tese_input && patch && storage == StorageClassInput &&
4113 (bi_type == BuiltInTessLevelOuter || bi_type == BuiltInTessLevelInner))
4114 {
4115 // In this case, we won't add the builtin to the interface struct,
4116 // but we still need the hook to run to populate the arrays.
4117 string base_ref = join(ts&: tess_factor_buffer_var_name, ts: "[", ts: to_expression(id: builtin_primitive_id_id), ts: "]");
4118 const char *mbr_name =
4119 bi_type == BuiltInTessLevelOuter ? "edgeTessellationFactor" : "insideTessellationFactor";
4120 add_tess_level_input(base_ref, mbr_name, var);
4121 if (inputs_by_builtin.count(x: bi_type))
4122 {
4123 uint32_t locn = inputs_by_builtin[bi_type].location;
4124 mark_location_as_used_by_shader(location: locn, type, storage: StorageClassInput);
4125 }
4126 }
4127 });
4128
4129 // If no variables qualify, leave.
4130 // For patch input in a tessellation evaluation shader, the per-vertex stage inputs
4131 // are included in a special patch control point array.
4132 if (vars.empty() &&
4133 !(!msl_options.raw_buffer_tese_input && storage == StorageClassInput && patch && stage_in_var_id))
4134 return 0;
4135
4136 // Add a new typed variable for this interface structure.
4137 // The initializer expression is allocated here, but populated when the function
4138 // declaraion is emitted, because it is cleared after each compilation pass.
4139 uint32_t next_id = ir.increase_bound_by(count: 3);
4140 uint32_t ib_type_id = next_id++;
4141 auto &ib_type = set<SPIRType>(id: ib_type_id, args: OpTypeStruct);
4142 ib_type.basetype = SPIRType::Struct;
4143 ib_type.storage = storage;
4144 set_decoration(id: ib_type_id, decoration: DecorationBlock);
4145
4146 uint32_t ib_var_id = next_id++;
4147 auto &var = set<SPIRVariable>(id: ib_var_id, args&: ib_type_id, args&: storage, args: 0);
4148 var.initializer = next_id++;
4149
4150 string ib_var_ref;
4151 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
4152 switch (storage)
4153 {
4154 case StorageClassInput:
4155 ib_var_ref = patch ? patch_stage_in_var_name : stage_in_var_name;
4156 switch (get_execution_model())
4157 {
4158 case ExecutionModelTessellationControl:
4159 // Add a hook to populate the shared workgroup memory containing the gl_in array.
4160 entry_func.fixup_hooks_in.push_back(t: [=]() {
4161 // Can't use PatchVertices, PrimitiveId, or InvocationId yet; the hooks for those may not have run yet.
4162 if (msl_options.multi_patch_workgroup)
4163 {
4164 // n.b. builtin_invocation_id_id here is the dispatch global invocation ID,
4165 // not the TC invocation ID.
4166 statement(ts: "device ", ts: to_name(id: ir.default_entry_point), ts: "_", ts: ib_var_ref, ts: "* gl_in = &",
4167 ts&: input_buffer_var_name, ts: "[min(", ts: to_expression(id: builtin_invocation_id_id), ts: ".x / ",
4168 ts&: get_entry_point().output_vertices,
4169 ts: ", spvIndirectParams[1] - 1) * spvIndirectParams[0]];");
4170 }
4171 else
4172 {
4173 // It's safe to use InvocationId here because it's directly mapped to a
4174 // Metal builtin, and therefore doesn't need a hook.
4175 statement(ts: "if (", ts: to_expression(id: builtin_invocation_id_id), ts: " < spvIndirectParams[0])");
4176 statement(ts: " ", ts&: input_wg_var_name, ts: "[", ts: to_expression(id: builtin_invocation_id_id),
4177 ts: "] = ", ts: ib_var_ref, ts: ";");
4178 statement(ts: "threadgroup_barrier(mem_flags::mem_threadgroup);");
4179 statement(ts: "if (", ts: to_expression(id: builtin_invocation_id_id),
4180 ts: " >= ", ts&: get_entry_point().output_vertices, ts: ")");
4181 statement(ts: " return;");
4182 }
4183 });
4184 break;
4185 case ExecutionModelTessellationEvaluation:
4186 if (!msl_options.raw_buffer_tese_input)
4187 break;
4188 if (patch)
4189 {
4190 entry_func.fixup_hooks_in.push_back(
4191 t: [=]()
4192 {
4193 statement(ts: "const device ", ts: to_name(id: ir.default_entry_point), ts: "_", ts: ib_var_ref, ts: "& ", ts: ib_var_ref,
4194 ts: " = ", ts&: patch_input_buffer_var_name, ts: "[", ts: to_expression(id: builtin_primitive_id_id),
4195 ts: "];");
4196 });
4197 }
4198 else
4199 {
4200 entry_func.fixup_hooks_in.push_back(
4201 t: [=]()
4202 {
4203 statement(ts: "const device ", ts: to_name(id: ir.default_entry_point), ts: "_", ts: ib_var_ref, ts: "* gl_in = &",
4204 ts&: input_buffer_var_name, ts: "[", ts: to_expression(id: builtin_primitive_id_id), ts: " * ",
4205 ts&: get_entry_point().output_vertices, ts: "];");
4206 });
4207 }
4208 break;
4209 default:
4210 break;
4211 }
4212 break;
4213
4214 case StorageClassOutput:
4215 {
4216 ib_var_ref = patch ? patch_stage_out_var_name : stage_out_var_name;
4217
4218 // Add the output interface struct as a local variable to the entry function.
4219 // If the entry point should return the output struct, set the entry function
4220 // to return the output interface struct, otherwise to return nothing.
4221 // Watch out for the rare case where the terminator of the last entry point block is a
4222 // Kill, instead of a Return. Based on SPIR-V's block-domination rules, we assume that
4223 // any block that has a Kill will also have a terminating Return, except the last block.
4224 // Indicate the output var requires early initialization.
4225 bool ep_should_return_output = !get_is_rasterization_disabled();
4226 uint32_t rtn_id = ep_should_return_output ? ib_var_id : 0;
4227 if (!capture_output_to_buffer)
4228 {
4229 entry_func.add_local_variable(id: ib_var_id);
4230 for (auto &blk_id : entry_func.blocks)
4231 {
4232 auto &blk = get<SPIRBlock>(id: blk_id);
4233 if (blk.terminator == SPIRBlock::Return || (blk.terminator == SPIRBlock::Kill && blk_id == entry_func.blocks.back()))
4234 blk.return_value = rtn_id;
4235 }
4236 vars_needing_early_declaration.push_back(t: ib_var_id);
4237 }
4238 else
4239 {
4240 switch (get_execution_model())
4241 {
4242 case ExecutionModelVertex:
4243 case ExecutionModelTessellationEvaluation:
4244 // Instead of declaring a struct variable to hold the output and then
4245 // copying that to the output buffer, we'll declare the output variable
4246 // as a reference to the final output element in the buffer. Then we can
4247 // avoid the extra copy.
4248 entry_func.fixup_hooks_in.push_back(t: [=]() {
4249 if (stage_out_var_id)
4250 {
4251 // The first member of the indirect buffer is always the number of vertices
4252 // to draw.
4253 // We zero-base the InstanceID & VertexID variables for HLSL emulation elsewhere, so don't do it twice
4254 if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation)
4255 {
4256 statement(ts: "device ", ts: to_name(id: ir.default_entry_point), ts: "_", ts: ib_var_ref, ts: "& ", ts: ib_var_ref,
4257 ts: " = ", ts&: output_buffer_var_name, ts: "[", ts: to_expression(id: builtin_invocation_id_id),
4258 ts: ".y * ", ts: to_expression(id: builtin_stage_input_size_id), ts: ".x + ",
4259 ts: to_expression(id: builtin_invocation_id_id), ts: ".x];");
4260 }
4261 else if (msl_options.enable_base_index_zero)
4262 {
4263 statement(ts: "device ", ts: to_name(id: ir.default_entry_point), ts: "_", ts: ib_var_ref, ts: "& ", ts: ib_var_ref,
4264 ts: " = ", ts&: output_buffer_var_name, ts: "[", ts: to_expression(id: builtin_instance_idx_id),
4265 ts: " * spvIndirectParams[0] + ", ts: to_expression(id: builtin_vertex_idx_id), ts: "];");
4266 }
4267 else
4268 {
4269 statement(ts: "device ", ts: to_name(id: ir.default_entry_point), ts: "_", ts: ib_var_ref, ts: "& ", ts: ib_var_ref,
4270 ts: " = ", ts&: output_buffer_var_name, ts: "[(", ts: to_expression(id: builtin_instance_idx_id),
4271 ts: " - ", ts: to_expression(id: builtin_base_instance_id), ts: ") * spvIndirectParams[0] + ",
4272 ts: to_expression(id: builtin_vertex_idx_id), ts: " - ",
4273 ts: to_expression(id: builtin_base_vertex_id), ts: "];");
4274 }
4275 }
4276 });
4277 break;
4278 case ExecutionModelTessellationControl:
4279 if (msl_options.multi_patch_workgroup)
4280 {
4281 // We cannot use PrimitiveId here, because the hook may not have run yet.
4282 if (patch)
4283 {
4284 entry_func.fixup_hooks_in.push_back(t: [=]() {
4285 statement(ts: "device ", ts: to_name(id: ir.default_entry_point), ts: "_", ts: ib_var_ref, ts: "& ", ts: ib_var_ref,
4286 ts: " = ", ts&: patch_output_buffer_var_name, ts: "[", ts: to_expression(id: builtin_invocation_id_id),
4287 ts: ".x / ", ts&: get_entry_point().output_vertices, ts: "];");
4288 });
4289 }
4290 else
4291 {
4292 entry_func.fixup_hooks_in.push_back(t: [=]() {
4293 statement(ts: "device ", ts: to_name(id: ir.default_entry_point), ts: "_", ts: ib_var_ref, ts: "* gl_out = &",
4294 ts&: output_buffer_var_name, ts: "[", ts: to_expression(id: builtin_invocation_id_id), ts: ".x - ",
4295 ts: to_expression(id: builtin_invocation_id_id), ts: ".x % ",
4296 ts&: get_entry_point().output_vertices, ts: "];");
4297 });
4298 }
4299 }
4300 else
4301 {
4302 if (patch)
4303 {
4304 entry_func.fixup_hooks_in.push_back(t: [=]() {
4305 statement(ts: "device ", ts: to_name(id: ir.default_entry_point), ts: "_", ts: ib_var_ref, ts: "& ", ts: ib_var_ref,
4306 ts: " = ", ts&: patch_output_buffer_var_name, ts: "[", ts: to_expression(id: builtin_primitive_id_id),
4307 ts: "];");
4308 });
4309 }
4310 else
4311 {
4312 entry_func.fixup_hooks_in.push_back(t: [=]() {
4313 statement(ts: "device ", ts: to_name(id: ir.default_entry_point), ts: "_", ts: ib_var_ref, ts: "* gl_out = &",
4314 ts&: output_buffer_var_name, ts: "[", ts: to_expression(id: builtin_primitive_id_id), ts: " * ",
4315 ts&: get_entry_point().output_vertices, ts: "];");
4316 });
4317 }
4318 }
4319 break;
4320 default:
4321 break;
4322 }
4323 }
4324 break;
4325 }
4326
4327 default:
4328 break;
4329 }
4330
4331 set_name(id: ib_type_id, name: to_name(id: ir.default_entry_point) + "_" + ib_var_ref);
4332 set_name(id: ib_var_id, name: ib_var_ref);
4333
4334 for (auto *p_var : vars)
4335 {
4336 bool strip_array = (is_tesc_shader() || (is_tese_shader() && storage == StorageClassInput)) && !patch;
4337
4338 // Fixing up flattened stores in TESC is impossible since the memory is group shared either via
4339 // device (not masked) or threadgroup (masked) storage classes and it's race condition city.
4340 meta.strip_array = strip_array;
4341 meta.allow_local_declaration = !strip_array && !(is_tesc_shader() && storage == StorageClassOutput);
4342 add_variable_to_interface_block(storage, ib_var_ref, ib_type, var&: *p_var, meta);
4343 }
4344
4345 if (((is_tesc_shader() && msl_options.multi_patch_workgroup) ||
4346 (is_tese_shader() && msl_options.raw_buffer_tese_input)) &&
4347 storage == StorageClassInput)
4348 {
4349 // For tessellation inputs, add all outputs from the previous stage to ensure
4350 // the struct containing them is the correct size and layout.
4351 for (auto &input : inputs_by_location)
4352 {
4353 if (location_inputs_in_use.count(x: input.first.location) != 0)
4354 continue;
4355
4356 if (patch != (input.second.rate == MSL_SHADER_VARIABLE_RATE_PER_PATCH))
4357 continue;
4358
4359 // Tessellation levels have their own struct, so there's no need to add them here.
4360 if (input.second.builtin == BuiltInTessLevelOuter || input.second.builtin == BuiltInTessLevelInner)
4361 continue;
4362
4363 // Create a fake variable to put at the location.
4364 uint32_t offset = ir.increase_bound_by(count: 5);
4365 uint32_t type_id = offset;
4366 uint32_t vec_type_id = offset + 1;
4367 uint32_t array_type_id = offset + 2;
4368 uint32_t ptr_type_id = offset + 3;
4369 uint32_t var_id = offset + 4;
4370
4371 SPIRType type { OpTypeInt };
4372 switch (input.second.format)
4373 {
4374 case MSL_SHADER_VARIABLE_FORMAT_UINT16:
4375 case MSL_SHADER_VARIABLE_FORMAT_ANY16:
4376 type.basetype = SPIRType::UShort;
4377 type.width = 16;
4378 break;
4379 case MSL_SHADER_VARIABLE_FORMAT_ANY32:
4380 default:
4381 type.basetype = SPIRType::UInt;
4382 type.width = 32;
4383 break;
4384 }
4385 set<SPIRType>(id: type_id, args&: type);
4386 if (input.second.vecsize > 1)
4387 {
4388 type.op = OpTypeVector;
4389 type.vecsize = input.second.vecsize;
4390 set<SPIRType>(id: vec_type_id, args&: type);
4391 type_id = vec_type_id;
4392 }
4393
4394 type.op = OpTypeArray;
4395 type.array.push_back(t: 0);
4396 type.array_size_literal.push_back(t: true);
4397 type.parent_type = type_id;
4398 set<SPIRType>(id: array_type_id, args&: type);
4399 type.self = type_id;
4400
4401 type.op = OpTypePointer;
4402 type.pointer = true;
4403 type.pointer_depth++;
4404 type.parent_type = array_type_id;
4405 type.storage = storage;
4406 auto &ptr_type = set<SPIRType>(id: ptr_type_id, args&: type);
4407 ptr_type.self = array_type_id;
4408
4409 auto &fake_var = set<SPIRVariable>(id: var_id, args&: ptr_type_id, args&: storage);
4410 set_decoration(id: var_id, decoration: DecorationLocation, argument: input.first.location);
4411 if (input.first.component)
4412 set_decoration(id: var_id, decoration: DecorationComponent, argument: input.first.component);
4413
4414 meta.strip_array = true;
4415 meta.allow_local_declaration = false;
4416 add_variable_to_interface_block(storage, ib_var_ref, ib_type, var&: fake_var, meta);
4417 }
4418 }
4419
4420 if (capture_output_to_buffer && storage == StorageClassOutput)
4421 {
4422 // For captured output, add all inputs from the next stage to ensure
4423 // the struct containing them is the correct size and layout. This is
4424 // necessary for certain implicit builtins that may nonetheless be read,
4425 // even when they aren't written.
4426 for (auto &output : outputs_by_location)
4427 {
4428 if (location_outputs_in_use.count(x: output.first.location) != 0)
4429 continue;
4430
4431 // Create a fake variable to put at the location.
4432 uint32_t offset = ir.increase_bound_by(count: 5);
4433 uint32_t type_id = offset;
4434 uint32_t vec_type_id = offset + 1;
4435 uint32_t array_type_id = offset + 2;
4436 uint32_t ptr_type_id = offset + 3;
4437 uint32_t var_id = offset + 4;
4438
4439 SPIRType type { OpTypeInt };
4440 switch (output.second.format)
4441 {
4442 case MSL_SHADER_VARIABLE_FORMAT_UINT16:
4443 case MSL_SHADER_VARIABLE_FORMAT_ANY16:
4444 type.basetype = SPIRType::UShort;
4445 type.width = 16;
4446 break;
4447 case MSL_SHADER_VARIABLE_FORMAT_ANY32:
4448 default:
4449 type.basetype = SPIRType::UInt;
4450 type.width = 32;
4451 break;
4452 }
4453 set<SPIRType>(id: type_id, args&: type);
4454 if (output.second.vecsize > 1)
4455 {
4456 type.op = OpTypeVector;
4457 type.vecsize = output.second.vecsize;
4458 set<SPIRType>(id: vec_type_id, args&: type);
4459 type_id = vec_type_id;
4460 }
4461
4462 if (is_tesc_shader())
4463 {
4464 type.op = OpTypeArray;
4465 type.array.push_back(t: 0);
4466 type.array_size_literal.push_back(t: true);
4467 type.parent_type = type_id;
4468 set<SPIRType>(id: array_type_id, args&: type);
4469 }
4470
4471 type.op = OpTypePointer;
4472 type.pointer = true;
4473 type.pointer_depth++;
4474 type.parent_type = is_tesc_shader() ? array_type_id : type_id;
4475 type.storage = storage;
4476 auto &ptr_type = set<SPIRType>(id: ptr_type_id, args&: type);
4477 ptr_type.self = type.parent_type;
4478
4479 auto &fake_var = set<SPIRVariable>(id: var_id, args&: ptr_type_id, args&: storage);
4480 set_decoration(id: var_id, decoration: DecorationLocation, argument: output.first.location);
4481 if (output.first.component)
4482 set_decoration(id: var_id, decoration: DecorationComponent, argument: output.first.component);
4483
4484 meta.strip_array = true;
4485 meta.allow_local_declaration = false;
4486 add_variable_to_interface_block(storage, ib_var_ref, ib_type, var&: fake_var, meta);
4487 }
4488 }
4489
4490 // When multiple variables need to access same location,
4491 // unroll locations one by one and we will flatten output or input as necessary.
4492 for (auto &loc : meta.location_meta)
4493 {
4494 uint32_t location = loc.first;
4495 auto &location_meta = loc.second;
4496
4497 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
4498 uint32_t type_id = build_extended_vector_type(type_id: location_meta.base_type_id, components: location_meta.num_components);
4499 ib_type.member_types.push_back(t: type_id);
4500
4501 set_member_name(id: ib_type.self, index: ib_mbr_idx, name: join(ts: "m_location_", ts&: location));
4502 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: location);
4503 mark_location_as_used_by_shader(location, type: get<SPIRType>(id: type_id), storage);
4504
4505 if (location_meta.flat)
4506 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationFlat);
4507 if (location_meta.noperspective)
4508 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationNoPerspective);
4509 if (location_meta.centroid)
4510 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationCentroid);
4511 if (location_meta.sample)
4512 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationSample);
4513 }
4514
4515 // Sort the members of the structure by their locations.
4516 MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::LocationThenBuiltInType);
4517 member_sorter.sort();
4518
4519 // The member indices were saved to the original variables, but after the members
4520 // were sorted, those indices are now likely incorrect. Fix those up now.
4521 fix_up_interface_member_indices(storage, ib_type_id);
4522
4523 // For patch inputs, add one more member, holding the array of control point data.
4524 if (is_tese_shader() && !msl_options.raw_buffer_tese_input && storage == StorageClassInput && patch &&
4525 stage_in_var_id)
4526 {
4527 uint32_t pcp_type_id = ir.increase_bound_by(count: 1);
4528 auto &pcp_type = set<SPIRType>(id: pcp_type_id, args&: ib_type);
4529 pcp_type.basetype = SPIRType::ControlPointArray;
4530 pcp_type.parent_type = pcp_type.type_alias = get_stage_in_struct_type().self;
4531 pcp_type.storage = storage;
4532 ir.meta[pcp_type_id] = ir.meta[ib_type.self];
4533 uint32_t mbr_idx = uint32_t(ib_type.member_types.size());
4534 ib_type.member_types.push_back(t: pcp_type_id);
4535 set_member_name(id: ib_type.self, index: mbr_idx, name: "gl_in");
4536 }
4537
4538 if (storage == StorageClassInput)
4539 set_decoration(id: ib_var_id, decoration: DecorationNonWritable);
4540
4541 return ib_var_id;
4542}
4543
4544uint32_t CompilerMSL::add_interface_block_pointer(uint32_t ib_var_id, StorageClass storage)
4545{
4546 if (!ib_var_id)
4547 return 0;
4548
4549 uint32_t ib_ptr_var_id;
4550 uint32_t next_id = ir.increase_bound_by(count: 3);
4551 auto &ib_type = expression_type(id: ib_var_id);
4552 if (is_tesc_shader() || (is_tese_shader() && msl_options.raw_buffer_tese_input))
4553 {
4554 // Tessellation control per-vertex I/O is presented as an array, so we must
4555 // do the same with our struct here.
4556 uint32_t ib_ptr_type_id = next_id++;
4557 auto &ib_ptr_type = set<SPIRType>(id: ib_ptr_type_id, args: ib_type);
4558 ib_ptr_type.op = OpTypePointer;
4559 ib_ptr_type.parent_type = ib_ptr_type.type_alias = ib_type.self;
4560 ib_ptr_type.pointer = true;
4561 ib_ptr_type.pointer_depth++;
4562 ib_ptr_type.storage = storage == StorageClassInput ?
4563 ((is_tesc_shader() && msl_options.multi_patch_workgroup) ||
4564 (is_tese_shader() && msl_options.raw_buffer_tese_input) ?
4565 StorageClassStorageBuffer :
4566 StorageClassWorkgroup) :
4567 StorageClassStorageBuffer;
4568 ir.meta[ib_ptr_type_id] = ir.meta[ib_type.self];
4569 // To ensure that get_variable_data_type() doesn't strip off the pointer,
4570 // which we need, use another pointer.
4571 uint32_t ib_ptr_ptr_type_id = next_id++;
4572 auto &ib_ptr_ptr_type = set<SPIRType>(id: ib_ptr_ptr_type_id, args&: ib_ptr_type);
4573 ib_ptr_ptr_type.parent_type = ib_ptr_type_id;
4574 ib_ptr_ptr_type.type_alias = ib_type.self;
4575 ib_ptr_ptr_type.storage = StorageClassFunction;
4576 ir.meta[ib_ptr_ptr_type_id] = ir.meta[ib_type.self];
4577
4578 ib_ptr_var_id = next_id;
4579 set<SPIRVariable>(id: ib_ptr_var_id, args&: ib_ptr_ptr_type_id, args: StorageClassFunction, args: 0);
4580 set_name(id: ib_ptr_var_id, name: storage == StorageClassInput ? "gl_in" : "gl_out");
4581 if (storage == StorageClassInput)
4582 set_decoration(id: ib_ptr_var_id, decoration: DecorationNonWritable);
4583 }
4584 else
4585 {
4586 // Tessellation evaluation per-vertex inputs are also presented as arrays.
4587 // But, in Metal, this array uses a very special type, 'patch_control_point<T>',
4588 // which is a container that can be used to access the control point data.
4589 // To represent this, a special 'ControlPointArray' type has been added to the
4590 // SPIRV-Cross type system. It should only be generated by and seen in the MSL
4591 // backend (i.e. this one).
4592 uint32_t pcp_type_id = next_id++;
4593 auto &pcp_type = set<SPIRType>(id: pcp_type_id, args: ib_type);
4594 pcp_type.basetype = SPIRType::ControlPointArray;
4595 pcp_type.parent_type = pcp_type.type_alias = ib_type.self;
4596 pcp_type.storage = storage;
4597 ir.meta[pcp_type_id] = ir.meta[ib_type.self];
4598
4599 ib_ptr_var_id = next_id;
4600 set<SPIRVariable>(id: ib_ptr_var_id, args&: pcp_type_id, args&: storage, args: 0);
4601 set_name(id: ib_ptr_var_id, name: "gl_in");
4602 ir.meta[ib_ptr_var_id].decoration.qualified_alias = join(ts&: patch_stage_in_var_name, ts: ".gl_in");
4603 }
4604 return ib_ptr_var_id;
4605}
4606
4607uint32_t CompilerMSL::add_meshlet_block(bool per_primitive)
4608{
4609 // Accumulate the variables that should appear in the interface struct.
4610 SmallVector<SPIRVariable *> vars;
4611
4612 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
4613 if (var.storage != StorageClassOutput || var.self == builtin_mesh_primitive_indices_id)
4614 return;
4615 if (is_per_primitive_variable(var) != per_primitive)
4616 return;
4617 vars.push_back(t: &var);
4618 });
4619
4620 if (vars.empty())
4621 return 0;
4622
4623 uint32_t next_id = ir.increase_bound_by(count: 1);
4624 auto &type = set<SPIRType>(id: next_id, args: SPIRType(OpTypeStruct));
4625 type.basetype = SPIRType::Struct;
4626
4627 InterfaceBlockMeta meta;
4628 for (auto *p_var : vars)
4629 {
4630 meta.strip_array = true;
4631 meta.allow_local_declaration = false;
4632 add_variable_to_interface_block(storage: StorageClassOutput, ib_var_ref: "", ib_type&: type, var&: *p_var, meta);
4633 }
4634
4635 if (per_primitive)
4636 set_name(id: type.self, name: "spvPerPrimitive");
4637 else
4638 set_name(id: type.self, name: "spvPerVertex");
4639
4640 return next_id;
4641}
4642
4643// Ensure that the type is compatible with the builtin.
4644// If it is, simply return the given type ID.
4645// Otherwise, create a new type, and return it's ID.
4646uint32_t CompilerMSL::ensure_correct_builtin_type(uint32_t type_id, BuiltIn builtin)
4647{
4648 auto &type = get<SPIRType>(id: type_id);
4649 auto &pointee_type = get_pointee_type(type);
4650
4651 if ((builtin == BuiltInSampleMask && is_array(type: pointee_type)) ||
4652 ((builtin == BuiltInLayer || builtin == BuiltInViewportIndex || builtin == BuiltInFragStencilRefEXT) &&
4653 pointee_type.basetype != SPIRType::UInt))
4654 {
4655 uint32_t next_id = ir.increase_bound_by(count: is_pointer(type) ? 2 : 1);
4656 uint32_t base_type_id = next_id++;
4657 auto &base_type = set<SPIRType>(id: base_type_id, args: OpTypeInt);
4658 base_type.basetype = SPIRType::UInt;
4659 base_type.width = 32;
4660
4661 if (!is_pointer(type))
4662 return base_type_id;
4663
4664 uint32_t ptr_type_id = next_id++;
4665 auto &ptr_type = set<SPIRType>(id: ptr_type_id, args&: base_type);
4666 ptr_type.op = spv::OpTypePointer;
4667 ptr_type.pointer = true;
4668 ptr_type.pointer_depth++;
4669 ptr_type.storage = type.storage;
4670 ptr_type.parent_type = base_type_id;
4671 return ptr_type_id;
4672 }
4673
4674 return type_id;
4675}
4676
4677// Ensure that the type is compatible with the shader input.
4678// If it is, simply return the given type ID.
4679// Otherwise, create a new type, and return its ID.
4680uint32_t CompilerMSL::ensure_correct_input_type(uint32_t type_id, uint32_t location, uint32_t component, uint32_t num_components, bool strip_array)
4681{
4682 auto &type = get<SPIRType>(id: type_id);
4683
4684 uint32_t max_array_dimensions = strip_array ? 1 : 0;
4685
4686 // Struct and array types must match exactly.
4687 if (type.basetype == SPIRType::Struct || type.array.size() > max_array_dimensions)
4688 return type_id;
4689
4690 auto p_va = inputs_by_location.find(x: {.location: location, .component: component});
4691 if (p_va == end(cont&: inputs_by_location))
4692 {
4693 if (num_components > type.vecsize)
4694 return build_extended_vector_type(type_id, components: num_components);
4695 else
4696 return type_id;
4697 }
4698
4699 if (num_components == 0)
4700 num_components = p_va->second.vecsize;
4701
4702 switch (p_va->second.format)
4703 {
4704 case MSL_SHADER_VARIABLE_FORMAT_UINT8:
4705 {
4706 switch (type.basetype)
4707 {
4708 case SPIRType::UByte:
4709 case SPIRType::UShort:
4710 case SPIRType::UInt:
4711 if (num_components > type.vecsize)
4712 return build_extended_vector_type(type_id, components: num_components);
4713 else
4714 return type_id;
4715
4716 case SPIRType::Short:
4717 return build_extended_vector_type(type_id, components: num_components > type.vecsize ? num_components : type.vecsize,
4718 basetype: SPIRType::UShort);
4719 case SPIRType::Int:
4720 return build_extended_vector_type(type_id, components: num_components > type.vecsize ? num_components : type.vecsize,
4721 basetype: SPIRType::UInt);
4722
4723 default:
4724 SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader");
4725 }
4726 }
4727
4728 case MSL_SHADER_VARIABLE_FORMAT_UINT16:
4729 {
4730 switch (type.basetype)
4731 {
4732 case SPIRType::UShort:
4733 case SPIRType::UInt:
4734 if (num_components > type.vecsize)
4735 return build_extended_vector_type(type_id, components: num_components);
4736 else
4737 return type_id;
4738
4739 case SPIRType::Int:
4740 return build_extended_vector_type(type_id, components: num_components > type.vecsize ? num_components : type.vecsize,
4741 basetype: SPIRType::UInt);
4742
4743 default:
4744 SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader");
4745 }
4746 }
4747
4748 default:
4749 if (num_components > type.vecsize)
4750 type_id = build_extended_vector_type(type_id, components: num_components);
4751 break;
4752 }
4753
4754 return type_id;
4755}
4756
4757void CompilerMSL::mark_struct_members_packed(const SPIRType &type)
4758{
4759 // Handle possible recursion when a struct contains a pointer to its own type nested somewhere.
4760 if (has_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationPhysicalTypePacked))
4761 return;
4762
4763 set_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationPhysicalTypePacked);
4764
4765 // Problem case! Struct needs to be placed at an awkward alignment.
4766 // Mark every member of the child struct as packed.
4767 uint32_t mbr_cnt = uint32_t(type.member_types.size());
4768 for (uint32_t i = 0; i < mbr_cnt; i++)
4769 {
4770 auto &mbr_type = get<SPIRType>(id: type.member_types[i]);
4771 if (mbr_type.basetype == SPIRType::Struct)
4772 {
4773 // Recursively mark structs as packed.
4774 auto *struct_type = &mbr_type;
4775 while (!struct_type->array.empty())
4776 struct_type = &get<SPIRType>(id: struct_type->parent_type);
4777 mark_struct_members_packed(type: *struct_type);
4778 }
4779 else if (!is_scalar(type: mbr_type))
4780 set_extended_member_decoration(type: type.self, index: i, decoration: SPIRVCrossDecorationPhysicalTypePacked);
4781 }
4782}
4783
4784void CompilerMSL::mark_scalar_layout_structs(const SPIRType &type)
4785{
4786 uint32_t mbr_cnt = uint32_t(type.member_types.size());
4787 for (uint32_t i = 0; i < mbr_cnt; i++)
4788 {
4789 // Handle possible recursion when a struct contains a pointer to its own type nested somewhere.
4790 auto &mbr_type = get<SPIRType>(id: type.member_types[i]);
4791 if (mbr_type.basetype == SPIRType::Struct && !(mbr_type.pointer && mbr_type.storage == StorageClassPhysicalStorageBuffer))
4792 {
4793 auto *struct_type = &mbr_type;
4794 while (!struct_type->array.empty())
4795 struct_type = &get<SPIRType>(id: struct_type->parent_type);
4796
4797 if (has_extended_decoration(id: struct_type->self, decoration: SPIRVCrossDecorationPhysicalTypePacked))
4798 continue;
4799
4800 uint32_t msl_alignment = get_declared_struct_member_alignment_msl(struct_type: type, index: i);
4801 uint32_t msl_size = get_declared_struct_member_size_msl(struct_type: type, index: i);
4802 uint32_t spirv_offset = type_struct_member_offset(type, index: i);
4803 uint32_t spirv_offset_next;
4804 if (i + 1 < mbr_cnt)
4805 spirv_offset_next = type_struct_member_offset(type, index: i + 1);
4806 else
4807 spirv_offset_next = spirv_offset + msl_size;
4808
4809 // Both are complicated cases. In scalar layout, a struct of float3 might just consume 12 bytes,
4810 // and the next member will be placed at offset 12.
4811 bool struct_is_misaligned = (spirv_offset % msl_alignment) != 0;
4812 bool struct_is_too_large = spirv_offset + msl_size > spirv_offset_next;
4813 uint32_t array_stride = 0;
4814 bool struct_needs_explicit_padding = false;
4815
4816 // Verify that if a struct is used as an array that ArrayStride matches the effective size of the struct.
4817 if (!mbr_type.array.empty())
4818 {
4819 array_stride = type_struct_member_array_stride(type, index: i);
4820 uint32_t dimensions = uint32_t(mbr_type.array.size() - 1);
4821 for (uint32_t dim = 0; dim < dimensions; dim++)
4822 {
4823 uint32_t array_size = to_array_size_literal(type: mbr_type, index: dim);
4824 array_stride /= max<uint32_t>(a: array_size, b: 1u);
4825 }
4826
4827 // Set expected struct size based on ArrayStride.
4828 struct_needs_explicit_padding = true;
4829
4830 // If struct size is larger than array stride, we might be able to fit, if we tightly pack.
4831 if (get_declared_struct_size_msl(struct_type: *struct_type) > array_stride)
4832 struct_is_too_large = true;
4833 }
4834
4835 if (struct_is_misaligned || struct_is_too_large)
4836 mark_struct_members_packed(type: *struct_type);
4837 mark_scalar_layout_structs(type: *struct_type);
4838
4839 if (struct_needs_explicit_padding)
4840 {
4841 msl_size = get_declared_struct_size_msl(struct_type: *struct_type, ignore_alignment: true, ignore_padding: true);
4842 if (array_stride < msl_size)
4843 {
4844 SPIRV_CROSS_THROW("Cannot express an array stride smaller than size of struct type.");
4845 }
4846 else
4847 {
4848 if (has_extended_decoration(id: struct_type->self, decoration: SPIRVCrossDecorationPaddingTarget))
4849 {
4850 if (array_stride !=
4851 get_extended_decoration(id: struct_type->self, decoration: SPIRVCrossDecorationPaddingTarget))
4852 SPIRV_CROSS_THROW(
4853 "A struct is used with different array strides. Cannot express this in MSL.");
4854 }
4855 else
4856 set_extended_decoration(id: struct_type->self, decoration: SPIRVCrossDecorationPaddingTarget, value: array_stride);
4857 }
4858 }
4859 }
4860 }
4861}
4862
4863// Sort the members of the struct type by offset, and pack and then pad members where needed
4864// to align MSL members with SPIR-V offsets. The struct members are iterated twice. Packing
4865// occurs first, followed by padding, because packing a member reduces both its size and its
4866// natural alignment, possibly requiring a padding member to be added ahead of it.
4867void CompilerMSL::align_struct(SPIRType &ib_type, unordered_set<uint32_t> &aligned_structs)
4868{
4869 // We align structs recursively, so stop any redundant work.
4870 ID &ib_type_id = ib_type.self;
4871 if (aligned_structs.count(x: ib_type_id))
4872 return;
4873 aligned_structs.insert(x: ib_type_id);
4874
4875 // Sort the members of the interface structure by their offset.
4876 // They should already be sorted per SPIR-V spec anyway.
4877 MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::Offset);
4878 member_sorter.sort();
4879
4880 auto mbr_cnt = uint32_t(ib_type.member_types.size());
4881
4882 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
4883 {
4884 // Pack any dependent struct types before we pack a parent struct.
4885 auto &mbr_type = get<SPIRType>(id: ib_type.member_types[mbr_idx]);
4886 if (mbr_type.basetype == SPIRType::Struct)
4887 align_struct(ib_type&: mbr_type, aligned_structs);
4888 }
4889
4890 // Test the alignment of each member, and if a member should be closer to the previous
4891 // member than the default spacing expects, it is likely that the previous member is in
4892 // a packed format. If so, and the previous member is packable, pack it.
4893 // For example ... this applies to any 3-element vector that is followed by a scalar.
4894 uint32_t msl_offset = 0;
4895 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
4896 {
4897 // This checks the member in isolation, if the member needs some kind of type remapping to conform to SPIR-V
4898 // offsets, array strides and matrix strides.
4899 ensure_member_packing_rules_msl(ib_type, index: mbr_idx);
4900
4901 // Align current offset to the current member's default alignment. If the member was packed, it will observe
4902 // the updated alignment here.
4903 uint32_t msl_align_mask = get_declared_struct_member_alignment_msl(struct_type: ib_type, index: mbr_idx) - 1;
4904 uint32_t aligned_msl_offset = (msl_offset + msl_align_mask) & ~msl_align_mask;
4905
4906 // Fetch the member offset as declared in the SPIRV.
4907 uint32_t spirv_mbr_offset = get_member_decoration(id: ib_type_id, index: mbr_idx, decoration: DecorationOffset);
4908 if (spirv_mbr_offset > aligned_msl_offset)
4909 {
4910 // Since MSL and SPIR-V have slightly different struct member alignment and
4911 // size rules, we'll pad to standard C-packing rules with a char[] array. If the member is farther
4912 // away than C-packing, expects, add an inert padding member before the the member.
4913 uint32_t padding_bytes = spirv_mbr_offset - aligned_msl_offset;
4914 set_extended_member_decoration(type: ib_type_id, index: mbr_idx, decoration: SPIRVCrossDecorationPaddingTarget, value: padding_bytes);
4915
4916 // Re-align as a sanity check that aligning post-padding matches up.
4917 msl_offset += padding_bytes;
4918 aligned_msl_offset = (msl_offset + msl_align_mask) & ~msl_align_mask;
4919 }
4920 else if (spirv_mbr_offset < aligned_msl_offset)
4921 {
4922 // This should not happen, but deal with unexpected scenarios.
4923 // It *might* happen if a sub-struct has a larger alignment requirement in MSL than SPIR-V.
4924 SPIRV_CROSS_THROW("Cannot represent buffer block correctly in MSL.");
4925 }
4926
4927 assert(aligned_msl_offset == spirv_mbr_offset);
4928
4929 // Increment the current offset to be positioned immediately after the current member.
4930 // Don't do this for the last member since it can be unsized, and it is not relevant for padding purposes here.
4931 if (mbr_idx + 1 < mbr_cnt)
4932 msl_offset = aligned_msl_offset + get_declared_struct_member_size_msl(struct_type: ib_type, index: mbr_idx);
4933 }
4934}
4935
4936bool CompilerMSL::validate_member_packing_rules_msl(const SPIRType &type, uint32_t index) const
4937{
4938 auto &mbr_type = get<SPIRType>(id: type.member_types[index]);
4939 uint32_t spirv_offset = get_member_decoration(id: type.self, index, decoration: DecorationOffset);
4940
4941 if (index + 1 < type.member_types.size())
4942 {
4943 // First, we will check offsets. If SPIR-V offset + MSL size > SPIR-V offset of next member,
4944 // we *must* perform some kind of remapping, no way getting around it.
4945 // We can always pad after this member if necessary, so that case is fine.
4946 uint32_t spirv_offset_next = get_member_decoration(id: type.self, index: index + 1, decoration: DecorationOffset);
4947 assert(spirv_offset_next >= spirv_offset);
4948 uint32_t maximum_size = spirv_offset_next - spirv_offset;
4949 uint32_t msl_mbr_size = get_declared_struct_member_size_msl(struct_type: type, index);
4950 if (msl_mbr_size > maximum_size)
4951 return false;
4952 }
4953
4954 if (is_array(type: mbr_type))
4955 {
4956 // If we have an array type, array stride must match exactly with SPIR-V.
4957
4958 // An exception to this requirement is if we have one array element.
4959 // This comes from DX scalar layout workaround.
4960 // If app tries to be cheeky and access the member out of bounds, this will not work, but this is the best we can do.
4961 // In OpAccessChain with logical memory models, access chains must be in-bounds in SPIR-V specification.
4962 bool relax_array_stride = mbr_type.array.back() == 1 && mbr_type.array_size_literal.back();
4963
4964 if (!relax_array_stride)
4965 {
4966 uint32_t spirv_array_stride = type_struct_member_array_stride(type, index);
4967 uint32_t msl_array_stride = get_declared_struct_member_array_stride_msl(struct_type: type, index);
4968 if (spirv_array_stride != msl_array_stride)
4969 return false;
4970 }
4971 }
4972
4973 if (is_matrix(type: mbr_type))
4974 {
4975 // Need to check MatrixStride as well.
4976 uint32_t spirv_matrix_stride = type_struct_member_matrix_stride(type, index);
4977 uint32_t msl_matrix_stride = get_declared_struct_member_matrix_stride_msl(struct_type: type, index);
4978 if (spirv_matrix_stride != msl_matrix_stride)
4979 return false;
4980 }
4981
4982 // Now, we check alignment.
4983 uint32_t msl_alignment = get_declared_struct_member_alignment_msl(struct_type: type, index);
4984 if ((spirv_offset % msl_alignment) != 0)
4985 return false;
4986
4987 // We're in the clear.
4988 return true;
4989}
4990
4991// Here we need to verify that the member type we declare conforms to Offset, ArrayStride or MatrixStride restrictions.
4992// If there is a mismatch, we need to emit remapped types, either normal types, or "packed_X" types.
4993// In odd cases we need to emit packed and remapped types, for e.g. weird matrices or arrays with weird array strides.
4994void CompilerMSL::ensure_member_packing_rules_msl(SPIRType &ib_type, uint32_t index)
4995{
4996 if (validate_member_packing_rules_msl(type: ib_type, index))
4997 return;
4998
4999 // We failed validation.
5000 // This case will be nightmare-ish to deal with. This could possibly happen if struct alignment does not quite
5001 // match up with what we want. Scalar block layout comes to mind here where we might have to work around the rule
5002 // that struct alignment == max alignment of all members and struct size depends on this alignment.
5003 // Can't repack structs, but can repack pointers to structs.
5004 auto &mbr_type = get<SPIRType>(id: ib_type.member_types[index]);
5005 bool is_buff_ptr = mbr_type.pointer && mbr_type.storage == StorageClassPhysicalStorageBuffer;
5006 if (mbr_type.basetype == SPIRType::Struct && !is_buff_ptr)
5007 SPIRV_CROSS_THROW("Cannot perform any repacking for structs when it is used as a member of another struct.");
5008
5009 // Perform remapping here.
5010 // There is nothing to be gained by using packed scalars, so don't attempt it.
5011 if (!is_scalar(type: ib_type))
5012 set_extended_member_decoration(type: ib_type.self, index, decoration: SPIRVCrossDecorationPhysicalTypePacked);
5013
5014 // Try validating again, now with packed.
5015 if (validate_member_packing_rules_msl(type: ib_type, index))
5016 return;
5017
5018 // We're in deep trouble, and we need to create a new PhysicalType which matches up with what we expect.
5019 // A lot of work goes here ...
5020 // We will need remapping on Load and Store to translate the types between Logical and Physical.
5021
5022 // First, we check if we have small vector std140 array.
5023 // We detect this if we have an array of vectors, and array stride is greater than number of elements.
5024 if (!mbr_type.array.empty() && !is_matrix(type: mbr_type))
5025 {
5026 uint32_t array_stride = type_struct_member_array_stride(type: ib_type, index);
5027
5028 // Hack off array-of-arrays until we find the array stride per element we must have to make it work.
5029 uint32_t dimensions = uint32_t(mbr_type.array.size() - 1);
5030 for (uint32_t dim = 0; dim < dimensions; dim++)
5031 array_stride /= max<uint32_t>(a: to_array_size_literal(type: mbr_type, index: dim), b: 1u);
5032
5033 // Pointers are 8 bytes
5034 uint32_t mbr_width_in_bytes = is_buff_ptr ? 8 : (mbr_type.width / 8);
5035 uint32_t elems_per_stride = array_stride / mbr_width_in_bytes;
5036
5037 if (elems_per_stride == 3)
5038 SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios.");
5039 else if (elems_per_stride > 4 && elems_per_stride != 8)
5040 SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL.");
5041
5042 if (elems_per_stride == 8)
5043 {
5044 if (mbr_type.width == 16)
5045 add_spv_func_and_recompile(spv_func: SPVFuncImplPaddedStd140);
5046 else
5047 SPIRV_CROSS_THROW("Unexpected type in std140 wide array resolve.");
5048 }
5049
5050 auto physical_type = mbr_type;
5051 physical_type.vecsize = elems_per_stride;
5052 physical_type.parent_type = 0;
5053
5054 // If this is a physical buffer pointer, replace type with a ulongn vector.
5055 if (is_buff_ptr)
5056 {
5057 physical_type.width = 64;
5058 physical_type.basetype = to_unsigned_basetype(width: physical_type.width);
5059 physical_type.pointer = false;
5060 physical_type.pointer_depth = false;
5061 physical_type.forward_pointer = false;
5062 }
5063
5064 uint32_t type_id = ir.increase_bound_by(count: 1);
5065 set<SPIRType>(id: type_id, args&: physical_type);
5066 set_extended_member_decoration(type: ib_type.self, index, decoration: SPIRVCrossDecorationPhysicalTypeID, value: type_id);
5067 set_decoration(id: type_id, decoration: DecorationArrayStride, argument: array_stride);
5068
5069 // Remove packed_ for vectors of size 1, 2 and 4.
5070 unset_extended_member_decoration(type: ib_type.self, index, decoration: SPIRVCrossDecorationPhysicalTypePacked);
5071 }
5072 else if (is_matrix(type: mbr_type))
5073 {
5074 // MatrixStride might be std140-esque.
5075 uint32_t matrix_stride = type_struct_member_matrix_stride(type: ib_type, index);
5076
5077 uint32_t elems_per_stride = matrix_stride / (mbr_type.width / 8);
5078
5079 if (elems_per_stride == 3)
5080 SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios.");
5081 else if (elems_per_stride > 4 && elems_per_stride != 8)
5082 SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL.");
5083
5084 if (elems_per_stride == 8)
5085 {
5086 if (mbr_type.basetype != SPIRType::Half)
5087 SPIRV_CROSS_THROW("Unexpected type in std140 wide matrix stride resolve.");
5088 add_spv_func_and_recompile(spv_func: SPVFuncImplPaddedStd140);
5089 }
5090
5091 bool row_major = has_member_decoration(id: ib_type.self, index, decoration: DecorationRowMajor);
5092 auto physical_type = mbr_type;
5093 physical_type.parent_type = 0;
5094
5095 if (row_major)
5096 physical_type.columns = elems_per_stride;
5097 else
5098 physical_type.vecsize = elems_per_stride;
5099 uint32_t type_id = ir.increase_bound_by(count: 1);
5100 set<SPIRType>(id: type_id, args&: physical_type);
5101 set_extended_member_decoration(type: ib_type.self, index, decoration: SPIRVCrossDecorationPhysicalTypeID, value: type_id);
5102
5103 // Remove packed_ for vectors of size 1, 2 and 4.
5104 unset_extended_member_decoration(type: ib_type.self, index, decoration: SPIRVCrossDecorationPhysicalTypePacked);
5105 }
5106 else
5107 SPIRV_CROSS_THROW("Found a buffer packing case which we cannot represent in MSL.");
5108
5109 // Try validating again, now with physical type remapping.
5110 if (validate_member_packing_rules_msl(type: ib_type, index))
5111 return;
5112
5113 // We might have a particular odd scalar layout case where the last element of an array
5114 // does not take up as much space as the ArrayStride or MatrixStride. This can happen with DX cbuffers.
5115 // The "proper" workaround for this is extremely painful and essentially impossible in the edge case of float3[],
5116 // so we hack around it by declaring the offending array or matrix with one less array size/col/row,
5117 // and rely on padding to get the correct value. We will technically access arrays out of bounds into the padding region,
5118 // but it should spill over gracefully without too much trouble. We rely on behavior like this for unsized arrays anyways.
5119
5120 // E.g. we might observe a physical layout of:
5121 // { float2 a[2]; float b; } in cbuffer layout where ArrayStride of a is 16, but offset of b is 24, packed right after a[1] ...
5122 uint32_t type_id = get_extended_member_decoration(type: ib_type.self, index, decoration: SPIRVCrossDecorationPhysicalTypeID);
5123 auto &type = get<SPIRType>(id: type_id);
5124
5125 // Modify the physical type in-place. This is safe since each physical type workaround is a copy.
5126 if (is_array(type))
5127 {
5128 if (type.array.back() > 1)
5129 {
5130 if (!type.array_size_literal.back())
5131 SPIRV_CROSS_THROW("Cannot apply scalar layout workaround with spec constant array size.");
5132 type.array.back() -= 1;
5133 }
5134 else
5135 {
5136 // We have an array of size 1, so we cannot decrement that. Our only option now is to
5137 // force a packed layout instead, and drop the physical type remap since ArrayStride is meaningless now.
5138 unset_extended_member_decoration(type: ib_type.self, index, decoration: SPIRVCrossDecorationPhysicalTypeID);
5139 set_extended_member_decoration(type: ib_type.self, index, decoration: SPIRVCrossDecorationPhysicalTypePacked);
5140 }
5141 }
5142 else if (is_matrix(type))
5143 {
5144 bool row_major = has_member_decoration(id: ib_type.self, index, decoration: DecorationRowMajor);
5145 if (!row_major)
5146 {
5147 // Slice off one column. If we only have 2 columns, this might turn the matrix into a vector with one array element instead.
5148 if (type.columns > 2)
5149 {
5150 type.columns--;
5151 }
5152 else if (type.columns == 2)
5153 {
5154 type.columns = 1;
5155 assert(type.array.empty());
5156 type.op = OpTypeArray;
5157 type.array.push_back(t: 1);
5158 type.array_size_literal.push_back(t: true);
5159 }
5160 }
5161 else
5162 {
5163 // Slice off one row. If we only have 2 rows, this might turn the matrix into a vector with one array element instead.
5164 if (type.vecsize > 2)
5165 {
5166 type.vecsize--;
5167 }
5168 else if (type.vecsize == 2)
5169 {
5170 type.vecsize = type.columns;
5171 type.columns = 1;
5172 assert(type.array.empty());
5173 type.op = OpTypeArray;
5174 type.array.push_back(t: 1);
5175 type.array_size_literal.push_back(t: true);
5176 }
5177 }
5178 }
5179
5180 // This better validate now, or we must fail gracefully.
5181 if (!validate_member_packing_rules_msl(type: ib_type, index))
5182 SPIRV_CROSS_THROW("Found a buffer packing case which we cannot represent in MSL.");
5183}
5184
5185void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_expression)
5186{
5187 auto &type = expression_type(id: rhs_expression);
5188
5189 bool lhs_remapped_type = has_extended_decoration(id: lhs_expression, decoration: SPIRVCrossDecorationPhysicalTypeID);
5190 bool lhs_packed_type = has_extended_decoration(id: lhs_expression, decoration: SPIRVCrossDecorationPhysicalTypePacked);
5191 auto *lhs_e = maybe_get<SPIRExpression>(id: lhs_expression);
5192 auto *rhs_e = maybe_get<SPIRExpression>(id: rhs_expression);
5193
5194 bool transpose = lhs_e && lhs_e->need_transpose;
5195
5196 if (has_decoration(id: lhs_expression, decoration: DecorationBuiltIn) &&
5197 BuiltIn(get_decoration(id: lhs_expression, decoration: DecorationBuiltIn)) == BuiltInSampleMask &&
5198 is_array(type))
5199 {
5200 // Storing an array to SampleMask, have to remove the array-ness before storing.
5201 statement(ts: to_expression(id: lhs_expression), ts: " = ", ts: to_enclosed_unpacked_expression(id: rhs_expression), ts: "[0];");
5202 register_write(chain: lhs_expression);
5203 }
5204 else if (!lhs_remapped_type && !lhs_packed_type)
5205 {
5206 // No physical type remapping, and no packed type, so can just emit a store directly.
5207
5208 // We might not be dealing with remapped physical types or packed types,
5209 // but we might be doing a clean store to a row-major matrix.
5210 // In this case, we just flip transpose states, and emit the store, a transpose must be in the RHS expression, if any.
5211 if (is_matrix(type) && lhs_e && lhs_e->need_transpose)
5212 {
5213 lhs_e->need_transpose = false;
5214
5215 if (rhs_e && rhs_e->need_transpose)
5216 {
5217 // Direct copy, but might need to unpack RHS.
5218 // Skip the transpose, as we will transpose when writing to LHS and transpose(transpose(T)) == T.
5219 rhs_e->need_transpose = false;
5220 statement(ts: to_expression(id: lhs_expression), ts: " = ", ts: to_unpacked_row_major_matrix_expression(id: rhs_expression),
5221 ts: ";");
5222 rhs_e->need_transpose = true;
5223 }
5224 else
5225 statement(ts: to_expression(id: lhs_expression), ts: " = transpose(", ts: to_unpacked_expression(id: rhs_expression), ts: ");");
5226
5227 lhs_e->need_transpose = true;
5228 register_write(chain: lhs_expression);
5229 }
5230 else if (lhs_e && lhs_e->need_transpose)
5231 {
5232 lhs_e->need_transpose = false;
5233
5234 // Storing a column to a row-major matrix. Unroll the write.
5235 for (uint32_t c = 0; c < type.vecsize; c++)
5236 {
5237 auto lhs_expr = to_dereferenced_expression(id: lhs_expression);
5238 auto column_index = lhs_expr.find_last_of(c: '[');
5239 if (column_index != string::npos)
5240 {
5241 statement(ts&: lhs_expr.insert(pos1: column_index, str: join(ts: '[', ts&: c, ts: ']')), ts: " = ",
5242 ts: to_extract_component_expression(id: rhs_expression, index: c), ts: ";");
5243 }
5244 }
5245 lhs_e->need_transpose = true;
5246 register_write(chain: lhs_expression);
5247 }
5248 else
5249 CompilerGLSL::emit_store_statement(lhs_expression, rhs_expression);
5250 }
5251 else if (!lhs_remapped_type && !is_matrix(type) && !transpose)
5252 {
5253 // Even if the target type is packed, we can directly store to it. We cannot store to packed matrices directly,
5254 // since they are declared as array of vectors instead, and we need the fallback path below.
5255 CompilerGLSL::emit_store_statement(lhs_expression, rhs_expression);
5256 }
5257 else
5258 {
5259 // Special handling when storing to a remapped physical type.
5260 // This is mostly to deal with std140 padded matrices or vectors.
5261
5262 TypeID physical_type_id = lhs_remapped_type ?
5263 ID(get_extended_decoration(id: lhs_expression, decoration: SPIRVCrossDecorationPhysicalTypeID)) :
5264 type.self;
5265
5266 auto &physical_type = get<SPIRType>(id: physical_type_id);
5267
5268 string cast_addr_space = "thread";
5269 auto *p_var_lhs = maybe_get_backing_variable(chain: lhs_expression);
5270 if (p_var_lhs)
5271 cast_addr_space = get_type_address_space(type: get<SPIRType>(id: p_var_lhs->basetype), id: lhs_expression);
5272
5273 if (is_matrix(type))
5274 {
5275 const char *packed_pfx = lhs_packed_type ? "packed_" : "";
5276
5277 // Packed matrices are stored as arrays of packed vectors, so we need
5278 // to assign the vectors one at a time.
5279 // For row-major matrices, we need to transpose the *right-hand* side,
5280 // not the left-hand side.
5281
5282 // Lots of cases to cover here ...
5283
5284 bool rhs_transpose = rhs_e && rhs_e->need_transpose;
5285 SPIRType write_type = type;
5286 string cast_expr;
5287
5288 // We're dealing with transpose manually.
5289 if (rhs_transpose)
5290 rhs_e->need_transpose = false;
5291
5292 if (transpose)
5293 {
5294 // We're dealing with transpose manually.
5295 lhs_e->need_transpose = false;
5296 write_type.vecsize = type.columns;
5297 write_type.columns = 1;
5298
5299 if (physical_type.columns != type.columns)
5300 cast_expr = join(ts: "(", ts&: cast_addr_space, ts: " ", ts&: packed_pfx, ts: type_to_glsl(type: write_type), ts: "&)");
5301
5302 if (rhs_transpose)
5303 {
5304 // If RHS is also transposed, we can just copy row by row.
5305 for (uint32_t i = 0; i < type.vecsize; i++)
5306 {
5307 statement(ts&: cast_expr, ts: to_enclosed_expression(id: lhs_expression), ts: "[", ts&: i, ts: "]", ts: " = ",
5308 ts: to_unpacked_row_major_matrix_expression(id: rhs_expression), ts: "[", ts&: i, ts: "];");
5309 }
5310 }
5311 else
5312 {
5313 auto vector_type = expression_type(id: rhs_expression);
5314 vector_type.vecsize = vector_type.columns;
5315 vector_type.columns = 1;
5316
5317 // Transpose on the fly. Emitting a lot of full transpose() ops and extracting lanes seems very bad,
5318 // so pick out individual components instead.
5319 for (uint32_t i = 0; i < type.vecsize; i++)
5320 {
5321 string rhs_row = type_to_glsl_constructor(type: vector_type) + "(";
5322 for (uint32_t j = 0; j < vector_type.vecsize; j++)
5323 {
5324 rhs_row += join(ts: to_enclosed_unpacked_expression(id: rhs_expression), ts: "[", ts&: j, ts: "][", ts&: i, ts: "]");
5325 if (j + 1 < vector_type.vecsize)
5326 rhs_row += ", ";
5327 }
5328 rhs_row += ")";
5329
5330 statement(ts&: cast_expr, ts: to_enclosed_expression(id: lhs_expression), ts: "[", ts&: i, ts: "]", ts: " = ", ts&: rhs_row, ts: ";");
5331 }
5332 }
5333
5334 // We're dealing with transpose manually.
5335 lhs_e->need_transpose = true;
5336 }
5337 else
5338 {
5339 write_type.columns = 1;
5340
5341 if (physical_type.vecsize != type.vecsize)
5342 cast_expr = join(ts: "(", ts&: cast_addr_space, ts: " ", ts&: packed_pfx, ts: type_to_glsl(type: write_type), ts: "&)");
5343
5344 if (rhs_transpose)
5345 {
5346 auto vector_type = expression_type(id: rhs_expression);
5347 vector_type.columns = 1;
5348
5349 // Transpose on the fly. Emitting a lot of full transpose() ops and extracting lanes seems very bad,
5350 // so pick out individual components instead.
5351 for (uint32_t i = 0; i < type.columns; i++)
5352 {
5353 string rhs_row = type_to_glsl_constructor(type: vector_type) + "(";
5354 for (uint32_t j = 0; j < vector_type.vecsize; j++)
5355 {
5356 // Need to explicitly unpack expression since we've mucked with transpose state.
5357 auto unpacked_expr = to_unpacked_row_major_matrix_expression(id: rhs_expression);
5358 rhs_row += join(ts&: unpacked_expr, ts: "[", ts&: j, ts: "][", ts&: i, ts: "]");
5359 if (j + 1 < vector_type.vecsize)
5360 rhs_row += ", ";
5361 }
5362 rhs_row += ")";
5363
5364 statement(ts&: cast_expr, ts: to_enclosed_expression(id: lhs_expression), ts: "[", ts&: i, ts: "]", ts: " = ", ts&: rhs_row, ts: ";");
5365 }
5366 }
5367 else
5368 {
5369 // Copy column-by-column.
5370 for (uint32_t i = 0; i < type.columns; i++)
5371 {
5372 statement(ts&: cast_expr, ts: to_enclosed_expression(id: lhs_expression), ts: "[", ts&: i, ts: "]", ts: " = ",
5373 ts: to_enclosed_unpacked_expression(id: rhs_expression), ts: "[", ts&: i, ts: "];");
5374 }
5375 }
5376 }
5377
5378 // We're dealing with transpose manually.
5379 if (rhs_transpose)
5380 rhs_e->need_transpose = true;
5381 }
5382 else if (transpose)
5383 {
5384 lhs_e->need_transpose = false;
5385
5386 SPIRType write_type = type;
5387 write_type.vecsize = 1;
5388 write_type.columns = 1;
5389
5390 // Storing a column to a row-major matrix. Unroll the write.
5391 for (uint32_t c = 0; c < type.vecsize; c++)
5392 {
5393 auto lhs_expr = to_enclosed_expression(id: lhs_expression);
5394 auto column_index = lhs_expr.find_last_of(c: '[');
5395
5396 // Get rid of any ".data" half8 handling here, we're casting to scalar anyway.
5397 auto end_column_index = lhs_expr.find_last_of(c: ']');
5398 auto end_dot_index = lhs_expr.find_last_of(c: '.');
5399 if (end_dot_index != string::npos && end_dot_index > end_column_index)
5400 lhs_expr.resize(n: end_dot_index);
5401
5402 if (column_index != string::npos)
5403 {
5404 statement(ts: "((", ts&: cast_addr_space, ts: " ", ts: type_to_glsl(type: write_type), ts: "*)&",
5405 ts&: lhs_expr.insert(pos1: column_index, str: join(ts: '[', ts&: c, ts: ']', ts: ")")), ts: " = ",
5406 ts: to_extract_component_expression(id: rhs_expression, index: c), ts: ";");
5407 }
5408 }
5409
5410 lhs_e->need_transpose = true;
5411 }
5412 else if ((is_matrix(type: physical_type) || is_array(type: physical_type)) &&
5413 physical_type.vecsize <= 4 &&
5414 physical_type.vecsize > type.vecsize)
5415 {
5416 assert(type.vecsize >= 1 && type.vecsize <= 3);
5417
5418 // If we have packed types, we cannot use swizzled stores.
5419 // We could technically unroll the store for each element if needed.
5420 // When remapping to a std140 physical type, we always get float4,
5421 // and the packed decoration should always be removed.
5422 assert(!lhs_packed_type);
5423
5424 string lhs = to_dereferenced_expression(id: lhs_expression);
5425 string rhs = to_pointer_expression(id: rhs_expression);
5426
5427 // Unpack the expression so we can store to it with a float or float2.
5428 // It's still an l-value, so it's fine. Most other unpacking of expressions turn them into r-values instead.
5429 lhs = join(ts: "(", ts&: cast_addr_space, ts: " ", ts: type_to_glsl(type), ts: "&)", ts: enclose_expression(expr: lhs));
5430 if (!optimize_read_modify_write(type: expression_type(id: rhs_expression), lhs, rhs))
5431 statement(ts&: lhs, ts: " = ", ts&: rhs, ts: ";");
5432 }
5433 else if (!is_matrix(type))
5434 {
5435 string lhs = to_dereferenced_expression(id: lhs_expression);
5436 string rhs = to_pointer_expression(id: rhs_expression);
5437 if (!optimize_read_modify_write(type: expression_type(id: rhs_expression), lhs, rhs))
5438 statement(ts&: lhs, ts: " = ", ts&: rhs, ts: ";");
5439 }
5440
5441 register_write(chain: lhs_expression);
5442 }
5443}
5444
5445static bool expression_ends_with(const string &expr_str, const std::string &ending)
5446{
5447 if (expr_str.length() >= ending.length())
5448 return (expr_str.compare(pos: expr_str.length() - ending.length(), n: ending.length(), str: ending) == 0);
5449 else
5450 return false;
5451}
5452
5453// Converts the format of the current expression from packed to unpacked,
5454// by wrapping the expression in a constructor of the appropriate type.
5455// Also, handle special physical ID remapping scenarios, similar to emit_store_statement().
5456string CompilerMSL::unpack_expression_type(string expr_str, const SPIRType &type, uint32_t physical_type_id,
5457 bool packed, bool row_major)
5458{
5459 // Trivial case, nothing to do.
5460 if (physical_type_id == 0 && !packed)
5461 return expr_str;
5462
5463 const SPIRType *physical_type = nullptr;
5464 if (physical_type_id)
5465 physical_type = &get<SPIRType>(id: physical_type_id);
5466
5467 static const char *swizzle_lut[] = {
5468 ".x",
5469 ".xy",
5470 ".xyz",
5471 "",
5472 };
5473
5474 // TODO: Move everything to the template wrapper?
5475 bool uses_std140_wrapper = physical_type && physical_type->vecsize > 4;
5476
5477 if (physical_type && is_vector(type: *physical_type) && is_array(type: *physical_type) &&
5478 !uses_std140_wrapper &&
5479 physical_type->vecsize > type.vecsize && !expression_ends_with(expr_str, ending: swizzle_lut[type.vecsize - 1]))
5480 {
5481 // std140 array cases for vectors.
5482 assert(type.vecsize >= 1 && type.vecsize <= 3);
5483 return enclose_expression(expr: expr_str) + swizzle_lut[type.vecsize - 1];
5484 }
5485 else if (physical_type && is_matrix(type: *physical_type) && is_vector(type) &&
5486 !uses_std140_wrapper &&
5487 physical_type->vecsize > type.vecsize)
5488 {
5489 // Extract column from padded matrix.
5490 assert(type.vecsize >= 1 && type.vecsize <= 4);
5491 return enclose_expression(expr: expr_str) + swizzle_lut[type.vecsize - 1];
5492 }
5493 else if (is_matrix(type))
5494 {
5495 // Packed matrices are stored as arrays of packed vectors. Unfortunately,
5496 // we can't just pass the array straight to the matrix constructor. We have to
5497 // pass each vector individually, so that they can be unpacked to normal vectors.
5498 if (!physical_type)
5499 physical_type = &type;
5500
5501 uint32_t vecsize = type.vecsize;
5502 uint32_t columns = type.columns;
5503 if (row_major)
5504 swap(a&: vecsize, b&: columns);
5505
5506 uint32_t physical_vecsize = row_major ? physical_type->columns : physical_type->vecsize;
5507
5508 const char *base_type = type.width == 16 ? "half" : "float";
5509 string unpack_expr = join(ts&: base_type, ts&: columns, ts: "x", ts&: vecsize, ts: "(");
5510
5511 const char *load_swiz = "";
5512 const char *data_swiz = physical_vecsize > 4 ? ".data" : "";
5513
5514 if (physical_vecsize != vecsize)
5515 load_swiz = swizzle_lut[vecsize - 1];
5516
5517 for (uint32_t i = 0; i < columns; i++)
5518 {
5519 if (i > 0)
5520 unpack_expr += ", ";
5521
5522 if (packed)
5523 unpack_expr += join(ts&: base_type, ts&: physical_vecsize, ts: "(", ts&: expr_str, ts: "[", ts&: i, ts: "]", ts: ")", ts&: load_swiz);
5524 else
5525 unpack_expr += join(ts&: expr_str, ts: "[", ts&: i, ts: "]", ts&: data_swiz, ts&: load_swiz);
5526 }
5527
5528 unpack_expr += ")";
5529 return unpack_expr;
5530 }
5531 else
5532 {
5533 return join(ts: type_to_glsl(type), ts: "(", ts&: expr_str, ts: ")");
5534 }
5535}
5536
5537// Emits the file header info
5538void CompilerMSL::emit_header()
5539{
5540 // This particular line can be overridden during compilation, so make it a flag and not a pragma line.
5541 if (suppress_missing_prototypes)
5542 statement(ts: "#pragma clang diagnostic ignored \"-Wmissing-prototypes\"");
5543 if (suppress_incompatible_pointer_types_discard_qualifiers)
5544 statement(ts: "#pragma clang diagnostic ignored \"-Wincompatible-pointer-types-discards-qualifiers\"");
5545
5546 // Disable warning about missing braces for array<T> template to make arrays a value type
5547 if (spv_function_implementations.count(x: SPVFuncImplUnsafeArray) != 0)
5548 statement(ts: "#pragma clang diagnostic ignored \"-Wmissing-braces\"");
5549
5550 for (auto &pragma : pragma_lines)
5551 statement(ts: pragma);
5552
5553 if (!pragma_lines.empty() || suppress_missing_prototypes)
5554 statement(ts: "");
5555
5556 statement(ts: "#include <metal_stdlib>");
5557 statement(ts: "#include <simd/simd.h>");
5558
5559 for (auto &header : header_lines)
5560 statement(ts&: header);
5561
5562 statement(ts: "");
5563 statement(ts: "using namespace metal;");
5564 statement(ts: "");
5565
5566 for (auto &td : typedef_lines)
5567 statement(ts: td);
5568
5569 if (!typedef_lines.empty())
5570 statement(ts: "");
5571}
5572
5573void CompilerMSL::add_pragma_line(const string &line)
5574{
5575 auto rslt = pragma_lines.insert(x: line);
5576 if (rslt.second)
5577 force_recompile();
5578}
5579
5580void CompilerMSL::add_typedef_line(const string &line)
5581{
5582 auto rslt = typedef_lines.insert(x: line);
5583 if (rslt.second)
5584 force_recompile();
5585}
5586
5587// Template struct like spvUnsafeArray<> need to be declared *before* any resources are declared
5588void CompilerMSL::emit_custom_templates()
5589{
5590 static const char * const address_spaces[] = {
5591 "thread", "constant", "device", "threadgroup", "threadgroup_imageblock", "ray_data", "object_data"
5592 };
5593
5594 for (const auto &spv_func : spv_function_implementations)
5595 {
5596 switch (spv_func)
5597 {
5598 case SPVFuncImplUnsafeArray:
5599 statement(ts: "template<typename T, size_t Num>");
5600 statement(ts: "struct spvUnsafeArray");
5601 begin_scope();
5602 statement(ts: "T elements[Num ? Num : 1];");
5603 statement(ts: "");
5604 statement(ts: "thread T& operator [] (size_t pos) thread");
5605 begin_scope();
5606 statement(ts: "return elements[pos];");
5607 end_scope();
5608 statement(ts: "constexpr const thread T& operator [] (size_t pos) const thread");
5609 begin_scope();
5610 statement(ts: "return elements[pos];");
5611 end_scope();
5612 statement(ts: "");
5613 statement(ts: "device T& operator [] (size_t pos) device");
5614 begin_scope();
5615 statement(ts: "return elements[pos];");
5616 end_scope();
5617 statement(ts: "constexpr const device T& operator [] (size_t pos) const device");
5618 begin_scope();
5619 statement(ts: "return elements[pos];");
5620 end_scope();
5621 statement(ts: "");
5622 statement(ts: "constexpr const constant T& operator [] (size_t pos) const constant");
5623 begin_scope();
5624 statement(ts: "return elements[pos];");
5625 end_scope();
5626 statement(ts: "");
5627 statement(ts: "threadgroup T& operator [] (size_t pos) threadgroup");
5628 begin_scope();
5629 statement(ts: "return elements[pos];");
5630 end_scope();
5631 statement(ts: "constexpr const threadgroup T& operator [] (size_t pos) const threadgroup");
5632 begin_scope();
5633 statement(ts: "return elements[pos];");
5634 end_scope();
5635 if (get_execution_model() == spv::ExecutionModelMeshEXT ||
5636 get_execution_model() == spv::ExecutionModelTaskEXT)
5637 {
5638 statement(ts: "");
5639 statement(ts: "object_data T& operator [] (size_t pos) object_data");
5640 begin_scope();
5641 statement(ts: "return elements[pos];");
5642 end_scope();
5643 statement(ts: "constexpr const object_data T& operator [] (size_t pos) const object_data");
5644 begin_scope();
5645 statement(ts: "return elements[pos];");
5646 end_scope();
5647 }
5648 end_scope_decl();
5649 statement(ts: "");
5650 break;
5651
5652 case SPVFuncImplStorageMatrix:
5653 statement(ts: "template<typename T, int Cols, int Rows=Cols>");
5654 statement(ts: "struct spvStorageMatrix");
5655 begin_scope();
5656 statement(ts: "vec<T, Rows> columns[Cols];");
5657 statement(ts: "");
5658 for (size_t method_idx = 0; method_idx < sizeof(address_spaces) / sizeof(address_spaces[0]); ++method_idx)
5659 {
5660 // Some address spaces require particular features.
5661 if (method_idx == 4) // threadgroup_imageblock
5662 statement(ts: "#ifdef __HAVE_IMAGEBLOCKS__");
5663 else if (method_idx == 5) // ray_data
5664 statement(ts: "#ifdef __HAVE_RAYTRACING__");
5665 else if (method_idx == 6) // object_data
5666 statement(ts: "#ifdef __HAVE_MESH__");
5667 const string &method_as = address_spaces[method_idx];
5668 statement(ts: "spvStorageMatrix() ", ts: method_as, ts: " = default;");
5669 if (method_idx != 1) // constant
5670 {
5671 statement(ts: method_as, ts: " spvStorageMatrix& operator=(initializer_list<vec<T, Rows>> cols) ",
5672 ts: method_as);
5673 begin_scope();
5674 statement(ts: "size_t i;");
5675 statement(ts: "thread vec<T, Rows>* col;");
5676 statement(ts: "for (i = 0, col = cols.begin(); i < Cols; ++i, ++col)");
5677 statement(ts: " columns[i] = *col;");
5678 statement(ts: "return *this;");
5679 end_scope();
5680 }
5681 statement(ts: "");
5682 for (size_t param_idx = 0; param_idx < sizeof(address_spaces) / sizeof(address_spaces[0]); ++param_idx)
5683 {
5684 if (param_idx != method_idx)
5685 {
5686 if (param_idx == 4) // threadgroup_imageblock
5687 statement(ts: "#ifdef __HAVE_IMAGEBLOCKS__");
5688 else if (param_idx == 5) // ray_data
5689 statement(ts: "#ifdef __HAVE_RAYTRACING__");
5690 else if (param_idx == 6) // object_data
5691 statement(ts: "#ifdef __HAVE_MESH__");
5692 }
5693 const string &param_as = address_spaces[param_idx];
5694 statement(ts: "spvStorageMatrix(const ", ts: param_as, ts: " matrix<T, Cols, Rows>& m) ", ts: method_as);
5695 begin_scope();
5696 statement(ts: "for (size_t i = 0; i < Cols; ++i)");
5697 statement(ts: " columns[i] = m.columns[i];");
5698 end_scope();
5699 statement(ts: "spvStorageMatrix(const ", ts: param_as, ts: " spvStorageMatrix& m) ", ts: method_as, ts: " = default;");
5700 if (method_idx != 1) // constant
5701 {
5702 statement(ts: method_as, ts: " spvStorageMatrix& operator=(const ", ts: param_as,
5703 ts: " matrix<T, Cols, Rows>& m) ", ts: method_as);
5704 begin_scope();
5705 statement(ts: "for (size_t i = 0; i < Cols; ++i)");
5706 statement(ts: " columns[i] = m.columns[i];");
5707 statement(ts: "return *this;");
5708 end_scope();
5709 statement(ts: method_as, ts: " spvStorageMatrix& operator=(const ", ts: param_as, ts: " spvStorageMatrix& m) ",
5710 ts: method_as, ts: " = default;");
5711 }
5712 if (param_idx != method_idx && param_idx >= 4)
5713 statement(ts: "#endif");
5714 statement(ts: "");
5715 }
5716 statement(ts: "operator matrix<T, Cols, Rows>() const ", ts: method_as);
5717 begin_scope();
5718 statement(ts: "matrix<T, Cols, Rows> m;");
5719 statement(ts: "for (int i = 0; i < Cols; ++i)");
5720 statement(ts: " m.columns[i] = columns[i];");
5721 statement(ts: "return m;");
5722 end_scope();
5723 statement(ts: "");
5724 statement(ts: "vec<T, Rows> operator[](size_t idx) const ", ts: method_as);
5725 begin_scope();
5726 statement(ts: "return columns[idx];");
5727 end_scope();
5728 if (method_idx != 1) // constant
5729 {
5730 statement(ts: method_as, ts: " vec<T, Rows>& operator[](size_t idx) ", ts: method_as);
5731 begin_scope();
5732 statement(ts: "return columns[idx];");
5733 end_scope();
5734 }
5735 if (method_idx >= 4)
5736 statement(ts: "#endif");
5737 statement(ts: "");
5738 }
5739 end_scope_decl();
5740 statement(ts: "");
5741 statement(ts: "template<typename T, int Cols, int Rows>");
5742 statement(ts: "matrix<T, Rows, Cols> transpose(spvStorageMatrix<T, Cols, Rows> m)");
5743 begin_scope();
5744 statement(ts: "return transpose(matrix<T, Cols, Rows>(m));");
5745 end_scope();
5746 statement(ts: "");
5747 statement(ts: "typedef spvStorageMatrix<half, 2, 2> spvStorage_half2x2;");
5748 statement(ts: "typedef spvStorageMatrix<half, 2, 3> spvStorage_half2x3;");
5749 statement(ts: "typedef spvStorageMatrix<half, 2, 4> spvStorage_half2x4;");
5750 statement(ts: "typedef spvStorageMatrix<half, 3, 2> spvStorage_half3x2;");
5751 statement(ts: "typedef spvStorageMatrix<half, 3, 3> spvStorage_half3x3;");
5752 statement(ts: "typedef spvStorageMatrix<half, 3, 4> spvStorage_half3x4;");
5753 statement(ts: "typedef spvStorageMatrix<half, 4, 2> spvStorage_half4x2;");
5754 statement(ts: "typedef spvStorageMatrix<half, 4, 3> spvStorage_half4x3;");
5755 statement(ts: "typedef spvStorageMatrix<half, 4, 4> spvStorage_half4x4;");
5756 statement(ts: "typedef spvStorageMatrix<float, 2, 2> spvStorage_float2x2;");
5757 statement(ts: "typedef spvStorageMatrix<float, 2, 3> spvStorage_float2x3;");
5758 statement(ts: "typedef spvStorageMatrix<float, 2, 4> spvStorage_float2x4;");
5759 statement(ts: "typedef spvStorageMatrix<float, 3, 2> spvStorage_float3x2;");
5760 statement(ts: "typedef spvStorageMatrix<float, 3, 3> spvStorage_float3x3;");
5761 statement(ts: "typedef spvStorageMatrix<float, 3, 4> spvStorage_float3x4;");
5762 statement(ts: "typedef spvStorageMatrix<float, 4, 2> spvStorage_float4x2;");
5763 statement(ts: "typedef spvStorageMatrix<float, 4, 3> spvStorage_float4x3;");
5764 statement(ts: "typedef spvStorageMatrix<float, 4, 4> spvStorage_float4x4;");
5765 statement(ts: "");
5766 break;
5767
5768 default:
5769 break;
5770 }
5771 }
5772}
5773
5774// Emits any needed custom function bodies.
5775// Metal helper functions must be static force-inline, i.e. static inline __attribute__((always_inline))
5776// otherwise they will cause problems when linked together in a single Metallib.
5777void CompilerMSL::emit_custom_functions()
5778{
5779 // Use when outputting overloaded functions to cover different address spaces.
5780 static const char *texture_addr_spaces[] = { "device", "constant", "thread" };
5781 static uint32_t texture_addr_space_count = sizeof(texture_addr_spaces) / sizeof(char*);
5782
5783 if (spv_function_implementations.count(x: SPVFuncImplArrayCopyMultidim))
5784 spv_function_implementations.insert(x: SPVFuncImplArrayCopy);
5785
5786 if (spv_function_implementations.count(x: SPVFuncImplDynamicImageSampler))
5787 {
5788 // Unfortunately, this one needs a lot of the other functions to compile OK.
5789 if (!msl_options.supports_msl_version(major: 2))
5790 SPIRV_CROSS_THROW(
5791 "spvDynamicImageSampler requires default-constructible texture objects, which require MSL 2.0.");
5792 spv_function_implementations.insert(x: SPVFuncImplForwardArgs);
5793 spv_function_implementations.insert(x: SPVFuncImplTextureSwizzle);
5794 if (msl_options.swizzle_texture_samples)
5795 spv_function_implementations.insert(x: SPVFuncImplGatherSwizzle);
5796 for (uint32_t i = SPVFuncImplChromaReconstructNearest2Plane;
5797 i <= SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane; i++)
5798 spv_function_implementations.insert(x: static_cast<SPVFuncImpl>(i));
5799 spv_function_implementations.insert(x: SPVFuncImplExpandITUFullRange);
5800 spv_function_implementations.insert(x: SPVFuncImplExpandITUNarrowRange);
5801 spv_function_implementations.insert(x: SPVFuncImplConvertYCbCrBT709);
5802 spv_function_implementations.insert(x: SPVFuncImplConvertYCbCrBT601);
5803 spv_function_implementations.insert(x: SPVFuncImplConvertYCbCrBT2020);
5804 }
5805
5806 for (uint32_t i = SPVFuncImplChromaReconstructNearest2Plane;
5807 i <= SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane; i++)
5808 if (spv_function_implementations.count(x: static_cast<SPVFuncImpl>(i)))
5809 spv_function_implementations.insert(x: SPVFuncImplForwardArgs);
5810
5811 if (spv_function_implementations.count(x: SPVFuncImplTextureSwizzle) ||
5812 spv_function_implementations.count(x: SPVFuncImplGatherSwizzle) ||
5813 spv_function_implementations.count(x: SPVFuncImplGatherCompareSwizzle))
5814 {
5815 spv_function_implementations.insert(x: SPVFuncImplForwardArgs);
5816 spv_function_implementations.insert(x: SPVFuncImplGetSwizzle);
5817 }
5818
5819 for (const auto &spv_func : spv_function_implementations)
5820 {
5821 switch (spv_func)
5822 {
5823 case SPVFuncImplMod:
5824 statement(ts: "// Implementation of the GLSL mod() function, which is slightly different than Metal fmod()");
5825 statement(ts: "template<typename Tx, typename Ty>");
5826 statement(ts: "inline Tx mod(Tx x, Ty y)");
5827 begin_scope();
5828 statement(ts: "return x - y * floor(x / y);");
5829 end_scope();
5830 statement(ts: "");
5831 break;
5832
5833 case SPVFuncImplRadians:
5834 statement(ts: "// Implementation of the GLSL radians() function");
5835 statement(ts: "template<typename T>");
5836 statement(ts: "inline T radians(T d)");
5837 begin_scope();
5838 statement(ts: "return d * T(0.01745329251);");
5839 end_scope();
5840 statement(ts: "");
5841 break;
5842
5843 case SPVFuncImplDegrees:
5844 statement(ts: "// Implementation of the GLSL degrees() function");
5845 statement(ts: "template<typename T>");
5846 statement(ts: "inline T degrees(T r)");
5847 begin_scope();
5848 statement(ts: "return r * T(57.2957795131);");
5849 end_scope();
5850 statement(ts: "");
5851 break;
5852
5853 case SPVFuncImplFindILsb:
5854 statement(ts: "// Implementation of the GLSL findLSB() function");
5855 statement(ts: "template<typename T>");
5856 statement(ts: "inline T spvFindLSB(T x)");
5857 begin_scope();
5858 statement(ts: "return select(ctz(x), T(-1), x == T(0));");
5859 end_scope();
5860 statement(ts: "");
5861 break;
5862
5863 case SPVFuncImplFindUMsb:
5864 statement(ts: "// Implementation of the unsigned GLSL findMSB() function");
5865 statement(ts: "template<typename T>");
5866 statement(ts: "inline T spvFindUMSB(T x)");
5867 begin_scope();
5868 statement(ts: "return select(clz(T(0)) - (clz(x) + T(1)), T(-1), x == T(0));");
5869 end_scope();
5870 statement(ts: "");
5871 break;
5872
5873 case SPVFuncImplFindSMsb:
5874 statement(ts: "// Implementation of the signed GLSL findMSB() function");
5875 statement(ts: "template<typename T>");
5876 statement(ts: "inline T spvFindSMSB(T x)");
5877 begin_scope();
5878 statement(ts: "T v = select(x, T(-1) - x, x < T(0));");
5879 statement(ts: "return select(clz(T(0)) - (clz(v) + T(1)), T(-1), v == T(0));");
5880 end_scope();
5881 statement(ts: "");
5882 break;
5883
5884 case SPVFuncImplSSign:
5885 statement(ts: "// Implementation of the GLSL sign() function for integer types");
5886 statement(ts: "template<typename T, typename E = typename enable_if<is_integral<T>::value>::type>");
5887 statement(ts: "inline T sign(T x)");
5888 begin_scope();
5889 statement(ts: "return select(select(select(x, T(0), x == T(0)), T(1), x > T(0)), T(-1), x < T(0));");
5890 end_scope();
5891 statement(ts: "");
5892 break;
5893
5894 case SPVFuncImplArrayCopy:
5895 case SPVFuncImplArrayCopyMultidim:
5896 {
5897 // Unfortunately we cannot template on the address space, so combinatorial explosion it is.
5898 static const char *function_name_tags[] = {
5899 "FromConstantToStack", "FromConstantToThreadGroup", "FromStackToStack",
5900 "FromStackToThreadGroup", "FromThreadGroupToStack", "FromThreadGroupToThreadGroup",
5901 "FromDeviceToDevice", "FromConstantToDevice", "FromStackToDevice",
5902 "FromThreadGroupToDevice", "FromDeviceToStack", "FromDeviceToThreadGroup",
5903 };
5904
5905 static const char *src_address_space[] = {
5906 "constant", "constant", "thread const", "thread const",
5907 "threadgroup const", "threadgroup const", "device const", "constant",
5908 "thread const", "threadgroup const", "device const", "device const",
5909 };
5910
5911 static const char *dst_address_space[] = {
5912 "thread", "threadgroup", "thread", "threadgroup", "thread", "threadgroup",
5913 "device", "device", "device", "device", "thread", "threadgroup",
5914 };
5915
5916 for (uint32_t variant = 0; variant < 12; variant++)
5917 {
5918 bool is_multidim = spv_func == SPVFuncImplArrayCopyMultidim;
5919 const char* dim = is_multidim ? "[N][M]" : "[N]";
5920 statement(ts: "template<typename T, uint N", ts: is_multidim ? ", uint M>" : ">");
5921 statement(ts: "inline void spvArrayCopy", ts&: function_name_tags[variant], ts: "(",
5922 ts&: dst_address_space[variant], ts: " T (&dst)", ts&: dim, ts: ", ",
5923 ts&: src_address_space[variant], ts: " T (&src)", ts&: dim, ts: ")");
5924 begin_scope();
5925 statement(ts: "for (uint i = 0; i < N; i++)");
5926 begin_scope();
5927 if (is_multidim)
5928 statement(ts: "spvArrayCopy", ts&: function_name_tags[variant], ts: "(dst[i], src[i]);");
5929 else
5930 statement(ts: "dst[i] = src[i];");
5931 end_scope();
5932 end_scope();
5933 statement(ts: "");
5934 }
5935 break;
5936 }
5937
5938 // Support for Metal 2.1's new texture_buffer type.
5939 case SPVFuncImplTexelBufferCoords:
5940 {
5941 if (msl_options.texel_buffer_texture_width > 0)
5942 {
5943 string tex_width_str = convert_to_string(t: msl_options.texel_buffer_texture_width);
5944 statement(ts: "// Returns 2D texture coords corresponding to 1D texel buffer coords");
5945 statement(ts&: force_inline);
5946 statement(ts: "uint2 spvTexelBufferCoord(uint tc)");
5947 begin_scope();
5948 statement(ts: join(ts: "return uint2(tc % ", ts&: tex_width_str, ts: ", tc / ", ts&: tex_width_str, ts: ");"));
5949 end_scope();
5950 statement(ts: "");
5951 }
5952 else
5953 {
5954 statement(ts: "// Returns 2D texture coords corresponding to 1D texel buffer coords");
5955 statement(
5956 ts: "#define spvTexelBufferCoord(tc, tex) uint2((tc) % (tex).get_width(), (tc) / (tex).get_width())");
5957 statement(ts: "");
5958 }
5959 break;
5960 }
5961
5962 // Emulate texture2D atomic operations
5963 case SPVFuncImplImage2DAtomicCoords:
5964 {
5965 if (msl_options.supports_msl_version(major: 1, minor: 2))
5966 {
5967 statement(ts: "// The required alignment of a linear texture of R32Uint format.");
5968 statement(ts: "constant uint spvLinearTextureAlignmentOverride [[function_constant(",
5969 ts&: msl_options.r32ui_alignment_constant_id, ts: ")]];");
5970 statement(ts: "constant uint spvLinearTextureAlignment = ",
5971 ts: "is_function_constant_defined(spvLinearTextureAlignmentOverride) ? ",
5972 ts: "spvLinearTextureAlignmentOverride : ", ts&: msl_options.r32ui_linear_texture_alignment, ts: ";");
5973 }
5974 else
5975 {
5976 statement(ts: "// The required alignment of a linear texture of R32Uint format.");
5977 statement(ts: "constant uint spvLinearTextureAlignment = ", ts&: msl_options.r32ui_linear_texture_alignment,
5978 ts: ";");
5979 }
5980 statement(ts: "// Returns buffer coords corresponding to 2D texture coords for emulating 2D texture atomics");
5981 statement(ts: "#define spvImage2DAtomicCoord(tc, tex) (((((tex).get_width() + ",
5982 ts: " spvLinearTextureAlignment / 4 - 1) & ~(",
5983 ts: " spvLinearTextureAlignment / 4 - 1)) * (tc).y) + (tc).x)");
5984 statement(ts: "");
5985 break;
5986 }
5987
5988 // Fix up gradient vectors when sampling a cube texture for Apple Silicon.
5989 // h/t Alexey Knyazev (https://github.com/KhronosGroup/MoltenVK/issues/2068#issuecomment-1817799067) for the code.
5990 case SPVFuncImplGradientCube:
5991 statement(ts: "static inline gradientcube spvGradientCube(float3 P, float3 dPdx, float3 dPdy)");
5992 begin_scope();
5993 statement(ts: "// Major axis selection");
5994 statement(ts: "float3 absP = abs(P);");
5995 statement(ts: "bool xMajor = absP.x >= max(absP.y, absP.z);");
5996 statement(ts: "bool yMajor = absP.y >= absP.z;");
5997 statement(ts: "float3 Q = xMajor ? P.yzx : (yMajor ? P.xzy : P);");
5998 statement(ts: "float3 dQdx = xMajor ? dPdx.yzx : (yMajor ? dPdx.xzy : dPdx);");
5999 statement(ts: "float3 dQdy = xMajor ? dPdy.yzx : (yMajor ? dPdy.xzy : dPdy);");
6000 statement_no_indent(ts: "");
6001 statement(ts: "// Skip a couple of operations compared to usual projection");
6002 statement(ts: "float4 d = float4(dQdx.xy, dQdy.xy) - (Q.xy / Q.z).xyxy * float4(dQdx.zz, dQdy.zz);");
6003 statement_no_indent(ts: "");
6004 statement(ts: "// Final swizzle to put the intermediate values into non-ignored components");
6005 statement(ts: "// X major: X and Z");
6006 statement(ts: "// Y major: X and Y");
6007 statement(ts: "// Z major: Y and Z");
6008 statement(ts: "return gradientcube(xMajor ? d.xxy : d.xyx, xMajor ? d.zzw : d.zwz);");
6009 end_scope();
6010 statement(ts: "");
6011 break;
6012
6013 // "fadd" intrinsic support
6014 case SPVFuncImplFAdd:
6015 statement(ts: "template<typename T>");
6016 statement(ts: "[[clang::optnone]] T spvFAdd(T l, T r)");
6017 begin_scope();
6018 statement(ts: "return fma(T(1), l, r);");
6019 end_scope();
6020 statement(ts: "");
6021 break;
6022
6023 // "fsub" intrinsic support
6024 case SPVFuncImplFSub:
6025 statement(ts: "template<typename T>");
6026 statement(ts: "[[clang::optnone]] T spvFSub(T l, T r)");
6027 begin_scope();
6028 statement(ts: "return fma(T(-1), r, l);");
6029 end_scope();
6030 statement(ts: "");
6031 break;
6032
6033 // "fmul' intrinsic support
6034 case SPVFuncImplFMul:
6035 statement(ts: "template<typename T>");
6036 statement(ts: "[[clang::optnone]] T spvFMul(T l, T r)");
6037 begin_scope();
6038 statement(ts: "return fma(l, r, T(0));");
6039 end_scope();
6040 statement(ts: "");
6041
6042 statement(ts: "template<typename T, int Cols, int Rows>");
6043 statement(ts: "[[clang::optnone]] vec<T, Cols> spvFMulVectorMatrix(vec<T, Rows> v, matrix<T, Cols, Rows> m)");
6044 begin_scope();
6045 statement(ts: "vec<T, Cols> res = vec<T, Cols>(0);");
6046 statement(ts: "for (uint i = Rows; i > 0; --i)");
6047 begin_scope();
6048 statement(ts: "vec<T, Cols> tmp(0);");
6049 statement(ts: "for (uint j = 0; j < Cols; ++j)");
6050 begin_scope();
6051 statement(ts: "tmp[j] = m[j][i - 1];");
6052 end_scope();
6053 statement(ts: "res = fma(tmp, vec<T, Cols>(v[i - 1]), res);");
6054 end_scope();
6055 statement(ts: "return res;");
6056 end_scope();
6057 statement(ts: "");
6058
6059 statement(ts: "template<typename T, int Cols, int Rows>");
6060 statement(ts: "[[clang::optnone]] vec<T, Rows> spvFMulMatrixVector(matrix<T, Cols, Rows> m, vec<T, Cols> v)");
6061 begin_scope();
6062 statement(ts: "vec<T, Rows> res = vec<T, Rows>(0);");
6063 statement(ts: "for (uint i = Cols; i > 0; --i)");
6064 begin_scope();
6065 statement(ts: "res = fma(m[i - 1], vec<T, Rows>(v[i - 1]), res);");
6066 end_scope();
6067 statement(ts: "return res;");
6068 end_scope();
6069 statement(ts: "");
6070
6071 statement(ts: "template<typename T, int LCols, int LRows, int RCols, int RRows>");
6072 statement(ts: "[[clang::optnone]] matrix<T, RCols, LRows> spvFMulMatrixMatrix(matrix<T, LCols, LRows> l, matrix<T, RCols, RRows> r)");
6073 begin_scope();
6074 statement(ts: "matrix<T, RCols, LRows> res;");
6075 statement(ts: "for (uint i = 0; i < RCols; i++)");
6076 begin_scope();
6077 statement(ts: "vec<T, RCols> tmp(0);");
6078 statement(ts: "for (uint j = 0; j < LCols; j++)");
6079 begin_scope();
6080 statement(ts: "tmp = fma(vec<T, RCols>(r[i][j]), l[j], tmp);");
6081 end_scope();
6082 statement(ts: "res[i] = tmp;");
6083 end_scope();
6084 statement(ts: "return res;");
6085 end_scope();
6086 statement(ts: "");
6087 break;
6088
6089 case SPVFuncImplQuantizeToF16:
6090 // Ensure fast-math is disabled to match Vulkan results.
6091 // SpvHalfTypeSelector is used to match the half* template type to the float* template type.
6092 // Depending on GPU, MSL does not always flush converted subnormal halfs to zero,
6093 // as required by OpQuantizeToF16, so check for subnormals and flush them to zero.
6094 statement(ts: "template <typename F> struct SpvHalfTypeSelector;");
6095 statement(ts: "template <> struct SpvHalfTypeSelector<float> { public: using H = half; };");
6096 statement(ts: "template<uint N> struct SpvHalfTypeSelector<vec<float, N>> { using H = vec<half, N>; };");
6097 statement(ts: "template<typename F, typename H = typename SpvHalfTypeSelector<F>::H>");
6098 statement(ts: "[[clang::optnone]] F spvQuantizeToF16(F fval)");
6099 begin_scope();
6100 statement(ts: "H hval = H(fval);");
6101 statement(ts: "hval = select(copysign(H(0), hval), hval, isnormal(hval) || isinf(hval) || isnan(hval));");
6102 statement(ts: "return F(hval);");
6103 end_scope();
6104 statement(ts: "");
6105 break;
6106
6107 // Emulate texturecube_array with texture2d_array for iOS where this type is not available
6108 case SPVFuncImplCubemapTo2DArrayFace:
6109 statement(ts&: force_inline);
6110 statement(ts: "float3 spvCubemapTo2DArrayFace(float3 P)");
6111 begin_scope();
6112 statement(ts: "float3 Coords = abs(P.xyz);");
6113 statement(ts: "float CubeFace = 0;");
6114 statement(ts: "float ProjectionAxis = 0;");
6115 statement(ts: "float u = 0;");
6116 statement(ts: "float v = 0;");
6117 statement(ts: "if (Coords.x >= Coords.y && Coords.x >= Coords.z)");
6118 begin_scope();
6119 statement(ts: "CubeFace = P.x >= 0 ? 0 : 1;");
6120 statement(ts: "ProjectionAxis = Coords.x;");
6121 statement(ts: "u = P.x >= 0 ? -P.z : P.z;");
6122 statement(ts: "v = -P.y;");
6123 end_scope();
6124 statement(ts: "else if (Coords.y >= Coords.x && Coords.y >= Coords.z)");
6125 begin_scope();
6126 statement(ts: "CubeFace = P.y >= 0 ? 2 : 3;");
6127 statement(ts: "ProjectionAxis = Coords.y;");
6128 statement(ts: "u = P.x;");
6129 statement(ts: "v = P.y >= 0 ? P.z : -P.z;");
6130 end_scope();
6131 statement(ts: "else");
6132 begin_scope();
6133 statement(ts: "CubeFace = P.z >= 0 ? 4 : 5;");
6134 statement(ts: "ProjectionAxis = Coords.z;");
6135 statement(ts: "u = P.z >= 0 ? P.x : -P.x;");
6136 statement(ts: "v = -P.y;");
6137 end_scope();
6138 statement(ts: "u = 0.5 * (u/ProjectionAxis + 1);");
6139 statement(ts: "v = 0.5 * (v/ProjectionAxis + 1);");
6140 statement(ts: "return float3(u, v, CubeFace);");
6141 end_scope();
6142 statement(ts: "");
6143 break;
6144
6145 case SPVFuncImplInverse4x4:
6146 statement(ts: "// Returns the determinant of a 2x2 matrix.");
6147 statement(ts&: force_inline);
6148 statement(ts: "float spvDet2x2(float a1, float a2, float b1, float b2)");
6149 begin_scope();
6150 statement(ts: "return a1 * b2 - b1 * a2;");
6151 end_scope();
6152 statement(ts: "");
6153
6154 statement(ts: "// Returns the determinant of a 3x3 matrix.");
6155 statement(ts&: force_inline);
6156 statement(ts: "float spvDet3x3(float a1, float a2, float a3, float b1, float b2, float b3, float c1, "
6157 "float c2, float c3)");
6158 begin_scope();
6159 statement(ts: "return a1 * spvDet2x2(b2, b3, c2, c3) - b1 * spvDet2x2(a2, a3, c2, c3) + c1 * spvDet2x2(a2, a3, "
6160 "b2, b3);");
6161 end_scope();
6162 statement(ts: "");
6163 statement(ts: "// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
6164 statement(ts: "// adjoint and dividing by the determinant. The contents of the matrix are changed.");
6165 statement(ts&: force_inline);
6166 statement(ts: "float4x4 spvInverse4x4(float4x4 m)");
6167 begin_scope();
6168 statement(ts: "float4x4 adj; // The adjoint matrix (inverse after dividing by determinant)");
6169 statement_no_indent(ts: "");
6170 statement(ts: "// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
6171 statement(ts: "adj[0][0] = spvDet3x3(m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
6172 "m[3][3]);");
6173 statement(ts: "adj[0][1] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
6174 "m[3][3]);");
6175 statement(ts: "adj[0][2] = spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[3][1], m[3][2], "
6176 "m[3][3]);");
6177 statement(ts: "adj[0][3] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], "
6178 "m[2][3]);");
6179 statement_no_indent(ts: "");
6180 statement(ts: "adj[1][0] = -spvDet3x3(m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
6181 "m[3][3]);");
6182 statement(ts: "adj[1][1] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
6183 "m[3][3]);");
6184 statement(ts: "adj[1][2] = -spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[3][0], m[3][2], "
6185 "m[3][3]);");
6186 statement(ts: "adj[1][3] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], "
6187 "m[2][3]);");
6188 statement_no_indent(ts: "");
6189 statement(ts: "adj[2][0] = spvDet3x3(m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
6190 "m[3][3]);");
6191 statement(ts: "adj[2][1] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
6192 "m[3][3]);");
6193 statement(ts: "adj[2][2] = spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[3][0], m[3][1], "
6194 "m[3][3]);");
6195 statement(ts: "adj[2][3] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], "
6196 "m[2][3]);");
6197 statement_no_indent(ts: "");
6198 statement(ts: "adj[3][0] = -spvDet3x3(m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
6199 "m[3][2]);");
6200 statement(ts: "adj[3][1] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
6201 "m[3][2]);");
6202 statement(ts: "adj[3][2] = -spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[3][0], m[3][1], "
6203 "m[3][2]);");
6204 statement(ts: "adj[3][3] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], "
6205 "m[2][2]);");
6206 statement_no_indent(ts: "");
6207 statement(ts: "// Calculate the determinant as a combination of the cofactors of the first row.");
6208 statement(ts: "float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]) + (adj[0][3] "
6209 "* m[3][0]);");
6210 statement_no_indent(ts: "");
6211 statement(ts: "// Divide the classical adjoint matrix by the determinant.");
6212 statement(ts: "// If determinant is zero, matrix is not invertable, so leave it unchanged.");
6213 statement(ts: "return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
6214 end_scope();
6215 statement(ts: "");
6216 break;
6217
6218 case SPVFuncImplInverse3x3:
6219 if (spv_function_implementations.count(x: SPVFuncImplInverse4x4) == 0)
6220 {
6221 statement(ts: "// Returns the determinant of a 2x2 matrix.");
6222 statement(ts&: force_inline);
6223 statement(ts: "float spvDet2x2(float a1, float a2, float b1, float b2)");
6224 begin_scope();
6225 statement(ts: "return a1 * b2 - b1 * a2;");
6226 end_scope();
6227 statement(ts: "");
6228 }
6229
6230 statement(ts: "// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
6231 statement(ts: "// adjoint and dividing by the determinant. The contents of the matrix are changed.");
6232 statement(ts&: force_inline);
6233 statement(ts: "float3x3 spvInverse3x3(float3x3 m)");
6234 begin_scope();
6235 statement(ts: "float3x3 adj; // The adjoint matrix (inverse after dividing by determinant)");
6236 statement_no_indent(ts: "");
6237 statement(ts: "// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
6238 statement(ts: "adj[0][0] = spvDet2x2(m[1][1], m[1][2], m[2][1], m[2][2]);");
6239 statement(ts: "adj[0][1] = -spvDet2x2(m[0][1], m[0][2], m[2][1], m[2][2]);");
6240 statement(ts: "adj[0][2] = spvDet2x2(m[0][1], m[0][2], m[1][1], m[1][2]);");
6241 statement_no_indent(ts: "");
6242 statement(ts: "adj[1][0] = -spvDet2x2(m[1][0], m[1][2], m[2][0], m[2][2]);");
6243 statement(ts: "adj[1][1] = spvDet2x2(m[0][0], m[0][2], m[2][0], m[2][2]);");
6244 statement(ts: "adj[1][2] = -spvDet2x2(m[0][0], m[0][2], m[1][0], m[1][2]);");
6245 statement_no_indent(ts: "");
6246 statement(ts: "adj[2][0] = spvDet2x2(m[1][0], m[1][1], m[2][0], m[2][1]);");
6247 statement(ts: "adj[2][1] = -spvDet2x2(m[0][0], m[0][1], m[2][0], m[2][1]);");
6248 statement(ts: "adj[2][2] = spvDet2x2(m[0][0], m[0][1], m[1][0], m[1][1]);");
6249 statement_no_indent(ts: "");
6250 statement(ts: "// Calculate the determinant as a combination of the cofactors of the first row.");
6251 statement(ts: "float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]);");
6252 statement_no_indent(ts: "");
6253 statement(ts: "// Divide the classical adjoint matrix by the determinant.");
6254 statement(ts: "// If determinant is zero, matrix is not invertable, so leave it unchanged.");
6255 statement(ts: "return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
6256 end_scope();
6257 statement(ts: "");
6258 break;
6259
6260 case SPVFuncImplInverse2x2:
6261 statement(ts: "// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
6262 statement(ts: "// adjoint and dividing by the determinant. The contents of the matrix are changed.");
6263 statement(ts&: force_inline);
6264 statement(ts: "float2x2 spvInverse2x2(float2x2 m)");
6265 begin_scope();
6266 statement(ts: "float2x2 adj; // The adjoint matrix (inverse after dividing by determinant)");
6267 statement_no_indent(ts: "");
6268 statement(ts: "// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
6269 statement(ts: "adj[0][0] = m[1][1];");
6270 statement(ts: "adj[0][1] = -m[0][1];");
6271 statement_no_indent(ts: "");
6272 statement(ts: "adj[1][0] = -m[1][0];");
6273 statement(ts: "adj[1][1] = m[0][0];");
6274 statement_no_indent(ts: "");
6275 statement(ts: "// Calculate the determinant as a combination of the cofactors of the first row.");
6276 statement(ts: "float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]);");
6277 statement_no_indent(ts: "");
6278 statement(ts: "// Divide the classical adjoint matrix by the determinant.");
6279 statement(ts: "// If determinant is zero, matrix is not invertable, so leave it unchanged.");
6280 statement(ts: "return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
6281 end_scope();
6282 statement(ts: "");
6283 break;
6284
6285 case SPVFuncImplForwardArgs:
6286 statement(ts: "template<typename T> struct spvRemoveReference { typedef T type; };");
6287 statement(ts: "template<typename T> struct spvRemoveReference<thread T&> { typedef T type; };");
6288 statement(ts: "template<typename T> struct spvRemoveReference<thread T&&> { typedef T type; };");
6289 statement(ts: "template<typename T> inline constexpr thread T&& spvForward(thread typename "
6290 "spvRemoveReference<T>::type& x)");
6291 begin_scope();
6292 statement(ts: "return static_cast<thread T&&>(x);");
6293 end_scope();
6294 statement(ts: "template<typename T> inline constexpr thread T&& spvForward(thread typename "
6295 "spvRemoveReference<T>::type&& x)");
6296 begin_scope();
6297 statement(ts: "return static_cast<thread T&&>(x);");
6298 end_scope();
6299 statement(ts: "");
6300 break;
6301
6302 case SPVFuncImplGetSwizzle:
6303 statement(ts: "enum class spvSwizzle : uint");
6304 begin_scope();
6305 statement(ts: "none = 0,");
6306 statement(ts: "zero,");
6307 statement(ts: "one,");
6308 statement(ts: "red,");
6309 statement(ts: "green,");
6310 statement(ts: "blue,");
6311 statement(ts: "alpha");
6312 end_scope_decl();
6313 statement(ts: "");
6314 statement(ts: "template<typename T>");
6315 statement(ts: "inline T spvGetSwizzle(vec<T, 4> x, T c, spvSwizzle s)");
6316 begin_scope();
6317 statement(ts: "switch (s)");
6318 begin_scope();
6319 statement(ts: "case spvSwizzle::none:");
6320 statement(ts: " return c;");
6321 statement(ts: "case spvSwizzle::zero:");
6322 statement(ts: " return 0;");
6323 statement(ts: "case spvSwizzle::one:");
6324 statement(ts: " return 1;");
6325 statement(ts: "case spvSwizzle::red:");
6326 statement(ts: " return x.r;");
6327 statement(ts: "case spvSwizzle::green:");
6328 statement(ts: " return x.g;");
6329 statement(ts: "case spvSwizzle::blue:");
6330 statement(ts: " return x.b;");
6331 statement(ts: "case spvSwizzle::alpha:");
6332 statement(ts: " return x.a;");
6333 end_scope();
6334 end_scope();
6335 statement(ts: "");
6336 break;
6337
6338 case SPVFuncImplTextureSwizzle:
6339 statement(ts: "// Wrapper function that swizzles texture samples and fetches.");
6340 statement(ts: "template<typename T>");
6341 statement(ts: "inline vec<T, 4> spvTextureSwizzle(vec<T, 4> x, uint s)");
6342 begin_scope();
6343 statement(ts: "if (!s)");
6344 statement(ts: " return x;");
6345 statement(ts: "return vec<T, 4>(spvGetSwizzle(x, x.r, spvSwizzle((s >> 0) & 0xFF)), "
6346 "spvGetSwizzle(x, x.g, spvSwizzle((s >> 8) & 0xFF)), spvGetSwizzle(x, x.b, spvSwizzle((s >> 16) "
6347 "& 0xFF)), "
6348 "spvGetSwizzle(x, x.a, spvSwizzle((s >> 24) & 0xFF)));");
6349 end_scope();
6350 statement(ts: "");
6351 statement(ts: "template<typename T>");
6352 statement(ts: "inline T spvTextureSwizzle(T x, uint s)");
6353 begin_scope();
6354 statement(ts: "return spvTextureSwizzle(vec<T, 4>(x, 0, 0, 1), s).x;");
6355 end_scope();
6356 statement(ts: "");
6357 break;
6358
6359 case SPVFuncImplGatherSwizzle:
6360 statement(ts: "// Wrapper function that swizzles texture gathers.");
6361 statement(ts: "template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
6362 "typename... Ts>");
6363 statement(ts: "inline vec<T, 4> spvGatherSwizzle(const thread Tex<T>& t, sampler s, "
6364 "uint sw, component c, Ts... params) METAL_CONST_ARG(c)");
6365 begin_scope();
6366 statement(ts: "if (sw)");
6367 begin_scope();
6368 statement(ts: "switch (spvSwizzle((sw >> (uint(c) * 8)) & 0xFF))");
6369 begin_scope();
6370 statement(ts: "case spvSwizzle::none:");
6371 statement(ts: " break;");
6372 statement(ts: "case spvSwizzle::zero:");
6373 statement(ts: " return vec<T, 4>(0, 0, 0, 0);");
6374 statement(ts: "case spvSwizzle::one:");
6375 statement(ts: " return vec<T, 4>(1, 1, 1, 1);");
6376 statement(ts: "case spvSwizzle::red:");
6377 statement(ts: " return t.gather(s, spvForward<Ts>(params)..., component::x);");
6378 statement(ts: "case spvSwizzle::green:");
6379 statement(ts: " return t.gather(s, spvForward<Ts>(params)..., component::y);");
6380 statement(ts: "case spvSwizzle::blue:");
6381 statement(ts: " return t.gather(s, spvForward<Ts>(params)..., component::z);");
6382 statement(ts: "case spvSwizzle::alpha:");
6383 statement(ts: " return t.gather(s, spvForward<Ts>(params)..., component::w);");
6384 end_scope();
6385 end_scope();
6386 // texture::gather insists on its component parameter being a constant
6387 // expression, so we need this silly workaround just to compile the shader.
6388 statement(ts: "switch (c)");
6389 begin_scope();
6390 statement(ts: "case component::x:");
6391 statement(ts: " return t.gather(s, spvForward<Ts>(params)..., component::x);");
6392 statement(ts: "case component::y:");
6393 statement(ts: " return t.gather(s, spvForward<Ts>(params)..., component::y);");
6394 statement(ts: "case component::z:");
6395 statement(ts: " return t.gather(s, spvForward<Ts>(params)..., component::z);");
6396 statement(ts: "case component::w:");
6397 statement(ts: " return t.gather(s, spvForward<Ts>(params)..., component::w);");
6398 end_scope();
6399 end_scope();
6400 statement(ts: "");
6401 break;
6402
6403 case SPVFuncImplGatherCompareSwizzle:
6404 statement(ts: "// Wrapper function that swizzles depth texture gathers.");
6405 statement(ts: "template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
6406 "typename... Ts>");
6407 statement(ts: "inline vec<T, 4> spvGatherCompareSwizzle(const thread Tex<T>& t, sampler "
6408 "s, uint sw, Ts... params) ");
6409 begin_scope();
6410 statement(ts: "if (sw)");
6411 begin_scope();
6412 statement(ts: "switch (spvSwizzle(sw & 0xFF))");
6413 begin_scope();
6414 statement(ts: "case spvSwizzle::none:");
6415 statement(ts: "case spvSwizzle::red:");
6416 statement(ts: " break;");
6417 statement(ts: "case spvSwizzle::zero:");
6418 statement(ts: "case spvSwizzle::green:");
6419 statement(ts: "case spvSwizzle::blue:");
6420 statement(ts: "case spvSwizzle::alpha:");
6421 statement(ts: " return vec<T, 4>(0, 0, 0, 0);");
6422 statement(ts: "case spvSwizzle::one:");
6423 statement(ts: " return vec<T, 4>(1, 1, 1, 1);");
6424 end_scope();
6425 end_scope();
6426 statement(ts: "return t.gather_compare(s, spvForward<Ts>(params)...);");
6427 end_scope();
6428 statement(ts: "");
6429 break;
6430
6431 case SPVFuncImplGatherConstOffsets:
6432 // Because we are passing a texture reference, we have to output an overloaded version of this function for each address space.
6433 for (uint32_t i = 0; i < texture_addr_space_count; i++)
6434 {
6435 statement(ts: "// Wrapper function that processes a ", ts&: texture_addr_spaces[i], ts: " texture gather with a constant offset array.");
6436 statement(ts: "template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
6437 "typename Toff, typename... Tp>");
6438 statement(ts: "inline vec<T, 4> spvGatherConstOffsets(const ", ts&: texture_addr_spaces[i], ts: " Tex<T>& t, sampler s, "
6439 "Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)");
6440 begin_scope();
6441 statement(ts: "vec<T, 4> rslts[4];");
6442 statement(ts: "for (uint i = 0; i < 4; i++)");
6443 begin_scope();
6444 statement(ts: "switch (c)");
6445 begin_scope();
6446 // Work around texture::gather() requiring its component parameter to be a constant expression
6447 statement(ts: "case component::x:");
6448 statement(ts: " rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::x);");
6449 statement(ts: " break;");
6450 statement(ts: "case component::y:");
6451 statement(ts: " rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::y);");
6452 statement(ts: " break;");
6453 statement(ts: "case component::z:");
6454 statement(ts: " rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::z);");
6455 statement(ts: " break;");
6456 statement(ts: "case component::w:");
6457 statement(ts: " rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::w);");
6458 statement(ts: " break;");
6459 end_scope();
6460 end_scope();
6461 // Pull all values from the i0j0 component of each gather footprint
6462 statement(ts: "return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);");
6463 end_scope();
6464 statement(ts: "");
6465 }
6466 break;
6467
6468 case SPVFuncImplGatherCompareConstOffsets:
6469 // Because we are passing a texture reference, we have to output an overloaded version of this function for each address space.
6470 for (uint32_t i = 0; i < texture_addr_space_count; i++)
6471 {
6472 statement(ts: "// Wrapper function that processes a ", ts&: texture_addr_spaces[i], ts: " texture gather with a constant offset array.");
6473 statement(ts: "template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
6474 "typename Toff, typename... Tp>");
6475 statement(ts: "inline vec<T, 4> spvGatherCompareConstOffsets(const ", ts&: texture_addr_spaces[i], ts: " Tex<T>& t, sampler s, "
6476 "Toff coffsets, Tp... params)");
6477 begin_scope();
6478 statement(ts: "vec<T, 4> rslts[4];");
6479 statement(ts: "for (uint i = 0; i < 4; i++)");
6480 begin_scope();
6481 statement(ts: " rslts[i] = t.gather_compare(s, spvForward<Tp>(params)..., coffsets[i]);");
6482 end_scope();
6483 // Pull all values from the i0j0 component of each gather footprint
6484 statement(ts: "return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);");
6485 end_scope();
6486 statement(ts: "");
6487 }
6488 break;
6489
6490 case SPVFuncImplSubgroupBroadcast:
6491 // Metal doesn't allow broadcasting boolean values directly, but we can work around that by broadcasting
6492 // them as integers.
6493 statement(ts: "template<typename T>");
6494 statement(ts: "inline T spvSubgroupBroadcast(T value, ushort lane)");
6495 begin_scope();
6496 if (msl_options.use_quadgroup_operation())
6497 statement(ts: "return quad_broadcast(value, lane);");
6498 else
6499 statement(ts: "return simd_broadcast(value, lane);");
6500 end_scope();
6501 statement(ts: "");
6502 statement(ts: "template<>");
6503 statement(ts: "inline bool spvSubgroupBroadcast(bool value, ushort lane)");
6504 begin_scope();
6505 if (msl_options.use_quadgroup_operation())
6506 statement(ts: "return !!quad_broadcast((ushort)value, lane);");
6507 else
6508 statement(ts: "return !!simd_broadcast((ushort)value, lane);");
6509 end_scope();
6510 statement(ts: "");
6511 statement(ts: "template<uint N>");
6512 statement(ts: "inline vec<bool, N> spvSubgroupBroadcast(vec<bool, N> value, ushort lane)");
6513 begin_scope();
6514 if (msl_options.use_quadgroup_operation())
6515 statement(ts: "return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);");
6516 else
6517 statement(ts: "return (vec<bool, N>)simd_broadcast((vec<ushort, N>)value, lane);");
6518 end_scope();
6519 statement(ts: "");
6520 break;
6521
6522 case SPVFuncImplSubgroupBroadcastFirst:
6523 statement(ts: "template<typename T>");
6524 statement(ts: "inline T spvSubgroupBroadcastFirst(T value)");
6525 begin_scope();
6526 if (msl_options.use_quadgroup_operation())
6527 statement(ts: "return quad_broadcast_first(value);");
6528 else
6529 statement(ts: "return simd_broadcast_first(value);");
6530 end_scope();
6531 statement(ts: "");
6532 statement(ts: "template<>");
6533 statement(ts: "inline bool spvSubgroupBroadcastFirst(bool value)");
6534 begin_scope();
6535 if (msl_options.use_quadgroup_operation())
6536 statement(ts: "return !!quad_broadcast_first((ushort)value);");
6537 else
6538 statement(ts: "return !!simd_broadcast_first((ushort)value);");
6539 end_scope();
6540 statement(ts: "");
6541 statement(ts: "template<uint N>");
6542 statement(ts: "inline vec<bool, N> spvSubgroupBroadcastFirst(vec<bool, N> value)");
6543 begin_scope();
6544 if (msl_options.use_quadgroup_operation())
6545 statement(ts: "return (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value);");
6546 else
6547 statement(ts: "return (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value);");
6548 end_scope();
6549 statement(ts: "");
6550 break;
6551
6552 case SPVFuncImplSubgroupBallot:
6553 statement(ts: "inline uint4 spvSubgroupBallot(bool value)");
6554 begin_scope();
6555 if (msl_options.use_quadgroup_operation())
6556 {
6557 statement(ts: "return uint4((quad_vote::vote_t)quad_ballot(value), 0, 0, 0);");
6558 }
6559 else if (msl_options.is_ios())
6560 {
6561 // The current simd_vote on iOS uses a 32-bit integer-like object.
6562 statement(ts: "return uint4((simd_vote::vote_t)simd_ballot(value), 0, 0, 0);");
6563 }
6564 else
6565 {
6566 statement(ts: "simd_vote vote = simd_ballot(value);");
6567 statement(ts: "// simd_ballot() returns a 64-bit integer-like object, but");
6568 statement(ts: "// SPIR-V callers expect a uint4. We must convert.");
6569 statement(ts: "// FIXME: This won't include higher bits if Apple ever supports");
6570 statement(ts: "// 128 lanes in an SIMD-group.");
6571 statement(ts: "return uint4(as_type<uint2>((simd_vote::vote_t)vote), 0, 0);");
6572 }
6573 end_scope();
6574 statement(ts: "");
6575 break;
6576
6577 case SPVFuncImplSubgroupBallotBitExtract:
6578 statement(ts: "inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)");
6579 begin_scope();
6580 statement(ts: "return !!extract_bits(ballot[bit / 32], bit % 32, 1);");
6581 end_scope();
6582 statement(ts: "");
6583 break;
6584
6585 case SPVFuncImplSubgroupBallotFindLSB:
6586 statement(ts: "inline uint spvSubgroupBallotFindLSB(uint4 ballot, uint gl_SubgroupSize)");
6587 begin_scope();
6588 if (msl_options.is_ios())
6589 {
6590 statement(ts: "uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));");
6591 }
6592 else
6593 {
6594 statement(ts: "uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
6595 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
6596 }
6597 statement(ts: "ballot &= mask;");
6598 statement(ts: "return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + "
6599 "ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);");
6600 end_scope();
6601 statement(ts: "");
6602 break;
6603
6604 case SPVFuncImplSubgroupBallotFindMSB:
6605 statement(ts: "inline uint spvSubgroupBallotFindMSB(uint4 ballot, uint gl_SubgroupSize)");
6606 begin_scope();
6607 if (msl_options.is_ios())
6608 {
6609 statement(ts: "uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));");
6610 }
6611 else
6612 {
6613 statement(ts: "uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
6614 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
6615 }
6616 statement(ts: "ballot &= mask;");
6617 statement(ts: "return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - "
6618 "(clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), "
6619 "ballot.z == 0), ballot.w == 0);");
6620 end_scope();
6621 statement(ts: "");
6622 break;
6623
6624 case SPVFuncImplSubgroupBallotBitCount:
6625 statement(ts: "inline uint spvPopCount4(uint4 ballot)");
6626 begin_scope();
6627 statement(ts: "return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);");
6628 end_scope();
6629 statement(ts: "");
6630 statement(ts: "inline uint spvSubgroupBallotBitCount(uint4 ballot, uint gl_SubgroupSize)");
6631 begin_scope();
6632 if (msl_options.is_ios())
6633 {
6634 statement(ts: "uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));");
6635 }
6636 else
6637 {
6638 statement(ts: "uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
6639 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
6640 }
6641 statement(ts: "return spvPopCount4(ballot & mask);");
6642 end_scope();
6643 statement(ts: "");
6644 statement(ts: "inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
6645 begin_scope();
6646 if (msl_options.is_ios())
6647 {
6648 statement(ts: "uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupInvocationID + 1), uint3(0));");
6649 }
6650 else
6651 {
6652 statement(ts: "uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), "
6653 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), "
6654 "uint2(0));");
6655 }
6656 statement(ts: "return spvPopCount4(ballot & mask);");
6657 end_scope();
6658 statement(ts: "");
6659 statement(ts: "inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
6660 begin_scope();
6661 if (msl_options.is_ios())
6662 {
6663 statement(ts: "uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupInvocationID), uint2(0));");
6664 }
6665 else
6666 {
6667 statement(ts: "uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), "
6668 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));");
6669 }
6670 statement(ts: "return spvPopCount4(ballot & mask);");
6671 end_scope();
6672 statement(ts: "");
6673 break;
6674
6675 case SPVFuncImplSubgroupAllEqual:
6676 // Metal doesn't provide a function to evaluate this directly. But, we can
6677 // implement this by comparing every thread's value to one thread's value
6678 // (in this case, the value of the first active thread). Then, by the transitive
6679 // property of equality, if all comparisons return true, then they are all equal.
6680 statement(ts: "template<typename T>");
6681 statement(ts: "inline bool spvSubgroupAllEqual(T value)");
6682 begin_scope();
6683 if (msl_options.use_quadgroup_operation())
6684 statement(ts: "return quad_all(all(value == quad_broadcast_first(value)));");
6685 else
6686 statement(ts: "return simd_all(all(value == simd_broadcast_first(value)));");
6687 end_scope();
6688 statement(ts: "");
6689 statement(ts: "template<>");
6690 statement(ts: "inline bool spvSubgroupAllEqual(bool value)");
6691 begin_scope();
6692 if (msl_options.use_quadgroup_operation())
6693 statement(ts: "return quad_all(value) || !quad_any(value);");
6694 else
6695 statement(ts: "return simd_all(value) || !simd_any(value);");
6696 end_scope();
6697 statement(ts: "");
6698 statement(ts: "template<uint N>");
6699 statement(ts: "inline bool spvSubgroupAllEqual(vec<bool, N> value)");
6700 begin_scope();
6701 if (msl_options.use_quadgroup_operation())
6702 statement(ts: "return quad_all(all(value == (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value)));");
6703 else
6704 statement(ts: "return simd_all(all(value == (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value)));");
6705 end_scope();
6706 statement(ts: "");
6707 break;
6708
6709 case SPVFuncImplSubgroupShuffle:
6710 statement(ts: "template<typename T>");
6711 statement(ts: "inline T spvSubgroupShuffle(T value, ushort lane)");
6712 begin_scope();
6713 if (msl_options.use_quadgroup_operation())
6714 statement(ts: "return quad_shuffle(value, lane);");
6715 else
6716 statement(ts: "return simd_shuffle(value, lane);");
6717 end_scope();
6718 statement(ts: "");
6719 statement(ts: "template<>");
6720 statement(ts: "inline bool spvSubgroupShuffle(bool value, ushort lane)");
6721 begin_scope();
6722 if (msl_options.use_quadgroup_operation())
6723 statement(ts: "return !!quad_shuffle((ushort)value, lane);");
6724 else
6725 statement(ts: "return !!simd_shuffle((ushort)value, lane);");
6726 end_scope();
6727 statement(ts: "");
6728 statement(ts: "template<uint N>");
6729 statement(ts: "inline vec<bool, N> spvSubgroupShuffle(vec<bool, N> value, ushort lane)");
6730 begin_scope();
6731 if (msl_options.use_quadgroup_operation())
6732 statement(ts: "return (vec<bool, N>)quad_shuffle((vec<ushort, N>)value, lane);");
6733 else
6734 statement(ts: "return (vec<bool, N>)simd_shuffle((vec<ushort, N>)value, lane);");
6735 end_scope();
6736 statement(ts: "");
6737 break;
6738
6739 case SPVFuncImplSubgroupShuffleXor:
6740 statement(ts: "template<typename T>");
6741 statement(ts: "inline T spvSubgroupShuffleXor(T value, ushort mask)");
6742 begin_scope();
6743 if (msl_options.use_quadgroup_operation())
6744 statement(ts: "return quad_shuffle_xor(value, mask);");
6745 else
6746 statement(ts: "return simd_shuffle_xor(value, mask);");
6747 end_scope();
6748 statement(ts: "");
6749 statement(ts: "template<>");
6750 statement(ts: "inline bool spvSubgroupShuffleXor(bool value, ushort mask)");
6751 begin_scope();
6752 if (msl_options.use_quadgroup_operation())
6753 statement(ts: "return !!quad_shuffle_xor((ushort)value, mask);");
6754 else
6755 statement(ts: "return !!simd_shuffle_xor((ushort)value, mask);");
6756 end_scope();
6757 statement(ts: "");
6758 statement(ts: "template<uint N>");
6759 statement(ts: "inline vec<bool, N> spvSubgroupShuffleXor(vec<bool, N> value, ushort mask)");
6760 begin_scope();
6761 if (msl_options.use_quadgroup_operation())
6762 statement(ts: "return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, mask);");
6763 else
6764 statement(ts: "return (vec<bool, N>)simd_shuffle_xor((vec<ushort, N>)value, mask);");
6765 end_scope();
6766 statement(ts: "");
6767 break;
6768
6769 case SPVFuncImplSubgroupShuffleUp:
6770 statement(ts: "template<typename T>");
6771 statement(ts: "inline T spvSubgroupShuffleUp(T value, ushort delta)");
6772 begin_scope();
6773 if (msl_options.use_quadgroup_operation())
6774 statement(ts: "return quad_shuffle_up(value, delta);");
6775 else
6776 statement(ts: "return simd_shuffle_up(value, delta);");
6777 end_scope();
6778 statement(ts: "");
6779 statement(ts: "template<>");
6780 statement(ts: "inline bool spvSubgroupShuffleUp(bool value, ushort delta)");
6781 begin_scope();
6782 if (msl_options.use_quadgroup_operation())
6783 statement(ts: "return !!quad_shuffle_up((ushort)value, delta);");
6784 else
6785 statement(ts: "return !!simd_shuffle_up((ushort)value, delta);");
6786 end_scope();
6787 statement(ts: "");
6788 statement(ts: "template<uint N>");
6789 statement(ts: "inline vec<bool, N> spvSubgroupShuffleUp(vec<bool, N> value, ushort delta)");
6790 begin_scope();
6791 if (msl_options.use_quadgroup_operation())
6792 statement(ts: "return (vec<bool, N>)quad_shuffle_up((vec<ushort, N>)value, delta);");
6793 else
6794 statement(ts: "return (vec<bool, N>)simd_shuffle_up((vec<ushort, N>)value, delta);");
6795 end_scope();
6796 statement(ts: "");
6797 break;
6798
6799 case SPVFuncImplSubgroupShuffleDown:
6800 statement(ts: "template<typename T>");
6801 statement(ts: "inline T spvSubgroupShuffleDown(T value, ushort delta)");
6802 begin_scope();
6803 if (msl_options.use_quadgroup_operation())
6804 statement(ts: "return quad_shuffle_down(value, delta);");
6805 else
6806 statement(ts: "return simd_shuffle_down(value, delta);");
6807 end_scope();
6808 statement(ts: "");
6809 statement(ts: "template<>");
6810 statement(ts: "inline bool spvSubgroupShuffleDown(bool value, ushort delta)");
6811 begin_scope();
6812 if (msl_options.use_quadgroup_operation())
6813 statement(ts: "return !!quad_shuffle_down((ushort)value, delta);");
6814 else
6815 statement(ts: "return !!simd_shuffle_down((ushort)value, delta);");
6816 end_scope();
6817 statement(ts: "");
6818 statement(ts: "template<uint N>");
6819 statement(ts: "inline vec<bool, N> spvSubgroupShuffleDown(vec<bool, N> value, ushort delta)");
6820 begin_scope();
6821 if (msl_options.use_quadgroup_operation())
6822 statement(ts: "return (vec<bool, N>)quad_shuffle_down((vec<ushort, N>)value, delta);");
6823 else
6824 statement(ts: "return (vec<bool, N>)simd_shuffle_down((vec<ushort, N>)value, delta);");
6825 end_scope();
6826 statement(ts: "");
6827 break;
6828
6829 case SPVFuncImplQuadBroadcast:
6830 statement(ts: "template<typename T>");
6831 statement(ts: "inline T spvQuadBroadcast(T value, uint lane)");
6832 begin_scope();
6833 statement(ts: "return quad_broadcast(value, lane);");
6834 end_scope();
6835 statement(ts: "");
6836 statement(ts: "template<>");
6837 statement(ts: "inline bool spvQuadBroadcast(bool value, uint lane)");
6838 begin_scope();
6839 statement(ts: "return !!quad_broadcast((ushort)value, lane);");
6840 end_scope();
6841 statement(ts: "");
6842 statement(ts: "template<uint N>");
6843 statement(ts: "inline vec<bool, N> spvQuadBroadcast(vec<bool, N> value, uint lane)");
6844 begin_scope();
6845 statement(ts: "return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);");
6846 end_scope();
6847 statement(ts: "");
6848 break;
6849
6850 case SPVFuncImplQuadSwap:
6851 // We can implement this easily based on the following table giving
6852 // the target lane ID from the direction and current lane ID:
6853 // Direction
6854 // | 0 | 1 | 2 |
6855 // ---+---+---+---+
6856 // L 0 | 1 2 3
6857 // a 1 | 0 3 2
6858 // n 2 | 3 0 1
6859 // e 3 | 2 1 0
6860 // Notice that target = source ^ (direction + 1).
6861 statement(ts: "template<typename T>");
6862 statement(ts: "inline T spvQuadSwap(T value, uint dir)");
6863 begin_scope();
6864 statement(ts: "return quad_shuffle_xor(value, dir + 1);");
6865 end_scope();
6866 statement(ts: "");
6867 statement(ts: "template<>");
6868 statement(ts: "inline bool spvQuadSwap(bool value, uint dir)");
6869 begin_scope();
6870 statement(ts: "return !!quad_shuffle_xor((ushort)value, dir + 1);");
6871 end_scope();
6872 statement(ts: "");
6873 statement(ts: "template<uint N>");
6874 statement(ts: "inline vec<bool, N> spvQuadSwap(vec<bool, N> value, uint dir)");
6875 begin_scope();
6876 statement(ts: "return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, dir + 1);");
6877 end_scope();
6878 statement(ts: "");
6879 break;
6880
6881 case SPVFuncImplReflectScalar:
6882 // Metal does not support scalar versions of these functions.
6883 // Ensure fast-math is disabled to match Vulkan results.
6884 statement(ts: "template<typename T>");
6885 statement(ts: "[[clang::optnone]] T spvReflect(T i, T n)");
6886 begin_scope();
6887 statement(ts: "return i - T(2) * i * n * n;");
6888 end_scope();
6889 statement(ts: "");
6890 break;
6891
6892 case SPVFuncImplRefractScalar:
6893 // Metal does not support scalar versions of these functions.
6894 statement(ts: "template<typename T>");
6895 statement(ts: "inline T spvRefract(T i, T n, T eta)");
6896 begin_scope();
6897 statement(ts: "T NoI = n * i;");
6898 statement(ts: "T NoI2 = NoI * NoI;");
6899 statement(ts: "T k = T(1) - eta * eta * (T(1) - NoI2);");
6900 statement(ts: "if (k < T(0))");
6901 begin_scope();
6902 statement(ts: "return T(0);");
6903 end_scope();
6904 statement(ts: "else");
6905 begin_scope();
6906 statement(ts: "return eta * i - (eta * NoI + sqrt(k)) * n;");
6907 end_scope();
6908 end_scope();
6909 statement(ts: "");
6910 break;
6911
6912 case SPVFuncImplFaceForwardScalar:
6913 // Metal does not support scalar versions of these functions.
6914 statement(ts: "template<typename T>");
6915 statement(ts: "inline T spvFaceForward(T n, T i, T nref)");
6916 begin_scope();
6917 statement(ts: "return i * nref < T(0) ? n : -n;");
6918 end_scope();
6919 statement(ts: "");
6920 break;
6921
6922 case SPVFuncImplChromaReconstructNearest2Plane:
6923 statement(ts: "template<typename T, typename... LodOptions>");
6924 statement(ts: "inline vec<T, 4> spvChromaReconstructNearest(texture2d<T> plane0, texture2d<T> plane1, sampler "
6925 "samp, float2 coord, LodOptions... options)");
6926 begin_scope();
6927 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6928 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6929 statement(ts: "ycbcr.br = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).rg;");
6930 statement(ts: "return ycbcr;");
6931 end_scope();
6932 statement(ts: "");
6933 break;
6934
6935 case SPVFuncImplChromaReconstructNearest3Plane:
6936 statement(ts: "template<typename T, typename... LodOptions>");
6937 statement(ts: "inline vec<T, 4> spvChromaReconstructNearest(texture2d<T> plane0, texture2d<T> plane1, "
6938 "texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
6939 begin_scope();
6940 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6941 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6942 statement(ts: "ycbcr.b = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6943 statement(ts: "ycbcr.r = plane2.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6944 statement(ts: "return ycbcr;");
6945 end_scope();
6946 statement(ts: "");
6947 break;
6948
6949 case SPVFuncImplChromaReconstructLinear422CositedEven2Plane:
6950 statement(ts: "template<typename T, typename... LodOptions>");
6951 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear422CositedEven(texture2d<T> plane0, texture2d<T> "
6952 "plane1, sampler samp, float2 coord, LodOptions... options)");
6953 begin_scope();
6954 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6955 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6956 statement(ts: "if (fract(coord.x * plane1.get_width()) != 0.0)");
6957 begin_scope();
6958 statement(ts: "ycbcr.br = vec<T, 2>(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
6959 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).rg);");
6960 end_scope();
6961 statement(ts: "else");
6962 begin_scope();
6963 statement(ts: "ycbcr.br = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).rg;");
6964 end_scope();
6965 statement(ts: "return ycbcr;");
6966 end_scope();
6967 statement(ts: "");
6968 break;
6969
6970 case SPVFuncImplChromaReconstructLinear422CositedEven3Plane:
6971 statement(ts: "template<typename T, typename... LodOptions>");
6972 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear422CositedEven(texture2d<T> plane0, texture2d<T> "
6973 "plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
6974 begin_scope();
6975 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6976 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6977 statement(ts: "if (fract(coord.x * plane1.get_width()) != 0.0)");
6978 begin_scope();
6979 statement(ts: "ycbcr.b = T(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
6980 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).r);");
6981 statement(ts: "ycbcr.r = T(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
6982 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).r);");
6983 end_scope();
6984 statement(ts: "else");
6985 begin_scope();
6986 statement(ts: "ycbcr.b = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6987 statement(ts: "ycbcr.r = plane2.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6988 end_scope();
6989 statement(ts: "return ycbcr;");
6990 end_scope();
6991 statement(ts: "");
6992 break;
6993
6994 case SPVFuncImplChromaReconstructLinear422Midpoint2Plane:
6995 statement(ts: "template<typename T, typename... LodOptions>");
6996 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear422Midpoint(texture2d<T> plane0, texture2d<T> "
6997 "plane1, sampler samp, float2 coord, LodOptions... options)");
6998 begin_scope();
6999 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
7000 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
7001 statement(ts: "int2 offs = int2(fract(coord.x * plane1.get_width()) != 0.0 ? 1 : -1, 0);");
7002 statement(ts: "ycbcr.br = vec<T, 2>(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
7003 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).rg);");
7004 statement(ts: "return ycbcr;");
7005 end_scope();
7006 statement(ts: "");
7007 break;
7008
7009 case SPVFuncImplChromaReconstructLinear422Midpoint3Plane:
7010 statement(ts: "template<typename T, typename... LodOptions>");
7011 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear422Midpoint(texture2d<T> plane0, texture2d<T> "
7012 "plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
7013 begin_scope();
7014 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
7015 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
7016 statement(ts: "int2 offs = int2(fract(coord.x * plane1.get_width()) != 0.0 ? 1 : -1, 0);");
7017 statement(ts: "ycbcr.b = T(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
7018 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).r);");
7019 statement(ts: "ycbcr.r = T(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
7020 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).r);");
7021 statement(ts: "return ycbcr;");
7022 end_scope();
7023 statement(ts: "");
7024 break;
7025
7026 case SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven2Plane:
7027 statement(ts: "template<typename T, typename... LodOptions>");
7028 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYCositedEven(texture2d<T> plane0, "
7029 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
7030 begin_scope();
7031 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
7032 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
7033 statement(ts: "float2 ab = fract(round(coord * float2(plane0.get_width(), plane0.get_height())) * 0.5);");
7034 statement(ts: "ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
7035 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
7036 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
7037 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
7038 statement(ts: "return ycbcr;");
7039 end_scope();
7040 statement(ts: "");
7041 break;
7042
7043 case SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven3Plane:
7044 statement(ts: "template<typename T, typename... LodOptions>");
7045 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYCositedEven(texture2d<T> plane0, "
7046 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
7047 begin_scope();
7048 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
7049 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
7050 statement(ts: "float2 ab = fract(round(coord * float2(plane0.get_width(), plane0.get_height())) * 0.5);");
7051 statement(ts: "ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
7052 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
7053 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
7054 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
7055 statement(ts: "ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
7056 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
7057 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
7058 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
7059 statement(ts: "return ycbcr;");
7060 end_scope();
7061 statement(ts: "");
7062 break;
7063
7064 case SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven2Plane:
7065 statement(ts: "template<typename T, typename... LodOptions>");
7066 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear420XMidpointYCositedEven(texture2d<T> plane0, "
7067 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
7068 begin_scope();
7069 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
7070 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
7071 statement(ts: "float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
7072 "0)) * 0.5);");
7073 statement(ts: "ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
7074 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
7075 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
7076 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
7077 statement(ts: "return ycbcr;");
7078 end_scope();
7079 statement(ts: "");
7080 break;
7081
7082 case SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven3Plane:
7083 statement(ts: "template<typename T, typename... LodOptions>");
7084 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear420XMidpointYCositedEven(texture2d<T> plane0, "
7085 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
7086 begin_scope();
7087 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
7088 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
7089 statement(ts: "float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
7090 "0)) * 0.5);");
7091 statement(ts: "ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
7092 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
7093 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
7094 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
7095 statement(ts: "ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
7096 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
7097 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
7098 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
7099 statement(ts: "return ycbcr;");
7100 end_scope();
7101 statement(ts: "");
7102 break;
7103
7104 case SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint2Plane:
7105 statement(ts: "template<typename T, typename... LodOptions>");
7106 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYMidpoint(texture2d<T> plane0, "
7107 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
7108 begin_scope();
7109 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
7110 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
7111 statement(ts: "float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0, "
7112 "0.5)) * 0.5);");
7113 statement(ts: "ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
7114 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
7115 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
7116 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
7117 statement(ts: "return ycbcr;");
7118 end_scope();
7119 statement(ts: "");
7120 break;
7121
7122 case SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint3Plane:
7123 statement(ts: "template<typename T, typename... LodOptions>");
7124 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYMidpoint(texture2d<T> plane0, "
7125 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
7126 begin_scope();
7127 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
7128 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
7129 statement(ts: "float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0, "
7130 "0.5)) * 0.5);");
7131 statement(ts: "ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
7132 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
7133 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
7134 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
7135 statement(ts: "ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
7136 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
7137 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
7138 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
7139 statement(ts: "return ycbcr;");
7140 end_scope();
7141 statement(ts: "");
7142 break;
7143
7144 case SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint2Plane:
7145 statement(ts: "template<typename T, typename... LodOptions>");
7146 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear420XMidpointYMidpoint(texture2d<T> plane0, "
7147 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
7148 begin_scope();
7149 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
7150 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
7151 statement(ts: "float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
7152 "0.5)) * 0.5);");
7153 statement(ts: "ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
7154 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
7155 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
7156 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
7157 statement(ts: "return ycbcr;");
7158 end_scope();
7159 statement(ts: "");
7160 break;
7161
7162 case SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane:
7163 statement(ts: "template<typename T, typename... LodOptions>");
7164 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear420XMidpointYMidpoint(texture2d<T> plane0, "
7165 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
7166 begin_scope();
7167 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
7168 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
7169 statement(ts: "float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
7170 "0.5)) * 0.5);");
7171 statement(ts: "ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
7172 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
7173 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
7174 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
7175 statement(ts: "ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
7176 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
7177 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
7178 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
7179 statement(ts: "return ycbcr;");
7180 end_scope();
7181 statement(ts: "");
7182 break;
7183
7184 case SPVFuncImplExpandITUFullRange:
7185 statement(ts: "template<typename T>");
7186 statement(ts: "inline vec<T, 4> spvExpandITUFullRange(vec<T, 4> ycbcr, int n)");
7187 begin_scope();
7188 statement(ts: "ycbcr.br -= exp2(T(n-1))/(exp2(T(n))-1);");
7189 statement(ts: "return ycbcr;");
7190 end_scope();
7191 statement(ts: "");
7192 break;
7193
7194 case SPVFuncImplExpandITUNarrowRange:
7195 statement(ts: "template<typename T>");
7196 statement(ts: "inline vec<T, 4> spvExpandITUNarrowRange(vec<T, 4> ycbcr, int n)");
7197 begin_scope();
7198 statement(ts: "ycbcr.g = (ycbcr.g * (exp2(T(n)) - 1) - ldexp(T(16), n - 8))/ldexp(T(219), n - 8);");
7199 statement(ts: "ycbcr.br = (ycbcr.br * (exp2(T(n)) - 1) - ldexp(T(128), n - 8))/ldexp(T(224), n - 8);");
7200 statement(ts: "return ycbcr;");
7201 end_scope();
7202 statement(ts: "");
7203 break;
7204
7205 case SPVFuncImplConvertYCbCrBT709:
7206 statement(ts: "// cf. Khronos Data Format Specification, section 15.1.1");
7207 statement(ts: "constant float3x3 spvBT709Factors = {{1, 1, 1}, {0, -0.13397432/0.7152, 1.8556}, {1.5748, "
7208 "-0.33480248/0.7152, 0}};");
7209 statement(ts: "");
7210 statement(ts: "template<typename T>");
7211 statement(ts: "inline vec<T, 4> spvConvertYCbCrBT709(vec<T, 4> ycbcr)");
7212 begin_scope();
7213 statement(ts: "vec<T, 4> rgba;");
7214 statement(ts: "rgba.rgb = vec<T, 3>(spvBT709Factors * ycbcr.gbr);");
7215 statement(ts: "rgba.a = ycbcr.a;");
7216 statement(ts: "return rgba;");
7217 end_scope();
7218 statement(ts: "");
7219 break;
7220
7221 case SPVFuncImplConvertYCbCrBT601:
7222 statement(ts: "// cf. Khronos Data Format Specification, section 15.1.2");
7223 statement(ts: "constant float3x3 spvBT601Factors = {{1, 1, 1}, {0, -0.202008/0.587, 1.772}, {1.402, "
7224 "-0.419198/0.587, 0}};");
7225 statement(ts: "");
7226 statement(ts: "template<typename T>");
7227 statement(ts: "inline vec<T, 4> spvConvertYCbCrBT601(vec<T, 4> ycbcr)");
7228 begin_scope();
7229 statement(ts: "vec<T, 4> rgba;");
7230 statement(ts: "rgba.rgb = vec<T, 3>(spvBT601Factors * ycbcr.gbr);");
7231 statement(ts: "rgba.a = ycbcr.a;");
7232 statement(ts: "return rgba;");
7233 end_scope();
7234 statement(ts: "");
7235 break;
7236
7237 case SPVFuncImplConvertYCbCrBT2020:
7238 statement(ts: "// cf. Khronos Data Format Specification, section 15.1.3");
7239 statement(ts: "constant float3x3 spvBT2020Factors = {{1, 1, 1}, {0, -0.11156702/0.6780, 1.8814}, {1.4746, "
7240 "-0.38737742/0.6780, 0}};");
7241 statement(ts: "");
7242 statement(ts: "template<typename T>");
7243 statement(ts: "inline vec<T, 4> spvConvertYCbCrBT2020(vec<T, 4> ycbcr)");
7244 begin_scope();
7245 statement(ts: "vec<T, 4> rgba;");
7246 statement(ts: "rgba.rgb = vec<T, 3>(spvBT2020Factors * ycbcr.gbr);");
7247 statement(ts: "rgba.a = ycbcr.a;");
7248 statement(ts: "return rgba;");
7249 end_scope();
7250 statement(ts: "");
7251 break;
7252
7253 case SPVFuncImplDynamicImageSampler:
7254 statement(ts: "enum class spvFormatResolution");
7255 begin_scope();
7256 statement(ts: "_444 = 0,");
7257 statement(ts: "_422,");
7258 statement(ts: "_420");
7259 end_scope_decl();
7260 statement(ts: "");
7261 statement(ts: "enum class spvChromaFilter");
7262 begin_scope();
7263 statement(ts: "nearest = 0,");
7264 statement(ts: "linear");
7265 end_scope_decl();
7266 statement(ts: "");
7267 statement(ts: "enum class spvXChromaLocation");
7268 begin_scope();
7269 statement(ts: "cosited_even = 0,");
7270 statement(ts: "midpoint");
7271 end_scope_decl();
7272 statement(ts: "");
7273 statement(ts: "enum class spvYChromaLocation");
7274 begin_scope();
7275 statement(ts: "cosited_even = 0,");
7276 statement(ts: "midpoint");
7277 end_scope_decl();
7278 statement(ts: "");
7279 statement(ts: "enum class spvYCbCrModelConversion");
7280 begin_scope();
7281 statement(ts: "rgb_identity = 0,");
7282 statement(ts: "ycbcr_identity,");
7283 statement(ts: "ycbcr_bt_709,");
7284 statement(ts: "ycbcr_bt_601,");
7285 statement(ts: "ycbcr_bt_2020");
7286 end_scope_decl();
7287 statement(ts: "");
7288 statement(ts: "enum class spvYCbCrRange");
7289 begin_scope();
7290 statement(ts: "itu_full = 0,");
7291 statement(ts: "itu_narrow");
7292 end_scope_decl();
7293 statement(ts: "");
7294 statement(ts: "struct spvComponentBits");
7295 begin_scope();
7296 statement(ts: "constexpr explicit spvComponentBits(int v) thread : value(v) {}");
7297 statement(ts: "uchar value : 6;");
7298 end_scope_decl();
7299 statement(ts: "// A class corresponding to metal::sampler which holds sampler");
7300 statement(ts: "// Y'CbCr conversion info.");
7301 statement(ts: "struct spvYCbCrSampler");
7302 begin_scope();
7303 statement(ts: "constexpr spvYCbCrSampler() thread : val(build()) {}");
7304 statement(ts: "template<typename... Ts>");
7305 statement(ts: "constexpr spvYCbCrSampler(Ts... t) thread : val(build(t...)) {}");
7306 statement(ts: "constexpr spvYCbCrSampler(const thread spvYCbCrSampler& s) thread = default;");
7307 statement(ts: "");
7308 statement(ts: "spvFormatResolution get_resolution() const thread");
7309 begin_scope();
7310 statement(ts: "return spvFormatResolution((val & resolution_mask) >> resolution_base);");
7311 end_scope();
7312 statement(ts: "spvChromaFilter get_chroma_filter() const thread");
7313 begin_scope();
7314 statement(ts: "return spvChromaFilter((val & chroma_filter_mask) >> chroma_filter_base);");
7315 end_scope();
7316 statement(ts: "spvXChromaLocation get_x_chroma_offset() const thread");
7317 begin_scope();
7318 statement(ts: "return spvXChromaLocation((val & x_chroma_off_mask) >> x_chroma_off_base);");
7319 end_scope();
7320 statement(ts: "spvYChromaLocation get_y_chroma_offset() const thread");
7321 begin_scope();
7322 statement(ts: "return spvYChromaLocation((val & y_chroma_off_mask) >> y_chroma_off_base);");
7323 end_scope();
7324 statement(ts: "spvYCbCrModelConversion get_ycbcr_model() const thread");
7325 begin_scope();
7326 statement(ts: "return spvYCbCrModelConversion((val & ycbcr_model_mask) >> ycbcr_model_base);");
7327 end_scope();
7328 statement(ts: "spvYCbCrRange get_ycbcr_range() const thread");
7329 begin_scope();
7330 statement(ts: "return spvYCbCrRange((val & ycbcr_range_mask) >> ycbcr_range_base);");
7331 end_scope();
7332 statement(ts: "int get_bpc() const thread { return (val & bpc_mask) >> bpc_base; }");
7333 statement(ts: "");
7334 statement(ts: "private:");
7335 statement(ts: "ushort val;");
7336 statement(ts: "");
7337 statement(ts: "constexpr static constant ushort resolution_bits = 2;");
7338 statement(ts: "constexpr static constant ushort chroma_filter_bits = 2;");
7339 statement(ts: "constexpr static constant ushort x_chroma_off_bit = 1;");
7340 statement(ts: "constexpr static constant ushort y_chroma_off_bit = 1;");
7341 statement(ts: "constexpr static constant ushort ycbcr_model_bits = 3;");
7342 statement(ts: "constexpr static constant ushort ycbcr_range_bit = 1;");
7343 statement(ts: "constexpr static constant ushort bpc_bits = 6;");
7344 statement(ts: "");
7345 statement(ts: "constexpr static constant ushort resolution_base = 0;");
7346 statement(ts: "constexpr static constant ushort chroma_filter_base = 2;");
7347 statement(ts: "constexpr static constant ushort x_chroma_off_base = 4;");
7348 statement(ts: "constexpr static constant ushort y_chroma_off_base = 5;");
7349 statement(ts: "constexpr static constant ushort ycbcr_model_base = 6;");
7350 statement(ts: "constexpr static constant ushort ycbcr_range_base = 9;");
7351 statement(ts: "constexpr static constant ushort bpc_base = 10;");
7352 statement(ts: "");
7353 statement(
7354 ts: "constexpr static constant ushort resolution_mask = ((1 << resolution_bits) - 1) << resolution_base;");
7355 statement(ts: "constexpr static constant ushort chroma_filter_mask = ((1 << chroma_filter_bits) - 1) << "
7356 "chroma_filter_base;");
7357 statement(ts: "constexpr static constant ushort x_chroma_off_mask = ((1 << x_chroma_off_bit) - 1) << "
7358 "x_chroma_off_base;");
7359 statement(ts: "constexpr static constant ushort y_chroma_off_mask = ((1 << y_chroma_off_bit) - 1) << "
7360 "y_chroma_off_base;");
7361 statement(ts: "constexpr static constant ushort ycbcr_model_mask = ((1 << ycbcr_model_bits) - 1) << "
7362 "ycbcr_model_base;");
7363 statement(ts: "constexpr static constant ushort ycbcr_range_mask = ((1 << ycbcr_range_bit) - 1) << "
7364 "ycbcr_range_base;");
7365 statement(ts: "constexpr static constant ushort bpc_mask = ((1 << bpc_bits) - 1) << bpc_base;");
7366 statement(ts: "");
7367 statement(ts: "static constexpr ushort build()");
7368 begin_scope();
7369 statement(ts: "return 0;");
7370 end_scope();
7371 statement(ts: "");
7372 statement(ts: "template<typename... Ts>");
7373 statement(ts: "static constexpr ushort build(spvFormatResolution res, Ts... t)");
7374 begin_scope();
7375 statement(ts: "return (ushort(res) << resolution_base) | (build(t...) & ~resolution_mask);");
7376 end_scope();
7377 statement(ts: "");
7378 statement(ts: "template<typename... Ts>");
7379 statement(ts: "static constexpr ushort build(spvChromaFilter filt, Ts... t)");
7380 begin_scope();
7381 statement(ts: "return (ushort(filt) << chroma_filter_base) | (build(t...) & ~chroma_filter_mask);");
7382 end_scope();
7383 statement(ts: "");
7384 statement(ts: "template<typename... Ts>");
7385 statement(ts: "static constexpr ushort build(spvXChromaLocation loc, Ts... t)");
7386 begin_scope();
7387 statement(ts: "return (ushort(loc) << x_chroma_off_base) | (build(t...) & ~x_chroma_off_mask);");
7388 end_scope();
7389 statement(ts: "");
7390 statement(ts: "template<typename... Ts>");
7391 statement(ts: "static constexpr ushort build(spvYChromaLocation loc, Ts... t)");
7392 begin_scope();
7393 statement(ts: "return (ushort(loc) << y_chroma_off_base) | (build(t...) & ~y_chroma_off_mask);");
7394 end_scope();
7395 statement(ts: "");
7396 statement(ts: "template<typename... Ts>");
7397 statement(ts: "static constexpr ushort build(spvYCbCrModelConversion model, Ts... t)");
7398 begin_scope();
7399 statement(ts: "return (ushort(model) << ycbcr_model_base) | (build(t...) & ~ycbcr_model_mask);");
7400 end_scope();
7401 statement(ts: "");
7402 statement(ts: "template<typename... Ts>");
7403 statement(ts: "static constexpr ushort build(spvYCbCrRange range, Ts... t)");
7404 begin_scope();
7405 statement(ts: "return (ushort(range) << ycbcr_range_base) | (build(t...) & ~ycbcr_range_mask);");
7406 end_scope();
7407 statement(ts: "");
7408 statement(ts: "template<typename... Ts>");
7409 statement(ts: "static constexpr ushort build(spvComponentBits bpc, Ts... t)");
7410 begin_scope();
7411 statement(ts: "return (ushort(bpc.value) << bpc_base) | (build(t...) & ~bpc_mask);");
7412 end_scope();
7413 end_scope_decl();
7414 statement(ts: "");
7415 statement(ts: "// A class which can hold up to three textures and a sampler, including");
7416 statement(ts: "// Y'CbCr conversion info, used to pass combined image-samplers");
7417 statement(ts: "// dynamically to functions.");
7418 statement(ts: "template<typename T>");
7419 statement(ts: "struct spvDynamicImageSampler");
7420 begin_scope();
7421 statement(ts: "texture2d<T> plane0;");
7422 statement(ts: "texture2d<T> plane1;");
7423 statement(ts: "texture2d<T> plane2;");
7424 statement(ts: "sampler samp;");
7425 statement(ts: "spvYCbCrSampler ycbcr_samp;");
7426 statement(ts: "uint swizzle = 0;");
7427 statement(ts: "");
7428 if (msl_options.swizzle_texture_samples)
7429 {
7430 statement(ts: "constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp, uint sw) thread :");
7431 statement(ts: " plane0(tex), samp(samp), swizzle(sw) {}");
7432 }
7433 else
7434 {
7435 statement(ts: "constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp) thread :");
7436 statement(ts: " plane0(tex), samp(samp) {}");
7437 }
7438 statement(ts: "constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp, spvYCbCrSampler ycbcr_samp, "
7439 "uint sw) thread :");
7440 statement(ts: " plane0(tex), samp(samp), ycbcr_samp(ycbcr_samp), swizzle(sw) {}");
7441 statement(ts: "constexpr spvDynamicImageSampler(texture2d<T> plane0, texture2d<T> plane1,");
7442 statement(ts: " sampler samp, spvYCbCrSampler ycbcr_samp, uint sw) thread :");
7443 statement(ts: " plane0(plane0), plane1(plane1), samp(samp), ycbcr_samp(ycbcr_samp), swizzle(sw) {}");
7444 statement(
7445 ts: "constexpr spvDynamicImageSampler(texture2d<T> plane0, texture2d<T> plane1, texture2d<T> plane2,");
7446 statement(ts: " sampler samp, spvYCbCrSampler ycbcr_samp, uint sw) thread :");
7447 statement(ts: " plane0(plane0), plane1(plane1), plane2(plane2), samp(samp), ycbcr_samp(ycbcr_samp), "
7448 "swizzle(sw) {}");
7449 statement(ts: "");
7450 // XXX This is really hard to follow... I've left comments to make it a bit easier.
7451 statement(ts: "template<typename... LodOptions>");
7452 statement(ts: "vec<T, 4> do_sample(float2 coord, LodOptions... options) const thread");
7453 begin_scope();
7454 statement(ts: "if (!is_null_texture(plane1))");
7455 begin_scope();
7456 statement(ts: "if (ycbcr_samp.get_resolution() == spvFormatResolution::_444 ||");
7457 statement(ts: " ycbcr_samp.get_chroma_filter() == spvChromaFilter::nearest)");
7458 begin_scope();
7459 statement(ts: "if (!is_null_texture(plane2))");
7460 statement(ts: " return spvChromaReconstructNearest(plane0, plane1, plane2, samp, coord,");
7461 statement(ts: " spvForward<LodOptions>(options)...);");
7462 statement(
7463 ts: "return spvChromaReconstructNearest(plane0, plane1, samp, coord, spvForward<LodOptions>(options)...);");
7464 end_scope(); // if (resolution == 422 || chroma_filter == nearest)
7465 statement(ts: "switch (ycbcr_samp.get_resolution())");
7466 begin_scope();
7467 statement(ts: "case spvFormatResolution::_444: break;");
7468 statement(ts: "case spvFormatResolution::_422:");
7469 begin_scope();
7470 statement(ts: "switch (ycbcr_samp.get_x_chroma_offset())");
7471 begin_scope();
7472 statement(ts: "case spvXChromaLocation::cosited_even:");
7473 statement(ts: " if (!is_null_texture(plane2))");
7474 statement(ts: " return spvChromaReconstructLinear422CositedEven(");
7475 statement(ts: " plane0, plane1, plane2, samp,");
7476 statement(ts: " coord, spvForward<LodOptions>(options)...);");
7477 statement(ts: " return spvChromaReconstructLinear422CositedEven(");
7478 statement(ts: " plane0, plane1, samp, coord,");
7479 statement(ts: " spvForward<LodOptions>(options)...);");
7480 statement(ts: "case spvXChromaLocation::midpoint:");
7481 statement(ts: " if (!is_null_texture(plane2))");
7482 statement(ts: " return spvChromaReconstructLinear422Midpoint(");
7483 statement(ts: " plane0, plane1, plane2, samp,");
7484 statement(ts: " coord, spvForward<LodOptions>(options)...);");
7485 statement(ts: " return spvChromaReconstructLinear422Midpoint(");
7486 statement(ts: " plane0, plane1, samp, coord,");
7487 statement(ts: " spvForward<LodOptions>(options)...);");
7488 end_scope(); // switch (x_chroma_offset)
7489 end_scope(); // case 422:
7490 statement(ts: "case spvFormatResolution::_420:");
7491 begin_scope();
7492 statement(ts: "switch (ycbcr_samp.get_x_chroma_offset())");
7493 begin_scope();
7494 statement(ts: "case spvXChromaLocation::cosited_even:");
7495 begin_scope();
7496 statement(ts: "switch (ycbcr_samp.get_y_chroma_offset())");
7497 begin_scope();
7498 statement(ts: "case spvYChromaLocation::cosited_even:");
7499 statement(ts: " if (!is_null_texture(plane2))");
7500 statement(ts: " return spvChromaReconstructLinear420XCositedEvenYCositedEven(");
7501 statement(ts: " plane0, plane1, plane2, samp,");
7502 statement(ts: " coord, spvForward<LodOptions>(options)...);");
7503 statement(ts: " return spvChromaReconstructLinear420XCositedEvenYCositedEven(");
7504 statement(ts: " plane0, plane1, samp, coord,");
7505 statement(ts: " spvForward<LodOptions>(options)...);");
7506 statement(ts: "case spvYChromaLocation::midpoint:");
7507 statement(ts: " if (!is_null_texture(plane2))");
7508 statement(ts: " return spvChromaReconstructLinear420XCositedEvenYMidpoint(");
7509 statement(ts: " plane0, plane1, plane2, samp,");
7510 statement(ts: " coord, spvForward<LodOptions>(options)...);");
7511 statement(ts: " return spvChromaReconstructLinear420XCositedEvenYMidpoint(");
7512 statement(ts: " plane0, plane1, samp, coord,");
7513 statement(ts: " spvForward<LodOptions>(options)...);");
7514 end_scope(); // switch (y_chroma_offset)
7515 end_scope(); // case x::cosited_even:
7516 statement(ts: "case spvXChromaLocation::midpoint:");
7517 begin_scope();
7518 statement(ts: "switch (ycbcr_samp.get_y_chroma_offset())");
7519 begin_scope();
7520 statement(ts: "case spvYChromaLocation::cosited_even:");
7521 statement(ts: " if (!is_null_texture(plane2))");
7522 statement(ts: " return spvChromaReconstructLinear420XMidpointYCositedEven(");
7523 statement(ts: " plane0, plane1, plane2, samp,");
7524 statement(ts: " coord, spvForward<LodOptions>(options)...);");
7525 statement(ts: " return spvChromaReconstructLinear420XMidpointYCositedEven(");
7526 statement(ts: " plane0, plane1, samp, coord,");
7527 statement(ts: " spvForward<LodOptions>(options)...);");
7528 statement(ts: "case spvYChromaLocation::midpoint:");
7529 statement(ts: " if (!is_null_texture(plane2))");
7530 statement(ts: " return spvChromaReconstructLinear420XMidpointYMidpoint(");
7531 statement(ts: " plane0, plane1, plane2, samp,");
7532 statement(ts: " coord, spvForward<LodOptions>(options)...);");
7533 statement(ts: " return spvChromaReconstructLinear420XMidpointYMidpoint(");
7534 statement(ts: " plane0, plane1, samp, coord,");
7535 statement(ts: " spvForward<LodOptions>(options)...);");
7536 end_scope(); // switch (y_chroma_offset)
7537 end_scope(); // case x::midpoint
7538 end_scope(); // switch (x_chroma_offset)
7539 end_scope(); // case 420:
7540 end_scope(); // switch (resolution)
7541 end_scope(); // if (multiplanar)
7542 statement(ts: "return plane0.sample(samp, coord, spvForward<LodOptions>(options)...);");
7543 end_scope(); // do_sample()
7544 statement(ts: "template <typename... LodOptions>");
7545 statement(ts: "vec<T, 4> sample(float2 coord, LodOptions... options) const thread");
7546 begin_scope();
7547 statement(
7548 ts: "vec<T, 4> s = spvTextureSwizzle(do_sample(coord, spvForward<LodOptions>(options)...), swizzle);");
7549 statement(ts: "if (ycbcr_samp.get_ycbcr_model() == spvYCbCrModelConversion::rgb_identity)");
7550 statement(ts: " return s;");
7551 statement(ts: "");
7552 statement(ts: "switch (ycbcr_samp.get_ycbcr_range())");
7553 begin_scope();
7554 statement(ts: "case spvYCbCrRange::itu_full:");
7555 statement(ts: " s = spvExpandITUFullRange(s, ycbcr_samp.get_bpc());");
7556 statement(ts: " break;");
7557 statement(ts: "case spvYCbCrRange::itu_narrow:");
7558 statement(ts: " s = spvExpandITUNarrowRange(s, ycbcr_samp.get_bpc());");
7559 statement(ts: " break;");
7560 end_scope();
7561 statement(ts: "");
7562 statement(ts: "switch (ycbcr_samp.get_ycbcr_model())");
7563 begin_scope();
7564 statement(ts: "case spvYCbCrModelConversion::rgb_identity:"); // Silence Clang warning
7565 statement(ts: "case spvYCbCrModelConversion::ycbcr_identity:");
7566 statement(ts: " return s;");
7567 statement(ts: "case spvYCbCrModelConversion::ycbcr_bt_709:");
7568 statement(ts: " return spvConvertYCbCrBT709(s);");
7569 statement(ts: "case spvYCbCrModelConversion::ycbcr_bt_601:");
7570 statement(ts: " return spvConvertYCbCrBT601(s);");
7571 statement(ts: "case spvYCbCrModelConversion::ycbcr_bt_2020:");
7572 statement(ts: " return spvConvertYCbCrBT2020(s);");
7573 end_scope();
7574 end_scope();
7575 statement(ts: "");
7576 // Sampler Y'CbCr conversion forbids offsets.
7577 statement(ts: "vec<T, 4> sample(float2 coord, int2 offset) const thread");
7578 begin_scope();
7579 if (msl_options.swizzle_texture_samples)
7580 statement(ts: "return spvTextureSwizzle(plane0.sample(samp, coord, offset), swizzle);");
7581 else
7582 statement(ts: "return plane0.sample(samp, coord, offset);");
7583 end_scope();
7584 statement(ts: "template<typename lod_options>");
7585 statement(ts: "vec<T, 4> sample(float2 coord, lod_options options, int2 offset) const thread");
7586 begin_scope();
7587 if (msl_options.swizzle_texture_samples)
7588 statement(ts: "return spvTextureSwizzle(plane0.sample(samp, coord, options, offset), swizzle);");
7589 else
7590 statement(ts: "return plane0.sample(samp, coord, options, offset);");
7591 end_scope();
7592 statement(ts: "#if __HAVE_MIN_LOD_CLAMP__");
7593 statement(ts: "vec<T, 4> sample(float2 coord, bias b, min_lod_clamp min_lod, int2 offset) const thread");
7594 begin_scope();
7595 statement(ts: "return plane0.sample(samp, coord, b, min_lod, offset);");
7596 end_scope();
7597 statement(
7598 ts: "vec<T, 4> sample(float2 coord, gradient2d grad, min_lod_clamp min_lod, int2 offset) const thread");
7599 begin_scope();
7600 statement(ts: "return plane0.sample(samp, coord, grad, min_lod, offset);");
7601 end_scope();
7602 statement(ts: "#endif");
7603 statement(ts: "");
7604 // Y'CbCr conversion forbids all operations but sampling.
7605 statement(ts: "vec<T, 4> read(uint2 coord, uint lod = 0) const thread");
7606 begin_scope();
7607 statement(ts: "return plane0.read(coord, lod);");
7608 end_scope();
7609 statement(ts: "");
7610 statement(ts: "vec<T, 4> gather(float2 coord, int2 offset = int2(0), component c = component::x) const thread");
7611 begin_scope();
7612 if (msl_options.swizzle_texture_samples)
7613 statement(ts: "return spvGatherSwizzle(plane0, samp, swizzle, c, coord, offset);");
7614 else
7615 statement(ts: "return plane0.gather(samp, coord, offset, c);");
7616 end_scope();
7617 end_scope_decl();
7618 statement(ts: "");
7619 break;
7620
7621 case SPVFuncImplRayQueryIntersectionParams:
7622 statement(ts: "intersection_params spvMakeIntersectionParams(uint flags)");
7623 begin_scope();
7624 statement(ts: "intersection_params ip;");
7625 statement(ts: "if ((flags & ", ts: RayFlagsOpaqueKHRMask, ts: ") != 0)");
7626 statement(ts: " ip.force_opacity(forced_opacity::opaque);");
7627 statement(ts: "if ((flags & ", ts: RayFlagsNoOpaqueKHRMask, ts: ") != 0)");
7628 statement(ts: " ip.force_opacity(forced_opacity::non_opaque);");
7629 statement(ts: "if ((flags & ", ts: RayFlagsTerminateOnFirstHitKHRMask, ts: ") != 0)");
7630 statement(ts: " ip.accept_any_intersection(true);");
7631 // RayFlagsSkipClosestHitShaderKHRMask is not available in MSL
7632 statement(ts: "if ((flags & ", ts: RayFlagsCullBackFacingTrianglesKHRMask, ts: ") != 0)");
7633 statement(ts: " ip.set_triangle_cull_mode(triangle_cull_mode::back);");
7634 statement(ts: "if ((flags & ", ts: RayFlagsCullFrontFacingTrianglesKHRMask, ts: ") != 0)");
7635 statement(ts: " ip.set_triangle_cull_mode(triangle_cull_mode::front);");
7636 statement(ts: "if ((flags & ", ts: RayFlagsCullOpaqueKHRMask, ts: ") != 0)");
7637 statement(ts: " ip.set_opacity_cull_mode(opacity_cull_mode::opaque);");
7638 statement(ts: "if ((flags & ", ts: RayFlagsCullNoOpaqueKHRMask, ts: ") != 0)");
7639 statement(ts: " ip.set_opacity_cull_mode(opacity_cull_mode::non_opaque);");
7640 statement(ts: "if ((flags & ", ts: RayFlagsSkipTrianglesKHRMask, ts: ") != 0)");
7641 statement(ts: " ip.set_geometry_cull_mode(geometry_cull_mode::triangle);");
7642 statement(ts: "if ((flags & ", ts: RayFlagsSkipAABBsKHRMask, ts: ") != 0)");
7643 statement(ts: " ip.set_geometry_cull_mode(geometry_cull_mode::bounding_box);");
7644 statement(ts: "return ip;");
7645 end_scope();
7646 statement(ts: "");
7647 break;
7648
7649 case SPVFuncImplVariableDescriptor:
7650 statement(ts: "template<typename T>");
7651 statement(ts: "struct spvDescriptor");
7652 begin_scope();
7653 statement(ts: "T value;");
7654 end_scope_decl();
7655 statement(ts: "");
7656 break;
7657
7658 case SPVFuncImplVariableSizedDescriptor:
7659 statement(ts: "template<typename T>");
7660 statement(ts: "struct spvBufferDescriptor");
7661 begin_scope();
7662 statement(ts: "T value;");
7663 statement(ts: "int length;");
7664 statement(ts: "const device T& operator -> () const device");
7665 begin_scope();
7666 statement(ts: "return value;");
7667 end_scope();
7668 statement(ts: "const device T& operator * () const device");
7669 begin_scope();
7670 statement(ts: "return value;");
7671 end_scope();
7672 end_scope_decl();
7673 statement(ts: "");
7674 break;
7675
7676 case SPVFuncImplVariableDescriptorArray:
7677 if (spv_function_implementations.count(x: SPVFuncImplVariableDescriptor) != 0)
7678 {
7679 statement(ts: "template<typename T>");
7680 statement(ts: "struct spvDescriptorArray");
7681 begin_scope();
7682 statement(ts: "spvDescriptorArray(const device spvDescriptor<T>* ptr) : ptr(&ptr->value)");
7683 begin_scope();
7684 end_scope();
7685 statement(ts: "const device T& operator [] (size_t i) const");
7686 begin_scope();
7687 statement(ts: "return ptr[i];");
7688 end_scope();
7689 statement(ts: "const device T* ptr;");
7690 end_scope_decl();
7691 statement(ts: "");
7692 }
7693 else
7694 {
7695 statement(ts: "template<typename T>");
7696 statement(ts: "struct spvDescriptorArray;");
7697 statement(ts: "");
7698 }
7699
7700 if (msl_options.runtime_array_rich_descriptor &&
7701 spv_function_implementations.count(x: SPVFuncImplVariableSizedDescriptor) != 0)
7702 {
7703 statement(ts: "template<typename T>");
7704 statement(ts: "struct spvDescriptorArray<device T*>");
7705 begin_scope();
7706 statement(ts: "spvDescriptorArray(const device spvBufferDescriptor<device T*>* ptr) : ptr(ptr)");
7707 begin_scope();
7708 end_scope();
7709 statement(ts: "const device T* operator [] (size_t i) const");
7710 begin_scope();
7711 statement(ts: "return ptr[i].value;");
7712 end_scope();
7713 statement(ts: "const int length(int i) const");
7714 begin_scope();
7715 statement(ts: "return ptr[i].length;");
7716 end_scope();
7717 statement(ts: "const device spvBufferDescriptor<device T*>* ptr;");
7718 end_scope_decl();
7719 statement(ts: "");
7720 }
7721 break;
7722
7723 case SPVFuncImplPaddedStd140:
7724 // .data is used in access chain.
7725 statement(ts: "template <typename T>");
7726 statement(ts: "struct spvPaddedStd140 { alignas(16) T data; };");
7727 statement(ts: "template <typename T, int n>");
7728 statement(ts: "using spvPaddedStd140Matrix = spvPaddedStd140<T>[n];");
7729 statement(ts: "");
7730 break;
7731
7732 case SPVFuncImplReduceAdd:
7733 // Metal doesn't support __builtin_reduce_add or simd_reduce_add, so we need this.
7734 // Metal also doesn't support the other vector builtins, which would have been useful to make this a single template.
7735
7736 statement(ts: "template <typename T>");
7737 statement(ts: "T reduce_add(vec<T, 2> v) { return v.x + v.y; }");
7738
7739 statement(ts: "template <typename T>");
7740 statement(ts: "T reduce_add(vec<T, 3> v) { return v.x + v.y + v.z; }");
7741
7742 statement(ts: "template <typename T>");
7743 statement(ts: "T reduce_add(vec<T, 4> v) { return v.x + v.y + v.z + v.w; }");
7744
7745 statement(ts: "");
7746 break;
7747
7748 case SPVFuncImplImageFence:
7749 statement(ts: "template <typename ImageT>");
7750 statement(ts: "void spvImageFence(ImageT img) { img.fence(); }");
7751 statement(ts: "");
7752 break;
7753
7754 case SPVFuncImplTextureCast:
7755 statement(ts: "template <typename T, typename U>");
7756 statement(ts: "T spvTextureCast(U img)");
7757 begin_scope();
7758 // MSL complains if you try to cast the texture itself, but casting the reference type is ... ok? *shrug*
7759 // Gotta go what you gotta do I suppose.
7760 statement(ts: "return reinterpret_cast<thread const T &>(img);");
7761 end_scope();
7762 statement(ts: "");
7763 break;
7764
7765 case SPVFuncImplMulExtended:
7766 // Compiler may hit an internal error with mulhi, but doesn't when encapsulated for some reason.
7767 statement(ts: "template<typename T, typename U, typename V>");
7768 statement(ts: "[[clang::optnone]] T spvMulExtended(V l, V r)");
7769 begin_scope();
7770 statement(ts: "return T{U(l * r), U(mulhi(l, r))};");
7771 end_scope();
7772 statement(ts: "");
7773 break;
7774
7775 case SPVFuncImplSetMeshOutputsEXT:
7776 statement(ts: "void spvSetMeshOutputsEXT(uint gl_LocalInvocationIndex, threadgroup uint2& spvMeshSizes, uint vertexCount, uint primitiveCount)");
7777 begin_scope();
7778 statement(ts: "if (gl_LocalInvocationIndex == 0)");
7779 begin_scope();
7780 statement(ts: "spvMeshSizes.x = vertexCount;");
7781 statement(ts: "spvMeshSizes.y = primitiveCount;");
7782 end_scope();
7783 end_scope();
7784 statement(ts: "");
7785 break;
7786
7787 default:
7788 break;
7789 }
7790 }
7791}
7792
7793static string inject_top_level_storage_qualifier(const string &expr, const string &qualifier)
7794{
7795 // Easier to do this through text munging since the qualifier does not exist in the type system at all,
7796 // and plumbing in all that information is not very helpful.
7797 size_t last_reference = expr.find_last_of(c: '&');
7798 size_t last_pointer = expr.find_last_of(c: '*');
7799 size_t last_significant = string::npos;
7800
7801 if (last_reference == string::npos)
7802 last_significant = last_pointer;
7803 else if (last_pointer == string::npos)
7804 last_significant = last_reference;
7805 else
7806 last_significant = max<size_t>(a: last_reference, b: last_pointer);
7807
7808 if (last_significant == string::npos)
7809 return join(ts: qualifier, ts: " ", ts: expr);
7810 else
7811 {
7812 return join(ts: expr.substr(pos: 0, n: last_significant + 1), ts: " ",
7813 ts: qualifier, ts: expr.substr(pos: last_significant + 1, n: string::npos));
7814 }
7815}
7816
7817void CompilerMSL::declare_constant_arrays()
7818{
7819 bool fully_inlined = ir.ids_for_type[TypeFunction].size() == 1;
7820
7821 // MSL cannot declare arrays inline (except when declaring a variable), so we must move them out to
7822 // global constants directly, so we are able to use constants as variable expressions.
7823 bool emitted = false;
7824
7825 ir.for_each_typed_id<SPIRConstant>(op: [&](uint32_t, SPIRConstant &c) {
7826 if (c.specialization)
7827 return;
7828
7829 auto &type = this->get<SPIRType>(id: c.constant_type);
7830 // Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries.
7831 // FIXME: However, hoisting constants to main() means we need to pass down constant arrays to leaf functions if they are used there.
7832 // If there are multiple functions in the module, drop this case to avoid breaking use cases which do not need to
7833 // link into Metal libraries. This is hacky.
7834 if (is_array(type) && (!fully_inlined || is_scalar(type) || is_vector(type)))
7835 {
7836 add_resource_name(id: c.self);
7837 auto name = to_name(id: c.self);
7838 statement(ts: inject_top_level_storage_qualifier(expr: variable_decl(type, name), qualifier: "constant"),
7839 ts: " = ", ts: constant_expression(c), ts: ";");
7840 emitted = true;
7841 }
7842 });
7843
7844 if (emitted)
7845 statement(ts: "");
7846}
7847
7848// Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries
7849void CompilerMSL::declare_complex_constant_arrays()
7850{
7851 // If we do not have a fully inlined module, we did not opt in to
7852 // declaring constant arrays of complex types. See CompilerMSL::declare_constant_arrays().
7853 bool fully_inlined = ir.ids_for_type[TypeFunction].size() == 1;
7854 if (!fully_inlined)
7855 return;
7856
7857 // MSL cannot declare arrays inline (except when declaring a variable), so we must move them out to
7858 // global constants directly, so we are able to use constants as variable expressions.
7859 bool emitted = false;
7860
7861 ir.for_each_typed_id<SPIRConstant>(op: [&](uint32_t, SPIRConstant &c) {
7862 if (c.specialization)
7863 return;
7864
7865 auto &type = this->get<SPIRType>(id: c.constant_type);
7866 if (is_array(type) && !(is_scalar(type) || is_vector(type)))
7867 {
7868 add_resource_name(id: c.self);
7869 auto name = to_name(id: c.self);
7870 statement(ts: "", ts: variable_decl(type, name), ts: " = ", ts: constant_expression(c), ts: ";");
7871 emitted = true;
7872 }
7873 });
7874
7875 if (emitted)
7876 statement(ts: "");
7877}
7878
7879void CompilerMSL::emit_resources()
7880{
7881 declare_constant_arrays();
7882
7883 // Emit the special [[stage_in]] and [[stage_out]] interface blocks which we created.
7884 emit_interface_block(ib_var_id: stage_out_var_id);
7885 emit_interface_block(ib_var_id: patch_stage_out_var_id);
7886 emit_interface_block(ib_var_id: stage_in_var_id);
7887 emit_interface_block(ib_var_id: patch_stage_in_var_id);
7888
7889 if (get_execution_model() == ExecutionModelMeshEXT)
7890 {
7891 auto &execution = get_entry_point();
7892 const char *topology = "";
7893 if (execution.flags.get(bit: ExecutionModeOutputTrianglesEXT))
7894 topology = "topology::triangle";
7895 else if (execution.flags.get(bit: ExecutionModeOutputLinesEXT))
7896 topology = "topology::line";
7897 else if (execution.flags.get(bit: ExecutionModeOutputPoints))
7898 topology = "topology::point";
7899
7900 const char *per_primitive = mesh_out_per_primitive ? "spvPerPrimitive" : "void";
7901 statement(ts: "using spvMesh_t = mesh<", ts: "spvPerVertex, ", ts&: per_primitive, ts: ", ", ts&: execution.output_vertices, ts: ", ",
7902 ts&: execution.output_primitives, ts: ", ", ts&: topology, ts: ">;");
7903 statement(ts: "");
7904 }
7905}
7906
7907// Emit declarations for the specialization Metal function constants
7908void CompilerMSL::emit_specialization_constants_and_structs()
7909{
7910 SpecializationConstant wg_x, wg_y, wg_z;
7911 ID workgroup_size_id = get_work_group_size_specialization_constants(x&: wg_x, y&: wg_y, z&: wg_z);
7912 bool emitted = false;
7913
7914 unordered_set<uint32_t> declared_structs;
7915 unordered_set<uint32_t> aligned_structs;
7916
7917 // First, we need to deal with scalar block layout.
7918 // It is possible that a struct may have to be placed at an alignment which does not match the innate alignment of the struct itself.
7919 // In that case, if such a case exists for a struct, we must force that all elements of the struct become packed_ types.
7920 // This makes the struct alignment as small as physically possible.
7921 // When we actually align the struct later, we can insert padding as necessary to make the packed members behave like normally aligned types.
7922 ir.for_each_typed_id<SPIRType>(op: [&](uint32_t type_id, const SPIRType &type) {
7923 if (type.basetype == SPIRType::Struct &&
7924 has_extended_decoration(id: type_id, decoration: SPIRVCrossDecorationBufferBlockRepacked))
7925 mark_scalar_layout_structs(type);
7926 });
7927
7928 bool builtin_block_type_is_required = is_mesh_shader();
7929 // Very special case. If gl_PerVertex is initialized as an array (tessellation)
7930 // we have to potentially emit the gl_PerVertex struct type so that we can emit a constant LUT.
7931 ir.for_each_typed_id<SPIRConstant>(op: [&](uint32_t, SPIRConstant &c) {
7932 auto &type = this->get<SPIRType>(id: c.constant_type);
7933 if (is_array(type) && has_decoration(id: type.self, decoration: DecorationBlock) && is_builtin_type(type))
7934 builtin_block_type_is_required = true;
7935 });
7936
7937 // Very particular use of the soft loop lock.
7938 // align_struct may need to create custom types on the fly, but we don't care about
7939 // these types for purpose of iterating over them in ir.ids_for_type and friends.
7940 auto loop_lock = ir.create_loop_soft_lock();
7941
7942 // Physical storage buffer pointers can have cyclical references,
7943 // so emit forward declarations of them before other structs.
7944 // Ignore type_id because we want the underlying struct type from the pointer.
7945 ir.for_each_typed_id<SPIRType>(op: [&](uint32_t /* type_id */, const SPIRType &type) {
7946 if (type.basetype == SPIRType::Struct &&
7947 type.pointer && type.storage == StorageClassPhysicalStorageBuffer &&
7948 declared_structs.count(x: type.self) == 0)
7949 {
7950 statement(ts: "struct ", ts: to_name(id: type.self), ts: ";");
7951 declared_structs.insert(x: type.self);
7952 emitted = true;
7953 }
7954 });
7955 if (emitted)
7956 statement(ts: "");
7957
7958 emitted = false;
7959 declared_structs.clear();
7960
7961 // It is possible to have multiple spec constants that use the same spec constant ID.
7962 // The most common cause of this is defining spec constants in GLSL while also declaring
7963 // the workgroup size to use those spec constants. But, Metal forbids declaring more than
7964 // one variable with the same function constant ID.
7965 // In this case, we must only declare one variable with the [[function_constant(id)]]
7966 // attribute, and use its initializer to initialize all the spec constants with
7967 // that ID.
7968 std::unordered_map<uint32_t, ConstantID> unique_func_constants;
7969
7970 for (auto &id_ : ir.ids_for_constant_undef_or_type)
7971 {
7972 auto &id = ir.ids[id_];
7973
7974 if (id.get_type() == TypeConstant)
7975 {
7976 auto &c = id.get<SPIRConstant>();
7977
7978 if (c.self == workgroup_size_id)
7979 {
7980 // TODO: This can be expressed as a [[threads_per_threadgroup]] input semantic, but we need to know
7981 // the work group size at compile time in SPIR-V, and [[threads_per_threadgroup]] would need to be passed around as a global.
7982 // The work group size may be a specialization constant.
7983 statement(ts: "constant uint3 ", ts: builtin_to_glsl(builtin: BuiltInWorkgroupSize, storage: StorageClassWorkgroup),
7984 ts: " [[maybe_unused]] = ", ts: constant_expression(c: get<SPIRConstant>(id: workgroup_size_id)), ts: ";");
7985 emitted = true;
7986 }
7987 else if (c.specialization)
7988 {
7989 auto &type = get<SPIRType>(id: c.constant_type);
7990 string sc_type_name = type_to_glsl(type);
7991 add_resource_name(id: c.self);
7992 string sc_name = to_name(id: c.self);
7993
7994 // Function constants are only supported in MSL 1.2 and later.
7995 // If we don't support it just declare the "default" directly.
7996 // This "default" value can be overridden to the true specialization constant by the API user.
7997 // Specialization constants which are used as array length expressions cannot be function constants in MSL,
7998 // so just fall back to macros.
7999 if (msl_options.supports_msl_version(major: 1, minor: 2) && has_decoration(id: c.self, decoration: DecorationSpecId) &&
8000 !c.is_used_as_array_length)
8001 {
8002 // Only scalar, non-composite values can be function constants.
8003 uint32_t constant_id = get_decoration(id: c.self, decoration: DecorationSpecId);
8004 if (!unique_func_constants.count(x: constant_id))
8005 unique_func_constants.insert(x: make_pair(x&: constant_id, y&: c.self));
8006 SPIRType::BaseType sc_tmp_type = expression_type(id: unique_func_constants[constant_id]).basetype;
8007 string sc_tmp_name = to_name(id: unique_func_constants[constant_id]) + "_tmp";
8008 if (unique_func_constants[constant_id] == c.self)
8009 statement(ts: "constant ", ts&: sc_type_name, ts: " ", ts&: sc_tmp_name, ts: " [[function_constant(", ts&: constant_id,
8010 ts: ")]];");
8011 statement(ts: "constant ", ts&: sc_type_name, ts: " ", ts&: sc_name, ts: " = is_function_constant_defined(", ts&: sc_tmp_name,
8012 ts: ") ? ", ts: bitcast_expression(target_type: type, expr_type: sc_tmp_type, expr: sc_tmp_name), ts: " : ", ts: constant_expression(c),
8013 ts: ";");
8014 }
8015 else if (has_decoration(id: c.self, decoration: DecorationSpecId))
8016 {
8017 // Fallback to macro overrides.
8018 c.specialization_constant_macro_name =
8019 constant_value_macro_name(id: get_decoration(id: c.self, decoration: DecorationSpecId));
8020
8021 statement(ts: "#ifndef ", ts&: c.specialization_constant_macro_name);
8022 statement(ts: "#define ", ts&: c.specialization_constant_macro_name, ts: " ", ts: constant_expression(c));
8023 statement(ts: "#endif");
8024 statement(ts: "constant ", ts&: sc_type_name, ts: " ", ts&: sc_name, ts: " = ", ts&: c.specialization_constant_macro_name,
8025 ts: ";");
8026 }
8027 else
8028 {
8029 // Composite specialization constants must be built from other specialization constants.
8030 statement(ts: "constant ", ts&: sc_type_name, ts: " ", ts&: sc_name, ts: " = ", ts: constant_expression(c), ts: ";");
8031 }
8032 emitted = true;
8033 }
8034 }
8035 else if (id.get_type() == TypeConstantOp)
8036 {
8037 auto &c = id.get<SPIRConstantOp>();
8038 auto &type = get<SPIRType>(id: c.basetype);
8039 add_resource_name(id: c.self);
8040 auto name = to_name(id: c.self);
8041 statement(ts: "constant ", ts: variable_decl(type, name), ts: " = ", ts: constant_op_expression(cop: c), ts: ";");
8042 emitted = true;
8043 }
8044 else if (id.get_type() == TypeType)
8045 {
8046 // Output non-builtin interface structs. These include local function structs
8047 // and structs nested within uniform and read-write buffers.
8048 auto &type = id.get<SPIRType>();
8049 TypeID type_id = type.self;
8050
8051 bool is_struct = (type.basetype == SPIRType::Struct) && type.array.empty() && !type.pointer;
8052 bool is_block =
8053 has_decoration(id: type.self, decoration: DecorationBlock) || has_decoration(id: type.self, decoration: DecorationBufferBlock);
8054
8055 bool is_builtin_block = is_block && is_builtin_type(type);
8056 bool is_declarable_struct = is_struct && (!is_builtin_block || builtin_block_type_is_required);
8057
8058 // We'll declare this later.
8059 if (stage_out_var_id && get_stage_out_struct_type().self == type_id)
8060 is_declarable_struct = false;
8061 if (patch_stage_out_var_id && get_patch_stage_out_struct_type().self == type_id)
8062 is_declarable_struct = false;
8063 if (stage_in_var_id && get_stage_in_struct_type().self == type_id)
8064 is_declarable_struct = false;
8065 if (patch_stage_in_var_id && get_patch_stage_in_struct_type().self == type_id)
8066 is_declarable_struct = false;
8067
8068 // Special case. Declare builtin struct anyways if we need to emit a threadgroup version of it.
8069 if (stage_out_masked_builtin_type_id == type_id)
8070 is_declarable_struct = true;
8071
8072 // Align and emit declarable structs...but avoid declaring each more than once.
8073 if (is_declarable_struct && declared_structs.count(x: type_id) == 0)
8074 {
8075 if (emitted)
8076 statement(ts: "");
8077 emitted = false;
8078
8079 declared_structs.insert(x: type_id);
8080
8081 if (has_extended_decoration(id: type_id, decoration: SPIRVCrossDecorationBufferBlockRepacked))
8082 align_struct(ib_type&: type, aligned_structs);
8083
8084 // Make sure we declare the underlying struct type, and not the "decorated" type with pointers, etc.
8085 emit_struct(type&: get<SPIRType>(id: type_id));
8086 }
8087 }
8088 else if (id.get_type() == TypeUndef)
8089 {
8090 auto &undef = id.get<SPIRUndef>();
8091 auto &type = get<SPIRType>(id: undef.basetype);
8092 // OpUndef can be void for some reason ...
8093 if (type.basetype == SPIRType::Void)
8094 return;
8095
8096 // Undefined global memory is not allowed in MSL.
8097 // Declare constant and init to zeros. Use {}, as global constructors can break Metal.
8098 statement(
8099 ts: inject_top_level_storage_qualifier(expr: variable_decl(type, name: to_name(id: undef.self), id: undef.self), qualifier: "constant"),
8100 ts: " = {};");
8101 emitted = true;
8102 }
8103 }
8104
8105 if (emitted)
8106 statement(ts: "");
8107}
8108
8109void CompilerMSL::emit_binary_ptr_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op)
8110{
8111 bool forward = should_forward(id: op0) && should_forward(id: op1);
8112 emit_op(result_type, result_id, rhs: join(ts: to_ptr_expression(id: op0), ts: " ", ts&: op, ts: " ", ts: to_ptr_expression(id: op1)), forward_rhs: forward);
8113 inherit_expression_dependencies(dst: result_id, source: op0);
8114 inherit_expression_dependencies(dst: result_id, source: op1);
8115}
8116
8117string CompilerMSL::to_ptr_expression(uint32_t id, bool register_expression_read)
8118{
8119 auto *e = maybe_get<SPIRExpression>(id);
8120 auto expr = enclose_expression(expr: e && e->need_transpose ? e->expression : to_expression(id, register_expression_read));
8121 if (!should_dereference(id))
8122 expr = address_of_expression(expr);
8123 return expr;
8124}
8125
8126void CompilerMSL::emit_binary_unord_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1,
8127 const char *op)
8128{
8129 bool forward = should_forward(id: op0) && should_forward(id: op1);
8130 emit_op(result_type, result_id,
8131 rhs: join(ts: "(isunordered(", ts: to_enclosed_unpacked_expression(id: op0), ts: ", ", ts: to_enclosed_unpacked_expression(id: op1),
8132 ts: ") || ", ts: to_enclosed_unpacked_expression(id: op0), ts: " ", ts&: op, ts: " ", ts: to_enclosed_unpacked_expression(id: op1),
8133 ts: ")"),
8134 forward_rhs: forward);
8135
8136 inherit_expression_dependencies(dst: result_id, source: op0);
8137 inherit_expression_dependencies(dst: result_id, source: op1);
8138}
8139
8140bool CompilerMSL::emit_tessellation_io_load(uint32_t result_type_id, uint32_t id, uint32_t ptr)
8141{
8142 auto &ptr_type = expression_type(id: ptr);
8143 auto &result_type = get<SPIRType>(id: result_type_id);
8144 if (ptr_type.storage != StorageClassInput && ptr_type.storage != StorageClassOutput)
8145 return false;
8146 if (ptr_type.storage == StorageClassOutput && is_tese_shader())
8147 return false;
8148
8149 if (has_decoration(id: ptr, decoration: DecorationPatch))
8150 return false;
8151 bool ptr_is_io_variable = ir.ids[ptr].get_type() == TypeVariable;
8152
8153 bool flattened_io = variable_storage_requires_stage_io(storage: ptr_type.storage);
8154
8155 bool flat_data_type = flattened_io &&
8156 (is_matrix(type: result_type) || is_array(type: result_type) || result_type.basetype == SPIRType::Struct);
8157
8158 // Edge case, even with multi-patch workgroups, we still need to unroll load
8159 // if we're loading control points directly.
8160 if (ptr_is_io_variable && is_array(type: result_type))
8161 flat_data_type = true;
8162
8163 if (!flat_data_type)
8164 return false;
8165
8166 // Now, we must unflatten a composite type and take care of interleaving array access with gl_in/gl_out.
8167 // Lots of painful code duplication since we *really* should not unroll these kinds of loads in entry point fixup
8168 // unless we're forced to do this when the code is emitting inoptimal OpLoads.
8169 string expr;
8170
8171 uint32_t interface_index = get_extended_decoration(id: ptr, decoration: SPIRVCrossDecorationInterfaceMemberIndex);
8172 auto *var = maybe_get_backing_variable(chain: ptr);
8173 auto &expr_type = get_pointee_type(type_id: ptr_type.self);
8174
8175 const auto &iface_type = expression_type(id: stage_in_ptr_var_id);
8176
8177 if (!flattened_io)
8178 {
8179 // Simplest case for multi-patch workgroups, just unroll array as-is.
8180 if (interface_index == uint32_t(-1))
8181 return false;
8182
8183 expr += type_to_glsl(type: result_type) + "({ ";
8184 uint32_t num_control_points = to_array_size_literal(type: result_type, index: uint32_t(result_type.array.size()) - 1);
8185
8186 for (uint32_t i = 0; i < num_control_points; i++)
8187 {
8188 const uint32_t indices[2] = { i, interface_index };
8189 AccessChainMeta meta;
8190 expr += access_chain_internal(base: stage_in_ptr_var_id, indices, count: 2,
8191 flags: ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, meta: &meta);
8192 if (i + 1 < num_control_points)
8193 expr += ", ";
8194 }
8195 expr += " })";
8196 }
8197 else if (result_type.array.size() > 2)
8198 {
8199 SPIRV_CROSS_THROW("Cannot load tessellation IO variables with more than 2 dimensions.");
8200 }
8201 else if (result_type.array.size() == 2)
8202 {
8203 if (!ptr_is_io_variable)
8204 SPIRV_CROSS_THROW("Loading an array-of-array must be loaded directly from an IO variable.");
8205 if (interface_index == uint32_t(-1))
8206 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
8207 if (result_type.basetype == SPIRType::Struct || is_matrix(type: result_type))
8208 SPIRV_CROSS_THROW("Cannot load array-of-array of composite type in tessellation IO.");
8209
8210 expr += type_to_glsl(type: result_type) + "({ ";
8211 uint32_t num_control_points = to_array_size_literal(type: result_type, index: 1);
8212 uint32_t base_interface_index = interface_index;
8213
8214 auto &sub_type = get<SPIRType>(id: result_type.parent_type);
8215
8216 for (uint32_t i = 0; i < num_control_points; i++)
8217 {
8218 expr += type_to_glsl(type: sub_type) + "({ ";
8219 interface_index = base_interface_index;
8220 uint32_t array_size = to_array_size_literal(type: result_type, index: 0);
8221 for (uint32_t j = 0; j < array_size; j++, interface_index++)
8222 {
8223 const uint32_t indices[2] = { i, interface_index };
8224
8225 AccessChainMeta meta;
8226 expr += access_chain_internal(base: stage_in_ptr_var_id, indices, count: 2,
8227 flags: ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, meta: &meta);
8228 if (!is_matrix(type: sub_type) && sub_type.basetype != SPIRType::Struct &&
8229 expr_type.vecsize > sub_type.vecsize)
8230 expr += vector_swizzle(vecsize: sub_type.vecsize, index: 0);
8231
8232 if (j + 1 < array_size)
8233 expr += ", ";
8234 }
8235 expr += " })";
8236 if (i + 1 < num_control_points)
8237 expr += ", ";
8238 }
8239 expr += " })";
8240 }
8241 else if (result_type.basetype == SPIRType::Struct)
8242 {
8243 bool is_array_of_struct = is_array(type: result_type);
8244 if (is_array_of_struct && !ptr_is_io_variable)
8245 SPIRV_CROSS_THROW("Loading array of struct from IO variable must come directly from IO variable.");
8246
8247 uint32_t num_control_points = 1;
8248 if (is_array_of_struct)
8249 {
8250 num_control_points = to_array_size_literal(type: result_type, index: 0);
8251 expr += type_to_glsl(type: result_type) + "({ ";
8252 }
8253
8254 auto &struct_type = is_array_of_struct ? get<SPIRType>(id: result_type.parent_type) : result_type;
8255 assert(struct_type.array.empty());
8256
8257 for (uint32_t i = 0; i < num_control_points; i++)
8258 {
8259 expr += type_to_glsl(type: struct_type) + "{ ";
8260 for (uint32_t j = 0; j < uint32_t(struct_type.member_types.size()); j++)
8261 {
8262 // The base interface index is stored per variable for structs.
8263 if (var)
8264 {
8265 interface_index =
8266 get_extended_member_decoration(type: var->self, index: j, decoration: SPIRVCrossDecorationInterfaceMemberIndex);
8267 }
8268
8269 if (interface_index == uint32_t(-1))
8270 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
8271
8272 const auto &mbr_type = get<SPIRType>(id: struct_type.member_types[j]);
8273 const auto &expr_mbr_type = get<SPIRType>(id: expr_type.member_types[j]);
8274 if (is_matrix(type: mbr_type) && ptr_type.storage == StorageClassInput)
8275 {
8276 expr += type_to_glsl(type: mbr_type) + "(";
8277 for (uint32_t k = 0; k < mbr_type.columns; k++, interface_index++)
8278 {
8279 if (is_array_of_struct)
8280 {
8281 const uint32_t indices[2] = { i, interface_index };
8282 AccessChainMeta meta;
8283 expr += access_chain_internal(
8284 base: stage_in_ptr_var_id, indices, count: 2,
8285 flags: ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, meta: &meta);
8286 }
8287 else
8288 expr += to_expression(id: ptr) + "." + to_member_name(type: iface_type, index: interface_index);
8289 if (expr_mbr_type.vecsize > mbr_type.vecsize)
8290 expr += vector_swizzle(vecsize: mbr_type.vecsize, index: 0);
8291
8292 if (k + 1 < mbr_type.columns)
8293 expr += ", ";
8294 }
8295 expr += ")";
8296 }
8297 else if (is_array(type: mbr_type))
8298 {
8299 expr += type_to_glsl(type: mbr_type) + "({ ";
8300 uint32_t array_size = to_array_size_literal(type: mbr_type, index: 0);
8301 for (uint32_t k = 0; k < array_size; k++, interface_index++)
8302 {
8303 if (is_array_of_struct)
8304 {
8305 const uint32_t indices[2] = { i, interface_index };
8306 AccessChainMeta meta;
8307 expr += access_chain_internal(
8308 base: stage_in_ptr_var_id, indices, count: 2,
8309 flags: ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, meta: &meta);
8310 }
8311 else
8312 expr += to_expression(id: ptr) + "." + to_member_name(type: iface_type, index: interface_index);
8313 if (expr_mbr_type.vecsize > mbr_type.vecsize)
8314 expr += vector_swizzle(vecsize: mbr_type.vecsize, index: 0);
8315
8316 if (k + 1 < array_size)
8317 expr += ", ";
8318 }
8319 expr += " })";
8320 }
8321 else
8322 {
8323 if (is_array_of_struct)
8324 {
8325 const uint32_t indices[2] = { i, interface_index };
8326 AccessChainMeta meta;
8327 expr += access_chain_internal(base: stage_in_ptr_var_id, indices, count: 2,
8328 flags: ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT,
8329 meta: &meta);
8330 }
8331 else
8332 expr += to_expression(id: ptr) + "." + to_member_name(type: iface_type, index: interface_index);
8333 if (expr_mbr_type.vecsize > mbr_type.vecsize)
8334 expr += vector_swizzle(vecsize: mbr_type.vecsize, index: 0);
8335 }
8336
8337 if (j + 1 < struct_type.member_types.size())
8338 expr += ", ";
8339 }
8340 expr += " }";
8341 if (i + 1 < num_control_points)
8342 expr += ", ";
8343 }
8344 if (is_array_of_struct)
8345 expr += " })";
8346 }
8347 else if (is_matrix(type: result_type))
8348 {
8349 bool is_array_of_matrix = is_array(type: result_type);
8350 if (is_array_of_matrix && !ptr_is_io_variable)
8351 SPIRV_CROSS_THROW("Loading array of matrix from IO variable must come directly from IO variable.");
8352 if (interface_index == uint32_t(-1))
8353 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
8354
8355 if (is_array_of_matrix)
8356 {
8357 // Loading a matrix from each control point.
8358 uint32_t base_interface_index = interface_index;
8359 uint32_t num_control_points = to_array_size_literal(type: result_type, index: 0);
8360 expr += type_to_glsl(type: result_type) + "({ ";
8361
8362 auto &matrix_type = get_variable_element_type(var: get<SPIRVariable>(id: ptr));
8363
8364 for (uint32_t i = 0; i < num_control_points; i++)
8365 {
8366 interface_index = base_interface_index;
8367 expr += type_to_glsl(type: matrix_type) + "(";
8368 for (uint32_t j = 0; j < result_type.columns; j++, interface_index++)
8369 {
8370 const uint32_t indices[2] = { i, interface_index };
8371
8372 AccessChainMeta meta;
8373 expr += access_chain_internal(base: stage_in_ptr_var_id, indices, count: 2,
8374 flags: ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, meta: &meta);
8375 if (expr_type.vecsize > result_type.vecsize)
8376 expr += vector_swizzle(vecsize: result_type.vecsize, index: 0);
8377 if (j + 1 < result_type.columns)
8378 expr += ", ";
8379 }
8380 expr += ")";
8381 if (i + 1 < num_control_points)
8382 expr += ", ";
8383 }
8384
8385 expr += " })";
8386 }
8387 else
8388 {
8389 expr += type_to_glsl(type: result_type) + "(";
8390 for (uint32_t i = 0; i < result_type.columns; i++, interface_index++)
8391 {
8392 expr += to_expression(id: ptr) + "." + to_member_name(type: iface_type, index: interface_index);
8393 if (expr_type.vecsize > result_type.vecsize)
8394 expr += vector_swizzle(vecsize: result_type.vecsize, index: 0);
8395 if (i + 1 < result_type.columns)
8396 expr += ", ";
8397 }
8398 expr += ")";
8399 }
8400 }
8401 else if (ptr_is_io_variable)
8402 {
8403 assert(is_array(result_type));
8404 assert(result_type.array.size() == 1);
8405 if (interface_index == uint32_t(-1))
8406 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
8407
8408 // We're loading an array directly from a global variable.
8409 // This means we're loading one member from each control point.
8410 expr += type_to_glsl(type: result_type) + "({ ";
8411 uint32_t num_control_points = to_array_size_literal(type: result_type, index: 0);
8412
8413 for (uint32_t i = 0; i < num_control_points; i++)
8414 {
8415 const uint32_t indices[2] = { i, interface_index };
8416
8417 AccessChainMeta meta;
8418 expr += access_chain_internal(base: stage_in_ptr_var_id, indices, count: 2,
8419 flags: ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, meta: &meta);
8420 if (expr_type.vecsize > result_type.vecsize)
8421 expr += vector_swizzle(vecsize: result_type.vecsize, index: 0);
8422
8423 if (i + 1 < num_control_points)
8424 expr += ", ";
8425 }
8426 expr += " })";
8427 }
8428 else
8429 {
8430 // We're loading an array from a concrete control point.
8431 assert(is_array(result_type));
8432 assert(result_type.array.size() == 1);
8433 if (interface_index == uint32_t(-1))
8434 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
8435
8436 expr += type_to_glsl(type: result_type) + "({ ";
8437 uint32_t array_size = to_array_size_literal(type: result_type, index: 0);
8438 for (uint32_t i = 0; i < array_size; i++, interface_index++)
8439 {
8440 expr += to_expression(id: ptr) + "." + to_member_name(type: iface_type, index: interface_index);
8441 if (expr_type.vecsize > result_type.vecsize)
8442 expr += vector_swizzle(vecsize: result_type.vecsize, index: 0);
8443 if (i + 1 < array_size)
8444 expr += ", ";
8445 }
8446 expr += " })";
8447 }
8448
8449 emit_op(result_type: result_type_id, result_id: id, rhs: expr, forward_rhs: false);
8450 register_read(expr: id, chain: ptr, forwarded: false);
8451 return true;
8452}
8453
8454bool CompilerMSL::emit_tessellation_access_chain(const uint32_t *ops, uint32_t length)
8455{
8456 // If this is a per-vertex output, remap it to the I/O array buffer.
8457
8458 // Any object which did not go through IO flattening shenanigans will go there instead.
8459 // We will unflatten on-demand instead as needed, but not all possible cases can be supported, especially with arrays.
8460
8461 auto *var = maybe_get_backing_variable(chain: ops[2]);
8462 bool patch = false;
8463 bool flat_data = false;
8464 bool ptr_is_chain = false;
8465 bool flatten_composites = false;
8466
8467 bool is_block = false;
8468 bool is_arrayed = false;
8469
8470 if (var)
8471 {
8472 auto &type = get_variable_data_type(var: *var);
8473 is_block = has_decoration(id: type.self, decoration: DecorationBlock);
8474 is_arrayed = !type.array.empty();
8475
8476 flatten_composites = variable_storage_requires_stage_io(storage: var->storage);
8477 patch = has_decoration(id: ops[2], decoration: DecorationPatch) || is_patch_block(type);
8478
8479 // Should match strip_array in add_interface_block.
8480 flat_data = var->storage == StorageClassInput || (var->storage == StorageClassOutput && is_tesc_shader());
8481
8482 // Patch inputs are treated as normal block IO variables, so they don't deal with this path at all.
8483 if (patch && (!is_block || is_arrayed || var->storage == StorageClassInput))
8484 flat_data = false;
8485
8486 // We might have a chained access chain, where
8487 // we first take the access chain to the control point, and then we chain into a member or something similar.
8488 // In this case, we need to skip gl_in/gl_out remapping.
8489 // Also, skip ptr chain for patches.
8490 ptr_is_chain = var->self != ID(ops[2]);
8491 }
8492
8493 bool builtin_variable = false;
8494 bool variable_is_flat = false;
8495
8496 if (var && flat_data)
8497 {
8498 builtin_variable = is_builtin_variable(var: *var);
8499
8500 BuiltIn bi_type = BuiltInMax;
8501 if (builtin_variable && !is_block)
8502 bi_type = BuiltIn(get_decoration(id: var->self, decoration: DecorationBuiltIn));
8503
8504 variable_is_flat = !builtin_variable || is_block ||
8505 bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
8506 bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance;
8507 }
8508
8509 if (variable_is_flat)
8510 {
8511 // If output is masked, it is emitted as a "normal" variable, just go through normal code paths.
8512 // Only check this for the first level of access chain.
8513 // Dealing with this for partial access chains should be possible, but awkward.
8514 if (var->storage == StorageClassOutput && !ptr_is_chain)
8515 {
8516 bool masked = false;
8517 if (is_block)
8518 {
8519 uint32_t relevant_member_index = patch ? 3 : 4;
8520 // FIXME: This won't work properly if the application first access chains into gl_out element,
8521 // then access chains into the member. Super weird, but theoretically possible ...
8522 if (length > relevant_member_index)
8523 {
8524 uint32_t mbr_idx = get<SPIRConstant>(id: ops[relevant_member_index]).scalar();
8525 masked = is_stage_output_block_member_masked(var: *var, index: mbr_idx, strip_array: true);
8526 }
8527 }
8528 else if (var)
8529 masked = is_stage_output_variable_masked(var: *var);
8530
8531 if (masked)
8532 return false;
8533 }
8534
8535 AccessChainMeta meta;
8536 SmallVector<uint32_t> indices;
8537 uint32_t next_id = ir.increase_bound_by(count: 1);
8538
8539 indices.reserve(count: length - 3 + 1);
8540
8541 uint32_t first_non_array_index = (ptr_is_chain ? 3 : 4) - (patch ? 1 : 0);
8542
8543 VariableID stage_var_id;
8544 if (patch)
8545 stage_var_id = var->storage == StorageClassInput ? patch_stage_in_var_id : patch_stage_out_var_id;
8546 else
8547 stage_var_id = var->storage == StorageClassInput ? stage_in_ptr_var_id : stage_out_ptr_var_id;
8548
8549 VariableID ptr = ptr_is_chain ? VariableID(ops[2]) : stage_var_id;
8550 if (!ptr_is_chain && !patch)
8551 {
8552 // Index into gl_in/gl_out with first array index.
8553 indices.push_back(t: ops[first_non_array_index - 1]);
8554 }
8555
8556 auto &result_ptr_type = get<SPIRType>(id: ops[0]);
8557
8558 uint32_t const_mbr_id = next_id++;
8559 uint32_t index = get_extended_decoration(id: ops[2], decoration: SPIRVCrossDecorationInterfaceMemberIndex);
8560
8561 // If we have a pointer chain expression, and we are no longer pointing to a composite
8562 // object, we are in the clear. There is no longer a need to flatten anything.
8563 bool further_access_chain_is_trivial = false;
8564 if (ptr_is_chain && flatten_composites)
8565 {
8566 auto &ptr_type = expression_type(id: ptr);
8567 if (!is_array(type: ptr_type) && !is_matrix(type: ptr_type) && ptr_type.basetype != SPIRType::Struct)
8568 further_access_chain_is_trivial = true;
8569 }
8570
8571 if (!further_access_chain_is_trivial && (flatten_composites || is_block))
8572 {
8573 uint32_t i = first_non_array_index;
8574 auto *type = &get_variable_element_type(var: *var);
8575 if (index == uint32_t(-1) && length >= (first_non_array_index + 1))
8576 {
8577 // Maybe this is a struct type in the input class, in which case
8578 // we put it as a decoration on the corresponding member.
8579 uint32_t mbr_idx = get_constant(id: ops[first_non_array_index]).scalar();
8580 index = get_extended_member_decoration(type: var->self, index: mbr_idx,
8581 decoration: SPIRVCrossDecorationInterfaceMemberIndex);
8582 assert(index != uint32_t(-1));
8583 i++;
8584 type = &get<SPIRType>(id: type->member_types[mbr_idx]);
8585 }
8586
8587 // In this case, we're poking into flattened structures and arrays, so now we have to
8588 // combine the following indices. If we encounter a non-constant index,
8589 // we're hosed.
8590 for (; flatten_composites && i < length; ++i)
8591 {
8592 if (!is_array(type: *type) && !is_matrix(type: *type) && type->basetype != SPIRType::Struct)
8593 break;
8594
8595 auto *c = maybe_get<SPIRConstant>(id: ops[i]);
8596 if (!c || c->specialization)
8597 SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable in tessellation. "
8598 "This is currently unsupported.");
8599
8600 // We're in flattened space, so just increment the member index into IO block.
8601 // We can only do this once in the current implementation, so either:
8602 // Struct, Matrix or 1-dimensional array for a control point.
8603 if (type->basetype == SPIRType::Struct && var->storage == StorageClassOutput)
8604 {
8605 // Need to consider holes, since individual block members might be masked away.
8606 uint32_t mbr_idx = c->scalar();
8607 for (uint32_t j = 0; j < mbr_idx; j++)
8608 if (!is_stage_output_block_member_masked(var: *var, index: j, strip_array: true))
8609 index++;
8610 }
8611 else
8612 index += c->scalar();
8613
8614 if (type->parent_type)
8615 type = &get<SPIRType>(id: type->parent_type);
8616 else if (type->basetype == SPIRType::Struct)
8617 type = &get<SPIRType>(id: type->member_types[c->scalar()]);
8618 }
8619
8620 // We're not going to emit the actual member name, we let any further OpLoad take care of that.
8621 // Tag the access chain with the member index we're referencing.
8622 auto &result_pointee_type = get_pointee_type(type: result_ptr_type);
8623 bool defer_access_chain = flatten_composites && (is_matrix(type: result_pointee_type) || is_array(type: result_pointee_type) ||
8624 result_pointee_type.basetype == SPIRType::Struct);
8625
8626 if (!defer_access_chain)
8627 {
8628 // Access the appropriate member of gl_in/gl_out.
8629 set<SPIRConstant>(id: const_mbr_id, args: get_uint_type_id(), args&: index, args: false);
8630 indices.push_back(t: const_mbr_id);
8631
8632 // Member index is now irrelevant.
8633 index = uint32_t(-1);
8634
8635 // Append any straggling access chain indices.
8636 if (i < length)
8637 indices.insert(itr: indices.end(), insert_begin: ops + i, insert_end: ops + length);
8638 }
8639 else
8640 {
8641 // We must have consumed the entire access chain if we're deferring it.
8642 assert(i == length);
8643 }
8644
8645 if (index != uint32_t(-1))
8646 set_extended_decoration(id: ops[1], decoration: SPIRVCrossDecorationInterfaceMemberIndex, value: index);
8647 else
8648 unset_extended_decoration(id: ops[1], decoration: SPIRVCrossDecorationInterfaceMemberIndex);
8649 }
8650 else
8651 {
8652 if (index != uint32_t(-1))
8653 {
8654 set<SPIRConstant>(id: const_mbr_id, args: get_uint_type_id(), args&: index, args: false);
8655 indices.push_back(t: const_mbr_id);
8656 }
8657
8658 // Member index is now irrelevant.
8659 index = uint32_t(-1);
8660 unset_extended_decoration(id: ops[1], decoration: SPIRVCrossDecorationInterfaceMemberIndex);
8661
8662 indices.insert(itr: indices.end(), insert_begin: ops + first_non_array_index, insert_end: ops + length);
8663 }
8664
8665 // We use the pointer to the base of the input/output array here,
8666 // so this is always a pointer chain.
8667 string e;
8668
8669 if (!ptr_is_chain)
8670 {
8671 // This is the start of an access chain, use ptr_chain to index into control point array.
8672 e = access_chain(base: ptr, indices: indices.data(), count: uint32_t(indices.size()), target_type: result_ptr_type, meta: &meta, ptr_chain: !patch);
8673 }
8674 else
8675 {
8676 // If we're accessing a struct, we need to use member indices which are based on the IO block,
8677 // not actual struct type, so we have to use a split access chain here where
8678 // first path resolves the control point index, i.e. gl_in[index], and second half deals with
8679 // looking up flattened member name.
8680
8681 // However, it is possible that we partially accessed a struct,
8682 // by taking pointer to member inside the control-point array.
8683 // For this case, we fall back to a natural access chain since we have already dealt with remapping struct members.
8684 // One way to check this here is if we have 2 implied read expressions.
8685 // First one is the gl_in/gl_out struct itself, then an index into that array.
8686 // If we have traversed further, we use a normal access chain formulation.
8687 auto *ptr_expr = maybe_get<SPIRExpression>(id: ptr);
8688 bool split_access_chain_formulation = flatten_composites && ptr_expr &&
8689 ptr_expr->implied_read_expressions.size() == 2 &&
8690 !further_access_chain_is_trivial;
8691
8692 if (split_access_chain_formulation)
8693 {
8694 e = join(ts: to_expression(id: ptr),
8695 ts: access_chain_internal(base: stage_var_id, indices: indices.data(), count: uint32_t(indices.size()),
8696 flags: ACCESS_CHAIN_CHAIN_ONLY_BIT, meta: &meta));
8697 }
8698 else
8699 {
8700 e = access_chain_internal(base: ptr, indices: indices.data(), count: uint32_t(indices.size()), flags: 0, meta: &meta);
8701 }
8702 }
8703
8704 // Get the actual type of the object that was accessed. If it's a vector type and we changed it,
8705 // then we'll need to add a swizzle.
8706 // For this, we can't necessarily rely on the type of the base expression, because it might be
8707 // another access chain, and it will therefore already have the "correct" type.
8708 auto *expr_type = &get_variable_data_type(var: *var);
8709 if (has_extended_decoration(id: ops[2], decoration: SPIRVCrossDecorationTessIOOriginalInputTypeID))
8710 expr_type = &get<SPIRType>(id: get_extended_decoration(id: ops[2], decoration: SPIRVCrossDecorationTessIOOriginalInputTypeID));
8711 for (uint32_t i = 3; i < length; i++)
8712 {
8713 if (!is_array(type: *expr_type) && expr_type->basetype == SPIRType::Struct)
8714 expr_type = &get<SPIRType>(id: expr_type->member_types[get<SPIRConstant>(id: ops[i]).scalar()]);
8715 else
8716 expr_type = &get<SPIRType>(id: expr_type->parent_type);
8717 }
8718 if (!is_array(type: *expr_type) && !is_matrix(type: *expr_type) && expr_type->basetype != SPIRType::Struct &&
8719 expr_type->vecsize > result_ptr_type.vecsize)
8720 e += vector_swizzle(vecsize: result_ptr_type.vecsize, index: 0);
8721
8722 auto &expr = set<SPIRExpression>(id: ops[1], args: std::move(e), args: ops[0], args: should_forward(id: ops[2]));
8723 expr.loaded_from = var->self;
8724 expr.need_transpose = meta.need_transpose;
8725 expr.access_chain = true;
8726
8727 // Mark the result as being packed if necessary.
8728 if (meta.storage_is_packed)
8729 set_extended_decoration(id: ops[1], decoration: SPIRVCrossDecorationPhysicalTypePacked);
8730 if (meta.storage_physical_type != 0)
8731 set_extended_decoration(id: ops[1], decoration: SPIRVCrossDecorationPhysicalTypeID, value: meta.storage_physical_type);
8732 if (meta.storage_is_invariant)
8733 set_decoration(id: ops[1], decoration: DecorationInvariant);
8734 // Save the type we found in case the result is used in another access chain.
8735 set_extended_decoration(id: ops[1], decoration: SPIRVCrossDecorationTessIOOriginalInputTypeID, value: expr_type->self);
8736
8737 // If we have some expression dependencies in our access chain, this access chain is technically a forwarded
8738 // temporary which could be subject to invalidation.
8739 // Need to assume we're forwarded while calling inherit_expression_depdendencies.
8740 forwarded_temporaries.insert(x: ops[1]);
8741 // The access chain itself is never forced to a temporary, but its dependencies might.
8742 suppressed_usage_tracking.insert(x: ops[1]);
8743
8744 for (uint32_t i = 2; i < length; i++)
8745 {
8746 inherit_expression_dependencies(dst: ops[1], source: ops[i]);
8747 add_implied_read_expression(e&: expr, source: ops[i]);
8748 }
8749
8750 // If we have no dependencies after all, i.e., all indices in the access chain are immutable temporaries,
8751 // we're not forwarded after all.
8752 if (expr.expression_dependencies.empty())
8753 forwarded_temporaries.erase(x: ops[1]);
8754
8755 return true;
8756 }
8757
8758 // If this is the inner tessellation level, and we're tessellating triangles,
8759 // drop the last index. It isn't an array in this case, so we can't have an
8760 // array reference here. We need to make this ID a variable instead of an
8761 // expression so we don't try to dereference it as a variable pointer.
8762 // Don't do this if the index is a constant 1, though. We need to drop stores
8763 // to that one.
8764 auto *m = ir.find_meta(id: var ? var->self : ID(0));
8765 if (is_tesc_shader() && var && m && m->decoration.builtin_type == BuiltInTessLevelInner &&
8766 is_tessellating_triangles())
8767 {
8768 auto *c = maybe_get<SPIRConstant>(id: ops[3]);
8769 if (c && c->scalar() == 1)
8770 return false;
8771 auto &dest_var = set<SPIRVariable>(id: ops[1], args&: *var);
8772 dest_var.basetype = ops[0];
8773 ir.meta[ops[1]] = ir.meta[ops[2]];
8774 inherit_expression_dependencies(dst: ops[1], source: ops[2]);
8775 return true;
8776 }
8777
8778 return false;
8779}
8780
8781bool CompilerMSL::is_out_of_bounds_tessellation_level(uint32_t id_lhs)
8782{
8783 if (!is_tessellating_triangles())
8784 return false;
8785
8786 // In SPIR-V, TessLevelInner always has two elements and TessLevelOuter always has
8787 // four. This is true even if we are tessellating triangles. This allows clients
8788 // to use a single tessellation control shader with multiple tessellation evaluation
8789 // shaders.
8790 // In Metal, however, only the first element of TessLevelInner and the first three
8791 // of TessLevelOuter are accessible. This stems from how in Metal, the tessellation
8792 // levels must be stored to a dedicated buffer in a particular format that depends
8793 // on the patch type. Therefore, in Triangles mode, any store to the second
8794 // inner level or the fourth outer level must be dropped.
8795 const auto *e = maybe_get<SPIRExpression>(id: id_lhs);
8796 if (!e || !e->access_chain)
8797 return false;
8798 BuiltIn builtin = BuiltIn(get_decoration(id: e->loaded_from, decoration: DecorationBuiltIn));
8799 if (builtin != BuiltInTessLevelInner && builtin != BuiltInTessLevelOuter)
8800 return false;
8801 auto *c = maybe_get<SPIRConstant>(id: e->implied_read_expressions[1]);
8802 if (!c)
8803 return false;
8804 return (builtin == BuiltInTessLevelInner && c->scalar() == 1) ||
8805 (builtin == BuiltInTessLevelOuter && c->scalar() == 3);
8806}
8807
8808bool CompilerMSL::prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
8809 spv::StorageClass storage, bool &is_packed)
8810{
8811 // If there is any risk of writes happening with the access chain in question,
8812 // and there is a risk of concurrent write access to other components,
8813 // we must cast the access chain to a plain pointer to ensure we only access the exact scalars we expect.
8814 // The MSL compiler refuses to allow component-level access for any non-packed vector types.
8815 if (!is_packed && (storage == StorageClassStorageBuffer || storage == StorageClassWorkgroup))
8816 {
8817 const char *addr_space = storage == StorageClassWorkgroup ? "threadgroup" : "device";
8818 expr = join(ts: "((", ts&: addr_space, ts: " ", ts: type_to_glsl(type), ts: "*)&", ts: enclose_expression(expr), ts: ")");
8819
8820 // Further indexing should happen with packed rules (array index, not swizzle).
8821 is_packed = true;
8822 return true;
8823 }
8824 else
8825 return false;
8826}
8827
8828bool CompilerMSL::access_chain_needs_stage_io_builtin_translation(uint32_t base)
8829{
8830 auto *var = maybe_get_backing_variable(chain: base);
8831 if (!var || !is_tessellation_shader())
8832 return true;
8833
8834 // We only need to rewrite builtin access chains when accessing flattened builtins like gl_ClipDistance_N.
8835 // Avoid overriding it back to just gl_ClipDistance.
8836 // This can only happen in scenarios where we cannot flatten/unflatten access chains, so, the only case
8837 // where this triggers is evaluation shader inputs.
8838 bool redirect_builtin = is_tese_shader() ? var->storage == StorageClassOutput : false;
8839 return redirect_builtin;
8840}
8841
8842// Sets the interface member index for an access chain to a pull-model interpolant.
8843void CompilerMSL::fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t length)
8844{
8845 auto *var = maybe_get_backing_variable(chain: ops[2]);
8846 if (!var || !pull_model_inputs.count(x: var->self))
8847 return;
8848 // Get the base index.
8849 uint32_t interface_index;
8850 auto &var_type = get_variable_data_type(var: *var);
8851 auto &result_type = get<SPIRType>(id: ops[0]);
8852 auto *type = &var_type;
8853 if (has_extended_decoration(id: ops[2], decoration: SPIRVCrossDecorationInterfaceMemberIndex))
8854 {
8855 interface_index = get_extended_decoration(id: ops[2], decoration: SPIRVCrossDecorationInterfaceMemberIndex);
8856 }
8857 else
8858 {
8859 // Assume an access chain into a struct variable.
8860 assert(var_type.basetype == SPIRType::Struct);
8861 auto &c = get<SPIRConstant>(id: ops[3 + var_type.array.size()]);
8862 interface_index =
8863 get_extended_member_decoration(type: var->self, index: c.scalar(), decoration: SPIRVCrossDecorationInterfaceMemberIndex);
8864 }
8865 // Accumulate indices. We'll have to skip over the one for the struct, if present, because we already accounted
8866 // for that getting the base index.
8867 for (uint32_t i = 3; i < length; ++i)
8868 {
8869 if (is_vector(type: *type) && !is_array(type: *type) && is_scalar(type: result_type))
8870 {
8871 // We don't want to combine the next index. Actually, we need to save it
8872 // so we know to apply a swizzle to the result of the interpolation.
8873 set_extended_decoration(id: ops[1], decoration: SPIRVCrossDecorationInterpolantComponentExpr, value: ops[i]);
8874 break;
8875 }
8876
8877 auto *c = maybe_get<SPIRConstant>(id: ops[i]);
8878 if (!c || c->specialization)
8879 SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable using pull-model "
8880 "interpolation. This is currently unsupported.");
8881
8882 if (type->parent_type)
8883 type = &get<SPIRType>(id: type->parent_type);
8884 else if (type->basetype == SPIRType::Struct)
8885 type = &get<SPIRType>(id: type->member_types[c->scalar()]);
8886
8887 if (!has_extended_decoration(id: ops[2], decoration: SPIRVCrossDecorationInterfaceMemberIndex) &&
8888 i - 3 == var_type.array.size())
8889 continue;
8890
8891 interface_index += c->scalar();
8892 }
8893 // Save this to the access chain itself so we can recover it later when calling an interpolation function.
8894 set_extended_decoration(id: ops[1], decoration: SPIRVCrossDecorationInterfaceMemberIndex, value: interface_index);
8895}
8896
8897
8898// If the physical type of a physical buffer pointer has been changed
8899// to a ulong or ulongn vector, add a cast back to the pointer type.
8900void CompilerMSL::check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type)
8901{
8902 auto *p_physical_type = maybe_get<SPIRType>(id: physical_type);
8903 if (p_physical_type &&
8904 p_physical_type->storage == StorageClassPhysicalStorageBuffer &&
8905 p_physical_type->basetype == to_unsigned_basetype(width: 64))
8906 {
8907 if (p_physical_type->vecsize > 1)
8908 expr += ".x";
8909
8910 expr = join(ts: "((", ts: type_to_glsl(type: *type), ts: ")", ts&: expr, ts: ")");
8911 }
8912}
8913
8914// Override for MSL-specific syntax instructions
8915void CompilerMSL::emit_instruction(const Instruction &instruction)
8916{
8917#define MSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
8918#define MSL_PTR_BOP(op) emit_binary_ptr_op(ops[0], ops[1], ops[2], ops[3], #op)
8919 // MSL does care about implicit integer promotion, but those cases are all handled in common code.
8920#define MSL_BOP_CAST(op, type) \
8921 emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode), false)
8922#define MSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
8923#define MSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
8924#define MSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
8925#define MSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
8926#define MSL_BFOP_CAST(op, type) \
8927 emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
8928#define MSL_UFOP(op) emit_unary_func_op(ops[0], ops[1], ops[2], #op)
8929#define MSL_UNORD_BOP(op) emit_binary_unord_op(ops[0], ops[1], ops[2], ops[3], #op)
8930
8931 auto ops = stream(instr: instruction);
8932 auto opcode = static_cast<Op>(instruction.op);
8933
8934 opcode = get_remapped_spirv_op(op: opcode);
8935
8936 // If we need to do implicit bitcasts, make sure we do it with the correct type.
8937 uint32_t integer_width = get_integer_width_for_instruction(instr: instruction);
8938 auto int_type = to_signed_basetype(width: integer_width);
8939 auto uint_type = to_unsigned_basetype(width: integer_width);
8940
8941 switch (opcode)
8942 {
8943 case OpLoad:
8944 {
8945 uint32_t id = ops[1];
8946 uint32_t ptr = ops[2];
8947 if (is_tessellation_shader())
8948 {
8949 if (!emit_tessellation_io_load(result_type_id: ops[0], id, ptr))
8950 CompilerGLSL::emit_instruction(instr: instruction);
8951 }
8952 else
8953 {
8954 // Sample mask input for Metal is not an array
8955 if (BuiltIn(get_decoration(id: ptr, decoration: DecorationBuiltIn)) == BuiltInSampleMask)
8956 set_decoration(id, decoration: DecorationBuiltIn, argument: BuiltInSampleMask);
8957 CompilerGLSL::emit_instruction(instr: instruction);
8958 }
8959 break;
8960 }
8961
8962 // Comparisons
8963 case OpIEqual:
8964 MSL_BOP_CAST(==, int_type);
8965 break;
8966
8967 case OpLogicalEqual:
8968 case OpFOrdEqual:
8969 MSL_BOP(==);
8970 break;
8971
8972 case OpINotEqual:
8973 MSL_BOP_CAST(!=, int_type);
8974 break;
8975
8976 case OpLogicalNotEqual:
8977 case OpFOrdNotEqual:
8978 // TODO: Should probably negate the == result here.
8979 // Typically OrdNotEqual comes from GLSL which itself does not really specify what
8980 // happens with NaN.
8981 // Consider fixing this if we run into real issues.
8982 MSL_BOP(!=);
8983 break;
8984
8985 case OpUGreaterThan:
8986 MSL_BOP_CAST(>, uint_type);
8987 break;
8988
8989 case OpSGreaterThan:
8990 MSL_BOP_CAST(>, int_type);
8991 break;
8992
8993 case OpFOrdGreaterThan:
8994 MSL_BOP(>);
8995 break;
8996
8997 case OpUGreaterThanEqual:
8998 MSL_BOP_CAST(>=, uint_type);
8999 break;
9000
9001 case OpSGreaterThanEqual:
9002 MSL_BOP_CAST(>=, int_type);
9003 break;
9004
9005 case OpFOrdGreaterThanEqual:
9006 MSL_BOP(>=);
9007 break;
9008
9009 case OpULessThan:
9010 MSL_BOP_CAST(<, uint_type);
9011 break;
9012
9013 case OpSLessThan:
9014 MSL_BOP_CAST(<, int_type);
9015 break;
9016
9017 case OpFOrdLessThan:
9018 MSL_BOP(<);
9019 break;
9020
9021 case OpULessThanEqual:
9022 MSL_BOP_CAST(<=, uint_type);
9023 break;
9024
9025 case OpSLessThanEqual:
9026 MSL_BOP_CAST(<=, int_type);
9027 break;
9028
9029 case OpFOrdLessThanEqual:
9030 MSL_BOP(<=);
9031 break;
9032
9033 case OpFUnordEqual:
9034 MSL_UNORD_BOP(==);
9035 break;
9036
9037 case OpFUnordNotEqual:
9038 // not equal in MSL generates une opcodes to begin with.
9039 // Since unordered not equal is how it works in C, just inherit that behavior.
9040 MSL_BOP(!=);
9041 break;
9042
9043 case OpFUnordGreaterThan:
9044 MSL_UNORD_BOP(>);
9045 break;
9046
9047 case OpFUnordGreaterThanEqual:
9048 MSL_UNORD_BOP(>=);
9049 break;
9050
9051 case OpFUnordLessThan:
9052 MSL_UNORD_BOP(<);
9053 break;
9054
9055 case OpFUnordLessThanEqual:
9056 MSL_UNORD_BOP(<=);
9057 break;
9058
9059 // Pointer math
9060 case OpPtrEqual:
9061 MSL_PTR_BOP(==);
9062 break;
9063
9064 case OpPtrNotEqual:
9065 MSL_PTR_BOP(!=);
9066 break;
9067
9068 case OpPtrDiff:
9069 MSL_PTR_BOP(-);
9070 break;
9071
9072 // Derivatives
9073 case OpDPdx:
9074 case OpDPdxFine:
9075 case OpDPdxCoarse:
9076 MSL_UFOP(dfdx);
9077 register_control_dependent_expression(expr: ops[1]);
9078 break;
9079
9080 case OpDPdy:
9081 case OpDPdyFine:
9082 case OpDPdyCoarse:
9083 MSL_UFOP(dfdy);
9084 register_control_dependent_expression(expr: ops[1]);
9085 break;
9086
9087 case OpFwidth:
9088 case OpFwidthCoarse:
9089 case OpFwidthFine:
9090 MSL_UFOP(fwidth);
9091 register_control_dependent_expression(expr: ops[1]);
9092 break;
9093
9094 // Bitfield
9095 case OpBitFieldInsert:
9096 {
9097 emit_bitfield_insert_op(result_type: ops[0], result_id: ops[1], op0: ops[2], op1: ops[3], op2: ops[4], op3: ops[5], op: "insert_bits", offset_count_type: SPIRType::UInt);
9098 break;
9099 }
9100
9101 case OpBitFieldSExtract:
9102 {
9103 emit_trinary_func_op_bitextract(result_type: ops[0], result_id: ops[1], op0: ops[2], op1: ops[3], op2: ops[4], op: "extract_bits", expected_result_type: int_type, input_type0: int_type,
9104 input_type1: SPIRType::UInt, input_type2: SPIRType::UInt);
9105 break;
9106 }
9107
9108 case OpBitFieldUExtract:
9109 {
9110 emit_trinary_func_op_bitextract(result_type: ops[0], result_id: ops[1], op0: ops[2], op1: ops[3], op2: ops[4], op: "extract_bits", expected_result_type: uint_type, input_type0: uint_type,
9111 input_type1: SPIRType::UInt, input_type2: SPIRType::UInt);
9112 break;
9113 }
9114
9115 case OpBitReverse:
9116 // BitReverse does not have issues with sign since result type must match input type.
9117 MSL_UFOP(reverse_bits);
9118 break;
9119
9120 case OpBitCount:
9121 {
9122 auto basetype = expression_type(id: ops[2]).basetype;
9123 emit_unary_func_op_cast(result_type: ops[0], result_id: ops[1], op0: ops[2], op: "popcount", input_type: basetype, expected_result_type: basetype);
9124 break;
9125 }
9126
9127 case OpFRem:
9128 MSL_BFOP(fmod);
9129 break;
9130
9131 case OpFMul:
9132 if (msl_options.invariant_float_math || has_decoration(id: ops[1], decoration: DecorationNoContraction))
9133 MSL_BFOP(spvFMul);
9134 else
9135 MSL_BOP(*);
9136 break;
9137
9138 case OpFAdd:
9139 if (msl_options.invariant_float_math || has_decoration(id: ops[1], decoration: DecorationNoContraction))
9140 MSL_BFOP(spvFAdd);
9141 else
9142 MSL_BOP(+);
9143 break;
9144
9145 case OpFSub:
9146 if (msl_options.invariant_float_math || has_decoration(id: ops[1], decoration: DecorationNoContraction))
9147 MSL_BFOP(spvFSub);
9148 else
9149 MSL_BOP(-);
9150 break;
9151
9152 // Atomics
9153 case OpAtomicExchange:
9154 {
9155 uint32_t result_type = ops[0];
9156 uint32_t id = ops[1];
9157 uint32_t ptr = ops[2];
9158 uint32_t mem_sem = ops[4];
9159 uint32_t val = ops[5];
9160 emit_atomic_func_op(result_type, result_id: id, op: "atomic_exchange", opcode, mem_order_1: mem_sem, mem_order_2: mem_sem, has_mem_order_2: false, op0: ptr, op1: val);
9161 break;
9162 }
9163
9164 case OpAtomicCompareExchange:
9165 {
9166 uint32_t result_type = ops[0];
9167 uint32_t id = ops[1];
9168 uint32_t ptr = ops[2];
9169 uint32_t mem_sem_pass = ops[4];
9170 uint32_t mem_sem_fail = ops[5];
9171 uint32_t val = ops[6];
9172 uint32_t comp = ops[7];
9173 emit_atomic_func_op(result_type, result_id: id, op: "atomic_compare_exchange_weak", opcode,
9174 mem_order_1: mem_sem_pass, mem_order_2: mem_sem_fail, has_mem_order_2: true,
9175 op0: ptr, op1: comp, op1_is_pointer: true, op1_is_literal: false, op2: val);
9176 break;
9177 }
9178
9179 case OpAtomicCompareExchangeWeak:
9180 SPIRV_CROSS_THROW("OpAtomicCompareExchangeWeak is only supported in kernel profile.");
9181
9182 case OpAtomicLoad:
9183 {
9184 uint32_t result_type = ops[0];
9185 uint32_t id = ops[1];
9186 uint32_t ptr = ops[2];
9187 uint32_t mem_sem = ops[4];
9188 check_atomic_image(id: ptr);
9189 emit_atomic_func_op(result_type, result_id: id, op: "atomic_load", opcode, mem_order_1: mem_sem, mem_order_2: mem_sem, has_mem_order_2: false, op0: ptr, op1: 0);
9190 break;
9191 }
9192
9193 case OpAtomicStore:
9194 {
9195 uint32_t result_type = expression_type(id: ops[0]).self;
9196 uint32_t id = ops[0];
9197 uint32_t ptr = ops[0];
9198 uint32_t mem_sem = ops[2];
9199 uint32_t val = ops[3];
9200 check_atomic_image(id: ptr);
9201 emit_atomic_func_op(result_type, result_id: id, op: "atomic_store", opcode, mem_order_1: mem_sem, mem_order_2: mem_sem, has_mem_order_2: false, op0: ptr, op1: val);
9202 break;
9203 }
9204
9205#define MSL_AFMO_IMPL(op, valsrc, valconst) \
9206 do \
9207 { \
9208 uint32_t result_type = ops[0]; \
9209 uint32_t id = ops[1]; \
9210 uint32_t ptr = ops[2]; \
9211 uint32_t mem_sem = ops[4]; \
9212 uint32_t val = valsrc; \
9213 emit_atomic_func_op(result_type, id, "atomic_fetch_" #op, opcode, \
9214 mem_sem, mem_sem, false, ptr, val, \
9215 false, valconst); \
9216 } while (false)
9217
9218#define MSL_AFMO(op) MSL_AFMO_IMPL(op, ops[5], false)
9219#define MSL_AFMIO(op) MSL_AFMO_IMPL(op, 1, true)
9220
9221 case OpAtomicIIncrement:
9222 MSL_AFMIO(add);
9223 break;
9224
9225 case OpAtomicIDecrement:
9226 MSL_AFMIO(sub);
9227 break;
9228
9229 case OpAtomicIAdd:
9230 case OpAtomicFAddEXT:
9231 MSL_AFMO(add);
9232 break;
9233
9234 case OpAtomicISub:
9235 MSL_AFMO(sub);
9236 break;
9237
9238 case OpAtomicSMin:
9239 case OpAtomicUMin:
9240 MSL_AFMO(min);
9241 break;
9242
9243 case OpAtomicSMax:
9244 case OpAtomicUMax:
9245 MSL_AFMO(max);
9246 break;
9247
9248 case OpAtomicAnd:
9249 MSL_AFMO(and);
9250 break;
9251
9252 case OpAtomicOr:
9253 MSL_AFMO(or);
9254 break;
9255
9256 case OpAtomicXor:
9257 MSL_AFMO(xor);
9258 break;
9259
9260 // Images
9261
9262 // Reads == Fetches in Metal
9263 case OpImageRead:
9264 {
9265 // Mark that this shader reads from this image
9266 uint32_t img_id = ops[2];
9267 auto &type = expression_type(id: img_id);
9268 auto *p_var = maybe_get_backing_variable(chain: img_id);
9269 if (type.image.dim != DimSubpassData)
9270 {
9271 if (p_var && has_decoration(id: p_var->self, decoration: DecorationNonReadable))
9272 {
9273 unset_decoration(id: p_var->self, decoration: DecorationNonReadable);
9274 force_recompile();
9275 }
9276 }
9277
9278 // Metal requires explicit fences to break up RAW hazards, even within the same shader invocation
9279 if (msl_options.readwrite_texture_fences && p_var && !has_decoration(id: p_var->self, decoration: DecorationNonWritable))
9280 {
9281 add_spv_func_and_recompile(spv_func: SPVFuncImplImageFence);
9282 // Need to wrap this with a value type,
9283 // since the Metal headers are broken and do not consider case when the image is a reference.
9284 statement(ts: "spvImageFence(", ts: to_expression(id: img_id), ts: ");");
9285 }
9286
9287 emit_texture_op(i: instruction, sparse: false);
9288 break;
9289 }
9290
9291 // Emulate texture2D atomic operations
9292 case OpImageTexelPointer:
9293 {
9294 // When using the pointer, we need to know which variable it is actually loaded from.
9295 auto *var = maybe_get_backing_variable(chain: ops[2]);
9296 if (var && atomic_image_vars_emulated.count(x: var->self))
9297 {
9298 uint32_t result_type = ops[0];
9299 uint32_t id = ops[1];
9300
9301 std::string coord = to_expression(id: ops[3]);
9302 auto &type = expression_type(id: ops[2]);
9303 if (type.image.dim == Dim2D)
9304 {
9305 coord = join(ts: "spvImage2DAtomicCoord(", ts&: coord, ts: ", ", ts: to_expression(id: ops[2]), ts: ")");
9306 }
9307
9308 auto &e = set<SPIRExpression>(id, args: join(ts: to_expression(id: ops[2]), ts: "_atomic[", ts&: coord, ts: "]"), args&: result_type, args: true);
9309 e.loaded_from = var ? var->self : ID(0);
9310 inherit_expression_dependencies(dst: id, source: ops[3]);
9311 }
9312 else
9313 {
9314 uint32_t result_type = ops[0];
9315 uint32_t id = ops[1];
9316
9317 // Virtual expression. Split this up in the actual image atomic.
9318 // In GLSL and HLSL we are able to resolve the dereference inline, but MSL has
9319 // image.op(coord, ...) syntax.
9320 auto &e =
9321 set<SPIRExpression>(id, args: join(ts: to_expression(id: ops[2]), ts: "@",
9322 ts: bitcast_expression(target_type: SPIRType::UInt, arg: ops[3])),
9323 args&: result_type, args: true);
9324
9325 // When using the pointer, we need to know which variable it is actually loaded from.
9326 e.loaded_from = var ? var->self : ID(0);
9327 inherit_expression_dependencies(dst: id, source: ops[3]);
9328 }
9329 break;
9330 }
9331
9332 case OpImageWrite:
9333 {
9334 uint32_t img_id = ops[0];
9335 uint32_t coord_id = ops[1];
9336 uint32_t texel_id = ops[2];
9337 const uint32_t *opt = &ops[3];
9338 uint32_t length = instruction.length - 3;
9339
9340 // Bypass pointers because we need the real image struct
9341 auto &type = expression_type(id: img_id);
9342 auto &img_type = get<SPIRType>(id: type.self);
9343
9344 // Ensure this image has been marked as being written to and force a
9345 // recommpile so that the image type output will include write access
9346 auto *p_var = maybe_get_backing_variable(chain: img_id);
9347 if (p_var && has_decoration(id: p_var->self, decoration: DecorationNonWritable))
9348 {
9349 unset_decoration(id: p_var->self, decoration: DecorationNonWritable);
9350 force_recompile();
9351 }
9352
9353 bool forward = false;
9354 uint32_t bias = 0;
9355 uint32_t lod = 0;
9356 uint32_t flags = 0;
9357
9358 if (length)
9359 {
9360 flags = *opt++;
9361 length--;
9362 }
9363
9364 auto test = [&](uint32_t &v, uint32_t flag) {
9365 if (length && (flags & flag))
9366 {
9367 v = *opt++;
9368 length--;
9369 }
9370 };
9371
9372 test(bias, ImageOperandsBiasMask);
9373 test(lod, ImageOperandsLodMask);
9374
9375 auto &texel_type = expression_type(id: texel_id);
9376 auto store_type = texel_type;
9377 store_type.vecsize = 4;
9378
9379 TextureFunctionArguments args = {};
9380 args.base.img = img_id;
9381 args.base.imgtype = &img_type;
9382 args.base.is_fetch = true;
9383 args.coord = coord_id;
9384 args.lod = lod;
9385
9386 string expr;
9387 if (needs_frag_discard_checks())
9388 expr = join(ts: "(", ts: builtin_to_glsl(builtin: BuiltInHelperInvocation, storage: StorageClassInput), ts: " ? ((void)0) : ");
9389 expr += join(ts: to_expression(id: img_id), ts: ".write(",
9390 ts: remap_swizzle(result_type: store_type, input_components: texel_type.vecsize, expr: to_expression(id: texel_id)), ts: ", ",
9391 ts: CompilerMSL::to_function_args(args, p_forward: &forward), ts: ")");
9392 if (needs_frag_discard_checks())
9393 expr += ")";
9394 statement(ts&: expr, ts: ";");
9395
9396 if (p_var && variable_storage_is_aliased(var: *p_var))
9397 flush_all_aliased_variables();
9398
9399 break;
9400 }
9401
9402 case OpImageQuerySize:
9403 case OpImageQuerySizeLod:
9404 {
9405 uint32_t rslt_type_id = ops[0];
9406 auto &rslt_type = get<SPIRType>(id: rslt_type_id);
9407
9408 uint32_t id = ops[1];
9409
9410 uint32_t img_id = ops[2];
9411 string img_exp = to_expression(id: img_id);
9412 auto &img_type = expression_type(id: img_id);
9413 Dim img_dim = img_type.image.dim;
9414 bool img_is_array = img_type.image.arrayed;
9415
9416 if (img_type.basetype != SPIRType::Image)
9417 SPIRV_CROSS_THROW("Invalid type for OpImageQuerySize.");
9418
9419 string lod;
9420 if (opcode == OpImageQuerySizeLod)
9421 {
9422 // LOD index defaults to zero, so don't bother outputing level zero index
9423 string decl_lod = to_expression(id: ops[3]);
9424 if (decl_lod != "0")
9425 lod = decl_lod;
9426 }
9427
9428 string expr = type_to_glsl(type: rslt_type) + "(";
9429 expr += img_exp + ".get_width(" + lod + ")";
9430
9431 if (img_dim == Dim2D || img_dim == DimCube || img_dim == Dim3D)
9432 expr += ", " + img_exp + ".get_height(" + lod + ")";
9433
9434 if (img_dim == Dim3D)
9435 expr += ", " + img_exp + ".get_depth(" + lod + ")";
9436
9437 if (img_is_array)
9438 {
9439 expr += ", " + img_exp + ".get_array_size()";
9440 if (img_dim == DimCube && msl_options.emulate_cube_array)
9441 expr += " / 6";
9442 }
9443
9444 expr += ")";
9445
9446 emit_op(result_type: rslt_type_id, result_id: id, rhs: expr, forward_rhs: should_forward(id: img_id));
9447
9448 break;
9449 }
9450
9451 case OpImageQueryLod:
9452 {
9453 if (!msl_options.supports_msl_version(major: 2, minor: 2))
9454 SPIRV_CROSS_THROW("ImageQueryLod is only supported on MSL 2.2 and up.");
9455 uint32_t result_type = ops[0];
9456 uint32_t id = ops[1];
9457 uint32_t image_id = ops[2];
9458 uint32_t coord_id = ops[3];
9459 emit_uninitialized_temporary_expression(type: result_type, id);
9460
9461 std::string coord_expr = to_expression(id: coord_id);
9462 auto sampler_expr = to_sampler_expression(id: image_id);
9463 auto *combined = maybe_get<SPIRCombinedImageSampler>(id: image_id);
9464 auto image_expr = combined ? to_expression(id: combined->image) : to_expression(id: image_id);
9465 const SPIRType &image_type = expression_type(id: image_id);
9466 const SPIRType &coord_type = expression_type(id: coord_id);
9467
9468 switch (image_type.image.dim)
9469 {
9470 case Dim1D:
9471 if (!msl_options.texture_1D_as_2D)
9472 SPIRV_CROSS_THROW("ImageQueryLod is not supported on 1D textures.");
9473 [[fallthrough]];
9474 case Dim2D:
9475 if (coord_type.vecsize > 2)
9476 coord_expr = enclose_expression(expr: coord_expr) + ".xy";
9477 break;
9478 case DimCube:
9479 case Dim3D:
9480 if (coord_type.vecsize > 3)
9481 coord_expr = enclose_expression(expr: coord_expr) + ".xyz";
9482 break;
9483 default:
9484 SPIRV_CROSS_THROW("Bad image type given to OpImageQueryLod");
9485 }
9486
9487 // TODO: It is unclear if calculcate_clamped_lod also conditionally rounds
9488 // the reported LOD based on the sampler. NEAREST miplevel should
9489 // round the LOD, but LINEAR miplevel should not round.
9490 // Let's hope this does not become an issue ...
9491 statement(ts: to_expression(id), ts: ".x = ", ts&: image_expr, ts: ".calculate_clamped_lod(", ts&: sampler_expr, ts: ", ",
9492 ts&: coord_expr, ts: ");");
9493 statement(ts: to_expression(id), ts: ".y = ", ts&: image_expr, ts: ".calculate_unclamped_lod(", ts&: sampler_expr, ts: ", ",
9494 ts&: coord_expr, ts: ");");
9495 register_control_dependent_expression(expr: id);
9496 break;
9497 }
9498
9499#define MSL_ImgQry(qrytype) \
9500 do \
9501 { \
9502 uint32_t rslt_type_id = ops[0]; \
9503 auto &rslt_type = get<SPIRType>(rslt_type_id); \
9504 uint32_t id = ops[1]; \
9505 uint32_t img_id = ops[2]; \
9506 string img_exp = to_expression(img_id); \
9507 string expr = type_to_glsl(rslt_type) + "(" + img_exp + ".get_num_" #qrytype "())"; \
9508 emit_op(rslt_type_id, id, expr, should_forward(img_id)); \
9509 } while (false)
9510
9511 case OpImageQueryLevels:
9512 MSL_ImgQry(mip_levels);
9513 break;
9514
9515 case OpImageQuerySamples:
9516 MSL_ImgQry(samples);
9517 break;
9518
9519 case OpImage:
9520 {
9521 uint32_t result_type = ops[0];
9522 uint32_t id = ops[1];
9523 auto *combined = maybe_get<SPIRCombinedImageSampler>(id: ops[2]);
9524
9525 if (combined)
9526 {
9527 auto &e = emit_op(result_type, result_id: id, rhs: to_expression(id: combined->image), forward_rhs: true, suppress_usage_tracking: true);
9528 auto *var = maybe_get_backing_variable(chain: combined->image);
9529 if (var)
9530 e.loaded_from = var->self;
9531 }
9532 else
9533 {
9534 auto *var = maybe_get_backing_variable(chain: ops[2]);
9535 SPIRExpression *e;
9536 if (var && has_extended_decoration(id: var->self, decoration: SPIRVCrossDecorationDynamicImageSampler))
9537 e = &emit_op(result_type, result_id: id, rhs: join(ts: to_expression(id: ops[2]), ts: ".plane0"), forward_rhs: true, suppress_usage_tracking: true);
9538 else
9539 e = &emit_op(result_type, result_id: id, rhs: to_expression(id: ops[2]), forward_rhs: true, suppress_usage_tracking: true);
9540 if (var)
9541 e->loaded_from = var->self;
9542 }
9543 break;
9544 }
9545
9546 // Casting
9547 case OpQuantizeToF16:
9548 {
9549 uint32_t result_type = ops[0];
9550 uint32_t id = ops[1];
9551 uint32_t arg = ops[2];
9552 string exp = join(ts: "spvQuantizeToF16(", ts: to_expression(id: arg), ts: ")");
9553 emit_op(result_type, result_id: id, rhs: exp, forward_rhs: should_forward(id: arg));
9554 break;
9555 }
9556
9557 case OpInBoundsAccessChain:
9558 case OpAccessChain:
9559 case OpPtrAccessChain:
9560 if (is_tessellation_shader())
9561 {
9562 if (!emit_tessellation_access_chain(ops, length: instruction.length))
9563 CompilerGLSL::emit_instruction(instr: instruction);
9564 }
9565 else
9566 CompilerGLSL::emit_instruction(instr: instruction);
9567 fix_up_interpolant_access_chain(ops, length: instruction.length);
9568 break;
9569
9570 case OpStore:
9571 {
9572 const auto &type = expression_type(id: ops[0]);
9573
9574 if (is_out_of_bounds_tessellation_level(id_lhs: ops[0]))
9575 break;
9576
9577 if (needs_frag_discard_checks() &&
9578 (type.storage == StorageClassStorageBuffer || type.storage == StorageClassUniform))
9579 {
9580 // If we're in a continue block, this kludge will make the block too complex
9581 // to emit normally.
9582 assert(current_emitting_block);
9583 auto cont_type = continue_block_type(continue_block: *current_emitting_block);
9584 if (cont_type != SPIRBlock::ContinueNone && cont_type != SPIRBlock::ComplexLoop)
9585 {
9586 current_emitting_block->complex_continue = true;
9587 force_recompile();
9588 }
9589 statement(ts: "if (!", ts: builtin_to_glsl(builtin: BuiltInHelperInvocation, storage: StorageClassInput), ts: ")");
9590 begin_scope();
9591 }
9592 if (!maybe_emit_array_assignment(id_lhs: ops[0], id_rhs: ops[1]))
9593 CompilerGLSL::emit_instruction(instr: instruction);
9594 if (needs_frag_discard_checks() &&
9595 (type.storage == StorageClassStorageBuffer || type.storage == StorageClassUniform))
9596 end_scope();
9597 break;
9598 }
9599
9600 // Compute barriers
9601 case OpMemoryBarrier:
9602 emit_barrier(id_exe_scope: 0, id_mem_scope: ops[0], id_mem_sem: ops[1]);
9603 break;
9604
9605 case OpControlBarrier:
9606 // In GLSL a memory barrier is often followed by a control barrier.
9607 // But in MSL, memory barriers are also control barriers, so don't
9608 // emit a simple control barrier if a memory barrier has just been emitted.
9609 if (previous_instruction_opcode != OpMemoryBarrier)
9610 emit_barrier(id_exe_scope: ops[0], id_mem_scope: ops[1], id_mem_sem: ops[2]);
9611 break;
9612
9613 case OpOuterProduct:
9614 {
9615 uint32_t result_type = ops[0];
9616 uint32_t id = ops[1];
9617 uint32_t a = ops[2];
9618 uint32_t b = ops[3];
9619
9620 auto &type = get<SPIRType>(id: result_type);
9621 string expr = type_to_glsl_constructor(type);
9622 expr += "(";
9623 for (uint32_t col = 0; col < type.columns; col++)
9624 {
9625 expr += to_enclosed_unpacked_expression(id: a);
9626 expr += " * ";
9627 expr += to_extract_component_expression(id: b, index: col);
9628 if (col + 1 < type.columns)
9629 expr += ", ";
9630 }
9631 expr += ")";
9632 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: should_forward(id: a) && should_forward(id: b));
9633 inherit_expression_dependencies(dst: id, source: a);
9634 inherit_expression_dependencies(dst: id, source: b);
9635 break;
9636 }
9637
9638 case OpVectorTimesMatrix:
9639 case OpMatrixTimesVector:
9640 {
9641 if (!msl_options.invariant_float_math && !has_decoration(id: ops[1], decoration: DecorationNoContraction))
9642 {
9643 CompilerGLSL::emit_instruction(instr: instruction);
9644 break;
9645 }
9646
9647 // If the matrix needs transpose, just flip the multiply order.
9648 auto *e = maybe_get<SPIRExpression>(id: ops[opcode == OpMatrixTimesVector ? 2 : 3]);
9649 if (e && e->need_transpose)
9650 {
9651 e->need_transpose = false;
9652 string expr;
9653
9654 if (opcode == OpMatrixTimesVector)
9655 {
9656 expr = join(ts: "spvFMulVectorMatrix(", ts: to_enclosed_unpacked_expression(id: ops[3]), ts: ", ",
9657 ts: to_unpacked_row_major_matrix_expression(id: ops[2]), ts: ")");
9658 }
9659 else
9660 {
9661 expr = join(ts: "spvFMulMatrixVector(", ts: to_unpacked_row_major_matrix_expression(id: ops[3]), ts: ", ",
9662 ts: to_enclosed_unpacked_expression(id: ops[2]), ts: ")");
9663 }
9664
9665 bool forward = should_forward(id: ops[2]) && should_forward(id: ops[3]);
9666 emit_op(result_type: ops[0], result_id: ops[1], rhs: expr, forward_rhs: forward);
9667 e->need_transpose = true;
9668 inherit_expression_dependencies(dst: ops[1], source: ops[2]);
9669 inherit_expression_dependencies(dst: ops[1], source: ops[3]);
9670 }
9671 else
9672 {
9673 if (opcode == OpMatrixTimesVector)
9674 MSL_BFOP(spvFMulMatrixVector);
9675 else
9676 MSL_BFOP(spvFMulVectorMatrix);
9677 }
9678 break;
9679 }
9680
9681 case OpMatrixTimesMatrix:
9682 {
9683 if (!msl_options.invariant_float_math && !has_decoration(id: ops[1], decoration: DecorationNoContraction))
9684 {
9685 CompilerGLSL::emit_instruction(instr: instruction);
9686 break;
9687 }
9688
9689 auto *a = maybe_get<SPIRExpression>(id: ops[2]);
9690 auto *b = maybe_get<SPIRExpression>(id: ops[3]);
9691
9692 // If both matrices need transpose, we can multiply in flipped order and tag the expression as transposed.
9693 // a^T * b^T = (b * a)^T.
9694 if (a && b && a->need_transpose && b->need_transpose)
9695 {
9696 a->need_transpose = false;
9697 b->need_transpose = false;
9698
9699 auto expr =
9700 join(ts: "spvFMulMatrixMatrix(", ts: enclose_expression(expr: to_unpacked_row_major_matrix_expression(id: ops[3])), ts: ", ",
9701 ts: enclose_expression(expr: to_unpacked_row_major_matrix_expression(id: ops[2])), ts: ")");
9702
9703 bool forward = should_forward(id: ops[2]) && should_forward(id: ops[3]);
9704 auto &e = emit_op(result_type: ops[0], result_id: ops[1], rhs: expr, forward_rhs: forward);
9705 e.need_transpose = true;
9706 a->need_transpose = true;
9707 b->need_transpose = true;
9708 inherit_expression_dependencies(dst: ops[1], source: ops[2]);
9709 inherit_expression_dependencies(dst: ops[1], source: ops[3]);
9710 }
9711 else
9712 MSL_BFOP(spvFMulMatrixMatrix);
9713
9714 break;
9715 }
9716
9717 case OpIAddCarry:
9718 case OpISubBorrow:
9719 {
9720 uint32_t result_type = ops[0];
9721 uint32_t result_id = ops[1];
9722 uint32_t op0 = ops[2];
9723 uint32_t op1 = ops[3];
9724 auto &type = get<SPIRType>(id: result_type);
9725 emit_uninitialized_temporary_expression(type: result_type, id: result_id);
9726
9727 auto &res_type = get<SPIRType>(id: type.member_types[1]);
9728 if (opcode == OpIAddCarry)
9729 {
9730 statement(ts: to_expression(id: result_id), ts: ".", ts: to_member_name(type, index: 0), ts: " = ",
9731 ts: to_enclosed_unpacked_expression(id: op0), ts: " + ", ts: to_enclosed_unpacked_expression(id: op1), ts: ";");
9732 statement(ts: to_expression(id: result_id), ts: ".", ts: to_member_name(type, index: 1), ts: " = select(", ts: type_to_glsl(type: res_type),
9733 ts: "(1), ", ts: type_to_glsl(type: res_type), ts: "(0), ", ts: to_unpacked_expression(id: result_id), ts: ".", ts: to_member_name(type, index: 0),
9734 ts: " >= max(", ts: to_unpacked_expression(id: op0), ts: ", ", ts: to_unpacked_expression(id: op1), ts: "));");
9735 }
9736 else
9737 {
9738 statement(ts: to_expression(id: result_id), ts: ".", ts: to_member_name(type, index: 0), ts: " = ", ts: to_enclosed_unpacked_expression(id: op0), ts: " - ",
9739 ts: to_enclosed_unpacked_expression(id: op1), ts: ";");
9740 statement(ts: to_expression(id: result_id), ts: ".", ts: to_member_name(type, index: 1), ts: " = select(", ts: type_to_glsl(type: res_type),
9741 ts: "(1), ", ts: type_to_glsl(type: res_type), ts: "(0), ", ts: to_enclosed_unpacked_expression(id: op0),
9742 ts: " >= ", ts: to_enclosed_unpacked_expression(id: op1), ts: ");");
9743 }
9744 break;
9745 }
9746
9747 case OpUMulExtended:
9748 case OpSMulExtended:
9749 {
9750 uint32_t result_type = ops[0];
9751 uint32_t result_id = ops[1];
9752 uint32_t op0 = ops[2];
9753 uint32_t op1 = ops[3];
9754 auto &type = get<SPIRType>(id: result_type);
9755 auto &op_type = get<SPIRType>(id: type.member_types[0]);
9756 auto input_type = opcode == OpSMulExtended ? int_type : uint_type;
9757 string cast_op0, cast_op1;
9758
9759 binary_op_bitcast_helper(cast_op0, cast_op1, input_type, op0, op1, skip_cast_if_equal_type: false);
9760 auto expr = join(ts: "spvMulExtended<", ts: type_to_glsl(type), ts: ", ", ts: type_to_glsl(type: op_type), ts: ">(", ts&: cast_op0, ts: ", ", ts&: cast_op1, ts: ")");
9761 emit_op(result_type, result_id, rhs: expr, forward_rhs: true);
9762 break;
9763 }
9764
9765 case OpArrayLength:
9766 {
9767 auto &type = expression_type(id: ops[2]);
9768 uint32_t offset = type_struct_member_offset(type, index: ops[3]);
9769 uint32_t stride = type_struct_member_array_stride(type, index: ops[3]);
9770
9771 auto expr = join(ts: "(", ts: to_buffer_size_expression(id: ops[2]), ts: " - ", ts&: offset, ts: ") / ", ts&: stride);
9772 emit_op(result_type: ops[0], result_id: ops[1], rhs: expr, forward_rhs: true);
9773 break;
9774 }
9775
9776 // Legacy sub-group stuff ...
9777 case OpSubgroupBallotKHR:
9778 case OpSubgroupFirstInvocationKHR:
9779 case OpSubgroupReadInvocationKHR:
9780 case OpSubgroupAllKHR:
9781 case OpSubgroupAnyKHR:
9782 case OpSubgroupAllEqualKHR:
9783 emit_subgroup_op(i: instruction);
9784 break;
9785
9786 // SPV_INTEL_shader_integer_functions2
9787 case OpUCountLeadingZerosINTEL:
9788 MSL_UFOP(clz);
9789 break;
9790
9791 case OpUCountTrailingZerosINTEL:
9792 MSL_UFOP(ctz);
9793 break;
9794
9795 case OpAbsISubINTEL:
9796 case OpAbsUSubINTEL:
9797 MSL_BFOP(absdiff);
9798 break;
9799
9800 case OpIAddSatINTEL:
9801 case OpUAddSatINTEL:
9802 MSL_BFOP(addsat);
9803 break;
9804
9805 case OpIAverageINTEL:
9806 case OpUAverageINTEL:
9807 MSL_BFOP(hadd);
9808 break;
9809
9810 case OpIAverageRoundedINTEL:
9811 case OpUAverageRoundedINTEL:
9812 MSL_BFOP(rhadd);
9813 break;
9814
9815 case OpISubSatINTEL:
9816 case OpUSubSatINTEL:
9817 MSL_BFOP(subsat);
9818 break;
9819
9820 case OpIMul32x16INTEL:
9821 {
9822 uint32_t result_type = ops[0];
9823 uint32_t id = ops[1];
9824 uint32_t a = ops[2], b = ops[3];
9825 bool forward = should_forward(id: a) && should_forward(id: b);
9826 emit_op(result_type, result_id: id, rhs: join(ts: "int(short(", ts: to_unpacked_expression(id: a), ts: ")) * int(short(", ts: to_unpacked_expression(id: b), ts: "))"), forward_rhs: forward);
9827 inherit_expression_dependencies(dst: id, source: a);
9828 inherit_expression_dependencies(dst: id, source: b);
9829 break;
9830 }
9831
9832 case OpUMul32x16INTEL:
9833 {
9834 uint32_t result_type = ops[0];
9835 uint32_t id = ops[1];
9836 uint32_t a = ops[2], b = ops[3];
9837 bool forward = should_forward(id: a) && should_forward(id: b);
9838 emit_op(result_type, result_id: id, rhs: join(ts: "uint(ushort(", ts: to_unpacked_expression(id: a), ts: ")) * uint(ushort(", ts: to_unpacked_expression(id: b), ts: "))"), forward_rhs: forward);
9839 inherit_expression_dependencies(dst: id, source: a);
9840 inherit_expression_dependencies(dst: id, source: b);
9841 break;
9842 }
9843
9844 // SPV_EXT_demote_to_helper_invocation
9845 case OpDemoteToHelperInvocationEXT:
9846 if (!msl_options.supports_msl_version(major: 2, minor: 3))
9847 SPIRV_CROSS_THROW("discard_fragment() does not formally have demote semantics until MSL 2.3.");
9848 CompilerGLSL::emit_instruction(instr: instruction);
9849 break;
9850
9851 case OpIsHelperInvocationEXT:
9852 if (msl_options.is_ios() && !msl_options.supports_msl_version(major: 2, minor: 3))
9853 SPIRV_CROSS_THROW("simd_is_helper_thread() requires MSL 2.3 on iOS.");
9854 else if (msl_options.is_macos() && !msl_options.supports_msl_version(major: 2, minor: 1))
9855 SPIRV_CROSS_THROW("simd_is_helper_thread() requires MSL 2.1 on macOS.");
9856 emit_op(result_type: ops[0], result_id: ops[1],
9857 rhs: needs_manual_helper_invocation_updates() ? builtin_to_glsl(builtin: BuiltInHelperInvocation, storage: StorageClassInput) :
9858 "simd_is_helper_thread()",
9859 forward_rhs: false);
9860 break;
9861
9862 case OpBeginInvocationInterlockEXT:
9863 case OpEndInvocationInterlockEXT:
9864 if (!msl_options.supports_msl_version(major: 2, minor: 0))
9865 SPIRV_CROSS_THROW("Raster order groups require MSL 2.0.");
9866 break; // Nothing to do in the body
9867
9868 case OpConvertUToAccelerationStructureKHR:
9869 SPIRV_CROSS_THROW("ConvertUToAccelerationStructure is not supported in MSL.");
9870 case OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR:
9871 SPIRV_CROSS_THROW("BindingTableRecordOffset is not supported in MSL.");
9872
9873 case OpRayQueryInitializeKHR:
9874 {
9875 flush_variable_declaration(id: ops[0]);
9876 register_write(chain: ops[0]);
9877 add_spv_func_and_recompile(spv_func: SPVFuncImplRayQueryIntersectionParams);
9878
9879 statement(ts: to_expression(id: ops[0]), ts: ".reset(", ts: "ray(", ts: to_expression(id: ops[4]), ts: ", ", ts: to_expression(id: ops[6]), ts: ", ",
9880 ts: to_expression(id: ops[5]), ts: ", ", ts: to_expression(id: ops[7]), ts: "), ", ts: to_expression(id: ops[1]), ts: ", ", ts: to_expression(id: ops[3]),
9881 ts: ", spvMakeIntersectionParams(", ts: to_expression(id: ops[2]), ts: "));");
9882 break;
9883 }
9884 case OpRayQueryProceedKHR:
9885 {
9886 flush_variable_declaration(id: ops[0]);
9887 register_write(chain: ops[2]);
9888 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".next()"), forward_rhs: false);
9889 break;
9890 }
9891#define MSL_RAY_QUERY_IS_CANDIDATE get<SPIRConstant>(ops[3]).scalar_i32() == 0
9892
9893#define MSL_RAY_QUERY_GET_OP(op, msl_op) \
9894 case OpRayQueryGet##op##KHR: \
9895 flush_variable_declaration(ops[2]); \
9896 emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".get_" #msl_op "()"), false); \
9897 break
9898
9899#define MSL_RAY_QUERY_OP_INNER2(op, msl_prefix, msl_op) \
9900 case OpRayQueryGet##op##KHR: \
9901 flush_variable_declaration(ops[2]); \
9902 if (MSL_RAY_QUERY_IS_CANDIDATE) \
9903 emit_op(ops[0], ops[1], join(to_expression(ops[2]), #msl_prefix "_candidate_" #msl_op "()"), false); \
9904 else \
9905 emit_op(ops[0], ops[1], join(to_expression(ops[2]), #msl_prefix "_committed_" #msl_op "()"), false); \
9906 break
9907
9908#define MSL_RAY_QUERY_GET_OP2(op, msl_op) MSL_RAY_QUERY_OP_INNER2(op, .get, msl_op)
9909#define MSL_RAY_QUERY_IS_OP2(op, msl_op) MSL_RAY_QUERY_OP_INNER2(op, .is, msl_op)
9910
9911 MSL_RAY_QUERY_GET_OP(RayTMin, ray_min_distance);
9912 MSL_RAY_QUERY_GET_OP(WorldRayOrigin, world_space_ray_origin);
9913 MSL_RAY_QUERY_GET_OP(WorldRayDirection, world_space_ray_direction);
9914 MSL_RAY_QUERY_GET_OP2(IntersectionInstanceId, instance_id);
9915 MSL_RAY_QUERY_GET_OP2(IntersectionInstanceCustomIndex, user_instance_id);
9916 MSL_RAY_QUERY_GET_OP2(IntersectionBarycentrics, triangle_barycentric_coord);
9917 MSL_RAY_QUERY_GET_OP2(IntersectionPrimitiveIndex, primitive_id);
9918 MSL_RAY_QUERY_GET_OP2(IntersectionGeometryIndex, geometry_id);
9919 MSL_RAY_QUERY_GET_OP2(IntersectionObjectRayOrigin, ray_origin);
9920 MSL_RAY_QUERY_GET_OP2(IntersectionObjectRayDirection, ray_direction);
9921 MSL_RAY_QUERY_GET_OP2(IntersectionObjectToWorld, object_to_world_transform);
9922 MSL_RAY_QUERY_GET_OP2(IntersectionWorldToObject, world_to_object_transform);
9923 MSL_RAY_QUERY_IS_OP2(IntersectionFrontFace, triangle_front_facing);
9924
9925 case OpRayQueryGetIntersectionTypeKHR:
9926 flush_variable_declaration(id: ops[2]);
9927 if (MSL_RAY_QUERY_IS_CANDIDATE)
9928 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: "uint(", ts: to_expression(id: ops[2]), ts: ".get_candidate_intersection_type()) - 1"),
9929 forward_rhs: false);
9930 else
9931 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: "uint(", ts: to_expression(id: ops[2]), ts: ".get_committed_intersection_type())"), forward_rhs: false);
9932 break;
9933 case OpRayQueryGetIntersectionTKHR:
9934 flush_variable_declaration(id: ops[2]);
9935 if (MSL_RAY_QUERY_IS_CANDIDATE)
9936 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".get_candidate_triangle_distance()"), forward_rhs: false);
9937 else
9938 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".get_committed_distance()"), forward_rhs: false);
9939 break;
9940 case OpRayQueryGetIntersectionCandidateAABBOpaqueKHR:
9941 {
9942 flush_variable_declaration(id: ops[0]);
9943 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".is_candidate_non_opaque_bounding_box()"), forward_rhs: false);
9944 break;
9945 }
9946 case OpRayQueryConfirmIntersectionKHR:
9947 flush_variable_declaration(id: ops[0]);
9948 register_write(chain: ops[0]);
9949 statement(ts: to_expression(id: ops[0]), ts: ".commit_triangle_intersection();");
9950 break;
9951 case OpRayQueryGenerateIntersectionKHR:
9952 flush_variable_declaration(id: ops[0]);
9953 register_write(chain: ops[0]);
9954 statement(ts: to_expression(id: ops[0]), ts: ".commit_bounding_box_intersection(", ts: to_expression(id: ops[1]), ts: ");");
9955 break;
9956 case OpRayQueryTerminateKHR:
9957 flush_variable_declaration(id: ops[0]);
9958 register_write(chain: ops[0]);
9959 statement(ts: to_expression(id: ops[0]), ts: ".abort();");
9960 break;
9961#undef MSL_RAY_QUERY_GET_OP
9962#undef MSL_RAY_QUERY_IS_CANDIDATE
9963#undef MSL_RAY_QUERY_IS_OP2
9964#undef MSL_RAY_QUERY_GET_OP2
9965#undef MSL_RAY_QUERY_OP_INNER2
9966
9967 case OpConvertPtrToU:
9968 case OpConvertUToPtr:
9969 case OpBitcast:
9970 {
9971 auto &type = get<SPIRType>(id: ops[0]);
9972 auto &input_type = expression_type(id: ops[2]);
9973
9974 if (opcode != OpBitcast || type.pointer || input_type.pointer)
9975 {
9976 string op;
9977
9978 if (type.vecsize == 1 && input_type.vecsize == 1)
9979 op = join(ts: "reinterpret_cast<", ts: type_to_glsl(type), ts: ">(", ts: to_unpacked_expression(id: ops[2]), ts: ")");
9980 else if (input_type.vecsize == 2)
9981 op = join(ts: "reinterpret_cast<", ts: type_to_glsl(type), ts: ">(as_type<ulong>(", ts: to_unpacked_expression(id: ops[2]), ts: "))");
9982 else
9983 op = join(ts: "as_type<", ts: type_to_glsl(type), ts: ">(reinterpret_cast<ulong>(", ts: to_unpacked_expression(id: ops[2]), ts: "))");
9984
9985 emit_op(result_type: ops[0], result_id: ops[1], rhs: op, forward_rhs: should_forward(id: ops[2]));
9986 inherit_expression_dependencies(dst: ops[1], source: ops[2]);
9987 }
9988 else
9989 CompilerGLSL::emit_instruction(instr: instruction);
9990
9991 break;
9992 }
9993
9994 case OpSDot:
9995 case OpUDot:
9996 case OpSUDot:
9997 {
9998 uint32_t result_type = ops[0];
9999 uint32_t id = ops[1];
10000 uint32_t vec1 = ops[2];
10001 uint32_t vec2 = ops[3];
10002
10003 auto &input_type1 = expression_type(id: vec1);
10004 auto &input_type2 = expression_type(id: vec2);
10005
10006 string vec1input, vec2input;
10007 auto input_size = input_type1.vecsize;
10008 if (instruction.length == 5)
10009 {
10010 if (ops[4] == PackedVectorFormatPackedVectorFormat4x8Bit)
10011 {
10012 string type = opcode == OpSDot || opcode == OpSUDot ? "char4" : "uchar4";
10013 vec1input = join(ts: "as_type<", ts&: type, ts: ">(", ts: to_expression(id: vec1), ts: ")");
10014 type = opcode == OpSDot ? "char4" : "uchar4";
10015 vec2input = join(ts: "as_type<", ts&: type, ts: ">(", ts: to_expression(id: vec2), ts: ")");
10016 input_size = 4;
10017 }
10018 else
10019 SPIRV_CROSS_THROW("Packed vector formats other than 4x8Bit for integer dot product is not supported.");
10020 }
10021 else
10022 {
10023 // Inputs are sign or zero-extended to their target width.
10024 SPIRType::BaseType vec1_expected_type =
10025 opcode != OpUDot ?
10026 to_signed_basetype(width: input_type1.width) :
10027 to_unsigned_basetype(width: input_type1.width);
10028
10029 SPIRType::BaseType vec2_expected_type =
10030 opcode != OpSDot ?
10031 to_unsigned_basetype(width: input_type2.width) :
10032 to_signed_basetype(width: input_type2.width);
10033
10034 vec1input = bitcast_expression(target_type: vec1_expected_type, arg: vec1);
10035 vec2input = bitcast_expression(target_type: vec2_expected_type, arg: vec2);
10036 }
10037
10038 auto &type = get<SPIRType>(id: result_type);
10039
10040 // We'll get the appropriate sign-extend or zero-extend, no matter which type we cast to here.
10041 // The addition in reduce_add is sign-invariant.
10042 auto result_type_cast = join(ts: type_to_glsl(type), ts&: input_size);
10043
10044 string exp = join(ts: "reduce_add(",
10045 ts&: result_type_cast, ts: "(", ts&: vec1input, ts: ") * ",
10046 ts&: result_type_cast, ts: "(", ts&: vec2input, ts: "))");
10047
10048 emit_op(result_type, result_id: id, rhs: exp, forward_rhs: should_forward(id: vec1) && should_forward(id: vec2));
10049 inherit_expression_dependencies(dst: id, source: vec1);
10050 inherit_expression_dependencies(dst: id, source: vec2);
10051 break;
10052 }
10053
10054 case OpSDotAccSat:
10055 case OpUDotAccSat:
10056 case OpSUDotAccSat:
10057 {
10058 uint32_t result_type = ops[0];
10059 uint32_t id = ops[1];
10060 uint32_t vec1 = ops[2];
10061 uint32_t vec2 = ops[3];
10062 uint32_t acc = ops[4];
10063
10064 auto input_type1 = expression_type(id: vec1);
10065 auto input_type2 = expression_type(id: vec2);
10066
10067 string vec1input, vec2input;
10068 if (instruction.length == 6)
10069 {
10070 if (ops[5] == PackedVectorFormatPackedVectorFormat4x8Bit)
10071 {
10072 string type = opcode == OpSDotAccSat || opcode == OpSUDotAccSat ? "char4" : "uchar4";
10073 vec1input = join(ts: "as_type<", ts&: type, ts: ">(", ts: to_expression(id: vec1), ts: ")");
10074 type = opcode == OpSDotAccSat ? "char4" : "uchar4";
10075 vec2input = join(ts: "as_type<", ts&: type, ts: ">(", ts: to_expression(id: vec2), ts: ")");
10076 input_type1.vecsize = 4;
10077 input_type2.vecsize = 4;
10078 }
10079 else
10080 SPIRV_CROSS_THROW("Packed vector formats other than 4x8Bit for integer dot product is not supported.");
10081 }
10082 else
10083 {
10084 // Inputs are sign or zero-extended to their target width.
10085 SPIRType::BaseType vec1_expected_type =
10086 opcode != OpUDotAccSat ?
10087 to_signed_basetype(width: input_type1.width) :
10088 to_unsigned_basetype(width: input_type1.width);
10089
10090 SPIRType::BaseType vec2_expected_type =
10091 opcode != OpSDotAccSat ?
10092 to_unsigned_basetype(width: input_type2.width) :
10093 to_signed_basetype(width: input_type2.width);
10094
10095 vec1input = bitcast_expression(target_type: vec1_expected_type, arg: vec1);
10096 vec2input = bitcast_expression(target_type: vec2_expected_type, arg: vec2);
10097 }
10098
10099 auto &type = get<SPIRType>(id: result_type);
10100
10101 SPIRType::BaseType pre_saturate_type =
10102 opcode != OpUDotAccSat ?
10103 to_signed_basetype(width: type.width) :
10104 to_unsigned_basetype(width: type.width);
10105
10106 input_type1.basetype = pre_saturate_type;
10107 input_type2.basetype = pre_saturate_type;
10108
10109 string exp = join(ts: type_to_glsl(type), ts: "(addsat(reduce_add(",
10110 ts: type_to_glsl(type: input_type1), ts: "(", ts&: vec1input, ts: ") * ",
10111 ts: type_to_glsl(type: input_type2), ts: "(", ts&: vec2input, ts: ")), ",
10112 ts: bitcast_expression(target_type: pre_saturate_type, arg: acc), ts: "))");
10113
10114 emit_op(result_type, result_id: id, rhs: exp, forward_rhs: should_forward(id: vec1) && should_forward(id: vec2));
10115 inherit_expression_dependencies(dst: id, source: vec1);
10116 inherit_expression_dependencies(dst: id, source: vec2);
10117 break;
10118 }
10119
10120 case OpSetMeshOutputsEXT:
10121 {
10122 flush_variable_declaration(id: builtin_mesh_primitive_indices_id);
10123 add_spv_func_and_recompile(spv_func: SPVFuncImplSetMeshOutputsEXT);
10124 statement(ts: "spvSetMeshOutputsEXT(gl_LocalInvocationIndex, spvMeshSizes, ", ts: to_unpacked_expression(id: ops[0]), ts: ", ", ts: to_unpacked_expression(id: ops[1]), ts: ");");
10125 break;
10126 }
10127
10128 default:
10129 CompilerGLSL::emit_instruction(instr: instruction);
10130 break;
10131 }
10132
10133 previous_instruction_opcode = opcode;
10134}
10135
10136void CompilerMSL::emit_texture_op(const Instruction &i, bool sparse)
10137{
10138 if (sparse)
10139 SPIRV_CROSS_THROW("Sparse feedback not yet supported in MSL.");
10140
10141 if (msl_options.use_framebuffer_fetch_subpasses)
10142 {
10143 auto *ops = stream(instr: i);
10144
10145 uint32_t result_type_id = ops[0];
10146 uint32_t id = ops[1];
10147 uint32_t img = ops[2];
10148
10149 auto &type = expression_type(id: img);
10150 auto &imgtype = get<SPIRType>(id: type.self);
10151
10152 // Use Metal's native frame-buffer fetch API for subpass inputs.
10153 if (imgtype.image.dim == DimSubpassData)
10154 {
10155 // Subpass inputs cannot be invalidated,
10156 // so just forward the expression directly.
10157 string expr = to_expression(id: img);
10158 emit_op(result_type: result_type_id, result_id: id, rhs: expr, forward_rhs: true);
10159 return;
10160 }
10161 }
10162
10163 // Fallback to default implementation
10164 CompilerGLSL::emit_texture_op(i, sparse);
10165}
10166
10167void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uint32_t id_mem_sem)
10168{
10169 auto model = get_execution_model();
10170
10171 if (model != ExecutionModelGLCompute && model != ExecutionModelTaskEXT &&
10172 model != ExecutionModelMeshEXT && !is_tesc_shader())
10173 {
10174 return;
10175 }
10176
10177 uint32_t exe_scope = id_exe_scope ? evaluate_constant_u32(id: id_exe_scope) : uint32_t(ScopeInvocation);
10178 uint32_t mem_scope = id_mem_scope ? evaluate_constant_u32(id: id_mem_scope) : uint32_t(ScopeInvocation);
10179 // Use the wider of the two scopes (smaller value)
10180 exe_scope = min(a: exe_scope, b: mem_scope);
10181
10182 if (msl_options.emulate_subgroups && exe_scope >= ScopeSubgroup && !id_mem_sem)
10183 // In this case, we assume a "subgroup" size of 1. The barrier, then, is a noop.
10184 return;
10185
10186 string bar_stmt;
10187 if ((msl_options.is_ios() && msl_options.supports_msl_version(major: 1, minor: 2)) || msl_options.supports_msl_version(major: 2))
10188 bar_stmt = exe_scope < ScopeSubgroup ? "threadgroup_barrier" : "simdgroup_barrier";
10189 else
10190 bar_stmt = "threadgroup_barrier";
10191 bar_stmt += "(";
10192
10193 uint32_t mem_sem = id_mem_sem ? evaluate_constant_u32(id: id_mem_sem) : uint32_t(MemorySemanticsMaskNone);
10194
10195 // Use the | operator to combine flags if we can.
10196 if (msl_options.supports_msl_version(major: 1, minor: 2))
10197 {
10198 string mem_flags = "";
10199 // For tesc shaders, this also affects objects in the Output storage class.
10200 // Since in Metal, these are placed in a device buffer, we have to sync device memory here.
10201 if (is_tesc_shader() ||
10202 (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)))
10203 mem_flags += "mem_flags::mem_device";
10204
10205 // Fix tessellation patch function processing
10206 if (is_tesc_shader() || (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask)))
10207 {
10208 if (!mem_flags.empty())
10209 mem_flags += " | ";
10210 mem_flags += "mem_flags::mem_threadgroup";
10211 }
10212 if (mem_sem & MemorySemanticsImageMemoryMask)
10213 {
10214 if (!mem_flags.empty())
10215 mem_flags += " | ";
10216 mem_flags += "mem_flags::mem_texture";
10217 }
10218
10219 if (mem_flags.empty())
10220 mem_flags = "mem_flags::mem_none";
10221
10222 bar_stmt += mem_flags;
10223 }
10224 else
10225 {
10226 if ((mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)) &&
10227 (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask)))
10228 bar_stmt += "mem_flags::mem_device_and_threadgroup";
10229 else if (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask))
10230 bar_stmt += "mem_flags::mem_device";
10231 else if (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask))
10232 bar_stmt += "mem_flags::mem_threadgroup";
10233 else if (mem_sem & MemorySemanticsImageMemoryMask)
10234 bar_stmt += "mem_flags::mem_texture";
10235 else
10236 bar_stmt += "mem_flags::mem_none";
10237 }
10238
10239 bar_stmt += ");";
10240
10241 statement(ts&: bar_stmt);
10242
10243 assert(current_emitting_block);
10244 flush_control_dependent_expressions(block: current_emitting_block->self);
10245 flush_all_active_variables();
10246}
10247
10248static bool storage_class_array_is_thread(StorageClass storage)
10249{
10250 switch (storage)
10251 {
10252 case StorageClassInput:
10253 case StorageClassOutput:
10254 case StorageClassGeneric:
10255 case StorageClassFunction:
10256 case StorageClassPrivate:
10257 return true;
10258
10259 default:
10260 return false;
10261 }
10262}
10263
10264bool CompilerMSL::emit_array_copy(const char *expr, uint32_t lhs_id, uint32_t rhs_id,
10265 StorageClass lhs_storage, StorageClass rhs_storage)
10266{
10267 // Allow Metal to use the array<T> template to make arrays a value type.
10268 // This, however, cannot be used for threadgroup address specifiers, so consider the custom array copy as fallback.
10269 bool lhs_is_thread_storage = storage_class_array_is_thread(storage: lhs_storage);
10270 bool rhs_is_thread_storage = storage_class_array_is_thread(storage: rhs_storage);
10271
10272 bool lhs_is_array_template = lhs_is_thread_storage || lhs_storage == StorageClassWorkgroup;
10273 bool rhs_is_array_template = rhs_is_thread_storage || rhs_storage == StorageClassWorkgroup;
10274
10275 // Special considerations for stage IO variables.
10276 // If the variable is actually backed by non-user visible device storage, we use array templates for those.
10277 //
10278 // Another special consideration is given to thread local variables which happen to have Offset decorations
10279 // applied to them. Block-like types do not use array templates, so we need to force POD path if we detect
10280 // these scenarios. This check isn't perfect since it would be technically possible to mix and match these things,
10281 // and for a fully correct solution we might have to track array template state through access chains as well,
10282 // but for all reasonable use cases, this should suffice.
10283 // This special case should also only apply to Function/Private storage classes.
10284 // We should not check backing variable for temporaries.
10285 auto *lhs_var = maybe_get_backing_variable(chain: lhs_id);
10286 if (lhs_var && lhs_storage == StorageClassStorageBuffer && storage_class_array_is_thread(storage: lhs_var->storage))
10287 lhs_is_array_template = true;
10288 else if (lhs_var && lhs_storage != StorageClassGeneric && type_is_block_like(type: get<SPIRType>(id: lhs_var->basetype)))
10289 lhs_is_array_template = false;
10290
10291 auto *rhs_var = maybe_get_backing_variable(chain: rhs_id);
10292 if (rhs_var && rhs_storage == StorageClassStorageBuffer && storage_class_array_is_thread(storage: rhs_var->storage))
10293 rhs_is_array_template = true;
10294 else if (rhs_var && rhs_storage != StorageClassGeneric && type_is_block_like(type: get<SPIRType>(id: rhs_var->basetype)))
10295 rhs_is_array_template = false;
10296
10297 // If threadgroup storage qualifiers are *not* used:
10298 // Avoid spvCopy* wrapper functions; Otherwise, spvUnsafeArray<> template cannot be used with that storage qualifier.
10299 if (lhs_is_array_template && rhs_is_array_template && !using_builtin_array())
10300 {
10301 // Fall back to normal copy path.
10302 return false;
10303 }
10304 else
10305 {
10306 // Ensure the LHS variable has been declared
10307 if (lhs_var)
10308 flush_variable_declaration(id: lhs_var->self);
10309
10310 string lhs;
10311 if (expr)
10312 lhs = expr;
10313 else
10314 lhs = to_expression(id: lhs_id);
10315
10316 // Assignment from an array initializer is fine.
10317 auto &type = expression_type(id: rhs_id);
10318 auto *var = maybe_get_backing_variable(chain: rhs_id);
10319
10320 // Unfortunately, we cannot template on address space in MSL,
10321 // so explicit address space redirection it is ...
10322 bool is_constant = false;
10323 if (ir.ids[rhs_id].get_type() == TypeConstant)
10324 {
10325 is_constant = true;
10326 }
10327 else if (var && var->remapped_variable && var->statically_assigned &&
10328 ir.ids[var->static_expression].get_type() == TypeConstant)
10329 {
10330 is_constant = true;
10331 }
10332 else if (rhs_storage == StorageClassUniform || rhs_storage == StorageClassUniformConstant)
10333 {
10334 is_constant = true;
10335 }
10336
10337 // For the case where we have OpLoad triggering an array copy,
10338 // we cannot easily detect this case ahead of time since it's
10339 // context dependent. We might have to force a recompile here
10340 // if this is the only use of array copies in our shader.
10341 add_spv_func_and_recompile(spv_func: type.array.size() > 1 ? SPVFuncImplArrayCopyMultidim : SPVFuncImplArrayCopy);
10342
10343 const char *tag = nullptr;
10344 if (lhs_is_thread_storage && is_constant)
10345 tag = "FromConstantToStack";
10346 else if (lhs_storage == StorageClassWorkgroup && is_constant)
10347 tag = "FromConstantToThreadGroup";
10348 else if (lhs_is_thread_storage && rhs_is_thread_storage)
10349 tag = "FromStackToStack";
10350 else if (lhs_storage == StorageClassWorkgroup && rhs_is_thread_storage)
10351 tag = "FromStackToThreadGroup";
10352 else if (lhs_is_thread_storage && rhs_storage == StorageClassWorkgroup)
10353 tag = "FromThreadGroupToStack";
10354 else if (lhs_storage == StorageClassWorkgroup && rhs_storage == StorageClassWorkgroup)
10355 tag = "FromThreadGroupToThreadGroup";
10356 else if (lhs_storage == StorageClassStorageBuffer && rhs_storage == StorageClassStorageBuffer)
10357 tag = "FromDeviceToDevice";
10358 else if (lhs_storage == StorageClassStorageBuffer && is_constant)
10359 tag = "FromConstantToDevice";
10360 else if (lhs_storage == StorageClassStorageBuffer && rhs_storage == StorageClassWorkgroup)
10361 tag = "FromThreadGroupToDevice";
10362 else if (lhs_storage == StorageClassStorageBuffer && rhs_is_thread_storage)
10363 tag = "FromStackToDevice";
10364 else if (lhs_storage == StorageClassWorkgroup && rhs_storage == StorageClassStorageBuffer)
10365 tag = "FromDeviceToThreadGroup";
10366 else if (lhs_is_thread_storage && rhs_storage == StorageClassStorageBuffer)
10367 tag = "FromDeviceToStack";
10368 else
10369 SPIRV_CROSS_THROW("Unknown storage class used for copying arrays.");
10370
10371 // Pass internal array of spvUnsafeArray<> into wrapper functions
10372 if (lhs_is_array_template && rhs_is_array_template && !msl_options.force_native_arrays)
10373 statement(ts: "spvArrayCopy", ts&: tag, ts: "(", ts&: lhs, ts: ".elements, ", ts: to_expression(id: rhs_id), ts: ".elements);");
10374 if (lhs_is_array_template && !msl_options.force_native_arrays)
10375 statement(ts: "spvArrayCopy", ts&: tag, ts: "(", ts&: lhs, ts: ".elements, ", ts: to_expression(id: rhs_id), ts: ");");
10376 else if (rhs_is_array_template && !msl_options.force_native_arrays)
10377 statement(ts: "spvArrayCopy", ts&: tag, ts: "(", ts&: lhs, ts: ", ", ts: to_expression(id: rhs_id), ts: ".elements);");
10378 else
10379 statement(ts: "spvArrayCopy", ts&: tag, ts: "(", ts&: lhs, ts: ", ", ts: to_expression(id: rhs_id), ts: ");");
10380 }
10381
10382 return true;
10383}
10384
10385uint32_t CompilerMSL::get_physical_tess_level_array_size(spv::BuiltIn builtin) const
10386{
10387 if (is_tessellating_triangles())
10388 return builtin == BuiltInTessLevelInner ? 1 : 3;
10389 else
10390 return builtin == BuiltInTessLevelInner ? 2 : 4;
10391}
10392
10393// Since MSL does not allow arrays to be copied via simple variable assignment,
10394// if the LHS and RHS represent an assignment of an entire array, it must be
10395// implemented by calling an array copy function.
10396// Returns whether the struct assignment was emitted.
10397bool CompilerMSL::maybe_emit_array_assignment(uint32_t id_lhs, uint32_t id_rhs)
10398{
10399 // We only care about assignments of an entire array
10400 auto &type = expression_type(id: id_lhs);
10401 if (!is_array(type: get_pointee_type(type)))
10402 return false;
10403
10404 auto *var = maybe_get<SPIRVariable>(id: id_lhs);
10405
10406 // Is this a remapped, static constant? Don't do anything.
10407 if (var && var->remapped_variable && var->statically_assigned)
10408 return true;
10409
10410 if (ir.ids[id_rhs].get_type() == TypeConstant && var && var->deferred_declaration)
10411 {
10412 // Special case, if we end up declaring a variable when assigning the constant array,
10413 // we can avoid the copy by directly assigning the constant expression.
10414 // This is likely necessary to be able to use a variable as a true look-up table, as it is unlikely
10415 // the compiler will be able to optimize the spvArrayCopy() into a constant LUT.
10416 // After a variable has been declared, we can no longer assign constant arrays in MSL unfortunately.
10417 statement(ts: to_expression(id: id_lhs), ts: " = ", ts: constant_expression(c: get<SPIRConstant>(id: id_rhs)), ts: ";");
10418 return true;
10419 }
10420
10421 if (is_tesc_shader() && has_decoration(id: id_lhs, decoration: DecorationBuiltIn))
10422 {
10423 auto builtin = BuiltIn(get_decoration(id: id_lhs, decoration: DecorationBuiltIn));
10424 // Need to manually unroll the array store.
10425 if (builtin == BuiltInTessLevelInner || builtin == BuiltInTessLevelOuter)
10426 {
10427 uint32_t array_size = get_physical_tess_level_array_size(builtin);
10428 if (array_size == 1)
10429 statement(ts: to_expression(id: id_lhs), ts: " = half(", ts: to_expression(id: id_rhs), ts: "[0]);");
10430 else
10431 {
10432 for (uint32_t i = 0; i < array_size; i++)
10433 statement(ts: to_expression(id: id_lhs), ts: "[", ts&: i, ts: "] = half(", ts: to_expression(id: id_rhs), ts: "[", ts&: i, ts: "]);");
10434 }
10435 return true;
10436 }
10437 }
10438
10439 auto lhs_storage = get_expression_effective_storage_class(ptr: id_lhs);
10440 auto rhs_storage = get_expression_effective_storage_class(ptr: id_rhs);
10441 if (!emit_array_copy(expr: nullptr, lhs_id: id_lhs, rhs_id: id_rhs, lhs_storage, rhs_storage))
10442 return false;
10443
10444 register_write(chain: id_lhs);
10445
10446 return true;
10447}
10448
10449// Emits one of the atomic functions. In MSL, the atomic functions operate on pointers
10450void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id, const char *op, Op opcode,
10451 uint32_t mem_order_1, uint32_t mem_order_2, bool has_mem_order_2, uint32_t obj, uint32_t op1,
10452 bool op1_is_pointer, bool op1_is_literal, uint32_t op2)
10453{
10454 string exp;
10455
10456 auto &ptr_type = expression_type(id: obj);
10457 auto &type = get_pointee_type(type: ptr_type);
10458 auto expected_type = type.basetype;
10459 if (opcode == OpAtomicUMax || opcode == OpAtomicUMin)
10460 expected_type = to_unsigned_basetype(width: type.width);
10461 else if (opcode == OpAtomicSMax || opcode == OpAtomicSMin)
10462 expected_type = to_signed_basetype(width: type.width);
10463
10464 bool use_native_image_atomic;
10465 if (msl_options.supports_msl_version(major: 3, minor: 1))
10466 use_native_image_atomic = check_atomic_image(id: obj);
10467 else
10468 use_native_image_atomic = false;
10469
10470 if (type.width == 64)
10471 SPIRV_CROSS_THROW("MSL currently does not support 64-bit atomics.");
10472
10473 auto remapped_type = type;
10474 remapped_type.basetype = expected_type;
10475
10476 auto *var = maybe_get_backing_variable(chain: obj);
10477 const auto *res_type = var ? &get<SPIRType>(id: var->basetype) : nullptr;
10478 assert(type.storage != StorageClassImage || res_type);
10479
10480 bool is_atomic_compare_exchange_strong = op1_is_pointer && op1;
10481
10482 bool check_discard = opcode != OpAtomicLoad && needs_frag_discard_checks() &&
10483 ptr_type.storage != StorageClassWorkgroup;
10484
10485 // Even compare exchange atomics are vec4 on metal for ... reasons :v
10486 uint32_t vec4_temporary_id = 0;
10487 if (use_native_image_atomic && is_atomic_compare_exchange_strong)
10488 {
10489 uint32_t &tmp_id = extra_sub_expressions[result_id];
10490 if (!tmp_id)
10491 {
10492 tmp_id = ir.increase_bound_by(count: 2);
10493
10494 auto vec4_type = get<SPIRType>(id: result_type);
10495 vec4_type.vecsize = 4;
10496 set<SPIRType>(id: tmp_id + 1, args&: vec4_type);
10497 }
10498
10499 vec4_temporary_id = tmp_id;
10500 }
10501
10502 if (check_discard)
10503 {
10504 if (is_atomic_compare_exchange_strong)
10505 {
10506 // We're already emitting a CAS loop here; a conditional won't hurt.
10507 emit_uninitialized_temporary_expression(type: result_type, id: result_id);
10508 if (vec4_temporary_id)
10509 emit_uninitialized_temporary_expression(type: vec4_temporary_id + 1, id: vec4_temporary_id);
10510 statement(ts: "if (!", ts: builtin_to_glsl(builtin: BuiltInHelperInvocation, storage: StorageClassInput), ts: ")");
10511 begin_scope();
10512 }
10513 else
10514 exp = join(ts: "(!", ts: builtin_to_glsl(builtin: BuiltInHelperInvocation, storage: StorageClassInput), ts: " ? ");
10515 }
10516
10517 if (use_native_image_atomic)
10518 {
10519 auto obj_expression = to_expression(id: obj);
10520 auto split_index = obj_expression.find_first_of(c: '@');
10521 bool needs_reinterpret = opcode == OpAtomicUMax || opcode == OpAtomicUMin || opcode == OpAtomicSMax || opcode == OpAtomicSMin;
10522 needs_reinterpret &= type.basetype != expected_type;
10523 SPIRVariable *backing_var = nullptr;
10524
10525 // Try to avoid waiting until not force recompile later mode to enable force recompile later
10526 if (needs_reinterpret && (backing_var = maybe_get_backing_variable(chain: obj)))
10527 add_spv_func_and_recompile(spv_func: SPVFuncImplTextureCast);
10528
10529 // Will only be false if we're in "force recompile later" mode.
10530 if (split_index != string::npos)
10531 {
10532 auto coord = obj_expression.substr(pos: split_index + 1);
10533 auto image_expr = obj_expression.substr(pos: 0, n: split_index);
10534
10535 // Handle problem cases with sign where we need signed min/max on a uint image for example.
10536 // It seems to work to cast the texture type itself, even if it is probably wildly outside of spec,
10537 // but SPIR-V requires this to work.
10538 if (needs_reinterpret && backing_var)
10539 {
10540 assert(spv_function_implementations.count(SPVFuncImplTextureCast) && "Should have been added above");
10541
10542 const auto *backing_type = &get<SPIRType>(id: backing_var->basetype);
10543 while (backing_type->op != OpTypeImage)
10544 backing_type = &get<SPIRType>(id: backing_type->parent_type);
10545
10546 auto img_type = *backing_type;
10547 auto tmp_type = type;
10548 tmp_type.basetype = expected_type;
10549 img_type.image.type = ir.increase_bound_by(count: 1);
10550 set<SPIRType>(id: img_type.image.type, args&: tmp_type);
10551
10552 image_expr = join(ts: "spvTextureCast<", ts: type_to_glsl(type: img_type, id: obj), ts: ">(", ts&: image_expr, ts: ")");
10553 }
10554
10555 exp += join(ts&: image_expr, ts: ".", ts&: op, ts: "(");
10556 if (ptr_type.storage == StorageClassImage && res_type->image.arrayed)
10557 {
10558 switch (res_type->image.dim)
10559 {
10560 case Dim1D:
10561 if (msl_options.texture_1D_as_2D)
10562 exp += join(ts: "uint2(", ts&: coord, ts: ".x, 0), ", ts&: coord, ts: ".y");
10563 else
10564 exp += join(ts&: coord, ts: ".x, ", ts&: coord, ts: ".y");
10565
10566 break;
10567 case Dim2D:
10568 exp += join(ts&: coord, ts: ".xy, ", ts&: coord, ts: ".z");
10569 break;
10570 default:
10571 SPIRV_CROSS_THROW("Cannot do atomics on Cube textures.");
10572 }
10573 }
10574 else if (ptr_type.storage == StorageClassImage && res_type->image.dim == Dim1D && msl_options.texture_1D_as_2D)
10575 exp += join(ts: "uint2(", ts&: coord, ts: ", 0)");
10576 else
10577 exp += coord;
10578 }
10579 else
10580 {
10581 exp += obj_expression;
10582 }
10583 }
10584 else
10585 {
10586 exp += string(op) + "_explicit(";
10587 exp += "(";
10588 // Emulate texture2D atomic operations
10589 if (ptr_type.storage == StorageClassImage)
10590 {
10591 auto &flags = ir.get_decoration_bitset(id: var->self);
10592 if (decoration_flags_signal_volatile(flags))
10593 exp += "volatile ";
10594 exp += "device";
10595 }
10596 else if (var && ptr_type.storage != StorageClassPhysicalStorageBuffer)
10597 {
10598 exp += get_argument_address_space(argument: *var);
10599 }
10600 else
10601 {
10602 // Fallback scenario, could happen for raw pointers.
10603 exp += ptr_type.storage == StorageClassWorkgroup ? "threadgroup" : "device";
10604 }
10605
10606 exp += " atomic_";
10607 // For signed and unsigned min/max, we can signal this through the pointer type.
10608 // There is no other way, since C++ does not have explicit signage for atomics.
10609 exp += type_to_glsl(type: remapped_type);
10610 exp += "*)";
10611
10612 exp += "&";
10613 exp += to_enclosed_expression(id: obj);
10614 }
10615
10616 if (is_atomic_compare_exchange_strong)
10617 {
10618 assert(strcmp(op, "atomic_compare_exchange_weak") == 0);
10619 assert(op2);
10620 assert(has_mem_order_2);
10621 exp += ", &";
10622 exp += to_name(id: vec4_temporary_id ? vec4_temporary_id : result_id);
10623 exp += ", ";
10624 exp += to_expression(id: op2);
10625
10626 if (!use_native_image_atomic)
10627 {
10628 exp += ", ";
10629 exp += get_memory_order(spv_mem_sem: mem_order_1);
10630 exp += ", ";
10631 exp += get_memory_order(spv_mem_sem: mem_order_2);
10632 }
10633 exp += ")";
10634
10635 // MSL only supports the weak atomic compare exchange, so emit a CAS loop here.
10636 // The MSL function returns false if the atomic write fails OR the comparison test fails,
10637 // so we must validate that it wasn't the comparison test that failed before continuing
10638 // the CAS loop, otherwise it will loop infinitely, with the comparison test always failing.
10639 // The function updates the comparator value from the memory value, so the additional
10640 // comparison test evaluates the memory value against the expected value.
10641 if (!check_discard)
10642 {
10643 emit_uninitialized_temporary_expression(type: result_type, id: result_id);
10644 if (vec4_temporary_id)
10645 emit_uninitialized_temporary_expression(type: vec4_temporary_id + 1, id: vec4_temporary_id);
10646 }
10647
10648 statement(ts: "do");
10649 begin_scope();
10650
10651 string scalar_expression;
10652 if (vec4_temporary_id)
10653 scalar_expression = join(ts: to_expression(id: vec4_temporary_id), ts: ".x");
10654 else
10655 scalar_expression = to_expression(id: result_id);
10656
10657 statement(ts&: scalar_expression, ts: " = ", ts: to_expression(id: op1), ts: ";");
10658 end_scope_decl(decl: join(ts: "while (!", ts&: exp, ts: " && ", ts&: scalar_expression, ts: " == ", ts: to_enclosed_expression(id: op1), ts: ")"));
10659 if (vec4_temporary_id)
10660 statement(ts: to_expression(id: result_id), ts: " = ", ts&: scalar_expression, ts: ";");
10661
10662 // Vulkan: (section 9.29: ... and values returned by atomic instructions in helper invocations are undefined)
10663 if (check_discard)
10664 {
10665 end_scope();
10666 statement(ts: "else");
10667 begin_scope();
10668 statement(ts: to_expression(id: result_id), ts: " = {};");
10669 end_scope();
10670 }
10671 }
10672 else
10673 {
10674 assert(strcmp(op, "atomic_compare_exchange_weak") != 0);
10675
10676 if (op1)
10677 {
10678 exp += ", ";
10679 if (op1_is_literal)
10680 exp += to_string(val: op1);
10681 else
10682 exp += bitcast_expression(target_type: expected_type, arg: op1);
10683 }
10684
10685 if (op2)
10686 exp += ", " + to_expression(id: op2);
10687
10688 if (!use_native_image_atomic)
10689 {
10690 exp += string(", ") + get_memory_order(spv_mem_sem: mem_order_1);
10691 if (has_mem_order_2)
10692 exp += string(", ") + get_memory_order(spv_mem_sem: mem_order_2);
10693 }
10694
10695 exp += ")";
10696
10697 // For some particular reason, atomics return vec4 in Metal ...
10698 if (use_native_image_atomic)
10699 exp += ".x";
10700
10701 // Vulkan: (section 9.29: ... and values returned by atomic instructions in helper invocations are undefined)
10702 if (check_discard)
10703 {
10704 exp += " : ";
10705 if (strcmp(s1: op, s2: "atomic_store") != 0)
10706 exp += join(ts: type_to_glsl(type: get<SPIRType>(id: result_type)), ts: "{}");
10707 else
10708 exp += "((void)0)";
10709 exp += ")";
10710 }
10711
10712 if (expected_type != type.basetype)
10713 exp = bitcast_expression(target_type: type, expr_type: expected_type, expr: exp);
10714
10715 if (strcmp(s1: op, s2: "atomic_store") != 0)
10716 emit_op(result_type, result_id, rhs: exp, forward_rhs: false);
10717 else
10718 statement(ts&: exp, ts: ";");
10719 }
10720
10721 flush_all_atomic_capable_variables();
10722}
10723
10724// Metal only supports relaxed memory order for now
10725const char *CompilerMSL::get_memory_order(uint32_t)
10726{
10727 return "memory_order_relaxed";
10728}
10729
10730// Override for MSL-specific extension syntax instructions.
10731// In some cases, deliberately select either the fast or precise versions of the MSL functions to match Vulkan math precision results.
10732void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, uint32_t count)
10733{
10734 auto op = static_cast<GLSLstd450>(eop);
10735
10736 // If we need to do implicit bitcasts, make sure we do it with the correct type.
10737 uint32_t integer_width = get_integer_width_for_glsl_instruction(op, arguments: args, length: count);
10738 auto int_type = to_signed_basetype(width: integer_width);
10739 auto uint_type = to_unsigned_basetype(width: integer_width);
10740
10741 op = get_remapped_glsl_op(std450_op: op);
10742
10743 auto &restype = get<SPIRType>(id: result_type);
10744
10745 switch (op)
10746 {
10747 case GLSLstd450Sinh:
10748 if (restype.basetype == SPIRType::Half)
10749 {
10750 // MSL does not have overload for half. Force-cast back to half.
10751 auto expr = join(ts: "half(fast::sinh(", ts: to_unpacked_expression(id: args[0]), ts: "))");
10752 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: should_forward(id: args[0]));
10753 inherit_expression_dependencies(dst: id, source: args[0]);
10754 }
10755 else
10756 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "fast::sinh");
10757 break;
10758 case GLSLstd450Cosh:
10759 if (restype.basetype == SPIRType::Half)
10760 {
10761 // MSL does not have overload for half. Force-cast back to half.
10762 auto expr = join(ts: "half(fast::cosh(", ts: to_unpacked_expression(id: args[0]), ts: "))");
10763 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: should_forward(id: args[0]));
10764 inherit_expression_dependencies(dst: id, source: args[0]);
10765 }
10766 else
10767 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "fast::cosh");
10768 break;
10769 case GLSLstd450Tanh:
10770 if (restype.basetype == SPIRType::Half)
10771 {
10772 // MSL does not have overload for half. Force-cast back to half.
10773 auto expr = join(ts: "half(fast::tanh(", ts: to_unpacked_expression(id: args[0]), ts: "))");
10774 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: should_forward(id: args[0]));
10775 inherit_expression_dependencies(dst: id, source: args[0]);
10776 }
10777 else
10778 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "precise::tanh");
10779 break;
10780 case GLSLstd450Atan2:
10781 if (restype.basetype == SPIRType::Half)
10782 {
10783 // MSL does not have overload for half. Force-cast back to half.
10784 auto expr = join(ts: "half(fast::atan2(", ts: to_unpacked_expression(id: args[0]), ts: ", ", ts: to_unpacked_expression(id: args[1]), ts: "))");
10785 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: should_forward(id: args[0]) && should_forward(id: args[1]));
10786 inherit_expression_dependencies(dst: id, source: args[0]);
10787 inherit_expression_dependencies(dst: id, source: args[1]);
10788 }
10789 else
10790 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "precise::atan2");
10791 break;
10792 case GLSLstd450InverseSqrt:
10793 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "rsqrt");
10794 break;
10795 case GLSLstd450RoundEven:
10796 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "rint");
10797 break;
10798
10799 case GLSLstd450FindILsb:
10800 {
10801 // In this template version of findLSB, we return T.
10802 auto basetype = expression_type(id: args[0]).basetype;
10803 emit_unary_func_op_cast(result_type, result_id: id, op0: args[0], op: "spvFindLSB", input_type: basetype, expected_result_type: basetype);
10804 break;
10805 }
10806
10807 case GLSLstd450FindSMsb:
10808 emit_unary_func_op_cast(result_type, result_id: id, op0: args[0], op: "spvFindSMSB", input_type: int_type, expected_result_type: int_type);
10809 break;
10810
10811 case GLSLstd450FindUMsb:
10812 emit_unary_func_op_cast(result_type, result_id: id, op0: args[0], op: "spvFindUMSB", input_type: uint_type, expected_result_type: uint_type);
10813 break;
10814
10815 case GLSLstd450PackSnorm4x8:
10816 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "pack_float_to_snorm4x8");
10817 break;
10818 case GLSLstd450PackUnorm4x8:
10819 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "pack_float_to_unorm4x8");
10820 break;
10821 case GLSLstd450PackSnorm2x16:
10822 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "pack_float_to_snorm2x16");
10823 break;
10824 case GLSLstd450PackUnorm2x16:
10825 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "pack_float_to_unorm2x16");
10826 break;
10827
10828 case GLSLstd450PackHalf2x16:
10829 {
10830 auto expr = join(ts: "as_type<uint>(half2(", ts: to_expression(id: args[0]), ts: "))");
10831 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: should_forward(id: args[0]));
10832 inherit_expression_dependencies(dst: id, source: args[0]);
10833 break;
10834 }
10835
10836 case GLSLstd450UnpackSnorm4x8:
10837 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "unpack_snorm4x8_to_float");
10838 break;
10839 case GLSLstd450UnpackUnorm4x8:
10840 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "unpack_unorm4x8_to_float");
10841 break;
10842 case GLSLstd450UnpackSnorm2x16:
10843 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "unpack_snorm2x16_to_float");
10844 break;
10845 case GLSLstd450UnpackUnorm2x16:
10846 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "unpack_unorm2x16_to_float");
10847 break;
10848
10849 case GLSLstd450UnpackHalf2x16:
10850 {
10851 auto expr = join(ts: "float2(as_type<half2>(", ts: to_expression(id: args[0]), ts: "))");
10852 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: should_forward(id: args[0]));
10853 inherit_expression_dependencies(dst: id, source: args[0]);
10854 break;
10855 }
10856
10857 case GLSLstd450PackDouble2x32:
10858 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "unsupported_GLSLstd450PackDouble2x32"); // Currently unsupported
10859 break;
10860 case GLSLstd450UnpackDouble2x32:
10861 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "unsupported_GLSLstd450UnpackDouble2x32"); // Currently unsupported
10862 break;
10863
10864 case GLSLstd450MatrixInverse:
10865 {
10866 auto &mat_type = get<SPIRType>(id: result_type);
10867 switch (mat_type.columns)
10868 {
10869 case 2:
10870 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvInverse2x2");
10871 break;
10872 case 3:
10873 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvInverse3x3");
10874 break;
10875 case 4:
10876 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvInverse4x4");
10877 break;
10878 default:
10879 break;
10880 }
10881 break;
10882 }
10883
10884 case GLSLstd450FMin:
10885 // If the result type isn't float, don't bother calling the specific
10886 // precise::/fast:: version. Metal doesn't have those for half and
10887 // double types.
10888 if (get<SPIRType>(id: result_type).basetype != SPIRType::Float)
10889 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "min");
10890 else
10891 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "fast::min");
10892 break;
10893
10894 case GLSLstd450FMax:
10895 if (get<SPIRType>(id: result_type).basetype != SPIRType::Float)
10896 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "max");
10897 else
10898 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "fast::max");
10899 break;
10900
10901 case GLSLstd450FClamp:
10902 // TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call.
10903 if (get<SPIRType>(id: result_type).basetype != SPIRType::Float)
10904 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "clamp");
10905 else
10906 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "fast::clamp");
10907 break;
10908
10909 case GLSLstd450NMin:
10910 if (get<SPIRType>(id: result_type).basetype != SPIRType::Float)
10911 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "min");
10912 else
10913 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "precise::min");
10914 break;
10915
10916 case GLSLstd450NMax:
10917 if (get<SPIRType>(id: result_type).basetype != SPIRType::Float)
10918 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "max");
10919 else
10920 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "precise::max");
10921 break;
10922
10923 case GLSLstd450NClamp:
10924 // TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call.
10925 if (get<SPIRType>(id: result_type).basetype != SPIRType::Float)
10926 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "clamp");
10927 else
10928 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "precise::clamp");
10929 break;
10930
10931 case GLSLstd450InterpolateAtCentroid:
10932 {
10933 // We can't just emit the expression normally, because the qualified name contains a call to the default
10934 // interpolate method, or refers to a local variable. We saved the interface index we need; use it to construct
10935 // the base for the method call.
10936 uint32_t interface_index = get_extended_decoration(id: args[0], decoration: SPIRVCrossDecorationInterfaceMemberIndex);
10937 string component;
10938 if (has_extended_decoration(id: args[0], decoration: SPIRVCrossDecorationInterpolantComponentExpr))
10939 {
10940 uint32_t index_expr = get_extended_decoration(id: args[0], decoration: SPIRVCrossDecorationInterpolantComponentExpr);
10941 auto *c = maybe_get<SPIRConstant>(id: index_expr);
10942 if (!c || c->specialization)
10943 component = join(ts: "[", ts: to_expression(id: index_expr), ts: "]");
10944 else
10945 component = join(ts: ".", ts: index_to_swizzle(index: c->scalar()));
10946 }
10947 emit_op(result_type, result_id: id,
10948 rhs: join(ts: to_name(id: stage_in_var_id), ts: ".", ts: to_member_name(type: get_stage_in_struct_type(), index: interface_index),
10949 ts: ".interpolate_at_centroid()", ts&: component),
10950 forward_rhs: should_forward(id: args[0]));
10951 break;
10952 }
10953
10954 case GLSLstd450InterpolateAtSample:
10955 {
10956 uint32_t interface_index = get_extended_decoration(id: args[0], decoration: SPIRVCrossDecorationInterfaceMemberIndex);
10957 string component;
10958 if (has_extended_decoration(id: args[0], decoration: SPIRVCrossDecorationInterpolantComponentExpr))
10959 {
10960 uint32_t index_expr = get_extended_decoration(id: args[0], decoration: SPIRVCrossDecorationInterpolantComponentExpr);
10961 auto *c = maybe_get<SPIRConstant>(id: index_expr);
10962 if (!c || c->specialization)
10963 component = join(ts: "[", ts: to_expression(id: index_expr), ts: "]");
10964 else
10965 component = join(ts: ".", ts: index_to_swizzle(index: c->scalar()));
10966 }
10967 emit_op(result_type, result_id: id,
10968 rhs: join(ts: to_name(id: stage_in_var_id), ts: ".", ts: to_member_name(type: get_stage_in_struct_type(), index: interface_index),
10969 ts: ".interpolate_at_sample(", ts: to_expression(id: args[1]), ts: ")", ts&: component),
10970 forward_rhs: should_forward(id: args[0]) && should_forward(id: args[1]));
10971 break;
10972 }
10973
10974 case GLSLstd450InterpolateAtOffset:
10975 {
10976 uint32_t interface_index = get_extended_decoration(id: args[0], decoration: SPIRVCrossDecorationInterfaceMemberIndex);
10977 string component;
10978 if (has_extended_decoration(id: args[0], decoration: SPIRVCrossDecorationInterpolantComponentExpr))
10979 {
10980 uint32_t index_expr = get_extended_decoration(id: args[0], decoration: SPIRVCrossDecorationInterpolantComponentExpr);
10981 auto *c = maybe_get<SPIRConstant>(id: index_expr);
10982 if (!c || c->specialization)
10983 component = join(ts: "[", ts: to_expression(id: index_expr), ts: "]");
10984 else
10985 component = join(ts: ".", ts: index_to_swizzle(index: c->scalar()));
10986 }
10987 // Like Direct3D, Metal puts the (0, 0) at the upper-left corner, not the center as SPIR-V and GLSL do.
10988 // Offset the offset by (1/2 - 1/16), or 0.4375, to compensate for this.
10989 // It has to be (1/2 - 1/16) and not 1/2, or several CTS tests subtly break on Intel.
10990 emit_op(result_type, result_id: id,
10991 rhs: join(ts: to_name(id: stage_in_var_id), ts: ".", ts: to_member_name(type: get_stage_in_struct_type(), index: interface_index),
10992 ts: ".interpolate_at_offset(", ts: to_expression(id: args[1]), ts: " + 0.4375)", ts&: component),
10993 forward_rhs: should_forward(id: args[0]) && should_forward(id: args[1]));
10994 break;
10995 }
10996
10997 case GLSLstd450Distance:
10998 // MSL does not support scalar versions here.
10999 if (expression_type(id: args[0]).vecsize == 1)
11000 {
11001 // Equivalent to length(a - b) -> abs(a - b).
11002 emit_op(result_type, result_id: id,
11003 rhs: join(ts: "abs(", ts: to_enclosed_unpacked_expression(id: args[0]), ts: " - ",
11004 ts: to_enclosed_unpacked_expression(id: args[1]), ts: ")"),
11005 forward_rhs: should_forward(id: args[0]) && should_forward(id: args[1]));
11006 inherit_expression_dependencies(dst: id, source: args[0]);
11007 inherit_expression_dependencies(dst: id, source: args[1]);
11008 }
11009 else
11010 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
11011 break;
11012
11013 case GLSLstd450Length:
11014 // MSL does not support scalar versions, so use abs().
11015 if (expression_type(id: args[0]).vecsize == 1)
11016 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "abs");
11017 else
11018 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
11019 break;
11020
11021 case GLSLstd450Normalize:
11022 {
11023 auto &exp_type = expression_type(id: args[0]);
11024 // MSL does not support scalar versions here.
11025 // MSL has no implementation for normalize in the fast:: namespace for half2 and half3
11026 // Returns -1 or 1 for valid input, sign() does the job.
11027 if (exp_type.vecsize == 1)
11028 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "sign");
11029 else if (exp_type.vecsize <= 3 && exp_type.basetype == SPIRType::Half)
11030 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "normalize");
11031 else
11032 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "fast::normalize");
11033 break;
11034 }
11035 case GLSLstd450Reflect:
11036 if (get<SPIRType>(id: result_type).vecsize == 1)
11037 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "spvReflect");
11038 else
11039 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
11040 break;
11041
11042 case GLSLstd450Refract:
11043 if (get<SPIRType>(id: result_type).vecsize == 1)
11044 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "spvRefract");
11045 else
11046 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
11047 break;
11048
11049 case GLSLstd450FaceForward:
11050 if (get<SPIRType>(id: result_type).vecsize == 1)
11051 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "spvFaceForward");
11052 else
11053 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
11054 break;
11055
11056 case GLSLstd450Modf:
11057 case GLSLstd450Frexp:
11058 {
11059 // Special case. If the variable is a scalar access chain, we cannot use it directly. We have to emit a temporary.
11060 // Another special case is if the variable is in a storage class which is not thread.
11061 auto *ptr = maybe_get<SPIRExpression>(id: args[1]);
11062 auto &type = expression_type(id: args[1]);
11063
11064 bool is_thread_storage = storage_class_array_is_thread(storage: type.storage);
11065 if (type.storage == StorageClassOutput && capture_output_to_buffer)
11066 is_thread_storage = false;
11067
11068 if (!is_thread_storage ||
11069 (ptr && ptr->access_chain && is_scalar(type: expression_type(id: args[1]))))
11070 {
11071 register_call_out_argument(id: args[1]);
11072 forced_temporaries.insert(x: id);
11073
11074 // Need to create temporaries and copy over to access chain after.
11075 // We cannot directly take the reference of a vector swizzle in MSL, even if it's scalar ...
11076 uint32_t &tmp_id = extra_sub_expressions[id];
11077 if (!tmp_id)
11078 tmp_id = ir.increase_bound_by(count: 1);
11079
11080 uint32_t tmp_type_id = get_pointee_type_id(type_id: expression_type_id(id: args[1]));
11081 emit_uninitialized_temporary_expression(type: tmp_type_id, id: tmp_id);
11082 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: tmp_id, op: eop == GLSLstd450Modf ? "modf" : "frexp");
11083 statement(ts: to_expression(id: args[1]), ts: " = ", ts: to_expression(id: tmp_id), ts: ";");
11084 }
11085 else
11086 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
11087 break;
11088 }
11089
11090 case GLSLstd450Pow:
11091 // powr makes x < 0.0 undefined, just like SPIR-V.
11092 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "powr");
11093 break;
11094
11095 default:
11096 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
11097 break;
11098 }
11099}
11100
11101void CompilerMSL::emit_spv_amd_shader_trinary_minmax_op(uint32_t result_type, uint32_t id, uint32_t eop,
11102 const uint32_t *args, uint32_t count)
11103{
11104 enum AMDShaderTrinaryMinMax
11105 {
11106 FMin3AMD = 1,
11107 UMin3AMD = 2,
11108 SMin3AMD = 3,
11109 FMax3AMD = 4,
11110 UMax3AMD = 5,
11111 SMax3AMD = 6,
11112 FMid3AMD = 7,
11113 UMid3AMD = 8,
11114 SMid3AMD = 9
11115 };
11116
11117 if (!msl_options.supports_msl_version(major: 2, minor: 1))
11118 SPIRV_CROSS_THROW("Trinary min/max functions require MSL 2.1.");
11119
11120 auto op = static_cast<AMDShaderTrinaryMinMax>(eop);
11121
11122 switch (op)
11123 {
11124 case FMid3AMD:
11125 case UMid3AMD:
11126 case SMid3AMD:
11127 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "median3");
11128 break;
11129 default:
11130 CompilerGLSL::emit_spv_amd_shader_trinary_minmax_op(result_type, result_id: id, op: eop, args, count);
11131 break;
11132 }
11133}
11134
11135// Emit a structure declaration for the specified interface variable.
11136void CompilerMSL::emit_interface_block(uint32_t ib_var_id)
11137{
11138 if (ib_var_id)
11139 {
11140 auto &ib_var = get<SPIRVariable>(id: ib_var_id);
11141 auto &ib_type = get_variable_data_type(var: ib_var);
11142 //assert(ib_type.basetype == SPIRType::Struct && !ib_type.member_types.empty());
11143 assert(ib_type.basetype == SPIRType::Struct);
11144 emit_struct(type&: ib_type);
11145 }
11146}
11147
11148// Emits the declaration signature of the specified function.
11149// If this is the entry point function, Metal-specific return value and function arguments are added.
11150void CompilerMSL::emit_function_prototype(SPIRFunction &func, const Bitset &)
11151{
11152 if (func.self != ir.default_entry_point)
11153 add_function_overload(func);
11154
11155 local_variable_names = resource_names;
11156 string decl;
11157
11158 processing_entry_point = func.self == ir.default_entry_point;
11159
11160 // Metal helper functions must be static force-inline otherwise they will cause problems when linked together in a single Metallib.
11161 if (!processing_entry_point)
11162 statement(ts&: force_inline);
11163
11164 auto &type = get<SPIRType>(id: func.return_type);
11165
11166 if (!type.array.empty() && msl_options.force_native_arrays)
11167 {
11168 // We cannot return native arrays in MSL, so "return" through an out variable.
11169 decl += "void";
11170 }
11171 else
11172 {
11173 decl += func_type_decl(type);
11174 }
11175
11176 decl += " ";
11177 decl += to_name(id: func.self);
11178 decl += "(";
11179
11180 if (!type.array.empty() && msl_options.force_native_arrays)
11181 {
11182 // Fake arrays returns by writing to an out array instead.
11183 decl += "thread ";
11184 decl += type_to_glsl(type);
11185 decl += " (&spvReturnValue)";
11186 decl += type_to_array_glsl(type, variable_id: 0);
11187 if (!func.arguments.empty())
11188 decl += ", ";
11189 }
11190
11191 if (processing_entry_point)
11192 {
11193 if (msl_options.argument_buffers)
11194 decl += entry_point_args_argument_buffer(append_comma: !func.arguments.empty());
11195 else
11196 decl += entry_point_args_classic(append_comma: !func.arguments.empty());
11197
11198 // append entry point args to avoid conflicts in local variable names.
11199 local_variable_names.insert(first: resource_names.begin(), last: resource_names.end());
11200
11201 // If entry point function has variables that require early declaration,
11202 // ensure they each have an empty initializer, creating one if needed.
11203 // This is done at this late stage because the initialization expression
11204 // is cleared after each compilation pass.
11205 for (auto var_id : vars_needing_early_declaration)
11206 {
11207 auto &ed_var = get<SPIRVariable>(id: var_id);
11208 ID &initializer = ed_var.initializer;
11209 if (!initializer)
11210 initializer = ir.increase_bound_by(count: 1);
11211
11212 // Do not override proper initializers.
11213 if (ir.ids[initializer].get_type() == TypeNone || ir.ids[initializer].get_type() == TypeExpression)
11214 set<SPIRExpression>(id: ed_var.initializer, args: "{}", args&: ed_var.basetype, args: true);
11215 }
11216
11217 // add `taskPayloadSharedEXT` variable to entry-point arguments
11218 for (auto &v : func.local_variables)
11219 {
11220 auto &var = get<SPIRVariable>(id: v);
11221 if (var.storage != StorageClassTaskPayloadWorkgroupEXT)
11222 continue;
11223
11224 add_local_variable_name(id: v);
11225 SPIRFunction::Parameter arg = {};
11226 arg.id = v;
11227 arg.type = var.basetype;
11228 arg.alias_global_variable = true;
11229 decl += join(ts: ", ", ts: argument_decl(arg), ts: " [[payload]]");
11230 }
11231 }
11232
11233 for (auto &arg : func.arguments)
11234 {
11235 uint32_t name_id = arg.id;
11236
11237 auto *var = maybe_get<SPIRVariable>(id: arg.id);
11238 if (var)
11239 {
11240 // If we need to modify the name of the variable, make sure we modify the original variable.
11241 // Our alias is just a shadow variable.
11242 if (arg.alias_global_variable && var->basevariable)
11243 name_id = var->basevariable;
11244
11245 var->parameter = &arg; // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
11246 }
11247
11248 add_local_variable_name(id: name_id);
11249
11250 decl += argument_decl(arg);
11251
11252 bool is_dynamic_img_sampler = has_extended_decoration(id: arg.id, decoration: SPIRVCrossDecorationDynamicImageSampler);
11253
11254 auto &arg_type = get<SPIRType>(id: arg.type);
11255 if (arg_type.basetype == SPIRType::SampledImage && !is_dynamic_img_sampler)
11256 {
11257 // Manufacture automatic plane args for multiplanar texture
11258 uint32_t planes = 1;
11259 if (auto *constexpr_sampler = find_constexpr_sampler(id: name_id))
11260 if (constexpr_sampler->ycbcr_conversion_enable)
11261 planes = constexpr_sampler->planes;
11262 for (uint32_t i = 1; i < planes; i++)
11263 decl += join(ts: ", ", ts: argument_decl(arg), ts&: plane_name_suffix, ts&: i);
11264
11265 // Manufacture automatic sampler arg for SampledImage texture
11266 if (arg_type.image.dim != DimBuffer)
11267 {
11268 if (arg_type.array.empty() || (var ? is_var_runtime_size_array(var: *var) : is_runtime_size_array(type: arg_type)))
11269 {
11270 decl += join(ts: ", ", ts: sampler_type(type: arg_type, id: arg.id, member: false), ts: " ", ts: to_sampler_expression(id: name_id));
11271 }
11272 else
11273 {
11274 const char *sampler_address_space =
11275 descriptor_address_space(id: name_id,
11276 storage: StorageClassUniformConstant,
11277 plain_address_space: "thread const");
11278 decl += join(ts: ", ", ts&: sampler_address_space, ts: " ", ts: sampler_type(type: arg_type, id: name_id, member: false), ts: "& ",
11279 ts: to_sampler_expression(id: name_id));
11280 }
11281 }
11282 }
11283
11284 // Manufacture automatic swizzle arg.
11285 if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type: arg_type) &&
11286 !is_dynamic_img_sampler)
11287 {
11288 bool arg_is_array = !arg_type.array.empty();
11289 decl += join(ts: ", constant uint", ts: arg_is_array ? "* " : "& ", ts: to_swizzle_expression(id: name_id));
11290 }
11291
11292 if (buffer_requires_array_length(id: name_id))
11293 {
11294 bool arg_is_array = !arg_type.array.empty();
11295 decl += join(ts: ", constant uint", ts: arg_is_array ? "* " : "& ", ts: to_buffer_size_expression(id: name_id));
11296 }
11297
11298 if (&arg != &func.arguments.back())
11299 decl += ", ";
11300 }
11301
11302 decl += ")";
11303 statement(ts&: decl);
11304}
11305
11306static bool needs_chroma_reconstruction(const MSLConstexprSampler *constexpr_sampler)
11307{
11308 // For now, only multiplanar images need explicit reconstruction. GBGR and BGRG images
11309 // use implicit reconstruction.
11310 return constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && constexpr_sampler->planes > 1;
11311}
11312
11313// Returns the texture sampling function string for the specified image and sampling characteristics.
11314string CompilerMSL::to_function_name(const TextureFunctionNameArguments &args)
11315{
11316 VariableID img = args.base.img;
11317 const MSLConstexprSampler *constexpr_sampler = nullptr;
11318 bool is_dynamic_img_sampler = false;
11319 if (auto *var = maybe_get_backing_variable(chain: img))
11320 {
11321 constexpr_sampler = find_constexpr_sampler(id: var->basevariable ? var->basevariable : VariableID(var->self));
11322 is_dynamic_img_sampler = has_extended_decoration(id: var->self, decoration: SPIRVCrossDecorationDynamicImageSampler);
11323 }
11324
11325 // Special-case gather. We have to alter the component being looked up in the swizzle case.
11326 if (msl_options.swizzle_texture_samples && args.base.is_gather && !is_dynamic_img_sampler &&
11327 (!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable))
11328 {
11329 bool is_compare = comparison_ids.count(x: img);
11330 add_spv_func_and_recompile(spv_func: is_compare ? SPVFuncImplGatherCompareSwizzle : SPVFuncImplGatherSwizzle);
11331 return is_compare ? "spvGatherCompareSwizzle" : "spvGatherSwizzle";
11332 }
11333
11334 // Special-case gather with an array of offsets. We have to lower into 4 separate gathers.
11335 if (args.has_array_offsets && !is_dynamic_img_sampler &&
11336 (!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable))
11337 {
11338 bool is_compare = comparison_ids.count(x: img);
11339 add_spv_func_and_recompile(spv_func: is_compare ? SPVFuncImplGatherCompareConstOffsets : SPVFuncImplGatherConstOffsets);
11340 add_spv_func_and_recompile(spv_func: SPVFuncImplForwardArgs);
11341 return is_compare ? "spvGatherCompareConstOffsets" : "spvGatherConstOffsets";
11342 }
11343
11344 auto *combined = maybe_get<SPIRCombinedImageSampler>(id: img);
11345
11346 // Texture reference
11347 string fname;
11348 if (needs_chroma_reconstruction(constexpr_sampler) && !is_dynamic_img_sampler)
11349 {
11350 if (constexpr_sampler->planes != 2 && constexpr_sampler->planes != 3)
11351 SPIRV_CROSS_THROW("Unhandled number of color image planes!");
11352 // 444 images aren't downsampled, so we don't need to do linear filtering.
11353 if (constexpr_sampler->resolution == MSL_FORMAT_RESOLUTION_444 ||
11354 constexpr_sampler->chroma_filter == MSL_SAMPLER_FILTER_NEAREST)
11355 {
11356 if (constexpr_sampler->planes == 2)
11357 add_spv_func_and_recompile(spv_func: SPVFuncImplChromaReconstructNearest2Plane);
11358 else
11359 add_spv_func_and_recompile(spv_func: SPVFuncImplChromaReconstructNearest3Plane);
11360 fname = "spvChromaReconstructNearest";
11361 }
11362 else // Linear with a downsampled format
11363 {
11364 fname = "spvChromaReconstructLinear";
11365 switch (constexpr_sampler->resolution)
11366 {
11367 case MSL_FORMAT_RESOLUTION_444:
11368 assert(false);
11369 break; // not reached
11370 case MSL_FORMAT_RESOLUTION_422:
11371 switch (constexpr_sampler->x_chroma_offset)
11372 {
11373 case MSL_CHROMA_LOCATION_COSITED_EVEN:
11374 if (constexpr_sampler->planes == 2)
11375 add_spv_func_and_recompile(spv_func: SPVFuncImplChromaReconstructLinear422CositedEven2Plane);
11376 else
11377 add_spv_func_and_recompile(spv_func: SPVFuncImplChromaReconstructLinear422CositedEven3Plane);
11378 fname += "422CositedEven";
11379 break;
11380 case MSL_CHROMA_LOCATION_MIDPOINT:
11381 if (constexpr_sampler->planes == 2)
11382 add_spv_func_and_recompile(spv_func: SPVFuncImplChromaReconstructLinear422Midpoint2Plane);
11383 else
11384 add_spv_func_and_recompile(spv_func: SPVFuncImplChromaReconstructLinear422Midpoint3Plane);
11385 fname += "422Midpoint";
11386 break;
11387 default:
11388 SPIRV_CROSS_THROW("Invalid chroma location.");
11389 }
11390 break;
11391 case MSL_FORMAT_RESOLUTION_420:
11392 fname += "420";
11393 switch (constexpr_sampler->x_chroma_offset)
11394 {
11395 case MSL_CHROMA_LOCATION_COSITED_EVEN:
11396 switch (constexpr_sampler->y_chroma_offset)
11397 {
11398 case MSL_CHROMA_LOCATION_COSITED_EVEN:
11399 if (constexpr_sampler->planes == 2)
11400 add_spv_func_and_recompile(
11401 spv_func: SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven2Plane);
11402 else
11403 add_spv_func_and_recompile(
11404 spv_func: SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven3Plane);
11405 fname += "XCositedEvenYCositedEven";
11406 break;
11407 case MSL_CHROMA_LOCATION_MIDPOINT:
11408 if (constexpr_sampler->planes == 2)
11409 add_spv_func_and_recompile(
11410 spv_func: SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint2Plane);
11411 else
11412 add_spv_func_and_recompile(
11413 spv_func: SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint3Plane);
11414 fname += "XCositedEvenYMidpoint";
11415 break;
11416 default:
11417 SPIRV_CROSS_THROW("Invalid Y chroma location.");
11418 }
11419 break;
11420 case MSL_CHROMA_LOCATION_MIDPOINT:
11421 switch (constexpr_sampler->y_chroma_offset)
11422 {
11423 case MSL_CHROMA_LOCATION_COSITED_EVEN:
11424 if (constexpr_sampler->planes == 2)
11425 add_spv_func_and_recompile(
11426 spv_func: SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven2Plane);
11427 else
11428 add_spv_func_and_recompile(
11429 spv_func: SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven3Plane);
11430 fname += "XMidpointYCositedEven";
11431 break;
11432 case MSL_CHROMA_LOCATION_MIDPOINT:
11433 if (constexpr_sampler->planes == 2)
11434 add_spv_func_and_recompile(spv_func: SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint2Plane);
11435 else
11436 add_spv_func_and_recompile(spv_func: SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane);
11437 fname += "XMidpointYMidpoint";
11438 break;
11439 default:
11440 SPIRV_CROSS_THROW("Invalid Y chroma location.");
11441 }
11442 break;
11443 default:
11444 SPIRV_CROSS_THROW("Invalid X chroma location.");
11445 }
11446 break;
11447 default:
11448 SPIRV_CROSS_THROW("Invalid format resolution.");
11449 }
11450 }
11451 }
11452 else
11453 {
11454 fname = to_expression(id: combined ? combined->image : img) + ".";
11455
11456 // Texture function and sampler
11457 if (args.base.is_fetch)
11458 fname += "read";
11459 else if (args.base.is_gather)
11460 fname += "gather";
11461 else
11462 fname += "sample";
11463
11464 if (args.has_dref)
11465 fname += "_compare";
11466 }
11467
11468 return fname;
11469}
11470
11471string CompilerMSL::convert_to_f32(const string &expr, uint32_t components)
11472{
11473 SPIRType t { components > 1 ? OpTypeVector : OpTypeFloat };
11474 t.basetype = SPIRType::Float;
11475 t.vecsize = components;
11476 t.columns = 1;
11477 return join(ts: type_to_glsl_constructor(type: t), ts: "(", ts: expr, ts: ")");
11478}
11479
11480static inline bool sampling_type_needs_f32_conversion(const SPIRType &type)
11481{
11482 // Double is not supported to begin with, but doesn't hurt to check for completion.
11483 return type.basetype == SPIRType::Half || type.basetype == SPIRType::Double;
11484}
11485
11486// Returns the function args for a texture sampling function for the specified image and sampling characteristics.
11487string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool *p_forward)
11488{
11489 VariableID img = args.base.img;
11490 auto &imgtype = *args.base.imgtype;
11491 uint32_t lod = args.lod;
11492 uint32_t grad_x = args.grad_x;
11493 uint32_t grad_y = args.grad_y;
11494 uint32_t bias = args.bias;
11495
11496 const MSLConstexprSampler *constexpr_sampler = nullptr;
11497 bool is_dynamic_img_sampler = false;
11498 if (auto *var = maybe_get_backing_variable(chain: img))
11499 {
11500 constexpr_sampler = find_constexpr_sampler(id: var->basevariable ? var->basevariable : VariableID(var->self));
11501 is_dynamic_img_sampler = has_extended_decoration(id: var->self, decoration: SPIRVCrossDecorationDynamicImageSampler);
11502 }
11503
11504 string farg_str;
11505 bool forward = true;
11506
11507 if (!is_dynamic_img_sampler)
11508 {
11509 // Texture reference (for some cases)
11510 if (needs_chroma_reconstruction(constexpr_sampler))
11511 {
11512 // Multiplanar images need two or three textures.
11513 farg_str += to_expression(id: img);
11514 for (uint32_t i = 1; i < constexpr_sampler->planes; i++)
11515 farg_str += join(ts: ", ", ts: to_expression(id: img), ts&: plane_name_suffix, ts&: i);
11516 }
11517 else if ((!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable) &&
11518 msl_options.swizzle_texture_samples && args.base.is_gather)
11519 {
11520 auto *combined = maybe_get<SPIRCombinedImageSampler>(id: img);
11521 farg_str += to_expression(id: combined ? combined->image : img);
11522 }
11523
11524 // Gathers with constant offsets call a special function, so include the texture.
11525 if (args.has_array_offsets)
11526 farg_str += to_expression(id: img);
11527
11528 // Sampler reference
11529 if (!args.base.is_fetch)
11530 {
11531 if (!farg_str.empty())
11532 farg_str += ", ";
11533 farg_str += to_sampler_expression(id: img);
11534 }
11535
11536 if ((!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable) &&
11537 msl_options.swizzle_texture_samples && args.base.is_gather)
11538 {
11539 // Add the swizzle constant from the swizzle buffer.
11540 farg_str += ", " + to_swizzle_expression(id: img);
11541 used_swizzle_buffer = true;
11542 }
11543
11544 // Const offsets gather puts the const offsets before the other args.
11545 if (args.has_array_offsets)
11546 {
11547 forward = forward && should_forward(id: args.offset);
11548 farg_str += ", " + to_unpacked_expression(id: args.offset);
11549 }
11550
11551 // Const offsets gather or swizzled gather puts the component before the other args.
11552 if (args.component && (args.has_array_offsets || msl_options.swizzle_texture_samples))
11553 {
11554 forward = forward && should_forward(id: args.component);
11555 farg_str += ", " + to_component_argument(id: args.component);
11556 }
11557 }
11558
11559 // Texture coordinates
11560 forward = forward && should_forward(id: args.coord);
11561 auto coord_expr = to_enclosed_unpacked_expression(id: args.coord);
11562 auto &coord_type = expression_type(id: args.coord);
11563 bool coord_is_fp = type_is_floating_point(type: coord_type);
11564 bool is_cube_fetch = false;
11565
11566 string tex_coords = coord_expr;
11567 uint32_t alt_coord_component = 0;
11568
11569 switch (imgtype.image.dim)
11570 {
11571
11572 case Dim1D:
11573 if (coord_type.vecsize > 1)
11574 tex_coords = enclose_expression(expr: tex_coords) + ".x";
11575
11576 if (args.base.is_fetch)
11577 tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
11578 else if (sampling_type_needs_f32_conversion(type: coord_type))
11579 tex_coords = convert_to_f32(expr: tex_coords, components: 1);
11580
11581 if (msl_options.texture_1D_as_2D)
11582 {
11583 if (args.base.is_fetch)
11584 tex_coords = "uint2(" + tex_coords + ", 0)";
11585 else
11586 tex_coords = "float2(" + tex_coords + ", 0.5)";
11587 }
11588
11589 alt_coord_component = 1;
11590 break;
11591
11592 case DimBuffer:
11593 if (coord_type.vecsize > 1)
11594 tex_coords = enclose_expression(expr: tex_coords) + ".x";
11595
11596 if (msl_options.texture_buffer_native)
11597 {
11598 tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
11599 }
11600 else
11601 {
11602 // Metal texel buffer textures are 2D, so convert 1D coord to 2D.
11603 // Support for Metal 2.1's new texture_buffer type.
11604 if (args.base.is_fetch)
11605 {
11606 if (msl_options.texel_buffer_texture_width > 0)
11607 {
11608 tex_coords = "spvTexelBufferCoord(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
11609 }
11610 else
11611 {
11612 tex_coords = "spvTexelBufferCoord(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ", " +
11613 to_expression(id: img) + ")";
11614 }
11615 }
11616 }
11617
11618 alt_coord_component = 1;
11619 break;
11620
11621 case DimSubpassData:
11622 // If we're using Metal's native frame-buffer fetch API for subpass inputs,
11623 // this path will not be hit.
11624 tex_coords = "uint2(gl_FragCoord.xy)";
11625 alt_coord_component = 2;
11626 break;
11627
11628 case Dim2D:
11629 if (coord_type.vecsize > 2)
11630 tex_coords = enclose_expression(expr: tex_coords) + ".xy";
11631
11632 if (args.base.is_fetch)
11633 tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
11634 else if (sampling_type_needs_f32_conversion(type: coord_type))
11635 tex_coords = convert_to_f32(expr: tex_coords, components: 2);
11636
11637 alt_coord_component = 2;
11638 break;
11639
11640 case Dim3D:
11641 if (coord_type.vecsize > 3)
11642 tex_coords = enclose_expression(expr: tex_coords) + ".xyz";
11643
11644 if (args.base.is_fetch)
11645 tex_coords = "uint3(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
11646 else if (sampling_type_needs_f32_conversion(type: coord_type))
11647 tex_coords = convert_to_f32(expr: tex_coords, components: 3);
11648
11649 alt_coord_component = 3;
11650 break;
11651
11652 case DimCube:
11653 if (args.base.is_fetch)
11654 {
11655 is_cube_fetch = true;
11656 tex_coords += ".xy";
11657 tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
11658 }
11659 else
11660 {
11661 if (coord_type.vecsize > 3)
11662 tex_coords = enclose_expression(expr: tex_coords) + ".xyz";
11663 }
11664
11665 if (sampling_type_needs_f32_conversion(type: coord_type))
11666 tex_coords = convert_to_f32(expr: tex_coords, components: 3);
11667
11668 alt_coord_component = 3;
11669 break;
11670
11671 default:
11672 break;
11673 }
11674
11675 if (args.base.is_fetch && args.offset)
11676 {
11677 // Fetch offsets must be applied directly to the coordinate.
11678 forward = forward && should_forward(id: args.offset);
11679 auto &type = expression_type(id: args.offset);
11680 if (imgtype.image.dim == Dim1D && msl_options.texture_1D_as_2D)
11681 {
11682 if (type.basetype != SPIRType::UInt)
11683 tex_coords += join(ts: " + uint2(", ts: bitcast_expression(target_type: SPIRType::UInt, arg: args.offset), ts: ", 0)");
11684 else
11685 tex_coords += join(ts: " + uint2(", ts: to_enclosed_unpacked_expression(id: args.offset), ts: ", 0)");
11686 }
11687 else
11688 {
11689 if (type.basetype != SPIRType::UInt)
11690 tex_coords += " + " + bitcast_expression(target_type: SPIRType::UInt, arg: args.offset);
11691 else
11692 tex_coords += " + " + to_enclosed_unpacked_expression(id: args.offset);
11693 }
11694 }
11695
11696 // If projection, use alt coord as divisor
11697 if (args.base.is_proj)
11698 {
11699 if (sampling_type_needs_f32_conversion(type: coord_type))
11700 tex_coords += " / " + convert_to_f32(expr: to_extract_component_expression(id: args.coord, index: alt_coord_component), components: 1);
11701 else
11702 tex_coords += " / " + to_extract_component_expression(id: args.coord, index: alt_coord_component);
11703 }
11704
11705 if (!farg_str.empty())
11706 farg_str += ", ";
11707
11708 if (imgtype.image.dim == DimCube && imgtype.image.arrayed && msl_options.emulate_cube_array)
11709 {
11710 farg_str += "spvCubemapTo2DArrayFace(" + tex_coords + ").xy";
11711
11712 if (is_cube_fetch)
11713 farg_str += ", uint(" + to_extract_component_expression(id: args.coord, index: 2) + ")";
11714 else
11715 farg_str +=
11716 ", uint(spvCubemapTo2DArrayFace(" + tex_coords + ").z) + (uint(" +
11717 round_fp_tex_coords(tex_coords: to_extract_component_expression(id: args.coord, index: alt_coord_component), coord_is_fp) +
11718 ") * 6u)";
11719
11720 add_spv_func_and_recompile(spv_func: SPVFuncImplCubemapTo2DArrayFace);
11721 }
11722 else
11723 {
11724 farg_str += tex_coords;
11725
11726 // If fetch from cube, add face explicitly
11727 if (is_cube_fetch)
11728 {
11729 // Special case for cube arrays, face and layer are packed in one dimension.
11730 if (imgtype.image.arrayed)
11731 farg_str += ", uint(" + to_extract_component_expression(id: args.coord, index: 2) + ") % 6u";
11732 else
11733 farg_str +=
11734 ", uint(" + round_fp_tex_coords(tex_coords: to_extract_component_expression(id: args.coord, index: 2), coord_is_fp) + ")";
11735 }
11736
11737 // If array, use alt coord
11738 if (imgtype.image.arrayed)
11739 {
11740 // Special case for cube arrays, face and layer are packed in one dimension.
11741 if (imgtype.image.dim == DimCube && args.base.is_fetch)
11742 {
11743 farg_str += ", uint(" + to_extract_component_expression(id: args.coord, index: 2) + ") / 6u";
11744 }
11745 else
11746 {
11747 farg_str +=
11748 ", uint(" +
11749 round_fp_tex_coords(tex_coords: to_extract_component_expression(id: args.coord, index: alt_coord_component), coord_is_fp) +
11750 ")";
11751 if (imgtype.image.dim == DimSubpassData)
11752 {
11753 if (msl_options.multiview)
11754 farg_str += " + gl_ViewIndex";
11755 else if (msl_options.arrayed_subpass_input)
11756 farg_str += " + gl_Layer";
11757 }
11758 }
11759 }
11760 else if (imgtype.image.dim == DimSubpassData)
11761 {
11762 if (msl_options.multiview)
11763 farg_str += ", gl_ViewIndex";
11764 else if (msl_options.arrayed_subpass_input)
11765 farg_str += ", gl_Layer";
11766 }
11767 }
11768
11769 // Depth compare reference value
11770 if (args.dref)
11771 {
11772 forward = forward && should_forward(id: args.dref);
11773 farg_str += ", ";
11774
11775 auto &dref_type = expression_type(id: args.dref);
11776
11777 string dref_expr;
11778 if (args.base.is_proj)
11779 dref_expr = join(ts: to_enclosed_unpacked_expression(id: args.dref), ts: " / ",
11780 ts: to_extract_component_expression(id: args.coord, index: alt_coord_component));
11781 else
11782 dref_expr = to_unpacked_expression(id: args.dref);
11783
11784 if (sampling_type_needs_f32_conversion(type: dref_type))
11785 dref_expr = convert_to_f32(expr: dref_expr, components: 1);
11786
11787 farg_str += dref_expr;
11788
11789 if (msl_options.is_macos() && (grad_x || grad_y))
11790 {
11791 // For sample compare, MSL does not support gradient2d for all targets (only iOS apparently according to docs).
11792 // However, the most common case here is to have a constant gradient of 0, as that is the only way to express
11793 // LOD == 0 in GLSL with sampler2DArrayShadow (cascaded shadow mapping).
11794 // We will detect a compile-time constant 0 value for gradient and promote that to level(0) on MSL.
11795 bool constant_zero_x = !grad_x || expression_is_constant_null(id: grad_x);
11796 bool constant_zero_y = !grad_y || expression_is_constant_null(id: grad_y);
11797 if (constant_zero_x && constant_zero_y &&
11798 (!imgtype.image.arrayed || !msl_options.sample_dref_lod_array_as_grad))
11799 {
11800 lod = 0;
11801 grad_x = 0;
11802 grad_y = 0;
11803 farg_str += ", level(0)";
11804 }
11805 else if (!msl_options.supports_msl_version(major: 2, minor: 3))
11806 {
11807 SPIRV_CROSS_THROW("Using non-constant 0.0 gradient() qualifier for sample_compare. This is not "
11808 "supported on macOS prior to MSL 2.3.");
11809 }
11810 }
11811
11812 if (msl_options.is_macos() && bias)
11813 {
11814 // Bias is not supported either on macOS with sample_compare.
11815 // Verify it is compile-time zero, and drop the argument.
11816 if (expression_is_constant_null(id: bias))
11817 {
11818 bias = 0;
11819 }
11820 else if (!msl_options.supports_msl_version(major: 2, minor: 3))
11821 {
11822 SPIRV_CROSS_THROW("Using non-constant 0.0 bias() qualifier for sample_compare. This is not supported "
11823 "on macOS prior to MSL 2.3.");
11824 }
11825 }
11826 }
11827
11828 // LOD Options
11829 // Metal does not support LOD for 1D textures.
11830 if (bias && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
11831 {
11832 forward = forward && should_forward(id: bias);
11833 farg_str += ", bias(" + to_unpacked_expression(id: bias) + ")";
11834 }
11835
11836 // Metal does not support LOD for 1D textures.
11837 if (lod && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
11838 {
11839 forward = forward && should_forward(id: lod);
11840 if (args.base.is_fetch)
11841 {
11842 farg_str += ", " + to_unpacked_expression(id: lod);
11843 }
11844 else if (msl_options.sample_dref_lod_array_as_grad && args.dref && imgtype.image.arrayed)
11845 {
11846 if (msl_options.is_macos() && !msl_options.supports_msl_version(major: 2, minor: 3))
11847 SPIRV_CROSS_THROW("Using non-constant 0.0 gradient() qualifier for sample_compare. This is not "
11848 "supported on macOS prior to MSL 2.3.");
11849 // Some Metal devices have a bug where the LoD is erroneously biased upward
11850 // when using a level() argument. Since this doesn't happen as much with gradient2d(),
11851 // if we perform the LoD calculation in reverse, we can pass a gradient
11852 // instead.
11853 // lod = log2(rhoMax/eta) -> exp2(lod) = rhoMax/eta
11854 // If we make all of the scale factors the same, eta will be 1 and
11855 // exp2(lod) = rho.
11856 // rhoX = dP/dx * extent; rhoY = dP/dy * extent
11857 // Therefore, dP/dx = dP/dy = exp2(lod)/extent.
11858 // (Subtracting 0.5 before exponentiation gives better results.)
11859 string grad_opt, extent, grad_coord;
11860 VariableID base_img = img;
11861 if (auto *combined = maybe_get<SPIRCombinedImageSampler>(id: img))
11862 base_img = combined->image;
11863 switch (imgtype.image.dim)
11864 {
11865 case Dim1D:
11866 grad_opt = "gradient2d";
11867 extent = join(ts: "float2(", ts: to_expression(id: base_img), ts: ".get_width(), 1.0)");
11868 break;
11869 case Dim2D:
11870 grad_opt = "gradient2d";
11871 extent = join(ts: "float2(", ts: to_expression(id: base_img), ts: ".get_width(), ", ts: to_expression(id: base_img), ts: ".get_height())");
11872 break;
11873 case DimCube:
11874 if (imgtype.image.arrayed && msl_options.emulate_cube_array)
11875 {
11876 grad_opt = "gradient2d";
11877 extent = join(ts: "float2(", ts: to_expression(id: base_img), ts: ".get_width())");
11878 }
11879 else
11880 {
11881 if (msl_options.agx_manual_cube_grad_fixup)
11882 {
11883 add_spv_func_and_recompile(spv_func: SPVFuncImplGradientCube);
11884 grad_opt = "spvGradientCube";
11885 grad_coord = tex_coords + ", ";
11886 }
11887 else
11888 {
11889 grad_opt = "gradientcube";
11890 }
11891 extent = join(ts: "float3(", ts: to_expression(id: base_img), ts: ".get_width())");
11892 }
11893 break;
11894 default:
11895 grad_opt = "unsupported_gradient_dimension";
11896 extent = "float3(1.0)";
11897 break;
11898 }
11899 farg_str += join(ts: ", ", ts&: grad_opt, ts: "(", ts&: grad_coord, ts: "exp2(", ts: to_unpacked_expression(id: lod), ts: " - 0.5) / ", ts&: extent,
11900 ts: ", exp2(", ts: to_unpacked_expression(id: lod), ts: " - 0.5) / ", ts&: extent, ts: ")");
11901 }
11902 else
11903 {
11904 farg_str += ", level(" + to_unpacked_expression(id: lod) + ")";
11905 }
11906 }
11907 else if (args.base.is_fetch && !lod && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D) &&
11908 imgtype.image.dim != DimBuffer && !imgtype.image.ms && imgtype.image.sampled != 2)
11909 {
11910 // Lod argument is optional in OpImageFetch, but we require a LOD value, pick 0 as the default.
11911 // Check for sampled type as well, because is_fetch is also used for OpImageRead in MSL.
11912 farg_str += ", 0";
11913 }
11914
11915 // Metal does not support LOD for 1D textures.
11916 if ((grad_x || grad_y) && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
11917 {
11918 forward = forward && should_forward(id: grad_x);
11919 forward = forward && should_forward(id: grad_y);
11920 string grad_opt, grad_coord;
11921 switch (imgtype.image.dim)
11922 {
11923 case Dim1D:
11924 case Dim2D:
11925 grad_opt = "gradient2d";
11926 break;
11927 case Dim3D:
11928 grad_opt = "gradient3d";
11929 break;
11930 case DimCube:
11931 if (imgtype.image.arrayed && msl_options.emulate_cube_array)
11932 {
11933 grad_opt = "gradient2d";
11934 }
11935 else if (msl_options.agx_manual_cube_grad_fixup)
11936 {
11937 add_spv_func_and_recompile(spv_func: SPVFuncImplGradientCube);
11938 grad_opt = "spvGradientCube";
11939 grad_coord = tex_coords + ", ";
11940 }
11941 else
11942 {
11943 grad_opt = "gradientcube";
11944 }
11945 break;
11946 default:
11947 grad_opt = "unsupported_gradient_dimension";
11948 break;
11949 }
11950 farg_str += join(ts: ", ", ts&: grad_opt, ts: "(", ts&: grad_coord, ts: to_unpacked_expression(id: grad_x), ts: ", ", ts: to_unpacked_expression(id: grad_y), ts: ")");
11951 }
11952
11953 if (args.min_lod)
11954 {
11955 if (!msl_options.supports_msl_version(major: 2, minor: 2))
11956 SPIRV_CROSS_THROW("min_lod_clamp() is only supported in MSL 2.2+ and up.");
11957
11958 forward = forward && should_forward(id: args.min_lod);
11959 farg_str += ", min_lod_clamp(" + to_unpacked_expression(id: args.min_lod) + ")";
11960 }
11961
11962 // Add offsets
11963 string offset_expr;
11964 const SPIRType *offset_type = nullptr;
11965 if (args.offset && !args.base.is_fetch && !args.has_array_offsets)
11966 {
11967 forward = forward && should_forward(id: args.offset);
11968 offset_expr = to_unpacked_expression(id: args.offset);
11969 offset_type = &expression_type(id: args.offset);
11970 }
11971
11972 if (!offset_expr.empty())
11973 {
11974 switch (imgtype.image.dim)
11975 {
11976 case Dim1D:
11977 if (!msl_options.texture_1D_as_2D)
11978 break;
11979 if (offset_type->vecsize > 1)
11980 offset_expr = enclose_expression(expr: offset_expr) + ".x";
11981
11982 farg_str += join(ts: ", int2(", ts&: offset_expr, ts: ", 0)");
11983 break;
11984
11985 case Dim2D:
11986 if (offset_type->vecsize > 2)
11987 offset_expr = enclose_expression(expr: offset_expr) + ".xy";
11988
11989 farg_str += ", " + offset_expr;
11990 break;
11991
11992 case Dim3D:
11993 if (offset_type->vecsize > 3)
11994 offset_expr = enclose_expression(expr: offset_expr) + ".xyz";
11995
11996 farg_str += ", " + offset_expr;
11997 break;
11998
11999 default:
12000 break;
12001 }
12002 }
12003
12004 if (args.component && !args.has_array_offsets)
12005 {
12006 // If 2D has gather component, ensure it also has an offset arg
12007 if (imgtype.image.dim == Dim2D && offset_expr.empty())
12008 farg_str += ", int2(0)";
12009
12010 if (!msl_options.swizzle_texture_samples || is_dynamic_img_sampler)
12011 {
12012 forward = forward && should_forward(id: args.component);
12013
12014 uint32_t image_var = 0;
12015 if (const auto *combined = maybe_get<SPIRCombinedImageSampler>(id: img))
12016 {
12017 if (const auto *img_var = maybe_get_backing_variable(chain: combined->image))
12018 image_var = img_var->self;
12019 }
12020 else if (const auto *var = maybe_get_backing_variable(chain: img))
12021 {
12022 image_var = var->self;
12023 }
12024
12025 if (image_var == 0 || !is_depth_image(type: expression_type(id: image_var), id: image_var))
12026 farg_str += ", " + to_component_argument(id: args.component);
12027 }
12028 }
12029
12030 if (args.sample)
12031 {
12032 forward = forward && should_forward(id: args.sample);
12033 farg_str += ", ";
12034 farg_str += to_unpacked_expression(id: args.sample);
12035 }
12036
12037 *p_forward = forward;
12038
12039 return farg_str;
12040}
12041
12042// If the texture coordinates are floating point, invokes MSL round() function to round them.
12043string CompilerMSL::round_fp_tex_coords(string tex_coords, bool coord_is_fp)
12044{
12045 return coord_is_fp ? ("rint(" + tex_coords + ")") : tex_coords;
12046}
12047
12048// Returns a string to use in an image sampling function argument.
12049// The ID must be a scalar constant.
12050string CompilerMSL::to_component_argument(uint32_t id)
12051{
12052 uint32_t component_index = evaluate_constant_u32(id);
12053 switch (component_index)
12054 {
12055 case 0:
12056 return "component::x";
12057 case 1:
12058 return "component::y";
12059 case 2:
12060 return "component::z";
12061 case 3:
12062 return "component::w";
12063
12064 default:
12065 SPIRV_CROSS_THROW("The value (" + to_string(component_index) + ") of OpConstant ID " + to_string(id) +
12066 " is not a valid Component index, which must be one of 0, 1, 2, or 3.");
12067 }
12068}
12069
12070// Establish sampled image as expression object and assign the sampler to it.
12071void CompilerMSL::emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id)
12072{
12073 set<SPIRCombinedImageSampler>(id: result_id, args&: result_type, args&: image_id, args&: samp_id);
12074}
12075
12076string CompilerMSL::to_texture_op(const Instruction &i, bool sparse, bool *forward,
12077 SmallVector<uint32_t> &inherited_expressions)
12078{
12079 auto *ops = stream(instr: i);
12080 uint32_t result_type_id = ops[0];
12081 uint32_t img = ops[2];
12082 auto &result_type = get<SPIRType>(id: result_type_id);
12083 auto op = static_cast<Op>(i.op);
12084 bool is_gather = (op == OpImageGather || op == OpImageDrefGather);
12085
12086 // Bypass pointers because we need the real image struct
12087 auto &type = expression_type(id: img);
12088 auto &imgtype = get<SPIRType>(id: type.self);
12089
12090 const MSLConstexprSampler *constexpr_sampler = nullptr;
12091 bool is_dynamic_img_sampler = false;
12092 if (auto *var = maybe_get_backing_variable(chain: img))
12093 {
12094 constexpr_sampler = find_constexpr_sampler(id: var->basevariable ? var->basevariable : VariableID(var->self));
12095 is_dynamic_img_sampler = has_extended_decoration(id: var->self, decoration: SPIRVCrossDecorationDynamicImageSampler);
12096 }
12097
12098 string expr;
12099 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && !is_dynamic_img_sampler)
12100 {
12101 // If this needs sampler Y'CbCr conversion, we need to do some additional
12102 // processing.
12103 switch (constexpr_sampler->ycbcr_model)
12104 {
12105 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY:
12106 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY:
12107 // Default
12108 break;
12109 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_709:
12110 add_spv_func_and_recompile(spv_func: SPVFuncImplConvertYCbCrBT709);
12111 expr += "spvConvertYCbCrBT709(";
12112 break;
12113 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_601:
12114 add_spv_func_and_recompile(spv_func: SPVFuncImplConvertYCbCrBT601);
12115 expr += "spvConvertYCbCrBT601(";
12116 break;
12117 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_2020:
12118 add_spv_func_and_recompile(spv_func: SPVFuncImplConvertYCbCrBT2020);
12119 expr += "spvConvertYCbCrBT2020(";
12120 break;
12121 default:
12122 SPIRV_CROSS_THROW("Invalid Y'CbCr model conversion.");
12123 }
12124
12125 if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY)
12126 {
12127 switch (constexpr_sampler->ycbcr_range)
12128 {
12129 case MSL_SAMPLER_YCBCR_RANGE_ITU_FULL:
12130 add_spv_func_and_recompile(spv_func: SPVFuncImplExpandITUFullRange);
12131 expr += "spvExpandITUFullRange(";
12132 break;
12133 case MSL_SAMPLER_YCBCR_RANGE_ITU_NARROW:
12134 add_spv_func_and_recompile(spv_func: SPVFuncImplExpandITUNarrowRange);
12135 expr += "spvExpandITUNarrowRange(";
12136 break;
12137 default:
12138 SPIRV_CROSS_THROW("Invalid Y'CbCr range.");
12139 }
12140 }
12141 }
12142 else if (msl_options.swizzle_texture_samples && !is_gather && is_sampled_image_type(type: imgtype) &&
12143 !is_dynamic_img_sampler)
12144 {
12145 add_spv_func_and_recompile(spv_func: SPVFuncImplTextureSwizzle);
12146 expr += "spvTextureSwizzle(";
12147 }
12148
12149 string inner_expr = CompilerGLSL::to_texture_op(i, sparse, forward, inherited_expressions);
12150
12151 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && !is_dynamic_img_sampler)
12152 {
12153 if (!constexpr_sampler->swizzle_is_identity())
12154 {
12155 static const char swizzle_names[] = "rgba";
12156 if (!constexpr_sampler->swizzle_has_one_or_zero())
12157 {
12158 // If we can, do it inline.
12159 expr += inner_expr + ".";
12160 for (uint32_t c = 0; c < 4; c++)
12161 {
12162 switch (constexpr_sampler->swizzle[c])
12163 {
12164 case MSL_COMPONENT_SWIZZLE_IDENTITY:
12165 expr += swizzle_names[c];
12166 break;
12167 case MSL_COMPONENT_SWIZZLE_R:
12168 case MSL_COMPONENT_SWIZZLE_G:
12169 case MSL_COMPONENT_SWIZZLE_B:
12170 case MSL_COMPONENT_SWIZZLE_A:
12171 expr += swizzle_names[constexpr_sampler->swizzle[c] - MSL_COMPONENT_SWIZZLE_R];
12172 break;
12173 default:
12174 SPIRV_CROSS_THROW("Invalid component swizzle.");
12175 }
12176 }
12177 }
12178 else
12179 {
12180 // Otherwise, we need to emit a temporary and swizzle that.
12181 uint32_t temp_id = ir.increase_bound_by(count: 1);
12182 emit_op(result_type: result_type_id, result_id: temp_id, rhs: inner_expr, forward_rhs: false);
12183 for (auto &inherit : inherited_expressions)
12184 inherit_expression_dependencies(dst: temp_id, source: inherit);
12185 inherited_expressions.clear();
12186 inherited_expressions.push_back(t: temp_id);
12187
12188 switch (op)
12189 {
12190 case OpImageSampleDrefImplicitLod:
12191 case OpImageSampleImplicitLod:
12192 case OpImageSampleProjImplicitLod:
12193 case OpImageSampleProjDrefImplicitLod:
12194 register_control_dependent_expression(expr: temp_id);
12195 break;
12196
12197 default:
12198 break;
12199 }
12200 expr += type_to_glsl(type: result_type) + "(";
12201 for (uint32_t c = 0; c < 4; c++)
12202 {
12203 switch (constexpr_sampler->swizzle[c])
12204 {
12205 case MSL_COMPONENT_SWIZZLE_IDENTITY:
12206 expr += to_expression(id: temp_id) + "." + swizzle_names[c];
12207 break;
12208 case MSL_COMPONENT_SWIZZLE_ZERO:
12209 expr += "0";
12210 break;
12211 case MSL_COMPONENT_SWIZZLE_ONE:
12212 expr += "1";
12213 break;
12214 case MSL_COMPONENT_SWIZZLE_R:
12215 case MSL_COMPONENT_SWIZZLE_G:
12216 case MSL_COMPONENT_SWIZZLE_B:
12217 case MSL_COMPONENT_SWIZZLE_A:
12218 expr += to_expression(id: temp_id) + "." +
12219 swizzle_names[constexpr_sampler->swizzle[c] - MSL_COMPONENT_SWIZZLE_R];
12220 break;
12221 default:
12222 SPIRV_CROSS_THROW("Invalid component swizzle.");
12223 }
12224 if (c < 3)
12225 expr += ", ";
12226 }
12227 expr += ")";
12228 }
12229 }
12230 else
12231 expr += inner_expr;
12232 if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY)
12233 {
12234 expr += join(ts: ", ", ts: constexpr_sampler->bpc, ts: ")");
12235 if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY)
12236 expr += ")";
12237 }
12238 }
12239 else
12240 {
12241 expr += inner_expr;
12242 if (msl_options.swizzle_texture_samples && !is_gather && is_sampled_image_type(type: imgtype) &&
12243 !is_dynamic_img_sampler)
12244 {
12245 // Add the swizzle constant from the swizzle buffer.
12246 expr += ", " + to_swizzle_expression(id: img) + ")";
12247 used_swizzle_buffer = true;
12248 }
12249 }
12250
12251 return expr;
12252}
12253
12254static string create_swizzle(MSLComponentSwizzle swizzle)
12255{
12256 switch (swizzle)
12257 {
12258 case MSL_COMPONENT_SWIZZLE_IDENTITY:
12259 return "spvSwizzle::none";
12260 case MSL_COMPONENT_SWIZZLE_ZERO:
12261 return "spvSwizzle::zero";
12262 case MSL_COMPONENT_SWIZZLE_ONE:
12263 return "spvSwizzle::one";
12264 case MSL_COMPONENT_SWIZZLE_R:
12265 return "spvSwizzle::red";
12266 case MSL_COMPONENT_SWIZZLE_G:
12267 return "spvSwizzle::green";
12268 case MSL_COMPONENT_SWIZZLE_B:
12269 return "spvSwizzle::blue";
12270 case MSL_COMPONENT_SWIZZLE_A:
12271 return "spvSwizzle::alpha";
12272 default:
12273 SPIRV_CROSS_THROW("Invalid component swizzle.");
12274 }
12275}
12276
12277// Returns a string representation of the ID, usable as a function arg.
12278// Manufacture automatic sampler arg for SampledImage texture.
12279string CompilerMSL::to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id)
12280{
12281 string arg_str;
12282
12283 auto &type = expression_type(id);
12284 bool is_dynamic_img_sampler = has_extended_decoration(id: arg.id, decoration: SPIRVCrossDecorationDynamicImageSampler);
12285 // If the argument *itself* is a "dynamic" combined-image sampler, then we can just pass that around.
12286 bool arg_is_dynamic_img_sampler = has_extended_decoration(id, decoration: SPIRVCrossDecorationDynamicImageSampler);
12287 if (is_dynamic_img_sampler && !arg_is_dynamic_img_sampler)
12288 arg_str = join(ts: "spvDynamicImageSampler<", ts: type_to_glsl(type: get<SPIRType>(id: type.image.type)), ts: ">(");
12289
12290 auto *c = maybe_get<SPIRConstant>(id);
12291 if (msl_options.force_native_arrays && c && !get<SPIRType>(id: c->constant_type).array.empty())
12292 {
12293 // If we are passing a constant array directly to a function for some reason,
12294 // the callee will expect an argument in thread const address space
12295 // (since we can only bind to arrays with references in MSL).
12296 // To resolve this, we must emit a copy in this address space.
12297 // This kind of code gen should be rare enough that performance is not a real concern.
12298 // Inline the SPIR-V to avoid this kind of suboptimal codegen.
12299 //
12300 // We risk calling this inside a continue block (invalid code),
12301 // so just create a thread local copy in the current function.
12302 arg_str = join(ts: "_", ts&: id, ts: "_array_copy");
12303 auto &constants = current_function->constant_arrays_needed_on_stack;
12304 auto itr = find(first: begin(cont&: constants), last: end(cont&: constants), val: ID(id));
12305 if (itr == end(cont&: constants))
12306 {
12307 force_recompile();
12308 constants.push_back(t: id);
12309 }
12310 }
12311 // Dereference pointer variables where needed.
12312 // FIXME: This dereference is actually backwards. We should really just support passing pointer variables between functions.
12313 else if (should_dereference(id))
12314 arg_str += dereference_expression(expression_type: type, expr: CompilerGLSL::to_func_call_arg(arg, id));
12315 else
12316 arg_str += CompilerGLSL::to_func_call_arg(arg, id);
12317
12318 // Need to check the base variable in case we need to apply a qualified alias.
12319 uint32_t var_id = 0;
12320 auto *var = maybe_get<SPIRVariable>(id);
12321 if (var)
12322 var_id = var->basevariable;
12323
12324 if (!arg_is_dynamic_img_sampler)
12325 {
12326 auto *constexpr_sampler = find_constexpr_sampler(id: var_id ? var_id : id);
12327 if (type.basetype == SPIRType::SampledImage)
12328 {
12329 // Manufacture automatic plane args for multiplanar texture
12330 uint32_t planes = 1;
12331 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
12332 {
12333 planes = constexpr_sampler->planes;
12334 // If this parameter isn't aliasing a global, then we need to use
12335 // the special "dynamic image-sampler" class to pass it--and we need
12336 // to use it for *every* non-alias parameter, in case a combined
12337 // image-sampler with a Y'CbCr conversion is passed. Hopefully, this
12338 // pathological case is so rare that it should never be hit in practice.
12339 if (!arg.alias_global_variable)
12340 add_spv_func_and_recompile(spv_func: SPVFuncImplDynamicImageSampler);
12341 }
12342 for (uint32_t i = 1; i < planes; i++)
12343 arg_str += join(ts: ", ", ts: CompilerGLSL::to_func_call_arg(arg, id), ts&: plane_name_suffix, ts&: i);
12344 // Manufacture automatic sampler arg if the arg is a SampledImage texture.
12345 if (type.image.dim != DimBuffer)
12346 arg_str += ", " + to_sampler_expression(id: var_id ? var_id : id);
12347
12348 // Add sampler Y'CbCr conversion info if we have it
12349 if (is_dynamic_img_sampler && constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
12350 {
12351 SmallVector<string> samp_args;
12352
12353 switch (constexpr_sampler->resolution)
12354 {
12355 case MSL_FORMAT_RESOLUTION_444:
12356 // Default
12357 break;
12358 case MSL_FORMAT_RESOLUTION_422:
12359 samp_args.push_back(t: "spvFormatResolution::_422");
12360 break;
12361 case MSL_FORMAT_RESOLUTION_420:
12362 samp_args.push_back(t: "spvFormatResolution::_420");
12363 break;
12364 default:
12365 SPIRV_CROSS_THROW("Invalid format resolution.");
12366 }
12367
12368 if (constexpr_sampler->chroma_filter != MSL_SAMPLER_FILTER_NEAREST)
12369 samp_args.push_back(t: "spvChromaFilter::linear");
12370
12371 if (constexpr_sampler->x_chroma_offset != MSL_CHROMA_LOCATION_COSITED_EVEN)
12372 samp_args.push_back(t: "spvXChromaLocation::midpoint");
12373 if (constexpr_sampler->y_chroma_offset != MSL_CHROMA_LOCATION_COSITED_EVEN)
12374 samp_args.push_back(t: "spvYChromaLocation::midpoint");
12375 switch (constexpr_sampler->ycbcr_model)
12376 {
12377 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY:
12378 // Default
12379 break;
12380 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY:
12381 samp_args.push_back(t: "spvYCbCrModelConversion::ycbcr_identity");
12382 break;
12383 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_709:
12384 samp_args.push_back(t: "spvYCbCrModelConversion::ycbcr_bt_709");
12385 break;
12386 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_601:
12387 samp_args.push_back(t: "spvYCbCrModelConversion::ycbcr_bt_601");
12388 break;
12389 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_2020:
12390 samp_args.push_back(t: "spvYCbCrModelConversion::ycbcr_bt_2020");
12391 break;
12392 default:
12393 SPIRV_CROSS_THROW("Invalid Y'CbCr model conversion.");
12394 }
12395 if (constexpr_sampler->ycbcr_range != MSL_SAMPLER_YCBCR_RANGE_ITU_FULL)
12396 samp_args.push_back(t: "spvYCbCrRange::itu_narrow");
12397 samp_args.push_back(t: join(ts: "spvComponentBits(", ts: constexpr_sampler->bpc, ts: ")"));
12398 arg_str += join(ts: ", spvYCbCrSampler(", ts: merge(list: samp_args), ts: ")");
12399 }
12400 }
12401
12402 if (is_dynamic_img_sampler && constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
12403 arg_str += join(ts: ", (uint(", ts: create_swizzle(swizzle: constexpr_sampler->swizzle[3]), ts: ") << 24) | (uint(",
12404 ts: create_swizzle(swizzle: constexpr_sampler->swizzle[2]), ts: ") << 16) | (uint(",
12405 ts: create_swizzle(swizzle: constexpr_sampler->swizzle[1]), ts: ") << 8) | uint(",
12406 ts: create_swizzle(swizzle: constexpr_sampler->swizzle[0]), ts: ")");
12407 else if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
12408 arg_str += ", " + to_swizzle_expression(id: var_id ? var_id : id);
12409
12410 if (buffer_requires_array_length(id: var_id))
12411 arg_str += ", " + to_buffer_size_expression(id: var_id ? var_id : id);
12412
12413 if (is_dynamic_img_sampler)
12414 arg_str += ")";
12415 }
12416
12417 // Emulate texture2D atomic operations
12418 auto *backing_var = maybe_get_backing_variable(chain: var_id);
12419 if (backing_var && atomic_image_vars_emulated.count(x: backing_var->self))
12420 {
12421 arg_str += ", " + to_expression(id: var_id) + "_atomic";
12422 }
12423
12424 return arg_str;
12425}
12426
12427// If the ID represents a sampled image that has been assigned a sampler already,
12428// generate an expression for the sampler, otherwise generate a fake sampler name
12429// by appending a suffix to the expression constructed from the ID.
12430string CompilerMSL::to_sampler_expression(uint32_t id)
12431{
12432 auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
12433 if (combined && combined->sampler)
12434 return to_expression(id: combined->sampler);
12435
12436 uint32_t expr_id = combined ? uint32_t(combined->image) : id;
12437
12438 // Constexpr samplers are declared as local variables,
12439 // so exclude any qualifier names on the image expression.
12440 if (auto *var = maybe_get_backing_variable(chain: expr_id))
12441 {
12442 uint32_t img_id = var->basevariable ? var->basevariable : VariableID(var->self);
12443 if (find_constexpr_sampler(id: img_id))
12444 return Compiler::to_name(id: img_id) + sampler_name_suffix;
12445 }
12446
12447 auto img_expr = to_expression(id: expr_id);
12448 auto index = img_expr.find_first_of(c: '[');
12449 if (index == string::npos)
12450 return img_expr + sampler_name_suffix;
12451 else
12452 return img_expr.substr(pos: 0, n: index) + sampler_name_suffix + img_expr.substr(pos: index);
12453}
12454
12455string CompilerMSL::to_swizzle_expression(uint32_t id)
12456{
12457 auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
12458
12459 auto expr = to_expression(id: combined ? combined->image : VariableID(id));
12460 auto index = expr.find_first_of(c: '[');
12461
12462 // If an image is part of an argument buffer translate this to a legal identifier.
12463 string::size_type period = 0;
12464 while ((period = expr.find_first_of(c: '.', pos: period)) != string::npos && period < index)
12465 expr[period] = '_';
12466
12467 if (index == string::npos)
12468 return expr + swizzle_name_suffix;
12469 else
12470 {
12471 auto image_expr = expr.substr(pos: 0, n: index);
12472 auto array_expr = expr.substr(pos: index);
12473 return image_expr + swizzle_name_suffix + array_expr;
12474 }
12475}
12476
12477string CompilerMSL::to_buffer_size_expression(uint32_t id)
12478{
12479 auto expr = to_expression(id);
12480 auto index = expr.find_first_of(c: '[');
12481
12482 // This is quite crude, but we need to translate the reference name (*spvDescriptorSetN.name) to
12483 // the pointer expression spvDescriptorSetN.name to make a reasonable expression here.
12484 // This only happens if we have argument buffers and we are using OpArrayLength on a lone SSBO in that set.
12485 if (expr.size() >= 3 && expr[0] == '(' && expr[1] == '*')
12486 expr = address_of_expression(expr);
12487
12488 // If a buffer is part of an argument buffer translate this to a legal identifier.
12489 for (auto &c : expr)
12490 if (c == '.')
12491 c = '_';
12492
12493 if (index == string::npos)
12494 return expr + buffer_size_name_suffix;
12495 else
12496 {
12497 auto buffer_expr = expr.substr(pos: 0, n: index);
12498 auto array_expr = expr.substr(pos: index);
12499 if (auto var = maybe_get_backing_variable(chain: id))
12500 {
12501 if (is_var_runtime_size_array(var: *var))
12502 {
12503 if (!msl_options.runtime_array_rich_descriptor)
12504 SPIRV_CROSS_THROW("OpArrayLength requires rich descriptor format");
12505
12506 auto last_pos = array_expr.find_last_of(c: ']');
12507 if (last_pos != std::string::npos)
12508 return buffer_expr + ".length(" + array_expr.substr(pos: 1, n: last_pos - 1) + ")";
12509 }
12510 }
12511 return buffer_expr + buffer_size_name_suffix + array_expr;
12512 }
12513}
12514
12515// Checks whether the type is a Block all of whose members have DecorationPatch.
12516bool CompilerMSL::is_patch_block(const SPIRType &type)
12517{
12518 if (!has_decoration(id: type.self, decoration: DecorationBlock))
12519 return false;
12520
12521 for (uint32_t i = 0; i < type.member_types.size(); i++)
12522 {
12523 if (!has_member_decoration(id: type.self, index: i, decoration: DecorationPatch))
12524 return false;
12525 }
12526
12527 return true;
12528}
12529
12530// Checks whether the ID is a row_major matrix that requires conversion before use
12531bool CompilerMSL::is_non_native_row_major_matrix(uint32_t id)
12532{
12533 auto *e = maybe_get<SPIRExpression>(id);
12534 if (e)
12535 return e->need_transpose;
12536 else
12537 return has_decoration(id, decoration: DecorationRowMajor);
12538}
12539
12540// Checks whether the member is a row_major matrix that requires conversion before use
12541bool CompilerMSL::member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index)
12542{
12543 return has_member_decoration(id: type.self, index, decoration: DecorationRowMajor);
12544}
12545
12546string CompilerMSL::convert_row_major_matrix(string exp_str, const SPIRType &exp_type, uint32_t physical_type_id,
12547 bool is_packed, bool relaxed)
12548{
12549 if (!is_matrix(type: exp_type))
12550 {
12551 return CompilerGLSL::convert_row_major_matrix(exp_str: std::move(exp_str), exp_type, physical_type_id, is_packed, relaxed);
12552 }
12553 else
12554 {
12555 strip_enclosed_expression(expr&: exp_str);
12556 if (physical_type_id != 0 || is_packed)
12557 exp_str = unpack_expression_type(expr_str: exp_str, type: exp_type, physical_type_id, packed: is_packed, row_major: true);
12558 return join(ts: "transpose(", ts&: exp_str, ts: ")");
12559 }
12560}
12561
12562// Called automatically at the end of the entry point function
12563void CompilerMSL::emit_fixup()
12564{
12565 if (is_vertex_like_shader() && stage_out_var_id && !qual_pos_var_name.empty() && !capture_output_to_buffer)
12566 {
12567 if (options.vertex.fixup_clipspace)
12568 statement(ts&: qual_pos_var_name, ts: ".z = (", ts&: qual_pos_var_name, ts: ".z + ", ts&: qual_pos_var_name,
12569 ts: ".w) * 0.5; // Adjust clip-space for Metal");
12570
12571 if (options.vertex.flip_vert_y)
12572 statement(ts&: qual_pos_var_name, ts: ".y = -(", ts&: qual_pos_var_name, ts: ".y);", ts: " // Invert Y-axis for Metal");
12573 }
12574}
12575
12576// Return a string defining a structure member, with padding and packing.
12577string CompilerMSL::to_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
12578 const string &qualifier)
12579{
12580 uint32_t orig_member_type_id = member_type_id;
12581 if (member_is_remapped_physical_type(type, index))
12582 member_type_id = get_extended_member_decoration(type: type.self, index, decoration: SPIRVCrossDecorationPhysicalTypeID);
12583 auto &physical_type = get<SPIRType>(id: member_type_id);
12584
12585 // If this member is packed, mark it as so.
12586 string pack_pfx;
12587
12588 // Allow Metal to use the array<T> template to make arrays a value type
12589 uint32_t orig_id = 0;
12590 if (has_extended_member_decoration(type: type.self, index, decoration: SPIRVCrossDecorationInterfaceOrigID))
12591 orig_id = get_extended_member_decoration(type: type.self, index, decoration: SPIRVCrossDecorationInterfaceOrigID);
12592
12593 bool row_major = false;
12594 if (is_matrix(type: physical_type))
12595 row_major = has_member_decoration(id: type.self, index, decoration: DecorationRowMajor);
12596
12597 SPIRType row_major_physical_type { OpTypeMatrix };
12598 const SPIRType *declared_type = &physical_type;
12599
12600 // If a struct is being declared with physical layout,
12601 // do not use array<T> wrappers.
12602 // This avoids a lot of complicated cases with packed vectors and matrices,
12603 // and generally we cannot copy full arrays in and out of buffers into Function
12604 // address space.
12605 // Array of resources should also be declared as builtin arrays.
12606 if (has_member_decoration(id: type.self, index, decoration: DecorationOffset))
12607 is_using_builtin_array = true;
12608 else if (has_extended_member_decoration(type: type.self, index, decoration: SPIRVCrossDecorationResourceIndexPrimary))
12609 is_using_builtin_array = true;
12610
12611 if (member_is_packed_physical_type(type, index))
12612 {
12613 // If we're packing a matrix, output an appropriate typedef
12614 if (physical_type.basetype == SPIRType::Struct)
12615 {
12616 SPIRV_CROSS_THROW("Cannot emit a packed struct currently.");
12617 }
12618 else if (is_matrix(type: physical_type))
12619 {
12620 uint32_t rows = physical_type.vecsize;
12621 uint32_t cols = physical_type.columns;
12622 pack_pfx = "packed_";
12623 if (row_major)
12624 {
12625 // These are stored transposed.
12626 rows = physical_type.columns;
12627 cols = physical_type.vecsize;
12628 pack_pfx = "packed_rm_";
12629 }
12630 string base_type = physical_type.width == 16 ? "half" : "float";
12631 string td_line = "typedef ";
12632 td_line += "packed_" + base_type + to_string(val: rows);
12633 td_line += " " + pack_pfx;
12634 // Use the actual matrix size here.
12635 td_line += base_type + to_string(val: physical_type.columns) + "x" + to_string(val: physical_type.vecsize);
12636 td_line += "[" + to_string(val: cols) + "]";
12637 td_line += ";";
12638 add_typedef_line(line: td_line);
12639 }
12640 else if (!is_scalar(type: physical_type)) // scalar type is already packed.
12641 pack_pfx = "packed_";
12642 }
12643 else if (is_matrix(type: physical_type))
12644 {
12645 if (!msl_options.supports_msl_version(major: 3, minor: 0) &&
12646 has_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationWorkgroupStruct))
12647 {
12648 pack_pfx = "spvStorage_";
12649 add_spv_func_and_recompile(spv_func: SPVFuncImplStorageMatrix);
12650 // The pack prefix causes problems with array<T> wrappers.
12651 is_using_builtin_array = true;
12652 }
12653 if (row_major)
12654 {
12655 // Need to declare type with flipped vecsize/columns.
12656 row_major_physical_type = physical_type;
12657 swap(a&: row_major_physical_type.vecsize, b&: row_major_physical_type.columns);
12658 declared_type = &row_major_physical_type;
12659 }
12660 }
12661
12662 // iOS Tier 1 argument buffers do not support writable images.
12663 if (physical_type.basetype == SPIRType::Image &&
12664 physical_type.image.sampled == 2 &&
12665 msl_options.is_ios() &&
12666 msl_options.argument_buffers_tier <= Options::ArgumentBuffersTier::Tier1 &&
12667 !has_decoration(id: orig_id, decoration: DecorationNonWritable))
12668 {
12669 SPIRV_CROSS_THROW("Writable images are not allowed on Tier1 argument buffers on iOS.");
12670 }
12671
12672 // Array information is baked into these types.
12673 string array_type;
12674 if (physical_type.basetype != SPIRType::Image && physical_type.basetype != SPIRType::Sampler &&
12675 physical_type.basetype != SPIRType::SampledImage)
12676 {
12677 BuiltIn builtin = BuiltInMax;
12678
12679 // Special handling. In [[stage_out]] or [[stage_in]] blocks,
12680 // we need flat arrays, but if we're somehow declaring gl_PerVertex for constant array reasons, we want
12681 // template array types to be declared.
12682 bool is_ib_in_out =
12683 ((stage_out_var_id && get_stage_out_struct_type().self == type.self &&
12684 variable_storage_requires_stage_io(storage: StorageClassOutput)) ||
12685 (stage_in_var_id && get_stage_in_struct_type().self == type.self &&
12686 variable_storage_requires_stage_io(storage: StorageClassInput))) ||
12687 is_mesh_shader();
12688 if (is_ib_in_out && is_member_builtin(type, index, builtin: &builtin))
12689 is_using_builtin_array = true;
12690 array_type = type_to_array_glsl(type: physical_type, variable_id: orig_id);
12691 }
12692
12693 if (is_mesh_shader())
12694 {
12695 BuiltIn builtin = BuiltInMax;
12696 if (is_member_builtin(type, index, builtin: &builtin))
12697 {
12698 if (builtin == BuiltInPrimitiveShadingRateKHR)
12699 {
12700 // not supported in metal 3.0
12701 is_using_builtin_array = false;
12702 return "";
12703 }
12704
12705 SPIRType metallic_type = *declared_type;
12706 if (builtin == BuiltInCullPrimitiveEXT)
12707 metallic_type.basetype = SPIRType::Boolean;
12708 else if (builtin == BuiltInPrimitiveId || builtin == BuiltInLayer || builtin == BuiltInViewportIndex)
12709 metallic_type.basetype = SPIRType::UInt;
12710
12711 is_using_builtin_array = true;
12712 std::string result;
12713 if (has_member_decoration(id: type.self, index: orig_id, decoration: DecorationBuiltIn))
12714 {
12715 // avoid '_RESERVED_IDENTIFIER_FIXUP_' in variable name
12716 result = join(ts: type_to_glsl(type: metallic_type, id: orig_id, member: false), ts: " ", ts: qualifier,
12717 ts: builtin_to_glsl(builtin, storage: StorageClassOutput), ts: member_attribute_qualifier(type, index),
12718 ts&: array_type, ts: ";");
12719 }
12720 else
12721 {
12722 result = join(ts: type_to_glsl(type: metallic_type, id: orig_id, member: false), ts: " ", ts: qualifier,
12723 ts: to_member_name(type, index), ts: member_attribute_qualifier(type, index), ts&: array_type, ts: ";");
12724 }
12725 is_using_builtin_array = false;
12726 return result;
12727 }
12728 }
12729
12730 if (orig_id)
12731 {
12732 auto *data_type = declared_type;
12733 if (is_pointer(type: *data_type))
12734 data_type = &get_pointee_type(type: *data_type);
12735
12736 if (is_array(type: *data_type) && get_resource_array_size(type: *data_type, id: orig_id) == 0)
12737 {
12738 // Hack for declaring unsized array of resources. Need to declare dummy sized array by value inline.
12739 // This can then be wrapped in spvDescriptorArray as usual.
12740 array_type = "[1] /* unsized array hack */";
12741 }
12742 }
12743
12744 string decl_type;
12745 if (declared_type->vecsize > 4)
12746 {
12747 auto orig_type = get<SPIRType>(id: orig_member_type_id);
12748 if (is_matrix(type: orig_type) && row_major)
12749 swap(a&: orig_type.vecsize, b&: orig_type.columns);
12750 orig_type.columns = 1;
12751 decl_type = type_to_glsl(type: orig_type, id: orig_id, member: true);
12752
12753 if (declared_type->columns > 1)
12754 decl_type = join(ts: "spvPaddedStd140Matrix<", ts&: decl_type, ts: ", ", ts: declared_type->columns, ts: ">");
12755 else
12756 decl_type = join(ts: "spvPaddedStd140<", ts&: decl_type, ts: ">");
12757 }
12758 else
12759 decl_type = type_to_glsl(type: *declared_type, id: orig_id, member: true);
12760
12761 const char *overlapping_binding_tag =
12762 has_extended_member_decoration(type: type.self, index, decoration: SPIRVCrossDecorationOverlappingBinding) ?
12763 "// Overlapping binding: " : "";
12764
12765 auto result = join(ts&: overlapping_binding_tag, ts&: pack_pfx, ts&: decl_type, ts: " ", ts: qualifier,
12766 ts: to_member_name(type, index), ts: member_attribute_qualifier(type, index), ts&: array_type, ts: ";");
12767
12768 is_using_builtin_array = false;
12769 return result;
12770}
12771
12772// Emit a structure member, padding and packing to maintain the correct memeber alignments.
12773void CompilerMSL::emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
12774 const string &qualifier, uint32_t)
12775{
12776 // If this member requires padding to maintain its declared offset, emit a dummy padding member before it.
12777 if (has_extended_member_decoration(type: type.self, index, decoration: SPIRVCrossDecorationPaddingTarget))
12778 {
12779 uint32_t pad_len = get_extended_member_decoration(type: type.self, index, decoration: SPIRVCrossDecorationPaddingTarget);
12780 statement(ts: "char _m", ts&: index, ts: "_pad", ts: "[", ts&: pad_len, ts: "];");
12781 }
12782
12783 BuiltIn builtin = BuiltInMax;
12784 if (is_mesh_shader() && is_member_builtin(type, index, builtin: &builtin))
12785 {
12786 if (!has_active_builtin(builtin, storage: StorageClassOutput) && !has_active_builtin(builtin, storage: StorageClassInput))
12787 {
12788 // Do not emit unused builtins in mesh-output blocks
12789 return;
12790 }
12791 }
12792
12793 // Handle HLSL-style 0-based vertex/instance index.
12794 builtin_declaration = true;
12795 statement(ts: to_struct_member(type, member_type_id, index, qualifier));
12796 builtin_declaration = false;
12797}
12798
12799void CompilerMSL::emit_struct_padding_target(const SPIRType &type)
12800{
12801 uint32_t struct_size = get_declared_struct_size_msl(struct_type: type, ignore_alignment: true, ignore_padding: true);
12802 uint32_t target_size = get_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationPaddingTarget);
12803 if (target_size < struct_size)
12804 SPIRV_CROSS_THROW("Cannot pad with negative bytes.");
12805 else if (target_size > struct_size)
12806 statement(ts: "char _m0_final_padding[", ts: target_size - struct_size, ts: "];");
12807}
12808
12809// Return a MSL qualifier for the specified function attribute member
12810string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t index)
12811{
12812 auto &execution = get_entry_point();
12813
12814 uint32_t mbr_type_id = type.member_types[index];
12815 auto &mbr_type = get<SPIRType>(id: mbr_type_id);
12816
12817 BuiltIn builtin = BuiltInMax;
12818 bool is_builtin = is_member_builtin(type, index, builtin: &builtin);
12819
12820 if (has_extended_member_decoration(type: type.self, index, decoration: SPIRVCrossDecorationResourceIndexPrimary))
12821 {
12822 string quals = join(
12823 ts: " [[id(", ts: get_extended_member_decoration(type: type.self, index, decoration: SPIRVCrossDecorationResourceIndexPrimary), ts: ")");
12824 if (interlocked_resources.count(
12825 x: get_extended_member_decoration(type: type.self, index, decoration: SPIRVCrossDecorationInterfaceOrigID)))
12826 quals += ", raster_order_group(0)";
12827 quals += "]]";
12828 return quals;
12829 }
12830
12831 // Vertex function inputs
12832 if (execution.model == ExecutionModelVertex && type.storage == StorageClassInput)
12833 {
12834 if (is_builtin)
12835 {
12836 switch (builtin)
12837 {
12838 case BuiltInVertexId:
12839 case BuiltInVertexIndex:
12840 case BuiltInBaseVertex:
12841 case BuiltInInstanceId:
12842 case BuiltInInstanceIndex:
12843 case BuiltInBaseInstance:
12844 if (msl_options.vertex_for_tessellation)
12845 return "";
12846 return string(" [[") + builtin_qualifier(builtin) + "]]";
12847
12848 case BuiltInDrawIndex:
12849 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
12850
12851 default:
12852 return "";
12853 }
12854 }
12855
12856 uint32_t locn;
12857 if (is_builtin)
12858 locn = get_or_allocate_builtin_input_member_location(builtin, type_id: type.self, index);
12859 else
12860 locn = get_member_location(type_id: type.self, index);
12861
12862 if (locn != k_unknown_location)
12863 return string(" [[attribute(") + convert_to_string(t: locn) + ")]]";
12864 }
12865
12866 bool use_semantic_stage_output = is_mesh_shader() || is_tese_shader() ||
12867 (execution.model == ExecutionModelVertex && !msl_options.vertex_for_tessellation);
12868
12869 // Vertex, mesh and tessellation evaluation function outputs
12870 if ((type.storage == StorageClassOutput || is_mesh_shader()) && use_semantic_stage_output)
12871 {
12872 if (is_builtin)
12873 {
12874 switch (builtin)
12875 {
12876 case BuiltInPointSize:
12877 // Only mark the PointSize builtin if really rendering points.
12878 // Some shaders may include a PointSize builtin even when used to render
12879 // non-point topologies, and Metal will reject this builtin when compiling
12880 // the shader into a render pipeline that uses a non-point topology.
12881 return msl_options.enable_point_size_builtin ? (string(" [[") + builtin_qualifier(builtin) + "]]") : "";
12882
12883 case BuiltInViewportIndex:
12884 if (!msl_options.supports_msl_version(major: 2, minor: 0))
12885 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
12886 /* fallthrough */
12887 case BuiltInPosition:
12888 case BuiltInLayer:
12889 case BuiltInCullPrimitiveEXT:
12890 case BuiltInPrimitiveShadingRateKHR:
12891 case BuiltInPrimitiveId:
12892 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
12893
12894 case BuiltInClipDistance:
12895 if (has_member_decoration(id: type.self, index, decoration: DecorationIndex))
12896 return join(ts: " [[user(clip", ts: get_member_decoration(id: type.self, index, decoration: DecorationIndex), ts: ")]]");
12897 else
12898 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
12899
12900 case BuiltInCullDistance:
12901 if (has_member_decoration(id: type.self, index, decoration: DecorationIndex))
12902 return join(ts: " [[user(cull", ts: get_member_decoration(id: type.self, index, decoration: DecorationIndex), ts: ")]]");
12903 else
12904 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
12905
12906 default:
12907 return "";
12908 }
12909 }
12910 string loc_qual = member_location_attribute_qualifier(type, index);
12911 if (!loc_qual.empty())
12912 return join(ts: " [[", ts&: loc_qual, ts: "]]");
12913 }
12914
12915 if (execution.model == ExecutionModelVertex && msl_options.vertex_for_tessellation && type.storage == StorageClassOutput)
12916 {
12917 // For this type of shader, we always arrange for it to capture its
12918 // output to a buffer. For this reason, qualifiers are irrelevant here.
12919 if (is_builtin)
12920 // We still have to assign a location so the output struct will sort correctly.
12921 get_or_allocate_builtin_output_member_location(builtin, type_id: type.self, index);
12922 return "";
12923 }
12924
12925 // Tessellation control function inputs
12926 if (is_tesc_shader() && type.storage == StorageClassInput)
12927 {
12928 if (is_builtin)
12929 {
12930 switch (builtin)
12931 {
12932 case BuiltInInvocationId:
12933 case BuiltInPrimitiveId:
12934 if (msl_options.multi_patch_workgroup)
12935 return "";
12936 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
12937 case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
12938 case BuiltInSubgroupSize: // FIXME: Should work in any stage
12939 if (msl_options.emulate_subgroups)
12940 return "";
12941 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
12942 case BuiltInPatchVertices:
12943 return "";
12944 // Others come from stage input.
12945 default:
12946 break;
12947 }
12948 }
12949 if (msl_options.multi_patch_workgroup)
12950 return "";
12951
12952 uint32_t locn;
12953 if (is_builtin)
12954 locn = get_or_allocate_builtin_input_member_location(builtin, type_id: type.self, index);
12955 else
12956 locn = get_member_location(type_id: type.self, index);
12957
12958 if (locn != k_unknown_location)
12959 return string(" [[attribute(") + convert_to_string(t: locn) + ")]]";
12960 }
12961
12962 // Tessellation control function outputs
12963 if (is_tesc_shader() && type.storage == StorageClassOutput)
12964 {
12965 // For this type of shader, we always arrange for it to capture its
12966 // output to a buffer. For this reason, qualifiers are irrelevant here.
12967 if (is_builtin)
12968 // We still have to assign a location so the output struct will sort correctly.
12969 get_or_allocate_builtin_output_member_location(builtin, type_id: type.self, index);
12970 return "";
12971 }
12972
12973 // Tessellation evaluation function inputs
12974 if (is_tese_shader() && type.storage == StorageClassInput)
12975 {
12976 if (is_builtin)
12977 {
12978 switch (builtin)
12979 {
12980 case BuiltInPrimitiveId:
12981 case BuiltInTessCoord:
12982 return string(" [[") + builtin_qualifier(builtin) + "]]";
12983 case BuiltInPatchVertices:
12984 return "";
12985 // Others come from stage input.
12986 default:
12987 break;
12988 }
12989 }
12990
12991 if (msl_options.raw_buffer_tese_input)
12992 return "";
12993
12994 // The special control point array must not be marked with an attribute.
12995 if (get_type(id: type.member_types[index]).basetype == SPIRType::ControlPointArray)
12996 return "";
12997
12998 uint32_t locn;
12999 if (is_builtin)
13000 locn = get_or_allocate_builtin_input_member_location(builtin, type_id: type.self, index);
13001 else
13002 locn = get_member_location(type_id: type.self, index);
13003
13004 if (locn != k_unknown_location)
13005 return string(" [[attribute(") + convert_to_string(t: locn) + ")]]";
13006 }
13007
13008 // Tessellation evaluation function outputs were handled above.
13009
13010 // Fragment function inputs
13011 if (execution.model == ExecutionModelFragment && type.storage == StorageClassInput)
13012 {
13013 string quals;
13014 if (is_builtin)
13015 {
13016 switch (builtin)
13017 {
13018 case BuiltInViewIndex:
13019 if (!msl_options.multiview || !msl_options.multiview_layered_rendering)
13020 break;
13021 /* fallthrough */
13022 case BuiltInFrontFacing:
13023 case BuiltInPointCoord:
13024 case BuiltInFragCoord:
13025 case BuiltInSampleId:
13026 case BuiltInSampleMask:
13027 case BuiltInLayer:
13028 case BuiltInBaryCoordKHR:
13029 case BuiltInBaryCoordNoPerspKHR:
13030 quals = builtin_qualifier(builtin);
13031 break;
13032
13033 case BuiltInClipDistance:
13034 return join(ts: " [[user(clip", ts: get_member_decoration(id: type.self, index, decoration: DecorationIndex), ts: ")]]");
13035 case BuiltInCullDistance:
13036 return join(ts: " [[user(cull", ts: get_member_decoration(id: type.self, index, decoration: DecorationIndex), ts: ")]]");
13037
13038 default:
13039 break;
13040 }
13041 }
13042 else
13043 quals = member_location_attribute_qualifier(type, index);
13044
13045 if (builtin == BuiltInBaryCoordKHR && has_member_decoration(id: type.self, index, decoration: DecorationNoPerspective))
13046 {
13047 // NoPerspective is baked into the builtin type.
13048 SPIRV_CROSS_THROW("NoPerspective decorations are not supported for BaryCoord inputs.");
13049 }
13050
13051 // Don't bother decorating integers with the 'flat' attribute; it's
13052 // the default (in fact, the only option). Also don't bother with the
13053 // FragCoord builtin; it's always noperspective on Metal.
13054 if (!type_is_integral(type: mbr_type) && (!is_builtin || builtin != BuiltInFragCoord))
13055 {
13056 if (has_member_decoration(id: type.self, index, decoration: DecorationFlat))
13057 {
13058 if (!quals.empty())
13059 quals += ", ";
13060 quals += "flat";
13061 }
13062 else if (has_member_decoration(id: type.self, index, decoration: DecorationCentroid))
13063 {
13064 if (!quals.empty())
13065 quals += ", ";
13066
13067 if (builtin == BuiltInBaryCoordNoPerspKHR || builtin == BuiltInBaryCoordKHR)
13068 SPIRV_CROSS_THROW("Centroid interpolation not supported for barycentrics in MSL.");
13069
13070 if (has_member_decoration(id: type.self, index, decoration: DecorationNoPerspective))
13071 quals += "centroid_no_perspective";
13072 else
13073 quals += "centroid_perspective";
13074 }
13075 else if (has_member_decoration(id: type.self, index, decoration: DecorationSample))
13076 {
13077 if (!quals.empty())
13078 quals += ", ";
13079
13080 if (builtin == BuiltInBaryCoordNoPerspKHR || builtin == BuiltInBaryCoordKHR)
13081 SPIRV_CROSS_THROW("Sample interpolation not supported for barycentrics in MSL.");
13082
13083 if (has_member_decoration(id: type.self, index, decoration: DecorationNoPerspective))
13084 quals += "sample_no_perspective";
13085 else
13086 quals += "sample_perspective";
13087 }
13088 else if (has_member_decoration(id: type.self, index, decoration: DecorationNoPerspective) || builtin == BuiltInBaryCoordNoPerspKHR)
13089 {
13090 if (!quals.empty())
13091 quals += ", ";
13092 quals += "center_no_perspective";
13093 }
13094 else if (builtin == BuiltInBaryCoordKHR)
13095 {
13096 if (!quals.empty())
13097 quals += ", ";
13098 quals += "center_perspective";
13099 }
13100 }
13101
13102 if (!quals.empty())
13103 return " [[" + quals + "]]";
13104 }
13105
13106 // Fragment function outputs
13107 if (execution.model == ExecutionModelFragment && type.storage == StorageClassOutput)
13108 {
13109 if (is_builtin)
13110 {
13111 switch (builtin)
13112 {
13113 case BuiltInFragStencilRefEXT:
13114 // Similar to PointSize, only mark FragStencilRef if there's a stencil buffer.
13115 // Some shaders may include a FragStencilRef builtin even when used to render
13116 // without a stencil attachment, and Metal will reject this builtin
13117 // when compiling the shader into a render pipeline that does not set
13118 // stencilAttachmentPixelFormat.
13119 if (!msl_options.enable_frag_stencil_ref_builtin)
13120 return "";
13121 if (!msl_options.supports_msl_version(major: 2, minor: 1))
13122 SPIRV_CROSS_THROW("Stencil export only supported in MSL 2.1 and up.");
13123 return string(" [[") + builtin_qualifier(builtin) + "]]";
13124
13125 case BuiltInFragDepth:
13126 // Ditto FragDepth.
13127 if (!msl_options.enable_frag_depth_builtin)
13128 return "";
13129 /* fallthrough */
13130 case BuiltInSampleMask:
13131 return string(" [[") + builtin_qualifier(builtin) + "]]";
13132
13133 default:
13134 return "";
13135 }
13136 }
13137 uint32_t locn = get_member_location(type_id: type.self, index);
13138 // Metal will likely complain about missing color attachments, too.
13139 if (locn != k_unknown_location && !(msl_options.enable_frag_output_mask & (1 << locn)))
13140 return "";
13141 if (locn != k_unknown_location && has_member_decoration(id: type.self, index, decoration: DecorationIndex))
13142 return join(ts: " [[color(", ts&: locn, ts: "), index(", ts: get_member_decoration(id: type.self, index, decoration: DecorationIndex),
13143 ts: ")]]");
13144 else if (locn != k_unknown_location)
13145 return join(ts: " [[color(", ts&: locn, ts: ")]]");
13146 else if (has_member_decoration(id: type.self, index, decoration: DecorationIndex))
13147 return join(ts: " [[index(", ts: get_member_decoration(id: type.self, index, decoration: DecorationIndex), ts: ")]]");
13148 else
13149 return "";
13150 }
13151
13152 // Compute function inputs
13153 if (execution.model == ExecutionModelGLCompute && type.storage == StorageClassInput)
13154 {
13155 if (is_builtin)
13156 {
13157 switch (builtin)
13158 {
13159 case BuiltInNumSubgroups:
13160 case BuiltInSubgroupId:
13161 case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
13162 case BuiltInSubgroupSize: // FIXME: Should work in any stage
13163 if (msl_options.emulate_subgroups)
13164 break;
13165 /* fallthrough */
13166 case BuiltInGlobalInvocationId:
13167 case BuiltInWorkgroupId:
13168 case BuiltInNumWorkgroups:
13169 case BuiltInLocalInvocationId:
13170 case BuiltInLocalInvocationIndex:
13171 return string(" [[") + builtin_qualifier(builtin) + "]]";
13172
13173 default:
13174 return "";
13175 }
13176 }
13177 }
13178
13179 return "";
13180}
13181
13182// A user-defined output variable is considered to match an input variable in the subsequent
13183// stage if the two variables are declared with the same Location and Component decoration and
13184// match in type and decoration, except that interpolation decorations are not required to match.
13185// For the purposes of interface matching, variables declared without a Component decoration are
13186// considered to have a Component decoration of zero.
13187string CompilerMSL::member_location_attribute_qualifier(const SPIRType &type, uint32_t index)
13188{
13189 string quals;
13190 uint32_t comp;
13191 uint32_t locn = get_member_location(type_id: type.self, index, comp: &comp);
13192 if (locn != k_unknown_location)
13193 {
13194 quals += "user(locn";
13195 quals += convert_to_string(t: locn);
13196 if (comp != k_unknown_component && comp != 0)
13197 {
13198 quals += "_";
13199 quals += convert_to_string(t: comp);
13200 }
13201 quals += ")";
13202 }
13203 return quals;
13204}
13205
13206// Returns the location decoration of the member with the specified index in the specified type.
13207// If the location of the member has been explicitly set, that location is used. If not, this
13208// function assumes the members are ordered in their location order, and simply returns the
13209// index as the location.
13210uint32_t CompilerMSL::get_member_location(uint32_t type_id, uint32_t index, uint32_t *comp) const
13211{
13212 if (comp)
13213 {
13214 if (has_member_decoration(id: type_id, index, decoration: DecorationComponent))
13215 *comp = get_member_decoration(id: type_id, index, decoration: DecorationComponent);
13216 else
13217 *comp = k_unknown_component;
13218 }
13219
13220 if (has_member_decoration(id: type_id, index, decoration: DecorationLocation))
13221 return get_member_decoration(id: type_id, index, decoration: DecorationLocation);
13222 else
13223 return k_unknown_location;
13224}
13225
13226uint32_t CompilerMSL::get_or_allocate_builtin_input_member_location(spv::BuiltIn builtin,
13227 uint32_t type_id, uint32_t index,
13228 uint32_t *comp)
13229{
13230 uint32_t loc = get_member_location(type_id, index, comp);
13231 if (loc != k_unknown_location)
13232 return loc;
13233
13234 if (comp)
13235 *comp = k_unknown_component;
13236
13237 // Late allocation. Find a location which is unused by the application.
13238 // This can happen for built-in inputs in tessellation which are mixed and matched with user inputs.
13239 auto &mbr_type = get<SPIRType>(id: get<SPIRType>(id: type_id).member_types[index]);
13240 uint32_t count = type_to_location_count(type: mbr_type);
13241
13242 loc = 0;
13243
13244 const auto location_range_in_use = [this](uint32_t location, uint32_t location_count) -> bool {
13245 for (uint32_t i = 0; i < location_count; i++)
13246 if (location_inputs_in_use.count(x: location + i) != 0)
13247 return true;
13248 return false;
13249 };
13250
13251 while (location_range_in_use(loc, count))
13252 loc++;
13253
13254 set_member_decoration(id: type_id, index, decoration: DecorationLocation, argument: loc);
13255
13256 // Triangle tess level inputs are shared in one packed float4,
13257 // mark both builtins as sharing one location.
13258 if (!msl_options.raw_buffer_tese_input && is_tessellating_triangles() &&
13259 (builtin == BuiltInTessLevelInner || builtin == BuiltInTessLevelOuter))
13260 {
13261 builtin_to_automatic_input_location[BuiltInTessLevelInner] = loc;
13262 builtin_to_automatic_input_location[BuiltInTessLevelOuter] = loc;
13263 }
13264 else
13265 builtin_to_automatic_input_location[builtin] = loc;
13266
13267 mark_location_as_used_by_shader(location: loc, type: mbr_type, storage: StorageClassInput, fallback: true);
13268 return loc;
13269}
13270
13271uint32_t CompilerMSL::get_or_allocate_builtin_output_member_location(spv::BuiltIn builtin,
13272 uint32_t type_id, uint32_t index,
13273 uint32_t *comp)
13274{
13275 uint32_t loc = get_member_location(type_id, index, comp);
13276 if (loc != k_unknown_location)
13277 return loc;
13278 loc = 0;
13279
13280 if (comp)
13281 *comp = k_unknown_component;
13282
13283 // Late allocation. Find a location which is unused by the application.
13284 // This can happen for built-in outputs in tessellation which are mixed and matched with user inputs.
13285 auto &mbr_type = get<SPIRType>(id: get<SPIRType>(id: type_id).member_types[index]);
13286 uint32_t count = type_to_location_count(type: mbr_type);
13287
13288 const auto location_range_in_use = [this](uint32_t location, uint32_t location_count) -> bool {
13289 for (uint32_t i = 0; i < location_count; i++)
13290 if (location_outputs_in_use.count(x: location + i) != 0)
13291 return true;
13292 return false;
13293 };
13294
13295 while (location_range_in_use(loc, count))
13296 loc++;
13297
13298 set_member_decoration(id: type_id, index, decoration: DecorationLocation, argument: loc);
13299
13300 // Triangle tess level inputs are shared in one packed float4;
13301 // mark both builtins as sharing one location.
13302 if (is_tessellating_triangles() && (builtin == BuiltInTessLevelInner || builtin == BuiltInTessLevelOuter))
13303 {
13304 builtin_to_automatic_output_location[BuiltInTessLevelInner] = loc;
13305 builtin_to_automatic_output_location[BuiltInTessLevelOuter] = loc;
13306 }
13307 else
13308 builtin_to_automatic_output_location[builtin] = loc;
13309
13310 mark_location_as_used_by_shader(location: loc, type: mbr_type, storage: StorageClassOutput, fallback: true);
13311 return loc;
13312}
13313
13314// Returns the type declaration for a function, including the
13315// entry type if the current function is the entry point function
13316string CompilerMSL::func_type_decl(SPIRType &type)
13317{
13318 // The regular function return type. If not processing the entry point function, that's all we need
13319 string return_type = type_to_glsl(type) + type_to_array_glsl(type, variable_id: 0);
13320 if (!processing_entry_point)
13321 return return_type;
13322
13323 // If an outgoing interface block has been defined, and it should be returned, override the entry point return type
13324 bool ep_should_return_output = !get_is_rasterization_disabled();
13325 if (stage_out_var_id && ep_should_return_output)
13326 return_type = type_to_glsl(type: get_stage_out_struct_type()) + type_to_array_glsl(type, variable_id: 0);
13327
13328 // Prepend a entry type, based on the execution model
13329 string entry_type;
13330 auto &execution = get_entry_point();
13331 switch (execution.model)
13332 {
13333 case ExecutionModelVertex:
13334 if (msl_options.vertex_for_tessellation && !msl_options.supports_msl_version(major: 1, minor: 2))
13335 SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
13336 entry_type = msl_options.vertex_for_tessellation ? "kernel" : "vertex";
13337 break;
13338 case ExecutionModelTessellationEvaluation:
13339 if (!msl_options.supports_msl_version(major: 1, minor: 2))
13340 SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
13341 if (execution.flags.get(bit: ExecutionModeIsolines))
13342 SPIRV_CROSS_THROW("Metal does not support isoline tessellation.");
13343 if (msl_options.is_ios())
13344 entry_type = join(ts: "[[ patch(", ts: is_tessellating_triangles() ? "triangle" : "quad", ts: ") ]] vertex");
13345 else
13346 entry_type = join(ts: "[[ patch(", ts: is_tessellating_triangles() ? "triangle" : "quad", ts: ", ",
13347 ts&: execution.output_vertices, ts: ") ]] vertex");
13348 break;
13349 case ExecutionModelFragment:
13350 entry_type = uses_explicit_early_fragment_test() ? "[[ early_fragment_tests ]] fragment" : "fragment";
13351 break;
13352 case ExecutionModelTessellationControl:
13353 if (!msl_options.supports_msl_version(major: 1, minor: 2))
13354 SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
13355 if (execution.flags.get(bit: ExecutionModeIsolines))
13356 SPIRV_CROSS_THROW("Metal does not support isoline tessellation.");
13357 /* fallthrough */
13358 case ExecutionModelGLCompute:
13359 case ExecutionModelKernel:
13360 entry_type = "kernel";
13361 break;
13362 case ExecutionModelMeshEXT:
13363 entry_type = "[[mesh]]";
13364 break;
13365 case ExecutionModelTaskEXT:
13366 entry_type = "[[object]]";
13367 break;
13368 default:
13369 entry_type = "unknown";
13370 break;
13371 }
13372
13373 return entry_type + " " + return_type;
13374}
13375
13376bool CompilerMSL::is_tesc_shader() const
13377{
13378 return get_execution_model() == ExecutionModelTessellationControl;
13379}
13380
13381bool CompilerMSL::is_tese_shader() const
13382{
13383 return get_execution_model() == ExecutionModelTessellationEvaluation;
13384}
13385
13386bool CompilerMSL::is_mesh_shader() const
13387{
13388 return get_execution_model() == spv::ExecutionModelMeshEXT;
13389}
13390
13391bool CompilerMSL::uses_explicit_early_fragment_test()
13392{
13393 auto &ep_flags = get_entry_point().flags;
13394 return ep_flags.get(bit: ExecutionModeEarlyFragmentTests) || ep_flags.get(bit: ExecutionModePostDepthCoverage);
13395}
13396
13397// In MSL, address space qualifiers are required for all pointer or reference variables
13398string CompilerMSL::get_argument_address_space(const SPIRVariable &argument)
13399{
13400 const auto &type = get<SPIRType>(id: argument.basetype);
13401 return get_type_address_space(type, id: argument.self, argument: true);
13402}
13403
13404bool CompilerMSL::decoration_flags_signal_volatile(const Bitset &flags)
13405{
13406 return flags.get(bit: DecorationVolatile) || flags.get(bit: DecorationCoherent);
13407}
13408
13409string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bool argument)
13410{
13411 // This can be called for variable pointer contexts as well, so be very careful about which method we choose.
13412 Bitset flags;
13413 auto *var = maybe_get<SPIRVariable>(id);
13414 if (var && type.basetype == SPIRType::Struct &&
13415 (has_decoration(id: type.self, decoration: DecorationBlock) || has_decoration(id: type.self, decoration: DecorationBufferBlock)))
13416 flags = get_buffer_block_flags(id);
13417 else
13418 flags = get_decoration_bitset(id);
13419
13420 const char *addr_space = nullptr;
13421 switch (type.storage)
13422 {
13423 case StorageClassWorkgroup:
13424 addr_space = "threadgroup";
13425 break;
13426
13427 case StorageClassStorageBuffer:
13428 case StorageClassPhysicalStorageBuffer:
13429 {
13430 // For arguments from variable pointers, we use the write count deduction, so
13431 // we should not assume any constness here. Only for global SSBOs.
13432 bool readonly = false;
13433 if (!var || has_decoration(id: type.self, decoration: DecorationBlock))
13434 readonly = flags.get(bit: DecorationNonWritable);
13435
13436 addr_space = readonly ? "const device" : "device";
13437 break;
13438 }
13439
13440 case StorageClassUniform:
13441 case StorageClassUniformConstant:
13442 case StorageClassPushConstant:
13443 if (type.basetype == SPIRType::Struct)
13444 {
13445 bool ssbo = has_decoration(id: type.self, decoration: DecorationBufferBlock);
13446 if (ssbo)
13447 addr_space = flags.get(bit: DecorationNonWritable) ? "const device" : "device";
13448 else
13449 addr_space = "constant";
13450 }
13451 else if (!argument)
13452 {
13453 addr_space = "constant";
13454 }
13455 else if (type_is_msl_framebuffer_fetch(type))
13456 {
13457 // Subpass inputs are passed around by value.
13458 addr_space = "";
13459 }
13460 break;
13461
13462 case StorageClassFunction:
13463 case StorageClassGeneric:
13464 break;
13465
13466 case StorageClassInput:
13467 if (is_tesc_shader() && var && var->basevariable == stage_in_ptr_var_id)
13468 addr_space = msl_options.multi_patch_workgroup ? "const device" : "threadgroup";
13469 // Don't pass tessellation levels in the device AS; we load and convert them
13470 // to float manually.
13471 if (is_tese_shader() && msl_options.raw_buffer_tese_input && var)
13472 {
13473 bool is_stage_in = var->basevariable == stage_in_ptr_var_id;
13474 bool is_patch_stage_in = has_decoration(id: var->self, decoration: DecorationPatch);
13475 bool is_builtin = has_decoration(id: var->self, decoration: DecorationBuiltIn);
13476 BuiltIn builtin = (BuiltIn)get_decoration(id: var->self, decoration: DecorationBuiltIn);
13477 bool is_tess_level = is_builtin && (builtin == BuiltInTessLevelOuter || builtin == BuiltInTessLevelInner);
13478 if (is_stage_in || (is_patch_stage_in && !is_tess_level))
13479 addr_space = "const device";
13480 }
13481 if (get_execution_model() == ExecutionModelFragment && var && var->basevariable == stage_in_var_id)
13482 addr_space = "thread";
13483 break;
13484
13485 case StorageClassOutput:
13486 if (capture_output_to_buffer)
13487 {
13488 if (var && type.storage == StorageClassOutput)
13489 {
13490 bool is_masked = is_stage_output_variable_masked(var: *var);
13491
13492 if (is_masked)
13493 {
13494 if (is_tessellation_shader())
13495 addr_space = "threadgroup";
13496 else
13497 addr_space = "thread";
13498 }
13499 else if (variable_decl_is_remapped_storage(variable: *var, storage: StorageClassWorkgroup))
13500 addr_space = "threadgroup";
13501 }
13502
13503 if (!addr_space)
13504 addr_space = "device";
13505 }
13506
13507 if (is_mesh_shader())
13508 addr_space = "threadgroup";
13509 break;
13510
13511 case StorageClassTaskPayloadWorkgroupEXT:
13512 if (is_mesh_shader())
13513 addr_space = "const object_data";
13514 else
13515 addr_space = "object_data";
13516 break;
13517
13518 default:
13519 break;
13520 }
13521
13522 if (!addr_space)
13523 {
13524 // No address space for plain values.
13525 addr_space = type.pointer || (argument && type.basetype == SPIRType::ControlPointArray) ? "thread" : "";
13526 }
13527
13528 if (decoration_flags_signal_volatile(flags) && 0 != strcmp(s1: addr_space, s2: "thread"))
13529 return join(ts: "volatile ", ts&: addr_space);
13530 else
13531 return addr_space;
13532}
13533
13534const char *CompilerMSL::to_restrict(uint32_t id, bool space)
13535{
13536 // This can be called for variable pointer contexts as well, so be very careful about which method we choose.
13537 Bitset flags;
13538 if (ir.ids[id].get_type() == TypeVariable)
13539 {
13540 uint32_t type_id = expression_type_id(id);
13541 auto &type = expression_type(id);
13542 if (type.basetype == SPIRType::Struct &&
13543 (has_decoration(id: type_id, decoration: DecorationBlock) || has_decoration(id: type_id, decoration: DecorationBufferBlock)))
13544 flags = get_buffer_block_flags(id);
13545 else
13546 flags = get_decoration_bitset(id);
13547 }
13548 else
13549 flags = get_decoration_bitset(id);
13550
13551 return flags.get(bit: DecorationRestrict) || flags.get(bit: DecorationRestrictPointerEXT) ?
13552 (space ? "__restrict " : "__restrict") : "";
13553}
13554
13555string CompilerMSL::entry_point_arg_stage_in()
13556{
13557 string decl;
13558
13559 if ((is_tesc_shader() && msl_options.multi_patch_workgroup) ||
13560 (is_tese_shader() && msl_options.raw_buffer_tese_input))
13561 return decl;
13562
13563 // Stage-in structure
13564 uint32_t stage_in_id;
13565 if (is_tese_shader())
13566 stage_in_id = patch_stage_in_var_id;
13567 else
13568 stage_in_id = stage_in_var_id;
13569
13570 if (stage_in_id)
13571 {
13572 auto &var = get<SPIRVariable>(id: stage_in_id);
13573 auto &type = get_variable_data_type(var);
13574
13575 add_resource_name(id: var.self);
13576 decl = join(ts: type_to_glsl(type), ts: " ", ts: to_name(id: var.self), ts: " [[stage_in]]");
13577 }
13578
13579 return decl;
13580}
13581
13582// Returns true if this input builtin should be a direct parameter on a shader function parameter list,
13583// and false for builtins that should be passed or calculated some other way.
13584bool CompilerMSL::is_direct_input_builtin(BuiltIn bi_type)
13585{
13586 switch (bi_type)
13587 {
13588 // Vertex function in
13589 case BuiltInVertexId:
13590 case BuiltInVertexIndex:
13591 case BuiltInBaseVertex:
13592 case BuiltInInstanceId:
13593 case BuiltInInstanceIndex:
13594 case BuiltInBaseInstance:
13595 return get_execution_model() != ExecutionModelVertex || !msl_options.vertex_for_tessellation;
13596 // Tess. control function in
13597 case BuiltInPosition:
13598 case BuiltInPointSize:
13599 case BuiltInClipDistance:
13600 case BuiltInCullDistance:
13601 case BuiltInPatchVertices:
13602 return false;
13603 case BuiltInInvocationId:
13604 case BuiltInPrimitiveId:
13605 return !is_tesc_shader() || !msl_options.multi_patch_workgroup;
13606 // Tess. evaluation function in
13607 case BuiltInTessLevelInner:
13608 case BuiltInTessLevelOuter:
13609 return false;
13610 // Fragment function in
13611 case BuiltInSamplePosition:
13612 case BuiltInHelperInvocation:
13613 case BuiltInBaryCoordKHR:
13614 case BuiltInBaryCoordNoPerspKHR:
13615 return false;
13616 case BuiltInViewIndex:
13617 return get_execution_model() == ExecutionModelFragment && msl_options.multiview &&
13618 msl_options.multiview_layered_rendering;
13619 // Compute function in
13620 case BuiltInSubgroupId:
13621 case BuiltInNumSubgroups:
13622 return !msl_options.emulate_subgroups;
13623 // Any stage function in
13624 case BuiltInDeviceIndex:
13625 case BuiltInSubgroupEqMask:
13626 case BuiltInSubgroupGeMask:
13627 case BuiltInSubgroupGtMask:
13628 case BuiltInSubgroupLeMask:
13629 case BuiltInSubgroupLtMask:
13630 return false;
13631 case BuiltInSubgroupSize:
13632 if (msl_options.fixed_subgroup_size != 0)
13633 return false;
13634 /* fallthrough */
13635 case BuiltInSubgroupLocalInvocationId:
13636 return !msl_options.emulate_subgroups;
13637 default:
13638 return true;
13639 }
13640}
13641
13642// Returns true if this is a fragment shader that runs per sample, and false otherwise.
13643bool CompilerMSL::is_sample_rate() const
13644{
13645 auto &caps = get_declared_capabilities();
13646 return get_execution_model() == ExecutionModelFragment &&
13647 (msl_options.force_sample_rate_shading ||
13648 std::find(first: caps.begin(), last: caps.end(), val: CapabilitySampleRateShading) != caps.end() ||
13649 (msl_options.use_framebuffer_fetch_subpasses && need_subpass_input_ms));
13650}
13651
13652bool CompilerMSL::is_intersection_query() const
13653{
13654 auto &caps = get_declared_capabilities();
13655 return std::find(first: caps.begin(), last: caps.end(), val: CapabilityRayQueryKHR) != caps.end();
13656}
13657
13658void CompilerMSL::entry_point_args_builtin(string &ep_args)
13659{
13660 // Builtin variables
13661 SmallVector<pair<SPIRVariable *, BuiltIn>, 8> active_builtins;
13662 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t var_id, SPIRVariable &var) {
13663 if (var.storage != StorageClassInput)
13664 return;
13665
13666 auto bi_type = BuiltIn(get_decoration(id: var_id, decoration: DecorationBuiltIn));
13667
13668 // Don't emit SamplePosition as a separate parameter. In the entry
13669 // point, we get that by calling get_sample_position() on the sample ID.
13670 if (is_builtin_variable(var) &&
13671 get_variable_data_type(var).basetype != SPIRType::Struct &&
13672 get_variable_data_type(var).basetype != SPIRType::ControlPointArray)
13673 {
13674 // If the builtin is not part of the active input builtin set, don't emit it.
13675 // Relevant for multiple entry-point modules which might declare unused builtins.
13676 if (!active_input_builtins.get(bit: bi_type) || !interface_variable_exists_in_entry_point(id: var_id))
13677 return;
13678
13679 // Remember this variable. We may need to correct its type.
13680 active_builtins.push_back(t: make_pair(x: &var, y&: bi_type));
13681
13682 if (is_direct_input_builtin(bi_type))
13683 {
13684 if (!ep_args.empty())
13685 ep_args += ", ";
13686
13687 // Handle HLSL-style 0-based vertex/instance index.
13688 builtin_declaration = true;
13689
13690 // Handle different MSL gl_TessCoord types. (float2, float3)
13691 if (bi_type == BuiltInTessCoord && get_entry_point().flags.get(bit: ExecutionModeQuads))
13692 ep_args += "float2 " + to_expression(id: var_id) + "In";
13693 else
13694 ep_args += builtin_type_decl(builtin: bi_type, id: var_id) + " " + to_expression(id: var_id);
13695
13696 ep_args += string(" [[") + builtin_qualifier(builtin: bi_type);
13697 if (bi_type == BuiltInSampleMask && get_entry_point().flags.get(bit: ExecutionModePostDepthCoverage))
13698 {
13699 if (!msl_options.supports_msl_version(major: 2))
13700 SPIRV_CROSS_THROW("Post-depth coverage requires MSL 2.0.");
13701 if (msl_options.is_macos() && !msl_options.supports_msl_version(major: 2, minor: 3))
13702 SPIRV_CROSS_THROW("Post-depth coverage on Mac requires MSL 2.3.");
13703 ep_args += ", post_depth_coverage";
13704 }
13705 ep_args += "]]";
13706 builtin_declaration = false;
13707 }
13708 }
13709
13710 if (has_extended_decoration(id: var_id, decoration: SPIRVCrossDecorationBuiltInDispatchBase))
13711 {
13712 // This is a special implicit builtin, not corresponding to any SPIR-V builtin,
13713 // which holds the base that was passed to vkCmdDispatchBase() or vkCmdDrawIndexed(). If it's present,
13714 // assume we emitted it for a good reason.
13715 assert(msl_options.supports_msl_version(1, 2));
13716 if (!ep_args.empty())
13717 ep_args += ", ";
13718
13719 ep_args += type_to_glsl(type: get_variable_data_type(var)) + " " + to_expression(id: var_id) + " [[grid_origin]]";
13720 }
13721
13722 if (has_extended_decoration(id: var_id, decoration: SPIRVCrossDecorationBuiltInStageInputSize))
13723 {
13724 // This is another special implicit builtin, not corresponding to any SPIR-V builtin,
13725 // which holds the number of vertices and instances to draw. If it's present,
13726 // assume we emitted it for a good reason.
13727 assert(msl_options.supports_msl_version(1, 2));
13728 if (!ep_args.empty())
13729 ep_args += ", ";
13730
13731 ep_args += type_to_glsl(type: get_variable_data_type(var)) + " " + to_expression(id: var_id) + " [[grid_size]]";
13732 }
13733 });
13734
13735 // Correct the types of all encountered active builtins. We couldn't do this before
13736 // because ensure_correct_builtin_type() may increase the bound, which isn't allowed
13737 // while iterating over IDs.
13738 for (auto &var : active_builtins)
13739 var.first->basetype = ensure_correct_builtin_type(type_id: var.first->basetype, builtin: var.second);
13740
13741 // Handle HLSL-style 0-based vertex/instance index.
13742 if (needs_base_vertex_arg == TriState::Yes)
13743 ep_args += built_in_func_arg(builtin: BuiltInBaseVertex, prefix_comma: !ep_args.empty());
13744
13745 if (needs_base_instance_arg == TriState::Yes)
13746 ep_args += built_in_func_arg(builtin: BuiltInBaseInstance, prefix_comma: !ep_args.empty());
13747
13748 if (capture_output_to_buffer)
13749 {
13750 // Add parameters to hold the indirect draw parameters and the shader output. This has to be handled
13751 // specially because it needs to be a pointer, not a reference.
13752 if (stage_out_var_id)
13753 {
13754 if (!ep_args.empty())
13755 ep_args += ", ";
13756 ep_args += join(ts: "device ", ts: type_to_glsl(type: get_stage_out_struct_type()), ts: "* ", ts&: output_buffer_var_name,
13757 ts: " [[buffer(", ts&: msl_options.shader_output_buffer_index, ts: ")]]");
13758 }
13759
13760 if (is_tesc_shader())
13761 {
13762 if (!ep_args.empty())
13763 ep_args += ", ";
13764 ep_args +=
13765 join(ts: "constant uint* spvIndirectParams [[buffer(", ts&: msl_options.indirect_params_buffer_index, ts: ")]]");
13766 }
13767 else if (stage_out_var_id &&
13768 !(get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation))
13769 {
13770 if (!ep_args.empty())
13771 ep_args += ", ";
13772 ep_args +=
13773 join(ts: "device uint* spvIndirectParams [[buffer(", ts&: msl_options.indirect_params_buffer_index, ts: ")]]");
13774 }
13775
13776 if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation &&
13777 (active_input_builtins.get(bit: BuiltInVertexIndex) || active_input_builtins.get(bit: BuiltInVertexId)) &&
13778 msl_options.vertex_index_type != Options::IndexType::None)
13779 {
13780 // Add the index buffer so we can set gl_VertexIndex correctly.
13781 if (!ep_args.empty())
13782 ep_args += ", ";
13783 switch (msl_options.vertex_index_type)
13784 {
13785 case Options::IndexType::None:
13786 break;
13787 case Options::IndexType::UInt16:
13788 ep_args += join(ts: "const device ushort* ", ts&: index_buffer_var_name, ts: " [[buffer(",
13789 ts&: msl_options.shader_index_buffer_index, ts: ")]]");
13790 break;
13791 case Options::IndexType::UInt32:
13792 ep_args += join(ts: "const device uint* ", ts&: index_buffer_var_name, ts: " [[buffer(",
13793 ts&: msl_options.shader_index_buffer_index, ts: ")]]");
13794 break;
13795 }
13796 }
13797
13798 // Tessellation control shaders get three additional parameters:
13799 // a buffer to hold the per-patch data, a buffer to hold the per-patch
13800 // tessellation levels, and a block of workgroup memory to hold the
13801 // input control point data.
13802 if (is_tesc_shader())
13803 {
13804 if (patch_stage_out_var_id)
13805 {
13806 if (!ep_args.empty())
13807 ep_args += ", ";
13808 ep_args +=
13809 join(ts: "device ", ts: type_to_glsl(type: get_patch_stage_out_struct_type()), ts: "* ", ts&: patch_output_buffer_var_name,
13810 ts: " [[buffer(", ts: convert_to_string(t: msl_options.shader_patch_output_buffer_index), ts: ")]]");
13811 }
13812 if (!ep_args.empty())
13813 ep_args += ", ";
13814 ep_args += join(ts: "device ", ts: get_tess_factor_struct_name(), ts: "* ", ts&: tess_factor_buffer_var_name, ts: " [[buffer(",
13815 ts: convert_to_string(t: msl_options.shader_tess_factor_buffer_index), ts: ")]]");
13816
13817 // Initializer for tess factors must be handled specially since it's never declared as a normal variable.
13818 uint32_t outer_factor_initializer_id = 0;
13819 uint32_t inner_factor_initializer_id = 0;
13820 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
13821 if (!has_decoration(id: var.self, decoration: DecorationBuiltIn) || var.storage != StorageClassOutput || !var.initializer)
13822 return;
13823
13824 BuiltIn builtin = BuiltIn(get_decoration(id: var.self, decoration: DecorationBuiltIn));
13825 if (builtin == BuiltInTessLevelInner)
13826 inner_factor_initializer_id = var.initializer;
13827 else if (builtin == BuiltInTessLevelOuter)
13828 outer_factor_initializer_id = var.initializer;
13829 });
13830
13831 const SPIRConstant *c = nullptr;
13832
13833 if (outer_factor_initializer_id && (c = maybe_get<SPIRConstant>(id: outer_factor_initializer_id)))
13834 {
13835 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
13836 entry_func.fixup_hooks_in.push_back(
13837 t: [=]()
13838 {
13839 uint32_t components = is_tessellating_triangles() ? 3 : 4;
13840 for (uint32_t i = 0; i < components; i++)
13841 {
13842 statement(ts: builtin_to_glsl(builtin: BuiltInTessLevelOuter, storage: StorageClassOutput), ts: "[", ts&: i,
13843 ts: "] = ", ts: "half(", ts: to_expression(id: c->subconstants[i]), ts: ");");
13844 }
13845 });
13846 }
13847
13848 if (inner_factor_initializer_id && (c = maybe_get<SPIRConstant>(id: inner_factor_initializer_id)))
13849 {
13850 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
13851 if (is_tessellating_triangles())
13852 {
13853 entry_func.fixup_hooks_in.push_back(t: [=]() {
13854 statement(ts: builtin_to_glsl(builtin: BuiltInTessLevelInner, storage: StorageClassOutput), ts: " = ", ts: "half(",
13855 ts: to_expression(id: c->subconstants[0]), ts: ");");
13856 });
13857 }
13858 else
13859 {
13860 entry_func.fixup_hooks_in.push_back(t: [=]() {
13861 for (uint32_t i = 0; i < 2; i++)
13862 {
13863 statement(ts: builtin_to_glsl(builtin: BuiltInTessLevelInner, storage: StorageClassOutput), ts: "[", ts&: i, ts: "] = ",
13864 ts: "half(", ts: to_expression(id: c->subconstants[i]), ts: ");");
13865 }
13866 });
13867 }
13868 }
13869
13870 if (stage_in_var_id)
13871 {
13872 if (!ep_args.empty())
13873 ep_args += ", ";
13874 if (msl_options.multi_patch_workgroup)
13875 {
13876 ep_args += join(ts: "device ", ts: type_to_glsl(type: get_stage_in_struct_type()), ts: "* ", ts&: input_buffer_var_name,
13877 ts: " [[buffer(", ts: convert_to_string(t: msl_options.shader_input_buffer_index), ts: ")]]");
13878 }
13879 else
13880 {
13881 ep_args += join(ts: "threadgroup ", ts: type_to_glsl(type: get_stage_in_struct_type()), ts: "* ", ts&: input_wg_var_name,
13882 ts: " [[threadgroup(", ts: convert_to_string(t: msl_options.shader_input_wg_index), ts: ")]]");
13883 }
13884 }
13885 }
13886 }
13887 // Tessellation evaluation shaders get three additional parameters:
13888 // a buffer for the per-patch data, a buffer for the per-patch
13889 // tessellation levels, and a buffer for the control point data.
13890 if (is_tese_shader() && msl_options.raw_buffer_tese_input)
13891 {
13892 if (patch_stage_in_var_id)
13893 {
13894 if (!ep_args.empty())
13895 ep_args += ", ";
13896 ep_args +=
13897 join(ts: "const device ", ts: type_to_glsl(type: get_patch_stage_in_struct_type()), ts: "* ", ts&: patch_input_buffer_var_name,
13898 ts: " [[buffer(", ts: convert_to_string(t: msl_options.shader_patch_input_buffer_index), ts: ")]]");
13899 }
13900
13901 if (tess_level_inner_var_id || tess_level_outer_var_id)
13902 {
13903 if (!ep_args.empty())
13904 ep_args += ", ";
13905 ep_args += join(ts: "const device ", ts: get_tess_factor_struct_name(), ts: "* ", ts&: tess_factor_buffer_var_name,
13906 ts: " [[buffer(", ts: convert_to_string(t: msl_options.shader_tess_factor_buffer_index), ts: ")]]");
13907 }
13908
13909 if (stage_in_var_id)
13910 {
13911 if (!ep_args.empty())
13912 ep_args += ", ";
13913 ep_args += join(ts: "const device ", ts: type_to_glsl(type: get_stage_in_struct_type()), ts: "* ", ts&: input_buffer_var_name,
13914 ts: " [[buffer(", ts: convert_to_string(t: msl_options.shader_input_buffer_index), ts: ")]]");
13915 }
13916 }
13917
13918 if (is_mesh_shader())
13919 {
13920 if (!ep_args.empty())
13921 ep_args += ", ";
13922 ep_args += join(ts: "spvMesh_t spvMesh");
13923 }
13924
13925 if (get_execution_model() == ExecutionModelTaskEXT)
13926 {
13927 if (!ep_args.empty())
13928 ep_args += ", ";
13929 ep_args += join(ts: "mesh_grid_properties spvMgp");
13930 }
13931}
13932
13933string CompilerMSL::entry_point_args_argument_buffer(bool append_comma)
13934{
13935 string ep_args = entry_point_arg_stage_in();
13936 Bitset claimed_bindings;
13937
13938 for (uint32_t i = 0; i < kMaxArgumentBuffers; i++)
13939 {
13940 uint32_t id = argument_buffer_ids[i];
13941 if (id == 0)
13942 continue;
13943
13944 add_resource_name(id);
13945 auto &var = get<SPIRVariable>(id);
13946 auto &type = get_variable_data_type(var);
13947
13948 if (!ep_args.empty())
13949 ep_args += ", ";
13950
13951 // Check if the argument buffer binding itself has been remapped.
13952 uint32_t buffer_binding;
13953 auto itr = resource_bindings.find(x: { .model: get_entry_point().model, .desc_set: i, .binding: kArgumentBufferBinding });
13954 if (itr != end(cont&: resource_bindings))
13955 {
13956 buffer_binding = itr->second.first.msl_buffer;
13957 itr->second.second = true;
13958 }
13959 else
13960 {
13961 // As a fallback, directly map desc set <-> binding.
13962 // If that was taken, take the next buffer binding.
13963 if (claimed_bindings.get(bit: i))
13964 buffer_binding = next_metal_resource_index_buffer;
13965 else
13966 buffer_binding = i;
13967 }
13968
13969 claimed_bindings.set(buffer_binding);
13970
13971 ep_args += get_argument_address_space(argument: var) + " ";
13972
13973 if (recursive_inputs.count(x: type.self))
13974 ep_args += string("void* ") + to_restrict(id, space: true) + to_name(id) + "_vp";
13975 else
13976 ep_args += type_to_glsl(type) + "& " + to_restrict(id, space: true) + to_name(id);
13977
13978 ep_args += " [[buffer(" + convert_to_string(t: buffer_binding) + ")]]";
13979
13980 next_metal_resource_index_buffer = max(a: next_metal_resource_index_buffer, b: buffer_binding + 1);
13981 }
13982
13983 entry_point_args_discrete_descriptors(args&: ep_args);
13984 entry_point_args_builtin(ep_args);
13985
13986 if (!ep_args.empty() && append_comma)
13987 ep_args += ", ";
13988
13989 return ep_args;
13990}
13991
13992const MSLConstexprSampler *CompilerMSL::find_constexpr_sampler(uint32_t id) const
13993{
13994 // Try by ID.
13995 {
13996 auto itr = constexpr_samplers_by_id.find(x: id);
13997 if (itr != end(cont: constexpr_samplers_by_id))
13998 return &itr->second;
13999 }
14000
14001 // Try by binding.
14002 {
14003 uint32_t desc_set = get_decoration(id, decoration: DecorationDescriptorSet);
14004 uint32_t binding = get_decoration(id, decoration: DecorationBinding);
14005
14006 auto itr = constexpr_samplers_by_binding.find(x: { .desc_set: desc_set, .binding: binding });
14007 if (itr != end(cont: constexpr_samplers_by_binding))
14008 return &itr->second;
14009 }
14010
14011 return nullptr;
14012}
14013
14014void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args)
14015{
14016 // Output resources, sorted by resource index & type
14017 // We need to sort to work around a bug on macOS 10.13 with NVidia drivers where switching between shaders
14018 // with different order of buffers can result in issues with buffer assignments inside the driver.
14019 struct Resource
14020 {
14021 SPIRVariable *var;
14022 SPIRVariable *discrete_descriptor_alias;
14023 string name;
14024 SPIRType::BaseType basetype;
14025 uint32_t index;
14026 uint32_t plane;
14027 uint32_t secondary_index;
14028 };
14029
14030 SmallVector<Resource> resources;
14031
14032 entry_point_bindings.clear();
14033 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t var_id, SPIRVariable &var) {
14034 if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
14035 var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer) &&
14036 !is_hidden_variable(var))
14037 {
14038 auto &type = get_variable_data_type(var);
14039 uint32_t desc_set = get_decoration(id: var_id, decoration: DecorationDescriptorSet);
14040
14041 if (is_supported_argument_buffer_type(type) && var.storage != StorageClassPushConstant)
14042 {
14043 if (descriptor_set_is_argument_buffer(desc_set))
14044 {
14045 if (is_var_runtime_size_array(var))
14046 {
14047 // Runtime arrays need to be wrapped in spvDescriptorArray from argument buffer payload.
14048 entry_point_bindings.push_back(t: &var);
14049 // We'll wrap this, so to_name() will always use non-qualified name.
14050 // We'll need the qualified name to create temporary variable instead.
14051 ir.meta[var_id].decoration.qualified_alias_explicit_override = true;
14052 }
14053 return;
14054 }
14055 }
14056
14057 // Handle descriptor aliasing of simple discrete cases.
14058 // We can handle aliasing of buffers by casting pointers.
14059 // The amount of aliasing we can perform for discrete descriptors is very limited.
14060 // For fully mutable-style aliasing, we need argument buffers where we can exploit the fact
14061 // that descriptors are all 8 bytes.
14062 SPIRVariable *discrete_descriptor_alias = nullptr;
14063 if (var.storage == StorageClassUniform || var.storage == StorageClassStorageBuffer)
14064 {
14065 for (auto &resource : resources)
14066 {
14067 if (get_decoration(id: resource.var->self, decoration: DecorationDescriptorSet) ==
14068 get_decoration(id: var_id, decoration: DecorationDescriptorSet) &&
14069 get_decoration(id: resource.var->self, decoration: DecorationBinding) ==
14070 get_decoration(id: var_id, decoration: DecorationBinding) &&
14071 resource.basetype == SPIRType::Struct && type.basetype == SPIRType::Struct &&
14072 (resource.var->storage == StorageClassUniform ||
14073 resource.var->storage == StorageClassStorageBuffer))
14074 {
14075 discrete_descriptor_alias = resource.var;
14076 // Self-reference marks that we should declare the resource,
14077 // and it's being used as an alias (so we can emit void* instead).
14078 resource.discrete_descriptor_alias = resource.var;
14079 // Need to promote interlocked usage so that the primary declaration is correct.
14080 if (interlocked_resources.count(x: var_id))
14081 interlocked_resources.insert(x: resource.var->self);
14082 break;
14083 }
14084 }
14085 }
14086
14087 const MSLConstexprSampler *constexpr_sampler = nullptr;
14088 if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler)
14089 {
14090 constexpr_sampler = find_constexpr_sampler(id: var_id);
14091 if (constexpr_sampler)
14092 {
14093 // Mark this ID as a constexpr sampler for later in case it came from set/bindings.
14094 constexpr_samplers_by_id[var_id] = *constexpr_sampler;
14095 }
14096 }
14097
14098 // Emulate texture2D atomic operations
14099 uint32_t secondary_index = 0;
14100 if (atomic_image_vars_emulated.count(x: var.self))
14101 {
14102 secondary_index = get_metal_resource_index(var, basetype: SPIRType::AtomicCounter, plane: 0);
14103 }
14104
14105 if (type.basetype == SPIRType::SampledImage)
14106 {
14107 add_resource_name(id: var_id);
14108
14109 uint32_t plane_count = 1;
14110 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
14111 plane_count = constexpr_sampler->planes;
14112
14113 entry_point_bindings.push_back(t: &var);
14114 for (uint32_t i = 0; i < plane_count; i++)
14115 resources.push_back(t: {.var: &var, .discrete_descriptor_alias: discrete_descriptor_alias, .name: to_name(id: var_id), .basetype: SPIRType::Image,
14116 .index: get_metal_resource_index(var, basetype: SPIRType::Image, plane: i), .plane: i, .secondary_index: secondary_index });
14117
14118 if (type.image.dim != DimBuffer && !constexpr_sampler)
14119 {
14120 resources.push_back(t: {.var: &var, .discrete_descriptor_alias: discrete_descriptor_alias, .name: to_sampler_expression(id: var_id), .basetype: SPIRType::Sampler,
14121 .index: get_metal_resource_index(var, basetype: SPIRType::Sampler), .plane: 0, .secondary_index: 0 });
14122 }
14123 }
14124 else if (!constexpr_sampler)
14125 {
14126 // constexpr samplers are not declared as resources.
14127 add_resource_name(id: var_id);
14128
14129 // Don't allocate resource indices for aliases.
14130 uint32_t resource_index = ~0u;
14131 if (!discrete_descriptor_alias)
14132 resource_index = get_metal_resource_index(var, basetype: type.basetype);
14133
14134 entry_point_bindings.push_back(t: &var);
14135 resources.push_back(t: {.var: &var, .discrete_descriptor_alias: discrete_descriptor_alias, .name: to_name(id: var_id), .basetype: type.basetype,
14136 .index: resource_index, .plane: 0, .secondary_index: secondary_index });
14137 }
14138 }
14139 });
14140
14141 stable_sort(first: resources.begin(), last: resources.end(),
14142 comp: [](const Resource &lhs, const Resource &rhs)
14143 { return tie(args: lhs.basetype, args: lhs.index) < tie(args: rhs.basetype, args: rhs.index); });
14144
14145 for (auto &r : resources)
14146 {
14147 auto &var = *r.var;
14148 auto &type = get_variable_data_type(var);
14149
14150 uint32_t var_id = var.self;
14151
14152 switch (r.basetype)
14153 {
14154 case SPIRType::Struct:
14155 {
14156 auto &m = ir.meta[type.self];
14157 if (m.members.size() == 0)
14158 break;
14159
14160 if (r.discrete_descriptor_alias)
14161 {
14162 if (r.var == r.discrete_descriptor_alias)
14163 {
14164 auto primary_name = join(ts: "spvBufferAliasSet",
14165 ts: get_decoration(id: var_id, decoration: DecorationDescriptorSet),
14166 ts: "Binding",
14167 ts: get_decoration(id: var_id, decoration: DecorationBinding));
14168
14169 // Declare the primary alias as void*
14170 if (!ep_args.empty())
14171 ep_args += ", ";
14172 ep_args += get_argument_address_space(argument: var) + " void* " + primary_name;
14173 ep_args += " [[buffer(" + convert_to_string(t: r.index) + ")";
14174 if (interlocked_resources.count(x: var_id))
14175 ep_args += ", raster_order_group(0)";
14176 ep_args += "]]";
14177 }
14178
14179 buffer_aliases_discrete.push_back(t: r.var->self);
14180 }
14181 else if (!type.array.empty())
14182 {
14183 if (type.array.size() > 1)
14184 SPIRV_CROSS_THROW("Arrays of arrays of buffers are not supported.");
14185
14186 is_using_builtin_array = true;
14187 if (is_var_runtime_size_array(var))
14188 {
14189 add_spv_func_and_recompile(spv_func: SPVFuncImplVariableDescriptorArray);
14190 if (!ep_args.empty())
14191 ep_args += ", ";
14192 const bool ssbo = has_decoration(id: type.self, decoration: DecorationBufferBlock);
14193 if ((var.storage == spv::StorageClassStorageBuffer || ssbo) &&
14194 msl_options.runtime_array_rich_descriptor)
14195 {
14196 add_spv_func_and_recompile(spv_func: SPVFuncImplVariableSizedDescriptor);
14197 ep_args += "const device spvBufferDescriptor<" + get_argument_address_space(argument: var) + " " +
14198 type_to_glsl(type) + "*>* ";
14199 }
14200 else
14201 {
14202 add_spv_func_and_recompile(spv_func: SPVFuncImplVariableDescriptor);
14203 ep_args += "const device spvDescriptor<" + get_argument_address_space(argument: var) + " " +
14204 type_to_glsl(type) + "*>* ";
14205 }
14206 ep_args += to_restrict(id: var_id, space: true) + r.name + "_";
14207 ep_args += " [[buffer(" + convert_to_string(t: r.index) + ")";
14208 if (interlocked_resources.count(x: var_id))
14209 ep_args += ", raster_order_group(0)";
14210 ep_args += "]]";
14211 }
14212 else
14213 {
14214 uint32_t array_size = get_resource_array_size(type, id: var_id);
14215 for (uint32_t i = 0; i < array_size; ++i)
14216 {
14217 if (!ep_args.empty())
14218 ep_args += ", ";
14219 ep_args += get_argument_address_space(argument: var) + " " + type_to_glsl(type) + "* " +
14220 to_restrict(id: var_id, space: true) + r.name + "_" + convert_to_string(t: i);
14221 ep_args += " [[buffer(" + convert_to_string(t: r.index + i) + ")";
14222 if (interlocked_resources.count(x: var_id))
14223 ep_args += ", raster_order_group(0)";
14224 ep_args += "]]";
14225 }
14226 }
14227 is_using_builtin_array = false;
14228 }
14229 else
14230 {
14231 if (!ep_args.empty())
14232 ep_args += ", ";
14233 ep_args += get_argument_address_space(argument: var) + " ";
14234
14235 if (recursive_inputs.count(x: type.self))
14236 ep_args += string("void* ") + to_restrict(id: var_id, space: true) + r.name + "_vp";
14237 else
14238 ep_args += type_to_glsl(type) + "& " + to_restrict(id: var_id, space: true) + r.name;
14239
14240 ep_args += " [[buffer(" + convert_to_string(t: r.index) + ")";
14241 if (interlocked_resources.count(x: var_id))
14242 ep_args += ", raster_order_group(0)";
14243 ep_args += "]]";
14244 }
14245 break;
14246 }
14247 case SPIRType::Sampler:
14248 if (!ep_args.empty())
14249 ep_args += ", ";
14250 ep_args += sampler_type(type, id: var_id, member: false) + " " + r.name;
14251 if (is_var_runtime_size_array(var))
14252 ep_args += "_ [[buffer(" + convert_to_string(t: r.index) + ")]]";
14253 else
14254 ep_args += " [[sampler(" + convert_to_string(t: r.index) + ")]]";
14255 break;
14256 case SPIRType::Image:
14257 {
14258 if (!ep_args.empty())
14259 ep_args += ", ";
14260
14261 // Use Metal's native frame-buffer fetch API for subpass inputs.
14262 const auto &basetype = get<SPIRType>(id: var.basetype);
14263 if (!type_is_msl_framebuffer_fetch(type: basetype))
14264 {
14265 ep_args += image_type_glsl(type, id: var_id, member: false) + " " + r.name;
14266 if (r.plane > 0)
14267 ep_args += join(ts&: plane_name_suffix, ts&: r.plane);
14268
14269 if (is_var_runtime_size_array(var))
14270 ep_args += "_ [[buffer(" + convert_to_string(t: r.index) + ")";
14271 else
14272 ep_args += " [[texture(" + convert_to_string(t: r.index) + ")";
14273
14274 if (interlocked_resources.count(x: var_id))
14275 ep_args += ", raster_order_group(0)";
14276 ep_args += "]]";
14277 }
14278 else
14279 {
14280 if (msl_options.is_macos() && !msl_options.supports_msl_version(major: 2, minor: 3))
14281 SPIRV_CROSS_THROW("Framebuffer fetch on Mac is not supported before MSL 2.3.");
14282 ep_args += image_type_glsl(type, id: var_id, member: false) + " " + r.name;
14283 ep_args += " [[color(" + convert_to_string(t: r.index) + ")]]";
14284 }
14285
14286 // Emulate texture2D atomic operations
14287 if (atomic_image_vars_emulated.count(x: var.self))
14288 {
14289 auto &flags = ir.get_decoration_bitset(id: var.self);
14290 const char *cv_flags = decoration_flags_signal_volatile(flags) ? "volatile " : "";
14291 ep_args += join(ts: ", ", ts&: cv_flags, ts: "device atomic_", ts: type_to_glsl(type: get<SPIRType>(id: basetype.image.type), id: 0));
14292 ep_args += "* " + r.name + "_atomic";
14293 ep_args += " [[buffer(" + convert_to_string(t: r.secondary_index) + ")";
14294 if (interlocked_resources.count(x: var_id))
14295 ep_args += ", raster_order_group(0)";
14296 ep_args += "]]";
14297 }
14298 break;
14299 }
14300 case SPIRType::AccelerationStructure:
14301 {
14302 if (is_var_runtime_size_array(var))
14303 {
14304 add_spv_func_and_recompile(spv_func: SPVFuncImplVariableDescriptor);
14305 const auto &parent_type = get<SPIRType>(id: type.parent_type);
14306 if (!ep_args.empty())
14307 ep_args += ", ";
14308 ep_args += "const device spvDescriptor<" + type_to_glsl(type: parent_type) + ">* " +
14309 to_restrict(id: var_id, space: true) + r.name + "_";
14310 ep_args += " [[buffer(" + convert_to_string(t: r.index) + ")]]";
14311 }
14312 else
14313 {
14314 if (!ep_args.empty())
14315 ep_args += ", ";
14316 ep_args += type_to_glsl(type, id: var_id) + " " + r.name;
14317 ep_args += " [[buffer(" + convert_to_string(t: r.index) + ")]]";
14318 }
14319 break;
14320 }
14321 default:
14322 if (!ep_args.empty())
14323 ep_args += ", ";
14324 if (!type.pointer)
14325 ep_args += get_type_address_space(type: get<SPIRType>(id: var.basetype), id: var_id) + " " +
14326 type_to_glsl(type, id: var_id) + "& " + r.name;
14327 else
14328 ep_args += type_to_glsl(type, id: var_id) + " " + r.name;
14329 ep_args += " [[buffer(" + convert_to_string(t: r.index) + ")";
14330 if (interlocked_resources.count(x: var_id))
14331 ep_args += ", raster_order_group(0)";
14332 ep_args += "]]";
14333 break;
14334 }
14335 }
14336}
14337
14338// Returns a string containing a comma-delimited list of args for the entry point function
14339// This is the "classic" method of MSL 1 when we don't have argument buffer support.
14340string CompilerMSL::entry_point_args_classic(bool append_comma)
14341{
14342 string ep_args = entry_point_arg_stage_in();
14343 entry_point_args_discrete_descriptors(ep_args);
14344 entry_point_args_builtin(ep_args);
14345
14346 if (!ep_args.empty() && append_comma)
14347 ep_args += ", ";
14348
14349 return ep_args;
14350}
14351
14352void CompilerMSL::fix_up_shader_inputs_outputs()
14353{
14354 auto &entry_func = this->get<SPIRFunction>(id: ir.default_entry_point);
14355
14356 // Emit a guard to ensure we don't execute beyond the last vertex.
14357 // Vertex shaders shouldn't have the problems with barriers in non-uniform control flow that
14358 // tessellation control shaders do, so early returns should be OK. We may need to revisit this
14359 // if it ever becomes possible to use barriers from a vertex shader.
14360 if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation)
14361 {
14362 entry_func.fixup_hooks_in.push_back(t: [this]() {
14363 statement(ts: "if (any(", ts: to_expression(id: builtin_invocation_id_id),
14364 ts: " >= ", ts: to_expression(id: builtin_stage_input_size_id), ts: "))");
14365 statement(ts: " return;");
14366 });
14367 }
14368
14369 if (is_mesh_shader())
14370 {
14371 // If shader doesn't call SetMeshOutputsEXT, nothing should be rendered.
14372 // No need to barrier after this, because only thread 0 writes to this later.
14373 entry_func.fixup_hooks_in.push_back(t: [this]() { statement(ts: "if (gl_LocalInvocationIndex == 0) spvMeshSizes.y = 0u;"); });
14374 entry_func.fixup_hooks_out.push_back(t: [this]() { emit_mesh_outputs(); });
14375 }
14376
14377 // Look for sampled images and buffer. Add hooks to set up the swizzle constants or array lengths.
14378 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
14379 auto &type = get_variable_data_type(var);
14380 uint32_t var_id = var.self;
14381 bool ssbo = has_decoration(id: type.self, decoration: DecorationBufferBlock);
14382
14383 if (var.storage == StorageClassUniformConstant && !is_hidden_variable(var))
14384 {
14385 if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
14386 {
14387 entry_func.fixup_hooks_in.push_back(t: [this, &type, &var, var_id]() {
14388 bool is_array_type = !type.array.empty();
14389
14390 uint32_t desc_set = get_decoration(id: var_id, decoration: DecorationDescriptorSet);
14391 if (descriptor_set_is_argument_buffer(desc_set))
14392 {
14393 statement(ts: "constant uint", ts: is_array_type ? "* " : "& ", ts: to_swizzle_expression(id: var_id),
14394 ts: is_array_type ? " = &" : " = ", ts: to_name(id: argument_buffer_ids[desc_set]),
14395 ts: ".spvSwizzleConstants", ts: "[",
14396 ts: convert_to_string(t: get_metal_resource_index(var, basetype: SPIRType::Image)), ts: "];");
14397 }
14398 else
14399 {
14400 // If we have an array of images, we need to be able to index into it, so take a pointer instead.
14401 statement(ts: "constant uint", ts: is_array_type ? "* " : "& ", ts: to_swizzle_expression(id: var_id),
14402 ts: is_array_type ? " = &" : " = ", ts: to_name(id: swizzle_buffer_id), ts: "[",
14403 ts: convert_to_string(t: get_metal_resource_index(var, basetype: SPIRType::Image)), ts: "];");
14404 }
14405 });
14406 }
14407 }
14408 else if ((var.storage == StorageClassStorageBuffer || (var.storage == StorageClassUniform && ssbo)) &&
14409 !is_hidden_variable(var))
14410 {
14411 if (buffer_requires_array_length(id: var.self))
14412 {
14413 entry_func.fixup_hooks_in.push_back(
14414 t: [this, &type, &var, var_id]()
14415 {
14416 bool is_array_type = !type.array.empty() && !is_var_runtime_size_array(var);
14417
14418 uint32_t desc_set = get_decoration(id: var_id, decoration: DecorationDescriptorSet);
14419 if (descriptor_set_is_argument_buffer(desc_set))
14420 {
14421 statement(ts: "constant uint", ts: is_array_type ? "* " : "& ", ts: to_buffer_size_expression(id: var_id),
14422 ts: is_array_type ? " = &" : " = ", ts: to_name(id: argument_buffer_ids[desc_set]),
14423 ts: ".spvBufferSizeConstants", ts: "[",
14424 ts: convert_to_string(t: get_metal_resource_index(var, basetype: SPIRType::UInt)), ts: "];");
14425 }
14426 else
14427 {
14428 // If we have an array of images, we need to be able to index into it, so take a pointer instead.
14429 statement(ts: "constant uint", ts: is_array_type ? "* " : "& ", ts: to_buffer_size_expression(id: var_id),
14430 ts: is_array_type ? " = &" : " = ", ts: to_name(id: buffer_size_buffer_id), ts: "[",
14431 ts: convert_to_string(t: get_metal_resource_index(var, basetype: type.basetype)), ts: "];");
14432 }
14433 });
14434 }
14435 }
14436
14437 if (!msl_options.argument_buffers &&
14438 msl_options.replace_recursive_inputs && type_contains_recursion(type) &&
14439 (var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
14440 var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer))
14441 {
14442 recursive_inputs.insert(x: type.self);
14443 entry_func.fixup_hooks_in.push_back(t: [this, &type, &var, var_id]() {
14444 auto addr_space = get_argument_address_space(argument: var);
14445 auto var_name = to_name(id: var_id);
14446 statement(ts&: addr_space, ts: " auto& ", ts: to_restrict(id: var_id, space: true), ts&: var_name,
14447 ts: " = *(", ts&: addr_space, ts: " ", ts: type_to_glsl(type), ts: "*)", ts&: var_name, ts: "_vp;");
14448 });
14449 }
14450 });
14451
14452 // Builtin variables
14453 ir.for_each_typed_id<SPIRVariable>(op: [this, &entry_func](uint32_t, SPIRVariable &var) {
14454 uint32_t var_id = var.self;
14455 BuiltIn bi_type = ir.meta[var_id].decoration.builtin_type;
14456
14457 if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
14458 return;
14459 if (!interface_variable_exists_in_entry_point(id: var.self))
14460 return;
14461
14462 if (var.storage == StorageClassInput && is_builtin_variable(var) && active_input_builtins.get(bit: bi_type))
14463 {
14464 switch (bi_type)
14465 {
14466 case BuiltInSamplePosition:
14467 entry_func.fixup_hooks_in.push_back(t: [=]() {
14468 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = get_sample_position(",
14469 ts: to_expression(id: builtin_sample_id_id), ts: ");");
14470 });
14471 break;
14472 case BuiltInFragCoord:
14473 if (is_sample_rate())
14474 {
14475 entry_func.fixup_hooks_in.push_back(t: [=]() {
14476 statement(ts: to_expression(id: var_id), ts: ".xy += get_sample_position(",
14477 ts: to_expression(id: builtin_sample_id_id), ts: ") - 0.5;");
14478 });
14479 }
14480 break;
14481 case BuiltInInvocationId:
14482 // This is direct-mapped without multi-patch workgroups.
14483 if (!is_tesc_shader() || !msl_options.multi_patch_workgroup)
14484 break;
14485
14486 entry_func.fixup_hooks_in.push_back(t: [=]() {
14487 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14488 ts: to_expression(id: builtin_invocation_id_id), ts: ".x % ", ts&: this->get_entry_point().output_vertices,
14489 ts: ";");
14490 });
14491 break;
14492 case BuiltInPrimitiveId:
14493 // This is natively supported by fragment and tessellation evaluation shaders.
14494 // In tessellation control shaders, this is direct-mapped without multi-patch workgroups.
14495 if (!is_tesc_shader() || !msl_options.multi_patch_workgroup)
14496 break;
14497
14498 entry_func.fixup_hooks_in.push_back(t: [=]() {
14499 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = min(",
14500 ts: to_expression(id: builtin_invocation_id_id), ts: ".x / ", ts&: this->get_entry_point().output_vertices,
14501 ts: ", spvIndirectParams[1] - 1);");
14502 });
14503 break;
14504 case BuiltInPatchVertices:
14505 if (is_tese_shader())
14506 {
14507 if (msl_options.raw_buffer_tese_input)
14508 {
14509 entry_func.fixup_hooks_in.push_back(
14510 t: [=]() {
14511 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14512 ts&: get_entry_point().output_vertices, ts: ";");
14513 });
14514 }
14515 else
14516 {
14517 entry_func.fixup_hooks_in.push_back(
14518 t: [=]()
14519 {
14520 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14521 ts: to_expression(id: patch_stage_in_var_id), ts: ".gl_in.size();");
14522 });
14523 }
14524 }
14525 else
14526 {
14527 entry_func.fixup_hooks_in.push_back(t: [=]() {
14528 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = spvIndirectParams[0];");
14529 });
14530 }
14531 break;
14532 case BuiltInTessCoord:
14533 if (get_entry_point().flags.get(bit: ExecutionModeQuads))
14534 {
14535 // The entry point will only have a float2 TessCoord variable.
14536 // Pad to float3.
14537 entry_func.fixup_hooks_in.push_back(t: [=]() {
14538 auto name = builtin_to_glsl(builtin: BuiltInTessCoord, storage: StorageClassInput);
14539 statement(ts: "float3 " + name + " = float3(" + name + "In.x, " + name + "In.y, 0.0);");
14540 });
14541 }
14542
14543 // Emit a fixup to account for the shifted domain. Don't do this for triangles;
14544 // MoltenVK will just reverse the winding order instead.
14545 if (msl_options.tess_domain_origin_lower_left && !is_tessellating_triangles())
14546 {
14547 string tc = to_expression(id: var_id);
14548 entry_func.fixup_hooks_in.push_back(t: [=]() { statement(ts: tc, ts: ".y = 1.0 - ", ts: tc, ts: ".y;"); });
14549 }
14550 break;
14551 case BuiltInSubgroupId:
14552 if (!msl_options.emulate_subgroups)
14553 break;
14554 // For subgroup emulation, this is the same as the local invocation index.
14555 entry_func.fixup_hooks_in.push_back(t: [=]() {
14556 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14557 ts: to_expression(id: builtin_local_invocation_index_id), ts: ";");
14558 });
14559 break;
14560 case BuiltInNumSubgroups:
14561 if (!msl_options.emulate_subgroups)
14562 break;
14563 // For subgroup emulation, this is the same as the workgroup size.
14564 entry_func.fixup_hooks_in.push_back(t: [=]() {
14565 auto &type = expression_type(id: builtin_workgroup_size_id);
14566 string size_expr = to_expression(id: builtin_workgroup_size_id);
14567 if (type.vecsize >= 3)
14568 size_expr = join(ts&: size_expr, ts: ".x * ", ts&: size_expr, ts: ".y * ", ts&: size_expr, ts: ".z");
14569 else if (type.vecsize == 2)
14570 size_expr = join(ts&: size_expr, ts: ".x * ", ts&: size_expr, ts: ".y");
14571 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ", ts&: size_expr, ts: ";");
14572 });
14573 break;
14574 case BuiltInSubgroupLocalInvocationId:
14575 if (!msl_options.emulate_subgroups)
14576 break;
14577 // For subgroup emulation, assume subgroups of size 1.
14578 entry_func.fixup_hooks_in.push_back(
14579 t: [=]() { statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = 0;"); });
14580 break;
14581 case BuiltInSubgroupSize:
14582 if (msl_options.emulate_subgroups)
14583 {
14584 // For subgroup emulation, assume subgroups of size 1.
14585 entry_func.fixup_hooks_in.push_back(
14586 t: [=]() { statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = 1;"); });
14587 }
14588 else if (msl_options.fixed_subgroup_size != 0)
14589 {
14590 entry_func.fixup_hooks_in.push_back(t: [=]() {
14591 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14592 ts&: msl_options.fixed_subgroup_size, ts: ";");
14593 });
14594 }
14595 break;
14596 case BuiltInSubgroupEqMask:
14597 if (msl_options.is_ios() && !msl_options.supports_msl_version(major: 2, minor: 2))
14598 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
14599 if (!msl_options.supports_msl_version(major: 2, minor: 1))
14600 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
14601 entry_func.fixup_hooks_in.push_back(t: [=]() {
14602 if (msl_options.is_ios())
14603 {
14604 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ", ts: "uint4(1 << ",
14605 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: ", uint3(0));");
14606 }
14607 else
14608 {
14609 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14610 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " >= 32 ? uint4(0, (1 << (",
14611 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " - 32)), uint2(0)) : uint4(1 << ",
14612 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: ", uint3(0));");
14613 }
14614 });
14615 break;
14616 case BuiltInSubgroupGeMask:
14617 if (msl_options.is_ios() && !msl_options.supports_msl_version(major: 2, minor: 2))
14618 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
14619 if (!msl_options.supports_msl_version(major: 2, minor: 1))
14620 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
14621 if (msl_options.fixed_subgroup_size != 0)
14622 add_spv_func_and_recompile(spv_func: SPVFuncImplSubgroupBallot);
14623 entry_func.fixup_hooks_in.push_back(t: [=]() {
14624 // Case where index < 32, size < 32:
14625 // mask0 = bfi(0, 0xFFFFFFFF, index, size - index);
14626 // mask1 = bfi(0, 0xFFFFFFFF, 0, 0); // Gives 0
14627 // Case where index < 32 but size >= 32:
14628 // mask0 = bfi(0, 0xFFFFFFFF, index, 32 - index);
14629 // mask1 = bfi(0, 0xFFFFFFFF, 0, size - 32);
14630 // Case where index >= 32:
14631 // mask0 = bfi(0, 0xFFFFFFFF, 32, 0); // Gives 0
14632 // mask1 = bfi(0, 0xFFFFFFFF, index - 32, size - index);
14633 // This is expressed without branches to avoid divergent
14634 // control flow--hence the complicated min/max expressions.
14635 // This is further complicated by the fact that if you attempt
14636 // to bfi/bfe out-of-bounds on Metal, undefined behavior is the
14637 // result.
14638 if (msl_options.fixed_subgroup_size > 32)
14639 {
14640 // Don't use the subgroup size variable with fixed subgroup sizes,
14641 // since the variables could be defined in the wrong order.
14642 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14643 ts: " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
14644 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: ", 32u), (uint)max(32 - (int)",
14645 ts: to_expression(id: builtin_subgroup_invocation_id_id),
14646 ts: ", 0)), insert_bits(0u, 0xFFFFFFFF,"
14647 " (uint)max((int)",
14648 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " - 32, 0), ",
14649 ts&: msl_options.fixed_subgroup_size, ts: " - max(",
14650 ts: to_expression(id: builtin_subgroup_invocation_id_id),
14651 ts: ", 32u)), uint2(0));");
14652 }
14653 else if (msl_options.fixed_subgroup_size != 0)
14654 {
14655 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14656 ts: " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
14657 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: ", ",
14658 ts&: msl_options.fixed_subgroup_size, ts: " - ",
14659 ts: to_expression(id: builtin_subgroup_invocation_id_id),
14660 ts: "), uint3(0));");
14661 }
14662 else if (msl_options.is_ios())
14663 {
14664 // On iOS, the SIMD-group size will currently never exceed 32.
14665 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14666 ts: " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
14667 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: ", ",
14668 ts: to_expression(id: builtin_subgroup_size_id), ts: " - ",
14669 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: "), uint3(0));");
14670 }
14671 else
14672 {
14673 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14674 ts: " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
14675 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: ", 32u), (uint)max(min((int)",
14676 ts: to_expression(id: builtin_subgroup_size_id), ts: ", 32) - (int)",
14677 ts: to_expression(id: builtin_subgroup_invocation_id_id),
14678 ts: ", 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
14679 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " - 32, 0), (uint)max((int)",
14680 ts: to_expression(id: builtin_subgroup_size_id), ts: " - (int)max(",
14681 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: ", 32u), 0)), uint2(0));");
14682 }
14683 });
14684 break;
14685 case BuiltInSubgroupGtMask:
14686 if (msl_options.is_ios() && !msl_options.supports_msl_version(major: 2, minor: 2))
14687 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
14688 if (!msl_options.supports_msl_version(major: 2, minor: 1))
14689 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
14690 add_spv_func_and_recompile(spv_func: SPVFuncImplSubgroupBallot);
14691 entry_func.fixup_hooks_in.push_back(t: [=]() {
14692 // The same logic applies here, except now the index is one
14693 // more than the subgroup invocation ID.
14694 if (msl_options.fixed_subgroup_size > 32)
14695 {
14696 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14697 ts: " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
14698 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " + 1, 32u), (uint)max(32 - (int)",
14699 ts: to_expression(id: builtin_subgroup_invocation_id_id),
14700 ts: " - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
14701 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " + 1 - 32, 0), ",
14702 ts&: msl_options.fixed_subgroup_size, ts: " - max(",
14703 ts: to_expression(id: builtin_subgroup_invocation_id_id),
14704 ts: " + 1, 32u)), uint2(0));");
14705 }
14706 else if (msl_options.fixed_subgroup_size != 0)
14707 {
14708 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14709 ts: " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
14710 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " + 1, ",
14711 ts&: msl_options.fixed_subgroup_size, ts: " - ",
14712 ts: to_expression(id: builtin_subgroup_invocation_id_id),
14713 ts: " - 1), uint3(0));");
14714 }
14715 else if (msl_options.is_ios())
14716 {
14717 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14718 ts: " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
14719 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " + 1, ",
14720 ts: to_expression(id: builtin_subgroup_size_id), ts: " - ",
14721 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " - 1), uint3(0));");
14722 }
14723 else
14724 {
14725 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14726 ts: " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
14727 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " + 1, 32u), (uint)max(min((int)",
14728 ts: to_expression(id: builtin_subgroup_size_id), ts: ", 32) - (int)",
14729 ts: to_expression(id: builtin_subgroup_invocation_id_id),
14730 ts: " - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
14731 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " + 1 - 32, 0), (uint)max((int)",
14732 ts: to_expression(id: builtin_subgroup_size_id), ts: " - (int)max(",
14733 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " + 1, 32u), 0)), uint2(0));");
14734 }
14735 });
14736 break;
14737 case BuiltInSubgroupLeMask:
14738 if (msl_options.is_ios() && !msl_options.supports_msl_version(major: 2, minor: 2))
14739 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
14740 if (!msl_options.supports_msl_version(major: 2, minor: 1))
14741 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
14742 add_spv_func_and_recompile(spv_func: SPVFuncImplSubgroupBallot);
14743 entry_func.fixup_hooks_in.push_back(t: [=]() {
14744 if (msl_options.is_ios())
14745 {
14746 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14747 ts: " = uint4(extract_bits(0xFFFFFFFF, 0, ",
14748 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " + 1), uint3(0));");
14749 }
14750 else
14751 {
14752 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14753 ts: " = uint4(extract_bits(0xFFFFFFFF, 0, min(",
14754 ts: to_expression(id: builtin_subgroup_invocation_id_id),
14755 ts: " + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
14756 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " + 1 - 32, 0)), uint2(0));");
14757 }
14758 });
14759 break;
14760 case BuiltInSubgroupLtMask:
14761 if (msl_options.is_ios() && !msl_options.supports_msl_version(major: 2, minor: 2))
14762 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
14763 if (!msl_options.supports_msl_version(major: 2, minor: 1))
14764 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
14765 add_spv_func_and_recompile(spv_func: SPVFuncImplSubgroupBallot);
14766 entry_func.fixup_hooks_in.push_back(t: [=]() {
14767 if (msl_options.is_ios())
14768 {
14769 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14770 ts: " = uint4(extract_bits(0xFFFFFFFF, 0, ",
14771 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: "), uint3(0));");
14772 }
14773 else
14774 {
14775 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14776 ts: " = uint4(extract_bits(0xFFFFFFFF, 0, min(",
14777 ts: to_expression(id: builtin_subgroup_invocation_id_id),
14778 ts: ", 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
14779 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " - 32, 0)), uint2(0));");
14780 }
14781 });
14782 break;
14783 case BuiltInViewIndex:
14784 if (!msl_options.multiview)
14785 {
14786 // According to the Vulkan spec, when not running under a multiview
14787 // render pass, ViewIndex is 0.
14788 entry_func.fixup_hooks_in.push_back(t: [=]() {
14789 statement(ts: "const ", ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = 0;");
14790 });
14791 }
14792 else if (msl_options.view_index_from_device_index)
14793 {
14794 // In this case, we take the view index from that of the device we're running on.
14795 entry_func.fixup_hooks_in.push_back(t: [=]() {
14796 statement(ts: "const ", ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14797 ts&: msl_options.device_index, ts: ";");
14798 });
14799 // We actually don't want to set the render_target_array_index here.
14800 // Since every physical device is rendering a different view,
14801 // there's no need for layered rendering here.
14802 }
14803 else if (!msl_options.multiview_layered_rendering)
14804 {
14805 // In this case, the views are rendered one at a time. The view index, then,
14806 // is just the first part of the "view mask".
14807 entry_func.fixup_hooks_in.push_back(t: [=]() {
14808 statement(ts: "const ", ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14809 ts: to_expression(id: view_mask_buffer_id), ts: "[0];");
14810 });
14811 }
14812 else if (get_execution_model() == ExecutionModelFragment)
14813 {
14814 // Because we adjusted the view index in the vertex shader, we have to
14815 // adjust it back here.
14816 entry_func.fixup_hooks_in.push_back(t: [=]() {
14817 statement(ts: to_expression(id: var_id), ts: " += ", ts: to_expression(id: view_mask_buffer_id), ts: "[0];");
14818 });
14819 }
14820 else if (get_execution_model() == ExecutionModelVertex)
14821 {
14822 // Metal provides no special support for multiview, so we smuggle
14823 // the view index in the instance index.
14824 entry_func.fixup_hooks_in.push_back(t: [=]() {
14825 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14826 ts: to_expression(id: view_mask_buffer_id), ts: "[0] + (", ts: to_expression(id: builtin_instance_idx_id),
14827 ts: " - ", ts: to_expression(id: builtin_base_instance_id), ts: ") % ",
14828 ts: to_expression(id: view_mask_buffer_id), ts: "[1];");
14829 statement(ts: to_expression(id: builtin_instance_idx_id), ts: " = (",
14830 ts: to_expression(id: builtin_instance_idx_id), ts: " - ",
14831 ts: to_expression(id: builtin_base_instance_id), ts: ") / ", ts: to_expression(id: view_mask_buffer_id),
14832 ts: "[1] + ", ts: to_expression(id: builtin_base_instance_id), ts: ";");
14833 });
14834 // In addition to setting the variable itself, we also need to
14835 // set the render_target_array_index with it on output. We have to
14836 // offset this by the base view index, because Metal isn't in on
14837 // our little game here.
14838 entry_func.fixup_hooks_out.push_back(t: [=]() {
14839 statement(ts: to_expression(id: builtin_layer_id), ts: " = ", ts: to_expression(id: var_id), ts: " - ",
14840 ts: to_expression(id: view_mask_buffer_id), ts: "[0];");
14841 });
14842 }
14843 break;
14844 case BuiltInDeviceIndex:
14845 // Metal pipelines belong to the devices which create them, so we'll
14846 // need to create a MTLPipelineState for every MTLDevice in a grouped
14847 // VkDevice. We can assume, then, that the device index is constant.
14848 entry_func.fixup_hooks_in.push_back(t: [=]() {
14849 statement(ts: "const ", ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14850 ts&: msl_options.device_index, ts: ";");
14851 });
14852 break;
14853 case BuiltInWorkgroupId:
14854 if (!msl_options.dispatch_base || !active_input_builtins.get(bit: BuiltInWorkgroupId))
14855 break;
14856
14857 // The vkCmdDispatchBase() command lets the client set the base value
14858 // of WorkgroupId. Metal has no direct equivalent; we must make this
14859 // adjustment ourselves.
14860 entry_func.fixup_hooks_in.push_back(t: [=]() {
14861 statement(ts: to_expression(id: var_id), ts: " += ", ts: to_dereferenced_expression(id: builtin_dispatch_base_id), ts: ";");
14862 });
14863 break;
14864 case BuiltInGlobalInvocationId:
14865 if (!msl_options.dispatch_base || !active_input_builtins.get(bit: BuiltInGlobalInvocationId))
14866 break;
14867
14868 // GlobalInvocationId is defined as LocalInvocationId + WorkgroupId * WorkgroupSize.
14869 // This needs to be adjusted too.
14870 entry_func.fixup_hooks_in.push_back(t: [=]() {
14871 auto &execution = this->get_entry_point();
14872 uint32_t workgroup_size_id = execution.workgroup_size.constant;
14873 if (workgroup_size_id)
14874 statement(ts: to_expression(id: var_id), ts: " += ", ts: to_dereferenced_expression(id: builtin_dispatch_base_id),
14875 ts: " * ", ts: to_expression(id: workgroup_size_id), ts: ";");
14876 else
14877 statement(ts: to_expression(id: var_id), ts: " += ", ts: to_dereferenced_expression(id: builtin_dispatch_base_id),
14878 ts: " * uint3(", ts&: execution.workgroup_size.x, ts: ", ", ts&: execution.workgroup_size.y, ts: ", ",
14879 ts&: execution.workgroup_size.z, ts: ");");
14880 });
14881 break;
14882 case BuiltInVertexId:
14883 case BuiltInVertexIndex:
14884 // This is direct-mapped normally.
14885 if (!msl_options.vertex_for_tessellation)
14886 break;
14887
14888 entry_func.fixup_hooks_in.push_back(t: [=]() {
14889 builtin_declaration = true;
14890 switch (msl_options.vertex_index_type)
14891 {
14892 case Options::IndexType::None:
14893 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14894 ts: to_expression(id: builtin_invocation_id_id), ts: ".x + ",
14895 ts: to_expression(id: builtin_dispatch_base_id), ts: ".x;");
14896 break;
14897 case Options::IndexType::UInt16:
14898 case Options::IndexType::UInt32:
14899 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ", ts&: index_buffer_var_name,
14900 ts: "[", ts: to_expression(id: builtin_invocation_id_id), ts: ".x] + ",
14901 ts: to_expression(id: builtin_dispatch_base_id), ts: ".x;");
14902 break;
14903 }
14904 builtin_declaration = false;
14905 });
14906 break;
14907 case BuiltInBaseVertex:
14908 // This is direct-mapped normally.
14909 if (!msl_options.vertex_for_tessellation)
14910 break;
14911
14912 entry_func.fixup_hooks_in.push_back(t: [=]() {
14913 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14914 ts: to_expression(id: builtin_dispatch_base_id), ts: ".x;");
14915 });
14916 break;
14917 case BuiltInInstanceId:
14918 case BuiltInInstanceIndex:
14919 // This is direct-mapped normally.
14920 if (!msl_options.vertex_for_tessellation)
14921 break;
14922
14923 entry_func.fixup_hooks_in.push_back(t: [=]() {
14924 builtin_declaration = true;
14925 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14926 ts: to_expression(id: builtin_invocation_id_id), ts: ".y + ", ts: to_expression(id: builtin_dispatch_base_id),
14927 ts: ".y;");
14928 builtin_declaration = false;
14929 });
14930 break;
14931 case BuiltInBaseInstance:
14932 // This is direct-mapped normally.
14933 if (!msl_options.vertex_for_tessellation)
14934 break;
14935
14936 entry_func.fixup_hooks_in.push_back(t: [=]() {
14937 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14938 ts: to_expression(id: builtin_dispatch_base_id), ts: ".y;");
14939 });
14940 break;
14941 default:
14942 break;
14943 }
14944 }
14945 else if (var.storage == StorageClassOutput && get_execution_model() == ExecutionModelFragment &&
14946 is_builtin_variable(var) && active_output_builtins.get(bit: bi_type))
14947 {
14948 switch (bi_type)
14949 {
14950 case BuiltInSampleMask:
14951 if (has_additional_fixed_sample_mask())
14952 {
14953 // If the additional fixed sample mask was set, we need to adjust the sample_mask
14954 // output to reflect that. If the shader outputs the sample_mask itself too, we need
14955 // to AND the two masks to get the final one.
14956 string op_str = does_shader_write_sample_mask ? " &= " : " = ";
14957 entry_func.fixup_hooks_out.push_back(t: [=]() {
14958 statement(ts: to_expression(id: builtin_sample_mask_id), ts: op_str, ts: additional_fixed_sample_mask_str(), ts: ";");
14959 });
14960 }
14961 break;
14962 case BuiltInFragDepth:
14963 if (msl_options.input_attachment_is_ds_attachment && !writes_to_depth)
14964 {
14965 entry_func.fixup_hooks_out.push_back(t: [=]() {
14966 statement(ts: to_expression(id: builtin_frag_depth_id), ts: " = ", ts: to_expression(id: builtin_frag_coord_id), ts: ".z;");
14967 });
14968 }
14969 break;
14970 default:
14971 break;
14972 }
14973 }
14974 });
14975}
14976
14977// Returns the Metal index of the resource of the specified type as used by the specified variable.
14978uint32_t CompilerMSL::get_metal_resource_index(SPIRVariable &var, SPIRType::BaseType basetype, uint32_t plane)
14979{
14980 auto &execution = get_entry_point();
14981 auto &var_dec = ir.meta[var.self].decoration;
14982 auto &var_type = get<SPIRType>(id: var.basetype);
14983 uint32_t var_desc_set = (var.storage == StorageClassPushConstant) ? kPushConstDescSet : var_dec.set;
14984 uint32_t var_binding = (var.storage == StorageClassPushConstant) ? kPushConstBinding : var_dec.binding;
14985
14986 // If a matching binding has been specified, find and use it.
14987 auto itr = resource_bindings.find(x: { .model: execution.model, .desc_set: var_desc_set, .binding: var_binding });
14988
14989 // Atomic helper buffers for image atomics need to use secondary bindings as well.
14990 bool use_secondary_binding = (var_type.basetype == SPIRType::SampledImage && basetype == SPIRType::Sampler) ||
14991 basetype == SPIRType::AtomicCounter;
14992
14993 auto resource_decoration =
14994 use_secondary_binding ? SPIRVCrossDecorationResourceIndexSecondary : SPIRVCrossDecorationResourceIndexPrimary;
14995
14996 if (plane == 1)
14997 resource_decoration = SPIRVCrossDecorationResourceIndexTertiary;
14998 if (plane == 2)
14999 resource_decoration = SPIRVCrossDecorationResourceIndexQuaternary;
15000
15001 if (itr != end(cont&: resource_bindings))
15002 {
15003 auto &remap = itr->second;
15004 remap.second = true;
15005 switch (basetype)
15006 {
15007 case SPIRType::Image:
15008 set_extended_decoration(id: var.self, decoration: resource_decoration, value: remap.first.msl_texture + plane);
15009 return remap.first.msl_texture + plane;
15010 case SPIRType::Sampler:
15011 set_extended_decoration(id: var.self, decoration: resource_decoration, value: remap.first.msl_sampler);
15012 return remap.first.msl_sampler;
15013 default:
15014 set_extended_decoration(id: var.self, decoration: resource_decoration, value: remap.first.msl_buffer);
15015 return remap.first.msl_buffer;
15016 }
15017 }
15018
15019 // If we have already allocated an index, keep using it.
15020 if (has_extended_decoration(id: var.self, decoration: resource_decoration))
15021 return get_extended_decoration(id: var.self, decoration: resource_decoration);
15022
15023 auto &type = get<SPIRType>(id: var.basetype);
15024
15025 if (type_is_msl_framebuffer_fetch(type))
15026 {
15027 // Frame-buffer fetch gets its fallback resource index from the input attachment index,
15028 // which is then treated as color index.
15029 return get_decoration(id: var.self, decoration: DecorationInputAttachmentIndex);
15030 }
15031 else if (msl_options.enable_decoration_binding)
15032 {
15033 // Allow user to enable decoration binding.
15034 // If there is no explicit mapping of bindings to MSL, use the declared binding as a fallback.
15035 if (has_decoration(id: var.self, decoration: DecorationBinding))
15036 {
15037 var_binding = get_decoration(id: var.self, decoration: DecorationBinding);
15038 // Avoid emitting sentinel bindings.
15039 if (var_binding < 0x80000000u)
15040 return var_binding;
15041 }
15042 }
15043
15044 // If we did not explicitly remap, allocate bindings on demand.
15045 // We cannot reliably use Binding decorations since SPIR-V and MSL's binding models are very different.
15046
15047 bool allocate_argument_buffer_ids = false;
15048
15049 if (var.storage != StorageClassPushConstant)
15050 allocate_argument_buffer_ids = descriptor_set_is_argument_buffer(desc_set: var_desc_set);
15051
15052 uint32_t binding_stride = 1;
15053 for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
15054 binding_stride *= to_array_size_literal(type, index: i);
15055
15056 // If a binding has not been specified, revert to incrementing resource indices.
15057 uint32_t resource_index;
15058
15059 if (allocate_argument_buffer_ids)
15060 {
15061 // Allocate from a flat ID binding space.
15062 resource_index = next_metal_resource_ids[var_desc_set];
15063 next_metal_resource_ids[var_desc_set] += binding_stride;
15064 }
15065 else
15066 {
15067 if (is_var_runtime_size_array(var))
15068 {
15069 basetype = SPIRType::Struct;
15070 binding_stride = 1;
15071 }
15072 // Allocate from plain bindings which are allocated per resource type.
15073 switch (basetype)
15074 {
15075 case SPIRType::Image:
15076 resource_index = next_metal_resource_index_texture;
15077 next_metal_resource_index_texture += binding_stride;
15078 break;
15079 case SPIRType::Sampler:
15080 resource_index = next_metal_resource_index_sampler;
15081 next_metal_resource_index_sampler += binding_stride;
15082 break;
15083 default:
15084 resource_index = next_metal_resource_index_buffer;
15085 next_metal_resource_index_buffer += binding_stride;
15086 break;
15087 }
15088 }
15089
15090 set_extended_decoration(id: var.self, decoration: resource_decoration, value: resource_index);
15091 return resource_index;
15092}
15093
15094bool CompilerMSL::type_is_msl_framebuffer_fetch(const SPIRType &type) const
15095{
15096 return type.basetype == SPIRType::Image && type.image.dim == DimSubpassData &&
15097 msl_options.use_framebuffer_fetch_subpasses;
15098}
15099
15100const char *CompilerMSL::descriptor_address_space(uint32_t id, StorageClass storage, const char *plain_address_space) const
15101{
15102 if (msl_options.argument_buffers)
15103 {
15104 bool storage_class_is_descriptor = storage == StorageClassUniform ||
15105 storage == StorageClassStorageBuffer ||
15106 storage == StorageClassUniformConstant;
15107
15108 uint32_t desc_set = get_decoration(id, decoration: DecorationDescriptorSet);
15109 if (storage_class_is_descriptor && descriptor_set_is_argument_buffer(desc_set))
15110 {
15111 // An awkward case where we need to emit *more* address space declarations (yay!).
15112 // An example is where we pass down an array of buffer pointers to leaf functions.
15113 // It's a constant array containing pointers to constants.
15114 // The pointer array is always constant however. E.g.
15115 // device SSBO * constant (&array)[N].
15116 // const device SSBO * constant (&array)[N].
15117 // constant SSBO * constant (&array)[N].
15118 // However, this only matters for argument buffers, since for MSL 1.0 style codegen,
15119 // we emit the buffer array on stack instead, and that seems to work just fine apparently.
15120
15121 // If the argument was marked as being in device address space, any pointer to member would
15122 // be const device, not constant.
15123 if (argument_buffer_device_storage_mask & (1u << desc_set))
15124 return "const device";
15125 else
15126 return "constant";
15127 }
15128 }
15129
15130 return plain_address_space;
15131}
15132
15133string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
15134{
15135 auto &var = get<SPIRVariable>(id: arg.id);
15136 auto &type = get_variable_data_type(var);
15137 auto &var_type = get<SPIRType>(id: arg.type);
15138 StorageClass type_storage = var_type.storage;
15139
15140 // If we need to modify the name of the variable, make sure we use the original variable.
15141 // Our alias is just a shadow variable.
15142 uint32_t name_id = var.self;
15143 if (arg.alias_global_variable && var.basevariable)
15144 name_id = var.basevariable;
15145
15146 bool constref = !arg.alias_global_variable && is_pointer(type: var_type) && arg.write_count == 0;
15147 // Framebuffer fetch is plain value, const looks out of place, but it is not wrong.
15148 if (type_is_msl_framebuffer_fetch(type))
15149 constref = false;
15150 else if (type_storage == StorageClassUniformConstant)
15151 constref = true;
15152
15153 bool type_is_image = type.basetype == SPIRType::Image || type.basetype == SPIRType::SampledImage ||
15154 type.basetype == SPIRType::Sampler;
15155 bool type_is_tlas = type.basetype == SPIRType::AccelerationStructure;
15156
15157 // For opaque types we handle const later due to descriptor address spaces.
15158 const char *cv_qualifier = (constref && !type_is_image) ? "const " : "";
15159 string decl;
15160
15161 // If this is a combined image-sampler for a 2D image with floating-point type,
15162 // we emitted the 'spvDynamicImageSampler' type, and this is *not* an alias parameter
15163 // for a global, then we need to emit a "dynamic" combined image-sampler.
15164 // Unfortunately, this is necessary to properly support passing around
15165 // combined image-samplers with Y'CbCr conversions on them.
15166 bool is_dynamic_img_sampler = !arg.alias_global_variable && type.basetype == SPIRType::SampledImage &&
15167 type.image.dim == Dim2D && type_is_floating_point(type: get<SPIRType>(id: type.image.type)) &&
15168 spv_function_implementations.count(x: SPVFuncImplDynamicImageSampler);
15169
15170 // Allow Metal to use the array<T> template to make arrays a value type
15171 string address_space = get_argument_address_space(argument: var);
15172 bool builtin = has_decoration(id: var.self, decoration: DecorationBuiltIn);
15173 auto builtin_type = BuiltIn(get_decoration(id: arg.id, decoration: DecorationBuiltIn));
15174
15175 if (var.basevariable && (var.basevariable == stage_in_ptr_var_id || var.basevariable == stage_out_ptr_var_id))
15176 decl = join(ts&: cv_qualifier, ts: type_to_glsl(type, id: arg.id));
15177 else if (builtin && builtin_type != spv::BuiltInPrimitiveTriangleIndicesEXT &&
15178 builtin_type != spv::BuiltInPrimitiveLineIndicesEXT && builtin_type != spv::BuiltInPrimitivePointIndicesEXT)
15179 {
15180 // Only use templated array for Clip/Cull distance when feasible.
15181 // In other scenarios, we need need to override array length for tess levels (if used as outputs),
15182 // or we need to emit the expected type for builtins (uint vs int).
15183 auto storage = get<SPIRType>(id: var.basetype).storage;
15184
15185 if (storage == StorageClassInput &&
15186 (builtin_type == BuiltInTessLevelInner || builtin_type == BuiltInTessLevelOuter))
15187 {
15188 is_using_builtin_array = false;
15189 }
15190 else if (builtin_type != BuiltInClipDistance && builtin_type != BuiltInCullDistance)
15191 {
15192 is_using_builtin_array = true;
15193 }
15194
15195 if (storage == StorageClassOutput && variable_storage_requires_stage_io(storage) &&
15196 !is_stage_output_builtin_masked(builtin: builtin_type))
15197 is_using_builtin_array = true;
15198
15199 if (is_using_builtin_array)
15200 decl = join(ts&: cv_qualifier, ts: builtin_type_decl(builtin: builtin_type, id: arg.id));
15201 else
15202 decl = join(ts&: cv_qualifier, ts: type_to_glsl(type, id: arg.id));
15203 }
15204 else if (is_var_runtime_size_array(var))
15205 {
15206 const auto *parent_type = &get<SPIRType>(id: type.parent_type);
15207 auto type_name = type_to_glsl(type: *parent_type, id: arg.id);
15208 if (type.basetype == SPIRType::AccelerationStructure)
15209 decl = join(ts: "spvDescriptorArray<", ts&: type_name, ts: ">");
15210 else if (type_is_image)
15211 decl = join(ts: "spvDescriptorArray<", ts&: cv_qualifier, ts&: type_name, ts: ">");
15212 else
15213 decl = join(ts: "spvDescriptorArray<", ts&: address_space, ts: " ", ts&: type_name, ts: "*>");
15214 address_space = "const";
15215 }
15216 else if ((type_storage == StorageClassUniform || type_storage == StorageClassStorageBuffer) && is_array(type))
15217 {
15218 is_using_builtin_array = true;
15219 decl += join(ts&: cv_qualifier, ts: type_to_glsl(type, id: arg.id), ts: "*");
15220 }
15221 else if (is_dynamic_img_sampler)
15222 {
15223 decl = join(ts&: cv_qualifier, ts: "spvDynamicImageSampler<", ts: type_to_glsl(type: get<SPIRType>(id: type.image.type)), ts: ">");
15224 // Mark the variable so that we can handle passing it to another function.
15225 set_extended_decoration(id: arg.id, decoration: SPIRVCrossDecorationDynamicImageSampler);
15226 }
15227 else
15228 {
15229 // The type is a pointer type we need to emit cv_qualifier late.
15230 if (is_pointer(type))
15231 {
15232 decl = type_to_glsl(type, id: arg.id);
15233 if (*cv_qualifier != '\0')
15234 decl += join(ts: " ", ts&: cv_qualifier);
15235 }
15236 else
15237 {
15238 decl = join(ts&: cv_qualifier, ts: type_to_glsl(type, id: arg.id));
15239 }
15240 }
15241
15242 if (!builtin && !is_pointer(type: var_type) &&
15243 (type_storage == StorageClassFunction || type_storage == StorageClassGeneric))
15244 {
15245 // If the argument is a pure value and not an opaque type, we will pass by value.
15246 if (msl_options.force_native_arrays && is_array(type))
15247 {
15248 // We are receiving an array by value. This is problematic.
15249 // We cannot be sure of the target address space since we are supposed to receive a copy,
15250 // but this is not possible with MSL without some extra work.
15251 // We will have to assume we're getting a reference in thread address space.
15252 // If we happen to get a reference in constant address space, the caller must emit a copy and pass that.
15253 // Thread const therefore becomes the only logical choice, since we cannot "create" a constant array from
15254 // non-constant arrays, but we can create thread const from constant.
15255 decl = string("thread const ") + decl;
15256 decl += " (&";
15257 const char *restrict_kw = to_restrict(id: name_id, space: true);
15258 if (*restrict_kw)
15259 {
15260 decl += " ";
15261 decl += restrict_kw;
15262 }
15263 decl += to_expression(id: name_id);
15264 decl += ")";
15265 decl += type_to_array_glsl(type, variable_id: name_id);
15266 }
15267 else
15268 {
15269 if (!address_space.empty())
15270 decl = join(ts&: address_space, ts: " ", ts&: decl);
15271 decl += " ";
15272 decl += to_expression(id: name_id);
15273 }
15274 }
15275 else if (is_array(type) && !type_is_image)
15276 {
15277 // Arrays of opaque types are special cased.
15278 if (!address_space.empty())
15279 decl = join(ts&: address_space, ts: " ", ts&: decl);
15280
15281 // spvDescriptorArray absorbs the address space inside the template.
15282 if (!is_var_runtime_size_array(var))
15283 {
15284 const char *argument_buffer_space = descriptor_address_space(id: name_id, storage: type_storage, plain_address_space: nullptr);
15285 if (argument_buffer_space)
15286 {
15287 decl += " ";
15288 decl += argument_buffer_space;
15289 }
15290 }
15291
15292 // Special case, need to override the array size here if we're using tess level as an argument.
15293 if (is_tesc_shader() && builtin &&
15294 (builtin_type == BuiltInTessLevelInner || builtin_type == BuiltInTessLevelOuter))
15295 {
15296 uint32_t array_size = get_physical_tess_level_array_size(builtin: builtin_type);
15297 if (array_size == 1)
15298 {
15299 decl += " &";
15300 decl += to_expression(id: name_id);
15301 }
15302 else
15303 {
15304 decl += " (&";
15305 decl += to_expression(id: name_id);
15306 decl += ")";
15307 decl += join(ts: "[", ts&: array_size, ts: "]");
15308 }
15309 }
15310 else if (is_var_runtime_size_array(var))
15311 {
15312 decl += " " + to_expression(id: name_id);
15313 }
15314 else
15315 {
15316 auto array_size_decl = type_to_array_glsl(type, variable_id: name_id);
15317 if (array_size_decl.empty())
15318 decl += "& ";
15319 else
15320 decl += " (&";
15321
15322 const char *restrict_kw = to_restrict(id: name_id, space: true);
15323 if (*restrict_kw)
15324 {
15325 decl += " ";
15326 decl += restrict_kw;
15327 }
15328 decl += to_expression(id: name_id);
15329
15330 if (!array_size_decl.empty())
15331 {
15332 decl += ")";
15333 decl += array_size_decl;
15334 }
15335 }
15336 }
15337 else if (!type_is_image && !type_is_tlas &&
15338 (!pull_model_inputs.count(x: var.basevariable) || type.basetype == SPIRType::Struct))
15339 {
15340 // If this is going to be a reference to a variable pointer, the address space
15341 // for the reference has to go before the '&', but after the '*'.
15342 if (!address_space.empty())
15343 {
15344 if (is_pointer(type))
15345 {
15346 if (*cv_qualifier == '\0')
15347 decl += ' ';
15348 decl += join(ts&: address_space, ts: " ");
15349 }
15350 else
15351 decl = join(ts&: address_space, ts: " ", ts&: decl);
15352 }
15353 decl += "&";
15354 decl += " ";
15355 decl += to_restrict(id: name_id, space: true);
15356 decl += to_expression(id: name_id);
15357 }
15358 else if (type_is_image || type_is_tlas)
15359 {
15360 if (is_var_runtime_size_array(var))
15361 {
15362 decl = address_space + " " + decl + " " + to_expression(id: name_id);
15363 }
15364 else if (type.array.empty())
15365 {
15366 // For non-arrayed types we can just pass opaque descriptors by value.
15367 // This fixes problems if descriptors are passed by value from argument buffers and plain descriptors
15368 // in same shader.
15369 // There is no address space we can actually use, but value will work.
15370 // This will break if applications attempt to pass down descriptor arrays as arguments, but
15371 // fortunately that is extremely unlikely ...
15372 decl += " ";
15373 decl += to_expression(id: name_id);
15374 }
15375 else
15376 {
15377 const char *img_address_space = descriptor_address_space(id: name_id, storage: type_storage, plain_address_space: "thread const");
15378 decl = join(ts&: img_address_space, ts: " ", ts&: decl);
15379 decl += "& ";
15380 decl += to_expression(id: name_id);
15381 }
15382 }
15383 else
15384 {
15385 if (!address_space.empty())
15386 decl = join(ts&: address_space, ts: " ", ts&: decl);
15387 decl += " ";
15388 decl += to_expression(id: name_id);
15389 }
15390
15391 // Emulate texture2D atomic operations
15392 auto *backing_var = maybe_get_backing_variable(chain: name_id);
15393 if (backing_var && atomic_image_vars_emulated.count(x: backing_var->self))
15394 {
15395 auto &flags = ir.get_decoration_bitset(id: backing_var->self);
15396 const char *cv_flags = decoration_flags_signal_volatile(flags) ? "volatile " : "";
15397 decl += join(ts: ", ", ts&: cv_flags, ts: "device atomic_", ts: type_to_glsl(type: get<SPIRType>(id: var_type.image.type), id: 0));
15398 decl += "* " + to_expression(id: name_id) + "_atomic";
15399 }
15400
15401 is_using_builtin_array = false;
15402
15403 return decl;
15404}
15405
15406// If we're currently in the entry point function, and the object
15407// has a qualified name, use it, otherwise use the standard name.
15408string CompilerMSL::to_name(uint32_t id, bool allow_alias) const
15409{
15410 if (current_function && (current_function->self == ir.default_entry_point))
15411 {
15412 auto *m = ir.find_meta(id);
15413 if (m && !m->decoration.qualified_alias_explicit_override && !m->decoration.qualified_alias.empty())
15414 return m->decoration.qualified_alias;
15415 }
15416 return Compiler::to_name(id, allow_alias);
15417}
15418
15419// Appends the name of the member to the variable qualifier string, except for Builtins.
15420string CompilerMSL::append_member_name(const string &qualifier, const SPIRType &type, uint32_t index)
15421{
15422 // Don't qualify Builtin names because they are unique and are treated as such when building expressions
15423 BuiltIn builtin = BuiltInMax;
15424 if (is_member_builtin(type, index, builtin: &builtin))
15425 return builtin_to_glsl(builtin, storage: type.storage);
15426
15427 // Strip any underscore prefix from member name
15428 string mbr_name = to_member_name(type, index);
15429 size_t startPos = mbr_name.find_first_not_of(s: "_");
15430 mbr_name = (startPos != string::npos) ? mbr_name.substr(pos: startPos) : "";
15431 return join(ts: qualifier, ts: "_", ts&: mbr_name);
15432}
15433
15434// Ensures that the specified name is permanently usable by prepending a prefix
15435// if the first chars are _ and a digit, which indicate a transient name.
15436string CompilerMSL::ensure_valid_name(string name, string pfx)
15437{
15438 return (name.size() >= 2 && name[0] == '_' && isdigit(name[1])) ? (pfx + name) : name;
15439}
15440
15441const std::unordered_set<std::string> &CompilerMSL::get_reserved_keyword_set()
15442{
15443 static const unordered_set<string> keywords = {
15444 "kernel",
15445 "vertex",
15446 "fragment",
15447 "compute",
15448 "constant",
15449 "device",
15450 "bias",
15451 "level",
15452 "gradient2d",
15453 "gradientcube",
15454 "gradient3d",
15455 "min_lod_clamp",
15456 "assert",
15457 "VARIABLE_TRACEPOINT",
15458 "STATIC_DATA_TRACEPOINT",
15459 "STATIC_DATA_TRACEPOINT_V",
15460 "METAL_ALIGN",
15461 "METAL_ASM",
15462 "METAL_CONST",
15463 "METAL_DEPRECATED",
15464 "METAL_ENABLE_IF",
15465 "METAL_FUNC",
15466 "METAL_INTERNAL",
15467 "METAL_NON_NULL_RETURN",
15468 "METAL_NORETURN",
15469 "METAL_NOTHROW",
15470 "METAL_PURE",
15471 "METAL_UNAVAILABLE",
15472 "METAL_IMPLICIT",
15473 "METAL_EXPLICIT",
15474 "METAL_CONST_ARG",
15475 "METAL_ARG_UNIFORM",
15476 "METAL_ZERO_ARG",
15477 "METAL_VALID_LOD_ARG",
15478 "METAL_VALID_LEVEL_ARG",
15479 "METAL_VALID_STORE_ORDER",
15480 "METAL_VALID_LOAD_ORDER",
15481 "METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER",
15482 "METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS",
15483 "METAL_VALID_RENDER_TARGET",
15484 "is_function_constant_defined",
15485 "CHAR_BIT",
15486 "SCHAR_MAX",
15487 "SCHAR_MIN",
15488 "UCHAR_MAX",
15489 "CHAR_MAX",
15490 "CHAR_MIN",
15491 "USHRT_MAX",
15492 "SHRT_MAX",
15493 "SHRT_MIN",
15494 "UINT_MAX",
15495 "INT_MAX",
15496 "INT_MIN",
15497 "FLT_DIG",
15498 "FLT_MANT_DIG",
15499 "FLT_MAX_10_EXP",
15500 "FLT_MAX_EXP",
15501 "FLT_MIN_10_EXP",
15502 "FLT_MIN_EXP",
15503 "FLT_RADIX",
15504 "FLT_MAX",
15505 "FLT_MIN",
15506 "FLT_EPSILON",
15507 "FP_ILOGB0",
15508 "FP_ILOGBNAN",
15509 "MAXFLOAT",
15510 "HUGE_VALF",
15511 "INFINITY",
15512 "NAN",
15513 "M_E_F",
15514 "M_LOG2E_F",
15515 "M_LOG10E_F",
15516 "M_LN2_F",
15517 "M_LN10_F",
15518 "M_PI_F",
15519 "M_PI_2_F",
15520 "M_PI_4_F",
15521 "M_1_PI_F",
15522 "M_2_PI_F",
15523 "M_2_SQRTPI_F",
15524 "M_SQRT2_F",
15525 "M_SQRT1_2_F",
15526 "HALF_DIG",
15527 "HALF_MANT_DIG",
15528 "HALF_MAX_10_EXP",
15529 "HALF_MAX_EXP",
15530 "HALF_MIN_10_EXP",
15531 "HALF_MIN_EXP",
15532 "HALF_RADIX",
15533 "HALF_MAX",
15534 "HALF_MIN",
15535 "HALF_EPSILON",
15536 "MAXHALF",
15537 "HUGE_VALH",
15538 "M_E_H",
15539 "M_LOG2E_H",
15540 "M_LOG10E_H",
15541 "M_LN2_H",
15542 "M_LN10_H",
15543 "M_PI_H",
15544 "M_PI_2_H",
15545 "M_PI_4_H",
15546 "M_1_PI_H",
15547 "M_2_PI_H",
15548 "M_2_SQRTPI_H",
15549 "M_SQRT2_H",
15550 "M_SQRT1_2_H",
15551 "DBL_DIG",
15552 "DBL_MANT_DIG",
15553 "DBL_MAX_10_EXP",
15554 "DBL_MAX_EXP",
15555 "DBL_MIN_10_EXP",
15556 "DBL_MIN_EXP",
15557 "DBL_RADIX",
15558 "DBL_MAX",
15559 "DBL_MIN",
15560 "DBL_EPSILON",
15561 "HUGE_VAL",
15562 "M_E",
15563 "M_LOG2E",
15564 "M_LOG10E",
15565 "M_LN2",
15566 "M_LN10",
15567 "M_PI",
15568 "M_PI_2",
15569 "M_PI_4",
15570 "M_1_PI",
15571 "M_2_PI",
15572 "M_2_SQRTPI",
15573 "M_SQRT2",
15574 "M_SQRT1_2",
15575 "quad_broadcast",
15576 "thread",
15577 "threadgroup",
15578 };
15579
15580 return keywords;
15581}
15582
15583const std::unordered_set<std::string> &CompilerMSL::get_illegal_func_names()
15584{
15585 static const unordered_set<string> illegal_func_names = {
15586 "main",
15587 "saturate",
15588 "assert",
15589 "fmin3",
15590 "fmax3",
15591 "divide",
15592 "fmod",
15593 "median3",
15594 "VARIABLE_TRACEPOINT",
15595 "STATIC_DATA_TRACEPOINT",
15596 "STATIC_DATA_TRACEPOINT_V",
15597 "METAL_ALIGN",
15598 "METAL_ASM",
15599 "METAL_CONST",
15600 "METAL_DEPRECATED",
15601 "METAL_ENABLE_IF",
15602 "METAL_FUNC",
15603 "METAL_INTERNAL",
15604 "METAL_NON_NULL_RETURN",
15605 "METAL_NORETURN",
15606 "METAL_NOTHROW",
15607 "METAL_PURE",
15608 "METAL_UNAVAILABLE",
15609 "METAL_IMPLICIT",
15610 "METAL_EXPLICIT",
15611 "METAL_CONST_ARG",
15612 "METAL_ARG_UNIFORM",
15613 "METAL_ZERO_ARG",
15614 "METAL_VALID_LOD_ARG",
15615 "METAL_VALID_LEVEL_ARG",
15616 "METAL_VALID_STORE_ORDER",
15617 "METAL_VALID_LOAD_ORDER",
15618 "METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER",
15619 "METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS",
15620 "METAL_VALID_RENDER_TARGET",
15621 "is_function_constant_defined",
15622 "CHAR_BIT",
15623 "SCHAR_MAX",
15624 "SCHAR_MIN",
15625 "UCHAR_MAX",
15626 "CHAR_MAX",
15627 "CHAR_MIN",
15628 "USHRT_MAX",
15629 "SHRT_MAX",
15630 "SHRT_MIN",
15631 "UINT_MAX",
15632 "INT_MAX",
15633 "INT_MIN",
15634 "FLT_DIG",
15635 "FLT_MANT_DIG",
15636 "FLT_MAX_10_EXP",
15637 "FLT_MAX_EXP",
15638 "FLT_MIN_10_EXP",
15639 "FLT_MIN_EXP",
15640 "FLT_RADIX",
15641 "FLT_MAX",
15642 "FLT_MIN",
15643 "FLT_EPSILON",
15644 "FP_ILOGB0",
15645 "FP_ILOGBNAN",
15646 "MAXFLOAT",
15647 "HUGE_VALF",
15648 "INFINITY",
15649 "NAN",
15650 "M_E_F",
15651 "M_LOG2E_F",
15652 "M_LOG10E_F",
15653 "M_LN2_F",
15654 "M_LN10_F",
15655 "M_PI_F",
15656 "M_PI_2_F",
15657 "M_PI_4_F",
15658 "M_1_PI_F",
15659 "M_2_PI_F",
15660 "M_2_SQRTPI_F",
15661 "M_SQRT2_F",
15662 "M_SQRT1_2_F",
15663 "HALF_DIG",
15664 "HALF_MANT_DIG",
15665 "HALF_MAX_10_EXP",
15666 "HALF_MAX_EXP",
15667 "HALF_MIN_10_EXP",
15668 "HALF_MIN_EXP",
15669 "HALF_RADIX",
15670 "HALF_MAX",
15671 "HALF_MIN",
15672 "HALF_EPSILON",
15673 "MAXHALF",
15674 "HUGE_VALH",
15675 "M_E_H",
15676 "M_LOG2E_H",
15677 "M_LOG10E_H",
15678 "M_LN2_H",
15679 "M_LN10_H",
15680 "M_PI_H",
15681 "M_PI_2_H",
15682 "M_PI_4_H",
15683 "M_1_PI_H",
15684 "M_2_PI_H",
15685 "M_2_SQRTPI_H",
15686 "M_SQRT2_H",
15687 "M_SQRT1_2_H",
15688 "DBL_DIG",
15689 "DBL_MANT_DIG",
15690 "DBL_MAX_10_EXP",
15691 "DBL_MAX_EXP",
15692 "DBL_MIN_10_EXP",
15693 "DBL_MIN_EXP",
15694 "DBL_RADIX",
15695 "DBL_MAX",
15696 "DBL_MIN",
15697 "DBL_EPSILON",
15698 "HUGE_VAL",
15699 "M_E",
15700 "M_LOG2E",
15701 "M_LOG10E",
15702 "M_LN2",
15703 "M_LN10",
15704 "M_PI",
15705 "M_PI_2",
15706 "M_PI_4",
15707 "M_1_PI",
15708 "M_2_PI",
15709 "M_2_SQRTPI",
15710 "M_SQRT2",
15711 "M_SQRT1_2",
15712 };
15713
15714 return illegal_func_names;
15715}
15716
15717// Replace all names that match MSL keywords or Metal Standard Library functions.
15718void CompilerMSL::replace_illegal_names()
15719{
15720 // FIXME: MSL and GLSL are doing two different things here.
15721 // Agree on convention and remove this override.
15722 auto &keywords = get_reserved_keyword_set();
15723 auto &illegal_func_names = get_illegal_func_names();
15724
15725 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t self, SPIRVariable &) {
15726 auto *meta = ir.find_meta(id: self);
15727 if (!meta)
15728 return;
15729
15730 auto &dec = meta->decoration;
15731 if (keywords.find(x: dec.alias) != end(cont: keywords))
15732 dec.alias += "0";
15733 });
15734
15735 ir.for_each_typed_id<SPIRFunction>(op: [&](uint32_t self, SPIRFunction &) {
15736 auto *meta = ir.find_meta(id: self);
15737 if (!meta)
15738 return;
15739
15740 auto &dec = meta->decoration;
15741 if (illegal_func_names.find(x: dec.alias) != end(cont: illegal_func_names))
15742 dec.alias += "0";
15743 });
15744
15745 ir.for_each_typed_id<SPIRType>(op: [&](uint32_t self, SPIRType &) {
15746 auto *meta = ir.find_meta(id: self);
15747 if (!meta)
15748 return;
15749
15750 for (auto &mbr_dec : meta->members)
15751 if (keywords.find(x: mbr_dec.alias) != end(cont: keywords))
15752 mbr_dec.alias += "0";
15753 });
15754
15755 CompilerGLSL::replace_illegal_names();
15756}
15757
15758void CompilerMSL::replace_illegal_entry_point_names()
15759{
15760 auto &illegal_func_names = get_illegal_func_names();
15761
15762 // It is important to this before we fixup identifiers,
15763 // since if ep_name is reserved, we will need to fix that up,
15764 // and then copy alias back into entry.name after the fixup.
15765 for (auto &entry : ir.entry_points)
15766 {
15767 // Change both the entry point name and the alias, to keep them synced.
15768 string &ep_name = entry.second.name;
15769 if (illegal_func_names.find(x: ep_name) != end(cont: illegal_func_names))
15770 ep_name += "0";
15771
15772 ir.meta[entry.first].decoration.alias = ep_name;
15773 }
15774}
15775
15776void CompilerMSL::sync_entry_point_aliases_and_names()
15777{
15778 for (auto &entry : ir.entry_points)
15779 entry.second.name = ir.meta[entry.first].decoration.alias;
15780}
15781
15782string CompilerMSL::to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain_is_resolved)
15783{
15784 auto *var = maybe_get_backing_variable(chain: base);
15785 // If this is a buffer array, we have to dereference the buffer pointers.
15786 // Otherwise, if this is a pointer expression, dereference it.
15787
15788 bool declared_as_pointer = false;
15789
15790 if (var)
15791 {
15792 // Only allow -> dereference for block types. This is so we get expressions like
15793 // buffer[i]->first_member.second_member, rather than buffer[i]->first->second.
15794 const bool is_block =
15795 has_decoration(id: type.self, decoration: DecorationBlock) || has_decoration(id: type.self, decoration: DecorationBufferBlock);
15796
15797 bool is_buffer_variable =
15798 is_block && (var->storage == StorageClassUniform || var->storage == StorageClassStorageBuffer);
15799 declared_as_pointer = is_buffer_variable && is_array(type: get_pointee_type(type_id: var->basetype));
15800 }
15801
15802 if (declared_as_pointer || (!ptr_chain_is_resolved && should_dereference(id: base)))
15803 return join(ts: "->", ts: to_member_name(type, index));
15804 else
15805 return join(ts: ".", ts: to_member_name(type, index));
15806}
15807
15808string CompilerMSL::to_qualifiers_glsl(uint32_t id)
15809{
15810 string quals;
15811
15812 auto *var = maybe_get<SPIRVariable>(id);
15813 auto &type = expression_type(id);
15814
15815 if (type.storage == StorageClassTaskPayloadWorkgroupEXT)
15816 quals += "object_data ";
15817
15818 if (type.storage == StorageClassWorkgroup || (var && variable_decl_is_remapped_storage(variable: *var, storage: StorageClassWorkgroup)))
15819 quals += "threadgroup ";
15820
15821 return quals;
15822}
15823
15824// The optional id parameter indicates the object whose type we are trying
15825// to find the description for. It is optional. Most type descriptions do not
15826// depend on a specific object's use of that type.
15827string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id, bool member)
15828{
15829 string type_name;
15830
15831 // Pointer?
15832 if (is_pointer(type) || type_is_array_of_pointers(type))
15833 {
15834 assert(type.pointer_depth > 0);
15835
15836 const char *restrict_kw;
15837
15838 auto type_address_space = get_type_address_space(type, id);
15839 const auto *p_parent_type = &get<SPIRType>(id: type.parent_type);
15840
15841 // If we're wrapping buffer descriptors in a spvDescriptorArray, we'll have to handle it as a special case.
15842 if (member && id)
15843 {
15844 auto &var = get<SPIRVariable>(id);
15845 if (is_var_runtime_size_array(var) && is_runtime_size_array(type: *p_parent_type))
15846 {
15847 const bool ssbo = has_decoration(id: p_parent_type->self, decoration: DecorationBufferBlock);
15848 bool buffer_desc =
15849 (var.storage == StorageClassStorageBuffer || ssbo) &&
15850 msl_options.runtime_array_rich_descriptor;
15851
15852 const char *wrapper_type = buffer_desc ? "spvBufferDescriptor" : "spvDescriptor";
15853 add_spv_func_and_recompile(spv_func: SPVFuncImplVariableDescriptorArray);
15854 add_spv_func_and_recompile(spv_func: buffer_desc ? SPVFuncImplVariableSizedDescriptor : SPVFuncImplVariableDescriptor);
15855
15856 type_name = join(ts&: wrapper_type, ts: "<", ts&: type_address_space, ts: " ", ts: type_to_glsl(type: *p_parent_type, id), ts: " *>");
15857 return type_name;
15858 }
15859 }
15860
15861 // Work around C pointer qualifier rules. If glsl_type is a pointer type as well
15862 // we'll need to emit the address space to the right.
15863 // We could always go this route, but it makes the code unnatural.
15864 // Prefer emitting thread T *foo over T thread* foo since it's more readable,
15865 // but we'll have to emit thread T * thread * T constant bar; for example.
15866 if (is_pointer(type) && is_pointer(type: *p_parent_type))
15867 type_name = join(ts: type_to_glsl(type: *p_parent_type, id), ts: " ", ts&: type_address_space, ts: " ");
15868 else
15869 {
15870 // Since this is not a pointer-to-pointer, ensure we've dug down to the base type.
15871 // Some situations chain pointers even though they are not formally pointers-of-pointers.
15872 while (is_pointer(type: *p_parent_type))
15873 p_parent_type = &get<SPIRType>(id: p_parent_type->parent_type);
15874
15875 // If we're emitting BDA, just use the templated type.
15876 // Emitting builtin arrays need a lot of cooperation with other code to ensure
15877 // the C-style nesting works right.
15878 // FIXME: This is somewhat of a hack.
15879 bool old_is_using_builtin_array = is_using_builtin_array;
15880 if (is_physical_pointer(type))
15881 is_using_builtin_array = false;
15882
15883 type_name = join(ts&: type_address_space, ts: " ", ts: type_to_glsl(type: *p_parent_type, id));
15884
15885 is_using_builtin_array = old_is_using_builtin_array;
15886 }
15887
15888 switch (type.basetype)
15889 {
15890 case SPIRType::Image:
15891 case SPIRType::SampledImage:
15892 case SPIRType::Sampler:
15893 // These are handles.
15894 break;
15895 default:
15896 // Anything else can be a raw pointer.
15897 type_name += "*";
15898 restrict_kw = to_restrict(id, space: false);
15899 if (*restrict_kw)
15900 {
15901 type_name += " ";
15902 type_name += restrict_kw;
15903 }
15904 break;
15905 }
15906 return type_name;
15907 }
15908
15909 switch (type.basetype)
15910 {
15911 case SPIRType::Struct:
15912 // Need OpName lookup here to get a "sensible" name for a struct.
15913 // Allow Metal to use the array<T> template to make arrays a value type
15914 type_name = to_name(id: type.self);
15915 break;
15916
15917 case SPIRType::Image:
15918 case SPIRType::SampledImage:
15919 return image_type_glsl(type, id, member);
15920
15921 case SPIRType::Sampler:
15922 return sampler_type(type, id, member);
15923
15924 case SPIRType::Void:
15925 return "void";
15926
15927 case SPIRType::AtomicCounter:
15928 return "atomic_uint";
15929
15930 case SPIRType::ControlPointArray:
15931 return join(ts: "patch_control_point<", ts: type_to_glsl(type: get<SPIRType>(id: type.parent_type), id), ts: ">");
15932
15933 case SPIRType::Interpolant:
15934 return join(ts: "interpolant<", ts: type_to_glsl(type: get<SPIRType>(id: type.parent_type), id), ts: ", interpolation::",
15935 ts: has_decoration(id: type.self, decoration: DecorationNoPerspective) ? "no_perspective" : "perspective", ts: ">");
15936
15937 // Scalars
15938 case SPIRType::Boolean:
15939 {
15940 auto *var = maybe_get_backing_variable(chain: id);
15941 if (var && var->basevariable)
15942 var = &get<SPIRVariable>(id: var->basevariable);
15943
15944 // Need to special-case threadgroup booleans. They are supposed to be logical
15945 // storage, but MSL compilers will sometimes crash if you use threadgroup bool.
15946 // Workaround this by using 16-bit types instead and fixup on load-store to this data.
15947 if ((var && var->storage == StorageClassWorkgroup) || type.storage == StorageClassWorkgroup || member)
15948 type_name = "short";
15949 else
15950 type_name = "bool";
15951 break;
15952 }
15953
15954 case SPIRType::Char:
15955 case SPIRType::SByte:
15956 type_name = "char";
15957 break;
15958 case SPIRType::UByte:
15959 type_name = "uchar";
15960 break;
15961 case SPIRType::Short:
15962 type_name = "short";
15963 break;
15964 case SPIRType::UShort:
15965 type_name = "ushort";
15966 break;
15967 case SPIRType::Int:
15968 type_name = "int";
15969 break;
15970 case SPIRType::UInt:
15971 type_name = "uint";
15972 break;
15973 case SPIRType::Int64:
15974 if (!msl_options.supports_msl_version(major: 2, minor: 2))
15975 SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above.");
15976 type_name = "long";
15977 break;
15978 case SPIRType::UInt64:
15979 if (!msl_options.supports_msl_version(major: 2, minor: 2))
15980 SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above.");
15981 type_name = "ulong";
15982 break;
15983 case SPIRType::Half:
15984 type_name = "half";
15985 break;
15986 case SPIRType::Float:
15987 type_name = "float";
15988 break;
15989 case SPIRType::Double:
15990 type_name = "double"; // Currently unsupported
15991 break;
15992 case SPIRType::AccelerationStructure:
15993 if (msl_options.supports_msl_version(major: 2, minor: 4))
15994 type_name = "raytracing::acceleration_structure<raytracing::instancing>";
15995 else if (msl_options.supports_msl_version(major: 2, minor: 3))
15996 type_name = "raytracing::instance_acceleration_structure";
15997 else
15998 SPIRV_CROSS_THROW("Acceleration Structure Type is supported in MSL 2.3 and above.");
15999 break;
16000 case SPIRType::RayQuery:
16001 return "raytracing::intersection_query<raytracing::instancing, raytracing::triangle_data>";
16002 case SPIRType::MeshGridProperties:
16003 return "mesh_grid_properties";
16004
16005 default:
16006 return "unknown_type";
16007 }
16008
16009 // Matrix?
16010 if (type.columns > 1)
16011 {
16012 auto *var = maybe_get_backing_variable(chain: id);
16013 if (var && var->basevariable)
16014 var = &get<SPIRVariable>(id: var->basevariable);
16015
16016 // Need to special-case threadgroup matrices. Due to an oversight, Metal's
16017 // matrix struct prior to Metal 3 lacks constructors in the threadgroup AS,
16018 // preventing us from default-constructing or initializing matrices in threadgroup storage.
16019 // Work around this by using our own type as storage.
16020 if (((var && var->storage == StorageClassWorkgroup) || type.storage == StorageClassWorkgroup) &&
16021 !msl_options.supports_msl_version(major: 3, minor: 0))
16022 {
16023 add_spv_func_and_recompile(spv_func: SPVFuncImplStorageMatrix);
16024 type_name = "spvStorage_" + type_name;
16025 }
16026
16027 type_name += to_string(val: type.columns) + "x";
16028 }
16029
16030 // Vector or Matrix?
16031 if (type.vecsize > 1)
16032 type_name += to_string(val: type.vecsize);
16033
16034 if (type.array.empty() || using_builtin_array())
16035 {
16036 return type_name;
16037 }
16038 else
16039 {
16040 // Allow Metal to use the array<T> template to make arrays a value type
16041 add_spv_func_and_recompile(spv_func: SPVFuncImplUnsafeArray);
16042 string res;
16043 string sizes;
16044
16045 for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
16046 {
16047 res += "spvUnsafeArray<";
16048 sizes += ", ";
16049 sizes += to_array_size(type, index: i);
16050 sizes += ">";
16051 }
16052
16053 res += type_name + sizes;
16054 return res;
16055 }
16056}
16057
16058string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id)
16059{
16060 return type_to_glsl(type, id, member: false);
16061}
16062
16063string CompilerMSL::type_to_array_glsl(const SPIRType &type, uint32_t variable_id)
16064{
16065 // Allow Metal to use the array<T> template to make arrays a value type
16066 switch (type.basetype)
16067 {
16068 case SPIRType::AtomicCounter:
16069 case SPIRType::ControlPointArray:
16070 case SPIRType::RayQuery:
16071 return CompilerGLSL::type_to_array_glsl(type, variable_id);
16072
16073 default:
16074 if (type_is_array_of_pointers(type) || using_builtin_array())
16075 {
16076 const SPIRVariable *var = variable_id ? &get<SPIRVariable>(id: variable_id) : nullptr;
16077 if (var && (var->storage == StorageClassUniform || var->storage == StorageClassStorageBuffer) &&
16078 is_array(type: get_variable_data_type(var: *var)))
16079 {
16080 return join(ts: "[", ts: get_resource_array_size(type, id: variable_id), ts: "]");
16081 }
16082 else
16083 return CompilerGLSL::type_to_array_glsl(type, variable_id);
16084 }
16085 else
16086 return "";
16087 }
16088}
16089
16090string CompilerMSL::constant_op_expression(const SPIRConstantOp &cop)
16091{
16092 switch (cop.opcode)
16093 {
16094 case OpQuantizeToF16:
16095 add_spv_func_and_recompile(spv_func: SPVFuncImplQuantizeToF16);
16096 return join(ts: "spvQuantizeToF16(", ts: to_expression(id: cop.arguments[0]), ts: ")");
16097 default:
16098 return CompilerGLSL::constant_op_expression(cop);
16099 }
16100}
16101
16102bool CompilerMSL::variable_decl_is_remapped_storage(const SPIRVariable &variable, spv::StorageClass storage) const
16103{
16104 if (variable.storage == storage)
16105 return true;
16106
16107 if (storage == StorageClassWorkgroup)
16108 {
16109 // Specially masked IO block variable.
16110 // Normally, we will never access IO blocks directly here.
16111 // The only scenario which that should occur is with a masked IO block.
16112 if (is_tesc_shader() && variable.storage == StorageClassOutput &&
16113 has_decoration(id: get<SPIRType>(id: variable.basetype).self, decoration: DecorationBlock))
16114 {
16115 return true;
16116 }
16117
16118 if (is_mesh_shader())
16119 return variable.storage == StorageClassOutput;
16120
16121 return variable.storage == StorageClassOutput && is_tesc_shader() && is_stage_output_variable_masked(var: variable);
16122 }
16123 else if (storage == StorageClassStorageBuffer)
16124 {
16125 // These builtins are passed directly; we don't want to use remapping
16126 // for them.
16127 auto builtin = (BuiltIn)get_decoration(id: variable.self, decoration: DecorationBuiltIn);
16128 if (is_tese_shader() && is_builtin_variable(var: variable) && (builtin == BuiltInTessCoord || builtin == BuiltInPrimitiveId))
16129 return false;
16130
16131 // We won't be able to catch writes to control point outputs here since variable
16132 // refers to a function local pointer.
16133 // This is fine, as there cannot be concurrent writers to that memory anyways,
16134 // so we just ignore that case.
16135
16136 return (variable.storage == StorageClassOutput || variable.storage == StorageClassInput) &&
16137 !variable_storage_requires_stage_io(storage: variable.storage) &&
16138 (variable.storage != StorageClassOutput || !is_stage_output_variable_masked(var: variable));
16139 }
16140 else
16141 {
16142 return false;
16143 }
16144}
16145
16146// GCC workaround of lambdas calling protected funcs
16147std::string CompilerMSL::variable_decl(const SPIRType &type, const std::string &name, uint32_t id)
16148{
16149 return CompilerGLSL::variable_decl(type, name, id);
16150}
16151
16152std::string CompilerMSL::sampler_type(const SPIRType &type, uint32_t id, bool member)
16153{
16154 auto *var = maybe_get<SPIRVariable>(id);
16155 if (var && var->basevariable)
16156 {
16157 // Check against the base variable, and not a fake ID which might have been generated for this variable.
16158 id = var->basevariable;
16159 }
16160
16161 if (!type.array.empty())
16162 {
16163 if (!msl_options.supports_msl_version(major: 2))
16164 SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of samplers.");
16165
16166 if (type.array.size() > 1)
16167 SPIRV_CROSS_THROW("Arrays of arrays of samplers are not supported in MSL.");
16168
16169 // Arrays of samplers in MSL must be declared with a special array<T, N> syntax ala C++11 std::array.
16170 // If we have a runtime array, it could be a variable-count descriptor set binding.
16171 auto &parent = get<SPIRType>(id: get_pointee_type(type).parent_type);
16172 uint32_t array_size = get_resource_array_size(type, id);
16173
16174 if (array_size == 0)
16175 {
16176 add_spv_func_and_recompile(spv_func: SPVFuncImplVariableDescriptor);
16177 add_spv_func_and_recompile(spv_func: SPVFuncImplVariableDescriptorArray);
16178
16179 const char *descriptor_wrapper = processing_entry_point ? "const device spvDescriptor" : "const spvDescriptorArray";
16180 if (member)
16181 descriptor_wrapper = "spvDescriptor";
16182 return join(ts&: descriptor_wrapper, ts: "<", ts: sampler_type(type: parent, id, member: false), ts: ">",
16183 ts: processing_entry_point ? "*" : "");
16184 }
16185 else
16186 {
16187 return join(ts: "array<", ts: sampler_type(type: parent, id, member: false), ts: ", ", ts&: array_size, ts: ">");
16188 }
16189 }
16190 else
16191 return "sampler";
16192}
16193
16194// Returns an MSL string describing the SPIR-V image type
16195string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id, bool member)
16196{
16197 auto *var = maybe_get<SPIRVariable>(id);
16198 if (var && var->basevariable)
16199 {
16200 // For comparison images, check against the base variable,
16201 // and not the fake ID which might have been generated for this variable.
16202 id = var->basevariable;
16203 }
16204
16205 if (!type.array.empty())
16206 {
16207 uint32_t major = 2, minor = 0;
16208 if (msl_options.is_ios())
16209 {
16210 major = 1;
16211 minor = 2;
16212 }
16213 if (!msl_options.supports_msl_version(major, minor))
16214 {
16215 if (msl_options.is_ios())
16216 SPIRV_CROSS_THROW("MSL 1.2 or greater is required for arrays of textures.");
16217 else
16218 SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of textures.");
16219 }
16220
16221 if (type.array.size() > 1)
16222 SPIRV_CROSS_THROW("Arrays of arrays of textures are not supported in MSL.");
16223
16224 // Arrays of images in MSL must be declared with a special array<T, N> syntax ala C++11 std::array.
16225 // If we have a runtime array, it could be a variable-count descriptor set binding.
16226 auto &parent = get<SPIRType>(id: get_pointee_type(type).parent_type);
16227 uint32_t array_size = get_resource_array_size(type, id);
16228
16229 if (array_size == 0)
16230 {
16231 add_spv_func_and_recompile(spv_func: SPVFuncImplVariableDescriptor);
16232 add_spv_func_and_recompile(spv_func: SPVFuncImplVariableDescriptorArray);
16233 const char *descriptor_wrapper = processing_entry_point ? "const device spvDescriptor" : "const spvDescriptorArray";
16234 if (member)
16235 {
16236 descriptor_wrapper = "spvDescriptor";
16237 // This requires a specialized wrapper type that packs image and sampler side by side.
16238 // It is possible in theory.
16239 if (type.basetype == SPIRType::SampledImage)
16240 SPIRV_CROSS_THROW("Argument buffer runtime array currently not supported for combined image sampler.");
16241 }
16242 return join(ts&: descriptor_wrapper, ts: "<", ts: image_type_glsl(type: parent, id, member: false), ts: ">",
16243 ts: processing_entry_point ? "*" : "");
16244 }
16245 else
16246 {
16247 return join(ts: "array<", ts: image_type_glsl(type: parent, id, member: false), ts: ", ", ts&: array_size, ts: ">");
16248 }
16249 }
16250
16251 string img_type_name;
16252
16253 auto &img_type = type.image;
16254
16255 if (is_depth_image(type, id))
16256 {
16257 switch (img_type.dim)
16258 {
16259 case Dim1D:
16260 case Dim2D:
16261 if (img_type.dim == Dim1D && !msl_options.texture_1D_as_2D)
16262 {
16263 // Use a native Metal 1D texture
16264 img_type_name += "depth1d_unsupported_by_metal";
16265 break;
16266 }
16267
16268 if (img_type.ms && img_type.arrayed)
16269 {
16270 if (!msl_options.supports_msl_version(major: 2, minor: 1))
16271 SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1.");
16272 img_type_name += "depth2d_ms_array";
16273 }
16274 else if (img_type.ms)
16275 img_type_name += "depth2d_ms";
16276 else if (img_type.arrayed)
16277 img_type_name += "depth2d_array";
16278 else
16279 img_type_name += "depth2d";
16280 break;
16281 case Dim3D:
16282 img_type_name += "depth3d_unsupported_by_metal";
16283 break;
16284 case DimCube:
16285 if (!msl_options.emulate_cube_array)
16286 img_type_name += (img_type.arrayed ? "depthcube_array" : "depthcube");
16287 else
16288 img_type_name += (img_type.arrayed ? "depth2d_array" : "depthcube");
16289 break;
16290 default:
16291 img_type_name += "unknown_depth_texture_type";
16292 break;
16293 }
16294 }
16295 else
16296 {
16297 switch (img_type.dim)
16298 {
16299 case DimBuffer:
16300 if (img_type.ms || img_type.arrayed)
16301 SPIRV_CROSS_THROW("Cannot use texel buffers with multisampling or array layers.");
16302
16303 if (msl_options.texture_buffer_native)
16304 {
16305 if (!msl_options.supports_msl_version(major: 2, minor: 1))
16306 SPIRV_CROSS_THROW("Native texture_buffer type is only supported in MSL 2.1.");
16307 img_type_name = "texture_buffer";
16308 }
16309 else
16310 img_type_name += "texture2d";
16311 break;
16312 case Dim1D:
16313 case Dim2D:
16314 case DimSubpassData:
16315 {
16316 bool subpass_array =
16317 img_type.dim == DimSubpassData && (msl_options.multiview || msl_options.arrayed_subpass_input);
16318 if (img_type.dim == Dim1D && !msl_options.texture_1D_as_2D)
16319 {
16320 // Use a native Metal 1D texture
16321 img_type_name += (img_type.arrayed ? "texture1d_array" : "texture1d");
16322 break;
16323 }
16324
16325 // Use Metal's native frame-buffer fetch API for subpass inputs.
16326 if (type_is_msl_framebuffer_fetch(type))
16327 {
16328 auto img_type_4 = get<SPIRType>(id: img_type.type);
16329 img_type_4.vecsize = 4;
16330 return type_to_glsl(type: img_type_4);
16331 }
16332 if (img_type.ms && (img_type.arrayed || subpass_array))
16333 {
16334 if (!msl_options.supports_msl_version(major: 2, minor: 1))
16335 SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1.");
16336 img_type_name += "texture2d_ms_array";
16337 }
16338 else if (img_type.ms)
16339 img_type_name += "texture2d_ms";
16340 else if (img_type.arrayed || subpass_array)
16341 img_type_name += "texture2d_array";
16342 else
16343 img_type_name += "texture2d";
16344 break;
16345 }
16346 case Dim3D:
16347 img_type_name += "texture3d";
16348 break;
16349 case DimCube:
16350 if (!msl_options.emulate_cube_array)
16351 img_type_name += (img_type.arrayed ? "texturecube_array" : "texturecube");
16352 else
16353 img_type_name += (img_type.arrayed ? "texture2d_array" : "texturecube");
16354 break;
16355 default:
16356 img_type_name += "unknown_texture_type";
16357 break;
16358 }
16359 }
16360
16361 // Append the pixel type
16362 img_type_name += "<";
16363 img_type_name += type_to_glsl(type: get<SPIRType>(id: img_type.type));
16364
16365 // For unsampled images, append the sample/read/write access qualifier.
16366 // For kernel images, the access qualifier my be supplied directly by SPIR-V.
16367 // Otherwise it may be set based on whether the image is read from or written to within the shader.
16368 if (type.basetype == SPIRType::Image && type.image.sampled == 2 && type.image.dim != DimSubpassData)
16369 {
16370 switch (img_type.access)
16371 {
16372 case AccessQualifierReadOnly:
16373 img_type_name += ", access::read";
16374 break;
16375
16376 case AccessQualifierWriteOnly:
16377 img_type_name += ", access::write";
16378 break;
16379
16380 case AccessQualifierReadWrite:
16381 img_type_name += ", access::read_write";
16382 break;
16383
16384 default:
16385 {
16386 auto *p_var = maybe_get_backing_variable(chain: id);
16387 if (p_var && p_var->basevariable)
16388 p_var = maybe_get<SPIRVariable>(id: p_var->basevariable);
16389 if (p_var && !has_decoration(id: p_var->self, decoration: DecorationNonWritable))
16390 {
16391 img_type_name += ", access::";
16392
16393 if (!has_decoration(id: p_var->self, decoration: DecorationNonReadable))
16394 img_type_name += "read_";
16395
16396 img_type_name += "write";
16397 }
16398 break;
16399 }
16400 }
16401 }
16402
16403 img_type_name += ">";
16404
16405 return img_type_name;
16406}
16407
16408void CompilerMSL::emit_subgroup_op(const Instruction &i)
16409{
16410 const uint32_t *ops = stream(instr: i);
16411 auto op = static_cast<Op>(i.op);
16412
16413 if (msl_options.emulate_subgroups)
16414 {
16415 // In this mode, only the GroupNonUniform cap is supported. The only op
16416 // we need to handle, then, is OpGroupNonUniformElect.
16417 if (op != OpGroupNonUniformElect)
16418 SPIRV_CROSS_THROW("Subgroup emulation does not support operations other than Elect.");
16419 // In this mode, the subgroup size is assumed to be one, so every invocation
16420 // is elected.
16421 emit_op(result_type: ops[0], result_id: ops[1], rhs: "true", forward_rhs: true);
16422 return;
16423 }
16424
16425 // Metal 2.0 is required. iOS only supports quad ops on 11.0 (2.0), with
16426 // full support in 13.0 (2.2). macOS only supports broadcast and shuffle on
16427 // 10.13 (2.0), with full support in 10.14 (2.1).
16428 // Note that Apple GPUs before A13 make no distinction between a quad-group
16429 // and a SIMD-group; all SIMD-groups are quad-groups on those.
16430 if (!msl_options.supports_msl_version(major: 2))
16431 SPIRV_CROSS_THROW("Subgroups are only supported in Metal 2.0 and up.");
16432
16433 // If we need to do implicit bitcasts, make sure we do it with the correct type.
16434 uint32_t integer_width = get_integer_width_for_instruction(instr: i);
16435 auto int_type = to_signed_basetype(width: integer_width);
16436 auto uint_type = to_unsigned_basetype(width: integer_width);
16437
16438 if (msl_options.is_ios() && (!msl_options.supports_msl_version(major: 2, minor: 3) || !msl_options.ios_use_simdgroup_functions))
16439 {
16440 switch (op)
16441 {
16442 default:
16443 SPIRV_CROSS_THROW("Subgroup ops beyond broadcast, ballot, and shuffle on iOS require Metal 2.3 and up.");
16444 case OpGroupNonUniformBroadcastFirst:
16445 if (!msl_options.supports_msl_version(major: 2, minor: 2))
16446 SPIRV_CROSS_THROW("BroadcastFirst on iOS requires Metal 2.2 and up.");
16447 break;
16448 case OpGroupNonUniformElect:
16449 if (!msl_options.supports_msl_version(major: 2, minor: 2))
16450 SPIRV_CROSS_THROW("Elect on iOS requires Metal 2.2 and up.");
16451 break;
16452 case OpGroupNonUniformAny:
16453 case OpGroupNonUniformAll:
16454 case OpGroupNonUniformAllEqual:
16455 case OpGroupNonUniformBallot:
16456 case OpGroupNonUniformInverseBallot:
16457 case OpGroupNonUniformBallotBitExtract:
16458 case OpGroupNonUniformBallotFindLSB:
16459 case OpGroupNonUniformBallotFindMSB:
16460 case OpGroupNonUniformBallotBitCount:
16461 case OpSubgroupBallotKHR:
16462 case OpSubgroupAllKHR:
16463 case OpSubgroupAnyKHR:
16464 case OpSubgroupAllEqualKHR:
16465 if (!msl_options.supports_msl_version(major: 2, minor: 2))
16466 SPIRV_CROSS_THROW("Ballot ops on iOS requires Metal 2.2 and up.");
16467 break;
16468 case OpGroupNonUniformBroadcast:
16469 case OpGroupNonUniformShuffle:
16470 case OpGroupNonUniformShuffleXor:
16471 case OpGroupNonUniformShuffleUp:
16472 case OpGroupNonUniformShuffleDown:
16473 case OpGroupNonUniformQuadSwap:
16474 case OpGroupNonUniformQuadBroadcast:
16475 case OpSubgroupReadInvocationKHR:
16476 break;
16477 }
16478 }
16479
16480 if (msl_options.is_macos() && !msl_options.supports_msl_version(major: 2, minor: 1))
16481 {
16482 switch (op)
16483 {
16484 default:
16485 SPIRV_CROSS_THROW("Subgroup ops beyond broadcast and shuffle on macOS require Metal 2.1 and up.");
16486 case OpGroupNonUniformBroadcast:
16487 case OpGroupNonUniformShuffle:
16488 case OpGroupNonUniformShuffleXor:
16489 case OpGroupNonUniformShuffleUp:
16490 case OpGroupNonUniformShuffleDown:
16491 case OpSubgroupReadInvocationKHR:
16492 break;
16493 }
16494 }
16495
16496 uint32_t op_idx = 0;
16497 uint32_t result_type = ops[op_idx++];
16498 uint32_t id = ops[op_idx++];
16499
16500 Scope scope;
16501 switch (op)
16502 {
16503 case OpSubgroupBallotKHR:
16504 case OpSubgroupFirstInvocationKHR:
16505 case OpSubgroupReadInvocationKHR:
16506 case OpSubgroupAllKHR:
16507 case OpSubgroupAnyKHR:
16508 case OpSubgroupAllEqualKHR:
16509 // These earlier instructions don't have the scope operand.
16510 scope = ScopeSubgroup;
16511 break;
16512 default:
16513 scope = static_cast<Scope>(evaluate_constant_u32(id: ops[op_idx++]));
16514 break;
16515 }
16516 if (scope != ScopeSubgroup)
16517 SPIRV_CROSS_THROW("Only subgroup scope is supported.");
16518
16519 switch (op)
16520 {
16521 case OpGroupNonUniformElect:
16522 if (msl_options.use_quadgroup_operation())
16523 emit_op(result_type, result_id: id, rhs: "quad_is_first()", forward_rhs: false);
16524 else
16525 emit_op(result_type, result_id: id, rhs: "simd_is_first()", forward_rhs: false);
16526 break;
16527
16528 case OpGroupNonUniformBroadcast:
16529 case OpSubgroupReadInvocationKHR:
16530 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: ops[op_idx + 1], op: "spvSubgroupBroadcast");
16531 break;
16532
16533 case OpGroupNonUniformBroadcastFirst:
16534 case OpSubgroupFirstInvocationKHR:
16535 emit_unary_func_op(result_type, result_id: id, op0: ops[op_idx], op: "spvSubgroupBroadcastFirst");
16536 break;
16537
16538 case OpGroupNonUniformBallot:
16539 case OpSubgroupBallotKHR:
16540 emit_unary_func_op(result_type, result_id: id, op0: ops[op_idx], op: "spvSubgroupBallot");
16541 break;
16542
16543 case OpGroupNonUniformInverseBallot:
16544 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: builtin_subgroup_invocation_id_id, op: "spvSubgroupBallotBitExtract");
16545 break;
16546
16547 case OpGroupNonUniformBallotBitExtract:
16548 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: ops[op_idx + 1], op: "spvSubgroupBallotBitExtract");
16549 break;
16550
16551 case OpGroupNonUniformBallotFindLSB:
16552 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: builtin_subgroup_size_id, op: "spvSubgroupBallotFindLSB");
16553 break;
16554
16555 case OpGroupNonUniformBallotFindMSB:
16556 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: builtin_subgroup_size_id, op: "spvSubgroupBallotFindMSB");
16557 break;
16558
16559 case OpGroupNonUniformBallotBitCount:
16560 {
16561 auto operation = static_cast<GroupOperation>(ops[op_idx++]);
16562 switch (operation)
16563 {
16564 case GroupOperationReduce:
16565 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: builtin_subgroup_size_id, op: "spvSubgroupBallotBitCount");
16566 break;
16567 case GroupOperationInclusiveScan:
16568 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: builtin_subgroup_invocation_id_id,
16569 op: "spvSubgroupBallotInclusiveBitCount");
16570 break;
16571 case GroupOperationExclusiveScan:
16572 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: builtin_subgroup_invocation_id_id,
16573 op: "spvSubgroupBallotExclusiveBitCount");
16574 break;
16575 default:
16576 SPIRV_CROSS_THROW("Invalid BitCount operation.");
16577 }
16578 break;
16579 }
16580
16581 case OpGroupNonUniformShuffle:
16582 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: ops[op_idx + 1], op: "spvSubgroupShuffle");
16583 break;
16584
16585 case OpGroupNonUniformShuffleXor:
16586 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: ops[op_idx + 1], op: "spvSubgroupShuffleXor");
16587 break;
16588
16589 case OpGroupNonUniformShuffleUp:
16590 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: ops[op_idx + 1], op: "spvSubgroupShuffleUp");
16591 break;
16592
16593 case OpGroupNonUniformShuffleDown:
16594 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: ops[op_idx + 1], op: "spvSubgroupShuffleDown");
16595 break;
16596
16597 case OpGroupNonUniformAll:
16598 case OpSubgroupAllKHR:
16599 if (msl_options.use_quadgroup_operation())
16600 emit_unary_func_op(result_type, result_id: id, op0: ops[op_idx], op: "quad_all");
16601 else
16602 emit_unary_func_op(result_type, result_id: id, op0: ops[op_idx], op: "simd_all");
16603 break;
16604
16605 case OpGroupNonUniformAny:
16606 case OpSubgroupAnyKHR:
16607 if (msl_options.use_quadgroup_operation())
16608 emit_unary_func_op(result_type, result_id: id, op0: ops[op_idx], op: "quad_any");
16609 else
16610 emit_unary_func_op(result_type, result_id: id, op0: ops[op_idx], op: "simd_any");
16611 break;
16612
16613 case OpGroupNonUniformAllEqual:
16614 case OpSubgroupAllEqualKHR:
16615 emit_unary_func_op(result_type, result_id: id, op0: ops[op_idx], op: "spvSubgroupAllEqual");
16616 break;
16617
16618 // clang-format off
16619#define MSL_GROUP_OP(op, msl_op) \
16620case OpGroupNonUniform##op: \
16621 { \
16622 auto operation = static_cast<GroupOperation>(ops[op_idx++]); \
16623 if (operation == GroupOperationReduce) \
16624 emit_unary_func_op(result_type, id, ops[op_idx], "simd_" #msl_op); \
16625 else if (operation == GroupOperationInclusiveScan) \
16626 emit_unary_func_op(result_type, id, ops[op_idx], "simd_prefix_inclusive_" #msl_op); \
16627 else if (operation == GroupOperationExclusiveScan) \
16628 emit_unary_func_op(result_type, id, ops[op_idx], "simd_prefix_exclusive_" #msl_op); \
16629 else if (operation == GroupOperationClusteredReduce) \
16630 { \
16631 /* Only cluster sizes of 4 are supported. */ \
16632 uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \
16633 if (cluster_size != 4) \
16634 SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
16635 emit_unary_func_op(result_type, id, ops[op_idx], "quad_" #msl_op); \
16636 } \
16637 else \
16638 SPIRV_CROSS_THROW("Invalid group operation."); \
16639 break; \
16640 }
16641 MSL_GROUP_OP(FAdd, sum)
16642 MSL_GROUP_OP(FMul, product)
16643 MSL_GROUP_OP(IAdd, sum)
16644 MSL_GROUP_OP(IMul, product)
16645#undef MSL_GROUP_OP
16646 // The others, unfortunately, don't support InclusiveScan or ExclusiveScan.
16647
16648#define MSL_GROUP_OP(op, msl_op) \
16649case OpGroupNonUniform##op: \
16650 { \
16651 auto operation = static_cast<GroupOperation>(ops[op_idx++]); \
16652 if (operation == GroupOperationReduce) \
16653 emit_unary_func_op(result_type, id, ops[op_idx], "simd_" #msl_op); \
16654 else if (operation == GroupOperationInclusiveScan) \
16655 SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
16656 else if (operation == GroupOperationExclusiveScan) \
16657 SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
16658 else if (operation == GroupOperationClusteredReduce) \
16659 { \
16660 /* Only cluster sizes of 4 are supported. */ \
16661 uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \
16662 if (cluster_size != 4) \
16663 SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
16664 emit_unary_func_op(result_type, id, ops[op_idx], "quad_" #msl_op); \
16665 } \
16666 else \
16667 SPIRV_CROSS_THROW("Invalid group operation."); \
16668 break; \
16669 }
16670
16671#define MSL_GROUP_OP_CAST(op, msl_op, type) \
16672case OpGroupNonUniform##op: \
16673 { \
16674 auto operation = static_cast<GroupOperation>(ops[op_idx++]); \
16675 if (operation == GroupOperationReduce) \
16676 emit_unary_func_op_cast(result_type, id, ops[op_idx], "simd_" #msl_op, type, type); \
16677 else if (operation == GroupOperationInclusiveScan) \
16678 SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
16679 else if (operation == GroupOperationExclusiveScan) \
16680 SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
16681 else if (operation == GroupOperationClusteredReduce) \
16682 { \
16683 /* Only cluster sizes of 4 are supported. */ \
16684 uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \
16685 if (cluster_size != 4) \
16686 SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
16687 emit_unary_func_op_cast(result_type, id, ops[op_idx], "quad_" #msl_op, type, type); \
16688 } \
16689 else \
16690 SPIRV_CROSS_THROW("Invalid group operation."); \
16691 break; \
16692 }
16693
16694 MSL_GROUP_OP(FMin, min)
16695 MSL_GROUP_OP(FMax, max)
16696 MSL_GROUP_OP_CAST(SMin, min, int_type)
16697 MSL_GROUP_OP_CAST(SMax, max, int_type)
16698 MSL_GROUP_OP_CAST(UMin, min, uint_type)
16699 MSL_GROUP_OP_CAST(UMax, max, uint_type)
16700 MSL_GROUP_OP(BitwiseAnd, and)
16701 MSL_GROUP_OP(BitwiseOr, or)
16702 MSL_GROUP_OP(BitwiseXor, xor)
16703 MSL_GROUP_OP(LogicalAnd, and)
16704 MSL_GROUP_OP(LogicalOr, or)
16705 MSL_GROUP_OP(LogicalXor, xor)
16706 // clang-format on
16707#undef MSL_GROUP_OP
16708#undef MSL_GROUP_OP_CAST
16709
16710 case OpGroupNonUniformQuadSwap:
16711 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: ops[op_idx + 1], op: "spvQuadSwap");
16712 break;
16713
16714 case OpGroupNonUniformQuadBroadcast:
16715 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: ops[op_idx + 1], op: "spvQuadBroadcast");
16716 break;
16717
16718 default:
16719 SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
16720 }
16721
16722 register_control_dependent_expression(expr: id);
16723}
16724
16725string CompilerMSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
16726{
16727 if (out_type.basetype == in_type.basetype)
16728 return "";
16729
16730 assert(out_type.basetype != SPIRType::Boolean);
16731 assert(in_type.basetype != SPIRType::Boolean);
16732
16733 bool integral_cast = type_is_integral(type: out_type) && type_is_integral(type: in_type) && (out_type.vecsize == in_type.vecsize);
16734 bool same_size_cast = (out_type.width * out_type.vecsize) == (in_type.width * in_type.vecsize);
16735
16736 // Bitcasting can only be used between types of the same overall size.
16737 // And always formally cast between integers, because it's trivial, and also
16738 // because Metal can internally cast the results of some integer ops to a larger
16739 // size (eg. short shift right becomes int), which means chaining integer ops
16740 // together may introduce size variations that SPIR-V doesn't know about.
16741 if (same_size_cast && !integral_cast)
16742 return "as_type<" + type_to_glsl(type: out_type) + ">";
16743 else
16744 return type_to_glsl(type: out_type);
16745}
16746
16747bool CompilerMSL::emit_complex_bitcast(uint32_t, uint32_t, uint32_t)
16748{
16749 // This is handled from the outside where we deal with PtrToU/UToPtr and friends.
16750 return false;
16751}
16752
16753// Returns an MSL string identifying the name of a SPIR-V builtin.
16754// Output builtins are qualified with the name of the stage out structure.
16755string CompilerMSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage)
16756{
16757 switch (builtin)
16758 {
16759 // Handle HLSL-style 0-based vertex/instance index.
16760 // Override GLSL compiler strictness
16761 case BuiltInVertexId:
16762 ensure_builtin(storage: StorageClassInput, builtin: BuiltInVertexId);
16763 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(major: 1, minor: 1) &&
16764 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
16765 {
16766 if (builtin_declaration)
16767 {
16768 if (needs_base_vertex_arg != TriState::No)
16769 needs_base_vertex_arg = TriState::Yes;
16770 return "gl_VertexID";
16771 }
16772 else
16773 {
16774 ensure_builtin(storage: StorageClassInput, builtin: BuiltInBaseVertex);
16775 return "(gl_VertexID - gl_BaseVertex)";
16776 }
16777 }
16778 else
16779 {
16780 return "gl_VertexID";
16781 }
16782 case BuiltInInstanceId:
16783 ensure_builtin(storage: StorageClassInput, builtin: BuiltInInstanceId);
16784 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(major: 1, minor: 1) &&
16785 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
16786 {
16787 if (builtin_declaration)
16788 {
16789 if (needs_base_instance_arg != TriState::No)
16790 needs_base_instance_arg = TriState::Yes;
16791 return "gl_InstanceID";
16792 }
16793 else
16794 {
16795 ensure_builtin(storage: StorageClassInput, builtin: BuiltInBaseInstance);
16796 return "(gl_InstanceID - gl_BaseInstance)";
16797 }
16798 }
16799 else
16800 {
16801 return "gl_InstanceID";
16802 }
16803 case BuiltInVertexIndex:
16804 ensure_builtin(storage: StorageClassInput, builtin: BuiltInVertexIndex);
16805 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(major: 1, minor: 1) &&
16806 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
16807 {
16808 if (builtin_declaration)
16809 {
16810 if (needs_base_vertex_arg != TriState::No)
16811 needs_base_vertex_arg = TriState::Yes;
16812 return "gl_VertexIndex";
16813 }
16814 else
16815 {
16816 ensure_builtin(storage: StorageClassInput, builtin: BuiltInBaseVertex);
16817 return "(gl_VertexIndex - gl_BaseVertex)";
16818 }
16819 }
16820 else
16821 {
16822 return "gl_VertexIndex";
16823 }
16824 case BuiltInInstanceIndex:
16825 ensure_builtin(storage: StorageClassInput, builtin: BuiltInInstanceIndex);
16826 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(major: 1, minor: 1) &&
16827 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
16828 {
16829 if (builtin_declaration)
16830 {
16831 if (needs_base_instance_arg != TriState::No)
16832 needs_base_instance_arg = TriState::Yes;
16833 return "gl_InstanceIndex";
16834 }
16835 else
16836 {
16837 ensure_builtin(storage: StorageClassInput, builtin: BuiltInBaseInstance);
16838 return "(gl_InstanceIndex - gl_BaseInstance)";
16839 }
16840 }
16841 else
16842 {
16843 return "gl_InstanceIndex";
16844 }
16845 case BuiltInBaseVertex:
16846 if (msl_options.supports_msl_version(major: 1, minor: 1) &&
16847 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
16848 {
16849 needs_base_vertex_arg = TriState::No;
16850 return "gl_BaseVertex";
16851 }
16852 else
16853 {
16854 SPIRV_CROSS_THROW("BaseVertex requires Metal 1.1 and Mac or Apple A9+ hardware.");
16855 }
16856 case BuiltInBaseInstance:
16857 if (msl_options.supports_msl_version(major: 1, minor: 1) &&
16858 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
16859 {
16860 needs_base_instance_arg = TriState::No;
16861 return "gl_BaseInstance";
16862 }
16863 else
16864 {
16865 SPIRV_CROSS_THROW("BaseInstance requires Metal 1.1 and Mac or Apple A9+ hardware.");
16866 }
16867 case BuiltInDrawIndex:
16868 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
16869
16870 // When used in the entry function, output builtins are qualified with output struct name.
16871 // Test storage class as NOT Input, as output builtins might be part of generic type.
16872 // Also don't do this for tessellation control shaders.
16873 case BuiltInViewportIndex:
16874 if (!msl_options.supports_msl_version(major: 2, minor: 0))
16875 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
16876 /* fallthrough */
16877 case BuiltInFragDepth:
16878 case BuiltInFragStencilRefEXT:
16879 if ((builtin == BuiltInFragDepth && !msl_options.enable_frag_depth_builtin) ||
16880 (builtin == BuiltInFragStencilRefEXT && !msl_options.enable_frag_stencil_ref_builtin))
16881 break;
16882 /* fallthrough */
16883 case BuiltInPosition:
16884 case BuiltInPointSize:
16885 case BuiltInClipDistance:
16886 case BuiltInCullDistance:
16887 case BuiltInLayer:
16888 if (is_tesc_shader())
16889 break;
16890 if (is_mesh_shader())
16891 break;
16892 if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point) &&
16893 !is_stage_output_builtin_masked(builtin))
16894 return stage_out_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
16895 break;
16896
16897 case BuiltInSampleMask:
16898 if (storage == StorageClassInput && current_function && (current_function->self == ir.default_entry_point) &&
16899 (has_additional_fixed_sample_mask() || needs_sample_id))
16900 {
16901 string samp_mask_in;
16902 samp_mask_in += "(" + CompilerGLSL::builtin_to_glsl(builtin, storage);
16903 if (has_additional_fixed_sample_mask())
16904 samp_mask_in += " & " + additional_fixed_sample_mask_str();
16905 if (needs_sample_id)
16906 samp_mask_in += " & (1 << gl_SampleID)";
16907 samp_mask_in += ")";
16908 return samp_mask_in;
16909 }
16910 if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point) &&
16911 !is_stage_output_builtin_masked(builtin))
16912 return stage_out_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
16913 break;
16914
16915 case BuiltInBaryCoordKHR:
16916 case BuiltInBaryCoordNoPerspKHR:
16917 if (storage == StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
16918 return stage_in_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
16919 break;
16920
16921 case BuiltInTessLevelOuter:
16922 if (is_tesc_shader() && storage != StorageClassInput && current_function &&
16923 (current_function->self == ir.default_entry_point))
16924 {
16925 return join(ts&: tess_factor_buffer_var_name, ts: "[", ts: to_expression(id: builtin_primitive_id_id),
16926 ts: "].edgeTessellationFactor");
16927 }
16928 break;
16929
16930 case BuiltInTessLevelInner:
16931 if (is_tesc_shader() && storage != StorageClassInput && current_function &&
16932 (current_function->self == ir.default_entry_point))
16933 {
16934 return join(ts&: tess_factor_buffer_var_name, ts: "[", ts: to_expression(id: builtin_primitive_id_id),
16935 ts: "].insideTessellationFactor");
16936 }
16937 break;
16938
16939 case BuiltInHelperInvocation:
16940 if (needs_manual_helper_invocation_updates())
16941 break;
16942 if (msl_options.is_ios() && !msl_options.supports_msl_version(major: 2, minor: 3))
16943 SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.3 on iOS.");
16944 else if (msl_options.is_macos() && !msl_options.supports_msl_version(major: 2, minor: 1))
16945 SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.1 on macOS.");
16946 // In SPIR-V 1.6 with Volatile HelperInvocation, we cannot emit a fixup early.
16947 return "simd_is_helper_thread()";
16948
16949 case BuiltInPrimitiveId:
16950 return "gl_PrimitiveID";
16951
16952 default:
16953 break;
16954 }
16955
16956 return CompilerGLSL::builtin_to_glsl(builtin, storage);
16957}
16958
16959// Returns an MSL string attribute qualifer for a SPIR-V builtin
16960string CompilerMSL::builtin_qualifier(BuiltIn builtin)
16961{
16962 auto &execution = get_entry_point();
16963
16964 switch (builtin)
16965 {
16966 // Vertex function in
16967 case BuiltInVertexId:
16968 return "vertex_id";
16969 case BuiltInVertexIndex:
16970 return "vertex_id";
16971 case BuiltInBaseVertex:
16972 return "base_vertex";
16973 case BuiltInInstanceId:
16974 return "instance_id";
16975 case BuiltInInstanceIndex:
16976 return "instance_id";
16977 case BuiltInBaseInstance:
16978 return "base_instance";
16979 case BuiltInDrawIndex:
16980 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
16981
16982 // Vertex function out
16983 case BuiltInClipDistance:
16984 return "clip_distance";
16985 case BuiltInCullDistance:
16986 return "cull_distance";
16987 case BuiltInPointSize:
16988 return "point_size";
16989 case BuiltInPosition:
16990 if (position_invariant)
16991 {
16992 if (!msl_options.supports_msl_version(major: 2, minor: 1))
16993 SPIRV_CROSS_THROW("Invariant position is only supported on MSL 2.1 and up.");
16994 return "position, invariant";
16995 }
16996 else
16997 return "position";
16998 case BuiltInLayer:
16999 return "render_target_array_index";
17000 case BuiltInViewportIndex:
17001 if (!msl_options.supports_msl_version(major: 2, minor: 0))
17002 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
17003 return "viewport_array_index";
17004
17005 // Tess. control function in
17006 case BuiltInInvocationId:
17007 if (msl_options.multi_patch_workgroup)
17008 {
17009 // Shouldn't be reached.
17010 SPIRV_CROSS_THROW("InvocationId is computed manually with multi-patch workgroups in MSL.");
17011 }
17012 return "thread_index_in_threadgroup";
17013 case BuiltInPatchVertices:
17014 // Shouldn't be reached.
17015 SPIRV_CROSS_THROW("PatchVertices is derived from the auxiliary buffer in MSL.");
17016 case BuiltInPrimitiveId:
17017 switch (execution.model)
17018 {
17019 case ExecutionModelTessellationControl:
17020 if (msl_options.multi_patch_workgroup)
17021 {
17022 // Shouldn't be reached.
17023 SPIRV_CROSS_THROW("PrimitiveId is computed manually with multi-patch workgroups in MSL.");
17024 }
17025 return "threadgroup_position_in_grid";
17026 case ExecutionModelTessellationEvaluation:
17027 return "patch_id";
17028 case ExecutionModelFragment:
17029 if (msl_options.is_ios() && !msl_options.supports_msl_version(major: 2, minor: 3))
17030 SPIRV_CROSS_THROW("PrimitiveId on iOS requires MSL 2.3.");
17031 else if (msl_options.is_macos() && !msl_options.supports_msl_version(major: 2, minor: 2))
17032 SPIRV_CROSS_THROW("PrimitiveId on macOS requires MSL 2.2.");
17033 return "primitive_id";
17034 case ExecutionModelMeshEXT:
17035 return "primitive_id";
17036 default:
17037 SPIRV_CROSS_THROW("PrimitiveId is not supported in this execution model.");
17038 }
17039
17040 // Tess. control function out
17041 case BuiltInTessLevelOuter:
17042 case BuiltInTessLevelInner:
17043 // Shouldn't be reached.
17044 SPIRV_CROSS_THROW("Tessellation levels are handled specially in MSL.");
17045
17046 // Tess. evaluation function in
17047 case BuiltInTessCoord:
17048 return "position_in_patch";
17049
17050 // Fragment function in
17051 case BuiltInFrontFacing:
17052 return "front_facing";
17053 case BuiltInPointCoord:
17054 return "point_coord";
17055 case BuiltInFragCoord:
17056 return "position";
17057 case BuiltInSampleId:
17058 return "sample_id";
17059 case BuiltInSampleMask:
17060 return "sample_mask";
17061 case BuiltInSamplePosition:
17062 // Shouldn't be reached.
17063 SPIRV_CROSS_THROW("Sample position is retrieved by a function in MSL.");
17064 case BuiltInViewIndex:
17065 if (execution.model != ExecutionModelFragment && execution.model != ExecutionModelMeshEXT)
17066 SPIRV_CROSS_THROW("ViewIndex is handled specially outside fragment shaders.");
17067 // The ViewIndex was implicitly used in the prior stages to set the render_target_array_index,
17068 // so we can get it from there.
17069 return "render_target_array_index";
17070
17071 // Fragment function out
17072 case BuiltInFragDepth:
17073 if (execution.flags.get(bit: ExecutionModeDepthGreater))
17074 return "depth(greater)";
17075 else if (execution.flags.get(bit: ExecutionModeDepthLess))
17076 return "depth(less)";
17077 else
17078 return "depth(any)";
17079
17080 case BuiltInFragStencilRefEXT:
17081 return "stencil";
17082
17083 // Compute function in
17084 case BuiltInGlobalInvocationId:
17085 return "thread_position_in_grid";
17086
17087 case BuiltInWorkgroupId:
17088 return "threadgroup_position_in_grid";
17089
17090 case BuiltInNumWorkgroups:
17091 return "threadgroups_per_grid";
17092
17093 case BuiltInLocalInvocationId:
17094 return "thread_position_in_threadgroup";
17095
17096 case BuiltInLocalInvocationIndex:
17097 return "thread_index_in_threadgroup";
17098
17099 case BuiltInSubgroupSize:
17100 if (msl_options.emulate_subgroups || msl_options.fixed_subgroup_size != 0)
17101 // Shouldn't be reached.
17102 SPIRV_CROSS_THROW("Emitting threads_per_simdgroup attribute with fixed subgroup size??");
17103 if (execution.model == ExecutionModelFragment)
17104 {
17105 if (!msl_options.supports_msl_version(major: 2, minor: 2))
17106 SPIRV_CROSS_THROW("threads_per_simdgroup requires Metal 2.2 in fragment shaders.");
17107 return "threads_per_simdgroup";
17108 }
17109 else
17110 {
17111 // thread_execution_width is an alias for threads_per_simdgroup, and it's only available since 1.0,
17112 // but not in fragment.
17113 return "thread_execution_width";
17114 }
17115
17116 case BuiltInNumSubgroups:
17117 if (msl_options.emulate_subgroups)
17118 // Shouldn't be reached.
17119 SPIRV_CROSS_THROW("NumSubgroups is handled specially with emulation.");
17120 if (!msl_options.supports_msl_version(major: 2))
17121 SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
17122 return msl_options.use_quadgroup_operation() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup";
17123
17124 case BuiltInSubgroupId:
17125 if (msl_options.emulate_subgroups)
17126 // Shouldn't be reached.
17127 SPIRV_CROSS_THROW("SubgroupId is handled specially with emulation.");
17128 if (!msl_options.supports_msl_version(major: 2))
17129 SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
17130 return msl_options.use_quadgroup_operation() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup";
17131
17132 case BuiltInSubgroupLocalInvocationId:
17133 if (msl_options.emulate_subgroups)
17134 // Shouldn't be reached.
17135 SPIRV_CROSS_THROW("SubgroupLocalInvocationId is handled specially with emulation.");
17136 if (execution.model == ExecutionModelFragment)
17137 {
17138 if (!msl_options.supports_msl_version(major: 2, minor: 2))
17139 SPIRV_CROSS_THROW("thread_index_in_simdgroup requires Metal 2.2 in fragment shaders.");
17140 return "thread_index_in_simdgroup";
17141 }
17142 else if (execution.model == ExecutionModelKernel || execution.model == ExecutionModelGLCompute ||
17143 execution.model == ExecutionModelTessellationControl ||
17144 (execution.model == ExecutionModelVertex && msl_options.vertex_for_tessellation))
17145 {
17146 // We are generating a Metal kernel function.
17147 if (!msl_options.supports_msl_version(major: 2))
17148 SPIRV_CROSS_THROW("Subgroup builtins in kernel functions require Metal 2.0.");
17149 return msl_options.use_quadgroup_operation() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup";
17150 }
17151 else
17152 SPIRV_CROSS_THROW("Subgroup builtins are not available in this type of function.");
17153
17154 case BuiltInSubgroupEqMask:
17155 case BuiltInSubgroupGeMask:
17156 case BuiltInSubgroupGtMask:
17157 case BuiltInSubgroupLeMask:
17158 case BuiltInSubgroupLtMask:
17159 // Shouldn't be reached.
17160 SPIRV_CROSS_THROW("Subgroup ballot masks are handled specially in MSL.");
17161
17162 case BuiltInBaryCoordKHR:
17163 case BuiltInBaryCoordNoPerspKHR:
17164 if (msl_options.is_ios() && !msl_options.supports_msl_version(major: 2, minor: 3))
17165 SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.3 and above on iOS.");
17166 else if (!msl_options.supports_msl_version(major: 2, minor: 2))
17167 SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS.");
17168 return "barycentric_coord";
17169
17170 case BuiltInCullPrimitiveEXT:
17171 return "primitive_culled";
17172
17173 default:
17174 return "unsupported-built-in";
17175 }
17176}
17177
17178// Returns an MSL string type declaration for a SPIR-V builtin
17179string CompilerMSL::builtin_type_decl(BuiltIn builtin, uint32_t id)
17180{
17181 switch (builtin)
17182 {
17183 // Vertex function in
17184 case BuiltInVertexId:
17185 return "uint";
17186 case BuiltInVertexIndex:
17187 return "uint";
17188 case BuiltInBaseVertex:
17189 return "uint";
17190 case BuiltInInstanceId:
17191 return "uint";
17192 case BuiltInInstanceIndex:
17193 return "uint";
17194 case BuiltInBaseInstance:
17195 return "uint";
17196 case BuiltInDrawIndex:
17197 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
17198
17199 // Vertex function out
17200 case BuiltInClipDistance:
17201 case BuiltInCullDistance:
17202 return "float";
17203 case BuiltInPointSize:
17204 return "float";
17205 case BuiltInPosition:
17206 return "float4";
17207 case BuiltInLayer:
17208 return "uint";
17209 case BuiltInViewportIndex:
17210 if (!msl_options.supports_msl_version(major: 2, minor: 0))
17211 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
17212 return "uint";
17213
17214 // Tess. control function in
17215 case BuiltInInvocationId:
17216 return "uint";
17217 case BuiltInPatchVertices:
17218 return "uint";
17219 case BuiltInPrimitiveId:
17220 return "uint";
17221
17222 // Tess. control function out
17223 case BuiltInTessLevelInner:
17224 if (is_tese_shader())
17225 return (msl_options.raw_buffer_tese_input || is_tessellating_triangles()) ? "float" : "float2";
17226 return "half";
17227 case BuiltInTessLevelOuter:
17228 if (is_tese_shader())
17229 return (msl_options.raw_buffer_tese_input || is_tessellating_triangles()) ? "float" : "float4";
17230 return "half";
17231
17232 // Tess. evaluation function in
17233 case BuiltInTessCoord:
17234 return "float3";
17235
17236 // Fragment function in
17237 case BuiltInFrontFacing:
17238 return "bool";
17239 case BuiltInPointCoord:
17240 return "float2";
17241 case BuiltInFragCoord:
17242 return "float4";
17243 case BuiltInSampleId:
17244 return "uint";
17245 case BuiltInSampleMask:
17246 return "uint";
17247 case BuiltInSamplePosition:
17248 return "float2";
17249 case BuiltInViewIndex:
17250 return "uint";
17251
17252 case BuiltInHelperInvocation:
17253 return "bool";
17254
17255 case BuiltInBaryCoordKHR:
17256 case BuiltInBaryCoordNoPerspKHR:
17257 // Use the type as declared, can be 1, 2 or 3 components.
17258 return type_to_glsl(type: get_variable_data_type(var: get<SPIRVariable>(id)));
17259
17260 // Fragment function out
17261 case BuiltInFragDepth:
17262 return "float";
17263
17264 case BuiltInFragStencilRefEXT:
17265 return "uint";
17266
17267 // Compute function in
17268 case BuiltInGlobalInvocationId:
17269 case BuiltInLocalInvocationId:
17270 case BuiltInNumWorkgroups:
17271 case BuiltInWorkgroupId:
17272 return "uint3";
17273 case BuiltInLocalInvocationIndex:
17274 case BuiltInNumSubgroups:
17275 case BuiltInSubgroupId:
17276 case BuiltInSubgroupSize:
17277 case BuiltInSubgroupLocalInvocationId:
17278 return "uint";
17279 case BuiltInSubgroupEqMask:
17280 case BuiltInSubgroupGeMask:
17281 case BuiltInSubgroupGtMask:
17282 case BuiltInSubgroupLeMask:
17283 case BuiltInSubgroupLtMask:
17284 return "uint4";
17285
17286 case BuiltInDeviceIndex:
17287 return "int";
17288
17289 case BuiltInPrimitivePointIndicesEXT:
17290 return "uint";
17291 case BuiltInPrimitiveLineIndicesEXT:
17292 return "uint2";
17293 case BuiltInPrimitiveTriangleIndicesEXT:
17294 return "uint3";
17295
17296 default:
17297 return "unsupported-built-in-type";
17298 }
17299}
17300
17301// Returns the declaration of a built-in argument to a function
17302string CompilerMSL::built_in_func_arg(BuiltIn builtin, bool prefix_comma)
17303{
17304 string bi_arg;
17305 if (prefix_comma)
17306 bi_arg += ", ";
17307
17308 // Handle HLSL-style 0-based vertex/instance index.
17309 builtin_declaration = true;
17310 bi_arg += builtin_type_decl(builtin);
17311 bi_arg += string(" ") + builtin_to_glsl(builtin, storage: StorageClassInput);
17312 bi_arg += string(" [[") + builtin_qualifier(builtin) + string("]]");
17313 builtin_declaration = false;
17314
17315 return bi_arg;
17316}
17317
17318const SPIRType &CompilerMSL::get_physical_member_type(const SPIRType &type, uint32_t index) const
17319{
17320 if (member_is_remapped_physical_type(type, index))
17321 return get<SPIRType>(id: get_extended_member_decoration(type: type.self, index, decoration: SPIRVCrossDecorationPhysicalTypeID));
17322 else
17323 return get<SPIRType>(id: type.member_types[index]);
17324}
17325
17326SPIRType CompilerMSL::get_presumed_input_type(const SPIRType &ib_type, uint32_t index) const
17327{
17328 SPIRType type = get_physical_member_type(type: ib_type, index);
17329 uint32_t loc = get_member_decoration(id: ib_type.self, index, decoration: DecorationLocation);
17330 uint32_t cmp = get_member_decoration(id: ib_type.self, index, decoration: DecorationComponent);
17331 auto p_va = inputs_by_location.find(x: {.location: loc, .component: cmp});
17332 if (p_va != end(cont: inputs_by_location) && p_va->second.vecsize > type.vecsize)
17333 type.vecsize = p_va->second.vecsize;
17334
17335 return type;
17336}
17337
17338uint32_t CompilerMSL::get_declared_type_array_stride_msl(const SPIRType &type, bool is_packed, bool row_major) const
17339{
17340 // Array stride in MSL is always size * array_size. sizeof(float3) == 16,
17341 // unlike GLSL and HLSL where array stride would be 16 and size 12.
17342
17343 // We could use parent type here and recurse, but that makes creating physical type remappings
17344 // far more complicated. We'd rather just create the final type, and ignore having to create the entire type
17345 // hierarchy in order to compute this value, so make a temporary type on the stack.
17346
17347 auto basic_type = type;
17348 basic_type.array.clear();
17349 basic_type.array_size_literal.clear();
17350 uint32_t value_size = get_declared_type_size_msl(type: basic_type, packed: is_packed, row_major);
17351
17352 uint32_t dimensions = uint32_t(type.array.size());
17353 assert(dimensions > 0);
17354 dimensions--;
17355
17356 // Multiply together every dimension, except the last one.
17357 for (uint32_t dim = 0; dim < dimensions; dim++)
17358 {
17359 uint32_t array_size = to_array_size_literal(type, index: dim);
17360 value_size *= max<uint32_t>(a: array_size, b: 1u);
17361 }
17362
17363 return value_size;
17364}
17365
17366uint32_t CompilerMSL::get_declared_struct_member_array_stride_msl(const SPIRType &type, uint32_t index) const
17367{
17368 return get_declared_type_array_stride_msl(type: get_physical_member_type(type, index),
17369 is_packed: member_is_packed_physical_type(type, index),
17370 row_major: has_member_decoration(id: type.self, index, decoration: DecorationRowMajor));
17371}
17372
17373uint32_t CompilerMSL::get_declared_input_array_stride_msl(const SPIRType &type, uint32_t index) const
17374{
17375 return get_declared_type_array_stride_msl(type: get_presumed_input_type(ib_type: type, index), is_packed: false,
17376 row_major: has_member_decoration(id: type.self, index, decoration: DecorationRowMajor));
17377}
17378
17379uint32_t CompilerMSL::get_declared_type_matrix_stride_msl(const SPIRType &type, bool packed, bool row_major) const
17380{
17381 // For packed matrices, we just use the size of the vector type.
17382 // Otherwise, MatrixStride == alignment, which is the size of the underlying vector type.
17383 if (packed)
17384 return (type.width / 8) * ((row_major && type.columns > 1) ? type.columns : type.vecsize);
17385 else
17386 return get_declared_type_alignment_msl(type, packed: false, row_major);
17387}
17388
17389uint32_t CompilerMSL::get_declared_struct_member_matrix_stride_msl(const SPIRType &type, uint32_t index) const
17390{
17391 return get_declared_type_matrix_stride_msl(type: get_physical_member_type(type, index),
17392 packed: member_is_packed_physical_type(type, index),
17393 row_major: has_member_decoration(id: type.self, index, decoration: DecorationRowMajor));
17394}
17395
17396uint32_t CompilerMSL::get_declared_input_matrix_stride_msl(const SPIRType &type, uint32_t index) const
17397{
17398 return get_declared_type_matrix_stride_msl(type: get_presumed_input_type(ib_type: type, index), packed: false,
17399 row_major: has_member_decoration(id: type.self, index, decoration: DecorationRowMajor));
17400}
17401
17402uint32_t CompilerMSL::get_declared_struct_size_msl(const SPIRType &struct_type, bool ignore_alignment,
17403 bool ignore_padding) const
17404{
17405 // If we have a target size, that is the declared size as well.
17406 if (!ignore_padding && has_extended_decoration(id: struct_type.self, decoration: SPIRVCrossDecorationPaddingTarget))
17407 return get_extended_decoration(id: struct_type.self, decoration: SPIRVCrossDecorationPaddingTarget);
17408
17409 if (struct_type.member_types.empty())
17410 return 0;
17411
17412 uint32_t mbr_cnt = uint32_t(struct_type.member_types.size());
17413
17414 // In MSL, a struct's alignment is equal to the maximum alignment of any of its members.
17415 uint32_t alignment = 1;
17416
17417 if (!ignore_alignment)
17418 {
17419 for (uint32_t i = 0; i < mbr_cnt; i++)
17420 {
17421 uint32_t mbr_alignment = get_declared_struct_member_alignment_msl(struct_type, index: i);
17422 alignment = max(a: alignment, b: mbr_alignment);
17423 }
17424 }
17425
17426 // Last member will always be matched to the final Offset decoration, but size of struct in MSL now depends
17427 // on physical size in MSL, and the size of the struct itself is then aligned to struct alignment.
17428 uint32_t spirv_offset = type_struct_member_offset(type: struct_type, index: mbr_cnt - 1);
17429 uint32_t msl_size = spirv_offset + get_declared_struct_member_size_msl(struct_type, index: mbr_cnt - 1);
17430 msl_size = (msl_size + alignment - 1) & ~(alignment - 1);
17431 return msl_size;
17432}
17433
17434uint32_t CompilerMSL::get_physical_type_stride(const SPIRType &type) const
17435{
17436 // This should only be relevant for plain types such as scalars and vectors?
17437 // If we're pointing to a struct, it will recursively pick up packed/row-major state.
17438 return get_declared_type_size_msl(type, packed: false, row_major: false);
17439}
17440
17441// Returns the byte size of a struct member.
17442uint32_t CompilerMSL::get_declared_type_size_msl(const SPIRType &type, bool is_packed, bool row_major) const
17443{
17444 // Pointers take 8 bytes each
17445 // Match both pointer and array-of-pointer here.
17446 if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer)
17447 {
17448 uint32_t type_size = 8;
17449
17450 // Work our way through potentially layered arrays,
17451 // stopping when we hit a pointer that is not also an array.
17452 int32_t dim_idx = (int32_t)type.array.size() - 1;
17453 auto *p_type = &type;
17454 while (!is_pointer(type: *p_type) && dim_idx >= 0)
17455 {
17456 type_size *= to_array_size_literal(type: *p_type, index: dim_idx);
17457 p_type = &get<SPIRType>(id: p_type->parent_type);
17458 dim_idx--;
17459 }
17460
17461 return type_size;
17462 }
17463
17464 switch (type.basetype)
17465 {
17466 case SPIRType::Unknown:
17467 case SPIRType::Void:
17468 case SPIRType::AtomicCounter:
17469 case SPIRType::Image:
17470 case SPIRType::SampledImage:
17471 case SPIRType::Sampler:
17472 SPIRV_CROSS_THROW("Querying size of opaque object.");
17473
17474 default:
17475 {
17476 if (!type.array.empty())
17477 {
17478 uint32_t array_size = to_array_size_literal(type);
17479 return get_declared_type_array_stride_msl(type, is_packed, row_major) * max<uint32_t>(a: array_size, b: 1u);
17480 }
17481
17482 if (type.basetype == SPIRType::Struct)
17483 return get_declared_struct_size_msl(struct_type: type);
17484
17485 if (is_packed)
17486 {
17487 return type.vecsize * type.columns * (type.width / 8);
17488 }
17489 else
17490 {
17491 // An unpacked 3-element vector or matrix column is the same memory size as a 4-element.
17492 uint32_t vecsize = type.vecsize;
17493 uint32_t columns = type.columns;
17494
17495 if (row_major && columns > 1)
17496 swap(a&: vecsize, b&: columns);
17497
17498 if (vecsize == 3)
17499 vecsize = 4;
17500
17501 return vecsize * columns * (type.width / 8);
17502 }
17503 }
17504 }
17505}
17506
17507uint32_t CompilerMSL::get_declared_struct_member_size_msl(const SPIRType &type, uint32_t index) const
17508{
17509 return get_declared_type_size_msl(type: get_physical_member_type(type, index),
17510 is_packed: member_is_packed_physical_type(type, index),
17511 row_major: has_member_decoration(id: type.self, index, decoration: DecorationRowMajor));
17512}
17513
17514uint32_t CompilerMSL::get_declared_input_size_msl(const SPIRType &type, uint32_t index) const
17515{
17516 return get_declared_type_size_msl(type: get_presumed_input_type(ib_type: type, index), is_packed: false,
17517 row_major: has_member_decoration(id: type.self, index, decoration: DecorationRowMajor));
17518}
17519
17520// Returns the byte alignment of a type.
17521uint32_t CompilerMSL::get_declared_type_alignment_msl(const SPIRType &type, bool is_packed, bool row_major) const
17522{
17523 // Pointers align on multiples of 8 bytes.
17524 // Deliberately ignore array-ness here. It's not relevant for alignment.
17525 if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer)
17526 return 8;
17527
17528 switch (type.basetype)
17529 {
17530 case SPIRType::Unknown:
17531 case SPIRType::Void:
17532 case SPIRType::AtomicCounter:
17533 case SPIRType::Image:
17534 case SPIRType::SampledImage:
17535 case SPIRType::Sampler:
17536 SPIRV_CROSS_THROW("Querying alignment of opaque object.");
17537
17538 case SPIRType::Double:
17539 SPIRV_CROSS_THROW("double types are not supported in buffers in MSL.");
17540
17541 case SPIRType::Struct:
17542 {
17543 // In MSL, a struct's alignment is equal to the maximum alignment of any of its members.
17544 uint32_t alignment = 1;
17545 for (uint32_t i = 0; i < type.member_types.size(); i++)
17546 alignment = max(a: alignment, b: uint32_t(get_declared_struct_member_alignment_msl(struct_type: type, index: i)));
17547 return alignment;
17548 }
17549
17550 default:
17551 {
17552 if (type.basetype == SPIRType::Int64 && !msl_options.supports_msl_version(major: 2, minor: 3))
17553 SPIRV_CROSS_THROW("long types in buffers are only supported in MSL 2.3 and above.");
17554 if (type.basetype == SPIRType::UInt64 && !msl_options.supports_msl_version(major: 2, minor: 3))
17555 SPIRV_CROSS_THROW("ulong types in buffers are only supported in MSL 2.3 and above.");
17556 // Alignment of packed type is the same as the underlying component or column size.
17557 // Alignment of unpacked type is the same as the vector size.
17558 // Alignment of 3-elements vector is the same as 4-elements (including packed using column).
17559 if (is_packed)
17560 {
17561 // If we have packed_T and friends, the alignment is always scalar.
17562 return type.width / 8;
17563 }
17564 else
17565 {
17566 // This is the general rule for MSL. Size == alignment.
17567 uint32_t vecsize = (row_major && type.columns > 1) ? type.columns : type.vecsize;
17568 return (type.width / 8) * (vecsize == 3 ? 4 : vecsize);
17569 }
17570 }
17571 }
17572}
17573
17574uint32_t CompilerMSL::get_declared_struct_member_alignment_msl(const SPIRType &type, uint32_t index) const
17575{
17576 return get_declared_type_alignment_msl(type: get_physical_member_type(type, index),
17577 is_packed: member_is_packed_physical_type(type, index),
17578 row_major: has_member_decoration(id: type.self, index, decoration: DecorationRowMajor));
17579}
17580
17581uint32_t CompilerMSL::get_declared_input_alignment_msl(const SPIRType &type, uint32_t index) const
17582{
17583 return get_declared_type_alignment_msl(type: get_presumed_input_type(ib_type: type, index), is_packed: false,
17584 row_major: has_member_decoration(id: type.self, index, decoration: DecorationRowMajor));
17585}
17586
17587bool CompilerMSL::skip_argument(uint32_t) const
17588{
17589 return false;
17590}
17591
17592void CompilerMSL::analyze_sampled_image_usage()
17593{
17594 if (msl_options.swizzle_texture_samples)
17595 {
17596 SampledImageScanner scanner(*this);
17597 traverse_all_reachable_opcodes(block: get<SPIRFunction>(id: ir.default_entry_point), handler&: scanner);
17598 }
17599}
17600
17601bool CompilerMSL::SampledImageScanner::handle(spv::Op opcode, const uint32_t *args, uint32_t length)
17602{
17603 switch (opcode)
17604 {
17605 case OpLoad:
17606 case OpImage:
17607 case OpSampledImage:
17608 {
17609 if (length < 3)
17610 return false;
17611
17612 uint32_t result_type = args[0];
17613 auto &type = compiler.get<SPIRType>(id: result_type);
17614 if ((type.basetype != SPIRType::Image && type.basetype != SPIRType::SampledImage) || type.image.sampled != 1)
17615 return true;
17616
17617 uint32_t id = args[1];
17618 compiler.set<SPIRExpression>(id, args: "", args&: result_type, args: true);
17619 break;
17620 }
17621 case OpImageSampleExplicitLod:
17622 case OpImageSampleProjExplicitLod:
17623 case OpImageSampleDrefExplicitLod:
17624 case OpImageSampleProjDrefExplicitLod:
17625 case OpImageSampleImplicitLod:
17626 case OpImageSampleProjImplicitLod:
17627 case OpImageSampleDrefImplicitLod:
17628 case OpImageSampleProjDrefImplicitLod:
17629 case OpImageFetch:
17630 case OpImageGather:
17631 case OpImageDrefGather:
17632 compiler.has_sampled_images =
17633 compiler.has_sampled_images || compiler.is_sampled_image_type(type: compiler.expression_type(id: args[2]));
17634 compiler.needs_swizzle_buffer_def = compiler.needs_swizzle_buffer_def || compiler.has_sampled_images;
17635 break;
17636 default:
17637 break;
17638 }
17639 return true;
17640}
17641
17642// If a needed custom function wasn't added before, add it and force a recompile.
17643void CompilerMSL::add_spv_func_and_recompile(SPVFuncImpl spv_func)
17644{
17645 if (spv_function_implementations.count(x: spv_func) == 0)
17646 {
17647 spv_function_implementations.insert(x: spv_func);
17648 suppress_missing_prototypes = true;
17649 force_recompile();
17650 }
17651}
17652
17653bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, uint32_t length)
17654{
17655 // Since MSL exists in a single execution scope, function prototype declarations are not
17656 // needed, and clutter the output. If secondary functions are output (either as a SPIR-V
17657 // function implementation or as indicated by the presence of OpFunctionCall), then set
17658 // suppress_missing_prototypes to suppress compiler warnings of missing function prototypes.
17659
17660 // Mark if the input requires the implementation of an SPIR-V function that does not exist in Metal.
17661 SPVFuncImpl spv_func = get_spv_func_impl(opcode, args);
17662 if (spv_func != SPVFuncImplNone)
17663 {
17664 compiler.spv_function_implementations.insert(x: spv_func);
17665 suppress_missing_prototypes = true;
17666 }
17667
17668 switch (opcode)
17669 {
17670
17671 case OpFunctionCall:
17672 suppress_missing_prototypes = true;
17673 break;
17674
17675 case OpDemoteToHelperInvocationEXT:
17676 uses_discard = true;
17677 break;
17678
17679 // Emulate texture2D atomic operations
17680 case OpImageTexelPointer:
17681 {
17682 if (!compiler.msl_options.supports_msl_version(major: 3, minor: 1))
17683 {
17684 auto *var = compiler.maybe_get_backing_variable(chain: args[2]);
17685 image_pointers_emulated[args[1]] = var ? var->self : ID(0);
17686 }
17687 break;
17688 }
17689
17690 case OpImageWrite:
17691 uses_image_write = true;
17692 break;
17693
17694 case OpStore:
17695 check_resource_write(var_id: args[0]);
17696 break;
17697
17698 // Emulate texture2D atomic operations
17699 case OpAtomicExchange:
17700 case OpAtomicCompareExchange:
17701 case OpAtomicCompareExchangeWeak:
17702 case OpAtomicIIncrement:
17703 case OpAtomicIDecrement:
17704 case OpAtomicIAdd:
17705 case OpAtomicFAddEXT:
17706 case OpAtomicISub:
17707 case OpAtomicSMin:
17708 case OpAtomicUMin:
17709 case OpAtomicSMax:
17710 case OpAtomicUMax:
17711 case OpAtomicAnd:
17712 case OpAtomicOr:
17713 case OpAtomicXor:
17714 {
17715 uses_atomics = true;
17716 auto it = image_pointers_emulated.find(x: args[2]);
17717 if (it != image_pointers_emulated.end())
17718 {
17719 uses_image_write = true;
17720 compiler.atomic_image_vars_emulated.insert(x: it->second);
17721 }
17722 else
17723 check_resource_write(var_id: args[2]);
17724 break;
17725 }
17726
17727 case OpAtomicStore:
17728 {
17729 uses_atomics = true;
17730 auto it = image_pointers_emulated.find(x: args[0]);
17731 if (it != image_pointers_emulated.end())
17732 {
17733 compiler.atomic_image_vars_emulated.insert(x: it->second);
17734 uses_image_write = true;
17735 }
17736 else
17737 check_resource_write(var_id: args[0]);
17738 break;
17739 }
17740
17741 case OpAtomicLoad:
17742 {
17743 uses_atomics = true;
17744 auto it = image_pointers_emulated.find(x: args[2]);
17745 if (it != image_pointers_emulated.end())
17746 {
17747 compiler.atomic_image_vars_emulated.insert(x: it->second);
17748 }
17749 break;
17750 }
17751
17752 case OpGroupNonUniformInverseBallot:
17753 needs_subgroup_invocation_id = true;
17754 break;
17755
17756 case OpGroupNonUniformBallotFindLSB:
17757 case OpGroupNonUniformBallotFindMSB:
17758 needs_subgroup_size = true;
17759 break;
17760
17761 case OpGroupNonUniformBallotBitCount:
17762 if (args[3] == GroupOperationReduce)
17763 needs_subgroup_size = true;
17764 else
17765 needs_subgroup_invocation_id = true;
17766 break;
17767
17768 case OpArrayLength:
17769 {
17770 auto *var = compiler.maybe_get_backing_variable(chain: args[2]);
17771 if (var != nullptr)
17772 {
17773 if (!compiler.is_var_runtime_size_array(var: *var))
17774 compiler.buffers_requiring_array_length.insert(x: var->self);
17775 }
17776 break;
17777 }
17778
17779 case OpInBoundsAccessChain:
17780 case OpAccessChain:
17781 case OpPtrAccessChain:
17782 {
17783 // OpArrayLength might want to know if taking ArrayLength of an array of SSBOs.
17784 uint32_t result_type = args[0];
17785 uint32_t id = args[1];
17786 uint32_t ptr = args[2];
17787
17788 compiler.set<SPIRExpression>(id, args: "", args&: result_type, args: true);
17789 compiler.register_read(expr: id, chain: ptr, forwarded: true);
17790 compiler.ir.ids[id].set_allow_type_rewrite();
17791 break;
17792 }
17793
17794 case OpExtInst:
17795 {
17796 uint32_t extension_set = args[2];
17797 if (compiler.get<SPIRExtension>(id: extension_set).ext == SPIRExtension::GLSL)
17798 {
17799 auto op_450 = static_cast<GLSLstd450>(args[3]);
17800 switch (op_450)
17801 {
17802 case GLSLstd450InterpolateAtCentroid:
17803 case GLSLstd450InterpolateAtSample:
17804 case GLSLstd450InterpolateAtOffset:
17805 {
17806 if (!compiler.msl_options.supports_msl_version(major: 2, minor: 3))
17807 SPIRV_CROSS_THROW("Pull-model interpolation requires MSL 2.3.");
17808 // Fragment varyings used with pull-model interpolation need special handling,
17809 // due to the way pull-model interpolation works in Metal.
17810 auto *var = compiler.maybe_get_backing_variable(chain: args[4]);
17811 if (var)
17812 {
17813 compiler.pull_model_inputs.insert(x: var->self);
17814 auto &var_type = compiler.get_variable_element_type(var: *var);
17815 // In addition, if this variable has a 'Sample' decoration, we need the sample ID
17816 // in order to do default interpolation.
17817 if (compiler.has_decoration(id: var->self, decoration: DecorationSample))
17818 {
17819 needs_sample_id = true;
17820 }
17821 else if (var_type.basetype == SPIRType::Struct)
17822 {
17823 // Now we need to check each member and see if it has this decoration.
17824 for (uint32_t i = 0; i < var_type.member_types.size(); ++i)
17825 {
17826 if (compiler.has_member_decoration(id: var_type.self, index: i, decoration: DecorationSample))
17827 {
17828 needs_sample_id = true;
17829 break;
17830 }
17831 }
17832 }
17833 }
17834 break;
17835 }
17836 default:
17837 break;
17838 }
17839 }
17840 break;
17841 }
17842
17843 case OpIsHelperInvocationEXT:
17844 if (compiler.needs_manual_helper_invocation_updates())
17845 needs_helper_invocation = true;
17846 break;
17847
17848 default:
17849 break;
17850 }
17851
17852 // If it has one, keep track of the instruction's result type, mapped by ID
17853 uint32_t result_type, result_id;
17854 if (compiler.instruction_to_result_type(result_type, result_id, op: opcode, args, length))
17855 result_types[result_id] = result_type;
17856
17857 return true;
17858}
17859
17860// If the variable is a Uniform or StorageBuffer, mark that a resource has been written to.
17861void CompilerMSL::OpCodePreprocessor::check_resource_write(uint32_t var_id)
17862{
17863 auto *p_var = compiler.maybe_get_backing_variable(chain: var_id);
17864 StorageClass sc = p_var ? p_var->storage : StorageClassMax;
17865 if (sc == StorageClassUniform || sc == StorageClassStorageBuffer)
17866 uses_buffer_write = true;
17867}
17868
17869// Returns an enumeration of a SPIR-V function that needs to be output for certain Op codes.
17870CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op opcode, const uint32_t *args)
17871{
17872 switch (opcode)
17873 {
17874 case OpFMod:
17875 return SPVFuncImplMod;
17876
17877 case OpFAdd:
17878 case OpFSub:
17879 if (compiler.msl_options.invariant_float_math ||
17880 compiler.has_decoration(id: args[1], decoration: DecorationNoContraction))
17881 {
17882 return opcode == OpFAdd ? SPVFuncImplFAdd : SPVFuncImplFSub;
17883 }
17884 break;
17885
17886 case OpFMul:
17887 case OpOuterProduct:
17888 case OpMatrixTimesVector:
17889 case OpVectorTimesMatrix:
17890 case OpMatrixTimesMatrix:
17891 if (compiler.msl_options.invariant_float_math ||
17892 compiler.has_decoration(id: args[1], decoration: DecorationNoContraction))
17893 {
17894 return SPVFuncImplFMul;
17895 }
17896 break;
17897
17898 case OpQuantizeToF16:
17899 return SPVFuncImplQuantizeToF16;
17900
17901 case OpTypeArray:
17902 {
17903 // Allow Metal to use the array<T> template to make arrays a value type
17904 return SPVFuncImplUnsafeArray;
17905 }
17906
17907 // Emulate texture2D atomic operations
17908 case OpAtomicExchange:
17909 case OpAtomicCompareExchange:
17910 case OpAtomicCompareExchangeWeak:
17911 case OpAtomicIIncrement:
17912 case OpAtomicIDecrement:
17913 case OpAtomicIAdd:
17914 case OpAtomicFAddEXT:
17915 case OpAtomicISub:
17916 case OpAtomicSMin:
17917 case OpAtomicUMin:
17918 case OpAtomicSMax:
17919 case OpAtomicUMax:
17920 case OpAtomicAnd:
17921 case OpAtomicOr:
17922 case OpAtomicXor:
17923 case OpAtomicLoad:
17924 case OpAtomicStore:
17925 {
17926 auto it = image_pointers_emulated.find(x: args[opcode == OpAtomicStore ? 0 : 2]);
17927 if (it != image_pointers_emulated.end())
17928 {
17929 uint32_t tid = compiler.get<SPIRVariable>(id: it->second).basetype;
17930 if (tid && compiler.get<SPIRType>(id: tid).image.dim == Dim2D)
17931 return SPVFuncImplImage2DAtomicCoords;
17932 }
17933 break;
17934 }
17935
17936 case OpImageFetch:
17937 case OpImageRead:
17938 case OpImageWrite:
17939 {
17940 // Retrieve the image type, and if it's a Buffer, emit a texel coordinate function
17941 uint32_t tid = result_types[args[opcode == OpImageWrite ? 0 : 2]];
17942 if (tid && compiler.get<SPIRType>(id: tid).image.dim == DimBuffer && !compiler.msl_options.texture_buffer_native)
17943 return SPVFuncImplTexelBufferCoords;
17944 break;
17945 }
17946
17947 case OpExtInst:
17948 {
17949 uint32_t extension_set = args[2];
17950 if (compiler.get<SPIRExtension>(id: extension_set).ext == SPIRExtension::GLSL)
17951 {
17952 auto op_450 = static_cast<GLSLstd450>(args[3]);
17953 switch (op_450)
17954 {
17955 case GLSLstd450Radians:
17956 return SPVFuncImplRadians;
17957 case GLSLstd450Degrees:
17958 return SPVFuncImplDegrees;
17959 case GLSLstd450FindILsb:
17960 return SPVFuncImplFindILsb;
17961 case GLSLstd450FindSMsb:
17962 return SPVFuncImplFindSMsb;
17963 case GLSLstd450FindUMsb:
17964 return SPVFuncImplFindUMsb;
17965 case GLSLstd450SSign:
17966 return SPVFuncImplSSign;
17967 case GLSLstd450Reflect:
17968 {
17969 auto &type = compiler.get<SPIRType>(id: args[0]);
17970 if (type.vecsize == 1)
17971 return SPVFuncImplReflectScalar;
17972 break;
17973 }
17974 case GLSLstd450Refract:
17975 {
17976 auto &type = compiler.get<SPIRType>(id: args[0]);
17977 if (type.vecsize == 1)
17978 return SPVFuncImplRefractScalar;
17979 break;
17980 }
17981 case GLSLstd450FaceForward:
17982 {
17983 auto &type = compiler.get<SPIRType>(id: args[0]);
17984 if (type.vecsize == 1)
17985 return SPVFuncImplFaceForwardScalar;
17986 break;
17987 }
17988 case GLSLstd450MatrixInverse:
17989 {
17990 auto &mat_type = compiler.get<SPIRType>(id: args[0]);
17991 switch (mat_type.columns)
17992 {
17993 case 2:
17994 return SPVFuncImplInverse2x2;
17995 case 3:
17996 return SPVFuncImplInverse3x3;
17997 case 4:
17998 return SPVFuncImplInverse4x4;
17999 default:
18000 break;
18001 }
18002 break;
18003 }
18004 default:
18005 break;
18006 }
18007 }
18008 break;
18009 }
18010
18011 case OpGroupNonUniformBroadcast:
18012 case OpSubgroupReadInvocationKHR:
18013 return SPVFuncImplSubgroupBroadcast;
18014
18015 case OpGroupNonUniformBroadcastFirst:
18016 case OpSubgroupFirstInvocationKHR:
18017 return SPVFuncImplSubgroupBroadcastFirst;
18018
18019 case OpGroupNonUniformBallot:
18020 case OpSubgroupBallotKHR:
18021 return SPVFuncImplSubgroupBallot;
18022
18023 case OpGroupNonUniformInverseBallot:
18024 case OpGroupNonUniformBallotBitExtract:
18025 return SPVFuncImplSubgroupBallotBitExtract;
18026
18027 case OpGroupNonUniformBallotFindLSB:
18028 return SPVFuncImplSubgroupBallotFindLSB;
18029
18030 case OpGroupNonUniformBallotFindMSB:
18031 return SPVFuncImplSubgroupBallotFindMSB;
18032
18033 case OpGroupNonUniformBallotBitCount:
18034 return SPVFuncImplSubgroupBallotBitCount;
18035
18036 case OpGroupNonUniformAllEqual:
18037 case OpSubgroupAllEqualKHR:
18038 return SPVFuncImplSubgroupAllEqual;
18039
18040 case OpGroupNonUniformShuffle:
18041 return SPVFuncImplSubgroupShuffle;
18042
18043 case OpGroupNonUniformShuffleXor:
18044 return SPVFuncImplSubgroupShuffleXor;
18045
18046 case OpGroupNonUniformShuffleUp:
18047 return SPVFuncImplSubgroupShuffleUp;
18048
18049 case OpGroupNonUniformShuffleDown:
18050 return SPVFuncImplSubgroupShuffleDown;
18051
18052 case OpGroupNonUniformQuadBroadcast:
18053 return SPVFuncImplQuadBroadcast;
18054
18055 case OpGroupNonUniformQuadSwap:
18056 return SPVFuncImplQuadSwap;
18057
18058 case OpSDot:
18059 case OpUDot:
18060 case OpSUDot:
18061 case OpSDotAccSat:
18062 case OpUDotAccSat:
18063 case OpSUDotAccSat:
18064 return SPVFuncImplReduceAdd;
18065
18066 case OpSMulExtended:
18067 case OpUMulExtended:
18068 return SPVFuncImplMulExtended;
18069
18070 default:
18071 break;
18072 }
18073 return SPVFuncImplNone;
18074}
18075
18076// Sort both type and meta member content based on builtin status (put builtins at end),
18077// then by the required sorting aspect.
18078void CompilerMSL::MemberSorter::sort()
18079{
18080 // Create a temporary array of consecutive member indices and sort it based on how
18081 // the members should be reordered, based on builtin and sorting aspect meta info.
18082 size_t mbr_cnt = type.member_types.size();
18083 SmallVector<uint32_t> mbr_idxs(mbr_cnt);
18084 std::iota(first: mbr_idxs.begin(), last: mbr_idxs.end(), value: 0); // Fill with consecutive indices
18085 std::stable_sort(first: mbr_idxs.begin(), last: mbr_idxs.end(), comp: *this); // Sort member indices based on sorting aspect
18086
18087 bool sort_is_identity = true;
18088 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
18089 {
18090 if (mbr_idx != mbr_idxs[mbr_idx])
18091 {
18092 sort_is_identity = false;
18093 break;
18094 }
18095 }
18096
18097 if (sort_is_identity)
18098 return;
18099
18100 if (meta.members.size() < type.member_types.size())
18101 {
18102 // This should never trigger in normal circumstances, but to be safe.
18103 meta.members.resize(new_size: type.member_types.size());
18104 }
18105
18106 // Move type and meta member info to the order defined by the sorted member indices.
18107 // This is done by creating temporary copies of both member types and meta, and then
18108 // copying back to the original content at the sorted indices.
18109 auto mbr_types_cpy = type.member_types;
18110 auto mbr_meta_cpy = meta.members;
18111 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
18112 {
18113 type.member_types[mbr_idx] = mbr_types_cpy[mbr_idxs[mbr_idx]];
18114 meta.members[mbr_idx] = mbr_meta_cpy[mbr_idxs[mbr_idx]];
18115 }
18116
18117 // If we're sorting by Offset, this might affect user code which accesses a buffer block.
18118 // We will need to redirect member indices from defined index to sorted index using reverse lookup.
18119 if (sort_aspect == SortAspect::Offset)
18120 {
18121 type.member_type_index_redirection.resize(new_size: mbr_cnt);
18122 for (uint32_t map_idx = 0; map_idx < mbr_cnt; map_idx++)
18123 type.member_type_index_redirection[mbr_idxs[map_idx]] = map_idx;
18124 }
18125}
18126
18127bool CompilerMSL::MemberSorter::operator()(uint32_t mbr_idx1, uint32_t mbr_idx2)
18128{
18129 auto &mbr_meta1 = meta.members[mbr_idx1];
18130 auto &mbr_meta2 = meta.members[mbr_idx2];
18131
18132 if (sort_aspect == LocationThenBuiltInType)
18133 {
18134 // Sort first by builtin status (put builtins at end), then by the sorting aspect.
18135 if (mbr_meta1.builtin != mbr_meta2.builtin)
18136 return mbr_meta2.builtin;
18137 else if (mbr_meta1.builtin)
18138 return mbr_meta1.builtin_type < mbr_meta2.builtin_type;
18139 else if (mbr_meta1.location == mbr_meta2.location)
18140 return mbr_meta1.component < mbr_meta2.component;
18141 else
18142 return mbr_meta1.location < mbr_meta2.location;
18143 }
18144 else
18145 return mbr_meta1.offset < mbr_meta2.offset;
18146}
18147
18148CompilerMSL::MemberSorter::MemberSorter(SPIRType &t, Meta &m, SortAspect sa)
18149 : type(t)
18150 , meta(m)
18151 , sort_aspect(sa)
18152{
18153 // Ensure enough meta info is available
18154 meta.members.resize(new_size: max(a: type.member_types.size(), b: meta.members.size()));
18155}
18156
18157void CompilerMSL::remap_constexpr_sampler(VariableID id, const MSLConstexprSampler &sampler)
18158{
18159 auto &type = get<SPIRType>(id: get<SPIRVariable>(id).basetype);
18160 if (type.basetype != SPIRType::SampledImage && type.basetype != SPIRType::Sampler)
18161 SPIRV_CROSS_THROW("Can only remap SampledImage and Sampler type.");
18162 if (!type.array.empty())
18163 SPIRV_CROSS_THROW("Can not remap array of samplers.");
18164 constexpr_samplers_by_id[id] = sampler;
18165}
18166
18167void CompilerMSL::remap_constexpr_sampler_by_binding(uint32_t desc_set, uint32_t binding,
18168 const MSLConstexprSampler &sampler)
18169{
18170 constexpr_samplers_by_binding[{ .desc_set: desc_set, .binding: binding }] = sampler;
18171}
18172
18173void CompilerMSL::cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
18174{
18175 bool is_packed = has_extended_decoration(id: source_id, decoration: SPIRVCrossDecorationPhysicalTypePacked);
18176 auto *source_expr = maybe_get<SPIRExpression>(id: source_id);
18177 auto *var = maybe_get_backing_variable(chain: source_id);
18178 const SPIRType *var_type = nullptr, *phys_type = nullptr;
18179
18180 if (uint32_t phys_id = get_extended_decoration(id: source_id, decoration: SPIRVCrossDecorationPhysicalTypeID))
18181 phys_type = &get<SPIRType>(id: phys_id);
18182 else
18183 phys_type = &expr_type;
18184
18185 if (var)
18186 {
18187 source_id = var->self;
18188 var_type = &get_variable_data_type(var: *var);
18189 }
18190
18191 bool rewrite_boolean_load =
18192 expr_type.basetype == SPIRType::Boolean &&
18193 (var && (var->storage == StorageClassWorkgroup || var_type->basetype == SPIRType::Struct));
18194
18195 // Type fixups for workgroup variables if they are booleans.
18196 if (rewrite_boolean_load)
18197 {
18198 if (is_array(type: expr_type))
18199 expr = to_rerolled_array_expression(parent_type: expr_type, expr, type: expr_type);
18200 else
18201 expr = join(ts: type_to_glsl(type: expr_type), ts: "(", ts&: expr, ts: ")");
18202 }
18203
18204 // Type fixups for workgroup variables if they are matrices.
18205 // Don't do fixup for packed types; those are handled specially.
18206 // FIXME: Maybe use a type like spvStorageMatrix for packed matrices?
18207 if (!msl_options.supports_msl_version(major: 3, minor: 0) && var &&
18208 (var->storage == StorageClassWorkgroup ||
18209 (var_type->basetype == SPIRType::Struct &&
18210 has_extended_decoration(id: var_type->self, decoration: SPIRVCrossDecorationWorkgroupStruct) && !is_packed)) &&
18211 expr_type.columns > 1)
18212 {
18213 SPIRType matrix_type = *phys_type;
18214 if (source_expr && source_expr->need_transpose)
18215 swap(a&: matrix_type.vecsize, b&: matrix_type.columns);
18216 matrix_type.array.clear();
18217 matrix_type.array_size_literal.clear();
18218 expr = join(ts: type_to_glsl(type: matrix_type), ts: "(", ts&: expr, ts: ")");
18219 }
18220
18221 // Only interested in standalone builtin variables in the switch below.
18222 if (!has_decoration(id: source_id, decoration: DecorationBuiltIn))
18223 {
18224 // If the backing variable does not match our expected sign, we can fix it up here.
18225 // See ensure_correct_input_type().
18226 if (var && var->storage == StorageClassInput)
18227 {
18228 auto &base_type = get<SPIRType>(id: var->basetype);
18229 if (base_type.basetype != SPIRType::Struct && expr_type.basetype != base_type.basetype)
18230 expr = join(ts: type_to_glsl(type: expr_type), ts: "(", ts&: expr, ts: ")");
18231 }
18232 return;
18233 }
18234
18235 auto builtin = static_cast<BuiltIn>(get_decoration(id: source_id, decoration: DecorationBuiltIn));
18236 auto expected_type = expr_type.basetype;
18237 auto expected_width = expr_type.width;
18238 switch (builtin)
18239 {
18240 case BuiltInGlobalInvocationId:
18241 case BuiltInLocalInvocationId:
18242 case BuiltInWorkgroupId:
18243 case BuiltInLocalInvocationIndex:
18244 case BuiltInWorkgroupSize:
18245 case BuiltInNumWorkgroups:
18246 case BuiltInLayer:
18247 case BuiltInViewportIndex:
18248 case BuiltInFragStencilRefEXT:
18249 case BuiltInPrimitiveId:
18250 case BuiltInSubgroupSize:
18251 case BuiltInSubgroupLocalInvocationId:
18252 case BuiltInViewIndex:
18253 case BuiltInVertexIndex:
18254 case BuiltInInstanceIndex:
18255 case BuiltInBaseInstance:
18256 case BuiltInBaseVertex:
18257 case BuiltInSampleMask:
18258 expected_type = SPIRType::UInt;
18259 expected_width = 32;
18260 break;
18261
18262 case BuiltInTessLevelInner:
18263 case BuiltInTessLevelOuter:
18264 if (is_tesc_shader())
18265 {
18266 expected_type = SPIRType::Half;
18267 expected_width = 16;
18268 }
18269 break;
18270
18271 default:
18272 break;
18273 }
18274
18275 if (is_array(type: expr_type) && builtin == BuiltInSampleMask)
18276 {
18277 // Needs special handling.
18278 auto wrap_expr = join(ts: type_to_glsl(type: expr_type), ts: "({ ");
18279 wrap_expr += join(ts: type_to_glsl(type: get<SPIRType>(id: expr_type.parent_type)), ts: "(", ts&: expr, ts: ")");
18280 wrap_expr += " })";
18281 expr = std::move(wrap_expr);
18282 }
18283 else if (expected_type != expr_type.basetype)
18284 {
18285 if (is_array(type: expr_type) && (builtin == BuiltInTessLevelInner || builtin == BuiltInTessLevelOuter))
18286 {
18287 // Triggers when loading TessLevel directly as an array.
18288 // Need explicit padding + cast.
18289 auto wrap_expr = join(ts: type_to_glsl(type: expr_type), ts: "({ ");
18290
18291 uint32_t array_size = get_physical_tess_level_array_size(builtin);
18292 for (uint32_t i = 0; i < array_size; i++)
18293 {
18294 if (array_size > 1)
18295 wrap_expr += join(ts: "float(", ts&: expr, ts: "[", ts&: i, ts: "])");
18296 else
18297 wrap_expr += join(ts: "float(", ts&: expr, ts: ")");
18298 if (i + 1 < array_size)
18299 wrap_expr += ", ";
18300 }
18301
18302 if (is_tessellating_triangles())
18303 wrap_expr += ", 0.0";
18304
18305 wrap_expr += " })";
18306 expr = std::move(wrap_expr);
18307 }
18308 else
18309 {
18310 // These are of different widths, so we cannot do a straight bitcast.
18311 if (expected_width != expr_type.width)
18312 expr = join(ts: type_to_glsl(type: expr_type), ts: "(", ts&: expr, ts: ")");
18313 else
18314 expr = bitcast_expression(target_type: expr_type, expr_type: expected_type, expr);
18315 }
18316 }
18317}
18318
18319void CompilerMSL::cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
18320{
18321 bool is_packed = has_extended_decoration(id: target_id, decoration: SPIRVCrossDecorationPhysicalTypePacked);
18322 auto *target_expr = maybe_get<SPIRExpression>(id: target_id);
18323 auto *var = maybe_get_backing_variable(chain: target_id);
18324 const SPIRType *var_type = nullptr, *phys_type = nullptr;
18325
18326 if (uint32_t phys_id = get_extended_decoration(id: target_id, decoration: SPIRVCrossDecorationPhysicalTypeID))
18327 phys_type = &get<SPIRType>(id: phys_id);
18328 else
18329 phys_type = &expr_type;
18330
18331 if (var)
18332 {
18333 target_id = var->self;
18334 var_type = &get_variable_data_type(var: *var);
18335 }
18336
18337 bool rewrite_boolean_store =
18338 expr_type.basetype == SPIRType::Boolean &&
18339 (var && (var->storage == StorageClassWorkgroup || var_type->basetype == SPIRType::Struct));
18340
18341 // Type fixups for workgroup variables or struct members if they are booleans.
18342 if (rewrite_boolean_store)
18343 {
18344 if (is_array(type: expr_type))
18345 {
18346 expr = to_rerolled_array_expression(parent_type: *var_type, expr, type: expr_type);
18347 }
18348 else
18349 {
18350 auto short_type = expr_type;
18351 short_type.basetype = SPIRType::Short;
18352 expr = join(ts: type_to_glsl(type: short_type), ts: "(", ts&: expr, ts: ")");
18353 }
18354 }
18355
18356 // Type fixups for workgroup variables if they are matrices.
18357 // Don't do fixup for packed types; those are handled specially.
18358 // FIXME: Maybe use a type like spvStorageMatrix for packed matrices?
18359 if (!msl_options.supports_msl_version(major: 3, minor: 0) && var &&
18360 (var->storage == StorageClassWorkgroup ||
18361 (var_type->basetype == SPIRType::Struct &&
18362 has_extended_decoration(id: var_type->self, decoration: SPIRVCrossDecorationWorkgroupStruct) && !is_packed)) &&
18363 expr_type.columns > 1)
18364 {
18365 SPIRType matrix_type = *phys_type;
18366 if (target_expr && target_expr->need_transpose)
18367 swap(a&: matrix_type.vecsize, b&: matrix_type.columns);
18368 expr = join(ts: "spvStorage_", ts: type_to_glsl(type: matrix_type), ts: "(", ts&: expr, ts: ")");
18369 }
18370
18371 // Only interested in standalone builtin variables.
18372 if (!has_decoration(id: target_id, decoration: DecorationBuiltIn))
18373 return;
18374
18375 auto builtin = static_cast<BuiltIn>(get_decoration(id: target_id, decoration: DecorationBuiltIn));
18376 auto expected_type = expr_type.basetype;
18377 auto expected_width = expr_type.width;
18378 switch (builtin)
18379 {
18380 case BuiltInLayer:
18381 case BuiltInViewportIndex:
18382 case BuiltInFragStencilRefEXT:
18383 case BuiltInPrimitiveId:
18384 case BuiltInViewIndex:
18385 expected_type = SPIRType::UInt;
18386 expected_width = 32;
18387 break;
18388
18389 case BuiltInTessLevelInner:
18390 case BuiltInTessLevelOuter:
18391 expected_type = SPIRType::Half;
18392 expected_width = 16;
18393 break;
18394
18395 default:
18396 break;
18397 }
18398
18399 if (expected_type != expr_type.basetype)
18400 {
18401 if (expected_width != expr_type.width)
18402 {
18403 // These are of different widths, so we cannot do a straight bitcast.
18404 auto type = expr_type;
18405 type.basetype = expected_type;
18406 type.width = expected_width;
18407 expr = join(ts: type_to_glsl(type), ts: "(", ts&: expr, ts: ")");
18408 }
18409 else
18410 {
18411 auto type = expr_type;
18412 type.basetype = expected_type;
18413 expr = bitcast_expression(target_type: type, expr_type: expr_type.basetype, expr);
18414 }
18415 }
18416}
18417
18418string CompilerMSL::to_initializer_expression(const SPIRVariable &var)
18419{
18420 // We risk getting an array initializer here with MSL. If we have an array.
18421 // FIXME: We cannot handle non-constant arrays being initialized.
18422 // We will need to inject spvArrayCopy here somehow ...
18423 auto &type = get<SPIRType>(id: var.basetype);
18424 string expr;
18425 if (ir.ids[var.initializer].get_type() == TypeConstant &&
18426 (!type.array.empty() || type.basetype == SPIRType::Struct))
18427 expr = constant_expression(c: get<SPIRConstant>(id: var.initializer));
18428 else
18429 expr = CompilerGLSL::to_initializer_expression(var);
18430 // If the initializer has more vector components than the variable, add a swizzle.
18431 // FIXME: This can't handle arrays or structs.
18432 auto &init_type = expression_type(id: var.initializer);
18433 if (type.array.empty() && type.basetype != SPIRType::Struct && init_type.vecsize > type.vecsize)
18434 expr = enclose_expression(expr: expr + vector_swizzle(vecsize: type.vecsize, index: 0));
18435 return expr;
18436}
18437
18438string CompilerMSL::to_zero_initialized_expression(uint32_t)
18439{
18440 return "{}";
18441}
18442
18443bool CompilerMSL::descriptor_set_is_argument_buffer(uint32_t desc_set) const
18444{
18445 if (!msl_options.argument_buffers)
18446 return false;
18447 if (desc_set >= kMaxArgumentBuffers)
18448 return false;
18449
18450 return (argument_buffer_discrete_mask & (1u << desc_set)) == 0;
18451}
18452
18453bool CompilerMSL::is_supported_argument_buffer_type(const SPIRType &type) const
18454{
18455 // iOS Tier 1 argument buffers do not support writable images.
18456 // When the argument buffer is encoded, we don't know whether this image will have a
18457 // NonWritable decoration, so just use discrete arguments for all storage images on iOS.
18458 bool is_supported_type = !(type.basetype == SPIRType::Image &&
18459 type.image.sampled == 2 &&
18460 msl_options.is_ios() &&
18461 msl_options.argument_buffers_tier <= Options::ArgumentBuffersTier::Tier1);
18462 return is_supported_type && !type_is_msl_framebuffer_fetch(type);
18463}
18464
18465void CompilerMSL::emit_argument_buffer_aliased_descriptor(const SPIRVariable &aliased_var,
18466 const SPIRVariable &base_var)
18467{
18468 // To deal with buffer <-> image aliasing, we need to perform an unholy UB ritual.
18469 // A texture type in Metal 3.0 is a pointer. However, we cannot simply cast a pointer to texture.
18470 // What we *can* do is to cast pointer-to-pointer to pointer-to-texture.
18471
18472 // We need to explicitly reach into the descriptor buffer lvalue, not any spvDescriptorArray wrapper.
18473 auto *var_meta = ir.find_meta(id: base_var.self);
18474 bool old_explicit_qualifier = var_meta && var_meta->decoration.qualified_alias_explicit_override;
18475 if (var_meta)
18476 var_meta->decoration.qualified_alias_explicit_override = false;
18477 auto unqualified_name = to_name(id: base_var.self, allow_alias: false);
18478 if (var_meta)
18479 var_meta->decoration.qualified_alias_explicit_override = old_explicit_qualifier;
18480
18481 // For non-arrayed buffers, we have already performed a de-reference.
18482 // We need a proper lvalue to cast, so strip away the de-reference.
18483 if (unqualified_name.size() > 2 && unqualified_name[0] == '(' && unqualified_name[1] == '*')
18484 {
18485 unqualified_name.erase(first: unqualified_name.begin(), last: unqualified_name.begin() + 2);
18486 unqualified_name.pop_back();
18487 }
18488
18489 string name;
18490
18491 auto &var_type = get<SPIRType>(id: aliased_var.basetype);
18492 auto &data_type = get_variable_data_type(var: aliased_var);
18493 string descriptor_storage = descriptor_address_space(id: aliased_var.self, storage: aliased_var.storage, plain_address_space: "");
18494
18495 if (aliased_var.storage == StorageClassUniformConstant)
18496 {
18497 if (is_var_runtime_size_array(var: aliased_var))
18498 {
18499 // This becomes a plain pointer to spvDescriptor.
18500 name = join(ts: "reinterpret_cast<", ts&: descriptor_storage, ts: " ",
18501 ts: type_to_glsl(type: get_variable_data_type(var: aliased_var), id: aliased_var.self, member: true), ts: ">(&",
18502 ts&: unqualified_name, ts: ")");
18503 }
18504 else
18505 {
18506 name = join(ts: "reinterpret_cast<", ts&: descriptor_storage, ts: " ",
18507 ts: type_to_glsl(type: get_variable_data_type(var: aliased_var), id: aliased_var.self, member: true), ts: " &>(",
18508 ts&: unqualified_name, ts: ");");
18509 }
18510 }
18511 else
18512 {
18513 // Buffer types.
18514 bool old_is_using_builtin_array = is_using_builtin_array;
18515 is_using_builtin_array = true;
18516
18517 bool needs_post_cast_deref = !is_array(type: data_type);
18518 string ref_type = needs_post_cast_deref ? "&" : join(ts: "(&)", ts: type_to_array_glsl(type: var_type, variable_id: aliased_var.self));
18519
18520 if (is_var_runtime_size_array(var: aliased_var))
18521 {
18522 name = join(ts: "reinterpret_cast<",
18523 ts: type_to_glsl(type: var_type, id: aliased_var.self, member: true), ts: " ", ts&: descriptor_storage, ts: " *>(&",
18524 ts&: unqualified_name, ts: ")");
18525 }
18526 else
18527 {
18528 name = join(ts: needs_post_cast_deref ? "*" : "", ts: "reinterpret_cast<",
18529 ts: type_to_glsl(type: var_type, id: aliased_var.self, member: true), ts: " ", ts&: descriptor_storage, ts: " ",
18530 ts&: ref_type,
18531 ts: ">(", ts&: unqualified_name, ts: ");");
18532 }
18533
18534 if (needs_post_cast_deref)
18535 descriptor_storage = get_type_address_space(type: var_type, id: aliased_var.self, argument: false);
18536
18537 // These kinds of ridiculous casts trigger warnings in compiler. Just ignore them.
18538 if (!suppress_incompatible_pointer_types_discard_qualifiers)
18539 {
18540 suppress_incompatible_pointer_types_discard_qualifiers = true;
18541 force_recompile_guarantee_forward_progress();
18542 }
18543
18544 is_using_builtin_array = old_is_using_builtin_array;
18545 }
18546
18547 if (!is_var_runtime_size_array(var: aliased_var))
18548 {
18549 // Lower to temporary, so drop the qualification.
18550 set_qualified_name(id: aliased_var.self, name: "");
18551 statement(ts&: descriptor_storage, ts: " auto &", ts: to_name(id: aliased_var.self), ts: " = ", ts&: name);
18552 }
18553 else
18554 {
18555 // This alias may have already been used to emit an entry point declaration. If there is a mismatch, we need a recompile.
18556 // Moving this code to be run earlier will also conflict,
18557 // because we need the qualified alias for the base resource,
18558 // so forcing recompile until things sync up is the least invasive method for now.
18559 if (ir.meta[aliased_var.self].decoration.qualified_alias != name)
18560 force_recompile();
18561
18562 // This will get wrapped in a separate temporary when a spvDescriptorArray wrapper is emitted.
18563 set_qualified_name(id: aliased_var.self, name);
18564 }
18565}
18566
18567void CompilerMSL::analyze_argument_buffers()
18568{
18569 // Gather all used resources and sort them out into argument buffers.
18570 // Each argument buffer corresponds to a descriptor set in SPIR-V.
18571 // The [[id(N)]] values used correspond to the resource mapping we have for MSL.
18572 // Otherwise, the binding number is used, but this is generally not safe some types like
18573 // combined image samplers and arrays of resources. Metal needs different indices here,
18574 // while SPIR-V can have one descriptor set binding. To use argument buffers in practice,
18575 // you will need to use the remapping from the API.
18576 for (auto &id : argument_buffer_ids)
18577 id = 0;
18578
18579 // Output resources, sorted by resource index & type.
18580 struct Resource
18581 {
18582 SPIRVariable *var;
18583 string name;
18584 SPIRType::BaseType basetype;
18585 uint32_t index;
18586 uint32_t plane_count;
18587 uint32_t plane;
18588 uint32_t overlapping_var_id;
18589 };
18590 SmallVector<Resource> resources_in_set[kMaxArgumentBuffers];
18591 SmallVector<uint32_t> inline_block_vars;
18592
18593 bool set_needs_swizzle_buffer[kMaxArgumentBuffers] = {};
18594 bool set_needs_buffer_sizes[kMaxArgumentBuffers] = {};
18595 bool needs_buffer_sizes = false;
18596
18597 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t self, SPIRVariable &var) {
18598 if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
18599 var.storage == StorageClassStorageBuffer) &&
18600 !is_hidden_variable(var))
18601 {
18602 uint32_t desc_set = get_decoration(id: self, decoration: DecorationDescriptorSet);
18603 // Ignore if it's part of a push descriptor set.
18604 if (!descriptor_set_is_argument_buffer(desc_set))
18605 return;
18606
18607 uint32_t var_id = var.self;
18608 auto &type = get_variable_data_type(var);
18609
18610 if (desc_set >= kMaxArgumentBuffers)
18611 SPIRV_CROSS_THROW("Descriptor set index is out of range.");
18612
18613 const MSLConstexprSampler *constexpr_sampler = nullptr;
18614 if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler)
18615 {
18616 constexpr_sampler = find_constexpr_sampler(id: var_id);
18617 if (constexpr_sampler)
18618 {
18619 // Mark this ID as a constexpr sampler for later in case it came from set/bindings.
18620 constexpr_samplers_by_id[var_id] = *constexpr_sampler;
18621 }
18622 }
18623
18624 uint32_t binding = get_decoration(id: var_id, decoration: DecorationBinding);
18625 if (type.basetype == SPIRType::SampledImage)
18626 {
18627 add_resource_name(id: var_id);
18628
18629 uint32_t plane_count = 1;
18630 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
18631 plane_count = constexpr_sampler->planes;
18632
18633 for (uint32_t i = 0; i < plane_count; i++)
18634 {
18635 uint32_t image_resource_index = get_metal_resource_index(var, basetype: SPIRType::Image, plane: i);
18636 resources_in_set[desc_set].push_back(
18637 t: { .var: &var, .name: to_name(id: var_id), .basetype: SPIRType::Image, .index: image_resource_index, .plane_count: plane_count, .plane: i, .overlapping_var_id: 0 });
18638 }
18639
18640 if (type.image.dim != DimBuffer && !constexpr_sampler)
18641 {
18642 uint32_t sampler_resource_index = get_metal_resource_index(var, basetype: SPIRType::Sampler);
18643 resources_in_set[desc_set].push_back(
18644 t: { .var: &var, .name: to_sampler_expression(id: var_id), .basetype: SPIRType::Sampler, .index: sampler_resource_index, .plane_count: 1, .plane: 0, .overlapping_var_id: 0 });
18645 }
18646 }
18647 else if (inline_uniform_blocks.count(x: SetBindingPair{ .desc_set: desc_set, .binding: binding }))
18648 {
18649 inline_block_vars.push_back(t: var_id);
18650 }
18651 else if (!constexpr_sampler && is_supported_argument_buffer_type(type))
18652 {
18653 // constexpr samplers are not declared as resources.
18654 // Inline uniform blocks are always emitted at the end.
18655 add_resource_name(id: var_id);
18656
18657 uint32_t resource_index = get_metal_resource_index(var, basetype: type.basetype);
18658
18659 resources_in_set[desc_set].push_back(
18660 t: { .var: &var, .name: to_name(id: var_id), .basetype: type.basetype, .index: resource_index, .plane_count: 1, .plane: 0, .overlapping_var_id: 0 });
18661
18662 // Emulate texture2D atomic operations
18663 if (atomic_image_vars_emulated.count(x: var.self))
18664 {
18665 uint32_t buffer_resource_index = get_metal_resource_index(var, basetype: SPIRType::AtomicCounter, plane: 0);
18666 resources_in_set[desc_set].push_back(
18667 t: { .var: &var, .name: to_name(id: var_id) + "_atomic", .basetype: SPIRType::Struct, .index: buffer_resource_index, .plane_count: 1, .plane: 0, .overlapping_var_id: 0 });
18668 }
18669 }
18670
18671 // Check if this descriptor set needs a swizzle buffer.
18672 if (needs_swizzle_buffer_def && is_sampled_image_type(type))
18673 set_needs_swizzle_buffer[desc_set] = true;
18674 else if (buffer_requires_array_length(id: var_id))
18675 {
18676 set_needs_buffer_sizes[desc_set] = true;
18677 needs_buffer_sizes = true;
18678 }
18679 }
18680 });
18681
18682 if (needs_swizzle_buffer_def || needs_buffer_sizes)
18683 {
18684 uint32_t uint_ptr_type_id = 0;
18685
18686 // We might have to add a swizzle buffer resource to the set.
18687 for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
18688 {
18689 if (!set_needs_swizzle_buffer[desc_set] && !set_needs_buffer_sizes[desc_set])
18690 continue;
18691
18692 if (uint_ptr_type_id == 0)
18693 {
18694 uint_ptr_type_id = ir.increase_bound_by(count: 1);
18695
18696 // Create a buffer to hold extra data, including the swizzle constants.
18697 SPIRType uint_type_pointer = get_uint_type();
18698 uint_type_pointer.op = OpTypePointer;
18699 uint_type_pointer.pointer = true;
18700 uint_type_pointer.pointer_depth++;
18701 uint_type_pointer.parent_type = get_uint_type_id();
18702 uint_type_pointer.storage = StorageClassUniform;
18703 set<SPIRType>(id: uint_ptr_type_id, args&: uint_type_pointer);
18704 set_decoration(id: uint_ptr_type_id, decoration: DecorationArrayStride, argument: 4);
18705 }
18706
18707 if (set_needs_swizzle_buffer[desc_set])
18708 {
18709 uint32_t var_id = ir.increase_bound_by(count: 1);
18710 auto &var = set<SPIRVariable>(id: var_id, args&: uint_ptr_type_id, args: StorageClassUniformConstant);
18711 set_name(id: var_id, name: "spvSwizzleConstants");
18712 set_decoration(id: var_id, decoration: DecorationDescriptorSet, argument: desc_set);
18713 set_decoration(id: var_id, decoration: DecorationBinding, argument: kSwizzleBufferBinding);
18714 resources_in_set[desc_set].push_back(
18715 t: { .var: &var, .name: to_name(id: var_id), .basetype: SPIRType::UInt, .index: get_metal_resource_index(var, basetype: SPIRType::UInt), .plane_count: 1, .plane: 0, .overlapping_var_id: 0 });
18716 }
18717
18718 if (set_needs_buffer_sizes[desc_set])
18719 {
18720 uint32_t var_id = ir.increase_bound_by(count: 1);
18721 auto &var = set<SPIRVariable>(id: var_id, args&: uint_ptr_type_id, args: StorageClassUniformConstant);
18722 set_name(id: var_id, name: "spvBufferSizeConstants");
18723 set_decoration(id: var_id, decoration: DecorationDescriptorSet, argument: desc_set);
18724 set_decoration(id: var_id, decoration: DecorationBinding, argument: kBufferSizeBufferBinding);
18725 resources_in_set[desc_set].push_back(
18726 t: { .var: &var, .name: to_name(id: var_id), .basetype: SPIRType::UInt, .index: get_metal_resource_index(var, basetype: SPIRType::UInt), .plane_count: 1, .plane: 0, .overlapping_var_id: 0 });
18727 }
18728 }
18729 }
18730
18731 // Now add inline uniform blocks.
18732 for (uint32_t var_id : inline_block_vars)
18733 {
18734 auto &var = get<SPIRVariable>(id: var_id);
18735 uint32_t desc_set = get_decoration(id: var_id, decoration: DecorationDescriptorSet);
18736 add_resource_name(id: var_id);
18737 resources_in_set[desc_set].push_back(
18738 t: { .var: &var, .name: to_name(id: var_id), .basetype: SPIRType::Struct, .index: get_metal_resource_index(var, basetype: SPIRType::Struct), .plane_count: 1, .plane: 0, .overlapping_var_id: 0 });
18739 }
18740
18741 for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
18742 {
18743 auto &resources = resources_in_set[desc_set];
18744 if (resources.empty())
18745 continue;
18746
18747 assert(descriptor_set_is_argument_buffer(desc_set));
18748
18749 uint32_t next_id = ir.increase_bound_by(count: 3);
18750 uint32_t type_id = next_id + 1;
18751 uint32_t ptr_type_id = next_id + 2;
18752 argument_buffer_ids[desc_set] = next_id;
18753
18754 auto &buffer_type = set<SPIRType>(id: type_id, args: OpTypeStruct);
18755
18756 buffer_type.basetype = SPIRType::Struct;
18757
18758 if ((argument_buffer_device_storage_mask & (1u << desc_set)) != 0)
18759 {
18760 buffer_type.storage = StorageClassStorageBuffer;
18761 // Make sure the argument buffer gets marked as const device.
18762 set_decoration(id: next_id, decoration: DecorationNonWritable);
18763 // Need to mark the type as a Block to enable this.
18764 set_decoration(id: type_id, decoration: DecorationBlock);
18765 }
18766 else
18767 buffer_type.storage = StorageClassUniform;
18768
18769 auto buffer_type_name = join(ts: "spvDescriptorSetBuffer", ts&: desc_set);
18770 set_name(id: type_id, name: buffer_type_name);
18771
18772 auto &ptr_type = set<SPIRType>(id: ptr_type_id, args: OpTypePointer);
18773 ptr_type = buffer_type;
18774 ptr_type.op = spv::OpTypePointer;
18775 ptr_type.pointer = true;
18776 ptr_type.pointer_depth++;
18777 ptr_type.parent_type = type_id;
18778
18779 uint32_t buffer_variable_id = next_id;
18780 auto &buffer_var = set<SPIRVariable>(id: buffer_variable_id, args&: ptr_type_id, args: StorageClassUniform);
18781 auto buffer_name = join(ts: "spvDescriptorSet", ts&: desc_set);
18782 set_name(id: buffer_variable_id, name: buffer_name);
18783
18784 // Ids must be emitted in ID order.
18785 stable_sort(first: begin(cont&: resources), last: end(cont&: resources), comp: [&](const Resource &lhs, const Resource &rhs) -> bool {
18786 return tie(args: lhs.index, args: lhs.basetype) < tie(args: rhs.index, args: rhs.basetype);
18787 });
18788
18789 for (size_t i = 0; i < resources.size() - 1; i++)
18790 {
18791 auto &r1 = resources[i];
18792 auto &r2 = resources[i + 1];
18793
18794 if (r1.index == r2.index)
18795 {
18796 if (r1.overlapping_var_id)
18797 r2.overlapping_var_id = r1.overlapping_var_id;
18798 else
18799 r2.overlapping_var_id = r1.var->self;
18800
18801 set_extended_decoration(id: r2.var->self, decoration: SPIRVCrossDecorationOverlappingBinding, value: r2.overlapping_var_id);
18802 }
18803 }
18804
18805 uint32_t member_index = 0;
18806 uint32_t next_arg_buff_index = 0;
18807 uint32_t prev_was_scalar_on_array_offset = 0;
18808 for (auto &resource : resources)
18809 {
18810 auto &var = *resource.var;
18811 auto &type = get_variable_data_type(var);
18812
18813 if (is_var_runtime_size_array(var) && (argument_buffer_device_storage_mask & (1u << desc_set)) == 0)
18814 SPIRV_CROSS_THROW("Runtime sized variables must be in device storage argument buffers.");
18815
18816 // If needed, synthesize and add padding members.
18817 // member_index and next_arg_buff_index are incremented when padding members are added.
18818 if (msl_options.pad_argument_buffer_resources && resource.plane == 0 && resource.overlapping_var_id == 0)
18819 {
18820 auto rez_bind = get_argument_buffer_resource(desc_set, arg_idx: next_arg_buff_index - prev_was_scalar_on_array_offset);
18821 rez_bind.count -= prev_was_scalar_on_array_offset;
18822
18823 while (resource.index > next_arg_buff_index)
18824 {
18825 switch (rez_bind.basetype)
18826 {
18827 case SPIRType::Void:
18828 case SPIRType::Boolean:
18829 case SPIRType::SByte:
18830 case SPIRType::UByte:
18831 case SPIRType::Short:
18832 case SPIRType::UShort:
18833 case SPIRType::Int:
18834 case SPIRType::UInt:
18835 case SPIRType::Int64:
18836 case SPIRType::UInt64:
18837 case SPIRType::AtomicCounter:
18838 case SPIRType::Half:
18839 case SPIRType::Float:
18840 case SPIRType::Double:
18841 add_argument_buffer_padding_buffer_type(struct_type&: buffer_type, mbr_idx&: member_index, arg_buff_index&: next_arg_buff_index, rez_bind);
18842 break;
18843 case SPIRType::Image:
18844 add_argument_buffer_padding_image_type(struct_type&: buffer_type, mbr_idx&: member_index, arg_buff_index&: next_arg_buff_index, rez_bind);
18845 break;
18846 case SPIRType::Sampler:
18847 add_argument_buffer_padding_sampler_type(struct_type&: buffer_type, mbr_idx&: member_index, arg_buff_index&: next_arg_buff_index, rez_bind);
18848 break;
18849 case SPIRType::SampledImage:
18850 if (next_arg_buff_index == rez_bind.msl_sampler)
18851 add_argument_buffer_padding_sampler_type(struct_type&: buffer_type, mbr_idx&: member_index, arg_buff_index&: next_arg_buff_index, rez_bind);
18852 else
18853 add_argument_buffer_padding_image_type(struct_type&: buffer_type, mbr_idx&: member_index, arg_buff_index&: next_arg_buff_index, rez_bind);
18854 break;
18855 default:
18856 break;
18857 }
18858
18859 // After padding, retrieve the resource again. It will either be more padding, or the actual resource.
18860 rez_bind = get_argument_buffer_resource(desc_set, arg_idx: next_arg_buff_index);
18861 prev_was_scalar_on_array_offset = 0;
18862 }
18863
18864 uint32_t count = rez_bind.count;
18865
18866 // If the current resource is an array in the descriptor, but is a scalar
18867 // in the shader, only the first element will be consumed. The next pass
18868 // will add a padding member to consume the remaining array elements.
18869 if (count > 1 && type.array.empty())
18870 count = prev_was_scalar_on_array_offset = 1;
18871
18872 // Adjust the number of slots consumed by current member itself.
18873 next_arg_buff_index += resource.plane_count * count;
18874 }
18875
18876 string mbr_name = ensure_valid_name(name: resource.name, pfx: "m");
18877 if (resource.plane > 0)
18878 mbr_name += join(ts&: plane_name_suffix, ts&: resource.plane);
18879 set_member_name(id: buffer_type.self, index: member_index, name: mbr_name);
18880
18881 if (resource.basetype == SPIRType::Sampler && type.basetype != SPIRType::Sampler)
18882 {
18883 // Have to synthesize a sampler type here.
18884
18885 bool type_is_array = !type.array.empty();
18886 uint32_t sampler_type_id = ir.increase_bound_by(count: type_is_array ? 2 : 1);
18887 auto &new_sampler_type = set<SPIRType>(id: sampler_type_id, args: OpTypeSampler);
18888 new_sampler_type.basetype = SPIRType::Sampler;
18889 new_sampler_type.storage = StorageClassUniformConstant;
18890
18891 if (type_is_array)
18892 {
18893 uint32_t sampler_type_array_id = sampler_type_id + 1;
18894 auto &sampler_type_array = set<SPIRType>(id: sampler_type_array_id, args: OpTypeArray);
18895 sampler_type_array = new_sampler_type;
18896 sampler_type_array.array = type.array;
18897 sampler_type_array.array_size_literal = type.array_size_literal;
18898 sampler_type_array.parent_type = sampler_type_id;
18899 buffer_type.member_types.push_back(t: sampler_type_array_id);
18900 }
18901 else
18902 buffer_type.member_types.push_back(t: sampler_type_id);
18903 }
18904 else
18905 {
18906 uint32_t binding = get_decoration(id: var.self, decoration: DecorationBinding);
18907 SetBindingPair pair = { .desc_set: desc_set, .binding: binding };
18908
18909 if (resource.basetype == SPIRType::Image || resource.basetype == SPIRType::Sampler ||
18910 resource.basetype == SPIRType::SampledImage)
18911 {
18912 // Drop pointer information when we emit the resources into a struct.
18913 buffer_type.member_types.push_back(t: get_variable_data_type_id(var));
18914 if (has_extended_decoration(id: var.self, decoration: SPIRVCrossDecorationOverlappingBinding))
18915 {
18916 if (!msl_options.supports_msl_version(major: 3, minor: 0))
18917 SPIRV_CROSS_THROW("Full mutable aliasing of argument buffer descriptors only works on Metal 3+.");
18918
18919 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
18920 entry_func.fixup_hooks_in.push_back(t: [this, resource]() {
18921 emit_argument_buffer_aliased_descriptor(aliased_var: *resource.var, base_var: this->get<SPIRVariable>(id: resource.overlapping_var_id));
18922 });
18923 }
18924 else if (resource.plane == 0)
18925 {
18926 set_qualified_name(id: var.self, name: join(ts: to_name(id: buffer_variable_id), ts: ".", ts&: mbr_name));
18927 }
18928 }
18929 else if (buffers_requiring_dynamic_offset.count(x: pair))
18930 {
18931 // Don't set the qualified name here; we'll define a variable holding the corrected buffer address later.
18932 buffer_type.member_types.push_back(t: var.basetype);
18933 buffers_requiring_dynamic_offset[pair].second = var.self;
18934 }
18935 else if (inline_uniform_blocks.count(x: pair))
18936 {
18937 // Put the buffer block itself into the argument buffer.
18938 buffer_type.member_types.push_back(t: get_variable_data_type_id(var));
18939 set_qualified_name(id: var.self, name: join(ts: to_name(id: buffer_variable_id), ts: ".", ts&: mbr_name));
18940 }
18941 else if (atomic_image_vars_emulated.count(x: var.self))
18942 {
18943 // Emulate texture2D atomic operations.
18944 // Don't set the qualified name: it's already set for this variable,
18945 // and the code that references the buffer manually appends "_atomic"
18946 // to the name.
18947 uint32_t offset = ir.increase_bound_by(count: 2);
18948 uint32_t atomic_type_id = offset;
18949 uint32_t type_ptr_id = offset + 1;
18950
18951 SPIRType atomic_type { OpTypeInt };
18952 atomic_type.basetype = SPIRType::AtomicCounter;
18953 atomic_type.width = 32;
18954 atomic_type.vecsize = 1;
18955 set<SPIRType>(id: atomic_type_id, args&: atomic_type);
18956
18957 atomic_type.op = OpTypePointer;
18958 atomic_type.pointer = true;
18959 atomic_type.pointer_depth++;
18960 atomic_type.parent_type = atomic_type_id;
18961 atomic_type.storage = StorageClassStorageBuffer;
18962 auto &atomic_ptr_type = set<SPIRType>(id: type_ptr_id, args&: atomic_type);
18963 atomic_ptr_type.self = atomic_type_id;
18964
18965 buffer_type.member_types.push_back(t: type_ptr_id);
18966 }
18967 else
18968 {
18969 buffer_type.member_types.push_back(t: var.basetype);
18970 if (has_extended_decoration(id: var.self, decoration: SPIRVCrossDecorationOverlappingBinding))
18971 {
18972 // Casting raw pointers is fine since their ABI is fixed, but anything opaque is deeply questionable on Metal 2.
18973 if (get<SPIRVariable>(id: resource.overlapping_var_id).storage == StorageClassUniformConstant &&
18974 !msl_options.supports_msl_version(major: 3, minor: 0))
18975 {
18976 SPIRV_CROSS_THROW("Full mutable aliasing of argument buffer descriptors only works on Metal 3+.");
18977 }
18978
18979 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
18980
18981 entry_func.fixup_hooks_in.push_back(t: [this, resource]() {
18982 emit_argument_buffer_aliased_descriptor(aliased_var: *resource.var, base_var: this->get<SPIRVariable>(id: resource.overlapping_var_id));
18983 });
18984 }
18985 else if (type.array.empty())
18986 set_qualified_name(id: var.self, name: join(ts: "(*", ts: to_name(id: buffer_variable_id), ts: ".", ts&: mbr_name, ts: ")"));
18987 else
18988 set_qualified_name(id: var.self, name: join(ts: to_name(id: buffer_variable_id), ts: ".", ts&: mbr_name));
18989 }
18990 }
18991
18992 set_extended_member_decoration(type: buffer_type.self, index: member_index, decoration: SPIRVCrossDecorationResourceIndexPrimary,
18993 value: resource.index);
18994 set_extended_member_decoration(type: buffer_type.self, index: member_index, decoration: SPIRVCrossDecorationInterfaceOrigID,
18995 value: var.self);
18996 if (has_extended_decoration(id: var.self, decoration: SPIRVCrossDecorationOverlappingBinding))
18997 set_extended_member_decoration(type: buffer_type.self, index: member_index, decoration: SPIRVCrossDecorationOverlappingBinding);
18998 member_index++;
18999 }
19000
19001 if (msl_options.replace_recursive_inputs && type_contains_recursion(type: buffer_type))
19002 {
19003 recursive_inputs.insert(x: type_id);
19004 auto &entry_func = this->get<SPIRFunction>(id: ir.default_entry_point);
19005 auto addr_space = get_argument_address_space(argument: buffer_var);
19006 entry_func.fixup_hooks_in.push_back(t: [this, addr_space, buffer_name, buffer_type_name]() {
19007 statement(ts: addr_space, ts: " auto& ", ts: buffer_name, ts: " = *(", ts: addr_space, ts: " ", ts: buffer_type_name, ts: "*)", ts: buffer_name, ts: "_vp;");
19008 });
19009 }
19010 }
19011}
19012
19013// Return the resource type of the app-provided resources for the descriptor set,
19014// that matches the resource index of the argument buffer index.
19015// This is a two-step lookup, first lookup the resource binding number from the argument buffer index,
19016// then lookup the resource binding using the binding number.
19017const MSLResourceBinding &CompilerMSL::get_argument_buffer_resource(uint32_t desc_set, uint32_t arg_idx) const
19018{
19019 auto stage = get_entry_point().model;
19020 StageSetBinding arg_idx_tuple = { .model: stage, .desc_set: desc_set, .binding: arg_idx };
19021 auto arg_itr = resource_arg_buff_idx_to_binding_number.find(x: arg_idx_tuple);
19022 if (arg_itr != end(cont: resource_arg_buff_idx_to_binding_number))
19023 {
19024 StageSetBinding bind_tuple = { .model: stage, .desc_set: desc_set, .binding: arg_itr->second };
19025 auto bind_itr = resource_bindings.find(x: bind_tuple);
19026 if (bind_itr != end(cont: resource_bindings))
19027 return bind_itr->second.first;
19028 }
19029 SPIRV_CROSS_THROW("Argument buffer resource base type could not be determined. When padding argument buffer "
19030 "elements, all descriptor set resources must be supplied with a base type by the app.");
19031}
19032
19033// Adds an argument buffer padding argument buffer type as one or more members of the struct type at the member index.
19034// Metal does not support arrays of buffers, so these are emitted as multiple struct members.
19035void CompilerMSL::add_argument_buffer_padding_buffer_type(SPIRType &struct_type, uint32_t &mbr_idx,
19036 uint32_t &arg_buff_index, MSLResourceBinding &rez_bind)
19037{
19038 if (!argument_buffer_padding_buffer_type_id)
19039 {
19040 uint32_t buff_type_id = ir.increase_bound_by(count: 2);
19041 auto &buff_type = set<SPIRType>(id: buff_type_id, args: OpNop);
19042 buff_type.basetype = rez_bind.basetype;
19043 buff_type.storage = StorageClassUniformConstant;
19044
19045 uint32_t ptr_type_id = buff_type_id + 1;
19046 auto &ptr_type = set<SPIRType>(id: ptr_type_id, args: OpTypePointer);
19047 ptr_type = buff_type;
19048 ptr_type.op = spv::OpTypePointer;
19049 ptr_type.pointer = true;
19050 ptr_type.pointer_depth++;
19051 ptr_type.parent_type = buff_type_id;
19052
19053 argument_buffer_padding_buffer_type_id = ptr_type_id;
19054 }
19055
19056 add_argument_buffer_padding_type(mbr_type_id: argument_buffer_padding_buffer_type_id, struct_type, mbr_idx, arg_buff_index, count: rez_bind.count);
19057}
19058
19059// Adds an argument buffer padding argument image type as a member of the struct type at the member index.
19060void CompilerMSL::add_argument_buffer_padding_image_type(SPIRType &struct_type, uint32_t &mbr_idx,
19061 uint32_t &arg_buff_index, MSLResourceBinding &rez_bind)
19062{
19063 if (!argument_buffer_padding_image_type_id)
19064 {
19065 uint32_t base_type_id = ir.increase_bound_by(count: 2);
19066 auto &base_type = set<SPIRType>(id: base_type_id, args: OpTypeFloat);
19067 base_type.basetype = SPIRType::Float;
19068 base_type.width = 32;
19069
19070 uint32_t img_type_id = base_type_id + 1;
19071 auto &img_type = set<SPIRType>(id: img_type_id, args: OpTypeImage);
19072 img_type.basetype = SPIRType::Image;
19073 img_type.storage = StorageClassUniformConstant;
19074
19075 img_type.image.type = base_type_id;
19076 img_type.image.dim = Dim2D;
19077 img_type.image.depth = false;
19078 img_type.image.arrayed = false;
19079 img_type.image.ms = false;
19080 img_type.image.sampled = 1;
19081 img_type.image.format = ImageFormatUnknown;
19082 img_type.image.access = AccessQualifierMax;
19083
19084 argument_buffer_padding_image_type_id = img_type_id;
19085 }
19086
19087 add_argument_buffer_padding_type(mbr_type_id: argument_buffer_padding_image_type_id, struct_type, mbr_idx, arg_buff_index, count: rez_bind.count);
19088}
19089
19090// Adds an argument buffer padding argument sampler type as a member of the struct type at the member index.
19091void CompilerMSL::add_argument_buffer_padding_sampler_type(SPIRType &struct_type, uint32_t &mbr_idx,
19092 uint32_t &arg_buff_index, MSLResourceBinding &rez_bind)
19093{
19094 if (!argument_buffer_padding_sampler_type_id)
19095 {
19096 uint32_t samp_type_id = ir.increase_bound_by(count: 1);
19097 auto &samp_type = set<SPIRType>(id: samp_type_id, args: OpTypeSampler);
19098 samp_type.basetype = SPIRType::Sampler;
19099 samp_type.storage = StorageClassUniformConstant;
19100
19101 argument_buffer_padding_sampler_type_id = samp_type_id;
19102 }
19103
19104 add_argument_buffer_padding_type(mbr_type_id: argument_buffer_padding_sampler_type_id, struct_type, mbr_idx, arg_buff_index, count: rez_bind.count);
19105}
19106
19107// Adds the argument buffer padding argument type as a member of the struct type at the member index.
19108// Advances both arg_buff_index and mbr_idx to next argument slots.
19109void CompilerMSL::add_argument_buffer_padding_type(uint32_t mbr_type_id, SPIRType &struct_type, uint32_t &mbr_idx,
19110 uint32_t &arg_buff_index, uint32_t count)
19111{
19112 uint32_t type_id = mbr_type_id;
19113 if (count > 1)
19114 {
19115 uint32_t ary_type_id = ir.increase_bound_by(count: 1);
19116 auto &ary_type = set<SPIRType>(id: ary_type_id, args&: get<SPIRType>(id: type_id));
19117 ary_type.op = OpTypeArray;
19118 ary_type.array.push_back(t: count);
19119 ary_type.array_size_literal.push_back(t: true);
19120 ary_type.parent_type = type_id;
19121 type_id = ary_type_id;
19122 }
19123
19124 set_member_name(id: struct_type.self, index: mbr_idx, name: join(ts: "_m", ts&: arg_buff_index, ts: "_pad"));
19125 set_extended_member_decoration(type: struct_type.self, index: mbr_idx, decoration: SPIRVCrossDecorationResourceIndexPrimary, value: arg_buff_index);
19126 struct_type.member_types.push_back(t: type_id);
19127
19128 arg_buff_index += count;
19129 mbr_idx++;
19130}
19131
19132void CompilerMSL::activate_argument_buffer_resources()
19133{
19134 // For ABI compatibility, force-enable all resources which are part of argument buffers.
19135 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t self, const SPIRVariable &) {
19136 if (!has_decoration(id: self, decoration: DecorationDescriptorSet))
19137 return;
19138
19139 uint32_t desc_set = get_decoration(id: self, decoration: DecorationDescriptorSet);
19140 if (descriptor_set_is_argument_buffer(desc_set))
19141 add_active_interface_variable(var_id: self);
19142 });
19143}
19144
19145bool CompilerMSL::using_builtin_array() const
19146{
19147 return msl_options.force_native_arrays || is_using_builtin_array;
19148}
19149
19150void CompilerMSL::set_combined_sampler_suffix(const char *suffix)
19151{
19152 sampler_name_suffix = suffix;
19153}
19154
19155const char *CompilerMSL::get_combined_sampler_suffix() const
19156{
19157 return sampler_name_suffix.c_str();
19158}
19159
19160void CompilerMSL::emit_block_hints(const SPIRBlock &)
19161{
19162}
19163
19164void CompilerMSL::emit_mesh_entry_point()
19165{
19166 auto &ep = get_entry_point();
19167 auto &f = get<SPIRFunction>(id: ir.default_entry_point);
19168
19169 const uint32_t func_id = ir.increase_bound_by(count: 3);
19170 const uint32_t block_id = func_id + 1;
19171 const uint32_t ret_id = func_id + 2;
19172 auto &wrapped_main = set<SPIRFunction>(id: func_id, args&: f.return_type, args&: f.function_type);
19173
19174 wrapped_main.blocks.push_back(t: block_id);
19175 wrapped_main.entry_block = block_id;
19176
19177 auto &wrapped_entry = set<SPIRBlock>(block_id);
19178 wrapped_entry.terminator = SPIRBlock::Return;
19179
19180 // Push call to original 'main'
19181 Instruction ix = {};
19182 ix.op = OpFunctionCall;
19183 ix.offset = uint32_t(ir.spirv.size());
19184 ix.length = 3;
19185
19186 ir.spirv.push_back(x: f.return_type);
19187 ir.spirv.push_back(x: ret_id);
19188 ir.spirv.push_back(x: ep.self);
19189
19190 wrapped_entry.ops.push_back(t: ix);
19191
19192 // relace entry-point for new one
19193 SPIREntryPoint proxy_ep = ep;
19194 proxy_ep.self = func_id;
19195 ir.entry_points.insert(x: std::make_pair(x: func_id, y&: proxy_ep));
19196 ir.meta[func_id] = ir.meta[ir.default_entry_point];
19197 ir.meta[ir.default_entry_point].decoration.alias.clear();
19198
19199 ir.default_entry_point = func_id;
19200}
19201
19202void CompilerMSL::emit_mesh_outputs()
19203{
19204 auto &mode = get_entry_point();
19205
19206 // predefined thread count or zero, if specialization constant is in use
19207 uint32_t num_invocations = 0;
19208 if (mode.workgroup_size.id_x == 0 && mode.workgroup_size.id_y == 0 && mode.workgroup_size.id_z == 0)
19209 num_invocations = mode.workgroup_size.x * mode.workgroup_size.y * mode.workgroup_size.z;
19210
19211 statement(ts: "threadgroup_barrier(mem_flags::mem_threadgroup);");
19212 statement(ts: "if (spvMeshSizes.y == 0)");
19213 begin_scope();
19214 statement(ts: "return;");
19215 end_scope();
19216 statement(ts: "spvMesh.set_primitive_count(spvMeshSizes.y);");
19217
19218 statement(ts: "const uint spvThreadCount [[maybe_unused]] = (gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z);");
19219
19220 if (mesh_out_per_vertex != 0)
19221 {
19222 auto &type_vert = get<SPIRType>(id: mesh_out_per_vertex);
19223
19224 if (num_invocations < mode.output_vertices)
19225 {
19226 statement(ts: "for (uint spvVI = gl_LocalInvocationIndex; spvVI < spvMeshSizes.x; spvVI += spvThreadCount)");
19227 }
19228 else
19229 {
19230 statement(ts: "const uint spvVI = gl_LocalInvocationIndex;");
19231 statement(ts: "if (gl_LocalInvocationIndex < spvMeshSizes.x)");
19232 }
19233
19234 begin_scope();
19235
19236 statement(ts: "spvPerVertex spvV = {};");
19237 for (uint32_t index = 0; index < uint32_t(type_vert.member_types.size()); ++index)
19238 {
19239 uint32_t orig_var = get_extended_member_decoration(type: type_vert.self, index, decoration: SPIRVCrossDecorationInterfaceOrigID);
19240 uint32_t orig_id = get_extended_member_decoration(type: type_vert.self, index, decoration: SPIRVCrossDecorationInterfaceMemberIndex);
19241
19242 // Clip/cull distances are special-case
19243 if (orig_var == 0 && orig_id == (~0u))
19244 continue;
19245
19246 auto &orig = get<SPIRVariable>(id: orig_var);
19247 auto &orig_type = get<SPIRType>(id: orig.basetype);
19248
19249 // FIXME: Need to deal with complex composite IO types. These may need extra unroll, etc.
19250
19251 BuiltIn builtin = BuiltInMax;
19252 std::string access;
19253 if (orig_type.basetype == SPIRType::Struct)
19254 {
19255 if (has_member_decoration(id: orig_type.self, index: orig_id, decoration: DecorationBuiltIn))
19256 builtin = BuiltIn(get_member_decoration(id: orig_type.self, index: orig_id, decoration: DecorationBuiltIn));
19257
19258 switch (builtin)
19259 {
19260 case BuiltInPosition:
19261 case BuiltInPointSize:
19262 case BuiltInClipDistance:
19263 case BuiltInCullDistance:
19264 access = "." + builtin_to_glsl(builtin, storage: StorageClassOutput);
19265 break;
19266 default:
19267 access = "." + to_member_name(type: orig_type, index: orig_id);
19268 break;
19269 }
19270
19271 if (has_member_decoration(id: type_vert.self, index, decoration: DecorationIndex))
19272 {
19273 // Declare the Clip/CullDistance as [[user(clip/cullN)]].
19274 const uint32_t orig_index = get_member_decoration(id: type_vert.self, index, decoration: DecorationIndex);
19275 access += "[" + to_string(val: orig_index) + "]";
19276 statement(ts: "spvV.", ts: builtin_to_glsl(builtin, storage: StorageClassOutput), ts: "[", ts: orig_index, ts: "] = ", ts: to_name(id: orig_var), ts: "[spvVI]", ts&: access, ts: ";");
19277 }
19278 }
19279
19280 statement(ts: "spvV.", ts: to_member_name(type: type_vert, index), ts: " = ", ts: to_name(id: orig_var), ts: "[spvVI]", ts&: access, ts: ";");
19281 if (options.vertex.flip_vert_y && builtin == BuiltInPosition)
19282 {
19283 statement(ts: "spvV.", ts: to_member_name(type: type_vert, index), ts: ".y = -(", ts: "spvV.",
19284 ts: to_member_name(type: type_vert, index), ts: ".y);", ts: " // Invert Y-axis for Metal");
19285 }
19286 }
19287 statement(ts: "spvMesh.set_vertex(spvVI, spvV);");
19288 end_scope();
19289 }
19290
19291 if (mesh_out_per_primitive != 0 || builtin_mesh_primitive_indices_id != 0)
19292 {
19293 if (num_invocations < mode.output_primitives)
19294 {
19295 statement(ts: "for (uint spvPI = gl_LocalInvocationIndex; spvPI < spvMeshSizes.y; spvPI += spvThreadCount)");
19296 }
19297 else
19298 {
19299 statement(ts: "const uint spvPI = gl_LocalInvocationIndex;");
19300 statement(ts: "if (gl_LocalInvocationIndex < spvMeshSizes.y)");
19301 }
19302
19303 // FIXME: Need to deal with complex composite IO types. These may need extra unroll, etc.
19304
19305 begin_scope();
19306
19307 if (builtin_mesh_primitive_indices_id != 0)
19308 {
19309 if (mode.flags.get(bit: ExecutionModeOutputTrianglesEXT))
19310 {
19311 statement(ts: "spvMesh.set_index(spvPI * 3u + 0u, gl_PrimitiveTriangleIndicesEXT[spvPI].x);");
19312 statement(ts: "spvMesh.set_index(spvPI * 3u + 1u, gl_PrimitiveTriangleIndicesEXT[spvPI].y);");
19313 statement(ts: "spvMesh.set_index(spvPI * 3u + 2u, gl_PrimitiveTriangleIndicesEXT[spvPI].z);");
19314 }
19315 else if (mode.flags.get(bit: ExecutionModeOutputLinesEXT))
19316 {
19317 statement(ts: "spvMesh.set_index(spvPI * 2u + 0u, gl_PrimitiveLineIndicesEXT[spvPI].x);");
19318 statement(ts: "spvMesh.set_index(spvPI * 2u + 1u, gl_PrimitiveLineIndicesEXT[spvPI].y);");
19319 }
19320 else
19321 {
19322 statement(ts: "spvMesh.set_index(spvPI, gl_PrimitivePointIndicesEXT[spvPI]);");
19323 }
19324 }
19325
19326 if (mesh_out_per_primitive != 0)
19327 {
19328 auto &type_prim = get<SPIRType>(id: mesh_out_per_primitive);
19329 statement(ts: "spvPerPrimitive spvP = {};");
19330 for (uint32_t index = 0; index < uint32_t(type_prim.member_types.size()); ++index)
19331 {
19332 uint32_t orig_var =
19333 get_extended_member_decoration(type: type_prim.self, index, decoration: SPIRVCrossDecorationInterfaceOrigID);
19334 uint32_t orig_id =
19335 get_extended_member_decoration(type: type_prim.self, index, decoration: SPIRVCrossDecorationInterfaceMemberIndex);
19336 auto &orig = get<SPIRVariable>(id: orig_var);
19337 auto &orig_type = get<SPIRType>(id: orig.basetype);
19338
19339 BuiltIn builtin = BuiltInMax;
19340 std::string access;
19341 if (orig_type.basetype == SPIRType::Struct)
19342 {
19343 if (has_member_decoration(id: orig_type.self, index: orig_id, decoration: DecorationBuiltIn))
19344 builtin = BuiltIn(get_member_decoration(id: orig_type.self, index: orig_id, decoration: DecorationBuiltIn));
19345
19346 switch (builtin)
19347 {
19348 case BuiltInPrimitiveId:
19349 case BuiltInLayer:
19350 case BuiltInViewportIndex:
19351 case BuiltInCullPrimitiveEXT:
19352 case BuiltInPrimitiveShadingRateKHR:
19353 access = "." + builtin_to_glsl(builtin, storage: StorageClassOutput);
19354 break;
19355 default:
19356 access = "." + to_member_name(type: orig_type, index: orig_id);
19357 }
19358 }
19359 statement(ts: "spvP.", ts: to_member_name(type: type_prim, index), ts: " = ", ts: to_name(id: orig_var), ts: "[spvPI]", ts&: access, ts: ";");
19360 }
19361 statement(ts: "spvMesh.set_primitive(spvPI, spvP);");
19362 }
19363
19364 end_scope();
19365 }
19366}
19367
19368void CompilerMSL::emit_mesh_tasks(SPIRBlock &block)
19369{
19370 // GLSL: Once this instruction is called, the workgroup must be terminated immediately, and the mesh shaders are launched.
19371 // TODO: find relieble and clean of terminating shader.
19372 flush_variable_declaration(id: builtin_task_grid_id);
19373 statement(ts: "spvMgp.set_threadgroups_per_grid(uint3(", ts: to_unpacked_expression(id: block.mesh.groups[0]), ts: ", ",
19374 ts: to_unpacked_expression(id: block.mesh.groups[1]), ts: ", ", ts: to_unpacked_expression(id: block.mesh.groups[2]), ts: "));");
19375 // This is correct if EmitMeshTasks is called in the entry function for shader.
19376 // Only viable solutions would be:
19377 // - Caller ensures the SPIR-V is inlined, then this always holds true.
19378 // - Pass down a "should terminate" bool to leaf functions and chain return (horrible and disgusting, let's not).
19379 statement(ts: "return;");
19380}
19381
19382string CompilerMSL::additional_fixed_sample_mask_str() const
19383{
19384 char print_buffer[32];
19385#ifdef _MSC_VER
19386 // snprintf does not exist or is buggy on older MSVC versions, some of
19387 // them being used by MinGW. Use sprintf instead and disable
19388 // corresponding warning.
19389#pragma warning(push)
19390#pragma warning(disable : 4996)
19391#endif
19392#if _WIN32
19393 sprintf(print_buffer, "0x%x", msl_options.additional_fixed_sample_mask);
19394#else
19395 snprintf(s: print_buffer, maxlen: sizeof(print_buffer), format: "0x%x", msl_options.additional_fixed_sample_mask);
19396#endif
19397#ifdef _MSC_VER
19398#pragma warning(pop)
19399#endif
19400 return print_buffer;
19401}
19402

source code of qtshadertools/src/3rdparty/SPIRV-Cross/spirv_msl.cpp