Skip to main content

trustformers_debug/
distributed_profiling.rs

1//! Enhanced Distributed Training Profiling
2//!
3//! This module provides comprehensive profiling support for distributed training scenarios,
4//! including multi-node coordination, gradient synchronization analysis, and communication
5//! pattern optimization.
6
7use anyhow::{Context, Result};
8use parking_lot::RwLock;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13use tracing::{debug, info};
14
15/// Distributed training profiler
16///
17/// Provides advanced profiling capabilities for distributed training including:
18/// - Cross-node communication analysis
19/// - Gradient synchronization profiling
20/// - Load balancing metrics
21/// - Communication bottleneck detection
22#[derive(Debug)]
23pub struct DistributedProfiler {
24    /// Configuration
25    config: DistributedProfilerConfig,
26    /// Node metadata
27    nodes: Arc<RwLock<HashMap<String, NodeInfo>>>,
28    /// Communication events
29    comm_events: Arc<RwLock<Vec<CommunicationEvent>>>,
30    /// Synchronization events
31    sync_events: Arc<RwLock<Vec<SynchronizationEvent>>>,
32    /// Performance snapshots per node
33    node_snapshots: Arc<RwLock<HashMap<String, Vec<NodePerformanceSnapshot>>>>,
34    /// Start time for profiling session
35    start_time: Instant,
36}
37
38/// Configuration for distributed profiling
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct DistributedProfilerConfig {
41    /// Enable communication profiling
42    pub enable_comm_profiling: bool,
43    /// Enable gradient sync profiling
44    pub enable_grad_sync_profiling: bool,
45    /// Enable load balance profiling
46    pub enable_load_balance_profiling: bool,
47    /// Enable network bandwidth analysis
48    pub enable_bandwidth_analysis: bool,
49    /// Sampling interval (milliseconds)
50    pub sampling_interval_ms: u64,
51    /// Maximum events to store per category
52    pub max_events_per_category: usize,
53    /// Enable automatic bottleneck detection
54    pub enable_bottleneck_detection: bool,
55    /// Bottleneck threshold (percentage)
56    pub bottleneck_threshold_pct: f64,
57}
58
59impl Default for DistributedProfilerConfig {
60    fn default() -> Self {
61        Self {
62            enable_comm_profiling: true,
63            enable_grad_sync_profiling: true,
64            enable_load_balance_profiling: true,
65            enable_bandwidth_analysis: true,
66            sampling_interval_ms: 100,
67            max_events_per_category: 10000,
68            enable_bottleneck_detection: true,
69            bottleneck_threshold_pct: 80.0,
70        }
71    }
72}
73
74/// Information about a node in the distributed cluster
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct NodeInfo {
77    /// Node ID (unique identifier)
78    pub node_id: String,
79    /// Rank in distributed training
80    pub rank: usize,
81    /// World size (total number of nodes)
82    pub world_size: usize,
83    /// Node hostname/IP
84    pub host: String,
85    /// GPU devices on this node
86    pub gpu_count: usize,
87    /// Node role (master, worker, etc.)
88    pub role: NodeRole,
89    /// Node status
90    pub status: NodeStatus,
91}
92
93/// Node role in distributed training
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
95pub enum NodeRole {
96    /// Master/coordinator node
97    Master,
98    /// Worker node
99    Worker,
100    /// Parameter server
101    ParameterServer,
102    /// Hybrid role
103    Hybrid,
104}
105
106/// Node status
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
108pub enum NodeStatus {
109    /// Node is active and healthy
110    Active,
111    /// Node is idle
112    Idle,
113    /// Node has a warning
114    Warning,
115    /// Node has failed
116    Failed,
117    /// Node is disconnected
118    Disconnected,
119}
120
121/// Communication event between nodes
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct CommunicationEvent {
124    /// Event ID
125    pub event_id: usize,
126    /// Timestamp
127    pub timestamp: Duration,
128    /// Source node
129    pub source_node: String,
130    /// Destination node
131    pub dest_node: String,
132    /// Communication type
133    pub comm_type: CommunicationType,
134    /// Data size (bytes)
135    pub data_size_bytes: usize,
136    /// Duration (milliseconds)
137    pub duration_ms: f64,
138    /// Bandwidth (MB/s)
139    pub bandwidth_mbps: f64,
140}
141
142/// Type of communication between nodes
143#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
144pub enum CommunicationType {
145    /// Point-to-point send
146    Send,
147    /// Point-to-point receive
148    Recv,
149    /// All-reduce operation
150    AllReduce,
151    /// All-gather operation
152    AllGather,
153    /// Broadcast operation
154    Broadcast,
155    /// Scatter operation
156    Scatter,
157    /// Reduce operation
158    Reduce,
159    /// Barrier synchronization
160    Barrier,
161}
162
163/// Gradient synchronization event
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct SynchronizationEvent {
166    /// Event ID
167    pub event_id: usize,
168    /// Timestamp
169    pub timestamp: Duration,
170    /// Participating nodes
171    pub nodes: Vec<String>,
172    /// Synchronization type
173    pub sync_type: SyncType,
174    /// Total gradient size (bytes)
175    pub gradient_size_bytes: usize,
176    /// Synchronization duration (milliseconds)
177    pub duration_ms: f64,
178    /// Success status
179    pub success: bool,
180    /// Error message (if failed)
181    pub error: Option<String>,
182}
183
184/// Type of gradient synchronization
185#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
186pub enum SyncType {
187    /// Data-parallel all-reduce
188    DataParallel,
189    /// Model-parallel send/recv
190    ModelParallel,
191    /// Pipeline-parallel forward
192    PipelineForward,
193    /// Pipeline-parallel backward
194    PipelineBackward,
195    /// Hybrid parallel
196    Hybrid,
197}
198
199/// Performance snapshot for a single node
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct NodePerformanceSnapshot {
202    /// Timestamp
203    pub timestamp: Duration,
204    /// Node ID
205    pub node_id: String,
206    /// Compute utilization (0-100)
207    pub compute_utilization_pct: f64,
208    /// Memory utilization (0-100)
209    pub memory_utilization_pct: f64,
210    /// Network utilization (0-100)
211    pub network_utilization_pct: f64,
212    /// Throughput (samples/sec)
213    pub throughput: f64,
214    /// Active communication count
215    pub active_communications: usize,
216    /// Pending operations
217    pub pending_operations: usize,
218}
219
220/// Distributed profiling report
221#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct DistributedProfilingReport {
223    /// Total profiling duration
224    pub total_duration_secs: f64,
225    /// Number of nodes profiled
226    pub num_nodes: usize,
227    /// Communication summary
228    pub communication_summary: CommunicationSummary,
229    /// Synchronization summary
230    pub synchronization_summary: SynchronizationSummary,
231    /// Load balance analysis
232    pub load_balance: LoadBalanceAnalysis,
233    /// Detected bottlenecks
234    pub bottlenecks: Vec<Bottleneck>,
235    /// Performance recommendations
236    pub recommendations: Vec<String>,
237}
238
239/// Summary of communication patterns
240#[derive(Debug, Clone, Serialize, Deserialize)]
241pub struct CommunicationSummary {
242    /// Total communication events
243    pub total_events: usize,
244    /// Total data transferred (bytes)
245    pub total_data_bytes: usize,
246    /// Average bandwidth (MB/s)
247    pub avg_bandwidth_mbps: f64,
248    /// Peak bandwidth (MB/s)
249    pub peak_bandwidth_mbps: f64,
250    /// Communication overhead (percentage of total time)
251    pub overhead_pct: f64,
252    /// Most common communication type
253    pub most_common_type: Option<CommunicationType>,
254    /// Slowest communication
255    pub slowest_comm: Option<CommunicationEvent>,
256}
257
258/// Summary of synchronization operations
259#[derive(Debug, Clone, Serialize, Deserialize)]
260pub struct SynchronizationSummary {
261    /// Total synchronization events
262    pub total_syncs: usize,
263    /// Successful syncs
264    pub successful_syncs: usize,
265    /// Failed syncs
266    pub failed_syncs: usize,
267    /// Average sync duration (milliseconds)
268    pub avg_sync_duration_ms: f64,
269    /// Maximum sync duration (milliseconds)
270    pub max_sync_duration_ms: f64,
271    /// Total time in synchronization (seconds)
272    pub total_sync_time_secs: f64,
273    /// Synchronization efficiency (0-1)
274    pub sync_efficiency: f64,
275}
276
277/// Load balance analysis across nodes
278#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct LoadBalanceAnalysis {
280    /// Load imbalance score (0-1, lower is better)
281    pub imbalance_score: f64,
282    /// Compute utilization per node
283    pub compute_utilization: HashMap<String, f64>,
284    /// Memory utilization per node
285    pub memory_utilization: HashMap<String, f64>,
286    /// Throughput per node
287    pub throughput: HashMap<String, f64>,
288    /// Straggler nodes (slowest nodes)
289    pub stragglers: Vec<String>,
290    /// Idle time per node (seconds)
291    pub idle_time: HashMap<String, f64>,
292}
293
294/// Detected performance bottleneck
295#[derive(Debug, Clone, Serialize, Deserialize)]
296pub struct Bottleneck {
297    /// Bottleneck type
298    pub bottleneck_type: BottleneckType,
299    /// Severity (0-100)
300    pub severity: f64,
301    /// Affected nodes
302    pub affected_nodes: Vec<String>,
303    /// Description
304    pub description: String,
305    /// Suggested fix
306    pub suggestion: String,
307}
308
309/// Type of performance bottleneck
310#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
311pub enum BottleneckType {
312    /// Communication bottleneck
313    Communication,
314    /// Synchronization bottleneck
315    Synchronization,
316    /// Compute imbalance
317    ComputeImbalance,
318    /// Memory bottleneck
319    Memory,
320    /// Network congestion
321    NetworkCongestion,
322    /// Straggler node
323    Straggler,
324}
325
326impl DistributedProfiler {
327    /// Create a new distributed profiler
328    ///
329    /// # Arguments
330    /// * `config` - Profiler configuration
331    ///
332    /// # Example
333    /// ```rust
334    /// use trustformers_debug::{DistributedProfiler, DistributedProfilerConfig};
335    ///
336    /// let config = DistributedProfilerConfig::default();
337    /// let profiler = DistributedProfiler::new(config);
338    /// ```
339    pub fn new(config: DistributedProfilerConfig) -> Self {
340        info!("Initializing distributed profiler");
341        Self {
342            config,
343            nodes: Arc::new(RwLock::new(HashMap::new())),
344            comm_events: Arc::new(RwLock::new(Vec::new())),
345            sync_events: Arc::new(RwLock::new(Vec::new())),
346            node_snapshots: Arc::new(RwLock::new(HashMap::new())),
347            start_time: Instant::now(),
348        }
349    }
350
351    /// Register a node in the cluster
352    ///
353    /// # Arguments
354    /// * `node_info` - Information about the node
355    pub fn register_node(&self, node_info: NodeInfo) -> Result<()> {
356        debug!(
357            "Registering node: {} (rank {})",
358            node_info.node_id, node_info.rank
359        );
360
361        let mut nodes = self.nodes.write();
362        nodes.insert(node_info.node_id.clone(), node_info);
363
364        Ok(())
365    }
366
367    /// Record a communication event
368    ///
369    /// # Arguments
370    /// * `event` - Communication event to record
371    pub fn record_communication(&self, event: CommunicationEvent) -> Result<()> {
372        if !self.config.enable_comm_profiling {
373            return Ok(());
374        }
375
376        let mut events = self.comm_events.write();
377
378        // Limit stored events
379        if events.len() >= self.config.max_events_per_category {
380            events.remove(0);
381        }
382
383        events.push(event);
384        Ok(())
385    }
386
387    /// Record a synchronization event
388    ///
389    /// # Arguments
390    /// * `event` - Synchronization event to record
391    pub fn record_synchronization(&self, event: SynchronizationEvent) -> Result<()> {
392        if !self.config.enable_grad_sync_profiling {
393            return Ok(());
394        }
395
396        let mut events = self.sync_events.write();
397
398        // Limit stored events
399        if events.len() >= self.config.max_events_per_category {
400            events.remove(0);
401        }
402
403        events.push(event);
404        Ok(())
405    }
406
407    /// Record a performance snapshot for a node
408    ///
409    /// # Arguments
410    /// * `snapshot` - Performance snapshot
411    pub fn record_snapshot(&self, snapshot: NodePerformanceSnapshot) -> Result<()> {
412        let mut snapshots = self.node_snapshots.write();
413
414        let node_history = snapshots.entry(snapshot.node_id.clone()).or_default();
415
416        // Limit stored snapshots
417        if node_history.len() >= self.config.max_events_per_category {
418            node_history.remove(0);
419        }
420
421        node_history.push(snapshot);
422        Ok(())
423    }
424
425    /// Generate a comprehensive profiling report
426    ///
427    /// # Returns
428    /// Detailed profiling report with analysis and recommendations
429    pub fn generate_report(&self) -> Result<DistributedProfilingReport> {
430        info!("Generating distributed profiling report");
431
432        let total_duration = self.start_time.elapsed().as_secs_f64();
433        let nodes = self.nodes.read();
434        let num_nodes = nodes.len();
435
436        // Analyze communication patterns
437        let communication_summary = self.analyze_communication()?;
438
439        // Analyze synchronization
440        let synchronization_summary = self.analyze_synchronization()?;
441
442        // Analyze load balance
443        let load_balance = self.analyze_load_balance()?;
444
445        // Detect bottlenecks
446        let bottlenecks = if self.config.enable_bottleneck_detection {
447            self.detect_bottlenecks(
448                &communication_summary,
449                &synchronization_summary,
450                &load_balance,
451            )?
452        } else {
453            Vec::new()
454        };
455
456        // Generate recommendations
457        let recommendations = self.generate_recommendations(&bottlenecks, &load_balance)?;
458
459        Ok(DistributedProfilingReport {
460            total_duration_secs: total_duration,
461            num_nodes,
462            communication_summary,
463            synchronization_summary,
464            load_balance,
465            bottlenecks,
466            recommendations,
467        })
468    }
469
470    /// Analyze communication patterns
471    fn analyze_communication(&self) -> Result<CommunicationSummary> {
472        let events = self.comm_events.read();
473
474        if events.is_empty() {
475            return Ok(CommunicationSummary {
476                total_events: 0,
477                total_data_bytes: 0,
478                avg_bandwidth_mbps: 0.0,
479                peak_bandwidth_mbps: 0.0,
480                overhead_pct: 0.0,
481                most_common_type: None,
482                slowest_comm: None,
483            });
484        }
485
486        let total_events = events.len();
487        let total_data_bytes: usize = events.iter().map(|e| e.data_size_bytes).sum();
488
489        let bandwidths: Vec<f64> = events.iter().map(|e| e.bandwidth_mbps).collect();
490        let avg_bandwidth_mbps = bandwidths.iter().sum::<f64>() / bandwidths.len() as f64;
491        let peak_bandwidth_mbps = bandwidths.iter().fold(0.0f64, |a, &b| a.max(b));
492
493        let total_comm_time: f64 = events.iter().map(|e| e.duration_ms).sum();
494        let overhead_pct =
495            (total_comm_time / 1000.0) / self.start_time.elapsed().as_secs_f64() * 100.0;
496
497        // Find most common type
498        let mut type_counts: HashMap<CommunicationType, usize> = HashMap::new();
499        for event in events.iter() {
500            *type_counts.entry(event.comm_type).or_insert(0) += 1;
501        }
502        let most_common_type =
503            type_counts.iter().max_by_key(|(_, count)| *count).map(|(typ, _)| *typ);
504
505        // Find slowest communication
506        let slowest_comm = events
507            .iter()
508            .max_by(|a, b| {
509                a.duration_ms.partial_cmp(&b.duration_ms).unwrap_or(std::cmp::Ordering::Equal)
510            })
511            .cloned();
512
513        Ok(CommunicationSummary {
514            total_events,
515            total_data_bytes,
516            avg_bandwidth_mbps,
517            peak_bandwidth_mbps,
518            overhead_pct,
519            most_common_type,
520            slowest_comm,
521        })
522    }
523
524    /// Analyze synchronization operations
525    fn analyze_synchronization(&self) -> Result<SynchronizationSummary> {
526        let events = self.sync_events.read();
527
528        if events.is_empty() {
529            return Ok(SynchronizationSummary {
530                total_syncs: 0,
531                successful_syncs: 0,
532                failed_syncs: 0,
533                avg_sync_duration_ms: 0.0,
534                max_sync_duration_ms: 0.0,
535                total_sync_time_secs: 0.0,
536                sync_efficiency: 1.0,
537            });
538        }
539
540        let total_syncs = events.len();
541        let successful_syncs = events.iter().filter(|e| e.success).count();
542        let failed_syncs = total_syncs - successful_syncs;
543
544        let durations: Vec<f64> = events.iter().map(|e| e.duration_ms).collect();
545        let avg_sync_duration_ms = durations.iter().sum::<f64>() / durations.len() as f64;
546        let max_sync_duration_ms = durations.iter().fold(0.0f64, |a, &b| a.max(b));
547        let total_sync_time_secs = durations.iter().sum::<f64>() / 1000.0;
548
549        // Calculate efficiency (theoretical min time / actual time)
550        let theoretical_min = events.iter()
551            .map(|e| e.gradient_size_bytes as f64 / 1_000_000.0) // Convert to MB
552            .sum::<f64>()
553            / 10.0; // Assume 10 MB/s ideal bandwidth
554        let sync_efficiency = (theoretical_min / total_sync_time_secs).min(1.0);
555
556        Ok(SynchronizationSummary {
557            total_syncs,
558            successful_syncs,
559            failed_syncs,
560            avg_sync_duration_ms,
561            max_sync_duration_ms,
562            total_sync_time_secs,
563            sync_efficiency,
564        })
565    }
566
567    /// Analyze load balance across nodes
568    fn analyze_load_balance(&self) -> Result<LoadBalanceAnalysis> {
569        let snapshots = self.node_snapshots.read();
570
571        if snapshots.is_empty() {
572            return Ok(LoadBalanceAnalysis {
573                imbalance_score: 0.0,
574                compute_utilization: HashMap::new(),
575                memory_utilization: HashMap::new(),
576                throughput: HashMap::new(),
577                stragglers: Vec::new(),
578                idle_time: HashMap::new(),
579            });
580        }
581
582        let mut compute_utilization = HashMap::new();
583        let mut memory_utilization = HashMap::new();
584        let mut throughput = HashMap::new();
585        let mut idle_time = HashMap::new();
586
587        // Calculate averages per node
588        for (node_id, node_snapshots) in snapshots.iter() {
589            if node_snapshots.is_empty() {
590                continue;
591            }
592
593            let avg_compute = node_snapshots.iter().map(|s| s.compute_utilization_pct).sum::<f64>()
594                / node_snapshots.len() as f64;
595
596            let avg_memory = node_snapshots.iter().map(|s| s.memory_utilization_pct).sum::<f64>()
597                / node_snapshots.len() as f64;
598
599            let avg_throughput = node_snapshots.iter().map(|s| s.throughput).sum::<f64>()
600                / node_snapshots.len() as f64;
601
602            // Calculate idle time (when compute utilization < 10%)
603            let idle_samples =
604                node_snapshots.iter().filter(|s| s.compute_utilization_pct < 10.0).count();
605            let idle_secs =
606                idle_samples as f64 * (self.config.sampling_interval_ms as f64 / 1000.0);
607
608            compute_utilization.insert(node_id.clone(), avg_compute);
609            memory_utilization.insert(node_id.clone(), avg_memory);
610            throughput.insert(node_id.clone(), avg_throughput);
611            idle_time.insert(node_id.clone(), idle_secs);
612        }
613
614        // Calculate imbalance score (coefficient of variation of throughput)
615        let throughput_values: Vec<f64> = throughput.values().copied().collect();
616        let imbalance_score = if !throughput_values.is_empty() {
617            let mean = throughput_values.iter().sum::<f64>() / throughput_values.len() as f64;
618            let variance = throughput_values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
619                / throughput_values.len() as f64;
620            let std_dev = variance.sqrt();
621            std_dev / mean
622        } else {
623            0.0
624        };
625
626        // Identify stragglers (nodes with significantly lower throughput)
627        let mean_throughput =
628            throughput_values.iter().sum::<f64>() / throughput_values.len().max(1) as f64;
629        let stragglers: Vec<String> = throughput.iter()
630            .filter(|(_, &t)| t < mean_throughput * 0.7) // 30% below average
631            .map(|(node_id, _)| node_id.clone())
632            .collect();
633
634        Ok(LoadBalanceAnalysis {
635            imbalance_score,
636            compute_utilization,
637            memory_utilization,
638            throughput,
639            stragglers,
640            idle_time,
641        })
642    }
643
644    /// Detect performance bottlenecks
645    fn detect_bottlenecks(
646        &self,
647        comm_summary: &CommunicationSummary,
648        sync_summary: &SynchronizationSummary,
649        load_balance: &LoadBalanceAnalysis,
650    ) -> Result<Vec<Bottleneck>> {
651        let mut bottlenecks = Vec::new();
652
653        // Check for communication bottleneck
654        if comm_summary.overhead_pct > self.config.bottleneck_threshold_pct {
655            bottlenecks.push(Bottleneck {
656                bottleneck_type: BottleneckType::Communication,
657                severity: comm_summary.overhead_pct,
658                affected_nodes: vec!["all".to_string()],
659                description: format!(
660                    "Communication overhead is {:.1}%, significantly impacting performance",
661                    comm_summary.overhead_pct
662                ),
663                suggestion: "Consider reducing communication frequency, increasing batch size, or using gradient compression".to_string(),
664            });
665        }
666
667        // Check for synchronization bottleneck
668        if sync_summary.sync_efficiency < 0.5 {
669            bottlenecks.push(Bottleneck {
670                bottleneck_type: BottleneckType::Synchronization,
671                severity: (1.0 - sync_summary.sync_efficiency) * 100.0,
672                affected_nodes: vec!["all".to_string()],
673                description: format!(
674                    "Synchronization efficiency is only {:.1}%, indicating significant overhead",
675                    sync_summary.sync_efficiency * 100.0
676                ),
677                suggestion: "Use gradient accumulation, optimize all-reduce operations, or consider hierarchical synchronization".to_string(),
678            });
679        }
680
681        // Check for load imbalance
682        if load_balance.imbalance_score > 0.3 {
683            bottlenecks.push(Bottleneck {
684                bottleneck_type: BottleneckType::ComputeImbalance,
685                severity: load_balance.imbalance_score * 100.0,
686                affected_nodes: load_balance.stragglers.clone(),
687                description: format!(
688                    "High load imbalance detected (score: {:.2}), {} straggler node(s)",
689                    load_balance.imbalance_score,
690                    load_balance.stragglers.len()
691                ),
692                suggestion: "Balance data distribution, check for hardware heterogeneity, or implement dynamic load balancing".to_string(),
693            });
694        }
695
696        // Check for straggler nodes
697        for straggler in &load_balance.stragglers {
698            if let Some(&idle_time) = load_balance.idle_time.get(straggler) {
699                if idle_time > 5.0 {
700                    // More than 5 seconds idle
701                    bottlenecks.push(Bottleneck {
702                        bottleneck_type: BottleneckType::Straggler,
703                        severity: 75.0,
704                        affected_nodes: vec![straggler.clone()],
705                        description: format!(
706                            "Node {} is a straggler with {:.1}s idle time",
707                            straggler, idle_time
708                        ),
709                        suggestion: format!(
710                            "Investigate node {} for hardware issues, resource contention, or network problems",
711                            straggler
712                        ),
713                    });
714                }
715            }
716        }
717
718        Ok(bottlenecks)
719    }
720
721    /// Generate optimization recommendations
722    fn generate_recommendations(
723        &self,
724        bottlenecks: &[Bottleneck],
725        load_balance: &LoadBalanceAnalysis,
726    ) -> Result<Vec<String>> {
727        let mut recommendations = Vec::new();
728
729        // General recommendations based on bottlenecks
730        for bottleneck in bottlenecks {
731            if bottleneck.severity > 50.0 {
732                recommendations.push(format!(
733                    "[HIGH PRIORITY] {}: {}",
734                    match bottleneck.bottleneck_type {
735                        BottleneckType::Communication => "Communication Bottleneck",
736                        BottleneckType::Synchronization => "Synchronization Bottleneck",
737                        BottleneckType::ComputeImbalance => "Load Imbalance",
738                        BottleneckType::Memory => "Memory Bottleneck",
739                        BottleneckType::NetworkCongestion => "Network Congestion",
740                        BottleneckType::Straggler => "Straggler Node",
741                    },
742                    bottleneck.suggestion
743                ));
744            }
745        }
746
747        // Load balance recommendations
748        if load_balance.imbalance_score > 0.2 {
749            recommendations.push(
750                "Consider implementing dynamic batch size adjustment per node based on compute capability".to_string()
751            );
752        }
753
754        // Check for underutilized nodes
755        let underutilized: Vec<_> = load_balance
756            .compute_utilization
757            .iter()
758            .filter(|(_, &util)| util < 50.0)
759            .collect();
760
761        if !underutilized.is_empty() {
762            recommendations.push(format!(
763                "{} node(s) are underutilized (<50% compute). Consider increasing batch size or model complexity",
764                underutilized.len()
765            ));
766        }
767
768        // If no specific recommendations, add general ones
769        if recommendations.is_empty() {
770            recommendations.push(
771                "Performance looks good! Continue monitoring for any degradation".to_string(),
772            );
773            recommendations.push(
774                "Consider enabling gradient compression to reduce communication overhead"
775                    .to_string(),
776            );
777            recommendations
778                .push("Experiment with mixed-precision training for better throughput".to_string());
779        }
780
781        Ok(recommendations)
782    }
783
784    /// Export profiling data to JSON
785    ///
786    /// # Arguments
787    /// * `path` - Output file path
788    pub fn export_json(&self, path: &std::path::Path) -> Result<()> {
789        let report = self.generate_report()?;
790        let json =
791            serde_json::to_string_pretty(&report).context("Failed to serialize report to JSON")?;
792        std::fs::write(path, json).context("Failed to write JSON file")?;
793        info!("Exported profiling report to {}", path.display());
794        Ok(())
795    }
796
797    /// Get real-time statistics (for dashboards)
798    ///
799    /// # Returns
800    /// Current profiling statistics
801    pub fn get_realtime_stats(&self) -> Result<RealtimeStats> {
802        let nodes = self.nodes.read();
803        let comm_events = self.comm_events.read();
804        let sync_events = self.sync_events.read();
805
806        // Calculate recent metrics (last 10 seconds)
807        let recent_cutoff = self.start_time.elapsed().saturating_sub(Duration::from_secs(10));
808
809        let recent_comm_count = comm_events.iter().filter(|e| e.timestamp >= recent_cutoff).count();
810
811        let recent_sync_count = sync_events.iter().filter(|e| e.timestamp >= recent_cutoff).count();
812
813        let active_nodes = nodes.values().filter(|n| n.status == NodeStatus::Active).count();
814
815        Ok(RealtimeStats {
816            active_nodes,
817            total_nodes: nodes.len(),
818            recent_communications: recent_comm_count,
819            recent_synchronizations: recent_sync_count,
820            elapsed_time_secs: self.start_time.elapsed().as_secs_f64(),
821        })
822    }
823}
824
825/// Real-time statistics for dashboards
826#[derive(Debug, Clone, Serialize, Deserialize)]
827pub struct RealtimeStats {
828    /// Number of active nodes
829    pub active_nodes: usize,
830    /// Total number of nodes
831    pub total_nodes: usize,
832    /// Recent communication events (last 10s)
833    pub recent_communications: usize,
834    /// Recent synchronization events (last 10s)
835    pub recent_synchronizations: usize,
836    /// Elapsed time since profiling started
837    pub elapsed_time_secs: f64,
838}
839
840#[cfg(test)]
841mod tests {
842    use super::*;
843
844    #[test]
845    fn test_profiler_creation() {
846        let config = DistributedProfilerConfig::default();
847        let _profiler = DistributedProfiler::new(config);
848    }
849
850    #[test]
851    fn test_node_registration() -> Result<()> {
852        let config = DistributedProfilerConfig::default();
853        let profiler = DistributedProfiler::new(config);
854
855        let node = NodeInfo {
856            node_id: "node-0".to_string(),
857            rank: 0,
858            world_size: 4,
859            host: "localhost".to_string(),
860            gpu_count: 1,
861            role: NodeRole::Master,
862            status: NodeStatus::Active,
863        };
864
865        profiler.register_node(node)?;
866
867        let nodes = profiler.nodes.read();
868        assert_eq!(nodes.len(), 1);
869        assert!(nodes.contains_key("node-0"));
870
871        Ok(())
872    }
873
874    #[test]
875    fn test_communication_recording() -> Result<()> {
876        let config = DistributedProfilerConfig::default();
877        let profiler = DistributedProfiler::new(config);
878
879        let event = CommunicationEvent {
880            event_id: 0,
881            timestamp: Duration::from_millis(100),
882            source_node: "node-0".to_string(),
883            dest_node: "node-1".to_string(),
884            comm_type: CommunicationType::AllReduce,
885            data_size_bytes: 1024 * 1024,
886            duration_ms: 10.5,
887            bandwidth_mbps: 95.0,
888        };
889
890        profiler.record_communication(event)?;
891
892        let events = profiler.comm_events.read();
893        assert_eq!(events.len(), 1);
894
895        Ok(())
896    }
897
898    #[test]
899    fn test_report_generation() -> Result<()> {
900        let config = DistributedProfilerConfig::default();
901        let profiler = DistributedProfiler::new(config);
902
903        // Register nodes
904        for i in 0..4 {
905            let node = NodeInfo {
906                node_id: format!("node-{}", i),
907                rank: i,
908                world_size: 4,
909                host: "localhost".to_string(),
910                gpu_count: 1,
911                role: if i == 0 { NodeRole::Master } else { NodeRole::Worker },
912                status: NodeStatus::Active,
913            };
914            profiler.register_node(node)?;
915        }
916
917        // Record some events
918        for i in 0..10 {
919            let event = CommunicationEvent {
920                event_id: i,
921                timestamp: Duration::from_millis(i as u64 * 100),
922                source_node: format!("node-{}", i % 4),
923                dest_node: format!("node-{}", (i + 1) % 4),
924                comm_type: CommunicationType::AllReduce,
925                data_size_bytes: 1024 * 1024,
926                duration_ms: 10.0 + (i as f64 * 0.5),
927                bandwidth_mbps: 100.0 - (i as f64 * 2.0),
928            };
929            profiler.record_communication(event)?;
930        }
931
932        let report = profiler.generate_report()?;
933
934        assert_eq!(report.num_nodes, 4);
935        assert_eq!(report.communication_summary.total_events, 10);
936        assert!(report.communication_summary.avg_bandwidth_mbps > 0.0);
937
938        Ok(())
939    }
940}