1use yscv_autograd::Graph;
2use yscv_onnx::{
3 OnnxExportAttr, OnnxExportGraph, OnnxExportNode, OnnxExportValueInfo, export_onnx_model,
4 export_onnx_model_to_file,
5};
6use yscv_tensor::Tensor;
7
8use crate::{ModelError, ModelLayer, SequentialModel};
9
10pub fn export_sequential_to_onnx(
15 model: &SequentialModel,
16 graph: &Graph,
17 input_shape: &[i64],
18 producer_name: &str,
19 model_name: &str,
20) -> Result<Vec<u8>, ModelError> {
21 let export_graph = build_onnx_graph(model, graph, input_shape)?;
22 export_onnx_model(&export_graph, producer_name, model_name)
23 .map_err(|e| ModelError::OnnxExport(e.to_string()))
24}
25
26pub fn export_sequential_to_onnx_file(
28 model: &SequentialModel,
29 graph: &Graph,
30 input_shape: &[i64],
31 producer_name: &str,
32 model_name: &str,
33 path: &std::path::Path,
34) -> Result<(), ModelError> {
35 let export_graph = build_onnx_graph(model, graph, input_shape)?;
36 export_onnx_model_to_file(&export_graph, producer_name, model_name, path)
37 .map_err(|e| ModelError::OnnxExport(e.to_string()))
38}
39
40fn build_onnx_graph(
41 model: &SequentialModel,
42 graph: &Graph,
43 input_shape: &[i64],
44) -> Result<OnnxExportGraph, ModelError> {
45 let mut nodes = Vec::new();
46 let mut initializers: Vec<(String, Tensor)> = Vec::new();
47 let mut current_name = "input".to_string();
48 let mut node_counter = 0usize;
49
50 for (idx, layer) in model.layers().iter().enumerate() {
51 let out_name = format!("layer{idx}_out");
52 match layer {
53 ModelLayer::Linear(l) => {
54 let weight_tensor = graph
55 .value(l.weight_node().expect("linear layer has weight node"))
56 .map_err(|e| ModelError::OnnxExport(format!("Linear weight: {e}")))?
57 .clone();
58 let bias_tensor = graph
59 .value(l.bias_node().expect("linear layer has bias node"))
60 .map_err(|e| ModelError::OnnxExport(format!("Linear bias: {e}")))?
61 .clone();
62
63 let w_name = format!("linear{idx}_weight");
64 let b_name = format!("linear{idx}_bias");
65
66 let transposed = weight_tensor
67 .transpose_2d()
68 .map_err(|e| ModelError::OnnxExport(format!("transpose weight: {e}")))?;
69 initializers.push((w_name.clone(), transposed));
70 initializers.push((b_name.clone(), bias_tensor));
71
72 nodes.push(OnnxExportNode {
73 op_type: "Gemm".to_string(),
74 name: format!("node{node_counter}"),
75 inputs: vec![current_name.clone(), w_name, b_name],
76 outputs: vec![out_name.clone()],
77 attributes: vec![OnnxExportAttr::Int("transB".to_string(), 0)],
78 });
79 node_counter += 1;
80 }
81 ModelLayer::ReLU(_) => {
82 nodes.push(OnnxExportNode {
83 op_type: "Relu".to_string(),
84 name: format!("node{node_counter}"),
85 inputs: vec![current_name.clone()],
86 outputs: vec![out_name.clone()],
87 attributes: vec![],
88 });
89 node_counter += 1;
90 }
91 ModelLayer::LeakyReLU(l) => {
92 nodes.push(OnnxExportNode {
93 op_type: "LeakyRelu".to_string(),
94 name: format!("node{node_counter}"),
95 inputs: vec![current_name.clone()],
96 outputs: vec![out_name.clone()],
97 attributes: vec![OnnxExportAttr::Float(
98 "alpha".to_string(),
99 l.negative_slope(),
100 )],
101 });
102 node_counter += 1;
103 }
104 ModelLayer::Sigmoid(_) => {
105 nodes.push(OnnxExportNode {
106 op_type: "Sigmoid".to_string(),
107 name: format!("node{node_counter}"),
108 inputs: vec![current_name.clone()],
109 outputs: vec![out_name.clone()],
110 attributes: vec![],
111 });
112 node_counter += 1;
113 }
114 ModelLayer::Tanh(_) => {
115 nodes.push(OnnxExportNode {
116 op_type: "Tanh".to_string(),
117 name: format!("node{node_counter}"),
118 inputs: vec![current_name.clone()],
119 outputs: vec![out_name.clone()],
120 attributes: vec![],
121 });
122 node_counter += 1;
123 }
124 ModelLayer::Dropout(d) => {
125 let ratio_name = format!("dropout{idx}_ratio");
126 initializers.push((ratio_name.clone(), Tensor::scalar(d.rate())));
127 nodes.push(OnnxExportNode {
128 op_type: "Dropout".to_string(),
129 name: format!("node{node_counter}"),
130 inputs: vec![current_name.clone(), ratio_name],
131 outputs: vec![out_name.clone()],
132 attributes: vec![],
133 });
134 node_counter += 1;
135 }
136 ModelLayer::Conv2d(l) => {
137 let weight_nhwc = l.weight();
138 let w_onnx = nhwc_weight_to_nchw(weight_nhwc)?;
139 let w_name = format!("conv{idx}_weight");
140 initializers.push((w_name.clone(), w_onnx));
141
142 let mut inputs = vec![current_name.clone(), w_name];
143 if let Some(bias) = l.bias() {
144 let b_name = format!("conv{idx}_bias");
145 initializers.push((b_name.clone(), bias.clone()));
146 inputs.push(b_name);
147 }
148
149 nodes.push(OnnxExportNode {
150 op_type: "Conv".to_string(),
151 name: format!("node{node_counter}"),
152 inputs,
153 outputs: vec![out_name.clone()],
154 attributes: vec![
155 OnnxExportAttr::Ints(
156 "kernel_shape".to_string(),
157 vec![l.kernel_h() as i64, l.kernel_w() as i64],
158 ),
159 OnnxExportAttr::Ints(
160 "strides".to_string(),
161 vec![l.stride_h() as i64, l.stride_w() as i64],
162 ),
163 ],
164 });
165 node_counter += 1;
166 }
167 ModelLayer::BatchNorm2d(l) => {
168 let scale_name = format!("bn{idx}_scale");
169 let bias_name = format!("bn{idx}_bias");
170 let mean_name = format!("bn{idx}_mean");
171 let var_name = format!("bn{idx}_var");
172 initializers.push((scale_name.clone(), l.gamma().clone()));
173 initializers.push((bias_name.clone(), l.beta().clone()));
174 initializers.push((mean_name.clone(), l.running_mean().clone()));
175 initializers.push((var_name.clone(), l.running_var().clone()));
176
177 nodes.push(OnnxExportNode {
178 op_type: "BatchNormalization".to_string(),
179 name: format!("node{node_counter}"),
180 inputs: vec![
181 current_name.clone(),
182 scale_name,
183 bias_name,
184 mean_name,
185 var_name,
186 ],
187 outputs: vec![out_name.clone()],
188 attributes: vec![OnnxExportAttr::Float("epsilon".to_string(), l.epsilon())],
189 });
190 node_counter += 1;
191 }
192 ModelLayer::MaxPool2d(l) => {
193 nodes.push(OnnxExportNode {
194 op_type: "MaxPool".to_string(),
195 name: format!("node{node_counter}"),
196 inputs: vec![current_name.clone()],
197 outputs: vec![out_name.clone()],
198 attributes: vec![
199 OnnxExportAttr::Ints(
200 "kernel_shape".to_string(),
201 vec![l.kernel_h() as i64, l.kernel_w() as i64],
202 ),
203 OnnxExportAttr::Ints(
204 "strides".to_string(),
205 vec![l.stride_h() as i64, l.stride_w() as i64],
206 ),
207 ],
208 });
209 node_counter += 1;
210 }
211 ModelLayer::AvgPool2d(l) => {
212 nodes.push(OnnxExportNode {
213 op_type: "AveragePool".to_string(),
214 name: format!("node{node_counter}"),
215 inputs: vec![current_name.clone()],
216 outputs: vec![out_name.clone()],
217 attributes: vec![
218 OnnxExportAttr::Ints(
219 "kernel_shape".to_string(),
220 vec![l.kernel_h() as i64, l.kernel_w() as i64],
221 ),
222 OnnxExportAttr::Ints(
223 "strides".to_string(),
224 vec![l.stride_h() as i64, l.stride_w() as i64],
225 ),
226 ],
227 });
228 node_counter += 1;
229 }
230 ModelLayer::GlobalAvgPool2d(_) => {
231 nodes.push(OnnxExportNode {
232 op_type: "GlobalAveragePool".to_string(),
233 name: format!("node{node_counter}"),
234 inputs: vec![current_name.clone()],
235 outputs: vec![out_name.clone()],
236 attributes: vec![],
237 });
238 node_counter += 1;
239 }
240 ModelLayer::Flatten(_) => {
241 nodes.push(OnnxExportNode {
242 op_type: "Flatten".to_string(),
243 name: format!("node{node_counter}"),
244 inputs: vec![current_name.clone()],
245 outputs: vec![out_name.clone()],
246 attributes: vec![OnnxExportAttr::Int("axis".to_string(), 1)],
247 });
248 node_counter += 1;
249 }
250 ModelLayer::Softmax(_) => {
251 nodes.push(OnnxExportNode {
252 op_type: "Softmax".to_string(),
253 name: format!("node{node_counter}"),
254 inputs: vec![current_name.clone()],
255 outputs: vec![out_name.clone()],
256 attributes: vec![OnnxExportAttr::Int("axis".to_string(), -1)],
257 });
258 node_counter += 1;
259 }
260 ModelLayer::DepthwiseConv2d(l) => {
261 let w_nhwc = l.weight();
262 let s = w_nhwc.shape();
264 let (kh, kw, c) = (s[0], s[1], s[2]);
265 let src = w_nhwc.data();
266 let mut dst = vec![0.0f32; src.len()];
267 for ch in 0..c {
268 for r in 0..kh {
269 for col in 0..kw {
270 let src_idx = (r * kw + col) * c + ch;
271 let dst_idx = (ch * kh + r) * kw + col;
272 dst[dst_idx] = src[src_idx];
273 }
274 }
275 }
276 let w_onnx = Tensor::from_vec(vec![c, 1, kh, kw], dst)
277 .map_err(|e| ModelError::OnnxExport(e.to_string()))?;
278 let w_name = format!("dwconv{idx}_weight");
279 initializers.push((w_name.clone(), w_onnx));
280
281 let mut inputs = vec![current_name.clone(), w_name];
282 if let Some(bias) = l.bias() {
283 let b_name = format!("dwconv{idx}_bias");
284 initializers.push((b_name.clone(), bias.clone()));
285 inputs.push(b_name);
286 }
287
288 nodes.push(OnnxExportNode {
289 op_type: "Conv".to_string(),
290 name: format!("node{node_counter}"),
291 inputs,
292 outputs: vec![out_name.clone()],
293 attributes: vec![
294 OnnxExportAttr::Ints(
295 "kernel_shape".to_string(),
296 vec![l.kernel_h() as i64, l.kernel_w() as i64],
297 ),
298 OnnxExportAttr::Ints(
299 "strides".to_string(),
300 vec![l.stride_h() as i64, l.stride_w() as i64],
301 ),
302 OnnxExportAttr::Int("group".to_string(), c as i64),
303 ],
304 });
305 node_counter += 1;
306 }
307 ModelLayer::SeparableConv2d(l) => {
308 let dw = l.depthwise();
310 let s = dw.weight().shape();
311 let (kh, kw, c) = (s[0], s[1], s[2]);
312 let src = dw.weight().data();
313 let mut dst = vec![0.0f32; src.len()];
314 for ch in 0..c {
315 for r in 0..kh {
316 for col in 0..kw {
317 let src_idx = (r * kw + col) * c + ch;
318 let dst_idx = (ch * kh + r) * kw + col;
319 dst[dst_idx] = src[src_idx];
320 }
321 }
322 }
323 let dw_onnx = Tensor::from_vec(vec![c, 1, kh, kw], dst)
324 .map_err(|e| ModelError::OnnxExport(e.to_string()))?;
325 let dw_name = format!("sepconv{idx}_dw_weight");
326 initializers.push((dw_name.clone(), dw_onnx));
327
328 let dw_out_name = format!("layer{idx}_dw_out");
329 nodes.push(OnnxExportNode {
330 op_type: "Conv".to_string(),
331 name: format!("node{node_counter}"),
332 inputs: vec![current_name.clone(), dw_name],
333 outputs: vec![dw_out_name.clone()],
334 attributes: vec![
335 OnnxExportAttr::Ints(
336 "kernel_shape".to_string(),
337 vec![kh as i64, kw as i64],
338 ),
339 OnnxExportAttr::Ints(
340 "strides".to_string(),
341 vec![l.stride_h() as i64, l.stride_w() as i64],
342 ),
343 OnnxExportAttr::Int("group".to_string(), c as i64),
344 ],
345 });
346 node_counter += 1;
347
348 let pw = l.pointwise();
350 let pw_onnx = nhwc_weight_to_nchw(pw.weight())?;
351 let pw_name = format!("sepconv{idx}_pw_weight");
352 initializers.push((pw_name.clone(), pw_onnx));
353 let mut pw_inputs = vec![dw_out_name, pw_name];
354 if let Some(bias) = pw.bias() {
355 let b_name = format!("sepconv{idx}_bias");
356 initializers.push((b_name.clone(), bias.clone()));
357 pw_inputs.push(b_name);
358 }
359 nodes.push(OnnxExportNode {
360 op_type: "Conv".to_string(),
361 name: format!("node{node_counter}"),
362 inputs: pw_inputs,
363 outputs: vec![out_name.clone()],
364 attributes: vec![
365 OnnxExportAttr::Ints("kernel_shape".to_string(), vec![1, 1]),
366 OnnxExportAttr::Ints("strides".to_string(), vec![1, 1]),
367 ],
368 });
369 node_counter += 1;
370 }
371 ModelLayer::ResidualBlock(r) => {
372 let skip_name = current_name.clone();
374 let mut inner_name = current_name.clone();
375 for (sub_idx, _inner_layer) in r.layers().iter().enumerate() {
376 let inner_out = format!("layer{idx}_res{sub_idx}_out");
377 nodes.push(OnnxExportNode {
378 op_type: "Identity".to_string(),
379 name: format!("node{node_counter}"),
380 inputs: vec![inner_name.clone()],
381 outputs: vec![inner_out.clone()],
382 attributes: vec![],
383 });
384 node_counter += 1;
385 inner_name = inner_out;
386 }
387 nodes.push(OnnxExportNode {
389 op_type: "Add".to_string(),
390 name: format!("node{node_counter}"),
391 inputs: vec![inner_name, skip_name],
392 outputs: vec![out_name.clone()],
393 attributes: vec![],
394 });
395 node_counter += 1;
396 }
397 ModelLayer::Embedding(_)
398 | ModelLayer::LayerNorm(_)
399 | ModelLayer::GroupNorm(_)
400 | ModelLayer::LoraLinear(_)
401 | ModelLayer::Conv1d(_)
402 | ModelLayer::Conv3d(_)
403 | ModelLayer::ConvTranspose2d(_)
404 | ModelLayer::AdaptiveAvgPool2d(_)
405 | ModelLayer::AdaptiveMaxPool2d(_)
406 | ModelLayer::InstanceNorm(_)
407 | ModelLayer::PixelShuffle(_)
408 | ModelLayer::Upsample(_)
409 | ModelLayer::GELU(_)
410 | ModelLayer::SiLU(_)
411 | ModelLayer::Mish(_)
412 | ModelLayer::PReLU(_)
413 | ModelLayer::Rnn(_)
414 | ModelLayer::Lstm(_)
415 | ModelLayer::Gru(_)
416 | ModelLayer::MultiHeadAttention(_)
417 | ModelLayer::TransformerEncoder(_)
418 | ModelLayer::FeedForward(_)
419 | ModelLayer::DeformableConv2d(_) => {
420 nodes.push(OnnxExportNode {
421 op_type: "Identity".to_string(),
422 name: format!("node{node_counter}"),
423 inputs: vec![current_name.clone()],
424 outputs: vec![out_name.clone()],
425 attributes: vec![],
426 });
427 node_counter += 1;
428 }
429 }
430 current_name = out_name;
431 }
432
433 let inputs = vec![OnnxExportValueInfo {
434 name: "input".to_string(),
435 shape: input_shape.to_vec(),
436 }];
437 let outputs = vec![OnnxExportValueInfo {
438 name: current_name,
439 shape: vec![],
440 }];
441
442 Ok(OnnxExportGraph {
443 nodes,
444 initializers,
445 inputs,
446 outputs,
447 })
448}
449
450fn nhwc_weight_to_nchw(w: &Tensor) -> Result<Tensor, ModelError> {
453 let s = w.shape();
454 if s.len() != 4 {
455 return Err(ModelError::OnnxExport(format!(
456 "conv weight rank must be 4, got {}",
457 s.len()
458 )));
459 }
460 let (kh, kw, cin, cout) = (s[0], s[1], s[2], s[3]);
461 let src = w.data();
462 let mut dst = vec![0.0f32; src.len()];
463 for oc in 0..cout {
464 for ic in 0..cin {
465 for r in 0..kh {
466 for c in 0..kw {
467 let src_idx = ((r * kw + c) * cin + ic) * cout + oc;
468 let dst_idx = ((oc * cin + ic) * kh + r) * kw + c;
469 dst[dst_idx] = src[src_idx];
470 }
471 }
472 }
473 }
474 Tensor::from_vec(vec![cout, cin, kh, kw], dst)
475 .map_err(|e| ModelError::OnnxExport(e.to_string()))
476}