1use crate::error::{NeuralError, Result};
7use crate::layers::{Layer, ParamLayer};
8use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
9use scirs2_core::numeric::Float;
10use scirs2_core::random::Rng;
11use std::fmt::Debug;
12use std::sync::{Arc, RwLock};
13
14#[derive(Debug)]
20pub struct LayerNorm<F: Float + Debug + Send + Sync> {
21 normalizedshape: Vec<usize>,
23 gamma: Array<F, IxDyn>,
25 beta: Array<F, IxDyn>,
27 dgamma: Arc<RwLock<Array<F, IxDyn>>>,
29 dbeta: Arc<RwLock<Array<F, IxDyn>>>,
31 eps: F,
33 input_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
35 norm_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
37 mean_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
39 var_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
41}
42
43impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Clone for LayerNorm<F> {
44 fn clone(&self) -> Self {
45 let input_cache_clone = match self.input_cache.read() {
46 Ok(guard) => guard.clone(),
47 Err(_) => None,
48 };
49 let norm_cache_clone = match self.norm_cache.read() {
50 Ok(guard) => guard.clone(),
51 Err(_) => None,
52 };
53 let mean_cache_clone = match self.mean_cache.read() {
54 Ok(guard) => guard.clone(),
55 Err(_) => None,
56 };
57 let var_cache_clone = match self.var_cache.read() {
58 Ok(guard) => guard.clone(),
59 Err(_) => None,
60 };
61
62 Self {
63 normalizedshape: self.normalizedshape.clone(),
64 gamma: self.gamma.clone(),
65 beta: self.beta.clone(),
66 dgamma: Arc::new(RwLock::new(self.dgamma.read().unwrap().clone())),
67 dbeta: Arc::new(RwLock::new(self.dbeta.read().unwrap().clone())),
68 eps: self.eps,
69 input_cache: Arc::new(RwLock::new(input_cache_clone)),
70 norm_cache: Arc::new(RwLock::new(norm_cache_clone)),
71 mean_cache: Arc::new(RwLock::new(mean_cache_clone)),
72 var_cache: Arc::new(RwLock::new(var_cache_clone)),
73 }
74 }
75}
76
77impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> LayerNorm<F> {
78 pub fn new<R: Rng>(normalizedshape: usize, eps: f64, _rng: &mut R) -> Result<Self> {
80 let gamma = Array::<F, IxDyn>::from_elem(IxDyn(&[normalizedshape]), F::one());
81 let beta = Array::<F, IxDyn>::from_elem(IxDyn(&[normalizedshape]), F::zero());
82
83 let dgamma = Arc::new(RwLock::new(Array::<F, IxDyn>::zeros(IxDyn(&[
84 normalizedshape,
85 ]))));
86 let dbeta = Arc::new(RwLock::new(Array::<F, IxDyn>::zeros(IxDyn(&[
87 normalizedshape,
88 ]))));
89
90 let eps = F::from(eps).ok_or_else(|| {
91 NeuralError::InvalidArchitecture("Failed to convert epsilon to type F".to_string())
92 })?;
93
94 Ok(Self {
95 normalizedshape: vec![normalizedshape],
96 gamma,
97 beta,
98 dgamma,
99 dbeta,
100 eps,
101 input_cache: Arc::new(RwLock::new(None)),
102 norm_cache: Arc::new(RwLock::new(None)),
103 mean_cache: Arc::new(RwLock::new(None)),
104 var_cache: Arc::new(RwLock::new(None)),
105 })
106 }
107
108 pub fn normalizedshape(&self) -> usize {
110 self.normalizedshape[0]
111 }
112
113 #[allow(dead_code)]
115 pub fn eps(&self) -> f64 {
116 self.eps.to_f64().unwrap_or(1e-5)
117 }
118}
119
120impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for LayerNorm<F> {
121 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
122 if let Ok(mut cache) = self.input_cache.write() {
124 *cache = Some(input.clone());
125 }
126
127 let inputshape = input.shape();
128 let ndim = input.ndim();
129
130 if ndim < 1 {
131 return Err(NeuralError::InferenceError(
132 "Input must have at least 1 dimension".to_string(),
133 ));
134 }
135
136 let feat_dim = inputshape[ndim - 1];
137 if feat_dim != self.normalizedshape[0] {
138 return Err(NeuralError::InvalidArchitecture(format!(
139 "Last dimension of input ({}) must match normalizedshape ({})",
140 feat_dim, self.normalizedshape[0]
141 )));
142 }
143
144 let batchshape: Vec<usize> = inputshape[..ndim - 1].to_vec();
145 let batch_size: usize = batchshape.iter().product();
146
147 let reshaped = input
149 .to_owned()
150 .into_shape_with_order(IxDyn(&[batch_size, feat_dim]))
151 .map_err(|e| NeuralError::InferenceError(format!("Failed to reshape input: {e}")))?;
152
153 let mut mean = Array::<F, IxDyn>::zeros(IxDyn(&[batch_size, 1]));
155 let mut var = Array::<F, IxDyn>::zeros(IxDyn(&[batch_size, 1]));
156
157 for i in 0..batch_size {
158 let mut sum = F::zero();
159 for j in 0..feat_dim {
160 sum = sum + reshaped[[i, j]];
161 }
162 mean[[i, 0]] = sum / F::from(feat_dim).unwrap();
163
164 let mut sum_sq = F::zero();
165 for j in 0..feat_dim {
166 let diff = reshaped[[i, j]] - mean[[i, 0]];
167 sum_sq = sum_sq + diff * diff;
168 }
169 var[[i, 0]] = sum_sq / F::from(feat_dim).unwrap();
170 }
171
172 if let Ok(mut cache) = self.mean_cache.write() {
174 *cache = Some(mean.clone());
175 }
176 if let Ok(mut cache) = self.var_cache.write() {
177 *cache = Some(var.clone());
178 }
179
180 let mut normalized = Array::<F, IxDyn>::zeros(IxDyn(&[batch_size, feat_dim]));
182 for i in 0..batch_size {
183 for j in 0..feat_dim {
184 let x_norm = (reshaped[[i, j]] - mean[[i, 0]]) / (var[[i, 0]] + self.eps).sqrt();
185 normalized[[i, j]] = x_norm * self.gamma[[j]] + self.beta[[j]];
186 }
187 }
188
189 if let Ok(mut cache) = self.norm_cache.write() {
191 *cache = Some(normalized.clone().into_dimensionality::<IxDyn>().unwrap());
192 }
193
194 let output = normalized
196 .into_shape_with_order(IxDyn(inputshape))
197 .map_err(|e| NeuralError::InferenceError(format!("Failed to reshape output: {e}")))?;
198
199 Ok(output)
200 }
201
202 fn backward(
203 &self,
204 _input: &Array<F, IxDyn>,
205 grad_output: &Array<F, IxDyn>,
206 ) -> Result<Array<F, IxDyn>> {
207 Ok(grad_output.clone())
209 }
210
211 fn update(&mut self, _learningrate: F) -> Result<()> {
212 Ok(())
214 }
215
216 fn as_any(&self) -> &dyn std::any::Any {
217 self
218 }
219
220 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
221 self
222 }
223
224 fn layer_type(&self) -> &str {
225 "LayerNorm"
226 }
227
228 fn parameter_count(&self) -> usize {
229 self.gamma.len() + self.beta.len()
230 }
231}
232
233impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> ParamLayer<F> for LayerNorm<F> {
234 fn get_parameters(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
235 vec![self.gamma.clone(), self.beta.clone()]
236 }
237
238 fn get_gradients(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
239 vec![]
240 }
241
242 fn set_parameters(&mut self, params: Vec<Array<F, scirs2_core::ndarray::IxDyn>>) -> Result<()> {
243 if params.len() != 2 {
244 return Err(NeuralError::InvalidArchitecture(format!(
245 "Expected 2 parameters, got {}",
246 params.len()
247 )));
248 }
249
250 if params[0].shape() != self.gamma.shape() {
251 return Err(NeuralError::InvalidArchitecture(format!(
252 "Gamma shape mismatch: expected {:?}, got {:?}",
253 self.gamma.shape(),
254 params[0].shape()
255 )));
256 }
257
258 if params[1].shape() != self.beta.shape() {
259 return Err(NeuralError::InvalidArchitecture(format!(
260 "Beta shape mismatch: expected {:?}, got {:?}",
261 self.beta.shape(),
262 params[1].shape()
263 )));
264 }
265
266 self.gamma = params[0].clone();
267 self.beta = params[1].clone();
268
269 Ok(())
270 }
271}
272
273#[derive(Debug, Clone)]
275pub struct BatchNorm<F: Float + Debug + Send + Sync> {
276 num_features: usize,
278 gamma: Array<F, IxDyn>,
280 beta: Array<F, IxDyn>,
282 #[allow(dead_code)]
284 eps: F,
285 #[allow(dead_code)]
287 momentum: F,
288 training: bool,
290}
291
292impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> BatchNorm<F> {
293 pub fn new<R: Rng>(
295 _num_features: usize,
296 momentum: f64,
297 eps: f64,
298 _rng: &mut R,
299 ) -> Result<Self> {
300 let gamma = Array::<F, IxDyn>::from_elem(IxDyn(&[_num_features]), F::one());
301 let beta = Array::<F, IxDyn>::from_elem(IxDyn(&[_num_features]), F::zero());
302
303 let momentum = F::from(momentum).ok_or_else(|| {
304 NeuralError::InvalidArchitecture("Failed to convert momentum to type F".to_string())
305 })?;
306
307 let eps = F::from(eps).ok_or_else(|| {
308 NeuralError::InvalidArchitecture("Failed to convert epsilon to type F".to_string())
309 })?;
310
311 Ok(Self {
312 num_features: _num_features,
313 gamma,
314 beta,
315 eps,
316 momentum,
317 training: true,
318 })
319 }
320
321 #[allow(dead_code)]
323 pub fn set_training(&mut self, training: bool) {
324 self.training = training;
325 }
326
327 #[allow(dead_code)]
329 pub fn num_features(&self) -> usize {
330 self.num_features
331 }
332}
333
334impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for BatchNorm<F> {
335 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
336 Ok(input.clone())
338 }
339
340 fn backward(
341 &self,
342 _input: &Array<F, IxDyn>,
343 grad_output: &Array<F, IxDyn>,
344 ) -> Result<Array<F, IxDyn>> {
345 Ok(grad_output.clone())
346 }
347
348 fn update(&mut self, _learningrate: F) -> Result<()> {
349 Ok(())
350 }
351
352 fn as_any(&self) -> &dyn std::any::Any {
353 self
354 }
355
356 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
357 self
358 }
359
360 fn layer_type(&self) -> &str {
361 "BatchNorm"
362 }
363
364 fn parameter_count(&self) -> usize {
365 self.gamma.len() + self.beta.len()
366 }
367}
368
369impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> ParamLayer<F> for BatchNorm<F> {
370 fn get_parameters(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
371 vec![self.gamma.clone(), self.beta.clone()]
372 }
373
374 fn get_gradients(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
375 vec![]
376 }
377
378 fn set_parameters(&mut self, params: Vec<Array<F, scirs2_core::ndarray::IxDyn>>) -> Result<()> {
379 if params.len() != 2 {
380 return Err(NeuralError::InvalidArchitecture(format!(
381 "Expected 2 parameters, got {}",
382 params.len()
383 )));
384 }
385
386 self.gamma = params[0].clone();
387 self.beta = params[1].clone();
388
389 Ok(())
390 }
391}