Skip to main content

spirv_webgpu_transform/util/
ensure.rs

1use super::*;
2
3#[macro_export]
4macro_rules! last_of_indices {
5    ( $( $v:expr ),+ $(,)? ) => {{
6        let mut max_val: Option<usize> = None;
7
8        $(
9            for &x in $v.iter() {
10                max_val = Some(match max_val {
11                    Some(current) => current.max(x),
12                    None => x,
13                });
14            }
15        )+
16
17        max_val
18    }};
19}
20
21pub fn ensure_ext_inst_import<F: Fn(&str) -> bool>(
22    spv: &[u32],
23    op_ext_inst_import_idxs: &[usize],
24    instruction_bound: &mut u32,
25    header: &mut Vec<u32>,
26    filter: F,
27    template: &str,
28) -> u32 {
29    if let Some(idx) = op_ext_inst_import_idxs.iter().find(|&&idx| {
30        let word_count = hiword(spv[idx]) as usize;
31        let extension = literal_to_string_le(&spv[idx + 2..idx + word_count])
32            .expect("Invalid string in OpExtInstImport");
33        filter(&extension)
34    }) {
35        spv[idx + 1]
36    } else {
37        let mut ext = string_to_literal_le(template);
38        let new_id = *instruction_bound;
39        *instruction_bound += 1;
40        header.append(&mut vec![
41            encode_word(2 + ext.len() as u16, SPV_INSTRUCTION_OP_EXT_INST_IMPORT),
42            new_id,
43        ]);
44        header.append(&mut ext);
45        new_id
46    }
47}
48
49pub fn ensure_type_bool(
50    spv: &[u32],
51    op_type_bool_idxs: &[usize],
52    instruction_bound: &mut u32,
53    header: &mut Vec<u32>,
54) -> u32 {
55    if let Some(idx) = op_type_bool_idxs.first() {
56        spv[idx + 1]
57    } else {
58        let new_id = *instruction_bound;
59        *instruction_bound += 1;
60        header.append(&mut vec![
61            encode_word(2, SPV_INSTRUCTION_OP_TYPE_BOOL),
62            new_id,
63        ]);
64        new_id
65    }
66}
67
68pub fn ensure_type_int(
69    spv: &[u32],
70    op_type_int_idxs: &[usize],
71    instruction_bound: &mut u32,
72    header: &mut Vec<u32>,
73    template_width: u32,
74    template_signedness: u32,
75) -> u32 {
76    if let Some(idx) = op_type_int_idxs.iter().find(|&&ty_idx| {
77        let width = spv[ty_idx + 2];
78        let signedness = spv[ty_idx + 3];
79
80        width == template_width && signedness == template_signedness
81    }) {
82        spv[idx + 1]
83    } else {
84        let new_id = *instruction_bound;
85        *instruction_bound += 1;
86        header.append(&mut vec![
87            encode_word(4, SPV_INSTRUCTION_OP_TYPE_INT),
88            new_id,
89            template_width,
90            template_signedness,
91        ]);
92        new_id
93    }
94}
95
96pub fn ensure_type_vector(
97    spv: &[u32],
98    op_type_vector_idxs: &[usize],
99    instruction_bound: &mut u32,
100    header: &mut Vec<u32>,
101    template_component_type_id: u32,
102    template_component_count: u32,
103) -> u32 {
104    if let Some(idx) = op_type_vector_idxs.iter().find(|&&ty_idx| {
105        let component_type_id = spv[ty_idx + 2];
106        let component_count = spv[ty_idx + 3];
107
108        component_type_id == template_component_type_id
109            && component_count == template_component_count
110    }) {
111        spv[idx + 1]
112    } else {
113        let new_id = *instruction_bound;
114        *instruction_bound += 1;
115        header.append(&mut vec![
116            encode_word(4, SPV_INSTRUCTION_OP_TYPE_VECTOR),
117            new_id,
118            template_component_type_id,
119            template_component_count,
120        ]);
121        new_id
122    }
123}
124
125pub fn ensure_type_pointer(
126    spv: &[u32],
127    op_type_pointer_idxs: &[usize],
128    instruction_bound: &mut u32,
129    header: &mut Vec<u32>,
130    template_storage_class: u32,
131    template_underlying_type_id: u32,
132) -> u32 {
133    if let Some(tp_idx) = op_type_pointer_idxs.iter().find(|&&tp_idx| {
134        let storage_class = spv[tp_idx + 2];
135        let underlying_type_id = spv[tp_idx + 3];
136        storage_class == template_storage_class && template_underlying_type_id == underlying_type_id
137    }) {
138        spv[tp_idx + 1]
139    } else {
140        let new_id = *instruction_bound;
141        *instruction_bound += 1;
142        header.append(&mut vec![
143            encode_word(4, SPV_INSTRUCTION_OP_TYPE_POINTER),
144            new_id,
145            template_storage_class,
146            template_underlying_type_id,
147        ]);
148        new_id
149    }
150}