Skip to main content

trustformers_core/monitoring/
mod.rs

1// Monitoring and debugging tools for TrustformeRS
2pub mod activation_stats;
3pub mod attention;
4pub mod gradient_flow;
5pub mod memory;
6pub mod metrics;
7pub mod profiler;
8pub mod tensorboard;
9
10pub use activation_stats::*;
11pub use attention::*;
12pub use gradient_flow::*;
13pub use memory::*;
14pub use metrics::*;
15pub use profiler::*;
16pub use tensorboard::*;
17
18use anyhow::Result;
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use std::time::{Duration, Instant};
22
23/// Central monitoring system for tracking model performance and resource usage
24#[derive(Debug, Clone)]
25pub struct ModelMonitor {
26    memory_tracker: MemoryTracker,
27    attention_visualizer: AttentionVisualizer,
28    profiler: ModelProfiler,
29    metrics_collector: MetricsCollector,
30    enabled: bool,
31}
32
33impl Default for ModelMonitor {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl ModelMonitor {
40    pub fn new() -> Self {
41        Self {
42            memory_tracker: MemoryTracker::new(),
43            attention_visualizer: AttentionVisualizer::new(),
44            profiler: ModelProfiler::new(),
45            metrics_collector: MetricsCollector::new(),
46            enabled: true,
47        }
48    }
49
50    pub fn with_config(config: MonitoringConfig) -> Self {
51        Self {
52            memory_tracker: MemoryTracker::with_config(config.memory_config),
53            attention_visualizer: AttentionVisualizer::with_config(config.attention_config),
54            profiler: ModelProfiler::with_config(config.profiler_config),
55            metrics_collector: MetricsCollector::with_config(config.metrics_config),
56            enabled: config.enabled,
57        }
58    }
59
60    /// Start monitoring a forward pass
61    pub fn start_forward_pass(
62        &mut self,
63        batch_size: usize,
64        sequence_length: usize,
65    ) -> Result<MonitoringSession> {
66        if !self.enabled {
67            return Ok(MonitoringSession::disabled());
68        }
69
70        let session_id = uuid::Uuid::new_v4().to_string();
71        let start_time = Instant::now();
72
73        self.memory_tracker.start_tracking(&session_id)?;
74        self.attention_visualizer.start_tracking(&session_id)?;
75        self.profiler.start_profiling(&session_id)?;
76
77        Ok(MonitoringSession {
78            id: session_id,
79            start_time,
80            batch_size,
81            sequence_length,
82            enabled: true,
83        })
84    }
85
86    /// Track attention weights for visualization
87    pub fn track_attention(
88        &mut self,
89        session: &MonitoringSession,
90        layer_idx: usize,
91        attention_weights: &crate::tensor::Tensor,
92        input_tokens: Option<&[String]>,
93    ) -> Result<()> {
94        if !session.enabled {
95            return Ok(());
96        }
97
98        self.attention_visualizer.track_attention(
99            &session.id,
100            layer_idx,
101            attention_weights,
102            input_tokens,
103        )
104    }
105
106    /// Track memory usage at a specific point
107    pub fn track_memory(
108        &mut self,
109        session: &MonitoringSession,
110        checkpoint: &str,
111    ) -> Result<MemorySnapshot> {
112        if !session.enabled {
113            return Ok(MemorySnapshot::default());
114        }
115
116        self.memory_tracker.take_snapshot(&session.id, checkpoint)
117    }
118
119    /// End monitoring session and collect results
120    pub fn end_session(&mut self, session: MonitoringSession) -> Result<MonitoringReport> {
121        if !session.enabled {
122            return Ok(MonitoringReport::default());
123        }
124
125        let duration = session.start_time.elapsed();
126
127        let memory_report = self.memory_tracker.end_tracking(&session.id)?;
128        let profiling_report = self.profiler.end_profiling(&session.id)?;
129        let attention_report = self.attention_visualizer.get_report(&session.id)?;
130
131        let report = MonitoringReport {
132            session_id: session.id,
133            duration,
134            batch_size: session.batch_size,
135            sequence_length: session.sequence_length,
136            memory_report,
137            profiling_report,
138            attention_report,
139            metrics: self.metrics_collector.collect_metrics()?,
140        };
141
142        Ok(report)
143    }
144
145    /// Enable or disable monitoring
146    pub fn set_enabled(&mut self, enabled: bool) {
147        self.enabled = enabled;
148    }
149
150    /// Get current monitoring status
151    pub fn is_enabled(&self) -> bool {
152        self.enabled
153    }
154
155    /// Clear all collected data
156    pub fn clear(&mut self) -> Result<()> {
157        self.memory_tracker.clear()?;
158        self.attention_visualizer.clear()?;
159        self.profiler.clear()?;
160        self.metrics_collector.clear()?;
161        Ok(())
162    }
163}
164
165/// Configuration for the monitoring system
166#[derive(Debug, Clone)]
167pub struct MonitoringConfig {
168    pub enabled: bool,
169    pub memory_config: MemoryTrackerConfig,
170    pub attention_config: AttentionVisualizerConfig,
171    pub profiler_config: ProfilerConfig,
172    pub metrics_config: MetricsCollectorConfig,
173}
174
175impl Default for MonitoringConfig {
176    fn default() -> Self {
177        Self {
178            enabled: true,
179            memory_config: MemoryTrackerConfig::default(),
180            attention_config: AttentionVisualizerConfig::default(),
181            profiler_config: ProfilerConfig::default(),
182            metrics_config: MetricsCollectorConfig::default(),
183        }
184    }
185}
186
187/// Monitoring session for tracking a single forward pass
188#[derive(Debug, Clone)]
189pub struct MonitoringSession {
190    pub id: String,
191    pub start_time: Instant,
192    pub batch_size: usize,
193    pub sequence_length: usize,
194    pub enabled: bool,
195}
196
197impl MonitoringSession {
198    fn disabled() -> Self {
199        Self {
200            id: String::new(),
201            start_time: Instant::now(),
202            batch_size: 0,
203            sequence_length: 0,
204            enabled: false,
205        }
206    }
207}
208
209/// Complete monitoring report for a session
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct MonitoringReport {
212    pub session_id: String,
213    pub duration: Duration,
214    pub batch_size: usize,
215    pub sequence_length: usize,
216    pub memory_report: MemoryReport,
217    pub profiling_report: ProfilingReport,
218    pub attention_report: AttentionReport,
219    pub metrics: HashMap<String, f64>,
220}
221
222impl Default for MonitoringReport {
223    fn default() -> Self {
224        Self {
225            session_id: String::new(),
226            duration: Duration::from_secs(0),
227            batch_size: 0,
228            sequence_length: 0,
229            memory_report: MemoryReport::default(),
230            profiling_report: ProfilingReport::default(),
231            attention_report: AttentionReport::default(),
232            metrics: HashMap::new(),
233        }
234    }
235}
236
237impl MonitoringReport {
238    /// Save report to file
239    pub fn save_to_file(&self, path: &str) -> Result<()> {
240        let json = serde_json::to_string_pretty(self)?;
241        std::fs::write(path, json)?;
242        Ok(())
243    }
244
245    /// Load report from file
246    pub fn load_from_file(path: &str) -> Result<Self> {
247        let content = std::fs::read_to_string(path)?;
248        let report = serde_json::from_str(&content)?;
249        Ok(report)
250    }
251
252    /// Print a summary of the report
253    pub fn print_summary(&self) {
254        println!("Monitoring Report Summary");
255        println!("========================");
256        println!("Session ID: {}", self.session_id);
257        println!("Duration: {:.2}ms", self.duration.as_millis());
258        println!("Batch Size: {}", self.batch_size);
259        println!("Sequence Length: {}", self.sequence_length);
260        println!();
261
262        self.memory_report.print_summary();
263        println!();
264
265        self.profiling_report.print_summary();
266        println!();
267
268        self.attention_report.print_summary();
269        println!();
270
271        if !self.metrics.is_empty() {
272            println!("Additional Metrics:");
273            for (name, value) in &self.metrics {
274                println!("  {}: {:.4}", name, value);
275            }
276        }
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn test_monitor_creation() {
286        let monitor = ModelMonitor::new();
287        assert!(monitor.is_enabled());
288    }
289
290    #[test]
291    fn test_monitor_with_config() {
292        let config = MonitoringConfig {
293            enabled: false,
294            ..Default::default()
295        };
296
297        let monitor = ModelMonitor::with_config(config);
298        assert!(!monitor.is_enabled());
299    }
300
301    #[test]
302    fn test_monitoring_session() -> Result<()> {
303        let mut monitor = ModelMonitor::new();
304
305        let session = monitor.start_forward_pass(4, 128)?;
306        assert_eq!(session.batch_size, 4);
307        assert_eq!(session.sequence_length, 128);
308        assert!(session.enabled);
309
310        let report = monitor.end_session(session)?;
311        assert!(report.duration > Duration::from_nanos(0));
312
313        Ok(())
314    }
315
316    #[test]
317    fn test_disabled_monitoring() -> Result<()> {
318        let mut monitor = ModelMonitor::new();
319        monitor.set_enabled(false);
320
321        let session = monitor.start_forward_pass(4, 128)?;
322        assert!(!session.enabled);
323
324        let report = monitor.end_session(session)?;
325        assert_eq!(report.session_id, "");
326
327        Ok(())
328    }
329
330    #[test]
331    fn test_monitor_clear() -> Result<()> {
332        let mut monitor = ModelMonitor::new();
333
334        // Start and end a session to populate some data
335        let session = monitor.start_forward_pass(4, 128)?;
336        let _report = monitor.end_session(session)?;
337
338        // Clear should not fail
339        monitor.clear()?;
340
341        Ok(())
342    }
343
344    #[test]
345    fn test_monitoring_config_default() {
346        let config = MonitoringConfig::default();
347        assert!(config.enabled);
348    }
349
350    #[test]
351    fn test_monitoring_report_serialization() -> Result<()> {
352        let report = MonitoringReport::default();
353
354        // Test saving and loading
355        let temp_path = "/tmp/test_monitoring_report.json";
356        report.save_to_file(temp_path)?;
357        let loaded_report = MonitoringReport::load_from_file(temp_path)?;
358
359        assert_eq!(report.session_id, loaded_report.session_id);
360        assert_eq!(report.batch_size, loaded_report.batch_size);
361
362        // Clean up
363        std::fs::remove_file(temp_path).ok();
364
365        Ok(())
366    }
367}