1use super::*;
2
3type LeftRightOutput = (Option<Vec<u32>>, Option<Vec<u32>>);
4
5pub fn mirrorpatch(
11 left_spv: &[u32],
12 left_corrections: &mut Option<CorrectionMap>,
13 right_spv: &[u32],
14 right_corrections: &mut Option<CorrectionMap>,
15) -> Result<LeftRightOutput, ()> {
16 if left_corrections.is_none() && right_corrections.is_none() {
17 return Ok((None, None));
18 }
19
20 let mut left_affected_decorations = vec![];
21 let mut right_affected_decorations = vec![];
22
23 let mut left_instruction_bound = left_spv[SPV_HEADER_INSTRUCTION_BOUND_OFFSET];
24 let mut right_instruction_bound = right_spv[SPV_HEADER_INSTRUCTION_BOUND_OFFSET];
25
26 let left_corrections_map = left_corrections
27 .as_ref()
28 .map(|correction_map| correction_map.sets.clone())
29 .unwrap_or_default();
30 let right_corrections_map = right_corrections
31 .as_ref()
32 .map(|correction_map| correction_map.sets.clone())
33 .unwrap_or_default();
34
35 let mut scan_set_idxs = left_corrections_map
36 .keys()
37 .chain(right_corrections_map.keys())
38 .copied()
39 .collect::<Vec<_>>();
40
41 scan_set_idxs.dedup();
42
43 for set_idx in scan_set_idxs {
44 let left_bindings = left_corrections_map
45 .get(&set_idx)
46 .cloned()
47 .map(|v| v.bindings)
48 .unwrap_or_default();
49 let right_bindings = right_corrections_map
50 .get(&set_idx)
51 .cloned()
52 .map(|v| v.bindings)
53 .unwrap_or_default();
54
55 for (left_binding_idx, l) in left_bindings.iter() {
56 let r = right_bindings
57 .get(left_binding_idx)
58 .cloned()
59 .unwrap_or_default();
60
61 push_affected_decorations(
62 &mut right_affected_decorations,
63 &mut right_instruction_bound,
64 set_idx,
65 *left_binding_idx,
66 l,
67 &r,
68 );
69 }
70
71 for (right_binding_idx, r) in right_bindings.iter() {
72 let l = left_bindings
73 .get(right_binding_idx)
74 .cloned()
75 .unwrap_or_default();
76
77 push_affected_decorations(
78 &mut left_affected_decorations,
79 &mut left_instruction_bound,
80 set_idx,
81 *right_binding_idx,
82 r,
83 &l,
84 );
85 }
86 }
87
88 let l = (!left_affected_decorations.is_empty())
89 .then(|| {
90 patch_spv_decorations(
91 left_spv,
92 left_corrections,
93 left_instruction_bound,
94 &left_affected_decorations,
95 )
96 })
97 .transpose()?;
98 let r = (!right_affected_decorations.is_empty())
99 .then(|| {
100 patch_spv_decorations(
101 right_spv,
102 right_corrections,
103 right_instruction_bound,
104 &right_affected_decorations,
105 )
106 })
107 .transpose()?;
108 Ok((l, r))
109}
110
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
112struct NewVariable {
113 set: u32,
114 binding: u32,
115 new_res_id: u32,
116 correction_type: CorrectionType,
117}
118
119fn patch_spv_decorations(
120 in_spv: &[u32],
121 corrections: &mut Option<CorrectionMap>,
122 new_instruction_bound: u32,
123 affected_decorations: &[NewVariable],
124) -> Result<Vec<u32>, ()> {
125 let spv = in_spv.to_owned();
126
127 let instruction_bound = new_instruction_bound;
128 let magic_number = spv[SPV_HEADER_MAGIC_NUM_OFFSET];
129 let spv_header = spv[0..SPV_HEADER_LENGTH].to_owned();
130
131 assert_eq!(magic_number, SPV_HEADER_MAGIC);
132
133 let mut instruction_inserts: Vec<InstructionInsert> = vec![];
134
135 let spv = spv.into_iter().skip(SPV_HEADER_LENGTH).collect::<Vec<_>>();
136 let mut new_spv = spv.clone();
137
138 let mut op_decorate_idxs = vec![];
140 let mut op_variable_idxs = vec![];
141 let mut spv_idx = 0;
142 while spv_idx < spv.len() {
143 let op = spv[spv_idx];
144 let word_count = hiword(op);
145 let instruction = loword(op);
146
147 if instruction == SPV_INSTRUCTION_OP_DECORATE {
148 op_decorate_idxs.push(spv_idx)
149 }
150 if instruction == SPV_INSTRUCTION_OP_VARIABLE {
151 op_variable_idxs.push(spv_idx)
152 }
153
154 spv_idx += word_count as usize;
155 }
156 let first_op_deocrate_idx = op_decorate_idxs.first().copied();
157
158 let mut cached_original_variable_idxs = HashMap::new();
160 let affected_decorations = affected_decorations
161 .iter()
162 .map(|affected| {
163 let NewVariable {
165 set,
166 binding,
167 new_res_id,
168 correction_type,
169 } = *affected;
170 let original_variable_idx =
171 *if let Some(idx) = cached_original_variable_idxs.get(&(set, binding)) {
172 idx
173 } else {
174 let original_variable_id = op_decorate_idxs
175 .iter()
176 .find_map(|&d_idx| {
177 let target_id = spv[d_idx + 1];
178 let decoration_id = spv[d_idx + 2];
179 let decoration_value = spv[d_idx + 3];
180 (decoration_id == SPV_DECORATION_DESCRIPTOR_SET
181 && decoration_value == set
182 && op_decorate_idxs.iter().any(|&idx| {
183 let binding_target_id = spv[idx + 1];
184 let decoration_id = spv[idx + 2];
185 let decoration_value = spv[idx + 3];
186 decoration_id == SPV_DECORATION_BINDING
187 && decoration_value == binding
188 && target_id == binding_target_id
189 }))
190 .then_some(target_id)
191 })
192 .unwrap();
193 let idx = op_variable_idxs
194 .iter()
195 .find(|&idx| spv[idx + 2] == original_variable_id)
196 .unwrap();
197 cached_original_variable_idxs.insert((set, binding), idx);
198 idx
199 };
200
201 let original_variable_id = spv[original_variable_idx + 2];
203 let mut new_variable = Vec::new();
204 let word_count = hiword(spv[original_variable_idx]);
205 new_variable.extend_from_slice(
206 &spv[original_variable_idx..original_variable_idx + word_count as usize],
207 );
208 new_variable[2] = new_res_id;
209 instruction_inserts.push(InstructionInsert {
210 previous_spv_idx: original_variable_idx,
211 instruction: new_variable,
212 });
213
214 AffectedDecoration {
216 original_res_id: original_variable_id,
217 new_res_id,
218 correction_type,
219 }
220 })
221 .collect::<Vec<_>>();
222
223 let DecorateOut {
225 descriptor_sets_to_correct,
226 } = util::decorate(DecorateIn {
227 spv: &spv,
228 instruction_inserts: &mut instruction_inserts,
229 first_op_deocrate_idx,
230 op_decorate_idxs: &op_decorate_idxs,
231 affected_decorations: &affected_decorations,
232 corrections,
233 });
234
235 insert_new_instructions(&spv, &mut new_spv, &[], &instruction_inserts);
237
238 util::correct_decorate(CorrectDecorateIn {
240 new_spv: &mut new_spv,
241 descriptor_sets_to_correct,
242 });
243
244 prune_noops(&mut new_spv);
246
247 Ok(fuse_final(spv_header, new_spv, instruction_bound))
249}
250
251fn push_affected_decorations(
252 new_variables: &mut Vec<NewVariable>,
253 instruction_bound: &mut u32,
254 set: u32,
255 binding: u32,
256 l: &CorrectionBinding,
257 r: &CorrectionBinding,
258) {
259 let mut ll = l
260 .corrections
261 .iter()
262 .map(Some)
263 .enumerate()
264 .collect::<Vec<_>>();
265
266 for r_correction in r.corrections.iter() {
267 let idx_ty = ll
268 .iter()
269 .find(|(_, correction)| Some(r_correction) == correction.as_ref().copied())
270 .copied();
271 if let Some((idx, _)) = idx_ty {
272 ll[idx].1 = None;
273 }
274 }
275
276 let mut offset = 0;
277 for (_, op) in ll {
278 if let Some(correction) = op {
279 *instruction_bound += 1;
280 let new_res_id = *instruction_bound - 1;
281 new_variables.push(NewVariable {
282 set,
283 binding: binding + offset,
284 new_res_id,
285 correction_type: *correction,
286 });
287 } else {
288 offset += 1;
289 }
290 }
291}
292
293#[test]
294fn test_push_affected_decorations() {
295 let l = CorrectionBinding {
296 corrections: vec![
297 CorrectionType::SplitCombined,
298 CorrectionType::SplitDrefRegular,
299 CorrectionType::SplitDrefRegular,
300 CorrectionType::SplitCombined,
301 CorrectionType::SplitDrefComparison,
302 ],
303 };
304
305 let r = CorrectionBinding {
306 corrections: vec![
307 CorrectionType::SplitDrefRegular,
308 CorrectionType::SplitDrefComparison,
309 ],
310 };
311
312 let mut affected = vec![];
313 push_affected_decorations(&mut affected, &mut 0, 0, 0, &l, &r);
314 assert_eq!(
315 affected,
316 vec![
317 NewVariable {
318 set: 0,
319 binding: 0,
320 new_res_id: 0,
321 correction_type: CorrectionType::SplitCombined,
322 },
323 NewVariable {
324 set: 0,
325 binding: 1,
326 new_res_id: 1,
327 correction_type: CorrectionType::SplitDrefRegular,
328 },
329 NewVariable {
330 set: 0,
331 binding: 1,
332 new_res_id: 2,
333 correction_type: CorrectionType::SplitCombined,
334 },
335 ]
336 );
337
338 let mut affected = vec![];
339 push_affected_decorations(&mut affected, &mut 0, 0, 0, &r, &l);
340 assert_eq!(affected, vec![]);
341}