Skip to main content

spirv_webgpu_transform/
splitdref.rs

1use super::*;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4enum OperationVariant {
5    Regular,
6    Dref,
7}
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10enum LoadType {
11    Variable,
12    FunctionArgument,
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16enum MixState {
17    Mixed,
18    PotentiallyMixed,
19}
20
21trait IsIndexOrId {}
22impl IsIndexOrId for u32 {}
23impl IsIndexOrId for usize {}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26enum PatchObjectType<T: IsIndexOrId> {
27    Sampler(T),
28    Image(T),
29}
30
31impl<T> PatchObjectType<T>
32where
33    T: IsIndexOrId,
34{
35    fn next<N: IsIndexOrId>(self, next_id: N) -> PatchObjectType<N> {
36        match self {
37            PatchObjectType::Sampler(_) => PatchObjectType::Sampler(next_id),
38            PatchObjectType::Image(_) => PatchObjectType::Image(next_id),
39        }
40    }
41
42    fn inner(self) -> T {
43        match self {
44            PatchObjectType::Sampler(v) => v,
45            PatchObjectType::Image(v) => v,
46        }
47    }
48}
49
50/// Perform the operation on a `Vec<u32>`.
51/// Use [u8_slice_to_u32_vec] to convert a `&[u8]` into a `Vec<u32>`.
52/// Either update the existing `corrections` or create a new one.
53pub fn drefsplitter(
54    in_spv: &[u32],
55    corrections: &mut Option<CorrectionMap>,
56) -> Result<Vec<u32>, ()> {
57    let spv = in_spv.to_owned();
58
59    let mut instruction_bound = spv[SPV_HEADER_INSTRUCTION_BOUND_OFFSET];
60    let magic_number = spv[SPV_HEADER_MAGIC_NUM_OFFSET];
61
62    let spv_header = spv[0..SPV_HEADER_LENGTH].to_owned();
63
64    assert_eq!(magic_number, SPV_HEADER_MAGIC);
65
66    let mut instruction_inserts: Vec<InstructionInsert> = vec![];
67    let mut word_inserts: Vec<WordInsert> = vec![];
68
69    let spv = spv.into_iter().skip(SPV_HEADER_LENGTH).collect::<Vec<_>>();
70    let mut new_spv = spv.clone();
71
72    // 1. Find locations instructions we need
73    let mut op_dref_operation_idxs = vec![];
74    let mut op_sampled_operation_idxs = vec![];
75    let mut op_sampled_image_idxs = vec![];
76    let mut op_load_idxs = vec![];
77    let mut op_variable_idxs = vec![];
78    let mut op_decorate_idxs = vec![];
79    let mut op_type_image_idxs = vec![];
80    let mut op_type_pointer_idxs = vec![];
81    let mut op_type_function_idxs = vec![];
82    let mut op_function_idxs = vec![];
83    let mut op_function_call_idxs = vec![];
84    let mut op_function_parameter_idxs = vec![];
85
86    let mut first_op_type_sampler_id = None;
87    let mut first_op_type_pointer_sampler_id = None;
88
89    let mut spv_idx = 0;
90    while spv_idx < spv.len() {
91        let op = spv[spv_idx];
92        let word_count = hiword(op);
93        let instruction = loword(op);
94
95        match instruction {
96            SPV_INSTRUCTION_OP_SAMPLED_IMAGE => op_sampled_image_idxs.push(spv_idx),
97            SPV_INSTRUCTION_OP_LOAD => op_load_idxs.push(spv_idx),
98            SPV_INSTRUCTION_OP_IMAGE_SAMPLE_DREF_IMPLICIT_LOD
99            | SPV_INSTRUCTION_OP_IMAGE_SAMPLE_DREF_EXPLICIT_LOD
100            | SPV_INSTRUCTION_OP_IMAGE_SAMPLE_PROJ_DREF_IMPLICIT_LOD
101            | SPV_INSTRUCTION_OP_IMAGE_SAMPLE_PROJ_DREF_EXPLICIT_LOD
102            | SPV_INSTRUCTION_OP_IMAGE_DREF_GATHER
103            | SPV_INSTRUCTION_OP_IMAGE_SPARSE_SAMPLE_DREF_IMPLICIT_LOD
104            | SPV_INSTRUCTION_OP_IMAGE_SPARSE_SAMPLE_DREF_EXPLICIT_LOD
105            | SPV_INSTRUCTION_OP_IMAGE_SPARSE_DREF_GATHER => op_dref_operation_idxs.push(spv_idx),
106            SPV_INSTRUCTION_OP_IMAGE_SAMPLE_IMPLICIT_LOD
107            | SPV_INSTRUCTION_OP_IMAGE_SAMPLE_EXPLICIT_LOD
108            | SPV_INSTRUCTION_OP_IMAGE_SAMPLE_PROJ_IMPLICIT_LOD
109            | SPV_INSTRUCTION_OP_IMAGE_SAMPLE_PROJ_EXPLICIT_LOD
110            | SPV_INSTRUCTION_OP_IMAGE_GATHER
111            | SPV_INSTRUCTION_OP_IMAGE_SPARSE_SAMPLE_IMPLICIT_LOD
112            | SPV_INSTRUCTION_OP_IMAGE_SPARSE_SAMPLE_EXPLICIT_LOD
113            | SPV_INSTRUCTION_OP_IMAGE_SPARSE_GATHER => op_sampled_operation_idxs.push(spv_idx),
114            SPV_INSTRUCTION_OP_VARIABLE => op_variable_idxs.push(spv_idx),
115            SPV_INSTRUCTION_OP_DECORATE => op_decorate_idxs.push(spv_idx),
116            SPV_INSTRUCTION_OP_TYPE_IMAGE => op_type_image_idxs.push(spv_idx),
117            SPV_INSTRUCTION_OP_TYPE_SAMPLER => {
118                first_op_type_sampler_id.get_or_insert(spv[spv_idx + 1]);
119            }
120            SPV_INSTRUCTION_OP_TYPE_POINTER => {
121                if first_op_type_sampler_id == Some(spv[spv_idx + 3])
122                    && spv[spv_idx + 2] == SPV_STORAGE_CLASS_UNIFORM_CONSTANT
123                {
124                    first_op_type_pointer_sampler_id = Some(spv[spv_idx + 1]);
125                }
126                op_type_pointer_idxs.push(spv_idx)
127            }
128            SPV_INSTRUCTION_OP_TYPE_FUNCTION => op_type_function_idxs.push(spv_idx),
129            SPV_INSTRUCTION_OP_FUNCTION => op_function_idxs.push(spv_idx),
130            SPV_INSTRUCTION_OP_FUNCTION_CALL => op_function_call_idxs.push(spv_idx),
131            SPV_INSTRUCTION_OP_FUNCTION_PARAMETER => op_function_parameter_idxs.push(spv_idx),
132            _ => {}
133        }
134
135        spv_idx += word_count as usize;
136    }
137
138    let first_op_deocrate_idx = op_decorate_idxs.first().copied();
139
140    // If there is no OpTypeSampler, either this is invalid, or we do not need to do any patching at all.
141    let (Some(first_op_type_sampler_id), Some(first_op_type_pointer_sampler_id)) =
142        (first_op_type_sampler_id, first_op_type_pointer_sampler_id)
143    else {
144        return Ok(in_spv.to_vec());
145    };
146
147    // 2. Collect all the loaded sampled images of both operation types
148    // Conveniently, the offset for this value is always +3 for all of these operations
149    let loaded_sampled_image_ids = op_sampled_operation_idxs
150        .iter()
151        .map(|idx| (spv[idx + 3], OperationVariant::Regular))
152        .chain(
153            op_dref_operation_idxs
154                .iter()
155                .map(|idx| (spv[idx + 3], OperationVariant::Dref)),
156        )
157        .collect::<Vec<_>>();
158
159    // 3. Backtrace to find the OpSampledImage that resulted in our loaded sampled images
160    let loaded_variable_ids = op_sampled_image_idxs
161        .iter()
162        .filter_map(|idx| {
163            let sampled_result_id = spv[idx + 2];
164            let loaded_image_id = spv[idx + 3];
165            let loaded_sampler_id = spv[idx + 4];
166            loaded_sampled_image_ids
167                .iter()
168                .find_map(|(id, ty)| (*id == sampled_result_id).then_some(ty))
169                .map(|ty| {
170                    [
171                        (PatchObjectType::Image(loaded_image_id), ty),
172                        (PatchObjectType::Sampler(loaded_sampler_id), ty),
173                    ]
174                })
175        })
176        .flatten()
177        .collect::<Vec<_>>();
178
179    // 4. Backtrack to find the OpLoad that resulted in our loaded images
180    let object_ids = op_load_idxs
181        .iter()
182        .filter_map(|idx| {
183            let loaded_result_id = spv[idx + 2];
184            let original_image_or_sampler = spv[idx + 3];
185            loaded_variable_ids
186                .iter()
187                .find_map(|(id, ty)| (id.inner() == loaded_result_id).then_some((id, ty)))
188                .map(|(id, ty)| (id.next(original_image_or_sampler), idx, ty))
189        })
190        .collect::<Vec<_>>();
191
192    // 5. Find the images that mismatch operations
193    let mut object_flags = HashMap::new();
194    let mut patch_object_id_to_loads = HashMap::new();
195
196    for (id, load_idx, ty) in object_ids.iter().copied() {
197        let entry = object_flags.entry(id).or_insert((false, false));
198
199        match ty {
200            OperationVariant::Regular => entry.0 = true,
201            OperationVariant::Dref => {
202                entry.1 = true;
203            }
204        }
205        patch_object_id_to_loads
206            .entry(id)
207            .or_insert(vec![])
208            .push((load_idx, ty));
209    }
210
211    let mixed_object_ids = object_flags
212        .iter()
213        .filter_map(|(&id, &(uses_regular, uses_dref))| (uses_regular && uses_dref).then_some(id))
214        .collect::<Vec<_>>();
215
216    // 6. Mix object flags from across contexts (OpVariable + OpFunctionArgument)
217    // See `test_cross_dref.spv` and `test_hidden3_dref.spv`.
218    let mut aggregate_flags: HashMap<PatchObjectType<usize>, (bool, bool)> = HashMap::new();
219    for (&id, &flags) in object_flags.iter() {
220        if let Some(&v_idx) = op_variable_idxs
221            .iter()
222            .find(|&&idx| spv[idx + 2] == id.inner())
223        {
224            let entry = aggregate_flags
225                .entry(id.next(v_idx))
226                .or_insert((false, false));
227            entry.0 |= flags.0;
228            entry.1 |= flags.1;
229        }
230    }
231    for (&id, &flags) in object_flags.iter() {
232        if let Some(&fp_idx) = op_function_parameter_idxs
233            .iter()
234            .find(|&&idx| spv[idx + 2] == id.inner())
235        {
236            let entry = get_function_from_parameter(&spv, fp_idx);
237            let mut traced = vec![];
238            let variables =
239                trace_function_argument_to_variables(TraceFunctionArgumentToVariablesIn {
240                    spv: &spv,
241                    op_variable_idxs: &op_variable_idxs,
242                    op_function_parameter_idxs: &op_function_parameter_idxs,
243                    op_function_call_idxs: &op_function_call_idxs,
244                    entry,
245                    traced_function_call_idxs: &mut traced,
246                });
247            for v_idx in variables {
248                if let Some(entry) = aggregate_flags.get_mut(&id.next(v_idx)) {
249                    entry.0 |= flags.0;
250                    entry.1 |= flags.1;
251                }
252            }
253        }
254    }
255    let aggregate_mixed_variables = aggregate_flags
256        .into_iter()
257        .filter_map(|(g, (uses_regular, uses_dref))| (uses_regular && uses_dref).then_some(g))
258        .collect::<HashSet<_>>();
259
260    // 6. Find the OpVariable of the mismatched images
261    let patch_variable_idxs = op_variable_idxs
262        .iter()
263        .filter_map(|idx: &usize| {
264            let result_id = spv[*idx + 2];
265            mixed_object_ids
266                .iter()
267                .find(|id| id.inner() == result_id)
268                .map(|id| id.next(*idx))
269        })
270        .collect::<Vec<_>>();
271
272    // 7. Find OpFunctionParameter of ~~the mismatched~~ ALL images operations
273    // Later, we can keep the ones that trace to mismatched global variables
274    let patch_function_parameter_idxs = op_function_parameter_idxs
275        .iter()
276        .filter_map(|idx: &usize| {
277            let result_id = spv[*idx + 2];
278            mixed_object_ids
279                .iter()
280                .find_map(|id| {
281                    (id.inner() == result_id).then_some((id.next(*idx), MixState::Mixed))
282                })
283                .or(object_ids.iter().find_map(|(id, _, _)| {
284                    (id.inner() == result_id).then_some((id.next(*idx), MixState::PotentiallyMixed))
285                }))
286        })
287        .collect::<Vec<_>>();
288
289    // 8. Find the OpVariable that eventually reaches OpFunctionCall of our OpFunctions
290    // Because functions may be deeply nested, we'll have to account for other OpFunctionCalls
291    let function_patch_variables_with_calls = patch_function_parameter_idxs
292        .iter()
293        .map(|&(idx, mix_state)| {
294            let mut traced_function_calls = vec![];
295            let entry = get_function_from_parameter(&spv, idx.inner());
296            let variables =
297                trace_function_argument_to_variables(TraceFunctionArgumentToVariablesIn {
298                    spv: &spv,
299                    op_variable_idxs: &op_variable_idxs,
300                    op_function_parameter_idxs: &op_function_parameter_idxs,
301                    op_function_call_idxs: &op_function_call_idxs,
302                    entry,
303                    traced_function_call_idxs: &mut traced_function_calls,
304                });
305            (
306                variables
307                    .into_iter()
308                    .map(|v| idx.next(v))
309                    .collect::<Vec<_>>(),
310                traced_function_calls,
311                mix_state,
312            )
313        })
314        .collect::<Vec<_>>();
315
316    let function_patch_variables_with_calls = function_patch_variables_with_calls
317        .iter()
318        .cloned()
319        .filter_map(|(variables, calls, mix_state)| match mix_state {
320            MixState::Mixed => Some((variables, calls)),
321            MixState::PotentiallyMixed => (function_patch_variables_with_calls.iter().any(
322                |(mixed_variables, _, mix_state)| {
323                    *mix_state == MixState::Mixed
324                        && mixed_variables
325                            .iter()
326                            .any(|va| variables.iter().any(|vb| va == vb))
327                },
328            ) || patch_variable_idxs
329                .iter()
330                .any(|idx| variables.iter().any(|va| va == idx))
331                || variables
332                    .iter()
333                    .any(|va| aggregate_mixed_variables.contains(va)))
334            .then_some((variables, calls)),
335        })
336        .collect::<Vec<_>>();
337
338    let mut patch_variable_idxs = patch_variable_idxs
339        .iter()
340        .copied()
341        .map(|idx| (idx, LoadType::Variable))
342        .collect::<Vec<_>>();
343
344    let mut function_patch_variable_set = HashSet::new();
345    for (variables, _) in function_patch_variables_with_calls.iter() {
346        for variable in variables {
347            if !function_patch_variable_set.contains(variable) {
348                patch_variable_idxs.push((*variable, LoadType::FunctionArgument));
349                function_patch_variable_set.insert(*variable);
350            }
351        }
352    }
353
354    // 9. Find OpTypePointer that resulted in OpVariable
355    let patch_variable_idxs = patch_variable_idxs.into_iter().map(|(variable_idx, lty)| {
356        let type_pointer_id = spv[variable_idx.inner() + 1];
357        let maybe_tp_idx = op_type_pointer_idxs.iter().find(|&tp_idx| {
358            let tp_id = spv[tp_idx + 1];
359            type_pointer_id == tp_id
360        });
361        (variable_idx, lty, maybe_tp_idx.copied())
362    });
363
364    // 10. Find OpTypeImage that resulted in OpTypePointer
365    //    We also want to create an complement OpTypeImage (depth=!depth) (without duplicates) and
366    //    a respective OpTypePointer ~~and OpTypeSampledImage pair~~ (also no duplicates).
367    let mut existing_type_pointers_from_type_image = HashMap::new();
368    let mut existing_type_images_from_complement_instruction = HashMap::new();
369
370    let patch_variable_idxs = patch_variable_idxs
371        .map(|(variable_idx, lty, tp_idx)| {
372            match variable_idx {
373                v @ PatchObjectType::Sampler(variable_idx) => {
374                    (
375                        v.next(variable_idx),
376                        lty,
377                        first_op_type_sampler_id,
378                        first_op_type_pointer_sampler_id,
379                        first_op_type_sampler_id,
380                        // From the perspective of a SPIRV sampler variable, this doesn't matter
381                        OperationVariant::Dref,
382                    )
383                }
384                v @ PatchObjectType::Image(variable_idx) => {
385                    let variable_result_id = spv[variable_idx];
386                    let image_type_id = if let Some(tp_idx) = tp_idx {
387                        // type_image_id
388                        spv[tp_idx + 3]
389                    } else if let Some(load_idxs) =
390                        patch_object_id_to_loads.get(&PatchObjectType::Image(variable_result_id))
391                        && let Some(&(load_idx, _)) = load_idxs.first()
392                    {
393                        // We don't have a type pointer, let's find the OpTypeImage via our original OpLoad!
394                        // load_type_result_id
395                        spv[load_idx + 1]
396                    } else {
397                        unreachable!(
398                            "Our OpVariable image id should always point back to a OpLoad id"
399                        );
400                    };
401
402                    // Grab the existing type image
403                    let (ti_idx, ti_id) = op_type_image_idxs
404                        .iter()
405                        .find_map(|&ti_idx| {
406                            let result_id = spv[ti_idx + 1];
407                            (result_id == image_type_id).then_some((ti_idx, result_id))
408                        })
409                        .unwrap();
410
411                    // Try to find an type image with the complement properties or (re-)create one
412                    let ti_word_count = hiword(spv[ti_idx]) as usize;
413                    let mut ti_complement = spv[ti_idx + 2..ti_idx + ti_word_count].to_vec();
414                    let complement_ty = match ti_complement[2] {
415                        0 | 2 => {
416                            ti_complement[2] = 1;
417                            OperationVariant::Dref
418                        }
419                        1 => {
420                            ti_complement[2] = 0;
421                            OperationVariant::Regular
422                        }
423                        _ => panic!("depth field on valid spv can only be 0, 1, or 2"),
424                    };
425
426                    let mut new_instructions = vec![];
427
428                    let complement_ti_id = existing_type_images_from_complement_instruction
429                        .get(&ti_complement)
430                        .copied()
431                        .or(op_type_image_idxs.iter().find_map(|&idx| {
432                            let word_count = hiword(spv[idx]) as usize;
433                            let result_id = spv[idx + 1];
434                            // To have a consistent instruction ordering, we white-out the existing OpTypeImage
435                            if ti_complement == spv[idx + 2..idx + word_count] {
436                                for it in new_spv.iter_mut().skip(idx).take(word_count) {
437                                    *it = encode_word(1, SPV_INSTRUCTION_OP_NOP);
438                                }
439                                Some(result_id)
440                            } else {
441                                None
442                            }
443                        }));
444                    let complement_ti_id = {
445                        let new_type_image_id = complement_ti_id.unwrap_or_else(|| {
446                            instruction_bound += 1;
447                            instruction_bound - 1
448                        });
449                        if !existing_type_images_from_complement_instruction
450                            .contains_key(&ti_complement)
451                        {
452                            let mut new_instruction = vec![
453                                encode_word(
454                                    (ti_complement.len() + 2) as u16,
455                                    SPV_INSTRUCTION_OP_TYPE_IMAGE,
456                                ),
457                                new_type_image_id,
458                            ];
459                            existing_type_images_from_complement_instruction
460                                .insert(ti_complement.clone(), new_type_image_id);
461                            new_instruction.append(&mut ti_complement);
462                            drop(ti_complement);
463                            new_instructions.append(&mut new_instruction);
464                        }
465                        new_type_image_id
466                    };
467
468                    // Try to find a type id for complement type image or create one
469                    let complement_tp_id = existing_type_pointers_from_type_image
470                        .get(&complement_ti_id)
471                        .copied()
472                        .or(op_type_pointer_idxs.iter().find_map(|&idx| {
473                            let result_id = spv[idx + 1];
474                            let type_id = spv[idx + 3];
475                            if type_id == complement_ti_id {
476                                existing_type_pointers_from_type_image
477                                    .insert(complement_ti_id, result_id);
478                                Some(result_id)
479                            } else {
480                                None
481                            }
482                        }))
483                        .unwrap_or_else(|| {
484                            let new_type_pointer_id = instruction_bound;
485                            instruction_bound += 1;
486                            let mut new_instruction = vec![
487                                encode_word(4, SPV_INSTRUCTION_OP_TYPE_POINTER),
488                                new_type_pointer_id,
489                                SPV_STORAGE_CLASS_UNIFORM_CONSTANT,
490                                complement_ti_id,
491                            ];
492                            new_instructions.append(&mut new_instruction);
493                            existing_type_pointers_from_type_image
494                                .insert(complement_ti_id, new_type_pointer_id);
495                            new_type_pointer_id
496                        });
497
498                    instruction_inserts.push(InstructionInsert {
499                        previous_spv_idx: ti_idx,
500                        instruction: new_instructions,
501                    });
502
503                    (
504                        v.next(variable_idx),
505                        lty,
506                        ti_id,
507                        complement_tp_id,
508                        complement_ti_id,
509                        complement_ty,
510                    )
511                }
512            }
513        })
514        .collect::<Vec<_>>();
515
516    // 11. New OpVariable with a new_id, patch old OpLoads, and new depth=1 OpTypeImage.
517    // Map new function arguments to the correct instructions.
518    // NOTE: GENERALLY, with glslc, each OpImage* will get its own OpLoad, so we don't need to
519    // check that its result isn't used for both regular and dref operations!
520    let mut affected_variables = Vec::new();
521
522    // There may be a shared OpTypeFunction but not shared OpFunctionParameter
523    let mut patched_function_types = HashMap::new();
524    let mut patched_function_parameters = HashSet::new();
525
526    // We may patch ourselves a new OpTypeFunction multiple times.
527    // Maps function type id and function index to our new type.
528    let mut defered_new_function_types: HashMap<(u32, usize), InstructionInsert> = HashMap::new();
529
530    for (
531        variable_idx_typed,
532        lty,
533        original_ti_id,
534        complement_tp_id,
535        complement_ti_id,
536        complement_ty,
537    ) in patch_variable_idxs
538    {
539        let variable_idx = variable_idx_typed.inner();
540        // OpVariable
541        let word_count = hiword(spv[variable_idx]);
542        let new_variable_id = instruction_bound;
543        instruction_bound += 1;
544        let mut new_variable = Vec::new();
545        new_variable.extend_from_slice(&spv[variable_idx..variable_idx + word_count as usize]);
546        new_variable[1] = complement_tp_id;
547        new_variable[2] = new_variable_id;
548        instruction_inserts.push(InstructionInsert {
549            previous_spv_idx: variable_idx,
550            instruction: new_variable,
551        });
552
553        affected_variables.push(AffectedDecoration {
554            original_res_id: spv[variable_idx + 2],
555            new_res_ids: vec![new_variable_id],
556            correction_type: match complement_ty {
557                OperationVariant::Regular => CorrectionType::SplitDrefRegular,
558                OperationVariant::Dref => CorrectionType::SplitDrefComparison,
559            },
560        });
561
562        // OpLoad
563        match lty {
564            LoadType::Variable => {
565                let old_variable_id = spv[variable_idx + 2];
566                if let Some(op_load_idxs) =
567                    patch_object_id_to_loads.get(&variable_idx_typed.next(old_variable_id))
568                {
569                    for &(op_load_idx, ty) in op_load_idxs {
570                        if **ty == complement_ty {
571                            new_spv[op_load_idx + 1] = complement_ti_id;
572                            new_spv[op_load_idx + 3] = new_variable_id;
573                        } else {
574                            new_spv[op_load_idx + 1] = original_ti_id;
575                            new_spv[op_load_idx + 3] = old_variable_id;
576                        };
577                    }
578                }
579            }
580            LoadType::FunctionArgument => {
581                let mut function_id_and_index_to_new_parameter_id = HashMap::new();
582
583                // Patch function types, definition parameter, and final loads
584                for (variables, calls) in function_patch_variables_with_calls.iter() {
585                    if variables.contains(&variable_idx_typed.next(variable_idx)) {
586                        for &call in calls.iter().rev() {
587                            let function_id = spv[call.call_parameter.function_idx + 2];
588                            let type_function_id = spv[call.call_parameter.function_idx + 4];
589                            if !patched_function_parameters.contains(&(
590                                call.call_parameter.parameter_instruction_idx,
591                                spv[call.call_parameter.function_idx + 2],
592                            )) {
593                                let Some(type_function_idx) =
594                                    op_type_function_idxs.iter().find(|&&idx| {
595                                        let result_id = spv[idx + 1];
596                                        type_function_id == result_id
597                                    })
598                                else {
599                                    panic!(
600                                        "OpTypeFunction does not exist for function {}, type {}",
601                                        function_id, type_function_id
602                                    );
603                                };
604
605                                // To allow multiple patching we can either patch by taking
606                                // an instruction directly from the code, or by patching a
607                                // function we have already began
608                                // We save patching the function's type for later in case of
609                                // duplicate OpTypeFunction
610                                {
611                                    let (new_type_function_id, type_instruction_type_info) =
612                                        if let Some(new_function_type) = defered_new_function_types
613                                            .get_mut(&(
614                                                type_function_id,
615                                                call.call_parameter.function_idx,
616                                            ))
617                                        {
618                                            new_function_type.instruction.insert(
619                                                3 + call.call_parameter.parameter_instruction_idx
620                                                    + 1
621                                                    + 1,
622                                                complement_tp_id,
623                                            );
624                                            new_function_type.instruction[0] = encode_word(
625                                                new_function_type.instruction.len() as u16,
626                                                SPV_INSTRUCTION_OP_TYPE_FUNCTION,
627                                            );
628                                            (
629                                                new_function_type.instruction[1],
630                                                new_function_type.instruction[2..].to_vec(),
631                                            )
632                                        } else {
633                                            let new_function_type_id = instruction_bound;
634                                            instruction_bound += 1;
635
636                                            let word_count = hiword(spv[*type_function_idx]);
637                                            let mut type_function = Vec::new();
638                                            type_function.extend_from_slice(
639                                                &spv[*type_function_idx
640                                                    ..*type_function_idx + word_count as usize],
641                                            );
642                                            type_function[0] = encode_word(
643                                                word_count + 1,
644                                                SPV_INSTRUCTION_OP_TYPE_FUNCTION,
645                                            );
646                                            type_function[1] = new_function_type_id;
647                                            type_function.insert(
648                                                3 + call.call_parameter.parameter_instruction_idx
649                                                    + 1,
650                                                complement_tp_id,
651                                            );
652
653                                            let type_instruction_type_info =
654                                                type_function[2..].to_vec();
655
656                                            defered_new_function_types.insert(
657                                                (
658                                                    type_function_id,
659                                                    call.call_parameter.function_idx,
660                                                ),
661                                                InstructionInsert {
662                                                    previous_spv_idx: *type_function_idx,
663                                                    instruction: type_function,
664                                                },
665                                            );
666                                            (new_function_type_id, type_instruction_type_info)
667                                        };
668                                    let entry = patched_function_types
669                                        .entry(type_instruction_type_info)
670                                        .or_insert((new_type_function_id, vec![]));
671                                    entry
672                                        .1
673                                        .push((type_function_id, call.call_parameter.function_idx));
674                                }
675
676                                // Patch function parameter
677                                let new_parameter_id = instruction_bound;
678                                instruction_bound += 1;
679                                instruction_inserts.push(InstructionInsert {
680                                    previous_spv_idx: call.call_parameter.parameter_idx,
681                                    instruction: vec![
682                                        encode_word(3, SPV_INSTRUCTION_OP_FUNCTION_PARAMETER),
683                                        complement_tp_id,
684                                        new_parameter_id,
685                                    ],
686                                });
687
688                                // Use our new parameters to patch dependent OpLoads
689                                for load_idx in op_load_idxs.iter() {
690                                    let result_id = spv[load_idx + 2];
691                                    let ptr_id = spv[load_idx + 3];
692                                    let parameter_result_id =
693                                        spv[call.call_parameter.parameter_idx + 2];
694
695                                    // TODO: OPT Someone else can come by and rearrange these silly data
696                                    // structures later.
697                                    if ptr_id == parameter_result_id {
698                                        let ty = loaded_variable_ids
699                                            .iter()
700                                            .find_map(|&(id, ty)| {
701                                                (id.inner() == result_id).then_some(ty)
702                                            })
703                                            .unwrap();
704                                        if *ty == complement_ty {
705                                            new_spv[load_idx + 1] = complement_ti_id;
706                                            new_spv[load_idx + 3] = new_parameter_id;
707                                        }
708                                    }
709                                }
710
711                                let function_id = spv[call.function_call_idx + 3];
712                                function_id_and_index_to_new_parameter_id.insert(
713                                    (function_id, call.call_parameter.parameter_instruction_idx),
714                                    new_parameter_id,
715                                );
716                                patched_function_parameters.insert((
717                                    call.call_parameter.parameter_instruction_idx,
718                                    function_id,
719                                ));
720                            }
721                        }
722                    }
723                }
724
725                // Patch function calls that call other functions
726                for (variables, calls) in function_patch_variables_with_calls.iter() {
727                    if variables.contains(&variable_idx_typed.next(variable_idx)) {
728                        for &call in calls.iter().rev() {
729                            let function_idx = get_function_index_of_instruction_index(
730                                &spv,
731                                call.function_call_idx,
732                            );
733                            let function_id = spv[function_idx + 2];
734                            if let Some(parameter_word) = function_id_and_index_to_new_parameter_id
735                                .get(&(function_id, call.call_parameter.parameter_instruction_idx))
736                            {
737                                word_inserts.push(WordInsert {
738                                    idx: call.function_call_idx
739                                        + 4
740                                        + call.call_parameter.parameter_instruction_idx,
741                                    word: *parameter_word,
742                                    head_idx: call.function_call_idx,
743                                });
744                            } else {
745                                word_inserts.push(WordInsert {
746                                    idx: call.function_call_idx
747                                        + 4
748                                        + call.call_parameter.parameter_instruction_idx,
749                                    word: new_variable_id,
750                                    head_idx: call.function_call_idx,
751                                });
752                            }
753                        }
754                    }
755                }
756            }
757        }
758
759        // OpSampledImage
760        // NOTE: We did not patch in a new OpSampledImage and OpTypeSampledImage.
761        // Thankfully, it seems that `spirv-val`, `naga`, nor `tint` seem to care.
762    }
763
764    // 12. Remove duplicate function types and patch them into OpFunction
765    for (_, (new_type_function_id, functions)) in patched_function_types.into_iter() {
766        for (idx, &(type_function_id, function_idx)) in functions.iter().enumerate() {
767            if idx != 0 {
768                defered_new_function_types.remove(&(type_function_id, function_idx));
769            }
770            new_spv[function_idx + 4] = new_type_function_id;
771        }
772    }
773
774    // We now insert our new function types
775    for (_, new_instruction) in defered_new_function_types {
776        instruction_inserts.push(new_instruction)
777    }
778
779    // 13. Insert new OpDecorate
780    let DecorateOut {
781        descriptor_sets_to_correct,
782    } = util::decorate(DecorateIn {
783        spv: &spv,
784        instruction_inserts: &mut instruction_inserts,
785        first_op_deocrate_idx,
786        op_decorate_idxs: &op_decorate_idxs,
787        affected_decorations: &affected_variables,
788        corrections,
789    });
790
791    // 14. Insert New Instructions
792    insert_new_instructions(&spv, &mut new_spv, &word_inserts, &instruction_inserts);
793
794    // 15. Correct OpDecorate Bindings
795    util::correct_decorate(CorrectDecorateIn {
796        new_spv: &mut new_spv,
797        descriptor_sets_to_correct,
798    });
799
800    // 16. Remove Instructions that have been Whited Out.
801    prune_noops(&mut new_spv);
802
803    // 17. Write New Header and New Code
804    Ok(fuse_final(spv_header, new_spv, instruction_bound))
805}