1use crate::fsdp::{AutoWrapPolicy, MemoryConfig};
15use crate::{TorshDistributedError, TorshResult};
16use serde::{Deserialize, Serialize};
17use std::path::Path;
18use torsh_core::DType;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct FairScaleConfig {
23 pub fsdp: Option<FairScaleFsdpConfig>,
25 pub oss: Option<FairScaleOssConfig>,
27 pub sharded_grad_scaler: Option<FairScaleGradScalerConfig>,
29 pub activation_checkpointing: Option<FairScaleActivationCheckpointingConfig>,
31 pub pipeline_parallelism: Option<FairScalePipelineConfig>,
33 pub memory_optimization: Option<FairScaleMemoryOptimizationConfig>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct FairScaleFsdpConfig {
40 pub auto_wrap_policy: Option<FairScaleAutoWrapPolicy>,
42 pub min_num_params: Option<u64>,
44 pub wrapper_cls: Option<String>,
46 pub mixed_precision: Option<bool>,
48 pub flatten_parameters: Option<bool>,
50 pub bucket_cap_mb: Option<f32>,
52 pub compute_dtype: Option<String>,
54 pub buffer_dtype: Option<String>,
56 pub reshard_after_forward: Option<bool>,
58 pub move_grads_to_cpu: Option<bool>,
60 pub move_params_to_cpu: Option<bool>,
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
66pub enum FairScaleAutoWrapPolicy {
67 None,
69 SizeBased,
71 TransformerBased,
73 CustomFunction,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct FairScaleOssConfig {
80 pub optimizer: String,
82 pub broadcast_buffers: Option<bool>,
84 pub compress_gradients: Option<bool>,
86 pub gradient_compression: Option<String>,
88 pub partition_optimizer: Option<bool>,
90 pub gradient_predivide_factor: Option<f32>,
92 pub gradient_postdivide_factor: Option<f32>,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct FairScaleGradScalerConfig {
99 pub init_scale: Option<f32>,
101 pub growth_factor: Option<f32>,
103 pub backoff_factor: Option<f32>,
105 pub growth_interval: Option<u32>,
107 pub enabled: Option<bool>,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct FairScaleActivationCheckpointingConfig {
114 pub strategy: FairScaleCheckpointingStrategy,
116 pub checkpoint_ratio: Option<f32>,
118 pub offload_to_cpu: Option<bool>,
120 pub checkpoint_every_n_layers: Option<u32>,
122 pub use_gradient_checkpointing: Option<bool>,
124}
125
126#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
128pub enum FairScaleCheckpointingStrategy {
129 None,
131 Uniform,
133 Selective,
135 Adaptive,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct FairScalePipelineConfig {
142 pub stages: u32,
144 pub micro_batch_size: Option<u32>,
146 pub balance_mode: Option<FairScaleBalanceMode>,
148 pub schedule: Option<FairScalePipelineSchedule>,
150 pub checkpoint_activation: Option<bool>,
152 pub distributed_backend: Option<String>,
154}
155
156#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
158pub enum FairScaleBalanceMode {
159 Auto,
161 Manual,
163 Parameters,
165 Time,
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
171pub enum FairScalePipelineSchedule {
172 GPipe,
174 OneF1B,
176 Interleaved,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct FairScaleMemoryOptimizationConfig {
183 pub cpu_offloading: Option<bool>,
185 pub parameter_offloading: Option<bool>,
187 pub optimizer_offloading: Option<bool>,
189 pub gradient_compression: Option<bool>,
191 pub memory_defragmentation: Option<bool>,
193 pub lazy_parameter_init: Option<bool>,
195}
196
197#[derive(Debug, Clone, Default)]
199pub struct FairScaleStats {
200 pub fsdp_ops: u64,
202 pub fsdp_time_sec: f64,
204 pub oss_ops: u64,
206 pub oss_time_sec: f64,
208 pub memory_saved_bytes: u64,
210 pub checkpointed_activations: u64,
212 pub pipeline_efficiency: f64,
214 pub gradient_scaling_events: u64,
216 pub average_shard_size: f64,
218}
219
220pub struct FairScaleIntegration {
222 config: FairScaleConfig,
224 stats: FairScaleStats,
226 initialized: bool,
228 rank: u32,
230 world_size: u32,
232 local_rank: u32,
234 local_size: u32,
236}
237
238impl FairScaleIntegration {
239 pub fn new(config: FairScaleConfig) -> Self {
241 Self {
242 config,
243 stats: FairScaleStats::default(),
244 initialized: false,
245 rank: 0,
246 world_size: 1,
247 local_rank: 0,
248 local_size: 1,
249 }
250 }
251
252 pub fn from_file<P: AsRef<Path>>(path: P) -> TorshResult<Self> {
254 let content = std::fs::read_to_string(path).map_err(|e| {
255 TorshDistributedError::configuration_error(format!(
256 "Failed to read FairScale config file: {}",
257 e
258 ))
259 })?;
260
261 let config: FairScaleConfig = serde_json::from_str(&content).map_err(|e| {
262 TorshDistributedError::configuration_error(format!(
263 "Failed to parse FairScale config: {}",
264 e
265 ))
266 })?;
267
268 Ok(Self::new(config))
269 }
270
271 pub fn initialize(
273 &mut self,
274 rank: u32,
275 world_size: u32,
276 local_rank: u32,
277 local_size: u32,
278 ) -> TorshResult<()> {
279 if self.initialized {
280 return Err(TorshDistributedError::configuration_error(
281 "FairScale integration already initialized",
282 ));
283 }
284
285 self.rank = rank;
286 self.world_size = world_size;
287 self.local_rank = local_rank;
288 self.local_size = local_size;
289
290 self.validate_config()?;
291 self.setup_fsdp()?;
292 self.setup_oss()?;
293 self.setup_grad_scaler()?;
294 self.setup_activation_checkpointing()?;
295 self.setup_pipeline_parallelism()?;
296 self.setup_memory_optimization()?;
297
298 self.initialized = true;
299 tracing::info!(
300 "FairScale integration initialized - rank: {}, world_size: {}, local_rank: {}",
301 self.rank,
302 self.world_size,
303 self.local_rank
304 );
305
306 Ok(())
307 }
308
309 fn validate_config(&self) -> TorshResult<()> {
311 if let Some(ref fsdp) = self.config.fsdp {
313 if let Some(min_params) = fsdp.min_num_params {
314 if min_params == 0 {
315 return Err(TorshDistributedError::configuration_error(
316 "FSDP min_num_params must be greater than 0",
317 ));
318 }
319 }
320
321 if let Some(bucket_cap) = fsdp.bucket_cap_mb {
322 if bucket_cap <= 0.0 {
323 return Err(TorshDistributedError::configuration_error(
324 "FSDP bucket_cap_mb must be greater than 0",
325 ));
326 }
327 }
328 }
329
330 if let Some(ref pipeline) = self.config.pipeline_parallelism {
332 if pipeline.stages == 0 {
333 return Err(TorshDistributedError::configuration_error(
334 "Pipeline stages must be greater than 0",
335 ));
336 }
337
338 if pipeline.stages > self.world_size {
339 return Err(TorshDistributedError::configuration_error(
340 "Pipeline stages cannot exceed world size",
341 ));
342 }
343
344 if let Some(micro_batch_size) = pipeline.micro_batch_size {
345 if micro_batch_size == 0 {
346 return Err(TorshDistributedError::configuration_error(
347 "Pipeline micro_batch_size must be greater than 0",
348 ));
349 }
350 }
351 }
352
353 if let Some(ref grad_scaler) = self.config.sharded_grad_scaler {
355 if let Some(init_scale) = grad_scaler.init_scale {
356 if init_scale <= 0.0 {
357 return Err(TorshDistributedError::configuration_error(
358 "GradScaler init_scale must be greater than 0",
359 ));
360 }
361 }
362
363 if let Some(growth_factor) = grad_scaler.growth_factor {
364 if growth_factor <= 1.0 {
365 return Err(TorshDistributedError::configuration_error(
366 "GradScaler growth_factor must be greater than 1",
367 ));
368 }
369 }
370
371 if let Some(backoff_factor) = grad_scaler.backoff_factor {
372 if backoff_factor <= 0.0 || backoff_factor >= 1.0 {
373 return Err(TorshDistributedError::configuration_error(
374 "GradScaler backoff_factor must be between 0 and 1",
375 ));
376 }
377 }
378 }
379
380 Ok(())
381 }
382
383 fn setup_fsdp(&self) -> TorshResult<()> {
385 if let Some(ref fsdp) = self.config.fsdp {
386 tracing::info!("Setting up FairScale FSDP");
387
388 let auto_wrap_policy = fsdp
389 .auto_wrap_policy
390 .unwrap_or(FairScaleAutoWrapPolicy::SizeBased);
391 tracing::debug!("FSDP auto-wrap policy: {:?}", auto_wrap_policy);
392
393 let min_params = fsdp.min_num_params.unwrap_or(100_000);
394 tracing::debug!("FSDP min parameters for auto-wrap: {}", min_params);
395
396 let mixed_precision = fsdp.mixed_precision.unwrap_or(false);
397 tracing::debug!("FSDP mixed precision: {}", mixed_precision);
398
399 let flatten_parameters = fsdp.flatten_parameters.unwrap_or(true);
400 tracing::debug!("FSDP flatten parameters: {}", flatten_parameters);
401
402 let bucket_cap_mb = fsdp.bucket_cap_mb.unwrap_or(25.0);
403 tracing::debug!("FSDP bucket capacity: {} MB", bucket_cap_mb);
404
405 let reshard_after_forward = fsdp.reshard_after_forward.unwrap_or(true);
406 tracing::debug!("FSDP reshard after forward: {}", reshard_after_forward);
407
408 if fsdp.move_grads_to_cpu.unwrap_or(false) {
409 tracing::debug!("FSDP gradient CPU offloading enabled");
410 }
411
412 if fsdp.move_params_to_cpu.unwrap_or(false) {
413 tracing::debug!("FSDP parameter CPU offloading enabled");
414 }
415 }
416 Ok(())
417 }
418
419 fn setup_oss(&self) -> TorshResult<()> {
421 if let Some(ref oss) = self.config.oss {
422 tracing::info!("Setting up FairScale OSS (Optimizer State Sharding)");
423
424 tracing::debug!("OSS optimizer: {}", oss.optimizer);
425
426 let broadcast_buffers = oss.broadcast_buffers.unwrap_or(true);
427 tracing::debug!("OSS broadcast buffers: {}", broadcast_buffers);
428
429 let compress_gradients = oss.compress_gradients.unwrap_or(false);
430 tracing::debug!("OSS compress gradients: {}", compress_gradients);
431
432 if let Some(ref compression) = oss.gradient_compression {
433 tracing::debug!("OSS gradient compression algorithm: {}", compression);
434 }
435
436 let partition_optimizer = oss.partition_optimizer.unwrap_or(true);
437 tracing::debug!("OSS partition optimizer: {}", partition_optimizer);
438
439 let predivide_factor = oss.gradient_predivide_factor.unwrap_or(1.0);
440 tracing::debug!("OSS gradient predivide factor: {}", predivide_factor);
441
442 let postdivide_factor = oss.gradient_postdivide_factor.unwrap_or(1.0);
443 tracing::debug!("OSS gradient postdivide factor: {}", postdivide_factor);
444 }
445 Ok(())
446 }
447
448 fn setup_grad_scaler(&self) -> TorshResult<()> {
450 if let Some(ref grad_scaler) = self.config.sharded_grad_scaler {
451 tracing::info!("Setting up FairScale ShardedGradScaler");
452
453 let init_scale = grad_scaler.init_scale.unwrap_or(2.0_f32.powi(16));
454 tracing::debug!("GradScaler initial scale: {}", init_scale);
455
456 let growth_factor = grad_scaler.growth_factor.unwrap_or(2.0);
457 tracing::debug!("GradScaler growth factor: {}", growth_factor);
458
459 let backoff_factor = grad_scaler.backoff_factor.unwrap_or(0.5);
460 tracing::debug!("GradScaler backoff factor: {}", backoff_factor);
461
462 let growth_interval = grad_scaler.growth_interval.unwrap_or(2000);
463 tracing::debug!("GradScaler growth interval: {}", growth_interval);
464
465 let enabled = grad_scaler.enabled.unwrap_or(true);
466 tracing::debug!("GradScaler enabled: {}", enabled);
467 }
468 Ok(())
469 }
470
471 fn setup_activation_checkpointing(&self) -> TorshResult<()> {
473 if let Some(ref checkpoint) = self.config.activation_checkpointing {
474 tracing::info!("Setting up FairScale activation checkpointing");
475
476 tracing::debug!(
477 "Activation checkpointing strategy: {:?}",
478 checkpoint.strategy
479 );
480
481 let checkpoint_ratio = checkpoint.checkpoint_ratio.unwrap_or(0.5);
482 tracing::debug!("Activation checkpointing ratio: {}", checkpoint_ratio);
483
484 let offload_to_cpu = checkpoint.offload_to_cpu.unwrap_or(false);
485 tracing::debug!("Activation checkpointing CPU offload: {}", offload_to_cpu);
486
487 if let Some(every_n) = checkpoint.checkpoint_every_n_layers {
488 tracing::debug!("Checkpoint every {} layers", every_n);
489 }
490
491 let use_gradient_checkpointing = checkpoint.use_gradient_checkpointing.unwrap_or(false);
492 tracing::debug!("Use gradient checkpointing: {}", use_gradient_checkpointing);
493 }
494 Ok(())
495 }
496
497 fn setup_pipeline_parallelism(&self) -> TorshResult<()> {
499 if let Some(ref pipeline) = self.config.pipeline_parallelism {
500 tracing::info!("Setting up FairScale pipeline parallelism");
501
502 tracing::debug!("Pipeline stages: {}", pipeline.stages);
503
504 let micro_batch_size = pipeline.micro_batch_size.unwrap_or(1);
505 tracing::debug!("Pipeline micro batch size: {}", micro_batch_size);
506
507 let balance_mode = pipeline.balance_mode.unwrap_or(FairScaleBalanceMode::Auto);
508 tracing::debug!("Pipeline balance mode: {:?}", balance_mode);
509
510 let schedule = pipeline
511 .schedule
512 .unwrap_or(FairScalePipelineSchedule::GPipe);
513 tracing::debug!("Pipeline schedule: {:?}", schedule);
514
515 let checkpoint_activation = pipeline.checkpoint_activation.unwrap_or(false);
516 tracing::debug!("Pipeline checkpoint activation: {}", checkpoint_activation);
517
518 if let Some(ref backend) = pipeline.distributed_backend {
519 tracing::debug!("Pipeline distributed backend: {}", backend);
520 }
521 }
522 Ok(())
523 }
524
525 fn setup_memory_optimization(&self) -> TorshResult<()> {
527 if let Some(ref memory) = self.config.memory_optimization {
528 tracing::info!("Setting up FairScale memory optimization");
529
530 let cpu_offloading = memory.cpu_offloading.unwrap_or(false);
531 tracing::debug!("Memory CPU offloading: {}", cpu_offloading);
532
533 let parameter_offloading = memory.parameter_offloading.unwrap_or(false);
534 tracing::debug!("Memory parameter offloading: {}", parameter_offloading);
535
536 let optimizer_offloading = memory.optimizer_offloading.unwrap_or(false);
537 tracing::debug!("Memory optimizer offloading: {}", optimizer_offloading);
538
539 let gradient_compression = memory.gradient_compression.unwrap_or(false);
540 tracing::debug!("Memory gradient compression: {}", gradient_compression);
541
542 let memory_defragmentation = memory.memory_defragmentation.unwrap_or(false);
543 tracing::debug!("Memory defragmentation: {}", memory_defragmentation);
544
545 let lazy_parameter_init = memory.lazy_parameter_init.unwrap_or(false);
546 tracing::debug!(
547 "Memory lazy parameter initialization: {}",
548 lazy_parameter_init
549 );
550 }
551 Ok(())
552 }
553
554 pub fn to_fsdp_config(&self) -> TorshResult<crate::fsdp::FsdpConfig> {
556 use crate::fsdp::{BackwardPrefetch, FsdpConfig, MixedPrecisionConfig, ShardingStrategy};
557
558 let sharding_strategy = if let Some(ref fsdp) = self.config.fsdp {
559 if fsdp.move_params_to_cpu.unwrap_or(false) {
560 ShardingStrategy::NoShard
561 } else if fsdp.reshard_after_forward.unwrap_or(true) {
562 ShardingStrategy::FullShard
563 } else {
564 ShardingStrategy::ShardGradOp
565 }
566 } else {
567 ShardingStrategy::FullShard
568 };
569
570 let mixed_precision = if let Some(ref fsdp) = self.config.fsdp {
571 if fsdp.mixed_precision.unwrap_or(false) {
572 Some(MixedPrecisionConfig {
573 param_dtype: DType::F16, reduce_dtype: DType::F16,
575 buffer_dtype: DType::F16,
576 keep_low_precision_grads: false,
577 })
578 } else {
579 None
580 }
581 } else {
582 None
583 };
584
585 let config = FsdpConfig {
586 min_num_params: 1000,
587 auto_wrap_policy: AutoWrapPolicy::SizeBasedAutoWrap {
588 min_num_params: 1000,
589 },
590 sharding_strategy,
591 mixed_precision,
592 cpu_offload: self
593 .config
594 .fsdp
595 .as_ref()
596 .map(|f| f.move_params_to_cpu.unwrap_or(false))
597 .unwrap_or(false),
598 memory_config: MemoryConfig {
599 limit_all_gathers: true,
600 use_orig_params: false,
601 offload_to_cpu: self
602 .config
603 .fsdp
604 .as_ref()
605 .map(|f| f.move_params_to_cpu.unwrap_or(false))
606 .unwrap_or(false),
607 },
608 backward_prefetch: BackwardPrefetch::BackwardPre,
609 };
610
611 Ok(config)
612 }
613
614 pub fn to_pipeline_config(&self) -> TorshResult<Option<crate::pipeline::PipelineConfig>> {
616 if let Some(ref pipeline) = self.config.pipeline_parallelism {
617 use crate::pipeline::{PipelineConfig, ScheduleType};
618
619 let schedule_type = match pipeline
620 .schedule
621 .unwrap_or(FairScalePipelineSchedule::GPipe)
622 {
623 FairScalePipelineSchedule::GPipe => ScheduleType::GPipe,
624 FairScalePipelineSchedule::OneF1B => ScheduleType::OneFOneBInterleaved,
625 FairScalePipelineSchedule::Interleaved => ScheduleType::InterleavedOneFOneB,
626 };
627
628 let config = PipelineConfig {
629 num_micro_batches: pipeline.micro_batch_size.unwrap_or(1) as usize,
630 schedule: schedule_type,
631 accumulate_gradients: false,
632 base_tag: 0,
633 async_comm: true,
634 comm_timeout_ms: 5000,
635 };
636
637 Ok(Some(config))
638 } else {
639 Ok(None)
640 }
641 }
642
643 pub fn config(&self) -> &FairScaleConfig {
645 &self.config
646 }
647
648 pub fn stats(&self) -> &FairScaleStats {
650 &self.stats
651 }
652
653 pub fn is_initialized(&self) -> bool {
655 self.initialized
656 }
657
658 pub fn rank(&self) -> u32 {
660 self.rank
661 }
662
663 pub fn world_size(&self) -> u32 {
665 self.world_size
666 }
667
668 pub fn local_rank(&self) -> u32 {
670 self.local_rank
671 }
672
673 pub fn local_size(&self) -> u32 {
675 self.local_size
676 }
677
678 pub fn fsdp_operation(
680 &mut self,
681 operation_name: &str,
682 parameter_count: usize,
683 ) -> TorshResult<()> {
684 if !self.initialized {
685 return Err(TorshDistributedError::BackendNotInitialized);
686 }
687
688 let start_time = std::time::Instant::now();
689
690 tracing::debug!(
691 "FairScale FSDP operation: {} ({} params)",
692 operation_name,
693 parameter_count
694 );
695
696 self.stats.fsdp_ops += 1;
698 self.stats.fsdp_time_sec += start_time.elapsed().as_secs_f64();
699
700 let shard_size = parameter_count / self.world_size as usize;
702 let memory_saved = parameter_count - shard_size;
703 self.stats.memory_saved_bytes += (memory_saved * 4) as u64; self.stats.average_shard_size = shard_size as f64;
705
706 Ok(())
707 }
708
709 pub fn oss_operation(
711 &mut self,
712 operation_name: &str,
713 optimizer_state_size: usize,
714 ) -> TorshResult<()> {
715 if !self.initialized {
716 return Err(TorshDistributedError::BackendNotInitialized);
717 }
718
719 let start_time = std::time::Instant::now();
720
721 tracing::debug!(
722 "FairScale OSS operation: {} ({} bytes)",
723 operation_name,
724 optimizer_state_size
725 );
726
727 self.stats.oss_ops += 1;
729 self.stats.oss_time_sec += start_time.elapsed().as_secs_f64();
730
731 Ok(())
732 }
733
734 pub fn record_activation_checkpoint(&mut self, layer_name: &str, memory_saved: usize) {
736 if self.config.activation_checkpointing.is_some() {
737 tracing::debug!(
738 "Activation checkpoint: {} (saved {} bytes)",
739 layer_name,
740 memory_saved
741 );
742 self.stats.checkpointed_activations += 1;
743 self.stats.memory_saved_bytes += memory_saved as u64;
744 }
745 }
746
747 pub fn record_gradient_scaling_event(&mut self, scale_factor: f32) {
749 if self.config.sharded_grad_scaler.is_some() {
750 tracing::debug!("Gradient scaling event: scale factor {}", scale_factor);
751 self.stats.gradient_scaling_events += 1;
752 }
753 }
754
755 pub fn default_config() -> FairScaleConfig {
757 FairScaleConfig {
758 fsdp: Some(FairScaleFsdpConfig {
759 auto_wrap_policy: Some(FairScaleAutoWrapPolicy::SizeBased),
760 min_num_params: Some(100_000),
761 wrapper_cls: None,
762 mixed_precision: Some(false),
763 flatten_parameters: Some(true),
764 bucket_cap_mb: Some(25.0),
765 compute_dtype: Some("float32".to_string()),
766 buffer_dtype: Some("float32".to_string()),
767 reshard_after_forward: Some(true),
768 move_grads_to_cpu: Some(false),
769 move_params_to_cpu: Some(false),
770 }),
771 oss: Some(FairScaleOssConfig {
772 optimizer: "AdamW".to_string(),
773 broadcast_buffers: Some(true),
774 compress_gradients: Some(false),
775 gradient_compression: None,
776 partition_optimizer: Some(true),
777 gradient_predivide_factor: Some(1.0),
778 gradient_postdivide_factor: Some(1.0),
779 }),
780 sharded_grad_scaler: None,
781 activation_checkpointing: None,
782 pipeline_parallelism: None,
783 memory_optimization: Some(FairScaleMemoryOptimizationConfig {
784 cpu_offloading: Some(false),
785 parameter_offloading: Some(false),
786 optimizer_offloading: Some(false),
787 gradient_compression: Some(false),
788 memory_defragmentation: Some(false),
789 lazy_parameter_init: Some(false),
790 }),
791 }
792 }
793
794 pub fn config_with_fsdp_mixed_precision() -> FairScaleConfig {
796 let mut config = Self::default_config();
797
798 if let Some(ref mut fsdp) = config.fsdp {
799 fsdp.mixed_precision = Some(true);
800 fsdp.compute_dtype = Some("float16".to_string());
801 fsdp.buffer_dtype = Some("float16".to_string());
802 }
803
804 config.sharded_grad_scaler = Some(FairScaleGradScalerConfig {
805 init_scale: Some(2.0_f32.powi(16)),
806 growth_factor: Some(2.0),
807 backoff_factor: Some(0.5),
808 growth_interval: Some(2000),
809 enabled: Some(true),
810 });
811
812 config
813 }
814
815 pub fn config_with_pipeline_parallelism(stages: u32) -> FairScaleConfig {
817 let mut config = Self::default_config();
818
819 config.pipeline_parallelism = Some(FairScalePipelineConfig {
820 stages,
821 micro_batch_size: Some(1),
822 balance_mode: Some(FairScaleBalanceMode::Auto),
823 schedule: Some(FairScalePipelineSchedule::OneF1B),
824 checkpoint_activation: Some(true),
825 distributed_backend: Some("nccl".to_string()),
826 });
827
828 config.activation_checkpointing = Some(FairScaleActivationCheckpointingConfig {
829 strategy: FairScaleCheckpointingStrategy::Uniform,
830 checkpoint_ratio: Some(0.5),
831 offload_to_cpu: Some(false),
832 checkpoint_every_n_layers: Some(4),
833 use_gradient_checkpointing: Some(true),
834 });
835
836 config
837 }
838}
839
840impl Default for FairScaleConfig {
841 fn default() -> Self {
842 FairScaleIntegration::default_config()
843 }
844}
845
846#[cfg(test)]
847mod tests {
848 use super::*;
849
850 #[test]
851 fn test_fairscale_config_validation() {
852 let config = FairScaleIntegration::default_config();
853 let mut integration = FairScaleIntegration::new(config);
854
855 assert!(integration.initialize(0, 4, 0, 2).is_ok());
857 assert!(integration.is_initialized());
858 assert_eq!(integration.rank(), 0);
859 assert_eq!(integration.world_size(), 4);
860 assert_eq!(integration.local_rank(), 0);
861 }
862
863 #[test]
864 fn test_fairscale_fsdp_config_conversion() {
865 let config = FairScaleIntegration::default_config();
866 let mut integration = FairScaleIntegration::new(config);
867
868 assert!(integration.initialize(0, 4, 0, 2).is_ok());
869
870 let fsdp_config = integration.to_fsdp_config().unwrap();
872 assert!(matches!(
873 fsdp_config.sharding_strategy,
874 crate::fsdp::ShardingStrategy::FullShard
875 ));
876 assert!(!fsdp_config.cpu_offload);
877 assert_eq!(fsdp_config.min_num_params, 1000);
878 }
879
880 #[test]
881 fn test_fairscale_pipeline_config_conversion() {
882 let config = FairScaleIntegration::config_with_pipeline_parallelism(4);
883 let mut integration = FairScaleIntegration::new(config);
884
885 assert!(integration.initialize(0, 4, 0, 2).is_ok());
886
887 let pipeline_config = integration.to_pipeline_config().unwrap();
889 assert!(pipeline_config.is_some());
890
891 if let Some(config) = pipeline_config {
892 assert_eq!(config.num_micro_batches, 1);
894 assert!(matches!(
895 config.schedule,
896 crate::pipeline::ScheduleType::OneFOneBInterleaved
897 ));
898 }
901 }
902
903 #[test]
904 fn test_fairscale_fsdp_operations() {
905 let config = FairScaleIntegration::default_config();
906 let mut integration = FairScaleIntegration::new(config);
907
908 assert!(integration.initialize(0, 4, 0, 2).is_ok());
909
910 assert!(integration.fsdp_operation("forward", 1000000).is_ok());
912 assert!(integration.fsdp_operation("backward", 1000000).is_ok());
913
914 let stats = integration.stats();
915 assert_eq!(stats.fsdp_ops, 2);
916 assert!(stats.fsdp_time_sec >= 0.0);
918 assert!(stats.memory_saved_bytes > 0);
919 assert_eq!(stats.average_shard_size, 250000.0); }
921
922 #[test]
923 fn test_fairscale_oss_operations() {
924 let config = FairScaleIntegration::default_config();
925 let mut integration = FairScaleIntegration::new(config);
926
927 assert!(integration.initialize(0, 4, 0, 2).is_ok());
928
929 assert!(integration.oss_operation("step", 1024).is_ok());
931 assert!(integration.oss_operation("zero_grad", 1024).is_ok());
932
933 let stats = integration.stats();
934 assert_eq!(stats.oss_ops, 2);
935 assert!(stats.oss_time_sec >= 0.0);
937 }
938
939 #[test]
940 fn test_fairscale_mixed_precision_config() {
941 let config = FairScaleIntegration::config_with_fsdp_mixed_precision();
942 let mut integration = FairScaleIntegration::new(config);
943
944 assert!(integration.initialize(0, 4, 0, 2).is_ok());
945
946 let fsdp_config = integration.to_fsdp_config().unwrap();
948 assert!(fsdp_config.mixed_precision.is_some());
949
950 if let Some(mp_config) = fsdp_config.mixed_precision {
951 assert_eq!(mp_config.param_dtype, DType::F16);
952 assert_eq!(mp_config.reduce_dtype, DType::F16);
953 assert_eq!(mp_config.buffer_dtype, DType::F16);
954 }
955 }
956
957 #[test]
958 fn test_fairscale_invalid_pipeline_stages() {
959 let config = FairScaleIntegration::config_with_pipeline_parallelism(0); let mut integration = FairScaleIntegration::new(config);
961
962 assert!(integration.initialize(0, 4, 0, 2).is_err());
964 }
965
966 #[test]
967 fn test_fairscale_config_serialization() {
968 let config = FairScaleIntegration::config_with_fsdp_mixed_precision();
969
970 let json = serde_json::to_string(&config).unwrap();
972 assert!(json.contains("float16"));
973 assert!(json.contains("fsdp"));
974 assert!(json.contains("sharded_grad_scaler"));
975
976 let deserialized: FairScaleConfig = serde_json::from_str(&json).unwrap();
978 assert!(deserialized.fsdp.is_some());
979 assert!(deserialized.sharded_grad_scaler.is_some());
980 }
981
982 #[test]
983 fn test_fairscale_activation_checkpointing() {
984 let config = FairScaleIntegration::config_with_pipeline_parallelism(4);
985 let mut integration = FairScaleIntegration::new(config);
986
987 assert!(integration.initialize(0, 4, 0, 2).is_ok());
988
989 integration.record_activation_checkpoint("layer1", 1024);
991 integration.record_activation_checkpoint("layer2", 2048);
992
993 let stats = integration.stats();
994 assert_eq!(stats.checkpointed_activations, 2);
995 assert!(stats.memory_saved_bytes >= 3072); }
997
998 #[test]
999 fn test_fairscale_gradient_scaling() {
1000 let config = FairScaleIntegration::config_with_fsdp_mixed_precision();
1001 let mut integration = FairScaleIntegration::new(config);
1002
1003 assert!(integration.initialize(0, 4, 0, 2).is_ok());
1004
1005 integration.record_gradient_scaling_event(65536.0);
1007 integration.record_gradient_scaling_event(32768.0);
1008
1009 let stats = integration.stats();
1010 assert_eq!(stats.gradient_scaling_events, 2);
1011 }
1012}