Skip to main content

trustformers_debug/gradient_debugger/
monitoring.rs

1//! Real-time Gradient Monitoring and Adaptive Thresholds
2//!
3//! This module provides real-time gradient monitoring capabilities with adaptive
4//! thresholds that dynamically adjust based on gradient behavior patterns.
5
6use super::types::*;
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::VecDeque;
10
11/// Adaptive thresholds for dynamic gradient monitoring
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct AdaptiveThresholds {
14    pub layer_name: String,
15    pub vanishing_threshold: f64,
16    pub exploding_threshold: f64,
17    pub adaptation_rate: f64,
18    pub recent_gradients: VecDeque<f64>,
19    pub last_updated: DateTime<Utc>,
20}
21
22impl AdaptiveThresholds {
23    pub fn new(layer_name: String, initial_vanishing: f64, initial_exploding: f64) -> Self {
24        Self {
25            layer_name,
26            vanishing_threshold: initial_vanishing,
27            exploding_threshold: initial_exploding,
28            adaptation_rate: 0.1,
29            recent_gradients: VecDeque::with_capacity(100),
30            last_updated: Utc::now(),
31        }
32    }
33
34    pub fn update_thresholds(&mut self, gradient_norm: f64) {
35        // Add new gradient to history
36        if self.recent_gradients.len() >= 100 {
37            self.recent_gradients.pop_front();
38        }
39        self.recent_gradients.push_back(gradient_norm);
40
41        // Update thresholds based on recent history
42        if self.recent_gradients.len() >= 10 {
43            let mean =
44                self.recent_gradients.iter().sum::<f64>() / self.recent_gradients.len() as f64;
45            let variance = self.recent_gradients.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
46                / self.recent_gradients.len() as f64;
47            let std_dev = variance.sqrt();
48
49            // Adaptive threshold updates
50            let new_vanishing = (mean - 2.0 * std_dev).max(1e-8);
51            let new_exploding = mean + 3.0 * std_dev;
52
53            self.vanishing_threshold = self.vanishing_threshold * (1.0 - self.adaptation_rate)
54                + new_vanishing * self.adaptation_rate;
55            self.exploding_threshold = self.exploding_threshold * (1.0 - self.adaptation_rate)
56                + new_exploding * self.adaptation_rate;
57
58            self.last_updated = Utc::now();
59        }
60    }
61
62    pub fn check_thresholds(&self, gradient_norm: f64) -> Vec<GradientAlert> {
63        let mut alerts = Vec::new();
64
65        if gradient_norm < self.vanishing_threshold {
66            alerts.push(GradientAlert::VanishingGradients {
67                layer_name: self.layer_name.clone(),
68                norm: gradient_norm,
69                threshold: self.vanishing_threshold,
70            });
71        }
72
73        if gradient_norm > self.exploding_threshold {
74            alerts.push(GradientAlert::ExplodingGradients {
75                layer_name: self.layer_name.clone(),
76                norm: gradient_norm,
77                threshold: self.exploding_threshold,
78            });
79        }
80
81        alerts
82    }
83
84    /// Create adaptive thresholds from gradient history
85    pub fn from_history(history: &GradientHistory) -> Self {
86        let layer_name = history.layer_name.clone();
87
88        if history.gradient_norms.is_empty() {
89            return Self::new(layer_name, 1e-6, 10.0);
90        }
91
92        // Calculate initial thresholds from history
93        let norms: Vec<f64> = history.gradient_norms.iter().cloned().collect();
94        let mean = norms.iter().sum::<f64>() / norms.len() as f64;
95        let variance = norms.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / norms.len() as f64;
96        let std_dev = variance.sqrt();
97
98        let initial_vanishing = (mean - 2.0 * std_dev).max(1e-8);
99        let initial_exploding = mean + 3.0 * std_dev;
100
101        let mut thresholds = Self::new(layer_name, initial_vanishing, initial_exploding);
102
103        // Pre-populate with recent gradients
104        for &norm in norms.iter().rev().take(50) {
105            if thresholds.recent_gradients.len() >= 100 {
106                thresholds.recent_gradients.pop_front();
107            }
108            thresholds.recent_gradients.push_back(norm);
109        }
110
111        thresholds
112    }
113}
114
115/// Real-time gradient monitoring
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct RealTimeGradientMonitor {
118    pub layer_name: String,
119    pub current_gradient_norm: f64,
120    pub gradient_velocity: f64,
121    pub gradient_acceleration: f64,
122    pub stability_window: VecDeque<f64>,
123    pub anomaly_score: f64,
124}
125
126impl RealTimeGradientMonitor {
127    pub fn new(layer_name: String) -> Self {
128        Self {
129            layer_name,
130            current_gradient_norm: 0.0,
131            gradient_velocity: 0.0,
132            gradient_acceleration: 0.0,
133            stability_window: VecDeque::with_capacity(10),
134            anomaly_score: 0.0,
135        }
136    }
137
138    pub fn update(&mut self, new_gradient_norm: f64) {
139        let previous_norm = self.current_gradient_norm;
140        let previous_velocity = self.gradient_velocity;
141
142        self.current_gradient_norm = new_gradient_norm;
143        self.gradient_velocity = new_gradient_norm - previous_norm;
144        self.gradient_acceleration = self.gradient_velocity - previous_velocity;
145
146        // Update stability window
147        if self.stability_window.len() >= 10 {
148            self.stability_window.pop_front();
149        }
150        self.stability_window.push_back(new_gradient_norm);
151
152        // Update anomaly score
153        self.anomaly_score = self.compute_anomaly_score();
154    }
155
156    fn compute_anomaly_score(&self) -> f64 {
157        if self.stability_window.len() < 5 {
158            return 0.0;
159        }
160
161        let mean = self.stability_window.iter().sum::<f64>() / self.stability_window.len() as f64;
162        let variance = self.stability_window.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
163            / self.stability_window.len() as f64;
164        let std_dev = variance.sqrt();
165
166        if std_dev == 0.0 {
167            return 0.0;
168        }
169
170        // Z-score based anomaly detection
171        let z_score = (self.current_gradient_norm - mean) / std_dev;
172        z_score.abs().min(5.0) / 5.0 // Normalize to 0-1 range
173    }
174
175    pub fn get_stability_score(&self) -> f64 {
176        if self.stability_window.len() < 3 {
177            return 1.0;
178        }
179
180        let variance = self
181            .stability_window
182            .iter()
183            .map(|&x| (x - self.current_gradient_norm).powi(2))
184            .sum::<f64>()
185            / self.stability_window.len() as f64;
186
187        // Higher variance = lower stability
188        1.0 / (1.0 + variance)
189    }
190
191    pub fn is_stable(&self, threshold: f64) -> bool {
192        self.get_stability_score() > threshold
193    }
194
195    pub fn is_oscillating(&self) -> bool {
196        if self.stability_window.len() < 6 {
197            return false;
198        }
199
200        // Check for oscillating pattern by looking at sign changes
201        let mut sign_changes = 0;
202        let values: Vec<f64> = self.stability_window.iter().cloned().collect();
203
204        for i in 1..values.len() - 1 {
205            let prev_diff = values[i] - values[i - 1];
206            let curr_diff = values[i + 1] - values[i];
207
208            if prev_diff * curr_diff < 0.0 {
209                sign_changes += 1;
210            }
211        }
212
213        // If more than half the intervals change sign, consider it oscillating
214        sign_changes > values.len() / 2
215    }
216}
217
218/// Gradient monitoring configuration
219#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct MonitoringConfig {
221    pub enable_adaptive_thresholds: bool,
222    pub enable_real_time_monitoring: bool,
223    pub stability_threshold: f64,
224    pub anomaly_threshold: f64,
225    pub update_frequency: usize,
226    pub history_window_size: usize,
227}
228
229impl Default for MonitoringConfig {
230    fn default() -> Self {
231        Self {
232            enable_adaptive_thresholds: true,
233            enable_real_time_monitoring: true,
234            stability_threshold: 0.8,
235            anomaly_threshold: 0.7,
236            update_frequency: 1,
237            history_window_size: 100,
238        }
239    }
240}
241
242/// Monitoring results and insights
243#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct MonitoringResults {
245    pub layer_name: String,
246    pub timestamp: DateTime<Utc>,
247    pub current_status: LayerHealth,
248    pub stability_score: f64,
249    pub anomaly_score: f64,
250    pub alerts: Vec<GradientAlert>,
251    pub recommendations: Vec<String>,
252}
253
254impl MonitoringResults {
255    pub fn new(layer_name: String) -> Self {
256        Self {
257            layer_name,
258            timestamp: Utc::now(),
259            current_status: LayerHealth::Healthy,
260            stability_score: 1.0,
261            anomaly_score: 0.0,
262            alerts: Vec::new(),
263            recommendations: Vec::new(),
264        }
265    }
266
267    pub fn add_alert(&mut self, alert: GradientAlert) {
268        self.alerts.push(alert);
269        self.update_status();
270    }
271
272    pub fn add_recommendation(&mut self, recommendation: String) {
273        self.recommendations.push(recommendation);
274    }
275
276    fn update_status(&mut self) {
277        if self.alerts.iter().any(|alert| {
278            matches!(
279                alert,
280                GradientAlert::ExplodingGradients { .. } | GradientAlert::NoGradientFlow { .. }
281            )
282        }) {
283            self.current_status = LayerHealth::Critical;
284        } else if !self.alerts.is_empty() || self.anomaly_score > 0.7 {
285            self.current_status = LayerHealth::Warning;
286        } else {
287            self.current_status = LayerHealth::Healthy;
288        }
289    }
290}