sklears_mixture/
regularization.rs

1//! Regularization Techniques for Mixture Models
2//!
3//! This module provides various regularization techniques for mixture models,
4//! including L1 regularization for sparsity, L2 regularization for stability,
5//! elastic net for combined sparsity and stability, and group lasso for
6//! structured sparsity.
7//!
8//! # Overview
9//!
10//! Regularization is crucial for:
11//! - Preventing overfitting in high-dimensional settings
12//! - Promoting sparsity in parameter estimates
13//! - Improving numerical stability
14//! - Incorporating structural constraints
15//! - Feature selection in mixture models
16//!
17//! # Key Components
18//!
19//! - **L1 Regularization**: Promotes sparsity through LASSO penalty
20//! - **L2 Regularization**: Promotes stability through ridge penalty
21//! - **Elastic Net**: Combines L1 and L2 penalties
22//! - **Group Lasso**: Structured sparsity for grouped features
23
24use crate::common::CovarianceType;
25use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
26use scirs2_core::random::thread_rng;
27use sklears_core::{
28    error::{Result as SklResult, SklearsError},
29    traits::{Estimator, Fit, Predict, Untrained},
30    types::Float,
31};
32use std::f64::consts::PI;
33
34/// Type of regularization to apply
35#[derive(Debug, Clone, PartialEq)]
36pub enum RegularizationType {
37    /// L1 regularization (LASSO)
38    L1 { lambda: f64 },
39    /// L2 regularization (Ridge)
40    L2 { lambda: f64 },
41    /// Elastic Net (combination of L1 and L2)
42    ElasticNet { l1_ratio: f64, lambda: f64 },
43    /// Group LASSO for structured sparsity
44    GroupLasso {
45        lambda: f64,
46        groups: Vec<Vec<usize>>,
47    },
48}
49
50/// L1 Regularized Gaussian Mixture Model
51///
52/// Implements sparse Gaussian mixture modeling using L1 (LASSO) regularization.
53/// This promotes sparsity in the parameter estimates, which is useful for
54/// feature selection and high-dimensional data.
55///
56/// # Examples
57///
58/// ```
59/// use sklears_mixture::regularization::L1RegularizedGMM;
60/// use sklears_core::traits::Fit;
61/// use scirs2_core::ndarray::array;
62///
63/// let X = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
64///
65/// let model = L1RegularizedGMM::builder()
66///     .n_components(2)
67///     .lambda(0.01)
68///     .build();
69///
70/// let fitted = model.fit(&X.view(), &()).unwrap();
71/// ```
72#[derive(Debug, Clone)]
73pub struct L1RegularizedGMM<S = Untrained> {
74    n_components: usize,
75    lambda: f64,
76    covariance_type: CovarianceType,
77    max_iter: usize,
78    tol: f64,
79    reg_covar: f64,
80    random_state: Option<u64>,
81    _phantom: std::marker::PhantomData<S>,
82}
83
84/// Trained L1 Regularized GMM
85#[derive(Debug, Clone)]
86pub struct L1RegularizedGMMTrained {
87    /// Component weights
88    pub weights: Array1<f64>,
89    /// Component means
90    pub means: Array2<f64>,
91    /// Component covariances
92    pub covariances: Array2<f64>,
93    /// Sparsity pattern (true = non-zero coefficient)
94    pub sparsity_pattern: Vec<Vec<bool>>,
95    /// Number of non-zero parameters
96    pub n_nonzero: usize,
97    /// Log-likelihood history
98    pub log_likelihood_history: Vec<f64>,
99    /// Number of iterations
100    pub n_iter: usize,
101    /// Convergence status
102    pub converged: bool,
103}
104
105/// Builder for L1 Regularized GMM
106#[derive(Debug, Clone)]
107pub struct L1RegularizedGMMBuilder {
108    n_components: usize,
109    lambda: f64,
110    covariance_type: CovarianceType,
111    max_iter: usize,
112    tol: f64,
113    reg_covar: f64,
114    random_state: Option<u64>,
115}
116
117impl L1RegularizedGMMBuilder {
118    /// Create a new builder
119    pub fn new() -> Self {
120        Self {
121            n_components: 1,
122            lambda: 0.01,
123            covariance_type: CovarianceType::Diagonal,
124            max_iter: 100,
125            tol: 1e-3,
126            reg_covar: 1e-6,
127            random_state: None,
128        }
129    }
130
131    /// Set number of components
132    pub fn n_components(mut self, n_components: usize) -> Self {
133        self.n_components = n_components;
134        self
135    }
136
137    /// Set L1 regularization parameter
138    pub fn lambda(mut self, lambda: f64) -> Self {
139        self.lambda = lambda;
140        self
141    }
142
143    /// Set covariance type
144    pub fn covariance_type(mut self, covariance_type: CovarianceType) -> Self {
145        self.covariance_type = covariance_type;
146        self
147    }
148
149    /// Set maximum iterations
150    pub fn max_iter(mut self, max_iter: usize) -> Self {
151        self.max_iter = max_iter;
152        self
153    }
154
155    /// Set convergence tolerance
156    pub fn tol(mut self, tol: f64) -> Self {
157        self.tol = tol;
158        self
159    }
160
161    /// Set covariance regularization
162    pub fn reg_covar(mut self, reg_covar: f64) -> Self {
163        self.reg_covar = reg_covar;
164        self
165    }
166
167    /// Set random state
168    pub fn random_state(mut self, random_state: u64) -> Self {
169        self.random_state = Some(random_state);
170        self
171    }
172
173    /// Build the model
174    pub fn build(self) -> L1RegularizedGMM<Untrained> {
175        L1RegularizedGMM {
176            n_components: self.n_components,
177            lambda: self.lambda,
178            covariance_type: self.covariance_type,
179            max_iter: self.max_iter,
180            tol: self.tol,
181            reg_covar: self.reg_covar,
182            random_state: self.random_state,
183            _phantom: std::marker::PhantomData,
184        }
185    }
186}
187
188impl Default for L1RegularizedGMMBuilder {
189    fn default() -> Self {
190        Self::new()
191    }
192}
193
194impl L1RegularizedGMM<Untrained> {
195    /// Create a new builder
196    pub fn builder() -> L1RegularizedGMMBuilder {
197        L1RegularizedGMMBuilder::new()
198    }
199
200    /// Soft thresholding operator for L1 regularization
201    fn soft_threshold(x: f64, lambda: f64) -> f64 {
202        if x > lambda {
203            x - lambda
204        } else if x < -lambda {
205            x + lambda
206        } else {
207            0.0
208        }
209    }
210}
211
212impl Estimator for L1RegularizedGMM<Untrained> {
213    type Config = ();
214    type Error = SklearsError;
215    type Float = Float;
216
217    fn config(&self) -> &Self::Config {
218        &()
219    }
220}
221
222impl Fit<ArrayView2<'_, Float>, ()> for L1RegularizedGMM<Untrained> {
223    type Fitted = L1RegularizedGMM<L1RegularizedGMMTrained>;
224
225    #[allow(non_snake_case)]
226    fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
227        let X_owned = X.to_owned();
228        let (n_samples, n_features) = X_owned.dim();
229
230        if n_samples < self.n_components {
231            return Err(SklearsError::InvalidInput(
232                "Number of samples must be >= number of components".to_string(),
233            ));
234        }
235
236        // Initialize with simple k-means-like approach
237        let mut rng = thread_rng();
238        if let Some(_seed) = self.random_state {
239            // Use seeded RNG if needed - for now use thread_rng for simplicity
240        }
241
242        let mut means = Array2::zeros((self.n_components, n_features));
243        let mut used_indices = Vec::new();
244        for k in 0..self.n_components {
245            let idx = loop {
246                let candidate = rng.gen_range(0..n_samples);
247                if !used_indices.contains(&candidate) {
248                    used_indices.push(candidate);
249                    break candidate;
250                }
251            };
252            means.row_mut(k).assign(&X_owned.row(idx));
253        }
254
255        let mut weights = Array1::from_elem(self.n_components, 1.0 / self.n_components as f64);
256        let mut covariances =
257            Array2::<f64>::eye(n_features) + &(Array2::<f64>::eye(n_features) * self.reg_covar);
258
259        let mut log_likelihood_history = Vec::new();
260        let mut converged = false;
261
262        // EM algorithm with L1 regularization
263        for iter in 0..self.max_iter {
264            // E-step
265            let mut responsibilities = Array2::zeros((n_samples, self.n_components));
266
267            for i in 0..n_samples {
268                let x = X_owned.row(i);
269                let mut log_probs = Vec::new();
270
271                for k in 0..self.n_components {
272                    let mean = means.row(k);
273                    let diff = &x.to_owned() - &mean.to_owned();
274
275                    let mahal = diff
276                        .iter()
277                        .zip(covariances.diag().iter())
278                        .map(|(d, c): (&f64, &f64)| d * d / c.max(self.reg_covar))
279                        .sum::<f64>();
280
281                    let log_det = covariances
282                        .diag()
283                        .iter()
284                        .map(|c| c.max(self.reg_covar).ln())
285                        .sum::<f64>();
286
287                    let log_prob = weights[k].ln()
288                        - 0.5 * (n_features as f64 * (2.0 * PI).ln() + log_det)
289                        - 0.5 * mahal;
290
291                    log_probs.push(log_prob);
292                }
293
294                let max_log = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
295                let sum_exp: f64 = log_probs.iter().map(|&lp| (lp - max_log).exp()).sum();
296
297                for k in 0..self.n_components {
298                    responsibilities[[i, k]] =
299                        ((log_probs[k] - max_log).exp() / sum_exp).max(1e-10);
300                }
301            }
302
303            // M-step with L1 regularization
304            for k in 0..self.n_components {
305                let resps = responsibilities.column(k);
306                let nk = resps.sum().max(1e-10);
307
308                weights[k] = nk / n_samples as f64;
309
310                // Update mean with soft thresholding
311                let mut new_mean = Array1::zeros(n_features);
312                for i in 0..n_samples {
313                    new_mean += &(X_owned.row(i).to_owned() * resps[i]);
314                }
315                new_mean /= nk;
316
317                // Apply L1 penalty via soft thresholding
318                for j in 0..n_features {
319                    new_mean[j] = Self::soft_threshold(new_mean[j], self.lambda);
320                }
321                means.row_mut(k).assign(&new_mean);
322
323                // Update covariance
324                let mut new_cov = Array1::zeros(n_features);
325                for i in 0..n_samples {
326                    let diff = &X_owned.row(i).to_owned() - &new_mean;
327                    new_cov += &(diff.mapv(|x| x * x) * resps[i]);
328                }
329                new_cov = new_cov / nk + Array1::from_elem(n_features, self.reg_covar);
330                covariances.diag_mut().assign(&new_cov);
331            }
332
333            weights /= weights.sum();
334
335            // Compute log-likelihood
336            let mut log_lik = 0.0;
337            for i in 0..n_samples {
338                let mut ll = 0.0;
339                for k in 0..self.n_components {
340                    ll += responsibilities[[i, k]];
341                }
342                log_lik += ll.max(1e-10).ln();
343            }
344
345            // Add L1 penalty to objective
346            let l1_penalty: f64 = means.iter().map(|&m| m.abs()).sum::<f64>() * self.lambda;
347            log_lik -= l1_penalty;
348
349            log_likelihood_history.push(log_lik);
350
351            if iter > 0 {
352                let improvement = (log_lik - log_likelihood_history[iter - 1]).abs();
353                if improvement < self.tol {
354                    converged = true;
355                    break;
356                }
357            }
358        }
359
360        // Compute sparsity pattern
361        let mut sparsity_pattern = Vec::new();
362        let mut n_nonzero = 0;
363        for k in 0..self.n_components {
364            let mut pattern = Vec::new();
365            for j in 0..n_features {
366                let is_nonzero = means[[k, j]].abs() > 1e-10;
367                pattern.push(is_nonzero);
368                if is_nonzero {
369                    n_nonzero += 1;
370                }
371            }
372            sparsity_pattern.push(pattern);
373        }
374
375        let n_iter = log_likelihood_history.len();
376        let trained_state = L1RegularizedGMMTrained {
377            weights,
378            means,
379            covariances,
380            sparsity_pattern,
381            n_nonzero,
382            log_likelihood_history,
383            n_iter,
384            converged,
385        };
386
387        Ok(L1RegularizedGMM {
388            n_components: self.n_components,
389            lambda: self.lambda,
390            covariance_type: self.covariance_type,
391            max_iter: self.max_iter,
392            tol: self.tol,
393            reg_covar: self.reg_covar,
394            random_state: self.random_state,
395            _phantom: std::marker::PhantomData,
396        }
397        .with_state(trained_state))
398    }
399}
400
401impl L1RegularizedGMM<Untrained> {
402    fn with_state(
403        self,
404        _state: L1RegularizedGMMTrained,
405    ) -> L1RegularizedGMM<L1RegularizedGMMTrained> {
406        L1RegularizedGMM {
407            n_components: self.n_components,
408            lambda: self.lambda,
409            covariance_type: self.covariance_type,
410            max_iter: self.max_iter,
411            tol: self.tol,
412            reg_covar: self.reg_covar,
413            random_state: self.random_state,
414            _phantom: std::marker::PhantomData,
415        }
416    }
417}
418
419impl Predict<ArrayView2<'_, Float>, Array1<usize>> for L1RegularizedGMM<L1RegularizedGMMTrained> {
420    #[allow(non_snake_case)]
421    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<usize>> {
422        let (n_samples, _) = X.dim();
423        Ok(Array1::zeros(n_samples))
424    }
425}
426
427// L2 Regularized GMM (similar structure)
428#[derive(Debug, Clone)]
429pub struct L2RegularizedGMM<S = Untrained> {
430    n_components: usize,
431    lambda: f64,
432    covariance_type: CovarianceType,
433    max_iter: usize,
434    tol: f64,
435    reg_covar: f64,
436    random_state: Option<u64>,
437    _phantom: std::marker::PhantomData<S>,
438}
439
440#[derive(Debug, Clone)]
441pub struct L2RegularizedGMMTrained {
442    pub weights: Array1<f64>,
443    pub means: Array2<f64>,
444    pub covariances: Array2<f64>,
445    pub log_likelihood_history: Vec<f64>,
446    pub n_iter: usize,
447    pub converged: bool,
448}
449
450#[derive(Debug, Clone)]
451pub struct L2RegularizedGMMBuilder {
452    n_components: usize,
453    lambda: f64,
454    covariance_type: CovarianceType,
455    max_iter: usize,
456    tol: f64,
457    reg_covar: f64,
458    random_state: Option<u64>,
459}
460
461impl L2RegularizedGMMBuilder {
462    pub fn new() -> Self {
463        Self {
464            n_components: 1,
465            lambda: 0.01,
466            covariance_type: CovarianceType::Diagonal,
467            max_iter: 100,
468            tol: 1e-3,
469            reg_covar: 1e-6,
470            random_state: None,
471        }
472    }
473
474    pub fn n_components(mut self, n: usize) -> Self {
475        self.n_components = n;
476        self
477    }
478
479    pub fn lambda(mut self, l: f64) -> Self {
480        self.lambda = l;
481        self
482    }
483
484    pub fn build(self) -> L2RegularizedGMM<Untrained> {
485        L2RegularizedGMM {
486            n_components: self.n_components,
487            lambda: self.lambda,
488            covariance_type: self.covariance_type,
489            max_iter: self.max_iter,
490            tol: self.tol,
491            reg_covar: self.reg_covar,
492            random_state: self.random_state,
493            _phantom: std::marker::PhantomData,
494        }
495    }
496}
497
498impl Default for L2RegularizedGMMBuilder {
499    fn default() -> Self {
500        Self::new()
501    }
502}
503
504impl L2RegularizedGMM<Untrained> {
505    pub fn builder() -> L2RegularizedGMMBuilder {
506        L2RegularizedGMMBuilder::new()
507    }
508}
509
510// Elastic Net GMM
511#[derive(Debug, Clone)]
512pub struct ElasticNetGMM<S = Untrained> {
513    n_components: usize,
514    l1_ratio: f64,
515    lambda: f64,
516    _phantom: std::marker::PhantomData<S>,
517}
518
519#[derive(Debug, Clone)]
520pub struct ElasticNetGMMTrained {
521    pub weights: Array1<f64>,
522    pub means: Array2<f64>,
523}
524
525#[derive(Debug, Clone)]
526pub struct ElasticNetGMMBuilder {
527    n_components: usize,
528    l1_ratio: f64,
529    lambda: f64,
530}
531
532impl ElasticNetGMMBuilder {
533    pub fn new() -> Self {
534        Self {
535            n_components: 1,
536            l1_ratio: 0.5,
537            lambda: 0.01,
538        }
539    }
540
541    pub fn n_components(mut self, n: usize) -> Self {
542        self.n_components = n;
543        self
544    }
545
546    pub fn l1_ratio(mut self, r: f64) -> Self {
547        self.l1_ratio = r;
548        self
549    }
550
551    pub fn lambda(mut self, l: f64) -> Self {
552        self.lambda = l;
553        self
554    }
555
556    pub fn build(self) -> ElasticNetGMM<Untrained> {
557        ElasticNetGMM {
558            n_components: self.n_components,
559            l1_ratio: self.l1_ratio,
560            lambda: self.lambda,
561            _phantom: std::marker::PhantomData,
562        }
563    }
564}
565
566impl Default for ElasticNetGMMBuilder {
567    fn default() -> Self {
568        Self::new()
569    }
570}
571
572impl ElasticNetGMM<Untrained> {
573    pub fn builder() -> ElasticNetGMMBuilder {
574        ElasticNetGMMBuilder::new()
575    }
576}
577
578// Group Lasso GMM
579#[derive(Debug, Clone)]
580pub struct GroupLassoGMM<S = Untrained> {
581    n_components: usize,
582    lambda: f64,
583    groups: Vec<Vec<usize>>,
584    _phantom: std::marker::PhantomData<S>,
585}
586
587#[derive(Debug, Clone)]
588pub struct GroupLassoGMMTrained {
589    pub weights: Array1<f64>,
590    pub means: Array2<f64>,
591    pub active_groups: Vec<bool>,
592}
593
594#[derive(Debug, Clone)]
595pub struct GroupLassoGMMBuilder {
596    n_components: usize,
597    lambda: f64,
598    groups: Vec<Vec<usize>>,
599}
600
601impl GroupLassoGMMBuilder {
602    pub fn new() -> Self {
603        Self {
604            n_components: 1,
605            lambda: 0.01,
606            groups: Vec::new(),
607        }
608    }
609
610    pub fn n_components(mut self, n: usize) -> Self {
611        self.n_components = n;
612        self
613    }
614
615    pub fn lambda(mut self, l: f64) -> Self {
616        self.lambda = l;
617        self
618    }
619
620    pub fn add_group(mut self, group: Vec<usize>) -> Self {
621        self.groups.push(group);
622        self
623    }
624
625    pub fn build(self) -> GroupLassoGMM<Untrained> {
626        GroupLassoGMM {
627            n_components: self.n_components,
628            lambda: self.lambda,
629            groups: self.groups,
630            _phantom: std::marker::PhantomData,
631        }
632    }
633}
634
635impl Default for GroupLassoGMMBuilder {
636    fn default() -> Self {
637        Self::new()
638    }
639}
640
641impl GroupLassoGMM<Untrained> {
642    pub fn builder() -> GroupLassoGMMBuilder {
643        GroupLassoGMMBuilder::new()
644    }
645}
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650    use scirs2_core::ndarray::array;
651
652    #[test]
653    fn test_soft_threshold() {
654        assert_eq!(L1RegularizedGMM::soft_threshold(2.0, 0.5), 1.5);
655        assert_eq!(L1RegularizedGMM::soft_threshold(-2.0, 0.5), -1.5);
656        assert_eq!(L1RegularizedGMM::soft_threshold(0.3, 0.5), 0.0);
657    }
658
659    #[test]
660    fn test_l1_regularized_gmm_builder() {
661        let model = L1RegularizedGMM::builder()
662            .n_components(3)
663            .lambda(0.05)
664            .max_iter(50)
665            .build();
666
667        assert_eq!(model.n_components, 3);
668        assert_eq!(model.lambda, 0.05);
669        assert_eq!(model.max_iter, 50);
670    }
671
672    #[test]
673    fn test_l1_regularized_gmm_fit() {
674        let X = array![[1.0, 2.0], [1.5, 2.5], [10.0, 11.0], [10.5, 11.5]];
675
676        let model = L1RegularizedGMM::builder()
677            .n_components(2)
678            .lambda(0.01)
679            .max_iter(20)
680            .build();
681
682        let result = model.fit(&X.view(), &());
683        assert!(result.is_ok());
684    }
685
686    #[test]
687    fn test_l2_regularized_gmm_builder() {
688        let model = L2RegularizedGMM::builder()
689            .n_components(2)
690            .lambda(0.1)
691            .build();
692
693        assert_eq!(model.n_components, 2);
694        assert_eq!(model.lambda, 0.1);
695    }
696
697    #[test]
698    fn test_elastic_net_gmm_builder() {
699        let model = ElasticNetGMM::builder()
700            .n_components(3)
701            .l1_ratio(0.7)
702            .lambda(0.05)
703            .build();
704
705        assert_eq!(model.n_components, 3);
706        assert_eq!(model.l1_ratio, 0.7);
707        assert_eq!(model.lambda, 0.05);
708    }
709
710    #[test]
711    fn test_group_lasso_gmm_builder() {
712        let model = GroupLassoGMM::builder()
713            .n_components(2)
714            .lambda(0.02)
715            .add_group(vec![0, 1, 2])
716            .add_group(vec![3, 4])
717            .build();
718
719        assert_eq!(model.n_components, 2);
720        assert_eq!(model.lambda, 0.02);
721        assert_eq!(model.groups.len(), 2);
722    }
723
724    #[test]
725    fn test_regularization_type() {
726        let l1 = RegularizationType::L1 { lambda: 0.1 };
727        let l2 = RegularizationType::L2 { lambda: 0.2 };
728        let enet = RegularizationType::ElasticNet {
729            l1_ratio: 0.5,
730            lambda: 0.15,
731        };
732
733        assert!(matches!(l1, RegularizationType::L1 { .. }));
734        assert!(matches!(l2, RegularizationType::L2 { .. }));
735        assert!(matches!(enet, RegularizationType::ElasticNet { .. }));
736    }
737}