Skip to main content

tensorlogic_train/
weight_init.rs

1//! Weight initialization strategies for neural network parameters.
2//!
3//! Provides common initialization methods including Xavier/Glorot, Kaiming/He,
4//! LeCun, orthogonal, and basic constant/normal/uniform initializations.
5//! Uses a deterministic LCG-based RNG (no `rand` crate dependency).
6
7use scirs2_core::ndarray::{ArrayD, IxDyn};
8use std::f64::consts::PI;
9
10// ---------------------------------------------------------------------------
11// Error type
12// ---------------------------------------------------------------------------
13
14/// Errors that can occur during weight initialization.
15#[derive(Debug, Clone)]
16pub enum InitError {
17    /// Fan-in value is invalid (zero).
18    InvalidFanIn(usize),
19    /// Fan-out value is invalid (zero).
20    InvalidFanOut(usize),
21    /// Gain value is invalid (non-positive or non-finite).
22    InvalidGain(f64),
23    /// Standard deviation is invalid (non-positive or non-finite).
24    InvalidStd(f64),
25    /// Shape is too small for the requested operation.
26    ShapeTooSmall { shape: Vec<usize> },
27    /// Shape is empty.
28    EmptyShape,
29    /// Array creation failed.
30    ShapeError(String),
31}
32
33impl std::fmt::Display for InitError {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        match self {
36            Self::InvalidFanIn(v) => write!(f, "invalid fan_in: {v}"),
37            Self::InvalidFanOut(v) => write!(f, "invalid fan_out: {v}"),
38            Self::InvalidGain(v) => write!(f, "invalid gain: {v}"),
39            Self::InvalidStd(v) => write!(f, "invalid std: {v}"),
40            Self::ShapeTooSmall { shape } => write!(f, "shape too small: {shape:?}"),
41            Self::EmptyShape => write!(f, "empty shape"),
42            Self::ShapeError(msg) => write!(f, "shape error: {msg}"),
43        }
44    }
45}
46
47impl std::error::Error for InitError {}
48
49// ---------------------------------------------------------------------------
50// FanMode
51// ---------------------------------------------------------------------------
52
53/// Selects whether to use fan_in or fan_out for Kaiming initialization.
54#[derive(Debug, Clone, PartialEq)]
55pub enum FanMode {
56    /// Use the number of input connections (fan_in).
57    FanIn,
58    /// Use the number of output connections (fan_out).
59    FanOut,
60}
61
62// ---------------------------------------------------------------------------
63// Deterministic LCG RNG
64// ---------------------------------------------------------------------------
65
66/// A deterministic linear congruential generator (LCG) for reproducible
67/// weight initialization without depending on the `rand` crate.
68#[derive(Debug, Clone)]
69pub struct InitRng {
70    state: u64,
71}
72
73impl InitRng {
74    /// Create a new RNG with the given seed.
75    pub fn new(seed: u64) -> Self {
76        Self { state: seed }
77    }
78
79    /// Advance the LCG state by one step.
80    #[inline]
81    fn step(&mut self) {
82        self.state = self
83            .state
84            .wrapping_mul(6_364_136_223_846_793_005)
85            .wrapping_add(1_442_695_040_888_963_407);
86    }
87
88    /// Return the next uniform value in `[0, 1)`.
89    pub fn next_f64(&mut self) -> f64 {
90        self.step();
91        (self.state >> 11) as f64 / ((1u64 << 53) as f64)
92    }
93
94    /// Return a sample from the standard normal distribution N(0,1)
95    /// using the Box-Muller transform.
96    pub fn next_normal(&mut self) -> f64 {
97        let u1 = self.next_f64().max(f64::MIN_POSITIVE); // avoid ln(0)
98        let u2 = self.next_f64();
99        (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos()
100    }
101
102    /// Return a uniform value in `[low, high)`.
103    pub fn next_uniform(&mut self, low: f64, high: f64) -> f64 {
104        low + (high - low) * self.next_f64()
105    }
106}
107
108// ---------------------------------------------------------------------------
109// Helper: compute fan_in / fan_out
110// ---------------------------------------------------------------------------
111
112/// Compute `(fan_in, fan_out)` from a weight tensor shape.
113///
114/// - 2-D `[out_features, in_features]`:  fan_in = in_features, fan_out = out_features
115/// - N-D (N >= 3) convolution `[out_channels, in_channels, k1, k2, ...]`:
116///   fan_in  = in_channels  * product(k_dims)
117///   fan_out = out_channels * product(k_dims)
118pub fn compute_fans(shape: &[usize]) -> Result<(usize, usize), InitError> {
119    match shape.len() {
120        0 => Err(InitError::EmptyShape),
121        1 => Err(InitError::ShapeTooSmall {
122            shape: shape.to_vec(),
123        }),
124        2 => {
125            let fan_out = shape[0];
126            let fan_in = shape[1];
127            if fan_in == 0 {
128                return Err(InitError::InvalidFanIn(0));
129            }
130            if fan_out == 0 {
131                return Err(InitError::InvalidFanOut(0));
132            }
133            Ok((fan_in, fan_out))
134        }
135        _ => {
136            let receptive_field: usize = shape[2..].iter().product();
137            let fan_in = shape[1] * receptive_field;
138            let fan_out = shape[0] * receptive_field;
139            if fan_in == 0 {
140                return Err(InitError::InvalidFanIn(0));
141            }
142            if fan_out == 0 {
143                return Err(InitError::InvalidFanOut(0));
144            }
145            Ok((fan_in, fan_out))
146        }
147    }
148}
149
150// ---------------------------------------------------------------------------
151// Helper: build ArrayD from a Vec
152// ---------------------------------------------------------------------------
153
154fn make_array(shape: &[usize], data: Vec<f64>) -> Result<ArrayD<f64>, InitError> {
155    ArrayD::from_shape_vec(IxDyn(shape), data).map_err(|e| InitError::ShapeError(e.to_string()))
156}
157
158fn total_elements(shape: &[usize]) -> usize {
159    shape.iter().product()
160}
161
162// ---------------------------------------------------------------------------
163// Gain helper
164// ---------------------------------------------------------------------------
165
166/// Return the recommended gain for a given activation function name.
167///
168/// | Activation    | Gain                              |
169/// |---------------|-----------------------------------|
170/// | `"linear"`    | 1.0                               |
171/// | `"sigmoid"`   | 1.0                               |
172/// | `"tanh"`      | 5.0 / 3.0                         |
173/// | `"relu"`      | sqrt(2.0)                         |
174/// | `"leaky_relu"`| sqrt(2.0 / (1 + 0.01^2))         |
175/// | `"selu"`      | 0.75                              |
176/// | other         | 1.0                               |
177pub fn gain_for_activation(activation: &str) -> f64 {
178    match activation {
179        "linear" | "sigmoid" => 1.0,
180        "tanh" => 5.0 / 3.0,
181        "relu" => 2.0_f64.sqrt(),
182        "leaky_relu" => (2.0 / (1.0 + 0.01_f64.powi(2))).sqrt(),
183        "selu" => 3.0 / 4.0,
184        _ => 1.0,
185    }
186}
187
188// ---------------------------------------------------------------------------
189// Xavier / Glorot
190// ---------------------------------------------------------------------------
191
192/// Xavier (Glorot) **uniform** initialization.
193///
194/// Values drawn from U(-limit, limit) where `limit = gain * sqrt(6 / (fan_in + fan_out))`.
195pub fn xavier_uniform(
196    shape: &[usize],
197    gain: f64,
198    rng: &mut InitRng,
199) -> Result<ArrayD<f64>, InitError> {
200    validate_gain(gain)?;
201    let (fan_in, fan_out) = compute_fans(shape)?;
202    let limit = gain * (6.0 / (fan_in + fan_out) as f64).sqrt();
203    let n = total_elements(shape);
204    let data: Vec<f64> = (0..n).map(|_| rng.next_uniform(-limit, limit)).collect();
205    make_array(shape, data)
206}
207
208/// Xavier (Glorot) **normal** initialization.
209///
210/// Values drawn from N(0, std) where `std = gain * sqrt(2 / (fan_in + fan_out))`.
211pub fn xavier_normal(
212    shape: &[usize],
213    gain: f64,
214    rng: &mut InitRng,
215) -> Result<ArrayD<f64>, InitError> {
216    validate_gain(gain)?;
217    let (fan_in, fan_out) = compute_fans(shape)?;
218    let std = gain * (2.0 / (fan_in + fan_out) as f64).sqrt();
219    let n = total_elements(shape);
220    let data: Vec<f64> = (0..n).map(|_| std * rng.next_normal()).collect();
221    make_array(shape, data)
222}
223
224// ---------------------------------------------------------------------------
225// Kaiming / He
226// ---------------------------------------------------------------------------
227
228/// Kaiming (He) **uniform** initialization.
229///
230/// Values drawn from U(-bound, bound) where `bound = gain * sqrt(3 / fan)`.
231pub fn kaiming_uniform(
232    shape: &[usize],
233    gain: f64,
234    mode: FanMode,
235    rng: &mut InitRng,
236) -> Result<ArrayD<f64>, InitError> {
237    validate_gain(gain)?;
238    let (fan_in, fan_out) = compute_fans(shape)?;
239    let fan = match mode {
240        FanMode::FanIn => fan_in,
241        FanMode::FanOut => fan_out,
242    };
243    let bound = gain * (3.0 / fan as f64).sqrt();
244    let n = total_elements(shape);
245    let data: Vec<f64> = (0..n).map(|_| rng.next_uniform(-bound, bound)).collect();
246    make_array(shape, data)
247}
248
249/// Kaiming (He) **normal** initialization.
250///
251/// Values drawn from N(0, std) where `std = gain / sqrt(fan)`.
252pub fn kaiming_normal(
253    shape: &[usize],
254    gain: f64,
255    mode: FanMode,
256    rng: &mut InitRng,
257) -> Result<ArrayD<f64>, InitError> {
258    validate_gain(gain)?;
259    let (fan_in, fan_out) = compute_fans(shape)?;
260    let fan = match mode {
261        FanMode::FanIn => fan_in,
262        FanMode::FanOut => fan_out,
263    };
264    let std = gain / (fan as f64).sqrt();
265    let n = total_elements(shape);
266    let data: Vec<f64> = (0..n).map(|_| std * rng.next_normal()).collect();
267    make_array(shape, data)
268}
269
270// ---------------------------------------------------------------------------
271// LeCun
272// ---------------------------------------------------------------------------
273
274/// LeCun **normal** initialization: N(0, 1/sqrt(fan_in)).
275pub fn lecun_normal(shape: &[usize], rng: &mut InitRng) -> Result<ArrayD<f64>, InitError> {
276    let (fan_in, _) = compute_fans(shape)?;
277    let std = 1.0 / (fan_in as f64).sqrt();
278    let n = total_elements(shape);
279    let data: Vec<f64> = (0..n).map(|_| std * rng.next_normal()).collect();
280    make_array(shape, data)
281}
282
283/// LeCun **uniform** initialization: U(-limit, limit) where `limit = sqrt(3/fan_in)`.
284pub fn lecun_uniform(shape: &[usize], rng: &mut InitRng) -> Result<ArrayD<f64>, InitError> {
285    let (fan_in, _) = compute_fans(shape)?;
286    let limit = (3.0 / fan_in as f64).sqrt();
287    let n = total_elements(shape);
288    let data: Vec<f64> = (0..n).map(|_| rng.next_uniform(-limit, limit)).collect();
289    make_array(shape, data)
290}
291
292// ---------------------------------------------------------------------------
293// Constant / Zeros / Ones
294// ---------------------------------------------------------------------------
295
296/// Initialize all elements to a constant value.
297pub fn constant_init(shape: &[usize], value: f64) -> ArrayD<f64> {
298    ArrayD::from_elem(IxDyn(shape), value)
299}
300
301/// Initialize all elements to zero.
302pub fn zeros_init(shape: &[usize]) -> ArrayD<f64> {
303    ArrayD::zeros(IxDyn(shape))
304}
305
306/// Initialize all elements to one.
307pub fn ones_init(shape: &[usize]) -> ArrayD<f64> {
308    ArrayD::ones(IxDyn(shape))
309}
310
311// ---------------------------------------------------------------------------
312// Normal / Uniform
313// ---------------------------------------------------------------------------
314
315/// Normal initialization with specified mean and standard deviation.
316pub fn normal_init(
317    shape: &[usize],
318    mean: f64,
319    std: f64,
320    rng: &mut InitRng,
321) -> Result<ArrayD<f64>, InitError> {
322    if std <= 0.0 || !std.is_finite() {
323        return Err(InitError::InvalidStd(std));
324    }
325    let n = total_elements(shape);
326    let data: Vec<f64> = (0..n).map(|_| mean + std * rng.next_normal()).collect();
327    make_array(shape, data)
328}
329
330/// Uniform initialization with specified bounds `[low, high)`.
331pub fn uniform_init(
332    shape: &[usize],
333    low: f64,
334    high: f64,
335    rng: &mut InitRng,
336) -> Result<ArrayD<f64>, InitError> {
337    if low >= high {
338        return Err(InitError::InvalidStd(high - low)); // reuse for "bad range"
339    }
340    let n = total_elements(shape);
341    let data: Vec<f64> = (0..n).map(|_| rng.next_uniform(low, high)).collect();
342    make_array(shape, data)
343}
344
345// ---------------------------------------------------------------------------
346// Orthogonal
347// ---------------------------------------------------------------------------
348
349/// Orthogonal initialization via QR-like Gram-Schmidt on a random matrix.
350///
351/// Generates a random matrix, then orthogonalises it. For non-square shapes
352/// the result is reshaped to the requested dimensions. The `gain` parameter
353/// scales the resulting orthogonal matrix.
354pub fn orthogonal_init(
355    shape: &[usize],
356    gain: f64,
357    rng: &mut InitRng,
358) -> Result<ArrayD<f64>, InitError> {
359    validate_gain(gain)?;
360    if shape.len() < 2 {
361        return Err(InitError::ShapeTooSmall {
362            shape: shape.to_vec(),
363        });
364    }
365
366    let rows = shape[0];
367    let cols: usize = shape[1..].iter().product();
368    if rows == 0 || cols == 0 {
369        return Err(InitError::ShapeTooSmall {
370            shape: shape.to_vec(),
371        });
372    }
373
374    // Generate a random matrix (rows x cols)
375    let n = rows * cols;
376    let mut flat: Vec<f64> = (0..n).map(|_| rng.next_normal()).collect();
377
378    // Determine whether we QR on the matrix or its transpose
379    let (work_rows, work_cols, transposed) = if rows >= cols {
380        (rows, cols, false)
381    } else {
382        (cols, rows, true)
383    };
384
385    // Build column-major representation for Gram-Schmidt
386    // We work on a (work_rows x work_cols) matrix stored as columns.
387    let mut columns: Vec<Vec<f64>> = if !transposed {
388        // columns of the original matrix
389        (0..work_cols)
390            .map(|c| (0..work_rows).map(|r| flat[r * cols + c]).collect())
391            .collect()
392    } else {
393        // columns of the transpose: rows of the original
394        (0..work_cols)
395            .map(|c| (0..work_rows).map(|r| flat[c * cols + r]).collect())
396            .collect()
397    };
398
399    // Modified Gram-Schmidt orthogonalisation
400    for i in 0..work_cols {
401        // Normalise column i
402        let norm = dot_vec(&columns[i], &columns[i]).sqrt();
403        if norm < 1e-15 {
404            // Degenerate column – fill with a canonical basis vector
405            for v in columns[i].iter_mut() {
406                *v = 0.0;
407            }
408            if i < work_rows {
409                columns[i][i] = 1.0;
410            }
411        } else {
412            for v in columns[i].iter_mut() {
413                *v /= norm;
414            }
415        }
416
417        // Project out column i from subsequent columns
418        let qi = columns[i].clone();
419        for col in columns.iter_mut().skip(i + 1) {
420            let proj = dot_vec(&qi, col);
421            for (v, q) in col.iter_mut().zip(qi.iter()) {
422                *v -= proj * q;
423            }
424        }
425    }
426
427    // Reassemble into flat (rows x cols) row-major
428    flat.clear();
429    if !transposed {
430        for r in 0..rows {
431            for col in columns.iter().take(cols) {
432                flat.push(gain * col[r]);
433            }
434        }
435    } else {
436        for col_vec in columns.iter().take(rows) {
437            for &val in col_vec.iter().take(cols) {
438                flat.push(gain * val);
439            }
440        }
441    }
442
443    make_array(shape, flat)
444}
445
446/// Dot product of two equal-length vectors.
447fn dot_vec(a: &[f64], b: &[f64]) -> f64 {
448    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
449}
450
451// ---------------------------------------------------------------------------
452// Validation helpers
453// ---------------------------------------------------------------------------
454
455fn validate_gain(gain: f64) -> Result<(), InitError> {
456    if gain <= 0.0 || !gain.is_finite() {
457        return Err(InitError::InvalidGain(gain));
458    }
459    Ok(())
460}
461
462// ---------------------------------------------------------------------------
463// InitStats
464// ---------------------------------------------------------------------------
465
466/// Statistics about an initialized weight tensor.
467#[derive(Debug, Clone)]
468pub struct InitStats {
469    /// Shape of the tensor.
470    pub shape: Vec<usize>,
471    /// Total number of elements.
472    pub num_elements: usize,
473    /// Mean value.
474    pub mean: f64,
475    /// Standard deviation.
476    pub std: f64,
477    /// Minimum value.
478    pub min: f64,
479    /// Maximum value.
480    pub max: f64,
481    /// Computed fan_in.
482    pub fan_in: usize,
483    /// Computed fan_out.
484    pub fan_out: usize,
485}
486
487impl InitStats {
488    /// Compute statistics for the given tensor and shape.
489    pub fn compute(tensor: &ArrayD<f64>, shape: &[usize]) -> Self {
490        let n = tensor.len();
491        let (fan_in, fan_out) = compute_fans(shape).unwrap_or((0, 0));
492
493        let mut sum = 0.0_f64;
494        let mut min_val = f64::INFINITY;
495        let mut max_val = f64::NEG_INFINITY;
496
497        for &v in tensor.iter() {
498            sum += v;
499            if v < min_val {
500                min_val = v;
501            }
502            if v > max_val {
503                max_val = v;
504            }
505        }
506
507        let mean = if n > 0 { sum / n as f64 } else { 0.0 };
508
509        let variance = if n > 1 {
510            let mut sq_sum = 0.0_f64;
511            for &v in tensor.iter() {
512                sq_sum += (v - mean).powi(2);
513            }
514            sq_sum / n as f64
515        } else {
516            0.0
517        };
518
519        Self {
520            shape: shape.to_vec(),
521            num_elements: n,
522            mean,
523            std: variance.sqrt(),
524            min: min_val,
525            max: max_val,
526            fan_in,
527            fan_out,
528        }
529    }
530
531    /// Return a human-readable summary string.
532    pub fn summary(&self) -> String {
533        format!(
534            "InitStats {{ shape: {:?}, n: {}, mean: {:.6}, std: {:.6}, \
535             min: {:.6}, max: {:.6}, fan_in: {}, fan_out: {} }}",
536            self.shape,
537            self.num_elements,
538            self.mean,
539            self.std,
540            self.min,
541            self.max,
542            self.fan_in,
543            self.fan_out,
544        )
545    }
546}
547
548// ===========================================================================
549// Tests
550// ===========================================================================
551
552#[cfg(test)]
553mod tests {
554    use super::*;
555
556    #[test]
557    fn test_compute_fans_2d() {
558        let (fan_in, fan_out) = compute_fans(&[10, 5]).expect("compute_fans failed");
559        assert_eq!(fan_in, 5);
560        assert_eq!(fan_out, 10);
561    }
562
563    #[test]
564    fn test_compute_fans_4d() {
565        // [out_ch=16, in_ch=3, kH=3, kW=3]
566        let (fan_in, fan_out) = compute_fans(&[16, 3, 3, 3]).expect("compute_fans failed");
567        assert_eq!(fan_in, 3 * 3 * 3); // 27
568        assert_eq!(fan_out, 16 * 3 * 3); // 144
569    }
570
571    #[test]
572    fn test_xavier_uniform_range() {
573        let shape = [64, 32];
574        let (fan_in, fan_out) = compute_fans(&shape).expect("fans");
575        let limit = (6.0 / (fan_in + fan_out) as f64).sqrt();
576        let mut rng = InitRng::new(42);
577        let arr = xavier_uniform(&shape, 1.0, &mut rng).expect("xavier_uniform");
578        for &v in arr.iter() {
579            assert!(
580                v >= -limit && v <= limit,
581                "value {v} outside [{}, {}]",
582                -limit,
583                limit
584            );
585        }
586    }
587
588    #[test]
589    fn test_xavier_normal_mean_near_zero() {
590        let shape = [256, 128];
591        let mut rng = InitRng::new(123);
592        let arr = xavier_normal(&shape, 1.0, &mut rng).expect("xavier_normal");
593        let mean: f64 = arr.iter().sum::<f64>() / arr.len() as f64;
594        assert!(mean.abs() < 0.05, "mean too far from zero: {mean}");
595    }
596
597    #[test]
598    fn test_kaiming_uniform_fan_in() {
599        let shape = [64, 32];
600        let gain = 2.0_f64.sqrt();
601        let (fan_in, _) = compute_fans(&shape).expect("fans");
602        let bound = gain * (3.0 / fan_in as f64).sqrt();
603        let mut rng = InitRng::new(7);
604        let arr = kaiming_uniform(&shape, gain, FanMode::FanIn, &mut rng).expect("kaiming_uniform");
605        for &v in arr.iter() {
606            assert!(
607                v >= -bound && v <= bound,
608                "value {v} outside [{}, {}]",
609                -bound,
610                bound
611            );
612        }
613    }
614
615    #[test]
616    fn test_kaiming_normal_std() {
617        let shape = [256, 128];
618        let gain = 2.0_f64.sqrt();
619        let (fan_in, _) = compute_fans(&shape).expect("fans");
620        let expected_std = gain / (fan_in as f64).sqrt();
621        let mut rng = InitRng::new(99);
622        let arr = kaiming_normal(&shape, gain, FanMode::FanIn, &mut rng).expect("kaiming_normal");
623        let n = arr.len() as f64;
624        let mean: f64 = arr.iter().sum::<f64>() / n;
625        let var: f64 = arr.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / n;
626        let actual_std = var.sqrt();
627        let ratio = actual_std / expected_std;
628        assert!(
629            (0.85..=1.15).contains(&ratio),
630            "std ratio {ratio} (actual={actual_std}, expected={expected_std})"
631        );
632    }
633
634    #[test]
635    fn test_lecun_normal_shape() {
636        let shape = [16, 8, 3, 3];
637        let mut rng = InitRng::new(55);
638        let arr = lecun_normal(&shape, &mut rng).expect("lecun_normal");
639        assert_eq!(arr.shape(), &[16, 8, 3, 3]);
640    }
641
642    #[test]
643    fn test_lecun_uniform_range() {
644        let shape = [32, 16];
645        let (fan_in, _) = compute_fans(&shape).expect("fans");
646        let limit = (3.0 / fan_in as f64).sqrt();
647        let mut rng = InitRng::new(11);
648        let arr = lecun_uniform(&shape, &mut rng).expect("lecun_uniform");
649        for &v in arr.iter() {
650            assert!(
651                v >= -limit && v <= limit,
652                "value {v} outside [{}, {}]",
653                -limit,
654                limit
655            );
656        }
657    }
658
659    #[test]
660    fn test_constant_init_value() {
661        let arr = constant_init(&[3, 4], 3.15);
662        for &v in arr.iter() {
663            assert!((v - 3.15).abs() < 1e-12);
664        }
665    }
666
667    #[test]
668    fn test_zeros_init() {
669        let arr = zeros_init(&[5, 5]);
670        for &v in arr.iter() {
671            assert!((v).abs() < 1e-15);
672        }
673    }
674
675    #[test]
676    fn test_ones_init() {
677        let arr = ones_init(&[2, 3]);
678        for &v in arr.iter() {
679            assert!((v - 1.0).abs() < 1e-15);
680        }
681    }
682
683    #[test]
684    fn test_orthogonal_init_square() {
685        let shape = [8, 8];
686        let mut rng = InitRng::new(77);
687        let arr = orthogonal_init(&shape, 1.0, &mut rng).expect("orthogonal_init");
688        // Check Q^T Q ≈ I  (columns are orthonormal)
689        let n = 8;
690        for i in 0..n {
691            for j in 0..n {
692                let mut dot = 0.0_f64;
693                for k in 0..n {
694                    // arr[[k, i]] * arr[[k, j]]
695                    dot += arr[[k, i].as_ref()] * arr[[k, j].as_ref()];
696                }
697                let expected = if i == j { 1.0 } else { 0.0 };
698                assert!(
699                    (dot - expected).abs() < 1e-8,
700                    "Q^T Q [{i},{j}] = {dot}, expected {expected}"
701                );
702            }
703        }
704    }
705
706    #[test]
707    fn test_normal_init_distribution() {
708        let shape = [512, 256];
709        let target_mean = 2.0;
710        let target_std = 0.5;
711        let mut rng = InitRng::new(42);
712        let arr = normal_init(&shape, target_mean, target_std, &mut rng).expect("normal_init");
713        let n = arr.len() as f64;
714        let mean: f64 = arr.iter().sum::<f64>() / n;
715        let var: f64 = arr.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / n;
716        let actual_std = var.sqrt();
717        assert!(
718            (mean - target_mean).abs() < 0.05,
719            "mean {mean} far from {target_mean}"
720        );
721        assert!(
722            (actual_std - target_std).abs() < 0.05,
723            "std {actual_std} far from {target_std}"
724        );
725    }
726
727    #[test]
728    fn test_uniform_init_bounds() {
729        let shape = [100, 100];
730        let mut rng = InitRng::new(13);
731        let arr = uniform_init(&shape, -0.5, 0.5, &mut rng).expect("uniform_init");
732        for &v in arr.iter() {
733            assert!((-0.5..0.5).contains(&v), "value {v} out of bounds");
734        }
735    }
736
737    #[test]
738    fn test_gain_for_relu() {
739        let g = gain_for_activation("relu");
740        assert!((g - 2.0_f64.sqrt()).abs() < 1e-12);
741    }
742
743    #[test]
744    fn test_gain_for_tanh() {
745        let g = gain_for_activation("tanh");
746        assert!((g - 5.0 / 3.0).abs() < 1e-12);
747    }
748
749    #[test]
750    fn test_gain_for_unknown() {
751        assert!((gain_for_activation("swish") - 1.0).abs() < 1e-12);
752    }
753
754    #[test]
755    fn test_init_stats_compute() {
756        let arr = ones_init(&[4, 5]);
757        let stats = InitStats::compute(&arr, &[4, 5]);
758        assert_eq!(stats.num_elements, 20);
759        assert!((stats.mean - 1.0).abs() < 1e-12);
760        assert!(stats.std < 1e-12);
761    }
762
763    #[test]
764    fn test_init_stats_summary_nonempty() {
765        let arr = zeros_init(&[3, 3]);
766        let stats = InitStats::compute(&arr, &[3, 3]);
767        let s = stats.summary();
768        assert!(!s.is_empty());
769        assert!(s.contains("InitStats"));
770    }
771
772    #[test]
773    fn test_fan_mode_kaiming_changes_std() {
774        // For a non-square shape, fan_in != fan_out, so distributions differ.
775        let shape = [128, 32];
776        let gain = 2.0_f64.sqrt();
777
778        let mut rng1 = InitRng::new(1000);
779        let arr_in =
780            kaiming_normal(&shape, gain, FanMode::FanIn, &mut rng1).expect("kaiming fan_in");
781
782        let mut rng2 = InitRng::new(1000);
783        let arr_out =
784            kaiming_normal(&shape, gain, FanMode::FanOut, &mut rng2).expect("kaiming fan_out");
785
786        let std_in = {
787            let n = arr_in.len() as f64;
788            let m: f64 = arr_in.iter().sum::<f64>() / n;
789            (arr_in.iter().map(|v| (v - m).powi(2)).sum::<f64>() / n).sqrt()
790        };
791        let std_out = {
792            let n = arr_out.len() as f64;
793            let m: f64 = arr_out.iter().sum::<f64>() / n;
794            (arr_out.iter().map(|v| (v - m).powi(2)).sum::<f64>() / n).sqrt()
795        };
796
797        // fan_in=32 vs fan_out=128, so std_in should be larger than std_out
798        assert!(
799            (std_in - std_out).abs() > 0.01,
800            "std_in={std_in} and std_out={std_out} should differ significantly"
801        );
802    }
803}