1use crate::error::{NeuralError, Result as NeuralResult};
40use scirs2_core::ndarray::{s, Array1, Array2};
41use scirs2_core::random::rngs::SmallRng;
42use scirs2_core::random::{Rng, RngExt, SeedableRng};
43
44#[derive(Debug, Clone)]
50pub struct TensorParallelConfig {
51 pub n_workers: usize,
53 pub gather_output: bool,
56}
57
58impl Default for TensorParallelConfig {
59 fn default() -> Self {
60 Self {
61 n_workers: 2,
62 gather_output: true,
63 }
64 }
65}
66
67fn xavier_init(rng: &mut SmallRng, n_in: usize, n_out: usize) -> f64 {
73 let scale = (6.0_f64 / (n_in + n_out) as f64).sqrt();
74 rng.random::<f64>() * 2.0 * scale - scale
75}
76
77pub struct ColumnParallelLinear {
86 local_weights: Vec<Array2<f64>>,
88 local_biases: Vec<Array1<f64>>,
90 config: TensorParallelConfig,
91 n_in: usize,
92 total_n_out: usize,
93}
94
95impl ColumnParallelLinear {
96 pub fn new(
102 n_in: usize,
103 n_out: usize,
104 config: TensorParallelConfig,
105 seed: u64,
106 ) -> NeuralResult<Self> {
107 if config.n_workers == 0 {
108 return Err(NeuralError::ConfigError(
109 "TensorParallelConfig.n_workers must be > 0".into(),
110 ));
111 }
112 if !n_out.is_multiple_of(config.n_workers) {
113 return Err(NeuralError::ConfigError(format!(
114 "n_out ({n_out}) must be divisible by n_workers ({})",
115 config.n_workers
116 )));
117 }
118
119 let chunk = n_out / config.n_workers;
120 let mut rng = SmallRng::seed_from_u64(seed);
121
122 let mut local_weights = Vec::with_capacity(config.n_workers);
123 let mut local_biases = Vec::with_capacity(config.n_workers);
124
125 for _ in 0..config.n_workers {
126 let w = Array2::from_shape_fn((n_in, chunk), |_| xavier_init(&mut rng, n_in, n_out));
127 let b = Array1::zeros(chunk);
128 local_weights.push(w);
129 local_biases.push(b);
130 }
131
132 Ok(Self {
133 local_weights,
134 local_biases,
135 config,
136 n_in,
137 total_n_out: n_out,
138 })
139 }
140
141 pub fn forward(&self, input: &Array2<f64>) -> NeuralResult<Array2<f64>> {
148 let batch = input.shape()[0];
149 let n_in = input.shape()[1];
150 if n_in != self.n_in {
151 return Err(NeuralError::DimensionMismatch(format!(
152 "ColumnParallelLinear: expected n_in={}, got {n_in}",
153 self.n_in
154 )));
155 }
156
157 let mut parts: Vec<Array2<f64>> = Vec::with_capacity(self.config.n_workers);
158 for (w, b) in self.local_weights.iter().zip(self.local_biases.iter()) {
159 let y = input.dot(w) + b; parts.push(y);
161 }
162
163 if self.config.gather_output {
164 let chunk = self.total_n_out / self.config.n_workers;
166 let mut gathered = Array2::<f64>::zeros((batch, self.total_n_out));
167 for (wi, part) in parts.iter().enumerate() {
168 let start = wi * chunk;
169 let end = start + chunk;
170 gathered.slice_mut(s![.., start..end]).assign(part);
171 }
172 Ok(gathered)
173 } else {
174 parts
176 .into_iter()
177 .next()
178 .ok_or_else(|| NeuralError::ComputationError("no workers".into()))
179 }
180 }
181
182 pub fn n_out(&self) -> usize {
184 self.total_n_out
185 }
186
187 pub fn n_workers(&self) -> usize {
189 self.config.n_workers
190 }
191}
192
193pub struct RowParallelLinear {
202 local_weights: Vec<Array2<f64>>,
204 bias: Array1<f64>,
206 config: TensorParallelConfig,
207 total_n_in: usize,
208 n_out: usize,
209}
210
211impl RowParallelLinear {
212 pub fn new(
218 n_in: usize,
219 n_out: usize,
220 config: TensorParallelConfig,
221 seed: u64,
222 ) -> NeuralResult<Self> {
223 if config.n_workers == 0 {
224 return Err(NeuralError::ConfigError(
225 "TensorParallelConfig.n_workers must be > 0".into(),
226 ));
227 }
228 if !n_in.is_multiple_of(config.n_workers) {
229 return Err(NeuralError::ConfigError(format!(
230 "n_in ({n_in}) must be divisible by n_workers ({})",
231 config.n_workers
232 )));
233 }
234
235 let chunk = n_in / config.n_workers;
236 let mut rng = SmallRng::seed_from_u64(seed);
237
238 let mut local_weights = Vec::with_capacity(config.n_workers);
239 for _ in 0..config.n_workers {
240 let w = Array2::from_shape_fn((chunk, n_out), |_| xavier_init(&mut rng, n_in, n_out));
241 local_weights.push(w);
242 }
243 let bias = Array1::zeros(n_out);
244
245 Ok(Self {
246 local_weights,
247 bias,
248 config,
249 total_n_in: n_in,
250 n_out,
251 })
252 }
253
254 pub fn forward(&self, input: &Array2<f64>) -> NeuralResult<Array2<f64>> {
260 let batch = input.shape()[0];
261 let n_in = input.shape()[1];
262 if n_in != self.total_n_in {
263 return Err(NeuralError::DimensionMismatch(format!(
264 "RowParallelLinear: expected n_in={}, got {n_in}",
265 self.total_n_in
266 )));
267 }
268
269 let chunk = self.total_n_in / self.config.n_workers;
270 let mut acc = Array2::<f64>::zeros((batch, self.n_out));
271
272 for (wi, w) in self.local_weights.iter().enumerate() {
273 let start = wi * chunk;
274 let end = start + chunk;
275 let input_slice = input.slice(s![.., start..end]);
276 let partial = input_slice.dot(w); acc += &partial;
278 }
279
280 acc += &self.bias;
282
283 Ok(acc)
284 }
285
286 pub fn n_in(&self) -> usize {
288 self.total_n_in
289 }
290}
291
292pub struct ParallelEmbedding {
302 local_tables: Vec<Array2<f64>>,
304 vocab_size: usize,
305 embed_dim: usize,
306 n_workers: usize,
307}
308
309impl ParallelEmbedding {
310 pub fn new(
316 vocab_size: usize,
317 embed_dim: usize,
318 n_workers: usize,
319 seed: u64,
320 ) -> NeuralResult<Self> {
321 if n_workers == 0 {
322 return Err(NeuralError::ConfigError(
323 "ParallelEmbedding: n_workers must be > 0".into(),
324 ));
325 }
326 if !vocab_size.is_multiple_of(n_workers) {
327 return Err(NeuralError::ConfigError(format!(
328 "vocab_size ({vocab_size}) must be divisible by n_workers ({n_workers})"
329 )));
330 }
331
332 let local_vocab = vocab_size / n_workers;
333 let mut rng = SmallRng::seed_from_u64(seed);
334
335 let mut local_tables = Vec::with_capacity(n_workers);
337 for _ in 0..n_workers {
338 let table = Array2::from_shape_fn((local_vocab, embed_dim), |_| {
339 (rng.random::<f64>() * 2.0 - 1.0) * 0.02
340 });
341 local_tables.push(table);
342 }
343
344 Ok(Self {
345 local_tables,
346 vocab_size,
347 embed_dim,
348 n_workers,
349 })
350 }
351
352 pub fn forward(&self, indices: &[usize]) -> NeuralResult<Array2<f64>> {
359 let local_vocab = self.vocab_size / self.n_workers;
360 let mut out = Array2::<f64>::zeros((indices.len(), self.embed_dim));
361
362 for (row, &idx) in indices.iter().enumerate() {
363 if idx >= self.vocab_size {
364 return Err(NeuralError::InvalidArgument(format!(
365 "token index {idx} out of range (vocab_size={})",
366 self.vocab_size
367 )));
368 }
369 let worker_id = idx / local_vocab;
370 let local_idx = idx % local_vocab;
371 let embedding = self.local_tables[worker_id].slice(s![local_idx, ..]);
372 out.slice_mut(s![row, ..]).assign(&embedding);
373 }
374
375 Ok(out)
376 }
377
378 pub fn vocab_size(&self) -> usize {
380 self.vocab_size
381 }
382
383 pub fn embed_dim(&self) -> usize {
385 self.embed_dim
386 }
387}
388
389#[cfg(test)]
394mod tests {
395 use super::*;
396 use scirs2_core::ndarray::Array2;
397
398 #[test]
401 fn test_default_config_n_workers_2() {
402 let cfg = TensorParallelConfig::default();
403 assert_eq!(cfg.n_workers, 2, "default n_workers must be 2");
404 assert!(cfg.gather_output, "default gather_output must be true");
405 }
406
407 #[test]
410 fn test_column_parallel_output_shape() {
411 let cfg = TensorParallelConfig {
412 n_workers: 2,
413 gather_output: true,
414 };
415 let layer = ColumnParallelLinear::new(8, 4, cfg, 0).expect("ok");
416 let input = Array2::<f64>::ones((5, 8));
417 let out = layer.forward(&input).expect("forward ok");
418 assert_eq!(out.shape(), [5, 4], "output shape should be [batch, n_out]");
419 }
420
421 #[test]
422 fn test_column_parallel_n_out() {
423 let cfg = TensorParallelConfig {
424 n_workers: 4,
425 gather_output: true,
426 };
427 let layer = ColumnParallelLinear::new(6, 8, cfg, 1).expect("ok");
428 assert_eq!(layer.n_out(), 8);
429 assert_eq!(layer.n_workers(), 4);
430 }
431
432 #[test]
433 fn test_column_parallel_n_workers_1_equivalent_to_linear() {
434 let n_in = 4;
436 let n_out = 6;
437 let cfg = TensorParallelConfig {
438 n_workers: 1,
439 gather_output: true,
440 };
441 let layer = ColumnParallelLinear::new(n_in, n_out, cfg, 42).expect("ok");
442 let input = Array2::from_shape_fn((3, n_in), |(i, j)| (i * n_in + j) as f64 * 0.1);
443 let out = layer.forward(&input).expect("forward ok");
444 let expected = input.dot(&layer.local_weights[0]) + &layer.local_biases[0];
446 let diff: f64 = (&out - &expected).mapv(|v| v.abs()).sum();
447 assert!(
448 diff < 1e-12,
449 "n_workers=1 must match single linear; diff={diff}"
450 );
451 }
452
453 #[test]
454 fn test_column_parallel_indivisible_n_out_error() {
455 let cfg = TensorParallelConfig {
456 n_workers: 3,
457 gather_output: true,
458 };
459 assert!(
460 ColumnParallelLinear::new(4, 7, cfg, 0).is_err(),
461 "n_out=7 is not divisible by 3"
462 );
463 }
464
465 #[test]
468 fn test_row_parallel_output_shape() {
469 let cfg = TensorParallelConfig {
470 n_workers: 2,
471 gather_output: true,
472 };
473 let layer = RowParallelLinear::new(8, 4, cfg, 0).expect("ok");
474 let input = Array2::<f64>::ones((5, 8));
475 let out = layer.forward(&input).expect("forward ok");
476 assert_eq!(out.shape(), [5, 4], "output shape should be [batch, n_out]");
477 }
478
479 #[test]
480 fn test_row_parallel_n_in() {
481 let cfg = TensorParallelConfig {
482 n_workers: 2,
483 gather_output: true,
484 };
485 let layer = RowParallelLinear::new(6, 3, cfg, 0).expect("ok");
486 assert_eq!(layer.n_in(), 6);
487 }
488
489 #[test]
490 fn test_row_parallel_all_reduce_equals_full_matmul() {
491 let n_in = 8;
493 let n_out = 4;
494 let cfg = TensorParallelConfig {
495 n_workers: 2,
496 gather_output: true,
497 };
498 let layer = RowParallelLinear::new(n_in, n_out, cfg, 7).expect("ok");
499 let input = Array2::from_shape_fn((3, n_in), |(i, j)| (i * n_in + j) as f64 * 0.1);
500 let out_parallel = layer.forward(&input).expect("row parallel ok");
501
502 use scirs2_core::ndarray::concatenate;
504 use scirs2_core::ndarray::Axis;
505 let full_w: Array2<f64> = concatenate(
506 Axis(0),
507 &[layer.local_weights[0].view(), layer.local_weights[1].view()],
508 )
509 .expect("concat ok");
510 let out_full = input.dot(&full_w) + &layer.bias;
511
512 let diff: f64 = (&out_parallel - &out_full).mapv(|v| v.abs()).sum();
513 assert!(
514 diff < 1e-12,
515 "row-parallel must equal full matmul; diff={diff}"
516 );
517 }
518
519 #[test]
520 fn test_col_row_composition_shape() {
521 let n_in = 8;
522 let hidden = 16;
523 let n_out = 4;
524 let cfg1 = TensorParallelConfig {
525 n_workers: 2,
526 gather_output: true,
527 };
528 let cfg2 = TensorParallelConfig {
529 n_workers: 2,
530 gather_output: true,
531 };
532 let col = ColumnParallelLinear::new(n_in, hidden, cfg1, 0).expect("col ok");
533 let row = RowParallelLinear::new(hidden, n_out, cfg2, 1).expect("row ok");
534 let input = Array2::<f64>::ones((5, n_in));
535 let mid = col.forward(&input).expect("col forward");
536 let out = row.forward(&mid).expect("row forward");
537 assert_eq!(out.shape(), [5, n_out]);
538 }
539
540 #[test]
543 fn test_parallel_embedding_output_shape() {
544 let emb = ParallelEmbedding::new(8, 16, 2, 0).expect("ok");
545 let indices = vec![0_usize, 1, 3, 7];
546 let out = emb.forward(&indices).expect("forward ok");
547 assert_eq!(
548 out.shape(),
549 [4, 16],
550 "shape should be [n_indices, embed_dim]"
551 );
552 }
553
554 #[test]
555 fn test_parallel_embedding_vocab_and_dim() {
556 let emb = ParallelEmbedding::new(100, 32, 4, 0).expect("ok");
557 assert_eq!(emb.vocab_size(), 100);
558 assert_eq!(emb.embed_dim(), 32);
559 }
560
561 #[test]
562 fn test_parallel_embedding_same_index_same_vector() {
563 let emb = ParallelEmbedding::new(8, 4, 2, 99).expect("ok");
564 let out1 = emb.forward(&[3]).expect("ok");
565 let out2 = emb.forward(&[3]).expect("ok");
566 let diff: f64 = (&out1 - &out2).mapv(|v| v.abs()).sum();
567 assert!(diff < 1e-15, "same index must always return same embedding");
568 }
569
570 #[test]
571 fn test_parallel_embedding_out_of_range_error() {
572 let emb = ParallelEmbedding::new(8, 4, 2, 0).expect("ok");
573 assert!(
574 emb.forward(&[8]).is_err(),
575 "index 8 is out of range for vocab_size=8"
576 );
577 }
578
579 #[test]
580 fn test_parallel_embedding_indivisible_vocab_error() {
581 assert!(
582 ParallelEmbedding::new(7, 4, 2, 0).is_err(),
583 "vocab_size=7 not divisible by 2"
584 );
585 }
586}