1use crate::{TorshDistributedError, TorshResult};
8use serde::{Deserialize, Serialize};
9use std::path::Path;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
13pub enum ZeroStage {
14 Stage0 = 0,
16 Stage1 = 1,
18 Stage2 = 2,
20 Stage3 = 3,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize, Default)]
26pub struct DeepSpeedConfig {
27 pub zero_optimization: ZeroOptimizationConfig,
29 pub gradient_clipping: Option<f32>,
31 pub gradient_accumulation_steps: Option<u32>,
33 pub fp16: Option<FP16Config>,
35 pub zero_force_ds_cpu_optimizer: Option<bool>,
37 pub activation_checkpointing: Option<ActivationCheckpointingConfig>,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ZeroOptimizationConfig {
44 pub stage: ZeroStage,
46 pub allgather_bucket_size: Option<u64>,
48 pub reduce_bucket_size: Option<u64>,
50 pub overlap_comm: Option<bool>,
52 pub contiguous_gradients: Option<bool>,
54 pub sub_group_size: Option<u32>,
56 pub reduce_scatter: Option<bool>,
58 pub allgather_partitions: Option<bool>,
60 pub stage3_max_live_parameters: Option<u64>,
62 pub stage3_max_reuse_distance: Option<u64>,
64 pub stage3_prefetch_bucket_size: Option<u64>,
66 pub stage3_param_persistence_threshold: Option<u64>,
68 pub offload_optimizer: Option<OffloadOptimizerConfig>,
70 pub offload_param: Option<OffloadParamConfig>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct FP16Config {
77 pub enabled: bool,
79 pub loss_scale: Option<f32>,
81 pub loss_scale_window: Option<u32>,
83 pub hysteresis: Option<u32>,
85 pub min_loss_scale: Option<f32>,
87 pub initial_scale_power: Option<u32>,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct ActivationCheckpointingConfig {
94 pub partition_activations: Option<bool>,
96 pub cpu_checkpointing: Option<bool>,
98 pub contiguous_memory_optimization: Option<bool>,
100 pub number_checkpoints: Option<u32>,
102 pub synchronize_checkpoint_boundary: Option<bool>,
104 pub profile: Option<bool>,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct OffloadOptimizerConfig {
111 pub device: String,
113 pub nvme_path: Option<String>,
115 pub pin_memory: Option<bool>,
117 pub buffer_count: Option<u32>,
119 pub fast_init: Option<bool>,
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct OffloadParamConfig {
126 pub device: String,
128 pub nvme_path: Option<String>,
130 pub pin_memory: Option<bool>,
132 pub buffer_count: Option<u32>,
134 pub max_in_cpu: Option<u64>,
136}
137
138pub struct DeepSpeedIntegration {
140 config: DeepSpeedConfig,
141 initialized: bool,
142}
143
144impl DeepSpeedIntegration {
145 pub fn new(config: DeepSpeedConfig) -> Self {
147 Self {
148 config,
149 initialized: false,
150 }
151 }
152
153 pub fn from_file<P: AsRef<Path>>(path: P) -> TorshResult<Self> {
155 let content = std::fs::read_to_string(path).map_err(|e| {
156 TorshDistributedError::configuration_error(format!(
157 "Failed to read DeepSpeed config file: {}",
158 e
159 ))
160 })?;
161
162 let config: DeepSpeedConfig = serde_json::from_str(&content).map_err(|e| {
163 TorshDistributedError::configuration_error(format!(
164 "Failed to parse DeepSpeed config: {}",
165 e
166 ))
167 })?;
168
169 Ok(Self::new(config))
170 }
171
172 pub fn from_json(json: &str) -> TorshResult<Self> {
174 let config: DeepSpeedConfig = serde_json::from_str(json).map_err(|e| {
175 TorshDistributedError::configuration_error(format!(
176 "Failed to parse DeepSpeed config: {}",
177 e
178 ))
179 })?;
180
181 Ok(Self::new(config))
182 }
183
184 pub fn initialize(&mut self) -> TorshResult<()> {
186 if self.initialized {
187 return Ok(());
188 }
189
190 self.validate_config()?;
192
193 match self.config.zero_optimization.stage {
195 ZeroStage::Stage0 => {
196 tracing::info!(
198 "DeepSpeed integration initialized with ZeRO Stage 0 (no optimization)"
199 );
200 }
201 ZeroStage::Stage1 => {
202 self.initialize_zero_stage1()?;
204 }
205 ZeroStage::Stage2 => {
206 self.initialize_zero_stage2()?;
208 }
209 ZeroStage::Stage3 => {
210 self.initialize_zero_stage3()?;
212 }
213 }
214
215 self.initialized = true;
216 Ok(())
217 }
218
219 fn validate_config(&self) -> TorshResult<()> {
221 if matches!(self.config.zero_optimization.stage, ZeroStage::Stage3)
223 && self
224 .config
225 .zero_optimization
226 .stage3_max_live_parameters
227 .is_none()
228 {
229 return Err(TorshDistributedError::configuration_error(
230 "ZeRO Stage 3 requires stage3_max_live_parameters to be set",
231 ));
232 }
233
234 if let Some(ref offload_config) = self.config.zero_optimization.offload_optimizer {
236 if offload_config.device.is_empty() {
237 return Err(TorshDistributedError::configuration_error(
238 "Offload optimizer device cannot be empty",
239 ));
240 }
241 }
242
243 if let Some(ref offload_config) = self.config.zero_optimization.offload_param {
244 if offload_config.device.is_empty() {
245 return Err(TorshDistributedError::configuration_error(
246 "Offload parameter device cannot be empty",
247 ));
248 }
249 }
250
251 Ok(())
252 }
253
254 fn initialize_zero_stage1(&self) -> TorshResult<()> {
256 tracing::info!("Initializing DeepSpeed ZeRO Stage 1 (optimizer state partitioning)");
257
258 let bucket_size = self
260 .config
261 .zero_optimization
262 .reduce_bucket_size
263 .unwrap_or(2e8 as u64);
264 tracing::debug!("ZeRO Stage 1 - Reduce bucket size: {}", bucket_size);
265
266 Ok(())
267 }
268
269 fn initialize_zero_stage2(&self) -> TorshResult<()> {
271 tracing::info!("Initializing DeepSpeed ZeRO Stage 2 (gradient partitioning)");
272
273 let allgather_bucket_size = self
275 .config
276 .zero_optimization
277 .allgather_bucket_size
278 .unwrap_or(2e8 as u64);
279 let reduce_bucket_size = self
280 .config
281 .zero_optimization
282 .reduce_bucket_size
283 .unwrap_or(2e8 as u64);
284 let overlap_comm = self.config.zero_optimization.overlap_comm.unwrap_or(true);
285
286 tracing::debug!(
287 "ZeRO Stage 2 - Allgather bucket size: {}",
288 allgather_bucket_size
289 );
290 tracing::debug!("ZeRO Stage 2 - Reduce bucket size: {}", reduce_bucket_size);
291 tracing::debug!("ZeRO Stage 2 - Overlap communication: {}", overlap_comm);
292
293 Ok(())
294 }
295
296 fn initialize_zero_stage3(&self) -> TorshResult<()> {
298 tracing::info!("Initializing DeepSpeed ZeRO Stage 3 (parameter partitioning)");
299
300 let max_live_params = self
302 .config
303 .zero_optimization
304 .stage3_max_live_parameters
305 .unwrap_or(1e9 as u64);
306 let max_reuse_distance = self
307 .config
308 .zero_optimization
309 .stage3_max_reuse_distance
310 .unwrap_or(1000);
311 let prefetch_bucket_size = self
312 .config
313 .zero_optimization
314 .stage3_prefetch_bucket_size
315 .unwrap_or(5e8 as u64);
316
317 tracing::debug!("ZeRO Stage 3 - Max live parameters: {}", max_live_params);
318 tracing::debug!("ZeRO Stage 3 - Max reuse distance: {}", max_reuse_distance);
319 tracing::debug!(
320 "ZeRO Stage 3 - Prefetch bucket size: {}",
321 prefetch_bucket_size
322 );
323
324 Ok(())
325 }
326
327 pub fn config(&self) -> &DeepSpeedConfig {
329 &self.config
330 }
331
332 pub fn is_initialized(&self) -> bool {
334 self.initialized
335 }
336
337 pub fn to_fsdp_config(&self) -> TorshResult<crate::fsdp::FsdpConfig> {
339 use crate::fsdp::{FsdpConfig, MixedPrecisionConfig, ShardingStrategy};
340
341 let sharding_strategy = match self.config.zero_optimization.stage {
342 ZeroStage::Stage0 => ShardingStrategy::NoShard,
343 ZeroStage::Stage1 => ShardingStrategy::ShardGradOp,
344 ZeroStage::Stage2 => ShardingStrategy::ShardGradOp,
345 ZeroStage::Stage3 => ShardingStrategy::FullShard,
346 };
347
348 let mixed_precision = if let Some(ref fp16_config) = self.config.fp16 {
349 if fp16_config.enabled {
350 Some(MixedPrecisionConfig {
351 param_dtype: torsh_core::DType::F16,
352 reduce_dtype: torsh_core::DType::F16,
353 buffer_dtype: torsh_core::DType::F16,
354 keep_low_precision_grads: false,
355 })
356 } else {
357 None
358 }
359 } else {
360 None
361 };
362
363 Ok(FsdpConfig {
364 min_num_params: 1000,
365 auto_wrap_policy: crate::fsdp::AutoWrapPolicy::SizeBasedAutoWrap {
366 min_num_params: 1000,
367 },
368 sharding_strategy,
369 mixed_precision,
370 cpu_offload: self.config.zero_optimization.offload_optimizer.is_some()
371 || self.config.zero_optimization.offload_param.is_some(),
372 memory_config: crate::fsdp::MemoryConfig::default(),
373 backward_prefetch: crate::fsdp::BackwardPrefetch::BackwardPre,
374 })
375 }
376
377 pub fn to_gradient_compression_config(
379 &self,
380 ) -> Option<crate::gradient_compression::CompressionConfig> {
381 if self
383 .config
384 .zero_optimization
385 .reduce_scatter
386 .unwrap_or(false)
387 {
388 Some(crate::gradient_compression::CompressionConfig {
389 method: crate::gradient_compression::CompressionMethod::TopK { k: 0.1 },
390 compression_ratio: 0.1,
391 error_feedback: true,
392 error_feedback_momentum: 0.9,
393 memory_efficient: true,
394 warmup_steps: 5,
395 })
396 } else {
397 None
398 }
399 }
400
401 pub fn get_stats(&self) -> DeepSpeedStats {
403 DeepSpeedStats {
404 zero_stage: self.config.zero_optimization.stage,
405 initialized: self.initialized,
406 fp16_enabled: self
407 .config
408 .fp16
409 .as_ref()
410 .map(|c| c.enabled)
411 .unwrap_or(false),
412 cpu_offload_enabled: self.config.zero_force_ds_cpu_optimizer.unwrap_or(false),
413 activation_checkpointing_enabled: self
414 .config
415 .activation_checkpointing
416 .as_ref()
417 .map(|c| c.partition_activations.unwrap_or(false))
418 .unwrap_or(false),
419 }
420 }
421}
422
423impl Default for DeepSpeedIntegration {
424 fn default() -> Self {
425 Self::new(DeepSpeedConfig::default())
426 }
427}
428
429#[derive(Debug, Clone)]
431pub struct DeepSpeedStats {
432 pub zero_stage: ZeroStage,
434 pub initialized: bool,
436 pub fp16_enabled: bool,
438 pub cpu_offload_enabled: bool,
440 pub activation_checkpointing_enabled: bool,
442}
443
444impl Default for ZeroOptimizationConfig {
445 fn default() -> Self {
446 Self {
447 stage: ZeroStage::Stage0,
448 allgather_bucket_size: None,
449 reduce_bucket_size: None,
450 overlap_comm: None,
451 contiguous_gradients: None,
452 sub_group_size: None,
453 reduce_scatter: None,
454 allgather_partitions: None,
455 stage3_max_live_parameters: None,
456 stage3_max_reuse_distance: None,
457 stage3_prefetch_bucket_size: None,
458 stage3_param_persistence_threshold: None,
459 offload_optimizer: None,
460 offload_param: None,
461 }
462 }
463}
464
465pub mod utils {
467 use super::*;
468
469 pub fn create_zero_stage1_config() -> DeepSpeedConfig {
471 DeepSpeedConfig {
472 zero_optimization: ZeroOptimizationConfig {
473 stage: ZeroStage::Stage1,
474 overlap_comm: Some(true),
475 contiguous_gradients: Some(true),
476 reduce_bucket_size: Some(2e8 as u64),
477 ..Default::default()
478 },
479 ..Default::default()
480 }
481 }
482
483 pub fn create_zero_stage2_config() -> DeepSpeedConfig {
485 DeepSpeedConfig {
486 zero_optimization: ZeroOptimizationConfig {
487 stage: ZeroStage::Stage2,
488 overlap_comm: Some(true),
489 contiguous_gradients: Some(true),
490 reduce_bucket_size: Some(2e8 as u64),
491 allgather_bucket_size: Some(2e8 as u64),
492 ..Default::default()
493 },
494 ..Default::default()
495 }
496 }
497
498 pub fn create_zero_stage3_config() -> DeepSpeedConfig {
500 DeepSpeedConfig {
501 zero_optimization: ZeroOptimizationConfig {
502 stage: ZeroStage::Stage3,
503 overlap_comm: Some(true),
504 contiguous_gradients: Some(true),
505 reduce_bucket_size: Some(2e8 as u64),
506 allgather_bucket_size: Some(2e8 as u64),
507 stage3_max_live_parameters: Some(1e9 as u64),
508 stage3_max_reuse_distance: Some(1000),
509 stage3_prefetch_bucket_size: Some(5e8 as u64),
510 stage3_param_persistence_threshold: Some(1e6 as u64),
511 ..Default::default()
512 },
513 ..Default::default()
514 }
515 }
516
517 pub fn create_fp16_config() -> DeepSpeedConfig {
519 DeepSpeedConfig {
520 zero_optimization: ZeroOptimizationConfig {
521 stage: ZeroStage::Stage2,
522 overlap_comm: Some(true),
523 contiguous_gradients: Some(true),
524 reduce_bucket_size: Some(2e8 as u64),
525 allgather_bucket_size: Some(2e8 as u64),
526 ..Default::default()
527 },
528 fp16: Some(FP16Config {
529 enabled: true,
530 loss_scale: None,
531 loss_scale_window: Some(1000),
532 hysteresis: Some(2),
533 min_loss_scale: Some(1.0),
534 initial_scale_power: Some(16),
535 }),
536 ..Default::default()
537 }
538 }
539
540 pub fn create_cpu_offload_config() -> DeepSpeedConfig {
542 DeepSpeedConfig {
543 zero_optimization: ZeroOptimizationConfig {
544 stage: ZeroStage::Stage3,
545 overlap_comm: Some(true),
546 contiguous_gradients: Some(true),
547 reduce_bucket_size: Some(2e8 as u64),
548 allgather_bucket_size: Some(2e8 as u64),
549 stage3_max_live_parameters: Some(1e9 as u64),
550 stage3_max_reuse_distance: Some(1000),
551 stage3_prefetch_bucket_size: Some(5e8 as u64),
552 stage3_param_persistence_threshold: Some(1e6 as u64),
553 offload_optimizer: Some(OffloadOptimizerConfig {
554 device: "cpu".to_string(),
555 nvme_path: None,
556 pin_memory: Some(false),
557 buffer_count: Some(4),
558 fast_init: Some(false),
559 }),
560 offload_param: Some(OffloadParamConfig {
561 device: "cpu".to_string(),
562 nvme_path: None,
563 pin_memory: Some(false),
564 buffer_count: Some(4),
565 max_in_cpu: Some(1e9 as u64),
566 }),
567 ..Default::default()
568 },
569 zero_force_ds_cpu_optimizer: Some(true),
570 ..Default::default()
571 }
572 }
573}
574
575#[cfg(test)]
576mod tests {
577 use super::*;
578
579 #[test]
580 fn test_deepspeed_config_serialization() {
581 let config = utils::create_zero_stage2_config();
582 let json = serde_json::to_string(&config).unwrap();
583 let deserialized: DeepSpeedConfig = serde_json::from_str(&json).unwrap();
584
585 assert_eq!(
586 config.zero_optimization.stage,
587 deserialized.zero_optimization.stage
588 );
589 assert_eq!(
590 config.zero_optimization.overlap_comm,
591 deserialized.zero_optimization.overlap_comm
592 );
593 }
594
595 #[test]
596 fn test_deepspeed_integration_initialization() {
597 let config = utils::create_zero_stage1_config();
598 let mut integration = DeepSpeedIntegration::new(config);
599
600 assert!(!integration.is_initialized());
601 integration.initialize().unwrap();
602 assert!(integration.is_initialized());
603 }
604
605 #[test]
606 fn test_deepspeed_to_fsdp_config() {
607 let config = utils::create_zero_stage3_config();
608 let integration = DeepSpeedIntegration::new(config);
609 let fsdp_config = integration.to_fsdp_config().unwrap();
610
611 assert_eq!(
612 fsdp_config.sharding_strategy,
613 crate::fsdp::ShardingStrategy::FullShard
614 );
615 }
616
617 #[test]
618 fn test_deepspeed_stats() {
619 let config = utils::create_fp16_config();
620 let integration = DeepSpeedIntegration::new(config);
621 let stats = integration.get_stats();
622
623 assert_eq!(stats.zero_stage, ZeroStage::Stage2);
624 assert!(!stats.initialized);
625 assert!(stats.fp16_enabled); }
627
628 #[test]
629 fn test_deepspeed_config_validation() {
630 let mut config = utils::create_zero_stage3_config();
631 config.zero_optimization.stage3_max_live_parameters = None;
632
633 let mut integration = DeepSpeedIntegration::new(config);
634 assert!(integration.initialize().is_err());
635 }
636}