Skip to main content

trustformers_debug/
tensorboard_integration.rs

1//! TensorBoard integration for exporting training metrics and visualizations
2//!
3//! This module provides functionality to export TrustformeRS debugging data to TensorBoard format,
4//! enabling integration with TensorBoard's powerful visualization tools.
5
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs::{self, File};
10use std::io::Write;
11use std::path::{Path, PathBuf};
12use std::time::{SystemTime, UNIX_EPOCH};
13
14use scirs2_core::ndarray::ArrayD;
15
16/// TensorBoard event writer for logging scalars, histograms, and embeddings
17#[derive(Debug)]
18pub struct TensorBoardWriter {
19    log_dir: PathBuf,
20    run_name: String,
21    step_counter: u64,
22    scalar_logs: Vec<ScalarEvent>,
23    histogram_logs: Vec<HistogramEvent>,
24    text_logs: Vec<TextEvent>,
25    embedding_logs: Vec<EmbeddingEvent>,
26}
27
28/// Scalar event for logging single values
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ScalarEvent {
31    pub tag: String,
32    pub value: f64,
33    pub step: u64,
34    pub timestamp: u64,
35}
36
37/// Histogram event for logging distributions
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct HistogramEvent {
40    pub tag: String,
41    pub values: Vec<f64>,
42    pub step: u64,
43    pub timestamp: u64,
44    pub min: f64,
45    pub max: f64,
46    pub num: usize,
47    pub sum: f64,
48    pub sum_squares: f64,
49}
50
51/// Text event for logging text data
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct TextEvent {
54    pub tag: String,
55    pub text: String,
56    pub step: u64,
57    pub timestamp: u64,
58}
59
60/// Embedding event for projector visualization
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct EmbeddingEvent {
63    pub tag: String,
64    pub embeddings: Vec<Vec<f64>>,
65    pub labels: Option<Vec<String>>,
66    pub step: u64,
67    pub timestamp: u64,
68}
69
70/// Graph node for model architecture visualization
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct GraphNode {
73    pub name: String,
74    pub op_type: String,
75    pub input_names: Vec<String>,
76    pub output_names: Vec<String>,
77    pub attributes: HashMap<String, String>,
78}
79
80/// Graph definition for model structure
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct GraphDef {
83    pub nodes: Vec<GraphNode>,
84    pub metadata: HashMap<String, String>,
85}
86
87impl TensorBoardWriter {
88    /// Create a new TensorBoard writer with the specified log directory
89    ///
90    /// # Arguments
91    ///
92    /// * `log_dir` - Directory where TensorBoard logs will be written
93    ///
94    /// # Example
95    ///
96    /// ```no_run
97    /// use trustformers_debug::TensorBoardWriter;
98    ///
99    /// let writer = TensorBoardWriter::new("runs/experiment1").unwrap();
100    /// ```
101    pub fn new<P: AsRef<Path>>(log_dir: P) -> Result<Self> {
102        let log_dir = log_dir.as_ref().to_path_buf();
103        let run_name = format!(
104            "run_{}",
105            SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs()
106        );
107
108        // Create log directory if it doesn't exist
109        fs::create_dir_all(&log_dir)?;
110
111        Ok(Self {
112            log_dir,
113            run_name,
114            step_counter: 0,
115            scalar_logs: Vec::new(),
116            histogram_logs: Vec::new(),
117            text_logs: Vec::new(),
118            embedding_logs: Vec::new(),
119        })
120    }
121
122    /// Create a new TensorBoard writer with a custom run name
123    pub fn with_run_name<P: AsRef<Path>>(log_dir: P, run_name: String) -> Result<Self> {
124        let log_dir = log_dir.as_ref().to_path_buf();
125
126        // Create log directory if it doesn't exist
127        fs::create_dir_all(&log_dir)?;
128
129        Ok(Self {
130            log_dir,
131            run_name,
132            step_counter: 0,
133            scalar_logs: Vec::new(),
134            histogram_logs: Vec::new(),
135            text_logs: Vec::new(),
136            embedding_logs: Vec::new(),
137        })
138    }
139
140    /// Add a scalar value to the log
141    ///
142    /// # Arguments
143    ///
144    /// * `tag` - Name/identifier for this scalar
145    /// * `value` - Scalar value to log
146    /// * `step` - Training step number
147    ///
148    /// # Example
149    ///
150    /// ```no_run
151    /// # use trustformers_debug::TensorBoardWriter;
152    /// # let mut writer = TensorBoardWriter::new("runs/test").unwrap();
153    /// writer.add_scalar("loss/train", 0.5, 100).unwrap();
154    /// writer.add_scalar("accuracy/val", 0.95, 100).unwrap();
155    /// ```
156    pub fn add_scalar(&mut self, tag: &str, value: f64, step: u64) -> Result<()> {
157        let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
158
159        self.scalar_logs.push(ScalarEvent {
160            tag: tag.to_string(),
161            value,
162            step,
163            timestamp,
164        });
165
166        Ok(())
167    }
168
169    /// Add multiple scalars at once
170    pub fn add_scalars(
171        &mut self,
172        main_tag: &str,
173        tag_scalar_dict: HashMap<String, f64>,
174        step: u64,
175    ) -> Result<()> {
176        for (tag, value) in tag_scalar_dict {
177            let full_tag = format!("{}/{}", main_tag, tag);
178            self.add_scalar(&full_tag, value, step)?;
179        }
180        Ok(())
181    }
182
183    /// Add a histogram of values
184    ///
185    /// # Arguments
186    ///
187    /// * `tag` - Name/identifier for this histogram
188    /// * `values` - Array of values to create histogram from
189    /// * `step` - Training step number
190    ///
191    /// # Example
192    ///
193    /// ```no_run
194    /// # use trustformers_debug::TensorBoardWriter;
195    /// # let mut writer = TensorBoardWriter::new("runs/test").unwrap();
196    /// let weights = vec![0.1, 0.2, 0.15, 0.3, 0.25];
197    /// writer.add_histogram("layer.0.weight", &weights, 100).unwrap();
198    /// ```
199    pub fn add_histogram(&mut self, tag: &str, values: &[f64], step: u64) -> Result<()> {
200        if values.is_empty() {
201            return Ok(());
202        }
203
204        let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
205
206        let min = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
207        let max = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
208        let sum: f64 = values.iter().sum();
209        let sum_squares: f64 = values.iter().map(|x| x * x).sum();
210
211        self.histogram_logs.push(HistogramEvent {
212            tag: tag.to_string(),
213            values: values.to_vec(),
214            step,
215            timestamp,
216            min,
217            max,
218            num: values.len(),
219            sum,
220            sum_squares,
221        });
222
223        Ok(())
224    }
225
226    /// Add text data for logging
227    pub fn add_text(&mut self, tag: &str, text: &str, step: u64) -> Result<()> {
228        let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
229
230        self.text_logs.push(TextEvent {
231            tag: tag.to_string(),
232            text: text.to_string(),
233            step,
234            timestamp,
235        });
236
237        Ok(())
238    }
239
240    /// Add embeddings for projector visualization
241    ///
242    /// # Arguments
243    ///
244    /// * `tag` - Name for this embedding
245    /// * `embeddings` - 2D array of embedding vectors
246    /// * `labels` - Optional labels for each embedding
247    /// * `step` - Training step number
248    pub fn add_embedding(
249        &mut self,
250        tag: &str,
251        embeddings: Vec<Vec<f64>>,
252        labels: Option<Vec<String>>,
253        step: u64,
254    ) -> Result<()> {
255        let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
256
257        self.embedding_logs.push(EmbeddingEvent {
258            tag: tag.to_string(),
259            embeddings,
260            labels,
261            step,
262            timestamp,
263        });
264
265        Ok(())
266    }
267
268    /// Add a graph definition for model architecture visualization
269    pub fn add_graph(&mut self, graph: &GraphDef) -> Result<()> {
270        let graph_path = self.log_dir.join(&self.run_name).join("graph.json");
271        fs::create_dir_all(graph_path.parent().unwrap())?;
272
273        let graph_json = serde_json::to_string_pretty(graph)?;
274        fs::write(graph_path, graph_json)?;
275
276        Ok(())
277    }
278
279    /// Flush all pending logs to disk
280    pub fn flush(&mut self) -> Result<()> {
281        let run_dir = self.log_dir.join(&self.run_name);
282        fs::create_dir_all(&run_dir)?;
283
284        // Write scalar logs
285        if !self.scalar_logs.is_empty() {
286            let scalars_path = run_dir.join("scalars.jsonl");
287            let mut file = File::create(scalars_path)?;
288            for event in &self.scalar_logs {
289                let line = serde_json::to_string(event)?;
290                writeln!(file, "{}", line)?;
291            }
292        }
293
294        // Write histogram logs
295        if !self.histogram_logs.is_empty() {
296            let histograms_path = run_dir.join("histograms.jsonl");
297            let mut file = File::create(histograms_path)?;
298            for event in &self.histogram_logs {
299                let line = serde_json::to_string(event)?;
300                writeln!(file, "{}", line)?;
301            }
302        }
303
304        // Write text logs
305        if !self.text_logs.is_empty() {
306            let text_path = run_dir.join("text.jsonl");
307            let mut file = File::create(text_path)?;
308            for event in &self.text_logs {
309                let line = serde_json::to_string(event)?;
310                writeln!(file, "{}", line)?;
311            }
312        }
313
314        // Write embedding logs
315        if !self.embedding_logs.is_empty() {
316            let embeddings_path = run_dir.join("embeddings.jsonl");
317            let mut file = File::create(embeddings_path)?;
318            for event in &self.embedding_logs {
319                let line = serde_json::to_string(event)?;
320                writeln!(file, "{}", line)?;
321            }
322        }
323
324        Ok(())
325    }
326
327    /// Get the path to the log directory
328    pub fn log_dir(&self) -> &Path {
329        &self.log_dir
330    }
331
332    /// Get the current run name
333    pub fn run_name(&self) -> &str {
334        &self.run_name
335    }
336
337    /// Increment the internal step counter
338    pub fn increment_step(&mut self) -> u64 {
339        self.step_counter += 1;
340        self.step_counter
341    }
342
343    /// Get the current step counter value
344    pub fn current_step(&self) -> u64 {
345        self.step_counter
346    }
347
348    /// Close the writer and flush all remaining data
349    pub fn close(mut self) -> Result<()> {
350        self.flush()
351    }
352}
353
354impl Drop for TensorBoardWriter {
355    fn drop(&mut self) {
356        // Auto-flush on drop
357        let _ = self.flush();
358    }
359}
360
361/// Helper function to create a graph node
362pub fn create_graph_node(
363    name: String,
364    op_type: String,
365    inputs: Vec<String>,
366    outputs: Vec<String>,
367) -> GraphNode {
368    GraphNode {
369        name,
370        op_type,
371        input_names: inputs,
372        output_names: outputs,
373        attributes: HashMap::new(),
374    }
375}
376
377/// Helper to convert tensor statistics to histogram-compatible format
378pub fn tensor_to_histogram_values(tensor: &ArrayD<f32>) -> Vec<f64> {
379    tensor.iter().map(|&x| x as f64).collect()
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use std::env;
386
387    #[test]
388    fn test_tensorboard_writer_creation() {
389        let temp_dir = env::temp_dir().join("tensorboard_test");
390        let writer = TensorBoardWriter::new(&temp_dir).unwrap();
391        assert!(writer.log_dir().exists());
392    }
393
394    #[test]
395    fn test_add_scalar() {
396        let temp_dir = env::temp_dir().join("tensorboard_scalar_test");
397        let mut writer = TensorBoardWriter::new(&temp_dir).unwrap();
398
399        writer.add_scalar("test/loss", 0.5, 0).unwrap();
400        writer.add_scalar("test/loss", 0.4, 1).unwrap();
401        writer.add_scalar("test/loss", 0.3, 2).unwrap();
402
403        assert_eq!(writer.scalar_logs.len(), 3);
404        assert_eq!(writer.scalar_logs[0].value, 0.5);
405        assert_eq!(writer.scalar_logs[1].value, 0.4);
406    }
407
408    #[test]
409    fn test_add_histogram() {
410        let temp_dir = env::temp_dir().join("tensorboard_histogram_test");
411        let mut writer = TensorBoardWriter::new(&temp_dir).unwrap();
412
413        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
414        writer.add_histogram("test/weights", &values, 0).unwrap();
415
416        assert_eq!(writer.histogram_logs.len(), 1);
417        assert_eq!(writer.histogram_logs[0].min, 1.0);
418        assert_eq!(writer.histogram_logs[0].max, 5.0);
419        assert_eq!(writer.histogram_logs[0].num, 5);
420    }
421
422    #[test]
423    fn test_add_text() {
424        let temp_dir = env::temp_dir().join("tensorboard_text_test");
425        let mut writer = TensorBoardWriter::new(&temp_dir).unwrap();
426
427        writer.add_text("test/note", "This is a test", 0).unwrap();
428        assert_eq!(writer.text_logs.len(), 1);
429        assert_eq!(writer.text_logs[0].text, "This is a test");
430    }
431
432    #[test]
433    fn test_flush() {
434        let temp_dir = env::temp_dir().join("tensorboard_flush_test");
435        let mut writer = TensorBoardWriter::new(&temp_dir).unwrap();
436
437        writer.add_scalar("test/metric", 1.0, 0).unwrap();
438        writer.flush().unwrap();
439
440        let scalars_path = temp_dir.join(writer.run_name()).join("scalars.jsonl");
441        assert!(scalars_path.exists());
442    }
443
444    #[test]
445    fn test_add_scalars() {
446        let temp_dir = env::temp_dir().join("tensorboard_scalars_test");
447        let mut writer = TensorBoardWriter::new(&temp_dir).unwrap();
448
449        let mut metrics = HashMap::new();
450        metrics.insert("loss".to_string(), 0.5);
451        metrics.insert("accuracy".to_string(), 0.95);
452
453        writer.add_scalars("train", metrics, 0).unwrap();
454        assert_eq!(writer.scalar_logs.len(), 2);
455    }
456
457    #[test]
458    fn test_add_embedding() {
459        let temp_dir = env::temp_dir().join("tensorboard_embedding_test");
460        let mut writer = TensorBoardWriter::new(&temp_dir).unwrap();
461
462        let embeddings = vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]];
463        let labels = vec!["class1".to_string(), "class2".to_string()];
464
465        writer.add_embedding("test/emb", embeddings, Some(labels), 0).unwrap();
466        assert_eq!(writer.embedding_logs.len(), 1);
467    }
468
469    #[test]
470    fn test_graph_node_creation() {
471        let node = create_graph_node(
472            "layer1".to_string(),
473            "Linear".to_string(),
474            vec!["input".to_string()],
475            vec!["output".to_string()],
476        );
477
478        assert_eq!(node.name, "layer1");
479        assert_eq!(node.op_type, "Linear");
480        assert_eq!(node.input_names.len(), 1);
481        assert_eq!(node.output_names.len(), 1);
482    }
483}