1use crate::dataset::Dataset;
9use crate::error::{Result, ScryLearnError};
10use crate::sparse::{CscMatrix, CsrMatrix};
11
12#[derive(Clone, Debug)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37#[non_exhaustive]
38pub struct LassoRegression {
39 alpha: f64,
41 max_iter: usize,
43 tol: f64,
45 coefficients: Vec<f64>,
47 intercept: f64,
49 fitted: bool,
51 #[cfg_attr(feature = "serde", serde(default))]
52 _schema_version: u32,
53}
54
55impl LassoRegression {
56 pub fn new() -> Self {
58 Self {
59 alpha: 1.0,
60 max_iter: 1000,
61 tol: crate::constants::DEFAULT_TOL,
62 coefficients: Vec::new(),
63 intercept: 0.0,
64 fitted: false,
65 _schema_version: crate::version::SCHEMA_VERSION,
66 }
67 }
68
69 pub fn alpha(mut self, a: f64) -> Self {
71 self.alpha = a;
72 self
73 }
74
75 pub fn max_iter(mut self, n: usize) -> Self {
77 self.max_iter = n;
78 self
79 }
80
81 pub fn tol(mut self, t: f64) -> Self {
83 self.tol = t;
84 self
85 }
86
87 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
89 data.validate_finite()?;
90 if let Some(csc) = data.sparse_csc() {
91 return self.fit_sparse(csc, &data.target);
92 }
93 let n = data.n_samples();
94 let m = data.n_features();
95 if n == 0 {
96 return Err(ScryLearnError::EmptyDataset);
97 }
98 if self.alpha < 0.0 {
99 return Err(ScryLearnError::InvalidParameter(
100 "alpha must be >= 0".into(),
101 ));
102 }
103
104 let y = &data.target;
105
106 let mut beta = vec![0.0; m];
108 let mut intercept = y.iter().sum::<f64>() / n as f64;
109
110 let mut col_norm_sq: Vec<f64> = vec![0.0; m];
112 for j in 0..m {
113 let col = &data.features[j];
114 let mut sq = 0.0;
115 for &x in col {
116 sq += x * x;
117 }
118 col_norm_sq[j] = sq / n as f64;
119 }
120
121 let n_f64 = n as f64;
122
123 let mut residuals: Vec<f64> = y.iter().map(|&yi| yi - intercept).collect();
125
126 for _iter in 0..self.max_iter {
127 let mut max_change = 0.0_f64;
128
129 let r_mean = residuals.iter().sum::<f64>() / n_f64;
131 let new_intercept = intercept + r_mean;
132 max_change = max_change.max((new_intercept - intercept).abs());
133 for r in &mut residuals {
134 *r -= r_mean;
135 }
136 intercept = new_intercept;
137
138 for j in 0..m {
140 if col_norm_sq[j] < crate::constants::NEAR_ZERO {
141 continue; }
143
144 let old_beta_j = beta[j];
145 let col = &data.features[j];
146
147 if old_beta_j != 0.0 {
149 for i in 0..n {
150 residuals[i] += col[i] * old_beta_j;
151 }
152 }
153
154 let mut rho = 0.0;
156 for i in 0..n {
157 rho += col[i] * residuals[i];
158 }
159 rho /= n_f64;
160
161 let new_beta_j = soft_threshold(rho, self.alpha) / col_norm_sq[j];
163 max_change = max_change.max((new_beta_j - old_beta_j).abs());
164 beta[j] = new_beta_j;
165
166 if new_beta_j != 0.0 {
168 for i in 0..n {
169 residuals[i] -= col[i] * new_beta_j;
170 }
171 }
172 }
173
174 if max_change < self.tol {
175 break;
176 }
177 }
178
179 self.coefficients = beta;
180 self.intercept = intercept;
181 self.fitted = true;
182 Ok(())
183 }
184
185 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
189 crate::version::check_schema_version(self._schema_version)?;
190 if !self.fitted {
191 return Err(ScryLearnError::NotFitted);
192 }
193 Ok(features
194 .iter()
195 .map(|row| {
196 row.iter()
197 .zip(self.coefficients.iter())
198 .map(|(x, b)| x * b)
199 .sum::<f64>()
200 + self.intercept
201 })
202 .collect())
203 }
204
205 #[allow(clippy::needless_range_loop)]
209 pub fn fit_sparse(&mut self, features: &CscMatrix, target: &[f64]) -> Result<()> {
210 let n = features.n_rows();
211 let m = features.n_cols();
212 if n == 0 {
213 return Err(ScryLearnError::EmptyDataset);
214 }
215 if target.len() != n {
216 return Err(ScryLearnError::InvalidParameter(format!(
217 "target length {} != n_rows {}",
218 target.len(),
219 n
220 )));
221 }
222 if self.alpha < 0.0 {
223 return Err(ScryLearnError::InvalidParameter(
224 "alpha must be >= 0".into(),
225 ));
226 }
227
228 let n_f64 = n as f64;
229 let mut beta = vec![0.0; m];
230 let mut intercept = target.iter().sum::<f64>() / n_f64;
231
232 let mut col_norm_sq = vec![0.0; m];
234 for j in 0..m {
235 let mut sq_sum = 0.0;
236 for (_, val) in features.col(j).iter() {
237 sq_sum += val * val;
238 }
239 col_norm_sq[j] = sq_sum / n_f64;
240 }
241
242 let mut residuals: Vec<f64> = target.iter().map(|&y| y - intercept).collect();
244
245 for _iter in 0..self.max_iter {
246 let mut max_change = 0.0_f64;
247
248 let r_mean = residuals.iter().sum::<f64>() / n_f64;
250 let new_intercept = intercept + r_mean;
251 max_change = max_change.max((new_intercept - intercept).abs());
252 for r in &mut residuals {
253 *r -= r_mean;
254 }
255 intercept = new_intercept;
256
257 for j in 0..m {
259 if col_norm_sq[j] < crate::constants::NEAR_ZERO {
260 continue;
261 }
262
263 let old_beta_j = beta[j];
264
265 if old_beta_j != 0.0 {
267 for (row_idx, val) in features.col(j).iter() {
268 residuals[row_idx] += val * old_beta_j;
269 }
270 }
271
272 let mut rho = 0.0;
274 for (row_idx, val) in features.col(j).iter() {
275 rho += val * residuals[row_idx];
276 }
277 rho /= n_f64;
278
279 let new_beta_j = soft_threshold(rho, self.alpha) / col_norm_sq[j];
280 max_change = max_change.max((new_beta_j - old_beta_j).abs());
281 beta[j] = new_beta_j;
282
283 if new_beta_j != 0.0 {
285 for (row_idx, val) in features.col(j).iter() {
286 residuals[row_idx] -= val * new_beta_j;
287 }
288 }
289 }
290
291 if max_change < self.tol {
292 break;
293 }
294 }
295
296 self.coefficients = beta;
297 self.intercept = intercept;
298 self.fitted = true;
299 Ok(())
300 }
301
302 pub fn predict_sparse(&self, features: &CsrMatrix) -> Result<Vec<f64>> {
304 if !self.fitted {
305 return Err(ScryLearnError::NotFitted);
306 }
307 Ok((0..features.n_rows())
308 .map(|i| {
309 let mut y = self.intercept;
310 for (col, val) in features.row(i).iter() {
311 if col < self.coefficients.len() {
312 y += self.coefficients[col] * val;
313 }
314 }
315 y
316 })
317 .collect())
318 }
319
320 pub fn coefficients(&self) -> &[f64] {
322 &self.coefficients
323 }
324
325 pub fn intercept(&self) -> f64 {
327 self.intercept
328 }
329}
330
331impl Default for LassoRegression {
332 fn default() -> Self {
333 Self::new()
334 }
335}
336
337#[inline]
339fn soft_threshold(z: f64, gamma: f64) -> f64 {
340 if z > gamma {
341 z - gamma
342 } else if z < -gamma {
343 z + gamma
344 } else {
345 0.0
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_lasso_fit_predict() {
355 let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
357 let target = vec![3.0, 5.0, 7.0, 9.0, 11.0];
358 let data = Dataset::new(features, target, vec!["x".into()], "y");
359
360 let mut lasso = LassoRegression::new().alpha(0.01).max_iter(5000);
361 lasso.fit(&data).unwrap();
362
363 let preds = lasso.predict(&[vec![3.0]]).unwrap();
364 assert!(
365 (preds[0] - 7.0).abs() < 0.5,
366 "expected ~7.0, got {}",
367 preds[0]
368 );
369 }
370
371 #[test]
372 fn test_lasso_sparsity() {
373 let n = 100;
375 let mut rng = crate::rng::FastRng::new(42);
376 let mut x1 = Vec::with_capacity(n);
377 let mut x2 = Vec::with_capacity(n);
378 let mut x3 = Vec::with_capacity(n);
379 let mut x4 = Vec::with_capacity(n);
380 let mut y = Vec::with_capacity(n);
381
382 for _ in 0..n {
383 let v1 = rng.f64() * 10.0;
384 let v2 = rng.f64() * 10.0;
385 let v3 = rng.f64() * 10.0;
386 let v4 = rng.f64() * 10.0;
387 x1.push(v1);
388 x2.push(v2);
389 x3.push(v3);
390 x4.push(v4);
391 y.push(2.0 * v1 + 3.0 * v3 + 1.0);
392 }
393
394 let data = Dataset::new(
395 vec![x1, x2, x3, x4],
396 y,
397 vec!["x1".into(), "x2".into(), "x3".into(), "x4".into()],
398 "y",
399 );
400
401 let mut lasso = LassoRegression::new().alpha(0.5).max_iter(5000);
402 lasso.fit(&data).unwrap();
403
404 let coefs = lasso.coefficients();
405 assert!(
407 coefs[1].abs() < 0.1,
408 "x2 coef should be ~0, got {}",
409 coefs[1]
410 );
411 assert!(
412 coefs[3].abs() < 0.1,
413 "x4 coef should be ~0, got {}",
414 coefs[3]
415 );
416 assert!(coefs[0].abs() > 0.5, "x1 coef should be significant");
418 assert!(coefs[2].abs() > 0.5, "x3 coef should be significant");
419 }
420
421 #[test]
422 fn test_lasso_not_fitted() {
423 let lasso = LassoRegression::new();
424 assert!(lasso.predict(&[vec![1.0]]).is_err());
425 }
426
427 #[test]
428 fn test_sparse_lasso_matches_dense() {
429 let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
430 let target = vec![3.0, 5.0, 7.0, 9.0, 11.0];
431 let data = Dataset::new(features.clone(), target.clone(), vec!["x".into()], "y");
432
433 let mut lasso_dense = LassoRegression::new().alpha(0.01).max_iter(5000);
434 lasso_dense.fit(&data).unwrap();
435
436 let csc = CscMatrix::from_dense(&features);
437 let mut lasso_sparse = LassoRegression::new().alpha(0.01).max_iter(5000);
438 lasso_sparse.fit_sparse(&csc, &target).unwrap();
439
440 assert!(
441 (lasso_dense.coefficients()[0] - lasso_sparse.coefficients()[0]).abs() < 0.1,
442 "Dense={} vs Sparse={}",
443 lasso_dense.coefficients()[0],
444 lasso_sparse.coefficients()[0]
445 );
446
447 let test = vec![vec![3.0]];
448 let csr = CsrMatrix::from_dense(&test);
449 let pred_d = lasso_dense.predict(&test).unwrap()[0];
450 let pred_s = lasso_sparse.predict_sparse(&csr).unwrap()[0];
451 assert!(
452 (pred_d - pred_s).abs() < 0.5,
453 "Dense pred={pred_d} vs Sparse pred={pred_s}"
454 );
455 }
456}