Skip to main content

webnn_graph/onnx/
shape_inference.rs

1// Static shape/type inference scaffold for ONNX graphs.
2// Conservative: records only fully-static shapes and folds small integer constants
3// to unblock reshape/axes/starts/ends calculations. Dynamic dims cause errors so
4// callers can ask users to run onnx-simplifier or provide overrides.
5use crate::ast::DataType;
6use crate::onnx::convert::map_onnx_data_type;
7use crate::onnx::ir::{Dim, OnnxIrGraph, TensorShape, TensorType};
8use crate::protos::onnx::{
9    tensor_shape_proto::dimension::Value as DimensionValue, type_proto::Value as TypeProtoValue,
10    GraphProto, ModelProto, NodeProto, TensorProto, TensorProto_DataType,
11};
12use std::collections::{HashMap, HashSet};
13use thiserror::Error;
14
15#[derive(Debug, Error)]
16pub enum ShapeInferenceError {
17    #[error("input '{0}' is missing shape information")]
18    MissingInputShape(String),
19    #[error("input '{input}' has dynamic dimension '{dim}', please provide an override")]
20    DynamicDim { input: String, dim: String },
21    #[error("unsupported ONNX data type: {0}")]
22    UnsupportedDataType(i32),
23    #[error("could not infer shape for op '{op}'")]
24    CannotInfer { op: String },
25}
26
27#[derive(Debug, Default)]
28pub struct InferenceResult {
29    pub value_shapes: HashMap<String, Vec<i64>>,
30    pub value_types: HashMap<String, DataType>,
31    pub const_values: HashMap<String, Vec<i64>>,
32}
33
34/// Run a lightweight static shape/type inference pass.
35/// Returns only fully-known shapes; dynamic dimensions trigger an error.
36pub fn infer_static_shapes(
37    model: &ModelProto,
38    overrides: &HashMap<String, u32>,
39) -> Result<InferenceResult, ShapeInferenceError> {
40    let mut result = InferenceResult::default();
41
42    if model.graph.is_none() {
43        return Ok(result);
44    }
45
46    let graph = model.graph.as_ref().unwrap();
47    let mut ir = OnnxIrGraph::default();
48    let initializer_names: HashSet<String> = graph
49        .initializer
50        .as_slice()
51        .iter()
52        .map(|i| i.name.as_str().to_string())
53        .collect();
54
55    seed_inputs(graph, overrides, &initializer_names, &mut ir, &mut result)?;
56    seed_initializers(graph, &mut ir, &mut result)?;
57    seed_constant_nodes(graph, &mut result, &mut ir)?;
58
59    propagate_node_shapes(graph, &mut result)?;
60
61    Ok(result)
62}
63
64fn seed_inputs(
65    graph: &GraphProto,
66    overrides: &HashMap<String, u32>,
67    initializer_names: &HashSet<String>,
68    ir: &mut OnnxIrGraph,
69    result: &mut InferenceResult,
70) -> Result<(), ShapeInferenceError> {
71    for input in graph.input.as_slice() {
72        let name = input.name.as_str().to_string();
73        let vi = ir.value_or_insert(&name);
74        vi.producer = None;
75
76        if initializer_names.contains(&name) {
77            continue;
78        }
79
80        let type_proto = input
81            .r#type
82            .as_ref()
83            .ok_or_else(|| ShapeInferenceError::MissingInputShape(name.clone()))?;
84
85        let tensor_type = match &type_proto.value {
86            Some(TypeProtoValue::TensorType(tt)) => tt,
87            _ => return Err(ShapeInferenceError::MissingInputShape(name.clone())),
88        };
89
90        let dtype = if tensor_type.elem_type != 0 {
91            map_onnx_data_type(tensor_type.elem_type)
92                .map_err(|_| ShapeInferenceError::UnsupportedDataType(tensor_type.elem_type))?
93        } else {
94            return Err(ShapeInferenceError::UnsupportedDataType(0));
95        };
96
97        let shape = tensor_type
98            .shape
99            .as_ref()
100            .ok_or_else(|| ShapeInferenceError::MissingInputShape(name.clone()))?;
101
102        let mut dims = Vec::new();
103        for dim in shape.dim.as_slice() {
104            if let Some(value) = &dim.value {
105                match value {
106                    DimensionValue::DimValue(v) => {
107                        dims.push(Dim::Known(*v));
108                    }
109                    DimensionValue::DimParam(key) => {
110                        if let Some(v) = overrides.get(key.as_str()) {
111                            dims.push(Dim::Known(*v as i64));
112                        } else {
113                            return Err(ShapeInferenceError::DynamicDim {
114                                input: name.clone(),
115                                dim: key.clone(),
116                            });
117                        }
118                    }
119                }
120            } else {
121                return Err(ShapeInferenceError::MissingInputShape(name.clone()));
122            }
123        }
124
125        let ty = TensorType {
126            data_type: dtype.clone(),
127            shape: TensorShape { dims },
128        };
129        vi.ty = Some(ty.clone());
130        result.value_types.insert(name.clone(), dtype);
131        if let Some(shape) = ty.shape.to_i64() {
132            result.value_shapes.insert(name, shape);
133        }
134    }
135    Ok(())
136}
137
138fn seed_initializers(
139    graph: &GraphProto,
140    ir: &mut OnnxIrGraph,
141    result: &mut InferenceResult,
142) -> Result<(), ShapeInferenceError> {
143    for init in graph.initializer.as_slice() {
144        let name = init.name.as_str().to_string();
145        let vi = ir.value_or_insert(&name);
146        vi.producer = None;
147
148        let dtype = map_onnx_data_type(init.data_type)
149            .map_err(|_| ShapeInferenceError::UnsupportedDataType(init.data_type))?;
150        let shape: Vec<i64> = init.dims.as_slice().to_vec();
151        result.value_types.insert(name.clone(), dtype.clone());
152        result.value_shapes.insert(name.clone(), shape);
153
154        if matches!(
155            dtype,
156            DataType::Int32 | DataType::Int64 | DataType::Uint32 | DataType::Uint64
157        ) {
158            let values = read_int_tensor(init);
159            if !values.is_empty() {
160                result.const_values.insert(name, values);
161            }
162        }
163    }
164    Ok(())
165}
166
167fn seed_constant_nodes(
168    graph: &GraphProto,
169    result: &mut InferenceResult,
170    ir: &mut OnnxIrGraph,
171) -> Result<(), ShapeInferenceError> {
172    for node in graph.node.as_slice() {
173        if node.op_type.as_str() != "Constant" {
174            continue;
175        }
176
177        if let Some(out) = node.output.as_slice().first() {
178            let out_name = out.to_string();
179            let vi = ir.value_or_insert(&out_name);
180            vi.producer = Some(node.name.as_str().to_string());
181
182            if let Some(attr) = node
183                .attribute
184                .as_slice()
185                .iter()
186                .find(|a| a.name.as_str() == "value" && a.t.is_some())
187            {
188                let t = attr.t.as_ref().unwrap();
189                let dtype = map_onnx_data_type(t.data_type)
190                    .map_err(|_| ShapeInferenceError::UnsupportedDataType(t.data_type))?;
191                result.value_types.insert(out_name.clone(), dtype);
192
193                let vals = read_int_tensor(t);
194                if !vals.is_empty() {
195                    result.const_values.insert(out_name.clone(), vals.clone());
196                    let shape: Vec<i64> = if vals.len() == 1 {
197                        Vec::new()
198                    } else {
199                        vec![vals.len() as i64]
200                    };
201                    result.value_shapes.insert(out_name.clone(), shape);
202                    vi.ty = Some(TensorType {
203                        data_type: result.value_types[&out_name].clone(),
204                        shape: TensorShape::from_known(result.value_shapes[&out_name].clone()),
205                    });
206                }
207            }
208        }
209    }
210    Ok(())
211}
212
213fn propagate_node_shapes(
214    graph: &GraphProto,
215    result: &mut InferenceResult,
216) -> Result<(), ShapeInferenceError> {
217    let mut progress = true;
218    let max_iters = 8;
219    let mut iter = 0;
220
221    while progress && iter < max_iters {
222        progress = false;
223        iter += 1;
224
225        for node in graph.node.as_slice() {
226            let outputs = node.output.as_slice();
227            if outputs.is_empty() {
228                continue;
229            }
230            if outputs
231                .iter()
232                .all(|o| result.value_shapes.contains_key(o.as_str()))
233            {
234                continue;
235            }
236
237            if let Some(shape) = infer_node_shape(node, result) {
238                let out_name = outputs[0].to_string();
239                result.value_shapes.entry(out_name.clone()).or_insert(shape);
240
241                // Propagate dtype from first input if available.
242                if let Some(first_in) = node.input.as_slice().first() {
243                    if let Some(dtype) = result.value_types.get(first_in).cloned() {
244                        result.value_types.entry(out_name.clone()).or_insert(dtype);
245                    }
246                }
247
248                progress = true;
249            }
250        }
251
252        // Opportunistic const folding for integer tensors to unlock more shapes.
253        progress |= fold_integer_constants(graph, result);
254    }
255
256    Ok(())
257}
258
259#[allow(dead_code)]
260fn broadcast_shapes(a: &[i64], b: &[i64]) -> Option<Vec<i64>> {
261    let mut result = Vec::new();
262    let mut ai = a.iter().rev();
263    let mut bi = b.iter().rev();
264
265    loop {
266        match (ai.next(), bi.next()) {
267            (Some(&ad), Some(&bd)) => {
268                if ad == bd {
269                    result.push(ad);
270                } else if ad == 1 {
271                    result.push(bd);
272                } else if bd == 1 {
273                    result.push(ad);
274                } else {
275                    return None;
276                }
277            }
278            (Some(&ad), None) => result.push(ad),
279            (None, Some(&bd)) => result.push(bd),
280            (None, None) => break,
281        }
282    }
283
284    result.reverse();
285    Some(result)
286}
287
288fn infer_node_shape(node: &NodeProto, ctx: &InferenceResult) -> Option<Vec<i64>> {
289    let op = node.op_type.as_str();
290    match op {
291        "Relu" | "Tanh" | "Sigmoid" | "Erf" | "Softmax" | "Gelu" | "Exp" | "Log" | "Abs"
292        | "Neg" | "Sqrt" | "LayerNormalization" => node
293            .input
294            .as_slice()
295            .first()
296            .and_then(|i| ctx.value_shapes.get(i).cloned()),
297        "Add" | "Sub" | "Mul" | "Div" | "Pow" => {
298            if node.input.as_slice().len() < 2 {
299                return None;
300            }
301            let a = node.input.as_slice()[0].as_str();
302            let b = node.input.as_slice()[1].as_str();
303            match (ctx.value_shapes.get(a), ctx.value_shapes.get(b)) {
304                // Prefer smaller shape to avoid inflation
305                // Rationale: Broadcasting happens implicitly; storing inflated shapes
306                // breaks ONNX round-trip conversion
307                (Some(sa), Some(sb)) => {
308                    if sa.len() <= sb.len() {
309                        Some(sa.clone())
310                    } else {
311                        Some(sb.clone())
312                    }
313                }
314                _ => None,
315            }
316        }
317        "MatMul" => {
318            if node.input.as_slice().len() < 2 {
319                return None;
320            }
321            let a_shape = ctx.value_shapes.get(node.input.as_slice()[0].as_str())?;
322            let b_shape = ctx.value_shapes.get(node.input.as_slice()[1].as_str())?;
323
324            // Attention pattern: rank-4 [B,S,H,D] x [B,S,H,D] -> [B,S,H,H]
325            if a_shape.len() == 4 && b_shape.len() == 4 {
326                return Some(vec![a_shape[0], a_shape[1], a_shape[2], b_shape[3]]);
327            }
328
329            // Fallback generic matmul
330            if a_shape.len() >= 2 && b_shape.len() >= 2 {
331                let m = a_shape[a_shape.len() - 2];
332                let n = b_shape[b_shape.len() - 1];
333                let mut out = Vec::new();
334                if a_shape.len() > 2 {
335                    out.extend_from_slice(&a_shape[..a_shape.len() - 2]);
336                }
337                out.push(m);
338                out.push(n);
339                return Some(out);
340            }
341            None
342        }
343        "Transpose" => {
344            let input = node.input.as_slice().first()?;
345            let shape = ctx.value_shapes.get(input)?;
346            let perm: Vec<usize> = node
347                .attribute
348                .as_slice()
349                .iter()
350                .find(|a| a.name.as_str() == "perm")
351                .map(|a| a.ints.iter().map(|&i| i as usize).collect::<Vec<usize>>())
352                .unwrap_or_else(|| (0..shape.len()).rev().collect());
353            if perm.iter().any(|&i| i >= shape.len()) {
354                return None;
355            }
356            Some(perm.iter().map(|&i| shape[i]).collect())
357        }
358        "Concat" => {
359            let mut shapes = Vec::new();
360            for inp in node.input.as_slice() {
361                if let Some(s) = ctx.value_shapes.get(inp.as_str()) {
362                    shapes.push(s.clone());
363                } else {
364                    return None;
365                }
366            }
367            if shapes.is_empty() {
368                return None;
369            }
370            let mut axis = node
371                .attribute
372                .as_slice()
373                .iter()
374                .find(|a| a.name.as_str() == "axis" && a.i != 0)
375                .map(|a| a.i)
376                .unwrap_or(0);
377            if axis < 0 {
378                axis += shapes[0].len() as i64;
379            }
380            let axis = axis as usize;
381            let mut out = shapes[0].clone();
382            for s in shapes.iter().skip(1) {
383                if s.len() != out.len() || axis >= s.len() {
384                    return None;
385                }
386                out[axis] += s[axis];
387            }
388            Some(out)
389        }
390        "Unsqueeze" => {
391            if node.input.as_slice().is_empty() {
392                return None;
393            }
394            let input_shape = ctx.value_shapes.get(node.input.as_slice()[0].as_str())?;
395            let mut axes = node
396                .attribute
397                .as_slice()
398                .iter()
399                .find(|a| a.name.as_str() == "axes")
400                .map(|a| a.ints.clone())
401                .unwrap_or_default();
402            // Opset >= 13: axes is a second input tensor, not an attribute
403            if axes.is_empty() && node.input.as_slice().len() > 1 {
404                axes = ctx
405                    .const_values
406                    .get(node.input.as_slice()[1].as_str())
407                    .cloned()
408                    .unwrap_or_default();
409            }
410            if axes.is_empty() {
411                return None;
412            }
413            let mut output_shape = input_shape.clone();
414            let mut sorted_axes = axes.clone();
415            sorted_axes.sort();
416            for axis in sorted_axes {
417                let idx = if axis < 0 {
418                    (output_shape.len() as i64 + axis + 1) as usize
419                } else {
420                    axis as usize
421                };
422                if idx > output_shape.len() {
423                    return None;
424                }
425                output_shape.insert(idx, 1);
426            }
427            Some(output_shape)
428        }
429        "Expand" => {
430            if node.input.as_slice().len() < 2 {
431                return None;
432            }
433            // Primary: use the shape tensor from const_values
434            if let Some(target_shape) = ctx.const_values.get(node.input.as_slice()[1].as_str()) {
435                if !target_shape.is_empty() {
436                    return Some(target_shape.clone());
437                }
438            }
439            // Fallback: use the output shape if already known (e.g. from ONNX value_info)
440            if let Some(out) = node.output.as_slice().first() {
441                if let Some(shape) = ctx.value_shapes.get(out.as_str()) {
442                    if !shape.is_empty() && shape.iter().all(|&d| d > 0) {
443                        return Some(shape.clone());
444                    }
445                }
446            }
447            None
448        }
449        "Squeeze" => {
450            if node.input.as_slice().is_empty() {
451                return None;
452            }
453            let input_shape = ctx.value_shapes.get(node.input.as_slice()[0].as_str())?;
454            let mut axes = node
455                .attribute
456                .as_slice()
457                .iter()
458                .find(|a| a.name.as_str() == "axes")
459                .map(|a| a.ints.clone())
460                .unwrap_or_default();
461            // Opset >= 13: axes is a second input tensor, not an attribute
462            if axes.is_empty() && node.input.as_slice().len() > 1 {
463                axes = ctx
464                    .const_values
465                    .get(node.input.as_slice()[1].as_str())
466                    .cloned()
467                    .unwrap_or_default();
468            }
469            let mut output_shape = input_shape.clone();
470            if axes.is_empty() {
471                output_shape.retain(|&d| d != 1);
472                return Some(output_shape);
473            }
474            let mut axes_norm: Vec<usize> = axes
475                .iter()
476                .map(|&a| {
477                    if a < 0 {
478                        (input_shape.len() as i64 + a) as usize
479                    } else {
480                        a as usize
481                    }
482                })
483                .collect();
484            axes_norm.sort();
485            axes_norm.dedup();
486            let mut keep = Vec::new();
487            for (idx, dim) in input_shape.iter().enumerate() {
488                if axes_norm.contains(&idx) {
489                    continue;
490                }
491                keep.push(*dim);
492            }
493            Some(keep)
494        }
495        "Reshape" => {
496            if node.input.as_slice().len() < 2 {
497                return None;
498            }
499            let data_shape = ctx.value_shapes.get(node.input.as_slice()[0].as_str())?;
500            let shape_input = node.input.as_slice()[1].as_str();
501            let mut target: Vec<i64> = ctx.const_values.get(shape_input)?.clone();
502
503            if target.contains(&-1) {
504                let total_input: i64 = data_shape.iter().product();
505                let known: i64 = target.iter().filter(|&&d| d != -1).product();
506                if known == 0 || total_input % known != 0 {
507                    return None;
508                }
509                if let Some(idx) = target.iter().position(|&d| d == -1) {
510                    target[idx] = total_input / known;
511                }
512            }
513            Some(target)
514        }
515        "Slice" => {
516            if node.input.as_slice().is_empty() {
517                return None;
518            }
519            let data_shape = ctx.value_shapes.get(node.input.as_slice()[0].as_str())?;
520            let starts = node
521                .input
522                .as_slice()
523                .get(1)
524                .and_then(|n| ctx.const_values.get(n))
525                .cloned()?;
526            let ends = node
527                .input
528                .as_slice()
529                .get(2)
530                .and_then(|n| ctx.const_values.get(n))
531                .cloned()?;
532            let axes = node
533                .input
534                .as_slice()
535                .get(3)
536                .and_then(|n| ctx.const_values.get(n))
537                .cloned()
538                .unwrap_or_else(|| (0..data_shape.len() as i64).collect());
539            let steps = node
540                .input
541                .as_slice()
542                .get(4)
543                .and_then(|n| ctx.const_values.get(n))
544                .cloned()
545                .unwrap_or_else(|| vec![1; axes.len()]);
546
547            if axes.len() != starts.len() || axes.len() != ends.len() || axes.len() != steps.len() {
548                return None;
549            }
550
551            let mut out = data_shape.clone();
552            for i in 0..axes.len() {
553                let mut axis = axes[i];
554                if axis < 0 {
555                    axis += data_shape.len() as i64;
556                }
557                let axis = axis as usize;
558                if axis >= out.len() {
559                    return None;
560                }
561                if steps[i] != 1 {
562                    return None;
563                }
564                let dim = data_shape[axis];
565                let mut start = starts[i];
566                let mut end = ends[i];
567                if start < 0 {
568                    start += dim;
569                }
570                if end < 0 {
571                    end += dim;
572                }
573                start = start.max(0);
574                end = end.min(dim);
575                out[axis] = if end < start { 0 } else { end - start };
576            }
577            Some(out)
578        }
579        "Gather" => {
580            if node.input.as_slice().len() < 2 {
581                return None;
582            }
583            let data_shape = ctx.value_shapes.get(node.input.as_slice()[0].as_str())?;
584            let indices_shape = ctx.value_shapes.get(node.input.as_slice()[1].as_str())?;
585            let mut axis = node
586                .attribute
587                .as_slice()
588                .iter()
589                .find(|a| a.name.as_str() == "axis" && a.i != 0)
590                .map(|a| a.i)
591                .unwrap_or(0);
592            if axis < 0 {
593                axis += data_shape.len() as i64;
594            }
595            let axis = axis as usize;
596            if axis > data_shape.len() {
597                return None;
598            }
599            let mut out = Vec::new();
600            out.extend_from_slice(&data_shape[..axis]);
601            out.extend(indices_shape.iter().cloned());
602            if axis < data_shape.len() {
603                out.extend_from_slice(&data_shape[axis + 1..]);
604            }
605            Some(out)
606        }
607        "Split" => {
608            let input_shape = node
609                .input
610                .as_slice()
611                .first()
612                .and_then(|i| ctx.value_shapes.get(i))
613                .cloned()?;
614            let mut axis = node
615                .attribute
616                .as_slice()
617                .iter()
618                .find(|a| a.name.as_str() == "axis" && a.i != 0)
619                .map(|a| a.i)
620                .unwrap_or(0);
621            if axis < 0 {
622                axis += input_shape.len() as i64;
623            }
624            let axis = axis as usize;
625            if axis >= input_shape.len() {
626                return None;
627            }
628            let splits = node
629                .attribute
630                .as_slice()
631                .iter()
632                .find(|a| a.name.as_str() == "split")
633                .map(|a| a.ints.clone());
634            if let Some(s) = splits {
635                if s.iter().any(|&v| v <= 0) {
636                    return None;
637                }
638                let sum: i64 = s.iter().sum();
639                if sum != input_shape[axis] {
640                    return None;
641                }
642                let mut out = input_shape.clone();
643                out[axis] = s[0];
644                Some(out)
645            } else {
646                let outputs = node.output.as_slice().len() as i64;
647                if outputs == 0 || input_shape[axis] % outputs != 0 {
648                    return None;
649                }
650                let chunk = input_shape[axis] / outputs;
651                let mut out = input_shape.clone();
652                out[axis] = chunk;
653                Some(out)
654            }
655        }
656        "ReduceMean" | "ReduceSum" | "ReduceMax" | "ReduceMin" => {
657            let input = node.input.as_slice().first()?;
658            let input_shape = ctx.value_shapes.get(input)?;
659            let axes: Vec<i64> = node
660                .attribute
661                .as_slice()
662                .iter()
663                .find(|a| a.name.as_str() == "axes")
664                .map(|a| a.ints.clone())
665                .unwrap_or_default();
666            let keepdims = node
667                .attribute
668                .as_slice()
669                .iter()
670                .find(|a| a.name.as_str() == "keepdims" && a.i != 0)
671                .map(|a| a.i != 0)
672                .unwrap_or(true);
673            if axes.is_empty() {
674                if keepdims {
675                    Some(vec![1; input_shape.len()])
676                } else {
677                    Some(vec![])
678                }
679            } else {
680                let mut out = input_shape.clone();
681                for axis in axes {
682                    let mut a = axis;
683                    if a < 0 {
684                        a += input_shape.len() as i64;
685                    }
686                    let idx = a as usize;
687                    if idx >= out.len() {
688                        return None;
689                    }
690                    if keepdims {
691                        out[idx] = 1;
692                    } else {
693                        out[idx] = -1;
694                    }
695                }
696                if !keepdims {
697                    out.retain(|&d| d != -1);
698                }
699                Some(out)
700            }
701        }
702        _ => None,
703    }
704}
705
706fn fold_integer_constants(graph: &GraphProto, ctx: &mut InferenceResult) -> bool {
707    let mut changed = false;
708    let mut where_count = 0;
709    for node in graph.node.as_slice() {
710        if node.op_type.as_str() == "Where" {
711            where_count += 1;
712        }
713        let outputs = node.output.as_slice();
714        if outputs.is_empty() {
715            continue;
716        }
717        if ctx.const_values.contains_key(outputs[0].as_str()) {
718            continue;
719        }
720
721        let op = node.op_type.as_str();
722        let inputs = node.input.as_slice();
723
724        // Shape nodes can be folded if the input shape is already known, even when the value is
725        // dynamic. This is critical for turning dynamic shape expressions into static vectors that
726        // downstream ops (Concat/Gather/Expand) can consume.
727        if op == "Shape" {
728            if let Some(inp) = inputs.first() {
729                if let Some(shape) = ctx.value_shapes.get(inp.as_str()) {
730                    let out_name = outputs[0].to_string();
731                    ctx.const_values.insert(out_name.clone(), shape.clone());
732                    ctx.value_shapes.insert(out_name, vec![shape.len() as i64]);
733                    changed = true;
734                    continue;
735                }
736            }
737        }
738
739        let all_const = inputs
740            .iter()
741            .all(|i| ctx.const_values.contains_key(i.as_str()));
742        if !all_const {
743            continue;
744        }
745
746        match op {
747            "Concat" => {
748                let mut axis = 0i64;
749                for attr in node.attribute.as_slice() {
750                    if attr.name.as_str() == "axis" && attr.i != 0 {
751                        axis = attr.i;
752                    }
753                }
754                if axis == 0 {
755                    let mut combined = Vec::new();
756                    for inp in inputs {
757                        if let Some(vals) = ctx.const_values.get(inp.as_str()) {
758                            combined.extend_from_slice(vals);
759                        }
760                    }
761                    if !combined.is_empty() {
762                        let out_name = outputs[0].to_string();
763                        ctx.const_values.insert(out_name.clone(), combined.clone());
764                        ctx.value_shapes
765                            .insert(out_name, vec![combined.len() as i64]);
766                        changed = true;
767                    }
768                }
769            }
770            "Gather" => {
771                let mut axis = 0i64;
772                for attr in node.attribute.as_slice() {
773                    if attr.name.as_str() == "axis" && attr.i != 0 {
774                        axis = attr.i;
775                    }
776                }
777                if axis == 0 && inputs.len() >= 2 {
778                    let data = ctx.const_values.get(inputs[0].as_str());
779                    let indices = ctx.const_values.get(inputs[1].as_str());
780                    if let (Some(data), Some(indices)) = (data, indices) {
781                        let mut gathered = Vec::new();
782                        for &idx in indices {
783                            let i = if idx < 0 {
784                                (data.len() as i64 + idx) as usize
785                            } else {
786                                idx as usize
787                            };
788                            if let Some(v) = data.get(i) {
789                                gathered.push(*v);
790                            }
791                        }
792                        if !gathered.is_empty() {
793                            let out_name = outputs[0].to_string();
794                            ctx.const_values.insert(out_name.clone(), gathered.clone());
795                            let shape = if gathered.len() == 1 {
796                                Vec::new()
797                            } else {
798                                vec![gathered.len() as i64]
799                            };
800                            ctx.value_shapes.insert(out_name, shape);
801                            changed = true;
802                        }
803                    }
804                }
805            }
806            "Unsqueeze" => {
807                if inputs.is_empty() {
808                    continue;
809                }
810                let data = ctx.const_values.get(inputs[0].as_str()).cloned();
811                if data.is_none() {
812                    continue;
813                }
814
815                let mut axes: Vec<i64> = node
816                    .attribute
817                    .as_slice()
818                    .iter()
819                    .find(|a| a.name.as_str() == "axes")
820                    .map(|a| a.ints.clone())
821                    .unwrap_or_default();
822                if axes.is_empty() && inputs.len() > 1 {
823                    axes = ctx
824                        .const_values
825                        .get(inputs[1].as_str())
826                        .cloned()
827                        .unwrap_or_default();
828                }
829                if axes.is_empty() {
830                    continue;
831                }
832
833                let mut sorted_axes = axes.clone();
834                sorted_axes.sort();
835
836                let mut out_shape = ctx
837                    .value_shapes
838                    .get(inputs[0].as_str())
839                    .cloned()
840                    .unwrap_or_else(|| {
841                        let len = data.as_ref().map(|v| v.len()).unwrap_or(0);
842                        if len <= 1 {
843                            Vec::new()
844                        } else {
845                            vec![len as i64]
846                        }
847                    });
848
849                for axis in sorted_axes {
850                    let idx = if axis < 0 {
851                        (out_shape.len() as i64 + axis + 1) as usize
852                    } else {
853                        axis as usize
854                    };
855                    if idx > out_shape.len() {
856                        continue;
857                    }
858                    out_shape.insert(idx, 1);
859                }
860
861                let out_name = outputs[0].to_string();
862                ctx.const_values
863                    .insert(out_name.clone(), data.unwrap_or_default());
864                ctx.value_shapes.insert(out_name, out_shape);
865                changed = true;
866            }
867            "Reshape" => {
868                if inputs.len() < 2 {
869                    continue;
870                }
871                let data = ctx.const_values.get(inputs[0].as_str()).cloned();
872                let shape_target = ctx.const_values.get(inputs[1].as_str()).cloned();
873                if let (Some(data), Some(mut target)) = (data, shape_target) {
874                    // Resolve -1 dimension
875                    if target.contains(&-1) {
876                        let total: i64 = if data.is_empty() {
877                            1
878                        } else {
879                            data.len() as i64
880                        };
881                        let known: i64 = target.iter().filter(|&&d| d != -1).product();
882                        if known != 0 {
883                            if let Some(idx) = target.iter().position(|&d| d == -1) {
884                                target[idx] = total / known;
885                            }
886                        }
887                    }
888                    let out_name = outputs[0].to_string();
889                    let out_shape = target.clone();
890                    ctx.const_values.insert(out_name.clone(), data);
891                    ctx.value_shapes.insert(out_name, out_shape);
892                    changed = true;
893                }
894            }
895            "ConstantOfShape" => {
896                // ConstantOfShape takes a 1D shape tensor and produces a tensor
897                // filled with a constant value (default 0.0f, or from 'value' attr)
898                if inputs.is_empty() {
899                    continue;
900                }
901                if let Some(shape_vals) = ctx.const_values.get(inputs[0].as_str()).cloned() {
902                    // Get the fill value from the 'value' attribute (default 0)
903                    let fill_value: i64 = node
904                        .attribute
905                        .as_slice()
906                        .iter()
907                        .find(|a| a.name.as_str() == "value")
908                        .and_then(|a| {
909                            let t = a.t.as_ref()?;
910                            if !t.raw_data.as_slice().is_empty() {
911                                // Try int64 first, then float
912                                if t.data_type == 7 && t.raw_data.as_slice().len() >= 8 {
913                                    let bytes = t.raw_data.as_slice();
914                                    Some(i64::from_le_bytes([
915                                        bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5],
916                                        bytes[6], bytes[7],
917                                    ]))
918                                } else if t.data_type == 1 && t.raw_data.as_slice().len() >= 4 {
919                                    let bytes = t.raw_data.as_slice();
920                                    Some(f32::from_le_bytes([
921                                        bytes[0], bytes[1], bytes[2], bytes[3],
922                                    ]) as i64)
923                                } else {
924                                    Some(0)
925                                }
926                            } else if !t.int64_data.as_slice().is_empty() {
927                                Some(t.int64_data.as_slice()[0])
928                            } else if !t.float_data.as_slice().is_empty() {
929                                Some(t.float_data.as_slice()[0] as i64)
930                            } else {
931                                Some(0)
932                            }
933                        })
934                        .unwrap_or(0);
935
936                    let total: usize = shape_vals.iter().map(|&d| d.max(0) as usize).product();
937                    let data = vec![fill_value; total];
938                    let out_name = outputs[0].to_string();
939                    ctx.const_values.insert(out_name.clone(), data);
940                    ctx.value_shapes.insert(out_name, shape_vals);
941                    changed = true;
942                }
943            }
944            "Mul" => {
945                if inputs.len() < 2 {
946                    continue;
947                }
948                let lhs = ctx.const_values.get(inputs[0].as_str()).cloned();
949                let rhs = ctx.const_values.get(inputs[1].as_str()).cloned();
950                if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
951                    // Handle scalar * tensor broadcasting
952                    let values: Vec<i64> = if lhs.len() == 1 && rhs.len() > 1 {
953                        rhs.iter().map(|&r| lhs[0] * r).collect()
954                    } else if rhs.len() == 1 && lhs.len() > 1 {
955                        lhs.iter().map(|&l| l * rhs[0]).collect()
956                    } else if lhs.len() == rhs.len() {
957                        lhs.iter().zip(rhs.iter()).map(|(&l, &r)| l * r).collect()
958                    } else {
959                        continue;
960                    };
961                    let out_name = outputs[0].to_string();
962                    let shape = if values.len() == 1 {
963                        Vec::new()
964                    } else {
965                        vec![values.len() as i64]
966                    };
967                    ctx.const_values.insert(out_name.clone(), values);
968                    ctx.value_shapes.insert(out_name, shape);
969                    changed = true;
970                }
971            }
972            "Equal" => {
973                if inputs.len() < 2 {
974                    continue;
975                }
976                let lhs = ctx.const_values.get(inputs[0].as_str()).cloned();
977                let rhs = ctx.const_values.get(inputs[1].as_str()).cloned();
978                if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
979                    if lhs.len() != rhs.len() {
980                        continue;
981                    }
982                    let values: Vec<i64> = lhs
983                        .iter()
984                        .zip(rhs.iter())
985                        .map(|(l, r)| if l == r { 1 } else { 0 })
986                        .collect();
987                    let out_name = outputs[0].to_string();
988                    let shape = if values.len() == 1 {
989                        Vec::new()
990                    } else {
991                        vec![values.len() as i64]
992                    };
993                    ctx.const_values.insert(out_name.clone(), values);
994                    ctx.value_shapes.insert(out_name, shape);
995                    changed = true;
996                }
997            }
998            "Where" => {
999                if inputs.len() < 3 {
1000                    continue;
1001                }
1002
1003                // Debug: always log Where operations that involve rotary
1004                if inputs.iter().any(|i| i.contains("rotary")) {
1005                    crate::debug_println!("[WHERE DEBUG] Processing Where node");
1006                    crate::debug_println!("  inputs: {:?}", inputs);
1007                    crate::debug_println!("  outputs: {:?}", outputs);
1008                }
1009
1010                let cond = ctx.const_values.get(inputs[0].as_str()).cloned();
1011                let a = ctx.const_values.get(inputs[1].as_str()).cloned();
1012                let b = ctx.const_values.get(inputs[2].as_str()).cloned();
1013
1014                if inputs.iter().any(|i| i.contains("rotary")) {
1015                    crate::debug_println!("  cond const: {}", cond.is_some());
1016                    crate::debug_println!("  a const: {}", a.is_some());
1017                    crate::debug_println!("  b const: {}", b.is_some());
1018                }
1019
1020                // Case 1: All inputs are constant - evaluate fully
1021                if let (Some(cond), Some(a), Some(b)) = (cond, a, b) {
1022                    if cond.len() != a.len() || a.len() != b.len() {
1023                        continue;
1024                    }
1025
1026                    // HEURISTIC: If one branch is trivial (all 1s, ≤3 elements) and the other is not,
1027                    // prefer the non-trivial one regardless of condition value.
1028                    // This handles rotary embedding patterns where Where(cond, [1,1,1], [1,32,1])
1029                    // should prefer [1,32,1] even if cond evaluates to select the first branch.
1030                    let is_trivial =
1031                        |vals: &[i64]| -> bool { vals.iter().all(|&v| v == 1) && vals.len() <= 3 };
1032
1033                    let mut out = if is_trivial(&a) && !is_trivial(&b) {
1034                        if inputs.iter().any(|i| i.contains("rotary")) {
1035                            crate::debug_println!("[WHERE SMART EVAL] Preferring non-trivial branch b={:?} over trivial a={:?}", b, a);
1036                        }
1037                        b
1038                    } else if is_trivial(&b) && !is_trivial(&a) {
1039                        if inputs.iter().any(|i| i.contains("rotary")) {
1040                            crate::debug_println!("[WHERE SMART EVAL] Preferring non-trivial branch a={:?} over trivial b={:?}", a, b);
1041                        }
1042                        a
1043                    } else {
1044                        // Normal element-wise evaluation
1045                        let mut result = Vec::with_capacity(a.len());
1046                        for i in 0..a.len() {
1047                            result.push(if cond[i] != 0 { a[i] } else { b[i] });
1048                        }
1049                        result
1050                    };
1051
1052                    // HEURISTIC: If the output contains -1 (reshape placeholder), try to resolve it
1053                    // For rotary embedding patterns, check if this feeds into an Expand operation
1054                    if out.contains(&-1) && !outputs.is_empty() {
1055                        let output_name = outputs[0].as_str();
1056                        // Look for Expand nodes that use this Where output as their shape input
1057                        for node in graph.node.as_slice() {
1058                            if node.op_type.as_str() == "Expand"
1059                                && node.input.len() >= 2
1060                                && node.input[1].as_str() == output_name
1061                            {
1062                                // Found the Expand - check its data input shape
1063                                let data_input = node.input[0].as_str();
1064                                if let Some(data_shape) = ctx.value_shapes.get(data_input) {
1065                                    // Resolve -1 based on data shape
1066                                    if out.len() == data_shape.len() {
1067                                        for i in 0..out.len() {
1068                                            if out[i] == -1 {
1069                                                out[i] = data_shape[i];
1070                                                if inputs.iter().any(|inp| inp.contains("rotary")) {
1071                                                    crate::debug_println!("[WHERE RESOLVE] Resolved -1 at position {} to {} from data shape {:?}", i, data_shape[i], data_shape);
1072                                                }
1073                                            }
1074                                        }
1075                                    }
1076                                }
1077                            }
1078                        }
1079                    }
1080
1081                    let out_name = outputs[0].to_string();
1082                    let shape = if out.len() == 1 {
1083                        Vec::new()
1084                    } else {
1085                        vec![out.len() as i64]
1086                    };
1087                    if inputs.iter().any(|i| i.contains("rotary")) {
1088                        crate::debug_println!("[WHERE STORE] Storing {} = {:?}", out_name, out);
1089                    }
1090                    ctx.const_values.insert(out_name.clone(), out);
1091                    ctx.value_shapes.insert(out_name, shape);
1092                    changed = true;
1093                } else {
1094                    // Case 2: Some inputs are dynamic - use shape inference heuristics
1095                    // This handles the common pattern: Where(dynamic_condition, trivial_constant, dynamic_value)
1096                    // Prefer the more specific/larger shape over trivial shapes like [1,1,1]
1097
1098                    let a_const = ctx.const_values.get(inputs[1].as_str());
1099                    let b_const = ctx.const_values.get(inputs[2].as_str());
1100                    let a_shape = ctx.value_shapes.get(inputs[1].as_str());
1101                    let b_shape = ctx.value_shapes.get(inputs[2].as_str());
1102
1103                    // Heuristic: If one branch is a trivial constant (all 1s) and the other has shape info, use the other
1104                    let is_trivial_constant =
1105                        |vals: &[i64]| -> bool { vals.iter().all(|&v| v == 1) && vals.len() <= 3 };
1106
1107                    let preferred_values = if let (Some(a_vals), None) = (a_const, b_const) {
1108                        // 'a' is constant, 'b' is dynamic
1109                        if is_trivial_constant(a_vals) && b_shape.is_some() {
1110                            // Prefer dynamic 'b' over trivial constant 'a'
1111                            // Use the shape of 'b' as the constant values for the Where output
1112                            crate::debug_println!("[WHERE HEURISTIC] Preferring dynamic input {} (shape {:?}) over trivial constant {:?}", inputs[2], b_shape, a_vals);
1113                            b_shape.cloned()
1114                        } else {
1115                            Some(a_vals.clone())
1116                        }
1117                    } else if let (None, Some(b_vals)) = (a_const, b_const) {
1118                        // 'b' is constant, 'a' is dynamic
1119                        if is_trivial_constant(b_vals) && a_shape.is_some() {
1120                            // Prefer dynamic 'a' over trivial constant 'b'
1121                            // Use the shape of 'a' as the constant values for the Where output
1122                            crate::debug_println!("[WHERE HEURISTIC] Preferring dynamic input {} (shape {:?}) over trivial constant {:?}", inputs[1], a_shape, b_vals);
1123                            a_shape.cloned()
1124                        } else {
1125                            Some(b_vals.clone())
1126                        }
1127                    } else {
1128                        None
1129                    };
1130
1131                    // Set both const_values and value_shapes for the output
1132                    if let Some(values) = preferred_values {
1133                        let out_name = outputs[0].to_string();
1134                        let shape = if values.len() == 1 {
1135                            Vec::new()
1136                        } else {
1137                            vec![values.len() as i64]
1138                        };
1139                        ctx.const_values.insert(out_name.clone(), values);
1140                        ctx.value_shapes.insert(out_name, shape);
1141                        changed = true;
1142                    }
1143                }
1144            }
1145            _ => {}
1146        }
1147    }
1148    if where_count > 0 {
1149        crate::debug_println!(
1150            "[FOLD DEBUG] Processed {} Where nodes, changed={}",
1151            where_count,
1152            changed
1153        );
1154    }
1155    changed
1156}
1157
1158fn read_int_tensor(tensor: &TensorProto) -> Vec<i64> {
1159    let raw = tensor.raw_data.as_slice();
1160    if !raw.is_empty() {
1161        match tensor.data_type {
1162            x if x == TensorProto_DataType::Int32 as i32 => raw
1163                .chunks_exact(4)
1164                .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as i64)
1165                .collect(),
1166            _ => raw
1167                .chunks_exact(8)
1168                .map(|c| i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))
1169                .collect(),
1170        }
1171    } else if !tensor.int64_data.as_slice().is_empty() {
1172        tensor.int64_data.as_slice().to_vec()
1173    } else if !tensor.int32_data.as_slice().is_empty() {
1174        tensor
1175            .int32_data
1176            .as_slice()
1177            .iter()
1178            .map(|&v| v as i64)
1179            .collect()
1180    } else {
1181        Vec::new()
1182    }
1183}
1184
1185#[cfg(test)]
1186mod tests {
1187    use super::*;
1188
1189    #[test]
1190    fn dynamic_dim_requires_override() {
1191        use crate::protos::onnx::{tensor_shape_proto, type_proto};
1192
1193        let dim = tensor_shape_proto::Dimension {
1194            value: Some(tensor_shape_proto::dimension::Value::DimParam(
1195                "batch".to_string(),
1196            )),
1197            denotation: String::new(),
1198        };
1199        let shape = crate::protos::onnx::TensorShapeProto { dim: vec![dim] };
1200
1201        let tensor_type = type_proto::Tensor {
1202            elem_type: crate::protos::onnx::TensorProto_DataType::Float.into(),
1203            shape: Some(shape),
1204        };
1205
1206        let type_proto = crate::protos::onnx::TypeProto {
1207            value: Some(type_proto::Value::TensorType(tensor_type)),
1208            denotation: String::new(),
1209        };
1210
1211        let vi = crate::protos::onnx::ValueInfoProto {
1212            name: "input".to_string(),
1213            r#type: Some(type_proto),
1214            ..Default::default()
1215        };
1216
1217        let graph = crate::protos::onnx::GraphProto {
1218            input: vec![vi],
1219            ..Default::default()
1220        };
1221
1222        let model = crate::protos::onnx::ModelProto {
1223            graph: Some(graph),
1224            ..Default::default()
1225        };
1226
1227        let res = infer_static_shapes(&model, &HashMap::new());
1228        assert!(matches!(
1229            res,
1230            Err(ShapeInferenceError::DynamicDim { dim, .. }) if dim == "batch"
1231        ));
1232    }
1233
1234    #[test]
1235    fn override_allows_static_shape() {
1236        use crate::protos::onnx::{tensor_shape_proto, type_proto};
1237
1238        let dim = tensor_shape_proto::Dimension {
1239            value: Some(tensor_shape_proto::dimension::Value::DimParam(
1240                "batch".to_string(),
1241            )),
1242            denotation: String::new(),
1243        };
1244        let shape = crate::protos::onnx::TensorShapeProto { dim: vec![dim] };
1245
1246        let tensor_type = type_proto::Tensor {
1247            elem_type: crate::protos::onnx::TensorProto_DataType::Float.into(),
1248            shape: Some(shape),
1249        };
1250
1251        let type_proto = crate::protos::onnx::TypeProto {
1252            value: Some(type_proto::Value::TensorType(tensor_type)),
1253            denotation: String::new(),
1254        };
1255
1256        let vi = crate::protos::onnx::ValueInfoProto {
1257            name: "input".to_string(),
1258            r#type: Some(type_proto),
1259            ..Default::default()
1260        };
1261
1262        let graph = crate::protos::onnx::GraphProto {
1263            input: vec![vi],
1264            ..Default::default()
1265        };
1266
1267        let model = crate::protos::onnx::ModelProto {
1268            graph: Some(graph),
1269            ..Default::default()
1270        };
1271
1272        let mut overrides = HashMap::new();
1273        overrides.insert("batch".to_string(), 1);
1274        let res = infer_static_shapes(&model, &overrides).unwrap();
1275        assert_eq!(res.value_shapes.get("input"), Some(&vec![1]));
1276    }
1277}