1use std::collections::HashMap;
8use std::sync::Arc;
9use trustformers_core::errors::Result;
10use trustformers_core::parallel::{
11 CommunicationBackend, ModelParallelConfig, ModelParallelContext,
12};
13use trustformers_core::tensor::Tensor;
14use trustformers_core::traits::Optimizer;
15
16use trustformers_core::parallel::{mpi_utils, MpiCommunicatorImpl};
17
18use crate::zero::{ZeROConfig, ZeROOptimizer, ZeROStage};
19
20#[derive(Debug, Clone)]
22pub struct MultiNodeConfig {
23 pub num_nodes: usize,
25 pub devices_per_node: usize,
27 pub node_rank: usize,
29 pub local_rank: usize,
31 pub global_rank: usize,
33 pub zero_config: ZeROConfig,
35 pub gradient_compression: bool,
37 pub comm_backend: CommunicationBackend,
39 pub overlap_comm_compute: bool,
41 pub gradient_bucket_size_mb: usize,
43}
44
45impl Default for MultiNodeConfig {
46 fn default() -> Self {
47 Self {
48 num_nodes: 1,
49 devices_per_node: 1,
50 node_rank: 0,
51 local_rank: 0,
52 global_rank: 0,
53 zero_config: ZeROConfig::default(),
54 gradient_compression: false,
55 comm_backend: CommunicationBackend::Mpi,
56 overlap_comm_compute: true,
57 gradient_bucket_size_mb: 25,
58 }
59 }
60}
61
62impl MultiNodeConfig {
63 pub fn new(
65 num_nodes: usize,
66 devices_per_node: usize,
67 node_rank: usize,
68 local_rank: usize,
69 ) -> Self {
70 let global_rank = node_rank * devices_per_node + local_rank;
71
72 Self {
73 num_nodes,
74 devices_per_node,
75 node_rank,
76 local_rank,
77 global_rank,
78 ..Default::default()
79 }
80 }
81
82 pub fn world_size(&self) -> usize {
84 self.num_nodes * self.devices_per_node
85 }
86
87 pub fn is_master(&self) -> bool {
89 self.global_rank == 0
90 }
91
92 pub fn node_local_ranks(&self) -> Vec<usize> {
94 let start = self.node_rank * self.devices_per_node;
95 (start..start + self.devices_per_node).collect()
96 }
97}
98
99pub struct MultiNodeTrainer<T: Optimizer> {
101 config: MultiNodeConfig,
102 mp_context: Arc<ModelParallelContext>,
103 zero_optimizer: ZeROOptimizer<T>,
104 mpi_communicator: Option<Arc<MpiCommunicatorImpl>>,
105 gradient_buffers: HashMap<String, GradientSyncBuffer>,
106 #[allow(dead_code)]
107 communication_overlap: bool,
108 node_local_group: Option<Vec<usize>>,
109 cross_node_group: Option<Vec<usize>>,
110}
111
112#[derive(Debug, Clone)]
114struct GradientSyncBuffer {
115 gradients: HashMap<String, Tensor>,
117 accumulation_steps: usize,
119 compression_info: Option<CompressionInfo>,
121}
122
123#[derive(Debug, Clone)]
124struct CompressionInfo {
125 #[allow(dead_code)]
127 compression_ratio: f32,
128 #[allow(dead_code)]
130 original_size: usize,
131 #[allow(dead_code)]
133 compressed_size: usize,
134}
135
136impl GradientSyncBuffer {
137 fn new() -> Self {
138 Self {
139 gradients: HashMap::new(),
140 accumulation_steps: 0,
141 compression_info: None,
142 }
143 }
144
145 fn add_gradient(&mut self, name: String, gradient: Tensor) -> Result<()> {
146 if let Some(existing) = self.gradients.get_mut(&name) {
147 *existing = existing.add(&gradient)?;
148 } else {
149 self.gradients.insert(name, gradient);
150 }
151 self.accumulation_steps += 1;
152 Ok(())
153 }
154
155 fn clear(&mut self) {
156 self.gradients.clear();
157 self.accumulation_steps = 0;
158 self.compression_info = None;
159 }
160
161 fn average_gradients(&mut self) -> Result<()> {
162 if self.accumulation_steps <= 1 {
163 return Ok(());
164 }
165
166 let divisor = self.accumulation_steps as f32;
167 for gradient in self.gradients.values_mut() {
168 *gradient = gradient.scalar_div(divisor)?;
169 }
170 Ok(())
171 }
172}
173
174impl<T: Optimizer> MultiNodeTrainer<T> {
175 pub fn new(config: MultiNodeConfig, base_optimizer: T) -> Result<Self> {
177 let mp_config = ModelParallelConfig {
179 num_devices: config.world_size(),
180 device_ids: (0..config.world_size()).collect(),
181 comm_backend: config.comm_backend,
182 ..Default::default()
183 };
184
185 let mp_context = Arc::new(ModelParallelContext::new(mp_config)?);
187
188 let zero_optimizer = ZeROOptimizer::new(
190 base_optimizer,
191 config.zero_config.clone(),
192 mp_context.clone(),
193 )?;
194
195 let mpi_communicator = if config.comm_backend == CommunicationBackend::Mpi {
197 Some(Arc::new(MpiCommunicatorImpl::new()?))
198 } else {
199 None
200 };
201
202 let node_local_group = Some(config.node_local_ranks());
204 let cross_node_group =
205 Some((0..config.num_nodes).map(|i| i * config.devices_per_node).collect());
206
207 let communication_overlap = config.overlap_comm_compute;
208
209 Ok(Self {
210 config,
211 mp_context,
212 zero_optimizer,
213 mpi_communicator,
214 gradient_buffers: HashMap::new(),
215 communication_overlap,
216 node_local_group,
217 cross_node_group,
218 })
219 }
220
221 pub fn initialize_environment() -> Result<()> {
223 mpi_utils::init_mpi_environment()?;
224 mpi_utils::check_mpi_environment()?;
225
226 let (local_rank, local_size) = mpi_utils::get_node_local_info()?;
228 println!("Multi-node environment initialized:");
229 println!(" Local rank: {}", local_rank);
230 println!(" Local size: {}", local_size);
231
232 Ok(())
233 }
234
235 pub fn register_parameters(&mut self, parameters: HashMap<String, Tensor>) -> Result<()> {
237 self.zero_optimizer.register_parameters(parameters.clone())?;
239
240 for name in parameters.keys() {
242 self.gradient_buffers.insert(name.clone(), GradientSyncBuffer::new());
243 }
244
245 println!("Multi-node training initialized:");
246 println!(" Node rank: {}", self.config.node_rank);
247 println!(" Global rank: {}", self.config.global_rank);
248 println!(" World size: {}", self.config.world_size());
249 println!(" ZeRO stage: {:?}", self.zero_optimizer.get_stage());
250 println!(" Parameters: {}", parameters.len());
251
252 Ok(())
253 }
254
255 pub fn update_gradients(&mut self, gradients: HashMap<String, Tensor>) -> Result<()> {
257 for (name, gradient) in gradients {
259 if let Some(buffer) = self.gradient_buffers.get_mut(&name) {
260 buffer.add_gradient(name.clone(), gradient)?;
261 }
262 }
263
264 self.zero_optimizer.update_gradients(self.collect_local_gradients()?)?;
266
267 Ok(())
268 }
269
270 fn collect_local_gradients(&self) -> Result<HashMap<String, Tensor>> {
272 let mut gradients = HashMap::new();
273 for (name, buffer) in &self.gradient_buffers {
274 if let Some(grad) = buffer.gradients.get(name) {
275 gradients.insert(name.clone(), grad.clone());
276 }
277 }
278 Ok(gradients)
279 }
280
281 pub fn synchronize_gradients(&mut self) -> Result<()> {
283 if self.config.world_size() == 1 {
284 return Ok(()); }
286
287 for buffer in self.gradient_buffers.values_mut() {
289 buffer.average_gradients()?;
290 }
291
292 self.hierarchical_all_reduce()?;
294
295 for buffer in self.gradient_buffers.values_mut() {
297 buffer.clear();
298 }
299
300 Ok(())
301 }
302
303 fn hierarchical_all_reduce(&mut self) -> Result<()> {
305 let has_mpi = self.mpi_communicator.is_some();
306
307 if has_mpi {
308 self.node_local_reduce()?;
310
311 if self.config.local_rank == 0 {
313 self.cross_node_all_reduce()?;
314 }
315
316 self.node_local_broadcast()?;
318
319 if let Some(ref mpi_comm) = self.mpi_communicator {
321 mpi_comm.barrier()?;
322 }
323 } else {
324 for buffer in self.gradient_buffers.values_mut() {
326 for gradient in buffer.gradients.values_mut() {
327 self.mp_context.all_reduce(gradient)?;
328 }
329 }
330 }
331
332 Ok(())
333 }
334
335 fn node_local_reduce(&mut self) -> Result<()> {
337 if let Some(ref _local_ranks) = self.node_local_group {
339 for buffer in self.gradient_buffers.values_mut() {
340 for gradient in buffer.gradients.values_mut() {
341 self.mp_context.all_reduce(gradient)?;
344 }
345 }
346 }
347 Ok(())
348 }
349
350 fn cross_node_all_reduce(&mut self) -> Result<()> {
352 if let Some(ref _cross_ranks) = self.cross_node_group {
354 for buffer in self.gradient_buffers.values_mut() {
355 for gradient in buffer.gradients.values_mut() {
356 self.mp_context.all_reduce(gradient)?;
357 }
358 }
359 }
360 Ok(())
361 }
362
363 fn node_local_broadcast(&mut self) -> Result<()> {
365 let root_rank = self.config.node_rank * self.config.devices_per_node;
367
368 for buffer in self.gradient_buffers.values_mut() {
369 for gradient in buffer.gradients.values_mut() {
370 self.mp_context.broadcast(gradient, root_rank)?;
371 }
372 }
373 Ok(())
374 }
375
376 pub fn apply_gradients(&mut self, accumulation_steps: usize) -> Result<()> {
378 self.synchronize_gradients()?;
380
381 self.zero_optimizer.apply_accumulated_grads(accumulation_steps)?;
383
384 Ok(())
385 }
386
387 pub fn optimizer_step(&mut self) -> Result<()> {
389 self.synchronize_gradients()?;
391
392 self.zero_optimizer.optimizer_step()?;
394
395 Ok(())
396 }
397
398 pub fn get_memory_usage(&self) -> HashMap<String, usize> {
400 let memory_stats = self.zero_optimizer.get_memory_stats();
401 let mut stats = HashMap::new();
402
403 stats.insert(
405 "optimizer_memory_saved".to_string(),
406 memory_stats.optimizer_memory_saved,
407 );
408 stats.insert(
409 "gradient_memory_saved".to_string(),
410 memory_stats.gradient_memory_saved,
411 );
412 stats.insert(
413 "parameter_memory_saved".to_string(),
414 memory_stats.parameter_memory_saved,
415 );
416 stats.insert(
417 "communication_overhead".to_string(),
418 memory_stats.communication_overhead,
419 );
420 stats.insert(
421 "total_memory_saved".to_string(),
422 memory_stats.total_memory_saved,
423 );
424
425 let mut buffer_memory = 0;
427 for buffer in self.gradient_buffers.values() {
428 for gradient in buffer.gradients.values() {
429 buffer_memory += gradient.memory_usage();
430 }
431 }
432 stats.insert("gradient_sync_buffers".to_string(), buffer_memory);
433
434 let comm_overhead = self.config.world_size() * 1024 * 1024; stats.insert("communication_overhead".to_string(), comm_overhead);
437
438 stats
439 }
440
441 pub fn get_training_stats(&self) -> MultiNodeStats {
443 let memory_stats = self.zero_optimizer.get_memory_stats();
444 let mut memory_savings = HashMap::new();
445
446 let total_memory = memory_stats.total_memory_saved;
448 if total_memory > 0 {
449 memory_savings.insert(
450 "optimizer_states".to_string(),
451 memory_stats.optimizer_memory_saved as f32 / total_memory as f32,
452 );
453 memory_savings.insert(
454 "gradients".to_string(),
455 memory_stats.gradient_memory_saved as f32 / total_memory as f32,
456 );
457 memory_savings.insert(
458 "parameters".to_string(),
459 memory_stats.parameter_memory_saved as f32 / total_memory as f32,
460 );
461 }
462
463 MultiNodeStats {
464 node_rank: self.config.node_rank,
465 global_rank: self.config.global_rank,
466 world_size: self.config.world_size(),
467 zero_stage: self.zero_optimizer.get_stage(),
468 memory_savings,
469 communication_backend: self.config.comm_backend,
470 gradient_compression_enabled: self.config.gradient_compression,
471 }
472 }
473
474 pub fn should_save_checkpoint(&self) -> bool {
476 self.config.is_master()
477 }
478
479 pub fn barrier(&self) -> Result<()> {
481 if let Some(ref mpi_comm) = self.mpi_communicator {
482 mpi_comm.barrier()?;
483 }
484
485 Ok(())
486 }
487
488 pub fn finalize() -> Result<()> {
490 MpiCommunicatorImpl::finalize()?;
491
492 println!("Multi-node training finalized");
493 Ok(())
494 }
495}
496
497#[derive(Debug, Clone)]
499pub struct MultiNodeStats {
500 pub node_rank: usize,
501 pub global_rank: usize,
502 pub world_size: usize,
503 pub zero_stage: ZeROStage,
504 pub memory_savings: HashMap<String, f32>,
505 pub communication_backend: CommunicationBackend,
506 pub gradient_compression_enabled: bool,
507}
508
509impl MultiNodeStats {
510 pub fn print_stats(&self) {
512 println!("=== Multi-Node Training Statistics ===");
513 println!("Node Rank: {}", self.node_rank);
514 println!("Global Rank: {}", self.global_rank);
515 println!("World Size: {}", self.world_size);
516 println!("ZeRO Stage: {:?}", self.zero_stage);
517 println!("Communication Backend: {:?}", self.communication_backend);
518 println!(
519 "Gradient Compression: {}",
520 self.gradient_compression_enabled
521 );
522
523 println!("Memory Savings:");
524 for (component, savings) in &self.memory_savings {
525 println!(" {}: {:.1}%", component, savings * 100.0);
526 }
527 println!("=====================================");
528 }
529}
530
531#[cfg(test)]
532mod tests {
533 use super::*;
534 use crate::adam::Adam;
535
536 #[test]
537 fn test_multinode_config() {
538 let config = MultiNodeConfig::new(4, 8, 2, 3);
539
540 assert_eq!(config.num_nodes, 4);
541 assert_eq!(config.devices_per_node, 8);
542 assert_eq!(config.node_rank, 2);
543 assert_eq!(config.local_rank, 3);
544 assert_eq!(config.global_rank, 19); assert_eq!(config.world_size(), 32); assert!(!config.is_master());
547
548 let master_config = MultiNodeConfig::new(4, 8, 0, 0);
549 assert!(master_config.is_master());
550 }
551
552 #[test]
553 fn test_multinode_trainer_creation() {
554 let config = MultiNodeConfig::new(2, 4, 0, 0);
555 let optimizer = Adam::new(0.001, (0.9, 0.999), 1e-8, 0.0);
556
557 match MultiNodeTrainer::new(config, optimizer) {
558 Ok(trainer) => {
559 assert_eq!(trainer.config.world_size(), 8);
560 assert!(trainer.config.is_master());
561 },
562 Err(e) => {
563 println!("Expected error in test environment: {}", e);
565 },
566 }
567 }
568
569 #[test]
570 fn test_gradient_sync_buffer() {
571 let mut buffer = GradientSyncBuffer::new();
572
573 let grad1 = Tensor::ones(&[2, 2]).unwrap();
574 let grad2 = Tensor::ones(&[2, 2]).unwrap();
575
576 buffer.add_gradient("param1".to_string(), grad1).unwrap();
577 buffer.add_gradient("param1".to_string(), grad2).unwrap();
578
579 assert_eq!(buffer.accumulation_steps, 2);
580 assert_eq!(buffer.gradients.len(), 1);
581
582 buffer.average_gradients().unwrap();
583 buffer.clear();
586 assert_eq!(buffer.gradients.len(), 0);
587 assert_eq!(buffer.accumulation_steps, 0);
588 }
589
590 #[test]
591 fn test_node_groups() {
592 let config = MultiNodeConfig::new(3, 4, 1, 2);
593 let node_ranks = config.node_local_ranks();
594
595 assert_eq!(node_ranks, vec![4, 5, 6, 7]); }
597}