Skip to main content

trustformers_optim/multinode/
mod.rs

1//! Multi-Node Distributed Training Support
2//!
3//! This module provides infrastructure for distributed training across
4//! multiple nodes using MPI communication backend. It integrates with
5//! the existing ZeRO optimization for memory-efficient multi-node training.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9use trustformers_core::errors::Result;
10use trustformers_core::parallel::{
11    CommunicationBackend, ModelParallelConfig, ModelParallelContext,
12};
13use trustformers_core::tensor::Tensor;
14use trustformers_core::traits::Optimizer;
15
16use trustformers_core::parallel::{mpi_utils, MpiCommunicatorImpl};
17
18use crate::zero::{ZeROConfig, ZeROOptimizer, ZeROStage};
19
20/// Multi-node training configuration
21#[derive(Debug, Clone)]
22pub struct MultiNodeConfig {
23    /// Number of nodes in the cluster
24    pub num_nodes: usize,
25    /// Number of devices per node
26    pub devices_per_node: usize,
27    /// Node rank (0-based)
28    pub node_rank: usize,
29    /// Local device rank within node
30    pub local_rank: usize,
31    /// Global rank across all nodes
32    pub global_rank: usize,
33    /// ZeRO configuration for memory optimization
34    pub zero_config: ZeROConfig,
35    /// Enable gradient compression
36    pub gradient_compression: bool,
37    /// Communication backend
38    pub comm_backend: CommunicationBackend,
39    /// Enable overlap of computation and communication
40    pub overlap_comm_compute: bool,
41    /// Bucket size for gradient synchronization (MB)
42    pub gradient_bucket_size_mb: usize,
43}
44
45impl Default for MultiNodeConfig {
46    fn default() -> Self {
47        Self {
48            num_nodes: 1,
49            devices_per_node: 1,
50            node_rank: 0,
51            local_rank: 0,
52            global_rank: 0,
53            zero_config: ZeROConfig::default(),
54            gradient_compression: false,
55            comm_backend: CommunicationBackend::Mpi,
56            overlap_comm_compute: true,
57            gradient_bucket_size_mb: 25,
58        }
59    }
60}
61
62impl MultiNodeConfig {
63    /// Create configuration for multi-node training
64    pub fn new(
65        num_nodes: usize,
66        devices_per_node: usize,
67        node_rank: usize,
68        local_rank: usize,
69    ) -> Self {
70        let global_rank = node_rank * devices_per_node + local_rank;
71
72        Self {
73            num_nodes,
74            devices_per_node,
75            node_rank,
76            local_rank,
77            global_rank,
78            ..Default::default()
79        }
80    }
81
82    /// Get total world size across all nodes
83    pub fn world_size(&self) -> usize {
84        self.num_nodes * self.devices_per_node
85    }
86
87    /// Check if this is the master rank
88    pub fn is_master(&self) -> bool {
89        self.global_rank == 0
90    }
91
92    /// Get node-local ranks for this node
93    pub fn node_local_ranks(&self) -> Vec<usize> {
94        let start = self.node_rank * self.devices_per_node;
95        (start..start + self.devices_per_node).collect()
96    }
97}
98
99/// Multi-node distributed training coordinator
100pub struct MultiNodeTrainer<T: Optimizer> {
101    config: MultiNodeConfig,
102    mp_context: Arc<ModelParallelContext>,
103    zero_optimizer: ZeROOptimizer<T>,
104    mpi_communicator: Option<Arc<MpiCommunicatorImpl>>,
105    gradient_buffers: HashMap<String, GradientSyncBuffer>,
106    #[allow(dead_code)]
107    communication_overlap: bool,
108    node_local_group: Option<Vec<usize>>,
109    cross_node_group: Option<Vec<usize>>,
110}
111
112/// Buffer for gradient synchronization across nodes
113#[derive(Debug, Clone)]
114struct GradientSyncBuffer {
115    /// Buffered gradients
116    gradients: HashMap<String, Tensor>,
117    /// Number of accumulated steps
118    accumulation_steps: usize,
119    /// Compression metadata
120    compression_info: Option<CompressionInfo>,
121}
122
123#[derive(Debug, Clone)]
124struct CompressionInfo {
125    /// Compression ratio achieved
126    #[allow(dead_code)]
127    compression_ratio: f32,
128    /// Original size in bytes
129    #[allow(dead_code)]
130    original_size: usize,
131    /// Compressed size in bytes
132    #[allow(dead_code)]
133    compressed_size: usize,
134}
135
136impl GradientSyncBuffer {
137    fn new() -> Self {
138        Self {
139            gradients: HashMap::new(),
140            accumulation_steps: 0,
141            compression_info: None,
142        }
143    }
144
145    fn add_gradient(&mut self, name: String, gradient: Tensor) -> Result<()> {
146        if let Some(existing) = self.gradients.get_mut(&name) {
147            *existing = existing.add(&gradient)?;
148        } else {
149            self.gradients.insert(name, gradient);
150        }
151        self.accumulation_steps += 1;
152        Ok(())
153    }
154
155    fn clear(&mut self) {
156        self.gradients.clear();
157        self.accumulation_steps = 0;
158        self.compression_info = None;
159    }
160
161    fn average_gradients(&mut self) -> Result<()> {
162        if self.accumulation_steps <= 1 {
163            return Ok(());
164        }
165
166        let divisor = self.accumulation_steps as f32;
167        for gradient in self.gradients.values_mut() {
168            *gradient = gradient.scalar_div(divisor)?;
169        }
170        Ok(())
171    }
172}
173
174impl<T: Optimizer> MultiNodeTrainer<T> {
175    /// Create a new multi-node trainer
176    pub fn new(config: MultiNodeConfig, base_optimizer: T) -> Result<Self> {
177        // Create model parallel configuration
178        let mp_config = ModelParallelConfig {
179            num_devices: config.world_size(),
180            device_ids: (0..config.world_size()).collect(),
181            comm_backend: config.comm_backend,
182            ..Default::default()
183        };
184
185        // Initialize model parallel context
186        let mp_context = Arc::new(ModelParallelContext::new(mp_config)?);
187
188        // Initialize ZeRO optimizer
189        let zero_optimizer = ZeROOptimizer::new(
190            base_optimizer,
191            config.zero_config.clone(),
192            mp_context.clone(),
193        )?;
194
195        // Initialize MPI communicator
196        let mpi_communicator = if config.comm_backend == CommunicationBackend::Mpi {
197            Some(Arc::new(MpiCommunicatorImpl::new()?))
198        } else {
199            None
200        };
201
202        // Create node groups for hierarchical communication
203        let node_local_group = Some(config.node_local_ranks());
204        let cross_node_group =
205            Some((0..config.num_nodes).map(|i| i * config.devices_per_node).collect());
206
207        let communication_overlap = config.overlap_comm_compute;
208
209        Ok(Self {
210            config,
211            mp_context,
212            zero_optimizer,
213            mpi_communicator,
214            gradient_buffers: HashMap::new(),
215            communication_overlap,
216            node_local_group,
217            cross_node_group,
218        })
219    }
220
221    /// Initialize MPI environment for multi-node training
222    pub fn initialize_environment() -> Result<()> {
223        mpi_utils::init_mpi_environment()?;
224        mpi_utils::check_mpi_environment()?;
225
226        // Get node-local information
227        let (local_rank, local_size) = mpi_utils::get_node_local_info()?;
228        println!("Multi-node environment initialized:");
229        println!("  Local rank: {}", local_rank);
230        println!("  Local size: {}", local_size);
231
232        Ok(())
233    }
234
235    /// Register parameters for multi-node training
236    pub fn register_parameters(&mut self, parameters: HashMap<String, Tensor>) -> Result<()> {
237        // Register with ZeRO optimizer
238        self.zero_optimizer.register_parameters(parameters.clone())?;
239
240        // Initialize gradient buffers for each parameter
241        for name in parameters.keys() {
242            self.gradient_buffers.insert(name.clone(), GradientSyncBuffer::new());
243        }
244
245        println!("Multi-node training initialized:");
246        println!("  Node rank: {}", self.config.node_rank);
247        println!("  Global rank: {}", self.config.global_rank);
248        println!("  World size: {}", self.config.world_size());
249        println!("  ZeRO stage: {:?}", self.zero_optimizer.get_stage());
250        println!("  Parameters: {}", parameters.len());
251
252        Ok(())
253    }
254
255    /// Update gradients with multi-node synchronization
256    pub fn update_gradients(&mut self, gradients: HashMap<String, Tensor>) -> Result<()> {
257        // Accumulate gradients locally
258        for (name, gradient) in gradients {
259            if let Some(buffer) = self.gradient_buffers.get_mut(&name) {
260                buffer.add_gradient(name.clone(), gradient)?;
261            }
262        }
263
264        // Update ZeRO optimizer (local processing)
265        self.zero_optimizer.update_gradients(self.collect_local_gradients()?)?;
266
267        Ok(())
268    }
269
270    /// Collect local gradients from buffers
271    fn collect_local_gradients(&self) -> Result<HashMap<String, Tensor>> {
272        let mut gradients = HashMap::new();
273        for (name, buffer) in &self.gradient_buffers {
274            if let Some(grad) = buffer.gradients.get(name) {
275                gradients.insert(name.clone(), grad.clone());
276            }
277        }
278        Ok(gradients)
279    }
280
281    /// Synchronize gradients across all nodes
282    pub fn synchronize_gradients(&mut self) -> Result<()> {
283        if self.config.world_size() == 1 {
284            return Ok(()); // No synchronization needed for single node
285        }
286
287        // Average accumulated gradients
288        for buffer in self.gradient_buffers.values_mut() {
289            buffer.average_gradients()?;
290        }
291
292        // Perform hierarchical gradient synchronization
293        self.hierarchical_all_reduce()?;
294
295        // Clear buffers after synchronization
296        for buffer in self.gradient_buffers.values_mut() {
297            buffer.clear();
298        }
299
300        Ok(())
301    }
302
303    /// Hierarchical all-reduce for better network utilization
304    fn hierarchical_all_reduce(&mut self) -> Result<()> {
305        let has_mpi = self.mpi_communicator.is_some();
306
307        if has_mpi {
308            // Step 1: Reduce within each node
309            self.node_local_reduce()?;
310
311            // Step 2: All-reduce across nodes (one rank per node)
312            if self.config.local_rank == 0 {
313                self.cross_node_all_reduce()?;
314            }
315
316            // Step 3: Broadcast within each node
317            self.node_local_broadcast()?;
318
319            // Synchronize all processes
320            if let Some(ref mpi_comm) = self.mpi_communicator {
321                mpi_comm.barrier()?;
322            }
323        } else {
324            // Fallback to regular all-reduce using model parallel context
325            for buffer in self.gradient_buffers.values_mut() {
326                for gradient in buffer.gradients.values_mut() {
327                    self.mp_context.all_reduce(gradient)?;
328                }
329            }
330        }
331
332        Ok(())
333    }
334
335    /// Reduce gradients within each node
336    fn node_local_reduce(&mut self) -> Result<()> {
337        // Simplified implementation - in practice would use node-local communicator
338        if let Some(ref _local_ranks) = self.node_local_group {
339            for buffer in self.gradient_buffers.values_mut() {
340                for gradient in buffer.gradients.values_mut() {
341                    // Use MPI reduce operation within node
342                    // For now, use the general all_reduce
343                    self.mp_context.all_reduce(gradient)?;
344                }
345            }
346        }
347        Ok(())
348    }
349
350    /// All-reduce gradients across nodes
351    fn cross_node_all_reduce(&mut self) -> Result<()> {
352        // Only performed by one rank per node (usually local_rank == 0)
353        if let Some(ref _cross_ranks) = self.cross_node_group {
354            for buffer in self.gradient_buffers.values_mut() {
355                for gradient in buffer.gradients.values_mut() {
356                    self.mp_context.all_reduce(gradient)?;
357                }
358            }
359        }
360        Ok(())
361    }
362
363    /// Broadcast gradients within each node
364    fn node_local_broadcast(&mut self) -> Result<()> {
365        // Broadcast from local rank 0 to other ranks on the same node
366        let root_rank = self.config.node_rank * self.config.devices_per_node;
367
368        for buffer in self.gradient_buffers.values_mut() {
369            for gradient in buffer.gradients.values_mut() {
370                self.mp_context.broadcast(gradient, root_rank)?;
371            }
372        }
373        Ok(())
374    }
375
376    /// Apply gradients with multi-node coordination
377    pub fn apply_gradients(&mut self, accumulation_steps: usize) -> Result<()> {
378        // Synchronize gradients across nodes first
379        self.synchronize_gradients()?;
380
381        // Apply gradients using ZeRO optimizer
382        self.zero_optimizer.apply_accumulated_grads(accumulation_steps)?;
383
384        Ok(())
385    }
386
387    /// Perform optimizer step with multi-node coordination
388    pub fn optimizer_step(&mut self) -> Result<()> {
389        // Synchronize gradients across nodes
390        self.synchronize_gradients()?;
391
392        // Perform optimizer step using ZeRO
393        self.zero_optimizer.optimizer_step()?;
394
395        Ok(())
396    }
397
398    /// Get comprehensive memory usage across nodes
399    pub fn get_memory_usage(&self) -> HashMap<String, usize> {
400        let memory_stats = self.zero_optimizer.get_memory_stats();
401        let mut stats = HashMap::new();
402
403        // Add ZeRO memory statistics
404        stats.insert(
405            "optimizer_memory_saved".to_string(),
406            memory_stats.optimizer_memory_saved,
407        );
408        stats.insert(
409            "gradient_memory_saved".to_string(),
410            memory_stats.gradient_memory_saved,
411        );
412        stats.insert(
413            "parameter_memory_saved".to_string(),
414            memory_stats.parameter_memory_saved,
415        );
416        stats.insert(
417            "communication_overhead".to_string(),
418            memory_stats.communication_overhead,
419        );
420        stats.insert(
421            "total_memory_saved".to_string(),
422            memory_stats.total_memory_saved,
423        );
424
425        // Add multi-node specific memory usage
426        let mut buffer_memory = 0;
427        for buffer in self.gradient_buffers.values() {
428            for gradient in buffer.gradients.values() {
429                buffer_memory += gradient.memory_usage();
430            }
431        }
432        stats.insert("gradient_sync_buffers".to_string(), buffer_memory);
433
434        // Add communication overhead estimate
435        let comm_overhead = self.config.world_size() * 1024 * 1024; // 1MB per process estimate
436        stats.insert("communication_overhead".to_string(), comm_overhead);
437
438        stats
439    }
440
441    /// Get multi-node training statistics
442    pub fn get_training_stats(&self) -> MultiNodeStats {
443        let memory_stats = self.zero_optimizer.get_memory_stats();
444        let mut memory_savings = HashMap::new();
445
446        // Convert memory stats to savings percentages (simplified calculation)
447        let total_memory = memory_stats.total_memory_saved;
448        if total_memory > 0 {
449            memory_savings.insert(
450                "optimizer_states".to_string(),
451                memory_stats.optimizer_memory_saved as f32 / total_memory as f32,
452            );
453            memory_savings.insert(
454                "gradients".to_string(),
455                memory_stats.gradient_memory_saved as f32 / total_memory as f32,
456            );
457            memory_savings.insert(
458                "parameters".to_string(),
459                memory_stats.parameter_memory_saved as f32 / total_memory as f32,
460            );
461        }
462
463        MultiNodeStats {
464            node_rank: self.config.node_rank,
465            global_rank: self.config.global_rank,
466            world_size: self.config.world_size(),
467            zero_stage: self.zero_optimizer.get_stage(),
468            memory_savings,
469            communication_backend: self.config.comm_backend,
470            gradient_compression_enabled: self.config.gradient_compression,
471        }
472    }
473
474    /// Check if this process should save checkpoints
475    pub fn should_save_checkpoint(&self) -> bool {
476        self.config.is_master()
477    }
478
479    /// Barrier synchronization across all nodes
480    pub fn barrier(&self) -> Result<()> {
481        if let Some(ref mpi_comm) = self.mpi_communicator {
482            mpi_comm.barrier()?;
483        }
484
485        Ok(())
486    }
487
488    /// Finalize multi-node training
489    pub fn finalize() -> Result<()> {
490        MpiCommunicatorImpl::finalize()?;
491
492        println!("Multi-node training finalized");
493        Ok(())
494    }
495}
496
497/// Statistics for multi-node training
498#[derive(Debug, Clone)]
499pub struct MultiNodeStats {
500    pub node_rank: usize,
501    pub global_rank: usize,
502    pub world_size: usize,
503    pub zero_stage: ZeROStage,
504    pub memory_savings: HashMap<String, f32>,
505    pub communication_backend: CommunicationBackend,
506    pub gradient_compression_enabled: bool,
507}
508
509impl MultiNodeStats {
510    /// Print training statistics
511    pub fn print_stats(&self) {
512        println!("=== Multi-Node Training Statistics ===");
513        println!("Node Rank: {}", self.node_rank);
514        println!("Global Rank: {}", self.global_rank);
515        println!("World Size: {}", self.world_size);
516        println!("ZeRO Stage: {:?}", self.zero_stage);
517        println!("Communication Backend: {:?}", self.communication_backend);
518        println!(
519            "Gradient Compression: {}",
520            self.gradient_compression_enabled
521        );
522
523        println!("Memory Savings:");
524        for (component, savings) in &self.memory_savings {
525            println!("  {}: {:.1}%", component, savings * 100.0);
526        }
527        println!("=====================================");
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534    use crate::adam::Adam;
535
536    #[test]
537    fn test_multinode_config() {
538        let config = MultiNodeConfig::new(4, 8, 2, 3);
539
540        assert_eq!(config.num_nodes, 4);
541        assert_eq!(config.devices_per_node, 8);
542        assert_eq!(config.node_rank, 2);
543        assert_eq!(config.local_rank, 3);
544        assert_eq!(config.global_rank, 19); // 2 * 8 + 3
545        assert_eq!(config.world_size(), 32); // 4 * 8
546        assert!(!config.is_master());
547
548        let master_config = MultiNodeConfig::new(4, 8, 0, 0);
549        assert!(master_config.is_master());
550    }
551
552    #[test]
553    fn test_multinode_trainer_creation() {
554        let config = MultiNodeConfig::new(2, 4, 0, 0);
555        let optimizer = Adam::new(0.001, (0.9, 0.999), 1e-8, 0.0);
556
557        match MultiNodeTrainer::new(config, optimizer) {
558            Ok(trainer) => {
559                assert_eq!(trainer.config.world_size(), 8);
560                assert!(trainer.config.is_master());
561            },
562            Err(e) => {
563                // Expected in test environment without proper MPI setup
564                println!("Expected error in test environment: {}", e);
565            },
566        }
567    }
568
569    #[test]
570    fn test_gradient_sync_buffer() {
571        let mut buffer = GradientSyncBuffer::new();
572
573        let grad1 = Tensor::ones(&[2, 2]).unwrap();
574        let grad2 = Tensor::ones(&[2, 2]).unwrap();
575
576        buffer.add_gradient("param1".to_string(), grad1).unwrap();
577        buffer.add_gradient("param1".to_string(), grad2).unwrap();
578
579        assert_eq!(buffer.accumulation_steps, 2);
580        assert_eq!(buffer.gradients.len(), 1);
581
582        buffer.average_gradients().unwrap();
583        // After averaging, each element should be 1.0 (2.0 / 2)
584
585        buffer.clear();
586        assert_eq!(buffer.gradients.len(), 0);
587        assert_eq!(buffer.accumulation_steps, 0);
588    }
589
590    #[test]
591    fn test_node_groups() {
592        let config = MultiNodeConfig::new(3, 4, 1, 2);
593        let node_ranks = config.node_local_ranks();
594
595        assert_eq!(node_ranks, vec![4, 5, 6, 7]); // Node 1 with 4 devices per node
596    }
597}