Skip to main content

spirv_webgpu_transform/
splitcombined.rs

1use super::*;
2
3mod function_call;
4mod function_parameter;
5mod load;
6mod type_function;
7mod type_pointer;
8mod variable;
9
10use function_call::*;
11use function_parameter::*;
12use load::*;
13use type_function::*;
14use type_pointer::*;
15use variable::*;
16
17/// Perform the operation on a `Vec<u32>`.
18/// Use [u8_slice_to_u32_vec] to convert a `&[u8]` into a `Vec<u32>`
19/// Either update the existing `corrections` or create a new one.
20pub fn combimgsampsplitter(
21    in_spv: &[u32],
22    corrections: &mut Option<CorrectionMap>,
23) -> Result<Vec<u32>, ()> {
24    let spv = in_spv.to_owned();
25
26    let mut instruction_bound = spv[SPV_HEADER_INSTRUCTION_BOUND_OFFSET];
27    let magic_number = spv[SPV_HEADER_MAGIC_NUM_OFFSET];
28
29    let spv_header = spv[0..SPV_HEADER_LENGTH].to_owned();
30
31    assert_eq!(magic_number, SPV_HEADER_MAGIC);
32
33    let mut instruction_inserts = vec![];
34    let mut word_inserts = vec![];
35
36    let spv = spv.into_iter().skip(SPV_HEADER_LENGTH).collect::<Vec<_>>();
37    let mut new_spv = spv.clone();
38
39    let mut op_type_sampler_idx = None;
40    let mut first_op_deocrate_idx = None;
41    let mut first_op_type_void_idx = None;
42
43    let mut op_type_image_idxs = vec![];
44    let mut op_type_sampled_image_idxs = vec![];
45    let mut op_type_pointer_idxs = vec![];
46    let mut op_variables_idxs = vec![];
47    let mut op_loads_idxs = vec![];
48    let mut op_decorate_idxs = vec![];
49    let mut op_type_function_idxs = vec![];
50    let mut op_function_parameter_idxs = vec![];
51    let mut op_function_call_idxs = vec![];
52
53    // 1. Find locations instructions we need
54    let mut spv_idx = 0;
55    while spv_idx < spv.len() {
56        let op = spv[spv_idx];
57        let word_count = hiword(op);
58        let instruction = loword(op);
59
60        match instruction {
61            SPV_INSTRUCTION_OP_TYPE_VOID => {
62                first_op_type_void_idx = Some(spv_idx);
63            }
64            SPV_INSTRUCTION_OP_TYPE_SAMPLER => {
65                op_type_sampler_idx = Some(spv_idx);
66                new_spv[spv_idx] = encode_word(word_count, SPV_INSTRUCTION_OP_NOP);
67            }
68            SPV_INSTRUCTION_OP_TYPE_IMAGE => {
69                op_type_image_idxs.push(spv_idx);
70            }
71            SPV_INSTRUCTION_OP_TYPE_SAMPLED_IMAGE => op_type_sampled_image_idxs.push(spv_idx),
72            SPV_INSTRUCTION_OP_TYPE_POINTER => {
73                if spv[spv_idx + 2] == SPV_STORAGE_CLASS_UNIFORM_CONSTANT {
74                    op_type_pointer_idxs.push(spv_idx);
75                }
76            }
77            SPV_INSTRUCTION_OP_VARIABLE => op_variables_idxs.push(spv_idx),
78            SPV_INSTRUCTION_OP_LOAD => op_loads_idxs.push(spv_idx),
79            SPV_INSTRUCTION_OP_DECORATE => {
80                op_decorate_idxs.push(spv_idx);
81                first_op_deocrate_idx.get_or_insert(spv_idx);
82            }
83            SPV_INSTRUCTION_OP_TYPE_FUNCTION => op_type_function_idxs.push(spv_idx),
84            SPV_INSTRUCTION_OP_FUNCTION_PARAMETER => op_function_parameter_idxs.push(spv_idx),
85            SPV_INSTRUCTION_OP_FUNCTION_CALL => op_function_call_idxs.push(spv_idx),
86
87            _ => {}
88        }
89
90        spv_idx += word_count as usize;
91    }
92
93    // 2. Insert OpTypeSampler and respective OpTypePointer if neccessary
94
95    // - If there has been no OpTypeImage, there will be nothing to do
96    if op_type_image_idxs.is_empty() {
97        return Ok(in_spv.to_vec());
98    };
99
100    let op_type_sampler_res_id = if let Some(idx) = op_type_sampler_idx {
101        spv[idx + 1]
102    } else {
103        let op_type_sampler_res_id = instruction_bound;
104        instruction_bound += 1;
105        op_type_sampler_res_id
106    };
107
108    let op_type_pointer_sampler_res_id = instruction_bound;
109    instruction_bound += 1;
110    instruction_inserts.push(InstructionInsert {
111        // Let's avoid trouble and just insert after OpTypeVoid.
112        // previous_spv_idx: op_type_image_idx,
113        previous_spv_idx: first_op_type_void_idx.unwrap(),
114        instruction: vec![
115            encode_word(2, SPV_INSTRUCTION_OP_TYPE_SAMPLER),
116            op_type_sampler_res_id,
117            encode_word(4, SPV_INSTRUCTION_OP_TYPE_POINTER),
118            op_type_pointer_sampler_res_id,
119            SPV_STORAGE_CLASS_UNIFORM_CONSTANT,
120            op_type_sampler_res_id,
121        ],
122    });
123
124    // 3. OpTypePointer
125    let tp_res = type_pointer(TypePointerIn {
126        spv: &spv,
127        new_spv: &mut new_spv,
128
129        op_type_pointer_idxs: &op_type_pointer_idxs,
130        op_type_sampled_image_idxs: &op_type_sampled_image_idxs,
131    });
132
133    // 4. OpVariable
134    let v_res = variable(VariableIn {
135        spv: &spv,
136        instruction_bound: &mut instruction_bound,
137        instruction_inserts: &mut instruction_inserts,
138        op_type_pointer_sampler_res_id,
139        op_variables_idxs: &op_variables_idxs,
140        tp_res: &tp_res,
141    });
142
143    // 5. OpTypeFunction
144    type_function(TypeFunctionIn {
145        spv: &spv,
146        word_inserts: &mut word_inserts,
147        op_type_pointer_sampler_res_id,
148        op_type_function_idxs: &op_type_function_idxs,
149        tp_res: &tp_res,
150    });
151
152    // 6. OpFunctionParameter
153    let parameter_res = function_parameter(FunctionParameterIn {
154        spv: &spv,
155        instruction_bound: &mut instruction_bound,
156        instruction_inserts: &mut instruction_inserts,
157        op_type_pointer_sampler_res_id,
158        op_function_parameter_idxs: &op_function_parameter_idxs,
159        tp_res: &tp_res,
160    });
161
162    // 7. OpFunctionCall
163    function_call(FunctionCallIn {
164        spv: &spv,
165        word_inserts: &mut word_inserts,
166        op_function_call_idxs: &op_function_call_idxs,
167        v_res: &v_res,
168        parameter_res: &parameter_res,
169    });
170
171    // 8. OpLoad
172    load(LoadIn {
173        spv: &spv,
174        new_spv: &mut new_spv,
175        instruction_bound: &mut instruction_bound,
176        instruction_inserts: &mut instruction_inserts,
177        op_type_sampler_res_id,
178        op_loads_idxs: &op_loads_idxs,
179        v_res: &v_res,
180        parameter_res: &parameter_res,
181    });
182
183    // 9. OpDecorate
184    let DecorateOut {
185        descriptor_sets_to_correct,
186    } = util::decorate(DecorateIn {
187        spv: &spv,
188        instruction_inserts: &mut instruction_inserts,
189        first_op_deocrate_idx,
190        op_decorate_idxs: &op_decorate_idxs,
191        affected_decorations: &v_res
192            .iter()
193            .map(
194                |VariableOut {
195                     v_res_id,
196                     new_sampler_v_res_id,
197                     ..
198                 }| {
199                    AffectedDecoration {
200                        original_res_id: *v_res_id,
201                        new_res_id: *new_sampler_v_res_id,
202                        correction_type: CorrectionType::SplitCombined,
203                    }
204                },
205            )
206            .collect::<Vec<_>>(),
207        corrections,
208    });
209
210    // 10. Insert New Instructions
211    insert_new_instructions(&spv, &mut new_spv, &word_inserts, &instruction_inserts);
212
213    // 11. Correct OpDecorate Bindings
214    util::correct_decorate(CorrectDecorateIn {
215        new_spv: &mut new_spv,
216        descriptor_sets_to_correct,
217    });
218
219    // 12. Remove Instructions that have been Whited Out.
220    prune_noops(&mut new_spv);
221
222    // 13. Write New Header and New Code
223    Ok(fuse_final(spv_header, new_spv, instruction_bound))
224}