1use crate::distributed::ProcessGroup;
2use anyhow::{anyhow, Result};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::{Arc, Mutex, RwLock};
6use std::time::{Duration, Instant};
7use trustformers_core::tensor::Tensor;
8use trustformers_core::traits::Model;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ParallelismConfig {
23 pub dp_size: usize,
25 pub mp_size: usize,
27 pub pp_size: usize,
29 pub num_micro_batches: usize,
31 pub gradient_accumulation: bool,
33 pub accumulation_steps: usize,
35 pub activation_checkpointing: bool,
37 pub comm_backend: CommBackend,
39 pub pipeline_schedule: PipelineSchedule,
41 pub memory_optimization: MemoryOptimization,
43}
44
45impl Default for ParallelismConfig {
46 fn default() -> Self {
47 Self {
48 dp_size: 1,
49 mp_size: 1,
50 pp_size: 1,
51 num_micro_batches: 4,
52 gradient_accumulation: true,
53 accumulation_steps: 1,
54 activation_checkpointing: true,
55 comm_backend: CommBackend::NCCL,
56 pipeline_schedule: PipelineSchedule::GPipe,
57 memory_optimization: MemoryOptimization::Medium,
58 }
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub enum CommBackend {
65 NCCL,
66 Gloo,
67 MPI,
68 InfiniBand,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub enum PipelineSchedule {
74 GPipe,
76 PipeDream,
78 PipeDream2BW,
80 Interleaved1F1B,
82 Adaptive,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub enum MemoryOptimization {
89 None,
90 Low,
91 Medium,
92 High,
93 Extreme,
94}
95
96#[allow(dead_code)]
101pub struct Parallelism3D {
102 config: ParallelismConfig,
103 #[allow(dead_code)]
104 global_rank: usize,
105 world_size: usize,
106
107 dp_rank: usize,
109 mp_rank: usize,
110 pp_rank: usize,
111
112 dp_group: Arc<dyn ProcessGroup>,
114 mp_group: Arc<dyn ProcessGroup>,
115 pp_group: Arc<dyn ProcessGroup>,
116
117 pipeline_state: Arc<RwLock<PipelineState>>,
119
120 comm_stats: Arc<Mutex<CommunicationStats>>,
122
123 memory_manager: Arc<Mutex<MemoryManager>>,
125}
126
127#[derive(Debug, Default)]
129#[allow(dead_code)]
130struct PipelineState {
131 #[allow(dead_code)]
132 current_micro_batch: usize,
133 forward_passes_completed: usize,
134 backward_passes_completed: usize,
135 pipeline_bubbles: usize,
136 stage_timings: HashMap<usize, Duration>,
137 communication_overhead: Duration,
138}
139
140#[derive(Debug, Default)]
142#[allow(dead_code)]
143struct CommunicationStats {
144 dp_all_reduce_time: Duration,
145 mp_all_reduce_time: Duration,
146 pp_send_recv_time: Duration,
147 #[allow(dead_code)]
148 total_bytes_communicated: u64,
149 communication_efficiency: f32,
150 bandwidth_utilization: f32,
151}
152
153#[derive(Debug)]
155#[allow(dead_code)]
156struct MemoryManager {
157 #[allow(dead_code)]
158 activation_memory_pool: HashMap<String, Vec<Tensor>>,
159 gradient_memory_pool: HashMap<String, Vec<Tensor>>,
160 peak_memory_usage: u64,
161 current_memory_usage: u64,
162 memory_optimization_level: MemoryOptimization,
163 checkpointed_activations: HashMap<String, Vec<Tensor>>,
164}
165
166impl Default for MemoryManager {
167 fn default() -> Self {
168 Self {
169 activation_memory_pool: HashMap::new(),
170 gradient_memory_pool: HashMap::new(),
171 peak_memory_usage: 0,
172 current_memory_usage: 0,
173 memory_optimization_level: MemoryOptimization::Medium,
174 checkpointed_activations: HashMap::new(),
175 }
176 }
177}
178
179impl Parallelism3D {
180 pub fn new(
182 config: ParallelismConfig,
183 global_rank: usize,
184 world_size: usize,
185 dp_group: Arc<dyn ProcessGroup>,
186 mp_group: Arc<dyn ProcessGroup>,
187 pp_group: Arc<dyn ProcessGroup>,
188 ) -> Result<Self> {
189 if config.dp_size * config.mp_size * config.pp_size != world_size {
191 return Err(anyhow!(
192 "Invalid parallelism configuration: dp_size ({}) * mp_size ({}) * pp_size ({}) != world_size ({})",
193 config.dp_size, config.mp_size, config.pp_size, world_size
194 ));
195 }
196
197 let dp_rank = global_rank / (config.mp_size * config.pp_size);
199 let mp_rank = (global_rank / config.pp_size) % config.mp_size;
200 let pp_rank = global_rank % config.pp_size;
201
202 let memory_manager = MemoryManager {
203 memory_optimization_level: config.memory_optimization.clone(),
204 ..Default::default()
205 };
206
207 Ok(Self {
208 config,
209 global_rank,
210 world_size,
211 dp_rank,
212 mp_rank,
213 pp_rank,
214 dp_group,
215 mp_group,
216 pp_group,
217 pipeline_state: Arc::new(RwLock::new(PipelineState::default())),
218 comm_stats: Arc::new(Mutex::new(CommunicationStats::default())),
219 memory_manager: Arc::new(Mutex::new(memory_manager)),
220 })
221 }
222
223 pub fn forward_pass<M: Model>(
225 &self,
226 model: &M,
227 inputs: &[Tensor],
228 micro_batch_id: usize,
229 ) -> Result<Vec<Tensor>> {
230 let _start_time = Instant::now();
231
232 match self.config.pipeline_schedule {
234 PipelineSchedule::GPipe => self.forward_gpipe(model, inputs, micro_batch_id),
235 PipelineSchedule::PipeDream => self.forward_pipedream(model, inputs, micro_batch_id),
236 PipelineSchedule::PipeDream2BW => {
237 self.forward_pipedream_2bw(model, inputs, micro_batch_id)
238 },
239 PipelineSchedule::Interleaved1F1B => {
240 self.forward_interleaved_1f1b(model, inputs, micro_batch_id)
241 },
242 PipelineSchedule::Adaptive => self.forward_adaptive(model, inputs, micro_batch_id),
243 }
244 }
245
246 pub fn backward_pass<M: Model>(
248 &self,
249 model: &mut M,
250 gradients: &[Tensor],
251 micro_batch_id: usize,
252 ) -> Result<Vec<Tensor>> {
253 let _start_time = Instant::now();
254
255 match self.config.pipeline_schedule {
257 PipelineSchedule::GPipe => self.backward_gpipe(model, gradients, micro_batch_id),
258 PipelineSchedule::PipeDream => {
259 self.backward_pipedream(model, gradients, micro_batch_id)
260 },
261 PipelineSchedule::PipeDream2BW => {
262 self.backward_pipedream_2bw(model, gradients, micro_batch_id)
263 },
264 PipelineSchedule::Interleaved1F1B => {
265 self.backward_interleaved_1f1b(model, gradients, micro_batch_id)
266 },
267 PipelineSchedule::Adaptive => self.backward_adaptive(model, gradients, micro_batch_id),
268 }
269 }
270
271 pub fn synchronize_gradients(&self, gradients: &mut [Tensor]) -> Result<()> {
273 let start_time = Instant::now();
274
275 if self.config.mp_size > 1 {
277 self.mp_reduce_scatter_gradients(gradients)?;
278 }
279
280 if self.config.dp_size > 1 {
282 self.dp_all_reduce_gradients(gradients)?;
283 }
284
285 if self.config.mp_size > 1 {
287 self.mp_all_gather_gradients(gradients)?;
288 }
289
290 let mut stats = self.comm_stats.lock().expect("lock should not be poisoned");
292 stats.dp_all_reduce_time += start_time.elapsed();
293
294 Ok(())
295 }
296
297 pub fn optimize_memory(&self, tensors: &mut [Tensor]) -> Result<()> {
299 let memory_manager = self.memory_manager.lock().expect("lock should not be poisoned");
300
301 match memory_manager.memory_optimization_level {
302 MemoryOptimization::None => {
303 Ok(())
305 },
306 MemoryOptimization::Low => {
307 self.apply_activation_checkpointing(tensors, 4)
309 },
310 MemoryOptimization::Medium => {
311 self.apply_activation_checkpointing(tensors, 2)?;
313 self.apply_gradient_compression(tensors)
314 },
315 MemoryOptimization::High => {
316 self.apply_activation_checkpointing(tensors, 1)?;
318 self.apply_gradient_compression(tensors)?;
319 self.apply_cpu_offloading(tensors)
320 },
321 MemoryOptimization::Extreme => {
322 self.apply_zero_optimization(tensors)
324 },
325 }
326 }
327
328 pub fn optimize_pipeline_bubbles(&self) -> Result<()> {
330 let state = self.pipeline_state.write().expect("lock should not be poisoned");
331
332 let total_stages = self.config.pp_size;
334 let avg_stage_time = state.stage_timings.values().sum::<Duration>() / total_stages as u32;
335
336 let mut bottleneck_stages = Vec::new();
338 for (stage, timing) in &state.stage_timings {
339 if *timing > avg_stage_time * 2 {
340 bottleneck_stages.push(*stage);
341 }
342 }
343
344 if !bottleneck_stages.is_empty() {
346 self.apply_load_balancing(&bottleneck_stages)?;
347 }
348
349 let pipeline_efficiency = 1.0
351 - (state.pipeline_bubbles as f32
352 / (state.forward_passes_completed + state.backward_passes_completed) as f32);
353
354 if pipeline_efficiency < 0.8 {
355 self.adjust_micro_batch_size()?;
356 }
357
358 Ok(())
359 }
360
361 pub fn get_statistics(&self) -> Result<Parallelism3DStats> {
363 let pipeline_state = self.pipeline_state.read().expect("lock should not be poisoned");
364 let comm_stats = self.comm_stats.lock().expect("lock should not be poisoned");
365 let memory_manager = self.memory_manager.lock().expect("lock should not be poisoned");
366
367 Ok(Parallelism3DStats {
368 dp_rank: self.dp_rank,
369 mp_rank: self.mp_rank,
370 pp_rank: self.pp_rank,
371 pipeline_efficiency: self.calculate_pipeline_efficiency(&pipeline_state),
372 communication_efficiency: comm_stats.communication_efficiency,
373 memory_efficiency: self.calculate_memory_efficiency(&memory_manager),
374 total_communication_time: comm_stats.dp_all_reduce_time
375 + comm_stats.mp_all_reduce_time
376 + comm_stats.pp_send_recv_time,
377 peak_memory_usage: memory_manager.peak_memory_usage,
378 pipeline_bubbles: pipeline_state.pipeline_bubbles,
379 micro_batches_processed: pipeline_state.forward_passes_completed,
380 })
381 }
382
383 fn forward_gpipe<M: Model>(
386 &self,
387 model: &M,
388 inputs: &[Tensor],
389 _micro_batch_id: usize,
390 ) -> Result<Vec<Tensor>> {
391 if self.pp_rank == 0 {
394 let outputs = self.process_pipeline_stage(model, inputs, 0)?;
396
397 if self.config.pp_size > 1 {
399 self.send_to_next_stage(&outputs)?;
400 }
401
402 Ok(outputs)
403 } else {
404 let received_inputs = self.receive_from_previous_stage()?;
406 let outputs = self.process_pipeline_stage(model, &received_inputs, self.pp_rank)?;
407
408 if self.pp_rank < self.config.pp_size - 1 {
409 self.send_to_next_stage(&outputs)?;
410 }
411
412 Ok(outputs)
413 }
414 }
415
416 fn forward_pipedream<M: Model>(
417 &self,
418 model: &M,
419 inputs: &[Tensor],
420 micro_batch_id: usize,
421 ) -> Result<Vec<Tensor>> {
422 self.forward_gpipe(model, inputs, micro_batch_id) }
426
427 fn forward_pipedream_2bw<M: Model>(
428 &self,
429 model: &M,
430 inputs: &[Tensor],
431 micro_batch_id: usize,
432 ) -> Result<Vec<Tensor>> {
433 self.forward_gpipe(model, inputs, micro_batch_id) }
436
437 fn forward_interleaved_1f1b<M: Model>(
438 &self,
439 model: &M,
440 inputs: &[Tensor],
441 micro_batch_id: usize,
442 ) -> Result<Vec<Tensor>> {
443 self.forward_gpipe(model, inputs, micro_batch_id) }
446
447 fn forward_adaptive<M: Model>(
448 &self,
449 model: &M,
450 inputs: &[Tensor],
451 micro_batch_id: usize,
452 ) -> Result<Vec<Tensor>> {
453 let stats = self.comm_stats.lock().expect("lock should not be poisoned");
457 let communication_time_ratio = stats.pp_send_recv_time.as_millis() as f32
458 / (stats.dp_all_reduce_time.as_millis() + stats.mp_all_reduce_time.as_millis() + 1)
459 as f32;
460
461 if communication_time_ratio > 2.0 {
462 self.forward_gpipe(model, inputs, micro_batch_id)
464 } else {
465 self.forward_interleaved_1f1b(model, inputs, micro_batch_id)
467 }
468 }
469
470 fn backward_gpipe<M: Model>(
472 &self,
473 _model: &mut M,
474 gradients: &[Tensor],
475 _micro_batch_id: usize,
476 ) -> Result<Vec<Tensor>> {
477 Ok(gradients.to_vec()) }
480
481 fn backward_pipedream<M: Model>(
482 &self,
483 _model: &mut M,
484 gradients: &[Tensor],
485 _micro_batch_id: usize,
486 ) -> Result<Vec<Tensor>> {
487 Ok(gradients.to_vec()) }
489
490 fn backward_pipedream_2bw<M: Model>(
491 &self,
492 _model: &mut M,
493 gradients: &[Tensor],
494 _micro_batch_id: usize,
495 ) -> Result<Vec<Tensor>> {
496 Ok(gradients.to_vec()) }
498
499 fn backward_interleaved_1f1b<M: Model>(
500 &self,
501 _model: &mut M,
502 gradients: &[Tensor],
503 _micro_batch_id: usize,
504 ) -> Result<Vec<Tensor>> {
505 Ok(gradients.to_vec()) }
507
508 fn backward_adaptive<M: Model>(
509 &self,
510 _model: &mut M,
511 gradients: &[Tensor],
512 _micro_batch_id: usize,
513 ) -> Result<Vec<Tensor>> {
514 Ok(gradients.to_vec()) }
516
517 fn mp_reduce_scatter_gradients(&self, gradients: &mut [Tensor]) -> Result<()> {
519 for _tensor in gradients.iter_mut() {
522 }
524 Ok(())
525 }
526
527 fn dp_all_reduce_gradients(&self, gradients: &mut [Tensor]) -> Result<()> {
528 self.dp_group.all_reduce(gradients)?;
530
531 for _tensor in gradients.iter_mut() {
533 }
535
536 Ok(())
537 }
538
539 fn mp_all_gather_gradients(&self, gradients: &mut [Tensor]) -> Result<()> {
540 for _tensor in gradients.iter_mut() {
543 }
545 Ok(())
546 }
547
548 fn send_to_next_stage(&self, _tensors: &[Tensor]) -> Result<()> {
549 Ok(())
552 }
553
554 fn receive_from_previous_stage(&self) -> Result<Vec<Tensor>> {
555 Ok(vec![Tensor::zeros(&[1])?]) }
559
560 fn process_pipeline_stage<M: Model>(
561 &self,
562 _model: &M,
563 inputs: &[Tensor],
564 _stage: usize,
565 ) -> Result<Vec<Tensor>> {
566 Ok(inputs.to_vec()) }
570
571 fn apply_activation_checkpointing(
573 &self,
574 tensors: &mut [Tensor],
575 checkpoint_ratio: usize,
576 ) -> Result<()> {
577 let mut memory_manager = self.memory_manager.lock().expect("lock should not be poisoned");
578
579 for (i, tensor) in tensors.iter().enumerate() {
581 if i % checkpoint_ratio == 0 {
582 memory_manager
583 .checkpointed_activations
584 .entry(format!("checkpoint_{}", i))
585 .or_default()
586 .push(tensor.clone());
587 }
588 }
589
590 Ok(())
591 }
592
593 fn apply_gradient_compression(&self, tensors: &mut [Tensor]) -> Result<()> {
594 for _tensor in tensors.iter_mut() {
596 }
601 Ok(())
602 }
603
604 fn apply_cpu_offloading(&self, _tensors: &mut [Tensor]) -> Result<()> {
605 Ok(())
607 }
608
609 fn apply_zero_optimization(&self, _tensors: &mut [Tensor]) -> Result<()> {
610 Ok(())
612 }
613
614 fn apply_load_balancing(&self, _bottleneck_stages: &[usize]) -> Result<()> {
616 Ok(())
618 }
619
620 fn adjust_micro_batch_size(&self) -> Result<()> {
621 Ok(())
623 }
624
625 fn calculate_pipeline_efficiency(&self, state: &PipelineState) -> f32 {
627 if state.forward_passes_completed == 0 {
628 return 0.0;
629 }
630
631 let total_passes = state.forward_passes_completed + state.backward_passes_completed;
632 1.0 - (state.pipeline_bubbles as f32 / total_passes as f32)
633 }
634
635 fn calculate_memory_efficiency(&self, memory_manager: &MemoryManager) -> f32 {
636 if memory_manager.peak_memory_usage == 0 {
637 return 1.0;
638 }
639
640 memory_manager.current_memory_usage as f32 / memory_manager.peak_memory_usage as f32
641 }
642}
643
644#[derive(Debug, Clone, Serialize, Deserialize)]
646pub struct Parallelism3DStats {
647 pub dp_rank: usize,
648 pub mp_rank: usize,
649 pub pp_rank: usize,
650 pub pipeline_efficiency: f32,
651 pub communication_efficiency: f32,
652 pub memory_efficiency: f32,
653 pub total_communication_time: Duration,
654 pub peak_memory_usage: u64,
655 pub pipeline_bubbles: usize,
656 pub micro_batches_processed: usize,
657}
658
659pub struct Parallelism3DManager {
661 coordinators: HashMap<String, Arc<Parallelism3D>>,
662 global_config: ParallelismConfig,
663 performance_tracker: Arc<Mutex<PerformanceTracker>>,
664}
665
666#[derive(Debug, Default)]
667#[allow(dead_code)]
668struct PerformanceTracker {
669 #[allow(dead_code)]
670 iteration_times: Vec<Duration>,
671 communication_times: Vec<Duration>,
672 memory_usage_samples: Vec<u64>,
673 efficiency_scores: Vec<f32>,
674}
675
676impl Parallelism3DManager {
677 pub fn new(config: ParallelismConfig) -> Self {
679 Self {
680 coordinators: HashMap::new(),
681 global_config: config,
682 performance_tracker: Arc::new(Mutex::new(PerformanceTracker::default())),
683 }
684 }
685
686 pub fn register_coordinator(
688 &mut self,
689 name: String,
690 coordinator: Arc<Parallelism3D>,
691 ) -> Result<()> {
692 self.coordinators.insert(name, coordinator);
693 Ok(())
694 }
695
696 pub fn get_aggregate_stats(&self) -> Result<AggregateParallelismStats> {
698 let mut aggregate = AggregateParallelismStats::default();
699
700 for coordinator in self.coordinators.values() {
701 let stats = coordinator.get_statistics()?;
702 aggregate.total_pipeline_efficiency += stats.pipeline_efficiency;
703 aggregate.total_communication_time += stats.total_communication_time;
704 aggregate.total_memory_usage += stats.peak_memory_usage;
705 aggregate.total_micro_batches += stats.micro_batches_processed;
706 }
707
708 if !self.coordinators.is_empty() {
709 aggregate.average_pipeline_efficiency =
710 aggregate.total_pipeline_efficiency / self.coordinators.len() as f32;
711 }
712
713 aggregate.num_coordinators = self.coordinators.len();
714
715 Ok(aggregate)
716 }
717
718 pub fn optimize_configuration(&mut self) -> Result<ParallelismConfig> {
720 let tracker = self.performance_tracker.lock().expect("lock should not be poisoned");
721
722 if tracker.efficiency_scores.is_empty() {
723 return Ok(self.global_config.clone());
724 }
725
726 let avg_efficiency =
727 tracker.efficiency_scores.iter().sum::<f32>() / tracker.efficiency_scores.len() as f32;
728 let mut optimized_config = self.global_config.clone();
729
730 if avg_efficiency < 0.8 {
732 optimized_config.num_micro_batches = (optimized_config.num_micro_batches * 2).min(16);
733 } else if avg_efficiency > 0.95 {
734 optimized_config.num_micro_batches = (optimized_config.num_micro_batches / 2).max(1);
735 }
736
737 let avg_memory_usage = tracker.memory_usage_samples.iter().sum::<u64>()
739 / tracker.memory_usage_samples.len() as u64;
740 if avg_memory_usage > (0.9 * (32u64 * 1024 * 1024 * 1024) as f64) as u64 {
741 optimized_config.memory_optimization = MemoryOptimization::High;
743 }
744
745 Ok(optimized_config)
746 }
747}
748
749#[derive(Debug, Default, Clone, Serialize, Deserialize)]
751pub struct AggregateParallelismStats {
752 pub num_coordinators: usize,
753 pub total_pipeline_efficiency: f32,
754 pub average_pipeline_efficiency: f32,
755 pub total_communication_time: Duration,
756 pub total_memory_usage: u64,
757 pub total_micro_batches: usize,
758}
759
760#[cfg(test)]
761mod tests {
762 use super::*;
763 use crate::distributed::SimulatedProcessGroup;
764
765 #[test]
766 fn test_parallelism_config_validation() {
767 let config = ParallelismConfig {
768 dp_size: 2,
769 mp_size: 2,
770 pp_size: 2,
771 ..Default::default()
772 };
773
774 let world_size = 8; assert_eq!(config.dp_size * config.mp_size * config.pp_size, world_size);
776 }
777
778 #[test]
779 fn test_rank_calculation() {
780 let config = ParallelismConfig {
781 dp_size: 2,
782 mp_size: 2,
783 pp_size: 2,
784 ..Default::default()
785 };
786
787 let global_rank = 5;
788 let _world_size = 8;
789
790 let dp_rank = global_rank / (config.mp_size * config.pp_size);
791 let mp_rank = (global_rank / config.pp_size) % config.mp_size;
792 let pp_rank = global_rank % config.pp_size;
793
794 assert_eq!(dp_rank, 1);
795 assert_eq!(mp_rank, 0);
796 assert_eq!(pp_rank, 1);
797 }
798
799 #[test]
800 fn test_3d_parallelism_creation() {
801 let config = ParallelismConfig {
802 dp_size: 2,
803 mp_size: 1,
804 pp_size: 1,
805 ..Default::default()
806 };
807
808 let dp_group = Arc::new(SimulatedProcessGroup::new(0, 2));
809 let mp_group = Arc::new(SimulatedProcessGroup::new(0, 1));
810 let pp_group = Arc::new(SimulatedProcessGroup::new(0, 1));
811
812 let parallelism = Parallelism3D::new(config, 0, 2, dp_group, mp_group, pp_group);
813
814 assert!(parallelism.is_ok());
815 let p = parallelism.expect("operation failed in test");
816 assert_eq!(p.dp_rank, 0);
817 assert_eq!(p.mp_rank, 0);
818 assert_eq!(p.pp_rank, 0);
819 }
820
821 #[test]
822 fn test_memory_optimization_levels() {
823 use MemoryOptimization::*;
824
825 let levels = vec![None, Low, Medium, High, Extreme];
826
827 for level in levels {
828 let config = ParallelismConfig {
829 memory_optimization: level,
830 ..Default::default()
831 };
832
833 let json = serde_json::to_string(&config).expect("JSON serialization failed");
835 let deserialized: ParallelismConfig =
836 serde_json::from_str(&json).expect("JSON deserialization failed");
837
838 assert!(matches!(deserialized.memory_optimization, _));
839 }
840 }
841
842 #[test]
843 fn test_pipeline_schedule_types() {
844 use PipelineSchedule::*;
845
846 let schedules = vec![GPipe, PipeDream, PipeDream2BW, Interleaved1F1B, Adaptive];
847
848 for schedule in schedules {
849 let config = ParallelismConfig {
850 pipeline_schedule: schedule,
851 ..Default::default()
852 };
853
854 let json = serde_json::to_string(&config).expect("JSON serialization failed");
856 let deserialized: ParallelismConfig =
857 serde_json::from_str(&json).expect("JSON deserialization failed");
858
859 assert!(matches!(deserialized.pipeline_schedule, _));
860 }
861 }
862
863 #[test]
864 fn test_3d_parallelism_manager() {
865 let config = ParallelismConfig::default();
866 let mut manager = Parallelism3DManager::new(config);
867
868 let optimized_config = manager.optimize_configuration();
870 assert!(optimized_config.is_ok());
871
872 let stats = manager.get_aggregate_stats();
874 assert!(stats.is_ok());
875
876 let stats = stats.expect("operation failed in test");
877 assert_eq!(stats.num_coordinators, 0);
878 assert_eq!(stats.average_pipeline_efficiency, 0.0);
879 }
880
881 #[test]
882 fn test_config_serialization() {
883 let config = ParallelismConfig {
884 dp_size: 4,
885 mp_size: 2,
886 pp_size: 8,
887 num_micro_batches: 16,
888 gradient_accumulation: true,
889 accumulation_steps: 4,
890 activation_checkpointing: true,
891 comm_backend: CommBackend::NCCL,
892 pipeline_schedule: PipelineSchedule::Interleaved1F1B,
893 memory_optimization: MemoryOptimization::High,
894 };
895
896 let json = serde_json::to_string(&config).expect("JSON serialization failed");
897 let deserialized: ParallelismConfig =
898 serde_json::from_str(&json).expect("JSON deserialization failed");
899
900 assert_eq!(config.dp_size, deserialized.dp_size);
901 assert_eq!(config.mp_size, deserialized.mp_size);
902 assert_eq!(config.pp_size, deserialized.pp_size);
903 assert_eq!(config.num_micro_batches, deserialized.num_micro_batches);
904 assert_eq!(
905 config.gradient_accumulation,
906 deserialized.gradient_accumulation
907 );
908 assert_eq!(config.accumulation_steps, deserialized.accumulation_steps);
909 assert_eq!(
910 config.activation_checkpointing,
911 deserialized.activation_checkpointing
912 );
913 }
914}