1use crate::backend::ReduceOp;
11use crate::collectives::{all_gather, all_reduce};
12use crate::{ProcessGroup, Rank, TorshDistributedError, TorshResult};
13use dashmap::DashMap;
14use parking_lot::RwLock;
15use std::collections::HashMap;
16use std::sync::{Arc, Mutex};
17use torsh_core::{device::DeviceType, error::Result, DType, Shape};
18use torsh_nn::{Module, Parameter};
19use torsh_tensor::Tensor;
20use tracing::{debug, info};
21
22#[derive(Debug, Clone)]
24pub struct FsdpConfig {
25 pub min_num_params: usize,
27 pub auto_wrap_policy: AutoWrapPolicy,
29 pub sharding_strategy: ShardingStrategy,
31 pub mixed_precision: Option<MixedPrecisionConfig>,
33 pub cpu_offload: bool,
35 pub memory_config: MemoryConfig,
37 pub backward_prefetch: BackwardPrefetch,
39}
40
41impl Default for FsdpConfig {
42 fn default() -> Self {
43 Self {
44 min_num_params: 1000,
45 auto_wrap_policy: AutoWrapPolicy::SizeBasedAutoWrap {
46 min_num_params: 1000,
47 },
48 sharding_strategy: ShardingStrategy::FullShard,
49 mixed_precision: None,
50 cpu_offload: false,
51 memory_config: MemoryConfig::default(),
52 backward_prefetch: BackwardPrefetch::BackwardPre,
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub enum AutoWrapPolicy {
60 SizeBasedAutoWrap { min_num_params: usize },
62 ModuleTypeBasedAutoWrap { module_types: Vec<String> },
64 CustomAutoWrap,
66 NoAutoWrap,
68}
69
70#[derive(Debug, Clone, PartialEq)]
72pub enum ShardingStrategy {
73 FullShard,
75 ShardGradOp,
77 NoShard,
79 HybridShard,
81}
82
83#[derive(Debug, Clone)]
85pub struct MixedPrecisionConfig {
86 pub param_dtype: DType,
88 pub reduce_dtype: DType,
90 pub buffer_dtype: DType,
92 pub keep_low_precision_grads: bool,
94}
95
96#[derive(Debug, Clone)]
98pub struct MemoryConfig {
99 pub limit_all_gathers: bool,
101 pub use_orig_params: bool,
103 pub offload_to_cpu: bool,
105}
106
107impl Default for MemoryConfig {
108 fn default() -> Self {
109 Self {
110 limit_all_gathers: true,
111 use_orig_params: false,
112 offload_to_cpu: false,
113 }
114 }
115}
116
117#[derive(Debug, Clone, PartialEq)]
119pub enum BackwardPrefetch {
120 BackwardPre,
122 BackwardPost,
124 None,
126}
127
128#[derive(Debug, Clone)]
130pub struct ShardInfo {
131 pub rank: Rank,
133 pub start_idx: usize,
135 pub shard_size: usize,
137 pub original_shape: Shape,
139 pub is_local: bool,
141}
142
143#[derive(Debug)]
145enum ParameterState {
146 Sharded {
148 #[allow(dead_code)]
149 shard_info: ShardInfo,
150 },
151 Gathered {
153 #[allow(dead_code)]
154 full_tensor: Tensor,
155 },
156 #[allow(dead_code)]
158 Gathering,
159 #[allow(dead_code)]
161 Sharding,
162}
163
164pub struct FullyShardedDataParallel {
166 module: Arc<RwLock<dyn Module>>,
168 process_group: Arc<ProcessGroup>,
170 config: FsdpConfig,
172 param_states: Arc<DashMap<String, ParameterState>>,
174 sharded_params: Arc<DashMap<String, Tensor>>,
176 #[allow(dead_code)]
178 gathered_params: Arc<DashMap<String, Tensor>>,
179 #[allow(dead_code)]
181 grad_buffers: Arc<DashMap<String, Tensor>>,
182 training: Arc<Mutex<bool>>,
184 #[allow(dead_code)]
186 compute_stream: Arc<Mutex<Option<String>>>,
187 memory_stats: Arc<Mutex<MemoryStats>>,
189}
190
191#[derive(Debug, Default)]
193pub struct MemoryStats {
194 pub peak_memory_mb: f64,
196 pub current_memory_mb: f64,
198 pub memory_saved_mb: f64,
200 pub num_all_gathers: u64,
202 pub num_reduce_scatters: u64,
204}
205
206impl FullyShardedDataParallel {
207 pub fn new(
209 module: Arc<RwLock<dyn Module>>,
210 process_group: Arc<ProcessGroup>,
211 config: FsdpConfig,
212 ) -> TorshResult<Self> {
213 let fsdp = Self {
214 module,
215 process_group,
216 config,
217 param_states: Arc::new(DashMap::new()),
218 sharded_params: Arc::new(DashMap::new()),
219 gathered_params: Arc::new(DashMap::new()),
220 grad_buffers: Arc::new(DashMap::new()),
221 training: Arc::new(Mutex::new(true)),
222 compute_stream: Arc::new(Mutex::new(None)),
223 memory_stats: Arc::new(Mutex::new(MemoryStats::default())),
224 };
225
226 fsdp.shard_parameters()?;
228
229 info!(
230 "FSDP initialized with strategy {:?} for {} workers",
231 fsdp.config.sharding_strategy,
232 fsdp.process_group.world_size()
233 );
234
235 Ok(fsdp)
236 }
237
238 fn shard_parameters(&self) -> TorshResult<()> {
240 let module_guard = self.module.read();
241 let parameters = module_guard.parameters();
242 drop(module_guard);
243
244 let world_size = self.process_group.world_size() as usize;
245 let rank = self.process_group.rank() as usize;
246
247 for (name, param) in parameters {
248 let tensor_arc = param.tensor();
249 let tensor_guard = tensor_arc.read();
250 if tensor_guard.numel() < self.config.min_num_params {
251 self.param_states.insert(
253 name.clone(),
254 ParameterState::Gathered {
255 full_tensor: tensor_guard.clone(),
256 },
257 );
258 continue;
259 }
260
261 let flat_param = tensor_guard.flatten()?;
263 let total_elements = flat_param.numel();
264
265 let base_shard_size = total_elements / world_size;
267 let remainder = total_elements % world_size;
268
269 let mut start_idx = 0;
270 for worker_rank in 0..world_size {
271 let shard_size = base_shard_size + if worker_rank < remainder { 1 } else { 0 };
272
273 if worker_rank == rank {
274 let shard = flat_param
276 .slice(0, start_idx, start_idx + shard_size)?
277 .to_tensor()?;
278 self.sharded_params.insert(name.clone(), shard);
279
280 let shard_info = ShardInfo {
281 rank: worker_rank as Rank,
282 start_idx,
283 shard_size,
284 original_shape: tensor_guard.shape().clone(),
285 is_local: true,
286 };
287
288 self.param_states
289 .insert(name.clone(), ParameterState::Sharded { shard_info });
290 }
291
292 start_idx += shard_size;
293 }
294
295 debug!(
296 "Sharded parameter '{}' with {} elements across {} workers",
297 name, total_elements, world_size
298 );
299 drop(tensor_guard);
300 }
301
302 Ok(())
303 }
304
305 #[allow(dead_code)]
307 async fn gather_parameters(&self, param_names: &[String]) -> TorshResult<()> {
308 for param_name in param_names {
309 if let Some(mut state_ref) = self.param_states.get_mut(param_name) {
310 if let ParameterState::Sharded { shard_info } = &*state_ref {
311 let original_shape = shard_info.original_shape.clone();
313 *state_ref = ParameterState::Gathering;
314 drop(state_ref);
315
316 let shard = self.sharded_params.get(param_name).ok_or_else(|| {
318 TorshDistributedError::backend_error(
319 "fsdp",
320 format!("Shard not found for parameter '{}'", param_name),
321 )
322 })?;
323
324 let mut gathered_tensors = Vec::new();
325 all_gather(&mut gathered_tensors, &*shard, &self.process_group).await?;
326
327 let gathered_tensor = if gathered_tensors.len() == 1 {
329 gathered_tensors
330 .into_iter()
331 .next()
332 .expect("gathered_tensors should not be empty")
333 } else {
334 gathered_tensors
336 .into_iter()
337 .next()
338 .expect("gathered_tensors should not be empty")
339 };
340
341 let shape_dims: Vec<i32> =
343 original_shape.dims().iter().map(|&x| x as i32).collect();
344 let reshaped = gathered_tensor.reshape(&shape_dims)?;
345
346 self.gathered_params
348 .insert(param_name.clone(), reshaped.clone());
349
350 self.param_states.insert(
352 param_name.clone(),
353 ParameterState::Gathered {
354 full_tensor: reshaped,
355 },
356 );
357
358 let mut stats = self
360 .memory_stats
361 .lock()
362 .expect("lock should not be poisoned");
363 stats.num_all_gathers += 1;
364 }
365 }
366 }
367
368 Ok(())
369 }
370
371 #[allow(dead_code)]
373 async fn reduce_scatter_gradients(&self, param_names: &[String]) -> TorshResult<()> {
374 for param_name in param_names {
375 if let Some(grad_buffer) = self.grad_buffers.get(param_name) {
376 let mut reduced_grad = grad_buffer.clone();
378 all_reduce(&mut reduced_grad, ReduceOp::Sum, &self.process_group).await?;
379
380 if let Some(state_ref) = self.param_states.get(param_name) {
382 if let ParameterState::Sharded { shard_info } = &*state_ref {
383 let grad_shard = reduced_grad.slice(
384 0,
385 shard_info.start_idx,
386 shard_info.start_idx + shard_info.shard_size,
387 )?;
388
389 if let Some(mut param_shard) = self.sharded_params.get_mut(param_name) {
391 let grad_tensor = grad_shard.to_tensor()?;
393 *param_shard = param_shard.sub(&grad_tensor)?;
394 }
395 }
396 }
397
398 let mut stats = self
400 .memory_stats
401 .lock()
402 .expect("lock should not be poisoned");
403 stats.num_reduce_scatters += 1;
404 }
405
406 self.param_states.insert(
408 param_name.clone(),
409 ParameterState::Sharded {
410 shard_info: self.get_shard_info(param_name)?,
411 },
412 );
413
414 self.gathered_params.remove(param_name);
416 }
417
418 Ok(())
419 }
420
421 #[allow(dead_code)]
423 fn get_shard_info(&self, param_name: &str) -> TorshResult<ShardInfo> {
424 if let Some(state_ref) = self.param_states.get(param_name) {
425 match &*state_ref {
426 ParameterState::Sharded { shard_info } => Ok(shard_info.clone()),
427 _ => Err(TorshDistributedError::backend_error(
428 "fsdp",
429 format!("Parameter '{}' is not in sharded state", param_name),
430 )),
431 }
432 } else {
433 Err(TorshDistributedError::backend_error(
434 "fsdp",
435 format!("Parameter '{}' not found", param_name),
436 ))
437 }
438 }
439
440 pub fn train(&self, mode: bool) {
442 *self.training.lock().expect("lock should not be poisoned") = mode;
443 let mut module_guard = self.module.write();
444 if mode {
445 module_guard.train();
446 } else {
447 module_guard.eval();
448 }
449 }
450
451 pub fn is_training(&self) -> bool {
453 *self.training.lock().expect("lock should not be poisoned")
454 }
455
456 pub fn memory_stats(&self) -> MemoryStats {
458 let stats = self
459 .memory_stats
460 .lock()
461 .expect("lock should not be poisoned");
462 MemoryStats {
463 peak_memory_mb: stats.peak_memory_mb,
464 current_memory_mb: stats.current_memory_mb,
465 memory_saved_mb: stats.memory_saved_mb,
466 num_all_gathers: stats.num_all_gathers,
467 num_reduce_scatters: stats.num_reduce_scatters,
468 }
469 }
470
471 pub fn num_parameters(&self) -> usize {
473 let module_guard = self.module.read();
474 let parameters = module_guard.parameters();
475 parameters.values().map(|p| p.tensor().read().numel()).sum()
476 }
477
478 pub fn local_sharding_ratio(&self) -> f64 {
480 let total_params = self.num_parameters();
481 let local_params: usize = self
482 .sharded_params
483 .iter()
484 .map(|entry| entry.value().numel())
485 .sum();
486
487 if total_params > 0 {
488 local_params as f64 / total_params as f64
489 } else {
490 0.0
491 }
492 }
493}
494
495impl Module for FullyShardedDataParallel {
496 fn forward(&self, input: &Tensor) -> Result<Tensor> {
497 let _param_names: Vec<String> = self
499 .param_states
500 .iter()
501 .filter_map(|entry| match entry.value() {
502 ParameterState::Sharded { .. } => Some(entry.key().clone()),
503 _ => None,
504 })
505 .collect();
506
507 let module_guard = self.module.read();
513 let output = module_guard.forward(input)?;
514 drop(module_guard);
515
516 if self.is_training() {
518 debug!("Forward pass completed, gradients will be reduce-scattered in backward");
521 } else {
522 }
525
526 Ok(output)
527 }
528
529 fn parameters(&self) -> HashMap<String, Parameter> {
530 let mut params = HashMap::new();
532
533 for entry in self.sharded_params.iter() {
534 let name = entry.key().clone();
535 let tensor = entry.value().clone();
536 params.insert(name, Parameter::new(tensor));
537 }
538
539 params
540 }
541
542 fn named_parameters(&self) -> HashMap<String, Parameter> {
543 self.parameters()
544 }
545
546 fn training(&self) -> bool {
547 *self.training.lock().expect("lock should not be poisoned")
548 }
549
550 fn train(&mut self) {
551 *self.training.lock().expect("lock should not be poisoned") = true;
552 }
553
554 fn eval(&mut self) {
555 *self.training.lock().expect("lock should not be poisoned") = false;
556 }
557
558 fn to_device(&mut self, _device: DeviceType) -> torsh_core::Result<()> {
559 Ok(())
561 }
562}
563
564pub fn fsdp_wrap<M: Module + 'static>(
566 module: M,
567 process_group: Arc<ProcessGroup>,
568 config: Option<FsdpConfig>,
569) -> TorshResult<FullyShardedDataParallel> {
570 let config = config.unwrap_or_default();
571 let module_arc = Arc::new(RwLock::new(module));
572 FullyShardedDataParallel::new(module_arc, process_group, config)
573}
574
575pub fn auto_wrap_modules<M: Module + 'static>(
577 module: M,
578 process_group: Arc<ProcessGroup>,
579 auto_wrap_policy: AutoWrapPolicy,
580) -> TorshResult<FullyShardedDataParallel> {
581 let config = FsdpConfig {
582 auto_wrap_policy,
583 ..Default::default()
584 };
585
586 fsdp_wrap(module, process_group, Some(config))
587}
588
589#[cfg(test)]
590mod tests {
591 use super::*;
592 use crate::{init_process_group, BackendType};
593
594 use torsh_nn::{prelude::Linear, Module};
595
596 #[tokio::test]
597 async fn test_fsdp_initialization() -> TorshResult<()> {
598 let process_group =
599 Arc::new(init_process_group(BackendType::Gloo, 0, 2, "127.0.0.1", 12345).await?);
600
601 let linear = Linear::new(128, 64, true);
602 let config = FsdpConfig::default();
603
604 let fsdp =
605 FullyShardedDataParallel::new(Arc::new(RwLock::new(linear)), process_group, config)?;
606
607 assert!(fsdp.local_sharding_ratio() > 0.0);
608 assert!(fsdp.local_sharding_ratio() <= 1.0);
609
610 Ok(())
611 }
612
613 #[tokio::test]
614 async fn test_fsdp_forward_pass() -> TorshResult<()> {
615 let process_group =
616 Arc::new(init_process_group(BackendType::Gloo, 0, 1, "127.0.0.1", 12346).await?);
617
618 let linear = Linear::new(64, 32, true);
619 let fsdp = fsdp_wrap(linear, process_group, None)?;
620
621 let input = torsh_tensor::creation::randn(&[8, 64])?;
622 let output = fsdp.forward(&input)?;
623
624 assert_eq!(output.shape().dims(), &[8, 32]);
625
626 Ok(())
627 }
628
629 #[test]
630 fn test_fsdp_config() {
631 let config = FsdpConfig::default();
632 assert_eq!(config.min_num_params, 1000);
633 assert_eq!(config.sharding_strategy, ShardingStrategy::FullShard);
634 assert_eq!(config.backward_prefetch, BackwardPrefetch::BackwardPre);
635
636 let custom_config = FsdpConfig {
637 min_num_params: 500,
638 sharding_strategy: ShardingStrategy::ShardGradOp,
639 cpu_offload: true,
640 ..Default::default()
641 };
642
643 assert_eq!(custom_config.min_num_params, 500);
644 assert_eq!(
645 custom_config.sharding_strategy,
646 ShardingStrategy::ShardGradOp
647 );
648 assert!(custom_config.cpu_offload);
649 }
650
651 #[test]
652 fn test_shard_info() {
653 let shard_info = ShardInfo {
654 rank: 0,
655 start_idx: 0,
656 shard_size: 1000,
657 original_shape: Shape::new(vec![10, 100]),
658 is_local: true,
659 };
660
661 assert_eq!(shard_info.rank, 0);
662 assert_eq!(shard_info.shard_size, 1000);
663 assert!(shard_info.is_local);
664 }
665
666 #[test]
667 fn test_memory_stats() {
668 let stats = MemoryStats::default();
669 assert_eq!(stats.peak_memory_mb, 0.0);
670 assert_eq!(stats.num_all_gathers, 0);
671 assert_eq!(stats.num_reduce_scatters, 0);
672 }
673
674 #[tokio::test]
675 async fn test_auto_wrap() -> TorshResult<()> {
676 let process_group =
677 Arc::new(init_process_group(BackendType::Gloo, 0, 1, "127.0.0.1", 12347).await?);
678
679 let linear = Linear::new(100, 50, true);
680 let policy = AutoWrapPolicy::SizeBasedAutoWrap {
681 min_num_params: 1000,
682 };
683
684 let fsdp = auto_wrap_modules(linear, process_group, policy)?;
685
686 assert!(fsdp.local_sharding_ratio() >= 0.9); Ok(())
690 }
691}