trustformers_core/monitoring/
mod.rs1pub mod activation_stats;
3pub mod attention;
4pub mod gradient_flow;
5pub mod memory;
6pub mod metrics;
7pub mod profiler;
8pub mod tensorboard;
9
10pub use activation_stats::*;
11pub use attention::*;
12pub use gradient_flow::*;
13pub use memory::*;
14pub use metrics::*;
15pub use profiler::*;
16pub use tensorboard::*;
17
18use anyhow::Result;
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use std::time::{Duration, Instant};
22
23#[derive(Debug, Clone)]
25pub struct ModelMonitor {
26 memory_tracker: MemoryTracker,
27 attention_visualizer: AttentionVisualizer,
28 profiler: ModelProfiler,
29 metrics_collector: MetricsCollector,
30 enabled: bool,
31}
32
33impl Default for ModelMonitor {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl ModelMonitor {
40 pub fn new() -> Self {
41 Self {
42 memory_tracker: MemoryTracker::new(),
43 attention_visualizer: AttentionVisualizer::new(),
44 profiler: ModelProfiler::new(),
45 metrics_collector: MetricsCollector::new(),
46 enabled: true,
47 }
48 }
49
50 pub fn with_config(config: MonitoringConfig) -> Self {
51 Self {
52 memory_tracker: MemoryTracker::with_config(config.memory_config),
53 attention_visualizer: AttentionVisualizer::with_config(config.attention_config),
54 profiler: ModelProfiler::with_config(config.profiler_config),
55 metrics_collector: MetricsCollector::with_config(config.metrics_config),
56 enabled: config.enabled,
57 }
58 }
59
60 pub fn start_forward_pass(
62 &mut self,
63 batch_size: usize,
64 sequence_length: usize,
65 ) -> Result<MonitoringSession> {
66 if !self.enabled {
67 return Ok(MonitoringSession::disabled());
68 }
69
70 let session_id = uuid::Uuid::new_v4().to_string();
71 let start_time = Instant::now();
72
73 self.memory_tracker.start_tracking(&session_id)?;
74 self.attention_visualizer.start_tracking(&session_id)?;
75 self.profiler.start_profiling(&session_id)?;
76
77 Ok(MonitoringSession {
78 id: session_id,
79 start_time,
80 batch_size,
81 sequence_length,
82 enabled: true,
83 })
84 }
85
86 pub fn track_attention(
88 &mut self,
89 session: &MonitoringSession,
90 layer_idx: usize,
91 attention_weights: &crate::tensor::Tensor,
92 input_tokens: Option<&[String]>,
93 ) -> Result<()> {
94 if !session.enabled {
95 return Ok(());
96 }
97
98 self.attention_visualizer.track_attention(
99 &session.id,
100 layer_idx,
101 attention_weights,
102 input_tokens,
103 )
104 }
105
106 pub fn track_memory(
108 &mut self,
109 session: &MonitoringSession,
110 checkpoint: &str,
111 ) -> Result<MemorySnapshot> {
112 if !session.enabled {
113 return Ok(MemorySnapshot::default());
114 }
115
116 self.memory_tracker.take_snapshot(&session.id, checkpoint)
117 }
118
119 pub fn end_session(&mut self, session: MonitoringSession) -> Result<MonitoringReport> {
121 if !session.enabled {
122 return Ok(MonitoringReport::default());
123 }
124
125 let duration = session.start_time.elapsed();
126
127 let memory_report = self.memory_tracker.end_tracking(&session.id)?;
128 let profiling_report = self.profiler.end_profiling(&session.id)?;
129 let attention_report = self.attention_visualizer.get_report(&session.id)?;
130
131 let report = MonitoringReport {
132 session_id: session.id,
133 duration,
134 batch_size: session.batch_size,
135 sequence_length: session.sequence_length,
136 memory_report,
137 profiling_report,
138 attention_report,
139 metrics: self.metrics_collector.collect_metrics()?,
140 };
141
142 Ok(report)
143 }
144
145 pub fn set_enabled(&mut self, enabled: bool) {
147 self.enabled = enabled;
148 }
149
150 pub fn is_enabled(&self) -> bool {
152 self.enabled
153 }
154
155 pub fn clear(&mut self) -> Result<()> {
157 self.memory_tracker.clear()?;
158 self.attention_visualizer.clear()?;
159 self.profiler.clear()?;
160 self.metrics_collector.clear()?;
161 Ok(())
162 }
163}
164
165#[derive(Debug, Clone)]
167pub struct MonitoringConfig {
168 pub enabled: bool,
169 pub memory_config: MemoryTrackerConfig,
170 pub attention_config: AttentionVisualizerConfig,
171 pub profiler_config: ProfilerConfig,
172 pub metrics_config: MetricsCollectorConfig,
173}
174
175impl Default for MonitoringConfig {
176 fn default() -> Self {
177 Self {
178 enabled: true,
179 memory_config: MemoryTrackerConfig::default(),
180 attention_config: AttentionVisualizerConfig::default(),
181 profiler_config: ProfilerConfig::default(),
182 metrics_config: MetricsCollectorConfig::default(),
183 }
184 }
185}
186
187#[derive(Debug, Clone)]
189pub struct MonitoringSession {
190 pub id: String,
191 pub start_time: Instant,
192 pub batch_size: usize,
193 pub sequence_length: usize,
194 pub enabled: bool,
195}
196
197impl MonitoringSession {
198 fn disabled() -> Self {
199 Self {
200 id: String::new(),
201 start_time: Instant::now(),
202 batch_size: 0,
203 sequence_length: 0,
204 enabled: false,
205 }
206 }
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct MonitoringReport {
212 pub session_id: String,
213 pub duration: Duration,
214 pub batch_size: usize,
215 pub sequence_length: usize,
216 pub memory_report: MemoryReport,
217 pub profiling_report: ProfilingReport,
218 pub attention_report: AttentionReport,
219 pub metrics: HashMap<String, f64>,
220}
221
222impl Default for MonitoringReport {
223 fn default() -> Self {
224 Self {
225 session_id: String::new(),
226 duration: Duration::from_secs(0),
227 batch_size: 0,
228 sequence_length: 0,
229 memory_report: MemoryReport::default(),
230 profiling_report: ProfilingReport::default(),
231 attention_report: AttentionReport::default(),
232 metrics: HashMap::new(),
233 }
234 }
235}
236
237impl MonitoringReport {
238 pub fn save_to_file(&self, path: &str) -> Result<()> {
240 let json = serde_json::to_string_pretty(self)?;
241 std::fs::write(path, json)?;
242 Ok(())
243 }
244
245 pub fn load_from_file(path: &str) -> Result<Self> {
247 let content = std::fs::read_to_string(path)?;
248 let report = serde_json::from_str(&content)?;
249 Ok(report)
250 }
251
252 pub fn print_summary(&self) {
254 println!("Monitoring Report Summary");
255 println!("========================");
256 println!("Session ID: {}", self.session_id);
257 println!("Duration: {:.2}ms", self.duration.as_millis());
258 println!("Batch Size: {}", self.batch_size);
259 println!("Sequence Length: {}", self.sequence_length);
260 println!();
261
262 self.memory_report.print_summary();
263 println!();
264
265 self.profiling_report.print_summary();
266 println!();
267
268 self.attention_report.print_summary();
269 println!();
270
271 if !self.metrics.is_empty() {
272 println!("Additional Metrics:");
273 for (name, value) in &self.metrics {
274 println!(" {}: {:.4}", name, value);
275 }
276 }
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_monitor_creation() {
286 let monitor = ModelMonitor::new();
287 assert!(monitor.is_enabled());
288 }
289
290 #[test]
291 fn test_monitor_with_config() {
292 let config = MonitoringConfig {
293 enabled: false,
294 ..Default::default()
295 };
296
297 let monitor = ModelMonitor::with_config(config);
298 assert!(!monitor.is_enabled());
299 }
300
301 #[test]
302 fn test_monitoring_session() -> Result<()> {
303 let mut monitor = ModelMonitor::new();
304
305 let session = monitor.start_forward_pass(4, 128)?;
306 assert_eq!(session.batch_size, 4);
307 assert_eq!(session.sequence_length, 128);
308 assert!(session.enabled);
309
310 let report = monitor.end_session(session)?;
311 assert!(report.duration > Duration::from_nanos(0));
312
313 Ok(())
314 }
315
316 #[test]
317 fn test_disabled_monitoring() -> Result<()> {
318 let mut monitor = ModelMonitor::new();
319 monitor.set_enabled(false);
320
321 let session = monitor.start_forward_pass(4, 128)?;
322 assert!(!session.enabled);
323
324 let report = monitor.end_session(session)?;
325 assert_eq!(report.session_id, "");
326
327 Ok(())
328 }
329
330 #[test]
331 fn test_monitor_clear() -> Result<()> {
332 let mut monitor = ModelMonitor::new();
333
334 let session = monitor.start_forward_pass(4, 128)?;
336 let _report = monitor.end_session(session)?;
337
338 monitor.clear()?;
340
341 Ok(())
342 }
343
344 #[test]
345 fn test_monitoring_config_default() {
346 let config = MonitoringConfig::default();
347 assert!(config.enabled);
348 }
349
350 #[test]
351 fn test_monitoring_report_serialization() -> Result<()> {
352 let report = MonitoringReport::default();
353
354 let temp_path = "/tmp/test_monitoring_report.json";
356 report.save_to_file(temp_path)?;
357 let loaded_report = MonitoringReport::load_from_file(temp_path)?;
358
359 assert_eq!(report.session_id, loaded_report.session_id);
360 assert_eq!(report.batch_size, loaded_report.batch_size);
361
362 std::fs::remove_file(temp_path).ok();
364
365 Ok(())
366 }
367}