Skip to main content

spirv_webgpu_transform/
storagecubepatch.rs

1use super::*;
2
3fn inc(ib: &mut u32) -> u32 {
4    *ib += 1;
5    *ib - 1
6}
7
8mod image_cube_direction_to_arrayed;
9use image_cube_direction_to_arrayed::*;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12enum ImageOperation<T> {
13    Fetch(T),
14    Read(T),
15    Write(T),
16}
17
18impl<T> ImageOperation<T>
19where
20    T: Clone + Copy,
21{
22    fn get(&self) -> T {
23        match self {
24            ImageOperation::Fetch(v) | ImageOperation::Read(v) | ImageOperation::Write(v) => *v,
25        }
26    }
27
28    fn image_offset(&self) -> usize {
29        match self {
30            ImageOperation::Fetch(_) => 3,
31            ImageOperation::Read(_) => 3,
32            ImageOperation::Write(_) => 1,
33        }
34    }
35
36    fn coordinate_offset(&self) -> usize {
37        match self {
38            ImageOperation::Fetch(_) => 4,
39            ImageOperation::Read(_) => 4,
40            ImageOperation::Write(_) => 2,
41        }
42    }
43}
44
45/// Perform the operation on a `Vec<u32>`.
46/// Use [u8_slice_to_u32_vec] to convert a `&[u8]` into a `Vec<u32>`.
47/// Does not produce any side effects or corrections.
48pub fn storagecubepatch(
49    in_spv: &[u32],
50    corrections: &mut Option<CorrectionMap>,
51) -> Result<Vec<u32>, ()> {
52    let spv = in_spv.to_owned();
53
54    let mut instruction_bound = spv[SPV_HEADER_INSTRUCTION_BOUND_OFFSET];
55    let magic_number = spv[SPV_HEADER_MAGIC_NUM_OFFSET];
56
57    let spv_header = spv[0..SPV_HEADER_LENGTH].to_owned();
58
59    assert_eq!(magic_number, SPV_HEADER_MAGIC);
60
61    let mut instruction_inserts = vec![];
62    let word_inserts = vec![];
63
64    let spv = spv.into_iter().skip(SPV_HEADER_LENGTH).collect::<Vec<_>>();
65    let mut new_spv = spv.clone();
66
67    // 1. Find locations instructions we need
68    let mut op_type_image_idxs = vec![];
69    let mut op_type_pointer_idxs = vec![];
70    let mut op_variable_idxs = vec![];
71    let mut op_load_idxs = vec![];
72    let mut op_type_int_idxs = vec![];
73    let mut op_type_bool_idxs = vec![];
74    let mut op_type_vector_idxs = vec![];
75    let mut op_ext_inst_import_idxs = vec![];
76    let mut op_function_parameter_idxs = vec![];
77    let mut op_function_call_idxs = vec![];
78    let mut op_decorate_idxs = vec![];
79
80    let mut image_operation_idxs = vec![];
81
82    let mut spv_idx = 0;
83    while spv_idx < spv.len() {
84        let op = spv[spv_idx];
85        let word_count = hiword(op);
86        let instruction = loword(op);
87
88        match instruction {
89            SPV_INSTRUCTION_OP_TYPE_IMAGE => op_type_image_idxs.push(spv_idx),
90            SPV_INSTRUCTION_OP_TYPE_POINTER => op_type_pointer_idxs.push(spv_idx),
91            SPV_INSTRUCTION_OP_VARIABLE => op_variable_idxs.push(spv_idx),
92            SPV_INSTRUCTION_OP_LOAD => op_load_idxs.push(spv_idx),
93            SPV_INSTRUCTION_OP_TYPE_INT => op_type_int_idxs.push(spv_idx),
94            SPV_INSTRUCTION_OP_TYPE_BOOL => op_type_bool_idxs.push(spv_idx),
95            SPV_INSTRUCTION_OP_TYPE_VECTOR => op_type_vector_idxs.push(spv_idx),
96            SPV_INSTRUCTION_OP_EXT_INST_IMPORT => op_ext_inst_import_idxs.push(spv_idx),
97            SPV_INSTRUCTION_OP_FUNCTION_PARAMETER => op_function_parameter_idxs.push(spv_idx),
98            SPV_INSTRUCTION_OP_FUNCTION_CALL => op_function_call_idxs.push(spv_idx),
99            SPV_INSTRUCTION_OP_DECORATE => op_decorate_idxs.push(spv_idx),
100            SPV_INSTRUCTION_OP_IMAGE_FETCH => {
101                image_operation_idxs.push(ImageOperation::Fetch(spv_idx))
102            }
103            SPV_INSTRUCTION_OP_IMAGE_READ => {
104                image_operation_idxs.push(ImageOperation::Read(spv_idx))
105            }
106            SPV_INSTRUCTION_OP_IMAGE_WRITE => {
107                image_operation_idxs.push(ImageOperation::Write(spv_idx))
108            }
109            _ => {}
110        }
111
112        spv_idx += word_count as usize;
113    }
114
115    if op_type_vector_idxs.is_empty()
116        || op_type_image_idxs.is_empty()
117        || image_operation_idxs.is_empty()
118    {
119        return Ok(in_spv.to_vec());
120    }
121
122    let header_position = last_of_indices!(
123        op_type_int_idxs,
124        op_type_bool_idxs,
125        op_type_vector_idxs,
126        op_type_pointer_idxs
127    );
128
129    // 2. Insert Required Types
130    let mut header_insert = InstructionInsert {
131        previous_spv_idx: header_position.unwrap(),
132        instruction: vec![],
133    };
134
135    let bool_id = ensure_type_bool(
136        &spv,
137        &op_type_bool_idxs,
138        &mut instruction_bound,
139        &mut header_insert.instruction,
140    );
141    let bool_ptr_id = ensure_type_pointer(
142        &spv,
143        &op_type_pointer_idxs,
144        &mut instruction_bound,
145        &mut header_insert.instruction,
146        SPV_STORAGE_CLASS_FUNCTION,
147        bool_id,
148    );
149    let int32_id = ensure_type_int(
150        &spv,
151        &op_type_int_idxs,
152        &mut instruction_bound,
153        &mut header_insert.instruction,
154        32,
155        SPV_SIGNEDNESS_SIGNED,
156    );
157    let int32_ptr_id = ensure_type_pointer(
158        &spv,
159        &op_type_pointer_idxs,
160        &mut instruction_bound,
161        &mut header_insert.instruction,
162        SPV_STORAGE_CLASS_FUNCTION,
163        int32_id,
164    );
165    let v3int32_id = ensure_type_vector(
166        &spv,
167        &op_type_vector_idxs,
168        &mut instruction_bound,
169        &mut header_insert.instruction,
170        int32_id,
171        3,
172    );
173    let v3int32_ptr_id = ensure_type_pointer(
174        &spv,
175        &op_type_pointer_idxs,
176        &mut instruction_bound,
177        &mut header_insert.instruction,
178        SPV_STORAGE_CLASS_FUNCTION,
179        v3int32_id,
180    );
181    let v2int32_id = ensure_type_vector(
182        &spv,
183        &op_type_vector_idxs,
184        &mut instruction_bound,
185        &mut header_insert.instruction,
186        int32_id,
187        2,
188    );
189    let v2int32_ptr_id = ensure_type_pointer(
190        &spv,
191        &op_type_pointer_idxs,
192        &mut instruction_bound,
193        &mut header_insert.instruction,
194        SPV_STORAGE_CLASS_FUNCTION,
195        v2int32_id,
196    );
197    let glsl_std_id = ensure_ext_inst_import(
198        &spv,
199        &op_ext_inst_import_idxs,
200        &mut instruction_bound,
201        &mut header_insert.instruction,
202        |s| s.starts_with("GLSL.std."),
203        "GLSL.std.450",
204    );
205
206    // We only need to validate the bool_id, ptr_int_id
207    let type_inputs = CubeDirectionTypeInputs {
208        int_id: int32_id,
209        v3int_id: v3int32_id,
210        v2int_id: v2int32_id,
211        bool_id,
212        ptr_v3int_id: v3int32_ptr_id,
213        ptr_int_id: int32_ptr_id,
214        ptr_bool_id: bool_ptr_id,
215        ptr_v2int_id: v2int32_ptr_id,
216    };
217    let (function_type_id, mut function_type_spv) =
218        image_cube_direction_to_arrayed_fn_type(&mut instruction_bound, type_inputs);
219    header_insert.instruction.append(&mut function_type_spv);
220
221    // 3. Find / Insert Required Constants
222    let (shared_constants, mut constants_spv) =
223        image_cube_direction_to_arrayed_constants_spv(&mut instruction_bound, int32_id);
224    header_insert.instruction.append(&mut constants_spv);
225
226    // 4. Insert Function Type and Definition
227    let (function_id, function_spv) = image_cube_direction_to_arrayed_spv(
228        &mut instruction_bound,
229        type_inputs,
230        function_type_id,
231        shared_constants,
232        glsl_std_id,
233    );
234    let mut function_definition_words = function_spv;
235
236    // 5. Find OpTypeImage, change Cube -> 2D
237    let type_image_ids = op_type_image_idxs
238        .iter()
239        .filter_map(|idx| {
240            let result_id = spv[idx + 1];
241            let dim = spv[idx + 3];
242            let arrayed = spv[idx + 5];
243            // 0: unknown, 1: sampling, 2: read/write
244            let sampled = spv[idx + 7];
245
246            if dim == SPV_DIMENSION_CUBE && arrayed == 1 && sampled == 2 {
247                panic!("imageCubeArray is not supported");
248            }
249
250            // imageCube => image2DArray
251            if dim == SPV_DIMENSION_CUBE && sampled == 2 {
252                new_spv[idx + 3] = SPV_DIMENSION_2D;
253                new_spv[idx + 5] = 1;
254                Some(result_id)
255            } else {
256                None
257            }
258        })
259        .collect::<Vec<_>>();
260
261    // 6. Find OpTypePointer -> OpVariable / OpFunctionParameter -> OpLoad
262    let type_pointer_ids = op_type_pointer_idxs
263        .iter()
264        .filter_map(|idx| {
265            let result_id = spv[idx + 1];
266            let underlying_type_id = spv[idx + 3];
267
268            type_image_ids
269                .contains(&underlying_type_id)
270                .then_some(result_id)
271        })
272        .collect::<Vec<_>>();
273    let loadable_ids = op_variable_idxs
274        .iter()
275        // Yes, offsets 1 and 2 are identical
276        .chain(op_function_parameter_idxs.iter())
277        .filter_map(|idx| {
278            let result_id = spv[idx + 2];
279            let result_type_id = spv[idx + 1];
280            type_pointer_ids
281                .contains(&result_type_id)
282                .then_some(result_id)
283        })
284        .collect::<Vec<_>>();
285    let loaded_ids = op_load_idxs
286        .iter()
287        .filter_map(|idx| {
288            let result_id = spv[idx + 2];
289            let pointer_id = spv[idx + 3];
290
291            loadable_ids.contains(&pointer_id).then_some(result_id)
292        })
293        .collect::<Vec<_>>();
294
295    // 7. Find and Patch OpImage{Fetch, Read, Write}
296    for operation_with_idx in image_operation_idxs.iter() {
297        let op_idx = operation_with_idx.get();
298        let op_word_count = hiword(spv[op_idx]) as usize;
299        let image_id = spv[op_idx + operation_with_idx.image_offset()];
300        let coord_id = spv[op_idx + operation_with_idx.coordinate_offset()];
301
302        if loaded_ids.contains(&image_id) {
303            // Inject new temp variable to store coordinate
304            // TODO: OPT further reduce the number of temp variables by sharing then within the same functions
305            let temp_id = inc(&mut instruction_bound);
306            instruction_inserts.push(InstructionInsert {
307                previous_spv_idx: get_function_label_index_of_instruction_index(&spv, op_idx),
308                instruction: vec![
309                    encode_word(4, SPV_INSTRUCTION_OP_VARIABLE),
310                    type_inputs.ptr_v3int_id,
311                    temp_id,
312                    SPV_STORAGE_CLASS_FUNCTION,
313                ],
314            });
315            // Store existing coordinate, pass to our function, create new instruction
316            let output_id = inc(&mut instruction_bound);
317            let mut new_instructions = vec![
318                encode_word(3, SPV_INSTRUCTION_OP_STORE),
319                temp_id,
320                coord_id,
321                encode_word(5, SPV_INSTRUCTION_OP_FUNCTION_CALL),
322                type_inputs.v3int_id,
323                output_id,
324                function_id,
325                temp_id,
326            ];
327            let cl = new_instructions.len();
328            new_instructions.extend_from_slice(&spv[op_idx..op_idx + op_word_count]);
329            new_instructions[cl + operation_with_idx.coordinate_offset()] = output_id;
330            instruction_inserts.push(InstructionInsert {
331                previous_spv_idx: op_idx,
332                instruction: new_instructions,
333            });
334
335            new_spv[op_idx..op_idx + op_word_count].fill(encode_word(1, SPV_INSTRUCTION_OP_NOP));
336        }
337    }
338
339    // 8. Fill Correction Map
340    decorate(DecorateIn {
341        spv: &spv,
342        instruction_inserts: &mut vec![],
343        first_op_deocrate_idx: op_decorate_idxs.first().copied(),
344        op_decorate_idxs: &op_decorate_idxs,
345        affected_decorations: &loadable_ids
346            .iter()
347            .map(|id| AffectedDecoration {
348                original_res_id: *id,
349                new_res_ids: vec![*id],
350                correction_type: CorrectionType::ConvertStorageCube,
351            })
352            .collect::<Vec<_>>(),
353        corrections,
354    });
355
356    // 9. Insert New Instructions
357    instruction_inserts.insert(0, header_insert);
358    insert_new_instructions(&spv, &mut new_spv, &word_inserts, &instruction_inserts);
359    new_spv.append(&mut function_definition_words);
360
361    // 10. Remove Instructions that have been Whited Out.
362    prune_noops(&mut new_spv);
363
364    // 11. Write New Header and New Code
365    Ok(fuse_final(spv_header, new_spv, instruction_bound))
366}