trustformers_core/parallel/
pipeline_parallel.rs1#![allow(unused_variables)] use super::model_parallel::{
9 ModelParallelContext, PipelineOp, PipelineSchedule, PipelineScheduleType,
10};
11use crate::errors::{runtime_error, Result};
12use crate::Tensor;
13use parking_lot::{Mutex, RwLock};
14use std::collections::{HashMap, VecDeque};
15use std::sync::Arc;
16
17pub trait PipelineLayer: Send + Sync {
19 fn forward(&self, input: &Tensor) -> Result<Tensor>;
20 fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor>;
21}
22
23pub struct PipelineStage {
25 pub stage_id: usize,
27 pub layers: Vec<Box<dyn PipelineLayer>>,
29 pub device_id: usize,
31 pub requires_grad: bool,
33}
34
35impl PipelineStage {
36 pub fn new(stage_id: usize, device_id: usize) -> Self {
37 Self {
38 stage_id,
39 layers: Vec::new(),
40 device_id,
41 requires_grad: true,
42 }
43 }
44
45 pub fn add_layer(&mut self, layer: Box<dyn PipelineLayer>) {
46 self.layers.push(layer);
47 }
48
49 pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
51 let mut output = input.clone();
52 for layer in &self.layers {
53 output = layer.forward(&output)?;
54 }
55 Ok(output)
56 }
57
58 pub fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor> {
60 let mut grad = grad_output.clone();
61 for layer in self.layers.iter_mut().rev() {
63 grad = layer.backward(&grad)?;
64 }
65 Ok(grad)
66 }
67}
68
69pub struct PipelineModel {
71 pub stages: Vec<PipelineStage>,
73 pub mp_context: Arc<ModelParallelContext>,
75 pub local_stage_id: Option<usize>,
77}
78
79impl PipelineModel {
80 pub fn new(mp_context: Arc<ModelParallelContext>) -> Self {
81 Self {
82 stages: Vec::new(),
83 mp_context,
84 local_stage_id: None,
85 }
86 }
87
88 pub fn add_stage(&mut self, stage: PipelineStage) {
90 if stage.device_id == self.mp_context.rank() {
91 self.local_stage_id = Some(stage.stage_id);
92 }
93 self.stages.push(stage);
94 }
95
96 pub fn local_stage(&self) -> Result<&PipelineStage> {
98 let stage_id =
99 self.local_stage_id.ok_or_else(|| runtime_error("No local stage assigned"))?;
100 self.stages.get(stage_id).ok_or_else(|| runtime_error("Invalid stage ID"))
101 }
102
103 pub fn local_stage_mut(&mut self) -> Result<&mut PipelineStage> {
105 let stage_id =
106 self.local_stage_id.ok_or_else(|| runtime_error("No local stage assigned"))?;
107 self.stages.get_mut(stage_id).ok_or_else(|| runtime_error("Invalid stage ID"))
108 }
109
110 pub fn num_stages(&self) -> usize {
112 self.stages.len()
113 }
114}
115
116#[derive(Clone)]
118pub struct Microbatch {
119 pub id: usize,
121 pub input: Option<Tensor>,
123 pub output: Option<Tensor>,
125 pub grad_output: Option<Tensor>,
127 pub grad_input: Option<Tensor>,
129 pub labels: Option<Tensor>,
131}
132
133impl Microbatch {
134 pub fn new(id: usize) -> Self {
135 Self {
136 id,
137 input: None,
138 output: None,
139 grad_output: None,
140 grad_input: None,
141 labels: None,
142 }
143 }
144}
145
146pub struct MicrobatchManager {
148 microbatches: Vec<Microbatch>,
150 checkpoint_activations: bool,
152 forward_queue: VecDeque<usize>,
154 backward_queue: VecDeque<usize>,
156}
157
158impl MicrobatchManager {
159 pub fn new(num_microbatches: usize, checkpoint_activations: bool) -> Self {
160 let microbatches = (0..num_microbatches).map(Microbatch::new).collect();
161
162 Self {
163 microbatches,
164 checkpoint_activations,
165 forward_queue: VecDeque::new(),
166 backward_queue: VecDeque::new(),
167 }
168 }
169
170 pub fn get(&self, id: usize) -> Result<&Microbatch> {
172 self.microbatches
173 .get(id)
174 .ok_or_else(|| runtime_error(format!("Invalid microbatch ID: {}", id)))
175 }
176
177 pub fn get_mut(&mut self, id: usize) -> Result<&mut Microbatch> {
179 self.microbatches
180 .get_mut(id)
181 .ok_or_else(|| runtime_error(format!("Invalid microbatch ID: {}", id)))
182 }
183
184 pub fn enqueue_forward(&mut self, mb_id: usize) {
186 self.forward_queue.push_back(mb_id);
187 }
188
189 pub fn enqueue_backward(&mut self, mb_id: usize) {
191 self.backward_queue.push_back(mb_id);
192 }
193
194 pub fn dequeue_forward(&mut self) -> Option<usize> {
196 self.forward_queue.pop_front()
197 }
198
199 pub fn dequeue_backward(&mut self) -> Option<usize> {
201 self.backward_queue.pop_front()
202 }
203
204 pub fn maybe_clear_activation(&mut self, mb_id: usize) -> Result<()> {
206 if self.checkpoint_activations {
207 let mb = self.get_mut(mb_id)?;
208 mb.output = None; }
210 Ok(())
211 }
212
213 pub fn maybe_recompute_activation(
215 &mut self,
216 mb_id: usize,
217 stage: &PipelineStage,
218 ) -> Result<()> {
219 let should_recompute = self.checkpoint_activations;
220 let mb = self.get_mut(mb_id)?;
221 if should_recompute && mb.output.is_none() {
222 if let Some(input) = &mb.input {
224 mb.output = Some(stage.forward(input)?);
225 }
226 }
227 Ok(())
228 }
229}
230
231pub struct PipelineExecutor {
233 model: Arc<RwLock<PipelineModel>>,
235 schedule: PipelineSchedule,
237 mb_manager: Arc<Mutex<MicrobatchManager>>,
239 #[allow(dead_code)]
241 send_buffers: HashMap<usize, Tensor>,
242 _recv_buffers: HashMap<usize, Tensor>,
243}
244
245impl PipelineExecutor {
246 pub fn new(
247 model: Arc<RwLock<PipelineModel>>,
248 num_microbatches: usize,
249 checkpoint_activations: bool,
250 ) -> Result<Self> {
251 let num_stages = {
252 let model_read = model.read();
253 model_read.num_stages()
254 };
255
256 let schedule = PipelineSchedule::new(
257 num_stages,
258 num_microbatches,
259 PipelineScheduleType::OneForwardOneBackward,
260 );
261
262 let mb_manager = Arc::new(Mutex::new(MicrobatchManager::new(
263 num_microbatches,
264 checkpoint_activations,
265 )));
266
267 Ok(Self {
268 model,
269 schedule,
270 mb_manager,
271 send_buffers: HashMap::new(),
272 _recv_buffers: HashMap::new(),
273 })
274 }
275
276 pub fn execute_step(&mut self, inputs: Vec<Tensor>, labels: Vec<Tensor>) -> Result<f32> {
278 let num_inputs = inputs.len();
279
280 self.prepare_microbatches(inputs, labels)?;
282
283 let stage_id = {
285 let model = self.model.read();
286 model.local_stage_id.ok_or_else(|| runtime_error("No local stage"))?
287 };
288
289 let ops = self.schedule.get_stage_schedule(stage_id);
290
291 let mut total_loss = 0.0;
293 for op in ops {
294 match op {
295 PipelineOp::Forward { microbatch_id } => {
296 self.execute_forward(microbatch_id)?;
297 },
298 PipelineOp::Backward { microbatch_id } => {
299 let loss = self.execute_backward(microbatch_id)?;
300 total_loss += loss;
301 },
302 PipelineOp::SendActivation { to_stage } => {
303 self.send_activation(to_stage)?;
304 },
305 PipelineOp::RecvActivation { from_stage } => {
306 self.recv_activation(from_stage)?;
307 },
308 PipelineOp::SendGradient { to_stage } => {
309 self.send_gradient(to_stage)?;
310 },
311 PipelineOp::RecvGradient { from_stage } => {
312 self.recv_gradient(from_stage)?;
313 },
314 }
315 }
316
317 Ok(total_loss / num_inputs as f32)
318 }
319
320 fn prepare_microbatches(&mut self, inputs: Vec<Tensor>, labels: Vec<Tensor>) -> Result<()> {
322 let mut mb_manager = self.mb_manager.lock();
323
324 for (i, (input, label)) in inputs.into_iter().zip(labels).enumerate() {
325 let mb = mb_manager.get_mut(i)?;
326 mb.input = Some(input);
327 mb.labels = Some(label);
328 mb_manager.enqueue_forward(i);
329 }
330
331 Ok(())
332 }
333
334 fn execute_forward(&mut self, mb_id: usize) -> Result<()> {
336 let mut model = self.model.write();
337 let stage = model.local_stage_mut()?;
338
339 let mut mb_manager = self.mb_manager.lock();
340 let mb = mb_manager.get_mut(mb_id)?;
341
342 let input = if stage.stage_id == 0 {
344 mb.input.as_ref().ok_or_else(|| runtime_error("Missing input"))?
345 } else {
346 mb.output.as_ref().ok_or_else(|| runtime_error("Missing activation"))?
348 };
349
350 let output = stage.forward(input)?;
352 mb.output = Some(output);
353
354 mb_manager.maybe_clear_activation(mb_id)?;
356
357 Ok(())
358 }
359
360 fn execute_backward(&mut self, mb_id: usize) -> Result<f32> {
362 let (is_last_stage, stage_id) = {
363 let model = self.model.read();
364 let stage = model.local_stage()?;
365 (stage.stage_id == model.num_stages() - 1, stage.stage_id)
366 };
367
368 let mut model = self.model.write();
369 let stage = model.local_stage_mut()?;
370
371 let mut mb_manager = self.mb_manager.lock();
372
373 mb_manager.maybe_recompute_activation(mb_id, stage)?;
375
376 let mb = mb_manager.get_mut(mb_id)?;
377
378 let loss = if is_last_stage {
380 1.0
382 } else {
383 0.0
384 };
385
386 let grad_output = if is_last_stage {
388 mb.output.as_ref().ok_or_else(|| runtime_error("Missing output"))?.clone()
390 } else {
391 mb.grad_output
393 .as_ref()
394 .ok_or_else(|| runtime_error("Missing grad_output"))?
395 .clone()
396 };
397
398 let grad_input = stage.backward(&grad_output)?;
400 mb.grad_input = Some(grad_input);
401
402 Ok(loss)
403 }
404
405 fn send_activation(&mut self, to_stage: usize) -> Result<()> {
407 Ok(())
409 }
410
411 fn recv_activation(&mut self, from_stage: usize) -> Result<()> {
413 Ok(())
415 }
416
417 fn send_gradient(&mut self, to_stage: usize) -> Result<()> {
419 Ok(())
421 }
422
423 fn recv_gradient(&mut self, from_stage: usize) -> Result<()> {
425 Ok(())
427 }
428}
429
430pub struct PipelineOptimizer {
432 #[allow(dead_code)]
434 lr: f32,
435 _weight_decay: f32,
437 accumulation_steps: usize,
439 current_step: usize,
441 accumulated_grads: HashMap<String, Tensor>,
443}
444
445impl PipelineOptimizer {
446 pub fn new(lr: f32, weight_decay: f32, accumulation_steps: usize) -> Self {
447 Self {
448 lr,
449 _weight_decay: weight_decay,
450 accumulation_steps,
451 current_step: 0,
452 accumulated_grads: HashMap::new(),
453 }
454 }
455
456 pub fn accumulate_gradients(&mut self, grads: HashMap<String, Tensor>) -> Result<()> {
458 for (name, grad) in grads {
459 if let Some(acc_grad) = self.accumulated_grads.get_mut(&name) {
460 *acc_grad = acc_grad.add(&grad)?;
461 } else {
462 self.accumulated_grads.insert(name, grad);
463 }
464 }
465
466 self.current_step += 1;
467 Ok(())
468 }
469
470 pub fn step(&mut self, model: &mut PipelineModel) -> Result<bool> {
472 if self.current_step < self.accumulation_steps {
473 return Ok(false);
474 }
475
476 let scale = 1.0 / self.accumulation_steps as f32;
478
479 self.accumulated_grads.clear();
482 self.current_step = 0;
483
484 Ok(true)
485 }
486}
487
488pub struct PipelineModelBuilder {
490 mp_context: Arc<ModelParallelContext>,
491 stages: Vec<PipelineStage>,
492 layers_per_stage: Option<usize>,
493}
494
495impl PipelineModelBuilder {
496 pub fn new(mp_context: Arc<ModelParallelContext>) -> Self {
497 Self {
498 mp_context,
499 stages: Vec::new(),
500 layers_per_stage: None,
501 }
502 }
503
504 pub fn layers_per_stage(mut self, layers_per_stage: usize) -> Self {
506 self.layers_per_stage = Some(layers_per_stage);
507 self
508 }
509
510 pub fn add_stage(mut self, stage: PipelineStage) -> Self {
512 self.stages.push(stage);
513 self
514 }
515
516 pub fn build(self) -> Result<PipelineModel> {
518 let mut model = PipelineModel::new(self.mp_context);
519
520 for stage in self.stages {
521 model.add_stage(stage);
522 }
523
524 Ok(model)
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::super::model_parallel::{
531 CommunicationBackend, ModelParallelConfig, ModelParallelStrategy,
532 };
533 use super::*;
534
535 #[test]
536 fn test_pipeline_stage() {
537 let stage = PipelineStage::new(0, 0);
538 assert_eq!(stage.stage_id, 0);
539 assert_eq!(stage.device_id, 0);
540 assert!(stage.requires_grad);
541 }
542
543 #[test]
544 fn test_microbatch_manager() {
545 let mut manager = MicrobatchManager::new(4, true);
546
547 manager.enqueue_forward(0);
548 manager.enqueue_forward(1);
549
550 assert_eq!(manager.dequeue_forward(), Some(0));
551 assert_eq!(manager.dequeue_forward(), Some(1));
552 assert_eq!(manager.dequeue_forward(), None);
553 }
554
555 #[test]
556 fn test_pipeline_model_builder() {
557 let config = ModelParallelConfig {
558 num_devices: 4,
559 device_ids: vec![0, 1, 2, 3],
560 strategy: ModelParallelStrategy::Pipeline,
561 comm_backend: CommunicationBackend::Custom,
562 ..Default::default()
563 };
564
565 let mp_context =
566 Arc::new(ModelParallelContext::new(config).expect("operation failed in test"));
567
568 let model = PipelineModelBuilder::new(mp_context)
569 .add_stage(PipelineStage::new(0, 0))
570 .add_stage(PipelineStage::new(1, 1))
571 .build()
572 .expect("operation failed in test");
573
574 assert_eq!(model.num_stages(), 2);
575 }
576}