sklears_linear/
regularization_schemes.rs

1//! Composable Regularization Schemes for Linear Models
2//!
3//! This module implements various regularization schemes that can be composed
4//! and used with the modular framework. All regularization schemes implement
5//! the Regularization trait for consistency and pluggability.
6
7use crate::modular_framework::Regularization;
8use scirs2_core::ndarray::Array1;
9use sklears_core::{
10    error::{Result, SklearsError},
11    types::Float,
12};
13
14/// L2 (Ridge) regularization: α/2 * ||w||²₂
15#[derive(Debug, Clone)]
16pub struct L2Regularization {
17    /// Regularization strength
18    pub alpha: Float,
19}
20
21impl L2Regularization {
22    /// Create a new L2 regularization with the specified strength
23    pub fn new(alpha: Float) -> Result<Self> {
24        if alpha < 0.0 {
25            return Err(SklearsError::InvalidParameter {
26                name: "alpha".to_string(),
27                reason: format!(
28                    "Regularization strength must be non-negative, got {}",
29                    alpha
30                ),
31            });
32        }
33        Ok(Self { alpha })
34    }
35}
36
37impl Regularization for L2Regularization {
38    fn penalty(&self, coefficients: &Array1<Float>) -> Result<Float> {
39        let norm_squared = coefficients.mapv(|x| x * x).sum();
40        Ok(0.5 * self.alpha * norm_squared)
41    }
42
43    fn penalty_gradient(&self, coefficients: &Array1<Float>) -> Result<Array1<Float>> {
44        Ok(self.alpha * coefficients)
45    }
46
47    fn proximal_operator(
48        &self,
49        coefficients: &Array1<Float>,
50        step_size: Float,
51    ) -> Result<Array1<Float>> {
52        // For L2: prox(x) = x / (1 + α * step_size)
53        let shrinkage_factor = 1.0 / (1.0 + self.alpha * step_size);
54        Ok(coefficients * shrinkage_factor)
55    }
56
57    fn strength(&self) -> Float {
58        self.alpha
59    }
60
61    fn name(&self) -> &'static str {
62        "L2Regularization"
63    }
64}
65
66/// L1 (Lasso) regularization: α * ||w||₁
67#[derive(Debug, Clone)]
68pub struct L1Regularization {
69    /// Regularization strength
70    pub alpha: Float,
71}
72
73impl L1Regularization {
74    /// Create a new L1 regularization with the specified strength
75    pub fn new(alpha: Float) -> Result<Self> {
76        if alpha < 0.0 {
77            return Err(SklearsError::InvalidParameter {
78                name: "alpha".to_string(),
79                reason: format!(
80                    "Regularization strength must be non-negative, got {}",
81                    alpha
82                ),
83            });
84        }
85        Ok(Self { alpha })
86    }
87}
88
89impl Regularization for L1Regularization {
90    fn penalty(&self, coefficients: &Array1<Float>) -> Result<Float> {
91        let l1_norm = coefficients.mapv(|x| x.abs()).sum();
92        Ok(self.alpha * l1_norm)
93    }
94
95    fn penalty_gradient(&self, coefficients: &Array1<Float>) -> Result<Array1<Float>> {
96        // L1 penalty is not differentiable at 0, so we return the subgradient
97        let subgradient = coefficients.mapv(|x| {
98            if x > 0.0 {
99                self.alpha
100            } else if x < 0.0 {
101                -self.alpha
102            } else {
103                0.0 // Could be any value in [-α, α]
104            }
105        });
106        Ok(subgradient)
107    }
108
109    fn proximal_operator(
110        &self,
111        coefficients: &Array1<Float>,
112        step_size: Float,
113    ) -> Result<Array1<Float>> {
114        // Soft thresholding operator
115        let threshold = self.alpha * step_size;
116        let result = coefficients.mapv(|x| {
117            if x > threshold {
118                x - threshold
119            } else if x < -threshold {
120                x + threshold
121            } else {
122                0.0
123            }
124        });
125        Ok(result)
126    }
127
128    fn is_non_smooth(&self) -> bool {
129        true
130    }
131
132    fn strength(&self) -> Float {
133        self.alpha
134    }
135
136    fn name(&self) -> &'static str {
137        "L1Regularization"
138    }
139}
140
141/// Elastic Net regularization: α * (ρ * ||w||₁ + (1-ρ)/2 * ||w||²₂)
142#[derive(Debug, Clone)]
143pub struct ElasticNetRegularization {
144    /// Total regularization strength
145    pub alpha: Float,
146    /// L1 ratio (ρ): 0 = Ridge, 1 = Lasso
147    pub l1_ratio: Float,
148}
149
150impl ElasticNetRegularization {
151    /// Create a new Elastic Net regularization
152    pub fn new(alpha: Float, l1_ratio: Float) -> Result<Self> {
153        if alpha < 0.0 {
154            return Err(SklearsError::InvalidParameter {
155                name: "alpha".to_string(),
156                reason: format!(
157                    "Regularization strength must be non-negative, got {}",
158                    alpha
159                ),
160            });
161        }
162        if !(0.0..=1.0).contains(&l1_ratio) {
163            return Err(SklearsError::InvalidParameter {
164                name: "l1_ratio".to_string(),
165                reason: format!("L1 ratio must be between 0 and 1, got {}", l1_ratio),
166            });
167        }
168        Ok(Self { alpha, l1_ratio })
169    }
170
171    /// Get the L1 regularization strength
172    pub fn l1_strength(&self) -> Float {
173        self.alpha * self.l1_ratio
174    }
175
176    /// Get the L2 regularization strength
177    pub fn l2_strength(&self) -> Float {
178        self.alpha * (1.0 - self.l1_ratio)
179    }
180}
181
182impl Regularization for ElasticNetRegularization {
183    fn penalty(&self, coefficients: &Array1<Float>) -> Result<Float> {
184        let l1_norm = coefficients.mapv(|x| x.abs()).sum();
185        let l2_norm_squared = coefficients.mapv(|x| x * x).sum();
186
187        let l1_penalty = self.l1_strength() * l1_norm;
188        let l2_penalty = 0.5 * self.l2_strength() * l2_norm_squared;
189
190        Ok(l1_penalty + l2_penalty)
191    }
192
193    fn penalty_gradient(&self, coefficients: &Array1<Float>) -> Result<Array1<Float>> {
194        let l1_strength = self.l1_strength();
195        let l2_strength = self.l2_strength();
196
197        let gradient = coefficients.mapv(|x| {
198            let l1_subgrad = if x > 0.0 {
199                l1_strength
200            } else if x < 0.0 {
201                -l1_strength
202            } else {
203                0.0
204            };
205            let l2_grad = l2_strength * x;
206            l1_subgrad + l2_grad
207        });
208
209        Ok(gradient)
210    }
211
212    fn proximal_operator(
213        &self,
214        coefficients: &Array1<Float>,
215        step_size: Float,
216    ) -> Result<Array1<Float>> {
217        let l1_strength = self.l1_strength();
218        let l2_strength = self.l2_strength();
219
220        // Elastic net proximal operator combines soft thresholding with L2 shrinkage
221        let threshold = l1_strength * step_size;
222        let shrinkage_factor = 1.0 / (1.0 + l2_strength * step_size);
223
224        let result = coefficients.mapv(|x| {
225            let soft_thresholded = if x > threshold {
226                x - threshold
227            } else if x < -threshold {
228                x + threshold
229            } else {
230                0.0
231            };
232            soft_thresholded * shrinkage_factor
233        });
234
235        Ok(result)
236    }
237
238    fn is_non_smooth(&self) -> bool {
239        self.l1_ratio > 0.0
240    }
241
242    fn strength(&self) -> Float {
243        self.alpha
244    }
245
246    fn name(&self) -> &'static str {
247        "ElasticNetRegularization"
248    }
249}
250
251/// Group Lasso regularization for grouped features
252#[derive(Debug, Clone)]
253pub struct GroupLassoRegularization {
254    /// Regularization strength
255    pub alpha: Float,
256    /// Group assignment for each feature (group_id for each coefficient)
257    pub groups: Vec<usize>,
258}
259
260impl GroupLassoRegularization {
261    /// Create a new Group Lasso regularization
262    pub fn new(alpha: Float, groups: Vec<usize>) -> Result<Self> {
263        if alpha < 0.0 {
264            return Err(SklearsError::InvalidParameter {
265                name: "alpha".to_string(),
266                reason: format!(
267                    "Regularization strength must be non-negative, got {}",
268                    alpha
269                ),
270            });
271        }
272        Ok(Self { alpha, groups })
273    }
274}
275
276impl Regularization for GroupLassoRegularization {
277    fn penalty(&self, coefficients: &Array1<Float>) -> Result<Float> {
278        if coefficients.len() != self.groups.len() {
279            return Err(SklearsError::DimensionMismatch {
280                expected: self.groups.len(),
281                actual: coefficients.len(),
282            });
283        }
284
285        // Group the coefficients and compute L2 norm for each group
286        let max_group = *self.groups.iter().max().unwrap_or(&0);
287        let mut group_norms = vec![0.0; max_group + 1];
288
289        for (i, &group_id) in self.groups.iter().enumerate() {
290            group_norms[group_id] += coefficients[i] * coefficients[i];
291        }
292
293        let penalty = group_norms
294            .iter()
295            .map(|&norm_sq| norm_sq.sqrt())
296            .sum::<Float>();
297        Ok(self.alpha * penalty)
298    }
299
300    fn penalty_gradient(&self, coefficients: &Array1<Float>) -> Result<Array1<Float>> {
301        if coefficients.len() != self.groups.len() {
302            return Err(SklearsError::DimensionMismatch {
303                expected: self.groups.len(),
304                actual: coefficients.len(),
305            });
306        }
307
308        let max_group = *self.groups.iter().max().unwrap_or(&0);
309        let mut group_norms = vec![0.0; max_group + 1];
310
311        // Compute group norms
312        for (i, &group_id) in self.groups.iter().enumerate() {
313            group_norms[group_id] += coefficients[i] * coefficients[i];
314        }
315
316        // Convert to L2 norms
317        for norm_sq in &mut group_norms {
318            *norm_sq = norm_sq.sqrt();
319        }
320
321        // Compute subgradient
322        let mut gradient = Array1::zeros(coefficients.len());
323        for (i, &group_id) in self.groups.iter().enumerate() {
324            if group_norms[group_id] > 0.0 {
325                gradient[i] = self.alpha * coefficients[i] / group_norms[group_id];
326            } else {
327                gradient[i] = 0.0; // Subgradient at 0
328            }
329        }
330
331        Ok(gradient)
332    }
333
334    fn proximal_operator(
335        &self,
336        coefficients: &Array1<Float>,
337        step_size: Float,
338    ) -> Result<Array1<Float>> {
339        if coefficients.len() != self.groups.len() {
340            return Err(SklearsError::DimensionMismatch {
341                expected: self.groups.len(),
342                actual: coefficients.len(),
343            });
344        }
345
346        let max_group = *self.groups.iter().max().unwrap_or(&0);
347        let mut group_norms = vec![0.0; max_group + 1];
348
349        // Compute group norms
350        for (i, &group_id) in self.groups.iter().enumerate() {
351            group_norms[group_id] += coefficients[i] * coefficients[i];
352        }
353
354        for norm_sq in &mut group_norms {
355            *norm_sq = norm_sq.sqrt();
356        }
357
358        // Apply group soft thresholding
359        let threshold = self.alpha * step_size;
360        let mut result = coefficients.clone();
361
362        for (i, &group_id) in self.groups.iter().enumerate() {
363            let group_norm = group_norms[group_id];
364            if group_norm > threshold {
365                let shrinkage_factor = (group_norm - threshold) / group_norm;
366                result[i] *= shrinkage_factor;
367            } else {
368                result[i] = 0.0;
369            }
370        }
371
372        Ok(result)
373    }
374
375    fn is_non_smooth(&self) -> bool {
376        true
377    }
378
379    fn strength(&self) -> Float {
380        self.alpha
381    }
382
383    fn name(&self) -> &'static str {
384        "GroupLassoRegularization"
385    }
386}
387
388/// Composite regularization that combines multiple regularization schemes
389#[derive(Debug)]
390pub struct CompositeRegularization {
391    /// List of regularization schemes with their weights
392    regularizations: Vec<(Float, Box<dyn Regularization>)>,
393}
394
395impl Default for CompositeRegularization {
396    fn default() -> Self {
397        Self::new()
398    }
399}
400
401impl CompositeRegularization {
402    /// Create a new composite regularization
403    pub fn new() -> Self {
404        Self {
405            regularizations: Vec::new(),
406        }
407    }
408
409    /// Add a regularization scheme with a weight
410    pub fn add_regularization(
411        mut self,
412        weight: Float,
413        regularization: Box<dyn Regularization>,
414    ) -> Self {
415        self.regularizations.push((weight, regularization));
416        self
417    }
418
419    /// Add L1 regularization
420    pub fn add_l1(self, alpha: Float) -> Result<Self> {
421        Ok(self.add_regularization(1.0, Box::new(L1Regularization::new(alpha)?)))
422    }
423
424    /// Add L2 regularization
425    pub fn add_l2(self, alpha: Float) -> Result<Self> {
426        Ok(self.add_regularization(1.0, Box::new(L2Regularization::new(alpha)?)))
427    }
428
429    /// Add Group Lasso regularization
430    pub fn add_group_lasso(self, alpha: Float, groups: Vec<usize>) -> Result<Self> {
431        Ok(self.add_regularization(1.0, Box::new(GroupLassoRegularization::new(alpha, groups)?)))
432    }
433
434    /// Check if any component is non-smooth
435    pub fn is_any_non_smooth(&self) -> bool {
436        self.regularizations
437            .iter()
438            .any(|(_, reg)| reg.is_non_smooth())
439    }
440}
441
442impl Regularization for CompositeRegularization {
443    fn penalty(&self, coefficients: &Array1<Float>) -> Result<Float> {
444        let mut total_penalty = 0.0;
445        for (weight, regularization) in &self.regularizations {
446            total_penalty += weight * regularization.penalty(coefficients)?;
447        }
448        Ok(total_penalty)
449    }
450
451    fn penalty_gradient(&self, coefficients: &Array1<Float>) -> Result<Array1<Float>> {
452        let mut total_gradient = Array1::zeros(coefficients.len());
453        for (weight, regularization) in &self.regularizations {
454            let grad = regularization.penalty_gradient(coefficients)?;
455            total_gradient = total_gradient + *weight * grad;
456        }
457        Ok(total_gradient)
458    }
459
460    fn proximal_operator(
461        &self,
462        coefficients: &Array1<Float>,
463        step_size: Float,
464    ) -> Result<Array1<Float>> {
465        // For composite regularization, we apply proximal operators sequentially
466        // This is an approximation - the exact proximal operator is generally not available
467        let mut result = coefficients.clone();
468        for (weight, regularization) in &self.regularizations {
469            result = regularization.proximal_operator(&result, weight * step_size)?;
470        }
471        Ok(result)
472    }
473
474    fn is_non_smooth(&self) -> bool {
475        self.is_any_non_smooth()
476    }
477
478    fn strength(&self) -> Float {
479        // Return the sum of weighted strengths
480        self.regularizations
481            .iter()
482            .map(|(weight, reg)| weight * reg.strength())
483            .sum()
484    }
485
486    fn name(&self) -> &'static str {
487        "CompositeRegularization"
488    }
489}
490
491/// Factory for creating common regularization schemes
492pub struct RegularizationFactory;
493
494impl RegularizationFactory {
495    /// Create L1 (Lasso) regularization
496    pub fn l1(alpha: Float) -> Result<Box<dyn Regularization>> {
497        Ok(Box::new(L1Regularization::new(alpha)?))
498    }
499
500    /// Create L2 (Ridge) regularization
501    pub fn l2(alpha: Float) -> Result<Box<dyn Regularization>> {
502        Ok(Box::new(L2Regularization::new(alpha)?))
503    }
504
505    /// Create Elastic Net regularization
506    pub fn elastic_net(alpha: Float, l1_ratio: Float) -> Result<Box<dyn Regularization>> {
507        Ok(Box::new(ElasticNetRegularization::new(alpha, l1_ratio)?))
508    }
509
510    /// Create Group Lasso regularization
511    pub fn group_lasso(alpha: Float, groups: Vec<usize>) -> Result<Box<dyn Regularization>> {
512        Ok(Box::new(GroupLassoRegularization::new(alpha, groups)?))
513    }
514
515    /// Create a composite regularization builder
516    pub fn composite() -> CompositeRegularization {
517        CompositeRegularization::new()
518    }
519}
520
521#[allow(non_snake_case)]
522#[cfg(test)]
523mod tests {
524    use super::*;
525    use scirs2_core::ndarray::Array;
526
527    #[test]
528    fn test_l2_regularization() {
529        let reg = L2Regularization::new(0.5).unwrap();
530        let coefficients = Array::from_vec(vec![1.0, -2.0, 3.0]);
531
532        let penalty = reg.penalty(&coefficients).unwrap();
533        let expected = 0.5 * 0.5 * (1.0 + 4.0 + 9.0); // α/2 * ||w||²
534        assert!((penalty - expected).abs() < 1e-10);
535
536        let gradient = reg.penalty_gradient(&coefficients).unwrap();
537        let expected_grad = Array::from_vec(vec![0.5, -1.0, 1.5]); // α * w
538        for (actual, expected) in gradient.iter().zip(expected_grad.iter()) {
539            assert!((actual - expected).abs() < 1e-10);
540        }
541    }
542
543    #[test]
544    fn test_l1_regularization() {
545        let reg = L1Regularization::new(0.3).unwrap();
546        let coefficients = Array::from_vec(vec![1.0, -2.0, 3.0]);
547
548        let penalty = reg.penalty(&coefficients).unwrap();
549        let expected = 0.3 * (1.0 + 2.0 + 3.0); // α * ||w||₁
550        assert!((penalty - expected).abs() < 1e-10);
551
552        assert!(reg.is_non_smooth());
553    }
554
555    #[test]
556    fn test_l1_proximal_operator() {
557        let reg = L1Regularization::new(1.0).unwrap();
558        let coefficients = Array::from_vec(vec![2.0, -1.0, 0.5]);
559        let step_size = 1.0;
560
561        let result = reg.proximal_operator(&coefficients, step_size).unwrap();
562        // Soft thresholding with threshold = 1.0 * 1.0 = 1.0
563        let expected = Array::from_vec(vec![1.0, 0.0, 0.0]);
564        for (actual, expected) in result.iter().zip(expected.iter()) {
565            assert!((actual - expected).abs() < 1e-10);
566        }
567    }
568
569    #[test]
570    fn test_elastic_net_regularization() {
571        let reg = ElasticNetRegularization::new(1.0, 0.7).unwrap();
572        let coefficients = Array::from_vec(vec![1.0, -1.0]);
573
574        let penalty = reg.penalty(&coefficients).unwrap();
575        let l1_penalty = 0.7 * (1.0 + 1.0); // l1_ratio * α * ||w||₁
576        let l2_penalty = 0.5 * 0.3 * (1.0 + 1.0); // (1-l1_ratio) * α/2 * ||w||²
577        let expected = l1_penalty + l2_penalty;
578        assert!((penalty - expected).abs() < 1e-10);
579
580        assert!(reg.is_non_smooth()); // Because l1_ratio > 0
581    }
582
583    #[test]
584    fn test_group_lasso_regularization() {
585        let groups = vec![0, 0, 1, 1]; // Two groups: {0,1} and {2,3}
586        let reg = GroupLassoRegularization::new(1.0, groups).unwrap();
587        let coefficients = Array::from_vec(vec![3.0, 4.0, 0.0, 0.0]); // First group has norm 5.0, second group is zero
588
589        let penalty = reg.penalty(&coefficients).unwrap();
590        let expected = 5.0 + 0.0; // Sum of group L2 norms
591        assert!((penalty - expected).abs() < 1e-10);
592
593        assert!(reg.is_non_smooth());
594    }
595
596    #[test]
597    fn test_composite_regularization() {
598        let composite = CompositeRegularization::new()
599            .add_l1(0.1)
600            .unwrap()
601            .add_l2(0.2)
602            .unwrap();
603
604        let coefficients = Array::from_vec(vec![1.0, -2.0]);
605
606        let penalty = composite.penalty(&coefficients).unwrap();
607        let l1_penalty = 0.1 * (1.0 + 2.0);
608        let l2_penalty = 0.5 * 0.2 * (1.0 + 4.0);
609        let expected = l1_penalty + l2_penalty;
610        assert!((penalty - expected).abs() < 1e-10);
611
612        assert!(composite.is_non_smooth()); // Because it includes L1
613    }
614
615    #[test]
616    fn test_regularization_factory() {
617        let l1 = RegularizationFactory::l1(0.5).unwrap();
618        assert_eq!(l1.name(), "L1Regularization");
619
620        let l2 = RegularizationFactory::l2(0.3).unwrap();
621        assert_eq!(l2.name(), "L2Regularization");
622
623        let elastic_net = RegularizationFactory::elastic_net(1.0, 0.8).unwrap();
624        assert_eq!(elastic_net.name(), "ElasticNetRegularization");
625    }
626
627    #[test]
628    fn test_invalid_parameters() {
629        // Negative alpha should fail
630        assert!(L1Regularization::new(-1.0).is_err());
631        assert!(L2Regularization::new(-0.1).is_err());
632
633        // Invalid l1_ratio should fail
634        assert!(ElasticNetRegularization::new(1.0, -0.1).is_err());
635        assert!(ElasticNetRegularization::new(1.0, 1.5).is_err());
636    }
637}