scirs2_neural/utils/
initializers.rs1use crate::error::{NeuralError, Result};
4use scirs2_core::ndarray::{Array, Dimension, IxDyn};
5use scirs2_core::numeric::Float;
6use scirs2_core::random::Rng;
7use std::fmt::Debug;
8#[derive(Debug, Clone, Copy)]
10pub enum Initializer {
11 Zeros,
13 Ones,
15 Uniform {
17 min: f64,
19 max: f64,
21 },
22 Normal {
24 mean: f64,
26 std: f64,
28 },
29 Xavier,
31 He,
33 LeCun,
35}
36impl Initializer {
37 pub fn initialize<F: Float + Debug, R: Rng>(
47 &self,
48 shape: IxDyn,
49 fan_in: usize,
50 fan_out: usize,
51 rng: &mut R,
52 ) -> Result<Array<F, IxDyn>> {
53 let size = shape.as_array_view().iter().product();
54 match self {
55 Initializer::Zeros => Ok(Array::zeros(shape)),
56 Initializer::Ones => {
57 let ones: Vec<F> = (0..size).map(|_| F::one()).collect();
58 Array::from_shape_vec(shape, ones).map_err(|e| {
59 NeuralError::InvalidArchitecture(format!("Failed to create array: {e}"))
60 })
61 }
62 Initializer::Uniform { min, max } => {
63 let values: Vec<F> = (0..size)
64 .map(|_| {
65 let val = rng.gen_range(*min..*max);
66 F::from(val).ok_or_else(|| {
67 NeuralError::InvalidArchitecture(
68 "Failed to convert random value".to_string(),
69 )
70 })
71 })
72 .collect::<Result<Vec<F>>>()?;
73 Array::from_shape_vec(shape, values).map_err(|e| {
74 NeuralError::InvalidArchitecture(format!("Failed to create array: {e}"))
75 })
76 }
77 Initializer::Normal { mean, std } => {
78 let values: Vec<F> = (0..(size / 2 + 1))
79 .flat_map(|_| {
80 let u1 = rng.gen_range(0.0..1.0);
82 let u2 = rng.gen_range(0.0..1.0);
83 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
84 let z1 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).sin();
85 let val0 = mean + std * z0;
86 let val1 = mean + std * z1;
87 vec![
88 F::from(val0).unwrap_or(F::zero()),
89 F::from(val1).unwrap_or(F::zero()),
90 ]
91 })
92 .take(size)
93 .collect();
94 Array::from_shape_vec(shape, values).map_err(|e| {
95 NeuralError::InvalidArchitecture(format!("Failed to create array: {e}"))
96 })
97 }
98 Initializer::Xavier => {
99 let limit = (6.0 / (fan_in + fan_out) as f64).sqrt();
100 let values: Vec<F> = (0..size)
101 .map(|_| {
102 let val = rng.gen_range(-limit..limit);
103 F::from(val).unwrap_or(F::zero())
104 })
105 .collect();
106 Array::from_shape_vec(shape, values).map_err(|e| {
107 NeuralError::InvalidArchitecture(format!("Failed to create array: {e}"))
108 })
109 }
110 Initializer::He => {
111 let std = (2.0 / fan_in as f64).sqrt();
112 let values: Vec<F> = (0..(size / 2 + 1))
113 .flat_map(|_| {
114 let u1 = rng.gen_range(0.0..1.0);
116 let u2 = rng.gen_range(0.0..1.0);
117 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
118 let z1 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).sin();
119 let val0 = std * z0;
120 let val1 = std * z1;
121 vec![
122 F::from(val0).unwrap_or(F::zero()),
123 F::from(val1).unwrap_or(F::zero()),
124 ]
125 })
126 .take(size)
127 .collect();
128 Array::from_shape_vec(shape, values).map_err(|e| {
129 NeuralError::InvalidArchitecture(format!("Failed to create array: {e}"))
130 })
131 }
132 Initializer::LeCun => {
133 let std = (1.0 / fan_in as f64).sqrt();
134 let values: Vec<F> = (0..(size / 2 + 1))
135 .flat_map(|_| {
136 let u1 = rng.gen_range(0.0..1.0);
138 let u2 = rng.gen_range(0.0..1.0);
139 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
140 let z1 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).sin();
141 let val0 = std * z0;
142 let val1 = std * z1;
143 vec![
144 F::from(val0).unwrap_or(F::zero()),
145 F::from(val1).unwrap_or(F::zero()),
146 ]
147 })
148 .take(size)
149 .collect();
150 Array::from_shape_vec(shape, values).map_err(|e| {
151 NeuralError::InvalidArchitecture(format!("Failed to create array: {e}"))
152 })
153 }
154 }
155 }
156}
157#[allow(dead_code)]
164pub fn xavier_uniform<F: Float + Debug>(shape: IxDyn) -> Result<Array<F, IxDyn>> {
165 let fan_in = match shape.ndim() {
166 0 => 1,
167 1 => shape[0],
168 _ => shape[0],
169 };
170 let fan_out = match shape.ndim() {
171 1 => 1,
172 _ => shape[1],
173 };
174 let mut rng = scirs2_core::random::rng();
175 Initializer::Xavier.initialize(shape, fan_in, fan_out, &mut rng)
176}