Skip to main content

spirv_webgpu_transform/
isnanisinfpatch.rs

1use super::*;
2
3fn inc(ib: &mut u32) -> u32 {
4    *ib += 1;
5    *ib - 1
6}
7
8// Someone should make a rust-spirv dsl macro
9mod isnan_isinf;
10mod shared;
11
12use isnan_isinf::*;
13use shared::*;
14
15/// Perform the operation on a `Vec<u32>`.
16/// Use [u8_slice_to_u32_vec] to convert a `&[u8]` into a `Vec<u32>`.
17/// Does not produce any side effects or corrections.
18pub fn isnanisinfpatch(in_spv: &[u32]) -> Result<Vec<u32>, ()> {
19    let spv = in_spv.to_owned();
20
21    let mut instruction_bound = spv[SPV_HEADER_INSTRUCTION_BOUND_OFFSET];
22    let magic_number = spv[SPV_HEADER_MAGIC_NUM_OFFSET];
23
24    let spv_header = spv[0..SPV_HEADER_LENGTH].to_owned();
25
26    assert_eq!(magic_number, SPV_HEADER_MAGIC);
27
28    let mut instruction_inserts = vec![];
29    let word_inserts = vec![];
30
31    let spv = spv.into_iter().skip(SPV_HEADER_LENGTH).collect::<Vec<_>>();
32    let mut new_spv = spv.clone();
33
34    // 1. Find locations instructions we need
35    let mut op_function_idxs = vec![];
36    let mut op_load_idxs = vec![];
37    let mut op_type_pointer_idxs = vec![];
38    let mut op_is_nan_idxs = vec![];
39    let mut op_is_inf_idxs = vec![];
40    let mut op_type_bool_idxs = vec![];
41    let mut op_type_int_idxs = vec![];
42    let mut op_type_float_idxs = vec![];
43    let mut op_type_vector_idxs = vec![];
44
45    let mut spv_idx = 0;
46    while spv_idx < spv.len() {
47        let op = spv[spv_idx];
48        let word_count = hiword(op);
49        let instruction = loword(op);
50
51        match instruction {
52            SPV_INSTRUCTION_OP_FUNCTION => op_function_idxs.push(spv_idx),
53            SPV_INSTRUCTION_OP_LOAD => op_load_idxs.push(spv_idx),
54            SPV_INSTRUCTION_OP_TYPE_POINTER => op_type_pointer_idxs.push(spv_idx),
55            SPV_INSTRUCTION_OP_IS_NAN => op_is_nan_idxs.push(spv_idx),
56            SPV_INSTRUCTION_OP_IS_INF => op_is_inf_idxs.push(spv_idx),
57            SPV_INSTRUCTION_OP_TYPE_BOOL => op_type_bool_idxs.push(spv_idx),
58            SPV_INSTRUCTION_OP_TYPE_INT => op_type_int_idxs.push(spv_idx),
59            SPV_INSTRUCTION_OP_TYPE_FLOAT => op_type_float_idxs.push(spv_idx),
60            SPV_INSTRUCTION_OP_TYPE_VECTOR => op_type_vector_idxs.push(spv_idx),
61            _ => {}
62        }
63
64        spv_idx += word_count as usize;
65    }
66
67    if op_is_nan_idxs.is_empty() && op_is_inf_idxs.is_empty() {
68        return Ok(in_spv.to_vec());
69    }
70    let Some(last_op_type_pointer) = op_type_pointer_idxs.last() else {
71        return Ok(in_spv.to_vec());
72    };
73    let last_op_type_pointer = *last_op_type_pointer;
74
75    // 2. Useful closures
76    let get_float_type_width = |id| {
77        op_type_float_idxs
78            .iter()
79            .find_map(|idx| (spv[idx + 1] == id).then_some(spv[idx + 2]))
80    };
81
82    let get_underlying_vector_type = |id| {
83        op_type_vector_idxs.iter().find_map(|idx| {
84            let result_id = spv[idx + 1];
85            let component_type = spv[idx + 2];
86            let component_count = spv[idx + 3];
87
88            (result_id == id).then_some((component_type, component_count as usize))
89        })
90    };
91
92    // 3. Insert shared uint definitions and shared constants
93    // Since there are only two main float widths, we will include both for simplicity
94    let mut header_insert = InstructionInsert {
95        previous_spv_idx: last_op_type_pointer,
96        instruction: vec![],
97    };
98
99    let ensure_type_int =
100        |header: &mut Vec<u32>, instruction_bound: &mut u32, template_width: u32| {
101            if let Some(idx) = op_type_int_idxs.iter().find(|&&ty_idx| {
102                let width = spv[ty_idx + 2];
103                let signedness = spv[ty_idx + 2];
104
105                signedness == SPV_SIGNEDNESS_UNSIGNED && width == template_width
106            }) {
107                spv[idx + 1]
108            } else {
109                let new_id = *instruction_bound;
110                *instruction_bound += 1;
111                header.append(&mut vec![
112                    encode_word(4, SPV_INSTRUCTION_OP_TYPE_INT),
113                    new_id,
114                    template_width,
115                    SPV_SIGNEDNESS_UNSIGNED,
116                ]);
117                new_id
118            }
119        };
120    let ensure_type_ptr = |header: &mut Vec<u32>, instruction_bound: &mut u32, template_id: u32| {
121        if let Some(tp_idx) = op_type_pointer_idxs
122            .iter()
123            .find(|&&tp_idx| template_id == spv[tp_idx + 3])
124        {
125            spv[tp_idx + 1]
126        } else {
127            let new_id = *instruction_bound;
128            *instruction_bound += 1;
129            header.append(&mut vec![
130                encode_word(4, SPV_INSTRUCTION_OP_TYPE_POINTER),
131                new_id,
132                SPV_STORAGE_CLASS_FUNCTION,
133                template_id,
134            ]);
135            new_id
136        }
137    };
138
139    // NOTE: 64-bit isnan and isinf doesn't actually make much sense, so I will just leave it.
140    // In standard glsl, you cannot write that kind of substitution because there is no `uint64_t` nor `doubleBitsToUint`.
141
142    let uint32_id = ensure_type_int(&mut header_insert.instruction, &mut instruction_bound, 32);
143    let uint32_ptr_id = ensure_type_ptr(
144        &mut header_insert.instruction,
145        &mut instruction_bound,
146        uint32_id,
147    );
148    let shared_type_inputs_32 = NanInfSharedTypeInputs {
149        uint_id: uint32_id,
150        ptr_uint_id: uint32_ptr_id,
151    };
152
153    // let uint64_id = ensure_type_int(&mut header_insert.instruction, &mut instruction_bound, 64);
154    // let uint64_ptr_id = ensure_type_ptr(
155    //     &mut header_insert.instruction,
156    //     &mut instruction_bound,
157    //     uint64_id,
158    // );
159    // let shared_type_inputs_64 = NanInfSharedTypeInputs {
160    //     uint_id: uint64_id,
161    //     ptr_uint_id: uint64_ptr_id,
162    // };
163
164    let (shared_constants_32, mut constants_spv_32) =
165        nan_inf_shared_constants(&mut instruction_bound, shared_type_inputs_32);
166    header_insert.instruction.append(&mut constants_spv_32);
167    // let (shared_constants_64, mut constants_spv_64) =
168    //     nan_inf_shared_constants(&mut instruction_bound, shared_type_inputs_64);
169    // header_insert.instruction.append(&mut constants_spv_64);
170
171    // 4. Insert shared isnan / isinf declaration and definitions
172
173    // SPIR-V doesn't like duplicate `OpTypeFunction`
174    let mut fn_type_defs = HashMap::new();
175    // We should create the proper number of functions too.
176    let mut fn_defs = HashMap::new();
177
178    let mut desc_to_idx: HashMap<_, Vec<usize>> = HashMap::new();
179    let fn_set: HashSet<(IsNanOrIsInf, NanInfSharedFunctionInputs, u32, Option<usize>)> =
180        op_is_nan_idxs
181            .iter()
182            .map(|v| (IsNanOrIsInf::IsNan, v))
183            .chain(op_is_inf_idxs.iter().map(|v| (IsNanOrIsInf::IsInf, v)))
184            .map(|(ty, op_idx)| {
185                let input_id = spv[op_idx + 3];
186
187                // We actually cannot rely on loads to get the types of immediate values
188                // let load_idx = op_load_idxs
189                //     .iter()
190                //     .find(|&load_idx| {
191                //         let load_result_id = spv[load_idx + 2];
192                //         load_result_id == input_id
193                //     })
194                //     .expect("OpIsNan/Inf not accompanied by OpLoad?");
195                // let float_ty_id = spv[load_idx + 1];
196
197                let float_ty_id = trace_previous_intermediate_id(&spv, input_id, *op_idx)
198                    .expect("OpIsNan/Inf's argument is not defined?");
199                let original_float_type_id = float_ty_id;
200                let (underlying_float_ty_id, float_component_count) =
201                    get_underlying_vector_type(float_ty_id)
202                        .map(|(a, b)| (a, Some(b)))
203                        .unwrap_or((float_ty_id, None));
204                let pointer_float_ty_id = op_type_pointer_idxs
205                    .iter()
206                    .find_map(|&tp_idx| {
207                        let result_id = spv[tp_idx + 1];
208                        let underlying_type_id = spv[tp_idx + 3];
209
210                        (underlying_type_id == underlying_float_ty_id).then_some(result_id)
211                    })
212                    .unwrap_or_else(|| {
213                        let new_id = instruction_bound;
214                        instruction_bound += 1;
215                        header_insert.instruction.append(&mut vec![
216                            encode_word(4, SPV_INSTRUCTION_OP_TYPE_POINTER),
217                            new_id,
218                            SPV_STORAGE_CLASS_FUNCTION,
219                            underlying_float_ty_id,
220                        ]);
221                        new_id
222                    });
223                let bool_ty_id = spv[op_idx + 1];
224                let (underlying_bool_ty_id, bool_component_count) =
225                    get_underlying_vector_type(bool_ty_id)
226                        .map(|(a, b)| (a, Some(b)))
227                        .unwrap_or((bool_ty_id, None));
228                assert!(bool_component_count == float_component_count);
229
230                let ret = (
231                    ty,
232                    NanInfSharedFunctionInputs {
233                        bool_id: underlying_bool_ty_id,
234                        float_id: underlying_float_ty_id,
235                        ptr_float_id: pointer_float_ty_id,
236                    },
237                    original_float_type_id,
238                    bool_component_count,
239                );
240                desc_to_idx.entry(ret).or_default().push(*op_idx);
241                ret
242            })
243            .collect::<HashSet<_, _>>();
244
245    let mut function_definition_words = vec![];
246
247    struct PatchEntry {
248        fn_id: u32,
249        input: NanInfSharedFunctionInputs,
250        original_float_type_id: u32,
251        bool_component_count: Option<usize>,
252    }
253    let mut patch_map: HashMap<usize, PatchEntry> = HashMap::new();
254    for (ty, input, original_float_type_id, component_count) in fn_set {
255        let (fn_type, mut spv) = nan_inf_fn_type(&mut instruction_bound, input);
256        let fn_type = if let Some(existing_fn_type) = fn_type_defs.get(&input).copied() {
257            existing_fn_type
258        } else {
259            header_insert.instruction.append(&mut spv);
260            fn_type_defs.insert(input, fn_type);
261            fn_type
262        };
263
264        let (selected_type_inputs, selected_constants) =
265            match get_float_type_width(input.float_id).expect("Our OpTypeFloat dispeared?") {
266                32 => (shared_type_inputs_32, shared_constants_32),
267                // 64 => (shared_type_inputs_64, shared_constants_64),
268                n => panic!(
269                    "Float width {} not supported for isnan/isinf substitution",
270                    n
271                ),
272            };
273
274        let (fn_id, mut spv) = is_nan_is_inf_spv(
275            &mut instruction_bound,
276            ty,
277            selected_type_inputs,
278            input,
279            fn_type,
280            selected_constants,
281        );
282        let fn_id = if let Some(existing_fn_id) = fn_defs.get(&(ty, input, selected_type_inputs)) {
283            *existing_fn_id
284        } else {
285            function_definition_words.append(&mut spv);
286            fn_defs.insert((ty, input, selected_type_inputs), fn_id);
287            fn_id
288        };
289
290        let key = (ty, input, original_float_type_id, component_count);
291        for op_idx in &desc_to_idx[&key] {
292            patch_map.insert(
293                *op_idx,
294                PatchEntry {
295                    fn_id,
296                    input,
297                    original_float_type_id,
298                    bool_component_count: component_count,
299                },
300            );
301        }
302    }
303
304    // 5. Insert additional temp variables and indexing constants for vectored cases
305    // We will create the shared data used to generate these variables here
306    let mut indexing_constant_instructions = InstructionInsert {
307        previous_spv_idx: last_op_type_pointer,
308        instruction: vec![],
309    };
310
311    let max_components = patch_map
312        .values()
313        .filter_map(|v| v.bool_component_count)
314        .max()
315        .unwrap_or(0);
316
317    // We need constants for indexing
318    let mut index_ids = vec![];
319    for n in 0..max_components {
320        let index_id = instruction_bound;
321        instruction_bound += 1;
322        index_ids.push(index_id);
323
324        indexing_constant_instructions.instruction.append(&mut vec![
325            encode_word(4, SPV_INSTRUCTION_OP_CONSTANT),
326            uint32_id,
327            index_id,
328            n as u32,
329        ]);
330    }
331
332    instruction_inserts.push(indexing_constant_instructions);
333
334    // 6. Insert and patch isnan / isinf usage
335    for &op_idx in op_is_nan_idxs.iter().chain(op_is_inf_idxs.iter()) {
336        let result_type_id = spv[op_idx + 1];
337        let result_id = spv[op_idx + 2];
338        let x = spv[op_idx + 3];
339        let PatchEntry {
340            fn_id,
341            input,
342            original_float_type_id,
343            bool_component_count,
344        } = patch_map[&op_idx];
345
346        for i in 0..4 {
347            new_spv[op_idx + i] = encode_word(1, SPV_INSTRUCTION_OP_NOP);
348        }
349
350        // Both patch implementations need a temp variable
351        // TODO: OPT further reduce the number of temp variables by sharing then within the same functions
352        let mut temp_variable_instructions = InstructionInsert {
353            previous_spv_idx: get_function_label_index_of_instruction_index(&spv, op_idx),
354            instruction: vec![],
355        };
356        let param_id = instruction_bound;
357        instruction_bound += 1;
358        temp_variable_instructions.instruction.append(&mut vec![
359            encode_word(4, SPV_INSTRUCTION_OP_VARIABLE),
360            input.ptr_float_id,
361            param_id,
362            SPV_STORAGE_CLASS_FUNCTION,
363        ]);
364
365        if let Some(component_count) = bool_component_count {
366            let mut new_instructions = InstructionInsert {
367                previous_spv_idx: op_idx,
368                instruction: vec![],
369            };
370
371            // We need a temp variable for the vector itself
372            let float_vector_type_pointer_id = op_type_pointer_idxs
373                .iter()
374                .find_map(|idx| (spv[idx + 3] == original_float_type_id).then_some(spv[idx + 1]))
375                .expect("This vector type has no type pointer?");
376            let temp_vector_id = instruction_bound;
377            instruction_bound += 1;
378            temp_variable_instructions.instruction.append(&mut vec![
379                encode_word(4, SPV_INSTRUCTION_OP_VARIABLE),
380                float_vector_type_pointer_id,
381                temp_vector_id,
382                SPV_STORAGE_CLASS_FUNCTION,
383            ]);
384
385            let mut component_results = (0..component_count)
386                .map(|n| {
387                    let accessed_id = instruction_bound;
388                    instruction_bound += 1;
389                    let loaded_id = instruction_bound;
390                    instruction_bound += 1;
391                    let fn_result_id = instruction_bound;
392                    instruction_bound += 1;
393                    new_instructions.instruction.append(&mut vec![
394                        encode_word(3, SPV_INSTRUCTION_OP_STORE),
395                        temp_vector_id,
396                        x,
397                        encode_word(5, SPV_INSTRUCTION_OP_ACCESS_CHAIN),
398                        input.ptr_float_id,
399                        accessed_id,
400                        temp_vector_id,
401                        index_ids[n],
402                        encode_word(4, SPV_INSTRUCTION_OP_LOAD),
403                        input.float_id,
404                        loaded_id,
405                        accessed_id,
406                        encode_word(3, SPV_INSTRUCTION_OP_STORE),
407                        param_id,
408                        loaded_id,
409                        encode_word(5, SPV_INSTRUCTION_OP_FUNCTION_CALL),
410                        input.bool_id,
411                        fn_result_id,
412                        fn_id,
413                        param_id,
414                    ]);
415                    fn_result_id
416                })
417                .collect::<Vec<u32>>();
418
419            new_instructions.instruction.append(&mut vec![
420                encode_word(
421                    3 + component_count as u16,
422                    SPV_INSTRUCTION_OP_COMPOSITE_CONSTRUCT,
423                ),
424                result_type_id,
425                result_id,
426            ]);
427            new_instructions.instruction.append(&mut component_results);
428            instruction_inserts.push(new_instructions);
429        } else {
430            let new_instructions = InstructionInsert {
431                previous_spv_idx: op_idx,
432                instruction: vec![
433                    encode_word(3, SPV_INSTRUCTION_OP_STORE),
434                    param_id,
435                    x,
436                    encode_word(5, SPV_INSTRUCTION_OP_FUNCTION_CALL),
437                    result_type_id,
438                    result_id,
439                    fn_id,
440                    param_id,
441                ],
442            };
443            instruction_inserts.push(new_instructions);
444        }
445
446        instruction_inserts.push(temp_variable_instructions);
447    }
448
449    // 7. Insert New Instructions
450    instruction_inserts.insert(0, header_insert);
451    insert_new_instructions(&spv, &mut new_spv, &word_inserts, &instruction_inserts);
452    new_spv.append(&mut function_definition_words);
453
454    // 8. Remove Instructions that have been Whited Out.
455    prune_noops(&mut new_spv);
456
457    // 9. Write New Header and New Code
458    Ok(fuse_final(spv_header, new_spv, instruction_bound))
459}