Skip to main content

scirs2_neural/export/
onnx.rs

1//! ONNX-like model representation and export utilities.
2//!
3//! This module provides pure-Rust data structures that mirror the ONNX protobuf
4//! schema (nodes, graphs, tensors, value-info) without requiring any C library.
5//! Serialisation uses `oxicode` for compact binary interchange and `serde_json`
6//! for human-readable JSON.
7//!
8//! # ONNX compatibility notes
9//!
10//! - Default opset version: **17** (matches ONNX 1.13 / ONNX Runtime 1.15).
11//! - Only float32 weights are stored in [`OnnxTensor`].  Float64 sources are
12//!   downcast transparently during export.
13//! - Dynamic batch dimensions are represented as `None` in [`OnnxValueInfo::shape`].
14
15use crate::error::{NeuralError, Result};
16use oxicode::{config as oxicode_config, serde as oxicode_serde};
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19
20// ndarray imports via scirs2-core re-exports
21use scirs2_core::ndarray::{Array1, Array2, Array4};
22
23// ---------------------------------------------------------------------------
24// Core data types
25// ---------------------------------------------------------------------------
26
27/// Supported ONNX element data types (subset used by this exporter).
28#[non_exhaustive]
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30#[repr(i32)]
31#[derive(Default)]
32pub enum OnnxDataType {
33    /// 32-bit IEEE-754 float (ONNX type 1)
34    #[default]
35    Float32 = 1,
36    /// 32-bit signed integer (ONNX type 6)
37    Int32 = 6,
38    /// 64-bit signed integer (ONNX type 7)
39    Int64 = 7,
40    /// 64-bit IEEE-754 float (ONNX type 11)
41    Float64 = 11,
42}
43
44/// Attribute value attached to an [`OnnxNode`].
45///
46/// Follows ONNX `AttributeProto` semantics.
47#[non_exhaustive]
48#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
49pub enum OnnxAttribute {
50    /// Single float scalar.
51    Float(f32),
52    /// Single int scalar.
53    Int(i64),
54    /// String attribute.
55    String(String),
56    /// Repeated floats.
57    Floats(Vec<f32>),
58    /// Repeated ints.
59    Ints(Vec<i64>),
60}
61
62/// A single compute node in the ONNX graph (corresponds to `NodeProto`).
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct OnnxNode {
65    /// ONNX operator type string (e.g. `"Gemm"`, `"Conv"`, `"Relu"`).
66    pub op_type: String,
67    /// Human-readable name (unique within the graph).
68    pub name: String,
69    /// Names of input tensors consumed by this node.
70    pub inputs: Vec<String>,
71    /// Names of output tensors produced by this node.
72    pub outputs: Vec<String>,
73    /// Operator attributes (kernel size, dilations, activations, …).
74    pub attributes: HashMap<String, OnnxAttribute>,
75}
76
77impl OnnxNode {
78    /// Construct a node with no attributes.
79    pub fn new(
80        op_type: impl Into<String>,
81        name: impl Into<String>,
82        inputs: Vec<String>,
83        outputs: Vec<String>,
84    ) -> Self {
85        Self {
86            op_type: op_type.into(),
87            name: name.into(),
88            inputs,
89            outputs,
90            attributes: HashMap::new(),
91        }
92    }
93
94    /// Add an attribute and return `self` for builder-style chaining.
95    pub fn with_attr(mut self, key: impl Into<String>, value: OnnxAttribute) -> Self {
96        self.attributes.insert(key.into(), value);
97        self
98    }
99}
100
101/// A named constant tensor stored in the graph (corresponds to `TensorProto`).
102///
103/// Only float32 data is kept; callers should downcast f64 weights.
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct OnnxTensor {
106    /// Tensor name (must match an [`OnnxNode`] input name).
107    pub name: String,
108    /// Element type tag.
109    pub data_type: OnnxDataType,
110    /// Dimension sizes in row-major order.
111    pub dims: Vec<i64>,
112    /// Flat float32 payload (populated when `data_type == Float32`).
113    pub float_data: Vec<f32>,
114    /// Flat int64 payload (populated when `data_type == Int64`).
115    pub int64_data: Vec<i64>,
116}
117
118impl OnnxTensor {
119    /// Build an `OnnxTensor` from a flat slice of f64 values.
120    ///
121    /// Values are cast to f32 for wire-format compatibility.
122    pub fn from_f64_slice(name: impl Into<String>, dims: Vec<i64>, data: &[f64]) -> Self {
123        Self {
124            name: name.into(),
125            data_type: OnnxDataType::Float32,
126            dims,
127            float_data: data.iter().map(|&v| v as f32).collect(),
128            int64_data: Vec::new(),
129        }
130    }
131
132    /// Build an `OnnxTensor` directly from f32 values (no cast required).
133    pub fn from_f32_slice(name: impl Into<String>, dims: Vec<i64>, data: &[f32]) -> Self {
134        Self {
135            name: name.into(),
136            data_type: OnnxDataType::Float32,
137            dims,
138            float_data: data.to_vec(),
139            int64_data: Vec::new(),
140        }
141    }
142
143    /// Return the total number of elements (product of dims).
144    pub fn numel(&self) -> usize {
145        self.dims
146            .iter()
147            .map(|&d| d as usize)
148            .product::<usize>()
149            .max(1)
150    }
151}
152
153/// Typed tensor description used for graph inputs/outputs (corresponds to
154/// `ValueInfoProto`).
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct OnnxValueInfo {
157    /// Tensor name.
158    pub name: String,
159    /// Element type.
160    pub data_type: OnnxDataType,
161    /// Shape; `None` entries indicate dynamic (symbolic) dimensions.
162    pub shape: Vec<Option<i64>>,
163}
164
165impl OnnxValueInfo {
166    /// Convenience constructor.
167    pub fn new(name: impl Into<String>, data_type: OnnxDataType, shape: Vec<Option<i64>>) -> Self {
168        Self {
169            name: name.into(),
170            data_type,
171            shape,
172        }
173    }
174}
175
176/// A complete ONNX compute graph (`GraphProto`).
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct OnnxGraph {
179    /// Ordered list of compute nodes.
180    pub nodes: Vec<OnnxNode>,
181    /// Graph input descriptors.
182    pub inputs: Vec<OnnxValueInfo>,
183    /// Graph output descriptors.
184    pub outputs: Vec<OnnxValueInfo>,
185    /// Constant tensors (model weights and biases).
186    pub initializers: Vec<OnnxTensor>,
187}
188
189impl OnnxGraph {
190    /// Create an empty graph.
191    pub fn new() -> Self {
192        Self {
193            nodes: Vec::new(),
194            inputs: Vec::new(),
195            outputs: Vec::new(),
196            initializers: Vec::new(),
197        }
198    }
199}
200
201impl Default for OnnxGraph {
202    fn default() -> Self {
203        Self::new()
204    }
205}
206
207/// Top-level ONNX model wrapper (`ModelProto`).
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct OnnxModel {
210    /// Compute graph.
211    pub graph: OnnxGraph,
212    /// ONNX opset version (default: 17).
213    pub opset_version: i64,
214    /// ONNX IR version (default: 8).
215    pub ir_version: i64,
216    /// Framework that produced this model.
217    pub producer_name: String,
218    /// Caller-assigned model version integer.
219    pub model_version: i64,
220}
221
222impl Default for OnnxModel {
223    fn default() -> Self {
224        Self {
225            graph: OnnxGraph::new(),
226            opset_version: 17,
227            ir_version: 8,
228            producer_name: "scirs2-neural".to_string(),
229            model_version: 1,
230        }
231    }
232}
233
234impl OnnxModel {
235    /// Create a new model with default metadata and the given graph.
236    pub fn new(graph: OnnxGraph) -> Self {
237        Self {
238            graph,
239            ..Default::default()
240        }
241    }
242
243    // ------------------------------------------------------------------
244    // Serialization helpers
245    // ------------------------------------------------------------------
246
247    /// Serialise the model to compact oxicode bytes.
248    ///
249    /// The binary layout uses `oxicode`'s `serde` API which produces a
250    /// self-describing, SIMD-optimised payload.  Use [`from_bytes`] to
251    /// deserialise.
252    pub fn to_bytes(&self) -> Result<Vec<u8>> {
253        let cfg = oxicode_config::standard();
254        oxicode_serde::encode_to_vec(self, cfg)
255            .map_err(|e| NeuralError::SerializationError(format!("oxicode encode error: {e}")))
256    }
257
258    /// Deserialise a model that was produced by [`to_bytes`].
259    pub fn from_bytes(data: &[u8]) -> Result<Self> {
260        let cfg = oxicode_config::standard();
261        oxicode_serde::decode_owned_from_slice(data, cfg)
262            .map(|(model, _)| model)
263            .map_err(|e| NeuralError::DeserializationError(format!("oxicode decode error: {e}")))
264    }
265
266    /// Serialise to pretty-printed JSON (useful for inspection / debugging).
267    pub fn to_json(&self) -> Result<String> {
268        serde_json::to_string_pretty(self)
269            .map_err(|e| NeuralError::SerializationError(format!("JSON encode error: {e}")))
270    }
271
272    /// Deserialise from JSON produced by [`to_json`].
273    pub fn from_json(json: &str) -> Result<Self> {
274        serde_json::from_str(json)
275            .map_err(|e| NeuralError::DeserializationError(format!("JSON decode error: {e}")))
276    }
277}
278
279// ---------------------------------------------------------------------------
280// OnnxExportable trait
281// ---------------------------------------------------------------------------
282
283/// Implemented by neural-network layer types that can produce ONNX nodes.
284pub trait OnnxExportable {
285    /// Emit the ONNX compute node(s) for this layer.
286    ///
287    /// * `input_name`  — name of the tensor flowing *into* this layer.
288    /// * `output_name` — name of the tensor produced by this layer.
289    /// * `prefix`      — namespace prefix for generated weight tensor names.
290    fn to_onnx_nodes(&self, input_name: &str, output_name: &str, prefix: &str) -> Vec<OnnxNode>;
291
292    /// Emit the weight [`OnnxTensor`]s (initializers) for this layer.
293    fn to_onnx_initializers(&self, prefix: &str) -> Vec<OnnxTensor>;
294}
295
296// ---------------------------------------------------------------------------
297// Free-standing layer exporters
298// ---------------------------------------------------------------------------
299
300/// Export a fully-connected (`Gemm`) layer.
301///
302/// `weights` has shape `[out_features, in_features]` following PyTorch/ONNX
303/// convention.  The produced node uses `transB=1` so that the weight matrix
304/// stored as `[out, in]` is transposed during the matrix-multiply.
305///
306/// Returns `(nodes, initializers)` ready for insertion into an [`OnnxGraph`].
307pub fn export_linear(
308    weights: &Array2<f64>,
309    bias: Option<&Array1<f64>>,
310    input_name: &str,
311    output_name: &str,
312    prefix: &str,
313) -> (Vec<OnnxNode>, Vec<OnnxTensor>) {
314    let w_name = format!("{prefix}.weight");
315    let b_name = format!("{prefix}.bias");
316
317    let out_features = weights.nrows() as i64;
318    let in_features = weights.ncols() as i64;
319
320    // Build weight initializer
321    let w_flat: Vec<f64> = weights.iter().copied().collect();
322    let w_tensor = OnnxTensor::from_f64_slice(&w_name, vec![out_features, in_features], &w_flat);
323
324    let mut node_inputs = vec![input_name.to_string(), w_name.clone()];
325    let mut initializers = vec![w_tensor];
326
327    // Optionally include bias
328    if let Some(b) = bias {
329        let b_flat: Vec<f64> = b.iter().copied().collect();
330        let b_tensor = OnnxTensor::from_f64_slice(&b_name, vec![out_features], &b_flat);
331        initializers.push(b_tensor);
332        node_inputs.push(b_name.clone());
333    }
334
335    let node = OnnxNode::new(
336        "Gemm",
337        format!("{prefix}/Gemm"),
338        node_inputs,
339        vec![output_name.to_string()],
340    )
341    .with_attr("transB", OnnxAttribute::Int(1))
342    .with_attr("alpha", OnnxAttribute::Float(1.0))
343    .with_attr("beta", OnnxAttribute::Float(1.0));
344
345    (vec![node], initializers)
346}
347
348/// Export a 2-D convolution (`Conv`) layer.
349///
350/// `weights` has ONNX layout `[out_channels, in_channels, kH, kW]`.
351/// `stride` and `padding` are `[H, W]` pairs.
352///
353/// Returns `(nodes, initializers)`.
354pub fn export_conv2d(
355    weights: &Array4<f64>,
356    bias: Option<&Array1<f64>>,
357    stride: &[usize],
358    padding: &[usize],
359    input_name: &str,
360    output_name: &str,
361    prefix: &str,
362) -> (Vec<OnnxNode>, Vec<OnnxTensor>) {
363    let w_name = format!("{prefix}.weight");
364    let b_name = format!("{prefix}.bias");
365
366    let shape = weights.shape();
367    let dims: Vec<i64> = shape.iter().map(|&d| d as i64).collect();
368    let w_flat: Vec<f64> = weights.iter().copied().collect();
369    let w_tensor = OnnxTensor::from_f64_slice(&w_name, dims, &w_flat);
370
371    let strides_attr: Vec<i64> = stride.iter().map(|&s| s as i64).collect();
372    let pads_onnx: Vec<i64> = padding
373        .iter()
374        .chain(padding.iter())
375        .map(|&p| p as i64)
376        .collect(); // [top, left, bottom, right]
377
378    let mut node_inputs = vec![input_name.to_string(), w_name.clone()];
379    let mut initializers = vec![w_tensor];
380
381    if let Some(b) = bias {
382        let out_channels = shape[0] as i64;
383        let b_flat: Vec<f64> = b.iter().copied().collect();
384        let b_tensor = OnnxTensor::from_f64_slice(&b_name, vec![out_channels], &b_flat);
385        initializers.push(b_tensor);
386        node_inputs.push(b_name.clone());
387    }
388
389    let node = OnnxNode::new(
390        "Conv",
391        format!("{prefix}/Conv"),
392        node_inputs,
393        vec![output_name.to_string()],
394    )
395    .with_attr("strides", OnnxAttribute::Ints(strides_attr))
396    .with_attr("pads", OnnxAttribute::Ints(pads_onnx));
397
398    (vec![node], initializers)
399}
400
401/// Export an elementwise activation function.
402///
403/// Supported `kind` values: `"relu"`, `"sigmoid"`, `"tanh"`, `"gelu"`,
404/// `"leaky_relu"`, `"elu"`, `"selu"`, `"softmax"`, `"log_softmax"`.
405///
406/// Unknown kinds fall back to `"Relu"` with a warning attribute.
407pub fn export_activation(kind: &str, input_name: &str, output_name: &str) -> OnnxNode {
408    let (op_type, extra): (&str, Option<(&str, OnnxAttribute)>) = match kind.to_lowercase().as_str()
409    {
410        "relu" => ("Relu", None),
411        "sigmoid" => ("Sigmoid", None),
412        "tanh" => ("Tanh", None),
413        "gelu" => ("Gelu", None),
414        "leaky_relu" => ("LeakyRelu", Some(("alpha", OnnxAttribute::Float(0.01)))),
415        "elu" => ("Elu", Some(("alpha", OnnxAttribute::Float(1.0)))),
416        "selu" => ("Selu", None),
417        "softmax" => ("Softmax", Some(("axis", OnnxAttribute::Int(-1)))),
418        "log_softmax" => ("LogSoftmax", Some(("axis", OnnxAttribute::Int(-1)))),
419        unknown => {
420            let mut node = OnnxNode::new(
421                "Relu",
422                format!("{unknown}/fallback_Relu"),
423                vec![input_name.to_string()],
424                vec![output_name.to_string()],
425            );
426            node.attributes.insert(
427                "_scirs2_unsupported_activation".to_string(),
428                OnnxAttribute::String(unknown.to_string()),
429            );
430            return node;
431        }
432    };
433
434    let mut node = OnnxNode::new(
435        op_type,
436        format!("{input_name}/{op_type}"),
437        vec![input_name.to_string()],
438        vec![output_name.to_string()],
439    );
440
441    if let Some((key, val)) = extra {
442        node.attributes.insert(key.to_string(), val);
443    }
444
445    node
446}
447
448/// Export a batch-normalisation layer.
449///
450/// Produces a single `BatchNormalization` node with four weight initializers:
451/// `scale` (γ), `bias` (β), `mean` (running mean), `var` (running variance).
452///
453/// `epsilon` defaults to `1e-5` if `None`.
454///
455/// Returns `(nodes, initializers)`.
456pub fn export_batchnorm(
457    scale: &[f64],
458    bias: &[f64],
459    mean: &[f64],
460    var: &[f64],
461    epsilon: Option<f32>,
462    input_name: &str,
463    output_name: &str,
464    prefix: &str,
465) -> (Vec<OnnxNode>, Vec<OnnxTensor>) {
466    let num_features = scale.len() as i64;
467    let eps = epsilon.unwrap_or(1e-5_f32);
468
469    let scale_name = format!("{prefix}.scale");
470    let bias_name = format!("{prefix}.bias");
471    let mean_name = format!("{prefix}.mean");
472    let var_name = format!("{prefix}.var");
473
474    let initializers = vec![
475        OnnxTensor::from_f64_slice(&scale_name, vec![num_features], scale),
476        OnnxTensor::from_f64_slice(&bias_name, vec![num_features], bias),
477        OnnxTensor::from_f64_slice(&mean_name, vec![num_features], mean),
478        OnnxTensor::from_f64_slice(&var_name, vec![num_features], var),
479    ];
480
481    let node = OnnxNode::new(
482        "BatchNormalization",
483        format!("{prefix}/BatchNormalization"),
484        vec![
485            input_name.to_string(),
486            scale_name,
487            bias_name,
488            mean_name,
489            var_name,
490        ],
491        vec![output_name.to_string()],
492    )
493    .with_attr("epsilon", OnnxAttribute::Float(eps));
494
495    (vec![node], initializers)
496}
497
498// ---------------------------------------------------------------------------
499// Sequential model exporter
500// ---------------------------------------------------------------------------
501
502/// Assemble an [`OnnxModel`] from a list of pre-exported layer segments.
503///
504/// Each entry in `layers` is `(layer_name, nodes, initializers)`.  Tensor
505/// names are assigned automatically: the graph input is `"input_0"` and
506/// intermediate activations follow the convention `"{layer_name}_out"`.
507///
508/// `input_shape` describes the graph input (use `None` for dynamic / batch
509/// dimensions).
510///
511/// ```rust
512/// use scirs2_neural::export::onnx::{export_linear, export_activation, export_sequential};
513/// use scirs2_core::ndarray::Array2;
514///
515/// let w1 = Array2::<f64>::zeros((64, 784));
516/// let (n1, i1) = export_linear(&w1, None, "input_0", "fc0_out", "fc0");
517/// let act1 = export_activation("relu", "fc0_out", "act0_out");
518///
519/// let w2 = Array2::<f64>::zeros((10, 64));
520/// let (n2, i2) = export_linear(&w2, None, "act0_out", "output_0", "fc1");
521///
522/// let layers = vec![
523///     ("fc0".to_string(), n1, i1),
524///     ("act0".to_string(), vec![act1], vec![]),
525///     ("fc1".to_string(), n2, i2),
526/// ];
527///
528/// let model = export_sequential(&layers, &[None, Some(784)]);
529/// assert_eq!(model.graph.nodes.len(), 3);
530/// assert_eq!(model.opset_version, 17);
531/// ```
532pub fn export_sequential(
533    layers: &[(String, Vec<OnnxNode>, Vec<OnnxTensor>)],
534    input_shape: &[Option<i64>],
535) -> OnnxModel {
536    let mut graph = OnnxGraph::new();
537
538    // Graph input
539    graph.inputs.push(OnnxValueInfo::new(
540        "input_0",
541        OnnxDataType::Float32,
542        input_shape.to_vec(),
543    ));
544
545    // Collect all nodes and initializers
546    let mut last_output = "input_0".to_string();
547    for (layer_name, nodes, inits) in layers {
548        graph.initializers.extend(inits.iter().cloned());
549        for node in nodes {
550            graph.nodes.push(node.clone());
551        }
552        // Track last output produced by this layer's final node
553        if let Some(last_node) = nodes.last() {
554            if let Some(out) = last_node.outputs.first() {
555                last_output = out.clone();
556            } else {
557                last_output = format!("{layer_name}_out");
558            }
559        }
560    }
561
562    // Graph output
563    graph.outputs.push(OnnxValueInfo::new(
564        last_output,
565        OnnxDataType::Float32,
566        vec![None],
567    ));
568
569    OnnxModel::new(graph)
570}
571
572// ---------------------------------------------------------------------------
573// Tests
574// ---------------------------------------------------------------------------
575
576#[cfg(test)]
577mod tests {
578    use super::*;
579    use scirs2_core::ndarray::{Array1, Array2, Array4};
580
581    // -----------------------------------------------------------------------
582    // Node construction tests
583    // -----------------------------------------------------------------------
584
585    #[test]
586    fn test_onnx_activation_node_relu() {
587        let node = export_activation("relu", "x", "y");
588        assert_eq!(node.op_type, "Relu");
589        assert_eq!(node.inputs, vec!["x".to_string()]);
590        assert_eq!(node.outputs, vec!["y".to_string()]);
591    }
592
593    #[test]
594    fn test_onnx_activation_node_sigmoid() {
595        let node = export_activation("sigmoid", "x", "y");
596        assert_eq!(node.op_type, "Sigmoid");
597    }
598
599    #[test]
600    fn test_onnx_activation_node_tanh() {
601        let node = export_activation("tanh", "x", "y");
602        assert_eq!(node.op_type, "Tanh");
603    }
604
605    #[test]
606    fn test_onnx_activation_node_softmax() {
607        let node = export_activation("softmax", "x", "y");
608        assert_eq!(node.op_type, "Softmax");
609        assert!(node.attributes.contains_key("axis"));
610    }
611
612    #[test]
613    fn test_onnx_activation_node_unknown_fallback() {
614        let node = export_activation("crelu_custom", "x", "y");
615        // Falls back to Relu but records the unknown name
616        assert_eq!(node.op_type, "Relu");
617        assert!(node
618            .attributes
619            .contains_key("_scirs2_unsupported_activation"));
620    }
621
622    // -----------------------------------------------------------------------
623    // Linear (Gemm) exporter tests
624    // -----------------------------------------------------------------------
625
626    #[test]
627    fn test_onnx_linear_node_no_bias() {
628        let w = Array2::<f64>::zeros((4, 8));
629        let (nodes, inits) = export_linear(&w, None, "x", "y", "fc");
630        assert_eq!(nodes.len(), 1);
631        assert_eq!(nodes[0].op_type, "Gemm");
632        // Only weight initializer (no bias)
633        assert_eq!(inits.len(), 1);
634        assert_eq!(inits[0].dims, vec![4_i64, 8_i64]);
635        assert_eq!(inits[0].float_data.len(), 32);
636    }
637
638    #[test]
639    fn test_onnx_linear_node_with_bias() {
640        let w = Array2::<f64>::zeros((4, 8));
641        let b = Array1::<f64>::zeros(4);
642        let (nodes, inits) = export_linear(&w, Some(&b), "x", "y", "fc");
643        assert_eq!(nodes.len(), 1);
644        assert_eq!(inits.len(), 2); // weight + bias
645                                    // Bias is the second initializer
646        assert_eq!(inits[1].dims, vec![4_i64]);
647        assert_eq!(inits[1].float_data.len(), 4);
648    }
649
650    #[test]
651    fn test_onnx_linear_trans_b_attribute() {
652        let w = Array2::<f64>::zeros((3, 5));
653        let (nodes, _) = export_linear(&w, None, "x", "y", "fc");
654        let trans_b = nodes[0].attributes.get("transB").expect("transB attribute");
655        assert_eq!(trans_b, &OnnxAttribute::Int(1));
656    }
657
658    // -----------------------------------------------------------------------
659    // Conv2d exporter tests
660    // -----------------------------------------------------------------------
661
662    #[test]
663    fn test_onnx_conv2d_node() {
664        // [out_channels, in_channels, kH, kW]
665        let w = Array4::<f64>::zeros((16, 3, 3, 3));
666        let (nodes, inits) = export_conv2d(&w, None, &[1, 1], &[1, 1], "x", "y", "conv1");
667        assert_eq!(nodes.len(), 1);
668        assert_eq!(nodes[0].op_type, "Conv");
669        assert_eq!(inits.len(), 1);
670        assert_eq!(inits[0].dims, vec![16, 3, 3, 3]);
671        assert_eq!(inits[0].float_data.len(), 16 * 3 * 3 * 3);
672    }
673
674    #[test]
675    fn test_onnx_conv2d_with_bias() {
676        let w = Array4::<f64>::zeros((8, 1, 5, 5));
677        let b = Array1::<f64>::zeros(8);
678        let (nodes, inits) = export_conv2d(&w, Some(&b), &[2, 2], &[0, 0], "x", "y", "conv0");
679        assert_eq!(nodes.len(), 1);
680        assert_eq!(inits.len(), 2);
681        // Check stride attribute
682        let strides = nodes[0].attributes.get("strides").expect("strides");
683        assert_eq!(strides, &OnnxAttribute::Ints(vec![2, 2]));
684    }
685
686    // -----------------------------------------------------------------------
687    // BatchNorm exporter tests
688    // -----------------------------------------------------------------------
689
690    #[test]
691    fn test_onnx_batchnorm_export() {
692        let scale = vec![1.0_f64; 32];
693        let bias = vec![0.0_f64; 32];
694        let mean = vec![0.0_f64; 32];
695        let var = vec![1.0_f64; 32];
696        let (nodes, inits) = export_batchnorm(&scale, &bias, &mean, &var, None, "x", "y", "bn1");
697        assert_eq!(nodes.len(), 1);
698        assert_eq!(nodes[0].op_type, "BatchNormalization");
699        // Must have exactly 4 initializers: scale, bias, mean, var
700        assert_eq!(inits.len(), 4);
701        for init in &inits {
702            assert_eq!(init.dims, vec![32_i64]);
703            assert_eq!(init.float_data.len(), 32);
704        }
705    }
706
707    #[test]
708    fn test_onnx_batchnorm_epsilon_attribute() {
709        let v = vec![1.0_f64; 4];
710        let (nodes, _) = export_batchnorm(&v, &v, &v, &v, Some(1e-3), "x", "y", "bn");
711        let eps = nodes[0].attributes.get("epsilon").expect("epsilon attr");
712        assert_eq!(eps, &OnnxAttribute::Float(1e-3_f32));
713    }
714
715    // -----------------------------------------------------------------------
716    // OnnxModel metadata tests
717    // -----------------------------------------------------------------------
718
719    #[test]
720    fn test_onnx_opset_default() {
721        let model = OnnxModel::default();
722        assert_eq!(model.opset_version, 17);
723        assert_eq!(model.ir_version, 8);
724        assert_eq!(model.producer_name, "scirs2-neural");
725    }
726
727    // -----------------------------------------------------------------------
728    // Serialisation round-trip tests
729    // -----------------------------------------------------------------------
730
731    fn build_small_model() -> OnnxModel {
732        let w = Array2::<f64>::zeros((4, 8));
733        let b = Array1::<f64>::zeros(4);
734        let (nodes, inits) = export_linear(&w, Some(&b), "input_0", "output_0", "fc0");
735        let mut graph = OnnxGraph::new();
736        graph.inputs.push(OnnxValueInfo::new(
737            "input_0",
738            OnnxDataType::Float32,
739            vec![None, Some(8)],
740        ));
741        graph.outputs.push(OnnxValueInfo::new(
742            "output_0",
743            OnnxDataType::Float32,
744            vec![None, Some(4)],
745        ));
746        graph.nodes.extend(nodes);
747        graph.initializers.extend(inits);
748        OnnxModel::new(graph)
749    }
750
751    #[test]
752    fn test_onnx_model_roundtrip_bytes() {
753        let original = build_small_model();
754        let bytes = original.to_bytes().expect("to_bytes failed");
755        let restored = OnnxModel::from_bytes(&bytes).expect("from_bytes failed");
756        assert_eq!(restored.opset_version, original.opset_version);
757        assert_eq!(restored.graph.nodes.len(), original.graph.nodes.len());
758        assert_eq!(
759            restored.graph.initializers.len(),
760            original.graph.initializers.len()
761        );
762        assert_eq!(
763            restored.graph.initializers[0].float_data.len(),
764            original.graph.initializers[0].float_data.len()
765        );
766    }
767
768    #[test]
769    fn test_onnx_json_roundtrip() {
770        let original = build_small_model();
771        let json = original.to_json().expect("to_json failed");
772        assert!(json.contains("Gemm"));
773        let restored = OnnxModel::from_json(&json).expect("from_json failed");
774        assert_eq!(restored.graph.nodes[0].op_type, "Gemm");
775        assert_eq!(restored.graph.inputs[0].name, "input_0");
776    }
777
778    #[test]
779    fn test_onnx_json_contains_producer_name() {
780        let model = OnnxModel::default();
781        let json = model.to_json().expect("to_json");
782        assert!(json.contains("scirs2-neural"));
783    }
784
785    // -----------------------------------------------------------------------
786    // Sequential exporter tests
787    // -----------------------------------------------------------------------
788
789    #[test]
790    fn test_onnx_sequential_graph() {
791        let w1 = Array2::<f64>::zeros((64, 784));
792        let (n1, i1) = export_linear(&w1, None, "input_0", "fc0_out", "fc0");
793        let act1 = export_activation("relu", "fc0_out", "act0_out");
794
795        let w2 = Array2::<f64>::zeros((10, 64));
796        let (n2, i2) = export_linear(&w2, None, "act0_out", "output_0", "fc1");
797
798        let layers = vec![
799            ("fc0".to_string(), n1, i1),
800            ("act0".to_string(), vec![act1], vec![]),
801            ("fc1".to_string(), n2, i2),
802        ];
803
804        let model = export_sequential(&layers, &[None, Some(784)]);
805        // 3 nodes total: Gemm, Relu, Gemm
806        assert_eq!(model.graph.nodes.len(), 3);
807        assert_eq!(model.graph.nodes[0].op_type, "Gemm");
808        assert_eq!(model.graph.nodes[1].op_type, "Relu");
809        assert_eq!(model.graph.nodes[2].op_type, "Gemm");
810        // 2 weight initializers (one per linear layer)
811        assert_eq!(model.graph.initializers.len(), 2);
812        assert_eq!(model.opset_version, 17);
813    }
814
815    #[test]
816    fn test_onnx_sequential_single_layer() {
817        let w = Array2::<f64>::zeros((2, 3));
818        let (nodes, inits) = export_linear(&w, None, "input_0", "output_0", "fc");
819        let layers = vec![("fc".to_string(), nodes, inits)];
820        let model = export_sequential(&layers, &[None, Some(3)]);
821        assert_eq!(model.graph.nodes.len(), 1);
822        assert_eq!(model.graph.inputs[0].name, "input_0");
823    }
824
825    #[test]
826    fn test_onnx_tensor_numel() {
827        let t = OnnxTensor::from_f64_slice("t", vec![2, 3, 4], &[0.0_f64; 24]);
828        assert_eq!(t.numel(), 24);
829    }
830
831    #[test]
832    fn test_onnx_node_builder_with_attr() {
833        let node = OnnxNode::new("Relu", "r", vec!["x".to_string()], vec!["y".to_string()])
834            .with_attr("alpha", OnnxAttribute::Float(0.1));
835        assert!(node.attributes.contains_key("alpha"));
836    }
837}