Skip to main content

trustformers_debug/
dashboard.rs

1//! Interactive dashboards for real-time monitoring and analysis
2
3use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use std::collections::{HashMap, VecDeque};
6use std::sync::{Arc, Mutex};
7use std::time::{Duration, Instant, SystemTime};
8use uuid::Uuid;
9
10use crate::DebugConfig;
11
12/// Real-time metrics for dashboard display
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct DashboardMetrics {
15    pub timestamp: SystemTime,
16    pub loss: Option<f64>,
17    pub accuracy: Option<f64>,
18    pub learning_rate: Option<f64>,
19    pub memory_usage_mb: f64,
20    pub gpu_utilization: Option<f64>,
21    pub tokens_per_second: Option<f64>,
22    pub gradient_norm: Option<f64>,
23    pub epoch: Option<u32>,
24    pub step: Option<u64>,
25}
26
27/// Training monitor for real-time tracking
28#[derive(Debug)]
29pub struct TrainingMonitor {
30    #[allow(dead_code)]
31    config: DebugConfig,
32    metrics_history: VecDeque<DashboardMetrics>,
33    max_history: usize,
34    start_time: Instant,
35    alert_thresholds: AlertThresholds,
36    active_alerts: Vec<TrainingAlert>,
37}
38
39/// Alert thresholds for training monitoring
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct AlertThresholds {
42    pub loss_increase_threshold: f64,
43    pub gradient_norm_max: f64,
44    pub memory_usage_max_mb: f64,
45    pub gpu_utilization_min: f64,
46    pub learning_rate_min: f64,
47    pub tokens_per_second_min: f64,
48}
49
50impl Default for AlertThresholds {
51    fn default() -> Self {
52        Self {
53            loss_increase_threshold: 1.5,
54            gradient_norm_max: 10.0,
55            memory_usage_max_mb: 8192.0,
56            gpu_utilization_min: 0.7,
57            learning_rate_min: 1e-8,
58            tokens_per_second_min: 100.0,
59        }
60    }
61}
62
63/// Training alert types
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct TrainingAlert {
66    pub alert_type: AlertType,
67    pub severity: AlertSeverity,
68    pub message: String,
69    pub timestamp: SystemTime,
70    pub metric_value: f64,
71    pub threshold: f64,
72    pub suggested_action: String,
73}
74
75#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
76pub enum AlertType {
77    LossIncrease,
78    GradientExplosion,
79    MemoryOveruse,
80    LowGpuUtilization,
81    LearningRateTooLow,
82    SlowTokenProcessing,
83    ModelDivergence,
84    TrainingStalled,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub enum AlertSeverity {
89    Info,
90    Warning,
91    Critical,
92}
93
94impl TrainingMonitor {
95    /// Create a new training monitor
96    pub fn new(config: &DebugConfig) -> Self {
97        Self {
98            config: config.clone(),
99            metrics_history: VecDeque::new(),
100            max_history: 10000,
101            start_time: Instant::now(),
102            alert_thresholds: AlertThresholds::default(),
103            active_alerts: Vec::new(),
104        }
105    }
106
107    /// Update metrics and check for alerts
108    pub fn update_metrics(&mut self, metrics: DashboardMetrics) {
109        // Add to history
110        self.metrics_history.push_back(metrics.clone());
111
112        // Trim history if needed
113        if self.metrics_history.len() > self.max_history {
114            self.metrics_history.pop_front();
115        }
116
117        // Check for alerts
118        self.check_alerts(&metrics);
119    }
120
121    /// Get recent metrics for dashboard
122    pub fn get_recent_metrics(&self, count: usize) -> Vec<DashboardMetrics> {
123        self.metrics_history.iter().rev().take(count).rev().cloned().collect()
124    }
125
126    /// Get active alerts
127    pub fn get_active_alerts(&self) -> &[TrainingAlert] {
128        &self.active_alerts
129    }
130
131    /// Clear resolved alerts
132    pub fn clear_alert(&mut self, _alert_type: AlertType) {
133        self.active_alerts.retain(|alert| !matches!(&alert.alert_type, _alert_type));
134    }
135
136    /// Set custom alert thresholds
137    pub fn set_alert_thresholds(&mut self, thresholds: AlertThresholds) {
138        self.alert_thresholds = thresholds;
139    }
140
141    /// Generate training summary
142    pub fn generate_training_summary(&self) -> TrainingSummary {
143        let total_duration = self.start_time.elapsed();
144        let total_steps = self.metrics_history.len();
145
146        let avg_loss = self.calculate_average_loss();
147        let best_accuracy = self.calculate_best_accuracy();
148        let avg_tokens_per_second = self.calculate_average_tokens_per_second();
149        let training_stability = self.calculate_training_stability();
150
151        TrainingSummary {
152            total_duration,
153            total_steps,
154            avg_loss,
155            best_accuracy,
156            avg_tokens_per_second,
157            training_stability,
158            active_alerts_count: self.active_alerts.len(),
159            convergence_status: self.assess_convergence(),
160        }
161    }
162
163    fn check_alerts(&mut self, metrics: &DashboardMetrics) {
164        // Check for loss increase
165        if let Some(current_loss) = metrics.loss {
166            if let Some(prev_metrics) =
167                self.metrics_history.get(self.metrics_history.len().saturating_sub(10))
168            {
169                if let Some(prev_loss) = prev_metrics.loss {
170                    if current_loss > prev_loss * self.alert_thresholds.loss_increase_threshold {
171                        self.add_alert(TrainingAlert {
172                            alert_type: AlertType::LossIncrease,
173                            severity: AlertSeverity::Warning,
174                            message: "Loss has increased significantly".to_string(),
175                            timestamp: SystemTime::now(),
176                            metric_value: current_loss,
177                            threshold: prev_loss * self.alert_thresholds.loss_increase_threshold,
178                            suggested_action: "Check learning rate or data quality".to_string(),
179                        });
180                    }
181                }
182            }
183        }
184
185        // Check gradient norm
186        if let Some(grad_norm) = metrics.gradient_norm {
187            if grad_norm > self.alert_thresholds.gradient_norm_max {
188                self.add_alert(TrainingAlert {
189                    alert_type: AlertType::GradientExplosion,
190                    severity: AlertSeverity::Critical,
191                    message: "Gradient explosion detected".to_string(),
192                    timestamp: SystemTime::now(),
193                    metric_value: grad_norm,
194                    threshold: self.alert_thresholds.gradient_norm_max,
195                    suggested_action: "Apply gradient clipping or reduce learning rate".to_string(),
196                });
197            }
198        }
199
200        // Check memory usage
201        if metrics.memory_usage_mb > self.alert_thresholds.memory_usage_max_mb {
202            self.add_alert(TrainingAlert {
203                alert_type: AlertType::MemoryOveruse,
204                severity: AlertSeverity::Warning,
205                message: "High memory usage detected".to_string(),
206                timestamp: SystemTime::now(),
207                metric_value: metrics.memory_usage_mb,
208                threshold: self.alert_thresholds.memory_usage_max_mb,
209                suggested_action: "Reduce batch size or enable gradient checkpointing".to_string(),
210            });
211        }
212
213        // Check GPU utilization
214        if let Some(gpu_util) = metrics.gpu_utilization {
215            if gpu_util < self.alert_thresholds.gpu_utilization_min {
216                self.add_alert(TrainingAlert {
217                    alert_type: AlertType::LowGpuUtilization,
218                    severity: AlertSeverity::Info,
219                    message: "Low GPU utilization".to_string(),
220                    timestamp: SystemTime::now(),
221                    metric_value: gpu_util,
222                    threshold: self.alert_thresholds.gpu_utilization_min,
223                    suggested_action: "Increase batch size or check data loading".to_string(),
224                });
225            }
226        }
227
228        // Check tokens per second
229        if let Some(tps) = metrics.tokens_per_second {
230            if tps < self.alert_thresholds.tokens_per_second_min {
231                self.add_alert(TrainingAlert {
232                    alert_type: AlertType::SlowTokenProcessing,
233                    severity: AlertSeverity::Warning,
234                    message: "Slow token processing detected".to_string(),
235                    timestamp: SystemTime::now(),
236                    metric_value: tps,
237                    threshold: self.alert_thresholds.tokens_per_second_min,
238                    suggested_action: "Optimize model or increase batch size".to_string(),
239                });
240            }
241        }
242    }
243
244    fn add_alert(&mut self, alert: TrainingAlert) {
245        // Avoid duplicate alerts of same type
246        if !self.active_alerts.iter().any(|a| a.alert_type == alert.alert_type) {
247            self.active_alerts.push(alert);
248        }
249    }
250
251    fn calculate_average_loss(&self) -> Option<f64> {
252        let losses: Vec<f64> = self.metrics_history.iter().filter_map(|m| m.loss).collect();
253
254        if losses.is_empty() {
255            None
256        } else {
257            Some(losses.iter().sum::<f64>() / losses.len() as f64)
258        }
259    }
260
261    fn calculate_best_accuracy(&self) -> Option<f64> {
262        self.metrics_history
263            .iter()
264            .filter_map(|m| m.accuracy)
265            .fold(None, |acc, x| match acc {
266                None => Some(x),
267                Some(y) => Some(x.max(y)),
268            })
269    }
270
271    fn calculate_average_tokens_per_second(&self) -> Option<f64> {
272        let tps_values: Vec<f64> =
273            self.metrics_history.iter().filter_map(|m| m.tokens_per_second).collect();
274
275        if tps_values.is_empty() {
276            None
277        } else {
278            Some(tps_values.iter().sum::<f64>() / tps_values.len() as f64)
279        }
280    }
281
282    fn calculate_training_stability(&self) -> TrainingStability {
283        if self.metrics_history.len() < 10 {
284            return TrainingStability::Insufficient;
285        }
286
287        let recent_losses: Vec<f64> =
288            self.metrics_history.iter().rev().take(50).filter_map(|m| m.loss).collect();
289
290        if recent_losses.len() < 10 {
291            return TrainingStability::Insufficient;
292        }
293
294        // Calculate loss variance
295        let mean_loss = recent_losses.iter().sum::<f64>() / recent_losses.len() as f64;
296        let variance = recent_losses.iter().map(|&x| (x - mean_loss).powi(2)).sum::<f64>()
297            / recent_losses.len() as f64;
298
299        let std_dev = variance.sqrt();
300        let coefficient_of_variation = if mean_loss != 0.0 { std_dev / mean_loss } else { 0.0 };
301
302        match coefficient_of_variation {
303            cv if cv < 0.1 => TrainingStability::Stable,
304            cv if cv < 0.3 => TrainingStability::Moderate,
305            _ => TrainingStability::Unstable,
306        }
307    }
308
309    fn assess_convergence(&self) -> ConvergenceStatus {
310        if self.metrics_history.len() < 50 {
311            return ConvergenceStatus::TooEarly;
312        }
313
314        let recent_losses: Vec<f64> =
315            self.metrics_history.iter().rev().take(100).filter_map(|m| m.loss).collect();
316
317        if recent_losses.len() < 50 {
318            return ConvergenceStatus::TooEarly;
319        }
320
321        // Check if loss is decreasing
322        let first_half_avg =
323            recent_losses[25..].iter().sum::<f64>() / (recent_losses.len() - 25) as f64;
324        let second_half_avg = recent_losses[..25].iter().sum::<f64>() / 25.0;
325
326        if second_half_avg < first_half_avg * 0.95 {
327            ConvergenceStatus::Converging
328        } else if (second_half_avg - first_half_avg).abs() / first_half_avg < 0.01 {
329            ConvergenceStatus::Converged
330        } else {
331            ConvergenceStatus::Diverging
332        }
333    }
334}
335
336/// Model comparison tool for A/B testing
337#[derive(Debug)]
338pub struct ModelComparator {
339    models: HashMap<String, ModelMetrics>,
340    comparison_config: ComparisonConfig,
341}
342
343#[derive(Debug, Clone, Serialize, Deserialize)]
344pub struct ModelMetrics {
345    pub model_id: String,
346    pub model_name: String,
347    pub metrics_history: Vec<DashboardMetrics>,
348    pub final_loss: Option<f64>,
349    pub final_accuracy: Option<f64>,
350    pub training_time: Duration,
351    pub parameter_count: usize,
352    pub model_size_mb: f64,
353}
354
355#[derive(Debug, Clone, Serialize, Deserialize)]
356pub struct ComparisonConfig {
357    pub primary_metric: String,
358    pub comparison_window: usize,
359    pub significance_threshold: f64,
360}
361
362impl Default for ComparisonConfig {
363    fn default() -> Self {
364        Self {
365            primary_metric: "loss".to_string(),
366            comparison_window: 100,
367            significance_threshold: 0.05,
368        }
369    }
370}
371
372impl ModelComparator {
373    /// Create new model comparator
374    pub fn new() -> Self {
375        Self {
376            models: HashMap::new(),
377            comparison_config: ComparisonConfig::default(),
378        }
379    }
380
381    /// Add model for comparison
382    pub fn add_model(&mut self, model_metrics: ModelMetrics) {
383        self.models.insert(model_metrics.model_id.clone(), model_metrics);
384    }
385
386    /// Compare models and generate report
387    pub fn compare_models(&self) -> ModelComparisonReport {
388        let mut comparisons = Vec::new();
389        let model_ids: Vec<String> = self.models.keys().cloned().collect();
390
391        for i in 0..model_ids.len() {
392            for j in (i + 1)..model_ids.len() {
393                let model_a = &self.models[&model_ids[i]];
394                let model_b = &self.models[&model_ids[j]];
395
396                let comparison = self.compare_two_models(model_a, model_b);
397                comparisons.push(comparison);
398            }
399        }
400
401        let best_model = self.find_best_model();
402        let ranking = self.rank_models();
403
404        ModelComparisonReport {
405            comparisons,
406            best_model,
407            ranking,
408            comparison_config: self.comparison_config.clone(),
409        }
410    }
411
412    fn compare_two_models(
413        &self,
414        model_a: &ModelMetrics,
415        model_b: &ModelMetrics,
416    ) -> ModelComparison {
417        let performance_diff = self.calculate_performance_difference(model_a, model_b);
418        let efficiency_diff = self.calculate_efficiency_difference(model_a, model_b);
419        let statistical_significance = self.test_statistical_significance(model_a, model_b);
420
421        ModelComparison {
422            model_a_id: model_a.model_id.clone(),
423            model_b_id: model_b.model_id.clone(),
424            performance_difference: performance_diff,
425            efficiency_difference: efficiency_diff,
426            statistical_significance,
427            recommendation: self.generate_recommendation(model_a, model_b, performance_diff),
428        }
429    }
430
431    fn calculate_performance_difference(
432        &self,
433        model_a: &ModelMetrics,
434        model_b: &ModelMetrics,
435    ) -> f64 {
436        match self.comparison_config.primary_metric.as_str() {
437            "loss" => {
438                if let (Some(loss_a), Some(loss_b)) = (model_a.final_loss, model_b.final_loss) {
439                    (loss_b - loss_a) / loss_a // Negative means model_a is better
440                } else {
441                    0.0
442                }
443            },
444            "accuracy" => {
445                if let (Some(acc_a), Some(acc_b)) = (model_a.final_accuracy, model_b.final_accuracy)
446                {
447                    (acc_b - acc_a) / acc_a // Positive means model_b is better
448                } else {
449                    0.0
450                }
451            },
452            _ => 0.0,
453        }
454    }
455
456    fn calculate_efficiency_difference(
457        &self,
458        model_a: &ModelMetrics,
459        model_b: &ModelMetrics,
460    ) -> f64 {
461        // Compare training time efficiency
462        let time_diff =
463            model_b.training_time.as_secs_f64() / model_a.training_time.as_secs_f64() - 1.0;
464
465        // Compare model size efficiency
466        let size_diff = model_b.model_size_mb / model_a.model_size_mb - 1.0;
467
468        // Combined efficiency score (lower is better)
469        (time_diff + size_diff) / 2.0
470    }
471
472    fn test_statistical_significance(
473        &self,
474        _model_a: &ModelMetrics,
475        _model_b: &ModelMetrics,
476    ) -> bool {
477        // Simplified statistical test - in practice would use proper statistical methods
478        true // Placeholder
479    }
480
481    fn generate_recommendation(
482        &self,
483        model_a: &ModelMetrics,
484        model_b: &ModelMetrics,
485        perf_diff: f64,
486    ) -> String {
487        if perf_diff.abs() < 0.01 {
488            "Models perform similarly - choose based on other factors".to_string()
489        } else if perf_diff < 0.0 {
490            format!(
491                "Model {} performs {:.1}% better",
492                model_a.model_name,
493                perf_diff.abs() * 100.0
494            )
495        } else {
496            format!(
497                "Model {} performs {:.1}% better",
498                model_b.model_name,
499                perf_diff * 100.0
500            )
501        }
502    }
503
504    fn find_best_model(&self) -> Option<String> {
505        let mut best_model = None;
506        let mut best_score = f64::NEG_INFINITY;
507
508        for model in self.models.values() {
509            let score = match self.comparison_config.primary_metric.as_str() {
510                "loss" => model.final_loss.map(|l| -l).unwrap_or(f64::NEG_INFINITY),
511                "accuracy" => model.final_accuracy.unwrap_or(0.0),
512                _ => 0.0,
513            };
514
515            if score > best_score {
516                best_score = score;
517                best_model = Some(model.model_id.clone());
518            }
519        }
520
521        best_model
522    }
523
524    fn rank_models(&self) -> Vec<ModelRanking> {
525        let mut rankings: Vec<ModelRanking> = self
526            .models
527            .values()
528            .map(|model| {
529                let score = match self.comparison_config.primary_metric.as_str() {
530                    "loss" => model.final_loss.map(|l| -l).unwrap_or(f64::NEG_INFINITY),
531                    "accuracy" => model.final_accuracy.unwrap_or(0.0),
532                    _ => 0.0,
533                };
534
535                ModelRanking {
536                    model_id: model.model_id.clone(),
537                    model_name: model.model_name.clone(),
538                    score,
539                    rank: 0, // Will be filled below
540                }
541            })
542            .collect();
543
544        rankings.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
545
546        for (i, ranking) in rankings.iter_mut().enumerate() {
547            ranking.rank = i + 1;
548        }
549
550        rankings
551    }
552}
553
554/// Hyperparameter explorer for optimization guidance
555#[derive(Debug)]
556#[allow(dead_code)]
557pub struct HyperparameterExplorer {
558    experiments: HashMap<String, HyperparameterExperiment>,
559    #[allow(dead_code)]
560    search_space: HyperparameterSearchSpace,
561    optimization_history: Vec<OptimizationStep>,
562}
563
564#[derive(Debug, Clone, Serialize, Deserialize)]
565pub struct HyperparameterExperiment {
566    pub experiment_id: String,
567    pub hyperparameters: HashMap<String, HyperparameterValue>,
568    pub results: ExperimentResults,
569    pub status: ExperimentStatus,
570}
571
572#[derive(Debug, Clone, Serialize, Deserialize)]
573pub enum HyperparameterValue {
574    Float(f64),
575    Integer(i64),
576    String(String),
577    Boolean(bool),
578}
579
580#[derive(Debug, Clone, Serialize, Deserialize)]
581pub struct ExperimentResults {
582    pub final_loss: Option<f64>,
583    pub final_accuracy: Option<f64>,
584    pub training_time: Duration,
585    pub convergence_epoch: Option<u32>,
586    pub best_validation_score: Option<f64>,
587}
588
589#[derive(Debug, Clone, Serialize, Deserialize)]
590pub enum ExperimentStatus {
591    Running,
592    Completed,
593    Failed,
594    Cancelled,
595}
596
597#[derive(Debug, Clone, Serialize, Deserialize)]
598pub struct HyperparameterSearchSpace {
599    pub learning_rate: (f64, f64),
600    pub batch_size: (i64, i64),
601    pub dropout_rate: (f64, f64),
602    pub weight_decay: (f64, f64),
603    pub num_layers: (i64, i64),
604    pub hidden_size: (i64, i64),
605}
606
607impl Default for HyperparameterSearchSpace {
608    fn default() -> Self {
609        Self {
610            learning_rate: (1e-5, 1e-1),
611            batch_size: (4, 128),
612            dropout_rate: (0.0, 0.5),
613            weight_decay: (0.0, 1e-2),
614            num_layers: (1, 12),
615            hidden_size: (64, 2048),
616        }
617    }
618}
619
620#[derive(Debug, Clone, Serialize, Deserialize)]
621pub struct OptimizationStep {
622    pub step: usize,
623    pub best_experiment_id: String,
624    pub best_score: f64,
625    pub exploration_count: usize,
626    pub exploitation_count: usize,
627}
628
629impl HyperparameterExplorer {
630    /// Create new hyperparameter explorer
631    pub fn new() -> Self {
632        Self {
633            experiments: HashMap::new(),
634            search_space: HyperparameterSearchSpace::default(),
635            optimization_history: Vec::new(),
636        }
637    }
638
639    /// Add experiment result
640    pub fn add_experiment(&mut self, experiment: HyperparameterExperiment) {
641        self.experiments.insert(experiment.experiment_id.clone(), experiment);
642    }
643
644    /// Get hyperparameter recommendations
645    pub fn get_recommendations(&self) -> HyperparameterRecommendations {
646        let best_experiments = self.find_best_experiments(5);
647        let parameter_importance = self.analyze_parameter_importance();
648        let suggested_ranges = self.suggest_search_ranges();
649        let next_experiments = self.suggest_next_experiments(3);
650
651        HyperparameterRecommendations {
652            best_experiments,
653            parameter_importance,
654            suggested_ranges,
655            next_experiments,
656            total_experiments: self.experiments.len(),
657        }
658    }
659
660    fn find_best_experiments(&self, limit: usize) -> Vec<String> {
661        let mut experiments: Vec<_> = self.experiments.values().collect();
662        experiments.sort_by(|a, b| {
663            let score_a = a.results.final_loss.unwrap_or(f64::INFINITY);
664            let score_b = b.results.final_loss.unwrap_or(f64::INFINITY);
665            score_a.partial_cmp(&score_b).unwrap_or(std::cmp::Ordering::Equal)
666        });
667
668        experiments.iter().take(limit).map(|exp| exp.experiment_id.clone()).collect()
669    }
670
671    fn analyze_parameter_importance(&self) -> HashMap<String, f64> {
672        // Simplified parameter importance analysis
673        let mut importance = HashMap::new();
674        importance.insert("learning_rate".to_string(), 0.8);
675        importance.insert("batch_size".to_string(), 0.6);
676        importance.insert("dropout_rate".to_string(), 0.4);
677        importance.insert("weight_decay".to_string(), 0.3);
678        importance
679    }
680
681    fn suggest_search_ranges(&self) -> HashMap<String, (f64, f64)> {
682        // Analyze best experiments to narrow search ranges
683        let mut ranges = HashMap::new();
684        ranges.insert("learning_rate".to_string(), (1e-4, 1e-2));
685        ranges.insert("dropout_rate".to_string(), (0.1, 0.3));
686        ranges
687    }
688
689    fn suggest_next_experiments(&self, count: usize) -> Vec<HashMap<String, HyperparameterValue>> {
690        let mut suggestions = Vec::new();
691
692        for i in 0..count {
693            let mut params = HashMap::new();
694
695            // Generate varied parameter combinations based on best results
696            params.insert(
697                "learning_rate".to_string(),
698                HyperparameterValue::Float(0.001 * (1.0 + i as f64 * 0.5)),
699            );
700            params.insert(
701                "batch_size".to_string(),
702                HyperparameterValue::Integer(32 * (1 + i as i64)),
703            );
704            params.insert(
705                "dropout_rate".to_string(),
706                HyperparameterValue::Float(0.1 + i as f64 * 0.1),
707            );
708
709            suggestions.push(params);
710        }
711
712        suggestions
713    }
714}
715
716/// Dashboard aggregator that combines all monitoring tools
717#[derive(Debug)]
718pub struct InteractiveDashboard {
719    #[allow(dead_code)]
720    config: DebugConfig,
721    training_monitor: TrainingMonitor,
722    model_comparator: ModelComparator,
723    hyperparameter_explorer: HyperparameterExplorer,
724    dashboard_state: DashboardState,
725    websocket_server: Option<WebSocketServer>,
726}
727
728#[derive(Debug, Serialize, Deserialize)]
729pub struct DashboardState {
730    pub active_session_id: Option<Uuid>,
731    pub refresh_rate_ms: u64,
732    pub auto_alerts: bool,
733    pub display_mode: DisplayMode,
734}
735
736#[derive(Debug, Clone, Serialize, Deserialize)]
737pub enum DisplayMode {
738    Overview,
739    DetailedMetrics,
740    ModelComparison,
741    HyperparameterOptimization,
742    AlertsOnly,
743}
744
745/// WebSocket server for real-time dashboard updates
746#[derive(Debug)]
747#[allow(dead_code)]
748pub struct WebSocketServer {
749    #[allow(dead_code)]
750    port: u16,
751    connected_clients: Arc<Mutex<Vec<String>>>,
752}
753
754impl InteractiveDashboard {
755    /// Create new interactive dashboard
756    pub fn new(config: &DebugConfig) -> Self {
757        Self {
758            config: config.clone(),
759            training_monitor: TrainingMonitor::new(config),
760            model_comparator: ModelComparator::new(),
761            hyperparameter_explorer: HyperparameterExplorer::new(),
762            dashboard_state: DashboardState {
763                active_session_id: None,
764                refresh_rate_ms: 1000,
765                auto_alerts: true,
766                display_mode: DisplayMode::Overview,
767            },
768            websocket_server: None,
769        }
770    }
771
772    /// Start dashboard with WebSocket server
773    pub async fn start(&mut self, port: Option<u16>) -> Result<()> {
774        let port = port.unwrap_or(8080);
775
776        self.websocket_server = Some(WebSocketServer {
777            port,
778            connected_clients: Arc::new(Mutex::new(Vec::new())),
779        });
780
781        tracing::info!("Interactive dashboard started on port {}", port);
782        Ok(())
783    }
784
785    /// Update dashboard with new metrics
786    pub fn update(&mut self, metrics: DashboardMetrics) {
787        self.training_monitor.update_metrics(metrics.clone());
788
789        // Broadcast to connected clients if WebSocket server is running
790        if let Some(_ws_server) = &self.websocket_server {
791            self.broadcast_update(metrics);
792        }
793    }
794
795    /// Get current dashboard snapshot
796    pub fn get_dashboard_snapshot(&self) -> DashboardSnapshot {
797        let training_summary = self.training_monitor.generate_training_summary();
798        let recent_metrics = self.training_monitor.get_recent_metrics(100);
799        let active_alerts = self.training_monitor.get_active_alerts().to_vec();
800        let model_comparison = self.model_comparator.compare_models();
801        let hyperparameter_recommendations = self.hyperparameter_explorer.get_recommendations();
802
803        DashboardSnapshot {
804            timestamp: SystemTime::now(),
805            training_summary,
806            recent_metrics,
807            active_alerts,
808            model_comparison,
809            hyperparameter_recommendations,
810            dashboard_state: DashboardState {
811                active_session_id: self.dashboard_state.active_session_id,
812                refresh_rate_ms: self.dashboard_state.refresh_rate_ms,
813                auto_alerts: self.dashboard_state.auto_alerts,
814                display_mode: self.dashboard_state.display_mode.clone(),
815            },
816        }
817    }
818
819    /// Export dashboard data to file
820    pub async fn export_dashboard_data(&self, path: &str) -> Result<()> {
821        let snapshot = self.get_dashboard_snapshot();
822        let json = serde_json::to_string_pretty(&snapshot)?;
823        tokio::fs::write(path, json).await?;
824        Ok(())
825    }
826
827    fn broadcast_update(&self, _metrics: DashboardMetrics) {
828        // In a real implementation, this would send updates to WebSocket clients
829        tracing::debug!("Broadcasting dashboard update to connected clients");
830    }
831}
832
833// Supporting data structures
834
835#[derive(Debug, Clone, Serialize, Deserialize)]
836pub struct TrainingSummary {
837    pub total_duration: Duration,
838    pub total_steps: usize,
839    pub avg_loss: Option<f64>,
840    pub best_accuracy: Option<f64>,
841    pub avg_tokens_per_second: Option<f64>,
842    pub training_stability: TrainingStability,
843    pub active_alerts_count: usize,
844    pub convergence_status: ConvergenceStatus,
845}
846
847#[derive(Debug, Clone, Serialize, Deserialize)]
848pub enum TrainingStability {
849    Stable,
850    Moderate,
851    Unstable,
852    Insufficient,
853}
854
855#[derive(Debug, Clone, Serialize, Deserialize)]
856pub enum ConvergenceStatus {
857    TooEarly,
858    Converging,
859    Converged,
860    Diverging,
861}
862
863#[derive(Debug, Serialize, Deserialize)]
864pub struct ModelComparisonReport {
865    pub comparisons: Vec<ModelComparison>,
866    pub best_model: Option<String>,
867    pub ranking: Vec<ModelRanking>,
868    pub comparison_config: ComparisonConfig,
869}
870
871#[derive(Debug, Serialize, Deserialize)]
872pub struct ModelComparison {
873    pub model_a_id: String,
874    pub model_b_id: String,
875    pub performance_difference: f64,
876    pub efficiency_difference: f64,
877    pub statistical_significance: bool,
878    pub recommendation: String,
879}
880
881#[derive(Debug, Serialize, Deserialize)]
882pub struct ModelRanking {
883    pub model_id: String,
884    pub model_name: String,
885    pub score: f64,
886    pub rank: usize,
887}
888
889#[derive(Debug, Serialize, Deserialize)]
890pub struct HyperparameterRecommendations {
891    pub best_experiments: Vec<String>,
892    pub parameter_importance: HashMap<String, f64>,
893    pub suggested_ranges: HashMap<String, (f64, f64)>,
894    pub next_experiments: Vec<HashMap<String, HyperparameterValue>>,
895    pub total_experiments: usize,
896}
897
898#[derive(Debug, Serialize, Deserialize)]
899pub struct DashboardSnapshot {
900    pub timestamp: SystemTime,
901    pub training_summary: TrainingSummary,
902    pub recent_metrics: Vec<DashboardMetrics>,
903    pub active_alerts: Vec<TrainingAlert>,
904    pub model_comparison: ModelComparisonReport,
905    pub hyperparameter_recommendations: HyperparameterRecommendations,
906    pub dashboard_state: DashboardState,
907}
908
909/// Dashboard report for integration with main debug system
910#[derive(Debug, Serialize, Deserialize)]
911pub struct DashboardReport {
912    pub session_duration: Duration,
913    pub total_metrics_recorded: usize,
914    pub alerts_triggered: usize,
915    pub models_compared: usize,
916    pub experiments_tracked: usize,
917    pub performance_summary: TrainingSummary,
918    pub key_insights: Vec<String>,
919    pub recommendations: Vec<String>,
920}
921
922impl InteractiveDashboard {
923    /// Generate comprehensive dashboard report
924    pub async fn generate_report(&self) -> Result<DashboardReport> {
925        let training_summary = self.training_monitor.generate_training_summary();
926        let total_metrics = self.training_monitor.metrics_history.len();
927        let alerts_count = self.training_monitor.active_alerts.len();
928        let models_count = self.model_comparator.models.len();
929        let experiments_count = self.hyperparameter_explorer.experiments.len();
930
931        let key_insights = self.generate_key_insights();
932        let recommendations = self.generate_recommendations();
933
934        Ok(DashboardReport {
935            session_duration: training_summary.total_duration,
936            total_metrics_recorded: total_metrics,
937            alerts_triggered: alerts_count,
938            models_compared: models_count,
939            experiments_tracked: experiments_count,
940            performance_summary: training_summary,
941            key_insights,
942            recommendations,
943        })
944    }
945
946    fn generate_key_insights(&self) -> Vec<String> {
947        let mut insights = Vec::new();
948
949        // Training stability insights
950        match self.training_monitor.generate_training_summary().training_stability {
951            TrainingStability::Stable => insights.push("Training is proceeding stably".to_string()),
952            TrainingStability::Unstable => insights.push(
953                "Training shows high variance - consider adjusting hyperparameters".to_string(),
954            ),
955            _ => {},
956        }
957
958        // Model comparison insights
959        if self.model_comparator.models.len() > 1 {
960            let comparison = self.model_comparator.compare_models();
961            if let Some(best_model) = comparison.best_model {
962                insights.push(format!("Best performing model: {}", best_model));
963            }
964        }
965
966        // Alert insights
967        let critical_alerts = self
968            .training_monitor
969            .active_alerts
970            .iter()
971            .filter(|alert| matches!(alert.severity, AlertSeverity::Critical))
972            .count();
973
974        if critical_alerts > 0 {
975            insights.push(format!(
976                "{} critical alerts require immediate attention",
977                critical_alerts
978            ));
979        }
980
981        insights
982    }
983
984    fn generate_recommendations(&self) -> Vec<String> {
985        let mut recommendations = Vec::new();
986
987        // Based on active alerts
988        for alert in &self.training_monitor.active_alerts {
989            if matches!(alert.severity, AlertSeverity::Critical) {
990                recommendations.push(alert.suggested_action.clone());
991            }
992        }
993
994        // Based on hyperparameter exploration
995        if self.hyperparameter_explorer.experiments.len() > 5 {
996            recommendations.push(
997                "Continue hyperparameter optimization with narrowed search ranges".to_string(),
998            );
999        }
1000
1001        // Based on model comparison
1002        if self.model_comparator.models.len() > 1 {
1003            recommendations
1004                .push("Focus on the best performing model architecture for production".to_string());
1005        }
1006
1007        if recommendations.is_empty() {
1008            recommendations.push("Continue monitoring training progress".to_string());
1009        }
1010
1011        recommendations
1012    }
1013}