Skip to main content

yscv_model/
distributed.rs

1//! Distributed training primitives for multi-GPU and multi-node training.
2//!
3//! This module provides the building blocks for scaling training across multiple
4//! workers (processes, threads, or machines):
5//!
6//! - **Transport layer** ([`Transport`] trait, [`InProcessTransport`]) -- byte-level
7//!   send/recv/barrier used by aggregation strategies. For TCP-based networking see
8//!   [`crate::tcp_transport`].
9//! - **Gradient aggregation** ([`AllReduceAggregator`], [`ParameterServer`]) --
10//!   strategies for combining gradients from different workers.
11//! - **Gradient compression** ([`TopKCompressor`], [`compress_gradients`],
12//!   [`decompress_gradients`]) -- reduce communication volume by sending only the
13//!   most significant gradient elements.
14//! - **Pipeline parallelism** ([`PipelineParallelConfig`], [`split_into_stages`]) --
15//!   partition a model's layers across stages so different micro-batches execute
16//!   concurrently in different stages.
17//! - **Tensor sharding / FSDP** ([`shard_tensor`], [`gather_shards`]) -- split
18//!   parameter tensors along the first dimension for Fully Sharded Data Parallel
19//!   style training.
20//! - **Distributed train step** ([`distributed_train_step`]) -- combines local
21//!   gradient computation, aggregation, and parameter update into a single call.
22//!
23//! ## Limitations
24//! - TCP transport only (no RDMA/InfiniBand)
25//! - Gradient synchronization is synchronous (no async overlap with compute)
26//! - Tested up to ~8 nodes; not validated at datacenter scale (100+)
27//! - No NCCL/MPI backend (use TCP AllReduce or ParameterServer)
28//!
29//! # Quick start
30//!
31//! ```rust,ignore
32//! use yscv_model::{AllReduceAggregator, InProcessTransport, DistributedConfig, distributed_train_step};
33//!
34//! // Create two in-process workers for testing
35//! let transports = InProcessTransport::create_group(2);
36//!
37//! let config = DistributedConfig { world_size: 2, rank: 0, coordinator_addr: String::new() };
38//! let mut aggregator = AllReduceAggregator::new(config, Box::new(transports.into_iter().next().unwrap()));
39//!
40//! let loss = distributed_train_step(
41//!     || Ok((0.5, vec![/* local gradients */])),
42//!     |aggregated| { /* apply gradients */ Ok(()) },
43//!     &mut aggregator,
44//! ).unwrap();
45//! ```
46
47use std::sync::mpsc;
48use std::sync::{Arc, Mutex};
49
50use serde::{Deserialize, Serialize};
51
52use yscv_tensor::Tensor;
53
54use crate::ModelError;
55
56// ---------------------------------------------------------------------------
57// Configuration
58// ---------------------------------------------------------------------------
59
60/// Identifies a worker inside a distributed training group.
61#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
62pub struct DistributedConfig {
63    /// Total number of workers.
64    pub world_size: usize,
65    /// Zero-based rank of this worker.
66    pub rank: usize,
67    /// `host:port` of the rank-0 coordinator.
68    pub coordinator_addr: String,
69}
70
71// ---------------------------------------------------------------------------
72// Transport trait
73// ---------------------------------------------------------------------------
74
75/// Byte-level communication primitive used by aggregation strategies.
76pub trait Transport: Send {
77    /// Send `data` to the worker with the given rank.
78    fn send(&self, dest_rank: usize, data: &[u8]) -> Result<(), ModelError>;
79    /// Receive data from the worker with the given rank.
80    fn recv(&self, src_rank: usize) -> Result<Vec<u8>, ModelError>;
81    /// Block until all workers reach the barrier.
82    fn barrier(&self) -> Result<(), ModelError>;
83}
84
85/// In-process transport backed by `mpsc` channels (for testing).
86pub struct InProcessTransport {
87    rank: usize,
88    _world_size: usize,
89    // senders[dest_rank] = sender for messages TO that rank FROM this rank
90    senders: Vec<mpsc::Sender<Vec<u8>>>,
91    // receivers[src_rank] = receiver for messages FROM that rank TO this rank
92    receivers: Vec<mpsc::Receiver<Vec<u8>>>,
93    barrier_state: Arc<Mutex<BarrierState>>,
94}
95
96struct BarrierState {
97    count: usize,
98    generation: u64,
99    target: usize,
100    condvar_senders: Vec<mpsc::Sender<()>>,
101    condvar_receivers: Vec<Option<mpsc::Receiver<()>>>,
102}
103
104impl InProcessTransport {
105    /// Creates `world_size` connected transports for in-process testing.
106    #[allow(clippy::type_complexity)]
107    pub fn create_group(world_size: usize) -> Vec<Self> {
108        // Build sender/receiver grids: sender_grid[src][dest], receiver_grid[src][dest]
109        let mut sender_grid: Vec<Vec<Option<mpsc::Sender<Vec<u8>>>>> = Vec::new();
110        let mut receiver_grid: Vec<Vec<Option<mpsc::Receiver<Vec<u8>>>>> = Vec::new();
111        for _ in 0..world_size {
112            let mut s_row = Vec::new();
113            let mut r_row = Vec::new();
114            for _ in 0..world_size {
115                let (tx, rx) = mpsc::channel();
116                s_row.push(Some(tx));
117                r_row.push(Some(rx));
118            }
119            sender_grid.push(s_row);
120            receiver_grid.push(r_row);
121        }
122
123        // Barrier state
124        let mut barrier_senders = Vec::new();
125        let mut barrier_receivers = Vec::new();
126        for _ in 0..world_size {
127            let (tx, rx) = mpsc::channel();
128            barrier_senders.push(tx);
129            barrier_receivers.push(Some(rx));
130        }
131
132        let barrier_state = Arc::new(Mutex::new(BarrierState {
133            count: 0,
134            generation: 0,
135            target: world_size,
136            condvar_senders: barrier_senders,
137            condvar_receivers: barrier_receivers,
138        }));
139
140        let mut transports = Vec::new();
141
142        for r in 0..world_size {
143            let mut my_senders = Vec::new();
144            let mut my_receivers = Vec::new();
145
146            // transport[r].senders[d] = sender_grid[r][d]
147            for item in sender_grid[r].iter_mut() {
148                my_senders.push(item.take().expect("sender not yet taken"));
149            }
150            // transport[r].receivers[s] = receiver_grid[s][r]
151            for row in receiver_grid.iter_mut() {
152                my_receivers.push(row[r].take().expect("receiver not yet taken"));
153            }
154
155            transports.push(InProcessTransport {
156                rank: r,
157                _world_size: world_size,
158                senders: my_senders,
159                receivers: my_receivers,
160                barrier_state: barrier_state.clone(),
161            });
162        }
163
164        transports
165    }
166}
167
168impl Transport for InProcessTransport {
169    fn send(&self, dest_rank: usize, data: &[u8]) -> Result<(), ModelError> {
170        self.senders[dest_rank].send(data.to_vec()).map_err(|e| {
171            ModelError::CheckpointSerialization {
172                message: format!("transport send failed: {e}"),
173            }
174        })
175    }
176
177    fn recv(&self, src_rank: usize) -> Result<Vec<u8>, ModelError> {
178        self.receivers[src_rank]
179            .recv()
180            .map_err(|e| ModelError::CheckpointSerialization {
181                message: format!("transport recv failed: {e}"),
182            })
183    }
184
185    fn barrier(&self) -> Result<(), ModelError> {
186        let mut state =
187            self.barrier_state
188                .lock()
189                .map_err(|e| ModelError::CheckpointSerialization {
190                    message: format!("barrier lock failed: {e}"),
191                })?;
192        state.count += 1;
193        if state.count == state.target {
194            state.count = 0;
195            state.generation += 1;
196            // Wake all waiting workers
197            for tx in &state.condvar_senders {
198                let _ = tx.send(());
199            }
200            Ok(())
201        } else {
202            let rx = state.condvar_receivers[self.rank].take();
203            drop(state);
204            if let Some(rx) = rx {
205                let _ = rx.recv();
206            }
207            // Put receiver back for next barrier
208            let (tx, rx) = mpsc::channel();
209            let mut state =
210                self.barrier_state
211                    .lock()
212                    .map_err(|e| ModelError::CheckpointSerialization {
213                        message: format!("barrier lock failed: {e}"),
214                    })?;
215            state.condvar_senders[self.rank] = tx;
216            state.condvar_receivers[self.rank] = Some(rx);
217            Ok(())
218        }
219    }
220}
221
222// ---------------------------------------------------------------------------
223// Gradient Aggregator trait
224// ---------------------------------------------------------------------------
225
226/// Strategy for combining gradients across distributed workers.
227pub trait GradientAggregator: Send {
228    /// Aggregate local gradients, returning the combined result.
229    fn aggregate(&mut self, local_gradients: &[Tensor]) -> Result<Vec<Tensor>, ModelError>;
230}
231
232/// No-op aggregator for single-machine training (API uniformity).
233pub struct LocalAggregator;
234
235impl GradientAggregator for LocalAggregator {
236    fn aggregate(&mut self, local_gradients: &[Tensor]) -> Result<Vec<Tensor>, ModelError> {
237        Ok(local_gradients.to_vec())
238    }
239}
240
241/// All-reduce aggregator: averages gradients across all workers via ring reduce.
242///
243/// Each worker independently computes local gradients and then calls
244/// [`GradientAggregator::aggregate`]. Internally the aggregator serialises the
245/// gradient tensors, performs a ring all-reduce (scatter-reduce followed by
246/// all-gather) over the configured [`Transport`], and returns the element-wise
247/// average. This is the most common strategy for synchronous data-parallel
248/// training because every worker ends up with the same averaged gradients
249/// without a central bottleneck.
250///
251/// # Example
252///
253/// ```rust,ignore
254/// let transports = InProcessTransport::create_group(2);
255/// let cfg = DistributedConfig { world_size: 2, rank: 0, coordinator_addr: String::new() };
256/// let mut agg = AllReduceAggregator::new(cfg, Box::new(transports.into_iter().next().unwrap()));
257/// let averaged = agg.aggregate(&local_grads).unwrap();
258/// ```
259pub struct AllReduceAggregator {
260    config: DistributedConfig,
261    transport: Box<dyn Transport>,
262}
263
264impl AllReduceAggregator {
265    pub fn new(config: DistributedConfig, transport: Box<dyn Transport>) -> Self {
266        Self { config, transport }
267    }
268}
269
270impl GradientAggregator for AllReduceAggregator {
271    fn aggregate(&mut self, local_gradients: &[Tensor]) -> Result<Vec<Tensor>, ModelError> {
272        let world = self.config.world_size;
273        if world <= 1 {
274            return Ok(local_gradients.to_vec());
275        }
276
277        // Serialize local gradients
278        let local_bytes = serialize_tensors(local_gradients)?;
279
280        // Ring all-reduce: each rank sends to (rank+1) % world, receives from (rank-1+world) % world
281        let next = (self.config.rank + 1) % world;
282        let prev = (self.config.rank + world - 1) % world;
283
284        // Scatter-reduce phase
285        let mut accumulated = local_bytes.clone();
286        for _ in 0..(world - 1) {
287            self.transport.send(next, &accumulated)?;
288            let received = self.transport.recv(prev)?;
289            // Accumulate: element-wise add
290            let recv_tensors = deserialize_tensors(&received)?;
291            let acc_tensors = deserialize_tensors(&accumulated)?;
292            let mut summed = Vec::new();
293            for (a, r) in acc_tensors.iter().zip(recv_tensors.iter()) {
294                summed.push(a.add(r)?);
295            }
296            accumulated = serialize_tensors(&summed)?;
297        }
298
299        // Average
300        let result_tensors = deserialize_tensors(&accumulated)?;
301        let scale = 1.0 / world as f32;
302        let mut averaged = Vec::new();
303        for t in &result_tensors {
304            averaged.push(t.scale(scale));
305        }
306
307        self.transport.barrier()?;
308        Ok(averaged)
309    }
310}
311
312// ---------------------------------------------------------------------------
313// Parameter Server
314// ---------------------------------------------------------------------------
315
316/// Centralized parameter server: rank 0 collects, averages, and broadcasts
317/// gradients (or parameters).
318///
319/// Use this instead of [`AllReduceAggregator`] when you want a star topology
320/// (all workers communicate only with rank 0). This is simpler to reason about
321/// and works well when the coordinator has high bandwidth, but can become a
322/// bottleneck at large scale compared to ring all-reduce.
323///
324/// * [`broadcast_params`](Self::broadcast_params) -- rank 0 sends the current
325///   parameters to all workers (useful at initialisation or after a checkpoint
326///   restore).
327/// * [`reduce_gradients`](Self::reduce_gradients) -- workers send local
328///   gradients to rank 0; rank 0 averages them and broadcasts the result back.
329pub struct ParameterServer {
330    config: DistributedConfig,
331    transport: Box<dyn Transport>,
332}
333
334impl ParameterServer {
335    pub fn new(config: DistributedConfig, transport: Box<dyn Transport>) -> Self {
336        Self { config, transport }
337    }
338
339    /// Rank 0 broadcasts params to all workers; workers receive and return them.
340    pub fn broadcast_params(&self, params: &[Tensor]) -> Result<Vec<Tensor>, ModelError> {
341        let world = self.config.world_size;
342        if world <= 1 {
343            return Ok(params.to_vec());
344        }
345
346        if self.config.rank == 0 {
347            let data = serialize_tensors(params)?;
348            for dest in 1..world {
349                self.transport.send(dest, &data)?;
350            }
351            Ok(params.to_vec())
352        } else {
353            let data = self.transport.recv(0)?;
354            deserialize_tensors(&data)
355        }
356    }
357
358    /// Workers send gradients to rank 0; rank 0 averages and returns result.
359    pub fn reduce_gradients(&self, grads: &[Tensor]) -> Result<Vec<Tensor>, ModelError> {
360        let world = self.config.world_size;
361        if world <= 1 {
362            return Ok(grads.to_vec());
363        }
364
365        if self.config.rank == 0 {
366            let mut acc = grads.to_vec();
367            for src in 1..world {
368                let data = self.transport.recv(src)?;
369                let remote_grads = deserialize_tensors(&data)?;
370                for (a, r) in acc.iter_mut().zip(remote_grads.iter()) {
371                    *a = a.add(r)?;
372                }
373            }
374            let scale = 1.0 / world as f32;
375            let mut averaged = Vec::new();
376            for t in &acc {
377                averaged.push(t.scale(scale));
378            }
379            // Broadcast averaged gradients back
380            let data = serialize_tensors(&averaged)?;
381            for dest in 1..world {
382                self.transport.send(dest, &data)?;
383            }
384            Ok(averaged)
385        } else {
386            let data = serialize_tensors(grads)?;
387            self.transport.send(0, &data)?;
388            let result_data = self.transport.recv(0)?;
389            deserialize_tensors(&result_data)
390        }
391    }
392}
393
394// ---------------------------------------------------------------------------
395// Data-parallel config
396// ---------------------------------------------------------------------------
397
398/// Configuration for data-parallel distributed training.
399pub struct DataParallelConfig {
400    pub config: DistributedConfig,
401    pub aggregator: Box<dyn GradientAggregator>,
402}
403
404// ---------------------------------------------------------------------------
405// Gradient compression
406// ---------------------------------------------------------------------------
407
408/// Compressed gradient: stores only the top-k elements by magnitude.
409#[derive(Debug, Clone, Serialize, Deserialize)]
410pub struct CompressedGradient {
411    pub indices: Vec<usize>,
412    pub values: Vec<f32>,
413    pub original_len: usize,
414}
415
416/// Top-K gradient compressor: keeps only the top `ratio` fraction of gradients.
417pub struct TopKCompressor {
418    pub ratio: f32,
419}
420
421impl TopKCompressor {
422    pub fn new(ratio: f32) -> Self {
423        Self {
424            ratio: ratio.clamp(0.0, 1.0),
425        }
426    }
427}
428
429/// Compress gradients by keeping only top-k% elements by magnitude.
430pub fn compress_gradients(gradients: &[Tensor], ratio: f32) -> Vec<CompressedGradient> {
431    let ratio = ratio.clamp(0.0, 1.0);
432    gradients
433        .iter()
434        .map(|t| {
435            let data = t.data();
436            let k = ((data.len() as f32 * ratio).ceil() as usize)
437                .max(1)
438                .min(data.len());
439
440            // Find top-k by magnitude
441            let mut indexed: Vec<(usize, f32)> = data.iter().copied().enumerate().collect();
442            indexed.sort_by(|a, b| {
443                b.1.abs()
444                    .partial_cmp(&a.1.abs())
445                    .unwrap_or(std::cmp::Ordering::Equal)
446            });
447            indexed.truncate(k);
448
449            let indices: Vec<usize> = indexed.iter().map(|(i, _)| *i).collect();
450            let values: Vec<f32> = indexed.iter().map(|(_, v)| *v).collect();
451
452            CompressedGradient {
453                indices,
454                values,
455                original_len: data.len(),
456            }
457        })
458        .collect()
459}
460
461/// Decompress gradients back to full tensors.
462pub fn decompress_gradients(
463    compressed: &[CompressedGradient],
464    shapes: &[Vec<usize>],
465) -> Result<Vec<Tensor>, ModelError> {
466    let mut result = Vec::with_capacity(compressed.len());
467    for (cg, shape) in compressed.iter().zip(shapes.iter()) {
468        let mut data = vec![0.0f32; cg.original_len];
469        for (&idx, &val) in cg.indices.iter().zip(cg.values.iter()) {
470            if idx < data.len() {
471                data[idx] = val;
472            }
473        }
474        result.push(Tensor::from_vec(shape.clone(), data)?);
475    }
476    Ok(result)
477}
478
479// ---------------------------------------------------------------------------
480// Distributed training step
481// ---------------------------------------------------------------------------
482
483/// Performs a single distributed training step: forward, backward, aggregate, update.
484///
485/// This is the main entry point for one iteration of distributed training.
486/// It calls `compute_gradients_fn` to run the local forward and backward pass,
487/// then aggregates gradients across workers via the provided `aggregator`,
488/// and finally applies the aggregated gradients through `apply_gradients_fn`.
489///
490/// # Arguments
491///
492/// * `compute_gradients_fn` -- closure that returns `(loss, local_gradients)`.
493/// * `apply_gradients_fn` -- closure that receives the aggregated gradients and
494///   updates model parameters (e.g. via an optimizer step).
495/// * `aggregator` -- the gradient aggregation strategy (e.g. [`AllReduceAggregator`]
496///   or [`LocalAggregator`] for single-worker training).
497///
498/// # Returns
499///
500/// The scalar loss value produced by `compute_gradients_fn`.
501///
502/// # Example
503///
504/// ```rust,ignore
505/// let loss = distributed_train_step(
506///     || { /* forward + backward */ Ok((loss_val, grads)) },
507///     |agg_grads| { optimizer.apply(agg_grads); Ok(()) },
508///     &mut aggregator,
509/// )?;
510/// ```
511pub fn distributed_train_step<F, G>(
512    compute_gradients_fn: F,
513    apply_gradients_fn: G,
514    aggregator: &mut dyn GradientAggregator,
515) -> Result<f32, ModelError>
516where
517    F: FnOnce() -> Result<(f32, Vec<Tensor>), ModelError>,
518    G: FnOnce(&[Tensor]) -> Result<(), ModelError>,
519{
520    let (loss, local_grads) = compute_gradients_fn()?;
521    let aggregated = aggregator.aggregate(&local_grads)?;
522    apply_gradients_fn(&aggregated)?;
523    Ok(loss)
524}
525
526// ---------------------------------------------------------------------------
527// Serialization helpers
528// ---------------------------------------------------------------------------
529
530fn serialize_tensors(tensors: &[Tensor]) -> Result<Vec<u8>, ModelError> {
531    let mut entries = Vec::new();
532    for t in tensors {
533        let shape = t.shape().to_vec();
534        let data = t.data().to_vec();
535        entries.push((shape, data));
536    }
537    serde_json::to_vec(&entries).map_err(|e| ModelError::CheckpointSerialization {
538        message: format!("tensor serialization failed: {e}"),
539    })
540}
541
542fn deserialize_tensors(data: &[u8]) -> Result<Vec<Tensor>, ModelError> {
543    let entries: Vec<(Vec<usize>, Vec<f32>)> =
544        serde_json::from_slice(data).map_err(|e| ModelError::CheckpointSerialization {
545            message: format!("tensor deserialization failed: {e}"),
546        })?;
547    let mut tensors = Vec::with_capacity(entries.len());
548    for (shape, values) in entries {
549        tensors.push(Tensor::from_vec(shape, values)?);
550    }
551    Ok(tensors)
552}
553
554// ---------------------------------------------------------------------------
555// Pipeline Parallelism
556// ---------------------------------------------------------------------------
557
558/// Pipeline parallelism: split a sequential model across multiple stages.
559///
560/// Each stage holds a contiguous subset of layers. During forward pass,
561/// micro-batches flow through stages sequentially. This enables training
562/// models larger than single-device memory.
563pub struct PipelineStage {
564    /// Layer indices [start, end) in the original model
565    pub start_layer: usize,
566    pub end_layer: usize,
567    /// Stage rank in the pipeline
568    pub rank: usize,
569}
570
571/// Configuration for pipeline-parallel training.
572///
573/// Pipeline parallelism partitions a model's sequential layers into
574/// `num_stages` stages, each assigned to a different device. During training,
575/// the mini-batch is split into `num_micro_batches` micro-batches that flow
576/// through the pipeline concurrently (GPipe-style scheduling), reducing the
577/// bubble time compared to naive sequential execution.
578///
579/// Use [`split_into_stages`] to compute the layer ranges for each stage.
580pub struct PipelineParallelConfig {
581    /// Number of pipeline stages
582    pub num_stages: usize,
583    /// Number of micro-batches per mini-batch
584    pub num_micro_batches: usize,
585}
586
587impl PipelineParallelConfig {
588    pub fn new(num_stages: usize, num_micro_batches: usize) -> Self {
589        Self {
590            num_stages,
591            num_micro_batches,
592        }
593    }
594}
595
596/// Split a model with `num_layers` layers into `num_stages` roughly equal stages.
597pub fn split_into_stages(num_layers: usize, num_stages: usize) -> Vec<PipelineStage> {
598    assert!(num_stages > 0 && num_stages <= num_layers);
599    let base = num_layers / num_stages;
600    let remainder = num_layers % num_stages;
601    let mut stages = Vec::with_capacity(num_stages);
602    let mut start = 0;
603    for rank in 0..num_stages {
604        let extra = if rank < remainder { 1 } else { 0 };
605        let end = start + base + extra;
606        stages.push(PipelineStage {
607            start_layer: start,
608            end_layer: end,
609            rank,
610        });
611        start = end;
612    }
613    stages
614}
615
616// ---------------------------------------------------------------------------
617// Tensor Sharding (FSDP-lite)
618// ---------------------------------------------------------------------------
619
620/// Shard a tensor along its first dimension into `num_shards` roughly equal parts.
621///
622/// This is the core primitive for Fully Sharded Data Parallel (FSDP)-style
623/// training: large parameter tensors are split across workers so that each
624/// worker stores only its shard. Before a forward or backward pass the
625/// shards are gathered (see [`gather_shards`]), and after the pass only the
626/// local shard's gradients are kept.
627///
628/// Each shard is a separate [`Tensor`] that can be placed on a different device.
629/// If `num_shards` does not evenly divide the first dimension, earlier shards
630/// receive one extra row.
631pub fn shard_tensor(tensor: &Tensor, num_shards: usize) -> Result<Vec<Tensor>, ModelError> {
632    if num_shards == 0 {
633        return Err(ModelError::InvalidConv2dStride {
634            stride_h: 0,
635            stride_w: 0,
636        });
637    }
638    let shape = tensor.shape();
639    if shape.is_empty() {
640        return Ok(vec![tensor.clone()]);
641    }
642    let first_dim = shape[0];
643    if num_shards > first_dim {
644        return Err(ModelError::InvalidParameterShape {
645            parameter: "shard_tensor",
646            expected: vec![num_shards],
647            got: shape.to_vec(),
648        });
649    }
650
651    let data = tensor.data();
652    let stride = data.len() / first_dim; // elements per row along dim 0
653    let base = first_dim / num_shards;
654    let remainder = first_dim % num_shards;
655
656    let mut shards = Vec::with_capacity(num_shards);
657    let mut offset = 0;
658    for i in 0..num_shards {
659        let rows = base + if i < remainder { 1 } else { 0 };
660        let start = offset * stride;
661        let end = (offset + rows) * stride;
662        let mut shard_shape = shape.to_vec();
663        shard_shape[0] = rows;
664        shards.push(Tensor::from_vec(shard_shape, data[start..end].to_vec())?);
665        offset += rows;
666    }
667    Ok(shards)
668}
669
670/// Reassemble shards (produced by [`shard_tensor`]) back into a single tensor.
671///
672/// Concatenates along the first dimension. The invariant
673/// `gather_shards(&shard_tensor(t, n)?) == Ok(t.clone())` holds for any
674/// valid `n`.
675pub fn gather_shards(shards: &[Tensor]) -> Result<Tensor, ModelError> {
676    if shards.is_empty() {
677        return Err(ModelError::InvalidParameterShape {
678            parameter: "gather_shards",
679            expected: vec![1],
680            got: vec![0],
681        });
682    }
683    if shards.len() == 1 {
684        return Ok(shards[0].clone());
685    }
686
687    let first_shape = shards[0].shape();
688    let tail: Vec<usize> = first_shape[1..].to_vec();
689    let stride: usize = tail.iter().product::<usize>().max(1);
690
691    let total_rows: usize = shards.iter().map(|s| s.shape()[0]).sum();
692    let mut data = Vec::with_capacity(total_rows * stride);
693    for shard in shards {
694        data.extend_from_slice(shard.data());
695    }
696
697    let mut out_shape = vec![total_rows];
698    out_shape.extend_from_slice(&tail);
699    Tensor::from_vec(out_shape, data).map_err(ModelError::Tensor)
700}
701
702#[cfg(test)]
703mod tests {
704    use super::*;
705
706    #[test]
707    fn pipeline_stages_cover_all_layers() {
708        let stages = split_into_stages(10, 3);
709        assert_eq!(stages.len(), 3);
710        assert_eq!(stages[0].start_layer, 0);
711        assert_eq!(stages[2].end_layer, 10);
712        for i in 1..stages.len() {
713            assert_eq!(stages[i].start_layer, stages[i - 1].end_layer);
714        }
715    }
716
717    #[test]
718    fn shard_and_gather_roundtrip() {
719        let t = Tensor::from_vec(vec![6, 4], (0..24).map(|i| i as f32).collect()).unwrap();
720        let shards = shard_tensor(&t, 3).unwrap();
721        assert_eq!(shards.len(), 3);
722        assert_eq!(shards[0].shape(), &[2, 4]);
723        assert_eq!(shards[1].shape(), &[2, 4]);
724        assert_eq!(shards[2].shape(), &[2, 4]);
725        let gathered = gather_shards(&shards).unwrap();
726        assert_eq!(gathered.shape(), t.shape());
727        assert_eq!(gathered.data(), t.data());
728    }
729
730    #[test]
731    fn shard_uneven_split() {
732        let t = Tensor::from_vec(vec![7, 2], (0..14).map(|i| i as f32).collect()).unwrap();
733        let shards = shard_tensor(&t, 3).unwrap();
734        // 7 / 3 = 2 base + 1 remainder → first shard gets 3, others get 2
735        assert_eq!(shards[0].shape()[0], 3);
736        assert_eq!(shards[1].shape()[0], 2);
737        assert_eq!(shards[2].shape()[0], 2);
738        let gathered = gather_shards(&shards).unwrap();
739        assert_eq!(gathered.data(), t.data());
740    }
741}