Skip to main content

trustformers_debug/
tensorboard_integration.rs

1//! TensorBoard integration for exporting training metrics and visualizations
2//!
3//! This module provides functionality to export TrustformeRS debugging data to TensorBoard format,
4//! enabling integration with TensorBoard's powerful visualization tools.
5
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs::{self, File};
10use std::io::Write;
11use std::path::{Path, PathBuf};
12use std::time::{SystemTime, UNIX_EPOCH};
13
14use scirs2_core::ndarray::ArrayD;
15
16/// TensorBoard event writer for logging scalars, histograms, and embeddings
17#[derive(Debug)]
18pub struct TensorBoardWriter {
19    log_dir: PathBuf,
20    run_name: String,
21    step_counter: u64,
22    scalar_logs: Vec<ScalarEvent>,
23    histogram_logs: Vec<HistogramEvent>,
24    text_logs: Vec<TextEvent>,
25    embedding_logs: Vec<EmbeddingEvent>,
26}
27
28/// Scalar event for logging single values
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ScalarEvent {
31    pub tag: String,
32    pub value: f64,
33    pub step: u64,
34    pub timestamp: u64,
35}
36
37/// Histogram event for logging distributions
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct HistogramEvent {
40    pub tag: String,
41    pub values: Vec<f64>,
42    pub step: u64,
43    pub timestamp: u64,
44    pub min: f64,
45    pub max: f64,
46    pub num: usize,
47    pub sum: f64,
48    pub sum_squares: f64,
49}
50
51/// Text event for logging text data
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct TextEvent {
54    pub tag: String,
55    pub text: String,
56    pub step: u64,
57    pub timestamp: u64,
58}
59
60/// Embedding event for projector visualization
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct EmbeddingEvent {
63    pub tag: String,
64    pub embeddings: Vec<Vec<f64>>,
65    pub labels: Option<Vec<String>>,
66    pub step: u64,
67    pub timestamp: u64,
68}
69
70/// Graph node for model architecture visualization
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct GraphNode {
73    pub name: String,
74    pub op_type: String,
75    pub input_names: Vec<String>,
76    pub output_names: Vec<String>,
77    pub attributes: HashMap<String, String>,
78}
79
80/// Graph definition for model structure
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct GraphDef {
83    pub nodes: Vec<GraphNode>,
84    pub metadata: HashMap<String, String>,
85}
86
87impl TensorBoardWriter {
88    /// Create a new TensorBoard writer with the specified log directory
89    ///
90    /// # Arguments
91    ///
92    /// * `log_dir` - Directory where TensorBoard logs will be written
93    ///
94    /// # Example
95    ///
96    /// ```no_run
97    /// use trustformers_debug::TensorBoardWriter;
98    ///
99    /// let writer = TensorBoardWriter::new("runs/experiment1").unwrap();
100    /// ```
101    pub fn new<P: AsRef<Path>>(log_dir: P) -> Result<Self> {
102        let log_dir = log_dir.as_ref().to_path_buf();
103        let run_name = format!(
104            "run_{}",
105            SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs()
106        );
107
108        // Create log directory if it doesn't exist
109        fs::create_dir_all(&log_dir)?;
110
111        Ok(Self {
112            log_dir,
113            run_name,
114            step_counter: 0,
115            scalar_logs: Vec::new(),
116            histogram_logs: Vec::new(),
117            text_logs: Vec::new(),
118            embedding_logs: Vec::new(),
119        })
120    }
121
122    /// Create a new TensorBoard writer with a custom run name
123    pub fn with_run_name<P: AsRef<Path>>(log_dir: P, run_name: String) -> Result<Self> {
124        let log_dir = log_dir.as_ref().to_path_buf();
125
126        // Create log directory if it doesn't exist
127        fs::create_dir_all(&log_dir)?;
128
129        Ok(Self {
130            log_dir,
131            run_name,
132            step_counter: 0,
133            scalar_logs: Vec::new(),
134            histogram_logs: Vec::new(),
135            text_logs: Vec::new(),
136            embedding_logs: Vec::new(),
137        })
138    }
139
140    /// Add a scalar value to the log
141    ///
142    /// # Arguments
143    ///
144    /// * `tag` - Name/identifier for this scalar
145    /// * `value` - Scalar value to log
146    /// * `step` - Training step number
147    ///
148    /// # Example
149    ///
150    /// ```no_run
151    /// # use trustformers_debug::TensorBoardWriter;
152    /// # let mut writer = TensorBoardWriter::new("runs/test").unwrap();
153    /// writer.add_scalar("loss/train", 0.5, 100).unwrap();
154    /// writer.add_scalar("accuracy/val", 0.95, 100).unwrap();
155    /// ```
156    pub fn add_scalar(&mut self, tag: &str, value: f64, step: u64) -> Result<()> {
157        let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
158
159        self.scalar_logs.push(ScalarEvent {
160            tag: tag.to_string(),
161            value,
162            step,
163            timestamp,
164        });
165
166        Ok(())
167    }
168
169    /// Add multiple scalars at once
170    pub fn add_scalars(
171        &mut self,
172        main_tag: &str,
173        tag_scalar_dict: HashMap<String, f64>,
174        step: u64,
175    ) -> Result<()> {
176        for (tag, value) in tag_scalar_dict {
177            let full_tag = format!("{}/{}", main_tag, tag);
178            self.add_scalar(&full_tag, value, step)?;
179        }
180        Ok(())
181    }
182
183    /// Add a histogram of values
184    ///
185    /// # Arguments
186    ///
187    /// * `tag` - Name/identifier for this histogram
188    /// * `values` - Array of values to create histogram from
189    /// * `step` - Training step number
190    ///
191    /// # Example
192    ///
193    /// ```no_run
194    /// # use trustformers_debug::TensorBoardWriter;
195    /// # let mut writer = TensorBoardWriter::new("runs/test").unwrap();
196    /// let weights = vec![0.1, 0.2, 0.15, 0.3, 0.25];
197    /// writer.add_histogram("layer.0.weight", &weights, 100).unwrap();
198    /// ```
199    pub fn add_histogram(&mut self, tag: &str, values: &[f64], step: u64) -> Result<()> {
200        if values.is_empty() {
201            return Ok(());
202        }
203
204        let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
205
206        let min = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
207        let max = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
208        let sum: f64 = values.iter().sum();
209        let sum_squares: f64 = values.iter().map(|x| x * x).sum();
210
211        self.histogram_logs.push(HistogramEvent {
212            tag: tag.to_string(),
213            values: values.to_vec(),
214            step,
215            timestamp,
216            min,
217            max,
218            num: values.len(),
219            sum,
220            sum_squares,
221        });
222
223        Ok(())
224    }
225
226    /// Add text data for logging
227    pub fn add_text(&mut self, tag: &str, text: &str, step: u64) -> Result<()> {
228        let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
229
230        self.text_logs.push(TextEvent {
231            tag: tag.to_string(),
232            text: text.to_string(),
233            step,
234            timestamp,
235        });
236
237        Ok(())
238    }
239
240    /// Add embeddings for projector visualization
241    ///
242    /// # Arguments
243    ///
244    /// * `tag` - Name for this embedding
245    /// * `embeddings` - 2D array of embedding vectors
246    /// * `labels` - Optional labels for each embedding
247    /// * `step` - Training step number
248    pub fn add_embedding(
249        &mut self,
250        tag: &str,
251        embeddings: Vec<Vec<f64>>,
252        labels: Option<Vec<String>>,
253        step: u64,
254    ) -> Result<()> {
255        let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
256
257        self.embedding_logs.push(EmbeddingEvent {
258            tag: tag.to_string(),
259            embeddings,
260            labels,
261            step,
262            timestamp,
263        });
264
265        Ok(())
266    }
267
268    /// Add a graph definition for model architecture visualization
269    pub fn add_graph(&mut self, graph: &GraphDef) -> Result<()> {
270        let graph_path = self.log_dir.join(&self.run_name).join("graph.json");
271        if let Some(parent) = graph_path.parent() {
272            fs::create_dir_all(parent)?;
273        }
274
275        let graph_json = serde_json::to_string_pretty(graph)?;
276        fs::write(graph_path, graph_json)?;
277
278        Ok(())
279    }
280
281    /// Flush all pending logs to disk
282    pub fn flush(&mut self) -> Result<()> {
283        let run_dir = self.log_dir.join(&self.run_name);
284        fs::create_dir_all(&run_dir)?;
285
286        // Write scalar logs
287        if !self.scalar_logs.is_empty() {
288            let scalars_path = run_dir.join("scalars.jsonl");
289            let mut file = File::create(scalars_path)?;
290            for event in &self.scalar_logs {
291                let line = serde_json::to_string(event)?;
292                writeln!(file, "{}", line)?;
293            }
294        }
295
296        // Write histogram logs
297        if !self.histogram_logs.is_empty() {
298            let histograms_path = run_dir.join("histograms.jsonl");
299            let mut file = File::create(histograms_path)?;
300            for event in &self.histogram_logs {
301                let line = serde_json::to_string(event)?;
302                writeln!(file, "{}", line)?;
303            }
304        }
305
306        // Write text logs
307        if !self.text_logs.is_empty() {
308            let text_path = run_dir.join("text.jsonl");
309            let mut file = File::create(text_path)?;
310            for event in &self.text_logs {
311                let line = serde_json::to_string(event)?;
312                writeln!(file, "{}", line)?;
313            }
314        }
315
316        // Write embedding logs
317        if !self.embedding_logs.is_empty() {
318            let embeddings_path = run_dir.join("embeddings.jsonl");
319            let mut file = File::create(embeddings_path)?;
320            for event in &self.embedding_logs {
321                let line = serde_json::to_string(event)?;
322                writeln!(file, "{}", line)?;
323            }
324        }
325
326        Ok(())
327    }
328
329    /// Get the path to the log directory
330    pub fn log_dir(&self) -> &Path {
331        &self.log_dir
332    }
333
334    /// Get the current run name
335    pub fn run_name(&self) -> &str {
336        &self.run_name
337    }
338
339    /// Increment the internal step counter
340    pub fn increment_step(&mut self) -> u64 {
341        self.step_counter += 1;
342        self.step_counter
343    }
344
345    /// Get the current step counter value
346    pub fn current_step(&self) -> u64 {
347        self.step_counter
348    }
349
350    /// Close the writer and flush all remaining data
351    pub fn close(mut self) -> Result<()> {
352        self.flush()
353    }
354}
355
356impl Drop for TensorBoardWriter {
357    fn drop(&mut self) {
358        // Auto-flush on drop
359        let _ = self.flush();
360    }
361}
362
363/// Helper function to create a graph node
364pub fn create_graph_node(
365    name: String,
366    op_type: String,
367    inputs: Vec<String>,
368    outputs: Vec<String>,
369) -> GraphNode {
370    GraphNode {
371        name,
372        op_type,
373        input_names: inputs,
374        output_names: outputs,
375        attributes: HashMap::new(),
376    }
377}
378
379/// Helper to convert tensor statistics to histogram-compatible format
380pub fn tensor_to_histogram_values(tensor: &ArrayD<f32>) -> Vec<f64> {
381    tensor.iter().map(|&x| x as f64).collect()
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387    use std::env;
388
389    #[test]
390    fn test_tensorboard_writer_creation() {
391        let temp_dir = env::temp_dir().join("tensorboard_test");
392        let writer = TensorBoardWriter::new(&temp_dir).expect("tensor operation failed");
393        assert!(writer.log_dir().exists());
394    }
395
396    #[test]
397    fn test_add_scalar() {
398        let temp_dir = env::temp_dir().join("tensorboard_scalar_test");
399        let mut writer = TensorBoardWriter::new(&temp_dir).expect("tensor operation failed");
400
401        writer.add_scalar("test/loss", 0.5, 0).expect("add operation failed");
402        writer.add_scalar("test/loss", 0.4, 1).expect("add operation failed");
403        writer.add_scalar("test/loss", 0.3, 2).expect("add operation failed");
404
405        assert_eq!(writer.scalar_logs.len(), 3);
406        assert_eq!(writer.scalar_logs[0].value, 0.5);
407        assert_eq!(writer.scalar_logs[1].value, 0.4);
408    }
409
410    #[test]
411    fn test_add_histogram() {
412        let temp_dir = env::temp_dir().join("tensorboard_histogram_test");
413        let mut writer = TensorBoardWriter::new(&temp_dir).expect("tensor operation failed");
414
415        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
416        writer.add_histogram("test/weights", &values, 0).expect("add operation failed");
417
418        assert_eq!(writer.histogram_logs.len(), 1);
419        assert_eq!(writer.histogram_logs[0].min, 1.0);
420        assert_eq!(writer.histogram_logs[0].max, 5.0);
421        assert_eq!(writer.histogram_logs[0].num, 5);
422    }
423
424    #[test]
425    fn test_add_text() {
426        let temp_dir = env::temp_dir().join("tensorboard_text_test");
427        let mut writer = TensorBoardWriter::new(&temp_dir).expect("tensor operation failed");
428
429        writer.add_text("test/note", "This is a test", 0).expect("add operation failed");
430        assert_eq!(writer.text_logs.len(), 1);
431        assert_eq!(writer.text_logs[0].text, "This is a test");
432    }
433
434    #[test]
435    fn test_flush() {
436        let temp_dir = env::temp_dir().join("tensorboard_flush_test");
437        let mut writer = TensorBoardWriter::new(&temp_dir).expect("tensor operation failed");
438
439        writer.add_scalar("test/metric", 1.0, 0).expect("add operation failed");
440        writer.flush().expect("operation failed in test");
441
442        let scalars_path = temp_dir.join(writer.run_name()).join("scalars.jsonl");
443        assert!(scalars_path.exists());
444    }
445
446    #[test]
447    fn test_add_scalars() {
448        let temp_dir = env::temp_dir().join("tensorboard_scalars_test");
449        let mut writer = TensorBoardWriter::new(&temp_dir).expect("tensor operation failed");
450
451        let mut metrics = HashMap::new();
452        metrics.insert("loss".to_string(), 0.5);
453        metrics.insert("accuracy".to_string(), 0.95);
454
455        writer.add_scalars("train", metrics, 0).expect("add operation failed");
456        assert_eq!(writer.scalar_logs.len(), 2);
457    }
458
459    #[test]
460    fn test_add_embedding() {
461        let temp_dir = env::temp_dir().join("tensorboard_embedding_test");
462        let mut writer = TensorBoardWriter::new(&temp_dir).expect("tensor operation failed");
463
464        let embeddings = vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]];
465        let labels = vec!["class1".to_string(), "class2".to_string()];
466
467        writer
468            .add_embedding("test/emb", embeddings, Some(labels), 0)
469            .expect("add operation failed");
470        assert_eq!(writer.embedding_logs.len(), 1);
471    }
472
473    #[test]
474    fn test_graph_node_creation() {
475        let node = create_graph_node(
476            "layer1".to_string(),
477            "Linear".to_string(),
478            vec!["input".to_string()],
479            vec!["output".to_string()],
480        );
481
482        assert_eq!(node.name, "layer1");
483        assert_eq!(node.op_type, "Linear");
484        assert_eq!(node.input_names.len(), 1);
485        assert_eq!(node.output_names.len(), 1);
486    }
487}