spirv_webgpu_transform/util/
ensure.rs1use super::*;
2
3#[macro_export]
4macro_rules! last_of_indices {
5 ( $( $v:expr ),+ $(,)? ) => {{
6 let mut max_val: Option<usize> = None;
7
8 $(
9 for &x in $v.iter() {
10 max_val = Some(match max_val {
11 Some(current) => current.max(x),
12 None => x,
13 });
14 }
15 )+
16
17 max_val
18 }};
19}
20
21pub fn ensure_ext_inst_import<F: Fn(&str) -> bool>(
22 spv: &[u32],
23 op_ext_inst_import_idxs: &[usize],
24 instruction_bound: &mut u32,
25 header: &mut Vec<u32>,
26 filter: F,
27 template: &str,
28) -> u32 {
29 if let Some(idx) = op_ext_inst_import_idxs.iter().find(|&&idx| {
30 let word_count = hiword(spv[idx]) as usize;
31 let extension = literal_to_string_le(&spv[idx + 2..idx + word_count])
32 .expect("Invalid string in OpExtInstImport");
33 filter(&extension)
34 }) {
35 spv[idx + 1]
36 } else {
37 let mut ext = string_to_literal_le(template);
38 let new_id = *instruction_bound;
39 *instruction_bound += 1;
40 header.append(&mut vec![
41 encode_word(2 + ext.len() as u16, SPV_INSTRUCTION_OP_EXT_INST_IMPORT),
42 new_id,
43 ]);
44 header.append(&mut ext);
45 new_id
46 }
47}
48
49pub fn ensure_type_bool(
50 spv: &[u32],
51 op_type_bool_idxs: &[usize],
52 instruction_bound: &mut u32,
53 header: &mut Vec<u32>,
54) -> u32 {
55 if let Some(idx) = op_type_bool_idxs.first() {
56 spv[idx + 1]
57 } else {
58 let new_id = *instruction_bound;
59 *instruction_bound += 1;
60 header.append(&mut vec![
61 encode_word(2, SPV_INSTRUCTION_OP_TYPE_BOOL),
62 new_id,
63 ]);
64 new_id
65 }
66}
67
68pub fn ensure_type_int(
69 spv: &[u32],
70 op_type_int_idxs: &[usize],
71 instruction_bound: &mut u32,
72 header: &mut Vec<u32>,
73 template_width: u32,
74 template_signedness: u32,
75) -> u32 {
76 if let Some(idx) = op_type_int_idxs.iter().find(|&&ty_idx| {
77 let width = spv[ty_idx + 2];
78 let signedness = spv[ty_idx + 3];
79
80 width == template_width && signedness == template_signedness
81 }) {
82 spv[idx + 1]
83 } else {
84 let new_id = *instruction_bound;
85 *instruction_bound += 1;
86 header.append(&mut vec![
87 encode_word(4, SPV_INSTRUCTION_OP_TYPE_INT),
88 new_id,
89 template_width,
90 template_signedness,
91 ]);
92 new_id
93 }
94}
95
96pub fn ensure_type_vector(
97 spv: &[u32],
98 op_type_vector_idxs: &[usize],
99 instruction_bound: &mut u32,
100 header: &mut Vec<u32>,
101 template_component_type_id: u32,
102 template_component_count: u32,
103) -> u32 {
104 if let Some(idx) = op_type_vector_idxs.iter().find(|&&ty_idx| {
105 let component_type_id = spv[ty_idx + 2];
106 let component_count = spv[ty_idx + 3];
107
108 component_type_id == template_component_type_id
109 && component_count == template_component_count
110 }) {
111 spv[idx + 1]
112 } else {
113 let new_id = *instruction_bound;
114 *instruction_bound += 1;
115 header.append(&mut vec![
116 encode_word(4, SPV_INSTRUCTION_OP_TYPE_VECTOR),
117 new_id,
118 template_component_type_id,
119 template_component_count,
120 ]);
121 new_id
122 }
123}
124
125pub fn ensure_type_pointer(
126 spv: &[u32],
127 op_type_pointer_idxs: &[usize],
128 instruction_bound: &mut u32,
129 header: &mut Vec<u32>,
130 template_storage_class: u32,
131 template_underlying_type_id: u32,
132) -> u32 {
133 if let Some(tp_idx) = op_type_pointer_idxs.iter().find(|&&tp_idx| {
134 let storage_class = spv[tp_idx + 2];
135 let underlying_type_id = spv[tp_idx + 3];
136 storage_class == template_storage_class && template_underlying_type_id == underlying_type_id
137 }) {
138 spv[tp_idx + 1]
139 } else {
140 let new_id = *instruction_bound;
141 *instruction_bound += 1;
142 header.append(&mut vec![
143 encode_word(4, SPV_INSTRUCTION_OP_TYPE_POINTER),
144 new_id,
145 template_storage_class,
146 template_underlying_type_id,
147 ]);
148 new_id
149 }
150}