Skip to main content

trustformers_debug/
unified_debug_session.rs

1//! Unified debugging session manager
2//!
3//! This module provides a high-level API for managing debugging sessions that
4//! integrate multiple debugging tools (TensorBoard, visualizers, stability checkers, etc.)
5
6use anyhow::Result;
7use std::path::{Path, PathBuf};
8use std::time::{SystemTime, UNIX_EPOCH};
9
10use crate::{
11    ActivationVisualizer, AttentionVisualizer, GraphVisualizer, NetronExporter, StabilityChecker,
12    TensorBoardWriter,
13};
14
15/// Unified debugging session that manages multiple debugging tools
16#[derive(Debug)]
17pub struct UnifiedDebugSession {
18    /// Session ID
19    session_id: String,
20    /// Session directory for outputs
21    session_dir: PathBuf,
22    /// TensorBoard writer (optional)
23    tensorboard: Option<TensorBoardWriter>,
24    /// Activation visualizer
25    activation_viz: ActivationVisualizer,
26    /// Attention visualizer
27    attention_viz: AttentionVisualizer,
28    /// Stability checker
29    stability_checker: StabilityChecker,
30    /// Graph visualizer (optional)
31    graph_viz: Option<GraphVisualizer>,
32    /// Session configuration
33    config: UnifiedDebugSessionConfig,
34    /// Current step counter
35    step: u64,
36}
37
38/// Configuration for debug session
39#[derive(Debug, Clone)]
40pub struct UnifiedDebugSessionConfig {
41    /// Enable TensorBoard logging
42    pub enable_tensorboard: bool,
43    /// Enable activation visualization
44    pub enable_activation_viz: bool,
45    /// Enable attention visualization
46    pub enable_attention_viz: bool,
47    /// Enable stability checking
48    pub enable_stability_check: bool,
49    /// Enable graph visualization
50    pub enable_graph_viz: bool,
51    /// Auto-save interval (in steps, 0 = disabled)
52    pub auto_save_interval: u64,
53    /// Session name prefix
54    pub session_name: Option<String>,
55}
56
57impl Default for UnifiedDebugSessionConfig {
58    fn default() -> Self {
59        Self {
60            enable_tensorboard: true,
61            enable_activation_viz: true,
62            enable_attention_viz: true,
63            enable_stability_check: true,
64            enable_graph_viz: false,
65            auto_save_interval: 100,
66            session_name: None,
67        }
68    }
69}
70
71/// Summary of debugging session results
72#[derive(Debug, Clone)]
73pub struct SessionSummary {
74    /// Session ID
75    pub session_id: String,
76    /// Total steps recorded
77    pub total_steps: u64,
78    /// Number of activations captured
79    pub num_activations: usize,
80    /// Number of attention patterns captured
81    pub num_attention_patterns: usize,
82    /// Number of stability issues detected
83    pub num_stability_issues: usize,
84    /// Session directory
85    pub output_directory: PathBuf,
86}
87
88impl UnifiedDebugSession {
89    /// Create a new debugging session
90    ///
91    /// # Arguments
92    ///
93    /// * `output_dir` - Directory where debugging outputs will be saved
94    ///
95    /// # Example
96    ///
97    /// ```
98    /// use trustformers_debug::UnifiedDebugSession;
99    ///
100    /// let session = UnifiedDebugSession::new("debug_outputs").unwrap();
101    /// ```
102    pub fn new<P: AsRef<Path>>(output_dir: P) -> Result<Self> {
103        Self::with_config(output_dir, UnifiedDebugSessionConfig::default())
104    }
105
106    /// Create a debugging session with custom configuration
107    pub fn with_config<P: AsRef<Path>>(
108        output_dir: P,
109        config: UnifiedDebugSessionConfig,
110    ) -> Result<Self> {
111        let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
112
113        let session_id = if let Some(ref name) = config.session_name {
114            format!("{}_{}", name, timestamp)
115        } else {
116            format!("debug_session_{}", timestamp)
117        };
118
119        let session_dir = output_dir.as_ref().join(&session_id);
120        std::fs::create_dir_all(&session_dir)?;
121
122        let tensorboard = if config.enable_tensorboard {
123            let tb_dir = session_dir.join("tensorboard");
124            Some(TensorBoardWriter::new(&tb_dir)?)
125        } else {
126            None
127        };
128
129        let graph_viz = if config.enable_graph_viz {
130            Some(GraphVisualizer::new(&session_id))
131        } else {
132            None
133        };
134
135        Ok(Self {
136            session_id,
137            session_dir,
138            tensorboard,
139            activation_viz: ActivationVisualizer::new(),
140            attention_viz: AttentionVisualizer::new(),
141            stability_checker: StabilityChecker::new(),
142            graph_viz,
143            config,
144            step: 0,
145        })
146    }
147
148    /// Log a scalar metric
149    pub fn log_scalar(&mut self, tag: &str, value: f64) -> Result<()> {
150        if let Some(ref mut tb) = self.tensorboard {
151            tb.add_scalar(tag, value, self.step)?;
152        }
153        Ok(())
154    }
155
156    /// Log multiple scalars at once
157    pub fn log_scalars(&mut self, tag: &str, values: &[(&str, f64)]) -> Result<()> {
158        if let Some(tb) = &mut self.tensorboard {
159            for (name, value) in values {
160                tb.add_scalar(&format!("{}/{}", tag, name), *value, self.step)?;
161            }
162        }
163        Ok(())
164    }
165
166    /// Log histogram (e.g., weight distribution)
167    pub fn log_histogram(&mut self, tag: &str, values: &[f64]) -> Result<()> {
168        if let Some(tb) = &mut self.tensorboard {
169            tb.add_histogram(tag, values, self.step)?;
170        }
171        Ok(())
172    }
173
174    /// Register layer activations
175    pub fn register_activations(
176        &mut self,
177        layer_name: &str,
178        values: Vec<f32>,
179        shape: Vec<usize>,
180    ) -> Result<()> {
181        if self.config.enable_activation_viz {
182            self.activation_viz.register(layer_name, values, shape)?;
183        }
184        Ok(())
185    }
186
187    /// Register attention weights
188    pub fn register_attention(
189        &mut self,
190        layer_name: &str,
191        weights: Vec<Vec<Vec<f64>>>,
192        tokens: Vec<String>,
193    ) -> Result<()> {
194        if self.config.enable_attention_viz {
195            self.attention_viz.register(
196                layer_name,
197                weights,
198                tokens.clone(),
199                tokens,
200                crate::AttentionType::SelfAttention,
201            )?;
202        }
203        Ok(())
204    }
205
206    /// Check tensor for numerical stability
207    pub fn check_stability(&mut self, layer_name: &str, values: &[f64]) -> Result<usize> {
208        if self.config.enable_stability_check {
209            self.stability_checker.check_tensor(layer_name, values)
210        } else {
211            Ok(0)
212        }
213    }
214
215    /// Increment step counter
216    pub fn step(&mut self) {
217        self.step += 1;
218
219        // Auto-save if enabled
220        if self.config.auto_save_interval > 0
221            && self.step.is_multiple_of(self.config.auto_save_interval)
222        {
223            let _ = self.save();
224        }
225    }
226
227    /// Get current step
228    pub fn current_step(&self) -> u64 {
229        self.step
230    }
231
232    /// Save all accumulated data to disk
233    pub fn save(&mut self) -> Result<()> {
234        // Flush TensorBoard
235        if let Some(tb) = &mut self.tensorboard {
236            tb.flush()?;
237        }
238
239        // Export activation summaries
240        if self.config.enable_activation_viz && self.activation_viz.num_layers() > 0 {
241            let act_summary = self.activation_viz.print_summary()?;
242            let act_path = self.session_dir.join("activation_summary.txt");
243            std::fs::write(act_path, act_summary)?;
244        }
245
246        // Export attention summaries
247        if self.config.enable_attention_viz && self.attention_viz.num_layers() > 0 {
248            let att_summary = self.attention_viz.summary();
249            let att_path = self.session_dir.join("attention_summary.txt");
250            std::fs::write(att_path, att_summary)?;
251        }
252
253        // Export stability report
254        if self.config.enable_stability_check && self.stability_checker.has_issues() {
255            let stability_report = self.stability_checker.report();
256            let stability_path = self.session_dir.join("stability_report.txt");
257            std::fs::write(stability_path, stability_report)?;
258
259            let issues_json = self.session_dir.join("stability_issues.json");
260            self.stability_checker.export_to_json(&issues_json)?;
261        }
262
263        Ok(())
264    }
265
266    /// Export activation visualization for a specific layer
267    pub fn export_activation_viz(&self, layer_name: &str, filename: &str) -> Result<()> {
268        let path = self.session_dir.join(filename);
269        self.activation_viz.export_statistics(layer_name, &path)
270    }
271
272    /// Export attention visualization for a specific layer
273    pub fn export_attention_viz(&self, layer_name: &str, filename: &str) -> Result<()> {
274        let path = self.session_dir.join(filename);
275        self.attention_viz.export_to_json(layer_name, &path)
276    }
277
278    /// Export attention as BertViz HTML
279    pub fn export_attention_bertviz(&self, layer_name: &str, filename: &str) -> Result<()> {
280        let path = self.session_dir.join(filename);
281        self.attention_viz.export_to_bertviz(layer_name, &path)
282    }
283
284    /// Get session summary
285    pub fn summary(&self) -> SessionSummary {
286        SessionSummary {
287            session_id: self.session_id.clone(),
288            total_steps: self.step,
289            num_activations: self.activation_viz.num_layers(),
290            num_attention_patterns: self.attention_viz.num_layers(),
291            num_stability_issues: self.stability_checker.total_issues(),
292            output_directory: self.session_dir.clone(),
293        }
294    }
295
296    /// Get session directory path
297    pub fn session_dir(&self) -> &Path {
298        &self.session_dir
299    }
300
301    /// Print summary to string
302    pub fn print_summary(&self) -> String {
303        let summary = self.summary();
304        format!(
305            r#"Debug Session Summary
306=====================
307Session ID: {}
308Total Steps: {}
309Activations Captured: {}
310Attention Patterns: {}
311Stability Issues: {}
312Output Directory: {}
313
314TensorBoard: {}
315Use: tensorboard --logdir={}
316"#,
317            summary.session_id,
318            summary.total_steps,
319            summary.num_activations,
320            summary.num_attention_patterns,
321            summary.num_stability_issues,
322            summary.output_directory.display(),
323            if self.tensorboard.is_some() { "Enabled" } else { "Disabled" },
324            self.session_dir.join("tensorboard").display(),
325        )
326    }
327
328    /// Close session and perform final save
329    pub fn close(mut self) -> Result<SessionSummary> {
330        self.save()?;
331        Ok(self.summary())
332    }
333
334    /// Get reference to activation visualizer
335    pub fn activation_visualizer(&self) -> &ActivationVisualizer {
336        &self.activation_viz
337    }
338
339    /// Get reference to attention visualizer
340    pub fn attention_visualizer(&self) -> &AttentionVisualizer {
341        &self.attention_viz
342    }
343
344    /// Get reference to stability checker
345    pub fn stability_checker(&self) -> &StabilityChecker {
346        &self.stability_checker
347    }
348
349    /// Get mutable reference to graph visualizer (if enabled)
350    pub fn graph_visualizer_mut(&mut self) -> Option<&mut GraphVisualizer> {
351        self.graph_viz.as_mut()
352    }
353
354    /// Export model architecture to Netron format
355    pub fn export_model_netron(&self, model_name: &str, description: &str) -> Result<PathBuf> {
356        let exporter = NetronExporter::new(model_name, description);
357        let path = self.session_dir.join(format!("{}.json", model_name));
358        exporter.export(&path)?;
359        Ok(path)
360    }
361}
362
363impl Drop for UnifiedDebugSession {
364    fn drop(&mut self) {
365        // Auto-save on drop
366        let _ = self.save();
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use std::env;
374
375    #[test]
376    fn test_debug_session_creation() {
377        let temp_dir = env::temp_dir().join("debug_session_test");
378        let session = UnifiedDebugSession::new(&temp_dir).expect("temp file creation failed");
379
380        assert_eq!(session.current_step(), 0);
381        assert!(session.session_dir().exists());
382    }
383
384    #[test]
385    fn test_log_scalar() {
386        let temp_dir = env::temp_dir().join("debug_session_scalar_test");
387        let mut session = UnifiedDebugSession::new(&temp_dir).expect("temp file creation failed");
388
389        session.log_scalar("test/loss", 0.5).expect("operation failed in test");
390        session.log_scalar("test/accuracy", 0.95).expect("operation failed in test");
391        session.step();
392
393        assert_eq!(session.current_step(), 1);
394    }
395
396    #[test]
397    fn test_register_activations() {
398        let temp_dir = env::temp_dir().join("debug_session_activations_test");
399        let mut session = UnifiedDebugSession::new(&temp_dir).expect("temp file creation failed");
400
401        let activations = vec![0.1, 0.2, 0.3, 0.4, 0.5];
402        session
403            .register_activations("layer1", activations, vec![5])
404            .expect("operation failed in test");
405
406        assert_eq!(session.activation_visualizer().num_layers(), 1);
407    }
408
409    #[test]
410    fn test_check_stability() {
411        let temp_dir = env::temp_dir().join("debug_session_stability_test");
412        let mut session = UnifiedDebugSession::new(&temp_dir).expect("temp file creation failed");
413
414        let values = vec![1.0, f64::NAN, 2.0];
415        let issues = session.check_stability("layer1", &values).expect("operation failed in test");
416
417        assert!(issues > 0);
418        assert!(session.stability_checker().has_issues());
419    }
420
421    #[test]
422    fn test_session_save() {
423        let temp_dir = env::temp_dir().join("debug_session_save_test");
424        let mut session = UnifiedDebugSession::new(&temp_dir).expect("temp file creation failed");
425
426        session.log_scalar("test/metric", 1.0).expect("operation failed in test");
427        session
428            .register_activations("layer1", vec![1.0, 2.0, 3.0], vec![3])
429            .expect("operation failed in test");
430
431        session.save().expect("operation failed in test");
432
433        // Check that files were created
434        let session_dir = session.session_dir();
435        assert!(session_dir.exists());
436    }
437
438    #[test]
439    fn test_session_summary() {
440        let temp_dir = env::temp_dir().join("debug_session_summary_test");
441        let mut session = UnifiedDebugSession::new(&temp_dir).expect("temp file creation failed");
442
443        session
444            .register_activations("layer1", vec![1.0], vec![1])
445            .expect("operation failed in test");
446        session.step();
447        session.step();
448
449        let summary = session.summary();
450        assert_eq!(summary.total_steps, 2);
451        assert_eq!(summary.num_activations, 1);
452    }
453
454    #[test]
455    fn test_custom_config() {
456        let temp_dir = env::temp_dir().join("debug_session_config_test");
457
458        let config = UnifiedDebugSessionConfig {
459            enable_tensorboard: false,
460            enable_activation_viz: true,
461            enable_attention_viz: false,
462            enable_stability_check: true,
463            enable_graph_viz: false,
464            auto_save_interval: 0,
465            session_name: Some("test_session".to_string()),
466        };
467
468        let session =
469            UnifiedDebugSession::with_config(&temp_dir, config).expect("temp file creation failed");
470        assert!(session.tensorboard.is_none());
471        assert!(session.session_id.starts_with("test_session"));
472    }
473
474    #[test]
475    fn test_auto_save() {
476        let temp_dir = env::temp_dir().join("debug_session_autosave_test");
477
478        let config = UnifiedDebugSessionConfig {
479            auto_save_interval: 2,
480            ..Default::default()
481        };
482
483        let mut session =
484            UnifiedDebugSession::with_config(&temp_dir, config).expect("temp file creation failed");
485
486        session.log_scalar("test/value", 1.0).expect("operation failed in test");
487        session.step(); // step 1
488        session.step(); // step 2 - should trigger auto-save
489
490        // Session should have saved automatically
491        assert_eq!(session.current_step(), 2);
492    }
493
494    #[test]
495    fn test_print_summary() {
496        let temp_dir = env::temp_dir().join("debug_session_print_test");
497        let session = UnifiedDebugSession::new(&temp_dir).expect("temp file creation failed");
498
499        let summary_str = session.print_summary();
500        assert!(summary_str.contains("Debug Session Summary"));
501        assert!(summary_str.contains("Session ID"));
502    }
503}