Skip to main content

webnn_graph/onnx/
convert.rs

1// Main ONNX to WebNN conversion logic
2
3use crate::ast::{DataType, Dimension, DynamicDimension, GraphJson};
4use crate::protos::onnx::{
5    tensor_shape_proto::dimension::Value as DimensionValue, type_proto::Value as TypeProtoValue,
6    ModelProto, TensorProto, TensorProto_DataType,
7};
8use prost::Message;
9use serde_json::Value as JsonValue;
10use std::collections::{BTreeMap, HashMap, HashSet};
11use std::fs;
12use std::path::Path;
13use thiserror::Error;
14use webnn_onnx_utils::{data_types as utils_data_types, identifiers};
15
16const MIN_SUPPORTED_OPSET: i64 = 11;
17const MAX_SUPPORTED_OPSET: i64 = 18;
18
19#[derive(Debug, Error)]
20pub enum OnnxError {
21    #[error("failed to read ONNX file: {0}")]
22    IoError(#[from] std::io::Error),
23
24    #[error("failed to parse ONNX protobuf: {0}")]
25    ProtobufError(String),
26
27    #[error("unsupported ONNX opset version {version} for domain '{domain}'")]
28    UnsupportedOpset { domain: String, version: i64 },
29
30    #[error("unsupported operator: {op} (node: {node})")]
31    UnsupportedOp { op: String, node: String },
32
33    #[error("missing required attribute: {attr} in {op}")]
34    MissingAttribute { attr: String, op: String },
35
36    #[error("invalid tensor shape: {0}")]
37    InvalidShape(String),
38
39    #[error("type conversion error: {0}")]
40    TypeConversion(#[from] webnn_onnx_utils::error::ConversionError),
41
42    #[error("shape inference failed for node: {0}")]
43    ShapeInference(String),
44}
45
46/// Sanitize ONNX identifiers for WebNN DSL compatibility
47/// Replaces problematic characters that would confuse the parser
48pub fn sanitize_identifier(name: &str) -> String {
49    identifiers::sanitize_for_webnn(name)
50}
51
52/// Convert ONNX data type code to WebNN DataType using shared utilities
53pub(crate) fn map_onnx_data_type(onnx_type: i32) -> Result<DataType, OnnxError> {
54    if onnx_type == TensorProto_DataType::Bool as i32 {
55        return Ok(DataType::Uint8);
56    }
57
58    let utils_dtype = utils_data_types::onnx_to_webnn(onnx_type)?;
59    Ok(match utils_dtype {
60        utils_data_types::DataType::Float32 => DataType::Float32,
61        utils_data_types::DataType::Float16 => DataType::Float16,
62        utils_data_types::DataType::Int32 => DataType::Int32,
63        utils_data_types::DataType::Uint32 => DataType::Uint32,
64        utils_data_types::DataType::Int64 => DataType::Int64,
65        utils_data_types::DataType::Uint64 => DataType::Uint64,
66        utils_data_types::DataType::Int8 => DataType::Int8,
67        utils_data_types::DataType::Uint8 => DataType::Uint8,
68    })
69}
70
71/// Infer output shape for an ONNX node based on its operation type and inputs
72fn infer_shape(
73    node: &crate::protos::onnx::NodeProto,
74    value_shapes: &HashMap<String, Vec<i64>>,
75    initializers: &HashMap<String, &TensorProto>,
76    const_values: &HashMap<String, Vec<i64>>,
77) -> Option<Vec<i64>> {
78    let op = node.op_type.as_str();
79
80    match op {
81        // Unary operations that preserve shape
82        "Cast" | "Relu" | "Tanh" | "Sigmoid" | "Erf" | "Softmax" | "Gelu" | "Exp" | "Log"
83        | "Abs" | "Neg" | "Sqrt" | "LayerNormalization" | "Trilu" => {
84            let ins = node.input.as_slice();
85            if ins.is_empty() {
86                return None;
87            }
88            value_shapes.get(ins[0].as_str()).cloned()
89        }
90
91        // Binary operations with NumPy-style broadcasting semantics.
92        "Add" | "Sub" | "Mul" | "Div" | "Pow" => {
93            let ins = node.input.as_slice();
94            if ins.len() < 2 {
95                return None;
96            }
97
98            let shape_a = value_shapes.get(ins[0].as_str());
99            let shape_b = value_shapes.get(ins[1].as_str());
100
101            match (shape_a, shape_b) {
102                (Some(a), Some(b)) => {
103                    let rank = a.len().max(b.len());
104                    let mut out_rev = Vec::with_capacity(rank);
105                    for i in 0..rank {
106                        let da = a.get(a.len().wrapping_sub(1 + i)).copied().unwrap_or(1);
107                        let db = b.get(b.len().wrapping_sub(1 + i)).copied().unwrap_or(1);
108                        if da == db || da == 1 {
109                            out_rev.push(db);
110                        } else if db == 1 {
111                            out_rev.push(da);
112                        } else {
113                            return None;
114                        }
115                    }
116                    out_rev.reverse();
117                    Some(out_rev)
118                }
119                (Some(a), None) => Some(a.clone()),
120                (None, Some(b)) => Some(b.clone()),
121                (None, None) => None,
122            }
123        }
124
125        // MatMul (2D matrix multiplication)
126        "MatMul" => {
127            let ins = node.input.as_slice();
128            if ins.len() < 2 {
129                return None;
130            }
131
132            let a_shape = value_shapes.get(ins[0].as_str())?;
133            let b_shape = value_shapes.get(ins[1].as_str())?;
134
135            // Handle 2D case: [M, K] @ [K, N] -> [M, N]
136            if a_shape.len() >= 2 && b_shape.len() >= 2 {
137                let m = a_shape[a_shape.len() - 2];
138                let n = b_shape[b_shape.len() - 1];
139
140                // For higher-dim inputs, preserve batch dimensions
141                if a_shape.len() == 2 && b_shape.len() == 2 {
142                    return Some(vec![m, n]);
143                } else if a_shape.len() > 2 {
144                    let mut result = a_shape[..a_shape.len() - 2].to_vec();
145                    result.push(m);
146                    result.push(n);
147                    return Some(result);
148                }
149            }
150            None
151        }
152
153        // Transpose preserves shape with permuted dimensions
154        "Transpose" => {
155            let ins = node.input.as_slice();
156            if ins.is_empty() {
157                return None;
158            }
159            let input_shape = value_shapes.get(ins[0].as_str())?;
160
161            // Get perm attribute
162            let perm: Vec<usize> = node
163                .attribute
164                .as_slice()
165                .iter()
166                .find(|a| a.name.as_str() == "perm")
167                .map(|a| a.ints.iter().map(|&i| i as usize).collect::<Vec<usize>>())
168                .unwrap_or_else(|| (0..input_shape.len()).rev().collect());
169
170            // Apply permutation
171            Some(perm.iter().map(|&i| input_shape[i]).collect())
172        }
173
174        // Reduce operations
175        "ReduceMean" | "ReduceSum" | "ReduceMax" | "ReduceMin" => {
176            let ins = node.input.as_slice();
177            if ins.is_empty() {
178                return None;
179            }
180            let input_shape = value_shapes.get(ins[0].as_str())?;
181
182            // Check keepdims attribute (default is 1/true)
183            let keepdims = node
184                .attribute
185                .as_slice()
186                .iter()
187                .find(|a| a.name.as_str() == "keepdims")
188                .and_then(|a| if a.i != 0 { Some(a.i != 0) } else { None })
189                .unwrap_or(true);
190
191            // Get axes attribute
192            let axes: Vec<i64> = node
193                .attribute
194                .as_slice()
195                .iter()
196                .find(|a| a.name.as_str() == "axes")
197                .map(|a| a.ints.clone())
198                .unwrap_or_default();
199
200            if axes.is_empty() {
201                // Reduce all dimensions
202                if keepdims {
203                    Some(vec![1; input_shape.len()])
204                } else {
205                    Some(vec![])
206                }
207            } else {
208                // Reduce specific axes
209                let mut output_shape = input_shape.clone();
210                for &axis in &axes {
211                    let idx = if axis < 0 {
212                        (input_shape.len() as i64 + axis) as usize
213                    } else {
214                        axis as usize
215                    };
216                    if idx < output_shape.len() {
217                        if keepdims {
218                            output_shape[idx] = 1;
219                        } else {
220                            output_shape[idx] = -1; // Mark for removal
221                        }
222                    }
223                }
224                if !keepdims {
225                    output_shape.retain(|&d| d != -1);
226                }
227                Some(output_shape)
228            }
229        }
230
231        // Gemm (generalized matrix multiplication)
232        "Gemm" => {
233            let ins = node.input.as_slice();
234            if ins.len() < 2 {
235                return None;
236            }
237
238            let a_shape = value_shapes.get(ins[0].as_str())?;
239            let b_shape = value_shapes.get(ins[1].as_str())?;
240
241            if a_shape.len() != 2 || b_shape.len() != 2 {
242                return None;
243            }
244
245            // Check transA and transB attributes
246            let trans_a = node
247                .attribute
248                .as_slice()
249                .iter()
250                .find(|a| a.name.as_str() == "transA")
251                .and_then(|a| if a.i != 0 { Some(a.i != 0) } else { None })
252                .unwrap_or(false);
253
254            let trans_b = node
255                .attribute
256                .as_slice()
257                .iter()
258                .find(|a| a.name.as_str() == "transB")
259                .and_then(|a| if a.i != 0 { Some(a.i != 0) } else { None })
260                .unwrap_or(false);
261
262            let m = if trans_a { a_shape[1] } else { a_shape[0] };
263            let n = if trans_b { b_shape[0] } else { b_shape[1] };
264
265            Some(vec![m, n])
266        }
267
268        "Gather" => {
269            let ins = node.input.as_slice();
270            if ins.len() < 2 {
271                return None;
272            }
273
274            let data_shape = value_shapes.get(ins[0].as_str())?;
275            let indices_shape = value_shapes.get(ins[1].as_str())?;
276
277            let mut axis = node
278                .attribute
279                .as_slice()
280                .iter()
281                .find(|a| a.name.as_str() == "axis")
282                .and_then(|a| if a.i != 0 { Some(a.i) } else { None })
283                .unwrap_or(0);
284
285            if axis < 0 {
286                axis += data_shape.len() as i64;
287            }
288
289            let axis_usize = axis as usize;
290            if axis_usize > data_shape.len() {
291                return None;
292            }
293
294            let mut output = Vec::new();
295            output.extend_from_slice(&data_shape[..axis_usize]);
296            output.extend(indices_shape.iter().cloned());
297            if axis_usize < data_shape.len() {
298                output.extend_from_slice(&data_shape[axis_usize + 1..]);
299            }
300            Some(output)
301        }
302
303        "Unsqueeze" => {
304            let ins = node.input.as_slice();
305            if ins.is_empty() {
306                return None;
307            }
308
309            let input_shape = value_shapes.get(ins[0].as_str())?.clone();
310            let mut axes: Vec<i64> = node
311                .attribute
312                .as_slice()
313                .iter()
314                .find(|a| a.name.as_str() == "axes")
315                .map(|a| a.ints.clone())
316                .unwrap_or_default();
317
318            if axes.is_empty() {
319                return None;
320            }
321
322            axes.sort();
323            let mut output_shape = input_shape;
324            for axis in axes {
325                let idx = if axis < 0 {
326                    (output_shape.len() as i64 + axis + 1) as usize
327                } else {
328                    axis as usize
329                };
330                if idx <= output_shape.len() {
331                    output_shape.insert(idx, 1);
332                }
333            }
334            Some(output_shape)
335        }
336
337        "Concat" => {
338            let mut shapes = Vec::new();
339            for inp in node.input.as_slice() {
340                let shape = value_shapes.get(inp.as_str())?;
341                shapes.push(shape.clone());
342            }
343
344            if shapes.is_empty() {
345                return None;
346            }
347
348            let mut axis = node
349                .attribute
350                .as_slice()
351                .iter()
352                .find(|a| a.name.as_str() == "axis")
353                .and_then(|a| if a.i != 0 { Some(a.i) } else { None })
354                .unwrap_or(0);
355
356            if axis < 0 {
357                axis += shapes[0].len() as i64;
358            }
359            let axis_usize = axis as usize;
360
361            let mut output = shapes[0].clone();
362            for shape in shapes.iter().skip(1) {
363                if shape.len() != output.len() || axis_usize >= shape.len() {
364                    return None;
365                }
366                output[axis_usize] += shape[axis_usize];
367            }
368            Some(output)
369        }
370
371        "Reshape" => {
372            let ins = node.input.as_slice();
373            if ins.len() < 2 {
374                return None;
375            }
376
377            let input_shape = value_shapes.get(ins[0].as_str())?;
378            let shape_input = ins[1].as_str();
379            let mut target: Vec<i64> = if let Some(values) = const_values.get(shape_input) {
380                values.clone()
381            } else if let Some(shape_tensor) = initializers.get(shape_input) {
382                if !shape_tensor.raw_data.as_slice().is_empty() {
383                    if shape_tensor.data_type == TensorProto_DataType::Int32 as i32 {
384                        shape_tensor
385                            .raw_data
386                            .as_slice()
387                            .chunks_exact(4)
388                            .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as i64)
389                            .collect()
390                    } else {
391                        shape_tensor
392                            .raw_data
393                            .as_slice()
394                            .chunks_exact(8)
395                            .map(|c| {
396                                i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
397                            })
398                            .collect()
399                    }
400                } else if !shape_tensor.int64_data.as_slice().is_empty() {
401                    shape_tensor.int64_data.as_slice().to_vec()
402                } else if !shape_tensor.int32_data.as_slice().is_empty() {
403                    shape_tensor
404                        .int32_data
405                        .as_slice()
406                        .iter()
407                        .map(|&v| v as i64)
408                        .collect()
409                } else {
410                    Vec::new()
411                }
412            } else {
413                Vec::new()
414            };
415
416            if target.is_empty() {
417                return None;
418            }
419
420            if target.contains(&-1) {
421                let total_input: i64 = input_shape.iter().product();
422                let known: i64 = target.iter().filter(|&&d| d != -1).product();
423                if known == 0 || total_input % known != 0 {
424                    return None;
425                }
426                if let Some(idx) = target.iter().position(|&d| d == -1) {
427                    target[idx] = total_input / known;
428                }
429            }
430
431            Some(target)
432        }
433
434        "Slice" => {
435            let ins = node.input.as_slice();
436            if ins.is_empty() {
437                return None;
438            }
439
440            let input_shape = value_shapes.get(ins[0].as_str())?;
441
442            let read_ints = |name: Option<&String>| -> Option<Vec<i64>> {
443                if let Some(n) = name {
444                    if let Some(v) = const_values.get(n) {
445                        return Some(v.clone());
446                    }
447                    if let Some(t) = initializers.get(n) {
448                        let raw = t.raw_data.as_slice();
449                        if !raw.is_empty() {
450                            if t.data_type == TensorProto_DataType::Int32 as i32 {
451                                return Some(
452                                    raw.chunks_exact(4)
453                                        .map(|c| {
454                                            i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as i64
455                                        })
456                                        .collect(),
457                                );
458                            } else {
459                                return Some(
460                                    raw.chunks_exact(8)
461                                        .map(|c| {
462                                            i64::from_le_bytes([
463                                                c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7],
464                                            ])
465                                        })
466                                        .collect(),
467                                );
468                            }
469                        } else if !t.int64_data.as_slice().is_empty() {
470                            return Some(t.int64_data.as_slice().to_vec());
471                        } else if !t.int32_data.as_slice().is_empty() {
472                            return Some(
473                                t.int32_data.as_slice().iter().map(|&v| v as i64).collect(),
474                            );
475                        }
476                    }
477                }
478                None
479            };
480
481            let starts = read_ints(ins.get(1))?;
482            let ends = read_ints(ins.get(2))?;
483            let axes =
484                read_ints(ins.get(3)).unwrap_or_else(|| (0..input_shape.len() as i64).collect());
485            let steps = read_ints(ins.get(4)).unwrap_or_else(|| vec![1; axes.len()]);
486
487            if axes.len() != starts.len() || axes.len() != ends.len() || axes.len() != steps.len() {
488                return None;
489            }
490
491            let mut output = input_shape.clone();
492            for i in 0..axes.len() {
493                let axis = if axes[i] < 0 {
494                    (input_shape.len() as i64 + axes[i]) as usize
495                } else {
496                    axes[i] as usize
497                };
498                if axis >= output.len() {
499                    return None;
500                }
501
502                let step = steps[i];
503                if step != 1 {
504                    return None;
505                }
506
507                let dim = input_shape[axis];
508                let mut start = starts[i];
509                let mut end = ends[i];
510
511                if start < 0 {
512                    start += dim;
513                }
514                if end < 0 {
515                    end += dim;
516                }
517
518                start = start.max(0);
519                end = end.min(dim);
520
521                if end < start {
522                    output[axis] = 0;
523                } else {
524                    output[axis] = end - start;
525                }
526            }
527
528            Some(output)
529        }
530
531        _ => None,
532    }
533}
534
535fn shape_numel(shape: &[i64]) -> Option<usize> {
536    shape.iter().try_fold(1usize, |acc, &d| {
537        if d < 0 {
538            return None;
539        }
540        usize::try_from(d).ok().map(|v| acc.saturating_mul(v))
541    })
542}
543
544fn const_shape_for_folding(
545    name: &str,
546    values: &[i64],
547    value_shapes: &HashMap<String, Vec<i64>>,
548) -> Vec<i64> {
549    if let Some(shape) = value_shapes.get(name) {
550        if shape_numel(shape) == Some(values.len()) {
551            return shape.clone();
552        }
553    }
554
555    if values.len() == 1 {
556        Vec::new()
557    } else {
558        vec![values.len() as i64]
559    }
560}
561
562fn broadcast_shape(shape_a: &[i64], shape_b: &[i64]) -> Option<Vec<i64>> {
563    let rank = shape_a.len().max(shape_b.len());
564    let mut out_rev = Vec::with_capacity(rank);
565    for i in 0..rank {
566        let da = shape_a
567            .get(shape_a.len().wrapping_sub(1 + i))
568            .copied()
569            .unwrap_or(1);
570        let db = shape_b
571            .get(shape_b.len().wrapping_sub(1 + i))
572            .copied()
573            .unwrap_or(1);
574        if da <= 0 || db <= 0 {
575            return None;
576        }
577        if da == db || da == 1 {
578            out_rev.push(db);
579        } else if db == 1 {
580            out_rev.push(da);
581        } else {
582            return None;
583        }
584    }
585    out_rev.reverse();
586    Some(out_rev)
587}
588
589fn linear_index_for_broadcast_operand(
590    out_linear_idx: usize,
591    out_shape: &[i64],
592    in_shape: &[i64],
593) -> Option<usize> {
594    if in_shape.is_empty() {
595        return Some(0);
596    }
597
598    let in_rank = in_shape.len();
599    let out_rank = out_shape.len();
600    if in_rank > out_rank {
601        return None;
602    }
603
604    let mut in_linear_idx = 0usize;
605    let mut in_stride = 1usize;
606    let mut rem = out_linear_idx;
607
608    for out_axis_rev in 0..out_rank {
609        let out_axis = out_rank - 1 - out_axis_rev;
610        let out_dim = usize::try_from(out_shape[out_axis]).ok()?;
611        if out_dim == 0 {
612            return None;
613        }
614        let out_coord = rem % out_dim;
615        rem /= out_dim;
616
617        if out_axis_rev < in_rank {
618            let in_axis = in_rank - 1 - out_axis_rev;
619            let in_dim = usize::try_from(in_shape[in_axis]).ok()?;
620            if in_dim == 0 {
621                return None;
622            }
623            let in_coord = if in_dim == 1 { 0 } else { out_coord };
624            in_linear_idx = in_linear_idx.saturating_add(in_coord.saturating_mul(in_stride));
625            in_stride = in_stride.saturating_mul(in_dim);
626        }
627    }
628
629    Some(in_linear_idx)
630}
631
632fn fold_binary_const_i64(
633    op_type: &str,
634    a_values: &[i64],
635    b_values: &[i64],
636    a_shape: &[i64],
637    b_shape: &[i64],
638) -> Option<(Vec<i64>, Vec<i64>)> {
639    let out_shape = broadcast_shape(a_shape, b_shape)?;
640    let out_numel = shape_numel(&out_shape)?;
641
642    let mut out_values = Vec::with_capacity(out_numel);
643    for out_idx in 0..out_numel {
644        let a_idx = linear_index_for_broadcast_operand(out_idx, &out_shape, a_shape)?;
645        let b_idx = linear_index_for_broadcast_operand(out_idx, &out_shape, b_shape)?;
646        let av = *a_values.get(a_idx)?;
647        let bv = *b_values.get(b_idx)?;
648        let v = match op_type {
649            "Add" => av + bv,
650            "Sub" => av - bv,
651            "Mul" => av * bv,
652            "Div" => {
653                if bv == 0 {
654                    return None;
655                }
656                av / bv
657            }
658            "Equal" => {
659                if av == bv {
660                    1
661                } else {
662                    0
663                }
664            }
665            _ => return None,
666        };
667        out_values.push(v);
668    }
669
670    Some((out_values, out_shape))
671}
672
673fn value_shape_dims_for<'a>(
674    name: &str,
675    value_shape_dims: &'a HashMap<String, Vec<Dimension>>,
676) -> Option<&'a [Dimension]> {
677    let sanitized = sanitize_identifier(name);
678    let trimmed = name.trim_start_matches('/');
679    value_shape_dims
680        .get(name)
681        .or_else(|| value_shape_dims.get(&sanitized))
682        .or_else(|| value_shape_dims.get(trimmed))
683        .map(Vec::as_slice)
684}
685
686fn dims_contain_dynamic(dims: &[Dimension]) -> bool {
687    dims.iter().any(|d| matches!(d, Dimension::Dynamic(_)))
688}
689
690pub(crate) fn parse_dynamic_dim_expr(dim_name: &str) -> (String, i64) {
691    let s = dim_name.trim();
692    if let Some((lhs, rhs)) = s.rsplit_once('+') {
693        if let Ok(offset) = rhs.trim().parse::<i64>() {
694            return (lhs.trim().to_string(), offset);
695        }
696    }
697    if let Some((lhs, rhs)) = s.rsplit_once('-') {
698        if let Ok(offset) = rhs.trim().parse::<i64>() {
699            return (lhs.trim().to_string(), -offset);
700        }
701    }
702    (s.to_string(), 0)
703}
704
705pub(crate) fn format_dynamic_dim_expr(base: &str, offset: i64) -> String {
706    if offset > 0 {
707        format!("{base} + {offset}")
708    } else if offset < 0 {
709        format!("{base} - {}", offset.abs())
710    } else {
711        base.to_string()
712    }
713}
714
715fn parse_additive_dynamic_dim_expr(dim_name: &str) -> Option<(BTreeMap<String, i64>, i64)> {
716    let expr = dim_name.trim();
717    if expr.is_empty() {
718        return None;
719    }
720
721    let normalized = expr.replace('+', " + ").replace('-', " - ");
722    let mut terms = BTreeMap::new();
723    let mut constant = 0i64;
724    let mut sign = 1i64;
725    let mut saw_term = false;
726
727    for token in normalized.split_whitespace() {
728        match token {
729            "+" => sign = 1,
730            "-" => sign = -1,
731            _ => {
732                saw_term = true;
733                if let Ok(value) = token.parse::<i64>() {
734                    constant += sign * value;
735                } else {
736                    *terms.entry(token.to_string()).or_insert(0) += sign;
737                }
738                sign = 1;
739            }
740        }
741    }
742
743    if !saw_term {
744        return None;
745    }
746
747    terms.retain(|_, coeff| *coeff != 0);
748    Some((terms, constant))
749}
750
751fn format_additive_dynamic_dim_expr(
752    terms: &BTreeMap<String, i64>,
753    constant: i64,
754) -> Option<String> {
755    if terms.is_empty() && constant == 0 {
756        return None;
757    }
758
759    let mut out = String::new();
760    for (name, coeff) in terms {
761        for _ in 0..coeff.abs() {
762            if out.is_empty() {
763                if *coeff < 0 {
764                    out.push_str("- ");
765                }
766                out.push_str(name);
767            } else if *coeff < 0 {
768                out.push_str(" - ");
769                out.push_str(name);
770            } else {
771                out.push_str(" + ");
772                out.push_str(name);
773            }
774        }
775    }
776
777    if constant != 0 {
778        if out.is_empty() {
779            out.push_str(&constant.to_string());
780        } else if constant < 0 {
781            out.push_str(" - ");
782            out.push_str(&constant.abs().to_string());
783        } else {
784            out.push_str(" + ");
785            out.push_str(&constant.to_string());
786        }
787    }
788
789    Some(out)
790}
791
792fn is_runtime_resolvable_dynamic_dim_expr(dim_name: &str) -> bool {
793    let s = dim_name.trim();
794    if s.is_empty() || s.contains('*') || s.contains('/') {
795        return false;
796    }
797    if let Some((lhs, rhs)) = s.rsplit_once('+') {
798        return !lhs.trim().is_empty() && rhs.trim().parse::<i64>().is_ok();
799    }
800    if let Some((lhs, rhs)) = s.rsplit_once('-') {
801        return !lhs.trim().is_empty() && rhs.trim().parse::<i64>().is_ok();
802    }
803    true
804}
805
806fn shift_dynamic_dimension(dim: &DynamicDimension, delta: i64) -> Option<DynamicDimension> {
807    let (base, offset) = parse_dynamic_dim_expr(&dim.name);
808    let name = format_dynamic_dim_expr(&base, offset.checked_add(delta)?);
809    let shifted_max = (dim.max_size as i64).checked_add(delta)?.max(0);
810    let max_size = u32::try_from(shifted_max).ok()?;
811    Some(DynamicDimension { name, max_size })
812}
813
814pub(crate) fn dynamic_scalar_dimension_for_value(
815    name: &str,
816    value_shape_dims: &HashMap<String, Vec<Dimension>>,
817) -> Option<DynamicDimension> {
818    let dims = value_shape_dims_for(name, value_shape_dims)?;
819    if dims.len() != 1 {
820        return None;
821    }
822    match &dims[0] {
823        Dimension::Dynamic(dim) => Some(dim.clone()),
824        Dimension::Static(_) => None,
825    }
826}
827
828fn dimension_vector_for_value(
829    name: &str,
830    const_values: &HashMap<String, Vec<i64>>,
831    value_shape_dims: &HashMap<String, Vec<Dimension>>,
832) -> Option<Vec<Dimension>> {
833    if let Some(dims) = value_shape_dims_for(name, value_shape_dims) {
834        return Some(dims.to_vec());
835    }
836    let values = const_values.get(name)?;
837    values
838        .iter()
839        .map(|&v| u32::try_from(v).ok().map(Dimension::Static))
840        .collect()
841}
842
843fn is_trivial_static_dimension_vector(dims: &[Dimension]) -> bool {
844    dims.len() <= 3 && dims.iter().all(|d| matches!(d, Dimension::Static(1)))
845}
846
847fn combine_binary_dimension(
848    op_type: &str,
849    dynamic: &DynamicDimension,
850    static_value: i64,
851    dynamic_on_lhs: bool,
852) -> Option<Dimension> {
853    match op_type {
854        "Add" => shift_dynamic_dimension(dynamic, static_value).map(Dimension::Dynamic),
855        "Sub" if dynamic_on_lhs => {
856            shift_dynamic_dimension(dynamic, -static_value).map(Dimension::Dynamic)
857        }
858        "Mul" if static_value == 0 => Some(Dimension::Static(0)),
859        "Mul" if static_value == 1 => Some(Dimension::Dynamic(dynamic.clone())),
860        "Mul" if static_value > 1 => Some(Dimension::Dynamic(DynamicDimension {
861            name: if dynamic_on_lhs {
862                format!("{} * {}", dynamic.name, static_value)
863            } else {
864                format!("{} * {}", static_value, dynamic.name)
865            },
866            max_size: dynamic.max_size.saturating_mul(static_value as u32),
867        })),
868        "Div" if dynamic_on_lhs && static_value == 1 => Some(Dimension::Dynamic(dynamic.clone())),
869        "Div" if dynamic_on_lhs && static_value > 1 => Some(Dimension::Dynamic(DynamicDimension {
870            name: format!("{} / {}", dynamic.name, static_value),
871            max_size: dynamic.max_size / (static_value as u32),
872        })),
873        _ => None,
874    }
875}
876
877fn combine_dynamic_dimensions(
878    op_type: &str,
879    lhs: &DynamicDimension,
880    rhs: &DynamicDimension,
881    lhs_value: i64,
882    rhs_value: i64,
883) -> Option<Dimension> {
884    match op_type {
885        "Add" | "Sub" => {
886            let (mut terms, mut constant) = parse_additive_dynamic_dim_expr(&lhs.name)?;
887            let (rhs_terms, rhs_constant) = parse_additive_dynamic_dim_expr(&rhs.name)?;
888            let rhs_sign = if op_type == "Add" { 1 } else { -1 };
889
890            for (name, coeff) in rhs_terms {
891                *terms.entry(name).or_insert(0) += rhs_sign * coeff;
892            }
893            constant += rhs_sign * rhs_constant;
894            terms.retain(|_, coeff| *coeff != 0);
895
896            let value = if op_type == "Add" {
897                lhs_value.checked_add(rhs_value)?
898            } else {
899                lhs_value.checked_sub(rhs_value)?
900            };
901            if terms.is_empty() {
902                return u32::try_from(value).ok().map(Dimension::Static);
903            }
904
905            let name = format_additive_dynamic_dim_expr(&terms, constant)?;
906            let max_size = u32::try_from(value).ok()?;
907            Some(Dimension::Dynamic(DynamicDimension { name, max_size }))
908        }
909        _ => None,
910    }
911}
912
913fn fold_binary_dynamic_dims(
914    op_type: &str,
915    a_values: &[i64],
916    b_values: &[i64],
917    a_shape: &[i64],
918    b_shape: &[i64],
919    a_dims: Option<&[Dimension]>,
920    b_dims: Option<&[Dimension]>,
921) -> Option<Vec<Dimension>> {
922    let out_shape = broadcast_shape(a_shape, b_shape)?;
923    let out_numel = shape_numel(&out_shape)?;
924    let mut out_dims = Vec::with_capacity(out_numel);
925    let mut has_dynamic = false;
926
927    for out_idx in 0..out_numel {
928        let a_idx = linear_index_for_broadcast_operand(out_idx, &out_shape, a_shape)?;
929        let b_idx = linear_index_for_broadcast_operand(out_idx, &out_shape, b_shape)?;
930        let av = *a_values.get(a_idx)?;
931        let bv = *b_values.get(b_idx)?;
932        let a_dim = a_dims.and_then(|dims| dims.get(a_idx));
933        let b_dim = b_dims.and_then(|dims| dims.get(b_idx));
934
935        let out_dim = match (a_dim, b_dim) {
936            (Some(Dimension::Dynamic(dynamic)), Some(Dimension::Static(_)))
937            | (Some(Dimension::Dynamic(dynamic)), None) => {
938                let dim = combine_binary_dimension(op_type, dynamic, bv, true)?;
939                has_dynamic |= matches!(dim, Dimension::Dynamic(_));
940                dim
941            }
942            (Some(Dimension::Static(_)), Some(Dimension::Dynamic(dynamic)))
943            | (None, Some(Dimension::Dynamic(dynamic))) => {
944                let dim = combine_binary_dimension(op_type, dynamic, av, false)?;
945                has_dynamic |= matches!(dim, Dimension::Dynamic(_));
946                dim
947            }
948            (Some(Dimension::Dynamic(a_dynamic)), Some(Dimension::Dynamic(b_dynamic))) => {
949                let dim = combine_dynamic_dimensions(op_type, a_dynamic, b_dynamic, av, bv)?;
950                has_dynamic |= matches!(dim, Dimension::Dynamic(_));
951                dim
952            }
953            _ => {
954                let value = match op_type {
955                    "Add" => av + bv,
956                    "Sub" => av - bv,
957                    "Mul" => av * bv,
958                    "Div" => {
959                        if bv == 0 {
960                            return None;
961                        }
962                        av / bv
963                    }
964                    _ => return None,
965                };
966                Dimension::Static(u32::try_from(value).ok()?)
967            }
968        };
969
970        out_dims.push(out_dim);
971    }
972
973    has_dynamic.then_some(out_dims)
974}
975
976pub(crate) fn dynamic_range_length_dimension(
977    start: i64,
978    delta: i64,
979    start_dim: Option<&DynamicDimension>,
980    limit: &DynamicDimension,
981) -> Option<DynamicDimension> {
982    if delta != 1 {
983        return None;
984    }
985
986    let (mut terms, mut constant) = parse_additive_dynamic_dim_expr(&limit.name)?;
987    if let Some(start_dim) = start_dim {
988        let (start_terms, start_constant) = parse_additive_dynamic_dim_expr(&start_dim.name)?;
989        for (name, coeff) in start_terms {
990            *terms.entry(name).or_insert(0) -= coeff;
991        }
992        constant -= start_constant;
993    } else {
994        constant -= start;
995    }
996    terms.retain(|_, coeff| *coeff != 0);
997    if terms.is_empty() {
998        return None;
999    }
1000
1001    let name = format_additive_dynamic_dim_expr(&terms, constant)?;
1002    if !is_runtime_resolvable_dynamic_dim_expr(&name) {
1003        return None;
1004    }
1005
1006    let max_size = u32::try_from((limit.max_size as i64).checked_sub(start)?).ok()?;
1007    Some(DynamicDimension { name, max_size })
1008}
1009
1010/// Conversion options for ONNX to WebNN
1011#[derive(Debug, Clone)]
1012pub struct ConvertOptions {
1013    /// Extract weights to external file (default: true)
1014    pub extract_weights: bool,
1015    /// Output file path for graph (.webnn or .json)
1016    pub output_path: String,
1017    /// Weights file path (.weights)
1018    pub weights_path: Option<String>,
1019    /// Manifest file path (.manifest.json)
1020    pub manifest_path: Option<String>,
1021    /// Override dynamic dimension values (e.g., batch_size=1, sequence_length=128)
1022    pub free_dim_overrides: HashMap<String, u32>,
1023    /// Enable constant folding and shape propagation optimizations
1024    pub optimize: bool,
1025    /// Experimental: preserve unresolved dynamic input dimensions in v2 graph metadata
1026    pub experimental_dynamic_inputs: bool,
1027}
1028
1029impl Default for ConvertOptions {
1030    fn default() -> Self {
1031        Self {
1032            extract_weights: true,
1033            output_path: "output.webnn".to_string(),
1034            weights_path: Some("output.weights".to_string()),
1035            manifest_path: Some("output.manifest.json".to_string()),
1036            free_dim_overrides: HashMap::new(),
1037            optimize: false,
1038            experimental_dynamic_inputs: false,
1039        }
1040    }
1041}
1042
1043struct TensorInfo {
1044    _data_type: DataType,
1045    _shape: Vec<i64>,
1046}
1047
1048/// Main converter structure
1049pub struct OnnxConverter {
1050    model: ModelProto,
1051    graph: GraphJson,
1052    _value_info: HashMap<String, TensorInfo>,
1053}
1054
1055impl OnnxConverter {
1056    /// Create a new converter from an ONNX model
1057    pub fn new(model: ModelProto) -> Result<Self, OnnxError> {
1058        let graph_name = if let Some(graph) = &model.graph {
1059            if !graph.name.is_empty() {
1060                graph.name.as_str().to_string()
1061            } else {
1062                "graph".to_string()
1063            }
1064        } else {
1065            "graph".to_string()
1066        };
1067
1068        let graph = GraphJson {
1069            format: "webnn-graph-json".to_string(),
1070            version: 1,
1071            name: Some(graph_name),
1072            quantized: false,
1073            inputs: BTreeMap::new(),
1074            consts: BTreeMap::new(),
1075            nodes: Vec::new(),
1076            outputs: BTreeMap::new(),
1077        };
1078
1079        Ok(Self {
1080            model,
1081            graph,
1082            _value_info: HashMap::new(),
1083        })
1084    }
1085
1086    /// Extract metadata from ONNX model
1087    pub fn extract_metadata(&self) -> Result<(), OnnxError> {
1088        if self.model.graph.is_none() {
1089            return Err(OnnxError::ProtobufError(
1090                "Missing graph in model".to_string(),
1091            ));
1092        }
1093
1094        let graph = self.model.graph.as_ref().unwrap();
1095
1096        // Print basic info
1097        println!("Model name: {}", self.graph.name.as_ref().unwrap());
1098        println!("Inputs: {}", graph.input.as_slice().len());
1099        println!("Outputs: {}", graph.output.as_slice().len());
1100        println!("Nodes: {}", graph.node.as_slice().len());
1101        println!("Initializers: {}", graph.initializer.as_slice().len());
1102
1103        Ok(())
1104    }
1105
1106    /// Convert ONNX model to GraphJson
1107    pub fn convert(mut self, options: &ConvertOptions) -> Result<GraphJson, OnnxError> {
1108        if self.model.graph.is_none() {
1109            return Err(OnnxError::ProtobufError(
1110                "Missing graph in model".to_string(),
1111            ));
1112        }
1113
1114        // Validate opset imports
1115        for import in self.model.opset_import.as_slice() {
1116            let domain = import.domain.as_str();
1117            let version = import.version;
1118            let domain_name = if domain.is_empty() {
1119                "ai.onnx".to_string()
1120            } else {
1121                domain.to_string()
1122            };
1123
1124            if (domain.is_empty() || domain == "ai.onnx")
1125                && !(MIN_SUPPORTED_OPSET..=MAX_SUPPORTED_OPSET).contains(&version)
1126            {
1127                return Err(OnnxError::UnsupportedOpset {
1128                    domain: domain_name,
1129                    version,
1130                });
1131            }
1132        }
1133
1134        let onnx_graph = self.model.graph.as_ref().unwrap();
1135        let mut value_name_map: HashMap<String, String> = HashMap::new();
1136        let mut effective_overrides = options.free_dim_overrides.clone();
1137        let mut inference_overrides = effective_overrides.clone();
1138        let mut value_types: HashMap<String, DataType> = HashMap::new();
1139
1140        // Merge overrides from model metadata if present
1141        for meta in self.model.metadata_props.as_slice() {
1142            if meta
1143                .key
1144                .as_str()
1145                .eq_ignore_ascii_case("freedimensionoverrides")
1146            {
1147                if let Ok(json) = serde_json::from_str::<JsonValue>(meta.value.as_str()) {
1148                    let obj = json
1149                        .get("freeDimensionOverrides")
1150                        .unwrap_or(&json)
1151                        .as_object()
1152                        .cloned();
1153                    if let Some(map) = obj {
1154                        for (name, value) in map {
1155                            if let Some(v) = value.as_u64() {
1156                                effective_overrides.entry(name.clone()).or_insert(v as u32);
1157                            }
1158                        }
1159                    }
1160                }
1161            }
1162        }
1163
1164        // Process inputs (exclude initializers)
1165        let initializer_names: HashSet<String> = onnx_graph
1166            .initializer
1167            .as_slice()
1168            .iter()
1169            .map(|init| init.name.as_str().to_string())
1170            .collect();
1171
1172        let default_dynamic_max_size: u32 = 65_535;
1173        let default_inference_dim_values: HashMap<&str, u32> =
1174            HashMap::from([("batch_size", 1), ("batch", 1), ("n", 1), ("b", 1)]);
1175        let dynamic_max_for_dim = |name: &str| -> u32 {
1176            let lower = name.to_ascii_lowercase();
1177            if lower.contains("past")
1178                || lower.contains("seq")
1179                || lower.contains("length")
1180                || lower == "s"
1181                || lower == "t"
1182            {
1183                4096
1184            } else if lower.contains("batch") || lower == "b" || lower == "n" {
1185                8
1186            } else {
1187                default_dynamic_max_size
1188            }
1189        };
1190
1191        let resolve_dim_override =
1192            |dim_param: &str, overrides: &mut HashMap<String, u32>| -> Option<u32> {
1193                if let Some(v) = overrides.get(dim_param) {
1194                    return Some(*v);
1195                }
1196
1197                let lower = dim_param.to_ascii_lowercase();
1198                if let Some(v) = overrides.get(&lower) {
1199                    return Some(*v);
1200                }
1201                None
1202            };
1203        let resolve_dim_for_inference =
1204            |dim_param: &str, overrides: &mut HashMap<String, u32>| -> Option<u32> {
1205                if let Some(v) = resolve_dim_override(dim_param, overrides) {
1206                    return Some(v);
1207                }
1208                let lower = dim_param.to_ascii_lowercase();
1209                if let Some(v) = default_inference_dim_values.get(lower.as_str()) {
1210                    overrides.insert(dim_param.to_string(), *v);
1211                    return Some(*v);
1212                }
1213                None
1214            };
1215
1216        for input in onnx_graph.input.as_slice() {
1217            let raw_name = input.name.as_str().to_string();
1218            let name = sanitize_identifier(&raw_name);
1219
1220            // Skip if this is an initializer (constant)
1221            if initializer_names.contains(&raw_name) {
1222                continue;
1223            }
1224
1225            // Get type info
1226            if let Some(type_proto) = &input.r#type {
1227                if let Some(TypeProtoValue::TensorType(tensor_type)) = &type_proto.value {
1228                    let data_type = if tensor_type.elem_type != 0 {
1229                        let onnx_type = tensor_type.elem_type;
1230                        map_onnx_data_type(onnx_type)?
1231                    } else {
1232                        DataType::Float32 // Default
1233                    };
1234
1235                    let shape = if let Some(shape_proto) = &tensor_type.shape {
1236                        let mut resolved: Vec<Dimension> = Vec::new();
1237                        for (idx, dim) in shape_proto.dim.iter().enumerate() {
1238                            if let Some(dim_value) = &dim.value {
1239                                match dim_value {
1240                                    DimensionValue::DimValue(v) => {
1241                                        if *v > 0 {
1242                                            resolved.push(Dimension::Static(*v as u32));
1243                                        } else if options.experimental_dynamic_inputs {
1244                                            resolved.push(Dimension::Dynamic(DynamicDimension {
1245                                                name: format!("{}_dim{}", name, idx),
1246                                                max_size: default_dynamic_max_size,
1247                                            }));
1248                                        } else {
1249                                            let dim_hint = format!("{}_dim{}", name, idx);
1250                                            return Err(OnnxError::InvalidShape(format!(
1251                                                "Input '{}' has non-positive dim value ({}) at index {}. \
1252Provide --override-dim {}=<value> or enable --experimental-dynamic-inputs.",
1253                                                raw_name,
1254                                                v,
1255                                                idx,
1256                                                dim_hint
1257                                            )));
1258                                        }
1259                                    }
1260                                    DimensionValue::DimParam(dim_param) => {
1261                                        if let Some(v) = resolve_dim_override(
1262                                            dim_param,
1263                                            &mut effective_overrides,
1264                                        ) {
1265                                            resolved.push(Dimension::Static(v));
1266                                        } else if options.experimental_dynamic_inputs {
1267                                            let max_size = dynamic_max_for_dim(dim_param);
1268                                            resolved.push(Dimension::Dynamic(DynamicDimension {
1269                                                name: dim_param.to_string(),
1270                                                max_size,
1271                                            }));
1272                                        } else if let Some(v) = resolve_dim_for_inference(
1273                                            dim_param,
1274                                            &mut inference_overrides,
1275                                        ) {
1276                                            effective_overrides
1277                                                .entry(dim_param.clone())
1278                                                .or_insert(v);
1279                                            resolved.push(Dimension::Static(v));
1280                                        } else {
1281                                            return Err(OnnxError::InvalidShape(format!(
1282                                                "Input '{}' has unresolved dynamic dimension '{}'. \
1283Provide --override-dim {}=<value> or enable --experimental-dynamic-inputs.",
1284                                                raw_name, dim_param, dim_param
1285                                            )));
1286                                        }
1287                                    }
1288                                }
1289                            } else if options.experimental_dynamic_inputs {
1290                                resolved.push(Dimension::Dynamic(DynamicDimension {
1291                                    name: format!("{}_dim{}", name, idx),
1292                                    max_size: default_dynamic_max_size,
1293                                }));
1294                            } else {
1295                                let dim_hint = format!("{}_dim{}", name, idx);
1296                                return Err(OnnxError::InvalidShape(format!(
1297                                    "Input '{}' has unknown dimension at index {}. \
1298Provide --override-dim {}=<value> or enable --experimental-dynamic-inputs.",
1299                                    raw_name, idx, dim_hint
1300                                )));
1301                            }
1302                        }
1303                        resolved
1304                    } else {
1305                        return Err(OnnxError::InvalidShape(format!(
1306                            "Input '{}' is missing shape information",
1307                            raw_name
1308                        )));
1309                    };
1310
1311                    if shape.is_empty() {
1312                        continue;
1313                    }
1314
1315                    self.graph.inputs.insert(
1316                        name.clone(),
1317                        crate::ast::OperandDesc {
1318                            data_type: data_type.clone(),
1319                            shape,
1320                        },
1321                    );
1322
1323                    value_name_map.insert(raw_name.clone(), name.clone());
1324                    value_name_map.insert(name.clone(), name.clone());
1325                    value_types.insert(raw_name.clone(), data_type.clone());
1326                    value_types.insert(name.clone(), data_type);
1327                }
1328            }
1329        }
1330
1331        // Process initializers (constants/weights)
1332        for initializer in onnx_graph.initializer.as_slice() {
1333            let name = sanitize_identifier(initializer.name.as_str());
1334            let raw_data = initializer.raw_data.as_slice();
1335
1336            // Skip initializers with no data (check both raw_data and typed data fields)
1337            let has_data = !raw_data.is_empty()
1338                || !initializer.float_data.as_slice().is_empty()
1339                || !initializer.int32_data.as_slice().is_empty()
1340                || !initializer.int64_data.as_slice().is_empty()
1341                || !initializer.double_data.as_slice().is_empty();
1342
1343            if !has_data {
1344                crate::debug_println!("Warning: Skipping initializer '{}' with no data", name);
1345                continue;
1346            }
1347
1348            let onnx_type = initializer.data_type;
1349            let data_type = map_onnx_data_type(onnx_type)?;
1350            let shape: Vec<u32> = initializer
1351                .dims
1352                .as_slice()
1353                .iter()
1354                .map(|d| *d as u32)
1355                .collect();
1356
1357            let init = if options.extract_weights {
1358                // External weights reference (use original name for weights file)
1359                crate::ast::ConstInit::Weights {
1360                    r#ref: sanitize_identifier(initializer.name.as_str()),
1361                }
1362            } else {
1363                // Inline bytes
1364                let bytes = raw_data.to_vec();
1365                crate::ast::ConstInit::InlineBytes { bytes }
1366            };
1367
1368            self.graph
1369                .consts
1370                .entry(name.clone())
1371                .or_insert(crate::ast::ConstDecl {
1372                    data_type: data_type.clone(),
1373                    shape,
1374                    init,
1375                });
1376
1377            value_name_map.insert(initializer.name.as_str().to_string(), name.clone());
1378            value_name_map.insert(name.clone(), name.clone());
1379            value_types.insert(initializer.name.as_str().to_string(), data_type.clone());
1380            value_types.insert(name, data_type);
1381        }
1382
1383        // Process nodes using OpRegistry
1384        let registry = crate::onnx::ops::OpRegistry::new();
1385
1386        // Build initializers map for resolving constant shapes
1387        let mut initializers_map = std::collections::HashMap::new();
1388        for initializer in onnx_graph.initializer.as_slice() {
1389            // Skip initializers with no data (check both raw_data and typed data fields)
1390            let has_data = !initializer.raw_data.as_slice().is_empty()
1391                || !initializer.float_data.as_slice().is_empty()
1392                || !initializer.int32_data.as_slice().is_empty()
1393                || !initializer.int64_data.as_slice().is_empty()
1394                || !initializer.double_data.as_slice().is_empty();
1395
1396            if !has_data {
1397                continue;
1398            }
1399            initializers_map.insert(initializer.name.as_str().to_string(), initializer);
1400        }
1401
1402        // Build value_shapes map from value_info and inputs for shape inference
1403        let mut value_shapes = std::collections::HashMap::new();
1404        let mut value_shape_dims = std::collections::HashMap::new();
1405
1406        // Add input shapes (already validated)
1407        for (raw_name, mapped_name) in value_name_map.clone() {
1408            if initializer_names.contains(&raw_name) {
1409                continue;
1410            }
1411            if let Some(input) = onnx_graph
1412                .input
1413                .as_slice()
1414                .iter()
1415                .find(|i| i.name.as_str() == raw_name)
1416            {
1417                if let Some(type_proto) = &input.r#type {
1418                    if let Some(TypeProtoValue::TensorType(tensor_type)) = &type_proto.value {
1419                        if let Some(shape_proto) = &tensor_type.shape {
1420                            let mut shape: Vec<i64> = Vec::new();
1421                            let mut unknown = false;
1422                            for dim in &shape_proto.dim {
1423                                if let Some(dim_value) = &dim.value {
1424                                    match dim_value {
1425                                        DimensionValue::DimValue(v) => {
1426                                            if *v > 0 {
1427                                                shape.push(*v);
1428                                            } else if options.experimental_dynamic_inputs {
1429                                                shape.push(default_dynamic_max_size as i64);
1430                                            } else {
1431                                                unknown = true;
1432                                                break;
1433                                            }
1434                                        }
1435                                        DimensionValue::DimParam(dim_param) => {
1436                                            if let Some(v) = resolve_dim_for_inference(
1437                                                dim_param,
1438                                                &mut inference_overrides,
1439                                            ) {
1440                                                shape.push(v as i64);
1441                                            } else if options.experimental_dynamic_inputs {
1442                                                shape.push(dynamic_max_for_dim(dim_param) as i64);
1443                                            } else {
1444                                                unknown = true;
1445                                                break;
1446                                            }
1447                                        }
1448                                    }
1449                                } else if options.experimental_dynamic_inputs {
1450                                    shape.push(default_dynamic_max_size as i64);
1451                                } else {
1452                                    unknown = true;
1453                                    break;
1454                                }
1455                            }
1456                            if !unknown && !shape.is_empty() {
1457                                value_shapes.insert(raw_name.clone(), shape.clone());
1458                                value_shapes.insert(mapped_name.clone(), shape);
1459                            }
1460                            let mut dims = Vec::new();
1461                            for dim in &shape_proto.dim {
1462                                if let Some(dim_value) = &dim.value {
1463                                    match dim_value {
1464                                        DimensionValue::DimValue(v) => {
1465                                            if *v > 0 {
1466                                                dims.push(crate::ast::Dimension::Static(*v as u32));
1467                                            }
1468                                        }
1469                                        DimensionValue::DimParam(dim_param) => {
1470                                            dims.push(crate::ast::Dimension::Dynamic(
1471                                                crate::ast::DynamicDimension {
1472                                                    name: dim_param.clone(),
1473                                                    max_size: dynamic_max_for_dim(dim_param),
1474                                                },
1475                                            ));
1476                                        }
1477                                    }
1478                                }
1479                            }
1480                            if !dims.is_empty() {
1481                                value_shape_dims.insert(raw_name.clone(), dims.clone());
1482                                value_shape_dims.insert(mapped_name.clone(), dims);
1483                            }
1484                        }
1485                    }
1486                }
1487            }
1488        }
1489
1490        // Add initializer shapes
1491        for initializer in onnx_graph.initializer.as_slice() {
1492            // Skip initializers with no data (check both raw_data and typed data fields)
1493            let has_data = !initializer.raw_data.as_slice().is_empty()
1494                || !initializer.float_data.as_slice().is_empty()
1495                || !initializer.int32_data.as_slice().is_empty()
1496                || !initializer.int64_data.as_slice().is_empty()
1497                || !initializer.double_data.as_slice().is_empty();
1498
1499            if !has_data {
1500                continue;
1501            }
1502            let shape: Vec<i64> = initializer.dims.as_slice().to_vec();
1503            value_shapes.insert(initializer.name.as_str().to_string(), shape);
1504            let dims: Vec<crate::ast::Dimension> = initializer
1505                .dims
1506                .iter()
1507                .copied()
1508                .filter(|d| *d > 0)
1509                .map(|d| crate::ast::Dimension::Static(d as u32))
1510                .collect();
1511            if !dims.is_empty() {
1512                value_shape_dims.insert(initializer.name.as_str().to_string(), dims);
1513            }
1514        }
1515
1516        // Add value_info shapes (intermediate tensors from shape inference)
1517        // Try to resolve dynamic dimensions using overrides
1518        for value_info in onnx_graph.value_info.as_slice() {
1519            if let Some(type_proto) = &value_info.r#type {
1520                if let Some(TypeProtoValue::TensorType(tensor_type)) = &type_proto.value {
1521                    if let Some(shape_proto) = &tensor_type.shape {
1522                        let mut shape: Vec<i64> = Vec::new();
1523                        let mut unknown = false;
1524
1525                        for dim in &shape_proto.dim {
1526                            if let Some(dim_value) = &dim.value {
1527                                match dim_value {
1528                                    DimensionValue::DimValue(v) => {
1529                                        if *v > 0 {
1530                                            shape.push(*v);
1531                                        } else if options.experimental_dynamic_inputs {
1532                                            shape.push(default_dynamic_max_size as i64);
1533                                        } else {
1534                                            unknown = true;
1535                                            break;
1536                                        }
1537                                    }
1538                                    DimensionValue::DimParam(dim_param) => {
1539                                        if let Some(v) = resolve_dim_for_inference(
1540                                            dim_param,
1541                                            &mut inference_overrides,
1542                                        ) {
1543                                            shape.push(v as i64);
1544                                        } else if options.experimental_dynamic_inputs {
1545                                            shape.push(dynamic_max_for_dim(dim_param) as i64);
1546                                        } else {
1547                                            unknown = true;
1548                                            break;
1549                                        }
1550                                    }
1551                                }
1552                            } else if options.experimental_dynamic_inputs {
1553                                shape.push(default_dynamic_max_size as i64);
1554                            } else {
1555                                unknown = true;
1556                                break;
1557                            }
1558                        }
1559
1560                        if !unknown && !shape.is_empty() && shape.iter().all(|&d| d > 0) {
1561                            value_shapes.insert(value_info.name.as_str().to_string(), shape);
1562                        }
1563                        let mut dims = Vec::new();
1564                        for dim in &shape_proto.dim {
1565                            if let Some(dim_value) = &dim.value {
1566                                match dim_value {
1567                                    DimensionValue::DimValue(v) => {
1568                                        if *v > 0 {
1569                                            dims.push(crate::ast::Dimension::Static(*v as u32));
1570                                        }
1571                                    }
1572                                    DimensionValue::DimParam(dim_param) => {
1573                                        dims.push(crate::ast::Dimension::Dynamic(
1574                                            crate::ast::DynamicDimension {
1575                                                name: dim_param.clone(),
1576                                                max_size: dynamic_max_for_dim(dim_param),
1577                                            },
1578                                        ));
1579                                    }
1580                                }
1581                            }
1582                        }
1583                        if !dims.is_empty() {
1584                            value_shape_dims.insert(value_info.name.as_str().to_string(), dims);
1585                        }
1586                    }
1587                }
1588            }
1589        }
1590
1591        // Seed const values with integer initializers and Constant nodes
1592        let mut const_values: HashMap<String, Vec<i64>> = HashMap::new();
1593        for (name, initializer) in &initializers_map {
1594            if initializer.data_type == TensorProto_DataType::Int64 as i32
1595                || initializer.data_type == TensorProto_DataType::Int32 as i32
1596            {
1597                let raw = initializer.raw_data.as_slice();
1598                let values = if !raw.is_empty() {
1599                    if initializer.data_type == TensorProto_DataType::Int32 as i32 {
1600                        raw.chunks_exact(4)
1601                            .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as i64)
1602                            .collect()
1603                    } else {
1604                        raw.chunks_exact(8)
1605                            .map(|c| {
1606                                i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
1607                            })
1608                            .collect()
1609                    }
1610                } else if !initializer.int64_data.as_slice().is_empty() {
1611                    initializer.int64_data.as_slice().to_vec()
1612                } else if !initializer.int32_data.as_slice().is_empty() {
1613                    initializer
1614                        .int32_data
1615                        .as_slice()
1616                        .iter()
1617                        .map(|&v| v as i64)
1618                        .collect()
1619                } else {
1620                    Vec::new()
1621                };
1622
1623                if !values.is_empty() {
1624                    const_values.insert(name.clone(), values);
1625                }
1626            }
1627        }
1628
1629        for node in onnx_graph.node.as_slice() {
1630            if node.op_type.as_str() == "Constant" {
1631                if let Some(attr) = node
1632                    .attribute
1633                    .as_slice()
1634                    .iter()
1635                    .find(|a| a.name.as_str() == "value" && a.t.is_some())
1636                {
1637                    let tensor = attr.t.as_ref().unwrap();
1638                    if tensor.data_type == TensorProto_DataType::Int64 as i32
1639                        || tensor.data_type == TensorProto_DataType::Int32 as i32
1640                    {
1641                        let raw = tensor.raw_data.as_slice();
1642                        let values = if !raw.is_empty() {
1643                            if tensor.data_type == TensorProto_DataType::Int32 as i32 {
1644                                raw.chunks_exact(4)
1645                                    .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as i64)
1646                                    .collect()
1647                            } else {
1648                                raw.chunks_exact(8)
1649                                    .map(|c| {
1650                                        i64::from_le_bytes([
1651                                            c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7],
1652                                        ])
1653                                    })
1654                                    .collect()
1655                            }
1656                        } else if !tensor.int64_data.as_slice().is_empty() {
1657                            tensor.int64_data.as_slice().to_vec()
1658                        } else if !tensor.int32_data.as_slice().is_empty() {
1659                            tensor
1660                                .int32_data
1661                                .as_slice()
1662                                .iter()
1663                                .map(|&v| v as i64)
1664                                .collect()
1665                        } else {
1666                            Vec::new()
1667                        };
1668
1669                        if let Some(out) = node.output.as_slice().first() {
1670                            if !values.is_empty() {
1671                                const_values.insert(out.to_string(), values);
1672                                value_types.insert(out.to_string(), DataType::Int64);
1673                            }
1674                        }
1675                    }
1676                }
1677            }
1678        }
1679
1680        // Run the static shape/type inference scaffold to seed shapes/types/constants
1681        // before lowering. Errors surface early if dynamic dims remain.
1682        let mut dynamic_inference_attempts: HashSet<String> = HashSet::new();
1683        loop {
1684            match crate::onnx::shape_inference::infer_static_shapes(
1685                &self.model,
1686                &inference_overrides,
1687            ) {
1688                Ok(inferred) => {
1689                    // Initial seeding: use or_insert since these are the first values
1690                    // (no prior shapes to override)
1691                    for (k, v) in inferred.value_shapes {
1692                        value_shapes.entry(k).or_insert(v);
1693                    }
1694                    for (k, v) in inferred.value_types {
1695                        value_types.entry(k).or_insert(v);
1696                    }
1697                    for (k, v) in inferred.const_values {
1698                        // Use insert() instead of or_insert() to allow shape inference to correct
1699                        // earlier wrong values (e.g., Where operation heuristics)
1700                        if k.contains("rotary") && k.contains("Where") {
1701                            if let Some(old_val) = const_values.get(&k) {
1702                                crate::debug_println!(
1703                                    "[CONVERT] Overwriting {} from {:?} to {:?}",
1704                                    k,
1705                                    old_val,
1706                                    v
1707                                );
1708                            } else {
1709                                crate::debug_println!("[CONVERT] Inserting new {} = {:?}", k, v);
1710                            }
1711                        }
1712                        const_values.insert(k, v);
1713                    }
1714                    break;
1715                }
1716                Err(crate::onnx::shape_inference::ShapeInferenceError::DynamicDim {
1717                    input,
1718                    dim,
1719                }) => {
1720                    if options.experimental_dynamic_inputs
1721                        && !dynamic_inference_attempts.contains(dim.as_str())
1722                    {
1723                        let fallback = dynamic_max_for_dim(&dim);
1724                        inference_overrides.insert(dim.clone(), fallback);
1725                        dynamic_inference_attempts.insert(dim.clone());
1726                        crate::debug_println!(
1727                            "[CONVERT] Retrying static shape inference with inferred override {}={} \
1728                             (required by input '{}')",
1729                            dim,
1730                            fallback,
1731                            input
1732                        );
1733                        continue;
1734                    }
1735                    crate::debug_println!(
1736                        "[CONVERT] Skipping static shape inference due to unresolved dynamic dim '{}' on input '{}'",
1737                        dim,
1738                        input
1739                    );
1740                    break;
1741                }
1742                Err(e) => return Err(OnnxError::ShapeInference(e.to_string())),
1743            }
1744        }
1745
1746        // Propagate shapes and fold constant shape expressions in a few passes
1747        for _ in 0..3 {
1748            if options.optimize {
1749                let max_iterations = 10;
1750                for iteration in 0..max_iterations {
1751                    let initial_count = value_shapes.len();
1752
1753                    for onnx_node in onnx_graph.node.as_slice() {
1754                        let all_outputs_known = onnx_node
1755                            .output
1756                            .as_slice()
1757                            .iter()
1758                            .all(|out| value_shapes.contains_key(out.as_str()));
1759                        if all_outputs_known {
1760                            continue;
1761                        }
1762
1763                        if let Some(inferred) =
1764                            infer_shape(onnx_node, &value_shapes, &initializers_map, &const_values)
1765                        {
1766                            if let Some(output_name) = onnx_node.output.as_slice().first() {
1767                                // Debug: track shape changes for layer 15 operations
1768                                if output_name.contains("layers_15_self_attn")
1769                                    && (output_name.contains("Reshape")
1770                                        || output_name.contains("Transpose"))
1771                                {
1772                                    crate::debug_println!(
1773                                        "[SHAPE DEBUG] {} {} -> {:?}",
1774                                        onnx_node.op_type.as_str(),
1775                                        output_name,
1776                                        inferred
1777                                    );
1778                                }
1779                                // Force the correct shape - shape inference computes exact output shape
1780                                value_shapes.insert(output_name.to_string(), inferred);
1781                            }
1782                        }
1783                    }
1784
1785                    if value_shapes.len() == initial_count {
1786                        break;
1787                    }
1788
1789                    if iteration == max_iterations - 1 {
1790                        crate::debug_println!(
1791                            "Warning: Shape propagation reached max iterations ({}/{})",
1792                            value_shapes.len(),
1793                            onnx_graph.node.as_slice().len()
1794                        );
1795                    }
1796                }
1797            }
1798
1799            // If we know the input_ids shape (batch, seq), upgrade any lone hidden-dim
1800            // tensors (length-1 shapes) to [batch, seq, hidden] to unblock downstream
1801            // matmul/reshape resolution in decoder graphs that lost batch/seq dims.
1802            if let Some(ids_shape) = value_shapes.get("input_ids") {
1803                if ids_shape.len() == 2 {
1804                    let (batch, seq) = (ids_shape[0], ids_shape[1]);
1805                    let upgrades: Vec<(String, Vec<i64>)> = value_shapes
1806                        .iter()
1807                        .filter_map(|(k, v)| {
1808                            if v.len() == 1 && v[0] > 1 {
1809                                Some((k.clone(), vec![batch, seq, v[0]]))
1810                            } else {
1811                                None
1812                            }
1813                        })
1814                        .collect();
1815                    for (k, v) in upgrades {
1816                        value_shapes.insert(k, v);
1817                    }
1818                }
1819            }
1820
1821            crate::debug_println!(
1822                "[debug] layer_norm shape {:?}",
1823                value_shapes.get("/decoder/block.0/layer.0/layer_norm/Mul_1_output_0")
1824            );
1825            crate::debug_println!(
1826                "[debug] matmul q shape {:?}",
1827                value_shapes.get("/decoder/block.0/layer.0/SelfAttention/q/MatMul_output_0")
1828            );
1829            crate::debug_println!(
1830                "[debug] input_ids shape {:?}",
1831                value_shapes.get("input_ids")
1832            );
1833            crate::debug_println!(
1834                "[debug] ln div shape {:?}",
1835                value_shapes.get("/decoder/block.0/layer.0/layer_norm/Div_output_0")
1836            );
1837
1838            let consts_before = const_values.len();
1839
1840            // DEBUG: Check value before propagation
1841            if let Some(val) = const_values.get("/model/rotary_emb/Where_output_0") {
1842                crate::debug_println!("[PROP BEFORE] /model/rotary_emb/Where_output_0 = {:?}", val);
1843            }
1844
1845            // Extend const value map for const-foldable shapes
1846            for node in onnx_graph.node.as_slice() {
1847                let op_type = node.op_type.as_str();
1848                if op_type == "Shape" {
1849                    if let (Some(inp), Some(out)) = (
1850                        node.input.as_slice().first(),
1851                        node.output.as_slice().first(),
1852                    ) {
1853                        let out = out.to_string();
1854                        if let Some(shape) = value_shapes.get(inp).cloned() {
1855                            if shape.iter().all(|d| *d > 0) {
1856                                // Propagate dynamic dim metadata: Shape output is a 1-D
1857                                // tensor whose elements correspond to input dimensions.
1858                                if options.experimental_dynamic_inputs {
1859                                    let inp_s = inp.to_string();
1860                                    if let Some(dims) = value_shape_dims.get(&inp_s).or_else(|| {
1861                                        value_shape_dims.get(&sanitize_identifier(&inp_s))
1862                                    }) {
1863                                        // Each element of the Shape output corresponds to one
1864                                        // input dimension.  Build a 1-D dim vector where
1865                                        // dynamic input dims become Dynamic elements.
1866                                        let out_dims: Vec<crate::ast::Dimension> = dims
1867                                            .iter()
1868                                            .map(|d| match d {
1869                                                crate::ast::Dimension::Dynamic(dd) => {
1870                                                    crate::ast::Dimension::Dynamic(dd.clone())
1871                                                }
1872                                                crate::ast::Dimension::Static(v) => {
1873                                                    crate::ast::Dimension::Static(*v)
1874                                                }
1875                                            })
1876                                            .collect();
1877                                        value_shape_dims.insert(out.clone(), out_dims);
1878                                    }
1879                                }
1880                                const_values.insert(out.clone(), shape.clone());
1881                                let inferred_shape = vec![shape.len() as i64];
1882                                // Force the correct shape - Shape operation computes exact output shape
1883                                value_shapes.insert(out.clone(), inferred_shape.clone());
1884                                value_shapes.insert(sanitize_identifier(&out), inferred_shape);
1885                                value_types.insert(out, DataType::Int64);
1886                            }
1887                        }
1888                    }
1889                } else if op_type == "Gather" {
1890                    if let (Some(data_name), Some(indices_name), Some(out)) = (
1891                        node.input.as_slice().first(),
1892                        node.input.as_slice().get(1),
1893                        node.output.as_slice().first(),
1894                    ) {
1895                        if let (Some(data), Some(indices)) =
1896                            (const_values.get(data_name), const_values.get(indices_name))
1897                        {
1898                            let axis = node
1899                                .attribute
1900                                .as_slice()
1901                                .iter()
1902                                .find(|a| a.name.as_str() == "axis" && a.i != 0)
1903                                .map(|a| a.i)
1904                                .unwrap_or(0);
1905
1906                            if axis == 0 {
1907                                let mut gathered = Vec::new();
1908                                let mut gathered_dims = Vec::new();
1909                                let data_dims = if options.experimental_dynamic_inputs {
1910                                    value_shape_dims
1911                                        .get(data_name)
1912                                        .or_else(|| {
1913                                            value_shape_dims.get(&sanitize_identifier(data_name))
1914                                        })
1915                                        .cloned()
1916                                } else {
1917                                    None
1918                                };
1919                                for &idx in indices {
1920                                    let i = if idx < 0 {
1921                                        (data.len() as i64 + idx) as usize
1922                                    } else {
1923                                        idx as usize
1924                                    };
1925                                    if let Some(v) = data.get(i) {
1926                                        gathered.push(*v);
1927                                        if let Some(ref dd) = data_dims {
1928                                            if let Some(dim) = dd.get(i) {
1929                                                gathered_dims.push(dim.clone());
1930                                            }
1931                                        }
1932                                    }
1933                                }
1934                                if !gathered.is_empty() {
1935                                    if options.experimental_dynamic_inputs
1936                                        && gathered_dims.len() == gathered.len()
1937                                        && gathered_dims
1938                                            .iter()
1939                                            .any(|d| matches!(d, crate::ast::Dimension::Dynamic(_)))
1940                                    {
1941                                        value_shape_dims.insert(out.to_string(), gathered_dims);
1942                                    }
1943                                    const_values.insert(out.to_string(), gathered.clone());
1944                                    let out_shape = if gathered.len() == 1 {
1945                                        Vec::new()
1946                                    } else {
1947                                        vec![gathered.len() as i64]
1948                                    };
1949                                    // Force the correct shape - Gather operation computes exact output shape
1950                                    value_shapes.insert(out.to_string(), out_shape.clone());
1951                                    value_shapes.insert(sanitize_identifier(out), out_shape);
1952                                    value_types.insert(out.to_string(), DataType::Int64);
1953                                }
1954                            }
1955                        }
1956                    }
1957                } else if matches!(op_type, "Add" | "Sub" | "Mul" | "Div") {
1958                    if node.input.as_slice().len() >= 2 {
1959                        if let (Some(a_name), Some(b_name), Some(out)) = (
1960                            node.input.as_slice().first(),
1961                            node.input.as_slice().get(1),
1962                            node.output.as_slice().first(),
1963                        ) {
1964                            let a = const_values.get(a_name);
1965                            let b = const_values.get(b_name);
1966                            if let (Some(a), Some(b)) = (a, b) {
1967                                let a_shape = const_shape_for_folding(a_name, a, &value_shapes);
1968                                let b_shape = const_shape_for_folding(b_name, b, &value_shapes);
1969                                if let Some((result_vals, out_shape)) =
1970                                    fold_binary_const_i64(op_type, a, b, &a_shape, &b_shape)
1971                                {
1972                                    if options.experimental_dynamic_inputs {
1973                                        let a_dims =
1974                                            value_shape_dims_for(a_name, &value_shape_dims);
1975                                        let b_dims =
1976                                            value_shape_dims_for(b_name, &value_shape_dims);
1977                                        if let Some(out_dims) = fold_binary_dynamic_dims(
1978                                            op_type, a, b, &a_shape, &b_shape, a_dims, b_dims,
1979                                        ) {
1980                                            value_shape_dims.insert(out.to_string(), out_dims);
1981                                        }
1982                                    }
1983                                    const_values.insert(out.to_string(), result_vals.clone());
1984                                    // Force the correct shape - Binary operations compute exact output shape
1985                                    value_shapes.insert(out.to_string(), out_shape.clone());
1986                                    value_shapes.insert(sanitize_identifier(out), out_shape);
1987                                    if let Some(dtype) = node
1988                                        .input
1989                                        .as_slice()
1990                                        .iter()
1991                                        .find_map(|i| value_types.get(i).cloned())
1992                                    {
1993                                        value_types.insert(out.to_string(), dtype);
1994                                    }
1995                                }
1996                            }
1997                        }
1998                    }
1999                } else if op_type == "Cast" || op_type == "Unsqueeze" || op_type == "Squeeze" {
2000                    if let (Some(inp), Some(out)) = (
2001                        node.input.as_slice().first(),
2002                        node.output.as_slice().first(),
2003                    ) {
2004                        if let Some(vals) = const_values.get(inp).cloned() {
2005                            // Propagate dynamic dim metadata
2006                            if options.experimental_dynamic_inputs {
2007                                if let Some(dims) = value_shape_dims
2008                                    .get(inp)
2009                                    .or_else(|| value_shape_dims.get(&sanitize_identifier(inp)))
2010                                    .cloned()
2011                                {
2012                                    value_shape_dims.insert(out.to_string(), dims);
2013                                }
2014                            }
2015                            const_values.insert(out.to_string(), vals.clone());
2016                            let out_shape = if vals.len() == 1 {
2017                                Vec::new()
2018                            } else {
2019                                vec![vals.len() as i64]
2020                            };
2021                            // Force the correct shape - Cast/Unsqueeze/Squeeze compute exact output shape
2022                            value_shapes.insert(out.to_string(), out_shape);
2023                            if let Some(dtype) = value_types.get(inp).cloned() {
2024                                value_types.insert(out.to_string(), dtype);
2025                            }
2026                        }
2027                    }
2028                } else if op_type == "Range" {
2029                    if node.input.as_slice().len() == 3 {
2030                        if let (Some(start_name), Some(limit_name), Some(delta_name)) = (
2031                            node.input.as_slice().first(),
2032                            node.input.as_slice().get(1),
2033                            node.input.as_slice().get(2),
2034                        ) {
2035                            if options.experimental_dynamic_inputs {
2036                                let start_dim = dynamic_scalar_dimension_for_value(
2037                                    start_name,
2038                                    &value_shape_dims,
2039                                );
2040                                if let Some(limit_dim) = dynamic_scalar_dimension_for_value(
2041                                    limit_name,
2042                                    &value_shape_dims,
2043                                ) {
2044                                    if let (Some(start_vals), Some(delta_vals), Some(out)) = (
2045                                        const_values.get(start_name),
2046                                        const_values.get(delta_name),
2047                                        node.output.as_slice().first(),
2048                                    ) {
2049                                        if !start_vals.is_empty() && !delta_vals.is_empty() {
2050                                            let start = start_vals[0];
2051                                            let delta = delta_vals[0];
2052                                            if let Some(range_dim) = dynamic_range_length_dimension(
2053                                                start,
2054                                                delta,
2055                                                start_dim.as_ref(),
2056                                                &limit_dim,
2057                                            ) {
2058                                                let out_shape = vec![range_dim.max_size as i64];
2059                                                value_shape_dims.insert(
2060                                                    out.to_string(),
2061                                                    vec![Dimension::Dynamic(range_dim.clone())],
2062                                                );
2063                                                value_shapes
2064                                                    .insert(out.to_string(), out_shape.clone());
2065                                                value_shapes
2066                                                    .insert(sanitize_identifier(out), out_shape);
2067                                                value_types
2068                                                    .insert(out.to_string(), DataType::Int64);
2069                                            }
2070                                        }
2071                                    }
2072                                    continue;
2073                                }
2074                            }
2075
2076                            // Range(start, limit, delta) -> [start, start+delta, start+2*delta, ...]
2077                            if let (Some(start_vals), Some(limit_vals), Some(delta_vals)) = (
2078                                const_values.get(start_name),
2079                                const_values.get(limit_name),
2080                                const_values.get(delta_name),
2081                            ) {
2082                                if !start_vals.is_empty()
2083                                    && !limit_vals.is_empty()
2084                                    && !delta_vals.is_empty()
2085                                {
2086                                    let start = start_vals[0];
2087                                    let limit = limit_vals[0];
2088                                    let delta = delta_vals[0];
2089
2090                                    let mut range_vals = Vec::new();
2091                                    if delta > 0 {
2092                                        let mut current = start;
2093                                        while current < limit {
2094                                            range_vals.push(current);
2095                                            current += delta;
2096                                        }
2097                                    } else if delta < 0 {
2098                                        let mut current = start;
2099                                        while current > limit {
2100                                            range_vals.push(current);
2101                                            current += delta;
2102                                        }
2103                                    }
2104
2105                                    if let Some(out) = node.output.as_slice().first() {
2106                                        const_values.insert(out.to_string(), range_vals.clone());
2107                                        let out_shape = vec![range_vals.len() as i64];
2108                                        // Force the correct shape - Range computes exact output shape
2109                                        value_shapes.insert(out.to_string(), out_shape.clone());
2110                                        value_shapes.insert(sanitize_identifier(out), out_shape);
2111                                        value_types.insert(out.to_string(), DataType::Int64);
2112                                    }
2113                                }
2114                            }
2115                        }
2116                    }
2117                } else if op_type == "Concat" {
2118                    // Concatenate constant inputs (often used to build shape tensors)
2119                    if let Some(out) = node.output.as_slice().first() {
2120                        let mut concatenated: Vec<i64> = Vec::new();
2121                        let mut all_const = true;
2122                        for inp in node.input.as_slice() {
2123                            if let Some(vals) = const_values.get(inp) {
2124                                concatenated.extend_from_slice(vals);
2125                            } else {
2126                                all_const = false;
2127                                break;
2128                            }
2129                        }
2130
2131                        // Handle axis=0 or axis=-1 (common for shape building)
2132                        let axis = node
2133                            .attribute
2134                            .as_slice()
2135                            .iter()
2136                            .find(|a| a.name.as_str() == "axis" && a.i != 0)
2137                            .map(|a| a.i)
2138                            .unwrap_or(0);
2139
2140                        if all_const && (axis == 0 || axis == -1) {
2141                            if out.contains("rotary") && out.contains("Where") {
2142                                crate::debug_println!(
2143                                    "[CONCAT WRITE] Writing {} = {:?}",
2144                                    out,
2145                                    concatenated
2146                                );
2147                            }
2148                            // Propagate dynamic dim metadata through concat
2149                            if options.experimental_dynamic_inputs {
2150                                let mut concat_dims: Vec<crate::ast::Dimension> = Vec::new();
2151                                let mut has_dynamic = false;
2152                                for inp in node.input.as_slice() {
2153                                    let inp_s = inp.to_string();
2154                                    if let Some(dims) = value_shape_dims.get(&inp_s).or_else(|| {
2155                                        value_shape_dims.get(&sanitize_identifier(&inp_s))
2156                                    }) {
2157                                        for d in dims {
2158                                            if matches!(d, crate::ast::Dimension::Dynamic(_)) {
2159                                                has_dynamic = true;
2160                                            }
2161                                            concat_dims.push(d.clone());
2162                                        }
2163                                    } else if let Some(vals) = const_values.get(inp) {
2164                                        for v in vals {
2165                                            concat_dims
2166                                                .push(crate::ast::Dimension::Static(*v as u32));
2167                                        }
2168                                    }
2169                                }
2170                                if has_dynamic && concat_dims.len() == concatenated.len() {
2171                                    value_shape_dims.insert(out.to_string(), concat_dims);
2172                                }
2173                            }
2174                            const_values.insert(out.to_string(), concatenated.clone());
2175                            let out_shape = vec![concatenated.len() as i64];
2176                            // Force the correct shape - Concat computes exact output shape
2177                            value_shapes.insert(out.to_string(), out_shape.clone());
2178                            value_shapes.insert(sanitize_identifier(out), out_shape);
2179                            value_types.insert(out.to_string(), DataType::Int64);
2180                        }
2181                    }
2182                } else if op_type == "ConstantOfShape" {
2183                    // ConstantOfShape(shape) -> tensor filled with constant value
2184                    if let Some(shape_name) = node.input.as_slice().first() {
2185                        let dynamic_output_dims = if options.experimental_dynamic_inputs {
2186                            value_shape_dims_for(shape_name, &value_shape_dims)
2187                                .map(|dims| dims.to_vec())
2188                                .filter(|dims| dims_contain_dynamic(dims))
2189                        } else {
2190                            None
2191                        };
2192
2193                        if let (Some(out), Some(dims)) =
2194                            (node.output.as_slice().first(), dynamic_output_dims.as_ref())
2195                        {
2196                            value_shape_dims.insert(out.to_string(), dims.to_vec());
2197                            const_values.remove(out.as_str());
2198                        }
2199
2200                        if let Some(shape_vals) = const_values.get(shape_name).cloned() {
2201                            // Get the fill value from attributes (default is 0)
2202                            let mut fill_value = 0i64;
2203                            for attr in node.attribute.as_slice() {
2204                                if attr.name.as_str() == "value" {
2205                                    if let Some(value_tensor) = attr.t.as_ref() {
2206                                        if value_tensor.data_type
2207                                            == crate::protos::onnx::TensorProto_DataType::Int64
2208                                                as i32
2209                                        {
2210                                            let raw = value_tensor.raw_data.as_slice();
2211                                            if !raw.is_empty() && raw.len() >= 8 {
2212                                                fill_value = i64::from_le_bytes([
2213                                                    raw[0], raw[1], raw[2], raw[3], raw[4], raw[5],
2214                                                    raw[6], raw[7],
2215                                                ]);
2216                                            } else if !value_tensor.int64_data.as_slice().is_empty()
2217                                            {
2218                                                fill_value = value_tensor.int64_data.as_slice()[0];
2219                                            }
2220                                        }
2221                                    }
2222                                }
2223                            }
2224
2225                            // Calculate number of elements
2226                            let numel = if shape_vals.is_empty() {
2227                                1
2228                            } else {
2229                                shape_vals.iter().product::<i64>()
2230                            };
2231
2232                            if numel > 0 && numel < 1_000_000 {
2233                                // Reasonable size limit
2234                                let filled_tensor = vec![fill_value; numel as usize];
2235                                if let Some(out) = node.output.as_slice().first() {
2236                                    let should_keep_const = dynamic_output_dims
2237                                        .as_ref()
2238                                        .is_none_or(|dims| !dims_contain_dynamic(dims));
2239                                    if should_keep_const {
2240                                        const_values.insert(out.to_string(), filled_tensor);
2241                                    } else {
2242                                        const_values.remove(out.as_str());
2243                                    }
2244                                    // Force the correct shape - ConstantOfShape creates exact output shape
2245                                    value_shapes.insert(out.to_string(), shape_vals.clone());
2246                                    value_shapes
2247                                        .insert(sanitize_identifier(out), shape_vals.clone());
2248                                    value_types.insert(out.to_string(), DataType::Int64);
2249                                }
2250                            }
2251                        }
2252                    }
2253                } else if op_type == "Equal" {
2254                    // Equal(a, b) -> boolean tensor (represented as i64: 1 for true, 0 for false)
2255                    if node.input.as_slice().len() >= 2 {
2256                        if let (Some(a_name), Some(b_name), Some(out)) = (
2257                            node.input.as_slice().first(),
2258                            node.input.as_slice().get(1),
2259                            node.output.as_slice().first(),
2260                        ) {
2261                            let a = const_values.get(a_name);
2262                            let b = const_values.get(b_name);
2263                            if let (Some(a), Some(b)) = (a, b) {
2264                                let a_shape = const_shape_for_folding(a_name, a, &value_shapes);
2265                                let b_shape = const_shape_for_folding(b_name, b, &value_shapes);
2266                                if let Some((result_vals, out_shape)) =
2267                                    fold_binary_const_i64("Equal", a, b, &a_shape, &b_shape)
2268                                {
2269                                    const_values.insert(out.to_string(), result_vals.clone());
2270                                    // Force the correct shape - Equal operation computes exact output shape
2271                                    value_shapes.insert(out.to_string(), out_shape.clone());
2272                                    value_shapes.insert(sanitize_identifier(out), out_shape);
2273                                    value_types.insert(out.to_string(), DataType::Int64);
2274                                }
2275                            }
2276                        }
2277                    }
2278                } else if op_type == "Where" {
2279                    if options.experimental_dynamic_inputs && node.input.as_slice().len() >= 3 {
2280                        if let Some(out) = node.output.as_slice().first() {
2281                            let cond = const_values.get(node.input.as_slice()[0].as_str());
2282                            let a_dims = dimension_vector_for_value(
2283                                node.input.as_slice()[1].as_str(),
2284                                &const_values,
2285                                &value_shape_dims,
2286                            );
2287                            let b_dims = dimension_vector_for_value(
2288                                node.input.as_slice()[2].as_str(),
2289                                &const_values,
2290                                &value_shape_dims,
2291                            );
2292                            let out_dims = if let (Some(cond), Some(a_dims), Some(b_dims)) =
2293                                (cond, a_dims.as_ref(), b_dims.as_ref())
2294                            {
2295                                if cond.len() == 1 && a_dims.len() == b_dims.len() {
2296                                    Some(if cond[0] != 0 {
2297                                        a_dims.clone()
2298                                    } else {
2299                                        b_dims.clone()
2300                                    })
2301                                } else if cond.len() == a_dims.len() && cond.len() == b_dims.len() {
2302                                    Some(
2303                                        cond.iter()
2304                                            .enumerate()
2305                                            .map(|(idx, c)| {
2306                                                if *c != 0 {
2307                                                    a_dims[idx].clone()
2308                                                } else {
2309                                                    b_dims[idx].clone()
2310                                                }
2311                                            })
2312                                            .collect(),
2313                                    )
2314                                } else {
2315                                    None
2316                                }
2317                            } else if let (Some(a_dims), Some(b_dims)) =
2318                                (a_dims.as_ref(), b_dims.as_ref())
2319                            {
2320                                let a_has_dynamic =
2321                                    a_dims.iter().any(|d| matches!(d, Dimension::Dynamic(_)));
2322                                let b_has_dynamic =
2323                                    b_dims.iter().any(|d| matches!(d, Dimension::Dynamic(_)));
2324                                if a_has_dynamic && !b_has_dynamic {
2325                                    Some(a_dims.clone())
2326                                } else if b_has_dynamic && !a_has_dynamic {
2327                                    Some(b_dims.clone())
2328                                } else if a_has_dynamic
2329                                    && b_has_dynamic
2330                                    && a_dims.len() == b_dims.len()
2331                                {
2332                                    Some(
2333                                        a_dims
2334                                            .iter()
2335                                            .zip(b_dims.iter())
2336                                            .map(|(a_dim, b_dim)| match (a_dim, b_dim) {
2337                                                (Dimension::Dynamic(dim), _) => {
2338                                                    Dimension::Dynamic(dim.clone())
2339                                                }
2340                                                (_, Dimension::Dynamic(dim)) => {
2341                                                    Dimension::Dynamic(dim.clone())
2342                                                }
2343                                                (Dimension::Static(v), _) => Dimension::Static(*v),
2344                                            })
2345                                            .collect(),
2346                                    )
2347                                } else {
2348                                    None
2349                                }
2350                            } else if let Some(a_dims) = a_dims.as_ref() {
2351                                if a_dims.iter().any(|d| matches!(d, Dimension::Dynamic(_)))
2352                                    && !is_trivial_static_dimension_vector(a_dims)
2353                                {
2354                                    Some(a_dims.clone())
2355                                } else {
2356                                    None
2357                                }
2358                            } else if let Some(b_dims) = b_dims.as_ref() {
2359                                if b_dims.iter().any(|d| matches!(d, Dimension::Dynamic(_)))
2360                                    && !is_trivial_static_dimension_vector(b_dims)
2361                                {
2362                                    Some(b_dims.clone())
2363                                } else {
2364                                    None
2365                                }
2366                            } else {
2367                                None
2368                            };
2369
2370                            if let Some(out_dims) = out_dims {
2371                                if out_dims.iter().any(|d| matches!(d, Dimension::Dynamic(_))) {
2372                                    value_shape_dims.insert(out.to_string(), out_dims);
2373                                }
2374                            }
2375                        }
2376                    }
2377                    // Keep Where dynamic to avoid baking shape-driving expressions
2378                    // (e.g., past_sequence_length + 1) into fixed constants.
2379                    continue;
2380                }
2381            }
2382
2383            if const_values.len() == consts_before {
2384                break;
2385            }
2386
2387            // DEBUG: Check value after propagation pass
2388            if let Some(val) = const_values.get("/model/rotary_emb/Where_output_0") {
2389                crate::debug_println!("[PROP AFTER] /model/rotary_emb/Where_output_0 = {:?}", val);
2390            }
2391        }
2392
2393        // DEBUG: Check value before node conversion
2394        if let Some(val) = const_values.get("/model/rotary_emb/Where_output_0") {
2395            crate::debug_println!("[NODE CONV] /model/rotary_emb/Where_output_0 = {:?}", val);
2396        }
2397        for onnx_node in onnx_graph.node.as_slice() {
2398            // If all outputs are compile-time constants, emit them directly and skip conversion
2399            let outputs = onnx_node.output.as_slice();
2400            let has_dynamic_output_metadata = outputs.iter().any(|o| {
2401                value_shape_dims_for(o.as_str(), &value_shape_dims)
2402                    .map(|dims| dims.iter().any(|d| matches!(d, Dimension::Dynamic(_))))
2403                    .unwrap_or(false)
2404            });
2405            if !outputs.is_empty()
2406                && !has_dynamic_output_metadata
2407                && outputs
2408                    .iter()
2409                    .all(|o| const_values.contains_key(o.as_str()))
2410            {
2411                // Check if outputs are true scalars (rank 0), not just single-element tensors
2412                let all_scalar = outputs.iter().all(|o| {
2413                    value_shapes
2414                        .get(o.as_str())
2415                        .map(|s| s.is_empty()) // True scalar has empty shape
2416                        .unwrap_or_else(|| {
2417                            // Fallback: check if data length is 1
2418                            const_values
2419                                .get(o.as_str())
2420                                .map(|v| v.len() == 1)
2421                                .unwrap_or(false)
2422                        })
2423                });
2424
2425                // Handle scalar constants by emitting them inline
2426                if all_scalar {
2427                    for out in outputs {
2428                        if let Some(values) = const_values.get(out) {
2429                            let const_name = sanitize_identifier(out);
2430                            // Use the intended shape from value_shapes, not just empty for single-element
2431                            let shape = value_shapes
2432                                .get(out.as_str())
2433                                .map(|s| s.iter().map(|&d| d as u32).collect())
2434                                .unwrap_or_else(Vec::new);
2435
2436                            let decl = crate::ast::ConstDecl {
2437                                data_type: DataType::Int64,
2438                                shape,
2439                                init: crate::ast::ConstInit::InlineBytes {
2440                                    bytes: values[0].to_le_bytes().to_vec(),
2441                                },
2442                            };
2443
2444                            if let Some(existing) = self.graph.consts.get(&const_name) {
2445                                if existing != &decl {
2446                                    return Err(OnnxError::InvalidShape(format!(
2447                                        "Conflicting constant definitions for '{}'",
2448                                        const_name
2449                                    )));
2450                                }
2451                            } else {
2452                                self.graph.consts.insert(const_name.clone(), decl);
2453                            }
2454
2455                            value_name_map.insert(out.to_string(), const_name.clone());
2456                            value_name_map.insert(const_name.clone(), const_name.clone());
2457                            value_types.insert(out.to_string(), DataType::Int64);
2458                            value_types.insert(const_name, DataType::Int64);
2459                        }
2460                    }
2461                }
2462                // For non-scalar constants (like Range output), emit inline consts so downstream
2463                // nodes have a defined producer.
2464                for out in outputs {
2465                    if let Some(values) = const_values.get(out) {
2466                        let const_name = sanitize_identifier(out);
2467                        let mut shape = value_shapes
2468                            .get(out.as_str())
2469                            .cloned()
2470                            .unwrap_or_else(|| vec![values.len() as i64]);
2471                        let declared_numel = shape
2472                            .iter()
2473                            .try_fold(1usize, |acc, d| usize::try_from(*d).ok().map(|v| acc * v));
2474                        if declared_numel != Some(values.len()) {
2475                            // Some folded constants are broadcast candidates where value_shapes
2476                            // carries the post-broadcast shape but const_values stores the compact payload.
2477                            // Keep shape/data internally consistent by using the compact shape.
2478                            shape = vec![values.len() as i64];
2479                        }
2480                        let dtype = value_types
2481                            .get(out.as_str())
2482                            .cloned()
2483                            .unwrap_or(DataType::Int64);
2484
2485                        // Flatten i64 values into little-endian bytes
2486                        let mut bytes = Vec::with_capacity(values.len() * 8);
2487                        for v in values {
2488                            bytes.extend_from_slice(&v.to_le_bytes());
2489                        }
2490
2491                        let decl = crate::ast::ConstDecl {
2492                            data_type: dtype.clone(),
2493                            shape: shape.iter().map(|d| *d as u32).collect(),
2494                            init: crate::ast::ConstInit::InlineBytes { bytes },
2495                        };
2496
2497                        let existing = self.graph.consts.get(&const_name).cloned();
2498                        if existing.is_none() {
2499                            self.graph.consts.insert(const_name.clone(), decl);
2500                        }
2501
2502                        value_name_map.insert(out.to_string(), const_name.clone());
2503                        value_name_map.insert(const_name.clone(), const_name.clone());
2504                        value_types.insert(out.to_string(), dtype.clone());
2505                        value_types.insert(const_name, dtype);
2506                    }
2507                }
2508                continue;
2509            }
2510
2511            let context = crate::onnx::ops::ConversionContext {
2512                initializers: &initializers_map,
2513                value_shapes: &value_shapes,
2514                value_shape_dims: &value_shape_dims,
2515                const_values: &const_values,
2516                value_ids: &value_name_map,
2517                value_types: &value_types,
2518            };
2519
2520            let converted = registry.convert_node(onnx_node, &context)?;
2521
2522            for (name, mut decl) in converted.consts {
2523                if let crate::ast::ConstInit::InlineBytes { bytes } = &decl.init {
2524                    let elem_size = match decl.data_type {
2525                        DataType::Float32 => 4,
2526                        DataType::Float16 => 2,
2527                        DataType::Int64 => 8,
2528                        DataType::Uint64 => 8,
2529                        DataType::Int32 => 4,
2530                        DataType::Uint32 => 4,
2531                        DataType::Int8 => 1,
2532                        DataType::Uint8 => 1,
2533                        DataType::Int4 | DataType::Uint4 => 0,
2534                    };
2535                    if elem_size > 0 {
2536                        let declared_numel = decl
2537                            .shape
2538                            .iter()
2539                            .try_fold(1usize, |acc, d| usize::try_from(*d).ok().map(|v| acc * v));
2540                        let declared_bytes = declared_numel.map(|n| n * elem_size);
2541                        if declared_bytes != Some(bytes.len()) && bytes.len() % elem_size == 0 {
2542                            // Keep const metadata internally consistent even when upstream shape
2543                            // metadata reflects a broadcasted view of compact inline data.
2544                            decl.shape = vec![(bytes.len() / elem_size) as u32];
2545                        }
2546                    }
2547                }
2548                let decl_dtype = decl.data_type.clone();
2549                if let Some(existing) = self.graph.consts.get(&name) {
2550                    if existing != &decl {
2551                        return Err(OnnxError::InvalidShape(format!(
2552                            "Conflicting constant definitions for '{}'",
2553                            name
2554                        )));
2555                    }
2556                } else {
2557                    self.graph.consts.insert(name.clone(), decl);
2558                }
2559                value_name_map.insert(name.clone(), name.clone());
2560                value_types.insert(name.clone(), decl_dtype);
2561            }
2562
2563            for (onnx_out, webnn_id) in converted.output_mappings {
2564                value_name_map.insert(onnx_out.clone(), webnn_id.clone());
2565                value_name_map.insert(sanitize_identifier(&onnx_out), webnn_id.clone());
2566            }
2567
2568            for (onnx_out, dtype) in converted.output_types {
2569                if let Some(webnn_id) = value_name_map.get(&onnx_out).cloned() {
2570                    value_types.insert(webnn_id, dtype);
2571                }
2572            }
2573
2574            // Track output shapes after conversion to prevent shape inflation
2575            // Use .insert() to force correct shapes (not .or_insert() which preserves old shapes)
2576            if let Some(inferred_shape) =
2577                infer_shape(onnx_node, &value_shapes, &initializers_map, &const_values)
2578            {
2579                for output_name in onnx_node.output.as_slice() {
2580                    // Insert shape for both raw and sanitized names
2581                    value_shapes.insert(output_name.to_string(), inferred_shape.clone());
2582                    value_shapes.insert(sanitize_identifier(output_name), inferred_shape.clone());
2583                }
2584            }
2585
2586            self.graph.nodes.extend(converted.nodes);
2587        }
2588
2589        // Process outputs
2590        for output in onnx_graph.output.as_slice() {
2591            let onnx_name = output.name.as_str();
2592            if let Some(mapped) = value_name_map.get(onnx_name) {
2593                self.graph
2594                    .outputs
2595                    .insert(sanitize_identifier(onnx_name), mapped.clone());
2596            } else {
2597                return Err(OnnxError::InvalidShape(format!(
2598                    "No WebNN value found for ONNX output '{}'",
2599                    onnx_name
2600                )));
2601            }
2602        }
2603
2604        let has_dynamic_inputs = self.graph.inputs.values().any(|operand| {
2605            operand
2606                .shape
2607                .iter()
2608                .any(|dim| matches!(dim, Dimension::Dynamic(_)))
2609        });
2610        self.graph.version = if has_dynamic_inputs { 2 } else { 1 };
2611
2612        Ok(self.graph)
2613    }
2614}
2615
2616/// Convert an ONNX file to WebNN format with optional weight extraction
2617pub fn convert_onnx<P: AsRef<Path>>(
2618    onnx_path: P,
2619    mut options: ConvertOptions,
2620) -> Result<GraphJson, OnnxError> {
2621    // Read ONNX file
2622    let onnx_path_ref = onnx_path.as_ref();
2623    let onnx_bytes = fs::read(onnx_path_ref)?;
2624
2625    // Parse protobuf
2626    let mut model: ModelProto =
2627        ModelProto::decode(&onnx_bytes[..]).map_err(|e| OnnxError::ProtobufError(e.to_string()))?;
2628
2629    // Apply constant folding if optimize flag is set
2630    if options.optimize {
2631        crate::debug_println!("Running constant folding...");
2632        let evaluators = crate::onnx::constant_folding::evaluators::get_evaluators();
2633        let nodes_folded =
2634            crate::onnx::constant_folding::fold_constants_in_model(&mut model, &evaluators)?;
2635        crate::debug_println!("Constant folding: {} nodes folded", nodes_folded);
2636    }
2637
2638    // Merge overrides from sidecar dims file if provided implicitly and not already set
2639    if options.free_dim_overrides.is_empty() {
2640        let mut sidecar = onnx_path_ref.to_path_buf();
2641        sidecar.set_extension("dims.json");
2642        if sidecar.exists() {
2643            let content = fs::read_to_string(&sidecar)?;
2644            if let Ok(json) = serde_json::from_str::<JsonValue>(&content) {
2645                if let Some(obj) = json
2646                    .get("freeDimensionOverrides")
2647                    .unwrap_or(&json)
2648                    .as_object()
2649                {
2650                    for (name, value) in obj {
2651                        if let Some(v) = value.as_u64() {
2652                            options
2653                                .free_dim_overrides
2654                                .entry(name.clone())
2655                                .or_insert(v as u32);
2656                        }
2657                    }
2658                }
2659            }
2660        }
2661    }
2662
2663    // Create converter
2664    let converter = OnnxConverter::new(model.clone())?;
2665
2666    // Extract metadata for debugging
2667    converter.extract_metadata()?;
2668
2669    // Convert to GraphJson
2670    let mut graph = converter.convert(&options)?;
2671
2672    // Extract weights if requested
2673    if options.extract_weights {
2674        if let (Some(weights_path), Some(manifest_path)) =
2675            (&options.weights_path, &options.manifest_path)
2676        {
2677            extract_weights_from_onnx(&model, &mut graph, weights_path, manifest_path)?;
2678        }
2679    }
2680
2681    Ok(graph)
2682}
2683
2684/// Extract weights from ONNX model to .weights and .manifest.json files.
2685/// Also extracts large inline constants from the converted graph into the weights file.
2686fn extract_weights_from_onnx(
2687    model: &ModelProto,
2688    graph: &mut GraphJson,
2689    weights_path: &str,
2690    manifest_path: &str,
2691) -> Result<(), OnnxError> {
2692    use crate::weights::{TensorEntry, WeightsManifest};
2693
2694    if model.graph.is_none() {
2695        return Err(OnnxError::ProtobufError(
2696            "Missing graph in model".to_string(),
2697        ));
2698    }
2699
2700    let onnx_graph = model.graph.as_ref().unwrap();
2701    let mut manifest = WeightsManifest {
2702        format: "wg-weights-manifest".to_string(),
2703        version: 1,
2704        endianness: "little".to_string(),
2705        tensors: BTreeMap::new(),
2706    };
2707
2708    let mut weights_data = Vec::new();
2709    let mut current_offset = 0u64;
2710
2711    // Process each initializer
2712    for initializer in onnx_graph.initializer.as_slice() {
2713        let name = sanitize_identifier(initializer.name.as_str());
2714
2715        // Convert ONNX data type enum to i32, then to WebNN DataType
2716        let onnx_type = initializer.data_type;
2717        let data_type = map_onnx_data_type(onnx_type)?;
2718
2719        let shape: Vec<u32> = initializer
2720            .dims
2721            .as_slice()
2722            .iter()
2723            .map(|d| *d as u32)
2724            .collect();
2725        let raw_data = initializer.raw_data.as_slice();
2726
2727        // Convert typed data to bytes if raw_data is empty
2728        let bytes_to_write: Vec<u8> = if raw_data.is_empty() {
2729            // Try to extract from typed data fields
2730            let int64_data = initializer.int64_data.as_slice();
2731            let float_data = initializer.float_data.as_slice();
2732            let int32_data = initializer.int32_data.as_slice();
2733            let double_data = initializer.double_data.as_slice();
2734
2735            if !int64_data.is_empty() {
2736                // Convert int64_data to bytes (little-endian)
2737                int64_data.iter().flat_map(|&v| v.to_le_bytes()).collect()
2738            } else if !float_data.is_empty() {
2739                // Convert float_data to bytes (little-endian)
2740                float_data.iter().flat_map(|&v| v.to_le_bytes()).collect()
2741            } else if !int32_data.is_empty() {
2742                // Convert int32_data to bytes (little-endian)
2743                int32_data.iter().flat_map(|&v| v.to_le_bytes()).collect()
2744            } else if !double_data.is_empty() {
2745                // Convert double_data to bytes (little-endian)
2746                double_data.iter().flat_map(|&v| v.to_le_bytes()).collect()
2747            } else {
2748                // No data at all - skip this initializer
2749                crate::debug_println!("Warning: Skipping initializer '{}' with no data", name);
2750                continue;
2751            }
2752        } else {
2753            raw_data.to_vec()
2754        };
2755
2756        let byte_length = bytes_to_write.len() as u64;
2757
2758        // Add to manifest
2759        manifest.tensors.insert(
2760            name,
2761            TensorEntry {
2762                data_type,
2763                shape,
2764                byte_offset: current_offset,
2765                byte_length,
2766                layout: None,
2767            },
2768        );
2769
2770        // Append to weights data
2771        weights_data.extend_from_slice(&bytes_to_write);
2772        current_offset += byte_length;
2773    }
2774
2775    // Extract large inline constants from the graph into the weights file.
2776    // Threshold: constants larger than 1 KiB are moved to external weights.
2777    const INLINE_THRESHOLD: usize = 1024;
2778    for (name, decl) in graph.consts.iter_mut() {
2779        if let crate::ast::ConstInit::InlineBytes { bytes } = &decl.init {
2780            if bytes.len() > INLINE_THRESHOLD && !manifest.tensors.contains_key(name) {
2781                let byte_length = bytes.len() as u64;
2782                manifest.tensors.insert(
2783                    name.clone(),
2784                    TensorEntry {
2785                        data_type: decl.data_type.clone(),
2786                        shape: decl.shape.clone(),
2787                        byte_offset: current_offset,
2788                        byte_length,
2789                        layout: None,
2790                    },
2791                );
2792                weights_data.extend_from_slice(bytes);
2793                current_offset += byte_length;
2794            }
2795        }
2796    }
2797    // Update the graph consts to use weight references instead of inline bytes
2798    for (name, decl) in graph.consts.iter_mut() {
2799        if let crate::ast::ConstInit::InlineBytes { bytes } = &decl.init {
2800            if bytes.len() > INLINE_THRESHOLD {
2801                decl.init = crate::ast::ConstInit::Weights {
2802                    r#ref: name.clone(),
2803                };
2804            }
2805        }
2806    }
2807
2808    // Write weights file
2809    fs::write(weights_path, &weights_data)?;
2810
2811    // Write manifest file
2812    let manifest_json = serde_json::to_string_pretty(&manifest)
2813        .map_err(|e| OnnxError::ProtobufError(e.to_string()))?;
2814    fs::write(manifest_path, manifest_json)?;
2815
2816    Ok(())
2817}
2818
2819#[cfg(test)]
2820mod tests {
2821    use super::*;
2822
2823    #[test]
2824    fn test_convert_options_default() {
2825        let options = ConvertOptions::default();
2826        assert!(options.extract_weights);
2827        assert_eq!(options.output_path, "output.webnn");
2828    }
2829
2830    #[test]
2831    fn test_sanitize_identifier_replaces_colons() {
2832        assert_eq!(sanitize_identifier("foo::bar"), "foo__bar");
2833        assert_eq!(sanitize_identifier("foo:bar"), "foo_bar");
2834    }
2835
2836    #[test]
2837    fn test_sanitize_identifier_replaces_dots() {
2838        assert_eq!(sanitize_identifier("encoder.block.0"), "encoder_block_0");
2839        assert_eq!(
2840            sanitize_identifier("model.layer.weight"),
2841            "model_layer_weight"
2842        );
2843        assert_eq!(sanitize_identifier("a.b.c"), "a_b_c");
2844    }
2845
2846    #[test]
2847    fn test_sanitize_identifier_replaces_combined() {
2848        // Test combinations of :: : and .
2849        assert_eq!(
2850            sanitize_identifier("module::class:method.field"),
2851            "module__class_method_field"
2852        );
2853        assert_eq!(
2854            sanitize_identifier("encoder.attention::output:dense"),
2855            "encoder_attention__output_dense"
2856        );
2857    }
2858
2859    #[test]
2860    fn test_sanitize_identifier_no_change() {
2861        // Identifiers that don't need sanitization
2862        assert_eq!(sanitize_identifier("simple_name"), "simple_name");
2863        assert_eq!(sanitize_identifier("CamelCase"), "CamelCase");
2864        assert_eq!(sanitize_identifier("name123"), "name123");
2865    }
2866
2867    #[test]
2868    fn test_inline_bytes_encoding_for_i64_values() {
2869        // Test the inline bytes encoding logic used for non-scalar constants
2870        // This simulates what happens when Range or similar ops produce constant arrays
2871        let values: Vec<i64> = vec![0, 1, 2, 3, 4];
2872        let mut bytes = Vec::with_capacity(values.len() * 8);
2873        for v in values {
2874            bytes.extend_from_slice(&v.to_le_bytes());
2875        }
2876
2877        // Verify byte length
2878        assert_eq!(bytes.len(), 40); // 5 values * 8 bytes each
2879
2880        // Verify first value (0)
2881        let first_bytes: [u8; 8] = bytes[0..8].try_into().unwrap();
2882        assert_eq!(i64::from_le_bytes(first_bytes), 0);
2883
2884        // Verify last value (4)
2885        let last_bytes: [u8; 8] = bytes[32..40].try_into().unwrap();
2886        assert_eq!(i64::from_le_bytes(last_bytes), 4);
2887    }
2888
2889    #[test]
2890    fn test_inline_bytes_encoding_single_value() {
2891        // Test single value encoding
2892        let values: Vec<i64> = vec![42];
2893        let mut bytes = Vec::with_capacity(values.len() * 8);
2894        for v in values {
2895            bytes.extend_from_slice(&v.to_le_bytes());
2896        }
2897
2898        assert_eq!(bytes.len(), 8);
2899        let decoded: [u8; 8] = bytes.try_into().unwrap();
2900        assert_eq!(i64::from_le_bytes(decoded), 42);
2901    }
2902
2903    #[test]
2904    fn test_inline_bytes_encoding_negative_values() {
2905        // Test with negative values (important for Range with negative delta)
2906        let values: Vec<i64> = vec![5, 4, 3, 2, 1, 0, -1, -2];
2907        let mut bytes = Vec::with_capacity(values.len() * 8);
2908        for v in values {
2909            bytes.extend_from_slice(&v.to_le_bytes());
2910        }
2911
2912        assert_eq!(bytes.len(), 64); // 8 values * 8 bytes each
2913
2914        // Verify a negative value
2915        let neg_bytes: [u8; 8] = bytes[56..64].try_into().unwrap();
2916        assert_eq!(i64::from_le_bytes(neg_bytes), -2);
2917    }
2918
2919    #[test]
2920    fn test_inline_bytes_encoding_large_values() {
2921        // Test with large i64 values
2922        let values: Vec<i64> = vec![i64::MAX, i64::MIN, 0];
2923        let mut bytes = Vec::with_capacity(values.len() * 8);
2924        for v in values {
2925            bytes.extend_from_slice(&v.to_le_bytes());
2926        }
2927
2928        assert_eq!(bytes.len(), 24);
2929
2930        // Verify MAX value
2931        let max_bytes: [u8; 8] = bytes[0..8].try_into().unwrap();
2932        assert_eq!(i64::from_le_bytes(max_bytes), i64::MAX);
2933
2934        // Verify MIN value
2935        let min_bytes: [u8; 8] = bytes[8..16].try_into().unwrap();
2936        assert_eq!(i64::from_le_bytes(min_bytes), i64::MIN);
2937    }
2938
2939    #[test]
2940    fn test_convert_preserves_dynamic_input_dim_without_override() {
2941        use crate::protos::onnx::{tensor_shape_proto, type_proto};
2942        use crate::protos::onnx::{GraphProto, ModelProto, TensorShapeProto, ValueInfoProto};
2943
2944        let dim_batch = tensor_shape_proto::Dimension {
2945            value: Some(tensor_shape_proto::dimension::Value::DimParam(
2946                "batch_size".to_string(),
2947            )),
2948            denotation: String::new(),
2949        };
2950        let dim_seq = tensor_shape_proto::Dimension {
2951            value: Some(tensor_shape_proto::dimension::Value::DimValue(1)),
2952            denotation: String::new(),
2953        };
2954        let shape = TensorShapeProto {
2955            dim: vec![dim_batch, dim_seq],
2956        };
2957
2958        let tensor_type = type_proto::Tensor {
2959            elem_type: TensorProto_DataType::Int64.into(),
2960            shape: Some(shape),
2961        };
2962        let type_proto = crate::protos::onnx::TypeProto {
2963            value: Some(type_proto::Value::TensorType(tensor_type)),
2964            denotation: String::new(),
2965        };
2966
2967        let input_vi = ValueInfoProto {
2968            name: "input_ids".to_string(),
2969            r#type: Some(type_proto.clone()),
2970            ..Default::default()
2971        };
2972        let output_vi = ValueInfoProto {
2973            name: "input_ids".to_string(),
2974            r#type: Some(type_proto),
2975            ..Default::default()
2976        };
2977
2978        let model = ModelProto {
2979            graph: Some(GraphProto {
2980                input: vec![input_vi],
2981                output: vec![output_vi],
2982                ..Default::default()
2983            }),
2984            ..Default::default()
2985        };
2986
2987        let converter = OnnxConverter::new(model).expect("converter");
2988        let graph = converter
2989            .convert(&ConvertOptions {
2990                experimental_dynamic_inputs: true,
2991                ..ConvertOptions::default()
2992            })
2993            .expect("convert");
2994
2995        let input = graph.inputs.get("input_ids").expect("input_ids input");
2996        assert_eq!(input.shape.len(), 2);
2997        assert!(matches!(
2998            &input.shape[0],
2999            Dimension::Dynamic(d) if d.name == "batch_size"
3000        ));
3001        assert!(matches!(&input.shape[1], Dimension::Static(1)));
3002        assert_eq!(graph.version, 2);
3003    }
3004
3005    #[test]
3006    fn test_convert_rejects_dynamic_input_dim_without_flag() {
3007        use crate::protos::onnx::{tensor_shape_proto, type_proto};
3008        use crate::protos::onnx::{GraphProto, ModelProto, TensorShapeProto, ValueInfoProto};
3009
3010        let dim_batch = tensor_shape_proto::Dimension {
3011            value: Some(tensor_shape_proto::dimension::Value::DimParam(
3012                "unknown_dim".to_string(),
3013            )),
3014            denotation: String::new(),
3015        };
3016        let dim_seq = tensor_shape_proto::Dimension {
3017            value: Some(tensor_shape_proto::dimension::Value::DimValue(1)),
3018            denotation: String::new(),
3019        };
3020        let shape = TensorShapeProto {
3021            dim: vec![dim_batch, dim_seq],
3022        };
3023
3024        let tensor_type = type_proto::Tensor {
3025            elem_type: TensorProto_DataType::Int64.into(),
3026            shape: Some(shape),
3027        };
3028        let type_proto = crate::protos::onnx::TypeProto {
3029            value: Some(type_proto::Value::TensorType(tensor_type)),
3030            denotation: String::new(),
3031        };
3032
3033        let input_vi = ValueInfoProto {
3034            name: "input_ids".to_string(),
3035            r#type: Some(type_proto.clone()),
3036            ..Default::default()
3037        };
3038        let output_vi = ValueInfoProto {
3039            name: "input_ids".to_string(),
3040            r#type: Some(type_proto),
3041            ..Default::default()
3042        };
3043
3044        let model = ModelProto {
3045            graph: Some(GraphProto {
3046                input: vec![input_vi],
3047                output: vec![output_vi],
3048                ..Default::default()
3049            }),
3050            ..Default::default()
3051        };
3052
3053        let converter = OnnxConverter::new(model).expect("converter");
3054        let err = converter
3055            .convert(&ConvertOptions::default())
3056            .expect_err("should require overrides or flag");
3057        let msg = err.to_string();
3058        assert!(msg.contains("override-dim"));
3059        assert!(msg.contains("experimental-dynamic-inputs"));
3060    }
3061
3062    #[test]
3063    fn test_convert_dynamic_shape_concat_reshape_path_with_experimental_flag() {
3064        use crate::protos::onnx::{tensor_shape_proto, type_proto};
3065        use crate::protos::onnx::{
3066            AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto, TensorShapeProto,
3067            ValueInfoProto,
3068        };
3069
3070        let batch_dim = tensor_shape_proto::Dimension {
3071            value: Some(tensor_shape_proto::dimension::Value::DimValue(1)),
3072            denotation: String::new(),
3073        };
3074        let seq_dim = tensor_shape_proto::Dimension {
3075            value: Some(tensor_shape_proto::dimension::Value::DimParam(
3076                "sequence_length".to_string(),
3077            )),
3078            denotation: String::new(),
3079        };
3080        let hidden_dim = tensor_shape_proto::Dimension {
3081            value: Some(tensor_shape_proto::dimension::Value::DimValue(4)),
3082            denotation: String::new(),
3083        };
3084        let data_shape = TensorShapeProto {
3085            dim: vec![batch_dim, seq_dim, hidden_dim],
3086        };
3087
3088        let data_tensor_type = type_proto::Tensor {
3089            elem_type: TensorProto_DataType::Float.into(),
3090            shape: Some(data_shape),
3091        };
3092        let data_type_proto = crate::protos::onnx::TypeProto {
3093            value: Some(type_proto::Value::TensorType(data_tensor_type)),
3094            denotation: String::new(),
3095        };
3096
3097        let data_input = ValueInfoProto {
3098            name: "data".to_string(),
3099            r#type: Some(data_type_proto.clone()),
3100            ..Default::default()
3101        };
3102        let data_output = ValueInfoProto {
3103            name: "out".to_string(),
3104            r#type: Some(data_type_proto),
3105            ..Default::default()
3106        };
3107
3108        let idx0 = TensorProto {
3109            name: "idx0".to_string(),
3110            data_type: TensorProto_DataType::Int64 as i32,
3111            dims: vec![1],
3112            int64_data: vec![0],
3113            ..Default::default()
3114        };
3115        let idx1 = TensorProto {
3116            name: "idx1".to_string(),
3117            data_type: TensorProto_DataType::Int64 as i32,
3118            dims: vec![1],
3119            int64_data: vec![1],
3120            ..Default::default()
3121        };
3122        let last_dim = TensorProto {
3123            name: "last_dim".to_string(),
3124            data_type: TensorProto_DataType::Int64 as i32,
3125            dims: vec![1],
3126            int64_data: vec![4],
3127            ..Default::default()
3128        };
3129
3130        let shape_node = NodeProto {
3131            op_type: "Shape".to_string(),
3132            input: vec!["data".to_string()],
3133            output: vec!["shape_out".to_string()],
3134            ..Default::default()
3135        };
3136        let gather0 = NodeProto {
3137            op_type: "Gather".to_string(),
3138            input: vec!["shape_out".to_string(), "idx0".to_string()],
3139            output: vec!["dim0".to_string()],
3140            attribute: vec![AttributeProto {
3141                name: "axis".to_string(),
3142                i: 0,
3143                ..Default::default()
3144            }],
3145            ..Default::default()
3146        };
3147        let gather1 = NodeProto {
3148            op_type: "Gather".to_string(),
3149            input: vec!["shape_out".to_string(), "idx1".to_string()],
3150            output: vec!["dim1".to_string()],
3151            attribute: vec![AttributeProto {
3152                name: "axis".to_string(),
3153                i: 0,
3154                ..Default::default()
3155            }],
3156            ..Default::default()
3157        };
3158        let concat_shape = NodeProto {
3159            op_type: "Concat".to_string(),
3160            input: vec![
3161                "dim0".to_string(),
3162                "dim1".to_string(),
3163                "last_dim".to_string(),
3164            ],
3165            output: vec!["shape_for_reshape".to_string()],
3166            attribute: vec![AttributeProto {
3167                name: "axis".to_string(),
3168                i: 0,
3169                ..Default::default()
3170            }],
3171            ..Default::default()
3172        };
3173        let reshape = NodeProto {
3174            op_type: "Reshape".to_string(),
3175            input: vec!["data".to_string(), "shape_for_reshape".to_string()],
3176            output: vec!["out".to_string()],
3177            ..Default::default()
3178        };
3179
3180        let model = ModelProto {
3181            graph: Some(GraphProto {
3182                input: vec![data_input],
3183                output: vec![data_output],
3184                initializer: vec![idx0, idx1, last_dim],
3185                node: vec![shape_node, gather0, gather1, concat_shape, reshape],
3186                ..Default::default()
3187            }),
3188            ..Default::default()
3189        };
3190
3191        let converter = OnnxConverter::new(model).expect("converter");
3192        let graph = converter
3193            .convert(&ConvertOptions {
3194                optimize: true,
3195                experimental_dynamic_inputs: true,
3196                extract_weights: false,
3197                ..ConvertOptions::default()
3198            })
3199            .expect("dynamic reshape path should convert");
3200
3201        let reshape_node = graph
3202            .nodes
3203            .iter()
3204            .find(|n| n.op == "reshape")
3205            .expect("reshape node should exist");
3206        let shape = reshape_node
3207            .options
3208            .get("newShape")
3209            .and_then(|v| v.as_array())
3210            .expect("newShape should be an array");
3211        assert_eq!(shape.len(), 3);
3212        assert_eq!(shape[0].as_u64(), Some(1));
3213        assert_eq!(shape[2].as_u64(), Some(4));
3214        // The sequence dimension may be a concrete integer (concretized for lowering)
3215        // or a dynamic dimension object {"name": ..., "maxSize": N} when dynamic
3216        // dimension metadata is propagated.
3217        let dim1_ok = shape[1].as_u64().is_some_and(|v| v > 0)
3218            || shape[1].as_object().is_some_and(|o| {
3219                o.contains_key("name")
3220                    && o.get("maxSize")
3221                        .and_then(|v| v.as_u64())
3222                        .is_some_and(|v| v > 0)
3223            });
3224        assert!(
3225            dim1_ok,
3226            "sequence dimension should be concretized or dynamic for lowering, got: {:?}",
3227            shape[1]
3228        );
3229    }
3230
3231    #[test]
3232    fn test_convert_reshape_shape_path_survives_add_broadcast() {
3233        use crate::protos::onnx::{tensor_shape_proto, type_proto};
3234        use crate::protos::onnx::{
3235            AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto, TensorShapeProto,
3236            ValueInfoProto,
3237        };
3238
3239        let batch_dim = tensor_shape_proto::Dimension {
3240            value: Some(tensor_shape_proto::dimension::Value::DimValue(1)),
3241            denotation: String::new(),
3242        };
3243        let seq_dim = tensor_shape_proto::Dimension {
3244            value: Some(tensor_shape_proto::dimension::Value::DimValue(128)),
3245            denotation: String::new(),
3246        };
3247        let hidden_dim = tensor_shape_proto::Dimension {
3248            value: Some(tensor_shape_proto::dimension::Value::DimValue(4)),
3249            denotation: String::new(),
3250        };
3251        let data_shape = TensorShapeProto {
3252            dim: vec![batch_dim, seq_dim, hidden_dim],
3253        };
3254
3255        let data_tensor_type = type_proto::Tensor {
3256            elem_type: TensorProto_DataType::Float.into(),
3257            shape: Some(data_shape),
3258        };
3259        let data_type_proto = crate::protos::onnx::TypeProto {
3260            value: Some(type_proto::Value::TensorType(data_tensor_type)),
3261            denotation: String::new(),
3262        };
3263
3264        let data_input = ValueInfoProto {
3265            name: "data".to_string(),
3266            r#type: Some(data_type_proto.clone()),
3267            ..Default::default()
3268        };
3269        let data_output = ValueInfoProto {
3270            name: "out".to_string(),
3271            r#type: Some(data_type_proto),
3272            ..Default::default()
3273        };
3274
3275        let bias = TensorProto {
3276            name: "bias".to_string(),
3277            data_type: TensorProto_DataType::Float as i32,
3278            dims: vec![4],
3279            float_data: vec![0.0, 0.0, 0.0, 0.0],
3280            ..Default::default()
3281        };
3282        let idx0 = TensorProto {
3283            name: "idx0".to_string(),
3284            data_type: TensorProto_DataType::Int64 as i32,
3285            dims: vec![1],
3286            int64_data: vec![0],
3287            ..Default::default()
3288        };
3289        let idx1 = TensorProto {
3290            name: "idx1".to_string(),
3291            data_type: TensorProto_DataType::Int64 as i32,
3292            dims: vec![1],
3293            int64_data: vec![1],
3294            ..Default::default()
3295        };
3296        let last_dim = TensorProto {
3297            name: "last_dim".to_string(),
3298            data_type: TensorProto_DataType::Int64 as i32,
3299            dims: vec![1],
3300            int64_data: vec![4],
3301            ..Default::default()
3302        };
3303
3304        let add_node = NodeProto {
3305            op_type: "Add".to_string(),
3306            input: vec!["data".to_string(), "bias".to_string()],
3307            output: vec!["add_out".to_string()],
3308            ..Default::default()
3309        };
3310        let shape_node = NodeProto {
3311            op_type: "Shape".to_string(),
3312            input: vec!["add_out".to_string()],
3313            output: vec!["shape_out".to_string()],
3314            ..Default::default()
3315        };
3316        let gather0 = NodeProto {
3317            op_type: "Gather".to_string(),
3318            input: vec!["shape_out".to_string(), "idx0".to_string()],
3319            output: vec!["dim0".to_string()],
3320            attribute: vec![AttributeProto {
3321                name: "axis".to_string(),
3322                i: 0,
3323                ..Default::default()
3324            }],
3325            ..Default::default()
3326        };
3327        let gather1 = NodeProto {
3328            op_type: "Gather".to_string(),
3329            input: vec!["shape_out".to_string(), "idx1".to_string()],
3330            output: vec!["dim1".to_string()],
3331            attribute: vec![AttributeProto {
3332                name: "axis".to_string(),
3333                i: 0,
3334                ..Default::default()
3335            }],
3336            ..Default::default()
3337        };
3338        let concat_shape = NodeProto {
3339            op_type: "Concat".to_string(),
3340            input: vec![
3341                "dim0".to_string(),
3342                "dim1".to_string(),
3343                "last_dim".to_string(),
3344            ],
3345            output: vec!["shape_for_reshape".to_string()],
3346            attribute: vec![AttributeProto {
3347                name: "axis".to_string(),
3348                i: 0,
3349                ..Default::default()
3350            }],
3351            ..Default::default()
3352        };
3353        let reshape = NodeProto {
3354            op_type: "Reshape".to_string(),
3355            input: vec!["add_out".to_string(), "shape_for_reshape".to_string()],
3356            output: vec!["out".to_string()],
3357            ..Default::default()
3358        };
3359
3360        let model = ModelProto {
3361            graph: Some(GraphProto {
3362                input: vec![data_input],
3363                output: vec![data_output],
3364                initializer: vec![bias, idx0, idx1, last_dim],
3365                node: vec![
3366                    add_node,
3367                    shape_node,
3368                    gather0,
3369                    gather1,
3370                    concat_shape,
3371                    reshape,
3372                ],
3373                ..Default::default()
3374            }),
3375            ..Default::default()
3376        };
3377
3378        let converter = OnnxConverter::new(model).expect("converter");
3379        let graph = converter
3380            .convert(&ConvertOptions {
3381                optimize: true,
3382                extract_weights: false,
3383                ..ConvertOptions::default()
3384            })
3385            .expect("broadcasted shape path should convert");
3386
3387        let reshape_node = graph
3388            .nodes
3389            .iter()
3390            .find(|n| n.op == "reshape")
3391            .expect("reshape node should exist");
3392        assert_eq!(
3393            reshape_node.options.get("newShape"),
3394            Some(&serde_json::json!([1, 128, 4]))
3395        );
3396    }
3397
3398    #[test]
3399    fn test_convert_dynamic_range_lowers_to_slice_and_preserves_dynamic_reshape() {
3400        use crate::protos::onnx::{tensor_shape_proto, type_proto};
3401        use crate::protos::onnx::{
3402            AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto, TensorShapeProto,
3403            ValueInfoProto,
3404        };
3405
3406        let seq_dim = tensor_shape_proto::Dimension {
3407            value: Some(tensor_shape_proto::dimension::Value::DimParam(
3408                "sequence_length".to_string(),
3409            )),
3410            denotation: String::new(),
3411        };
3412        let data_shape = TensorShapeProto { dim: vec![seq_dim] };
3413
3414        let data_tensor_type = type_proto::Tensor {
3415            elem_type: TensorProto_DataType::Float.into(),
3416            shape: Some(data_shape),
3417        };
3418        let data_type_proto = crate::protos::onnx::TypeProto {
3419            value: Some(type_proto::Value::TensorType(data_tensor_type)),
3420            denotation: String::new(),
3421        };
3422
3423        let data_input = ValueInfoProto {
3424            name: "data".to_string(),
3425            r#type: Some(data_type_proto),
3426            ..Default::default()
3427        };
3428        let output_vi = ValueInfoProto {
3429            name: "out".to_string(),
3430            ..Default::default()
3431        };
3432
3433        let idx0 = TensorProto {
3434            name: "idx0".to_string(),
3435            data_type: TensorProto_DataType::Int64 as i32,
3436            dims: vec![1],
3437            int64_data: vec![0],
3438            ..Default::default()
3439        };
3440        let zero = TensorProto {
3441            name: "zero".to_string(),
3442            data_type: TensorProto_DataType::Int64 as i32,
3443            dims: vec![],
3444            int64_data: vec![0],
3445            ..Default::default()
3446        };
3447        let one = TensorProto {
3448            name: "one".to_string(),
3449            data_type: TensorProto_DataType::Int64 as i32,
3450            dims: vec![],
3451            int64_data: vec![1],
3452            ..Default::default()
3453        };
3454
3455        let shape_node = NodeProto {
3456            op_type: "Shape".to_string(),
3457            input: vec!["data".to_string()],
3458            output: vec!["shape_out".to_string()],
3459            ..Default::default()
3460        };
3461        let gather = NodeProto {
3462            op_type: "Gather".to_string(),
3463            input: vec!["shape_out".to_string(), "idx0".to_string()],
3464            output: vec!["seq_len".to_string()],
3465            attribute: vec![AttributeProto {
3466                name: "axis".to_string(),
3467                i: 0,
3468                ..Default::default()
3469            }],
3470            ..Default::default()
3471        };
3472        let add_limit = NodeProto {
3473            op_type: "Add".to_string(),
3474            input: vec!["seq_len".to_string(), "one".to_string()],
3475            output: vec!["range_limit".to_string()],
3476            ..Default::default()
3477        };
3478        let range = NodeProto {
3479            op_type: "Range".to_string(),
3480            input: vec![
3481                "zero".to_string(),
3482                "range_limit".to_string(),
3483                "one".to_string(),
3484            ],
3485            output: vec!["range_out".to_string()],
3486            ..Default::default()
3487        };
3488        let concat_shape = NodeProto {
3489            op_type: "Concat".to_string(),
3490            input: vec!["range_limit".to_string(), "one".to_string()],
3491            output: vec!["shape_for_reshape".to_string()],
3492            attribute: vec![AttributeProto {
3493                name: "axis".to_string(),
3494                i: 0,
3495                ..Default::default()
3496            }],
3497            ..Default::default()
3498        };
3499        let reshape = NodeProto {
3500            op_type: "Reshape".to_string(),
3501            input: vec!["range_out".to_string(), "shape_for_reshape".to_string()],
3502            output: vec!["out".to_string()],
3503            ..Default::default()
3504        };
3505
3506        let model = ModelProto {
3507            graph: Some(GraphProto {
3508                input: vec![data_input],
3509                output: vec![output_vi],
3510                initializer: vec![idx0, zero, one],
3511                node: vec![shape_node, gather, add_limit, range, concat_shape, reshape],
3512                ..Default::default()
3513            }),
3514            ..Default::default()
3515        };
3516
3517        let converter = OnnxConverter::new(model).expect("converter");
3518        let graph = converter
3519            .convert(&ConvertOptions {
3520                optimize: true,
3521                experimental_dynamic_inputs: true,
3522                extract_weights: false,
3523                ..ConvertOptions::default()
3524            })
3525            .expect("dynamic range path should convert");
3526
3527        let slice_node = graph
3528            .nodes
3529            .iter()
3530            .find(|n| n.op == "slice")
3531            .expect("range should lower to slice");
3532        let slice_sizes = slice_node
3533            .options
3534            .get("sizes")
3535            .and_then(|v| v.as_array())
3536            .expect("slice sizes should exist");
3537        assert_eq!(slice_sizes.len(), 1);
3538        let dynamic_size = slice_sizes[0]
3539            .as_object()
3540            .expect("dynamic range size should be a dimension object");
3541        assert_eq!(
3542            dynamic_size.get("name").and_then(|v| v.as_str()),
3543            Some("sequence_length + 1")
3544        );
3545        assert_eq!(
3546            dynamic_size.get("maxSize").and_then(|v| v.as_u64()),
3547            Some(4097)
3548        );
3549
3550        let reshape_node = graph
3551            .nodes
3552            .iter()
3553            .find(|n| n.op == "reshape")
3554            .expect("reshape node should exist");
3555        let new_shape = reshape_node
3556            .options
3557            .get("newShape")
3558            .and_then(|v| v.as_array())
3559            .expect("reshape newShape should exist");
3560        assert_eq!(new_shape.len(), 2);
3561        assert_eq!(new_shape[1].as_u64(), Some(1));
3562        let reshape_dim0 = new_shape[0]
3563            .as_object()
3564            .expect("reshape dim 0 should stay dynamic");
3565        assert_eq!(
3566            reshape_dim0.get("name").and_then(|v| v.as_str()),
3567            Some("sequence_length + 1")
3568        );
3569        assert_eq!(
3570            reshape_dim0.get("maxSize").and_then(|v| v.as_u64()),
3571            Some(4097)
3572        );
3573    }
3574
3575    #[test]
3576    fn test_convert_dynamic_range_with_dynamic_start_lowers_to_slice_and_add() {
3577        use crate::protos::onnx::{tensor_shape_proto, type_proto};
3578        use crate::protos::onnx::{
3579            AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto, TensorShapeProto,
3580            ValueInfoProto,
3581        };
3582
3583        let batch_dim = tensor_shape_proto::Dimension {
3584            value: Some(tensor_shape_proto::dimension::Value::DimValue(1)),
3585            denotation: String::new(),
3586        };
3587        let seq_dim = tensor_shape_proto::Dimension {
3588            value: Some(tensor_shape_proto::dimension::Value::DimParam(
3589                "sequence_length".to_string(),
3590            )),
3591            denotation: String::new(),
3592        };
3593        let past_dim = tensor_shape_proto::Dimension {
3594            value: Some(tensor_shape_proto::dimension::Value::DimParam(
3595                "past_sequence_length".to_string(),
3596            )),
3597            denotation: String::new(),
3598        };
3599        let heads_dim = tensor_shape_proto::Dimension {
3600            value: Some(tensor_shape_proto::dimension::Value::DimValue(3)),
3601            denotation: String::new(),
3602        };
3603        let head_dim = tensor_shape_proto::Dimension {
3604            value: Some(tensor_shape_proto::dimension::Value::DimValue(4)),
3605            denotation: String::new(),
3606        };
3607
3608        let ids_shape = TensorShapeProto {
3609            dim: vec![batch_dim.clone(), seq_dim.clone()],
3610        };
3611        let past_shape = TensorShapeProto {
3612            dim: vec![batch_dim, heads_dim, past_dim, head_dim],
3613        };
3614        let range_shape = TensorShapeProto {
3615            dim: vec![seq_dim.clone()],
3616        };
3617        let out_shape = TensorShapeProto {
3618            dim: vec![
3619                seq_dim,
3620                tensor_shape_proto::Dimension {
3621                    value: Some(tensor_shape_proto::dimension::Value::DimValue(1)),
3622                    denotation: String::new(),
3623                },
3624            ],
3625        };
3626
3627        let ids_tensor_type = type_proto::Tensor {
3628            elem_type: TensorProto_DataType::Int64.into(),
3629            shape: Some(ids_shape),
3630        };
3631        let past_tensor_type = type_proto::Tensor {
3632            elem_type: TensorProto_DataType::Float.into(),
3633            shape: Some(past_shape),
3634        };
3635        let range_tensor_type = type_proto::Tensor {
3636            elem_type: TensorProto_DataType::Int64.into(),
3637            shape: Some(range_shape),
3638        };
3639        let out_tensor_type = type_proto::Tensor {
3640            elem_type: TensorProto_DataType::Int64.into(),
3641            shape: Some(out_shape),
3642        };
3643
3644        let ids_input = ValueInfoProto {
3645            name: "ids".to_string(),
3646            r#type: Some(crate::protos::onnx::TypeProto {
3647                value: Some(type_proto::Value::TensorType(ids_tensor_type)),
3648                denotation: String::new(),
3649            }),
3650            ..Default::default()
3651        };
3652        let past_input = ValueInfoProto {
3653            name: "past".to_string(),
3654            r#type: Some(crate::protos::onnx::TypeProto {
3655                value: Some(type_proto::Value::TensorType(past_tensor_type)),
3656                denotation: String::new(),
3657            }),
3658            ..Default::default()
3659        };
3660        let range_vi = ValueInfoProto {
3661            name: "range_out".to_string(),
3662            r#type: Some(crate::protos::onnx::TypeProto {
3663                value: Some(type_proto::Value::TensorType(range_tensor_type)),
3664                denotation: String::new(),
3665            }),
3666            ..Default::default()
3667        };
3668        let out_vi = ValueInfoProto {
3669            name: "out".to_string(),
3670            r#type: Some(crate::protos::onnx::TypeProto {
3671                value: Some(type_proto::Value::TensorType(out_tensor_type)),
3672                denotation: String::new(),
3673            }),
3674            ..Default::default()
3675        };
3676
3677        let idx1 = TensorProto {
3678            name: "idx1".to_string(),
3679            data_type: TensorProto_DataType::Int64 as i32,
3680            dims: vec![1],
3681            int64_data: vec![1],
3682            ..Default::default()
3683        };
3684        let idx2 = TensorProto {
3685            name: "idx2".to_string(),
3686            data_type: TensorProto_DataType::Int64 as i32,
3687            dims: vec![1],
3688            int64_data: vec![2],
3689            ..Default::default()
3690        };
3691        let one = TensorProto {
3692            name: "one".to_string(),
3693            data_type: TensorProto_DataType::Int64 as i32,
3694            dims: vec![],
3695            int64_data: vec![1],
3696            ..Default::default()
3697        };
3698        let reshape_shape = TensorProto {
3699            name: "reshape_shape".to_string(),
3700            data_type: TensorProto_DataType::Int64 as i32,
3701            dims: vec![2],
3702            int64_data: vec![4096, 1],
3703            ..Default::default()
3704        };
3705
3706        let shape_past = NodeProto {
3707            op_type: "Shape".to_string(),
3708            input: vec!["past".to_string()],
3709            output: vec!["past_shape".to_string()],
3710            ..Default::default()
3711        };
3712        let gather_start = NodeProto {
3713            op_type: "Gather".to_string(),
3714            input: vec!["past_shape".to_string(), "idx2".to_string()],
3715            output: vec!["range_start".to_string()],
3716            attribute: vec![AttributeProto {
3717                name: "axis".to_string(),
3718                i: 0,
3719                ..Default::default()
3720            }],
3721            ..Default::default()
3722        };
3723        let shape_ids = NodeProto {
3724            op_type: "Shape".to_string(),
3725            input: vec!["ids".to_string()],
3726            output: vec!["ids_shape".to_string()],
3727            ..Default::default()
3728        };
3729        let gather_seq = NodeProto {
3730            op_type: "Gather".to_string(),
3731            input: vec!["ids_shape".to_string(), "idx1".to_string()],
3732            output: vec!["seq_len".to_string()],
3733            attribute: vec![AttributeProto {
3734                name: "axis".to_string(),
3735                i: 0,
3736                ..Default::default()
3737            }],
3738            ..Default::default()
3739        };
3740        let add_limit = NodeProto {
3741            op_type: "Add".to_string(),
3742            input: vec!["range_start".to_string(), "seq_len".to_string()],
3743            output: vec!["range_limit".to_string()],
3744            ..Default::default()
3745        };
3746        let range = NodeProto {
3747            op_type: "Range".to_string(),
3748            input: vec![
3749                "range_start".to_string(),
3750                "range_limit".to_string(),
3751                "one".to_string(),
3752            ],
3753            output: vec!["range_out".to_string()],
3754            ..Default::default()
3755        };
3756        let reshape = NodeProto {
3757            op_type: "Reshape".to_string(),
3758            input: vec!["range_out".to_string(), "reshape_shape".to_string()],
3759            output: vec!["out".to_string()],
3760            ..Default::default()
3761        };
3762
3763        let model = ModelProto {
3764            graph: Some(GraphProto {
3765                input: vec![ids_input, past_input],
3766                output: vec![out_vi.clone()],
3767                value_info: vec![range_vi, out_vi],
3768                initializer: vec![idx1, idx2, one, reshape_shape],
3769                node: vec![
3770                    shape_past,
3771                    gather_start,
3772                    shape_ids,
3773                    gather_seq,
3774                    add_limit,
3775                    range,
3776                    reshape,
3777                ],
3778                ..Default::default()
3779            }),
3780            ..Default::default()
3781        };
3782
3783        let converter = OnnxConverter::new(model).expect("converter");
3784        let graph = converter
3785            .convert(&ConvertOptions {
3786                optimize: true,
3787                experimental_dynamic_inputs: true,
3788                extract_weights: false,
3789                ..ConvertOptions::default()
3790            })
3791            .expect("dynamic range with dynamic start should convert");
3792
3793        assert!(
3794            !graph.consts.contains_key("range_out"),
3795            "range output should stay runtime-computed"
3796        );
3797
3798        let slice_node = graph
3799            .nodes
3800            .iter()
3801            .find(|n| n.id == "range_out_slice" && n.op == "slice")
3802            .expect("range should lower to a slice");
3803        let slice_sizes = slice_node
3804            .options
3805            .get("sizes")
3806            .and_then(|v| v.as_array())
3807            .expect("slice sizes should exist");
3808        let dynamic_size = slice_sizes[0]
3809            .as_object()
3810            .expect("slice size should be dynamic");
3811        assert_eq!(
3812            dynamic_size.get("name").and_then(|v| v.as_str()),
3813            Some("sequence_length")
3814        );
3815        assert_eq!(
3816            dynamic_size.get("maxSize").and_then(|v| v.as_u64()),
3817            Some(4096)
3818        );
3819
3820        let add_node = graph
3821            .nodes
3822            .iter()
3823            .find(|n| n.id == "range_out" && n.op == "add")
3824            .expect("dynamic-start range should add the runtime start offset");
3825        assert_eq!(add_node.inputs.len(), 2);
3826        assert_eq!(add_node.inputs[0], "range_out_slice");
3827
3828        let reshape_node = graph
3829            .nodes
3830            .iter()
3831            .find(|n| n.op == "reshape")
3832            .expect("reshape node should exist");
3833        let new_shape = reshape_node
3834            .options
3835            .get("newShape")
3836            .and_then(|v| v.as_array())
3837            .expect("reshape newShape should exist");
3838        assert_eq!(new_shape.len(), 2);
3839        assert_eq!(new_shape[1].as_u64(), Some(1));
3840        let reshape_dim0 = new_shape[0]
3841            .as_object()
3842            .expect("reshape dim 0 should stay dynamic");
3843        assert_eq!(
3844            reshape_dim0.get("name").and_then(|v| v.as_str()),
3845            Some("sequence_length")
3846        );
3847        assert_eq!(
3848            reshape_dim0.get("maxSize").and_then(|v| v.as_u64()),
3849            Some(4096)
3850        );
3851    }
3852
3853    #[test]
3854    fn test_binary_const_folding_preserves_broadcast_shape() {
3855        let a = vec![-1];
3856        let b = [1, 2, 3, 4].repeat(128);
3857        let a_shape = Vec::<i64>::new();
3858        let b_shape = vec![1, 128, 4];
3859        let (out, out_shape) =
3860            fold_binary_const_i64("Mul", &a, &b, &a_shape, &b_shape).expect("broadcast fold");
3861        assert_eq!(out_shape, vec![1, 128, 4]);
3862        assert_eq!(out.len(), 512);
3863        assert_eq!(out[0], -1);
3864        assert_eq!(out[1], -2);
3865        assert_eq!(out[2], -3);
3866        assert_eq!(out[3], -4);
3867    }
3868
3869    #[test]
3870    fn test_convert_equal_broadcast_path_does_not_flatten_const_shape() {
3871        use crate::protos::onnx::{
3872            type_proto, AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto,
3873        };
3874
3875        let a = TensorProto {
3876            name: "shape_vec".to_string(),
3877            data_type: TensorProto_DataType::Int64 as i32,
3878            dims: vec![4],
3879            int64_data: vec![1, 128, 4, 8],
3880            ..Default::default()
3881        };
3882        let shape3 = TensorProto {
3883            name: "shape3".to_string(),
3884            data_type: TensorProto_DataType::Int64 as i32,
3885            dims: vec![3],
3886            int64_data: vec![1, 128, 4],
3887            ..Default::default()
3888        };
3889        let neg1 = TensorProto {
3890            name: "neg1".to_string(),
3891            data_type: TensorProto_DataType::Int64 as i32,
3892            dims: vec![],
3893            int64_data: vec![-1],
3894            ..Default::default()
3895        };
3896        let cos_fill = TensorProto {
3897            data_type: TensorProto_DataType::Int64 as i32,
3898            dims: vec![],
3899            int64_data: vec![1],
3900            ..Default::default()
3901        };
3902
3903        let cos = NodeProto {
3904            op_type: "ConstantOfShape".to_string(),
3905            input: vec!["shape3".to_string()],
3906            output: vec!["cos_out".to_string()],
3907            attribute: vec![AttributeProto {
3908                name: "value".to_string(),
3909                t: Some(cos_fill),
3910                ..Default::default()
3911            }],
3912            ..Default::default()
3913        };
3914        let mul = NodeProto {
3915            op_type: "Mul".to_string(),
3916            input: vec!["cos_out".to_string(), "neg1".to_string()],
3917            output: vec!["mul_out".to_string()],
3918            ..Default::default()
3919        };
3920        let eq = NodeProto {
3921            op_type: "Equal".to_string(),
3922            input: vec!["shape_vec".to_string(), "mul_out".to_string()],
3923            output: vec!["eq_out".to_string()],
3924            ..Default::default()
3925        };
3926
3927        let output_type = crate::protos::onnx::TypeProto {
3928            value: Some(type_proto::Value::TensorType(type_proto::Tensor {
3929                elem_type: TensorProto_DataType::Bool.into(),
3930                shape: None,
3931            })),
3932            denotation: String::new(),
3933        };
3934
3935        let model = ModelProto {
3936            graph: Some(GraphProto {
3937                initializer: vec![a, shape3, neg1],
3938                node: vec![cos, mul, eq],
3939                output: vec![crate::protos::onnx::ValueInfoProto {
3940                    name: "eq_out".to_string(),
3941                    r#type: Some(output_type),
3942                    ..Default::default()
3943                }],
3944                ..Default::default()
3945            }),
3946            ..Default::default()
3947        };
3948
3949        let converter = OnnxConverter::new(model).expect("converter");
3950        let graph = converter
3951            .convert(&ConvertOptions {
3952                optimize: true,
3953                extract_weights: false,
3954                ..ConvertOptions::default()
3955            })
3956            .expect("convert");
3957
3958        let mul_const = graph.consts.get("mul_out").expect("mul_out const");
3959        assert_eq!(mul_const.shape, vec![1, 128, 4]);
3960        assert!(
3961            !graph.consts.contains_key("eq_out")
3962                || graph
3963                    .consts
3964                    .get("eq_out")
3965                    .is_some_and(|decl| decl.shape == vec![1, 128, 4]),
3966            "eq_out constant must not be flattened"
3967        );
3968    }
3969}