scry_learn/linear/mod.rs
1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Linear models: OLS, Ridge, Logistic, Lasso, and ElasticNet.
3//!
4//! # Regularization naming convention
5//!
6//! scry-learn uses **`alpha`** as the regularization strength parameter across
7//! all linear models — matching scikit-learn's `Ridge`, `Lasso`, and `ElasticNet`:
8//!
9//! | Model | Parameter | Meaning |
10//! |-------|-----------|---------|
11//! | [`LinearRegression`] | `alpha` | L2 penalty strength (0 = OLS) |
12//! | [`Ridge`] | `alpha` | L2 penalty strength (constructor arg) |
13//! | [`LassoRegression`] | `alpha` | L1 penalty strength |
14//! | [`ElasticNet`] | `alpha` | Total penalty strength |
15//! | [`LogisticRegression`] | `alpha` | Penalty strength (type set by [`Penalty`]) |
16//!
17//! ## sklearn migration note
18//!
19//! scikit-learn's `LogisticRegression` and `SVC` use **`C = 1/alpha`** (inverse
20//! regularization strength). When porting sklearn code, convert via `alpha = 1.0 / C`.
21//! All other sklearn linear models (`Ridge`, `Lasso`, `ElasticNet`) already use
22//! `alpha`, so those translate directly.
23
24mod elastic_net;
25mod lasso;
26mod lbfgs;
27mod logistic;
28pub(crate) mod qr;
29mod regression;
30pub(crate) mod svd;
31
32pub use elastic_net::ElasticNet;
33pub use lasso::LassoRegression;
34pub use logistic::{LogisticRegression, Penalty, Solver};
35pub use regression::LinearRegression;
36
37use crate::dataset::Dataset;
38use crate::error::Result;
39
40/// Ridge regression — [`LinearRegression`] with L2 regularization.
41///
42/// This is a thin wrapper around [`LinearRegression`] that provides a more
43/// discoverable API for users coming from scikit-learn's `Ridge` class.
44///
45/// # Example
46///
47/// ```
48/// use scry_learn::linear::Ridge;
49/// # use scry_learn::dataset::Dataset;
50///
51/// let data = Dataset::new(
52/// vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]],
53/// vec![2.0, 4.0, 6.0, 8.0, 10.0],
54/// vec!["x".into()],
55/// "y",
56/// );
57///
58/// let mut model = Ridge::new(1.0);
59/// model.fit(&data).unwrap();
60/// ```
61#[derive(Clone)]
62#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
63#[non_exhaustive]
64pub struct Ridge {
65 inner: LinearRegression,
66}
67
68impl Ridge {
69 /// Create a new Ridge regression model with the given L2 regularization strength.
70 ///
71 /// Equivalent to `LinearRegression::new().alpha(alpha)`.
72 pub fn new(alpha: f64) -> Self {
73 Self {
74 inner: LinearRegression::new().alpha(alpha),
75 }
76 }
77
78 /// Train the model on the given dataset.
79 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
80 self.inner.fit(data)
81 }
82
83 /// Predict target values for the given feature matrix.
84 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
85 self.inner.predict(features)
86 }
87
88 /// Get the learned coefficients.
89 pub fn coefficients(&self) -> &[f64] {
90 self.inner.coefficients()
91 }
92
93 /// Get the learned intercept.
94 pub fn intercept(&self) -> f64 {
95 self.inner.intercept()
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102
103 #[test]
104 fn test_ridge_alias() {
105 let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
106 let target = vec![2.0, 4.0, 6.0, 8.0, 10.0];
107 let data = Dataset::new(features, target, vec!["x".into()], "y");
108
109 // Ridge(1.0) should produce the same result as LinearRegression::new().alpha(1.0).
110 let mut ridge = Ridge::new(1.0);
111 ridge.fit(&data).unwrap();
112
113 let mut lr = LinearRegression::new().alpha(1.0);
114 lr.fit(&data).unwrap();
115
116 assert!(
117 (ridge.coefficients()[0] - lr.coefficients()[0]).abs() < 1e-10,
118 "Ridge and LinearRegression(alpha=1.0) should produce identical coefficients"
119 );
120 assert!(
121 (ridge.intercept() - lr.intercept()).abs() < 1e-10,
122 "Ridge and LinearRegression(alpha=1.0) should produce identical intercepts"
123 );
124
125 // Sanity: coefficient should be shrunk below 2.0 (the OLS solution).
126 assert!(ridge.coefficients()[0] < 2.0);
127 assert!(ridge.coefficients()[0] > 1.0);
128 }
129}