Skip to main content

yscv_model/
onnx_export.rs

1use yscv_autograd::Graph;
2use yscv_onnx::{
3    OnnxExportAttr, OnnxExportGraph, OnnxExportNode, OnnxExportValueInfo, export_onnx_model,
4    export_onnx_model_to_file,
5};
6use yscv_tensor::Tensor;
7
8use crate::{ModelError, ModelLayer, SequentialModel};
9
10/// Exports a `SequentialModel` to an ONNX protobuf byte vector.
11///
12/// `input_shape` is the full shape including batch dimension, e.g. `[1, 28, 28, 1]` for NHWC.
13/// Linear layer weights are read from the autograd `graph`.
14pub fn export_sequential_to_onnx(
15    model: &SequentialModel,
16    graph: &Graph,
17    input_shape: &[i64],
18    producer_name: &str,
19    model_name: &str,
20) -> Result<Vec<u8>, ModelError> {
21    let export_graph = build_onnx_graph(model, graph, input_shape)?;
22    export_onnx_model(&export_graph, producer_name, model_name)
23        .map_err(|e| ModelError::OnnxExport(e.to_string()))
24}
25
26/// Exports a `SequentialModel` to an ONNX file.
27pub fn export_sequential_to_onnx_file(
28    model: &SequentialModel,
29    graph: &Graph,
30    input_shape: &[i64],
31    producer_name: &str,
32    model_name: &str,
33    path: &std::path::Path,
34) -> Result<(), ModelError> {
35    let export_graph = build_onnx_graph(model, graph, input_shape)?;
36    export_onnx_model_to_file(&export_graph, producer_name, model_name, path)
37        .map_err(|e| ModelError::OnnxExport(e.to_string()))
38}
39
40fn build_onnx_graph(
41    model: &SequentialModel,
42    graph: &Graph,
43    input_shape: &[i64],
44) -> Result<OnnxExportGraph, ModelError> {
45    let mut nodes = Vec::new();
46    let mut initializers: Vec<(String, Tensor)> = Vec::new();
47    let mut current_name = "input".to_string();
48    let mut node_counter = 0usize;
49
50    for (idx, layer) in model.layers().iter().enumerate() {
51        let out_name = format!("layer{idx}_out");
52        match layer {
53            ModelLayer::Linear(l) => {
54                let weight_tensor = graph
55                    .value(l.weight_node().expect("linear layer has weight node"))
56                    .map_err(|e| ModelError::OnnxExport(format!("Linear weight: {e}")))?
57                    .clone();
58                let bias_tensor = graph
59                    .value(l.bias_node().expect("linear layer has bias node"))
60                    .map_err(|e| ModelError::OnnxExport(format!("Linear bias: {e}")))?
61                    .clone();
62
63                let w_name = format!("linear{idx}_weight");
64                let b_name = format!("linear{idx}_bias");
65
66                let transposed = weight_tensor
67                    .transpose_2d()
68                    .map_err(|e| ModelError::OnnxExport(format!("transpose weight: {e}")))?;
69                initializers.push((w_name.clone(), transposed));
70                initializers.push((b_name.clone(), bias_tensor));
71
72                nodes.push(OnnxExportNode {
73                    op_type: "Gemm".to_string(),
74                    name: format!("node{node_counter}"),
75                    inputs: vec![current_name.clone(), w_name, b_name],
76                    outputs: vec![out_name.clone()],
77                    attributes: vec![OnnxExportAttr::Int("transB".to_string(), 0)],
78                });
79                node_counter += 1;
80            }
81            ModelLayer::ReLU(_) => {
82                nodes.push(OnnxExportNode {
83                    op_type: "Relu".to_string(),
84                    name: format!("node{node_counter}"),
85                    inputs: vec![current_name.clone()],
86                    outputs: vec![out_name.clone()],
87                    attributes: vec![],
88                });
89                node_counter += 1;
90            }
91            ModelLayer::LeakyReLU(l) => {
92                nodes.push(OnnxExportNode {
93                    op_type: "LeakyRelu".to_string(),
94                    name: format!("node{node_counter}"),
95                    inputs: vec![current_name.clone()],
96                    outputs: vec![out_name.clone()],
97                    attributes: vec![OnnxExportAttr::Float(
98                        "alpha".to_string(),
99                        l.negative_slope(),
100                    )],
101                });
102                node_counter += 1;
103            }
104            ModelLayer::Sigmoid(_) => {
105                nodes.push(OnnxExportNode {
106                    op_type: "Sigmoid".to_string(),
107                    name: format!("node{node_counter}"),
108                    inputs: vec![current_name.clone()],
109                    outputs: vec![out_name.clone()],
110                    attributes: vec![],
111                });
112                node_counter += 1;
113            }
114            ModelLayer::Tanh(_) => {
115                nodes.push(OnnxExportNode {
116                    op_type: "Tanh".to_string(),
117                    name: format!("node{node_counter}"),
118                    inputs: vec![current_name.clone()],
119                    outputs: vec![out_name.clone()],
120                    attributes: vec![],
121                });
122                node_counter += 1;
123            }
124            ModelLayer::Dropout(d) => {
125                let ratio_name = format!("dropout{idx}_ratio");
126                initializers.push((ratio_name.clone(), Tensor::scalar(d.rate())));
127                nodes.push(OnnxExportNode {
128                    op_type: "Dropout".to_string(),
129                    name: format!("node{node_counter}"),
130                    inputs: vec![current_name.clone(), ratio_name],
131                    outputs: vec![out_name.clone()],
132                    attributes: vec![],
133                });
134                node_counter += 1;
135            }
136            ModelLayer::Conv2d(l) => {
137                let weight_nhwc = l.weight();
138                let w_onnx = nhwc_weight_to_nchw(weight_nhwc)?;
139                let w_name = format!("conv{idx}_weight");
140                initializers.push((w_name.clone(), w_onnx));
141
142                let mut inputs = vec![current_name.clone(), w_name];
143                if let Some(bias) = l.bias() {
144                    let b_name = format!("conv{idx}_bias");
145                    initializers.push((b_name.clone(), bias.clone()));
146                    inputs.push(b_name);
147                }
148
149                nodes.push(OnnxExportNode {
150                    op_type: "Conv".to_string(),
151                    name: format!("node{node_counter}"),
152                    inputs,
153                    outputs: vec![out_name.clone()],
154                    attributes: vec![
155                        OnnxExportAttr::Ints(
156                            "kernel_shape".to_string(),
157                            vec![l.kernel_h() as i64, l.kernel_w() as i64],
158                        ),
159                        OnnxExportAttr::Ints(
160                            "strides".to_string(),
161                            vec![l.stride_h() as i64, l.stride_w() as i64],
162                        ),
163                    ],
164                });
165                node_counter += 1;
166            }
167            ModelLayer::BatchNorm2d(l) => {
168                let scale_name = format!("bn{idx}_scale");
169                let bias_name = format!("bn{idx}_bias");
170                let mean_name = format!("bn{idx}_mean");
171                let var_name = format!("bn{idx}_var");
172                initializers.push((scale_name.clone(), l.gamma().clone()));
173                initializers.push((bias_name.clone(), l.beta().clone()));
174                initializers.push((mean_name.clone(), l.running_mean().clone()));
175                initializers.push((var_name.clone(), l.running_var().clone()));
176
177                nodes.push(OnnxExportNode {
178                    op_type: "BatchNormalization".to_string(),
179                    name: format!("node{node_counter}"),
180                    inputs: vec![
181                        current_name.clone(),
182                        scale_name,
183                        bias_name,
184                        mean_name,
185                        var_name,
186                    ],
187                    outputs: vec![out_name.clone()],
188                    attributes: vec![OnnxExportAttr::Float("epsilon".to_string(), l.epsilon())],
189                });
190                node_counter += 1;
191            }
192            ModelLayer::MaxPool2d(l) => {
193                nodes.push(OnnxExportNode {
194                    op_type: "MaxPool".to_string(),
195                    name: format!("node{node_counter}"),
196                    inputs: vec![current_name.clone()],
197                    outputs: vec![out_name.clone()],
198                    attributes: vec![
199                        OnnxExportAttr::Ints(
200                            "kernel_shape".to_string(),
201                            vec![l.kernel_h() as i64, l.kernel_w() as i64],
202                        ),
203                        OnnxExportAttr::Ints(
204                            "strides".to_string(),
205                            vec![l.stride_h() as i64, l.stride_w() as i64],
206                        ),
207                    ],
208                });
209                node_counter += 1;
210            }
211            ModelLayer::AvgPool2d(l) => {
212                nodes.push(OnnxExportNode {
213                    op_type: "AveragePool".to_string(),
214                    name: format!("node{node_counter}"),
215                    inputs: vec![current_name.clone()],
216                    outputs: vec![out_name.clone()],
217                    attributes: vec![
218                        OnnxExportAttr::Ints(
219                            "kernel_shape".to_string(),
220                            vec![l.kernel_h() as i64, l.kernel_w() as i64],
221                        ),
222                        OnnxExportAttr::Ints(
223                            "strides".to_string(),
224                            vec![l.stride_h() as i64, l.stride_w() as i64],
225                        ),
226                    ],
227                });
228                node_counter += 1;
229            }
230            ModelLayer::GlobalAvgPool2d(_) => {
231                nodes.push(OnnxExportNode {
232                    op_type: "GlobalAveragePool".to_string(),
233                    name: format!("node{node_counter}"),
234                    inputs: vec![current_name.clone()],
235                    outputs: vec![out_name.clone()],
236                    attributes: vec![],
237                });
238                node_counter += 1;
239            }
240            ModelLayer::Flatten(_) => {
241                nodes.push(OnnxExportNode {
242                    op_type: "Flatten".to_string(),
243                    name: format!("node{node_counter}"),
244                    inputs: vec![current_name.clone()],
245                    outputs: vec![out_name.clone()],
246                    attributes: vec![OnnxExportAttr::Int("axis".to_string(), 1)],
247                });
248                node_counter += 1;
249            }
250            ModelLayer::Softmax(_) => {
251                nodes.push(OnnxExportNode {
252                    op_type: "Softmax".to_string(),
253                    name: format!("node{node_counter}"),
254                    inputs: vec![current_name.clone()],
255                    outputs: vec![out_name.clone()],
256                    attributes: vec![OnnxExportAttr::Int("axis".to_string(), -1)],
257                });
258                node_counter += 1;
259            }
260            ModelLayer::DepthwiseConv2d(l) => {
261                let w_nhwc = l.weight();
262                // Depthwise weight [KH,KW,C,1] → ONNX group conv [C,1,KH,KW]
263                let s = w_nhwc.shape();
264                let (kh, kw, c) = (s[0], s[1], s[2]);
265                let src = w_nhwc.data();
266                let mut dst = vec![0.0f32; src.len()];
267                for ch in 0..c {
268                    for r in 0..kh {
269                        for col in 0..kw {
270                            let src_idx = (r * kw + col) * c + ch;
271                            let dst_idx = (ch * kh + r) * kw + col;
272                            dst[dst_idx] = src[src_idx];
273                        }
274                    }
275                }
276                let w_onnx = Tensor::from_vec(vec![c, 1, kh, kw], dst)
277                    .map_err(|e| ModelError::OnnxExport(e.to_string()))?;
278                let w_name = format!("dwconv{idx}_weight");
279                initializers.push((w_name.clone(), w_onnx));
280
281                let mut inputs = vec![current_name.clone(), w_name];
282                if let Some(bias) = l.bias() {
283                    let b_name = format!("dwconv{idx}_bias");
284                    initializers.push((b_name.clone(), bias.clone()));
285                    inputs.push(b_name);
286                }
287
288                nodes.push(OnnxExportNode {
289                    op_type: "Conv".to_string(),
290                    name: format!("node{node_counter}"),
291                    inputs,
292                    outputs: vec![out_name.clone()],
293                    attributes: vec![
294                        OnnxExportAttr::Ints(
295                            "kernel_shape".to_string(),
296                            vec![l.kernel_h() as i64, l.kernel_w() as i64],
297                        ),
298                        OnnxExportAttr::Ints(
299                            "strides".to_string(),
300                            vec![l.stride_h() as i64, l.stride_w() as i64],
301                        ),
302                        OnnxExportAttr::Int("group".to_string(), c as i64),
303                    ],
304                });
305                node_counter += 1;
306            }
307            ModelLayer::SeparableConv2d(l) => {
308                // Depthwise part
309                let dw = l.depthwise();
310                let s = dw.weight().shape();
311                let (kh, kw, c) = (s[0], s[1], s[2]);
312                let src = dw.weight().data();
313                let mut dst = vec![0.0f32; src.len()];
314                for ch in 0..c {
315                    for r in 0..kh {
316                        for col in 0..kw {
317                            let src_idx = (r * kw + col) * c + ch;
318                            let dst_idx = (ch * kh + r) * kw + col;
319                            dst[dst_idx] = src[src_idx];
320                        }
321                    }
322                }
323                let dw_onnx = Tensor::from_vec(vec![c, 1, kh, kw], dst)
324                    .map_err(|e| ModelError::OnnxExport(e.to_string()))?;
325                let dw_name = format!("sepconv{idx}_dw_weight");
326                initializers.push((dw_name.clone(), dw_onnx));
327
328                let dw_out_name = format!("layer{idx}_dw_out");
329                nodes.push(OnnxExportNode {
330                    op_type: "Conv".to_string(),
331                    name: format!("node{node_counter}"),
332                    inputs: vec![current_name.clone(), dw_name],
333                    outputs: vec![dw_out_name.clone()],
334                    attributes: vec![
335                        OnnxExportAttr::Ints(
336                            "kernel_shape".to_string(),
337                            vec![kh as i64, kw as i64],
338                        ),
339                        OnnxExportAttr::Ints(
340                            "strides".to_string(),
341                            vec![l.stride_h() as i64, l.stride_w() as i64],
342                        ),
343                        OnnxExportAttr::Int("group".to_string(), c as i64),
344                    ],
345                });
346                node_counter += 1;
347
348                // Pointwise part
349                let pw = l.pointwise();
350                let pw_onnx = nhwc_weight_to_nchw(pw.weight())?;
351                let pw_name = format!("sepconv{idx}_pw_weight");
352                initializers.push((pw_name.clone(), pw_onnx));
353                let mut pw_inputs = vec![dw_out_name, pw_name];
354                if let Some(bias) = pw.bias() {
355                    let b_name = format!("sepconv{idx}_bias");
356                    initializers.push((b_name.clone(), bias.clone()));
357                    pw_inputs.push(b_name);
358                }
359                nodes.push(OnnxExportNode {
360                    op_type: "Conv".to_string(),
361                    name: format!("node{node_counter}"),
362                    inputs: pw_inputs,
363                    outputs: vec![out_name.clone()],
364                    attributes: vec![
365                        OnnxExportAttr::Ints("kernel_shape".to_string(), vec![1, 1]),
366                        OnnxExportAttr::Ints("strides".to_string(), vec![1, 1]),
367                    ],
368                });
369                node_counter += 1;
370            }
371            ModelLayer::ResidualBlock(r) => {
372                // Export inner layers as Identity nodes, then add skip connection.
373                let skip_name = current_name.clone();
374                let mut inner_name = current_name.clone();
375                for (sub_idx, _inner_layer) in r.layers().iter().enumerate() {
376                    let inner_out = format!("layer{idx}_res{sub_idx}_out");
377                    nodes.push(OnnxExportNode {
378                        op_type: "Identity".to_string(),
379                        name: format!("node{node_counter}"),
380                        inputs: vec![inner_name.clone()],
381                        outputs: vec![inner_out.clone()],
382                        attributes: vec![],
383                    });
384                    node_counter += 1;
385                    inner_name = inner_out;
386                }
387                // Add node for skip connection: output = inner_output + skip_input
388                nodes.push(OnnxExportNode {
389                    op_type: "Add".to_string(),
390                    name: format!("node{node_counter}"),
391                    inputs: vec![inner_name, skip_name],
392                    outputs: vec![out_name.clone()],
393                    attributes: vec![],
394                });
395                node_counter += 1;
396            }
397            ModelLayer::Embedding(_)
398            | ModelLayer::LayerNorm(_)
399            | ModelLayer::GroupNorm(_)
400            | ModelLayer::LoraLinear(_)
401            | ModelLayer::Conv1d(_)
402            | ModelLayer::Conv3d(_)
403            | ModelLayer::ConvTranspose2d(_)
404            | ModelLayer::AdaptiveAvgPool2d(_)
405            | ModelLayer::AdaptiveMaxPool2d(_)
406            | ModelLayer::InstanceNorm(_)
407            | ModelLayer::PixelShuffle(_)
408            | ModelLayer::Upsample(_)
409            | ModelLayer::GELU(_)
410            | ModelLayer::SiLU(_)
411            | ModelLayer::Mish(_)
412            | ModelLayer::PReLU(_)
413            | ModelLayer::Rnn(_)
414            | ModelLayer::Lstm(_)
415            | ModelLayer::Gru(_)
416            | ModelLayer::MultiHeadAttention(_)
417            | ModelLayer::TransformerEncoder(_)
418            | ModelLayer::FeedForward(_)
419            | ModelLayer::DeformableConv2d(_) => {
420                nodes.push(OnnxExportNode {
421                    op_type: "Identity".to_string(),
422                    name: format!("node{node_counter}"),
423                    inputs: vec![current_name.clone()],
424                    outputs: vec![out_name.clone()],
425                    attributes: vec![],
426                });
427                node_counter += 1;
428            }
429        }
430        current_name = out_name;
431    }
432
433    let inputs = vec![OnnxExportValueInfo {
434        name: "input".to_string(),
435        shape: input_shape.to_vec(),
436    }];
437    let outputs = vec![OnnxExportValueInfo {
438        name: current_name,
439        shape: vec![],
440    }];
441
442    Ok(OnnxExportGraph {
443        nodes,
444        initializers,
445        inputs,
446        outputs,
447    })
448}
449
450/// Transpose convolution weight from NHWC `[KH, KW, Cin, Cout]`
451/// to ONNX NCHW `[Cout, Cin, KH, KW]`.
452fn nhwc_weight_to_nchw(w: &Tensor) -> Result<Tensor, ModelError> {
453    let s = w.shape();
454    if s.len() != 4 {
455        return Err(ModelError::OnnxExport(format!(
456            "conv weight rank must be 4, got {}",
457            s.len()
458        )));
459    }
460    let (kh, kw, cin, cout) = (s[0], s[1], s[2], s[3]);
461    let src = w.data();
462    let mut dst = vec![0.0f32; src.len()];
463    for oc in 0..cout {
464        for ic in 0..cin {
465            for r in 0..kh {
466                for c in 0..kw {
467                    let src_idx = ((r * kw + c) * cin + ic) * cout + oc;
468                    let dst_idx = ((oc * cin + ic) * kh + r) * kw + c;
469                    dst[dst_idx] = src[src_idx];
470                }
471            }
472        }
473    }
474    Tensor::from_vec(vec![cout, cin, kh, kw], dst)
475        .map_err(|e| ModelError::OnnxExport(e.to_string()))
476}