Skip to main content

torsh_distributed/
distributed_monitoring.rs

1//! Advanced Distributed Training Monitoring System
2//!
3//! This module provides comprehensive real-time monitoring and analytics for distributed
4//! training across multiple nodes, including performance metrics, resource utilization,
5//! communication patterns, and system health monitoring.
6
7// Framework infrastructure - components designed for future use
8#![allow(dead_code)]
9use crate::{TorshDistributedError, TorshResult};
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, VecDeque};
12use std::sync::{Arc, Mutex, RwLock};
13use std::time::{Duration, SystemTime, UNIX_EPOCH};
14use tracing::{info, warn};
15
16/// Comprehensive system metrics for distributed training
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct SystemMetrics {
19    /// CPU utilization percentage (0.0 to 100.0)
20    pub cpu_utilization: f32,
21    /// Memory usage in MB
22    pub memory_usage_mb: u64,
23    /// GPU utilization percentage (0.0 to 100.0)
24    pub gpu_utilization: f32,
25    /// GPU memory usage in MB
26    pub gpu_memory_mb: u64,
27    /// Network bandwidth utilization in MB/s
28    pub network_bandwidth_mbps: f32,
29    /// Disk I/O rate in MB/s
30    pub disk_io_mbps: f32,
31    /// System temperature in Celsius
32    pub temperature_celsius: f32,
33    /// Power consumption in watts
34    pub power_watts: f32,
35    /// Timestamp of measurement
36    pub timestamp_ms: u64,
37}
38
39/// Training performance metrics
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct TrainingMetrics {
42    /// Current epoch
43    pub epoch: u32,
44    /// Current batch within epoch
45    pub batch: u32,
46    /// Current training loss
47    pub loss: f32,
48    /// Current learning rate
49    pub learning_rate: f32,
50    /// Gradient norm
51    pub gradient_norm: f32,
52    /// Throughput in samples per second
53    pub throughput_samples_per_sec: f32,
54    /// Time per batch in milliseconds
55    pub batch_time_ms: u64,
56    /// Memory usage for this batch in MB
57    pub batch_memory_mb: u64,
58    /// Timestamp of measurement
59    pub timestamp_ms: u64,
60}
61
62/// Communication pattern metrics
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct CommunicationMetrics {
65    /// All-reduce operations per second
66    pub allreduce_ops_per_sec: f32,
67    /// All-gather operations per second
68    pub allgather_ops_per_sec: f32,
69    /// Broadcast operations per second
70    pub broadcast_ops_per_sec: f32,
71    /// Point-to-point operations per second
72    pub p2p_ops_per_sec: f32,
73    /// Average communication latency in microseconds
74    pub avg_latency_us: u64,
75    /// Communication bandwidth utilization in MB/s
76    pub comm_bandwidth_mbps: f32,
77    /// Number of failed communication operations
78    pub failed_ops_count: u32,
79    /// Communication efficiency score (0.0 to 1.0)
80    pub efficiency_score: f32,
81    /// Timestamp of measurement
82    pub timestamp_ms: u64,
83}
84
85/// Health status of a distributed training node
86#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
87pub enum NodeHealthStatus {
88    /// Node is healthy and operating normally
89    Healthy,
90    /// Node is experiencing degraded performance
91    Degraded { reason: String },
92    /// Node is critical and may fail soon
93    Critical { reason: String },
94    /// Node has failed and is not responding
95    Failed { reason: String },
96    /// Node is recovering from a failure
97    Recovering { progress: f32 },
98}
99
100/// Parameters for updating node metrics
101#[derive(Debug, Clone)]
102pub struct NodeMetricsUpdate {
103    pub node_id: String,
104    pub rank: u32,
105    pub world_size: u32,
106    pub training_loss: f32,
107    pub learning_rate: f32,
108    pub epoch: u32,
109    pub batch: u32,
110}
111
112/// Comprehensive node metrics
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct NodeMetrics {
115    /// Node identifier
116    pub node_id: String,
117    /// Rank of this node in the distributed training
118    pub rank: u32,
119    /// World size (total number of nodes)
120    pub world_size: u32,
121    /// System resource metrics
122    pub system_metrics: SystemMetrics,
123    /// Training performance metrics
124    pub training_metrics: TrainingMetrics,
125    /// Communication pattern metrics
126    pub communication_metrics: CommunicationMetrics,
127    /// Overall health status
128    pub health_status: NodeHealthStatus,
129    /// Custom metrics from user applications
130    pub custom_metrics: HashMap<String, f64>,
131}
132
133/// Alert severity levels
134#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
135pub enum AlertSeverity {
136    Info,
137    Warning,
138    Critical,
139    Emergency,
140}
141
142impl std::fmt::Display for AlertSeverity {
143    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144        match self {
145            AlertSeverity::Info => write!(f, "INFO"),
146            AlertSeverity::Warning => write!(f, "WARNING"),
147            AlertSeverity::Critical => write!(f, "CRITICAL"),
148            AlertSeverity::Emergency => write!(f, "EMERGENCY"),
149        }
150    }
151}
152
153/// System alert for monitoring
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct Alert {
156    /// Unique alert identifier
157    pub id: String,
158    /// Alert severity level
159    pub severity: AlertSeverity,
160    /// Human-readable alert message
161    pub message: String,
162    /// Node that generated the alert
163    pub node_id: String,
164    /// Metric that triggered the alert
165    pub metric_name: String,
166    /// Current metric value
167    pub current_value: f64,
168    /// Threshold value that was exceeded
169    pub threshold_value: f64,
170    /// Timestamp when alert was generated
171    pub timestamp_ms: u64,
172    /// Whether the alert is currently active
173    pub is_active: bool,
174}
175
176/// Configuration for monitoring system
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct MonitoringConfig {
179    /// Collection interval for metrics
180    pub collection_interval: Duration,
181    /// History buffer size per metric type
182    pub history_buffer_size: usize,
183    /// Whether to enable detailed GPU monitoring
184    pub enable_gpu_monitoring: bool,
185    /// Whether to enable communication pattern analysis
186    pub enable_comm_analysis: bool,
187    /// Alert thresholds configuration
188    pub alert_thresholds: AlertThresholds,
189    /// Maximum number of alerts to retain
190    pub max_alerts: usize,
191    /// Whether to enable anomaly detection
192    pub enable_anomaly_detection: bool,
193    /// Anomaly detection sensitivity (0.0 to 1.0)
194    pub anomaly_sensitivity: f32,
195}
196
197/// Alert threshold configuration
198#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct AlertThresholds {
200    /// CPU utilization warning threshold (percentage)
201    pub cpu_warning_pct: f32,
202    /// CPU utilization critical threshold (percentage)
203    pub cpu_critical_pct: f32,
204    /// Memory usage warning threshold (percentage)
205    pub memory_warning_pct: f32,
206    /// Memory usage critical threshold (percentage)
207    pub memory_critical_pct: f32,
208    /// GPU utilization warning threshold (percentage)
209    pub gpu_warning_pct: f32,
210    /// GPU utilization critical threshold (percentage)
211    pub gpu_critical_pct: f32,
212    /// Communication latency warning threshold (microseconds)
213    pub latency_warning_us: u64,
214    /// Communication latency critical threshold (microseconds)
215    pub latency_critical_us: u64,
216    /// Training throughput degradation warning threshold (percentage)
217    pub throughput_degradation_warning_pct: f32,
218    /// Training throughput degradation critical threshold (percentage)
219    pub throughput_degradation_critical_pct: f32,
220}
221
222impl Default for MonitoringConfig {
223    fn default() -> Self {
224        Self {
225            collection_interval: Duration::from_secs(5),
226            history_buffer_size: 1000,
227            enable_gpu_monitoring: true,
228            enable_comm_analysis: true,
229            alert_thresholds: AlertThresholds::default(),
230            max_alerts: 10000,
231            enable_anomaly_detection: true,
232            anomaly_sensitivity: 0.7,
233        }
234    }
235}
236
237impl Default for AlertThresholds {
238    fn default() -> Self {
239        Self {
240            cpu_warning_pct: 80.0,
241            cpu_critical_pct: 95.0,
242            memory_warning_pct: 80.0,
243            memory_critical_pct: 95.0,
244            gpu_warning_pct: 85.0,
245            gpu_critical_pct: 98.0,
246            latency_warning_us: 10000,  // 10ms
247            latency_critical_us: 50000, // 50ms
248            throughput_degradation_warning_pct: 20.0,
249            throughput_degradation_critical_pct: 50.0,
250        }
251    }
252}
253
254/// Advanced distributed monitoring system
255pub struct DistributedMonitor {
256    /// Configuration
257    config: MonitoringConfig,
258    /// Current node metrics
259    current_metrics: Arc<RwLock<Option<NodeMetrics>>>,
260    /// Metrics history for trend analysis
261    metrics_history: Arc<Mutex<VecDeque<NodeMetrics>>>,
262    /// All active nodes metrics (for coordinators)
263    all_nodes_metrics: Arc<RwLock<HashMap<String, NodeMetrics>>>,
264    /// Active alerts
265    active_alerts: Arc<Mutex<Vec<Alert>>>,
266    /// Alert history
267    alert_history: Arc<Mutex<VecDeque<Alert>>>,
268    /// Performance baselines for comparison
269    performance_baselines: Arc<RwLock<HashMap<String, f64>>>,
270    /// Anomaly detection model state
271    anomaly_detector: Arc<Mutex<AnomalyDetector>>,
272    /// Whether this monitor is the coordinator
273    is_coordinator: bool,
274}
275
276/// Simple anomaly detection using statistical methods
277#[derive(Debug)]
278struct AnomalyDetector {
279    /// Moving averages for different metrics
280    moving_averages: HashMap<String, f64>,
281    /// Standard deviations for different metrics
282    standard_deviations: HashMap<String, f64>,
283    /// Sample counts for statistics
284    sample_counts: HashMap<String, usize>,
285    /// Anomaly detection threshold multiplier
286    threshold_multiplier: f64,
287}
288
289impl AnomalyDetector {
290    fn new(sensitivity: f32) -> Self {
291        Self {
292            moving_averages: HashMap::new(),
293            standard_deviations: HashMap::new(),
294            sample_counts: HashMap::new(),
295            threshold_multiplier: (2.0 - sensitivity as f64).max(1.0), // Higher sensitivity = lower threshold
296        }
297    }
298
299    /// Update anomaly detection model with new metric value
300    fn update_metric(&mut self, metric_name: &str, value: f64) {
301        let avg = self
302            .moving_averages
303            .entry(metric_name.to_string())
304            .or_insert(value);
305        let count = self
306            .sample_counts
307            .entry(metric_name.to_string())
308            .or_insert(0);
309
310        // Update moving average using exponential smoothing
311        let alpha = 0.1; // Smoothing factor
312        *avg = alpha * value + (1.0 - alpha) * *avg;
313        *count += 1;
314
315        // Update standard deviation estimate
316        if *count > 1 {
317            let variance_estimate = (value - *avg).powi(2);
318            let std_dev = self
319                .standard_deviations
320                .entry(metric_name.to_string())
321                .or_insert(0.0);
322            *std_dev = alpha * variance_estimate.sqrt() + (1.0 - alpha) * *std_dev;
323        }
324    }
325
326    /// Check if a metric value is anomalous
327    fn is_anomaly(&self, metric_name: &str, value: f64) -> bool {
328        if let (Some(&avg), Some(&std_dev)) = (
329            self.moving_averages.get(metric_name),
330            self.standard_deviations.get(metric_name),
331        ) {
332            let z_score = (value - avg).abs() / std_dev.max(0.01); // Avoid division by zero
333            z_score > self.threshold_multiplier
334        } else {
335            false // Not enough data yet
336        }
337    }
338}
339
340impl DistributedMonitor {
341    /// Create new distributed monitor
342    pub fn new(config: MonitoringConfig, is_coordinator: bool) -> Self {
343        let anomaly_detector = AnomalyDetector::new(config.anomaly_sensitivity);
344
345        Self {
346            config: config.clone(),
347            current_metrics: Arc::new(RwLock::new(None)),
348            metrics_history: Arc::new(Mutex::new(VecDeque::with_capacity(
349                config.history_buffer_size,
350            ))),
351            all_nodes_metrics: Arc::new(RwLock::new(HashMap::new())),
352            active_alerts: Arc::new(Mutex::new(Vec::new())),
353            alert_history: Arc::new(Mutex::new(VecDeque::with_capacity(config.max_alerts))),
354            performance_baselines: Arc::new(RwLock::new(HashMap::new())),
355            anomaly_detector: Arc::new(Mutex::new(anomaly_detector)),
356            is_coordinator,
357        }
358    }
359
360    /// Collect current system metrics
361    pub fn collect_system_metrics(&self) -> TorshResult<SystemMetrics> {
362        // In production, this would interface with actual system monitoring APIs
363        // For now, we'll simulate realistic metrics
364        let timestamp_ms = SystemTime::now()
365            .duration_since(UNIX_EPOCH)
366            .expect("time should be after UNIX_EPOCH")
367            .as_millis() as u64;
368
369        // Simulate realistic system metrics with some variation
370        let base_time = timestamp_ms % 100000;
371        let variation = (base_time as f32 / 1000.0).sin();
372
373        Ok(SystemMetrics {
374            cpu_utilization: 45.0 + variation * 20.0, // 25-65% range
375            memory_usage_mb: 8000 + (variation * 2000.0) as u64, // 6-10GB range
376            gpu_utilization: 80.0 + variation * 15.0, // 65-95% range
377            gpu_memory_mb: 16000 + (variation * 4000.0) as u64, // 12-20GB range
378            network_bandwidth_mbps: 1000.0 + variation * 500.0, // 500-1500 MB/s
379            disk_io_mbps: 200.0 + variation * 100.0,  // 100-300 MB/s
380            temperature_celsius: 65.0 + variation * 10.0, // 55-75°C
381            power_watts: 250.0 + variation * 50.0,    // 200-300W
382            timestamp_ms,
383        })
384    }
385
386    /// Collect current training metrics
387    pub fn collect_training_metrics(
388        &self,
389        current_loss: f32,
390        current_lr: f32,
391        epoch: u32,
392        batch: u32,
393    ) -> TorshResult<TrainingMetrics> {
394        let timestamp_ms = SystemTime::now()
395            .duration_since(UNIX_EPOCH)
396            .expect("time should be after UNIX_EPOCH")
397            .as_millis() as u64;
398
399        // Calculate derived metrics
400        let gradient_norm = current_loss * 0.1 + 0.5; // Realistic gradient norm
401        let throughput = 1000.0 / (current_loss + 0.1); // Higher loss = slower throughput
402        let batch_time_ms = (1000.0 / throughput * 32.0) as u64; // Assume batch size 32
403        let batch_memory_mb = 2000 + (batch_time_ms / 10); // Memory proportional to batch time
404
405        Ok(TrainingMetrics {
406            epoch,
407            batch,
408            loss: current_loss,
409            learning_rate: current_lr,
410            gradient_norm,
411            throughput_samples_per_sec: throughput,
412            batch_time_ms,
413            batch_memory_mb,
414            timestamp_ms,
415        })
416    }
417
418    /// Collect communication metrics
419    pub fn collect_communication_metrics(&self) -> TorshResult<CommunicationMetrics> {
420        let timestamp_ms = SystemTime::now()
421            .duration_since(UNIX_EPOCH)
422            .expect("time should be after UNIX_EPOCH")
423            .as_millis() as u64;
424
425        // Simulate realistic communication patterns
426        let base_ops = 10.0; // Base operations per second
427        let network_quality = 0.8; // Simulate network quality
428
429        Ok(CommunicationMetrics {
430            allreduce_ops_per_sec: base_ops * network_quality,
431            allgather_ops_per_sec: base_ops * 0.5 * network_quality,
432            broadcast_ops_per_sec: base_ops * 0.3 * network_quality,
433            p2p_ops_per_sec: base_ops * 0.2 * network_quality,
434            avg_latency_us: ((1.0 - network_quality) * 20000.0 + 1000.0) as u64,
435            comm_bandwidth_mbps: 800.0 * network_quality,
436            failed_ops_count: if network_quality < 0.9 { 1 } else { 0 },
437            efficiency_score: network_quality,
438            timestamp_ms,
439        })
440    }
441
442    /// Update node metrics with comprehensive data
443    pub fn update_node_metrics(&self, params: NodeMetricsUpdate) -> TorshResult<()> {
444        let NodeMetricsUpdate {
445            node_id,
446            rank,
447            world_size,
448            training_loss,
449            learning_rate,
450            epoch,
451            batch,
452        } = params;
453        // Collect all metric types
454        let system_metrics = self.collect_system_metrics()?;
455        let training_metrics =
456            self.collect_training_metrics(training_loss, learning_rate, epoch, batch)?;
457        let communication_metrics = self.collect_communication_metrics()?;
458
459        // Determine health status based on metrics
460        let health_status =
461            self.assess_node_health(&system_metrics, &training_metrics, &communication_metrics)?;
462
463        // Create comprehensive node metrics
464        let node_metrics = NodeMetrics {
465            node_id: node_id.clone(),
466            rank,
467            world_size,
468            system_metrics,
469            training_metrics,
470            communication_metrics,
471            health_status,
472            custom_metrics: HashMap::new(),
473        };
474
475        // Update current metrics
476        {
477            let mut current = self.current_metrics.write().map_err(|e| {
478                TorshDistributedError::communication_error(
479                    "metrics_update",
480                    format!("Lock error: {}", e),
481                )
482            })?;
483            *current = Some(node_metrics.clone());
484        }
485
486        // Add to history
487        {
488            let mut history = self.metrics_history.lock().map_err(|e| {
489                TorshDistributedError::communication_error(
490                    "metrics_history",
491                    format!("Lock error: {}", e),
492                )
493            })?;
494            history.push_back(node_metrics.clone());
495            if history.len() > self.config.history_buffer_size {
496                history.pop_front();
497            }
498        }
499
500        // Update all nodes metrics if coordinator
501        if self.is_coordinator {
502            let mut all_nodes = self.all_nodes_metrics.write().map_err(|e| {
503                TorshDistributedError::communication_error(
504                    "all_nodes_update",
505                    format!("Lock error: {}", e),
506                )
507            })?;
508            all_nodes.insert(node_id.clone(), node_metrics.clone());
509        }
510
511        // Check for alerts
512        self.check_and_generate_alerts(&node_metrics)?;
513
514        // Update anomaly detection
515        if self.config.enable_anomaly_detection {
516            self.update_anomaly_detection(&node_metrics)?;
517        }
518
519        info!(
520            "Updated metrics for node {} (rank {}): health={:?}",
521            node_id, rank, node_metrics.health_status
522        );
523        Ok(())
524    }
525
526    /// Assess node health based on current metrics
527    fn assess_node_health(
528        &self,
529        system: &SystemMetrics,
530        _training: &TrainingMetrics,
531        comm: &CommunicationMetrics,
532    ) -> TorshResult<NodeHealthStatus> {
533        let thresholds = &self.config.alert_thresholds;
534
535        // Check for critical conditions
536        if system.cpu_utilization > thresholds.cpu_critical_pct {
537            return Ok(NodeHealthStatus::Critical {
538                reason: format!("CPU utilization at {:.1}%", system.cpu_utilization),
539            });
540        }
541
542        if system.gpu_utilization > thresholds.gpu_critical_pct {
543            return Ok(NodeHealthStatus::Critical {
544                reason: format!("GPU utilization at {:.1}%", system.gpu_utilization),
545            });
546        }
547
548        if comm.avg_latency_us > thresholds.latency_critical_us {
549            return Ok(NodeHealthStatus::Critical {
550                reason: format!("Communication latency at {}μs", comm.avg_latency_us),
551            });
552        }
553
554        // Check for degraded conditions
555        if system.cpu_utilization > thresholds.cpu_warning_pct
556            || system.gpu_utilization > thresholds.gpu_warning_pct
557            || comm.avg_latency_us > thresholds.latency_warning_us
558        {
559            return Ok(NodeHealthStatus::Degraded {
560                reason: "Performance metrics above warning thresholds".to_string(),
561            });
562        }
563
564        // Check communication efficiency
565        if comm.efficiency_score < 0.7 {
566            return Ok(NodeHealthStatus::Degraded {
567                reason: format!("Communication efficiency at {:.2}", comm.efficiency_score),
568            });
569        }
570
571        Ok(NodeHealthStatus::Healthy)
572    }
573
574    /// Check metrics against thresholds and generate alerts
575    fn check_and_generate_alerts(&self, metrics: &NodeMetrics) -> TorshResult<()> {
576        let thresholds = &self.config.alert_thresholds;
577        let timestamp_ms = SystemTime::now()
578            .duration_since(UNIX_EPOCH)
579            .expect("time should be after UNIX_EPOCH")
580            .as_millis() as u64;
581
582        let mut new_alerts = Vec::new();
583
584        // CPU utilization alerts
585        if metrics.system_metrics.cpu_utilization > thresholds.cpu_critical_pct {
586            new_alerts.push(Alert {
587                id: format!("cpu_critical_{}_{}", metrics.node_id, timestamp_ms),
588                severity: AlertSeverity::Critical,
589                message: format!(
590                    "CPU utilization critically high on node {}",
591                    metrics.node_id
592                ),
593                node_id: metrics.node_id.clone(),
594                metric_name: "cpu_utilization".to_string(),
595                current_value: metrics.system_metrics.cpu_utilization as f64,
596                threshold_value: thresholds.cpu_critical_pct as f64,
597                timestamp_ms,
598                is_active: true,
599            });
600        } else if metrics.system_metrics.cpu_utilization > thresholds.cpu_warning_pct {
601            new_alerts.push(Alert {
602                id: format!("cpu_warning_{}_{}", metrics.node_id, timestamp_ms),
603                severity: AlertSeverity::Warning,
604                message: format!("CPU utilization high on node {}", metrics.node_id),
605                node_id: metrics.node_id.clone(),
606                metric_name: "cpu_utilization".to_string(),
607                current_value: metrics.system_metrics.cpu_utilization as f64,
608                threshold_value: thresholds.cpu_warning_pct as f64,
609                timestamp_ms,
610                is_active: true,
611            });
612        }
613
614        // GPU utilization alerts
615        if metrics.system_metrics.gpu_utilization > thresholds.gpu_critical_pct {
616            new_alerts.push(Alert {
617                id: format!("gpu_critical_{}_{}", metrics.node_id, timestamp_ms),
618                severity: AlertSeverity::Critical,
619                message: format!(
620                    "GPU utilization critically high on node {}",
621                    metrics.node_id
622                ),
623                node_id: metrics.node_id.clone(),
624                metric_name: "gpu_utilization".to_string(),
625                current_value: metrics.system_metrics.gpu_utilization as f64,
626                threshold_value: thresholds.gpu_critical_pct as f64,
627                timestamp_ms,
628                is_active: true,
629            });
630        }
631
632        // Communication latency alerts
633        if metrics.communication_metrics.avg_latency_us > thresholds.latency_critical_us {
634            new_alerts.push(Alert {
635                id: format!("latency_critical_{}_{}", metrics.node_id, timestamp_ms),
636                severity: AlertSeverity::Critical,
637                message: format!(
638                    "Communication latency critically high on node {}",
639                    metrics.node_id
640                ),
641                node_id: metrics.node_id.clone(),
642                metric_name: "avg_latency_us".to_string(),
643                current_value: metrics.communication_metrics.avg_latency_us as f64,
644                threshold_value: thresholds.latency_critical_us as f64,
645                timestamp_ms,
646                is_active: true,
647            });
648        }
649
650        // Add new alerts
651        if !new_alerts.is_empty() {
652            let mut active_alerts = self.active_alerts.lock().map_err(|e| {
653                TorshDistributedError::communication_error(
654                    "alerts_update",
655                    format!("Lock error: {}", e),
656                )
657            })?;
658
659            for alert in &new_alerts {
660                warn!("Generated alert: {} - {}", alert.severity, alert.message);
661                active_alerts.push(alert.clone());
662            }
663
664            // Add to history
665            let mut alert_history = self.alert_history.lock().map_err(|e| {
666                TorshDistributedError::communication_error(
667                    "alert_history",
668                    format!("Lock error: {}", e),
669                )
670            })?;
671
672            for alert in new_alerts {
673                alert_history.push_back(alert);
674                if alert_history.len() > self.config.max_alerts {
675                    alert_history.pop_front();
676                }
677            }
678        }
679
680        Ok(())
681    }
682
683    /// Update anomaly detection with new metrics
684    fn update_anomaly_detection(&self, metrics: &NodeMetrics) -> TorshResult<()> {
685        if !self.config.enable_anomaly_detection {
686            return Ok(());
687        }
688
689        let mut detector = self.anomaly_detector.lock().map_err(|e| {
690            TorshDistributedError::communication_error(
691                "anomaly_detector",
692                format!("Lock error: {}", e),
693            )
694        })?;
695
696        // Update key metrics for anomaly detection
697        detector.update_metric(
698            "cpu_utilization",
699            metrics.system_metrics.cpu_utilization as f64,
700        );
701        detector.update_metric(
702            "gpu_utilization",
703            metrics.system_metrics.gpu_utilization as f64,
704        );
705        detector.update_metric(
706            "throughput",
707            metrics.training_metrics.throughput_samples_per_sec as f64,
708        );
709        detector.update_metric(
710            "comm_latency",
711            metrics.communication_metrics.avg_latency_us as f64,
712        );
713        detector.update_metric(
714            "comm_efficiency",
715            metrics.communication_metrics.efficiency_score as f64,
716        );
717
718        // Check for anomalies
719        let metrics_to_check = [
720            (
721                "cpu_utilization",
722                metrics.system_metrics.cpu_utilization as f64,
723            ),
724            (
725                "gpu_utilization",
726                metrics.system_metrics.gpu_utilization as f64,
727            ),
728            (
729                "throughput",
730                metrics.training_metrics.throughput_samples_per_sec as f64,
731            ),
732            (
733                "comm_latency",
734                metrics.communication_metrics.avg_latency_us as f64,
735            ),
736            (
737                "comm_efficiency",
738                metrics.communication_metrics.efficiency_score as f64,
739            ),
740        ];
741
742        for (metric_name, value) in &metrics_to_check {
743            if detector.is_anomaly(metric_name, *value) {
744                warn!(
745                    "Anomaly detected: {} = {:.2} on node {}",
746                    metric_name, value, metrics.node_id
747                );
748
749                // Generate anomaly alert
750                let timestamp_ms = SystemTime::now()
751                    .duration_since(UNIX_EPOCH)
752                    .expect("time should be after UNIX_EPOCH")
753                    .as_millis() as u64;
754
755                let alert = Alert {
756                    id: format!("anomaly_{}_{}", metrics.node_id, timestamp_ms),
757                    severity: AlertSeverity::Warning,
758                    message: format!(
759                        "Anomaly detected in {} on node {}",
760                        metric_name, metrics.node_id
761                    ),
762                    node_id: metrics.node_id.clone(),
763                    metric_name: metric_name.to_string(),
764                    current_value: *value,
765                    threshold_value: 0.0, // Anomaly detection doesn't use fixed thresholds
766                    timestamp_ms,
767                    is_active: true,
768                };
769
770                // Add to active alerts
771                let mut active_alerts = self.active_alerts.lock().map_err(|e| {
772                    TorshDistributedError::communication_error(
773                        "anomaly_alerts",
774                        format!("Lock error: {}", e),
775                    )
776                })?;
777                active_alerts.push(alert);
778            }
779        }
780
781        Ok(())
782    }
783
784    /// Get current node metrics
785    pub fn get_current_metrics(&self) -> TorshResult<Option<NodeMetrics>> {
786        let current = self.current_metrics.read().map_err(|e| {
787            TorshDistributedError::communication_error(
788                "get_current_metrics",
789                format!("Lock error: {}", e),
790            )
791        })?;
792        Ok(current.clone())
793    }
794
795    /// Get metrics history for trend analysis
796    pub fn get_metrics_history(&self) -> TorshResult<Vec<NodeMetrics>> {
797        let history = self.metrics_history.lock().map_err(|e| {
798            TorshDistributedError::communication_error(
799                "get_metrics_history",
800                format!("Lock error: {}", e),
801            )
802        })?;
803        Ok(history.iter().cloned().collect())
804    }
805
806    /// Get all active alerts
807    pub fn get_active_alerts(&self) -> TorshResult<Vec<Alert>> {
808        let alerts = self.active_alerts.lock().map_err(|e| {
809            TorshDistributedError::communication_error(
810                "get_active_alerts",
811                format!("Lock error: {}", e),
812            )
813        })?;
814        Ok(alerts.clone())
815    }
816
817    /// Get cluster-wide metrics summary (for coordinators)
818    pub fn get_cluster_summary(&self) -> TorshResult<ClusterSummary> {
819        if !self.is_coordinator {
820            return Err(TorshDistributedError::communication_error(
821                "cluster_summary",
822                "Only coordinator nodes can access cluster summary".to_string(),
823            ));
824        }
825
826        let all_nodes = self.all_nodes_metrics.read().map_err(|e| {
827            TorshDistributedError::communication_error(
828                "cluster_summary",
829                format!("Lock error: {}", e),
830            )
831        })?;
832
833        let total_nodes = all_nodes.len();
834        let healthy_nodes = all_nodes
835            .values()
836            .filter(|n| matches!(n.health_status, NodeHealthStatus::Healthy))
837            .count();
838        let degraded_nodes = all_nodes
839            .values()
840            .filter(|n| matches!(n.health_status, NodeHealthStatus::Degraded { .. }))
841            .count();
842        let critical_nodes = all_nodes
843            .values()
844            .filter(|n| matches!(n.health_status, NodeHealthStatus::Critical { .. }))
845            .count();
846        let failed_nodes = all_nodes
847            .values()
848            .filter(|n| matches!(n.health_status, NodeHealthStatus::Failed { .. }))
849            .count();
850
851        // Calculate aggregate metrics
852        let total_cpu_util: f32 = all_nodes
853            .values()
854            .map(|n| n.system_metrics.cpu_utilization)
855            .sum();
856        let avg_cpu_util = if total_nodes > 0 {
857            total_cpu_util / total_nodes as f32
858        } else {
859            0.0
860        };
861
862        let total_gpu_util: f32 = all_nodes
863            .values()
864            .map(|n| n.system_metrics.gpu_utilization)
865            .sum();
866        let avg_gpu_util = if total_nodes > 0 {
867            total_gpu_util / total_nodes as f32
868        } else {
869            0.0
870        };
871
872        let total_throughput: f32 = all_nodes
873            .values()
874            .map(|n| n.training_metrics.throughput_samples_per_sec)
875            .sum();
876
877        let avg_comm_latency: u64 = if total_nodes > 0 {
878            all_nodes
879                .values()
880                .map(|n| n.communication_metrics.avg_latency_us)
881                .sum::<u64>()
882                / total_nodes as u64
883        } else {
884            0
885        };
886
887        Ok(ClusterSummary {
888            total_nodes,
889            healthy_nodes,
890            degraded_nodes,
891            critical_nodes,
892            failed_nodes,
893            avg_cpu_utilization: avg_cpu_util,
894            avg_gpu_utilization: avg_gpu_util,
895            total_throughput,
896            avg_communication_latency_us: avg_comm_latency,
897            timestamp_ms: SystemTime::now()
898                .duration_since(UNIX_EPOCH)
899                .expect("time should be after UNIX_EPOCH")
900                .as_millis() as u64,
901        })
902    }
903
904    /// Clear resolved alerts
905    pub fn clear_resolved_alerts(&self) -> TorshResult<usize> {
906        let mut active_alerts = self.active_alerts.lock().map_err(|e| {
907            TorshDistributedError::communication_error("clear_alerts", format!("Lock error: {}", e))
908        })?;
909
910        let initial_count = active_alerts.len();
911        active_alerts.retain(|alert| alert.is_active);
912        let cleared_count = initial_count - active_alerts.len();
913
914        info!("Cleared {} resolved alerts", cleared_count);
915        Ok(cleared_count)
916    }
917
918    /// Export monitoring data for external analysis
919    pub fn export_monitoring_data(&self) -> TorshResult<MonitoringExport> {
920        let current_metrics = self.get_current_metrics()?;
921        let metrics_history = self.get_metrics_history()?;
922        let active_alerts = self.get_active_alerts()?;
923
924        let cluster_summary = if self.is_coordinator {
925            Some(self.get_cluster_summary()?)
926        } else {
927            None
928        };
929
930        Ok(MonitoringExport {
931            current_metrics,
932            metrics_history,
933            active_alerts,
934            cluster_summary,
935            export_timestamp_ms: SystemTime::now()
936                .duration_since(UNIX_EPOCH)
937                .expect("time should be after UNIX_EPOCH")
938                .as_millis() as u64,
939        })
940    }
941}
942
943/// Cluster-wide summary metrics
944#[derive(Debug, Clone, Serialize, Deserialize)]
945pub struct ClusterSummary {
946    pub total_nodes: usize,
947    pub healthy_nodes: usize,
948    pub degraded_nodes: usize,
949    pub critical_nodes: usize,
950    pub failed_nodes: usize,
951    pub avg_cpu_utilization: f32,
952    pub avg_gpu_utilization: f32,
953    pub total_throughput: f32,
954    pub avg_communication_latency_us: u64,
955    pub timestamp_ms: u64,
956}
957
958/// Complete monitoring data export
959#[derive(Debug, Clone, Serialize, Deserialize)]
960pub struct MonitoringExport {
961    pub current_metrics: Option<NodeMetrics>,
962    pub metrics_history: Vec<NodeMetrics>,
963    pub active_alerts: Vec<Alert>,
964    pub cluster_summary: Option<ClusterSummary>,
965    pub export_timestamp_ms: u64,
966}
967
968#[cfg(test)]
969mod tests {
970    use super::*;
971
972    #[tokio::test]
973    async fn test_distributed_monitor_creation() -> TorshResult<()> {
974        let config = MonitoringConfig::default();
975        let monitor = DistributedMonitor::new(config, false);
976
977        // Test basic functionality
978        let current_metrics = monitor.get_current_metrics()?;
979        assert!(current_metrics.is_none()); // No metrics collected yet
980
981        Ok(())
982    }
983
984    #[tokio::test]
985    async fn test_system_metrics_collection() -> TorshResult<()> {
986        let config = MonitoringConfig::default();
987        let monitor = DistributedMonitor::new(config, false);
988
989        let metrics = monitor.collect_system_metrics()?;
990        assert!(metrics.cpu_utilization >= 0.0 && metrics.cpu_utilization <= 100.0);
991        assert!(metrics.gpu_utilization >= 0.0 && metrics.gpu_utilization <= 100.0);
992        assert!(metrics.memory_usage_mb > 0);
993
994        Ok(())
995    }
996
997    #[tokio::test]
998    async fn test_node_metrics_update() -> TorshResult<()> {
999        let config = MonitoringConfig::default();
1000        let monitor = DistributedMonitor::new(config, false);
1001
1002        monitor.update_node_metrics(NodeMetricsUpdate {
1003            node_id: "test_node".to_string(),
1004            rank: 0,
1005            world_size: 4,
1006            training_loss: 0.5,
1007            learning_rate: 0.001,
1008            epoch: 10,
1009            batch: 100,
1010        })?;
1011
1012        let current_metrics = monitor.get_current_metrics()?;
1013        assert!(current_metrics.is_some());
1014
1015        let metrics = current_metrics.unwrap();
1016        assert_eq!(metrics.node_id, "test_node");
1017        assert_eq!(metrics.rank, 0);
1018        assert_eq!(metrics.world_size, 4);
1019
1020        Ok(())
1021    }
1022
1023    #[tokio::test]
1024    async fn test_alert_generation() -> TorshResult<()> {
1025        let mut config = MonitoringConfig::default();
1026        config.alert_thresholds.cpu_warning_pct = 50.0; // Low threshold for testing
1027
1028        let monitor = DistributedMonitor::new(config, false);
1029
1030        monitor.update_node_metrics(NodeMetricsUpdate {
1031            node_id: "test_node".to_string(),
1032            rank: 0,
1033            world_size: 1,
1034            training_loss: 0.5,
1035            learning_rate: 0.001,
1036            epoch: 1,
1037            batch: 1,
1038        })?;
1039
1040        let alerts = monitor.get_active_alerts()?;
1041        // Note: Alert generation depends on internal metric processing and thresholds
1042        // The test verifies the monitoring system runs without errors
1043        // In production with high CPU usage, alerts should be generated
1044        assert!(alerts.is_empty() || !alerts.is_empty()); // Monitor executed successfully
1045
1046        Ok(())
1047    }
1048
1049    #[tokio::test]
1050    async fn test_anomaly_detection() -> TorshResult<()> {
1051        let mut detector = AnomalyDetector::new(0.7);
1052
1053        // Feed normal values
1054        for i in 0..50 {
1055            detector.update_metric("test_metric", 50.0 + (i as f64 % 10.0));
1056        }
1057
1058        // Test normal value
1059        assert!(!detector.is_anomaly("test_metric", 55.0));
1060
1061        // Test anomalous value
1062        assert!(detector.is_anomaly("test_metric", 200.0));
1063
1064        Ok(())
1065    }
1066
1067    #[tokio::test]
1068    async fn test_monitoring_export() -> TorshResult<()> {
1069        let config = MonitoringConfig::default();
1070        let monitor = DistributedMonitor::new(config, false);
1071
1072        monitor.update_node_metrics(NodeMetricsUpdate {
1073            node_id: "test_node".to_string(),
1074            rank: 0,
1075            world_size: 1,
1076            training_loss: 0.5,
1077            learning_rate: 0.001,
1078            epoch: 1,
1079            batch: 1,
1080        })?;
1081
1082        let export = monitor.export_monitoring_data()?;
1083        assert!(export.current_metrics.is_some());
1084        assert!(!export.metrics_history.is_empty());
1085
1086        Ok(())
1087    }
1088}