1use scirs2_core::ndarray::{ArrayD, IxDyn};
8use std::f64::consts::PI;
9
10#[derive(Debug, Clone)]
16pub enum InitError {
17 InvalidFanIn(usize),
19 InvalidFanOut(usize),
21 InvalidGain(f64),
23 InvalidStd(f64),
25 ShapeTooSmall { shape: Vec<usize> },
27 EmptyShape,
29 ShapeError(String),
31}
32
33impl std::fmt::Display for InitError {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 match self {
36 Self::InvalidFanIn(v) => write!(f, "invalid fan_in: {v}"),
37 Self::InvalidFanOut(v) => write!(f, "invalid fan_out: {v}"),
38 Self::InvalidGain(v) => write!(f, "invalid gain: {v}"),
39 Self::InvalidStd(v) => write!(f, "invalid std: {v}"),
40 Self::ShapeTooSmall { shape } => write!(f, "shape too small: {shape:?}"),
41 Self::EmptyShape => write!(f, "empty shape"),
42 Self::ShapeError(msg) => write!(f, "shape error: {msg}"),
43 }
44 }
45}
46
47impl std::error::Error for InitError {}
48
49#[derive(Debug, Clone, PartialEq)]
55pub enum FanMode {
56 FanIn,
58 FanOut,
60}
61
62#[derive(Debug, Clone)]
69pub struct InitRng {
70 state: u64,
71}
72
73impl InitRng {
74 pub fn new(seed: u64) -> Self {
76 Self { state: seed }
77 }
78
79 #[inline]
81 fn step(&mut self) {
82 self.state = self
83 .state
84 .wrapping_mul(6_364_136_223_846_793_005)
85 .wrapping_add(1_442_695_040_888_963_407);
86 }
87
88 pub fn next_f64(&mut self) -> f64 {
90 self.step();
91 (self.state >> 11) as f64 / ((1u64 << 53) as f64)
92 }
93
94 pub fn next_normal(&mut self) -> f64 {
97 let u1 = self.next_f64().max(f64::MIN_POSITIVE); let u2 = self.next_f64();
99 (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos()
100 }
101
102 pub fn next_uniform(&mut self, low: f64, high: f64) -> f64 {
104 low + (high - low) * self.next_f64()
105 }
106}
107
108pub fn compute_fans(shape: &[usize]) -> Result<(usize, usize), InitError> {
119 match shape.len() {
120 0 => Err(InitError::EmptyShape),
121 1 => Err(InitError::ShapeTooSmall {
122 shape: shape.to_vec(),
123 }),
124 2 => {
125 let fan_out = shape[0];
126 let fan_in = shape[1];
127 if fan_in == 0 {
128 return Err(InitError::InvalidFanIn(0));
129 }
130 if fan_out == 0 {
131 return Err(InitError::InvalidFanOut(0));
132 }
133 Ok((fan_in, fan_out))
134 }
135 _ => {
136 let receptive_field: usize = shape[2..].iter().product();
137 let fan_in = shape[1] * receptive_field;
138 let fan_out = shape[0] * receptive_field;
139 if fan_in == 0 {
140 return Err(InitError::InvalidFanIn(0));
141 }
142 if fan_out == 0 {
143 return Err(InitError::InvalidFanOut(0));
144 }
145 Ok((fan_in, fan_out))
146 }
147 }
148}
149
150fn make_array(shape: &[usize], data: Vec<f64>) -> Result<ArrayD<f64>, InitError> {
155 ArrayD::from_shape_vec(IxDyn(shape), data).map_err(|e| InitError::ShapeError(e.to_string()))
156}
157
158fn total_elements(shape: &[usize]) -> usize {
159 shape.iter().product()
160}
161
162pub fn gain_for_activation(activation: &str) -> f64 {
178 match activation {
179 "linear" | "sigmoid" => 1.0,
180 "tanh" => 5.0 / 3.0,
181 "relu" => 2.0_f64.sqrt(),
182 "leaky_relu" => (2.0 / (1.0 + 0.01_f64.powi(2))).sqrt(),
183 "selu" => 3.0 / 4.0,
184 _ => 1.0,
185 }
186}
187
188pub fn xavier_uniform(
196 shape: &[usize],
197 gain: f64,
198 rng: &mut InitRng,
199) -> Result<ArrayD<f64>, InitError> {
200 validate_gain(gain)?;
201 let (fan_in, fan_out) = compute_fans(shape)?;
202 let limit = gain * (6.0 / (fan_in + fan_out) as f64).sqrt();
203 let n = total_elements(shape);
204 let data: Vec<f64> = (0..n).map(|_| rng.next_uniform(-limit, limit)).collect();
205 make_array(shape, data)
206}
207
208pub fn xavier_normal(
212 shape: &[usize],
213 gain: f64,
214 rng: &mut InitRng,
215) -> Result<ArrayD<f64>, InitError> {
216 validate_gain(gain)?;
217 let (fan_in, fan_out) = compute_fans(shape)?;
218 let std = gain * (2.0 / (fan_in + fan_out) as f64).sqrt();
219 let n = total_elements(shape);
220 let data: Vec<f64> = (0..n).map(|_| std * rng.next_normal()).collect();
221 make_array(shape, data)
222}
223
224pub fn kaiming_uniform(
232 shape: &[usize],
233 gain: f64,
234 mode: FanMode,
235 rng: &mut InitRng,
236) -> Result<ArrayD<f64>, InitError> {
237 validate_gain(gain)?;
238 let (fan_in, fan_out) = compute_fans(shape)?;
239 let fan = match mode {
240 FanMode::FanIn => fan_in,
241 FanMode::FanOut => fan_out,
242 };
243 let bound = gain * (3.0 / fan as f64).sqrt();
244 let n = total_elements(shape);
245 let data: Vec<f64> = (0..n).map(|_| rng.next_uniform(-bound, bound)).collect();
246 make_array(shape, data)
247}
248
249pub fn kaiming_normal(
253 shape: &[usize],
254 gain: f64,
255 mode: FanMode,
256 rng: &mut InitRng,
257) -> Result<ArrayD<f64>, InitError> {
258 validate_gain(gain)?;
259 let (fan_in, fan_out) = compute_fans(shape)?;
260 let fan = match mode {
261 FanMode::FanIn => fan_in,
262 FanMode::FanOut => fan_out,
263 };
264 let std = gain / (fan as f64).sqrt();
265 let n = total_elements(shape);
266 let data: Vec<f64> = (0..n).map(|_| std * rng.next_normal()).collect();
267 make_array(shape, data)
268}
269
270pub fn lecun_normal(shape: &[usize], rng: &mut InitRng) -> Result<ArrayD<f64>, InitError> {
276 let (fan_in, _) = compute_fans(shape)?;
277 let std = 1.0 / (fan_in as f64).sqrt();
278 let n = total_elements(shape);
279 let data: Vec<f64> = (0..n).map(|_| std * rng.next_normal()).collect();
280 make_array(shape, data)
281}
282
283pub fn lecun_uniform(shape: &[usize], rng: &mut InitRng) -> Result<ArrayD<f64>, InitError> {
285 let (fan_in, _) = compute_fans(shape)?;
286 let limit = (3.0 / fan_in as f64).sqrt();
287 let n = total_elements(shape);
288 let data: Vec<f64> = (0..n).map(|_| rng.next_uniform(-limit, limit)).collect();
289 make_array(shape, data)
290}
291
292pub fn constant_init(shape: &[usize], value: f64) -> ArrayD<f64> {
298 ArrayD::from_elem(IxDyn(shape), value)
299}
300
301pub fn zeros_init(shape: &[usize]) -> ArrayD<f64> {
303 ArrayD::zeros(IxDyn(shape))
304}
305
306pub fn ones_init(shape: &[usize]) -> ArrayD<f64> {
308 ArrayD::ones(IxDyn(shape))
309}
310
311pub fn normal_init(
317 shape: &[usize],
318 mean: f64,
319 std: f64,
320 rng: &mut InitRng,
321) -> Result<ArrayD<f64>, InitError> {
322 if std <= 0.0 || !std.is_finite() {
323 return Err(InitError::InvalidStd(std));
324 }
325 let n = total_elements(shape);
326 let data: Vec<f64> = (0..n).map(|_| mean + std * rng.next_normal()).collect();
327 make_array(shape, data)
328}
329
330pub fn uniform_init(
332 shape: &[usize],
333 low: f64,
334 high: f64,
335 rng: &mut InitRng,
336) -> Result<ArrayD<f64>, InitError> {
337 if low >= high {
338 return Err(InitError::InvalidStd(high - low)); }
340 let n = total_elements(shape);
341 let data: Vec<f64> = (0..n).map(|_| rng.next_uniform(low, high)).collect();
342 make_array(shape, data)
343}
344
345pub fn orthogonal_init(
355 shape: &[usize],
356 gain: f64,
357 rng: &mut InitRng,
358) -> Result<ArrayD<f64>, InitError> {
359 validate_gain(gain)?;
360 if shape.len() < 2 {
361 return Err(InitError::ShapeTooSmall {
362 shape: shape.to_vec(),
363 });
364 }
365
366 let rows = shape[0];
367 let cols: usize = shape[1..].iter().product();
368 if rows == 0 || cols == 0 {
369 return Err(InitError::ShapeTooSmall {
370 shape: shape.to_vec(),
371 });
372 }
373
374 let n = rows * cols;
376 let mut flat: Vec<f64> = (0..n).map(|_| rng.next_normal()).collect();
377
378 let (work_rows, work_cols, transposed) = if rows >= cols {
380 (rows, cols, false)
381 } else {
382 (cols, rows, true)
383 };
384
385 let mut columns: Vec<Vec<f64>> = if !transposed {
388 (0..work_cols)
390 .map(|c| (0..work_rows).map(|r| flat[r * cols + c]).collect())
391 .collect()
392 } else {
393 (0..work_cols)
395 .map(|c| (0..work_rows).map(|r| flat[c * cols + r]).collect())
396 .collect()
397 };
398
399 for i in 0..work_cols {
401 let norm = dot_vec(&columns[i], &columns[i]).sqrt();
403 if norm < 1e-15 {
404 for v in columns[i].iter_mut() {
406 *v = 0.0;
407 }
408 if i < work_rows {
409 columns[i][i] = 1.0;
410 }
411 } else {
412 for v in columns[i].iter_mut() {
413 *v /= norm;
414 }
415 }
416
417 let qi = columns[i].clone();
419 for col in columns.iter_mut().skip(i + 1) {
420 let proj = dot_vec(&qi, col);
421 for (v, q) in col.iter_mut().zip(qi.iter()) {
422 *v -= proj * q;
423 }
424 }
425 }
426
427 flat.clear();
429 if !transposed {
430 for r in 0..rows {
431 for col in columns.iter().take(cols) {
432 flat.push(gain * col[r]);
433 }
434 }
435 } else {
436 for col_vec in columns.iter().take(rows) {
437 for &val in col_vec.iter().take(cols) {
438 flat.push(gain * val);
439 }
440 }
441 }
442
443 make_array(shape, flat)
444}
445
446fn dot_vec(a: &[f64], b: &[f64]) -> f64 {
448 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
449}
450
451fn validate_gain(gain: f64) -> Result<(), InitError> {
456 if gain <= 0.0 || !gain.is_finite() {
457 return Err(InitError::InvalidGain(gain));
458 }
459 Ok(())
460}
461
462#[derive(Debug, Clone)]
468pub struct InitStats {
469 pub shape: Vec<usize>,
471 pub num_elements: usize,
473 pub mean: f64,
475 pub std: f64,
477 pub min: f64,
479 pub max: f64,
481 pub fan_in: usize,
483 pub fan_out: usize,
485}
486
487impl InitStats {
488 pub fn compute(tensor: &ArrayD<f64>, shape: &[usize]) -> Self {
490 let n = tensor.len();
491 let (fan_in, fan_out) = compute_fans(shape).unwrap_or((0, 0));
492
493 let mut sum = 0.0_f64;
494 let mut min_val = f64::INFINITY;
495 let mut max_val = f64::NEG_INFINITY;
496
497 for &v in tensor.iter() {
498 sum += v;
499 if v < min_val {
500 min_val = v;
501 }
502 if v > max_val {
503 max_val = v;
504 }
505 }
506
507 let mean = if n > 0 { sum / n as f64 } else { 0.0 };
508
509 let variance = if n > 1 {
510 let mut sq_sum = 0.0_f64;
511 for &v in tensor.iter() {
512 sq_sum += (v - mean).powi(2);
513 }
514 sq_sum / n as f64
515 } else {
516 0.0
517 };
518
519 Self {
520 shape: shape.to_vec(),
521 num_elements: n,
522 mean,
523 std: variance.sqrt(),
524 min: min_val,
525 max: max_val,
526 fan_in,
527 fan_out,
528 }
529 }
530
531 pub fn summary(&self) -> String {
533 format!(
534 "InitStats {{ shape: {:?}, n: {}, mean: {:.6}, std: {:.6}, \
535 min: {:.6}, max: {:.6}, fan_in: {}, fan_out: {} }}",
536 self.shape,
537 self.num_elements,
538 self.mean,
539 self.std,
540 self.min,
541 self.max,
542 self.fan_in,
543 self.fan_out,
544 )
545 }
546}
547
548#[cfg(test)]
553mod tests {
554 use super::*;
555
556 #[test]
557 fn test_compute_fans_2d() {
558 let (fan_in, fan_out) = compute_fans(&[10, 5]).expect("compute_fans failed");
559 assert_eq!(fan_in, 5);
560 assert_eq!(fan_out, 10);
561 }
562
563 #[test]
564 fn test_compute_fans_4d() {
565 let (fan_in, fan_out) = compute_fans(&[16, 3, 3, 3]).expect("compute_fans failed");
567 assert_eq!(fan_in, 3 * 3 * 3); assert_eq!(fan_out, 16 * 3 * 3); }
570
571 #[test]
572 fn test_xavier_uniform_range() {
573 let shape = [64, 32];
574 let (fan_in, fan_out) = compute_fans(&shape).expect("fans");
575 let limit = (6.0 / (fan_in + fan_out) as f64).sqrt();
576 let mut rng = InitRng::new(42);
577 let arr = xavier_uniform(&shape, 1.0, &mut rng).expect("xavier_uniform");
578 for &v in arr.iter() {
579 assert!(
580 v >= -limit && v <= limit,
581 "value {v} outside [{}, {}]",
582 -limit,
583 limit
584 );
585 }
586 }
587
588 #[test]
589 fn test_xavier_normal_mean_near_zero() {
590 let shape = [256, 128];
591 let mut rng = InitRng::new(123);
592 let arr = xavier_normal(&shape, 1.0, &mut rng).expect("xavier_normal");
593 let mean: f64 = arr.iter().sum::<f64>() / arr.len() as f64;
594 assert!(mean.abs() < 0.05, "mean too far from zero: {mean}");
595 }
596
597 #[test]
598 fn test_kaiming_uniform_fan_in() {
599 let shape = [64, 32];
600 let gain = 2.0_f64.sqrt();
601 let (fan_in, _) = compute_fans(&shape).expect("fans");
602 let bound = gain * (3.0 / fan_in as f64).sqrt();
603 let mut rng = InitRng::new(7);
604 let arr = kaiming_uniform(&shape, gain, FanMode::FanIn, &mut rng).expect("kaiming_uniform");
605 for &v in arr.iter() {
606 assert!(
607 v >= -bound && v <= bound,
608 "value {v} outside [{}, {}]",
609 -bound,
610 bound
611 );
612 }
613 }
614
615 #[test]
616 fn test_kaiming_normal_std() {
617 let shape = [256, 128];
618 let gain = 2.0_f64.sqrt();
619 let (fan_in, _) = compute_fans(&shape).expect("fans");
620 let expected_std = gain / (fan_in as f64).sqrt();
621 let mut rng = InitRng::new(99);
622 let arr = kaiming_normal(&shape, gain, FanMode::FanIn, &mut rng).expect("kaiming_normal");
623 let n = arr.len() as f64;
624 let mean: f64 = arr.iter().sum::<f64>() / n;
625 let var: f64 = arr.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / n;
626 let actual_std = var.sqrt();
627 let ratio = actual_std / expected_std;
628 assert!(
629 (0.85..=1.15).contains(&ratio),
630 "std ratio {ratio} (actual={actual_std}, expected={expected_std})"
631 );
632 }
633
634 #[test]
635 fn test_lecun_normal_shape() {
636 let shape = [16, 8, 3, 3];
637 let mut rng = InitRng::new(55);
638 let arr = lecun_normal(&shape, &mut rng).expect("lecun_normal");
639 assert_eq!(arr.shape(), &[16, 8, 3, 3]);
640 }
641
642 #[test]
643 fn test_lecun_uniform_range() {
644 let shape = [32, 16];
645 let (fan_in, _) = compute_fans(&shape).expect("fans");
646 let limit = (3.0 / fan_in as f64).sqrt();
647 let mut rng = InitRng::new(11);
648 let arr = lecun_uniform(&shape, &mut rng).expect("lecun_uniform");
649 for &v in arr.iter() {
650 assert!(
651 v >= -limit && v <= limit,
652 "value {v} outside [{}, {}]",
653 -limit,
654 limit
655 );
656 }
657 }
658
659 #[test]
660 fn test_constant_init_value() {
661 let arr = constant_init(&[3, 4], 3.15);
662 for &v in arr.iter() {
663 assert!((v - 3.15).abs() < 1e-12);
664 }
665 }
666
667 #[test]
668 fn test_zeros_init() {
669 let arr = zeros_init(&[5, 5]);
670 for &v in arr.iter() {
671 assert!((v).abs() < 1e-15);
672 }
673 }
674
675 #[test]
676 fn test_ones_init() {
677 let arr = ones_init(&[2, 3]);
678 for &v in arr.iter() {
679 assert!((v - 1.0).abs() < 1e-15);
680 }
681 }
682
683 #[test]
684 fn test_orthogonal_init_square() {
685 let shape = [8, 8];
686 let mut rng = InitRng::new(77);
687 let arr = orthogonal_init(&shape, 1.0, &mut rng).expect("orthogonal_init");
688 let n = 8;
690 for i in 0..n {
691 for j in 0..n {
692 let mut dot = 0.0_f64;
693 for k in 0..n {
694 dot += arr[[k, i].as_ref()] * arr[[k, j].as_ref()];
696 }
697 let expected = if i == j { 1.0 } else { 0.0 };
698 assert!(
699 (dot - expected).abs() < 1e-8,
700 "Q^T Q [{i},{j}] = {dot}, expected {expected}"
701 );
702 }
703 }
704 }
705
706 #[test]
707 fn test_normal_init_distribution() {
708 let shape = [512, 256];
709 let target_mean = 2.0;
710 let target_std = 0.5;
711 let mut rng = InitRng::new(42);
712 let arr = normal_init(&shape, target_mean, target_std, &mut rng).expect("normal_init");
713 let n = arr.len() as f64;
714 let mean: f64 = arr.iter().sum::<f64>() / n;
715 let var: f64 = arr.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / n;
716 let actual_std = var.sqrt();
717 assert!(
718 (mean - target_mean).abs() < 0.05,
719 "mean {mean} far from {target_mean}"
720 );
721 assert!(
722 (actual_std - target_std).abs() < 0.05,
723 "std {actual_std} far from {target_std}"
724 );
725 }
726
727 #[test]
728 fn test_uniform_init_bounds() {
729 let shape = [100, 100];
730 let mut rng = InitRng::new(13);
731 let arr = uniform_init(&shape, -0.5, 0.5, &mut rng).expect("uniform_init");
732 for &v in arr.iter() {
733 assert!((-0.5..0.5).contains(&v), "value {v} out of bounds");
734 }
735 }
736
737 #[test]
738 fn test_gain_for_relu() {
739 let g = gain_for_activation("relu");
740 assert!((g - 2.0_f64.sqrt()).abs() < 1e-12);
741 }
742
743 #[test]
744 fn test_gain_for_tanh() {
745 let g = gain_for_activation("tanh");
746 assert!((g - 5.0 / 3.0).abs() < 1e-12);
747 }
748
749 #[test]
750 fn test_gain_for_unknown() {
751 assert!((gain_for_activation("swish") - 1.0).abs() < 1e-12);
752 }
753
754 #[test]
755 fn test_init_stats_compute() {
756 let arr = ones_init(&[4, 5]);
757 let stats = InitStats::compute(&arr, &[4, 5]);
758 assert_eq!(stats.num_elements, 20);
759 assert!((stats.mean - 1.0).abs() < 1e-12);
760 assert!(stats.std < 1e-12);
761 }
762
763 #[test]
764 fn test_init_stats_summary_nonempty() {
765 let arr = zeros_init(&[3, 3]);
766 let stats = InitStats::compute(&arr, &[3, 3]);
767 let s = stats.summary();
768 assert!(!s.is_empty());
769 assert!(s.contains("InitStats"));
770 }
771
772 #[test]
773 fn test_fan_mode_kaiming_changes_std() {
774 let shape = [128, 32];
776 let gain = 2.0_f64.sqrt();
777
778 let mut rng1 = InitRng::new(1000);
779 let arr_in =
780 kaiming_normal(&shape, gain, FanMode::FanIn, &mut rng1).expect("kaiming fan_in");
781
782 let mut rng2 = InitRng::new(1000);
783 let arr_out =
784 kaiming_normal(&shape, gain, FanMode::FanOut, &mut rng2).expect("kaiming fan_out");
785
786 let std_in = {
787 let n = arr_in.len() as f64;
788 let m: f64 = arr_in.iter().sum::<f64>() / n;
789 (arr_in.iter().map(|v| (v - m).powi(2)).sum::<f64>() / n).sqrt()
790 };
791 let std_out = {
792 let n = arr_out.len() as f64;
793 let m: f64 = arr_out.iter().sum::<f64>() / n;
794 (arr_out.iter().map(|v| (v - m).powi(2)).sum::<f64>() / n).sqrt()
795 };
796
797 assert!(
799 (std_in - std_out).abs() > 0.01,
800 "std_in={std_in} and std_out={std_out} should differ significantly"
801 );
802 }
803}