Skip to main content

trustformers_training/
advanced_stability_monitor.rs

1//! Advanced Training Stability Monitoring System
2//!
3//! This module provides predictive anomaly detection and proactive recovery mechanisms
4//! that go beyond traditional reactive monitoring to prevent training failures before they occur.
5
6use anyhow::Result;
7use log;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, VecDeque};
10use trustformers_core::errors::runtime_error;
11use trustformers_core::tensor::Tensor;
12
13/// Configuration for advanced stability monitoring
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AdvancedStabilityConfig {
16    /// Enable predictive anomaly detection
17    pub predictive_detection: bool,
18    /// Enable proactive recovery mechanisms
19    pub proactive_recovery: bool,
20    /// Enable training dynamics analysis
21    pub dynamics_analysis: bool,
22    /// Enable loss landscape monitoring
23    pub loss_landscape_monitoring: bool,
24    /// Prediction horizon (steps ahead)
25    pub prediction_horizon: usize,
26    /// Confidence threshold for predictions
27    pub prediction_confidence_threshold: f32,
28    /// Pattern detection window size
29    pub pattern_window_size: usize,
30    /// Stability score threshold
31    pub stability_threshold: f32,
32    /// Adaptive recovery enabled
33    pub adaptive_recovery: bool,
34}
35
36impl Default for AdvancedStabilityConfig {
37    fn default() -> Self {
38        Self {
39            predictive_detection: true,
40            proactive_recovery: true,
41            dynamics_analysis: true,
42            loss_landscape_monitoring: true,
43            prediction_horizon: 10,
44            prediction_confidence_threshold: 0.7,
45            pattern_window_size: 50,
46            stability_threshold: 0.8,
47            adaptive_recovery: true,
48        }
49    }
50}
51
52/// Training dynamics patterns
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct TrainingDynamics {
55    /// Loss trajectory trend
56    pub loss_trend: TrendDirection,
57    /// Gradient norm evolution
58    pub gradient_trend: TrendDirection,
59    /// Learning rate effectiveness
60    pub lr_effectiveness: f32,
61    /// Convergence velocity
62    pub convergence_velocity: f32,
63    /// Oscillation frequency
64    pub oscillation_frequency: f32,
65    /// Phase space trajectory
66    pub phase_trajectory: Vec<(f32, f32)>, // (loss, gradient_norm)
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub enum TrendDirection {
71    Decreasing,
72    Increasing,
73    Stable,
74    Oscillating,
75    Diverging,
76}
77
78/// Predictive anomaly detection result
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct PredictiveAnomaly {
81    /// Predicted step where anomaly will occur
82    pub predicted_step: usize,
83    /// Type of predicted anomaly
84    pub anomaly_type: PredictedAnomalyType,
85    /// Confidence of prediction (0-1)
86    pub confidence: f32,
87    /// Time to occurrence (estimated steps)
88    pub time_to_occurrence: usize,
89    /// Suggested preventive actions
90    pub preventive_actions: Vec<PreventiveAction>,
91    /// Risk level
92    pub risk_level: RiskLevel,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub enum PredictedAnomalyType {
97    GradientExplosion,
98    GradientVanishing,
99    TrainingStagnation,
100    ConvergenceFailure,
101    NumericalInstability,
102    OscillatingLoss,
103    MemoryExhaustion,
104    LearningRateDeterioration,
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub enum PreventiveAction {
109    ReduceLearningRate {
110        factor: f32,
111    },
112    IncreaseGradientClipping {
113        new_threshold: f32,
114    },
115    AdjustOptimizer {
116        suggested_params: HashMap<String, f32>,
117    },
118    TriggerEarlyCheckpoint,
119    ModifyBatchSize {
120        new_size: usize,
121    },
122    AdjustWarmupSchedule,
123    EnableNoise {
124        noise_level: f32,
125    },
126    ResetAccumulatedGradients,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub enum RiskLevel {
131    Low,
132    Medium,
133    High,
134    Critical,
135}
136
137/// Loss landscape analysis
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct LossLandscapeAnalysis {
140    /// Local curvature estimate
141    pub local_curvature: f32,
142    /// Gradient consistency score
143    pub gradient_consistency: f32,
144    /// Escape difficulty from current region
145    pub escape_difficulty: f32,
146    /// Basin stability
147    pub basin_stability: f32,
148    /// Saddle point probability
149    pub saddle_point_prob: f32,
150}
151
152/// Stability score breakdown
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct StabilityScore {
155    /// Overall stability score (0-1)
156    pub overall_score: f32,
157    /// Gradient stability component
158    pub gradient_stability: f32,
159    /// Loss stability component
160    pub loss_stability: f32,
161    /// Convergence stability component
162    pub convergence_stability: f32,
163    /// Numerical stability component
164    pub numerical_stability: f32,
165    /// Recommendations for improvement
166    pub recommendations: Vec<String>,
167}
168
169/// Advanced stability monitor
170#[allow(dead_code)]
171pub struct AdvancedStabilityMonitor {
172    config: AdvancedStabilityConfig,
173    loss_history: VecDeque<f32>,
174    gradient_history: VecDeque<f32>,
175    lr_history: VecDeque<f32>,
176    dynamics_history: Vec<TrainingDynamics>,
177    predicted_anomalies: Vec<PredictiveAnomaly>,
178    landscape_analyses: VecDeque<LossLandscapeAnalysis>,
179    stability_scores: VecDeque<StabilityScore>,
180    #[allow(dead_code)]
181    recovery_effectiveness: HashMap<PreventiveAction, f32>,
182    pattern_detector: PatternDetector,
183}
184
185impl AdvancedStabilityMonitor {
186    pub fn new(config: AdvancedStabilityConfig) -> Self {
187        Self {
188            config,
189            loss_history: VecDeque::new(),
190            gradient_history: VecDeque::new(),
191            lr_history: VecDeque::new(),
192            dynamics_history: Vec::new(),
193            predicted_anomalies: Vec::new(),
194            landscape_analyses: VecDeque::new(),
195            stability_scores: VecDeque::new(),
196            recovery_effectiveness: HashMap::new(),
197            pattern_detector: PatternDetector::new(),
198        }
199    }
200
201    /// Analyze current training step and predict future stability
202    pub fn analyze_step(
203        &mut self,
204        step: usize,
205        loss: f32,
206        gradient_norm: f32,
207        learning_rate: f32,
208        gradients: &HashMap<String, Tensor>,
209    ) -> Result<()> {
210        // Update histories
211        self.update_histories(loss, gradient_norm, learning_rate);
212
213        // Analyze training dynamics
214        if self.config.dynamics_analysis {
215            let dynamics = self.analyze_training_dynamics()?;
216            self.dynamics_history.push(dynamics);
217        }
218
219        // Perform loss landscape analysis
220        if self.config.loss_landscape_monitoring {
221            let landscape = self.analyze_loss_landscape(gradients)?;
222            self.landscape_analyses.push_back(landscape);
223            if self.landscape_analyses.len() > self.config.pattern_window_size {
224                self.landscape_analyses.pop_front();
225            }
226        }
227
228        // Compute stability score
229        let stability = self.compute_stability_score()?;
230        self.stability_scores.push_back(stability);
231        if self.stability_scores.len() > self.config.pattern_window_size {
232            self.stability_scores.pop_front();
233        }
234
235        // Predictive anomaly detection
236        if self.config.predictive_detection {
237            let predictions = self.predict_anomalies(step)?;
238            self.predicted_anomalies.extend(predictions);
239        }
240
241        Ok(())
242    }
243
244    /// Get stability report with predictions and recommendations
245    pub fn get_stability_report(&self) -> StabilityReport {
246        let current_stability =
247            self.stability_scores.back().map(|s| s.overall_score).unwrap_or(1.0);
248
249        let immediate_risks: Vec<PredictiveAnomaly> = self
250            .predicted_anomalies
251            .iter()
252            .filter(|anomaly| anomaly.time_to_occurrence <= 5)
253            .cloned()
254            .collect();
255
256        let trend_analysis = self.analyze_stability_trend();
257
258        StabilityReport {
259            current_stability_score: current_stability,
260            stability_trend: trend_analysis,
261            immediate_risks,
262            predicted_anomalies: self.predicted_anomalies.clone(),
263            landscape_health: self.landscape_analyses.back().cloned(),
264            recommendations: self.generate_recommendations(),
265            confidence_level: self.compute_prediction_confidence(),
266        }
267    }
268
269    /// Apply proactive recovery based on predictions
270    pub fn apply_proactive_recovery(
271        &mut self,
272        trainer_params: &mut TrainerParameters,
273    ) -> Result<Vec<PreventiveAction>> {
274        if !self.config.proactive_recovery {
275            return Ok(Vec::new());
276        }
277
278        let mut applied_actions = Vec::new();
279
280        // Collect actions to apply first to avoid borrowing conflicts
281        let mut actions_to_apply = Vec::new();
282
283        for anomaly in &self.predicted_anomalies {
284            if anomaly.confidence >= self.config.prediction_confidence_threshold
285                && anomaly.time_to_occurrence <= 3
286            {
287                for action in &anomaly.preventive_actions {
288                    if self.should_apply_action(action, trainer_params) {
289                        actions_to_apply.push(action.clone());
290                    }
291                }
292            }
293        }
294
295        // Apply the collected actions
296        for action in actions_to_apply {
297            self.apply_preventive_action(&action, trainer_params)?;
298            applied_actions.push(action);
299        }
300
301        Ok(applied_actions)
302    }
303
304    fn update_histories(&mut self, loss: f32, gradient_norm: f32, learning_rate: f32) {
305        self.loss_history.push_back(loss);
306        self.gradient_history.push_back(gradient_norm);
307        self.lr_history.push_back(learning_rate);
308
309        let max_len = self.config.pattern_window_size;
310        if self.loss_history.len() > max_len {
311            self.loss_history.pop_front();
312        }
313        if self.gradient_history.len() > max_len {
314            self.gradient_history.pop_front();
315        }
316        if self.lr_history.len() > max_len {
317            self.lr_history.pop_front();
318        }
319    }
320
321    fn analyze_training_dynamics(&self) -> Result<TrainingDynamics> {
322        let loss_trend = self.compute_trend(&self.loss_history);
323        let gradient_trend = self.compute_trend(&self.gradient_history);
324        let lr_effectiveness = self.compute_lr_effectiveness();
325        let convergence_velocity = self.compute_convergence_velocity();
326        let oscillation_frequency = self.compute_oscillation_frequency();
327        let phase_trajectory = self.compute_phase_trajectory();
328
329        Ok(TrainingDynamics {
330            loss_trend,
331            gradient_trend,
332            lr_effectiveness,
333            convergence_velocity,
334            oscillation_frequency,
335            phase_trajectory,
336        })
337    }
338
339    fn analyze_loss_landscape(
340        &self,
341        gradients: &HashMap<String, Tensor>,
342    ) -> Result<LossLandscapeAnalysis> {
343        let local_curvature = self.estimate_local_curvature(gradients).unwrap_or_else(|e| {
344            log::warn!("Failed to estimate local curvature: {}", e);
345            0.1
346        });
347
348        let gradient_consistency =
349            self.compute_gradient_consistency(gradients).unwrap_or_else(|e| {
350                log::warn!("Failed to compute gradient consistency: {}", e);
351                0.8
352            });
353
354        let escape_difficulty = self.estimate_escape_difficulty();
355        let basin_stability = self.estimate_basin_stability();
356
357        let saddle_point_prob =
358            self.estimate_saddle_point_probability(gradients).unwrap_or_else(|e| {
359                log::warn!("Failed to estimate saddle point probability: {}", e);
360                0.2
361            });
362
363        Ok(LossLandscapeAnalysis {
364            local_curvature,
365            gradient_consistency,
366            escape_difficulty,
367            basin_stability,
368            saddle_point_prob,
369        })
370    }
371
372    fn compute_stability_score(&self) -> Result<StabilityScore> {
373        let gradient_stability = self.compute_gradient_stability();
374        let loss_stability = self.compute_loss_stability();
375        let convergence_stability = self.compute_convergence_stability();
376        let numerical_stability = self.compute_numerical_stability();
377
378        let overall_score =
379            (gradient_stability + loss_stability + convergence_stability + numerical_stability)
380                / 4.0;
381
382        let recommendations = self.generate_stability_recommendations(
383            gradient_stability,
384            loss_stability,
385            convergence_stability,
386            numerical_stability,
387        );
388
389        Ok(StabilityScore {
390            overall_score,
391            gradient_stability,
392            loss_stability,
393            convergence_stability,
394            numerical_stability,
395            recommendations,
396        })
397    }
398
399    fn predict_anomalies(&self, current_step: usize) -> Result<Vec<PredictiveAnomaly>> {
400        let mut predictions = Vec::new();
401
402        // Predict gradient explosion
403        if let Some(anomaly) = self.predict_gradient_explosion(current_step)? {
404            predictions.push(anomaly);
405        }
406
407        // Predict training stagnation
408        if let Some(anomaly) = self.predict_training_stagnation(current_step)? {
409            predictions.push(anomaly);
410        }
411
412        // Predict numerical instability
413        if let Some(anomaly) = self.predict_numerical_instability(current_step)? {
414            predictions.push(anomaly);
415        }
416
417        // Predict oscillating loss
418        if let Some(anomaly) = self.predict_oscillating_loss(current_step)? {
419            predictions.push(anomaly);
420        }
421
422        Ok(predictions)
423    }
424
425    // Helper methods for trend analysis
426    fn compute_trend(&self, history: &VecDeque<f32>) -> TrendDirection {
427        if history.len() < 3 {
428            return TrendDirection::Stable;
429        }
430
431        // Take the 10 most recent values and restore chronological order
432        let mut recent: Vec<f32> = history.iter().rev().take(10).cloned().collect();
433        recent.reverse(); // Restore chronological order for slope computation
434        let slope = self.compute_slope(&recent);
435        let variance = self.compute_variance(&recent);
436
437        if variance > 0.1 {
438            TrendDirection::Oscillating
439        } else if slope < -0.01 {
440            TrendDirection::Decreasing
441        } else if slope > 0.01 {
442            TrendDirection::Increasing
443        } else {
444            TrendDirection::Stable
445        }
446    }
447
448    fn compute_slope(&self, values: &[f32]) -> f32 {
449        if values.len() < 2 {
450            return 0.0;
451        }
452
453        let n = values.len() as f32;
454        let sum_x: f32 = (0..values.len()).map(|i| i as f32).sum();
455        let sum_y: f32 = values.iter().sum();
456        let sum_xy: f32 = values.iter().enumerate().map(|(i, &y)| i as f32 * y).sum();
457        let sum_x2: f32 = (0..values.len()).map(|i| (i as f32).powi(2)).sum();
458
459        (n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x)
460    }
461
462    fn compute_variance(&self, values: &[f32]) -> f32 {
463        if values.is_empty() {
464            return 0.0;
465        }
466
467        let mean: f32 = values.iter().sum::<f32>() / values.len() as f32;
468        let variance: f32 =
469            values.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32;
470
471        variance
472    }
473
474    // Enhanced implementations for complex analysis methods
475    fn compute_lr_effectiveness(&self) -> f32 {
476        if self.loss_history.len() < 5 || self.lr_history.len() < 5 {
477            return 0.5;
478        }
479
480        // Compute correlation between LR changes and loss improvements
481        let mut lr_effectiveness_scores = Vec::new();
482
483        for window in self
484            .loss_history
485            .iter()
486            .zip(self.lr_history.iter())
487            .collect::<Vec<_>>()
488            .windows(3)
489        {
490            if let [(l1, lr1), (l2, lr2), (_l3, _lr3)] = window {
491                let loss_improvement = (*l1 - *l2) / l1.max(1e-8f32);
492                let lr_change = (*lr2 - *lr1) / lr1.max(1e-8f32);
493
494                // Higher effectiveness if LR increases lead to loss decreases (and vice versa)
495                if loss_improvement > 0.0 && lr_change > 0.0 {
496                    lr_effectiveness_scores.push(0.8);
497                } else if loss_improvement < 0.0 && lr_change < 0.0 {
498                    lr_effectiveness_scores.push(0.6);
499                } else {
500                    lr_effectiveness_scores.push(0.3);
501                }
502            }
503        }
504
505        if lr_effectiveness_scores.is_empty() {
506            0.5
507        } else {
508            lr_effectiveness_scores.iter().sum::<f32>() / lr_effectiveness_scores.len() as f32
509        }
510    }
511
512    fn compute_convergence_velocity(&self) -> f32 {
513        if self.loss_history.len() < 10 {
514            return 0.0;
515        }
516
517        let recent_losses: Vec<f32> = self.loss_history.iter().rev().take(10).cloned().collect();
518        let slope = self.compute_slope(&recent_losses);
519
520        // Normalize slope to get velocity (more negative slope = faster convergence)
521
522        if slope < 0.0 {
523            (-slope * 100.0).min(1.0)
524        } else {
525            0.0
526        }
527    }
528
529    fn compute_oscillation_frequency(&self) -> f32 {
530        if self.loss_history.len() < 10 {
531            return 0.0;
532        }
533
534        let recent_losses: Vec<f32> = self.loss_history.iter().rev().take(20).cloned().collect();
535        let mut direction_changes = 0;
536
537        for window in recent_losses.windows(3) {
538            if (window[1] > window[0]) != (window[2] > window[1]) {
539                direction_changes += 1;
540            }
541        }
542
543        // Normalize by the number of possible direction changes
544        direction_changes as f32 / (recent_losses.len() - 2).max(1) as f32
545    }
546
547    fn compute_phase_trajectory(&self) -> Vec<(f32, f32)> {
548        self.loss_history
549            .iter()
550            .zip(self.gradient_history.iter())
551            .map(|(&l, &g)| (l, g))
552            .collect()
553    }
554
555    fn estimate_local_curvature(&self, gradients: &HashMap<String, Tensor>) -> Result<f32> {
556        if gradients.is_empty() || self.gradient_history.len() < 3 {
557            return Ok(0.1);
558        }
559
560        // Estimate curvature using finite differences of gradient norms
561        let _current_norm = self.compute_total_gradient_norm(gradients)?;
562        let recent_norms: Vec<f32> = self.gradient_history.iter().rev().take(3).cloned().collect();
563
564        if recent_norms.len() >= 3 {
565            // Second derivative approximation using finite differences
566            let second_derivative = recent_norms[0] - 2.0 * recent_norms[1] + recent_norms[2];
567            let curvature = second_derivative.abs() / (recent_norms[1].max(1e-8));
568            Ok(curvature.min(10.0)) // Cap extreme curvature values
569        } else {
570            Ok(0.1)
571        }
572    }
573
574    fn compute_gradient_consistency(&self, gradients: &HashMap<String, Tensor>) -> Result<f32> {
575        if gradients.len() < 2 {
576            return Ok(1.0);
577        }
578
579        // Compute consistency by checking gradient norm ratios across layers
580        let mut norms = Vec::new();
581        for tensor in gradients.values() {
582            let data = tensor.data().unwrap_or_default();
583            let norm = data.iter().map(|&x| x * x).sum::<f32>().sqrt();
584            norms.push(norm);
585        }
586
587        if norms.is_empty() {
588            return Ok(1.0);
589        }
590
591        let mean_norm = norms.iter().sum::<f32>() / norms.len() as f32;
592        let variance =
593            norms.iter().map(|&x| (x - mean_norm).powi(2)).sum::<f32>() / norms.len() as f32;
594        let cv = variance.sqrt() / mean_norm.max(1e-8);
595
596        // Higher consistency (lower CV) gets higher score
597        Ok((1.0 / (1.0 + cv * 2.0)).clamp(0.0, 1.0))
598    }
599
600    fn estimate_escape_difficulty(&self) -> f32 {
601        if self.loss_history.len() < 20 {
602            return 0.3;
603        }
604
605        // Estimate difficulty based on local minima detection
606        let recent_losses: Vec<f32> = self.loss_history.iter().rev().take(20).cloned().collect();
607        let mut local_minima_count = 0;
608
609        for window in recent_losses.windows(5) {
610            if window[2] < window[0]
611                && window[2] < window[1]
612                && window[2] < window[3]
613                && window[2] < window[4]
614            {
615                local_minima_count += 1;
616            }
617        }
618
619        // More local minima suggest higher escape difficulty
620        (local_minima_count as f32 / 5.0).min(1.0)
621    }
622
623    fn estimate_basin_stability(&self) -> f32 {
624        if self.loss_history.len() < 10 {
625            return 0.7;
626        }
627
628        // Estimate stability based on loss variance and trend
629        let recent_losses: Vec<f32> = self.loss_history.iter().rev().take(10).cloned().collect();
630        let variance = self.compute_variance(&recent_losses);
631        let mean_loss = recent_losses.iter().sum::<f32>() / recent_losses.len() as f32;
632        let cv = variance.sqrt() / mean_loss.max(1e-8);
633
634        // Lower variance indicates more stable basin
635        (1.0 / (1.0 + cv * 3.0)).clamp(0.0, 1.0)
636    }
637
638    fn estimate_saddle_point_probability(
639        &self,
640        gradients: &HashMap<String, Tensor>,
641    ) -> Result<f32> {
642        if gradients.is_empty() || self.gradient_history.len() < 5 {
643            return Ok(0.2);
644        }
645
646        let current_grad_norm = self.compute_total_gradient_norm(gradients)?;
647
648        // Saddle points typically have small gradients but high curvature
649        let small_gradient = current_grad_norm < 0.01;
650        let curvature = self.estimate_local_curvature(gradients)?;
651        let high_curvature = curvature > 0.1;
652
653        let probability = if small_gradient && high_curvature {
654            0.8
655        } else if small_gradient {
656            0.4
657        } else {
658            0.1
659        };
660
661        Ok(probability)
662    }
663
664    fn compute_gradient_stability(&self) -> f32 {
665        if self.gradient_history.len() < 5 {
666            return 1.0;
667        }
668        let variance =
669            self.compute_variance(&self.gradient_history.iter().cloned().collect::<Vec<_>>());
670        (1.0 / (1.0 + variance)).clamp(0.0, 1.0)
671    }
672
673    fn compute_loss_stability(&self) -> f32 {
674        if self.loss_history.len() < 5 {
675            return 1.0;
676        }
677        let recent_losses: Vec<f32> = self.loss_history.iter().rev().take(10).cloned().collect();
678        let slope = self.compute_slope(&recent_losses);
679        if slope < 0.0 {
680            0.9
681        } else if slope < 0.01 {
682            0.7
683        } else {
684            0.3
685        }
686    }
687
688    fn compute_convergence_stability(&self) -> f32 {
689        if self.loss_history.len() < 10 {
690            return 0.8;
691        }
692
693        let convergence_velocity = self.compute_convergence_velocity();
694        let oscillation_freq = self.compute_oscillation_frequency();
695
696        // Balance between good convergence speed and low oscillation
697        let velocity_score = convergence_velocity.min(0.5) * 2.0; // Normalize to 0-1
698        let stability_score = (1.0 - oscillation_freq).max(0.0);
699
700        (velocity_score * 0.6 + stability_score * 0.4).clamp(0.0, 1.0)
701    }
702
703    fn compute_numerical_stability(&self) -> f32 {
704        if self.loss_history.is_empty() || self.gradient_history.is_empty() {
705            return 0.9;
706        }
707
708        // Check for numerical issues in recent history
709        let recent_losses: Vec<f32> = self.loss_history.iter().rev().take(10).cloned().collect();
710        let recent_grads: Vec<f32> = self.gradient_history.iter().rev().take(10).cloned().collect();
711
712        let loss_issues = recent_losses.iter().any(|&x| !x.is_finite());
713        let grad_issues = recent_grads.iter().any(|&x| !x.is_finite());
714
715        let extreme_values = recent_losses.iter().any(|&x| !(-1e6..=1e6).contains(&x))
716            || recent_grads.iter().any(|&x| !(-1e6..=1e6).contains(&x));
717
718        if loss_issues || grad_issues {
719            0.0 // Critical numerical instability
720        } else if extreme_values {
721            0.3 // Potential instability
722        } else {
723            let max_loss = recent_losses.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
724            let max_grad = recent_grads.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
725
726            // Penalize very large values
727            let loss_penalty = if max_loss > 1000.0 { 0.3 } else { 0.0 };
728            let grad_penalty = if max_grad > 100.0 { 0.2 } else { 0.0 };
729
730            (1.0f32 - loss_penalty - grad_penalty).max(0.0f32)
731        }
732    }
733
734    fn generate_stability_recommendations(
735        &self,
736        gs: f32,
737        ls: f32,
738        cs: f32,
739        ns: f32,
740    ) -> Vec<String> {
741        let mut recommendations = Vec::new();
742
743        if gs < 0.5 {
744            recommendations.push(
745                "Consider gradient clipping or normalization to improve gradient stability"
746                    .to_string(),
747            );
748        }
749
750        if ls < 0.5 {
751            recommendations.push(
752                "Loss appears unstable - consider reducing learning rate or adjusting optimizer"
753                    .to_string(),
754            );
755        }
756
757        if cs < 0.5 {
758            recommendations.push("Poor convergence stability - consider learning rate scheduling or different optimizer".to_string());
759        }
760
761        if ns < 0.5 {
762            recommendations.push("Numerical instability detected - check for NaN/Inf values and consider mixed precision".to_string());
763        }
764
765        let overall_score = (gs + ls + cs + ns) / 4.0;
766
767        if overall_score < 0.3 {
768            recommendations.push(
769                "Critical stability issues - consider checkpoint rollback and parameter reset"
770                    .to_string(),
771            );
772        } else if overall_score < 0.6 {
773            recommendations.push("Moderate stability issues - monitor closely and consider conservative training settings".to_string());
774        } else if recommendations.is_empty() {
775            recommendations.push("Training stability is good - continue monitoring".to_string());
776        }
777
778        recommendations
779    }
780
781    fn analyze_stability_trend(&self) -> TrendDirection {
782        let scores: Vec<f32> = self.stability_scores.iter().map(|s| s.overall_score).collect();
783        self.compute_trend(&scores.into_iter().collect())
784    }
785
786    fn generate_recommendations(&self) -> Vec<String> {
787        vec!["Continue monitoring training progress".to_string()]
788    }
789
790    fn compute_prediction_confidence(&self) -> f32 {
791        // Base confidence on data quality and history length
792        let history_quality = if self.loss_history.len() >= 20 { 0.9 } else { 0.5 };
793        let data_quality = if self.loss_history.iter().all(|&x| x.is_finite()) { 0.9 } else { 0.3 };
794        let trend_consistency = if self.dynamics_history.len() >= 3 { 0.8 } else { 0.4 };
795
796        (history_quality * 0.4f32 + data_quality * 0.4f32 + trend_consistency * 0.2f32).min(1.0f32)
797    }
798
799    /// Helper method to compute total gradient norm across all tensors
800    fn compute_total_gradient_norm(&self, gradients: &HashMap<String, Tensor>) -> Result<f32> {
801        let mut total_norm_sq = 0.0f32;
802
803        for tensor in gradients.values() {
804            let data = tensor.data().map_err(|_| runtime_error("Failed to get tensor data"))?;
805            let tensor_norm_sq: f32 = data.iter().map(|&x| x * x).sum();
806            total_norm_sq += tensor_norm_sq;
807        }
808
809        Ok(total_norm_sq.sqrt())
810    }
811
812    /// Detect exponential growth in gradient sequence
813    fn detect_exponential_growth(&self, values: &[f32]) -> bool {
814        if values.len() < 5 {
815            return false;
816        }
817
818        // Check if each value is consistently larger than the previous by a significant factor
819        let mut growth_count = 0;
820        for window in values.windows(2) {
821            if window[0] > 0.0 && window[1] / window[0] > 1.5 {
822                growth_count += 1;
823            }
824        }
825
826        growth_count >= (values.len() - 1) / 2 // At least half show significant growth
827    }
828
829    /// Detect increasing variance in recent values
830    fn detect_variance_increase(&self, values: &[f32]) -> bool {
831        if values.len() < 8 {
832            return false;
833        }
834
835        let mid_point = values.len() / 2;
836        let early_half = &values[..mid_point];
837        let recent_half = &values[mid_point..];
838
839        let early_variance = self.compute_variance(early_half);
840        let recent_variance = self.compute_variance(recent_half);
841
842        recent_variance > early_variance * 2.0 // Recent variance is significantly higher
843    }
844
845    /// Detect lack of improvement over a threshold
846    fn detect_no_improvement(&self, losses: &[f32], threshold: f32) -> bool {
847        if losses.len() < 5 {
848            return false;
849        }
850
851        let best_loss = losses.iter().fold(f32::INFINITY, |a, &b| a.min(b));
852        let recent_loss = losses[0]; // Most recent (reversed order)
853
854        // No improvement if recent loss is not significantly better than the best
855        (best_loss - recent_loss) / best_loss.max(1e-8) < threshold
856    }
857
858    /// Compute oscillation amplitude
859    fn compute_oscillation_amplitude(&self) -> f32 {
860        if self.loss_history.len() < 10 {
861            return 0.0;
862        }
863
864        let recent_losses: Vec<f32> = self.loss_history.iter().rev().take(10).cloned().collect();
865        let max_loss = recent_losses.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
866        let min_loss = recent_losses.iter().fold(f32::INFINITY, |a, &b| a.min(b));
867        let mean_loss = recent_losses.iter().sum::<f32>() / recent_losses.len() as f32;
868
869        if mean_loss > 0.0 {
870            (max_loss - min_loss) / mean_loss
871        } else {
872            0.0
873        }
874    }
875
876    fn predict_gradient_explosion(&self, current_step: usize) -> Result<Option<PredictiveAnomaly>> {
877        if self.gradient_history.len() < 5 {
878            return Ok(None);
879        }
880
881        let recent_grads: Vec<f32> = self.gradient_history.iter().rev().take(10).cloned().collect();
882
883        // Multiple indicators for gradient explosion
884        let trend = self.compute_trend(&self.gradient_history);
885        let exponential_growth = self.detect_exponential_growth(&recent_grads);
886        let variance_increase = self.detect_variance_increase(&recent_grads);
887
888        let base_confidence = match trend {
889            TrendDirection::Increasing => 0.6,
890            TrendDirection::Diverging => 0.9,
891            _ => 0.0,
892        };
893
894        let growth_factor = if exponential_growth { 0.3f32 } else { 0.0f32 };
895        let variance_factor = if variance_increase { 0.2f32 } else { 0.0f32 };
896
897        let confidence = (base_confidence + growth_factor + variance_factor).min(1.0f32);
898
899        if confidence >= self.config.prediction_confidence_threshold {
900            let time_to_occurrence = if exponential_growth { 2 } else { 5 };
901            let risk_level = if confidence > 0.8 { RiskLevel::Critical } else { RiskLevel::High };
902
903            return Ok(Some(PredictiveAnomaly {
904                predicted_step: current_step + time_to_occurrence,
905                anomaly_type: PredictedAnomalyType::GradientExplosion,
906                confidence,
907                time_to_occurrence,
908                preventive_actions: vec![
909                    PreventiveAction::ReduceLearningRate {
910                        factor: if confidence > 0.8 { 0.1 } else { 0.5 },
911                    },
912                    PreventiveAction::IncreaseGradientClipping { new_threshold: 1.0 },
913                    PreventiveAction::TriggerEarlyCheckpoint,
914                ],
915                risk_level,
916            }));
917        }
918
919        Ok(None)
920    }
921
922    fn predict_training_stagnation(
923        &self,
924        current_step: usize,
925    ) -> Result<Option<PredictiveAnomaly>> {
926        if self.loss_history.len() < 20 {
927            return Ok(None);
928        }
929
930        let recent_losses: Vec<f32> = self.loss_history.iter().rev().take(15).cloned().collect();
931        let variance = self.compute_variance(&recent_losses);
932        let slope = self.compute_slope(&recent_losses);
933
934        // Multiple stagnation indicators
935        let low_variance = variance < 1e-6;
936        let flat_slope = slope.abs() < 1e-5;
937        let no_improvement = self.detect_no_improvement(&recent_losses, 0.001);
938
939        let stagnation_indicators =
940            [low_variance, flat_slope, no_improvement].iter().filter(|&&x| x).count();
941
942        if stagnation_indicators >= 2 {
943            let confidence = match stagnation_indicators {
944                3 => 0.95,
945                2 => 0.7,
946                _ => 0.5,
947            };
948
949            if confidence >= self.config.prediction_confidence_threshold {
950                return Ok(Some(PredictiveAnomaly {
951                    predicted_step: current_step + 10,
952                    anomaly_type: PredictedAnomalyType::TrainingStagnation,
953                    confidence,
954                    time_to_occurrence: 10,
955                    preventive_actions: vec![
956                        PreventiveAction::AdjustOptimizer {
957                            suggested_params: [
958                                ("momentum".to_string(), 0.9),
959                                ("learning_rate_multiplier".to_string(), 1.5),
960                            ]
961                            .into_iter()
962                            .collect(),
963                        },
964                        PreventiveAction::EnableNoise { noise_level: 0.01 },
965                        PreventiveAction::AdjustWarmupSchedule,
966                    ],
967                    risk_level: if confidence > 0.8 { RiskLevel::High } else { RiskLevel::Medium },
968                }));
969            }
970        }
971
972        Ok(None)
973    }
974
975    fn predict_numerical_instability(
976        &self,
977        current_step: usize,
978    ) -> Result<Option<PredictiveAnomaly>> {
979        if self.loss_history.len() < 5 {
980            return Ok(None);
981        }
982
983        let recent_loss = self.loss_history.back().unwrap_or(&1.0);
984        if recent_loss.is_nan() || recent_loss.is_infinite() || *recent_loss > 1e6 {
985            return Ok(Some(PredictiveAnomaly {
986                predicted_step: current_step + 1,
987                anomaly_type: PredictedAnomalyType::NumericalInstability,
988                confidence: 0.95,
989                time_to_occurrence: 1,
990                preventive_actions: vec![
991                    PreventiveAction::ReduceLearningRate { factor: 0.1 },
992                    PreventiveAction::TriggerEarlyCheckpoint,
993                ],
994                risk_level: RiskLevel::Critical,
995            }));
996        }
997
998        Ok(None)
999    }
1000
1001    fn predict_oscillating_loss(&self, current_step: usize) -> Result<Option<PredictiveAnomaly>> {
1002        if self.loss_history.len() < 15 {
1003            return Ok(None);
1004        }
1005
1006        let oscillation_freq = self.compute_oscillation_frequency();
1007        let amplitude = self.compute_oscillation_amplitude();
1008
1009        // Oscillation severity based on frequency and amplitude
1010        let severity_score = oscillation_freq * amplitude;
1011
1012        if oscillation_freq > 0.3 || severity_score > 0.2 {
1013            let confidence = (oscillation_freq * 2.0 + severity_score).min(1.0);
1014
1015            if confidence >= self.config.prediction_confidence_threshold {
1016                return Ok(Some(PredictiveAnomaly {
1017                    predicted_step: current_step + 5,
1018                    anomaly_type: PredictedAnomalyType::OscillatingLoss,
1019                    confidence,
1020                    time_to_occurrence: 5,
1021                    preventive_actions: vec![
1022                        PreventiveAction::ReduceLearningRate {
1023                            factor: if severity_score > 0.5 { 0.5 } else { 0.8 },
1024                        },
1025                        PreventiveAction::AdjustWarmupSchedule,
1026                        PreventiveAction::EnableNoise { noise_level: 0.005 }, // Small noise to break oscillations
1027                        PreventiveAction::ModifyBatchSize { new_size: 64 }, // Larger batch for stability
1028                    ],
1029                    risk_level: if severity_score > 0.5 {
1030                        RiskLevel::High
1031                    } else {
1032                        RiskLevel::Medium
1033                    },
1034                }));
1035            }
1036        }
1037
1038        Ok(None)
1039    }
1040
1041    fn should_apply_action(&self, _action: &PreventiveAction, _params: &TrainerParameters) -> bool {
1042        true // Simplified logic
1043    }
1044
1045    fn apply_preventive_action(
1046        &mut self,
1047        action: &PreventiveAction,
1048        params: &mut TrainerParameters,
1049    ) -> Result<()> {
1050        match action {
1051            PreventiveAction::ReduceLearningRate { factor } => {
1052                params.learning_rate *= factor;
1053            },
1054            PreventiveAction::IncreaseGradientClipping { new_threshold } => {
1055                params.gradient_clip_threshold = *new_threshold;
1056            },
1057            PreventiveAction::ModifyBatchSize { new_size } => {
1058                params.batch_size = *new_size;
1059            },
1060            _ => {
1061                // Other actions would be implemented based on trainer interface
1062            },
1063        }
1064        Ok(())
1065    }
1066}
1067
1068/// Pattern detector for complex training dynamics
1069pub struct PatternDetector {
1070    #[allow(dead_code)]
1071    pattern_library: HashMap<String, Pattern>,
1072}
1073
1074impl Default for PatternDetector {
1075    fn default() -> Self {
1076        Self::new()
1077    }
1078}
1079
1080impl PatternDetector {
1081    pub fn new() -> Self {
1082        Self {
1083            pattern_library: HashMap::new(),
1084        }
1085    }
1086
1087    pub fn detect_patterns(&self, _dynamics: &TrainingDynamics) -> Vec<DetectedPattern> {
1088        Vec::new() // Placeholder
1089    }
1090}
1091
1092#[derive(Debug, Clone)]
1093pub struct Pattern {
1094    pub name: String,
1095    pub description: String,
1096    pub indicators: Vec<PatternIndicator>,
1097}
1098
1099#[derive(Debug, Clone)]
1100pub struct PatternIndicator {
1101    pub metric: String,
1102    pub condition: String,
1103    pub threshold: f32,
1104}
1105
1106#[derive(Debug, Clone)]
1107pub struct DetectedPattern {
1108    pub pattern: Pattern,
1109    pub confidence: f32,
1110    pub severity: RiskLevel,
1111}
1112
1113/// Trainer parameters that can be modified by recovery actions
1114#[derive(Debug, Clone)]
1115pub struct TrainerParameters {
1116    pub learning_rate: f32,
1117    pub gradient_clip_threshold: f32,
1118    pub batch_size: usize,
1119    pub optimizer_params: HashMap<String, f32>,
1120}
1121
1122/// Comprehensive stability report
1123#[derive(Debug, Clone, Serialize, Deserialize)]
1124pub struct StabilityReport {
1125    pub current_stability_score: f32,
1126    pub stability_trend: TrendDirection,
1127    pub immediate_risks: Vec<PredictiveAnomaly>,
1128    pub predicted_anomalies: Vec<PredictiveAnomaly>,
1129    pub landscape_health: Option<LossLandscapeAnalysis>,
1130    pub recommendations: Vec<String>,
1131    pub confidence_level: f32,
1132}
1133
1134#[cfg(test)]
1135mod tests {
1136    use super::*;
1137
1138    #[test]
1139    fn test_advanced_stability_monitor_creation() {
1140        let config = AdvancedStabilityConfig::default();
1141        let monitor = AdvancedStabilityMonitor::new(config);
1142        assert!(monitor.loss_history.is_empty());
1143        assert!(monitor.predicted_anomalies.is_empty());
1144    }
1145
1146    #[test]
1147    fn test_stability_analysis() {
1148        let config = AdvancedStabilityConfig::default();
1149        let mut monitor = AdvancedStabilityMonitor::new(config);
1150        let gradients = HashMap::new();
1151
1152        let result = monitor.analyze_step(0, 1.0, 0.5, 0.001, &gradients);
1153        assert!(result.is_ok());
1154    }
1155
1156    #[test]
1157    fn test_trend_computation() {
1158        let config = AdvancedStabilityConfig::default();
1159        let monitor = AdvancedStabilityMonitor::new(config);
1160
1161        let values: VecDeque<f32> = vec![1.0, 0.9, 0.8, 0.7, 0.6].into();
1162        let trend = monitor.compute_trend(&values);
1163        assert!(matches!(trend, TrendDirection::Decreasing));
1164    }
1165
1166    #[test]
1167    fn test_stability_report_generation() {
1168        let config = AdvancedStabilityConfig::default();
1169        let monitor = AdvancedStabilityMonitor::new(config);
1170
1171        let report = monitor.get_stability_report();
1172        assert!(report.current_stability_score >= 0.0);
1173        assert!(report.confidence_level >= 0.0);
1174    }
1175}