Skip to main content

torsh_distributed/three_d_parallelism/
config.rs

1//! Configuration types and validation for 3D parallelism
2//!
3//! This module contains all configuration types, validation logic,
4//! and ranking systems for coordinating 3D parallelism operations.
5
6use crate::{TorshDistributedError, TorshResult};
7use std::collections::HashMap;
8
9/// Configuration for 3D parallelism (Data, Tensor, Pipeline)
10#[derive(Debug, Clone)]
11pub struct ThreeDParallelismConfig {
12    /// Data parallel dimension size
13    pub dp_size: usize,
14    /// Tensor parallel dimension size
15    pub tp_size: usize,
16    /// Pipeline parallel dimension size
17    pub pp_size: usize,
18    /// Total number of layers in the model
19    pub num_layers: usize,
20    /// Micro-batch size for pipeline parallelism
21    pub micro_batch_size: usize,
22    /// Memory optimization strategy
23    pub memory_strategy: MemoryOptimizationStrategy,
24    /// Communication optimization strategy
25    pub comm_strategy: CommunicationStrategy,
26    /// Whether to enable gradient checkpointing
27    pub enable_gradient_checkpointing: bool,
28    /// Whether to enable mixed precision training
29    pub enable_mixed_precision: bool,
30    /// Pipeline schedule type
31    pub pipeline_schedule: PipelineSchedule,
32    /// Maximum memory usage per device (in GB)
33    pub max_memory_per_device: f32,
34    /// Communication timeout in milliseconds
35    pub communication_timeout_ms: u64,
36}
37
38impl Default for ThreeDParallelismConfig {
39    fn default() -> Self {
40        Self {
41            dp_size: 1,
42            tp_size: 1,
43            pp_size: 1,
44            num_layers: 24,
45            micro_batch_size: 1,
46            memory_strategy: MemoryOptimizationStrategy::Standard,
47            comm_strategy: CommunicationStrategy::AllReduce,
48            enable_gradient_checkpointing: false,
49            enable_mixed_precision: false,
50            pipeline_schedule: PipelineSchedule::Interleaved,
51            max_memory_per_device: 8.0,
52            communication_timeout_ms: 30000,
53        }
54    }
55}
56
57impl ThreeDParallelismConfig {
58    /// Validate configuration against available world size
59    pub fn validate(&self, world_size: usize) -> TorshResult<()> {
60        let expected_world_size = self.dp_size * self.tp_size * self.pp_size;
61        if expected_world_size != world_size {
62            return Err(TorshDistributedError::InvalidArgument {
63                arg: "world_size".to_string(),
64                expected: format!("{} devices", expected_world_size),
65                reason: format!(
66                    "3D parallelism configuration mismatch: expected {} devices ({}*{}*{}), got {}",
67                    expected_world_size, self.dp_size, self.tp_size, self.pp_size, world_size
68                ),
69            });
70        }
71
72        if self.num_layers % self.pp_size != 0 {
73            return Err(TorshDistributedError::InvalidArgument {
74                arg: "num_layers".to_string(),
75                expected: format!("divisible by {}", self.pp_size),
76                reason: format!(
77                    "Number of layers ({}) must be divisible by pipeline parallel size ({})",
78                    self.num_layers, self.pp_size
79                ),
80            });
81        }
82
83        if self.micro_batch_size == 0 {
84            return Err(TorshDistributedError::invalid_argument(
85                "micro_batch_size",
86                "greater than 0",
87                "Micro-batch size must be greater than 0",
88            ));
89        }
90
91        Ok(())
92    }
93
94    /// Get number of layers per pipeline stage
95    pub fn layers_per_stage(&self) -> usize {
96        self.num_layers / self.pp_size
97    }
98
99    /// Calculate memory requirements per device
100    pub fn memory_requirements(&self) -> MemoryRequirements {
101        let layers_per_stage = self.layers_per_stage();
102        let model_memory_per_layer = 1024.0; // MB per layer (rough estimate)
103
104        let model_memory = layers_per_stage as f32 * model_memory_per_layer / self.tp_size as f32;
105        let activation_memory = match self.memory_strategy {
106            MemoryOptimizationStrategy::Basic => model_memory * 2.0,
107            MemoryOptimizationStrategy::Standard => model_memory * 1.5,
108            MemoryOptimizationStrategy::Aggressive => model_memory * 1.2,
109            MemoryOptimizationStrategy::Extreme => model_memory * 1.0,
110        };
111
112        let optimizer_memory = model_memory * 2.0; // Adam optimizer
113        let total_memory = model_memory + activation_memory + optimizer_memory;
114
115        MemoryRequirements {
116            model_memory_mb: model_memory,
117            activation_memory_mb: activation_memory,
118            optimizer_memory_mb: optimizer_memory,
119            total_memory_mb: total_memory,
120        }
121    }
122}
123
124/// Memory optimization strategies for 3D parallelism
125#[derive(Debug, Clone, Copy, PartialEq)]
126pub enum MemoryOptimizationStrategy {
127    /// Basic memory management with minimal optimizations
128    Basic,
129    /// Standard memory management with gradient checkpointing
130    Standard,
131    /// Aggressive memory optimization with activation recomputation
132    Aggressive,
133    /// Extreme memory optimization with disk offloading
134    Extreme,
135}
136
137/// Communication optimization strategies
138#[derive(Debug, Clone, Copy, PartialEq)]
139pub enum CommunicationStrategy {
140    /// Standard all-reduce communication
141    AllReduce,
142    /// Hierarchical all-reduce with local reduction first
143    HierarchicalAllReduce,
144    /// Ring-based all-reduce for better bandwidth utilization
145    RingAllReduce,
146    /// Tree-based all-reduce for latency optimization
147    TreeAllReduce,
148    /// Adaptive strategy that switches based on message size
149    Adaptive,
150}
151
152/// Pipeline scheduling strategies
153#[derive(Debug, Clone, Copy, PartialEq)]
154pub enum PipelineSchedule {
155    /// Simple round-robin scheduling
156    RoundRobin,
157    /// Interleaved scheduling for better pipeline utilization
158    Interleaved,
159    /// GPipe scheduling with micro-batching
160    GPipe,
161    /// 1F1B (One Forward One Backward) scheduling
162    OneForwardOneBackward,
163}
164
165/// Memory requirement breakdown
166#[derive(Debug, Clone)]
167pub struct MemoryRequirements {
168    pub model_memory_mb: f32,
169    pub activation_memory_mb: f32,
170    pub optimizer_memory_mb: f32,
171    pub total_memory_mb: f32,
172}
173
174/// Rank mapping for 3D parallelism coordinates
175#[derive(Debug, Clone)]
176pub struct RankMapping {
177    /// Global rank in the distributed system
178    pub global_rank: usize,
179    /// Data parallel rank (0 to dp_size-1)
180    pub dp_rank: usize,
181    /// Tensor parallel rank (0 to tp_size-1)
182    pub tp_rank: usize,
183    /// Pipeline parallel rank (0 to pp_size-1)
184    pub pp_rank: usize,
185    /// Local rank on the node
186    pub local_rank: usize,
187    /// World size
188    pub world_size: usize,
189    /// 3D parallelism configuration
190    pub config: ThreeDParallelismConfig,
191}
192
193impl RankMapping {
194    /// Create new rank mapping from global rank and configuration
195    pub fn new(config: &ThreeDParallelismConfig, global_rank: usize) -> Self {
196        let world_size = config.dp_size * config.tp_size * config.pp_size;
197
198        // Calculate 3D coordinates from global rank
199        // Layout: [dp][tp][pp] with pp as the fastest changing dimension
200        let pp_rank = global_rank % config.pp_size;
201        let tp_rank = (global_rank / config.pp_size) % config.tp_size;
202        let dp_rank = global_rank / (config.pp_size * config.tp_size);
203
204        let local_rank = global_rank % 8; // Assuming 8 GPUs per node
205
206        Self {
207            global_rank,
208            dp_rank,
209            tp_rank,
210            pp_rank,
211            local_rank,
212            world_size,
213            config: config.clone(),
214        }
215    }
216
217    /// Get global rank from 3D coordinates
218    pub fn from_3d_coords(
219        config: &ThreeDParallelismConfig,
220        dp_rank: usize,
221        tp_rank: usize,
222        pp_rank: usize,
223    ) -> usize {
224        dp_rank * (config.tp_size * config.pp_size) + tp_rank * config.pp_size + pp_rank
225    }
226
227    /// Check if this rank is the first in the data parallel group
228    pub fn is_dp_head(&self) -> bool {
229        self.dp_rank == 0
230    }
231
232    /// Check if this rank is the first in the tensor parallel group
233    pub fn is_tp_head(&self) -> bool {
234        self.tp_rank == 0
235    }
236
237    /// Check if this rank is the first in the pipeline parallel group
238    pub fn is_pp_head(&self) -> bool {
239        self.pp_rank == 0
240    }
241
242    /// Check if this rank is the last in the pipeline parallel group
243    pub fn is_pp_tail(&self) -> bool {
244        self.pp_rank == self.config.pp_size - 1
245    }
246
247    /// Get ranks in the same data parallel group
248    pub fn dp_group_ranks(&self) -> Vec<usize> {
249        (0..self.config.dp_size)
250            .map(|dp| Self::from_3d_coords(&self.config, dp, self.tp_rank, self.pp_rank))
251            .collect()
252    }
253
254    /// Get ranks in the same tensor parallel group
255    pub fn tp_group_ranks(&self) -> Vec<usize> {
256        (0..self.config.tp_size)
257            .map(|tp| Self::from_3d_coords(&self.config, self.dp_rank, tp, self.pp_rank))
258            .collect()
259    }
260
261    /// Get ranks in the same pipeline parallel group
262    pub fn pp_group_ranks(&self) -> Vec<usize> {
263        (0..self.config.pp_size)
264            .map(|pp| Self::from_3d_coords(&self.config, self.dp_rank, self.tp_rank, pp))
265            .collect()
266    }
267
268    /// Get the next rank in the pipeline
269    pub fn next_pp_rank(&self) -> Option<usize> {
270        if self.pp_rank < self.config.pp_size - 1 {
271            Some(Self::from_3d_coords(
272                &self.config,
273                self.dp_rank,
274                self.tp_rank,
275                self.pp_rank + 1,
276            ))
277        } else {
278            None
279        }
280    }
281
282    /// Get the previous rank in the pipeline
283    pub fn prev_pp_rank(&self) -> Option<usize> {
284        if self.pp_rank > 0 {
285            Some(Self::from_3d_coords(
286                &self.config,
287                self.dp_rank,
288                self.tp_rank,
289                self.pp_rank - 1,
290            ))
291        } else {
292            None
293        }
294    }
295}
296
297/// Process group identifiers for different parallelism dimensions
298#[derive(Debug, Clone)]
299pub struct ProcessGroupIds {
300    /// Data parallel process groups
301    pub dp_groups: HashMap<(usize, usize), String>, // (tp_rank, pp_rank) -> group_id
302    /// Tensor parallel process groups
303    pub tp_groups: HashMap<(usize, usize), String>, // (dp_rank, pp_rank) -> group_id
304    /// Pipeline parallel process groups
305    pub pp_groups: HashMap<(usize, usize), String>, // (dp_rank, tp_rank) -> group_id
306}
307
308impl ProcessGroupIds {
309    /// Create process group identifiers for a given configuration
310    pub fn new(config: &ThreeDParallelismConfig) -> Self {
311        let mut dp_groups = HashMap::new();
312        let mut tp_groups = HashMap::new();
313        let mut pp_groups = HashMap::new();
314
315        // Create DP groups: one group for each (tp_rank, pp_rank) combination
316        for tp_rank in 0..config.tp_size {
317            for pp_rank in 0..config.pp_size {
318                let group_id = format!("dp_group_tp{}_pp{}", tp_rank, pp_rank);
319                dp_groups.insert((tp_rank, pp_rank), group_id);
320            }
321        }
322
323        // Create TP groups: one group for each (dp_rank, pp_rank) combination
324        for dp_rank in 0..config.dp_size {
325            for pp_rank in 0..config.pp_size {
326                let group_id = format!("tp_group_dp{}_pp{}", dp_rank, pp_rank);
327                tp_groups.insert((dp_rank, pp_rank), group_id);
328            }
329        }
330
331        // Create PP groups: one group for each (dp_rank, tp_rank) combination
332        for dp_rank in 0..config.dp_size {
333            for tp_rank in 0..config.tp_size {
334                let group_id = format!("pp_group_dp{}_tp{}", dp_rank, tp_rank);
335                pp_groups.insert((dp_rank, tp_rank), group_id);
336            }
337        }
338
339        Self {
340            dp_groups,
341            tp_groups,
342            pp_groups,
343        }
344    }
345
346    /// Get data parallel group ID for given coordinates
347    pub fn get_dp_group_id(&self, tp_rank: usize, pp_rank: usize) -> Option<&String> {
348        self.dp_groups.get(&(tp_rank, pp_rank))
349    }
350
351    /// Get tensor parallel group ID for given coordinates
352    pub fn get_tp_group_id(&self, dp_rank: usize, pp_rank: usize) -> Option<&String> {
353        self.tp_groups.get(&(dp_rank, pp_rank))
354    }
355
356    /// Get pipeline parallel group ID for given coordinates
357    pub fn get_pp_group_id(&self, dp_rank: usize, tp_rank: usize) -> Option<&String> {
358        self.pp_groups.get(&(dp_rank, tp_rank))
359    }
360}