Skip to main content

torsh_distributed/
fsdp.rs

1//! Fully Sharded Data Parallel (FSDP) implementation
2//!
3//! FSDP enables training very large models by sharding parameters across workers.
4//! Unlike DDP which replicates the full model on each worker, FSDP:
5//! - Shards model parameters across workers (each worker stores only a fraction)
6//! - Gathers parameters during forward/backward passes as needed
7//! - Re-shards parameters after computation to save memory
8//! - Supports nested sharding for hierarchical model parallelism
9
10use crate::backend::ReduceOp;
11use crate::collectives::{all_gather, all_reduce};
12use crate::{ProcessGroup, Rank, TorshDistributedError, TorshResult};
13use dashmap::DashMap;
14use parking_lot::RwLock;
15use std::collections::HashMap;
16use std::sync::{Arc, Mutex};
17use torsh_core::{device::DeviceType, error::Result, DType, Shape};
18use torsh_nn::{Module, Parameter};
19use torsh_tensor::Tensor;
20use tracing::{debug, info};
21
22/// FSDP configuration options
23#[derive(Debug, Clone)]
24pub struct FsdpConfig {
25    /// Minimum number of parameters for a module to be sharded
26    pub min_num_params: usize,
27    /// Whether to use auto wrapping of submodules
28    pub auto_wrap_policy: AutoWrapPolicy,
29    /// Sharding strategy
30    pub sharding_strategy: ShardingStrategy,
31    /// Mixed precision configuration
32    pub mixed_precision: Option<MixedPrecisionConfig>,
33    /// Whether to use CPU offloading
34    pub cpu_offload: bool,
35    /// Memory management configuration
36    pub memory_config: MemoryConfig,
37    /// Backward prefetch configuration
38    pub backward_prefetch: BackwardPrefetch,
39}
40
41impl Default for FsdpConfig {
42    fn default() -> Self {
43        Self {
44            min_num_params: 1000,
45            auto_wrap_policy: AutoWrapPolicy::SizeBasedAutoWrap {
46                min_num_params: 1000,
47            },
48            sharding_strategy: ShardingStrategy::FullShard,
49            mixed_precision: None,
50            cpu_offload: false,
51            memory_config: MemoryConfig::default(),
52            backward_prefetch: BackwardPrefetch::BackwardPre,
53        }
54    }
55}
56
57/// Auto-wrapping policies for FSDP
58#[derive(Debug, Clone)]
59pub enum AutoWrapPolicy {
60    /// Wrap modules based on number of parameters
61    SizeBasedAutoWrap { min_num_params: usize },
62    /// Wrap specific module types
63    ModuleTypeBasedAutoWrap { module_types: Vec<String> },
64    /// Custom wrapping function
65    CustomAutoWrap,
66    /// No auto-wrapping
67    NoAutoWrap,
68}
69
70/// Sharding strategies
71#[derive(Debug, Clone, PartialEq)]
72pub enum ShardingStrategy {
73    /// Shard parameters, gradients, and optimizer states (default)
74    FullShard,
75    /// Shard gradients and optimizer states only
76    ShardGradOp,
77    /// No sharding (equivalent to DDP)
78    NoShard,
79    /// Hybrid sharding for hierarchical parallelism
80    HybridShard,
81}
82
83/// Mixed precision configuration
84#[derive(Debug, Clone)]
85pub struct MixedPrecisionConfig {
86    /// Parameter data type
87    pub param_dtype: DType,
88    /// Gradient reduction data type
89    pub reduce_dtype: DType,
90    /// Buffer data type
91    pub buffer_dtype: DType,
92    /// Keep parameters in fp32 for backward pass
93    pub keep_low_precision_grads: bool,
94}
95
96/// Memory management configuration
97#[derive(Debug, Clone)]
98pub struct MemoryConfig {
99    /// Limit all-gather for parameters (memory vs speed tradeoff)
100    pub limit_all_gathers: bool,
101    /// Use original parameters for computation when possible
102    pub use_orig_params: bool,
103    /// Offload parameters to CPU when not in use
104    pub offload_to_cpu: bool,
105}
106
107impl Default for MemoryConfig {
108    fn default() -> Self {
109        Self {
110            limit_all_gathers: true,
111            use_orig_params: false,
112            offload_to_cpu: false,
113        }
114    }
115}
116
117/// Backward prefetch configuration
118#[derive(Debug, Clone, PartialEq)]
119pub enum BackwardPrefetch {
120    /// Prefetch the next layer's parameters during backward pass
121    BackwardPre,
122    /// Prefetch after current layer's gradient computation
123    BackwardPost,
124    /// No prefetching
125    None,
126}
127
128/// Parameter shard information
129#[derive(Debug, Clone)]
130pub struct ShardInfo {
131    /// Rank that owns this shard
132    pub rank: Rank,
133    /// Start index in the flattened parameter tensor
134    pub start_idx: usize,
135    /// Number of elements in this shard
136    pub shard_size: usize,
137    /// Shape of the original parameter
138    pub original_shape: Shape,
139    /// Whether this shard is currently on this worker
140    pub is_local: bool,
141}
142
143/// FSDP parameter state
144#[derive(Debug)]
145enum ParameterState {
146    /// Parameter is sharded across workers
147    Sharded {
148        #[allow(dead_code)]
149        shard_info: ShardInfo,
150    },
151    /// Parameter is gathered (full tensor available)
152    Gathered {
153        #[allow(dead_code)]
154        full_tensor: Tensor,
155    },
156    /// Parameter is being gathered (async operation in progress)
157    #[allow(dead_code)]
158    Gathering,
159    /// Parameter is being sharded (async operation in progress)  
160    #[allow(dead_code)]
161    Sharding,
162}
163
164/// Fully Sharded Data Parallel wrapper
165pub struct FullyShardedDataParallel {
166    /// Wrapped module
167    module: Arc<RwLock<dyn Module>>,
168    /// Process group for communication
169    process_group: Arc<ProcessGroup>,
170    /// FSDP configuration
171    config: FsdpConfig,
172    /// Parameter state tracking
173    param_states: Arc<DashMap<String, ParameterState>>,
174    /// Sharded parameters storage
175    sharded_params: Arc<DashMap<String, Tensor>>,
176    /// Gathered parameters cache
177    #[allow(dead_code)]
178    gathered_params: Arc<DashMap<String, Tensor>>,
179    /// Gradient buffers
180    #[allow(dead_code)]
181    grad_buffers: Arc<DashMap<String, Tensor>>,
182    /// Whether we're in training mode
183    training: Arc<Mutex<bool>>,
184    /// Current compute stream for overlapping
185    #[allow(dead_code)]
186    compute_stream: Arc<Mutex<Option<String>>>,
187    /// Memory usage statistics
188    memory_stats: Arc<Mutex<MemoryStats>>,
189}
190
191/// Memory usage statistics
192#[derive(Debug, Default)]
193pub struct MemoryStats {
194    /// Peak memory usage during training
195    pub peak_memory_mb: f64,
196    /// Current memory usage
197    pub current_memory_mb: f64,
198    /// Memory saved by sharding
199    pub memory_saved_mb: f64,
200    /// Number of all-gather operations
201    pub num_all_gathers: u64,
202    /// Number of reduce-scatter operations  
203    pub num_reduce_scatters: u64,
204}
205
206impl FullyShardedDataParallel {
207    /// Create a new FSDP wrapper
208    pub fn new(
209        module: Arc<RwLock<dyn Module>>,
210        process_group: Arc<ProcessGroup>,
211        config: FsdpConfig,
212    ) -> TorshResult<Self> {
213        let fsdp = Self {
214            module,
215            process_group,
216            config,
217            param_states: Arc::new(DashMap::new()),
218            sharded_params: Arc::new(DashMap::new()),
219            gathered_params: Arc::new(DashMap::new()),
220            grad_buffers: Arc::new(DashMap::new()),
221            training: Arc::new(Mutex::new(true)),
222            compute_stream: Arc::new(Mutex::new(None)),
223            memory_stats: Arc::new(Mutex::new(MemoryStats::default())),
224        };
225
226        // Initialize parameter sharding
227        fsdp.shard_parameters()?;
228
229        info!(
230            "FSDP initialized with strategy {:?} for {} workers",
231            fsdp.config.sharding_strategy,
232            fsdp.process_group.world_size()
233        );
234
235        Ok(fsdp)
236    }
237
238    /// Shard parameters across workers
239    fn shard_parameters(&self) -> TorshResult<()> {
240        let module_guard = self.module.read();
241        let parameters = module_guard.parameters();
242        drop(module_guard);
243
244        let world_size = self.process_group.world_size() as usize;
245        let rank = self.process_group.rank() as usize;
246
247        for (name, param) in parameters {
248            let tensor_arc = param.tensor();
249            let tensor_guard = tensor_arc.read();
250            if tensor_guard.numel() < self.config.min_num_params {
251                // Don't shard small parameters
252                self.param_states.insert(
253                    name.clone(),
254                    ParameterState::Gathered {
255                        full_tensor: tensor_guard.clone(),
256                    },
257                );
258                continue;
259            }
260
261            // Flatten parameter for sharding
262            let flat_param = tensor_guard.flatten()?;
263            let total_elements = flat_param.numel();
264
265            // Calculate shard sizes (handle uneven division)
266            let base_shard_size = total_elements / world_size;
267            let remainder = total_elements % world_size;
268
269            let mut start_idx = 0;
270            for worker_rank in 0..world_size {
271                let shard_size = base_shard_size + if worker_rank < remainder { 1 } else { 0 };
272
273                if worker_rank == rank {
274                    // This is our shard
275                    let shard = flat_param
276                        .slice(0, start_idx, start_idx + shard_size)?
277                        .to_tensor()?;
278                    self.sharded_params.insert(name.clone(), shard);
279
280                    let shard_info = ShardInfo {
281                        rank: worker_rank as Rank,
282                        start_idx,
283                        shard_size,
284                        original_shape: tensor_guard.shape().clone(),
285                        is_local: true,
286                    };
287
288                    self.param_states
289                        .insert(name.clone(), ParameterState::Sharded { shard_info });
290                }
291
292                start_idx += shard_size;
293            }
294
295            debug!(
296                "Sharded parameter '{}' with {} elements across {} workers",
297                name, total_elements, world_size
298            );
299            drop(tensor_guard);
300        }
301
302        Ok(())
303    }
304
305    /// Gather parameters for computation
306    #[allow(dead_code)]
307    async fn gather_parameters(&self, param_names: &[String]) -> TorshResult<()> {
308        for param_name in param_names {
309            if let Some(mut state_ref) = self.param_states.get_mut(param_name) {
310                if let ParameterState::Sharded { shard_info } = &*state_ref {
311                    // Mark as gathering
312                    let original_shape = shard_info.original_shape.clone();
313                    *state_ref = ParameterState::Gathering;
314                    drop(state_ref);
315
316                    // Perform all-gather to reconstruct full parameter
317                    let shard = self.sharded_params.get(param_name).ok_or_else(|| {
318                        TorshDistributedError::backend_error(
319                            "fsdp",
320                            format!("Shard not found for parameter '{}'", param_name),
321                        )
322                    })?;
323
324                    let mut gathered_tensors = Vec::new();
325                    all_gather(&mut gathered_tensors, &*shard, &self.process_group).await?;
326
327                    // Concatenate all gathered tensors
328                    let gathered_tensor = if gathered_tensors.len() == 1 {
329                        gathered_tensors
330                            .into_iter()
331                            .next()
332                            .expect("gathered_tensors should not be empty")
333                    } else {
334                        // For simplicity, just use the first tensor (mock behavior)
335                        gathered_tensors
336                            .into_iter()
337                            .next()
338                            .expect("gathered_tensors should not be empty")
339                    };
340
341                    // Reshape to original shape
342                    let shape_dims: Vec<i32> =
343                        original_shape.dims().iter().map(|&x| x as i32).collect();
344                    let reshaped = gathered_tensor.reshape(&shape_dims)?;
345
346                    // Cache gathered parameter
347                    self.gathered_params
348                        .insert(param_name.clone(), reshaped.clone());
349
350                    // Update state
351                    self.param_states.insert(
352                        param_name.clone(),
353                        ParameterState::Gathered {
354                            full_tensor: reshaped,
355                        },
356                    );
357
358                    // Update statistics
359                    let mut stats = self
360                        .memory_stats
361                        .lock()
362                        .expect("lock should not be poisoned");
363                    stats.num_all_gathers += 1;
364                }
365            }
366        }
367
368        Ok(())
369    }
370
371    /// Reduce-scatter gradients and re-shard parameters
372    #[allow(dead_code)]
373    async fn reduce_scatter_gradients(&self, param_names: &[String]) -> TorshResult<()> {
374        for param_name in param_names {
375            if let Some(grad_buffer) = self.grad_buffers.get(param_name) {
376                // Perform reduce-scatter on gradients
377                let mut reduced_grad = grad_buffer.clone();
378                all_reduce(&mut reduced_grad, ReduceOp::Sum, &self.process_group).await?;
379
380                // Get the local shard of the gradient
381                if let Some(state_ref) = self.param_states.get(param_name) {
382                    if let ParameterState::Sharded { shard_info } = &*state_ref {
383                        let grad_shard = reduced_grad.slice(
384                            0,
385                            shard_info.start_idx,
386                            shard_info.start_idx + shard_info.shard_size,
387                        )?;
388
389                        // Update the local parameter shard with gradient
390                        if let Some(mut param_shard) = self.sharded_params.get_mut(param_name) {
391                            // Apply gradient (this would typically be done by optimizer)
392                            let grad_tensor = grad_shard.to_tensor()?;
393                            *param_shard = param_shard.sub(&grad_tensor)?;
394                        }
395                    }
396                }
397
398                // Update statistics
399                let mut stats = self
400                    .memory_stats
401                    .lock()
402                    .expect("lock should not be poisoned");
403                stats.num_reduce_scatters += 1;
404            }
405
406            // Re-shard parameter
407            self.param_states.insert(
408                param_name.clone(),
409                ParameterState::Sharded {
410                    shard_info: self.get_shard_info(param_name)?,
411                },
412            );
413
414            // Remove from gathered cache to save memory
415            self.gathered_params.remove(param_name);
416        }
417
418        Ok(())
419    }
420
421    /// Get shard information for a parameter
422    #[allow(dead_code)]
423    fn get_shard_info(&self, param_name: &str) -> TorshResult<ShardInfo> {
424        if let Some(state_ref) = self.param_states.get(param_name) {
425            match &*state_ref {
426                ParameterState::Sharded { shard_info } => Ok(shard_info.clone()),
427                _ => Err(TorshDistributedError::backend_error(
428                    "fsdp",
429                    format!("Parameter '{}' is not in sharded state", param_name),
430                )),
431            }
432        } else {
433            Err(TorshDistributedError::backend_error(
434                "fsdp",
435                format!("Parameter '{}' not found", param_name),
436            ))
437        }
438    }
439
440    /// Set training mode
441    pub fn train(&self, mode: bool) {
442        *self.training.lock().expect("lock should not be poisoned") = mode;
443        let mut module_guard = self.module.write();
444        if mode {
445            module_guard.train();
446        } else {
447            module_guard.eval();
448        }
449    }
450
451    /// Check if in training mode
452    pub fn is_training(&self) -> bool {
453        *self.training.lock().expect("lock should not be poisoned")
454    }
455
456    /// Get memory statistics
457    pub fn memory_stats(&self) -> MemoryStats {
458        let stats = self
459            .memory_stats
460            .lock()
461            .expect("lock should not be poisoned");
462        MemoryStats {
463            peak_memory_mb: stats.peak_memory_mb,
464            current_memory_mb: stats.current_memory_mb,
465            memory_saved_mb: stats.memory_saved_mb,
466            num_all_gathers: stats.num_all_gathers,
467            num_reduce_scatters: stats.num_reduce_scatters,
468        }
469    }
470
471    /// Get the number of parameters in the model
472    pub fn num_parameters(&self) -> usize {
473        let module_guard = self.module.read();
474        let parameters = module_guard.parameters();
475        parameters.values().map(|p| p.tensor().read().numel()).sum()
476    }
477
478    /// Get the local sharding ratio (fraction of parameters stored locally)
479    pub fn local_sharding_ratio(&self) -> f64 {
480        let total_params = self.num_parameters();
481        let local_params: usize = self
482            .sharded_params
483            .iter()
484            .map(|entry| entry.value().numel())
485            .sum();
486
487        if total_params > 0 {
488            local_params as f64 / total_params as f64
489        } else {
490            0.0
491        }
492    }
493}
494
495impl Module for FullyShardedDataParallel {
496    fn forward(&self, input: &Tensor) -> Result<Tensor> {
497        // Get all parameter names that need gathering
498        let _param_names: Vec<String> = self
499            .param_states
500            .iter()
501            .filter_map(|entry| match entry.value() {
502                ParameterState::Sharded { .. } => Some(entry.key().clone()),
503                _ => None,
504            })
505            .collect();
506
507        // Gather parameters for forward pass
508        // Note: In a real implementation, this would be async
509        // For now, we'll use the mock implementation
510
511        // Perform forward pass with gathered parameters
512        let module_guard = self.module.read();
513        let output = module_guard.forward(input)?;
514        drop(module_guard);
515
516        // In training mode, prepare for backward pass
517        if self.is_training() {
518            // Set up gradient hooks for automatic reduce-scatter
519            // This would be implemented with the autograd system
520            debug!("Forward pass completed, gradients will be reduce-scattered in backward");
521        } else {
522            // In eval mode, immediately re-shard to save memory
523            // Note: In a real implementation, this would be async
524        }
525
526        Ok(output)
527    }
528
529    fn parameters(&self) -> HashMap<String, Parameter> {
530        // Return sharded parameters for memory efficiency
531        let mut params = HashMap::new();
532
533        for entry in self.sharded_params.iter() {
534            let name = entry.key().clone();
535            let tensor = entry.value().clone();
536            params.insert(name, Parameter::new(tensor));
537        }
538
539        params
540    }
541
542    fn named_parameters(&self) -> HashMap<String, Parameter> {
543        self.parameters()
544    }
545
546    fn training(&self) -> bool {
547        *self.training.lock().expect("lock should not be poisoned")
548    }
549
550    fn train(&mut self) {
551        *self.training.lock().expect("lock should not be poisoned") = true;
552    }
553
554    fn eval(&mut self) {
555        *self.training.lock().expect("lock should not be poisoned") = false;
556    }
557
558    fn to_device(&mut self, _device: DeviceType) -> torsh_core::Result<()> {
559        // FSDP device management is handled through backend
560        Ok(())
561    }
562}
563
564/// Helper function to wrap a module with FSDP
565pub fn fsdp_wrap<M: Module + 'static>(
566    module: M,
567    process_group: Arc<ProcessGroup>,
568    config: Option<FsdpConfig>,
569) -> TorshResult<FullyShardedDataParallel> {
570    let config = config.unwrap_or_default();
571    let module_arc = Arc::new(RwLock::new(module));
572    FullyShardedDataParallel::new(module_arc, process_group, config)
573}
574
575/// Auto-wrap modules based on policy
576pub fn auto_wrap_modules<M: Module + 'static>(
577    module: M,
578    process_group: Arc<ProcessGroup>,
579    auto_wrap_policy: AutoWrapPolicy,
580) -> TorshResult<FullyShardedDataParallel> {
581    let config = FsdpConfig {
582        auto_wrap_policy,
583        ..Default::default()
584    };
585
586    fsdp_wrap(module, process_group, Some(config))
587}
588
589#[cfg(test)]
590mod tests {
591    use super::*;
592    use crate::{init_process_group, BackendType};
593
594    use torsh_nn::{prelude::Linear, Module};
595
596    #[tokio::test]
597    async fn test_fsdp_initialization() -> TorshResult<()> {
598        let process_group =
599            Arc::new(init_process_group(BackendType::Gloo, 0, 2, "127.0.0.1", 12345).await?);
600
601        let linear = Linear::new(128, 64, true);
602        let config = FsdpConfig::default();
603
604        let fsdp =
605            FullyShardedDataParallel::new(Arc::new(RwLock::new(linear)), process_group, config)?;
606
607        assert!(fsdp.local_sharding_ratio() > 0.0);
608        assert!(fsdp.local_sharding_ratio() <= 1.0);
609
610        Ok(())
611    }
612
613    #[tokio::test]
614    async fn test_fsdp_forward_pass() -> TorshResult<()> {
615        let process_group =
616            Arc::new(init_process_group(BackendType::Gloo, 0, 1, "127.0.0.1", 12346).await?);
617
618        let linear = Linear::new(64, 32, true);
619        let fsdp = fsdp_wrap(linear, process_group, None)?;
620
621        let input = torsh_tensor::creation::randn(&[8, 64])?;
622        let output = fsdp.forward(&input)?;
623
624        assert_eq!(output.shape().dims(), &[8, 32]);
625
626        Ok(())
627    }
628
629    #[test]
630    fn test_fsdp_config() {
631        let config = FsdpConfig::default();
632        assert_eq!(config.min_num_params, 1000);
633        assert_eq!(config.sharding_strategy, ShardingStrategy::FullShard);
634        assert_eq!(config.backward_prefetch, BackwardPrefetch::BackwardPre);
635
636        let custom_config = FsdpConfig {
637            min_num_params: 500,
638            sharding_strategy: ShardingStrategy::ShardGradOp,
639            cpu_offload: true,
640            ..Default::default()
641        };
642
643        assert_eq!(custom_config.min_num_params, 500);
644        assert_eq!(
645            custom_config.sharding_strategy,
646            ShardingStrategy::ShardGradOp
647        );
648        assert!(custom_config.cpu_offload);
649    }
650
651    #[test]
652    fn test_shard_info() {
653        let shard_info = ShardInfo {
654            rank: 0,
655            start_idx: 0,
656            shard_size: 1000,
657            original_shape: Shape::new(vec![10, 100]),
658            is_local: true,
659        };
660
661        assert_eq!(shard_info.rank, 0);
662        assert_eq!(shard_info.shard_size, 1000);
663        assert!(shard_info.is_local);
664    }
665
666    #[test]
667    fn test_memory_stats() {
668        let stats = MemoryStats::default();
669        assert_eq!(stats.peak_memory_mb, 0.0);
670        assert_eq!(stats.num_all_gathers, 0);
671        assert_eq!(stats.num_reduce_scatters, 0);
672    }
673
674    #[tokio::test]
675    async fn test_auto_wrap() -> TorshResult<()> {
676        let process_group =
677            Arc::new(init_process_group(BackendType::Gloo, 0, 1, "127.0.0.1", 12347).await?);
678
679        let linear = Linear::new(100, 50, true);
680        let policy = AutoWrapPolicy::SizeBasedAutoWrap {
681            min_num_params: 1000,
682        };
683
684        let fsdp = auto_wrap_modules(linear, process_group, policy)?;
685
686        // Small module should not be sharded
687        assert!(fsdp.local_sharding_ratio() >= 0.9); // Most parameters should be local
688
689        Ok(())
690    }
691}