1use super::*;
2
3fn inc(ib: &mut u32) -> u32 {
4 *ib += 1;
5 *ib - 1
6}
7
8mod isnan_isinf;
10mod shared;
11
12use isnan_isinf::*;
13use shared::*;
14
15pub fn isnanisinfpatch(in_spv: &[u32]) -> Result<Vec<u32>, ()> {
19 let spv = in_spv.to_owned();
20
21 let mut instruction_bound = spv[SPV_HEADER_INSTRUCTION_BOUND_OFFSET];
22 let magic_number = spv[SPV_HEADER_MAGIC_NUM_OFFSET];
23
24 let spv_header = spv[0..SPV_HEADER_LENGTH].to_owned();
25
26 assert_eq!(magic_number, SPV_HEADER_MAGIC);
27
28 let mut instruction_inserts = vec![];
29 let word_inserts = vec![];
30
31 let spv = spv.into_iter().skip(SPV_HEADER_LENGTH).collect::<Vec<_>>();
32 let mut new_spv = spv.clone();
33
34 let mut op_function_idxs = vec![];
36 let mut op_load_idxs = vec![];
37 let mut op_type_pointer_idxs = vec![];
38 let mut op_is_nan_idxs = vec![];
39 let mut op_is_inf_idxs = vec![];
40 let mut op_type_bool_idxs = vec![];
41 let mut op_type_int_idxs = vec![];
42 let mut op_type_float_idxs = vec![];
43 let mut op_type_vector_idxs = vec![];
44
45 let mut spv_idx = 0;
46 while spv_idx < spv.len() {
47 let op = spv[spv_idx];
48 let word_count = hiword(op);
49 let instruction = loword(op);
50
51 match instruction {
52 SPV_INSTRUCTION_OP_FUNCTION => op_function_idxs.push(spv_idx),
53 SPV_INSTRUCTION_OP_LOAD => op_load_idxs.push(spv_idx),
54 SPV_INSTRUCTION_OP_TYPE_POINTER => op_type_pointer_idxs.push(spv_idx),
55 SPV_INSTRUCTION_OP_IS_NAN => op_is_nan_idxs.push(spv_idx),
56 SPV_INSTRUCTION_OP_IS_INF => op_is_inf_idxs.push(spv_idx),
57 SPV_INSTRUCTION_OP_TYPE_BOOL => op_type_bool_idxs.push(spv_idx),
58 SPV_INSTRUCTION_OP_TYPE_INT => op_type_int_idxs.push(spv_idx),
59 SPV_INSTRUCTION_OP_TYPE_FLOAT => op_type_float_idxs.push(spv_idx),
60 SPV_INSTRUCTION_OP_TYPE_VECTOR => op_type_vector_idxs.push(spv_idx),
61 _ => {}
62 }
63
64 spv_idx += word_count as usize;
65 }
66
67 if op_is_nan_idxs.is_empty() && op_is_inf_idxs.is_empty() {
68 return Ok(in_spv.to_vec());
69 }
70 let Some(last_op_type_pointer) = op_type_pointer_idxs.last() else {
71 return Ok(in_spv.to_vec());
72 };
73 let last_op_type_pointer = *last_op_type_pointer;
74
75 let get_float_type_width = |id| {
77 op_type_float_idxs
78 .iter()
79 .find_map(|idx| (spv[idx + 1] == id).then_some(spv[idx + 2]))
80 };
81
82 let get_underlying_vector_type = |id| {
83 op_type_vector_idxs.iter().find_map(|idx| {
84 let result_id = spv[idx + 1];
85 let component_type = spv[idx + 2];
86 let component_count = spv[idx + 3];
87
88 (result_id == id).then_some((component_type, component_count as usize))
89 })
90 };
91
92 let mut header_insert = InstructionInsert {
95 previous_spv_idx: last_op_type_pointer,
96 instruction: vec![],
97 };
98
99 let ensure_type_int =
100 |header: &mut Vec<u32>, instruction_bound: &mut u32, template_width: u32| {
101 if let Some(idx) = op_type_int_idxs.iter().find(|&&ty_idx| {
102 let width = spv[ty_idx + 2];
103 let signedness = spv[ty_idx + 2];
104
105 signedness == SPV_SIGNEDNESS_UNSIGNED && width == template_width
106 }) {
107 spv[idx + 1]
108 } else {
109 let new_id = *instruction_bound;
110 *instruction_bound += 1;
111 header.append(&mut vec![
112 encode_word(4, SPV_INSTRUCTION_OP_TYPE_INT),
113 new_id,
114 template_width,
115 SPV_SIGNEDNESS_UNSIGNED,
116 ]);
117 new_id
118 }
119 };
120 let ensure_type_ptr = |header: &mut Vec<u32>, instruction_bound: &mut u32, template_id: u32| {
121 if let Some(tp_idx) = op_type_pointer_idxs
122 .iter()
123 .find(|&&tp_idx| template_id == spv[tp_idx + 3])
124 {
125 spv[tp_idx + 1]
126 } else {
127 let new_id = *instruction_bound;
128 *instruction_bound += 1;
129 header.append(&mut vec![
130 encode_word(4, SPV_INSTRUCTION_OP_TYPE_POINTER),
131 new_id,
132 SPV_STORAGE_CLASS_FUNCTION,
133 template_id,
134 ]);
135 new_id
136 }
137 };
138
139 let uint32_id = ensure_type_int(&mut header_insert.instruction, &mut instruction_bound, 32);
143 let uint32_ptr_id = ensure_type_ptr(
144 &mut header_insert.instruction,
145 &mut instruction_bound,
146 uint32_id,
147 );
148 let shared_type_inputs_32 = NanInfSharedTypeInputs {
149 uint_id: uint32_id,
150 ptr_uint_id: uint32_ptr_id,
151 };
152
153 let (shared_constants_32, mut constants_spv_32) =
165 nan_inf_shared_constants(&mut instruction_bound, shared_type_inputs_32);
166 header_insert.instruction.append(&mut constants_spv_32);
167 let mut fn_type_defs = HashMap::new();
175 let mut fn_defs = HashMap::new();
177
178 let mut desc_to_idx: HashMap<_, Vec<usize>> = HashMap::new();
179 let fn_set: HashSet<(IsNanOrIsInf, NanInfSharedFunctionInputs, u32, Option<usize>)> =
180 op_is_nan_idxs
181 .iter()
182 .map(|v| (IsNanOrIsInf::IsNan, v))
183 .chain(op_is_inf_idxs.iter().map(|v| (IsNanOrIsInf::IsInf, v)))
184 .map(|(ty, op_idx)| {
185 let input_id = spv[op_idx + 3];
186
187 let float_ty_id = trace_previous_intermediate_id(&spv, input_id, *op_idx)
198 .expect("OpIsNan/Inf's argument is not defined?");
199 let original_float_type_id = float_ty_id;
200 let (underlying_float_ty_id, float_component_count) =
201 get_underlying_vector_type(float_ty_id)
202 .map(|(a, b)| (a, Some(b)))
203 .unwrap_or((float_ty_id, None));
204 let pointer_float_ty_id = op_type_pointer_idxs
205 .iter()
206 .find_map(|&tp_idx| {
207 let result_id = spv[tp_idx + 1];
208 let underlying_type_id = spv[tp_idx + 3];
209
210 (underlying_type_id == underlying_float_ty_id).then_some(result_id)
211 })
212 .unwrap_or_else(|| {
213 let new_id = instruction_bound;
214 instruction_bound += 1;
215 header_insert.instruction.append(&mut vec![
216 encode_word(4, SPV_INSTRUCTION_OP_TYPE_POINTER),
217 new_id,
218 SPV_STORAGE_CLASS_FUNCTION,
219 underlying_float_ty_id,
220 ]);
221 new_id
222 });
223 let bool_ty_id = spv[op_idx + 1];
224 let (underlying_bool_ty_id, bool_component_count) =
225 get_underlying_vector_type(bool_ty_id)
226 .map(|(a, b)| (a, Some(b)))
227 .unwrap_or((bool_ty_id, None));
228 assert!(bool_component_count == float_component_count);
229
230 let ret = (
231 ty,
232 NanInfSharedFunctionInputs {
233 bool_id: underlying_bool_ty_id,
234 float_id: underlying_float_ty_id,
235 ptr_float_id: pointer_float_ty_id,
236 },
237 original_float_type_id,
238 bool_component_count,
239 );
240 desc_to_idx.entry(ret).or_default().push(*op_idx);
241 ret
242 })
243 .collect::<HashSet<_, _>>();
244
245 let mut function_definition_words = vec![];
246
247 struct PatchEntry {
248 fn_id: u32,
249 input: NanInfSharedFunctionInputs,
250 original_float_type_id: u32,
251 bool_component_count: Option<usize>,
252 }
253 let mut patch_map: HashMap<usize, PatchEntry> = HashMap::new();
254 for (ty, input, original_float_type_id, component_count) in fn_set {
255 let (fn_type, mut spv) = nan_inf_fn_type(&mut instruction_bound, input);
256 let fn_type = if let Some(existing_fn_type) = fn_type_defs.get(&input).copied() {
257 existing_fn_type
258 } else {
259 header_insert.instruction.append(&mut spv);
260 fn_type_defs.insert(input, fn_type);
261 fn_type
262 };
263
264 let (selected_type_inputs, selected_constants) =
265 match get_float_type_width(input.float_id).expect("Our OpTypeFloat dispeared?") {
266 32 => (shared_type_inputs_32, shared_constants_32),
267 n => panic!(
269 "Float width {} not supported for isnan/isinf substitution",
270 n
271 ),
272 };
273
274 let (fn_id, mut spv) = is_nan_is_inf_spv(
275 &mut instruction_bound,
276 ty,
277 selected_type_inputs,
278 input,
279 fn_type,
280 selected_constants,
281 );
282 let fn_id = if let Some(existing_fn_id) = fn_defs.get(&(ty, input, selected_type_inputs)) {
283 *existing_fn_id
284 } else {
285 function_definition_words.append(&mut spv);
286 fn_defs.insert((ty, input, selected_type_inputs), fn_id);
287 fn_id
288 };
289
290 let key = (ty, input, original_float_type_id, component_count);
291 for op_idx in &desc_to_idx[&key] {
292 patch_map.insert(
293 *op_idx,
294 PatchEntry {
295 fn_id,
296 input,
297 original_float_type_id,
298 bool_component_count: component_count,
299 },
300 );
301 }
302 }
303
304 let mut indexing_constant_instructions = InstructionInsert {
307 previous_spv_idx: last_op_type_pointer,
308 instruction: vec![],
309 };
310
311 let max_components = patch_map
312 .values()
313 .filter_map(|v| v.bool_component_count)
314 .max()
315 .unwrap_or(0);
316
317 let mut index_ids = vec![];
319 for n in 0..max_components {
320 let index_id = instruction_bound;
321 instruction_bound += 1;
322 index_ids.push(index_id);
323
324 indexing_constant_instructions.instruction.append(&mut vec![
325 encode_word(4, SPV_INSTRUCTION_OP_CONSTANT),
326 uint32_id,
327 index_id,
328 n as u32,
329 ]);
330 }
331
332 instruction_inserts.push(indexing_constant_instructions);
333
334 for &op_idx in op_is_nan_idxs.iter().chain(op_is_inf_idxs.iter()) {
336 let result_type_id = spv[op_idx + 1];
337 let result_id = spv[op_idx + 2];
338 let x = spv[op_idx + 3];
339 let PatchEntry {
340 fn_id,
341 input,
342 original_float_type_id,
343 bool_component_count,
344 } = patch_map[&op_idx];
345
346 for i in 0..4 {
347 new_spv[op_idx + i] = encode_word(1, SPV_INSTRUCTION_OP_NOP);
348 }
349
350 let mut temp_variable_instructions = InstructionInsert {
353 previous_spv_idx: get_function_label_index_of_instruction_index(&spv, op_idx),
354 instruction: vec![],
355 };
356 let param_id = instruction_bound;
357 instruction_bound += 1;
358 temp_variable_instructions.instruction.append(&mut vec![
359 encode_word(4, SPV_INSTRUCTION_OP_VARIABLE),
360 input.ptr_float_id,
361 param_id,
362 SPV_STORAGE_CLASS_FUNCTION,
363 ]);
364
365 if let Some(component_count) = bool_component_count {
366 let mut new_instructions = InstructionInsert {
367 previous_spv_idx: op_idx,
368 instruction: vec![],
369 };
370
371 let float_vector_type_pointer_id = op_type_pointer_idxs
373 .iter()
374 .find_map(|idx| (spv[idx + 3] == original_float_type_id).then_some(spv[idx + 1]))
375 .expect("This vector type has no type pointer?");
376 let temp_vector_id = instruction_bound;
377 instruction_bound += 1;
378 temp_variable_instructions.instruction.append(&mut vec![
379 encode_word(4, SPV_INSTRUCTION_OP_VARIABLE),
380 float_vector_type_pointer_id,
381 temp_vector_id,
382 SPV_STORAGE_CLASS_FUNCTION,
383 ]);
384
385 let mut component_results = (0..component_count)
386 .map(|n| {
387 let accessed_id = instruction_bound;
388 instruction_bound += 1;
389 let loaded_id = instruction_bound;
390 instruction_bound += 1;
391 let fn_result_id = instruction_bound;
392 instruction_bound += 1;
393 new_instructions.instruction.append(&mut vec![
394 encode_word(3, SPV_INSTRUCTION_OP_STORE),
395 temp_vector_id,
396 x,
397 encode_word(5, SPV_INSTRUCTION_OP_ACCESS_CHAIN),
398 input.ptr_float_id,
399 accessed_id,
400 temp_vector_id,
401 index_ids[n],
402 encode_word(4, SPV_INSTRUCTION_OP_LOAD),
403 input.float_id,
404 loaded_id,
405 accessed_id,
406 encode_word(3, SPV_INSTRUCTION_OP_STORE),
407 param_id,
408 loaded_id,
409 encode_word(5, SPV_INSTRUCTION_OP_FUNCTION_CALL),
410 input.bool_id,
411 fn_result_id,
412 fn_id,
413 param_id,
414 ]);
415 fn_result_id
416 })
417 .collect::<Vec<u32>>();
418
419 new_instructions.instruction.append(&mut vec![
420 encode_word(
421 3 + component_count as u16,
422 SPV_INSTRUCTION_OP_COMPOSITE_CONSTRUCT,
423 ),
424 result_type_id,
425 result_id,
426 ]);
427 new_instructions.instruction.append(&mut component_results);
428 instruction_inserts.push(new_instructions);
429 } else {
430 let new_instructions = InstructionInsert {
431 previous_spv_idx: op_idx,
432 instruction: vec![
433 encode_word(3, SPV_INSTRUCTION_OP_STORE),
434 param_id,
435 x,
436 encode_word(5, SPV_INSTRUCTION_OP_FUNCTION_CALL),
437 result_type_id,
438 result_id,
439 fn_id,
440 param_id,
441 ],
442 };
443 instruction_inserts.push(new_instructions);
444 }
445
446 instruction_inserts.push(temp_variable_instructions);
447 }
448
449 instruction_inserts.insert(0, header_insert);
451 insert_new_instructions(&spv, &mut new_spv, &word_inserts, &instruction_inserts);
452 new_spv.append(&mut function_definition_words);
453
454 prune_noops(&mut new_spv);
456
457 Ok(fuse_final(spv_header, new_spv, instruction_bound))
459}