1use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use std::collections::{HashMap, VecDeque};
6use std::sync::{Arc, Mutex};
7use std::time::{Duration, Instant, SystemTime};
8use uuid::Uuid;
9
10use crate::DebugConfig;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct DashboardMetrics {
15 pub timestamp: SystemTime,
16 pub loss: Option<f64>,
17 pub accuracy: Option<f64>,
18 pub learning_rate: Option<f64>,
19 pub memory_usage_mb: f64,
20 pub gpu_utilization: Option<f64>,
21 pub tokens_per_second: Option<f64>,
22 pub gradient_norm: Option<f64>,
23 pub epoch: Option<u32>,
24 pub step: Option<u64>,
25}
26
27#[derive(Debug)]
29pub struct TrainingMonitor {
30 #[allow(dead_code)]
31 config: DebugConfig,
32 metrics_history: VecDeque<DashboardMetrics>,
33 max_history: usize,
34 start_time: Instant,
35 alert_thresholds: AlertThresholds,
36 active_alerts: Vec<TrainingAlert>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct AlertThresholds {
42 pub loss_increase_threshold: f64,
43 pub gradient_norm_max: f64,
44 pub memory_usage_max_mb: f64,
45 pub gpu_utilization_min: f64,
46 pub learning_rate_min: f64,
47 pub tokens_per_second_min: f64,
48}
49
50impl Default for AlertThresholds {
51 fn default() -> Self {
52 Self {
53 loss_increase_threshold: 1.5,
54 gradient_norm_max: 10.0,
55 memory_usage_max_mb: 8192.0,
56 gpu_utilization_min: 0.7,
57 learning_rate_min: 1e-8,
58 tokens_per_second_min: 100.0,
59 }
60 }
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct TrainingAlert {
66 pub alert_type: AlertType,
67 pub severity: AlertSeverity,
68 pub message: String,
69 pub timestamp: SystemTime,
70 pub metric_value: f64,
71 pub threshold: f64,
72 pub suggested_action: String,
73}
74
75#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
76pub enum AlertType {
77 LossIncrease,
78 GradientExplosion,
79 MemoryOveruse,
80 LowGpuUtilization,
81 LearningRateTooLow,
82 SlowTokenProcessing,
83 ModelDivergence,
84 TrainingStalled,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub enum AlertSeverity {
89 Info,
90 Warning,
91 Critical,
92}
93
94impl TrainingMonitor {
95 pub fn new(config: &DebugConfig) -> Self {
97 Self {
98 config: config.clone(),
99 metrics_history: VecDeque::new(),
100 max_history: 10000,
101 start_time: Instant::now(),
102 alert_thresholds: AlertThresholds::default(),
103 active_alerts: Vec::new(),
104 }
105 }
106
107 pub fn update_metrics(&mut self, metrics: DashboardMetrics) {
109 self.metrics_history.push_back(metrics.clone());
111
112 if self.metrics_history.len() > self.max_history {
114 self.metrics_history.pop_front();
115 }
116
117 self.check_alerts(&metrics);
119 }
120
121 pub fn get_recent_metrics(&self, count: usize) -> Vec<DashboardMetrics> {
123 self.metrics_history.iter().rev().take(count).rev().cloned().collect()
124 }
125
126 pub fn get_active_alerts(&self) -> &[TrainingAlert] {
128 &self.active_alerts
129 }
130
131 pub fn clear_alert(&mut self, _alert_type: AlertType) {
133 self.active_alerts.retain(|alert| !matches!(&alert.alert_type, _alert_type));
134 }
135
136 pub fn set_alert_thresholds(&mut self, thresholds: AlertThresholds) {
138 self.alert_thresholds = thresholds;
139 }
140
141 pub fn generate_training_summary(&self) -> TrainingSummary {
143 let total_duration = self.start_time.elapsed();
144 let total_steps = self.metrics_history.len();
145
146 let avg_loss = self.calculate_average_loss();
147 let best_accuracy = self.calculate_best_accuracy();
148 let avg_tokens_per_second = self.calculate_average_tokens_per_second();
149 let training_stability = self.calculate_training_stability();
150
151 TrainingSummary {
152 total_duration,
153 total_steps,
154 avg_loss,
155 best_accuracy,
156 avg_tokens_per_second,
157 training_stability,
158 active_alerts_count: self.active_alerts.len(),
159 convergence_status: self.assess_convergence(),
160 }
161 }
162
163 fn check_alerts(&mut self, metrics: &DashboardMetrics) {
164 if let Some(current_loss) = metrics.loss {
166 if let Some(prev_metrics) =
167 self.metrics_history.get(self.metrics_history.len().saturating_sub(10))
168 {
169 if let Some(prev_loss) = prev_metrics.loss {
170 if current_loss > prev_loss * self.alert_thresholds.loss_increase_threshold {
171 self.add_alert(TrainingAlert {
172 alert_type: AlertType::LossIncrease,
173 severity: AlertSeverity::Warning,
174 message: "Loss has increased significantly".to_string(),
175 timestamp: SystemTime::now(),
176 metric_value: current_loss,
177 threshold: prev_loss * self.alert_thresholds.loss_increase_threshold,
178 suggested_action: "Check learning rate or data quality".to_string(),
179 });
180 }
181 }
182 }
183 }
184
185 if let Some(grad_norm) = metrics.gradient_norm {
187 if grad_norm > self.alert_thresholds.gradient_norm_max {
188 self.add_alert(TrainingAlert {
189 alert_type: AlertType::GradientExplosion,
190 severity: AlertSeverity::Critical,
191 message: "Gradient explosion detected".to_string(),
192 timestamp: SystemTime::now(),
193 metric_value: grad_norm,
194 threshold: self.alert_thresholds.gradient_norm_max,
195 suggested_action: "Apply gradient clipping or reduce learning rate".to_string(),
196 });
197 }
198 }
199
200 if metrics.memory_usage_mb > self.alert_thresholds.memory_usage_max_mb {
202 self.add_alert(TrainingAlert {
203 alert_type: AlertType::MemoryOveruse,
204 severity: AlertSeverity::Warning,
205 message: "High memory usage detected".to_string(),
206 timestamp: SystemTime::now(),
207 metric_value: metrics.memory_usage_mb,
208 threshold: self.alert_thresholds.memory_usage_max_mb,
209 suggested_action: "Reduce batch size or enable gradient checkpointing".to_string(),
210 });
211 }
212
213 if let Some(gpu_util) = metrics.gpu_utilization {
215 if gpu_util < self.alert_thresholds.gpu_utilization_min {
216 self.add_alert(TrainingAlert {
217 alert_type: AlertType::LowGpuUtilization,
218 severity: AlertSeverity::Info,
219 message: "Low GPU utilization".to_string(),
220 timestamp: SystemTime::now(),
221 metric_value: gpu_util,
222 threshold: self.alert_thresholds.gpu_utilization_min,
223 suggested_action: "Increase batch size or check data loading".to_string(),
224 });
225 }
226 }
227
228 if let Some(tps) = metrics.tokens_per_second {
230 if tps < self.alert_thresholds.tokens_per_second_min {
231 self.add_alert(TrainingAlert {
232 alert_type: AlertType::SlowTokenProcessing,
233 severity: AlertSeverity::Warning,
234 message: "Slow token processing detected".to_string(),
235 timestamp: SystemTime::now(),
236 metric_value: tps,
237 threshold: self.alert_thresholds.tokens_per_second_min,
238 suggested_action: "Optimize model or increase batch size".to_string(),
239 });
240 }
241 }
242 }
243
244 fn add_alert(&mut self, alert: TrainingAlert) {
245 if !self.active_alerts.iter().any(|a| a.alert_type == alert.alert_type) {
247 self.active_alerts.push(alert);
248 }
249 }
250
251 fn calculate_average_loss(&self) -> Option<f64> {
252 let losses: Vec<f64> = self.metrics_history.iter().filter_map(|m| m.loss).collect();
253
254 if losses.is_empty() {
255 None
256 } else {
257 Some(losses.iter().sum::<f64>() / losses.len() as f64)
258 }
259 }
260
261 fn calculate_best_accuracy(&self) -> Option<f64> {
262 self.metrics_history
263 .iter()
264 .filter_map(|m| m.accuracy)
265 .fold(None, |acc, x| match acc {
266 None => Some(x),
267 Some(y) => Some(x.max(y)),
268 })
269 }
270
271 fn calculate_average_tokens_per_second(&self) -> Option<f64> {
272 let tps_values: Vec<f64> =
273 self.metrics_history.iter().filter_map(|m| m.tokens_per_second).collect();
274
275 if tps_values.is_empty() {
276 None
277 } else {
278 Some(tps_values.iter().sum::<f64>() / tps_values.len() as f64)
279 }
280 }
281
282 fn calculate_training_stability(&self) -> TrainingStability {
283 if self.metrics_history.len() < 10 {
284 return TrainingStability::Insufficient;
285 }
286
287 let recent_losses: Vec<f64> =
288 self.metrics_history.iter().rev().take(50).filter_map(|m| m.loss).collect();
289
290 if recent_losses.len() < 10 {
291 return TrainingStability::Insufficient;
292 }
293
294 let mean_loss = recent_losses.iter().sum::<f64>() / recent_losses.len() as f64;
296 let variance = recent_losses.iter().map(|&x| (x - mean_loss).powi(2)).sum::<f64>()
297 / recent_losses.len() as f64;
298
299 let std_dev = variance.sqrt();
300 let coefficient_of_variation = if mean_loss != 0.0 { std_dev / mean_loss } else { 0.0 };
301
302 match coefficient_of_variation {
303 cv if cv < 0.1 => TrainingStability::Stable,
304 cv if cv < 0.3 => TrainingStability::Moderate,
305 _ => TrainingStability::Unstable,
306 }
307 }
308
309 fn assess_convergence(&self) -> ConvergenceStatus {
310 if self.metrics_history.len() < 50 {
311 return ConvergenceStatus::TooEarly;
312 }
313
314 let recent_losses: Vec<f64> =
315 self.metrics_history.iter().rev().take(100).filter_map(|m| m.loss).collect();
316
317 if recent_losses.len() < 50 {
318 return ConvergenceStatus::TooEarly;
319 }
320
321 let first_half_avg =
323 recent_losses[25..].iter().sum::<f64>() / (recent_losses.len() - 25) as f64;
324 let second_half_avg = recent_losses[..25].iter().sum::<f64>() / 25.0;
325
326 if second_half_avg < first_half_avg * 0.95 {
327 ConvergenceStatus::Converging
328 } else if (second_half_avg - first_half_avg).abs() / first_half_avg < 0.01 {
329 ConvergenceStatus::Converged
330 } else {
331 ConvergenceStatus::Diverging
332 }
333 }
334}
335
336#[derive(Debug)]
338pub struct ModelComparator {
339 models: HashMap<String, ModelMetrics>,
340 comparison_config: ComparisonConfig,
341}
342
343#[derive(Debug, Clone, Serialize, Deserialize)]
344pub struct ModelMetrics {
345 pub model_id: String,
346 pub model_name: String,
347 pub metrics_history: Vec<DashboardMetrics>,
348 pub final_loss: Option<f64>,
349 pub final_accuracy: Option<f64>,
350 pub training_time: Duration,
351 pub parameter_count: usize,
352 pub model_size_mb: f64,
353}
354
355#[derive(Debug, Clone, Serialize, Deserialize)]
356pub struct ComparisonConfig {
357 pub primary_metric: String,
358 pub comparison_window: usize,
359 pub significance_threshold: f64,
360}
361
362impl Default for ComparisonConfig {
363 fn default() -> Self {
364 Self {
365 primary_metric: "loss".to_string(),
366 comparison_window: 100,
367 significance_threshold: 0.05,
368 }
369 }
370}
371
372impl ModelComparator {
373 pub fn new() -> Self {
375 Self {
376 models: HashMap::new(),
377 comparison_config: ComparisonConfig::default(),
378 }
379 }
380
381 pub fn add_model(&mut self, model_metrics: ModelMetrics) {
383 self.models.insert(model_metrics.model_id.clone(), model_metrics);
384 }
385
386 pub fn compare_models(&self) -> ModelComparisonReport {
388 let mut comparisons = Vec::new();
389 let model_ids: Vec<String> = self.models.keys().cloned().collect();
390
391 for i in 0..model_ids.len() {
392 for j in (i + 1)..model_ids.len() {
393 let model_a = &self.models[&model_ids[i]];
394 let model_b = &self.models[&model_ids[j]];
395
396 let comparison = self.compare_two_models(model_a, model_b);
397 comparisons.push(comparison);
398 }
399 }
400
401 let best_model = self.find_best_model();
402 let ranking = self.rank_models();
403
404 ModelComparisonReport {
405 comparisons,
406 best_model,
407 ranking,
408 comparison_config: self.comparison_config.clone(),
409 }
410 }
411
412 fn compare_two_models(
413 &self,
414 model_a: &ModelMetrics,
415 model_b: &ModelMetrics,
416 ) -> ModelComparison {
417 let performance_diff = self.calculate_performance_difference(model_a, model_b);
418 let efficiency_diff = self.calculate_efficiency_difference(model_a, model_b);
419 let statistical_significance = self.test_statistical_significance(model_a, model_b);
420
421 ModelComparison {
422 model_a_id: model_a.model_id.clone(),
423 model_b_id: model_b.model_id.clone(),
424 performance_difference: performance_diff,
425 efficiency_difference: efficiency_diff,
426 statistical_significance,
427 recommendation: self.generate_recommendation(model_a, model_b, performance_diff),
428 }
429 }
430
431 fn calculate_performance_difference(
432 &self,
433 model_a: &ModelMetrics,
434 model_b: &ModelMetrics,
435 ) -> f64 {
436 match self.comparison_config.primary_metric.as_str() {
437 "loss" => {
438 if let (Some(loss_a), Some(loss_b)) = (model_a.final_loss, model_b.final_loss) {
439 (loss_b - loss_a) / loss_a } else {
441 0.0
442 }
443 },
444 "accuracy" => {
445 if let (Some(acc_a), Some(acc_b)) = (model_a.final_accuracy, model_b.final_accuracy)
446 {
447 (acc_b - acc_a) / acc_a } else {
449 0.0
450 }
451 },
452 _ => 0.0,
453 }
454 }
455
456 fn calculate_efficiency_difference(
457 &self,
458 model_a: &ModelMetrics,
459 model_b: &ModelMetrics,
460 ) -> f64 {
461 let time_diff =
463 model_b.training_time.as_secs_f64() / model_a.training_time.as_secs_f64() - 1.0;
464
465 let size_diff = model_b.model_size_mb / model_a.model_size_mb - 1.0;
467
468 (time_diff + size_diff) / 2.0
470 }
471
472 fn test_statistical_significance(
473 &self,
474 _model_a: &ModelMetrics,
475 _model_b: &ModelMetrics,
476 ) -> bool {
477 true }
480
481 fn generate_recommendation(
482 &self,
483 model_a: &ModelMetrics,
484 model_b: &ModelMetrics,
485 perf_diff: f64,
486 ) -> String {
487 if perf_diff.abs() < 0.01 {
488 "Models perform similarly - choose based on other factors".to_string()
489 } else if perf_diff < 0.0 {
490 format!(
491 "Model {} performs {:.1}% better",
492 model_a.model_name,
493 perf_diff.abs() * 100.0
494 )
495 } else {
496 format!(
497 "Model {} performs {:.1}% better",
498 model_b.model_name,
499 perf_diff * 100.0
500 )
501 }
502 }
503
504 fn find_best_model(&self) -> Option<String> {
505 let mut best_model = None;
506 let mut best_score = f64::NEG_INFINITY;
507
508 for model in self.models.values() {
509 let score = match self.comparison_config.primary_metric.as_str() {
510 "loss" => model.final_loss.map(|l| -l).unwrap_or(f64::NEG_INFINITY),
511 "accuracy" => model.final_accuracy.unwrap_or(0.0),
512 _ => 0.0,
513 };
514
515 if score > best_score {
516 best_score = score;
517 best_model = Some(model.model_id.clone());
518 }
519 }
520
521 best_model
522 }
523
524 fn rank_models(&self) -> Vec<ModelRanking> {
525 let mut rankings: Vec<ModelRanking> = self
526 .models
527 .values()
528 .map(|model| {
529 let score = match self.comparison_config.primary_metric.as_str() {
530 "loss" => model.final_loss.map(|l| -l).unwrap_or(f64::NEG_INFINITY),
531 "accuracy" => model.final_accuracy.unwrap_or(0.0),
532 _ => 0.0,
533 };
534
535 ModelRanking {
536 model_id: model.model_id.clone(),
537 model_name: model.model_name.clone(),
538 score,
539 rank: 0, }
541 })
542 .collect();
543
544 rankings.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
545
546 for (i, ranking) in rankings.iter_mut().enumerate() {
547 ranking.rank = i + 1;
548 }
549
550 rankings
551 }
552}
553
554#[derive(Debug)]
556#[allow(dead_code)]
557pub struct HyperparameterExplorer {
558 experiments: HashMap<String, HyperparameterExperiment>,
559 #[allow(dead_code)]
560 search_space: HyperparameterSearchSpace,
561 optimization_history: Vec<OptimizationStep>,
562}
563
564#[derive(Debug, Clone, Serialize, Deserialize)]
565pub struct HyperparameterExperiment {
566 pub experiment_id: String,
567 pub hyperparameters: HashMap<String, HyperparameterValue>,
568 pub results: ExperimentResults,
569 pub status: ExperimentStatus,
570}
571
572#[derive(Debug, Clone, Serialize, Deserialize)]
573pub enum HyperparameterValue {
574 Float(f64),
575 Integer(i64),
576 String(String),
577 Boolean(bool),
578}
579
580#[derive(Debug, Clone, Serialize, Deserialize)]
581pub struct ExperimentResults {
582 pub final_loss: Option<f64>,
583 pub final_accuracy: Option<f64>,
584 pub training_time: Duration,
585 pub convergence_epoch: Option<u32>,
586 pub best_validation_score: Option<f64>,
587}
588
589#[derive(Debug, Clone, Serialize, Deserialize)]
590pub enum ExperimentStatus {
591 Running,
592 Completed,
593 Failed,
594 Cancelled,
595}
596
597#[derive(Debug, Clone, Serialize, Deserialize)]
598pub struct HyperparameterSearchSpace {
599 pub learning_rate: (f64, f64),
600 pub batch_size: (i64, i64),
601 pub dropout_rate: (f64, f64),
602 pub weight_decay: (f64, f64),
603 pub num_layers: (i64, i64),
604 pub hidden_size: (i64, i64),
605}
606
607impl Default for HyperparameterSearchSpace {
608 fn default() -> Self {
609 Self {
610 learning_rate: (1e-5, 1e-1),
611 batch_size: (4, 128),
612 dropout_rate: (0.0, 0.5),
613 weight_decay: (0.0, 1e-2),
614 num_layers: (1, 12),
615 hidden_size: (64, 2048),
616 }
617 }
618}
619
620#[derive(Debug, Clone, Serialize, Deserialize)]
621pub struct OptimizationStep {
622 pub step: usize,
623 pub best_experiment_id: String,
624 pub best_score: f64,
625 pub exploration_count: usize,
626 pub exploitation_count: usize,
627}
628
629impl HyperparameterExplorer {
630 pub fn new() -> Self {
632 Self {
633 experiments: HashMap::new(),
634 search_space: HyperparameterSearchSpace::default(),
635 optimization_history: Vec::new(),
636 }
637 }
638
639 pub fn add_experiment(&mut self, experiment: HyperparameterExperiment) {
641 self.experiments.insert(experiment.experiment_id.clone(), experiment);
642 }
643
644 pub fn get_recommendations(&self) -> HyperparameterRecommendations {
646 let best_experiments = self.find_best_experiments(5);
647 let parameter_importance = self.analyze_parameter_importance();
648 let suggested_ranges = self.suggest_search_ranges();
649 let next_experiments = self.suggest_next_experiments(3);
650
651 HyperparameterRecommendations {
652 best_experiments,
653 parameter_importance,
654 suggested_ranges,
655 next_experiments,
656 total_experiments: self.experiments.len(),
657 }
658 }
659
660 fn find_best_experiments(&self, limit: usize) -> Vec<String> {
661 let mut experiments: Vec<_> = self.experiments.values().collect();
662 experiments.sort_by(|a, b| {
663 let score_a = a.results.final_loss.unwrap_or(f64::INFINITY);
664 let score_b = b.results.final_loss.unwrap_or(f64::INFINITY);
665 score_a.partial_cmp(&score_b).unwrap_or(std::cmp::Ordering::Equal)
666 });
667
668 experiments.iter().take(limit).map(|exp| exp.experiment_id.clone()).collect()
669 }
670
671 fn analyze_parameter_importance(&self) -> HashMap<String, f64> {
672 let mut importance = HashMap::new();
674 importance.insert("learning_rate".to_string(), 0.8);
675 importance.insert("batch_size".to_string(), 0.6);
676 importance.insert("dropout_rate".to_string(), 0.4);
677 importance.insert("weight_decay".to_string(), 0.3);
678 importance
679 }
680
681 fn suggest_search_ranges(&self) -> HashMap<String, (f64, f64)> {
682 let mut ranges = HashMap::new();
684 ranges.insert("learning_rate".to_string(), (1e-4, 1e-2));
685 ranges.insert("dropout_rate".to_string(), (0.1, 0.3));
686 ranges
687 }
688
689 fn suggest_next_experiments(&self, count: usize) -> Vec<HashMap<String, HyperparameterValue>> {
690 let mut suggestions = Vec::new();
691
692 for i in 0..count {
693 let mut params = HashMap::new();
694
695 params.insert(
697 "learning_rate".to_string(),
698 HyperparameterValue::Float(0.001 * (1.0 + i as f64 * 0.5)),
699 );
700 params.insert(
701 "batch_size".to_string(),
702 HyperparameterValue::Integer(32 * (1 + i as i64)),
703 );
704 params.insert(
705 "dropout_rate".to_string(),
706 HyperparameterValue::Float(0.1 + i as f64 * 0.1),
707 );
708
709 suggestions.push(params);
710 }
711
712 suggestions
713 }
714}
715
716#[derive(Debug)]
718pub struct InteractiveDashboard {
719 #[allow(dead_code)]
720 config: DebugConfig,
721 training_monitor: TrainingMonitor,
722 model_comparator: ModelComparator,
723 hyperparameter_explorer: HyperparameterExplorer,
724 dashboard_state: DashboardState,
725 websocket_server: Option<WebSocketServer>,
726}
727
728#[derive(Debug, Serialize, Deserialize)]
729pub struct DashboardState {
730 pub active_session_id: Option<Uuid>,
731 pub refresh_rate_ms: u64,
732 pub auto_alerts: bool,
733 pub display_mode: DisplayMode,
734}
735
736#[derive(Debug, Clone, Serialize, Deserialize)]
737pub enum DisplayMode {
738 Overview,
739 DetailedMetrics,
740 ModelComparison,
741 HyperparameterOptimization,
742 AlertsOnly,
743}
744
745#[derive(Debug)]
747#[allow(dead_code)]
748pub struct WebSocketServer {
749 #[allow(dead_code)]
750 port: u16,
751 connected_clients: Arc<Mutex<Vec<String>>>,
752}
753
754impl InteractiveDashboard {
755 pub fn new(config: &DebugConfig) -> Self {
757 Self {
758 config: config.clone(),
759 training_monitor: TrainingMonitor::new(config),
760 model_comparator: ModelComparator::new(),
761 hyperparameter_explorer: HyperparameterExplorer::new(),
762 dashboard_state: DashboardState {
763 active_session_id: None,
764 refresh_rate_ms: 1000,
765 auto_alerts: true,
766 display_mode: DisplayMode::Overview,
767 },
768 websocket_server: None,
769 }
770 }
771
772 pub async fn start(&mut self, port: Option<u16>) -> Result<()> {
774 let port = port.unwrap_or(8080);
775
776 self.websocket_server = Some(WebSocketServer {
777 port,
778 connected_clients: Arc::new(Mutex::new(Vec::new())),
779 });
780
781 tracing::info!("Interactive dashboard started on port {}", port);
782 Ok(())
783 }
784
785 pub fn update(&mut self, metrics: DashboardMetrics) {
787 self.training_monitor.update_metrics(metrics.clone());
788
789 if let Some(_ws_server) = &self.websocket_server {
791 self.broadcast_update(metrics);
792 }
793 }
794
795 pub fn get_dashboard_snapshot(&self) -> DashboardSnapshot {
797 let training_summary = self.training_monitor.generate_training_summary();
798 let recent_metrics = self.training_monitor.get_recent_metrics(100);
799 let active_alerts = self.training_monitor.get_active_alerts().to_vec();
800 let model_comparison = self.model_comparator.compare_models();
801 let hyperparameter_recommendations = self.hyperparameter_explorer.get_recommendations();
802
803 DashboardSnapshot {
804 timestamp: SystemTime::now(),
805 training_summary,
806 recent_metrics,
807 active_alerts,
808 model_comparison,
809 hyperparameter_recommendations,
810 dashboard_state: DashboardState {
811 active_session_id: self.dashboard_state.active_session_id,
812 refresh_rate_ms: self.dashboard_state.refresh_rate_ms,
813 auto_alerts: self.dashboard_state.auto_alerts,
814 display_mode: self.dashboard_state.display_mode.clone(),
815 },
816 }
817 }
818
819 pub async fn export_dashboard_data(&self, path: &str) -> Result<()> {
821 let snapshot = self.get_dashboard_snapshot();
822 let json = serde_json::to_string_pretty(&snapshot)?;
823 tokio::fs::write(path, json).await?;
824 Ok(())
825 }
826
827 fn broadcast_update(&self, _metrics: DashboardMetrics) {
828 tracing::debug!("Broadcasting dashboard update to connected clients");
830 }
831}
832
833#[derive(Debug, Clone, Serialize, Deserialize)]
836pub struct TrainingSummary {
837 pub total_duration: Duration,
838 pub total_steps: usize,
839 pub avg_loss: Option<f64>,
840 pub best_accuracy: Option<f64>,
841 pub avg_tokens_per_second: Option<f64>,
842 pub training_stability: TrainingStability,
843 pub active_alerts_count: usize,
844 pub convergence_status: ConvergenceStatus,
845}
846
847#[derive(Debug, Clone, Serialize, Deserialize)]
848pub enum TrainingStability {
849 Stable,
850 Moderate,
851 Unstable,
852 Insufficient,
853}
854
855#[derive(Debug, Clone, Serialize, Deserialize)]
856pub enum ConvergenceStatus {
857 TooEarly,
858 Converging,
859 Converged,
860 Diverging,
861}
862
863#[derive(Debug, Serialize, Deserialize)]
864pub struct ModelComparisonReport {
865 pub comparisons: Vec<ModelComparison>,
866 pub best_model: Option<String>,
867 pub ranking: Vec<ModelRanking>,
868 pub comparison_config: ComparisonConfig,
869}
870
871#[derive(Debug, Serialize, Deserialize)]
872pub struct ModelComparison {
873 pub model_a_id: String,
874 pub model_b_id: String,
875 pub performance_difference: f64,
876 pub efficiency_difference: f64,
877 pub statistical_significance: bool,
878 pub recommendation: String,
879}
880
881#[derive(Debug, Serialize, Deserialize)]
882pub struct ModelRanking {
883 pub model_id: String,
884 pub model_name: String,
885 pub score: f64,
886 pub rank: usize,
887}
888
889#[derive(Debug, Serialize, Deserialize)]
890pub struct HyperparameterRecommendations {
891 pub best_experiments: Vec<String>,
892 pub parameter_importance: HashMap<String, f64>,
893 pub suggested_ranges: HashMap<String, (f64, f64)>,
894 pub next_experiments: Vec<HashMap<String, HyperparameterValue>>,
895 pub total_experiments: usize,
896}
897
898#[derive(Debug, Serialize, Deserialize)]
899pub struct DashboardSnapshot {
900 pub timestamp: SystemTime,
901 pub training_summary: TrainingSummary,
902 pub recent_metrics: Vec<DashboardMetrics>,
903 pub active_alerts: Vec<TrainingAlert>,
904 pub model_comparison: ModelComparisonReport,
905 pub hyperparameter_recommendations: HyperparameterRecommendations,
906 pub dashboard_state: DashboardState,
907}
908
909#[derive(Debug, Serialize, Deserialize)]
911pub struct DashboardReport {
912 pub session_duration: Duration,
913 pub total_metrics_recorded: usize,
914 pub alerts_triggered: usize,
915 pub models_compared: usize,
916 pub experiments_tracked: usize,
917 pub performance_summary: TrainingSummary,
918 pub key_insights: Vec<String>,
919 pub recommendations: Vec<String>,
920}
921
922impl InteractiveDashboard {
923 pub async fn generate_report(&self) -> Result<DashboardReport> {
925 let training_summary = self.training_monitor.generate_training_summary();
926 let total_metrics = self.training_monitor.metrics_history.len();
927 let alerts_count = self.training_monitor.active_alerts.len();
928 let models_count = self.model_comparator.models.len();
929 let experiments_count = self.hyperparameter_explorer.experiments.len();
930
931 let key_insights = self.generate_key_insights();
932 let recommendations = self.generate_recommendations();
933
934 Ok(DashboardReport {
935 session_duration: training_summary.total_duration,
936 total_metrics_recorded: total_metrics,
937 alerts_triggered: alerts_count,
938 models_compared: models_count,
939 experiments_tracked: experiments_count,
940 performance_summary: training_summary,
941 key_insights,
942 recommendations,
943 })
944 }
945
946 fn generate_key_insights(&self) -> Vec<String> {
947 let mut insights = Vec::new();
948
949 match self.training_monitor.generate_training_summary().training_stability {
951 TrainingStability::Stable => insights.push("Training is proceeding stably".to_string()),
952 TrainingStability::Unstable => insights.push(
953 "Training shows high variance - consider adjusting hyperparameters".to_string(),
954 ),
955 _ => {},
956 }
957
958 if self.model_comparator.models.len() > 1 {
960 let comparison = self.model_comparator.compare_models();
961 if let Some(best_model) = comparison.best_model {
962 insights.push(format!("Best performing model: {}", best_model));
963 }
964 }
965
966 let critical_alerts = self
968 .training_monitor
969 .active_alerts
970 .iter()
971 .filter(|alert| matches!(alert.severity, AlertSeverity::Critical))
972 .count();
973
974 if critical_alerts > 0 {
975 insights.push(format!(
976 "{} critical alerts require immediate attention",
977 critical_alerts
978 ));
979 }
980
981 insights
982 }
983
984 fn generate_recommendations(&self) -> Vec<String> {
985 let mut recommendations = Vec::new();
986
987 for alert in &self.training_monitor.active_alerts {
989 if matches!(alert.severity, AlertSeverity::Critical) {
990 recommendations.push(alert.suggested_action.clone());
991 }
992 }
993
994 if self.hyperparameter_explorer.experiments.len() > 5 {
996 recommendations.push(
997 "Continue hyperparameter optimization with narrowed search ranges".to_string(),
998 );
999 }
1000
1001 if self.model_comparator.models.len() > 1 {
1003 recommendations
1004 .push("Focus on the best performing model architecture for production".to_string());
1005 }
1006
1007 if recommendations.is_empty() {
1008 recommendations.push("Continue monitoring training progress".to_string());
1009 }
1010
1011 recommendations
1012 }
1013}