Skip to main content

trustformers_debug/
distributed_profiling.rs

1//! Enhanced Distributed Training Profiling
2//!
3//! This module provides comprehensive profiling support for distributed training scenarios,
4//! including multi-node coordination, gradient synchronization analysis, and communication
5//! pattern optimization.
6
7use anyhow::{Context, Result};
8use parking_lot::RwLock;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13use tracing::{debug, info};
14
15/// Distributed training profiler
16///
17/// Provides advanced profiling capabilities for distributed training including:
18/// - Cross-node communication analysis
19/// - Gradient synchronization profiling
20/// - Load balancing metrics
21/// - Communication bottleneck detection
22#[derive(Debug)]
23pub struct DistributedProfiler {
24    /// Configuration
25    config: DistributedProfilerConfig,
26    /// Node metadata
27    nodes: Arc<RwLock<HashMap<String, NodeInfo>>>,
28    /// Communication events
29    comm_events: Arc<RwLock<Vec<CommunicationEvent>>>,
30    /// Synchronization events
31    sync_events: Arc<RwLock<Vec<SynchronizationEvent>>>,
32    /// Performance snapshots per node
33    node_snapshots: Arc<RwLock<HashMap<String, Vec<NodePerformanceSnapshot>>>>,
34    /// Start time for profiling session
35    start_time: Instant,
36}
37
38/// Configuration for distributed profiling
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct DistributedProfilerConfig {
41    /// Enable communication profiling
42    pub enable_comm_profiling: bool,
43    /// Enable gradient sync profiling
44    pub enable_grad_sync_profiling: bool,
45    /// Enable load balance profiling
46    pub enable_load_balance_profiling: bool,
47    /// Enable network bandwidth analysis
48    pub enable_bandwidth_analysis: bool,
49    /// Sampling interval (milliseconds)
50    pub sampling_interval_ms: u64,
51    /// Maximum events to store per category
52    pub max_events_per_category: usize,
53    /// Enable automatic bottleneck detection
54    pub enable_bottleneck_detection: bool,
55    /// Bottleneck threshold (percentage)
56    pub bottleneck_threshold_pct: f64,
57}
58
59impl Default for DistributedProfilerConfig {
60    fn default() -> Self {
61        Self {
62            enable_comm_profiling: true,
63            enable_grad_sync_profiling: true,
64            enable_load_balance_profiling: true,
65            enable_bandwidth_analysis: true,
66            sampling_interval_ms: 100,
67            max_events_per_category: 10000,
68            enable_bottleneck_detection: true,
69            bottleneck_threshold_pct: 80.0,
70        }
71    }
72}
73
74/// Information about a node in the distributed cluster
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct NodeInfo {
77    /// Node ID (unique identifier)
78    pub node_id: String,
79    /// Rank in distributed training
80    pub rank: usize,
81    /// World size (total number of nodes)
82    pub world_size: usize,
83    /// Node hostname/IP
84    pub host: String,
85    /// GPU devices on this node
86    pub gpu_count: usize,
87    /// Node role (master, worker, etc.)
88    pub role: NodeRole,
89    /// Node status
90    pub status: NodeStatus,
91}
92
93/// Node role in distributed training
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
95pub enum NodeRole {
96    /// Master/coordinator node
97    Master,
98    /// Worker node
99    Worker,
100    /// Parameter server
101    ParameterServer,
102    /// Hybrid role
103    Hybrid,
104}
105
106/// Node status
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
108pub enum NodeStatus {
109    /// Node is active and healthy
110    Active,
111    /// Node is idle
112    Idle,
113    /// Node has a warning
114    Warning,
115    /// Node has failed
116    Failed,
117    /// Node is disconnected
118    Disconnected,
119}
120
121/// Communication event between nodes
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct CommunicationEvent {
124    /// Event ID
125    pub event_id: usize,
126    /// Timestamp
127    pub timestamp: Duration,
128    /// Source node
129    pub source_node: String,
130    /// Destination node
131    pub dest_node: String,
132    /// Communication type
133    pub comm_type: CommunicationType,
134    /// Data size (bytes)
135    pub data_size_bytes: usize,
136    /// Duration (milliseconds)
137    pub duration_ms: f64,
138    /// Bandwidth (MB/s)
139    pub bandwidth_mbps: f64,
140}
141
142/// Type of communication between nodes
143#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
144pub enum CommunicationType {
145    /// Point-to-point send
146    Send,
147    /// Point-to-point receive
148    Recv,
149    /// All-reduce operation
150    AllReduce,
151    /// All-gather operation
152    AllGather,
153    /// Broadcast operation
154    Broadcast,
155    /// Scatter operation
156    Scatter,
157    /// Reduce operation
158    Reduce,
159    /// Barrier synchronization
160    Barrier,
161}
162
163/// Gradient synchronization event
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct SynchronizationEvent {
166    /// Event ID
167    pub event_id: usize,
168    /// Timestamp
169    pub timestamp: Duration,
170    /// Participating nodes
171    pub nodes: Vec<String>,
172    /// Synchronization type
173    pub sync_type: SyncType,
174    /// Total gradient size (bytes)
175    pub gradient_size_bytes: usize,
176    /// Synchronization duration (milliseconds)
177    pub duration_ms: f64,
178    /// Success status
179    pub success: bool,
180    /// Error message (if failed)
181    pub error: Option<String>,
182}
183
184/// Type of gradient synchronization
185#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
186pub enum SyncType {
187    /// Data-parallel all-reduce
188    DataParallel,
189    /// Model-parallel send/recv
190    ModelParallel,
191    /// Pipeline-parallel forward
192    PipelineForward,
193    /// Pipeline-parallel backward
194    PipelineBackward,
195    /// Hybrid parallel
196    Hybrid,
197}
198
199/// Performance snapshot for a single node
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct NodePerformanceSnapshot {
202    /// Timestamp
203    pub timestamp: Duration,
204    /// Node ID
205    pub node_id: String,
206    /// Compute utilization (0-100)
207    pub compute_utilization_pct: f64,
208    /// Memory utilization (0-100)
209    pub memory_utilization_pct: f64,
210    /// Network utilization (0-100)
211    pub network_utilization_pct: f64,
212    /// Throughput (samples/sec)
213    pub throughput: f64,
214    /// Active communication count
215    pub active_communications: usize,
216    /// Pending operations
217    pub pending_operations: usize,
218}
219
220/// Distributed profiling report
221#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct DistributedProfilingReport {
223    /// Total profiling duration
224    pub total_duration_secs: f64,
225    /// Number of nodes profiled
226    pub num_nodes: usize,
227    /// Communication summary
228    pub communication_summary: CommunicationSummary,
229    /// Synchronization summary
230    pub synchronization_summary: SynchronizationSummary,
231    /// Load balance analysis
232    pub load_balance: LoadBalanceAnalysis,
233    /// Detected bottlenecks
234    pub bottlenecks: Vec<Bottleneck>,
235    /// Performance recommendations
236    pub recommendations: Vec<String>,
237}
238
239/// Summary of communication patterns
240#[derive(Debug, Clone, Serialize, Deserialize)]
241pub struct CommunicationSummary {
242    /// Total communication events
243    pub total_events: usize,
244    /// Total data transferred (bytes)
245    pub total_data_bytes: usize,
246    /// Average bandwidth (MB/s)
247    pub avg_bandwidth_mbps: f64,
248    /// Peak bandwidth (MB/s)
249    pub peak_bandwidth_mbps: f64,
250    /// Communication overhead (percentage of total time)
251    pub overhead_pct: f64,
252    /// Most common communication type
253    pub most_common_type: Option<CommunicationType>,
254    /// Slowest communication
255    pub slowest_comm: Option<CommunicationEvent>,
256}
257
258/// Summary of synchronization operations
259#[derive(Debug, Clone, Serialize, Deserialize)]
260pub struct SynchronizationSummary {
261    /// Total synchronization events
262    pub total_syncs: usize,
263    /// Successful syncs
264    pub successful_syncs: usize,
265    /// Failed syncs
266    pub failed_syncs: usize,
267    /// Average sync duration (milliseconds)
268    pub avg_sync_duration_ms: f64,
269    /// Maximum sync duration (milliseconds)
270    pub max_sync_duration_ms: f64,
271    /// Total time in synchronization (seconds)
272    pub total_sync_time_secs: f64,
273    /// Synchronization efficiency (0-1)
274    pub sync_efficiency: f64,
275}
276
277/// Load balance analysis across nodes
278#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct LoadBalanceAnalysis {
280    /// Load imbalance score (0-1, lower is better)
281    pub imbalance_score: f64,
282    /// Compute utilization per node
283    pub compute_utilization: HashMap<String, f64>,
284    /// Memory utilization per node
285    pub memory_utilization: HashMap<String, f64>,
286    /// Throughput per node
287    pub throughput: HashMap<String, f64>,
288    /// Straggler nodes (slowest nodes)
289    pub stragglers: Vec<String>,
290    /// Idle time per node (seconds)
291    pub idle_time: HashMap<String, f64>,
292}
293
294/// Detected performance bottleneck
295#[derive(Debug, Clone, Serialize, Deserialize)]
296pub struct Bottleneck {
297    /// Bottleneck type
298    pub bottleneck_type: BottleneckType,
299    /// Severity (0-100)
300    pub severity: f64,
301    /// Affected nodes
302    pub affected_nodes: Vec<String>,
303    /// Description
304    pub description: String,
305    /// Suggested fix
306    pub suggestion: String,
307}
308
309/// Type of performance bottleneck
310#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
311pub enum BottleneckType {
312    /// Communication bottleneck
313    Communication,
314    /// Synchronization bottleneck
315    Synchronization,
316    /// Compute imbalance
317    ComputeImbalance,
318    /// Memory bottleneck
319    Memory,
320    /// Network congestion
321    NetworkCongestion,
322    /// Straggler node
323    Straggler,
324}
325
326impl DistributedProfiler {
327    /// Create a new distributed profiler
328    ///
329    /// # Arguments
330    /// * `config` - Profiler configuration
331    ///
332    /// # Example
333    /// ```rust
334    /// use trustformers_debug::{DistributedProfiler, DistributedProfilerConfig};
335    ///
336    /// let config = DistributedProfilerConfig::default();
337    /// let profiler = DistributedProfiler::new(config);
338    /// ```
339    pub fn new(config: DistributedProfilerConfig) -> Self {
340        info!("Initializing distributed profiler");
341        Self {
342            config,
343            nodes: Arc::new(RwLock::new(HashMap::new())),
344            comm_events: Arc::new(RwLock::new(Vec::new())),
345            sync_events: Arc::new(RwLock::new(Vec::new())),
346            node_snapshots: Arc::new(RwLock::new(HashMap::new())),
347            start_time: Instant::now(),
348        }
349    }
350
351    /// Register a node in the cluster
352    ///
353    /// # Arguments
354    /// * `node_info` - Information about the node
355    pub fn register_node(&self, node_info: NodeInfo) -> Result<()> {
356        debug!(
357            "Registering node: {} (rank {})",
358            node_info.node_id, node_info.rank
359        );
360
361        let mut nodes = self.nodes.write();
362        nodes.insert(node_info.node_id.clone(), node_info);
363
364        Ok(())
365    }
366
367    /// Record a communication event
368    ///
369    /// # Arguments
370    /// * `event` - Communication event to record
371    pub fn record_communication(&self, event: CommunicationEvent) -> Result<()> {
372        if !self.config.enable_comm_profiling {
373            return Ok(());
374        }
375
376        let mut events = self.comm_events.write();
377
378        // Limit stored events
379        if events.len() >= self.config.max_events_per_category {
380            events.remove(0);
381        }
382
383        events.push(event);
384        Ok(())
385    }
386
387    /// Record a synchronization event
388    ///
389    /// # Arguments
390    /// * `event` - Synchronization event to record
391    pub fn record_synchronization(&self, event: SynchronizationEvent) -> Result<()> {
392        if !self.config.enable_grad_sync_profiling {
393            return Ok(());
394        }
395
396        let mut events = self.sync_events.write();
397
398        // Limit stored events
399        if events.len() >= self.config.max_events_per_category {
400            events.remove(0);
401        }
402
403        events.push(event);
404        Ok(())
405    }
406
407    /// Record a performance snapshot for a node
408    ///
409    /// # Arguments
410    /// * `snapshot` - Performance snapshot
411    pub fn record_snapshot(&self, snapshot: NodePerformanceSnapshot) -> Result<()> {
412        let mut snapshots = self.node_snapshots.write();
413
414        let node_history = snapshots.entry(snapshot.node_id.clone()).or_default();
415
416        // Limit stored snapshots
417        if node_history.len() >= self.config.max_events_per_category {
418            node_history.remove(0);
419        }
420
421        node_history.push(snapshot);
422        Ok(())
423    }
424
425    /// Generate a comprehensive profiling report
426    ///
427    /// # Returns
428    /// Detailed profiling report with analysis and recommendations
429    pub fn generate_report(&self) -> Result<DistributedProfilingReport> {
430        info!("Generating distributed profiling report");
431
432        let total_duration = self.start_time.elapsed().as_secs_f64();
433        let nodes = self.nodes.read();
434        let num_nodes = nodes.len();
435
436        // Analyze communication patterns
437        let communication_summary = self.analyze_communication()?;
438
439        // Analyze synchronization
440        let synchronization_summary = self.analyze_synchronization()?;
441
442        // Analyze load balance
443        let load_balance = self.analyze_load_balance()?;
444
445        // Detect bottlenecks
446        let bottlenecks = if self.config.enable_bottleneck_detection {
447            self.detect_bottlenecks(
448                &communication_summary,
449                &synchronization_summary,
450                &load_balance,
451            )?
452        } else {
453            Vec::new()
454        };
455
456        // Generate recommendations
457        let recommendations = self.generate_recommendations(&bottlenecks, &load_balance)?;
458
459        Ok(DistributedProfilingReport {
460            total_duration_secs: total_duration,
461            num_nodes,
462            communication_summary,
463            synchronization_summary,
464            load_balance,
465            bottlenecks,
466            recommendations,
467        })
468    }
469
470    /// Analyze communication patterns
471    fn analyze_communication(&self) -> Result<CommunicationSummary> {
472        let events = self.comm_events.read();
473
474        if events.is_empty() {
475            return Ok(CommunicationSummary {
476                total_events: 0,
477                total_data_bytes: 0,
478                avg_bandwidth_mbps: 0.0,
479                peak_bandwidth_mbps: 0.0,
480                overhead_pct: 0.0,
481                most_common_type: None,
482                slowest_comm: None,
483            });
484        }
485
486        let total_events = events.len();
487        let total_data_bytes: usize = events.iter().map(|e| e.data_size_bytes).sum();
488
489        let bandwidths: Vec<f64> = events.iter().map(|e| e.bandwidth_mbps).collect();
490        let avg_bandwidth_mbps = bandwidths.iter().sum::<f64>() / bandwidths.len() as f64;
491        let peak_bandwidth_mbps = bandwidths.iter().fold(0.0f64, |a, &b| a.max(b));
492
493        let total_comm_time: f64 = events.iter().map(|e| e.duration_ms).sum();
494        let overhead_pct =
495            (total_comm_time / 1000.0) / self.start_time.elapsed().as_secs_f64() * 100.0;
496
497        // Find most common type
498        let mut type_counts: HashMap<CommunicationType, usize> = HashMap::new();
499        for event in events.iter() {
500            *type_counts.entry(event.comm_type).or_insert(0) += 1;
501        }
502        let most_common_type =
503            type_counts.iter().max_by_key(|(_, count)| *count).map(|(typ, _)| *typ);
504
505        // Find slowest communication
506        let slowest_comm = events
507            .iter()
508            .max_by(|a, b| a.duration_ms.partial_cmp(&b.duration_ms).unwrap())
509            .cloned();
510
511        Ok(CommunicationSummary {
512            total_events,
513            total_data_bytes,
514            avg_bandwidth_mbps,
515            peak_bandwidth_mbps,
516            overhead_pct,
517            most_common_type,
518            slowest_comm,
519        })
520    }
521
522    /// Analyze synchronization operations
523    fn analyze_synchronization(&self) -> Result<SynchronizationSummary> {
524        let events = self.sync_events.read();
525
526        if events.is_empty() {
527            return Ok(SynchronizationSummary {
528                total_syncs: 0,
529                successful_syncs: 0,
530                failed_syncs: 0,
531                avg_sync_duration_ms: 0.0,
532                max_sync_duration_ms: 0.0,
533                total_sync_time_secs: 0.0,
534                sync_efficiency: 1.0,
535            });
536        }
537
538        let total_syncs = events.len();
539        let successful_syncs = events.iter().filter(|e| e.success).count();
540        let failed_syncs = total_syncs - successful_syncs;
541
542        let durations: Vec<f64> = events.iter().map(|e| e.duration_ms).collect();
543        let avg_sync_duration_ms = durations.iter().sum::<f64>() / durations.len() as f64;
544        let max_sync_duration_ms = durations.iter().fold(0.0f64, |a, &b| a.max(b));
545        let total_sync_time_secs = durations.iter().sum::<f64>() / 1000.0;
546
547        // Calculate efficiency (theoretical min time / actual time)
548        let theoretical_min = events.iter()
549            .map(|e| e.gradient_size_bytes as f64 / 1_000_000.0) // Convert to MB
550            .sum::<f64>()
551            / 10.0; // Assume 10 MB/s ideal bandwidth
552        let sync_efficiency = (theoretical_min / total_sync_time_secs).min(1.0);
553
554        Ok(SynchronizationSummary {
555            total_syncs,
556            successful_syncs,
557            failed_syncs,
558            avg_sync_duration_ms,
559            max_sync_duration_ms,
560            total_sync_time_secs,
561            sync_efficiency,
562        })
563    }
564
565    /// Analyze load balance across nodes
566    fn analyze_load_balance(&self) -> Result<LoadBalanceAnalysis> {
567        let snapshots = self.node_snapshots.read();
568
569        if snapshots.is_empty() {
570            return Ok(LoadBalanceAnalysis {
571                imbalance_score: 0.0,
572                compute_utilization: HashMap::new(),
573                memory_utilization: HashMap::new(),
574                throughput: HashMap::new(),
575                stragglers: Vec::new(),
576                idle_time: HashMap::new(),
577            });
578        }
579
580        let mut compute_utilization = HashMap::new();
581        let mut memory_utilization = HashMap::new();
582        let mut throughput = HashMap::new();
583        let mut idle_time = HashMap::new();
584
585        // Calculate averages per node
586        for (node_id, node_snapshots) in snapshots.iter() {
587            if node_snapshots.is_empty() {
588                continue;
589            }
590
591            let avg_compute = node_snapshots.iter().map(|s| s.compute_utilization_pct).sum::<f64>()
592                / node_snapshots.len() as f64;
593
594            let avg_memory = node_snapshots.iter().map(|s| s.memory_utilization_pct).sum::<f64>()
595                / node_snapshots.len() as f64;
596
597            let avg_throughput = node_snapshots.iter().map(|s| s.throughput).sum::<f64>()
598                / node_snapshots.len() as f64;
599
600            // Calculate idle time (when compute utilization < 10%)
601            let idle_samples =
602                node_snapshots.iter().filter(|s| s.compute_utilization_pct < 10.0).count();
603            let idle_secs =
604                idle_samples as f64 * (self.config.sampling_interval_ms as f64 / 1000.0);
605
606            compute_utilization.insert(node_id.clone(), avg_compute);
607            memory_utilization.insert(node_id.clone(), avg_memory);
608            throughput.insert(node_id.clone(), avg_throughput);
609            idle_time.insert(node_id.clone(), idle_secs);
610        }
611
612        // Calculate imbalance score (coefficient of variation of throughput)
613        let throughput_values: Vec<f64> = throughput.values().copied().collect();
614        let imbalance_score = if !throughput_values.is_empty() {
615            let mean = throughput_values.iter().sum::<f64>() / throughput_values.len() as f64;
616            let variance = throughput_values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
617                / throughput_values.len() as f64;
618            let std_dev = variance.sqrt();
619            std_dev / mean
620        } else {
621            0.0
622        };
623
624        // Identify stragglers (nodes with significantly lower throughput)
625        let mean_throughput =
626            throughput_values.iter().sum::<f64>() / throughput_values.len().max(1) as f64;
627        let stragglers: Vec<String> = throughput.iter()
628            .filter(|(_, &t)| t < mean_throughput * 0.7) // 30% below average
629            .map(|(node_id, _)| node_id.clone())
630            .collect();
631
632        Ok(LoadBalanceAnalysis {
633            imbalance_score,
634            compute_utilization,
635            memory_utilization,
636            throughput,
637            stragglers,
638            idle_time,
639        })
640    }
641
642    /// Detect performance bottlenecks
643    fn detect_bottlenecks(
644        &self,
645        comm_summary: &CommunicationSummary,
646        sync_summary: &SynchronizationSummary,
647        load_balance: &LoadBalanceAnalysis,
648    ) -> Result<Vec<Bottleneck>> {
649        let mut bottlenecks = Vec::new();
650
651        // Check for communication bottleneck
652        if comm_summary.overhead_pct > self.config.bottleneck_threshold_pct {
653            bottlenecks.push(Bottleneck {
654                bottleneck_type: BottleneckType::Communication,
655                severity: comm_summary.overhead_pct,
656                affected_nodes: vec!["all".to_string()],
657                description: format!(
658                    "Communication overhead is {:.1}%, significantly impacting performance",
659                    comm_summary.overhead_pct
660                ),
661                suggestion: "Consider reducing communication frequency, increasing batch size, or using gradient compression".to_string(),
662            });
663        }
664
665        // Check for synchronization bottleneck
666        if sync_summary.sync_efficiency < 0.5 {
667            bottlenecks.push(Bottleneck {
668                bottleneck_type: BottleneckType::Synchronization,
669                severity: (1.0 - sync_summary.sync_efficiency) * 100.0,
670                affected_nodes: vec!["all".to_string()],
671                description: format!(
672                    "Synchronization efficiency is only {:.1}%, indicating significant overhead",
673                    sync_summary.sync_efficiency * 100.0
674                ),
675                suggestion: "Use gradient accumulation, optimize all-reduce operations, or consider hierarchical synchronization".to_string(),
676            });
677        }
678
679        // Check for load imbalance
680        if load_balance.imbalance_score > 0.3 {
681            bottlenecks.push(Bottleneck {
682                bottleneck_type: BottleneckType::ComputeImbalance,
683                severity: load_balance.imbalance_score * 100.0,
684                affected_nodes: load_balance.stragglers.clone(),
685                description: format!(
686                    "High load imbalance detected (score: {:.2}), {} straggler node(s)",
687                    load_balance.imbalance_score,
688                    load_balance.stragglers.len()
689                ),
690                suggestion: "Balance data distribution, check for hardware heterogeneity, or implement dynamic load balancing".to_string(),
691            });
692        }
693
694        // Check for straggler nodes
695        for straggler in &load_balance.stragglers {
696            if let Some(&idle_time) = load_balance.idle_time.get(straggler) {
697                if idle_time > 5.0 {
698                    // More than 5 seconds idle
699                    bottlenecks.push(Bottleneck {
700                        bottleneck_type: BottleneckType::Straggler,
701                        severity: 75.0,
702                        affected_nodes: vec![straggler.clone()],
703                        description: format!(
704                            "Node {} is a straggler with {:.1}s idle time",
705                            straggler, idle_time
706                        ),
707                        suggestion: format!(
708                            "Investigate node {} for hardware issues, resource contention, or network problems",
709                            straggler
710                        ),
711                    });
712                }
713            }
714        }
715
716        Ok(bottlenecks)
717    }
718
719    /// Generate optimization recommendations
720    fn generate_recommendations(
721        &self,
722        bottlenecks: &[Bottleneck],
723        load_balance: &LoadBalanceAnalysis,
724    ) -> Result<Vec<String>> {
725        let mut recommendations = Vec::new();
726
727        // General recommendations based on bottlenecks
728        for bottleneck in bottlenecks {
729            if bottleneck.severity > 50.0 {
730                recommendations.push(format!(
731                    "[HIGH PRIORITY] {}: {}",
732                    match bottleneck.bottleneck_type {
733                        BottleneckType::Communication => "Communication Bottleneck",
734                        BottleneckType::Synchronization => "Synchronization Bottleneck",
735                        BottleneckType::ComputeImbalance => "Load Imbalance",
736                        BottleneckType::Memory => "Memory Bottleneck",
737                        BottleneckType::NetworkCongestion => "Network Congestion",
738                        BottleneckType::Straggler => "Straggler Node",
739                    },
740                    bottleneck.suggestion
741                ));
742            }
743        }
744
745        // Load balance recommendations
746        if load_balance.imbalance_score > 0.2 {
747            recommendations.push(
748                "Consider implementing dynamic batch size adjustment per node based on compute capability".to_string()
749            );
750        }
751
752        // Check for underutilized nodes
753        let underutilized: Vec<_> = load_balance
754            .compute_utilization
755            .iter()
756            .filter(|(_, &util)| util < 50.0)
757            .collect();
758
759        if !underutilized.is_empty() {
760            recommendations.push(format!(
761                "{} node(s) are underutilized (<50% compute). Consider increasing batch size or model complexity",
762                underutilized.len()
763            ));
764        }
765
766        // If no specific recommendations, add general ones
767        if recommendations.is_empty() {
768            recommendations.push(
769                "Performance looks good! Continue monitoring for any degradation".to_string(),
770            );
771            recommendations.push(
772                "Consider enabling gradient compression to reduce communication overhead"
773                    .to_string(),
774            );
775            recommendations
776                .push("Experiment with mixed-precision training for better throughput".to_string());
777        }
778
779        Ok(recommendations)
780    }
781
782    /// Export profiling data to JSON
783    ///
784    /// # Arguments
785    /// * `path` - Output file path
786    pub fn export_json(&self, path: &std::path::Path) -> Result<()> {
787        let report = self.generate_report()?;
788        let json =
789            serde_json::to_string_pretty(&report).context("Failed to serialize report to JSON")?;
790        std::fs::write(path, json).context("Failed to write JSON file")?;
791        info!("Exported profiling report to {}", path.display());
792        Ok(())
793    }
794
795    /// Get real-time statistics (for dashboards)
796    ///
797    /// # Returns
798    /// Current profiling statistics
799    pub fn get_realtime_stats(&self) -> Result<RealtimeStats> {
800        let nodes = self.nodes.read();
801        let comm_events = self.comm_events.read();
802        let sync_events = self.sync_events.read();
803
804        // Calculate recent metrics (last 10 seconds)
805        let recent_cutoff = self.start_time.elapsed().saturating_sub(Duration::from_secs(10));
806
807        let recent_comm_count = comm_events.iter().filter(|e| e.timestamp >= recent_cutoff).count();
808
809        let recent_sync_count = sync_events.iter().filter(|e| e.timestamp >= recent_cutoff).count();
810
811        let active_nodes = nodes.values().filter(|n| n.status == NodeStatus::Active).count();
812
813        Ok(RealtimeStats {
814            active_nodes,
815            total_nodes: nodes.len(),
816            recent_communications: recent_comm_count,
817            recent_synchronizations: recent_sync_count,
818            elapsed_time_secs: self.start_time.elapsed().as_secs_f64(),
819        })
820    }
821}
822
823/// Real-time statistics for dashboards
824#[derive(Debug, Clone, Serialize, Deserialize)]
825pub struct RealtimeStats {
826    /// Number of active nodes
827    pub active_nodes: usize,
828    /// Total number of nodes
829    pub total_nodes: usize,
830    /// Recent communication events (last 10s)
831    pub recent_communications: usize,
832    /// Recent synchronization events (last 10s)
833    pub recent_synchronizations: usize,
834    /// Elapsed time since profiling started
835    pub elapsed_time_secs: f64,
836}
837
838#[cfg(test)]
839mod tests {
840    use super::*;
841
842    #[test]
843    fn test_profiler_creation() {
844        let config = DistributedProfilerConfig::default();
845        let _profiler = DistributedProfiler::new(config);
846    }
847
848    #[test]
849    fn test_node_registration() -> Result<()> {
850        let config = DistributedProfilerConfig::default();
851        let profiler = DistributedProfiler::new(config);
852
853        let node = NodeInfo {
854            node_id: "node-0".to_string(),
855            rank: 0,
856            world_size: 4,
857            host: "localhost".to_string(),
858            gpu_count: 1,
859            role: NodeRole::Master,
860            status: NodeStatus::Active,
861        };
862
863        profiler.register_node(node)?;
864
865        let nodes = profiler.nodes.read();
866        assert_eq!(nodes.len(), 1);
867        assert!(nodes.contains_key("node-0"));
868
869        Ok(())
870    }
871
872    #[test]
873    fn test_communication_recording() -> Result<()> {
874        let config = DistributedProfilerConfig::default();
875        let profiler = DistributedProfiler::new(config);
876
877        let event = CommunicationEvent {
878            event_id: 0,
879            timestamp: Duration::from_millis(100),
880            source_node: "node-0".to_string(),
881            dest_node: "node-1".to_string(),
882            comm_type: CommunicationType::AllReduce,
883            data_size_bytes: 1024 * 1024,
884            duration_ms: 10.5,
885            bandwidth_mbps: 95.0,
886        };
887
888        profiler.record_communication(event)?;
889
890        let events = profiler.comm_events.read();
891        assert_eq!(events.len(), 1);
892
893        Ok(())
894    }
895
896    #[test]
897    fn test_report_generation() -> Result<()> {
898        let config = DistributedProfilerConfig::default();
899        let profiler = DistributedProfiler::new(config);
900
901        // Register nodes
902        for i in 0..4 {
903            let node = NodeInfo {
904                node_id: format!("node-{}", i),
905                rank: i,
906                world_size: 4,
907                host: "localhost".to_string(),
908                gpu_count: 1,
909                role: if i == 0 { NodeRole::Master } else { NodeRole::Worker },
910                status: NodeStatus::Active,
911            };
912            profiler.register_node(node)?;
913        }
914
915        // Record some events
916        for i in 0..10 {
917            let event = CommunicationEvent {
918                event_id: i,
919                timestamp: Duration::from_millis(i as u64 * 100),
920                source_node: format!("node-{}", i % 4),
921                dest_node: format!("node-{}", (i + 1) % 4),
922                comm_type: CommunicationType::AllReduce,
923                data_size_bytes: 1024 * 1024,
924                duration_ms: 10.0 + (i as f64 * 0.5),
925                bandwidth_mbps: 100.0 - (i as f64 * 2.0),
926            };
927            profiler.record_communication(event)?;
928        }
929
930        let report = profiler.generate_report()?;
931
932        assert_eq!(report.num_nodes, 4);
933        assert_eq!(report.communication_summary.total_events, 10);
934        assert!(report.communication_summary.avg_bandwidth_mbps > 0.0);
935
936        Ok(())
937    }
938}