1use crate::error::{NeuralError, Result};
7use crate::layers::Layer;
8use ndarray::{Array, IxDyn, ScalarOperand};
9use num_traits::Float;
10use std::fmt::Debug;
11use std::marker::PhantomData;
12use std::sync::{Arc, RwLock};
13
14#[derive(Debug, Clone)]
40pub struct ActivityRegularization<F: Float + Debug + Send + Sync> {
41 l1_factor: Option<F>,
43 l2_factor: Option<F>,
45 name: Option<String>,
47 input_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
49 activity_loss: Arc<RwLock<F>>,
51 _phantom: PhantomData<F>,
53}
54
55impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> ActivityRegularization<F> {
56 pub fn new(l1_factor: Option<f64>, l2_factor: Option<f64>, name: Option<&str>) -> Result<Self> {
68 if l1_factor.is_none() && l2_factor.is_none() {
70 return Err(NeuralError::InvalidArchitecture(
71 "At least one of L1 or L2 regularization factor must be provided".to_string(),
72 ));
73 }
74
75 if let Some(l1) = l1_factor {
77 if l1 < 0.0 {
78 return Err(NeuralError::InvalidArchitecture(
79 "L1 regularization factor must be non-negative".to_string(),
80 ));
81 }
82 }
83
84 if let Some(l2) = l2_factor {
85 if l2 < 0.0 {
86 return Err(NeuralError::InvalidArchitecture(
87 "L2 regularization factor must be non-negative".to_string(),
88 ));
89 }
90 }
91
92 Ok(Self {
93 l1_factor: l1_factor.map(|x| F::from(x).unwrap()),
94 l2_factor: l2_factor.map(|x| F::from(x).unwrap()),
95 name: name.map(String::from),
96 input_cache: Arc::new(RwLock::new(None)),
97 activity_loss: Arc::new(RwLock::new(F::zero())),
98 _phantom: PhantomData,
99 })
100 }
101
102 pub fn name(&self) -> Option<&str> {
104 self.name.as_deref()
105 }
106
107 pub fn get_activity_loss(&self) -> Result<F> {
109 match self.activity_loss.read() {
110 Ok(loss) => Ok(*loss),
111 Err(_) => Err(NeuralError::InferenceError(
112 "Failed to acquire read lock on activity loss".to_string(),
113 )),
114 }
115 }
116
117 fn calculate_activity_loss(&self, input: &Array<F, IxDyn>) -> F {
119 let mut total_loss = F::zero();
120
121 if let Some(l1_factor) = self.l1_factor {
123 let l1_loss = input.mapv(|x| x.abs()).sum();
124 total_loss = total_loss + l1_factor * l1_loss;
125 }
126
127 if let Some(l2_factor) = self.l2_factor {
129 let l2_loss = input.mapv(|x| x * x).sum();
130 total_loss = total_loss + l2_factor * l2_loss;
131 }
132
133 total_loss
134 }
135
136 fn calculate_activity_gradients(&self, input: &Array<F, IxDyn>) -> Array<F, IxDyn> {
138 let mut grad = Array::<F, IxDyn>::zeros(input.raw_dim());
139
140 if let Some(l1_factor) = self.l1_factor {
142 let l1_grad = input.mapv(|x| {
143 if x > F::zero() {
144 l1_factor
145 } else if x < F::zero() {
146 -l1_factor
147 } else {
148 F::zero()
149 }
150 });
151 grad = grad + l1_grad;
152 }
153
154 if let Some(l2_factor) = self.l2_factor {
156 let two = F::from(2.0).unwrap();
157 let l2_grad = input.mapv(|x| two * l2_factor * x);
158 grad = grad + l2_grad;
159 }
160
161 grad
162 }
163}
164
165impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F>
166 for ActivityRegularization<F>
167{
168 fn as_any(&self) -> &dyn std::any::Any {
169 self
170 }
171
172 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
173 self
174 }
175
176 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
177 if let Ok(mut cache) = self.input_cache.write() {
179 *cache = Some(input.clone());
180 } else {
181 return Err(NeuralError::InferenceError(
182 "Failed to acquire write lock on input cache".to_string(),
183 ));
184 }
185
186 let loss = self.calculate_activity_loss(input);
188 if let Ok(mut loss_cache) = self.activity_loss.write() {
189 *loss_cache = loss;
190 } else {
191 return Err(NeuralError::InferenceError(
192 "Failed to acquire write lock on activity loss cache".to_string(),
193 ));
194 }
195
196 Ok(input.clone())
199 }
200
201 fn backward(
202 &self,
203 _input: &Array<F, IxDyn>,
204 grad_output: &Array<F, IxDyn>,
205 ) -> Result<Array<F, IxDyn>> {
206 let input_ref = match self.input_cache.read() {
208 Ok(guard) => guard,
209 Err(_) => {
210 return Err(NeuralError::InferenceError(
211 "Failed to acquire read lock on input cache".to_string(),
212 ))
213 }
214 };
215 if input_ref.is_none() {
216 return Err(NeuralError::InferenceError(
217 "No cached input for backward pass. Call forward() first.".to_string(),
218 ));
219 }
220 let cached_input = input_ref.as_ref().unwrap();
221
222 if cached_input.shape() != grad_output.shape() {
224 return Err(NeuralError::InferenceError(
225 "Input and gradient output shapes must match".to_string(),
226 ));
227 }
228
229 let activity_grad = self.calculate_activity_gradients(cached_input);
231
232 Ok(grad_output + &activity_grad)
234 }
235
236 fn update(&mut self, _learning_rate: F) -> Result<()> {
237 Ok(())
239 }
240
241 fn layer_type(&self) -> &str {
242 "ActivityRegularization"
243 }
244
245 fn parameter_count(&self) -> usize {
246 0
248 }
249
250 fn layer_description(&self) -> String {
251 let l1_str = match self.l1_factor {
252 Some(l1) => format!("{:?}", l1),
253 None => "None".to_string(),
254 };
255 let l2_str = match self.l2_factor {
256 Some(l2) => format!("{:?}", l2),
257 None => "None".to_string(),
258 };
259
260 format!(
261 "type:ActivityRegularization, l1:{}, l2:{}, name:{}",
262 l1_str,
263 l2_str,
264 self.name.as_ref().map_or("None", |s| s)
265 )
266 }
267}
268
269#[derive(Debug, Clone)]
294pub struct L1ActivityRegularization<F: Float + Debug + Send + Sync> {
295 inner: ActivityRegularization<F>,
296}
297
298impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> L1ActivityRegularization<F> {
299 pub fn new(factor: f64, name: Option<&str>) -> Result<Self> {
310 Ok(Self {
311 inner: ActivityRegularization::new(Some(factor), None, name)?,
312 })
313 }
314
315 pub fn name(&self) -> Option<&str> {
317 self.inner.name()
318 }
319
320 pub fn get_activity_loss(&self) -> Result<F> {
322 self.inner.get_activity_loss()
323 }
324}
325
326impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F>
327 for L1ActivityRegularization<F>
328{
329 fn as_any(&self) -> &dyn std::any::Any {
330 self
331 }
332
333 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
334 self
335 }
336
337 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
338 self.inner.forward(input)
339 }
340
341 fn backward(
342 &self,
343 input: &Array<F, IxDyn>,
344 grad_output: &Array<F, IxDyn>,
345 ) -> Result<Array<F, IxDyn>> {
346 self.inner.backward(input, grad_output)
347 }
348
349 fn update(&mut self, learning_rate: F) -> Result<()> {
350 self.inner.update(learning_rate)
351 }
352
353 fn layer_type(&self) -> &str {
354 "L1ActivityRegularization"
355 }
356
357 fn parameter_count(&self) -> usize {
358 self.inner.parameter_count()
359 }
360
361 fn layer_description(&self) -> String {
362 self.inner
363 .layer_description()
364 .replace("ActivityRegularization", "L1ActivityRegularization")
365 }
366}
367
368#[derive(Debug, Clone)]
393pub struct L2ActivityRegularization<F: Float + Debug + Send + Sync> {
394 inner: ActivityRegularization<F>,
395}
396
397impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> L2ActivityRegularization<F> {
398 pub fn new(factor: f64, name: Option<&str>) -> Result<Self> {
409 Ok(Self {
410 inner: ActivityRegularization::new(None, Some(factor), name)?,
411 })
412 }
413
414 pub fn name(&self) -> Option<&str> {
416 self.inner.name()
417 }
418
419 pub fn get_activity_loss(&self) -> Result<F> {
421 self.inner.get_activity_loss()
422 }
423}
424
425impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F>
426 for L2ActivityRegularization<F>
427{
428 fn as_any(&self) -> &dyn std::any::Any {
429 self
430 }
431
432 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
433 self
434 }
435
436 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
437 self.inner.forward(input)
438 }
439
440 fn backward(
441 &self,
442 input: &Array<F, IxDyn>,
443 grad_output: &Array<F, IxDyn>,
444 ) -> Result<Array<F, IxDyn>> {
445 self.inner.backward(input, grad_output)
446 }
447
448 fn update(&mut self, learning_rate: F) -> Result<()> {
449 self.inner.update(learning_rate)
450 }
451
452 fn layer_type(&self) -> &str {
453 "L2ActivityRegularization"
454 }
455
456 fn parameter_count(&self) -> usize {
457 self.inner.parameter_count()
458 }
459
460 fn layer_description(&self) -> String {
461 self.inner
462 .layer_description()
463 .replace("ActivityRegularization", "L2ActivityRegularization")
464 }
465}
466
467#[cfg(test)]
468mod tests {
469 use super::*;
470 use ndarray::{array, Array2};
471
472 #[test]
473 fn test_activity_regularization_creation() {
474 let l1_reg = ActivityRegularization::<f64>::new(Some(0.01), None, Some("l1")).unwrap();
476 assert!(l1_reg.l1_factor.is_some());
477 assert!(l1_reg.l2_factor.is_none());
478
479 let l2_reg = ActivityRegularization::<f64>::new(None, Some(0.02), Some("l2")).unwrap();
481 assert!(l2_reg.l1_factor.is_none());
482 assert!(l2_reg.l2_factor.is_some());
483
484 let both_reg =
486 ActivityRegularization::<f64>::new(Some(0.01), Some(0.02), Some("both")).unwrap();
487 assert!(both_reg.l1_factor.is_some());
488 assert!(both_reg.l2_factor.is_some());
489
490 assert!(ActivityRegularization::<f64>::new(None, None, Some("none")).is_err());
492 }
493
494 #[test]
495 fn test_activity_regularization_forward() {
496 let reg = ActivityRegularization::<f64>::new(Some(0.01), Some(0.02), Some("test")).unwrap();
497
498 let input = Array2::<f64>::from_elem((2, 3), 1.0);
499 let input_dyn = input.clone().into_dyn();
500 let output = reg.forward(&input_dyn).unwrap();
501
502 assert_eq!(input.into_dyn().shape(), output.shape());
504 for (a, b) in input_dyn.iter().zip(output.iter()) {
505 assert!((a - b).abs() < 1e-10);
506 }
507 }
508
509 #[test]
510 fn test_activity_regularization_backward() {
511 let reg = ActivityRegularization::<f64>::new(Some(0.1), Some(0.1), Some("test")).unwrap();
512
513 let input = array![[1.0, -2.0, 0.5], [0.0, 3.0, -1.0]].into_dyn();
514 let grad_output = array![[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]].into_dyn();
515
516 let _output = reg.forward(&input).unwrap();
518
519 let grad_input = reg.backward(&input, &grad_output).unwrap();
521
522 assert_eq!(grad_input.shape(), input.shape());
524 }
527
528 #[test]
529 fn test_l1_activity_regularization() {
530 let reg = L1ActivityRegularization::<f64>::new(0.01, Some("l1_test")).unwrap();
531
532 let input = Array2::<f64>::from_elem((2, 3), 2.0);
533 let input_dyn = input.clone().into_dyn();
534 let output = reg.forward(&input_dyn).unwrap();
535
536 assert_eq!(input.into_dyn().shape(), output.shape());
538
539 let loss = reg.get_activity_loss().unwrap();
541 assert!(loss > 0.0); }
543
544 #[test]
545 fn test_l2_activity_regularization() {
546 let reg = L2ActivityRegularization::<f64>::new(0.01, Some("l2_test")).unwrap();
547
548 let input = Array2::<f64>::from_elem((2, 3), 2.0);
549 let input_dyn = input.clone().into_dyn();
550 let output = reg.forward(&input_dyn).unwrap();
551
552 assert_eq!(input.into_dyn().shape(), output.shape());
554
555 let loss = reg.get_activity_loss().unwrap();
557 assert!(loss > 0.0); }
559
560 #[test]
561 fn test_activity_loss_calculation() {
562 let reg = ActivityRegularization::<f64>::new(Some(0.1), Some(0.1), Some("test")).unwrap();
563
564 let input = array![[1.0, -1.0], [2.0, 0.0]].into_dyn();
566 let _output = reg.forward(&input).unwrap();
567
568 let loss = reg.get_activity_loss().unwrap();
569
570 assert!((loss - 1.0).abs() < 1e-10);
574 }
575}