Skip to main content

scry_learn/linear/
lasso.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Lasso regression via coordinate descent (L1 regularization).
3//!
4//! Coordinate descent iteratively optimizes one coefficient at a time,
5//! applying the soft-thresholding operator to drive small coefficients
6//! exactly to zero — producing sparse models.
7
8use crate::dataset::Dataset;
9use crate::error::{Result, ScryLearnError};
10use crate::sparse::{CscMatrix, CsrMatrix};
11
12/// Lasso regression (L1-regularized linear regression).
13///
14/// Uses coordinate descent to find the coefficients `β` that minimize:
15///
16/// ```text
17/// (1 / 2n) ‖y − Xβ − β₀‖² + α ‖β‖₁
18/// ```
19///
20/// Higher `alpha` produces sparser models (more coefficients driven to zero).
21///
22/// # Example
23/// ```
24/// use scry_learn::dataset::Dataset;
25/// use scry_learn::linear::LassoRegression;
26///
27/// let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
28/// let target = vec![2.1, 4.0, 5.9, 8.1, 10.0];
29/// let data = Dataset::new(features, target, vec!["x".into()], "y");
30///
31/// let mut lasso = LassoRegression::new().alpha(0.1);
32/// lasso.fit(&data).unwrap();
33/// let preds = lasso.predict(&[vec![3.0]]).unwrap();
34/// ```
35#[derive(Clone, Debug)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37#[non_exhaustive]
38pub struct LassoRegression {
39    /// L1 regularization strength.
40    alpha: f64,
41    /// Maximum coordinate descent iterations.
42    max_iter: usize,
43    /// Convergence tolerance.
44    tol: f64,
45    /// Learned coefficients (one per feature).
46    coefficients: Vec<f64>,
47    /// Learned intercept.
48    intercept: f64,
49    /// Whether the model has been fitted.
50    fitted: bool,
51    #[cfg_attr(feature = "serde", serde(default))]
52    _schema_version: u32,
53}
54
55impl LassoRegression {
56    /// Create a new Lasso with default parameters (alpha=1.0, max_iter=1000).
57    pub fn new() -> Self {
58        Self {
59            alpha: 1.0,
60            max_iter: 1000,
61            tol: crate::constants::DEFAULT_TOL,
62            coefficients: Vec::new(),
63            intercept: 0.0,
64            fitted: false,
65            _schema_version: crate::version::SCHEMA_VERSION,
66        }
67    }
68
69    /// Set the L1 regularization strength.
70    pub fn alpha(mut self, a: f64) -> Self {
71        self.alpha = a;
72        self
73    }
74
75    /// Set the maximum number of iterations.
76    pub fn max_iter(mut self, n: usize) -> Self {
77        self.max_iter = n;
78        self
79    }
80
81    /// Set convergence tolerance.
82    pub fn tol(mut self, t: f64) -> Self {
83        self.tol = t;
84        self
85    }
86
87    /// Fit the Lasso model using coordinate descent.
88    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
89        data.validate_finite()?;
90        if let Some(csc) = data.sparse_csc() {
91            return self.fit_sparse(csc, &data.target);
92        }
93        let n = data.n_samples();
94        let m = data.n_features();
95        if n == 0 {
96            return Err(ScryLearnError::EmptyDataset);
97        }
98        if self.alpha < 0.0 {
99            return Err(ScryLearnError::InvalidParameter(
100                "alpha must be >= 0".into(),
101            ));
102        }
103
104        let y = &data.target;
105
106        // Initialize coefficients to zero.
107        let mut beta = vec![0.0; m];
108        let mut intercept = y.iter().sum::<f64>() / n as f64;
109
110        // Precompute feature norms: ‖X_j‖² / n (using column-major data directly).
111        let mut col_norm_sq: Vec<f64> = vec![0.0; m];
112        for j in 0..m {
113            let col = &data.features[j];
114            let mut sq = 0.0;
115            for &x in col {
116                sq += x * x;
117            }
118            col_norm_sq[j] = sq / n as f64;
119        }
120
121        let n_f64 = n as f64;
122
123        // Initialize residuals: r_i = y_i - intercept
124        let mut residuals: Vec<f64> = y.iter().map(|&yi| yi - intercept).collect();
125
126        for _iter in 0..self.max_iter {
127            let mut max_change = 0.0_f64;
128
129            // Update intercept: shift by mean of residuals.
130            let r_mean = residuals.iter().sum::<f64>() / n_f64;
131            let new_intercept = intercept + r_mean;
132            max_change = max_change.max((new_intercept - intercept).abs());
133            for r in &mut residuals {
134                *r -= r_mean;
135            }
136            intercept = new_intercept;
137
138            // Coordinate descent over each feature.
139            for j in 0..m {
140                if col_norm_sq[j] < crate::constants::NEAR_ZERO {
141                    continue; // skip constant features
142                }
143
144                let old_beta_j = beta[j];
145                let col = &data.features[j];
146
147                // Add back current j contribution to residuals.
148                if old_beta_j != 0.0 {
149                    for i in 0..n {
150                        residuals[i] += col[i] * old_beta_j;
151                    }
152                }
153
154                // ρ = (1/n) Σ x_ij * r_i
155                let mut rho = 0.0;
156                for i in 0..n {
157                    rho += col[i] * residuals[i];
158                }
159                rho /= n_f64;
160
161                // Soft-thresholding: β_j = S(ρ, α) / ‖X_j‖²/n
162                let new_beta_j = soft_threshold(rho, self.alpha) / col_norm_sq[j];
163                max_change = max_change.max((new_beta_j - old_beta_j).abs());
164                beta[j] = new_beta_j;
165
166                // Remove new j contribution from residuals.
167                if new_beta_j != 0.0 {
168                    for i in 0..n {
169                        residuals[i] -= col[i] * new_beta_j;
170                    }
171                }
172            }
173
174            if max_change < self.tol {
175                break;
176            }
177        }
178
179        self.coefficients = beta;
180        self.intercept = intercept;
181        self.fitted = true;
182        Ok(())
183    }
184
185    /// Predict target values for new samples.
186    ///
187    /// `features` is row-major: `features[sample_idx][feature_idx]`.
188    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
189        crate::version::check_schema_version(self._schema_version)?;
190        if !self.fitted {
191            return Err(ScryLearnError::NotFitted);
192        }
193        Ok(features
194            .iter()
195            .map(|row| {
196                row.iter()
197                    .zip(self.coefficients.iter())
198                    .map(|(x, b)| x * b)
199                    .sum::<f64>()
200                    + self.intercept
201            })
202            .collect())
203    }
204
205    /// Fit on sparse features using coordinate descent.
206    ///
207    /// Accepts `CscMatrix` for efficient column-oriented coordinate descent.
208    #[allow(clippy::needless_range_loop)]
209    pub fn fit_sparse(&mut self, features: &CscMatrix, target: &[f64]) -> Result<()> {
210        let n = features.n_rows();
211        let m = features.n_cols();
212        if n == 0 {
213            return Err(ScryLearnError::EmptyDataset);
214        }
215        if target.len() != n {
216            return Err(ScryLearnError::InvalidParameter(format!(
217                "target length {} != n_rows {}",
218                target.len(),
219                n
220            )));
221        }
222        if self.alpha < 0.0 {
223            return Err(ScryLearnError::InvalidParameter(
224                "alpha must be >= 0".into(),
225            ));
226        }
227
228        let n_f64 = n as f64;
229        let mut beta = vec![0.0; m];
230        let mut intercept = target.iter().sum::<f64>() / n_f64;
231
232        // Precompute ‖X_j‖² / n from sparse columns.
233        let mut col_norm_sq = vec![0.0; m];
234        for j in 0..m {
235            let mut sq_sum = 0.0;
236            for (_, val) in features.col(j).iter() {
237                sq_sum += val * val;
238            }
239            col_norm_sq[j] = sq_sum / n_f64;
240        }
241
242        // Residuals: r_i = y_i - intercept
243        let mut residuals: Vec<f64> = target.iter().map(|&y| y - intercept).collect();
244
245        for _iter in 0..self.max_iter {
246            let mut max_change = 0.0_f64;
247
248            // Update intercept.
249            let r_mean = residuals.iter().sum::<f64>() / n_f64;
250            let new_intercept = intercept + r_mean;
251            max_change = max_change.max((new_intercept - intercept).abs());
252            for r in &mut residuals {
253                *r -= r_mean;
254            }
255            intercept = new_intercept;
256
257            // Coordinate descent.
258            for j in 0..m {
259                if col_norm_sq[j] < crate::constants::NEAR_ZERO {
260                    continue;
261                }
262
263                let old_beta_j = beta[j];
264
265                // Add back current j contribution to residuals.
266                if old_beta_j != 0.0 {
267                    for (row_idx, val) in features.col(j).iter() {
268                        residuals[row_idx] += val * old_beta_j;
269                    }
270                }
271
272                // ρ = (1/n) Σ x_ij * r_i (only non-zero x_ij entries)
273                let mut rho = 0.0;
274                for (row_idx, val) in features.col(j).iter() {
275                    rho += val * residuals[row_idx];
276                }
277                rho /= n_f64;
278
279                let new_beta_j = soft_threshold(rho, self.alpha) / col_norm_sq[j];
280                max_change = max_change.max((new_beta_j - old_beta_j).abs());
281                beta[j] = new_beta_j;
282
283                // Remove new j contribution from residuals.
284                if new_beta_j != 0.0 {
285                    for (row_idx, val) in features.col(j).iter() {
286                        residuals[row_idx] -= val * new_beta_j;
287                    }
288                }
289            }
290
291            if max_change < self.tol {
292                break;
293            }
294        }
295
296        self.coefficients = beta;
297        self.intercept = intercept;
298        self.fitted = true;
299        Ok(())
300    }
301
302    /// Predict from sparse features (CSR format).
303    pub fn predict_sparse(&self, features: &CsrMatrix) -> Result<Vec<f64>> {
304        if !self.fitted {
305            return Err(ScryLearnError::NotFitted);
306        }
307        Ok((0..features.n_rows())
308            .map(|i| {
309                let mut y = self.intercept;
310                for (col, val) in features.row(i).iter() {
311                    if col < self.coefficients.len() {
312                        y += self.coefficients[col] * val;
313                    }
314                }
315                y
316            })
317            .collect())
318    }
319
320    /// Get learned coefficients.
321    pub fn coefficients(&self) -> &[f64] {
322        &self.coefficients
323    }
324
325    /// Get learned intercept.
326    pub fn intercept(&self) -> f64 {
327        self.intercept
328    }
329}
330
331impl Default for LassoRegression {
332    fn default() -> Self {
333        Self::new()
334    }
335}
336
337/// Soft-thresholding operator: S(z, γ) = sign(z) max(|z| - γ, 0).
338#[inline]
339fn soft_threshold(z: f64, gamma: f64) -> f64 {
340    if z > gamma {
341        z - gamma
342    } else if z < -gamma {
343        z + gamma
344    } else {
345        0.0
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_lasso_fit_predict() {
355        // y = 2x + 1
356        let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
357        let target = vec![3.0, 5.0, 7.0, 9.0, 11.0];
358        let data = Dataset::new(features, target, vec!["x".into()], "y");
359
360        let mut lasso = LassoRegression::new().alpha(0.01).max_iter(5000);
361        lasso.fit(&data).unwrap();
362
363        let preds = lasso.predict(&[vec![3.0]]).unwrap();
364        assert!(
365            (preds[0] - 7.0).abs() < 0.5,
366            "expected ~7.0, got {}",
367            preds[0]
368        );
369    }
370
371    #[test]
372    fn test_lasso_sparsity() {
373        // y = 2*x1 + 3*x3 + 1, x2 and x4 are noise
374        let n = 100;
375        let mut rng = crate::rng::FastRng::new(42);
376        let mut x1 = Vec::with_capacity(n);
377        let mut x2 = Vec::with_capacity(n);
378        let mut x3 = Vec::with_capacity(n);
379        let mut x4 = Vec::with_capacity(n);
380        let mut y = Vec::with_capacity(n);
381
382        for _ in 0..n {
383            let v1 = rng.f64() * 10.0;
384            let v2 = rng.f64() * 10.0;
385            let v3 = rng.f64() * 10.0;
386            let v4 = rng.f64() * 10.0;
387            x1.push(v1);
388            x2.push(v2);
389            x3.push(v3);
390            x4.push(v4);
391            y.push(2.0 * v1 + 3.0 * v3 + 1.0);
392        }
393
394        let data = Dataset::new(
395            vec![x1, x2, x3, x4],
396            y,
397            vec!["x1".into(), "x2".into(), "x3".into(), "x4".into()],
398            "y",
399        );
400
401        let mut lasso = LassoRegression::new().alpha(0.5).max_iter(5000);
402        lasso.fit(&data).unwrap();
403
404        let coefs = lasso.coefficients();
405        // x2 and x4 coefficients should be driven to ~0
406        assert!(
407            coefs[1].abs() < 0.1,
408            "x2 coef should be ~0, got {}",
409            coefs[1]
410        );
411        assert!(
412            coefs[3].abs() < 0.1,
413            "x4 coef should be ~0, got {}",
414            coefs[3]
415        );
416        // x1 and x3 should be significant
417        assert!(coefs[0].abs() > 0.5, "x1 coef should be significant");
418        assert!(coefs[2].abs() > 0.5, "x3 coef should be significant");
419    }
420
421    #[test]
422    fn test_lasso_not_fitted() {
423        let lasso = LassoRegression::new();
424        assert!(lasso.predict(&[vec![1.0]]).is_err());
425    }
426
427    #[test]
428    fn test_sparse_lasso_matches_dense() {
429        let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
430        let target = vec![3.0, 5.0, 7.0, 9.0, 11.0];
431        let data = Dataset::new(features.clone(), target.clone(), vec!["x".into()], "y");
432
433        let mut lasso_dense = LassoRegression::new().alpha(0.01).max_iter(5000);
434        lasso_dense.fit(&data).unwrap();
435
436        let csc = CscMatrix::from_dense(&features);
437        let mut lasso_sparse = LassoRegression::new().alpha(0.01).max_iter(5000);
438        lasso_sparse.fit_sparse(&csc, &target).unwrap();
439
440        assert!(
441            (lasso_dense.coefficients()[0] - lasso_sparse.coefficients()[0]).abs() < 0.1,
442            "Dense={} vs Sparse={}",
443            lasso_dense.coefficients()[0],
444            lasso_sparse.coefficients()[0]
445        );
446
447        let test = vec![vec![3.0]];
448        let csr = CsrMatrix::from_dense(&test);
449        let pred_d = lasso_dense.predict(&test).unwrap()[0];
450        let pred_s = lasso_sparse.predict_sparse(&csr).unwrap()[0];
451        assert!(
452            (pred_d - pred_s).abs() < 0.5,
453            "Dense pred={pred_d} vs Sparse pred={pred_s}"
454        );
455    }
456}