Skip to main content

trustformers_core/parallel/
pipeline_parallel.rs

1//! Pipeline Parallelism for Large Model Training
2//!
3//! This module implements pipeline parallelism, which splits a model into stages
4
5#![allow(unused_variables)] // Distributed parallelism implementation with reserved parameters
6//! across multiple devices and processes microbatches in a pipelined manner.
7
8use super::model_parallel::{
9    ModelParallelContext, PipelineOp, PipelineSchedule, PipelineScheduleType,
10};
11use crate::errors::{runtime_error, Result};
12use crate::Tensor;
13use parking_lot::{Mutex, RwLock};
14use std::collections::{HashMap, VecDeque};
15use std::sync::Arc;
16
17/// Layer wrapper for pipeline stages
18pub trait PipelineLayer: Send + Sync {
19    fn forward(&self, input: &Tensor) -> Result<Tensor>;
20    fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor>;
21}
22
23/// A single stage in the pipeline
24pub struct PipelineStage {
25    /// Stage ID (0-indexed)
26    pub stage_id: usize,
27    /// Layers in this stage
28    pub layers: Vec<Box<dyn PipelineLayer>>,
29    /// Device ID for this stage
30    pub device_id: usize,
31    /// Whether this stage requires gradient computation
32    pub requires_grad: bool,
33}
34
35impl PipelineStage {
36    pub fn new(stage_id: usize, device_id: usize) -> Self {
37        Self {
38            stage_id,
39            layers: Vec::new(),
40            device_id,
41            requires_grad: true,
42        }
43    }
44
45    pub fn add_layer(&mut self, layer: Box<dyn PipelineLayer>) {
46        self.layers.push(layer);
47    }
48
49    /// Forward pass through all layers in the stage
50    pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
51        let mut output = input.clone();
52        for layer in &self.layers {
53            output = layer.forward(&output)?;
54        }
55        Ok(output)
56    }
57
58    /// Backward pass through all layers in the stage
59    pub fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor> {
60        let mut grad = grad_output.clone();
61        // Process layers in reverse order
62        for layer in self.layers.iter_mut().rev() {
63            grad = layer.backward(&grad)?;
64        }
65        Ok(grad)
66    }
67}
68
69/// Model split into pipeline stages
70pub struct PipelineModel {
71    /// All pipeline stages
72    pub stages: Vec<PipelineStage>,
73    /// Model parallel context
74    pub mp_context: Arc<ModelParallelContext>,
75    /// Stage assignment for this rank
76    pub local_stage_id: Option<usize>,
77}
78
79impl PipelineModel {
80    pub fn new(mp_context: Arc<ModelParallelContext>) -> Self {
81        Self {
82            stages: Vec::new(),
83            mp_context,
84            local_stage_id: None,
85        }
86    }
87
88    /// Add a stage to the pipeline
89    pub fn add_stage(&mut self, stage: PipelineStage) {
90        if stage.device_id == self.mp_context.rank() {
91            self.local_stage_id = Some(stage.stage_id);
92        }
93        self.stages.push(stage);
94    }
95
96    /// Get the local stage for this rank
97    pub fn local_stage(&self) -> Result<&PipelineStage> {
98        let stage_id =
99            self.local_stage_id.ok_or_else(|| runtime_error("No local stage assigned"))?;
100        self.stages.get(stage_id).ok_or_else(|| runtime_error("Invalid stage ID"))
101    }
102
103    /// Get mutable local stage
104    pub fn local_stage_mut(&mut self) -> Result<&mut PipelineStage> {
105        let stage_id =
106            self.local_stage_id.ok_or_else(|| runtime_error("No local stage assigned"))?;
107        self.stages.get_mut(stage_id).ok_or_else(|| runtime_error("Invalid stage ID"))
108    }
109
110    /// Get total number of stages
111    pub fn num_stages(&self) -> usize {
112        self.stages.len()
113    }
114}
115
116/// Microbatch data structure
117#[derive(Clone)]
118pub struct Microbatch {
119    /// Microbatch ID
120    pub id: usize,
121    /// Input tensor
122    pub input: Option<Tensor>,
123    /// Output tensor (activations)
124    pub output: Option<Tensor>,
125    /// Gradient w.r.t output
126    pub grad_output: Option<Tensor>,
127    /// Gradient w.r.t input
128    pub grad_input: Option<Tensor>,
129    /// Labels for loss computation (only for last stage)
130    pub labels: Option<Tensor>,
131}
132
133impl Microbatch {
134    pub fn new(id: usize) -> Self {
135        Self {
136            id,
137            input: None,
138            output: None,
139            grad_output: None,
140            grad_input: None,
141            labels: None,
142        }
143    }
144}
145
146/// Manages microbatches across pipeline stages
147pub struct MicrobatchManager {
148    /// All microbatches
149    microbatches: Vec<Microbatch>,
150    /// Activation checkpointing enabled
151    checkpoint_activations: bool,
152    /// Queue of pending forward passes
153    forward_queue: VecDeque<usize>,
154    /// Queue of pending backward passes
155    backward_queue: VecDeque<usize>,
156}
157
158impl MicrobatchManager {
159    pub fn new(num_microbatches: usize, checkpoint_activations: bool) -> Self {
160        let microbatches = (0..num_microbatches).map(Microbatch::new).collect();
161
162        Self {
163            microbatches,
164            checkpoint_activations,
165            forward_queue: VecDeque::new(),
166            backward_queue: VecDeque::new(),
167        }
168    }
169
170    /// Get microbatch by ID
171    pub fn get(&self, id: usize) -> Result<&Microbatch> {
172        self.microbatches
173            .get(id)
174            .ok_or_else(|| runtime_error(format!("Invalid microbatch ID: {}", id)))
175    }
176
177    /// Get mutable microbatch
178    pub fn get_mut(&mut self, id: usize) -> Result<&mut Microbatch> {
179        self.microbatches
180            .get_mut(id)
181            .ok_or_else(|| runtime_error(format!("Invalid microbatch ID: {}", id)))
182    }
183
184    /// Add microbatch to forward queue
185    pub fn enqueue_forward(&mut self, mb_id: usize) {
186        self.forward_queue.push_back(mb_id);
187    }
188
189    /// Add microbatch to backward queue
190    pub fn enqueue_backward(&mut self, mb_id: usize) {
191        self.backward_queue.push_back(mb_id);
192    }
193
194    /// Get next forward microbatch
195    pub fn dequeue_forward(&mut self) -> Option<usize> {
196        self.forward_queue.pop_front()
197    }
198
199    /// Get next backward microbatch
200    pub fn dequeue_backward(&mut self) -> Option<usize> {
201        self.backward_queue.pop_front()
202    }
203
204    /// Clear activation if checkpointing is enabled
205    pub fn maybe_clear_activation(&mut self, mb_id: usize) -> Result<()> {
206        if self.checkpoint_activations {
207            let mb = self.get_mut(mb_id)?;
208            mb.output = None; // Clear to save memory
209        }
210        Ok(())
211    }
212
213    /// Recompute activation if needed
214    pub fn maybe_recompute_activation(
215        &mut self,
216        mb_id: usize,
217        stage: &PipelineStage,
218    ) -> Result<()> {
219        let should_recompute = self.checkpoint_activations;
220        let mb = self.get_mut(mb_id)?;
221        if should_recompute && mb.output.is_none() {
222            // Recompute forward pass
223            if let Some(input) = &mb.input {
224                mb.output = Some(stage.forward(input)?);
225            }
226        }
227        Ok(())
228    }
229}
230
231/// Pipeline executor that manages the execution schedule
232pub struct PipelineExecutor {
233    /// Pipeline model
234    model: Arc<RwLock<PipelineModel>>,
235    /// Pipeline schedule
236    schedule: PipelineSchedule,
237    /// Microbatch manager
238    mb_manager: Arc<Mutex<MicrobatchManager>>,
239    /// Communication buffers
240    #[allow(dead_code)]
241    send_buffers: HashMap<usize, Tensor>,
242    _recv_buffers: HashMap<usize, Tensor>,
243}
244
245impl PipelineExecutor {
246    pub fn new(
247        model: Arc<RwLock<PipelineModel>>,
248        num_microbatches: usize,
249        checkpoint_activations: bool,
250    ) -> Result<Self> {
251        let num_stages = {
252            let model_read = model.read();
253            model_read.num_stages()
254        };
255
256        let schedule = PipelineSchedule::new(
257            num_stages,
258            num_microbatches,
259            PipelineScheduleType::OneForwardOneBackward,
260        );
261
262        let mb_manager = Arc::new(Mutex::new(MicrobatchManager::new(
263            num_microbatches,
264            checkpoint_activations,
265        )));
266
267        Ok(Self {
268            model,
269            schedule,
270            mb_manager,
271            send_buffers: HashMap::new(),
272            _recv_buffers: HashMap::new(),
273        })
274    }
275
276    /// Execute one training step
277    pub fn execute_step(&mut self, inputs: Vec<Tensor>, labels: Vec<Tensor>) -> Result<f32> {
278        let num_inputs = inputs.len();
279
280        // Split inputs into microbatches
281        self.prepare_microbatches(inputs, labels)?;
282
283        // Get schedule for local stage
284        let stage_id = {
285            let model = self.model.read();
286            model.local_stage_id.ok_or_else(|| runtime_error("No local stage"))?
287        };
288
289        let ops = self.schedule.get_stage_schedule(stage_id);
290
291        // Execute operations according to schedule
292        let mut total_loss = 0.0;
293        for op in ops {
294            match op {
295                PipelineOp::Forward { microbatch_id } => {
296                    self.execute_forward(microbatch_id)?;
297                },
298                PipelineOp::Backward { microbatch_id } => {
299                    let loss = self.execute_backward(microbatch_id)?;
300                    total_loss += loss;
301                },
302                PipelineOp::SendActivation { to_stage } => {
303                    self.send_activation(to_stage)?;
304                },
305                PipelineOp::RecvActivation { from_stage } => {
306                    self.recv_activation(from_stage)?;
307                },
308                PipelineOp::SendGradient { to_stage } => {
309                    self.send_gradient(to_stage)?;
310                },
311                PipelineOp::RecvGradient { from_stage } => {
312                    self.recv_gradient(from_stage)?;
313                },
314            }
315        }
316
317        Ok(total_loss / num_inputs as f32)
318    }
319
320    /// Prepare microbatches from full batch
321    fn prepare_microbatches(&mut self, inputs: Vec<Tensor>, labels: Vec<Tensor>) -> Result<()> {
322        let mut mb_manager = self.mb_manager.lock();
323
324        for (i, (input, label)) in inputs.into_iter().zip(labels).enumerate() {
325            let mb = mb_manager.get_mut(i)?;
326            mb.input = Some(input);
327            mb.labels = Some(label);
328            mb_manager.enqueue_forward(i);
329        }
330
331        Ok(())
332    }
333
334    /// Execute forward pass for a microbatch
335    fn execute_forward(&mut self, mb_id: usize) -> Result<()> {
336        let mut model = self.model.write();
337        let stage = model.local_stage_mut()?;
338
339        let mut mb_manager = self.mb_manager.lock();
340        let mb = mb_manager.get_mut(mb_id)?;
341
342        // Get input (from previous stage or initial input)
343        let input = if stage.stage_id == 0 {
344            mb.input.as_ref().ok_or_else(|| runtime_error("Missing input"))?
345        } else {
346            // Would receive from previous stage
347            mb.output.as_ref().ok_or_else(|| runtime_error("Missing activation"))?
348        };
349
350        // Forward pass
351        let output = stage.forward(input)?;
352        mb.output = Some(output);
353
354        // Maybe clear activation for checkpointing
355        mb_manager.maybe_clear_activation(mb_id)?;
356
357        Ok(())
358    }
359
360    /// Execute backward pass for a microbatch
361    fn execute_backward(&mut self, mb_id: usize) -> Result<f32> {
362        let (is_last_stage, stage_id) = {
363            let model = self.model.read();
364            let stage = model.local_stage()?;
365            (stage.stage_id == model.num_stages() - 1, stage.stage_id)
366        };
367
368        let mut model = self.model.write();
369        let stage = model.local_stage_mut()?;
370
371        let mut mb_manager = self.mb_manager.lock();
372
373        // Recompute activation if needed
374        mb_manager.maybe_recompute_activation(mb_id, stage)?;
375
376        let mb = mb_manager.get_mut(mb_id)?;
377
378        // Compute loss and gradient for last stage
379        let loss = if is_last_stage {
380            // Compute loss (simplified - would use actual loss function)
381            1.0
382        } else {
383            0.0
384        };
385
386        // Get gradient w.r.t output
387        let grad_output = if is_last_stage {
388            // Compute gradient from loss
389            mb.output.as_ref().ok_or_else(|| runtime_error("Missing output"))?.clone()
390        } else {
391            // Would receive from next stage
392            mb.grad_output
393                .as_ref()
394                .ok_or_else(|| runtime_error("Missing grad_output"))?
395                .clone()
396        };
397
398        // Backward pass
399        let grad_input = stage.backward(&grad_output)?;
400        mb.grad_input = Some(grad_input);
401
402        Ok(loss)
403    }
404
405    /// Send activation to next stage
406    fn send_activation(&mut self, to_stage: usize) -> Result<()> {
407        // In practice, would use MPI/NCCL for communication
408        Ok(())
409    }
410
411    /// Receive activation from previous stage
412    fn recv_activation(&mut self, from_stage: usize) -> Result<()> {
413        // In practice, would use MPI/NCCL for communication
414        Ok(())
415    }
416
417    /// Send gradient to previous stage
418    fn send_gradient(&mut self, to_stage: usize) -> Result<()> {
419        // In practice, would use MPI/NCCL for communication
420        Ok(())
421    }
422
423    /// Receive gradient from next stage
424    fn recv_gradient(&mut self, from_stage: usize) -> Result<()> {
425        // In practice, would use MPI/NCCL for communication
426        Ok(())
427    }
428}
429
430/// Optimizer for pipeline parallel training
431pub struct PipelineOptimizer {
432    /// Learning rate
433    #[allow(dead_code)]
434    lr: f32,
435    /// Weight decay
436    _weight_decay: f32,
437    /// Gradient accumulation steps
438    accumulation_steps: usize,
439    /// Current accumulation step
440    current_step: usize,
441    /// Accumulated gradients
442    accumulated_grads: HashMap<String, Tensor>,
443}
444
445impl PipelineOptimizer {
446    pub fn new(lr: f32, weight_decay: f32, accumulation_steps: usize) -> Self {
447        Self {
448            lr,
449            _weight_decay: weight_decay,
450            accumulation_steps,
451            current_step: 0,
452            accumulated_grads: HashMap::new(),
453        }
454    }
455
456    /// Accumulate gradients from microbatch
457    pub fn accumulate_gradients(&mut self, grads: HashMap<String, Tensor>) -> Result<()> {
458        for (name, grad) in grads {
459            if let Some(acc_grad) = self.accumulated_grads.get_mut(&name) {
460                *acc_grad = acc_grad.add(&grad)?;
461            } else {
462                self.accumulated_grads.insert(name, grad);
463            }
464        }
465
466        self.current_step += 1;
467        Ok(())
468    }
469
470    /// Apply gradients if accumulation is complete
471    pub fn step(&mut self, model: &mut PipelineModel) -> Result<bool> {
472        if self.current_step < self.accumulation_steps {
473            return Ok(false);
474        }
475
476        // Apply accumulated gradients
477        let scale = 1.0 / self.accumulation_steps as f32;
478
479        // In practice, would update model parameters
480        // For now, just clear accumulated gradients
481        self.accumulated_grads.clear();
482        self.current_step = 0;
483
484        Ok(true)
485    }
486}
487
488/// Builder for creating pipeline models
489pub struct PipelineModelBuilder {
490    mp_context: Arc<ModelParallelContext>,
491    stages: Vec<PipelineStage>,
492    layers_per_stage: Option<usize>,
493}
494
495impl PipelineModelBuilder {
496    pub fn new(mp_context: Arc<ModelParallelContext>) -> Self {
497        Self {
498            mp_context,
499            stages: Vec::new(),
500            layers_per_stage: None,
501        }
502    }
503
504    /// Set number of layers per stage (for automatic partitioning)
505    pub fn layers_per_stage(mut self, layers_per_stage: usize) -> Self {
506        self.layers_per_stage = Some(layers_per_stage);
507        self
508    }
509
510    /// Add a pre-configured stage
511    pub fn add_stage(mut self, stage: PipelineStage) -> Self {
512        self.stages.push(stage);
513        self
514    }
515
516    /// Build the pipeline model
517    pub fn build(self) -> Result<PipelineModel> {
518        let mut model = PipelineModel::new(self.mp_context);
519
520        for stage in self.stages {
521            model.add_stage(stage);
522        }
523
524        Ok(model)
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::super::model_parallel::{
531        CommunicationBackend, ModelParallelConfig, ModelParallelStrategy,
532    };
533    use super::*;
534
535    #[test]
536    fn test_pipeline_stage() {
537        let stage = PipelineStage::new(0, 0);
538        assert_eq!(stage.stage_id, 0);
539        assert_eq!(stage.device_id, 0);
540        assert!(stage.requires_grad);
541    }
542
543    #[test]
544    fn test_microbatch_manager() {
545        let mut manager = MicrobatchManager::new(4, true);
546
547        manager.enqueue_forward(0);
548        manager.enqueue_forward(1);
549
550        assert_eq!(manager.dequeue_forward(), Some(0));
551        assert_eq!(manager.dequeue_forward(), Some(1));
552        assert_eq!(manager.dequeue_forward(), None);
553    }
554
555    #[test]
556    fn test_pipeline_model_builder() {
557        let config = ModelParallelConfig {
558            num_devices: 4,
559            device_ids: vec![0, 1, 2, 3],
560            strategy: ModelParallelStrategy::Pipeline,
561            comm_backend: CommunicationBackend::Custom,
562            ..Default::default()
563        };
564
565        let mp_context =
566            Arc::new(ModelParallelContext::new(config).expect("operation failed in test"));
567
568        let model = PipelineModelBuilder::new(mp_context)
569            .add_stage(PipelineStage::new(0, 0))
570            .add_stage(PipelineStage::new(1, 1))
571            .build()
572            .expect("operation failed in test");
573
574        assert_eq!(model.num_stages(), 2);
575    }
576}