trustformers_debug/model_diagnostics/
alerts.rs

1//! Alert system and diagnostic notifications.
2//!
3//! This module provides comprehensive alert management for model diagnostics,
4//! including threshold-based monitoring, alert prioritization, notification
5//! systems, and automated response recommendations.
6
7use anyhow::Result;
8use chrono::{DateTime, Duration, Utc};
9use std::collections::VecDeque;
10
11use super::types::{
12    ConvergenceStatus, LayerActivationStats, ModelDiagnosticAlert, ModelPerformanceMetrics,
13    TrainingDynamics, TrainingStability,
14};
15
16/// Alert manager for monitoring and managing diagnostic alerts.
17#[derive(Debug)]
18pub struct AlertManager {
19    /// Active alerts
20    active_alerts: Vec<ActiveAlert>,
21    /// Alert history
22    alert_history: VecDeque<HistoricalAlert>,
23    /// Alert configuration
24    config: AlertConfig,
25    /// Alert thresholds
26    thresholds: AlertThresholds,
27    /// Performance baseline for comparison
28    performance_baseline: Option<PerformanceBaseline>,
29}
30
31/// Configuration for the alert system.
32#[derive(Debug, Clone)]
33pub struct AlertConfig {
34    /// Maximum number of alerts to keep in history
35    pub max_history_size: usize,
36    /// Minimum time between duplicate alerts
37    pub duplicate_alert_cooldown: Duration,
38    /// Alert severity levels to monitor
39    pub monitored_severities: Vec<AlertSeverity>,
40    /// Enable automatic alert resolution
41    pub auto_resolve_alerts: bool,
42    /// Alert notification settings
43    pub notification_settings: NotificationSettings,
44}
45
46/// Alert thresholds for various metrics.
47#[derive(Debug, Clone)]
48pub struct AlertThresholds {
49    /// Performance degradation threshold (percentage)
50    pub performance_degradation_percent: f64,
51    /// Memory usage threshold (MB)
52    pub memory_usage_threshold_mb: f64,
53    /// Memory leak detection threshold (MB per step)
54    pub memory_leak_threshold_mb_per_step: f64,
55    /// Training instability variance threshold
56    pub training_instability_variance: f64,
57    /// Dead neuron ratio threshold
58    pub dead_neuron_ratio_threshold: f64,
59    /// Saturated neuron ratio threshold
60    pub saturated_neuron_ratio_threshold: f64,
61    /// Convergence plateau duration threshold (steps)
62    pub plateau_duration_threshold: usize,
63    /// Learning rate adjustment threshold
64    pub learning_rate_adjustment_threshold: f64,
65}
66
67/// Performance baseline for comparison.
68#[derive(Debug, Clone)]
69pub struct PerformanceBaseline {
70    /// Baseline loss value
71    pub baseline_loss: f64,
72    /// Baseline throughput
73    pub baseline_throughput: f64,
74    /// Baseline memory usage
75    pub baseline_memory_mb: f64,
76    /// Baseline accuracy (if available)
77    pub baseline_accuracy: Option<f64>,
78    /// When baseline was established
79    pub established_at: DateTime<Utc>,
80}
81
82/// Active alert with current status.
83#[derive(Debug, Clone)]
84pub struct ActiveAlert {
85    /// Alert information
86    pub alert: ModelDiagnosticAlert,
87    /// Alert severity
88    pub severity: AlertSeverity,
89    /// When alert was first triggered
90    pub triggered_at: DateTime<Utc>,
91    /// Number of times alert has been triggered
92    pub trigger_count: usize,
93    /// Recommended actions
94    pub recommended_actions: Vec<String>,
95    /// Alert status
96    pub status: AlertStatus,
97}
98
99/// Historical alert record.
100#[derive(Debug, Clone)]
101pub struct HistoricalAlert {
102    /// Alert information
103    pub alert: ModelDiagnosticAlert,
104    /// Alert severity
105    pub severity: AlertSeverity,
106    /// When alert was triggered
107    pub triggered_at: DateTime<Utc>,
108    /// When alert was resolved
109    pub resolved_at: Option<DateTime<Utc>>,
110    /// How alert was resolved
111    pub resolution_method: Option<String>,
112    /// Duration alert was active
113    pub duration: Option<Duration>,
114}
115
116/// Alert severity levels.
117#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
118pub enum AlertSeverity {
119    /// Informational alerts
120    Info,
121    /// Warning alerts
122    Warning,
123    /// Critical alerts requiring immediate attention
124    Critical,
125    /// Emergency alerts indicating system failure
126    Emergency,
127}
128
129/// Alert status tracking.
130#[derive(Debug, Clone, PartialEq)]
131pub enum AlertStatus {
132    /// Alert is active and unresolved
133    Active,
134    /// Alert is acknowledged but not resolved
135    Acknowledged,
136    /// Alert is being investigated
137    InvestigationInProgress,
138    /// Alert has been resolved
139    Resolved,
140    /// Alert was a false positive
141    FalsePositive,
142}
143
144/// Notification settings for alerts.
145#[derive(Debug, Clone)]
146pub struct NotificationSettings {
147    /// Enable console notifications
148    pub console_notifications: bool,
149    /// Enable file logging
150    pub file_logging: bool,
151    /// Log file path for alerts
152    pub log_file_path: Option<String>,
153    /// Enable webhook notifications
154    pub webhook_notifications: bool,
155    /// Webhook URL for notifications
156    pub webhook_url: Option<String>,
157}
158
159impl Default for AlertConfig {
160    fn default() -> Self {
161        Self {
162            max_history_size: 1000,
163            duplicate_alert_cooldown: Duration::minutes(5),
164            monitored_severities: vec![
165                AlertSeverity::Warning,
166                AlertSeverity::Critical,
167                AlertSeverity::Emergency,
168            ],
169            auto_resolve_alerts: true,
170            notification_settings: NotificationSettings::default(),
171        }
172    }
173}
174
175impl Default for NotificationSettings {
176    fn default() -> Self {
177        Self {
178            console_notifications: true,
179            file_logging: false,
180            log_file_path: None,
181            webhook_notifications: false,
182            webhook_url: None,
183        }
184    }
185}
186
187impl Default for AlertThresholds {
188    fn default() -> Self {
189        Self {
190            performance_degradation_percent: 10.0,
191            memory_usage_threshold_mb: 8192.0, // 8GB
192            memory_leak_threshold_mb_per_step: 1.0,
193            training_instability_variance: 0.1,
194            dead_neuron_ratio_threshold: 0.1,
195            saturated_neuron_ratio_threshold: 0.05,
196            plateau_duration_threshold: 100,
197            learning_rate_adjustment_threshold: 0.01,
198        }
199    }
200}
201
202impl AlertManager {
203    /// Create a new alert manager.
204    pub fn new() -> Self {
205        Self {
206            active_alerts: Vec::new(),
207            alert_history: VecDeque::new(),
208            config: AlertConfig::default(),
209            thresholds: AlertThresholds::default(),
210            performance_baseline: None,
211        }
212    }
213
214    /// Create alert manager with custom configuration.
215    pub fn with_config(config: AlertConfig, thresholds: AlertThresholds) -> Self {
216        Self {
217            active_alerts: Vec::new(),
218            alert_history: VecDeque::new(),
219            config,
220            thresholds,
221            performance_baseline: None,
222        }
223    }
224
225    /// Set performance baseline for comparison.
226    pub fn set_performance_baseline(&mut self, baseline: PerformanceBaseline) {
227        self.performance_baseline = Some(baseline);
228    }
229
230    /// Establish baseline from current metrics.
231    pub fn establish_baseline_from_metrics(&mut self, metrics: &ModelPerformanceMetrics) {
232        self.performance_baseline = Some(PerformanceBaseline {
233            baseline_loss: metrics.loss,
234            baseline_throughput: metrics.throughput_samples_per_sec,
235            baseline_memory_mb: metrics.memory_usage_mb,
236            baseline_accuracy: metrics.accuracy,
237            established_at: Utc::now(),
238        });
239    }
240
241    /// Process performance metrics and generate alerts.
242    pub fn process_performance_metrics(
243        &mut self,
244        metrics: &ModelPerformanceMetrics,
245    ) -> Result<Vec<ModelDiagnosticAlert>> {
246        let mut new_alerts = Vec::new();
247
248        // Check for performance degradation
249        if let Some(baseline) = &self.performance_baseline {
250            let loss_degradation =
251                ((metrics.loss - baseline.baseline_loss) / baseline.baseline_loss) * 100.0;
252            if loss_degradation > self.thresholds.performance_degradation_percent {
253                let alert = ModelDiagnosticAlert::PerformanceDegradation {
254                    metric: "loss".to_string(),
255                    current: metrics.loss,
256                    previous_avg: baseline.baseline_loss,
257                    degradation_percent: loss_degradation,
258                };
259                new_alerts.push(alert);
260            }
261
262            let throughput_degradation = ((baseline.baseline_throughput
263                - metrics.throughput_samples_per_sec)
264                / baseline.baseline_throughput)
265                * 100.0;
266            if throughput_degradation > self.thresholds.performance_degradation_percent {
267                let alert = ModelDiagnosticAlert::PerformanceDegradation {
268                    metric: "throughput".to_string(),
269                    current: metrics.throughput_samples_per_sec,
270                    previous_avg: baseline.baseline_throughput,
271                    degradation_percent: throughput_degradation,
272                };
273                new_alerts.push(alert);
274            }
275        }
276
277        // Check for memory issues
278        if metrics.memory_usage_mb > self.thresholds.memory_usage_threshold_mb {
279            let alert = ModelDiagnosticAlert::MemoryLeak {
280                current_usage_mb: metrics.memory_usage_mb,
281                growth_rate_mb_per_step: 0.0, // Would need historical data to calculate
282            };
283            new_alerts.push(alert);
284        }
285
286        // Process new alerts
287        for alert in &new_alerts {
288            self.add_alert(alert.clone(), self.determine_alert_severity(alert))?;
289        }
290
291        Ok(new_alerts)
292    }
293
294    /// Process training dynamics and generate alerts.
295    pub fn process_training_dynamics(
296        &mut self,
297        dynamics: &TrainingDynamics,
298    ) -> Result<Vec<ModelDiagnosticAlert>> {
299        let mut new_alerts = Vec::new();
300
301        // Check for training instability
302        if matches!(
303            dynamics.training_stability,
304            TrainingStability::Unstable | TrainingStability::HighVariance
305        ) {
306            let alert = ModelDiagnosticAlert::TrainingInstability {
307                variance: 0.0, // Would need to extract from dynamics
308                threshold: self.thresholds.training_instability_variance,
309            };
310            new_alerts.push(alert);
311        }
312
313        // Check for convergence issues
314        match dynamics.convergence_status {
315            ConvergenceStatus::Diverging => {
316                let alert = ModelDiagnosticAlert::ConvergenceIssue {
317                    issue_type: ConvergenceStatus::Diverging,
318                    duration_steps: 0, // Would need historical tracking
319                };
320                new_alerts.push(alert);
321            },
322            ConvergenceStatus::Plateau => {
323                if let Some(plateau_info) = &dynamics.plateau_detection {
324                    if plateau_info.duration_steps > self.thresholds.plateau_duration_threshold {
325                        let alert = ModelDiagnosticAlert::ConvergenceIssue {
326                            issue_type: ConvergenceStatus::Plateau,
327                            duration_steps: plateau_info.duration_steps,
328                        };
329                        new_alerts.push(alert);
330                    }
331                }
332            },
333            _ => {},
334        }
335
336        // Process new alerts
337        for alert in &new_alerts {
338            self.add_alert(alert.clone(), self.determine_alert_severity(alert))?;
339        }
340
341        Ok(new_alerts)
342    }
343
344    /// Process layer statistics and generate alerts.
345    pub fn process_layer_stats(
346        &mut self,
347        stats: &LayerActivationStats,
348    ) -> Result<Vec<ModelDiagnosticAlert>> {
349        let mut new_alerts = Vec::new();
350
351        // Check for dead neurons
352        if stats.dead_neurons_ratio > self.thresholds.dead_neuron_ratio_threshold {
353            let alert = ModelDiagnosticAlert::ArchitecturalConcern {
354                concern: format!(
355                    "High dead neuron ratio in layer {}: {:.2}%",
356                    stats.layer_name,
357                    stats.dead_neurons_ratio * 100.0
358                ),
359                recommendation: "Consider adjusting learning rate or initialization".to_string(),
360            };
361            new_alerts.push(alert);
362        }
363
364        // Check for saturated neurons
365        if stats.saturated_neurons_ratio > self.thresholds.saturated_neuron_ratio_threshold {
366            let alert = ModelDiagnosticAlert::ArchitecturalConcern {
367                concern: format!(
368                    "High saturated neuron ratio in layer {}: {:.2}%",
369                    stats.layer_name,
370                    stats.saturated_neurons_ratio * 100.0
371                ),
372                recommendation: "Consider adjusting activation function or scaling".to_string(),
373            };
374            new_alerts.push(alert);
375        }
376
377        // Process new alerts
378        for alert in &new_alerts {
379            self.add_alert(alert.clone(), self.determine_alert_severity(alert))?;
380        }
381
382        Ok(new_alerts)
383    }
384
385    /// Add a new alert to the system.
386    pub fn add_alert(
387        &mut self,
388        alert: ModelDiagnosticAlert,
389        severity: AlertSeverity,
390    ) -> Result<()> {
391        // Check for duplicate alerts within cooldown period
392        if self.is_duplicate_alert(&alert) {
393            return Ok(());
394        }
395
396        let active_alert = ActiveAlert {
397            alert: alert.clone(),
398            severity: severity.clone(),
399            triggered_at: Utc::now(),
400            trigger_count: 1,
401            recommended_actions: self.generate_recommended_actions(&alert),
402            status: AlertStatus::Active,
403        };
404
405        self.active_alerts.push(active_alert);
406
407        // Send notification
408        self.send_notification(&alert, &severity)?;
409
410        Ok(())
411    }
412
413    /// Resolve an alert.
414    pub fn resolve_alert(&mut self, alert_index: usize, resolution_method: String) -> Result<()> {
415        if alert_index >= self.active_alerts.len() {
416            return Err(anyhow::anyhow!("Invalid alert index"));
417        }
418
419        let mut active_alert = self.active_alerts.remove(alert_index);
420        active_alert.status = AlertStatus::Resolved;
421
422        let historical_alert = HistoricalAlert {
423            alert: active_alert.alert,
424            severity: active_alert.severity,
425            triggered_at: active_alert.triggered_at,
426            resolved_at: Some(Utc::now()),
427            resolution_method: Some(resolution_method),
428            duration: Some(Utc::now() - active_alert.triggered_at),
429        };
430
431        self.add_to_history(historical_alert);
432        Ok(())
433    }
434
435    /// Get all active alerts.
436    pub fn get_active_alerts(&self) -> &[ActiveAlert] {
437        &self.active_alerts
438    }
439
440    /// Get alerts by severity.
441    pub fn get_alerts_by_severity(&self, severity: AlertSeverity) -> Vec<&ActiveAlert> {
442        self.active_alerts.iter().filter(|alert| alert.severity == severity).collect()
443    }
444
445    /// Get alert statistics.
446    pub fn get_alert_statistics(&self) -> AlertStatistics {
447        let mut stats = AlertStatistics::default();
448
449        for alert in &self.active_alerts {
450            match alert.severity {
451                AlertSeverity::Info => stats.info_count += 1,
452                AlertSeverity::Warning => stats.warning_count += 1,
453                AlertSeverity::Critical => stats.critical_count += 1,
454                AlertSeverity::Emergency => stats.emergency_count += 1,
455            }
456        }
457
458        stats.total_active = self.active_alerts.len();
459        stats.total_historical = self.alert_history.len();
460
461        stats
462    }
463
464    /// Clear resolved alerts from active list.
465    pub fn clear_resolved_alerts(&mut self) {
466        let now = Utc::now();
467        let mut resolved_alerts = Vec::new();
468
469        self.active_alerts.retain(|alert| {
470            if matches!(alert.status, AlertStatus::Resolved) {
471                resolved_alerts.push(HistoricalAlert {
472                    alert: alert.alert.clone(),
473                    severity: alert.severity.clone(),
474                    triggered_at: alert.triggered_at,
475                    resolved_at: Some(now),
476                    resolution_method: Some("Auto-resolved".to_string()),
477                    duration: Some(now - alert.triggered_at),
478                });
479                false
480            } else {
481                true
482            }
483        });
484
485        for historical in resolved_alerts {
486            self.add_to_history(historical);
487        }
488    }
489
490    /// Determine alert severity based on alert type.
491    fn determine_alert_severity(&self, alert: &ModelDiagnosticAlert) -> AlertSeverity {
492        match alert {
493            ModelDiagnosticAlert::PerformanceDegradation {
494                degradation_percent,
495                ..
496            } => {
497                if *degradation_percent > 50.0 {
498                    AlertSeverity::Critical
499                } else if *degradation_percent > 25.0 {
500                    AlertSeverity::Warning
501                } else {
502                    AlertSeverity::Info
503                }
504            },
505            ModelDiagnosticAlert::MemoryLeak {
506                current_usage_mb, ..
507            } => {
508                if *current_usage_mb > 16384.0 {
509                    // 16GB
510                    AlertSeverity::Emergency
511                } else if *current_usage_mb > 8192.0 {
512                    // 8GB
513                    AlertSeverity::Critical
514                } else {
515                    AlertSeverity::Warning
516                }
517            },
518            ModelDiagnosticAlert::TrainingInstability { .. } => AlertSeverity::Warning,
519            ModelDiagnosticAlert::ConvergenceIssue { issue_type, .. } => match issue_type {
520                ConvergenceStatus::Diverging => AlertSeverity::Critical,
521                ConvergenceStatus::Plateau => AlertSeverity::Warning,
522                _ => AlertSeverity::Info,
523            },
524            ModelDiagnosticAlert::ArchitecturalConcern { .. } => AlertSeverity::Info,
525        }
526    }
527
528    /// Check if alert is a duplicate within cooldown period.
529    fn is_duplicate_alert(&self, alert: &ModelDiagnosticAlert) -> bool {
530        let now = Utc::now();
531        let cooldown_threshold = now - self.config.duplicate_alert_cooldown;
532
533        self.active_alerts.iter().any(|active| {
534            active.triggered_at > cooldown_threshold
535                && std::mem::discriminant(&active.alert) == std::mem::discriminant(alert)
536        })
537    }
538
539    /// Generate recommended actions for an alert.
540    fn generate_recommended_actions(&self, alert: &ModelDiagnosticAlert) -> Vec<String> {
541        match alert {
542            ModelDiagnosticAlert::PerformanceDegradation { metric, .. } => {
543                vec![
544                    format!("Investigate {} degradation causes", metric),
545                    "Check for data quality issues".to_string(),
546                    "Review recent configuration changes".to_string(),
547                    "Consider adjusting learning rate".to_string(),
548                ]
549            },
550            ModelDiagnosticAlert::MemoryLeak { .. } => {
551                vec![
552                    "Monitor memory usage patterns".to_string(),
553                    "Check for gradient accumulation issues".to_string(),
554                    "Review batch size configuration".to_string(),
555                    "Consider implementing memory cleanup".to_string(),
556                ]
557            },
558            ModelDiagnosticAlert::TrainingInstability { .. } => {
559                vec![
560                    "Reduce learning rate".to_string(),
561                    "Enable gradient clipping".to_string(),
562                    "Check data preprocessing".to_string(),
563                    "Consider using learning rate scheduling".to_string(),
564                ]
565            },
566            ModelDiagnosticAlert::ConvergenceIssue { issue_type, .. } => match issue_type {
567                ConvergenceStatus::Diverging => vec![
568                    "Immediately reduce learning rate".to_string(),
569                    "Check gradient magnitudes".to_string(),
570                    "Review loss function implementation".to_string(),
571                ],
572                ConvergenceStatus::Plateau => vec![
573                    "Consider learning rate annealing".to_string(),
574                    "Try different optimization algorithm".to_string(),
575                    "Evaluate model capacity".to_string(),
576                ],
577                _ => vec!["Monitor training progress".to_string()],
578            },
579            ModelDiagnosticAlert::ArchitecturalConcern { recommendation, .. } => {
580                vec![recommendation.clone()]
581            },
582        }
583    }
584
585    /// Send notification for an alert.
586    fn send_notification(
587        &self,
588        alert: &ModelDiagnosticAlert,
589        severity: &AlertSeverity,
590    ) -> Result<()> {
591        if self.config.notification_settings.console_notifications {
592            println!("[{:?}] Alert: {:?}", severity, alert);
593        }
594
595        if self.config.notification_settings.file_logging {
596            if let Some(log_path) = &self.config.notification_settings.log_file_path {
597                // Would implement file logging here
598                let _ = log_path; // Suppress unused warning
599            }
600        }
601
602        if self.config.notification_settings.webhook_notifications {
603            if let Some(webhook_url) = &self.config.notification_settings.webhook_url {
604                // Would implement webhook notification here
605                let _ = webhook_url; // Suppress unused warning
606            }
607        }
608
609        Ok(())
610    }
611
612    /// Add alert to history with size management.
613    fn add_to_history(&mut self, historical_alert: HistoricalAlert) {
614        self.alert_history.push_back(historical_alert);
615
616        while self.alert_history.len() > self.config.max_history_size {
617            self.alert_history.pop_front();
618        }
619    }
620}
621
622/// Alert system statistics.
623#[derive(Debug, Default)]
624pub struct AlertStatistics {
625    /// Number of active info alerts
626    pub info_count: usize,
627    /// Number of active warning alerts
628    pub warning_count: usize,
629    /// Number of active critical alerts
630    pub critical_count: usize,
631    /// Number of active emergency alerts
632    pub emergency_count: usize,
633    /// Total active alerts
634    pub total_active: usize,
635    /// Total historical alerts
636    pub total_historical: usize,
637}
638
639impl Default for AlertManager {
640    fn default() -> Self {
641        Self::new()
642    }
643}
644
645#[cfg(test)]
646mod tests {
647    use super::*;
648
649    #[test]
650    fn test_alert_manager_creation() {
651        let manager = AlertManager::new();
652        assert_eq!(manager.active_alerts.len(), 0);
653        assert_eq!(manager.alert_history.len(), 0);
654    }
655
656    #[test]
657    fn test_add_alert() {
658        let mut manager = AlertManager::new();
659        let alert = ModelDiagnosticAlert::PerformanceDegradation {
660            metric: "loss".to_string(),
661            current: 1.5,
662            previous_avg: 1.0,
663            degradation_percent: 50.0,
664        };
665
666        manager.add_alert(alert, AlertSeverity::Warning).unwrap();
667        assert_eq!(manager.active_alerts.len(), 1);
668    }
669
670    #[test]
671    fn test_alert_severity_determination() {
672        let manager = AlertManager::new();
673
674        let high_degradation = ModelDiagnosticAlert::PerformanceDegradation {
675            metric: "loss".to_string(),
676            current: 2.0,
677            previous_avg: 1.0,
678            degradation_percent: 60.0,
679        };
680
681        let severity = manager.determine_alert_severity(&high_degradation);
682        assert_eq!(severity, AlertSeverity::Critical);
683    }
684
685    #[test]
686    fn test_duplicate_alert_detection() {
687        let mut manager = AlertManager::new();
688        let alert = ModelDiagnosticAlert::TrainingInstability {
689            variance: 0.2,
690            threshold: 0.1,
691        };
692
693        // Add first alert
694        manager.add_alert(alert.clone(), AlertSeverity::Warning).unwrap();
695        assert_eq!(manager.active_alerts.len(), 1);
696
697        // Try to add duplicate - should be filtered out
698        manager.add_alert(alert, AlertSeverity::Warning).unwrap();
699        assert_eq!(manager.active_alerts.len(), 1);
700    }
701}