Skip to main content

scry_learn/linear/
mod.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Linear models: OLS, Ridge, Logistic, Lasso, and ElasticNet.
3//!
4//! # Regularization naming convention
5//!
6//! scry-learn uses **`alpha`** as the regularization strength parameter across
7//! all linear models — matching scikit-learn's `Ridge`, `Lasso`, and `ElasticNet`:
8//!
9//! | Model | Parameter | Meaning |
10//! |-------|-----------|---------|
11//! | [`LinearRegression`] | `alpha` | L2 penalty strength (0 = OLS) |
12//! | [`Ridge`] | `alpha` | L2 penalty strength (constructor arg) |
13//! | [`LassoRegression`] | `alpha` | L1 penalty strength |
14//! | [`ElasticNet`] | `alpha` | Total penalty strength |
15//! | [`LogisticRegression`] | `alpha` | Penalty strength (type set by [`Penalty`]) |
16//!
17//! ## sklearn migration note
18//!
19//! scikit-learn's `LogisticRegression` and `SVC` use **`C = 1/alpha`** (inverse
20//! regularization strength). When porting sklearn code, convert via `alpha = 1.0 / C`.
21//! All other sklearn linear models (`Ridge`, `Lasso`, `ElasticNet`) already use
22//! `alpha`, so those translate directly.
23
24mod elastic_net;
25mod lasso;
26mod lbfgs;
27mod logistic;
28pub(crate) mod qr;
29mod regression;
30pub(crate) mod svd;
31
32pub use elastic_net::ElasticNet;
33pub use lasso::LassoRegression;
34pub use logistic::{LogisticRegression, Penalty, Solver};
35pub use regression::LinearRegression;
36
37use crate::dataset::Dataset;
38use crate::error::Result;
39
40/// Ridge regression — [`LinearRegression`] with L2 regularization.
41///
42/// This is a thin wrapper around [`LinearRegression`] that provides a more
43/// discoverable API for users coming from scikit-learn's `Ridge` class.
44///
45/// # Example
46///
47/// ```
48/// use scry_learn::linear::Ridge;
49/// # use scry_learn::dataset::Dataset;
50///
51/// let data = Dataset::new(
52///     vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]],
53///     vec![2.0, 4.0, 6.0, 8.0, 10.0],
54///     vec!["x".into()],
55///     "y",
56/// );
57///
58/// let mut model = Ridge::new(1.0);
59/// model.fit(&data).unwrap();
60/// ```
61#[derive(Clone)]
62#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
63#[non_exhaustive]
64pub struct Ridge {
65    inner: LinearRegression,
66}
67
68impl Ridge {
69    /// Create a new Ridge regression model with the given L2 regularization strength.
70    ///
71    /// Equivalent to `LinearRegression::new().alpha(alpha)`.
72    pub fn new(alpha: f64) -> Self {
73        Self {
74            inner: LinearRegression::new().alpha(alpha),
75        }
76    }
77
78    /// Train the model on the given dataset.
79    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
80        self.inner.fit(data)
81    }
82
83    /// Predict target values for the given feature matrix.
84    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
85        self.inner.predict(features)
86    }
87
88    /// Get the learned coefficients.
89    pub fn coefficients(&self) -> &[f64] {
90        self.inner.coefficients()
91    }
92
93    /// Get the learned intercept.
94    pub fn intercept(&self) -> f64 {
95        self.inner.intercept()
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102
103    #[test]
104    fn test_ridge_alias() {
105        let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
106        let target = vec![2.0, 4.0, 6.0, 8.0, 10.0];
107        let data = Dataset::new(features, target, vec!["x".into()], "y");
108
109        // Ridge(1.0) should produce the same result as LinearRegression::new().alpha(1.0).
110        let mut ridge = Ridge::new(1.0);
111        ridge.fit(&data).unwrap();
112
113        let mut lr = LinearRegression::new().alpha(1.0);
114        lr.fit(&data).unwrap();
115
116        assert!(
117            (ridge.coefficients()[0] - lr.coefficients()[0]).abs() < 1e-10,
118            "Ridge and LinearRegression(alpha=1.0) should produce identical coefficients"
119        );
120        assert!(
121            (ridge.intercept() - lr.intercept()).abs() < 1e-10,
122            "Ridge and LinearRegression(alpha=1.0) should produce identical intercepts"
123        );
124
125        // Sanity: coefficient should be shrunk below 2.0 (the OLS solution).
126        assert!(ridge.coefficients()[0] < 2.0);
127        assert!(ridge.coefficients()[0] > 1.0);
128    }
129}