1use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use std::collections::VecDeque;
6use std::time::{Duration, SystemTime};
7
8use crate::{DashboardMetrics, DebugConfig};
9
10#[derive(Debug)]
12pub struct HealthChecker {
13 #[allow(dead_code)]
14 config: DebugConfig,
15 metrics_history: VecDeque<DashboardMetrics>,
16 health_assessments: Vec<HealthAssessment>,
17 stability_tracker: StabilityTracker,
18 convergence_analyzer: ConvergenceAnalyzer,
19 overfitting_detector: OverfittingDetector,
20 generalization_monitor: GeneralizationMonitor,
21 performance_baseline: Option<PerformanceBaseline>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct HealthAssessment {
27 pub timestamp: SystemTime,
28 pub overall_health_score: f64,
29 pub training_stability_index: f64,
30 pub convergence_probability: f64,
31 pub overfitting_risk: OverfittingRisk,
32 pub generalization_score: f64,
33 pub component_scores: ComponentHealthScores,
34 pub health_status: HealthStatus,
35 pub alerts: Vec<HealthAlert>,
36 pub recommendations: Vec<HealthRecommendation>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ComponentHealthScores {
41 pub gradient_health: f64,
42 pub loss_health: f64,
43 pub accuracy_health: f64,
44 pub performance_health: f64,
45 pub memory_health: f64,
46 pub stability_health: f64,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub enum HealthStatus {
51 Excellent, Good, Fair, Poor, Critical, }
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct HealthAlert {
60 pub alert_type: HealthAlertType,
61 pub severity: AlertSeverity,
62 pub message: String,
63 pub metric_value: f64,
64 pub threshold: f64,
65 pub trend: Trend,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub enum HealthAlertType {
70 TrainingStability,
71 ConvergenceIssue,
72 OverfittingDetected,
73 PerformanceDegradation,
74 MemoryIssue,
75 GradientProblem,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub enum AlertSeverity {
80 Critical,
81 High,
82 Medium,
83 Low,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub enum Trend {
88 Improving,
89 Stable,
90 Degrading,
91 Volatile,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct HealthRecommendation {
96 pub category: RecommendationCategory,
97 pub title: String,
98 pub description: String,
99 pub urgency: RecommendationUrgency,
100 pub expected_impact: f64,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub enum RecommendationCategory {
105 Training,
106 Architecture,
107 Hyperparameters,
108 Data,
109 Performance,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub enum RecommendationUrgency {
114 Immediate,
115 Soon,
116 Eventually,
117 Optional,
118}
119
120#[derive(Debug)]
122pub struct StabilityTracker {
123 loss_stability: MetricStability,
124 accuracy_stability: MetricStability,
125 gradient_stability: MetricStability,
126 learning_rate_stability: MetricStability,
127 #[allow(dead_code)]
128 window_size: usize,
129}
130
131#[derive(Debug)]
132pub struct MetricStability {
133 values: VecDeque<f64>,
134 variance_threshold: f64,
135 #[allow(dead_code)]
136 trend_threshold: f64,
137}
138
139impl MetricStability {
140 pub fn new(variance_threshold: f64, trend_threshold: f64) -> Self {
141 Self {
142 values: VecDeque::new(),
143 variance_threshold,
144 trend_threshold,
145 }
146 }
147
148 pub fn update(&mut self, value: f64) {
149 self.values.push_back(value);
150 if self.values.len() > 50 {
151 self.values.pop_front();
152 }
153 }
154
155 pub fn calculate_stability(&self) -> f64 {
156 if self.values.len() < 5 {
157 return 0.5; }
159
160 let variance = self.calculate_variance();
161 let trend_stability = self.calculate_trend_stability();
162
163 let variance_score = if variance < self.variance_threshold {
164 1.0
165 } else {
166 (self.variance_threshold / variance).min(1.0)
167 };
168
169 (variance_score + trend_stability) / 2.0
170 }
171
172 fn calculate_variance(&self) -> f64 {
173 if self.values.len() < 2 {
174 return 0.0;
175 }
176
177 let mean = self.values.iter().sum::<f64>() / self.values.len() as f64;
178 let variance = self.values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
179 / (self.values.len() - 1) as f64;
180 variance
181 }
182
183 fn calculate_trend_stability(&self) -> f64 {
184 if self.values.len() < 10 {
185 return 0.5;
186 }
187
188 let mut slope_changes = 0;
190 let values: Vec<f64> = self.values.iter().cloned().collect();
191
192 for i in 2..values.len() {
193 let slope1 = values[i - 1] - values[i - 2];
194 let slope2 = values[i] - values[i - 1];
195
196 if (slope1 > 0.0) != (slope2 > 0.0) {
197 slope_changes += 1;
198 }
199 }
200
201 let change_rate = slope_changes as f64 / (values.len() - 2) as f64;
202 (1.0 - change_rate).max(0.0)
203 }
204}
205
206#[derive(Debug)]
208pub struct ConvergenceAnalyzer {
209 loss_history: VecDeque<f64>,
210 accuracy_history: VecDeque<f64>,
211 convergence_window: usize,
212 convergence_threshold: f64,
213}
214
215impl Default for ConvergenceAnalyzer {
216 fn default() -> Self {
217 Self::new()
218 }
219}
220
221impl ConvergenceAnalyzer {
222 pub fn new() -> Self {
223 Self {
224 loss_history: VecDeque::new(),
225 accuracy_history: VecDeque::new(),
226 convergence_window: 100,
227 convergence_threshold: 0.01,
228 }
229 }
230
231 pub fn update(&mut self, loss: Option<f64>, accuracy: Option<f64>) {
232 if let Some(loss) = loss {
233 self.loss_history.push_back(loss);
234 if self.loss_history.len() > self.convergence_window * 2 {
235 self.loss_history.pop_front();
236 }
237 }
238
239 if let Some(accuracy) = accuracy {
240 self.accuracy_history.push_back(accuracy);
241 if self.accuracy_history.len() > self.convergence_window * 2 {
242 self.accuracy_history.pop_front();
243 }
244 }
245 }
246
247 pub fn calculate_convergence_probability(&self) -> f64 {
248 let loss_convergence = self.analyze_loss_convergence();
249 let accuracy_convergence = self.analyze_accuracy_convergence();
250
251 0.7 * loss_convergence + 0.3 * accuracy_convergence
253 }
254
255 fn analyze_loss_convergence(&self) -> f64 {
256 if self.loss_history.len() < self.convergence_window {
257 return 0.3; }
259
260 let recent_window = self.convergence_window / 2;
261 let recent_losses: Vec<f64> =
262 self.loss_history.iter().rev().take(recent_window).cloned().collect();
263
264 let earlier_losses: Vec<f64> = self
265 .loss_history
266 .iter()
267 .rev()
268 .skip(recent_window)
269 .take(recent_window)
270 .cloned()
271 .collect();
272
273 if recent_losses.is_empty() || earlier_losses.is_empty() {
274 return 0.3;
275 }
276
277 let recent_avg = recent_losses.iter().sum::<f64>() / recent_losses.len() as f64;
278 let earlier_avg = earlier_losses.iter().sum::<f64>() / earlier_losses.len() as f64;
279
280 if recent_avg < earlier_avg {
282 let improvement_rate = (earlier_avg - recent_avg) / earlier_avg;
283
284 if improvement_rate > self.convergence_threshold {
285 0.8 } else {
287 0.6 }
289 } else {
290 let variance = self.calculate_variance(&recent_losses);
292 if variance < self.convergence_threshold {
293 0.4 } else {
295 0.2 }
297 }
298 }
299
300 fn analyze_accuracy_convergence(&self) -> f64 {
301 if self.accuracy_history.len() < self.convergence_window {
302 return 0.5;
303 }
304
305 let recent_window = self.convergence_window / 2;
306 let recent_accuracy: Vec<f64> =
307 self.accuracy_history.iter().rev().take(recent_window).cloned().collect();
308
309 let variance = self.calculate_variance(&recent_accuracy);
310
311 if variance < 0.01 {
313 0.8
314 } else if variance < 0.05 {
315 0.6
316 } else {
317 0.4
318 }
319 }
320
321 fn calculate_variance(&self, values: &[f64]) -> f64 {
322 if values.len() < 2 {
323 return 0.0;
324 }
325
326 let mean = values.iter().sum::<f64>() / values.len() as f64;
327 let variance =
328 values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
329 variance
330 }
331}
332
333#[derive(Debug)]
335pub struct OverfittingDetector {
336 train_loss_history: VecDeque<f64>,
337 val_loss_history: VecDeque<f64>,
338 train_accuracy_history: VecDeque<f64>,
339 val_accuracy_history: VecDeque<f64>,
340 overfitting_threshold: f64,
341}
342
343#[derive(Debug, Clone, Serialize, Deserialize)]
344pub enum OverfittingRisk {
345 None,
346 Low,
347 Medium,
348 High,
349 Severe,
350}
351
352impl Default for OverfittingDetector {
353 fn default() -> Self {
354 Self::new()
355 }
356}
357
358impl OverfittingDetector {
359 pub fn new() -> Self {
360 Self {
361 train_loss_history: VecDeque::new(),
362 val_loss_history: VecDeque::new(),
363 train_accuracy_history: VecDeque::new(),
364 val_accuracy_history: VecDeque::new(),
365 overfitting_threshold: 0.1,
366 }
367 }
368
369 pub fn update_train_metrics(&mut self, loss: Option<f64>, accuracy: Option<f64>) {
370 if let Some(loss) = loss {
371 self.train_loss_history.push_back(loss);
372 if self.train_loss_history.len() > 100 {
373 self.train_loss_history.pop_front();
374 }
375 }
376
377 if let Some(accuracy) = accuracy {
378 self.train_accuracy_history.push_back(accuracy);
379 if self.train_accuracy_history.len() > 100 {
380 self.train_accuracy_history.pop_front();
381 }
382 }
383 }
384
385 pub fn update_validation_metrics(&mut self, loss: Option<f64>, accuracy: Option<f64>) {
386 if let Some(loss) = loss {
387 self.val_loss_history.push_back(loss);
388 if self.val_loss_history.len() > 100 {
389 self.val_loss_history.pop_front();
390 }
391 }
392
393 if let Some(accuracy) = accuracy {
394 self.val_accuracy_history.push_back(accuracy);
395 if self.val_accuracy_history.len() > 100 {
396 self.val_accuracy_history.pop_front();
397 }
398 }
399 }
400
401 pub fn detect_overfitting(&self) -> OverfittingRisk {
402 let loss_gap = self.calculate_loss_gap();
403 let accuracy_gap = self.calculate_accuracy_gap();
404 let trend_analysis = self.analyze_overfitting_trend();
405
406 let overfitting_score = (loss_gap + accuracy_gap + trend_analysis) / 3.0;
407
408 match overfitting_score {
409 score if score > 0.8 => OverfittingRisk::Severe,
410 score if score > 0.6 => OverfittingRisk::High,
411 score if score > 0.4 => OverfittingRisk::Medium,
412 score if score > 0.2 => OverfittingRisk::Low,
413 _ => OverfittingRisk::None,
414 }
415 }
416
417 fn calculate_loss_gap(&self) -> f64 {
418 if self.train_loss_history.len() < 10 || self.val_loss_history.len() < 10 {
419 return 0.0;
420 }
421
422 let recent_train_loss = self.train_loss_history.iter().rev().take(10).sum::<f64>() / 10.0;
423
424 let recent_val_loss = self.val_loss_history.iter().rev().take(10).sum::<f64>() / 10.0;
425
426 if recent_train_loss < recent_val_loss {
427 let gap = (recent_val_loss - recent_train_loss) / recent_train_loss;
428 (gap / self.overfitting_threshold).min(1.0)
429 } else {
430 0.0
431 }
432 }
433
434 fn calculate_accuracy_gap(&self) -> f64 {
435 if self.train_accuracy_history.len() < 10 || self.val_accuracy_history.len() < 10 {
436 return 0.0;
437 }
438
439 let recent_train_acc =
440 self.train_accuracy_history.iter().rev().take(10).sum::<f64>() / 10.0;
441
442 let recent_val_acc = self.val_accuracy_history.iter().rev().take(10).sum::<f64>() / 10.0;
443
444 if recent_train_acc > recent_val_acc {
445 let gap = recent_train_acc - recent_val_acc;
446 (gap / self.overfitting_threshold).min(1.0)
447 } else {
448 0.0
449 }
450 }
451
452 fn analyze_overfitting_trend(&self) -> f64 {
453 if self.train_loss_history.len() < 20 || self.val_loss_history.len() < 20 {
455 return 0.0;
456 }
457
458 let early_train_loss = self.train_loss_history.iter().take(10).sum::<f64>() / 10.0;
459
460 let recent_train_loss = self.train_loss_history.iter().rev().take(10).sum::<f64>() / 10.0;
461
462 let early_val_loss = self.val_loss_history.iter().take(10).sum::<f64>() / 10.0;
463
464 let recent_val_loss = self.val_loss_history.iter().rev().take(10).sum::<f64>() / 10.0;
465
466 let early_gap = (early_val_loss - early_train_loss).max(0.0);
467 let recent_gap = (recent_val_loss - recent_train_loss).max(0.0);
468
469 if recent_gap > early_gap && early_gap > 0.0 {
470 ((recent_gap - early_gap) / early_gap).min(1.0)
471 } else {
472 0.0
473 }
474 }
475}
476
477#[derive(Debug)]
479pub struct GeneralizationMonitor {
480 cross_validation_scores: Vec<f64>,
481 holdout_performance: Option<f64>,
482 train_performance: Option<f64>,
483 complexity_metrics: ComplexityMetrics,
484}
485
486#[derive(Debug)]
487pub struct ComplexityMetrics {
488 parameter_count: usize,
489 #[allow(dead_code)]
490 effective_capacity: f64,
491 data_size: usize,
492}
493
494impl Default for GeneralizationMonitor {
495 fn default() -> Self {
496 Self::new()
497 }
498}
499
500impl GeneralizationMonitor {
501 pub fn new() -> Self {
502 Self {
503 cross_validation_scores: Vec::new(),
504 holdout_performance: None,
505 train_performance: None,
506 complexity_metrics: ComplexityMetrics {
507 parameter_count: 0,
508 effective_capacity: 0.0,
509 data_size: 0,
510 },
511 }
512 }
513
514 pub fn update_performance(&mut self, train_perf: f64, val_perf: Option<f64>) {
515 self.train_performance = Some(train_perf);
516 if let Some(val_perf) = val_perf {
517 self.holdout_performance = Some(val_perf);
518 }
519 }
520
521 pub fn calculate_generalization_score(&self) -> f64 {
522 let performance_consistency = self.calculate_performance_consistency();
523 let complexity_penalty = self.calculate_complexity_penalty();
524 let cv_consistency = self.calculate_cv_consistency();
525
526 (performance_consistency + cv_consistency + (1.0 - complexity_penalty)) / 3.0
527 }
528
529 fn calculate_performance_consistency(&self) -> f64 {
530 match (self.train_performance, self.holdout_performance) {
531 (Some(train), Some(val)) => {
532 let gap = (train - val).abs();
533 (1.0 - gap.min(1.0)).max(0.0)
534 },
535 _ => 0.5, }
537 }
538
539 fn calculate_complexity_penalty(&self) -> f64 {
540 if self.complexity_metrics.data_size == 0 {
542 return 0.0;
543 }
544
545 let param_per_sample = self.complexity_metrics.parameter_count as f64
546 / self.complexity_metrics.data_size as f64;
547
548 if param_per_sample > 1.0 {
549 0.8 } else if param_per_sample > 0.1 {
551 0.4 } else {
553 0.1 }
555 }
556
557 fn calculate_cv_consistency(&self) -> f64 {
558 if self.cross_validation_scores.len() < 3 {
559 return 0.5;
560 }
561
562 let mean = self.cross_validation_scores.iter().sum::<f64>()
563 / self.cross_validation_scores.len() as f64;
564 let variance = self.cross_validation_scores.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
565 / (self.cross_validation_scores.len() - 1) as f64;
566
567 (1.0 - variance.sqrt().min(1.0)).max(0.0)
568 }
569}
570
571#[derive(Debug, Clone)]
573pub struct PerformanceBaseline {
574 pub baseline_loss: f64,
575 pub baseline_accuracy: f64,
576 pub baseline_training_time: Duration,
577 pub baseline_memory_usage: f64,
578 pub established_at: SystemTime,
579}
580
581impl HealthChecker {
582 pub fn new(config: &DebugConfig) -> Self {
584 Self {
585 config: config.clone(),
586 metrics_history: VecDeque::new(),
587 health_assessments: Vec::new(),
588 stability_tracker: StabilityTracker {
589 loss_stability: MetricStability::new(0.1, 0.05),
590 accuracy_stability: MetricStability::new(0.01, 0.02),
591 gradient_stability: MetricStability::new(1.0, 0.1),
592 learning_rate_stability: MetricStability::new(0.0001, 0.001),
593 window_size: 50,
594 },
595 convergence_analyzer: ConvergenceAnalyzer::new(),
596 overfitting_detector: OverfittingDetector::new(),
597 generalization_monitor: GeneralizationMonitor::new(),
598 performance_baseline: None,
599 }
600 }
601
602 pub fn update(&mut self, metrics: DashboardMetrics) {
604 self.metrics_history.push_back(metrics.clone());
605
606 if self.metrics_history.len() > 1000 {
608 self.metrics_history.pop_front();
609 }
610
611 if let Some(loss) = metrics.loss {
613 self.stability_tracker.loss_stability.update(loss);
614 }
615 if let Some(accuracy) = metrics.accuracy {
616 self.stability_tracker.accuracy_stability.update(accuracy);
617 }
618 if let Some(grad_norm) = metrics.gradient_norm {
619 self.stability_tracker.gradient_stability.update(grad_norm);
620 }
621 if let Some(lr) = metrics.learning_rate {
622 self.stability_tracker.learning_rate_stability.update(lr);
623 }
624
625 self.convergence_analyzer.update(metrics.loss, metrics.accuracy);
627
628 self.overfitting_detector.update_train_metrics(metrics.loss, metrics.accuracy);
630
631 if let (Some(accuracy), Some(loss)) = (metrics.accuracy, metrics.loss) {
633 self.generalization_monitor.update_performance(accuracy, Some(1.0 - loss));
634 }
635 }
636
637 pub fn assess_health(&mut self) -> Result<HealthAssessment> {
639 let overall_health_score = self.calculate_overall_health_score();
640 let training_stability_index = self.calculate_training_stability_index();
641 let convergence_probability = self.convergence_analyzer.calculate_convergence_probability();
642 let overfitting_risk = self.overfitting_detector.detect_overfitting();
643 let generalization_score = self.generalization_monitor.calculate_generalization_score();
644
645 let component_scores = self.calculate_component_scores();
646 let health_status = self.determine_health_status(overall_health_score);
647 let alerts = self.generate_health_alerts();
648 let recommendations = self.generate_health_recommendations(&alerts);
649
650 let assessment = HealthAssessment {
651 timestamp: SystemTime::now(),
652 overall_health_score,
653 training_stability_index,
654 convergence_probability,
655 overfitting_risk,
656 generalization_score,
657 component_scores,
658 health_status,
659 alerts,
660 recommendations,
661 };
662
663 self.health_assessments.push(assessment.clone());
664
665 if self.health_assessments.len() > 100 {
667 self.health_assessments.drain(0..50);
668 }
669
670 Ok(assessment)
671 }
672
673 fn calculate_overall_health_score(&self) -> f64 {
674 let component_scores = self.calculate_component_scores();
675
676 let weights = [
678 ("stability", 0.25),
679 ("convergence", 0.20),
680 ("gradient", 0.15),
681 ("loss", 0.15),
682 ("accuracy", 0.10),
683 ("performance", 0.10),
684 ("memory", 0.05),
685 ];
686
687 let mut weighted_sum = 0.0;
688 weighted_sum += weights[0].1 * component_scores.stability_health;
689 weighted_sum +=
690 weights[1].1 * self.convergence_analyzer.calculate_convergence_probability();
691 weighted_sum += weights[2].1 * component_scores.gradient_health;
692 weighted_sum += weights[3].1 * component_scores.loss_health;
693 weighted_sum += weights[4].1 * component_scores.accuracy_health;
694 weighted_sum += weights[5].1 * component_scores.performance_health;
695 weighted_sum += weights[6].1 * component_scores.memory_health;
696
697 weighted_sum
698 }
699
700 fn calculate_training_stability_index(&self) -> f64 {
701 let loss_stability = self.stability_tracker.loss_stability.calculate_stability();
702 let accuracy_stability = self.stability_tracker.accuracy_stability.calculate_stability();
703 let gradient_stability = self.stability_tracker.gradient_stability.calculate_stability();
704 let lr_stability = self.stability_tracker.learning_rate_stability.calculate_stability();
705
706 (loss_stability + accuracy_stability + gradient_stability + lr_stability) / 4.0
707 }
708
709 fn calculate_component_scores(&self) -> ComponentHealthScores {
710 ComponentHealthScores {
711 gradient_health: self.stability_tracker.gradient_stability.calculate_stability(),
712 loss_health: self.calculate_loss_health(),
713 accuracy_health: self.calculate_accuracy_health(),
714 performance_health: self.calculate_performance_health(),
715 memory_health: self.calculate_memory_health(),
716 stability_health: self.calculate_training_stability_index(),
717 }
718 }
719
720 fn calculate_loss_health(&self) -> f64 {
721 if self.metrics_history.len() < 10 {
722 return 0.5;
723 }
724
725 let recent_losses: Vec<f64> =
726 self.metrics_history.iter().rev().take(10).filter_map(|m| m.loss).collect();
727
728 if recent_losses.is_empty() {
729 return 0.5;
730 }
731
732 let first_half_avg = recent_losses[..recent_losses.len() / 2].iter().sum::<f64>()
734 / (recent_losses.len() / 2) as f64;
735 let second_half_avg = recent_losses[recent_losses.len() / 2..].iter().sum::<f64>()
736 / (recent_losses.len() - recent_losses.len() / 2) as f64;
737
738 if second_half_avg < first_half_avg {
739 0.8 } else if (second_half_avg - first_half_avg).abs() / first_half_avg < 0.05 {
741 0.6 } else {
743 0.3 }
745 }
746
747 fn calculate_accuracy_health(&self) -> f64 {
748 if self.metrics_history.len() < 10 {
749 return 0.5;
750 }
751
752 let recent_accuracies: Vec<f64> =
753 self.metrics_history.iter().rev().take(10).filter_map(|m| m.accuracy).collect();
754
755 if recent_accuracies.is_empty() {
756 return 0.5;
757 }
758
759 let avg_accuracy = recent_accuracies.iter().sum::<f64>() / recent_accuracies.len() as f64;
760
761 let accuracy_score = avg_accuracy; let stability_score = self.stability_tracker.accuracy_stability.calculate_stability();
764
765 (accuracy_score + stability_score) / 2.0
766 }
767
768 fn calculate_performance_health(&self) -> f64 {
769 if let Some(last_metrics) = self.metrics_history.back() {
771 let mut score = 0.0;
772 let mut components = 0;
773
774 if let Some(tps) = last_metrics.tokens_per_second {
775 score += if tps > 100.0 { 0.8 } else { tps / 125.0 };
776 components += 1;
777 }
778
779 if let Some(gpu_util) = last_metrics.gpu_utilization {
780 score += gpu_util;
781 components += 1;
782 }
783
784 if components > 0 {
785 score / components as f64
786 } else {
787 0.5
788 }
789 } else {
790 0.5
791 }
792 }
793
794 fn calculate_memory_health(&self) -> f64 {
795 if let Some(last_metrics) = self.metrics_history.back() {
796 let memory_usage = last_metrics.memory_usage_mb;
797
798 if memory_usage < 4096.0 {
801 0.9
802 } else if memory_usage < 6144.0 {
803 0.7
804 } else if memory_usage < 8192.0 {
805 0.5
806 } else {
807 0.2
808 }
809 } else {
810 0.5
811 }
812 }
813
814 fn determine_health_status(&self, score: f64) -> HealthStatus {
815 match score {
816 s if s >= 0.9 => HealthStatus::Excellent,
817 s if s >= 0.75 => HealthStatus::Good,
818 s if s >= 0.6 => HealthStatus::Fair,
819 s if s >= 0.4 => HealthStatus::Poor,
820 _ => HealthStatus::Critical,
821 }
822 }
823
824 fn generate_health_alerts(&self) -> Vec<HealthAlert> {
825 let mut alerts = Vec::new();
826
827 let stability_index = self.calculate_training_stability_index();
829 if stability_index < 0.3 {
830 alerts.push(HealthAlert {
831 alert_type: HealthAlertType::TrainingStability,
832 severity: AlertSeverity::High,
833 message: "Training is highly unstable".to_string(),
834 metric_value: stability_index,
835 threshold: 0.3,
836 trend: Trend::Degrading,
837 });
838 }
839
840 let convergence_prob = self.convergence_analyzer.calculate_convergence_probability();
842 if convergence_prob < 0.2 {
843 alerts.push(HealthAlert {
844 alert_type: HealthAlertType::ConvergenceIssue,
845 severity: AlertSeverity::Medium,
846 message: "Low probability of convergence".to_string(),
847 metric_value: convergence_prob,
848 threshold: 0.2,
849 trend: Trend::Stable,
850 });
851 }
852
853 match self.overfitting_detector.detect_overfitting() {
855 OverfittingRisk::High | OverfittingRisk::Severe => {
856 alerts.push(HealthAlert {
857 alert_type: HealthAlertType::OverfittingDetected,
858 severity: AlertSeverity::High,
859 message: "Significant overfitting detected".to_string(),
860 metric_value: 0.8,
861 threshold: 0.6,
862 trend: Trend::Degrading,
863 });
864 },
865 OverfittingRisk::Medium => {
866 alerts.push(HealthAlert {
867 alert_type: HealthAlertType::OverfittingDetected,
868 severity: AlertSeverity::Medium,
869 message: "Moderate overfitting risk".to_string(),
870 metric_value: 0.5,
871 threshold: 0.4,
872 trend: Trend::Stable,
873 });
874 },
875 _ => {},
876 }
877
878 alerts
879 }
880
881 fn generate_health_recommendations(&self, alerts: &[HealthAlert]) -> Vec<HealthRecommendation> {
882 let mut recommendations = Vec::new();
883
884 for alert in alerts {
885 match alert.alert_type {
886 HealthAlertType::TrainingStability => {
887 recommendations.push(HealthRecommendation {
888 category: RecommendationCategory::Training,
889 title: "Improve Training Stability".to_string(),
890 description:
891 "Reduce learning rate or increase batch size to stabilize training"
892 .to_string(),
893 urgency: RecommendationUrgency::Soon,
894 expected_impact: 0.3,
895 });
896 },
897 HealthAlertType::ConvergenceIssue => {
898 recommendations.push(HealthRecommendation {
899 category: RecommendationCategory::Hyperparameters,
900 title: "Adjust Learning Rate Schedule".to_string(),
901 description:
902 "Implement learning rate scheduling or adjust optimizer settings"
903 .to_string(),
904 urgency: RecommendationUrgency::Eventually,
905 expected_impact: 0.2,
906 });
907 },
908 HealthAlertType::OverfittingDetected => {
909 recommendations.push(HealthRecommendation {
910 category: RecommendationCategory::Training,
911 title: "Add Regularization".to_string(),
912 description: "Implement dropout, weight decay, or early stopping to reduce overfitting".to_string(),
913 urgency: RecommendationUrgency::Soon,
914 expected_impact: 0.25,
915 });
916 },
917 _ => {},
918 }
919 }
920
921 let overall_score = self.calculate_overall_health_score();
923 if overall_score < 0.6 {
924 recommendations.push(HealthRecommendation {
925 category: RecommendationCategory::Training,
926 title: "Comprehensive Training Review".to_string(),
927 description: "Review entire training setup including data, model architecture, and hyperparameters".to_string(),
928 urgency: RecommendationUrgency::Immediate,
929 expected_impact: 0.4,
930 });
931 }
932
933 recommendations
934 }
935
936 pub fn set_baseline(&mut self, baseline: PerformanceBaseline) {
938 self.performance_baseline = Some(baseline);
939 }
940
941 pub fn get_health_history(&self) -> &[HealthAssessment] {
943 &self.health_assessments
944 }
945
946 pub async fn quick_health_check(&self) -> Result<crate::QuickHealthSummary> {
948 let score = if let Some(assessment) = self.health_assessments.last() {
949 assessment.overall_health_score * 100.0
950 } else {
951 50.0 };
954
955 let status = match score {
956 90.0..=100.0 => "Excellent",
957 75.0..89.9 => "Good",
958 60.0..74.9 => "Fair",
959 40.0..59.9 => "Poor",
960 _ => "Critical",
961 }
962 .to_string();
963
964 let mut recommendations = Vec::new();
965 if score < 60.0 {
966 recommendations.push("Review training configuration and data quality".to_string());
967 }
968 if score < 40.0 {
969 recommendations
970 .push("Consider adjusting learning rate and model architecture".to_string());
971 }
972 if score < 80.0 {
973 recommendations.push("Monitor training stability and convergence".to_string());
974 }
975
976 Ok(crate::QuickHealthSummary {
977 score,
978 status,
979 recommendations,
980 })
981 }
982
983 pub async fn generate_report(&self) -> Result<HealthReport> {
985 let current_assessment = if let Some(assessment) = self.health_assessments.last() {
986 assessment.clone()
987 } else {
988 return Ok(HealthReport::default());
989 };
990
991 let health_trends = self.analyze_health_trends();
992 let risk_assessment = self.assess_risks();
993 let improvement_suggestions = self.generate_improvement_suggestions();
994
995 Ok(HealthReport {
996 current_health: current_assessment,
997 health_trends,
998 risk_assessment,
999 improvement_suggestions,
1000 baseline_comparison: self.compare_with_baseline(),
1001 summary: self.generate_health_summary(),
1002 })
1003 }
1004
1005 fn analyze_health_trends(&self) -> HealthTrends {
1006 if self.health_assessments.len() < 5 {
1007 return HealthTrends::default();
1008 }
1009
1010 let recent_scores: Vec<f64> = self
1011 .health_assessments
1012 .iter()
1013 .rev()
1014 .take(10)
1015 .map(|a| a.overall_health_score)
1016 .collect();
1017
1018 let first_half_avg = recent_scores[recent_scores.len() / 2..].iter().sum::<f64>()
1019 / (recent_scores.len() - recent_scores.len() / 2) as f64;
1020 let second_half_avg = recent_scores[..recent_scores.len() / 2].iter().sum::<f64>()
1021 / (recent_scores.len() / 2) as f64;
1022
1023 let trend = if second_half_avg > first_half_avg * 1.05 {
1024 Trend::Improving
1025 } else if second_half_avg < first_half_avg * 0.95 {
1026 Trend::Degrading
1027 } else {
1028 Trend::Stable
1029 };
1030
1031 HealthTrends {
1032 overall_trend: trend,
1033 stability_trend: Trend::Stable, convergence_trend: Trend::Stable,
1035 overfitting_trend: Trend::Stable,
1036 }
1037 }
1038
1039 fn assess_risks(&self) -> Vec<HealthRisk> {
1040 let mut risks = Vec::new();
1041
1042 if let Some(current) = self.health_assessments.last() {
1043 if current.overall_health_score < 0.4 {
1044 risks.push(HealthRisk {
1045 risk_type: "Poor Overall Health".to_string(),
1046 probability: 0.9,
1047 impact: 0.8,
1048 description: "Model training is in poor health and may fail".to_string(),
1049 });
1050 }
1051
1052 match current.overfitting_risk {
1053 OverfittingRisk::High | OverfittingRisk::Severe => {
1054 risks.push(HealthRisk {
1055 risk_type: "Overfitting".to_string(),
1056 probability: 0.8,
1057 impact: 0.6,
1058 description: "Model is likely overfitting and will generalize poorly"
1059 .to_string(),
1060 });
1061 },
1062 _ => {},
1063 }
1064
1065 if current.convergence_probability < 0.3 {
1066 risks.push(HealthRisk {
1067 risk_type: "Training Failure".to_string(),
1068 probability: 0.7,
1069 impact: 0.9,
1070 description: "Training may not converge to a useful solution".to_string(),
1071 });
1072 }
1073 }
1074
1075 risks
1076 }
1077
1078 fn generate_improvement_suggestions(&self) -> Vec<ImprovementSuggestion> {
1079 let mut suggestions = Vec::new();
1080
1081 if let Some(current) = self.health_assessments.last() {
1082 if current.component_scores.stability_health < 0.5 {
1083 suggestions.push(ImprovementSuggestion {
1084 area: "Training Stability".to_string(),
1085 suggestion: "Reduce learning rate and increase batch size".to_string(),
1086 expected_improvement: 0.3,
1087 implementation_effort: "Low".to_string(),
1088 });
1089 }
1090
1091 if current.convergence_probability < 0.5 {
1092 suggestions.push(ImprovementSuggestion {
1093 area: "Convergence".to_string(),
1094 suggestion: "Implement learning rate scheduling and gradient clipping"
1095 .to_string(),
1096 expected_improvement: 0.25,
1097 implementation_effort: "Medium".to_string(),
1098 });
1099 }
1100
1101 match current.overfitting_risk {
1102 OverfittingRisk::Medium | OverfittingRisk::High | OverfittingRisk::Severe => {
1103 suggestions.push(ImprovementSuggestion {
1104 area: "Overfitting Prevention".to_string(),
1105 suggestion: "Add dropout layers, implement early stopping, or increase training data".to_string(),
1106 expected_improvement: 0.4,
1107 implementation_effort: "Medium".to_string(),
1108 });
1109 },
1110 _ => {},
1111 }
1112 }
1113
1114 suggestions
1115 }
1116
1117 fn compare_with_baseline(&self) -> Option<BaselineComparison> {
1118 if let (Some(_baseline), Some(current)) =
1119 (&self.performance_baseline, self.health_assessments.last())
1120 {
1121 Some(BaselineComparison {
1122 health_score_change: current.overall_health_score - 0.8, stability_change: current.training_stability_index - 0.7,
1124 convergence_change: current.convergence_probability - 0.6,
1125 improvement_percentage: ((current.overall_health_score - 0.8) / 0.8 * 100.0)
1126 .max(-100.0),
1127 })
1128 } else {
1129 None
1130 }
1131 }
1132
1133 fn generate_health_summary(&self) -> String {
1134 if let Some(current) = self.health_assessments.last() {
1135 match current.health_status {
1136 HealthStatus::Excellent => "Training is in excellent health with stable convergence and no significant issues detected.".to_string(),
1137 HealthStatus::Good => "Training is proceeding well with minor optimization opportunities.".to_string(),
1138 HealthStatus::Fair => "Training shows some concerning patterns that should be addressed.".to_string(),
1139 HealthStatus::Poor => "Training has significant issues requiring immediate attention.".to_string(),
1140 HealthStatus::Critical => "Training is in critical condition and may fail without intervention.".to_string(),
1141 }
1142 } else {
1143 "Insufficient data for health assessment.".to_string()
1144 }
1145 }
1146}
1147
1148#[derive(Debug, Serialize, Deserialize)]
1151pub struct HealthReport {
1152 pub current_health: HealthAssessment,
1153 pub health_trends: HealthTrends,
1154 pub risk_assessment: Vec<HealthRisk>,
1155 pub improvement_suggestions: Vec<ImprovementSuggestion>,
1156 pub baseline_comparison: Option<BaselineComparison>,
1157 pub summary: String,
1158}
1159
1160impl Default for HealthReport {
1161 fn default() -> Self {
1162 Self {
1163 current_health: HealthAssessment {
1164 timestamp: SystemTime::now(),
1165 overall_health_score: 0.5,
1166 training_stability_index: 0.5,
1167 convergence_probability: 0.5,
1168 overfitting_risk: OverfittingRisk::None,
1169 generalization_score: 0.5,
1170 component_scores: ComponentHealthScores {
1171 gradient_health: 0.5,
1172 loss_health: 0.5,
1173 accuracy_health: 0.5,
1174 performance_health: 0.5,
1175 memory_health: 0.5,
1176 stability_health: 0.5,
1177 },
1178 health_status: HealthStatus::Fair,
1179 alerts: Vec::new(),
1180 recommendations: Vec::new(),
1181 },
1182 health_trends: HealthTrends::default(),
1183 risk_assessment: Vec::new(),
1184 improvement_suggestions: Vec::new(),
1185 baseline_comparison: None,
1186 summary: "No health data available yet.".to_string(),
1187 }
1188 }
1189}
1190
1191#[derive(Debug, Clone, Serialize, Deserialize)]
1192pub struct HealthTrends {
1193 pub overall_trend: Trend,
1194 pub stability_trend: Trend,
1195 pub convergence_trend: Trend,
1196 pub overfitting_trend: Trend,
1197}
1198
1199impl Default for HealthTrends {
1200 fn default() -> Self {
1201 Self {
1202 overall_trend: Trend::Stable,
1203 stability_trend: Trend::Stable,
1204 convergence_trend: Trend::Stable,
1205 overfitting_trend: Trend::Stable,
1206 }
1207 }
1208}
1209
1210#[derive(Debug, Clone, Serialize, Deserialize)]
1211pub struct HealthRisk {
1212 pub risk_type: String,
1213 pub probability: f64,
1214 pub impact: f64,
1215 pub description: String,
1216}
1217
1218#[derive(Debug, Clone, Serialize, Deserialize)]
1219pub struct ImprovementSuggestion {
1220 pub area: String,
1221 pub suggestion: String,
1222 pub expected_improvement: f64,
1223 pub implementation_effort: String,
1224}
1225
1226#[derive(Debug, Clone, Serialize, Deserialize)]
1227pub struct BaselineComparison {
1228 pub health_score_change: f64,
1229 pub stability_change: f64,
1230 pub convergence_change: f64,
1231 pub improvement_percentage: f64,
1232}