1use std::sync::mpsc;
48use std::sync::{Arc, Mutex};
49
50use serde::{Deserialize, Serialize};
51
52use yscv_tensor::Tensor;
53
54use crate::ModelError;
55
56#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
62pub struct DistributedConfig {
63 pub world_size: usize,
65 pub rank: usize,
67 pub coordinator_addr: String,
69}
70
71pub trait Transport: Send {
77 fn send(&self, dest_rank: usize, data: &[u8]) -> Result<(), ModelError>;
79 fn recv(&self, src_rank: usize) -> Result<Vec<u8>, ModelError>;
81 fn barrier(&self) -> Result<(), ModelError>;
83}
84
85pub struct InProcessTransport {
87 rank: usize,
88 _world_size: usize,
89 senders: Vec<mpsc::Sender<Vec<u8>>>,
91 receivers: Vec<mpsc::Receiver<Vec<u8>>>,
93 barrier_state: Arc<Mutex<BarrierState>>,
94}
95
96struct BarrierState {
97 count: usize,
98 generation: u64,
99 target: usize,
100 condvar_senders: Vec<mpsc::Sender<()>>,
101 condvar_receivers: Vec<Option<mpsc::Receiver<()>>>,
102}
103
104impl InProcessTransport {
105 #[allow(clippy::type_complexity)]
107 pub fn create_group(world_size: usize) -> Vec<Self> {
108 let mut sender_grid: Vec<Vec<Option<mpsc::Sender<Vec<u8>>>>> = Vec::new();
110 let mut receiver_grid: Vec<Vec<Option<mpsc::Receiver<Vec<u8>>>>> = Vec::new();
111 for _ in 0..world_size {
112 let mut s_row = Vec::new();
113 let mut r_row = Vec::new();
114 for _ in 0..world_size {
115 let (tx, rx) = mpsc::channel();
116 s_row.push(Some(tx));
117 r_row.push(Some(rx));
118 }
119 sender_grid.push(s_row);
120 receiver_grid.push(r_row);
121 }
122
123 let mut barrier_senders = Vec::new();
125 let mut barrier_receivers = Vec::new();
126 for _ in 0..world_size {
127 let (tx, rx) = mpsc::channel();
128 barrier_senders.push(tx);
129 barrier_receivers.push(Some(rx));
130 }
131
132 let barrier_state = Arc::new(Mutex::new(BarrierState {
133 count: 0,
134 generation: 0,
135 target: world_size,
136 condvar_senders: barrier_senders,
137 condvar_receivers: barrier_receivers,
138 }));
139
140 let mut transports = Vec::new();
141
142 for r in 0..world_size {
143 let mut my_senders = Vec::new();
144 let mut my_receivers = Vec::new();
145
146 for item in sender_grid[r].iter_mut() {
148 my_senders.push(item.take().expect("sender not yet taken"));
149 }
150 for row in receiver_grid.iter_mut() {
152 my_receivers.push(row[r].take().expect("receiver not yet taken"));
153 }
154
155 transports.push(InProcessTransport {
156 rank: r,
157 _world_size: world_size,
158 senders: my_senders,
159 receivers: my_receivers,
160 barrier_state: barrier_state.clone(),
161 });
162 }
163
164 transports
165 }
166}
167
168impl Transport for InProcessTransport {
169 fn send(&self, dest_rank: usize, data: &[u8]) -> Result<(), ModelError> {
170 self.senders[dest_rank].send(data.to_vec()).map_err(|e| {
171 ModelError::CheckpointSerialization {
172 message: format!("transport send failed: {e}"),
173 }
174 })
175 }
176
177 fn recv(&self, src_rank: usize) -> Result<Vec<u8>, ModelError> {
178 self.receivers[src_rank]
179 .recv()
180 .map_err(|e| ModelError::CheckpointSerialization {
181 message: format!("transport recv failed: {e}"),
182 })
183 }
184
185 fn barrier(&self) -> Result<(), ModelError> {
186 let mut state =
187 self.barrier_state
188 .lock()
189 .map_err(|e| ModelError::CheckpointSerialization {
190 message: format!("barrier lock failed: {e}"),
191 })?;
192 state.count += 1;
193 if state.count == state.target {
194 state.count = 0;
195 state.generation += 1;
196 for tx in &state.condvar_senders {
198 let _ = tx.send(());
199 }
200 Ok(())
201 } else {
202 let rx = state.condvar_receivers[self.rank].take();
203 drop(state);
204 if let Some(rx) = rx {
205 let _ = rx.recv();
206 }
207 let (tx, rx) = mpsc::channel();
209 let mut state =
210 self.barrier_state
211 .lock()
212 .map_err(|e| ModelError::CheckpointSerialization {
213 message: format!("barrier lock failed: {e}"),
214 })?;
215 state.condvar_senders[self.rank] = tx;
216 state.condvar_receivers[self.rank] = Some(rx);
217 Ok(())
218 }
219 }
220}
221
222pub trait GradientAggregator: Send {
228 fn aggregate(&mut self, local_gradients: &[Tensor]) -> Result<Vec<Tensor>, ModelError>;
230}
231
232pub struct LocalAggregator;
234
235impl GradientAggregator for LocalAggregator {
236 fn aggregate(&mut self, local_gradients: &[Tensor]) -> Result<Vec<Tensor>, ModelError> {
237 Ok(local_gradients.to_vec())
238 }
239}
240
241pub struct AllReduceAggregator {
260 config: DistributedConfig,
261 transport: Box<dyn Transport>,
262}
263
264impl AllReduceAggregator {
265 pub fn new(config: DistributedConfig, transport: Box<dyn Transport>) -> Self {
266 Self { config, transport }
267 }
268}
269
270impl GradientAggregator for AllReduceAggregator {
271 fn aggregate(&mut self, local_gradients: &[Tensor]) -> Result<Vec<Tensor>, ModelError> {
272 let world = self.config.world_size;
273 if world <= 1 {
274 return Ok(local_gradients.to_vec());
275 }
276
277 let local_bytes = serialize_tensors(local_gradients)?;
279
280 let next = (self.config.rank + 1) % world;
282 let prev = (self.config.rank + world - 1) % world;
283
284 let mut accumulated = local_bytes.clone();
286 for _ in 0..(world - 1) {
287 self.transport.send(next, &accumulated)?;
288 let received = self.transport.recv(prev)?;
289 let recv_tensors = deserialize_tensors(&received)?;
291 let acc_tensors = deserialize_tensors(&accumulated)?;
292 let mut summed = Vec::new();
293 for (a, r) in acc_tensors.iter().zip(recv_tensors.iter()) {
294 summed.push(a.add(r)?);
295 }
296 accumulated = serialize_tensors(&summed)?;
297 }
298
299 let result_tensors = deserialize_tensors(&accumulated)?;
301 let scale = 1.0 / world as f32;
302 let mut averaged = Vec::new();
303 for t in &result_tensors {
304 averaged.push(t.scale(scale));
305 }
306
307 self.transport.barrier()?;
308 Ok(averaged)
309 }
310}
311
312pub struct ParameterServer {
330 config: DistributedConfig,
331 transport: Box<dyn Transport>,
332}
333
334impl ParameterServer {
335 pub fn new(config: DistributedConfig, transport: Box<dyn Transport>) -> Self {
336 Self { config, transport }
337 }
338
339 pub fn broadcast_params(&self, params: &[Tensor]) -> Result<Vec<Tensor>, ModelError> {
341 let world = self.config.world_size;
342 if world <= 1 {
343 return Ok(params.to_vec());
344 }
345
346 if self.config.rank == 0 {
347 let data = serialize_tensors(params)?;
348 for dest in 1..world {
349 self.transport.send(dest, &data)?;
350 }
351 Ok(params.to_vec())
352 } else {
353 let data = self.transport.recv(0)?;
354 deserialize_tensors(&data)
355 }
356 }
357
358 pub fn reduce_gradients(&self, grads: &[Tensor]) -> Result<Vec<Tensor>, ModelError> {
360 let world = self.config.world_size;
361 if world <= 1 {
362 return Ok(grads.to_vec());
363 }
364
365 if self.config.rank == 0 {
366 let mut acc = grads.to_vec();
367 for src in 1..world {
368 let data = self.transport.recv(src)?;
369 let remote_grads = deserialize_tensors(&data)?;
370 for (a, r) in acc.iter_mut().zip(remote_grads.iter()) {
371 *a = a.add(r)?;
372 }
373 }
374 let scale = 1.0 / world as f32;
375 let mut averaged = Vec::new();
376 for t in &acc {
377 averaged.push(t.scale(scale));
378 }
379 let data = serialize_tensors(&averaged)?;
381 for dest in 1..world {
382 self.transport.send(dest, &data)?;
383 }
384 Ok(averaged)
385 } else {
386 let data = serialize_tensors(grads)?;
387 self.transport.send(0, &data)?;
388 let result_data = self.transport.recv(0)?;
389 deserialize_tensors(&result_data)
390 }
391 }
392}
393
394pub struct DataParallelConfig {
400 pub config: DistributedConfig,
401 pub aggregator: Box<dyn GradientAggregator>,
402}
403
404#[derive(Debug, Clone, Serialize, Deserialize)]
410pub struct CompressedGradient {
411 pub indices: Vec<usize>,
412 pub values: Vec<f32>,
413 pub original_len: usize,
414}
415
416pub struct TopKCompressor {
418 pub ratio: f32,
419}
420
421impl TopKCompressor {
422 pub fn new(ratio: f32) -> Self {
423 Self {
424 ratio: ratio.clamp(0.0, 1.0),
425 }
426 }
427}
428
429pub fn compress_gradients(gradients: &[Tensor], ratio: f32) -> Vec<CompressedGradient> {
431 let ratio = ratio.clamp(0.0, 1.0);
432 gradients
433 .iter()
434 .map(|t| {
435 let data = t.data();
436 let k = ((data.len() as f32 * ratio).ceil() as usize)
437 .max(1)
438 .min(data.len());
439
440 let mut indexed: Vec<(usize, f32)> = data.iter().copied().enumerate().collect();
442 indexed.sort_by(|a, b| {
443 b.1.abs()
444 .partial_cmp(&a.1.abs())
445 .unwrap_or(std::cmp::Ordering::Equal)
446 });
447 indexed.truncate(k);
448
449 let indices: Vec<usize> = indexed.iter().map(|(i, _)| *i).collect();
450 let values: Vec<f32> = indexed.iter().map(|(_, v)| *v).collect();
451
452 CompressedGradient {
453 indices,
454 values,
455 original_len: data.len(),
456 }
457 })
458 .collect()
459}
460
461pub fn decompress_gradients(
463 compressed: &[CompressedGradient],
464 shapes: &[Vec<usize>],
465) -> Result<Vec<Tensor>, ModelError> {
466 let mut result = Vec::with_capacity(compressed.len());
467 for (cg, shape) in compressed.iter().zip(shapes.iter()) {
468 let mut data = vec![0.0f32; cg.original_len];
469 for (&idx, &val) in cg.indices.iter().zip(cg.values.iter()) {
470 if idx < data.len() {
471 data[idx] = val;
472 }
473 }
474 result.push(Tensor::from_vec(shape.clone(), data)?);
475 }
476 Ok(result)
477}
478
479pub fn distributed_train_step<F, G>(
512 compute_gradients_fn: F,
513 apply_gradients_fn: G,
514 aggregator: &mut dyn GradientAggregator,
515) -> Result<f32, ModelError>
516where
517 F: FnOnce() -> Result<(f32, Vec<Tensor>), ModelError>,
518 G: FnOnce(&[Tensor]) -> Result<(), ModelError>,
519{
520 let (loss, local_grads) = compute_gradients_fn()?;
521 let aggregated = aggregator.aggregate(&local_grads)?;
522 apply_gradients_fn(&aggregated)?;
523 Ok(loss)
524}
525
526fn serialize_tensors(tensors: &[Tensor]) -> Result<Vec<u8>, ModelError> {
531 let mut entries = Vec::new();
532 for t in tensors {
533 let shape = t.shape().to_vec();
534 let data = t.data().to_vec();
535 entries.push((shape, data));
536 }
537 serde_json::to_vec(&entries).map_err(|e| ModelError::CheckpointSerialization {
538 message: format!("tensor serialization failed: {e}"),
539 })
540}
541
542fn deserialize_tensors(data: &[u8]) -> Result<Vec<Tensor>, ModelError> {
543 let entries: Vec<(Vec<usize>, Vec<f32>)> =
544 serde_json::from_slice(data).map_err(|e| ModelError::CheckpointSerialization {
545 message: format!("tensor deserialization failed: {e}"),
546 })?;
547 let mut tensors = Vec::with_capacity(entries.len());
548 for (shape, values) in entries {
549 tensors.push(Tensor::from_vec(shape, values)?);
550 }
551 Ok(tensors)
552}
553
554pub struct PipelineStage {
564 pub start_layer: usize,
566 pub end_layer: usize,
567 pub rank: usize,
569}
570
571pub struct PipelineParallelConfig {
581 pub num_stages: usize,
583 pub num_micro_batches: usize,
585}
586
587impl PipelineParallelConfig {
588 pub fn new(num_stages: usize, num_micro_batches: usize) -> Self {
589 Self {
590 num_stages,
591 num_micro_batches,
592 }
593 }
594}
595
596pub fn split_into_stages(num_layers: usize, num_stages: usize) -> Vec<PipelineStage> {
598 assert!(num_stages > 0 && num_stages <= num_layers);
599 let base = num_layers / num_stages;
600 let remainder = num_layers % num_stages;
601 let mut stages = Vec::with_capacity(num_stages);
602 let mut start = 0;
603 for rank in 0..num_stages {
604 let extra = if rank < remainder { 1 } else { 0 };
605 let end = start + base + extra;
606 stages.push(PipelineStage {
607 start_layer: start,
608 end_layer: end,
609 rank,
610 });
611 start = end;
612 }
613 stages
614}
615
616pub fn shard_tensor(tensor: &Tensor, num_shards: usize) -> Result<Vec<Tensor>, ModelError> {
632 if num_shards == 0 {
633 return Err(ModelError::InvalidConv2dStride {
634 stride_h: 0,
635 stride_w: 0,
636 });
637 }
638 let shape = tensor.shape();
639 if shape.is_empty() {
640 return Ok(vec![tensor.clone()]);
641 }
642 let first_dim = shape[0];
643 if num_shards > first_dim {
644 return Err(ModelError::InvalidParameterShape {
645 parameter: "shard_tensor",
646 expected: vec![num_shards],
647 got: shape.to_vec(),
648 });
649 }
650
651 let data = tensor.data();
652 let stride = data.len() / first_dim; let base = first_dim / num_shards;
654 let remainder = first_dim % num_shards;
655
656 let mut shards = Vec::with_capacity(num_shards);
657 let mut offset = 0;
658 for i in 0..num_shards {
659 let rows = base + if i < remainder { 1 } else { 0 };
660 let start = offset * stride;
661 let end = (offset + rows) * stride;
662 let mut shard_shape = shape.to_vec();
663 shard_shape[0] = rows;
664 shards.push(Tensor::from_vec(shard_shape, data[start..end].to_vec())?);
665 offset += rows;
666 }
667 Ok(shards)
668}
669
670pub fn gather_shards(shards: &[Tensor]) -> Result<Tensor, ModelError> {
676 if shards.is_empty() {
677 return Err(ModelError::InvalidParameterShape {
678 parameter: "gather_shards",
679 expected: vec![1],
680 got: vec![0],
681 });
682 }
683 if shards.len() == 1 {
684 return Ok(shards[0].clone());
685 }
686
687 let first_shape = shards[0].shape();
688 let tail: Vec<usize> = first_shape[1..].to_vec();
689 let stride: usize = tail.iter().product::<usize>().max(1);
690
691 let total_rows: usize = shards.iter().map(|s| s.shape()[0]).sum();
692 let mut data = Vec::with_capacity(total_rows * stride);
693 for shard in shards {
694 data.extend_from_slice(shard.data());
695 }
696
697 let mut out_shape = vec![total_rows];
698 out_shape.extend_from_slice(&tail);
699 Tensor::from_vec(out_shape, data).map_err(ModelError::Tensor)
700}
701
702#[cfg(test)]
703mod tests {
704 use super::*;
705
706 #[test]
707 fn pipeline_stages_cover_all_layers() {
708 let stages = split_into_stages(10, 3);
709 assert_eq!(stages.len(), 3);
710 assert_eq!(stages[0].start_layer, 0);
711 assert_eq!(stages[2].end_layer, 10);
712 for i in 1..stages.len() {
713 assert_eq!(stages[i].start_layer, stages[i - 1].end_layer);
714 }
715 }
716
717 #[test]
718 fn shard_and_gather_roundtrip() {
719 let t = Tensor::from_vec(vec![6, 4], (0..24).map(|i| i as f32).collect()).unwrap();
720 let shards = shard_tensor(&t, 3).unwrap();
721 assert_eq!(shards.len(), 3);
722 assert_eq!(shards[0].shape(), &[2, 4]);
723 assert_eq!(shards[1].shape(), &[2, 4]);
724 assert_eq!(shards[2].shape(), &[2, 4]);
725 let gathered = gather_shards(&shards).unwrap();
726 assert_eq!(gathered.shape(), t.shape());
727 assert_eq!(gathered.data(), t.data());
728 }
729
730 #[test]
731 fn shard_uneven_split() {
732 let t = Tensor::from_vec(vec![7, 2], (0..14).map(|i| i as f32).collect()).unwrap();
733 let shards = shard_tensor(&t, 3).unwrap();
734 assert_eq!(shards[0].shape()[0], 3);
736 assert_eq!(shards[1].shape()[0], 2);
737 assert_eq!(shards[2].shape()[0], 2);
738 let gathered = gather_shards(&shards).unwrap();
739 assert_eq!(gathered.data(), t.data());
740 }
741}