Skip to main content

trustformers_core/parallel/
model_parallel.rs

1//! Model Parallel Support for Large Models
2//!
3//! This module provides infrastructure for distributing model layers and tensors
4//! across multiple devices (GPUs) to enable training and inference of models
5//! that are too large to fit on a single device.
6
7#![allow(unused_variables)] // Model parallelism implementation
8
9#[allow(unused_imports)] // Used conditionally based on feature gates
10use crate::errors::{runtime_error, tensor_op_error, Result};
11use crate::Tensor;
12use std::sync::Arc;
13
14/// Model parallelism strategy
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum ModelParallelStrategy {
17    /// Pipeline parallelism - split model by layers
18    Pipeline,
19    /// Tensor parallelism - split individual layers
20    Tensor,
21    /// Hybrid approach combining both
22    Hybrid,
23}
24
25/// Configuration for model parallel execution
26#[derive(Debug, Clone)]
27pub struct ModelParallelConfig {
28    /// Number of devices to use
29    pub num_devices: usize,
30    /// Parallelism strategy
31    pub strategy: ModelParallelStrategy,
32    /// Device IDs to use (e.g., [0, 1, 2, 3] for 4 GPUs)
33    pub device_ids: Vec<usize>,
34    /// Pipeline depth for pipeline parallelism
35    pub pipeline_depth: Option<usize>,
36    /// Tensor split dimension for tensor parallelism
37    pub tensor_split_dim: Option<usize>,
38    /// Enable gradient checkpointing to save memory
39    pub gradient_checkpointing: bool,
40    /// Communication backend
41    pub comm_backend: CommunicationBackend,
42}
43
44impl Default for ModelParallelConfig {
45    fn default() -> Self {
46        Self {
47            num_devices: 1,
48            strategy: ModelParallelStrategy::Pipeline,
49            device_ids: vec![0],
50            pipeline_depth: None,
51            tensor_split_dim: None,
52            gradient_checkpointing: false,
53            comm_backend: CommunicationBackend::Nccl,
54        }
55    }
56}
57
58/// Communication backend for model parallel
59#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
60pub enum CommunicationBackend {
61    /// NVIDIA Collective Communication Library
62    Nccl,
63    /// Message Passing Interface
64    Mpi,
65    /// Gloo (CPU communication)
66    Gloo,
67    /// Custom implementation
68    Custom,
69}
70
71/// Distributed tensor that can be split across devices
72#[derive(Debug, Clone)]
73pub struct DistributedTensor {
74    /// Local shard of the tensor
75    pub local_shard: Tensor,
76    /// Global shape of the full tensor
77    pub global_shape: Vec<usize>,
78    /// Partition info
79    pub partition: TensorPartition,
80    /// Device ID where this shard resides
81    pub device_id: usize,
82}
83
84/// Information about how a tensor is partitioned
85#[derive(Debug, Clone)]
86pub struct TensorPartition {
87    /// Dimension along which tensor is split
88    pub split_dim: usize,
89    /// Start index in the global tensor
90    pub start_idx: usize,
91    /// End index in the global tensor
92    pub end_idx: usize,
93    /// Total number of partitions
94    pub num_partitions: usize,
95    /// This partition's rank
96    pub partition_rank: usize,
97}
98
99impl DistributedTensor {
100    /// Create a new distributed tensor from a local shard
101    pub fn new(
102        local_shard: Tensor,
103        global_shape: Vec<usize>,
104        partition: TensorPartition,
105        device_id: usize,
106    ) -> Self {
107        Self {
108            local_shard,
109            global_shape,
110            partition,
111            device_id,
112        }
113    }
114
115    /// Get the local shape of this shard
116    pub fn local_shape(&self) -> Vec<usize> {
117        self.local_shard.shape()
118    }
119
120    /// Check if this tensor needs communication for operations
121    pub fn requires_communication(&self) -> bool {
122        self.partition.num_partitions > 1
123    }
124}
125
126/// Model parallel context managing distributed execution
127pub struct ModelParallelContext {
128    config: ModelParallelConfig,
129    rank: usize,
130    world_size: usize,
131    pub(crate) communicator: Arc<dyn Communicator>,
132    #[allow(dead_code)]
133    device_mesh: DeviceMesh,
134}
135
136impl ModelParallelContext {
137    /// Initialize model parallel context
138    pub fn new(config: ModelParallelConfig) -> Result<Self> {
139        let world_size = config.num_devices;
140        let rank = 0; // Will be set by init process
141
142        let communicator = create_communicator(&config.comm_backend)?;
143        let device_mesh = DeviceMesh::new(&config.device_ids, config.strategy)?;
144
145        Ok(Self {
146            config,
147            rank,
148            world_size,
149            communicator,
150            device_mesh,
151        })
152    }
153
154    /// Get current process rank
155    pub fn rank(&self) -> usize {
156        self.rank
157    }
158
159    /// Get total world size
160    pub fn world_size(&self) -> usize {
161        self.world_size
162    }
163
164    /// Partition a tensor across devices
165    pub fn partition_tensor(&self, tensor: &Tensor, split_dim: usize) -> Result<DistributedTensor> {
166        let shape = tensor.shape();
167        if split_dim >= shape.len() {
168            return Err(tensor_op_error(
169                "split_tensor",
170                format!(
171                    "Split dimension {} out of bounds for tensor with {} dimensions",
172                    split_dim,
173                    shape.len()
174                ),
175            ));
176        }
177
178        let dim_size = shape[split_dim];
179        let chunk_size = dim_size.div_ceil(self.world_size);
180        let start_idx = self.rank * chunk_size;
181        let end_idx = ((self.rank + 1) * chunk_size).min(dim_size);
182
183        // Extract local shard by slicing the tensor along the split dimension
184        let local_shard = tensor.slice(split_dim, start_idx, end_idx)?;
185
186        let partition = TensorPartition {
187            split_dim,
188            start_idx,
189            end_idx,
190            num_partitions: self.world_size,
191            partition_rank: self.rank,
192        };
193
194        Ok(DistributedTensor::new(
195            local_shard,
196            shape.to_vec(),
197            partition,
198            self.config.device_ids[self.rank],
199        ))
200    }
201
202    /// Gather distributed tensor to full tensor
203    pub fn all_gather(&self, distributed: &DistributedTensor) -> Result<Tensor> {
204        if !distributed.requires_communication() {
205            return Ok(distributed.local_shard.clone());
206        }
207
208        self.communicator
209            .all_gather(&distributed.local_shard, distributed.partition.split_dim)
210    }
211
212    /// Reduce scattered tensor across devices
213    pub fn reduce_scatter(&self, tensor: &Tensor, split_dim: usize) -> Result<Tensor> {
214        self.communicator.reduce_scatter(tensor, split_dim)
215    }
216
217    /// All-reduce operation for gradient synchronization
218    pub fn all_reduce(&self, tensor: &mut Tensor) -> Result<()> {
219        self.communicator.all_reduce(tensor)
220    }
221
222    /// Broadcast tensor from root rank to all other ranks
223    pub fn broadcast(&self, tensor: &mut Tensor, root: usize) -> Result<()> {
224        self.communicator.broadcast(tensor, root)
225    }
226}
227
228/// Device mesh for organizing devices in model parallel
229#[derive(Debug, Clone)]
230pub struct DeviceMesh {
231    /// Device IDs in the mesh
232    device_ids: Vec<usize>,
233    /// Topology of the mesh
234    topology: MeshTopology,
235}
236
237#[derive(Debug, Clone)]
238enum MeshTopology {
239    /// Linear arrangement (for pipeline parallel)
240    Linear,
241    /// 2D mesh (for tensor parallel)
242    Grid2D { rows: usize, cols: usize },
243    /// 3D mesh (for hybrid parallel)
244    #[allow(dead_code)]
245    Grid3D { x: usize, y: usize, z: usize },
246}
247
248impl DeviceMesh {
249    fn new(device_ids: &[usize], strategy: ModelParallelStrategy) -> Result<Self> {
250        let topology = match strategy {
251            ModelParallelStrategy::Pipeline => MeshTopology::Linear,
252            ModelParallelStrategy::Tensor => {
253                // For tensor parallel, try to create a balanced 2D grid
254                let n = device_ids.len();
255                let rows = (n as f64).sqrt().ceil() as usize;
256                let cols = n.div_ceil(rows);
257                MeshTopology::Grid2D { rows, cols }
258            },
259            ModelParallelStrategy::Hybrid => {
260                // For hybrid, create a 3D mesh if possible
261                // This is a simplified version
262                MeshTopology::Linear
263            },
264        };
265
266        Ok(Self {
267            device_ids: device_ids.to_vec(),
268            topology,
269        })
270    }
271
272    /// Get device ID at a given coordinate
273    pub fn device_at(&self, coord: &[usize]) -> Option<usize> {
274        match &self.topology {
275            MeshTopology::Linear => {
276                coord.first().and_then(|&idx| self.device_ids.get(idx).copied())
277            },
278            MeshTopology::Grid2D { rows, cols } => {
279                if coord.len() >= 2 {
280                    let idx = coord[0] * cols + coord[1];
281                    self.device_ids.get(idx).copied()
282                } else {
283                    None
284                }
285            },
286            MeshTopology::Grid3D { x, y, z } => {
287                if coord.len() >= 3 {
288                    let idx = coord[0] * y * z + coord[1] * z + coord[2];
289                    self.device_ids.get(idx).copied()
290                } else {
291                    None
292                }
293            },
294        }
295    }
296}
297
298/// Communication interface for model parallel operations
299pub trait Communicator: Send + Sync {
300    /// All-gather operation
301    fn all_gather(&self, tensor: &Tensor, split_dim: usize) -> Result<Tensor>;
302
303    /// Reduce-scatter operation
304    fn reduce_scatter(&self, tensor: &Tensor, split_dim: usize) -> Result<Tensor>;
305
306    /// All-reduce operation
307    fn all_reduce(&self, tensor: &mut Tensor) -> Result<()>;
308
309    /// Point-to-point send
310    fn send(&self, tensor: &Tensor, dest: usize) -> Result<()>;
311
312    /// Point-to-point receive
313    fn recv(&self, shape: &[usize], src: usize) -> Result<Tensor>;
314
315    /// Broadcast from root
316    fn broadcast(&self, tensor: &mut Tensor, root: usize) -> Result<()>;
317}
318
319/// Create appropriate communicator based on backend
320fn create_communicator(backend: &CommunicationBackend) -> Result<Arc<dyn Communicator>> {
321    match backend {
322        CommunicationBackend::Nccl => {
323            #[cfg(feature = "nccl")]
324            {
325                use super::nccl_communicator::create_nccl_communicator;
326                // Default to rank 0, world_size 1, device 0 for single-process case
327                // In a real distributed setup, these would come from environment or config
328                let rank =
329                    std::env::var("RANK").unwrap_or_else(|_| "0".to_string()).parse().unwrap_or(0);
330                let world_size = std::env::var("WORLD_SIZE")
331                    .unwrap_or_else(|_| "1".to_string())
332                    .parse()
333                    .unwrap_or(1);
334                let device_id = std::env::var("LOCAL_RANK")
335                    .unwrap_or_else(|_| "0".to_string())
336                    .parse()
337                    .unwrap_or(0);
338
339                create_nccl_communicator(rank, world_size, device_id)
340            }
341
342            #[cfg(not(feature = "nccl"))]
343            return Err(runtime_error(
344                "NCCL backend requested but not compiled with nccl feature",
345            ));
346        },
347        CommunicationBackend::Mpi => {
348            use super::mpi_communicator::MpiCommunicatorImpl;
349            Ok(Arc::new(MpiCommunicatorImpl::new()?))
350        },
351        CommunicationBackend::Gloo => {
352            // Fallback to mock for now
353            Ok(Arc::new(MockCommunicator::new()))
354        },
355        CommunicationBackend::Custom => Ok(Arc::new(MockCommunicator::new())),
356    }
357}
358
359/// Mock communicator for testing
360struct MockCommunicator;
361
362impl MockCommunicator {
363    fn new() -> Self {
364        Self
365    }
366}
367
368impl Communicator for MockCommunicator {
369    fn all_gather(&self, tensor: &Tensor, _split_dim: usize) -> Result<Tensor> {
370        // In mock mode, just return the tensor as-is
371        Ok(tensor.clone())
372    }
373
374    fn reduce_scatter(&self, tensor: &Tensor, _split_dim: usize) -> Result<Tensor> {
375        Ok(tensor.clone())
376    }
377
378    fn all_reduce(&self, _tensor: &mut Tensor) -> Result<()> {
379        Ok(())
380    }
381
382    fn send(&self, _tensor: &Tensor, _dest: usize) -> Result<()> {
383        Ok(())
384    }
385
386    fn recv(&self, shape: &[usize], _src: usize) -> Result<Tensor> {
387        Tensor::zeros(shape)
388    }
389
390    fn broadcast(&self, _tensor: &mut Tensor, _root: usize) -> Result<()> {
391        Ok(())
392    }
393}
394
395/// Pipeline parallel schedule for forward/backward passes
396#[derive(Debug, Clone)]
397pub struct PipelineSchedule {
398    /// Number of pipeline stages
399    pub num_stages: usize,
400    /// Number of microbatches
401    pub num_microbatches: usize,
402    /// Schedule type
403    pub schedule_type: PipelineScheduleType,
404}
405
406#[derive(Debug, Clone, Copy)]
407pub enum PipelineScheduleType {
408    /// Forward then backward (simple but inefficient)
409    Sequential,
410    /// 1F1B schedule (one forward, one backward)
411    OneForwardOneBackward,
412    /// Interleaved 1F1B for better efficiency
413    InterleavedOneF1B,
414}
415
416impl PipelineSchedule {
417    /// Create a new pipeline schedule
418    pub fn new(
419        num_stages: usize,
420        num_microbatches: usize,
421        schedule_type: PipelineScheduleType,
422    ) -> Self {
423        Self {
424            num_stages,
425            num_microbatches,
426            schedule_type,
427        }
428    }
429
430    /// Get the schedule for a specific stage
431    pub fn get_stage_schedule(&self, stage_id: usize) -> Vec<PipelineOp> {
432        match self.schedule_type {
433            PipelineScheduleType::Sequential => self.sequential_schedule(stage_id),
434            PipelineScheduleType::OneForwardOneBackward => self.one_f1b_schedule(stage_id),
435            PipelineScheduleType::InterleavedOneF1B => self.interleaved_1f1b_schedule(stage_id),
436        }
437    }
438
439    fn sequential_schedule(&self, stage_id: usize) -> Vec<PipelineOp> {
440        let mut ops = Vec::new();
441
442        // All forwards first
443        for mb in 0..self.num_microbatches {
444            ops.push(PipelineOp::Forward { microbatch_id: mb });
445        }
446
447        // Then all backwards
448        for mb in (0..self.num_microbatches).rev() {
449            ops.push(PipelineOp::Backward { microbatch_id: mb });
450        }
451
452        ops
453    }
454
455    fn one_f1b_schedule(&self, stage_id: usize) -> Vec<PipelineOp> {
456        let mut ops = Vec::new();
457        let num_warmup = self.num_stages - stage_id - 1;
458
459        // Warmup phase - only forward
460        for mb in 0..num_warmup.min(self.num_microbatches) {
461            ops.push(PipelineOp::Forward { microbatch_id: mb });
462        }
463
464        // Steady state - 1F1B
465        let steady_state_mbs = self.num_microbatches.saturating_sub(num_warmup);
466        for i in 0..steady_state_mbs {
467            let forward_mb = num_warmup + i;
468            let backward_mb = i;
469
470            if forward_mb < self.num_microbatches {
471                ops.push(PipelineOp::Forward {
472                    microbatch_id: forward_mb,
473                });
474            }
475            ops.push(PipelineOp::Backward {
476                microbatch_id: backward_mb,
477            });
478        }
479
480        // Cooldown phase - only backward
481        for mb in steady_state_mbs..self.num_microbatches {
482            ops.push(PipelineOp::Backward { microbatch_id: mb });
483        }
484
485        ops
486    }
487
488    fn interleaved_1f1b_schedule(&self, _stage_id: usize) -> Vec<PipelineOp> {
489        // Simplified version - can be optimized further
490        self.one_f1b_schedule(_stage_id)
491    }
492}
493
494#[derive(Debug, Clone)]
495pub enum PipelineOp {
496    Forward { microbatch_id: usize },
497    Backward { microbatch_id: usize },
498    SendActivation { to_stage: usize },
499    RecvActivation { from_stage: usize },
500    SendGradient { to_stage: usize },
501    RecvGradient { from_stage: usize },
502}
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507
508    #[test]
509    fn test_tensor_partition() {
510        let ctx = ModelParallelContext::new(ModelParallelConfig {
511            num_devices: 4,
512            device_ids: vec![0, 1, 2, 3],
513            comm_backend: CommunicationBackend::Custom, // Use mock backend for tests
514            ..Default::default()
515        })
516        .expect("operation failed in test");
517
518        let tensor = Tensor::zeros(&[128, 512]).expect("Failed to create zero tensor");
519        let distributed = ctx.partition_tensor(&tensor, 0).expect("tensor operation failed");
520
521        // Verify partition metadata is correct
522        assert_eq!(distributed.global_shape, vec![128, 512]);
523        assert_eq!(distributed.partition.split_dim, 0);
524        assert_eq!(distributed.partition.start_idx, 0);
525        assert_eq!(distributed.partition.end_idx, 32);
526        assert_eq!(distributed.partition.num_partitions, 4);
527
528        // Check local tensor shape after slicing
529        let local_shape = distributed.local_shard.shape();
530        assert_eq!(local_shape, vec![32, 512]); // First dimension should be sliced to 32
531    }
532
533    #[test]
534    fn test_device_mesh() {
535        let mesh = DeviceMesh::new(&[0, 1, 2, 3], ModelParallelStrategy::Tensor)
536            .expect("tensor operation failed");
537
538        assert_eq!(mesh.device_at(&[0, 0]), Some(0));
539        assert_eq!(mesh.device_at(&[0, 1]), Some(1));
540        assert_eq!(mesh.device_at(&[1, 0]), Some(2));
541        assert_eq!(mesh.device_at(&[1, 1]), Some(3));
542    }
543
544    #[test]
545    fn test_pipeline_schedule() {
546        let schedule = PipelineSchedule::new(4, 8, PipelineScheduleType::OneForwardOneBackward);
547        let stage0_ops = schedule.get_stage_schedule(0);
548
549        // Stage 0 should have 3 warmup forwards
550        let forward_ops: Vec<_> = stage0_ops
551            .iter()
552            .filter(|op| matches!(op, PipelineOp::Forward { .. }))
553            .collect();
554        assert!(!forward_ops.is_empty());
555    }
556}