trustformers_debug/gradient_debugger/
monitoring.rs1use super::types::*;
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::VecDeque;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct AdaptiveThresholds {
14 pub layer_name: String,
15 pub vanishing_threshold: f64,
16 pub exploding_threshold: f64,
17 pub adaptation_rate: f64,
18 pub recent_gradients: VecDeque<f64>,
19 pub last_updated: DateTime<Utc>,
20}
21
22impl AdaptiveThresholds {
23 pub fn new(layer_name: String, initial_vanishing: f64, initial_exploding: f64) -> Self {
24 Self {
25 layer_name,
26 vanishing_threshold: initial_vanishing,
27 exploding_threshold: initial_exploding,
28 adaptation_rate: 0.1,
29 recent_gradients: VecDeque::with_capacity(100),
30 last_updated: Utc::now(),
31 }
32 }
33
34 pub fn update_thresholds(&mut self, gradient_norm: f64) {
35 if self.recent_gradients.len() >= 100 {
37 self.recent_gradients.pop_front();
38 }
39 self.recent_gradients.push_back(gradient_norm);
40
41 if self.recent_gradients.len() >= 10 {
43 let mean =
44 self.recent_gradients.iter().sum::<f64>() / self.recent_gradients.len() as f64;
45 let variance = self.recent_gradients.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
46 / self.recent_gradients.len() as f64;
47 let std_dev = variance.sqrt();
48
49 let new_vanishing = (mean - 2.0 * std_dev).max(1e-8);
51 let new_exploding = mean + 3.0 * std_dev;
52
53 self.vanishing_threshold = self.vanishing_threshold * (1.0 - self.adaptation_rate)
54 + new_vanishing * self.adaptation_rate;
55 self.exploding_threshold = self.exploding_threshold * (1.0 - self.adaptation_rate)
56 + new_exploding * self.adaptation_rate;
57
58 self.last_updated = Utc::now();
59 }
60 }
61
62 pub fn check_thresholds(&self, gradient_norm: f64) -> Vec<GradientAlert> {
63 let mut alerts = Vec::new();
64
65 if gradient_norm < self.vanishing_threshold {
66 alerts.push(GradientAlert::VanishingGradients {
67 layer_name: self.layer_name.clone(),
68 norm: gradient_norm,
69 threshold: self.vanishing_threshold,
70 });
71 }
72
73 if gradient_norm > self.exploding_threshold {
74 alerts.push(GradientAlert::ExplodingGradients {
75 layer_name: self.layer_name.clone(),
76 norm: gradient_norm,
77 threshold: self.exploding_threshold,
78 });
79 }
80
81 alerts
82 }
83
84 pub fn from_history(history: &GradientHistory) -> Self {
86 let layer_name = history.layer_name.clone();
87
88 if history.gradient_norms.is_empty() {
89 return Self::new(layer_name, 1e-6, 10.0);
90 }
91
92 let norms: Vec<f64> = history.gradient_norms.iter().cloned().collect();
94 let mean = norms.iter().sum::<f64>() / norms.len() as f64;
95 let variance = norms.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / norms.len() as f64;
96 let std_dev = variance.sqrt();
97
98 let initial_vanishing = (mean - 2.0 * std_dev).max(1e-8);
99 let initial_exploding = mean + 3.0 * std_dev;
100
101 let mut thresholds = Self::new(layer_name, initial_vanishing, initial_exploding);
102
103 for &norm in norms.iter().rev().take(50) {
105 if thresholds.recent_gradients.len() >= 100 {
106 thresholds.recent_gradients.pop_front();
107 }
108 thresholds.recent_gradients.push_back(norm);
109 }
110
111 thresholds
112 }
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct RealTimeGradientMonitor {
118 pub layer_name: String,
119 pub current_gradient_norm: f64,
120 pub gradient_velocity: f64,
121 pub gradient_acceleration: f64,
122 pub stability_window: VecDeque<f64>,
123 pub anomaly_score: f64,
124}
125
126impl RealTimeGradientMonitor {
127 pub fn new(layer_name: String) -> Self {
128 Self {
129 layer_name,
130 current_gradient_norm: 0.0,
131 gradient_velocity: 0.0,
132 gradient_acceleration: 0.0,
133 stability_window: VecDeque::with_capacity(10),
134 anomaly_score: 0.0,
135 }
136 }
137
138 pub fn update(&mut self, new_gradient_norm: f64) {
139 let previous_norm = self.current_gradient_norm;
140 let previous_velocity = self.gradient_velocity;
141
142 self.current_gradient_norm = new_gradient_norm;
143 self.gradient_velocity = new_gradient_norm - previous_norm;
144 self.gradient_acceleration = self.gradient_velocity - previous_velocity;
145
146 if self.stability_window.len() >= 10 {
148 self.stability_window.pop_front();
149 }
150 self.stability_window.push_back(new_gradient_norm);
151
152 self.anomaly_score = self.compute_anomaly_score();
154 }
155
156 fn compute_anomaly_score(&self) -> f64 {
157 if self.stability_window.len() < 5 {
158 return 0.0;
159 }
160
161 let mean = self.stability_window.iter().sum::<f64>() / self.stability_window.len() as f64;
162 let variance = self.stability_window.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
163 / self.stability_window.len() as f64;
164 let std_dev = variance.sqrt();
165
166 if std_dev == 0.0 {
167 return 0.0;
168 }
169
170 let z_score = (self.current_gradient_norm - mean) / std_dev;
172 z_score.abs().min(5.0) / 5.0 }
174
175 pub fn get_stability_score(&self) -> f64 {
176 if self.stability_window.len() < 3 {
177 return 1.0;
178 }
179
180 let variance = self
181 .stability_window
182 .iter()
183 .map(|&x| (x - self.current_gradient_norm).powi(2))
184 .sum::<f64>()
185 / self.stability_window.len() as f64;
186
187 1.0 / (1.0 + variance)
189 }
190
191 pub fn is_stable(&self, threshold: f64) -> bool {
192 self.get_stability_score() > threshold
193 }
194
195 pub fn is_oscillating(&self) -> bool {
196 if self.stability_window.len() < 6 {
197 return false;
198 }
199
200 let mut sign_changes = 0;
202 let values: Vec<f64> = self.stability_window.iter().cloned().collect();
203
204 for i in 1..values.len() - 1 {
205 let prev_diff = values[i] - values[i - 1];
206 let curr_diff = values[i + 1] - values[i];
207
208 if prev_diff * curr_diff < 0.0 {
209 sign_changes += 1;
210 }
211 }
212
213 sign_changes > values.len() / 2
215 }
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct MonitoringConfig {
221 pub enable_adaptive_thresholds: bool,
222 pub enable_real_time_monitoring: bool,
223 pub stability_threshold: f64,
224 pub anomaly_threshold: f64,
225 pub update_frequency: usize,
226 pub history_window_size: usize,
227}
228
229impl Default for MonitoringConfig {
230 fn default() -> Self {
231 Self {
232 enable_adaptive_thresholds: true,
233 enable_real_time_monitoring: true,
234 stability_threshold: 0.8,
235 anomaly_threshold: 0.7,
236 update_frequency: 1,
237 history_window_size: 100,
238 }
239 }
240}
241
242#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct MonitoringResults {
245 pub layer_name: String,
246 pub timestamp: DateTime<Utc>,
247 pub current_status: LayerHealth,
248 pub stability_score: f64,
249 pub anomaly_score: f64,
250 pub alerts: Vec<GradientAlert>,
251 pub recommendations: Vec<String>,
252}
253
254impl MonitoringResults {
255 pub fn new(layer_name: String) -> Self {
256 Self {
257 layer_name,
258 timestamp: Utc::now(),
259 current_status: LayerHealth::Healthy,
260 stability_score: 1.0,
261 anomaly_score: 0.0,
262 alerts: Vec::new(),
263 recommendations: Vec::new(),
264 }
265 }
266
267 pub fn add_alert(&mut self, alert: GradientAlert) {
268 self.alerts.push(alert);
269 self.update_status();
270 }
271
272 pub fn add_recommendation(&mut self, recommendation: String) {
273 self.recommendations.push(recommendation);
274 }
275
276 fn update_status(&mut self) {
277 if self.alerts.iter().any(|alert| {
278 matches!(
279 alert,
280 GradientAlert::ExplodingGradients { .. } | GradientAlert::NoGradientFlow { .. }
281 )
282 }) {
283 self.current_status = LayerHealth::Critical;
284 } else if !self.alerts.is_empty() || self.anomaly_score > 0.7 {
285 self.current_status = LayerHealth::Warning;
286 } else {
287 self.current_status = LayerHealth::Healthy;
288 }
289 }
290}