Skip to main content

trustformers_debug/
unified_debug_session.rs

1//! Unified debugging session manager
2//!
3//! This module provides a high-level API for managing debugging sessions that
4//! integrate multiple debugging tools (TensorBoard, visualizers, stability checkers, etc.)
5
6use anyhow::Result;
7use std::path::{Path, PathBuf};
8use std::time::{SystemTime, UNIX_EPOCH};
9
10use crate::{
11    ActivationVisualizer, AttentionVisualizer, GraphVisualizer, NetronExporter, StabilityChecker,
12    TensorBoardWriter,
13};
14
15/// Unified debugging session that manages multiple debugging tools
16#[derive(Debug)]
17pub struct UnifiedDebugSession {
18    /// Session ID
19    session_id: String,
20    /// Session directory for outputs
21    session_dir: PathBuf,
22    /// TensorBoard writer (optional)
23    tensorboard: Option<TensorBoardWriter>,
24    /// Activation visualizer
25    activation_viz: ActivationVisualizer,
26    /// Attention visualizer
27    attention_viz: AttentionVisualizer,
28    /// Stability checker
29    stability_checker: StabilityChecker,
30    /// Graph visualizer (optional)
31    graph_viz: Option<GraphVisualizer>,
32    /// Session configuration
33    config: UnifiedDebugSessionConfig,
34    /// Current step counter
35    step: u64,
36}
37
38/// Configuration for debug session
39#[derive(Debug, Clone)]
40pub struct UnifiedDebugSessionConfig {
41    /// Enable TensorBoard logging
42    pub enable_tensorboard: bool,
43    /// Enable activation visualization
44    pub enable_activation_viz: bool,
45    /// Enable attention visualization
46    pub enable_attention_viz: bool,
47    /// Enable stability checking
48    pub enable_stability_check: bool,
49    /// Enable graph visualization
50    pub enable_graph_viz: bool,
51    /// Auto-save interval (in steps, 0 = disabled)
52    pub auto_save_interval: u64,
53    /// Session name prefix
54    pub session_name: Option<String>,
55}
56
57impl Default for UnifiedDebugSessionConfig {
58    fn default() -> Self {
59        Self {
60            enable_tensorboard: true,
61            enable_activation_viz: true,
62            enable_attention_viz: true,
63            enable_stability_check: true,
64            enable_graph_viz: false,
65            auto_save_interval: 100,
66            session_name: None,
67        }
68    }
69}
70
71/// Summary of debugging session results
72#[derive(Debug, Clone)]
73pub struct SessionSummary {
74    /// Session ID
75    pub session_id: String,
76    /// Total steps recorded
77    pub total_steps: u64,
78    /// Number of activations captured
79    pub num_activations: usize,
80    /// Number of attention patterns captured
81    pub num_attention_patterns: usize,
82    /// Number of stability issues detected
83    pub num_stability_issues: usize,
84    /// Session directory
85    pub output_directory: PathBuf,
86}
87
88impl UnifiedDebugSession {
89    /// Create a new debugging session
90    ///
91    /// # Arguments
92    ///
93    /// * `output_dir` - Directory where debugging outputs will be saved
94    ///
95    /// # Example
96    ///
97    /// ```
98    /// use trustformers_debug::UnifiedDebugSession;
99    ///
100    /// let session = UnifiedDebugSession::new("debug_outputs").unwrap();
101    /// ```
102    pub fn new<P: AsRef<Path>>(output_dir: P) -> Result<Self> {
103        Self::with_config(output_dir, UnifiedDebugSessionConfig::default())
104    }
105
106    /// Create a debugging session with custom configuration
107    pub fn with_config<P: AsRef<Path>>(
108        output_dir: P,
109        config: UnifiedDebugSessionConfig,
110    ) -> Result<Self> {
111        let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
112
113        let session_id = if let Some(ref name) = config.session_name {
114            format!("{}_{}", name, timestamp)
115        } else {
116            format!("debug_session_{}", timestamp)
117        };
118
119        let session_dir = output_dir.as_ref().join(&session_id);
120        std::fs::create_dir_all(&session_dir)?;
121
122        let tensorboard = if config.enable_tensorboard {
123            let tb_dir = session_dir.join("tensorboard");
124            Some(TensorBoardWriter::new(&tb_dir)?)
125        } else {
126            None
127        };
128
129        let graph_viz = if config.enable_graph_viz {
130            Some(GraphVisualizer::new(&session_id))
131        } else {
132            None
133        };
134
135        Ok(Self {
136            session_id,
137            session_dir,
138            tensorboard,
139            activation_viz: ActivationVisualizer::new(),
140            attention_viz: AttentionVisualizer::new(),
141            stability_checker: StabilityChecker::new(),
142            graph_viz,
143            config,
144            step: 0,
145        })
146    }
147
148    /// Log a scalar metric
149    pub fn log_scalar(&mut self, tag: &str, value: f64) -> Result<()> {
150        if let Some(ref mut tb) = self.tensorboard {
151            tb.add_scalar(tag, value, self.step)?;
152        }
153        Ok(())
154    }
155
156    /// Log multiple scalars at once
157    pub fn log_scalars(&mut self, tag: &str, values: &[(&str, f64)]) -> Result<()> {
158        if let Some(tb) = &mut self.tensorboard {
159            for (name, value) in values {
160                tb.add_scalar(&format!("{}/{}", tag, name), *value, self.step)?;
161            }
162        }
163        Ok(())
164    }
165
166    /// Log histogram (e.g., weight distribution)
167    pub fn log_histogram(&mut self, tag: &str, values: &[f64]) -> Result<()> {
168        if let Some(tb) = &mut self.tensorboard {
169            tb.add_histogram(tag, values, self.step)?;
170        }
171        Ok(())
172    }
173
174    /// Register layer activations
175    pub fn register_activations(
176        &mut self,
177        layer_name: &str,
178        values: Vec<f32>,
179        shape: Vec<usize>,
180    ) -> Result<()> {
181        if self.config.enable_activation_viz {
182            self.activation_viz.register(layer_name, values, shape)?;
183        }
184        Ok(())
185    }
186
187    /// Register attention weights
188    pub fn register_attention(
189        &mut self,
190        layer_name: &str,
191        weights: Vec<Vec<Vec<f64>>>,
192        tokens: Vec<String>,
193    ) -> Result<()> {
194        if self.config.enable_attention_viz {
195            self.attention_viz.register(
196                layer_name,
197                weights,
198                tokens.clone(),
199                tokens,
200                crate::AttentionType::SelfAttention,
201            )?;
202        }
203        Ok(())
204    }
205
206    /// Check tensor for numerical stability
207    pub fn check_stability(&mut self, layer_name: &str, values: &[f64]) -> Result<usize> {
208        if self.config.enable_stability_check {
209            self.stability_checker.check_tensor(layer_name, values)
210        } else {
211            Ok(0)
212        }
213    }
214
215    /// Increment step counter
216    pub fn step(&mut self) {
217        self.step += 1;
218
219        // Auto-save if enabled
220        if self.config.auto_save_interval > 0 && self.step % self.config.auto_save_interval == 0 {
221            let _ = self.save();
222        }
223    }
224
225    /// Get current step
226    pub fn current_step(&self) -> u64 {
227        self.step
228    }
229
230    /// Save all accumulated data to disk
231    pub fn save(&mut self) -> Result<()> {
232        // Flush TensorBoard
233        if let Some(tb) = &mut self.tensorboard {
234            tb.flush()?;
235        }
236
237        // Export activation summaries
238        if self.config.enable_activation_viz && self.activation_viz.num_layers() > 0 {
239            let act_summary = self.activation_viz.print_summary()?;
240            let act_path = self.session_dir.join("activation_summary.txt");
241            std::fs::write(act_path, act_summary)?;
242        }
243
244        // Export attention summaries
245        if self.config.enable_attention_viz && self.attention_viz.num_layers() > 0 {
246            let att_summary = self.attention_viz.summary();
247            let att_path = self.session_dir.join("attention_summary.txt");
248            std::fs::write(att_path, att_summary)?;
249        }
250
251        // Export stability report
252        if self.config.enable_stability_check && self.stability_checker.has_issues() {
253            let stability_report = self.stability_checker.report();
254            let stability_path = self.session_dir.join("stability_report.txt");
255            std::fs::write(stability_path, stability_report)?;
256
257            let issues_json = self.session_dir.join("stability_issues.json");
258            self.stability_checker.export_to_json(&issues_json)?;
259        }
260
261        Ok(())
262    }
263
264    /// Export activation visualization for a specific layer
265    pub fn export_activation_viz(&self, layer_name: &str, filename: &str) -> Result<()> {
266        let path = self.session_dir.join(filename);
267        self.activation_viz.export_statistics(layer_name, &path)
268    }
269
270    /// Export attention visualization for a specific layer
271    pub fn export_attention_viz(&self, layer_name: &str, filename: &str) -> Result<()> {
272        let path = self.session_dir.join(filename);
273        self.attention_viz.export_to_json(layer_name, &path)
274    }
275
276    /// Export attention as BertViz HTML
277    pub fn export_attention_bertviz(&self, layer_name: &str, filename: &str) -> Result<()> {
278        let path = self.session_dir.join(filename);
279        self.attention_viz.export_to_bertviz(layer_name, &path)
280    }
281
282    /// Get session summary
283    pub fn summary(&self) -> SessionSummary {
284        SessionSummary {
285            session_id: self.session_id.clone(),
286            total_steps: self.step,
287            num_activations: self.activation_viz.num_layers(),
288            num_attention_patterns: self.attention_viz.num_layers(),
289            num_stability_issues: self.stability_checker.total_issues(),
290            output_directory: self.session_dir.clone(),
291        }
292    }
293
294    /// Get session directory path
295    pub fn session_dir(&self) -> &Path {
296        &self.session_dir
297    }
298
299    /// Print summary to string
300    pub fn print_summary(&self) -> String {
301        let summary = self.summary();
302        format!(
303            r#"Debug Session Summary
304=====================
305Session ID: {}
306Total Steps: {}
307Activations Captured: {}
308Attention Patterns: {}
309Stability Issues: {}
310Output Directory: {}
311
312TensorBoard: {}
313Use: tensorboard --logdir={}
314"#,
315            summary.session_id,
316            summary.total_steps,
317            summary.num_activations,
318            summary.num_attention_patterns,
319            summary.num_stability_issues,
320            summary.output_directory.display(),
321            if self.tensorboard.is_some() { "Enabled" } else { "Disabled" },
322            self.session_dir.join("tensorboard").display(),
323        )
324    }
325
326    /// Close session and perform final save
327    pub fn close(mut self) -> Result<SessionSummary> {
328        self.save()?;
329        Ok(self.summary())
330    }
331
332    /// Get reference to activation visualizer
333    pub fn activation_visualizer(&self) -> &ActivationVisualizer {
334        &self.activation_viz
335    }
336
337    /// Get reference to attention visualizer
338    pub fn attention_visualizer(&self) -> &AttentionVisualizer {
339        &self.attention_viz
340    }
341
342    /// Get reference to stability checker
343    pub fn stability_checker(&self) -> &StabilityChecker {
344        &self.stability_checker
345    }
346
347    /// Get mutable reference to graph visualizer (if enabled)
348    pub fn graph_visualizer_mut(&mut self) -> Option<&mut GraphVisualizer> {
349        self.graph_viz.as_mut()
350    }
351
352    /// Export model architecture to Netron format
353    pub fn export_model_netron(&self, model_name: &str, description: &str) -> Result<PathBuf> {
354        let exporter = NetronExporter::new(model_name, description);
355        let path = self.session_dir.join(format!("{}.json", model_name));
356        exporter.export(&path)?;
357        Ok(path)
358    }
359}
360
361impl Drop for UnifiedDebugSession {
362    fn drop(&mut self) {
363        // Auto-save on drop
364        let _ = self.save();
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use std::env;
372
373    #[test]
374    fn test_debug_session_creation() {
375        let temp_dir = env::temp_dir().join("debug_session_test");
376        let session = UnifiedDebugSession::new(&temp_dir).unwrap();
377
378        assert_eq!(session.current_step(), 0);
379        assert!(session.session_dir().exists());
380    }
381
382    #[test]
383    fn test_log_scalar() {
384        let temp_dir = env::temp_dir().join("debug_session_scalar_test");
385        let mut session = UnifiedDebugSession::new(&temp_dir).unwrap();
386
387        session.log_scalar("test/loss", 0.5).unwrap();
388        session.log_scalar("test/accuracy", 0.95).unwrap();
389        session.step();
390
391        assert_eq!(session.current_step(), 1);
392    }
393
394    #[test]
395    fn test_register_activations() {
396        let temp_dir = env::temp_dir().join("debug_session_activations_test");
397        let mut session = UnifiedDebugSession::new(&temp_dir).unwrap();
398
399        let activations = vec![0.1, 0.2, 0.3, 0.4, 0.5];
400        session.register_activations("layer1", activations, vec![5]).unwrap();
401
402        assert_eq!(session.activation_visualizer().num_layers(), 1);
403    }
404
405    #[test]
406    fn test_check_stability() {
407        let temp_dir = env::temp_dir().join("debug_session_stability_test");
408        let mut session = UnifiedDebugSession::new(&temp_dir).unwrap();
409
410        let values = vec![1.0, f64::NAN, 2.0];
411        let issues = session.check_stability("layer1", &values).unwrap();
412
413        assert!(issues > 0);
414        assert!(session.stability_checker().has_issues());
415    }
416
417    #[test]
418    fn test_session_save() {
419        let temp_dir = env::temp_dir().join("debug_session_save_test");
420        let mut session = UnifiedDebugSession::new(&temp_dir).unwrap();
421
422        session.log_scalar("test/metric", 1.0).unwrap();
423        session.register_activations("layer1", vec![1.0, 2.0, 3.0], vec![3]).unwrap();
424
425        session.save().unwrap();
426
427        // Check that files were created
428        let session_dir = session.session_dir();
429        assert!(session_dir.exists());
430    }
431
432    #[test]
433    fn test_session_summary() {
434        let temp_dir = env::temp_dir().join("debug_session_summary_test");
435        let mut session = UnifiedDebugSession::new(&temp_dir).unwrap();
436
437        session.register_activations("layer1", vec![1.0], vec![1]).unwrap();
438        session.step();
439        session.step();
440
441        let summary = session.summary();
442        assert_eq!(summary.total_steps, 2);
443        assert_eq!(summary.num_activations, 1);
444    }
445
446    #[test]
447    fn test_custom_config() {
448        let temp_dir = env::temp_dir().join("debug_session_config_test");
449
450        let config = UnifiedDebugSessionConfig {
451            enable_tensorboard: false,
452            enable_activation_viz: true,
453            enable_attention_viz: false,
454            enable_stability_check: true,
455            enable_graph_viz: false,
456            auto_save_interval: 0,
457            session_name: Some("test_session".to_string()),
458        };
459
460        let session = UnifiedDebugSession::with_config(&temp_dir, config).unwrap();
461        assert!(session.tensorboard.is_none());
462        assert!(session.session_id.starts_with("test_session"));
463    }
464
465    #[test]
466    fn test_auto_save() {
467        let temp_dir = env::temp_dir().join("debug_session_autosave_test");
468
469        let config = UnifiedDebugSessionConfig {
470            auto_save_interval: 2,
471            ..Default::default()
472        };
473
474        let mut session = UnifiedDebugSession::with_config(&temp_dir, config).unwrap();
475
476        session.log_scalar("test/value", 1.0).unwrap();
477        session.step(); // step 1
478        session.step(); // step 2 - should trigger auto-save
479
480        // Session should have saved automatically
481        assert_eq!(session.current_step(), 2);
482    }
483
484    #[test]
485    fn test_print_summary() {
486        let temp_dir = env::temp_dir().join("debug_session_print_test");
487        let session = UnifiedDebugSession::new(&temp_dir).unwrap();
488
489        let summary_str = session.print_summary();
490        assert!(summary_str.contains("Debug Session Summary"));
491        assert!(summary_str.contains("Session ID"));
492    }
493}