spirv_webgpu_transform/
mirrorpatch.rs

1use super::*;
2
3type LeftRightOutput = (Option<Vec<u32>>, Option<Vec<u32>>);
4
5/// Reflect one set of patched uniforms onto another shader with the same underlying set of
6/// uniforms.
7/// This is important for ensuring patched vertex and fragment shaders have the same layout.
8/// This is because some transformations occur based off what instructions appear, so as a result,
9/// the vertex and fragment shader may have a different layout after a set of transformations
10pub fn mirrorpatch(
11    left_spv: &[u32],
12    left_corrections: &mut Option<CorrectionMap>,
13    right_spv: &[u32],
14    right_corrections: &mut Option<CorrectionMap>,
15) -> Result<LeftRightOutput, ()> {
16    if left_corrections.is_none() && right_corrections.is_none() {
17        return Ok((None, None));
18    }
19
20    let mut left_affected_decorations = vec![];
21    let mut right_affected_decorations = vec![];
22
23    let mut left_instruction_bound = left_spv[SPV_HEADER_INSTRUCTION_BOUND_OFFSET];
24    let mut right_instruction_bound = right_spv[SPV_HEADER_INSTRUCTION_BOUND_OFFSET];
25
26    let left_corrections_map = left_corrections
27        .as_ref()
28        .map(|correction_map| correction_map.sets.clone())
29        .unwrap_or_default();
30    let right_corrections_map = right_corrections
31        .as_ref()
32        .map(|correction_map| correction_map.sets.clone())
33        .unwrap_or_default();
34
35    let mut scan_set_idxs = left_corrections_map
36        .keys()
37        .chain(right_corrections_map.keys())
38        .copied()
39        .collect::<Vec<_>>();
40
41    scan_set_idxs.dedup();
42
43    for set_idx in scan_set_idxs {
44        let left_bindings = left_corrections_map
45            .get(&set_idx)
46            .cloned()
47            .map(|v| v.bindings)
48            .unwrap_or_default();
49        let right_bindings = right_corrections_map
50            .get(&set_idx)
51            .cloned()
52            .map(|v| v.bindings)
53            .unwrap_or_default();
54
55        for (left_binding_idx, l) in left_bindings.iter() {
56            let r = right_bindings
57                .get(left_binding_idx)
58                .cloned()
59                .unwrap_or_default();
60
61            push_affected_decorations(
62                &mut right_affected_decorations,
63                &mut right_instruction_bound,
64                set_idx,
65                *left_binding_idx,
66                l,
67                &r,
68            );
69        }
70
71        for (right_binding_idx, r) in right_bindings.iter() {
72            let l = left_bindings
73                .get(right_binding_idx)
74                .cloned()
75                .unwrap_or_default();
76
77            push_affected_decorations(
78                &mut left_affected_decorations,
79                &mut left_instruction_bound,
80                set_idx,
81                *right_binding_idx,
82                r,
83                &l,
84            );
85        }
86    }
87
88    let l = (!left_affected_decorations.is_empty())
89        .then(|| {
90            patch_spv_decorations(
91                left_spv,
92                left_corrections,
93                left_instruction_bound,
94                &left_affected_decorations,
95            )
96        })
97        .transpose()?;
98    let r = (!right_affected_decorations.is_empty())
99        .then(|| {
100            patch_spv_decorations(
101                right_spv,
102                right_corrections,
103                right_instruction_bound,
104                &right_affected_decorations,
105            )
106        })
107        .transpose()?;
108    Ok((l, r))
109}
110
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
112struct NewVariable {
113    set: u32,
114    binding: u32,
115    new_res_id: u32,
116    correction_type: CorrectionType,
117}
118
119fn patch_spv_decorations(
120    in_spv: &[u32],
121    corrections: &mut Option<CorrectionMap>,
122    new_instruction_bound: u32,
123    affected_decorations: &[NewVariable],
124) -> Result<Vec<u32>, ()> {
125    let spv = in_spv.to_owned();
126
127    let instruction_bound = new_instruction_bound;
128    let magic_number = spv[SPV_HEADER_MAGIC_NUM_OFFSET];
129    let spv_header = spv[0..SPV_HEADER_LENGTH].to_owned();
130
131    assert_eq!(magic_number, SPV_HEADER_MAGIC);
132
133    let mut instruction_inserts: Vec<InstructionInsert> = vec![];
134
135    let spv = spv.into_iter().skip(SPV_HEADER_LENGTH).collect::<Vec<_>>();
136    let mut new_spv = spv.clone();
137
138    // 1. Find locations instructions we need
139    let mut op_decorate_idxs = vec![];
140    let mut op_variable_idxs = vec![];
141    let mut spv_idx = 0;
142    while spv_idx < spv.len() {
143        let op = spv[spv_idx];
144        let word_count = hiword(op);
145        let instruction = loword(op);
146
147        if instruction == SPV_INSTRUCTION_OP_DECORATE {
148            op_decorate_idxs.push(spv_idx)
149        }
150        if instruction == SPV_INSTRUCTION_OP_VARIABLE {
151            op_variable_idxs.push(spv_idx)
152        }
153
154        spv_idx += word_count as usize;
155    }
156    let first_op_deocrate_idx = op_decorate_idxs.first().copied();
157
158    // 2. Convert and insert new variables
159    let mut cached_original_variable_idxs = HashMap::new();
160    let affected_decorations = affected_decorations
161        .iter()
162        .map(|affected| {
163            // Given a set binding, find the original variable
164            let NewVariable {
165                set,
166                binding,
167                new_res_id,
168                correction_type,
169            } = *affected;
170            let original_variable_idx =
171                *if let Some(idx) = cached_original_variable_idxs.get(&(set, binding)) {
172                    idx
173                } else {
174                    let original_variable_id = op_decorate_idxs
175                        .iter()
176                        .find_map(|&d_idx| {
177                            let target_id = spv[d_idx + 1];
178                            let decoration_id = spv[d_idx + 2];
179                            let decoration_value = spv[d_idx + 3];
180                            (decoration_id == SPV_DECORATION_DESCRIPTOR_SET
181                                && decoration_value == set
182                                && op_decorate_idxs.iter().any(|&idx| {
183                                    let binding_target_id = spv[idx + 1];
184                                    let decoration_id = spv[idx + 2];
185                                    let decoration_value = spv[idx + 3];
186                                    decoration_id == SPV_DECORATION_BINDING
187                                        && decoration_value == binding
188                                        && target_id == binding_target_id
189                                }))
190                            .then_some(target_id)
191                        })
192                        .unwrap();
193                    let idx = op_variable_idxs
194                        .iter()
195                        .find(|&idx| spv[idx + 2] == original_variable_id)
196                        .unwrap();
197                    cached_original_variable_idxs.insert((set, binding), idx);
198                    idx
199                };
200
201            // Copy the original variable instruction and substitute new variable id
202            let original_variable_id = spv[original_variable_idx + 2];
203            let mut new_variable = Vec::new();
204            let word_count = hiword(spv[original_variable_idx]);
205            new_variable.extend_from_slice(
206                &spv[original_variable_idx..original_variable_idx + word_count as usize],
207            );
208            new_variable[2] = new_res_id;
209            instruction_inserts.push(InstructionInsert {
210                previous_spv_idx: original_variable_idx,
211                instruction: new_variable,
212            });
213
214            // Convert into affected decoration
215            AffectedDecoration {
216                original_res_id: original_variable_id,
217                new_res_id,
218                correction_type,
219            }
220        })
221        .collect::<Vec<_>>();
222
223    // 3. Insert new OpDecorate
224    let DecorateOut {
225        descriptor_sets_to_correct,
226    } = util::decorate(DecorateIn {
227        spv: &spv,
228        instruction_inserts: &mut instruction_inserts,
229        first_op_deocrate_idx,
230        op_decorate_idxs: &op_decorate_idxs,
231        affected_decorations: &affected_decorations,
232        corrections,
233    });
234
235    // 4. Insert New Instructions
236    insert_new_instructions(&spv, &mut new_spv, &[], &instruction_inserts);
237
238    // 5. Correct OpDecorate Bindings
239    util::correct_decorate(CorrectDecorateIn {
240        new_spv: &mut new_spv,
241        descriptor_sets_to_correct,
242    });
243
244    // 6. Remove Instructions that have been Whited Out.
245    prune_noops(&mut new_spv);
246
247    // 7. Write New Header and New Code
248    Ok(fuse_final(spv_header, new_spv, instruction_bound))
249}
250
251fn push_affected_decorations(
252    new_variables: &mut Vec<NewVariable>,
253    instruction_bound: &mut u32,
254    set: u32,
255    binding: u32,
256    l: &CorrectionBinding,
257    r: &CorrectionBinding,
258) {
259    let mut ll = l
260        .corrections
261        .iter()
262        .map(Some)
263        .enumerate()
264        .collect::<Vec<_>>();
265
266    for r_correction in r.corrections.iter() {
267        let idx_ty = ll
268            .iter()
269            .find(|(_, correction)| Some(r_correction) == correction.as_ref().copied())
270            .copied();
271        if let Some((idx, _)) = idx_ty {
272            ll[idx].1 = None;
273        }
274    }
275
276    let mut offset = 0;
277    for (_, op) in ll {
278        if let Some(correction) = op {
279            *instruction_bound += 1;
280            let new_res_id = *instruction_bound - 1;
281            new_variables.push(NewVariable {
282                set,
283                binding: binding + offset,
284                new_res_id,
285                correction_type: *correction,
286            });
287        } else {
288            offset += 1;
289        }
290    }
291}
292
293#[test]
294fn test_push_affected_decorations() {
295    let l = CorrectionBinding {
296        corrections: vec![
297            CorrectionType::SplitCombined,
298            CorrectionType::SplitDrefRegular,
299            CorrectionType::SplitDrefRegular,
300            CorrectionType::SplitCombined,
301            CorrectionType::SplitDrefComparison,
302        ],
303    };
304
305    let r = CorrectionBinding {
306        corrections: vec![
307            CorrectionType::SplitDrefRegular,
308            CorrectionType::SplitDrefComparison,
309        ],
310    };
311
312    let mut affected = vec![];
313    push_affected_decorations(&mut affected, &mut 0, 0, 0, &l, &r);
314    assert_eq!(
315        affected,
316        vec![
317            NewVariable {
318                set: 0,
319                binding: 0,
320                new_res_id: 0,
321                correction_type: CorrectionType::SplitCombined,
322            },
323            NewVariable {
324                set: 0,
325                binding: 1,
326                new_res_id: 1,
327                correction_type: CorrectionType::SplitDrefRegular,
328            },
329            NewVariable {
330                set: 0,
331                binding: 1,
332                new_res_id: 2,
333                correction_type: CorrectionType::SplitCombined,
334            },
335        ]
336    );
337
338    let mut affected = vec![];
339    push_affected_decorations(&mut affected, &mut 0, 0, 0, &r, &l);
340    assert_eq!(affected, vec![]);
341}