1#![allow(dead_code)]
9use crate::{TorshDistributedError, TorshResult};
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, VecDeque};
12use std::sync::{Arc, Mutex, RwLock};
13use std::time::{Duration, SystemTime, UNIX_EPOCH};
14use tracing::{info, warn};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct SystemMetrics {
19 pub cpu_utilization: f32,
21 pub memory_usage_mb: u64,
23 pub gpu_utilization: f32,
25 pub gpu_memory_mb: u64,
27 pub network_bandwidth_mbps: f32,
29 pub disk_io_mbps: f32,
31 pub temperature_celsius: f32,
33 pub power_watts: f32,
35 pub timestamp_ms: u64,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct TrainingMetrics {
42 pub epoch: u32,
44 pub batch: u32,
46 pub loss: f32,
48 pub learning_rate: f32,
50 pub gradient_norm: f32,
52 pub throughput_samples_per_sec: f32,
54 pub batch_time_ms: u64,
56 pub batch_memory_mb: u64,
58 pub timestamp_ms: u64,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct CommunicationMetrics {
65 pub allreduce_ops_per_sec: f32,
67 pub allgather_ops_per_sec: f32,
69 pub broadcast_ops_per_sec: f32,
71 pub p2p_ops_per_sec: f32,
73 pub avg_latency_us: u64,
75 pub comm_bandwidth_mbps: f32,
77 pub failed_ops_count: u32,
79 pub efficiency_score: f32,
81 pub timestamp_ms: u64,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
87pub enum NodeHealthStatus {
88 Healthy,
90 Degraded { reason: String },
92 Critical { reason: String },
94 Failed { reason: String },
96 Recovering { progress: f32 },
98}
99
100#[derive(Debug, Clone)]
102pub struct NodeMetricsUpdate {
103 pub node_id: String,
104 pub rank: u32,
105 pub world_size: u32,
106 pub training_loss: f32,
107 pub learning_rate: f32,
108 pub epoch: u32,
109 pub batch: u32,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct NodeMetrics {
115 pub node_id: String,
117 pub rank: u32,
119 pub world_size: u32,
121 pub system_metrics: SystemMetrics,
123 pub training_metrics: TrainingMetrics,
125 pub communication_metrics: CommunicationMetrics,
127 pub health_status: NodeHealthStatus,
129 pub custom_metrics: HashMap<String, f64>,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
135pub enum AlertSeverity {
136 Info,
137 Warning,
138 Critical,
139 Emergency,
140}
141
142impl std::fmt::Display for AlertSeverity {
143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144 match self {
145 AlertSeverity::Info => write!(f, "INFO"),
146 AlertSeverity::Warning => write!(f, "WARNING"),
147 AlertSeverity::Critical => write!(f, "CRITICAL"),
148 AlertSeverity::Emergency => write!(f, "EMERGENCY"),
149 }
150 }
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct Alert {
156 pub id: String,
158 pub severity: AlertSeverity,
160 pub message: String,
162 pub node_id: String,
164 pub metric_name: String,
166 pub current_value: f64,
168 pub threshold_value: f64,
170 pub timestamp_ms: u64,
172 pub is_active: bool,
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct MonitoringConfig {
179 pub collection_interval: Duration,
181 pub history_buffer_size: usize,
183 pub enable_gpu_monitoring: bool,
185 pub enable_comm_analysis: bool,
187 pub alert_thresholds: AlertThresholds,
189 pub max_alerts: usize,
191 pub enable_anomaly_detection: bool,
193 pub anomaly_sensitivity: f32,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct AlertThresholds {
200 pub cpu_warning_pct: f32,
202 pub cpu_critical_pct: f32,
204 pub memory_warning_pct: f32,
206 pub memory_critical_pct: f32,
208 pub gpu_warning_pct: f32,
210 pub gpu_critical_pct: f32,
212 pub latency_warning_us: u64,
214 pub latency_critical_us: u64,
216 pub throughput_degradation_warning_pct: f32,
218 pub throughput_degradation_critical_pct: f32,
220}
221
222impl Default for MonitoringConfig {
223 fn default() -> Self {
224 Self {
225 collection_interval: Duration::from_secs(5),
226 history_buffer_size: 1000,
227 enable_gpu_monitoring: true,
228 enable_comm_analysis: true,
229 alert_thresholds: AlertThresholds::default(),
230 max_alerts: 10000,
231 enable_anomaly_detection: true,
232 anomaly_sensitivity: 0.7,
233 }
234 }
235}
236
237impl Default for AlertThresholds {
238 fn default() -> Self {
239 Self {
240 cpu_warning_pct: 80.0,
241 cpu_critical_pct: 95.0,
242 memory_warning_pct: 80.0,
243 memory_critical_pct: 95.0,
244 gpu_warning_pct: 85.0,
245 gpu_critical_pct: 98.0,
246 latency_warning_us: 10000, latency_critical_us: 50000, throughput_degradation_warning_pct: 20.0,
249 throughput_degradation_critical_pct: 50.0,
250 }
251 }
252}
253
254pub struct DistributedMonitor {
256 config: MonitoringConfig,
258 current_metrics: Arc<RwLock<Option<NodeMetrics>>>,
260 metrics_history: Arc<Mutex<VecDeque<NodeMetrics>>>,
262 all_nodes_metrics: Arc<RwLock<HashMap<String, NodeMetrics>>>,
264 active_alerts: Arc<Mutex<Vec<Alert>>>,
266 alert_history: Arc<Mutex<VecDeque<Alert>>>,
268 performance_baselines: Arc<RwLock<HashMap<String, f64>>>,
270 anomaly_detector: Arc<Mutex<AnomalyDetector>>,
272 is_coordinator: bool,
274}
275
276#[derive(Debug)]
278struct AnomalyDetector {
279 moving_averages: HashMap<String, f64>,
281 standard_deviations: HashMap<String, f64>,
283 sample_counts: HashMap<String, usize>,
285 threshold_multiplier: f64,
287}
288
289impl AnomalyDetector {
290 fn new(sensitivity: f32) -> Self {
291 Self {
292 moving_averages: HashMap::new(),
293 standard_deviations: HashMap::new(),
294 sample_counts: HashMap::new(),
295 threshold_multiplier: (2.0 - sensitivity as f64).max(1.0), }
297 }
298
299 fn update_metric(&mut self, metric_name: &str, value: f64) {
301 let avg = self
302 .moving_averages
303 .entry(metric_name.to_string())
304 .or_insert(value);
305 let count = self
306 .sample_counts
307 .entry(metric_name.to_string())
308 .or_insert(0);
309
310 let alpha = 0.1; *avg = alpha * value + (1.0 - alpha) * *avg;
313 *count += 1;
314
315 if *count > 1 {
317 let variance_estimate = (value - *avg).powi(2);
318 let std_dev = self
319 .standard_deviations
320 .entry(metric_name.to_string())
321 .or_insert(0.0);
322 *std_dev = alpha * variance_estimate.sqrt() + (1.0 - alpha) * *std_dev;
323 }
324 }
325
326 fn is_anomaly(&self, metric_name: &str, value: f64) -> bool {
328 if let (Some(&avg), Some(&std_dev)) = (
329 self.moving_averages.get(metric_name),
330 self.standard_deviations.get(metric_name),
331 ) {
332 let z_score = (value - avg).abs() / std_dev.max(0.01); z_score > self.threshold_multiplier
334 } else {
335 false }
337 }
338}
339
340impl DistributedMonitor {
341 pub fn new(config: MonitoringConfig, is_coordinator: bool) -> Self {
343 let anomaly_detector = AnomalyDetector::new(config.anomaly_sensitivity);
344
345 Self {
346 config: config.clone(),
347 current_metrics: Arc::new(RwLock::new(None)),
348 metrics_history: Arc::new(Mutex::new(VecDeque::with_capacity(
349 config.history_buffer_size,
350 ))),
351 all_nodes_metrics: Arc::new(RwLock::new(HashMap::new())),
352 active_alerts: Arc::new(Mutex::new(Vec::new())),
353 alert_history: Arc::new(Mutex::new(VecDeque::with_capacity(config.max_alerts))),
354 performance_baselines: Arc::new(RwLock::new(HashMap::new())),
355 anomaly_detector: Arc::new(Mutex::new(anomaly_detector)),
356 is_coordinator,
357 }
358 }
359
360 pub fn collect_system_metrics(&self) -> TorshResult<SystemMetrics> {
362 let timestamp_ms = SystemTime::now()
365 .duration_since(UNIX_EPOCH)
366 .expect("time should be after UNIX_EPOCH")
367 .as_millis() as u64;
368
369 let base_time = timestamp_ms % 100000;
371 let variation = (base_time as f32 / 1000.0).sin();
372
373 Ok(SystemMetrics {
374 cpu_utilization: 45.0 + variation * 20.0, memory_usage_mb: 8000 + (variation * 2000.0) as u64, gpu_utilization: 80.0 + variation * 15.0, gpu_memory_mb: 16000 + (variation * 4000.0) as u64, network_bandwidth_mbps: 1000.0 + variation * 500.0, disk_io_mbps: 200.0 + variation * 100.0, temperature_celsius: 65.0 + variation * 10.0, power_watts: 250.0 + variation * 50.0, timestamp_ms,
383 })
384 }
385
386 pub fn collect_training_metrics(
388 &self,
389 current_loss: f32,
390 current_lr: f32,
391 epoch: u32,
392 batch: u32,
393 ) -> TorshResult<TrainingMetrics> {
394 let timestamp_ms = SystemTime::now()
395 .duration_since(UNIX_EPOCH)
396 .expect("time should be after UNIX_EPOCH")
397 .as_millis() as u64;
398
399 let gradient_norm = current_loss * 0.1 + 0.5; let throughput = 1000.0 / (current_loss + 0.1); let batch_time_ms = (1000.0 / throughput * 32.0) as u64; let batch_memory_mb = 2000 + (batch_time_ms / 10); Ok(TrainingMetrics {
406 epoch,
407 batch,
408 loss: current_loss,
409 learning_rate: current_lr,
410 gradient_norm,
411 throughput_samples_per_sec: throughput,
412 batch_time_ms,
413 batch_memory_mb,
414 timestamp_ms,
415 })
416 }
417
418 pub fn collect_communication_metrics(&self) -> TorshResult<CommunicationMetrics> {
420 let timestamp_ms = SystemTime::now()
421 .duration_since(UNIX_EPOCH)
422 .expect("time should be after UNIX_EPOCH")
423 .as_millis() as u64;
424
425 let base_ops = 10.0; let network_quality = 0.8; Ok(CommunicationMetrics {
430 allreduce_ops_per_sec: base_ops * network_quality,
431 allgather_ops_per_sec: base_ops * 0.5 * network_quality,
432 broadcast_ops_per_sec: base_ops * 0.3 * network_quality,
433 p2p_ops_per_sec: base_ops * 0.2 * network_quality,
434 avg_latency_us: ((1.0 - network_quality) * 20000.0 + 1000.0) as u64,
435 comm_bandwidth_mbps: 800.0 * network_quality,
436 failed_ops_count: if network_quality < 0.9 { 1 } else { 0 },
437 efficiency_score: network_quality,
438 timestamp_ms,
439 })
440 }
441
442 pub fn update_node_metrics(&self, params: NodeMetricsUpdate) -> TorshResult<()> {
444 let NodeMetricsUpdate {
445 node_id,
446 rank,
447 world_size,
448 training_loss,
449 learning_rate,
450 epoch,
451 batch,
452 } = params;
453 let system_metrics = self.collect_system_metrics()?;
455 let training_metrics =
456 self.collect_training_metrics(training_loss, learning_rate, epoch, batch)?;
457 let communication_metrics = self.collect_communication_metrics()?;
458
459 let health_status =
461 self.assess_node_health(&system_metrics, &training_metrics, &communication_metrics)?;
462
463 let node_metrics = NodeMetrics {
465 node_id: node_id.clone(),
466 rank,
467 world_size,
468 system_metrics,
469 training_metrics,
470 communication_metrics,
471 health_status,
472 custom_metrics: HashMap::new(),
473 };
474
475 {
477 let mut current = self.current_metrics.write().map_err(|e| {
478 TorshDistributedError::communication_error(
479 "metrics_update",
480 format!("Lock error: {}", e),
481 )
482 })?;
483 *current = Some(node_metrics.clone());
484 }
485
486 {
488 let mut history = self.metrics_history.lock().map_err(|e| {
489 TorshDistributedError::communication_error(
490 "metrics_history",
491 format!("Lock error: {}", e),
492 )
493 })?;
494 history.push_back(node_metrics.clone());
495 if history.len() > self.config.history_buffer_size {
496 history.pop_front();
497 }
498 }
499
500 if self.is_coordinator {
502 let mut all_nodes = self.all_nodes_metrics.write().map_err(|e| {
503 TorshDistributedError::communication_error(
504 "all_nodes_update",
505 format!("Lock error: {}", e),
506 )
507 })?;
508 all_nodes.insert(node_id.clone(), node_metrics.clone());
509 }
510
511 self.check_and_generate_alerts(&node_metrics)?;
513
514 if self.config.enable_anomaly_detection {
516 self.update_anomaly_detection(&node_metrics)?;
517 }
518
519 info!(
520 "Updated metrics for node {} (rank {}): health={:?}",
521 node_id, rank, node_metrics.health_status
522 );
523 Ok(())
524 }
525
526 fn assess_node_health(
528 &self,
529 system: &SystemMetrics,
530 _training: &TrainingMetrics,
531 comm: &CommunicationMetrics,
532 ) -> TorshResult<NodeHealthStatus> {
533 let thresholds = &self.config.alert_thresholds;
534
535 if system.cpu_utilization > thresholds.cpu_critical_pct {
537 return Ok(NodeHealthStatus::Critical {
538 reason: format!("CPU utilization at {:.1}%", system.cpu_utilization),
539 });
540 }
541
542 if system.gpu_utilization > thresholds.gpu_critical_pct {
543 return Ok(NodeHealthStatus::Critical {
544 reason: format!("GPU utilization at {:.1}%", system.gpu_utilization),
545 });
546 }
547
548 if comm.avg_latency_us > thresholds.latency_critical_us {
549 return Ok(NodeHealthStatus::Critical {
550 reason: format!("Communication latency at {}μs", comm.avg_latency_us),
551 });
552 }
553
554 if system.cpu_utilization > thresholds.cpu_warning_pct
556 || system.gpu_utilization > thresholds.gpu_warning_pct
557 || comm.avg_latency_us > thresholds.latency_warning_us
558 {
559 return Ok(NodeHealthStatus::Degraded {
560 reason: "Performance metrics above warning thresholds".to_string(),
561 });
562 }
563
564 if comm.efficiency_score < 0.7 {
566 return Ok(NodeHealthStatus::Degraded {
567 reason: format!("Communication efficiency at {:.2}", comm.efficiency_score),
568 });
569 }
570
571 Ok(NodeHealthStatus::Healthy)
572 }
573
574 fn check_and_generate_alerts(&self, metrics: &NodeMetrics) -> TorshResult<()> {
576 let thresholds = &self.config.alert_thresholds;
577 let timestamp_ms = SystemTime::now()
578 .duration_since(UNIX_EPOCH)
579 .expect("time should be after UNIX_EPOCH")
580 .as_millis() as u64;
581
582 let mut new_alerts = Vec::new();
583
584 if metrics.system_metrics.cpu_utilization > thresholds.cpu_critical_pct {
586 new_alerts.push(Alert {
587 id: format!("cpu_critical_{}_{}", metrics.node_id, timestamp_ms),
588 severity: AlertSeverity::Critical,
589 message: format!(
590 "CPU utilization critically high on node {}",
591 metrics.node_id
592 ),
593 node_id: metrics.node_id.clone(),
594 metric_name: "cpu_utilization".to_string(),
595 current_value: metrics.system_metrics.cpu_utilization as f64,
596 threshold_value: thresholds.cpu_critical_pct as f64,
597 timestamp_ms,
598 is_active: true,
599 });
600 } else if metrics.system_metrics.cpu_utilization > thresholds.cpu_warning_pct {
601 new_alerts.push(Alert {
602 id: format!("cpu_warning_{}_{}", metrics.node_id, timestamp_ms),
603 severity: AlertSeverity::Warning,
604 message: format!("CPU utilization high on node {}", metrics.node_id),
605 node_id: metrics.node_id.clone(),
606 metric_name: "cpu_utilization".to_string(),
607 current_value: metrics.system_metrics.cpu_utilization as f64,
608 threshold_value: thresholds.cpu_warning_pct as f64,
609 timestamp_ms,
610 is_active: true,
611 });
612 }
613
614 if metrics.system_metrics.gpu_utilization > thresholds.gpu_critical_pct {
616 new_alerts.push(Alert {
617 id: format!("gpu_critical_{}_{}", metrics.node_id, timestamp_ms),
618 severity: AlertSeverity::Critical,
619 message: format!(
620 "GPU utilization critically high on node {}",
621 metrics.node_id
622 ),
623 node_id: metrics.node_id.clone(),
624 metric_name: "gpu_utilization".to_string(),
625 current_value: metrics.system_metrics.gpu_utilization as f64,
626 threshold_value: thresholds.gpu_critical_pct as f64,
627 timestamp_ms,
628 is_active: true,
629 });
630 }
631
632 if metrics.communication_metrics.avg_latency_us > thresholds.latency_critical_us {
634 new_alerts.push(Alert {
635 id: format!("latency_critical_{}_{}", metrics.node_id, timestamp_ms),
636 severity: AlertSeverity::Critical,
637 message: format!(
638 "Communication latency critically high on node {}",
639 metrics.node_id
640 ),
641 node_id: metrics.node_id.clone(),
642 metric_name: "avg_latency_us".to_string(),
643 current_value: metrics.communication_metrics.avg_latency_us as f64,
644 threshold_value: thresholds.latency_critical_us as f64,
645 timestamp_ms,
646 is_active: true,
647 });
648 }
649
650 if !new_alerts.is_empty() {
652 let mut active_alerts = self.active_alerts.lock().map_err(|e| {
653 TorshDistributedError::communication_error(
654 "alerts_update",
655 format!("Lock error: {}", e),
656 )
657 })?;
658
659 for alert in &new_alerts {
660 warn!("Generated alert: {} - {}", alert.severity, alert.message);
661 active_alerts.push(alert.clone());
662 }
663
664 let mut alert_history = self.alert_history.lock().map_err(|e| {
666 TorshDistributedError::communication_error(
667 "alert_history",
668 format!("Lock error: {}", e),
669 )
670 })?;
671
672 for alert in new_alerts {
673 alert_history.push_back(alert);
674 if alert_history.len() > self.config.max_alerts {
675 alert_history.pop_front();
676 }
677 }
678 }
679
680 Ok(())
681 }
682
683 fn update_anomaly_detection(&self, metrics: &NodeMetrics) -> TorshResult<()> {
685 if !self.config.enable_anomaly_detection {
686 return Ok(());
687 }
688
689 let mut detector = self.anomaly_detector.lock().map_err(|e| {
690 TorshDistributedError::communication_error(
691 "anomaly_detector",
692 format!("Lock error: {}", e),
693 )
694 })?;
695
696 detector.update_metric(
698 "cpu_utilization",
699 metrics.system_metrics.cpu_utilization as f64,
700 );
701 detector.update_metric(
702 "gpu_utilization",
703 metrics.system_metrics.gpu_utilization as f64,
704 );
705 detector.update_metric(
706 "throughput",
707 metrics.training_metrics.throughput_samples_per_sec as f64,
708 );
709 detector.update_metric(
710 "comm_latency",
711 metrics.communication_metrics.avg_latency_us as f64,
712 );
713 detector.update_metric(
714 "comm_efficiency",
715 metrics.communication_metrics.efficiency_score as f64,
716 );
717
718 let metrics_to_check = [
720 (
721 "cpu_utilization",
722 metrics.system_metrics.cpu_utilization as f64,
723 ),
724 (
725 "gpu_utilization",
726 metrics.system_metrics.gpu_utilization as f64,
727 ),
728 (
729 "throughput",
730 metrics.training_metrics.throughput_samples_per_sec as f64,
731 ),
732 (
733 "comm_latency",
734 metrics.communication_metrics.avg_latency_us as f64,
735 ),
736 (
737 "comm_efficiency",
738 metrics.communication_metrics.efficiency_score as f64,
739 ),
740 ];
741
742 for (metric_name, value) in &metrics_to_check {
743 if detector.is_anomaly(metric_name, *value) {
744 warn!(
745 "Anomaly detected: {} = {:.2} on node {}",
746 metric_name, value, metrics.node_id
747 );
748
749 let timestamp_ms = SystemTime::now()
751 .duration_since(UNIX_EPOCH)
752 .expect("time should be after UNIX_EPOCH")
753 .as_millis() as u64;
754
755 let alert = Alert {
756 id: format!("anomaly_{}_{}", metrics.node_id, timestamp_ms),
757 severity: AlertSeverity::Warning,
758 message: format!(
759 "Anomaly detected in {} on node {}",
760 metric_name, metrics.node_id
761 ),
762 node_id: metrics.node_id.clone(),
763 metric_name: metric_name.to_string(),
764 current_value: *value,
765 threshold_value: 0.0, timestamp_ms,
767 is_active: true,
768 };
769
770 let mut active_alerts = self.active_alerts.lock().map_err(|e| {
772 TorshDistributedError::communication_error(
773 "anomaly_alerts",
774 format!("Lock error: {}", e),
775 )
776 })?;
777 active_alerts.push(alert);
778 }
779 }
780
781 Ok(())
782 }
783
784 pub fn get_current_metrics(&self) -> TorshResult<Option<NodeMetrics>> {
786 let current = self.current_metrics.read().map_err(|e| {
787 TorshDistributedError::communication_error(
788 "get_current_metrics",
789 format!("Lock error: {}", e),
790 )
791 })?;
792 Ok(current.clone())
793 }
794
795 pub fn get_metrics_history(&self) -> TorshResult<Vec<NodeMetrics>> {
797 let history = self.metrics_history.lock().map_err(|e| {
798 TorshDistributedError::communication_error(
799 "get_metrics_history",
800 format!("Lock error: {}", e),
801 )
802 })?;
803 Ok(history.iter().cloned().collect())
804 }
805
806 pub fn get_active_alerts(&self) -> TorshResult<Vec<Alert>> {
808 let alerts = self.active_alerts.lock().map_err(|e| {
809 TorshDistributedError::communication_error(
810 "get_active_alerts",
811 format!("Lock error: {}", e),
812 )
813 })?;
814 Ok(alerts.clone())
815 }
816
817 pub fn get_cluster_summary(&self) -> TorshResult<ClusterSummary> {
819 if !self.is_coordinator {
820 return Err(TorshDistributedError::communication_error(
821 "cluster_summary",
822 "Only coordinator nodes can access cluster summary".to_string(),
823 ));
824 }
825
826 let all_nodes = self.all_nodes_metrics.read().map_err(|e| {
827 TorshDistributedError::communication_error(
828 "cluster_summary",
829 format!("Lock error: {}", e),
830 )
831 })?;
832
833 let total_nodes = all_nodes.len();
834 let healthy_nodes = all_nodes
835 .values()
836 .filter(|n| matches!(n.health_status, NodeHealthStatus::Healthy))
837 .count();
838 let degraded_nodes = all_nodes
839 .values()
840 .filter(|n| matches!(n.health_status, NodeHealthStatus::Degraded { .. }))
841 .count();
842 let critical_nodes = all_nodes
843 .values()
844 .filter(|n| matches!(n.health_status, NodeHealthStatus::Critical { .. }))
845 .count();
846 let failed_nodes = all_nodes
847 .values()
848 .filter(|n| matches!(n.health_status, NodeHealthStatus::Failed { .. }))
849 .count();
850
851 let total_cpu_util: f32 = all_nodes
853 .values()
854 .map(|n| n.system_metrics.cpu_utilization)
855 .sum();
856 let avg_cpu_util = if total_nodes > 0 {
857 total_cpu_util / total_nodes as f32
858 } else {
859 0.0
860 };
861
862 let total_gpu_util: f32 = all_nodes
863 .values()
864 .map(|n| n.system_metrics.gpu_utilization)
865 .sum();
866 let avg_gpu_util = if total_nodes > 0 {
867 total_gpu_util / total_nodes as f32
868 } else {
869 0.0
870 };
871
872 let total_throughput: f32 = all_nodes
873 .values()
874 .map(|n| n.training_metrics.throughput_samples_per_sec)
875 .sum();
876
877 let avg_comm_latency: u64 = if total_nodes > 0 {
878 all_nodes
879 .values()
880 .map(|n| n.communication_metrics.avg_latency_us)
881 .sum::<u64>()
882 / total_nodes as u64
883 } else {
884 0
885 };
886
887 Ok(ClusterSummary {
888 total_nodes,
889 healthy_nodes,
890 degraded_nodes,
891 critical_nodes,
892 failed_nodes,
893 avg_cpu_utilization: avg_cpu_util,
894 avg_gpu_utilization: avg_gpu_util,
895 total_throughput,
896 avg_communication_latency_us: avg_comm_latency,
897 timestamp_ms: SystemTime::now()
898 .duration_since(UNIX_EPOCH)
899 .expect("time should be after UNIX_EPOCH")
900 .as_millis() as u64,
901 })
902 }
903
904 pub fn clear_resolved_alerts(&self) -> TorshResult<usize> {
906 let mut active_alerts = self.active_alerts.lock().map_err(|e| {
907 TorshDistributedError::communication_error("clear_alerts", format!("Lock error: {}", e))
908 })?;
909
910 let initial_count = active_alerts.len();
911 active_alerts.retain(|alert| alert.is_active);
912 let cleared_count = initial_count - active_alerts.len();
913
914 info!("Cleared {} resolved alerts", cleared_count);
915 Ok(cleared_count)
916 }
917
918 pub fn export_monitoring_data(&self) -> TorshResult<MonitoringExport> {
920 let current_metrics = self.get_current_metrics()?;
921 let metrics_history = self.get_metrics_history()?;
922 let active_alerts = self.get_active_alerts()?;
923
924 let cluster_summary = if self.is_coordinator {
925 Some(self.get_cluster_summary()?)
926 } else {
927 None
928 };
929
930 Ok(MonitoringExport {
931 current_metrics,
932 metrics_history,
933 active_alerts,
934 cluster_summary,
935 export_timestamp_ms: SystemTime::now()
936 .duration_since(UNIX_EPOCH)
937 .expect("time should be after UNIX_EPOCH")
938 .as_millis() as u64,
939 })
940 }
941}
942
943#[derive(Debug, Clone, Serialize, Deserialize)]
945pub struct ClusterSummary {
946 pub total_nodes: usize,
947 pub healthy_nodes: usize,
948 pub degraded_nodes: usize,
949 pub critical_nodes: usize,
950 pub failed_nodes: usize,
951 pub avg_cpu_utilization: f32,
952 pub avg_gpu_utilization: f32,
953 pub total_throughput: f32,
954 pub avg_communication_latency_us: u64,
955 pub timestamp_ms: u64,
956}
957
958#[derive(Debug, Clone, Serialize, Deserialize)]
960pub struct MonitoringExport {
961 pub current_metrics: Option<NodeMetrics>,
962 pub metrics_history: Vec<NodeMetrics>,
963 pub active_alerts: Vec<Alert>,
964 pub cluster_summary: Option<ClusterSummary>,
965 pub export_timestamp_ms: u64,
966}
967
968#[cfg(test)]
969mod tests {
970 use super::*;
971
972 #[tokio::test]
973 async fn test_distributed_monitor_creation() -> TorshResult<()> {
974 let config = MonitoringConfig::default();
975 let monitor = DistributedMonitor::new(config, false);
976
977 let current_metrics = monitor.get_current_metrics()?;
979 assert!(current_metrics.is_none()); Ok(())
982 }
983
984 #[tokio::test]
985 async fn test_system_metrics_collection() -> TorshResult<()> {
986 let config = MonitoringConfig::default();
987 let monitor = DistributedMonitor::new(config, false);
988
989 let metrics = monitor.collect_system_metrics()?;
990 assert!(metrics.cpu_utilization >= 0.0 && metrics.cpu_utilization <= 100.0);
991 assert!(metrics.gpu_utilization >= 0.0 && metrics.gpu_utilization <= 100.0);
992 assert!(metrics.memory_usage_mb > 0);
993
994 Ok(())
995 }
996
997 #[tokio::test]
998 async fn test_node_metrics_update() -> TorshResult<()> {
999 let config = MonitoringConfig::default();
1000 let monitor = DistributedMonitor::new(config, false);
1001
1002 monitor.update_node_metrics(NodeMetricsUpdate {
1003 node_id: "test_node".to_string(),
1004 rank: 0,
1005 world_size: 4,
1006 training_loss: 0.5,
1007 learning_rate: 0.001,
1008 epoch: 10,
1009 batch: 100,
1010 })?;
1011
1012 let current_metrics = monitor.get_current_metrics()?;
1013 assert!(current_metrics.is_some());
1014
1015 let metrics = current_metrics.unwrap();
1016 assert_eq!(metrics.node_id, "test_node");
1017 assert_eq!(metrics.rank, 0);
1018 assert_eq!(metrics.world_size, 4);
1019
1020 Ok(())
1021 }
1022
1023 #[tokio::test]
1024 async fn test_alert_generation() -> TorshResult<()> {
1025 let mut config = MonitoringConfig::default();
1026 config.alert_thresholds.cpu_warning_pct = 50.0; let monitor = DistributedMonitor::new(config, false);
1029
1030 monitor.update_node_metrics(NodeMetricsUpdate {
1031 node_id: "test_node".to_string(),
1032 rank: 0,
1033 world_size: 1,
1034 training_loss: 0.5,
1035 learning_rate: 0.001,
1036 epoch: 1,
1037 batch: 1,
1038 })?;
1039
1040 let alerts = monitor.get_active_alerts()?;
1041 assert!(alerts.is_empty() || !alerts.is_empty()); Ok(())
1047 }
1048
1049 #[tokio::test]
1050 async fn test_anomaly_detection() -> TorshResult<()> {
1051 let mut detector = AnomalyDetector::new(0.7);
1052
1053 for i in 0..50 {
1055 detector.update_metric("test_metric", 50.0 + (i as f64 % 10.0));
1056 }
1057
1058 assert!(!detector.is_anomaly("test_metric", 55.0));
1060
1061 assert!(detector.is_anomaly("test_metric", 200.0));
1063
1064 Ok(())
1065 }
1066
1067 #[tokio::test]
1068 async fn test_monitoring_export() -> TorshResult<()> {
1069 let config = MonitoringConfig::default();
1070 let monitor = DistributedMonitor::new(config, false);
1071
1072 monitor.update_node_metrics(NodeMetricsUpdate {
1073 node_id: "test_node".to_string(),
1074 rank: 0,
1075 world_size: 1,
1076 training_loss: 0.5,
1077 learning_rate: 0.001,
1078 epoch: 1,
1079 batch: 1,
1080 })?;
1081
1082 let export = monitor.export_monitoring_data()?;
1083 assert!(export.current_metrics.is_some());
1084 assert!(!export.metrics_history.is_empty());
1085
1086 Ok(())
1087 }
1088}