1use anyhow::Result;
7use log;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, VecDeque};
10use trustformers_core::errors::runtime_error;
11use trustformers_core::tensor::Tensor;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AdvancedStabilityConfig {
16 pub predictive_detection: bool,
18 pub proactive_recovery: bool,
20 pub dynamics_analysis: bool,
22 pub loss_landscape_monitoring: bool,
24 pub prediction_horizon: usize,
26 pub prediction_confidence_threshold: f32,
28 pub pattern_window_size: usize,
30 pub stability_threshold: f32,
32 pub adaptive_recovery: bool,
34}
35
36impl Default for AdvancedStabilityConfig {
37 fn default() -> Self {
38 Self {
39 predictive_detection: true,
40 proactive_recovery: true,
41 dynamics_analysis: true,
42 loss_landscape_monitoring: true,
43 prediction_horizon: 10,
44 prediction_confidence_threshold: 0.7,
45 pattern_window_size: 50,
46 stability_threshold: 0.8,
47 adaptive_recovery: true,
48 }
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct TrainingDynamics {
55 pub loss_trend: TrendDirection,
57 pub gradient_trend: TrendDirection,
59 pub lr_effectiveness: f32,
61 pub convergence_velocity: f32,
63 pub oscillation_frequency: f32,
65 pub phase_trajectory: Vec<(f32, f32)>, }
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub enum TrendDirection {
71 Decreasing,
72 Increasing,
73 Stable,
74 Oscillating,
75 Diverging,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct PredictiveAnomaly {
81 pub predicted_step: usize,
83 pub anomaly_type: PredictedAnomalyType,
85 pub confidence: f32,
87 pub time_to_occurrence: usize,
89 pub preventive_actions: Vec<PreventiveAction>,
91 pub risk_level: RiskLevel,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub enum PredictedAnomalyType {
97 GradientExplosion,
98 GradientVanishing,
99 TrainingStagnation,
100 ConvergenceFailure,
101 NumericalInstability,
102 OscillatingLoss,
103 MemoryExhaustion,
104 LearningRateDeterioration,
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub enum PreventiveAction {
109 ReduceLearningRate {
110 factor: f32,
111 },
112 IncreaseGradientClipping {
113 new_threshold: f32,
114 },
115 AdjustOptimizer {
116 suggested_params: HashMap<String, f32>,
117 },
118 TriggerEarlyCheckpoint,
119 ModifyBatchSize {
120 new_size: usize,
121 },
122 AdjustWarmupSchedule,
123 EnableNoise {
124 noise_level: f32,
125 },
126 ResetAccumulatedGradients,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub enum RiskLevel {
131 Low,
132 Medium,
133 High,
134 Critical,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct LossLandscapeAnalysis {
140 pub local_curvature: f32,
142 pub gradient_consistency: f32,
144 pub escape_difficulty: f32,
146 pub basin_stability: f32,
148 pub saddle_point_prob: f32,
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct StabilityScore {
155 pub overall_score: f32,
157 pub gradient_stability: f32,
159 pub loss_stability: f32,
161 pub convergence_stability: f32,
163 pub numerical_stability: f32,
165 pub recommendations: Vec<String>,
167}
168
169#[allow(dead_code)]
171pub struct AdvancedStabilityMonitor {
172 config: AdvancedStabilityConfig,
173 loss_history: VecDeque<f32>,
174 gradient_history: VecDeque<f32>,
175 lr_history: VecDeque<f32>,
176 dynamics_history: Vec<TrainingDynamics>,
177 predicted_anomalies: Vec<PredictiveAnomaly>,
178 landscape_analyses: VecDeque<LossLandscapeAnalysis>,
179 stability_scores: VecDeque<StabilityScore>,
180 #[allow(dead_code)]
181 recovery_effectiveness: HashMap<PreventiveAction, f32>,
182 pattern_detector: PatternDetector,
183}
184
185impl AdvancedStabilityMonitor {
186 pub fn new(config: AdvancedStabilityConfig) -> Self {
187 Self {
188 config,
189 loss_history: VecDeque::new(),
190 gradient_history: VecDeque::new(),
191 lr_history: VecDeque::new(),
192 dynamics_history: Vec::new(),
193 predicted_anomalies: Vec::new(),
194 landscape_analyses: VecDeque::new(),
195 stability_scores: VecDeque::new(),
196 recovery_effectiveness: HashMap::new(),
197 pattern_detector: PatternDetector::new(),
198 }
199 }
200
201 pub fn analyze_step(
203 &mut self,
204 step: usize,
205 loss: f32,
206 gradient_norm: f32,
207 learning_rate: f32,
208 gradients: &HashMap<String, Tensor>,
209 ) -> Result<()> {
210 self.update_histories(loss, gradient_norm, learning_rate);
212
213 if self.config.dynamics_analysis {
215 let dynamics = self.analyze_training_dynamics()?;
216 self.dynamics_history.push(dynamics);
217 }
218
219 if self.config.loss_landscape_monitoring {
221 let landscape = self.analyze_loss_landscape(gradients)?;
222 self.landscape_analyses.push_back(landscape);
223 if self.landscape_analyses.len() > self.config.pattern_window_size {
224 self.landscape_analyses.pop_front();
225 }
226 }
227
228 let stability = self.compute_stability_score()?;
230 self.stability_scores.push_back(stability);
231 if self.stability_scores.len() > self.config.pattern_window_size {
232 self.stability_scores.pop_front();
233 }
234
235 if self.config.predictive_detection {
237 let predictions = self.predict_anomalies(step)?;
238 self.predicted_anomalies.extend(predictions);
239 }
240
241 Ok(())
242 }
243
244 pub fn get_stability_report(&self) -> StabilityReport {
246 let current_stability =
247 self.stability_scores.back().map(|s| s.overall_score).unwrap_or(1.0);
248
249 let immediate_risks: Vec<PredictiveAnomaly> = self
250 .predicted_anomalies
251 .iter()
252 .filter(|anomaly| anomaly.time_to_occurrence <= 5)
253 .cloned()
254 .collect();
255
256 let trend_analysis = self.analyze_stability_trend();
257
258 StabilityReport {
259 current_stability_score: current_stability,
260 stability_trend: trend_analysis,
261 immediate_risks,
262 predicted_anomalies: self.predicted_anomalies.clone(),
263 landscape_health: self.landscape_analyses.back().cloned(),
264 recommendations: self.generate_recommendations(),
265 confidence_level: self.compute_prediction_confidence(),
266 }
267 }
268
269 pub fn apply_proactive_recovery(
271 &mut self,
272 trainer_params: &mut TrainerParameters,
273 ) -> Result<Vec<PreventiveAction>> {
274 if !self.config.proactive_recovery {
275 return Ok(Vec::new());
276 }
277
278 let mut applied_actions = Vec::new();
279
280 let mut actions_to_apply = Vec::new();
282
283 for anomaly in &self.predicted_anomalies {
284 if anomaly.confidence >= self.config.prediction_confidence_threshold
285 && anomaly.time_to_occurrence <= 3
286 {
287 for action in &anomaly.preventive_actions {
288 if self.should_apply_action(action, trainer_params) {
289 actions_to_apply.push(action.clone());
290 }
291 }
292 }
293 }
294
295 for action in actions_to_apply {
297 self.apply_preventive_action(&action, trainer_params)?;
298 applied_actions.push(action);
299 }
300
301 Ok(applied_actions)
302 }
303
304 fn update_histories(&mut self, loss: f32, gradient_norm: f32, learning_rate: f32) {
305 self.loss_history.push_back(loss);
306 self.gradient_history.push_back(gradient_norm);
307 self.lr_history.push_back(learning_rate);
308
309 let max_len = self.config.pattern_window_size;
310 if self.loss_history.len() > max_len {
311 self.loss_history.pop_front();
312 }
313 if self.gradient_history.len() > max_len {
314 self.gradient_history.pop_front();
315 }
316 if self.lr_history.len() > max_len {
317 self.lr_history.pop_front();
318 }
319 }
320
321 fn analyze_training_dynamics(&self) -> Result<TrainingDynamics> {
322 let loss_trend = self.compute_trend(&self.loss_history);
323 let gradient_trend = self.compute_trend(&self.gradient_history);
324 let lr_effectiveness = self.compute_lr_effectiveness();
325 let convergence_velocity = self.compute_convergence_velocity();
326 let oscillation_frequency = self.compute_oscillation_frequency();
327 let phase_trajectory = self.compute_phase_trajectory();
328
329 Ok(TrainingDynamics {
330 loss_trend,
331 gradient_trend,
332 lr_effectiveness,
333 convergence_velocity,
334 oscillation_frequency,
335 phase_trajectory,
336 })
337 }
338
339 fn analyze_loss_landscape(
340 &self,
341 gradients: &HashMap<String, Tensor>,
342 ) -> Result<LossLandscapeAnalysis> {
343 let local_curvature = self.estimate_local_curvature(gradients).unwrap_or_else(|e| {
344 log::warn!("Failed to estimate local curvature: {}", e);
345 0.1
346 });
347
348 let gradient_consistency =
349 self.compute_gradient_consistency(gradients).unwrap_or_else(|e| {
350 log::warn!("Failed to compute gradient consistency: {}", e);
351 0.8
352 });
353
354 let escape_difficulty = self.estimate_escape_difficulty();
355 let basin_stability = self.estimate_basin_stability();
356
357 let saddle_point_prob =
358 self.estimate_saddle_point_probability(gradients).unwrap_or_else(|e| {
359 log::warn!("Failed to estimate saddle point probability: {}", e);
360 0.2
361 });
362
363 Ok(LossLandscapeAnalysis {
364 local_curvature,
365 gradient_consistency,
366 escape_difficulty,
367 basin_stability,
368 saddle_point_prob,
369 })
370 }
371
372 fn compute_stability_score(&self) -> Result<StabilityScore> {
373 let gradient_stability = self.compute_gradient_stability();
374 let loss_stability = self.compute_loss_stability();
375 let convergence_stability = self.compute_convergence_stability();
376 let numerical_stability = self.compute_numerical_stability();
377
378 let overall_score =
379 (gradient_stability + loss_stability + convergence_stability + numerical_stability)
380 / 4.0;
381
382 let recommendations = self.generate_stability_recommendations(
383 gradient_stability,
384 loss_stability,
385 convergence_stability,
386 numerical_stability,
387 );
388
389 Ok(StabilityScore {
390 overall_score,
391 gradient_stability,
392 loss_stability,
393 convergence_stability,
394 numerical_stability,
395 recommendations,
396 })
397 }
398
399 fn predict_anomalies(&self, current_step: usize) -> Result<Vec<PredictiveAnomaly>> {
400 let mut predictions = Vec::new();
401
402 if let Some(anomaly) = self.predict_gradient_explosion(current_step)? {
404 predictions.push(anomaly);
405 }
406
407 if let Some(anomaly) = self.predict_training_stagnation(current_step)? {
409 predictions.push(anomaly);
410 }
411
412 if let Some(anomaly) = self.predict_numerical_instability(current_step)? {
414 predictions.push(anomaly);
415 }
416
417 if let Some(anomaly) = self.predict_oscillating_loss(current_step)? {
419 predictions.push(anomaly);
420 }
421
422 Ok(predictions)
423 }
424
425 fn compute_trend(&self, history: &VecDeque<f32>) -> TrendDirection {
427 if history.len() < 3 {
428 return TrendDirection::Stable;
429 }
430
431 let mut recent: Vec<f32> = history.iter().rev().take(10).cloned().collect();
433 recent.reverse(); let slope = self.compute_slope(&recent);
435 let variance = self.compute_variance(&recent);
436
437 if variance > 0.1 {
438 TrendDirection::Oscillating
439 } else if slope < -0.01 {
440 TrendDirection::Decreasing
441 } else if slope > 0.01 {
442 TrendDirection::Increasing
443 } else {
444 TrendDirection::Stable
445 }
446 }
447
448 fn compute_slope(&self, values: &[f32]) -> f32 {
449 if values.len() < 2 {
450 return 0.0;
451 }
452
453 let n = values.len() as f32;
454 let sum_x: f32 = (0..values.len()).map(|i| i as f32).sum();
455 let sum_y: f32 = values.iter().sum();
456 let sum_xy: f32 = values.iter().enumerate().map(|(i, &y)| i as f32 * y).sum();
457 let sum_x2: f32 = (0..values.len()).map(|i| (i as f32).powi(2)).sum();
458
459 (n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x)
460 }
461
462 fn compute_variance(&self, values: &[f32]) -> f32 {
463 if values.is_empty() {
464 return 0.0;
465 }
466
467 let mean: f32 = values.iter().sum::<f32>() / values.len() as f32;
468 let variance: f32 =
469 values.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32;
470
471 variance
472 }
473
474 fn compute_lr_effectiveness(&self) -> f32 {
476 if self.loss_history.len() < 5 || self.lr_history.len() < 5 {
477 return 0.5;
478 }
479
480 let mut lr_effectiveness_scores = Vec::new();
482
483 for window in self
484 .loss_history
485 .iter()
486 .zip(self.lr_history.iter())
487 .collect::<Vec<_>>()
488 .windows(3)
489 {
490 if let [(l1, lr1), (l2, lr2), (_l3, _lr3)] = window {
491 let loss_improvement = (*l1 - *l2) / l1.max(1e-8f32);
492 let lr_change = (*lr2 - *lr1) / lr1.max(1e-8f32);
493
494 if loss_improvement > 0.0 && lr_change > 0.0 {
496 lr_effectiveness_scores.push(0.8);
497 } else if loss_improvement < 0.0 && lr_change < 0.0 {
498 lr_effectiveness_scores.push(0.6);
499 } else {
500 lr_effectiveness_scores.push(0.3);
501 }
502 }
503 }
504
505 if lr_effectiveness_scores.is_empty() {
506 0.5
507 } else {
508 lr_effectiveness_scores.iter().sum::<f32>() / lr_effectiveness_scores.len() as f32
509 }
510 }
511
512 fn compute_convergence_velocity(&self) -> f32 {
513 if self.loss_history.len() < 10 {
514 return 0.0;
515 }
516
517 let recent_losses: Vec<f32> = self.loss_history.iter().rev().take(10).cloned().collect();
518 let slope = self.compute_slope(&recent_losses);
519
520 if slope < 0.0 {
523 (-slope * 100.0).min(1.0)
524 } else {
525 0.0
526 }
527 }
528
529 fn compute_oscillation_frequency(&self) -> f32 {
530 if self.loss_history.len() < 10 {
531 return 0.0;
532 }
533
534 let recent_losses: Vec<f32> = self.loss_history.iter().rev().take(20).cloned().collect();
535 let mut direction_changes = 0;
536
537 for window in recent_losses.windows(3) {
538 if (window[1] > window[0]) != (window[2] > window[1]) {
539 direction_changes += 1;
540 }
541 }
542
543 direction_changes as f32 / (recent_losses.len() - 2).max(1) as f32
545 }
546
547 fn compute_phase_trajectory(&self) -> Vec<(f32, f32)> {
548 self.loss_history
549 .iter()
550 .zip(self.gradient_history.iter())
551 .map(|(&l, &g)| (l, g))
552 .collect()
553 }
554
555 fn estimate_local_curvature(&self, gradients: &HashMap<String, Tensor>) -> Result<f32> {
556 if gradients.is_empty() || self.gradient_history.len() < 3 {
557 return Ok(0.1);
558 }
559
560 let _current_norm = self.compute_total_gradient_norm(gradients)?;
562 let recent_norms: Vec<f32> = self.gradient_history.iter().rev().take(3).cloned().collect();
563
564 if recent_norms.len() >= 3 {
565 let second_derivative = recent_norms[0] - 2.0 * recent_norms[1] + recent_norms[2];
567 let curvature = second_derivative.abs() / (recent_norms[1].max(1e-8));
568 Ok(curvature.min(10.0)) } else {
570 Ok(0.1)
571 }
572 }
573
574 fn compute_gradient_consistency(&self, gradients: &HashMap<String, Tensor>) -> Result<f32> {
575 if gradients.len() < 2 {
576 return Ok(1.0);
577 }
578
579 let mut norms = Vec::new();
581 for tensor in gradients.values() {
582 let data = tensor.data().unwrap_or_default();
583 let norm = data.iter().map(|&x| x * x).sum::<f32>().sqrt();
584 norms.push(norm);
585 }
586
587 if norms.is_empty() {
588 return Ok(1.0);
589 }
590
591 let mean_norm = norms.iter().sum::<f32>() / norms.len() as f32;
592 let variance =
593 norms.iter().map(|&x| (x - mean_norm).powi(2)).sum::<f32>() / norms.len() as f32;
594 let cv = variance.sqrt() / mean_norm.max(1e-8);
595
596 Ok((1.0 / (1.0 + cv * 2.0)).clamp(0.0, 1.0))
598 }
599
600 fn estimate_escape_difficulty(&self) -> f32 {
601 if self.loss_history.len() < 20 {
602 return 0.3;
603 }
604
605 let recent_losses: Vec<f32> = self.loss_history.iter().rev().take(20).cloned().collect();
607 let mut local_minima_count = 0;
608
609 for window in recent_losses.windows(5) {
610 if window[2] < window[0]
611 && window[2] < window[1]
612 && window[2] < window[3]
613 && window[2] < window[4]
614 {
615 local_minima_count += 1;
616 }
617 }
618
619 (local_minima_count as f32 / 5.0).min(1.0)
621 }
622
623 fn estimate_basin_stability(&self) -> f32 {
624 if self.loss_history.len() < 10 {
625 return 0.7;
626 }
627
628 let recent_losses: Vec<f32> = self.loss_history.iter().rev().take(10).cloned().collect();
630 let variance = self.compute_variance(&recent_losses);
631 let mean_loss = recent_losses.iter().sum::<f32>() / recent_losses.len() as f32;
632 let cv = variance.sqrt() / mean_loss.max(1e-8);
633
634 (1.0 / (1.0 + cv * 3.0)).clamp(0.0, 1.0)
636 }
637
638 fn estimate_saddle_point_probability(
639 &self,
640 gradients: &HashMap<String, Tensor>,
641 ) -> Result<f32> {
642 if gradients.is_empty() || self.gradient_history.len() < 5 {
643 return Ok(0.2);
644 }
645
646 let current_grad_norm = self.compute_total_gradient_norm(gradients)?;
647
648 let small_gradient = current_grad_norm < 0.01;
650 let curvature = self.estimate_local_curvature(gradients)?;
651 let high_curvature = curvature > 0.1;
652
653 let probability = if small_gradient && high_curvature {
654 0.8
655 } else if small_gradient {
656 0.4
657 } else {
658 0.1
659 };
660
661 Ok(probability)
662 }
663
664 fn compute_gradient_stability(&self) -> f32 {
665 if self.gradient_history.len() < 5 {
666 return 1.0;
667 }
668 let variance =
669 self.compute_variance(&self.gradient_history.iter().cloned().collect::<Vec<_>>());
670 (1.0 / (1.0 + variance)).clamp(0.0, 1.0)
671 }
672
673 fn compute_loss_stability(&self) -> f32 {
674 if self.loss_history.len() < 5 {
675 return 1.0;
676 }
677 let recent_losses: Vec<f32> = self.loss_history.iter().rev().take(10).cloned().collect();
678 let slope = self.compute_slope(&recent_losses);
679 if slope < 0.0 {
680 0.9
681 } else if slope < 0.01 {
682 0.7
683 } else {
684 0.3
685 }
686 }
687
688 fn compute_convergence_stability(&self) -> f32 {
689 if self.loss_history.len() < 10 {
690 return 0.8;
691 }
692
693 let convergence_velocity = self.compute_convergence_velocity();
694 let oscillation_freq = self.compute_oscillation_frequency();
695
696 let velocity_score = convergence_velocity.min(0.5) * 2.0; let stability_score = (1.0 - oscillation_freq).max(0.0);
699
700 (velocity_score * 0.6 + stability_score * 0.4).clamp(0.0, 1.0)
701 }
702
703 fn compute_numerical_stability(&self) -> f32 {
704 if self.loss_history.is_empty() || self.gradient_history.is_empty() {
705 return 0.9;
706 }
707
708 let recent_losses: Vec<f32> = self.loss_history.iter().rev().take(10).cloned().collect();
710 let recent_grads: Vec<f32> = self.gradient_history.iter().rev().take(10).cloned().collect();
711
712 let loss_issues = recent_losses.iter().any(|&x| !x.is_finite());
713 let grad_issues = recent_grads.iter().any(|&x| !x.is_finite());
714
715 let extreme_values = recent_losses.iter().any(|&x| !(-1e6..=1e6).contains(&x))
716 || recent_grads.iter().any(|&x| !(-1e6..=1e6).contains(&x));
717
718 if loss_issues || grad_issues {
719 0.0 } else if extreme_values {
721 0.3 } else {
723 let max_loss = recent_losses.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
724 let max_grad = recent_grads.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
725
726 let loss_penalty = if max_loss > 1000.0 { 0.3 } else { 0.0 };
728 let grad_penalty = if max_grad > 100.0 { 0.2 } else { 0.0 };
729
730 (1.0f32 - loss_penalty - grad_penalty).max(0.0f32)
731 }
732 }
733
734 fn generate_stability_recommendations(
735 &self,
736 gs: f32,
737 ls: f32,
738 cs: f32,
739 ns: f32,
740 ) -> Vec<String> {
741 let mut recommendations = Vec::new();
742
743 if gs < 0.5 {
744 recommendations.push(
745 "Consider gradient clipping or normalization to improve gradient stability"
746 .to_string(),
747 );
748 }
749
750 if ls < 0.5 {
751 recommendations.push(
752 "Loss appears unstable - consider reducing learning rate or adjusting optimizer"
753 .to_string(),
754 );
755 }
756
757 if cs < 0.5 {
758 recommendations.push("Poor convergence stability - consider learning rate scheduling or different optimizer".to_string());
759 }
760
761 if ns < 0.5 {
762 recommendations.push("Numerical instability detected - check for NaN/Inf values and consider mixed precision".to_string());
763 }
764
765 let overall_score = (gs + ls + cs + ns) / 4.0;
766
767 if overall_score < 0.3 {
768 recommendations.push(
769 "Critical stability issues - consider checkpoint rollback and parameter reset"
770 .to_string(),
771 );
772 } else if overall_score < 0.6 {
773 recommendations.push("Moderate stability issues - monitor closely and consider conservative training settings".to_string());
774 } else if recommendations.is_empty() {
775 recommendations.push("Training stability is good - continue monitoring".to_string());
776 }
777
778 recommendations
779 }
780
781 fn analyze_stability_trend(&self) -> TrendDirection {
782 let scores: Vec<f32> = self.stability_scores.iter().map(|s| s.overall_score).collect();
783 self.compute_trend(&scores.into_iter().collect())
784 }
785
786 fn generate_recommendations(&self) -> Vec<String> {
787 vec!["Continue monitoring training progress".to_string()]
788 }
789
790 fn compute_prediction_confidence(&self) -> f32 {
791 let history_quality = if self.loss_history.len() >= 20 { 0.9 } else { 0.5 };
793 let data_quality = if self.loss_history.iter().all(|&x| x.is_finite()) { 0.9 } else { 0.3 };
794 let trend_consistency = if self.dynamics_history.len() >= 3 { 0.8 } else { 0.4 };
795
796 (history_quality * 0.4f32 + data_quality * 0.4f32 + trend_consistency * 0.2f32).min(1.0f32)
797 }
798
799 fn compute_total_gradient_norm(&self, gradients: &HashMap<String, Tensor>) -> Result<f32> {
801 let mut total_norm_sq = 0.0f32;
802
803 for tensor in gradients.values() {
804 let data = tensor.data().map_err(|_| runtime_error("Failed to get tensor data"))?;
805 let tensor_norm_sq: f32 = data.iter().map(|&x| x * x).sum();
806 total_norm_sq += tensor_norm_sq;
807 }
808
809 Ok(total_norm_sq.sqrt())
810 }
811
812 fn detect_exponential_growth(&self, values: &[f32]) -> bool {
814 if values.len() < 5 {
815 return false;
816 }
817
818 let mut growth_count = 0;
820 for window in values.windows(2) {
821 if window[0] > 0.0 && window[1] / window[0] > 1.5 {
822 growth_count += 1;
823 }
824 }
825
826 growth_count >= (values.len() - 1) / 2 }
828
829 fn detect_variance_increase(&self, values: &[f32]) -> bool {
831 if values.len() < 8 {
832 return false;
833 }
834
835 let mid_point = values.len() / 2;
836 let early_half = &values[..mid_point];
837 let recent_half = &values[mid_point..];
838
839 let early_variance = self.compute_variance(early_half);
840 let recent_variance = self.compute_variance(recent_half);
841
842 recent_variance > early_variance * 2.0 }
844
845 fn detect_no_improvement(&self, losses: &[f32], threshold: f32) -> bool {
847 if losses.len() < 5 {
848 return false;
849 }
850
851 let best_loss = losses.iter().fold(f32::INFINITY, |a, &b| a.min(b));
852 let recent_loss = losses[0]; (best_loss - recent_loss) / best_loss.max(1e-8) < threshold
856 }
857
858 fn compute_oscillation_amplitude(&self) -> f32 {
860 if self.loss_history.len() < 10 {
861 return 0.0;
862 }
863
864 let recent_losses: Vec<f32> = self.loss_history.iter().rev().take(10).cloned().collect();
865 let max_loss = recent_losses.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
866 let min_loss = recent_losses.iter().fold(f32::INFINITY, |a, &b| a.min(b));
867 let mean_loss = recent_losses.iter().sum::<f32>() / recent_losses.len() as f32;
868
869 if mean_loss > 0.0 {
870 (max_loss - min_loss) / mean_loss
871 } else {
872 0.0
873 }
874 }
875
876 fn predict_gradient_explosion(&self, current_step: usize) -> Result<Option<PredictiveAnomaly>> {
877 if self.gradient_history.len() < 5 {
878 return Ok(None);
879 }
880
881 let recent_grads: Vec<f32> = self.gradient_history.iter().rev().take(10).cloned().collect();
882
883 let trend = self.compute_trend(&self.gradient_history);
885 let exponential_growth = self.detect_exponential_growth(&recent_grads);
886 let variance_increase = self.detect_variance_increase(&recent_grads);
887
888 let base_confidence = match trend {
889 TrendDirection::Increasing => 0.6,
890 TrendDirection::Diverging => 0.9,
891 _ => 0.0,
892 };
893
894 let growth_factor = if exponential_growth { 0.3f32 } else { 0.0f32 };
895 let variance_factor = if variance_increase { 0.2f32 } else { 0.0f32 };
896
897 let confidence = (base_confidence + growth_factor + variance_factor).min(1.0f32);
898
899 if confidence >= self.config.prediction_confidence_threshold {
900 let time_to_occurrence = if exponential_growth { 2 } else { 5 };
901 let risk_level = if confidence > 0.8 { RiskLevel::Critical } else { RiskLevel::High };
902
903 return Ok(Some(PredictiveAnomaly {
904 predicted_step: current_step + time_to_occurrence,
905 anomaly_type: PredictedAnomalyType::GradientExplosion,
906 confidence,
907 time_to_occurrence,
908 preventive_actions: vec![
909 PreventiveAction::ReduceLearningRate {
910 factor: if confidence > 0.8 { 0.1 } else { 0.5 },
911 },
912 PreventiveAction::IncreaseGradientClipping { new_threshold: 1.0 },
913 PreventiveAction::TriggerEarlyCheckpoint,
914 ],
915 risk_level,
916 }));
917 }
918
919 Ok(None)
920 }
921
922 fn predict_training_stagnation(
923 &self,
924 current_step: usize,
925 ) -> Result<Option<PredictiveAnomaly>> {
926 if self.loss_history.len() < 20 {
927 return Ok(None);
928 }
929
930 let recent_losses: Vec<f32> = self.loss_history.iter().rev().take(15).cloned().collect();
931 let variance = self.compute_variance(&recent_losses);
932 let slope = self.compute_slope(&recent_losses);
933
934 let low_variance = variance < 1e-6;
936 let flat_slope = slope.abs() < 1e-5;
937 let no_improvement = self.detect_no_improvement(&recent_losses, 0.001);
938
939 let stagnation_indicators =
940 [low_variance, flat_slope, no_improvement].iter().filter(|&&x| x).count();
941
942 if stagnation_indicators >= 2 {
943 let confidence = match stagnation_indicators {
944 3 => 0.95,
945 2 => 0.7,
946 _ => 0.5,
947 };
948
949 if confidence >= self.config.prediction_confidence_threshold {
950 return Ok(Some(PredictiveAnomaly {
951 predicted_step: current_step + 10,
952 anomaly_type: PredictedAnomalyType::TrainingStagnation,
953 confidence,
954 time_to_occurrence: 10,
955 preventive_actions: vec![
956 PreventiveAction::AdjustOptimizer {
957 suggested_params: [
958 ("momentum".to_string(), 0.9),
959 ("learning_rate_multiplier".to_string(), 1.5),
960 ]
961 .into_iter()
962 .collect(),
963 },
964 PreventiveAction::EnableNoise { noise_level: 0.01 },
965 PreventiveAction::AdjustWarmupSchedule,
966 ],
967 risk_level: if confidence > 0.8 { RiskLevel::High } else { RiskLevel::Medium },
968 }));
969 }
970 }
971
972 Ok(None)
973 }
974
975 fn predict_numerical_instability(
976 &self,
977 current_step: usize,
978 ) -> Result<Option<PredictiveAnomaly>> {
979 if self.loss_history.len() < 5 {
980 return Ok(None);
981 }
982
983 let recent_loss = self.loss_history.back().unwrap_or(&1.0);
984 if recent_loss.is_nan() || recent_loss.is_infinite() || *recent_loss > 1e6 {
985 return Ok(Some(PredictiveAnomaly {
986 predicted_step: current_step + 1,
987 anomaly_type: PredictedAnomalyType::NumericalInstability,
988 confidence: 0.95,
989 time_to_occurrence: 1,
990 preventive_actions: vec![
991 PreventiveAction::ReduceLearningRate { factor: 0.1 },
992 PreventiveAction::TriggerEarlyCheckpoint,
993 ],
994 risk_level: RiskLevel::Critical,
995 }));
996 }
997
998 Ok(None)
999 }
1000
1001 fn predict_oscillating_loss(&self, current_step: usize) -> Result<Option<PredictiveAnomaly>> {
1002 if self.loss_history.len() < 15 {
1003 return Ok(None);
1004 }
1005
1006 let oscillation_freq = self.compute_oscillation_frequency();
1007 let amplitude = self.compute_oscillation_amplitude();
1008
1009 let severity_score = oscillation_freq * amplitude;
1011
1012 if oscillation_freq > 0.3 || severity_score > 0.2 {
1013 let confidence = (oscillation_freq * 2.0 + severity_score).min(1.0);
1014
1015 if confidence >= self.config.prediction_confidence_threshold {
1016 return Ok(Some(PredictiveAnomaly {
1017 predicted_step: current_step + 5,
1018 anomaly_type: PredictedAnomalyType::OscillatingLoss,
1019 confidence,
1020 time_to_occurrence: 5,
1021 preventive_actions: vec![
1022 PreventiveAction::ReduceLearningRate {
1023 factor: if severity_score > 0.5 { 0.5 } else { 0.8 },
1024 },
1025 PreventiveAction::AdjustWarmupSchedule,
1026 PreventiveAction::EnableNoise { noise_level: 0.005 }, PreventiveAction::ModifyBatchSize { new_size: 64 }, ],
1029 risk_level: if severity_score > 0.5 {
1030 RiskLevel::High
1031 } else {
1032 RiskLevel::Medium
1033 },
1034 }));
1035 }
1036 }
1037
1038 Ok(None)
1039 }
1040
1041 fn should_apply_action(&self, _action: &PreventiveAction, _params: &TrainerParameters) -> bool {
1042 true }
1044
1045 fn apply_preventive_action(
1046 &mut self,
1047 action: &PreventiveAction,
1048 params: &mut TrainerParameters,
1049 ) -> Result<()> {
1050 match action {
1051 PreventiveAction::ReduceLearningRate { factor } => {
1052 params.learning_rate *= factor;
1053 },
1054 PreventiveAction::IncreaseGradientClipping { new_threshold } => {
1055 params.gradient_clip_threshold = *new_threshold;
1056 },
1057 PreventiveAction::ModifyBatchSize { new_size } => {
1058 params.batch_size = *new_size;
1059 },
1060 _ => {
1061 },
1063 }
1064 Ok(())
1065 }
1066}
1067
1068pub struct PatternDetector {
1070 #[allow(dead_code)]
1071 pattern_library: HashMap<String, Pattern>,
1072}
1073
1074impl Default for PatternDetector {
1075 fn default() -> Self {
1076 Self::new()
1077 }
1078}
1079
1080impl PatternDetector {
1081 pub fn new() -> Self {
1082 Self {
1083 pattern_library: HashMap::new(),
1084 }
1085 }
1086
1087 pub fn detect_patterns(&self, _dynamics: &TrainingDynamics) -> Vec<DetectedPattern> {
1088 Vec::new() }
1090}
1091
1092#[derive(Debug, Clone)]
1093pub struct Pattern {
1094 pub name: String,
1095 pub description: String,
1096 pub indicators: Vec<PatternIndicator>,
1097}
1098
1099#[derive(Debug, Clone)]
1100pub struct PatternIndicator {
1101 pub metric: String,
1102 pub condition: String,
1103 pub threshold: f32,
1104}
1105
1106#[derive(Debug, Clone)]
1107pub struct DetectedPattern {
1108 pub pattern: Pattern,
1109 pub confidence: f32,
1110 pub severity: RiskLevel,
1111}
1112
1113#[derive(Debug, Clone)]
1115pub struct TrainerParameters {
1116 pub learning_rate: f32,
1117 pub gradient_clip_threshold: f32,
1118 pub batch_size: usize,
1119 pub optimizer_params: HashMap<String, f32>,
1120}
1121
1122#[derive(Debug, Clone, Serialize, Deserialize)]
1124pub struct StabilityReport {
1125 pub current_stability_score: f32,
1126 pub stability_trend: TrendDirection,
1127 pub immediate_risks: Vec<PredictiveAnomaly>,
1128 pub predicted_anomalies: Vec<PredictiveAnomaly>,
1129 pub landscape_health: Option<LossLandscapeAnalysis>,
1130 pub recommendations: Vec<String>,
1131 pub confidence_level: f32,
1132}
1133
1134#[cfg(test)]
1135mod tests {
1136 use super::*;
1137
1138 #[test]
1139 fn test_advanced_stability_monitor_creation() {
1140 let config = AdvancedStabilityConfig::default();
1141 let monitor = AdvancedStabilityMonitor::new(config);
1142 assert!(monitor.loss_history.is_empty());
1143 assert!(monitor.predicted_anomalies.is_empty());
1144 }
1145
1146 #[test]
1147 fn test_stability_analysis() {
1148 let config = AdvancedStabilityConfig::default();
1149 let mut monitor = AdvancedStabilityMonitor::new(config);
1150 let gradients = HashMap::new();
1151
1152 let result = monitor.analyze_step(0, 1.0, 0.5, 0.001, &gradients);
1153 assert!(result.is_ok());
1154 }
1155
1156 #[test]
1157 fn test_trend_computation() {
1158 let config = AdvancedStabilityConfig::default();
1159 let monitor = AdvancedStabilityMonitor::new(config);
1160
1161 let values: VecDeque<f32> = vec![1.0, 0.9, 0.8, 0.7, 0.6].into();
1162 let trend = monitor.compute_trend(&values);
1163 assert!(matches!(trend, TrendDirection::Decreasing));
1164 }
1165
1166 #[test]
1167 fn test_stability_report_generation() {
1168 let config = AdvancedStabilityConfig::default();
1169 let monitor = AdvancedStabilityMonitor::new(config);
1170
1171 let report = monitor.get_stability_report();
1172 assert!(report.current_stability_score >= 0.0);
1173 assert!(report.confidence_level >= 0.0);
1174 }
1175}