Skip to main content

spirv_webgpu_transform/
splitcombined.rs

1use super::*;
2
3mod function_call;
4mod function_parameter;
5mod load;
6mod type_function;
7mod type_pointer;
8mod variable;
9
10use function_call::*;
11use function_parameter::*;
12use load::*;
13use type_function::*;
14use type_pointer::*;
15use variable::*;
16
17/// Perform the operation on a `Vec<u32>`.
18/// Use [u8_slice_to_u32_vec] to convert a `&[u8]` into a `Vec<u32>`
19/// Either update the existing `corrections` or create a new one.
20pub fn combimgsampsplitter(
21    in_spv: &[u32],
22    corrections: &mut Option<CorrectionMap>,
23) -> Result<Vec<u32>, ()> {
24    let spv = in_spv.to_owned();
25
26    let mut instruction_bound = spv[SPV_HEADER_INSTRUCTION_BOUND_OFFSET];
27    let magic_number = spv[SPV_HEADER_MAGIC_NUM_OFFSET];
28
29    let spv_header = spv[0..SPV_HEADER_LENGTH].to_owned();
30
31    assert_eq!(magic_number, SPV_HEADER_MAGIC);
32
33    let mut instruction_inserts = vec![];
34    let mut word_inserts = vec![];
35
36    let spv = spv.into_iter().skip(SPV_HEADER_LENGTH).collect::<Vec<_>>();
37    let mut new_spv = spv.clone();
38
39    let mut op_type_sampler_idx = None;
40    let mut first_op_deocrate_idx = None;
41    let mut first_op_type_void_idx = None;
42
43    let mut op_type_image_idxs = vec![];
44    let mut op_type_sampled_image_idxs = vec![];
45    let mut op_type_pointer_idxs = vec![];
46    let mut op_variables_idxs = vec![];
47    let mut op_loads_idxs = vec![];
48    let mut op_decorate_idxs = vec![];
49    let mut op_type_function_idxs = vec![];
50    let mut op_function_parameter_idxs = vec![];
51    let mut op_function_call_idxs = vec![];
52
53    // 1. Find locations instructions we need
54    let mut spv_idx = 0;
55    while spv_idx < spv.len() {
56        let op = spv[spv_idx];
57        let word_count = hiword(op);
58        let instruction = loword(op);
59
60        match instruction {
61            SPV_INSTRUCTION_OP_TYPE_VOID => {
62                first_op_type_void_idx = Some(spv_idx);
63            }
64            SPV_INSTRUCTION_OP_TYPE_SAMPLER => {
65                op_type_sampler_idx = Some(spv_idx);
66                new_spv[spv_idx] = encode_word(word_count, SPV_INSTRUCTION_OP_NOP);
67            }
68            SPV_INSTRUCTION_OP_TYPE_IMAGE => {
69                op_type_image_idxs.push(spv_idx);
70            }
71            SPV_INSTRUCTION_OP_TYPE_SAMPLED_IMAGE => op_type_sampled_image_idxs.push(spv_idx),
72            SPV_INSTRUCTION_OP_TYPE_POINTER => {
73                // This should probably go elsewhere.
74                #[allow(clippy::collapsible_match)]
75                if spv[spv_idx + 2] == SPV_STORAGE_CLASS_UNIFORM_CONSTANT {
76                    op_type_pointer_idxs.push(spv_idx);
77                }
78            }
79            SPV_INSTRUCTION_OP_VARIABLE => op_variables_idxs.push(spv_idx),
80            SPV_INSTRUCTION_OP_LOAD => op_loads_idxs.push(spv_idx),
81            SPV_INSTRUCTION_OP_DECORATE => {
82                op_decorate_idxs.push(spv_idx);
83                first_op_deocrate_idx.get_or_insert(spv_idx);
84            }
85            SPV_INSTRUCTION_OP_TYPE_FUNCTION => op_type_function_idxs.push(spv_idx),
86            SPV_INSTRUCTION_OP_FUNCTION_PARAMETER => op_function_parameter_idxs.push(spv_idx),
87            SPV_INSTRUCTION_OP_FUNCTION_CALL => op_function_call_idxs.push(spv_idx),
88
89            _ => {}
90        }
91
92        spv_idx += word_count as usize;
93    }
94
95    // 2. Insert OpTypeSampler and respective OpTypePointer if neccessary
96
97    // - If there has been no OpTypeImage, there will be nothing to do
98    if op_type_image_idxs.is_empty() {
99        return Ok(in_spv.to_vec());
100    };
101
102    let op_type_sampler_res_id = if let Some(idx) = op_type_sampler_idx {
103        spv[idx + 1]
104    } else {
105        let op_type_sampler_res_id = instruction_bound;
106        instruction_bound += 1;
107        op_type_sampler_res_id
108    };
109
110    let op_type_pointer_sampler_res_id = instruction_bound;
111    instruction_bound += 1;
112    instruction_inserts.push(InstructionInsert {
113        // Let's avoid trouble and just insert after OpTypeVoid.
114        // previous_spv_idx: op_type_image_idx,
115        previous_spv_idx: first_op_type_void_idx.unwrap(),
116        instruction: vec![
117            encode_word(2, SPV_INSTRUCTION_OP_TYPE_SAMPLER),
118            op_type_sampler_res_id,
119            encode_word(4, SPV_INSTRUCTION_OP_TYPE_POINTER),
120            op_type_pointer_sampler_res_id,
121            SPV_STORAGE_CLASS_UNIFORM_CONSTANT,
122            op_type_sampler_res_id,
123        ],
124    });
125
126    // 3. OpTypePointer
127    let tp_res = type_pointer(TypePointerIn {
128        spv: &spv,
129        new_spv: &mut new_spv,
130
131        op_type_pointer_idxs: &op_type_pointer_idxs,
132        op_type_sampled_image_idxs: &op_type_sampled_image_idxs,
133    });
134
135    // 4. OpVariable
136    let v_res = variable(VariableIn {
137        spv: &spv,
138        instruction_bound: &mut instruction_bound,
139        instruction_inserts: &mut instruction_inserts,
140        op_type_pointer_sampler_res_id,
141        op_variables_idxs: &op_variables_idxs,
142        tp_res: &tp_res,
143    });
144
145    // 5. OpTypeFunction
146    type_function(TypeFunctionIn {
147        spv: &spv,
148        word_inserts: &mut word_inserts,
149        op_type_pointer_sampler_res_id,
150        op_type_function_idxs: &op_type_function_idxs,
151        tp_res: &tp_res,
152    });
153
154    // 6. OpFunctionParameter
155    let parameter_res = function_parameter(FunctionParameterIn {
156        spv: &spv,
157        instruction_bound: &mut instruction_bound,
158        instruction_inserts: &mut instruction_inserts,
159        op_type_pointer_sampler_res_id,
160        op_function_parameter_idxs: &op_function_parameter_idxs,
161        tp_res: &tp_res,
162    });
163
164    // 7. OpFunctionCall
165    function_call(FunctionCallIn {
166        spv: &spv,
167        word_inserts: &mut word_inserts,
168        op_function_call_idxs: &op_function_call_idxs,
169        v_res: &v_res,
170        parameter_res: &parameter_res,
171    });
172
173    // 8. OpLoad
174    load(LoadIn {
175        spv: &spv,
176        new_spv: &mut new_spv,
177        instruction_bound: &mut instruction_bound,
178        instruction_inserts: &mut instruction_inserts,
179        op_type_sampler_res_id,
180        op_loads_idxs: &op_loads_idxs,
181        v_res: &v_res,
182        parameter_res: &parameter_res,
183    });
184
185    // 9. OpDecorate
186    let DecorateOut {
187        descriptor_sets_to_correct,
188    } = util::decorate(DecorateIn {
189        spv: &spv,
190        instruction_inserts: &mut instruction_inserts,
191        first_op_deocrate_idx,
192        op_decorate_idxs: &op_decorate_idxs,
193        affected_decorations: &v_res
194            .iter()
195            .map(
196                |VariableOut {
197                     v_res_id,
198                     new_sampler_v_res_id,
199                     ..
200                 }| {
201                    AffectedDecoration {
202                        original_res_id: *v_res_id,
203                        new_res_ids: vec![*new_sampler_v_res_id],
204                        correction_type: CorrectionType::SplitCombined,
205                    }
206                },
207            )
208            .collect::<Vec<_>>(),
209        corrections,
210    });
211
212    // 10. Insert New Instructions
213    insert_new_instructions(&spv, &mut new_spv, &word_inserts, &instruction_inserts);
214
215    // 11. Correct OpDecorate Bindings
216    util::correct_decorate(CorrectDecorateIn {
217        new_spv: &mut new_spv,
218        descriptor_sets_to_correct,
219    });
220
221    // 12. Remove Instructions that have been Whited Out.
222    prune_noops(&mut new_spv);
223
224    // 13. Write New Header and New Code
225    Ok(fuse_final(spv_header, new_spv, instruction_bound))
226}