1use crate::ast::DataType;
6use crate::onnx::convert::map_onnx_data_type;
7use crate::onnx::ir::{Dim, OnnxIrGraph, TensorShape, TensorType};
8use crate::protos::onnx::{
9 tensor_shape_proto::dimension::Value as DimensionValue, type_proto::Value as TypeProtoValue,
10 GraphProto, ModelProto, NodeProto, TensorProto, TensorProto_DataType,
11};
12use std::collections::{HashMap, HashSet};
13use thiserror::Error;
14
15#[derive(Debug, Error)]
16pub enum ShapeInferenceError {
17 #[error("input '{0}' is missing shape information")]
18 MissingInputShape(String),
19 #[error("input '{input}' has dynamic dimension '{dim}', please provide an override")]
20 DynamicDim { input: String, dim: String },
21 #[error("unsupported ONNX data type: {0}")]
22 UnsupportedDataType(i32),
23 #[error("could not infer shape for op '{op}'")]
24 CannotInfer { op: String },
25}
26
27#[derive(Debug, Default)]
28pub struct InferenceResult {
29 pub value_shapes: HashMap<String, Vec<i64>>,
30 pub value_types: HashMap<String, DataType>,
31 pub const_values: HashMap<String, Vec<i64>>,
32}
33
34pub fn infer_static_shapes(
37 model: &ModelProto,
38 overrides: &HashMap<String, u32>,
39) -> Result<InferenceResult, ShapeInferenceError> {
40 let mut result = InferenceResult::default();
41
42 if model.graph.is_none() {
43 return Ok(result);
44 }
45
46 let graph = model.graph.as_ref().unwrap();
47 let mut ir = OnnxIrGraph::default();
48 let initializer_names: HashSet<String> = graph
49 .initializer
50 .as_slice()
51 .iter()
52 .map(|i| i.name.as_str().to_string())
53 .collect();
54
55 seed_inputs(graph, overrides, &initializer_names, &mut ir, &mut result)?;
56 seed_initializers(graph, &mut ir, &mut result)?;
57 seed_constant_nodes(graph, &mut result, &mut ir)?;
58
59 propagate_node_shapes(graph, &mut result)?;
60
61 Ok(result)
62}
63
64fn seed_inputs(
65 graph: &GraphProto,
66 overrides: &HashMap<String, u32>,
67 initializer_names: &HashSet<String>,
68 ir: &mut OnnxIrGraph,
69 result: &mut InferenceResult,
70) -> Result<(), ShapeInferenceError> {
71 for input in graph.input.as_slice() {
72 let name = input.name.as_str().to_string();
73 let vi = ir.value_or_insert(&name);
74 vi.producer = None;
75
76 if initializer_names.contains(&name) {
77 continue;
78 }
79
80 let type_proto = input
81 .r#type
82 .as_ref()
83 .ok_or_else(|| ShapeInferenceError::MissingInputShape(name.clone()))?;
84
85 let tensor_type = match &type_proto.value {
86 Some(TypeProtoValue::TensorType(tt)) => tt,
87 _ => return Err(ShapeInferenceError::MissingInputShape(name.clone())),
88 };
89
90 let dtype = if tensor_type.elem_type != 0 {
91 map_onnx_data_type(tensor_type.elem_type)
92 .map_err(|_| ShapeInferenceError::UnsupportedDataType(tensor_type.elem_type))?
93 } else {
94 return Err(ShapeInferenceError::UnsupportedDataType(0));
95 };
96
97 let shape = tensor_type
98 .shape
99 .as_ref()
100 .ok_or_else(|| ShapeInferenceError::MissingInputShape(name.clone()))?;
101
102 let mut dims = Vec::new();
103 for dim in shape.dim.as_slice() {
104 if let Some(value) = &dim.value {
105 match value {
106 DimensionValue::DimValue(v) => {
107 dims.push(Dim::Known(*v));
108 }
109 DimensionValue::DimParam(key) => {
110 if let Some(v) = overrides.get(key.as_str()) {
111 dims.push(Dim::Known(*v as i64));
112 } else {
113 return Err(ShapeInferenceError::DynamicDim {
114 input: name.clone(),
115 dim: key.clone(),
116 });
117 }
118 }
119 }
120 } else {
121 return Err(ShapeInferenceError::MissingInputShape(name.clone()));
122 }
123 }
124
125 let ty = TensorType {
126 data_type: dtype.clone(),
127 shape: TensorShape { dims },
128 };
129 vi.ty = Some(ty.clone());
130 result.value_types.insert(name.clone(), dtype);
131 if let Some(shape) = ty.shape.to_i64() {
132 result.value_shapes.insert(name, shape);
133 }
134 }
135 Ok(())
136}
137
138fn seed_initializers(
139 graph: &GraphProto,
140 ir: &mut OnnxIrGraph,
141 result: &mut InferenceResult,
142) -> Result<(), ShapeInferenceError> {
143 for init in graph.initializer.as_slice() {
144 let name = init.name.as_str().to_string();
145 let vi = ir.value_or_insert(&name);
146 vi.producer = None;
147
148 let dtype = map_onnx_data_type(init.data_type)
149 .map_err(|_| ShapeInferenceError::UnsupportedDataType(init.data_type))?;
150 let shape: Vec<i64> = init.dims.as_slice().to_vec();
151 result.value_types.insert(name.clone(), dtype.clone());
152 result.value_shapes.insert(name.clone(), shape);
153
154 if matches!(
155 dtype,
156 DataType::Int32 | DataType::Int64 | DataType::Uint32 | DataType::Uint64
157 ) {
158 let values = read_int_tensor(init);
159 if !values.is_empty() {
160 result.const_values.insert(name, values);
161 }
162 }
163 }
164 Ok(())
165}
166
167fn seed_constant_nodes(
168 graph: &GraphProto,
169 result: &mut InferenceResult,
170 ir: &mut OnnxIrGraph,
171) -> Result<(), ShapeInferenceError> {
172 for node in graph.node.as_slice() {
173 if node.op_type.as_str() != "Constant" {
174 continue;
175 }
176
177 if let Some(out) = node.output.as_slice().first() {
178 let out_name = out.to_string();
179 let vi = ir.value_or_insert(&out_name);
180 vi.producer = Some(node.name.as_str().to_string());
181
182 if let Some(attr) = node
183 .attribute
184 .as_slice()
185 .iter()
186 .find(|a| a.name.as_str() == "value" && a.t.is_some())
187 {
188 let t = attr.t.as_ref().unwrap();
189 let dtype = map_onnx_data_type(t.data_type)
190 .map_err(|_| ShapeInferenceError::UnsupportedDataType(t.data_type))?;
191 result.value_types.insert(out_name.clone(), dtype);
192
193 let vals = read_int_tensor(t);
194 if !vals.is_empty() {
195 result.const_values.insert(out_name.clone(), vals.clone());
196 let shape: Vec<i64> = if vals.len() == 1 {
197 Vec::new()
198 } else {
199 vec![vals.len() as i64]
200 };
201 result.value_shapes.insert(out_name.clone(), shape);
202 vi.ty = Some(TensorType {
203 data_type: result.value_types[&out_name].clone(),
204 shape: TensorShape::from_known(result.value_shapes[&out_name].clone()),
205 });
206 }
207 }
208 }
209 }
210 Ok(())
211}
212
213fn propagate_node_shapes(
214 graph: &GraphProto,
215 result: &mut InferenceResult,
216) -> Result<(), ShapeInferenceError> {
217 let mut progress = true;
218 let max_iters = 8;
219 let mut iter = 0;
220
221 while progress && iter < max_iters {
222 progress = false;
223 iter += 1;
224
225 for node in graph.node.as_slice() {
226 let outputs = node.output.as_slice();
227 if outputs.is_empty() {
228 continue;
229 }
230 if outputs
231 .iter()
232 .all(|o| result.value_shapes.contains_key(o.as_str()))
233 {
234 continue;
235 }
236
237 if let Some(shape) = infer_node_shape(node, result) {
238 let out_name = outputs[0].to_string();
239 result.value_shapes.entry(out_name.clone()).or_insert(shape);
240
241 if let Some(first_in) = node.input.as_slice().first() {
243 if let Some(dtype) = result.value_types.get(first_in).cloned() {
244 result.value_types.entry(out_name.clone()).or_insert(dtype);
245 }
246 }
247
248 progress = true;
249 }
250 }
251
252 progress |= fold_integer_constants(graph, result);
254 }
255
256 Ok(())
257}
258
259#[allow(dead_code)]
260fn broadcast_shapes(a: &[i64], b: &[i64]) -> Option<Vec<i64>> {
261 let mut result = Vec::new();
262 let mut ai = a.iter().rev();
263 let mut bi = b.iter().rev();
264
265 loop {
266 match (ai.next(), bi.next()) {
267 (Some(&ad), Some(&bd)) => {
268 if ad == bd {
269 result.push(ad);
270 } else if ad == 1 {
271 result.push(bd);
272 } else if bd == 1 {
273 result.push(ad);
274 } else {
275 return None;
276 }
277 }
278 (Some(&ad), None) => result.push(ad),
279 (None, Some(&bd)) => result.push(bd),
280 (None, None) => break,
281 }
282 }
283
284 result.reverse();
285 Some(result)
286}
287
288fn infer_node_shape(node: &NodeProto, ctx: &InferenceResult) -> Option<Vec<i64>> {
289 let op = node.op_type.as_str();
290 match op {
291 "Relu" | "Tanh" | "Sigmoid" | "Erf" | "Softmax" | "Gelu" | "Exp" | "Log" | "Abs"
292 | "Neg" | "Sqrt" | "LayerNormalization" => node
293 .input
294 .as_slice()
295 .first()
296 .and_then(|i| ctx.value_shapes.get(i).cloned()),
297 "Add" | "Sub" | "Mul" | "Div" | "Pow" => {
298 if node.input.as_slice().len() < 2 {
299 return None;
300 }
301 let a = node.input.as_slice()[0].as_str();
302 let b = node.input.as_slice()[1].as_str();
303 match (ctx.value_shapes.get(a), ctx.value_shapes.get(b)) {
304 (Some(sa), Some(sb)) => {
308 if sa.len() <= sb.len() {
309 Some(sa.clone())
310 } else {
311 Some(sb.clone())
312 }
313 }
314 _ => None,
315 }
316 }
317 "MatMul" => {
318 if node.input.as_slice().len() < 2 {
319 return None;
320 }
321 let a_shape = ctx.value_shapes.get(node.input.as_slice()[0].as_str())?;
322 let b_shape = ctx.value_shapes.get(node.input.as_slice()[1].as_str())?;
323
324 if a_shape.len() == 4 && b_shape.len() == 4 {
326 return Some(vec![a_shape[0], a_shape[1], a_shape[2], b_shape[3]]);
327 }
328
329 if a_shape.len() >= 2 && b_shape.len() >= 2 {
331 let m = a_shape[a_shape.len() - 2];
332 let n = b_shape[b_shape.len() - 1];
333 let mut out = Vec::new();
334 if a_shape.len() > 2 {
335 out.extend_from_slice(&a_shape[..a_shape.len() - 2]);
336 }
337 out.push(m);
338 out.push(n);
339 return Some(out);
340 }
341 None
342 }
343 "Transpose" => {
344 let input = node.input.as_slice().first()?;
345 let shape = ctx.value_shapes.get(input)?;
346 let perm: Vec<usize> = node
347 .attribute
348 .as_slice()
349 .iter()
350 .find(|a| a.name.as_str() == "perm")
351 .map(|a| a.ints.iter().map(|&i| i as usize).collect::<Vec<usize>>())
352 .unwrap_or_else(|| (0..shape.len()).rev().collect());
353 if perm.iter().any(|&i| i >= shape.len()) {
354 return None;
355 }
356 Some(perm.iter().map(|&i| shape[i]).collect())
357 }
358 "Concat" => {
359 let mut shapes = Vec::new();
360 for inp in node.input.as_slice() {
361 if let Some(s) = ctx.value_shapes.get(inp.as_str()) {
362 shapes.push(s.clone());
363 } else {
364 return None;
365 }
366 }
367 if shapes.is_empty() {
368 return None;
369 }
370 let mut axis = node
371 .attribute
372 .as_slice()
373 .iter()
374 .find(|a| a.name.as_str() == "axis" && a.i != 0)
375 .map(|a| a.i)
376 .unwrap_or(0);
377 if axis < 0 {
378 axis += shapes[0].len() as i64;
379 }
380 let axis = axis as usize;
381 let mut out = shapes[0].clone();
382 for s in shapes.iter().skip(1) {
383 if s.len() != out.len() || axis >= s.len() {
384 return None;
385 }
386 out[axis] += s[axis];
387 }
388 Some(out)
389 }
390 "Unsqueeze" => {
391 if node.input.as_slice().is_empty() {
392 return None;
393 }
394 let input_shape = ctx.value_shapes.get(node.input.as_slice()[0].as_str())?;
395 let mut axes = node
396 .attribute
397 .as_slice()
398 .iter()
399 .find(|a| a.name.as_str() == "axes")
400 .map(|a| a.ints.clone())
401 .unwrap_or_default();
402 if axes.is_empty() && node.input.as_slice().len() > 1 {
404 axes = ctx
405 .const_values
406 .get(node.input.as_slice()[1].as_str())
407 .cloned()
408 .unwrap_or_default();
409 }
410 if axes.is_empty() {
411 return None;
412 }
413 let mut output_shape = input_shape.clone();
414 let mut sorted_axes = axes.clone();
415 sorted_axes.sort();
416 for axis in sorted_axes {
417 let idx = if axis < 0 {
418 (output_shape.len() as i64 + axis + 1) as usize
419 } else {
420 axis as usize
421 };
422 if idx > output_shape.len() {
423 return None;
424 }
425 output_shape.insert(idx, 1);
426 }
427 Some(output_shape)
428 }
429 "Expand" => {
430 if node.input.as_slice().len() < 2 {
431 return None;
432 }
433 if let Some(target_shape) = ctx.const_values.get(node.input.as_slice()[1].as_str()) {
435 if !target_shape.is_empty() {
436 return Some(target_shape.clone());
437 }
438 }
439 if let Some(out) = node.output.as_slice().first() {
441 if let Some(shape) = ctx.value_shapes.get(out.as_str()) {
442 if !shape.is_empty() && shape.iter().all(|&d| d > 0) {
443 return Some(shape.clone());
444 }
445 }
446 }
447 None
448 }
449 "Squeeze" => {
450 if node.input.as_slice().is_empty() {
451 return None;
452 }
453 let input_shape = ctx.value_shapes.get(node.input.as_slice()[0].as_str())?;
454 let mut axes = node
455 .attribute
456 .as_slice()
457 .iter()
458 .find(|a| a.name.as_str() == "axes")
459 .map(|a| a.ints.clone())
460 .unwrap_or_default();
461 if axes.is_empty() && node.input.as_slice().len() > 1 {
463 axes = ctx
464 .const_values
465 .get(node.input.as_slice()[1].as_str())
466 .cloned()
467 .unwrap_or_default();
468 }
469 let mut output_shape = input_shape.clone();
470 if axes.is_empty() {
471 output_shape.retain(|&d| d != 1);
472 return Some(output_shape);
473 }
474 let mut axes_norm: Vec<usize> = axes
475 .iter()
476 .map(|&a| {
477 if a < 0 {
478 (input_shape.len() as i64 + a) as usize
479 } else {
480 a as usize
481 }
482 })
483 .collect();
484 axes_norm.sort();
485 axes_norm.dedup();
486 let mut keep = Vec::new();
487 for (idx, dim) in input_shape.iter().enumerate() {
488 if axes_norm.contains(&idx) {
489 continue;
490 }
491 keep.push(*dim);
492 }
493 Some(keep)
494 }
495 "Reshape" => {
496 if node.input.as_slice().len() < 2 {
497 return None;
498 }
499 let data_shape = ctx.value_shapes.get(node.input.as_slice()[0].as_str())?;
500 let shape_input = node.input.as_slice()[1].as_str();
501 let mut target: Vec<i64> = ctx.const_values.get(shape_input)?.clone();
502
503 if target.contains(&-1) {
504 let total_input: i64 = data_shape.iter().product();
505 let known: i64 = target.iter().filter(|&&d| d != -1).product();
506 if known == 0 || total_input % known != 0 {
507 return None;
508 }
509 if let Some(idx) = target.iter().position(|&d| d == -1) {
510 target[idx] = total_input / known;
511 }
512 }
513 Some(target)
514 }
515 "Slice" => {
516 if node.input.as_slice().is_empty() {
517 return None;
518 }
519 let data_shape = ctx.value_shapes.get(node.input.as_slice()[0].as_str())?;
520 let starts = node
521 .input
522 .as_slice()
523 .get(1)
524 .and_then(|n| ctx.const_values.get(n))
525 .cloned()?;
526 let ends = node
527 .input
528 .as_slice()
529 .get(2)
530 .and_then(|n| ctx.const_values.get(n))
531 .cloned()?;
532 let axes = node
533 .input
534 .as_slice()
535 .get(3)
536 .and_then(|n| ctx.const_values.get(n))
537 .cloned()
538 .unwrap_or_else(|| (0..data_shape.len() as i64).collect());
539 let steps = node
540 .input
541 .as_slice()
542 .get(4)
543 .and_then(|n| ctx.const_values.get(n))
544 .cloned()
545 .unwrap_or_else(|| vec![1; axes.len()]);
546
547 if axes.len() != starts.len() || axes.len() != ends.len() || axes.len() != steps.len() {
548 return None;
549 }
550
551 let mut out = data_shape.clone();
552 for i in 0..axes.len() {
553 let mut axis = axes[i];
554 if axis < 0 {
555 axis += data_shape.len() as i64;
556 }
557 let axis = axis as usize;
558 if axis >= out.len() {
559 return None;
560 }
561 if steps[i] != 1 {
562 return None;
563 }
564 let dim = data_shape[axis];
565 let mut start = starts[i];
566 let mut end = ends[i];
567 if start < 0 {
568 start += dim;
569 }
570 if end < 0 {
571 end += dim;
572 }
573 start = start.max(0);
574 end = end.min(dim);
575 out[axis] = if end < start { 0 } else { end - start };
576 }
577 Some(out)
578 }
579 "Gather" => {
580 if node.input.as_slice().len() < 2 {
581 return None;
582 }
583 let data_shape = ctx.value_shapes.get(node.input.as_slice()[0].as_str())?;
584 let indices_shape = ctx.value_shapes.get(node.input.as_slice()[1].as_str())?;
585 let mut axis = node
586 .attribute
587 .as_slice()
588 .iter()
589 .find(|a| a.name.as_str() == "axis" && a.i != 0)
590 .map(|a| a.i)
591 .unwrap_or(0);
592 if axis < 0 {
593 axis += data_shape.len() as i64;
594 }
595 let axis = axis as usize;
596 if axis > data_shape.len() {
597 return None;
598 }
599 let mut out = Vec::new();
600 out.extend_from_slice(&data_shape[..axis]);
601 out.extend(indices_shape.iter().cloned());
602 if axis < data_shape.len() {
603 out.extend_from_slice(&data_shape[axis + 1..]);
604 }
605 Some(out)
606 }
607 "Split" => {
608 let input_shape = node
609 .input
610 .as_slice()
611 .first()
612 .and_then(|i| ctx.value_shapes.get(i))
613 .cloned()?;
614 let mut axis = node
615 .attribute
616 .as_slice()
617 .iter()
618 .find(|a| a.name.as_str() == "axis" && a.i != 0)
619 .map(|a| a.i)
620 .unwrap_or(0);
621 if axis < 0 {
622 axis += input_shape.len() as i64;
623 }
624 let axis = axis as usize;
625 if axis >= input_shape.len() {
626 return None;
627 }
628 let splits = node
629 .attribute
630 .as_slice()
631 .iter()
632 .find(|a| a.name.as_str() == "split")
633 .map(|a| a.ints.clone());
634 if let Some(s) = splits {
635 if s.iter().any(|&v| v <= 0) {
636 return None;
637 }
638 let sum: i64 = s.iter().sum();
639 if sum != input_shape[axis] {
640 return None;
641 }
642 let mut out = input_shape.clone();
643 out[axis] = s[0];
644 Some(out)
645 } else {
646 let outputs = node.output.as_slice().len() as i64;
647 if outputs == 0 || input_shape[axis] % outputs != 0 {
648 return None;
649 }
650 let chunk = input_shape[axis] / outputs;
651 let mut out = input_shape.clone();
652 out[axis] = chunk;
653 Some(out)
654 }
655 }
656 "ReduceMean" | "ReduceSum" | "ReduceMax" | "ReduceMin" => {
657 let input = node.input.as_slice().first()?;
658 let input_shape = ctx.value_shapes.get(input)?;
659 let axes: Vec<i64> = node
660 .attribute
661 .as_slice()
662 .iter()
663 .find(|a| a.name.as_str() == "axes")
664 .map(|a| a.ints.clone())
665 .unwrap_or_default();
666 let keepdims = node
667 .attribute
668 .as_slice()
669 .iter()
670 .find(|a| a.name.as_str() == "keepdims" && a.i != 0)
671 .map(|a| a.i != 0)
672 .unwrap_or(true);
673 if axes.is_empty() {
674 if keepdims {
675 Some(vec![1; input_shape.len()])
676 } else {
677 Some(vec![])
678 }
679 } else {
680 let mut out = input_shape.clone();
681 for axis in axes {
682 let mut a = axis;
683 if a < 0 {
684 a += input_shape.len() as i64;
685 }
686 let idx = a as usize;
687 if idx >= out.len() {
688 return None;
689 }
690 if keepdims {
691 out[idx] = 1;
692 } else {
693 out[idx] = -1;
694 }
695 }
696 if !keepdims {
697 out.retain(|&d| d != -1);
698 }
699 Some(out)
700 }
701 }
702 _ => None,
703 }
704}
705
706fn fold_integer_constants(graph: &GraphProto, ctx: &mut InferenceResult) -> bool {
707 let mut changed = false;
708 let mut where_count = 0;
709 for node in graph.node.as_slice() {
710 if node.op_type.as_str() == "Where" {
711 where_count += 1;
712 }
713 let outputs = node.output.as_slice();
714 if outputs.is_empty() {
715 continue;
716 }
717 if ctx.const_values.contains_key(outputs[0].as_str()) {
718 continue;
719 }
720
721 let op = node.op_type.as_str();
722 let inputs = node.input.as_slice();
723
724 if op == "Shape" {
728 if let Some(inp) = inputs.first() {
729 if let Some(shape) = ctx.value_shapes.get(inp.as_str()) {
730 let out_name = outputs[0].to_string();
731 ctx.const_values.insert(out_name.clone(), shape.clone());
732 ctx.value_shapes.insert(out_name, vec![shape.len() as i64]);
733 changed = true;
734 continue;
735 }
736 }
737 }
738
739 let all_const = inputs
740 .iter()
741 .all(|i| ctx.const_values.contains_key(i.as_str()));
742 if !all_const {
743 continue;
744 }
745
746 match op {
747 "Concat" => {
748 let mut axis = 0i64;
749 for attr in node.attribute.as_slice() {
750 if attr.name.as_str() == "axis" && attr.i != 0 {
751 axis = attr.i;
752 }
753 }
754 if axis == 0 {
755 let mut combined = Vec::new();
756 for inp in inputs {
757 if let Some(vals) = ctx.const_values.get(inp.as_str()) {
758 combined.extend_from_slice(vals);
759 }
760 }
761 if !combined.is_empty() {
762 let out_name = outputs[0].to_string();
763 ctx.const_values.insert(out_name.clone(), combined.clone());
764 ctx.value_shapes
765 .insert(out_name, vec![combined.len() as i64]);
766 changed = true;
767 }
768 }
769 }
770 "Gather" => {
771 let mut axis = 0i64;
772 for attr in node.attribute.as_slice() {
773 if attr.name.as_str() == "axis" && attr.i != 0 {
774 axis = attr.i;
775 }
776 }
777 if axis == 0 && inputs.len() >= 2 {
778 let data = ctx.const_values.get(inputs[0].as_str());
779 let indices = ctx.const_values.get(inputs[1].as_str());
780 if let (Some(data), Some(indices)) = (data, indices) {
781 let mut gathered = Vec::new();
782 for &idx in indices {
783 let i = if idx < 0 {
784 (data.len() as i64 + idx) as usize
785 } else {
786 idx as usize
787 };
788 if let Some(v) = data.get(i) {
789 gathered.push(*v);
790 }
791 }
792 if !gathered.is_empty() {
793 let out_name = outputs[0].to_string();
794 ctx.const_values.insert(out_name.clone(), gathered.clone());
795 let shape = if gathered.len() == 1 {
796 Vec::new()
797 } else {
798 vec![gathered.len() as i64]
799 };
800 ctx.value_shapes.insert(out_name, shape);
801 changed = true;
802 }
803 }
804 }
805 }
806 "Unsqueeze" => {
807 if inputs.is_empty() {
808 continue;
809 }
810 let data = ctx.const_values.get(inputs[0].as_str()).cloned();
811 if data.is_none() {
812 continue;
813 }
814
815 let mut axes: Vec<i64> = node
816 .attribute
817 .as_slice()
818 .iter()
819 .find(|a| a.name.as_str() == "axes")
820 .map(|a| a.ints.clone())
821 .unwrap_or_default();
822 if axes.is_empty() && inputs.len() > 1 {
823 axes = ctx
824 .const_values
825 .get(inputs[1].as_str())
826 .cloned()
827 .unwrap_or_default();
828 }
829 if axes.is_empty() {
830 continue;
831 }
832
833 let mut sorted_axes = axes.clone();
834 sorted_axes.sort();
835
836 let mut out_shape = ctx
837 .value_shapes
838 .get(inputs[0].as_str())
839 .cloned()
840 .unwrap_or_else(|| {
841 let len = data.as_ref().map(|v| v.len()).unwrap_or(0);
842 if len <= 1 {
843 Vec::new()
844 } else {
845 vec![len as i64]
846 }
847 });
848
849 for axis in sorted_axes {
850 let idx = if axis < 0 {
851 (out_shape.len() as i64 + axis + 1) as usize
852 } else {
853 axis as usize
854 };
855 if idx > out_shape.len() {
856 continue;
857 }
858 out_shape.insert(idx, 1);
859 }
860
861 let out_name = outputs[0].to_string();
862 ctx.const_values
863 .insert(out_name.clone(), data.unwrap_or_default());
864 ctx.value_shapes.insert(out_name, out_shape);
865 changed = true;
866 }
867 "Reshape" => {
868 if inputs.len() < 2 {
869 continue;
870 }
871 let data = ctx.const_values.get(inputs[0].as_str()).cloned();
872 let shape_target = ctx.const_values.get(inputs[1].as_str()).cloned();
873 if let (Some(data), Some(mut target)) = (data, shape_target) {
874 if target.contains(&-1) {
876 let total: i64 = if data.is_empty() {
877 1
878 } else {
879 data.len() as i64
880 };
881 let known: i64 = target.iter().filter(|&&d| d != -1).product();
882 if known != 0 {
883 if let Some(idx) = target.iter().position(|&d| d == -1) {
884 target[idx] = total / known;
885 }
886 }
887 }
888 let out_name = outputs[0].to_string();
889 let out_shape = target.clone();
890 ctx.const_values.insert(out_name.clone(), data);
891 ctx.value_shapes.insert(out_name, out_shape);
892 changed = true;
893 }
894 }
895 "ConstantOfShape" => {
896 if inputs.is_empty() {
899 continue;
900 }
901 if let Some(shape_vals) = ctx.const_values.get(inputs[0].as_str()).cloned() {
902 let fill_value: i64 = node
904 .attribute
905 .as_slice()
906 .iter()
907 .find(|a| a.name.as_str() == "value")
908 .and_then(|a| {
909 let t = a.t.as_ref()?;
910 if !t.raw_data.as_slice().is_empty() {
911 if t.data_type == 7 && t.raw_data.as_slice().len() >= 8 {
913 let bytes = t.raw_data.as_slice();
914 Some(i64::from_le_bytes([
915 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5],
916 bytes[6], bytes[7],
917 ]))
918 } else if t.data_type == 1 && t.raw_data.as_slice().len() >= 4 {
919 let bytes = t.raw_data.as_slice();
920 Some(f32::from_le_bytes([
921 bytes[0], bytes[1], bytes[2], bytes[3],
922 ]) as i64)
923 } else {
924 Some(0)
925 }
926 } else if !t.int64_data.as_slice().is_empty() {
927 Some(t.int64_data.as_slice()[0])
928 } else if !t.float_data.as_slice().is_empty() {
929 Some(t.float_data.as_slice()[0] as i64)
930 } else {
931 Some(0)
932 }
933 })
934 .unwrap_or(0);
935
936 let total: usize = shape_vals.iter().map(|&d| d.max(0) as usize).product();
937 let data = vec![fill_value; total];
938 let out_name = outputs[0].to_string();
939 ctx.const_values.insert(out_name.clone(), data);
940 ctx.value_shapes.insert(out_name, shape_vals);
941 changed = true;
942 }
943 }
944 "Mul" => {
945 if inputs.len() < 2 {
946 continue;
947 }
948 let lhs = ctx.const_values.get(inputs[0].as_str()).cloned();
949 let rhs = ctx.const_values.get(inputs[1].as_str()).cloned();
950 if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
951 let values: Vec<i64> = if lhs.len() == 1 && rhs.len() > 1 {
953 rhs.iter().map(|&r| lhs[0] * r).collect()
954 } else if rhs.len() == 1 && lhs.len() > 1 {
955 lhs.iter().map(|&l| l * rhs[0]).collect()
956 } else if lhs.len() == rhs.len() {
957 lhs.iter().zip(rhs.iter()).map(|(&l, &r)| l * r).collect()
958 } else {
959 continue;
960 };
961 let out_name = outputs[0].to_string();
962 let shape = if values.len() == 1 {
963 Vec::new()
964 } else {
965 vec![values.len() as i64]
966 };
967 ctx.const_values.insert(out_name.clone(), values);
968 ctx.value_shapes.insert(out_name, shape);
969 changed = true;
970 }
971 }
972 "Equal" => {
973 if inputs.len() < 2 {
974 continue;
975 }
976 let lhs = ctx.const_values.get(inputs[0].as_str()).cloned();
977 let rhs = ctx.const_values.get(inputs[1].as_str()).cloned();
978 if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
979 if lhs.len() != rhs.len() {
980 continue;
981 }
982 let values: Vec<i64> = lhs
983 .iter()
984 .zip(rhs.iter())
985 .map(|(l, r)| if l == r { 1 } else { 0 })
986 .collect();
987 let out_name = outputs[0].to_string();
988 let shape = if values.len() == 1 {
989 Vec::new()
990 } else {
991 vec![values.len() as i64]
992 };
993 ctx.const_values.insert(out_name.clone(), values);
994 ctx.value_shapes.insert(out_name, shape);
995 changed = true;
996 }
997 }
998 "Where" => {
999 if inputs.len() < 3 {
1000 continue;
1001 }
1002
1003 if inputs.iter().any(|i| i.contains("rotary")) {
1005 crate::debug_println!("[WHERE DEBUG] Processing Where node");
1006 crate::debug_println!(" inputs: {:?}", inputs);
1007 crate::debug_println!(" outputs: {:?}", outputs);
1008 }
1009
1010 let cond = ctx.const_values.get(inputs[0].as_str()).cloned();
1011 let a = ctx.const_values.get(inputs[1].as_str()).cloned();
1012 let b = ctx.const_values.get(inputs[2].as_str()).cloned();
1013
1014 if inputs.iter().any(|i| i.contains("rotary")) {
1015 crate::debug_println!(" cond const: {}", cond.is_some());
1016 crate::debug_println!(" a const: {}", a.is_some());
1017 crate::debug_println!(" b const: {}", b.is_some());
1018 }
1019
1020 if let (Some(cond), Some(a), Some(b)) = (cond, a, b) {
1022 if cond.len() != a.len() || a.len() != b.len() {
1023 continue;
1024 }
1025
1026 let is_trivial =
1031 |vals: &[i64]| -> bool { vals.iter().all(|&v| v == 1) && vals.len() <= 3 };
1032
1033 let mut out = if is_trivial(&a) && !is_trivial(&b) {
1034 if inputs.iter().any(|i| i.contains("rotary")) {
1035 crate::debug_println!("[WHERE SMART EVAL] Preferring non-trivial branch b={:?} over trivial a={:?}", b, a);
1036 }
1037 b
1038 } else if is_trivial(&b) && !is_trivial(&a) {
1039 if inputs.iter().any(|i| i.contains("rotary")) {
1040 crate::debug_println!("[WHERE SMART EVAL] Preferring non-trivial branch a={:?} over trivial b={:?}", a, b);
1041 }
1042 a
1043 } else {
1044 let mut result = Vec::with_capacity(a.len());
1046 for i in 0..a.len() {
1047 result.push(if cond[i] != 0 { a[i] } else { b[i] });
1048 }
1049 result
1050 };
1051
1052 if out.contains(&-1) && !outputs.is_empty() {
1055 let output_name = outputs[0].as_str();
1056 for node in graph.node.as_slice() {
1058 if node.op_type.as_str() == "Expand"
1059 && node.input.len() >= 2
1060 && node.input[1].as_str() == output_name
1061 {
1062 let data_input = node.input[0].as_str();
1064 if let Some(data_shape) = ctx.value_shapes.get(data_input) {
1065 if out.len() == data_shape.len() {
1067 for i in 0..out.len() {
1068 if out[i] == -1 {
1069 out[i] = data_shape[i];
1070 if inputs.iter().any(|inp| inp.contains("rotary")) {
1071 crate::debug_println!("[WHERE RESOLVE] Resolved -1 at position {} to {} from data shape {:?}", i, data_shape[i], data_shape);
1072 }
1073 }
1074 }
1075 }
1076 }
1077 }
1078 }
1079 }
1080
1081 let out_name = outputs[0].to_string();
1082 let shape = if out.len() == 1 {
1083 Vec::new()
1084 } else {
1085 vec![out.len() as i64]
1086 };
1087 if inputs.iter().any(|i| i.contains("rotary")) {
1088 crate::debug_println!("[WHERE STORE] Storing {} = {:?}", out_name, out);
1089 }
1090 ctx.const_values.insert(out_name.clone(), out);
1091 ctx.value_shapes.insert(out_name, shape);
1092 changed = true;
1093 } else {
1094 let a_const = ctx.const_values.get(inputs[1].as_str());
1099 let b_const = ctx.const_values.get(inputs[2].as_str());
1100 let a_shape = ctx.value_shapes.get(inputs[1].as_str());
1101 let b_shape = ctx.value_shapes.get(inputs[2].as_str());
1102
1103 let is_trivial_constant =
1105 |vals: &[i64]| -> bool { vals.iter().all(|&v| v == 1) && vals.len() <= 3 };
1106
1107 let preferred_values = if let (Some(a_vals), None) = (a_const, b_const) {
1108 if is_trivial_constant(a_vals) && b_shape.is_some() {
1110 crate::debug_println!("[WHERE HEURISTIC] Preferring dynamic input {} (shape {:?}) over trivial constant {:?}", inputs[2], b_shape, a_vals);
1113 b_shape.cloned()
1114 } else {
1115 Some(a_vals.clone())
1116 }
1117 } else if let (None, Some(b_vals)) = (a_const, b_const) {
1118 if is_trivial_constant(b_vals) && a_shape.is_some() {
1120 crate::debug_println!("[WHERE HEURISTIC] Preferring dynamic input {} (shape {:?}) over trivial constant {:?}", inputs[1], a_shape, b_vals);
1123 a_shape.cloned()
1124 } else {
1125 Some(b_vals.clone())
1126 }
1127 } else {
1128 None
1129 };
1130
1131 if let Some(values) = preferred_values {
1133 let out_name = outputs[0].to_string();
1134 let shape = if values.len() == 1 {
1135 Vec::new()
1136 } else {
1137 vec![values.len() as i64]
1138 };
1139 ctx.const_values.insert(out_name.clone(), values);
1140 ctx.value_shapes.insert(out_name, shape);
1141 changed = true;
1142 }
1143 }
1144 }
1145 _ => {}
1146 }
1147 }
1148 if where_count > 0 {
1149 crate::debug_println!(
1150 "[FOLD DEBUG] Processed {} Where nodes, changed={}",
1151 where_count,
1152 changed
1153 );
1154 }
1155 changed
1156}
1157
1158fn read_int_tensor(tensor: &TensorProto) -> Vec<i64> {
1159 let raw = tensor.raw_data.as_slice();
1160 if !raw.is_empty() {
1161 match tensor.data_type {
1162 x if x == TensorProto_DataType::Int32 as i32 => raw
1163 .chunks_exact(4)
1164 .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as i64)
1165 .collect(),
1166 _ => raw
1167 .chunks_exact(8)
1168 .map(|c| i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))
1169 .collect(),
1170 }
1171 } else if !tensor.int64_data.as_slice().is_empty() {
1172 tensor.int64_data.as_slice().to_vec()
1173 } else if !tensor.int32_data.as_slice().is_empty() {
1174 tensor
1175 .int32_data
1176 .as_slice()
1177 .iter()
1178 .map(|&v| v as i64)
1179 .collect()
1180 } else {
1181 Vec::new()
1182 }
1183}
1184
1185#[cfg(test)]
1186mod tests {
1187 use super::*;
1188
1189 #[test]
1190 fn dynamic_dim_requires_override() {
1191 use crate::protos::onnx::{tensor_shape_proto, type_proto};
1192
1193 let dim = tensor_shape_proto::Dimension {
1194 value: Some(tensor_shape_proto::dimension::Value::DimParam(
1195 "batch".to_string(),
1196 )),
1197 denotation: String::new(),
1198 };
1199 let shape = crate::protos::onnx::TensorShapeProto { dim: vec![dim] };
1200
1201 let tensor_type = type_proto::Tensor {
1202 elem_type: crate::protos::onnx::TensorProto_DataType::Float.into(),
1203 shape: Some(shape),
1204 };
1205
1206 let type_proto = crate::protos::onnx::TypeProto {
1207 value: Some(type_proto::Value::TensorType(tensor_type)),
1208 denotation: String::new(),
1209 };
1210
1211 let vi = crate::protos::onnx::ValueInfoProto {
1212 name: "input".to_string(),
1213 r#type: Some(type_proto),
1214 ..Default::default()
1215 };
1216
1217 let graph = crate::protos::onnx::GraphProto {
1218 input: vec![vi],
1219 ..Default::default()
1220 };
1221
1222 let model = crate::protos::onnx::ModelProto {
1223 graph: Some(graph),
1224 ..Default::default()
1225 };
1226
1227 let res = infer_static_shapes(&model, &HashMap::new());
1228 assert!(matches!(
1229 res,
1230 Err(ShapeInferenceError::DynamicDim { dim, .. }) if dim == "batch"
1231 ));
1232 }
1233
1234 #[test]
1235 fn override_allows_static_shape() {
1236 use crate::protos::onnx::{tensor_shape_proto, type_proto};
1237
1238 let dim = tensor_shape_proto::Dimension {
1239 value: Some(tensor_shape_proto::dimension::Value::DimParam(
1240 "batch".to_string(),
1241 )),
1242 denotation: String::new(),
1243 };
1244 let shape = crate::protos::onnx::TensorShapeProto { dim: vec![dim] };
1245
1246 let tensor_type = type_proto::Tensor {
1247 elem_type: crate::protos::onnx::TensorProto_DataType::Float.into(),
1248 shape: Some(shape),
1249 };
1250
1251 let type_proto = crate::protos::onnx::TypeProto {
1252 value: Some(type_proto::Value::TensorType(tensor_type)),
1253 denotation: String::new(),
1254 };
1255
1256 let vi = crate::protos::onnx::ValueInfoProto {
1257 name: "input".to_string(),
1258 r#type: Some(type_proto),
1259 ..Default::default()
1260 };
1261
1262 let graph = crate::protos::onnx::GraphProto {
1263 input: vec![vi],
1264 ..Default::default()
1265 };
1266
1267 let model = crate::protos::onnx::ModelProto {
1268 graph: Some(graph),
1269 ..Default::default()
1270 };
1271
1272 let mut overrides = HashMap::new();
1273 overrides.insert("batch".to_string(), 1);
1274 let res = infer_static_shapes(&model, &overrides).unwrap();
1275 assert_eq!(res.value_shapes.get("input"), Some(&vec![1]));
1276 }
1277}