Skip to main content

spirv_webgpu_transform/
isnanisinfpatch.rs

1use super::*;
2
3fn inc(ib: &mut u32) -> u32 {
4    *ib += 1;
5    *ib - 1
6}
7
8// Someone should make a rust-spirv dsl macro
9mod isnan_isinf;
10mod shared;
11
12use isnan_isinf::*;
13use shared::*;
14
15/// Perform the operation on a `Vec<u32>`.
16/// Use [u8_slice_to_u32_vec] to convert a `&[u8]` into a `Vec<u32>`.
17/// Does not produce any side effects or corrections.
18pub fn isnanisinfpatch(in_spv: &[u32]) -> Result<Vec<u32>, ()> {
19    let spv = in_spv.to_owned();
20
21    let mut instruction_bound = spv[SPV_HEADER_INSTRUCTION_BOUND_OFFSET];
22    let magic_number = spv[SPV_HEADER_MAGIC_NUM_OFFSET];
23
24    let spv_header = spv[0..SPV_HEADER_LENGTH].to_owned();
25
26    assert_eq!(magic_number, SPV_HEADER_MAGIC);
27
28    let mut instruction_inserts = vec![];
29    let word_inserts = vec![];
30
31    let spv = spv.into_iter().skip(SPV_HEADER_LENGTH).collect::<Vec<_>>();
32    let mut new_spv = spv.clone();
33
34    // 1. Find locations instructions we need
35    let mut op_function_idxs = vec![];
36    let mut op_load_idxs = vec![];
37    let mut op_type_pointer_idxs = vec![];
38    let mut op_is_nan_idxs = vec![];
39    let mut op_is_inf_idxs = vec![];
40    let mut op_type_bool_idxs = vec![];
41    let mut op_type_int_idxs = vec![];
42    let mut op_type_float_idxs = vec![];
43    let mut op_type_vector_idxs = vec![];
44
45    let mut spv_idx = 0;
46    while spv_idx < spv.len() {
47        let op = spv[spv_idx];
48        let word_count = hiword(op);
49        let instruction = loword(op);
50
51        match instruction {
52            SPV_INSTRUCTION_OP_FUNCTION => op_function_idxs.push(spv_idx),
53            SPV_INSTRUCTION_OP_LOAD => op_load_idxs.push(spv_idx),
54            SPV_INSTRUCTION_OP_TYPE_POINTER => op_type_pointer_idxs.push(spv_idx),
55            SPV_INSTRUCTION_OP_IS_NAN => op_is_nan_idxs.push(spv_idx),
56            SPV_INSTRUCTION_OP_IS_INF => op_is_inf_idxs.push(spv_idx),
57            SPV_INSTRUCTION_OP_TYPE_BOOL => op_type_bool_idxs.push(spv_idx),
58            SPV_INSTRUCTION_OP_TYPE_INT => op_type_int_idxs.push(spv_idx),
59            SPV_INSTRUCTION_OP_TYPE_FLOAT => op_type_float_idxs.push(spv_idx),
60            SPV_INSTRUCTION_OP_TYPE_VECTOR => op_type_vector_idxs.push(spv_idx),
61            _ => {}
62        }
63
64        spv_idx += word_count as usize;
65    }
66
67    if op_is_nan_idxs.is_empty() && op_is_inf_idxs.is_empty() {
68        return Ok(in_spv.to_vec());
69    }
70    let header_position = last_of_indices!(
71        op_type_int_idxs,
72        op_type_bool_idxs,
73        op_type_float_idxs,
74        op_type_vector_idxs,
75        op_type_pointer_idxs
76    );
77
78    // 2. Useful closures
79    let get_float_type_width = |id| {
80        op_type_float_idxs
81            .iter()
82            .find_map(|idx| (spv[idx + 1] == id).then_some(spv[idx + 2]))
83    };
84
85    let get_underlying_vector_type = |id| {
86        op_type_vector_idxs.iter().find_map(|idx| {
87            let result_id = spv[idx + 1];
88            let component_type = spv[idx + 2];
89            let component_count = spv[idx + 3];
90
91            (result_id == id).then_some((component_type, component_count as usize))
92        })
93    };
94
95    // 3. Insert shared uint definitions and shared constants
96    // Since there are only two main float widths, we will include both for simplicity
97    let mut header_insert = InstructionInsert {
98        previous_spv_idx: header_position.unwrap(),
99        instruction: vec![],
100    };
101
102    // NOTE: 64-bit isnan and isinf doesn't actually make much sense, so I will just leave it.
103    // In standard glsl, you cannot write that kind of substitution because there is no `uint64_t` nor `doubleBitsToUint`.
104
105    let uint32_id = ensure_type_int(
106        &spv,
107        &op_type_int_idxs,
108        &mut instruction_bound,
109        &mut header_insert.instruction,
110        32,
111        SPV_SIGNEDNESS_UNSIGNED,
112    );
113    let uint32_ptr_id = ensure_type_pointer(
114        &spv,
115        &op_type_pointer_idxs,
116        &mut instruction_bound,
117        &mut header_insert.instruction,
118        SPV_STORAGE_CLASS_FUNCTION,
119        uint32_id,
120    );
121    let shared_type_inputs_32 = NanInfSharedTypeInputs {
122        uint_id: uint32_id,
123        ptr_uint_id: uint32_ptr_id,
124    };
125
126    // let uint64_id = ensure_type_int(&mut header_insert.instruction, &mut instruction_bound, 64);
127    // let uint64_ptr_id = ensure_type_ptr(
128    //     &mut header_insert.instruction,
129    //     &mut instruction_bound,
130    //     uint64_id,
131    // );
132    // let shared_type_inputs_64 = NanInfSharedTypeInputs {
133    //     uint_id: uint64_id,
134    //     ptr_uint_id: uint64_ptr_id,
135    // };
136
137    let (shared_constants_32, mut constants_spv_32) =
138        nan_inf_shared_constants_spv(&mut instruction_bound, shared_type_inputs_32);
139    header_insert.instruction.append(&mut constants_spv_32);
140    // let (shared_constants_64, mut constants_spv_64) =
141    //     nan_inf_shared_constants(&mut instruction_bound, shared_type_inputs_64);
142    // header_insert.instruction.append(&mut constants_spv_64);
143
144    // 4. Insert shared isnan / isinf declaration and definitions
145
146    // SPIR-V doesn't like duplicate `OpTypeFunction`
147    let mut fn_type_defs = HashMap::new();
148    // We should create the proper number of functions too.
149    let mut fn_defs = HashMap::new();
150
151    let mut desc_to_idx: HashMap<_, Vec<usize>> = HashMap::new();
152    let fn_set: HashSet<(IsNanOrIsInf, NanInfSharedFunctionInputs, u32, Option<usize>)> =
153        op_is_nan_idxs
154            .iter()
155            .map(|v| (IsNanOrIsInf::IsNan, v))
156            .chain(op_is_inf_idxs.iter().map(|v| (IsNanOrIsInf::IsInf, v)))
157            .map(|(ty, op_idx)| {
158                let input_id = spv[op_idx + 3];
159
160                // We actually cannot rely on loads to get the types of immediate values
161                // let load_idx = op_load_idxs
162                //     .iter()
163                //     .find(|&load_idx| {
164                //         let load_result_id = spv[load_idx + 2];
165                //         load_result_id == input_id
166                //     })
167                //     .expect("OpIsNan/Inf not accompanied by OpLoad?");
168                // let float_ty_id = spv[load_idx + 1];
169
170                let float_ty_id = trace_previous_intermediate_id(&spv, input_id, *op_idx)
171                    .expect("OpIsNan/Inf's argument is not defined?");
172                let original_float_type_id = float_ty_id;
173                let (underlying_float_ty_id, float_component_count) =
174                    get_underlying_vector_type(float_ty_id)
175                        .map(|(a, b)| (a, Some(b)))
176                        .unwrap_or((float_ty_id, None));
177                let pointer_float_ty_id = op_type_pointer_idxs
178                    .iter()
179                    .find_map(|&tp_idx| {
180                        let result_id = spv[tp_idx + 1];
181                        let underlying_type_id = spv[tp_idx + 3];
182
183                        (underlying_type_id == underlying_float_ty_id).then_some(result_id)
184                    })
185                    .unwrap_or_else(|| {
186                        let new_id = instruction_bound;
187                        instruction_bound += 1;
188                        header_insert.instruction.append(&mut vec![
189                            encode_word(4, SPV_INSTRUCTION_OP_TYPE_POINTER),
190                            new_id,
191                            SPV_STORAGE_CLASS_FUNCTION,
192                            underlying_float_ty_id,
193                        ]);
194                        new_id
195                    });
196                let bool_ty_id = spv[op_idx + 1];
197                let (underlying_bool_ty_id, bool_component_count) =
198                    get_underlying_vector_type(bool_ty_id)
199                        .map(|(a, b)| (a, Some(b)))
200                        .unwrap_or((bool_ty_id, None));
201                assert!(bool_component_count == float_component_count);
202
203                let ret = (
204                    ty,
205                    NanInfSharedFunctionInputs {
206                        bool_id: underlying_bool_ty_id,
207                        float_id: underlying_float_ty_id,
208                        ptr_float_id: pointer_float_ty_id,
209                    },
210                    original_float_type_id,
211                    bool_component_count,
212                );
213                desc_to_idx.entry(ret).or_default().push(*op_idx);
214                ret
215            })
216            .collect::<HashSet<_, _>>();
217
218    let mut function_definition_words = vec![];
219
220    struct PatchEntry {
221        fn_id: u32,
222        input: NanInfSharedFunctionInputs,
223        original_float_type_id: u32,
224        bool_component_count: Option<usize>,
225    }
226    let mut patch_map: HashMap<usize, PatchEntry> = HashMap::new();
227    for (ty, input, original_float_type_id, component_count) in fn_set {
228        let (fn_type, mut spv) = nan_inf_fn_type_spv(&mut instruction_bound, input);
229        let fn_type = if let Some(existing_fn_type) = fn_type_defs.get(&input).copied() {
230            existing_fn_type
231        } else {
232            header_insert.instruction.append(&mut spv);
233            fn_type_defs.insert(input, fn_type);
234            fn_type
235        };
236
237        let (selected_type_inputs, selected_constants) =
238            match get_float_type_width(input.float_id).expect("Our OpTypeFloat dispeared?") {
239                32 => (shared_type_inputs_32, shared_constants_32),
240                // 64 => (shared_type_inputs_64, shared_constants_64),
241                n => panic!(
242                    "Float width {} not supported for isnan/isinf substitution",
243                    n
244                ),
245            };
246
247        let (fn_id, mut spv) = is_nan_is_inf_spv(
248            &mut instruction_bound,
249            ty,
250            selected_type_inputs,
251            input,
252            fn_type,
253            selected_constants,
254        );
255        let fn_id = if let Some(existing_fn_id) = fn_defs.get(&(ty, input, selected_type_inputs)) {
256            *existing_fn_id
257        } else {
258            function_definition_words.append(&mut spv);
259            fn_defs.insert((ty, input, selected_type_inputs), fn_id);
260            fn_id
261        };
262
263        let key = (ty, input, original_float_type_id, component_count);
264        for op_idx in &desc_to_idx[&key] {
265            patch_map.insert(
266                *op_idx,
267                PatchEntry {
268                    fn_id,
269                    input,
270                    original_float_type_id,
271                    bool_component_count: component_count,
272                },
273            );
274        }
275    }
276
277    // 5. Insert additional temp variables and indexing constants for vectored cases
278    // We will create the shared data used to generate these variables here
279    let mut indexing_constant_instructions = InstructionInsert {
280        previous_spv_idx: header_position.unwrap(),
281        instruction: vec![],
282    };
283
284    let max_components = patch_map
285        .values()
286        .filter_map(|v| v.bool_component_count)
287        .max()
288        .unwrap_or(0);
289
290    // We need constants for indexing
291    let mut index_ids = vec![];
292    for n in 0..max_components {
293        let index_id = instruction_bound;
294        instruction_bound += 1;
295        index_ids.push(index_id);
296
297        indexing_constant_instructions.instruction.append(&mut vec![
298            encode_word(4, SPV_INSTRUCTION_OP_CONSTANT),
299            uint32_id,
300            index_id,
301            n as u32,
302        ]);
303    }
304
305    instruction_inserts.push(indexing_constant_instructions);
306
307    // 6. Insert and patch isnan / isinf usage
308    for &op_idx in op_is_nan_idxs.iter().chain(op_is_inf_idxs.iter()) {
309        let result_type_id = spv[op_idx + 1];
310        let result_id = spv[op_idx + 2];
311        let x = spv[op_idx + 3];
312        let PatchEntry {
313            fn_id,
314            input,
315            original_float_type_id,
316            bool_component_count,
317        } = patch_map[&op_idx];
318
319        for i in 0..4 {
320            new_spv[op_idx + i] = encode_word(1, SPV_INSTRUCTION_OP_NOP);
321        }
322
323        // Both patch implementations need a temp variable
324        // TODO: OPT further reduce the number of temp variables by sharing then within the same functions
325        let mut temp_variable_instructions = InstructionInsert {
326            previous_spv_idx: get_function_label_index_of_instruction_index(&spv, op_idx),
327            instruction: vec![],
328        };
329        let param_id = instruction_bound;
330        instruction_bound += 1;
331        temp_variable_instructions.instruction.append(&mut vec![
332            encode_word(4, SPV_INSTRUCTION_OP_VARIABLE),
333            input.ptr_float_id,
334            param_id,
335            SPV_STORAGE_CLASS_FUNCTION,
336        ]);
337
338        if let Some(component_count) = bool_component_count {
339            let mut new_instructions = InstructionInsert {
340                previous_spv_idx: op_idx,
341                instruction: vec![],
342            };
343
344            // We need a temp variable for the vector itself
345            let float_vector_type_pointer_id = op_type_pointer_idxs
346                .iter()
347                .find_map(|idx| (spv[idx + 3] == original_float_type_id).then_some(spv[idx + 1]))
348                .expect("This vector type has no type pointer?");
349            let temp_vector_id = instruction_bound;
350            instruction_bound += 1;
351            temp_variable_instructions.instruction.append(&mut vec![
352                encode_word(4, SPV_INSTRUCTION_OP_VARIABLE),
353                float_vector_type_pointer_id,
354                temp_vector_id,
355                SPV_STORAGE_CLASS_FUNCTION,
356            ]);
357
358            let mut component_results = (0..component_count)
359                .map(|n| {
360                    let accessed_id = instruction_bound;
361                    instruction_bound += 1;
362                    let loaded_id = instruction_bound;
363                    instruction_bound += 1;
364                    let fn_result_id = instruction_bound;
365                    instruction_bound += 1;
366                    new_instructions.instruction.append(&mut vec![
367                        encode_word(3, SPV_INSTRUCTION_OP_STORE),
368                        temp_vector_id,
369                        x,
370                        encode_word(5, SPV_INSTRUCTION_OP_ACCESS_CHAIN),
371                        input.ptr_float_id,
372                        accessed_id,
373                        temp_vector_id,
374                        index_ids[n],
375                        encode_word(4, SPV_INSTRUCTION_OP_LOAD),
376                        input.float_id,
377                        loaded_id,
378                        accessed_id,
379                        encode_word(3, SPV_INSTRUCTION_OP_STORE),
380                        param_id,
381                        loaded_id,
382                        encode_word(5, SPV_INSTRUCTION_OP_FUNCTION_CALL),
383                        input.bool_id,
384                        fn_result_id,
385                        fn_id,
386                        param_id,
387                    ]);
388                    fn_result_id
389                })
390                .collect::<Vec<u32>>();
391
392            new_instructions.instruction.append(&mut vec![
393                encode_word(
394                    3 + component_count as u16,
395                    SPV_INSTRUCTION_OP_COMPOSITE_CONSTRUCT,
396                ),
397                result_type_id,
398                result_id,
399            ]);
400            new_instructions.instruction.append(&mut component_results);
401            instruction_inserts.push(new_instructions);
402        } else {
403            let new_instructions = InstructionInsert {
404                previous_spv_idx: op_idx,
405                instruction: vec![
406                    encode_word(3, SPV_INSTRUCTION_OP_STORE),
407                    param_id,
408                    x,
409                    encode_word(5, SPV_INSTRUCTION_OP_FUNCTION_CALL),
410                    result_type_id,
411                    result_id,
412                    fn_id,
413                    param_id,
414                ],
415            };
416            instruction_inserts.push(new_instructions);
417        }
418
419        instruction_inserts.push(temp_variable_instructions);
420    }
421
422    // 7. Insert New Instructions
423    instruction_inserts.insert(0, header_insert);
424    insert_new_instructions(&spv, &mut new_spv, &word_inserts, &instruction_inserts);
425    new_spv.append(&mut function_definition_words);
426
427    // 8. Remove Instructions that have been Whited Out.
428    prune_noops(&mut new_spv);
429
430    // 9. Write New Header and New Code
431    Ok(fuse_final(spv_header, new_spv, instruction_bound))
432}