torsh_distributed/three_d_parallelism/
config.rs1use crate::{TorshDistributedError, TorshResult};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
11pub struct ThreeDParallelismConfig {
12 pub dp_size: usize,
14 pub tp_size: usize,
16 pub pp_size: usize,
18 pub num_layers: usize,
20 pub micro_batch_size: usize,
22 pub memory_strategy: MemoryOptimizationStrategy,
24 pub comm_strategy: CommunicationStrategy,
26 pub enable_gradient_checkpointing: bool,
28 pub enable_mixed_precision: bool,
30 pub pipeline_schedule: PipelineSchedule,
32 pub max_memory_per_device: f32,
34 pub communication_timeout_ms: u64,
36}
37
38impl Default for ThreeDParallelismConfig {
39 fn default() -> Self {
40 Self {
41 dp_size: 1,
42 tp_size: 1,
43 pp_size: 1,
44 num_layers: 24,
45 micro_batch_size: 1,
46 memory_strategy: MemoryOptimizationStrategy::Standard,
47 comm_strategy: CommunicationStrategy::AllReduce,
48 enable_gradient_checkpointing: false,
49 enable_mixed_precision: false,
50 pipeline_schedule: PipelineSchedule::Interleaved,
51 max_memory_per_device: 8.0,
52 communication_timeout_ms: 30000,
53 }
54 }
55}
56
57impl ThreeDParallelismConfig {
58 pub fn validate(&self, world_size: usize) -> TorshResult<()> {
60 let expected_world_size = self.dp_size * self.tp_size * self.pp_size;
61 if expected_world_size != world_size {
62 return Err(TorshDistributedError::InvalidArgument {
63 arg: "world_size".to_string(),
64 expected: format!("{} devices", expected_world_size),
65 reason: format!(
66 "3D parallelism configuration mismatch: expected {} devices ({}*{}*{}), got {}",
67 expected_world_size, self.dp_size, self.tp_size, self.pp_size, world_size
68 ),
69 });
70 }
71
72 if self.num_layers % self.pp_size != 0 {
73 return Err(TorshDistributedError::InvalidArgument {
74 arg: "num_layers".to_string(),
75 expected: format!("divisible by {}", self.pp_size),
76 reason: format!(
77 "Number of layers ({}) must be divisible by pipeline parallel size ({})",
78 self.num_layers, self.pp_size
79 ),
80 });
81 }
82
83 if self.micro_batch_size == 0 {
84 return Err(TorshDistributedError::invalid_argument(
85 "micro_batch_size",
86 "greater than 0",
87 "Micro-batch size must be greater than 0",
88 ));
89 }
90
91 Ok(())
92 }
93
94 pub fn layers_per_stage(&self) -> usize {
96 self.num_layers / self.pp_size
97 }
98
99 pub fn memory_requirements(&self) -> MemoryRequirements {
101 let layers_per_stage = self.layers_per_stage();
102 let model_memory_per_layer = 1024.0; let model_memory = layers_per_stage as f32 * model_memory_per_layer / self.tp_size as f32;
105 let activation_memory = match self.memory_strategy {
106 MemoryOptimizationStrategy::Basic => model_memory * 2.0,
107 MemoryOptimizationStrategy::Standard => model_memory * 1.5,
108 MemoryOptimizationStrategy::Aggressive => model_memory * 1.2,
109 MemoryOptimizationStrategy::Extreme => model_memory * 1.0,
110 };
111
112 let optimizer_memory = model_memory * 2.0; let total_memory = model_memory + activation_memory + optimizer_memory;
114
115 MemoryRequirements {
116 model_memory_mb: model_memory,
117 activation_memory_mb: activation_memory,
118 optimizer_memory_mb: optimizer_memory,
119 total_memory_mb: total_memory,
120 }
121 }
122}
123
124#[derive(Debug, Clone, Copy, PartialEq)]
126pub enum MemoryOptimizationStrategy {
127 Basic,
129 Standard,
131 Aggressive,
133 Extreme,
135}
136
137#[derive(Debug, Clone, Copy, PartialEq)]
139pub enum CommunicationStrategy {
140 AllReduce,
142 HierarchicalAllReduce,
144 RingAllReduce,
146 TreeAllReduce,
148 Adaptive,
150}
151
152#[derive(Debug, Clone, Copy, PartialEq)]
154pub enum PipelineSchedule {
155 RoundRobin,
157 Interleaved,
159 GPipe,
161 OneForwardOneBackward,
163}
164
165#[derive(Debug, Clone)]
167pub struct MemoryRequirements {
168 pub model_memory_mb: f32,
169 pub activation_memory_mb: f32,
170 pub optimizer_memory_mb: f32,
171 pub total_memory_mb: f32,
172}
173
174#[derive(Debug, Clone)]
176pub struct RankMapping {
177 pub global_rank: usize,
179 pub dp_rank: usize,
181 pub tp_rank: usize,
183 pub pp_rank: usize,
185 pub local_rank: usize,
187 pub world_size: usize,
189 pub config: ThreeDParallelismConfig,
191}
192
193impl RankMapping {
194 pub fn new(config: &ThreeDParallelismConfig, global_rank: usize) -> Self {
196 let world_size = config.dp_size * config.tp_size * config.pp_size;
197
198 let pp_rank = global_rank % config.pp_size;
201 let tp_rank = (global_rank / config.pp_size) % config.tp_size;
202 let dp_rank = global_rank / (config.pp_size * config.tp_size);
203
204 let local_rank = global_rank % 8; Self {
207 global_rank,
208 dp_rank,
209 tp_rank,
210 pp_rank,
211 local_rank,
212 world_size,
213 config: config.clone(),
214 }
215 }
216
217 pub fn from_3d_coords(
219 config: &ThreeDParallelismConfig,
220 dp_rank: usize,
221 tp_rank: usize,
222 pp_rank: usize,
223 ) -> usize {
224 dp_rank * (config.tp_size * config.pp_size) + tp_rank * config.pp_size + pp_rank
225 }
226
227 pub fn is_dp_head(&self) -> bool {
229 self.dp_rank == 0
230 }
231
232 pub fn is_tp_head(&self) -> bool {
234 self.tp_rank == 0
235 }
236
237 pub fn is_pp_head(&self) -> bool {
239 self.pp_rank == 0
240 }
241
242 pub fn is_pp_tail(&self) -> bool {
244 self.pp_rank == self.config.pp_size - 1
245 }
246
247 pub fn dp_group_ranks(&self) -> Vec<usize> {
249 (0..self.config.dp_size)
250 .map(|dp| Self::from_3d_coords(&self.config, dp, self.tp_rank, self.pp_rank))
251 .collect()
252 }
253
254 pub fn tp_group_ranks(&self) -> Vec<usize> {
256 (0..self.config.tp_size)
257 .map(|tp| Self::from_3d_coords(&self.config, self.dp_rank, tp, self.pp_rank))
258 .collect()
259 }
260
261 pub fn pp_group_ranks(&self) -> Vec<usize> {
263 (0..self.config.pp_size)
264 .map(|pp| Self::from_3d_coords(&self.config, self.dp_rank, self.tp_rank, pp))
265 .collect()
266 }
267
268 pub fn next_pp_rank(&self) -> Option<usize> {
270 if self.pp_rank < self.config.pp_size - 1 {
271 Some(Self::from_3d_coords(
272 &self.config,
273 self.dp_rank,
274 self.tp_rank,
275 self.pp_rank + 1,
276 ))
277 } else {
278 None
279 }
280 }
281
282 pub fn prev_pp_rank(&self) -> Option<usize> {
284 if self.pp_rank > 0 {
285 Some(Self::from_3d_coords(
286 &self.config,
287 self.dp_rank,
288 self.tp_rank,
289 self.pp_rank - 1,
290 ))
291 } else {
292 None
293 }
294 }
295}
296
297#[derive(Debug, Clone)]
299pub struct ProcessGroupIds {
300 pub dp_groups: HashMap<(usize, usize), String>, pub tp_groups: HashMap<(usize, usize), String>, pub pp_groups: HashMap<(usize, usize), String>, }
307
308impl ProcessGroupIds {
309 pub fn new(config: &ThreeDParallelismConfig) -> Self {
311 let mut dp_groups = HashMap::new();
312 let mut tp_groups = HashMap::new();
313 let mut pp_groups = HashMap::new();
314
315 for tp_rank in 0..config.tp_size {
317 for pp_rank in 0..config.pp_size {
318 let group_id = format!("dp_group_tp{}_pp{}", tp_rank, pp_rank);
319 dp_groups.insert((tp_rank, pp_rank), group_id);
320 }
321 }
322
323 for dp_rank in 0..config.dp_size {
325 for pp_rank in 0..config.pp_size {
326 let group_id = format!("tp_group_dp{}_pp{}", dp_rank, pp_rank);
327 tp_groups.insert((dp_rank, pp_rank), group_id);
328 }
329 }
330
331 for dp_rank in 0..config.dp_size {
333 for tp_rank in 0..config.tp_size {
334 let group_id = format!("pp_group_dp{}_tp{}", dp_rank, tp_rank);
335 pp_groups.insert((dp_rank, tp_rank), group_id);
336 }
337 }
338
339 Self {
340 dp_groups,
341 tp_groups,
342 pp_groups,
343 }
344 }
345
346 pub fn get_dp_group_id(&self, tp_rank: usize, pp_rank: usize) -> Option<&String> {
348 self.dp_groups.get(&(tp_rank, pp_rank))
349 }
350
351 pub fn get_tp_group_id(&self, dp_rank: usize, pp_rank: usize) -> Option<&String> {
353 self.tp_groups.get(&(dp_rank, pp_rank))
354 }
355
356 pub fn get_pp_group_id(&self, dp_rank: usize, tp_rank: usize) -> Option<&String> {
358 self.pp_groups.get(&(dp_rank, tp_rank))
359 }
360}