scirs2_neural/utils/
initializers.rs1use crate::error::{NeuralError, Result};
4use ndarray::{Array, Dimension, IxDyn};
5use num_traits::Float;
6use rand::Rng;
7use std::fmt::Debug;
8
9#[derive(Debug, Clone, Copy)]
11pub enum Initializer {
12 Zeros,
14 Ones,
16 Uniform {
18 min: f64,
20 max: f64,
22 },
23 Normal {
25 mean: f64,
27 std: f64,
29 },
30 Xavier,
32 He,
34 LeCun,
36}
37
38impl Initializer {
39 pub fn initialize<F: Float + Debug, R: Rng>(
52 &self,
53 shape: IxDyn,
54 fan_in: usize,
55 fan_out: usize,
56 rng: &mut R,
57 ) -> Result<Array<F, IxDyn>> {
58 let size = shape.as_array_view().iter().product();
59
60 match self {
61 Initializer::Zeros => Ok(Array::zeros(shape)),
62 Initializer::Ones => {
63 let ones: Vec<F> = (0..size).map(|_| F::one()).collect();
64
65 Array::from_shape_vec(shape, ones).map_err(|e| {
66 NeuralError::InvalidArchitecture(format!("Failed to create array: {}", e))
67 })
68 }
69 Initializer::Uniform { min, max } => {
70 let values: Vec<F> = (0..size)
71 .map(|_| {
72 let val = rng.random_range(*min..*max);
73 F::from(val).ok_or_else(|| {
74 NeuralError::InvalidArchitecture(
75 "Failed to convert random value".to_string(),
76 )
77 })
78 })
79 .collect::<Result<Vec<F>>>()?;
80
81 Array::from_shape_vec(shape, values).map_err(|e| {
82 NeuralError::InvalidArchitecture(format!("Failed to create array: {}", e))
83 })
84 }
85 Initializer::Normal { mean, std } => {
86 let values: Vec<F> = (0..size)
87 .map(|_| {
88 let u1 = rng.random_range(0.0..1.0);
90 let u2 = rng.random_range(0.0..1.0);
91
92 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
93 let val = mean + std * z;
94
95 F::from(val).ok_or_else(|| {
96 NeuralError::InvalidArchitecture(
97 "Failed to convert random value".to_string(),
98 )
99 })
100 })
101 .collect::<Result<Vec<F>>>()?;
102
103 Array::from_shape_vec(shape, values).map_err(|e| {
104 NeuralError::InvalidArchitecture(format!("Failed to create array: {}", e))
105 })
106 }
107 Initializer::Xavier => {
108 let limit = (6.0 / (fan_in + fan_out) as f64).sqrt();
109
110 let values: Vec<F> = (0..size)
111 .map(|_| {
112 let val = rng.random_range(-limit..limit);
113 F::from(val).ok_or_else(|| {
114 NeuralError::InvalidArchitecture(
115 "Failed to convert random value".to_string(),
116 )
117 })
118 })
119 .collect::<Result<Vec<F>>>()?;
120
121 Array::from_shape_vec(shape, values).map_err(|e| {
122 NeuralError::InvalidArchitecture(format!("Failed to create array: {}", e))
123 })
124 }
125 Initializer::He => {
126 let std = (2.0 / fan_in as f64).sqrt();
127
128 let values: Vec<F> = (0..size)
129 .map(|_| {
130 let u1 = rng.random_range(0.0..1.0);
132 let u2 = rng.random_range(0.0..1.0);
133
134 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
135 let val = std * z;
136
137 F::from(val).ok_or_else(|| {
138 NeuralError::InvalidArchitecture(
139 "Failed to convert random value".to_string(),
140 )
141 })
142 })
143 .collect::<Result<Vec<F>>>()?;
144
145 Array::from_shape_vec(shape, values).map_err(|e| {
146 NeuralError::InvalidArchitecture(format!("Failed to create array: {}", e))
147 })
148 }
149 Initializer::LeCun => {
150 let std = (1.0 / fan_in as f64).sqrt();
151
152 let values: Vec<F> = (0..size)
153 .map(|_| {
154 let u1 = rng.random_range(0.0..1.0);
156 let u2 = rng.random_range(0.0..1.0);
157
158 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
159 let val = std * z;
160
161 F::from(val).ok_or_else(|| {
162 NeuralError::InvalidArchitecture(
163 "Failed to convert random value".to_string(),
164 )
165 })
166 })
167 .collect::<Result<Vec<F>>>()?;
168
169 Array::from_shape_vec(shape, values).map_err(|e| {
170 NeuralError::InvalidArchitecture(format!("Failed to create array: {}", e))
171 })
172 }
173 }
174 }
175}
176
177pub fn xavier_uniform<F: Float + Debug>(shape: IxDyn) -> Result<Array<F, IxDyn>> {
187 let fan_in = match shape.ndim() {
188 0 => 1,
189 1 => shape[0],
190 _ => shape[0],
191 };
192
193 let fan_out = match shape.ndim() {
194 0 => 1,
195 1 => 1,
196 _ => shape[1],
197 };
198
199 let mut rng = rand::rng();
200 Initializer::Xavier.initialize(shape, fan_in, fan_out, &mut rng)
201}