Skip to main content

scirs2_neural/training/
tensor_parallel.rs

1//! Tensor parallelism primitives: column-parallel, row-parallel linear layers,
2//! and a vocabulary-partitioned parallel embedding.
3//!
4//! These components simulate the Megatron-LM tensor-parallel strategy
5//! (Shoeybi et al., 2019) in a single-process environment by keeping
6//! separate weight slices per "worker" in memory.
7//!
8//! ## Column Parallel Linear
9//!
10//! Splits the output dimension across workers.  Each worker computes
11//! `y_i = x @ W_i + b_i` where `W_i = W[:, i*chunk:(i+1)*chunk]`.
12//! The results are all-gathered (concatenated) to form the full output.
13//!
14//! ## Row Parallel Linear
15//!
16//! Splits the input dimension across workers.  Each worker handles
17//! `x_i = x[:, i*chunk:(i+1)*chunk]` and computes `y_i = x_i @ W_i`.
18//! An all-reduce (sum) combines the partial results, then the shared bias is added.
19//!
20//! ## Parallel Embedding
21//!
22//! Partitions the vocabulary across workers.  Each token index is routed to
23//! the responsible worker; the resulting row is returned.
24//!
25//! ```rust
26//! use scirs2_neural::training::tensor_parallel::{
27//!     TensorParallelConfig, ColumnParallelLinear, RowParallelLinear, ParallelEmbedding,
28//! };
29//!
30//! let cfg = TensorParallelConfig::default();
31//! assert_eq!(cfg.n_workers, 2);
32//!
33//! let col = ColumnParallelLinear::new(8, 4, cfg.clone(), 0).expect("ok");
34//! let input = scirs2_core::ndarray::Array2::<f64>::ones((3, 8));
35//! let out = col.forward(&input).expect("ok");
36//! assert_eq!(out.shape(), [3, 4]);
37//! ```
38
39use crate::error::{NeuralError, Result as NeuralResult};
40use scirs2_core::ndarray::{s, Array1, Array2};
41use scirs2_core::random::rngs::SmallRng;
42use scirs2_core::random::{Rng, RngExt, SeedableRng};
43
44// ============================================================================
45// Config
46// ============================================================================
47
48/// Configuration for tensor-parallel layers.
49#[derive(Debug, Clone)]
50pub struct TensorParallelConfig {
51    /// Number of simulated workers.  Default: `2`.
52    pub n_workers: usize,
53    /// If `true`, all-gather the per-worker outputs after column-parallel linear.
54    /// Default: `true`.
55    pub gather_output: bool,
56}
57
58impl Default for TensorParallelConfig {
59    fn default() -> Self {
60        Self {
61            n_workers: 2,
62            gather_output: true,
63        }
64    }
65}
66
67// ============================================================================
68// Helpers
69// ============================================================================
70
71/// Xavier (Glorot) uniform initialisation scaled by `sqrt(2 / (fan_in + fan_out))`.
72fn xavier_init(rng: &mut SmallRng, n_in: usize, n_out: usize) -> f64 {
73    let scale = (6.0_f64 / (n_in + n_out) as f64).sqrt();
74    rng.random::<f64>() * 2.0 * scale - scale
75}
76
77// ============================================================================
78// ColumnParallelLinear
79// ============================================================================
80
81/// Splits the output dimension across `n_workers`.
82///
83/// With `n_workers = W` and output size `N`, each worker holds
84/// weights of shape `[n_in, N/W]` and bias of shape `[N/W]`.
85pub struct ColumnParallelLinear {
86    /// Per-worker weight slices `[n_in, n_out/n_workers]`.
87    local_weights: Vec<Array2<f64>>,
88    /// Per-worker bias vectors `[n_out/n_workers]`.
89    local_biases: Vec<Array1<f64>>,
90    config: TensorParallelConfig,
91    n_in: usize,
92    total_n_out: usize,
93}
94
95impl ColumnParallelLinear {
96    /// Create a column-parallel linear layer.
97    ///
98    /// # Errors
99    /// - `n_out` is not divisible by `config.n_workers`.
100    /// - `config.n_workers == 0`.
101    pub fn new(
102        n_in: usize,
103        n_out: usize,
104        config: TensorParallelConfig,
105        seed: u64,
106    ) -> NeuralResult<Self> {
107        if config.n_workers == 0 {
108            return Err(NeuralError::ConfigError(
109                "TensorParallelConfig.n_workers must be > 0".into(),
110            ));
111        }
112        if !n_out.is_multiple_of(config.n_workers) {
113            return Err(NeuralError::ConfigError(format!(
114                "n_out ({n_out}) must be divisible by n_workers ({})",
115                config.n_workers
116            )));
117        }
118
119        let chunk = n_out / config.n_workers;
120        let mut rng = SmallRng::seed_from_u64(seed);
121
122        let mut local_weights = Vec::with_capacity(config.n_workers);
123        let mut local_biases = Vec::with_capacity(config.n_workers);
124
125        for _ in 0..config.n_workers {
126            let w = Array2::from_shape_fn((n_in, chunk), |_| xavier_init(&mut rng, n_in, n_out));
127            let b = Array1::zeros(chunk);
128            local_weights.push(w);
129            local_biases.push(b);
130        }
131
132        Ok(Self {
133            local_weights,
134            local_biases,
135            config,
136            n_in,
137            total_n_out: n_out,
138        })
139    }
140
141    /// Forward pass.
142    ///
143    /// Each worker computes `y_i = input @ W_i + b_i`.  If `gather_output`,
144    /// the results are concatenated to `[batch, n_out]`; otherwise only the
145    /// first worker's output is returned (for single-process simulation with
146    /// `gather_output = false`).
147    pub fn forward(&self, input: &Array2<f64>) -> NeuralResult<Array2<f64>> {
148        let batch = input.shape()[0];
149        let n_in = input.shape()[1];
150        if n_in != self.n_in {
151            return Err(NeuralError::DimensionMismatch(format!(
152                "ColumnParallelLinear: expected n_in={}, got {n_in}",
153                self.n_in
154            )));
155        }
156
157        let mut parts: Vec<Array2<f64>> = Vec::with_capacity(self.config.n_workers);
158        for (w, b) in self.local_weights.iter().zip(self.local_biases.iter()) {
159            let y = input.dot(w) + b; // [batch, chunk]
160            parts.push(y);
161        }
162
163        if self.config.gather_output {
164            // Concatenate along feature axis.
165            let chunk = self.total_n_out / self.config.n_workers;
166            let mut gathered = Array2::<f64>::zeros((batch, self.total_n_out));
167            for (wi, part) in parts.iter().enumerate() {
168                let start = wi * chunk;
169                let end = start + chunk;
170                gathered.slice_mut(s![.., start..end]).assign(part);
171            }
172            Ok(gathered)
173        } else {
174            // Return first worker's slice.
175            parts
176                .into_iter()
177                .next()
178                .ok_or_else(|| NeuralError::ComputationError("no workers".into()))
179        }
180    }
181
182    /// Total output features (after all-gather).
183    pub fn n_out(&self) -> usize {
184        self.total_n_out
185    }
186
187    /// Number of simulated workers.
188    pub fn n_workers(&self) -> usize {
189        self.config.n_workers
190    }
191}
192
193// ============================================================================
194// RowParallelLinear
195// ============================================================================
196
197/// Splits the input dimension across `n_workers`.
198///
199/// Each worker holds weights `[n_in/n_workers, n_out]`.  The partial results
200/// are summed (all-reduce) and the shared bias is added once.
201pub struct RowParallelLinear {
202    /// Per-worker weight slices `[n_in/n_workers, n_out]`.
203    local_weights: Vec<Array2<f64>>,
204    /// Shared bias `[n_out]` (added after all-reduce).
205    bias: Array1<f64>,
206    config: TensorParallelConfig,
207    total_n_in: usize,
208    n_out: usize,
209}
210
211impl RowParallelLinear {
212    /// Create a row-parallel linear layer.
213    ///
214    /// # Errors
215    /// - `n_in` is not divisible by `config.n_workers`.
216    /// - `config.n_workers == 0`.
217    pub fn new(
218        n_in: usize,
219        n_out: usize,
220        config: TensorParallelConfig,
221        seed: u64,
222    ) -> NeuralResult<Self> {
223        if config.n_workers == 0 {
224            return Err(NeuralError::ConfigError(
225                "TensorParallelConfig.n_workers must be > 0".into(),
226            ));
227        }
228        if !n_in.is_multiple_of(config.n_workers) {
229            return Err(NeuralError::ConfigError(format!(
230                "n_in ({n_in}) must be divisible by n_workers ({})",
231                config.n_workers
232            )));
233        }
234
235        let chunk = n_in / config.n_workers;
236        let mut rng = SmallRng::seed_from_u64(seed);
237
238        let mut local_weights = Vec::with_capacity(config.n_workers);
239        for _ in 0..config.n_workers {
240            let w = Array2::from_shape_fn((chunk, n_out), |_| xavier_init(&mut rng, n_in, n_out));
241            local_weights.push(w);
242        }
243        let bias = Array1::zeros(n_out);
244
245        Ok(Self {
246            local_weights,
247            bias,
248            config,
249            total_n_in: n_in,
250            n_out,
251        })
252    }
253
254    /// Forward pass.
255    ///
256    /// Each worker computes `y_i = input_i @ W_i` where
257    /// `input_i = input[:, i*chunk:(i+1)*chunk]`.
258    /// The partial products are summed and the bias is added: `y = Σ y_i + bias`.
259    pub fn forward(&self, input: &Array2<f64>) -> NeuralResult<Array2<f64>> {
260        let batch = input.shape()[0];
261        let n_in = input.shape()[1];
262        if n_in != self.total_n_in {
263            return Err(NeuralError::DimensionMismatch(format!(
264                "RowParallelLinear: expected n_in={}, got {n_in}",
265                self.total_n_in
266            )));
267        }
268
269        let chunk = self.total_n_in / self.config.n_workers;
270        let mut acc = Array2::<f64>::zeros((batch, self.n_out));
271
272        for (wi, w) in self.local_weights.iter().enumerate() {
273            let start = wi * chunk;
274            let end = start + chunk;
275            let input_slice = input.slice(s![.., start..end]);
276            let partial = input_slice.dot(w); // [batch, n_out]
277            acc += &partial;
278        }
279
280        // Add shared bias.
281        acc += &self.bias;
282
283        Ok(acc)
284    }
285
286    /// Total input features (across all workers).
287    pub fn n_in(&self) -> usize {
288        self.total_n_in
289    }
290}
291
292// ============================================================================
293// ParallelEmbedding
294// ============================================================================
295
296/// Vocabulary-partitioned embedding table.
297///
298/// The vocabulary is split evenly across `n_workers`.  Each token index is
299/// routed to worker `index / (vocab_size / n_workers)` and the corresponding
300/// row is returned.
301pub struct ParallelEmbedding {
302    /// Per-worker embedding sub-tables `[vocab_size/n_workers, embed_dim]`.
303    local_tables: Vec<Array2<f64>>,
304    vocab_size: usize,
305    embed_dim: usize,
306    n_workers: usize,
307}
308
309impl ParallelEmbedding {
310    /// Create a parallel embedding table.
311    ///
312    /// # Errors
313    /// - `vocab_size` is not divisible by `n_workers`.
314    /// - `n_workers == 0`.
315    pub fn new(
316        vocab_size: usize,
317        embed_dim: usize,
318        n_workers: usize,
319        seed: u64,
320    ) -> NeuralResult<Self> {
321        if n_workers == 0 {
322            return Err(NeuralError::ConfigError(
323                "ParallelEmbedding: n_workers must be > 0".into(),
324            ));
325        }
326        if !vocab_size.is_multiple_of(n_workers) {
327            return Err(NeuralError::ConfigError(format!(
328                "vocab_size ({vocab_size}) must be divisible by n_workers ({n_workers})"
329            )));
330        }
331
332        let local_vocab = vocab_size / n_workers;
333        let mut rng = SmallRng::seed_from_u64(seed);
334
335        // Small normal initialisation for embeddings.
336        let mut local_tables = Vec::with_capacity(n_workers);
337        for _ in 0..n_workers {
338            let table = Array2::from_shape_fn((local_vocab, embed_dim), |_| {
339                (rng.random::<f64>() * 2.0 - 1.0) * 0.02
340            });
341            local_tables.push(table);
342        }
343
344        Ok(Self {
345            local_tables,
346            vocab_size,
347            embed_dim,
348            n_workers,
349        })
350    }
351
352    /// Look up embeddings for a sequence of token indices.
353    ///
354    /// Returns an array of shape `[len(indices), embed_dim]`.
355    ///
356    /// # Errors
357    /// Returns `NeuralError::InvalidArgument` if any index >= `vocab_size`.
358    pub fn forward(&self, indices: &[usize]) -> NeuralResult<Array2<f64>> {
359        let local_vocab = self.vocab_size / self.n_workers;
360        let mut out = Array2::<f64>::zeros((indices.len(), self.embed_dim));
361
362        for (row, &idx) in indices.iter().enumerate() {
363            if idx >= self.vocab_size {
364                return Err(NeuralError::InvalidArgument(format!(
365                    "token index {idx} out of range (vocab_size={})",
366                    self.vocab_size
367                )));
368            }
369            let worker_id = idx / local_vocab;
370            let local_idx = idx % local_vocab;
371            let embedding = self.local_tables[worker_id].slice(s![local_idx, ..]);
372            out.slice_mut(s![row, ..]).assign(&embedding);
373        }
374
375        Ok(out)
376    }
377
378    /// Total vocabulary size.
379    pub fn vocab_size(&self) -> usize {
380        self.vocab_size
381    }
382
383    /// Embedding dimension.
384    pub fn embed_dim(&self) -> usize {
385        self.embed_dim
386    }
387}
388
389// ============================================================================
390// Tests
391// ============================================================================
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use scirs2_core::ndarray::Array2;
397
398    // --- TensorParallelConfig ---
399
400    #[test]
401    fn test_default_config_n_workers_2() {
402        let cfg = TensorParallelConfig::default();
403        assert_eq!(cfg.n_workers, 2, "default n_workers must be 2");
404        assert!(cfg.gather_output, "default gather_output must be true");
405    }
406
407    // --- ColumnParallelLinear ---
408
409    #[test]
410    fn test_column_parallel_output_shape() {
411        let cfg = TensorParallelConfig {
412            n_workers: 2,
413            gather_output: true,
414        };
415        let layer = ColumnParallelLinear::new(8, 4, cfg, 0).expect("ok");
416        let input = Array2::<f64>::ones((5, 8));
417        let out = layer.forward(&input).expect("forward ok");
418        assert_eq!(out.shape(), [5, 4], "output shape should be [batch, n_out]");
419    }
420
421    #[test]
422    fn test_column_parallel_n_out() {
423        let cfg = TensorParallelConfig {
424            n_workers: 4,
425            gather_output: true,
426        };
427        let layer = ColumnParallelLinear::new(6, 8, cfg, 1).expect("ok");
428        assert_eq!(layer.n_out(), 8);
429        assert_eq!(layer.n_workers(), 4);
430    }
431
432    #[test]
433    fn test_column_parallel_n_workers_1_equivalent_to_linear() {
434        // With 1 worker, output should be same as a regular linear (W*X + b).
435        let n_in = 4;
436        let n_out = 6;
437        let cfg = TensorParallelConfig {
438            n_workers: 1,
439            gather_output: true,
440        };
441        let layer = ColumnParallelLinear::new(n_in, n_out, cfg, 42).expect("ok");
442        let input = Array2::from_shape_fn((3, n_in), |(i, j)| (i * n_in + j) as f64 * 0.1);
443        let out = layer.forward(&input).expect("forward ok");
444        // Manual linear: y = input @ W + b.
445        let expected = input.dot(&layer.local_weights[0]) + &layer.local_biases[0];
446        let diff: f64 = (&out - &expected).mapv(|v| v.abs()).sum();
447        assert!(
448            diff < 1e-12,
449            "n_workers=1 must match single linear; diff={diff}"
450        );
451    }
452
453    #[test]
454    fn test_column_parallel_indivisible_n_out_error() {
455        let cfg = TensorParallelConfig {
456            n_workers: 3,
457            gather_output: true,
458        };
459        assert!(
460            ColumnParallelLinear::new(4, 7, cfg, 0).is_err(),
461            "n_out=7 is not divisible by 3"
462        );
463    }
464
465    // --- RowParallelLinear ---
466
467    #[test]
468    fn test_row_parallel_output_shape() {
469        let cfg = TensorParallelConfig {
470            n_workers: 2,
471            gather_output: true,
472        };
473        let layer = RowParallelLinear::new(8, 4, cfg, 0).expect("ok");
474        let input = Array2::<f64>::ones((5, 8));
475        let out = layer.forward(&input).expect("forward ok");
476        assert_eq!(out.shape(), [5, 4], "output shape should be [batch, n_out]");
477    }
478
479    #[test]
480    fn test_row_parallel_n_in() {
481        let cfg = TensorParallelConfig {
482            n_workers: 2,
483            gather_output: true,
484        };
485        let layer = RowParallelLinear::new(6, 3, cfg, 0).expect("ok");
486        assert_eq!(layer.n_in(), 6);
487    }
488
489    #[test]
490    fn test_row_parallel_all_reduce_equals_full_matmul() {
491        // Row-parallel sum across workers must equal a full matrix multiply.
492        let n_in = 8;
493        let n_out = 4;
494        let cfg = TensorParallelConfig {
495            n_workers: 2,
496            gather_output: true,
497        };
498        let layer = RowParallelLinear::new(n_in, n_out, cfg, 7).expect("ok");
499        let input = Array2::from_shape_fn((3, n_in), |(i, j)| (i * n_in + j) as f64 * 0.1);
500        let out_parallel = layer.forward(&input).expect("row parallel ok");
501
502        // Reconstruct full weight by concatenating [W_0; W_1].
503        use scirs2_core::ndarray::concatenate;
504        use scirs2_core::ndarray::Axis;
505        let full_w: Array2<f64> = concatenate(
506            Axis(0),
507            &[layer.local_weights[0].view(), layer.local_weights[1].view()],
508        )
509        .expect("concat ok");
510        let out_full = input.dot(&full_w) + &layer.bias;
511
512        let diff: f64 = (&out_parallel - &out_full).mapv(|v| v.abs()).sum();
513        assert!(
514            diff < 1e-12,
515            "row-parallel must equal full matmul; diff={diff}"
516        );
517    }
518
519    #[test]
520    fn test_col_row_composition_shape() {
521        let n_in = 8;
522        let hidden = 16;
523        let n_out = 4;
524        let cfg1 = TensorParallelConfig {
525            n_workers: 2,
526            gather_output: true,
527        };
528        let cfg2 = TensorParallelConfig {
529            n_workers: 2,
530            gather_output: true,
531        };
532        let col = ColumnParallelLinear::new(n_in, hidden, cfg1, 0).expect("col ok");
533        let row = RowParallelLinear::new(hidden, n_out, cfg2, 1).expect("row ok");
534        let input = Array2::<f64>::ones((5, n_in));
535        let mid = col.forward(&input).expect("col forward");
536        let out = row.forward(&mid).expect("row forward");
537        assert_eq!(out.shape(), [5, n_out]);
538    }
539
540    // --- ParallelEmbedding ---
541
542    #[test]
543    fn test_parallel_embedding_output_shape() {
544        let emb = ParallelEmbedding::new(8, 16, 2, 0).expect("ok");
545        let indices = vec![0_usize, 1, 3, 7];
546        let out = emb.forward(&indices).expect("forward ok");
547        assert_eq!(
548            out.shape(),
549            [4, 16],
550            "shape should be [n_indices, embed_dim]"
551        );
552    }
553
554    #[test]
555    fn test_parallel_embedding_vocab_and_dim() {
556        let emb = ParallelEmbedding::new(100, 32, 4, 0).expect("ok");
557        assert_eq!(emb.vocab_size(), 100);
558        assert_eq!(emb.embed_dim(), 32);
559    }
560
561    #[test]
562    fn test_parallel_embedding_same_index_same_vector() {
563        let emb = ParallelEmbedding::new(8, 4, 2, 99).expect("ok");
564        let out1 = emb.forward(&[3]).expect("ok");
565        let out2 = emb.forward(&[3]).expect("ok");
566        let diff: f64 = (&out1 - &out2).mapv(|v| v.abs()).sum();
567        assert!(diff < 1e-15, "same index must always return same embedding");
568    }
569
570    #[test]
571    fn test_parallel_embedding_out_of_range_error() {
572        let emb = ParallelEmbedding::new(8, 4, 2, 0).expect("ok");
573        assert!(
574            emb.forward(&[8]).is_err(),
575            "index 8 is out of range for vocab_size=8"
576        );
577    }
578
579    #[test]
580    fn test_parallel_embedding_indivisible_vocab_error() {
581        assert!(
582            ParallelEmbedding::new(7, 4, 2, 0).is_err(),
583            "vocab_size=7 not divisible by 2"
584        );
585    }
586}