1use super::*;
2
3fn inc(ib: &mut u32) -> u32 {
4 *ib += 1;
5 *ib - 1
6}
7
8mod isnan_isinf;
10mod shared;
11
12use isnan_isinf::*;
13use shared::*;
14
15pub fn isnanisinfpatch(in_spv: &[u32]) -> Result<Vec<u32>, ()> {
19 let spv = in_spv.to_owned();
20
21 let mut instruction_bound = spv[SPV_HEADER_INSTRUCTION_BOUND_OFFSET];
22 let magic_number = spv[SPV_HEADER_MAGIC_NUM_OFFSET];
23
24 let spv_header = spv[0..SPV_HEADER_LENGTH].to_owned();
25
26 assert_eq!(magic_number, SPV_HEADER_MAGIC);
27
28 let mut instruction_inserts = vec![];
29 let word_inserts = vec![];
30
31 let spv = spv.into_iter().skip(SPV_HEADER_LENGTH).collect::<Vec<_>>();
32 let mut new_spv = spv.clone();
33
34 let mut op_function_idxs = vec![];
36 let mut op_load_idxs = vec![];
37 let mut op_type_pointer_idxs = vec![];
38 let mut op_is_nan_idxs = vec![];
39 let mut op_is_inf_idxs = vec![];
40 let mut op_type_bool_idxs = vec![];
41 let mut op_type_int_idxs = vec![];
42 let mut op_type_float_idxs = vec![];
43 let mut op_type_vector_idxs = vec![];
44
45 let mut spv_idx = 0;
46 while spv_idx < spv.len() {
47 let op = spv[spv_idx];
48 let word_count = hiword(op);
49 let instruction = loword(op);
50
51 match instruction {
52 SPV_INSTRUCTION_OP_FUNCTION => op_function_idxs.push(spv_idx),
53 SPV_INSTRUCTION_OP_LOAD => op_load_idxs.push(spv_idx),
54 SPV_INSTRUCTION_OP_TYPE_POINTER => op_type_pointer_idxs.push(spv_idx),
55 SPV_INSTRUCTION_OP_IS_NAN => op_is_nan_idxs.push(spv_idx),
56 SPV_INSTRUCTION_OP_IS_INF => op_is_inf_idxs.push(spv_idx),
57 SPV_INSTRUCTION_OP_TYPE_BOOL => op_type_bool_idxs.push(spv_idx),
58 SPV_INSTRUCTION_OP_TYPE_INT => op_type_int_idxs.push(spv_idx),
59 SPV_INSTRUCTION_OP_TYPE_FLOAT => op_type_float_idxs.push(spv_idx),
60 SPV_INSTRUCTION_OP_TYPE_VECTOR => op_type_vector_idxs.push(spv_idx),
61 _ => {}
62 }
63
64 spv_idx += word_count as usize;
65 }
66
67 if op_is_nan_idxs.is_empty() && op_is_inf_idxs.is_empty() {
68 return Ok(in_spv.to_vec());
69 }
70 let header_position = last_of_indices!(
71 op_type_int_idxs,
72 op_type_bool_idxs,
73 op_type_float_idxs,
74 op_type_vector_idxs,
75 op_type_pointer_idxs
76 );
77
78 let get_float_type_width = |id| {
80 op_type_float_idxs
81 .iter()
82 .find_map(|idx| (spv[idx + 1] == id).then_some(spv[idx + 2]))
83 };
84
85 let get_underlying_vector_type = |id| {
86 op_type_vector_idxs.iter().find_map(|idx| {
87 let result_id = spv[idx + 1];
88 let component_type = spv[idx + 2];
89 let component_count = spv[idx + 3];
90
91 (result_id == id).then_some((component_type, component_count as usize))
92 })
93 };
94
95 let mut header_insert = InstructionInsert {
98 previous_spv_idx: header_position.unwrap(),
99 instruction: vec![],
100 };
101
102 let uint32_id = ensure_type_int(
106 &spv,
107 &op_type_int_idxs,
108 &mut instruction_bound,
109 &mut header_insert.instruction,
110 32,
111 SPV_SIGNEDNESS_UNSIGNED,
112 );
113 let uint32_ptr_id = ensure_type_pointer(
114 &spv,
115 &op_type_pointer_idxs,
116 &mut instruction_bound,
117 &mut header_insert.instruction,
118 SPV_STORAGE_CLASS_FUNCTION,
119 uint32_id,
120 );
121 let shared_type_inputs_32 = NanInfSharedTypeInputs {
122 uint_id: uint32_id,
123 ptr_uint_id: uint32_ptr_id,
124 };
125
126 let (shared_constants_32, mut constants_spv_32) =
138 nan_inf_shared_constants_spv(&mut instruction_bound, shared_type_inputs_32);
139 header_insert.instruction.append(&mut constants_spv_32);
140 let mut fn_type_defs = HashMap::new();
148 let mut fn_defs = HashMap::new();
150
151 let mut desc_to_idx: HashMap<_, Vec<usize>> = HashMap::new();
152 let fn_set: HashSet<(IsNanOrIsInf, NanInfSharedFunctionInputs, u32, Option<usize>)> =
153 op_is_nan_idxs
154 .iter()
155 .map(|v| (IsNanOrIsInf::IsNan, v))
156 .chain(op_is_inf_idxs.iter().map(|v| (IsNanOrIsInf::IsInf, v)))
157 .map(|(ty, op_idx)| {
158 let input_id = spv[op_idx + 3];
159
160 let float_ty_id = trace_previous_intermediate_id(&spv, input_id, *op_idx)
171 .expect("OpIsNan/Inf's argument is not defined?");
172 let original_float_type_id = float_ty_id;
173 let (underlying_float_ty_id, float_component_count) =
174 get_underlying_vector_type(float_ty_id)
175 .map(|(a, b)| (a, Some(b)))
176 .unwrap_or((float_ty_id, None));
177 let pointer_float_ty_id = op_type_pointer_idxs
178 .iter()
179 .find_map(|&tp_idx| {
180 let result_id = spv[tp_idx + 1];
181 let underlying_type_id = spv[tp_idx + 3];
182
183 (underlying_type_id == underlying_float_ty_id).then_some(result_id)
184 })
185 .unwrap_or_else(|| {
186 let new_id = instruction_bound;
187 instruction_bound += 1;
188 header_insert.instruction.append(&mut vec![
189 encode_word(4, SPV_INSTRUCTION_OP_TYPE_POINTER),
190 new_id,
191 SPV_STORAGE_CLASS_FUNCTION,
192 underlying_float_ty_id,
193 ]);
194 new_id
195 });
196 let bool_ty_id = spv[op_idx + 1];
197 let (underlying_bool_ty_id, bool_component_count) =
198 get_underlying_vector_type(bool_ty_id)
199 .map(|(a, b)| (a, Some(b)))
200 .unwrap_or((bool_ty_id, None));
201 assert!(bool_component_count == float_component_count);
202
203 let ret = (
204 ty,
205 NanInfSharedFunctionInputs {
206 bool_id: underlying_bool_ty_id,
207 float_id: underlying_float_ty_id,
208 ptr_float_id: pointer_float_ty_id,
209 },
210 original_float_type_id,
211 bool_component_count,
212 );
213 desc_to_idx.entry(ret).or_default().push(*op_idx);
214 ret
215 })
216 .collect::<HashSet<_, _>>();
217
218 let mut function_definition_words = vec![];
219
220 struct PatchEntry {
221 fn_id: u32,
222 input: NanInfSharedFunctionInputs,
223 original_float_type_id: u32,
224 bool_component_count: Option<usize>,
225 }
226 let mut patch_map: HashMap<usize, PatchEntry> = HashMap::new();
227 for (ty, input, original_float_type_id, component_count) in fn_set {
228 let (fn_type, mut spv) = nan_inf_fn_type_spv(&mut instruction_bound, input);
229 let fn_type = if let Some(existing_fn_type) = fn_type_defs.get(&input).copied() {
230 existing_fn_type
231 } else {
232 header_insert.instruction.append(&mut spv);
233 fn_type_defs.insert(input, fn_type);
234 fn_type
235 };
236
237 let (selected_type_inputs, selected_constants) =
238 match get_float_type_width(input.float_id).expect("Our OpTypeFloat dispeared?") {
239 32 => (shared_type_inputs_32, shared_constants_32),
240 n => panic!(
242 "Float width {} not supported for isnan/isinf substitution",
243 n
244 ),
245 };
246
247 let (fn_id, mut spv) = is_nan_is_inf_spv(
248 &mut instruction_bound,
249 ty,
250 selected_type_inputs,
251 input,
252 fn_type,
253 selected_constants,
254 );
255 let fn_id = if let Some(existing_fn_id) = fn_defs.get(&(ty, input, selected_type_inputs)) {
256 *existing_fn_id
257 } else {
258 function_definition_words.append(&mut spv);
259 fn_defs.insert((ty, input, selected_type_inputs), fn_id);
260 fn_id
261 };
262
263 let key = (ty, input, original_float_type_id, component_count);
264 for op_idx in &desc_to_idx[&key] {
265 patch_map.insert(
266 *op_idx,
267 PatchEntry {
268 fn_id,
269 input,
270 original_float_type_id,
271 bool_component_count: component_count,
272 },
273 );
274 }
275 }
276
277 let mut indexing_constant_instructions = InstructionInsert {
280 previous_spv_idx: header_position.unwrap(),
281 instruction: vec![],
282 };
283
284 let max_components = patch_map
285 .values()
286 .filter_map(|v| v.bool_component_count)
287 .max()
288 .unwrap_or(0);
289
290 let mut index_ids = vec![];
292 for n in 0..max_components {
293 let index_id = instruction_bound;
294 instruction_bound += 1;
295 index_ids.push(index_id);
296
297 indexing_constant_instructions.instruction.append(&mut vec![
298 encode_word(4, SPV_INSTRUCTION_OP_CONSTANT),
299 uint32_id,
300 index_id,
301 n as u32,
302 ]);
303 }
304
305 instruction_inserts.push(indexing_constant_instructions);
306
307 for &op_idx in op_is_nan_idxs.iter().chain(op_is_inf_idxs.iter()) {
309 let result_type_id = spv[op_idx + 1];
310 let result_id = spv[op_idx + 2];
311 let x = spv[op_idx + 3];
312 let PatchEntry {
313 fn_id,
314 input,
315 original_float_type_id,
316 bool_component_count,
317 } = patch_map[&op_idx];
318
319 for i in 0..4 {
320 new_spv[op_idx + i] = encode_word(1, SPV_INSTRUCTION_OP_NOP);
321 }
322
323 let mut temp_variable_instructions = InstructionInsert {
326 previous_spv_idx: get_function_label_index_of_instruction_index(&spv, op_idx),
327 instruction: vec![],
328 };
329 let param_id = instruction_bound;
330 instruction_bound += 1;
331 temp_variable_instructions.instruction.append(&mut vec![
332 encode_word(4, SPV_INSTRUCTION_OP_VARIABLE),
333 input.ptr_float_id,
334 param_id,
335 SPV_STORAGE_CLASS_FUNCTION,
336 ]);
337
338 if let Some(component_count) = bool_component_count {
339 let mut new_instructions = InstructionInsert {
340 previous_spv_idx: op_idx,
341 instruction: vec![],
342 };
343
344 let float_vector_type_pointer_id = op_type_pointer_idxs
346 .iter()
347 .find_map(|idx| (spv[idx + 3] == original_float_type_id).then_some(spv[idx + 1]))
348 .expect("This vector type has no type pointer?");
349 let temp_vector_id = instruction_bound;
350 instruction_bound += 1;
351 temp_variable_instructions.instruction.append(&mut vec![
352 encode_word(4, SPV_INSTRUCTION_OP_VARIABLE),
353 float_vector_type_pointer_id,
354 temp_vector_id,
355 SPV_STORAGE_CLASS_FUNCTION,
356 ]);
357
358 let mut component_results = (0..component_count)
359 .map(|n| {
360 let accessed_id = instruction_bound;
361 instruction_bound += 1;
362 let loaded_id = instruction_bound;
363 instruction_bound += 1;
364 let fn_result_id = instruction_bound;
365 instruction_bound += 1;
366 new_instructions.instruction.append(&mut vec![
367 encode_word(3, SPV_INSTRUCTION_OP_STORE),
368 temp_vector_id,
369 x,
370 encode_word(5, SPV_INSTRUCTION_OP_ACCESS_CHAIN),
371 input.ptr_float_id,
372 accessed_id,
373 temp_vector_id,
374 index_ids[n],
375 encode_word(4, SPV_INSTRUCTION_OP_LOAD),
376 input.float_id,
377 loaded_id,
378 accessed_id,
379 encode_word(3, SPV_INSTRUCTION_OP_STORE),
380 param_id,
381 loaded_id,
382 encode_word(5, SPV_INSTRUCTION_OP_FUNCTION_CALL),
383 input.bool_id,
384 fn_result_id,
385 fn_id,
386 param_id,
387 ]);
388 fn_result_id
389 })
390 .collect::<Vec<u32>>();
391
392 new_instructions.instruction.append(&mut vec![
393 encode_word(
394 3 + component_count as u16,
395 SPV_INSTRUCTION_OP_COMPOSITE_CONSTRUCT,
396 ),
397 result_type_id,
398 result_id,
399 ]);
400 new_instructions.instruction.append(&mut component_results);
401 instruction_inserts.push(new_instructions);
402 } else {
403 let new_instructions = InstructionInsert {
404 previous_spv_idx: op_idx,
405 instruction: vec![
406 encode_word(3, SPV_INSTRUCTION_OP_STORE),
407 param_id,
408 x,
409 encode_word(5, SPV_INSTRUCTION_OP_FUNCTION_CALL),
410 result_type_id,
411 result_id,
412 fn_id,
413 param_id,
414 ],
415 };
416 instruction_inserts.push(new_instructions);
417 }
418
419 instruction_inserts.push(temp_variable_instructions);
420 }
421
422 instruction_inserts.insert(0, header_insert);
424 insert_new_instructions(&spv, &mut new_spv, &word_inserts, &instruction_inserts);
425 new_spv.append(&mut function_definition_words);
426
427 prune_noops(&mut new_spv);
429
430 Ok(fuse_final(spv_header, new_spv, instruction_bound))
432}