Skip to main content

scry_learn/linear/
elastic_net.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! ElasticNet regression via coordinate descent (L1 + L2 regularization).
3//!
4//! Combines L1 (Lasso) and L2 (Ridge) penalties, controlled by `l1_ratio`:
5//! - `l1_ratio = 1.0` → pure Lasso
6//! - `l1_ratio = 0.0` → pure Ridge
7//! - `0 < l1_ratio < 1` → mixed penalty
8
9use crate::dataset::Dataset;
10use crate::error::{Result, ScryLearnError};
11use crate::sparse::{CscMatrix, CsrMatrix};
12
13/// ElasticNet regression (mixed L1 + L2 regularization).
14///
15/// Uses coordinate descent to minimize:
16///
17/// ```text
18/// (1 / 2n) ‖y − Xβ − β₀‖² + α × l1_ratio × ‖β‖₁ + α × (1 − l1_ratio) / 2 × ‖β‖²
19/// ```
20///
21/// # Example
22/// ```
23/// use scry_learn::dataset::Dataset;
24/// use scry_learn::linear::ElasticNet;
25///
26/// let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
27/// let target = vec![2.1, 4.0, 5.9, 8.1, 10.0];
28/// let data = Dataset::new(features, target, vec!["x".into()], "y");
29///
30/// let mut model = ElasticNet::new().alpha(0.1).l1_ratio(0.5);
31/// model.fit(&data).unwrap();
32/// let preds = model.predict(&[vec![3.0]]).unwrap();
33/// ```
34#[derive(Clone, Debug)]
35#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
36#[non_exhaustive]
37pub struct ElasticNet {
38    /// Overall regularization strength.
39    alpha: f64,
40    /// Mix between L1 and L2 (0.0 = pure Ridge, 1.0 = pure Lasso).
41    l1_ratio: f64,
42    /// Maximum coordinate descent iterations.
43    max_iter: usize,
44    /// Convergence tolerance.
45    tol: f64,
46    /// Learned coefficients (one per feature).
47    coefficients: Vec<f64>,
48    /// Learned intercept.
49    intercept: f64,
50    /// Whether the model has been fitted.
51    fitted: bool,
52    #[cfg_attr(feature = "serde", serde(default))]
53    _schema_version: u32,
54}
55
56impl ElasticNet {
57    /// Create a new ElasticNet with default parameters.
58    pub fn new() -> Self {
59        Self {
60            alpha: 1.0,
61            l1_ratio: 0.5,
62            max_iter: 1000,
63            tol: crate::constants::DEFAULT_TOL,
64            coefficients: Vec::new(),
65            intercept: 0.0,
66            fitted: false,
67            _schema_version: crate::version::SCHEMA_VERSION,
68        }
69    }
70
71    /// Set the overall regularization strength.
72    pub fn alpha(mut self, a: f64) -> Self {
73        self.alpha = a;
74        self
75    }
76
77    /// Set the L1/L2 mix ratio (0.0 = Ridge, 1.0 = Lasso).
78    pub fn l1_ratio(mut self, r: f64) -> Self {
79        self.l1_ratio = r;
80        self
81    }
82
83    /// Set the maximum number of iterations.
84    pub fn max_iter(mut self, n: usize) -> Self {
85        self.max_iter = n;
86        self
87    }
88
89    /// Set convergence tolerance.
90    pub fn tol(mut self, t: f64) -> Self {
91        self.tol = t;
92        self
93    }
94
95    /// Fit the ElasticNet model using coordinate descent.
96    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
97        data.validate_finite()?;
98        if let Some(csc) = data.sparse_csc() {
99            return self.fit_sparse(csc, &data.target);
100        }
101        let n = data.n_samples();
102        let m = data.n_features();
103        if n == 0 {
104            return Err(ScryLearnError::EmptyDataset);
105        }
106        if self.alpha < 0.0 {
107            return Err(ScryLearnError::InvalidParameter(
108                "alpha must be >= 0".into(),
109            ));
110        }
111        if !(0.0..=1.0).contains(&self.l1_ratio) {
112            return Err(ScryLearnError::InvalidParameter(
113                "l1_ratio must be in [0, 1]".into(),
114            ));
115        }
116
117        let y = &data.target;
118
119        let mut beta = vec![0.0; m];
120        let mut intercept = y.iter().sum::<f64>() / n as f64;
121
122        // Precompute ‖X_j‖² / n (using column-major data directly).
123        let mut col_norm_sq: Vec<f64> = vec![0.0; m];
124        for j in 0..m {
125            let col = &data.features[j];
126            let mut sq = 0.0;
127            for &x in col {
128                sq += x * x;
129            }
130            col_norm_sq[j] = sq / n as f64;
131        }
132
133        let n_f64 = n as f64;
134        let l1_pen = self.alpha * self.l1_ratio;
135        let l2_pen = self.alpha * (1.0 - self.l1_ratio);
136
137        // Initialize residuals: r_i = y_i - intercept
138        let mut residuals: Vec<f64> = y.iter().map(|&yi| yi - intercept).collect();
139
140        for _iter in 0..self.max_iter {
141            let mut max_change = 0.0_f64;
142
143            // Update intercept: shift by mean of residuals.
144            let r_mean = residuals.iter().sum::<f64>() / n_f64;
145            let new_intercept = intercept + r_mean;
146            max_change = max_change.max((new_intercept - intercept).abs());
147            for r in &mut residuals {
148                *r -= r_mean;
149            }
150            intercept = new_intercept;
151
152            // Coordinate descent over each feature.
153            for j in 0..m {
154                if col_norm_sq[j] < crate::constants::NEAR_ZERO {
155                    continue;
156                }
157
158                let old_beta_j = beta[j];
159                let col = &data.features[j];
160
161                // Add back current j contribution to residuals.
162                if old_beta_j != 0.0 {
163                    for i in 0..n {
164                        residuals[i] += col[i] * old_beta_j;
165                    }
166                }
167
168                // ρ = (1/n) Σ x_ij * r_i
169                let mut rho = 0.0;
170                for i in 0..n {
171                    rho += col[i] * residuals[i];
172                }
173                rho /= n_f64;
174
175                // ElasticNet update: β_j = S(ρ, l1_pen) / (‖X_j‖²/n + l2_pen)
176                let new_beta_j = soft_threshold(rho, l1_pen) / (col_norm_sq[j] + l2_pen);
177                max_change = max_change.max((new_beta_j - old_beta_j).abs());
178                beta[j] = new_beta_j;
179
180                // Remove new j contribution from residuals.
181                if new_beta_j != 0.0 {
182                    for i in 0..n {
183                        residuals[i] -= col[i] * new_beta_j;
184                    }
185                }
186            }
187
188            if max_change < self.tol {
189                break;
190            }
191        }
192
193        self.coefficients = beta;
194        self.intercept = intercept;
195        self.fitted = true;
196        Ok(())
197    }
198
199    /// Predict target values for new samples.
200    ///
201    /// `features` is row-major: `features[sample_idx][feature_idx]`.
202    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
203        crate::version::check_schema_version(self._schema_version)?;
204        if !self.fitted {
205            return Err(ScryLearnError::NotFitted);
206        }
207        Ok(features
208            .iter()
209            .map(|row| {
210                row.iter()
211                    .zip(self.coefficients.iter())
212                    .map(|(x, b)| x * b)
213                    .sum::<f64>()
214                    + self.intercept
215            })
216            .collect())
217    }
218
219    /// Fit on sparse features using coordinate descent.
220    #[allow(clippy::needless_range_loop)]
221    pub fn fit_sparse(&mut self, features: &CscMatrix, target: &[f64]) -> Result<()> {
222        let n = features.n_rows();
223        let m = features.n_cols();
224        if n == 0 {
225            return Err(ScryLearnError::EmptyDataset);
226        }
227        if target.len() != n {
228            return Err(ScryLearnError::InvalidParameter(format!(
229                "target length {} != n_rows {}",
230                target.len(),
231                n
232            )));
233        }
234        if self.alpha < 0.0 {
235            return Err(ScryLearnError::InvalidParameter(
236                "alpha must be >= 0".into(),
237            ));
238        }
239        if !(0.0..=1.0).contains(&self.l1_ratio) {
240            return Err(ScryLearnError::InvalidParameter(
241                "l1_ratio must be in [0, 1]".into(),
242            ));
243        }
244
245        let n_f64 = n as f64;
246        let l1_pen = self.alpha * self.l1_ratio;
247        let l2_pen = self.alpha * (1.0 - self.l1_ratio);
248
249        let mut beta = vec![0.0; m];
250        let mut intercept = target.iter().sum::<f64>() / n_f64;
251
252        let mut col_norm_sq = vec![0.0; m];
253        for j in 0..m {
254            let mut sq_sum = 0.0;
255            for (_, val) in features.col(j).iter() {
256                sq_sum += val * val;
257            }
258            col_norm_sq[j] = sq_sum / n_f64;
259        }
260
261        let mut residuals: Vec<f64> = target.iter().map(|&y| y - intercept).collect();
262
263        for _iter in 0..self.max_iter {
264            let mut max_change = 0.0_f64;
265
266            let r_mean = residuals.iter().sum::<f64>() / n_f64;
267            let new_intercept = intercept + r_mean;
268            max_change = max_change.max((new_intercept - intercept).abs());
269            for r in &mut residuals {
270                *r -= r_mean;
271            }
272            intercept = new_intercept;
273
274            for j in 0..m {
275                if col_norm_sq[j] < crate::constants::NEAR_ZERO {
276                    continue;
277                }
278
279                let old_beta_j = beta[j];
280
281                if old_beta_j != 0.0 {
282                    for (row_idx, val) in features.col(j).iter() {
283                        residuals[row_idx] += val * old_beta_j;
284                    }
285                }
286
287                let mut rho = 0.0;
288                for (row_idx, val) in features.col(j).iter() {
289                    rho += val * residuals[row_idx];
290                }
291                rho /= n_f64;
292
293                let new_beta_j = soft_threshold(rho, l1_pen) / (col_norm_sq[j] + l2_pen);
294                max_change = max_change.max((new_beta_j - old_beta_j).abs());
295                beta[j] = new_beta_j;
296
297                if new_beta_j != 0.0 {
298                    for (row_idx, val) in features.col(j).iter() {
299                        residuals[row_idx] -= val * new_beta_j;
300                    }
301                }
302            }
303
304            if max_change < self.tol {
305                break;
306            }
307        }
308
309        self.coefficients = beta;
310        self.intercept = intercept;
311        self.fitted = true;
312        Ok(())
313    }
314
315    /// Predict from sparse features (CSR format).
316    pub fn predict_sparse(&self, features: &CsrMatrix) -> Result<Vec<f64>> {
317        if !self.fitted {
318            return Err(ScryLearnError::NotFitted);
319        }
320        Ok((0..features.n_rows())
321            .map(|i| {
322                let mut y = self.intercept;
323                for (col, val) in features.row(i).iter() {
324                    if col < self.coefficients.len() {
325                        y += self.coefficients[col] * val;
326                    }
327                }
328                y
329            })
330            .collect())
331    }
332
333    /// Get learned coefficients.
334    pub fn coefficients(&self) -> &[f64] {
335        &self.coefficients
336    }
337
338    /// Get learned intercept.
339    pub fn intercept(&self) -> f64 {
340        self.intercept
341    }
342}
343
344impl Default for ElasticNet {
345    fn default() -> Self {
346        Self::new()
347    }
348}
349
350/// Soft-thresholding operator: S(z, γ) = sign(z) max(|z| - γ, 0).
351#[inline]
352fn soft_threshold(z: f64, gamma: f64) -> f64 {
353    if z > gamma {
354        z - gamma
355    } else if z < -gamma {
356        z + gamma
357    } else {
358        0.0
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    #[test]
367    fn test_elastic_net_fit_predict() {
368        // y = 2x + 1
369        let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
370        let target = vec![3.0, 5.0, 7.0, 9.0, 11.0];
371        let data = Dataset::new(features, target, vec!["x".into()], "y");
372
373        let mut en = ElasticNet::new().alpha(0.01).l1_ratio(0.5).max_iter(5000);
374        en.fit(&data).unwrap();
375
376        let preds = en.predict(&[vec![3.0]]).unwrap();
377        assert!(
378            (preds[0] - 7.0).abs() < 0.5,
379            "expected ~7.0, got {}",
380            preds[0]
381        );
382    }
383
384    #[test]
385    fn test_elastic_net_pure_lasso() {
386        // l1_ratio=1.0 should behave like Lasso
387        let n = 50;
388        let mut rng = crate::rng::FastRng::new(42);
389        let mut x1 = Vec::with_capacity(n);
390        let mut x2 = Vec::with_capacity(n);
391        let mut y = Vec::with_capacity(n);
392        for _ in 0..n {
393            let v1 = rng.f64() * 10.0;
394            let v2 = rng.f64() * 10.0;
395            x1.push(v1);
396            x2.push(v2);
397            y.push(3.0 * v1 + 1.0); // x2 is irrelevant
398        }
399
400        let data = Dataset::new(vec![x1, x2], y, vec!["x1".into(), "x2".into()], "y");
401
402        let mut en = ElasticNet::new().alpha(0.5).l1_ratio(1.0).max_iter(5000);
403        en.fit(&data).unwrap();
404
405        let coefs = en.coefficients();
406        assert!(
407            coefs[1].abs() < 0.1,
408            "x2 should be ~0 with pure lasso, got {}",
409            coefs[1]
410        );
411    }
412
413    #[test]
414    fn test_elastic_net_not_fitted() {
415        let en = ElasticNet::new();
416        assert!(en.predict(&[vec![1.0]]).is_err());
417    }
418
419    #[test]
420    fn test_sparse_elastic_net_matches_dense() {
421        let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
422        let target = vec![3.0, 5.0, 7.0, 9.0, 11.0];
423        let data = Dataset::new(features.clone(), target.clone(), vec!["x".into()], "y");
424
425        let mut en_dense = ElasticNet::new().alpha(0.01).l1_ratio(0.5).max_iter(5000);
426        en_dense.fit(&data).unwrap();
427
428        let csc = CscMatrix::from_dense(&features);
429        let mut en_sparse = ElasticNet::new().alpha(0.01).l1_ratio(0.5).max_iter(5000);
430        en_sparse.fit_sparse(&csc, &target).unwrap();
431
432        assert!(
433            (en_dense.coefficients()[0] - en_sparse.coefficients()[0]).abs() < 0.1,
434            "Dense={} vs Sparse={}",
435            en_dense.coefficients()[0],
436            en_sparse.coefficients()[0]
437        );
438    }
439}