1use crate::{TrainError, TrainResult};
9use scirs2_core::ndarray::{Array, Ix2};
10use std::collections::HashMap;
11
12pub trait Regularizer {
14 fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64>;
22
23 fn compute_gradient(
31 &self,
32 parameters: &HashMap<String, Array<f64, Ix2>>,
33 ) -> TrainResult<HashMap<String, Array<f64, Ix2>>>;
34}
35
36#[derive(Debug, Clone)]
41pub struct L1Regularization {
42 pub lambda: f64,
44}
45
46impl L1Regularization {
47 pub fn new(lambda: f64) -> Self {
52 Self { lambda }
53 }
54}
55
56impl Default for L1Regularization {
57 fn default() -> Self {
58 Self { lambda: 0.01 }
59 }
60}
61
62impl Regularizer for L1Regularization {
63 fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
64 let mut penalty = 0.0;
65
66 for param in parameters.values() {
67 for &value in param.iter() {
68 penalty += value.abs();
69 }
70 }
71
72 Ok(self.lambda * penalty)
73 }
74
75 fn compute_gradient(
76 &self,
77 parameters: &HashMap<String, Array<f64, Ix2>>,
78 ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
79 let mut gradients = HashMap::new();
80
81 for (name, param) in parameters {
82 let grad = param.mapv(|w| self.lambda * w.signum());
84 gradients.insert(name.clone(), grad);
85 }
86
87 Ok(gradients)
88 }
89}
90
91#[derive(Debug, Clone)]
96pub struct L2Regularization {
97 pub lambda: f64,
99}
100
101impl L2Regularization {
102 pub fn new(lambda: f64) -> Self {
107 Self { lambda }
108 }
109}
110
111impl Default for L2Regularization {
112 fn default() -> Self {
113 Self { lambda: 0.01 }
114 }
115}
116
117impl Regularizer for L2Regularization {
118 fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
119 let mut penalty = 0.0;
120
121 for param in parameters.values() {
122 for &value in param.iter() {
123 penalty += value * value;
124 }
125 }
126
127 Ok(0.5 * self.lambda * penalty)
128 }
129
130 fn compute_gradient(
131 &self,
132 parameters: &HashMap<String, Array<f64, Ix2>>,
133 ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
134 let mut gradients = HashMap::new();
135
136 for (name, param) in parameters {
137 let grad = param.mapv(|w| self.lambda * w);
139 gradients.insert(name.clone(), grad);
140 }
141
142 Ok(gradients)
143 }
144}
145
146#[derive(Debug, Clone)]
150pub struct ElasticNetRegularization {
151 pub lambda: f64,
153 pub l1_ratio: f64,
155}
156
157impl ElasticNetRegularization {
158 pub fn new(lambda: f64, l1_ratio: f64) -> TrainResult<Self> {
164 if !(0.0..=1.0).contains(&l1_ratio) {
165 return Err(TrainError::InvalidParameter(
166 "l1_ratio must be between 0.0 and 1.0".to_string(),
167 ));
168 }
169 Ok(Self { lambda, l1_ratio })
170 }
171}
172
173impl Default for ElasticNetRegularization {
174 fn default() -> Self {
175 Self {
176 lambda: 0.01,
177 l1_ratio: 0.5,
178 }
179 }
180}
181
182impl Regularizer for ElasticNetRegularization {
183 fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
184 let mut l1_penalty = 0.0;
185 let mut l2_penalty = 0.0;
186
187 for param in parameters.values() {
188 for &value in param.iter() {
189 l1_penalty += value.abs();
190 l2_penalty += value * value;
191 }
192 }
193
194 let penalty =
195 self.lambda * (self.l1_ratio * l1_penalty + (1.0 - self.l1_ratio) * 0.5 * l2_penalty);
196
197 Ok(penalty)
198 }
199
200 fn compute_gradient(
201 &self,
202 parameters: &HashMap<String, Array<f64, Ix2>>,
203 ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
204 let mut gradients = HashMap::new();
205
206 for (name, param) in parameters {
207 let grad = param
209 .mapv(|w| self.lambda * (self.l1_ratio * w.signum() + (1.0 - self.l1_ratio) * w));
210 gradients.insert(name.clone(), grad);
211 }
212
213 Ok(gradients)
214 }
215}
216
217#[derive(Clone)]
221pub struct CompositeRegularization {
222 regularizers: Vec<Box<dyn RegularizerClone>>,
223}
224
225impl std::fmt::Debug for CompositeRegularization {
226 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227 f.debug_struct("CompositeRegularization")
228 .field("num_regularizers", &self.regularizers.len())
229 .finish()
230 }
231}
232
233trait RegularizerClone: Regularizer {
235 fn clone_box(&self) -> Box<dyn RegularizerClone>;
236}
237
238impl<T: Regularizer + Clone + 'static> RegularizerClone for T {
239 fn clone_box(&self) -> Box<dyn RegularizerClone> {
240 Box::new(self.clone())
241 }
242}
243
244impl Clone for Box<dyn RegularizerClone> {
245 fn clone(&self) -> Self {
246 self.clone_box()
247 }
248}
249
250impl Regularizer for Box<dyn RegularizerClone> {
251 fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
252 (**self).compute_penalty(parameters)
253 }
254
255 fn compute_gradient(
256 &self,
257 parameters: &HashMap<String, Array<f64, Ix2>>,
258 ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
259 (**self).compute_gradient(parameters)
260 }
261}
262
263impl CompositeRegularization {
264 pub fn new() -> Self {
266 Self {
267 regularizers: Vec::new(),
268 }
269 }
270
271 pub fn add<R: Regularizer + Clone + 'static>(&mut self, regularizer: R) {
276 self.regularizers.push(Box::new(regularizer));
277 }
278
279 pub fn len(&self) -> usize {
281 self.regularizers.len()
282 }
283
284 pub fn is_empty(&self) -> bool {
286 self.regularizers.is_empty()
287 }
288}
289
290impl Default for CompositeRegularization {
291 fn default() -> Self {
292 Self::new()
293 }
294}
295
296impl Regularizer for CompositeRegularization {
297 fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
298 let mut total_penalty = 0.0;
299
300 for regularizer in &self.regularizers {
301 total_penalty += regularizer.compute_penalty(parameters)?;
302 }
303
304 Ok(total_penalty)
305 }
306
307 fn compute_gradient(
308 &self,
309 parameters: &HashMap<String, Array<f64, Ix2>>,
310 ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
311 let mut total_gradients: HashMap<String, Array<f64, Ix2>> = HashMap::new();
312
313 for (name, param) in parameters {
315 total_gradients.insert(name.clone(), Array::zeros(param.raw_dim()));
316 }
317
318 for regularizer in &self.regularizers {
320 let grads = regularizer.compute_gradient(parameters)?;
321
322 for (name, grad) in grads {
323 if let Some(total_grad) = total_gradients.get_mut(&name) {
324 *total_grad = &*total_grad + &grad;
325 }
326 }
327 }
328
329 Ok(total_gradients)
330 }
331}
332
333#[derive(Debug, Clone)]
341pub struct SpectralNormalization {
342 pub target_norm: f64,
344 pub lambda: f64,
346 pub power_iterations: usize,
348}
349
350impl SpectralNormalization {
351 pub fn new(lambda: f64, target_norm: f64, power_iterations: usize) -> Self {
353 Self {
354 lambda,
355 target_norm,
356 power_iterations,
357 }
358 }
359
360 fn estimate_spectral_norm(&self, matrix: &Array<f64, Ix2>) -> f64 {
362 if matrix.is_empty() {
363 return 0.0;
364 }
365
366 let (nrows, ncols) = matrix.dim();
367 if nrows == 0 || ncols == 0 {
368 return 0.0;
369 }
370
371 let mut v = Array::from_elem((ncols,), 1.0 / (ncols as f64).sqrt());
373
374 for _ in 0..self.power_iterations {
376 let u = matrix.dot(&v);
378 let u_norm = u.iter().map(|&x| x * x).sum::<f64>().sqrt();
379 if u_norm < 1e-10 {
380 break;
381 }
382 let u = u / u_norm;
383
384 v = matrix.t().dot(&u);
386 let v_norm = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
387 if v_norm < 1e-10 {
388 break;
389 }
390 v /= v_norm;
391 }
392
393 let final_u = matrix.dot(&v);
395 final_u.iter().map(|&x| x * x).sum::<f64>().sqrt()
396 }
397}
398
399impl Default for SpectralNormalization {
400 fn default() -> Self {
401 Self {
402 target_norm: 1.0,
403 lambda: 0.01,
404 power_iterations: 1,
405 }
406 }
407}
408
409impl Regularizer for SpectralNormalization {
410 fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
411 let mut penalty = 0.0;
412
413 for param in parameters.values() {
414 let spectral_norm = self.estimate_spectral_norm(param);
415 penalty += (spectral_norm - self.target_norm).powi(2);
417 }
418
419 Ok(self.lambda * penalty)
420 }
421
422 fn compute_gradient(
423 &self,
424 parameters: &HashMap<String, Array<f64, Ix2>>,
425 ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
426 let mut gradients = HashMap::new();
427
428 for (name, param) in parameters {
429 let spectral_norm = self.estimate_spectral_norm(param);
430 if spectral_norm < 1e-10 {
431 gradients.insert(name.clone(), Array::zeros(param.dim()));
432 continue;
433 }
434
435 let frobenius_norm = param.iter().map(|&x| x * x).sum::<f64>().sqrt();
437 if frobenius_norm < 1e-10 {
438 gradients.insert(name.clone(), Array::zeros(param.dim()));
439 continue;
440 }
441
442 let scale = 2.0 * self.lambda * (spectral_norm - self.target_norm) / frobenius_norm;
443 let grad = param.mapv(|w| scale * w);
444 gradients.insert(name.clone(), grad);
445 }
446
447 Ok(gradients)
448 }
449}
450
451#[derive(Debug, Clone)]
459pub struct MaxNormRegularization {
460 pub max_norm: f64,
462 pub lambda: f64,
464 pub axis: usize,
466}
467
468impl MaxNormRegularization {
469 pub fn new(max_norm: f64, lambda: f64, axis: usize) -> Self {
471 Self {
472 max_norm,
473 lambda,
474 axis,
475 }
476 }
477}
478
479impl Default for MaxNormRegularization {
480 fn default() -> Self {
481 Self {
482 max_norm: 2.0,
483 lambda: 0.01,
484 axis: 0,
485 }
486 }
487}
488
489impl Regularizer for MaxNormRegularization {
490 fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
491 let mut penalty = 0.0;
492
493 for param in parameters.values() {
494 let axis_len = if self.axis == 0 {
495 param.nrows()
496 } else {
497 param.ncols()
498 };
499
500 for i in 0..axis_len {
501 let row_or_col = if self.axis == 0 {
502 param.row(i)
503 } else {
504 param.column(i)
505 };
506
507 let norm = row_or_col.iter().map(|&x| x * x).sum::<f64>().sqrt();
508 if norm > self.max_norm {
509 penalty += (norm - self.max_norm).powi(2);
510 }
511 }
512 }
513
514 Ok(self.lambda * penalty)
515 }
516
517 fn compute_gradient(
518 &self,
519 parameters: &HashMap<String, Array<f64, Ix2>>,
520 ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
521 let mut gradients = HashMap::new();
522
523 for (name, param) in parameters {
524 let mut grad = Array::zeros(param.dim());
525
526 let axis_len = if self.axis == 0 {
527 param.nrows()
528 } else {
529 param.ncols()
530 };
531
532 for i in 0..axis_len {
533 let row_or_col = if self.axis == 0 {
534 param.row(i)
535 } else {
536 param.column(i)
537 };
538
539 let norm = row_or_col.iter().map(|&x| x * x).sum::<f64>().sqrt();
540 if norm > self.max_norm {
541 let scale = 2.0 * self.lambda * (norm - self.max_norm) / (norm + 1e-10);
542
543 for (j, &val) in row_or_col.iter().enumerate() {
544 if self.axis == 0 {
545 grad[[i, j]] = scale * val;
546 } else {
547 grad[[j, i]] = scale * val;
548 }
549 }
550 }
551 }
552
553 gradients.insert(name.clone(), grad);
554 }
555
556 Ok(gradients)
557 }
558}
559
560#[derive(Debug, Clone)]
568pub struct OrthogonalRegularization {
569 pub lambda: f64,
571}
572
573impl OrthogonalRegularization {
574 pub fn new(lambda: f64) -> Self {
576 Self { lambda }
577 }
578}
579
580impl Default for OrthogonalRegularization {
581 fn default() -> Self {
582 Self { lambda: 0.01 }
583 }
584}
585
586impl Regularizer for OrthogonalRegularization {
587 fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
588 let mut penalty = 0.0;
589
590 for param in parameters.values() {
591 let wt_w = param.t().dot(param);
593
594 let (n, _) = wt_w.dim();
596 for i in 0..n {
597 for j in 0..n {
598 let target = if i == j { 1.0 } else { 0.0 };
599 let diff = wt_w[[i, j]] - target;
600 penalty += diff * diff;
601 }
602 }
603 }
604
605 Ok(self.lambda * penalty)
606 }
607
608 fn compute_gradient(
609 &self,
610 parameters: &HashMap<String, Array<f64, Ix2>>,
611 ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
612 let mut gradients = HashMap::new();
613
614 for (name, param) in parameters {
615 let wt_w = param.t().dot(param);
617
618 let (n, _) = wt_w.dim();
620 let mut identity = Array::zeros((n, n));
621 for i in 0..n {
622 identity[[i, i]] = 1.0;
623 }
624
625 let diff = &wt_w - &identity;
627 let grad = param.dot(&diff) * (2.0 * self.lambda);
628
629 gradients.insert(name.clone(), grad);
630 }
631
632 Ok(gradients)
633 }
634}
635
636#[derive(Debug, Clone)]
644pub struct GroupLassoRegularization {
645 pub lambda: f64,
647 pub group_size: usize,
649}
650
651impl GroupLassoRegularization {
652 pub fn new(lambda: f64, group_size: usize) -> Self {
654 Self { lambda, group_size }
655 }
656}
657
658impl Default for GroupLassoRegularization {
659 fn default() -> Self {
660 Self {
661 lambda: 0.01,
662 group_size: 10,
663 }
664 }
665}
666
667impl Regularizer for GroupLassoRegularization {
668 fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
669 let mut penalty = 0.0;
670
671 for param in parameters.values() {
672 let flat: Vec<f64> = param.iter().copied().collect();
674
675 for group in flat.chunks(self.group_size) {
677 let group_norm = group.iter().map(|&x| x * x).sum::<f64>().sqrt();
678 penalty += group_norm;
679 }
680 }
681
682 Ok(self.lambda * penalty)
683 }
684
685 fn compute_gradient(
686 &self,
687 parameters: &HashMap<String, Array<f64, Ix2>>,
688 ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
689 let mut gradients = HashMap::new();
690
691 for (name, param) in parameters {
692 let mut grad_flat = Vec::new();
693 let flat: Vec<f64> = param.iter().copied().collect();
694
695 for group in flat.chunks(self.group_size) {
696 let group_norm = group.iter().map(|&x| x * x).sum::<f64>().sqrt();
697 if group_norm > 1e-10 {
698 let scale = self.lambda / group_norm;
699 grad_flat.extend(group.iter().map(|&x| scale * x));
700 } else {
701 grad_flat.extend(vec![0.0; group.len()]);
702 }
703 }
704
705 let grad = Array::from_shape_vec(param.dim(), grad_flat).map_err(|e| {
707 TrainError::ModelError(format!("Failed to reshape gradient: {}", e))
708 })?;
709 gradients.insert(name.clone(), grad);
710 }
711
712 Ok(gradients)
713 }
714}
715
716#[cfg(test)]
717mod tests {
718 use super::*;
719 use scirs2_core::ndarray::array;
720
721 #[test]
722 fn test_l1_regularization() {
723 let regularizer = L1Regularization::new(0.1);
724
725 let mut params = HashMap::new();
726 params.insert("w".to_string(), array![[1.0, -2.0], [3.0, -4.0]]);
727
728 let penalty = regularizer.compute_penalty(¶ms).unwrap();
729 assert!((penalty - 1.0).abs() < 1e-6);
731
732 let gradients = regularizer.compute_gradient(¶ms).unwrap();
733 let grad_w = gradients.get("w").unwrap();
734
735 assert_eq!(grad_w[[0, 0]], 0.1); assert_eq!(grad_w[[0, 1]], -0.1); assert_eq!(grad_w[[1, 0]], 0.1); assert_eq!(grad_w[[1, 1]], -0.1); }
741
742 #[test]
743 fn test_l2_regularization() {
744 let regularizer = L2Regularization::new(0.1);
745
746 let mut params = HashMap::new();
747 params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
748
749 let penalty = regularizer.compute_penalty(¶ms).unwrap();
750 assert!((penalty - 1.5).abs() < 1e-6);
752
753 let gradients = regularizer.compute_gradient(¶ms).unwrap();
754 let grad_w = gradients.get("w").unwrap();
755
756 assert!((grad_w[[0, 0]] - 0.1).abs() < 1e-10); assert!((grad_w[[0, 1]] - 0.2).abs() < 1e-10); assert!((grad_w[[1, 0]] - 0.3).abs() < 1e-10); assert!((grad_w[[1, 1]] - 0.4).abs() < 1e-10); }
762
763 #[test]
764 fn test_elastic_net_regularization() {
765 let regularizer = ElasticNetRegularization::new(0.1, 0.5).unwrap();
766
767 let mut params = HashMap::new();
768 params.insert("w".to_string(), array![[1.0, 2.0]]);
769
770 let penalty = regularizer.compute_penalty(¶ms).unwrap();
771 assert!(penalty > 0.0);
772
773 let gradients = regularizer.compute_gradient(¶ms).unwrap();
774 let grad_w = gradients.get("w").unwrap();
775 assert_eq!(grad_w.shape(), &[1, 2]);
776 }
777
778 #[test]
779 fn test_elastic_net_invalid_ratio() {
780 let result = ElasticNetRegularization::new(0.1, 1.5);
781 assert!(result.is_err());
782
783 let result = ElasticNetRegularization::new(0.1, -0.1);
784 assert!(result.is_err());
785 }
786
787 #[test]
788 fn test_composite_regularization() {
789 let mut composite = CompositeRegularization::new();
790 composite.add(L1Regularization::new(0.1));
791 composite.add(L2Regularization::new(0.1));
792
793 let mut params = HashMap::new();
794 params.insert("w".to_string(), array![[1.0, 2.0]]);
795
796 let penalty = composite.compute_penalty(¶ms).unwrap();
797 assert!((penalty - 0.55).abs() < 1e-6);
801
802 let gradients = composite.compute_gradient(¶ms).unwrap();
803 let grad_w = gradients.get("w").unwrap();
804 assert_eq!(grad_w.shape(), &[1, 2]);
805
806 assert!((grad_w[[0, 0]] - 0.2).abs() < 1e-6);
809 }
810
811 #[test]
812 fn test_composite_empty() {
813 let composite = CompositeRegularization::new();
814 assert!(composite.is_empty());
815 assert_eq!(composite.len(), 0);
816
817 let mut params = HashMap::new();
818 params.insert("w".to_string(), array![[1.0]]);
819
820 let penalty = composite.compute_penalty(¶ms).unwrap();
821 assert_eq!(penalty, 0.0);
822 }
823
824 #[test]
825 fn test_multiple_parameters() {
826 let regularizer = L2Regularization::new(0.1);
827
828 let mut params = HashMap::new();
829 params.insert("w1".to_string(), array![[1.0, 2.0]]);
830 params.insert("w2".to_string(), array![[3.0]]);
831
832 let penalty = regularizer.compute_penalty(¶ms).unwrap();
833 assert!((penalty - 0.7).abs() < 1e-6);
835
836 let gradients = regularizer.compute_gradient(¶ms).unwrap();
837 assert_eq!(gradients.len(), 2);
838 assert!(gradients.contains_key("w1"));
839 assert!(gradients.contains_key("w2"));
840 }
841
842 #[test]
843 fn test_zero_lambda() {
844 let regularizer = L1Regularization::new(0.0);
845
846 let mut params = HashMap::new();
847 params.insert("w".to_string(), array![[100.0, 200.0]]);
848
849 let penalty = regularizer.compute_penalty(¶ms).unwrap();
850 assert_eq!(penalty, 0.0);
851
852 let gradients = regularizer.compute_gradient(¶ms).unwrap();
853 let grad_w = gradients.get("w").unwrap();
854 assert_eq!(grad_w[[0, 0]], 0.0);
855 assert_eq!(grad_w[[0, 1]], 0.0);
856 }
857
858 #[test]
859 fn test_spectral_normalization() {
860 let regularizer = SpectralNormalization::new(0.1, 1.0, 5);
861
862 let mut params = HashMap::new();
863 params.insert("w".to_string(), array![[2.0, 0.0], [0.0, 1.0]]);
864
865 let penalty = regularizer.compute_penalty(¶ms).unwrap();
866 assert!((penalty - 0.1).abs() < 0.01);
869
870 let gradients = regularizer.compute_gradient(¶ms).unwrap();
871 assert!(gradients.contains_key("w"));
872 }
873
874 #[test]
875 fn test_max_norm_regularization() {
876 let regularizer = MaxNormRegularization::new(1.0, 0.1, 0);
877
878 let mut params = HashMap::new();
879 params.insert(
880 "w".to_string(),
881 array![[3.0, 4.0], [0.1, 0.1]], );
883
884 let penalty = regularizer.compute_penalty(¶ms).unwrap();
885 assert!((penalty - 1.6).abs() < 0.1);
888
889 let gradients = regularizer.compute_gradient(¶ms).unwrap();
890 let grad_w = gradients.get("w").unwrap();
891 assert!(grad_w[[0, 0]].abs() > 0.0);
893 assert!(grad_w[[1, 0]].abs() < 1e-10);
895 }
896
897 #[test]
898 fn test_orthogonal_regularization() {
899 let regularizer = OrthogonalRegularization::new(0.1);
900
901 let mut params = HashMap::new();
902 params.insert("w".to_string(), array![[1.0, 0.0], [0.0, 1.0]]);
904
905 let penalty = regularizer.compute_penalty(¶ms).unwrap();
906 assert!(penalty.abs() < 1e-10);
907
908 params.insert("w".to_string(), array![[1.0, 1.0], [1.0, 1.0]]);
910 let penalty = regularizer.compute_penalty(¶ms).unwrap();
911 assert!(penalty > 0.0);
912
913 let gradients = regularizer.compute_gradient(¶ms).unwrap();
914 assert!(gradients.contains_key("w"));
915 }
916
917 #[test]
918 fn test_group_lasso_regularization() {
919 let regularizer = GroupLassoRegularization::new(0.1, 2);
920
921 let mut params = HashMap::new();
922 params.insert(
923 "w".to_string(),
924 array![[1.0, 2.0], [3.0, 4.0]], );
926
927 let penalty = regularizer.compute_penalty(¶ms).unwrap();
928 assert!((penalty - 0.7236).abs() < 0.01);
932
933 let gradients = regularizer.compute_gradient(¶ms).unwrap();
934 let grad_w = gradients.get("w").unwrap();
935 assert_eq!(grad_w.dim(), (2, 2));
936 }
937
938 #[test]
939 fn test_spectral_normalization_zero_matrix() {
940 let regularizer = SpectralNormalization::new(0.1, 1.0, 5);
941
942 let mut params = HashMap::new();
943 params.insert("w".to_string(), array![[0.0, 0.0], [0.0, 0.0]]);
944
945 let penalty = regularizer.compute_penalty(¶ms).unwrap();
946 assert!((penalty - 0.1).abs() < 0.01);
949
950 let gradients = regularizer.compute_gradient(¶ms).unwrap();
951 let grad_w = gradients.get("w").unwrap();
952 assert!(grad_w.iter().all(|&x| x.abs() < 1e-10));
954 }
955
956 #[test]
957 fn test_max_norm_no_violation() {
958 let regularizer = MaxNormRegularization::new(10.0, 0.1, 0);
959
960 let mut params = HashMap::new();
961 params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
962
963 let penalty = regularizer.compute_penalty(¶ms).unwrap();
964 assert!(penalty.abs() < 1e-10);
966
967 let gradients = regularizer.compute_gradient(¶ms).unwrap();
968 let grad_w = gradients.get("w").unwrap();
969 assert!(grad_w.iter().all(|&x| x.abs() < 1e-10));
971 }
972
973 #[test]
974 fn test_orthogonal_non_square() {
975 let regularizer = OrthogonalRegularization::new(0.1);
976
977 let mut params = HashMap::new();
978 params.insert("w".to_string(), array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]);
979
980 let penalty = regularizer.compute_penalty(¶ms).unwrap();
982 assert!(penalty > 0.0); let gradients = regularizer.compute_gradient(¶ms).unwrap();
985 assert!(gradients.contains_key("w"));
986 }
987
988 #[test]
989 fn test_group_lasso_single_group() {
990 let regularizer = GroupLassoRegularization::new(0.1, 4);
991
992 let mut params = HashMap::new();
993 params.insert("w".to_string(), array![[3.0, 4.0]]);
994
995 let penalty = regularizer.compute_penalty(¶ms).unwrap();
996 assert!((penalty - 0.5).abs() < 0.01);
999
1000 let gradients = regularizer.compute_gradient(¶ms).unwrap();
1001 let grad_w = gradients.get("w").unwrap();
1002 assert_eq!(grad_w.dim(), (1, 2));
1003 }
1004}