spirv_webgpu_transform/
splitdref.rs

1use super::*;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4enum OperationVariant {
5    Regular,
6    Dref,
7}
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10enum LoadType {
11    Variable,
12    FunctionArgument,
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16enum MixState {
17    Mixed,
18    PotentiallyMixed,
19}
20
21trait IsIndexOrId {}
22impl IsIndexOrId for u32 {}
23impl IsIndexOrId for usize {}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26enum PatchObjectType<T: IsIndexOrId> {
27    Sampler(T),
28    Image(T),
29}
30
31impl<T> PatchObjectType<T>
32where
33    T: IsIndexOrId,
34{
35    fn next<N: IsIndexOrId>(self, next_id: N) -> PatchObjectType<N> {
36        match self {
37            PatchObjectType::Sampler(_) => PatchObjectType::Sampler(next_id),
38            PatchObjectType::Image(_) => PatchObjectType::Image(next_id),
39        }
40    }
41
42    fn inner(self) -> T {
43        match self {
44            PatchObjectType::Sampler(v) => v,
45            PatchObjectType::Image(v) => v,
46        }
47    }
48}
49
50/// Perform the operation on a `Vec<u32>`.
51/// Use [u8_slice_to_u32_vec] to convert a `&[u8]` into a `Vec<u32>`.
52/// Either update the existing `corrections` or create a new one.
53pub fn drefsplitter(
54    in_spv: &[u32],
55    corrections: &mut Option<CorrectionMap>,
56) -> Result<Vec<u32>, ()> {
57    let spv = in_spv.to_owned();
58
59    let mut instruction_bound = spv[SPV_HEADER_INSTRUCTION_BOUND_OFFSET];
60    let magic_number = spv[SPV_HEADER_MAGIC_NUM_OFFSET];
61
62    let spv_header = spv[0..SPV_HEADER_LENGTH].to_owned();
63
64    assert_eq!(magic_number, SPV_HEADER_MAGIC);
65
66    let mut instruction_inserts: Vec<InstructionInsert> = vec![];
67    let mut word_inserts: Vec<WordInsert> = vec![];
68
69    let spv = spv.into_iter().skip(SPV_HEADER_LENGTH).collect::<Vec<_>>();
70    let mut new_spv = spv.clone();
71
72    // 1. Find locations instructions we need
73    let mut op_dref_operation_idxs = vec![];
74    let mut op_sampled_operation_idxs = vec![];
75    let mut op_sampled_image_idxs = vec![];
76    let mut op_load_idxs = vec![];
77    let mut op_variable_idxs = vec![];
78    let mut op_decorate_idxs = vec![];
79    let mut op_type_image_idxs = vec![];
80    let mut op_type_pointer_idxs = vec![];
81    let mut op_type_function_idxs = vec![];
82    let mut op_function_idxs = vec![];
83    let mut op_function_call_idxs = vec![];
84    let mut op_function_parameter_idxs = vec![];
85
86    let mut first_op_type_sampler_id = None;
87    let mut first_op_type_pointer_sampler_id = None;
88
89    let mut spv_idx = 0;
90    while spv_idx < spv.len() {
91        let op = spv[spv_idx];
92        let word_count = hiword(op);
93        let instruction = loword(op);
94
95        match instruction {
96            SPV_INSTRUCTION_OP_SAMPLED_IMAGE => op_sampled_image_idxs.push(spv_idx),
97            SPV_INSTRUCTION_OP_LOAD => op_load_idxs.push(spv_idx),
98            SPV_INSTRUCTION_OP_IMAGE_SAMPLE_DREF_IMPLICIT_LOD
99            | SPV_INSTRUCTION_OP_IMAGE_SAMPLE_DREF_EXPLICIT_LOD
100            | SPV_INSTRUCTION_OP_IMAGE_SAMPLE_PROJ_DREF_IMPLICIT_LOD
101            | SPV_INSTRUCTION_OP_IMAGE_SAMPLE_PROJ_DREF_EXPLICIT_LOD
102            | SPV_INSTRUCTION_OP_IMAGE_DREF_GATHER
103            | SPV_INSTRUCTION_OP_IMAGE_SPARSE_SAMPLE_DREF_IMPLICIT_LOD
104            | SPV_INSTRUCTION_OP_IMAGE_SPARSE_SAMPLE_DREF_EXPLICIT_LOD
105            | SPV_INSTRUCTION_OP_IMAGE_SPARSE_DREF_GATHER => op_dref_operation_idxs.push(spv_idx),
106            SPV_INSTRUCTION_OP_IMAGE_SAMPLE_IMPLICIT_LOD
107            | SPV_INSTRUCTION_OP_IMAGE_SAMPLE_EXPLICIT_LOD
108            | SPV_INSTRUCTION_OP_IMAGE_SAMPLE_PROJ_IMPLICIT_LOD
109            | SPV_INSTRUCTION_OP_IMAGE_SAMPLE_PROJ_EXPLICIT_LOD
110            | SPV_INSTRUCTION_OP_IMAGE_GATHER
111            | SPV_INSTRUCTION_OP_IMAGE_SPARSE_SAMPLE_IMPLICIT_LOD
112            | SPV_INSTRUCTION_OP_IMAGE_SPARSE_SAMPLE_EXPLICIT_LOD
113            | SPV_INSTRUCTION_OP_IMAGE_SPARSE_GATHER => op_sampled_operation_idxs.push(spv_idx),
114            SPV_INSTRUCTION_OP_VARIABLE => op_variable_idxs.push(spv_idx),
115            SPV_INSTRUCTION_OP_DECORATE => op_decorate_idxs.push(spv_idx),
116            SPV_INSTRUCTION_OP_TYPE_IMAGE => op_type_image_idxs.push(spv_idx),
117            SPV_INSTRUCTION_OP_TYPE_SAMPLER => {
118                first_op_type_sampler_id.get_or_insert(spv[spv_idx + 1]);
119            }
120            SPV_INSTRUCTION_OP_TYPE_POINTER => {
121                if first_op_type_sampler_id == Some(spv[spv_idx + 3])
122                    && spv[spv_idx + 2] == SPV_STORAGE_CLASS_UNIFORM_CONSTANT
123                {
124                    first_op_type_pointer_sampler_id = Some(spv[spv_idx + 1]);
125                }
126                op_type_pointer_idxs.push(spv_idx)
127            }
128            SPV_INSTRUCTION_OP_TYPE_FUNCTION => op_type_function_idxs.push(spv_idx),
129            SPV_INSTRUCTION_OP_FUNCTION => op_function_idxs.push(spv_idx),
130            SPV_INSTRUCTION_OP_FUNCTION_CALL => op_function_call_idxs.push(spv_idx),
131            SPV_INSTRUCTION_OP_FUNCTION_PARAMTER => op_function_parameter_idxs.push(spv_idx),
132            _ => {}
133        }
134
135        spv_idx += word_count as usize;
136    }
137
138    let first_op_deocrate_idx = op_decorate_idxs.first().copied();
139
140    // If there is no OpTypeSampler, either this is invalid, or we do not need to do any patching at all.
141    let (Some(first_op_type_sampler_id), Some(first_op_type_pointer_sampler_id)) =
142        (first_op_type_sampler_id, first_op_type_pointer_sampler_id)
143    else {
144        return Ok(in_spv.to_vec());
145    };
146
147    // 2. Collect all the loaded sampled images of both operation types
148    // Conveniently, the offset for this value is always +3 for all of these operations
149    let loaded_sampled_image_ids = op_sampled_operation_idxs
150        .iter()
151        .map(|idx| (spv[idx + 3], OperationVariant::Regular))
152        .chain(
153            op_dref_operation_idxs
154                .iter()
155                .map(|idx| (spv[idx + 3], OperationVariant::Dref)),
156        )
157        .collect::<Vec<_>>();
158
159    // 3. Backtrace to find the OpSampledImage that resulted in our loaded sampled images
160    let loaded_variable_ids = op_sampled_image_idxs
161        .iter()
162        .filter_map(|idx| {
163            let sampled_result_id = spv[idx + 2];
164            let loaded_image_id = spv[idx + 3];
165            let loaded_sampler_id = spv[idx + 4];
166            loaded_sampled_image_ids
167                .iter()
168                .find_map(|(id, ty)| (*id == sampled_result_id).then_some(ty))
169                .map(|ty| {
170                    [
171                        (PatchObjectType::Image(loaded_image_id), ty),
172                        (PatchObjectType::Sampler(loaded_sampler_id), ty),
173                    ]
174                })
175        })
176        .flatten()
177        .collect::<Vec<_>>();
178
179    // 4. Backtrack to find the OpLoad that resulted in our loaded images
180    let object_ids = op_load_idxs
181        .iter()
182        .filter_map(|idx| {
183            let loaded_result_id = spv[idx + 2];
184            let original_image_or_sampler = spv[idx + 3];
185            loaded_variable_ids
186                .iter()
187                .find_map(|(id, ty)| (id.inner() == loaded_result_id).then_some((id, ty)))
188                .map(|(id, ty)| (id.next(original_image_or_sampler), idx, ty))
189        })
190        .collect::<Vec<_>>();
191
192    // 5. Find the images that mismatch operations
193    let mut mixed_object_ids = HashMap::new();
194    let mut patch_object_id_to_loads = HashMap::new();
195
196    for (id, load_idx, ty) in object_ids.iter().copied() {
197        let entry = mixed_object_ids.entry(id).or_insert((false, false));
198
199        match ty {
200            OperationVariant::Regular => entry.0 = true,
201            OperationVariant::Dref => {
202                entry.1 = true;
203            }
204        }
205        patch_object_id_to_loads
206            .entry(id)
207            .or_insert(vec![])
208            .push((load_idx, ty));
209    }
210
211    let mixed_object_ids = mixed_object_ids
212        .into_iter()
213        .filter_map(|(id, (uses_regular, uses_dref))| (uses_regular && uses_dref).then_some(id))
214        .collect::<Vec<_>>();
215
216    // 6. Find the OpVariable of the mismatched images
217    let patch_variable_idxs = op_variable_idxs
218        .iter()
219        .filter_map(|idx: &usize| {
220            let result_id = spv[*idx + 2];
221            mixed_object_ids
222                .iter()
223                .find(|id| id.inner() == result_id)
224                .map(|id| id.next(*idx))
225        })
226        .collect::<Vec<_>>();
227
228    // 7. Find OpFunctionParameter of ~~the mismatched~~ ALL images operations
229    // Later, we can keep the ones that trace to mismatched global variables
230    let patch_function_parameter_idxs = op_function_parameter_idxs
231        .iter()
232        .filter_map(|idx: &usize| {
233            let result_id = spv[*idx + 2];
234            mixed_object_ids
235                .iter()
236                .find_map(|id| {
237                    (id.inner() == result_id).then_some((id.next(*idx), MixState::Mixed))
238                })
239                .or(object_ids.iter().find_map(|(id, _, _)| {
240                    (id.inner() == result_id).then_some((id.next(*idx), MixState::PotentiallyMixed))
241                }))
242        })
243        .collect::<Vec<_>>();
244
245    // 8. Find the OpVariable that eventually reaches OpFunctionCall of our OpFunctions
246    // Because functions may be deeply nested, we'll have to account for other OpFunctionCalls
247    let function_patch_variables_with_calls = patch_function_parameter_idxs
248        .iter()
249        .map(|&(idx, mix_state)| {
250            let mut traced_function_calls = vec![];
251            let entry = get_function_from_parameter(&spv, idx.inner());
252            let variables =
253                trace_function_argument_to_variables(TraceFunctionArgumentToVariablesIn {
254                    spv: &spv,
255                    op_variable_idxs: &op_variable_idxs,
256                    op_function_parameter_idxs: &op_function_parameter_idxs,
257                    op_function_call_idxs: &op_function_call_idxs,
258                    entry,
259                    traced_function_call_idxs: &mut traced_function_calls,
260                });
261            (
262                variables
263                    .into_iter()
264                    .map(|v| idx.next(v))
265                    .collect::<Vec<_>>(),
266                traced_function_calls,
267                mix_state,
268            )
269        })
270        .collect::<Vec<_>>();
271
272    // Filter out PotentiallyMixed parameters that don't relate to any Mixed function parameters or
273    // mixed variables
274    // TODO: This cannot handle mixing between different contexts, see `test_hidden3_dref.frag`
275    let function_patch_variables_with_calls = function_patch_variables_with_calls
276        .iter()
277        .cloned()
278        .filter_map(|(variables, calls, mix_state)| match mix_state {
279            MixState::Mixed => Some((variables, calls)),
280            MixState::PotentiallyMixed => (function_patch_variables_with_calls.iter().any(
281                |(mixed_variables, _, mix_state)| {
282                    *mix_state == MixState::Mixed
283                        && mixed_variables
284                            .iter()
285                            .any(|va| variables.iter().any(|vb| va == vb))
286                },
287            ) || patch_variable_idxs
288                .iter()
289                .any(|idx| variables.iter().any(|va| va == idx)))
290            .then_some((variables, calls)),
291        })
292        .collect::<Vec<_>>();
293
294    let mut patch_variable_idxs = patch_variable_idxs
295        .iter()
296        .copied()
297        .map(|idx| (idx, LoadType::Variable))
298        .collect::<Vec<_>>();
299
300    let mut function_patch_variable_set = HashSet::new();
301    for (variables, _) in function_patch_variables_with_calls.iter() {
302        for variable in variables {
303            if !function_patch_variable_set.contains(variable) {
304                patch_variable_idxs.push((*variable, LoadType::FunctionArgument));
305                function_patch_variable_set.insert(*variable);
306            }
307        }
308    }
309
310    // 9. Find OpTypePointer that resulted in OpVariable
311    let patch_variable_idxs = patch_variable_idxs.into_iter().map(|(variable_idx, lty)| {
312        let type_pointer_id = spv[variable_idx.inner() + 1];
313        let maybe_tp_idx = op_type_pointer_idxs.iter().find(|&tp_idx| {
314            let tp_id = spv[tp_idx + 1];
315            type_pointer_id == tp_id
316        });
317        (variable_idx, lty, maybe_tp_idx.copied())
318    });
319
320    // 10. Find OpTypeImage that resulted in OpTypePointer
321    //    We also want to create an complement OpTypeImage (depth=!depth) (without duplicates) and
322    //    a respective OpTypePointer ~~and OpTypeSampledImage pair~~ (also no duplicates).
323    let mut existing_type_pointers_from_type_image = HashMap::new();
324    let mut existing_type_images_from_complement_instruction = HashMap::new();
325
326    let patch_variable_idxs = patch_variable_idxs
327        .map(|(variable_idx, lty, tp_idx)| {
328            match variable_idx {
329                v @ PatchObjectType::Sampler(variable_idx) => {
330                    (
331                        v.next(variable_idx),
332                        lty,
333                        first_op_type_sampler_id,
334                        first_op_type_pointer_sampler_id,
335                        first_op_type_sampler_id,
336                        // From the perspective of a SPIRV sampler variable, this doesn't matter
337                        OperationVariant::Dref,
338                    )
339                }
340                v @ PatchObjectType::Image(variable_idx) => {
341                    let variable_result_id = spv[variable_idx];
342                    let image_type_id = if let Some(tp_idx) = tp_idx {
343                        // type_image_id
344                        spv[tp_idx + 3]
345                    } else if let Some(load_idxs) =
346                        patch_object_id_to_loads.get(&PatchObjectType::Image(variable_result_id))
347                        && let Some(&(load_idx, _)) = load_idxs.first()
348                    {
349                        // We don't have a type pointer, let's find the OpTypeImage via our original OpLoad!
350                        // load_type_result_id
351                        spv[load_idx + 1]
352                    } else {
353                        unreachable!(
354                            "Our OpVariable image id should always point back to a OpLoad id"
355                        );
356                    };
357
358                    // Grab the existing type image
359                    let (ti_idx, ti_id) = op_type_image_idxs
360                        .iter()
361                        .find_map(|&ti_idx| {
362                            let result_id = spv[ti_idx + 1];
363                            (result_id == image_type_id).then_some((ti_idx, result_id))
364                        })
365                        .unwrap();
366
367                    // Try to find an type image with the complement properties or (re-)create one
368                    let ti_word_count = hiword(spv[ti_idx]) as usize;
369                    let mut ti_complement = spv[ti_idx + 2..ti_idx + ti_word_count].to_vec();
370                    let complement_ty = match ti_complement[2] {
371                        0 | 2 => {
372                            ti_complement[2] = 1;
373                            OperationVariant::Dref
374                        }
375                        1 => {
376                            ti_complement[2] = 0;
377                            OperationVariant::Regular
378                        }
379                        _ => panic!("depth field on valid spv can only be 0, 1, or 2"),
380                    };
381
382                    let mut new_instructions = vec![];
383
384                    let complement_ti_id = existing_type_images_from_complement_instruction
385                        .get(&ti_complement)
386                        .copied()
387                        .or(op_type_image_idxs.iter().find_map(|&idx| {
388                            let word_count = hiword(spv[idx]) as usize;
389                            let result_id = spv[idx + 1];
390                            // To have a consistent instruction ordering, we white-out the existing OpTypeImage
391                            if ti_complement == spv[idx + 2..idx + word_count] {
392                                for it in new_spv.iter_mut().skip(idx).take(word_count) {
393                                    *it = encode_word(1, SPV_INSTRUCTION_OP_NOP);
394                                }
395                                Some(result_id)
396                            } else {
397                                None
398                            }
399                        }));
400                    let complement_ti_id = {
401                        let new_type_image_id = complement_ti_id.unwrap_or_else(|| {
402                            instruction_bound += 1;
403                            instruction_bound - 1
404                        });
405                        if !existing_type_images_from_complement_instruction
406                            .contains_key(&ti_complement)
407                        {
408                            let mut new_instruction = vec![
409                                encode_word(
410                                    (ti_complement.len() + 2) as u16,
411                                    SPV_INSTRUCTION_OP_TYPE_IMAGE,
412                                ),
413                                new_type_image_id,
414                            ];
415                            existing_type_images_from_complement_instruction
416                                .insert(ti_complement.clone(), new_type_image_id);
417                            new_instruction.append(&mut ti_complement);
418                            drop(ti_complement);
419                            new_instructions.append(&mut new_instruction);
420                        }
421                        new_type_image_id
422                    };
423
424                    // Try to find a type id for complement type image or create one
425                    let complement_tp_id = existing_type_pointers_from_type_image
426                        .get(&complement_ti_id)
427                        .copied()
428                        .or(op_type_pointer_idxs.iter().find_map(|&idx| {
429                            let result_id = spv[idx + 1];
430                            let type_id = spv[idx + 3];
431                            if type_id == complement_ti_id {
432                                existing_type_pointers_from_type_image
433                                    .insert(complement_ti_id, result_id);
434                                Some(result_id)
435                            } else {
436                                None
437                            }
438                        }))
439                        .unwrap_or_else(|| {
440                            let new_type_pointer_id = instruction_bound;
441                            instruction_bound += 1;
442                            let mut new_instruction = vec![
443                                encode_word(4, SPV_INSTRUCTION_OP_TYPE_POINTER),
444                                new_type_pointer_id,
445                                SPV_STORAGE_CLASS_UNIFORM_CONSTANT,
446                                complement_ti_id,
447                            ];
448                            new_instructions.append(&mut new_instruction);
449                            existing_type_pointers_from_type_image
450                                .insert(complement_ti_id, new_type_pointer_id);
451                            new_type_pointer_id
452                        });
453
454                    instruction_inserts.push(InstructionInsert {
455                        previous_spv_idx: ti_idx,
456                        instruction: new_instructions,
457                    });
458
459                    (
460                        v.next(variable_idx),
461                        lty,
462                        ti_id,
463                        complement_tp_id,
464                        complement_ti_id,
465                        complement_ty,
466                    )
467                }
468            }
469        })
470        .collect::<Vec<_>>();
471
472    // 11. New OpVariable with a new_id, patch old OpLoads, and new depth=1 OpTypeImage.
473    // Map new function arguments to the correct instructions.
474    // NOTE: GENERALLY, with glslc, each OpImage* will get its own OpLoad, so we don't need to
475    // check that its result isn't used for both regular and dref operations!
476    let mut affected_variables = Vec::new();
477
478    // There may be a shared OpTypeFunction but not shared OpFunctionParameter
479    let mut patched_function_types = HashMap::new();
480    let mut patched_function_parameters = HashSet::new();
481
482    // We may patch ourselves a new OpTypeFunction multiple times.
483    // Maps function type id and function index to our new type.
484    let mut defered_new_function_types: HashMap<(u32, usize), InstructionInsert> = HashMap::new();
485
486    for (
487        variable_idx_typed,
488        lty,
489        original_ti_id,
490        complement_tp_id,
491        complement_ti_id,
492        complement_ty,
493    ) in patch_variable_idxs
494    {
495        let variable_idx = variable_idx_typed.inner();
496        // OpVariable
497        let word_count = hiword(spv[variable_idx]);
498        let new_variable_id = instruction_bound;
499        instruction_bound += 1;
500        let mut new_variable = Vec::new();
501        new_variable.extend_from_slice(&spv[variable_idx..variable_idx + word_count as usize]);
502        new_variable[1] = complement_tp_id;
503        new_variable[2] = new_variable_id;
504        instruction_inserts.push(InstructionInsert {
505            previous_spv_idx: variable_idx,
506            instruction: new_variable,
507        });
508
509        affected_variables.push(util::DecorationVariable {
510            original_res_id: spv[variable_idx + 2],
511            new_res_id: new_variable_id,
512            correction_type: match complement_ty {
513                OperationVariant::Regular => CorrectionType::SplitDrefComparison,
514                OperationVariant::Dref => CorrectionType::SplitDrefRegular,
515            },
516        });
517
518        // OpLoad
519        match lty {
520            LoadType::Variable => {
521                let old_variable_id = spv[variable_idx + 2];
522                if let Some(op_load_idxs) =
523                    patch_object_id_to_loads.get(&variable_idx_typed.next(old_variable_id))
524                {
525                    for &(op_load_idx, ty) in op_load_idxs {
526                        if **ty == complement_ty {
527                            new_spv[op_load_idx + 1] = complement_ti_id;
528                            new_spv[op_load_idx + 3] = new_variable_id;
529                        } else {
530                            new_spv[op_load_idx + 1] = original_ti_id;
531                            new_spv[op_load_idx + 3] = old_variable_id;
532                        };
533                    }
534                }
535            }
536            LoadType::FunctionArgument => {
537                let mut function_id_and_index_to_new_parameter_id = HashMap::new();
538
539                // Patch function types, definition parameter, and final loads
540                for (variables, calls) in function_patch_variables_with_calls.iter() {
541                    if variables.contains(&variable_idx_typed.next(variable_idx)) {
542                        for &call in calls.iter().rev() {
543                            let function_id = spv[call.call_parameter.function_idx + 2];
544                            let type_function_id = spv[call.call_parameter.function_idx + 4];
545                            if !patched_function_parameters.contains(&(
546                                call.call_parameter.parameter_instruction_idx,
547                                spv[call.call_parameter.function_idx + 2],
548                            )) {
549                                let Some(type_function_idx) =
550                                    op_type_function_idxs.iter().find(|&&idx| {
551                                        let result_id = spv[idx + 1];
552                                        type_function_id == result_id
553                                    })
554                                else {
555                                    panic!(
556                                        "OpTypeFunction does not exist for function {}, type {}",
557                                        function_id, type_function_id
558                                    );
559                                };
560
561                                // To allow multiple patching we can either patch by taking
562                                // an instruction directly from the code, or by patching a
563                                // function we have already began
564                                // We save patching the function's type for later in case of
565                                // duplicate OpTypeFunction
566                                {
567                                    let (new_type_function_id, type_instruction_type_info) =
568                                        if let Some(new_function_type) = defered_new_function_types
569                                            .get_mut(&(
570                                                type_function_id,
571                                                call.call_parameter.function_idx,
572                                            ))
573                                        {
574                                            new_function_type.instruction.insert(
575                                                3 + call.call_parameter.parameter_instruction_idx
576                                                    + 1
577                                                    + 1,
578                                                complement_tp_id,
579                                            );
580                                            new_function_type.instruction[0] = encode_word(
581                                                new_function_type.instruction.len() as u16,
582                                                SPV_INSTRUCTION_OP_TYPE_FUNCTION,
583                                            );
584                                            (
585                                                new_function_type.instruction[1],
586                                                new_function_type.instruction[2..].to_vec(),
587                                            )
588                                        } else {
589                                            let new_function_type_id = instruction_bound;
590                                            instruction_bound += 1;
591
592                                            let word_count = hiword(spv[*type_function_idx]);
593                                            let mut type_function = Vec::new();
594                                            type_function.extend_from_slice(
595                                                &spv[*type_function_idx
596                                                    ..*type_function_idx + word_count as usize],
597                                            );
598                                            type_function[0] = encode_word(
599                                                word_count + 1,
600                                                SPV_INSTRUCTION_OP_TYPE_FUNCTION,
601                                            );
602                                            type_function[1] = new_function_type_id;
603                                            type_function.insert(
604                                                3 + call.call_parameter.parameter_instruction_idx
605                                                    + 1,
606                                                complement_tp_id,
607                                            );
608
609                                            let type_instruction_type_info =
610                                                type_function[2..].to_vec();
611
612                                            defered_new_function_types.insert(
613                                                (
614                                                    type_function_id,
615                                                    call.call_parameter.function_idx,
616                                                ),
617                                                InstructionInsert {
618                                                    previous_spv_idx: *type_function_idx,
619                                                    instruction: type_function,
620                                                },
621                                            );
622                                            (new_function_type_id, type_instruction_type_info)
623                                        };
624                                    let entry = patched_function_types
625                                        .entry(type_instruction_type_info)
626                                        .or_insert((new_type_function_id, vec![]));
627                                    entry
628                                        .1
629                                        .push((type_function_id, call.call_parameter.function_idx));
630                                }
631
632                                // Patch function parameter
633                                let new_parameter_id = instruction_bound;
634                                instruction_bound += 1;
635                                instruction_inserts.push(InstructionInsert {
636                                    previous_spv_idx: call.call_parameter.parameter_idx,
637                                    instruction: vec![
638                                        encode_word(3, SPV_INSTRUCTION_OP_FUNCTION_PARAMTER),
639                                        complement_tp_id,
640                                        new_parameter_id,
641                                    ],
642                                });
643
644                                // Use our new parameters to patch dependent OpLoads
645                                for load_idx in op_load_idxs.iter() {
646                                    let result_id = spv[load_idx + 2];
647                                    let ptr_id = spv[load_idx + 3];
648                                    let parameter_result_id =
649                                        spv[call.call_parameter.parameter_idx + 2];
650
651                                    // TODO: OPT Someone else can come by and rearrange these silly data
652                                    // structures later.
653                                    if ptr_id == parameter_result_id {
654                                        let ty = loaded_variable_ids
655                                            .iter()
656                                            .find_map(|&(id, ty)| {
657                                                (id.inner() == result_id).then_some(ty)
658                                            })
659                                            .unwrap();
660                                        if *ty == complement_ty {
661                                            new_spv[load_idx + 1] = complement_ti_id;
662                                            new_spv[load_idx + 3] = new_parameter_id;
663                                        }
664                                    }
665                                }
666
667                                let function_id = spv[call.function_call_idx + 3];
668                                function_id_and_index_to_new_parameter_id.insert(
669                                    (function_id, call.call_parameter.parameter_instruction_idx),
670                                    new_parameter_id,
671                                );
672                                patched_function_parameters.insert((
673                                    call.call_parameter.parameter_instruction_idx,
674                                    function_id,
675                                ));
676                            }
677                        }
678                    }
679                }
680
681                // Patch function calls that call other functions
682                for (variables, calls) in function_patch_variables_with_calls.iter() {
683                    if variables.contains(&variable_idx_typed.next(variable_idx)) {
684                        for &call in calls.iter().rev() {
685                            let function_idx = get_function_index_of_instruction_index(
686                                &spv,
687                                call.function_call_idx,
688                            );
689                            let function_id = spv[function_idx + 2];
690                            if let Some(parameter_word) = function_id_and_index_to_new_parameter_id
691                                .get(&(function_id, call.call_parameter.parameter_instruction_idx))
692                            {
693                                word_inserts.push(WordInsert {
694                                    idx: call.function_call_idx
695                                        + 4
696                                        + call.call_parameter.parameter_instruction_idx,
697                                    word: *parameter_word,
698                                    head_idx: call.function_call_idx,
699                                });
700                            } else {
701                                word_inserts.push(WordInsert {
702                                    idx: call.function_call_idx
703                                        + 4
704                                        + call.call_parameter.parameter_instruction_idx,
705                                    word: new_variable_id,
706                                    head_idx: call.function_call_idx,
707                                });
708                            }
709                        }
710                    }
711                }
712            }
713        }
714
715        // OpSampledImage
716        // NOTE: We did not patch in a new OpSampledImage and OpTypeSampledImage.
717        // Thankfully, it seems that `spirv-val`, `naga`, nor `tint` seem to care.
718    }
719
720    // 12. Remove duplicate function types and patch them into OpFunction
721    for (_, (new_type_function_id, functions)) in patched_function_types.into_iter() {
722        for (idx, &(type_function_id, function_idx)) in functions.iter().enumerate() {
723            if idx != 0 {
724                defered_new_function_types.remove(&(type_function_id, function_idx));
725            }
726            new_spv[function_idx + 4] = new_type_function_id;
727        }
728    }
729
730    // We now insert our new function types
731    for (_, new_instruction) in defered_new_function_types {
732        instruction_inserts.push(new_instruction)
733    }
734
735    // 13. Insert new OpDecorate
736    let DecorateOut {
737        descriptor_sets_to_correct,
738    } = util::decorate(DecorateIn {
739        spv: &spv,
740        instruction_inserts: &mut instruction_inserts,
741        first_op_deocrate_idx,
742        op_decorate_idxs: &op_decorate_idxs,
743        affected_variables: &affected_variables,
744        corrections,
745    });
746
747    // 14. Insert New Instructions
748    insert_new_instructions(&spv, &mut new_spv, &word_inserts, &instruction_inserts);
749
750    // 15. Correct OpDecorate Bindings
751    util::correct_decorate(CorrectDecorateIn {
752        new_spv: &mut new_spv,
753        descriptor_sets_to_correct,
754    });
755
756    // 16. Remove Instructions that have been Whited Out.
757    prune_noops(&mut new_spv);
758
759    // 17. Write New Header and New Code
760    Ok(fuse_final(spv_header, new_spv, instruction_bound))
761}