1#![allow(unused_variables)] #[allow(unused_imports)] use crate::errors::{runtime_error, tensor_op_error, Result};
11use crate::Tensor;
12use std::sync::Arc;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum ModelParallelStrategy {
17 Pipeline,
19 Tensor,
21 Hybrid,
23}
24
25#[derive(Debug, Clone)]
27pub struct ModelParallelConfig {
28 pub num_devices: usize,
30 pub strategy: ModelParallelStrategy,
32 pub device_ids: Vec<usize>,
34 pub pipeline_depth: Option<usize>,
36 pub tensor_split_dim: Option<usize>,
38 pub gradient_checkpointing: bool,
40 pub comm_backend: CommunicationBackend,
42}
43
44impl Default for ModelParallelConfig {
45 fn default() -> Self {
46 Self {
47 num_devices: 1,
48 strategy: ModelParallelStrategy::Pipeline,
49 device_ids: vec![0],
50 pipeline_depth: None,
51 tensor_split_dim: None,
52 gradient_checkpointing: false,
53 comm_backend: CommunicationBackend::Nccl,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
60pub enum CommunicationBackend {
61 Nccl,
63 Mpi,
65 Gloo,
67 Custom,
69}
70
71#[derive(Debug, Clone)]
73pub struct DistributedTensor {
74 pub local_shard: Tensor,
76 pub global_shape: Vec<usize>,
78 pub partition: TensorPartition,
80 pub device_id: usize,
82}
83
84#[derive(Debug, Clone)]
86pub struct TensorPartition {
87 pub split_dim: usize,
89 pub start_idx: usize,
91 pub end_idx: usize,
93 pub num_partitions: usize,
95 pub partition_rank: usize,
97}
98
99impl DistributedTensor {
100 pub fn new(
102 local_shard: Tensor,
103 global_shape: Vec<usize>,
104 partition: TensorPartition,
105 device_id: usize,
106 ) -> Self {
107 Self {
108 local_shard,
109 global_shape,
110 partition,
111 device_id,
112 }
113 }
114
115 pub fn local_shape(&self) -> Vec<usize> {
117 self.local_shard.shape()
118 }
119
120 pub fn requires_communication(&self) -> bool {
122 self.partition.num_partitions > 1
123 }
124}
125
126pub struct ModelParallelContext {
128 config: ModelParallelConfig,
129 rank: usize,
130 world_size: usize,
131 pub(crate) communicator: Arc<dyn Communicator>,
132 #[allow(dead_code)]
133 device_mesh: DeviceMesh,
134}
135
136impl ModelParallelContext {
137 pub fn new(config: ModelParallelConfig) -> Result<Self> {
139 let world_size = config.num_devices;
140 let rank = 0; let communicator = create_communicator(&config.comm_backend)?;
143 let device_mesh = DeviceMesh::new(&config.device_ids, config.strategy)?;
144
145 Ok(Self {
146 config,
147 rank,
148 world_size,
149 communicator,
150 device_mesh,
151 })
152 }
153
154 pub fn rank(&self) -> usize {
156 self.rank
157 }
158
159 pub fn world_size(&self) -> usize {
161 self.world_size
162 }
163
164 pub fn partition_tensor(&self, tensor: &Tensor, split_dim: usize) -> Result<DistributedTensor> {
166 let shape = tensor.shape();
167 if split_dim >= shape.len() {
168 return Err(tensor_op_error(
169 "split_tensor",
170 format!(
171 "Split dimension {} out of bounds for tensor with {} dimensions",
172 split_dim,
173 shape.len()
174 ),
175 ));
176 }
177
178 let dim_size = shape[split_dim];
179 let chunk_size = dim_size.div_ceil(self.world_size);
180 let start_idx = self.rank * chunk_size;
181 let end_idx = ((self.rank + 1) * chunk_size).min(dim_size);
182
183 let local_shard = tensor.slice(split_dim, start_idx, end_idx)?;
185
186 let partition = TensorPartition {
187 split_dim,
188 start_idx,
189 end_idx,
190 num_partitions: self.world_size,
191 partition_rank: self.rank,
192 };
193
194 Ok(DistributedTensor::new(
195 local_shard,
196 shape.to_vec(),
197 partition,
198 self.config.device_ids[self.rank],
199 ))
200 }
201
202 pub fn all_gather(&self, distributed: &DistributedTensor) -> Result<Tensor> {
204 if !distributed.requires_communication() {
205 return Ok(distributed.local_shard.clone());
206 }
207
208 self.communicator
209 .all_gather(&distributed.local_shard, distributed.partition.split_dim)
210 }
211
212 pub fn reduce_scatter(&self, tensor: &Tensor, split_dim: usize) -> Result<Tensor> {
214 self.communicator.reduce_scatter(tensor, split_dim)
215 }
216
217 pub fn all_reduce(&self, tensor: &mut Tensor) -> Result<()> {
219 self.communicator.all_reduce(tensor)
220 }
221
222 pub fn broadcast(&self, tensor: &mut Tensor, root: usize) -> Result<()> {
224 self.communicator.broadcast(tensor, root)
225 }
226}
227
228#[derive(Debug, Clone)]
230pub struct DeviceMesh {
231 device_ids: Vec<usize>,
233 topology: MeshTopology,
235}
236
237#[derive(Debug, Clone)]
238enum MeshTopology {
239 Linear,
241 Grid2D { rows: usize, cols: usize },
243 #[allow(dead_code)]
245 Grid3D { x: usize, y: usize, z: usize },
246}
247
248impl DeviceMesh {
249 fn new(device_ids: &[usize], strategy: ModelParallelStrategy) -> Result<Self> {
250 let topology = match strategy {
251 ModelParallelStrategy::Pipeline => MeshTopology::Linear,
252 ModelParallelStrategy::Tensor => {
253 let n = device_ids.len();
255 let rows = (n as f64).sqrt().ceil() as usize;
256 let cols = n.div_ceil(rows);
257 MeshTopology::Grid2D { rows, cols }
258 },
259 ModelParallelStrategy::Hybrid => {
260 MeshTopology::Linear
263 },
264 };
265
266 Ok(Self {
267 device_ids: device_ids.to_vec(),
268 topology,
269 })
270 }
271
272 pub fn device_at(&self, coord: &[usize]) -> Option<usize> {
274 match &self.topology {
275 MeshTopology::Linear => {
276 coord.first().and_then(|&idx| self.device_ids.get(idx).copied())
277 },
278 MeshTopology::Grid2D { rows, cols } => {
279 if coord.len() >= 2 {
280 let idx = coord[0] * cols + coord[1];
281 self.device_ids.get(idx).copied()
282 } else {
283 None
284 }
285 },
286 MeshTopology::Grid3D { x, y, z } => {
287 if coord.len() >= 3 {
288 let idx = coord[0] * y * z + coord[1] * z + coord[2];
289 self.device_ids.get(idx).copied()
290 } else {
291 None
292 }
293 },
294 }
295 }
296}
297
298pub trait Communicator: Send + Sync {
300 fn all_gather(&self, tensor: &Tensor, split_dim: usize) -> Result<Tensor>;
302
303 fn reduce_scatter(&self, tensor: &Tensor, split_dim: usize) -> Result<Tensor>;
305
306 fn all_reduce(&self, tensor: &mut Tensor) -> Result<()>;
308
309 fn send(&self, tensor: &Tensor, dest: usize) -> Result<()>;
311
312 fn recv(&self, shape: &[usize], src: usize) -> Result<Tensor>;
314
315 fn broadcast(&self, tensor: &mut Tensor, root: usize) -> Result<()>;
317}
318
319fn create_communicator(backend: &CommunicationBackend) -> Result<Arc<dyn Communicator>> {
321 match backend {
322 CommunicationBackend::Nccl => {
323 #[cfg(feature = "nccl")]
324 {
325 use super::nccl_communicator::create_nccl_communicator;
326 let rank =
329 std::env::var("RANK").unwrap_or_else(|_| "0".to_string()).parse().unwrap_or(0);
330 let world_size = std::env::var("WORLD_SIZE")
331 .unwrap_or_else(|_| "1".to_string())
332 .parse()
333 .unwrap_or(1);
334 let device_id = std::env::var("LOCAL_RANK")
335 .unwrap_or_else(|_| "0".to_string())
336 .parse()
337 .unwrap_or(0);
338
339 create_nccl_communicator(rank, world_size, device_id)
340 }
341
342 #[cfg(not(feature = "nccl"))]
343 return Err(runtime_error(
344 "NCCL backend requested but not compiled with nccl feature",
345 ));
346 },
347 CommunicationBackend::Mpi => {
348 use super::mpi_communicator::MpiCommunicatorImpl;
349 Ok(Arc::new(MpiCommunicatorImpl::new()?))
350 },
351 CommunicationBackend::Gloo => {
352 Ok(Arc::new(MockCommunicator::new()))
354 },
355 CommunicationBackend::Custom => Ok(Arc::new(MockCommunicator::new())),
356 }
357}
358
359struct MockCommunicator;
361
362impl MockCommunicator {
363 fn new() -> Self {
364 Self
365 }
366}
367
368impl Communicator for MockCommunicator {
369 fn all_gather(&self, tensor: &Tensor, _split_dim: usize) -> Result<Tensor> {
370 Ok(tensor.clone())
372 }
373
374 fn reduce_scatter(&self, tensor: &Tensor, _split_dim: usize) -> Result<Tensor> {
375 Ok(tensor.clone())
376 }
377
378 fn all_reduce(&self, _tensor: &mut Tensor) -> Result<()> {
379 Ok(())
380 }
381
382 fn send(&self, _tensor: &Tensor, _dest: usize) -> Result<()> {
383 Ok(())
384 }
385
386 fn recv(&self, shape: &[usize], _src: usize) -> Result<Tensor> {
387 Tensor::zeros(shape)
388 }
389
390 fn broadcast(&self, _tensor: &mut Tensor, _root: usize) -> Result<()> {
391 Ok(())
392 }
393}
394
395#[derive(Debug, Clone)]
397pub struct PipelineSchedule {
398 pub num_stages: usize,
400 pub num_microbatches: usize,
402 pub schedule_type: PipelineScheduleType,
404}
405
406#[derive(Debug, Clone, Copy)]
407pub enum PipelineScheduleType {
408 Sequential,
410 OneForwardOneBackward,
412 InterleavedOneF1B,
414}
415
416impl PipelineSchedule {
417 pub fn new(
419 num_stages: usize,
420 num_microbatches: usize,
421 schedule_type: PipelineScheduleType,
422 ) -> Self {
423 Self {
424 num_stages,
425 num_microbatches,
426 schedule_type,
427 }
428 }
429
430 pub fn get_stage_schedule(&self, stage_id: usize) -> Vec<PipelineOp> {
432 match self.schedule_type {
433 PipelineScheduleType::Sequential => self.sequential_schedule(stage_id),
434 PipelineScheduleType::OneForwardOneBackward => self.one_f1b_schedule(stage_id),
435 PipelineScheduleType::InterleavedOneF1B => self.interleaved_1f1b_schedule(stage_id),
436 }
437 }
438
439 fn sequential_schedule(&self, stage_id: usize) -> Vec<PipelineOp> {
440 let mut ops = Vec::new();
441
442 for mb in 0..self.num_microbatches {
444 ops.push(PipelineOp::Forward { microbatch_id: mb });
445 }
446
447 for mb in (0..self.num_microbatches).rev() {
449 ops.push(PipelineOp::Backward { microbatch_id: mb });
450 }
451
452 ops
453 }
454
455 fn one_f1b_schedule(&self, stage_id: usize) -> Vec<PipelineOp> {
456 let mut ops = Vec::new();
457 let num_warmup = self.num_stages - stage_id - 1;
458
459 for mb in 0..num_warmup.min(self.num_microbatches) {
461 ops.push(PipelineOp::Forward { microbatch_id: mb });
462 }
463
464 let steady_state_mbs = self.num_microbatches.saturating_sub(num_warmup);
466 for i in 0..steady_state_mbs {
467 let forward_mb = num_warmup + i;
468 let backward_mb = i;
469
470 if forward_mb < self.num_microbatches {
471 ops.push(PipelineOp::Forward {
472 microbatch_id: forward_mb,
473 });
474 }
475 ops.push(PipelineOp::Backward {
476 microbatch_id: backward_mb,
477 });
478 }
479
480 for mb in steady_state_mbs..self.num_microbatches {
482 ops.push(PipelineOp::Backward { microbatch_id: mb });
483 }
484
485 ops
486 }
487
488 fn interleaved_1f1b_schedule(&self, _stage_id: usize) -> Vec<PipelineOp> {
489 self.one_f1b_schedule(_stage_id)
491 }
492}
493
494#[derive(Debug, Clone)]
495pub enum PipelineOp {
496 Forward { microbatch_id: usize },
497 Backward { microbatch_id: usize },
498 SendActivation { to_stage: usize },
499 RecvActivation { from_stage: usize },
500 SendGradient { to_stage: usize },
501 RecvGradient { from_stage: usize },
502}
503
504#[cfg(test)]
505mod tests {
506 use super::*;
507
508 #[test]
509 fn test_tensor_partition() {
510 let ctx = ModelParallelContext::new(ModelParallelConfig {
511 num_devices: 4,
512 device_ids: vec![0, 1, 2, 3],
513 comm_backend: CommunicationBackend::Custom, ..Default::default()
515 })
516 .expect("operation failed in test");
517
518 let tensor = Tensor::zeros(&[128, 512]).expect("Failed to create zero tensor");
519 let distributed = ctx.partition_tensor(&tensor, 0).expect("tensor operation failed");
520
521 assert_eq!(distributed.global_shape, vec![128, 512]);
523 assert_eq!(distributed.partition.split_dim, 0);
524 assert_eq!(distributed.partition.start_idx, 0);
525 assert_eq!(distributed.partition.end_idx, 32);
526 assert_eq!(distributed.partition.num_partitions, 4);
527
528 let local_shape = distributed.local_shard.shape();
530 assert_eq!(local_shape, vec![32, 512]); }
532
533 #[test]
534 fn test_device_mesh() {
535 let mesh = DeviceMesh::new(&[0, 1, 2, 3], ModelParallelStrategy::Tensor)
536 .expect("tensor operation failed");
537
538 assert_eq!(mesh.device_at(&[0, 0]), Some(0));
539 assert_eq!(mesh.device_at(&[0, 1]), Some(1));
540 assert_eq!(mesh.device_at(&[1, 0]), Some(2));
541 assert_eq!(mesh.device_at(&[1, 1]), Some(3));
542 }
543
544 #[test]
545 fn test_pipeline_schedule() {
546 let schedule = PipelineSchedule::new(4, 8, PipelineScheduleType::OneForwardOneBackward);
547 let stage0_ops = schedule.get_stage_schedule(0);
548
549 let forward_ops: Vec<_> = stage0_ops
551 .iter()
552 .filter(|op| matches!(op, PipelineOp::Forward { .. }))
553 .collect();
554 assert!(!forward_ops.is_empty());
555 }
556}