1use super::*;
2
3mod layout;
4mod type_registry;
5
6use layout::*;
7use type_registry::*;
8
9pub fn immediatespatch(in_spv: &[u32]) -> Result<Vec<u32>, ()> {
12 let spv = in_spv.to_owned();
13
14 let instruction_bound = spv[SPV_HEADER_INSTRUCTION_BOUND_OFFSET];
15 let magic_number = spv[SPV_HEADER_MAGIC_NUM_OFFSET];
16
17 let spv_header = spv[0..SPV_HEADER_LENGTH].to_owned();
18
19 assert_eq!(magic_number, SPV_HEADER_MAGIC);
20
21 let mut instruction_inserts = vec![];
22 let word_inserts = vec![];
23
24 let spv = spv.into_iter().skip(SPV_HEADER_LENGTH).collect::<Vec<_>>();
25 let mut new_spv = spv.clone();
26
27 let mut op_variable_idxs = vec![];
29 let mut op_type_pointer_idxs = vec![];
30 let mut op_type_struct_idxs = vec![];
31 let mut op_type_array_idxs = vec![];
32 let mut op_type_matrix_idxs = vec![];
33 let mut op_type_vector_idxs = vec![];
34 let mut op_type_float_idxs = vec![];
35 let mut op_type_int_idxs = vec![];
36 let mut op_constant_idxs = vec![];
37 let mut op_decorate_idxs = vec![];
38 let mut op_member_decorate_idxs = vec![];
39
40 let mut spv_idx = 0;
41 while spv_idx < spv.len() {
42 let op = spv[spv_idx];
43 let word_count = hiword(op);
44 let instruction = loword(op);
45
46 match instruction {
47 SPV_INSTRUCTION_OP_VARIABLE => op_variable_idxs.push(spv_idx),
48 SPV_INSTRUCTION_OP_TYPE_POINTER => op_type_pointer_idxs.push(spv_idx),
49 SPV_INSTRUCTION_OP_TYPE_STRUCT => op_type_struct_idxs.push(spv_idx),
50 SPV_INSTRUCTION_OP_TYPE_ARRAY => op_type_array_idxs.push(spv_idx),
51 SPV_INSTRUCTION_OP_TYPE_MATRIX => op_type_matrix_idxs.push(spv_idx),
52 SPV_INSTRUCTION_OP_TYPE_VECTOR => op_type_vector_idxs.push(spv_idx),
53 SPV_INSTRUCTION_OP_TYPE_FLOAT => op_type_float_idxs.push(spv_idx),
54 SPV_INSTRUCTION_OP_TYPE_INT => op_type_int_idxs.push(spv_idx),
55 SPV_INSTRUCTION_OP_CONSTANT => op_constant_idxs.push(spv_idx),
56 SPV_INSTRUCTION_OP_DECORATE => op_decorate_idxs.push(spv_idx),
57 SPV_INSTRUCTION_OP_MEMBER_DECORATE => op_member_decorate_idxs.push(spv_idx),
58 _ => {}
59 }
60
61 spv_idx += word_count as usize;
62 }
63
64 let pc_variables = op_variable_idxs
66 .iter()
67 .filter_map(|&v_idx| {
68 let result_type_id = spv[v_idx + 1];
69 let result_id = spv[v_idx + 2];
70 let storage_class = spv[v_idx + 3];
71 (storage_class == SPV_STORAGE_CLASS_PUSH_CONSTANT).then_some((
72 v_idx,
73 result_type_id,
74 result_id,
75 ))
76 })
77 .collect::<Vec<_>>();
78
79 if pc_variables.is_empty() {
80 return Ok(in_spv.to_vec());
81 }
82
83 let block_struct_ids = pc_variables
85 .iter()
86 .map(|&(_, ptr_id, _)| {
87 op_type_pointer_idxs
88 .iter()
89 .find_map(|&tp_idx| {
90 let result_id = spv[tp_idx + 1];
91 let underlying_type_id = spv[tp_idx + 3];
92 (result_id == ptr_id).then_some(underlying_type_id)
93 })
94 .expect("OpVariable PushConstant referenced an undefined OpTypePointer")
95 })
96 .collect::<Vec<_>>();
97
98 let type_registry = build_type_registry(BuildTypeRegistryIn {
100 spv: &spv,
101 op_type_float_idxs: &op_type_float_idxs,
102 op_type_int_idxs: &op_type_int_idxs,
103 op_type_vector_idxs: &op_type_vector_idxs,
104 op_type_matrix_idxs: &op_type_matrix_idxs,
105 op_type_array_idxs: &op_type_array_idxs,
106 op_type_struct_idxs: &op_type_struct_idxs,
107 op_constant_idxs: &op_constant_idxs,
108 });
109
110 for &block_struct_id in &block_struct_ids {
112 relayout_type_recursive(
113 &spv,
114 &mut new_spv,
115 block_struct_id,
116 &type_registry,
117 &op_decorate_idxs,
118 &op_member_decorate_idxs,
119 );
120 }
121
122 for &tp_idx in &op_type_pointer_idxs {
125 let storage_class = spv[tp_idx + 2];
126 if storage_class == SPV_STORAGE_CLASS_PUSH_CONSTANT {
127 new_spv[tp_idx + 2] = SPV_STORAGE_CLASS_UNIFORM;
128 }
129 }
130 for &(v_idx, _, _) in &pc_variables {
131 new_spv[v_idx + 3] = SPV_STORAGE_CLASS_UNIFORM;
132 }
133
134 let first_op_decorate_idx = op_decorate_idxs.first().copied();
136 let next_set = op_decorate_idxs
137 .iter()
138 .filter_map(|&d_idx| {
139 let decoration_id = spv[d_idx + 2];
140 let decoration_value = spv[d_idx + 3];
141 (decoration_id == SPV_DECORATION_DESCRIPTOR_SET).then_some(decoration_value)
142 })
143 .max()
144 .map(|max| max + 1)
145 .unwrap_or(0);
146
147 for (binding_idx, &(_, _, var_id)) in pc_variables.iter().enumerate() {
148 instruction_inserts.push(InstructionInsert {
149 previous_spv_idx: first_op_decorate_idx
150 .expect("Push constant block has no OpDecorate (missing Block decoration?)"),
151 instruction: vec![
152 encode_word(4, SPV_INSTRUCTION_OP_DECORATE),
153 var_id,
154 SPV_DECORATION_DESCRIPTOR_SET,
155 next_set,
156 encode_word(4, SPV_INSTRUCTION_OP_DECORATE),
157 var_id,
158 SPV_DECORATION_BINDING,
159 binding_idx as u32,
160 ],
161 });
162 }
163
164 insert_new_instructions(&spv, &mut new_spv, &word_inserts, &instruction_inserts);
166
167 prune_noops(&mut new_spv);
169
170 Ok(fuse_final(spv_header, new_spv, instruction_bound))
172}
173
174fn relayout_type_recursive(
176 spv: &[u32],
177 new_spv: &mut [u32],
178 type_id: u32,
179 registry: &TypeRegistry,
180 op_decorate_idxs: &[usize],
181 op_member_decorate_idxs: &[usize],
182) {
183 let ty = match registry.get(&type_id) {
184 Some(t) => t,
185 None => return,
186 };
187
188 match &ty.kind {
189 TypeKind::Struct { members } => {
190 let layout = layout_struct(members, LayoutRule::Std140);
191
192 for (i, new_offset) in layout.member_offsets.iter().enumerate() {
193 patch_member_decoration_literal(
194 spv,
195 new_spv,
196 op_member_decorate_idxs,
197 type_id,
198 i as u32,
199 SPV_DECORATION_OFFSET,
200 *new_offset,
201 );
202 }
203
204 for (i, member) in members.iter().enumerate() {
205 let matrix_kind = match &member.kind {
206 TypeKind::Matrix { .. } => Some(&member.kind),
207 TypeKind::Array { element, .. } => match &element.kind {
208 TypeKind::Matrix { .. } => Some(&element.kind),
209 _ => None,
210 },
211 _ => None,
212 };
213 if let Some(TypeKind::Matrix { column, .. }) = matrix_kind {
214 let col_count = column_vec_count(column);
215 let scalar_w = column_scalar_width(column);
216 let new_stride = matrix_stride(col_count, scalar_w, LayoutRule::Std140);
217 patch_member_decoration_literal(
218 spv,
219 new_spv,
220 op_member_decorate_idxs,
221 type_id,
222 i as u32,
223 SPV_DECORATION_MATRIX_STRIDE,
224 new_stride,
225 );
226 }
227 relayout_type_recursive(
228 spv,
229 new_spv,
230 member.id,
231 registry,
232 op_decorate_idxs,
233 op_member_decorate_idxs,
234 );
235 }
236 }
237
238 TypeKind::Array { element, .. } => {
239 let new_stride = array_stride(&element.kind, LayoutRule::Std140);
240 for &d_idx in op_decorate_idxs {
241 let target_id = spv[d_idx + 1];
242 let decoration_id = spv[d_idx + 2];
243 if target_id == type_id && decoration_id == SPV_DECORATION_ARRAY_STRIDE {
244 new_spv[d_idx + 3] = new_stride;
245 }
246 }
247 relayout_type_recursive(
249 spv,
250 new_spv,
251 element.id,
252 registry,
253 op_decorate_idxs,
254 op_member_decorate_idxs,
255 );
256 }
257 _ => {}
259 }
260}
261
262fn patch_member_decoration_literal(
263 spv: &[u32],
264 new_spv: &mut [u32],
265 op_member_decorate_idxs: &[usize],
266 target_id: u32,
267 member: u32,
268 decoration: u32,
269 new_value: u32,
270) {
271 for &md_idx in op_member_decorate_idxs {
272 let md_target_id = spv[md_idx + 1];
273 let md_member = spv[md_idx + 2];
274 let md_decoration = spv[md_idx + 3];
275 if md_target_id == target_id && md_member == member && md_decoration == decoration {
276 new_spv[md_idx + 4] = new_value;
277 }
278 }
279}