1use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs;
10use std::path::Path;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct NetronModel {
15 pub metadata: ModelMetadata,
17 pub graph: ModelGraph,
19 pub version: String,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ModelMetadata {
26 pub name: String,
28 pub description: String,
30 pub author: Option<String>,
32 pub version: Option<String>,
34 pub license: Option<String>,
36 pub properties: HashMap<String, String>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ModelGraph {
43 pub name: String,
45 pub inputs: Vec<TensorInfo>,
47 pub outputs: Vec<TensorInfo>,
49 pub nodes: Vec<GraphNode>,
51 pub initializers: Vec<TensorData>,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct TensorInfo {
58 pub name: String,
60 pub dtype: String,
62 pub shape: Vec<i64>,
64 pub doc_string: Option<String>,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct GraphNode {
71 pub name: String,
73 pub op_type: String,
75 pub inputs: Vec<String>,
77 pub outputs: Vec<String>,
79 pub attributes: HashMap<String, AttributeValue>,
81 pub doc_string: Option<String>,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87#[serde(untagged)]
88pub enum AttributeValue {
89 Int(i64),
91 Float(f64),
93 String(String),
95 Bool(bool),
97 Ints(Vec<i64>),
99 Floats(Vec<f64>),
101 Strings(Vec<String>),
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct TensorData {
108 pub name: String,
110 pub dtype: String,
112 pub shape: Vec<i64>,
114 #[serde(skip_serializing_if = "Option::is_none")]
116 pub data: Option<Vec<f32>>,
117 #[serde(skip_serializing_if = "Option::is_none")]
119 pub data_location: Option<String>,
120}
121
122pub struct NetronExporter {
124 model: NetronModel,
125 output_format: ExportFormat,
126}
127
128#[derive(Debug, Clone, Copy, PartialEq, Eq)]
130pub enum ExportFormat {
131 Json,
133 Onnx,
135}
136
137impl NetronExporter {
138 pub fn new(model_name: &str, description: &str) -> Self {
153 let metadata = ModelMetadata {
154 name: model_name.to_string(),
155 description: description.to_string(),
156 author: None,
157 version: None,
158 license: None,
159 properties: HashMap::new(),
160 };
161
162 let graph = ModelGraph {
163 name: format!("{}_graph", model_name),
164 inputs: Vec::new(),
165 outputs: Vec::new(),
166 nodes: Vec::new(),
167 initializers: Vec::new(),
168 };
169
170 let model = NetronModel {
171 metadata,
172 graph,
173 version: "1.0".to_string(),
174 };
175
176 Self {
177 model,
178 output_format: ExportFormat::Json,
179 }
180 }
181
182 pub fn with_format(mut self, format: ExportFormat) -> Self {
184 self.output_format = format;
185 self
186 }
187
188 pub fn set_metadata(&mut self, metadata: ModelMetadata) {
190 self.model.metadata = metadata;
191 }
192
193 pub fn set_author(&mut self, author: &str) {
195 self.model.metadata.author = Some(author.to_string());
196 }
197
198 pub fn set_version(&mut self, version: &str) {
200 self.model.metadata.version = Some(version.to_string());
201 }
202
203 pub fn add_property(&mut self, key: &str, value: &str) {
205 self.model.metadata.properties.insert(key.to_string(), value.to_string());
206 }
207
208 pub fn add_input(&mut self, name: &str, dtype: &str, shape: Vec<i64>) {
210 self.model.graph.inputs.push(TensorInfo {
211 name: name.to_string(),
212 dtype: dtype.to_string(),
213 shape,
214 doc_string: None,
215 });
216 }
217
218 pub fn add_output(&mut self, name: &str, dtype: &str, shape: Vec<i64>) {
220 self.model.graph.outputs.push(TensorInfo {
221 name: name.to_string(),
222 dtype: dtype.to_string(),
223 shape,
224 doc_string: None,
225 });
226 }
227
228 pub fn add_node(
252 &mut self,
253 name: &str,
254 op_type: &str,
255 inputs: Vec<String>,
256 outputs: Vec<String>,
257 attributes: HashMap<String, AttributeValue>,
258 ) {
259 self.model.graph.nodes.push(GraphNode {
260 name: name.to_string(),
261 op_type: op_type.to_string(),
262 inputs,
263 outputs,
264 attributes,
265 doc_string: None,
266 });
267 }
268
269 pub fn add_node_with_doc(
271 &mut self,
272 name: &str,
273 op_type: &str,
274 inputs: Vec<String>,
275 outputs: Vec<String>,
276 attributes: HashMap<String, AttributeValue>,
277 doc_string: &str,
278 ) {
279 self.model.graph.nodes.push(GraphNode {
280 name: name.to_string(),
281 op_type: op_type.to_string(),
282 inputs,
283 outputs,
284 attributes,
285 doc_string: Some(doc_string.to_string()),
286 });
287 }
288
289 pub fn add_tensor_data(
291 &mut self,
292 name: &str,
293 dtype: &str,
294 shape: Vec<i64>,
295 data: Option<Vec<f32>>,
296 ) {
297 self.model.graph.initializers.push(TensorData {
298 name: name.to_string(),
299 dtype: dtype.to_string(),
300 shape,
301 data,
302 data_location: None,
303 });
304 }
305
306 pub fn export<P: AsRef<Path>>(&self, path: P) -> Result<()> {
320 let path = path.as_ref();
321
322 if let Some(parent) = path.parent() {
324 fs::create_dir_all(parent)?;
325 }
326
327 match self.output_format {
328 ExportFormat::Json => {
329 let json = serde_json::to_string_pretty(&self.model)?;
330 fs::write(path, json)?;
331 },
332 ExportFormat::Onnx => {
333 let json = serde_json::to_string_pretty(&self.model)?;
336 fs::write(path, json)?;
337 },
338 }
339
340 Ok(())
341 }
342
343 pub fn model(&self) -> &NetronModel {
345 &self.model
346 }
347
348 pub fn model_mut(&mut self) -> &mut NetronModel {
350 &mut self.model
351 }
352
353 pub fn to_json_string(&self) -> Result<String> {
355 Ok(serde_json::to_string_pretty(&self.model)?)
356 }
357
358 pub fn create_linear_node(
360 name: &str,
361 input_name: &str,
362 output_name: &str,
363 in_features: i64,
364 out_features: i64,
365 has_bias: bool,
366 ) -> GraphNode {
367 let mut attributes = HashMap::new();
368 attributes.insert("in_features".to_string(), AttributeValue::Int(in_features));
369 attributes.insert(
370 "out_features".to_string(),
371 AttributeValue::Int(out_features),
372 );
373 attributes.insert("bias".to_string(), AttributeValue::Bool(has_bias));
374
375 GraphNode {
376 name: name.to_string(),
377 op_type: "Linear".to_string(),
378 inputs: vec![input_name.to_string()],
379 outputs: vec![output_name.to_string()],
380 attributes,
381 doc_string: None,
382 }
383 }
384
385 pub fn create_attention_node(
387 name: &str,
388 input_name: &str,
389 output_name: &str,
390 num_heads: i64,
391 head_dim: i64,
392 ) -> GraphNode {
393 let mut attributes = HashMap::new();
394 attributes.insert("num_heads".to_string(), AttributeValue::Int(num_heads));
395 attributes.insert("head_dim".to_string(), AttributeValue::Int(head_dim));
396
397 GraphNode {
398 name: name.to_string(),
399 op_type: "MultiHeadAttention".to_string(),
400 inputs: vec![input_name.to_string()],
401 outputs: vec![output_name.to_string()],
402 attributes,
403 doc_string: Some("Multi-head self-attention layer".to_string()),
404 }
405 }
406
407 pub fn create_layernorm_node(
409 name: &str,
410 input_name: &str,
411 output_name: &str,
412 normalized_shape: Vec<i64>,
413 eps: f64,
414 ) -> GraphNode {
415 let mut attributes = HashMap::new();
416 attributes.insert(
417 "normalized_shape".to_string(),
418 AttributeValue::Ints(normalized_shape),
419 );
420 attributes.insert("eps".to_string(), AttributeValue::Float(eps));
421
422 GraphNode {
423 name: name.to_string(),
424 op_type: "LayerNorm".to_string(),
425 inputs: vec![input_name.to_string()],
426 outputs: vec![output_name.to_string()],
427 attributes,
428 doc_string: None,
429 }
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436 use std::env;
437
438 #[test]
439 fn test_netron_exporter_creation() {
440 let exporter = NetronExporter::new("test_model", "A test model");
441 assert_eq!(exporter.model.metadata.name, "test_model");
442 assert_eq!(exporter.model.metadata.description, "A test model");
443 }
444
445 #[test]
446 fn test_add_input_output() {
447 let mut exporter = NetronExporter::new("test", "test");
448
449 exporter.add_input("input_ids", "int64", vec![1, 128]);
450 exporter.add_output("logits", "float32", vec![1, 128, 30522]);
451
452 assert_eq!(exporter.model.graph.inputs.len(), 1);
453 assert_eq!(exporter.model.graph.outputs.len(), 1);
454 assert_eq!(exporter.model.graph.inputs[0].name, "input_ids");
455 }
456
457 #[test]
458 fn test_add_node() {
459 let mut exporter = NetronExporter::new("test", "test");
460
461 let mut attrs = HashMap::new();
462 attrs.insert("in_features".to_string(), AttributeValue::Int(768));
463 attrs.insert("out_features".to_string(), AttributeValue::Int(3072));
464
465 exporter.add_node(
466 "fc1",
467 "Linear",
468 vec!["input".to_string()],
469 vec!["output".to_string()],
470 attrs,
471 );
472
473 assert_eq!(exporter.model.graph.nodes.len(), 1);
474 assert_eq!(exporter.model.graph.nodes[0].name, "fc1");
475 assert_eq!(exporter.model.graph.nodes[0].op_type, "Linear");
476 }
477
478 #[test]
479 fn test_export_json() {
480 let temp_dir = env::temp_dir();
481 let output_path = temp_dir.join("test_model.json");
482
483 let mut exporter = NetronExporter::new("test_model", "Test model");
484 exporter.add_input("input", "float32", vec![1, 10]);
485 exporter.add_output("output", "float32", vec![1, 5]);
486
487 exporter.export(&output_path).expect("operation failed in test");
488 assert!(output_path.exists());
489
490 let _ = fs::remove_file(output_path);
492 }
493
494 #[test]
495 fn test_create_linear_node() {
496 let node = NetronExporter::create_linear_node("fc1", "input", "output", 768, 3072, true);
497
498 assert_eq!(node.name, "fc1");
499 assert_eq!(node.op_type, "Linear");
500 assert!(node.attributes.contains_key("in_features"));
501 assert!(node.attributes.contains_key("bias"));
502 }
503
504 #[test]
505 fn test_create_attention_node() {
506 let node = NetronExporter::create_attention_node("attn", "input", "output", 12, 64);
507
508 assert_eq!(node.op_type, "MultiHeadAttention");
509 assert!(node.doc_string.is_some());
510 }
511
512 #[test]
513 fn test_metadata_setters() {
514 let mut exporter = NetronExporter::new("test", "test");
515
516 exporter.set_author("Test Author");
517 exporter.set_version("1.0.0");
518 exporter.add_property("framework", "TrustformeRS");
519
520 assert_eq!(
521 exporter.model.metadata.author,
522 Some("Test Author".to_string())
523 );
524 assert_eq!(exporter.model.metadata.version, Some("1.0.0".to_string()));
525 assert_eq!(
526 exporter.model.metadata.properties.get("framework"),
527 Some(&"TrustformeRS".to_string())
528 );
529 }
530
531 #[test]
532 fn test_to_json_string() {
533 let mut exporter = NetronExporter::new("test", "test");
534 exporter.add_input("input", "float32", vec![1, 10]);
535
536 let json = exporter.to_json_string().expect("operation failed in test");
537 assert!(json.contains("test"));
538 assert!(json.contains("input"));
539 }
540
541 #[test]
542 fn test_add_tensor_data() {
543 let mut exporter = NetronExporter::new("test", "test");
544
545 let weights = vec![0.1, 0.2, 0.3, 0.4];
546 exporter.add_tensor_data("layer.weight", "float32", vec![2, 2], Some(weights));
547
548 assert_eq!(exporter.model.graph.initializers.len(), 1);
549 assert_eq!(exporter.model.graph.initializers[0].name, "layer.weight");
550 }
551}