scirs2_neural/utils/
initializers.rs

1//! Weight initialization strategies for neural networks
2
3use crate::error::{NeuralError, Result};
4use scirs2_core::ndarray::{Array, Dimension, IxDyn};
5use scirs2_core::numeric::Float;
6use scirs2_core::random::Rng;
7use std::fmt::Debug;
8/// Initialization strategies for neural network weights
9#[derive(Debug, Clone, Copy)]
10pub enum Initializer {
11    /// Zero initialization
12    Zeros,
13    /// One initialization
14    Ones,
15    /// Uniform random initialization
16    Uniform {
17        /// Minimum value
18        min: f64,
19        /// Maximum value
20        max: f64,
21    },
22    /// Normal random initialization
23    Normal {
24        /// Mean
25        mean: f64,
26        /// Standard deviation
27        std: f64,
28    },
29    /// Xavier/Glorot initialization
30    Xavier,
31    /// He initialization
32    He,
33    /// LeCun initialization
34    LeCun,
35}
36impl Initializer {
37    /// Initialize weights according to the strategy
38    ///
39    /// # Arguments
40    /// * `shape` - Shape of the weights array
41    /// * `fan_in` - Number of input connections (for Xavier, He, LeCun)
42    /// * `fan_out` - Number of output connections (for Xavier)
43    /// * `rng` - Random number generator
44    /// # Returns
45    /// * Initialized weights array
46    pub fn initialize<F: Float + Debug, R: Rng>(
47        &self,
48        shape: IxDyn,
49        fan_in: usize,
50        fan_out: usize,
51        rng: &mut R,
52    ) -> Result<Array<F, IxDyn>> {
53        let size = shape.as_array_view().iter().product();
54        match self {
55            Initializer::Zeros => Ok(Array::zeros(shape)),
56            Initializer::Ones => {
57                let ones: Vec<F> = (0..size).map(|_| F::one()).collect();
58                Array::from_shape_vec(shape, ones).map_err(|e| {
59                    NeuralError::InvalidArchitecture(format!("Failed to create array: {e}"))
60                })
61            }
62            Initializer::Uniform { min, max } => {
63                let values: Vec<F> = (0..size)
64                    .map(|_| {
65                        let val = rng.gen_range(*min..*max);
66                        F::from(val).ok_or_else(|| {
67                            NeuralError::InvalidArchitecture(
68                                "Failed to convert random value".to_string(),
69                            )
70                        })
71                    })
72                    .collect::<Result<Vec<F>>>()?;
73                Array::from_shape_vec(shape, values).map_err(|e| {
74                    NeuralError::InvalidArchitecture(format!("Failed to create array: {e}"))
75                })
76            }
77            Initializer::Normal { mean, std } => {
78                let values: Vec<F> = (0..(size / 2 + 1))
79                    .flat_map(|_| {
80                        // Box-Muller transform to generate normal distribution
81                        let u1 = rng.gen_range(0.0..1.0);
82                        let u2 = rng.gen_range(0.0..1.0);
83                        let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
84                        let z1 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).sin();
85                        let val0 = mean + std * z0;
86                        let val1 = mean + std * z1;
87                        vec![
88                            F::from(val0).unwrap_or(F::zero()),
89                            F::from(val1).unwrap_or(F::zero()),
90                        ]
91                    })
92                    .take(size)
93                    .collect();
94                Array::from_shape_vec(shape, values).map_err(|e| {
95                    NeuralError::InvalidArchitecture(format!("Failed to create array: {e}"))
96                })
97            }
98            Initializer::Xavier => {
99                let limit = (6.0 / (fan_in + fan_out) as f64).sqrt();
100                let values: Vec<F> = (0..size)
101                    .map(|_| {
102                        let val = rng.gen_range(-limit..limit);
103                        F::from(val).unwrap_or(F::zero())
104                    })
105                    .collect();
106                Array::from_shape_vec(shape, values).map_err(|e| {
107                    NeuralError::InvalidArchitecture(format!("Failed to create array: {e}"))
108                })
109            }
110            Initializer::He => {
111                let std = (2.0 / fan_in as f64).sqrt();
112                let values: Vec<F> = (0..(size / 2 + 1))
113                    .flat_map(|_| {
114                        // Box-Muller transform for He initialization
115                        let u1 = rng.gen_range(0.0..1.0);
116                        let u2 = rng.gen_range(0.0..1.0);
117                        let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
118                        let z1 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).sin();
119                        let val0 = std * z0;
120                        let val1 = std * z1;
121                        vec![
122                            F::from(val0).unwrap_or(F::zero()),
123                            F::from(val1).unwrap_or(F::zero()),
124                        ]
125                    })
126                    .take(size)
127                    .collect();
128                Array::from_shape_vec(shape, values).map_err(|e| {
129                    NeuralError::InvalidArchitecture(format!("Failed to create array: {e}"))
130                })
131            }
132            Initializer::LeCun => {
133                let std = (1.0 / fan_in as f64).sqrt();
134                let values: Vec<F> = (0..(size / 2 + 1))
135                    .flat_map(|_| {
136                        // Box-Muller transform for LeCun initialization
137                        let u1 = rng.gen_range(0.0..1.0);
138                        let u2 = rng.gen_range(0.0..1.0);
139                        let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
140                        let z1 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).sin();
141                        let val0 = std * z0;
142                        let val1 = std * z1;
143                        vec![
144                            F::from(val0).unwrap_or(F::zero()),
145                            F::from(val1).unwrap_or(F::zero()),
146                        ]
147                    })
148                    .take(size)
149                    .collect();
150                Array::from_shape_vec(shape, values).map_err(|e| {
151                    NeuralError::InvalidArchitecture(format!("Failed to create array: {e}"))
152                })
153            }
154        }
155    }
156}
157/// Xavier/Glorot uniform initialization
158///
159/// # Arguments
160/// * `shape` - Shape of the weights array
161/// # Returns
162/// * Initialized weights array
163#[allow(dead_code)]
164pub fn xavier_uniform<F: Float + Debug>(shape: IxDyn) -> Result<Array<F, IxDyn>> {
165    let fan_in = match shape.ndim() {
166        0 => 1,
167        1 => shape[0],
168        _ => shape[0],
169    };
170    let fan_out = match shape.ndim() {
171        1 => 1,
172        _ => shape[1],
173    };
174    let mut rng = scirs2_core::random::rng();
175    Initializer::Xavier.initialize(shape, fan_in, fan_out, &mut rng)
176}