Skip to main content

spirv_webgpu_transform/
mirrorpatch.rs

1use super::*;
2
3type LeftRightOutput = (Option<Vec<u32>>, Option<Vec<u32>>);
4
5/// Reflect one set of patched uniforms onto another shader with the same underlying set of
6/// uniforms.
7/// This is important for ensuring patched vertex and fragment shaders have the same layout.
8/// This is because some transformations occur based off what instructions appear, so as a result,
9/// the vertex and fragment shader may have a different layout after a set of transformations
10pub fn mirrorpatch(
11    left_spv: &[u32],
12    left_corrections: &mut Option<CorrectionMap>,
13    right_spv: &[u32],
14    right_corrections: &mut Option<CorrectionMap>,
15) -> Result<LeftRightOutput, ()> {
16    if left_corrections.is_none() && right_corrections.is_none() {
17        return Ok((None, None));
18    }
19
20    let mut left_affected_decorations = vec![];
21    let mut right_affected_decorations = vec![];
22
23    let mut left_instruction_bound = left_spv[SPV_HEADER_INSTRUCTION_BOUND_OFFSET];
24    let mut right_instruction_bound = right_spv[SPV_HEADER_INSTRUCTION_BOUND_OFFSET];
25
26    let left_corrections_map = left_corrections
27        .as_ref()
28        .map(|correction_map| correction_map.sets.clone())
29        .unwrap_or_default();
30    let right_corrections_map = right_corrections
31        .as_ref()
32        .map(|correction_map| correction_map.sets.clone())
33        .unwrap_or_default();
34
35    let mut scan_set_idxs = left_corrections_map
36        .keys()
37        .chain(right_corrections_map.keys())
38        .copied()
39        .collect::<Vec<_>>();
40
41    scan_set_idxs.dedup();
42
43    for set_idx in scan_set_idxs {
44        let left_bindings = left_corrections_map
45            .get(&set_idx)
46            .cloned()
47            .map(|v| v.bindings)
48            .unwrap_or_default();
49        let right_bindings = right_corrections_map
50            .get(&set_idx)
51            .cloned()
52            .map(|v| v.bindings)
53            .unwrap_or_default();
54
55        for (left_binding_idx, l) in left_bindings.iter() {
56            let r = right_bindings
57                .get(left_binding_idx)
58                .cloned()
59                .unwrap_or_default();
60
61            push_affected_decorations(
62                &mut right_affected_decorations,
63                &mut right_instruction_bound,
64                set_idx,
65                *left_binding_idx,
66                l,
67                &r,
68            );
69        }
70
71        for (right_binding_idx, r) in right_bindings.iter() {
72            let l = left_bindings
73                .get(right_binding_idx)
74                .cloned()
75                .unwrap_or_default();
76
77            push_affected_decorations(
78                &mut left_affected_decorations,
79                &mut left_instruction_bound,
80                set_idx,
81                *right_binding_idx,
82                r,
83                &l,
84            );
85        }
86    }
87
88    let l = (!left_affected_decorations.is_empty())
89        .then(|| {
90            patch_spv_decorations(
91                left_spv,
92                left_corrections,
93                left_instruction_bound,
94                &left_affected_decorations,
95            )
96        })
97        .transpose()?;
98    let r = (!right_affected_decorations.is_empty())
99        .then(|| {
100            patch_spv_decorations(
101                right_spv,
102                right_corrections,
103                right_instruction_bound,
104                &right_affected_decorations,
105            )
106        })
107        .transpose()?;
108    Ok((l, r))
109}
110
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
112struct NewVariable {
113    set: u32,
114    binding: u32,
115    new_res_id: u32,
116    correction_type: CorrectionType,
117}
118
119fn patch_spv_decorations(
120    in_spv: &[u32],
121    corrections: &mut Option<CorrectionMap>,
122    new_instruction_bound: u32,
123    affected_decorations: &[NewVariable],
124) -> Result<Vec<u32>, ()> {
125    let spv = in_spv.to_owned();
126
127    let instruction_bound = new_instruction_bound;
128    let magic_number = spv[SPV_HEADER_MAGIC_NUM_OFFSET];
129    let spv_header = spv[0..SPV_HEADER_LENGTH].to_owned();
130
131    assert_eq!(magic_number, SPV_HEADER_MAGIC);
132
133    let mut instruction_inserts: Vec<InstructionInsert> = vec![];
134
135    let spv = spv.into_iter().skip(SPV_HEADER_LENGTH).collect::<Vec<_>>();
136    let mut new_spv = spv.clone();
137
138    // 1. Find locations instructions we need
139    let mut op_decorate_idxs = vec![];
140    let mut op_variable_idxs = vec![];
141    let mut spv_idx = 0;
142    while spv_idx < spv.len() {
143        let op = spv[spv_idx];
144        let word_count = hiword(op);
145        let instruction = loword(op);
146
147        if instruction == SPV_INSTRUCTION_OP_DECORATE {
148            op_decorate_idxs.push(spv_idx)
149        }
150        if instruction == SPV_INSTRUCTION_OP_VARIABLE {
151            op_variable_idxs.push(spv_idx)
152        }
153
154        spv_idx += word_count as usize;
155    }
156    let first_op_deocrate_idx = op_decorate_idxs.first().copied();
157
158    // 2. Convert and insert new variables
159    let mut cached_original_variable_idxs = HashMap::new();
160    let affected_decorations = affected_decorations
161        .iter()
162        .map(|affected| {
163            // Given a set binding, find the original variable
164            let NewVariable {
165                set,
166                binding,
167                new_res_id,
168                correction_type,
169            } = *affected;
170            let original_variable_idx =
171                *if let Some(idx) = cached_original_variable_idxs.get(&(set, binding)) {
172                    idx
173                } else {
174                    let Some(original_variable_id) = op_decorate_idxs.iter().find_map(|&d_idx| {
175                        let target_id = spv[d_idx + 1];
176                        let decoration_id = spv[d_idx + 2];
177                        let decoration_value = spv[d_idx + 3];
178                        (decoration_id == SPV_DECORATION_DESCRIPTOR_SET
179                            && decoration_value == set
180                            && op_decorate_idxs.iter().any(|&idx| {
181                                let binding_target_id = spv[idx + 1];
182                                let decoration_id = spv[idx + 2];
183                                let decoration_value = spv[idx + 3];
184                                decoration_id == SPV_DECORATION_BINDING
185                                    && decoration_value == binding
186                                    && target_id == binding_target_id
187                            }))
188                        .then_some(target_id)
189                    }) else {
190                        // If there are no OpDecorates, no patching needs to be done.
191                        return Err(in_spv.to_vec());
192                    };
193                    let idx = op_variable_idxs
194                        .iter()
195                        .find(|&idx| spv[idx + 2] == original_variable_id)
196                        .unwrap();
197                    cached_original_variable_idxs.insert((set, binding), idx);
198                    idx
199                };
200
201            // Copy the original variable instruction and substitute new variable id
202            let original_variable_id = spv[original_variable_idx + 2];
203            let mut new_variable = Vec::new();
204            let word_count = hiword(spv[original_variable_idx]);
205            new_variable.extend_from_slice(
206                &spv[original_variable_idx..original_variable_idx + word_count as usize],
207            );
208            new_variable[2] = new_res_id;
209            instruction_inserts.push(InstructionInsert {
210                previous_spv_idx: original_variable_idx,
211                instruction: new_variable,
212            });
213
214            // Convert into affected decoration
215            Ok(AffectedDecoration {
216                original_res_id: original_variable_id,
217                new_res_id,
218                correction_type,
219            })
220        })
221        .collect::<Result<Vec<_>, _>>();
222
223    let affected_decorations = match affected_decorations {
224        Ok(d) => d,
225        Err(spv) => return Ok(spv),
226    };
227
228    // 3. Insert new OpDecorate
229    let DecorateOut {
230        descriptor_sets_to_correct,
231    } = util::decorate(DecorateIn {
232        spv: &spv,
233        instruction_inserts: &mut instruction_inserts,
234        first_op_deocrate_idx,
235        op_decorate_idxs: &op_decorate_idxs,
236        affected_decorations: &affected_decorations,
237        corrections,
238    });
239
240    // 4. Insert New Instructions
241    insert_new_instructions(&spv, &mut new_spv, &[], &instruction_inserts);
242
243    // 5. Correct OpDecorate Bindings
244    util::correct_decorate(CorrectDecorateIn {
245        new_spv: &mut new_spv,
246        descriptor_sets_to_correct,
247    });
248
249    // 6. Remove Instructions that have been Whited Out.
250    prune_noops(&mut new_spv);
251
252    // 7. Write New Header and New Code
253    Ok(fuse_final(spv_header, new_spv, instruction_bound))
254}
255
256fn push_affected_decorations(
257    new_variables: &mut Vec<NewVariable>,
258    instruction_bound: &mut u32,
259    set: u32,
260    binding: u32,
261    l: &CorrectionBinding,
262    r: &CorrectionBinding,
263) {
264    let mut ll = l
265        .corrections
266        .iter()
267        .map(Some)
268        .enumerate()
269        .collect::<Vec<_>>();
270
271    for r_correction in r.corrections.iter() {
272        let idx_ty = ll
273            .iter()
274            .find(|(_, correction)| Some(r_correction) == correction.as_ref().copied())
275            .copied();
276        if let Some((idx, _)) = idx_ty {
277            ll[idx].1 = None;
278        }
279    }
280
281    let mut offset = 0;
282    for (_, op) in ll {
283        if let Some(correction) = op {
284            *instruction_bound += 1;
285            let new_res_id = *instruction_bound - 1;
286            new_variables.push(NewVariable {
287                set,
288                binding: binding + offset,
289                new_res_id,
290                correction_type: *correction,
291            });
292        } else {
293            offset += 1;
294        }
295    }
296}
297
298#[test]
299fn test_push_affected_decorations() {
300    let l = CorrectionBinding {
301        corrections: vec![
302            CorrectionType::SplitCombined,
303            CorrectionType::SplitDrefRegular,
304            CorrectionType::SplitDrefRegular,
305            CorrectionType::SplitCombined,
306            CorrectionType::SplitDrefComparison,
307        ],
308    };
309
310    let r = CorrectionBinding {
311        corrections: vec![
312            CorrectionType::SplitDrefRegular,
313            CorrectionType::SplitDrefComparison,
314        ],
315    };
316
317    let mut affected = vec![];
318    push_affected_decorations(&mut affected, &mut 0, 0, 0, &l, &r);
319    assert_eq!(
320        affected,
321        vec![
322            NewVariable {
323                set: 0,
324                binding: 0,
325                new_res_id: 0,
326                correction_type: CorrectionType::SplitCombined,
327            },
328            NewVariable {
329                set: 0,
330                binding: 1,
331                new_res_id: 1,
332                correction_type: CorrectionType::SplitDrefRegular,
333            },
334            NewVariable {
335                set: 0,
336                binding: 1,
337                new_res_id: 2,
338                correction_type: CorrectionType::SplitCombined,
339            },
340        ]
341    );
342
343    let mut affected = vec![];
344    push_affected_decorations(&mut affected, &mut 0, 0, 0, &r, &l);
345    assert_eq!(affected, vec![]);
346}