Skip to main content

trustformers_debug/
health_checker.rs

1//! Model and training health assessment system
2
3use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use std::collections::VecDeque;
6use std::time::{Duration, SystemTime};
7
8use crate::{DashboardMetrics, DebugConfig};
9
10/// Comprehensive health checker for model training
11#[derive(Debug)]
12pub struct HealthChecker {
13    #[allow(dead_code)]
14    config: DebugConfig,
15    metrics_history: VecDeque<DashboardMetrics>,
16    health_assessments: Vec<HealthAssessment>,
17    stability_tracker: StabilityTracker,
18    convergence_analyzer: ConvergenceAnalyzer,
19    overfitting_detector: OverfittingDetector,
20    generalization_monitor: GeneralizationMonitor,
21    performance_baseline: Option<PerformanceBaseline>,
22}
23
24/// Overall health assessment
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct HealthAssessment {
27    pub timestamp: SystemTime,
28    pub overall_health_score: f64,
29    pub training_stability_index: f64,
30    pub convergence_probability: f64,
31    pub overfitting_risk: OverfittingRisk,
32    pub generalization_score: f64,
33    pub component_scores: ComponentHealthScores,
34    pub health_status: HealthStatus,
35    pub alerts: Vec<HealthAlert>,
36    pub recommendations: Vec<HealthRecommendation>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ComponentHealthScores {
41    pub gradient_health: f64,
42    pub loss_health: f64,
43    pub accuracy_health: f64,
44    pub performance_health: f64,
45    pub memory_health: f64,
46    pub stability_health: f64,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub enum HealthStatus {
51    Excellent, // 90-100%
52    Good,      // 75-89%
53    Fair,      // 60-74%
54    Poor,      // 40-59%
55    Critical,  // 0-39%
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct HealthAlert {
60    pub alert_type: HealthAlertType,
61    pub severity: AlertSeverity,
62    pub message: String,
63    pub metric_value: f64,
64    pub threshold: f64,
65    pub trend: Trend,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub enum HealthAlertType {
70    TrainingStability,
71    ConvergenceIssue,
72    OverfittingDetected,
73    PerformanceDegradation,
74    MemoryIssue,
75    GradientProblem,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub enum AlertSeverity {
80    Critical,
81    High,
82    Medium,
83    Low,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub enum Trend {
88    Improving,
89    Stable,
90    Degrading,
91    Volatile,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct HealthRecommendation {
96    pub category: RecommendationCategory,
97    pub title: String,
98    pub description: String,
99    pub urgency: RecommendationUrgency,
100    pub expected_impact: f64,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub enum RecommendationCategory {
105    Training,
106    Architecture,
107    Hyperparameters,
108    Data,
109    Performance,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub enum RecommendationUrgency {
114    Immediate,
115    Soon,
116    Eventually,
117    Optional,
118}
119
120/// Training stability tracking
121#[derive(Debug)]
122pub struct StabilityTracker {
123    loss_stability: MetricStability,
124    accuracy_stability: MetricStability,
125    gradient_stability: MetricStability,
126    learning_rate_stability: MetricStability,
127    #[allow(dead_code)]
128    window_size: usize,
129}
130
131#[derive(Debug)]
132pub struct MetricStability {
133    values: VecDeque<f64>,
134    variance_threshold: f64,
135    #[allow(dead_code)]
136    trend_threshold: f64,
137}
138
139impl MetricStability {
140    pub fn new(variance_threshold: f64, trend_threshold: f64) -> Self {
141        Self {
142            values: VecDeque::new(),
143            variance_threshold,
144            trend_threshold,
145        }
146    }
147
148    pub fn update(&mut self, value: f64) {
149        self.values.push_back(value);
150        if self.values.len() > 50 {
151            self.values.pop_front();
152        }
153    }
154
155    pub fn calculate_stability(&self) -> f64 {
156        if self.values.len() < 5 {
157            return 0.5; // Insufficient data
158        }
159
160        let variance = self.calculate_variance();
161        let trend_stability = self.calculate_trend_stability();
162
163        let variance_score = if variance < self.variance_threshold {
164            1.0
165        } else {
166            (self.variance_threshold / variance).min(1.0)
167        };
168
169        (variance_score + trend_stability) / 2.0
170    }
171
172    fn calculate_variance(&self) -> f64 {
173        if self.values.len() < 2 {
174            return 0.0;
175        }
176
177        let mean = self.values.iter().sum::<f64>() / self.values.len() as f64;
178        let variance = self.values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
179            / (self.values.len() - 1) as f64;
180        variance
181    }
182
183    fn calculate_trend_stability(&self) -> f64 {
184        if self.values.len() < 10 {
185            return 0.5;
186        }
187
188        // Calculate slope changes to detect instability
189        let mut slope_changes = 0;
190        let values: Vec<f64> = self.values.iter().cloned().collect();
191
192        for i in 2..values.len() {
193            let slope1 = values[i - 1] - values[i - 2];
194            let slope2 = values[i] - values[i - 1];
195
196            if (slope1 > 0.0) != (slope2 > 0.0) {
197                slope_changes += 1;
198            }
199        }
200
201        let change_rate = slope_changes as f64 / (values.len() - 2) as f64;
202        (1.0 - change_rate).max(0.0)
203    }
204}
205
206/// Convergence analysis
207#[derive(Debug)]
208pub struct ConvergenceAnalyzer {
209    loss_history: VecDeque<f64>,
210    accuracy_history: VecDeque<f64>,
211    convergence_window: usize,
212    convergence_threshold: f64,
213}
214
215impl Default for ConvergenceAnalyzer {
216    fn default() -> Self {
217        Self::new()
218    }
219}
220
221impl ConvergenceAnalyzer {
222    pub fn new() -> Self {
223        Self {
224            loss_history: VecDeque::new(),
225            accuracy_history: VecDeque::new(),
226            convergence_window: 100,
227            convergence_threshold: 0.01,
228        }
229    }
230
231    pub fn update(&mut self, loss: Option<f64>, accuracy: Option<f64>) {
232        if let Some(loss) = loss {
233            self.loss_history.push_back(loss);
234            if self.loss_history.len() > self.convergence_window * 2 {
235                self.loss_history.pop_front();
236            }
237        }
238
239        if let Some(accuracy) = accuracy {
240            self.accuracy_history.push_back(accuracy);
241            if self.accuracy_history.len() > self.convergence_window * 2 {
242                self.accuracy_history.pop_front();
243            }
244        }
245    }
246
247    pub fn calculate_convergence_probability(&self) -> f64 {
248        let loss_convergence = self.analyze_loss_convergence();
249        let accuracy_convergence = self.analyze_accuracy_convergence();
250
251        // Weight loss convergence more heavily
252        0.7 * loss_convergence + 0.3 * accuracy_convergence
253    }
254
255    fn analyze_loss_convergence(&self) -> f64 {
256        if self.loss_history.len() < self.convergence_window {
257            return 0.3; // Insufficient data
258        }
259
260        let recent_window = self.convergence_window / 2;
261        let recent_losses: Vec<f64> =
262            self.loss_history.iter().rev().take(recent_window).cloned().collect();
263
264        let earlier_losses: Vec<f64> = self
265            .loss_history
266            .iter()
267            .rev()
268            .skip(recent_window)
269            .take(recent_window)
270            .cloned()
271            .collect();
272
273        if recent_losses.is_empty() || earlier_losses.is_empty() {
274            return 0.3;
275        }
276
277        let recent_avg = recent_losses.iter().sum::<f64>() / recent_losses.len() as f64;
278        let earlier_avg = earlier_losses.iter().sum::<f64>() / earlier_losses.len() as f64;
279
280        // Check if loss is decreasing
281        if recent_avg < earlier_avg {
282            let improvement_rate = (earlier_avg - recent_avg) / earlier_avg;
283
284            if improvement_rate > self.convergence_threshold {
285                0.8 // Good convergence
286            } else {
287                0.6 // Slow convergence
288            }
289        } else {
290            // Loss increasing or stagnant
291            let variance = self.calculate_variance(&recent_losses);
292            if variance < self.convergence_threshold {
293                0.4 // Converged but not improving
294            } else {
295                0.2 // Diverging or unstable
296            }
297        }
298    }
299
300    fn analyze_accuracy_convergence(&self) -> f64 {
301        if self.accuracy_history.len() < self.convergence_window {
302            return 0.5;
303        }
304
305        let recent_window = self.convergence_window / 2;
306        let recent_accuracy: Vec<f64> =
307            self.accuracy_history.iter().rev().take(recent_window).cloned().collect();
308
309        let variance = self.calculate_variance(&recent_accuracy);
310
311        // Low variance in accuracy suggests convergence
312        if variance < 0.01 {
313            0.8
314        } else if variance < 0.05 {
315            0.6
316        } else {
317            0.4
318        }
319    }
320
321    fn calculate_variance(&self, values: &[f64]) -> f64 {
322        if values.len() < 2 {
323            return 0.0;
324        }
325
326        let mean = values.iter().sum::<f64>() / values.len() as f64;
327        let variance =
328            values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
329        variance
330    }
331}
332
333/// Overfitting detection
334#[derive(Debug)]
335pub struct OverfittingDetector {
336    train_loss_history: VecDeque<f64>,
337    val_loss_history: VecDeque<f64>,
338    train_accuracy_history: VecDeque<f64>,
339    val_accuracy_history: VecDeque<f64>,
340    overfitting_threshold: f64,
341}
342
343#[derive(Debug, Clone, Serialize, Deserialize)]
344pub enum OverfittingRisk {
345    None,
346    Low,
347    Medium,
348    High,
349    Severe,
350}
351
352impl Default for OverfittingDetector {
353    fn default() -> Self {
354        Self::new()
355    }
356}
357
358impl OverfittingDetector {
359    pub fn new() -> Self {
360        Self {
361            train_loss_history: VecDeque::new(),
362            val_loss_history: VecDeque::new(),
363            train_accuracy_history: VecDeque::new(),
364            val_accuracy_history: VecDeque::new(),
365            overfitting_threshold: 0.1,
366        }
367    }
368
369    pub fn update_train_metrics(&mut self, loss: Option<f64>, accuracy: Option<f64>) {
370        if let Some(loss) = loss {
371            self.train_loss_history.push_back(loss);
372            if self.train_loss_history.len() > 100 {
373                self.train_loss_history.pop_front();
374            }
375        }
376
377        if let Some(accuracy) = accuracy {
378            self.train_accuracy_history.push_back(accuracy);
379            if self.train_accuracy_history.len() > 100 {
380                self.train_accuracy_history.pop_front();
381            }
382        }
383    }
384
385    pub fn update_validation_metrics(&mut self, loss: Option<f64>, accuracy: Option<f64>) {
386        if let Some(loss) = loss {
387            self.val_loss_history.push_back(loss);
388            if self.val_loss_history.len() > 100 {
389                self.val_loss_history.pop_front();
390            }
391        }
392
393        if let Some(accuracy) = accuracy {
394            self.val_accuracy_history.push_back(accuracy);
395            if self.val_accuracy_history.len() > 100 {
396                self.val_accuracy_history.pop_front();
397            }
398        }
399    }
400
401    pub fn detect_overfitting(&self) -> OverfittingRisk {
402        let loss_gap = self.calculate_loss_gap();
403        let accuracy_gap = self.calculate_accuracy_gap();
404        let trend_analysis = self.analyze_overfitting_trend();
405
406        let overfitting_score = (loss_gap + accuracy_gap + trend_analysis) / 3.0;
407
408        match overfitting_score {
409            score if score > 0.8 => OverfittingRisk::Severe,
410            score if score > 0.6 => OverfittingRisk::High,
411            score if score > 0.4 => OverfittingRisk::Medium,
412            score if score > 0.2 => OverfittingRisk::Low,
413            _ => OverfittingRisk::None,
414        }
415    }
416
417    fn calculate_loss_gap(&self) -> f64 {
418        if self.train_loss_history.len() < 10 || self.val_loss_history.len() < 10 {
419            return 0.0;
420        }
421
422        let recent_train_loss = self.train_loss_history.iter().rev().take(10).sum::<f64>() / 10.0;
423
424        let recent_val_loss = self.val_loss_history.iter().rev().take(10).sum::<f64>() / 10.0;
425
426        if recent_train_loss < recent_val_loss {
427            let gap = (recent_val_loss - recent_train_loss) / recent_train_loss;
428            (gap / self.overfitting_threshold).min(1.0)
429        } else {
430            0.0
431        }
432    }
433
434    fn calculate_accuracy_gap(&self) -> f64 {
435        if self.train_accuracy_history.len() < 10 || self.val_accuracy_history.len() < 10 {
436            return 0.0;
437        }
438
439        let recent_train_acc =
440            self.train_accuracy_history.iter().rev().take(10).sum::<f64>() / 10.0;
441
442        let recent_val_acc = self.val_accuracy_history.iter().rev().take(10).sum::<f64>() / 10.0;
443
444        if recent_train_acc > recent_val_acc {
445            let gap = recent_train_acc - recent_val_acc;
446            (gap / self.overfitting_threshold).min(1.0)
447        } else {
448            0.0
449        }
450    }
451
452    fn analyze_overfitting_trend(&self) -> f64 {
453        // Analyze if the gap between train and validation is increasing
454        if self.train_loss_history.len() < 20 || self.val_loss_history.len() < 20 {
455            return 0.0;
456        }
457
458        let early_train_loss = self.train_loss_history.iter().take(10).sum::<f64>() / 10.0;
459
460        let recent_train_loss = self.train_loss_history.iter().rev().take(10).sum::<f64>() / 10.0;
461
462        let early_val_loss = self.val_loss_history.iter().take(10).sum::<f64>() / 10.0;
463
464        let recent_val_loss = self.val_loss_history.iter().rev().take(10).sum::<f64>() / 10.0;
465
466        let early_gap = (early_val_loss - early_train_loss).max(0.0);
467        let recent_gap = (recent_val_loss - recent_train_loss).max(0.0);
468
469        if recent_gap > early_gap && early_gap > 0.0 {
470            ((recent_gap - early_gap) / early_gap).min(1.0)
471        } else {
472            0.0
473        }
474    }
475}
476
477/// Generalization monitoring
478#[derive(Debug)]
479pub struct GeneralizationMonitor {
480    cross_validation_scores: Vec<f64>,
481    holdout_performance: Option<f64>,
482    train_performance: Option<f64>,
483    complexity_metrics: ComplexityMetrics,
484}
485
486#[derive(Debug)]
487pub struct ComplexityMetrics {
488    parameter_count: usize,
489    #[allow(dead_code)]
490    effective_capacity: f64,
491    data_size: usize,
492}
493
494impl Default for GeneralizationMonitor {
495    fn default() -> Self {
496        Self::new()
497    }
498}
499
500impl GeneralizationMonitor {
501    pub fn new() -> Self {
502        Self {
503            cross_validation_scores: Vec::new(),
504            holdout_performance: None,
505            train_performance: None,
506            complexity_metrics: ComplexityMetrics {
507                parameter_count: 0,
508                effective_capacity: 0.0,
509                data_size: 0,
510            },
511        }
512    }
513
514    pub fn update_performance(&mut self, train_perf: f64, val_perf: Option<f64>) {
515        self.train_performance = Some(train_perf);
516        if let Some(val_perf) = val_perf {
517            self.holdout_performance = Some(val_perf);
518        }
519    }
520
521    pub fn calculate_generalization_score(&self) -> f64 {
522        let performance_consistency = self.calculate_performance_consistency();
523        let complexity_penalty = self.calculate_complexity_penalty();
524        let cv_consistency = self.calculate_cv_consistency();
525
526        (performance_consistency + cv_consistency + (1.0 - complexity_penalty)) / 3.0
527    }
528
529    fn calculate_performance_consistency(&self) -> f64 {
530        match (self.train_performance, self.holdout_performance) {
531            (Some(train), Some(val)) => {
532                let gap = (train - val).abs();
533                (1.0 - gap.min(1.0)).max(0.0)
534            },
535            _ => 0.5, // Unknown
536        }
537    }
538
539    fn calculate_complexity_penalty(&self) -> f64 {
540        // Simplified complexity penalty based on parameter count vs data size
541        if self.complexity_metrics.data_size == 0 {
542            return 0.0;
543        }
544
545        let param_per_sample = self.complexity_metrics.parameter_count as f64
546            / self.complexity_metrics.data_size as f64;
547
548        if param_per_sample > 1.0 {
549            0.8 // High complexity
550        } else if param_per_sample > 0.1 {
551            0.4 // Medium complexity
552        } else {
553            0.1 // Low complexity
554        }
555    }
556
557    fn calculate_cv_consistency(&self) -> f64 {
558        if self.cross_validation_scores.len() < 3 {
559            return 0.5;
560        }
561
562        let mean = self.cross_validation_scores.iter().sum::<f64>()
563            / self.cross_validation_scores.len() as f64;
564        let variance = self.cross_validation_scores.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
565            / (self.cross_validation_scores.len() - 1) as f64;
566
567        (1.0 - variance.sqrt().min(1.0)).max(0.0)
568    }
569}
570
571/// Performance baseline for comparison
572#[derive(Debug, Clone)]
573pub struct PerformanceBaseline {
574    pub baseline_loss: f64,
575    pub baseline_accuracy: f64,
576    pub baseline_training_time: Duration,
577    pub baseline_memory_usage: f64,
578    pub established_at: SystemTime,
579}
580
581impl HealthChecker {
582    /// Create new health checker
583    pub fn new(config: &DebugConfig) -> Self {
584        Self {
585            config: config.clone(),
586            metrics_history: VecDeque::new(),
587            health_assessments: Vec::new(),
588            stability_tracker: StabilityTracker {
589                loss_stability: MetricStability::new(0.1, 0.05),
590                accuracy_stability: MetricStability::new(0.01, 0.02),
591                gradient_stability: MetricStability::new(1.0, 0.1),
592                learning_rate_stability: MetricStability::new(0.0001, 0.001),
593                window_size: 50,
594            },
595            convergence_analyzer: ConvergenceAnalyzer::new(),
596            overfitting_detector: OverfittingDetector::new(),
597            generalization_monitor: GeneralizationMonitor::new(),
598            performance_baseline: None,
599        }
600    }
601
602    /// Update health checker with new metrics
603    pub fn update(&mut self, metrics: DashboardMetrics) {
604        self.metrics_history.push_back(metrics.clone());
605
606        // Keep only recent metrics to prevent unbounded growth
607        if self.metrics_history.len() > 1000 {
608            self.metrics_history.pop_front();
609        }
610
611        // Update stability tracking
612        if let Some(loss) = metrics.loss {
613            self.stability_tracker.loss_stability.update(loss);
614        }
615        if let Some(accuracy) = metrics.accuracy {
616            self.stability_tracker.accuracy_stability.update(accuracy);
617        }
618        if let Some(grad_norm) = metrics.gradient_norm {
619            self.stability_tracker.gradient_stability.update(grad_norm);
620        }
621        if let Some(lr) = metrics.learning_rate {
622            self.stability_tracker.learning_rate_stability.update(lr);
623        }
624
625        // Update convergence analysis
626        self.convergence_analyzer.update(metrics.loss, metrics.accuracy);
627
628        // Update overfitting detection (assuming these are training metrics)
629        self.overfitting_detector.update_train_metrics(metrics.loss, metrics.accuracy);
630
631        // Update generalization monitoring
632        if let (Some(accuracy), Some(loss)) = (metrics.accuracy, metrics.loss) {
633            self.generalization_monitor.update_performance(accuracy, Some(1.0 - loss));
634        }
635    }
636
637    /// Perform comprehensive health assessment
638    pub fn assess_health(&mut self) -> Result<HealthAssessment> {
639        let overall_health_score = self.calculate_overall_health_score();
640        let training_stability_index = self.calculate_training_stability_index();
641        let convergence_probability = self.convergence_analyzer.calculate_convergence_probability();
642        let overfitting_risk = self.overfitting_detector.detect_overfitting();
643        let generalization_score = self.generalization_monitor.calculate_generalization_score();
644
645        let component_scores = self.calculate_component_scores();
646        let health_status = self.determine_health_status(overall_health_score);
647        let alerts = self.generate_health_alerts();
648        let recommendations = self.generate_health_recommendations(&alerts);
649
650        let assessment = HealthAssessment {
651            timestamp: SystemTime::now(),
652            overall_health_score,
653            training_stability_index,
654            convergence_probability,
655            overfitting_risk,
656            generalization_score,
657            component_scores,
658            health_status,
659            alerts,
660            recommendations,
661        };
662
663        self.health_assessments.push(assessment.clone());
664
665        // Keep only recent assessments
666        if self.health_assessments.len() > 100 {
667            self.health_assessments.drain(0..50);
668        }
669
670        Ok(assessment)
671    }
672
673    fn calculate_overall_health_score(&self) -> f64 {
674        let component_scores = self.calculate_component_scores();
675
676        // Weighted average of component scores
677        let weights = [
678            ("stability", 0.25),
679            ("convergence", 0.20),
680            ("gradient", 0.15),
681            ("loss", 0.15),
682            ("accuracy", 0.10),
683            ("performance", 0.10),
684            ("memory", 0.05),
685        ];
686
687        let mut weighted_sum = 0.0;
688        weighted_sum += weights[0].1 * component_scores.stability_health;
689        weighted_sum +=
690            weights[1].1 * self.convergence_analyzer.calculate_convergence_probability();
691        weighted_sum += weights[2].1 * component_scores.gradient_health;
692        weighted_sum += weights[3].1 * component_scores.loss_health;
693        weighted_sum += weights[4].1 * component_scores.accuracy_health;
694        weighted_sum += weights[5].1 * component_scores.performance_health;
695        weighted_sum += weights[6].1 * component_scores.memory_health;
696
697        weighted_sum
698    }
699
700    fn calculate_training_stability_index(&self) -> f64 {
701        let loss_stability = self.stability_tracker.loss_stability.calculate_stability();
702        let accuracy_stability = self.stability_tracker.accuracy_stability.calculate_stability();
703        let gradient_stability = self.stability_tracker.gradient_stability.calculate_stability();
704        let lr_stability = self.stability_tracker.learning_rate_stability.calculate_stability();
705
706        (loss_stability + accuracy_stability + gradient_stability + lr_stability) / 4.0
707    }
708
709    fn calculate_component_scores(&self) -> ComponentHealthScores {
710        ComponentHealthScores {
711            gradient_health: self.stability_tracker.gradient_stability.calculate_stability(),
712            loss_health: self.calculate_loss_health(),
713            accuracy_health: self.calculate_accuracy_health(),
714            performance_health: self.calculate_performance_health(),
715            memory_health: self.calculate_memory_health(),
716            stability_health: self.calculate_training_stability_index(),
717        }
718    }
719
720    fn calculate_loss_health(&self) -> f64 {
721        if self.metrics_history.len() < 10 {
722            return 0.5;
723        }
724
725        let recent_losses: Vec<f64> =
726            self.metrics_history.iter().rev().take(10).filter_map(|m| m.loss).collect();
727
728        if recent_losses.is_empty() {
729            return 0.5;
730        }
731
732        // Check if loss is generally decreasing
733        let first_half_avg = recent_losses[..recent_losses.len() / 2].iter().sum::<f64>()
734            / (recent_losses.len() / 2) as f64;
735        let second_half_avg = recent_losses[recent_losses.len() / 2..].iter().sum::<f64>()
736            / (recent_losses.len() - recent_losses.len() / 2) as f64;
737
738        if second_half_avg < first_half_avg {
739            0.8 // Loss decreasing
740        } else if (second_half_avg - first_half_avg).abs() / first_half_avg < 0.05 {
741            0.6 // Loss stable
742        } else {
743            0.3 // Loss increasing
744        }
745    }
746
747    fn calculate_accuracy_health(&self) -> f64 {
748        if self.metrics_history.len() < 10 {
749            return 0.5;
750        }
751
752        let recent_accuracies: Vec<f64> =
753            self.metrics_history.iter().rev().take(10).filter_map(|m| m.accuracy).collect();
754
755        if recent_accuracies.is_empty() {
756            return 0.5;
757        }
758
759        let avg_accuracy = recent_accuracies.iter().sum::<f64>() / recent_accuracies.len() as f64;
760
761        // Score based on absolute accuracy and stability
762        let accuracy_score = avg_accuracy; // Assuming accuracy is 0-1
763        let stability_score = self.stability_tracker.accuracy_stability.calculate_stability();
764
765        (accuracy_score + stability_score) / 2.0
766    }
767
768    fn calculate_performance_health(&self) -> f64 {
769        // Check tokens per second and GPU utilization
770        if let Some(last_metrics) = self.metrics_history.back() {
771            let mut score = 0.0;
772            let mut components = 0;
773
774            if let Some(tps) = last_metrics.tokens_per_second {
775                score += if tps > 100.0 { 0.8 } else { tps / 125.0 };
776                components += 1;
777            }
778
779            if let Some(gpu_util) = last_metrics.gpu_utilization {
780                score += gpu_util;
781                components += 1;
782            }
783
784            if components > 0 {
785                score / components as f64
786            } else {
787                0.5
788            }
789        } else {
790            0.5
791        }
792    }
793
794    fn calculate_memory_health(&self) -> f64 {
795        if let Some(last_metrics) = self.metrics_history.back() {
796            let memory_usage = last_metrics.memory_usage_mb;
797
798            // Assume 8GB as reasonable upper limit
799
800            if memory_usage < 4096.0 {
801                0.9
802            } else if memory_usage < 6144.0 {
803                0.7
804            } else if memory_usage < 8192.0 {
805                0.5
806            } else {
807                0.2
808            }
809        } else {
810            0.5
811        }
812    }
813
814    fn determine_health_status(&self, score: f64) -> HealthStatus {
815        match score {
816            s if s >= 0.9 => HealthStatus::Excellent,
817            s if s >= 0.75 => HealthStatus::Good,
818            s if s >= 0.6 => HealthStatus::Fair,
819            s if s >= 0.4 => HealthStatus::Poor,
820            _ => HealthStatus::Critical,
821        }
822    }
823
824    fn generate_health_alerts(&self) -> Vec<HealthAlert> {
825        let mut alerts = Vec::new();
826
827        // Training stability alerts
828        let stability_index = self.calculate_training_stability_index();
829        if stability_index < 0.3 {
830            alerts.push(HealthAlert {
831                alert_type: HealthAlertType::TrainingStability,
832                severity: AlertSeverity::High,
833                message: "Training is highly unstable".to_string(),
834                metric_value: stability_index,
835                threshold: 0.3,
836                trend: Trend::Degrading,
837            });
838        }
839
840        // Convergence alerts
841        let convergence_prob = self.convergence_analyzer.calculate_convergence_probability();
842        if convergence_prob < 0.2 {
843            alerts.push(HealthAlert {
844                alert_type: HealthAlertType::ConvergenceIssue,
845                severity: AlertSeverity::Medium,
846                message: "Low probability of convergence".to_string(),
847                metric_value: convergence_prob,
848                threshold: 0.2,
849                trend: Trend::Stable,
850            });
851        }
852
853        // Overfitting alerts
854        match self.overfitting_detector.detect_overfitting() {
855            OverfittingRisk::High | OverfittingRisk::Severe => {
856                alerts.push(HealthAlert {
857                    alert_type: HealthAlertType::OverfittingDetected,
858                    severity: AlertSeverity::High,
859                    message: "Significant overfitting detected".to_string(),
860                    metric_value: 0.8,
861                    threshold: 0.6,
862                    trend: Trend::Degrading,
863                });
864            },
865            OverfittingRisk::Medium => {
866                alerts.push(HealthAlert {
867                    alert_type: HealthAlertType::OverfittingDetected,
868                    severity: AlertSeverity::Medium,
869                    message: "Moderate overfitting risk".to_string(),
870                    metric_value: 0.5,
871                    threshold: 0.4,
872                    trend: Trend::Stable,
873                });
874            },
875            _ => {},
876        }
877
878        alerts
879    }
880
881    fn generate_health_recommendations(&self, alerts: &[HealthAlert]) -> Vec<HealthRecommendation> {
882        let mut recommendations = Vec::new();
883
884        for alert in alerts {
885            match alert.alert_type {
886                HealthAlertType::TrainingStability => {
887                    recommendations.push(HealthRecommendation {
888                        category: RecommendationCategory::Training,
889                        title: "Improve Training Stability".to_string(),
890                        description:
891                            "Reduce learning rate or increase batch size to stabilize training"
892                                .to_string(),
893                        urgency: RecommendationUrgency::Soon,
894                        expected_impact: 0.3,
895                    });
896                },
897                HealthAlertType::ConvergenceIssue => {
898                    recommendations.push(HealthRecommendation {
899                        category: RecommendationCategory::Hyperparameters,
900                        title: "Adjust Learning Rate Schedule".to_string(),
901                        description:
902                            "Implement learning rate scheduling or adjust optimizer settings"
903                                .to_string(),
904                        urgency: RecommendationUrgency::Eventually,
905                        expected_impact: 0.2,
906                    });
907                },
908                HealthAlertType::OverfittingDetected => {
909                    recommendations.push(HealthRecommendation {
910                        category: RecommendationCategory::Training,
911                        title: "Add Regularization".to_string(),
912                        description: "Implement dropout, weight decay, or early stopping to reduce overfitting".to_string(),
913                        urgency: RecommendationUrgency::Soon,
914                        expected_impact: 0.25,
915                    });
916                },
917                _ => {},
918            }
919        }
920
921        // Add general recommendations based on overall health
922        let overall_score = self.calculate_overall_health_score();
923        if overall_score < 0.6 {
924            recommendations.push(HealthRecommendation {
925                category: RecommendationCategory::Training,
926                title: "Comprehensive Training Review".to_string(),
927                description: "Review entire training setup including data, model architecture, and hyperparameters".to_string(),
928                urgency: RecommendationUrgency::Immediate,
929                expected_impact: 0.4,
930            });
931        }
932
933        recommendations
934    }
935
936    /// Set performance baseline for comparison
937    pub fn set_baseline(&mut self, baseline: PerformanceBaseline) {
938        self.performance_baseline = Some(baseline);
939    }
940
941    /// Get health assessment history
942    pub fn get_health_history(&self) -> &[HealthAssessment] {
943        &self.health_assessments
944    }
945
946    /// Quick health check for simplified interface
947    pub async fn quick_health_check(&self) -> Result<crate::QuickHealthSummary> {
948        let score = if let Some(assessment) = self.health_assessments.last() {
949            assessment.overall_health_score * 100.0
950        } else {
951            // If no assessments yet, do a basic check
952            50.0 // Default fair score
953        };
954
955        let status = match score {
956            90.0..=100.0 => "Excellent",
957            75.0..89.9 => "Good",
958            60.0..74.9 => "Fair",
959            40.0..59.9 => "Poor",
960            _ => "Critical",
961        }
962        .to_string();
963
964        let mut recommendations = Vec::new();
965        if score < 60.0 {
966            recommendations.push("Review training configuration and data quality".to_string());
967        }
968        if score < 40.0 {
969            recommendations
970                .push("Consider adjusting learning rate and model architecture".to_string());
971        }
972        if score < 80.0 {
973            recommendations.push("Monitor training stability and convergence".to_string());
974        }
975
976        Ok(crate::QuickHealthSummary {
977            score,
978            status,
979            recommendations,
980        })
981    }
982
983    /// Generate health report
984    pub async fn generate_report(&self) -> Result<HealthReport> {
985        let current_assessment = if let Some(assessment) = self.health_assessments.last() {
986            assessment.clone()
987        } else {
988            return Ok(HealthReport::default());
989        };
990
991        let health_trends = self.analyze_health_trends();
992        let risk_assessment = self.assess_risks();
993        let improvement_suggestions = self.generate_improvement_suggestions();
994
995        Ok(HealthReport {
996            current_health: current_assessment,
997            health_trends,
998            risk_assessment,
999            improvement_suggestions,
1000            baseline_comparison: self.compare_with_baseline(),
1001            summary: self.generate_health_summary(),
1002        })
1003    }
1004
1005    fn analyze_health_trends(&self) -> HealthTrends {
1006        if self.health_assessments.len() < 5 {
1007            return HealthTrends::default();
1008        }
1009
1010        let recent_scores: Vec<f64> = self
1011            .health_assessments
1012            .iter()
1013            .rev()
1014            .take(10)
1015            .map(|a| a.overall_health_score)
1016            .collect();
1017
1018        let first_half_avg = recent_scores[recent_scores.len() / 2..].iter().sum::<f64>()
1019            / (recent_scores.len() - recent_scores.len() / 2) as f64;
1020        let second_half_avg = recent_scores[..recent_scores.len() / 2].iter().sum::<f64>()
1021            / (recent_scores.len() / 2) as f64;
1022
1023        let trend = if second_half_avg > first_half_avg * 1.05 {
1024            Trend::Improving
1025        } else if second_half_avg < first_half_avg * 0.95 {
1026            Trend::Degrading
1027        } else {
1028            Trend::Stable
1029        };
1030
1031        HealthTrends {
1032            overall_trend: trend,
1033            stability_trend: Trend::Stable, // Simplified
1034            convergence_trend: Trend::Stable,
1035            overfitting_trend: Trend::Stable,
1036        }
1037    }
1038
1039    fn assess_risks(&self) -> Vec<HealthRisk> {
1040        let mut risks = Vec::new();
1041
1042        if let Some(current) = self.health_assessments.last() {
1043            if current.overall_health_score < 0.4 {
1044                risks.push(HealthRisk {
1045                    risk_type: "Poor Overall Health".to_string(),
1046                    probability: 0.9,
1047                    impact: 0.8,
1048                    description: "Model training is in poor health and may fail".to_string(),
1049                });
1050            }
1051
1052            match current.overfitting_risk {
1053                OverfittingRisk::High | OverfittingRisk::Severe => {
1054                    risks.push(HealthRisk {
1055                        risk_type: "Overfitting".to_string(),
1056                        probability: 0.8,
1057                        impact: 0.6,
1058                        description: "Model is likely overfitting and will generalize poorly"
1059                            .to_string(),
1060                    });
1061                },
1062                _ => {},
1063            }
1064
1065            if current.convergence_probability < 0.3 {
1066                risks.push(HealthRisk {
1067                    risk_type: "Training Failure".to_string(),
1068                    probability: 0.7,
1069                    impact: 0.9,
1070                    description: "Training may not converge to a useful solution".to_string(),
1071                });
1072            }
1073        }
1074
1075        risks
1076    }
1077
1078    fn generate_improvement_suggestions(&self) -> Vec<ImprovementSuggestion> {
1079        let mut suggestions = Vec::new();
1080
1081        if let Some(current) = self.health_assessments.last() {
1082            if current.component_scores.stability_health < 0.5 {
1083                suggestions.push(ImprovementSuggestion {
1084                    area: "Training Stability".to_string(),
1085                    suggestion: "Reduce learning rate and increase batch size".to_string(),
1086                    expected_improvement: 0.3,
1087                    implementation_effort: "Low".to_string(),
1088                });
1089            }
1090
1091            if current.convergence_probability < 0.5 {
1092                suggestions.push(ImprovementSuggestion {
1093                    area: "Convergence".to_string(),
1094                    suggestion: "Implement learning rate scheduling and gradient clipping"
1095                        .to_string(),
1096                    expected_improvement: 0.25,
1097                    implementation_effort: "Medium".to_string(),
1098                });
1099            }
1100
1101            match current.overfitting_risk {
1102                OverfittingRisk::Medium | OverfittingRisk::High | OverfittingRisk::Severe => {
1103                    suggestions.push(ImprovementSuggestion {
1104                        area: "Overfitting Prevention".to_string(),
1105                        suggestion: "Add dropout layers, implement early stopping, or increase training data".to_string(),
1106                        expected_improvement: 0.4,
1107                        implementation_effort: "Medium".to_string(),
1108                    });
1109                },
1110                _ => {},
1111            }
1112        }
1113
1114        suggestions
1115    }
1116
1117    fn compare_with_baseline(&self) -> Option<BaselineComparison> {
1118        if let (Some(_baseline), Some(current)) =
1119            (&self.performance_baseline, self.health_assessments.last())
1120        {
1121            Some(BaselineComparison {
1122                health_score_change: current.overall_health_score - 0.8, // Simplified baseline score
1123                stability_change: current.training_stability_index - 0.7,
1124                convergence_change: current.convergence_probability - 0.6,
1125                improvement_percentage: ((current.overall_health_score - 0.8) / 0.8 * 100.0)
1126                    .max(-100.0),
1127            })
1128        } else {
1129            None
1130        }
1131    }
1132
1133    fn generate_health_summary(&self) -> String {
1134        if let Some(current) = self.health_assessments.last() {
1135            match current.health_status {
1136                HealthStatus::Excellent => "Training is in excellent health with stable convergence and no significant issues detected.".to_string(),
1137                HealthStatus::Good => "Training is proceeding well with minor optimization opportunities.".to_string(),
1138                HealthStatus::Fair => "Training shows some concerning patterns that should be addressed.".to_string(),
1139                HealthStatus::Poor => "Training has significant issues requiring immediate attention.".to_string(),
1140                HealthStatus::Critical => "Training is in critical condition and may fail without intervention.".to_string(),
1141            }
1142        } else {
1143            "Insufficient data for health assessment.".to_string()
1144        }
1145    }
1146}
1147
1148// Report structures
1149
1150#[derive(Debug, Serialize, Deserialize)]
1151pub struct HealthReport {
1152    pub current_health: HealthAssessment,
1153    pub health_trends: HealthTrends,
1154    pub risk_assessment: Vec<HealthRisk>,
1155    pub improvement_suggestions: Vec<ImprovementSuggestion>,
1156    pub baseline_comparison: Option<BaselineComparison>,
1157    pub summary: String,
1158}
1159
1160impl Default for HealthReport {
1161    fn default() -> Self {
1162        Self {
1163            current_health: HealthAssessment {
1164                timestamp: SystemTime::now(),
1165                overall_health_score: 0.5,
1166                training_stability_index: 0.5,
1167                convergence_probability: 0.5,
1168                overfitting_risk: OverfittingRisk::None,
1169                generalization_score: 0.5,
1170                component_scores: ComponentHealthScores {
1171                    gradient_health: 0.5,
1172                    loss_health: 0.5,
1173                    accuracy_health: 0.5,
1174                    performance_health: 0.5,
1175                    memory_health: 0.5,
1176                    stability_health: 0.5,
1177                },
1178                health_status: HealthStatus::Fair,
1179                alerts: Vec::new(),
1180                recommendations: Vec::new(),
1181            },
1182            health_trends: HealthTrends::default(),
1183            risk_assessment: Vec::new(),
1184            improvement_suggestions: Vec::new(),
1185            baseline_comparison: None,
1186            summary: "No health data available yet.".to_string(),
1187        }
1188    }
1189}
1190
1191#[derive(Debug, Clone, Serialize, Deserialize)]
1192pub struct HealthTrends {
1193    pub overall_trend: Trend,
1194    pub stability_trend: Trend,
1195    pub convergence_trend: Trend,
1196    pub overfitting_trend: Trend,
1197}
1198
1199impl Default for HealthTrends {
1200    fn default() -> Self {
1201        Self {
1202            overall_trend: Trend::Stable,
1203            stability_trend: Trend::Stable,
1204            convergence_trend: Trend::Stable,
1205            overfitting_trend: Trend::Stable,
1206        }
1207    }
1208}
1209
1210#[derive(Debug, Clone, Serialize, Deserialize)]
1211pub struct HealthRisk {
1212    pub risk_type: String,
1213    pub probability: f64,
1214    pub impact: f64,
1215    pub description: String,
1216}
1217
1218#[derive(Debug, Clone, Serialize, Deserialize)]
1219pub struct ImprovementSuggestion {
1220    pub area: String,
1221    pub suggestion: String,
1222    pub expected_improvement: f64,
1223    pub implementation_effort: String,
1224}
1225
1226#[derive(Debug, Clone, Serialize, Deserialize)]
1227pub struct BaselineComparison {
1228    pub health_score_change: f64,
1229    pub stability_change: f64,
1230    pub convergence_change: f64,
1231    pub improvement_percentage: f64,
1232}