Skip to main content

trustformers_debug/
netron_export.rs

1//! Netron export functionality for model visualization
2//!
3//! This module provides tools to export TrustformeRS models to formats compatible with
4//! Netron (<https://netron.app/>), a powerful neural network visualizer.
5
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs;
10use std::path::Path;
11
12/// ONNX-like model representation for Netron visualization
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct NetronModel {
15    /// Model metadata
16    pub metadata: ModelMetadata,
17    /// Graph definition
18    pub graph: ModelGraph,
19    /// Model version
20    pub version: String,
21}
22
23/// Model metadata
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ModelMetadata {
26    /// Model name
27    pub name: String,
28    /// Model description
29    pub description: String,
30    /// Model author
31    pub author: Option<String>,
32    /// Model version
33    pub version: Option<String>,
34    /// License information
35    pub license: Option<String>,
36    /// Additional properties
37    pub properties: HashMap<String, String>,
38}
39
40/// Model graph containing nodes and edges
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ModelGraph {
43    /// Graph name
44    pub name: String,
45    /// Input tensors
46    pub inputs: Vec<TensorInfo>,
47    /// Output tensors
48    pub outputs: Vec<TensorInfo>,
49    /// Graph nodes (layers/operations)
50    pub nodes: Vec<GraphNode>,
51    /// Initializers (weights and biases)
52    pub initializers: Vec<TensorData>,
53}
54
55/// Tensor information
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct TensorInfo {
58    /// Tensor name
59    pub name: String,
60    /// Data type (e.g., "float32", "int64")
61    pub dtype: String,
62    /// Tensor shape
63    pub shape: Vec<i64>,
64    /// Optional documentation
65    pub doc_string: Option<String>,
66}
67
68/// Graph node representing a layer or operation
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct GraphNode {
71    /// Node name
72    pub name: String,
73    /// Operation type (e.g., "Linear", "Conv2d", "Softmax")
74    pub op_type: String,
75    /// Input tensor names
76    pub inputs: Vec<String>,
77    /// Output tensor names
78    pub outputs: Vec<String>,
79    /// Node attributes
80    pub attributes: HashMap<String, AttributeValue>,
81    /// Optional documentation
82    pub doc_string: Option<String>,
83}
84
85/// Attribute value types
86#[derive(Debug, Clone, Serialize, Deserialize)]
87#[serde(untagged)]
88pub enum AttributeValue {
89    /// Integer value
90    Int(i64),
91    /// Float value
92    Float(f64),
93    /// String value
94    String(String),
95    /// Boolean value
96    Bool(bool),
97    /// Array of integers
98    Ints(Vec<i64>),
99    /// Array of floats
100    Floats(Vec<f64>),
101    /// Array of strings
102    Strings(Vec<String>),
103}
104
105/// Tensor data for weights and biases
106#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct TensorData {
108    /// Tensor name
109    pub name: String,
110    /// Data type
111    pub dtype: String,
112    /// Tensor shape
113    pub shape: Vec<i64>,
114    /// Raw data (encoded as base64 for binary data)
115    #[serde(skip_serializing_if = "Option::is_none")]
116    pub data: Option<Vec<f32>>,
117    /// Data location (for external data)
118    #[serde(skip_serializing_if = "Option::is_none")]
119    pub data_location: Option<String>,
120}
121
122/// Netron exporter for model visualization
123pub struct NetronExporter {
124    model: NetronModel,
125    output_format: ExportFormat,
126}
127
128/// Export format options
129#[derive(Debug, Clone, Copy, PartialEq, Eq)]
130pub enum ExportFormat {
131    /// JSON format (human-readable)
132    Json,
133    /// ONNX-like binary format
134    Onnx,
135}
136
137impl NetronExporter {
138    /// Create a new Netron exporter
139    ///
140    /// # Arguments
141    ///
142    /// * `model_name` - Name of the model
143    /// * `description` - Model description
144    ///
145    /// # Example
146    ///
147    /// ```
148    /// use trustformers_debug::NetronExporter;
149    ///
150    /// let exporter = NetronExporter::new("bert-base", "BERT base model");
151    /// ```
152    pub fn new(model_name: &str, description: &str) -> Self {
153        let metadata = ModelMetadata {
154            name: model_name.to_string(),
155            description: description.to_string(),
156            author: None,
157            version: None,
158            license: None,
159            properties: HashMap::new(),
160        };
161
162        let graph = ModelGraph {
163            name: format!("{}_graph", model_name),
164            inputs: Vec::new(),
165            outputs: Vec::new(),
166            nodes: Vec::new(),
167            initializers: Vec::new(),
168        };
169
170        let model = NetronModel {
171            metadata,
172            graph,
173            version: "1.0".to_string(),
174        };
175
176        Self {
177            model,
178            output_format: ExportFormat::Json,
179        }
180    }
181
182    /// Set the export format
183    pub fn with_format(mut self, format: ExportFormat) -> Self {
184        self.output_format = format;
185        self
186    }
187
188    /// Set model metadata
189    pub fn set_metadata(&mut self, metadata: ModelMetadata) {
190        self.model.metadata = metadata;
191    }
192
193    /// Add model author
194    pub fn set_author(&mut self, author: &str) {
195        self.model.metadata.author = Some(author.to_string());
196    }
197
198    /// Add model version
199    pub fn set_version(&mut self, version: &str) {
200        self.model.metadata.version = Some(version.to_string());
201    }
202
203    /// Add a custom property to metadata
204    pub fn add_property(&mut self, key: &str, value: &str) {
205        self.model.metadata.properties.insert(key.to_string(), value.to_string());
206    }
207
208    /// Add an input tensor
209    pub fn add_input(&mut self, name: &str, dtype: &str, shape: Vec<i64>) {
210        self.model.graph.inputs.push(TensorInfo {
211            name: name.to_string(),
212            dtype: dtype.to_string(),
213            shape,
214            doc_string: None,
215        });
216    }
217
218    /// Add an output tensor
219    pub fn add_output(&mut self, name: &str, dtype: &str, shape: Vec<i64>) {
220        self.model.graph.outputs.push(TensorInfo {
221            name: name.to_string(),
222            dtype: dtype.to_string(),
223            shape,
224            doc_string: None,
225        });
226    }
227
228    /// Add a graph node (layer/operation)
229    ///
230    /// # Example
231    ///
232    /// ```
233    /// # use trustformers_debug::NetronExporter;
234    /// # use std::collections::HashMap;
235    /// let mut exporter = NetronExporter::new("model", "test model");
236    ///
237    /// let mut attrs = HashMap::new();
238    /// attrs.insert("in_features".to_string(),
239    ///              trustformers_debug::netron_export::AttributeValue::Int(768));
240    /// attrs.insert("out_features".to_string(),
241    ///              trustformers_debug::netron_export::AttributeValue::Int(3072));
242    ///
243    /// exporter.add_node(
244    ///     "fc1",
245    ///     "Linear",
246    ///     vec!["input".to_string()],
247    ///     vec!["hidden".to_string()],
248    ///     attrs,
249    /// );
250    /// ```
251    pub fn add_node(
252        &mut self,
253        name: &str,
254        op_type: &str,
255        inputs: Vec<String>,
256        outputs: Vec<String>,
257        attributes: HashMap<String, AttributeValue>,
258    ) {
259        self.model.graph.nodes.push(GraphNode {
260            name: name.to_string(),
261            op_type: op_type.to_string(),
262            inputs,
263            outputs,
264            attributes,
265            doc_string: None,
266        });
267    }
268
269    /// Add a node with documentation
270    pub fn add_node_with_doc(
271        &mut self,
272        name: &str,
273        op_type: &str,
274        inputs: Vec<String>,
275        outputs: Vec<String>,
276        attributes: HashMap<String, AttributeValue>,
277        doc_string: &str,
278    ) {
279        self.model.graph.nodes.push(GraphNode {
280            name: name.to_string(),
281            op_type: op_type.to_string(),
282            inputs,
283            outputs,
284            attributes,
285            doc_string: Some(doc_string.to_string()),
286        });
287    }
288
289    /// Add tensor data (weights/biases)
290    pub fn add_tensor_data(
291        &mut self,
292        name: &str,
293        dtype: &str,
294        shape: Vec<i64>,
295        data: Option<Vec<f32>>,
296    ) {
297        self.model.graph.initializers.push(TensorData {
298            name: name.to_string(),
299            dtype: dtype.to_string(),
300            shape,
301            data,
302            data_location: None,
303        });
304    }
305
306    /// Export the model to a file
307    ///
308    /// # Arguments
309    ///
310    /// * `path` - Output file path
311    ///
312    /// # Example
313    ///
314    /// ```no_run
315    /// # use trustformers_debug::NetronExporter;
316    /// # let exporter = NetronExporter::new("model", "test");
317    /// exporter.export("model.json").unwrap();
318    /// ```
319    pub fn export<P: AsRef<Path>>(&self, path: P) -> Result<()> {
320        let path = path.as_ref();
321
322        // Create parent directory if needed
323        if let Some(parent) = path.parent() {
324            fs::create_dir_all(parent)?;
325        }
326
327        match self.output_format {
328            ExportFormat::Json => {
329                let json = serde_json::to_string_pretty(&self.model)?;
330                fs::write(path, json)?;
331            },
332            ExportFormat::Onnx => {
333                // For now, export as JSON with .onnx extension
334                // A full ONNX protobuf implementation would require additional dependencies
335                let json = serde_json::to_string_pretty(&self.model)?;
336                fs::write(path, json)?;
337            },
338        }
339
340        Ok(())
341    }
342
343    /// Get a reference to the model
344    pub fn model(&self) -> &NetronModel {
345        &self.model
346    }
347
348    /// Get a mutable reference to the model
349    pub fn model_mut(&mut self) -> &mut NetronModel {
350        &mut self.model
351    }
352
353    /// Export model to a string (JSON format)
354    pub fn to_json_string(&self) -> Result<String> {
355        Ok(serde_json::to_string_pretty(&self.model)?)
356    }
357
358    /// Create a simple linear layer node
359    pub fn create_linear_node(
360        name: &str,
361        input_name: &str,
362        output_name: &str,
363        in_features: i64,
364        out_features: i64,
365        has_bias: bool,
366    ) -> GraphNode {
367        let mut attributes = HashMap::new();
368        attributes.insert("in_features".to_string(), AttributeValue::Int(in_features));
369        attributes.insert(
370            "out_features".to_string(),
371            AttributeValue::Int(out_features),
372        );
373        attributes.insert("bias".to_string(), AttributeValue::Bool(has_bias));
374
375        GraphNode {
376            name: name.to_string(),
377            op_type: "Linear".to_string(),
378            inputs: vec![input_name.to_string()],
379            outputs: vec![output_name.to_string()],
380            attributes,
381            doc_string: None,
382        }
383    }
384
385    /// Create a transformer attention node
386    pub fn create_attention_node(
387        name: &str,
388        input_name: &str,
389        output_name: &str,
390        num_heads: i64,
391        head_dim: i64,
392    ) -> GraphNode {
393        let mut attributes = HashMap::new();
394        attributes.insert("num_heads".to_string(), AttributeValue::Int(num_heads));
395        attributes.insert("head_dim".to_string(), AttributeValue::Int(head_dim));
396
397        GraphNode {
398            name: name.to_string(),
399            op_type: "MultiHeadAttention".to_string(),
400            inputs: vec![input_name.to_string()],
401            outputs: vec![output_name.to_string()],
402            attributes,
403            doc_string: Some("Multi-head self-attention layer".to_string()),
404        }
405    }
406
407    /// Create a layer normalization node
408    pub fn create_layernorm_node(
409        name: &str,
410        input_name: &str,
411        output_name: &str,
412        normalized_shape: Vec<i64>,
413        eps: f64,
414    ) -> GraphNode {
415        let mut attributes = HashMap::new();
416        attributes.insert(
417            "normalized_shape".to_string(),
418            AttributeValue::Ints(normalized_shape),
419        );
420        attributes.insert("eps".to_string(), AttributeValue::Float(eps));
421
422        GraphNode {
423            name: name.to_string(),
424            op_type: "LayerNorm".to_string(),
425            inputs: vec![input_name.to_string()],
426            outputs: vec![output_name.to_string()],
427            attributes,
428            doc_string: None,
429        }
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use std::env;
437
438    #[test]
439    fn test_netron_exporter_creation() {
440        let exporter = NetronExporter::new("test_model", "A test model");
441        assert_eq!(exporter.model.metadata.name, "test_model");
442        assert_eq!(exporter.model.metadata.description, "A test model");
443    }
444
445    #[test]
446    fn test_add_input_output() {
447        let mut exporter = NetronExporter::new("test", "test");
448
449        exporter.add_input("input_ids", "int64", vec![1, 128]);
450        exporter.add_output("logits", "float32", vec![1, 128, 30522]);
451
452        assert_eq!(exporter.model.graph.inputs.len(), 1);
453        assert_eq!(exporter.model.graph.outputs.len(), 1);
454        assert_eq!(exporter.model.graph.inputs[0].name, "input_ids");
455    }
456
457    #[test]
458    fn test_add_node() {
459        let mut exporter = NetronExporter::new("test", "test");
460
461        let mut attrs = HashMap::new();
462        attrs.insert("in_features".to_string(), AttributeValue::Int(768));
463        attrs.insert("out_features".to_string(), AttributeValue::Int(3072));
464
465        exporter.add_node(
466            "fc1",
467            "Linear",
468            vec!["input".to_string()],
469            vec!["output".to_string()],
470            attrs,
471        );
472
473        assert_eq!(exporter.model.graph.nodes.len(), 1);
474        assert_eq!(exporter.model.graph.nodes[0].name, "fc1");
475        assert_eq!(exporter.model.graph.nodes[0].op_type, "Linear");
476    }
477
478    #[test]
479    fn test_export_json() {
480        let temp_dir = env::temp_dir();
481        let output_path = temp_dir.join("test_model.json");
482
483        let mut exporter = NetronExporter::new("test_model", "Test model");
484        exporter.add_input("input", "float32", vec![1, 10]);
485        exporter.add_output("output", "float32", vec![1, 5]);
486
487        exporter.export(&output_path).unwrap();
488        assert!(output_path.exists());
489
490        // Clean up
491        let _ = fs::remove_file(output_path);
492    }
493
494    #[test]
495    fn test_create_linear_node() {
496        let node = NetronExporter::create_linear_node("fc1", "input", "output", 768, 3072, true);
497
498        assert_eq!(node.name, "fc1");
499        assert_eq!(node.op_type, "Linear");
500        assert!(node.attributes.contains_key("in_features"));
501        assert!(node.attributes.contains_key("bias"));
502    }
503
504    #[test]
505    fn test_create_attention_node() {
506        let node = NetronExporter::create_attention_node("attn", "input", "output", 12, 64);
507
508        assert_eq!(node.op_type, "MultiHeadAttention");
509        assert!(node.doc_string.is_some());
510    }
511
512    #[test]
513    fn test_metadata_setters() {
514        let mut exporter = NetronExporter::new("test", "test");
515
516        exporter.set_author("Test Author");
517        exporter.set_version("1.0.0");
518        exporter.add_property("framework", "TrustformeRS");
519
520        assert_eq!(
521            exporter.model.metadata.author,
522            Some("Test Author".to_string())
523        );
524        assert_eq!(exporter.model.metadata.version, Some("1.0.0".to_string()));
525        assert_eq!(
526            exporter.model.metadata.properties.get("framework"),
527            Some(&"TrustformeRS".to_string())
528        );
529    }
530
531    #[test]
532    fn test_to_json_string() {
533        let mut exporter = NetronExporter::new("test", "test");
534        exporter.add_input("input", "float32", vec![1, 10]);
535
536        let json = exporter.to_json_string().unwrap();
537        assert!(json.contains("test"));
538        assert!(json.contains("input"));
539    }
540
541    #[test]
542    fn test_add_tensor_data() {
543        let mut exporter = NetronExporter::new("test", "test");
544
545        let weights = vec![0.1, 0.2, 0.3, 0.4];
546        exporter.add_tensor_data("layer.weight", "float32", vec![2, 2], Some(weights));
547
548        assert_eq!(exporter.model.graph.initializers.len(), 1);
549        assert_eq!(exporter.model.graph.initializers[0].name, "layer.weight");
550    }
551}