Skip to main content

spirv_webgpu_transform/
immediatespatch.rs

1use super::*;
2
3mod layout;
4mod type_registry;
5
6use layout::*;
7use type_registry::*;
8
9/// Use [u8_slice_to_u32_vec] to convert a `&[u8]` into a `Vec<u32>`.
10/// Does not produce any side effects or corrections.
11pub fn immediatespatch(in_spv: &[u32]) -> Result<Vec<u32>, ()> {
12    let spv = in_spv.to_owned();
13
14    let instruction_bound = spv[SPV_HEADER_INSTRUCTION_BOUND_OFFSET];
15    let magic_number = spv[SPV_HEADER_MAGIC_NUM_OFFSET];
16
17    let spv_header = spv[0..SPV_HEADER_LENGTH].to_owned();
18
19    assert_eq!(magic_number, SPV_HEADER_MAGIC);
20
21    let mut instruction_inserts = vec![];
22    let word_inserts = vec![];
23
24    let spv = spv.into_iter().skip(SPV_HEADER_LENGTH).collect::<Vec<_>>();
25    let mut new_spv = spv.clone();
26
27    // 1. Find locations of instructions we need
28    let mut op_variable_idxs = vec![];
29    let mut op_type_pointer_idxs = vec![];
30    let mut op_type_struct_idxs = vec![];
31    let mut op_type_array_idxs = vec![];
32    let mut op_type_matrix_idxs = vec![];
33    let mut op_type_vector_idxs = vec![];
34    let mut op_type_float_idxs = vec![];
35    let mut op_type_int_idxs = vec![];
36    let mut op_constant_idxs = vec![];
37    let mut op_decorate_idxs = vec![];
38    let mut op_member_decorate_idxs = vec![];
39
40    let mut spv_idx = 0;
41    while spv_idx < spv.len() {
42        let op = spv[spv_idx];
43        let word_count = hiword(op);
44        let instruction = loword(op);
45
46        match instruction {
47            SPV_INSTRUCTION_OP_VARIABLE => op_variable_idxs.push(spv_idx),
48            SPV_INSTRUCTION_OP_TYPE_POINTER => op_type_pointer_idxs.push(spv_idx),
49            SPV_INSTRUCTION_OP_TYPE_STRUCT => op_type_struct_idxs.push(spv_idx),
50            SPV_INSTRUCTION_OP_TYPE_ARRAY => op_type_array_idxs.push(spv_idx),
51            SPV_INSTRUCTION_OP_TYPE_MATRIX => op_type_matrix_idxs.push(spv_idx),
52            SPV_INSTRUCTION_OP_TYPE_VECTOR => op_type_vector_idxs.push(spv_idx),
53            SPV_INSTRUCTION_OP_TYPE_FLOAT => op_type_float_idxs.push(spv_idx),
54            SPV_INSTRUCTION_OP_TYPE_INT => op_type_int_idxs.push(spv_idx),
55            SPV_INSTRUCTION_OP_CONSTANT => op_constant_idxs.push(spv_idx),
56            SPV_INSTRUCTION_OP_DECORATE => op_decorate_idxs.push(spv_idx),
57            SPV_INSTRUCTION_OP_MEMBER_DECORATE => op_member_decorate_idxs.push(spv_idx),
58            _ => {}
59        }
60
61        spv_idx += word_count as usize;
62    }
63
64    // 2. Find all `OpVariable` that is a `PushConstant`
65    let pc_variables = op_variable_idxs
66        .iter()
67        .filter_map(|&v_idx| {
68            let result_type_id = spv[v_idx + 1];
69            let result_id = spv[v_idx + 2];
70            let storage_class = spv[v_idx + 3];
71            (storage_class == SPV_STORAGE_CLASS_PUSH_CONSTANT).then_some((
72                v_idx,
73                result_type_id,
74                result_id,
75            ))
76        })
77        .collect::<Vec<_>>();
78
79    if pc_variables.is_empty() {
80        return Ok(in_spv.to_vec());
81    }
82
83    // 3. Find underlying type of variables
84    let block_struct_ids = pc_variables
85        .iter()
86        .map(|&(_, ptr_id, _)| {
87            op_type_pointer_idxs
88                .iter()
89                .find_map(|&tp_idx| {
90                    let result_id = spv[tp_idx + 1];
91                    let underlying_type_id = spv[tp_idx + 3];
92                    (result_id == ptr_id).then_some(underlying_type_id)
93                })
94                .expect("OpVariable PushConstant referenced an undefined OpTypePointer")
95        })
96        .collect::<Vec<_>>();
97
98    // 4. Build a registry of every relevant OpType*
99    let type_registry = build_type_registry(BuildTypeRegistryIn {
100        spv: &spv,
101        op_type_float_idxs: &op_type_float_idxs,
102        op_type_int_idxs: &op_type_int_idxs,
103        op_type_vector_idxs: &op_type_vector_idxs,
104        op_type_matrix_idxs: &op_type_matrix_idxs,
105        op_type_array_idxs: &op_type_array_idxs,
106        op_type_struct_idxs: &op_type_struct_idxs,
107        op_constant_idxs: &op_constant_idxs,
108    });
109
110    // 5. Rewrite Offset / ArrayStride / MatrixStride decoration
111    for &block_struct_id in &block_struct_ids {
112        relayout_type_recursive(
113            &spv,
114            &mut new_spv,
115            block_struct_id,
116            &type_registry,
117            &op_decorate_idxs,
118            &op_member_decorate_idxs,
119        );
120    }
121
122    // 6. Correct OpTypePointer and OpVariable PushConstant -> Uniform
123    // TODO: I believe having two of the same OpTypePointer is a validation error
124    for &tp_idx in &op_type_pointer_idxs {
125        let storage_class = spv[tp_idx + 2];
126        if storage_class == SPV_STORAGE_CLASS_PUSH_CONSTANT {
127            new_spv[tp_idx + 2] = SPV_STORAGE_CLASS_UNIFORM;
128        }
129    }
130    for &(v_idx, _, _) in &pc_variables {
131        new_spv[v_idx + 3] = SPV_STORAGE_CLASS_UNIFORM;
132    }
133
134    // 7. Place new uniforms in the set after the last set.
135    let first_op_decorate_idx = op_decorate_idxs.first().copied();
136    let next_set = op_decorate_idxs
137        .iter()
138        .filter_map(|&d_idx| {
139            let decoration_id = spv[d_idx + 2];
140            let decoration_value = spv[d_idx + 3];
141            (decoration_id == SPV_DECORATION_DESCRIPTOR_SET).then_some(decoration_value)
142        })
143        .max()
144        .map(|max| max + 1)
145        .unwrap_or(0);
146
147    for (binding_idx, &(_, _, var_id)) in pc_variables.iter().enumerate() {
148        instruction_inserts.push(InstructionInsert {
149            previous_spv_idx: first_op_decorate_idx
150                .expect("Push constant block has no OpDecorate (missing Block decoration?)"),
151            instruction: vec![
152                encode_word(4, SPV_INSTRUCTION_OP_DECORATE),
153                var_id,
154                SPV_DECORATION_DESCRIPTOR_SET,
155                next_set,
156                encode_word(4, SPV_INSTRUCTION_OP_DECORATE),
157                var_id,
158                SPV_DECORATION_BINDING,
159                binding_idx as u32,
160            ],
161        });
162    }
163
164    // 8. Insert New Instructions
165    insert_new_instructions(&spv, &mut new_spv, &word_inserts, &instruction_inserts);
166
167    // 9. Remove Instructions that have been Whited Out.
168    prune_noops(&mut new_spv);
169
170    // 10. Write New Header and New Code
171    Ok(fuse_final(spv_header, new_spv, instruction_bound))
172}
173
174// Recursively patch Offset / ArrayStride / MatrixStride decorations using our type registry.
175fn relayout_type_recursive(
176    spv: &[u32],
177    new_spv: &mut [u32],
178    type_id: u32,
179    registry: &TypeRegistry,
180    op_decorate_idxs: &[usize],
181    op_member_decorate_idxs: &[usize],
182) {
183    let ty = match registry.get(&type_id) {
184        Some(t) => t,
185        None => return,
186    };
187
188    match &ty.kind {
189        TypeKind::Struct { members } => {
190            let layout = layout_struct(members, LayoutRule::Std140);
191
192            for (i, new_offset) in layout.member_offsets.iter().enumerate() {
193                patch_member_decoration_literal(
194                    spv,
195                    new_spv,
196                    op_member_decorate_idxs,
197                    type_id,
198                    i as u32,
199                    SPV_DECORATION_OFFSET,
200                    *new_offset,
201                );
202            }
203
204            for (i, member) in members.iter().enumerate() {
205                let matrix_kind = match &member.kind {
206                    TypeKind::Matrix { .. } => Some(&member.kind),
207                    TypeKind::Array { element, .. } => match &element.kind {
208                        TypeKind::Matrix { .. } => Some(&element.kind),
209                        _ => None,
210                    },
211                    _ => None,
212                };
213                if let Some(TypeKind::Matrix { column, .. }) = matrix_kind {
214                    let col_count = column_vec_count(column);
215                    let scalar_w = column_scalar_width(column);
216                    let new_stride = matrix_stride(col_count, scalar_w, LayoutRule::Std140);
217                    patch_member_decoration_literal(
218                        spv,
219                        new_spv,
220                        op_member_decorate_idxs,
221                        type_id,
222                        i as u32,
223                        SPV_DECORATION_MATRIX_STRIDE,
224                        new_stride,
225                    );
226                }
227                relayout_type_recursive(
228                    spv,
229                    new_spv,
230                    member.id,
231                    registry,
232                    op_decorate_idxs,
233                    op_member_decorate_idxs,
234                );
235            }
236        }
237
238        TypeKind::Array { element, .. } => {
239            let new_stride = array_stride(&element.kind, LayoutRule::Std140);
240            for &d_idx in op_decorate_idxs {
241                let target_id = spv[d_idx + 1];
242                let decoration_id = spv[d_idx + 2];
243                if target_id == type_id && decoration_id == SPV_DECORATION_ARRAY_STRIDE {
244                    new_spv[d_idx + 3] = new_stride;
245                }
246            }
247            // Ensure arrays of arrays and array of structs are updated.
248            relayout_type_recursive(
249                spv,
250                new_spv,
251                element.id,
252                registry,
253                op_decorate_idxs,
254                op_member_decorate_idxs,
255            );
256        }
257        // No effect from scalars, vectors, and matrices.
258        _ => {}
259    }
260}
261
262fn patch_member_decoration_literal(
263    spv: &[u32],
264    new_spv: &mut [u32],
265    op_member_decorate_idxs: &[usize],
266    target_id: u32,
267    member: u32,
268    decoration: u32,
269    new_value: u32,
270) {
271    for &md_idx in op_member_decorate_idxs {
272        let md_target_id = spv[md_idx + 1];
273        let md_member = spv[md_idx + 2];
274        let md_decoration = spv[md_idx + 3];
275        if md_target_id == target_id && md_member == member && md_decoration == decoration {
276            new_spv[md_idx + 4] = new_value;
277        }
278    }
279}