sklears_cross_decomposition/
regularization.rs

1//! Regularization techniques for cross-decomposition methods
2//!
3//! This module provides various regularization methods including elastic net,
4//! group lasso, fused lasso, and other penalty functions that can be applied
5//! to cross-decomposition algorithms.
6
7use scirs2_core::ndarray::{Array1, Array2, Axis};
8use sklears_core::{
9    error::{Result, SklearsError},
10    types::Float,
11};
12
13/// Elastic Net regularization combining L1 and L2 penalties
14///
15/// The elastic net penalty is: alpha * (l1_ratio * |w|_1 + (1 - l1_ratio) * |w|_2^2)
16/// where alpha controls overall regularization strength and l1_ratio controls
17/// the balance between L1 and L2 penalties.
18#[derive(Debug, Clone)]
19pub struct ElasticNet {
20    /// Regularization strength (alpha)
21    pub alpha: Float,
22    /// L1 ratio (0.0 = pure L2, 1.0 = pure L1)
23    pub l1_ratio: Float,
24    /// Maximum number of iterations for coordinate descent
25    pub max_iter: usize,
26    /// Convergence tolerance
27    pub tol: Float,
28    /// Whether to fit intercept
29    pub fit_intercept: bool,
30    /// Whether to normalize features
31    pub normalize: bool,
32    /// Positive constraint on coefficients
33    pub positive: bool,
34    /// Random state for reproducibility
35    pub random_state: Option<u64>,
36}
37
38impl ElasticNet {
39    /// Create a new ElasticNet regularizer
40    pub fn new(alpha: Float, l1_ratio: Float) -> Self {
41        Self {
42            alpha,
43            l1_ratio: l1_ratio.clamp(0.0, 1.0),
44            max_iter: 1000,
45            tol: 1e-4,
46            fit_intercept: true,
47            normalize: false,
48            positive: false,
49            random_state: None,
50        }
51    }
52
53    /// Set maximum number of iterations
54    pub fn max_iter(mut self, max_iter: usize) -> Self {
55        self.max_iter = max_iter;
56        self
57    }
58
59    /// Set convergence tolerance
60    pub fn tol(mut self, tol: Float) -> Self {
61        self.tol = tol;
62        self
63    }
64
65    /// Set whether to fit intercept
66    pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
67        self.fit_intercept = fit_intercept;
68        self
69    }
70
71    /// Set whether to normalize features
72    pub fn normalize(mut self, normalize: bool) -> Self {
73        self.normalize = normalize;
74        self
75    }
76
77    /// Set positive constraint
78    pub fn positive(mut self, positive: bool) -> Self {
79        self.positive = positive;
80        self
81    }
82
83    /// Set random state
84    pub fn random_state(mut self, random_state: u64) -> Self {
85        self.random_state = Some(random_state);
86        self
87    }
88
89    /// Fit elastic net to data using coordinate descent
90    pub fn fit(&self, X: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
91        let (n_samples, n_features) = X.dim();
92
93        if y.len() != n_samples {
94            return Err(SklearsError::InvalidInput(format!(
95                "X and y have incompatible shapes: {} vs {}",
96                n_samples,
97                y.len()
98            )));
99        }
100
101        // Normalize features if requested
102        let (X_norm, feature_means, feature_stds) = if self.normalize {
103            self.normalize_features(X)?
104        } else {
105            (
106                X.clone(),
107                Array1::zeros(n_features),
108                Array1::ones(n_features),
109            )
110        };
111
112        // Center target if fitting intercept
113        let (y_centered, y_mean) = if self.fit_intercept {
114            let mean = y.mean().unwrap();
115            (y - mean, mean)
116        } else {
117            (y.clone(), 0.0)
118        };
119
120        // Initialize coefficients
121        let mut coef = Array1::zeros(n_features);
122
123        // Coordinate descent algorithm
124        for iter in 0..self.max_iter {
125            let old_coef = coef.clone();
126
127            for j in 0..n_features {
128                // Compute residual without j-th feature
129                let mut residual = y_centered.clone();
130                for k in 0..n_features {
131                    if k != j {
132                        let col_k = X_norm.column(k);
133                        residual = residual - coef[k] * &col_k;
134                    }
135                }
136
137                // Compute correlation with j-th feature
138                let col_j = X_norm.column(j);
139                let rho = col_j.dot(&residual) / n_samples as Float;
140
141                // Soft thresholding for L1 penalty
142                let l1_penalty = self.alpha * self.l1_ratio;
143                let l2_penalty = self.alpha * (1.0 - self.l1_ratio);
144
145                // Denominator includes L2 penalty
146                let denominator = col_j.dot(&col_j) / n_samples as Float + l2_penalty;
147
148                let new_coef = if rho > l1_penalty {
149                    (rho - l1_penalty) / denominator
150                } else if rho < -l1_penalty {
151                    (rho + l1_penalty) / denominator
152                } else {
153                    0.0
154                };
155
156                // Apply positive constraint if needed
157                coef[j] = if self.positive {
158                    new_coef.max(0.0)
159                } else {
160                    new_coef
161                };
162            }
163
164            // Check convergence
165            let coef_change = (&coef - &old_coef).mapv(|x| x.abs()).sum();
166            if coef_change < self.tol {
167                break;
168            }
169        }
170
171        // Rescale coefficients if features were normalized
172        if self.normalize {
173            for j in 0..n_features {
174                if feature_stds[j] > 1e-8 {
175                    coef[j] /= feature_stds[j];
176                }
177            }
178        }
179
180        Ok(coef)
181    }
182
183    /// Normalize features to zero mean and unit variance
184    fn normalize_features(
185        &self,
186        X: &Array2<Float>,
187    ) -> Result<(Array2<Float>, Array1<Float>, Array1<Float>)> {
188        let means = X.mean_axis(Axis(0)).unwrap();
189        let centered = X - &means.view().insert_axis(Axis(0));
190        let stds = centered.var_axis(Axis(0), 1.0).mapv(|x| x.sqrt());
191
192        let mut normalized = centered;
193        for (j, &std) in stds.iter().enumerate() {
194            if std > 1e-8 {
195                normalized.column_mut(j).mapv_inplace(|x| x / std);
196            }
197        }
198
199        Ok((normalized, means, stds))
200    }
201
202    /// Compute elastic net penalty value
203    pub fn penalty(&self, coef: &Array1<Float>) -> Float {
204        let l1_penalty = self.l1_ratio * coef.mapv(|x| x.abs()).sum();
205        let l2_penalty = (1.0 - self.l1_ratio) * 0.5 * coef.mapv(|x| x * x).sum();
206        self.alpha * (l1_penalty + l2_penalty)
207    }
208
209    /// Compute elastic net regularization path for different alpha values
210    pub fn path(
211        &self,
212        X: &Array2<Float>,
213        y: &Array1<Float>,
214        alphas: &Array1<Float>,
215    ) -> Result<Array2<Float>> {
216        let n_features = X.ncols();
217        let n_alphas = alphas.len();
218        let mut coef_path = Array2::zeros((n_features, n_alphas));
219
220        for (alpha_idx, &alpha) in alphas.iter().enumerate() {
221            let mut elastic_net = self.clone();
222            elastic_net.alpha = alpha;
223
224            let coef = elastic_net.fit(X, y)?;
225            coef_path.column_mut(alpha_idx).assign(&coef);
226        }
227
228        Ok(coef_path)
229    }
230}
231
232impl Default for ElasticNet {
233    fn default() -> Self {
234        Self::new(1.0, 0.5)
235    }
236}
237
238/// Group Lasso regularization for structured sparsity
239///
240/// Group lasso applies L2 penalty within groups and L1 penalty between groups,
241/// encouraging entire groups of features to be zeroed out together.
242#[derive(Debug, Clone)]
243pub struct GroupLasso {
244    /// Regularization strength
245    pub alpha: Float,
246    /// Groups of features (indices for each group)
247    pub groups: Vec<Vec<usize>>,
248    /// Maximum number of iterations
249    pub max_iter: usize,
250    /// Convergence tolerance
251    pub tol: Float,
252    /// Whether to fit intercept
253    pub fit_intercept: bool,
254}
255
256impl GroupLasso {
257    /// Create a new GroupLasso regularizer
258    pub fn new(alpha: Float, groups: Vec<Vec<usize>>) -> Self {
259        Self {
260            alpha,
261            groups,
262            max_iter: 1000,
263            tol: 1e-4,
264            fit_intercept: true,
265        }
266    }
267
268    /// Fit group lasso using block coordinate descent
269    pub fn fit(&self, X: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
270        let (n_samples, n_features) = X.dim();
271        let mut coef = Array1::zeros(n_features);
272
273        // Validate groups
274        let mut all_indices = Vec::new();
275        for group in &self.groups {
276            for &idx in group {
277                if idx >= n_features {
278                    return Err(SklearsError::InvalidInput(format!(
279                        "Group index {} exceeds number of features {}",
280                        idx, n_features
281                    )));
282                }
283                all_indices.push(idx);
284            }
285        }
286        all_indices.sort_unstable();
287        all_indices.dedup();
288
289        if all_indices.len() != n_features {
290            return Err(SklearsError::InvalidInput(
291                "Groups must cover all features exactly once".to_string(),
292            ));
293        }
294
295        // Center target if fitting intercept
296        let y_centered = if self.fit_intercept {
297            let mean = y.mean().unwrap();
298            y - mean
299        } else {
300            y.clone()
301        };
302
303        // Block coordinate descent
304        for _iter in 0..self.max_iter {
305            let old_coef = coef.clone();
306
307            for group in &self.groups {
308                if group.is_empty() {
309                    continue;
310                }
311
312                // Extract group features
313                let mut X_group = Array2::zeros((n_samples, group.len()));
314                for (g_idx, &feat_idx) in group.iter().enumerate() {
315                    X_group.column_mut(g_idx).assign(&X.column(feat_idx));
316                }
317
318                // Compute residual without current group
319                let mut residual = y_centered.clone();
320                for k in 0..n_features {
321                    if !group.contains(&k) {
322                        residual = residual - coef[k] * &X.column(k);
323                    }
324                }
325
326                // Compute group gradient
327                let gradient = X_group.t().dot(&residual) / n_samples as Float;
328                let gradient_norm = (gradient.dot(&gradient)).sqrt();
329
330                // Group soft thresholding
331                let group_penalty = self.alpha * (group.len() as Float).sqrt();
332
333                if gradient_norm <= group_penalty {
334                    // Zero out entire group
335                    for &feat_idx in group {
336                        coef[feat_idx] = 0.0;
337                    }
338                } else {
339                    // Shrink group coefficients
340                    let shrinkage = 1.0 - group_penalty / gradient_norm;
341
342                    // Solve within-group problem (L2 regularized least squares)
343                    let XtX = X_group.t().dot(&X_group) / n_samples as Float;
344                    let group_coef = self.solve_within_group(&XtX, &gradient, shrinkage)?;
345
346                    for (g_idx, &feat_idx) in group.iter().enumerate() {
347                        coef[feat_idx] = group_coef[g_idx];
348                    }
349                }
350            }
351
352            // Check convergence
353            let coef_change = (&coef - &old_coef).mapv(|x| x.abs()).sum();
354            if coef_change < self.tol {
355                break;
356            }
357        }
358
359        Ok(coef)
360    }
361
362    /// Solve within-group optimization problem
363    fn solve_within_group(
364        &self,
365        XtX: &Array2<Float>,
366        gradient: &Array1<Float>,
367        shrinkage: Float,
368    ) -> Result<Array1<Float>> {
369        // Simple solution: use gradient with shrinkage
370        // In practice, this should solve the regularized least squares problem
371        Ok(shrinkage * gradient)
372    }
373
374    /// Compute group lasso penalty
375    pub fn penalty(&self, coef: &Array1<Float>) -> Float {
376        let mut penalty = 0.0;
377
378        for group in &self.groups {
379            let mut group_norm_sq = 0.0;
380            for &idx in group {
381                group_norm_sq += coef[idx] * coef[idx];
382            }
383            penalty += (group.len() as Float).sqrt() * group_norm_sq.sqrt();
384        }
385
386        self.alpha * penalty
387    }
388}
389
390/// Fused Lasso regularization for sequential data
391///
392/// Fused lasso adds penalty on differences between adjacent coefficients,
393/// encouraging piecewise constant solutions.
394#[derive(Debug, Clone)]
395pub struct FusedLasso {
396    /// L1 penalty on coefficients
397    pub alpha1: Float,
398    /// L1 penalty on differences between adjacent coefficients
399    pub alpha2: Float,
400    /// Maximum number of iterations
401    pub max_iter: usize,
402    /// Convergence tolerance
403    pub tol: Float,
404}
405
406impl FusedLasso {
407    /// Create a new FusedLasso regularizer
408    pub fn new(alpha1: Float, alpha2: Float) -> Self {
409        Self {
410            alpha1,
411            alpha2,
412            max_iter: 1000,
413            tol: 1e-4,
414        }
415    }
416
417    /// Fit fused lasso using proximal gradient method
418    pub fn fit(&self, X: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
419        let (n_samples, n_features) = X.dim();
420        let mut coef = Array1::zeros(n_features);
421
422        // Learning rate (should be < 1 / largest eigenvalue of X^T X)
423        let step_size = 0.001;
424
425        for _iter in 0..self.max_iter {
426            let old_coef = coef.clone();
427
428            // Compute gradient of least squares loss
429            let residual = y - &X.dot(&coef);
430            let gradient = -X.t().dot(&residual) / n_samples as Float;
431
432            // Gradient step
433            let mut updated_coef = &coef - step_size * &gradient;
434
435            // Proximal operator for fused lasso penalty
436            updated_coef = self.proximal_operator(&updated_coef, step_size)?;
437
438            coef = updated_coef;
439
440            // Check convergence
441            let coef_change = (&coef - &old_coef).mapv(|x| x.abs()).sum();
442            if coef_change < self.tol {
443                break;
444            }
445        }
446
447        Ok(coef)
448    }
449
450    /// Proximal operator for fused lasso penalty
451    fn proximal_operator(&self, coef: &Array1<Float>, step_size: Float) -> Result<Array1<Float>> {
452        let n = coef.len();
453        let mut result = coef.clone();
454
455        // Apply soft thresholding for L1 penalty on coefficients
456        let l1_threshold = step_size * self.alpha1;
457        for i in 0..n {
458            result[i] = if result[i] > l1_threshold {
459                result[i] - l1_threshold
460            } else if result[i] < -l1_threshold {
461                result[i] + l1_threshold
462            } else {
463                0.0
464            };
465        }
466
467        // Apply fusion penalty (simplified - in practice use dynamic programming)
468        let fusion_threshold = step_size * self.alpha2;
469        for i in 1..n {
470            let diff = result[i] - result[i - 1];
471            if diff.abs() < fusion_threshold {
472                let avg = (result[i] + result[i - 1]) / 2.0;
473                result[i] = avg;
474                result[i - 1] = avg;
475            }
476        }
477
478        Ok(result)
479    }
480
481    /// Compute fused lasso penalty
482    pub fn penalty(&self, coef: &Array1<Float>) -> Float {
483        let l1_penalty = self.alpha1 * coef.mapv(|x| x.abs()).sum();
484
485        let mut fusion_penalty = 0.0;
486        for i in 1..coef.len() {
487            fusion_penalty += (coef[i] - coef[i - 1]).abs();
488        }
489        fusion_penalty *= self.alpha2;
490
491        l1_penalty + fusion_penalty
492    }
493}
494
495/// Adaptive Lasso regularization with adaptive weights
496///
497/// Adaptive lasso uses data-dependent weights in the L1 penalty,
498/// providing oracle properties under certain conditions.
499#[derive(Debug, Clone)]
500pub struct AdaptiveLasso {
501    /// Regularization strength
502    pub alpha: Float,
503    /// Adaptive weights for each feature
504    pub weights: Array1<Float>,
505    /// Maximum number of iterations
506    pub max_iter: usize,
507    /// Convergence tolerance
508    pub tol: Float,
509}
510
511impl AdaptiveLasso {
512    /// Create adaptive lasso with weights based on initial OLS estimates
513    pub fn from_ols(
514        alpha: Float,
515        X: &Array2<Float>,
516        y: &Array1<Float>,
517        gamma: Float,
518    ) -> Result<Self> {
519        // Compute OLS solution for weights
520        let XtX = X.t().dot(X);
521        let Xty = X.t().dot(y);
522
523        // Solve normal equations (simplified - should use proper linear solver)
524        let ols_coef = Self::solve_normal_equations(&XtX, &Xty)?;
525
526        // Compute adaptive weights: w_j = 1 / |beta_ols_j|^gamma
527        let weights = ols_coef.mapv(|x| 1.0 / (x.abs() + 1e-8).powf(gamma));
528
529        Ok(Self {
530            alpha,
531            weights,
532            max_iter: 1000,
533            tol: 1e-4,
534        })
535    }
536
537    /// Create adaptive lasso with custom weights
538    pub fn with_weights(alpha: Float, weights: Array1<Float>) -> Self {
539        Self {
540            alpha,
541            weights,
542            max_iter: 1000,
543            tol: 1e-4,
544        }
545    }
546
547    /// Fit adaptive lasso using coordinate descent
548    pub fn fit(&self, X: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
549        let (n_samples, n_features) = X.dim();
550        let mut coef = Array1::zeros(n_features);
551
552        if self.weights.len() != n_features {
553            return Err(SklearsError::InvalidInput(
554                "Weights length must match number of features".to_string(),
555            ));
556        }
557
558        // Coordinate descent
559        for _iter in 0..self.max_iter {
560            let old_coef = coef.clone();
561
562            for j in 0..n_features {
563                // Compute residual without j-th feature
564                let mut residual = y.clone();
565                for k in 0..n_features {
566                    if k != j {
567                        residual = residual - coef[k] * &X.column(k);
568                    }
569                }
570
571                // Correlation with j-th feature
572                let col_j = X.column(j);
573                let rho = col_j.dot(&residual) / n_samples as Float;
574
575                // Weighted soft thresholding
576                let threshold = self.alpha * self.weights[j];
577                let denominator = col_j.dot(&col_j) / n_samples as Float;
578
579                coef[j] = if rho > threshold {
580                    (rho - threshold) / denominator
581                } else if rho < -threshold {
582                    (rho + threshold) / denominator
583                } else {
584                    0.0
585                };
586            }
587
588            // Check convergence
589            let coef_change = (&coef - &old_coef).mapv(|x| x.abs()).sum();
590            if coef_change < self.tol {
591                break;
592            }
593        }
594
595        Ok(coef)
596    }
597
598    /// Simple normal equations solver (for weights computation)
599    fn solve_normal_equations(XtX: &Array2<Float>, Xty: &Array1<Float>) -> Result<Array1<Float>> {
600        // Simplified solver - in practice use proper matrix inversion/Cholesky
601        let n = XtX.nrows();
602        let mut result = Array1::zeros(n);
603
604        // Diagonal approximation for simplicity
605        for i in 0..n {
606            if XtX[[i, i]].abs() > 1e-8 {
607                result[i] = Xty[i] / XtX[[i, i]];
608            }
609        }
610
611        Ok(result)
612    }
613
614    /// Compute adaptive lasso penalty
615    pub fn penalty(&self, coef: &Array1<Float>) -> Float {
616        self.alpha
617            * coef
618                .iter()
619                .zip(self.weights.iter())
620                .map(|(&c, &w)| w * c.abs())
621                .sum::<Float>()
622    }
623}
624
625/// SCAD (Smoothly Clipped Absolute Deviation) regularization
626///
627/// SCAD provides a continuously differentiable penalty that applies less
628/// penalty to large coefficients than L1, reducing bias in large coefficients
629/// while maintaining sparsity for small ones.
630///
631/// The SCAD penalty function is:
632/// - For |β| ≤ λ: λ|β|
633/// - For λ < |β| ≤ aλ: (2aλ|β| - β² - λ²)/(2(a-1))  
634/// - For |β| > aλ: λ²(a+1)/2
635///
636/// where λ is the regularization parameter and a > 2 is a shape parameter.
637#[derive(Debug, Clone)]
638pub struct SCAD {
639    /// Regularization strength (lambda)
640    pub lambda: Float,
641    /// Shape parameter (a > 2)
642    pub a: Float,
643    /// Maximum number of iterations
644    pub max_iter: usize,
645    /// Convergence tolerance
646    pub tol: Float,
647    /// Whether to fit intercept
648    pub fit_intercept: bool,
649    /// Step size for coordinate descent
650    pub step_size: Float,
651}
652
653impl SCAD {
654    /// Create a new SCAD regularizer
655    pub fn new(lambda: Float, a: Float) -> Result<Self> {
656        if a <= 2.0 {
657            return Err(SklearsError::InvalidInput(
658                "SCAD parameter 'a' must be greater than 2.0".to_string(),
659            ));
660        }
661
662        Ok(Self {
663            lambda,
664            a,
665            max_iter: 1000,
666            tol: 1e-4,
667            fit_intercept: true,
668            step_size: 0.01,
669        })
670    }
671
672    /// Set maximum number of iterations
673    pub fn max_iter(mut self, max_iter: usize) -> Self {
674        self.max_iter = max_iter;
675        self
676    }
677
678    /// Set convergence tolerance
679    pub fn tol(mut self, tol: Float) -> Self {
680        self.tol = tol;
681        self
682    }
683
684    /// Set whether to fit intercept
685    pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
686        self.fit_intercept = fit_intercept;
687        self
688    }
689
690    /// Set step size for optimization
691    pub fn step_size(mut self, step_size: Float) -> Self {
692        self.step_size = step_size;
693        self
694    }
695
696    /// Fit SCAD regularized regression using coordinate descent
697    pub fn fit(&self, X: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
698        let (n_samples, n_features) = X.dim();
699
700        if y.len() != n_samples {
701            return Err(SklearsError::InvalidInput(format!(
702                "X and y have incompatible shapes: {} vs {}",
703                n_samples,
704                y.len()
705            )));
706        }
707
708        // Center target if fitting intercept
709        let y_centered = if self.fit_intercept {
710            let mean = y.mean().unwrap();
711            y - mean
712        } else {
713            y.clone()
714        };
715
716        // Initialize coefficients
717        let mut coef = Array1::zeros(n_features);
718
719        // Coordinate descent with SCAD penalty
720        for _iter in 0..self.max_iter {
721            let old_coef = coef.clone();
722
723            for j in 0..n_features {
724                // Compute residual without j-th feature
725                let mut residual = y_centered.clone();
726                for k in 0..n_features {
727                    if k != j {
728                        let col_k = X.column(k);
729                        residual = residual - coef[k] * &col_k;
730                    }
731                }
732
733                // Compute gradient and apply SCAD soft thresholding
734                let col_j = X.column(j);
735                let gradient = -col_j.dot(&residual) / n_samples as Float;
736                let hessian = col_j.dot(&col_j) / n_samples as Float;
737
738                if hessian > 0.0 {
739                    let beta_old = coef[j];
740                    let beta_unpenalized = beta_old - self.step_size * gradient / hessian;
741
742                    // Apply SCAD soft thresholding
743                    coef[j] = self.scad_soft_threshold(
744                        beta_unpenalized,
745                        self.lambda * self.step_size / hessian,
746                    );
747                }
748            }
749
750            // Check convergence
751            let coef_change = (&coef - &old_coef).mapv(|x| x.abs()).sum();
752            if coef_change < self.tol {
753                break;
754            }
755        }
756
757        Ok(coef)
758    }
759
760    /// SCAD soft thresholding operator
761    fn scad_soft_threshold(&self, beta: Float, threshold: Float) -> Float {
762        let abs_beta = beta.abs();
763        let sign = if beta >= 0.0 { 1.0 } else { -1.0 };
764
765        if abs_beta <= threshold {
766            // L1 penalty region
767            let shrunk = abs_beta - threshold;
768            if shrunk > 0.0 {
769                sign * shrunk
770            } else {
771                0.0
772            }
773        } else if abs_beta <= self.a * threshold {
774            // SCAD penalty region
775            let numerator = (self.a - 1.0) * beta - sign * self.a * threshold;
776            let denominator = self.a - 2.0;
777            numerator / denominator
778        } else {
779            // No penalty region
780            beta
781        }
782    }
783
784    /// Compute SCAD penalty value
785    pub fn penalty(&self, coef: &Array1<Float>) -> Float {
786        coef.iter()
787            .map(|&beta| self.scad_penalty_single(beta))
788            .sum()
789    }
790
791    /// SCAD penalty for a single coefficient
792    fn scad_penalty_single(&self, beta: Float) -> Float {
793        let abs_beta = beta.abs();
794
795        if abs_beta <= self.lambda {
796            self.lambda * abs_beta
797        } else if abs_beta <= self.a * self.lambda {
798            (2.0 * self.a * self.lambda * abs_beta - beta * beta - self.lambda * self.lambda)
799                / (2.0 * (self.a - 1.0))
800        } else {
801            self.lambda * self.lambda * (self.a + 1.0) / 2.0
802        }
803    }
804}
805
806impl Default for SCAD {
807    fn default() -> Self {
808        Self::new(1.0, 3.7).unwrap() // Standard choice a = 3.7
809    }
810}
811
812/// MCP (Minimax Concave Penalty) regularization
813///
814/// MCP provides a concave penalty that applies decreasing marginal penalty
815/// as coefficient magnitude increases, reducing bias more aggressively than SCAD
816/// while maintaining variable selection properties.
817///
818/// The MCP penalty function is:
819/// - For |β| ≤ γλ: λ|β| - β²/(2γ)
820/// - For |β| > γλ: γλ²/2
821///
822/// where λ is the regularization parameter and γ > 1 is a shape parameter.
823#[derive(Debug, Clone)]
824pub struct MCP {
825    /// Regularization strength (lambda)
826    pub lambda: Float,
827    /// Shape parameter (gamma > 1)
828    pub gamma: Float,
829    /// Maximum number of iterations
830    pub max_iter: usize,
831    /// Convergence tolerance
832    pub tol: Float,
833    /// Whether to fit intercept
834    pub fit_intercept: bool,
835    /// Step size for coordinate descent
836    pub step_size: Float,
837}
838
839impl MCP {
840    /// Create a new MCP regularizer
841    pub fn new(lambda: Float, gamma: Float) -> Result<Self> {
842        if gamma <= 1.0 {
843            return Err(SklearsError::InvalidInput(
844                "MCP parameter 'gamma' must be greater than 1.0".to_string(),
845            ));
846        }
847
848        Ok(Self {
849            lambda,
850            gamma,
851            max_iter: 1000,
852            tol: 1e-4,
853            fit_intercept: true,
854            step_size: 0.01,
855        })
856    }
857
858    /// Set maximum number of iterations
859    pub fn max_iter(mut self, max_iter: usize) -> Self {
860        self.max_iter = max_iter;
861        self
862    }
863
864    /// Set convergence tolerance
865    pub fn tol(mut self, tol: Float) -> Self {
866        self.tol = tol;
867        self
868    }
869
870    /// Set whether to fit intercept
871    pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
872        self.fit_intercept = fit_intercept;
873        self
874    }
875
876    /// Set step size for optimization
877    pub fn step_size(mut self, step_size: Float) -> Self {
878        self.step_size = step_size;
879        self
880    }
881
882    /// Fit MCP regularized regression using coordinate descent
883    pub fn fit(&self, X: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
884        let (n_samples, n_features) = X.dim();
885
886        if y.len() != n_samples {
887            return Err(SklearsError::InvalidInput(format!(
888                "X and y have incompatible shapes: {} vs {}",
889                n_samples,
890                y.len()
891            )));
892        }
893
894        // Center target if fitting intercept
895        let y_centered = if self.fit_intercept {
896            let mean = y.mean().unwrap();
897            y - mean
898        } else {
899            y.clone()
900        };
901
902        // Initialize coefficients
903        let mut coef = Array1::zeros(n_features);
904
905        // Coordinate descent with MCP penalty
906        for _iter in 0..self.max_iter {
907            let old_coef = coef.clone();
908
909            for j in 0..n_features {
910                // Compute residual without j-th feature
911                let mut residual = y_centered.clone();
912                for k in 0..n_features {
913                    if k != j {
914                        let col_k = X.column(k);
915                        residual = residual - coef[k] * &col_k;
916                    }
917                }
918
919                // Compute gradient and apply MCP soft thresholding
920                let col_j = X.column(j);
921                let gradient = -col_j.dot(&residual) / n_samples as Float;
922                let hessian = col_j.dot(&col_j) / n_samples as Float;
923
924                if hessian > 0.0 {
925                    let beta_old = coef[j];
926                    let beta_unpenalized = beta_old - self.step_size * gradient / hessian;
927
928                    // Apply MCP soft thresholding
929                    coef[j] = self.mcp_soft_threshold(
930                        beta_unpenalized,
931                        self.lambda * self.step_size / hessian,
932                    );
933                }
934            }
935
936            // Check convergence
937            let coef_change = (&coef - &old_coef).mapv(|x| x.abs()).sum();
938            if coef_change < self.tol {
939                break;
940            }
941        }
942
943        Ok(coef)
944    }
945
946    /// MCP soft thresholding operator
947    fn mcp_soft_threshold(&self, beta: Float, threshold: Float) -> Float {
948        let abs_beta = beta.abs();
949        let sign = if beta >= 0.0 { 1.0 } else { -1.0 };
950
951        if abs_beta <= self.gamma * threshold {
952            // MCP penalty region
953            let shrunk = abs_beta - threshold;
954            if shrunk > 0.0 {
955                let denominator = 1.0 - 1.0 / self.gamma;
956                sign * shrunk / denominator
957            } else {
958                0.0
959            }
960        } else {
961            // No penalty region
962            beta
963        }
964    }
965
966    /// Compute MCP penalty value
967    pub fn penalty(&self, coef: &Array1<Float>) -> Float {
968        coef.iter().map(|&beta| self.mcp_penalty_single(beta)).sum()
969    }
970
971    /// MCP penalty for a single coefficient
972    fn mcp_penalty_single(&self, beta: Float) -> Float {
973        let abs_beta = beta.abs();
974
975        if abs_beta <= self.gamma * self.lambda {
976            self.lambda * abs_beta - beta * beta / (2.0 * self.gamma)
977        } else {
978            self.gamma * self.lambda * self.lambda / 2.0
979        }
980    }
981}
982
983impl Default for MCP {
984    fn default() -> Self {
985        Self::new(1.0, 3.0).unwrap() // Standard choice gamma = 3.0
986    }
987}
988
989#[allow(non_snake_case)]
990#[cfg(test)]
991mod tests {
992    use super::*;
993    use scirs2_core::ndarray::array;
994
995    #[test]
996    fn test_elastic_net() {
997        let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
998        let y = array![1.0, 2.0, 3.0, 4.0];
999
1000        let elastic_net = ElasticNet::new(0.1, 0.5);
1001        let coef = elastic_net.fit(&X, &y).unwrap();
1002
1003        assert_eq!(coef.len(), 2);
1004
1005        // Test penalty computation
1006        let penalty = elastic_net.penalty(&coef);
1007        assert!(penalty >= 0.0);
1008    }
1009
1010    #[test]
1011    fn test_group_lasso() {
1012        let X = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
1013        let y = array![1.0, 2.0, 3.0];
1014        let groups = vec![vec![0, 1], vec![2]];
1015
1016        let group_lasso = GroupLasso::new(0.1, groups);
1017        let coef = group_lasso.fit(&X, &y).unwrap();
1018
1019        assert_eq!(coef.len(), 3);
1020
1021        let penalty = group_lasso.penalty(&coef);
1022        assert!(penalty >= 0.0);
1023    }
1024
1025    #[test]
1026    fn test_fused_lasso() {
1027        let X = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
1028        let y = array![1.0, 2.0, 3.0];
1029
1030        let fused_lasso = FusedLasso::new(0.1, 0.1);
1031        let coef = fused_lasso.fit(&X, &y).unwrap();
1032
1033        assert_eq!(coef.len(), 3);
1034
1035        let penalty = fused_lasso.penalty(&coef);
1036        assert!(penalty >= 0.0);
1037    }
1038
1039    #[test]
1040    fn test_adaptive_lasso() {
1041        let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
1042        let y = array![1.0, 2.0, 3.0, 4.0];
1043
1044        let adaptive_lasso = AdaptiveLasso::from_ols(0.1, &X, &y, 1.0).unwrap();
1045        let coef = adaptive_lasso.fit(&X, &y).unwrap();
1046
1047        assert_eq!(coef.len(), 2);
1048
1049        let penalty = adaptive_lasso.penalty(&coef);
1050        assert!(penalty >= 0.0);
1051    }
1052
1053    #[test]
1054    fn test_elastic_net_path() {
1055        let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
1056        let y = array![1.0, 2.0, 3.0, 4.0];
1057        let alphas = array![0.001, 0.01, 0.1, 1.0];
1058
1059        let elastic_net = ElasticNet::new(0.1, 0.5);
1060        let coef_path = elastic_net.path(&X, &y, &alphas).unwrap();
1061
1062        assert_eq!(coef_path.shape(), &[2, 4]);
1063    }
1064
1065    #[test]
1066    fn test_elastic_net_edge_cases() {
1067        let X = array![[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]];
1068        let y = array![1.0, 1.0, 2.0, 1.0];
1069
1070        // Pure L1 (Lasso) with weaker regularization
1071        let lasso = ElasticNet::new(0.01, 1.0);
1072        let coef_l1 = lasso.fit(&X, &y).unwrap();
1073
1074        // Pure L2 (Ridge) with weaker regularization
1075        let ridge = ElasticNet::new(0.01, 0.0);
1076        let coef_l2 = ridge.fit(&X, &y).unwrap();
1077
1078        // At least one coefficient should be non-zero
1079        assert!(coef_l1.iter().any(|&x| x.abs() > 1e-6) || coef_l2.iter().any(|&x| x.abs() > 1e-6));
1080    }
1081
1082    #[test]
1083    fn test_scad() {
1084        let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
1085        let y = array![1.0, 2.0, 3.0, 4.0];
1086
1087        let scad = SCAD::new(0.1, 3.7).unwrap();
1088        let coef = scad.fit(&X, &y).unwrap();
1089
1090        assert_eq!(coef.len(), 2);
1091
1092        let penalty = scad.penalty(&coef);
1093        assert!(penalty >= 0.0);
1094    }
1095
1096    #[test]
1097    fn test_scad_error_cases() {
1098        // Test invalid parameter a <= 2
1099        assert!(SCAD::new(0.1, 2.0).is_err());
1100        assert!(SCAD::new(0.1, 1.5).is_err());
1101
1102        // Valid parameter should work
1103        assert!(SCAD::new(0.1, 2.1).is_ok());
1104        assert!(SCAD::new(0.1, 3.7).is_ok());
1105    }
1106
1107    #[test]
1108    fn test_scad_penalty_function() {
1109        let scad = SCAD::new(1.0, 3.7).unwrap();
1110
1111        // Test penalty for small coefficients (L1 region)
1112        let coef_small = array![0.5, 0.3];
1113        let penalty_small = scad.penalty(&coef_small);
1114        let expected_small = 0.5 + 0.3; // L1 penalty
1115        assert!((penalty_small - expected_small).abs() < 1e-6);
1116
1117        // Test penalty for large coefficients (no penalty region)
1118        let coef_large = array![5.0, 4.0];
1119        let penalty_large = scad.penalty(&coef_large);
1120        let expected_large = 2.0 * (3.7 + 1.0) / 2.0; // Two coefficients in no-penalty region
1121        assert!((penalty_large - expected_large).abs() < 1e-6);
1122    }
1123
1124    #[test]
1125    fn test_mcp() {
1126        let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
1127        let y = array![1.0, 2.0, 3.0, 4.0];
1128
1129        let mcp = MCP::new(0.1, 3.0).unwrap();
1130        let coef = mcp.fit(&X, &y).unwrap();
1131
1132        assert_eq!(coef.len(), 2);
1133
1134        let penalty = mcp.penalty(&coef);
1135        assert!(penalty >= 0.0);
1136    }
1137
1138    #[test]
1139    fn test_mcp_error_cases() {
1140        // Test invalid parameter gamma <= 1
1141        assert!(MCP::new(0.1, 1.0).is_err());
1142        assert!(MCP::new(0.1, 0.5).is_err());
1143
1144        // Valid parameter should work
1145        assert!(MCP::new(0.1, 1.1).is_ok());
1146        assert!(MCP::new(0.1, 3.0).is_ok());
1147    }
1148
1149    #[test]
1150    fn test_mcp_penalty_function() {
1151        let mcp = MCP::new(1.0, 3.0).unwrap();
1152
1153        // Test penalty for small coefficients (MCP region)
1154        let coef_small = array![0.5, 0.3];
1155        let penalty_small = mcp.penalty(&coef_small);
1156        let expected_small = (0.5 - 0.5 * 0.5 / (2.0 * 3.0)) + (0.3 - 0.3 * 0.3 / (2.0 * 3.0));
1157        assert!((penalty_small - expected_small).abs() < 1e-6);
1158
1159        // Test penalty for large coefficients (no penalty region)
1160        let coef_large = array![5.0, 4.0];
1161        let penalty_large = mcp.penalty(&coef_large);
1162        let expected_large = 2.0 * (3.0 * 1.0 * 1.0 / 2.0); // Two coefficients in no-penalty region
1163        assert!((penalty_large - expected_large).abs() < 1e-6);
1164    }
1165
1166    #[test]
1167    fn test_scad_vs_mcp_penalties() {
1168        let scad = SCAD::new(0.1, 3.7).unwrap();
1169        let mcp = MCP::new(0.1, 3.0).unwrap();
1170
1171        // For very small coefficients, both should behave like L1
1172        let coef_tiny = array![0.01, 0.02];
1173        let scad_penalty = scad.penalty(&coef_tiny);
1174        let mcp_penalty = mcp.penalty(&coef_tiny);
1175        let l1_penalty = 0.1 * (0.01 + 0.02);
1176
1177        // SCAD should be approximately L1 for small coefficients
1178        assert!((scad_penalty - l1_penalty).abs() < 1e-3);
1179
1180        // MCP should be less than L1 for small coefficients due to concavity
1181        assert!(mcp_penalty < l1_penalty);
1182    }
1183
1184    #[test]
1185    fn test_regularization_builders() {
1186        // Test SCAD builder pattern
1187        let scad = SCAD::new(0.1, 3.7)
1188            .unwrap()
1189            .max_iter(500)
1190            .tol(1e-6)
1191            .fit_intercept(false)
1192            .step_size(0.02);
1193
1194        assert_eq!(scad.max_iter, 500);
1195        assert_eq!(scad.tol, 1e-6);
1196        assert_eq!(scad.fit_intercept, false);
1197        assert_eq!(scad.step_size, 0.02);
1198
1199        // Test MCP builder pattern
1200        let mcp = MCP::new(0.2, 2.5)
1201            .unwrap()
1202            .max_iter(800)
1203            .tol(1e-5)
1204            .fit_intercept(true)
1205            .step_size(0.05);
1206
1207        assert_eq!(mcp.max_iter, 800);
1208        assert_eq!(mcp.tol, 1e-5);
1209        assert_eq!(mcp.fit_intercept, true);
1210        assert_eq!(mcp.step_size, 0.05);
1211    }
1212}