Skip to main content

spirv_webgpu_transform/
pruneunuseddref.rs

1use super::*;
2
3pub fn pruneunuseddref(in_spv: &[u32]) -> Result<Vec<u32>, ()> {
4    let spv = in_spv.to_owned();
5
6    let instruction_bound = spv[SPV_HEADER_INSTRUCTION_BOUND_OFFSET];
7    let magic_number = spv[SPV_HEADER_MAGIC_NUM_OFFSET];
8
9    let spv_header = spv[0..SPV_HEADER_LENGTH].to_owned();
10
11    assert_eq!(magic_number, SPV_HEADER_MAGIC);
12
13    let spv = spv.into_iter().skip(SPV_HEADER_LENGTH).collect::<Vec<_>>();
14    let mut new_spv = spv.clone();
15
16    // 1. Find locations instructions we need
17    let mut op_type_pointer_idxs = vec![];
18    let mut op_type_image_idxs = vec![];
19    let mut op_variable_idxs = vec![];
20    let mut op_load_idxs = vec![];
21    let mut op_function_parameter_idxs = vec![];
22    let mut op_function_call_idxs = vec![];
23    let mut op_decorate_idxs = vec![];
24    let mut op_name_idxs = vec![];
25
26    let mut op_type_sampler_id_map = HashSet::new();
27    let mut op_sampled_image_id_map = HashSet::new();
28
29    let mut spv_idx = 0;
30    while spv_idx < spv.len() {
31        let op = spv[spv_idx];
32        let word_count = hiword(op);
33        let instruction = loword(op);
34
35        match instruction {
36            SPV_INSTRUCTION_OP_TYPE_POINTER => op_type_pointer_idxs.push(spv_idx),
37            SPV_INSTRUCTION_OP_TYPE_IMAGE => op_type_image_idxs.push(spv_idx),
38            SPV_INSTRUCTION_OP_VARIABLE => op_variable_idxs.push(spv_idx),
39            SPV_INSTRUCTION_OP_LOAD => op_load_idxs.push(spv_idx),
40            SPV_INSTRUCTION_OP_FUNCTION_PARAMETER => op_function_parameter_idxs.push(spv_idx),
41            SPV_INSTRUCTION_OP_FUNCTION_CALL => op_function_call_idxs.push(spv_idx),
42            SPV_INSTRUCTION_OP_DECORATE => op_decorate_idxs.push(spv_idx),
43            SPV_INSTRUCTION_OP_NAME => op_name_idxs.push(spv_idx),
44
45            SPV_INSTRUCTION_OP_TYPE_SAMPLER => {
46                let result_id = spv[spv_idx + 1];
47                op_type_sampler_id_map.insert(result_id);
48            }
49            SPV_INSTRUCTION_OP_SAMPLED_IMAGE => {
50                let image_id = spv[spv_idx + 3];
51                let sampler_id = spv[spv_idx + 4];
52                op_sampled_image_id_map.insert(image_id);
53                op_sampled_image_id_map.insert(sampler_id);
54            }
55            _ => {}
56        }
57
58        spv_idx += word_count as usize;
59    }
60
61    // 2. Find all OpTypePointer to OpTypeImage and OpTypeSampler
62    let image_type_pointers_map = op_type_pointer_idxs
63        .iter()
64        .filter_map(|&tp_idx| {
65            let result_id = spv[tp_idx + 1];
66            let underlying_type_id = spv[tp_idx + 3];
67            op_type_image_idxs
68                .iter()
69                .any(|ti_idx| {
70                    let type_id = spv[ti_idx + 1];
71                    let image_sampled = spv[ti_idx + 7];
72
73                    // `!= 2` filters for storage textures which shouldn't be pruned.
74                    image_sampled != 2 && type_id == underlying_type_id
75                })
76                .then_some(result_id)
77        })
78        .collect::<HashSet<_>>();
79    let sampler_type_pointers_map = op_type_pointer_idxs
80        .iter()
81        .filter_map(|&tp_idx| {
82            let result_id = spv[tp_idx + 1];
83            let underlying_type_id = spv[tp_idx + 3];
84            op_type_sampler_id_map
85                .contains(&underlying_type_id)
86                .then_some(result_id)
87        })
88        .collect::<HashSet<_>>();
89
90    // 3. Final all OpVariable to OpTypePointers
91    let variable_result_map = op_variable_idxs
92        .iter()
93        .filter_map(|&idx| {
94            let tp_id = spv[idx + 1];
95            let result_id = spv[idx + 2];
96
97            (image_type_pointers_map.contains(&tp_id) || sampler_type_pointers_map.contains(&tp_id))
98                .then_some((result_id, idx))
99        })
100        .collect::<HashMap<_, _>>();
101
102    // 4. Find all OpLoad to OpSampledImage
103    let loaded_idxs = op_load_idxs
104        .iter()
105        .filter(|&idx| {
106            let result_id = spv[idx + 2];
107            op_sampled_image_id_map.contains(&result_id)
108        })
109        .collect::<Vec<_>>();
110
111    let mut used_variable_idxs = loaded_idxs
112        .iter()
113        .filter_map(|&&load_idx| {
114            let pointer = spv[load_idx + 3];
115            variable_result_map
116                .contains_key(&pointer)
117                .then_some(pointer)
118        })
119        .collect::<HashSet<_>>();
120
121    // 5. Final all OpFunctionParameter to OpLoad
122    let function_parameter_idxs = op_function_parameter_idxs.iter().filter(|&fp_idx| {
123        let result_id = spv[fp_idx + 2];
124        op_load_idxs.iter().any(|&l_idx| {
125            let pointer = spv[l_idx + 3];
126            pointer == result_id
127        })
128    });
129
130    // 6. Trace Variables from OpFunctionParameter
131    for &fp_idx in function_parameter_idxs {
132        let entry = get_function_from_parameter(&spv, fp_idx);
133        let variables = trace_function_argument_to_variables(TraceFunctionArgumentToVariablesIn {
134            spv: &spv,
135            op_variable_idxs: &op_variable_idxs,
136            op_function_parameter_idxs: &op_function_parameter_idxs,
137            op_function_call_idxs: &op_function_call_idxs,
138            entry,
139            traced_function_call_idxs: &mut vec![],
140        });
141        for variable_idx in variables {
142            let variable_result_id = spv[variable_idx + 2];
143            used_variable_idxs.insert(variable_result_id);
144        }
145    }
146
147    // 8. Remove unused variables
148    let unused_variable_idxs = variable_result_map
149        .iter()
150        .filter_map(|(id, &idx)| (!used_variable_idxs.contains(id)).then_some(idx))
151        .collect::<Vec<_>>();
152
153    // 9. Find OpDecorate / OpName to OpVariable
154    let unused_decorate_idxs = op_decorate_idxs
155        .iter()
156        .filter(|&idx| {
157            let target = spv[idx + 1];
158            unused_variable_idxs.iter().any(|&v_idx| {
159                let result_id = spv[v_idx + 2];
160                target == result_id
161            })
162        })
163        .copied()
164        .collect::<Vec<_>>();
165
166    let unused_name_idxs = op_name_idxs
167        .iter()
168        .filter(|&idx| {
169            let target = spv[idx + 1];
170            unused_variable_idxs.iter().any(|&v_idx| {
171                let result_id = spv[v_idx + 2];
172                target == result_id
173            })
174        })
175        .copied()
176        .collect::<Vec<_>>();
177
178    // 9. Remove instructions
179    for spv_idx in unused_variable_idxs
180        .into_iter()
181        .chain(unused_decorate_idxs)
182        .chain(unused_name_idxs)
183    {
184        let op = spv[spv_idx];
185        let word_count = hiword(op) as usize;
186
187        new_spv[spv_idx..spv_idx + word_count].fill(encode_word(1, SPV_INSTRUCTION_OP_NOP));
188    }
189    prune_noops(&mut new_spv);
190
191    // 10. Write New Header and New Code
192    Ok(fuse_final(spv_header, new_spv, instruction_bound))
193}