1use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use std::collections::{HashMap, VecDeque};
6use std::sync::{Arc, Mutex};
7use std::time::{Duration, Instant, SystemTime};
8use uuid::Uuid;
9
10use crate::DebugConfig;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct DashboardMetrics {
15 pub timestamp: SystemTime,
16 pub loss: Option<f64>,
17 pub accuracy: Option<f64>,
18 pub learning_rate: Option<f64>,
19 pub memory_usage_mb: f64,
20 pub gpu_utilization: Option<f64>,
21 pub tokens_per_second: Option<f64>,
22 pub gradient_norm: Option<f64>,
23 pub epoch: Option<u32>,
24 pub step: Option<u64>,
25}
26
27#[derive(Debug)]
29pub struct TrainingMonitor {
30 #[allow(dead_code)]
31 config: DebugConfig,
32 metrics_history: VecDeque<DashboardMetrics>,
33 max_history: usize,
34 start_time: Instant,
35 alert_thresholds: AlertThresholds,
36 active_alerts: Vec<TrainingAlert>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct AlertThresholds {
42 pub loss_increase_threshold: f64,
43 pub gradient_norm_max: f64,
44 pub memory_usage_max_mb: f64,
45 pub gpu_utilization_min: f64,
46 pub learning_rate_min: f64,
47 pub tokens_per_second_min: f64,
48}
49
50impl Default for AlertThresholds {
51 fn default() -> Self {
52 Self {
53 loss_increase_threshold: 1.5,
54 gradient_norm_max: 10.0,
55 memory_usage_max_mb: 8192.0,
56 gpu_utilization_min: 0.7,
57 learning_rate_min: 1e-8,
58 tokens_per_second_min: 100.0,
59 }
60 }
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct TrainingAlert {
66 pub alert_type: AlertType,
67 pub severity: AlertSeverity,
68 pub message: String,
69 pub timestamp: SystemTime,
70 pub metric_value: f64,
71 pub threshold: f64,
72 pub suggested_action: String,
73}
74
75#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
76pub enum AlertType {
77 LossIncrease,
78 GradientExplosion,
79 MemoryOveruse,
80 LowGpuUtilization,
81 LearningRateTooLow,
82 SlowTokenProcessing,
83 ModelDivergence,
84 TrainingStalled,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub enum AlertSeverity {
89 Info,
90 Warning,
91 Critical,
92}
93
94impl TrainingMonitor {
95 pub fn new(config: &DebugConfig) -> Self {
97 Self {
98 config: config.clone(),
99 metrics_history: VecDeque::new(),
100 max_history: 10000,
101 start_time: Instant::now(),
102 alert_thresholds: AlertThresholds::default(),
103 active_alerts: Vec::new(),
104 }
105 }
106
107 pub fn update_metrics(&mut self, metrics: DashboardMetrics) {
109 self.metrics_history.push_back(metrics.clone());
111
112 if self.metrics_history.len() > self.max_history {
114 self.metrics_history.pop_front();
115 }
116
117 self.check_alerts(&metrics);
119 }
120
121 pub fn get_recent_metrics(&self, count: usize) -> Vec<DashboardMetrics> {
123 self.metrics_history.iter().rev().take(count).rev().cloned().collect()
124 }
125
126 pub fn get_active_alerts(&self) -> &[TrainingAlert] {
128 &self.active_alerts
129 }
130
131 pub fn clear_alert(&mut self, _alert_type: AlertType) {
133 self.active_alerts.retain(|alert| !matches!(&alert.alert_type, _alert_type));
134 }
135
136 pub fn set_alert_thresholds(&mut self, thresholds: AlertThresholds) {
138 self.alert_thresholds = thresholds;
139 }
140
141 pub fn generate_training_summary(&self) -> TrainingSummary {
143 let total_duration = self.start_time.elapsed();
144 let total_steps = self.metrics_history.len();
145
146 let avg_loss = self.calculate_average_loss();
147 let best_accuracy = self.calculate_best_accuracy();
148 let avg_tokens_per_second = self.calculate_average_tokens_per_second();
149 let training_stability = self.calculate_training_stability();
150
151 TrainingSummary {
152 total_duration,
153 total_steps,
154 avg_loss,
155 best_accuracy,
156 avg_tokens_per_second,
157 training_stability,
158 active_alerts_count: self.active_alerts.len(),
159 convergence_status: self.assess_convergence(),
160 }
161 }
162
163 fn check_alerts(&mut self, metrics: &DashboardMetrics) {
164 if let Some(current_loss) = metrics.loss {
166 if let Some(prev_metrics) =
167 self.metrics_history.get(self.metrics_history.len().saturating_sub(10))
168 {
169 if let Some(prev_loss) = prev_metrics.loss {
170 if current_loss > prev_loss * self.alert_thresholds.loss_increase_threshold {
171 self.add_alert(TrainingAlert {
172 alert_type: AlertType::LossIncrease,
173 severity: AlertSeverity::Warning,
174 message: "Loss has increased significantly".to_string(),
175 timestamp: SystemTime::now(),
176 metric_value: current_loss,
177 threshold: prev_loss * self.alert_thresholds.loss_increase_threshold,
178 suggested_action: "Check learning rate or data quality".to_string(),
179 });
180 }
181 }
182 }
183 }
184
185 if let Some(grad_norm) = metrics.gradient_norm {
187 if grad_norm > self.alert_thresholds.gradient_norm_max {
188 self.add_alert(TrainingAlert {
189 alert_type: AlertType::GradientExplosion,
190 severity: AlertSeverity::Critical,
191 message: "Gradient explosion detected".to_string(),
192 timestamp: SystemTime::now(),
193 metric_value: grad_norm,
194 threshold: self.alert_thresholds.gradient_norm_max,
195 suggested_action: "Apply gradient clipping or reduce learning rate".to_string(),
196 });
197 }
198 }
199
200 if metrics.memory_usage_mb > self.alert_thresholds.memory_usage_max_mb {
202 self.add_alert(TrainingAlert {
203 alert_type: AlertType::MemoryOveruse,
204 severity: AlertSeverity::Warning,
205 message: "High memory usage detected".to_string(),
206 timestamp: SystemTime::now(),
207 metric_value: metrics.memory_usage_mb,
208 threshold: self.alert_thresholds.memory_usage_max_mb,
209 suggested_action: "Reduce batch size or enable gradient checkpointing".to_string(),
210 });
211 }
212
213 if let Some(gpu_util) = metrics.gpu_utilization {
215 if gpu_util < self.alert_thresholds.gpu_utilization_min {
216 self.add_alert(TrainingAlert {
217 alert_type: AlertType::LowGpuUtilization,
218 severity: AlertSeverity::Info,
219 message: "Low GPU utilization".to_string(),
220 timestamp: SystemTime::now(),
221 metric_value: gpu_util,
222 threshold: self.alert_thresholds.gpu_utilization_min,
223 suggested_action: "Increase batch size or check data loading".to_string(),
224 });
225 }
226 }
227
228 if let Some(tps) = metrics.tokens_per_second {
230 if tps < self.alert_thresholds.tokens_per_second_min {
231 self.add_alert(TrainingAlert {
232 alert_type: AlertType::SlowTokenProcessing,
233 severity: AlertSeverity::Warning,
234 message: "Slow token processing detected".to_string(),
235 timestamp: SystemTime::now(),
236 metric_value: tps,
237 threshold: self.alert_thresholds.tokens_per_second_min,
238 suggested_action: "Optimize model or increase batch size".to_string(),
239 });
240 }
241 }
242 }
243
244 fn add_alert(&mut self, alert: TrainingAlert) {
245 if !self.active_alerts.iter().any(|a| a.alert_type == alert.alert_type) {
247 self.active_alerts.push(alert);
248 }
249 }
250
251 fn calculate_average_loss(&self) -> Option<f64> {
252 let losses: Vec<f64> = self.metrics_history.iter().filter_map(|m| m.loss).collect();
253
254 if losses.is_empty() {
255 None
256 } else {
257 Some(losses.iter().sum::<f64>() / losses.len() as f64)
258 }
259 }
260
261 fn calculate_best_accuracy(&self) -> Option<f64> {
262 self.metrics_history
263 .iter()
264 .filter_map(|m| m.accuracy)
265 .fold(None, |acc, x| match acc {
266 None => Some(x),
267 Some(y) => Some(x.max(y)),
268 })
269 }
270
271 fn calculate_average_tokens_per_second(&self) -> Option<f64> {
272 let tps_values: Vec<f64> =
273 self.metrics_history.iter().filter_map(|m| m.tokens_per_second).collect();
274
275 if tps_values.is_empty() {
276 None
277 } else {
278 Some(tps_values.iter().sum::<f64>() / tps_values.len() as f64)
279 }
280 }
281
282 fn calculate_training_stability(&self) -> TrainingStability {
283 if self.metrics_history.len() < 10 {
284 return TrainingStability::Insufficient;
285 }
286
287 let recent_losses: Vec<f64> =
288 self.metrics_history.iter().rev().take(50).filter_map(|m| m.loss).collect();
289
290 if recent_losses.len() < 10 {
291 return TrainingStability::Insufficient;
292 }
293
294 let mean_loss = recent_losses.iter().sum::<f64>() / recent_losses.len() as f64;
296 let variance = recent_losses.iter().map(|&x| (x - mean_loss).powi(2)).sum::<f64>()
297 / recent_losses.len() as f64;
298
299 let std_dev = variance.sqrt();
300 let coefficient_of_variation = if mean_loss != 0.0 { std_dev / mean_loss } else { 0.0 };
301
302 match coefficient_of_variation {
303 cv if cv < 0.1 => TrainingStability::Stable,
304 cv if cv < 0.3 => TrainingStability::Moderate,
305 _ => TrainingStability::Unstable,
306 }
307 }
308
309 fn assess_convergence(&self) -> ConvergenceStatus {
310 if self.metrics_history.len() < 50 {
311 return ConvergenceStatus::TooEarly;
312 }
313
314 let recent_losses: Vec<f64> =
315 self.metrics_history.iter().rev().take(100).filter_map(|m| m.loss).collect();
316
317 if recent_losses.len() < 50 {
318 return ConvergenceStatus::TooEarly;
319 }
320
321 let first_half_avg =
323 recent_losses[25..].iter().sum::<f64>() / (recent_losses.len() - 25) as f64;
324 let second_half_avg = recent_losses[..25].iter().sum::<f64>() / 25.0;
325
326 if second_half_avg < first_half_avg * 0.95 {
327 ConvergenceStatus::Converging
328 } else if (second_half_avg - first_half_avg).abs() / first_half_avg < 0.01 {
329 ConvergenceStatus::Converged
330 } else {
331 ConvergenceStatus::Diverging
332 }
333 }
334}
335
336#[derive(Debug)]
338pub struct ModelComparator {
339 models: HashMap<String, ModelMetrics>,
340 comparison_config: ComparisonConfig,
341}
342
343#[derive(Debug, Clone, Serialize, Deserialize)]
344pub struct ModelMetrics {
345 pub model_id: String,
346 pub model_name: String,
347 pub metrics_history: Vec<DashboardMetrics>,
348 pub final_loss: Option<f64>,
349 pub final_accuracy: Option<f64>,
350 pub training_time: Duration,
351 pub parameter_count: usize,
352 pub model_size_mb: f64,
353}
354
355#[derive(Debug, Clone, Serialize, Deserialize)]
356pub struct ComparisonConfig {
357 pub primary_metric: String,
358 pub comparison_window: usize,
359 pub significance_threshold: f64,
360}
361
362impl Default for ComparisonConfig {
363 fn default() -> Self {
364 Self {
365 primary_metric: "loss".to_string(),
366 comparison_window: 100,
367 significance_threshold: 0.05,
368 }
369 }
370}
371
372impl ModelComparator {
373 pub fn new() -> Self {
375 Self {
376 models: HashMap::new(),
377 comparison_config: ComparisonConfig::default(),
378 }
379 }
380
381 pub fn add_model(&mut self, model_metrics: ModelMetrics) {
383 self.models.insert(model_metrics.model_id.clone(), model_metrics);
384 }
385
386 pub fn compare_models(&self) -> ModelComparisonReport {
388 let mut comparisons = Vec::new();
389 let model_ids: Vec<String> = self.models.keys().cloned().collect();
390
391 for i in 0..model_ids.len() {
392 for j in (i + 1)..model_ids.len() {
393 let model_a = &self.models[&model_ids[i]];
394 let model_b = &self.models[&model_ids[j]];
395
396 let comparison = self.compare_two_models(model_a, model_b);
397 comparisons.push(comparison);
398 }
399 }
400
401 let best_model = self.find_best_model();
402 let ranking = self.rank_models();
403
404 ModelComparisonReport {
405 comparisons,
406 best_model,
407 ranking,
408 comparison_config: self.comparison_config.clone(),
409 }
410 }
411
412 fn compare_two_models(
413 &self,
414 model_a: &ModelMetrics,
415 model_b: &ModelMetrics,
416 ) -> ModelComparison {
417 let performance_diff = self.calculate_performance_difference(model_a, model_b);
418 let efficiency_diff = self.calculate_efficiency_difference(model_a, model_b);
419 let statistical_significance = self.test_statistical_significance(model_a, model_b);
420
421 ModelComparison {
422 model_a_id: model_a.model_id.clone(),
423 model_b_id: model_b.model_id.clone(),
424 performance_difference: performance_diff,
425 efficiency_difference: efficiency_diff,
426 statistical_significance,
427 recommendation: self.generate_recommendation(model_a, model_b, performance_diff),
428 }
429 }
430
431 fn calculate_performance_difference(
432 &self,
433 model_a: &ModelMetrics,
434 model_b: &ModelMetrics,
435 ) -> f64 {
436 match self.comparison_config.primary_metric.as_str() {
437 "loss" => {
438 if let (Some(loss_a), Some(loss_b)) = (model_a.final_loss, model_b.final_loss) {
439 (loss_b - loss_a) / loss_a } else {
441 0.0
442 }
443 },
444 "accuracy" => {
445 if let (Some(acc_a), Some(acc_b)) = (model_a.final_accuracy, model_b.final_accuracy)
446 {
447 (acc_b - acc_a) / acc_a } else {
449 0.0
450 }
451 },
452 _ => 0.0,
453 }
454 }
455
456 fn calculate_efficiency_difference(
457 &self,
458 model_a: &ModelMetrics,
459 model_b: &ModelMetrics,
460 ) -> f64 {
461 let time_diff =
463 model_b.training_time.as_secs_f64() / model_a.training_time.as_secs_f64() - 1.0;
464
465 let size_diff = model_b.model_size_mb / model_a.model_size_mb - 1.0;
467
468 (time_diff + size_diff) / 2.0
470 }
471
472 fn test_statistical_significance(
473 &self,
474 _model_a: &ModelMetrics,
475 _model_b: &ModelMetrics,
476 ) -> bool {
477 true }
480
481 fn generate_recommendation(
482 &self,
483 model_a: &ModelMetrics,
484 model_b: &ModelMetrics,
485 perf_diff: f64,
486 ) -> String {
487 if perf_diff.abs() < 0.01 {
488 "Models perform similarly - choose based on other factors".to_string()
489 } else if perf_diff < 0.0 {
490 format!(
491 "Model {} performs {:.1}% better",
492 model_a.model_name,
493 perf_diff.abs() * 100.0
494 )
495 } else {
496 format!(
497 "Model {} performs {:.1}% better",
498 model_b.model_name,
499 perf_diff * 100.0
500 )
501 }
502 }
503
504 fn find_best_model(&self) -> Option<String> {
505 let mut best_model = None;
506 let mut best_score = f64::NEG_INFINITY;
507
508 for model in self.models.values() {
509 let score = match self.comparison_config.primary_metric.as_str() {
510 "loss" => model.final_loss.map(|l| -l).unwrap_or(f64::NEG_INFINITY),
511 "accuracy" => model.final_accuracy.unwrap_or(0.0),
512 _ => 0.0,
513 };
514
515 if score > best_score {
516 best_score = score;
517 best_model = Some(model.model_id.clone());
518 }
519 }
520
521 best_model
522 }
523
524 fn rank_models(&self) -> Vec<ModelRanking> {
525 let mut rankings: Vec<ModelRanking> = self
526 .models
527 .values()
528 .map(|model| {
529 let score = match self.comparison_config.primary_metric.as_str() {
530 "loss" => model.final_loss.map(|l| -l).unwrap_or(f64::NEG_INFINITY),
531 "accuracy" => model.final_accuracy.unwrap_or(0.0),
532 _ => 0.0,
533 };
534
535 ModelRanking {
536 model_id: model.model_id.clone(),
537 model_name: model.model_name.clone(),
538 score,
539 rank: 0, }
541 })
542 .collect();
543
544 rankings.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
545
546 for (i, ranking) in rankings.iter_mut().enumerate() {
547 ranking.rank = i + 1;
548 }
549
550 rankings
551 }
552}
553
554#[derive(Debug)]
556#[allow(dead_code)]
557pub struct HyperparameterExplorer {
558 experiments: HashMap<String, HyperparameterExperiment>,
559 #[allow(dead_code)]
560 search_space: HyperparameterSearchSpace,
561 optimization_history: Vec<OptimizationStep>,
562}
563
564#[derive(Debug, Clone, Serialize, Deserialize)]
565pub struct HyperparameterExperiment {
566 pub experiment_id: String,
567 pub hyperparameters: HashMap<String, HyperparameterValue>,
568 pub results: ExperimentResults,
569 pub status: ExperimentStatus,
570}
571
572#[derive(Debug, Clone, Serialize, Deserialize)]
573pub enum HyperparameterValue {
574 Float(f64),
575 Integer(i64),
576 String(String),
577 Boolean(bool),
578}
579
580#[derive(Debug, Clone, Serialize, Deserialize)]
581pub struct ExperimentResults {
582 pub final_loss: Option<f64>,
583 pub final_accuracy: Option<f64>,
584 pub training_time: Duration,
585 pub convergence_epoch: Option<u32>,
586 pub best_validation_score: Option<f64>,
587}
588
589#[derive(Debug, Clone, Serialize, Deserialize)]
590pub enum ExperimentStatus {
591 Running,
592 Completed,
593 Failed,
594 Cancelled,
595}
596
597#[derive(Debug, Clone, Serialize, Deserialize)]
598pub struct HyperparameterSearchSpace {
599 pub learning_rate: (f64, f64),
600 pub batch_size: (i64, i64),
601 pub dropout_rate: (f64, f64),
602 pub weight_decay: (f64, f64),
603 pub num_layers: (i64, i64),
604 pub hidden_size: (i64, i64),
605}
606
607impl Default for HyperparameterSearchSpace {
608 fn default() -> Self {
609 Self {
610 learning_rate: (1e-5, 1e-1),
611 batch_size: (4, 128),
612 dropout_rate: (0.0, 0.5),
613 weight_decay: (0.0, 1e-2),
614 num_layers: (1, 12),
615 hidden_size: (64, 2048),
616 }
617 }
618}
619
620#[derive(Debug, Clone, Serialize, Deserialize)]
621pub struct OptimizationStep {
622 pub step: usize,
623 pub best_experiment_id: String,
624 pub best_score: f64,
625 pub exploration_count: usize,
626 pub exploitation_count: usize,
627}
628
629impl HyperparameterExplorer {
630 pub fn new() -> Self {
632 Self {
633 experiments: HashMap::new(),
634 search_space: HyperparameterSearchSpace::default(),
635 optimization_history: Vec::new(),
636 }
637 }
638
639 pub fn add_experiment(&mut self, experiment: HyperparameterExperiment) {
641 self.experiments.insert(experiment.experiment_id.clone(), experiment);
642 }
643
644 pub fn get_recommendations(&self) -> HyperparameterRecommendations {
646 let best_experiments = self.find_best_experiments(5);
647 let parameter_importance = self.analyze_parameter_importance();
648 let suggested_ranges = self.suggest_search_ranges();
649 let next_experiments = self.suggest_next_experiments(3);
650
651 HyperparameterRecommendations {
652 best_experiments,
653 parameter_importance,
654 suggested_ranges,
655 next_experiments,
656 total_experiments: self.experiments.len(),
657 }
658 }
659
660 fn find_best_experiments(&self, limit: usize) -> Vec<String> {
661 let mut experiments: Vec<_> = self.experiments.values().collect();
662 experiments.sort_by(|a, b| {
663 let score_a = a.results.final_loss.unwrap_or(f64::INFINITY);
664 let score_b = b.results.final_loss.unwrap_or(f64::INFINITY);
665 score_a.partial_cmp(&score_b).unwrap_or(std::cmp::Ordering::Equal)
666 });
667
668 experiments.iter().take(limit).map(|exp| exp.experiment_id.clone()).collect()
669 }
670
671 fn analyze_parameter_importance(&self) -> HashMap<String, f64> {
672 let mut importance = HashMap::new();
674 importance.insert("learning_rate".to_string(), 0.8);
675 importance.insert("batch_size".to_string(), 0.6);
676 importance.insert("dropout_rate".to_string(), 0.4);
677 importance.insert("weight_decay".to_string(), 0.3);
678 importance
679 }
680
681 fn suggest_search_ranges(&self) -> HashMap<String, (f64, f64)> {
682 let mut ranges = HashMap::new();
684 ranges.insert("learning_rate".to_string(), (1e-4, 1e-2));
685 ranges.insert("dropout_rate".to_string(), (0.1, 0.3));
686 ranges
687 }
688
689 fn suggest_next_experiments(&self, count: usize) -> Vec<HashMap<String, HyperparameterValue>> {
690 let mut suggestions = Vec::new();
691
692 for i in 0..count {
693 let mut params = HashMap::new();
694
695 params.insert(
697 "learning_rate".to_string(),
698 HyperparameterValue::Float(0.001 * (1.0 + i as f64 * 0.5)),
699 );
700 params.insert(
701 "batch_size".to_string(),
702 HyperparameterValue::Integer(32 * (1 + i as i64)),
703 );
704 params.insert(
705 "dropout_rate".to_string(),
706 HyperparameterValue::Float(0.1 + i as f64 * 0.1),
707 );
708
709 suggestions.push(params);
710 }
711
712 suggestions
713 }
714}
715
716#[derive(Debug)]
718pub struct InteractiveDashboard {
719 #[allow(dead_code)]
720 config: DebugConfig,
721 training_monitor: TrainingMonitor,
722 model_comparator: ModelComparator,
723 hyperparameter_explorer: HyperparameterExplorer,
724 dashboard_state: DashboardState,
725 websocket_server: Option<WebSocketServer>,
726}
727
728#[derive(Debug, Serialize, Deserialize)]
729pub struct DashboardState {
730 pub active_session_id: Option<Uuid>,
731 pub refresh_rate_ms: u64,
732 pub auto_alerts: bool,
733 pub display_mode: DisplayMode,
734}
735
736#[derive(Debug, Clone, Serialize, Deserialize)]
737pub enum DisplayMode {
738 Overview,
739 DetailedMetrics,
740 ModelComparison,
741 HyperparameterOptimization,
742 AlertsOnly,
743}
744
745#[derive(Debug)]
747#[allow(dead_code)]
748pub struct WebSocketServer {
749 #[allow(dead_code)]
750 port: u16,
751 connected_clients: Arc<Mutex<Vec<String>>>,
752}
753
754impl InteractiveDashboard {
755 pub fn new(config: &DebugConfig) -> Self {
757 Self {
758 config: config.clone(),
759 training_monitor: TrainingMonitor::new(config),
760 model_comparator: ModelComparator::new(),
761 hyperparameter_explorer: HyperparameterExplorer::new(),
762 dashboard_state: DashboardState {
763 active_session_id: None,
764 refresh_rate_ms: 1000,
765 auto_alerts: true,
766 display_mode: DisplayMode::Overview,
767 },
768 websocket_server: None,
769 }
770 }
771
772 pub async fn start(&mut self, port: Option<u16>) -> Result<()> {
774 let port = port.unwrap_or(8080);
775
776 self.websocket_server = Some(WebSocketServer {
777 port,
778 connected_clients: Arc::new(Mutex::new(Vec::new())),
779 });
780
781 tracing::info!("Interactive dashboard started on port {}", port);
782 Ok(())
783 }
784
785 pub fn update(&mut self, metrics: DashboardMetrics) {
787 self.training_monitor.update_metrics(metrics.clone());
788
789 if let Some(_ws_server) = &self.websocket_server {
791 self.broadcast_update(metrics);
792 }
793 }
794
795 pub fn get_dashboard_snapshot(&self) -> DashboardSnapshot {
797 let training_summary = self.training_monitor.generate_training_summary();
798 let recent_metrics = self.training_monitor.get_recent_metrics(100);
799 let active_alerts = self.training_monitor.get_active_alerts().to_vec();
800 let model_comparison = self.model_comparator.compare_models();
801 let hyperparameter_recommendations = self.hyperparameter_explorer.get_recommendations();
802
803 DashboardSnapshot {
804 timestamp: SystemTime::now(),
805 training_summary,
806 recent_metrics,
807 active_alerts,
808 model_comparison,
809 hyperparameter_recommendations,
810 dashboard_state: DashboardState {
811 active_session_id: self.dashboard_state.active_session_id,
812 refresh_rate_ms: self.dashboard_state.refresh_rate_ms,
813 auto_alerts: self.dashboard_state.auto_alerts,
814 display_mode: self.dashboard_state.display_mode.clone(),
815 },
816 }
817 }
818
819 pub async fn export_dashboard_data(&self, path: &str) -> Result<()> {
821 let snapshot = self.get_dashboard_snapshot();
822 let json = serde_json::to_string_pretty(&snapshot)?;
823 tokio::fs::write(path, json).await?;
824 Ok(())
825 }
826
827 fn broadcast_update(&self, _metrics: DashboardMetrics) {
828 tracing::debug!("Broadcasting dashboard update to connected clients");
830 }
831}
832
833#[derive(Debug, Clone, Serialize, Deserialize)]
836pub struct TrainingSummary {
837 pub total_duration: Duration,
838 pub total_steps: usize,
839 pub avg_loss: Option<f64>,
840 pub best_accuracy: Option<f64>,
841 pub avg_tokens_per_second: Option<f64>,
842 pub training_stability: TrainingStability,
843 pub active_alerts_count: usize,
844 pub convergence_status: ConvergenceStatus,
845}
846
847#[derive(Debug, Clone, Serialize, Deserialize)]
848pub enum TrainingStability {
849 Stable,
850 Moderate,
851 Unstable,
852 Insufficient,
853}
854
855#[derive(Debug, Clone, Serialize, Deserialize)]
856pub enum ConvergenceStatus {
857 TooEarly,
858 Converging,
859 Converged,
860 Diverging,
861}
862
863#[derive(Debug, Serialize, Deserialize)]
864pub struct ModelComparisonReport {
865 pub comparisons: Vec<ModelComparison>,
866 pub best_model: Option<String>,
867 pub ranking: Vec<ModelRanking>,
868 pub comparison_config: ComparisonConfig,
869}
870
871#[derive(Debug, Serialize, Deserialize)]
872pub struct ModelComparison {
873 pub model_a_id: String,
874 pub model_b_id: String,
875 pub performance_difference: f64,
876 pub efficiency_difference: f64,
877 pub statistical_significance: bool,
878 pub recommendation: String,
879}
880
881#[derive(Debug, Serialize, Deserialize)]
882pub struct ModelRanking {
883 pub model_id: String,
884 pub model_name: String,
885 pub score: f64,
886 pub rank: usize,
887}
888
889#[derive(Debug, Serialize, Deserialize)]
890pub struct HyperparameterRecommendations {
891 pub best_experiments: Vec<String>,
892 pub parameter_importance: HashMap<String, f64>,
893 pub suggested_ranges: HashMap<String, (f64, f64)>,
894 pub next_experiments: Vec<HashMap<String, HyperparameterValue>>,
895 pub total_experiments: usize,
896}
897
898#[derive(Debug, Serialize, Deserialize)]
899pub struct DashboardSnapshot {
900 pub timestamp: SystemTime,
901 pub training_summary: TrainingSummary,
902 pub recent_metrics: Vec<DashboardMetrics>,
903 pub active_alerts: Vec<TrainingAlert>,
904 pub model_comparison: ModelComparisonReport,
905 pub hyperparameter_recommendations: HyperparameterRecommendations,
906 pub dashboard_state: DashboardState,
907}
908
909#[derive(Debug, Serialize, Deserialize)]
911pub struct DashboardReport {
912 pub session_duration: Duration,
913 pub total_metrics_recorded: usize,
914 pub alerts_triggered: usize,
915 pub models_compared: usize,
916 pub experiments_tracked: usize,
917 pub performance_summary: TrainingSummary,
918 pub key_insights: Vec<String>,
919 pub recommendations: Vec<String>,
920}
921
922impl InteractiveDashboard {
923 pub async fn generate_report(&self) -> Result<DashboardReport> {
925 let training_summary = self.training_monitor.generate_training_summary();
926 let total_metrics = self.training_monitor.metrics_history.len();
927 let alerts_count = self.training_monitor.active_alerts.len();
928 let models_count = self.model_comparator.models.len();
929 let experiments_count = self.hyperparameter_explorer.experiments.len();
930
931 let key_insights = self.generate_key_insights();
932 let recommendations = self.generate_recommendations();
933
934 Ok(DashboardReport {
935 session_duration: training_summary.total_duration,
936 total_metrics_recorded: total_metrics,
937 alerts_triggered: alerts_count,
938 models_compared: models_count,
939 experiments_tracked: experiments_count,
940 performance_summary: training_summary,
941 key_insights,
942 recommendations,
943 })
944 }
945
946 fn generate_key_insights(&self) -> Vec<String> {
947 let mut insights = Vec::new();
948
949 match self.training_monitor.generate_training_summary().training_stability {
951 TrainingStability::Stable => insights.push("Training is proceeding stably".to_string()),
952 TrainingStability::Unstable => insights.push(
953 "Training shows high variance - consider adjusting hyperparameters".to_string(),
954 ),
955 _ => {},
956 }
957
958 if self.model_comparator.models.len() > 1 {
960 let comparison = self.model_comparator.compare_models();
961 if let Some(best_model) = comparison.best_model {
962 insights.push(format!("Best performing model: {}", best_model));
963 }
964 }
965
966 let critical_alerts = self
968 .training_monitor
969 .active_alerts
970 .iter()
971 .filter(|alert| matches!(alert.severity, AlertSeverity::Critical))
972 .count();
973
974 if critical_alerts > 0 {
975 insights.push(format!(
976 "{} critical alerts require immediate attention",
977 critical_alerts
978 ));
979 }
980
981 insights
982 }
983
984 fn generate_recommendations(&self) -> Vec<String> {
985 let mut recommendations = Vec::new();
986
987 for alert in &self.training_monitor.active_alerts {
989 if matches!(alert.severity, AlertSeverity::Critical) {
990 recommendations.push(alert.suggested_action.clone());
991 }
992 }
993
994 if self.hyperparameter_explorer.experiments.len() > 5 {
996 recommendations.push(
997 "Continue hyperparameter optimization with narrowed search ranges".to_string(),
998 );
999 }
1000
1001 if self.model_comparator.models.len() > 1 {
1003 recommendations
1004 .push("Focus on the best performing model architecture for production".to_string());
1005 }
1006
1007 if recommendations.is_empty() {
1008 recommendations.push("Continue monitoring training progress".to_string());
1009 }
1010
1011 recommendations
1012 }
1013}
1014
1015#[cfg(test)]
1016mod tests {
1017 use super::*;
1018
1019 fn make_config() -> DebugConfig {
1020 DebugConfig::default()
1021 }
1022
1023 fn make_metrics_with(
1024 loss: Option<f64>,
1025 accuracy: Option<f64>,
1026 memory_mb: f64,
1027 ) -> DashboardMetrics {
1028 DashboardMetrics {
1029 timestamp: SystemTime::now(),
1030 loss,
1031 accuracy,
1032 learning_rate: Some(0.001),
1033 memory_usage_mb: memory_mb,
1034 gpu_utilization: Some(0.8),
1035 tokens_per_second: Some(200.0),
1036 gradient_norm: Some(1.0),
1037 epoch: Some(1),
1038 step: Some(100),
1039 }
1040 }
1041
1042 fn make_metrics_simple() -> DashboardMetrics {
1043 make_metrics_with(Some(0.5), Some(0.85), 2048.0)
1044 }
1045
1046 #[test]
1049 fn test_alert_thresholds_default() {
1050 let thresholds = AlertThresholds::default();
1051 assert!((thresholds.loss_increase_threshold - 1.5).abs() < 1e-9);
1052 assert!((thresholds.gradient_norm_max - 10.0).abs() < 1e-9);
1053 assert!((thresholds.memory_usage_max_mb - 8192.0).abs() < 1e-9);
1054 }
1055
1056 #[test]
1059 fn test_training_monitor_new() {
1060 let config = make_config();
1061 let monitor = TrainingMonitor::new(&config);
1062 assert!(monitor.metrics_history.is_empty());
1063 assert!(monitor.active_alerts.is_empty());
1064 assert_eq!(monitor.max_history, 10000);
1065 }
1066
1067 #[test]
1068 fn test_training_monitor_update_metrics() {
1069 let config = make_config();
1070 let mut monitor = TrainingMonitor::new(&config);
1071 monitor.update_metrics(make_metrics_simple());
1072 assert_eq!(monitor.metrics_history.len(), 1);
1073 }
1074
1075 #[test]
1076 fn test_training_monitor_history_limit() {
1077 let config = make_config();
1078 let mut monitor = TrainingMonitor::new(&config);
1079 monitor.max_history = 5;
1080 for _ in 0..10 {
1081 monitor.update_metrics(make_metrics_simple());
1082 }
1083 assert_eq!(monitor.metrics_history.len(), 5);
1084 }
1085
1086 #[test]
1087 fn test_training_monitor_get_recent_metrics() {
1088 let config = make_config();
1089 let mut monitor = TrainingMonitor::new(&config);
1090 for _ in 0..5 {
1091 monitor.update_metrics(make_metrics_simple());
1092 }
1093 let recent = monitor.get_recent_metrics(3);
1094 assert_eq!(recent.len(), 3);
1095 }
1096
1097 #[test]
1098 fn test_training_monitor_get_recent_metrics_more_than_available() {
1099 let config = make_config();
1100 let mut monitor = TrainingMonitor::new(&config);
1101 monitor.update_metrics(make_metrics_simple());
1102 let recent = monitor.get_recent_metrics(10);
1103 assert_eq!(recent.len(), 1);
1104 }
1105
1106 #[test]
1107 fn test_training_monitor_set_alert_thresholds() {
1108 let config = make_config();
1109 let mut monitor = TrainingMonitor::new(&config);
1110 let thresholds = AlertThresholds {
1111 loss_increase_threshold: 2.0,
1112 gradient_norm_max: 5.0,
1113 memory_usage_max_mb: 4096.0,
1114 gpu_utilization_min: 0.5,
1115 learning_rate_min: 1e-6,
1116 tokens_per_second_min: 50.0,
1117 };
1118 monitor.set_alert_thresholds(thresholds);
1119 assert!((monitor.alert_thresholds.gradient_norm_max - 5.0).abs() < 1e-9);
1120 }
1121
1122 #[test]
1123 fn test_training_monitor_gradient_explosion_alert() {
1124 let config = make_config();
1125 let mut monitor = TrainingMonitor::new(&config);
1126 let mut metrics = make_metrics_simple();
1127 metrics.gradient_norm = Some(100.0);
1128 monitor.update_metrics(metrics);
1129 assert!(monitor
1130 .active_alerts
1131 .iter()
1132 .any(|a| a.alert_type == AlertType::GradientExplosion));
1133 }
1134
1135 #[test]
1136 fn test_training_monitor_memory_overuse_alert() {
1137 let config = make_config();
1138 let mut monitor = TrainingMonitor::new(&config);
1139 let metrics = make_metrics_with(Some(0.5), Some(0.8), 10000.0);
1140 monitor.update_metrics(metrics);
1141 assert!(monitor.active_alerts.iter().any(|a| a.alert_type == AlertType::MemoryOveruse));
1142 }
1143
1144 #[test]
1145 fn test_training_monitor_low_gpu_alert() {
1146 let config = make_config();
1147 let mut monitor = TrainingMonitor::new(&config);
1148 let mut metrics = make_metrics_simple();
1149 metrics.gpu_utilization = Some(0.1);
1150 monitor.update_metrics(metrics);
1151 assert!(monitor
1152 .active_alerts
1153 .iter()
1154 .any(|a| a.alert_type == AlertType::LowGpuUtilization));
1155 }
1156
1157 #[test]
1158 fn test_training_monitor_slow_token_alert() {
1159 let config = make_config();
1160 let mut monitor = TrainingMonitor::new(&config);
1161 let mut metrics = make_metrics_simple();
1162 metrics.tokens_per_second = Some(10.0);
1163 monitor.update_metrics(metrics);
1164 assert!(monitor
1165 .active_alerts
1166 .iter()
1167 .any(|a| a.alert_type == AlertType::SlowTokenProcessing));
1168 }
1169
1170 #[test]
1171 fn test_training_monitor_no_duplicate_alerts() {
1172 let config = make_config();
1173 let mut monitor = TrainingMonitor::new(&config);
1174 let mut metrics = make_metrics_simple();
1175 metrics.gradient_norm = Some(100.0);
1176 monitor.update_metrics(metrics.clone());
1177 monitor.update_metrics(metrics);
1178 let grad_alerts = monitor
1179 .active_alerts
1180 .iter()
1181 .filter(|a| a.alert_type == AlertType::GradientExplosion)
1182 .count();
1183 assert_eq!(grad_alerts, 1);
1184 }
1185
1186 #[test]
1187 fn test_training_monitor_average_loss_none() {
1188 let config = make_config();
1189 let monitor = TrainingMonitor::new(&config);
1190 assert!(monitor.calculate_average_loss().is_none());
1191 }
1192
1193 #[test]
1194 fn test_training_monitor_average_loss() {
1195 let config = make_config();
1196 let mut monitor = TrainingMonitor::new(&config);
1197 monitor.update_metrics(make_metrics_with(Some(1.0), None, 1024.0));
1198 monitor.update_metrics(make_metrics_with(Some(2.0), None, 1024.0));
1199 let avg = monitor.calculate_average_loss();
1200 assert!(avg.is_some());
1201 assert!((avg.expect("should be some") - 1.5).abs() < 1e-9);
1202 }
1203
1204 #[test]
1205 fn test_training_monitor_best_accuracy_none() {
1206 let config = make_config();
1207 let monitor = TrainingMonitor::new(&config);
1208 assert!(monitor.calculate_best_accuracy().is_none());
1209 }
1210
1211 #[test]
1212 fn test_training_monitor_best_accuracy() {
1213 let config = make_config();
1214 let mut monitor = TrainingMonitor::new(&config);
1215 monitor.update_metrics(make_metrics_with(None, Some(0.7), 1024.0));
1216 monitor.update_metrics(make_metrics_with(None, Some(0.9), 1024.0));
1217 monitor.update_metrics(make_metrics_with(None, Some(0.8), 1024.0));
1218 let best = monitor.calculate_best_accuracy();
1219 assert!(best.is_some());
1220 assert!((best.expect("should be some") - 0.9).abs() < 1e-9);
1221 }
1222
1223 #[test]
1224 fn test_training_monitor_avg_tps_none() {
1225 let config = make_config();
1226 let monitor = TrainingMonitor::new(&config);
1227 assert!(monitor.calculate_average_tokens_per_second().is_none());
1228 }
1229
1230 #[test]
1231 fn test_training_stability_insufficient() {
1232 let config = make_config();
1233 let monitor = TrainingMonitor::new(&config);
1234 assert!(matches!(
1235 monitor.calculate_training_stability(),
1236 TrainingStability::Insufficient
1237 ));
1238 }
1239
1240 #[test]
1241 fn test_convergence_too_early() {
1242 let config = make_config();
1243 let monitor = TrainingMonitor::new(&config);
1244 assert!(matches!(
1245 monitor.assess_convergence(),
1246 ConvergenceStatus::TooEarly
1247 ));
1248 }
1249
1250 #[test]
1251 fn test_generate_training_summary() {
1252 let config = make_config();
1253 let monitor = TrainingMonitor::new(&config);
1254 let summary = monitor.generate_training_summary();
1255 assert_eq!(summary.total_steps, 0);
1256 assert!(matches!(
1257 summary.convergence_status,
1258 ConvergenceStatus::TooEarly
1259 ));
1260 }
1261
1262 #[test]
1265 fn test_model_comparator_new() {
1266 let comparator = ModelComparator::new();
1267 assert!(comparator.models.is_empty());
1268 }
1269
1270 #[test]
1271 fn test_model_comparator_add_model() {
1272 let mut comparator = ModelComparator::new();
1273 comparator.add_model(ModelMetrics {
1274 model_id: "m1".to_string(),
1275 model_name: "Model A".to_string(),
1276 metrics_history: Vec::new(),
1277 final_loss: Some(0.5),
1278 final_accuracy: Some(0.9),
1279 training_time: Duration::from_secs(100),
1280 parameter_count: 1000,
1281 model_size_mb: 10.0,
1282 });
1283 assert_eq!(comparator.models.len(), 1);
1284 }
1285
1286 #[test]
1287 fn test_model_comparator_find_best_model_empty() {
1288 let comparator = ModelComparator::new();
1289 assert!(comparator.find_best_model().is_none());
1290 }
1291
1292 #[test]
1293 fn test_model_comparator_find_best_model() {
1294 let mut comparator = ModelComparator::new();
1295 comparator.add_model(ModelMetrics {
1296 model_id: "m1".to_string(),
1297 model_name: "Model A".to_string(),
1298 metrics_history: Vec::new(),
1299 final_loss: Some(0.5),
1300 final_accuracy: Some(0.9),
1301 training_time: Duration::from_secs(100),
1302 parameter_count: 1000,
1303 model_size_mb: 10.0,
1304 });
1305 comparator.add_model(ModelMetrics {
1306 model_id: "m2".to_string(),
1307 model_name: "Model B".to_string(),
1308 metrics_history: Vec::new(),
1309 final_loss: Some(0.3),
1310 final_accuracy: Some(0.95),
1311 training_time: Duration::from_secs(200),
1312 parameter_count: 2000,
1313 model_size_mb: 20.0,
1314 });
1315 let best = comparator.find_best_model();
1316 assert!(best.is_some());
1317 assert_eq!(best.expect("should find best"), "m2");
1318 }
1319
1320 #[test]
1321 fn test_model_comparator_rank_models() {
1322 let mut comparator = ModelComparator::new();
1323 comparator.add_model(ModelMetrics {
1324 model_id: "m1".to_string(),
1325 model_name: "A".to_string(),
1326 metrics_history: Vec::new(),
1327 final_loss: Some(0.5),
1328 final_accuracy: None,
1329 training_time: Duration::from_secs(100),
1330 parameter_count: 1000,
1331 model_size_mb: 10.0,
1332 });
1333 let ranking = comparator.rank_models();
1334 assert_eq!(ranking.len(), 1);
1335 assert_eq!(ranking[0].rank, 1);
1336 }
1337
1338 #[test]
1339 fn test_model_comparator_generate_recommendation_similar() {
1340 let comparator = ModelComparator::new();
1341 let ma = ModelMetrics {
1342 model_id: "a".to_string(),
1343 model_name: "A".to_string(),
1344 metrics_history: Vec::new(),
1345 final_loss: Some(0.5),
1346 final_accuracy: None,
1347 training_time: Duration::from_secs(100),
1348 parameter_count: 1000,
1349 model_size_mb: 10.0,
1350 };
1351 let rec = comparator.generate_recommendation(&ma, &ma, 0.0);
1352 assert!(rec.contains("similarly"));
1353 }
1354
1355 #[test]
1358 fn test_hyperparameter_explorer_new() {
1359 let explorer = HyperparameterExplorer::new();
1360 assert!(explorer.experiments.is_empty());
1361 }
1362
1363 #[test]
1364 fn test_hyperparameter_explorer_add_experiment() {
1365 let mut explorer = HyperparameterExplorer::new();
1366 explorer.add_experiment(HyperparameterExperiment {
1367 experiment_id: "exp1".to_string(),
1368 hyperparameters: HashMap::new(),
1369 results: ExperimentResults {
1370 final_loss: Some(0.5),
1371 final_accuracy: Some(0.9),
1372 training_time: Duration::from_secs(100),
1373 convergence_epoch: Some(50),
1374 best_validation_score: Some(0.88),
1375 },
1376 status: ExperimentStatus::Completed,
1377 });
1378 assert_eq!(explorer.experiments.len(), 1);
1379 }
1380
1381 #[test]
1382 fn test_hyperparameter_explorer_get_recommendations() {
1383 let explorer = HyperparameterExplorer::new();
1384 let recs = explorer.get_recommendations();
1385 assert_eq!(recs.total_experiments, 0);
1386 assert!(!recs.parameter_importance.is_empty());
1387 }
1388
1389 #[test]
1390 fn test_hyperparameter_explorer_suggest_next_experiments() {
1391 let explorer = HyperparameterExplorer::new();
1392 let suggestions = explorer.suggest_next_experiments(3);
1393 assert_eq!(suggestions.len(), 3);
1394 }
1395
1396 #[test]
1399 fn test_interactive_dashboard_new() {
1400 let config = make_config();
1401 let dashboard = InteractiveDashboard::new(&config);
1402 assert!(dashboard.websocket_server.is_none());
1403 }
1404
1405 #[test]
1406 fn test_interactive_dashboard_update() {
1407 let config = make_config();
1408 let mut dashboard = InteractiveDashboard::new(&config);
1409 dashboard.update(make_metrics_simple());
1410 assert_eq!(dashboard.training_monitor.metrics_history.len(), 1);
1411 }
1412
1413 #[test]
1414 fn test_interactive_dashboard_snapshot() {
1415 let config = make_config();
1416 let dashboard = InteractiveDashboard::new(&config);
1417 let snapshot = dashboard.get_dashboard_snapshot();
1418 assert!(snapshot.recent_metrics.is_empty());
1419 }
1420
1421 #[test]
1422 fn test_interactive_dashboard_generate_recommendations() {
1423 let config = make_config();
1424 let dashboard = InteractiveDashboard::new(&config);
1425 let recs = dashboard.generate_recommendations();
1426 assert!(!recs.is_empty());
1427 }
1428
1429 #[test]
1430 fn test_interactive_dashboard_generate_key_insights() {
1431 let config = make_config();
1432 let dashboard = InteractiveDashboard::new(&config);
1433 let insights = dashboard.generate_key_insights();
1434 assert!(insights.is_empty() || !insights.is_empty());
1436 }
1437
1438 #[test]
1441 fn test_comparison_config_default() {
1442 let config = ComparisonConfig::default();
1443 assert_eq!(config.primary_metric, "loss");
1444 assert_eq!(config.comparison_window, 100);
1445 }
1446
1447 #[test]
1450 fn test_search_space_default() {
1451 let space = HyperparameterSearchSpace::default();
1452 assert!(space.learning_rate.0 < space.learning_rate.1);
1453 assert!(space.batch_size.0 < space.batch_size.1);
1454 }
1455}