1/*
2 * Copyright 2016-2021 The Brenwill Workshop Ltd.
3 * SPDX-License-Identifier: Apache-2.0 OR MIT
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18/*
19 * At your option, you may choose to accept this material under either:
20 * 1. The Apache License, Version 2.0, found at <http://www.apache.org/licenses/LICENSE-2.0>, or
21 * 2. The MIT License, found at <http://opensource.org/licenses/MIT>.
22 */
23
24#include "spirv_msl.hpp"
25#include "GLSL.std.450.h"
26
27#include <algorithm>
28#include <assert.h>
29#include <numeric>
30
31using namespace spv;
32using namespace SPIRV_CROSS_NAMESPACE;
33using namespace std;
34
35static const uint32_t k_unknown_location = ~0u;
36static const uint32_t k_unknown_component = ~0u;
37static const char *force_inline = "static inline __attribute__((always_inline))";
38
39CompilerMSL::CompilerMSL(std::vector<uint32_t> spirv_)
40 : CompilerGLSL(std::move(spirv_))
41{
42}
43
44CompilerMSL::CompilerMSL(const uint32_t *ir_, size_t word_count)
45 : CompilerGLSL(ir_, word_count)
46{
47}
48
49CompilerMSL::CompilerMSL(const ParsedIR &ir_)
50 : CompilerGLSL(ir_)
51{
52}
53
54CompilerMSL::CompilerMSL(ParsedIR &&ir_)
55 : CompilerGLSL(std::move(ir_))
56{
57}
58
59void CompilerMSL::add_msl_shader_input(const MSLShaderInterfaceVariable &si)
60{
61 inputs_by_location[{.location: si.location, .component: si.component}] = si;
62 if (si.builtin != BuiltInMax && !inputs_by_builtin.count(x: si.builtin))
63 inputs_by_builtin[si.builtin] = si;
64}
65
66void CompilerMSL::add_msl_shader_output(const MSLShaderInterfaceVariable &so)
67{
68 outputs_by_location[{.location: so.location, .component: so.component}] = so;
69 if (so.builtin != BuiltInMax && !outputs_by_builtin.count(x: so.builtin))
70 outputs_by_builtin[so.builtin] = so;
71}
72
73void CompilerMSL::add_msl_resource_binding(const MSLResourceBinding &binding)
74{
75 StageSetBinding tuple = { .model: binding.stage, .desc_set: binding.desc_set, .binding: binding.binding };
76 resource_bindings[tuple] = { binding, false };
77
78 // If we might need to pad argument buffer members to positionally align
79 // arg buffer indexes, also maintain a lookup by argument buffer index.
80 if (msl_options.pad_argument_buffer_resources)
81 {
82 StageSetBinding arg_idx_tuple = { .model: binding.stage, .desc_set: binding.desc_set, .binding: k_unknown_component };
83
84#define ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(rez) \
85 arg_idx_tuple.binding = binding.msl_##rez; \
86 resource_arg_buff_idx_to_binding_number[arg_idx_tuple] = binding.binding
87
88 switch (binding.basetype)
89 {
90 case SPIRType::Void:
91 case SPIRType::Boolean:
92 case SPIRType::SByte:
93 case SPIRType::UByte:
94 case SPIRType::Short:
95 case SPIRType::UShort:
96 case SPIRType::Int:
97 case SPIRType::UInt:
98 case SPIRType::Int64:
99 case SPIRType::UInt64:
100 case SPIRType::AtomicCounter:
101 case SPIRType::Half:
102 case SPIRType::Float:
103 case SPIRType::Double:
104 ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(buffer);
105 break;
106 case SPIRType::Image:
107 ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(texture);
108 break;
109 case SPIRType::Sampler:
110 ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(sampler);
111 break;
112 case SPIRType::SampledImage:
113 ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(texture);
114 ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(sampler);
115 break;
116 default:
117 SPIRV_CROSS_THROW("Unexpected argument buffer resource base type. When padding argument buffer elements, "
118 "all descriptor set resources must be supplied with a base type by the app.");
119 }
120#undef ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP
121 }
122}
123
124void CompilerMSL::add_dynamic_buffer(uint32_t desc_set, uint32_t binding, uint32_t index)
125{
126 SetBindingPair pair = { .desc_set: desc_set, .binding: binding };
127 buffers_requiring_dynamic_offset[pair] = { index, 0 };
128}
129
130void CompilerMSL::add_inline_uniform_block(uint32_t desc_set, uint32_t binding)
131{
132 SetBindingPair pair = { .desc_set: desc_set, .binding: binding };
133 inline_uniform_blocks.insert(x: pair);
134}
135
136void CompilerMSL::add_discrete_descriptor_set(uint32_t desc_set)
137{
138 if (desc_set < kMaxArgumentBuffers)
139 argument_buffer_discrete_mask |= 1u << desc_set;
140}
141
142void CompilerMSL::set_argument_buffer_device_address_space(uint32_t desc_set, bool device_storage)
143{
144 if (desc_set < kMaxArgumentBuffers)
145 {
146 if (device_storage)
147 argument_buffer_device_storage_mask |= 1u << desc_set;
148 else
149 argument_buffer_device_storage_mask &= ~(1u << desc_set);
150 }
151}
152
153bool CompilerMSL::is_msl_shader_input_used(uint32_t location)
154{
155 // Don't report internal location allocations to app.
156 return location_inputs_in_use.count(x: location) != 0 &&
157 location_inputs_in_use_fallback.count(x: location) == 0;
158}
159
160bool CompilerMSL::is_msl_shader_output_used(uint32_t location)
161{
162 // Don't report internal location allocations to app.
163 return location_outputs_in_use.count(x: location) != 0 &&
164 location_outputs_in_use_fallback.count(x: location) == 0;
165}
166
167uint32_t CompilerMSL::get_automatic_builtin_input_location(spv::BuiltIn builtin) const
168{
169 auto itr = builtin_to_automatic_input_location.find(x: builtin);
170 if (itr == builtin_to_automatic_input_location.end())
171 return k_unknown_location;
172 else
173 return itr->second;
174}
175
176uint32_t CompilerMSL::get_automatic_builtin_output_location(spv::BuiltIn builtin) const
177{
178 auto itr = builtin_to_automatic_output_location.find(x: builtin);
179 if (itr == builtin_to_automatic_output_location.end())
180 return k_unknown_location;
181 else
182 return itr->second;
183}
184
185bool CompilerMSL::is_msl_resource_binding_used(ExecutionModel model, uint32_t desc_set, uint32_t binding) const
186{
187 StageSetBinding tuple = { .model: model, .desc_set: desc_set, .binding: binding };
188 auto itr = resource_bindings.find(x: tuple);
189 return itr != end(cont: resource_bindings) && itr->second.second;
190}
191
192bool CompilerMSL::is_var_runtime_size_array(const SPIRVariable &var) const
193{
194 auto& type = get_variable_data_type(var);
195 return is_runtime_size_array(type) && get_resource_array_size(type, id: var.self) == 0;
196}
197
198// Returns the size of the array of resources used by the variable with the specified type and id.
199// The size is first retrieved from the type, but in the case of runtime array sizing,
200// the size is retrieved from the resource binding added using add_msl_resource_binding().
201uint32_t CompilerMSL::get_resource_array_size(const SPIRType &type, uint32_t id) const
202{
203 uint32_t array_size = to_array_size_literal(type);
204
205 // If we have argument buffers, we need to honor the ABI by using the correct array size
206 // from the layout. Only use shader declared size if we're not using argument buffers.
207 uint32_t desc_set = get_decoration(id, decoration: DecorationDescriptorSet);
208 if (!descriptor_set_is_argument_buffer(desc_set) && array_size)
209 return array_size;
210
211 StageSetBinding tuple = { .model: get_entry_point().model, .desc_set: desc_set,
212 .binding: get_decoration(id, decoration: DecorationBinding) };
213 auto itr = resource_bindings.find(x: tuple);
214 return itr != end(cont: resource_bindings) ? itr->second.first.count : array_size;
215}
216
217uint32_t CompilerMSL::get_automatic_msl_resource_binding(uint32_t id) const
218{
219 return get_extended_decoration(id, decoration: SPIRVCrossDecorationResourceIndexPrimary);
220}
221
222uint32_t CompilerMSL::get_automatic_msl_resource_binding_secondary(uint32_t id) const
223{
224 return get_extended_decoration(id, decoration: SPIRVCrossDecorationResourceIndexSecondary);
225}
226
227uint32_t CompilerMSL::get_automatic_msl_resource_binding_tertiary(uint32_t id) const
228{
229 return get_extended_decoration(id, decoration: SPIRVCrossDecorationResourceIndexTertiary);
230}
231
232uint32_t CompilerMSL::get_automatic_msl_resource_binding_quaternary(uint32_t id) const
233{
234 return get_extended_decoration(id, decoration: SPIRVCrossDecorationResourceIndexQuaternary);
235}
236
237void CompilerMSL::set_fragment_output_components(uint32_t location, uint32_t components)
238{
239 fragment_output_components[location] = components;
240}
241
242bool CompilerMSL::builtin_translates_to_nonarray(spv::BuiltIn builtin) const
243{
244 return (builtin == BuiltInSampleMask);
245}
246
247void CompilerMSL::build_implicit_builtins()
248{
249 bool need_sample_pos = active_input_builtins.get(bit: BuiltInSamplePosition);
250 bool need_vertex_params = capture_output_to_buffer && get_execution_model() == ExecutionModelVertex &&
251 !msl_options.vertex_for_tessellation;
252 bool need_tesc_params = is_tesc_shader();
253 bool need_tese_params = is_tese_shader() && msl_options.raw_buffer_tese_input;
254 bool need_subgroup_mask =
255 active_input_builtins.get(bit: BuiltInSubgroupEqMask) || active_input_builtins.get(bit: BuiltInSubgroupGeMask) ||
256 active_input_builtins.get(bit: BuiltInSubgroupGtMask) || active_input_builtins.get(bit: BuiltInSubgroupLeMask) ||
257 active_input_builtins.get(bit: BuiltInSubgroupLtMask);
258 bool need_subgroup_ge_mask = !msl_options.is_ios() && (active_input_builtins.get(bit: BuiltInSubgroupGeMask) ||
259 active_input_builtins.get(bit: BuiltInSubgroupGtMask));
260 bool need_multiview = get_execution_model() == ExecutionModelVertex && !msl_options.view_index_from_device_index &&
261 msl_options.multiview_layered_rendering &&
262 (msl_options.multiview || active_input_builtins.get(bit: BuiltInViewIndex));
263 bool need_dispatch_base =
264 msl_options.dispatch_base && get_execution_model() == ExecutionModelGLCompute &&
265 (active_input_builtins.get(bit: BuiltInWorkgroupId) || active_input_builtins.get(bit: BuiltInGlobalInvocationId));
266 bool need_grid_params = get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation;
267 bool need_vertex_base_params =
268 need_grid_params &&
269 (active_input_builtins.get(bit: BuiltInVertexId) || active_input_builtins.get(bit: BuiltInVertexIndex) ||
270 active_input_builtins.get(bit: BuiltInBaseVertex) || active_input_builtins.get(bit: BuiltInInstanceId) ||
271 active_input_builtins.get(bit: BuiltInInstanceIndex) || active_input_builtins.get(bit: BuiltInBaseInstance));
272 bool need_local_invocation_index = msl_options.emulate_subgroups && active_input_builtins.get(bit: BuiltInSubgroupId);
273 bool need_workgroup_size = msl_options.emulate_subgroups && active_input_builtins.get(bit: BuiltInNumSubgroups);
274 bool force_frag_depth_passthrough =
275 get_execution_model() == ExecutionModelFragment && !uses_explicit_early_fragment_test() && need_subpass_input &&
276 msl_options.enable_frag_depth_builtin && msl_options.input_attachment_is_ds_attachment;
277
278 if (need_subpass_input || need_sample_pos || need_subgroup_mask || need_vertex_params || need_tesc_params ||
279 need_tese_params || need_multiview || need_dispatch_base || need_vertex_base_params || need_grid_params ||
280 needs_sample_id || needs_subgroup_invocation_id || needs_subgroup_size || needs_helper_invocation ||
281 has_additional_fixed_sample_mask() || need_local_invocation_index || need_workgroup_size || force_frag_depth_passthrough)
282 {
283 bool has_frag_coord = false;
284 bool has_sample_id = false;
285 bool has_vertex_idx = false;
286 bool has_base_vertex = false;
287 bool has_instance_idx = false;
288 bool has_base_instance = false;
289 bool has_invocation_id = false;
290 bool has_primitive_id = false;
291 bool has_subgroup_invocation_id = false;
292 bool has_subgroup_size = false;
293 bool has_view_idx = false;
294 bool has_layer = false;
295 bool has_helper_invocation = false;
296 bool has_local_invocation_index = false;
297 bool has_workgroup_size = false;
298 bool has_frag_depth = false;
299 uint32_t workgroup_id_type = 0;
300
301 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
302 if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
303 return;
304 if (!interface_variable_exists_in_entry_point(id: var.self))
305 return;
306 if (!has_decoration(id: var.self, decoration: DecorationBuiltIn))
307 return;
308
309 BuiltIn builtin = ir.meta[var.self].decoration.builtin_type;
310
311 if (var.storage == StorageClassOutput)
312 {
313 if (has_additional_fixed_sample_mask() && builtin == BuiltInSampleMask)
314 {
315 builtin_sample_mask_id = var.self;
316 mark_implicit_builtin(storage: StorageClassOutput, builtin: BuiltInSampleMask, id: var.self);
317 does_shader_write_sample_mask = true;
318 }
319
320 if (force_frag_depth_passthrough && builtin == BuiltInFragDepth)
321 {
322 builtin_frag_depth_id = var.self;
323 mark_implicit_builtin(storage: StorageClassOutput, builtin: BuiltInFragDepth, id: var.self);
324 has_frag_depth = true;
325 }
326 }
327
328 if (var.storage != StorageClassInput)
329 return;
330
331 // Use Metal's native frame-buffer fetch API for subpass inputs.
332 if (need_subpass_input && (!msl_options.use_framebuffer_fetch_subpasses))
333 {
334 switch (builtin)
335 {
336 case BuiltInFragCoord:
337 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInFragCoord, id: var.self);
338 builtin_frag_coord_id = var.self;
339 has_frag_coord = true;
340 break;
341 case BuiltInLayer:
342 if (!msl_options.arrayed_subpass_input || msl_options.multiview)
343 break;
344 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInLayer, id: var.self);
345 builtin_layer_id = var.self;
346 has_layer = true;
347 break;
348 case BuiltInViewIndex:
349 if (!msl_options.multiview)
350 break;
351 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInViewIndex, id: var.self);
352 builtin_view_idx_id = var.self;
353 has_view_idx = true;
354 break;
355 default:
356 break;
357 }
358 }
359
360 if ((need_sample_pos || needs_sample_id) && builtin == BuiltInSampleId)
361 {
362 builtin_sample_id_id = var.self;
363 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInSampleId, id: var.self);
364 has_sample_id = true;
365 }
366
367 if (need_vertex_params)
368 {
369 switch (builtin)
370 {
371 case BuiltInVertexIndex:
372 builtin_vertex_idx_id = var.self;
373 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInVertexIndex, id: var.self);
374 has_vertex_idx = true;
375 break;
376 case BuiltInBaseVertex:
377 builtin_base_vertex_id = var.self;
378 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInBaseVertex, id: var.self);
379 has_base_vertex = true;
380 break;
381 case BuiltInInstanceIndex:
382 builtin_instance_idx_id = var.self;
383 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInInstanceIndex, id: var.self);
384 has_instance_idx = true;
385 break;
386 case BuiltInBaseInstance:
387 builtin_base_instance_id = var.self;
388 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInBaseInstance, id: var.self);
389 has_base_instance = true;
390 break;
391 default:
392 break;
393 }
394 }
395
396 if (need_tesc_params && builtin == BuiltInInvocationId)
397 {
398 builtin_invocation_id_id = var.self;
399 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInInvocationId, id: var.self);
400 has_invocation_id = true;
401 }
402
403 if ((need_tesc_params || need_tese_params) && builtin == BuiltInPrimitiveId)
404 {
405 builtin_primitive_id_id = var.self;
406 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInPrimitiveId, id: var.self);
407 has_primitive_id = true;
408 }
409
410 if (need_tese_params && builtin == BuiltInTessLevelOuter)
411 {
412 tess_level_outer_var_id = var.self;
413 }
414
415 if (need_tese_params && builtin == BuiltInTessLevelInner)
416 {
417 tess_level_inner_var_id = var.self;
418 }
419
420 if ((need_subgroup_mask || needs_subgroup_invocation_id) && builtin == BuiltInSubgroupLocalInvocationId)
421 {
422 builtin_subgroup_invocation_id_id = var.self;
423 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInSubgroupLocalInvocationId, id: var.self);
424 has_subgroup_invocation_id = true;
425 }
426
427 if ((need_subgroup_ge_mask || needs_subgroup_size) && builtin == BuiltInSubgroupSize)
428 {
429 builtin_subgroup_size_id = var.self;
430 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInSubgroupSize, id: var.self);
431 has_subgroup_size = true;
432 }
433
434 if (need_multiview)
435 {
436 switch (builtin)
437 {
438 case BuiltInInstanceIndex:
439 // The view index here is derived from the instance index.
440 builtin_instance_idx_id = var.self;
441 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInInstanceIndex, id: var.self);
442 has_instance_idx = true;
443 break;
444 case BuiltInBaseInstance:
445 // If a non-zero base instance is used, we need to adjust for it when calculating the view index.
446 builtin_base_instance_id = var.self;
447 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInBaseInstance, id: var.self);
448 has_base_instance = true;
449 break;
450 case BuiltInViewIndex:
451 builtin_view_idx_id = var.self;
452 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInViewIndex, id: var.self);
453 has_view_idx = true;
454 break;
455 default:
456 break;
457 }
458 }
459
460 if (needs_helper_invocation && builtin == BuiltInHelperInvocation)
461 {
462 builtin_helper_invocation_id = var.self;
463 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInHelperInvocation, id: var.self);
464 has_helper_invocation = true;
465 }
466
467 if (need_local_invocation_index && builtin == BuiltInLocalInvocationIndex)
468 {
469 builtin_local_invocation_index_id = var.self;
470 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInLocalInvocationIndex, id: var.self);
471 has_local_invocation_index = true;
472 }
473
474 if (need_workgroup_size && builtin == BuiltInLocalInvocationId)
475 {
476 builtin_workgroup_size_id = var.self;
477 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInWorkgroupSize, id: var.self);
478 has_workgroup_size = true;
479 }
480
481 // The base workgroup needs to have the same type and vector size
482 // as the workgroup or invocation ID, so keep track of the type that
483 // was used.
484 if (need_dispatch_base && workgroup_id_type == 0 &&
485 (builtin == BuiltInWorkgroupId || builtin == BuiltInGlobalInvocationId))
486 workgroup_id_type = var.basetype;
487 });
488
489 // Use Metal's native frame-buffer fetch API for subpass inputs.
490 if ((!has_frag_coord || (msl_options.multiview && !has_view_idx) ||
491 (msl_options.arrayed_subpass_input && !msl_options.multiview && !has_layer)) &&
492 (!msl_options.use_framebuffer_fetch_subpasses) && need_subpass_input)
493 {
494 if (!has_frag_coord)
495 {
496 uint32_t offset = ir.increase_bound_by(count: 3);
497 uint32_t type_id = offset;
498 uint32_t type_ptr_id = offset + 1;
499 uint32_t var_id = offset + 2;
500
501 // Create gl_FragCoord.
502 SPIRType vec4_type { OpTypeVector };
503 vec4_type.basetype = SPIRType::Float;
504 vec4_type.width = 32;
505 vec4_type.vecsize = 4;
506 set<SPIRType>(id: type_id, args&: vec4_type);
507
508 SPIRType vec4_type_ptr = vec4_type;
509 vec4_type_ptr.op = OpTypePointer;
510 vec4_type_ptr.pointer = true;
511 vec4_type_ptr.pointer_depth++;
512 vec4_type_ptr.parent_type = type_id;
513 vec4_type_ptr.storage = StorageClassInput;
514 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: vec4_type_ptr);
515 ptr_type.self = type_id;
516
517 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
518 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInFragCoord);
519 builtin_frag_coord_id = var_id;
520 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInFragCoord, id: var_id);
521 }
522
523 if (!has_layer && msl_options.arrayed_subpass_input && !msl_options.multiview)
524 {
525 uint32_t offset = ir.increase_bound_by(count: 2);
526 uint32_t type_ptr_id = offset;
527 uint32_t var_id = offset + 1;
528
529 // Create gl_Layer.
530 SPIRType uint_type_ptr = get_uint_type();
531 uint_type_ptr.op = OpTypePointer;
532 uint_type_ptr.pointer = true;
533 uint_type_ptr.pointer_depth++;
534 uint_type_ptr.parent_type = get_uint_type_id();
535 uint_type_ptr.storage = StorageClassInput;
536 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: uint_type_ptr);
537 ptr_type.self = get_uint_type_id();
538
539 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
540 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInLayer);
541 builtin_layer_id = var_id;
542 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInLayer, id: var_id);
543 }
544
545 if (!has_view_idx && msl_options.multiview)
546 {
547 uint32_t offset = ir.increase_bound_by(count: 2);
548 uint32_t type_ptr_id = offset;
549 uint32_t var_id = offset + 1;
550
551 // Create gl_ViewIndex.
552 SPIRType uint_type_ptr = get_uint_type();
553 uint_type_ptr.op = OpTypePointer;
554 uint_type_ptr.pointer = true;
555 uint_type_ptr.pointer_depth++;
556 uint_type_ptr.parent_type = get_uint_type_id();
557 uint_type_ptr.storage = StorageClassInput;
558 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: uint_type_ptr);
559 ptr_type.self = get_uint_type_id();
560
561 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
562 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInViewIndex);
563 builtin_view_idx_id = var_id;
564 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInViewIndex, id: var_id);
565 }
566 }
567
568 if (!has_sample_id && (need_sample_pos || needs_sample_id))
569 {
570 uint32_t offset = ir.increase_bound_by(count: 2);
571 uint32_t type_ptr_id = offset;
572 uint32_t var_id = offset + 1;
573
574 // Create gl_SampleID.
575 SPIRType uint_type_ptr = get_uint_type();
576 uint_type_ptr.op = OpTypePointer;
577 uint_type_ptr.pointer = true;
578 uint_type_ptr.pointer_depth++;
579 uint_type_ptr.parent_type = get_uint_type_id();
580 uint_type_ptr.storage = StorageClassInput;
581 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: uint_type_ptr);
582 ptr_type.self = get_uint_type_id();
583
584 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
585 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInSampleId);
586 builtin_sample_id_id = var_id;
587 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInSampleId, id: var_id);
588 }
589
590 if ((need_vertex_params && (!has_vertex_idx || !has_base_vertex || !has_instance_idx || !has_base_instance)) ||
591 (need_multiview && (!has_instance_idx || !has_base_instance || !has_view_idx)))
592 {
593 uint32_t type_ptr_id = ir.increase_bound_by(count: 1);
594
595 SPIRType uint_type_ptr = get_uint_type();
596 uint_type_ptr.op = OpTypePointer;
597 uint_type_ptr.pointer = true;
598 uint_type_ptr.pointer_depth++;
599 uint_type_ptr.parent_type = get_uint_type_id();
600 uint_type_ptr.storage = StorageClassInput;
601 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: uint_type_ptr);
602 ptr_type.self = get_uint_type_id();
603
604 if (need_vertex_params && !has_vertex_idx)
605 {
606 uint32_t var_id = ir.increase_bound_by(count: 1);
607
608 // Create gl_VertexIndex.
609 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
610 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInVertexIndex);
611 builtin_vertex_idx_id = var_id;
612 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInVertexIndex, id: var_id);
613 }
614
615 if (need_vertex_params && !has_base_vertex)
616 {
617 uint32_t var_id = ir.increase_bound_by(count: 1);
618
619 // Create gl_BaseVertex.
620 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
621 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInBaseVertex);
622 builtin_base_vertex_id = var_id;
623 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInBaseVertex, id: var_id);
624 }
625
626 if (!has_instance_idx) // Needed by both multiview and tessellation
627 {
628 uint32_t var_id = ir.increase_bound_by(count: 1);
629
630 // Create gl_InstanceIndex.
631 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
632 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInInstanceIndex);
633 builtin_instance_idx_id = var_id;
634 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInInstanceIndex, id: var_id);
635 }
636
637 if (!has_base_instance) // Needed by both multiview and tessellation
638 {
639 uint32_t var_id = ir.increase_bound_by(count: 1);
640
641 // Create gl_BaseInstance.
642 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
643 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInBaseInstance);
644 builtin_base_instance_id = var_id;
645 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInBaseInstance, id: var_id);
646 }
647
648 if (need_multiview)
649 {
650 // Multiview shaders are not allowed to write to gl_Layer, ostensibly because
651 // it is implicitly written from gl_ViewIndex, but we have to do that explicitly.
652 // Note that we can't just abuse gl_ViewIndex for this purpose: it's an input, but
653 // gl_Layer is an output in vertex-pipeline shaders.
654 uint32_t type_ptr_out_id = ir.increase_bound_by(count: 2);
655 SPIRType uint_type_ptr_out = get_uint_type();
656 uint_type_ptr.op = OpTypePointer;
657 uint_type_ptr_out.pointer = true;
658 uint_type_ptr_out.pointer_depth++;
659 uint_type_ptr_out.parent_type = get_uint_type_id();
660 uint_type_ptr_out.storage = StorageClassOutput;
661 auto &ptr_out_type = set<SPIRType>(id: type_ptr_out_id, args&: uint_type_ptr_out);
662 ptr_out_type.self = get_uint_type_id();
663 uint32_t var_id = type_ptr_out_id + 1;
664 set<SPIRVariable>(id: var_id, args&: type_ptr_out_id, args: StorageClassOutput);
665 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInLayer);
666 builtin_layer_id = var_id;
667 mark_implicit_builtin(storage: StorageClassOutput, builtin: BuiltInLayer, id: var_id);
668 }
669
670 if (need_multiview && !has_view_idx)
671 {
672 uint32_t var_id = ir.increase_bound_by(count: 1);
673
674 // Create gl_ViewIndex.
675 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
676 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInViewIndex);
677 builtin_view_idx_id = var_id;
678 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInViewIndex, id: var_id);
679 }
680 }
681
682 if ((need_tesc_params && (msl_options.multi_patch_workgroup || !has_invocation_id || !has_primitive_id)) ||
683 (need_tese_params && !has_primitive_id) || need_grid_params)
684 {
685 uint32_t type_ptr_id = ir.increase_bound_by(count: 1);
686
687 SPIRType uint_type_ptr = get_uint_type();
688 uint_type_ptr.op = OpTypePointer;
689 uint_type_ptr.pointer = true;
690 uint_type_ptr.pointer_depth++;
691 uint_type_ptr.parent_type = get_uint_type_id();
692 uint_type_ptr.storage = StorageClassInput;
693 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: uint_type_ptr);
694 ptr_type.self = get_uint_type_id();
695
696 if ((need_tesc_params && msl_options.multi_patch_workgroup) || need_grid_params)
697 {
698 uint32_t var_id = ir.increase_bound_by(count: 1);
699
700 // Create gl_GlobalInvocationID.
701 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
702 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInGlobalInvocationId);
703 builtin_invocation_id_id = var_id;
704 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInGlobalInvocationId, id: var_id);
705 }
706 else if (need_tesc_params && !has_invocation_id)
707 {
708 uint32_t var_id = ir.increase_bound_by(count: 1);
709
710 // Create gl_InvocationID.
711 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
712 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInInvocationId);
713 builtin_invocation_id_id = var_id;
714 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInInvocationId, id: var_id);
715 }
716
717 if ((need_tesc_params || need_tese_params) && !has_primitive_id)
718 {
719 uint32_t var_id = ir.increase_bound_by(count: 1);
720
721 // Create gl_PrimitiveID.
722 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
723 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInPrimitiveId);
724 builtin_primitive_id_id = var_id;
725 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInPrimitiveId, id: var_id);
726 }
727
728 if (need_grid_params)
729 {
730 uint32_t var_id = ir.increase_bound_by(count: 1);
731
732 set<SPIRVariable>(id: var_id, args: build_extended_vector_type(type_id: get_uint_type_id(), components: 3), args: StorageClassInput);
733 set_extended_decoration(id: var_id, decoration: SPIRVCrossDecorationBuiltInStageInputSize);
734 get_entry_point().interface_variables.push_back(t: var_id);
735 set_name(id: var_id, name: "spvStageInputSize");
736 builtin_stage_input_size_id = var_id;
737 }
738 }
739
740 if (!has_subgroup_invocation_id && (need_subgroup_mask || needs_subgroup_invocation_id))
741 {
742 uint32_t offset = ir.increase_bound_by(count: 2);
743 uint32_t type_ptr_id = offset;
744 uint32_t var_id = offset + 1;
745
746 // Create gl_SubgroupInvocationID.
747 SPIRType uint_type_ptr = get_uint_type();
748 uint_type_ptr.op = OpTypePointer;
749 uint_type_ptr.pointer = true;
750 uint_type_ptr.pointer_depth++;
751 uint_type_ptr.parent_type = get_uint_type_id();
752 uint_type_ptr.storage = StorageClassInput;
753 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: uint_type_ptr);
754 ptr_type.self = get_uint_type_id();
755
756 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
757 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInSubgroupLocalInvocationId);
758 builtin_subgroup_invocation_id_id = var_id;
759 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInSubgroupLocalInvocationId, id: var_id);
760 }
761
762 if (!has_subgroup_size && (need_subgroup_ge_mask || needs_subgroup_size))
763 {
764 uint32_t offset = ir.increase_bound_by(count: 2);
765 uint32_t type_ptr_id = offset;
766 uint32_t var_id = offset + 1;
767
768 // Create gl_SubgroupSize.
769 SPIRType uint_type_ptr = get_uint_type();
770 uint_type_ptr.op = OpTypePointer;
771 uint_type_ptr.pointer = true;
772 uint_type_ptr.pointer_depth++;
773 uint_type_ptr.parent_type = get_uint_type_id();
774 uint_type_ptr.storage = StorageClassInput;
775 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: uint_type_ptr);
776 ptr_type.self = get_uint_type_id();
777
778 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
779 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInSubgroupSize);
780 builtin_subgroup_size_id = var_id;
781 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInSubgroupSize, id: var_id);
782 }
783
784 if (need_dispatch_base || need_vertex_base_params)
785 {
786 if (workgroup_id_type == 0)
787 workgroup_id_type = build_extended_vector_type(type_id: get_uint_type_id(), components: 3);
788 uint32_t var_id;
789 if (msl_options.supports_msl_version(major: 1, minor: 2))
790 {
791 // If we have MSL 1.2, we can (ab)use the [[grid_origin]] builtin
792 // to convey this information and save a buffer slot.
793 uint32_t offset = ir.increase_bound_by(count: 1);
794 var_id = offset;
795
796 set<SPIRVariable>(id: var_id, args&: workgroup_id_type, args: StorageClassInput);
797 set_extended_decoration(id: var_id, decoration: SPIRVCrossDecorationBuiltInDispatchBase);
798 get_entry_point().interface_variables.push_back(t: var_id);
799 }
800 else
801 {
802 // Otherwise, we need to fall back to a good ol' fashioned buffer.
803 uint32_t offset = ir.increase_bound_by(count: 2);
804 var_id = offset;
805 uint32_t type_id = offset + 1;
806
807 SPIRType var_type = get<SPIRType>(id: workgroup_id_type);
808 var_type.storage = StorageClassUniform;
809 set<SPIRType>(id: type_id, args&: var_type);
810
811 set<SPIRVariable>(id: var_id, args&: type_id, args: StorageClassUniform);
812 // This should never match anything.
813 set_decoration(id: var_id, decoration: DecorationDescriptorSet, argument: ~(5u));
814 set_decoration(id: var_id, decoration: DecorationBinding, argument: msl_options.indirect_params_buffer_index);
815 set_extended_decoration(id: var_id, decoration: SPIRVCrossDecorationResourceIndexPrimary,
816 value: msl_options.indirect_params_buffer_index);
817 }
818 set_name(id: var_id, name: "spvDispatchBase");
819 builtin_dispatch_base_id = var_id;
820 }
821
822 if (has_additional_fixed_sample_mask() && !does_shader_write_sample_mask)
823 {
824 uint32_t offset = ir.increase_bound_by(count: 2);
825 uint32_t var_id = offset + 1;
826
827 // Create gl_SampleMask.
828 SPIRType uint_type_ptr_out = get_uint_type();
829 uint_type_ptr_out.op = OpTypePointer;
830 uint_type_ptr_out.pointer = true;
831 uint_type_ptr_out.pointer_depth++;
832 uint_type_ptr_out.parent_type = get_uint_type_id();
833 uint_type_ptr_out.storage = StorageClassOutput;
834
835 auto &ptr_out_type = set<SPIRType>(id: offset, args&: uint_type_ptr_out);
836 ptr_out_type.self = get_uint_type_id();
837 set<SPIRVariable>(id: var_id, args&: offset, args: StorageClassOutput);
838 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInSampleMask);
839 builtin_sample_mask_id = var_id;
840 mark_implicit_builtin(storage: StorageClassOutput, builtin: BuiltInSampleMask, id: var_id);
841 }
842
843 if (!has_helper_invocation && needs_helper_invocation)
844 {
845 uint32_t offset = ir.increase_bound_by(count: 3);
846 uint32_t type_id = offset;
847 uint32_t type_ptr_id = offset + 1;
848 uint32_t var_id = offset + 2;
849
850 // Create gl_HelperInvocation.
851 SPIRType bool_type { OpTypeBool };
852 bool_type.basetype = SPIRType::Boolean;
853 bool_type.width = 8;
854 bool_type.vecsize = 1;
855 set<SPIRType>(id: type_id, args&: bool_type);
856
857 SPIRType bool_type_ptr_in = bool_type;
858 bool_type_ptr_in.op = spv::OpTypePointer;
859 bool_type_ptr_in.pointer = true;
860 bool_type_ptr_in.pointer_depth++;
861 bool_type_ptr_in.parent_type = type_id;
862 bool_type_ptr_in.storage = StorageClassInput;
863
864 auto &ptr_in_type = set<SPIRType>(id: type_ptr_id, args&: bool_type_ptr_in);
865 ptr_in_type.self = type_id;
866 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
867 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInHelperInvocation);
868 builtin_helper_invocation_id = var_id;
869 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInHelperInvocation, id: var_id);
870 }
871
872 if (need_local_invocation_index && !has_local_invocation_index)
873 {
874 uint32_t offset = ir.increase_bound_by(count: 2);
875 uint32_t type_ptr_id = offset;
876 uint32_t var_id = offset + 1;
877
878 // Create gl_LocalInvocationIndex.
879 SPIRType uint_type_ptr = get_uint_type();
880 uint_type_ptr.op = OpTypePointer;
881 uint_type_ptr.pointer = true;
882 uint_type_ptr.pointer_depth++;
883 uint_type_ptr.parent_type = get_uint_type_id();
884 uint_type_ptr.storage = StorageClassInput;
885
886 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: uint_type_ptr);
887 ptr_type.self = get_uint_type_id();
888 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
889 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInLocalInvocationIndex);
890 builtin_local_invocation_index_id = var_id;
891 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInLocalInvocationIndex, id: var_id);
892 }
893
894 if (need_workgroup_size && !has_workgroup_size)
895 {
896 uint32_t offset = ir.increase_bound_by(count: 2);
897 uint32_t type_ptr_id = offset;
898 uint32_t var_id = offset + 1;
899
900 // Create gl_WorkgroupSize.
901 uint32_t type_id = build_extended_vector_type(type_id: get_uint_type_id(), components: 3);
902 SPIRType uint_type_ptr = get<SPIRType>(id: type_id);
903 uint_type_ptr.op = OpTypePointer;
904 uint_type_ptr.pointer = true;
905 uint_type_ptr.pointer_depth++;
906 uint_type_ptr.parent_type = type_id;
907 uint_type_ptr.storage = StorageClassInput;
908
909 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: uint_type_ptr);
910 ptr_type.self = type_id;
911 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassInput);
912 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInWorkgroupSize);
913 builtin_workgroup_size_id = var_id;
914 mark_implicit_builtin(storage: StorageClassInput, builtin: BuiltInWorkgroupSize, id: var_id);
915 }
916
917 if (!has_frag_depth && force_frag_depth_passthrough)
918 {
919 uint32_t offset = ir.increase_bound_by(count: 3);
920 uint32_t type_id = offset;
921 uint32_t type_ptr_id = offset + 1;
922 uint32_t var_id = offset + 2;
923
924 // Create gl_FragDepth
925 SPIRType float_type { OpTypeFloat };
926 float_type.basetype = SPIRType::Float;
927 float_type.width = 32;
928 float_type.vecsize = 1;
929 set<SPIRType>(id: type_id, args&: float_type);
930
931 SPIRType float_type_ptr_in = float_type;
932 float_type_ptr_in.op = spv::OpTypePointer;
933 float_type_ptr_in.pointer = true;
934 float_type_ptr_in.pointer_depth++;
935 float_type_ptr_in.parent_type = type_id;
936 float_type_ptr_in.storage = StorageClassOutput;
937
938 auto &ptr_in_type = set<SPIRType>(id: type_ptr_id, args&: float_type_ptr_in);
939 ptr_in_type.self = type_id;
940 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassOutput);
941 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInFragDepth);
942 builtin_frag_depth_id = var_id;
943 mark_implicit_builtin(storage: StorageClassOutput, builtin: BuiltInFragDepth, id: var_id);
944 active_output_builtins.set(BuiltInFragDepth);
945 }
946 }
947
948 if (needs_swizzle_buffer_def)
949 {
950 uint32_t var_id = build_constant_uint_array_pointer();
951 set_name(id: var_id, name: "spvSwizzleConstants");
952 // This should never match anything.
953 set_decoration(id: var_id, decoration: DecorationDescriptorSet, argument: kSwizzleBufferBinding);
954 set_decoration(id: var_id, decoration: DecorationBinding, argument: msl_options.swizzle_buffer_index);
955 set_extended_decoration(id: var_id, decoration: SPIRVCrossDecorationResourceIndexPrimary, value: msl_options.swizzle_buffer_index);
956 swizzle_buffer_id = var_id;
957 }
958
959 if (needs_buffer_size_buffer())
960 {
961 uint32_t var_id = build_constant_uint_array_pointer();
962 set_name(id: var_id, name: "spvBufferSizeConstants");
963 // This should never match anything.
964 set_decoration(id: var_id, decoration: DecorationDescriptorSet, argument: kBufferSizeBufferBinding);
965 set_decoration(id: var_id, decoration: DecorationBinding, argument: msl_options.buffer_size_buffer_index);
966 set_extended_decoration(id: var_id, decoration: SPIRVCrossDecorationResourceIndexPrimary, value: msl_options.buffer_size_buffer_index);
967 buffer_size_buffer_id = var_id;
968 }
969
970 if (needs_view_mask_buffer())
971 {
972 uint32_t var_id = build_constant_uint_array_pointer();
973 set_name(id: var_id, name: "spvViewMask");
974 // This should never match anything.
975 set_decoration(id: var_id, decoration: DecorationDescriptorSet, argument: ~(4u));
976 set_decoration(id: var_id, decoration: DecorationBinding, argument: msl_options.view_mask_buffer_index);
977 set_extended_decoration(id: var_id, decoration: SPIRVCrossDecorationResourceIndexPrimary, value: msl_options.view_mask_buffer_index);
978 view_mask_buffer_id = var_id;
979 }
980
981 if (!buffers_requiring_dynamic_offset.empty())
982 {
983 uint32_t var_id = build_constant_uint_array_pointer();
984 set_name(id: var_id, name: "spvDynamicOffsets");
985 // This should never match anything.
986 set_decoration(id: var_id, decoration: DecorationDescriptorSet, argument: ~(5u));
987 set_decoration(id: var_id, decoration: DecorationBinding, argument: msl_options.dynamic_offsets_buffer_index);
988 set_extended_decoration(id: var_id, decoration: SPIRVCrossDecorationResourceIndexPrimary,
989 value: msl_options.dynamic_offsets_buffer_index);
990 dynamic_offsets_buffer_id = var_id;
991 }
992
993 // If we're returning a struct from a vertex-like entry point, we must return a position attribute.
994 bool need_position = (get_execution_model() == ExecutionModelVertex || is_tese_shader()) &&
995 !capture_output_to_buffer && !get_is_rasterization_disabled() &&
996 !active_output_builtins.get(bit: BuiltInPosition);
997
998 if (need_position)
999 {
1000 // If we can get away with returning void from entry point, we don't need to care.
1001 // If there is at least one other stage output, we need to return [[position]],
1002 // so we need to create one if it doesn't appear in the SPIR-V. Before adding the
1003 // implicit variable, check if it actually exists already, but just has not been used
1004 // or initialized, and if so, mark it as active, and do not create the implicit variable.
1005 bool has_output = false;
1006 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
1007 if (var.storage == StorageClassOutput && interface_variable_exists_in_entry_point(id: var.self))
1008 {
1009 has_output = true;
1010
1011 // Check if the var is the Position builtin
1012 if (has_decoration(id: var.self, decoration: DecorationBuiltIn) && get_decoration(id: var.self, decoration: DecorationBuiltIn) == BuiltInPosition)
1013 active_output_builtins.set(BuiltInPosition);
1014
1015 // If the var is a struct, check if any members is the Position builtin
1016 auto &var_type = get_variable_element_type(var);
1017 if (var_type.basetype == SPIRType::Struct)
1018 {
1019 auto mbr_cnt = var_type.member_types.size();
1020 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
1021 {
1022 auto builtin = BuiltInMax;
1023 bool is_builtin = is_member_builtin(type: var_type, index: mbr_idx, builtin: &builtin);
1024 if (is_builtin && builtin == BuiltInPosition)
1025 active_output_builtins.set(BuiltInPosition);
1026 }
1027 }
1028 }
1029 });
1030 need_position = has_output && !active_output_builtins.get(bit: BuiltInPosition);
1031 }
1032
1033 if (need_position)
1034 {
1035 uint32_t offset = ir.increase_bound_by(count: 3);
1036 uint32_t type_id = offset;
1037 uint32_t type_ptr_id = offset + 1;
1038 uint32_t var_id = offset + 2;
1039
1040 // Create gl_Position.
1041 SPIRType vec4_type { OpTypeVector };
1042 vec4_type.basetype = SPIRType::Float;
1043 vec4_type.width = 32;
1044 vec4_type.vecsize = 4;
1045 set<SPIRType>(id: type_id, args&: vec4_type);
1046
1047 SPIRType vec4_type_ptr = vec4_type;
1048 vec4_type_ptr.op = OpTypePointer;
1049 vec4_type_ptr.pointer = true;
1050 vec4_type_ptr.pointer_depth++;
1051 vec4_type_ptr.parent_type = type_id;
1052 vec4_type_ptr.storage = StorageClassOutput;
1053 auto &ptr_type = set<SPIRType>(id: type_ptr_id, args&: vec4_type_ptr);
1054 ptr_type.self = type_id;
1055
1056 set<SPIRVariable>(id: var_id, args&: type_ptr_id, args: StorageClassOutput);
1057 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: BuiltInPosition);
1058 mark_implicit_builtin(storage: StorageClassOutput, builtin: BuiltInPosition, id: var_id);
1059 }
1060}
1061
1062// Checks if the specified builtin variable (e.g. gl_InstanceIndex) is marked as active.
1063// If not, it marks it as active and forces a recompilation.
1064// This might be used when the optimization of inactive builtins was too optimistic (e.g. when "spvOut" is emitted).
1065void CompilerMSL::ensure_builtin(spv::StorageClass storage, spv::BuiltIn builtin)
1066{
1067 Bitset *active_builtins = nullptr;
1068 switch (storage)
1069 {
1070 case StorageClassInput:
1071 active_builtins = &active_input_builtins;
1072 break;
1073
1074 case StorageClassOutput:
1075 active_builtins = &active_output_builtins;
1076 break;
1077
1078 default:
1079 break;
1080 }
1081
1082 // At this point, the specified builtin variable must have already been declared in the entry point.
1083 // If not, mark as active and force recompile.
1084 if (active_builtins != nullptr && !active_builtins->get(bit: builtin))
1085 {
1086 active_builtins->set(builtin);
1087 force_recompile();
1088 }
1089}
1090
1091void CompilerMSL::mark_implicit_builtin(StorageClass storage, BuiltIn builtin, uint32_t id)
1092{
1093 Bitset *active_builtins = nullptr;
1094 switch (storage)
1095 {
1096 case StorageClassInput:
1097 active_builtins = &active_input_builtins;
1098 break;
1099
1100 case StorageClassOutput:
1101 active_builtins = &active_output_builtins;
1102 break;
1103
1104 default:
1105 break;
1106 }
1107
1108 assert(active_builtins != nullptr);
1109 active_builtins->set(builtin);
1110
1111 auto &var = get_entry_point().interface_variables;
1112 if (find(first: begin(cont&: var), last: end(cont&: var), val: VariableID(id)) == end(cont&: var))
1113 var.push_back(t: id);
1114}
1115
1116uint32_t CompilerMSL::build_constant_uint_array_pointer()
1117{
1118 uint32_t offset = ir.increase_bound_by(count: 3);
1119 uint32_t type_ptr_id = offset;
1120 uint32_t type_ptr_ptr_id = offset + 1;
1121 uint32_t var_id = offset + 2;
1122
1123 // Create a buffer to hold extra data, including the swizzle constants.
1124 SPIRType uint_type_pointer = get_uint_type();
1125 uint_type_pointer.op = OpTypePointer;
1126 uint_type_pointer.pointer = true;
1127 uint_type_pointer.pointer_depth++;
1128 uint_type_pointer.parent_type = get_uint_type_id();
1129 uint_type_pointer.storage = StorageClassUniform;
1130 set<SPIRType>(id: type_ptr_id, args&: uint_type_pointer);
1131 set_decoration(id: type_ptr_id, decoration: DecorationArrayStride, argument: 4);
1132
1133 SPIRType uint_type_pointer2 = uint_type_pointer;
1134 uint_type_pointer2.pointer_depth++;
1135 uint_type_pointer2.parent_type = type_ptr_id;
1136 set<SPIRType>(id: type_ptr_ptr_id, args&: uint_type_pointer2);
1137
1138 set<SPIRVariable>(id: var_id, args&: type_ptr_ptr_id, args: StorageClassUniformConstant);
1139 return var_id;
1140}
1141
1142static string create_sampler_address(const char *prefix, MSLSamplerAddress addr)
1143{
1144 switch (addr)
1145 {
1146 case MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE:
1147 return join(ts&: prefix, ts: "address::clamp_to_edge");
1148 case MSL_SAMPLER_ADDRESS_CLAMP_TO_ZERO:
1149 return join(ts&: prefix, ts: "address::clamp_to_zero");
1150 case MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER:
1151 return join(ts&: prefix, ts: "address::clamp_to_border");
1152 case MSL_SAMPLER_ADDRESS_REPEAT:
1153 return join(ts&: prefix, ts: "address::repeat");
1154 case MSL_SAMPLER_ADDRESS_MIRRORED_REPEAT:
1155 return join(ts&: prefix, ts: "address::mirrored_repeat");
1156 default:
1157 SPIRV_CROSS_THROW("Invalid sampler addressing mode.");
1158 }
1159}
1160
1161SPIRType &CompilerMSL::get_stage_in_struct_type()
1162{
1163 auto &si_var = get<SPIRVariable>(id: stage_in_var_id);
1164 return get_variable_data_type(var: si_var);
1165}
1166
1167SPIRType &CompilerMSL::get_stage_out_struct_type()
1168{
1169 auto &so_var = get<SPIRVariable>(id: stage_out_var_id);
1170 return get_variable_data_type(var: so_var);
1171}
1172
1173SPIRType &CompilerMSL::get_patch_stage_in_struct_type()
1174{
1175 auto &si_var = get<SPIRVariable>(id: patch_stage_in_var_id);
1176 return get_variable_data_type(var: si_var);
1177}
1178
1179SPIRType &CompilerMSL::get_patch_stage_out_struct_type()
1180{
1181 auto &so_var = get<SPIRVariable>(id: patch_stage_out_var_id);
1182 return get_variable_data_type(var: so_var);
1183}
1184
1185std::string CompilerMSL::get_tess_factor_struct_name()
1186{
1187 if (is_tessellating_triangles())
1188 return "MTLTriangleTessellationFactorsHalf";
1189 return "MTLQuadTessellationFactorsHalf";
1190}
1191
1192SPIRType &CompilerMSL::get_uint_type()
1193{
1194 return get<SPIRType>(id: get_uint_type_id());
1195}
1196
1197uint32_t CompilerMSL::get_uint_type_id()
1198{
1199 if (uint_type_id != 0)
1200 return uint_type_id;
1201
1202 uint_type_id = ir.increase_bound_by(count: 1);
1203
1204 SPIRType type { OpTypeInt };
1205 type.basetype = SPIRType::UInt;
1206 type.width = 32;
1207 set<SPIRType>(id: uint_type_id, args&: type);
1208 return uint_type_id;
1209}
1210
1211void CompilerMSL::emit_entry_point_declarations()
1212{
1213 // FIXME: Get test coverage here ...
1214 // Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries
1215 declare_complex_constant_arrays();
1216
1217 // Emit constexpr samplers here.
1218 for (auto &samp : constexpr_samplers_by_id)
1219 {
1220 auto &var = get<SPIRVariable>(id: samp.first);
1221 auto &type = get<SPIRType>(id: var.basetype);
1222 if (type.basetype == SPIRType::Sampler)
1223 add_resource_name(id: samp.first);
1224
1225 SmallVector<string> args;
1226 auto &s = samp.second;
1227
1228 if (s.coord != MSL_SAMPLER_COORD_NORMALIZED)
1229 args.push_back(t: "coord::pixel");
1230
1231 if (s.min_filter == s.mag_filter)
1232 {
1233 if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST)
1234 args.push_back(t: "filter::linear");
1235 }
1236 else
1237 {
1238 if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST)
1239 args.push_back(t: "min_filter::linear");
1240 if (s.mag_filter != MSL_SAMPLER_FILTER_NEAREST)
1241 args.push_back(t: "mag_filter::linear");
1242 }
1243
1244 switch (s.mip_filter)
1245 {
1246 case MSL_SAMPLER_MIP_FILTER_NONE:
1247 // Default
1248 break;
1249 case MSL_SAMPLER_MIP_FILTER_NEAREST:
1250 args.push_back(t: "mip_filter::nearest");
1251 break;
1252 case MSL_SAMPLER_MIP_FILTER_LINEAR:
1253 args.push_back(t: "mip_filter::linear");
1254 break;
1255 default:
1256 SPIRV_CROSS_THROW("Invalid mip filter.");
1257 }
1258
1259 if (s.s_address == s.t_address && s.s_address == s.r_address)
1260 {
1261 if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1262 args.push_back(t: create_sampler_address(prefix: "", addr: s.s_address));
1263 }
1264 else
1265 {
1266 if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1267 args.push_back(t: create_sampler_address(prefix: "s_", addr: s.s_address));
1268 if (s.t_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1269 args.push_back(t: create_sampler_address(prefix: "t_", addr: s.t_address));
1270 if (s.r_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1271 args.push_back(t: create_sampler_address(prefix: "r_", addr: s.r_address));
1272 }
1273
1274 if (s.compare_enable)
1275 {
1276 switch (s.compare_func)
1277 {
1278 case MSL_SAMPLER_COMPARE_FUNC_ALWAYS:
1279 args.push_back(t: "compare_func::always");
1280 break;
1281 case MSL_SAMPLER_COMPARE_FUNC_NEVER:
1282 args.push_back(t: "compare_func::never");
1283 break;
1284 case MSL_SAMPLER_COMPARE_FUNC_EQUAL:
1285 args.push_back(t: "compare_func::equal");
1286 break;
1287 case MSL_SAMPLER_COMPARE_FUNC_NOT_EQUAL:
1288 args.push_back(t: "compare_func::not_equal");
1289 break;
1290 case MSL_SAMPLER_COMPARE_FUNC_LESS:
1291 args.push_back(t: "compare_func::less");
1292 break;
1293 case MSL_SAMPLER_COMPARE_FUNC_LESS_EQUAL:
1294 args.push_back(t: "compare_func::less_equal");
1295 break;
1296 case MSL_SAMPLER_COMPARE_FUNC_GREATER:
1297 args.push_back(t: "compare_func::greater");
1298 break;
1299 case MSL_SAMPLER_COMPARE_FUNC_GREATER_EQUAL:
1300 args.push_back(t: "compare_func::greater_equal");
1301 break;
1302 default:
1303 SPIRV_CROSS_THROW("Invalid sampler compare function.");
1304 }
1305 }
1306
1307 if (s.s_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER || s.t_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER ||
1308 s.r_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER)
1309 {
1310 switch (s.border_color)
1311 {
1312 case MSL_SAMPLER_BORDER_COLOR_OPAQUE_BLACK:
1313 args.push_back(t: "border_color::opaque_black");
1314 break;
1315 case MSL_SAMPLER_BORDER_COLOR_OPAQUE_WHITE:
1316 args.push_back(t: "border_color::opaque_white");
1317 break;
1318 case MSL_SAMPLER_BORDER_COLOR_TRANSPARENT_BLACK:
1319 args.push_back(t: "border_color::transparent_black");
1320 break;
1321 default:
1322 SPIRV_CROSS_THROW("Invalid sampler border color.");
1323 }
1324 }
1325
1326 if (s.anisotropy_enable)
1327 args.push_back(t: join(ts: "max_anisotropy(", ts&: s.max_anisotropy, ts: ")"));
1328 if (s.lod_clamp_enable)
1329 {
1330 args.push_back(t: join(ts: "lod_clamp(", ts: format_float(value: s.lod_clamp_min), ts: ", ", ts: format_float(value: s.lod_clamp_max), ts: ")"));
1331 }
1332
1333 // If we would emit no arguments, then omit the parentheses entirely. Otherwise,
1334 // we'll wind up with a "most vexing parse" situation.
1335 if (args.empty())
1336 statement(ts: "constexpr sampler ",
1337 ts: type.basetype == SPIRType::SampledImage ? to_sampler_expression(id: samp.first) : to_name(id: samp.first),
1338 ts: ";");
1339 else
1340 statement(ts: "constexpr sampler ",
1341 ts: type.basetype == SPIRType::SampledImage ? to_sampler_expression(id: samp.first) : to_name(id: samp.first),
1342 ts: "(", ts: merge(list: args), ts: ");");
1343 }
1344
1345 // Emit dynamic buffers here.
1346 for (auto &dynamic_buffer : buffers_requiring_dynamic_offset)
1347 {
1348 if (!dynamic_buffer.second.second)
1349 {
1350 // Could happen if no buffer was used at requested binding point.
1351 continue;
1352 }
1353
1354 const auto &var = get<SPIRVariable>(id: dynamic_buffer.second.second);
1355 uint32_t var_id = var.self;
1356 const auto &type = get_variable_data_type(var);
1357 string name = to_name(id: var.self);
1358 uint32_t desc_set = get_decoration(id: var.self, decoration: DecorationDescriptorSet);
1359 uint32_t arg_id = argument_buffer_ids[desc_set];
1360 uint32_t base_index = dynamic_buffer.second.first;
1361
1362 if (is_array(type))
1363 {
1364 if (!type.array[type.array.size() - 1])
1365 SPIRV_CROSS_THROW("Runtime arrays with dynamic offsets are not supported yet.");
1366
1367 is_using_builtin_array = true;
1368 statement(ts: get_argument_address_space(argument: var), ts: " ", ts: type_to_glsl(type), ts: "* ", ts: to_restrict(id: var_id, space: true), ts&: name,
1369 ts: type_to_array_glsl(type, variable_id: var_id), ts: " =");
1370
1371 uint32_t array_size = to_array_size_literal(type);
1372 begin_scope();
1373
1374 for (uint32_t i = 0; i < array_size; i++)
1375 {
1376 statement(ts: "(", ts: get_argument_address_space(argument: var), ts: " ", ts: type_to_glsl(type), ts: "* ",
1377 ts: to_restrict(id: var_id, space: false), ts: ")((", ts: get_argument_address_space(argument: var), ts: " char* ",
1378 ts: to_restrict(id: var_id, space: false), ts: ")", ts: to_name(id: arg_id), ts: ".", ts: ensure_valid_name(name, pfx: "m"),
1379 ts: "[", ts&: i, ts: "]", ts: " + ", ts: to_name(id: dynamic_offsets_buffer_id), ts: "[", ts: base_index + i, ts: "]),");
1380 }
1381
1382 end_scope_decl();
1383 statement_no_indent(ts: "");
1384 is_using_builtin_array = false;
1385 }
1386 else
1387 {
1388 statement(ts: get_argument_address_space(argument: var), ts: " auto& ", ts: to_restrict(id: var_id, space: true), ts&: name, ts: " = *(",
1389 ts: get_argument_address_space(argument: var), ts: " ", ts: type_to_glsl(type), ts: "* ", ts: to_restrict(id: var_id, space: false), ts: ")((",
1390 ts: get_argument_address_space(argument: var), ts: " char* ", ts: to_restrict(id: var_id, space: false), ts: ")", ts: to_name(id: arg_id), ts: ".",
1391 ts: ensure_valid_name(name, pfx: "m"), ts: " + ", ts: to_name(id: dynamic_offsets_buffer_id), ts: "[", ts&: base_index, ts: "]);");
1392 }
1393 }
1394
1395 bool has_runtime_array_declaration = false;
1396 for (SPIRVariable *arg : entry_point_bindings)
1397 {
1398 const auto &var = *arg;
1399 const auto &type = get_variable_data_type(var);
1400 const auto &buffer_type = get_variable_element_type(var);
1401 const string name = to_name(id: var.self);
1402
1403 if (is_var_runtime_size_array(var))
1404 {
1405 if (msl_options.argument_buffers_tier < Options::ArgumentBuffersTier::Tier2)
1406 {
1407 SPIRV_CROSS_THROW("Unsized array of descriptors requires argument buffer tier 2");
1408 }
1409
1410 string resource_name;
1411 if (descriptor_set_is_argument_buffer(desc_set: get_decoration(id: var.self, decoration: DecorationDescriptorSet)))
1412 resource_name = ir.meta[var.self].decoration.qualified_alias;
1413 else
1414 resource_name = name + "_";
1415
1416 switch (type.basetype)
1417 {
1418 case SPIRType::Image:
1419 case SPIRType::Sampler:
1420 case SPIRType::AccelerationStructure:
1421 statement(ts: "spvDescriptorArray<", ts: type_to_glsl(type: buffer_type, id: var.self), ts: "> ", ts: name, ts: " {", ts&: resource_name, ts: "};");
1422 break;
1423 case SPIRType::SampledImage:
1424 statement(ts: "spvDescriptorArray<", ts: type_to_glsl(type: buffer_type, id: var.self), ts: "> ", ts: name, ts: " {", ts&: resource_name, ts: "};");
1425 // Unsupported with argument buffer for now.
1426 statement(ts: "spvDescriptorArray<sampler> ", ts: name, ts: "Smplr {", ts: name, ts: "Smplr_};");
1427 break;
1428 case SPIRType::Struct:
1429 statement(ts: "spvDescriptorArray<", ts: get_argument_address_space(argument: var), ts: " ", ts: type_to_glsl(type: buffer_type), ts: "*> ",
1430 ts: name, ts: " {", ts&: resource_name, ts: "};");
1431 break;
1432 default:
1433 break;
1434 }
1435 has_runtime_array_declaration = true;
1436 }
1437 else if (!type.array.empty() && type.basetype == SPIRType::Struct)
1438 {
1439 // Emit only buffer arrays here.
1440 statement(ts: get_argument_address_space(argument: var), ts: " ", ts: type_to_glsl(type: buffer_type), ts: "* ",
1441 ts: to_restrict(id: var.self, space: true), ts: name, ts: "[] =");
1442 begin_scope();
1443 uint32_t array_size = get_resource_array_size(type, id: var.self);
1444 for (uint32_t i = 0; i < array_size; ++i)
1445 statement(ts: name, ts: "_", ts&: i, ts: ",");
1446 end_scope_decl();
1447 statement_no_indent(ts: "");
1448 }
1449 }
1450
1451 if (has_runtime_array_declaration)
1452 statement_no_indent(ts: "");
1453
1454 // Emit buffer aliases here.
1455 for (auto &var_id : buffer_aliases_discrete)
1456 {
1457 const auto &var = get<SPIRVariable>(id: var_id);
1458 const auto &type = get_variable_data_type(var);
1459 auto addr_space = get_argument_address_space(argument: var);
1460 auto name = to_name(id: var_id);
1461
1462 uint32_t desc_set = get_decoration(id: var_id, decoration: DecorationDescriptorSet);
1463 uint32_t desc_binding = get_decoration(id: var_id, decoration: DecorationBinding);
1464 auto alias_name = join(ts: "spvBufferAliasSet", ts&: desc_set, ts: "Binding", ts&: desc_binding);
1465
1466 statement(ts&: addr_space, ts: " auto& ", ts: to_restrict(id: var_id, space: true),
1467 ts&: name,
1468 ts: " = *(", ts&: addr_space, ts: " ", ts: type_to_glsl(type), ts: "*)", ts&: alias_name, ts: ";");
1469 }
1470 // Discrete descriptors are processed in entry point emission every compiler iteration.
1471 buffer_aliases_discrete.clear();
1472
1473 for (auto &var_pair : buffer_aliases_argument)
1474 {
1475 uint32_t var_id = var_pair.first;
1476 uint32_t alias_id = var_pair.second;
1477
1478 const auto &var = get<SPIRVariable>(id: var_id);
1479 const auto &type = get_variable_data_type(var);
1480 auto addr_space = get_argument_address_space(argument: var);
1481
1482 if (type.array.empty())
1483 {
1484 statement(ts&: addr_space, ts: " auto& ", ts: to_restrict(id: var_id, space: true), ts: to_name(id: var_id), ts: " = (", ts&: addr_space, ts: " ",
1485 ts: type_to_glsl(type), ts: "&)", ts&: ir.meta[alias_id].decoration.qualified_alias, ts: ";");
1486 }
1487 else
1488 {
1489 const char *desc_addr_space = descriptor_address_space(id: var_id, storage: var.storage, plain_address_space: "thread");
1490
1491 // Esoteric type cast. Reference to array of pointers.
1492 // Auto here defers to UBO or SSBO. The address space of the reference needs to refer to the
1493 // address space of the argument buffer itself, which is usually constant, but can be const device for
1494 // large argument buffers.
1495 is_using_builtin_array = true;
1496 statement(ts&: desc_addr_space, ts: " auto& ", ts: to_restrict(id: var_id, space: true), ts: to_name(id: var_id), ts: " = (", ts&: addr_space, ts: " ",
1497 ts: type_to_glsl(type), ts: "* ", ts&: desc_addr_space, ts: " (&)",
1498 ts: type_to_array_glsl(type, variable_id: var_id), ts: ")", ts&: ir.meta[alias_id].decoration.qualified_alias, ts: ";");
1499 is_using_builtin_array = false;
1500 }
1501 }
1502
1503 // Emit disabled fragment outputs.
1504 std::sort(first: disabled_frag_outputs.begin(), last: disabled_frag_outputs.end());
1505 for (uint32_t var_id : disabled_frag_outputs)
1506 {
1507 auto &var = get<SPIRVariable>(id: var_id);
1508 add_local_variable_name(id: var_id);
1509 statement(ts: CompilerGLSL::variable_decl(variable: var), ts: ";");
1510 var.deferred_declaration = false;
1511 }
1512}
1513
1514string CompilerMSL::compile()
1515{
1516 replace_illegal_entry_point_names();
1517 ir.fixup_reserved_names();
1518
1519 // Do not deal with GLES-isms like precision, older extensions and such.
1520 options.vulkan_semantics = true;
1521 options.es = false;
1522 options.version = 450;
1523 backend.null_pointer_literal = "nullptr";
1524 backend.float_literal_suffix = false;
1525 backend.uint32_t_literal_suffix = true;
1526 backend.int16_t_literal_suffix = "";
1527 backend.uint16_t_literal_suffix = "";
1528 backend.basic_int_type = "int";
1529 backend.basic_uint_type = "uint";
1530 backend.basic_int8_type = "char";
1531 backend.basic_uint8_type = "uchar";
1532 backend.basic_int16_type = "short";
1533 backend.basic_uint16_type = "ushort";
1534 backend.boolean_mix_function = "select";
1535 backend.swizzle_is_function = false;
1536 backend.shared_is_implied = false;
1537 backend.use_initializer_list = true;
1538 backend.use_typed_initializer_list = true;
1539 backend.native_row_major_matrix = false;
1540 backend.unsized_array_supported = false;
1541 backend.can_declare_arrays_inline = false;
1542 backend.allow_truncated_access_chain = true;
1543 backend.comparison_image_samples_scalar = true;
1544 backend.native_pointers = true;
1545 backend.nonuniform_qualifier = "";
1546 backend.support_small_type_sampling_result = true;
1547 backend.supports_empty_struct = true;
1548 backend.support_64bit_switch = true;
1549 backend.boolean_in_struct_remapped_type = SPIRType::Short;
1550
1551 // Allow Metal to use the array<T> template unless we force it off.
1552 backend.can_return_array = !msl_options.force_native_arrays;
1553 backend.array_is_value_type = !msl_options.force_native_arrays;
1554 // Arrays which are part of buffer objects are never considered to be value types (just plain C-style).
1555 backend.array_is_value_type_in_buffer_blocks = false;
1556 backend.support_pointer_to_pointer = true;
1557 backend.implicit_c_integer_promotion_rules = true;
1558
1559 capture_output_to_buffer = msl_options.capture_output_to_buffer;
1560 is_rasterization_disabled = msl_options.disable_rasterization || capture_output_to_buffer;
1561
1562 // Initialize array here rather than constructor, MSVC 2013 workaround.
1563 for (auto &id : next_metal_resource_ids)
1564 id = 0;
1565
1566 fixup_anonymous_struct_names();
1567 fixup_type_alias();
1568 replace_illegal_names();
1569 sync_entry_point_aliases_and_names();
1570
1571 build_function_control_flow_graphs_and_analyze();
1572 update_active_builtins();
1573 analyze_image_and_sampler_usage();
1574 analyze_sampled_image_usage();
1575 analyze_interlocked_resource_usage();
1576 preprocess_op_codes();
1577 build_implicit_builtins();
1578
1579 if (needs_manual_helper_invocation_updates() &&
1580 (active_input_builtins.get(bit: BuiltInHelperInvocation) || needs_helper_invocation))
1581 {
1582 string builtin_helper_invocation = builtin_to_glsl(builtin: BuiltInHelperInvocation, storage: StorageClassInput);
1583 string discard_expr = join(ts&: builtin_helper_invocation, ts: " = true, discard_fragment()");
1584 if (msl_options.force_fragment_with_side_effects_execution)
1585 discard_expr = join(ts: "!", ts&: builtin_helper_invocation, ts: " ? (", ts&: discard_expr, ts: ") : (void)0");
1586 backend.discard_literal = discard_expr;
1587 backend.demote_literal = discard_expr;
1588 }
1589 else
1590 {
1591 backend.discard_literal = "discard_fragment()";
1592 backend.demote_literal = "discard_fragment()";
1593 }
1594
1595 fixup_image_load_store_access();
1596
1597 set_enabled_interface_variables(get_active_interface_variables());
1598 if (msl_options.force_active_argument_buffer_resources)
1599 activate_argument_buffer_resources();
1600
1601 if (swizzle_buffer_id)
1602 add_active_interface_variable(var_id: swizzle_buffer_id);
1603 if (buffer_size_buffer_id)
1604 add_active_interface_variable(var_id: buffer_size_buffer_id);
1605 if (view_mask_buffer_id)
1606 add_active_interface_variable(var_id: view_mask_buffer_id);
1607 if (dynamic_offsets_buffer_id)
1608 add_active_interface_variable(var_id: dynamic_offsets_buffer_id);
1609 if (builtin_layer_id)
1610 add_active_interface_variable(var_id: builtin_layer_id);
1611 if (builtin_dispatch_base_id && !msl_options.supports_msl_version(major: 1, minor: 2))
1612 add_active_interface_variable(var_id: builtin_dispatch_base_id);
1613 if (builtin_sample_mask_id)
1614 add_active_interface_variable(var_id: builtin_sample_mask_id);
1615 if (builtin_frag_depth_id)
1616 add_active_interface_variable(var_id: builtin_frag_depth_id);
1617
1618 // Create structs to hold input, output and uniform variables.
1619 // Do output first to ensure out. is declared at top of entry function.
1620 qual_pos_var_name = "";
1621 stage_out_var_id = add_interface_block(storage: StorageClassOutput);
1622 patch_stage_out_var_id = add_interface_block(storage: StorageClassOutput, patch: true);
1623 stage_in_var_id = add_interface_block(storage: StorageClassInput);
1624 if (is_tese_shader())
1625 patch_stage_in_var_id = add_interface_block(storage: StorageClassInput, patch: true);
1626
1627 if (is_tesc_shader())
1628 stage_out_ptr_var_id = add_interface_block_pointer(ib_var_id: stage_out_var_id, storage: StorageClassOutput);
1629 if (is_tessellation_shader())
1630 stage_in_ptr_var_id = add_interface_block_pointer(ib_var_id: stage_in_var_id, storage: StorageClassInput);
1631
1632 // Metal vertex functions that define no output must disable rasterization and return void.
1633 if (!stage_out_var_id)
1634 is_rasterization_disabled = true;
1635
1636 // Convert the use of global variables to recursively-passed function parameters
1637 localize_global_variables();
1638 extract_global_variables_from_functions();
1639
1640 // Mark any non-stage-in structs to be tightly packed.
1641 mark_packable_structs();
1642 reorder_type_alias();
1643
1644 // Add fixup hooks required by shader inputs and outputs. This needs to happen before
1645 // the loop, so the hooks aren't added multiple times.
1646 fix_up_shader_inputs_outputs();
1647
1648 // If we are using argument buffers, we create argument buffer structures for them here.
1649 // These buffers will be used in the entry point, not the individual resources.
1650 if (msl_options.argument_buffers)
1651 {
1652 if (!msl_options.supports_msl_version(major: 2, minor: 0))
1653 SPIRV_CROSS_THROW("Argument buffers can only be used with MSL 2.0 and up.");
1654 analyze_argument_buffers();
1655 }
1656
1657 uint32_t pass_count = 0;
1658 do
1659 {
1660 reset(iteration_count: pass_count);
1661
1662 // Start bindings at zero.
1663 next_metal_resource_index_buffer = 0;
1664 next_metal_resource_index_texture = 0;
1665 next_metal_resource_index_sampler = 0;
1666 for (auto &id : next_metal_resource_ids)
1667 id = 0;
1668
1669 // Move constructor for this type is broken on GCC 4.9 ...
1670 buffer.reset();
1671
1672 emit_header();
1673 emit_custom_templates();
1674 emit_custom_functions();
1675 emit_specialization_constants_and_structs();
1676 emit_resources();
1677 emit_function(func&: get<SPIRFunction>(id: ir.default_entry_point), return_flags: Bitset());
1678
1679 pass_count++;
1680 } while (is_forcing_recompilation());
1681
1682 return buffer.str();
1683}
1684
1685// Register the need to output any custom functions.
1686void CompilerMSL::preprocess_op_codes()
1687{
1688 OpCodePreprocessor preproc(*this);
1689 traverse_all_reachable_opcodes(block: get<SPIRFunction>(id: ir.default_entry_point), handler&: preproc);
1690
1691 suppress_missing_prototypes = preproc.suppress_missing_prototypes;
1692
1693 if (preproc.uses_atomics)
1694 {
1695 add_header_line(str: "#include <metal_atomic>");
1696 add_pragma_line(line: "#pragma clang diagnostic ignored \"-Wunused-variable\"");
1697 }
1698
1699 // Before MSL 2.1 (2.2 for textures), Metal vertex functions that write to
1700 // resources must disable rasterization and return void.
1701 if ((preproc.uses_buffer_write && !msl_options.supports_msl_version(major: 2, minor: 1)) ||
1702 (preproc.uses_image_write && !msl_options.supports_msl_version(major: 2, minor: 2)))
1703 is_rasterization_disabled = true;
1704
1705 // Tessellation control shaders are run as compute functions in Metal, and so
1706 // must capture their output to a buffer.
1707 if (is_tesc_shader() || (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation))
1708 {
1709 is_rasterization_disabled = true;
1710 capture_output_to_buffer = true;
1711 }
1712
1713 if (preproc.needs_subgroup_invocation_id)
1714 needs_subgroup_invocation_id = true;
1715 if (preproc.needs_subgroup_size)
1716 needs_subgroup_size = true;
1717 // build_implicit_builtins() hasn't run yet, and in fact, this needs to execute
1718 // before then so that gl_SampleID will get added; so we also need to check if
1719 // that function would add gl_FragCoord.
1720 if (preproc.needs_sample_id || msl_options.force_sample_rate_shading ||
1721 (is_sample_rate() && (active_input_builtins.get(bit: BuiltInFragCoord) ||
1722 (need_subpass_input_ms && !msl_options.use_framebuffer_fetch_subpasses))))
1723 needs_sample_id = true;
1724 if (preproc.needs_helper_invocation)
1725 needs_helper_invocation = true;
1726
1727 // OpKill is removed by the parser, so we need to identify those by inspecting
1728 // blocks.
1729 ir.for_each_typed_id<SPIRBlock>(op: [&preproc](uint32_t, SPIRBlock &block) {
1730 if (block.terminator == SPIRBlock::Kill)
1731 preproc.uses_discard = true;
1732 });
1733
1734 // Fragment shaders that both write to storage resources and discard fragments
1735 // need checks on the writes, to work around Metal allowing these writes despite
1736 // the fragment being dead. We also require to force Metal to execute fragment
1737 // shaders instead of being prematurely discarded.
1738 if (preproc.uses_discard && (preproc.uses_buffer_write || preproc.uses_image_write))
1739 {
1740 bool should_enable = (msl_options.check_discarded_frag_stores || msl_options.force_fragment_with_side_effects_execution);
1741 frag_shader_needs_discard_checks |= msl_options.check_discarded_frag_stores;
1742 needs_helper_invocation |= should_enable;
1743 // Fragment discard store checks imply manual HelperInvocation updates.
1744 msl_options.manual_helper_invocation_updates |= should_enable;
1745 }
1746
1747 if (is_intersection_query())
1748 {
1749 add_header_line(str: "#if __METAL_VERSION__ >= 230");
1750 add_header_line(str: "#include <metal_raytracing>");
1751 add_header_line(str: "using namespace metal::raytracing;");
1752 add_header_line(str: "#endif");
1753 }
1754}
1755
1756// Move the Private and Workgroup global variables to the entry function.
1757// Non-constant variables cannot have global scope in Metal.
1758void CompilerMSL::localize_global_variables()
1759{
1760 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
1761 auto iter = global_variables.begin();
1762 while (iter != global_variables.end())
1763 {
1764 uint32_t v_id = *iter;
1765 auto &var = get<SPIRVariable>(id: v_id);
1766 if (var.storage == StorageClassPrivate || var.storage == StorageClassWorkgroup)
1767 {
1768 if (!variable_is_lut(var))
1769 entry_func.add_local_variable(id: v_id);
1770 iter = global_variables.erase(itr: iter);
1771 }
1772 else
1773 iter++;
1774 }
1775}
1776
1777// For any global variable accessed directly by a function,
1778// extract that variable and add it as an argument to that function.
1779void CompilerMSL::extract_global_variables_from_functions()
1780{
1781 // Uniforms
1782 unordered_set<uint32_t> global_var_ids;
1783 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
1784 // Some builtins resolve directly to a function call which does not need any declared variables.
1785 // Skip these.
1786 if (var.storage == StorageClassInput && has_decoration(id: var.self, decoration: DecorationBuiltIn))
1787 {
1788 auto bi_type = BuiltIn(get_decoration(id: var.self, decoration: DecorationBuiltIn));
1789 if (bi_type == BuiltInHelperInvocation && !needs_manual_helper_invocation_updates())
1790 return;
1791 if (bi_type == BuiltInHelperInvocation && needs_manual_helper_invocation_updates())
1792 {
1793 if (msl_options.is_ios() && !msl_options.supports_msl_version(major: 2, minor: 3))
1794 SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.3 on iOS.");
1795 else if (msl_options.is_macos() && !msl_options.supports_msl_version(major: 2, minor: 1))
1796 SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.1 on macOS.");
1797 // Make sure this is declared and initialized.
1798 // Force this to have the proper name.
1799 set_name(id: var.self, name: builtin_to_glsl(builtin: BuiltInHelperInvocation, storage: StorageClassInput));
1800 auto &entry_func = this->get<SPIRFunction>(id: ir.default_entry_point);
1801 entry_func.add_local_variable(id: var.self);
1802 vars_needing_early_declaration.push_back(t: var.self);
1803 entry_func.fixup_hooks_in.push_back(t: [this, &var]()
1804 { statement(ts: to_name(id: var.self), ts: " = simd_is_helper_thread();"); });
1805 }
1806 }
1807
1808 if (var.storage == StorageClassInput || var.storage == StorageClassOutput ||
1809 var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
1810 var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer)
1811 {
1812 global_var_ids.insert(x: var.self);
1813 }
1814 });
1815
1816 // Local vars that are declared in the main function and accessed directly by a function
1817 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
1818 for (auto &var : entry_func.local_variables)
1819 if (get<SPIRVariable>(id: var).storage != StorageClassFunction)
1820 global_var_ids.insert(x: var);
1821
1822 std::set<uint32_t> added_arg_ids;
1823 unordered_set<uint32_t> processed_func_ids;
1824 extract_global_variables_from_function(func_id: ir.default_entry_point, added_arg_ids, global_var_ids, processed_func_ids);
1825}
1826
1827// MSL does not support the use of global variables for shader input content.
1828// For any global variable accessed directly by the specified function, extract that variable,
1829// add it as an argument to that function, and the arg to the added_arg_ids collection.
1830void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::set<uint32_t> &added_arg_ids,
1831 unordered_set<uint32_t> &global_var_ids,
1832 unordered_set<uint32_t> &processed_func_ids)
1833{
1834 // Avoid processing a function more than once
1835 if (processed_func_ids.find(x: func_id) != processed_func_ids.end())
1836 {
1837 // Return function global variables
1838 added_arg_ids = function_global_vars[func_id];
1839 return;
1840 }
1841
1842 processed_func_ids.insert(x: func_id);
1843
1844 auto &func = get<SPIRFunction>(id: func_id);
1845
1846 // Recursively establish global args added to functions on which we depend.
1847 for (auto block : func.blocks)
1848 {
1849 auto &b = get<SPIRBlock>(id: block);
1850 for (auto &i : b.ops)
1851 {
1852 auto ops = stream(instr: i);
1853 auto op = static_cast<Op>(i.op);
1854
1855 switch (op)
1856 {
1857 case OpLoad:
1858 case OpInBoundsAccessChain:
1859 case OpAccessChain:
1860 case OpPtrAccessChain:
1861 case OpArrayLength:
1862 {
1863 uint32_t base_id = ops[2];
1864 if (global_var_ids.find(x: base_id) != global_var_ids.end())
1865 added_arg_ids.insert(x: base_id);
1866
1867 // Use Metal's native frame-buffer fetch API for subpass inputs.
1868 auto &type = get<SPIRType>(id: ops[0]);
1869 if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData &&
1870 (!msl_options.use_framebuffer_fetch_subpasses))
1871 {
1872 // Implicitly reads gl_FragCoord.
1873 assert(builtin_frag_coord_id != 0);
1874 added_arg_ids.insert(x: builtin_frag_coord_id);
1875 if (msl_options.multiview)
1876 {
1877 // Implicitly reads gl_ViewIndex.
1878 assert(builtin_view_idx_id != 0);
1879 added_arg_ids.insert(x: builtin_view_idx_id);
1880 }
1881 else if (msl_options.arrayed_subpass_input)
1882 {
1883 // Implicitly reads gl_Layer.
1884 assert(builtin_layer_id != 0);
1885 added_arg_ids.insert(x: builtin_layer_id);
1886 }
1887 }
1888
1889 break;
1890 }
1891
1892 case OpFunctionCall:
1893 {
1894 // First see if any of the function call args are globals
1895 for (uint32_t arg_idx = 3; arg_idx < i.length; arg_idx++)
1896 {
1897 uint32_t arg_id = ops[arg_idx];
1898 if (global_var_ids.find(x: arg_id) != global_var_ids.end())
1899 added_arg_ids.insert(x: arg_id);
1900 }
1901
1902 // Then recurse into the function itself to extract globals used internally in the function
1903 uint32_t inner_func_id = ops[2];
1904 std::set<uint32_t> inner_func_args;
1905 extract_global_variables_from_function(func_id: inner_func_id, added_arg_ids&: inner_func_args, global_var_ids,
1906 processed_func_ids);
1907 added_arg_ids.insert(first: inner_func_args.begin(), last: inner_func_args.end());
1908 break;
1909 }
1910
1911 case OpStore:
1912 {
1913 uint32_t base_id = ops[0];
1914 if (global_var_ids.find(x: base_id) != global_var_ids.end())
1915 {
1916 added_arg_ids.insert(x: base_id);
1917
1918 if (msl_options.input_attachment_is_ds_attachment && base_id == builtin_frag_depth_id)
1919 writes_to_depth = true;
1920 }
1921
1922 uint32_t rvalue_id = ops[1];
1923 if (global_var_ids.find(x: rvalue_id) != global_var_ids.end())
1924 added_arg_ids.insert(x: rvalue_id);
1925
1926 if (needs_frag_discard_checks())
1927 added_arg_ids.insert(x: builtin_helper_invocation_id);
1928
1929 break;
1930 }
1931
1932 case OpSelect:
1933 {
1934 uint32_t base_id = ops[3];
1935 if (global_var_ids.find(x: base_id) != global_var_ids.end())
1936 added_arg_ids.insert(x: base_id);
1937 base_id = ops[4];
1938 if (global_var_ids.find(x: base_id) != global_var_ids.end())
1939 added_arg_ids.insert(x: base_id);
1940 break;
1941 }
1942
1943 case OpAtomicExchange:
1944 case OpAtomicCompareExchange:
1945 case OpAtomicStore:
1946 case OpAtomicIIncrement:
1947 case OpAtomicIDecrement:
1948 case OpAtomicIAdd:
1949 case OpAtomicFAddEXT:
1950 case OpAtomicISub:
1951 case OpAtomicSMin:
1952 case OpAtomicUMin:
1953 case OpAtomicSMax:
1954 case OpAtomicUMax:
1955 case OpAtomicAnd:
1956 case OpAtomicOr:
1957 case OpAtomicXor:
1958 case OpImageWrite:
1959 {
1960 if (needs_frag_discard_checks())
1961 added_arg_ids.insert(x: builtin_helper_invocation_id);
1962 uint32_t ptr = 0;
1963 if (op == OpAtomicStore || op == OpImageWrite)
1964 ptr = ops[0];
1965 else
1966 ptr = ops[2];
1967 if (global_var_ids.find(x: ptr) != global_var_ids.end())
1968 added_arg_ids.insert(x: ptr);
1969 break;
1970 }
1971
1972 // Emulate texture2D atomic operations
1973 case OpImageTexelPointer:
1974 {
1975 // When using the pointer, we need to know which variable it is actually loaded from.
1976 uint32_t base_id = ops[2];
1977 auto *var = maybe_get_backing_variable(chain: base_id);
1978 if (var)
1979 {
1980 if (atomic_image_vars_emulated.count(x: var->self) &&
1981 !get<SPIRType>(id: var->basetype).array.empty())
1982 {
1983 SPIRV_CROSS_THROW(
1984 "Cannot emulate array of storage images with atomics. Use MSL 3.1 for native support.");
1985 }
1986
1987 if (global_var_ids.find(x: base_id) != global_var_ids.end())
1988 added_arg_ids.insert(x: base_id);
1989 }
1990 break;
1991 }
1992
1993 case OpExtInst:
1994 {
1995 uint32_t extension_set = ops[2];
1996 if (get<SPIRExtension>(id: extension_set).ext == SPIRExtension::GLSL)
1997 {
1998 auto op_450 = static_cast<GLSLstd450>(ops[3]);
1999 switch (op_450)
2000 {
2001 case GLSLstd450InterpolateAtCentroid:
2002 case GLSLstd450InterpolateAtSample:
2003 case GLSLstd450InterpolateAtOffset:
2004 {
2005 // For these, we really need the stage-in block. It is theoretically possible to pass the
2006 // interpolant object, but a) doing so would require us to create an entirely new variable
2007 // with Interpolant type, and b) if we have a struct or array, handling all the members and
2008 // elements could get unwieldy fast.
2009 added_arg_ids.insert(x: stage_in_var_id);
2010 break;
2011 }
2012
2013 case GLSLstd450Modf:
2014 case GLSLstd450Frexp:
2015 {
2016 uint32_t base_id = ops[5];
2017 if (global_var_ids.find(x: base_id) != global_var_ids.end())
2018 added_arg_ids.insert(x: base_id);
2019 break;
2020 }
2021
2022 default:
2023 break;
2024 }
2025 }
2026 break;
2027 }
2028
2029 case OpGroupNonUniformInverseBallot:
2030 {
2031 added_arg_ids.insert(x: builtin_subgroup_invocation_id_id);
2032 break;
2033 }
2034
2035 case OpGroupNonUniformBallotFindLSB:
2036 case OpGroupNonUniformBallotFindMSB:
2037 {
2038 added_arg_ids.insert(x: builtin_subgroup_size_id);
2039 break;
2040 }
2041
2042 case OpGroupNonUniformBallotBitCount:
2043 {
2044 auto operation = static_cast<GroupOperation>(ops[3]);
2045 switch (operation)
2046 {
2047 case GroupOperationReduce:
2048 added_arg_ids.insert(x: builtin_subgroup_size_id);
2049 break;
2050 case GroupOperationInclusiveScan:
2051 case GroupOperationExclusiveScan:
2052 added_arg_ids.insert(x: builtin_subgroup_invocation_id_id);
2053 break;
2054 default:
2055 break;
2056 }
2057 break;
2058 }
2059
2060 case OpDemoteToHelperInvocation:
2061 if (needs_manual_helper_invocation_updates() &&
2062 (active_input_builtins.get(bit: BuiltInHelperInvocation) || needs_helper_invocation))
2063 added_arg_ids.insert(x: builtin_helper_invocation_id);
2064 break;
2065
2066 case OpIsHelperInvocationEXT:
2067 if (needs_manual_helper_invocation_updates())
2068 added_arg_ids.insert(x: builtin_helper_invocation_id);
2069 break;
2070
2071 case OpRayQueryInitializeKHR:
2072 case OpRayQueryProceedKHR:
2073 case OpRayQueryTerminateKHR:
2074 case OpRayQueryGenerateIntersectionKHR:
2075 case OpRayQueryConfirmIntersectionKHR:
2076 {
2077 // Ray query accesses memory directly, need check pass down object if using Private storage class.
2078 uint32_t base_id = ops[0];
2079 if (global_var_ids.find(x: base_id) != global_var_ids.end())
2080 added_arg_ids.insert(x: base_id);
2081 break;
2082 }
2083
2084 case OpRayQueryGetRayTMinKHR:
2085 case OpRayQueryGetRayFlagsKHR:
2086 case OpRayQueryGetWorldRayOriginKHR:
2087 case OpRayQueryGetWorldRayDirectionKHR:
2088 case OpRayQueryGetIntersectionCandidateAABBOpaqueKHR:
2089 case OpRayQueryGetIntersectionTypeKHR:
2090 case OpRayQueryGetIntersectionTKHR:
2091 case OpRayQueryGetIntersectionInstanceCustomIndexKHR:
2092 case OpRayQueryGetIntersectionInstanceIdKHR:
2093 case OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR:
2094 case OpRayQueryGetIntersectionGeometryIndexKHR:
2095 case OpRayQueryGetIntersectionPrimitiveIndexKHR:
2096 case OpRayQueryGetIntersectionBarycentricsKHR:
2097 case OpRayQueryGetIntersectionFrontFaceKHR:
2098 case OpRayQueryGetIntersectionObjectRayDirectionKHR:
2099 case OpRayQueryGetIntersectionObjectRayOriginKHR:
2100 case OpRayQueryGetIntersectionObjectToWorldKHR:
2101 case OpRayQueryGetIntersectionWorldToObjectKHR:
2102 {
2103 // Ray query accesses memory directly, need check pass down object if using Private storage class.
2104 uint32_t base_id = ops[2];
2105 if (global_var_ids.find(x: base_id) != global_var_ids.end())
2106 added_arg_ids.insert(x: base_id);
2107 break;
2108 }
2109
2110 default:
2111 break;
2112 }
2113
2114 if (needs_manual_helper_invocation_updates() && b.terminator == SPIRBlock::Kill &&
2115 (active_input_builtins.get(bit: BuiltInHelperInvocation) || needs_helper_invocation))
2116 added_arg_ids.insert(x: builtin_helper_invocation_id);
2117
2118 // TODO: Add all other operations which can affect memory.
2119 // We should consider a more unified system here to reduce boiler-plate.
2120 // This kind of analysis is done in several places ...
2121 }
2122 }
2123
2124 function_global_vars[func_id] = added_arg_ids;
2125
2126 // Add the global variables as arguments to the function
2127 if (func_id != ir.default_entry_point)
2128 {
2129 bool control_point_added_in = false;
2130 bool control_point_added_out = false;
2131 bool patch_added_in = false;
2132 bool patch_added_out = false;
2133
2134 for (uint32_t arg_id : added_arg_ids)
2135 {
2136 auto &var = get<SPIRVariable>(id: arg_id);
2137 uint32_t type_id = var.basetype;
2138 auto *p_type = &get<SPIRType>(id: type_id);
2139 BuiltIn bi_type = BuiltIn(get_decoration(id: arg_id, decoration: DecorationBuiltIn));
2140
2141 bool is_patch = has_decoration(id: arg_id, decoration: DecorationPatch) || is_patch_block(type: *p_type);
2142 bool is_block = has_decoration(id: p_type->self, decoration: DecorationBlock);
2143 bool is_control_point_storage =
2144 !is_patch && ((is_tessellation_shader() && var.storage == StorageClassInput) ||
2145 (is_tesc_shader() && var.storage == StorageClassOutput));
2146 bool is_patch_block_storage = is_patch && is_block && var.storage == StorageClassOutput;
2147 bool is_builtin = is_builtin_variable(var);
2148 bool variable_is_stage_io =
2149 !is_builtin || bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
2150 bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance ||
2151 p_type->basetype == SPIRType::Struct;
2152 bool is_redirected_to_global_stage_io = (is_control_point_storage || is_patch_block_storage) &&
2153 variable_is_stage_io;
2154
2155 // If output is masked it is not considered part of the global stage IO interface.
2156 if (is_redirected_to_global_stage_io && var.storage == StorageClassOutput)
2157 is_redirected_to_global_stage_io = !is_stage_output_variable_masked(var);
2158
2159 if (is_redirected_to_global_stage_io)
2160 {
2161 // Tessellation control shaders see inputs and per-point outputs as arrays.
2162 // Similarly, tessellation evaluation shaders see per-point inputs as arrays.
2163 // We collected them into a structure; we must pass the array of this
2164 // structure to the function.
2165 std::string name;
2166 if (is_patch)
2167 name = var.storage == StorageClassInput ? patch_stage_in_var_name : patch_stage_out_var_name;
2168 else
2169 name = var.storage == StorageClassInput ? "gl_in" : "gl_out";
2170
2171 if (var.storage == StorageClassOutput && has_decoration(id: p_type->self, decoration: DecorationBlock))
2172 {
2173 // If we're redirecting a block, we might still need to access the original block
2174 // variable if we're masking some members.
2175 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(p_type->member_types.size()); mbr_idx++)
2176 {
2177 if (is_stage_output_block_member_masked(var, index: mbr_idx, strip_array: true))
2178 {
2179 func.add_parameter(parameter_type: var.basetype, id: var.self, alias_global_variable: true);
2180 break;
2181 }
2182 }
2183 }
2184
2185 if (var.storage == StorageClassInput)
2186 {
2187 auto &added_in = is_patch ? patch_added_in : control_point_added_in;
2188 if (added_in)
2189 continue;
2190 arg_id = is_patch ? patch_stage_in_var_id : stage_in_ptr_var_id;
2191 added_in = true;
2192 }
2193 else if (var.storage == StorageClassOutput)
2194 {
2195 auto &added_out = is_patch ? patch_added_out : control_point_added_out;
2196 if (added_out)
2197 continue;
2198 arg_id = is_patch ? patch_stage_out_var_id : stage_out_ptr_var_id;
2199 added_out = true;
2200 }
2201
2202 type_id = get<SPIRVariable>(id: arg_id).basetype;
2203 uint32_t next_id = ir.increase_bound_by(count: 1);
2204 func.add_parameter(parameter_type: type_id, id: next_id, alias_global_variable: true);
2205 set<SPIRVariable>(id: next_id, args&: type_id, args: StorageClassFunction, args: 0, args&: arg_id);
2206
2207 set_name(id: next_id, name);
2208 if (is_tese_shader() && msl_options.raw_buffer_tese_input && var.storage == StorageClassInput)
2209 set_decoration(id: next_id, decoration: DecorationNonWritable);
2210 }
2211 else if (is_builtin && has_decoration(id: p_type->self, decoration: DecorationBlock))
2212 {
2213 // Get the pointee type
2214 type_id = get_pointee_type_id(type_id);
2215 p_type = &get<SPIRType>(id: type_id);
2216
2217 uint32_t mbr_idx = 0;
2218 for (auto &mbr_type_id : p_type->member_types)
2219 {
2220 BuiltIn builtin = BuiltInMax;
2221 is_builtin = is_member_builtin(type: *p_type, index: mbr_idx, builtin: &builtin);
2222 if (is_builtin && has_active_builtin(builtin, storage: var.storage))
2223 {
2224 // Add a arg variable with the same type and decorations as the member
2225 uint32_t next_ids = ir.increase_bound_by(count: 2);
2226 uint32_t ptr_type_id = next_ids + 0;
2227 uint32_t var_id = next_ids + 1;
2228
2229 // Make sure we have an actual pointer type,
2230 // so that we will get the appropriate address space when declaring these builtins.
2231 auto &ptr = set<SPIRType>(id: ptr_type_id, args&: get<SPIRType>(id: mbr_type_id));
2232 ptr.self = mbr_type_id;
2233 ptr.storage = var.storage;
2234 ptr.pointer = true;
2235 ptr.pointer_depth++;
2236 ptr.parent_type = mbr_type_id;
2237
2238 func.add_parameter(parameter_type: mbr_type_id, id: var_id, alias_global_variable: true);
2239 set<SPIRVariable>(id: var_id, args&: ptr_type_id, args: StorageClassFunction);
2240 ir.meta[var_id].decoration = ir.meta[type_id].members[mbr_idx];
2241 }
2242 mbr_idx++;
2243 }
2244 }
2245 else
2246 {
2247 uint32_t next_id = ir.increase_bound_by(count: 1);
2248 func.add_parameter(parameter_type: type_id, id: next_id, alias_global_variable: true);
2249 set<SPIRVariable>(id: next_id, args&: type_id, args: StorageClassFunction, args: 0, args&: arg_id);
2250
2251 // Ensure the new variable has all the same meta info
2252 ir.meta[next_id] = ir.meta[arg_id];
2253 }
2254 }
2255 }
2256}
2257
2258// For all variables that are some form of non-input-output interface block, mark that all the structs
2259// that are recursively contained within the type referenced by that variable should be packed tightly.
2260void CompilerMSL::mark_packable_structs()
2261{
2262 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
2263 if (var.storage != StorageClassFunction && !is_hidden_variable(var))
2264 {
2265 auto &type = this->get<SPIRType>(id: var.basetype);
2266 if (type.pointer &&
2267 (type.storage == StorageClassUniform || type.storage == StorageClassUniformConstant ||
2268 type.storage == StorageClassPushConstant || type.storage == StorageClassStorageBuffer) &&
2269 (has_decoration(id: type.self, decoration: DecorationBlock) || has_decoration(id: type.self, decoration: DecorationBufferBlock)))
2270 mark_as_packable(type);
2271 }
2272
2273 if (var.storage == StorageClassWorkgroup)
2274 {
2275 auto *type = &this->get<SPIRType>(id: var.basetype);
2276 if (type->basetype == SPIRType::Struct)
2277 mark_as_workgroup_struct(type&: *type);
2278 }
2279 });
2280
2281 // Physical storage buffer pointers can appear outside of the context of a variable, if the address
2282 // is calculated from a ulong or uvec2 and cast to a pointer, so check if they need to be packed too.
2283 ir.for_each_typed_id<SPIRType>(op: [&](uint32_t, SPIRType &type) {
2284 if (type.basetype == SPIRType::Struct && type.pointer && type.storage == StorageClassPhysicalStorageBuffer)
2285 mark_as_packable(type);
2286 });
2287}
2288
2289// If the specified type is a struct, it and any nested structs
2290// are marked as packable with the SPIRVCrossDecorationBufferBlockRepacked decoration,
2291void CompilerMSL::mark_as_packable(SPIRType &type)
2292{
2293 // If this is not the base type (eg. it's a pointer or array), tunnel down
2294 if (type.parent_type)
2295 {
2296 mark_as_packable(type&: get<SPIRType>(id: type.parent_type));
2297 return;
2298 }
2299
2300 // Handle possible recursion when a struct contains a pointer to its own type nested somewhere.
2301 if (type.basetype == SPIRType::Struct && !has_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationBufferBlockRepacked))
2302 {
2303 set_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationBufferBlockRepacked);
2304
2305 // Recurse
2306 uint32_t mbr_cnt = uint32_t(type.member_types.size());
2307 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
2308 {
2309 uint32_t mbr_type_id = type.member_types[mbr_idx];
2310 auto &mbr_type = get<SPIRType>(id: mbr_type_id);
2311 mark_as_packable(type&: mbr_type);
2312 if (mbr_type.type_alias)
2313 {
2314 auto &mbr_type_alias = get<SPIRType>(id: mbr_type.type_alias);
2315 mark_as_packable(type&: mbr_type_alias);
2316 }
2317 }
2318 }
2319}
2320
2321// If the specified type is a struct, it and any nested structs
2322// are marked as used with workgroup storage using the SPIRVCrossDecorationWorkgroupStruct decoration.
2323void CompilerMSL::mark_as_workgroup_struct(SPIRType &type)
2324{
2325 // If this is not the base type (eg. it's a pointer or array), tunnel down
2326 if (type.parent_type)
2327 {
2328 mark_as_workgroup_struct(type&: get<SPIRType>(id: type.parent_type));
2329 return;
2330 }
2331
2332 // Handle possible recursion when a struct contains a pointer to its own type nested somewhere.
2333 if (type.basetype == SPIRType::Struct && !has_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationWorkgroupStruct))
2334 {
2335 set_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationWorkgroupStruct);
2336
2337 // Recurse
2338 uint32_t mbr_cnt = uint32_t(type.member_types.size());
2339 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
2340 {
2341 uint32_t mbr_type_id = type.member_types[mbr_idx];
2342 auto &mbr_type = get<SPIRType>(id: mbr_type_id);
2343 mark_as_workgroup_struct(type&: mbr_type);
2344 if (mbr_type.type_alias)
2345 {
2346 auto &mbr_type_alias = get<SPIRType>(id: mbr_type.type_alias);
2347 mark_as_workgroup_struct(type&: mbr_type_alias);
2348 }
2349 }
2350 }
2351}
2352
2353// If a shader input exists at the location, it is marked as being used by this shader
2354void CompilerMSL::mark_location_as_used_by_shader(uint32_t location, const SPIRType &type,
2355 StorageClass storage, bool fallback)
2356{
2357 uint32_t count = type_to_location_count(type);
2358 switch (storage)
2359 {
2360 case StorageClassInput:
2361 for (uint32_t i = 0; i < count; i++)
2362 {
2363 location_inputs_in_use.insert(x: location + i);
2364 if (fallback)
2365 location_inputs_in_use_fallback.insert(x: location + i);
2366 }
2367 break;
2368 case StorageClassOutput:
2369 for (uint32_t i = 0; i < count; i++)
2370 {
2371 location_outputs_in_use.insert(x: location + i);
2372 if (fallback)
2373 location_outputs_in_use_fallback.insert(x: location + i);
2374 }
2375 break;
2376 default:
2377 return;
2378 }
2379}
2380
2381uint32_t CompilerMSL::get_target_components_for_fragment_location(uint32_t location) const
2382{
2383 auto itr = fragment_output_components.find(x: location);
2384 if (itr == end(cont: fragment_output_components))
2385 return 4;
2386 else
2387 return itr->second;
2388}
2389
2390uint32_t CompilerMSL::build_extended_vector_type(uint32_t type_id, uint32_t components, SPIRType::BaseType basetype)
2391{
2392 assert(components > 1);
2393 uint32_t new_type_id = ir.increase_bound_by(count: 1);
2394 const auto *p_old_type = &get<SPIRType>(id: type_id);
2395 const SPIRType *old_ptr_t = nullptr;
2396 const SPIRType *old_array_t = nullptr;
2397
2398 if (is_pointer(type: *p_old_type))
2399 {
2400 old_ptr_t = p_old_type;
2401 p_old_type = &get_pointee_type(type: *old_ptr_t);
2402 }
2403
2404 if (is_array(type: *p_old_type))
2405 {
2406 old_array_t = p_old_type;
2407 p_old_type = &get_type(id: old_array_t->parent_type);
2408 }
2409
2410 auto *type = &set<SPIRType>(id: new_type_id, args: *p_old_type);
2411 assert(is_scalar(*type) || is_vector(*type));
2412 type->op = OpTypeVector;
2413 type->vecsize = components;
2414 if (basetype != SPIRType::Unknown)
2415 type->basetype = basetype;
2416 type->self = new_type_id;
2417 // We want parent type to point to the scalar type.
2418 type->parent_type = is_scalar(type: *p_old_type) ? TypeID(p_old_type->self) : p_old_type->parent_type;
2419 assert(is_scalar(get<SPIRType>(type->parent_type)));
2420 type->array.clear();
2421 type->array_size_literal.clear();
2422 type->pointer = false;
2423
2424 if (old_array_t)
2425 {
2426 uint32_t array_type_id = ir.increase_bound_by(count: 1);
2427 type = &set<SPIRType>(id: array_type_id, args&: *type);
2428 type->op = OpTypeArray;
2429 type->parent_type = new_type_id;
2430 type->array = old_array_t->array;
2431 type->array_size_literal = old_array_t->array_size_literal;
2432 new_type_id = array_type_id;
2433 }
2434
2435 if (old_ptr_t)
2436 {
2437 uint32_t ptr_type_id = ir.increase_bound_by(count: 1);
2438 type = &set<SPIRType>(id: ptr_type_id, args&: *type);
2439 type->op = OpTypePointer;
2440 type->parent_type = new_type_id;
2441 type->storage = old_ptr_t->storage;
2442 type->pointer = true;
2443 type->pointer_depth++;
2444 new_type_id = ptr_type_id;
2445 }
2446
2447 return new_type_id;
2448}
2449
2450uint32_t CompilerMSL::build_msl_interpolant_type(uint32_t type_id, bool is_noperspective)
2451{
2452 uint32_t new_type_id = ir.increase_bound_by(count: 1);
2453 SPIRType &type = set<SPIRType>(id: new_type_id, args&: get<SPIRType>(id: type_id));
2454 type.basetype = SPIRType::Interpolant;
2455 type.parent_type = type_id;
2456 // In Metal, the pull-model interpolant type encodes perspective-vs-no-perspective in the type itself.
2457 // Add this decoration so we know which argument to pass to the template.
2458 if (is_noperspective)
2459 set_decoration(id: new_type_id, decoration: DecorationNoPerspective);
2460 return new_type_id;
2461}
2462
2463bool CompilerMSL::add_component_variable_to_interface_block(spv::StorageClass storage, const std::string &ib_var_ref,
2464 SPIRVariable &var,
2465 const SPIRType &type,
2466 InterfaceBlockMeta &meta)
2467{
2468 // Deal with Component decorations.
2469 const InterfaceBlockMeta::LocationMeta *location_meta = nullptr;
2470 uint32_t location = ~0u;
2471 if (has_decoration(id: var.self, decoration: DecorationLocation))
2472 {
2473 location = get_decoration(id: var.self, decoration: DecorationLocation);
2474 auto location_meta_itr = meta.location_meta.find(x: location);
2475 if (location_meta_itr != end(cont&: meta.location_meta))
2476 location_meta = &location_meta_itr->second;
2477 }
2478
2479 // Check if we need to pad fragment output to match a certain number of components.
2480 if (location_meta)
2481 {
2482 bool pad_fragment_output = has_decoration(id: var.self, decoration: DecorationLocation) &&
2483 msl_options.pad_fragment_output_components &&
2484 get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput;
2485
2486 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
2487 uint32_t start_component = get_decoration(id: var.self, decoration: DecorationComponent);
2488 uint32_t type_components = type.vecsize;
2489 uint32_t num_components = location_meta->num_components;
2490
2491 if (pad_fragment_output)
2492 {
2493 uint32_t locn = get_decoration(id: var.self, decoration: DecorationLocation);
2494 num_components = max<uint32_t>(a: num_components, b: get_target_components_for_fragment_location(location: locn));
2495 }
2496
2497 // We have already declared an IO block member as m_location_N.
2498 // Just emit an early-declared variable and fixup as needed.
2499 // Arrays need to be unrolled here since each location might need a different number of components.
2500 entry_func.add_local_variable(id: var.self);
2501 vars_needing_early_declaration.push_back(t: var.self);
2502
2503 if (var.storage == StorageClassInput)
2504 {
2505 entry_func.fixup_hooks_in.push_back(t: [=, &type, &var]() {
2506 if (!type.array.empty())
2507 {
2508 uint32_t array_size = to_array_size_literal(type);
2509 for (uint32_t loc_off = 0; loc_off < array_size; loc_off++)
2510 {
2511 statement(ts: to_name(id: var.self), ts: "[", ts&: loc_off, ts: "]", ts: " = ", ts: ib_var_ref,
2512 ts: ".m_location_", ts: location + loc_off,
2513 ts: vector_swizzle(vecsize: type_components, index: start_component), ts: ";");
2514 }
2515 }
2516 else
2517 {
2518 statement(ts: to_name(id: var.self), ts: " = ", ts: ib_var_ref, ts: ".m_location_", ts: location,
2519 ts: vector_swizzle(vecsize: type_components, index: start_component), ts: ";");
2520 }
2521 });
2522 }
2523 else
2524 {
2525 entry_func.fixup_hooks_out.push_back(t: [=, &type, &var]() {
2526 if (!type.array.empty())
2527 {
2528 uint32_t array_size = to_array_size_literal(type);
2529 for (uint32_t loc_off = 0; loc_off < array_size; loc_off++)
2530 {
2531 statement(ts: ib_var_ref, ts: ".m_location_", ts: location + loc_off,
2532 ts: vector_swizzle(vecsize: type_components, index: start_component), ts: " = ",
2533 ts: to_name(id: var.self), ts: "[", ts&: loc_off, ts: "];");
2534 }
2535 }
2536 else
2537 {
2538 statement(ts: ib_var_ref, ts: ".m_location_", ts: location,
2539 ts: vector_swizzle(vecsize: type_components, index: start_component), ts: " = ", ts: to_name(id: var.self), ts: ";");
2540 }
2541 });
2542 }
2543 return true;
2544 }
2545 else
2546 return false;
2547}
2548
2549void CompilerMSL::add_plain_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2550 SPIRType &ib_type, SPIRVariable &var, InterfaceBlockMeta &meta)
2551{
2552 bool is_builtin = is_builtin_variable(var);
2553 BuiltIn builtin = BuiltIn(get_decoration(id: var.self, decoration: DecorationBuiltIn));
2554 bool is_flat = has_decoration(id: var.self, decoration: DecorationFlat);
2555 bool is_noperspective = has_decoration(id: var.self, decoration: DecorationNoPerspective);
2556 bool is_centroid = has_decoration(id: var.self, decoration: DecorationCentroid);
2557 bool is_sample = has_decoration(id: var.self, decoration: DecorationSample);
2558
2559 // Add a reference to the variable type to the interface struct.
2560 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2561 uint32_t type_id = ensure_correct_builtin_type(type_id: var.basetype, builtin);
2562 var.basetype = type_id;
2563
2564 type_id = get_pointee_type_id(type_id: var.basetype);
2565 if (meta.strip_array && is_array(type: get<SPIRType>(id: type_id)))
2566 type_id = get<SPIRType>(id: type_id).parent_type;
2567 auto &type = get<SPIRType>(id: type_id);
2568 uint32_t target_components = 0;
2569 uint32_t type_components = type.vecsize;
2570
2571 bool padded_output = false;
2572 bool padded_input = false;
2573 uint32_t start_component = 0;
2574
2575 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
2576
2577 if (add_component_variable_to_interface_block(storage, ib_var_ref, var, type, meta))
2578 return;
2579
2580 bool pad_fragment_output = has_decoration(id: var.self, decoration: DecorationLocation) &&
2581 msl_options.pad_fragment_output_components &&
2582 get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput;
2583
2584 if (pad_fragment_output)
2585 {
2586 uint32_t locn = get_decoration(id: var.self, decoration: DecorationLocation);
2587 target_components = get_target_components_for_fragment_location(location: locn);
2588 if (type_components < target_components)
2589 {
2590 // Make a new type here.
2591 type_id = build_extended_vector_type(type_id, components: target_components);
2592 padded_output = true;
2593 }
2594 }
2595
2596 if (storage == StorageClassInput && pull_model_inputs.count(x: var.self))
2597 ib_type.member_types.push_back(t: build_msl_interpolant_type(type_id, is_noperspective));
2598 else
2599 ib_type.member_types.push_back(t: type_id);
2600
2601 // Give the member a name
2602 string mbr_name = ensure_valid_name(name: to_expression(id: var.self), pfx: "m");
2603 set_member_name(id: ib_type.self, index: ib_mbr_idx, name: mbr_name);
2604
2605 // Update the original variable reference to include the structure reference
2606 string qual_var_name = ib_var_ref + "." + mbr_name;
2607 // If using pull-model interpolation, need to add a call to the correct interpolation method.
2608 if (storage == StorageClassInput && pull_model_inputs.count(x: var.self))
2609 {
2610 if (is_centroid)
2611 qual_var_name += ".interpolate_at_centroid()";
2612 else if (is_sample)
2613 qual_var_name += join(ts: ".interpolate_at_sample(", ts: to_expression(id: builtin_sample_id_id), ts: ")");
2614 else
2615 qual_var_name += ".interpolate_at_center()";
2616 }
2617
2618 if (padded_output || padded_input)
2619 {
2620 entry_func.add_local_variable(id: var.self);
2621 vars_needing_early_declaration.push_back(t: var.self);
2622
2623 if (padded_output)
2624 {
2625 entry_func.fixup_hooks_out.push_back(t: [=, &var]() {
2626 statement(ts: qual_var_name, ts: vector_swizzle(vecsize: type_components, index: start_component), ts: " = ", ts: to_name(id: var.self),
2627 ts: ";");
2628 });
2629 }
2630 else
2631 {
2632 entry_func.fixup_hooks_in.push_back(t: [=, &var]() {
2633 statement(ts: to_name(id: var.self), ts: " = ", ts: qual_var_name, ts: vector_swizzle(vecsize: type_components, index: start_component),
2634 ts: ";");
2635 });
2636 }
2637 }
2638 else if (!meta.strip_array)
2639 ir.meta[var.self].decoration.qualified_alias = qual_var_name;
2640
2641 if (var.storage == StorageClassOutput && var.initializer != ID(0))
2642 {
2643 if (padded_output || padded_input)
2644 {
2645 entry_func.fixup_hooks_in.push_back(
2646 t: [=, &var]() { statement(ts: to_name(id: var.self), ts: " = ", ts: to_expression(id: var.initializer), ts: ";"); });
2647 }
2648 else
2649 {
2650 if (meta.strip_array)
2651 {
2652 entry_func.fixup_hooks_in.push_back(t: [=, &var]() {
2653 uint32_t index = get_extended_decoration(id: var.self, decoration: SPIRVCrossDecorationInterfaceMemberIndex);
2654 auto invocation = to_tesc_invocation_id();
2655 statement(ts: to_expression(id: stage_out_ptr_var_id), ts: "[",
2656 ts&: invocation, ts: "].",
2657 ts: to_member_name(type: ib_type, index), ts: " = ", ts: to_expression(id: var.initializer), ts: "[",
2658 ts&: invocation, ts: "];");
2659 });
2660 }
2661 else
2662 {
2663 entry_func.fixup_hooks_in.push_back(t: [=, &var]() {
2664 statement(ts: qual_var_name, ts: " = ", ts: to_expression(id: var.initializer), ts: ";");
2665 });
2666 }
2667 }
2668 }
2669
2670 // Copy the variable location from the original variable to the member
2671 if (get_decoration_bitset(id: var.self).get(bit: DecorationLocation))
2672 {
2673 uint32_t locn = get_decoration(id: var.self, decoration: DecorationLocation);
2674 uint32_t comp = get_decoration(id: var.self, decoration: DecorationComponent);
2675 if (storage == StorageClassInput)
2676 {
2677 type_id = ensure_correct_input_type(type_id: var.basetype, location: locn, component: comp, num_components: 0, strip_array: meta.strip_array);
2678 var.basetype = type_id;
2679
2680 type_id = get_pointee_type_id(type_id);
2681 if (meta.strip_array && is_array(type: get<SPIRType>(id: type_id)))
2682 type_id = get<SPIRType>(id: type_id).parent_type;
2683 if (pull_model_inputs.count(x: var.self))
2684 ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(type_id, is_noperspective);
2685 else
2686 ib_type.member_types[ib_mbr_idx] = type_id;
2687 }
2688 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: locn);
2689 if (comp)
2690 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationComponent, argument: comp);
2691 mark_location_as_used_by_shader(location: locn, type: get<SPIRType>(id: type_id), storage);
2692 }
2693 else if (is_builtin && is_tessellation_shader() && storage == StorageClassInput && inputs_by_builtin.count(x: builtin))
2694 {
2695 uint32_t locn = inputs_by_builtin[builtin].location;
2696 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: locn);
2697 mark_location_as_used_by_shader(location: locn, type, storage);
2698 }
2699 else if (is_builtin && capture_output_to_buffer && storage == StorageClassOutput && outputs_by_builtin.count(x: builtin))
2700 {
2701 uint32_t locn = outputs_by_builtin[builtin].location;
2702 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: locn);
2703 mark_location_as_used_by_shader(location: locn, type, storage);
2704 }
2705
2706 if (get_decoration_bitset(id: var.self).get(bit: DecorationComponent))
2707 {
2708 uint32_t component = get_decoration(id: var.self, decoration: DecorationComponent);
2709 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationComponent, argument: component);
2710 }
2711
2712 if (get_decoration_bitset(id: var.self).get(bit: DecorationIndex))
2713 {
2714 uint32_t index = get_decoration(id: var.self, decoration: DecorationIndex);
2715 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationIndex, argument: index);
2716 }
2717
2718 // Mark the member as builtin if needed
2719 if (is_builtin)
2720 {
2721 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationBuiltIn, argument: builtin);
2722 if (builtin == BuiltInPosition && storage == StorageClassOutput)
2723 qual_pos_var_name = qual_var_name;
2724 }
2725
2726 // Copy interpolation decorations if needed
2727 if (storage != StorageClassInput || !pull_model_inputs.count(x: var.self))
2728 {
2729 if (is_flat)
2730 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationFlat);
2731 if (is_noperspective)
2732 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationNoPerspective);
2733 if (is_centroid)
2734 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationCentroid);
2735 if (is_sample)
2736 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationSample);
2737 }
2738
2739 set_extended_member_decoration(type: ib_type.self, index: ib_mbr_idx, decoration: SPIRVCrossDecorationInterfaceOrigID, value: var.self);
2740}
2741
2742void CompilerMSL::add_composite_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2743 SPIRType &ib_type, SPIRVariable &var,
2744 InterfaceBlockMeta &meta)
2745{
2746 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
2747 auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2748 uint32_t elem_cnt = 0;
2749
2750 if (add_component_variable_to_interface_block(storage, ib_var_ref, var, type: var_type, meta))
2751 return;
2752
2753 if (is_matrix(type: var_type))
2754 {
2755 if (is_array(type: var_type))
2756 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
2757
2758 elem_cnt = var_type.columns;
2759 }
2760 else if (is_array(type: var_type))
2761 {
2762 if (var_type.array.size() != 1)
2763 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
2764
2765 elem_cnt = to_array_size_literal(type: var_type);
2766 }
2767
2768 bool is_builtin = is_builtin_variable(var);
2769 BuiltIn builtin = BuiltIn(get_decoration(id: var.self, decoration: DecorationBuiltIn));
2770 bool is_flat = has_decoration(id: var.self, decoration: DecorationFlat);
2771 bool is_noperspective = has_decoration(id: var.self, decoration: DecorationNoPerspective);
2772 bool is_centroid = has_decoration(id: var.self, decoration: DecorationCentroid);
2773 bool is_sample = has_decoration(id: var.self, decoration: DecorationSample);
2774
2775 auto *usable_type = &var_type;
2776 if (usable_type->pointer)
2777 usable_type = &get<SPIRType>(id: usable_type->parent_type);
2778 while (is_array(type: *usable_type) || is_matrix(type: *usable_type))
2779 usable_type = &get<SPIRType>(id: usable_type->parent_type);
2780
2781 // If a builtin, force it to have the proper name.
2782 if (is_builtin)
2783 set_name(id: var.self, name: builtin_to_glsl(builtin, storage: StorageClassFunction));
2784
2785 bool flatten_from_ib_var = false;
2786 string flatten_from_ib_mbr_name;
2787
2788 if (storage == StorageClassOutput && is_builtin && builtin == BuiltInClipDistance)
2789 {
2790 // Also declare [[clip_distance]] attribute here.
2791 uint32_t clip_array_mbr_idx = uint32_t(ib_type.member_types.size());
2792 ib_type.member_types.push_back(t: get_variable_data_type_id(var));
2793 set_member_decoration(id: ib_type.self, index: clip_array_mbr_idx, decoration: DecorationBuiltIn, argument: BuiltInClipDistance);
2794
2795 flatten_from_ib_mbr_name = builtin_to_glsl(builtin: BuiltInClipDistance, storage: StorageClassOutput);
2796 set_member_name(id: ib_type.self, index: clip_array_mbr_idx, name: flatten_from_ib_mbr_name);
2797
2798 // When we flatten, we flatten directly from the "out" struct,
2799 // not from a function variable.
2800 flatten_from_ib_var = true;
2801
2802 if (!msl_options.enable_clip_distance_user_varying)
2803 return;
2804 }
2805 else if (!meta.strip_array)
2806 {
2807 // Only flatten/unflatten IO composites for non-tessellation cases where arrays are not stripped.
2808 entry_func.add_local_variable(id: var.self);
2809 // We need to declare the variable early and at entry-point scope.
2810 vars_needing_early_declaration.push_back(t: var.self);
2811 }
2812
2813 for (uint32_t i = 0; i < elem_cnt; i++)
2814 {
2815 // Add a reference to the variable type to the interface struct.
2816 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2817
2818 uint32_t target_components = 0;
2819 bool padded_output = false;
2820 uint32_t type_id = usable_type->self;
2821
2822 // Check if we need to pad fragment output to match a certain number of components.
2823 if (get_decoration_bitset(id: var.self).get(bit: DecorationLocation) && msl_options.pad_fragment_output_components &&
2824 get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput)
2825 {
2826 uint32_t locn = get_decoration(id: var.self, decoration: DecorationLocation) + i;
2827 target_components = get_target_components_for_fragment_location(location: locn);
2828 if (usable_type->vecsize < target_components)
2829 {
2830 // Make a new type here.
2831 type_id = build_extended_vector_type(type_id: usable_type->self, components: target_components);
2832 padded_output = true;
2833 }
2834 }
2835
2836 if (storage == StorageClassInput && pull_model_inputs.count(x: var.self))
2837 ib_type.member_types.push_back(t: build_msl_interpolant_type(type_id: get_pointee_type_id(type_id), is_noperspective));
2838 else
2839 ib_type.member_types.push_back(t: get_pointee_type_id(type_id));
2840
2841 // Give the member a name
2842 string mbr_name = ensure_valid_name(name: join(ts: to_expression(id: var.self), ts: "_", ts&: i), pfx: "m");
2843 set_member_name(id: ib_type.self, index: ib_mbr_idx, name: mbr_name);
2844
2845 // There is no qualified alias since we need to flatten the internal array on return.
2846 if (get_decoration_bitset(id: var.self).get(bit: DecorationLocation))
2847 {
2848 uint32_t locn = get_decoration(id: var.self, decoration: DecorationLocation) + i;
2849 uint32_t comp = get_decoration(id: var.self, decoration: DecorationComponent);
2850 if (storage == StorageClassInput)
2851 {
2852 var.basetype = ensure_correct_input_type(type_id: var.basetype, location: locn, component: comp, num_components: 0, strip_array: meta.strip_array);
2853 uint32_t mbr_type_id = ensure_correct_input_type(type_id: usable_type->self, location: locn, component: comp, num_components: 0, strip_array: meta.strip_array);
2854 if (storage == StorageClassInput && pull_model_inputs.count(x: var.self))
2855 ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(type_id: mbr_type_id, is_noperspective);
2856 else
2857 ib_type.member_types[ib_mbr_idx] = mbr_type_id;
2858 }
2859 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: locn);
2860 if (comp)
2861 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationComponent, argument: comp);
2862 mark_location_as_used_by_shader(location: locn, type: *usable_type, storage);
2863 }
2864 else if (is_builtin && is_tessellation_shader() && storage == StorageClassInput && inputs_by_builtin.count(x: builtin))
2865 {
2866 uint32_t locn = inputs_by_builtin[builtin].location + i;
2867 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: locn);
2868 mark_location_as_used_by_shader(location: locn, type: *usable_type, storage);
2869 }
2870 else if (is_builtin && capture_output_to_buffer && storage == StorageClassOutput && outputs_by_builtin.count(x: builtin))
2871 {
2872 uint32_t locn = outputs_by_builtin[builtin].location + i;
2873 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: locn);
2874 mark_location_as_used_by_shader(location: locn, type: *usable_type, storage);
2875 }
2876 else if (is_builtin && (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance))
2877 {
2878 // Declare the Clip/CullDistance as [[user(clip/cullN)]].
2879 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationBuiltIn, argument: builtin);
2880 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationIndex, argument: i);
2881 }
2882
2883 if (get_decoration_bitset(id: var.self).get(bit: DecorationIndex))
2884 {
2885 uint32_t index = get_decoration(id: var.self, decoration: DecorationIndex);
2886 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationIndex, argument: index);
2887 }
2888
2889 if (storage != StorageClassInput || !pull_model_inputs.count(x: var.self))
2890 {
2891 // Copy interpolation decorations if needed
2892 if (is_flat)
2893 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationFlat);
2894 if (is_noperspective)
2895 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationNoPerspective);
2896 if (is_centroid)
2897 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationCentroid);
2898 if (is_sample)
2899 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationSample);
2900 }
2901
2902 set_extended_member_decoration(type: ib_type.self, index: ib_mbr_idx, decoration: SPIRVCrossDecorationInterfaceOrigID, value: var.self);
2903
2904 // Only flatten/unflatten IO composites for non-tessellation cases where arrays are not stripped.
2905 if (!meta.strip_array)
2906 {
2907 switch (storage)
2908 {
2909 case StorageClassInput:
2910 entry_func.fixup_hooks_in.push_back(t: [=, &var]() {
2911 if (pull_model_inputs.count(x: var.self))
2912 {
2913 string lerp_call;
2914 if (is_centroid)
2915 lerp_call = ".interpolate_at_centroid()";
2916 else if (is_sample)
2917 lerp_call = join(ts: ".interpolate_at_sample(", ts: to_expression(id: builtin_sample_id_id), ts: ")");
2918 else
2919 lerp_call = ".interpolate_at_center()";
2920 statement(ts: to_name(id: var.self), ts: "[", ts: i, ts: "] = ", ts: ib_var_ref, ts: ".", ts: mbr_name, ts&: lerp_call, ts: ";");
2921 }
2922 else
2923 {
2924 statement(ts: to_name(id: var.self), ts: "[", ts: i, ts: "] = ", ts: ib_var_ref, ts: ".", ts: mbr_name, ts: ";");
2925 }
2926 });
2927 break;
2928
2929 case StorageClassOutput:
2930 entry_func.fixup_hooks_out.push_back(t: [=, &var]() {
2931 if (padded_output)
2932 {
2933 auto &padded_type = this->get<SPIRType>(id: type_id);
2934 statement(
2935 ts: ib_var_ref, ts: ".", ts: mbr_name, ts: " = ",
2936 ts: remap_swizzle(result_type: padded_type, input_components: usable_type->vecsize, expr: join(ts: to_name(id: var.self), ts: "[", ts: i, ts: "]")),
2937 ts: ";");
2938 }
2939 else if (flatten_from_ib_var)
2940 statement(ts: ib_var_ref, ts: ".", ts: mbr_name, ts: " = ", ts: ib_var_ref, ts: ".", ts: flatten_from_ib_mbr_name, ts: "[", ts: i,
2941 ts: "];");
2942 else
2943 statement(ts: ib_var_ref, ts: ".", ts: mbr_name, ts: " = ", ts: to_name(id: var.self), ts: "[", ts: i, ts: "];");
2944 });
2945 break;
2946
2947 default:
2948 break;
2949 }
2950 }
2951 }
2952}
2953
2954void CompilerMSL::add_composite_member_variable_to_interface_block(StorageClass storage,
2955 const string &ib_var_ref, SPIRType &ib_type,
2956 SPIRVariable &var, SPIRType &var_type,
2957 uint32_t mbr_idx, InterfaceBlockMeta &meta,
2958 const string &mbr_name_qual,
2959 const string &var_chain_qual,
2960 uint32_t &location, uint32_t &var_mbr_idx,
2961 const Bitset &interpolation_qual)
2962{
2963 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
2964
2965 BuiltIn builtin = BuiltInMax;
2966 bool is_builtin = is_member_builtin(type: var_type, index: mbr_idx, builtin: &builtin);
2967 bool is_flat = interpolation_qual.get(bit: DecorationFlat) ||
2968 has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationFlat) ||
2969 has_decoration(id: var.self, decoration: DecorationFlat);
2970 bool is_noperspective = interpolation_qual.get(bit: DecorationNoPerspective) ||
2971 has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationNoPerspective) ||
2972 has_decoration(id: var.self, decoration: DecorationNoPerspective);
2973 bool is_centroid = interpolation_qual.get(bit: DecorationCentroid) ||
2974 has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationCentroid) ||
2975 has_decoration(id: var.self, decoration: DecorationCentroid);
2976 bool is_sample = interpolation_qual.get(bit: DecorationSample) ||
2977 has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationSample) ||
2978 has_decoration(id: var.self, decoration: DecorationSample);
2979
2980 Bitset inherited_qual;
2981 if (is_flat)
2982 inherited_qual.set(DecorationFlat);
2983 if (is_noperspective)
2984 inherited_qual.set(DecorationNoPerspective);
2985 if (is_centroid)
2986 inherited_qual.set(DecorationCentroid);
2987 if (is_sample)
2988 inherited_qual.set(DecorationSample);
2989
2990 uint32_t mbr_type_id = var_type.member_types[mbr_idx];
2991 auto &mbr_type = get<SPIRType>(id: mbr_type_id);
2992
2993 bool mbr_is_indexable = false;
2994 uint32_t elem_cnt = 1;
2995 if (is_matrix(type: mbr_type))
2996 {
2997 if (is_array(type: mbr_type))
2998 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
2999
3000 mbr_is_indexable = true;
3001 elem_cnt = mbr_type.columns;
3002 }
3003 else if (is_array(type: mbr_type))
3004 {
3005 if (mbr_type.array.size() != 1)
3006 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
3007
3008 mbr_is_indexable = true;
3009 elem_cnt = to_array_size_literal(type: mbr_type);
3010 }
3011
3012 auto *usable_type = &mbr_type;
3013 if (usable_type->pointer)
3014 usable_type = &get<SPIRType>(id: usable_type->parent_type);
3015 while (is_array(type: *usable_type) || is_matrix(type: *usable_type))
3016 usable_type = &get<SPIRType>(id: usable_type->parent_type);
3017
3018 bool flatten_from_ib_var = false;
3019 string flatten_from_ib_mbr_name;
3020
3021 if (storage == StorageClassOutput && is_builtin && builtin == BuiltInClipDistance)
3022 {
3023 // Also declare [[clip_distance]] attribute here.
3024 uint32_t clip_array_mbr_idx = uint32_t(ib_type.member_types.size());
3025 ib_type.member_types.push_back(t: mbr_type_id);
3026 set_member_decoration(id: ib_type.self, index: clip_array_mbr_idx, decoration: DecorationBuiltIn, argument: BuiltInClipDistance);
3027
3028 flatten_from_ib_mbr_name = builtin_to_glsl(builtin: BuiltInClipDistance, storage: StorageClassOutput);
3029 set_member_name(id: ib_type.self, index: clip_array_mbr_idx, name: flatten_from_ib_mbr_name);
3030
3031 // When we flatten, we flatten directly from the "out" struct,
3032 // not from a function variable.
3033 flatten_from_ib_var = true;
3034
3035 if (!msl_options.enable_clip_distance_user_varying)
3036 return;
3037 }
3038
3039 // Recursively handle nested structures.
3040 if (mbr_type.basetype == SPIRType::Struct)
3041 {
3042 for (uint32_t i = 0; i < elem_cnt; i++)
3043 {
3044 string mbr_name = append_member_name(qualifier: mbr_name_qual, type: var_type, index: mbr_idx) + (mbr_is_indexable ? join(ts: "_", ts&: i) : "");
3045 string var_chain = join(ts: var_chain_qual, ts: ".", ts: to_member_name(type: var_type, index: mbr_idx), ts: (mbr_is_indexable ? join(ts: "[", ts&: i, ts: "]") : ""));
3046 uint32_t sub_mbr_cnt = uint32_t(mbr_type.member_types.size());
3047 for (uint32_t sub_mbr_idx = 0; sub_mbr_idx < sub_mbr_cnt; sub_mbr_idx++)
3048 {
3049 add_composite_member_variable_to_interface_block(storage, ib_var_ref, ib_type,
3050 var, var_type&: mbr_type, mbr_idx: sub_mbr_idx,
3051 meta, mbr_name_qual: mbr_name, var_chain_qual: var_chain,
3052 location, var_mbr_idx, interpolation_qual: inherited_qual);
3053 // FIXME: Recursive structs and tessellation breaks here.
3054 var_mbr_idx++;
3055 }
3056 }
3057 return;
3058 }
3059
3060 for (uint32_t i = 0; i < elem_cnt; i++)
3061 {
3062 // Add a reference to the variable type to the interface struct.
3063 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
3064 if (storage == StorageClassInput && pull_model_inputs.count(x: var.self))
3065 ib_type.member_types.push_back(t: build_msl_interpolant_type(type_id: usable_type->self, is_noperspective));
3066 else
3067 ib_type.member_types.push_back(t: usable_type->self);
3068
3069 // Give the member a name
3070 string mbr_name = ensure_valid_name(name: append_member_name(qualifier: mbr_name_qual, type: var_type, index: mbr_idx) + (mbr_is_indexable ? join(ts: "_", ts&: i) : ""), pfx: "m");
3071 set_member_name(id: ib_type.self, index: ib_mbr_idx, name: mbr_name);
3072
3073 // Once we determine the location of the first member within nested structures,
3074 // from a var of the topmost structure, the remaining flattened members of
3075 // the nested structures will have consecutive location values. At this point,
3076 // we've recursively tunnelled into structs, arrays, and matrices, and are
3077 // down to a single location for each member now.
3078 if (!is_builtin && location != UINT32_MAX)
3079 {
3080 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: location);
3081 mark_location_as_used_by_shader(location, type: *usable_type, storage);
3082 location++;
3083 }
3084 else if (has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationLocation))
3085 {
3086 location = get_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationLocation) + i;
3087 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: location);
3088 mark_location_as_used_by_shader(location, type: *usable_type, storage);
3089 location++;
3090 }
3091 else if (has_decoration(id: var.self, decoration: DecorationLocation))
3092 {
3093 location = get_accumulated_member_location(var, mbr_idx, strip_array: meta.strip_array) + i;
3094 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: location);
3095 mark_location_as_used_by_shader(location, type: *usable_type, storage);
3096 location++;
3097 }
3098 else if (is_builtin && is_tessellation_shader() && storage == StorageClassInput && inputs_by_builtin.count(x: builtin))
3099 {
3100 location = inputs_by_builtin[builtin].location + i;
3101 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: location);
3102 mark_location_as_used_by_shader(location, type: *usable_type, storage);
3103 location++;
3104 }
3105 else if (is_builtin && capture_output_to_buffer && storage == StorageClassOutput && outputs_by_builtin.count(x: builtin))
3106 {
3107 location = outputs_by_builtin[builtin].location + i;
3108 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: location);
3109 mark_location_as_used_by_shader(location, type: *usable_type, storage);
3110 location++;
3111 }
3112 else if (is_builtin && (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance))
3113 {
3114 // Declare the Clip/CullDistance as [[user(clip/cullN)]].
3115 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationBuiltIn, argument: builtin);
3116 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationIndex, argument: i);
3117 }
3118
3119 if (has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationComponent))
3120 SPIRV_CROSS_THROW("DecorationComponent on matrices and arrays is not supported.");
3121
3122 if (storage != StorageClassInput || !pull_model_inputs.count(x: var.self))
3123 {
3124 // Copy interpolation decorations if needed
3125 if (is_flat)
3126 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationFlat);
3127 if (is_noperspective)
3128 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationNoPerspective);
3129 if (is_centroid)
3130 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationCentroid);
3131 if (is_sample)
3132 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationSample);
3133 }
3134
3135 set_extended_member_decoration(type: ib_type.self, index: ib_mbr_idx, decoration: SPIRVCrossDecorationInterfaceOrigID, value: var.self);
3136 set_extended_member_decoration(type: ib_type.self, index: ib_mbr_idx, decoration: SPIRVCrossDecorationInterfaceMemberIndex, value: var_mbr_idx);
3137
3138 // Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate.
3139 if (!meta.strip_array && meta.allow_local_declaration)
3140 {
3141 string var_chain = join(ts: var_chain_qual, ts: ".", ts: to_member_name(type: var_type, index: mbr_idx), ts: (mbr_is_indexable ? join(ts: "[", ts&: i, ts: "]") : ""));
3142 switch (storage)
3143 {
3144 case StorageClassInput:
3145 entry_func.fixup_hooks_in.push_back(t: [=, &var]() {
3146 string lerp_call;
3147 if (pull_model_inputs.count(x: var.self))
3148 {
3149 if (is_centroid)
3150 lerp_call = ".interpolate_at_centroid()";
3151 else if (is_sample)
3152 lerp_call = join(ts: ".interpolate_at_sample(", ts: to_expression(id: builtin_sample_id_id), ts: ")");
3153 else
3154 lerp_call = ".interpolate_at_center()";
3155 }
3156 statement(ts: var_chain, ts: " = ", ts: ib_var_ref, ts: ".", ts: mbr_name, ts&: lerp_call, ts: ";");
3157 });
3158 break;
3159
3160 case StorageClassOutput:
3161 entry_func.fixup_hooks_out.push_back(t: [=]() {
3162 if (flatten_from_ib_var)
3163 statement(ts: ib_var_ref, ts: ".", ts: mbr_name, ts: " = ", ts: ib_var_ref, ts: ".", ts: flatten_from_ib_mbr_name, ts: "[", ts: i, ts: "];");
3164 else
3165 statement(ts: ib_var_ref, ts: ".", ts: mbr_name, ts: " = ", ts: var_chain, ts: ";");
3166 });
3167 break;
3168
3169 default:
3170 break;
3171 }
3172 }
3173 }
3174}
3175
3176void CompilerMSL::add_plain_member_variable_to_interface_block(StorageClass storage,
3177 const string &ib_var_ref, SPIRType &ib_type,
3178 SPIRVariable &var, SPIRType &var_type,
3179 uint32_t mbr_idx, InterfaceBlockMeta &meta,
3180 const string &mbr_name_qual,
3181 const string &var_chain_qual,
3182 uint32_t &location, uint32_t &var_mbr_idx)
3183{
3184 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
3185
3186 BuiltIn builtin = BuiltInMax;
3187 bool is_builtin = is_member_builtin(type: var_type, index: mbr_idx, builtin: &builtin);
3188 bool is_flat =
3189 has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationFlat) || has_decoration(id: var.self, decoration: DecorationFlat);
3190 bool is_noperspective = has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationNoPerspective) ||
3191 has_decoration(id: var.self, decoration: DecorationNoPerspective);
3192 bool is_centroid = has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationCentroid) ||
3193 has_decoration(id: var.self, decoration: DecorationCentroid);
3194 bool is_sample =
3195 has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationSample) || has_decoration(id: var.self, decoration: DecorationSample);
3196
3197 // Add a reference to the member to the interface struct.
3198 uint32_t mbr_type_id = var_type.member_types[mbr_idx];
3199 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
3200 mbr_type_id = ensure_correct_builtin_type(type_id: mbr_type_id, builtin);
3201 var_type.member_types[mbr_idx] = mbr_type_id;
3202 if (storage == StorageClassInput && pull_model_inputs.count(x: var.self))
3203 ib_type.member_types.push_back(t: build_msl_interpolant_type(type_id: mbr_type_id, is_noperspective));
3204 else
3205 ib_type.member_types.push_back(t: mbr_type_id);
3206
3207 // Give the member a name
3208 string mbr_name = ensure_valid_name(name: append_member_name(qualifier: mbr_name_qual, type: var_type, index: mbr_idx), pfx: "m");
3209 set_member_name(id: ib_type.self, index: ib_mbr_idx, name: mbr_name);
3210
3211 // Update the original variable reference to include the structure reference
3212 string qual_var_name = ib_var_ref + "." + mbr_name;
3213 // If using pull-model interpolation, need to add a call to the correct interpolation method.
3214 if (storage == StorageClassInput && pull_model_inputs.count(x: var.self))
3215 {
3216 if (is_centroid)
3217 qual_var_name += ".interpolate_at_centroid()";
3218 else if (is_sample)
3219 qual_var_name += join(ts: ".interpolate_at_sample(", ts: to_expression(id: builtin_sample_id_id), ts: ")");
3220 else
3221 qual_var_name += ".interpolate_at_center()";
3222 }
3223
3224 bool flatten_stage_out = false;
3225 string var_chain = var_chain_qual + "." + to_member_name(type: var_type, index: mbr_idx);
3226 if (is_builtin && !meta.strip_array)
3227 {
3228 // For the builtin gl_PerVertex, we cannot treat it as a block anyways,
3229 // so redirect to qualified name.
3230 set_member_qualified_name(type_id: var_type.self, index: mbr_idx, name: qual_var_name);
3231 }
3232 else if (!meta.strip_array && meta.allow_local_declaration)
3233 {
3234 // Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate.
3235 switch (storage)
3236 {
3237 case StorageClassInput:
3238 entry_func.fixup_hooks_in.push_back(t: [=]() {
3239 statement(ts: var_chain, ts: " = ", ts: qual_var_name, ts: ";");
3240 });
3241 break;
3242
3243 case StorageClassOutput:
3244 flatten_stage_out = true;
3245 entry_func.fixup_hooks_out.push_back(t: [=]() {
3246 statement(ts: qual_var_name, ts: " = ", ts: var_chain, ts: ";");
3247 });
3248 break;
3249
3250 default:
3251 break;
3252 }
3253 }
3254
3255 // Once we determine the location of the first member within nested structures,
3256 // from a var of the topmost structure, the remaining flattened members of
3257 // the nested structures will have consecutive location values. At this point,
3258 // we've recursively tunnelled into structs, arrays, and matrices, and are
3259 // down to a single location for each member now.
3260 if (!is_builtin && location != UINT32_MAX)
3261 {
3262 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: location);
3263 mark_location_as_used_by_shader(location, type: get<SPIRType>(id: mbr_type_id), storage);
3264 location += type_to_location_count(type: get<SPIRType>(id: mbr_type_id));
3265 }
3266 else if (has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationLocation))
3267 {
3268 location = get_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationLocation);
3269 uint32_t comp = get_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationComponent);
3270 if (storage == StorageClassInput)
3271 {
3272 mbr_type_id = ensure_correct_input_type(type_id: mbr_type_id, location, component: comp, num_components: 0, strip_array: meta.strip_array);
3273 var_type.member_types[mbr_idx] = mbr_type_id;
3274 if (storage == StorageClassInput && pull_model_inputs.count(x: var.self))
3275 ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(type_id: mbr_type_id, is_noperspective);
3276 else
3277 ib_type.member_types[ib_mbr_idx] = mbr_type_id;
3278 }
3279 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: location);
3280 mark_location_as_used_by_shader(location, type: get<SPIRType>(id: mbr_type_id), storage);
3281 location += type_to_location_count(type: get<SPIRType>(id: mbr_type_id));
3282 }
3283 else if (has_decoration(id: var.self, decoration: DecorationLocation))
3284 {
3285 location = get_accumulated_member_location(var, mbr_idx, strip_array: meta.strip_array);
3286 if (storage == StorageClassInput)
3287 {
3288 mbr_type_id = ensure_correct_input_type(type_id: mbr_type_id, location, component: 0, num_components: 0, strip_array: meta.strip_array);
3289 var_type.member_types[mbr_idx] = mbr_type_id;
3290 if (storage == StorageClassInput && pull_model_inputs.count(x: var.self))
3291 ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(type_id: mbr_type_id, is_noperspective);
3292 else
3293 ib_type.member_types[ib_mbr_idx] = mbr_type_id;
3294 }
3295 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: location);
3296 mark_location_as_used_by_shader(location, type: get<SPIRType>(id: mbr_type_id), storage);
3297 location += type_to_location_count(type: get<SPIRType>(id: mbr_type_id));
3298 }
3299 else if (is_builtin && is_tessellation_shader() && storage == StorageClassInput && inputs_by_builtin.count(x: builtin))
3300 {
3301 location = inputs_by_builtin[builtin].location;
3302 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: location);
3303 mark_location_as_used_by_shader(location, type: get<SPIRType>(id: mbr_type_id), storage);
3304 location += type_to_location_count(type: get<SPIRType>(id: mbr_type_id));
3305 }
3306 else if (is_builtin && capture_output_to_buffer && storage == StorageClassOutput && outputs_by_builtin.count(x: builtin))
3307 {
3308 location = outputs_by_builtin[builtin].location;
3309 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: location);
3310 mark_location_as_used_by_shader(location, type: get<SPIRType>(id: mbr_type_id), storage);
3311 location += type_to_location_count(type: get<SPIRType>(id: mbr_type_id));
3312 }
3313
3314 // Copy the component location, if present.
3315 if (has_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationComponent))
3316 {
3317 uint32_t comp = get_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationComponent);
3318 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationComponent, argument: comp);
3319 }
3320
3321 // Mark the member as builtin if needed
3322 if (is_builtin)
3323 {
3324 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationBuiltIn, argument: builtin);
3325 if (builtin == BuiltInPosition && storage == StorageClassOutput)
3326 qual_pos_var_name = qual_var_name;
3327 }
3328
3329 const SPIRConstant *c = nullptr;
3330 if (!flatten_stage_out && var.storage == StorageClassOutput &&
3331 var.initializer != ID(0) && (c = maybe_get<SPIRConstant>(id: var.initializer)))
3332 {
3333 if (meta.strip_array)
3334 {
3335 entry_func.fixup_hooks_in.push_back(t: [=, &var]() {
3336 auto &type = this->get<SPIRType>(id: var.basetype);
3337 uint32_t index = get_extended_member_decoration(type: var.self, index: mbr_idx, decoration: SPIRVCrossDecorationInterfaceMemberIndex);
3338
3339 auto invocation = to_tesc_invocation_id();
3340 auto constant_chain = join(ts: to_expression(id: var.initializer), ts: "[", ts&: invocation, ts: "]");
3341 statement(ts: to_expression(id: stage_out_ptr_var_id), ts: "[",
3342 ts&: invocation, ts: "].",
3343 ts: to_member_name(type: ib_type, index), ts: " = ",
3344 ts&: constant_chain, ts: ".", ts: to_member_name(type, index: mbr_idx), ts: ";");
3345 });
3346 }
3347 else
3348 {
3349 entry_func.fixup_hooks_in.push_back(t: [=]() {
3350 statement(ts: qual_var_name, ts: " = ", ts: constant_expression(
3351 c: this->get<SPIRConstant>(id: c->subconstants[mbr_idx])), ts: ";");
3352 });
3353 }
3354 }
3355
3356 if (storage != StorageClassInput || !pull_model_inputs.count(x: var.self))
3357 {
3358 // Copy interpolation decorations if needed
3359 if (is_flat)
3360 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationFlat);
3361 if (is_noperspective)
3362 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationNoPerspective);
3363 if (is_centroid)
3364 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationCentroid);
3365 if (is_sample)
3366 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationSample);
3367 }
3368
3369 set_extended_member_decoration(type: ib_type.self, index: ib_mbr_idx, decoration: SPIRVCrossDecorationInterfaceOrigID, value: var.self);
3370 set_extended_member_decoration(type: ib_type.self, index: ib_mbr_idx, decoration: SPIRVCrossDecorationInterfaceMemberIndex, value: var_mbr_idx);
3371}
3372
3373// In Metal, the tessellation levels are stored as tightly packed half-precision floating point values.
3374// But, stage-in attribute offsets and strides must be multiples of four, so we can't pass the levels
3375// individually. Therefore, we must pass them as vectors. Triangles get a single float4, with the outer
3376// levels in 'xyz' and the inner level in 'w'. Quads get a float4 containing the outer levels and a
3377// float2 containing the inner levels.
3378void CompilerMSL::add_tess_level_input_to_interface_block(const std::string &ib_var_ref, SPIRType &ib_type,
3379 SPIRVariable &var)
3380{
3381 auto &var_type = get_variable_element_type(var);
3382
3383 BuiltIn builtin = BuiltIn(get_decoration(id: var.self, decoration: DecorationBuiltIn));
3384 bool triangles = is_tessellating_triangles();
3385 string mbr_name;
3386
3387 // Add a reference to the variable type to the interface struct.
3388 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
3389
3390 const auto mark_locations = [&](const SPIRType &new_var_type) {
3391 if (get_decoration_bitset(id: var.self).get(bit: DecorationLocation))
3392 {
3393 uint32_t locn = get_decoration(id: var.self, decoration: DecorationLocation);
3394 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: locn);
3395 mark_location_as_used_by_shader(location: locn, type: new_var_type, storage: StorageClassInput);
3396 }
3397 else if (inputs_by_builtin.count(x: builtin))
3398 {
3399 uint32_t locn = inputs_by_builtin[builtin].location;
3400 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: locn);
3401 mark_location_as_used_by_shader(location: locn, type: new_var_type, storage: StorageClassInput);
3402 }
3403 };
3404
3405 if (triangles)
3406 {
3407 // Triangles are tricky, because we want only one member in the struct.
3408 mbr_name = "gl_TessLevel";
3409
3410 // If we already added the other one, we can skip this step.
3411 if (!added_builtin_tess_level)
3412 {
3413 uint32_t type_id = build_extended_vector_type(type_id: var_type.self, components: 4);
3414
3415 ib_type.member_types.push_back(t: type_id);
3416
3417 // Give the member a name
3418 set_member_name(id: ib_type.self, index: ib_mbr_idx, name: mbr_name);
3419
3420 // We cannot decorate both, but the important part is that
3421 // it's marked as builtin so we can get automatic attribute assignment if needed.
3422 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationBuiltIn, argument: builtin);
3423
3424 mark_locations(var_type);
3425 added_builtin_tess_level = true;
3426 }
3427 }
3428 else
3429 {
3430 mbr_name = builtin_to_glsl(builtin, storage: StorageClassFunction);
3431
3432 uint32_t type_id = build_extended_vector_type(type_id: var_type.self, components: builtin == BuiltInTessLevelOuter ? 4 : 2);
3433
3434 uint32_t ptr_type_id = ir.increase_bound_by(count: 1);
3435 auto &new_var_type = set<SPIRType>(id: ptr_type_id, args&: get<SPIRType>(id: type_id));
3436 new_var_type.pointer = true;
3437 new_var_type.pointer_depth++;
3438 new_var_type.storage = StorageClassInput;
3439 new_var_type.parent_type = type_id;
3440
3441 ib_type.member_types.push_back(t: type_id);
3442
3443 // Give the member a name
3444 set_member_name(id: ib_type.self, index: ib_mbr_idx, name: mbr_name);
3445 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationBuiltIn, argument: builtin);
3446
3447 mark_locations(new_var_type);
3448 }
3449
3450 add_tess_level_input(base_ref: ib_var_ref, mbr_name, var);
3451}
3452
3453void CompilerMSL::add_tess_level_input(const std::string &base_ref, const std::string &mbr_name, SPIRVariable &var)
3454{
3455 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
3456 BuiltIn builtin = BuiltIn(get_decoration(id: var.self, decoration: DecorationBuiltIn));
3457
3458 // Force the variable to have the proper name.
3459 string var_name = builtin_to_glsl(builtin, storage: StorageClassFunction);
3460 set_name(id: var.self, name: var_name);
3461
3462 // We need to declare the variable early and at entry-point scope.
3463 entry_func.add_local_variable(id: var.self);
3464 vars_needing_early_declaration.push_back(t: var.self);
3465 bool triangles = is_tessellating_triangles();
3466
3467 if (builtin == BuiltInTessLevelOuter)
3468 {
3469 entry_func.fixup_hooks_in.push_back(
3470 t: [=]()
3471 {
3472 statement(ts: var_name, ts: "[0] = ", ts: base_ref, ts: ".", ts: mbr_name, ts: "[0];");
3473 statement(ts: var_name, ts: "[1] = ", ts: base_ref, ts: ".", ts: mbr_name, ts: "[1];");
3474 statement(ts: var_name, ts: "[2] = ", ts: base_ref, ts: ".", ts: mbr_name, ts: "[2];");
3475 if (!triangles)
3476 statement(ts: var_name, ts: "[3] = ", ts: base_ref, ts: ".", ts: mbr_name, ts: "[3];");
3477 });
3478 }
3479 else
3480 {
3481 entry_func.fixup_hooks_in.push_back(t: [=]() {
3482 if (triangles)
3483 {
3484 if (msl_options.raw_buffer_tese_input)
3485 statement(ts: var_name, ts: "[0] = ", ts: base_ref, ts: ".", ts: mbr_name, ts: ";");
3486 else
3487 statement(ts: var_name, ts: "[0] = ", ts: base_ref, ts: ".", ts: mbr_name, ts: "[3];");
3488 }
3489 else
3490 {
3491 statement(ts: var_name, ts: "[0] = ", ts: base_ref, ts: ".", ts: mbr_name, ts: "[0];");
3492 statement(ts: var_name, ts: "[1] = ", ts: base_ref, ts: ".", ts: mbr_name, ts: "[1];");
3493 }
3494 });
3495 }
3496}
3497
3498bool CompilerMSL::variable_storage_requires_stage_io(spv::StorageClass storage) const
3499{
3500 if (storage == StorageClassOutput)
3501 return !capture_output_to_buffer;
3502 else if (storage == StorageClassInput)
3503 return !(is_tesc_shader() && msl_options.multi_patch_workgroup) &&
3504 !(is_tese_shader() && msl_options.raw_buffer_tese_input);
3505 else
3506 return false;
3507}
3508
3509string CompilerMSL::to_tesc_invocation_id()
3510{
3511 if (msl_options.multi_patch_workgroup)
3512 {
3513 // n.b. builtin_invocation_id_id here is the dispatch global invocation ID,
3514 // not the TC invocation ID.
3515 return join(ts: to_expression(id: builtin_invocation_id_id), ts: ".x % ", ts&: get_entry_point().output_vertices);
3516 }
3517 else
3518 return builtin_to_glsl(builtin: BuiltInInvocationId, storage: StorageClassInput);
3519}
3520
3521void CompilerMSL::emit_local_masked_variable(const SPIRVariable &masked_var, bool strip_array)
3522{
3523 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
3524 bool threadgroup_storage = variable_decl_is_remapped_storage(variable: masked_var, storage: StorageClassWorkgroup);
3525
3526 if (threadgroup_storage && msl_options.multi_patch_workgroup)
3527 {
3528 // We need one threadgroup block per patch, so fake this.
3529 entry_func.fixup_hooks_in.push_back(t: [this, &masked_var]() {
3530 auto &type = get_variable_data_type(var: masked_var);
3531 add_local_variable_name(id: masked_var.self);
3532
3533 const uint32_t max_control_points_per_patch = 32u;
3534 uint32_t max_num_instances =
3535 (max_control_points_per_patch + get_entry_point().output_vertices - 1u) /
3536 get_entry_point().output_vertices;
3537 statement(ts: "threadgroup ", ts: type_to_glsl(type), ts: " ",
3538 ts: "spvStorage", ts: to_name(id: masked_var.self), ts: "[", ts&: max_num_instances, ts: "]",
3539 ts: type_to_array_glsl(type, variable_id: 0), ts: ";");
3540
3541 // Assign a threadgroup slice to each PrimitiveID.
3542 // We assume here that workgroup size is rounded to 32,
3543 // since that's the maximum number of control points per patch.
3544 // We cannot size the array based on fixed dispatch parameters,
3545 // since Metal does not allow that. :(
3546 // FIXME: We will likely need an option to support passing down target workgroup size,
3547 // so we can emit appropriate size here.
3548 statement(ts: "threadgroup auto ",
3549 ts: "&", ts: to_name(id: masked_var.self),
3550 ts: " = spvStorage", ts: to_name(id: masked_var.self), ts: "[",
3551 ts: "(", ts: to_expression(id: builtin_invocation_id_id), ts: ".x / ",
3552 ts&: get_entry_point().output_vertices, ts: ") % ",
3553 ts&: max_num_instances, ts: "];");
3554 });
3555 }
3556 else
3557 {
3558 entry_func.add_local_variable(id: masked_var.self);
3559 }
3560
3561 if (!threadgroup_storage)
3562 {
3563 vars_needing_early_declaration.push_back(t: masked_var.self);
3564 }
3565 else if (masked_var.initializer)
3566 {
3567 // Cannot directly initialize threadgroup variables. Need fixup hooks.
3568 ID initializer = masked_var.initializer;
3569 if (strip_array)
3570 {
3571 entry_func.fixup_hooks_in.push_back(t: [this, &masked_var, initializer]() {
3572 auto invocation = to_tesc_invocation_id();
3573 statement(ts: to_expression(id: masked_var.self), ts: "[",
3574 ts&: invocation, ts: "] = ",
3575 ts: to_expression(id: initializer), ts: "[",
3576 ts&: invocation, ts: "];");
3577 });
3578 }
3579 else
3580 {
3581 entry_func.fixup_hooks_in.push_back(t: [this, &masked_var, initializer]() {
3582 statement(ts: to_expression(id: masked_var.self), ts: " = ", ts: to_expression(id: initializer), ts: ";");
3583 });
3584 }
3585 }
3586}
3587
3588void CompilerMSL::add_variable_to_interface_block(StorageClass storage, const string &ib_var_ref, SPIRType &ib_type,
3589 SPIRVariable &var, InterfaceBlockMeta &meta)
3590{
3591 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
3592 // Tessellation control I/O variables and tessellation evaluation per-point inputs are
3593 // usually declared as arrays. In these cases, we want to add the element type to the
3594 // interface block, since in Metal it's the interface block itself which is arrayed.
3595 auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
3596 bool is_builtin = is_builtin_variable(var);
3597 auto builtin = BuiltIn(get_decoration(id: var.self, decoration: DecorationBuiltIn));
3598 bool is_block = has_decoration(id: var_type.self, decoration: DecorationBlock);
3599
3600 // If stage variables are masked out, emit them as plain variables instead.
3601 // For builtins, we query them one by one later.
3602 // IO blocks are not masked here, we need to mask them per-member instead.
3603 if (storage == StorageClassOutput && is_stage_output_variable_masked(var))
3604 {
3605 // If we ignore an output, we must still emit it, since it might be used by app.
3606 // Instead, just emit it as early declaration.
3607 emit_local_masked_variable(masked_var: var, strip_array: meta.strip_array);
3608 return;
3609 }
3610
3611 if (storage == StorageClassInput && has_decoration(id: var.self, decoration: DecorationPerVertexKHR))
3612 SPIRV_CROSS_THROW("PerVertexKHR decoration is not supported in MSL.");
3613
3614 // If variable names alias, they will end up with wrong names in the interface struct, because
3615 // there might be aliases in the member name cache and there would be a mismatch in fixup_in code.
3616 // Make sure to register the variables as unique resource names ahead of time.
3617 // This would normally conflict with the name cache when emitting local variables,
3618 // but this happens in the setup stage, before we hit compilation loops.
3619 // The name cache is cleared before we actually emit code, so this is safe.
3620 add_resource_name(id: var.self);
3621
3622 if (var_type.basetype == SPIRType::Struct)
3623 {
3624 bool block_requires_flattening =
3625 variable_storage_requires_stage_io(storage) || (is_block && var_type.array.empty());
3626 bool needs_local_declaration = !is_builtin && block_requires_flattening && meta.allow_local_declaration;
3627
3628 if (needs_local_declaration)
3629 {
3630 // For I/O blocks or structs, we will need to pass the block itself around
3631 // to functions if they are used globally in leaf functions.
3632 // Rather than passing down member by member,
3633 // we unflatten I/O blocks while running the shader,
3634 // and pass the actual struct type down to leaf functions.
3635 // We then unflatten inputs, and flatten outputs in the "fixup" stages.
3636 emit_local_masked_variable(masked_var: var, strip_array: meta.strip_array);
3637 }
3638
3639 if (!block_requires_flattening)
3640 {
3641 // In Metal tessellation shaders, the interface block itself is arrayed. This makes things
3642 // very complicated, since stage-in structures in MSL don't support nested structures.
3643 // Luckily, for stage-out when capturing output, we can avoid this and just add
3644 // composite members directly, because the stage-out structure is stored to a buffer,
3645 // not returned.
3646 add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
3647 }
3648 else
3649 {
3650 bool masked_block = false;
3651 uint32_t location = UINT32_MAX;
3652 uint32_t var_mbr_idx = 0;
3653 uint32_t elem_cnt = 1;
3654 if (is_matrix(type: var_type))
3655 {
3656 if (is_array(type: var_type))
3657 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
3658
3659 elem_cnt = var_type.columns;
3660 }
3661 else if (is_array(type: var_type))
3662 {
3663 if (var_type.array.size() != 1)
3664 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
3665
3666 elem_cnt = to_array_size_literal(type: var_type);
3667 }
3668
3669 for (uint32_t elem_idx = 0; elem_idx < elem_cnt; elem_idx++)
3670 {
3671 // Flatten the struct members into the interface struct
3672 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(var_type.member_types.size()); mbr_idx++)
3673 {
3674 builtin = BuiltInMax;
3675 is_builtin = is_member_builtin(type: var_type, index: mbr_idx, builtin: &builtin);
3676 auto &mbr_type = get<SPIRType>(id: var_type.member_types[mbr_idx]);
3677
3678 if (storage == StorageClassOutput && is_stage_output_block_member_masked(var, index: mbr_idx, strip_array: meta.strip_array))
3679 {
3680 location = UINT32_MAX; // Skip this member and resolve location again on next var member
3681
3682 if (is_block)
3683 masked_block = true;
3684
3685 // Non-builtin block output variables are just ignored, since they will still access
3686 // the block variable as-is. They're just not flattened.
3687 if (is_builtin && !meta.strip_array)
3688 {
3689 // Emit a fake variable instead.
3690 uint32_t ids = ir.increase_bound_by(count: 2);
3691 uint32_t ptr_type_id = ids + 0;
3692 uint32_t var_id = ids + 1;
3693
3694 auto ptr_type = mbr_type;
3695 ptr_type.pointer = true;
3696 ptr_type.pointer_depth++;
3697 ptr_type.parent_type = var_type.member_types[mbr_idx];
3698 ptr_type.storage = StorageClassOutput;
3699
3700 uint32_t initializer = 0;
3701 if (var.initializer)
3702 if (auto *c = maybe_get<SPIRConstant>(id: var.initializer))
3703 initializer = c->subconstants[mbr_idx];
3704
3705 set<SPIRType>(id: ptr_type_id, args&: ptr_type);
3706 set<SPIRVariable>(id: var_id, args&: ptr_type_id, args: StorageClassOutput, args&: initializer);
3707 entry_func.add_local_variable(id: var_id);
3708 vars_needing_early_declaration.push_back(t: var_id);
3709 set_name(id: var_id, name: builtin_to_glsl(builtin, storage: StorageClassOutput));
3710 set_decoration(id: var_id, decoration: DecorationBuiltIn, argument: builtin);
3711 }
3712 }
3713 else if (!is_builtin || has_active_builtin(builtin, storage))
3714 {
3715 bool is_composite_type = is_matrix(type: mbr_type) || is_array(type: mbr_type) || mbr_type.basetype == SPIRType::Struct;
3716 bool attribute_load_store =
3717 storage == StorageClassInput && get_execution_model() != ExecutionModelFragment;
3718 bool storage_is_stage_io = variable_storage_requires_stage_io(storage);
3719
3720 // Clip/CullDistance always need to be declared as user attributes.
3721 if (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance)
3722 is_builtin = false;
3723
3724 const string var_name = to_name(id: var.self);
3725 string mbr_name_qual = var_name;
3726 string var_chain_qual = var_name;
3727 if (elem_cnt > 1)
3728 {
3729 mbr_name_qual += join(ts: "_", ts&: elem_idx);
3730 var_chain_qual += join(ts: "[", ts&: elem_idx, ts: "]");
3731 }
3732
3733 if ((!is_builtin || attribute_load_store) && storage_is_stage_io && is_composite_type)
3734 {
3735 add_composite_member_variable_to_interface_block(storage, ib_var_ref, ib_type,
3736 var, var_type, mbr_idx, meta,
3737 mbr_name_qual, var_chain_qual,
3738 location, var_mbr_idx, interpolation_qual: {});
3739 }
3740 else
3741 {
3742 add_plain_member_variable_to_interface_block(storage, ib_var_ref, ib_type,
3743 var, var_type, mbr_idx, meta,
3744 mbr_name_qual, var_chain_qual,
3745 location, var_mbr_idx);
3746 }
3747 }
3748 var_mbr_idx++;
3749 }
3750 }
3751
3752 // If we're redirecting a block, we might still need to access the original block
3753 // variable if we're masking some members.
3754 if (masked_block && !needs_local_declaration && (!is_builtin_variable(var) || is_tesc_shader()))
3755 {
3756 if (is_builtin_variable(var))
3757 {
3758 // Ensure correct names for the block members if we're actually going to
3759 // declare gl_PerVertex.
3760 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(var_type.member_types.size()); mbr_idx++)
3761 {
3762 set_member_name(id: var_type.self, index: mbr_idx, name: builtin_to_glsl(
3763 builtin: BuiltIn(get_member_decoration(id: var_type.self, index: mbr_idx, decoration: DecorationBuiltIn)),
3764 storage: StorageClassOutput));
3765 }
3766
3767 set_name(id: var_type.self, name: "gl_PerVertex");
3768 set_name(id: var.self, name: "gl_out_masked");
3769 stage_out_masked_builtin_type_id = var_type.self;
3770 }
3771 emit_local_masked_variable(masked_var: var, strip_array: meta.strip_array);
3772 }
3773 }
3774 }
3775 else if (is_tese_shader() && storage == StorageClassInput && !meta.strip_array && is_builtin &&
3776 (builtin == BuiltInTessLevelOuter || builtin == BuiltInTessLevelInner))
3777 {
3778 add_tess_level_input_to_interface_block(ib_var_ref, ib_type, var);
3779 }
3780 else if (var_type.basetype == SPIRType::Boolean || var_type.basetype == SPIRType::Char ||
3781 type_is_integral(type: var_type) || type_is_floating_point(type: var_type))
3782 {
3783 if (!is_builtin || has_active_builtin(builtin, storage))
3784 {
3785 bool is_composite_type = is_matrix(type: var_type) || is_array(type: var_type);
3786 bool storage_is_stage_io = variable_storage_requires_stage_io(storage);
3787 bool attribute_load_store = storage == StorageClassInput && get_execution_model() != ExecutionModelFragment;
3788
3789 // Clip/CullDistance always needs to be declared as user attributes.
3790 if (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance)
3791 is_builtin = false;
3792
3793 // MSL does not allow matrices or arrays in input or output variables, so need to handle it specially.
3794 if ((!is_builtin || attribute_load_store) && storage_is_stage_io && is_composite_type)
3795 {
3796 add_composite_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
3797 }
3798 else
3799 {
3800 add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
3801 }
3802 }
3803 }
3804}
3805
3806// Fix up the mapping of variables to interface member indices, which is used to compile access chains
3807// for per-vertex variables in a tessellation control shader.
3808void CompilerMSL::fix_up_interface_member_indices(StorageClass storage, uint32_t ib_type_id)
3809{
3810 // Only needed for tessellation shaders and pull-model interpolants.
3811 // Need to redirect interface indices back to variables themselves.
3812 // For structs, each member of the struct need a separate instance.
3813 if (!is_tesc_shader() && !(is_tese_shader() && storage == StorageClassInput) &&
3814 !(get_execution_model() == ExecutionModelFragment && storage == StorageClassInput &&
3815 !pull_model_inputs.empty()))
3816 return;
3817
3818 auto mbr_cnt = uint32_t(ir.meta[ib_type_id].members.size());
3819 for (uint32_t i = 0; i < mbr_cnt; i++)
3820 {
3821 uint32_t var_id = get_extended_member_decoration(type: ib_type_id, index: i, decoration: SPIRVCrossDecorationInterfaceOrigID);
3822 if (!var_id)
3823 continue;
3824 auto &var = get<SPIRVariable>(id: var_id);
3825
3826 auto &type = get_variable_element_type(var);
3827
3828 bool flatten_composites = variable_storage_requires_stage_io(storage: var.storage);
3829 bool is_block = has_decoration(id: type.self, decoration: DecorationBlock);
3830
3831 uint32_t mbr_idx = uint32_t(-1);
3832 if (type.basetype == SPIRType::Struct && (flatten_composites || is_block))
3833 mbr_idx = get_extended_member_decoration(type: ib_type_id, index: i, decoration: SPIRVCrossDecorationInterfaceMemberIndex);
3834
3835 if (mbr_idx != uint32_t(-1))
3836 {
3837 // Only set the lowest InterfaceMemberIndex for each variable member.
3838 // IB struct members will be emitted in-order w.r.t. interface member index.
3839 if (!has_extended_member_decoration(type: var_id, index: mbr_idx, decoration: SPIRVCrossDecorationInterfaceMemberIndex))
3840 set_extended_member_decoration(type: var_id, index: mbr_idx, decoration: SPIRVCrossDecorationInterfaceMemberIndex, value: i);
3841 }
3842 else
3843 {
3844 // Only set the lowest InterfaceMemberIndex for each variable.
3845 // IB struct members will be emitted in-order w.r.t. interface member index.
3846 if (!has_extended_decoration(id: var_id, decoration: SPIRVCrossDecorationInterfaceMemberIndex))
3847 set_extended_decoration(id: var_id, decoration: SPIRVCrossDecorationInterfaceMemberIndex, value: i);
3848 }
3849 }
3850}
3851
3852// Add an interface structure for the type of storage, which is either StorageClassInput or StorageClassOutput.
3853// Returns the ID of the newly added variable, or zero if no variable was added.
3854uint32_t CompilerMSL::add_interface_block(StorageClass storage, bool patch)
3855{
3856 // Accumulate the variables that should appear in the interface struct.
3857 SmallVector<SPIRVariable *> vars;
3858 bool incl_builtins = storage == StorageClassOutput || is_tessellation_shader();
3859 bool has_seen_barycentric = false;
3860
3861 InterfaceBlockMeta meta;
3862
3863 // Varying interfaces between stages which use "user()" attribute can be dealt with
3864 // without explicit packing and unpacking of components. For any variables which link against the runtime
3865 // in some way (vertex attributes, fragment output, etc), we'll need to deal with it somehow.
3866 bool pack_components =
3867 (storage == StorageClassInput && get_execution_model() == ExecutionModelVertex) ||
3868 (storage == StorageClassOutput && get_execution_model() == ExecutionModelFragment) ||
3869 (storage == StorageClassOutput && get_execution_model() == ExecutionModelVertex && capture_output_to_buffer);
3870
3871 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t var_id, SPIRVariable &var) {
3872 if (var.storage != storage)
3873 return;
3874
3875 auto &type = this->get<SPIRType>(id: var.basetype);
3876
3877 bool is_builtin = is_builtin_variable(var);
3878 bool is_block = has_decoration(id: type.self, decoration: DecorationBlock);
3879
3880 auto bi_type = BuiltInMax;
3881 bool builtin_is_gl_in_out = false;
3882 if (is_builtin && !is_block)
3883 {
3884 bi_type = BuiltIn(get_decoration(id: var_id, decoration: DecorationBuiltIn));
3885 builtin_is_gl_in_out = bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
3886 bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance;
3887 }
3888
3889 if (is_builtin && is_block)
3890 builtin_is_gl_in_out = true;
3891
3892 uint32_t location = get_decoration(id: var_id, decoration: DecorationLocation);
3893
3894 bool builtin_is_stage_in_out = builtin_is_gl_in_out ||
3895 bi_type == BuiltInLayer || bi_type == BuiltInViewportIndex ||
3896 bi_type == BuiltInBaryCoordKHR || bi_type == BuiltInBaryCoordNoPerspKHR ||
3897 bi_type == BuiltInFragDepth ||
3898 bi_type == BuiltInFragStencilRefEXT || bi_type == BuiltInSampleMask;
3899
3900 // These builtins are part of the stage in/out structs.
3901 bool is_interface_block_builtin =
3902 builtin_is_stage_in_out || (is_tese_shader() && !msl_options.raw_buffer_tese_input &&
3903 (bi_type == BuiltInTessLevelOuter || bi_type == BuiltInTessLevelInner));
3904
3905 bool is_active = interface_variable_exists_in_entry_point(id: var.self);
3906 if (is_builtin && is_active)
3907 {
3908 // Only emit the builtin if it's active in this entry point. Interface variable list might lie.
3909 if (is_block)
3910 {
3911 // If any builtin is active, the block is active.
3912 uint32_t mbr_cnt = uint32_t(type.member_types.size());
3913 for (uint32_t i = 0; !is_active && i < mbr_cnt; i++)
3914 is_active = has_active_builtin(builtin: BuiltIn(get_member_decoration(id: type.self, index: i, decoration: DecorationBuiltIn)), storage);
3915 }
3916 else
3917 {
3918 is_active = has_active_builtin(builtin: bi_type, storage);
3919 }
3920 }
3921
3922 bool filter_patch_decoration = (has_decoration(id: var_id, decoration: DecorationPatch) || is_patch_block(type)) == patch;
3923
3924 bool hidden = is_hidden_variable(var, include_builtins: incl_builtins);
3925
3926 // ClipDistance is never hidden, we need to emulate it when used as an input.
3927 if (bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance)
3928 hidden = false;
3929
3930 // It's not enough to simply avoid marking fragment outputs if the pipeline won't
3931 // accept them. We can't put them in the struct at all, or otherwise the compiler
3932 // complains that the outputs weren't explicitly marked.
3933 // Frag depth and stencil outputs are incompatible with explicit early fragment tests.
3934 // In GLSL, depth and stencil outputs are just ignored when explicit early fragment tests are required.
3935 // In Metal, it's a compilation error, so we need to exclude them from the output struct.
3936 if (get_execution_model() == ExecutionModelFragment && storage == StorageClassOutput && !patch &&
3937 ((is_builtin && ((bi_type == BuiltInFragDepth && (!msl_options.enable_frag_depth_builtin || uses_explicit_early_fragment_test())) ||
3938 (bi_type == BuiltInFragStencilRefEXT && (!msl_options.enable_frag_stencil_ref_builtin || uses_explicit_early_fragment_test())))) ||
3939 (!is_builtin && !(msl_options.enable_frag_output_mask & (1 << location)))))
3940 {
3941 hidden = true;
3942 disabled_frag_outputs.push_back(t: var_id);
3943 // If a builtin, force it to have the proper name, and mark it as not part of the output struct.
3944 if (is_builtin)
3945 {
3946 set_name(id: var_id, name: builtin_to_glsl(builtin: bi_type, storage: StorageClassFunction));
3947 mask_stage_output_by_builtin(builtin: bi_type);
3948 }
3949 }
3950
3951 // Barycentric inputs must be emitted in stage-in, because they can have interpolation arguments.
3952 if (is_active && (bi_type == BuiltInBaryCoordKHR || bi_type == BuiltInBaryCoordNoPerspKHR))
3953 {
3954 if (has_seen_barycentric)
3955 SPIRV_CROSS_THROW("Cannot declare both BaryCoordNV and BaryCoordNoPerspNV in same shader in MSL.");
3956 has_seen_barycentric = true;
3957 hidden = false;
3958 }
3959
3960 if (is_active && !hidden && type.pointer && filter_patch_decoration &&
3961 (!is_builtin || is_interface_block_builtin))
3962 {
3963 vars.push_back(t: &var);
3964
3965 if (!is_builtin)
3966 {
3967 // Need to deal specially with DecorationComponent.
3968 // Multiple variables can alias the same Location, and try to make sure each location is declared only once.
3969 // We will swizzle data in and out to make this work.
3970 // This is only relevant for vertex inputs and fragment outputs.
3971 // Technically tessellation as well, but it is too complicated to support.
3972 uint32_t component = get_decoration(id: var_id, decoration: DecorationComponent);
3973 if (component != 0)
3974 {
3975 if (is_tessellation_shader())
3976 SPIRV_CROSS_THROW("Component decoration is not supported in tessellation shaders.");
3977 else if (pack_components)
3978 {
3979 uint32_t array_size = 1;
3980 if (!type.array.empty())
3981 array_size = to_array_size_literal(type);
3982
3983 for (uint32_t location_offset = 0; location_offset < array_size; location_offset++)
3984 {
3985 auto &location_meta = meta.location_meta[location + location_offset];
3986 location_meta.num_components = max<uint32_t>(a: location_meta.num_components, b: component + type.vecsize);
3987
3988 // For variables sharing location, decorations and base type must match.
3989 location_meta.base_type_id = type.self;
3990 location_meta.flat = has_decoration(id: var.self, decoration: DecorationFlat);
3991 location_meta.noperspective = has_decoration(id: var.self, decoration: DecorationNoPerspective);
3992 location_meta.centroid = has_decoration(id: var.self, decoration: DecorationCentroid);
3993 location_meta.sample = has_decoration(id: var.self, decoration: DecorationSample);
3994 }
3995 }
3996 }
3997 }
3998 }
3999
4000 if (is_tese_shader() && msl_options.raw_buffer_tese_input && patch && storage == StorageClassInput &&
4001 (bi_type == BuiltInTessLevelOuter || bi_type == BuiltInTessLevelInner))
4002 {
4003 // In this case, we won't add the builtin to the interface struct,
4004 // but we still need the hook to run to populate the arrays.
4005 string base_ref = join(ts&: tess_factor_buffer_var_name, ts: "[", ts: to_expression(id: builtin_primitive_id_id), ts: "]");
4006 const char *mbr_name =
4007 bi_type == BuiltInTessLevelOuter ? "edgeTessellationFactor" : "insideTessellationFactor";
4008 add_tess_level_input(base_ref, mbr_name, var);
4009 if (inputs_by_builtin.count(x: bi_type))
4010 {
4011 uint32_t locn = inputs_by_builtin[bi_type].location;
4012 mark_location_as_used_by_shader(location: locn, type, storage: StorageClassInput);
4013 }
4014 }
4015 });
4016
4017 // If no variables qualify, leave.
4018 // For patch input in a tessellation evaluation shader, the per-vertex stage inputs
4019 // are included in a special patch control point array.
4020 if (vars.empty() &&
4021 !(!msl_options.raw_buffer_tese_input && storage == StorageClassInput && patch && stage_in_var_id))
4022 return 0;
4023
4024 // Add a new typed variable for this interface structure.
4025 // The initializer expression is allocated here, but populated when the function
4026 // declaraion is emitted, because it is cleared after each compilation pass.
4027 uint32_t next_id = ir.increase_bound_by(count: 3);
4028 uint32_t ib_type_id = next_id++;
4029 auto &ib_type = set<SPIRType>(id: ib_type_id, args: OpTypeStruct);
4030 ib_type.basetype = SPIRType::Struct;
4031 ib_type.storage = storage;
4032 set_decoration(id: ib_type_id, decoration: DecorationBlock);
4033
4034 uint32_t ib_var_id = next_id++;
4035 auto &var = set<SPIRVariable>(id: ib_var_id, args&: ib_type_id, args&: storage, args: 0);
4036 var.initializer = next_id++;
4037
4038 string ib_var_ref;
4039 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
4040 switch (storage)
4041 {
4042 case StorageClassInput:
4043 ib_var_ref = patch ? patch_stage_in_var_name : stage_in_var_name;
4044 switch (get_execution_model())
4045 {
4046 case ExecutionModelTessellationControl:
4047 // Add a hook to populate the shared workgroup memory containing the gl_in array.
4048 entry_func.fixup_hooks_in.push_back(t: [=]() {
4049 // Can't use PatchVertices, PrimitiveId, or InvocationId yet; the hooks for those may not have run yet.
4050 if (msl_options.multi_patch_workgroup)
4051 {
4052 // n.b. builtin_invocation_id_id here is the dispatch global invocation ID,
4053 // not the TC invocation ID.
4054 statement(ts: "device ", ts: to_name(id: ir.default_entry_point), ts: "_", ts: ib_var_ref, ts: "* gl_in = &",
4055 ts&: input_buffer_var_name, ts: "[min(", ts: to_expression(id: builtin_invocation_id_id), ts: ".x / ",
4056 ts&: get_entry_point().output_vertices,
4057 ts: ", spvIndirectParams[1] - 1) * spvIndirectParams[0]];");
4058 }
4059 else
4060 {
4061 // It's safe to use InvocationId here because it's directly mapped to a
4062 // Metal builtin, and therefore doesn't need a hook.
4063 statement(ts: "if (", ts: to_expression(id: builtin_invocation_id_id), ts: " < spvIndirectParams[0])");
4064 statement(ts: " ", ts&: input_wg_var_name, ts: "[", ts: to_expression(id: builtin_invocation_id_id),
4065 ts: "] = ", ts: ib_var_ref, ts: ";");
4066 statement(ts: "threadgroup_barrier(mem_flags::mem_threadgroup);");
4067 statement(ts: "if (", ts: to_expression(id: builtin_invocation_id_id),
4068 ts: " >= ", ts&: get_entry_point().output_vertices, ts: ")");
4069 statement(ts: " return;");
4070 }
4071 });
4072 break;
4073 case ExecutionModelTessellationEvaluation:
4074 if (!msl_options.raw_buffer_tese_input)
4075 break;
4076 if (patch)
4077 {
4078 entry_func.fixup_hooks_in.push_back(
4079 t: [=]()
4080 {
4081 statement(ts: "const device ", ts: to_name(id: ir.default_entry_point), ts: "_", ts: ib_var_ref, ts: "& ", ts: ib_var_ref,
4082 ts: " = ", ts&: patch_input_buffer_var_name, ts: "[", ts: to_expression(id: builtin_primitive_id_id),
4083 ts: "];");
4084 });
4085 }
4086 else
4087 {
4088 entry_func.fixup_hooks_in.push_back(
4089 t: [=]()
4090 {
4091 statement(ts: "const device ", ts: to_name(id: ir.default_entry_point), ts: "_", ts: ib_var_ref, ts: "* gl_in = &",
4092 ts&: input_buffer_var_name, ts: "[", ts: to_expression(id: builtin_primitive_id_id), ts: " * ",
4093 ts&: get_entry_point().output_vertices, ts: "];");
4094 });
4095 }
4096 break;
4097 default:
4098 break;
4099 }
4100 break;
4101
4102 case StorageClassOutput:
4103 {
4104 ib_var_ref = patch ? patch_stage_out_var_name : stage_out_var_name;
4105
4106 // Add the output interface struct as a local variable to the entry function.
4107 // If the entry point should return the output struct, set the entry function
4108 // to return the output interface struct, otherwise to return nothing.
4109 // Watch out for the rare case where the terminator of the last entry point block is a
4110 // Kill, instead of a Return. Based on SPIR-V's block-domination rules, we assume that
4111 // any block that has a Kill will also have a terminating Return, except the last block.
4112 // Indicate the output var requires early initialization.
4113 bool ep_should_return_output = !get_is_rasterization_disabled();
4114 uint32_t rtn_id = ep_should_return_output ? ib_var_id : 0;
4115 if (!capture_output_to_buffer)
4116 {
4117 entry_func.add_local_variable(id: ib_var_id);
4118 for (auto &blk_id : entry_func.blocks)
4119 {
4120 auto &blk = get<SPIRBlock>(id: blk_id);
4121 if (blk.terminator == SPIRBlock::Return || (blk.terminator == SPIRBlock::Kill && blk_id == entry_func.blocks.back()))
4122 blk.return_value = rtn_id;
4123 }
4124 vars_needing_early_declaration.push_back(t: ib_var_id);
4125 }
4126 else
4127 {
4128 switch (get_execution_model())
4129 {
4130 case ExecutionModelVertex:
4131 case ExecutionModelTessellationEvaluation:
4132 // Instead of declaring a struct variable to hold the output and then
4133 // copying that to the output buffer, we'll declare the output variable
4134 // as a reference to the final output element in the buffer. Then we can
4135 // avoid the extra copy.
4136 entry_func.fixup_hooks_in.push_back(t: [=]() {
4137 if (stage_out_var_id)
4138 {
4139 // The first member of the indirect buffer is always the number of vertices
4140 // to draw.
4141 // We zero-base the InstanceID & VertexID variables for HLSL emulation elsewhere, so don't do it twice
4142 if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation)
4143 {
4144 statement(ts: "device ", ts: to_name(id: ir.default_entry_point), ts: "_", ts: ib_var_ref, ts: "& ", ts: ib_var_ref,
4145 ts: " = ", ts&: output_buffer_var_name, ts: "[", ts: to_expression(id: builtin_invocation_id_id),
4146 ts: ".y * ", ts: to_expression(id: builtin_stage_input_size_id), ts: ".x + ",
4147 ts: to_expression(id: builtin_invocation_id_id), ts: ".x];");
4148 }
4149 else if (msl_options.enable_base_index_zero)
4150 {
4151 statement(ts: "device ", ts: to_name(id: ir.default_entry_point), ts: "_", ts: ib_var_ref, ts: "& ", ts: ib_var_ref,
4152 ts: " = ", ts&: output_buffer_var_name, ts: "[", ts: to_expression(id: builtin_instance_idx_id),
4153 ts: " * spvIndirectParams[0] + ", ts: to_expression(id: builtin_vertex_idx_id), ts: "];");
4154 }
4155 else
4156 {
4157 statement(ts: "device ", ts: to_name(id: ir.default_entry_point), ts: "_", ts: ib_var_ref, ts: "& ", ts: ib_var_ref,
4158 ts: " = ", ts&: output_buffer_var_name, ts: "[(", ts: to_expression(id: builtin_instance_idx_id),
4159 ts: " - ", ts: to_expression(id: builtin_base_instance_id), ts: ") * spvIndirectParams[0] + ",
4160 ts: to_expression(id: builtin_vertex_idx_id), ts: " - ",
4161 ts: to_expression(id: builtin_base_vertex_id), ts: "];");
4162 }
4163 }
4164 });
4165 break;
4166 case ExecutionModelTessellationControl:
4167 if (msl_options.multi_patch_workgroup)
4168 {
4169 // We cannot use PrimitiveId here, because the hook may not have run yet.
4170 if (patch)
4171 {
4172 entry_func.fixup_hooks_in.push_back(t: [=]() {
4173 statement(ts: "device ", ts: to_name(id: ir.default_entry_point), ts: "_", ts: ib_var_ref, ts: "& ", ts: ib_var_ref,
4174 ts: " = ", ts&: patch_output_buffer_var_name, ts: "[", ts: to_expression(id: builtin_invocation_id_id),
4175 ts: ".x / ", ts&: get_entry_point().output_vertices, ts: "];");
4176 });
4177 }
4178 else
4179 {
4180 entry_func.fixup_hooks_in.push_back(t: [=]() {
4181 statement(ts: "device ", ts: to_name(id: ir.default_entry_point), ts: "_", ts: ib_var_ref, ts: "* gl_out = &",
4182 ts&: output_buffer_var_name, ts: "[", ts: to_expression(id: builtin_invocation_id_id), ts: ".x - ",
4183 ts: to_expression(id: builtin_invocation_id_id), ts: ".x % ",
4184 ts&: get_entry_point().output_vertices, ts: "];");
4185 });
4186 }
4187 }
4188 else
4189 {
4190 if (patch)
4191 {
4192 entry_func.fixup_hooks_in.push_back(t: [=]() {
4193 statement(ts: "device ", ts: to_name(id: ir.default_entry_point), ts: "_", ts: ib_var_ref, ts: "& ", ts: ib_var_ref,
4194 ts: " = ", ts&: patch_output_buffer_var_name, ts: "[", ts: to_expression(id: builtin_primitive_id_id),
4195 ts: "];");
4196 });
4197 }
4198 else
4199 {
4200 entry_func.fixup_hooks_in.push_back(t: [=]() {
4201 statement(ts: "device ", ts: to_name(id: ir.default_entry_point), ts: "_", ts: ib_var_ref, ts: "* gl_out = &",
4202 ts&: output_buffer_var_name, ts: "[", ts: to_expression(id: builtin_primitive_id_id), ts: " * ",
4203 ts&: get_entry_point().output_vertices, ts: "];");
4204 });
4205 }
4206 }
4207 break;
4208 default:
4209 break;
4210 }
4211 }
4212 break;
4213 }
4214
4215 default:
4216 break;
4217 }
4218
4219 set_name(id: ib_type_id, name: to_name(id: ir.default_entry_point) + "_" + ib_var_ref);
4220 set_name(id: ib_var_id, name: ib_var_ref);
4221
4222 for (auto *p_var : vars)
4223 {
4224 bool strip_array = (is_tesc_shader() || (is_tese_shader() && storage == StorageClassInput)) && !patch;
4225
4226 // Fixing up flattened stores in TESC is impossible since the memory is group shared either via
4227 // device (not masked) or threadgroup (masked) storage classes and it's race condition city.
4228 meta.strip_array = strip_array;
4229 meta.allow_local_declaration = !strip_array && !(is_tesc_shader() && storage == StorageClassOutput);
4230 add_variable_to_interface_block(storage, ib_var_ref, ib_type, var&: *p_var, meta);
4231 }
4232
4233 if (((is_tesc_shader() && msl_options.multi_patch_workgroup) ||
4234 (is_tese_shader() && msl_options.raw_buffer_tese_input)) &&
4235 storage == StorageClassInput)
4236 {
4237 // For tessellation inputs, add all outputs from the previous stage to ensure
4238 // the struct containing them is the correct size and layout.
4239 for (auto &input : inputs_by_location)
4240 {
4241 if (location_inputs_in_use.count(x: input.first.location) != 0)
4242 continue;
4243
4244 if (patch != (input.second.rate == MSL_SHADER_VARIABLE_RATE_PER_PATCH))
4245 continue;
4246
4247 // Tessellation levels have their own struct, so there's no need to add them here.
4248 if (input.second.builtin == BuiltInTessLevelOuter || input.second.builtin == BuiltInTessLevelInner)
4249 continue;
4250
4251 // Create a fake variable to put at the location.
4252 uint32_t offset = ir.increase_bound_by(count: 5);
4253 uint32_t type_id = offset;
4254 uint32_t vec_type_id = offset + 1;
4255 uint32_t array_type_id = offset + 2;
4256 uint32_t ptr_type_id = offset + 3;
4257 uint32_t var_id = offset + 4;
4258
4259 SPIRType type { OpTypeInt };
4260 switch (input.second.format)
4261 {
4262 case MSL_SHADER_VARIABLE_FORMAT_UINT16:
4263 case MSL_SHADER_VARIABLE_FORMAT_ANY16:
4264 type.basetype = SPIRType::UShort;
4265 type.width = 16;
4266 break;
4267 case MSL_SHADER_VARIABLE_FORMAT_ANY32:
4268 default:
4269 type.basetype = SPIRType::UInt;
4270 type.width = 32;
4271 break;
4272 }
4273 set<SPIRType>(id: type_id, args&: type);
4274 if (input.second.vecsize > 1)
4275 {
4276 type.op = OpTypeVector;
4277 type.vecsize = input.second.vecsize;
4278 set<SPIRType>(id: vec_type_id, args&: type);
4279 type_id = vec_type_id;
4280 }
4281
4282 type.op = OpTypeArray;
4283 type.array.push_back(t: 0);
4284 type.array_size_literal.push_back(t: true);
4285 type.parent_type = type_id;
4286 set<SPIRType>(id: array_type_id, args&: type);
4287 type.self = type_id;
4288
4289 type.op = OpTypePointer;
4290 type.pointer = true;
4291 type.pointer_depth++;
4292 type.parent_type = array_type_id;
4293 type.storage = storage;
4294 auto &ptr_type = set<SPIRType>(id: ptr_type_id, args&: type);
4295 ptr_type.self = array_type_id;
4296
4297 auto &fake_var = set<SPIRVariable>(id: var_id, args&: ptr_type_id, args&: storage);
4298 set_decoration(id: var_id, decoration: DecorationLocation, argument: input.first.location);
4299 if (input.first.component)
4300 set_decoration(id: var_id, decoration: DecorationComponent, argument: input.first.component);
4301
4302 meta.strip_array = true;
4303 meta.allow_local_declaration = false;
4304 add_variable_to_interface_block(storage, ib_var_ref, ib_type, var&: fake_var, meta);
4305 }
4306 }
4307
4308 if (capture_output_to_buffer && storage == StorageClassOutput)
4309 {
4310 // For captured output, add all inputs from the next stage to ensure
4311 // the struct containing them is the correct size and layout. This is
4312 // necessary for certain implicit builtins that may nonetheless be read,
4313 // even when they aren't written.
4314 for (auto &output : outputs_by_location)
4315 {
4316 if (location_outputs_in_use.count(x: output.first.location) != 0)
4317 continue;
4318
4319 // Create a fake variable to put at the location.
4320 uint32_t offset = ir.increase_bound_by(count: 5);
4321 uint32_t type_id = offset;
4322 uint32_t vec_type_id = offset + 1;
4323 uint32_t array_type_id = offset + 2;
4324 uint32_t ptr_type_id = offset + 3;
4325 uint32_t var_id = offset + 4;
4326
4327 SPIRType type { OpTypeInt };
4328 switch (output.second.format)
4329 {
4330 case MSL_SHADER_VARIABLE_FORMAT_UINT16:
4331 case MSL_SHADER_VARIABLE_FORMAT_ANY16:
4332 type.basetype = SPIRType::UShort;
4333 type.width = 16;
4334 break;
4335 case MSL_SHADER_VARIABLE_FORMAT_ANY32:
4336 default:
4337 type.basetype = SPIRType::UInt;
4338 type.width = 32;
4339 break;
4340 }
4341 set<SPIRType>(id: type_id, args&: type);
4342 if (output.second.vecsize > 1)
4343 {
4344 type.op = OpTypeVector;
4345 type.vecsize = output.second.vecsize;
4346 set<SPIRType>(id: vec_type_id, args&: type);
4347 type_id = vec_type_id;
4348 }
4349
4350 if (is_tesc_shader())
4351 {
4352 type.op = OpTypeArray;
4353 type.array.push_back(t: 0);
4354 type.array_size_literal.push_back(t: true);
4355 type.parent_type = type_id;
4356 set<SPIRType>(id: array_type_id, args&: type);
4357 }
4358
4359 type.op = OpTypePointer;
4360 type.pointer = true;
4361 type.pointer_depth++;
4362 type.parent_type = is_tesc_shader() ? array_type_id : type_id;
4363 type.storage = storage;
4364 auto &ptr_type = set<SPIRType>(id: ptr_type_id, args&: type);
4365 ptr_type.self = type.parent_type;
4366
4367 auto &fake_var = set<SPIRVariable>(id: var_id, args&: ptr_type_id, args&: storage);
4368 set_decoration(id: var_id, decoration: DecorationLocation, argument: output.first.location);
4369 if (output.first.component)
4370 set_decoration(id: var_id, decoration: DecorationComponent, argument: output.first.component);
4371
4372 meta.strip_array = true;
4373 meta.allow_local_declaration = false;
4374 add_variable_to_interface_block(storage, ib_var_ref, ib_type, var&: fake_var, meta);
4375 }
4376 }
4377
4378 // When multiple variables need to access same location,
4379 // unroll locations one by one and we will flatten output or input as necessary.
4380 for (auto &loc : meta.location_meta)
4381 {
4382 uint32_t location = loc.first;
4383 auto &location_meta = loc.second;
4384
4385 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
4386 uint32_t type_id = build_extended_vector_type(type_id: location_meta.base_type_id, components: location_meta.num_components);
4387 ib_type.member_types.push_back(t: type_id);
4388
4389 set_member_name(id: ib_type.self, index: ib_mbr_idx, name: join(ts: "m_location_", ts&: location));
4390 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationLocation, argument: location);
4391 mark_location_as_used_by_shader(location, type: get<SPIRType>(id: type_id), storage);
4392
4393 if (location_meta.flat)
4394 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationFlat);
4395 if (location_meta.noperspective)
4396 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationNoPerspective);
4397 if (location_meta.centroid)
4398 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationCentroid);
4399 if (location_meta.sample)
4400 set_member_decoration(id: ib_type.self, index: ib_mbr_idx, decoration: DecorationSample);
4401 }
4402
4403 // Sort the members of the structure by their locations.
4404 MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::LocationThenBuiltInType);
4405 member_sorter.sort();
4406
4407 // The member indices were saved to the original variables, but after the members
4408 // were sorted, those indices are now likely incorrect. Fix those up now.
4409 fix_up_interface_member_indices(storage, ib_type_id);
4410
4411 // For patch inputs, add one more member, holding the array of control point data.
4412 if (is_tese_shader() && !msl_options.raw_buffer_tese_input && storage == StorageClassInput && patch &&
4413 stage_in_var_id)
4414 {
4415 uint32_t pcp_type_id = ir.increase_bound_by(count: 1);
4416 auto &pcp_type = set<SPIRType>(id: pcp_type_id, args&: ib_type);
4417 pcp_type.basetype = SPIRType::ControlPointArray;
4418 pcp_type.parent_type = pcp_type.type_alias = get_stage_in_struct_type().self;
4419 pcp_type.storage = storage;
4420 ir.meta[pcp_type_id] = ir.meta[ib_type.self];
4421 uint32_t mbr_idx = uint32_t(ib_type.member_types.size());
4422 ib_type.member_types.push_back(t: pcp_type_id);
4423 set_member_name(id: ib_type.self, index: mbr_idx, name: "gl_in");
4424 }
4425
4426 if (storage == StorageClassInput)
4427 set_decoration(id: ib_var_id, decoration: DecorationNonWritable);
4428
4429 return ib_var_id;
4430}
4431
4432uint32_t CompilerMSL::add_interface_block_pointer(uint32_t ib_var_id, StorageClass storage)
4433{
4434 if (!ib_var_id)
4435 return 0;
4436
4437 uint32_t ib_ptr_var_id;
4438 uint32_t next_id = ir.increase_bound_by(count: 3);
4439 auto &ib_type = expression_type(id: ib_var_id);
4440 if (is_tesc_shader() || (is_tese_shader() && msl_options.raw_buffer_tese_input))
4441 {
4442 // Tessellation control per-vertex I/O is presented as an array, so we must
4443 // do the same with our struct here.
4444 uint32_t ib_ptr_type_id = next_id++;
4445 auto &ib_ptr_type = set<SPIRType>(id: ib_ptr_type_id, args: ib_type);
4446 ib_ptr_type.op = OpTypePointer;
4447 ib_ptr_type.parent_type = ib_ptr_type.type_alias = ib_type.self;
4448 ib_ptr_type.pointer = true;
4449 ib_ptr_type.pointer_depth++;
4450 ib_ptr_type.storage = storage == StorageClassInput ?
4451 ((is_tesc_shader() && msl_options.multi_patch_workgroup) ||
4452 (is_tese_shader() && msl_options.raw_buffer_tese_input) ?
4453 StorageClassStorageBuffer :
4454 StorageClassWorkgroup) :
4455 StorageClassStorageBuffer;
4456 ir.meta[ib_ptr_type_id] = ir.meta[ib_type.self];
4457 // To ensure that get_variable_data_type() doesn't strip off the pointer,
4458 // which we need, use another pointer.
4459 uint32_t ib_ptr_ptr_type_id = next_id++;
4460 auto &ib_ptr_ptr_type = set<SPIRType>(id: ib_ptr_ptr_type_id, args&: ib_ptr_type);
4461 ib_ptr_ptr_type.parent_type = ib_ptr_type_id;
4462 ib_ptr_ptr_type.type_alias = ib_type.self;
4463 ib_ptr_ptr_type.storage = StorageClassFunction;
4464 ir.meta[ib_ptr_ptr_type_id] = ir.meta[ib_type.self];
4465
4466 ib_ptr_var_id = next_id;
4467 set<SPIRVariable>(id: ib_ptr_var_id, args&: ib_ptr_ptr_type_id, args: StorageClassFunction, args: 0);
4468 set_name(id: ib_ptr_var_id, name: storage == StorageClassInput ? "gl_in" : "gl_out");
4469 if (storage == StorageClassInput)
4470 set_decoration(id: ib_ptr_var_id, decoration: DecorationNonWritable);
4471 }
4472 else
4473 {
4474 // Tessellation evaluation per-vertex inputs are also presented as arrays.
4475 // But, in Metal, this array uses a very special type, 'patch_control_point<T>',
4476 // which is a container that can be used to access the control point data.
4477 // To represent this, a special 'ControlPointArray' type has been added to the
4478 // SPIRV-Cross type system. It should only be generated by and seen in the MSL
4479 // backend (i.e. this one).
4480 uint32_t pcp_type_id = next_id++;
4481 auto &pcp_type = set<SPIRType>(id: pcp_type_id, args: ib_type);
4482 pcp_type.basetype = SPIRType::ControlPointArray;
4483 pcp_type.parent_type = pcp_type.type_alias = ib_type.self;
4484 pcp_type.storage = storage;
4485 ir.meta[pcp_type_id] = ir.meta[ib_type.self];
4486
4487 ib_ptr_var_id = next_id;
4488 set<SPIRVariable>(id: ib_ptr_var_id, args&: pcp_type_id, args&: storage, args: 0);
4489 set_name(id: ib_ptr_var_id, name: "gl_in");
4490 ir.meta[ib_ptr_var_id].decoration.qualified_alias = join(ts&: patch_stage_in_var_name, ts: ".gl_in");
4491 }
4492 return ib_ptr_var_id;
4493}
4494
4495// Ensure that the type is compatible with the builtin.
4496// If it is, simply return the given type ID.
4497// Otherwise, create a new type, and return it's ID.
4498uint32_t CompilerMSL::ensure_correct_builtin_type(uint32_t type_id, BuiltIn builtin)
4499{
4500 auto &type = get<SPIRType>(id: type_id);
4501 auto &pointee_type = get_pointee_type(type);
4502
4503 if ((builtin == BuiltInSampleMask && is_array(type: pointee_type)) ||
4504 ((builtin == BuiltInLayer || builtin == BuiltInViewportIndex || builtin == BuiltInFragStencilRefEXT) &&
4505 pointee_type.basetype != SPIRType::UInt))
4506 {
4507 uint32_t next_id = ir.increase_bound_by(count: is_pointer(type) ? 2 : 1);
4508 uint32_t base_type_id = next_id++;
4509 auto &base_type = set<SPIRType>(id: base_type_id, args: OpTypeInt);
4510 base_type.basetype = SPIRType::UInt;
4511 base_type.width = 32;
4512
4513 if (!is_pointer(type))
4514 return base_type_id;
4515
4516 uint32_t ptr_type_id = next_id++;
4517 auto &ptr_type = set<SPIRType>(id: ptr_type_id, args&: base_type);
4518 ptr_type.op = spv::OpTypePointer;
4519 ptr_type.pointer = true;
4520 ptr_type.pointer_depth++;
4521 ptr_type.storage = type.storage;
4522 ptr_type.parent_type = base_type_id;
4523 return ptr_type_id;
4524 }
4525
4526 return type_id;
4527}
4528
4529// Ensure that the type is compatible with the shader input.
4530// If it is, simply return the given type ID.
4531// Otherwise, create a new type, and return its ID.
4532uint32_t CompilerMSL::ensure_correct_input_type(uint32_t type_id, uint32_t location, uint32_t component, uint32_t num_components, bool strip_array)
4533{
4534 auto &type = get<SPIRType>(id: type_id);
4535
4536 uint32_t max_array_dimensions = strip_array ? 1 : 0;
4537
4538 // Struct and array types must match exactly.
4539 if (type.basetype == SPIRType::Struct || type.array.size() > max_array_dimensions)
4540 return type_id;
4541
4542 auto p_va = inputs_by_location.find(x: {.location: location, .component: component});
4543 if (p_va == end(cont&: inputs_by_location))
4544 {
4545 if (num_components > type.vecsize)
4546 return build_extended_vector_type(type_id, components: num_components);
4547 else
4548 return type_id;
4549 }
4550
4551 if (num_components == 0)
4552 num_components = p_va->second.vecsize;
4553
4554 switch (p_va->second.format)
4555 {
4556 case MSL_SHADER_VARIABLE_FORMAT_UINT8:
4557 {
4558 switch (type.basetype)
4559 {
4560 case SPIRType::UByte:
4561 case SPIRType::UShort:
4562 case SPIRType::UInt:
4563 if (num_components > type.vecsize)
4564 return build_extended_vector_type(type_id, components: num_components);
4565 else
4566 return type_id;
4567
4568 case SPIRType::Short:
4569 return build_extended_vector_type(type_id, components: num_components > type.vecsize ? num_components : type.vecsize,
4570 basetype: SPIRType::UShort);
4571 case SPIRType::Int:
4572 return build_extended_vector_type(type_id, components: num_components > type.vecsize ? num_components : type.vecsize,
4573 basetype: SPIRType::UInt);
4574
4575 default:
4576 SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader");
4577 }
4578 }
4579
4580 case MSL_SHADER_VARIABLE_FORMAT_UINT16:
4581 {
4582 switch (type.basetype)
4583 {
4584 case SPIRType::UShort:
4585 case SPIRType::UInt:
4586 if (num_components > type.vecsize)
4587 return build_extended_vector_type(type_id, components: num_components);
4588 else
4589 return type_id;
4590
4591 case SPIRType::Int:
4592 return build_extended_vector_type(type_id, components: num_components > type.vecsize ? num_components : type.vecsize,
4593 basetype: SPIRType::UInt);
4594
4595 default:
4596 SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader");
4597 }
4598 }
4599
4600 default:
4601 if (num_components > type.vecsize)
4602 type_id = build_extended_vector_type(type_id, components: num_components);
4603 break;
4604 }
4605
4606 return type_id;
4607}
4608
4609void CompilerMSL::mark_struct_members_packed(const SPIRType &type)
4610{
4611 // Handle possible recursion when a struct contains a pointer to its own type nested somewhere.
4612 if (has_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationPhysicalTypePacked))
4613 return;
4614
4615 set_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationPhysicalTypePacked);
4616
4617 // Problem case! Struct needs to be placed at an awkward alignment.
4618 // Mark every member of the child struct as packed.
4619 uint32_t mbr_cnt = uint32_t(type.member_types.size());
4620 for (uint32_t i = 0; i < mbr_cnt; i++)
4621 {
4622 auto &mbr_type = get<SPIRType>(id: type.member_types[i]);
4623 if (mbr_type.basetype == SPIRType::Struct)
4624 {
4625 // Recursively mark structs as packed.
4626 auto *struct_type = &mbr_type;
4627 while (!struct_type->array.empty())
4628 struct_type = &get<SPIRType>(id: struct_type->parent_type);
4629 mark_struct_members_packed(type: *struct_type);
4630 }
4631 else if (!is_scalar(type: mbr_type))
4632 set_extended_member_decoration(type: type.self, index: i, decoration: SPIRVCrossDecorationPhysicalTypePacked);
4633 }
4634}
4635
4636void CompilerMSL::mark_scalar_layout_structs(const SPIRType &type)
4637{
4638 uint32_t mbr_cnt = uint32_t(type.member_types.size());
4639 for (uint32_t i = 0; i < mbr_cnt; i++)
4640 {
4641 // Handle possible recursion when a struct contains a pointer to its own type nested somewhere.
4642 auto &mbr_type = get<SPIRType>(id: type.member_types[i]);
4643 if (mbr_type.basetype == SPIRType::Struct && !(mbr_type.pointer && mbr_type.storage == StorageClassPhysicalStorageBuffer))
4644 {
4645 auto *struct_type = &mbr_type;
4646 while (!struct_type->array.empty())
4647 struct_type = &get<SPIRType>(id: struct_type->parent_type);
4648
4649 if (has_extended_decoration(id: struct_type->self, decoration: SPIRVCrossDecorationPhysicalTypePacked))
4650 continue;
4651
4652 uint32_t msl_alignment = get_declared_struct_member_alignment_msl(struct_type: type, index: i);
4653 uint32_t msl_size = get_declared_struct_member_size_msl(struct_type: type, index: i);
4654 uint32_t spirv_offset = type_struct_member_offset(type, index: i);
4655 uint32_t spirv_offset_next;
4656 if (i + 1 < mbr_cnt)
4657 spirv_offset_next = type_struct_member_offset(type, index: i + 1);
4658 else
4659 spirv_offset_next = spirv_offset + msl_size;
4660
4661 // Both are complicated cases. In scalar layout, a struct of float3 might just consume 12 bytes,
4662 // and the next member will be placed at offset 12.
4663 bool struct_is_misaligned = (spirv_offset % msl_alignment) != 0;
4664 bool struct_is_too_large = spirv_offset + msl_size > spirv_offset_next;
4665 uint32_t array_stride = 0;
4666 bool struct_needs_explicit_padding = false;
4667
4668 // Verify that if a struct is used as an array that ArrayStride matches the effective size of the struct.
4669 if (!mbr_type.array.empty())
4670 {
4671 array_stride = type_struct_member_array_stride(type, index: i);
4672 uint32_t dimensions = uint32_t(mbr_type.array.size() - 1);
4673 for (uint32_t dim = 0; dim < dimensions; dim++)
4674 {
4675 uint32_t array_size = to_array_size_literal(type: mbr_type, index: dim);
4676 array_stride /= max<uint32_t>(a: array_size, b: 1u);
4677 }
4678
4679 // Set expected struct size based on ArrayStride.
4680 struct_needs_explicit_padding = true;
4681
4682 // If struct size is larger than array stride, we might be able to fit, if we tightly pack.
4683 if (get_declared_struct_size_msl(struct_type: *struct_type) > array_stride)
4684 struct_is_too_large = true;
4685 }
4686
4687 if (struct_is_misaligned || struct_is_too_large)
4688 mark_struct_members_packed(type: *struct_type);
4689 mark_scalar_layout_structs(type: *struct_type);
4690
4691 if (struct_needs_explicit_padding)
4692 {
4693 msl_size = get_declared_struct_size_msl(struct_type: *struct_type, ignore_alignment: true, ignore_padding: true);
4694 if (array_stride < msl_size)
4695 {
4696 SPIRV_CROSS_THROW("Cannot express an array stride smaller than size of struct type.");
4697 }
4698 else
4699 {
4700 if (has_extended_decoration(id: struct_type->self, decoration: SPIRVCrossDecorationPaddingTarget))
4701 {
4702 if (array_stride !=
4703 get_extended_decoration(id: struct_type->self, decoration: SPIRVCrossDecorationPaddingTarget))
4704 SPIRV_CROSS_THROW(
4705 "A struct is used with different array strides. Cannot express this in MSL.");
4706 }
4707 else
4708 set_extended_decoration(id: struct_type->self, decoration: SPIRVCrossDecorationPaddingTarget, value: array_stride);
4709 }
4710 }
4711 }
4712 }
4713}
4714
4715// Sort the members of the struct type by offset, and pack and then pad members where needed
4716// to align MSL members with SPIR-V offsets. The struct members are iterated twice. Packing
4717// occurs first, followed by padding, because packing a member reduces both its size and its
4718// natural alignment, possibly requiring a padding member to be added ahead of it.
4719void CompilerMSL::align_struct(SPIRType &ib_type, unordered_set<uint32_t> &aligned_structs)
4720{
4721 // We align structs recursively, so stop any redundant work.
4722 ID &ib_type_id = ib_type.self;
4723 if (aligned_structs.count(x: ib_type_id))
4724 return;
4725 aligned_structs.insert(x: ib_type_id);
4726
4727 // Sort the members of the interface structure by their offset.
4728 // They should already be sorted per SPIR-V spec anyway.
4729 MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::Offset);
4730 member_sorter.sort();
4731
4732 auto mbr_cnt = uint32_t(ib_type.member_types.size());
4733
4734 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
4735 {
4736 // Pack any dependent struct types before we pack a parent struct.
4737 auto &mbr_type = get<SPIRType>(id: ib_type.member_types[mbr_idx]);
4738 if (mbr_type.basetype == SPIRType::Struct)
4739 align_struct(ib_type&: mbr_type, aligned_structs);
4740 }
4741
4742 // Test the alignment of each member, and if a member should be closer to the previous
4743 // member than the default spacing expects, it is likely that the previous member is in
4744 // a packed format. If so, and the previous member is packable, pack it.
4745 // For example ... this applies to any 3-element vector that is followed by a scalar.
4746 uint32_t msl_offset = 0;
4747 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
4748 {
4749 // This checks the member in isolation, if the member needs some kind of type remapping to conform to SPIR-V
4750 // offsets, array strides and matrix strides.
4751 ensure_member_packing_rules_msl(ib_type, index: mbr_idx);
4752
4753 // Align current offset to the current member's default alignment. If the member was packed, it will observe
4754 // the updated alignment here.
4755 uint32_t msl_align_mask = get_declared_struct_member_alignment_msl(struct_type: ib_type, index: mbr_idx) - 1;
4756 uint32_t aligned_msl_offset = (msl_offset + msl_align_mask) & ~msl_align_mask;
4757
4758 // Fetch the member offset as declared in the SPIRV.
4759 uint32_t spirv_mbr_offset = get_member_decoration(id: ib_type_id, index: mbr_idx, decoration: DecorationOffset);
4760 if (spirv_mbr_offset > aligned_msl_offset)
4761 {
4762 // Since MSL and SPIR-V have slightly different struct member alignment and
4763 // size rules, we'll pad to standard C-packing rules with a char[] array. If the member is farther
4764 // away than C-packing, expects, add an inert padding member before the the member.
4765 uint32_t padding_bytes = spirv_mbr_offset - aligned_msl_offset;
4766 set_extended_member_decoration(type: ib_type_id, index: mbr_idx, decoration: SPIRVCrossDecorationPaddingTarget, value: padding_bytes);
4767
4768 // Re-align as a sanity check that aligning post-padding matches up.
4769 msl_offset += padding_bytes;
4770 aligned_msl_offset = (msl_offset + msl_align_mask) & ~msl_align_mask;
4771 }
4772 else if (spirv_mbr_offset < aligned_msl_offset)
4773 {
4774 // This should not happen, but deal with unexpected scenarios.
4775 // It *might* happen if a sub-struct has a larger alignment requirement in MSL than SPIR-V.
4776 SPIRV_CROSS_THROW("Cannot represent buffer block correctly in MSL.");
4777 }
4778
4779 assert(aligned_msl_offset == spirv_mbr_offset);
4780
4781 // Increment the current offset to be positioned immediately after the current member.
4782 // Don't do this for the last member since it can be unsized, and it is not relevant for padding purposes here.
4783 if (mbr_idx + 1 < mbr_cnt)
4784 msl_offset = aligned_msl_offset + get_declared_struct_member_size_msl(struct_type: ib_type, index: mbr_idx);
4785 }
4786}
4787
4788bool CompilerMSL::validate_member_packing_rules_msl(const SPIRType &type, uint32_t index) const
4789{
4790 auto &mbr_type = get<SPIRType>(id: type.member_types[index]);
4791 uint32_t spirv_offset = get_member_decoration(id: type.self, index, decoration: DecorationOffset);
4792
4793 if (index + 1 < type.member_types.size())
4794 {
4795 // First, we will check offsets. If SPIR-V offset + MSL size > SPIR-V offset of next member,
4796 // we *must* perform some kind of remapping, no way getting around it.
4797 // We can always pad after this member if necessary, so that case is fine.
4798 uint32_t spirv_offset_next = get_member_decoration(id: type.self, index: index + 1, decoration: DecorationOffset);
4799 assert(spirv_offset_next >= spirv_offset);
4800 uint32_t maximum_size = spirv_offset_next - spirv_offset;
4801 uint32_t msl_mbr_size = get_declared_struct_member_size_msl(struct_type: type, index);
4802 if (msl_mbr_size > maximum_size)
4803 return false;
4804 }
4805
4806 if (!mbr_type.array.empty())
4807 {
4808 // If we have an array type, array stride must match exactly with SPIR-V.
4809
4810 // An exception to this requirement is if we have one array element.
4811 // This comes from DX scalar layout workaround.
4812 // If app tries to be cheeky and access the member out of bounds, this will not work, but this is the best we can do.
4813 // In OpAccessChain with logical memory models, access chains must be in-bounds in SPIR-V specification.
4814 bool relax_array_stride = mbr_type.array.back() == 1 && mbr_type.array_size_literal.back();
4815
4816 if (!relax_array_stride)
4817 {
4818 uint32_t spirv_array_stride = type_struct_member_array_stride(type, index);
4819 uint32_t msl_array_stride = get_declared_struct_member_array_stride_msl(struct_type: type, index);
4820 if (spirv_array_stride != msl_array_stride)
4821 return false;
4822 }
4823 }
4824
4825 if (is_matrix(type: mbr_type))
4826 {
4827 // Need to check MatrixStride as well.
4828 uint32_t spirv_matrix_stride = type_struct_member_matrix_stride(type, index);
4829 uint32_t msl_matrix_stride = get_declared_struct_member_matrix_stride_msl(struct_type: type, index);
4830 if (spirv_matrix_stride != msl_matrix_stride)
4831 return false;
4832 }
4833
4834 // Now, we check alignment.
4835 uint32_t msl_alignment = get_declared_struct_member_alignment_msl(struct_type: type, index);
4836 if ((spirv_offset % msl_alignment) != 0)
4837 return false;
4838
4839 // We're in the clear.
4840 return true;
4841}
4842
4843// Here we need to verify that the member type we declare conforms to Offset, ArrayStride or MatrixStride restrictions.
4844// If there is a mismatch, we need to emit remapped types, either normal types, or "packed_X" types.
4845// In odd cases we need to emit packed and remapped types, for e.g. weird matrices or arrays with weird array strides.
4846void CompilerMSL::ensure_member_packing_rules_msl(SPIRType &ib_type, uint32_t index)
4847{
4848 if (validate_member_packing_rules_msl(type: ib_type, index))
4849 return;
4850
4851 // We failed validation.
4852 // This case will be nightmare-ish to deal with. This could possibly happen if struct alignment does not quite
4853 // match up with what we want. Scalar block layout comes to mind here where we might have to work around the rule
4854 // that struct alignment == max alignment of all members and struct size depends on this alignment.
4855 // Can't repack structs, but can repack pointers to structs.
4856 auto &mbr_type = get<SPIRType>(id: ib_type.member_types[index]);
4857 bool is_buff_ptr = mbr_type.pointer && mbr_type.storage == StorageClassPhysicalStorageBuffer;
4858 if (mbr_type.basetype == SPIRType::Struct && !is_buff_ptr)
4859 SPIRV_CROSS_THROW("Cannot perform any repacking for structs when it is used as a member of another struct.");
4860
4861 // Perform remapping here.
4862 // There is nothing to be gained by using packed scalars, so don't attempt it.
4863 if (!is_scalar(type: ib_type))
4864 set_extended_member_decoration(type: ib_type.self, index, decoration: SPIRVCrossDecorationPhysicalTypePacked);
4865
4866 // Try validating again, now with packed.
4867 if (validate_member_packing_rules_msl(type: ib_type, index))
4868 return;
4869
4870 // We're in deep trouble, and we need to create a new PhysicalType which matches up with what we expect.
4871 // A lot of work goes here ...
4872 // We will need remapping on Load and Store to translate the types between Logical and Physical.
4873
4874 // First, we check if we have small vector std140 array.
4875 // We detect this if we have an array of vectors, and array stride is greater than number of elements.
4876 if (!mbr_type.array.empty() && !is_matrix(type: mbr_type))
4877 {
4878 uint32_t array_stride = type_struct_member_array_stride(type: ib_type, index);
4879
4880 // Hack off array-of-arrays until we find the array stride per element we must have to make it work.
4881 uint32_t dimensions = uint32_t(mbr_type.array.size() - 1);
4882 for (uint32_t dim = 0; dim < dimensions; dim++)
4883 array_stride /= max<uint32_t>(a: to_array_size_literal(type: mbr_type, index: dim), b: 1u);
4884
4885 // Pointers are 8 bytes
4886 uint32_t mbr_width_in_bytes = is_buff_ptr ? 8 : (mbr_type.width / 8);
4887 uint32_t elems_per_stride = array_stride / mbr_width_in_bytes;
4888
4889 if (elems_per_stride == 3)
4890 SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios.");
4891 else if (elems_per_stride > 4 && elems_per_stride != 8)
4892 SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL.");
4893
4894 if (elems_per_stride == 8)
4895 {
4896 if (mbr_type.width == 16)
4897 add_spv_func_and_recompile(spv_func: SPVFuncImplPaddedStd140);
4898 else
4899 SPIRV_CROSS_THROW("Unexpected type in std140 wide array resolve.");
4900 }
4901
4902 auto physical_type = mbr_type;
4903 physical_type.vecsize = elems_per_stride;
4904 physical_type.parent_type = 0;
4905
4906 // If this is a physical buffer pointer, replace type with a ulongn vector.
4907 if (is_buff_ptr)
4908 {
4909 physical_type.width = 64;
4910 physical_type.basetype = to_unsigned_basetype(width: physical_type.width);
4911 physical_type.pointer = false;
4912 physical_type.pointer_depth = false;
4913 physical_type.forward_pointer = false;
4914 }
4915
4916 uint32_t type_id = ir.increase_bound_by(count: 1);
4917 set<SPIRType>(id: type_id, args&: physical_type);
4918 set_extended_member_decoration(type: ib_type.self, index, decoration: SPIRVCrossDecorationPhysicalTypeID, value: type_id);
4919 set_decoration(id: type_id, decoration: DecorationArrayStride, argument: array_stride);
4920
4921 // Remove packed_ for vectors of size 1, 2 and 4.
4922 unset_extended_member_decoration(type: ib_type.self, index, decoration: SPIRVCrossDecorationPhysicalTypePacked);
4923 }
4924 else if (is_matrix(type: mbr_type))
4925 {
4926 // MatrixStride might be std140-esque.
4927 uint32_t matrix_stride = type_struct_member_matrix_stride(type: ib_type, index);
4928
4929 uint32_t elems_per_stride = matrix_stride / (mbr_type.width / 8);
4930
4931 if (elems_per_stride == 3)
4932 SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios.");
4933 else if (elems_per_stride > 4 && elems_per_stride != 8)
4934 SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL.");
4935
4936 if (elems_per_stride == 8)
4937 {
4938 if (mbr_type.basetype != SPIRType::Half)
4939 SPIRV_CROSS_THROW("Unexpected type in std140 wide matrix stride resolve.");
4940 add_spv_func_and_recompile(spv_func: SPVFuncImplPaddedStd140);
4941 }
4942
4943 bool row_major = has_member_decoration(id: ib_type.self, index, decoration: DecorationRowMajor);
4944 auto physical_type = mbr_type;
4945 physical_type.parent_type = 0;
4946
4947 if (row_major)
4948 physical_type.columns = elems_per_stride;
4949 else
4950 physical_type.vecsize = elems_per_stride;
4951 uint32_t type_id = ir.increase_bound_by(count: 1);
4952 set<SPIRType>(id: type_id, args&: physical_type);
4953 set_extended_member_decoration(type: ib_type.self, index, decoration: SPIRVCrossDecorationPhysicalTypeID, value: type_id);
4954
4955 // Remove packed_ for vectors of size 1, 2 and 4.
4956 unset_extended_member_decoration(type: ib_type.self, index, decoration: SPIRVCrossDecorationPhysicalTypePacked);
4957 }
4958 else
4959 SPIRV_CROSS_THROW("Found a buffer packing case which we cannot represent in MSL.");
4960
4961 // Try validating again, now with physical type remapping.
4962 if (validate_member_packing_rules_msl(type: ib_type, index))
4963 return;
4964
4965 // We might have a particular odd scalar layout case where the last element of an array
4966 // does not take up as much space as the ArrayStride or MatrixStride. This can happen with DX cbuffers.
4967 // The "proper" workaround for this is extremely painful and essentially impossible in the edge case of float3[],
4968 // so we hack around it by declaring the offending array or matrix with one less array size/col/row,
4969 // and rely on padding to get the correct value. We will technically access arrays out of bounds into the padding region,
4970 // but it should spill over gracefully without too much trouble. We rely on behavior like this for unsized arrays anyways.
4971
4972 // E.g. we might observe a physical layout of:
4973 // { float2 a[2]; float b; } in cbuffer layout where ArrayStride of a is 16, but offset of b is 24, packed right after a[1] ...
4974 uint32_t type_id = get_extended_member_decoration(type: ib_type.self, index, decoration: SPIRVCrossDecorationPhysicalTypeID);
4975 auto &type = get<SPIRType>(id: type_id);
4976
4977 // Modify the physical type in-place. This is safe since each physical type workaround is a copy.
4978 if (is_array(type))
4979 {
4980 if (type.array.back() > 1)
4981 {
4982 if (!type.array_size_literal.back())
4983 SPIRV_CROSS_THROW("Cannot apply scalar layout workaround with spec constant array size.");
4984 type.array.back() -= 1;
4985 }
4986 else
4987 {
4988 // We have an array of size 1, so we cannot decrement that. Our only option now is to
4989 // force a packed layout instead, and drop the physical type remap since ArrayStride is meaningless now.
4990 unset_extended_member_decoration(type: ib_type.self, index, decoration: SPIRVCrossDecorationPhysicalTypeID);
4991 set_extended_member_decoration(type: ib_type.self, index, decoration: SPIRVCrossDecorationPhysicalTypePacked);
4992 }
4993 }
4994 else if (is_matrix(type))
4995 {
4996 bool row_major = has_member_decoration(id: ib_type.self, index, decoration: DecorationRowMajor);
4997 if (!row_major)
4998 {
4999 // Slice off one column. If we only have 2 columns, this might turn the matrix into a vector with one array element instead.
5000 if (type.columns > 2)
5001 {
5002 type.columns--;
5003 }
5004 else if (type.columns == 2)
5005 {
5006 type.columns = 1;
5007 assert(type.array.empty());
5008 type.op = OpTypeArray;
5009 type.array.push_back(t: 1);
5010 type.array_size_literal.push_back(t: true);
5011 }
5012 }
5013 else
5014 {
5015 // Slice off one row. If we only have 2 rows, this might turn the matrix into a vector with one array element instead.
5016 if (type.vecsize > 2)
5017 {
5018 type.vecsize--;
5019 }
5020 else if (type.vecsize == 2)
5021 {
5022 type.vecsize = type.columns;
5023 type.columns = 1;
5024 assert(type.array.empty());
5025 type.op = OpTypeArray;
5026 type.array.push_back(t: 1);
5027 type.array_size_literal.push_back(t: true);
5028 }
5029 }
5030 }
5031
5032 // This better validate now, or we must fail gracefully.
5033 if (!validate_member_packing_rules_msl(type: ib_type, index))
5034 SPIRV_CROSS_THROW("Found a buffer packing case which we cannot represent in MSL.");
5035}
5036
5037void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_expression)
5038{
5039 auto &type = expression_type(id: rhs_expression);
5040
5041 bool lhs_remapped_type = has_extended_decoration(id: lhs_expression, decoration: SPIRVCrossDecorationPhysicalTypeID);
5042 bool lhs_packed_type = has_extended_decoration(id: lhs_expression, decoration: SPIRVCrossDecorationPhysicalTypePacked);
5043 auto *lhs_e = maybe_get<SPIRExpression>(id: lhs_expression);
5044 auto *rhs_e = maybe_get<SPIRExpression>(id: rhs_expression);
5045
5046 bool transpose = lhs_e && lhs_e->need_transpose;
5047
5048 if (has_decoration(id: lhs_expression, decoration: DecorationBuiltIn) &&
5049 BuiltIn(get_decoration(id: lhs_expression, decoration: DecorationBuiltIn)) == BuiltInSampleMask &&
5050 is_array(type))
5051 {
5052 // Storing an array to SampleMask, have to remove the array-ness before storing.
5053 statement(ts: to_expression(id: lhs_expression), ts: " = ", ts: to_enclosed_unpacked_expression(id: rhs_expression), ts: "[0];");
5054 register_write(chain: lhs_expression);
5055 }
5056 else if (!lhs_remapped_type && !lhs_packed_type)
5057 {
5058 // No physical type remapping, and no packed type, so can just emit a store directly.
5059
5060 // We might not be dealing with remapped physical types or packed types,
5061 // but we might be doing a clean store to a row-major matrix.
5062 // In this case, we just flip transpose states, and emit the store, a transpose must be in the RHS expression, if any.
5063 if (is_matrix(type) && lhs_e && lhs_e->need_transpose)
5064 {
5065 lhs_e->need_transpose = false;
5066
5067 if (rhs_e && rhs_e->need_transpose)
5068 {
5069 // Direct copy, but might need to unpack RHS.
5070 // Skip the transpose, as we will transpose when writing to LHS and transpose(transpose(T)) == T.
5071 rhs_e->need_transpose = false;
5072 statement(ts: to_expression(id: lhs_expression), ts: " = ", ts: to_unpacked_row_major_matrix_expression(id: rhs_expression),
5073 ts: ";");
5074 rhs_e->need_transpose = true;
5075 }
5076 else
5077 statement(ts: to_expression(id: lhs_expression), ts: " = transpose(", ts: to_unpacked_expression(id: rhs_expression), ts: ");");
5078
5079 lhs_e->need_transpose = true;
5080 register_write(chain: lhs_expression);
5081 }
5082 else if (lhs_e && lhs_e->need_transpose)
5083 {
5084 lhs_e->need_transpose = false;
5085
5086 // Storing a column to a row-major matrix. Unroll the write.
5087 for (uint32_t c = 0; c < type.vecsize; c++)
5088 {
5089 auto lhs_expr = to_dereferenced_expression(id: lhs_expression);
5090 auto column_index = lhs_expr.find_last_of(c: '[');
5091 if (column_index != string::npos)
5092 {
5093 statement(ts&: lhs_expr.insert(pos1: column_index, str: join(ts: '[', ts&: c, ts: ']')), ts: " = ",
5094 ts: to_extract_component_expression(id: rhs_expression, index: c), ts: ";");
5095 }
5096 }
5097 lhs_e->need_transpose = true;
5098 register_write(chain: lhs_expression);
5099 }
5100 else
5101 CompilerGLSL::emit_store_statement(lhs_expression, rhs_expression);
5102 }
5103 else if (!lhs_remapped_type && !is_matrix(type) && !transpose)
5104 {
5105 // Even if the target type is packed, we can directly store to it. We cannot store to packed matrices directly,
5106 // since they are declared as array of vectors instead, and we need the fallback path below.
5107 CompilerGLSL::emit_store_statement(lhs_expression, rhs_expression);
5108 }
5109 else
5110 {
5111 // Special handling when storing to a remapped physical type.
5112 // This is mostly to deal with std140 padded matrices or vectors.
5113
5114 TypeID physical_type_id = lhs_remapped_type ?
5115 ID(get_extended_decoration(id: lhs_expression, decoration: SPIRVCrossDecorationPhysicalTypeID)) :
5116 type.self;
5117
5118 auto &physical_type = get<SPIRType>(id: physical_type_id);
5119
5120 string cast_addr_space = "thread";
5121 auto *p_var_lhs = maybe_get_backing_variable(chain: lhs_expression);
5122 if (p_var_lhs)
5123 cast_addr_space = get_type_address_space(type: get<SPIRType>(id: p_var_lhs->basetype), id: lhs_expression);
5124
5125 if (is_matrix(type))
5126 {
5127 const char *packed_pfx = lhs_packed_type ? "packed_" : "";
5128
5129 // Packed matrices are stored as arrays of packed vectors, so we need
5130 // to assign the vectors one at a time.
5131 // For row-major matrices, we need to transpose the *right-hand* side,
5132 // not the left-hand side.
5133
5134 // Lots of cases to cover here ...
5135
5136 bool rhs_transpose = rhs_e && rhs_e->need_transpose;
5137 SPIRType write_type = type;
5138 string cast_expr;
5139
5140 // We're dealing with transpose manually.
5141 if (rhs_transpose)
5142 rhs_e->need_transpose = false;
5143
5144 if (transpose)
5145 {
5146 // We're dealing with transpose manually.
5147 lhs_e->need_transpose = false;
5148 write_type.vecsize = type.columns;
5149 write_type.columns = 1;
5150
5151 if (physical_type.columns != type.columns)
5152 cast_expr = join(ts: "(", ts&: cast_addr_space, ts: " ", ts&: packed_pfx, ts: type_to_glsl(type: write_type), ts: "&)");
5153
5154 if (rhs_transpose)
5155 {
5156 // If RHS is also transposed, we can just copy row by row.
5157 for (uint32_t i = 0; i < type.vecsize; i++)
5158 {
5159 statement(ts&: cast_expr, ts: to_enclosed_expression(id: lhs_expression), ts: "[", ts&: i, ts: "]", ts: " = ",
5160 ts: to_unpacked_row_major_matrix_expression(id: rhs_expression), ts: "[", ts&: i, ts: "];");
5161 }
5162 }
5163 else
5164 {
5165 auto vector_type = expression_type(id: rhs_expression);
5166 vector_type.vecsize = vector_type.columns;
5167 vector_type.columns = 1;
5168
5169 // Transpose on the fly. Emitting a lot of full transpose() ops and extracting lanes seems very bad,
5170 // so pick out individual components instead.
5171 for (uint32_t i = 0; i < type.vecsize; i++)
5172 {
5173 string rhs_row = type_to_glsl_constructor(type: vector_type) + "(";
5174 for (uint32_t j = 0; j < vector_type.vecsize; j++)
5175 {
5176 rhs_row += join(ts: to_enclosed_unpacked_expression(id: rhs_expression), ts: "[", ts&: j, ts: "][", ts&: i, ts: "]");
5177 if (j + 1 < vector_type.vecsize)
5178 rhs_row += ", ";
5179 }
5180 rhs_row += ")";
5181
5182 statement(ts&: cast_expr, ts: to_enclosed_expression(id: lhs_expression), ts: "[", ts&: i, ts: "]", ts: " = ", ts&: rhs_row, ts: ";");
5183 }
5184 }
5185
5186 // We're dealing with transpose manually.
5187 lhs_e->need_transpose = true;
5188 }
5189 else
5190 {
5191 write_type.columns = 1;
5192
5193 if (physical_type.vecsize != type.vecsize)
5194 cast_expr = join(ts: "(", ts&: cast_addr_space, ts: " ", ts&: packed_pfx, ts: type_to_glsl(type: write_type), ts: "&)");
5195
5196 if (rhs_transpose)
5197 {
5198 auto vector_type = expression_type(id: rhs_expression);
5199 vector_type.columns = 1;
5200
5201 // Transpose on the fly. Emitting a lot of full transpose() ops and extracting lanes seems very bad,
5202 // so pick out individual components instead.
5203 for (uint32_t i = 0; i < type.columns; i++)
5204 {
5205 string rhs_row = type_to_glsl_constructor(type: vector_type) + "(";
5206 for (uint32_t j = 0; j < vector_type.vecsize; j++)
5207 {
5208 // Need to explicitly unpack expression since we've mucked with transpose state.
5209 auto unpacked_expr = to_unpacked_row_major_matrix_expression(id: rhs_expression);
5210 rhs_row += join(ts&: unpacked_expr, ts: "[", ts&: j, ts: "][", ts&: i, ts: "]");
5211 if (j + 1 < vector_type.vecsize)
5212 rhs_row += ", ";
5213 }
5214 rhs_row += ")";
5215
5216 statement(ts&: cast_expr, ts: to_enclosed_expression(id: lhs_expression), ts: "[", ts&: i, ts: "]", ts: " = ", ts&: rhs_row, ts: ";");
5217 }
5218 }
5219 else
5220 {
5221 // Copy column-by-column.
5222 for (uint32_t i = 0; i < type.columns; i++)
5223 {
5224 statement(ts&: cast_expr, ts: to_enclosed_expression(id: lhs_expression), ts: "[", ts&: i, ts: "]", ts: " = ",
5225 ts: to_enclosed_unpacked_expression(id: rhs_expression), ts: "[", ts&: i, ts: "];");
5226 }
5227 }
5228 }
5229
5230 // We're dealing with transpose manually.
5231 if (rhs_transpose)
5232 rhs_e->need_transpose = true;
5233 }
5234 else if (transpose)
5235 {
5236 lhs_e->need_transpose = false;
5237
5238 SPIRType write_type = type;
5239 write_type.vecsize = 1;
5240 write_type.columns = 1;
5241
5242 // Storing a column to a row-major matrix. Unroll the write.
5243 for (uint32_t c = 0; c < type.vecsize; c++)
5244 {
5245 auto lhs_expr = to_enclosed_expression(id: lhs_expression);
5246 auto column_index = lhs_expr.find_last_of(c: '[');
5247
5248 // Get rid of any ".data" half8 handling here, we're casting to scalar anyway.
5249 auto end_column_index = lhs_expr.find_last_of(c: ']');
5250 auto end_dot_index = lhs_expr.find_last_of(c: '.');
5251 if (end_dot_index != string::npos && end_dot_index > end_column_index)
5252 lhs_expr.resize(n: end_dot_index);
5253
5254 if (column_index != string::npos)
5255 {
5256 statement(ts: "((", ts&: cast_addr_space, ts: " ", ts: type_to_glsl(type: write_type), ts: "*)&",
5257 ts&: lhs_expr.insert(pos1: column_index, str: join(ts: '[', ts&: c, ts: ']', ts: ")")), ts: " = ",
5258 ts: to_extract_component_expression(id: rhs_expression, index: c), ts: ";");
5259 }
5260 }
5261
5262 lhs_e->need_transpose = true;
5263 }
5264 else if ((is_matrix(type: physical_type) || is_array(type: physical_type)) &&
5265 physical_type.vecsize <= 4 &&
5266 physical_type.vecsize > type.vecsize)
5267 {
5268 assert(type.vecsize >= 1 && type.vecsize <= 3);
5269
5270 // If we have packed types, we cannot use swizzled stores.
5271 // We could technically unroll the store for each element if needed.
5272 // When remapping to a std140 physical type, we always get float4,
5273 // and the packed decoration should always be removed.
5274 assert(!lhs_packed_type);
5275
5276 string lhs = to_dereferenced_expression(id: lhs_expression);
5277 string rhs = to_pointer_expression(id: rhs_expression);
5278
5279 // Unpack the expression so we can store to it with a float or float2.
5280 // It's still an l-value, so it's fine. Most other unpacking of expressions turn them into r-values instead.
5281 lhs = join(ts: "(", ts&: cast_addr_space, ts: " ", ts: type_to_glsl(type), ts: "&)", ts: enclose_expression(expr: lhs));
5282 if (!optimize_read_modify_write(type: expression_type(id: rhs_expression), lhs, rhs))
5283 statement(ts&: lhs, ts: " = ", ts&: rhs, ts: ";");
5284 }
5285 else if (!is_matrix(type))
5286 {
5287 string lhs = to_dereferenced_expression(id: lhs_expression);
5288 string rhs = to_pointer_expression(id: rhs_expression);
5289 if (!optimize_read_modify_write(type: expression_type(id: rhs_expression), lhs, rhs))
5290 statement(ts&: lhs, ts: " = ", ts&: rhs, ts: ";");
5291 }
5292
5293 register_write(chain: lhs_expression);
5294 }
5295}
5296
5297static bool expression_ends_with(const string &expr_str, const std::string &ending)
5298{
5299 if (expr_str.length() >= ending.length())
5300 return (expr_str.compare(pos: expr_str.length() - ending.length(), n: ending.length(), str: ending) == 0);
5301 else
5302 return false;
5303}
5304
5305// Converts the format of the current expression from packed to unpacked,
5306// by wrapping the expression in a constructor of the appropriate type.
5307// Also, handle special physical ID remapping scenarios, similar to emit_store_statement().
5308string CompilerMSL::unpack_expression_type(string expr_str, const SPIRType &type, uint32_t physical_type_id,
5309 bool packed, bool row_major)
5310{
5311 // Trivial case, nothing to do.
5312 if (physical_type_id == 0 && !packed)
5313 return expr_str;
5314
5315 const SPIRType *physical_type = nullptr;
5316 if (physical_type_id)
5317 physical_type = &get<SPIRType>(id: physical_type_id);
5318
5319 static const char *swizzle_lut[] = {
5320 ".x",
5321 ".xy",
5322 ".xyz",
5323 "",
5324 };
5325
5326 // TODO: Move everything to the template wrapper?
5327 bool uses_std140_wrapper = physical_type && physical_type->vecsize > 4;
5328
5329 if (physical_type && is_vector(type: *physical_type) && is_array(type: *physical_type) &&
5330 !uses_std140_wrapper &&
5331 physical_type->vecsize > type.vecsize && !expression_ends_with(expr_str, ending: swizzle_lut[type.vecsize - 1]))
5332 {
5333 // std140 array cases for vectors.
5334 assert(type.vecsize >= 1 && type.vecsize <= 3);
5335 return enclose_expression(expr: expr_str) + swizzle_lut[type.vecsize - 1];
5336 }
5337 else if (physical_type && is_matrix(type: *physical_type) && is_vector(type) &&
5338 !uses_std140_wrapper &&
5339 physical_type->vecsize > type.vecsize)
5340 {
5341 // Extract column from padded matrix.
5342 assert(type.vecsize >= 1 && type.vecsize <= 4);
5343 return enclose_expression(expr: expr_str) + swizzle_lut[type.vecsize - 1];
5344 }
5345 else if (is_matrix(type))
5346 {
5347 // Packed matrices are stored as arrays of packed vectors. Unfortunately,
5348 // we can't just pass the array straight to the matrix constructor. We have to
5349 // pass each vector individually, so that they can be unpacked to normal vectors.
5350 if (!physical_type)
5351 physical_type = &type;
5352
5353 uint32_t vecsize = type.vecsize;
5354 uint32_t columns = type.columns;
5355 if (row_major)
5356 swap(a&: vecsize, b&: columns);
5357
5358 uint32_t physical_vecsize = row_major ? physical_type->columns : physical_type->vecsize;
5359
5360 const char *base_type = type.width == 16 ? "half" : "float";
5361 string unpack_expr = join(ts&: base_type, ts&: columns, ts: "x", ts&: vecsize, ts: "(");
5362
5363 const char *load_swiz = "";
5364 const char *data_swiz = physical_vecsize > 4 ? ".data" : "";
5365
5366 if (physical_vecsize != vecsize)
5367 load_swiz = swizzle_lut[vecsize - 1];
5368
5369 for (uint32_t i = 0; i < columns; i++)
5370 {
5371 if (i > 0)
5372 unpack_expr += ", ";
5373
5374 if (packed)
5375 unpack_expr += join(ts&: base_type, ts&: physical_vecsize, ts: "(", ts&: expr_str, ts: "[", ts&: i, ts: "]", ts: ")", ts&: load_swiz);
5376 else
5377 unpack_expr += join(ts&: expr_str, ts: "[", ts&: i, ts: "]", ts&: data_swiz, ts&: load_swiz);
5378 }
5379
5380 unpack_expr += ")";
5381 return unpack_expr;
5382 }
5383 else
5384 {
5385 return join(ts: type_to_glsl(type), ts: "(", ts&: expr_str, ts: ")");
5386 }
5387}
5388
5389// Emits the file header info
5390void CompilerMSL::emit_header()
5391{
5392 // This particular line can be overridden during compilation, so make it a flag and not a pragma line.
5393 if (suppress_missing_prototypes)
5394 statement(ts: "#pragma clang diagnostic ignored \"-Wmissing-prototypes\"");
5395 if (suppress_incompatible_pointer_types_discard_qualifiers)
5396 statement(ts: "#pragma clang diagnostic ignored \"-Wincompatible-pointer-types-discards-qualifiers\"");
5397
5398 // Disable warning about missing braces for array<T> template to make arrays a value type
5399 if (spv_function_implementations.count(x: SPVFuncImplUnsafeArray) != 0)
5400 statement(ts: "#pragma clang diagnostic ignored \"-Wmissing-braces\"");
5401
5402 for (auto &pragma : pragma_lines)
5403 statement(ts: pragma);
5404
5405 if (!pragma_lines.empty() || suppress_missing_prototypes)
5406 statement(ts: "");
5407
5408 statement(ts: "#include <metal_stdlib>");
5409 statement(ts: "#include <simd/simd.h>");
5410
5411 for (auto &header : header_lines)
5412 statement(ts&: header);
5413
5414 statement(ts: "");
5415 statement(ts: "using namespace metal;");
5416 statement(ts: "");
5417
5418 for (auto &td : typedef_lines)
5419 statement(ts: td);
5420
5421 if (!typedef_lines.empty())
5422 statement(ts: "");
5423}
5424
5425void CompilerMSL::add_pragma_line(const string &line)
5426{
5427 auto rslt = pragma_lines.insert(x: line);
5428 if (rslt.second)
5429 force_recompile();
5430}
5431
5432void CompilerMSL::add_typedef_line(const string &line)
5433{
5434 auto rslt = typedef_lines.insert(x: line);
5435 if (rslt.second)
5436 force_recompile();
5437}
5438
5439// Template struct like spvUnsafeArray<> need to be declared *before* any resources are declared
5440void CompilerMSL::emit_custom_templates()
5441{
5442 static const char * const address_spaces[] = {
5443 "thread", "constant", "device", "threadgroup", "threadgroup_imageblock", "ray_data", "object_data"
5444 };
5445
5446 for (const auto &spv_func : spv_function_implementations)
5447 {
5448 switch (spv_func)
5449 {
5450 case SPVFuncImplUnsafeArray:
5451 statement(ts: "template<typename T, size_t Num>");
5452 statement(ts: "struct spvUnsafeArray");
5453 begin_scope();
5454 statement(ts: "T elements[Num ? Num : 1];");
5455 statement(ts: "");
5456 statement(ts: "thread T& operator [] (size_t pos) thread");
5457 begin_scope();
5458 statement(ts: "return elements[pos];");
5459 end_scope();
5460 statement(ts: "constexpr const thread T& operator [] (size_t pos) const thread");
5461 begin_scope();
5462 statement(ts: "return elements[pos];");
5463 end_scope();
5464 statement(ts: "");
5465 statement(ts: "device T& operator [] (size_t pos) device");
5466 begin_scope();
5467 statement(ts: "return elements[pos];");
5468 end_scope();
5469 statement(ts: "constexpr const device T& operator [] (size_t pos) const device");
5470 begin_scope();
5471 statement(ts: "return elements[pos];");
5472 end_scope();
5473 statement(ts: "");
5474 statement(ts: "constexpr const constant T& operator [] (size_t pos) const constant");
5475 begin_scope();
5476 statement(ts: "return elements[pos];");
5477 end_scope();
5478 statement(ts: "");
5479 statement(ts: "threadgroup T& operator [] (size_t pos) threadgroup");
5480 begin_scope();
5481 statement(ts: "return elements[pos];");
5482 end_scope();
5483 statement(ts: "constexpr const threadgroup T& operator [] (size_t pos) const threadgroup");
5484 begin_scope();
5485 statement(ts: "return elements[pos];");
5486 end_scope();
5487 end_scope_decl();
5488 statement(ts: "");
5489 break;
5490
5491 case SPVFuncImplStorageMatrix:
5492 statement(ts: "template<typename T, int Cols, int Rows=Cols>");
5493 statement(ts: "struct spvStorageMatrix");
5494 begin_scope();
5495 statement(ts: "vec<T, Rows> columns[Cols];");
5496 statement(ts: "");
5497 for (size_t method_idx = 0; method_idx < sizeof(address_spaces) / sizeof(address_spaces[0]); ++method_idx)
5498 {
5499 // Some address spaces require particular features.
5500 if (method_idx == 4) // threadgroup_imageblock
5501 statement(ts: "#ifdef __HAVE_IMAGEBLOCKS__");
5502 else if (method_idx == 5) // ray_data
5503 statement(ts: "#ifdef __HAVE_RAYTRACING__");
5504 else if (method_idx == 6) // object_data
5505 statement(ts: "#ifdef __HAVE_MESH__");
5506 const string &method_as = address_spaces[method_idx];
5507 statement(ts: "spvStorageMatrix() ", ts: method_as, ts: " = default;");
5508 if (method_idx != 1) // constant
5509 {
5510 statement(ts: method_as, ts: " spvStorageMatrix& operator=(initializer_list<vec<T, Rows>> cols) ",
5511 ts: method_as);
5512 begin_scope();
5513 statement(ts: "size_t i;");
5514 statement(ts: "thread vec<T, Rows>* col;");
5515 statement(ts: "for (i = 0, col = cols.begin(); i < Cols; ++i, ++col)");
5516 statement(ts: " columns[i] = *col;");
5517 statement(ts: "return *this;");
5518 end_scope();
5519 }
5520 statement(ts: "");
5521 for (size_t param_idx = 0; param_idx < sizeof(address_spaces) / sizeof(address_spaces[0]); ++param_idx)
5522 {
5523 if (param_idx != method_idx)
5524 {
5525 if (param_idx == 4) // threadgroup_imageblock
5526 statement(ts: "#ifdef __HAVE_IMAGEBLOCKS__");
5527 else if (param_idx == 5) // ray_data
5528 statement(ts: "#ifdef __HAVE_RAYTRACING__");
5529 else if (param_idx == 6) // object_data
5530 statement(ts: "#ifdef __HAVE_MESH__");
5531 }
5532 const string &param_as = address_spaces[param_idx];
5533 statement(ts: "spvStorageMatrix(const ", ts: param_as, ts: " matrix<T, Cols, Rows>& m) ", ts: method_as);
5534 begin_scope();
5535 statement(ts: "for (size_t i = 0; i < Cols; ++i)");
5536 statement(ts: " columns[i] = m.columns[i];");
5537 end_scope();
5538 statement(ts: "spvStorageMatrix(const ", ts: param_as, ts: " spvStorageMatrix& m) ", ts: method_as, ts: " = default;");
5539 if (method_idx != 1) // constant
5540 {
5541 statement(ts: method_as, ts: " spvStorageMatrix& operator=(const ", ts: param_as,
5542 ts: " matrix<T, Cols, Rows>& m) ", ts: method_as);
5543 begin_scope();
5544 statement(ts: "for (size_t i = 0; i < Cols; ++i)");
5545 statement(ts: " columns[i] = m.columns[i];");
5546 statement(ts: "return *this;");
5547 end_scope();
5548 statement(ts: method_as, ts: " spvStorageMatrix& operator=(const ", ts: param_as, ts: " spvStorageMatrix& m) ",
5549 ts: method_as, ts: " = default;");
5550 }
5551 if (param_idx != method_idx && param_idx >= 4)
5552 statement(ts: "#endif");
5553 statement(ts: "");
5554 }
5555 statement(ts: "operator matrix<T, Cols, Rows>() const ", ts: method_as);
5556 begin_scope();
5557 statement(ts: "matrix<T, Cols, Rows> m;");
5558 statement(ts: "for (int i = 0; i < Cols; ++i)");
5559 statement(ts: " m.columns[i] = columns[i];");
5560 statement(ts: "return m;");
5561 end_scope();
5562 statement(ts: "");
5563 statement(ts: "vec<T, Rows> operator[](size_t idx) const ", ts: method_as);
5564 begin_scope();
5565 statement(ts: "return columns[idx];");
5566 end_scope();
5567 if (method_idx != 1) // constant
5568 {
5569 statement(ts: method_as, ts: " vec<T, Rows>& operator[](size_t idx) ", ts: method_as);
5570 begin_scope();
5571 statement(ts: "return columns[idx];");
5572 end_scope();
5573 }
5574 if (method_idx >= 4)
5575 statement(ts: "#endif");
5576 statement(ts: "");
5577 }
5578 end_scope_decl();
5579 statement(ts: "");
5580 statement(ts: "template<typename T, int Cols, int Rows>");
5581 statement(ts: "matrix<T, Rows, Cols> transpose(spvStorageMatrix<T, Cols, Rows> m)");
5582 begin_scope();
5583 statement(ts: "return transpose(matrix<T, Cols, Rows>(m));");
5584 end_scope();
5585 statement(ts: "");
5586 statement(ts: "typedef spvStorageMatrix<half, 2, 2> spvStorage_half2x2;");
5587 statement(ts: "typedef spvStorageMatrix<half, 2, 3> spvStorage_half2x3;");
5588 statement(ts: "typedef spvStorageMatrix<half, 2, 4> spvStorage_half2x4;");
5589 statement(ts: "typedef spvStorageMatrix<half, 3, 2> spvStorage_half3x2;");
5590 statement(ts: "typedef spvStorageMatrix<half, 3, 3> spvStorage_half3x3;");
5591 statement(ts: "typedef spvStorageMatrix<half, 3, 4> spvStorage_half3x4;");
5592 statement(ts: "typedef spvStorageMatrix<half, 4, 2> spvStorage_half4x2;");
5593 statement(ts: "typedef spvStorageMatrix<half, 4, 3> spvStorage_half4x3;");
5594 statement(ts: "typedef spvStorageMatrix<half, 4, 4> spvStorage_half4x4;");
5595 statement(ts: "typedef spvStorageMatrix<float, 2, 2> spvStorage_float2x2;");
5596 statement(ts: "typedef spvStorageMatrix<float, 2, 3> spvStorage_float2x3;");
5597 statement(ts: "typedef spvStorageMatrix<float, 2, 4> spvStorage_float2x4;");
5598 statement(ts: "typedef spvStorageMatrix<float, 3, 2> spvStorage_float3x2;");
5599 statement(ts: "typedef spvStorageMatrix<float, 3, 3> spvStorage_float3x3;");
5600 statement(ts: "typedef spvStorageMatrix<float, 3, 4> spvStorage_float3x4;");
5601 statement(ts: "typedef spvStorageMatrix<float, 4, 2> spvStorage_float4x2;");
5602 statement(ts: "typedef spvStorageMatrix<float, 4, 3> spvStorage_float4x3;");
5603 statement(ts: "typedef spvStorageMatrix<float, 4, 4> spvStorage_float4x4;");
5604 statement(ts: "");
5605 break;
5606
5607 default:
5608 break;
5609 }
5610 }
5611}
5612
5613// Emits any needed custom function bodies.
5614// Metal helper functions must be static force-inline, i.e. static inline __attribute__((always_inline))
5615// otherwise they will cause problems when linked together in a single Metallib.
5616void CompilerMSL::emit_custom_functions()
5617{
5618 if (spv_function_implementations.count(x: SPVFuncImplArrayCopyMultidim))
5619 spv_function_implementations.insert(x: SPVFuncImplArrayCopy);
5620
5621 if (spv_function_implementations.count(x: SPVFuncImplDynamicImageSampler))
5622 {
5623 // Unfortunately, this one needs a lot of the other functions to compile OK.
5624 if (!msl_options.supports_msl_version(major: 2))
5625 SPIRV_CROSS_THROW(
5626 "spvDynamicImageSampler requires default-constructible texture objects, which require MSL 2.0.");
5627 spv_function_implementations.insert(x: SPVFuncImplForwardArgs);
5628 spv_function_implementations.insert(x: SPVFuncImplTextureSwizzle);
5629 if (msl_options.swizzle_texture_samples)
5630 spv_function_implementations.insert(x: SPVFuncImplGatherSwizzle);
5631 for (uint32_t i = SPVFuncImplChromaReconstructNearest2Plane;
5632 i <= SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane; i++)
5633 spv_function_implementations.insert(x: static_cast<SPVFuncImpl>(i));
5634 spv_function_implementations.insert(x: SPVFuncImplExpandITUFullRange);
5635 spv_function_implementations.insert(x: SPVFuncImplExpandITUNarrowRange);
5636 spv_function_implementations.insert(x: SPVFuncImplConvertYCbCrBT709);
5637 spv_function_implementations.insert(x: SPVFuncImplConvertYCbCrBT601);
5638 spv_function_implementations.insert(x: SPVFuncImplConvertYCbCrBT2020);
5639 }
5640
5641 for (uint32_t i = SPVFuncImplChromaReconstructNearest2Plane;
5642 i <= SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane; i++)
5643 if (spv_function_implementations.count(x: static_cast<SPVFuncImpl>(i)))
5644 spv_function_implementations.insert(x: SPVFuncImplForwardArgs);
5645
5646 if (spv_function_implementations.count(x: SPVFuncImplTextureSwizzle) ||
5647 spv_function_implementations.count(x: SPVFuncImplGatherSwizzle) ||
5648 spv_function_implementations.count(x: SPVFuncImplGatherCompareSwizzle))
5649 {
5650 spv_function_implementations.insert(x: SPVFuncImplForwardArgs);
5651 spv_function_implementations.insert(x: SPVFuncImplGetSwizzle);
5652 }
5653
5654 for (const auto &spv_func : spv_function_implementations)
5655 {
5656 switch (spv_func)
5657 {
5658 case SPVFuncImplMod:
5659 statement(ts: "// Implementation of the GLSL mod() function, which is slightly different than Metal fmod()");
5660 statement(ts: "template<typename Tx, typename Ty>");
5661 statement(ts: "inline Tx mod(Tx x, Ty y)");
5662 begin_scope();
5663 statement(ts: "return x - y * floor(x / y);");
5664 end_scope();
5665 statement(ts: "");
5666 break;
5667
5668 case SPVFuncImplRadians:
5669 statement(ts: "// Implementation of the GLSL radians() function");
5670 statement(ts: "template<typename T>");
5671 statement(ts: "inline T radians(T d)");
5672 begin_scope();
5673 statement(ts: "return d * T(0.01745329251);");
5674 end_scope();
5675 statement(ts: "");
5676 break;
5677
5678 case SPVFuncImplDegrees:
5679 statement(ts: "// Implementation of the GLSL degrees() function");
5680 statement(ts: "template<typename T>");
5681 statement(ts: "inline T degrees(T r)");
5682 begin_scope();
5683 statement(ts: "return r * T(57.2957795131);");
5684 end_scope();
5685 statement(ts: "");
5686 break;
5687
5688 case SPVFuncImplFindILsb:
5689 statement(ts: "// Implementation of the GLSL findLSB() function");
5690 statement(ts: "template<typename T>");
5691 statement(ts: "inline T spvFindLSB(T x)");
5692 begin_scope();
5693 statement(ts: "return select(ctz(x), T(-1), x == T(0));");
5694 end_scope();
5695 statement(ts: "");
5696 break;
5697
5698 case SPVFuncImplFindUMsb:
5699 statement(ts: "// Implementation of the unsigned GLSL findMSB() function");
5700 statement(ts: "template<typename T>");
5701 statement(ts: "inline T spvFindUMSB(T x)");
5702 begin_scope();
5703 statement(ts: "return select(clz(T(0)) - (clz(x) + T(1)), T(-1), x == T(0));");
5704 end_scope();
5705 statement(ts: "");
5706 break;
5707
5708 case SPVFuncImplFindSMsb:
5709 statement(ts: "// Implementation of the signed GLSL findMSB() function");
5710 statement(ts: "template<typename T>");
5711 statement(ts: "inline T spvFindSMSB(T x)");
5712 begin_scope();
5713 statement(ts: "T v = select(x, T(-1) - x, x < T(0));");
5714 statement(ts: "return select(clz(T(0)) - (clz(v) + T(1)), T(-1), v == T(0));");
5715 end_scope();
5716 statement(ts: "");
5717 break;
5718
5719 case SPVFuncImplSSign:
5720 statement(ts: "// Implementation of the GLSL sign() function for integer types");
5721 statement(ts: "template<typename T, typename E = typename enable_if<is_integral<T>::value>::type>");
5722 statement(ts: "inline T sign(T x)");
5723 begin_scope();
5724 statement(ts: "return select(select(select(x, T(0), x == T(0)), T(1), x > T(0)), T(-1), x < T(0));");
5725 end_scope();
5726 statement(ts: "");
5727 break;
5728
5729 case SPVFuncImplArrayCopy:
5730 case SPVFuncImplArrayCopyMultidim:
5731 {
5732 // Unfortunately we cannot template on the address space, so combinatorial explosion it is.
5733 static const char *function_name_tags[] = {
5734 "FromConstantToStack", "FromConstantToThreadGroup", "FromStackToStack",
5735 "FromStackToThreadGroup", "FromThreadGroupToStack", "FromThreadGroupToThreadGroup",
5736 "FromDeviceToDevice", "FromConstantToDevice", "FromStackToDevice",
5737 "FromThreadGroupToDevice", "FromDeviceToStack", "FromDeviceToThreadGroup",
5738 };
5739
5740 static const char *src_address_space[] = {
5741 "constant", "constant", "thread const", "thread const",
5742 "threadgroup const", "threadgroup const", "device const", "constant",
5743 "thread const", "threadgroup const", "device const", "device const",
5744 };
5745
5746 static const char *dst_address_space[] = {
5747 "thread", "threadgroup", "thread", "threadgroup", "thread", "threadgroup",
5748 "device", "device", "device", "device", "thread", "threadgroup",
5749 };
5750
5751 for (uint32_t variant = 0; variant < 12; variant++)
5752 {
5753 bool is_multidim = spv_func == SPVFuncImplArrayCopyMultidim;
5754 const char* dim = is_multidim ? "[N][M]" : "[N]";
5755 statement(ts: "template<typename T, uint N", ts: is_multidim ? ", uint M>" : ">");
5756 statement(ts: "inline void spvArrayCopy", ts&: function_name_tags[variant], ts: "(",
5757 ts&: dst_address_space[variant], ts: " T (&dst)", ts&: dim, ts: ", ",
5758 ts&: src_address_space[variant], ts: " T (&src)", ts&: dim, ts: ")");
5759 begin_scope();
5760 statement(ts: "for (uint i = 0; i < N; i++)");
5761 begin_scope();
5762 if (is_multidim)
5763 statement(ts: "spvArrayCopy", ts&: function_name_tags[variant], ts: "(dst[i], src[i]);");
5764 else
5765 statement(ts: "dst[i] = src[i];");
5766 end_scope();
5767 end_scope();
5768 statement(ts: "");
5769 }
5770 break;
5771 }
5772
5773 // Support for Metal 2.1's new texture_buffer type.
5774 case SPVFuncImplTexelBufferCoords:
5775 {
5776 if (msl_options.texel_buffer_texture_width > 0)
5777 {
5778 string tex_width_str = convert_to_string(t: msl_options.texel_buffer_texture_width);
5779 statement(ts: "// Returns 2D texture coords corresponding to 1D texel buffer coords");
5780 statement(ts&: force_inline);
5781 statement(ts: "uint2 spvTexelBufferCoord(uint tc)");
5782 begin_scope();
5783 statement(ts: join(ts: "return uint2(tc % ", ts&: tex_width_str, ts: ", tc / ", ts&: tex_width_str, ts: ");"));
5784 end_scope();
5785 statement(ts: "");
5786 }
5787 else
5788 {
5789 statement(ts: "// Returns 2D texture coords corresponding to 1D texel buffer coords");
5790 statement(
5791 ts: "#define spvTexelBufferCoord(tc, tex) uint2((tc) % (tex).get_width(), (tc) / (tex).get_width())");
5792 statement(ts: "");
5793 }
5794 break;
5795 }
5796
5797 // Emulate texture2D atomic operations
5798 case SPVFuncImplImage2DAtomicCoords:
5799 {
5800 if (msl_options.supports_msl_version(major: 1, minor: 2))
5801 {
5802 statement(ts: "// The required alignment of a linear texture of R32Uint format.");
5803 statement(ts: "constant uint spvLinearTextureAlignmentOverride [[function_constant(",
5804 ts&: msl_options.r32ui_alignment_constant_id, ts: ")]];");
5805 statement(ts: "constant uint spvLinearTextureAlignment = ",
5806 ts: "is_function_constant_defined(spvLinearTextureAlignmentOverride) ? ",
5807 ts: "spvLinearTextureAlignmentOverride : ", ts&: msl_options.r32ui_linear_texture_alignment, ts: ";");
5808 }
5809 else
5810 {
5811 statement(ts: "// The required alignment of a linear texture of R32Uint format.");
5812 statement(ts: "constant uint spvLinearTextureAlignment = ", ts&: msl_options.r32ui_linear_texture_alignment,
5813 ts: ";");
5814 }
5815 statement(ts: "// Returns buffer coords corresponding to 2D texture coords for emulating 2D texture atomics");
5816 statement(ts: "#define spvImage2DAtomicCoord(tc, tex) (((((tex).get_width() + ",
5817 ts: " spvLinearTextureAlignment / 4 - 1) & ~(",
5818 ts: " spvLinearTextureAlignment / 4 - 1)) * (tc).y) + (tc).x)");
5819 statement(ts: "");
5820 break;
5821 }
5822
5823 // Fix up gradient vectors when sampling a cube texture for Apple Silicon.
5824 // h/t Alexey Knyazev (https://github.com/KhronosGroup/MoltenVK/issues/2068#issuecomment-1817799067) for the code.
5825 case SPVFuncImplGradientCube:
5826 statement(ts: "static inline gradientcube spvGradientCube(float3 P, float3 dPdx, float3 dPdy)");
5827 begin_scope();
5828 statement(ts: "// Major axis selection");
5829 statement(ts: "float3 absP = abs(P);");
5830 statement(ts: "bool xMajor = absP.x >= max(absP.y, absP.z);");
5831 statement(ts: "bool yMajor = absP.y >= absP.z;");
5832 statement(ts: "float3 Q = xMajor ? P.yzx : (yMajor ? P.xzy : P);");
5833 statement(ts: "float3 dQdx = xMajor ? dPdx.yzx : (yMajor ? dPdx.xzy : dPdx);");
5834 statement(ts: "float3 dQdy = xMajor ? dPdy.yzx : (yMajor ? dPdy.xzy : dPdy);");
5835 statement_no_indent(ts: "");
5836 statement(ts: "// Skip a couple of operations compared to usual projection");
5837 statement(ts: "float4 d = float4(dQdx.xy, dQdy.xy) - (Q.xy / Q.z).xyxy * float4(dQdx.zz, dQdy.zz);");
5838 statement_no_indent(ts: "");
5839 statement(ts: "// Final swizzle to put the intermediate values into non-ignored components");
5840 statement(ts: "// X major: X and Z");
5841 statement(ts: "// Y major: X and Y");
5842 statement(ts: "// Z major: Y and Z");
5843 statement(ts: "return gradientcube(xMajor ? d.xxy : d.xyx, xMajor ? d.zzw : d.zwz);");
5844 end_scope();
5845 statement(ts: "");
5846 break;
5847
5848 // "fadd" intrinsic support
5849 case SPVFuncImplFAdd:
5850 statement(ts: "template<typename T>");
5851 statement(ts: "[[clang::optnone]] T spvFAdd(T l, T r)");
5852 begin_scope();
5853 statement(ts: "return fma(T(1), l, r);");
5854 end_scope();
5855 statement(ts: "");
5856 break;
5857
5858 // "fsub" intrinsic support
5859 case SPVFuncImplFSub:
5860 statement(ts: "template<typename T>");
5861 statement(ts: "[[clang::optnone]] T spvFSub(T l, T r)");
5862 begin_scope();
5863 statement(ts: "return fma(T(-1), r, l);");
5864 end_scope();
5865 statement(ts: "");
5866 break;
5867
5868 // "fmul' intrinsic support
5869 case SPVFuncImplFMul:
5870 statement(ts: "template<typename T>");
5871 statement(ts: "[[clang::optnone]] T spvFMul(T l, T r)");
5872 begin_scope();
5873 statement(ts: "return fma(l, r, T(0));");
5874 end_scope();
5875 statement(ts: "");
5876
5877 statement(ts: "template<typename T, int Cols, int Rows>");
5878 statement(ts: "[[clang::optnone]] vec<T, Cols> spvFMulVectorMatrix(vec<T, Rows> v, matrix<T, Cols, Rows> m)");
5879 begin_scope();
5880 statement(ts: "vec<T, Cols> res = vec<T, Cols>(0);");
5881 statement(ts: "for (uint i = Rows; i > 0; --i)");
5882 begin_scope();
5883 statement(ts: "vec<T, Cols> tmp(0);");
5884 statement(ts: "for (uint j = 0; j < Cols; ++j)");
5885 begin_scope();
5886 statement(ts: "tmp[j] = m[j][i - 1];");
5887 end_scope();
5888 statement(ts: "res = fma(tmp, vec<T, Cols>(v[i - 1]), res);");
5889 end_scope();
5890 statement(ts: "return res;");
5891 end_scope();
5892 statement(ts: "");
5893
5894 statement(ts: "template<typename T, int Cols, int Rows>");
5895 statement(ts: "[[clang::optnone]] vec<T, Rows> spvFMulMatrixVector(matrix<T, Cols, Rows> m, vec<T, Cols> v)");
5896 begin_scope();
5897 statement(ts: "vec<T, Rows> res = vec<T, Rows>(0);");
5898 statement(ts: "for (uint i = Cols; i > 0; --i)");
5899 begin_scope();
5900 statement(ts: "res = fma(m[i - 1], vec<T, Rows>(v[i - 1]), res);");
5901 end_scope();
5902 statement(ts: "return res;");
5903 end_scope();
5904 statement(ts: "");
5905
5906 statement(ts: "template<typename T, int LCols, int LRows, int RCols, int RRows>");
5907 statement(ts: "[[clang::optnone]] matrix<T, RCols, LRows> spvFMulMatrixMatrix(matrix<T, LCols, LRows> l, matrix<T, RCols, RRows> r)");
5908 begin_scope();
5909 statement(ts: "matrix<T, RCols, LRows> res;");
5910 statement(ts: "for (uint i = 0; i < RCols; i++)");
5911 begin_scope();
5912 statement(ts: "vec<T, RCols> tmp(0);");
5913 statement(ts: "for (uint j = 0; j < LCols; j++)");
5914 begin_scope();
5915 statement(ts: "tmp = fma(vec<T, RCols>(r[i][j]), l[j], tmp);");
5916 end_scope();
5917 statement(ts: "res[i] = tmp;");
5918 end_scope();
5919 statement(ts: "return res;");
5920 end_scope();
5921 statement(ts: "");
5922 break;
5923
5924 case SPVFuncImplQuantizeToF16:
5925 // Ensure fast-math is disabled to match Vulkan results.
5926 // SpvHalfTypeSelector is used to match the half* template type to the float* template type.
5927 // Depending on GPU, MSL does not always flush converted subnormal halfs to zero,
5928 // as required by OpQuantizeToF16, so check for subnormals and flush them to zero.
5929 statement(ts: "template <typename F> struct SpvHalfTypeSelector;");
5930 statement(ts: "template <> struct SpvHalfTypeSelector<float> { public: using H = half; };");
5931 statement(ts: "template<uint N> struct SpvHalfTypeSelector<vec<float, N>> { using H = vec<half, N>; };");
5932 statement(ts: "template<typename F, typename H = typename SpvHalfTypeSelector<F>::H>");
5933 statement(ts: "[[clang::optnone]] F spvQuantizeToF16(F fval)");
5934 begin_scope();
5935 statement(ts: "H hval = H(fval);");
5936 statement(ts: "hval = select(copysign(H(0), hval), hval, isnormal(hval) || isinf(hval) || isnan(hval));");
5937 statement(ts: "return F(hval);");
5938 end_scope();
5939 statement(ts: "");
5940 break;
5941
5942 // Emulate texturecube_array with texture2d_array for iOS where this type is not available
5943 case SPVFuncImplCubemapTo2DArrayFace:
5944 statement(ts&: force_inline);
5945 statement(ts: "float3 spvCubemapTo2DArrayFace(float3 P)");
5946 begin_scope();
5947 statement(ts: "float3 Coords = abs(P.xyz);");
5948 statement(ts: "float CubeFace = 0;");
5949 statement(ts: "float ProjectionAxis = 0;");
5950 statement(ts: "float u = 0;");
5951 statement(ts: "float v = 0;");
5952 statement(ts: "if (Coords.x >= Coords.y && Coords.x >= Coords.z)");
5953 begin_scope();
5954 statement(ts: "CubeFace = P.x >= 0 ? 0 : 1;");
5955 statement(ts: "ProjectionAxis = Coords.x;");
5956 statement(ts: "u = P.x >= 0 ? -P.z : P.z;");
5957 statement(ts: "v = -P.y;");
5958 end_scope();
5959 statement(ts: "else if (Coords.y >= Coords.x && Coords.y >= Coords.z)");
5960 begin_scope();
5961 statement(ts: "CubeFace = P.y >= 0 ? 2 : 3;");
5962 statement(ts: "ProjectionAxis = Coords.y;");
5963 statement(ts: "u = P.x;");
5964 statement(ts: "v = P.y >= 0 ? P.z : -P.z;");
5965 end_scope();
5966 statement(ts: "else");
5967 begin_scope();
5968 statement(ts: "CubeFace = P.z >= 0 ? 4 : 5;");
5969 statement(ts: "ProjectionAxis = Coords.z;");
5970 statement(ts: "u = P.z >= 0 ? P.x : -P.x;");
5971 statement(ts: "v = -P.y;");
5972 end_scope();
5973 statement(ts: "u = 0.5 * (u/ProjectionAxis + 1);");
5974 statement(ts: "v = 0.5 * (v/ProjectionAxis + 1);");
5975 statement(ts: "return float3(u, v, CubeFace);");
5976 end_scope();
5977 statement(ts: "");
5978 break;
5979
5980 case SPVFuncImplInverse4x4:
5981 statement(ts: "// Returns the determinant of a 2x2 matrix.");
5982 statement(ts&: force_inline);
5983 statement(ts: "float spvDet2x2(float a1, float a2, float b1, float b2)");
5984 begin_scope();
5985 statement(ts: "return a1 * b2 - b1 * a2;");
5986 end_scope();
5987 statement(ts: "");
5988
5989 statement(ts: "// Returns the determinant of a 3x3 matrix.");
5990 statement(ts&: force_inline);
5991 statement(ts: "float spvDet3x3(float a1, float a2, float a3, float b1, float b2, float b3, float c1, "
5992 "float c2, float c3)");
5993 begin_scope();
5994 statement(ts: "return a1 * spvDet2x2(b2, b3, c2, c3) - b1 * spvDet2x2(a2, a3, c2, c3) + c1 * spvDet2x2(a2, a3, "
5995 "b2, b3);");
5996 end_scope();
5997 statement(ts: "");
5998 statement(ts: "// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
5999 statement(ts: "// adjoint and dividing by the determinant. The contents of the matrix are changed.");
6000 statement(ts&: force_inline);
6001 statement(ts: "float4x4 spvInverse4x4(float4x4 m)");
6002 begin_scope();
6003 statement(ts: "float4x4 adj; // The adjoint matrix (inverse after dividing by determinant)");
6004 statement_no_indent(ts: "");
6005 statement(ts: "// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
6006 statement(ts: "adj[0][0] = spvDet3x3(m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
6007 "m[3][3]);");
6008 statement(ts: "adj[0][1] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
6009 "m[3][3]);");
6010 statement(ts: "adj[0][2] = spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[3][1], m[3][2], "
6011 "m[3][3]);");
6012 statement(ts: "adj[0][3] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], "
6013 "m[2][3]);");
6014 statement_no_indent(ts: "");
6015 statement(ts: "adj[1][0] = -spvDet3x3(m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
6016 "m[3][3]);");
6017 statement(ts: "adj[1][1] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
6018 "m[3][3]);");
6019 statement(ts: "adj[1][2] = -spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[3][0], m[3][2], "
6020 "m[3][3]);");
6021 statement(ts: "adj[1][3] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], "
6022 "m[2][3]);");
6023 statement_no_indent(ts: "");
6024 statement(ts: "adj[2][0] = spvDet3x3(m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
6025 "m[3][3]);");
6026 statement(ts: "adj[2][1] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
6027 "m[3][3]);");
6028 statement(ts: "adj[2][2] = spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[3][0], m[3][1], "
6029 "m[3][3]);");
6030 statement(ts: "adj[2][3] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], "
6031 "m[2][3]);");
6032 statement_no_indent(ts: "");
6033 statement(ts: "adj[3][0] = -spvDet3x3(m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
6034 "m[3][2]);");
6035 statement(ts: "adj[3][1] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
6036 "m[3][2]);");
6037 statement(ts: "adj[3][2] = -spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[3][0], m[3][1], "
6038 "m[3][2]);");
6039 statement(ts: "adj[3][3] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], "
6040 "m[2][2]);");
6041 statement_no_indent(ts: "");
6042 statement(ts: "// Calculate the determinant as a combination of the cofactors of the first row.");
6043 statement(ts: "float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]) + (adj[0][3] "
6044 "* m[3][0]);");
6045 statement_no_indent(ts: "");
6046 statement(ts: "// Divide the classical adjoint matrix by the determinant.");
6047 statement(ts: "// If determinant is zero, matrix is not invertable, so leave it unchanged.");
6048 statement(ts: "return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
6049 end_scope();
6050 statement(ts: "");
6051 break;
6052
6053 case SPVFuncImplInverse3x3:
6054 if (spv_function_implementations.count(x: SPVFuncImplInverse4x4) == 0)
6055 {
6056 statement(ts: "// Returns the determinant of a 2x2 matrix.");
6057 statement(ts&: force_inline);
6058 statement(ts: "float spvDet2x2(float a1, float a2, float b1, float b2)");
6059 begin_scope();
6060 statement(ts: "return a1 * b2 - b1 * a2;");
6061 end_scope();
6062 statement(ts: "");
6063 }
6064
6065 statement(ts: "// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
6066 statement(ts: "// adjoint and dividing by the determinant. The contents of the matrix are changed.");
6067 statement(ts&: force_inline);
6068 statement(ts: "float3x3 spvInverse3x3(float3x3 m)");
6069 begin_scope();
6070 statement(ts: "float3x3 adj; // The adjoint matrix (inverse after dividing by determinant)");
6071 statement_no_indent(ts: "");
6072 statement(ts: "// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
6073 statement(ts: "adj[0][0] = spvDet2x2(m[1][1], m[1][2], m[2][1], m[2][2]);");
6074 statement(ts: "adj[0][1] = -spvDet2x2(m[0][1], m[0][2], m[2][1], m[2][2]);");
6075 statement(ts: "adj[0][2] = spvDet2x2(m[0][1], m[0][2], m[1][1], m[1][2]);");
6076 statement_no_indent(ts: "");
6077 statement(ts: "adj[1][0] = -spvDet2x2(m[1][0], m[1][2], m[2][0], m[2][2]);");
6078 statement(ts: "adj[1][1] = spvDet2x2(m[0][0], m[0][2], m[2][0], m[2][2]);");
6079 statement(ts: "adj[1][2] = -spvDet2x2(m[0][0], m[0][2], m[1][0], m[1][2]);");
6080 statement_no_indent(ts: "");
6081 statement(ts: "adj[2][0] = spvDet2x2(m[1][0], m[1][1], m[2][0], m[2][1]);");
6082 statement(ts: "adj[2][1] = -spvDet2x2(m[0][0], m[0][1], m[2][0], m[2][1]);");
6083 statement(ts: "adj[2][2] = spvDet2x2(m[0][0], m[0][1], m[1][0], m[1][1]);");
6084 statement_no_indent(ts: "");
6085 statement(ts: "// Calculate the determinant as a combination of the cofactors of the first row.");
6086 statement(ts: "float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]);");
6087 statement_no_indent(ts: "");
6088 statement(ts: "// Divide the classical adjoint matrix by the determinant.");
6089 statement(ts: "// If determinant is zero, matrix is not invertable, so leave it unchanged.");
6090 statement(ts: "return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
6091 end_scope();
6092 statement(ts: "");
6093 break;
6094
6095 case SPVFuncImplInverse2x2:
6096 statement(ts: "// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
6097 statement(ts: "// adjoint and dividing by the determinant. The contents of the matrix are changed.");
6098 statement(ts&: force_inline);
6099 statement(ts: "float2x2 spvInverse2x2(float2x2 m)");
6100 begin_scope();
6101 statement(ts: "float2x2 adj; // The adjoint matrix (inverse after dividing by determinant)");
6102 statement_no_indent(ts: "");
6103 statement(ts: "// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
6104 statement(ts: "adj[0][0] = m[1][1];");
6105 statement(ts: "adj[0][1] = -m[0][1];");
6106 statement_no_indent(ts: "");
6107 statement(ts: "adj[1][0] = -m[1][0];");
6108 statement(ts: "adj[1][1] = m[0][0];");
6109 statement_no_indent(ts: "");
6110 statement(ts: "// Calculate the determinant as a combination of the cofactors of the first row.");
6111 statement(ts: "float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]);");
6112 statement_no_indent(ts: "");
6113 statement(ts: "// Divide the classical adjoint matrix by the determinant.");
6114 statement(ts: "// If determinant is zero, matrix is not invertable, so leave it unchanged.");
6115 statement(ts: "return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
6116 end_scope();
6117 statement(ts: "");
6118 break;
6119
6120 case SPVFuncImplForwardArgs:
6121 statement(ts: "template<typename T> struct spvRemoveReference { typedef T type; };");
6122 statement(ts: "template<typename T> struct spvRemoveReference<thread T&> { typedef T type; };");
6123 statement(ts: "template<typename T> struct spvRemoveReference<thread T&&> { typedef T type; };");
6124 statement(ts: "template<typename T> inline constexpr thread T&& spvForward(thread typename "
6125 "spvRemoveReference<T>::type& x)");
6126 begin_scope();
6127 statement(ts: "return static_cast<thread T&&>(x);");
6128 end_scope();
6129 statement(ts: "template<typename T> inline constexpr thread T&& spvForward(thread typename "
6130 "spvRemoveReference<T>::type&& x)");
6131 begin_scope();
6132 statement(ts: "return static_cast<thread T&&>(x);");
6133 end_scope();
6134 statement(ts: "");
6135 break;
6136
6137 case SPVFuncImplGetSwizzle:
6138 statement(ts: "enum class spvSwizzle : uint");
6139 begin_scope();
6140 statement(ts: "none = 0,");
6141 statement(ts: "zero,");
6142 statement(ts: "one,");
6143 statement(ts: "red,");
6144 statement(ts: "green,");
6145 statement(ts: "blue,");
6146 statement(ts: "alpha");
6147 end_scope_decl();
6148 statement(ts: "");
6149 statement(ts: "template<typename T>");
6150 statement(ts: "inline T spvGetSwizzle(vec<T, 4> x, T c, spvSwizzle s)");
6151 begin_scope();
6152 statement(ts: "switch (s)");
6153 begin_scope();
6154 statement(ts: "case spvSwizzle::none:");
6155 statement(ts: " return c;");
6156 statement(ts: "case spvSwizzle::zero:");
6157 statement(ts: " return 0;");
6158 statement(ts: "case spvSwizzle::one:");
6159 statement(ts: " return 1;");
6160 statement(ts: "case spvSwizzle::red:");
6161 statement(ts: " return x.r;");
6162 statement(ts: "case spvSwizzle::green:");
6163 statement(ts: " return x.g;");
6164 statement(ts: "case spvSwizzle::blue:");
6165 statement(ts: " return x.b;");
6166 statement(ts: "case spvSwizzle::alpha:");
6167 statement(ts: " return x.a;");
6168 end_scope();
6169 end_scope();
6170 statement(ts: "");
6171 break;
6172
6173 case SPVFuncImplTextureSwizzle:
6174 statement(ts: "// Wrapper function that swizzles texture samples and fetches.");
6175 statement(ts: "template<typename T>");
6176 statement(ts: "inline vec<T, 4> spvTextureSwizzle(vec<T, 4> x, uint s)");
6177 begin_scope();
6178 statement(ts: "if (!s)");
6179 statement(ts: " return x;");
6180 statement(ts: "return vec<T, 4>(spvGetSwizzle(x, x.r, spvSwizzle((s >> 0) & 0xFF)), "
6181 "spvGetSwizzle(x, x.g, spvSwizzle((s >> 8) & 0xFF)), spvGetSwizzle(x, x.b, spvSwizzle((s >> 16) "
6182 "& 0xFF)), "
6183 "spvGetSwizzle(x, x.a, spvSwizzle((s >> 24) & 0xFF)));");
6184 end_scope();
6185 statement(ts: "");
6186 statement(ts: "template<typename T>");
6187 statement(ts: "inline T spvTextureSwizzle(T x, uint s)");
6188 begin_scope();
6189 statement(ts: "return spvTextureSwizzle(vec<T, 4>(x, 0, 0, 1), s).x;");
6190 end_scope();
6191 statement(ts: "");
6192 break;
6193
6194 case SPVFuncImplGatherSwizzle:
6195 statement(ts: "// Wrapper function that swizzles texture gathers.");
6196 statement(ts: "template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
6197 "typename... Ts>");
6198 statement(ts: "inline vec<T, 4> spvGatherSwizzle(const thread Tex<T>& t, sampler s, "
6199 "uint sw, component c, Ts... params) METAL_CONST_ARG(c)");
6200 begin_scope();
6201 statement(ts: "if (sw)");
6202 begin_scope();
6203 statement(ts: "switch (spvSwizzle((sw >> (uint(c) * 8)) & 0xFF))");
6204 begin_scope();
6205 statement(ts: "case spvSwizzle::none:");
6206 statement(ts: " break;");
6207 statement(ts: "case spvSwizzle::zero:");
6208 statement(ts: " return vec<T, 4>(0, 0, 0, 0);");
6209 statement(ts: "case spvSwizzle::one:");
6210 statement(ts: " return vec<T, 4>(1, 1, 1, 1);");
6211 statement(ts: "case spvSwizzle::red:");
6212 statement(ts: " return t.gather(s, spvForward<Ts>(params)..., component::x);");
6213 statement(ts: "case spvSwizzle::green:");
6214 statement(ts: " return t.gather(s, spvForward<Ts>(params)..., component::y);");
6215 statement(ts: "case spvSwizzle::blue:");
6216 statement(ts: " return t.gather(s, spvForward<Ts>(params)..., component::z);");
6217 statement(ts: "case spvSwizzle::alpha:");
6218 statement(ts: " return t.gather(s, spvForward<Ts>(params)..., component::w);");
6219 end_scope();
6220 end_scope();
6221 // texture::gather insists on its component parameter being a constant
6222 // expression, so we need this silly workaround just to compile the shader.
6223 statement(ts: "switch (c)");
6224 begin_scope();
6225 statement(ts: "case component::x:");
6226 statement(ts: " return t.gather(s, spvForward<Ts>(params)..., component::x);");
6227 statement(ts: "case component::y:");
6228 statement(ts: " return t.gather(s, spvForward<Ts>(params)..., component::y);");
6229 statement(ts: "case component::z:");
6230 statement(ts: " return t.gather(s, spvForward<Ts>(params)..., component::z);");
6231 statement(ts: "case component::w:");
6232 statement(ts: " return t.gather(s, spvForward<Ts>(params)..., component::w);");
6233 end_scope();
6234 end_scope();
6235 statement(ts: "");
6236 break;
6237
6238 case SPVFuncImplGatherCompareSwizzle:
6239 statement(ts: "// Wrapper function that swizzles depth texture gathers.");
6240 statement(ts: "template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
6241 "typename... Ts>");
6242 statement(ts: "inline vec<T, 4> spvGatherCompareSwizzle(const thread Tex<T>& t, sampler "
6243 "s, uint sw, Ts... params) ");
6244 begin_scope();
6245 statement(ts: "if (sw)");
6246 begin_scope();
6247 statement(ts: "switch (spvSwizzle(sw & 0xFF))");
6248 begin_scope();
6249 statement(ts: "case spvSwizzle::none:");
6250 statement(ts: "case spvSwizzle::red:");
6251 statement(ts: " break;");
6252 statement(ts: "case spvSwizzle::zero:");
6253 statement(ts: "case spvSwizzle::green:");
6254 statement(ts: "case spvSwizzle::blue:");
6255 statement(ts: "case spvSwizzle::alpha:");
6256 statement(ts: " return vec<T, 4>(0, 0, 0, 0);");
6257 statement(ts: "case spvSwizzle::one:");
6258 statement(ts: " return vec<T, 4>(1, 1, 1, 1);");
6259 end_scope();
6260 end_scope();
6261 statement(ts: "return t.gather_compare(s, spvForward<Ts>(params)...);");
6262 end_scope();
6263 statement(ts: "");
6264 break;
6265
6266 case SPVFuncImplGatherConstOffsets:
6267 statement(ts: "// Wrapper function that processes a texture gather with a constant offset array.");
6268 statement(ts: "template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
6269 "typename Toff, typename... Tp>");
6270 statement(ts: "inline vec<T, 4> spvGatherConstOffsets(const thread Tex<T>& t, sampler s, "
6271 "Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)");
6272 begin_scope();
6273 statement(ts: "vec<T, 4> rslts[4];");
6274 statement(ts: "for (uint i = 0; i < 4; i++)");
6275 begin_scope();
6276 statement(ts: "switch (c)");
6277 begin_scope();
6278 // Work around texture::gather() requiring its component parameter to be a constant expression
6279 statement(ts: "case component::x:");
6280 statement(ts: " rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::x);");
6281 statement(ts: " break;");
6282 statement(ts: "case component::y:");
6283 statement(ts: " rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::y);");
6284 statement(ts: " break;");
6285 statement(ts: "case component::z:");
6286 statement(ts: " rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::z);");
6287 statement(ts: " break;");
6288 statement(ts: "case component::w:");
6289 statement(ts: " rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::w);");
6290 statement(ts: " break;");
6291 end_scope();
6292 end_scope();
6293 // Pull all values from the i0j0 component of each gather footprint
6294 statement(ts: "return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);");
6295 end_scope();
6296 statement(ts: "");
6297 break;
6298
6299 case SPVFuncImplGatherCompareConstOffsets:
6300 statement(ts: "// Wrapper function that processes a texture gather with a constant offset array.");
6301 statement(ts: "template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
6302 "typename Toff, typename... Tp>");
6303 statement(ts: "inline vec<T, 4> spvGatherCompareConstOffsets(const thread Tex<T>& t, sampler s, "
6304 "Toff coffsets, Tp... params)");
6305 begin_scope();
6306 statement(ts: "vec<T, 4> rslts[4];");
6307 statement(ts: "for (uint i = 0; i < 4; i++)");
6308 begin_scope();
6309 statement(ts: " rslts[i] = t.gather_compare(s, spvForward<Tp>(params)..., coffsets[i]);");
6310 end_scope();
6311 // Pull all values from the i0j0 component of each gather footprint
6312 statement(ts: "return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);");
6313 end_scope();
6314 statement(ts: "");
6315 break;
6316
6317 case SPVFuncImplSubgroupBroadcast:
6318 // Metal doesn't allow broadcasting boolean values directly, but we can work around that by broadcasting
6319 // them as integers.
6320 statement(ts: "template<typename T>");
6321 statement(ts: "inline T spvSubgroupBroadcast(T value, ushort lane)");
6322 begin_scope();
6323 if (msl_options.use_quadgroup_operation())
6324 statement(ts: "return quad_broadcast(value, lane);");
6325 else
6326 statement(ts: "return simd_broadcast(value, lane);");
6327 end_scope();
6328 statement(ts: "");
6329 statement(ts: "template<>");
6330 statement(ts: "inline bool spvSubgroupBroadcast(bool value, ushort lane)");
6331 begin_scope();
6332 if (msl_options.use_quadgroup_operation())
6333 statement(ts: "return !!quad_broadcast((ushort)value, lane);");
6334 else
6335 statement(ts: "return !!simd_broadcast((ushort)value, lane);");
6336 end_scope();
6337 statement(ts: "");
6338 statement(ts: "template<uint N>");
6339 statement(ts: "inline vec<bool, N> spvSubgroupBroadcast(vec<bool, N> value, ushort lane)");
6340 begin_scope();
6341 if (msl_options.use_quadgroup_operation())
6342 statement(ts: "return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);");
6343 else
6344 statement(ts: "return (vec<bool, N>)simd_broadcast((vec<ushort, N>)value, lane);");
6345 end_scope();
6346 statement(ts: "");
6347 break;
6348
6349 case SPVFuncImplSubgroupBroadcastFirst:
6350 statement(ts: "template<typename T>");
6351 statement(ts: "inline T spvSubgroupBroadcastFirst(T value)");
6352 begin_scope();
6353 if (msl_options.use_quadgroup_operation())
6354 statement(ts: "return quad_broadcast_first(value);");
6355 else
6356 statement(ts: "return simd_broadcast_first(value);");
6357 end_scope();
6358 statement(ts: "");
6359 statement(ts: "template<>");
6360 statement(ts: "inline bool spvSubgroupBroadcastFirst(bool value)");
6361 begin_scope();
6362 if (msl_options.use_quadgroup_operation())
6363 statement(ts: "return !!quad_broadcast_first((ushort)value);");
6364 else
6365 statement(ts: "return !!simd_broadcast_first((ushort)value);");
6366 end_scope();
6367 statement(ts: "");
6368 statement(ts: "template<uint N>");
6369 statement(ts: "inline vec<bool, N> spvSubgroupBroadcastFirst(vec<bool, N> value)");
6370 begin_scope();
6371 if (msl_options.use_quadgroup_operation())
6372 statement(ts: "return (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value);");
6373 else
6374 statement(ts: "return (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value);");
6375 end_scope();
6376 statement(ts: "");
6377 break;
6378
6379 case SPVFuncImplSubgroupBallot:
6380 statement(ts: "inline uint4 spvSubgroupBallot(bool value)");
6381 begin_scope();
6382 if (msl_options.use_quadgroup_operation())
6383 {
6384 statement(ts: "return uint4((quad_vote::vote_t)quad_ballot(value), 0, 0, 0);");
6385 }
6386 else if (msl_options.is_ios())
6387 {
6388 // The current simd_vote on iOS uses a 32-bit integer-like object.
6389 statement(ts: "return uint4((simd_vote::vote_t)simd_ballot(value), 0, 0, 0);");
6390 }
6391 else
6392 {
6393 statement(ts: "simd_vote vote = simd_ballot(value);");
6394 statement(ts: "// simd_ballot() returns a 64-bit integer-like object, but");
6395 statement(ts: "// SPIR-V callers expect a uint4. We must convert.");
6396 statement(ts: "// FIXME: This won't include higher bits if Apple ever supports");
6397 statement(ts: "// 128 lanes in an SIMD-group.");
6398 statement(ts: "return uint4(as_type<uint2>((simd_vote::vote_t)vote), 0, 0);");
6399 }
6400 end_scope();
6401 statement(ts: "");
6402 break;
6403
6404 case SPVFuncImplSubgroupBallotBitExtract:
6405 statement(ts: "inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)");
6406 begin_scope();
6407 statement(ts: "return !!extract_bits(ballot[bit / 32], bit % 32, 1);");
6408 end_scope();
6409 statement(ts: "");
6410 break;
6411
6412 case SPVFuncImplSubgroupBallotFindLSB:
6413 statement(ts: "inline uint spvSubgroupBallotFindLSB(uint4 ballot, uint gl_SubgroupSize)");
6414 begin_scope();
6415 if (msl_options.is_ios())
6416 {
6417 statement(ts: "uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));");
6418 }
6419 else
6420 {
6421 statement(ts: "uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
6422 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
6423 }
6424 statement(ts: "ballot &= mask;");
6425 statement(ts: "return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + "
6426 "ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);");
6427 end_scope();
6428 statement(ts: "");
6429 break;
6430
6431 case SPVFuncImplSubgroupBallotFindMSB:
6432 statement(ts: "inline uint spvSubgroupBallotFindMSB(uint4 ballot, uint gl_SubgroupSize)");
6433 begin_scope();
6434 if (msl_options.is_ios())
6435 {
6436 statement(ts: "uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));");
6437 }
6438 else
6439 {
6440 statement(ts: "uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
6441 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
6442 }
6443 statement(ts: "ballot &= mask;");
6444 statement(ts: "return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - "
6445 "(clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), "
6446 "ballot.z == 0), ballot.w == 0);");
6447 end_scope();
6448 statement(ts: "");
6449 break;
6450
6451 case SPVFuncImplSubgroupBallotBitCount:
6452 statement(ts: "inline uint spvPopCount4(uint4 ballot)");
6453 begin_scope();
6454 statement(ts: "return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);");
6455 end_scope();
6456 statement(ts: "");
6457 statement(ts: "inline uint spvSubgroupBallotBitCount(uint4 ballot, uint gl_SubgroupSize)");
6458 begin_scope();
6459 if (msl_options.is_ios())
6460 {
6461 statement(ts: "uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));");
6462 }
6463 else
6464 {
6465 statement(ts: "uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
6466 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
6467 }
6468 statement(ts: "return spvPopCount4(ballot & mask);");
6469 end_scope();
6470 statement(ts: "");
6471 statement(ts: "inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
6472 begin_scope();
6473 if (msl_options.is_ios())
6474 {
6475 statement(ts: "uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupInvocationID + 1), uint3(0));");
6476 }
6477 else
6478 {
6479 statement(ts: "uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), "
6480 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), "
6481 "uint2(0));");
6482 }
6483 statement(ts: "return spvPopCount4(ballot & mask);");
6484 end_scope();
6485 statement(ts: "");
6486 statement(ts: "inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
6487 begin_scope();
6488 if (msl_options.is_ios())
6489 {
6490 statement(ts: "uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupInvocationID), uint2(0));");
6491 }
6492 else
6493 {
6494 statement(ts: "uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), "
6495 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));");
6496 }
6497 statement(ts: "return spvPopCount4(ballot & mask);");
6498 end_scope();
6499 statement(ts: "");
6500 break;
6501
6502 case SPVFuncImplSubgroupAllEqual:
6503 // Metal doesn't provide a function to evaluate this directly. But, we can
6504 // implement this by comparing every thread's value to one thread's value
6505 // (in this case, the value of the first active thread). Then, by the transitive
6506 // property of equality, if all comparisons return true, then they are all equal.
6507 statement(ts: "template<typename T>");
6508 statement(ts: "inline bool spvSubgroupAllEqual(T value)");
6509 begin_scope();
6510 if (msl_options.use_quadgroup_operation())
6511 statement(ts: "return quad_all(all(value == quad_broadcast_first(value)));");
6512 else
6513 statement(ts: "return simd_all(all(value == simd_broadcast_first(value)));");
6514 end_scope();
6515 statement(ts: "");
6516 statement(ts: "template<>");
6517 statement(ts: "inline bool spvSubgroupAllEqual(bool value)");
6518 begin_scope();
6519 if (msl_options.use_quadgroup_operation())
6520 statement(ts: "return quad_all(value) || !quad_any(value);");
6521 else
6522 statement(ts: "return simd_all(value) || !simd_any(value);");
6523 end_scope();
6524 statement(ts: "");
6525 statement(ts: "template<uint N>");
6526 statement(ts: "inline bool spvSubgroupAllEqual(vec<bool, N> value)");
6527 begin_scope();
6528 if (msl_options.use_quadgroup_operation())
6529 statement(ts: "return quad_all(all(value == (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value)));");
6530 else
6531 statement(ts: "return simd_all(all(value == (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value)));");
6532 end_scope();
6533 statement(ts: "");
6534 break;
6535
6536 case SPVFuncImplSubgroupShuffle:
6537 statement(ts: "template<typename T>");
6538 statement(ts: "inline T spvSubgroupShuffle(T value, ushort lane)");
6539 begin_scope();
6540 if (msl_options.use_quadgroup_operation())
6541 statement(ts: "return quad_shuffle(value, lane);");
6542 else
6543 statement(ts: "return simd_shuffle(value, lane);");
6544 end_scope();
6545 statement(ts: "");
6546 statement(ts: "template<>");
6547 statement(ts: "inline bool spvSubgroupShuffle(bool value, ushort lane)");
6548 begin_scope();
6549 if (msl_options.use_quadgroup_operation())
6550 statement(ts: "return !!quad_shuffle((ushort)value, lane);");
6551 else
6552 statement(ts: "return !!simd_shuffle((ushort)value, lane);");
6553 end_scope();
6554 statement(ts: "");
6555 statement(ts: "template<uint N>");
6556 statement(ts: "inline vec<bool, N> spvSubgroupShuffle(vec<bool, N> value, ushort lane)");
6557 begin_scope();
6558 if (msl_options.use_quadgroup_operation())
6559 statement(ts: "return (vec<bool, N>)quad_shuffle((vec<ushort, N>)value, lane);");
6560 else
6561 statement(ts: "return (vec<bool, N>)simd_shuffle((vec<ushort, N>)value, lane);");
6562 end_scope();
6563 statement(ts: "");
6564 break;
6565
6566 case SPVFuncImplSubgroupShuffleXor:
6567 statement(ts: "template<typename T>");
6568 statement(ts: "inline T spvSubgroupShuffleXor(T value, ushort mask)");
6569 begin_scope();
6570 if (msl_options.use_quadgroup_operation())
6571 statement(ts: "return quad_shuffle_xor(value, mask);");
6572 else
6573 statement(ts: "return simd_shuffle_xor(value, mask);");
6574 end_scope();
6575 statement(ts: "");
6576 statement(ts: "template<>");
6577 statement(ts: "inline bool spvSubgroupShuffleXor(bool value, ushort mask)");
6578 begin_scope();
6579 if (msl_options.use_quadgroup_operation())
6580 statement(ts: "return !!quad_shuffle_xor((ushort)value, mask);");
6581 else
6582 statement(ts: "return !!simd_shuffle_xor((ushort)value, mask);");
6583 end_scope();
6584 statement(ts: "");
6585 statement(ts: "template<uint N>");
6586 statement(ts: "inline vec<bool, N> spvSubgroupShuffleXor(vec<bool, N> value, ushort mask)");
6587 begin_scope();
6588 if (msl_options.use_quadgroup_operation())
6589 statement(ts: "return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, mask);");
6590 else
6591 statement(ts: "return (vec<bool, N>)simd_shuffle_xor((vec<ushort, N>)value, mask);");
6592 end_scope();
6593 statement(ts: "");
6594 break;
6595
6596 case SPVFuncImplSubgroupShuffleUp:
6597 statement(ts: "template<typename T>");
6598 statement(ts: "inline T spvSubgroupShuffleUp(T value, ushort delta)");
6599 begin_scope();
6600 if (msl_options.use_quadgroup_operation())
6601 statement(ts: "return quad_shuffle_up(value, delta);");
6602 else
6603 statement(ts: "return simd_shuffle_up(value, delta);");
6604 end_scope();
6605 statement(ts: "");
6606 statement(ts: "template<>");
6607 statement(ts: "inline bool spvSubgroupShuffleUp(bool value, ushort delta)");
6608 begin_scope();
6609 if (msl_options.use_quadgroup_operation())
6610 statement(ts: "return !!quad_shuffle_up((ushort)value, delta);");
6611 else
6612 statement(ts: "return !!simd_shuffle_up((ushort)value, delta);");
6613 end_scope();
6614 statement(ts: "");
6615 statement(ts: "template<uint N>");
6616 statement(ts: "inline vec<bool, N> spvSubgroupShuffleUp(vec<bool, N> value, ushort delta)");
6617 begin_scope();
6618 if (msl_options.use_quadgroup_operation())
6619 statement(ts: "return (vec<bool, N>)quad_shuffle_up((vec<ushort, N>)value, delta);");
6620 else
6621 statement(ts: "return (vec<bool, N>)simd_shuffle_up((vec<ushort, N>)value, delta);");
6622 end_scope();
6623 statement(ts: "");
6624 break;
6625
6626 case SPVFuncImplSubgroupShuffleDown:
6627 statement(ts: "template<typename T>");
6628 statement(ts: "inline T spvSubgroupShuffleDown(T value, ushort delta)");
6629 begin_scope();
6630 if (msl_options.use_quadgroup_operation())
6631 statement(ts: "return quad_shuffle_down(value, delta);");
6632 else
6633 statement(ts: "return simd_shuffle_down(value, delta);");
6634 end_scope();
6635 statement(ts: "");
6636 statement(ts: "template<>");
6637 statement(ts: "inline bool spvSubgroupShuffleDown(bool value, ushort delta)");
6638 begin_scope();
6639 if (msl_options.use_quadgroup_operation())
6640 statement(ts: "return !!quad_shuffle_down((ushort)value, delta);");
6641 else
6642 statement(ts: "return !!simd_shuffle_down((ushort)value, delta);");
6643 end_scope();
6644 statement(ts: "");
6645 statement(ts: "template<uint N>");
6646 statement(ts: "inline vec<bool, N> spvSubgroupShuffleDown(vec<bool, N> value, ushort delta)");
6647 begin_scope();
6648 if (msl_options.use_quadgroup_operation())
6649 statement(ts: "return (vec<bool, N>)quad_shuffle_down((vec<ushort, N>)value, delta);");
6650 else
6651 statement(ts: "return (vec<bool, N>)simd_shuffle_down((vec<ushort, N>)value, delta);");
6652 end_scope();
6653 statement(ts: "");
6654 break;
6655
6656 case SPVFuncImplQuadBroadcast:
6657 statement(ts: "template<typename T>");
6658 statement(ts: "inline T spvQuadBroadcast(T value, uint lane)");
6659 begin_scope();
6660 statement(ts: "return quad_broadcast(value, lane);");
6661 end_scope();
6662 statement(ts: "");
6663 statement(ts: "template<>");
6664 statement(ts: "inline bool spvQuadBroadcast(bool value, uint lane)");
6665 begin_scope();
6666 statement(ts: "return !!quad_broadcast((ushort)value, lane);");
6667 end_scope();
6668 statement(ts: "");
6669 statement(ts: "template<uint N>");
6670 statement(ts: "inline vec<bool, N> spvQuadBroadcast(vec<bool, N> value, uint lane)");
6671 begin_scope();
6672 statement(ts: "return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);");
6673 end_scope();
6674 statement(ts: "");
6675 break;
6676
6677 case SPVFuncImplQuadSwap:
6678 // We can implement this easily based on the following table giving
6679 // the target lane ID from the direction and current lane ID:
6680 // Direction
6681 // | 0 | 1 | 2 |
6682 // ---+---+---+---+
6683 // L 0 | 1 2 3
6684 // a 1 | 0 3 2
6685 // n 2 | 3 0 1
6686 // e 3 | 2 1 0
6687 // Notice that target = source ^ (direction + 1).
6688 statement(ts: "template<typename T>");
6689 statement(ts: "inline T spvQuadSwap(T value, uint dir)");
6690 begin_scope();
6691 statement(ts: "return quad_shuffle_xor(value, dir + 1);");
6692 end_scope();
6693 statement(ts: "");
6694 statement(ts: "template<>");
6695 statement(ts: "inline bool spvQuadSwap(bool value, uint dir)");
6696 begin_scope();
6697 statement(ts: "return !!quad_shuffle_xor((ushort)value, dir + 1);");
6698 end_scope();
6699 statement(ts: "");
6700 statement(ts: "template<uint N>");
6701 statement(ts: "inline vec<bool, N> spvQuadSwap(vec<bool, N> value, uint dir)");
6702 begin_scope();
6703 statement(ts: "return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, dir + 1);");
6704 end_scope();
6705 statement(ts: "");
6706 break;
6707
6708 case SPVFuncImplReflectScalar:
6709 // Metal does not support scalar versions of these functions.
6710 // Ensure fast-math is disabled to match Vulkan results.
6711 statement(ts: "template<typename T>");
6712 statement(ts: "[[clang::optnone]] T spvReflect(T i, T n)");
6713 begin_scope();
6714 statement(ts: "return i - T(2) * i * n * n;");
6715 end_scope();
6716 statement(ts: "");
6717 break;
6718
6719 case SPVFuncImplRefractScalar:
6720 // Metal does not support scalar versions of these functions.
6721 statement(ts: "template<typename T>");
6722 statement(ts: "inline T spvRefract(T i, T n, T eta)");
6723 begin_scope();
6724 statement(ts: "T NoI = n * i;");
6725 statement(ts: "T NoI2 = NoI * NoI;");
6726 statement(ts: "T k = T(1) - eta * eta * (T(1) - NoI2);");
6727 statement(ts: "if (k < T(0))");
6728 begin_scope();
6729 statement(ts: "return T(0);");
6730 end_scope();
6731 statement(ts: "else");
6732 begin_scope();
6733 statement(ts: "return eta * i - (eta * NoI + sqrt(k)) * n;");
6734 end_scope();
6735 end_scope();
6736 statement(ts: "");
6737 break;
6738
6739 case SPVFuncImplFaceForwardScalar:
6740 // Metal does not support scalar versions of these functions.
6741 statement(ts: "template<typename T>");
6742 statement(ts: "inline T spvFaceForward(T n, T i, T nref)");
6743 begin_scope();
6744 statement(ts: "return i * nref < T(0) ? n : -n;");
6745 end_scope();
6746 statement(ts: "");
6747 break;
6748
6749 case SPVFuncImplChromaReconstructNearest2Plane:
6750 statement(ts: "template<typename T, typename... LodOptions>");
6751 statement(ts: "inline vec<T, 4> spvChromaReconstructNearest(texture2d<T> plane0, texture2d<T> plane1, sampler "
6752 "samp, float2 coord, LodOptions... options)");
6753 begin_scope();
6754 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6755 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6756 statement(ts: "ycbcr.br = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).rg;");
6757 statement(ts: "return ycbcr;");
6758 end_scope();
6759 statement(ts: "");
6760 break;
6761
6762 case SPVFuncImplChromaReconstructNearest3Plane:
6763 statement(ts: "template<typename T, typename... LodOptions>");
6764 statement(ts: "inline vec<T, 4> spvChromaReconstructNearest(texture2d<T> plane0, texture2d<T> plane1, "
6765 "texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
6766 begin_scope();
6767 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6768 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6769 statement(ts: "ycbcr.b = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6770 statement(ts: "ycbcr.r = plane2.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6771 statement(ts: "return ycbcr;");
6772 end_scope();
6773 statement(ts: "");
6774 break;
6775
6776 case SPVFuncImplChromaReconstructLinear422CositedEven2Plane:
6777 statement(ts: "template<typename T, typename... LodOptions>");
6778 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear422CositedEven(texture2d<T> plane0, texture2d<T> "
6779 "plane1, sampler samp, float2 coord, LodOptions... options)");
6780 begin_scope();
6781 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6782 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6783 statement(ts: "if (fract(coord.x * plane1.get_width()) != 0.0)");
6784 begin_scope();
6785 statement(ts: "ycbcr.br = vec<T, 2>(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
6786 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).rg);");
6787 end_scope();
6788 statement(ts: "else");
6789 begin_scope();
6790 statement(ts: "ycbcr.br = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).rg;");
6791 end_scope();
6792 statement(ts: "return ycbcr;");
6793 end_scope();
6794 statement(ts: "");
6795 break;
6796
6797 case SPVFuncImplChromaReconstructLinear422CositedEven3Plane:
6798 statement(ts: "template<typename T, typename... LodOptions>");
6799 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear422CositedEven(texture2d<T> plane0, texture2d<T> "
6800 "plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
6801 begin_scope();
6802 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6803 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6804 statement(ts: "if (fract(coord.x * plane1.get_width()) != 0.0)");
6805 begin_scope();
6806 statement(ts: "ycbcr.b = T(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
6807 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).r);");
6808 statement(ts: "ycbcr.r = T(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
6809 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).r);");
6810 end_scope();
6811 statement(ts: "else");
6812 begin_scope();
6813 statement(ts: "ycbcr.b = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6814 statement(ts: "ycbcr.r = plane2.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6815 end_scope();
6816 statement(ts: "return ycbcr;");
6817 end_scope();
6818 statement(ts: "");
6819 break;
6820
6821 case SPVFuncImplChromaReconstructLinear422Midpoint2Plane:
6822 statement(ts: "template<typename T, typename... LodOptions>");
6823 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear422Midpoint(texture2d<T> plane0, texture2d<T> "
6824 "plane1, sampler samp, float2 coord, LodOptions... options)");
6825 begin_scope();
6826 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6827 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6828 statement(ts: "int2 offs = int2(fract(coord.x * plane1.get_width()) != 0.0 ? 1 : -1, 0);");
6829 statement(ts: "ycbcr.br = vec<T, 2>(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
6830 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).rg);");
6831 statement(ts: "return ycbcr;");
6832 end_scope();
6833 statement(ts: "");
6834 break;
6835
6836 case SPVFuncImplChromaReconstructLinear422Midpoint3Plane:
6837 statement(ts: "template<typename T, typename... LodOptions>");
6838 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear422Midpoint(texture2d<T> plane0, texture2d<T> "
6839 "plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
6840 begin_scope();
6841 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6842 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6843 statement(ts: "int2 offs = int2(fract(coord.x * plane1.get_width()) != 0.0 ? 1 : -1, 0);");
6844 statement(ts: "ycbcr.b = T(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
6845 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).r);");
6846 statement(ts: "ycbcr.r = T(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
6847 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).r);");
6848 statement(ts: "return ycbcr;");
6849 end_scope();
6850 statement(ts: "");
6851 break;
6852
6853 case SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven2Plane:
6854 statement(ts: "template<typename T, typename... LodOptions>");
6855 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYCositedEven(texture2d<T> plane0, "
6856 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
6857 begin_scope();
6858 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6859 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6860 statement(ts: "float2 ab = fract(round(coord * float2(plane0.get_width(), plane0.get_height())) * 0.5);");
6861 statement(ts: "ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
6862 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
6863 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
6864 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
6865 statement(ts: "return ycbcr;");
6866 end_scope();
6867 statement(ts: "");
6868 break;
6869
6870 case SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven3Plane:
6871 statement(ts: "template<typename T, typename... LodOptions>");
6872 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYCositedEven(texture2d<T> plane0, "
6873 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
6874 begin_scope();
6875 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6876 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6877 statement(ts: "float2 ab = fract(round(coord * float2(plane0.get_width(), plane0.get_height())) * 0.5);");
6878 statement(ts: "ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
6879 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
6880 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
6881 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
6882 statement(ts: "ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
6883 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
6884 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
6885 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
6886 statement(ts: "return ycbcr;");
6887 end_scope();
6888 statement(ts: "");
6889 break;
6890
6891 case SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven2Plane:
6892 statement(ts: "template<typename T, typename... LodOptions>");
6893 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear420XMidpointYCositedEven(texture2d<T> plane0, "
6894 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
6895 begin_scope();
6896 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6897 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6898 statement(ts: "float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
6899 "0)) * 0.5);");
6900 statement(ts: "ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
6901 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
6902 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
6903 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
6904 statement(ts: "return ycbcr;");
6905 end_scope();
6906 statement(ts: "");
6907 break;
6908
6909 case SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven3Plane:
6910 statement(ts: "template<typename T, typename... LodOptions>");
6911 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear420XMidpointYCositedEven(texture2d<T> plane0, "
6912 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
6913 begin_scope();
6914 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6915 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6916 statement(ts: "float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
6917 "0)) * 0.5);");
6918 statement(ts: "ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
6919 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
6920 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
6921 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
6922 statement(ts: "ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
6923 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
6924 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
6925 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
6926 statement(ts: "return ycbcr;");
6927 end_scope();
6928 statement(ts: "");
6929 break;
6930
6931 case SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint2Plane:
6932 statement(ts: "template<typename T, typename... LodOptions>");
6933 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYMidpoint(texture2d<T> plane0, "
6934 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
6935 begin_scope();
6936 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6937 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6938 statement(ts: "float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0, "
6939 "0.5)) * 0.5);");
6940 statement(ts: "ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
6941 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
6942 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
6943 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
6944 statement(ts: "return ycbcr;");
6945 end_scope();
6946 statement(ts: "");
6947 break;
6948
6949 case SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint3Plane:
6950 statement(ts: "template<typename T, typename... LodOptions>");
6951 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYMidpoint(texture2d<T> plane0, "
6952 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
6953 begin_scope();
6954 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6955 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6956 statement(ts: "float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0, "
6957 "0.5)) * 0.5);");
6958 statement(ts: "ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
6959 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
6960 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
6961 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
6962 statement(ts: "ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
6963 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
6964 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
6965 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
6966 statement(ts: "return ycbcr;");
6967 end_scope();
6968 statement(ts: "");
6969 break;
6970
6971 case SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint2Plane:
6972 statement(ts: "template<typename T, typename... LodOptions>");
6973 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear420XMidpointYMidpoint(texture2d<T> plane0, "
6974 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
6975 begin_scope();
6976 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6977 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6978 statement(ts: "float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
6979 "0.5)) * 0.5);");
6980 statement(ts: "ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
6981 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
6982 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
6983 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
6984 statement(ts: "return ycbcr;");
6985 end_scope();
6986 statement(ts: "");
6987 break;
6988
6989 case SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane:
6990 statement(ts: "template<typename T, typename... LodOptions>");
6991 statement(ts: "inline vec<T, 4> spvChromaReconstructLinear420XMidpointYMidpoint(texture2d<T> plane0, "
6992 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
6993 begin_scope();
6994 statement(ts: "vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6995 statement(ts: "ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6996 statement(ts: "float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
6997 "0.5)) * 0.5);");
6998 statement(ts: "ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
6999 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
7000 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
7001 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
7002 statement(ts: "ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
7003 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
7004 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
7005 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
7006 statement(ts: "return ycbcr;");
7007 end_scope();
7008 statement(ts: "");
7009 break;
7010
7011 case SPVFuncImplExpandITUFullRange:
7012 statement(ts: "template<typename T>");
7013 statement(ts: "inline vec<T, 4> spvExpandITUFullRange(vec<T, 4> ycbcr, int n)");
7014 begin_scope();
7015 statement(ts: "ycbcr.br -= exp2(T(n-1))/(exp2(T(n))-1);");
7016 statement(ts: "return ycbcr;");
7017 end_scope();
7018 statement(ts: "");
7019 break;
7020
7021 case SPVFuncImplExpandITUNarrowRange:
7022 statement(ts: "template<typename T>");
7023 statement(ts: "inline vec<T, 4> spvExpandITUNarrowRange(vec<T, 4> ycbcr, int n)");
7024 begin_scope();
7025 statement(ts: "ycbcr.g = (ycbcr.g * (exp2(T(n)) - 1) - ldexp(T(16), n - 8))/ldexp(T(219), n - 8);");
7026 statement(ts: "ycbcr.br = (ycbcr.br * (exp2(T(n)) - 1) - ldexp(T(128), n - 8))/ldexp(T(224), n - 8);");
7027 statement(ts: "return ycbcr;");
7028 end_scope();
7029 statement(ts: "");
7030 break;
7031
7032 case SPVFuncImplConvertYCbCrBT709:
7033 statement(ts: "// cf. Khronos Data Format Specification, section 15.1.1");
7034 statement(ts: "constant float3x3 spvBT709Factors = {{1, 1, 1}, {0, -0.13397432/0.7152, 1.8556}, {1.5748, "
7035 "-0.33480248/0.7152, 0}};");
7036 statement(ts: "");
7037 statement(ts: "template<typename T>");
7038 statement(ts: "inline vec<T, 4> spvConvertYCbCrBT709(vec<T, 4> ycbcr)");
7039 begin_scope();
7040 statement(ts: "vec<T, 4> rgba;");
7041 statement(ts: "rgba.rgb = vec<T, 3>(spvBT709Factors * ycbcr.gbr);");
7042 statement(ts: "rgba.a = ycbcr.a;");
7043 statement(ts: "return rgba;");
7044 end_scope();
7045 statement(ts: "");
7046 break;
7047
7048 case SPVFuncImplConvertYCbCrBT601:
7049 statement(ts: "// cf. Khronos Data Format Specification, section 15.1.2");
7050 statement(ts: "constant float3x3 spvBT601Factors = {{1, 1, 1}, {0, -0.202008/0.587, 1.772}, {1.402, "
7051 "-0.419198/0.587, 0}};");
7052 statement(ts: "");
7053 statement(ts: "template<typename T>");
7054 statement(ts: "inline vec<T, 4> spvConvertYCbCrBT601(vec<T, 4> ycbcr)");
7055 begin_scope();
7056 statement(ts: "vec<T, 4> rgba;");
7057 statement(ts: "rgba.rgb = vec<T, 3>(spvBT601Factors * ycbcr.gbr);");
7058 statement(ts: "rgba.a = ycbcr.a;");
7059 statement(ts: "return rgba;");
7060 end_scope();
7061 statement(ts: "");
7062 break;
7063
7064 case SPVFuncImplConvertYCbCrBT2020:
7065 statement(ts: "// cf. Khronos Data Format Specification, section 15.1.3");
7066 statement(ts: "constant float3x3 spvBT2020Factors = {{1, 1, 1}, {0, -0.11156702/0.6780, 1.8814}, {1.4746, "
7067 "-0.38737742/0.6780, 0}};");
7068 statement(ts: "");
7069 statement(ts: "template<typename T>");
7070 statement(ts: "inline vec<T, 4> spvConvertYCbCrBT2020(vec<T, 4> ycbcr)");
7071 begin_scope();
7072 statement(ts: "vec<T, 4> rgba;");
7073 statement(ts: "rgba.rgb = vec<T, 3>(spvBT2020Factors * ycbcr.gbr);");
7074 statement(ts: "rgba.a = ycbcr.a;");
7075 statement(ts: "return rgba;");
7076 end_scope();
7077 statement(ts: "");
7078 break;
7079
7080 case SPVFuncImplDynamicImageSampler:
7081 statement(ts: "enum class spvFormatResolution");
7082 begin_scope();
7083 statement(ts: "_444 = 0,");
7084 statement(ts: "_422,");
7085 statement(ts: "_420");
7086 end_scope_decl();
7087 statement(ts: "");
7088 statement(ts: "enum class spvChromaFilter");
7089 begin_scope();
7090 statement(ts: "nearest = 0,");
7091 statement(ts: "linear");
7092 end_scope_decl();
7093 statement(ts: "");
7094 statement(ts: "enum class spvXChromaLocation");
7095 begin_scope();
7096 statement(ts: "cosited_even = 0,");
7097 statement(ts: "midpoint");
7098 end_scope_decl();
7099 statement(ts: "");
7100 statement(ts: "enum class spvYChromaLocation");
7101 begin_scope();
7102 statement(ts: "cosited_even = 0,");
7103 statement(ts: "midpoint");
7104 end_scope_decl();
7105 statement(ts: "");
7106 statement(ts: "enum class spvYCbCrModelConversion");
7107 begin_scope();
7108 statement(ts: "rgb_identity = 0,");
7109 statement(ts: "ycbcr_identity,");
7110 statement(ts: "ycbcr_bt_709,");
7111 statement(ts: "ycbcr_bt_601,");
7112 statement(ts: "ycbcr_bt_2020");
7113 end_scope_decl();
7114 statement(ts: "");
7115 statement(ts: "enum class spvYCbCrRange");
7116 begin_scope();
7117 statement(ts: "itu_full = 0,");
7118 statement(ts: "itu_narrow");
7119 end_scope_decl();
7120 statement(ts: "");
7121 statement(ts: "struct spvComponentBits");
7122 begin_scope();
7123 statement(ts: "constexpr explicit spvComponentBits(int v) thread : value(v) {}");
7124 statement(ts: "uchar value : 6;");
7125 end_scope_decl();
7126 statement(ts: "// A class corresponding to metal::sampler which holds sampler");
7127 statement(ts: "// Y'CbCr conversion info.");
7128 statement(ts: "struct spvYCbCrSampler");
7129 begin_scope();
7130 statement(ts: "constexpr spvYCbCrSampler() thread : val(build()) {}");
7131 statement(ts: "template<typename... Ts>");
7132 statement(ts: "constexpr spvYCbCrSampler(Ts... t) thread : val(build(t...)) {}");
7133 statement(ts: "constexpr spvYCbCrSampler(const thread spvYCbCrSampler& s) thread = default;");
7134 statement(ts: "");
7135 statement(ts: "spvFormatResolution get_resolution() const thread");
7136 begin_scope();
7137 statement(ts: "return spvFormatResolution((val & resolution_mask) >> resolution_base);");
7138 end_scope();
7139 statement(ts: "spvChromaFilter get_chroma_filter() const thread");
7140 begin_scope();
7141 statement(ts: "return spvChromaFilter((val & chroma_filter_mask) >> chroma_filter_base);");
7142 end_scope();
7143 statement(ts: "spvXChromaLocation get_x_chroma_offset() const thread");
7144 begin_scope();
7145 statement(ts: "return spvXChromaLocation((val & x_chroma_off_mask) >> x_chroma_off_base);");
7146 end_scope();
7147 statement(ts: "spvYChromaLocation get_y_chroma_offset() const thread");
7148 begin_scope();
7149 statement(ts: "return spvYChromaLocation((val & y_chroma_off_mask) >> y_chroma_off_base);");
7150 end_scope();
7151 statement(ts: "spvYCbCrModelConversion get_ycbcr_model() const thread");
7152 begin_scope();
7153 statement(ts: "return spvYCbCrModelConversion((val & ycbcr_model_mask) >> ycbcr_model_base);");
7154 end_scope();
7155 statement(ts: "spvYCbCrRange get_ycbcr_range() const thread");
7156 begin_scope();
7157 statement(ts: "return spvYCbCrRange((val & ycbcr_range_mask) >> ycbcr_range_base);");
7158 end_scope();
7159 statement(ts: "int get_bpc() const thread { return (val & bpc_mask) >> bpc_base; }");
7160 statement(ts: "");
7161 statement(ts: "private:");
7162 statement(ts: "ushort val;");
7163 statement(ts: "");
7164 statement(ts: "constexpr static constant ushort resolution_bits = 2;");
7165 statement(ts: "constexpr static constant ushort chroma_filter_bits = 2;");
7166 statement(ts: "constexpr static constant ushort x_chroma_off_bit = 1;");
7167 statement(ts: "constexpr static constant ushort y_chroma_off_bit = 1;");
7168 statement(ts: "constexpr static constant ushort ycbcr_model_bits = 3;");
7169 statement(ts: "constexpr static constant ushort ycbcr_range_bit = 1;");
7170 statement(ts: "constexpr static constant ushort bpc_bits = 6;");
7171 statement(ts: "");
7172 statement(ts: "constexpr static constant ushort resolution_base = 0;");
7173 statement(ts: "constexpr static constant ushort chroma_filter_base = 2;");
7174 statement(ts: "constexpr static constant ushort x_chroma_off_base = 4;");
7175 statement(ts: "constexpr static constant ushort y_chroma_off_base = 5;");
7176 statement(ts: "constexpr static constant ushort ycbcr_model_base = 6;");
7177 statement(ts: "constexpr static constant ushort ycbcr_range_base = 9;");
7178 statement(ts: "constexpr static constant ushort bpc_base = 10;");
7179 statement(ts: "");
7180 statement(
7181 ts: "constexpr static constant ushort resolution_mask = ((1 << resolution_bits) - 1) << resolution_base;");
7182 statement(ts: "constexpr static constant ushort chroma_filter_mask = ((1 << chroma_filter_bits) - 1) << "
7183 "chroma_filter_base;");
7184 statement(ts: "constexpr static constant ushort x_chroma_off_mask = ((1 << x_chroma_off_bit) - 1) << "
7185 "x_chroma_off_base;");
7186 statement(ts: "constexpr static constant ushort y_chroma_off_mask = ((1 << y_chroma_off_bit) - 1) << "
7187 "y_chroma_off_base;");
7188 statement(ts: "constexpr static constant ushort ycbcr_model_mask = ((1 << ycbcr_model_bits) - 1) << "
7189 "ycbcr_model_base;");
7190 statement(ts: "constexpr static constant ushort ycbcr_range_mask = ((1 << ycbcr_range_bit) - 1) << "
7191 "ycbcr_range_base;");
7192 statement(ts: "constexpr static constant ushort bpc_mask = ((1 << bpc_bits) - 1) << bpc_base;");
7193 statement(ts: "");
7194 statement(ts: "static constexpr ushort build()");
7195 begin_scope();
7196 statement(ts: "return 0;");
7197 end_scope();
7198 statement(ts: "");
7199 statement(ts: "template<typename... Ts>");
7200 statement(ts: "static constexpr ushort build(spvFormatResolution res, Ts... t)");
7201 begin_scope();
7202 statement(ts: "return (ushort(res) << resolution_base) | (build(t...) & ~resolution_mask);");
7203 end_scope();
7204 statement(ts: "");
7205 statement(ts: "template<typename... Ts>");
7206 statement(ts: "static constexpr ushort build(spvChromaFilter filt, Ts... t)");
7207 begin_scope();
7208 statement(ts: "return (ushort(filt) << chroma_filter_base) | (build(t...) & ~chroma_filter_mask);");
7209 end_scope();
7210 statement(ts: "");
7211 statement(ts: "template<typename... Ts>");
7212 statement(ts: "static constexpr ushort build(spvXChromaLocation loc, Ts... t)");
7213 begin_scope();
7214 statement(ts: "return (ushort(loc) << x_chroma_off_base) | (build(t...) & ~x_chroma_off_mask);");
7215 end_scope();
7216 statement(ts: "");
7217 statement(ts: "template<typename... Ts>");
7218 statement(ts: "static constexpr ushort build(spvYChromaLocation loc, Ts... t)");
7219 begin_scope();
7220 statement(ts: "return (ushort(loc) << y_chroma_off_base) | (build(t...) & ~y_chroma_off_mask);");
7221 end_scope();
7222 statement(ts: "");
7223 statement(ts: "template<typename... Ts>");
7224 statement(ts: "static constexpr ushort build(spvYCbCrModelConversion model, Ts... t)");
7225 begin_scope();
7226 statement(ts: "return (ushort(model) << ycbcr_model_base) | (build(t...) & ~ycbcr_model_mask);");
7227 end_scope();
7228 statement(ts: "");
7229 statement(ts: "template<typename... Ts>");
7230 statement(ts: "static constexpr ushort build(spvYCbCrRange range, Ts... t)");
7231 begin_scope();
7232 statement(ts: "return (ushort(range) << ycbcr_range_base) | (build(t...) & ~ycbcr_range_mask);");
7233 end_scope();
7234 statement(ts: "");
7235 statement(ts: "template<typename... Ts>");
7236 statement(ts: "static constexpr ushort build(spvComponentBits bpc, Ts... t)");
7237 begin_scope();
7238 statement(ts: "return (ushort(bpc.value) << bpc_base) | (build(t...) & ~bpc_mask);");
7239 end_scope();
7240 end_scope_decl();
7241 statement(ts: "");
7242 statement(ts: "// A class which can hold up to three textures and a sampler, including");
7243 statement(ts: "// Y'CbCr conversion info, used to pass combined image-samplers");
7244 statement(ts: "// dynamically to functions.");
7245 statement(ts: "template<typename T>");
7246 statement(ts: "struct spvDynamicImageSampler");
7247 begin_scope();
7248 statement(ts: "texture2d<T> plane0;");
7249 statement(ts: "texture2d<T> plane1;");
7250 statement(ts: "texture2d<T> plane2;");
7251 statement(ts: "sampler samp;");
7252 statement(ts: "spvYCbCrSampler ycbcr_samp;");
7253 statement(ts: "uint swizzle = 0;");
7254 statement(ts: "");
7255 if (msl_options.swizzle_texture_samples)
7256 {
7257 statement(ts: "constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp, uint sw) thread :");
7258 statement(ts: " plane0(tex), samp(samp), swizzle(sw) {}");
7259 }
7260 else
7261 {
7262 statement(ts: "constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp) thread :");
7263 statement(ts: " plane0(tex), samp(samp) {}");
7264 }
7265 statement(ts: "constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp, spvYCbCrSampler ycbcr_samp, "
7266 "uint sw) thread :");
7267 statement(ts: " plane0(tex), samp(samp), ycbcr_samp(ycbcr_samp), swizzle(sw) {}");
7268 statement(ts: "constexpr spvDynamicImageSampler(texture2d<T> plane0, texture2d<T> plane1,");
7269 statement(ts: " sampler samp, spvYCbCrSampler ycbcr_samp, uint sw) thread :");
7270 statement(ts: " plane0(plane0), plane1(plane1), samp(samp), ycbcr_samp(ycbcr_samp), swizzle(sw) {}");
7271 statement(
7272 ts: "constexpr spvDynamicImageSampler(texture2d<T> plane0, texture2d<T> plane1, texture2d<T> plane2,");
7273 statement(ts: " sampler samp, spvYCbCrSampler ycbcr_samp, uint sw) thread :");
7274 statement(ts: " plane0(plane0), plane1(plane1), plane2(plane2), samp(samp), ycbcr_samp(ycbcr_samp), "
7275 "swizzle(sw) {}");
7276 statement(ts: "");
7277 // XXX This is really hard to follow... I've left comments to make it a bit easier.
7278 statement(ts: "template<typename... LodOptions>");
7279 statement(ts: "vec<T, 4> do_sample(float2 coord, LodOptions... options) const thread");
7280 begin_scope();
7281 statement(ts: "if (!is_null_texture(plane1))");
7282 begin_scope();
7283 statement(ts: "if (ycbcr_samp.get_resolution() == spvFormatResolution::_444 ||");
7284 statement(ts: " ycbcr_samp.get_chroma_filter() == spvChromaFilter::nearest)");
7285 begin_scope();
7286 statement(ts: "if (!is_null_texture(plane2))");
7287 statement(ts: " return spvChromaReconstructNearest(plane0, plane1, plane2, samp, coord,");
7288 statement(ts: " spvForward<LodOptions>(options)...);");
7289 statement(
7290 ts: "return spvChromaReconstructNearest(plane0, plane1, samp, coord, spvForward<LodOptions>(options)...);");
7291 end_scope(); // if (resolution == 422 || chroma_filter == nearest)
7292 statement(ts: "switch (ycbcr_samp.get_resolution())");
7293 begin_scope();
7294 statement(ts: "case spvFormatResolution::_444: break;");
7295 statement(ts: "case spvFormatResolution::_422:");
7296 begin_scope();
7297 statement(ts: "switch (ycbcr_samp.get_x_chroma_offset())");
7298 begin_scope();
7299 statement(ts: "case spvXChromaLocation::cosited_even:");
7300 statement(ts: " if (!is_null_texture(plane2))");
7301 statement(ts: " return spvChromaReconstructLinear422CositedEven(");
7302 statement(ts: " plane0, plane1, plane2, samp,");
7303 statement(ts: " coord, spvForward<LodOptions>(options)...);");
7304 statement(ts: " return spvChromaReconstructLinear422CositedEven(");
7305 statement(ts: " plane0, plane1, samp, coord,");
7306 statement(ts: " spvForward<LodOptions>(options)...);");
7307 statement(ts: "case spvXChromaLocation::midpoint:");
7308 statement(ts: " if (!is_null_texture(plane2))");
7309 statement(ts: " return spvChromaReconstructLinear422Midpoint(");
7310 statement(ts: " plane0, plane1, plane2, samp,");
7311 statement(ts: " coord, spvForward<LodOptions>(options)...);");
7312 statement(ts: " return spvChromaReconstructLinear422Midpoint(");
7313 statement(ts: " plane0, plane1, samp, coord,");
7314 statement(ts: " spvForward<LodOptions>(options)...);");
7315 end_scope(); // switch (x_chroma_offset)
7316 end_scope(); // case 422:
7317 statement(ts: "case spvFormatResolution::_420:");
7318 begin_scope();
7319 statement(ts: "switch (ycbcr_samp.get_x_chroma_offset())");
7320 begin_scope();
7321 statement(ts: "case spvXChromaLocation::cosited_even:");
7322 begin_scope();
7323 statement(ts: "switch (ycbcr_samp.get_y_chroma_offset())");
7324 begin_scope();
7325 statement(ts: "case spvYChromaLocation::cosited_even:");
7326 statement(ts: " if (!is_null_texture(plane2))");
7327 statement(ts: " return spvChromaReconstructLinear420XCositedEvenYCositedEven(");
7328 statement(ts: " plane0, plane1, plane2, samp,");
7329 statement(ts: " coord, spvForward<LodOptions>(options)...);");
7330 statement(ts: " return spvChromaReconstructLinear420XCositedEvenYCositedEven(");
7331 statement(ts: " plane0, plane1, samp, coord,");
7332 statement(ts: " spvForward<LodOptions>(options)...);");
7333 statement(ts: "case spvYChromaLocation::midpoint:");
7334 statement(ts: " if (!is_null_texture(plane2))");
7335 statement(ts: " return spvChromaReconstructLinear420XCositedEvenYMidpoint(");
7336 statement(ts: " plane0, plane1, plane2, samp,");
7337 statement(ts: " coord, spvForward<LodOptions>(options)...);");
7338 statement(ts: " return spvChromaReconstructLinear420XCositedEvenYMidpoint(");
7339 statement(ts: " plane0, plane1, samp, coord,");
7340 statement(ts: " spvForward<LodOptions>(options)...);");
7341 end_scope(); // switch (y_chroma_offset)
7342 end_scope(); // case x::cosited_even:
7343 statement(ts: "case spvXChromaLocation::midpoint:");
7344 begin_scope();
7345 statement(ts: "switch (ycbcr_samp.get_y_chroma_offset())");
7346 begin_scope();
7347 statement(ts: "case spvYChromaLocation::cosited_even:");
7348 statement(ts: " if (!is_null_texture(plane2))");
7349 statement(ts: " return spvChromaReconstructLinear420XMidpointYCositedEven(");
7350 statement(ts: " plane0, plane1, plane2, samp,");
7351 statement(ts: " coord, spvForward<LodOptions>(options)...);");
7352 statement(ts: " return spvChromaReconstructLinear420XMidpointYCositedEven(");
7353 statement(ts: " plane0, plane1, samp, coord,");
7354 statement(ts: " spvForward<LodOptions>(options)...);");
7355 statement(ts: "case spvYChromaLocation::midpoint:");
7356 statement(ts: " if (!is_null_texture(plane2))");
7357 statement(ts: " return spvChromaReconstructLinear420XMidpointYMidpoint(");
7358 statement(ts: " plane0, plane1, plane2, samp,");
7359 statement(ts: " coord, spvForward<LodOptions>(options)...);");
7360 statement(ts: " return spvChromaReconstructLinear420XMidpointYMidpoint(");
7361 statement(ts: " plane0, plane1, samp, coord,");
7362 statement(ts: " spvForward<LodOptions>(options)...);");
7363 end_scope(); // switch (y_chroma_offset)
7364 end_scope(); // case x::midpoint
7365 end_scope(); // switch (x_chroma_offset)
7366 end_scope(); // case 420:
7367 end_scope(); // switch (resolution)
7368 end_scope(); // if (multiplanar)
7369 statement(ts: "return plane0.sample(samp, coord, spvForward<LodOptions>(options)...);");
7370 end_scope(); // do_sample()
7371 statement(ts: "template <typename... LodOptions>");
7372 statement(ts: "vec<T, 4> sample(float2 coord, LodOptions... options) const thread");
7373 begin_scope();
7374 statement(
7375 ts: "vec<T, 4> s = spvTextureSwizzle(do_sample(coord, spvForward<LodOptions>(options)...), swizzle);");
7376 statement(ts: "if (ycbcr_samp.get_ycbcr_model() == spvYCbCrModelConversion::rgb_identity)");
7377 statement(ts: " return s;");
7378 statement(ts: "");
7379 statement(ts: "switch (ycbcr_samp.get_ycbcr_range())");
7380 begin_scope();
7381 statement(ts: "case spvYCbCrRange::itu_full:");
7382 statement(ts: " s = spvExpandITUFullRange(s, ycbcr_samp.get_bpc());");
7383 statement(ts: " break;");
7384 statement(ts: "case spvYCbCrRange::itu_narrow:");
7385 statement(ts: " s = spvExpandITUNarrowRange(s, ycbcr_samp.get_bpc());");
7386 statement(ts: " break;");
7387 end_scope();
7388 statement(ts: "");
7389 statement(ts: "switch (ycbcr_samp.get_ycbcr_model())");
7390 begin_scope();
7391 statement(ts: "case spvYCbCrModelConversion::rgb_identity:"); // Silence Clang warning
7392 statement(ts: "case spvYCbCrModelConversion::ycbcr_identity:");
7393 statement(ts: " return s;");
7394 statement(ts: "case spvYCbCrModelConversion::ycbcr_bt_709:");
7395 statement(ts: " return spvConvertYCbCrBT709(s);");
7396 statement(ts: "case spvYCbCrModelConversion::ycbcr_bt_601:");
7397 statement(ts: " return spvConvertYCbCrBT601(s);");
7398 statement(ts: "case spvYCbCrModelConversion::ycbcr_bt_2020:");
7399 statement(ts: " return spvConvertYCbCrBT2020(s);");
7400 end_scope();
7401 end_scope();
7402 statement(ts: "");
7403 // Sampler Y'CbCr conversion forbids offsets.
7404 statement(ts: "vec<T, 4> sample(float2 coord, int2 offset) const thread");
7405 begin_scope();
7406 if (msl_options.swizzle_texture_samples)
7407 statement(ts: "return spvTextureSwizzle(plane0.sample(samp, coord, offset), swizzle);");
7408 else
7409 statement(ts: "return plane0.sample(samp, coord, offset);");
7410 end_scope();
7411 statement(ts: "template<typename lod_options>");
7412 statement(ts: "vec<T, 4> sample(float2 coord, lod_options options, int2 offset) const thread");
7413 begin_scope();
7414 if (msl_options.swizzle_texture_samples)
7415 statement(ts: "return spvTextureSwizzle(plane0.sample(samp, coord, options, offset), swizzle);");
7416 else
7417 statement(ts: "return plane0.sample(samp, coord, options, offset);");
7418 end_scope();
7419 statement(ts: "#if __HAVE_MIN_LOD_CLAMP__");
7420 statement(ts: "vec<T, 4> sample(float2 coord, bias b, min_lod_clamp min_lod, int2 offset) const thread");
7421 begin_scope();
7422 statement(ts: "return plane0.sample(samp, coord, b, min_lod, offset);");
7423 end_scope();
7424 statement(
7425 ts: "vec<T, 4> sample(float2 coord, gradient2d grad, min_lod_clamp min_lod, int2 offset) const thread");
7426 begin_scope();
7427 statement(ts: "return plane0.sample(samp, coord, grad, min_lod, offset);");
7428 end_scope();
7429 statement(ts: "#endif");
7430 statement(ts: "");
7431 // Y'CbCr conversion forbids all operations but sampling.
7432 statement(ts: "vec<T, 4> read(uint2 coord, uint lod = 0) const thread");
7433 begin_scope();
7434 statement(ts: "return plane0.read(coord, lod);");
7435 end_scope();
7436 statement(ts: "");
7437 statement(ts: "vec<T, 4> gather(float2 coord, int2 offset = int2(0), component c = component::x) const thread");
7438 begin_scope();
7439 if (msl_options.swizzle_texture_samples)
7440 statement(ts: "return spvGatherSwizzle(plane0, samp, swizzle, c, coord, offset);");
7441 else
7442 statement(ts: "return plane0.gather(samp, coord, offset, c);");
7443 end_scope();
7444 end_scope_decl();
7445 statement(ts: "");
7446 break;
7447
7448 case SPVFuncImplRayQueryIntersectionParams:
7449 statement(ts: "intersection_params spvMakeIntersectionParams(uint flags)");
7450 begin_scope();
7451 statement(ts: "intersection_params ip;");
7452 statement(ts: "if ((flags & ", ts: RayFlagsOpaqueKHRMask, ts: ") != 0)");
7453 statement(ts: " ip.force_opacity(forced_opacity::opaque);");
7454 statement(ts: "if ((flags & ", ts: RayFlagsNoOpaqueKHRMask, ts: ") != 0)");
7455 statement(ts: " ip.force_opacity(forced_opacity::non_opaque);");
7456 statement(ts: "if ((flags & ", ts: RayFlagsTerminateOnFirstHitKHRMask, ts: ") != 0)");
7457 statement(ts: " ip.accept_any_intersection(true);");
7458 // RayFlagsSkipClosestHitShaderKHRMask is not available in MSL
7459 statement(ts: "if ((flags & ", ts: RayFlagsCullBackFacingTrianglesKHRMask, ts: ") != 0)");
7460 statement(ts: " ip.set_triangle_cull_mode(triangle_cull_mode::back);");
7461 statement(ts: "if ((flags & ", ts: RayFlagsCullFrontFacingTrianglesKHRMask, ts: ") != 0)");
7462 statement(ts: " ip.set_triangle_cull_mode(triangle_cull_mode::front);");
7463 statement(ts: "if ((flags & ", ts: RayFlagsCullOpaqueKHRMask, ts: ") != 0)");
7464 statement(ts: " ip.set_opacity_cull_mode(opacity_cull_mode::opaque);");
7465 statement(ts: "if ((flags & ", ts: RayFlagsCullNoOpaqueKHRMask, ts: ") != 0)");
7466 statement(ts: " ip.set_opacity_cull_mode(opacity_cull_mode::non_opaque);");
7467 statement(ts: "if ((flags & ", ts: RayFlagsSkipTrianglesKHRMask, ts: ") != 0)");
7468 statement(ts: " ip.set_geometry_cull_mode(geometry_cull_mode::triangle);");
7469 statement(ts: "if ((flags & ", ts: RayFlagsSkipAABBsKHRMask, ts: ") != 0)");
7470 statement(ts: " ip.set_geometry_cull_mode(geometry_cull_mode::bounding_box);");
7471 statement(ts: "return ip;");
7472 end_scope();
7473 statement(ts: "");
7474 break;
7475
7476 case SPVFuncImplVariableDescriptor:
7477 statement(ts: "template<typename T>");
7478 statement(ts: "struct spvDescriptor");
7479 begin_scope();
7480 statement(ts: "T value;");
7481 end_scope_decl();
7482 statement(ts: "");
7483 break;
7484
7485 case SPVFuncImplVariableSizedDescriptor:
7486 statement(ts: "template<typename T>");
7487 statement(ts: "struct spvBufferDescriptor");
7488 begin_scope();
7489 statement(ts: "T value;");
7490 statement(ts: "int length;");
7491 statement(ts: "const device T& operator -> () const device");
7492 begin_scope();
7493 statement(ts: "return value;");
7494 end_scope();
7495 statement(ts: "const device T& operator * () const device");
7496 begin_scope();
7497 statement(ts: "return value;");
7498 end_scope();
7499 end_scope_decl();
7500 statement(ts: "");
7501 break;
7502
7503 case SPVFuncImplVariableDescriptorArray:
7504 if (spv_function_implementations.count(x: SPVFuncImplVariableDescriptor) != 0)
7505 {
7506 statement(ts: "template<typename T>");
7507 statement(ts: "struct spvDescriptorArray");
7508 begin_scope();
7509 statement(ts: "spvDescriptorArray(const device spvDescriptor<T>* ptr) : ptr(&ptr->value)");
7510 begin_scope();
7511 end_scope();
7512 statement(ts: "const device T& operator [] (size_t i) const");
7513 begin_scope();
7514 statement(ts: "return ptr[i];");
7515 end_scope();
7516 statement(ts: "const device T* ptr;");
7517 end_scope_decl();
7518 statement(ts: "");
7519 }
7520 else
7521 {
7522 statement(ts: "template<typename T>");
7523 statement(ts: "struct spvDescriptorArray;");
7524 statement(ts: "");
7525 }
7526
7527 if (msl_options.runtime_array_rich_descriptor &&
7528 spv_function_implementations.count(x: SPVFuncImplVariableSizedDescriptor) != 0)
7529 {
7530 statement(ts: "template<typename T>");
7531 statement(ts: "struct spvDescriptorArray<device T*>");
7532 begin_scope();
7533 statement(ts: "spvDescriptorArray(const device spvBufferDescriptor<device T*>* ptr) : ptr(ptr)");
7534 begin_scope();
7535 end_scope();
7536 statement(ts: "const device T* operator [] (size_t i) const");
7537 begin_scope();
7538 statement(ts: "return ptr[i].value;");
7539 end_scope();
7540 statement(ts: "const int length(int i) const");
7541 begin_scope();
7542 statement(ts: "return ptr[i].length;");
7543 end_scope();
7544 statement(ts: "const device spvBufferDescriptor<device T*>* ptr;");
7545 end_scope_decl();
7546 statement(ts: "");
7547 }
7548 break;
7549
7550 case SPVFuncImplPaddedStd140:
7551 // .data is used in access chain.
7552 statement(ts: "template <typename T>");
7553 statement(ts: "struct spvPaddedStd140 { alignas(16) T data; };");
7554 statement(ts: "template <typename T, int n>");
7555 statement(ts: "using spvPaddedStd140Matrix = spvPaddedStd140<T>[n];");
7556 statement(ts: "");
7557 break;
7558
7559 case SPVFuncImplReduceAdd:
7560 // Metal doesn't support __builtin_reduce_add or simd_reduce_add, so we need this.
7561 // Metal also doesn't support the other vector builtins, which would have been useful to make this a single template.
7562
7563 statement(ts: "template <typename T>");
7564 statement(ts: "T reduce_add(vec<T, 2> v) { return v.x + v.y; }");
7565
7566 statement(ts: "template <typename T>");
7567 statement(ts: "T reduce_add(vec<T, 3> v) { return v.x + v.y + v.z; }");
7568
7569 statement(ts: "template <typename T>");
7570 statement(ts: "T reduce_add(vec<T, 4> v) { return v.x + v.y + v.z + v.w; }");
7571
7572 statement(ts: "");
7573 break;
7574
7575 case SPVFuncImplImageFence:
7576 statement(ts: "template <typename ImageT>");
7577 statement(ts: "void spvImageFence(ImageT img) { img.fence(); }");
7578 statement(ts: "");
7579 break;
7580
7581 case SPVFuncImplTextureCast:
7582 statement(ts: "template <typename T, typename U>");
7583 statement(ts: "T spvTextureCast(U img)");
7584 begin_scope();
7585 // MSL complains if you try to cast the texture itself, but casting the reference type is ... ok? *shrug*
7586 // Gotta go what you gotta do I suppose.
7587 statement(ts: "return reinterpret_cast<thread const T &>(img);");
7588 end_scope();
7589 statement(ts: "");
7590 break;
7591
7592 default:
7593 break;
7594 }
7595 }
7596}
7597
7598static string inject_top_level_storage_qualifier(const string &expr, const string &qualifier)
7599{
7600 // Easier to do this through text munging since the qualifier does not exist in the type system at all,
7601 // and plumbing in all that information is not very helpful.
7602 size_t last_reference = expr.find_last_of(c: '&');
7603 size_t last_pointer = expr.find_last_of(c: '*');
7604 size_t last_significant = string::npos;
7605
7606 if (last_reference == string::npos)
7607 last_significant = last_pointer;
7608 else if (last_pointer == string::npos)
7609 last_significant = last_reference;
7610 else
7611 last_significant = max<size_t>(a: last_reference, b: last_pointer);
7612
7613 if (last_significant == string::npos)
7614 return join(ts: qualifier, ts: " ", ts: expr);
7615 else
7616 {
7617 return join(ts: expr.substr(pos: 0, n: last_significant + 1), ts: " ",
7618 ts: qualifier, ts: expr.substr(pos: last_significant + 1, n: string::npos));
7619 }
7620}
7621
7622void CompilerMSL::declare_constant_arrays()
7623{
7624 bool fully_inlined = ir.ids_for_type[TypeFunction].size() == 1;
7625
7626 // MSL cannot declare arrays inline (except when declaring a variable), so we must move them out to
7627 // global constants directly, so we are able to use constants as variable expressions.
7628 bool emitted = false;
7629
7630 ir.for_each_typed_id<SPIRConstant>(op: [&](uint32_t, SPIRConstant &c) {
7631 if (c.specialization)
7632 return;
7633
7634 auto &type = this->get<SPIRType>(id: c.constant_type);
7635 // Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries.
7636 // FIXME: However, hoisting constants to main() means we need to pass down constant arrays to leaf functions if they are used there.
7637 // If there are multiple functions in the module, drop this case to avoid breaking use cases which do not need to
7638 // link into Metal libraries. This is hacky.
7639 if (is_array(type) && (!fully_inlined || is_scalar(type) || is_vector(type)))
7640 {
7641 add_resource_name(id: c.self);
7642 auto name = to_name(id: c.self);
7643 statement(ts: inject_top_level_storage_qualifier(expr: variable_decl(type, name), qualifier: "constant"),
7644 ts: " = ", ts: constant_expression(c), ts: ";");
7645 emitted = true;
7646 }
7647 });
7648
7649 if (emitted)
7650 statement(ts: "");
7651}
7652
7653// Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries
7654void CompilerMSL::declare_complex_constant_arrays()
7655{
7656 // If we do not have a fully inlined module, we did not opt in to
7657 // declaring constant arrays of complex types. See CompilerMSL::declare_constant_arrays().
7658 bool fully_inlined = ir.ids_for_type[TypeFunction].size() == 1;
7659 if (!fully_inlined)
7660 return;
7661
7662 // MSL cannot declare arrays inline (except when declaring a variable), so we must move them out to
7663 // global constants directly, so we are able to use constants as variable expressions.
7664 bool emitted = false;
7665
7666 ir.for_each_typed_id<SPIRConstant>(op: [&](uint32_t, SPIRConstant &c) {
7667 if (c.specialization)
7668 return;
7669
7670 auto &type = this->get<SPIRType>(id: c.constant_type);
7671 if (is_array(type) && !(is_scalar(type) || is_vector(type)))
7672 {
7673 add_resource_name(id: c.self);
7674 auto name = to_name(id: c.self);
7675 statement(ts: "", ts: variable_decl(type, name), ts: " = ", ts: constant_expression(c), ts: ";");
7676 emitted = true;
7677 }
7678 });
7679
7680 if (emitted)
7681 statement(ts: "");
7682}
7683
7684void CompilerMSL::emit_resources()
7685{
7686 declare_constant_arrays();
7687
7688 // Emit the special [[stage_in]] and [[stage_out]] interface blocks which we created.
7689 emit_interface_block(ib_var_id: stage_out_var_id);
7690 emit_interface_block(ib_var_id: patch_stage_out_var_id);
7691 emit_interface_block(ib_var_id: stage_in_var_id);
7692 emit_interface_block(ib_var_id: patch_stage_in_var_id);
7693}
7694
7695// Emit declarations for the specialization Metal function constants
7696void CompilerMSL::emit_specialization_constants_and_structs()
7697{
7698 SpecializationConstant wg_x, wg_y, wg_z;
7699 ID workgroup_size_id = get_work_group_size_specialization_constants(x&: wg_x, y&: wg_y, z&: wg_z);
7700 bool emitted = false;
7701
7702 unordered_set<uint32_t> declared_structs;
7703 unordered_set<uint32_t> aligned_structs;
7704
7705 // First, we need to deal with scalar block layout.
7706 // It is possible that a struct may have to be placed at an alignment which does not match the innate alignment of the struct itself.
7707 // In that case, if such a case exists for a struct, we must force that all elements of the struct become packed_ types.
7708 // This makes the struct alignment as small as physically possible.
7709 // When we actually align the struct later, we can insert padding as necessary to make the packed members behave like normally aligned types.
7710 ir.for_each_typed_id<SPIRType>(op: [&](uint32_t type_id, const SPIRType &type) {
7711 if (type.basetype == SPIRType::Struct &&
7712 has_extended_decoration(id: type_id, decoration: SPIRVCrossDecorationBufferBlockRepacked))
7713 mark_scalar_layout_structs(type);
7714 });
7715
7716 bool builtin_block_type_is_required = false;
7717 // Very special case. If gl_PerVertex is initialized as an array (tessellation)
7718 // we have to potentially emit the gl_PerVertex struct type so that we can emit a constant LUT.
7719 ir.for_each_typed_id<SPIRConstant>(op: [&](uint32_t, SPIRConstant &c) {
7720 auto &type = this->get<SPIRType>(id: c.constant_type);
7721 if (is_array(type) && has_decoration(id: type.self, decoration: DecorationBlock) && is_builtin_type(type))
7722 builtin_block_type_is_required = true;
7723 });
7724
7725 // Very particular use of the soft loop lock.
7726 // align_struct may need to create custom types on the fly, but we don't care about
7727 // these types for purpose of iterating over them in ir.ids_for_type and friends.
7728 auto loop_lock = ir.create_loop_soft_lock();
7729
7730 // Physical storage buffer pointers can have cyclical references,
7731 // so emit forward declarations of them before other structs.
7732 // Ignore type_id because we want the underlying struct type from the pointer.
7733 ir.for_each_typed_id<SPIRType>(op: [&](uint32_t /* type_id */, const SPIRType &type) {
7734 if (type.basetype == SPIRType::Struct &&
7735 type.pointer && type.storage == StorageClassPhysicalStorageBuffer &&
7736 declared_structs.count(x: type.self) == 0)
7737 {
7738 statement(ts: "struct ", ts: to_name(id: type.self), ts: ";");
7739 declared_structs.insert(x: type.self);
7740 emitted = true;
7741 }
7742 });
7743 if (emitted)
7744 statement(ts: "");
7745
7746 emitted = false;
7747 declared_structs.clear();
7748
7749 // It is possible to have multiple spec constants that use the same spec constant ID.
7750 // The most common cause of this is defining spec constants in GLSL while also declaring
7751 // the workgroup size to use those spec constants. But, Metal forbids declaring more than
7752 // one variable with the same function constant ID.
7753 // In this case, we must only declare one variable with the [[function_constant(id)]]
7754 // attribute, and use its initializer to initialize all the spec constants with
7755 // that ID.
7756 std::unordered_map<uint32_t, ConstantID> unique_func_constants;
7757
7758 for (auto &id_ : ir.ids_for_constant_undef_or_type)
7759 {
7760 auto &id = ir.ids[id_];
7761
7762 if (id.get_type() == TypeConstant)
7763 {
7764 auto &c = id.get<SPIRConstant>();
7765
7766 if (c.self == workgroup_size_id)
7767 {
7768 // TODO: This can be expressed as a [[threads_per_threadgroup]] input semantic, but we need to know
7769 // the work group size at compile time in SPIR-V, and [[threads_per_threadgroup]] would need to be passed around as a global.
7770 // The work group size may be a specialization constant.
7771 statement(ts: "constant uint3 ", ts: builtin_to_glsl(builtin: BuiltInWorkgroupSize, storage: StorageClassWorkgroup),
7772 ts: " [[maybe_unused]] = ", ts: constant_expression(c: get<SPIRConstant>(id: workgroup_size_id)), ts: ";");
7773 emitted = true;
7774 }
7775 else if (c.specialization)
7776 {
7777 auto &type = get<SPIRType>(id: c.constant_type);
7778 string sc_type_name = type_to_glsl(type);
7779 add_resource_name(id: c.self);
7780 string sc_name = to_name(id: c.self);
7781
7782 // Function constants are only supported in MSL 1.2 and later.
7783 // If we don't support it just declare the "default" directly.
7784 // This "default" value can be overridden to the true specialization constant by the API user.
7785 // Specialization constants which are used as array length expressions cannot be function constants in MSL,
7786 // so just fall back to macros.
7787 if (msl_options.supports_msl_version(major: 1, minor: 2) && has_decoration(id: c.self, decoration: DecorationSpecId) &&
7788 !c.is_used_as_array_length)
7789 {
7790 // Only scalar, non-composite values can be function constants.
7791 uint32_t constant_id = get_decoration(id: c.self, decoration: DecorationSpecId);
7792 if (!unique_func_constants.count(x: constant_id))
7793 unique_func_constants.insert(x: make_pair(x&: constant_id, y&: c.self));
7794 SPIRType::BaseType sc_tmp_type = expression_type(id: unique_func_constants[constant_id]).basetype;
7795 string sc_tmp_name = to_name(id: unique_func_constants[constant_id]) + "_tmp";
7796 if (unique_func_constants[constant_id] == c.self)
7797 statement(ts: "constant ", ts&: sc_type_name, ts: " ", ts&: sc_tmp_name, ts: " [[function_constant(", ts&: constant_id,
7798 ts: ")]];");
7799 statement(ts: "constant ", ts&: sc_type_name, ts: " ", ts&: sc_name, ts: " = is_function_constant_defined(", ts&: sc_tmp_name,
7800 ts: ") ? ", ts: bitcast_expression(target_type: type, expr_type: sc_tmp_type, expr: sc_tmp_name), ts: " : ", ts: constant_expression(c),
7801 ts: ";");
7802 }
7803 else if (has_decoration(id: c.self, decoration: DecorationSpecId))
7804 {
7805 // Fallback to macro overrides.
7806 c.specialization_constant_macro_name =
7807 constant_value_macro_name(id: get_decoration(id: c.self, decoration: DecorationSpecId));
7808
7809 statement(ts: "#ifndef ", ts&: c.specialization_constant_macro_name);
7810 statement(ts: "#define ", ts&: c.specialization_constant_macro_name, ts: " ", ts: constant_expression(c));
7811 statement(ts: "#endif");
7812 statement(ts: "constant ", ts&: sc_type_name, ts: " ", ts&: sc_name, ts: " = ", ts&: c.specialization_constant_macro_name,
7813 ts: ";");
7814 }
7815 else
7816 {
7817 // Composite specialization constants must be built from other specialization constants.
7818 statement(ts: "constant ", ts&: sc_type_name, ts: " ", ts&: sc_name, ts: " = ", ts: constant_expression(c), ts: ";");
7819 }
7820 emitted = true;
7821 }
7822 }
7823 else if (id.get_type() == TypeConstantOp)
7824 {
7825 auto &c = id.get<SPIRConstantOp>();
7826 auto &type = get<SPIRType>(id: c.basetype);
7827 add_resource_name(id: c.self);
7828 auto name = to_name(id: c.self);
7829 statement(ts: "constant ", ts: variable_decl(type, name), ts: " = ", ts: constant_op_expression(cop: c), ts: ";");
7830 emitted = true;
7831 }
7832 else if (id.get_type() == TypeType)
7833 {
7834 // Output non-builtin interface structs. These include local function structs
7835 // and structs nested within uniform and read-write buffers.
7836 auto &type = id.get<SPIRType>();
7837 TypeID type_id = type.self;
7838
7839 bool is_struct = (type.basetype == SPIRType::Struct) && type.array.empty() && !type.pointer;
7840 bool is_block =
7841 has_decoration(id: type.self, decoration: DecorationBlock) || has_decoration(id: type.self, decoration: DecorationBufferBlock);
7842
7843 bool is_builtin_block = is_block && is_builtin_type(type);
7844 bool is_declarable_struct = is_struct && (!is_builtin_block || builtin_block_type_is_required);
7845
7846 // We'll declare this later.
7847 if (stage_out_var_id && get_stage_out_struct_type().self == type_id)
7848 is_declarable_struct = false;
7849 if (patch_stage_out_var_id && get_patch_stage_out_struct_type().self == type_id)
7850 is_declarable_struct = false;
7851 if (stage_in_var_id && get_stage_in_struct_type().self == type_id)
7852 is_declarable_struct = false;
7853 if (patch_stage_in_var_id && get_patch_stage_in_struct_type().self == type_id)
7854 is_declarable_struct = false;
7855
7856 // Special case. Declare builtin struct anyways if we need to emit a threadgroup version of it.
7857 if (stage_out_masked_builtin_type_id == type_id)
7858 is_declarable_struct = true;
7859
7860 // Align and emit declarable structs...but avoid declaring each more than once.
7861 if (is_declarable_struct && declared_structs.count(x: type_id) == 0)
7862 {
7863 if (emitted)
7864 statement(ts: "");
7865 emitted = false;
7866
7867 declared_structs.insert(x: type_id);
7868
7869 if (has_extended_decoration(id: type_id, decoration: SPIRVCrossDecorationBufferBlockRepacked))
7870 align_struct(ib_type&: type, aligned_structs);
7871
7872 // Make sure we declare the underlying struct type, and not the "decorated" type with pointers, etc.
7873 emit_struct(type&: get<SPIRType>(id: type_id));
7874 }
7875 }
7876 else if (id.get_type() == TypeUndef)
7877 {
7878 auto &undef = id.get<SPIRUndef>();
7879 auto &type = get<SPIRType>(id: undef.basetype);
7880 // OpUndef can be void for some reason ...
7881 if (type.basetype == SPIRType::Void)
7882 return;
7883
7884 // Undefined global memory is not allowed in MSL.
7885 // Declare constant and init to zeros. Use {}, as global constructors can break Metal.
7886 statement(
7887 ts: inject_top_level_storage_qualifier(expr: variable_decl(type, name: to_name(id: undef.self), id: undef.self), qualifier: "constant"),
7888 ts: " = {};");
7889 emitted = true;
7890 }
7891 }
7892
7893 if (emitted)
7894 statement(ts: "");
7895}
7896
7897void CompilerMSL::emit_binary_ptr_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op)
7898{
7899 bool forward = should_forward(id: op0) && should_forward(id: op1);
7900 emit_op(result_type, result_id, rhs: join(ts: to_ptr_expression(id: op0), ts: " ", ts&: op, ts: " ", ts: to_ptr_expression(id: op1)), forward_rhs: forward);
7901 inherit_expression_dependencies(dst: result_id, source: op0);
7902 inherit_expression_dependencies(dst: result_id, source: op1);
7903}
7904
7905string CompilerMSL::to_ptr_expression(uint32_t id, bool register_expression_read)
7906{
7907 auto *e = maybe_get<SPIRExpression>(id);
7908 auto expr = enclose_expression(expr: e && e->need_transpose ? e->expression : to_expression(id, register_expression_read));
7909 if (!should_dereference(id))
7910 expr = address_of_expression(expr);
7911 return expr;
7912}
7913
7914void CompilerMSL::emit_binary_unord_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1,
7915 const char *op)
7916{
7917 bool forward = should_forward(id: op0) && should_forward(id: op1);
7918 emit_op(result_type, result_id,
7919 rhs: join(ts: "(isunordered(", ts: to_enclosed_unpacked_expression(id: op0), ts: ", ", ts: to_enclosed_unpacked_expression(id: op1),
7920 ts: ") || ", ts: to_enclosed_unpacked_expression(id: op0), ts: " ", ts&: op, ts: " ", ts: to_enclosed_unpacked_expression(id: op1),
7921 ts: ")"),
7922 forward_rhs: forward);
7923
7924 inherit_expression_dependencies(dst: result_id, source: op0);
7925 inherit_expression_dependencies(dst: result_id, source: op1);
7926}
7927
7928bool CompilerMSL::emit_tessellation_io_load(uint32_t result_type_id, uint32_t id, uint32_t ptr)
7929{
7930 auto &ptr_type = expression_type(id: ptr);
7931 auto &result_type = get<SPIRType>(id: result_type_id);
7932 if (ptr_type.storage != StorageClassInput && ptr_type.storage != StorageClassOutput)
7933 return false;
7934 if (ptr_type.storage == StorageClassOutput && is_tese_shader())
7935 return false;
7936
7937 if (has_decoration(id: ptr, decoration: DecorationPatch))
7938 return false;
7939 bool ptr_is_io_variable = ir.ids[ptr].get_type() == TypeVariable;
7940
7941 bool flattened_io = variable_storage_requires_stage_io(storage: ptr_type.storage);
7942
7943 bool flat_data_type = flattened_io &&
7944 (is_matrix(type: result_type) || is_array(type: result_type) || result_type.basetype == SPIRType::Struct);
7945
7946 // Edge case, even with multi-patch workgroups, we still need to unroll load
7947 // if we're loading control points directly.
7948 if (ptr_is_io_variable && is_array(type: result_type))
7949 flat_data_type = true;
7950
7951 if (!flat_data_type)
7952 return false;
7953
7954 // Now, we must unflatten a composite type and take care of interleaving array access with gl_in/gl_out.
7955 // Lots of painful code duplication since we *really* should not unroll these kinds of loads in entry point fixup
7956 // unless we're forced to do this when the code is emitting inoptimal OpLoads.
7957 string expr;
7958
7959 uint32_t interface_index = get_extended_decoration(id: ptr, decoration: SPIRVCrossDecorationInterfaceMemberIndex);
7960 auto *var = maybe_get_backing_variable(chain: ptr);
7961 auto &expr_type = get_pointee_type(type_id: ptr_type.self);
7962
7963 const auto &iface_type = expression_type(id: stage_in_ptr_var_id);
7964
7965 if (!flattened_io)
7966 {
7967 // Simplest case for multi-patch workgroups, just unroll array as-is.
7968 if (interface_index == uint32_t(-1))
7969 return false;
7970
7971 expr += type_to_glsl(type: result_type) + "({ ";
7972 uint32_t num_control_points = to_array_size_literal(type: result_type, index: uint32_t(result_type.array.size()) - 1);
7973
7974 for (uint32_t i = 0; i < num_control_points; i++)
7975 {
7976 const uint32_t indices[2] = { i, interface_index };
7977 AccessChainMeta meta;
7978 expr += access_chain_internal(base: stage_in_ptr_var_id, indices, count: 2,
7979 flags: ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, meta: &meta);
7980 if (i + 1 < num_control_points)
7981 expr += ", ";
7982 }
7983 expr += " })";
7984 }
7985 else if (result_type.array.size() > 2)
7986 {
7987 SPIRV_CROSS_THROW("Cannot load tessellation IO variables with more than 2 dimensions.");
7988 }
7989 else if (result_type.array.size() == 2)
7990 {
7991 if (!ptr_is_io_variable)
7992 SPIRV_CROSS_THROW("Loading an array-of-array must be loaded directly from an IO variable.");
7993 if (interface_index == uint32_t(-1))
7994 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
7995 if (result_type.basetype == SPIRType::Struct || is_matrix(type: result_type))
7996 SPIRV_CROSS_THROW("Cannot load array-of-array of composite type in tessellation IO.");
7997
7998 expr += type_to_glsl(type: result_type) + "({ ";
7999 uint32_t num_control_points = to_array_size_literal(type: result_type, index: 1);
8000 uint32_t base_interface_index = interface_index;
8001
8002 auto &sub_type = get<SPIRType>(id: result_type.parent_type);
8003
8004 for (uint32_t i = 0; i < num_control_points; i++)
8005 {
8006 expr += type_to_glsl(type: sub_type) + "({ ";
8007 interface_index = base_interface_index;
8008 uint32_t array_size = to_array_size_literal(type: result_type, index: 0);
8009 for (uint32_t j = 0; j < array_size; j++, interface_index++)
8010 {
8011 const uint32_t indices[2] = { i, interface_index };
8012
8013 AccessChainMeta meta;
8014 expr += access_chain_internal(base: stage_in_ptr_var_id, indices, count: 2,
8015 flags: ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, meta: &meta);
8016 if (!is_matrix(type: sub_type) && sub_type.basetype != SPIRType::Struct &&
8017 expr_type.vecsize > sub_type.vecsize)
8018 expr += vector_swizzle(vecsize: sub_type.vecsize, index: 0);
8019
8020 if (j + 1 < array_size)
8021 expr += ", ";
8022 }
8023 expr += " })";
8024 if (i + 1 < num_control_points)
8025 expr += ", ";
8026 }
8027 expr += " })";
8028 }
8029 else if (result_type.basetype == SPIRType::Struct)
8030 {
8031 bool is_array_of_struct = is_array(type: result_type);
8032 if (is_array_of_struct && !ptr_is_io_variable)
8033 SPIRV_CROSS_THROW("Loading array of struct from IO variable must come directly from IO variable.");
8034
8035 uint32_t num_control_points = 1;
8036 if (is_array_of_struct)
8037 {
8038 num_control_points = to_array_size_literal(type: result_type, index: 0);
8039 expr += type_to_glsl(type: result_type) + "({ ";
8040 }
8041
8042 auto &struct_type = is_array_of_struct ? get<SPIRType>(id: result_type.parent_type) : result_type;
8043 assert(struct_type.array.empty());
8044
8045 for (uint32_t i = 0; i < num_control_points; i++)
8046 {
8047 expr += type_to_glsl(type: struct_type) + "{ ";
8048 for (uint32_t j = 0; j < uint32_t(struct_type.member_types.size()); j++)
8049 {
8050 // The base interface index is stored per variable for structs.
8051 if (var)
8052 {
8053 interface_index =
8054 get_extended_member_decoration(type: var->self, index: j, decoration: SPIRVCrossDecorationInterfaceMemberIndex);
8055 }
8056
8057 if (interface_index == uint32_t(-1))
8058 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
8059
8060 const auto &mbr_type = get<SPIRType>(id: struct_type.member_types[j]);
8061 const auto &expr_mbr_type = get<SPIRType>(id: expr_type.member_types[j]);
8062 if (is_matrix(type: mbr_type) && ptr_type.storage == StorageClassInput)
8063 {
8064 expr += type_to_glsl(type: mbr_type) + "(";
8065 for (uint32_t k = 0; k < mbr_type.columns; k++, interface_index++)
8066 {
8067 if (is_array_of_struct)
8068 {
8069 const uint32_t indices[2] = { i, interface_index };
8070 AccessChainMeta meta;
8071 expr += access_chain_internal(
8072 base: stage_in_ptr_var_id, indices, count: 2,
8073 flags: ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, meta: &meta);
8074 }
8075 else
8076 expr += to_expression(id: ptr) + "." + to_member_name(type: iface_type, index: interface_index);
8077 if (expr_mbr_type.vecsize > mbr_type.vecsize)
8078 expr += vector_swizzle(vecsize: mbr_type.vecsize, index: 0);
8079
8080 if (k + 1 < mbr_type.columns)
8081 expr += ", ";
8082 }
8083 expr += ")";
8084 }
8085 else if (is_array(type: mbr_type))
8086 {
8087 expr += type_to_glsl(type: mbr_type) + "({ ";
8088 uint32_t array_size = to_array_size_literal(type: mbr_type, index: 0);
8089 for (uint32_t k = 0; k < array_size; k++, interface_index++)
8090 {
8091 if (is_array_of_struct)
8092 {
8093 const uint32_t indices[2] = { i, interface_index };
8094 AccessChainMeta meta;
8095 expr += access_chain_internal(
8096 base: stage_in_ptr_var_id, indices, count: 2,
8097 flags: ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, meta: &meta);
8098 }
8099 else
8100 expr += to_expression(id: ptr) + "." + to_member_name(type: iface_type, index: interface_index);
8101 if (expr_mbr_type.vecsize > mbr_type.vecsize)
8102 expr += vector_swizzle(vecsize: mbr_type.vecsize, index: 0);
8103
8104 if (k + 1 < array_size)
8105 expr += ", ";
8106 }
8107 expr += " })";
8108 }
8109 else
8110 {
8111 if (is_array_of_struct)
8112 {
8113 const uint32_t indices[2] = { i, interface_index };
8114 AccessChainMeta meta;
8115 expr += access_chain_internal(base: stage_in_ptr_var_id, indices, count: 2,
8116 flags: ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT,
8117 meta: &meta);
8118 }
8119 else
8120 expr += to_expression(id: ptr) + "." + to_member_name(type: iface_type, index: interface_index);
8121 if (expr_mbr_type.vecsize > mbr_type.vecsize)
8122 expr += vector_swizzle(vecsize: mbr_type.vecsize, index: 0);
8123 }
8124
8125 if (j + 1 < struct_type.member_types.size())
8126 expr += ", ";
8127 }
8128 expr += " }";
8129 if (i + 1 < num_control_points)
8130 expr += ", ";
8131 }
8132 if (is_array_of_struct)
8133 expr += " })";
8134 }
8135 else if (is_matrix(type: result_type))
8136 {
8137 bool is_array_of_matrix = is_array(type: result_type);
8138 if (is_array_of_matrix && !ptr_is_io_variable)
8139 SPIRV_CROSS_THROW("Loading array of matrix from IO variable must come directly from IO variable.");
8140 if (interface_index == uint32_t(-1))
8141 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
8142
8143 if (is_array_of_matrix)
8144 {
8145 // Loading a matrix from each control point.
8146 uint32_t base_interface_index = interface_index;
8147 uint32_t num_control_points = to_array_size_literal(type: result_type, index: 0);
8148 expr += type_to_glsl(type: result_type) + "({ ";
8149
8150 auto &matrix_type = get_variable_element_type(var: get<SPIRVariable>(id: ptr));
8151
8152 for (uint32_t i = 0; i < num_control_points; i++)
8153 {
8154 interface_index = base_interface_index;
8155 expr += type_to_glsl(type: matrix_type) + "(";
8156 for (uint32_t j = 0; j < result_type.columns; j++, interface_index++)
8157 {
8158 const uint32_t indices[2] = { i, interface_index };
8159
8160 AccessChainMeta meta;
8161 expr += access_chain_internal(base: stage_in_ptr_var_id, indices, count: 2,
8162 flags: ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, meta: &meta);
8163 if (expr_type.vecsize > result_type.vecsize)
8164 expr += vector_swizzle(vecsize: result_type.vecsize, index: 0);
8165 if (j + 1 < result_type.columns)
8166 expr += ", ";
8167 }
8168 expr += ")";
8169 if (i + 1 < num_control_points)
8170 expr += ", ";
8171 }
8172
8173 expr += " })";
8174 }
8175 else
8176 {
8177 expr += type_to_glsl(type: result_type) + "(";
8178 for (uint32_t i = 0; i < result_type.columns; i++, interface_index++)
8179 {
8180 expr += to_expression(id: ptr) + "." + to_member_name(type: iface_type, index: interface_index);
8181 if (expr_type.vecsize > result_type.vecsize)
8182 expr += vector_swizzle(vecsize: result_type.vecsize, index: 0);
8183 if (i + 1 < result_type.columns)
8184 expr += ", ";
8185 }
8186 expr += ")";
8187 }
8188 }
8189 else if (ptr_is_io_variable)
8190 {
8191 assert(is_array(result_type));
8192 assert(result_type.array.size() == 1);
8193 if (interface_index == uint32_t(-1))
8194 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
8195
8196 // We're loading an array directly from a global variable.
8197 // This means we're loading one member from each control point.
8198 expr += type_to_glsl(type: result_type) + "({ ";
8199 uint32_t num_control_points = to_array_size_literal(type: result_type, index: 0);
8200
8201 for (uint32_t i = 0; i < num_control_points; i++)
8202 {
8203 const uint32_t indices[2] = { i, interface_index };
8204
8205 AccessChainMeta meta;
8206 expr += access_chain_internal(base: stage_in_ptr_var_id, indices, count: 2,
8207 flags: ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, meta: &meta);
8208 if (expr_type.vecsize > result_type.vecsize)
8209 expr += vector_swizzle(vecsize: result_type.vecsize, index: 0);
8210
8211 if (i + 1 < num_control_points)
8212 expr += ", ";
8213 }
8214 expr += " })";
8215 }
8216 else
8217 {
8218 // We're loading an array from a concrete control point.
8219 assert(is_array(result_type));
8220 assert(result_type.array.size() == 1);
8221 if (interface_index == uint32_t(-1))
8222 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
8223
8224 expr += type_to_glsl(type: result_type) + "({ ";
8225 uint32_t array_size = to_array_size_literal(type: result_type, index: 0);
8226 for (uint32_t i = 0; i < array_size; i++, interface_index++)
8227 {
8228 expr += to_expression(id: ptr) + "." + to_member_name(type: iface_type, index: interface_index);
8229 if (expr_type.vecsize > result_type.vecsize)
8230 expr += vector_swizzle(vecsize: result_type.vecsize, index: 0);
8231 if (i + 1 < array_size)
8232 expr += ", ";
8233 }
8234 expr += " })";
8235 }
8236
8237 emit_op(result_type: result_type_id, result_id: id, rhs: expr, forward_rhs: false);
8238 register_read(expr: id, chain: ptr, forwarded: false);
8239 return true;
8240}
8241
8242bool CompilerMSL::emit_tessellation_access_chain(const uint32_t *ops, uint32_t length)
8243{
8244 // If this is a per-vertex output, remap it to the I/O array buffer.
8245
8246 // Any object which did not go through IO flattening shenanigans will go there instead.
8247 // We will unflatten on-demand instead as needed, but not all possible cases can be supported, especially with arrays.
8248
8249 auto *var = maybe_get_backing_variable(chain: ops[2]);
8250 bool patch = false;
8251 bool flat_data = false;
8252 bool ptr_is_chain = false;
8253 bool flatten_composites = false;
8254
8255 bool is_block = false;
8256 bool is_arrayed = false;
8257
8258 if (var)
8259 {
8260 auto &type = get_variable_data_type(var: *var);
8261 is_block = has_decoration(id: type.self, decoration: DecorationBlock);
8262 is_arrayed = !type.array.empty();
8263
8264 flatten_composites = variable_storage_requires_stage_io(storage: var->storage);
8265 patch = has_decoration(id: ops[2], decoration: DecorationPatch) || is_patch_block(type);
8266
8267 // Should match strip_array in add_interface_block.
8268 flat_data = var->storage == StorageClassInput || (var->storage == StorageClassOutput && is_tesc_shader());
8269
8270 // Patch inputs are treated as normal block IO variables, so they don't deal with this path at all.
8271 if (patch && (!is_block || is_arrayed || var->storage == StorageClassInput))
8272 flat_data = false;
8273
8274 // We might have a chained access chain, where
8275 // we first take the access chain to the control point, and then we chain into a member or something similar.
8276 // In this case, we need to skip gl_in/gl_out remapping.
8277 // Also, skip ptr chain for patches.
8278 ptr_is_chain = var->self != ID(ops[2]);
8279 }
8280
8281 bool builtin_variable = false;
8282 bool variable_is_flat = false;
8283
8284 if (var && flat_data)
8285 {
8286 builtin_variable = is_builtin_variable(var: *var);
8287
8288 BuiltIn bi_type = BuiltInMax;
8289 if (builtin_variable && !is_block)
8290 bi_type = BuiltIn(get_decoration(id: var->self, decoration: DecorationBuiltIn));
8291
8292 variable_is_flat = !builtin_variable || is_block ||
8293 bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
8294 bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance;
8295 }
8296
8297 if (variable_is_flat)
8298 {
8299 // If output is masked, it is emitted as a "normal" variable, just go through normal code paths.
8300 // Only check this for the first level of access chain.
8301 // Dealing with this for partial access chains should be possible, but awkward.
8302 if (var->storage == StorageClassOutput && !ptr_is_chain)
8303 {
8304 bool masked = false;
8305 if (is_block)
8306 {
8307 uint32_t relevant_member_index = patch ? 3 : 4;
8308 // FIXME: This won't work properly if the application first access chains into gl_out element,
8309 // then access chains into the member. Super weird, but theoretically possible ...
8310 if (length > relevant_member_index)
8311 {
8312 uint32_t mbr_idx = get<SPIRConstant>(id: ops[relevant_member_index]).scalar();
8313 masked = is_stage_output_block_member_masked(var: *var, index: mbr_idx, strip_array: true);
8314 }
8315 }
8316 else if (var)
8317 masked = is_stage_output_variable_masked(var: *var);
8318
8319 if (masked)
8320 return false;
8321 }
8322
8323 AccessChainMeta meta;
8324 SmallVector<uint32_t> indices;
8325 uint32_t next_id = ir.increase_bound_by(count: 1);
8326
8327 indices.reserve(count: length - 3 + 1);
8328
8329 uint32_t first_non_array_index = (ptr_is_chain ? 3 : 4) - (patch ? 1 : 0);
8330
8331 VariableID stage_var_id;
8332 if (patch)
8333 stage_var_id = var->storage == StorageClassInput ? patch_stage_in_var_id : patch_stage_out_var_id;
8334 else
8335 stage_var_id = var->storage == StorageClassInput ? stage_in_ptr_var_id : stage_out_ptr_var_id;
8336
8337 VariableID ptr = ptr_is_chain ? VariableID(ops[2]) : stage_var_id;
8338 if (!ptr_is_chain && !patch)
8339 {
8340 // Index into gl_in/gl_out with first array index.
8341 indices.push_back(t: ops[first_non_array_index - 1]);
8342 }
8343
8344 auto &result_ptr_type = get<SPIRType>(id: ops[0]);
8345
8346 uint32_t const_mbr_id = next_id++;
8347 uint32_t index = get_extended_decoration(id: ops[2], decoration: SPIRVCrossDecorationInterfaceMemberIndex);
8348
8349 // If we have a pointer chain expression, and we are no longer pointing to a composite
8350 // object, we are in the clear. There is no longer a need to flatten anything.
8351 bool further_access_chain_is_trivial = false;
8352 if (ptr_is_chain && flatten_composites)
8353 {
8354 auto &ptr_type = expression_type(id: ptr);
8355 if (!is_array(type: ptr_type) && !is_matrix(type: ptr_type) && ptr_type.basetype != SPIRType::Struct)
8356 further_access_chain_is_trivial = true;
8357 }
8358
8359 if (!further_access_chain_is_trivial && (flatten_composites || is_block))
8360 {
8361 uint32_t i = first_non_array_index;
8362 auto *type = &get_variable_element_type(var: *var);
8363 if (index == uint32_t(-1) && length >= (first_non_array_index + 1))
8364 {
8365 // Maybe this is a struct type in the input class, in which case
8366 // we put it as a decoration on the corresponding member.
8367 uint32_t mbr_idx = get_constant(id: ops[first_non_array_index]).scalar();
8368 index = get_extended_member_decoration(type: var->self, index: mbr_idx,
8369 decoration: SPIRVCrossDecorationInterfaceMemberIndex);
8370 assert(index != uint32_t(-1));
8371 i++;
8372 type = &get<SPIRType>(id: type->member_types[mbr_idx]);
8373 }
8374
8375 // In this case, we're poking into flattened structures and arrays, so now we have to
8376 // combine the following indices. If we encounter a non-constant index,
8377 // we're hosed.
8378 for (; flatten_composites && i < length; ++i)
8379 {
8380 if (!is_array(type: *type) && !is_matrix(type: *type) && type->basetype != SPIRType::Struct)
8381 break;
8382
8383 auto *c = maybe_get<SPIRConstant>(id: ops[i]);
8384 if (!c || c->specialization)
8385 SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable in tessellation. "
8386 "This is currently unsupported.");
8387
8388 // We're in flattened space, so just increment the member index into IO block.
8389 // We can only do this once in the current implementation, so either:
8390 // Struct, Matrix or 1-dimensional array for a control point.
8391 if (type->basetype == SPIRType::Struct && var->storage == StorageClassOutput)
8392 {
8393 // Need to consider holes, since individual block members might be masked away.
8394 uint32_t mbr_idx = c->scalar();
8395 for (uint32_t j = 0; j < mbr_idx; j++)
8396 if (!is_stage_output_block_member_masked(var: *var, index: j, strip_array: true))
8397 index++;
8398 }
8399 else
8400 index += c->scalar();
8401
8402 if (type->parent_type)
8403 type = &get<SPIRType>(id: type->parent_type);
8404 else if (type->basetype == SPIRType::Struct)
8405 type = &get<SPIRType>(id: type->member_types[c->scalar()]);
8406 }
8407
8408 // We're not going to emit the actual member name, we let any further OpLoad take care of that.
8409 // Tag the access chain with the member index we're referencing.
8410 auto &result_pointee_type = get_pointee_type(type: result_ptr_type);
8411 bool defer_access_chain = flatten_composites && (is_matrix(type: result_pointee_type) || is_array(type: result_pointee_type) ||
8412 result_pointee_type.basetype == SPIRType::Struct);
8413
8414 if (!defer_access_chain)
8415 {
8416 // Access the appropriate member of gl_in/gl_out.
8417 set<SPIRConstant>(id: const_mbr_id, args: get_uint_type_id(), args&: index, args: false);
8418 indices.push_back(t: const_mbr_id);
8419
8420 // Member index is now irrelevant.
8421 index = uint32_t(-1);
8422
8423 // Append any straggling access chain indices.
8424 if (i < length)
8425 indices.insert(itr: indices.end(), insert_begin: ops + i, insert_end: ops + length);
8426 }
8427 else
8428 {
8429 // We must have consumed the entire access chain if we're deferring it.
8430 assert(i == length);
8431 }
8432
8433 if (index != uint32_t(-1))
8434 set_extended_decoration(id: ops[1], decoration: SPIRVCrossDecorationInterfaceMemberIndex, value: index);
8435 else
8436 unset_extended_decoration(id: ops[1], decoration: SPIRVCrossDecorationInterfaceMemberIndex);
8437 }
8438 else
8439 {
8440 if (index != uint32_t(-1))
8441 {
8442 set<SPIRConstant>(id: const_mbr_id, args: get_uint_type_id(), args&: index, args: false);
8443 indices.push_back(t: const_mbr_id);
8444 }
8445
8446 // Member index is now irrelevant.
8447 index = uint32_t(-1);
8448 unset_extended_decoration(id: ops[1], decoration: SPIRVCrossDecorationInterfaceMemberIndex);
8449
8450 indices.insert(itr: indices.end(), insert_begin: ops + first_non_array_index, insert_end: ops + length);
8451 }
8452
8453 // We use the pointer to the base of the input/output array here,
8454 // so this is always a pointer chain.
8455 string e;
8456
8457 if (!ptr_is_chain)
8458 {
8459 // This is the start of an access chain, use ptr_chain to index into control point array.
8460 e = access_chain(base: ptr, indices: indices.data(), count: uint32_t(indices.size()), target_type: result_ptr_type, meta: &meta, ptr_chain: !patch);
8461 }
8462 else
8463 {
8464 // If we're accessing a struct, we need to use member indices which are based on the IO block,
8465 // not actual struct type, so we have to use a split access chain here where
8466 // first path resolves the control point index, i.e. gl_in[index], and second half deals with
8467 // looking up flattened member name.
8468
8469 // However, it is possible that we partially accessed a struct,
8470 // by taking pointer to member inside the control-point array.
8471 // For this case, we fall back to a natural access chain since we have already dealt with remapping struct members.
8472 // One way to check this here is if we have 2 implied read expressions.
8473 // First one is the gl_in/gl_out struct itself, then an index into that array.
8474 // If we have traversed further, we use a normal access chain formulation.
8475 auto *ptr_expr = maybe_get<SPIRExpression>(id: ptr);
8476 bool split_access_chain_formulation = flatten_composites && ptr_expr &&
8477 ptr_expr->implied_read_expressions.size() == 2 &&
8478 !further_access_chain_is_trivial;
8479
8480 if (split_access_chain_formulation)
8481 {
8482 e = join(ts: to_expression(id: ptr),
8483 ts: access_chain_internal(base: stage_var_id, indices: indices.data(), count: uint32_t(indices.size()),
8484 flags: ACCESS_CHAIN_CHAIN_ONLY_BIT, meta: &meta));
8485 }
8486 else
8487 {
8488 e = access_chain_internal(base: ptr, indices: indices.data(), count: uint32_t(indices.size()), flags: 0, meta: &meta);
8489 }
8490 }
8491
8492 // Get the actual type of the object that was accessed. If it's a vector type and we changed it,
8493 // then we'll need to add a swizzle.
8494 // For this, we can't necessarily rely on the type of the base expression, because it might be
8495 // another access chain, and it will therefore already have the "correct" type.
8496 auto *expr_type = &get_variable_data_type(var: *var);
8497 if (has_extended_decoration(id: ops[2], decoration: SPIRVCrossDecorationTessIOOriginalInputTypeID))
8498 expr_type = &get<SPIRType>(id: get_extended_decoration(id: ops[2], decoration: SPIRVCrossDecorationTessIOOriginalInputTypeID));
8499 for (uint32_t i = 3; i < length; i++)
8500 {
8501 if (!is_array(type: *expr_type) && expr_type->basetype == SPIRType::Struct)
8502 expr_type = &get<SPIRType>(id: expr_type->member_types[get<SPIRConstant>(id: ops[i]).scalar()]);
8503 else
8504 expr_type = &get<SPIRType>(id: expr_type->parent_type);
8505 }
8506 if (!is_array(type: *expr_type) && !is_matrix(type: *expr_type) && expr_type->basetype != SPIRType::Struct &&
8507 expr_type->vecsize > result_ptr_type.vecsize)
8508 e += vector_swizzle(vecsize: result_ptr_type.vecsize, index: 0);
8509
8510 auto &expr = set<SPIRExpression>(id: ops[1], args: std::move(e), args: ops[0], args: should_forward(id: ops[2]));
8511 expr.loaded_from = var->self;
8512 expr.need_transpose = meta.need_transpose;
8513 expr.access_chain = true;
8514
8515 // Mark the result as being packed if necessary.
8516 if (meta.storage_is_packed)
8517 set_extended_decoration(id: ops[1], decoration: SPIRVCrossDecorationPhysicalTypePacked);
8518 if (meta.storage_physical_type != 0)
8519 set_extended_decoration(id: ops[1], decoration: SPIRVCrossDecorationPhysicalTypeID, value: meta.storage_physical_type);
8520 if (meta.storage_is_invariant)
8521 set_decoration(id: ops[1], decoration: DecorationInvariant);
8522 // Save the type we found in case the result is used in another access chain.
8523 set_extended_decoration(id: ops[1], decoration: SPIRVCrossDecorationTessIOOriginalInputTypeID, value: expr_type->self);
8524
8525 // If we have some expression dependencies in our access chain, this access chain is technically a forwarded
8526 // temporary which could be subject to invalidation.
8527 // Need to assume we're forwarded while calling inherit_expression_depdendencies.
8528 forwarded_temporaries.insert(x: ops[1]);
8529 // The access chain itself is never forced to a temporary, but its dependencies might.
8530 suppressed_usage_tracking.insert(x: ops[1]);
8531
8532 for (uint32_t i = 2; i < length; i++)
8533 {
8534 inherit_expression_dependencies(dst: ops[1], source: ops[i]);
8535 add_implied_read_expression(e&: expr, source: ops[i]);
8536 }
8537
8538 // If we have no dependencies after all, i.e., all indices in the access chain are immutable temporaries,
8539 // we're not forwarded after all.
8540 if (expr.expression_dependencies.empty())
8541 forwarded_temporaries.erase(x: ops[1]);
8542
8543 return true;
8544 }
8545
8546 // If this is the inner tessellation level, and we're tessellating triangles,
8547 // drop the last index. It isn't an array in this case, so we can't have an
8548 // array reference here. We need to make this ID a variable instead of an
8549 // expression so we don't try to dereference it as a variable pointer.
8550 // Don't do this if the index is a constant 1, though. We need to drop stores
8551 // to that one.
8552 auto *m = ir.find_meta(id: var ? var->self : ID(0));
8553 if (is_tesc_shader() && var && m && m->decoration.builtin_type == BuiltInTessLevelInner &&
8554 is_tessellating_triangles())
8555 {
8556 auto *c = maybe_get<SPIRConstant>(id: ops[3]);
8557 if (c && c->scalar() == 1)
8558 return false;
8559 auto &dest_var = set<SPIRVariable>(id: ops[1], args&: *var);
8560 dest_var.basetype = ops[0];
8561 ir.meta[ops[1]] = ir.meta[ops[2]];
8562 inherit_expression_dependencies(dst: ops[1], source: ops[2]);
8563 return true;
8564 }
8565
8566 return false;
8567}
8568
8569bool CompilerMSL::is_out_of_bounds_tessellation_level(uint32_t id_lhs)
8570{
8571 if (!is_tessellating_triangles())
8572 return false;
8573
8574 // In SPIR-V, TessLevelInner always has two elements and TessLevelOuter always has
8575 // four. This is true even if we are tessellating triangles. This allows clients
8576 // to use a single tessellation control shader with multiple tessellation evaluation
8577 // shaders.
8578 // In Metal, however, only the first element of TessLevelInner and the first three
8579 // of TessLevelOuter are accessible. This stems from how in Metal, the tessellation
8580 // levels must be stored to a dedicated buffer in a particular format that depends
8581 // on the patch type. Therefore, in Triangles mode, any store to the second
8582 // inner level or the fourth outer level must be dropped.
8583 const auto *e = maybe_get<SPIRExpression>(id: id_lhs);
8584 if (!e || !e->access_chain)
8585 return false;
8586 BuiltIn builtin = BuiltIn(get_decoration(id: e->loaded_from, decoration: DecorationBuiltIn));
8587 if (builtin != BuiltInTessLevelInner && builtin != BuiltInTessLevelOuter)
8588 return false;
8589 auto *c = maybe_get<SPIRConstant>(id: e->implied_read_expressions[1]);
8590 if (!c)
8591 return false;
8592 return (builtin == BuiltInTessLevelInner && c->scalar() == 1) ||
8593 (builtin == BuiltInTessLevelOuter && c->scalar() == 3);
8594}
8595
8596bool CompilerMSL::prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
8597 spv::StorageClass storage, bool &is_packed)
8598{
8599 // If there is any risk of writes happening with the access chain in question,
8600 // and there is a risk of concurrent write access to other components,
8601 // we must cast the access chain to a plain pointer to ensure we only access the exact scalars we expect.
8602 // The MSL compiler refuses to allow component-level access for any non-packed vector types.
8603 if (!is_packed && (storage == StorageClassStorageBuffer || storage == StorageClassWorkgroup))
8604 {
8605 const char *addr_space = storage == StorageClassWorkgroup ? "threadgroup" : "device";
8606 expr = join(ts: "((", ts&: addr_space, ts: " ", ts: type_to_glsl(type), ts: "*)&", ts: enclose_expression(expr), ts: ")");
8607
8608 // Further indexing should happen with packed rules (array index, not swizzle).
8609 is_packed = true;
8610 return true;
8611 }
8612 else
8613 return false;
8614}
8615
8616bool CompilerMSL::access_chain_needs_stage_io_builtin_translation(uint32_t base)
8617{
8618 auto *var = maybe_get_backing_variable(chain: base);
8619 if (!var || !is_tessellation_shader())
8620 return true;
8621
8622 // We only need to rewrite builtin access chains when accessing flattened builtins like gl_ClipDistance_N.
8623 // Avoid overriding it back to just gl_ClipDistance.
8624 // This can only happen in scenarios where we cannot flatten/unflatten access chains, so, the only case
8625 // where this triggers is evaluation shader inputs.
8626 bool redirect_builtin = is_tese_shader() ? var->storage == StorageClassOutput : false;
8627 return redirect_builtin;
8628}
8629
8630// Sets the interface member index for an access chain to a pull-model interpolant.
8631void CompilerMSL::fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t length)
8632{
8633 auto *var = maybe_get_backing_variable(chain: ops[2]);
8634 if (!var || !pull_model_inputs.count(x: var->self))
8635 return;
8636 // Get the base index.
8637 uint32_t interface_index;
8638 auto &var_type = get_variable_data_type(var: *var);
8639 auto &result_type = get<SPIRType>(id: ops[0]);
8640 auto *type = &var_type;
8641 if (has_extended_decoration(id: ops[2], decoration: SPIRVCrossDecorationInterfaceMemberIndex))
8642 {
8643 interface_index = get_extended_decoration(id: ops[2], decoration: SPIRVCrossDecorationInterfaceMemberIndex);
8644 }
8645 else
8646 {
8647 // Assume an access chain into a struct variable.
8648 assert(var_type.basetype == SPIRType::Struct);
8649 auto &c = get<SPIRConstant>(id: ops[3 + var_type.array.size()]);
8650 interface_index =
8651 get_extended_member_decoration(type: var->self, index: c.scalar(), decoration: SPIRVCrossDecorationInterfaceMemberIndex);
8652 }
8653 // Accumulate indices. We'll have to skip over the one for the struct, if present, because we already accounted
8654 // for that getting the base index.
8655 for (uint32_t i = 3; i < length; ++i)
8656 {
8657 if (is_vector(type: *type) && !is_array(type: *type) && is_scalar(type: result_type))
8658 {
8659 // We don't want to combine the next index. Actually, we need to save it
8660 // so we know to apply a swizzle to the result of the interpolation.
8661 set_extended_decoration(id: ops[1], decoration: SPIRVCrossDecorationInterpolantComponentExpr, value: ops[i]);
8662 break;
8663 }
8664
8665 auto *c = maybe_get<SPIRConstant>(id: ops[i]);
8666 if (!c || c->specialization)
8667 SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable using pull-model "
8668 "interpolation. This is currently unsupported.");
8669
8670 if (type->parent_type)
8671 type = &get<SPIRType>(id: type->parent_type);
8672 else if (type->basetype == SPIRType::Struct)
8673 type = &get<SPIRType>(id: type->member_types[c->scalar()]);
8674
8675 if (!has_extended_decoration(id: ops[2], decoration: SPIRVCrossDecorationInterfaceMemberIndex) &&
8676 i - 3 == var_type.array.size())
8677 continue;
8678
8679 interface_index += c->scalar();
8680 }
8681 // Save this to the access chain itself so we can recover it later when calling an interpolation function.
8682 set_extended_decoration(id: ops[1], decoration: SPIRVCrossDecorationInterfaceMemberIndex, value: interface_index);
8683}
8684
8685
8686// If the physical type of a physical buffer pointer has been changed
8687// to a ulong or ulongn vector, add a cast back to the pointer type.
8688void CompilerMSL::check_physical_type_cast(std::string &expr, const SPIRType *type, uint32_t physical_type)
8689{
8690 auto *p_physical_type = maybe_get<SPIRType>(id: physical_type);
8691 if (p_physical_type &&
8692 p_physical_type->storage == StorageClassPhysicalStorageBuffer &&
8693 p_physical_type->basetype == to_unsigned_basetype(width: 64))
8694 {
8695 if (p_physical_type->vecsize > 1)
8696 expr += ".x";
8697
8698 expr = join(ts: "((", ts: type_to_glsl(type: *type), ts: ")", ts&: expr, ts: ")");
8699 }
8700}
8701
8702// Override for MSL-specific syntax instructions
8703void CompilerMSL::emit_instruction(const Instruction &instruction)
8704{
8705#define MSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
8706#define MSL_PTR_BOP(op) emit_binary_ptr_op(ops[0], ops[1], ops[2], ops[3], #op)
8707 // MSL does care about implicit integer promotion, but those cases are all handled in common code.
8708#define MSL_BOP_CAST(op, type) \
8709 emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode), false)
8710#define MSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
8711#define MSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
8712#define MSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
8713#define MSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
8714#define MSL_BFOP_CAST(op, type) \
8715 emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
8716#define MSL_UFOP(op) emit_unary_func_op(ops[0], ops[1], ops[2], #op)
8717#define MSL_UNORD_BOP(op) emit_binary_unord_op(ops[0], ops[1], ops[2], ops[3], #op)
8718
8719 auto ops = stream(instr: instruction);
8720 auto opcode = static_cast<Op>(instruction.op);
8721
8722 opcode = get_remapped_spirv_op(op: opcode);
8723
8724 // If we need to do implicit bitcasts, make sure we do it with the correct type.
8725 uint32_t integer_width = get_integer_width_for_instruction(instr: instruction);
8726 auto int_type = to_signed_basetype(width: integer_width);
8727 auto uint_type = to_unsigned_basetype(width: integer_width);
8728
8729 switch (opcode)
8730 {
8731 case OpLoad:
8732 {
8733 uint32_t id = ops[1];
8734 uint32_t ptr = ops[2];
8735 if (is_tessellation_shader())
8736 {
8737 if (!emit_tessellation_io_load(result_type_id: ops[0], id, ptr))
8738 CompilerGLSL::emit_instruction(instr: instruction);
8739 }
8740 else
8741 {
8742 // Sample mask input for Metal is not an array
8743 if (BuiltIn(get_decoration(id: ptr, decoration: DecorationBuiltIn)) == BuiltInSampleMask)
8744 set_decoration(id, decoration: DecorationBuiltIn, argument: BuiltInSampleMask);
8745 CompilerGLSL::emit_instruction(instr: instruction);
8746 }
8747 break;
8748 }
8749
8750 // Comparisons
8751 case OpIEqual:
8752 MSL_BOP_CAST(==, int_type);
8753 break;
8754
8755 case OpLogicalEqual:
8756 case OpFOrdEqual:
8757 MSL_BOP(==);
8758 break;
8759
8760 case OpINotEqual:
8761 MSL_BOP_CAST(!=, int_type);
8762 break;
8763
8764 case OpLogicalNotEqual:
8765 case OpFOrdNotEqual:
8766 // TODO: Should probably negate the == result here.
8767 // Typically OrdNotEqual comes from GLSL which itself does not really specify what
8768 // happens with NaN.
8769 // Consider fixing this if we run into real issues.
8770 MSL_BOP(!=);
8771 break;
8772
8773 case OpUGreaterThan:
8774 MSL_BOP_CAST(>, uint_type);
8775 break;
8776
8777 case OpSGreaterThan:
8778 MSL_BOP_CAST(>, int_type);
8779 break;
8780
8781 case OpFOrdGreaterThan:
8782 MSL_BOP(>);
8783 break;
8784
8785 case OpUGreaterThanEqual:
8786 MSL_BOP_CAST(>=, uint_type);
8787 break;
8788
8789 case OpSGreaterThanEqual:
8790 MSL_BOP_CAST(>=, int_type);
8791 break;
8792
8793 case OpFOrdGreaterThanEqual:
8794 MSL_BOP(>=);
8795 break;
8796
8797 case OpULessThan:
8798 MSL_BOP_CAST(<, uint_type);
8799 break;
8800
8801 case OpSLessThan:
8802 MSL_BOP_CAST(<, int_type);
8803 break;
8804
8805 case OpFOrdLessThan:
8806 MSL_BOP(<);
8807 break;
8808
8809 case OpULessThanEqual:
8810 MSL_BOP_CAST(<=, uint_type);
8811 break;
8812
8813 case OpSLessThanEqual:
8814 MSL_BOP_CAST(<=, int_type);
8815 break;
8816
8817 case OpFOrdLessThanEqual:
8818 MSL_BOP(<=);
8819 break;
8820
8821 case OpFUnordEqual:
8822 MSL_UNORD_BOP(==);
8823 break;
8824
8825 case OpFUnordNotEqual:
8826 // not equal in MSL generates une opcodes to begin with.
8827 // Since unordered not equal is how it works in C, just inherit that behavior.
8828 MSL_BOP(!=);
8829 break;
8830
8831 case OpFUnordGreaterThan:
8832 MSL_UNORD_BOP(>);
8833 break;
8834
8835 case OpFUnordGreaterThanEqual:
8836 MSL_UNORD_BOP(>=);
8837 break;
8838
8839 case OpFUnordLessThan:
8840 MSL_UNORD_BOP(<);
8841 break;
8842
8843 case OpFUnordLessThanEqual:
8844 MSL_UNORD_BOP(<=);
8845 break;
8846
8847 // Pointer math
8848 case OpPtrEqual:
8849 MSL_PTR_BOP(==);
8850 break;
8851
8852 case OpPtrNotEqual:
8853 MSL_PTR_BOP(!=);
8854 break;
8855
8856 case OpPtrDiff:
8857 MSL_PTR_BOP(-);
8858 break;
8859
8860 // Derivatives
8861 case OpDPdx:
8862 case OpDPdxFine:
8863 case OpDPdxCoarse:
8864 MSL_UFOP(dfdx);
8865 register_control_dependent_expression(expr: ops[1]);
8866 break;
8867
8868 case OpDPdy:
8869 case OpDPdyFine:
8870 case OpDPdyCoarse:
8871 MSL_UFOP(dfdy);
8872 register_control_dependent_expression(expr: ops[1]);
8873 break;
8874
8875 case OpFwidth:
8876 case OpFwidthCoarse:
8877 case OpFwidthFine:
8878 MSL_UFOP(fwidth);
8879 register_control_dependent_expression(expr: ops[1]);
8880 break;
8881
8882 // Bitfield
8883 case OpBitFieldInsert:
8884 {
8885 emit_bitfield_insert_op(result_type: ops[0], result_id: ops[1], op0: ops[2], op1: ops[3], op2: ops[4], op3: ops[5], op: "insert_bits", offset_count_type: SPIRType::UInt);
8886 break;
8887 }
8888
8889 case OpBitFieldSExtract:
8890 {
8891 emit_trinary_func_op_bitextract(result_type: ops[0], result_id: ops[1], op0: ops[2], op1: ops[3], op2: ops[4], op: "extract_bits", expected_result_type: int_type, input_type0: int_type,
8892 input_type1: SPIRType::UInt, input_type2: SPIRType::UInt);
8893 break;
8894 }
8895
8896 case OpBitFieldUExtract:
8897 {
8898 emit_trinary_func_op_bitextract(result_type: ops[0], result_id: ops[1], op0: ops[2], op1: ops[3], op2: ops[4], op: "extract_bits", expected_result_type: uint_type, input_type0: uint_type,
8899 input_type1: SPIRType::UInt, input_type2: SPIRType::UInt);
8900 break;
8901 }
8902
8903 case OpBitReverse:
8904 // BitReverse does not have issues with sign since result type must match input type.
8905 MSL_UFOP(reverse_bits);
8906 break;
8907
8908 case OpBitCount:
8909 {
8910 auto basetype = expression_type(id: ops[2]).basetype;
8911 emit_unary_func_op_cast(result_type: ops[0], result_id: ops[1], op0: ops[2], op: "popcount", input_type: basetype, expected_result_type: basetype);
8912 break;
8913 }
8914
8915 case OpFRem:
8916 MSL_BFOP(fmod);
8917 break;
8918
8919 case OpFMul:
8920 if (msl_options.invariant_float_math || has_decoration(id: ops[1], decoration: DecorationNoContraction))
8921 MSL_BFOP(spvFMul);
8922 else
8923 MSL_BOP(*);
8924 break;
8925
8926 case OpFAdd:
8927 if (msl_options.invariant_float_math || has_decoration(id: ops[1], decoration: DecorationNoContraction))
8928 MSL_BFOP(spvFAdd);
8929 else
8930 MSL_BOP(+);
8931 break;
8932
8933 case OpFSub:
8934 if (msl_options.invariant_float_math || has_decoration(id: ops[1], decoration: DecorationNoContraction))
8935 MSL_BFOP(spvFSub);
8936 else
8937 MSL_BOP(-);
8938 break;
8939
8940 // Atomics
8941 case OpAtomicExchange:
8942 {
8943 uint32_t result_type = ops[0];
8944 uint32_t id = ops[1];
8945 uint32_t ptr = ops[2];
8946 uint32_t mem_sem = ops[4];
8947 uint32_t val = ops[5];
8948 emit_atomic_func_op(result_type, result_id: id, op: "atomic_exchange", opcode, mem_order_1: mem_sem, mem_order_2: mem_sem, has_mem_order_2: false, op0: ptr, op1: val);
8949 break;
8950 }
8951
8952 case OpAtomicCompareExchange:
8953 {
8954 uint32_t result_type = ops[0];
8955 uint32_t id = ops[1];
8956 uint32_t ptr = ops[2];
8957 uint32_t mem_sem_pass = ops[4];
8958 uint32_t mem_sem_fail = ops[5];
8959 uint32_t val = ops[6];
8960 uint32_t comp = ops[7];
8961 emit_atomic_func_op(result_type, result_id: id, op: "atomic_compare_exchange_weak", opcode,
8962 mem_order_1: mem_sem_pass, mem_order_2: mem_sem_fail, has_mem_order_2: true,
8963 op0: ptr, op1: comp, op1_is_pointer: true, op1_is_literal: false, op2: val);
8964 break;
8965 }
8966
8967 case OpAtomicCompareExchangeWeak:
8968 SPIRV_CROSS_THROW("OpAtomicCompareExchangeWeak is only supported in kernel profile.");
8969
8970 case OpAtomicLoad:
8971 {
8972 uint32_t result_type = ops[0];
8973 uint32_t id = ops[1];
8974 uint32_t ptr = ops[2];
8975 uint32_t mem_sem = ops[4];
8976 check_atomic_image(id: ptr);
8977 emit_atomic_func_op(result_type, result_id: id, op: "atomic_load", opcode, mem_order_1: mem_sem, mem_order_2: mem_sem, has_mem_order_2: false, op0: ptr, op1: 0);
8978 break;
8979 }
8980
8981 case OpAtomicStore:
8982 {
8983 uint32_t result_type = expression_type(id: ops[0]).self;
8984 uint32_t id = ops[0];
8985 uint32_t ptr = ops[0];
8986 uint32_t mem_sem = ops[2];
8987 uint32_t val = ops[3];
8988 check_atomic_image(id: ptr);
8989 emit_atomic_func_op(result_type, result_id: id, op: "atomic_store", opcode, mem_order_1: mem_sem, mem_order_2: mem_sem, has_mem_order_2: false, op0: ptr, op1: val);
8990 break;
8991 }
8992
8993#define MSL_AFMO_IMPL(op, valsrc, valconst) \
8994 do \
8995 { \
8996 uint32_t result_type = ops[0]; \
8997 uint32_t id = ops[1]; \
8998 uint32_t ptr = ops[2]; \
8999 uint32_t mem_sem = ops[4]; \
9000 uint32_t val = valsrc; \
9001 emit_atomic_func_op(result_type, id, "atomic_fetch_" #op, opcode, \
9002 mem_sem, mem_sem, false, ptr, val, \
9003 false, valconst); \
9004 } while (false)
9005
9006#define MSL_AFMO(op) MSL_AFMO_IMPL(op, ops[5], false)
9007#define MSL_AFMIO(op) MSL_AFMO_IMPL(op, 1, true)
9008
9009 case OpAtomicIIncrement:
9010 MSL_AFMIO(add);
9011 break;
9012
9013 case OpAtomicIDecrement:
9014 MSL_AFMIO(sub);
9015 break;
9016
9017 case OpAtomicIAdd:
9018 case OpAtomicFAddEXT:
9019 MSL_AFMO(add);
9020 break;
9021
9022 case OpAtomicISub:
9023 MSL_AFMO(sub);
9024 break;
9025
9026 case OpAtomicSMin:
9027 case OpAtomicUMin:
9028 MSL_AFMO(min);
9029 break;
9030
9031 case OpAtomicSMax:
9032 case OpAtomicUMax:
9033 MSL_AFMO(max);
9034 break;
9035
9036 case OpAtomicAnd:
9037 MSL_AFMO(and);
9038 break;
9039
9040 case OpAtomicOr:
9041 MSL_AFMO(or);
9042 break;
9043
9044 case OpAtomicXor:
9045 MSL_AFMO(xor);
9046 break;
9047
9048 // Images
9049
9050 // Reads == Fetches in Metal
9051 case OpImageRead:
9052 {
9053 // Mark that this shader reads from this image
9054 uint32_t img_id = ops[2];
9055 auto &type = expression_type(id: img_id);
9056 auto *p_var = maybe_get_backing_variable(chain: img_id);
9057 if (type.image.dim != DimSubpassData)
9058 {
9059 if (p_var && has_decoration(id: p_var->self, decoration: DecorationNonReadable))
9060 {
9061 unset_decoration(id: p_var->self, decoration: DecorationNonReadable);
9062 force_recompile();
9063 }
9064 }
9065
9066 // Metal requires explicit fences to break up RAW hazards, even within the same shader invocation
9067 if (msl_options.readwrite_texture_fences && p_var && !has_decoration(id: p_var->self, decoration: DecorationNonWritable))
9068 {
9069 add_spv_func_and_recompile(spv_func: SPVFuncImplImageFence);
9070 // Need to wrap this with a value type,
9071 // since the Metal headers are broken and do not consider case when the image is a reference.
9072 statement(ts: "spvImageFence(", ts: to_expression(id: img_id), ts: ");");
9073 }
9074
9075 emit_texture_op(i: instruction, sparse: false);
9076 break;
9077 }
9078
9079 // Emulate texture2D atomic operations
9080 case OpImageTexelPointer:
9081 {
9082 // When using the pointer, we need to know which variable it is actually loaded from.
9083 auto *var = maybe_get_backing_variable(chain: ops[2]);
9084 if (var && atomic_image_vars_emulated.count(x: var->self))
9085 {
9086 uint32_t result_type = ops[0];
9087 uint32_t id = ops[1];
9088
9089 std::string coord = to_expression(id: ops[3]);
9090 auto &type = expression_type(id: ops[2]);
9091 if (type.image.dim == Dim2D)
9092 {
9093 coord = join(ts: "spvImage2DAtomicCoord(", ts&: coord, ts: ", ", ts: to_expression(id: ops[2]), ts: ")");
9094 }
9095
9096 auto &e = set<SPIRExpression>(id, args: join(ts: to_expression(id: ops[2]), ts: "_atomic[", ts&: coord, ts: "]"), args&: result_type, args: true);
9097 e.loaded_from = var ? var->self : ID(0);
9098 inherit_expression_dependencies(dst: id, source: ops[3]);
9099 }
9100 else
9101 {
9102 uint32_t result_type = ops[0];
9103 uint32_t id = ops[1];
9104
9105 // Virtual expression. Split this up in the actual image atomic.
9106 // In GLSL and HLSL we are able to resolve the dereference inline, but MSL has
9107 // image.op(coord, ...) syntax.
9108 auto &e =
9109 set<SPIRExpression>(id, args: join(ts: to_expression(id: ops[2]), ts: "@",
9110 ts: bitcast_expression(target_type: SPIRType::UInt, arg: ops[3])),
9111 args&: result_type, args: true);
9112
9113 // When using the pointer, we need to know which variable it is actually loaded from.
9114 e.loaded_from = var ? var->self : ID(0);
9115 inherit_expression_dependencies(dst: id, source: ops[3]);
9116 }
9117 break;
9118 }
9119
9120 case OpImageWrite:
9121 {
9122 uint32_t img_id = ops[0];
9123 uint32_t coord_id = ops[1];
9124 uint32_t texel_id = ops[2];
9125 const uint32_t *opt = &ops[3];
9126 uint32_t length = instruction.length - 3;
9127
9128 // Bypass pointers because we need the real image struct
9129 auto &type = expression_type(id: img_id);
9130 auto &img_type = get<SPIRType>(id: type.self);
9131
9132 // Ensure this image has been marked as being written to and force a
9133 // recommpile so that the image type output will include write access
9134 auto *p_var = maybe_get_backing_variable(chain: img_id);
9135 if (p_var && has_decoration(id: p_var->self, decoration: DecorationNonWritable))
9136 {
9137 unset_decoration(id: p_var->self, decoration: DecorationNonWritable);
9138 force_recompile();
9139 }
9140
9141 bool forward = false;
9142 uint32_t bias = 0;
9143 uint32_t lod = 0;
9144 uint32_t flags = 0;
9145
9146 if (length)
9147 {
9148 flags = *opt++;
9149 length--;
9150 }
9151
9152 auto test = [&](uint32_t &v, uint32_t flag) {
9153 if (length && (flags & flag))
9154 {
9155 v = *opt++;
9156 length--;
9157 }
9158 };
9159
9160 test(bias, ImageOperandsBiasMask);
9161 test(lod, ImageOperandsLodMask);
9162
9163 auto &texel_type = expression_type(id: texel_id);
9164 auto store_type = texel_type;
9165 store_type.vecsize = 4;
9166
9167 TextureFunctionArguments args = {};
9168 args.base.img = img_id;
9169 args.base.imgtype = &img_type;
9170 args.base.is_fetch = true;
9171 args.coord = coord_id;
9172 args.lod = lod;
9173
9174 string expr;
9175 if (needs_frag_discard_checks())
9176 expr = join(ts: "(", ts: builtin_to_glsl(builtin: BuiltInHelperInvocation, storage: StorageClassInput), ts: " ? ((void)0) : ");
9177 expr += join(ts: to_expression(id: img_id), ts: ".write(",
9178 ts: remap_swizzle(result_type: store_type, input_components: texel_type.vecsize, expr: to_expression(id: texel_id)), ts: ", ",
9179 ts: CompilerMSL::to_function_args(args, p_forward: &forward), ts: ")");
9180 if (needs_frag_discard_checks())
9181 expr += ")";
9182 statement(ts&: expr, ts: ";");
9183
9184 if (p_var && variable_storage_is_aliased(var: *p_var))
9185 flush_all_aliased_variables();
9186
9187 break;
9188 }
9189
9190 case OpImageQuerySize:
9191 case OpImageQuerySizeLod:
9192 {
9193 uint32_t rslt_type_id = ops[0];
9194 auto &rslt_type = get<SPIRType>(id: rslt_type_id);
9195
9196 uint32_t id = ops[1];
9197
9198 uint32_t img_id = ops[2];
9199 string img_exp = to_expression(id: img_id);
9200 auto &img_type = expression_type(id: img_id);
9201 Dim img_dim = img_type.image.dim;
9202 bool img_is_array = img_type.image.arrayed;
9203
9204 if (img_type.basetype != SPIRType::Image)
9205 SPIRV_CROSS_THROW("Invalid type for OpImageQuerySize.");
9206
9207 string lod;
9208 if (opcode == OpImageQuerySizeLod)
9209 {
9210 // LOD index defaults to zero, so don't bother outputing level zero index
9211 string decl_lod = to_expression(id: ops[3]);
9212 if (decl_lod != "0")
9213 lod = decl_lod;
9214 }
9215
9216 string expr = type_to_glsl(type: rslt_type) + "(";
9217 expr += img_exp + ".get_width(" + lod + ")";
9218
9219 if (img_dim == Dim2D || img_dim == DimCube || img_dim == Dim3D)
9220 expr += ", " + img_exp + ".get_height(" + lod + ")";
9221
9222 if (img_dim == Dim3D)
9223 expr += ", " + img_exp + ".get_depth(" + lod + ")";
9224
9225 if (img_is_array)
9226 {
9227 expr += ", " + img_exp + ".get_array_size()";
9228 if (img_dim == DimCube && msl_options.emulate_cube_array)
9229 expr += " / 6";
9230 }
9231
9232 expr += ")";
9233
9234 emit_op(result_type: rslt_type_id, result_id: id, rhs: expr, forward_rhs: should_forward(id: img_id));
9235
9236 break;
9237 }
9238
9239 case OpImageQueryLod:
9240 {
9241 if (!msl_options.supports_msl_version(major: 2, minor: 2))
9242 SPIRV_CROSS_THROW("ImageQueryLod is only supported on MSL 2.2 and up.");
9243 uint32_t result_type = ops[0];
9244 uint32_t id = ops[1];
9245 uint32_t image_id = ops[2];
9246 uint32_t coord_id = ops[3];
9247 emit_uninitialized_temporary_expression(type: result_type, id);
9248
9249 auto sampler_expr = to_sampler_expression(id: image_id);
9250 auto *combined = maybe_get<SPIRCombinedImageSampler>(id: image_id);
9251 auto image_expr = combined ? to_expression(id: combined->image) : to_expression(id: image_id);
9252
9253 // TODO: It is unclear if calculcate_clamped_lod also conditionally rounds
9254 // the reported LOD based on the sampler. NEAREST miplevel should
9255 // round the LOD, but LINEAR miplevel should not round.
9256 // Let's hope this does not become an issue ...
9257 statement(ts: to_expression(id), ts: ".x = ", ts&: image_expr, ts: ".calculate_clamped_lod(", ts&: sampler_expr, ts: ", ",
9258 ts: to_expression(id: coord_id), ts: ");");
9259 statement(ts: to_expression(id), ts: ".y = ", ts&: image_expr, ts: ".calculate_unclamped_lod(", ts&: sampler_expr, ts: ", ",
9260 ts: to_expression(id: coord_id), ts: ");");
9261 register_control_dependent_expression(expr: id);
9262 break;
9263 }
9264
9265#define MSL_ImgQry(qrytype) \
9266 do \
9267 { \
9268 uint32_t rslt_type_id = ops[0]; \
9269 auto &rslt_type = get<SPIRType>(rslt_type_id); \
9270 uint32_t id = ops[1]; \
9271 uint32_t img_id = ops[2]; \
9272 string img_exp = to_expression(img_id); \
9273 string expr = type_to_glsl(rslt_type) + "(" + img_exp + ".get_num_" #qrytype "())"; \
9274 emit_op(rslt_type_id, id, expr, should_forward(img_id)); \
9275 } while (false)
9276
9277 case OpImageQueryLevels:
9278 MSL_ImgQry(mip_levels);
9279 break;
9280
9281 case OpImageQuerySamples:
9282 MSL_ImgQry(samples);
9283 break;
9284
9285 case OpImage:
9286 {
9287 uint32_t result_type = ops[0];
9288 uint32_t id = ops[1];
9289 auto *combined = maybe_get<SPIRCombinedImageSampler>(id: ops[2]);
9290
9291 if (combined)
9292 {
9293 auto &e = emit_op(result_type, result_id: id, rhs: to_expression(id: combined->image), forward_rhs: true, suppress_usage_tracking: true);
9294 auto *var = maybe_get_backing_variable(chain: combined->image);
9295 if (var)
9296 e.loaded_from = var->self;
9297 }
9298 else
9299 {
9300 auto *var = maybe_get_backing_variable(chain: ops[2]);
9301 SPIRExpression *e;
9302 if (var && has_extended_decoration(id: var->self, decoration: SPIRVCrossDecorationDynamicImageSampler))
9303 e = &emit_op(result_type, result_id: id, rhs: join(ts: to_expression(id: ops[2]), ts: ".plane0"), forward_rhs: true, suppress_usage_tracking: true);
9304 else
9305 e = &emit_op(result_type, result_id: id, rhs: to_expression(id: ops[2]), forward_rhs: true, suppress_usage_tracking: true);
9306 if (var)
9307 e->loaded_from = var->self;
9308 }
9309 break;
9310 }
9311
9312 // Casting
9313 case OpQuantizeToF16:
9314 {
9315 uint32_t result_type = ops[0];
9316 uint32_t id = ops[1];
9317 uint32_t arg = ops[2];
9318 string exp = join(ts: "spvQuantizeToF16(", ts: to_expression(id: arg), ts: ")");
9319 emit_op(result_type, result_id: id, rhs: exp, forward_rhs: should_forward(id: arg));
9320 break;
9321 }
9322
9323 case OpInBoundsAccessChain:
9324 case OpAccessChain:
9325 case OpPtrAccessChain:
9326 if (is_tessellation_shader())
9327 {
9328 if (!emit_tessellation_access_chain(ops, length: instruction.length))
9329 CompilerGLSL::emit_instruction(instr: instruction);
9330 }
9331 else
9332 CompilerGLSL::emit_instruction(instr: instruction);
9333 fix_up_interpolant_access_chain(ops, length: instruction.length);
9334 break;
9335
9336 case OpStore:
9337 {
9338 const auto &type = expression_type(id: ops[0]);
9339
9340 if (is_out_of_bounds_tessellation_level(id_lhs: ops[0]))
9341 break;
9342
9343 if (needs_frag_discard_checks() &&
9344 (type.storage == StorageClassStorageBuffer || type.storage == StorageClassUniform))
9345 {
9346 // If we're in a continue block, this kludge will make the block too complex
9347 // to emit normally.
9348 assert(current_emitting_block);
9349 auto cont_type = continue_block_type(continue_block: *current_emitting_block);
9350 if (cont_type != SPIRBlock::ContinueNone && cont_type != SPIRBlock::ComplexLoop)
9351 {
9352 current_emitting_block->complex_continue = true;
9353 force_recompile();
9354 }
9355 statement(ts: "if (!", ts: builtin_to_glsl(builtin: BuiltInHelperInvocation, storage: StorageClassInput), ts: ")");
9356 begin_scope();
9357 }
9358 if (!maybe_emit_array_assignment(id_lhs: ops[0], id_rhs: ops[1]))
9359 CompilerGLSL::emit_instruction(instr: instruction);
9360 if (needs_frag_discard_checks() &&
9361 (type.storage == StorageClassStorageBuffer || type.storage == StorageClassUniform))
9362 end_scope();
9363 break;
9364 }
9365
9366 // Compute barriers
9367 case OpMemoryBarrier:
9368 emit_barrier(id_exe_scope: 0, id_mem_scope: ops[0], id_mem_sem: ops[1]);
9369 break;
9370
9371 case OpControlBarrier:
9372 // In GLSL a memory barrier is often followed by a control barrier.
9373 // But in MSL, memory barriers are also control barriers, so don't
9374 // emit a simple control barrier if a memory barrier has just been emitted.
9375 if (previous_instruction_opcode != OpMemoryBarrier)
9376 emit_barrier(id_exe_scope: ops[0], id_mem_scope: ops[1], id_mem_sem: ops[2]);
9377 break;
9378
9379 case OpOuterProduct:
9380 {
9381 uint32_t result_type = ops[0];
9382 uint32_t id = ops[1];
9383 uint32_t a = ops[2];
9384 uint32_t b = ops[3];
9385
9386 auto &type = get<SPIRType>(id: result_type);
9387 string expr = type_to_glsl_constructor(type);
9388 expr += "(";
9389 for (uint32_t col = 0; col < type.columns; col++)
9390 {
9391 expr += to_enclosed_unpacked_expression(id: a);
9392 expr += " * ";
9393 expr += to_extract_component_expression(id: b, index: col);
9394 if (col + 1 < type.columns)
9395 expr += ", ";
9396 }
9397 expr += ")";
9398 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: should_forward(id: a) && should_forward(id: b));
9399 inherit_expression_dependencies(dst: id, source: a);
9400 inherit_expression_dependencies(dst: id, source: b);
9401 break;
9402 }
9403
9404 case OpVectorTimesMatrix:
9405 case OpMatrixTimesVector:
9406 {
9407 if (!msl_options.invariant_float_math && !has_decoration(id: ops[1], decoration: DecorationNoContraction))
9408 {
9409 CompilerGLSL::emit_instruction(instr: instruction);
9410 break;
9411 }
9412
9413 // If the matrix needs transpose, just flip the multiply order.
9414 auto *e = maybe_get<SPIRExpression>(id: ops[opcode == OpMatrixTimesVector ? 2 : 3]);
9415 if (e && e->need_transpose)
9416 {
9417 e->need_transpose = false;
9418 string expr;
9419
9420 if (opcode == OpMatrixTimesVector)
9421 {
9422 expr = join(ts: "spvFMulVectorMatrix(", ts: to_enclosed_unpacked_expression(id: ops[3]), ts: ", ",
9423 ts: to_unpacked_row_major_matrix_expression(id: ops[2]), ts: ")");
9424 }
9425 else
9426 {
9427 expr = join(ts: "spvFMulMatrixVector(", ts: to_unpacked_row_major_matrix_expression(id: ops[3]), ts: ", ",
9428 ts: to_enclosed_unpacked_expression(id: ops[2]), ts: ")");
9429 }
9430
9431 bool forward = should_forward(id: ops[2]) && should_forward(id: ops[3]);
9432 emit_op(result_type: ops[0], result_id: ops[1], rhs: expr, forward_rhs: forward);
9433 e->need_transpose = true;
9434 inherit_expression_dependencies(dst: ops[1], source: ops[2]);
9435 inherit_expression_dependencies(dst: ops[1], source: ops[3]);
9436 }
9437 else
9438 {
9439 if (opcode == OpMatrixTimesVector)
9440 MSL_BFOP(spvFMulMatrixVector);
9441 else
9442 MSL_BFOP(spvFMulVectorMatrix);
9443 }
9444 break;
9445 }
9446
9447 case OpMatrixTimesMatrix:
9448 {
9449 if (!msl_options.invariant_float_math && !has_decoration(id: ops[1], decoration: DecorationNoContraction))
9450 {
9451 CompilerGLSL::emit_instruction(instr: instruction);
9452 break;
9453 }
9454
9455 auto *a = maybe_get<SPIRExpression>(id: ops[2]);
9456 auto *b = maybe_get<SPIRExpression>(id: ops[3]);
9457
9458 // If both matrices need transpose, we can multiply in flipped order and tag the expression as transposed.
9459 // a^T * b^T = (b * a)^T.
9460 if (a && b && a->need_transpose && b->need_transpose)
9461 {
9462 a->need_transpose = false;
9463 b->need_transpose = false;
9464
9465 auto expr =
9466 join(ts: "spvFMulMatrixMatrix(", ts: enclose_expression(expr: to_unpacked_row_major_matrix_expression(id: ops[3])), ts: ", ",
9467 ts: enclose_expression(expr: to_unpacked_row_major_matrix_expression(id: ops[2])), ts: ")");
9468
9469 bool forward = should_forward(id: ops[2]) && should_forward(id: ops[3]);
9470 auto &e = emit_op(result_type: ops[0], result_id: ops[1], rhs: expr, forward_rhs: forward);
9471 e.need_transpose = true;
9472 a->need_transpose = true;
9473 b->need_transpose = true;
9474 inherit_expression_dependencies(dst: ops[1], source: ops[2]);
9475 inherit_expression_dependencies(dst: ops[1], source: ops[3]);
9476 }
9477 else
9478 MSL_BFOP(spvFMulMatrixMatrix);
9479
9480 break;
9481 }
9482
9483 case OpIAddCarry:
9484 case OpISubBorrow:
9485 {
9486 uint32_t result_type = ops[0];
9487 uint32_t result_id = ops[1];
9488 uint32_t op0 = ops[2];
9489 uint32_t op1 = ops[3];
9490 auto &type = get<SPIRType>(id: result_type);
9491 emit_uninitialized_temporary_expression(type: result_type, id: result_id);
9492
9493 auto &res_type = get<SPIRType>(id: type.member_types[1]);
9494 if (opcode == OpIAddCarry)
9495 {
9496 statement(ts: to_expression(id: result_id), ts: ".", ts: to_member_name(type, index: 0), ts: " = ",
9497 ts: to_enclosed_unpacked_expression(id: op0), ts: " + ", ts: to_enclosed_unpacked_expression(id: op1), ts: ";");
9498 statement(ts: to_expression(id: result_id), ts: ".", ts: to_member_name(type, index: 1), ts: " = select(", ts: type_to_glsl(type: res_type),
9499 ts: "(1), ", ts: type_to_glsl(type: res_type), ts: "(0), ", ts: to_unpacked_expression(id: result_id), ts: ".", ts: to_member_name(type, index: 0),
9500 ts: " >= max(", ts: to_unpacked_expression(id: op0), ts: ", ", ts: to_unpacked_expression(id: op1), ts: "));");
9501 }
9502 else
9503 {
9504 statement(ts: to_expression(id: result_id), ts: ".", ts: to_member_name(type, index: 0), ts: " = ", ts: to_enclosed_unpacked_expression(id: op0), ts: " - ",
9505 ts: to_enclosed_unpacked_expression(id: op1), ts: ";");
9506 statement(ts: to_expression(id: result_id), ts: ".", ts: to_member_name(type, index: 1), ts: " = select(", ts: type_to_glsl(type: res_type),
9507 ts: "(1), ", ts: type_to_glsl(type: res_type), ts: "(0), ", ts: to_enclosed_unpacked_expression(id: op0),
9508 ts: " >= ", ts: to_enclosed_unpacked_expression(id: op1), ts: ");");
9509 }
9510 break;
9511 }
9512
9513 case OpUMulExtended:
9514 case OpSMulExtended:
9515 {
9516 uint32_t result_type = ops[0];
9517 uint32_t result_id = ops[1];
9518 uint32_t op0 = ops[2];
9519 uint32_t op1 = ops[3];
9520 auto &type = get<SPIRType>(id: result_type);
9521 auto input_type = opcode == OpSMulExtended ? int_type : uint_type;
9522 string cast_op0, cast_op1;
9523
9524 binary_op_bitcast_helper(cast_op0, cast_op1, input_type, op0, op1, skip_cast_if_equal_type: false);
9525 emit_uninitialized_temporary_expression(type: result_type, id: result_id);
9526 statement(ts: to_expression(id: result_id), ts: ".", ts: to_member_name(type, index: 0), ts: " = ", ts&: cast_op0, ts: " * ", ts&: cast_op1, ts: ";");
9527 statement(ts: to_expression(id: result_id), ts: ".", ts: to_member_name(type, index: 1), ts: " = mulhi(", ts&: cast_op0, ts: ", ", ts&: cast_op1, ts: ");");
9528 break;
9529 }
9530
9531 case OpArrayLength:
9532 {
9533 auto &type = expression_type(id: ops[2]);
9534 uint32_t offset = type_struct_member_offset(type, index: ops[3]);
9535 uint32_t stride = type_struct_member_array_stride(type, index: ops[3]);
9536
9537 auto expr = join(ts: "(", ts: to_buffer_size_expression(id: ops[2]), ts: " - ", ts&: offset, ts: ") / ", ts&: stride);
9538 emit_op(result_type: ops[0], result_id: ops[1], rhs: expr, forward_rhs: true);
9539 break;
9540 }
9541
9542 // Legacy sub-group stuff ...
9543 case OpSubgroupBallotKHR:
9544 case OpSubgroupFirstInvocationKHR:
9545 case OpSubgroupReadInvocationKHR:
9546 case OpSubgroupAllKHR:
9547 case OpSubgroupAnyKHR:
9548 case OpSubgroupAllEqualKHR:
9549 emit_subgroup_op(i: instruction);
9550 break;
9551
9552 // SPV_INTEL_shader_integer_functions2
9553 case OpUCountLeadingZerosINTEL:
9554 MSL_UFOP(clz);
9555 break;
9556
9557 case OpUCountTrailingZerosINTEL:
9558 MSL_UFOP(ctz);
9559 break;
9560
9561 case OpAbsISubINTEL:
9562 case OpAbsUSubINTEL:
9563 MSL_BFOP(absdiff);
9564 break;
9565
9566 case OpIAddSatINTEL:
9567 case OpUAddSatINTEL:
9568 MSL_BFOP(addsat);
9569 break;
9570
9571 case OpIAverageINTEL:
9572 case OpUAverageINTEL:
9573 MSL_BFOP(hadd);
9574 break;
9575
9576 case OpIAverageRoundedINTEL:
9577 case OpUAverageRoundedINTEL:
9578 MSL_BFOP(rhadd);
9579 break;
9580
9581 case OpISubSatINTEL:
9582 case OpUSubSatINTEL:
9583 MSL_BFOP(subsat);
9584 break;
9585
9586 case OpIMul32x16INTEL:
9587 {
9588 uint32_t result_type = ops[0];
9589 uint32_t id = ops[1];
9590 uint32_t a = ops[2], b = ops[3];
9591 bool forward = should_forward(id: a) && should_forward(id: b);
9592 emit_op(result_type, result_id: id, rhs: join(ts: "int(short(", ts: to_unpacked_expression(id: a), ts: ")) * int(short(", ts: to_unpacked_expression(id: b), ts: "))"), forward_rhs: forward);
9593 inherit_expression_dependencies(dst: id, source: a);
9594 inherit_expression_dependencies(dst: id, source: b);
9595 break;
9596 }
9597
9598 case OpUMul32x16INTEL:
9599 {
9600 uint32_t result_type = ops[0];
9601 uint32_t id = ops[1];
9602 uint32_t a = ops[2], b = ops[3];
9603 bool forward = should_forward(id: a) && should_forward(id: b);
9604 emit_op(result_type, result_id: id, rhs: join(ts: "uint(ushort(", ts: to_unpacked_expression(id: a), ts: ")) * uint(ushort(", ts: to_unpacked_expression(id: b), ts: "))"), forward_rhs: forward);
9605 inherit_expression_dependencies(dst: id, source: a);
9606 inherit_expression_dependencies(dst: id, source: b);
9607 break;
9608 }
9609
9610 // SPV_EXT_demote_to_helper_invocation
9611 case OpDemoteToHelperInvocationEXT:
9612 if (!msl_options.supports_msl_version(major: 2, minor: 3))
9613 SPIRV_CROSS_THROW("discard_fragment() does not formally have demote semantics until MSL 2.3.");
9614 CompilerGLSL::emit_instruction(instr: instruction);
9615 break;
9616
9617 case OpIsHelperInvocationEXT:
9618 if (msl_options.is_ios() && !msl_options.supports_msl_version(major: 2, minor: 3))
9619 SPIRV_CROSS_THROW("simd_is_helper_thread() requires MSL 2.3 on iOS.");
9620 else if (msl_options.is_macos() && !msl_options.supports_msl_version(major: 2, minor: 1))
9621 SPIRV_CROSS_THROW("simd_is_helper_thread() requires MSL 2.1 on macOS.");
9622 emit_op(result_type: ops[0], result_id: ops[1],
9623 rhs: needs_manual_helper_invocation_updates() ? builtin_to_glsl(builtin: BuiltInHelperInvocation, storage: StorageClassInput) :
9624 "simd_is_helper_thread()",
9625 forward_rhs: false);
9626 break;
9627
9628 case OpBeginInvocationInterlockEXT:
9629 case OpEndInvocationInterlockEXT:
9630 if (!msl_options.supports_msl_version(major: 2, minor: 0))
9631 SPIRV_CROSS_THROW("Raster order groups require MSL 2.0.");
9632 break; // Nothing to do in the body
9633
9634 case OpConvertUToAccelerationStructureKHR:
9635 SPIRV_CROSS_THROW("ConvertUToAccelerationStructure is not supported in MSL.");
9636 case OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR:
9637 SPIRV_CROSS_THROW("BindingTableRecordOffset is not supported in MSL.");
9638
9639 case OpRayQueryInitializeKHR:
9640 {
9641 flush_variable_declaration(id: ops[0]);
9642 register_write(chain: ops[0]);
9643 add_spv_func_and_recompile(spv_func: SPVFuncImplRayQueryIntersectionParams);
9644
9645 statement(ts: to_expression(id: ops[0]), ts: ".reset(", ts: "ray(", ts: to_expression(id: ops[4]), ts: ", ", ts: to_expression(id: ops[6]), ts: ", ",
9646 ts: to_expression(id: ops[5]), ts: ", ", ts: to_expression(id: ops[7]), ts: "), ", ts: to_expression(id: ops[1]), ts: ", ", ts: to_expression(id: ops[3]),
9647 ts: ", spvMakeIntersectionParams(", ts: to_expression(id: ops[2]), ts: "));");
9648 break;
9649 }
9650 case OpRayQueryProceedKHR:
9651 {
9652 flush_variable_declaration(id: ops[0]);
9653 register_write(chain: ops[2]);
9654 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".next()"), forward_rhs: false);
9655 break;
9656 }
9657#define MSL_RAY_QUERY_IS_CANDIDATE get<SPIRConstant>(ops[3]).scalar_i32() == 0
9658
9659#define MSL_RAY_QUERY_GET_OP(op, msl_op) \
9660 case OpRayQueryGet##op##KHR: \
9661 flush_variable_declaration(ops[2]); \
9662 emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".get_" #msl_op "()"), false); \
9663 break
9664
9665#define MSL_RAY_QUERY_OP_INNER2(op, msl_prefix, msl_op) \
9666 case OpRayQueryGet##op##KHR: \
9667 flush_variable_declaration(ops[2]); \
9668 if (MSL_RAY_QUERY_IS_CANDIDATE) \
9669 emit_op(ops[0], ops[1], join(to_expression(ops[2]), #msl_prefix "_candidate_" #msl_op "()"), false); \
9670 else \
9671 emit_op(ops[0], ops[1], join(to_expression(ops[2]), #msl_prefix "_committed_" #msl_op "()"), false); \
9672 break
9673
9674#define MSL_RAY_QUERY_GET_OP2(op, msl_op) MSL_RAY_QUERY_OP_INNER2(op, .get, msl_op)
9675#define MSL_RAY_QUERY_IS_OP2(op, msl_op) MSL_RAY_QUERY_OP_INNER2(op, .is, msl_op)
9676
9677 MSL_RAY_QUERY_GET_OP(RayTMin, ray_min_distance);
9678 MSL_RAY_QUERY_GET_OP(WorldRayOrigin, world_space_ray_origin);
9679 MSL_RAY_QUERY_GET_OP(WorldRayDirection, world_space_ray_direction);
9680 MSL_RAY_QUERY_GET_OP2(IntersectionInstanceId, instance_id);
9681 MSL_RAY_QUERY_GET_OP2(IntersectionInstanceCustomIndex, user_instance_id);
9682 MSL_RAY_QUERY_GET_OP2(IntersectionBarycentrics, triangle_barycentric_coord);
9683 MSL_RAY_QUERY_GET_OP2(IntersectionPrimitiveIndex, primitive_id);
9684 MSL_RAY_QUERY_GET_OP2(IntersectionGeometryIndex, geometry_id);
9685 MSL_RAY_QUERY_GET_OP2(IntersectionObjectRayOrigin, ray_origin);
9686 MSL_RAY_QUERY_GET_OP2(IntersectionObjectRayDirection, ray_direction);
9687 MSL_RAY_QUERY_GET_OP2(IntersectionObjectToWorld, object_to_world_transform);
9688 MSL_RAY_QUERY_GET_OP2(IntersectionWorldToObject, world_to_object_transform);
9689 MSL_RAY_QUERY_IS_OP2(IntersectionFrontFace, triangle_front_facing);
9690
9691 case OpRayQueryGetIntersectionTypeKHR:
9692 flush_variable_declaration(id: ops[2]);
9693 if (MSL_RAY_QUERY_IS_CANDIDATE)
9694 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: "uint(", ts: to_expression(id: ops[2]), ts: ".get_candidate_intersection_type()) - 1"),
9695 forward_rhs: false);
9696 else
9697 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: "uint(", ts: to_expression(id: ops[2]), ts: ".get_committed_intersection_type())"), forward_rhs: false);
9698 break;
9699 case OpRayQueryGetIntersectionTKHR:
9700 flush_variable_declaration(id: ops[2]);
9701 if (MSL_RAY_QUERY_IS_CANDIDATE)
9702 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".get_candidate_triangle_distance()"), forward_rhs: false);
9703 else
9704 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".get_committed_distance()"), forward_rhs: false);
9705 break;
9706 case OpRayQueryGetIntersectionCandidateAABBOpaqueKHR:
9707 {
9708 flush_variable_declaration(id: ops[0]);
9709 emit_op(result_type: ops[0], result_id: ops[1], rhs: join(ts: to_expression(id: ops[2]), ts: ".is_candidate_non_opaque_bounding_box()"), forward_rhs: false);
9710 break;
9711 }
9712 case OpRayQueryConfirmIntersectionKHR:
9713 flush_variable_declaration(id: ops[0]);
9714 register_write(chain: ops[0]);
9715 statement(ts: to_expression(id: ops[0]), ts: ".commit_triangle_intersection();");
9716 break;
9717 case OpRayQueryGenerateIntersectionKHR:
9718 flush_variable_declaration(id: ops[0]);
9719 register_write(chain: ops[0]);
9720 statement(ts: to_expression(id: ops[0]), ts: ".commit_bounding_box_intersection(", ts: to_expression(id: ops[1]), ts: ");");
9721 break;
9722 case OpRayQueryTerminateKHR:
9723 flush_variable_declaration(id: ops[0]);
9724 register_write(chain: ops[0]);
9725 statement(ts: to_expression(id: ops[0]), ts: ".abort();");
9726 break;
9727#undef MSL_RAY_QUERY_GET_OP
9728#undef MSL_RAY_QUERY_IS_CANDIDATE
9729#undef MSL_RAY_QUERY_IS_OP2
9730#undef MSL_RAY_QUERY_GET_OP2
9731#undef MSL_RAY_QUERY_OP_INNER2
9732
9733 case OpConvertPtrToU:
9734 case OpConvertUToPtr:
9735 case OpBitcast:
9736 {
9737 auto &type = get<SPIRType>(id: ops[0]);
9738 auto &input_type = expression_type(id: ops[2]);
9739
9740 if (opcode != OpBitcast || type.pointer || input_type.pointer)
9741 {
9742 string op;
9743
9744 if (type.vecsize == 1 && input_type.vecsize == 1)
9745 op = join(ts: "reinterpret_cast<", ts: type_to_glsl(type), ts: ">(", ts: to_unpacked_expression(id: ops[2]), ts: ")");
9746 else if (input_type.vecsize == 2)
9747 op = join(ts: "reinterpret_cast<", ts: type_to_glsl(type), ts: ">(as_type<ulong>(", ts: to_unpacked_expression(id: ops[2]), ts: "))");
9748 else
9749 op = join(ts: "as_type<", ts: type_to_glsl(type), ts: ">(reinterpret_cast<ulong>(", ts: to_unpacked_expression(id: ops[2]), ts: "))");
9750
9751 emit_op(result_type: ops[0], result_id: ops[1], rhs: op, forward_rhs: should_forward(id: ops[2]));
9752 inherit_expression_dependencies(dst: ops[1], source: ops[2]);
9753 }
9754 else
9755 CompilerGLSL::emit_instruction(instr: instruction);
9756
9757 break;
9758 }
9759
9760 case OpSDot:
9761 case OpUDot:
9762 case OpSUDot:
9763 {
9764 uint32_t result_type = ops[0];
9765 uint32_t id = ops[1];
9766 uint32_t vec1 = ops[2];
9767 uint32_t vec2 = ops[3];
9768
9769 auto &input_type1 = expression_type(id: vec1);
9770 auto &input_type2 = expression_type(id: vec2);
9771
9772 string vec1input, vec2input;
9773 auto input_size = input_type1.vecsize;
9774 if (instruction.length == 5)
9775 {
9776 if (ops[4] == PackedVectorFormatPackedVectorFormat4x8Bit)
9777 {
9778 string type = opcode == OpSDot || opcode == OpSUDot ? "char4" : "uchar4";
9779 vec1input = join(ts: "as_type<", ts&: type, ts: ">(", ts: to_expression(id: vec1), ts: ")");
9780 type = opcode == OpSDot ? "char4" : "uchar4";
9781 vec2input = join(ts: "as_type<", ts&: type, ts: ">(", ts: to_expression(id: vec2), ts: ")");
9782 input_size = 4;
9783 }
9784 else
9785 SPIRV_CROSS_THROW("Packed vector formats other than 4x8Bit for integer dot product is not supported.");
9786 }
9787 else
9788 {
9789 // Inputs are sign or zero-extended to their target width.
9790 SPIRType::BaseType vec1_expected_type =
9791 opcode != OpUDot ?
9792 to_signed_basetype(width: input_type1.width) :
9793 to_unsigned_basetype(width: input_type1.width);
9794
9795 SPIRType::BaseType vec2_expected_type =
9796 opcode != OpSDot ?
9797 to_unsigned_basetype(width: input_type2.width) :
9798 to_signed_basetype(width: input_type2.width);
9799
9800 vec1input = bitcast_expression(target_type: vec1_expected_type, arg: vec1);
9801 vec2input = bitcast_expression(target_type: vec2_expected_type, arg: vec2);
9802 }
9803
9804 auto &type = get<SPIRType>(id: result_type);
9805
9806 // We'll get the appropriate sign-extend or zero-extend, no matter which type we cast to here.
9807 // The addition in reduce_add is sign-invariant.
9808 auto result_type_cast = join(ts: type_to_glsl(type), ts&: input_size);
9809
9810 string exp = join(ts: "reduce_add(",
9811 ts&: result_type_cast, ts: "(", ts&: vec1input, ts: ") * ",
9812 ts&: result_type_cast, ts: "(", ts&: vec2input, ts: "))");
9813
9814 emit_op(result_type, result_id: id, rhs: exp, forward_rhs: should_forward(id: vec1) && should_forward(id: vec2));
9815 inherit_expression_dependencies(dst: id, source: vec1);
9816 inherit_expression_dependencies(dst: id, source: vec2);
9817 break;
9818 }
9819
9820 case OpSDotAccSat:
9821 case OpUDotAccSat:
9822 case OpSUDotAccSat:
9823 {
9824 uint32_t result_type = ops[0];
9825 uint32_t id = ops[1];
9826 uint32_t vec1 = ops[2];
9827 uint32_t vec2 = ops[3];
9828 uint32_t acc = ops[4];
9829
9830 auto input_type1 = expression_type(id: vec1);
9831 auto input_type2 = expression_type(id: vec2);
9832
9833 string vec1input, vec2input;
9834 if (instruction.length == 6)
9835 {
9836 if (ops[5] == PackedVectorFormatPackedVectorFormat4x8Bit)
9837 {
9838 string type = opcode == OpSDotAccSat || opcode == OpSUDotAccSat ? "char4" : "uchar4";
9839 vec1input = join(ts: "as_type<", ts&: type, ts: ">(", ts: to_expression(id: vec1), ts: ")");
9840 type = opcode == OpSDotAccSat ? "char4" : "uchar4";
9841 vec2input = join(ts: "as_type<", ts&: type, ts: ">(", ts: to_expression(id: vec2), ts: ")");
9842 input_type1.vecsize = 4;
9843 input_type2.vecsize = 4;
9844 }
9845 else
9846 SPIRV_CROSS_THROW("Packed vector formats other than 4x8Bit for integer dot product is not supported.");
9847 }
9848 else
9849 {
9850 // Inputs are sign or zero-extended to their target width.
9851 SPIRType::BaseType vec1_expected_type =
9852 opcode != OpUDotAccSat ?
9853 to_signed_basetype(width: input_type1.width) :
9854 to_unsigned_basetype(width: input_type1.width);
9855
9856 SPIRType::BaseType vec2_expected_type =
9857 opcode != OpSDotAccSat ?
9858 to_unsigned_basetype(width: input_type2.width) :
9859 to_signed_basetype(width: input_type2.width);
9860
9861 vec1input = bitcast_expression(target_type: vec1_expected_type, arg: vec1);
9862 vec2input = bitcast_expression(target_type: vec2_expected_type, arg: vec2);
9863 }
9864
9865 auto &type = get<SPIRType>(id: result_type);
9866
9867 SPIRType::BaseType pre_saturate_type =
9868 opcode != OpUDotAccSat ?
9869 to_signed_basetype(width: type.width) :
9870 to_unsigned_basetype(width: type.width);
9871
9872 input_type1.basetype = pre_saturate_type;
9873 input_type2.basetype = pre_saturate_type;
9874
9875 string exp = join(ts: type_to_glsl(type), ts: "(addsat(reduce_add(",
9876 ts: type_to_glsl(type: input_type1), ts: "(", ts&: vec1input, ts: ") * ",
9877 ts: type_to_glsl(type: input_type2), ts: "(", ts&: vec2input, ts: ")), ",
9878 ts: bitcast_expression(target_type: pre_saturate_type, arg: acc), ts: "))");
9879
9880 emit_op(result_type, result_id: id, rhs: exp, forward_rhs: should_forward(id: vec1) && should_forward(id: vec2));
9881 inherit_expression_dependencies(dst: id, source: vec1);
9882 inherit_expression_dependencies(dst: id, source: vec2);
9883 break;
9884 }
9885
9886 default:
9887 CompilerGLSL::emit_instruction(instr: instruction);
9888 break;
9889 }
9890
9891 previous_instruction_opcode = opcode;
9892}
9893
9894void CompilerMSL::emit_texture_op(const Instruction &i, bool sparse)
9895{
9896 if (sparse)
9897 SPIRV_CROSS_THROW("Sparse feedback not yet supported in MSL.");
9898
9899 if (msl_options.use_framebuffer_fetch_subpasses)
9900 {
9901 auto *ops = stream(instr: i);
9902
9903 uint32_t result_type_id = ops[0];
9904 uint32_t id = ops[1];
9905 uint32_t img = ops[2];
9906
9907 auto &type = expression_type(id: img);
9908 auto &imgtype = get<SPIRType>(id: type.self);
9909
9910 // Use Metal's native frame-buffer fetch API for subpass inputs.
9911 if (imgtype.image.dim == DimSubpassData)
9912 {
9913 // Subpass inputs cannot be invalidated,
9914 // so just forward the expression directly.
9915 string expr = to_expression(id: img);
9916 emit_op(result_type: result_type_id, result_id: id, rhs: expr, forward_rhs: true);
9917 return;
9918 }
9919 }
9920
9921 // Fallback to default implementation
9922 CompilerGLSL::emit_texture_op(i, sparse);
9923}
9924
9925void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uint32_t id_mem_sem)
9926{
9927 if (get_execution_model() != ExecutionModelGLCompute && !is_tesc_shader())
9928 return;
9929
9930 uint32_t exe_scope = id_exe_scope ? evaluate_constant_u32(id: id_exe_scope) : uint32_t(ScopeInvocation);
9931 uint32_t mem_scope = id_mem_scope ? evaluate_constant_u32(id: id_mem_scope) : uint32_t(ScopeInvocation);
9932 // Use the wider of the two scopes (smaller value)
9933 exe_scope = min(a: exe_scope, b: mem_scope);
9934
9935 if (msl_options.emulate_subgroups && exe_scope >= ScopeSubgroup && !id_mem_sem)
9936 // In this case, we assume a "subgroup" size of 1. The barrier, then, is a noop.
9937 return;
9938
9939 string bar_stmt;
9940 if ((msl_options.is_ios() && msl_options.supports_msl_version(major: 1, minor: 2)) || msl_options.supports_msl_version(major: 2))
9941 bar_stmt = exe_scope < ScopeSubgroup ? "threadgroup_barrier" : "simdgroup_barrier";
9942 else
9943 bar_stmt = "threadgroup_barrier";
9944 bar_stmt += "(";
9945
9946 uint32_t mem_sem = id_mem_sem ? evaluate_constant_u32(id: id_mem_sem) : uint32_t(MemorySemanticsMaskNone);
9947
9948 // Use the | operator to combine flags if we can.
9949 if (msl_options.supports_msl_version(major: 1, minor: 2))
9950 {
9951 string mem_flags = "";
9952 // For tesc shaders, this also affects objects in the Output storage class.
9953 // Since in Metal, these are placed in a device buffer, we have to sync device memory here.
9954 if (is_tesc_shader() ||
9955 (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)))
9956 mem_flags += "mem_flags::mem_device";
9957
9958 // Fix tessellation patch function processing
9959 if (is_tesc_shader() || (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask)))
9960 {
9961 if (!mem_flags.empty())
9962 mem_flags += " | ";
9963 mem_flags += "mem_flags::mem_threadgroup";
9964 }
9965 if (mem_sem & MemorySemanticsImageMemoryMask)
9966 {
9967 if (!mem_flags.empty())
9968 mem_flags += " | ";
9969 mem_flags += "mem_flags::mem_texture";
9970 }
9971
9972 if (mem_flags.empty())
9973 mem_flags = "mem_flags::mem_none";
9974
9975 bar_stmt += mem_flags;
9976 }
9977 else
9978 {
9979 if ((mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)) &&
9980 (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask)))
9981 bar_stmt += "mem_flags::mem_device_and_threadgroup";
9982 else if (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask))
9983 bar_stmt += "mem_flags::mem_device";
9984 else if (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask))
9985 bar_stmt += "mem_flags::mem_threadgroup";
9986 else if (mem_sem & MemorySemanticsImageMemoryMask)
9987 bar_stmt += "mem_flags::mem_texture";
9988 else
9989 bar_stmt += "mem_flags::mem_none";
9990 }
9991
9992 bar_stmt += ");";
9993
9994 statement(ts&: bar_stmt);
9995
9996 assert(current_emitting_block);
9997 flush_control_dependent_expressions(block: current_emitting_block->self);
9998 flush_all_active_variables();
9999}
10000
10001static bool storage_class_array_is_thread(StorageClass storage)
10002{
10003 switch (storage)
10004 {
10005 case StorageClassInput:
10006 case StorageClassOutput:
10007 case StorageClassGeneric:
10008 case StorageClassFunction:
10009 case StorageClassPrivate:
10010 return true;
10011
10012 default:
10013 return false;
10014 }
10015}
10016
10017bool CompilerMSL::emit_array_copy(const char *expr, uint32_t lhs_id, uint32_t rhs_id,
10018 StorageClass lhs_storage, StorageClass rhs_storage)
10019{
10020 // Allow Metal to use the array<T> template to make arrays a value type.
10021 // This, however, cannot be used for threadgroup address specifiers, so consider the custom array copy as fallback.
10022 bool lhs_is_thread_storage = storage_class_array_is_thread(storage: lhs_storage);
10023 bool rhs_is_thread_storage = storage_class_array_is_thread(storage: rhs_storage);
10024
10025 bool lhs_is_array_template = lhs_is_thread_storage || lhs_storage == StorageClassWorkgroup;
10026 bool rhs_is_array_template = rhs_is_thread_storage || rhs_storage == StorageClassWorkgroup;
10027
10028 // Special considerations for stage IO variables.
10029 // If the variable is actually backed by non-user visible device storage, we use array templates for those.
10030 //
10031 // Another special consideration is given to thread local variables which happen to have Offset decorations
10032 // applied to them. Block-like types do not use array templates, so we need to force POD path if we detect
10033 // these scenarios. This check isn't perfect since it would be technically possible to mix and match these things,
10034 // and for a fully correct solution we might have to track array template state through access chains as well,
10035 // but for all reasonable use cases, this should suffice.
10036 // This special case should also only apply to Function/Private storage classes.
10037 // We should not check backing variable for temporaries.
10038 auto *lhs_var = maybe_get_backing_variable(chain: lhs_id);
10039 if (lhs_var && lhs_storage == StorageClassStorageBuffer && storage_class_array_is_thread(storage: lhs_var->storage))
10040 lhs_is_array_template = true;
10041 else if (lhs_var && lhs_storage != StorageClassGeneric && type_is_block_like(type: get<SPIRType>(id: lhs_var->basetype)))
10042 lhs_is_array_template = false;
10043
10044 auto *rhs_var = maybe_get_backing_variable(chain: rhs_id);
10045 if (rhs_var && rhs_storage == StorageClassStorageBuffer && storage_class_array_is_thread(storage: rhs_var->storage))
10046 rhs_is_array_template = true;
10047 else if (rhs_var && rhs_storage != StorageClassGeneric && type_is_block_like(type: get<SPIRType>(id: rhs_var->basetype)))
10048 rhs_is_array_template = false;
10049
10050 // If threadgroup storage qualifiers are *not* used:
10051 // Avoid spvCopy* wrapper functions; Otherwise, spvUnsafeArray<> template cannot be used with that storage qualifier.
10052 if (lhs_is_array_template && rhs_is_array_template && !using_builtin_array())
10053 {
10054 // Fall back to normal copy path.
10055 return false;
10056 }
10057 else
10058 {
10059 // Ensure the LHS variable has been declared
10060 if (lhs_var)
10061 flush_variable_declaration(id: lhs_var->self);
10062
10063 string lhs;
10064 if (expr)
10065 lhs = expr;
10066 else
10067 lhs = to_expression(id: lhs_id);
10068
10069 // Assignment from an array initializer is fine.
10070 auto &type = expression_type(id: rhs_id);
10071 auto *var = maybe_get_backing_variable(chain: rhs_id);
10072
10073 // Unfortunately, we cannot template on address space in MSL,
10074 // so explicit address space redirection it is ...
10075 bool is_constant = false;
10076 if (ir.ids[rhs_id].get_type() == TypeConstant)
10077 {
10078 is_constant = true;
10079 }
10080 else if (var && var->remapped_variable && var->statically_assigned &&
10081 ir.ids[var->static_expression].get_type() == TypeConstant)
10082 {
10083 is_constant = true;
10084 }
10085 else if (rhs_storage == StorageClassUniform || rhs_storage == StorageClassUniformConstant)
10086 {
10087 is_constant = true;
10088 }
10089
10090 // For the case where we have OpLoad triggering an array copy,
10091 // we cannot easily detect this case ahead of time since it's
10092 // context dependent. We might have to force a recompile here
10093 // if this is the only use of array copies in our shader.
10094 add_spv_func_and_recompile(spv_func: type.array.size() > 1 ? SPVFuncImplArrayCopyMultidim : SPVFuncImplArrayCopy);
10095
10096 const char *tag = nullptr;
10097 if (lhs_is_thread_storage && is_constant)
10098 tag = "FromConstantToStack";
10099 else if (lhs_storage == StorageClassWorkgroup && is_constant)
10100 tag = "FromConstantToThreadGroup";
10101 else if (lhs_is_thread_storage && rhs_is_thread_storage)
10102 tag = "FromStackToStack";
10103 else if (lhs_storage == StorageClassWorkgroup && rhs_is_thread_storage)
10104 tag = "FromStackToThreadGroup";
10105 else if (lhs_is_thread_storage && rhs_storage == StorageClassWorkgroup)
10106 tag = "FromThreadGroupToStack";
10107 else if (lhs_storage == StorageClassWorkgroup && rhs_storage == StorageClassWorkgroup)
10108 tag = "FromThreadGroupToThreadGroup";
10109 else if (lhs_storage == StorageClassStorageBuffer && rhs_storage == StorageClassStorageBuffer)
10110 tag = "FromDeviceToDevice";
10111 else if (lhs_storage == StorageClassStorageBuffer && is_constant)
10112 tag = "FromConstantToDevice";
10113 else if (lhs_storage == StorageClassStorageBuffer && rhs_storage == StorageClassWorkgroup)
10114 tag = "FromThreadGroupToDevice";
10115 else if (lhs_storage == StorageClassStorageBuffer && rhs_is_thread_storage)
10116 tag = "FromStackToDevice";
10117 else if (lhs_storage == StorageClassWorkgroup && rhs_storage == StorageClassStorageBuffer)
10118 tag = "FromDeviceToThreadGroup";
10119 else if (lhs_is_thread_storage && rhs_storage == StorageClassStorageBuffer)
10120 tag = "FromDeviceToStack";
10121 else
10122 SPIRV_CROSS_THROW("Unknown storage class used for copying arrays.");
10123
10124 // Pass internal array of spvUnsafeArray<> into wrapper functions
10125 if (lhs_is_array_template && rhs_is_array_template && !msl_options.force_native_arrays)
10126 statement(ts: "spvArrayCopy", ts&: tag, ts: "(", ts&: lhs, ts: ".elements, ", ts: to_expression(id: rhs_id), ts: ".elements);");
10127 if (lhs_is_array_template && !msl_options.force_native_arrays)
10128 statement(ts: "spvArrayCopy", ts&: tag, ts: "(", ts&: lhs, ts: ".elements, ", ts: to_expression(id: rhs_id), ts: ");");
10129 else if (rhs_is_array_template && !msl_options.force_native_arrays)
10130 statement(ts: "spvArrayCopy", ts&: tag, ts: "(", ts&: lhs, ts: ", ", ts: to_expression(id: rhs_id), ts: ".elements);");
10131 else
10132 statement(ts: "spvArrayCopy", ts&: tag, ts: "(", ts&: lhs, ts: ", ", ts: to_expression(id: rhs_id), ts: ");");
10133 }
10134
10135 return true;
10136}
10137
10138uint32_t CompilerMSL::get_physical_tess_level_array_size(spv::BuiltIn builtin) const
10139{
10140 if (is_tessellating_triangles())
10141 return builtin == BuiltInTessLevelInner ? 1 : 3;
10142 else
10143 return builtin == BuiltInTessLevelInner ? 2 : 4;
10144}
10145
10146// Since MSL does not allow arrays to be copied via simple variable assignment,
10147// if the LHS and RHS represent an assignment of an entire array, it must be
10148// implemented by calling an array copy function.
10149// Returns whether the struct assignment was emitted.
10150bool CompilerMSL::maybe_emit_array_assignment(uint32_t id_lhs, uint32_t id_rhs)
10151{
10152 // We only care about assignments of an entire array
10153 auto &type = expression_type(id: id_lhs);
10154 if (!is_array(type: get_pointee_type(type)))
10155 return false;
10156
10157 auto *var = maybe_get<SPIRVariable>(id: id_lhs);
10158
10159 // Is this a remapped, static constant? Don't do anything.
10160 if (var && var->remapped_variable && var->statically_assigned)
10161 return true;
10162
10163 if (ir.ids[id_rhs].get_type() == TypeConstant && var && var->deferred_declaration)
10164 {
10165 // Special case, if we end up declaring a variable when assigning the constant array,
10166 // we can avoid the copy by directly assigning the constant expression.
10167 // This is likely necessary to be able to use a variable as a true look-up table, as it is unlikely
10168 // the compiler will be able to optimize the spvArrayCopy() into a constant LUT.
10169 // After a variable has been declared, we can no longer assign constant arrays in MSL unfortunately.
10170 statement(ts: to_expression(id: id_lhs), ts: " = ", ts: constant_expression(c: get<SPIRConstant>(id: id_rhs)), ts: ";");
10171 return true;
10172 }
10173
10174 if (is_tesc_shader() && has_decoration(id: id_lhs, decoration: DecorationBuiltIn))
10175 {
10176 auto builtin = BuiltIn(get_decoration(id: id_lhs, decoration: DecorationBuiltIn));
10177 // Need to manually unroll the array store.
10178 if (builtin == BuiltInTessLevelInner || builtin == BuiltInTessLevelOuter)
10179 {
10180 uint32_t array_size = get_physical_tess_level_array_size(builtin);
10181 if (array_size == 1)
10182 statement(ts: to_expression(id: id_lhs), ts: " = half(", ts: to_expression(id: id_rhs), ts: "[0]);");
10183 else
10184 {
10185 for (uint32_t i = 0; i < array_size; i++)
10186 statement(ts: to_expression(id: id_lhs), ts: "[", ts&: i, ts: "] = half(", ts: to_expression(id: id_rhs), ts: "[", ts&: i, ts: "]);");
10187 }
10188 return true;
10189 }
10190 }
10191
10192 auto lhs_storage = get_expression_effective_storage_class(ptr: id_lhs);
10193 auto rhs_storage = get_expression_effective_storage_class(ptr: id_rhs);
10194 if (!emit_array_copy(expr: nullptr, lhs_id: id_lhs, rhs_id: id_rhs, lhs_storage, rhs_storage))
10195 return false;
10196
10197 register_write(chain: id_lhs);
10198
10199 return true;
10200}
10201
10202// Emits one of the atomic functions. In MSL, the atomic functions operate on pointers
10203void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id, const char *op, Op opcode,
10204 uint32_t mem_order_1, uint32_t mem_order_2, bool has_mem_order_2, uint32_t obj, uint32_t op1,
10205 bool op1_is_pointer, bool op1_is_literal, uint32_t op2)
10206{
10207 string exp;
10208
10209 auto &ptr_type = expression_type(id: obj);
10210 auto &type = get_pointee_type(type: ptr_type);
10211 auto expected_type = type.basetype;
10212 if (opcode == OpAtomicUMax || opcode == OpAtomicUMin)
10213 expected_type = to_unsigned_basetype(width: type.width);
10214 else if (opcode == OpAtomicSMax || opcode == OpAtomicSMin)
10215 expected_type = to_signed_basetype(width: type.width);
10216
10217 bool use_native_image_atomic;
10218 if (msl_options.supports_msl_version(major: 3, minor: 1))
10219 use_native_image_atomic = check_atomic_image(id: obj);
10220 else
10221 use_native_image_atomic = false;
10222
10223 if (type.width == 64)
10224 SPIRV_CROSS_THROW("MSL currently does not support 64-bit atomics.");
10225
10226 auto remapped_type = type;
10227 remapped_type.basetype = expected_type;
10228
10229 auto *var = maybe_get_backing_variable(chain: obj);
10230 const auto *res_type = var ? &get<SPIRType>(id: var->basetype) : nullptr;
10231 assert(type.storage != StorageClassImage || res_type);
10232
10233 bool is_atomic_compare_exchange_strong = op1_is_pointer && op1;
10234
10235 bool check_discard = opcode != OpAtomicLoad && needs_frag_discard_checks() &&
10236 ptr_type.storage != StorageClassWorkgroup;
10237
10238 // Even compare exchange atomics are vec4 on metal for ... reasons :v
10239 uint32_t vec4_temporary_id = 0;
10240 if (use_native_image_atomic && is_atomic_compare_exchange_strong)
10241 {
10242 uint32_t &tmp_id = extra_sub_expressions[result_id];
10243 if (!tmp_id)
10244 {
10245 tmp_id = ir.increase_bound_by(count: 2);
10246
10247 auto vec4_type = get<SPIRType>(id: result_type);
10248 vec4_type.vecsize = 4;
10249 set<SPIRType>(id: tmp_id + 1, args&: vec4_type);
10250 }
10251
10252 vec4_temporary_id = tmp_id;
10253 }
10254
10255 if (check_discard)
10256 {
10257 if (is_atomic_compare_exchange_strong)
10258 {
10259 // We're already emitting a CAS loop here; a conditional won't hurt.
10260 emit_uninitialized_temporary_expression(type: result_type, id: result_id);
10261 if (vec4_temporary_id)
10262 emit_uninitialized_temporary_expression(type: vec4_temporary_id + 1, id: vec4_temporary_id);
10263 statement(ts: "if (!", ts: builtin_to_glsl(builtin: BuiltInHelperInvocation, storage: StorageClassInput), ts: ")");
10264 begin_scope();
10265 }
10266 else
10267 exp = join(ts: "(!", ts: builtin_to_glsl(builtin: BuiltInHelperInvocation, storage: StorageClassInput), ts: " ? ");
10268 }
10269
10270 if (use_native_image_atomic)
10271 {
10272 auto obj_expression = to_expression(id: obj);
10273 auto split_index = obj_expression.find_first_of(c: '@');
10274
10275 // Will only be false if we're in "force recompile later" mode.
10276 if (split_index != string::npos)
10277 {
10278 auto coord = obj_expression.substr(pos: split_index + 1);
10279 auto image_expr = obj_expression.substr(pos: 0, n: split_index);
10280
10281 // Handle problem cases with sign where we need signed min/max on a uint image for example.
10282 // It seems to work to cast the texture type itself, even if it is probably wildly outside of spec,
10283 // but SPIR-V requires this to work.
10284 if ((opcode == OpAtomicUMax || opcode == OpAtomicUMin ||
10285 opcode == OpAtomicSMax || opcode == OpAtomicSMin) &&
10286 type.basetype != expected_type)
10287 {
10288 auto *backing_var = maybe_get_backing_variable(chain: obj);
10289 if (backing_var)
10290 {
10291 add_spv_func_and_recompile(spv_func: SPVFuncImplTextureCast);
10292
10293 const auto *backing_type = &get<SPIRType>(id: backing_var->basetype);
10294 while (backing_type->op != OpTypeImage)
10295 backing_type = &get<SPIRType>(id: backing_type->parent_type);
10296
10297 auto img_type = *backing_type;
10298 auto tmp_type = type;
10299 tmp_type.basetype = expected_type;
10300 img_type.image.type = ir.increase_bound_by(count: 1);
10301 set<SPIRType>(id: img_type.image.type, args&: tmp_type);
10302
10303 image_expr = join(ts: "spvTextureCast<", ts: type_to_glsl(type: img_type, id: obj), ts: ">(", ts&: image_expr, ts: ")");
10304 }
10305 }
10306
10307 exp += join(ts&: image_expr, ts: ".", ts&: op, ts: "(");
10308 if (ptr_type.storage == StorageClassImage && res_type->image.arrayed)
10309 {
10310 switch (res_type->image.dim)
10311 {
10312 case Dim1D:
10313 if (msl_options.texture_1D_as_2D)
10314 exp += join(ts: "uint2(", ts&: coord, ts: ".x, 0), ", ts&: coord, ts: ".y");
10315 else
10316 exp += join(ts&: coord, ts: ".x, ", ts&: coord, ts: ".y");
10317
10318 break;
10319 case Dim2D:
10320 exp += join(ts&: coord, ts: ".xy, ", ts&: coord, ts: ".z");
10321 break;
10322 default:
10323 SPIRV_CROSS_THROW("Cannot do atomics on Cube textures.");
10324 }
10325 }
10326 else if (ptr_type.storage == StorageClassImage && res_type->image.dim == Dim1D && msl_options.texture_1D_as_2D)
10327 exp += join(ts: "uint2(", ts&: coord, ts: ", 0)");
10328 else
10329 exp += coord;
10330 }
10331 else
10332 {
10333 exp += obj_expression;
10334 }
10335 }
10336 else
10337 {
10338 exp += string(op) + "_explicit(";
10339 exp += "(";
10340 // Emulate texture2D atomic operations
10341 if (ptr_type.storage == StorageClassImage)
10342 {
10343 auto &flags = ir.get_decoration_bitset(id: var->self);
10344 if (decoration_flags_signal_volatile(flags))
10345 exp += "volatile ";
10346 exp += "device";
10347 }
10348 else if (var && ptr_type.storage != StorageClassPhysicalStorageBuffer)
10349 {
10350 exp += get_argument_address_space(argument: *var);
10351 }
10352 else
10353 {
10354 // Fallback scenario, could happen for raw pointers.
10355 exp += ptr_type.storage == StorageClassWorkgroup ? "threadgroup" : "device";
10356 }
10357
10358 exp += " atomic_";
10359 // For signed and unsigned min/max, we can signal this through the pointer type.
10360 // There is no other way, since C++ does not have explicit signage for atomics.
10361 exp += type_to_glsl(type: remapped_type);
10362 exp += "*)";
10363
10364 exp += "&";
10365 exp += to_enclosed_expression(id: obj);
10366 }
10367
10368 if (is_atomic_compare_exchange_strong)
10369 {
10370 assert(strcmp(op, "atomic_compare_exchange_weak") == 0);
10371 assert(op2);
10372 assert(has_mem_order_2);
10373 exp += ", &";
10374 exp += to_name(id: vec4_temporary_id ? vec4_temporary_id : result_id);
10375 exp += ", ";
10376 exp += to_expression(id: op2);
10377
10378 if (!use_native_image_atomic)
10379 {
10380 exp += ", ";
10381 exp += get_memory_order(spv_mem_sem: mem_order_1);
10382 exp += ", ";
10383 exp += get_memory_order(spv_mem_sem: mem_order_2);
10384 }
10385 exp += ")";
10386
10387 // MSL only supports the weak atomic compare exchange, so emit a CAS loop here.
10388 // The MSL function returns false if the atomic write fails OR the comparison test fails,
10389 // so we must validate that it wasn't the comparison test that failed before continuing
10390 // the CAS loop, otherwise it will loop infinitely, with the comparison test always failing.
10391 // The function updates the comparator value from the memory value, so the additional
10392 // comparison test evaluates the memory value against the expected value.
10393 if (!check_discard)
10394 {
10395 emit_uninitialized_temporary_expression(type: result_type, id: result_id);
10396 if (vec4_temporary_id)
10397 emit_uninitialized_temporary_expression(type: vec4_temporary_id + 1, id: vec4_temporary_id);
10398 }
10399
10400 statement(ts: "do");
10401 begin_scope();
10402
10403 string scalar_expression;
10404 if (vec4_temporary_id)
10405 scalar_expression = join(ts: to_expression(id: vec4_temporary_id), ts: ".x");
10406 else
10407 scalar_expression = to_expression(id: result_id);
10408
10409 statement(ts&: scalar_expression, ts: " = ", ts: to_expression(id: op1), ts: ";");
10410 end_scope_decl(decl: join(ts: "while (!", ts&: exp, ts: " && ", ts&: scalar_expression, ts: " == ", ts: to_enclosed_expression(id: op1), ts: ")"));
10411 if (vec4_temporary_id)
10412 statement(ts: to_expression(id: result_id), ts: " = ", ts&: scalar_expression, ts: ";");
10413
10414 // Vulkan: (section 9.29: ... and values returned by atomic instructions in helper invocations are undefined)
10415 if (check_discard)
10416 {
10417 end_scope();
10418 statement(ts: "else");
10419 begin_scope();
10420 statement(ts: to_expression(id: result_id), ts: " = {};");
10421 end_scope();
10422 }
10423 }
10424 else
10425 {
10426 assert(strcmp(op, "atomic_compare_exchange_weak") != 0);
10427
10428 if (op1)
10429 {
10430 exp += ", ";
10431 if (op1_is_literal)
10432 exp += to_string(val: op1);
10433 else
10434 exp += bitcast_expression(target_type: expected_type, arg: op1);
10435 }
10436
10437 if (op2)
10438 exp += ", " + to_expression(id: op2);
10439
10440 if (!use_native_image_atomic)
10441 {
10442 exp += string(", ") + get_memory_order(spv_mem_sem: mem_order_1);
10443 if (has_mem_order_2)
10444 exp += string(", ") + get_memory_order(spv_mem_sem: mem_order_2);
10445 }
10446
10447 exp += ")";
10448
10449 // For some particular reason, atomics return vec4 in Metal ...
10450 if (use_native_image_atomic)
10451 exp += ".x";
10452
10453 // Vulkan: (section 9.29: ... and values returned by atomic instructions in helper invocations are undefined)
10454 if (check_discard)
10455 {
10456 exp += " : ";
10457 if (strcmp(s1: op, s2: "atomic_store") != 0)
10458 exp += join(ts: type_to_glsl(type: get<SPIRType>(id: result_type)), ts: "{}");
10459 else
10460 exp += "((void)0)";
10461 exp += ")";
10462 }
10463
10464 if (expected_type != type.basetype)
10465 exp = bitcast_expression(target_type: type, expr_type: expected_type, expr: exp);
10466
10467 if (strcmp(s1: op, s2: "atomic_store") != 0)
10468 emit_op(result_type, result_id, rhs: exp, forward_rhs: false);
10469 else
10470 statement(ts&: exp, ts: ";");
10471 }
10472
10473 flush_all_atomic_capable_variables();
10474}
10475
10476// Metal only supports relaxed memory order for now
10477const char *CompilerMSL::get_memory_order(uint32_t)
10478{
10479 return "memory_order_relaxed";
10480}
10481
10482// Override for MSL-specific extension syntax instructions.
10483// In some cases, deliberately select either the fast or precise versions of the MSL functions to match Vulkan math precision results.
10484void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, uint32_t count)
10485{
10486 auto op = static_cast<GLSLstd450>(eop);
10487
10488 // If we need to do implicit bitcasts, make sure we do it with the correct type.
10489 uint32_t integer_width = get_integer_width_for_glsl_instruction(op, arguments: args, length: count);
10490 auto int_type = to_signed_basetype(width: integer_width);
10491 auto uint_type = to_unsigned_basetype(width: integer_width);
10492
10493 op = get_remapped_glsl_op(std450_op: op);
10494
10495 auto &restype = get<SPIRType>(id: result_type);
10496
10497 switch (op)
10498 {
10499 case GLSLstd450Sinh:
10500 if (restype.basetype == SPIRType::Half)
10501 {
10502 // MSL does not have overload for half. Force-cast back to half.
10503 auto expr = join(ts: "half(fast::sinh(", ts: to_unpacked_expression(id: args[0]), ts: "))");
10504 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: should_forward(id: args[0]));
10505 inherit_expression_dependencies(dst: id, source: args[0]);
10506 }
10507 else
10508 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "fast::sinh");
10509 break;
10510 case GLSLstd450Cosh:
10511 if (restype.basetype == SPIRType::Half)
10512 {
10513 // MSL does not have overload for half. Force-cast back to half.
10514 auto expr = join(ts: "half(fast::cosh(", ts: to_unpacked_expression(id: args[0]), ts: "))");
10515 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: should_forward(id: args[0]));
10516 inherit_expression_dependencies(dst: id, source: args[0]);
10517 }
10518 else
10519 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "fast::cosh");
10520 break;
10521 case GLSLstd450Tanh:
10522 if (restype.basetype == SPIRType::Half)
10523 {
10524 // MSL does not have overload for half. Force-cast back to half.
10525 auto expr = join(ts: "half(fast::tanh(", ts: to_unpacked_expression(id: args[0]), ts: "))");
10526 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: should_forward(id: args[0]));
10527 inherit_expression_dependencies(dst: id, source: args[0]);
10528 }
10529 else
10530 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "precise::tanh");
10531 break;
10532 case GLSLstd450Atan2:
10533 if (restype.basetype == SPIRType::Half)
10534 {
10535 // MSL does not have overload for half. Force-cast back to half.
10536 auto expr = join(ts: "half(fast::atan2(", ts: to_unpacked_expression(id: args[0]), ts: ", ", ts: to_unpacked_expression(id: args[1]), ts: "))");
10537 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: should_forward(id: args[0]) && should_forward(id: args[1]));
10538 inherit_expression_dependencies(dst: id, source: args[0]);
10539 inherit_expression_dependencies(dst: id, source: args[1]);
10540 }
10541 else
10542 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "precise::atan2");
10543 break;
10544 case GLSLstd450InverseSqrt:
10545 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "rsqrt");
10546 break;
10547 case GLSLstd450RoundEven:
10548 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "rint");
10549 break;
10550
10551 case GLSLstd450FindILsb:
10552 {
10553 // In this template version of findLSB, we return T.
10554 auto basetype = expression_type(id: args[0]).basetype;
10555 emit_unary_func_op_cast(result_type, result_id: id, op0: args[0], op: "spvFindLSB", input_type: basetype, expected_result_type: basetype);
10556 break;
10557 }
10558
10559 case GLSLstd450FindSMsb:
10560 emit_unary_func_op_cast(result_type, result_id: id, op0: args[0], op: "spvFindSMSB", input_type: int_type, expected_result_type: int_type);
10561 break;
10562
10563 case GLSLstd450FindUMsb:
10564 emit_unary_func_op_cast(result_type, result_id: id, op0: args[0], op: "spvFindUMSB", input_type: uint_type, expected_result_type: uint_type);
10565 break;
10566
10567 case GLSLstd450PackSnorm4x8:
10568 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "pack_float_to_snorm4x8");
10569 break;
10570 case GLSLstd450PackUnorm4x8:
10571 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "pack_float_to_unorm4x8");
10572 break;
10573 case GLSLstd450PackSnorm2x16:
10574 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "pack_float_to_snorm2x16");
10575 break;
10576 case GLSLstd450PackUnorm2x16:
10577 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "pack_float_to_unorm2x16");
10578 break;
10579
10580 case GLSLstd450PackHalf2x16:
10581 {
10582 auto expr = join(ts: "as_type<uint>(half2(", ts: to_expression(id: args[0]), ts: "))");
10583 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: should_forward(id: args[0]));
10584 inherit_expression_dependencies(dst: id, source: args[0]);
10585 break;
10586 }
10587
10588 case GLSLstd450UnpackSnorm4x8:
10589 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "unpack_snorm4x8_to_float");
10590 break;
10591 case GLSLstd450UnpackUnorm4x8:
10592 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "unpack_unorm4x8_to_float");
10593 break;
10594 case GLSLstd450UnpackSnorm2x16:
10595 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "unpack_snorm2x16_to_float");
10596 break;
10597 case GLSLstd450UnpackUnorm2x16:
10598 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "unpack_unorm2x16_to_float");
10599 break;
10600
10601 case GLSLstd450UnpackHalf2x16:
10602 {
10603 auto expr = join(ts: "float2(as_type<half2>(", ts: to_expression(id: args[0]), ts: "))");
10604 emit_op(result_type, result_id: id, rhs: expr, forward_rhs: should_forward(id: args[0]));
10605 inherit_expression_dependencies(dst: id, source: args[0]);
10606 break;
10607 }
10608
10609 case GLSLstd450PackDouble2x32:
10610 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "unsupported_GLSLstd450PackDouble2x32"); // Currently unsupported
10611 break;
10612 case GLSLstd450UnpackDouble2x32:
10613 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "unsupported_GLSLstd450UnpackDouble2x32"); // Currently unsupported
10614 break;
10615
10616 case GLSLstd450MatrixInverse:
10617 {
10618 auto &mat_type = get<SPIRType>(id: result_type);
10619 switch (mat_type.columns)
10620 {
10621 case 2:
10622 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvInverse2x2");
10623 break;
10624 case 3:
10625 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvInverse3x3");
10626 break;
10627 case 4:
10628 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "spvInverse4x4");
10629 break;
10630 default:
10631 break;
10632 }
10633 break;
10634 }
10635
10636 case GLSLstd450FMin:
10637 // If the result type isn't float, don't bother calling the specific
10638 // precise::/fast:: version. Metal doesn't have those for half and
10639 // double types.
10640 if (get<SPIRType>(id: result_type).basetype != SPIRType::Float)
10641 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "min");
10642 else
10643 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "fast::min");
10644 break;
10645
10646 case GLSLstd450FMax:
10647 if (get<SPIRType>(id: result_type).basetype != SPIRType::Float)
10648 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "max");
10649 else
10650 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "fast::max");
10651 break;
10652
10653 case GLSLstd450FClamp:
10654 // TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call.
10655 if (get<SPIRType>(id: result_type).basetype != SPIRType::Float)
10656 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "clamp");
10657 else
10658 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "fast::clamp");
10659 break;
10660
10661 case GLSLstd450NMin:
10662 if (get<SPIRType>(id: result_type).basetype != SPIRType::Float)
10663 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "min");
10664 else
10665 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "precise::min");
10666 break;
10667
10668 case GLSLstd450NMax:
10669 if (get<SPIRType>(id: result_type).basetype != SPIRType::Float)
10670 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "max");
10671 else
10672 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "precise::max");
10673 break;
10674
10675 case GLSLstd450NClamp:
10676 // TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call.
10677 if (get<SPIRType>(id: result_type).basetype != SPIRType::Float)
10678 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "clamp");
10679 else
10680 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "precise::clamp");
10681 break;
10682
10683 case GLSLstd450InterpolateAtCentroid:
10684 {
10685 // We can't just emit the expression normally, because the qualified name contains a call to the default
10686 // interpolate method, or refers to a local variable. We saved the interface index we need; use it to construct
10687 // the base for the method call.
10688 uint32_t interface_index = get_extended_decoration(id: args[0], decoration: SPIRVCrossDecorationInterfaceMemberIndex);
10689 string component;
10690 if (has_extended_decoration(id: args[0], decoration: SPIRVCrossDecorationInterpolantComponentExpr))
10691 {
10692 uint32_t index_expr = get_extended_decoration(id: args[0], decoration: SPIRVCrossDecorationInterpolantComponentExpr);
10693 auto *c = maybe_get<SPIRConstant>(id: index_expr);
10694 if (!c || c->specialization)
10695 component = join(ts: "[", ts: to_expression(id: index_expr), ts: "]");
10696 else
10697 component = join(ts: ".", ts: index_to_swizzle(index: c->scalar()));
10698 }
10699 emit_op(result_type, result_id: id,
10700 rhs: join(ts: to_name(id: stage_in_var_id), ts: ".", ts: to_member_name(type: get_stage_in_struct_type(), index: interface_index),
10701 ts: ".interpolate_at_centroid()", ts&: component),
10702 forward_rhs: should_forward(id: args[0]));
10703 break;
10704 }
10705
10706 case GLSLstd450InterpolateAtSample:
10707 {
10708 uint32_t interface_index = get_extended_decoration(id: args[0], decoration: SPIRVCrossDecorationInterfaceMemberIndex);
10709 string component;
10710 if (has_extended_decoration(id: args[0], decoration: SPIRVCrossDecorationInterpolantComponentExpr))
10711 {
10712 uint32_t index_expr = get_extended_decoration(id: args[0], decoration: SPIRVCrossDecorationInterpolantComponentExpr);
10713 auto *c = maybe_get<SPIRConstant>(id: index_expr);
10714 if (!c || c->specialization)
10715 component = join(ts: "[", ts: to_expression(id: index_expr), ts: "]");
10716 else
10717 component = join(ts: ".", ts: index_to_swizzle(index: c->scalar()));
10718 }
10719 emit_op(result_type, result_id: id,
10720 rhs: join(ts: to_name(id: stage_in_var_id), ts: ".", ts: to_member_name(type: get_stage_in_struct_type(), index: interface_index),
10721 ts: ".interpolate_at_sample(", ts: to_expression(id: args[1]), ts: ")", ts&: component),
10722 forward_rhs: should_forward(id: args[0]) && should_forward(id: args[1]));
10723 break;
10724 }
10725
10726 case GLSLstd450InterpolateAtOffset:
10727 {
10728 uint32_t interface_index = get_extended_decoration(id: args[0], decoration: SPIRVCrossDecorationInterfaceMemberIndex);
10729 string component;
10730 if (has_extended_decoration(id: args[0], decoration: SPIRVCrossDecorationInterpolantComponentExpr))
10731 {
10732 uint32_t index_expr = get_extended_decoration(id: args[0], decoration: SPIRVCrossDecorationInterpolantComponentExpr);
10733 auto *c = maybe_get<SPIRConstant>(id: index_expr);
10734 if (!c || c->specialization)
10735 component = join(ts: "[", ts: to_expression(id: index_expr), ts: "]");
10736 else
10737 component = join(ts: ".", ts: index_to_swizzle(index: c->scalar()));
10738 }
10739 // Like Direct3D, Metal puts the (0, 0) at the upper-left corner, not the center as SPIR-V and GLSL do.
10740 // Offset the offset by (1/2 - 1/16), or 0.4375, to compensate for this.
10741 // It has to be (1/2 - 1/16) and not 1/2, or several CTS tests subtly break on Intel.
10742 emit_op(result_type, result_id: id,
10743 rhs: join(ts: to_name(id: stage_in_var_id), ts: ".", ts: to_member_name(type: get_stage_in_struct_type(), index: interface_index),
10744 ts: ".interpolate_at_offset(", ts: to_expression(id: args[1]), ts: " + 0.4375)", ts&: component),
10745 forward_rhs: should_forward(id: args[0]) && should_forward(id: args[1]));
10746 break;
10747 }
10748
10749 case GLSLstd450Distance:
10750 // MSL does not support scalar versions here.
10751 if (expression_type(id: args[0]).vecsize == 1)
10752 {
10753 // Equivalent to length(a - b) -> abs(a - b).
10754 emit_op(result_type, result_id: id,
10755 rhs: join(ts: "abs(", ts: to_enclosed_unpacked_expression(id: args[0]), ts: " - ",
10756 ts: to_enclosed_unpacked_expression(id: args[1]), ts: ")"),
10757 forward_rhs: should_forward(id: args[0]) && should_forward(id: args[1]));
10758 inherit_expression_dependencies(dst: id, source: args[0]);
10759 inherit_expression_dependencies(dst: id, source: args[1]);
10760 }
10761 else
10762 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
10763 break;
10764
10765 case GLSLstd450Length:
10766 // MSL does not support scalar versions, so use abs().
10767 if (expression_type(id: args[0]).vecsize == 1)
10768 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "abs");
10769 else
10770 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
10771 break;
10772
10773 case GLSLstd450Normalize:
10774 {
10775 auto &exp_type = expression_type(id: args[0]);
10776 // MSL does not support scalar versions here.
10777 // MSL has no implementation for normalize in the fast:: namespace for half2 and half3
10778 // Returns -1 or 1 for valid input, sign() does the job.
10779 if (exp_type.vecsize == 1)
10780 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "sign");
10781 else if (exp_type.vecsize <= 3 && exp_type.basetype == SPIRType::Half)
10782 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "normalize");
10783 else
10784 emit_unary_func_op(result_type, result_id: id, op0: args[0], op: "fast::normalize");
10785 break;
10786 }
10787 case GLSLstd450Reflect:
10788 if (get<SPIRType>(id: result_type).vecsize == 1)
10789 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "spvReflect");
10790 else
10791 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
10792 break;
10793
10794 case GLSLstd450Refract:
10795 if (get<SPIRType>(id: result_type).vecsize == 1)
10796 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "spvRefract");
10797 else
10798 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
10799 break;
10800
10801 case GLSLstd450FaceForward:
10802 if (get<SPIRType>(id: result_type).vecsize == 1)
10803 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "spvFaceForward");
10804 else
10805 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
10806 break;
10807
10808 case GLSLstd450Modf:
10809 case GLSLstd450Frexp:
10810 {
10811 // Special case. If the variable is a scalar access chain, we cannot use it directly. We have to emit a temporary.
10812 // Another special case is if the variable is in a storage class which is not thread.
10813 auto *ptr = maybe_get<SPIRExpression>(id: args[1]);
10814 auto &type = expression_type(id: args[1]);
10815
10816 bool is_thread_storage = storage_class_array_is_thread(storage: type.storage);
10817 if (type.storage == StorageClassOutput && capture_output_to_buffer)
10818 is_thread_storage = false;
10819
10820 if (!is_thread_storage ||
10821 (ptr && ptr->access_chain && is_scalar(type: expression_type(id: args[1]))))
10822 {
10823 register_call_out_argument(id: args[1]);
10824 forced_temporaries.insert(x: id);
10825
10826 // Need to create temporaries and copy over to access chain after.
10827 // We cannot directly take the reference of a vector swizzle in MSL, even if it's scalar ...
10828 uint32_t &tmp_id = extra_sub_expressions[id];
10829 if (!tmp_id)
10830 tmp_id = ir.increase_bound_by(count: 1);
10831
10832 uint32_t tmp_type_id = get_pointee_type_id(type_id: expression_type_id(id: args[1]));
10833 emit_uninitialized_temporary_expression(type: tmp_type_id, id: tmp_id);
10834 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: tmp_id, op: eop == GLSLstd450Modf ? "modf" : "frexp");
10835 statement(ts: to_expression(id: args[1]), ts: " = ", ts: to_expression(id: tmp_id), ts: ";");
10836 }
10837 else
10838 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
10839 break;
10840 }
10841
10842 case GLSLstd450Pow:
10843 // powr makes x < 0.0 undefined, just like SPIR-V.
10844 emit_binary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op: "powr");
10845 break;
10846
10847 default:
10848 CompilerGLSL::emit_glsl_op(result_type, result_id: id, op: eop, args, count);
10849 break;
10850 }
10851}
10852
10853void CompilerMSL::emit_spv_amd_shader_trinary_minmax_op(uint32_t result_type, uint32_t id, uint32_t eop,
10854 const uint32_t *args, uint32_t count)
10855{
10856 enum AMDShaderTrinaryMinMax
10857 {
10858 FMin3AMD = 1,
10859 UMin3AMD = 2,
10860 SMin3AMD = 3,
10861 FMax3AMD = 4,
10862 UMax3AMD = 5,
10863 SMax3AMD = 6,
10864 FMid3AMD = 7,
10865 UMid3AMD = 8,
10866 SMid3AMD = 9
10867 };
10868
10869 if (!msl_options.supports_msl_version(major: 2, minor: 1))
10870 SPIRV_CROSS_THROW("Trinary min/max functions require MSL 2.1.");
10871
10872 auto op = static_cast<AMDShaderTrinaryMinMax>(eop);
10873
10874 switch (op)
10875 {
10876 case FMid3AMD:
10877 case UMid3AMD:
10878 case SMid3AMD:
10879 emit_trinary_func_op(result_type, result_id: id, op0: args[0], op1: args[1], op2: args[2], op: "median3");
10880 break;
10881 default:
10882 CompilerGLSL::emit_spv_amd_shader_trinary_minmax_op(result_type, result_id: id, op: eop, args, count);
10883 break;
10884 }
10885}
10886
10887// Emit a structure declaration for the specified interface variable.
10888void CompilerMSL::emit_interface_block(uint32_t ib_var_id)
10889{
10890 if (ib_var_id)
10891 {
10892 auto &ib_var = get<SPIRVariable>(id: ib_var_id);
10893 auto &ib_type = get_variable_data_type(var: ib_var);
10894 //assert(ib_type.basetype == SPIRType::Struct && !ib_type.member_types.empty());
10895 assert(ib_type.basetype == SPIRType::Struct);
10896 emit_struct(type&: ib_type);
10897 }
10898}
10899
10900// Emits the declaration signature of the specified function.
10901// If this is the entry point function, Metal-specific return value and function arguments are added.
10902void CompilerMSL::emit_function_prototype(SPIRFunction &func, const Bitset &)
10903{
10904 if (func.self != ir.default_entry_point)
10905 add_function_overload(func);
10906
10907 local_variable_names = resource_names;
10908 string decl;
10909
10910 processing_entry_point = func.self == ir.default_entry_point;
10911
10912 // Metal helper functions must be static force-inline otherwise they will cause problems when linked together in a single Metallib.
10913 if (!processing_entry_point)
10914 statement(ts&: force_inline);
10915
10916 auto &type = get<SPIRType>(id: func.return_type);
10917
10918 if (!type.array.empty() && msl_options.force_native_arrays)
10919 {
10920 // We cannot return native arrays in MSL, so "return" through an out variable.
10921 decl += "void";
10922 }
10923 else
10924 {
10925 decl += func_type_decl(type);
10926 }
10927
10928 decl += " ";
10929 decl += to_name(id: func.self);
10930 decl += "(";
10931
10932 if (!type.array.empty() && msl_options.force_native_arrays)
10933 {
10934 // Fake arrays returns by writing to an out array instead.
10935 decl += "thread ";
10936 decl += type_to_glsl(type);
10937 decl += " (&spvReturnValue)";
10938 decl += type_to_array_glsl(type, variable_id: 0);
10939 if (!func.arguments.empty())
10940 decl += ", ";
10941 }
10942
10943 if (processing_entry_point)
10944 {
10945 if (msl_options.argument_buffers)
10946 decl += entry_point_args_argument_buffer(append_comma: !func.arguments.empty());
10947 else
10948 decl += entry_point_args_classic(append_comma: !func.arguments.empty());
10949
10950 // append entry point args to avoid conflicts in local variable names.
10951 local_variable_names.insert(first: resource_names.begin(), last: resource_names.end());
10952
10953 // If entry point function has variables that require early declaration,
10954 // ensure they each have an empty initializer, creating one if needed.
10955 // This is done at this late stage because the initialization expression
10956 // is cleared after each compilation pass.
10957 for (auto var_id : vars_needing_early_declaration)
10958 {
10959 auto &ed_var = get<SPIRVariable>(id: var_id);
10960 ID &initializer = ed_var.initializer;
10961 if (!initializer)
10962 initializer = ir.increase_bound_by(count: 1);
10963
10964 // Do not override proper initializers.
10965 if (ir.ids[initializer].get_type() == TypeNone || ir.ids[initializer].get_type() == TypeExpression)
10966 set<SPIRExpression>(id: ed_var.initializer, args: "{}", args&: ed_var.basetype, args: true);
10967 }
10968 }
10969
10970 for (auto &arg : func.arguments)
10971 {
10972 uint32_t name_id = arg.id;
10973
10974 auto *var = maybe_get<SPIRVariable>(id: arg.id);
10975 if (var)
10976 {
10977 // If we need to modify the name of the variable, make sure we modify the original variable.
10978 // Our alias is just a shadow variable.
10979 if (arg.alias_global_variable && var->basevariable)
10980 name_id = var->basevariable;
10981
10982 var->parameter = &arg; // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
10983 }
10984
10985 add_local_variable_name(id: name_id);
10986
10987 decl += argument_decl(arg);
10988
10989 bool is_dynamic_img_sampler = has_extended_decoration(id: arg.id, decoration: SPIRVCrossDecorationDynamicImageSampler);
10990
10991 auto &arg_type = get<SPIRType>(id: arg.type);
10992 if (arg_type.basetype == SPIRType::SampledImage && !is_dynamic_img_sampler)
10993 {
10994 // Manufacture automatic plane args for multiplanar texture
10995 uint32_t planes = 1;
10996 if (auto *constexpr_sampler = find_constexpr_sampler(id: name_id))
10997 if (constexpr_sampler->ycbcr_conversion_enable)
10998 planes = constexpr_sampler->planes;
10999 for (uint32_t i = 1; i < planes; i++)
11000 decl += join(ts: ", ", ts: argument_decl(arg), ts&: plane_name_suffix, ts&: i);
11001
11002 // Manufacture automatic sampler arg for SampledImage texture
11003 if (arg_type.image.dim != DimBuffer)
11004 {
11005 if (arg_type.array.empty() || (var ? is_var_runtime_size_array(var: *var) : is_runtime_size_array(type: arg_type)))
11006 {
11007 decl += join(ts: ", ", ts: sampler_type(type: arg_type, id: arg.id, member: false), ts: " ", ts: to_sampler_expression(id: name_id));
11008 }
11009 else
11010 {
11011 const char *sampler_address_space =
11012 descriptor_address_space(id: name_id,
11013 storage: StorageClassUniformConstant,
11014 plain_address_space: "thread const");
11015 decl += join(ts: ", ", ts&: sampler_address_space, ts: " ", ts: sampler_type(type: arg_type, id: name_id, member: false), ts: "& ",
11016 ts: to_sampler_expression(id: name_id));
11017 }
11018 }
11019 }
11020
11021 // Manufacture automatic swizzle arg.
11022 if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type: arg_type) &&
11023 !is_dynamic_img_sampler)
11024 {
11025 bool arg_is_array = !arg_type.array.empty();
11026 decl += join(ts: ", constant uint", ts: arg_is_array ? "* " : "& ", ts: to_swizzle_expression(id: name_id));
11027 }
11028
11029 if (buffer_requires_array_length(id: name_id))
11030 {
11031 bool arg_is_array = !arg_type.array.empty();
11032 decl += join(ts: ", constant uint", ts: arg_is_array ? "* " : "& ", ts: to_buffer_size_expression(id: name_id));
11033 }
11034
11035 if (&arg != &func.arguments.back())
11036 decl += ", ";
11037 }
11038
11039 decl += ")";
11040 statement(ts&: decl);
11041}
11042
11043static bool needs_chroma_reconstruction(const MSLConstexprSampler *constexpr_sampler)
11044{
11045 // For now, only multiplanar images need explicit reconstruction. GBGR and BGRG images
11046 // use implicit reconstruction.
11047 return constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && constexpr_sampler->planes > 1;
11048}
11049
11050// Returns the texture sampling function string for the specified image and sampling characteristics.
11051string CompilerMSL::to_function_name(const TextureFunctionNameArguments &args)
11052{
11053 VariableID img = args.base.img;
11054 const MSLConstexprSampler *constexpr_sampler = nullptr;
11055 bool is_dynamic_img_sampler = false;
11056 if (auto *var = maybe_get_backing_variable(chain: img))
11057 {
11058 constexpr_sampler = find_constexpr_sampler(id: var->basevariable ? var->basevariable : VariableID(var->self));
11059 is_dynamic_img_sampler = has_extended_decoration(id: var->self, decoration: SPIRVCrossDecorationDynamicImageSampler);
11060 }
11061
11062 // Special-case gather. We have to alter the component being looked up in the swizzle case.
11063 if (msl_options.swizzle_texture_samples && args.base.is_gather && !is_dynamic_img_sampler &&
11064 (!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable))
11065 {
11066 bool is_compare = comparison_ids.count(x: img);
11067 add_spv_func_and_recompile(spv_func: is_compare ? SPVFuncImplGatherCompareSwizzle : SPVFuncImplGatherSwizzle);
11068 return is_compare ? "spvGatherCompareSwizzle" : "spvGatherSwizzle";
11069 }
11070
11071 // Special-case gather with an array of offsets. We have to lower into 4 separate gathers.
11072 if (args.has_array_offsets && !is_dynamic_img_sampler &&
11073 (!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable))
11074 {
11075 bool is_compare = comparison_ids.count(x: img);
11076 add_spv_func_and_recompile(spv_func: is_compare ? SPVFuncImplGatherCompareConstOffsets : SPVFuncImplGatherConstOffsets);
11077 add_spv_func_and_recompile(spv_func: SPVFuncImplForwardArgs);
11078 return is_compare ? "spvGatherCompareConstOffsets" : "spvGatherConstOffsets";
11079 }
11080
11081 auto *combined = maybe_get<SPIRCombinedImageSampler>(id: img);
11082
11083 // Texture reference
11084 string fname;
11085 if (needs_chroma_reconstruction(constexpr_sampler) && !is_dynamic_img_sampler)
11086 {
11087 if (constexpr_sampler->planes != 2 && constexpr_sampler->planes != 3)
11088 SPIRV_CROSS_THROW("Unhandled number of color image planes!");
11089 // 444 images aren't downsampled, so we don't need to do linear filtering.
11090 if (constexpr_sampler->resolution == MSL_FORMAT_RESOLUTION_444 ||
11091 constexpr_sampler->chroma_filter == MSL_SAMPLER_FILTER_NEAREST)
11092 {
11093 if (constexpr_sampler->planes == 2)
11094 add_spv_func_and_recompile(spv_func: SPVFuncImplChromaReconstructNearest2Plane);
11095 else
11096 add_spv_func_and_recompile(spv_func: SPVFuncImplChromaReconstructNearest3Plane);
11097 fname = "spvChromaReconstructNearest";
11098 }
11099 else // Linear with a downsampled format
11100 {
11101 fname = "spvChromaReconstructLinear";
11102 switch (constexpr_sampler->resolution)
11103 {
11104 case MSL_FORMAT_RESOLUTION_444:
11105 assert(false);
11106 break; // not reached
11107 case MSL_FORMAT_RESOLUTION_422:
11108 switch (constexpr_sampler->x_chroma_offset)
11109 {
11110 case MSL_CHROMA_LOCATION_COSITED_EVEN:
11111 if (constexpr_sampler->planes == 2)
11112 add_spv_func_and_recompile(spv_func: SPVFuncImplChromaReconstructLinear422CositedEven2Plane);
11113 else
11114 add_spv_func_and_recompile(spv_func: SPVFuncImplChromaReconstructLinear422CositedEven3Plane);
11115 fname += "422CositedEven";
11116 break;
11117 case MSL_CHROMA_LOCATION_MIDPOINT:
11118 if (constexpr_sampler->planes == 2)
11119 add_spv_func_and_recompile(spv_func: SPVFuncImplChromaReconstructLinear422Midpoint2Plane);
11120 else
11121 add_spv_func_and_recompile(spv_func: SPVFuncImplChromaReconstructLinear422Midpoint3Plane);
11122 fname += "422Midpoint";
11123 break;
11124 default:
11125 SPIRV_CROSS_THROW("Invalid chroma location.");
11126 }
11127 break;
11128 case MSL_FORMAT_RESOLUTION_420:
11129 fname += "420";
11130 switch (constexpr_sampler->x_chroma_offset)
11131 {
11132 case MSL_CHROMA_LOCATION_COSITED_EVEN:
11133 switch (constexpr_sampler->y_chroma_offset)
11134 {
11135 case MSL_CHROMA_LOCATION_COSITED_EVEN:
11136 if (constexpr_sampler->planes == 2)
11137 add_spv_func_and_recompile(
11138 spv_func: SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven2Plane);
11139 else
11140 add_spv_func_and_recompile(
11141 spv_func: SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven3Plane);
11142 fname += "XCositedEvenYCositedEven";
11143 break;
11144 case MSL_CHROMA_LOCATION_MIDPOINT:
11145 if (constexpr_sampler->planes == 2)
11146 add_spv_func_and_recompile(
11147 spv_func: SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint2Plane);
11148 else
11149 add_spv_func_and_recompile(
11150 spv_func: SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint3Plane);
11151 fname += "XCositedEvenYMidpoint";
11152 break;
11153 default:
11154 SPIRV_CROSS_THROW("Invalid Y chroma location.");
11155 }
11156 break;
11157 case MSL_CHROMA_LOCATION_MIDPOINT:
11158 switch (constexpr_sampler->y_chroma_offset)
11159 {
11160 case MSL_CHROMA_LOCATION_COSITED_EVEN:
11161 if (constexpr_sampler->planes == 2)
11162 add_spv_func_and_recompile(
11163 spv_func: SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven2Plane);
11164 else
11165 add_spv_func_and_recompile(
11166 spv_func: SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven3Plane);
11167 fname += "XMidpointYCositedEven";
11168 break;
11169 case MSL_CHROMA_LOCATION_MIDPOINT:
11170 if (constexpr_sampler->planes == 2)
11171 add_spv_func_and_recompile(spv_func: SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint2Plane);
11172 else
11173 add_spv_func_and_recompile(spv_func: SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane);
11174 fname += "XMidpointYMidpoint";
11175 break;
11176 default:
11177 SPIRV_CROSS_THROW("Invalid Y chroma location.");
11178 }
11179 break;
11180 default:
11181 SPIRV_CROSS_THROW("Invalid X chroma location.");
11182 }
11183 break;
11184 default:
11185 SPIRV_CROSS_THROW("Invalid format resolution.");
11186 }
11187 }
11188 }
11189 else
11190 {
11191 fname = to_expression(id: combined ? combined->image : img) + ".";
11192
11193 // Texture function and sampler
11194 if (args.base.is_fetch)
11195 fname += "read";
11196 else if (args.base.is_gather)
11197 fname += "gather";
11198 else
11199 fname += "sample";
11200
11201 if (args.has_dref)
11202 fname += "_compare";
11203 }
11204
11205 return fname;
11206}
11207
11208string CompilerMSL::convert_to_f32(const string &expr, uint32_t components)
11209{
11210 SPIRType t { components > 1 ? OpTypeVector : OpTypeFloat };
11211 t.basetype = SPIRType::Float;
11212 t.vecsize = components;
11213 t.columns = 1;
11214 return join(ts: type_to_glsl_constructor(type: t), ts: "(", ts: expr, ts: ")");
11215}
11216
11217static inline bool sampling_type_needs_f32_conversion(const SPIRType &type)
11218{
11219 // Double is not supported to begin with, but doesn't hurt to check for completion.
11220 return type.basetype == SPIRType::Half || type.basetype == SPIRType::Double;
11221}
11222
11223// Returns the function args for a texture sampling function for the specified image and sampling characteristics.
11224string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool *p_forward)
11225{
11226 VariableID img = args.base.img;
11227 auto &imgtype = *args.base.imgtype;
11228 uint32_t lod = args.lod;
11229 uint32_t grad_x = args.grad_x;
11230 uint32_t grad_y = args.grad_y;
11231 uint32_t bias = args.bias;
11232
11233 const MSLConstexprSampler *constexpr_sampler = nullptr;
11234 bool is_dynamic_img_sampler = false;
11235 if (auto *var = maybe_get_backing_variable(chain: img))
11236 {
11237 constexpr_sampler = find_constexpr_sampler(id: var->basevariable ? var->basevariable : VariableID(var->self));
11238 is_dynamic_img_sampler = has_extended_decoration(id: var->self, decoration: SPIRVCrossDecorationDynamicImageSampler);
11239 }
11240
11241 string farg_str;
11242 bool forward = true;
11243
11244 if (!is_dynamic_img_sampler)
11245 {
11246 // Texture reference (for some cases)
11247 if (needs_chroma_reconstruction(constexpr_sampler))
11248 {
11249 // Multiplanar images need two or three textures.
11250 farg_str += to_expression(id: img);
11251 for (uint32_t i = 1; i < constexpr_sampler->planes; i++)
11252 farg_str += join(ts: ", ", ts: to_expression(id: img), ts&: plane_name_suffix, ts&: i);
11253 }
11254 else if ((!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable) &&
11255 msl_options.swizzle_texture_samples && args.base.is_gather)
11256 {
11257 auto *combined = maybe_get<SPIRCombinedImageSampler>(id: img);
11258 farg_str += to_expression(id: combined ? combined->image : img);
11259 }
11260
11261 // Gathers with constant offsets call a special function, so include the texture.
11262 if (args.has_array_offsets)
11263 farg_str += to_expression(id: img);
11264
11265 // Sampler reference
11266 if (!args.base.is_fetch)
11267 {
11268 if (!farg_str.empty())
11269 farg_str += ", ";
11270 farg_str += to_sampler_expression(id: img);
11271 }
11272
11273 if ((!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable) &&
11274 msl_options.swizzle_texture_samples && args.base.is_gather)
11275 {
11276 // Add the swizzle constant from the swizzle buffer.
11277 farg_str += ", " + to_swizzle_expression(id: img);
11278 used_swizzle_buffer = true;
11279 }
11280
11281 // Const offsets gather puts the const offsets before the other args.
11282 if (args.has_array_offsets)
11283 {
11284 forward = forward && should_forward(id: args.offset);
11285 farg_str += ", " + to_expression(id: args.offset);
11286 }
11287
11288 // Const offsets gather or swizzled gather puts the component before the other args.
11289 if (args.component && (args.has_array_offsets || msl_options.swizzle_texture_samples))
11290 {
11291 forward = forward && should_forward(id: args.component);
11292 farg_str += ", " + to_component_argument(id: args.component);
11293 }
11294 }
11295
11296 // Texture coordinates
11297 forward = forward && should_forward(id: args.coord);
11298 auto coord_expr = to_enclosed_expression(id: args.coord);
11299 auto &coord_type = expression_type(id: args.coord);
11300 bool coord_is_fp = type_is_floating_point(type: coord_type);
11301 bool is_cube_fetch = false;
11302
11303 string tex_coords = coord_expr;
11304 uint32_t alt_coord_component = 0;
11305
11306 switch (imgtype.image.dim)
11307 {
11308
11309 case Dim1D:
11310 if (coord_type.vecsize > 1)
11311 tex_coords = enclose_expression(expr: tex_coords) + ".x";
11312
11313 if (args.base.is_fetch)
11314 tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
11315 else if (sampling_type_needs_f32_conversion(type: coord_type))
11316 tex_coords = convert_to_f32(expr: tex_coords, components: 1);
11317
11318 if (msl_options.texture_1D_as_2D)
11319 {
11320 if (args.base.is_fetch)
11321 tex_coords = "uint2(" + tex_coords + ", 0)";
11322 else
11323 tex_coords = "float2(" + tex_coords + ", 0.5)";
11324 }
11325
11326 alt_coord_component = 1;
11327 break;
11328
11329 case DimBuffer:
11330 if (coord_type.vecsize > 1)
11331 tex_coords = enclose_expression(expr: tex_coords) + ".x";
11332
11333 if (msl_options.texture_buffer_native)
11334 {
11335 tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
11336 }
11337 else
11338 {
11339 // Metal texel buffer textures are 2D, so convert 1D coord to 2D.
11340 // Support for Metal 2.1's new texture_buffer type.
11341 if (args.base.is_fetch)
11342 {
11343 if (msl_options.texel_buffer_texture_width > 0)
11344 {
11345 tex_coords = "spvTexelBufferCoord(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
11346 }
11347 else
11348 {
11349 tex_coords = "spvTexelBufferCoord(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ", " +
11350 to_expression(id: img) + ")";
11351 }
11352 }
11353 }
11354
11355 alt_coord_component = 1;
11356 break;
11357
11358 case DimSubpassData:
11359 // If we're using Metal's native frame-buffer fetch API for subpass inputs,
11360 // this path will not be hit.
11361 tex_coords = "uint2(gl_FragCoord.xy)";
11362 alt_coord_component = 2;
11363 break;
11364
11365 case Dim2D:
11366 if (coord_type.vecsize > 2)
11367 tex_coords = enclose_expression(expr: tex_coords) + ".xy";
11368
11369 if (args.base.is_fetch)
11370 tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
11371 else if (sampling_type_needs_f32_conversion(type: coord_type))
11372 tex_coords = convert_to_f32(expr: tex_coords, components: 2);
11373
11374 alt_coord_component = 2;
11375 break;
11376
11377 case Dim3D:
11378 if (coord_type.vecsize > 3)
11379 tex_coords = enclose_expression(expr: tex_coords) + ".xyz";
11380
11381 if (args.base.is_fetch)
11382 tex_coords = "uint3(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
11383 else if (sampling_type_needs_f32_conversion(type: coord_type))
11384 tex_coords = convert_to_f32(expr: tex_coords, components: 3);
11385
11386 alt_coord_component = 3;
11387 break;
11388
11389 case DimCube:
11390 if (args.base.is_fetch)
11391 {
11392 is_cube_fetch = true;
11393 tex_coords += ".xy";
11394 tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
11395 }
11396 else
11397 {
11398 if (coord_type.vecsize > 3)
11399 tex_coords = enclose_expression(expr: tex_coords) + ".xyz";
11400 }
11401
11402 if (sampling_type_needs_f32_conversion(type: coord_type))
11403 tex_coords = convert_to_f32(expr: tex_coords, components: 3);
11404
11405 alt_coord_component = 3;
11406 break;
11407
11408 default:
11409 break;
11410 }
11411
11412 if (args.base.is_fetch && args.offset)
11413 {
11414 // Fetch offsets must be applied directly to the coordinate.
11415 forward = forward && should_forward(id: args.offset);
11416 auto &type = expression_type(id: args.offset);
11417 if (imgtype.image.dim == Dim1D && msl_options.texture_1D_as_2D)
11418 {
11419 if (type.basetype != SPIRType::UInt)
11420 tex_coords += join(ts: " + uint2(", ts: bitcast_expression(target_type: SPIRType::UInt, arg: args.offset), ts: ", 0)");
11421 else
11422 tex_coords += join(ts: " + uint2(", ts: to_enclosed_expression(id: args.offset), ts: ", 0)");
11423 }
11424 else
11425 {
11426 if (type.basetype != SPIRType::UInt)
11427 tex_coords += " + " + bitcast_expression(target_type: SPIRType::UInt, arg: args.offset);
11428 else
11429 tex_coords += " + " + to_enclosed_expression(id: args.offset);
11430 }
11431 }
11432
11433 // If projection, use alt coord as divisor
11434 if (args.base.is_proj)
11435 {
11436 if (sampling_type_needs_f32_conversion(type: coord_type))
11437 tex_coords += " / " + convert_to_f32(expr: to_extract_component_expression(id: args.coord, index: alt_coord_component), components: 1);
11438 else
11439 tex_coords += " / " + to_extract_component_expression(id: args.coord, index: alt_coord_component);
11440 }
11441
11442 if (!farg_str.empty())
11443 farg_str += ", ";
11444
11445 if (imgtype.image.dim == DimCube && imgtype.image.arrayed && msl_options.emulate_cube_array)
11446 {
11447 farg_str += "spvCubemapTo2DArrayFace(" + tex_coords + ").xy";
11448
11449 if (is_cube_fetch)
11450 farg_str += ", uint(" + to_extract_component_expression(id: args.coord, index: 2) + ")";
11451 else
11452 farg_str +=
11453 ", uint(spvCubemapTo2DArrayFace(" + tex_coords + ").z) + (uint(" +
11454 round_fp_tex_coords(tex_coords: to_extract_component_expression(id: args.coord, index: alt_coord_component), coord_is_fp) +
11455 ") * 6u)";
11456
11457 add_spv_func_and_recompile(spv_func: SPVFuncImplCubemapTo2DArrayFace);
11458 }
11459 else
11460 {
11461 farg_str += tex_coords;
11462
11463 // If fetch from cube, add face explicitly
11464 if (is_cube_fetch)
11465 {
11466 // Special case for cube arrays, face and layer are packed in one dimension.
11467 if (imgtype.image.arrayed)
11468 farg_str += ", uint(" + to_extract_component_expression(id: args.coord, index: 2) + ") % 6u";
11469 else
11470 farg_str +=
11471 ", uint(" + round_fp_tex_coords(tex_coords: to_extract_component_expression(id: args.coord, index: 2), coord_is_fp) + ")";
11472 }
11473
11474 // If array, use alt coord
11475 if (imgtype.image.arrayed)
11476 {
11477 // Special case for cube arrays, face and layer are packed in one dimension.
11478 if (imgtype.image.dim == DimCube && args.base.is_fetch)
11479 {
11480 farg_str += ", uint(" + to_extract_component_expression(id: args.coord, index: 2) + ") / 6u";
11481 }
11482 else
11483 {
11484 farg_str +=
11485 ", uint(" +
11486 round_fp_tex_coords(tex_coords: to_extract_component_expression(id: args.coord, index: alt_coord_component), coord_is_fp) +
11487 ")";
11488 if (imgtype.image.dim == DimSubpassData)
11489 {
11490 if (msl_options.multiview)
11491 farg_str += " + gl_ViewIndex";
11492 else if (msl_options.arrayed_subpass_input)
11493 farg_str += " + gl_Layer";
11494 }
11495 }
11496 }
11497 else if (imgtype.image.dim == DimSubpassData)
11498 {
11499 if (msl_options.multiview)
11500 farg_str += ", gl_ViewIndex";
11501 else if (msl_options.arrayed_subpass_input)
11502 farg_str += ", gl_Layer";
11503 }
11504 }
11505
11506 // Depth compare reference value
11507 if (args.dref)
11508 {
11509 forward = forward && should_forward(id: args.dref);
11510 farg_str += ", ";
11511
11512 auto &dref_type = expression_type(id: args.dref);
11513
11514 string dref_expr;
11515 if (args.base.is_proj)
11516 dref_expr = join(ts: to_enclosed_expression(id: args.dref), ts: " / ",
11517 ts: to_extract_component_expression(id: args.coord, index: alt_coord_component));
11518 else
11519 dref_expr = to_expression(id: args.dref);
11520
11521 if (sampling_type_needs_f32_conversion(type: dref_type))
11522 dref_expr = convert_to_f32(expr: dref_expr, components: 1);
11523
11524 farg_str += dref_expr;
11525
11526 if (msl_options.is_macos() && (grad_x || grad_y))
11527 {
11528 // For sample compare, MSL does not support gradient2d for all targets (only iOS apparently according to docs).
11529 // However, the most common case here is to have a constant gradient of 0, as that is the only way to express
11530 // LOD == 0 in GLSL with sampler2DArrayShadow (cascaded shadow mapping).
11531 // We will detect a compile-time constant 0 value for gradient and promote that to level(0) on MSL.
11532 bool constant_zero_x = !grad_x || expression_is_constant_null(id: grad_x);
11533 bool constant_zero_y = !grad_y || expression_is_constant_null(id: grad_y);
11534 if (constant_zero_x && constant_zero_y &&
11535 (!imgtype.image.arrayed || !msl_options.sample_dref_lod_array_as_grad))
11536 {
11537 lod = 0;
11538 grad_x = 0;
11539 grad_y = 0;
11540 farg_str += ", level(0)";
11541 }
11542 else if (!msl_options.supports_msl_version(major: 2, minor: 3))
11543 {
11544 SPIRV_CROSS_THROW("Using non-constant 0.0 gradient() qualifier for sample_compare. This is not "
11545 "supported on macOS prior to MSL 2.3.");
11546 }
11547 }
11548
11549 if (msl_options.is_macos() && bias)
11550 {
11551 // Bias is not supported either on macOS with sample_compare.
11552 // Verify it is compile-time zero, and drop the argument.
11553 if (expression_is_constant_null(id: bias))
11554 {
11555 bias = 0;
11556 }
11557 else if (!msl_options.supports_msl_version(major: 2, minor: 3))
11558 {
11559 SPIRV_CROSS_THROW("Using non-constant 0.0 bias() qualifier for sample_compare. This is not supported "
11560 "on macOS prior to MSL 2.3.");
11561 }
11562 }
11563 }
11564
11565 // LOD Options
11566 // Metal does not support LOD for 1D textures.
11567 if (bias && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
11568 {
11569 forward = forward && should_forward(id: bias);
11570 farg_str += ", bias(" + to_expression(id: bias) + ")";
11571 }
11572
11573 // Metal does not support LOD for 1D textures.
11574 if (lod && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
11575 {
11576 forward = forward && should_forward(id: lod);
11577 if (args.base.is_fetch)
11578 {
11579 farg_str += ", " + to_expression(id: lod);
11580 }
11581 else if (msl_options.sample_dref_lod_array_as_grad && args.dref && imgtype.image.arrayed)
11582 {
11583 if (msl_options.is_macos() && !msl_options.supports_msl_version(major: 2, minor: 3))
11584 SPIRV_CROSS_THROW("Using non-constant 0.0 gradient() qualifier for sample_compare. This is not "
11585 "supported on macOS prior to MSL 2.3.");
11586 // Some Metal devices have a bug where the LoD is erroneously biased upward
11587 // when using a level() argument. Since this doesn't happen as much with gradient2d(),
11588 // if we perform the LoD calculation in reverse, we can pass a gradient
11589 // instead.
11590 // lod = log2(rhoMax/eta) -> exp2(lod) = rhoMax/eta
11591 // If we make all of the scale factors the same, eta will be 1 and
11592 // exp2(lod) = rho.
11593 // rhoX = dP/dx * extent; rhoY = dP/dy * extent
11594 // Therefore, dP/dx = dP/dy = exp2(lod)/extent.
11595 // (Subtracting 0.5 before exponentiation gives better results.)
11596 string grad_opt, extent, grad_coord;
11597 VariableID base_img = img;
11598 if (auto *combined = maybe_get<SPIRCombinedImageSampler>(id: img))
11599 base_img = combined->image;
11600 switch (imgtype.image.dim)
11601 {
11602 case Dim1D:
11603 grad_opt = "gradient2d";
11604 extent = join(ts: "float2(", ts: to_expression(id: base_img), ts: ".get_width(), 1.0)");
11605 break;
11606 case Dim2D:
11607 grad_opt = "gradient2d";
11608 extent = join(ts: "float2(", ts: to_expression(id: base_img), ts: ".get_width(), ", ts: to_expression(id: base_img), ts: ".get_height())");
11609 break;
11610 case DimCube:
11611 if (imgtype.image.arrayed && msl_options.emulate_cube_array)
11612 {
11613 grad_opt = "gradient2d";
11614 extent = join(ts: "float2(", ts: to_expression(id: base_img), ts: ".get_width())");
11615 }
11616 else
11617 {
11618 if (msl_options.agx_manual_cube_grad_fixup)
11619 {
11620 add_spv_func_and_recompile(spv_func: SPVFuncImplGradientCube);
11621 grad_opt = "spvGradientCube";
11622 grad_coord = tex_coords + ", ";
11623 }
11624 else
11625 {
11626 grad_opt = "gradientcube";
11627 }
11628 extent = join(ts: "float3(", ts: to_expression(id: base_img), ts: ".get_width())");
11629 }
11630 break;
11631 default:
11632 grad_opt = "unsupported_gradient_dimension";
11633 extent = "float3(1.0)";
11634 break;
11635 }
11636 farg_str += join(ts: ", ", ts&: grad_opt, ts: "(", ts&: grad_coord, ts: "exp2(", ts: to_expression(id: lod), ts: " - 0.5) / ", ts&: extent,
11637 ts: ", exp2(", ts: to_expression(id: lod), ts: " - 0.5) / ", ts&: extent, ts: ")");
11638 }
11639 else
11640 {
11641 farg_str += ", level(" + to_expression(id: lod) + ")";
11642 }
11643 }
11644 else if (args.base.is_fetch && !lod && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D) &&
11645 imgtype.image.dim != DimBuffer && !imgtype.image.ms && imgtype.image.sampled != 2)
11646 {
11647 // Lod argument is optional in OpImageFetch, but we require a LOD value, pick 0 as the default.
11648 // Check for sampled type as well, because is_fetch is also used for OpImageRead in MSL.
11649 farg_str += ", 0";
11650 }
11651
11652 // Metal does not support LOD for 1D textures.
11653 if ((grad_x || grad_y) && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
11654 {
11655 forward = forward && should_forward(id: grad_x);
11656 forward = forward && should_forward(id: grad_y);
11657 string grad_opt, grad_coord;
11658 switch (imgtype.image.dim)
11659 {
11660 case Dim1D:
11661 case Dim2D:
11662 grad_opt = "gradient2d";
11663 break;
11664 case Dim3D:
11665 grad_opt = "gradient3d";
11666 break;
11667 case DimCube:
11668 if (imgtype.image.arrayed && msl_options.emulate_cube_array)
11669 {
11670 grad_opt = "gradient2d";
11671 }
11672 else if (msl_options.agx_manual_cube_grad_fixup)
11673 {
11674 add_spv_func_and_recompile(spv_func: SPVFuncImplGradientCube);
11675 grad_opt = "spvGradientCube";
11676 grad_coord = tex_coords + ", ";
11677 }
11678 else
11679 {
11680 grad_opt = "gradientcube";
11681 }
11682 break;
11683 default:
11684 grad_opt = "unsupported_gradient_dimension";
11685 break;
11686 }
11687 farg_str += join(ts: ", ", ts&: grad_opt, ts: "(", ts&: grad_coord, ts: to_expression(id: grad_x), ts: ", ", ts: to_expression(id: grad_y), ts: ")");
11688 }
11689
11690 if (args.min_lod)
11691 {
11692 if (!msl_options.supports_msl_version(major: 2, minor: 2))
11693 SPIRV_CROSS_THROW("min_lod_clamp() is only supported in MSL 2.2+ and up.");
11694
11695 forward = forward && should_forward(id: args.min_lod);
11696 farg_str += ", min_lod_clamp(" + to_expression(id: args.min_lod) + ")";
11697 }
11698
11699 // Add offsets
11700 string offset_expr;
11701 const SPIRType *offset_type = nullptr;
11702 if (args.offset && !args.base.is_fetch && !args.has_array_offsets)
11703 {
11704 forward = forward && should_forward(id: args.offset);
11705 offset_expr = to_expression(id: args.offset);
11706 offset_type = &expression_type(id: args.offset);
11707 }
11708
11709 if (!offset_expr.empty())
11710 {
11711 switch (imgtype.image.dim)
11712 {
11713 case Dim1D:
11714 if (!msl_options.texture_1D_as_2D)
11715 break;
11716 if (offset_type->vecsize > 1)
11717 offset_expr = enclose_expression(expr: offset_expr) + ".x";
11718
11719 farg_str += join(ts: ", int2(", ts&: offset_expr, ts: ", 0)");
11720 break;
11721
11722 case Dim2D:
11723 if (offset_type->vecsize > 2)
11724 offset_expr = enclose_expression(expr: offset_expr) + ".xy";
11725
11726 farg_str += ", " + offset_expr;
11727 break;
11728
11729 case Dim3D:
11730 if (offset_type->vecsize > 3)
11731 offset_expr = enclose_expression(expr: offset_expr) + ".xyz";
11732
11733 farg_str += ", " + offset_expr;
11734 break;
11735
11736 default:
11737 break;
11738 }
11739 }
11740
11741 if (args.component && !args.has_array_offsets)
11742 {
11743 // If 2D has gather component, ensure it also has an offset arg
11744 if (imgtype.image.dim == Dim2D && offset_expr.empty())
11745 farg_str += ", int2(0)";
11746
11747 if (!msl_options.swizzle_texture_samples || is_dynamic_img_sampler)
11748 {
11749 forward = forward && should_forward(id: args.component);
11750
11751 uint32_t image_var = 0;
11752 if (const auto *combined = maybe_get<SPIRCombinedImageSampler>(id: img))
11753 {
11754 if (const auto *img_var = maybe_get_backing_variable(chain: combined->image))
11755 image_var = img_var->self;
11756 }
11757 else if (const auto *var = maybe_get_backing_variable(chain: img))
11758 {
11759 image_var = var->self;
11760 }
11761
11762 if (image_var == 0 || !is_depth_image(type: expression_type(id: image_var), id: image_var))
11763 farg_str += ", " + to_component_argument(id: args.component);
11764 }
11765 }
11766
11767 if (args.sample)
11768 {
11769 forward = forward && should_forward(id: args.sample);
11770 farg_str += ", ";
11771 farg_str += to_expression(id: args.sample);
11772 }
11773
11774 *p_forward = forward;
11775
11776 return farg_str;
11777}
11778
11779// If the texture coordinates are floating point, invokes MSL round() function to round them.
11780string CompilerMSL::round_fp_tex_coords(string tex_coords, bool coord_is_fp)
11781{
11782 return coord_is_fp ? ("rint(" + tex_coords + ")") : tex_coords;
11783}
11784
11785// Returns a string to use in an image sampling function argument.
11786// The ID must be a scalar constant.
11787string CompilerMSL::to_component_argument(uint32_t id)
11788{
11789 uint32_t component_index = evaluate_constant_u32(id);
11790 switch (component_index)
11791 {
11792 case 0:
11793 return "component::x";
11794 case 1:
11795 return "component::y";
11796 case 2:
11797 return "component::z";
11798 case 3:
11799 return "component::w";
11800
11801 default:
11802 SPIRV_CROSS_THROW("The value (" + to_string(component_index) + ") of OpConstant ID " + to_string(id) +
11803 " is not a valid Component index, which must be one of 0, 1, 2, or 3.");
11804 }
11805}
11806
11807// Establish sampled image as expression object and assign the sampler to it.
11808void CompilerMSL::emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id)
11809{
11810 set<SPIRCombinedImageSampler>(id: result_id, args&: result_type, args&: image_id, args&: samp_id);
11811}
11812
11813string CompilerMSL::to_texture_op(const Instruction &i, bool sparse, bool *forward,
11814 SmallVector<uint32_t> &inherited_expressions)
11815{
11816 auto *ops = stream(instr: i);
11817 uint32_t result_type_id = ops[0];
11818 uint32_t img = ops[2];
11819 auto &result_type = get<SPIRType>(id: result_type_id);
11820 auto op = static_cast<Op>(i.op);
11821 bool is_gather = (op == OpImageGather || op == OpImageDrefGather);
11822
11823 // Bypass pointers because we need the real image struct
11824 auto &type = expression_type(id: img);
11825 auto &imgtype = get<SPIRType>(id: type.self);
11826
11827 const MSLConstexprSampler *constexpr_sampler = nullptr;
11828 bool is_dynamic_img_sampler = false;
11829 if (auto *var = maybe_get_backing_variable(chain: img))
11830 {
11831 constexpr_sampler = find_constexpr_sampler(id: var->basevariable ? var->basevariable : VariableID(var->self));
11832 is_dynamic_img_sampler = has_extended_decoration(id: var->self, decoration: SPIRVCrossDecorationDynamicImageSampler);
11833 }
11834
11835 string expr;
11836 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && !is_dynamic_img_sampler)
11837 {
11838 // If this needs sampler Y'CbCr conversion, we need to do some additional
11839 // processing.
11840 switch (constexpr_sampler->ycbcr_model)
11841 {
11842 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY:
11843 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY:
11844 // Default
11845 break;
11846 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_709:
11847 add_spv_func_and_recompile(spv_func: SPVFuncImplConvertYCbCrBT709);
11848 expr += "spvConvertYCbCrBT709(";
11849 break;
11850 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_601:
11851 add_spv_func_and_recompile(spv_func: SPVFuncImplConvertYCbCrBT601);
11852 expr += "spvConvertYCbCrBT601(";
11853 break;
11854 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_2020:
11855 add_spv_func_and_recompile(spv_func: SPVFuncImplConvertYCbCrBT2020);
11856 expr += "spvConvertYCbCrBT2020(";
11857 break;
11858 default:
11859 SPIRV_CROSS_THROW("Invalid Y'CbCr model conversion.");
11860 }
11861
11862 if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY)
11863 {
11864 switch (constexpr_sampler->ycbcr_range)
11865 {
11866 case MSL_SAMPLER_YCBCR_RANGE_ITU_FULL:
11867 add_spv_func_and_recompile(spv_func: SPVFuncImplExpandITUFullRange);
11868 expr += "spvExpandITUFullRange(";
11869 break;
11870 case MSL_SAMPLER_YCBCR_RANGE_ITU_NARROW:
11871 add_spv_func_and_recompile(spv_func: SPVFuncImplExpandITUNarrowRange);
11872 expr += "spvExpandITUNarrowRange(";
11873 break;
11874 default:
11875 SPIRV_CROSS_THROW("Invalid Y'CbCr range.");
11876 }
11877 }
11878 }
11879 else if (msl_options.swizzle_texture_samples && !is_gather && is_sampled_image_type(type: imgtype) &&
11880 !is_dynamic_img_sampler)
11881 {
11882 add_spv_func_and_recompile(spv_func: SPVFuncImplTextureSwizzle);
11883 expr += "spvTextureSwizzle(";
11884 }
11885
11886 string inner_expr = CompilerGLSL::to_texture_op(i, sparse, forward, inherited_expressions);
11887
11888 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && !is_dynamic_img_sampler)
11889 {
11890 if (!constexpr_sampler->swizzle_is_identity())
11891 {
11892 static const char swizzle_names[] = "rgba";
11893 if (!constexpr_sampler->swizzle_has_one_or_zero())
11894 {
11895 // If we can, do it inline.
11896 expr += inner_expr + ".";
11897 for (uint32_t c = 0; c < 4; c++)
11898 {
11899 switch (constexpr_sampler->swizzle[c])
11900 {
11901 case MSL_COMPONENT_SWIZZLE_IDENTITY:
11902 expr += swizzle_names[c];
11903 break;
11904 case MSL_COMPONENT_SWIZZLE_R:
11905 case MSL_COMPONENT_SWIZZLE_G:
11906 case MSL_COMPONENT_SWIZZLE_B:
11907 case MSL_COMPONENT_SWIZZLE_A:
11908 expr += swizzle_names[constexpr_sampler->swizzle[c] - MSL_COMPONENT_SWIZZLE_R];
11909 break;
11910 default:
11911 SPIRV_CROSS_THROW("Invalid component swizzle.");
11912 }
11913 }
11914 }
11915 else
11916 {
11917 // Otherwise, we need to emit a temporary and swizzle that.
11918 uint32_t temp_id = ir.increase_bound_by(count: 1);
11919 emit_op(result_type: result_type_id, result_id: temp_id, rhs: inner_expr, forward_rhs: false);
11920 for (auto &inherit : inherited_expressions)
11921 inherit_expression_dependencies(dst: temp_id, source: inherit);
11922 inherited_expressions.clear();
11923 inherited_expressions.push_back(t: temp_id);
11924
11925 switch (op)
11926 {
11927 case OpImageSampleDrefImplicitLod:
11928 case OpImageSampleImplicitLod:
11929 case OpImageSampleProjImplicitLod:
11930 case OpImageSampleProjDrefImplicitLod:
11931 register_control_dependent_expression(expr: temp_id);
11932 break;
11933
11934 default:
11935 break;
11936 }
11937 expr += type_to_glsl(type: result_type) + "(";
11938 for (uint32_t c = 0; c < 4; c++)
11939 {
11940 switch (constexpr_sampler->swizzle[c])
11941 {
11942 case MSL_COMPONENT_SWIZZLE_IDENTITY:
11943 expr += to_expression(id: temp_id) + "." + swizzle_names[c];
11944 break;
11945 case MSL_COMPONENT_SWIZZLE_ZERO:
11946 expr += "0";
11947 break;
11948 case MSL_COMPONENT_SWIZZLE_ONE:
11949 expr += "1";
11950 break;
11951 case MSL_COMPONENT_SWIZZLE_R:
11952 case MSL_COMPONENT_SWIZZLE_G:
11953 case MSL_COMPONENT_SWIZZLE_B:
11954 case MSL_COMPONENT_SWIZZLE_A:
11955 expr += to_expression(id: temp_id) + "." +
11956 swizzle_names[constexpr_sampler->swizzle[c] - MSL_COMPONENT_SWIZZLE_R];
11957 break;
11958 default:
11959 SPIRV_CROSS_THROW("Invalid component swizzle.");
11960 }
11961 if (c < 3)
11962 expr += ", ";
11963 }
11964 expr += ")";
11965 }
11966 }
11967 else
11968 expr += inner_expr;
11969 if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY)
11970 {
11971 expr += join(ts: ", ", ts: constexpr_sampler->bpc, ts: ")");
11972 if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY)
11973 expr += ")";
11974 }
11975 }
11976 else
11977 {
11978 expr += inner_expr;
11979 if (msl_options.swizzle_texture_samples && !is_gather && is_sampled_image_type(type: imgtype) &&
11980 !is_dynamic_img_sampler)
11981 {
11982 // Add the swizzle constant from the swizzle buffer.
11983 expr += ", " + to_swizzle_expression(id: img) + ")";
11984 used_swizzle_buffer = true;
11985 }
11986 }
11987
11988 return expr;
11989}
11990
11991static string create_swizzle(MSLComponentSwizzle swizzle)
11992{
11993 switch (swizzle)
11994 {
11995 case MSL_COMPONENT_SWIZZLE_IDENTITY:
11996 return "spvSwizzle::none";
11997 case MSL_COMPONENT_SWIZZLE_ZERO:
11998 return "spvSwizzle::zero";
11999 case MSL_COMPONENT_SWIZZLE_ONE:
12000 return "spvSwizzle::one";
12001 case MSL_COMPONENT_SWIZZLE_R:
12002 return "spvSwizzle::red";
12003 case MSL_COMPONENT_SWIZZLE_G:
12004 return "spvSwizzle::green";
12005 case MSL_COMPONENT_SWIZZLE_B:
12006 return "spvSwizzle::blue";
12007 case MSL_COMPONENT_SWIZZLE_A:
12008 return "spvSwizzle::alpha";
12009 default:
12010 SPIRV_CROSS_THROW("Invalid component swizzle.");
12011 }
12012}
12013
12014// Returns a string representation of the ID, usable as a function arg.
12015// Manufacture automatic sampler arg for SampledImage texture.
12016string CompilerMSL::to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id)
12017{
12018 string arg_str;
12019
12020 auto &type = expression_type(id);
12021 bool is_dynamic_img_sampler = has_extended_decoration(id: arg.id, decoration: SPIRVCrossDecorationDynamicImageSampler);
12022 // If the argument *itself* is a "dynamic" combined-image sampler, then we can just pass that around.
12023 bool arg_is_dynamic_img_sampler = has_extended_decoration(id, decoration: SPIRVCrossDecorationDynamicImageSampler);
12024 if (is_dynamic_img_sampler && !arg_is_dynamic_img_sampler)
12025 arg_str = join(ts: "spvDynamicImageSampler<", ts: type_to_glsl(type: get<SPIRType>(id: type.image.type)), ts: ">(");
12026
12027 auto *c = maybe_get<SPIRConstant>(id);
12028 if (msl_options.force_native_arrays && c && !get<SPIRType>(id: c->constant_type).array.empty())
12029 {
12030 // If we are passing a constant array directly to a function for some reason,
12031 // the callee will expect an argument in thread const address space
12032 // (since we can only bind to arrays with references in MSL).
12033 // To resolve this, we must emit a copy in this address space.
12034 // This kind of code gen should be rare enough that performance is not a real concern.
12035 // Inline the SPIR-V to avoid this kind of suboptimal codegen.
12036 //
12037 // We risk calling this inside a continue block (invalid code),
12038 // so just create a thread local copy in the current function.
12039 arg_str = join(ts: "_", ts&: id, ts: "_array_copy");
12040 auto &constants = current_function->constant_arrays_needed_on_stack;
12041 auto itr = find(first: begin(cont&: constants), last: end(cont&: constants), val: ID(id));
12042 if (itr == end(cont&: constants))
12043 {
12044 force_recompile();
12045 constants.push_back(t: id);
12046 }
12047 }
12048 // Dereference pointer variables where needed.
12049 // FIXME: This dereference is actually backwards. We should really just support passing pointer variables between functions.
12050 else if (should_dereference(id))
12051 arg_str += dereference_expression(expression_type: type, expr: CompilerGLSL::to_func_call_arg(arg, id));
12052 else
12053 arg_str += CompilerGLSL::to_func_call_arg(arg, id);
12054
12055 // Need to check the base variable in case we need to apply a qualified alias.
12056 uint32_t var_id = 0;
12057 auto *var = maybe_get<SPIRVariable>(id);
12058 if (var)
12059 var_id = var->basevariable;
12060
12061 if (!arg_is_dynamic_img_sampler)
12062 {
12063 auto *constexpr_sampler = find_constexpr_sampler(id: var_id ? var_id : id);
12064 if (type.basetype == SPIRType::SampledImage)
12065 {
12066 // Manufacture automatic plane args for multiplanar texture
12067 uint32_t planes = 1;
12068 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
12069 {
12070 planes = constexpr_sampler->planes;
12071 // If this parameter isn't aliasing a global, then we need to use
12072 // the special "dynamic image-sampler" class to pass it--and we need
12073 // to use it for *every* non-alias parameter, in case a combined
12074 // image-sampler with a Y'CbCr conversion is passed. Hopefully, this
12075 // pathological case is so rare that it should never be hit in practice.
12076 if (!arg.alias_global_variable)
12077 add_spv_func_and_recompile(spv_func: SPVFuncImplDynamicImageSampler);
12078 }
12079 for (uint32_t i = 1; i < planes; i++)
12080 arg_str += join(ts: ", ", ts: CompilerGLSL::to_func_call_arg(arg, id), ts&: plane_name_suffix, ts&: i);
12081 // Manufacture automatic sampler arg if the arg is a SampledImage texture.
12082 if (type.image.dim != DimBuffer)
12083 arg_str += ", " + to_sampler_expression(id: var_id ? var_id : id);
12084
12085 // Add sampler Y'CbCr conversion info if we have it
12086 if (is_dynamic_img_sampler && constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
12087 {
12088 SmallVector<string> samp_args;
12089
12090 switch (constexpr_sampler->resolution)
12091 {
12092 case MSL_FORMAT_RESOLUTION_444:
12093 // Default
12094 break;
12095 case MSL_FORMAT_RESOLUTION_422:
12096 samp_args.push_back(t: "spvFormatResolution::_422");
12097 break;
12098 case MSL_FORMAT_RESOLUTION_420:
12099 samp_args.push_back(t: "spvFormatResolution::_420");
12100 break;
12101 default:
12102 SPIRV_CROSS_THROW("Invalid format resolution.");
12103 }
12104
12105 if (constexpr_sampler->chroma_filter != MSL_SAMPLER_FILTER_NEAREST)
12106 samp_args.push_back(t: "spvChromaFilter::linear");
12107
12108 if (constexpr_sampler->x_chroma_offset != MSL_CHROMA_LOCATION_COSITED_EVEN)
12109 samp_args.push_back(t: "spvXChromaLocation::midpoint");
12110 if (constexpr_sampler->y_chroma_offset != MSL_CHROMA_LOCATION_COSITED_EVEN)
12111 samp_args.push_back(t: "spvYChromaLocation::midpoint");
12112 switch (constexpr_sampler->ycbcr_model)
12113 {
12114 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY:
12115 // Default
12116 break;
12117 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY:
12118 samp_args.push_back(t: "spvYCbCrModelConversion::ycbcr_identity");
12119 break;
12120 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_709:
12121 samp_args.push_back(t: "spvYCbCrModelConversion::ycbcr_bt_709");
12122 break;
12123 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_601:
12124 samp_args.push_back(t: "spvYCbCrModelConversion::ycbcr_bt_601");
12125 break;
12126 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_2020:
12127 samp_args.push_back(t: "spvYCbCrModelConversion::ycbcr_bt_2020");
12128 break;
12129 default:
12130 SPIRV_CROSS_THROW("Invalid Y'CbCr model conversion.");
12131 }
12132 if (constexpr_sampler->ycbcr_range != MSL_SAMPLER_YCBCR_RANGE_ITU_FULL)
12133 samp_args.push_back(t: "spvYCbCrRange::itu_narrow");
12134 samp_args.push_back(t: join(ts: "spvComponentBits(", ts: constexpr_sampler->bpc, ts: ")"));
12135 arg_str += join(ts: ", spvYCbCrSampler(", ts: merge(list: samp_args), ts: ")");
12136 }
12137 }
12138
12139 if (is_dynamic_img_sampler && constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
12140 arg_str += join(ts: ", (uint(", ts: create_swizzle(swizzle: constexpr_sampler->swizzle[3]), ts: ") << 24) | (uint(",
12141 ts: create_swizzle(swizzle: constexpr_sampler->swizzle[2]), ts: ") << 16) | (uint(",
12142 ts: create_swizzle(swizzle: constexpr_sampler->swizzle[1]), ts: ") << 8) | uint(",
12143 ts: create_swizzle(swizzle: constexpr_sampler->swizzle[0]), ts: ")");
12144 else if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
12145 arg_str += ", " + to_swizzle_expression(id: var_id ? var_id : id);
12146
12147 if (buffer_requires_array_length(id: var_id))
12148 arg_str += ", " + to_buffer_size_expression(id: var_id ? var_id : id);
12149
12150 if (is_dynamic_img_sampler)
12151 arg_str += ")";
12152 }
12153
12154 // Emulate texture2D atomic operations
12155 auto *backing_var = maybe_get_backing_variable(chain: var_id);
12156 if (backing_var && atomic_image_vars_emulated.count(x: backing_var->self))
12157 {
12158 arg_str += ", " + to_expression(id: var_id) + "_atomic";
12159 }
12160
12161 return arg_str;
12162}
12163
12164// If the ID represents a sampled image that has been assigned a sampler already,
12165// generate an expression for the sampler, otherwise generate a fake sampler name
12166// by appending a suffix to the expression constructed from the ID.
12167string CompilerMSL::to_sampler_expression(uint32_t id)
12168{
12169 auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
12170 auto expr = to_expression(id: combined ? combined->image : VariableID(id));
12171 auto index = expr.find_first_of(c: '[');
12172
12173 uint32_t samp_id = 0;
12174 if (combined)
12175 samp_id = combined->sampler;
12176
12177 if (index == string::npos)
12178 return samp_id ? to_expression(id: samp_id) : expr + sampler_name_suffix;
12179 else
12180 {
12181 auto image_expr = expr.substr(pos: 0, n: index);
12182 auto array_expr = expr.substr(pos: index);
12183 return samp_id ? to_expression(id: samp_id) : (image_expr + sampler_name_suffix + array_expr);
12184 }
12185}
12186
12187string CompilerMSL::to_swizzle_expression(uint32_t id)
12188{
12189 auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
12190
12191 auto expr = to_expression(id: combined ? combined->image : VariableID(id));
12192 auto index = expr.find_first_of(c: '[');
12193
12194 // If an image is part of an argument buffer translate this to a legal identifier.
12195 string::size_type period = 0;
12196 while ((period = expr.find_first_of(c: '.', pos: period)) != string::npos && period < index)
12197 expr[period] = '_';
12198
12199 if (index == string::npos)
12200 return expr + swizzle_name_suffix;
12201 else
12202 {
12203 auto image_expr = expr.substr(pos: 0, n: index);
12204 auto array_expr = expr.substr(pos: index);
12205 return image_expr + swizzle_name_suffix + array_expr;
12206 }
12207}
12208
12209string CompilerMSL::to_buffer_size_expression(uint32_t id)
12210{
12211 auto expr = to_expression(id);
12212 auto index = expr.find_first_of(c: '[');
12213
12214 // This is quite crude, but we need to translate the reference name (*spvDescriptorSetN.name) to
12215 // the pointer expression spvDescriptorSetN.name to make a reasonable expression here.
12216 // This only happens if we have argument buffers and we are using OpArrayLength on a lone SSBO in that set.
12217 if (expr.size() >= 3 && expr[0] == '(' && expr[1] == '*')
12218 expr = address_of_expression(expr);
12219
12220 // If a buffer is part of an argument buffer translate this to a legal identifier.
12221 for (auto &c : expr)
12222 if (c == '.')
12223 c = '_';
12224
12225 if (index == string::npos)
12226 return expr + buffer_size_name_suffix;
12227 else
12228 {
12229 auto buffer_expr = expr.substr(pos: 0, n: index);
12230 auto array_expr = expr.substr(pos: index);
12231 if (auto var = maybe_get_backing_variable(chain: id))
12232 {
12233 if (is_var_runtime_size_array(var: *var))
12234 {
12235 if (!msl_options.runtime_array_rich_descriptor)
12236 SPIRV_CROSS_THROW("OpArrayLength requires rich descriptor format");
12237
12238 auto last_pos = array_expr.find_last_of(c: ']');
12239 if (last_pos != std::string::npos)
12240 return buffer_expr + ".length(" + array_expr.substr(pos: 1, n: last_pos - 1) + ")";
12241 }
12242 }
12243 return buffer_expr + buffer_size_name_suffix + array_expr;
12244 }
12245}
12246
12247// Checks whether the type is a Block all of whose members have DecorationPatch.
12248bool CompilerMSL::is_patch_block(const SPIRType &type)
12249{
12250 if (!has_decoration(id: type.self, decoration: DecorationBlock))
12251 return false;
12252
12253 for (uint32_t i = 0; i < type.member_types.size(); i++)
12254 {
12255 if (!has_member_decoration(id: type.self, index: i, decoration: DecorationPatch))
12256 return false;
12257 }
12258
12259 return true;
12260}
12261
12262// Checks whether the ID is a row_major matrix that requires conversion before use
12263bool CompilerMSL::is_non_native_row_major_matrix(uint32_t id)
12264{
12265 auto *e = maybe_get<SPIRExpression>(id);
12266 if (e)
12267 return e->need_transpose;
12268 else
12269 return has_decoration(id, decoration: DecorationRowMajor);
12270}
12271
12272// Checks whether the member is a row_major matrix that requires conversion before use
12273bool CompilerMSL::member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index)
12274{
12275 return has_member_decoration(id: type.self, index, decoration: DecorationRowMajor);
12276}
12277
12278string CompilerMSL::convert_row_major_matrix(string exp_str, const SPIRType &exp_type, uint32_t physical_type_id,
12279 bool is_packed, bool relaxed)
12280{
12281 if (!is_matrix(type: exp_type))
12282 {
12283 return CompilerGLSL::convert_row_major_matrix(exp_str: std::move(exp_str), exp_type, physical_type_id, is_packed, relaxed);
12284 }
12285 else
12286 {
12287 strip_enclosed_expression(expr&: exp_str);
12288 if (physical_type_id != 0 || is_packed)
12289 exp_str = unpack_expression_type(expr_str: exp_str, type: exp_type, physical_type_id, packed: is_packed, row_major: true);
12290 return join(ts: "transpose(", ts&: exp_str, ts: ")");
12291 }
12292}
12293
12294// Called automatically at the end of the entry point function
12295void CompilerMSL::emit_fixup()
12296{
12297 if (is_vertex_like_shader() && stage_out_var_id && !qual_pos_var_name.empty() && !capture_output_to_buffer)
12298 {
12299 if (options.vertex.fixup_clipspace)
12300 statement(ts&: qual_pos_var_name, ts: ".z = (", ts&: qual_pos_var_name, ts: ".z + ", ts&: qual_pos_var_name,
12301 ts: ".w) * 0.5; // Adjust clip-space for Metal");
12302
12303 if (options.vertex.flip_vert_y)
12304 statement(ts&: qual_pos_var_name, ts: ".y = -(", ts&: qual_pos_var_name, ts: ".y);", ts: " // Invert Y-axis for Metal");
12305 }
12306}
12307
12308// Return a string defining a structure member, with padding and packing.
12309string CompilerMSL::to_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
12310 const string &qualifier)
12311{
12312 uint32_t orig_member_type_id = member_type_id;
12313 if (member_is_remapped_physical_type(type, index))
12314 member_type_id = get_extended_member_decoration(type: type.self, index, decoration: SPIRVCrossDecorationPhysicalTypeID);
12315 auto &physical_type = get<SPIRType>(id: member_type_id);
12316
12317 // If this member is packed, mark it as so.
12318 string pack_pfx;
12319
12320 // Allow Metal to use the array<T> template to make arrays a value type
12321 uint32_t orig_id = 0;
12322 if (has_extended_member_decoration(type: type.self, index, decoration: SPIRVCrossDecorationInterfaceOrigID))
12323 orig_id = get_extended_member_decoration(type: type.self, index, decoration: SPIRVCrossDecorationInterfaceOrigID);
12324
12325 bool row_major = false;
12326 if (is_matrix(type: physical_type))
12327 row_major = has_member_decoration(id: type.self, index, decoration: DecorationRowMajor);
12328
12329 SPIRType row_major_physical_type { OpTypeMatrix };
12330 const SPIRType *declared_type = &physical_type;
12331
12332 // If a struct is being declared with physical layout,
12333 // do not use array<T> wrappers.
12334 // This avoids a lot of complicated cases with packed vectors and matrices,
12335 // and generally we cannot copy full arrays in and out of buffers into Function
12336 // address space.
12337 // Array of resources should also be declared as builtin arrays.
12338 if (has_member_decoration(id: type.self, index, decoration: DecorationOffset))
12339 is_using_builtin_array = true;
12340 else if (has_extended_member_decoration(type: type.self, index, decoration: SPIRVCrossDecorationResourceIndexPrimary))
12341 is_using_builtin_array = true;
12342
12343 if (member_is_packed_physical_type(type, index))
12344 {
12345 // If we're packing a matrix, output an appropriate typedef
12346 if (physical_type.basetype == SPIRType::Struct)
12347 {
12348 SPIRV_CROSS_THROW("Cannot emit a packed struct currently.");
12349 }
12350 else if (is_matrix(type: physical_type))
12351 {
12352 uint32_t rows = physical_type.vecsize;
12353 uint32_t cols = physical_type.columns;
12354 pack_pfx = "packed_";
12355 if (row_major)
12356 {
12357 // These are stored transposed.
12358 rows = physical_type.columns;
12359 cols = physical_type.vecsize;
12360 pack_pfx = "packed_rm_";
12361 }
12362 string base_type = physical_type.width == 16 ? "half" : "float";
12363 string td_line = "typedef ";
12364 td_line += "packed_" + base_type + to_string(val: rows);
12365 td_line += " " + pack_pfx;
12366 // Use the actual matrix size here.
12367 td_line += base_type + to_string(val: physical_type.columns) + "x" + to_string(val: physical_type.vecsize);
12368 td_line += "[" + to_string(val: cols) + "]";
12369 td_line += ";";
12370 add_typedef_line(line: td_line);
12371 }
12372 else if (!is_scalar(type: physical_type)) // scalar type is already packed.
12373 pack_pfx = "packed_";
12374 }
12375 else if (is_matrix(type: physical_type))
12376 {
12377 if (!msl_options.supports_msl_version(major: 3, minor: 0) &&
12378 has_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationWorkgroupStruct))
12379 {
12380 pack_pfx = "spvStorage_";
12381 add_spv_func_and_recompile(spv_func: SPVFuncImplStorageMatrix);
12382 // The pack prefix causes problems with array<T> wrappers.
12383 is_using_builtin_array = true;
12384 }
12385 if (row_major)
12386 {
12387 // Need to declare type with flipped vecsize/columns.
12388 row_major_physical_type = physical_type;
12389 swap(a&: row_major_physical_type.vecsize, b&: row_major_physical_type.columns);
12390 declared_type = &row_major_physical_type;
12391 }
12392 }
12393
12394 // iOS Tier 1 argument buffers do not support writable images.
12395 if (physical_type.basetype == SPIRType::Image &&
12396 physical_type.image.sampled == 2 &&
12397 msl_options.is_ios() &&
12398 msl_options.argument_buffers_tier <= Options::ArgumentBuffersTier::Tier1 &&
12399 !has_decoration(id: orig_id, decoration: DecorationNonWritable))
12400 {
12401 SPIRV_CROSS_THROW("Writable images are not allowed on Tier1 argument buffers on iOS.");
12402 }
12403
12404 // Array information is baked into these types.
12405 string array_type;
12406 if (physical_type.basetype != SPIRType::Image && physical_type.basetype != SPIRType::Sampler &&
12407 physical_type.basetype != SPIRType::SampledImage)
12408 {
12409 BuiltIn builtin = BuiltInMax;
12410
12411 // Special handling. In [[stage_out]] or [[stage_in]] blocks,
12412 // we need flat arrays, but if we're somehow declaring gl_PerVertex for constant array reasons, we want
12413 // template array types to be declared.
12414 bool is_ib_in_out =
12415 ((stage_out_var_id && get_stage_out_struct_type().self == type.self &&
12416 variable_storage_requires_stage_io(storage: StorageClassOutput)) ||
12417 (stage_in_var_id && get_stage_in_struct_type().self == type.self &&
12418 variable_storage_requires_stage_io(storage: StorageClassInput)));
12419 if (is_ib_in_out && is_member_builtin(type, index, builtin: &builtin))
12420 is_using_builtin_array = true;
12421 array_type = type_to_array_glsl(type: physical_type, variable_id: orig_id);
12422 }
12423
12424 if (orig_id)
12425 {
12426 auto *data_type = declared_type;
12427 if (is_pointer(type: *data_type))
12428 data_type = &get_pointee_type(type: *data_type);
12429
12430 if (is_array(type: *data_type) && get_resource_array_size(type: *data_type, id: orig_id) == 0)
12431 {
12432 // Hack for declaring unsized array of resources. Need to declare dummy sized array by value inline.
12433 // This can then be wrapped in spvDescriptorArray as usual.
12434 array_type = "[1] /* unsized array hack */";
12435 }
12436 }
12437
12438 string decl_type;
12439 if (declared_type->vecsize > 4)
12440 {
12441 auto orig_type = get<SPIRType>(id: orig_member_type_id);
12442 if (is_matrix(type: orig_type) && row_major)
12443 swap(a&: orig_type.vecsize, b&: orig_type.columns);
12444 orig_type.columns = 1;
12445 decl_type = type_to_glsl(type: orig_type, id: orig_id, member: true);
12446
12447 if (declared_type->columns > 1)
12448 decl_type = join(ts: "spvPaddedStd140Matrix<", ts&: decl_type, ts: ", ", ts: declared_type->columns, ts: ">");
12449 else
12450 decl_type = join(ts: "spvPaddedStd140<", ts&: decl_type, ts: ">");
12451 }
12452 else
12453 decl_type = type_to_glsl(type: *declared_type, id: orig_id, member: true);
12454
12455 const char *overlapping_binding_tag =
12456 has_extended_member_decoration(type: type.self, index, decoration: SPIRVCrossDecorationOverlappingBinding) ?
12457 "// Overlapping binding: " : "";
12458
12459 auto result = join(ts&: overlapping_binding_tag, ts&: pack_pfx, ts&: decl_type, ts: " ", ts: qualifier,
12460 ts: to_member_name(type, index), ts: member_attribute_qualifier(type, index), ts&: array_type, ts: ";");
12461
12462 is_using_builtin_array = false;
12463 return result;
12464}
12465
12466// Emit a structure member, padding and packing to maintain the correct memeber alignments.
12467void CompilerMSL::emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
12468 const string &qualifier, uint32_t)
12469{
12470 // If this member requires padding to maintain its declared offset, emit a dummy padding member before it.
12471 if (has_extended_member_decoration(type: type.self, index, decoration: SPIRVCrossDecorationPaddingTarget))
12472 {
12473 uint32_t pad_len = get_extended_member_decoration(type: type.self, index, decoration: SPIRVCrossDecorationPaddingTarget);
12474 statement(ts: "char _m", ts&: index, ts: "_pad", ts: "[", ts&: pad_len, ts: "];");
12475 }
12476
12477 // Handle HLSL-style 0-based vertex/instance index.
12478 builtin_declaration = true;
12479 statement(ts: to_struct_member(type, member_type_id, index, qualifier));
12480 builtin_declaration = false;
12481}
12482
12483void CompilerMSL::emit_struct_padding_target(const SPIRType &type)
12484{
12485 uint32_t struct_size = get_declared_struct_size_msl(struct_type: type, ignore_alignment: true, ignore_padding: true);
12486 uint32_t target_size = get_extended_decoration(id: type.self, decoration: SPIRVCrossDecorationPaddingTarget);
12487 if (target_size < struct_size)
12488 SPIRV_CROSS_THROW("Cannot pad with negative bytes.");
12489 else if (target_size > struct_size)
12490 statement(ts: "char _m0_final_padding[", ts: target_size - struct_size, ts: "];");
12491}
12492
12493// Return a MSL qualifier for the specified function attribute member
12494string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t index)
12495{
12496 auto &execution = get_entry_point();
12497
12498 uint32_t mbr_type_id = type.member_types[index];
12499 auto &mbr_type = get<SPIRType>(id: mbr_type_id);
12500
12501 BuiltIn builtin = BuiltInMax;
12502 bool is_builtin = is_member_builtin(type, index, builtin: &builtin);
12503
12504 if (has_extended_member_decoration(type: type.self, index, decoration: SPIRVCrossDecorationResourceIndexPrimary))
12505 {
12506 string quals = join(
12507 ts: " [[id(", ts: get_extended_member_decoration(type: type.self, index, decoration: SPIRVCrossDecorationResourceIndexPrimary), ts: ")");
12508 if (interlocked_resources.count(
12509 x: get_extended_member_decoration(type: type.self, index, decoration: SPIRVCrossDecorationInterfaceOrigID)))
12510 quals += ", raster_order_group(0)";
12511 quals += "]]";
12512 return quals;
12513 }
12514
12515 // Vertex function inputs
12516 if (execution.model == ExecutionModelVertex && type.storage == StorageClassInput)
12517 {
12518 if (is_builtin)
12519 {
12520 switch (builtin)
12521 {
12522 case BuiltInVertexId:
12523 case BuiltInVertexIndex:
12524 case BuiltInBaseVertex:
12525 case BuiltInInstanceId:
12526 case BuiltInInstanceIndex:
12527 case BuiltInBaseInstance:
12528 if (msl_options.vertex_for_tessellation)
12529 return "";
12530 return string(" [[") + builtin_qualifier(builtin) + "]]";
12531
12532 case BuiltInDrawIndex:
12533 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
12534
12535 default:
12536 return "";
12537 }
12538 }
12539
12540 uint32_t locn;
12541 if (is_builtin)
12542 locn = get_or_allocate_builtin_input_member_location(builtin, type_id: type.self, index);
12543 else
12544 locn = get_member_location(type_id: type.self, index);
12545
12546 if (locn != k_unknown_location)
12547 return string(" [[attribute(") + convert_to_string(t: locn) + ")]]";
12548 }
12549
12550 // Vertex and tessellation evaluation function outputs
12551 if (((execution.model == ExecutionModelVertex && !msl_options.vertex_for_tessellation) || is_tese_shader()) &&
12552 type.storage == StorageClassOutput)
12553 {
12554 if (is_builtin)
12555 {
12556 switch (builtin)
12557 {
12558 case BuiltInPointSize:
12559 // Only mark the PointSize builtin if really rendering points.
12560 // Some shaders may include a PointSize builtin even when used to render
12561 // non-point topologies, and Metal will reject this builtin when compiling
12562 // the shader into a render pipeline that uses a non-point topology.
12563 return msl_options.enable_point_size_builtin ? (string(" [[") + builtin_qualifier(builtin) + "]]") : "";
12564
12565 case BuiltInViewportIndex:
12566 if (!msl_options.supports_msl_version(major: 2, minor: 0))
12567 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
12568 /* fallthrough */
12569 case BuiltInPosition:
12570 case BuiltInLayer:
12571 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
12572
12573 case BuiltInClipDistance:
12574 if (has_member_decoration(id: type.self, index, decoration: DecorationIndex))
12575 return join(ts: " [[user(clip", ts: get_member_decoration(id: type.self, index, decoration: DecorationIndex), ts: ")]]");
12576 else
12577 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
12578
12579 case BuiltInCullDistance:
12580 if (has_member_decoration(id: type.self, index, decoration: DecorationIndex))
12581 return join(ts: " [[user(cull", ts: get_member_decoration(id: type.self, index, decoration: DecorationIndex), ts: ")]]");
12582 else
12583 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
12584
12585 default:
12586 return "";
12587 }
12588 }
12589 string loc_qual = member_location_attribute_qualifier(type, index);
12590 if (!loc_qual.empty())
12591 return join(ts: " [[", ts&: loc_qual, ts: "]]");
12592 }
12593
12594 if (execution.model == ExecutionModelVertex && msl_options.vertex_for_tessellation && type.storage == StorageClassOutput)
12595 {
12596 // For this type of shader, we always arrange for it to capture its
12597 // output to a buffer. For this reason, qualifiers are irrelevant here.
12598 if (is_builtin)
12599 // We still have to assign a location so the output struct will sort correctly.
12600 get_or_allocate_builtin_output_member_location(builtin, type_id: type.self, index);
12601 return "";
12602 }
12603
12604 // Tessellation control function inputs
12605 if (is_tesc_shader() && type.storage == StorageClassInput)
12606 {
12607 if (is_builtin)
12608 {
12609 switch (builtin)
12610 {
12611 case BuiltInInvocationId:
12612 case BuiltInPrimitiveId:
12613 if (msl_options.multi_patch_workgroup)
12614 return "";
12615 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
12616 case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
12617 case BuiltInSubgroupSize: // FIXME: Should work in any stage
12618 if (msl_options.emulate_subgroups)
12619 return "";
12620 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
12621 case BuiltInPatchVertices:
12622 return "";
12623 // Others come from stage input.
12624 default:
12625 break;
12626 }
12627 }
12628 if (msl_options.multi_patch_workgroup)
12629 return "";
12630
12631 uint32_t locn;
12632 if (is_builtin)
12633 locn = get_or_allocate_builtin_input_member_location(builtin, type_id: type.self, index);
12634 else
12635 locn = get_member_location(type_id: type.self, index);
12636
12637 if (locn != k_unknown_location)
12638 return string(" [[attribute(") + convert_to_string(t: locn) + ")]]";
12639 }
12640
12641 // Tessellation control function outputs
12642 if (is_tesc_shader() && type.storage == StorageClassOutput)
12643 {
12644 // For this type of shader, we always arrange for it to capture its
12645 // output to a buffer. For this reason, qualifiers are irrelevant here.
12646 if (is_builtin)
12647 // We still have to assign a location so the output struct will sort correctly.
12648 get_or_allocate_builtin_output_member_location(builtin, type_id: type.self, index);
12649 return "";
12650 }
12651
12652 // Tessellation evaluation function inputs
12653 if (is_tese_shader() && type.storage == StorageClassInput)
12654 {
12655 if (is_builtin)
12656 {
12657 switch (builtin)
12658 {
12659 case BuiltInPrimitiveId:
12660 case BuiltInTessCoord:
12661 return string(" [[") + builtin_qualifier(builtin) + "]]";
12662 case BuiltInPatchVertices:
12663 return "";
12664 // Others come from stage input.
12665 default:
12666 break;
12667 }
12668 }
12669
12670 if (msl_options.raw_buffer_tese_input)
12671 return "";
12672
12673 // The special control point array must not be marked with an attribute.
12674 if (get_type(id: type.member_types[index]).basetype == SPIRType::ControlPointArray)
12675 return "";
12676
12677 uint32_t locn;
12678 if (is_builtin)
12679 locn = get_or_allocate_builtin_input_member_location(builtin, type_id: type.self, index);
12680 else
12681 locn = get_member_location(type_id: type.self, index);
12682
12683 if (locn != k_unknown_location)
12684 return string(" [[attribute(") + convert_to_string(t: locn) + ")]]";
12685 }
12686
12687 // Tessellation evaluation function outputs were handled above.
12688
12689 // Fragment function inputs
12690 if (execution.model == ExecutionModelFragment && type.storage == StorageClassInput)
12691 {
12692 string quals;
12693 if (is_builtin)
12694 {
12695 switch (builtin)
12696 {
12697 case BuiltInViewIndex:
12698 if (!msl_options.multiview || !msl_options.multiview_layered_rendering)
12699 break;
12700 /* fallthrough */
12701 case BuiltInFrontFacing:
12702 case BuiltInPointCoord:
12703 case BuiltInFragCoord:
12704 case BuiltInSampleId:
12705 case BuiltInSampleMask:
12706 case BuiltInLayer:
12707 case BuiltInBaryCoordKHR:
12708 case BuiltInBaryCoordNoPerspKHR:
12709 quals = builtin_qualifier(builtin);
12710 break;
12711
12712 case BuiltInClipDistance:
12713 return join(ts: " [[user(clip", ts: get_member_decoration(id: type.self, index, decoration: DecorationIndex), ts: ")]]");
12714 case BuiltInCullDistance:
12715 return join(ts: " [[user(cull", ts: get_member_decoration(id: type.self, index, decoration: DecorationIndex), ts: ")]]");
12716
12717 default:
12718 break;
12719 }
12720 }
12721 else
12722 quals = member_location_attribute_qualifier(type, index);
12723
12724 if (builtin == BuiltInBaryCoordKHR || builtin == BuiltInBaryCoordNoPerspKHR)
12725 {
12726 if (has_member_decoration(id: type.self, index, decoration: DecorationFlat) ||
12727 has_member_decoration(id: type.self, index, decoration: DecorationCentroid) ||
12728 has_member_decoration(id: type.self, index, decoration: DecorationSample) ||
12729 has_member_decoration(id: type.self, index, decoration: DecorationNoPerspective))
12730 {
12731 // NoPerspective is baked into the builtin type.
12732 SPIRV_CROSS_THROW(
12733 "Flat, Centroid, Sample, NoPerspective decorations are not supported for BaryCoord inputs.");
12734 }
12735 }
12736
12737 // Don't bother decorating integers with the 'flat' attribute; it's
12738 // the default (in fact, the only option). Also don't bother with the
12739 // FragCoord builtin; it's always noperspective on Metal.
12740 if (!type_is_integral(type: mbr_type) && (!is_builtin || builtin != BuiltInFragCoord))
12741 {
12742 if (has_member_decoration(id: type.self, index, decoration: DecorationFlat))
12743 {
12744 if (!quals.empty())
12745 quals += ", ";
12746 quals += "flat";
12747 }
12748 else if (has_member_decoration(id: type.self, index, decoration: DecorationCentroid))
12749 {
12750 if (!quals.empty())
12751 quals += ", ";
12752 if (has_member_decoration(id: type.self, index, decoration: DecorationNoPerspective))
12753 quals += "centroid_no_perspective";
12754 else
12755 quals += "centroid_perspective";
12756 }
12757 else if (has_member_decoration(id: type.self, index, decoration: DecorationSample))
12758 {
12759 if (!quals.empty())
12760 quals += ", ";
12761 if (has_member_decoration(id: type.self, index, decoration: DecorationNoPerspective))
12762 quals += "sample_no_perspective";
12763 else
12764 quals += "sample_perspective";
12765 }
12766 else if (has_member_decoration(id: type.self, index, decoration: DecorationNoPerspective))
12767 {
12768 if (!quals.empty())
12769 quals += ", ";
12770 quals += "center_no_perspective";
12771 }
12772 }
12773
12774 if (!quals.empty())
12775 return " [[" + quals + "]]";
12776 }
12777
12778 // Fragment function outputs
12779 if (execution.model == ExecutionModelFragment && type.storage == StorageClassOutput)
12780 {
12781 if (is_builtin)
12782 {
12783 switch (builtin)
12784 {
12785 case BuiltInFragStencilRefEXT:
12786 // Similar to PointSize, only mark FragStencilRef if there's a stencil buffer.
12787 // Some shaders may include a FragStencilRef builtin even when used to render
12788 // without a stencil attachment, and Metal will reject this builtin
12789 // when compiling the shader into a render pipeline that does not set
12790 // stencilAttachmentPixelFormat.
12791 if (!msl_options.enable_frag_stencil_ref_builtin)
12792 return "";
12793 if (!msl_options.supports_msl_version(major: 2, minor: 1))
12794 SPIRV_CROSS_THROW("Stencil export only supported in MSL 2.1 and up.");
12795 return string(" [[") + builtin_qualifier(builtin) + "]]";
12796
12797 case BuiltInFragDepth:
12798 // Ditto FragDepth.
12799 if (!msl_options.enable_frag_depth_builtin)
12800 return "";
12801 /* fallthrough */
12802 case BuiltInSampleMask:
12803 return string(" [[") + builtin_qualifier(builtin) + "]]";
12804
12805 default:
12806 return "";
12807 }
12808 }
12809 uint32_t locn = get_member_location(type_id: type.self, index);
12810 // Metal will likely complain about missing color attachments, too.
12811 if (locn != k_unknown_location && !(msl_options.enable_frag_output_mask & (1 << locn)))
12812 return "";
12813 if (locn != k_unknown_location && has_member_decoration(id: type.self, index, decoration: DecorationIndex))
12814 return join(ts: " [[color(", ts&: locn, ts: "), index(", ts: get_member_decoration(id: type.self, index, decoration: DecorationIndex),
12815 ts: ")]]");
12816 else if (locn != k_unknown_location)
12817 return join(ts: " [[color(", ts&: locn, ts: ")]]");
12818 else if (has_member_decoration(id: type.self, index, decoration: DecorationIndex))
12819 return join(ts: " [[index(", ts: get_member_decoration(id: type.self, index, decoration: DecorationIndex), ts: ")]]");
12820 else
12821 return "";
12822 }
12823
12824 // Compute function inputs
12825 if (execution.model == ExecutionModelGLCompute && type.storage == StorageClassInput)
12826 {
12827 if (is_builtin)
12828 {
12829 switch (builtin)
12830 {
12831 case BuiltInNumSubgroups:
12832 case BuiltInSubgroupId:
12833 case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
12834 case BuiltInSubgroupSize: // FIXME: Should work in any stage
12835 if (msl_options.emulate_subgroups)
12836 break;
12837 /* fallthrough */
12838 case BuiltInGlobalInvocationId:
12839 case BuiltInWorkgroupId:
12840 case BuiltInNumWorkgroups:
12841 case BuiltInLocalInvocationId:
12842 case BuiltInLocalInvocationIndex:
12843 return string(" [[") + builtin_qualifier(builtin) + "]]";
12844
12845 default:
12846 return "";
12847 }
12848 }
12849 }
12850
12851 return "";
12852}
12853
12854// A user-defined output variable is considered to match an input variable in the subsequent
12855// stage if the two variables are declared with the same Location and Component decoration and
12856// match in type and decoration, except that interpolation decorations are not required to match.
12857// For the purposes of interface matching, variables declared without a Component decoration are
12858// considered to have a Component decoration of zero.
12859string CompilerMSL::member_location_attribute_qualifier(const SPIRType &type, uint32_t index)
12860{
12861 string quals;
12862 uint32_t comp;
12863 uint32_t locn = get_member_location(type_id: type.self, index, comp: &comp);
12864 if (locn != k_unknown_location)
12865 {
12866 quals += "user(locn";
12867 quals += convert_to_string(t: locn);
12868 if (comp != k_unknown_component && comp != 0)
12869 {
12870 quals += "_";
12871 quals += convert_to_string(t: comp);
12872 }
12873 quals += ")";
12874 }
12875 return quals;
12876}
12877
12878// Returns the location decoration of the member with the specified index in the specified type.
12879// If the location of the member has been explicitly set, that location is used. If not, this
12880// function assumes the members are ordered in their location order, and simply returns the
12881// index as the location.
12882uint32_t CompilerMSL::get_member_location(uint32_t type_id, uint32_t index, uint32_t *comp) const
12883{
12884 if (comp)
12885 {
12886 if (has_member_decoration(id: type_id, index, decoration: DecorationComponent))
12887 *comp = get_member_decoration(id: type_id, index, decoration: DecorationComponent);
12888 else
12889 *comp = k_unknown_component;
12890 }
12891
12892 if (has_member_decoration(id: type_id, index, decoration: DecorationLocation))
12893 return get_member_decoration(id: type_id, index, decoration: DecorationLocation);
12894 else
12895 return k_unknown_location;
12896}
12897
12898uint32_t CompilerMSL::get_or_allocate_builtin_input_member_location(spv::BuiltIn builtin,
12899 uint32_t type_id, uint32_t index,
12900 uint32_t *comp)
12901{
12902 uint32_t loc = get_member_location(type_id, index, comp);
12903 if (loc != k_unknown_location)
12904 return loc;
12905
12906 if (comp)
12907 *comp = k_unknown_component;
12908
12909 // Late allocation. Find a location which is unused by the application.
12910 // This can happen for built-in inputs in tessellation which are mixed and matched with user inputs.
12911 auto &mbr_type = get<SPIRType>(id: get<SPIRType>(id: type_id).member_types[index]);
12912 uint32_t count = type_to_location_count(type: mbr_type);
12913
12914 loc = 0;
12915
12916 const auto location_range_in_use = [this](uint32_t location, uint32_t location_count) -> bool {
12917 for (uint32_t i = 0; i < location_count; i++)
12918 if (location_inputs_in_use.count(x: location + i) != 0)
12919 return true;
12920 return false;
12921 };
12922
12923 while (location_range_in_use(loc, count))
12924 loc++;
12925
12926 set_member_decoration(id: type_id, index, decoration: DecorationLocation, argument: loc);
12927
12928 // Triangle tess level inputs are shared in one packed float4,
12929 // mark both builtins as sharing one location.
12930 if (!msl_options.raw_buffer_tese_input && is_tessellating_triangles() &&
12931 (builtin == BuiltInTessLevelInner || builtin == BuiltInTessLevelOuter))
12932 {
12933 builtin_to_automatic_input_location[BuiltInTessLevelInner] = loc;
12934 builtin_to_automatic_input_location[BuiltInTessLevelOuter] = loc;
12935 }
12936 else
12937 builtin_to_automatic_input_location[builtin] = loc;
12938
12939 mark_location_as_used_by_shader(location: loc, type: mbr_type, storage: StorageClassInput, fallback: true);
12940 return loc;
12941}
12942
12943uint32_t CompilerMSL::get_or_allocate_builtin_output_member_location(spv::BuiltIn builtin,
12944 uint32_t type_id, uint32_t index,
12945 uint32_t *comp)
12946{
12947 uint32_t loc = get_member_location(type_id, index, comp);
12948 if (loc != k_unknown_location)
12949 return loc;
12950 loc = 0;
12951
12952 if (comp)
12953 *comp = k_unknown_component;
12954
12955 // Late allocation. Find a location which is unused by the application.
12956 // This can happen for built-in outputs in tessellation which are mixed and matched with user inputs.
12957 auto &mbr_type = get<SPIRType>(id: get<SPIRType>(id: type_id).member_types[index]);
12958 uint32_t count = type_to_location_count(type: mbr_type);
12959
12960 const auto location_range_in_use = [this](uint32_t location, uint32_t location_count) -> bool {
12961 for (uint32_t i = 0; i < location_count; i++)
12962 if (location_outputs_in_use.count(x: location + i) != 0)
12963 return true;
12964 return false;
12965 };
12966
12967 while (location_range_in_use(loc, count))
12968 loc++;
12969
12970 set_member_decoration(id: type_id, index, decoration: DecorationLocation, argument: loc);
12971
12972 // Triangle tess level inputs are shared in one packed float4;
12973 // mark both builtins as sharing one location.
12974 if (is_tessellating_triangles() && (builtin == BuiltInTessLevelInner || builtin == BuiltInTessLevelOuter))
12975 {
12976 builtin_to_automatic_output_location[BuiltInTessLevelInner] = loc;
12977 builtin_to_automatic_output_location[BuiltInTessLevelOuter] = loc;
12978 }
12979 else
12980 builtin_to_automatic_output_location[builtin] = loc;
12981
12982 mark_location_as_used_by_shader(location: loc, type: mbr_type, storage: StorageClassOutput, fallback: true);
12983 return loc;
12984}
12985
12986// Returns the type declaration for a function, including the
12987// entry type if the current function is the entry point function
12988string CompilerMSL::func_type_decl(SPIRType &type)
12989{
12990 // The regular function return type. If not processing the entry point function, that's all we need
12991 string return_type = type_to_glsl(type) + type_to_array_glsl(type, variable_id: 0);
12992 if (!processing_entry_point)
12993 return return_type;
12994
12995 // If an outgoing interface block has been defined, and it should be returned, override the entry point return type
12996 bool ep_should_return_output = !get_is_rasterization_disabled();
12997 if (stage_out_var_id && ep_should_return_output)
12998 return_type = type_to_glsl(type: get_stage_out_struct_type()) + type_to_array_glsl(type, variable_id: 0);
12999
13000 // Prepend a entry type, based on the execution model
13001 string entry_type;
13002 auto &execution = get_entry_point();
13003 switch (execution.model)
13004 {
13005 case ExecutionModelVertex:
13006 if (msl_options.vertex_for_tessellation && !msl_options.supports_msl_version(major: 1, minor: 2))
13007 SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
13008 entry_type = msl_options.vertex_for_tessellation ? "kernel" : "vertex";
13009 break;
13010 case ExecutionModelTessellationEvaluation:
13011 if (!msl_options.supports_msl_version(major: 1, minor: 2))
13012 SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
13013 if (execution.flags.get(bit: ExecutionModeIsolines))
13014 SPIRV_CROSS_THROW("Metal does not support isoline tessellation.");
13015 if (msl_options.is_ios())
13016 entry_type = join(ts: "[[ patch(", ts: is_tessellating_triangles() ? "triangle" : "quad", ts: ") ]] vertex");
13017 else
13018 entry_type = join(ts: "[[ patch(", ts: is_tessellating_triangles() ? "triangle" : "quad", ts: ", ",
13019 ts&: execution.output_vertices, ts: ") ]] vertex");
13020 break;
13021 case ExecutionModelFragment:
13022 entry_type = uses_explicit_early_fragment_test() ? "[[ early_fragment_tests ]] fragment" : "fragment";
13023 break;
13024 case ExecutionModelTessellationControl:
13025 if (!msl_options.supports_msl_version(major: 1, minor: 2))
13026 SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
13027 if (execution.flags.get(bit: ExecutionModeIsolines))
13028 SPIRV_CROSS_THROW("Metal does not support isoline tessellation.");
13029 /* fallthrough */
13030 case ExecutionModelGLCompute:
13031 case ExecutionModelKernel:
13032 entry_type = "kernel";
13033 break;
13034 default:
13035 entry_type = "unknown";
13036 break;
13037 }
13038
13039 return entry_type + " " + return_type;
13040}
13041
13042bool CompilerMSL::is_tesc_shader() const
13043{
13044 return get_execution_model() == ExecutionModelTessellationControl;
13045}
13046
13047bool CompilerMSL::is_tese_shader() const
13048{
13049 return get_execution_model() == ExecutionModelTessellationEvaluation;
13050}
13051
13052bool CompilerMSL::uses_explicit_early_fragment_test()
13053{
13054 auto &ep_flags = get_entry_point().flags;
13055 return ep_flags.get(bit: ExecutionModeEarlyFragmentTests) || ep_flags.get(bit: ExecutionModePostDepthCoverage);
13056}
13057
13058// In MSL, address space qualifiers are required for all pointer or reference variables
13059string CompilerMSL::get_argument_address_space(const SPIRVariable &argument)
13060{
13061 const auto &type = get<SPIRType>(id: argument.basetype);
13062 return get_type_address_space(type, id: argument.self, argument: true);
13063}
13064
13065bool CompilerMSL::decoration_flags_signal_volatile(const Bitset &flags)
13066{
13067 return flags.get(bit: DecorationVolatile) || flags.get(bit: DecorationCoherent);
13068}
13069
13070string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bool argument)
13071{
13072 // This can be called for variable pointer contexts as well, so be very careful about which method we choose.
13073 Bitset flags;
13074 auto *var = maybe_get<SPIRVariable>(id);
13075 if (var && type.basetype == SPIRType::Struct &&
13076 (has_decoration(id: type.self, decoration: DecorationBlock) || has_decoration(id: type.self, decoration: DecorationBufferBlock)))
13077 flags = get_buffer_block_flags(id);
13078 else
13079 flags = get_decoration_bitset(id);
13080
13081 const char *addr_space = nullptr;
13082 switch (type.storage)
13083 {
13084 case StorageClassWorkgroup:
13085 addr_space = "threadgroup";
13086 break;
13087
13088 case StorageClassStorageBuffer:
13089 case StorageClassPhysicalStorageBuffer:
13090 {
13091 // For arguments from variable pointers, we use the write count deduction, so
13092 // we should not assume any constness here. Only for global SSBOs.
13093 bool readonly = false;
13094 if (!var || has_decoration(id: type.self, decoration: DecorationBlock))
13095 readonly = flags.get(bit: DecorationNonWritable);
13096
13097 addr_space = readonly ? "const device" : "device";
13098 break;
13099 }
13100
13101 case StorageClassUniform:
13102 case StorageClassUniformConstant:
13103 case StorageClassPushConstant:
13104 if (type.basetype == SPIRType::Struct)
13105 {
13106 bool ssbo = has_decoration(id: type.self, decoration: DecorationBufferBlock);
13107 if (ssbo)
13108 addr_space = flags.get(bit: DecorationNonWritable) ? "const device" : "device";
13109 else
13110 addr_space = "constant";
13111 }
13112 else if (!argument)
13113 {
13114 addr_space = "constant";
13115 }
13116 else if (type_is_msl_framebuffer_fetch(type))
13117 {
13118 // Subpass inputs are passed around by value.
13119 addr_space = "";
13120 }
13121 break;
13122
13123 case StorageClassFunction:
13124 case StorageClassGeneric:
13125 break;
13126
13127 case StorageClassInput:
13128 if (is_tesc_shader() && var && var->basevariable == stage_in_ptr_var_id)
13129 addr_space = msl_options.multi_patch_workgroup ? "const device" : "threadgroup";
13130 // Don't pass tessellation levels in the device AS; we load and convert them
13131 // to float manually.
13132 if (is_tese_shader() && msl_options.raw_buffer_tese_input && var)
13133 {
13134 bool is_stage_in = var->basevariable == stage_in_ptr_var_id;
13135 bool is_patch_stage_in = has_decoration(id: var->self, decoration: DecorationPatch);
13136 bool is_builtin = has_decoration(id: var->self, decoration: DecorationBuiltIn);
13137 BuiltIn builtin = (BuiltIn)get_decoration(id: var->self, decoration: DecorationBuiltIn);
13138 bool is_tess_level = is_builtin && (builtin == BuiltInTessLevelOuter || builtin == BuiltInTessLevelInner);
13139 if (is_stage_in || (is_patch_stage_in && !is_tess_level))
13140 addr_space = "const device";
13141 }
13142 if (get_execution_model() == ExecutionModelFragment && var && var->basevariable == stage_in_var_id)
13143 addr_space = "thread";
13144 break;
13145
13146 case StorageClassOutput:
13147 if (capture_output_to_buffer)
13148 {
13149 if (var && type.storage == StorageClassOutput)
13150 {
13151 bool is_masked = is_stage_output_variable_masked(var: *var);
13152
13153 if (is_masked)
13154 {
13155 if (is_tessellation_shader())
13156 addr_space = "threadgroup";
13157 else
13158 addr_space = "thread";
13159 }
13160 else if (variable_decl_is_remapped_storage(variable: *var, storage: StorageClassWorkgroup))
13161 addr_space = "threadgroup";
13162 }
13163
13164 if (!addr_space)
13165 addr_space = "device";
13166 }
13167 break;
13168
13169 default:
13170 break;
13171 }
13172
13173 if (!addr_space)
13174 {
13175 // No address space for plain values.
13176 addr_space = type.pointer || (argument && type.basetype == SPIRType::ControlPointArray) ? "thread" : "";
13177 }
13178
13179 return join(ts: decoration_flags_signal_volatile(flags) ? "volatile " : "", ts&: addr_space);
13180}
13181
13182const char *CompilerMSL::to_restrict(uint32_t id, bool space)
13183{
13184 // This can be called for variable pointer contexts as well, so be very careful about which method we choose.
13185 Bitset flags;
13186 if (ir.ids[id].get_type() == TypeVariable)
13187 {
13188 uint32_t type_id = expression_type_id(id);
13189 auto &type = expression_type(id);
13190 if (type.basetype == SPIRType::Struct &&
13191 (has_decoration(id: type_id, decoration: DecorationBlock) || has_decoration(id: type_id, decoration: DecorationBufferBlock)))
13192 flags = get_buffer_block_flags(id);
13193 else
13194 flags = get_decoration_bitset(id);
13195 }
13196 else
13197 flags = get_decoration_bitset(id);
13198
13199 return flags.get(bit: DecorationRestrict) || flags.get(bit: DecorationRestrictPointerEXT) ?
13200 (space ? "__restrict " : "__restrict") : "";
13201}
13202
13203string CompilerMSL::entry_point_arg_stage_in()
13204{
13205 string decl;
13206
13207 if ((is_tesc_shader() && msl_options.multi_patch_workgroup) ||
13208 (is_tese_shader() && msl_options.raw_buffer_tese_input))
13209 return decl;
13210
13211 // Stage-in structure
13212 uint32_t stage_in_id;
13213 if (is_tese_shader())
13214 stage_in_id = patch_stage_in_var_id;
13215 else
13216 stage_in_id = stage_in_var_id;
13217
13218 if (stage_in_id)
13219 {
13220 auto &var = get<SPIRVariable>(id: stage_in_id);
13221 auto &type = get_variable_data_type(var);
13222
13223 add_resource_name(id: var.self);
13224 decl = join(ts: type_to_glsl(type), ts: " ", ts: to_name(id: var.self), ts: " [[stage_in]]");
13225 }
13226
13227 return decl;
13228}
13229
13230// Returns true if this input builtin should be a direct parameter on a shader function parameter list,
13231// and false for builtins that should be passed or calculated some other way.
13232bool CompilerMSL::is_direct_input_builtin(BuiltIn bi_type)
13233{
13234 switch (bi_type)
13235 {
13236 // Vertex function in
13237 case BuiltInVertexId:
13238 case BuiltInVertexIndex:
13239 case BuiltInBaseVertex:
13240 case BuiltInInstanceId:
13241 case BuiltInInstanceIndex:
13242 case BuiltInBaseInstance:
13243 return get_execution_model() != ExecutionModelVertex || !msl_options.vertex_for_tessellation;
13244 // Tess. control function in
13245 case BuiltInPosition:
13246 case BuiltInPointSize:
13247 case BuiltInClipDistance:
13248 case BuiltInCullDistance:
13249 case BuiltInPatchVertices:
13250 return false;
13251 case BuiltInInvocationId:
13252 case BuiltInPrimitiveId:
13253 return !is_tesc_shader() || !msl_options.multi_patch_workgroup;
13254 // Tess. evaluation function in
13255 case BuiltInTessLevelInner:
13256 case BuiltInTessLevelOuter:
13257 return false;
13258 // Fragment function in
13259 case BuiltInSamplePosition:
13260 case BuiltInHelperInvocation:
13261 case BuiltInBaryCoordKHR:
13262 case BuiltInBaryCoordNoPerspKHR:
13263 return false;
13264 case BuiltInViewIndex:
13265 return get_execution_model() == ExecutionModelFragment && msl_options.multiview &&
13266 msl_options.multiview_layered_rendering;
13267 // Compute function in
13268 case BuiltInSubgroupId:
13269 case BuiltInNumSubgroups:
13270 return !msl_options.emulate_subgroups;
13271 // Any stage function in
13272 case BuiltInDeviceIndex:
13273 case BuiltInSubgroupEqMask:
13274 case BuiltInSubgroupGeMask:
13275 case BuiltInSubgroupGtMask:
13276 case BuiltInSubgroupLeMask:
13277 case BuiltInSubgroupLtMask:
13278 return false;
13279 case BuiltInSubgroupSize:
13280 if (msl_options.fixed_subgroup_size != 0)
13281 return false;
13282 /* fallthrough */
13283 case BuiltInSubgroupLocalInvocationId:
13284 return !msl_options.emulate_subgroups;
13285 default:
13286 return true;
13287 }
13288}
13289
13290// Returns true if this is a fragment shader that runs per sample, and false otherwise.
13291bool CompilerMSL::is_sample_rate() const
13292{
13293 auto &caps = get_declared_capabilities();
13294 return get_execution_model() == ExecutionModelFragment &&
13295 (msl_options.force_sample_rate_shading ||
13296 std::find(first: caps.begin(), last: caps.end(), val: CapabilitySampleRateShading) != caps.end() ||
13297 (msl_options.use_framebuffer_fetch_subpasses && need_subpass_input_ms));
13298}
13299
13300bool CompilerMSL::is_intersection_query() const
13301{
13302 auto &caps = get_declared_capabilities();
13303 return std::find(first: caps.begin(), last: caps.end(), val: CapabilityRayQueryKHR) != caps.end();
13304}
13305
13306void CompilerMSL::entry_point_args_builtin(string &ep_args)
13307{
13308 // Builtin variables
13309 SmallVector<pair<SPIRVariable *, BuiltIn>, 8> active_builtins;
13310 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t var_id, SPIRVariable &var) {
13311 if (var.storage != StorageClassInput)
13312 return;
13313
13314 auto bi_type = BuiltIn(get_decoration(id: var_id, decoration: DecorationBuiltIn));
13315
13316 // Don't emit SamplePosition as a separate parameter. In the entry
13317 // point, we get that by calling get_sample_position() on the sample ID.
13318 if (is_builtin_variable(var) &&
13319 get_variable_data_type(var).basetype != SPIRType::Struct &&
13320 get_variable_data_type(var).basetype != SPIRType::ControlPointArray)
13321 {
13322 // If the builtin is not part of the active input builtin set, don't emit it.
13323 // Relevant for multiple entry-point modules which might declare unused builtins.
13324 if (!active_input_builtins.get(bit: bi_type) || !interface_variable_exists_in_entry_point(id: var_id))
13325 return;
13326
13327 // Remember this variable. We may need to correct its type.
13328 active_builtins.push_back(t: make_pair(x: &var, y&: bi_type));
13329
13330 if (is_direct_input_builtin(bi_type))
13331 {
13332 if (!ep_args.empty())
13333 ep_args += ", ";
13334
13335 // Handle HLSL-style 0-based vertex/instance index.
13336 builtin_declaration = true;
13337
13338 // Handle different MSL gl_TessCoord types. (float2, float3)
13339 if (bi_type == BuiltInTessCoord && get_entry_point().flags.get(bit: ExecutionModeQuads))
13340 ep_args += "float2 " + to_expression(id: var_id) + "In";
13341 else
13342 ep_args += builtin_type_decl(builtin: bi_type, id: var_id) + " " + to_expression(id: var_id);
13343
13344 ep_args += string(" [[") + builtin_qualifier(builtin: bi_type);
13345 if (bi_type == BuiltInSampleMask && get_entry_point().flags.get(bit: ExecutionModePostDepthCoverage))
13346 {
13347 if (!msl_options.supports_msl_version(major: 2))
13348 SPIRV_CROSS_THROW("Post-depth coverage requires MSL 2.0.");
13349 if (msl_options.is_macos() && !msl_options.supports_msl_version(major: 2, minor: 3))
13350 SPIRV_CROSS_THROW("Post-depth coverage on Mac requires MSL 2.3.");
13351 ep_args += ", post_depth_coverage";
13352 }
13353 ep_args += "]]";
13354 builtin_declaration = false;
13355 }
13356 }
13357
13358 if (has_extended_decoration(id: var_id, decoration: SPIRVCrossDecorationBuiltInDispatchBase))
13359 {
13360 // This is a special implicit builtin, not corresponding to any SPIR-V builtin,
13361 // which holds the base that was passed to vkCmdDispatchBase() or vkCmdDrawIndexed(). If it's present,
13362 // assume we emitted it for a good reason.
13363 assert(msl_options.supports_msl_version(1, 2));
13364 if (!ep_args.empty())
13365 ep_args += ", ";
13366
13367 ep_args += type_to_glsl(type: get_variable_data_type(var)) + " " + to_expression(id: var_id) + " [[grid_origin]]";
13368 }
13369
13370 if (has_extended_decoration(id: var_id, decoration: SPIRVCrossDecorationBuiltInStageInputSize))
13371 {
13372 // This is another special implicit builtin, not corresponding to any SPIR-V builtin,
13373 // which holds the number of vertices and instances to draw. If it's present,
13374 // assume we emitted it for a good reason.
13375 assert(msl_options.supports_msl_version(1, 2));
13376 if (!ep_args.empty())
13377 ep_args += ", ";
13378
13379 ep_args += type_to_glsl(type: get_variable_data_type(var)) + " " + to_expression(id: var_id) + " [[grid_size]]";
13380 }
13381 });
13382
13383 // Correct the types of all encountered active builtins. We couldn't do this before
13384 // because ensure_correct_builtin_type() may increase the bound, which isn't allowed
13385 // while iterating over IDs.
13386 for (auto &var : active_builtins)
13387 var.first->basetype = ensure_correct_builtin_type(type_id: var.first->basetype, builtin: var.second);
13388
13389 // Handle HLSL-style 0-based vertex/instance index.
13390 if (needs_base_vertex_arg == TriState::Yes)
13391 ep_args += built_in_func_arg(builtin: BuiltInBaseVertex, prefix_comma: !ep_args.empty());
13392
13393 if (needs_base_instance_arg == TriState::Yes)
13394 ep_args += built_in_func_arg(builtin: BuiltInBaseInstance, prefix_comma: !ep_args.empty());
13395
13396 if (capture_output_to_buffer)
13397 {
13398 // Add parameters to hold the indirect draw parameters and the shader output. This has to be handled
13399 // specially because it needs to be a pointer, not a reference.
13400 if (stage_out_var_id)
13401 {
13402 if (!ep_args.empty())
13403 ep_args += ", ";
13404 ep_args += join(ts: "device ", ts: type_to_glsl(type: get_stage_out_struct_type()), ts: "* ", ts&: output_buffer_var_name,
13405 ts: " [[buffer(", ts&: msl_options.shader_output_buffer_index, ts: ")]]");
13406 }
13407
13408 if (is_tesc_shader())
13409 {
13410 if (!ep_args.empty())
13411 ep_args += ", ";
13412 ep_args +=
13413 join(ts: "constant uint* spvIndirectParams [[buffer(", ts&: msl_options.indirect_params_buffer_index, ts: ")]]");
13414 }
13415 else if (stage_out_var_id &&
13416 !(get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation))
13417 {
13418 if (!ep_args.empty())
13419 ep_args += ", ";
13420 ep_args +=
13421 join(ts: "device uint* spvIndirectParams [[buffer(", ts&: msl_options.indirect_params_buffer_index, ts: ")]]");
13422 }
13423
13424 if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation &&
13425 (active_input_builtins.get(bit: BuiltInVertexIndex) || active_input_builtins.get(bit: BuiltInVertexId)) &&
13426 msl_options.vertex_index_type != Options::IndexType::None)
13427 {
13428 // Add the index buffer so we can set gl_VertexIndex correctly.
13429 if (!ep_args.empty())
13430 ep_args += ", ";
13431 switch (msl_options.vertex_index_type)
13432 {
13433 case Options::IndexType::None:
13434 break;
13435 case Options::IndexType::UInt16:
13436 ep_args += join(ts: "const device ushort* ", ts&: index_buffer_var_name, ts: " [[buffer(",
13437 ts&: msl_options.shader_index_buffer_index, ts: ")]]");
13438 break;
13439 case Options::IndexType::UInt32:
13440 ep_args += join(ts: "const device uint* ", ts&: index_buffer_var_name, ts: " [[buffer(",
13441 ts&: msl_options.shader_index_buffer_index, ts: ")]]");
13442 break;
13443 }
13444 }
13445
13446 // Tessellation control shaders get three additional parameters:
13447 // a buffer to hold the per-patch data, a buffer to hold the per-patch
13448 // tessellation levels, and a block of workgroup memory to hold the
13449 // input control point data.
13450 if (is_tesc_shader())
13451 {
13452 if (patch_stage_out_var_id)
13453 {
13454 if (!ep_args.empty())
13455 ep_args += ", ";
13456 ep_args +=
13457 join(ts: "device ", ts: type_to_glsl(type: get_patch_stage_out_struct_type()), ts: "* ", ts&: patch_output_buffer_var_name,
13458 ts: " [[buffer(", ts: convert_to_string(t: msl_options.shader_patch_output_buffer_index), ts: ")]]");
13459 }
13460 if (!ep_args.empty())
13461 ep_args += ", ";
13462 ep_args += join(ts: "device ", ts: get_tess_factor_struct_name(), ts: "* ", ts&: tess_factor_buffer_var_name, ts: " [[buffer(",
13463 ts: convert_to_string(t: msl_options.shader_tess_factor_buffer_index), ts: ")]]");
13464
13465 // Initializer for tess factors must be handled specially since it's never declared as a normal variable.
13466 uint32_t outer_factor_initializer_id = 0;
13467 uint32_t inner_factor_initializer_id = 0;
13468 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
13469 if (!has_decoration(id: var.self, decoration: DecorationBuiltIn) || var.storage != StorageClassOutput || !var.initializer)
13470 return;
13471
13472 BuiltIn builtin = BuiltIn(get_decoration(id: var.self, decoration: DecorationBuiltIn));
13473 if (builtin == BuiltInTessLevelInner)
13474 inner_factor_initializer_id = var.initializer;
13475 else if (builtin == BuiltInTessLevelOuter)
13476 outer_factor_initializer_id = var.initializer;
13477 });
13478
13479 const SPIRConstant *c = nullptr;
13480
13481 if (outer_factor_initializer_id && (c = maybe_get<SPIRConstant>(id: outer_factor_initializer_id)))
13482 {
13483 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
13484 entry_func.fixup_hooks_in.push_back(
13485 t: [=]()
13486 {
13487 uint32_t components = is_tessellating_triangles() ? 3 : 4;
13488 for (uint32_t i = 0; i < components; i++)
13489 {
13490 statement(ts: builtin_to_glsl(builtin: BuiltInTessLevelOuter, storage: StorageClassOutput), ts: "[", ts&: i,
13491 ts: "] = ", ts: "half(", ts: to_expression(id: c->subconstants[i]), ts: ");");
13492 }
13493 });
13494 }
13495
13496 if (inner_factor_initializer_id && (c = maybe_get<SPIRConstant>(id: inner_factor_initializer_id)))
13497 {
13498 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
13499 if (is_tessellating_triangles())
13500 {
13501 entry_func.fixup_hooks_in.push_back(t: [=]() {
13502 statement(ts: builtin_to_glsl(builtin: BuiltInTessLevelInner, storage: StorageClassOutput), ts: " = ", ts: "half(",
13503 ts: to_expression(id: c->subconstants[0]), ts: ");");
13504 });
13505 }
13506 else
13507 {
13508 entry_func.fixup_hooks_in.push_back(t: [=]() {
13509 for (uint32_t i = 0; i < 2; i++)
13510 {
13511 statement(ts: builtin_to_glsl(builtin: BuiltInTessLevelInner, storage: StorageClassOutput), ts: "[", ts&: i, ts: "] = ",
13512 ts: "half(", ts: to_expression(id: c->subconstants[i]), ts: ");");
13513 }
13514 });
13515 }
13516 }
13517
13518 if (stage_in_var_id)
13519 {
13520 if (!ep_args.empty())
13521 ep_args += ", ";
13522 if (msl_options.multi_patch_workgroup)
13523 {
13524 ep_args += join(ts: "device ", ts: type_to_glsl(type: get_stage_in_struct_type()), ts: "* ", ts&: input_buffer_var_name,
13525 ts: " [[buffer(", ts: convert_to_string(t: msl_options.shader_input_buffer_index), ts: ")]]");
13526 }
13527 else
13528 {
13529 ep_args += join(ts: "threadgroup ", ts: type_to_glsl(type: get_stage_in_struct_type()), ts: "* ", ts&: input_wg_var_name,
13530 ts: " [[threadgroup(", ts: convert_to_string(t: msl_options.shader_input_wg_index), ts: ")]]");
13531 }
13532 }
13533 }
13534 }
13535 // Tessellation evaluation shaders get three additional parameters:
13536 // a buffer for the per-patch data, a buffer for the per-patch
13537 // tessellation levels, and a buffer for the control point data.
13538 if (is_tese_shader() && msl_options.raw_buffer_tese_input)
13539 {
13540 if (patch_stage_in_var_id)
13541 {
13542 if (!ep_args.empty())
13543 ep_args += ", ";
13544 ep_args +=
13545 join(ts: "const device ", ts: type_to_glsl(type: get_patch_stage_in_struct_type()), ts: "* ", ts&: patch_input_buffer_var_name,
13546 ts: " [[buffer(", ts: convert_to_string(t: msl_options.shader_patch_input_buffer_index), ts: ")]]");
13547 }
13548
13549 if (tess_level_inner_var_id || tess_level_outer_var_id)
13550 {
13551 if (!ep_args.empty())
13552 ep_args += ", ";
13553 ep_args += join(ts: "const device ", ts: get_tess_factor_struct_name(), ts: "* ", ts&: tess_factor_buffer_var_name,
13554 ts: " [[buffer(", ts: convert_to_string(t: msl_options.shader_tess_factor_buffer_index), ts: ")]]");
13555 }
13556
13557 if (stage_in_var_id)
13558 {
13559 if (!ep_args.empty())
13560 ep_args += ", ";
13561 ep_args += join(ts: "const device ", ts: type_to_glsl(type: get_stage_in_struct_type()), ts: "* ", ts&: input_buffer_var_name,
13562 ts: " [[buffer(", ts: convert_to_string(t: msl_options.shader_input_buffer_index), ts: ")]]");
13563 }
13564 }
13565}
13566
13567string CompilerMSL::entry_point_args_argument_buffer(bool append_comma)
13568{
13569 string ep_args = entry_point_arg_stage_in();
13570 Bitset claimed_bindings;
13571
13572 for (uint32_t i = 0; i < kMaxArgumentBuffers; i++)
13573 {
13574 uint32_t id = argument_buffer_ids[i];
13575 if (id == 0)
13576 continue;
13577
13578 add_resource_name(id);
13579 auto &var = get<SPIRVariable>(id);
13580 auto &type = get_variable_data_type(var);
13581
13582 if (!ep_args.empty())
13583 ep_args += ", ";
13584
13585 // Check if the argument buffer binding itself has been remapped.
13586 uint32_t buffer_binding;
13587 auto itr = resource_bindings.find(x: { .model: get_entry_point().model, .desc_set: i, .binding: kArgumentBufferBinding });
13588 if (itr != end(cont&: resource_bindings))
13589 {
13590 buffer_binding = itr->second.first.msl_buffer;
13591 itr->second.second = true;
13592 }
13593 else
13594 {
13595 // As a fallback, directly map desc set <-> binding.
13596 // If that was taken, take the next buffer binding.
13597 if (claimed_bindings.get(bit: i))
13598 buffer_binding = next_metal_resource_index_buffer;
13599 else
13600 buffer_binding = i;
13601 }
13602
13603 claimed_bindings.set(buffer_binding);
13604
13605 ep_args += get_argument_address_space(argument: var) + " " + type_to_glsl(type) + "& " + to_restrict(id, space: true) + to_name(id);
13606 ep_args += " [[buffer(" + convert_to_string(t: buffer_binding) + ")]]";
13607
13608 next_metal_resource_index_buffer = max(a: next_metal_resource_index_buffer, b: buffer_binding + 1);
13609 }
13610
13611 entry_point_args_discrete_descriptors(args&: ep_args);
13612 entry_point_args_builtin(ep_args);
13613
13614 if (!ep_args.empty() && append_comma)
13615 ep_args += ", ";
13616
13617 return ep_args;
13618}
13619
13620const MSLConstexprSampler *CompilerMSL::find_constexpr_sampler(uint32_t id) const
13621{
13622 // Try by ID.
13623 {
13624 auto itr = constexpr_samplers_by_id.find(x: id);
13625 if (itr != end(cont: constexpr_samplers_by_id))
13626 return &itr->second;
13627 }
13628
13629 // Try by binding.
13630 {
13631 uint32_t desc_set = get_decoration(id, decoration: DecorationDescriptorSet);
13632 uint32_t binding = get_decoration(id, decoration: DecorationBinding);
13633
13634 auto itr = constexpr_samplers_by_binding.find(x: { .desc_set: desc_set, .binding: binding });
13635 if (itr != end(cont: constexpr_samplers_by_binding))
13636 return &itr->second;
13637 }
13638
13639 return nullptr;
13640}
13641
13642void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args)
13643{
13644 // Output resources, sorted by resource index & type
13645 // We need to sort to work around a bug on macOS 10.13 with NVidia drivers where switching between shaders
13646 // with different order of buffers can result in issues with buffer assignments inside the driver.
13647 struct Resource
13648 {
13649 SPIRVariable *var;
13650 SPIRVariable *discrete_descriptor_alias;
13651 string name;
13652 SPIRType::BaseType basetype;
13653 uint32_t index;
13654 uint32_t plane;
13655 uint32_t secondary_index;
13656 };
13657
13658 SmallVector<Resource> resources;
13659
13660 entry_point_bindings.clear();
13661 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t var_id, SPIRVariable &var) {
13662 if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
13663 var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer) &&
13664 !is_hidden_variable(var))
13665 {
13666 auto &type = get_variable_data_type(var);
13667 uint32_t desc_set = get_decoration(id: var_id, decoration: DecorationDescriptorSet);
13668
13669 if (is_supported_argument_buffer_type(type) && var.storage != StorageClassPushConstant)
13670 {
13671 if (descriptor_set_is_argument_buffer(desc_set))
13672 {
13673 if (is_var_runtime_size_array(var))
13674 {
13675 // Runtime arrays need to be wrapped in spvDescriptorArray from argument buffer payload.
13676 entry_point_bindings.push_back(t: &var);
13677 // We'll wrap this, so to_name() will always use non-qualified name.
13678 // We'll need the qualified name to create temporary variable instead.
13679 ir.meta[var_id].decoration.qualified_alias_explicit_override = true;
13680 }
13681 return;
13682 }
13683 }
13684
13685 // Handle descriptor aliasing of simple discrete cases.
13686 // We can handle aliasing of buffers by casting pointers.
13687 // The amount of aliasing we can perform for discrete descriptors is very limited.
13688 // For fully mutable-style aliasing, we need argument buffers where we can exploit the fact
13689 // that descriptors are all 8 bytes.
13690 SPIRVariable *discrete_descriptor_alias = nullptr;
13691 if (var.storage == StorageClassUniform || var.storage == StorageClassStorageBuffer)
13692 {
13693 for (auto &resource : resources)
13694 {
13695 if (get_decoration(id: resource.var->self, decoration: DecorationDescriptorSet) ==
13696 get_decoration(id: var_id, decoration: DecorationDescriptorSet) &&
13697 get_decoration(id: resource.var->self, decoration: DecorationBinding) ==
13698 get_decoration(id: var_id, decoration: DecorationBinding) &&
13699 resource.basetype == SPIRType::Struct && type.basetype == SPIRType::Struct &&
13700 (resource.var->storage == StorageClassUniform ||
13701 resource.var->storage == StorageClassStorageBuffer))
13702 {
13703 discrete_descriptor_alias = resource.var;
13704 // Self-reference marks that we should declare the resource,
13705 // and it's being used as an alias (so we can emit void* instead).
13706 resource.discrete_descriptor_alias = resource.var;
13707 // Need to promote interlocked usage so that the primary declaration is correct.
13708 if (interlocked_resources.count(x: var_id))
13709 interlocked_resources.insert(x: resource.var->self);
13710 break;
13711 }
13712 }
13713 }
13714
13715 const MSLConstexprSampler *constexpr_sampler = nullptr;
13716 if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler)
13717 {
13718 constexpr_sampler = find_constexpr_sampler(id: var_id);
13719 if (constexpr_sampler)
13720 {
13721 // Mark this ID as a constexpr sampler for later in case it came from set/bindings.
13722 constexpr_samplers_by_id[var_id] = *constexpr_sampler;
13723 }
13724 }
13725
13726 // Emulate texture2D atomic operations
13727 uint32_t secondary_index = 0;
13728 if (atomic_image_vars_emulated.count(x: var.self))
13729 {
13730 secondary_index = get_metal_resource_index(var, basetype: SPIRType::AtomicCounter, plane: 0);
13731 }
13732
13733 if (type.basetype == SPIRType::SampledImage)
13734 {
13735 add_resource_name(id: var_id);
13736
13737 uint32_t plane_count = 1;
13738 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
13739 plane_count = constexpr_sampler->planes;
13740
13741 entry_point_bindings.push_back(t: &var);
13742 for (uint32_t i = 0; i < plane_count; i++)
13743 resources.push_back(t: {.var: &var, .discrete_descriptor_alias: discrete_descriptor_alias, .name: to_name(id: var_id), .basetype: SPIRType::Image,
13744 .index: get_metal_resource_index(var, basetype: SPIRType::Image, plane: i), .plane: i, .secondary_index: secondary_index });
13745
13746 if (type.image.dim != DimBuffer && !constexpr_sampler)
13747 {
13748 resources.push_back(t: {.var: &var, .discrete_descriptor_alias: discrete_descriptor_alias, .name: to_sampler_expression(id: var_id), .basetype: SPIRType::Sampler,
13749 .index: get_metal_resource_index(var, basetype: SPIRType::Sampler), .plane: 0, .secondary_index: 0 });
13750 }
13751 }
13752 else if (!constexpr_sampler)
13753 {
13754 // constexpr samplers are not declared as resources.
13755 add_resource_name(id: var_id);
13756
13757 // Don't allocate resource indices for aliases.
13758 uint32_t resource_index = ~0u;
13759 if (!discrete_descriptor_alias)
13760 resource_index = get_metal_resource_index(var, basetype: type.basetype);
13761
13762 entry_point_bindings.push_back(t: &var);
13763 resources.push_back(t: {.var: &var, .discrete_descriptor_alias: discrete_descriptor_alias, .name: to_name(id: var_id), .basetype: type.basetype,
13764 .index: resource_index, .plane: 0, .secondary_index: secondary_index });
13765 }
13766 }
13767 });
13768
13769 stable_sort(first: resources.begin(), last: resources.end(),
13770 comp: [](const Resource &lhs, const Resource &rhs)
13771 { return tie(args: lhs.basetype, args: lhs.index) < tie(args: rhs.basetype, args: rhs.index); });
13772
13773 for (auto &r : resources)
13774 {
13775 auto &var = *r.var;
13776 auto &type = get_variable_data_type(var);
13777
13778 uint32_t var_id = var.self;
13779
13780 switch (r.basetype)
13781 {
13782 case SPIRType::Struct:
13783 {
13784 auto &m = ir.meta[type.self];
13785 if (m.members.size() == 0)
13786 break;
13787
13788 if (r.discrete_descriptor_alias)
13789 {
13790 if (r.var == r.discrete_descriptor_alias)
13791 {
13792 auto primary_name = join(ts: "spvBufferAliasSet",
13793 ts: get_decoration(id: var_id, decoration: DecorationDescriptorSet),
13794 ts: "Binding",
13795 ts: get_decoration(id: var_id, decoration: DecorationBinding));
13796
13797 // Declare the primary alias as void*
13798 if (!ep_args.empty())
13799 ep_args += ", ";
13800 ep_args += get_argument_address_space(argument: var) + " void* " + primary_name;
13801 ep_args += " [[buffer(" + convert_to_string(t: r.index) + ")";
13802 if (interlocked_resources.count(x: var_id))
13803 ep_args += ", raster_order_group(0)";
13804 ep_args += "]]";
13805 }
13806
13807 buffer_aliases_discrete.push_back(t: r.var->self);
13808 }
13809 else if (!type.array.empty())
13810 {
13811 if (type.array.size() > 1)
13812 SPIRV_CROSS_THROW("Arrays of arrays of buffers are not supported.");
13813
13814 is_using_builtin_array = true;
13815 if (is_var_runtime_size_array(var))
13816 {
13817 add_spv_func_and_recompile(spv_func: SPVFuncImplVariableDescriptorArray);
13818 if (!ep_args.empty())
13819 ep_args += ", ";
13820 const bool ssbo = has_decoration(id: type.self, decoration: DecorationBufferBlock);
13821 if ((var.storage == spv::StorageClassStorageBuffer || ssbo) &&
13822 msl_options.runtime_array_rich_descriptor)
13823 {
13824 add_spv_func_and_recompile(spv_func: SPVFuncImplVariableSizedDescriptor);
13825 ep_args += "const device spvBufferDescriptor<" + get_argument_address_space(argument: var) + " " +
13826 type_to_glsl(type) + "*>* ";
13827 }
13828 else
13829 {
13830 ep_args += "const device spvDescriptor<" + get_argument_address_space(argument: var) + " " +
13831 type_to_glsl(type) + "*>* ";
13832 }
13833 ep_args += to_restrict(id: var_id, space: true) + r.name + "_";
13834 ep_args += " [[buffer(" + convert_to_string(t: r.index) + ")";
13835 if (interlocked_resources.count(x: var_id))
13836 ep_args += ", raster_order_group(0)";
13837 ep_args += "]]";
13838 }
13839 else
13840 {
13841 uint32_t array_size = get_resource_array_size(type, id: var_id);
13842 for (uint32_t i = 0; i < array_size; ++i)
13843 {
13844 if (!ep_args.empty())
13845 ep_args += ", ";
13846 ep_args += get_argument_address_space(argument: var) + " " + type_to_glsl(type) + "* " +
13847 to_restrict(id: var_id, space: true) + r.name + "_" + convert_to_string(t: i);
13848 ep_args += " [[buffer(" + convert_to_string(t: r.index + i) + ")";
13849 if (interlocked_resources.count(x: var_id))
13850 ep_args += ", raster_order_group(0)";
13851 ep_args += "]]";
13852 }
13853 }
13854 is_using_builtin_array = false;
13855 }
13856 else
13857 {
13858 if (!ep_args.empty())
13859 ep_args += ", ";
13860 ep_args += get_argument_address_space(argument: var) + " ";
13861
13862 if (recursive_inputs.count(x: type.self))
13863 ep_args += string("void* ") + to_restrict(id: var_id, space: true) + r.name + "_vp";
13864 else
13865 ep_args += type_to_glsl(type) + "& " + to_restrict(id: var_id, space: true) + r.name;
13866
13867 ep_args += " [[buffer(" + convert_to_string(t: r.index) + ")";
13868 if (interlocked_resources.count(x: var_id))
13869 ep_args += ", raster_order_group(0)";
13870 ep_args += "]]";
13871 }
13872 break;
13873 }
13874 case SPIRType::Sampler:
13875 if (!ep_args.empty())
13876 ep_args += ", ";
13877 ep_args += sampler_type(type, id: var_id, member: false) + " " + r.name;
13878 if (is_var_runtime_size_array(var))
13879 ep_args += "_ [[buffer(" + convert_to_string(t: r.index) + ")]]";
13880 else
13881 ep_args += " [[sampler(" + convert_to_string(t: r.index) + ")]]";
13882 break;
13883 case SPIRType::Image:
13884 {
13885 if (!ep_args.empty())
13886 ep_args += ", ";
13887
13888 // Use Metal's native frame-buffer fetch API for subpass inputs.
13889 const auto &basetype = get<SPIRType>(id: var.basetype);
13890 if (!type_is_msl_framebuffer_fetch(type: basetype))
13891 {
13892 ep_args += image_type_glsl(type, id: var_id, member: false) + " " + r.name;
13893 if (r.plane > 0)
13894 ep_args += join(ts&: plane_name_suffix, ts&: r.plane);
13895
13896 if (is_var_runtime_size_array(var))
13897 ep_args += "_ [[buffer(" + convert_to_string(t: r.index) + ")";
13898 else
13899 ep_args += " [[texture(" + convert_to_string(t: r.index) + ")";
13900
13901 if (interlocked_resources.count(x: var_id))
13902 ep_args += ", raster_order_group(0)";
13903 ep_args += "]]";
13904 }
13905 else
13906 {
13907 if (msl_options.is_macos() && !msl_options.supports_msl_version(major: 2, minor: 3))
13908 SPIRV_CROSS_THROW("Framebuffer fetch on Mac is not supported before MSL 2.3.");
13909 ep_args += image_type_glsl(type, id: var_id, member: false) + " " + r.name;
13910 ep_args += " [[color(" + convert_to_string(t: r.index) + ")]]";
13911 }
13912
13913 // Emulate texture2D atomic operations
13914 if (atomic_image_vars_emulated.count(x: var.self))
13915 {
13916 auto &flags = ir.get_decoration_bitset(id: var.self);
13917 const char *cv_flags = decoration_flags_signal_volatile(flags) ? "volatile " : "";
13918 ep_args += join(ts: ", ", ts&: cv_flags, ts: "device atomic_", ts: type_to_glsl(type: get<SPIRType>(id: basetype.image.type), id: 0));
13919 ep_args += "* " + r.name + "_atomic";
13920 ep_args += " [[buffer(" + convert_to_string(t: r.secondary_index) + ")";
13921 if (interlocked_resources.count(x: var_id))
13922 ep_args += ", raster_order_group(0)";
13923 ep_args += "]]";
13924 }
13925 break;
13926 }
13927 case SPIRType::AccelerationStructure:
13928 {
13929 if (is_var_runtime_size_array(var))
13930 {
13931 add_spv_func_and_recompile(spv_func: SPVFuncImplVariableDescriptor);
13932 const auto &parent_type = get<SPIRType>(id: type.parent_type);
13933 if (!ep_args.empty())
13934 ep_args += ", ";
13935 ep_args += "const device spvDescriptor<" + type_to_glsl(type: parent_type) + ">* " +
13936 to_restrict(id: var_id, space: true) + r.name + "_";
13937 ep_args += " [[buffer(" + convert_to_string(t: r.index) + ")]]";
13938 }
13939 else
13940 {
13941 if (!ep_args.empty())
13942 ep_args += ", ";
13943 ep_args += type_to_glsl(type, id: var_id) + " " + r.name;
13944 ep_args += " [[buffer(" + convert_to_string(t: r.index) + ")]]";
13945 }
13946 break;
13947 }
13948 default:
13949 if (!ep_args.empty())
13950 ep_args += ", ";
13951 if (!type.pointer)
13952 ep_args += get_type_address_space(type: get<SPIRType>(id: var.basetype), id: var_id) + " " +
13953 type_to_glsl(type, id: var_id) + "& " + r.name;
13954 else
13955 ep_args += type_to_glsl(type, id: var_id) + " " + r.name;
13956 ep_args += " [[buffer(" + convert_to_string(t: r.index) + ")";
13957 if (interlocked_resources.count(x: var_id))
13958 ep_args += ", raster_order_group(0)";
13959 ep_args += "]]";
13960 break;
13961 }
13962 }
13963}
13964
13965// Returns a string containing a comma-delimited list of args for the entry point function
13966// This is the "classic" method of MSL 1 when we don't have argument buffer support.
13967string CompilerMSL::entry_point_args_classic(bool append_comma)
13968{
13969 string ep_args = entry_point_arg_stage_in();
13970 entry_point_args_discrete_descriptors(ep_args);
13971 entry_point_args_builtin(ep_args);
13972
13973 if (!ep_args.empty() && append_comma)
13974 ep_args += ", ";
13975
13976 return ep_args;
13977}
13978
13979void CompilerMSL::fix_up_shader_inputs_outputs()
13980{
13981 auto &entry_func = this->get<SPIRFunction>(id: ir.default_entry_point);
13982
13983 // Emit a guard to ensure we don't execute beyond the last vertex.
13984 // Vertex shaders shouldn't have the problems with barriers in non-uniform control flow that
13985 // tessellation control shaders do, so early returns should be OK. We may need to revisit this
13986 // if it ever becomes possible to use barriers from a vertex shader.
13987 if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation)
13988 {
13989 entry_func.fixup_hooks_in.push_back(t: [this]() {
13990 statement(ts: "if (any(", ts: to_expression(id: builtin_invocation_id_id),
13991 ts: " >= ", ts: to_expression(id: builtin_stage_input_size_id), ts: "))");
13992 statement(ts: " return;");
13993 });
13994 }
13995
13996 // Look for sampled images and buffer. Add hooks to set up the swizzle constants or array lengths.
13997 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t, SPIRVariable &var) {
13998 auto &type = get_variable_data_type(var);
13999 uint32_t var_id = var.self;
14000 bool ssbo = has_decoration(id: type.self, decoration: DecorationBufferBlock);
14001
14002 if (var.storage == StorageClassUniformConstant && !is_hidden_variable(var))
14003 {
14004 if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
14005 {
14006 entry_func.fixup_hooks_in.push_back(t: [this, &type, &var, var_id]() {
14007 bool is_array_type = !type.array.empty();
14008
14009 uint32_t desc_set = get_decoration(id: var_id, decoration: DecorationDescriptorSet);
14010 if (descriptor_set_is_argument_buffer(desc_set))
14011 {
14012 statement(ts: "constant uint", ts: is_array_type ? "* " : "& ", ts: to_swizzle_expression(id: var_id),
14013 ts: is_array_type ? " = &" : " = ", ts: to_name(id: argument_buffer_ids[desc_set]),
14014 ts: ".spvSwizzleConstants", ts: "[",
14015 ts: convert_to_string(t: get_metal_resource_index(var, basetype: SPIRType::Image)), ts: "];");
14016 }
14017 else
14018 {
14019 // If we have an array of images, we need to be able to index into it, so take a pointer instead.
14020 statement(ts: "constant uint", ts: is_array_type ? "* " : "& ", ts: to_swizzle_expression(id: var_id),
14021 ts: is_array_type ? " = &" : " = ", ts: to_name(id: swizzle_buffer_id), ts: "[",
14022 ts: convert_to_string(t: get_metal_resource_index(var, basetype: SPIRType::Image)), ts: "];");
14023 }
14024 });
14025 }
14026 }
14027 else if ((var.storage == StorageClassStorageBuffer || (var.storage == StorageClassUniform && ssbo)) &&
14028 !is_hidden_variable(var))
14029 {
14030 if (buffer_requires_array_length(id: var.self))
14031 {
14032 entry_func.fixup_hooks_in.push_back(
14033 t: [this, &type, &var, var_id]()
14034 {
14035 bool is_array_type = !type.array.empty() && !is_var_runtime_size_array(var);
14036
14037 uint32_t desc_set = get_decoration(id: var_id, decoration: DecorationDescriptorSet);
14038 if (descriptor_set_is_argument_buffer(desc_set))
14039 {
14040 statement(ts: "constant uint", ts: is_array_type ? "* " : "& ", ts: to_buffer_size_expression(id: var_id),
14041 ts: is_array_type ? " = &" : " = ", ts: to_name(id: argument_buffer_ids[desc_set]),
14042 ts: ".spvBufferSizeConstants", ts: "[",
14043 ts: convert_to_string(t: get_metal_resource_index(var, basetype: SPIRType::Image)), ts: "];");
14044 }
14045 else
14046 {
14047 // If we have an array of images, we need to be able to index into it, so take a pointer instead.
14048 statement(ts: "constant uint", ts: is_array_type ? "* " : "& ", ts: to_buffer_size_expression(id: var_id),
14049 ts: is_array_type ? " = &" : " = ", ts: to_name(id: buffer_size_buffer_id), ts: "[",
14050 ts: convert_to_string(t: get_metal_resource_index(var, basetype: type.basetype)), ts: "];");
14051 }
14052 });
14053 }
14054 }
14055
14056 if (msl_options.replace_recursive_inputs && type_contains_recursion(type) &&
14057 (var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
14058 var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer))
14059 {
14060 recursive_inputs.insert(x: type.self);
14061 entry_func.fixup_hooks_in.push_back(t: [this, &type, &var, var_id]() {
14062 auto addr_space = get_argument_address_space(argument: var);
14063 auto var_name = to_name(id: var_id);
14064 statement(ts&: addr_space, ts: " auto& ", ts: to_restrict(id: var_id, space: true), ts&: var_name,
14065 ts: " = *(", ts&: addr_space, ts: " ", ts: type_to_glsl(type), ts: "*)", ts&: var_name, ts: "_vp;");
14066 });
14067 }
14068 });
14069
14070 // Builtin variables
14071 ir.for_each_typed_id<SPIRVariable>(op: [this, &entry_func](uint32_t, SPIRVariable &var) {
14072 uint32_t var_id = var.self;
14073 BuiltIn bi_type = ir.meta[var_id].decoration.builtin_type;
14074
14075 if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
14076 return;
14077 if (!interface_variable_exists_in_entry_point(id: var.self))
14078 return;
14079
14080 if (var.storage == StorageClassInput && is_builtin_variable(var) && active_input_builtins.get(bit: bi_type))
14081 {
14082 switch (bi_type)
14083 {
14084 case BuiltInSamplePosition:
14085 entry_func.fixup_hooks_in.push_back(t: [=]() {
14086 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = get_sample_position(",
14087 ts: to_expression(id: builtin_sample_id_id), ts: ");");
14088 });
14089 break;
14090 case BuiltInFragCoord:
14091 if (is_sample_rate())
14092 {
14093 entry_func.fixup_hooks_in.push_back(t: [=]() {
14094 statement(ts: to_expression(id: var_id), ts: ".xy += get_sample_position(",
14095 ts: to_expression(id: builtin_sample_id_id), ts: ") - 0.5;");
14096 });
14097 }
14098 break;
14099 case BuiltInInvocationId:
14100 // This is direct-mapped without multi-patch workgroups.
14101 if (!is_tesc_shader() || !msl_options.multi_patch_workgroup)
14102 break;
14103
14104 entry_func.fixup_hooks_in.push_back(t: [=]() {
14105 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14106 ts: to_expression(id: builtin_invocation_id_id), ts: ".x % ", ts&: this->get_entry_point().output_vertices,
14107 ts: ";");
14108 });
14109 break;
14110 case BuiltInPrimitiveId:
14111 // This is natively supported by fragment and tessellation evaluation shaders.
14112 // In tessellation control shaders, this is direct-mapped without multi-patch workgroups.
14113 if (!is_tesc_shader() || !msl_options.multi_patch_workgroup)
14114 break;
14115
14116 entry_func.fixup_hooks_in.push_back(t: [=]() {
14117 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = min(",
14118 ts: to_expression(id: builtin_invocation_id_id), ts: ".x / ", ts&: this->get_entry_point().output_vertices,
14119 ts: ", spvIndirectParams[1] - 1);");
14120 });
14121 break;
14122 case BuiltInPatchVertices:
14123 if (is_tese_shader())
14124 {
14125 if (msl_options.raw_buffer_tese_input)
14126 {
14127 entry_func.fixup_hooks_in.push_back(
14128 t: [=]() {
14129 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14130 ts&: get_entry_point().output_vertices, ts: ";");
14131 });
14132 }
14133 else
14134 {
14135 entry_func.fixup_hooks_in.push_back(
14136 t: [=]()
14137 {
14138 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14139 ts: to_expression(id: patch_stage_in_var_id), ts: ".gl_in.size();");
14140 });
14141 }
14142 }
14143 else
14144 {
14145 entry_func.fixup_hooks_in.push_back(t: [=]() {
14146 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = spvIndirectParams[0];");
14147 });
14148 }
14149 break;
14150 case BuiltInTessCoord:
14151 if (get_entry_point().flags.get(bit: ExecutionModeQuads))
14152 {
14153 // The entry point will only have a float2 TessCoord variable.
14154 // Pad to float3.
14155 entry_func.fixup_hooks_in.push_back(t: [=]() {
14156 auto name = builtin_to_glsl(builtin: BuiltInTessCoord, storage: StorageClassInput);
14157 statement(ts: "float3 " + name + " = float3(" + name + "In.x, " + name + "In.y, 0.0);");
14158 });
14159 }
14160
14161 // Emit a fixup to account for the shifted domain. Don't do this for triangles;
14162 // MoltenVK will just reverse the winding order instead.
14163 if (msl_options.tess_domain_origin_lower_left && !is_tessellating_triangles())
14164 {
14165 string tc = to_expression(id: var_id);
14166 entry_func.fixup_hooks_in.push_back(t: [=]() { statement(ts: tc, ts: ".y = 1.0 - ", ts: tc, ts: ".y;"); });
14167 }
14168 break;
14169 case BuiltInSubgroupId:
14170 if (!msl_options.emulate_subgroups)
14171 break;
14172 // For subgroup emulation, this is the same as the local invocation index.
14173 entry_func.fixup_hooks_in.push_back(t: [=]() {
14174 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14175 ts: to_expression(id: builtin_local_invocation_index_id), ts: ";");
14176 });
14177 break;
14178 case BuiltInNumSubgroups:
14179 if (!msl_options.emulate_subgroups)
14180 break;
14181 // For subgroup emulation, this is the same as the workgroup size.
14182 entry_func.fixup_hooks_in.push_back(t: [=]() {
14183 auto &type = expression_type(id: builtin_workgroup_size_id);
14184 string size_expr = to_expression(id: builtin_workgroup_size_id);
14185 if (type.vecsize >= 3)
14186 size_expr = join(ts&: size_expr, ts: ".x * ", ts&: size_expr, ts: ".y * ", ts&: size_expr, ts: ".z");
14187 else if (type.vecsize == 2)
14188 size_expr = join(ts&: size_expr, ts: ".x * ", ts&: size_expr, ts: ".y");
14189 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ", ts&: size_expr, ts: ";");
14190 });
14191 break;
14192 case BuiltInSubgroupLocalInvocationId:
14193 if (!msl_options.emulate_subgroups)
14194 break;
14195 // For subgroup emulation, assume subgroups of size 1.
14196 entry_func.fixup_hooks_in.push_back(
14197 t: [=]() { statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = 0;"); });
14198 break;
14199 case BuiltInSubgroupSize:
14200 if (msl_options.emulate_subgroups)
14201 {
14202 // For subgroup emulation, assume subgroups of size 1.
14203 entry_func.fixup_hooks_in.push_back(
14204 t: [=]() { statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = 1;"); });
14205 }
14206 else if (msl_options.fixed_subgroup_size != 0)
14207 {
14208 entry_func.fixup_hooks_in.push_back(t: [=]() {
14209 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14210 ts&: msl_options.fixed_subgroup_size, ts: ";");
14211 });
14212 }
14213 break;
14214 case BuiltInSubgroupEqMask:
14215 if (msl_options.is_ios() && !msl_options.supports_msl_version(major: 2, minor: 2))
14216 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
14217 if (!msl_options.supports_msl_version(major: 2, minor: 1))
14218 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
14219 entry_func.fixup_hooks_in.push_back(t: [=]() {
14220 if (msl_options.is_ios())
14221 {
14222 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ", ts: "uint4(1 << ",
14223 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: ", uint3(0));");
14224 }
14225 else
14226 {
14227 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14228 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " >= 32 ? uint4(0, (1 << (",
14229 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " - 32)), uint2(0)) : uint4(1 << ",
14230 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: ", uint3(0));");
14231 }
14232 });
14233 break;
14234 case BuiltInSubgroupGeMask:
14235 if (msl_options.is_ios() && !msl_options.supports_msl_version(major: 2, minor: 2))
14236 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
14237 if (!msl_options.supports_msl_version(major: 2, minor: 1))
14238 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
14239 if (msl_options.fixed_subgroup_size != 0)
14240 add_spv_func_and_recompile(spv_func: SPVFuncImplSubgroupBallot);
14241 entry_func.fixup_hooks_in.push_back(t: [=]() {
14242 // Case where index < 32, size < 32:
14243 // mask0 = bfi(0, 0xFFFFFFFF, index, size - index);
14244 // mask1 = bfi(0, 0xFFFFFFFF, 0, 0); // Gives 0
14245 // Case where index < 32 but size >= 32:
14246 // mask0 = bfi(0, 0xFFFFFFFF, index, 32 - index);
14247 // mask1 = bfi(0, 0xFFFFFFFF, 0, size - 32);
14248 // Case where index >= 32:
14249 // mask0 = bfi(0, 0xFFFFFFFF, 32, 0); // Gives 0
14250 // mask1 = bfi(0, 0xFFFFFFFF, index - 32, size - index);
14251 // This is expressed without branches to avoid divergent
14252 // control flow--hence the complicated min/max expressions.
14253 // This is further complicated by the fact that if you attempt
14254 // to bfi/bfe out-of-bounds on Metal, undefined behavior is the
14255 // result.
14256 if (msl_options.fixed_subgroup_size > 32)
14257 {
14258 // Don't use the subgroup size variable with fixed subgroup sizes,
14259 // since the variables could be defined in the wrong order.
14260 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14261 ts: " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
14262 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: ", 32u), (uint)max(32 - (int)",
14263 ts: to_expression(id: builtin_subgroup_invocation_id_id),
14264 ts: ", 0)), insert_bits(0u, 0xFFFFFFFF,"
14265 " (uint)max((int)",
14266 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " - 32, 0), ",
14267 ts&: msl_options.fixed_subgroup_size, ts: " - max(",
14268 ts: to_expression(id: builtin_subgroup_invocation_id_id),
14269 ts: ", 32u)), uint2(0));");
14270 }
14271 else if (msl_options.fixed_subgroup_size != 0)
14272 {
14273 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14274 ts: " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
14275 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: ", ",
14276 ts&: msl_options.fixed_subgroup_size, ts: " - ",
14277 ts: to_expression(id: builtin_subgroup_invocation_id_id),
14278 ts: "), uint3(0));");
14279 }
14280 else if (msl_options.is_ios())
14281 {
14282 // On iOS, the SIMD-group size will currently never exceed 32.
14283 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14284 ts: " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
14285 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: ", ",
14286 ts: to_expression(id: builtin_subgroup_size_id), ts: " - ",
14287 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: "), uint3(0));");
14288 }
14289 else
14290 {
14291 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14292 ts: " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
14293 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: ", 32u), (uint)max(min((int)",
14294 ts: to_expression(id: builtin_subgroup_size_id), ts: ", 32) - (int)",
14295 ts: to_expression(id: builtin_subgroup_invocation_id_id),
14296 ts: ", 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
14297 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " - 32, 0), (uint)max((int)",
14298 ts: to_expression(id: builtin_subgroup_size_id), ts: " - (int)max(",
14299 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: ", 32u), 0)), uint2(0));");
14300 }
14301 });
14302 break;
14303 case BuiltInSubgroupGtMask:
14304 if (msl_options.is_ios() && !msl_options.supports_msl_version(major: 2, minor: 2))
14305 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
14306 if (!msl_options.supports_msl_version(major: 2, minor: 1))
14307 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
14308 add_spv_func_and_recompile(spv_func: SPVFuncImplSubgroupBallot);
14309 entry_func.fixup_hooks_in.push_back(t: [=]() {
14310 // The same logic applies here, except now the index is one
14311 // more than the subgroup invocation ID.
14312 if (msl_options.fixed_subgroup_size > 32)
14313 {
14314 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14315 ts: " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
14316 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " + 1, 32u), (uint)max(32 - (int)",
14317 ts: to_expression(id: builtin_subgroup_invocation_id_id),
14318 ts: " - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
14319 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " + 1 - 32, 0), ",
14320 ts&: msl_options.fixed_subgroup_size, ts: " - max(",
14321 ts: to_expression(id: builtin_subgroup_invocation_id_id),
14322 ts: " + 1, 32u)), uint2(0));");
14323 }
14324 else if (msl_options.fixed_subgroup_size != 0)
14325 {
14326 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14327 ts: " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
14328 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " + 1, ",
14329 ts&: msl_options.fixed_subgroup_size, ts: " - ",
14330 ts: to_expression(id: builtin_subgroup_invocation_id_id),
14331 ts: " - 1), uint3(0));");
14332 }
14333 else if (msl_options.is_ios())
14334 {
14335 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14336 ts: " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
14337 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " + 1, ",
14338 ts: to_expression(id: builtin_subgroup_size_id), ts: " - ",
14339 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " - 1), uint3(0));");
14340 }
14341 else
14342 {
14343 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14344 ts: " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
14345 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " + 1, 32u), (uint)max(min((int)",
14346 ts: to_expression(id: builtin_subgroup_size_id), ts: ", 32) - (int)",
14347 ts: to_expression(id: builtin_subgroup_invocation_id_id),
14348 ts: " - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
14349 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " + 1 - 32, 0), (uint)max((int)",
14350 ts: to_expression(id: builtin_subgroup_size_id), ts: " - (int)max(",
14351 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " + 1, 32u), 0)), uint2(0));");
14352 }
14353 });
14354 break;
14355 case BuiltInSubgroupLeMask:
14356 if (msl_options.is_ios() && !msl_options.supports_msl_version(major: 2, minor: 2))
14357 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
14358 if (!msl_options.supports_msl_version(major: 2, minor: 1))
14359 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
14360 add_spv_func_and_recompile(spv_func: SPVFuncImplSubgroupBallot);
14361 entry_func.fixup_hooks_in.push_back(t: [=]() {
14362 if (msl_options.is_ios())
14363 {
14364 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14365 ts: " = uint4(extract_bits(0xFFFFFFFF, 0, ",
14366 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " + 1), uint3(0));");
14367 }
14368 else
14369 {
14370 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14371 ts: " = uint4(extract_bits(0xFFFFFFFF, 0, min(",
14372 ts: to_expression(id: builtin_subgroup_invocation_id_id),
14373 ts: " + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
14374 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " + 1 - 32, 0)), uint2(0));");
14375 }
14376 });
14377 break;
14378 case BuiltInSubgroupLtMask:
14379 if (msl_options.is_ios() && !msl_options.supports_msl_version(major: 2, minor: 2))
14380 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
14381 if (!msl_options.supports_msl_version(major: 2, minor: 1))
14382 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
14383 add_spv_func_and_recompile(spv_func: SPVFuncImplSubgroupBallot);
14384 entry_func.fixup_hooks_in.push_back(t: [=]() {
14385 if (msl_options.is_ios())
14386 {
14387 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14388 ts: " = uint4(extract_bits(0xFFFFFFFF, 0, ",
14389 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: "), uint3(0));");
14390 }
14391 else
14392 {
14393 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id),
14394 ts: " = uint4(extract_bits(0xFFFFFFFF, 0, min(",
14395 ts: to_expression(id: builtin_subgroup_invocation_id_id),
14396 ts: ", 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
14397 ts: to_expression(id: builtin_subgroup_invocation_id_id), ts: " - 32, 0)), uint2(0));");
14398 }
14399 });
14400 break;
14401 case BuiltInViewIndex:
14402 if (!msl_options.multiview)
14403 {
14404 // According to the Vulkan spec, when not running under a multiview
14405 // render pass, ViewIndex is 0.
14406 entry_func.fixup_hooks_in.push_back(t: [=]() {
14407 statement(ts: "const ", ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = 0;");
14408 });
14409 }
14410 else if (msl_options.view_index_from_device_index)
14411 {
14412 // In this case, we take the view index from that of the device we're running on.
14413 entry_func.fixup_hooks_in.push_back(t: [=]() {
14414 statement(ts: "const ", ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14415 ts&: msl_options.device_index, ts: ";");
14416 });
14417 // We actually don't want to set the render_target_array_index here.
14418 // Since every physical device is rendering a different view,
14419 // there's no need for layered rendering here.
14420 }
14421 else if (!msl_options.multiview_layered_rendering)
14422 {
14423 // In this case, the views are rendered one at a time. The view index, then,
14424 // is just the first part of the "view mask".
14425 entry_func.fixup_hooks_in.push_back(t: [=]() {
14426 statement(ts: "const ", ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14427 ts: to_expression(id: view_mask_buffer_id), ts: "[0];");
14428 });
14429 }
14430 else if (get_execution_model() == ExecutionModelFragment)
14431 {
14432 // Because we adjusted the view index in the vertex shader, we have to
14433 // adjust it back here.
14434 entry_func.fixup_hooks_in.push_back(t: [=]() {
14435 statement(ts: to_expression(id: var_id), ts: " += ", ts: to_expression(id: view_mask_buffer_id), ts: "[0];");
14436 });
14437 }
14438 else if (get_execution_model() == ExecutionModelVertex)
14439 {
14440 // Metal provides no special support for multiview, so we smuggle
14441 // the view index in the instance index.
14442 entry_func.fixup_hooks_in.push_back(t: [=]() {
14443 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14444 ts: to_expression(id: view_mask_buffer_id), ts: "[0] + (", ts: to_expression(id: builtin_instance_idx_id),
14445 ts: " - ", ts: to_expression(id: builtin_base_instance_id), ts: ") % ",
14446 ts: to_expression(id: view_mask_buffer_id), ts: "[1];");
14447 statement(ts: to_expression(id: builtin_instance_idx_id), ts: " = (",
14448 ts: to_expression(id: builtin_instance_idx_id), ts: " - ",
14449 ts: to_expression(id: builtin_base_instance_id), ts: ") / ", ts: to_expression(id: view_mask_buffer_id),
14450 ts: "[1] + ", ts: to_expression(id: builtin_base_instance_id), ts: ";");
14451 });
14452 // In addition to setting the variable itself, we also need to
14453 // set the render_target_array_index with it on output. We have to
14454 // offset this by the base view index, because Metal isn't in on
14455 // our little game here.
14456 entry_func.fixup_hooks_out.push_back(t: [=]() {
14457 statement(ts: to_expression(id: builtin_layer_id), ts: " = ", ts: to_expression(id: var_id), ts: " - ",
14458 ts: to_expression(id: view_mask_buffer_id), ts: "[0];");
14459 });
14460 }
14461 break;
14462 case BuiltInDeviceIndex:
14463 // Metal pipelines belong to the devices which create them, so we'll
14464 // need to create a MTLPipelineState for every MTLDevice in a grouped
14465 // VkDevice. We can assume, then, that the device index is constant.
14466 entry_func.fixup_hooks_in.push_back(t: [=]() {
14467 statement(ts: "const ", ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14468 ts&: msl_options.device_index, ts: ";");
14469 });
14470 break;
14471 case BuiltInWorkgroupId:
14472 if (!msl_options.dispatch_base || !active_input_builtins.get(bit: BuiltInWorkgroupId))
14473 break;
14474
14475 // The vkCmdDispatchBase() command lets the client set the base value
14476 // of WorkgroupId. Metal has no direct equivalent; we must make this
14477 // adjustment ourselves.
14478 entry_func.fixup_hooks_in.push_back(t: [=]() {
14479 statement(ts: to_expression(id: var_id), ts: " += ", ts: to_dereferenced_expression(id: builtin_dispatch_base_id), ts: ";");
14480 });
14481 break;
14482 case BuiltInGlobalInvocationId:
14483 if (!msl_options.dispatch_base || !active_input_builtins.get(bit: BuiltInGlobalInvocationId))
14484 break;
14485
14486 // GlobalInvocationId is defined as LocalInvocationId + WorkgroupId * WorkgroupSize.
14487 // This needs to be adjusted too.
14488 entry_func.fixup_hooks_in.push_back(t: [=]() {
14489 auto &execution = this->get_entry_point();
14490 uint32_t workgroup_size_id = execution.workgroup_size.constant;
14491 if (workgroup_size_id)
14492 statement(ts: to_expression(id: var_id), ts: " += ", ts: to_dereferenced_expression(id: builtin_dispatch_base_id),
14493 ts: " * ", ts: to_expression(id: workgroup_size_id), ts: ";");
14494 else
14495 statement(ts: to_expression(id: var_id), ts: " += ", ts: to_dereferenced_expression(id: builtin_dispatch_base_id),
14496 ts: " * uint3(", ts&: execution.workgroup_size.x, ts: ", ", ts&: execution.workgroup_size.y, ts: ", ",
14497 ts&: execution.workgroup_size.z, ts: ");");
14498 });
14499 break;
14500 case BuiltInVertexId:
14501 case BuiltInVertexIndex:
14502 // This is direct-mapped normally.
14503 if (!msl_options.vertex_for_tessellation)
14504 break;
14505
14506 entry_func.fixup_hooks_in.push_back(t: [=]() {
14507 builtin_declaration = true;
14508 switch (msl_options.vertex_index_type)
14509 {
14510 case Options::IndexType::None:
14511 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14512 ts: to_expression(id: builtin_invocation_id_id), ts: ".x + ",
14513 ts: to_expression(id: builtin_dispatch_base_id), ts: ".x;");
14514 break;
14515 case Options::IndexType::UInt16:
14516 case Options::IndexType::UInt32:
14517 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ", ts&: index_buffer_var_name,
14518 ts: "[", ts: to_expression(id: builtin_invocation_id_id), ts: ".x] + ",
14519 ts: to_expression(id: builtin_dispatch_base_id), ts: ".x;");
14520 break;
14521 }
14522 builtin_declaration = false;
14523 });
14524 break;
14525 case BuiltInBaseVertex:
14526 // This is direct-mapped normally.
14527 if (!msl_options.vertex_for_tessellation)
14528 break;
14529
14530 entry_func.fixup_hooks_in.push_back(t: [=]() {
14531 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14532 ts: to_expression(id: builtin_dispatch_base_id), ts: ".x;");
14533 });
14534 break;
14535 case BuiltInInstanceId:
14536 case BuiltInInstanceIndex:
14537 // This is direct-mapped normally.
14538 if (!msl_options.vertex_for_tessellation)
14539 break;
14540
14541 entry_func.fixup_hooks_in.push_back(t: [=]() {
14542 builtin_declaration = true;
14543 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14544 ts: to_expression(id: builtin_invocation_id_id), ts: ".y + ", ts: to_expression(id: builtin_dispatch_base_id),
14545 ts: ".y;");
14546 builtin_declaration = false;
14547 });
14548 break;
14549 case BuiltInBaseInstance:
14550 // This is direct-mapped normally.
14551 if (!msl_options.vertex_for_tessellation)
14552 break;
14553
14554 entry_func.fixup_hooks_in.push_back(t: [=]() {
14555 statement(ts: builtin_type_decl(builtin: bi_type), ts: " ", ts: to_expression(id: var_id), ts: " = ",
14556 ts: to_expression(id: builtin_dispatch_base_id), ts: ".y;");
14557 });
14558 break;
14559 default:
14560 break;
14561 }
14562 }
14563 else if (var.storage == StorageClassOutput && get_execution_model() == ExecutionModelFragment &&
14564 is_builtin_variable(var) && active_output_builtins.get(bit: bi_type))
14565 {
14566 switch (bi_type)
14567 {
14568 case BuiltInSampleMask:
14569 if (has_additional_fixed_sample_mask())
14570 {
14571 // If the additional fixed sample mask was set, we need to adjust the sample_mask
14572 // output to reflect that. If the shader outputs the sample_mask itself too, we need
14573 // to AND the two masks to get the final one.
14574 string op_str = does_shader_write_sample_mask ? " &= " : " = ";
14575 entry_func.fixup_hooks_out.push_back(t: [=]() {
14576 statement(ts: to_expression(id: builtin_sample_mask_id), ts: op_str, ts: additional_fixed_sample_mask_str(), ts: ";");
14577 });
14578 }
14579 break;
14580 case BuiltInFragDepth:
14581 if (msl_options.input_attachment_is_ds_attachment && !writes_to_depth)
14582 {
14583 entry_func.fixup_hooks_out.push_back(t: [=]() {
14584 statement(ts: to_expression(id: builtin_frag_depth_id), ts: " = ", ts: to_expression(id: builtin_frag_coord_id), ts: ".z;");
14585 });
14586 }
14587 break;
14588 default:
14589 break;
14590 }
14591 }
14592 });
14593}
14594
14595// Returns the Metal index of the resource of the specified type as used by the specified variable.
14596uint32_t CompilerMSL::get_metal_resource_index(SPIRVariable &var, SPIRType::BaseType basetype, uint32_t plane)
14597{
14598 auto &execution = get_entry_point();
14599 auto &var_dec = ir.meta[var.self].decoration;
14600 auto &var_type = get<SPIRType>(id: var.basetype);
14601 uint32_t var_desc_set = (var.storage == StorageClassPushConstant) ? kPushConstDescSet : var_dec.set;
14602 uint32_t var_binding = (var.storage == StorageClassPushConstant) ? kPushConstBinding : var_dec.binding;
14603
14604 // If a matching binding has been specified, find and use it.
14605 auto itr = resource_bindings.find(x: { .model: execution.model, .desc_set: var_desc_set, .binding: var_binding });
14606
14607 // Atomic helper buffers for image atomics need to use secondary bindings as well.
14608 bool use_secondary_binding = (var_type.basetype == SPIRType::SampledImage && basetype == SPIRType::Sampler) ||
14609 basetype == SPIRType::AtomicCounter;
14610
14611 auto resource_decoration =
14612 use_secondary_binding ? SPIRVCrossDecorationResourceIndexSecondary : SPIRVCrossDecorationResourceIndexPrimary;
14613
14614 if (plane == 1)
14615 resource_decoration = SPIRVCrossDecorationResourceIndexTertiary;
14616 if (plane == 2)
14617 resource_decoration = SPIRVCrossDecorationResourceIndexQuaternary;
14618
14619 if (itr != end(cont&: resource_bindings))
14620 {
14621 auto &remap = itr->second;
14622 remap.second = true;
14623 switch (basetype)
14624 {
14625 case SPIRType::Image:
14626 set_extended_decoration(id: var.self, decoration: resource_decoration, value: remap.first.msl_texture + plane);
14627 return remap.first.msl_texture + plane;
14628 case SPIRType::Sampler:
14629 set_extended_decoration(id: var.self, decoration: resource_decoration, value: remap.first.msl_sampler);
14630 return remap.first.msl_sampler;
14631 default:
14632 set_extended_decoration(id: var.self, decoration: resource_decoration, value: remap.first.msl_buffer);
14633 return remap.first.msl_buffer;
14634 }
14635 }
14636
14637 // If we have already allocated an index, keep using it.
14638 if (has_extended_decoration(id: var.self, decoration: resource_decoration))
14639 return get_extended_decoration(id: var.self, decoration: resource_decoration);
14640
14641 auto &type = get<SPIRType>(id: var.basetype);
14642
14643 if (type_is_msl_framebuffer_fetch(type))
14644 {
14645 // Frame-buffer fetch gets its fallback resource index from the input attachment index,
14646 // which is then treated as color index.
14647 return get_decoration(id: var.self, decoration: DecorationInputAttachmentIndex);
14648 }
14649 else if (msl_options.enable_decoration_binding)
14650 {
14651 // Allow user to enable decoration binding.
14652 // If there is no explicit mapping of bindings to MSL, use the declared binding as a fallback.
14653 if (has_decoration(id: var.self, decoration: DecorationBinding))
14654 {
14655 var_binding = get_decoration(id: var.self, decoration: DecorationBinding);
14656 // Avoid emitting sentinel bindings.
14657 if (var_binding < 0x80000000u)
14658 return var_binding;
14659 }
14660 }
14661
14662 // If we did not explicitly remap, allocate bindings on demand.
14663 // We cannot reliably use Binding decorations since SPIR-V and MSL's binding models are very different.
14664
14665 bool allocate_argument_buffer_ids = false;
14666
14667 if (var.storage != StorageClassPushConstant)
14668 allocate_argument_buffer_ids = descriptor_set_is_argument_buffer(desc_set: var_desc_set);
14669
14670 uint32_t binding_stride = 1;
14671 for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
14672 binding_stride *= to_array_size_literal(type, index: i);
14673
14674 // If a binding has not been specified, revert to incrementing resource indices.
14675 uint32_t resource_index;
14676
14677 if (allocate_argument_buffer_ids)
14678 {
14679 // Allocate from a flat ID binding space.
14680 resource_index = next_metal_resource_ids[var_desc_set];
14681 next_metal_resource_ids[var_desc_set] += binding_stride;
14682 }
14683 else
14684 {
14685 if (is_var_runtime_size_array(var))
14686 {
14687 basetype = SPIRType::Struct;
14688 binding_stride = 1;
14689 }
14690 // Allocate from plain bindings which are allocated per resource type.
14691 switch (basetype)
14692 {
14693 case SPIRType::Image:
14694 resource_index = next_metal_resource_index_texture;
14695 next_metal_resource_index_texture += binding_stride;
14696 break;
14697 case SPIRType::Sampler:
14698 resource_index = next_metal_resource_index_sampler;
14699 next_metal_resource_index_sampler += binding_stride;
14700 break;
14701 default:
14702 resource_index = next_metal_resource_index_buffer;
14703 next_metal_resource_index_buffer += binding_stride;
14704 break;
14705 }
14706 }
14707
14708 set_extended_decoration(id: var.self, decoration: resource_decoration, value: resource_index);
14709 return resource_index;
14710}
14711
14712bool CompilerMSL::type_is_msl_framebuffer_fetch(const SPIRType &type) const
14713{
14714 return type.basetype == SPIRType::Image && type.image.dim == DimSubpassData &&
14715 msl_options.use_framebuffer_fetch_subpasses;
14716}
14717
14718const char *CompilerMSL::descriptor_address_space(uint32_t id, StorageClass storage, const char *plain_address_space) const
14719{
14720 if (msl_options.argument_buffers)
14721 {
14722 bool storage_class_is_descriptor = storage == StorageClassUniform ||
14723 storage == StorageClassStorageBuffer ||
14724 storage == StorageClassUniformConstant;
14725
14726 uint32_t desc_set = get_decoration(id, decoration: DecorationDescriptorSet);
14727 if (storage_class_is_descriptor && descriptor_set_is_argument_buffer(desc_set))
14728 {
14729 // An awkward case where we need to emit *more* address space declarations (yay!).
14730 // An example is where we pass down an array of buffer pointers to leaf functions.
14731 // It's a constant array containing pointers to constants.
14732 // The pointer array is always constant however. E.g.
14733 // device SSBO * constant (&array)[N].
14734 // const device SSBO * constant (&array)[N].
14735 // constant SSBO * constant (&array)[N].
14736 // However, this only matters for argument buffers, since for MSL 1.0 style codegen,
14737 // we emit the buffer array on stack instead, and that seems to work just fine apparently.
14738
14739 // If the argument was marked as being in device address space, any pointer to member would
14740 // be const device, not constant.
14741 if (argument_buffer_device_storage_mask & (1u << desc_set))
14742 return "const device";
14743 else
14744 return "constant";
14745 }
14746 }
14747
14748 return plain_address_space;
14749}
14750
14751string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
14752{
14753 auto &var = get<SPIRVariable>(id: arg.id);
14754 auto &type = get_variable_data_type(var);
14755 auto &var_type = get<SPIRType>(id: arg.type);
14756 StorageClass type_storage = var_type.storage;
14757
14758 // If we need to modify the name of the variable, make sure we use the original variable.
14759 // Our alias is just a shadow variable.
14760 uint32_t name_id = var.self;
14761 if (arg.alias_global_variable && var.basevariable)
14762 name_id = var.basevariable;
14763
14764 bool constref = !arg.alias_global_variable && is_pointer(type: var_type) && arg.write_count == 0;
14765 // Framebuffer fetch is plain value, const looks out of place, but it is not wrong.
14766 if (type_is_msl_framebuffer_fetch(type))
14767 constref = false;
14768 else if (type_storage == StorageClassUniformConstant)
14769 constref = true;
14770
14771 bool type_is_image = type.basetype == SPIRType::Image || type.basetype == SPIRType::SampledImage ||
14772 type.basetype == SPIRType::Sampler;
14773 bool type_is_tlas = type.basetype == SPIRType::AccelerationStructure;
14774
14775 // For opaque types we handle const later due to descriptor address spaces.
14776 const char *cv_qualifier = (constref && !type_is_image) ? "const " : "";
14777 string decl;
14778
14779 // If this is a combined image-sampler for a 2D image with floating-point type,
14780 // we emitted the 'spvDynamicImageSampler' type, and this is *not* an alias parameter
14781 // for a global, then we need to emit a "dynamic" combined image-sampler.
14782 // Unfortunately, this is necessary to properly support passing around
14783 // combined image-samplers with Y'CbCr conversions on them.
14784 bool is_dynamic_img_sampler = !arg.alias_global_variable && type.basetype == SPIRType::SampledImage &&
14785 type.image.dim == Dim2D && type_is_floating_point(type: get<SPIRType>(id: type.image.type)) &&
14786 spv_function_implementations.count(x: SPVFuncImplDynamicImageSampler);
14787
14788 // Allow Metal to use the array<T> template to make arrays a value type
14789 string address_space = get_argument_address_space(argument: var);
14790 bool builtin = has_decoration(id: var.self, decoration: DecorationBuiltIn);
14791 auto builtin_type = BuiltIn(get_decoration(id: arg.id, decoration: DecorationBuiltIn));
14792
14793 if (var.basevariable && (var.basevariable == stage_in_ptr_var_id || var.basevariable == stage_out_ptr_var_id))
14794 decl = join(ts&: cv_qualifier, ts: type_to_glsl(type, id: arg.id));
14795 else if (builtin)
14796 {
14797 // Only use templated array for Clip/Cull distance when feasible.
14798 // In other scenarios, we need need to override array length for tess levels (if used as outputs),
14799 // or we need to emit the expected type for builtins (uint vs int).
14800 auto storage = get<SPIRType>(id: var.basetype).storage;
14801
14802 if (storage == StorageClassInput &&
14803 (builtin_type == BuiltInTessLevelInner || builtin_type == BuiltInTessLevelOuter))
14804 {
14805 is_using_builtin_array = false;
14806 }
14807 else if (builtin_type != BuiltInClipDistance && builtin_type != BuiltInCullDistance)
14808 {
14809 is_using_builtin_array = true;
14810 }
14811
14812 if (storage == StorageClassOutput && variable_storage_requires_stage_io(storage) &&
14813 !is_stage_output_builtin_masked(builtin: builtin_type))
14814 is_using_builtin_array = true;
14815
14816 if (is_using_builtin_array)
14817 decl = join(ts&: cv_qualifier, ts: builtin_type_decl(builtin: builtin_type, id: arg.id));
14818 else
14819 decl = join(ts&: cv_qualifier, ts: type_to_glsl(type, id: arg.id));
14820 }
14821 else if (is_var_runtime_size_array(var))
14822 {
14823 const auto *parent_type = &get<SPIRType>(id: type.parent_type);
14824 auto type_name = type_to_glsl(type: *parent_type, id: arg.id);
14825 if (type.basetype == SPIRType::AccelerationStructure)
14826 decl = join(ts: "spvDescriptorArray<", ts&: type_name, ts: ">");
14827 else if (type_is_image)
14828 decl = join(ts: "spvDescriptorArray<", ts&: cv_qualifier, ts&: type_name, ts: ">");
14829 else
14830 decl = join(ts: "spvDescriptorArray<", ts&: address_space, ts: " ", ts&: type_name, ts: "*>");
14831 address_space = "const";
14832 }
14833 else if ((type_storage == StorageClassUniform || type_storage == StorageClassStorageBuffer) && is_array(type))
14834 {
14835 is_using_builtin_array = true;
14836 decl += join(ts&: cv_qualifier, ts: type_to_glsl(type, id: arg.id), ts: "*");
14837 }
14838 else if (is_dynamic_img_sampler)
14839 {
14840 decl = join(ts&: cv_qualifier, ts: "spvDynamicImageSampler<", ts: type_to_glsl(type: get<SPIRType>(id: type.image.type)), ts: ">");
14841 // Mark the variable so that we can handle passing it to another function.
14842 set_extended_decoration(id: arg.id, decoration: SPIRVCrossDecorationDynamicImageSampler);
14843 }
14844 else
14845 {
14846 // The type is a pointer type we need to emit cv_qualifier late.
14847 if (is_pointer(type))
14848 {
14849 decl = type_to_glsl(type, id: arg.id);
14850 if (*cv_qualifier != '\0')
14851 decl += join(ts: " ", ts&: cv_qualifier);
14852 }
14853 else
14854 {
14855 decl = join(ts&: cv_qualifier, ts: type_to_glsl(type, id: arg.id));
14856 }
14857 }
14858
14859 if (!builtin && !is_pointer(type: var_type) &&
14860 (type_storage == StorageClassFunction || type_storage == StorageClassGeneric))
14861 {
14862 // If the argument is a pure value and not an opaque type, we will pass by value.
14863 if (msl_options.force_native_arrays && is_array(type))
14864 {
14865 // We are receiving an array by value. This is problematic.
14866 // We cannot be sure of the target address space since we are supposed to receive a copy,
14867 // but this is not possible with MSL without some extra work.
14868 // We will have to assume we're getting a reference in thread address space.
14869 // If we happen to get a reference in constant address space, the caller must emit a copy and pass that.
14870 // Thread const therefore becomes the only logical choice, since we cannot "create" a constant array from
14871 // non-constant arrays, but we can create thread const from constant.
14872 decl = string("thread const ") + decl;
14873 decl += " (&";
14874 const char *restrict_kw = to_restrict(id: name_id, space: true);
14875 if (*restrict_kw)
14876 {
14877 decl += " ";
14878 decl += restrict_kw;
14879 }
14880 decl += to_expression(id: name_id);
14881 decl += ")";
14882 decl += type_to_array_glsl(type, variable_id: name_id);
14883 }
14884 else
14885 {
14886 if (!address_space.empty())
14887 decl = join(ts&: address_space, ts: " ", ts&: decl);
14888 decl += " ";
14889 decl += to_expression(id: name_id);
14890 }
14891 }
14892 else if (is_array(type) && !type_is_image)
14893 {
14894 // Arrays of opaque types are special cased.
14895 if (!address_space.empty())
14896 decl = join(ts&: address_space, ts: " ", ts&: decl);
14897
14898 // spvDescriptorArray absorbs the address space inside the template.
14899 if (!is_var_runtime_size_array(var))
14900 {
14901 const char *argument_buffer_space = descriptor_address_space(id: name_id, storage: type_storage, plain_address_space: nullptr);
14902 if (argument_buffer_space)
14903 {
14904 decl += " ";
14905 decl += argument_buffer_space;
14906 }
14907 }
14908
14909 // Special case, need to override the array size here if we're using tess level as an argument.
14910 if (is_tesc_shader() && builtin &&
14911 (builtin_type == BuiltInTessLevelInner || builtin_type == BuiltInTessLevelOuter))
14912 {
14913 uint32_t array_size = get_physical_tess_level_array_size(builtin: builtin_type);
14914 if (array_size == 1)
14915 {
14916 decl += " &";
14917 decl += to_expression(id: name_id);
14918 }
14919 else
14920 {
14921 decl += " (&";
14922 decl += to_expression(id: name_id);
14923 decl += ")";
14924 decl += join(ts: "[", ts&: array_size, ts: "]");
14925 }
14926 }
14927 else if (is_var_runtime_size_array(var))
14928 {
14929 decl += " " + to_expression(id: name_id);
14930 }
14931 else
14932 {
14933 auto array_size_decl = type_to_array_glsl(type, variable_id: name_id);
14934 if (array_size_decl.empty())
14935 decl += "& ";
14936 else
14937 decl += " (&";
14938
14939 const char *restrict_kw = to_restrict(id: name_id, space: true);
14940 if (*restrict_kw)
14941 {
14942 decl += " ";
14943 decl += restrict_kw;
14944 }
14945 decl += to_expression(id: name_id);
14946
14947 if (!array_size_decl.empty())
14948 {
14949 decl += ")";
14950 decl += array_size_decl;
14951 }
14952 }
14953 }
14954 else if (!type_is_image && !type_is_tlas &&
14955 (!pull_model_inputs.count(x: var.basevariable) || type.basetype == SPIRType::Struct))
14956 {
14957 // If this is going to be a reference to a variable pointer, the address space
14958 // for the reference has to go before the '&', but after the '*'.
14959 if (!address_space.empty())
14960 {
14961 if (is_pointer(type))
14962 {
14963 if (*cv_qualifier == '\0')
14964 decl += ' ';
14965 decl += join(ts&: address_space, ts: " ");
14966 }
14967 else
14968 decl = join(ts&: address_space, ts: " ", ts&: decl);
14969 }
14970 decl += "&";
14971 decl += " ";
14972 decl += to_restrict(id: name_id, space: true);
14973 decl += to_expression(id: name_id);
14974 }
14975 else if (type_is_image || type_is_tlas)
14976 {
14977 if (is_var_runtime_size_array(var))
14978 {
14979 decl = address_space + " " + decl + " " + to_expression(id: name_id);
14980 }
14981 else if (type.array.empty())
14982 {
14983 // For non-arrayed types we can just pass opaque descriptors by value.
14984 // This fixes problems if descriptors are passed by value from argument buffers and plain descriptors
14985 // in same shader.
14986 // There is no address space we can actually use, but value will work.
14987 // This will break if applications attempt to pass down descriptor arrays as arguments, but
14988 // fortunately that is extremely unlikely ...
14989 decl += " ";
14990 decl += to_expression(id: name_id);
14991 }
14992 else
14993 {
14994 const char *img_address_space = descriptor_address_space(id: name_id, storage: type_storage, plain_address_space: "thread const");
14995 decl = join(ts&: img_address_space, ts: " ", ts&: decl);
14996 decl += "& ";
14997 decl += to_expression(id: name_id);
14998 }
14999 }
15000 else
15001 {
15002 if (!address_space.empty())
15003 decl = join(ts&: address_space, ts: " ", ts&: decl);
15004 decl += " ";
15005 decl += to_expression(id: name_id);
15006 }
15007
15008 // Emulate texture2D atomic operations
15009 auto *backing_var = maybe_get_backing_variable(chain: name_id);
15010 if (backing_var && atomic_image_vars_emulated.count(x: backing_var->self))
15011 {
15012 auto &flags = ir.get_decoration_bitset(id: backing_var->self);
15013 const char *cv_flags = decoration_flags_signal_volatile(flags) ? "volatile " : "";
15014 decl += join(ts: ", ", ts&: cv_flags, ts: "device atomic_", ts: type_to_glsl(type: get<SPIRType>(id: var_type.image.type), id: 0));
15015 decl += "* " + to_expression(id: name_id) + "_atomic";
15016 }
15017
15018 is_using_builtin_array = false;
15019
15020 return decl;
15021}
15022
15023// If we're currently in the entry point function, and the object
15024// has a qualified name, use it, otherwise use the standard name.
15025string CompilerMSL::to_name(uint32_t id, bool allow_alias) const
15026{
15027 if (current_function && (current_function->self == ir.default_entry_point))
15028 {
15029 auto *m = ir.find_meta(id);
15030 if (m && !m->decoration.qualified_alias_explicit_override && !m->decoration.qualified_alias.empty())
15031 return m->decoration.qualified_alias;
15032 }
15033 return Compiler::to_name(id, allow_alias);
15034}
15035
15036// Appends the name of the member to the variable qualifier string, except for Builtins.
15037string CompilerMSL::append_member_name(const string &qualifier, const SPIRType &type, uint32_t index)
15038{
15039 // Don't qualify Builtin names because they are unique and are treated as such when building expressions
15040 BuiltIn builtin = BuiltInMax;
15041 if (is_member_builtin(type, index, builtin: &builtin))
15042 return builtin_to_glsl(builtin, storage: type.storage);
15043
15044 // Strip any underscore prefix from member name
15045 string mbr_name = to_member_name(type, index);
15046 size_t startPos = mbr_name.find_first_not_of(s: "_");
15047 mbr_name = (startPos != string::npos) ? mbr_name.substr(pos: startPos) : "";
15048 return join(ts: qualifier, ts: "_", ts&: mbr_name);
15049}
15050
15051// Ensures that the specified name is permanently usable by prepending a prefix
15052// if the first chars are _ and a digit, which indicate a transient name.
15053string CompilerMSL::ensure_valid_name(string name, string pfx)
15054{
15055 return (name.size() >= 2 && name[0] == '_' && isdigit(name[1])) ? (pfx + name) : name;
15056}
15057
15058const std::unordered_set<std::string> &CompilerMSL::get_reserved_keyword_set()
15059{
15060 static const unordered_set<string> keywords = {
15061 "kernel",
15062 "vertex",
15063 "fragment",
15064 "compute",
15065 "constant",
15066 "device",
15067 "bias",
15068 "level",
15069 "gradient2d",
15070 "gradientcube",
15071 "gradient3d",
15072 "min_lod_clamp",
15073 "assert",
15074 "VARIABLE_TRACEPOINT",
15075 "STATIC_DATA_TRACEPOINT",
15076 "STATIC_DATA_TRACEPOINT_V",
15077 "METAL_ALIGN",
15078 "METAL_ASM",
15079 "METAL_CONST",
15080 "METAL_DEPRECATED",
15081 "METAL_ENABLE_IF",
15082 "METAL_FUNC",
15083 "METAL_INTERNAL",
15084 "METAL_NON_NULL_RETURN",
15085 "METAL_NORETURN",
15086 "METAL_NOTHROW",
15087 "METAL_PURE",
15088 "METAL_UNAVAILABLE",
15089 "METAL_IMPLICIT",
15090 "METAL_EXPLICIT",
15091 "METAL_CONST_ARG",
15092 "METAL_ARG_UNIFORM",
15093 "METAL_ZERO_ARG",
15094 "METAL_VALID_LOD_ARG",
15095 "METAL_VALID_LEVEL_ARG",
15096 "METAL_VALID_STORE_ORDER",
15097 "METAL_VALID_LOAD_ORDER",
15098 "METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER",
15099 "METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS",
15100 "METAL_VALID_RENDER_TARGET",
15101 "is_function_constant_defined",
15102 "CHAR_BIT",
15103 "SCHAR_MAX",
15104 "SCHAR_MIN",
15105 "UCHAR_MAX",
15106 "CHAR_MAX",
15107 "CHAR_MIN",
15108 "USHRT_MAX",
15109 "SHRT_MAX",
15110 "SHRT_MIN",
15111 "UINT_MAX",
15112 "INT_MAX",
15113 "INT_MIN",
15114 "FLT_DIG",
15115 "FLT_MANT_DIG",
15116 "FLT_MAX_10_EXP",
15117 "FLT_MAX_EXP",
15118 "FLT_MIN_10_EXP",
15119 "FLT_MIN_EXP",
15120 "FLT_RADIX",
15121 "FLT_MAX",
15122 "FLT_MIN",
15123 "FLT_EPSILON",
15124 "FP_ILOGB0",
15125 "FP_ILOGBNAN",
15126 "MAXFLOAT",
15127 "HUGE_VALF",
15128 "INFINITY",
15129 "NAN",
15130 "M_E_F",
15131 "M_LOG2E_F",
15132 "M_LOG10E_F",
15133 "M_LN2_F",
15134 "M_LN10_F",
15135 "M_PI_F",
15136 "M_PI_2_F",
15137 "M_PI_4_F",
15138 "M_1_PI_F",
15139 "M_2_PI_F",
15140 "M_2_SQRTPI_F",
15141 "M_SQRT2_F",
15142 "M_SQRT1_2_F",
15143 "HALF_DIG",
15144 "HALF_MANT_DIG",
15145 "HALF_MAX_10_EXP",
15146 "HALF_MAX_EXP",
15147 "HALF_MIN_10_EXP",
15148 "HALF_MIN_EXP",
15149 "HALF_RADIX",
15150 "HALF_MAX",
15151 "HALF_MIN",
15152 "HALF_EPSILON",
15153 "MAXHALF",
15154 "HUGE_VALH",
15155 "M_E_H",
15156 "M_LOG2E_H",
15157 "M_LOG10E_H",
15158 "M_LN2_H",
15159 "M_LN10_H",
15160 "M_PI_H",
15161 "M_PI_2_H",
15162 "M_PI_4_H",
15163 "M_1_PI_H",
15164 "M_2_PI_H",
15165 "M_2_SQRTPI_H",
15166 "M_SQRT2_H",
15167 "M_SQRT1_2_H",
15168 "DBL_DIG",
15169 "DBL_MANT_DIG",
15170 "DBL_MAX_10_EXP",
15171 "DBL_MAX_EXP",
15172 "DBL_MIN_10_EXP",
15173 "DBL_MIN_EXP",
15174 "DBL_RADIX",
15175 "DBL_MAX",
15176 "DBL_MIN",
15177 "DBL_EPSILON",
15178 "HUGE_VAL",
15179 "M_E",
15180 "M_LOG2E",
15181 "M_LOG10E",
15182 "M_LN2",
15183 "M_LN10",
15184 "M_PI",
15185 "M_PI_2",
15186 "M_PI_4",
15187 "M_1_PI",
15188 "M_2_PI",
15189 "M_2_SQRTPI",
15190 "M_SQRT2",
15191 "M_SQRT1_2",
15192 "quad_broadcast",
15193 "thread",
15194 "threadgroup",
15195 };
15196
15197 return keywords;
15198}
15199
15200const std::unordered_set<std::string> &CompilerMSL::get_illegal_func_names()
15201{
15202 static const unordered_set<string> illegal_func_names = {
15203 "main",
15204 "saturate",
15205 "assert",
15206 "fmin3",
15207 "fmax3",
15208 "divide",
15209 "median3",
15210 "VARIABLE_TRACEPOINT",
15211 "STATIC_DATA_TRACEPOINT",
15212 "STATIC_DATA_TRACEPOINT_V",
15213 "METAL_ALIGN",
15214 "METAL_ASM",
15215 "METAL_CONST",
15216 "METAL_DEPRECATED",
15217 "METAL_ENABLE_IF",
15218 "METAL_FUNC",
15219 "METAL_INTERNAL",
15220 "METAL_NON_NULL_RETURN",
15221 "METAL_NORETURN",
15222 "METAL_NOTHROW",
15223 "METAL_PURE",
15224 "METAL_UNAVAILABLE",
15225 "METAL_IMPLICIT",
15226 "METAL_EXPLICIT",
15227 "METAL_CONST_ARG",
15228 "METAL_ARG_UNIFORM",
15229 "METAL_ZERO_ARG",
15230 "METAL_VALID_LOD_ARG",
15231 "METAL_VALID_LEVEL_ARG",
15232 "METAL_VALID_STORE_ORDER",
15233 "METAL_VALID_LOAD_ORDER",
15234 "METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER",
15235 "METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS",
15236 "METAL_VALID_RENDER_TARGET",
15237 "is_function_constant_defined",
15238 "CHAR_BIT",
15239 "SCHAR_MAX",
15240 "SCHAR_MIN",
15241 "UCHAR_MAX",
15242 "CHAR_MAX",
15243 "CHAR_MIN",
15244 "USHRT_MAX",
15245 "SHRT_MAX",
15246 "SHRT_MIN",
15247 "UINT_MAX",
15248 "INT_MAX",
15249 "INT_MIN",
15250 "FLT_DIG",
15251 "FLT_MANT_DIG",
15252 "FLT_MAX_10_EXP",
15253 "FLT_MAX_EXP",
15254 "FLT_MIN_10_EXP",
15255 "FLT_MIN_EXP",
15256 "FLT_RADIX",
15257 "FLT_MAX",
15258 "FLT_MIN",
15259 "FLT_EPSILON",
15260 "FP_ILOGB0",
15261 "FP_ILOGBNAN",
15262 "MAXFLOAT",
15263 "HUGE_VALF",
15264 "INFINITY",
15265 "NAN",
15266 "M_E_F",
15267 "M_LOG2E_F",
15268 "M_LOG10E_F",
15269 "M_LN2_F",
15270 "M_LN10_F",
15271 "M_PI_F",
15272 "M_PI_2_F",
15273 "M_PI_4_F",
15274 "M_1_PI_F",
15275 "M_2_PI_F",
15276 "M_2_SQRTPI_F",
15277 "M_SQRT2_F",
15278 "M_SQRT1_2_F",
15279 "HALF_DIG",
15280 "HALF_MANT_DIG",
15281 "HALF_MAX_10_EXP",
15282 "HALF_MAX_EXP",
15283 "HALF_MIN_10_EXP",
15284 "HALF_MIN_EXP",
15285 "HALF_RADIX",
15286 "HALF_MAX",
15287 "HALF_MIN",
15288 "HALF_EPSILON",
15289 "MAXHALF",
15290 "HUGE_VALH",
15291 "M_E_H",
15292 "M_LOG2E_H",
15293 "M_LOG10E_H",
15294 "M_LN2_H",
15295 "M_LN10_H",
15296 "M_PI_H",
15297 "M_PI_2_H",
15298 "M_PI_4_H",
15299 "M_1_PI_H",
15300 "M_2_PI_H",
15301 "M_2_SQRTPI_H",
15302 "M_SQRT2_H",
15303 "M_SQRT1_2_H",
15304 "DBL_DIG",
15305 "DBL_MANT_DIG",
15306 "DBL_MAX_10_EXP",
15307 "DBL_MAX_EXP",
15308 "DBL_MIN_10_EXP",
15309 "DBL_MIN_EXP",
15310 "DBL_RADIX",
15311 "DBL_MAX",
15312 "DBL_MIN",
15313 "DBL_EPSILON",
15314 "HUGE_VAL",
15315 "M_E",
15316 "M_LOG2E",
15317 "M_LOG10E",
15318 "M_LN2",
15319 "M_LN10",
15320 "M_PI",
15321 "M_PI_2",
15322 "M_PI_4",
15323 "M_1_PI",
15324 "M_2_PI",
15325 "M_2_SQRTPI",
15326 "M_SQRT2",
15327 "M_SQRT1_2",
15328 };
15329
15330 return illegal_func_names;
15331}
15332
15333// Replace all names that match MSL keywords or Metal Standard Library functions.
15334void CompilerMSL::replace_illegal_names()
15335{
15336 // FIXME: MSL and GLSL are doing two different things here.
15337 // Agree on convention and remove this override.
15338 auto &keywords = get_reserved_keyword_set();
15339 auto &illegal_func_names = get_illegal_func_names();
15340
15341 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t self, SPIRVariable &) {
15342 auto *meta = ir.find_meta(id: self);
15343 if (!meta)
15344 return;
15345
15346 auto &dec = meta->decoration;
15347 if (keywords.find(x: dec.alias) != end(cont: keywords))
15348 dec.alias += "0";
15349 });
15350
15351 ir.for_each_typed_id<SPIRFunction>(op: [&](uint32_t self, SPIRFunction &) {
15352 auto *meta = ir.find_meta(id: self);
15353 if (!meta)
15354 return;
15355
15356 auto &dec = meta->decoration;
15357 if (illegal_func_names.find(x: dec.alias) != end(cont: illegal_func_names))
15358 dec.alias += "0";
15359 });
15360
15361 ir.for_each_typed_id<SPIRType>(op: [&](uint32_t self, SPIRType &) {
15362 auto *meta = ir.find_meta(id: self);
15363 if (!meta)
15364 return;
15365
15366 for (auto &mbr_dec : meta->members)
15367 if (keywords.find(x: mbr_dec.alias) != end(cont: keywords))
15368 mbr_dec.alias += "0";
15369 });
15370
15371 CompilerGLSL::replace_illegal_names();
15372}
15373
15374void CompilerMSL::replace_illegal_entry_point_names()
15375{
15376 auto &illegal_func_names = get_illegal_func_names();
15377
15378 // It is important to this before we fixup identifiers,
15379 // since if ep_name is reserved, we will need to fix that up,
15380 // and then copy alias back into entry.name after the fixup.
15381 for (auto &entry : ir.entry_points)
15382 {
15383 // Change both the entry point name and the alias, to keep them synced.
15384 string &ep_name = entry.second.name;
15385 if (illegal_func_names.find(x: ep_name) != end(cont: illegal_func_names))
15386 ep_name += "0";
15387
15388 ir.meta[entry.first].decoration.alias = ep_name;
15389 }
15390}
15391
15392void CompilerMSL::sync_entry_point_aliases_and_names()
15393{
15394 for (auto &entry : ir.entry_points)
15395 entry.second.name = ir.meta[entry.first].decoration.alias;
15396}
15397
15398string CompilerMSL::to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain_is_resolved)
15399{
15400 auto *var = maybe_get_backing_variable(chain: base);
15401 // If this is a buffer array, we have to dereference the buffer pointers.
15402 // Otherwise, if this is a pointer expression, dereference it.
15403
15404 bool declared_as_pointer = false;
15405
15406 if (var)
15407 {
15408 // Only allow -> dereference for block types. This is so we get expressions like
15409 // buffer[i]->first_member.second_member, rather than buffer[i]->first->second.
15410 const bool is_block =
15411 has_decoration(id: type.self, decoration: DecorationBlock) || has_decoration(id: type.self, decoration: DecorationBufferBlock);
15412
15413 bool is_buffer_variable =
15414 is_block && (var->storage == StorageClassUniform || var->storage == StorageClassStorageBuffer);
15415 declared_as_pointer = is_buffer_variable && is_array(type: get_pointee_type(type_id: var->basetype));
15416 }
15417
15418 if (declared_as_pointer || (!ptr_chain_is_resolved && should_dereference(id: base)))
15419 return join(ts: "->", ts: to_member_name(type, index));
15420 else
15421 return join(ts: ".", ts: to_member_name(type, index));
15422}
15423
15424string CompilerMSL::to_qualifiers_glsl(uint32_t id)
15425{
15426 string quals;
15427
15428 auto *var = maybe_get<SPIRVariable>(id);
15429 auto &type = expression_type(id);
15430
15431 if (type.storage == StorageClassWorkgroup || (var && variable_decl_is_remapped_storage(variable: *var, storage: StorageClassWorkgroup)))
15432 quals += "threadgroup ";
15433
15434 return quals;
15435}
15436
15437// The optional id parameter indicates the object whose type we are trying
15438// to find the description for. It is optional. Most type descriptions do not
15439// depend on a specific object's use of that type.
15440string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id, bool member)
15441{
15442 string type_name;
15443
15444 // Pointer?
15445 if (is_pointer(type) || type_is_array_of_pointers(type))
15446 {
15447 assert(type.pointer_depth > 0);
15448
15449 const char *restrict_kw;
15450
15451 auto type_address_space = get_type_address_space(type, id);
15452 const auto *p_parent_type = &get<SPIRType>(id: type.parent_type);
15453
15454 // If we're wrapping buffer descriptors in a spvDescriptorArray, we'll have to handle it as a special case.
15455 if (member && id)
15456 {
15457 auto &var = get<SPIRVariable>(id);
15458 if (is_var_runtime_size_array(var) && is_runtime_size_array(type: *p_parent_type))
15459 {
15460 const bool ssbo = has_decoration(id: p_parent_type->self, decoration: DecorationBufferBlock);
15461 bool buffer_desc =
15462 (var.storage == StorageClassStorageBuffer || ssbo) &&
15463 msl_options.runtime_array_rich_descriptor;
15464
15465 const char *wrapper_type = buffer_desc ? "spvBufferDescriptor" : "spvDescriptor";
15466 add_spv_func_and_recompile(spv_func: SPVFuncImplVariableDescriptorArray);
15467 add_spv_func_and_recompile(spv_func: buffer_desc ? SPVFuncImplVariableSizedDescriptor : SPVFuncImplVariableDescriptor);
15468
15469 type_name = join(ts&: wrapper_type, ts: "<", ts&: type_address_space, ts: " ", ts: type_to_glsl(type: *p_parent_type, id), ts: " *>");
15470 return type_name;
15471 }
15472 }
15473
15474 // Work around C pointer qualifier rules. If glsl_type is a pointer type as well
15475 // we'll need to emit the address space to the right.
15476 // We could always go this route, but it makes the code unnatural.
15477 // Prefer emitting thread T *foo over T thread* foo since it's more readable,
15478 // but we'll have to emit thread T * thread * T constant bar; for example.
15479 if (is_pointer(type) && is_pointer(type: *p_parent_type))
15480 type_name = join(ts: type_to_glsl(type: *p_parent_type, id), ts: " ", ts&: type_address_space, ts: " ");
15481 else
15482 {
15483 // Since this is not a pointer-to-pointer, ensure we've dug down to the base type.
15484 // Some situations chain pointers even though they are not formally pointers-of-pointers.
15485 while (is_pointer(type: *p_parent_type))
15486 p_parent_type = &get<SPIRType>(id: p_parent_type->parent_type);
15487
15488 // If we're emitting BDA, just use the templated type.
15489 // Emitting builtin arrays need a lot of cooperation with other code to ensure
15490 // the C-style nesting works right.
15491 // FIXME: This is somewhat of a hack.
15492 bool old_is_using_builtin_array = is_using_builtin_array;
15493 if (is_physical_pointer(type))
15494 is_using_builtin_array = false;
15495
15496 type_name = join(ts&: type_address_space, ts: " ", ts: type_to_glsl(type: *p_parent_type, id));
15497
15498 is_using_builtin_array = old_is_using_builtin_array;
15499 }
15500
15501 switch (type.basetype)
15502 {
15503 case SPIRType::Image:
15504 case SPIRType::SampledImage:
15505 case SPIRType::Sampler:
15506 // These are handles.
15507 break;
15508 default:
15509 // Anything else can be a raw pointer.
15510 type_name += "*";
15511 restrict_kw = to_restrict(id, space: false);
15512 if (*restrict_kw)
15513 {
15514 type_name += " ";
15515 type_name += restrict_kw;
15516 }
15517 break;
15518 }
15519 return type_name;
15520 }
15521
15522 switch (type.basetype)
15523 {
15524 case SPIRType::Struct:
15525 // Need OpName lookup here to get a "sensible" name for a struct.
15526 // Allow Metal to use the array<T> template to make arrays a value type
15527 type_name = to_name(id: type.self);
15528 break;
15529
15530 case SPIRType::Image:
15531 case SPIRType::SampledImage:
15532 return image_type_glsl(type, id, member);
15533
15534 case SPIRType::Sampler:
15535 return sampler_type(type, id, member);
15536
15537 case SPIRType::Void:
15538 return "void";
15539
15540 case SPIRType::AtomicCounter:
15541 return "atomic_uint";
15542
15543 case SPIRType::ControlPointArray:
15544 return join(ts: "patch_control_point<", ts: type_to_glsl(type: get<SPIRType>(id: type.parent_type), id), ts: ">");
15545
15546 case SPIRType::Interpolant:
15547 return join(ts: "interpolant<", ts: type_to_glsl(type: get<SPIRType>(id: type.parent_type), id), ts: ", interpolation::",
15548 ts: has_decoration(id: type.self, decoration: DecorationNoPerspective) ? "no_perspective" : "perspective", ts: ">");
15549
15550 // Scalars
15551 case SPIRType::Boolean:
15552 {
15553 auto *var = maybe_get_backing_variable(chain: id);
15554 if (var && var->basevariable)
15555 var = &get<SPIRVariable>(id: var->basevariable);
15556
15557 // Need to special-case threadgroup booleans. They are supposed to be logical
15558 // storage, but MSL compilers will sometimes crash if you use threadgroup bool.
15559 // Workaround this by using 16-bit types instead and fixup on load-store to this data.
15560 if ((var && var->storage == StorageClassWorkgroup) || type.storage == StorageClassWorkgroup || member)
15561 type_name = "short";
15562 else
15563 type_name = "bool";
15564 break;
15565 }
15566
15567 case SPIRType::Char:
15568 case SPIRType::SByte:
15569 type_name = "char";
15570 break;
15571 case SPIRType::UByte:
15572 type_name = "uchar";
15573 break;
15574 case SPIRType::Short:
15575 type_name = "short";
15576 break;
15577 case SPIRType::UShort:
15578 type_name = "ushort";
15579 break;
15580 case SPIRType::Int:
15581 type_name = "int";
15582 break;
15583 case SPIRType::UInt:
15584 type_name = "uint";
15585 break;
15586 case SPIRType::Int64:
15587 if (!msl_options.supports_msl_version(major: 2, minor: 2))
15588 SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above.");
15589 type_name = "long";
15590 break;
15591 case SPIRType::UInt64:
15592 if (!msl_options.supports_msl_version(major: 2, minor: 2))
15593 SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above.");
15594 type_name = "ulong";
15595 break;
15596 case SPIRType::Half:
15597 type_name = "half";
15598 break;
15599 case SPIRType::Float:
15600 type_name = "float";
15601 break;
15602 case SPIRType::Double:
15603 type_name = "double"; // Currently unsupported
15604 break;
15605 case SPIRType::AccelerationStructure:
15606 if (msl_options.supports_msl_version(major: 2, minor: 4))
15607 type_name = "raytracing::acceleration_structure<raytracing::instancing>";
15608 else if (msl_options.supports_msl_version(major: 2, minor: 3))
15609 type_name = "raytracing::instance_acceleration_structure";
15610 else
15611 SPIRV_CROSS_THROW("Acceleration Structure Type is supported in MSL 2.3 and above.");
15612 break;
15613 case SPIRType::RayQuery:
15614 return "raytracing::intersection_query<raytracing::instancing, raytracing::triangle_data>";
15615
15616 default:
15617 return "unknown_type";
15618 }
15619
15620 // Matrix?
15621 if (type.columns > 1)
15622 {
15623 auto *var = maybe_get_backing_variable(chain: id);
15624 if (var && var->basevariable)
15625 var = &get<SPIRVariable>(id: var->basevariable);
15626
15627 // Need to special-case threadgroup matrices. Due to an oversight, Metal's
15628 // matrix struct prior to Metal 3 lacks constructors in the threadgroup AS,
15629 // preventing us from default-constructing or initializing matrices in threadgroup storage.
15630 // Work around this by using our own type as storage.
15631 if (((var && var->storage == StorageClassWorkgroup) || type.storage == StorageClassWorkgroup) &&
15632 !msl_options.supports_msl_version(major: 3, minor: 0))
15633 {
15634 add_spv_func_and_recompile(spv_func: SPVFuncImplStorageMatrix);
15635 type_name = "spvStorage_" + type_name;
15636 }
15637
15638 type_name += to_string(val: type.columns) + "x";
15639 }
15640
15641 // Vector or Matrix?
15642 if (type.vecsize > 1)
15643 type_name += to_string(val: type.vecsize);
15644
15645 if (type.array.empty() || using_builtin_array())
15646 {
15647 return type_name;
15648 }
15649 else
15650 {
15651 // Allow Metal to use the array<T> template to make arrays a value type
15652 add_spv_func_and_recompile(spv_func: SPVFuncImplUnsafeArray);
15653 string res;
15654 string sizes;
15655
15656 for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
15657 {
15658 res += "spvUnsafeArray<";
15659 sizes += ", ";
15660 sizes += to_array_size(type, index: i);
15661 sizes += ">";
15662 }
15663
15664 res += type_name + sizes;
15665 return res;
15666 }
15667}
15668
15669string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id)
15670{
15671 return type_to_glsl(type, id, member: false);
15672}
15673
15674string CompilerMSL::type_to_array_glsl(const SPIRType &type, uint32_t variable_id)
15675{
15676 // Allow Metal to use the array<T> template to make arrays a value type
15677 switch (type.basetype)
15678 {
15679 case SPIRType::AtomicCounter:
15680 case SPIRType::ControlPointArray:
15681 case SPIRType::RayQuery:
15682 return CompilerGLSL::type_to_array_glsl(type, variable_id);
15683
15684 default:
15685 if (type_is_array_of_pointers(type) || using_builtin_array())
15686 {
15687 const SPIRVariable *var = variable_id ? &get<SPIRVariable>(id: variable_id) : nullptr;
15688 if (var && (var->storage == StorageClassUniform || var->storage == StorageClassStorageBuffer) &&
15689 is_array(type: get_variable_data_type(var: *var)))
15690 {
15691 return join(ts: "[", ts: get_resource_array_size(type, id: variable_id), ts: "]");
15692 }
15693 else
15694 return CompilerGLSL::type_to_array_glsl(type, variable_id);
15695 }
15696 else
15697 return "";
15698 }
15699}
15700
15701string CompilerMSL::constant_op_expression(const SPIRConstantOp &cop)
15702{
15703 switch (cop.opcode)
15704 {
15705 case OpQuantizeToF16:
15706 add_spv_func_and_recompile(spv_func: SPVFuncImplQuantizeToF16);
15707 return join(ts: "spvQuantizeToF16(", ts: to_expression(id: cop.arguments[0]), ts: ")");
15708 default:
15709 return CompilerGLSL::constant_op_expression(cop);
15710 }
15711}
15712
15713bool CompilerMSL::variable_decl_is_remapped_storage(const SPIRVariable &variable, spv::StorageClass storage) const
15714{
15715 if (variable.storage == storage)
15716 return true;
15717
15718 if (storage == StorageClassWorkgroup)
15719 {
15720 // Specially masked IO block variable.
15721 // Normally, we will never access IO blocks directly here.
15722 // The only scenario which that should occur is with a masked IO block.
15723 if (is_tesc_shader() && variable.storage == StorageClassOutput &&
15724 has_decoration(id: get<SPIRType>(id: variable.basetype).self, decoration: DecorationBlock))
15725 {
15726 return true;
15727 }
15728
15729 return variable.storage == StorageClassOutput && is_tesc_shader() && is_stage_output_variable_masked(var: variable);
15730 }
15731 else if (storage == StorageClassStorageBuffer)
15732 {
15733 // These builtins are passed directly; we don't want to use remapping
15734 // for them.
15735 auto builtin = (BuiltIn)get_decoration(id: variable.self, decoration: DecorationBuiltIn);
15736 if (is_tese_shader() && is_builtin_variable(var: variable) && (builtin == BuiltInTessCoord || builtin == BuiltInPrimitiveId))
15737 return false;
15738
15739 // We won't be able to catch writes to control point outputs here since variable
15740 // refers to a function local pointer.
15741 // This is fine, as there cannot be concurrent writers to that memory anyways,
15742 // so we just ignore that case.
15743
15744 return (variable.storage == StorageClassOutput || variable.storage == StorageClassInput) &&
15745 !variable_storage_requires_stage_io(storage: variable.storage) &&
15746 (variable.storage != StorageClassOutput || !is_stage_output_variable_masked(var: variable));
15747 }
15748 else
15749 {
15750 return false;
15751 }
15752}
15753
15754// GCC workaround of lambdas calling protected funcs
15755std::string CompilerMSL::variable_decl(const SPIRType &type, const std::string &name, uint32_t id)
15756{
15757 return CompilerGLSL::variable_decl(type, name, id);
15758}
15759
15760std::string CompilerMSL::sampler_type(const SPIRType &type, uint32_t id, bool member)
15761{
15762 auto *var = maybe_get<SPIRVariable>(id);
15763 if (var && var->basevariable)
15764 {
15765 // Check against the base variable, and not a fake ID which might have been generated for this variable.
15766 id = var->basevariable;
15767 }
15768
15769 if (!type.array.empty())
15770 {
15771 if (!msl_options.supports_msl_version(major: 2))
15772 SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of samplers.");
15773
15774 if (type.array.size() > 1)
15775 SPIRV_CROSS_THROW("Arrays of arrays of samplers are not supported in MSL.");
15776
15777 // Arrays of samplers in MSL must be declared with a special array<T, N> syntax ala C++11 std::array.
15778 // If we have a runtime array, it could be a variable-count descriptor set binding.
15779 auto &parent = get<SPIRType>(id: get_pointee_type(type).parent_type);
15780 uint32_t array_size = get_resource_array_size(type, id);
15781
15782 if (array_size == 0)
15783 {
15784 add_spv_func_and_recompile(spv_func: SPVFuncImplVariableDescriptor);
15785 add_spv_func_and_recompile(spv_func: SPVFuncImplVariableDescriptorArray);
15786
15787 const char *descriptor_wrapper = processing_entry_point ? "const device spvDescriptor" : "const spvDescriptorArray";
15788 if (member)
15789 descriptor_wrapper = "spvDescriptor";
15790 return join(ts&: descriptor_wrapper, ts: "<", ts: sampler_type(type: parent, id, member: false), ts: ">",
15791 ts: processing_entry_point ? "*" : "");
15792 }
15793 else
15794 {
15795 return join(ts: "array<", ts: sampler_type(type: parent, id, member: false), ts: ", ", ts&: array_size, ts: ">");
15796 }
15797 }
15798 else
15799 return "sampler";
15800}
15801
15802// Returns an MSL string describing the SPIR-V image type
15803string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id, bool member)
15804{
15805 auto *var = maybe_get<SPIRVariable>(id);
15806 if (var && var->basevariable)
15807 {
15808 // For comparison images, check against the base variable,
15809 // and not the fake ID which might have been generated for this variable.
15810 id = var->basevariable;
15811 }
15812
15813 if (!type.array.empty())
15814 {
15815 uint32_t major = 2, minor = 0;
15816 if (msl_options.is_ios())
15817 {
15818 major = 1;
15819 minor = 2;
15820 }
15821 if (!msl_options.supports_msl_version(major, minor))
15822 {
15823 if (msl_options.is_ios())
15824 SPIRV_CROSS_THROW("MSL 1.2 or greater is required for arrays of textures.");
15825 else
15826 SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of textures.");
15827 }
15828
15829 if (type.array.size() > 1)
15830 SPIRV_CROSS_THROW("Arrays of arrays of textures are not supported in MSL.");
15831
15832 // Arrays of images in MSL must be declared with a special array<T, N> syntax ala C++11 std::array.
15833 // If we have a runtime array, it could be a variable-count descriptor set binding.
15834 auto &parent = get<SPIRType>(id: get_pointee_type(type).parent_type);
15835 uint32_t array_size = get_resource_array_size(type, id);
15836
15837 if (array_size == 0)
15838 {
15839 add_spv_func_and_recompile(spv_func: SPVFuncImplVariableDescriptor);
15840 add_spv_func_and_recompile(spv_func: SPVFuncImplVariableDescriptorArray);
15841 const char *descriptor_wrapper = processing_entry_point ? "const device spvDescriptor" : "const spvDescriptorArray";
15842 if (member)
15843 {
15844 descriptor_wrapper = "spvDescriptor";
15845 // This requires a specialized wrapper type that packs image and sampler side by side.
15846 // It is possible in theory.
15847 if (type.basetype == SPIRType::SampledImage)
15848 SPIRV_CROSS_THROW("Argument buffer runtime array currently not supported for combined image sampler.");
15849 }
15850 return join(ts&: descriptor_wrapper, ts: "<", ts: image_type_glsl(type: parent, id, member: false), ts: ">",
15851 ts: processing_entry_point ? "*" : "");
15852 }
15853 else
15854 {
15855 return join(ts: "array<", ts: image_type_glsl(type: parent, id, member: false), ts: ", ", ts&: array_size, ts: ">");
15856 }
15857 }
15858
15859 string img_type_name;
15860
15861 auto &img_type = type.image;
15862
15863 if (is_depth_image(type, id))
15864 {
15865 switch (img_type.dim)
15866 {
15867 case Dim1D:
15868 case Dim2D:
15869 if (img_type.dim == Dim1D && !msl_options.texture_1D_as_2D)
15870 {
15871 // Use a native Metal 1D texture
15872 img_type_name += "depth1d_unsupported_by_metal";
15873 break;
15874 }
15875
15876 if (img_type.ms && img_type.arrayed)
15877 {
15878 if (!msl_options.supports_msl_version(major: 2, minor: 1))
15879 SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1.");
15880 img_type_name += "depth2d_ms_array";
15881 }
15882 else if (img_type.ms)
15883 img_type_name += "depth2d_ms";
15884 else if (img_type.arrayed)
15885 img_type_name += "depth2d_array";
15886 else
15887 img_type_name += "depth2d";
15888 break;
15889 case Dim3D:
15890 img_type_name += "depth3d_unsupported_by_metal";
15891 break;
15892 case DimCube:
15893 if (!msl_options.emulate_cube_array)
15894 img_type_name += (img_type.arrayed ? "depthcube_array" : "depthcube");
15895 else
15896 img_type_name += (img_type.arrayed ? "depth2d_array" : "depthcube");
15897 break;
15898 default:
15899 img_type_name += "unknown_depth_texture_type";
15900 break;
15901 }
15902 }
15903 else
15904 {
15905 switch (img_type.dim)
15906 {
15907 case DimBuffer:
15908 if (img_type.ms || img_type.arrayed)
15909 SPIRV_CROSS_THROW("Cannot use texel buffers with multisampling or array layers.");
15910
15911 if (msl_options.texture_buffer_native)
15912 {
15913 if (!msl_options.supports_msl_version(major: 2, minor: 1))
15914 SPIRV_CROSS_THROW("Native texture_buffer type is only supported in MSL 2.1.");
15915 img_type_name = "texture_buffer";
15916 }
15917 else
15918 img_type_name += "texture2d";
15919 break;
15920 case Dim1D:
15921 case Dim2D:
15922 case DimSubpassData:
15923 {
15924 bool subpass_array =
15925 img_type.dim == DimSubpassData && (msl_options.multiview || msl_options.arrayed_subpass_input);
15926 if (img_type.dim == Dim1D && !msl_options.texture_1D_as_2D)
15927 {
15928 // Use a native Metal 1D texture
15929 img_type_name += (img_type.arrayed ? "texture1d_array" : "texture1d");
15930 break;
15931 }
15932
15933 // Use Metal's native frame-buffer fetch API for subpass inputs.
15934 if (type_is_msl_framebuffer_fetch(type))
15935 {
15936 auto img_type_4 = get<SPIRType>(id: img_type.type);
15937 img_type_4.vecsize = 4;
15938 return type_to_glsl(type: img_type_4);
15939 }
15940 if (img_type.ms && (img_type.arrayed || subpass_array))
15941 {
15942 if (!msl_options.supports_msl_version(major: 2, minor: 1))
15943 SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1.");
15944 img_type_name += "texture2d_ms_array";
15945 }
15946 else if (img_type.ms)
15947 img_type_name += "texture2d_ms";
15948 else if (img_type.arrayed || subpass_array)
15949 img_type_name += "texture2d_array";
15950 else
15951 img_type_name += "texture2d";
15952 break;
15953 }
15954 case Dim3D:
15955 img_type_name += "texture3d";
15956 break;
15957 case DimCube:
15958 if (!msl_options.emulate_cube_array)
15959 img_type_name += (img_type.arrayed ? "texturecube_array" : "texturecube");
15960 else
15961 img_type_name += (img_type.arrayed ? "texture2d_array" : "texturecube");
15962 break;
15963 default:
15964 img_type_name += "unknown_texture_type";
15965 break;
15966 }
15967 }
15968
15969 // Append the pixel type
15970 img_type_name += "<";
15971 img_type_name += type_to_glsl(type: get<SPIRType>(id: img_type.type));
15972
15973 // For unsampled images, append the sample/read/write access qualifier.
15974 // For kernel images, the access qualifier my be supplied directly by SPIR-V.
15975 // Otherwise it may be set based on whether the image is read from or written to within the shader.
15976 if (type.basetype == SPIRType::Image && type.image.sampled == 2 && type.image.dim != DimSubpassData)
15977 {
15978 switch (img_type.access)
15979 {
15980 case AccessQualifierReadOnly:
15981 img_type_name += ", access::read";
15982 break;
15983
15984 case AccessQualifierWriteOnly:
15985 img_type_name += ", access::write";
15986 break;
15987
15988 case AccessQualifierReadWrite:
15989 img_type_name += ", access::read_write";
15990 break;
15991
15992 default:
15993 {
15994 auto *p_var = maybe_get_backing_variable(chain: id);
15995 if (p_var && p_var->basevariable)
15996 p_var = maybe_get<SPIRVariable>(id: p_var->basevariable);
15997 if (p_var && !has_decoration(id: p_var->self, decoration: DecorationNonWritable))
15998 {
15999 img_type_name += ", access::";
16000
16001 if (!has_decoration(id: p_var->self, decoration: DecorationNonReadable))
16002 img_type_name += "read_";
16003
16004 img_type_name += "write";
16005 }
16006 break;
16007 }
16008 }
16009 }
16010
16011 img_type_name += ">";
16012
16013 return img_type_name;
16014}
16015
16016void CompilerMSL::emit_subgroup_op(const Instruction &i)
16017{
16018 const uint32_t *ops = stream(instr: i);
16019 auto op = static_cast<Op>(i.op);
16020
16021 if (msl_options.emulate_subgroups)
16022 {
16023 // In this mode, only the GroupNonUniform cap is supported. The only op
16024 // we need to handle, then, is OpGroupNonUniformElect.
16025 if (op != OpGroupNonUniformElect)
16026 SPIRV_CROSS_THROW("Subgroup emulation does not support operations other than Elect.");
16027 // In this mode, the subgroup size is assumed to be one, so every invocation
16028 // is elected.
16029 emit_op(result_type: ops[0], result_id: ops[1], rhs: "true", forward_rhs: true);
16030 return;
16031 }
16032
16033 // Metal 2.0 is required. iOS only supports quad ops on 11.0 (2.0), with
16034 // full support in 13.0 (2.2). macOS only supports broadcast and shuffle on
16035 // 10.13 (2.0), with full support in 10.14 (2.1).
16036 // Note that Apple GPUs before A13 make no distinction between a quad-group
16037 // and a SIMD-group; all SIMD-groups are quad-groups on those.
16038 if (!msl_options.supports_msl_version(major: 2))
16039 SPIRV_CROSS_THROW("Subgroups are only supported in Metal 2.0 and up.");
16040
16041 // If we need to do implicit bitcasts, make sure we do it with the correct type.
16042 uint32_t integer_width = get_integer_width_for_instruction(instr: i);
16043 auto int_type = to_signed_basetype(width: integer_width);
16044 auto uint_type = to_unsigned_basetype(width: integer_width);
16045
16046 if (msl_options.is_ios() && (!msl_options.supports_msl_version(major: 2, minor: 3) || !msl_options.ios_use_simdgroup_functions))
16047 {
16048 switch (op)
16049 {
16050 default:
16051 SPIRV_CROSS_THROW("Subgroup ops beyond broadcast, ballot, and shuffle on iOS require Metal 2.3 and up.");
16052 case OpGroupNonUniformBroadcastFirst:
16053 if (!msl_options.supports_msl_version(major: 2, minor: 2))
16054 SPIRV_CROSS_THROW("BroadcastFirst on iOS requires Metal 2.2 and up.");
16055 break;
16056 case OpGroupNonUniformElect:
16057 if (!msl_options.supports_msl_version(major: 2, minor: 2))
16058 SPIRV_CROSS_THROW("Elect on iOS requires Metal 2.2 and up.");
16059 break;
16060 case OpGroupNonUniformAny:
16061 case OpGroupNonUniformAll:
16062 case OpGroupNonUniformAllEqual:
16063 case OpGroupNonUniformBallot:
16064 case OpGroupNonUniformInverseBallot:
16065 case OpGroupNonUniformBallotBitExtract:
16066 case OpGroupNonUniformBallotFindLSB:
16067 case OpGroupNonUniformBallotFindMSB:
16068 case OpGroupNonUniformBallotBitCount:
16069 case OpSubgroupBallotKHR:
16070 case OpSubgroupAllKHR:
16071 case OpSubgroupAnyKHR:
16072 case OpSubgroupAllEqualKHR:
16073 if (!msl_options.supports_msl_version(major: 2, minor: 2))
16074 SPIRV_CROSS_THROW("Ballot ops on iOS requires Metal 2.2 and up.");
16075 break;
16076 case OpGroupNonUniformBroadcast:
16077 case OpGroupNonUniformShuffle:
16078 case OpGroupNonUniformShuffleXor:
16079 case OpGroupNonUniformShuffleUp:
16080 case OpGroupNonUniformShuffleDown:
16081 case OpGroupNonUniformQuadSwap:
16082 case OpGroupNonUniformQuadBroadcast:
16083 case OpSubgroupReadInvocationKHR:
16084 break;
16085 }
16086 }
16087
16088 if (msl_options.is_macos() && !msl_options.supports_msl_version(major: 2, minor: 1))
16089 {
16090 switch (op)
16091 {
16092 default:
16093 SPIRV_CROSS_THROW("Subgroup ops beyond broadcast and shuffle on macOS require Metal 2.1 and up.");
16094 case OpGroupNonUniformBroadcast:
16095 case OpGroupNonUniformShuffle:
16096 case OpGroupNonUniformShuffleXor:
16097 case OpGroupNonUniformShuffleUp:
16098 case OpGroupNonUniformShuffleDown:
16099 case OpSubgroupReadInvocationKHR:
16100 break;
16101 }
16102 }
16103
16104 uint32_t op_idx = 0;
16105 uint32_t result_type = ops[op_idx++];
16106 uint32_t id = ops[op_idx++];
16107
16108 Scope scope;
16109 switch (op)
16110 {
16111 case OpSubgroupBallotKHR:
16112 case OpSubgroupFirstInvocationKHR:
16113 case OpSubgroupReadInvocationKHR:
16114 case OpSubgroupAllKHR:
16115 case OpSubgroupAnyKHR:
16116 case OpSubgroupAllEqualKHR:
16117 // These earlier instructions don't have the scope operand.
16118 scope = ScopeSubgroup;
16119 break;
16120 default:
16121 scope = static_cast<Scope>(evaluate_constant_u32(id: ops[op_idx++]));
16122 break;
16123 }
16124 if (scope != ScopeSubgroup)
16125 SPIRV_CROSS_THROW("Only subgroup scope is supported.");
16126
16127 switch (op)
16128 {
16129 case OpGroupNonUniformElect:
16130 if (msl_options.use_quadgroup_operation())
16131 emit_op(result_type, result_id: id, rhs: "quad_is_first()", forward_rhs: false);
16132 else
16133 emit_op(result_type, result_id: id, rhs: "simd_is_first()", forward_rhs: false);
16134 break;
16135
16136 case OpGroupNonUniformBroadcast:
16137 case OpSubgroupReadInvocationKHR:
16138 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: ops[op_idx + 1], op: "spvSubgroupBroadcast");
16139 break;
16140
16141 case OpGroupNonUniformBroadcastFirst:
16142 case OpSubgroupFirstInvocationKHR:
16143 emit_unary_func_op(result_type, result_id: id, op0: ops[op_idx], op: "spvSubgroupBroadcastFirst");
16144 break;
16145
16146 case OpGroupNonUniformBallot:
16147 case OpSubgroupBallotKHR:
16148 emit_unary_func_op(result_type, result_id: id, op0: ops[op_idx], op: "spvSubgroupBallot");
16149 break;
16150
16151 case OpGroupNonUniformInverseBallot:
16152 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: builtin_subgroup_invocation_id_id, op: "spvSubgroupBallotBitExtract");
16153 break;
16154
16155 case OpGroupNonUniformBallotBitExtract:
16156 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: ops[op_idx + 1], op: "spvSubgroupBallotBitExtract");
16157 break;
16158
16159 case OpGroupNonUniformBallotFindLSB:
16160 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: builtin_subgroup_size_id, op: "spvSubgroupBallotFindLSB");
16161 break;
16162
16163 case OpGroupNonUniformBallotFindMSB:
16164 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: builtin_subgroup_size_id, op: "spvSubgroupBallotFindMSB");
16165 break;
16166
16167 case OpGroupNonUniformBallotBitCount:
16168 {
16169 auto operation = static_cast<GroupOperation>(ops[op_idx++]);
16170 switch (operation)
16171 {
16172 case GroupOperationReduce:
16173 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: builtin_subgroup_size_id, op: "spvSubgroupBallotBitCount");
16174 break;
16175 case GroupOperationInclusiveScan:
16176 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: builtin_subgroup_invocation_id_id,
16177 op: "spvSubgroupBallotInclusiveBitCount");
16178 break;
16179 case GroupOperationExclusiveScan:
16180 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: builtin_subgroup_invocation_id_id,
16181 op: "spvSubgroupBallotExclusiveBitCount");
16182 break;
16183 default:
16184 SPIRV_CROSS_THROW("Invalid BitCount operation.");
16185 }
16186 break;
16187 }
16188
16189 case OpGroupNonUniformShuffle:
16190 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: ops[op_idx + 1], op: "spvSubgroupShuffle");
16191 break;
16192
16193 case OpGroupNonUniformShuffleXor:
16194 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: ops[op_idx + 1], op: "spvSubgroupShuffleXor");
16195 break;
16196
16197 case OpGroupNonUniformShuffleUp:
16198 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: ops[op_idx + 1], op: "spvSubgroupShuffleUp");
16199 break;
16200
16201 case OpGroupNonUniformShuffleDown:
16202 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: ops[op_idx + 1], op: "spvSubgroupShuffleDown");
16203 break;
16204
16205 case OpGroupNonUniformAll:
16206 case OpSubgroupAllKHR:
16207 if (msl_options.use_quadgroup_operation())
16208 emit_unary_func_op(result_type, result_id: id, op0: ops[op_idx], op: "quad_all");
16209 else
16210 emit_unary_func_op(result_type, result_id: id, op0: ops[op_idx], op: "simd_all");
16211 break;
16212
16213 case OpGroupNonUniformAny:
16214 case OpSubgroupAnyKHR:
16215 if (msl_options.use_quadgroup_operation())
16216 emit_unary_func_op(result_type, result_id: id, op0: ops[op_idx], op: "quad_any");
16217 else
16218 emit_unary_func_op(result_type, result_id: id, op0: ops[op_idx], op: "simd_any");
16219 break;
16220
16221 case OpGroupNonUniformAllEqual:
16222 case OpSubgroupAllEqualKHR:
16223 emit_unary_func_op(result_type, result_id: id, op0: ops[op_idx], op: "spvSubgroupAllEqual");
16224 break;
16225
16226 // clang-format off
16227#define MSL_GROUP_OP(op, msl_op) \
16228case OpGroupNonUniform##op: \
16229 { \
16230 auto operation = static_cast<GroupOperation>(ops[op_idx++]); \
16231 if (operation == GroupOperationReduce) \
16232 emit_unary_func_op(result_type, id, ops[op_idx], "simd_" #msl_op); \
16233 else if (operation == GroupOperationInclusiveScan) \
16234 emit_unary_func_op(result_type, id, ops[op_idx], "simd_prefix_inclusive_" #msl_op); \
16235 else if (operation == GroupOperationExclusiveScan) \
16236 emit_unary_func_op(result_type, id, ops[op_idx], "simd_prefix_exclusive_" #msl_op); \
16237 else if (operation == GroupOperationClusteredReduce) \
16238 { \
16239 /* Only cluster sizes of 4 are supported. */ \
16240 uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \
16241 if (cluster_size != 4) \
16242 SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
16243 emit_unary_func_op(result_type, id, ops[op_idx], "quad_" #msl_op); \
16244 } \
16245 else \
16246 SPIRV_CROSS_THROW("Invalid group operation."); \
16247 break; \
16248 }
16249 MSL_GROUP_OP(FAdd, sum)
16250 MSL_GROUP_OP(FMul, product)
16251 MSL_GROUP_OP(IAdd, sum)
16252 MSL_GROUP_OP(IMul, product)
16253#undef MSL_GROUP_OP
16254 // The others, unfortunately, don't support InclusiveScan or ExclusiveScan.
16255
16256#define MSL_GROUP_OP(op, msl_op) \
16257case OpGroupNonUniform##op: \
16258 { \
16259 auto operation = static_cast<GroupOperation>(ops[op_idx++]); \
16260 if (operation == GroupOperationReduce) \
16261 emit_unary_func_op(result_type, id, ops[op_idx], "simd_" #msl_op); \
16262 else if (operation == GroupOperationInclusiveScan) \
16263 SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
16264 else if (operation == GroupOperationExclusiveScan) \
16265 SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
16266 else if (operation == GroupOperationClusteredReduce) \
16267 { \
16268 /* Only cluster sizes of 4 are supported. */ \
16269 uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \
16270 if (cluster_size != 4) \
16271 SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
16272 emit_unary_func_op(result_type, id, ops[op_idx], "quad_" #msl_op); \
16273 } \
16274 else \
16275 SPIRV_CROSS_THROW("Invalid group operation."); \
16276 break; \
16277 }
16278
16279#define MSL_GROUP_OP_CAST(op, msl_op, type) \
16280case OpGroupNonUniform##op: \
16281 { \
16282 auto operation = static_cast<GroupOperation>(ops[op_idx++]); \
16283 if (operation == GroupOperationReduce) \
16284 emit_unary_func_op_cast(result_type, id, ops[op_idx], "simd_" #msl_op, type, type); \
16285 else if (operation == GroupOperationInclusiveScan) \
16286 SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
16287 else if (operation == GroupOperationExclusiveScan) \
16288 SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
16289 else if (operation == GroupOperationClusteredReduce) \
16290 { \
16291 /* Only cluster sizes of 4 are supported. */ \
16292 uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \
16293 if (cluster_size != 4) \
16294 SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
16295 emit_unary_func_op_cast(result_type, id, ops[op_idx], "quad_" #msl_op, type, type); \
16296 } \
16297 else \
16298 SPIRV_CROSS_THROW("Invalid group operation."); \
16299 break; \
16300 }
16301
16302 MSL_GROUP_OP(FMin, min)
16303 MSL_GROUP_OP(FMax, max)
16304 MSL_GROUP_OP_CAST(SMin, min, int_type)
16305 MSL_GROUP_OP_CAST(SMax, max, int_type)
16306 MSL_GROUP_OP_CAST(UMin, min, uint_type)
16307 MSL_GROUP_OP_CAST(UMax, max, uint_type)
16308 MSL_GROUP_OP(BitwiseAnd, and)
16309 MSL_GROUP_OP(BitwiseOr, or)
16310 MSL_GROUP_OP(BitwiseXor, xor)
16311 MSL_GROUP_OP(LogicalAnd, and)
16312 MSL_GROUP_OP(LogicalOr, or)
16313 MSL_GROUP_OP(LogicalXor, xor)
16314 // clang-format on
16315#undef MSL_GROUP_OP
16316#undef MSL_GROUP_OP_CAST
16317
16318 case OpGroupNonUniformQuadSwap:
16319 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: ops[op_idx + 1], op: "spvQuadSwap");
16320 break;
16321
16322 case OpGroupNonUniformQuadBroadcast:
16323 emit_binary_func_op(result_type, result_id: id, op0: ops[op_idx], op1: ops[op_idx + 1], op: "spvQuadBroadcast");
16324 break;
16325
16326 default:
16327 SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
16328 }
16329
16330 register_control_dependent_expression(expr: id);
16331}
16332
16333string CompilerMSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
16334{
16335 if (out_type.basetype == in_type.basetype)
16336 return "";
16337
16338 assert(out_type.basetype != SPIRType::Boolean);
16339 assert(in_type.basetype != SPIRType::Boolean);
16340
16341 bool integral_cast = type_is_integral(type: out_type) && type_is_integral(type: in_type) && (out_type.vecsize == in_type.vecsize);
16342 bool same_size_cast = (out_type.width * out_type.vecsize) == (in_type.width * in_type.vecsize);
16343
16344 // Bitcasting can only be used between types of the same overall size.
16345 // And always formally cast between integers, because it's trivial, and also
16346 // because Metal can internally cast the results of some integer ops to a larger
16347 // size (eg. short shift right becomes int), which means chaining integer ops
16348 // together may introduce size variations that SPIR-V doesn't know about.
16349 if (same_size_cast && !integral_cast)
16350 return "as_type<" + type_to_glsl(type: out_type) + ">";
16351 else
16352 return type_to_glsl(type: out_type);
16353}
16354
16355bool CompilerMSL::emit_complex_bitcast(uint32_t, uint32_t, uint32_t)
16356{
16357 // This is handled from the outside where we deal with PtrToU/UToPtr and friends.
16358 return false;
16359}
16360
16361// Returns an MSL string identifying the name of a SPIR-V builtin.
16362// Output builtins are qualified with the name of the stage out structure.
16363string CompilerMSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage)
16364{
16365 switch (builtin)
16366 {
16367 // Handle HLSL-style 0-based vertex/instance index.
16368 // Override GLSL compiler strictness
16369 case BuiltInVertexId:
16370 ensure_builtin(storage: StorageClassInput, builtin: BuiltInVertexId);
16371 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(major: 1, minor: 1) &&
16372 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
16373 {
16374 if (builtin_declaration)
16375 {
16376 if (needs_base_vertex_arg != TriState::No)
16377 needs_base_vertex_arg = TriState::Yes;
16378 return "gl_VertexID";
16379 }
16380 else
16381 {
16382 ensure_builtin(storage: StorageClassInput, builtin: BuiltInBaseVertex);
16383 return "(gl_VertexID - gl_BaseVertex)";
16384 }
16385 }
16386 else
16387 {
16388 return "gl_VertexID";
16389 }
16390 case BuiltInInstanceId:
16391 ensure_builtin(storage: StorageClassInput, builtin: BuiltInInstanceId);
16392 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(major: 1, minor: 1) &&
16393 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
16394 {
16395 if (builtin_declaration)
16396 {
16397 if (needs_base_instance_arg != TriState::No)
16398 needs_base_instance_arg = TriState::Yes;
16399 return "gl_InstanceID";
16400 }
16401 else
16402 {
16403 ensure_builtin(storage: StorageClassInput, builtin: BuiltInBaseInstance);
16404 return "(gl_InstanceID - gl_BaseInstance)";
16405 }
16406 }
16407 else
16408 {
16409 return "gl_InstanceID";
16410 }
16411 case BuiltInVertexIndex:
16412 ensure_builtin(storage: StorageClassInput, builtin: BuiltInVertexIndex);
16413 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(major: 1, minor: 1) &&
16414 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
16415 {
16416 if (builtin_declaration)
16417 {
16418 if (needs_base_vertex_arg != TriState::No)
16419 needs_base_vertex_arg = TriState::Yes;
16420 return "gl_VertexIndex";
16421 }
16422 else
16423 {
16424 ensure_builtin(storage: StorageClassInput, builtin: BuiltInBaseVertex);
16425 return "(gl_VertexIndex - gl_BaseVertex)";
16426 }
16427 }
16428 else
16429 {
16430 return "gl_VertexIndex";
16431 }
16432 case BuiltInInstanceIndex:
16433 ensure_builtin(storage: StorageClassInput, builtin: BuiltInInstanceIndex);
16434 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(major: 1, minor: 1) &&
16435 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
16436 {
16437 if (builtin_declaration)
16438 {
16439 if (needs_base_instance_arg != TriState::No)
16440 needs_base_instance_arg = TriState::Yes;
16441 return "gl_InstanceIndex";
16442 }
16443 else
16444 {
16445 ensure_builtin(storage: StorageClassInput, builtin: BuiltInBaseInstance);
16446 return "(gl_InstanceIndex - gl_BaseInstance)";
16447 }
16448 }
16449 else
16450 {
16451 return "gl_InstanceIndex";
16452 }
16453 case BuiltInBaseVertex:
16454 if (msl_options.supports_msl_version(major: 1, minor: 1) &&
16455 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
16456 {
16457 needs_base_vertex_arg = TriState::No;
16458 return "gl_BaseVertex";
16459 }
16460 else
16461 {
16462 SPIRV_CROSS_THROW("BaseVertex requires Metal 1.1 and Mac or Apple A9+ hardware.");
16463 }
16464 case BuiltInBaseInstance:
16465 if (msl_options.supports_msl_version(major: 1, minor: 1) &&
16466 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
16467 {
16468 needs_base_instance_arg = TriState::No;
16469 return "gl_BaseInstance";
16470 }
16471 else
16472 {
16473 SPIRV_CROSS_THROW("BaseInstance requires Metal 1.1 and Mac or Apple A9+ hardware.");
16474 }
16475 case BuiltInDrawIndex:
16476 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
16477
16478 // When used in the entry function, output builtins are qualified with output struct name.
16479 // Test storage class as NOT Input, as output builtins might be part of generic type.
16480 // Also don't do this for tessellation control shaders.
16481 case BuiltInViewportIndex:
16482 if (!msl_options.supports_msl_version(major: 2, minor: 0))
16483 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
16484 /* fallthrough */
16485 case BuiltInFragDepth:
16486 case BuiltInFragStencilRefEXT:
16487 if ((builtin == BuiltInFragDepth && !msl_options.enable_frag_depth_builtin) ||
16488 (builtin == BuiltInFragStencilRefEXT && !msl_options.enable_frag_stencil_ref_builtin))
16489 break;
16490 /* fallthrough */
16491 case BuiltInPosition:
16492 case BuiltInPointSize:
16493 case BuiltInClipDistance:
16494 case BuiltInCullDistance:
16495 case BuiltInLayer:
16496 if (is_tesc_shader())
16497 break;
16498 if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point) &&
16499 !is_stage_output_builtin_masked(builtin))
16500 return stage_out_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
16501 break;
16502
16503 case BuiltInSampleMask:
16504 if (storage == StorageClassInput && current_function && (current_function->self == ir.default_entry_point) &&
16505 (has_additional_fixed_sample_mask() || needs_sample_id))
16506 {
16507 string samp_mask_in;
16508 samp_mask_in += "(" + CompilerGLSL::builtin_to_glsl(builtin, storage);
16509 if (has_additional_fixed_sample_mask())
16510 samp_mask_in += " & " + additional_fixed_sample_mask_str();
16511 if (needs_sample_id)
16512 samp_mask_in += " & (1 << gl_SampleID)";
16513 samp_mask_in += ")";
16514 return samp_mask_in;
16515 }
16516 if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point) &&
16517 !is_stage_output_builtin_masked(builtin))
16518 return stage_out_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
16519 break;
16520
16521 case BuiltInBaryCoordKHR:
16522 case BuiltInBaryCoordNoPerspKHR:
16523 if (storage == StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
16524 return stage_in_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
16525 break;
16526
16527 case BuiltInTessLevelOuter:
16528 if (is_tesc_shader() && storage != StorageClassInput && current_function &&
16529 (current_function->self == ir.default_entry_point))
16530 {
16531 return join(ts&: tess_factor_buffer_var_name, ts: "[", ts: to_expression(id: builtin_primitive_id_id),
16532 ts: "].edgeTessellationFactor");
16533 }
16534 break;
16535
16536 case BuiltInTessLevelInner:
16537 if (is_tesc_shader() && storage != StorageClassInput && current_function &&
16538 (current_function->self == ir.default_entry_point))
16539 {
16540 return join(ts&: tess_factor_buffer_var_name, ts: "[", ts: to_expression(id: builtin_primitive_id_id),
16541 ts: "].insideTessellationFactor");
16542 }
16543 break;
16544
16545 case BuiltInHelperInvocation:
16546 if (needs_manual_helper_invocation_updates())
16547 break;
16548 if (msl_options.is_ios() && !msl_options.supports_msl_version(major: 2, minor: 3))
16549 SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.3 on iOS.");
16550 else if (msl_options.is_macos() && !msl_options.supports_msl_version(major: 2, minor: 1))
16551 SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.1 on macOS.");
16552 // In SPIR-V 1.6 with Volatile HelperInvocation, we cannot emit a fixup early.
16553 return "simd_is_helper_thread()";
16554
16555 default:
16556 break;
16557 }
16558
16559 return CompilerGLSL::builtin_to_glsl(builtin, storage);
16560}
16561
16562// Returns an MSL string attribute qualifer for a SPIR-V builtin
16563string CompilerMSL::builtin_qualifier(BuiltIn builtin)
16564{
16565 auto &execution = get_entry_point();
16566
16567 switch (builtin)
16568 {
16569 // Vertex function in
16570 case BuiltInVertexId:
16571 return "vertex_id";
16572 case BuiltInVertexIndex:
16573 return "vertex_id";
16574 case BuiltInBaseVertex:
16575 return "base_vertex";
16576 case BuiltInInstanceId:
16577 return "instance_id";
16578 case BuiltInInstanceIndex:
16579 return "instance_id";
16580 case BuiltInBaseInstance:
16581 return "base_instance";
16582 case BuiltInDrawIndex:
16583 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
16584
16585 // Vertex function out
16586 case BuiltInClipDistance:
16587 return "clip_distance";
16588 case BuiltInPointSize:
16589 return "point_size";
16590 case BuiltInPosition:
16591 if (position_invariant)
16592 {
16593 if (!msl_options.supports_msl_version(major: 2, minor: 1))
16594 SPIRV_CROSS_THROW("Invariant position is only supported on MSL 2.1 and up.");
16595 return "position, invariant";
16596 }
16597 else
16598 return "position";
16599 case BuiltInLayer:
16600 return "render_target_array_index";
16601 case BuiltInViewportIndex:
16602 if (!msl_options.supports_msl_version(major: 2, minor: 0))
16603 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
16604 return "viewport_array_index";
16605
16606 // Tess. control function in
16607 case BuiltInInvocationId:
16608 if (msl_options.multi_patch_workgroup)
16609 {
16610 // Shouldn't be reached.
16611 SPIRV_CROSS_THROW("InvocationId is computed manually with multi-patch workgroups in MSL.");
16612 }
16613 return "thread_index_in_threadgroup";
16614 case BuiltInPatchVertices:
16615 // Shouldn't be reached.
16616 SPIRV_CROSS_THROW("PatchVertices is derived from the auxiliary buffer in MSL.");
16617 case BuiltInPrimitiveId:
16618 switch (execution.model)
16619 {
16620 case ExecutionModelTessellationControl:
16621 if (msl_options.multi_patch_workgroup)
16622 {
16623 // Shouldn't be reached.
16624 SPIRV_CROSS_THROW("PrimitiveId is computed manually with multi-patch workgroups in MSL.");
16625 }
16626 return "threadgroup_position_in_grid";
16627 case ExecutionModelTessellationEvaluation:
16628 return "patch_id";
16629 case ExecutionModelFragment:
16630 if (msl_options.is_ios() && !msl_options.supports_msl_version(major: 2, minor: 3))
16631 SPIRV_CROSS_THROW("PrimitiveId on iOS requires MSL 2.3.");
16632 else if (msl_options.is_macos() && !msl_options.supports_msl_version(major: 2, minor: 2))
16633 SPIRV_CROSS_THROW("PrimitiveId on macOS requires MSL 2.2.");
16634 return "primitive_id";
16635 default:
16636 SPIRV_CROSS_THROW("PrimitiveId is not supported in this execution model.");
16637 }
16638
16639 // Tess. control function out
16640 case BuiltInTessLevelOuter:
16641 case BuiltInTessLevelInner:
16642 // Shouldn't be reached.
16643 SPIRV_CROSS_THROW("Tessellation levels are handled specially in MSL.");
16644
16645 // Tess. evaluation function in
16646 case BuiltInTessCoord:
16647 return "position_in_patch";
16648
16649 // Fragment function in
16650 case BuiltInFrontFacing:
16651 return "front_facing";
16652 case BuiltInPointCoord:
16653 return "point_coord";
16654 case BuiltInFragCoord:
16655 return "position";
16656 case BuiltInSampleId:
16657 return "sample_id";
16658 case BuiltInSampleMask:
16659 return "sample_mask";
16660 case BuiltInSamplePosition:
16661 // Shouldn't be reached.
16662 SPIRV_CROSS_THROW("Sample position is retrieved by a function in MSL.");
16663 case BuiltInViewIndex:
16664 if (execution.model != ExecutionModelFragment)
16665 SPIRV_CROSS_THROW("ViewIndex is handled specially outside fragment shaders.");
16666 // The ViewIndex was implicitly used in the prior stages to set the render_target_array_index,
16667 // so we can get it from there.
16668 return "render_target_array_index";
16669
16670 // Fragment function out
16671 case BuiltInFragDepth:
16672 if (execution.flags.get(bit: ExecutionModeDepthGreater))
16673 return "depth(greater)";
16674 else if (execution.flags.get(bit: ExecutionModeDepthLess))
16675 return "depth(less)";
16676 else
16677 return "depth(any)";
16678
16679 case BuiltInFragStencilRefEXT:
16680 return "stencil";
16681
16682 // Compute function in
16683 case BuiltInGlobalInvocationId:
16684 return "thread_position_in_grid";
16685
16686 case BuiltInWorkgroupId:
16687 return "threadgroup_position_in_grid";
16688
16689 case BuiltInNumWorkgroups:
16690 return "threadgroups_per_grid";
16691
16692 case BuiltInLocalInvocationId:
16693 return "thread_position_in_threadgroup";
16694
16695 case BuiltInLocalInvocationIndex:
16696 return "thread_index_in_threadgroup";
16697
16698 case BuiltInSubgroupSize:
16699 if (msl_options.emulate_subgroups || msl_options.fixed_subgroup_size != 0)
16700 // Shouldn't be reached.
16701 SPIRV_CROSS_THROW("Emitting threads_per_simdgroup attribute with fixed subgroup size??");
16702 if (execution.model == ExecutionModelFragment)
16703 {
16704 if (!msl_options.supports_msl_version(major: 2, minor: 2))
16705 SPIRV_CROSS_THROW("threads_per_simdgroup requires Metal 2.2 in fragment shaders.");
16706 return "threads_per_simdgroup";
16707 }
16708 else
16709 {
16710 // thread_execution_width is an alias for threads_per_simdgroup, and it's only available since 1.0,
16711 // but not in fragment.
16712 return "thread_execution_width";
16713 }
16714
16715 case BuiltInNumSubgroups:
16716 if (msl_options.emulate_subgroups)
16717 // Shouldn't be reached.
16718 SPIRV_CROSS_THROW("NumSubgroups is handled specially with emulation.");
16719 if (!msl_options.supports_msl_version(major: 2))
16720 SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
16721 return msl_options.use_quadgroup_operation() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup";
16722
16723 case BuiltInSubgroupId:
16724 if (msl_options.emulate_subgroups)
16725 // Shouldn't be reached.
16726 SPIRV_CROSS_THROW("SubgroupId is handled specially with emulation.");
16727 if (!msl_options.supports_msl_version(major: 2))
16728 SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
16729 return msl_options.use_quadgroup_operation() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup";
16730
16731 case BuiltInSubgroupLocalInvocationId:
16732 if (msl_options.emulate_subgroups)
16733 // Shouldn't be reached.
16734 SPIRV_CROSS_THROW("SubgroupLocalInvocationId is handled specially with emulation.");
16735 if (execution.model == ExecutionModelFragment)
16736 {
16737 if (!msl_options.supports_msl_version(major: 2, minor: 2))
16738 SPIRV_CROSS_THROW("thread_index_in_simdgroup requires Metal 2.2 in fragment shaders.");
16739 return "thread_index_in_simdgroup";
16740 }
16741 else if (execution.model == ExecutionModelKernel || execution.model == ExecutionModelGLCompute ||
16742 execution.model == ExecutionModelTessellationControl ||
16743 (execution.model == ExecutionModelVertex && msl_options.vertex_for_tessellation))
16744 {
16745 // We are generating a Metal kernel function.
16746 if (!msl_options.supports_msl_version(major: 2))
16747 SPIRV_CROSS_THROW("Subgroup builtins in kernel functions require Metal 2.0.");
16748 return msl_options.use_quadgroup_operation() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup";
16749 }
16750 else
16751 SPIRV_CROSS_THROW("Subgroup builtins are not available in this type of function.");
16752
16753 case BuiltInSubgroupEqMask:
16754 case BuiltInSubgroupGeMask:
16755 case BuiltInSubgroupGtMask:
16756 case BuiltInSubgroupLeMask:
16757 case BuiltInSubgroupLtMask:
16758 // Shouldn't be reached.
16759 SPIRV_CROSS_THROW("Subgroup ballot masks are handled specially in MSL.");
16760
16761 case BuiltInBaryCoordKHR:
16762 if (msl_options.is_ios() && !msl_options.supports_msl_version(major: 2, minor: 3))
16763 SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.3 and above on iOS.");
16764 else if (!msl_options.supports_msl_version(major: 2, minor: 2))
16765 SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS.");
16766 return "barycentric_coord, center_perspective";
16767
16768 case BuiltInBaryCoordNoPerspKHR:
16769 if (msl_options.is_ios() && !msl_options.supports_msl_version(major: 2, minor: 3))
16770 SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.3 and above on iOS.");
16771 else if (!msl_options.supports_msl_version(major: 2, minor: 2))
16772 SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS.");
16773 return "barycentric_coord, center_no_perspective";
16774
16775 default:
16776 return "unsupported-built-in";
16777 }
16778}
16779
16780// Returns an MSL string type declaration for a SPIR-V builtin
16781string CompilerMSL::builtin_type_decl(BuiltIn builtin, uint32_t id)
16782{
16783 switch (builtin)
16784 {
16785 // Vertex function in
16786 case BuiltInVertexId:
16787 return "uint";
16788 case BuiltInVertexIndex:
16789 return "uint";
16790 case BuiltInBaseVertex:
16791 return "uint";
16792 case BuiltInInstanceId:
16793 return "uint";
16794 case BuiltInInstanceIndex:
16795 return "uint";
16796 case BuiltInBaseInstance:
16797 return "uint";
16798 case BuiltInDrawIndex:
16799 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
16800
16801 // Vertex function out
16802 case BuiltInClipDistance:
16803 case BuiltInCullDistance:
16804 return "float";
16805 case BuiltInPointSize:
16806 return "float";
16807 case BuiltInPosition:
16808 return "float4";
16809 case BuiltInLayer:
16810 return "uint";
16811 case BuiltInViewportIndex:
16812 if (!msl_options.supports_msl_version(major: 2, minor: 0))
16813 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
16814 return "uint";
16815
16816 // Tess. control function in
16817 case BuiltInInvocationId:
16818 return "uint";
16819 case BuiltInPatchVertices:
16820 return "uint";
16821 case BuiltInPrimitiveId:
16822 return "uint";
16823
16824 // Tess. control function out
16825 case BuiltInTessLevelInner:
16826 if (is_tese_shader())
16827 return (msl_options.raw_buffer_tese_input || is_tessellating_triangles()) ? "float" : "float2";
16828 return "half";
16829 case BuiltInTessLevelOuter:
16830 if (is_tese_shader())
16831 return (msl_options.raw_buffer_tese_input || is_tessellating_triangles()) ? "float" : "float4";
16832 return "half";
16833
16834 // Tess. evaluation function in
16835 case BuiltInTessCoord:
16836 return "float3";
16837
16838 // Fragment function in
16839 case BuiltInFrontFacing:
16840 return "bool";
16841 case BuiltInPointCoord:
16842 return "float2";
16843 case BuiltInFragCoord:
16844 return "float4";
16845 case BuiltInSampleId:
16846 return "uint";
16847 case BuiltInSampleMask:
16848 return "uint";
16849 case BuiltInSamplePosition:
16850 return "float2";
16851 case BuiltInViewIndex:
16852 return "uint";
16853
16854 case BuiltInHelperInvocation:
16855 return "bool";
16856
16857 case BuiltInBaryCoordKHR:
16858 case BuiltInBaryCoordNoPerspKHR:
16859 // Use the type as declared, can be 1, 2 or 3 components.
16860 return type_to_glsl(type: get_variable_data_type(var: get<SPIRVariable>(id)));
16861
16862 // Fragment function out
16863 case BuiltInFragDepth:
16864 return "float";
16865
16866 case BuiltInFragStencilRefEXT:
16867 return "uint";
16868
16869 // Compute function in
16870 case BuiltInGlobalInvocationId:
16871 case BuiltInLocalInvocationId:
16872 case BuiltInNumWorkgroups:
16873 case BuiltInWorkgroupId:
16874 return "uint3";
16875 case BuiltInLocalInvocationIndex:
16876 case BuiltInNumSubgroups:
16877 case BuiltInSubgroupId:
16878 case BuiltInSubgroupSize:
16879 case BuiltInSubgroupLocalInvocationId:
16880 return "uint";
16881 case BuiltInSubgroupEqMask:
16882 case BuiltInSubgroupGeMask:
16883 case BuiltInSubgroupGtMask:
16884 case BuiltInSubgroupLeMask:
16885 case BuiltInSubgroupLtMask:
16886 return "uint4";
16887
16888 case BuiltInDeviceIndex:
16889 return "int";
16890
16891 default:
16892 return "unsupported-built-in-type";
16893 }
16894}
16895
16896// Returns the declaration of a built-in argument to a function
16897string CompilerMSL::built_in_func_arg(BuiltIn builtin, bool prefix_comma)
16898{
16899 string bi_arg;
16900 if (prefix_comma)
16901 bi_arg += ", ";
16902
16903 // Handle HLSL-style 0-based vertex/instance index.
16904 builtin_declaration = true;
16905 bi_arg += builtin_type_decl(builtin);
16906 bi_arg += string(" ") + builtin_to_glsl(builtin, storage: StorageClassInput);
16907 bi_arg += string(" [[") + builtin_qualifier(builtin) + string("]]");
16908 builtin_declaration = false;
16909
16910 return bi_arg;
16911}
16912
16913const SPIRType &CompilerMSL::get_physical_member_type(const SPIRType &type, uint32_t index) const
16914{
16915 if (member_is_remapped_physical_type(type, index))
16916 return get<SPIRType>(id: get_extended_member_decoration(type: type.self, index, decoration: SPIRVCrossDecorationPhysicalTypeID));
16917 else
16918 return get<SPIRType>(id: type.member_types[index]);
16919}
16920
16921SPIRType CompilerMSL::get_presumed_input_type(const SPIRType &ib_type, uint32_t index) const
16922{
16923 SPIRType type = get_physical_member_type(type: ib_type, index);
16924 uint32_t loc = get_member_decoration(id: ib_type.self, index, decoration: DecorationLocation);
16925 uint32_t cmp = get_member_decoration(id: ib_type.self, index, decoration: DecorationComponent);
16926 auto p_va = inputs_by_location.find(x: {.location: loc, .component: cmp});
16927 if (p_va != end(cont: inputs_by_location) && p_va->second.vecsize > type.vecsize)
16928 type.vecsize = p_va->second.vecsize;
16929
16930 return type;
16931}
16932
16933uint32_t CompilerMSL::get_declared_type_array_stride_msl(const SPIRType &type, bool is_packed, bool row_major) const
16934{
16935 // Array stride in MSL is always size * array_size. sizeof(float3) == 16,
16936 // unlike GLSL and HLSL where array stride would be 16 and size 12.
16937
16938 // We could use parent type here and recurse, but that makes creating physical type remappings
16939 // far more complicated. We'd rather just create the final type, and ignore having to create the entire type
16940 // hierarchy in order to compute this value, so make a temporary type on the stack.
16941
16942 auto basic_type = type;
16943 basic_type.array.clear();
16944 basic_type.array_size_literal.clear();
16945 uint32_t value_size = get_declared_type_size_msl(type: basic_type, packed: is_packed, row_major);
16946
16947 uint32_t dimensions = uint32_t(type.array.size());
16948 assert(dimensions > 0);
16949 dimensions--;
16950
16951 // Multiply together every dimension, except the last one.
16952 for (uint32_t dim = 0; dim < dimensions; dim++)
16953 {
16954 uint32_t array_size = to_array_size_literal(type, index: dim);
16955 value_size *= max<uint32_t>(a: array_size, b: 1u);
16956 }
16957
16958 return value_size;
16959}
16960
16961uint32_t CompilerMSL::get_declared_struct_member_array_stride_msl(const SPIRType &type, uint32_t index) const
16962{
16963 return get_declared_type_array_stride_msl(type: get_physical_member_type(type, index),
16964 is_packed: member_is_packed_physical_type(type, index),
16965 row_major: has_member_decoration(id: type.self, index, decoration: DecorationRowMajor));
16966}
16967
16968uint32_t CompilerMSL::get_declared_input_array_stride_msl(const SPIRType &type, uint32_t index) const
16969{
16970 return get_declared_type_array_stride_msl(type: get_presumed_input_type(ib_type: type, index), is_packed: false,
16971 row_major: has_member_decoration(id: type.self, index, decoration: DecorationRowMajor));
16972}
16973
16974uint32_t CompilerMSL::get_declared_type_matrix_stride_msl(const SPIRType &type, bool packed, bool row_major) const
16975{
16976 // For packed matrices, we just use the size of the vector type.
16977 // Otherwise, MatrixStride == alignment, which is the size of the underlying vector type.
16978 if (packed)
16979 return (type.width / 8) * ((row_major && type.columns > 1) ? type.columns : type.vecsize);
16980 else
16981 return get_declared_type_alignment_msl(type, packed: false, row_major);
16982}
16983
16984uint32_t CompilerMSL::get_declared_struct_member_matrix_stride_msl(const SPIRType &type, uint32_t index) const
16985{
16986 return get_declared_type_matrix_stride_msl(type: get_physical_member_type(type, index),
16987 packed: member_is_packed_physical_type(type, index),
16988 row_major: has_member_decoration(id: type.self, index, decoration: DecorationRowMajor));
16989}
16990
16991uint32_t CompilerMSL::get_declared_input_matrix_stride_msl(const SPIRType &type, uint32_t index) const
16992{
16993 return get_declared_type_matrix_stride_msl(type: get_presumed_input_type(ib_type: type, index), packed: false,
16994 row_major: has_member_decoration(id: type.self, index, decoration: DecorationRowMajor));
16995}
16996
16997uint32_t CompilerMSL::get_declared_struct_size_msl(const SPIRType &struct_type, bool ignore_alignment,
16998 bool ignore_padding) const
16999{
17000 // If we have a target size, that is the declared size as well.
17001 if (!ignore_padding && has_extended_decoration(id: struct_type.self, decoration: SPIRVCrossDecorationPaddingTarget))
17002 return get_extended_decoration(id: struct_type.self, decoration: SPIRVCrossDecorationPaddingTarget);
17003
17004 if (struct_type.member_types.empty())
17005 return 0;
17006
17007 uint32_t mbr_cnt = uint32_t(struct_type.member_types.size());
17008
17009 // In MSL, a struct's alignment is equal to the maximum alignment of any of its members.
17010 uint32_t alignment = 1;
17011
17012 if (!ignore_alignment)
17013 {
17014 for (uint32_t i = 0; i < mbr_cnt; i++)
17015 {
17016 uint32_t mbr_alignment = get_declared_struct_member_alignment_msl(struct_type, index: i);
17017 alignment = max(a: alignment, b: mbr_alignment);
17018 }
17019 }
17020
17021 // Last member will always be matched to the final Offset decoration, but size of struct in MSL now depends
17022 // on physical size in MSL, and the size of the struct itself is then aligned to struct alignment.
17023 uint32_t spirv_offset = type_struct_member_offset(type: struct_type, index: mbr_cnt - 1);
17024 uint32_t msl_size = spirv_offset + get_declared_struct_member_size_msl(struct_type, index: mbr_cnt - 1);
17025 msl_size = (msl_size + alignment - 1) & ~(alignment - 1);
17026 return msl_size;
17027}
17028
17029// Returns the byte size of a struct member.
17030uint32_t CompilerMSL::get_declared_type_size_msl(const SPIRType &type, bool is_packed, bool row_major) const
17031{
17032 // Pointers take 8 bytes each
17033 if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer)
17034 {
17035 uint32_t type_size = 8 * (type.vecsize == 3 ? 4 : type.vecsize);
17036
17037 // Work our way through potentially layered arrays,
17038 // stopping when we hit a pointer that is not also an array.
17039 int32_t dim_idx = (int32_t)type.array.size() - 1;
17040 auto *p_type = &type;
17041 while (!is_pointer(type: *p_type) && dim_idx >= 0)
17042 {
17043 type_size *= to_array_size_literal(type: *p_type, index: dim_idx);
17044 p_type = &get<SPIRType>(id: p_type->parent_type);
17045 dim_idx--;
17046 }
17047
17048 return type_size;
17049 }
17050
17051 switch (type.basetype)
17052 {
17053 case SPIRType::Unknown:
17054 case SPIRType::Void:
17055 case SPIRType::AtomicCounter:
17056 case SPIRType::Image:
17057 case SPIRType::SampledImage:
17058 case SPIRType::Sampler:
17059 SPIRV_CROSS_THROW("Querying size of opaque object.");
17060
17061 default:
17062 {
17063 if (!type.array.empty())
17064 {
17065 uint32_t array_size = to_array_size_literal(type);
17066 return get_declared_type_array_stride_msl(type, is_packed, row_major) * max<uint32_t>(a: array_size, b: 1u);
17067 }
17068
17069 if (type.basetype == SPIRType::Struct)
17070 return get_declared_struct_size_msl(struct_type: type);
17071
17072 if (is_packed)
17073 {
17074 return type.vecsize * type.columns * (type.width / 8);
17075 }
17076 else
17077 {
17078 // An unpacked 3-element vector or matrix column is the same memory size as a 4-element.
17079 uint32_t vecsize = type.vecsize;
17080 uint32_t columns = type.columns;
17081
17082 if (row_major && columns > 1)
17083 swap(a&: vecsize, b&: columns);
17084
17085 if (vecsize == 3)
17086 vecsize = 4;
17087
17088 return vecsize * columns * (type.width / 8);
17089 }
17090 }
17091 }
17092}
17093
17094uint32_t CompilerMSL::get_declared_struct_member_size_msl(const SPIRType &type, uint32_t index) const
17095{
17096 return get_declared_type_size_msl(type: get_physical_member_type(type, index),
17097 is_packed: member_is_packed_physical_type(type, index),
17098 row_major: has_member_decoration(id: type.self, index, decoration: DecorationRowMajor));
17099}
17100
17101uint32_t CompilerMSL::get_declared_input_size_msl(const SPIRType &type, uint32_t index) const
17102{
17103 return get_declared_type_size_msl(type: get_presumed_input_type(ib_type: type, index), is_packed: false,
17104 row_major: has_member_decoration(id: type.self, index, decoration: DecorationRowMajor));
17105}
17106
17107// Returns the byte alignment of a type.
17108uint32_t CompilerMSL::get_declared_type_alignment_msl(const SPIRType &type, bool is_packed, bool row_major) const
17109{
17110 // Pointers aligns on multiples of 8 bytes
17111 if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer)
17112 return 8 * (type.vecsize == 3 ? 4 : type.vecsize);
17113
17114 switch (type.basetype)
17115 {
17116 case SPIRType::Unknown:
17117 case SPIRType::Void:
17118 case SPIRType::AtomicCounter:
17119 case SPIRType::Image:
17120 case SPIRType::SampledImage:
17121 case SPIRType::Sampler:
17122 SPIRV_CROSS_THROW("Querying alignment of opaque object.");
17123
17124 case SPIRType::Double:
17125 SPIRV_CROSS_THROW("double types are not supported in buffers in MSL.");
17126
17127 case SPIRType::Struct:
17128 {
17129 // In MSL, a struct's alignment is equal to the maximum alignment of any of its members.
17130 uint32_t alignment = 1;
17131 for (uint32_t i = 0; i < type.member_types.size(); i++)
17132 alignment = max(a: alignment, b: uint32_t(get_declared_struct_member_alignment_msl(struct_type: type, index: i)));
17133 return alignment;
17134 }
17135
17136 default:
17137 {
17138 if (type.basetype == SPIRType::Int64 && !msl_options.supports_msl_version(major: 2, minor: 3))
17139 SPIRV_CROSS_THROW("long types in buffers are only supported in MSL 2.3 and above.");
17140 if (type.basetype == SPIRType::UInt64 && !msl_options.supports_msl_version(major: 2, minor: 3))
17141 SPIRV_CROSS_THROW("ulong types in buffers are only supported in MSL 2.3 and above.");
17142 // Alignment of packed type is the same as the underlying component or column size.
17143 // Alignment of unpacked type is the same as the vector size.
17144 // Alignment of 3-elements vector is the same as 4-elements (including packed using column).
17145 if (is_packed)
17146 {
17147 // If we have packed_T and friends, the alignment is always scalar.
17148 return type.width / 8;
17149 }
17150 else
17151 {
17152 // This is the general rule for MSL. Size == alignment.
17153 uint32_t vecsize = (row_major && type.columns > 1) ? type.columns : type.vecsize;
17154 return (type.width / 8) * (vecsize == 3 ? 4 : vecsize);
17155 }
17156 }
17157 }
17158}
17159
17160uint32_t CompilerMSL::get_declared_struct_member_alignment_msl(const SPIRType &type, uint32_t index) const
17161{
17162 return get_declared_type_alignment_msl(type: get_physical_member_type(type, index),
17163 is_packed: member_is_packed_physical_type(type, index),
17164 row_major: has_member_decoration(id: type.self, index, decoration: DecorationRowMajor));
17165}
17166
17167uint32_t CompilerMSL::get_declared_input_alignment_msl(const SPIRType &type, uint32_t index) const
17168{
17169 return get_declared_type_alignment_msl(type: get_presumed_input_type(ib_type: type, index), is_packed: false,
17170 row_major: has_member_decoration(id: type.self, index, decoration: DecorationRowMajor));
17171}
17172
17173bool CompilerMSL::skip_argument(uint32_t) const
17174{
17175 return false;
17176}
17177
17178void CompilerMSL::analyze_sampled_image_usage()
17179{
17180 if (msl_options.swizzle_texture_samples)
17181 {
17182 SampledImageScanner scanner(*this);
17183 traverse_all_reachable_opcodes(block: get<SPIRFunction>(id: ir.default_entry_point), handler&: scanner);
17184 }
17185}
17186
17187bool CompilerMSL::SampledImageScanner::handle(spv::Op opcode, const uint32_t *args, uint32_t length)
17188{
17189 switch (opcode)
17190 {
17191 case OpLoad:
17192 case OpImage:
17193 case OpSampledImage:
17194 {
17195 if (length < 3)
17196 return false;
17197
17198 uint32_t result_type = args[0];
17199 auto &type = compiler.get<SPIRType>(id: result_type);
17200 if ((type.basetype != SPIRType::Image && type.basetype != SPIRType::SampledImage) || type.image.sampled != 1)
17201 return true;
17202
17203 uint32_t id = args[1];
17204 compiler.set<SPIRExpression>(id, args: "", args&: result_type, args: true);
17205 break;
17206 }
17207 case OpImageSampleExplicitLod:
17208 case OpImageSampleProjExplicitLod:
17209 case OpImageSampleDrefExplicitLod:
17210 case OpImageSampleProjDrefExplicitLod:
17211 case OpImageSampleImplicitLod:
17212 case OpImageSampleProjImplicitLod:
17213 case OpImageSampleDrefImplicitLod:
17214 case OpImageSampleProjDrefImplicitLod:
17215 case OpImageFetch:
17216 case OpImageGather:
17217 case OpImageDrefGather:
17218 compiler.has_sampled_images =
17219 compiler.has_sampled_images || compiler.is_sampled_image_type(type: compiler.expression_type(id: args[2]));
17220 compiler.needs_swizzle_buffer_def = compiler.needs_swizzle_buffer_def || compiler.has_sampled_images;
17221 break;
17222 default:
17223 break;
17224 }
17225 return true;
17226}
17227
17228// If a needed custom function wasn't added before, add it and force a recompile.
17229void CompilerMSL::add_spv_func_and_recompile(SPVFuncImpl spv_func)
17230{
17231 if (spv_function_implementations.count(x: spv_func) == 0)
17232 {
17233 spv_function_implementations.insert(x: spv_func);
17234 suppress_missing_prototypes = true;
17235 force_recompile();
17236 }
17237}
17238
17239bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, uint32_t length)
17240{
17241 // Since MSL exists in a single execution scope, function prototype declarations are not
17242 // needed, and clutter the output. If secondary functions are output (either as a SPIR-V
17243 // function implementation or as indicated by the presence of OpFunctionCall), then set
17244 // suppress_missing_prototypes to suppress compiler warnings of missing function prototypes.
17245
17246 // Mark if the input requires the implementation of an SPIR-V function that does not exist in Metal.
17247 SPVFuncImpl spv_func = get_spv_func_impl(opcode, args);
17248 if (spv_func != SPVFuncImplNone)
17249 {
17250 compiler.spv_function_implementations.insert(x: spv_func);
17251 suppress_missing_prototypes = true;
17252 }
17253
17254 switch (opcode)
17255 {
17256
17257 case OpFunctionCall:
17258 suppress_missing_prototypes = true;
17259 break;
17260
17261 case OpDemoteToHelperInvocationEXT:
17262 uses_discard = true;
17263 break;
17264
17265 // Emulate texture2D atomic operations
17266 case OpImageTexelPointer:
17267 {
17268 if (!compiler.msl_options.supports_msl_version(major: 3, minor: 1))
17269 {
17270 auto *var = compiler.maybe_get_backing_variable(chain: args[2]);
17271 image_pointers_emulated[args[1]] = var ? var->self : ID(0);
17272 }
17273 break;
17274 }
17275
17276 case OpImageWrite:
17277 uses_image_write = true;
17278 break;
17279
17280 case OpStore:
17281 check_resource_write(var_id: args[0]);
17282 break;
17283
17284 // Emulate texture2D atomic operations
17285 case OpAtomicExchange:
17286 case OpAtomicCompareExchange:
17287 case OpAtomicCompareExchangeWeak:
17288 case OpAtomicIIncrement:
17289 case OpAtomicIDecrement:
17290 case OpAtomicIAdd:
17291 case OpAtomicFAddEXT:
17292 case OpAtomicISub:
17293 case OpAtomicSMin:
17294 case OpAtomicUMin:
17295 case OpAtomicSMax:
17296 case OpAtomicUMax:
17297 case OpAtomicAnd:
17298 case OpAtomicOr:
17299 case OpAtomicXor:
17300 {
17301 uses_atomics = true;
17302 auto it = image_pointers_emulated.find(x: args[2]);
17303 if (it != image_pointers_emulated.end())
17304 {
17305 uses_image_write = true;
17306 compiler.atomic_image_vars_emulated.insert(x: it->second);
17307 }
17308 else
17309 check_resource_write(var_id: args[2]);
17310 break;
17311 }
17312
17313 case OpAtomicStore:
17314 {
17315 uses_atomics = true;
17316 auto it = image_pointers_emulated.find(x: args[0]);
17317 if (it != image_pointers_emulated.end())
17318 {
17319 compiler.atomic_image_vars_emulated.insert(x: it->second);
17320 uses_image_write = true;
17321 }
17322 else
17323 check_resource_write(var_id: args[0]);
17324 break;
17325 }
17326
17327 case OpAtomicLoad:
17328 {
17329 uses_atomics = true;
17330 auto it = image_pointers_emulated.find(x: args[2]);
17331 if (it != image_pointers_emulated.end())
17332 {
17333 compiler.atomic_image_vars_emulated.insert(x: it->second);
17334 }
17335 break;
17336 }
17337
17338 case OpGroupNonUniformInverseBallot:
17339 needs_subgroup_invocation_id = true;
17340 break;
17341
17342 case OpGroupNonUniformBallotFindLSB:
17343 case OpGroupNonUniformBallotFindMSB:
17344 needs_subgroup_size = true;
17345 break;
17346
17347 case OpGroupNonUniformBallotBitCount:
17348 if (args[3] == GroupOperationReduce)
17349 needs_subgroup_size = true;
17350 else
17351 needs_subgroup_invocation_id = true;
17352 break;
17353
17354 case OpArrayLength:
17355 {
17356 auto *var = compiler.maybe_get_backing_variable(chain: args[2]);
17357 if (var != nullptr)
17358 {
17359 if (!compiler.is_var_runtime_size_array(var: *var))
17360 compiler.buffers_requiring_array_length.insert(x: var->self);
17361 }
17362 break;
17363 }
17364
17365 case OpInBoundsAccessChain:
17366 case OpAccessChain:
17367 case OpPtrAccessChain:
17368 {
17369 // OpArrayLength might want to know if taking ArrayLength of an array of SSBOs.
17370 uint32_t result_type = args[0];
17371 uint32_t id = args[1];
17372 uint32_t ptr = args[2];
17373
17374 compiler.set<SPIRExpression>(id, args: "", args&: result_type, args: true);
17375 compiler.register_read(expr: id, chain: ptr, forwarded: true);
17376 compiler.ir.ids[id].set_allow_type_rewrite();
17377 break;
17378 }
17379
17380 case OpExtInst:
17381 {
17382 uint32_t extension_set = args[2];
17383 if (compiler.get<SPIRExtension>(id: extension_set).ext == SPIRExtension::GLSL)
17384 {
17385 auto op_450 = static_cast<GLSLstd450>(args[3]);
17386 switch (op_450)
17387 {
17388 case GLSLstd450InterpolateAtCentroid:
17389 case GLSLstd450InterpolateAtSample:
17390 case GLSLstd450InterpolateAtOffset:
17391 {
17392 if (!compiler.msl_options.supports_msl_version(major: 2, minor: 3))
17393 SPIRV_CROSS_THROW("Pull-model interpolation requires MSL 2.3.");
17394 // Fragment varyings used with pull-model interpolation need special handling,
17395 // due to the way pull-model interpolation works in Metal.
17396 auto *var = compiler.maybe_get_backing_variable(chain: args[4]);
17397 if (var)
17398 {
17399 compiler.pull_model_inputs.insert(x: var->self);
17400 auto &var_type = compiler.get_variable_element_type(var: *var);
17401 // In addition, if this variable has a 'Sample' decoration, we need the sample ID
17402 // in order to do default interpolation.
17403 if (compiler.has_decoration(id: var->self, decoration: DecorationSample))
17404 {
17405 needs_sample_id = true;
17406 }
17407 else if (var_type.basetype == SPIRType::Struct)
17408 {
17409 // Now we need to check each member and see if it has this decoration.
17410 for (uint32_t i = 0; i < var_type.member_types.size(); ++i)
17411 {
17412 if (compiler.has_member_decoration(id: var_type.self, index: i, decoration: DecorationSample))
17413 {
17414 needs_sample_id = true;
17415 break;
17416 }
17417 }
17418 }
17419 }
17420 break;
17421 }
17422 default:
17423 break;
17424 }
17425 }
17426 break;
17427 }
17428
17429 case OpIsHelperInvocationEXT:
17430 if (compiler.needs_manual_helper_invocation_updates())
17431 needs_helper_invocation = true;
17432 break;
17433
17434 default:
17435 break;
17436 }
17437
17438 // If it has one, keep track of the instruction's result type, mapped by ID
17439 uint32_t result_type, result_id;
17440 if (compiler.instruction_to_result_type(result_type, result_id, op: opcode, args, length))
17441 result_types[result_id] = result_type;
17442
17443 return true;
17444}
17445
17446// If the variable is a Uniform or StorageBuffer, mark that a resource has been written to.
17447void CompilerMSL::OpCodePreprocessor::check_resource_write(uint32_t var_id)
17448{
17449 auto *p_var = compiler.maybe_get_backing_variable(chain: var_id);
17450 StorageClass sc = p_var ? p_var->storage : StorageClassMax;
17451 if (sc == StorageClassUniform || sc == StorageClassStorageBuffer)
17452 uses_buffer_write = true;
17453}
17454
17455// Returns an enumeration of a SPIR-V function that needs to be output for certain Op codes.
17456CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op opcode, const uint32_t *args)
17457{
17458 switch (opcode)
17459 {
17460 case OpFMod:
17461 return SPVFuncImplMod;
17462
17463 case OpFAdd:
17464 case OpFSub:
17465 if (compiler.msl_options.invariant_float_math ||
17466 compiler.has_decoration(id: args[1], decoration: DecorationNoContraction))
17467 {
17468 return opcode == OpFAdd ? SPVFuncImplFAdd : SPVFuncImplFSub;
17469 }
17470 break;
17471
17472 case OpFMul:
17473 case OpOuterProduct:
17474 case OpMatrixTimesVector:
17475 case OpVectorTimesMatrix:
17476 case OpMatrixTimesMatrix:
17477 if (compiler.msl_options.invariant_float_math ||
17478 compiler.has_decoration(id: args[1], decoration: DecorationNoContraction))
17479 {
17480 return SPVFuncImplFMul;
17481 }
17482 break;
17483
17484 case OpQuantizeToF16:
17485 return SPVFuncImplQuantizeToF16;
17486
17487 case OpTypeArray:
17488 {
17489 // Allow Metal to use the array<T> template to make arrays a value type
17490 return SPVFuncImplUnsafeArray;
17491 }
17492
17493 // Emulate texture2D atomic operations
17494 case OpAtomicExchange:
17495 case OpAtomicCompareExchange:
17496 case OpAtomicCompareExchangeWeak:
17497 case OpAtomicIIncrement:
17498 case OpAtomicIDecrement:
17499 case OpAtomicIAdd:
17500 case OpAtomicFAddEXT:
17501 case OpAtomicISub:
17502 case OpAtomicSMin:
17503 case OpAtomicUMin:
17504 case OpAtomicSMax:
17505 case OpAtomicUMax:
17506 case OpAtomicAnd:
17507 case OpAtomicOr:
17508 case OpAtomicXor:
17509 case OpAtomicLoad:
17510 case OpAtomicStore:
17511 {
17512 auto it = image_pointers_emulated.find(x: args[opcode == OpAtomicStore ? 0 : 2]);
17513 if (it != image_pointers_emulated.end())
17514 {
17515 uint32_t tid = compiler.get<SPIRVariable>(id: it->second).basetype;
17516 if (tid && compiler.get<SPIRType>(id: tid).image.dim == Dim2D)
17517 return SPVFuncImplImage2DAtomicCoords;
17518 }
17519 break;
17520 }
17521
17522 case OpImageFetch:
17523 case OpImageRead:
17524 case OpImageWrite:
17525 {
17526 // Retrieve the image type, and if it's a Buffer, emit a texel coordinate function
17527 uint32_t tid = result_types[args[opcode == OpImageWrite ? 0 : 2]];
17528 if (tid && compiler.get<SPIRType>(id: tid).image.dim == DimBuffer && !compiler.msl_options.texture_buffer_native)
17529 return SPVFuncImplTexelBufferCoords;
17530 break;
17531 }
17532
17533 case OpExtInst:
17534 {
17535 uint32_t extension_set = args[2];
17536 if (compiler.get<SPIRExtension>(id: extension_set).ext == SPIRExtension::GLSL)
17537 {
17538 auto op_450 = static_cast<GLSLstd450>(args[3]);
17539 switch (op_450)
17540 {
17541 case GLSLstd450Radians:
17542 return SPVFuncImplRadians;
17543 case GLSLstd450Degrees:
17544 return SPVFuncImplDegrees;
17545 case GLSLstd450FindILsb:
17546 return SPVFuncImplFindILsb;
17547 case GLSLstd450FindSMsb:
17548 return SPVFuncImplFindSMsb;
17549 case GLSLstd450FindUMsb:
17550 return SPVFuncImplFindUMsb;
17551 case GLSLstd450SSign:
17552 return SPVFuncImplSSign;
17553 case GLSLstd450Reflect:
17554 {
17555 auto &type = compiler.get<SPIRType>(id: args[0]);
17556 if (type.vecsize == 1)
17557 return SPVFuncImplReflectScalar;
17558 break;
17559 }
17560 case GLSLstd450Refract:
17561 {
17562 auto &type = compiler.get<SPIRType>(id: args[0]);
17563 if (type.vecsize == 1)
17564 return SPVFuncImplRefractScalar;
17565 break;
17566 }
17567 case GLSLstd450FaceForward:
17568 {
17569 auto &type = compiler.get<SPIRType>(id: args[0]);
17570 if (type.vecsize == 1)
17571 return SPVFuncImplFaceForwardScalar;
17572 break;
17573 }
17574 case GLSLstd450MatrixInverse:
17575 {
17576 auto &mat_type = compiler.get<SPIRType>(id: args[0]);
17577 switch (mat_type.columns)
17578 {
17579 case 2:
17580 return SPVFuncImplInverse2x2;
17581 case 3:
17582 return SPVFuncImplInverse3x3;
17583 case 4:
17584 return SPVFuncImplInverse4x4;
17585 default:
17586 break;
17587 }
17588 break;
17589 }
17590 default:
17591 break;
17592 }
17593 }
17594 break;
17595 }
17596
17597 case OpGroupNonUniformBroadcast:
17598 case OpSubgroupReadInvocationKHR:
17599 return SPVFuncImplSubgroupBroadcast;
17600
17601 case OpGroupNonUniformBroadcastFirst:
17602 case OpSubgroupFirstInvocationKHR:
17603 return SPVFuncImplSubgroupBroadcastFirst;
17604
17605 case OpGroupNonUniformBallot:
17606 case OpSubgroupBallotKHR:
17607 return SPVFuncImplSubgroupBallot;
17608
17609 case OpGroupNonUniformInverseBallot:
17610 case OpGroupNonUniformBallotBitExtract:
17611 return SPVFuncImplSubgroupBallotBitExtract;
17612
17613 case OpGroupNonUniformBallotFindLSB:
17614 return SPVFuncImplSubgroupBallotFindLSB;
17615
17616 case OpGroupNonUniformBallotFindMSB:
17617 return SPVFuncImplSubgroupBallotFindMSB;
17618
17619 case OpGroupNonUniformBallotBitCount:
17620 return SPVFuncImplSubgroupBallotBitCount;
17621
17622 case OpGroupNonUniformAllEqual:
17623 case OpSubgroupAllEqualKHR:
17624 return SPVFuncImplSubgroupAllEqual;
17625
17626 case OpGroupNonUniformShuffle:
17627 return SPVFuncImplSubgroupShuffle;
17628
17629 case OpGroupNonUniformShuffleXor:
17630 return SPVFuncImplSubgroupShuffleXor;
17631
17632 case OpGroupNonUniformShuffleUp:
17633 return SPVFuncImplSubgroupShuffleUp;
17634
17635 case OpGroupNonUniformShuffleDown:
17636 return SPVFuncImplSubgroupShuffleDown;
17637
17638 case OpGroupNonUniformQuadBroadcast:
17639 return SPVFuncImplQuadBroadcast;
17640
17641 case OpGroupNonUniformQuadSwap:
17642 return SPVFuncImplQuadSwap;
17643
17644 case OpSDot:
17645 case OpUDot:
17646 case OpSUDot:
17647 case OpSDotAccSat:
17648 case OpUDotAccSat:
17649 case OpSUDotAccSat:
17650 return SPVFuncImplReduceAdd;
17651
17652 default:
17653 break;
17654 }
17655 return SPVFuncImplNone;
17656}
17657
17658// Sort both type and meta member content based on builtin status (put builtins at end),
17659// then by the required sorting aspect.
17660void CompilerMSL::MemberSorter::sort()
17661{
17662 // Create a temporary array of consecutive member indices and sort it based on how
17663 // the members should be reordered, based on builtin and sorting aspect meta info.
17664 size_t mbr_cnt = type.member_types.size();
17665 SmallVector<uint32_t> mbr_idxs(mbr_cnt);
17666 std::iota(first: mbr_idxs.begin(), last: mbr_idxs.end(), value: 0); // Fill with consecutive indices
17667 std::stable_sort(first: mbr_idxs.begin(), last: mbr_idxs.end(), comp: *this); // Sort member indices based on sorting aspect
17668
17669 bool sort_is_identity = true;
17670 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
17671 {
17672 if (mbr_idx != mbr_idxs[mbr_idx])
17673 {
17674 sort_is_identity = false;
17675 break;
17676 }
17677 }
17678
17679 if (sort_is_identity)
17680 return;
17681
17682 if (meta.members.size() < type.member_types.size())
17683 {
17684 // This should never trigger in normal circumstances, but to be safe.
17685 meta.members.resize(new_size: type.member_types.size());
17686 }
17687
17688 // Move type and meta member info to the order defined by the sorted member indices.
17689 // This is done by creating temporary copies of both member types and meta, and then
17690 // copying back to the original content at the sorted indices.
17691 auto mbr_types_cpy = type.member_types;
17692 auto mbr_meta_cpy = meta.members;
17693 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
17694 {
17695 type.member_types[mbr_idx] = mbr_types_cpy[mbr_idxs[mbr_idx]];
17696 meta.members[mbr_idx] = mbr_meta_cpy[mbr_idxs[mbr_idx]];
17697 }
17698
17699 // If we're sorting by Offset, this might affect user code which accesses a buffer block.
17700 // We will need to redirect member indices from defined index to sorted index using reverse lookup.
17701 if (sort_aspect == SortAspect::Offset)
17702 {
17703 type.member_type_index_redirection.resize(new_size: mbr_cnt);
17704 for (uint32_t map_idx = 0; map_idx < mbr_cnt; map_idx++)
17705 type.member_type_index_redirection[mbr_idxs[map_idx]] = map_idx;
17706 }
17707}
17708
17709bool CompilerMSL::MemberSorter::operator()(uint32_t mbr_idx1, uint32_t mbr_idx2)
17710{
17711 auto &mbr_meta1 = meta.members[mbr_idx1];
17712 auto &mbr_meta2 = meta.members[mbr_idx2];
17713
17714 if (sort_aspect == LocationThenBuiltInType)
17715 {
17716 // Sort first by builtin status (put builtins at end), then by the sorting aspect.
17717 if (mbr_meta1.builtin != mbr_meta2.builtin)
17718 return mbr_meta2.builtin;
17719 else if (mbr_meta1.builtin)
17720 return mbr_meta1.builtin_type < mbr_meta2.builtin_type;
17721 else if (mbr_meta1.location == mbr_meta2.location)
17722 return mbr_meta1.component < mbr_meta2.component;
17723 else
17724 return mbr_meta1.location < mbr_meta2.location;
17725 }
17726 else
17727 return mbr_meta1.offset < mbr_meta2.offset;
17728}
17729
17730CompilerMSL::MemberSorter::MemberSorter(SPIRType &t, Meta &m, SortAspect sa)
17731 : type(t)
17732 , meta(m)
17733 , sort_aspect(sa)
17734{
17735 // Ensure enough meta info is available
17736 meta.members.resize(new_size: max(a: type.member_types.size(), b: meta.members.size()));
17737}
17738
17739void CompilerMSL::remap_constexpr_sampler(VariableID id, const MSLConstexprSampler &sampler)
17740{
17741 auto &type = get<SPIRType>(id: get<SPIRVariable>(id).basetype);
17742 if (type.basetype != SPIRType::SampledImage && type.basetype != SPIRType::Sampler)
17743 SPIRV_CROSS_THROW("Can only remap SampledImage and Sampler type.");
17744 if (!type.array.empty())
17745 SPIRV_CROSS_THROW("Can not remap array of samplers.");
17746 constexpr_samplers_by_id[id] = sampler;
17747}
17748
17749void CompilerMSL::remap_constexpr_sampler_by_binding(uint32_t desc_set, uint32_t binding,
17750 const MSLConstexprSampler &sampler)
17751{
17752 constexpr_samplers_by_binding[{ .desc_set: desc_set, .binding: binding }] = sampler;
17753}
17754
17755void CompilerMSL::cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
17756{
17757 bool is_packed = has_extended_decoration(id: source_id, decoration: SPIRVCrossDecorationPhysicalTypePacked);
17758 auto *source_expr = maybe_get<SPIRExpression>(id: source_id);
17759 auto *var = maybe_get_backing_variable(chain: source_id);
17760 const SPIRType *var_type = nullptr, *phys_type = nullptr;
17761
17762 if (uint32_t phys_id = get_extended_decoration(id: source_id, decoration: SPIRVCrossDecorationPhysicalTypeID))
17763 phys_type = &get<SPIRType>(id: phys_id);
17764 else
17765 phys_type = &expr_type;
17766
17767 if (var)
17768 {
17769 source_id = var->self;
17770 var_type = &get_variable_data_type(var: *var);
17771 }
17772
17773 bool rewrite_boolean_load =
17774 expr_type.basetype == SPIRType::Boolean &&
17775 (var && (var->storage == StorageClassWorkgroup || var_type->basetype == SPIRType::Struct));
17776
17777 // Type fixups for workgroup variables if they are booleans.
17778 if (rewrite_boolean_load)
17779 {
17780 if (is_array(type: expr_type))
17781 expr = to_rerolled_array_expression(parent_type: expr_type, expr, type: expr_type);
17782 else
17783 expr = join(ts: type_to_glsl(type: expr_type), ts: "(", ts&: expr, ts: ")");
17784 }
17785
17786 // Type fixups for workgroup variables if they are matrices.
17787 // Don't do fixup for packed types; those are handled specially.
17788 // FIXME: Maybe use a type like spvStorageMatrix for packed matrices?
17789 if (!msl_options.supports_msl_version(major: 3, minor: 0) && var &&
17790 (var->storage == StorageClassWorkgroup ||
17791 (var_type->basetype == SPIRType::Struct &&
17792 has_extended_decoration(id: var_type->self, decoration: SPIRVCrossDecorationWorkgroupStruct) && !is_packed)) &&
17793 expr_type.columns > 1)
17794 {
17795 SPIRType matrix_type = *phys_type;
17796 if (source_expr && source_expr->need_transpose)
17797 swap(a&: matrix_type.vecsize, b&: matrix_type.columns);
17798 matrix_type.array.clear();
17799 matrix_type.array_size_literal.clear();
17800 expr = join(ts: type_to_glsl(type: matrix_type), ts: "(", ts&: expr, ts: ")");
17801 }
17802
17803 // Only interested in standalone builtin variables in the switch below.
17804 if (!has_decoration(id: source_id, decoration: DecorationBuiltIn))
17805 {
17806 // If the backing variable does not match our expected sign, we can fix it up here.
17807 // See ensure_correct_input_type().
17808 if (var && var->storage == StorageClassInput)
17809 {
17810 auto &base_type = get<SPIRType>(id: var->basetype);
17811 if (base_type.basetype != SPIRType::Struct && expr_type.basetype != base_type.basetype)
17812 expr = join(ts: type_to_glsl(type: expr_type), ts: "(", ts&: expr, ts: ")");
17813 }
17814 return;
17815 }
17816
17817 auto builtin = static_cast<BuiltIn>(get_decoration(id: source_id, decoration: DecorationBuiltIn));
17818 auto expected_type = expr_type.basetype;
17819 auto expected_width = expr_type.width;
17820 switch (builtin)
17821 {
17822 case BuiltInGlobalInvocationId:
17823 case BuiltInLocalInvocationId:
17824 case BuiltInWorkgroupId:
17825 case BuiltInLocalInvocationIndex:
17826 case BuiltInWorkgroupSize:
17827 case BuiltInNumWorkgroups:
17828 case BuiltInLayer:
17829 case BuiltInViewportIndex:
17830 case BuiltInFragStencilRefEXT:
17831 case BuiltInPrimitiveId:
17832 case BuiltInSubgroupSize:
17833 case BuiltInSubgroupLocalInvocationId:
17834 case BuiltInViewIndex:
17835 case BuiltInVertexIndex:
17836 case BuiltInInstanceIndex:
17837 case BuiltInBaseInstance:
17838 case BuiltInBaseVertex:
17839 case BuiltInSampleMask:
17840 expected_type = SPIRType::UInt;
17841 expected_width = 32;
17842 break;
17843
17844 case BuiltInTessLevelInner:
17845 case BuiltInTessLevelOuter:
17846 if (is_tesc_shader())
17847 {
17848 expected_type = SPIRType::Half;
17849 expected_width = 16;
17850 }
17851 break;
17852
17853 default:
17854 break;
17855 }
17856
17857 if (is_array(type: expr_type) && builtin == BuiltInSampleMask)
17858 {
17859 // Needs special handling.
17860 auto wrap_expr = join(ts: type_to_glsl(type: expr_type), ts: "({ ");
17861 wrap_expr += join(ts: type_to_glsl(type: get<SPIRType>(id: expr_type.parent_type)), ts: "(", ts&: expr, ts: ")");
17862 wrap_expr += " })";
17863 expr = std::move(wrap_expr);
17864 }
17865 else if (expected_type != expr_type.basetype)
17866 {
17867 if (is_array(type: expr_type) && (builtin == BuiltInTessLevelInner || builtin == BuiltInTessLevelOuter))
17868 {
17869 // Triggers when loading TessLevel directly as an array.
17870 // Need explicit padding + cast.
17871 auto wrap_expr = join(ts: type_to_glsl(type: expr_type), ts: "({ ");
17872
17873 uint32_t array_size = get_physical_tess_level_array_size(builtin);
17874 for (uint32_t i = 0; i < array_size; i++)
17875 {
17876 if (array_size > 1)
17877 wrap_expr += join(ts: "float(", ts&: expr, ts: "[", ts&: i, ts: "])");
17878 else
17879 wrap_expr += join(ts: "float(", ts&: expr, ts: ")");
17880 if (i + 1 < array_size)
17881 wrap_expr += ", ";
17882 }
17883
17884 if (is_tessellating_triangles())
17885 wrap_expr += ", 0.0";
17886
17887 wrap_expr += " })";
17888 expr = std::move(wrap_expr);
17889 }
17890 else
17891 {
17892 // These are of different widths, so we cannot do a straight bitcast.
17893 if (expected_width != expr_type.width)
17894 expr = join(ts: type_to_glsl(type: expr_type), ts: "(", ts&: expr, ts: ")");
17895 else
17896 expr = bitcast_expression(target_type: expr_type, expr_type: expected_type, expr);
17897 }
17898 }
17899}
17900
17901void CompilerMSL::cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
17902{
17903 bool is_packed = has_extended_decoration(id: target_id, decoration: SPIRVCrossDecorationPhysicalTypePacked);
17904 auto *target_expr = maybe_get<SPIRExpression>(id: target_id);
17905 auto *var = maybe_get_backing_variable(chain: target_id);
17906 const SPIRType *var_type = nullptr, *phys_type = nullptr;
17907
17908 if (uint32_t phys_id = get_extended_decoration(id: target_id, decoration: SPIRVCrossDecorationPhysicalTypeID))
17909 phys_type = &get<SPIRType>(id: phys_id);
17910 else
17911 phys_type = &expr_type;
17912
17913 if (var)
17914 {
17915 target_id = var->self;
17916 var_type = &get_variable_data_type(var: *var);
17917 }
17918
17919 bool rewrite_boolean_store =
17920 expr_type.basetype == SPIRType::Boolean &&
17921 (var && (var->storage == StorageClassWorkgroup || var_type->basetype == SPIRType::Struct));
17922
17923 // Type fixups for workgroup variables or struct members if they are booleans.
17924 if (rewrite_boolean_store)
17925 {
17926 if (is_array(type: expr_type))
17927 {
17928 expr = to_rerolled_array_expression(parent_type: *var_type, expr, type: expr_type);
17929 }
17930 else
17931 {
17932 auto short_type = expr_type;
17933 short_type.basetype = SPIRType::Short;
17934 expr = join(ts: type_to_glsl(type: short_type), ts: "(", ts&: expr, ts: ")");
17935 }
17936 }
17937
17938 // Type fixups for workgroup variables if they are matrices.
17939 // Don't do fixup for packed types; those are handled specially.
17940 // FIXME: Maybe use a type like spvStorageMatrix for packed matrices?
17941 if (!msl_options.supports_msl_version(major: 3, minor: 0) && var &&
17942 (var->storage == StorageClassWorkgroup ||
17943 (var_type->basetype == SPIRType::Struct &&
17944 has_extended_decoration(id: var_type->self, decoration: SPIRVCrossDecorationWorkgroupStruct) && !is_packed)) &&
17945 expr_type.columns > 1)
17946 {
17947 SPIRType matrix_type = *phys_type;
17948 if (target_expr && target_expr->need_transpose)
17949 swap(a&: matrix_type.vecsize, b&: matrix_type.columns);
17950 expr = join(ts: "spvStorage_", ts: type_to_glsl(type: matrix_type), ts: "(", ts&: expr, ts: ")");
17951 }
17952
17953 // Only interested in standalone builtin variables.
17954 if (!has_decoration(id: target_id, decoration: DecorationBuiltIn))
17955 return;
17956
17957 auto builtin = static_cast<BuiltIn>(get_decoration(id: target_id, decoration: DecorationBuiltIn));
17958 auto expected_type = expr_type.basetype;
17959 auto expected_width = expr_type.width;
17960 switch (builtin)
17961 {
17962 case BuiltInLayer:
17963 case BuiltInViewportIndex:
17964 case BuiltInFragStencilRefEXT:
17965 case BuiltInPrimitiveId:
17966 case BuiltInViewIndex:
17967 expected_type = SPIRType::UInt;
17968 expected_width = 32;
17969 break;
17970
17971 case BuiltInTessLevelInner:
17972 case BuiltInTessLevelOuter:
17973 expected_type = SPIRType::Half;
17974 expected_width = 16;
17975 break;
17976
17977 default:
17978 break;
17979 }
17980
17981 if (expected_type != expr_type.basetype)
17982 {
17983 if (expected_width != expr_type.width)
17984 {
17985 // These are of different widths, so we cannot do a straight bitcast.
17986 auto type = expr_type;
17987 type.basetype = expected_type;
17988 type.width = expected_width;
17989 expr = join(ts: type_to_glsl(type), ts: "(", ts&: expr, ts: ")");
17990 }
17991 else
17992 {
17993 auto type = expr_type;
17994 type.basetype = expected_type;
17995 expr = bitcast_expression(target_type: type, expr_type: expr_type.basetype, expr);
17996 }
17997 }
17998}
17999
18000string CompilerMSL::to_initializer_expression(const SPIRVariable &var)
18001{
18002 // We risk getting an array initializer here with MSL. If we have an array.
18003 // FIXME: We cannot handle non-constant arrays being initialized.
18004 // We will need to inject spvArrayCopy here somehow ...
18005 auto &type = get<SPIRType>(id: var.basetype);
18006 string expr;
18007 if (ir.ids[var.initializer].get_type() == TypeConstant &&
18008 (!type.array.empty() || type.basetype == SPIRType::Struct))
18009 expr = constant_expression(c: get<SPIRConstant>(id: var.initializer));
18010 else
18011 expr = CompilerGLSL::to_initializer_expression(var);
18012 // If the initializer has more vector components than the variable, add a swizzle.
18013 // FIXME: This can't handle arrays or structs.
18014 auto &init_type = expression_type(id: var.initializer);
18015 if (type.array.empty() && type.basetype != SPIRType::Struct && init_type.vecsize > type.vecsize)
18016 expr = enclose_expression(expr: expr + vector_swizzle(vecsize: type.vecsize, index: 0));
18017 return expr;
18018}
18019
18020string CompilerMSL::to_zero_initialized_expression(uint32_t)
18021{
18022 return "{}";
18023}
18024
18025bool CompilerMSL::descriptor_set_is_argument_buffer(uint32_t desc_set) const
18026{
18027 if (!msl_options.argument_buffers)
18028 return false;
18029 if (desc_set >= kMaxArgumentBuffers)
18030 return false;
18031
18032 return (argument_buffer_discrete_mask & (1u << desc_set)) == 0;
18033}
18034
18035bool CompilerMSL::is_supported_argument_buffer_type(const SPIRType &type) const
18036{
18037 // iOS Tier 1 argument buffers do not support writable images.
18038 // When the argument buffer is encoded, we don't know whether this image will have a
18039 // NonWritable decoration, so just use discrete arguments for all storage images on iOS.
18040 bool is_supported_type = !(type.basetype == SPIRType::Image &&
18041 type.image.sampled == 2 &&
18042 msl_options.is_ios() &&
18043 msl_options.argument_buffers_tier <= Options::ArgumentBuffersTier::Tier1);
18044 return is_supported_type && !type_is_msl_framebuffer_fetch(type);
18045}
18046
18047void CompilerMSL::emit_argument_buffer_aliased_descriptor(const SPIRVariable &aliased_var,
18048 const SPIRVariable &base_var)
18049{
18050 // To deal with buffer <-> image aliasing, we need to perform an unholy UB ritual.
18051 // A texture type in Metal 3.0 is a pointer. However, we cannot simply cast a pointer to texture.
18052 // What we *can* do is to cast pointer-to-pointer to pointer-to-texture.
18053
18054 // We need to explicitly reach into the descriptor buffer lvalue, not any spvDescriptorArray wrapper.
18055 auto *var_meta = ir.find_meta(id: base_var.self);
18056 bool old_explicit_qualifier = var_meta && var_meta->decoration.qualified_alias_explicit_override;
18057 if (var_meta)
18058 var_meta->decoration.qualified_alias_explicit_override = false;
18059 auto unqualified_name = to_name(id: base_var.self, allow_alias: false);
18060 if (var_meta)
18061 var_meta->decoration.qualified_alias_explicit_override = old_explicit_qualifier;
18062
18063 // For non-arrayed buffers, we have already performed a de-reference.
18064 // We need a proper lvalue to cast, so strip away the de-reference.
18065 if (unqualified_name.size() > 2 && unqualified_name[0] == '(' && unqualified_name[1] == '*')
18066 {
18067 unqualified_name.erase(first: unqualified_name.begin(), last: unqualified_name.begin() + 2);
18068 unqualified_name.pop_back();
18069 }
18070
18071 string name;
18072
18073 auto &var_type = get<SPIRType>(id: aliased_var.basetype);
18074 auto &data_type = get_variable_data_type(var: aliased_var);
18075 string descriptor_storage = descriptor_address_space(id: aliased_var.self, storage: aliased_var.storage, plain_address_space: "");
18076
18077 if (aliased_var.storage == StorageClassUniformConstant)
18078 {
18079 if (is_var_runtime_size_array(var: aliased_var))
18080 {
18081 // This becomes a plain pointer to spvDescriptor.
18082 name = join(ts: "reinterpret_cast<", ts&: descriptor_storage, ts: " ",
18083 ts: type_to_glsl(type: get_variable_data_type(var: aliased_var), id: aliased_var.self, member: true), ts: ">(&",
18084 ts&: unqualified_name, ts: ")");
18085 }
18086 else
18087 {
18088 name = join(ts: "reinterpret_cast<", ts&: descriptor_storage, ts: " ",
18089 ts: type_to_glsl(type: get_variable_data_type(var: aliased_var), id: aliased_var.self, member: true), ts: " &>(",
18090 ts&: unqualified_name, ts: ");");
18091 }
18092 }
18093 else
18094 {
18095 // Buffer types.
18096 bool old_is_using_builtin_array = is_using_builtin_array;
18097 is_using_builtin_array = true;
18098
18099 bool needs_post_cast_deref = !is_array(type: data_type);
18100 string ref_type = needs_post_cast_deref ? "&" : join(ts: "(&)", ts: type_to_array_glsl(type: var_type, variable_id: aliased_var.self));
18101
18102 if (is_var_runtime_size_array(var: aliased_var))
18103 {
18104 name = join(ts: "reinterpret_cast<",
18105 ts: type_to_glsl(type: var_type, id: aliased_var.self, member: true), ts: " ", ts&: descriptor_storage, ts: " *>(&",
18106 ts&: unqualified_name, ts: ")");
18107 }
18108 else
18109 {
18110 name = join(ts: needs_post_cast_deref ? "*" : "", ts: "reinterpret_cast<",
18111 ts: type_to_glsl(type: var_type, id: aliased_var.self, member: true), ts: " ", ts&: descriptor_storage, ts: " ",
18112 ts&: ref_type,
18113 ts: ">(", ts&: unqualified_name, ts: ");");
18114 }
18115
18116 if (needs_post_cast_deref)
18117 descriptor_storage = get_type_address_space(type: var_type, id: aliased_var.self, argument: false);
18118
18119 // These kinds of ridiculous casts trigger warnings in compiler. Just ignore them.
18120 if (!suppress_incompatible_pointer_types_discard_qualifiers)
18121 {
18122 suppress_incompatible_pointer_types_discard_qualifiers = true;
18123 force_recompile_guarantee_forward_progress();
18124 }
18125
18126 is_using_builtin_array = old_is_using_builtin_array;
18127 }
18128
18129 if (!is_var_runtime_size_array(var: aliased_var))
18130 {
18131 // Lower to temporary, so drop the qualification.
18132 set_qualified_name(id: aliased_var.self, name: "");
18133 statement(ts&: descriptor_storage, ts: " auto &", ts: to_name(id: aliased_var.self), ts: " = ", ts&: name);
18134 }
18135 else
18136 {
18137 // This will get wrapped in a separate temporary when a spvDescriptorArray wrapper is emitted.
18138 set_qualified_name(id: aliased_var.self, name);
18139 }
18140}
18141
18142void CompilerMSL::analyze_argument_buffers()
18143{
18144 // Gather all used resources and sort them out into argument buffers.
18145 // Each argument buffer corresponds to a descriptor set in SPIR-V.
18146 // The [[id(N)]] values used correspond to the resource mapping we have for MSL.
18147 // Otherwise, the binding number is used, but this is generally not safe some types like
18148 // combined image samplers and arrays of resources. Metal needs different indices here,
18149 // while SPIR-V can have one descriptor set binding. To use argument buffers in practice,
18150 // you will need to use the remapping from the API.
18151 for (auto &id : argument_buffer_ids)
18152 id = 0;
18153
18154 // Output resources, sorted by resource index & type.
18155 struct Resource
18156 {
18157 SPIRVariable *var;
18158 string name;
18159 SPIRType::BaseType basetype;
18160 uint32_t index;
18161 uint32_t plane;
18162 uint32_t overlapping_var_id;
18163 };
18164 SmallVector<Resource> resources_in_set[kMaxArgumentBuffers];
18165 SmallVector<uint32_t> inline_block_vars;
18166
18167 bool set_needs_swizzle_buffer[kMaxArgumentBuffers] = {};
18168 bool set_needs_buffer_sizes[kMaxArgumentBuffers] = {};
18169 bool needs_buffer_sizes = false;
18170
18171 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t self, SPIRVariable &var) {
18172 if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
18173 var.storage == StorageClassStorageBuffer) &&
18174 !is_hidden_variable(var))
18175 {
18176 uint32_t desc_set = get_decoration(id: self, decoration: DecorationDescriptorSet);
18177 // Ignore if it's part of a push descriptor set.
18178 if (!descriptor_set_is_argument_buffer(desc_set))
18179 return;
18180
18181 uint32_t var_id = var.self;
18182 auto &type = get_variable_data_type(var);
18183
18184 if (desc_set >= kMaxArgumentBuffers)
18185 SPIRV_CROSS_THROW("Descriptor set index is out of range.");
18186
18187 const MSLConstexprSampler *constexpr_sampler = nullptr;
18188 if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler)
18189 {
18190 constexpr_sampler = find_constexpr_sampler(id: var_id);
18191 if (constexpr_sampler)
18192 {
18193 // Mark this ID as a constexpr sampler for later in case it came from set/bindings.
18194 constexpr_samplers_by_id[var_id] = *constexpr_sampler;
18195 }
18196 }
18197
18198 uint32_t binding = get_decoration(id: var_id, decoration: DecorationBinding);
18199 if (type.basetype == SPIRType::SampledImage)
18200 {
18201 add_resource_name(id: var_id);
18202
18203 uint32_t plane_count = 1;
18204 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
18205 plane_count = constexpr_sampler->planes;
18206
18207 for (uint32_t i = 0; i < plane_count; i++)
18208 {
18209 uint32_t image_resource_index = get_metal_resource_index(var, basetype: SPIRType::Image, plane: i);
18210 resources_in_set[desc_set].push_back(
18211 t: { .var: &var, .name: to_name(id: var_id), .basetype: SPIRType::Image, .index: image_resource_index, .plane: i, .overlapping_var_id: 0 });
18212 }
18213
18214 if (type.image.dim != DimBuffer && !constexpr_sampler)
18215 {
18216 uint32_t sampler_resource_index = get_metal_resource_index(var, basetype: SPIRType::Sampler);
18217 resources_in_set[desc_set].push_back(
18218 t: { .var: &var, .name: to_sampler_expression(id: var_id), .basetype: SPIRType::Sampler, .index: sampler_resource_index, .plane: 0, .overlapping_var_id: 0 });
18219 }
18220 }
18221 else if (inline_uniform_blocks.count(x: SetBindingPair{ .desc_set: desc_set, .binding: binding }))
18222 {
18223 inline_block_vars.push_back(t: var_id);
18224 }
18225 else if (!constexpr_sampler && is_supported_argument_buffer_type(type))
18226 {
18227 // constexpr samplers are not declared as resources.
18228 // Inline uniform blocks are always emitted at the end.
18229 add_resource_name(id: var_id);
18230
18231 uint32_t resource_index = get_metal_resource_index(var, basetype: type.basetype);
18232
18233 resources_in_set[desc_set].push_back(
18234 t: { .var: &var, .name: to_name(id: var_id), .basetype: type.basetype, .index: resource_index, .plane: 0, .overlapping_var_id: 0 });
18235
18236 // Emulate texture2D atomic operations
18237 if (atomic_image_vars_emulated.count(x: var.self))
18238 {
18239 uint32_t buffer_resource_index = get_metal_resource_index(var, basetype: SPIRType::AtomicCounter, plane: 0);
18240 resources_in_set[desc_set].push_back(
18241 t: { .var: &var, .name: to_name(id: var_id) + "_atomic", .basetype: SPIRType::Struct, .index: buffer_resource_index, .plane: 0, .overlapping_var_id: 0 });
18242 }
18243 }
18244
18245 // Check if this descriptor set needs a swizzle buffer.
18246 if (needs_swizzle_buffer_def && is_sampled_image_type(type))
18247 set_needs_swizzle_buffer[desc_set] = true;
18248 else if (buffer_requires_array_length(id: var_id))
18249 {
18250 set_needs_buffer_sizes[desc_set] = true;
18251 needs_buffer_sizes = true;
18252 }
18253 }
18254 });
18255
18256 if (needs_swizzle_buffer_def || needs_buffer_sizes)
18257 {
18258 uint32_t uint_ptr_type_id = 0;
18259
18260 // We might have to add a swizzle buffer resource to the set.
18261 for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
18262 {
18263 if (!set_needs_swizzle_buffer[desc_set] && !set_needs_buffer_sizes[desc_set])
18264 continue;
18265
18266 if (uint_ptr_type_id == 0)
18267 {
18268 uint_ptr_type_id = ir.increase_bound_by(count: 1);
18269
18270 // Create a buffer to hold extra data, including the swizzle constants.
18271 SPIRType uint_type_pointer = get_uint_type();
18272 uint_type_pointer.op = OpTypePointer;
18273 uint_type_pointer.pointer = true;
18274 uint_type_pointer.pointer_depth++;
18275 uint_type_pointer.parent_type = get_uint_type_id();
18276 uint_type_pointer.storage = StorageClassUniform;
18277 set<SPIRType>(id: uint_ptr_type_id, args&: uint_type_pointer);
18278 set_decoration(id: uint_ptr_type_id, decoration: DecorationArrayStride, argument: 4);
18279 }
18280
18281 if (set_needs_swizzle_buffer[desc_set])
18282 {
18283 uint32_t var_id = ir.increase_bound_by(count: 1);
18284 auto &var = set<SPIRVariable>(id: var_id, args&: uint_ptr_type_id, args: StorageClassUniformConstant);
18285 set_name(id: var_id, name: "spvSwizzleConstants");
18286 set_decoration(id: var_id, decoration: DecorationDescriptorSet, argument: desc_set);
18287 set_decoration(id: var_id, decoration: DecorationBinding, argument: kSwizzleBufferBinding);
18288 resources_in_set[desc_set].push_back(
18289 t: { .var: &var, .name: to_name(id: var_id), .basetype: SPIRType::UInt, .index: get_metal_resource_index(var, basetype: SPIRType::UInt), .plane: 0, .overlapping_var_id: 0 });
18290 }
18291
18292 if (set_needs_buffer_sizes[desc_set])
18293 {
18294 uint32_t var_id = ir.increase_bound_by(count: 1);
18295 auto &var = set<SPIRVariable>(id: var_id, args&: uint_ptr_type_id, args: StorageClassUniformConstant);
18296 set_name(id: var_id, name: "spvBufferSizeConstants");
18297 set_decoration(id: var_id, decoration: DecorationDescriptorSet, argument: desc_set);
18298 set_decoration(id: var_id, decoration: DecorationBinding, argument: kBufferSizeBufferBinding);
18299 resources_in_set[desc_set].push_back(
18300 t: { .var: &var, .name: to_name(id: var_id), .basetype: SPIRType::UInt, .index: get_metal_resource_index(var, basetype: SPIRType::UInt), .plane: 0, .overlapping_var_id: 0 });
18301 }
18302 }
18303 }
18304
18305 // Now add inline uniform blocks.
18306 for (uint32_t var_id : inline_block_vars)
18307 {
18308 auto &var = get<SPIRVariable>(id: var_id);
18309 uint32_t desc_set = get_decoration(id: var_id, decoration: DecorationDescriptorSet);
18310 add_resource_name(id: var_id);
18311 resources_in_set[desc_set].push_back(
18312 t: { .var: &var, .name: to_name(id: var_id), .basetype: SPIRType::Struct, .index: get_metal_resource_index(var, basetype: SPIRType::Struct), .plane: 0, .overlapping_var_id: 0 });
18313 }
18314
18315 for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
18316 {
18317 auto &resources = resources_in_set[desc_set];
18318 if (resources.empty())
18319 continue;
18320
18321 assert(descriptor_set_is_argument_buffer(desc_set));
18322
18323 uint32_t next_id = ir.increase_bound_by(count: 3);
18324 uint32_t type_id = next_id + 1;
18325 uint32_t ptr_type_id = next_id + 2;
18326 argument_buffer_ids[desc_set] = next_id;
18327
18328 auto &buffer_type = set<SPIRType>(id: type_id, args: OpTypeStruct);
18329
18330 buffer_type.basetype = SPIRType::Struct;
18331
18332 if ((argument_buffer_device_storage_mask & (1u << desc_set)) != 0)
18333 {
18334 buffer_type.storage = StorageClassStorageBuffer;
18335 // Make sure the argument buffer gets marked as const device.
18336 set_decoration(id: next_id, decoration: DecorationNonWritable);
18337 // Need to mark the type as a Block to enable this.
18338 set_decoration(id: type_id, decoration: DecorationBlock);
18339 }
18340 else
18341 buffer_type.storage = StorageClassUniform;
18342
18343 set_name(id: type_id, name: join(ts: "spvDescriptorSetBuffer", ts&: desc_set));
18344
18345 auto &ptr_type = set<SPIRType>(id: ptr_type_id, args: OpTypePointer);
18346 ptr_type = buffer_type;
18347 ptr_type.op = spv::OpTypePointer;
18348 ptr_type.pointer = true;
18349 ptr_type.pointer_depth++;
18350 ptr_type.parent_type = type_id;
18351
18352 uint32_t buffer_variable_id = next_id;
18353 set<SPIRVariable>(id: buffer_variable_id, args&: ptr_type_id, args: StorageClassUniform);
18354 set_name(id: buffer_variable_id, name: join(ts: "spvDescriptorSet", ts&: desc_set));
18355
18356 // Ids must be emitted in ID order.
18357 stable_sort(first: begin(cont&: resources), last: end(cont&: resources), comp: [&](const Resource &lhs, const Resource &rhs) -> bool {
18358 return tie(args: lhs.index, args: lhs.basetype) < tie(args: rhs.index, args: rhs.basetype);
18359 });
18360
18361 for (size_t i = 0; i < resources.size() - 1; i++)
18362 {
18363 auto &r1 = resources[i];
18364 auto &r2 = resources[i + 1];
18365
18366 if (r1.index == r2.index)
18367 {
18368 if (r1.overlapping_var_id)
18369 r2.overlapping_var_id = r1.overlapping_var_id;
18370 else
18371 r2.overlapping_var_id = r1.var->self;
18372
18373 set_extended_decoration(id: r2.var->self, decoration: SPIRVCrossDecorationOverlappingBinding, value: r2.overlapping_var_id);
18374 }
18375 }
18376
18377 uint32_t member_index = 0;
18378 uint32_t next_arg_buff_index = 0;
18379 for (auto &resource : resources)
18380 {
18381 auto &var = *resource.var;
18382 auto &type = get_variable_data_type(var);
18383
18384 if (is_var_runtime_size_array(var) && (argument_buffer_device_storage_mask & (1u << desc_set)) == 0)
18385 SPIRV_CROSS_THROW("Runtime sized variables must be in device storage argument buffers.");
18386
18387 // If needed, synthesize and add padding members.
18388 // member_index and next_arg_buff_index are incremented when padding members are added.
18389 if (msl_options.pad_argument_buffer_resources && resource.overlapping_var_id == 0)
18390 {
18391 auto rez_bind = get_argument_buffer_resource(desc_set, arg_idx: next_arg_buff_index);
18392 while (resource.index > next_arg_buff_index)
18393 {
18394 switch (rez_bind.basetype)
18395 {
18396 case SPIRType::Void:
18397 case SPIRType::Boolean:
18398 case SPIRType::SByte:
18399 case SPIRType::UByte:
18400 case SPIRType::Short:
18401 case SPIRType::UShort:
18402 case SPIRType::Int:
18403 case SPIRType::UInt:
18404 case SPIRType::Int64:
18405 case SPIRType::UInt64:
18406 case SPIRType::AtomicCounter:
18407 case SPIRType::Half:
18408 case SPIRType::Float:
18409 case SPIRType::Double:
18410 add_argument_buffer_padding_buffer_type(struct_type&: buffer_type, mbr_idx&: member_index, arg_buff_index&: next_arg_buff_index, rez_bind);
18411 break;
18412 case SPIRType::Image:
18413 add_argument_buffer_padding_image_type(struct_type&: buffer_type, mbr_idx&: member_index, arg_buff_index&: next_arg_buff_index, rez_bind);
18414 break;
18415 case SPIRType::Sampler:
18416 add_argument_buffer_padding_sampler_type(struct_type&: buffer_type, mbr_idx&: member_index, arg_buff_index&: next_arg_buff_index, rez_bind);
18417 break;
18418 case SPIRType::SampledImage:
18419 if (next_arg_buff_index == rez_bind.msl_sampler)
18420 add_argument_buffer_padding_sampler_type(struct_type&: buffer_type, mbr_idx&: member_index, arg_buff_index&: next_arg_buff_index, rez_bind);
18421 else
18422 add_argument_buffer_padding_image_type(struct_type&: buffer_type, mbr_idx&: member_index, arg_buff_index&: next_arg_buff_index, rez_bind);
18423 break;
18424 default:
18425 break;
18426 }
18427
18428 // After padding, retrieve the resource again. It will either be more padding, or the actual resource.
18429 rez_bind = get_argument_buffer_resource(desc_set, arg_idx: next_arg_buff_index);
18430 }
18431
18432 // Adjust the number of slots consumed by current member itself.
18433 // Use the count value from the app, instead of the shader, in case the
18434 // shader is only accessing part, or even one element, of the array.
18435 next_arg_buff_index += rez_bind.count;
18436 }
18437
18438 string mbr_name = ensure_valid_name(name: resource.name, pfx: "m");
18439 if (resource.plane > 0)
18440 mbr_name += join(ts&: plane_name_suffix, ts&: resource.plane);
18441 set_member_name(id: buffer_type.self, index: member_index, name: mbr_name);
18442
18443 if (resource.basetype == SPIRType::Sampler && type.basetype != SPIRType::Sampler)
18444 {
18445 // Have to synthesize a sampler type here.
18446
18447 bool type_is_array = !type.array.empty();
18448 uint32_t sampler_type_id = ir.increase_bound_by(count: type_is_array ? 2 : 1);
18449 auto &new_sampler_type = set<SPIRType>(id: sampler_type_id, args: OpTypeSampler);
18450 new_sampler_type.basetype = SPIRType::Sampler;
18451 new_sampler_type.storage = StorageClassUniformConstant;
18452
18453 if (type_is_array)
18454 {
18455 uint32_t sampler_type_array_id = sampler_type_id + 1;
18456 auto &sampler_type_array = set<SPIRType>(id: sampler_type_array_id, args: OpTypeArray);
18457 sampler_type_array = new_sampler_type;
18458 sampler_type_array.array = type.array;
18459 sampler_type_array.array_size_literal = type.array_size_literal;
18460 sampler_type_array.parent_type = sampler_type_id;
18461 buffer_type.member_types.push_back(t: sampler_type_array_id);
18462 }
18463 else
18464 buffer_type.member_types.push_back(t: sampler_type_id);
18465 }
18466 else
18467 {
18468 uint32_t binding = get_decoration(id: var.self, decoration: DecorationBinding);
18469 SetBindingPair pair = { .desc_set: desc_set, .binding: binding };
18470
18471 if (resource.basetype == SPIRType::Image || resource.basetype == SPIRType::Sampler ||
18472 resource.basetype == SPIRType::SampledImage)
18473 {
18474 // Drop pointer information when we emit the resources into a struct.
18475 buffer_type.member_types.push_back(t: get_variable_data_type_id(var));
18476 if (has_extended_decoration(id: var.self, decoration: SPIRVCrossDecorationOverlappingBinding))
18477 {
18478 if (!msl_options.supports_msl_version(major: 3, minor: 0))
18479 SPIRV_CROSS_THROW("Full mutable aliasing of argument buffer descriptors only works on Metal 3+.");
18480
18481 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
18482 entry_func.fixup_hooks_in.push_back(t: [this, resource]() {
18483 emit_argument_buffer_aliased_descriptor(aliased_var: *resource.var, base_var: this->get<SPIRVariable>(id: resource.overlapping_var_id));
18484 });
18485 }
18486 else if (resource.plane == 0)
18487 {
18488 set_qualified_name(id: var.self, name: join(ts: to_name(id: buffer_variable_id), ts: ".", ts&: mbr_name));
18489 }
18490 }
18491 else if (buffers_requiring_dynamic_offset.count(x: pair))
18492 {
18493 // Don't set the qualified name here; we'll define a variable holding the corrected buffer address later.
18494 buffer_type.member_types.push_back(t: var.basetype);
18495 buffers_requiring_dynamic_offset[pair].second = var.self;
18496 }
18497 else if (inline_uniform_blocks.count(x: pair))
18498 {
18499 // Put the buffer block itself into the argument buffer.
18500 buffer_type.member_types.push_back(t: get_variable_data_type_id(var));
18501 set_qualified_name(id: var.self, name: join(ts: to_name(id: buffer_variable_id), ts: ".", ts&: mbr_name));
18502 }
18503 else if (atomic_image_vars_emulated.count(x: var.self))
18504 {
18505 // Emulate texture2D atomic operations.
18506 // Don't set the qualified name: it's already set for this variable,
18507 // and the code that references the buffer manually appends "_atomic"
18508 // to the name.
18509 uint32_t offset = ir.increase_bound_by(count: 2);
18510 uint32_t atomic_type_id = offset;
18511 uint32_t type_ptr_id = offset + 1;
18512
18513 SPIRType atomic_type { OpTypeInt };
18514 atomic_type.basetype = SPIRType::AtomicCounter;
18515 atomic_type.width = 32;
18516 atomic_type.vecsize = 1;
18517 set<SPIRType>(id: atomic_type_id, args&: atomic_type);
18518
18519 atomic_type.op = OpTypePointer;
18520 atomic_type.pointer = true;
18521 atomic_type.pointer_depth++;
18522 atomic_type.parent_type = atomic_type_id;
18523 atomic_type.storage = StorageClassStorageBuffer;
18524 auto &atomic_ptr_type = set<SPIRType>(id: type_ptr_id, args&: atomic_type);
18525 atomic_ptr_type.self = atomic_type_id;
18526
18527 buffer_type.member_types.push_back(t: type_ptr_id);
18528 }
18529 else
18530 {
18531 buffer_type.member_types.push_back(t: var.basetype);
18532 if (has_extended_decoration(id: var.self, decoration: SPIRVCrossDecorationOverlappingBinding))
18533 {
18534 // Casting raw pointers is fine since their ABI is fixed, but anything opaque is deeply questionable on Metal 2.
18535 if (get<SPIRVariable>(id: resource.overlapping_var_id).storage == StorageClassUniformConstant &&
18536 !msl_options.supports_msl_version(major: 3, minor: 0))
18537 {
18538 SPIRV_CROSS_THROW("Full mutable aliasing of argument buffer descriptors only works on Metal 3+.");
18539 }
18540
18541 auto &entry_func = get<SPIRFunction>(id: ir.default_entry_point);
18542
18543 entry_func.fixup_hooks_in.push_back(t: [this, resource]() {
18544 emit_argument_buffer_aliased_descriptor(aliased_var: *resource.var, base_var: this->get<SPIRVariable>(id: resource.overlapping_var_id));
18545 });
18546 }
18547 else if (type.array.empty())
18548 set_qualified_name(id: var.self, name: join(ts: "(*", ts: to_name(id: buffer_variable_id), ts: ".", ts&: mbr_name, ts: ")"));
18549 else
18550 set_qualified_name(id: var.self, name: join(ts: to_name(id: buffer_variable_id), ts: ".", ts&: mbr_name));
18551 }
18552 }
18553
18554 set_extended_member_decoration(type: buffer_type.self, index: member_index, decoration: SPIRVCrossDecorationResourceIndexPrimary,
18555 value: resource.index);
18556 set_extended_member_decoration(type: buffer_type.self, index: member_index, decoration: SPIRVCrossDecorationInterfaceOrigID,
18557 value: var.self);
18558 if (has_extended_decoration(id: var.self, decoration: SPIRVCrossDecorationOverlappingBinding))
18559 set_extended_member_decoration(type: buffer_type.self, index: member_index, decoration: SPIRVCrossDecorationOverlappingBinding);
18560 member_index++;
18561 }
18562 }
18563}
18564
18565// Return the resource type of the app-provided resources for the descriptor set,
18566// that matches the resource index of the argument buffer index.
18567// This is a two-step lookup, first lookup the resource binding number from the argument buffer index,
18568// then lookup the resource binding using the binding number.
18569const MSLResourceBinding &CompilerMSL::get_argument_buffer_resource(uint32_t desc_set, uint32_t arg_idx) const
18570{
18571 auto stage = get_entry_point().model;
18572 StageSetBinding arg_idx_tuple = { .model: stage, .desc_set: desc_set, .binding: arg_idx };
18573 auto arg_itr = resource_arg_buff_idx_to_binding_number.find(x: arg_idx_tuple);
18574 if (arg_itr != end(cont: resource_arg_buff_idx_to_binding_number))
18575 {
18576 StageSetBinding bind_tuple = { .model: stage, .desc_set: desc_set, .binding: arg_itr->second };
18577 auto bind_itr = resource_bindings.find(x: bind_tuple);
18578 if (bind_itr != end(cont: resource_bindings))
18579 return bind_itr->second.first;
18580 }
18581 SPIRV_CROSS_THROW("Argument buffer resource base type could not be determined. When padding argument buffer "
18582 "elements, all descriptor set resources must be supplied with a base type by the app.");
18583}
18584
18585// Adds an argument buffer padding argument buffer type as one or more members of the struct type at the member index.
18586// Metal does not support arrays of buffers, so these are emitted as multiple struct members.
18587void CompilerMSL::add_argument_buffer_padding_buffer_type(SPIRType &struct_type, uint32_t &mbr_idx,
18588 uint32_t &arg_buff_index, MSLResourceBinding &rez_bind)
18589{
18590 if (!argument_buffer_padding_buffer_type_id)
18591 {
18592 uint32_t buff_type_id = ir.increase_bound_by(count: 2);
18593 auto &buff_type = set<SPIRType>(id: buff_type_id, args: OpNop);
18594 buff_type.basetype = rez_bind.basetype;
18595 buff_type.storage = StorageClassUniformConstant;
18596
18597 uint32_t ptr_type_id = buff_type_id + 1;
18598 auto &ptr_type = set<SPIRType>(id: ptr_type_id, args: OpTypePointer);
18599 ptr_type = buff_type;
18600 ptr_type.op = spv::OpTypePointer;
18601 ptr_type.pointer = true;
18602 ptr_type.pointer_depth++;
18603 ptr_type.parent_type = buff_type_id;
18604
18605 argument_buffer_padding_buffer_type_id = ptr_type_id;
18606 }
18607
18608 add_argument_buffer_padding_type(mbr_type_id: argument_buffer_padding_buffer_type_id, struct_type, mbr_idx, arg_buff_index, count: rez_bind.count);
18609}
18610
18611// Adds an argument buffer padding argument image type as a member of the struct type at the member index.
18612void CompilerMSL::add_argument_buffer_padding_image_type(SPIRType &struct_type, uint32_t &mbr_idx,
18613 uint32_t &arg_buff_index, MSLResourceBinding &rez_bind)
18614{
18615 if (!argument_buffer_padding_image_type_id)
18616 {
18617 uint32_t base_type_id = ir.increase_bound_by(count: 2);
18618 auto &base_type = set<SPIRType>(id: base_type_id, args: OpTypeFloat);
18619 base_type.basetype = SPIRType::Float;
18620 base_type.width = 32;
18621
18622 uint32_t img_type_id = base_type_id + 1;
18623 auto &img_type = set<SPIRType>(id: img_type_id, args: OpTypeImage);
18624 img_type.basetype = SPIRType::Image;
18625 img_type.storage = StorageClassUniformConstant;
18626
18627 img_type.image.type = base_type_id;
18628 img_type.image.dim = Dim2D;
18629 img_type.image.depth = false;
18630 img_type.image.arrayed = false;
18631 img_type.image.ms = false;
18632 img_type.image.sampled = 1;
18633 img_type.image.format = ImageFormatUnknown;
18634 img_type.image.access = AccessQualifierMax;
18635
18636 argument_buffer_padding_image_type_id = img_type_id;
18637 }
18638
18639 add_argument_buffer_padding_type(mbr_type_id: argument_buffer_padding_image_type_id, struct_type, mbr_idx, arg_buff_index, count: rez_bind.count);
18640}
18641
18642// Adds an argument buffer padding argument sampler type as a member of the struct type at the member index.
18643void CompilerMSL::add_argument_buffer_padding_sampler_type(SPIRType &struct_type, uint32_t &mbr_idx,
18644 uint32_t &arg_buff_index, MSLResourceBinding &rez_bind)
18645{
18646 if (!argument_buffer_padding_sampler_type_id)
18647 {
18648 uint32_t samp_type_id = ir.increase_bound_by(count: 1);
18649 auto &samp_type = set<SPIRType>(id: samp_type_id, args: OpTypeSampler);
18650 samp_type.basetype = SPIRType::Sampler;
18651 samp_type.storage = StorageClassUniformConstant;
18652
18653 argument_buffer_padding_sampler_type_id = samp_type_id;
18654 }
18655
18656 add_argument_buffer_padding_type(mbr_type_id: argument_buffer_padding_sampler_type_id, struct_type, mbr_idx, arg_buff_index, count: rez_bind.count);
18657}
18658
18659// Adds the argument buffer padding argument type as a member of the struct type at the member index.
18660// Advances both arg_buff_index and mbr_idx to next argument slots.
18661void CompilerMSL::add_argument_buffer_padding_type(uint32_t mbr_type_id, SPIRType &struct_type, uint32_t &mbr_idx,
18662 uint32_t &arg_buff_index, uint32_t count)
18663{
18664 uint32_t type_id = mbr_type_id;
18665 if (count > 1)
18666 {
18667 uint32_t ary_type_id = ir.increase_bound_by(count: 1);
18668 auto &ary_type = set<SPIRType>(id: ary_type_id, args&: get<SPIRType>(id: type_id));
18669 ary_type.op = OpTypeArray;
18670 ary_type.array.push_back(t: count);
18671 ary_type.array_size_literal.push_back(t: true);
18672 ary_type.parent_type = type_id;
18673 type_id = ary_type_id;
18674 }
18675
18676 set_member_name(id: struct_type.self, index: mbr_idx, name: join(ts: "_m", ts&: arg_buff_index, ts: "_pad"));
18677 set_extended_member_decoration(type: struct_type.self, index: mbr_idx, decoration: SPIRVCrossDecorationResourceIndexPrimary, value: arg_buff_index);
18678 struct_type.member_types.push_back(t: type_id);
18679
18680 arg_buff_index += count;
18681 mbr_idx++;
18682}
18683
18684void CompilerMSL::activate_argument_buffer_resources()
18685{
18686 // For ABI compatibility, force-enable all resources which are part of argument buffers.
18687 ir.for_each_typed_id<SPIRVariable>(op: [&](uint32_t self, const SPIRVariable &) {
18688 if (!has_decoration(id: self, decoration: DecorationDescriptorSet))
18689 return;
18690
18691 uint32_t desc_set = get_decoration(id: self, decoration: DecorationDescriptorSet);
18692 if (descriptor_set_is_argument_buffer(desc_set))
18693 add_active_interface_variable(var_id: self);
18694 });
18695}
18696
18697bool CompilerMSL::using_builtin_array() const
18698{
18699 return msl_options.force_native_arrays || is_using_builtin_array;
18700}
18701
18702void CompilerMSL::set_combined_sampler_suffix(const char *suffix)
18703{
18704 sampler_name_suffix = suffix;
18705}
18706
18707const char *CompilerMSL::get_combined_sampler_suffix() const
18708{
18709 return sampler_name_suffix.c_str();
18710}
18711
18712void CompilerMSL::emit_block_hints(const SPIRBlock &)
18713{
18714}
18715
18716string CompilerMSL::additional_fixed_sample_mask_str() const
18717{
18718 char print_buffer[32];
18719#ifdef _MSC_VER
18720 // snprintf does not exist or is buggy on older MSVC versions, some of
18721 // them being used by MinGW. Use sprintf instead and disable
18722 // corresponding warning.
18723#pragma warning(push)
18724#pragma warning(disable : 4996)
18725#endif
18726#if _WIN32
18727 sprintf(print_buffer, "0x%x", msl_options.additional_fixed_sample_mask);
18728#else
18729 snprintf(s: print_buffer, maxlen: sizeof(print_buffer), format: "0x%x", msl_options.additional_fixed_sample_mask);
18730#endif
18731#ifdef _MSC_VER
18732#pragma warning(pop)
18733#endif
18734 return print_buffer;
18735}
18736

source code of qtshadertools/src/3rdparty/SPIRV-Cross/spirv_msl.cpp