Skip to main content

tensorlogic_train/
regularization.rs

1//! Regularization techniques for training.
2//!
3//! This module provides various regularization strategies to prevent overfitting:
4//! - L1 regularization (Lasso): Encourages sparsity
5//! - L2 regularization (Ridge): Prevents large weights
6//! - Composite regularization: Combines multiple regularizers
7
8use crate::{TrainError, TrainResult};
9use scirs2_core::ndarray::{Array, Ix2};
10use std::collections::HashMap;
11
12/// Trait for regularization strategies.
13pub trait Regularizer {
14    /// Compute the regularization penalty for given parameters.
15    ///
16    /// # Arguments
17    /// * `parameters` - Model parameters to regularize
18    ///
19    /// # Returns
20    /// The regularization penalty value
21    fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64>;
22
23    /// Compute the gradient of the regularization penalty.
24    ///
25    /// # Arguments
26    /// * `parameters` - Model parameters
27    ///
28    /// # Returns
29    /// Gradients of the regularization penalty for each parameter
30    fn compute_gradient(
31        &self,
32        parameters: &HashMap<String, Array<f64, Ix2>>,
33    ) -> TrainResult<HashMap<String, Array<f64, Ix2>>>;
34}
35
36/// L1 regularization (Lasso).
37///
38/// Adds penalty proportional to the absolute value of weights: λ * Σ|w|
39/// Encourages sparsity by driving some weights to exactly zero.
40#[derive(Debug, Clone)]
41pub struct L1Regularization {
42    /// Regularization strength (lambda).
43    pub lambda: f64,
44}
45
46impl L1Regularization {
47    /// Create a new L1 regularizer.
48    ///
49    /// # Arguments
50    /// * `lambda` - Regularization strength
51    pub fn new(lambda: f64) -> Self {
52        Self { lambda }
53    }
54}
55
56impl Default for L1Regularization {
57    fn default() -> Self {
58        Self { lambda: 0.01 }
59    }
60}
61
62impl Regularizer for L1Regularization {
63    fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
64        let mut penalty = 0.0;
65
66        for param in parameters.values() {
67            for &value in param.iter() {
68                penalty += value.abs();
69            }
70        }
71
72        Ok(self.lambda * penalty)
73    }
74
75    fn compute_gradient(
76        &self,
77        parameters: &HashMap<String, Array<f64, Ix2>>,
78    ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
79        let mut gradients = HashMap::new();
80
81        for (name, param) in parameters {
82            // Gradient of L1: λ * sign(w)
83            let grad = param.mapv(|w| self.lambda * w.signum());
84            gradients.insert(name.clone(), grad);
85        }
86
87        Ok(gradients)
88    }
89}
90
91/// L2 regularization (Ridge / Weight Decay).
92///
93/// Adds penalty proportional to the square of weights: λ * Σw²
94/// Prevents weights from becoming too large.
95#[derive(Debug, Clone)]
96pub struct L2Regularization {
97    /// Regularization strength (lambda).
98    pub lambda: f64,
99}
100
101impl L2Regularization {
102    /// Create a new L2 regularizer.
103    ///
104    /// # Arguments
105    /// * `lambda` - Regularization strength
106    pub fn new(lambda: f64) -> Self {
107        Self { lambda }
108    }
109}
110
111impl Default for L2Regularization {
112    fn default() -> Self {
113        Self { lambda: 0.01 }
114    }
115}
116
117impl Regularizer for L2Regularization {
118    fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
119        let mut penalty = 0.0;
120
121        for param in parameters.values() {
122            for &value in param.iter() {
123                penalty += value * value;
124            }
125        }
126
127        Ok(0.5 * self.lambda * penalty)
128    }
129
130    fn compute_gradient(
131        &self,
132        parameters: &HashMap<String, Array<f64, Ix2>>,
133    ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
134        let mut gradients = HashMap::new();
135
136        for (name, param) in parameters {
137            // Gradient of L2: λ * w
138            let grad = param.mapv(|w| self.lambda * w);
139            gradients.insert(name.clone(), grad);
140        }
141
142        Ok(gradients)
143    }
144}
145
146/// Elastic Net regularization (combination of L1 and L2).
147///
148/// Combines L1 and L2 penalties: l1_ratio * L1 + (1 - l1_ratio) * L2
149#[derive(Debug, Clone)]
150pub struct ElasticNetRegularization {
151    /// Overall regularization strength.
152    pub lambda: f64,
153    /// Balance between L1 and L2 (0.0 = pure L2, 1.0 = pure L1).
154    pub l1_ratio: f64,
155}
156
157impl ElasticNetRegularization {
158    /// Create a new Elastic Net regularizer.
159    ///
160    /// # Arguments
161    /// * `lambda` - Overall regularization strength
162    /// * `l1_ratio` - Balance between L1 and L2 (should be in [0, 1])
163    pub fn new(lambda: f64, l1_ratio: f64) -> TrainResult<Self> {
164        if !(0.0..=1.0).contains(&l1_ratio) {
165            return Err(TrainError::InvalidParameter(
166                "l1_ratio must be between 0.0 and 1.0".to_string(),
167            ));
168        }
169        Ok(Self { lambda, l1_ratio })
170    }
171}
172
173impl Default for ElasticNetRegularization {
174    fn default() -> Self {
175        Self {
176            lambda: 0.01,
177            l1_ratio: 0.5,
178        }
179    }
180}
181
182impl Regularizer for ElasticNetRegularization {
183    fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
184        let mut l1_penalty = 0.0;
185        let mut l2_penalty = 0.0;
186
187        for param in parameters.values() {
188            for &value in param.iter() {
189                l1_penalty += value.abs();
190                l2_penalty += value * value;
191            }
192        }
193
194        let penalty =
195            self.lambda * (self.l1_ratio * l1_penalty + (1.0 - self.l1_ratio) * 0.5 * l2_penalty);
196
197        Ok(penalty)
198    }
199
200    fn compute_gradient(
201        &self,
202        parameters: &HashMap<String, Array<f64, Ix2>>,
203    ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
204        let mut gradients = HashMap::new();
205
206        for (name, param) in parameters {
207            // Gradient: λ * (l1_ratio * sign(w) + (1 - l1_ratio) * w)
208            let grad = param
209                .mapv(|w| self.lambda * (self.l1_ratio * w.signum() + (1.0 - self.l1_ratio) * w));
210            gradients.insert(name.clone(), grad);
211        }
212
213        Ok(gradients)
214    }
215}
216
217/// Composite regularization that combines multiple regularizers.
218///
219/// Useful for applying different regularization strategies simultaneously.
220#[derive(Clone)]
221pub struct CompositeRegularization {
222    regularizers: Vec<Box<dyn RegularizerClone>>,
223}
224
225impl std::fmt::Debug for CompositeRegularization {
226    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227        f.debug_struct("CompositeRegularization")
228            .field("num_regularizers", &self.regularizers.len())
229            .finish()
230    }
231}
232
233/// Helper trait for cloning boxed regularizers.
234trait RegularizerClone: Regularizer {
235    fn clone_box(&self) -> Box<dyn RegularizerClone>;
236}
237
238impl<T: Regularizer + Clone + 'static> RegularizerClone for T {
239    fn clone_box(&self) -> Box<dyn RegularizerClone> {
240        Box::new(self.clone())
241    }
242}
243
244impl Clone for Box<dyn RegularizerClone> {
245    fn clone(&self) -> Self {
246        self.clone_box()
247    }
248}
249
250impl Regularizer for Box<dyn RegularizerClone> {
251    fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
252        (**self).compute_penalty(parameters)
253    }
254
255    fn compute_gradient(
256        &self,
257        parameters: &HashMap<String, Array<f64, Ix2>>,
258    ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
259        (**self).compute_gradient(parameters)
260    }
261}
262
263impl CompositeRegularization {
264    /// Create a new composite regularizer.
265    pub fn new() -> Self {
266        Self {
267            regularizers: Vec::new(),
268        }
269    }
270
271    /// Add a regularizer to the composite.
272    ///
273    /// # Arguments
274    /// * `regularizer` - Regularizer to add
275    pub fn add<R: Regularizer + Clone + 'static>(&mut self, regularizer: R) {
276        self.regularizers.push(Box::new(regularizer));
277    }
278
279    /// Get the number of regularizers in the composite.
280    pub fn len(&self) -> usize {
281        self.regularizers.len()
282    }
283
284    /// Check if the composite is empty.
285    pub fn is_empty(&self) -> bool {
286        self.regularizers.is_empty()
287    }
288}
289
290impl Default for CompositeRegularization {
291    fn default() -> Self {
292        Self::new()
293    }
294}
295
296impl Regularizer for CompositeRegularization {
297    fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
298        let mut total_penalty = 0.0;
299
300        for regularizer in &self.regularizers {
301            total_penalty += regularizer.compute_penalty(parameters)?;
302        }
303
304        Ok(total_penalty)
305    }
306
307    fn compute_gradient(
308        &self,
309        parameters: &HashMap<String, Array<f64, Ix2>>,
310    ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
311        let mut total_gradients: HashMap<String, Array<f64, Ix2>> = HashMap::new();
312
313        // Initialize with zeros
314        for (name, param) in parameters {
315            total_gradients.insert(name.clone(), Array::zeros(param.raw_dim()));
316        }
317
318        // Accumulate gradients from all regularizers
319        for regularizer in &self.regularizers {
320            let grads = regularizer.compute_gradient(parameters)?;
321
322            for (name, grad) in grads {
323                if let Some(total_grad) = total_gradients.get_mut(&name) {
324                    *total_grad = &*total_grad + &grad;
325                }
326            }
327        }
328
329        Ok(total_gradients)
330    }
331}
332
333/// Spectral Normalization regularizer.
334///
335/// Normalizes weight matrices by their spectral norm (largest singular value).
336/// Useful for stabilizing GAN training and improving generalization.
337///
338/// # References
339/// - Miyato et al. (2018): "Spectral Normalization for Generative Adversarial Networks"
340#[derive(Debug, Clone)]
341pub struct SpectralNormalization {
342    /// Target spectral norm (usually 1.0)
343    pub target_norm: f64,
344    /// Strength of the regularization
345    pub lambda: f64,
346    /// Number of power iterations for spectral norm estimation
347    pub power_iterations: usize,
348}
349
350impl SpectralNormalization {
351    /// Create a new spectral normalization regularizer.
352    pub fn new(lambda: f64, target_norm: f64, power_iterations: usize) -> Self {
353        Self {
354            lambda,
355            target_norm,
356            power_iterations,
357        }
358    }
359
360    /// Estimate spectral norm using power iteration.
361    fn estimate_spectral_norm(&self, matrix: &Array<f64, Ix2>) -> f64 {
362        if matrix.is_empty() {
363            return 0.0;
364        }
365
366        let (nrows, ncols) = matrix.dim();
367        if nrows == 0 || ncols == 0 {
368            return 0.0;
369        }
370
371        // Initialize random vector
372        let mut v = Array::from_elem((ncols,), 1.0 / (ncols as f64).sqrt());
373
374        // Power iteration to find dominant singular value
375        for _ in 0..self.power_iterations {
376            // u = W * v
377            let u = matrix.dot(&v);
378            let u_norm = u.iter().map(|&x| x * x).sum::<f64>().sqrt();
379            if u_norm < 1e-10 {
380                break;
381            }
382            let u = u / u_norm;
383
384            // v = W^T * u
385            v = matrix.t().dot(&u);
386            let v_norm = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
387            if v_norm < 1e-10 {
388                break;
389            }
390            v /= v_norm;
391        }
392
393        // σ = ||W * v||
394        let final_u = matrix.dot(&v);
395        final_u.iter().map(|&x| x * x).sum::<f64>().sqrt()
396    }
397}
398
399impl Default for SpectralNormalization {
400    fn default() -> Self {
401        Self {
402            target_norm: 1.0,
403            lambda: 0.01,
404            power_iterations: 1,
405        }
406    }
407}
408
409impl Regularizer for SpectralNormalization {
410    fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
411        let mut penalty = 0.0;
412
413        for param in parameters.values() {
414            let spectral_norm = self.estimate_spectral_norm(param);
415            // Penalty for deviation from target norm
416            penalty += (spectral_norm - self.target_norm).powi(2);
417        }
418
419        Ok(self.lambda * penalty)
420    }
421
422    fn compute_gradient(
423        &self,
424        parameters: &HashMap<String, Array<f64, Ix2>>,
425    ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
426        let mut gradients = HashMap::new();
427
428        for (name, param) in parameters {
429            let spectral_norm = self.estimate_spectral_norm(param);
430            if spectral_norm < 1e-10 {
431                gradients.insert(name.clone(), Array::zeros(param.dim()));
432                continue;
433            }
434
435            // Approximate gradient: ∇||W||_2 ≈ W / ||W||_F
436            let frobenius_norm = param.iter().map(|&x| x * x).sum::<f64>().sqrt();
437            if frobenius_norm < 1e-10 {
438                gradients.insert(name.clone(), Array::zeros(param.dim()));
439                continue;
440            }
441
442            let scale = 2.0 * self.lambda * (spectral_norm - self.target_norm) / frobenius_norm;
443            let grad = param.mapv(|w| scale * w);
444            gradients.insert(name.clone(), grad);
445        }
446
447        Ok(gradients)
448    }
449}
450
451/// MaxNorm constraint regularizer.
452///
453/// Constrains the norm of weight vectors to a maximum value.
454/// Useful for preventing exploding gradients and improving stability.
455///
456/// # References
457/// - Hinton et al.: "Improving neural networks by preventing co-adaptation"
458#[derive(Debug, Clone)]
459pub struct MaxNormRegularization {
460    /// Maximum allowed norm
461    pub max_norm: f64,
462    /// Regularization strength
463    pub lambda: f64,
464    /// Axis along which to compute norms (0 for rows, 1 for columns)
465    pub axis: usize,
466}
467
468impl MaxNormRegularization {
469    /// Create a new max norm regularizer.
470    pub fn new(max_norm: f64, lambda: f64, axis: usize) -> Self {
471        Self {
472            max_norm,
473            lambda,
474            axis,
475        }
476    }
477}
478
479impl Default for MaxNormRegularization {
480    fn default() -> Self {
481        Self {
482            max_norm: 2.0,
483            lambda: 0.01,
484            axis: 0,
485        }
486    }
487}
488
489impl Regularizer for MaxNormRegularization {
490    fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
491        let mut penalty = 0.0;
492
493        for param in parameters.values() {
494            let axis_len = if self.axis == 0 {
495                param.nrows()
496            } else {
497                param.ncols()
498            };
499
500            for i in 0..axis_len {
501                let row_or_col = if self.axis == 0 {
502                    param.row(i)
503                } else {
504                    param.column(i)
505                };
506
507                let norm = row_or_col.iter().map(|&x| x * x).sum::<f64>().sqrt();
508                if norm > self.max_norm {
509                    penalty += (norm - self.max_norm).powi(2);
510                }
511            }
512        }
513
514        Ok(self.lambda * penalty)
515    }
516
517    fn compute_gradient(
518        &self,
519        parameters: &HashMap<String, Array<f64, Ix2>>,
520    ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
521        let mut gradients = HashMap::new();
522
523        for (name, param) in parameters {
524            let mut grad = Array::zeros(param.dim());
525
526            let axis_len = if self.axis == 0 {
527                param.nrows()
528            } else {
529                param.ncols()
530            };
531
532            for i in 0..axis_len {
533                let row_or_col = if self.axis == 0 {
534                    param.row(i)
535                } else {
536                    param.column(i)
537                };
538
539                let norm = row_or_col.iter().map(|&x| x * x).sum::<f64>().sqrt();
540                if norm > self.max_norm {
541                    let scale = 2.0 * self.lambda * (norm - self.max_norm) / (norm + 1e-10);
542
543                    for (j, &val) in row_or_col.iter().enumerate() {
544                        if self.axis == 0 {
545                            grad[[i, j]] = scale * val;
546                        } else {
547                            grad[[j, i]] = scale * val;
548                        }
549                    }
550                }
551            }
552
553            gradients.insert(name.clone(), grad);
554        }
555
556        Ok(gradients)
557    }
558}
559
560/// Orthogonal regularization.
561///
562/// Encourages weight matrices to be orthogonal: W^T * W ≈ I
563/// Helps prevent internal covariate shift and improves gradient flow.
564///
565/// # References
566/// - Brock et al. (2017): "Neural Photo Editing with Introspective Adversarial Networks"
567#[derive(Debug, Clone)]
568pub struct OrthogonalRegularization {
569    /// Regularization strength
570    pub lambda: f64,
571}
572
573impl OrthogonalRegularization {
574    /// Create a new orthogonal regularizer.
575    pub fn new(lambda: f64) -> Self {
576        Self { lambda }
577    }
578}
579
580impl Default for OrthogonalRegularization {
581    fn default() -> Self {
582        Self { lambda: 0.01 }
583    }
584}
585
586impl Regularizer for OrthogonalRegularization {
587    fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
588        let mut penalty = 0.0;
589
590        for param in parameters.values() {
591            // Compute W^T * W
592            let wt_w = param.t().dot(param);
593
594            // Compute ||W^T * W - I||_F^2
595            let (n, _) = wt_w.dim();
596            for i in 0..n {
597                for j in 0..n {
598                    let target = if i == j { 1.0 } else { 0.0 };
599                    let diff = wt_w[[i, j]] - target;
600                    penalty += diff * diff;
601                }
602            }
603        }
604
605        Ok(self.lambda * penalty)
606    }
607
608    fn compute_gradient(
609        &self,
610        parameters: &HashMap<String, Array<f64, Ix2>>,
611    ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
612        let mut gradients = HashMap::new();
613
614        for (name, param) in parameters {
615            // W^T * W
616            let wt_w = param.t().dot(param);
617
618            // Create identity matrix
619            let (n, _) = wt_w.dim();
620            let mut identity = Array::zeros((n, n));
621            for i in 0..n {
622                identity[[i, i]] = 1.0;
623            }
624
625            // Gradient: 2 * λ * W * (W^T * W - I)
626            let diff = &wt_w - &identity;
627            let grad = param.dot(&diff) * (2.0 * self.lambda);
628
629            gradients.insert(name.clone(), grad);
630        }
631
632        Ok(gradients)
633    }
634}
635
636/// Group Lasso regularization.
637///
638/// Encourages group-wise sparsity by penalizing the L2 norm of groups.
639/// Useful when features have natural groupings.
640///
641/// # References
642/// - Yuan & Lin (2006): "Model selection and estimation in regression with grouped variables"
643#[derive(Debug, Clone)]
644pub struct GroupLassoRegularization {
645    /// Regularization strength
646    pub lambda: f64,
647    /// Group size (number of consecutive parameters per group)
648    pub group_size: usize,
649}
650
651impl GroupLassoRegularization {
652    /// Create a new group lasso regularizer.
653    pub fn new(lambda: f64, group_size: usize) -> Self {
654        Self { lambda, group_size }
655    }
656}
657
658impl Default for GroupLassoRegularization {
659    fn default() -> Self {
660        Self {
661            lambda: 0.01,
662            group_size: 10,
663        }
664    }
665}
666
667impl Regularizer for GroupLassoRegularization {
668    fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
669        let mut penalty = 0.0;
670
671        for param in parameters.values() {
672            // Flatten to 1D
673            let flat: Vec<f64> = param.iter().copied().collect();
674
675            // Compute group norms
676            for group in flat.chunks(self.group_size) {
677                let group_norm = group.iter().map(|&x| x * x).sum::<f64>().sqrt();
678                penalty += group_norm;
679            }
680        }
681
682        Ok(self.lambda * penalty)
683    }
684
685    fn compute_gradient(
686        &self,
687        parameters: &HashMap<String, Array<f64, Ix2>>,
688    ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
689        let mut gradients = HashMap::new();
690
691        for (name, param) in parameters {
692            let mut grad_flat = Vec::new();
693            let flat: Vec<f64> = param.iter().copied().collect();
694
695            for group in flat.chunks(self.group_size) {
696                let group_norm = group.iter().map(|&x| x * x).sum::<f64>().sqrt();
697                if group_norm > 1e-10 {
698                    let scale = self.lambda / group_norm;
699                    grad_flat.extend(group.iter().map(|&x| scale * x));
700                } else {
701                    grad_flat.extend(vec![0.0; group.len()]);
702                }
703            }
704
705            // Reshape back to original shape
706            let grad = Array::from_shape_vec(param.dim(), grad_flat).map_err(|e| {
707                TrainError::ModelError(format!("Failed to reshape gradient: {}", e))
708            })?;
709            gradients.insert(name.clone(), grad);
710        }
711
712        Ok(gradients)
713    }
714}
715
716#[cfg(test)]
717mod tests {
718    use super::*;
719    use scirs2_core::ndarray::array;
720
721    #[test]
722    fn test_l1_regularization() {
723        let regularizer = L1Regularization::new(0.1);
724
725        let mut params = HashMap::new();
726        params.insert("w".to_string(), array![[1.0, -2.0], [3.0, -4.0]]);
727
728        let penalty = regularizer.compute_penalty(&params).unwrap();
729        // Expected: 0.1 * (1 + 2 + 3 + 4) = 1.0
730        assert!((penalty - 1.0).abs() < 1e-6);
731
732        let gradients = regularizer.compute_gradient(&params).unwrap();
733        let grad_w = gradients.get("w").unwrap();
734
735        // Gradient should be λ * sign(w)
736        assert_eq!(grad_w[[0, 0]], 0.1); // sign(1.0) = 1.0
737        assert_eq!(grad_w[[0, 1]], -0.1); // sign(-2.0) = -1.0
738        assert_eq!(grad_w[[1, 0]], 0.1); // sign(3.0) = 1.0
739        assert_eq!(grad_w[[1, 1]], -0.1); // sign(-4.0) = -1.0
740    }
741
742    #[test]
743    fn test_l2_regularization() {
744        let regularizer = L2Regularization::new(0.1);
745
746        let mut params = HashMap::new();
747        params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
748
749        let penalty = regularizer.compute_penalty(&params).unwrap();
750        // Expected: 0.5 * 0.1 * (1 + 4 + 9 + 16) = 1.5
751        assert!((penalty - 1.5).abs() < 1e-6);
752
753        let gradients = regularizer.compute_gradient(&params).unwrap();
754        let grad_w = gradients.get("w").unwrap();
755
756        // Gradient should be λ * w
757        assert!((grad_w[[0, 0]] - 0.1).abs() < 1e-10); // 0.1 * 1.0
758        assert!((grad_w[[0, 1]] - 0.2).abs() < 1e-10); // 0.1 * 2.0
759        assert!((grad_w[[1, 0]] - 0.3).abs() < 1e-10); // 0.1 * 3.0
760        assert!((grad_w[[1, 1]] - 0.4).abs() < 1e-10); // 0.1 * 4.0
761    }
762
763    #[test]
764    fn test_elastic_net_regularization() {
765        let regularizer = ElasticNetRegularization::new(0.1, 0.5).unwrap();
766
767        let mut params = HashMap::new();
768        params.insert("w".to_string(), array![[1.0, 2.0]]);
769
770        let penalty = regularizer.compute_penalty(&params).unwrap();
771        assert!(penalty > 0.0);
772
773        let gradients = regularizer.compute_gradient(&params).unwrap();
774        let grad_w = gradients.get("w").unwrap();
775        assert_eq!(grad_w.shape(), &[1, 2]);
776    }
777
778    #[test]
779    fn test_elastic_net_invalid_ratio() {
780        let result = ElasticNetRegularization::new(0.1, 1.5);
781        assert!(result.is_err());
782
783        let result = ElasticNetRegularization::new(0.1, -0.1);
784        assert!(result.is_err());
785    }
786
787    #[test]
788    fn test_composite_regularization() {
789        let mut composite = CompositeRegularization::new();
790        composite.add(L1Regularization::new(0.1));
791        composite.add(L2Regularization::new(0.1));
792
793        let mut params = HashMap::new();
794        params.insert("w".to_string(), array![[1.0, 2.0]]);
795
796        let penalty = composite.compute_penalty(&params).unwrap();
797        // L1: 0.1 * (1 + 2) = 0.3
798        // L2: 0.5 * 0.1 * (1 + 4) = 0.25
799        // Total: 0.55
800        assert!((penalty - 0.55).abs() < 1e-6);
801
802        let gradients = composite.compute_gradient(&params).unwrap();
803        let grad_w = gradients.get("w").unwrap();
804        assert_eq!(grad_w.shape(), &[1, 2]);
805
806        // Gradient should combine both L1 and L2
807        // For w[0,0] = 1.0: L1 grad = 0.1, L2 grad = 0.1, total = 0.2
808        assert!((grad_w[[0, 0]] - 0.2).abs() < 1e-6);
809    }
810
811    #[test]
812    fn test_composite_empty() {
813        let composite = CompositeRegularization::new();
814        assert!(composite.is_empty());
815        assert_eq!(composite.len(), 0);
816
817        let mut params = HashMap::new();
818        params.insert("w".to_string(), array![[1.0]]);
819
820        let penalty = composite.compute_penalty(&params).unwrap();
821        assert_eq!(penalty, 0.0);
822    }
823
824    #[test]
825    fn test_multiple_parameters() {
826        let regularizer = L2Regularization::new(0.1);
827
828        let mut params = HashMap::new();
829        params.insert("w1".to_string(), array![[1.0, 2.0]]);
830        params.insert("w2".to_string(), array![[3.0]]);
831
832        let penalty = regularizer.compute_penalty(&params).unwrap();
833        // Expected: 0.5 * 0.1 * (1 + 4 + 9) = 0.7
834        assert!((penalty - 0.7).abs() < 1e-6);
835
836        let gradients = regularizer.compute_gradient(&params).unwrap();
837        assert_eq!(gradients.len(), 2);
838        assert!(gradients.contains_key("w1"));
839        assert!(gradients.contains_key("w2"));
840    }
841
842    #[test]
843    fn test_zero_lambda() {
844        let regularizer = L1Regularization::new(0.0);
845
846        let mut params = HashMap::new();
847        params.insert("w".to_string(), array![[100.0, 200.0]]);
848
849        let penalty = regularizer.compute_penalty(&params).unwrap();
850        assert_eq!(penalty, 0.0);
851
852        let gradients = regularizer.compute_gradient(&params).unwrap();
853        let grad_w = gradients.get("w").unwrap();
854        assert_eq!(grad_w[[0, 0]], 0.0);
855        assert_eq!(grad_w[[0, 1]], 0.0);
856    }
857
858    #[test]
859    fn test_spectral_normalization() {
860        let regularizer = SpectralNormalization::new(0.1, 1.0, 5);
861
862        let mut params = HashMap::new();
863        params.insert("w".to_string(), array![[2.0, 0.0], [0.0, 1.0]]);
864
865        let penalty = regularizer.compute_penalty(&params).unwrap();
866        // Spectral norm of [[2,0],[0,1]] is 2.0
867        // Penalty = 0.1 * (2.0 - 1.0)^2 = 0.1
868        assert!((penalty - 0.1).abs() < 0.01);
869
870        let gradients = regularizer.compute_gradient(&params).unwrap();
871        assert!(gradients.contains_key("w"));
872    }
873
874    #[test]
875    fn test_max_norm_regularization() {
876        let regularizer = MaxNormRegularization::new(1.0, 0.1, 0);
877
878        let mut params = HashMap::new();
879        params.insert(
880            "w".to_string(),
881            array![[3.0, 4.0], [0.1, 0.1]], // First row has norm 5.0 > 1.0
882        );
883
884        let penalty = regularizer.compute_penalty(&params).unwrap();
885        // First row: norm = 5.0, exceeds max_norm = 1.0
886        // Penalty = 0.1 * (5.0 - 1.0)^2 = 1.6
887        assert!((penalty - 1.6).abs() < 0.1);
888
889        let gradients = regularizer.compute_gradient(&params).unwrap();
890        let grad_w = gradients.get("w").unwrap();
891        // First row should have non-zero gradient
892        assert!(grad_w[[0, 0]].abs() > 0.0);
893        // Second row should have zero gradient (norm below max_norm)
894        assert!(grad_w[[1, 0]].abs() < 1e-10);
895    }
896
897    #[test]
898    fn test_orthogonal_regularization() {
899        let regularizer = OrthogonalRegularization::new(0.1);
900
901        let mut params = HashMap::new();
902        // Identity matrix should have zero penalty
903        params.insert("w".to_string(), array![[1.0, 0.0], [0.0, 1.0]]);
904
905        let penalty = regularizer.compute_penalty(&params).unwrap();
906        assert!(penalty.abs() < 1e-10);
907
908        // Non-orthogonal matrix should have non-zero penalty
909        params.insert("w".to_string(), array![[1.0, 1.0], [1.0, 1.0]]);
910        let penalty = regularizer.compute_penalty(&params).unwrap();
911        assert!(penalty > 0.0);
912
913        let gradients = regularizer.compute_gradient(&params).unwrap();
914        assert!(gradients.contains_key("w"));
915    }
916
917    #[test]
918    fn test_group_lasso_regularization() {
919        let regularizer = GroupLassoRegularization::new(0.1, 2);
920
921        let mut params = HashMap::new();
922        params.insert(
923            "w".to_string(),
924            array![[1.0, 2.0], [3.0, 4.0]], // Flatten to [1,2,3,4], groups [1,2] and [3,4]
925        );
926
927        let penalty = regularizer.compute_penalty(&params).unwrap();
928        // Group 1: sqrt(1^2 + 2^2) = sqrt(5) ≈ 2.236
929        // Group 2: sqrt(3^2 + 4^2) = sqrt(25) = 5.0
930        // Total: 0.1 * (2.236 + 5.0) ≈ 0.7236
931        assert!((penalty - 0.7236).abs() < 0.01);
932
933        let gradients = regularizer.compute_gradient(&params).unwrap();
934        let grad_w = gradients.get("w").unwrap();
935        assert_eq!(grad_w.dim(), (2, 2));
936    }
937
938    #[test]
939    fn test_spectral_normalization_zero_matrix() {
940        let regularizer = SpectralNormalization::new(0.1, 1.0, 5);
941
942        let mut params = HashMap::new();
943        params.insert("w".to_string(), array![[0.0, 0.0], [0.0, 0.0]]);
944
945        let penalty = regularizer.compute_penalty(&params).unwrap();
946        // Spectral norm of zero matrix is 0
947        // Penalty = 0.1 * (0 - 1.0)^2 = 0.1
948        assert!((penalty - 0.1).abs() < 0.01);
949
950        let gradients = regularizer.compute_gradient(&params).unwrap();
951        let grad_w = gradients.get("w").unwrap();
952        // Gradient should be zero for zero matrix
953        assert!(grad_w.iter().all(|&x| x.abs() < 1e-10));
954    }
955
956    #[test]
957    fn test_max_norm_no_violation() {
958        let regularizer = MaxNormRegularization::new(10.0, 0.1, 0);
959
960        let mut params = HashMap::new();
961        params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
962
963        let penalty = regularizer.compute_penalty(&params).unwrap();
964        // All norms are below 10.0, so no penalty
965        assert!(penalty.abs() < 1e-10);
966
967        let gradients = regularizer.compute_gradient(&params).unwrap();
968        let grad_w = gradients.get("w").unwrap();
969        // All gradients should be zero
970        assert!(grad_w.iter().all(|&x| x.abs() < 1e-10));
971    }
972
973    #[test]
974    fn test_orthogonal_non_square() {
975        let regularizer = OrthogonalRegularization::new(0.1);
976
977        let mut params = HashMap::new();
978        params.insert("w".to_string(), array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]);
979
980        // Non-square matrix: W^T * W will be 3x3
981        let penalty = regularizer.compute_penalty(&params).unwrap();
982        assert!(penalty > 0.0); // Should have some penalty
983
984        let gradients = regularizer.compute_gradient(&params).unwrap();
985        assert!(gradients.contains_key("w"));
986    }
987
988    #[test]
989    fn test_group_lasso_single_group() {
990        let regularizer = GroupLassoRegularization::new(0.1, 4);
991
992        let mut params = HashMap::new();
993        params.insert("w".to_string(), array![[3.0, 4.0]]);
994
995        let penalty = regularizer.compute_penalty(&params).unwrap();
996        // Single group: sqrt(3^2 + 4^2) = 5.0
997        // Penalty = 0.1 * 5.0 = 0.5
998        assert!((penalty - 0.5).abs() < 0.01);
999
1000        let gradients = regularizer.compute_gradient(&params).unwrap();
1001        let grad_w = gradients.get("w").unwrap();
1002        assert_eq!(grad_w.dim(), (1, 2));
1003    }
1004}