Skip to main content

trustformers_debug/
health_checker.rs

1//! Model and training health assessment system
2
3use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use std::collections::VecDeque;
6use std::time::{Duration, SystemTime};
7
8use crate::{DashboardMetrics, DebugConfig};
9
10/// Comprehensive health checker for model training
11#[derive(Debug)]
12pub struct HealthChecker {
13    #[allow(dead_code)]
14    config: DebugConfig,
15    metrics_history: VecDeque<DashboardMetrics>,
16    health_assessments: Vec<HealthAssessment>,
17    stability_tracker: StabilityTracker,
18    convergence_analyzer: ConvergenceAnalyzer,
19    overfitting_detector: OverfittingDetector,
20    generalization_monitor: GeneralizationMonitor,
21    performance_baseline: Option<PerformanceBaseline>,
22}
23
24/// Overall health assessment
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct HealthAssessment {
27    pub timestamp: SystemTime,
28    pub overall_health_score: f64,
29    pub training_stability_index: f64,
30    pub convergence_probability: f64,
31    pub overfitting_risk: OverfittingRisk,
32    pub generalization_score: f64,
33    pub component_scores: ComponentHealthScores,
34    pub health_status: HealthStatus,
35    pub alerts: Vec<HealthAlert>,
36    pub recommendations: Vec<HealthRecommendation>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ComponentHealthScores {
41    pub gradient_health: f64,
42    pub loss_health: f64,
43    pub accuracy_health: f64,
44    pub performance_health: f64,
45    pub memory_health: f64,
46    pub stability_health: f64,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub enum HealthStatus {
51    Excellent, // 90-100%
52    Good,      // 75-89%
53    Fair,      // 60-74%
54    Poor,      // 40-59%
55    Critical,  // 0-39%
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct HealthAlert {
60    pub alert_type: HealthAlertType,
61    pub severity: AlertSeverity,
62    pub message: String,
63    pub metric_value: f64,
64    pub threshold: f64,
65    pub trend: Trend,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub enum HealthAlertType {
70    TrainingStability,
71    ConvergenceIssue,
72    OverfittingDetected,
73    PerformanceDegradation,
74    MemoryIssue,
75    GradientProblem,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub enum AlertSeverity {
80    Critical,
81    High,
82    Medium,
83    Low,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub enum Trend {
88    Improving,
89    Stable,
90    Degrading,
91    Volatile,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct HealthRecommendation {
96    pub category: RecommendationCategory,
97    pub title: String,
98    pub description: String,
99    pub urgency: RecommendationUrgency,
100    pub expected_impact: f64,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub enum RecommendationCategory {
105    Training,
106    Architecture,
107    Hyperparameters,
108    Data,
109    Performance,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub enum RecommendationUrgency {
114    Immediate,
115    Soon,
116    Eventually,
117    Optional,
118}
119
120/// Training stability tracking
121#[derive(Debug)]
122pub struct StabilityTracker {
123    loss_stability: MetricStability,
124    accuracy_stability: MetricStability,
125    gradient_stability: MetricStability,
126    learning_rate_stability: MetricStability,
127    #[allow(dead_code)]
128    window_size: usize,
129}
130
131#[derive(Debug)]
132pub struct MetricStability {
133    values: VecDeque<f64>,
134    variance_threshold: f64,
135    #[allow(dead_code)]
136    trend_threshold: f64,
137}
138
139impl MetricStability {
140    pub fn new(variance_threshold: f64, trend_threshold: f64) -> Self {
141        Self {
142            values: VecDeque::new(),
143            variance_threshold,
144            trend_threshold,
145        }
146    }
147
148    pub fn update(&mut self, value: f64) {
149        self.values.push_back(value);
150        if self.values.len() > 50 {
151            self.values.pop_front();
152        }
153    }
154
155    pub fn calculate_stability(&self) -> f64 {
156        if self.values.len() < 5 {
157            return 0.5; // Insufficient data
158        }
159
160        let variance = self.calculate_variance();
161        let trend_stability = self.calculate_trend_stability();
162
163        let variance_score = if variance < self.variance_threshold {
164            1.0
165        } else {
166            (self.variance_threshold / variance).min(1.0)
167        };
168
169        (variance_score + trend_stability) / 2.0
170    }
171
172    fn calculate_variance(&self) -> f64 {
173        if self.values.len() < 2 {
174            return 0.0;
175        }
176
177        let mean = self.values.iter().sum::<f64>() / self.values.len() as f64;
178        let variance = self.values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
179            / (self.values.len() - 1) as f64;
180        variance
181    }
182
183    fn calculate_trend_stability(&self) -> f64 {
184        if self.values.len() < 10 {
185            return 0.5;
186        }
187
188        // Calculate slope changes to detect instability
189        let mut slope_changes = 0;
190        let values: Vec<f64> = self.values.iter().cloned().collect();
191
192        for i in 2..values.len() {
193            let slope1 = values[i - 1] - values[i - 2];
194            let slope2 = values[i] - values[i - 1];
195
196            if (slope1 > 0.0) != (slope2 > 0.0) {
197                slope_changes += 1;
198            }
199        }
200
201        let change_rate = slope_changes as f64 / (values.len() - 2) as f64;
202        (1.0 - change_rate).max(0.0)
203    }
204}
205
206/// Convergence analysis
207#[derive(Debug)]
208pub struct ConvergenceAnalyzer {
209    loss_history: VecDeque<f64>,
210    accuracy_history: VecDeque<f64>,
211    convergence_window: usize,
212    convergence_threshold: f64,
213}
214
215impl Default for ConvergenceAnalyzer {
216    fn default() -> Self {
217        Self::new()
218    }
219}
220
221impl ConvergenceAnalyzer {
222    pub fn new() -> Self {
223        Self {
224            loss_history: VecDeque::new(),
225            accuracy_history: VecDeque::new(),
226            convergence_window: 100,
227            convergence_threshold: 0.01,
228        }
229    }
230
231    pub fn update(&mut self, loss: Option<f64>, accuracy: Option<f64>) {
232        if let Some(loss) = loss {
233            self.loss_history.push_back(loss);
234            if self.loss_history.len() > self.convergence_window * 2 {
235                self.loss_history.pop_front();
236            }
237        }
238
239        if let Some(accuracy) = accuracy {
240            self.accuracy_history.push_back(accuracy);
241            if self.accuracy_history.len() > self.convergence_window * 2 {
242                self.accuracy_history.pop_front();
243            }
244        }
245    }
246
247    pub fn calculate_convergence_probability(&self) -> f64 {
248        let loss_convergence = self.analyze_loss_convergence();
249        let accuracy_convergence = self.analyze_accuracy_convergence();
250
251        // Weight loss convergence more heavily
252        0.7 * loss_convergence + 0.3 * accuracy_convergence
253    }
254
255    fn analyze_loss_convergence(&self) -> f64 {
256        if self.loss_history.len() < self.convergence_window {
257            return 0.3; // Insufficient data
258        }
259
260        let recent_window = self.convergence_window / 2;
261        let recent_losses: Vec<f64> =
262            self.loss_history.iter().rev().take(recent_window).cloned().collect();
263
264        let earlier_losses: Vec<f64> = self
265            .loss_history
266            .iter()
267            .rev()
268            .skip(recent_window)
269            .take(recent_window)
270            .cloned()
271            .collect();
272
273        if recent_losses.is_empty() || earlier_losses.is_empty() {
274            return 0.3;
275        }
276
277        let recent_avg = recent_losses.iter().sum::<f64>() / recent_losses.len() as f64;
278        let earlier_avg = earlier_losses.iter().sum::<f64>() / earlier_losses.len() as f64;
279
280        // Check if loss is decreasing
281        if recent_avg < earlier_avg {
282            let improvement_rate = (earlier_avg - recent_avg) / earlier_avg;
283
284            if improvement_rate > self.convergence_threshold {
285                0.8 // Good convergence
286            } else {
287                0.6 // Slow convergence
288            }
289        } else {
290            // Loss increasing or stagnant
291            let variance = self.calculate_variance(&recent_losses);
292            if variance < self.convergence_threshold {
293                0.4 // Converged but not improving
294            } else {
295                0.2 // Diverging or unstable
296            }
297        }
298    }
299
300    fn analyze_accuracy_convergence(&self) -> f64 {
301        if self.accuracy_history.len() < self.convergence_window {
302            return 0.5;
303        }
304
305        let recent_window = self.convergence_window / 2;
306        let recent_accuracy: Vec<f64> =
307            self.accuracy_history.iter().rev().take(recent_window).cloned().collect();
308
309        let variance = self.calculate_variance(&recent_accuracy);
310
311        // Low variance in accuracy suggests convergence
312        if variance < 0.01 {
313            0.8
314        } else if variance < 0.05 {
315            0.6
316        } else {
317            0.4
318        }
319    }
320
321    fn calculate_variance(&self, values: &[f64]) -> f64 {
322        if values.len() < 2 {
323            return 0.0;
324        }
325
326        let mean = values.iter().sum::<f64>() / values.len() as f64;
327        let variance =
328            values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
329        variance
330    }
331}
332
333/// Overfitting detection
334#[derive(Debug)]
335pub struct OverfittingDetector {
336    train_loss_history: VecDeque<f64>,
337    val_loss_history: VecDeque<f64>,
338    train_accuracy_history: VecDeque<f64>,
339    val_accuracy_history: VecDeque<f64>,
340    overfitting_threshold: f64,
341}
342
343#[derive(Debug, Clone, Serialize, Deserialize)]
344pub enum OverfittingRisk {
345    None,
346    Low,
347    Medium,
348    High,
349    Severe,
350}
351
352impl Default for OverfittingDetector {
353    fn default() -> Self {
354        Self::new()
355    }
356}
357
358impl OverfittingDetector {
359    pub fn new() -> Self {
360        Self {
361            train_loss_history: VecDeque::new(),
362            val_loss_history: VecDeque::new(),
363            train_accuracy_history: VecDeque::new(),
364            val_accuracy_history: VecDeque::new(),
365            overfitting_threshold: 0.1,
366        }
367    }
368
369    pub fn update_train_metrics(&mut self, loss: Option<f64>, accuracy: Option<f64>) {
370        if let Some(loss) = loss {
371            self.train_loss_history.push_back(loss);
372            if self.train_loss_history.len() > 100 {
373                self.train_loss_history.pop_front();
374            }
375        }
376
377        if let Some(accuracy) = accuracy {
378            self.train_accuracy_history.push_back(accuracy);
379            if self.train_accuracy_history.len() > 100 {
380                self.train_accuracy_history.pop_front();
381            }
382        }
383    }
384
385    pub fn update_validation_metrics(&mut self, loss: Option<f64>, accuracy: Option<f64>) {
386        if let Some(loss) = loss {
387            self.val_loss_history.push_back(loss);
388            if self.val_loss_history.len() > 100 {
389                self.val_loss_history.pop_front();
390            }
391        }
392
393        if let Some(accuracy) = accuracy {
394            self.val_accuracy_history.push_back(accuracy);
395            if self.val_accuracy_history.len() > 100 {
396                self.val_accuracy_history.pop_front();
397            }
398        }
399    }
400
401    pub fn detect_overfitting(&self) -> OverfittingRisk {
402        let loss_gap = self.calculate_loss_gap();
403        let accuracy_gap = self.calculate_accuracy_gap();
404        let trend_analysis = self.analyze_overfitting_trend();
405
406        let overfitting_score = (loss_gap + accuracy_gap + trend_analysis) / 3.0;
407
408        match overfitting_score {
409            score if score > 0.8 => OverfittingRisk::Severe,
410            score if score > 0.6 => OverfittingRisk::High,
411            score if score > 0.4 => OverfittingRisk::Medium,
412            score if score > 0.2 => OverfittingRisk::Low,
413            _ => OverfittingRisk::None,
414        }
415    }
416
417    fn calculate_loss_gap(&self) -> f64 {
418        if self.train_loss_history.len() < 10 || self.val_loss_history.len() < 10 {
419            return 0.0;
420        }
421
422        let recent_train_loss = self.train_loss_history.iter().rev().take(10).sum::<f64>() / 10.0;
423
424        let recent_val_loss = self.val_loss_history.iter().rev().take(10).sum::<f64>() / 10.0;
425
426        if recent_train_loss < recent_val_loss {
427            let gap = (recent_val_loss - recent_train_loss) / recent_train_loss;
428            (gap / self.overfitting_threshold).min(1.0)
429        } else {
430            0.0
431        }
432    }
433
434    fn calculate_accuracy_gap(&self) -> f64 {
435        if self.train_accuracy_history.len() < 10 || self.val_accuracy_history.len() < 10 {
436            return 0.0;
437        }
438
439        let recent_train_acc =
440            self.train_accuracy_history.iter().rev().take(10).sum::<f64>() / 10.0;
441
442        let recent_val_acc = self.val_accuracy_history.iter().rev().take(10).sum::<f64>() / 10.0;
443
444        if recent_train_acc > recent_val_acc {
445            let gap = recent_train_acc - recent_val_acc;
446            (gap / self.overfitting_threshold).min(1.0)
447        } else {
448            0.0
449        }
450    }
451
452    fn analyze_overfitting_trend(&self) -> f64 {
453        // Analyze if the gap between train and validation is increasing
454        if self.train_loss_history.len() < 20 || self.val_loss_history.len() < 20 {
455            return 0.0;
456        }
457
458        let early_train_loss = self.train_loss_history.iter().take(10).sum::<f64>() / 10.0;
459
460        let recent_train_loss = self.train_loss_history.iter().rev().take(10).sum::<f64>() / 10.0;
461
462        let early_val_loss = self.val_loss_history.iter().take(10).sum::<f64>() / 10.0;
463
464        let recent_val_loss = self.val_loss_history.iter().rev().take(10).sum::<f64>() / 10.0;
465
466        let early_gap = (early_val_loss - early_train_loss).max(0.0);
467        let recent_gap = (recent_val_loss - recent_train_loss).max(0.0);
468
469        if recent_gap > early_gap && early_gap > 0.0 {
470            ((recent_gap - early_gap) / early_gap).min(1.0)
471        } else {
472            0.0
473        }
474    }
475}
476
477/// Generalization monitoring
478#[derive(Debug)]
479pub struct GeneralizationMonitor {
480    cross_validation_scores: Vec<f64>,
481    holdout_performance: Option<f64>,
482    train_performance: Option<f64>,
483    complexity_metrics: ComplexityMetrics,
484}
485
486#[derive(Debug)]
487pub struct ComplexityMetrics {
488    parameter_count: usize,
489    #[allow(dead_code)]
490    effective_capacity: f64,
491    data_size: usize,
492}
493
494impl Default for GeneralizationMonitor {
495    fn default() -> Self {
496        Self::new()
497    }
498}
499
500impl GeneralizationMonitor {
501    pub fn new() -> Self {
502        Self {
503            cross_validation_scores: Vec::new(),
504            holdout_performance: None,
505            train_performance: None,
506            complexity_metrics: ComplexityMetrics {
507                parameter_count: 0,
508                effective_capacity: 0.0,
509                data_size: 0,
510            },
511        }
512    }
513
514    pub fn update_performance(&mut self, train_perf: f64, val_perf: Option<f64>) {
515        self.train_performance = Some(train_perf);
516        if let Some(val_perf) = val_perf {
517            self.holdout_performance = Some(val_perf);
518        }
519    }
520
521    pub fn calculate_generalization_score(&self) -> f64 {
522        let performance_consistency = self.calculate_performance_consistency();
523        let complexity_penalty = self.calculate_complexity_penalty();
524        let cv_consistency = self.calculate_cv_consistency();
525
526        (performance_consistency + cv_consistency + (1.0 - complexity_penalty)) / 3.0
527    }
528
529    fn calculate_performance_consistency(&self) -> f64 {
530        match (self.train_performance, self.holdout_performance) {
531            (Some(train), Some(val)) => {
532                let gap = (train - val).abs();
533                (1.0 - gap.min(1.0)).max(0.0)
534            },
535            _ => 0.5, // Unknown
536        }
537    }
538
539    fn calculate_complexity_penalty(&self) -> f64 {
540        // Simplified complexity penalty based on parameter count vs data size
541        if self.complexity_metrics.data_size == 0 {
542            return 0.0;
543        }
544
545        let param_per_sample = self.complexity_metrics.parameter_count as f64
546            / self.complexity_metrics.data_size as f64;
547
548        if param_per_sample > 1.0 {
549            0.8 // High complexity
550        } else if param_per_sample > 0.1 {
551            0.4 // Medium complexity
552        } else {
553            0.1 // Low complexity
554        }
555    }
556
557    fn calculate_cv_consistency(&self) -> f64 {
558        if self.cross_validation_scores.len() < 3 {
559            return 0.5;
560        }
561
562        let mean = self.cross_validation_scores.iter().sum::<f64>()
563            / self.cross_validation_scores.len() as f64;
564        let variance = self.cross_validation_scores.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
565            / (self.cross_validation_scores.len() - 1) as f64;
566
567        (1.0 - variance.sqrt().min(1.0)).max(0.0)
568    }
569}
570
571/// Performance baseline for comparison
572#[derive(Debug, Clone)]
573pub struct PerformanceBaseline {
574    pub baseline_loss: f64,
575    pub baseline_accuracy: f64,
576    pub baseline_training_time: Duration,
577    pub baseline_memory_usage: f64,
578    pub established_at: SystemTime,
579}
580
581impl HealthChecker {
582    /// Create new health checker
583    pub fn new(config: &DebugConfig) -> Self {
584        Self {
585            config: config.clone(),
586            metrics_history: VecDeque::new(),
587            health_assessments: Vec::new(),
588            stability_tracker: StabilityTracker {
589                loss_stability: MetricStability::new(0.1, 0.05),
590                accuracy_stability: MetricStability::new(0.01, 0.02),
591                gradient_stability: MetricStability::new(1.0, 0.1),
592                learning_rate_stability: MetricStability::new(0.0001, 0.001),
593                window_size: 50,
594            },
595            convergence_analyzer: ConvergenceAnalyzer::new(),
596            overfitting_detector: OverfittingDetector::new(),
597            generalization_monitor: GeneralizationMonitor::new(),
598            performance_baseline: None,
599        }
600    }
601
602    /// Update health checker with new metrics
603    pub fn update(&mut self, metrics: DashboardMetrics) {
604        self.metrics_history.push_back(metrics.clone());
605
606        // Keep only recent metrics to prevent unbounded growth
607        if self.metrics_history.len() > 1000 {
608            self.metrics_history.pop_front();
609        }
610
611        // Update stability tracking
612        if let Some(loss) = metrics.loss {
613            self.stability_tracker.loss_stability.update(loss);
614        }
615        if let Some(accuracy) = metrics.accuracy {
616            self.stability_tracker.accuracy_stability.update(accuracy);
617        }
618        if let Some(grad_norm) = metrics.gradient_norm {
619            self.stability_tracker.gradient_stability.update(grad_norm);
620        }
621        if let Some(lr) = metrics.learning_rate {
622            self.stability_tracker.learning_rate_stability.update(lr);
623        }
624
625        // Update convergence analysis
626        self.convergence_analyzer.update(metrics.loss, metrics.accuracy);
627
628        // Update overfitting detection (assuming these are training metrics)
629        self.overfitting_detector.update_train_metrics(metrics.loss, metrics.accuracy);
630
631        // Update generalization monitoring
632        if let (Some(accuracy), Some(loss)) = (metrics.accuracy, metrics.loss) {
633            self.generalization_monitor.update_performance(accuracy, Some(1.0 - loss));
634        }
635    }
636
637    /// Perform comprehensive health assessment
638    pub fn assess_health(&mut self) -> Result<HealthAssessment> {
639        let overall_health_score = self.calculate_overall_health_score();
640        let training_stability_index = self.calculate_training_stability_index();
641        let convergence_probability = self.convergence_analyzer.calculate_convergence_probability();
642        let overfitting_risk = self.overfitting_detector.detect_overfitting();
643        let generalization_score = self.generalization_monitor.calculate_generalization_score();
644
645        let component_scores = self.calculate_component_scores();
646        let health_status = self.determine_health_status(overall_health_score);
647        let alerts = self.generate_health_alerts();
648        let recommendations = self.generate_health_recommendations(&alerts);
649
650        let assessment = HealthAssessment {
651            timestamp: SystemTime::now(),
652            overall_health_score,
653            training_stability_index,
654            convergence_probability,
655            overfitting_risk,
656            generalization_score,
657            component_scores,
658            health_status,
659            alerts,
660            recommendations,
661        };
662
663        self.health_assessments.push(assessment.clone());
664
665        // Keep only recent assessments
666        if self.health_assessments.len() > 100 {
667            self.health_assessments.drain(0..50);
668        }
669
670        Ok(assessment)
671    }
672
673    fn calculate_overall_health_score(&self) -> f64 {
674        let component_scores = self.calculate_component_scores();
675
676        // Weighted average of component scores
677        let weights = [
678            ("stability", 0.25),
679            ("convergence", 0.20),
680            ("gradient", 0.15),
681            ("loss", 0.15),
682            ("accuracy", 0.10),
683            ("performance", 0.10),
684            ("memory", 0.05),
685        ];
686
687        let mut weighted_sum = 0.0;
688        weighted_sum += weights[0].1 * component_scores.stability_health;
689        weighted_sum +=
690            weights[1].1 * self.convergence_analyzer.calculate_convergence_probability();
691        weighted_sum += weights[2].1 * component_scores.gradient_health;
692        weighted_sum += weights[3].1 * component_scores.loss_health;
693        weighted_sum += weights[4].1 * component_scores.accuracy_health;
694        weighted_sum += weights[5].1 * component_scores.performance_health;
695        weighted_sum += weights[6].1 * component_scores.memory_health;
696
697        weighted_sum
698    }
699
700    fn calculate_training_stability_index(&self) -> f64 {
701        let loss_stability = self.stability_tracker.loss_stability.calculate_stability();
702        let accuracy_stability = self.stability_tracker.accuracy_stability.calculate_stability();
703        let gradient_stability = self.stability_tracker.gradient_stability.calculate_stability();
704        let lr_stability = self.stability_tracker.learning_rate_stability.calculate_stability();
705
706        (loss_stability + accuracy_stability + gradient_stability + lr_stability) / 4.0
707    }
708
709    fn calculate_component_scores(&self) -> ComponentHealthScores {
710        ComponentHealthScores {
711            gradient_health: self.stability_tracker.gradient_stability.calculate_stability(),
712            loss_health: self.calculate_loss_health(),
713            accuracy_health: self.calculate_accuracy_health(),
714            performance_health: self.calculate_performance_health(),
715            memory_health: self.calculate_memory_health(),
716            stability_health: self.calculate_training_stability_index(),
717        }
718    }
719
720    fn calculate_loss_health(&self) -> f64 {
721        if self.metrics_history.len() < 10 {
722            return 0.5;
723        }
724
725        let recent_losses: Vec<f64> =
726            self.metrics_history.iter().rev().take(10).filter_map(|m| m.loss).collect();
727
728        if recent_losses.is_empty() {
729            return 0.5;
730        }
731
732        // Check if loss is generally decreasing
733        let first_half_avg = recent_losses[..recent_losses.len() / 2].iter().sum::<f64>()
734            / (recent_losses.len() / 2) as f64;
735        let second_half_avg = recent_losses[recent_losses.len() / 2..].iter().sum::<f64>()
736            / (recent_losses.len() - recent_losses.len() / 2) as f64;
737
738        if second_half_avg < first_half_avg {
739            0.8 // Loss decreasing
740        } else if (second_half_avg - first_half_avg).abs() / first_half_avg < 0.05 {
741            0.6 // Loss stable
742        } else {
743            0.3 // Loss increasing
744        }
745    }
746
747    fn calculate_accuracy_health(&self) -> f64 {
748        if self.metrics_history.len() < 10 {
749            return 0.5;
750        }
751
752        let recent_accuracies: Vec<f64> =
753            self.metrics_history.iter().rev().take(10).filter_map(|m| m.accuracy).collect();
754
755        if recent_accuracies.is_empty() {
756            return 0.5;
757        }
758
759        let avg_accuracy = recent_accuracies.iter().sum::<f64>() / recent_accuracies.len() as f64;
760
761        // Score based on absolute accuracy and stability
762        let accuracy_score = avg_accuracy; // Assuming accuracy is 0-1
763        let stability_score = self.stability_tracker.accuracy_stability.calculate_stability();
764
765        (accuracy_score + stability_score) / 2.0
766    }
767
768    fn calculate_performance_health(&self) -> f64 {
769        // Check tokens per second and GPU utilization
770        if let Some(last_metrics) = self.metrics_history.back() {
771            let mut score = 0.0;
772            let mut components = 0;
773
774            if let Some(tps) = last_metrics.tokens_per_second {
775                score += if tps > 100.0 { 0.8 } else { tps / 125.0 };
776                components += 1;
777            }
778
779            if let Some(gpu_util) = last_metrics.gpu_utilization {
780                score += gpu_util;
781                components += 1;
782            }
783
784            if components > 0 {
785                score / components as f64
786            } else {
787                0.5
788            }
789        } else {
790            0.5
791        }
792    }
793
794    fn calculate_memory_health(&self) -> f64 {
795        if let Some(last_metrics) = self.metrics_history.back() {
796            let memory_usage = last_metrics.memory_usage_mb;
797
798            // Assume 8GB as reasonable upper limit
799
800            if memory_usage < 4096.0 {
801                0.9
802            } else if memory_usage < 6144.0 {
803                0.7
804            } else if memory_usage < 8192.0 {
805                0.5
806            } else {
807                0.2
808            }
809        } else {
810            0.5
811        }
812    }
813
814    fn determine_health_status(&self, score: f64) -> HealthStatus {
815        match score {
816            s if s >= 0.9 => HealthStatus::Excellent,
817            s if s >= 0.75 => HealthStatus::Good,
818            s if s >= 0.6 => HealthStatus::Fair,
819            s if s >= 0.4 => HealthStatus::Poor,
820            _ => HealthStatus::Critical,
821        }
822    }
823
824    fn generate_health_alerts(&self) -> Vec<HealthAlert> {
825        let mut alerts = Vec::new();
826
827        // Training stability alerts
828        let stability_index = self.calculate_training_stability_index();
829        if stability_index < 0.3 {
830            alerts.push(HealthAlert {
831                alert_type: HealthAlertType::TrainingStability,
832                severity: AlertSeverity::High,
833                message: "Training is highly unstable".to_string(),
834                metric_value: stability_index,
835                threshold: 0.3,
836                trend: Trend::Degrading,
837            });
838        }
839
840        // Convergence alerts
841        let convergence_prob = self.convergence_analyzer.calculate_convergence_probability();
842        if convergence_prob < 0.2 {
843            alerts.push(HealthAlert {
844                alert_type: HealthAlertType::ConvergenceIssue,
845                severity: AlertSeverity::Medium,
846                message: "Low probability of convergence".to_string(),
847                metric_value: convergence_prob,
848                threshold: 0.2,
849                trend: Trend::Stable,
850            });
851        }
852
853        // Overfitting alerts
854        match self.overfitting_detector.detect_overfitting() {
855            OverfittingRisk::High | OverfittingRisk::Severe => {
856                alerts.push(HealthAlert {
857                    alert_type: HealthAlertType::OverfittingDetected,
858                    severity: AlertSeverity::High,
859                    message: "Significant overfitting detected".to_string(),
860                    metric_value: 0.8,
861                    threshold: 0.6,
862                    trend: Trend::Degrading,
863                });
864            },
865            OverfittingRisk::Medium => {
866                alerts.push(HealthAlert {
867                    alert_type: HealthAlertType::OverfittingDetected,
868                    severity: AlertSeverity::Medium,
869                    message: "Moderate overfitting risk".to_string(),
870                    metric_value: 0.5,
871                    threshold: 0.4,
872                    trend: Trend::Stable,
873                });
874            },
875            _ => {},
876        }
877
878        alerts
879    }
880
881    fn generate_health_recommendations(&self, alerts: &[HealthAlert]) -> Vec<HealthRecommendation> {
882        let mut recommendations = Vec::new();
883
884        for alert in alerts {
885            match alert.alert_type {
886                HealthAlertType::TrainingStability => {
887                    recommendations.push(HealthRecommendation {
888                        category: RecommendationCategory::Training,
889                        title: "Improve Training Stability".to_string(),
890                        description:
891                            "Reduce learning rate or increase batch size to stabilize training"
892                                .to_string(),
893                        urgency: RecommendationUrgency::Soon,
894                        expected_impact: 0.3,
895                    });
896                },
897                HealthAlertType::ConvergenceIssue => {
898                    recommendations.push(HealthRecommendation {
899                        category: RecommendationCategory::Hyperparameters,
900                        title: "Adjust Learning Rate Schedule".to_string(),
901                        description:
902                            "Implement learning rate scheduling or adjust optimizer settings"
903                                .to_string(),
904                        urgency: RecommendationUrgency::Eventually,
905                        expected_impact: 0.2,
906                    });
907                },
908                HealthAlertType::OverfittingDetected => {
909                    recommendations.push(HealthRecommendation {
910                        category: RecommendationCategory::Training,
911                        title: "Add Regularization".to_string(),
912                        description: "Implement dropout, weight decay, or early stopping to reduce overfitting".to_string(),
913                        urgency: RecommendationUrgency::Soon,
914                        expected_impact: 0.25,
915                    });
916                },
917                _ => {},
918            }
919        }
920
921        // Add general recommendations based on overall health
922        let overall_score = self.calculate_overall_health_score();
923        if overall_score < 0.6 {
924            recommendations.push(HealthRecommendation {
925                category: RecommendationCategory::Training,
926                title: "Comprehensive Training Review".to_string(),
927                description: "Review entire training setup including data, model architecture, and hyperparameters".to_string(),
928                urgency: RecommendationUrgency::Immediate,
929                expected_impact: 0.4,
930            });
931        }
932
933        recommendations
934    }
935
936    /// Set performance baseline for comparison
937    pub fn set_baseline(&mut self, baseline: PerformanceBaseline) {
938        self.performance_baseline = Some(baseline);
939    }
940
941    /// Get health assessment history
942    pub fn get_health_history(&self) -> &[HealthAssessment] {
943        &self.health_assessments
944    }
945
946    /// Quick health check for simplified interface
947    pub async fn quick_health_check(&self) -> Result<crate::QuickHealthSummary> {
948        let score = if let Some(assessment) = self.health_assessments.last() {
949            assessment.overall_health_score * 100.0
950        } else {
951            // If no assessments yet, do a basic check
952            50.0 // Default fair score
953        };
954
955        let status = match score {
956            90.0..=100.0 => "Excellent",
957            75.0..89.9 => "Good",
958            60.0..74.9 => "Fair",
959            40.0..59.9 => "Poor",
960            _ => "Critical",
961        }
962        .to_string();
963
964        let mut recommendations = Vec::new();
965        if score < 60.0 {
966            recommendations.push("Review training configuration and data quality".to_string());
967        }
968        if score < 40.0 {
969            recommendations
970                .push("Consider adjusting learning rate and model architecture".to_string());
971        }
972        if score < 80.0 {
973            recommendations.push("Monitor training stability and convergence".to_string());
974        }
975
976        Ok(crate::QuickHealthSummary {
977            score,
978            status,
979            recommendations,
980        })
981    }
982
983    /// Generate health report
984    pub async fn generate_report(&self) -> Result<HealthReport> {
985        let current_assessment = if let Some(assessment) = self.health_assessments.last() {
986            assessment.clone()
987        } else {
988            return Ok(HealthReport::default());
989        };
990
991        let health_trends = self.analyze_health_trends();
992        let risk_assessment = self.assess_risks();
993        let improvement_suggestions = self.generate_improvement_suggestions();
994
995        Ok(HealthReport {
996            current_health: current_assessment,
997            health_trends,
998            risk_assessment,
999            improvement_suggestions,
1000            baseline_comparison: self.compare_with_baseline(),
1001            summary: self.generate_health_summary(),
1002        })
1003    }
1004
1005    fn analyze_health_trends(&self) -> HealthTrends {
1006        if self.health_assessments.len() < 5 {
1007            return HealthTrends::default();
1008        }
1009
1010        let recent_scores: Vec<f64> = self
1011            .health_assessments
1012            .iter()
1013            .rev()
1014            .take(10)
1015            .map(|a| a.overall_health_score)
1016            .collect();
1017
1018        let first_half_avg = recent_scores[recent_scores.len() / 2..].iter().sum::<f64>()
1019            / (recent_scores.len() - recent_scores.len() / 2) as f64;
1020        let second_half_avg = recent_scores[..recent_scores.len() / 2].iter().sum::<f64>()
1021            / (recent_scores.len() / 2) as f64;
1022
1023        let trend = if second_half_avg > first_half_avg * 1.05 {
1024            Trend::Improving
1025        } else if second_half_avg < first_half_avg * 0.95 {
1026            Trend::Degrading
1027        } else {
1028            Trend::Stable
1029        };
1030
1031        HealthTrends {
1032            overall_trend: trend,
1033            stability_trend: Trend::Stable, // Simplified
1034            convergence_trend: Trend::Stable,
1035            overfitting_trend: Trend::Stable,
1036        }
1037    }
1038
1039    fn assess_risks(&self) -> Vec<HealthRisk> {
1040        let mut risks = Vec::new();
1041
1042        if let Some(current) = self.health_assessments.last() {
1043            if current.overall_health_score < 0.4 {
1044                risks.push(HealthRisk {
1045                    risk_type: "Poor Overall Health".to_string(),
1046                    probability: 0.9,
1047                    impact: 0.8,
1048                    description: "Model training is in poor health and may fail".to_string(),
1049                });
1050            }
1051
1052            match current.overfitting_risk {
1053                OverfittingRisk::High | OverfittingRisk::Severe => {
1054                    risks.push(HealthRisk {
1055                        risk_type: "Overfitting".to_string(),
1056                        probability: 0.8,
1057                        impact: 0.6,
1058                        description: "Model is likely overfitting and will generalize poorly"
1059                            .to_string(),
1060                    });
1061                },
1062                _ => {},
1063            }
1064
1065            if current.convergence_probability < 0.3 {
1066                risks.push(HealthRisk {
1067                    risk_type: "Training Failure".to_string(),
1068                    probability: 0.7,
1069                    impact: 0.9,
1070                    description: "Training may not converge to a useful solution".to_string(),
1071                });
1072            }
1073        }
1074
1075        risks
1076    }
1077
1078    fn generate_improvement_suggestions(&self) -> Vec<ImprovementSuggestion> {
1079        let mut suggestions = Vec::new();
1080
1081        if let Some(current) = self.health_assessments.last() {
1082            if current.component_scores.stability_health < 0.5 {
1083                suggestions.push(ImprovementSuggestion {
1084                    area: "Training Stability".to_string(),
1085                    suggestion: "Reduce learning rate and increase batch size".to_string(),
1086                    expected_improvement: 0.3,
1087                    implementation_effort: "Low".to_string(),
1088                });
1089            }
1090
1091            if current.convergence_probability < 0.5 {
1092                suggestions.push(ImprovementSuggestion {
1093                    area: "Convergence".to_string(),
1094                    suggestion: "Implement learning rate scheduling and gradient clipping"
1095                        .to_string(),
1096                    expected_improvement: 0.25,
1097                    implementation_effort: "Medium".to_string(),
1098                });
1099            }
1100
1101            match current.overfitting_risk {
1102                OverfittingRisk::Medium | OverfittingRisk::High | OverfittingRisk::Severe => {
1103                    suggestions.push(ImprovementSuggestion {
1104                        area: "Overfitting Prevention".to_string(),
1105                        suggestion: "Add dropout layers, implement early stopping, or increase training data".to_string(),
1106                        expected_improvement: 0.4,
1107                        implementation_effort: "Medium".to_string(),
1108                    });
1109                },
1110                _ => {},
1111            }
1112        }
1113
1114        suggestions
1115    }
1116
1117    fn compare_with_baseline(&self) -> Option<BaselineComparison> {
1118        if let (Some(_baseline), Some(current)) =
1119            (&self.performance_baseline, self.health_assessments.last())
1120        {
1121            Some(BaselineComparison {
1122                health_score_change: current.overall_health_score - 0.8, // Simplified baseline score
1123                stability_change: current.training_stability_index - 0.7,
1124                convergence_change: current.convergence_probability - 0.6,
1125                improvement_percentage: ((current.overall_health_score - 0.8) / 0.8 * 100.0)
1126                    .max(-100.0),
1127            })
1128        } else {
1129            None
1130        }
1131    }
1132
1133    fn generate_health_summary(&self) -> String {
1134        if let Some(current) = self.health_assessments.last() {
1135            match current.health_status {
1136                HealthStatus::Excellent => "Training is in excellent health with stable convergence and no significant issues detected.".to_string(),
1137                HealthStatus::Good => "Training is proceeding well with minor optimization opportunities.".to_string(),
1138                HealthStatus::Fair => "Training shows some concerning patterns that should be addressed.".to_string(),
1139                HealthStatus::Poor => "Training has significant issues requiring immediate attention.".to_string(),
1140                HealthStatus::Critical => "Training is in critical condition and may fail without intervention.".to_string(),
1141            }
1142        } else {
1143            "Insufficient data for health assessment.".to_string()
1144        }
1145    }
1146}
1147
1148// Report structures
1149
1150#[derive(Debug, Serialize, Deserialize)]
1151pub struct HealthReport {
1152    pub current_health: HealthAssessment,
1153    pub health_trends: HealthTrends,
1154    pub risk_assessment: Vec<HealthRisk>,
1155    pub improvement_suggestions: Vec<ImprovementSuggestion>,
1156    pub baseline_comparison: Option<BaselineComparison>,
1157    pub summary: String,
1158}
1159
1160impl Default for HealthReport {
1161    fn default() -> Self {
1162        Self {
1163            current_health: HealthAssessment {
1164                timestamp: SystemTime::now(),
1165                overall_health_score: 0.5,
1166                training_stability_index: 0.5,
1167                convergence_probability: 0.5,
1168                overfitting_risk: OverfittingRisk::None,
1169                generalization_score: 0.5,
1170                component_scores: ComponentHealthScores {
1171                    gradient_health: 0.5,
1172                    loss_health: 0.5,
1173                    accuracy_health: 0.5,
1174                    performance_health: 0.5,
1175                    memory_health: 0.5,
1176                    stability_health: 0.5,
1177                },
1178                health_status: HealthStatus::Fair,
1179                alerts: Vec::new(),
1180                recommendations: Vec::new(),
1181            },
1182            health_trends: HealthTrends::default(),
1183            risk_assessment: Vec::new(),
1184            improvement_suggestions: Vec::new(),
1185            baseline_comparison: None,
1186            summary: "No health data available yet.".to_string(),
1187        }
1188    }
1189}
1190
1191#[derive(Debug, Clone, Serialize, Deserialize)]
1192pub struct HealthTrends {
1193    pub overall_trend: Trend,
1194    pub stability_trend: Trend,
1195    pub convergence_trend: Trend,
1196    pub overfitting_trend: Trend,
1197}
1198
1199impl Default for HealthTrends {
1200    fn default() -> Self {
1201        Self {
1202            overall_trend: Trend::Stable,
1203            stability_trend: Trend::Stable,
1204            convergence_trend: Trend::Stable,
1205            overfitting_trend: Trend::Stable,
1206        }
1207    }
1208}
1209
1210#[derive(Debug, Clone, Serialize, Deserialize)]
1211pub struct HealthRisk {
1212    pub risk_type: String,
1213    pub probability: f64,
1214    pub impact: f64,
1215    pub description: String,
1216}
1217
1218#[derive(Debug, Clone, Serialize, Deserialize)]
1219pub struct ImprovementSuggestion {
1220    pub area: String,
1221    pub suggestion: String,
1222    pub expected_improvement: f64,
1223    pub implementation_effort: String,
1224}
1225
1226#[derive(Debug, Clone, Serialize, Deserialize)]
1227pub struct BaselineComparison {
1228    pub health_score_change: f64,
1229    pub stability_change: f64,
1230    pub convergence_change: f64,
1231    pub improvement_percentage: f64,
1232}
1233
1234#[cfg(test)]
1235mod tests {
1236    use super::*;
1237    use std::time::SystemTime;
1238
1239    fn make_metrics(loss: Option<f64>, accuracy: Option<f64>) -> DashboardMetrics {
1240        DashboardMetrics {
1241            timestamp: SystemTime::now(),
1242            loss,
1243            accuracy,
1244            learning_rate: Some(0.001),
1245            memory_usage_mb: 2048.0,
1246            gpu_utilization: Some(0.75),
1247            tokens_per_second: Some(200.0),
1248            gradient_norm: Some(1.0),
1249            epoch: Some(1),
1250            step: Some(100),
1251        }
1252    }
1253
1254    fn make_config() -> DebugConfig {
1255        DebugConfig::default()
1256    }
1257
1258    // --- MetricStability tests ---
1259
1260    #[test]
1261    fn test_metric_stability_new() {
1262        let ms = MetricStability::new(0.1, 0.05);
1263        assert!(ms.values.is_empty());
1264        assert!((ms.variance_threshold - 0.1).abs() < 1e-9);
1265    }
1266
1267    #[test]
1268    fn test_metric_stability_update() {
1269        let mut ms = MetricStability::new(0.1, 0.05);
1270        ms.update(1.0);
1271        ms.update(2.0);
1272        assert_eq!(ms.values.len(), 2);
1273    }
1274
1275    #[test]
1276    fn test_metric_stability_update_overflow() {
1277        let mut ms = MetricStability::new(0.1, 0.05);
1278        for i in 0..60 {
1279            ms.update(i as f64);
1280        }
1281        assert_eq!(ms.values.len(), 50);
1282    }
1283
1284    #[test]
1285    fn test_metric_stability_insufficient_data() {
1286        let ms = MetricStability::new(0.1, 0.05);
1287        let stability = ms.calculate_stability();
1288        assert!((stability - 0.5).abs() < 1e-9);
1289    }
1290
1291    #[test]
1292    fn test_metric_stability_perfect_stability() {
1293        let mut ms = MetricStability::new(0.1, 0.05);
1294        for _ in 0..20 {
1295            ms.update(5.0);
1296        }
1297        let stability = ms.calculate_stability();
1298        // Zero variance -> variance_score = 1.0
1299        // All same values -> no slope changes -> trend_stability = 1.0 ideally
1300        assert!(stability > 0.7);
1301    }
1302
1303    #[test]
1304    fn test_metric_stability_variance_calculation() {
1305        let mut ms = MetricStability::new(0.1, 0.05);
1306        ms.update(2.0);
1307        ms.update(4.0);
1308        let variance = ms.calculate_variance();
1309        // variance = (2-3)^2 + (4-3)^2 / 1 = 2.0
1310        assert!((variance - 2.0).abs() < 1e-9);
1311    }
1312
1313    #[test]
1314    fn test_metric_stability_single_value_variance() {
1315        let mut ms = MetricStability::new(0.1, 0.05);
1316        ms.update(5.0);
1317        let variance = ms.calculate_variance();
1318        assert!((variance - 0.0).abs() < 1e-9);
1319    }
1320
1321    #[test]
1322    fn test_metric_stability_trend_stability_insufficient() {
1323        let mut ms = MetricStability::new(0.1, 0.05);
1324        for i in 0..5 {
1325            ms.update(i as f64);
1326        }
1327        let trend = ms.calculate_trend_stability();
1328        assert!((trend - 0.5).abs() < 1e-9);
1329    }
1330
1331    #[test]
1332    fn test_metric_stability_trend_stability_monotonic() {
1333        let mut ms = MetricStability::new(0.1, 0.05);
1334        for i in 0..20 {
1335            ms.update(i as f64);
1336        }
1337        let trend = ms.calculate_trend_stability();
1338        // Monotonically increasing -> no slope changes -> trend = 1.0
1339        assert!((trend - 1.0).abs() < 1e-9);
1340    }
1341
1342    // --- ConvergenceAnalyzer tests ---
1343
1344    #[test]
1345    fn test_convergence_analyzer_new() {
1346        let ca = ConvergenceAnalyzer::new();
1347        assert_eq!(ca.convergence_window, 100);
1348        assert!((ca.convergence_threshold - 0.01).abs() < 1e-9);
1349    }
1350
1351    #[test]
1352    fn test_convergence_analyzer_default() {
1353        let ca = ConvergenceAnalyzer::default();
1354        assert!(ca.loss_history.is_empty());
1355    }
1356
1357    #[test]
1358    fn test_convergence_analyzer_update_loss() {
1359        let mut ca = ConvergenceAnalyzer::new();
1360        ca.update(Some(1.5), None);
1361        assert_eq!(ca.loss_history.len(), 1);
1362        assert!(ca.accuracy_history.is_empty());
1363    }
1364
1365    #[test]
1366    fn test_convergence_analyzer_update_accuracy() {
1367        let mut ca = ConvergenceAnalyzer::new();
1368        ca.update(None, Some(0.85));
1369        assert!(ca.loss_history.is_empty());
1370        assert_eq!(ca.accuracy_history.len(), 1);
1371    }
1372
1373    #[test]
1374    fn test_convergence_analyzer_history_limit() {
1375        let mut ca = ConvergenceAnalyzer::new();
1376        for i in 0..250 {
1377            ca.update(Some(i as f64), None);
1378        }
1379        assert!(ca.loss_history.len() <= 200);
1380    }
1381
1382    #[test]
1383    fn test_convergence_probability_insufficient_data() {
1384        let ca = ConvergenceAnalyzer::new();
1385        let prob = ca.calculate_convergence_probability();
1386        // With no data: 0.7 * 0.3 + 0.3 * 0.5 = 0.21 + 0.15 = 0.36
1387        assert!(prob > 0.0 && prob < 1.0);
1388    }
1389
1390    #[test]
1391    fn test_convergence_variance_empty() {
1392        let ca = ConvergenceAnalyzer::new();
1393        let var = ca.calculate_variance(&[]);
1394        assert!((var - 0.0).abs() < 1e-9);
1395    }
1396
1397    #[test]
1398    fn test_convergence_variance_single() {
1399        let ca = ConvergenceAnalyzer::new();
1400        let var = ca.calculate_variance(&[5.0]);
1401        assert!((var - 0.0).abs() < 1e-9);
1402    }
1403
1404    #[test]
1405    fn test_convergence_variance_values() {
1406        let ca = ConvergenceAnalyzer::new();
1407        let var = ca.calculate_variance(&[1.0, 3.0]);
1408        // mean=2, variance = ((1-2)^2 + (3-2)^2)/(2-1) = 2.0
1409        assert!((var - 2.0).abs() < 1e-9);
1410    }
1411
1412    // --- OverfittingDetector tests ---
1413
1414    #[test]
1415    fn test_overfitting_detector_new() {
1416        let od = OverfittingDetector::new();
1417        assert!(od.train_loss_history.is_empty());
1418        assert!((od.overfitting_threshold - 0.1).abs() < 1e-9);
1419    }
1420
1421    #[test]
1422    fn test_overfitting_detector_default() {
1423        let od = OverfittingDetector::default();
1424        assert!(od.val_loss_history.is_empty());
1425    }
1426
1427    #[test]
1428    fn test_overfitting_detector_update_train_metrics() {
1429        let mut od = OverfittingDetector::new();
1430        od.update_train_metrics(Some(0.5), Some(0.9));
1431        assert_eq!(od.train_loss_history.len(), 1);
1432        assert_eq!(od.train_accuracy_history.len(), 1);
1433    }
1434
1435    #[test]
1436    fn test_overfitting_detector_update_validation_metrics() {
1437        let mut od = OverfittingDetector::new();
1438        od.update_validation_metrics(Some(0.6), Some(0.85));
1439        assert_eq!(od.val_loss_history.len(), 1);
1440        assert_eq!(od.val_accuracy_history.len(), 1);
1441    }
1442
1443    #[test]
1444    fn test_overfitting_detector_history_limit() {
1445        let mut od = OverfittingDetector::new();
1446        for i in 0..120 {
1447            od.update_train_metrics(Some(i as f64), None);
1448        }
1449        assert_eq!(od.train_loss_history.len(), 100);
1450    }
1451
1452    #[test]
1453    fn test_overfitting_no_data() {
1454        let od = OverfittingDetector::new();
1455        let risk = od.detect_overfitting();
1456        matches!(risk, OverfittingRisk::None);
1457    }
1458
1459    #[test]
1460    fn test_overfitting_loss_gap_insufficient() {
1461        let od = OverfittingDetector::new();
1462        let gap = od.calculate_loss_gap();
1463        assert!((gap - 0.0).abs() < 1e-9);
1464    }
1465
1466    #[test]
1467    fn test_overfitting_accuracy_gap_insufficient() {
1468        let od = OverfittingDetector::new();
1469        let gap = od.calculate_accuracy_gap();
1470        assert!((gap - 0.0).abs() < 1e-9);
1471    }
1472
1473    // --- GeneralizationMonitor tests ---
1474
1475    #[test]
1476    fn test_generalization_monitor_new() {
1477        let gm = GeneralizationMonitor::new();
1478        assert!(gm.cross_validation_scores.is_empty());
1479        assert!(gm.holdout_performance.is_none());
1480        assert!(gm.train_performance.is_none());
1481    }
1482
1483    #[test]
1484    fn test_generalization_monitor_default() {
1485        let gm = GeneralizationMonitor::default();
1486        assert!(gm.holdout_performance.is_none());
1487    }
1488
1489    #[test]
1490    fn test_generalization_update_performance() {
1491        let mut gm = GeneralizationMonitor::new();
1492        gm.update_performance(0.95, Some(0.90));
1493        assert!((gm.train_performance.expect("should be set") - 0.95).abs() < 1e-9);
1494        assert!((gm.holdout_performance.expect("should be set") - 0.90).abs() < 1e-9);
1495    }
1496
1497    #[test]
1498    fn test_generalization_score_no_data() {
1499        let gm = GeneralizationMonitor::new();
1500        let score = gm.calculate_generalization_score();
1501        // (0.5 + 0.5 + 1.0) / 3.0 = 0.667
1502        assert!(score > 0.0 && score < 1.0);
1503    }
1504
1505    #[test]
1506    fn test_generalization_performance_consistency_perfect() {
1507        let mut gm = GeneralizationMonitor::new();
1508        gm.update_performance(0.9, Some(0.9));
1509        let consistency = gm.calculate_performance_consistency();
1510        assert!((consistency - 1.0).abs() < 1e-9);
1511    }
1512
1513    #[test]
1514    fn test_generalization_performance_consistency_with_gap() {
1515        let mut gm = GeneralizationMonitor::new();
1516        gm.update_performance(0.9, Some(0.7));
1517        let consistency = gm.calculate_performance_consistency();
1518        // 1.0 - 0.2 = 0.8
1519        assert!((consistency - 0.8).abs() < 1e-9);
1520    }
1521
1522    #[test]
1523    fn test_generalization_complexity_penalty_no_data() {
1524        let gm = GeneralizationMonitor::new();
1525        let penalty = gm.calculate_complexity_penalty();
1526        assert!((penalty - 0.0).abs() < 1e-9);
1527    }
1528
1529    #[test]
1530    fn test_generalization_cv_consistency_insufficient() {
1531        let gm = GeneralizationMonitor::new();
1532        let cv = gm.calculate_cv_consistency();
1533        assert!((cv - 0.5).abs() < 1e-9);
1534    }
1535
1536    // --- HealthChecker tests ---
1537
1538    #[test]
1539    fn test_health_checker_new() {
1540        let config = make_config();
1541        let hc = HealthChecker::new(&config);
1542        assert!(hc.metrics_history.is_empty());
1543        assert!(hc.health_assessments.is_empty());
1544        assert!(hc.performance_baseline.is_none());
1545    }
1546
1547    #[test]
1548    fn test_health_checker_update() {
1549        let config = make_config();
1550        let mut hc = HealthChecker::new(&config);
1551        hc.update(make_metrics(Some(0.5), Some(0.8)));
1552        assert_eq!(hc.metrics_history.len(), 1);
1553    }
1554
1555    #[test]
1556    fn test_health_checker_update_limit() {
1557        let config = make_config();
1558        let mut hc = HealthChecker::new(&config);
1559        for i in 0..1100 {
1560            hc.update(make_metrics(Some(1.0 / (i as f64 + 1.0)), Some(0.5)));
1561        }
1562        assert!(hc.metrics_history.len() <= 1000);
1563    }
1564
1565    #[test]
1566    fn test_health_checker_determine_health_status() {
1567        let config = make_config();
1568        let hc = HealthChecker::new(&config);
1569        assert!(matches!(
1570            hc.determine_health_status(0.95),
1571            HealthStatus::Excellent
1572        ));
1573        assert!(matches!(
1574            hc.determine_health_status(0.80),
1575            HealthStatus::Good
1576        ));
1577        assert!(matches!(
1578            hc.determine_health_status(0.65),
1579            HealthStatus::Fair
1580        ));
1581        assert!(matches!(
1582            hc.determine_health_status(0.45),
1583            HealthStatus::Poor
1584        ));
1585        assert!(matches!(
1586            hc.determine_health_status(0.1),
1587            HealthStatus::Critical
1588        ));
1589    }
1590
1591    #[test]
1592    fn test_health_checker_set_baseline() {
1593        let config = make_config();
1594        let mut hc = HealthChecker::new(&config);
1595        let baseline = PerformanceBaseline {
1596            baseline_loss: 0.5,
1597            baseline_accuracy: 0.9,
1598            baseline_training_time: Duration::from_secs(3600),
1599            baseline_memory_usage: 4096.0,
1600            established_at: SystemTime::now(),
1601        };
1602        hc.set_baseline(baseline);
1603        assert!(hc.performance_baseline.is_some());
1604    }
1605
1606    #[test]
1607    fn test_health_checker_get_health_history() {
1608        let config = make_config();
1609        let hc = HealthChecker::new(&config);
1610        let history = hc.get_health_history();
1611        assert!(history.is_empty());
1612    }
1613
1614    #[test]
1615    fn test_health_checker_assess_health() {
1616        let config = make_config();
1617        let mut hc = HealthChecker::new(&config);
1618        for i in 0..20 {
1619            hc.update(make_metrics(
1620                Some(1.0 - i as f64 * 0.04),
1621                Some(0.5 + i as f64 * 0.02),
1622            ));
1623        }
1624        let result = hc.assess_health();
1625        assert!(result.is_ok());
1626        let assessment = result.expect("should succeed");
1627        assert!(assessment.overall_health_score >= 0.0);
1628        assert!(assessment.overall_health_score <= 1.0);
1629    }
1630
1631    #[test]
1632    fn test_health_checker_generate_health_summary_no_data() {
1633        let config = make_config();
1634        let hc = HealthChecker::new(&config);
1635        let summary = hc.generate_health_summary();
1636        assert!(summary.contains("Insufficient"));
1637    }
1638
1639    #[test]
1640    fn test_health_checker_loss_health_insufficient_data() {
1641        let config = make_config();
1642        let hc = HealthChecker::new(&config);
1643        let health = hc.calculate_loss_health();
1644        assert!((health - 0.5).abs() < 1e-9);
1645    }
1646
1647    #[test]
1648    fn test_health_checker_accuracy_health_insufficient() {
1649        let config = make_config();
1650        let hc = HealthChecker::new(&config);
1651        let health = hc.calculate_accuracy_health();
1652        assert!((health - 0.5).abs() < 1e-9);
1653    }
1654
1655    #[test]
1656    fn test_health_checker_performance_health_no_metrics() {
1657        let config = make_config();
1658        let hc = HealthChecker::new(&config);
1659        let health = hc.calculate_performance_health();
1660        assert!((health - 0.5).abs() < 1e-9);
1661    }
1662
1663    #[test]
1664    fn test_health_checker_memory_health_no_metrics() {
1665        let config = make_config();
1666        let hc = HealthChecker::new(&config);
1667        let health = hc.calculate_memory_health();
1668        assert!((health - 0.5).abs() < 1e-9);
1669    }
1670
1671    #[test]
1672    fn test_health_checker_memory_health_low_usage() {
1673        let config = make_config();
1674        let mut hc = HealthChecker::new(&config);
1675        hc.update(make_metrics(Some(0.5), Some(0.8)));
1676        let health = hc.calculate_memory_health();
1677        assert!((health - 0.9).abs() < 1e-9);
1678    }
1679
1680    #[test]
1681    fn test_health_checker_compare_with_baseline_none() {
1682        let config = make_config();
1683        let hc = HealthChecker::new(&config);
1684        assert!(hc.compare_with_baseline().is_none());
1685    }
1686
1687    #[test]
1688    fn test_health_checker_health_trends_insufficient() {
1689        let config = make_config();
1690        let hc = HealthChecker::new(&config);
1691        let trends = hc.analyze_health_trends();
1692        assert!(matches!(trends.overall_trend, Trend::Stable));
1693    }
1694
1695    #[test]
1696    fn test_health_checker_assess_risks_empty() {
1697        let config = make_config();
1698        let hc = HealthChecker::new(&config);
1699        let risks = hc.assess_risks();
1700        assert!(risks.is_empty());
1701    }
1702
1703    #[test]
1704    fn test_health_checker_improvement_suggestions_empty() {
1705        let config = make_config();
1706        let hc = HealthChecker::new(&config);
1707        let suggestions = hc.generate_improvement_suggestions();
1708        assert!(suggestions.is_empty());
1709    }
1710
1711    #[test]
1712    fn test_health_report_default() {
1713        let report = HealthReport::default();
1714        assert!((report.current_health.overall_health_score - 0.5).abs() < 1e-9);
1715        assert!(matches!(
1716            report.current_health.health_status,
1717            HealthStatus::Fair
1718        ));
1719    }
1720
1721    #[test]
1722    fn test_health_trends_default() {
1723        let trends = HealthTrends::default();
1724        assert!(matches!(trends.overall_trend, Trend::Stable));
1725        assert!(matches!(trends.stability_trend, Trend::Stable));
1726    }
1727}