1use crate::dataset::Dataset;
10use crate::error::{Result, ScryLearnError};
11use crate::sparse::{CscMatrix, CsrMatrix};
12
13#[derive(Clone, Debug)]
35#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
36#[non_exhaustive]
37pub struct ElasticNet {
38 alpha: f64,
40 l1_ratio: f64,
42 max_iter: usize,
44 tol: f64,
46 coefficients: Vec<f64>,
48 intercept: f64,
50 fitted: bool,
52 #[cfg_attr(feature = "serde", serde(default))]
53 _schema_version: u32,
54}
55
56impl ElasticNet {
57 pub fn new() -> Self {
59 Self {
60 alpha: 1.0,
61 l1_ratio: 0.5,
62 max_iter: 1000,
63 tol: crate::constants::DEFAULT_TOL,
64 coefficients: Vec::new(),
65 intercept: 0.0,
66 fitted: false,
67 _schema_version: crate::version::SCHEMA_VERSION,
68 }
69 }
70
71 pub fn alpha(mut self, a: f64) -> Self {
73 self.alpha = a;
74 self
75 }
76
77 pub fn l1_ratio(mut self, r: f64) -> Self {
79 self.l1_ratio = r;
80 self
81 }
82
83 pub fn max_iter(mut self, n: usize) -> Self {
85 self.max_iter = n;
86 self
87 }
88
89 pub fn tol(mut self, t: f64) -> Self {
91 self.tol = t;
92 self
93 }
94
95 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
97 data.validate_finite()?;
98 if let Some(csc) = data.sparse_csc() {
99 return self.fit_sparse(csc, &data.target);
100 }
101 let n = data.n_samples();
102 let m = data.n_features();
103 if n == 0 {
104 return Err(ScryLearnError::EmptyDataset);
105 }
106 if self.alpha < 0.0 {
107 return Err(ScryLearnError::InvalidParameter(
108 "alpha must be >= 0".into(),
109 ));
110 }
111 if !(0.0..=1.0).contains(&self.l1_ratio) {
112 return Err(ScryLearnError::InvalidParameter(
113 "l1_ratio must be in [0, 1]".into(),
114 ));
115 }
116
117 let y = &data.target;
118
119 let mut beta = vec![0.0; m];
120 let mut intercept = y.iter().sum::<f64>() / n as f64;
121
122 let mut col_norm_sq: Vec<f64> = vec![0.0; m];
124 for j in 0..m {
125 let col = &data.features[j];
126 let mut sq = 0.0;
127 for &x in col {
128 sq += x * x;
129 }
130 col_norm_sq[j] = sq / n as f64;
131 }
132
133 let n_f64 = n as f64;
134 let l1_pen = self.alpha * self.l1_ratio;
135 let l2_pen = self.alpha * (1.0 - self.l1_ratio);
136
137 let mut residuals: Vec<f64> = y.iter().map(|&yi| yi - intercept).collect();
139
140 for _iter in 0..self.max_iter {
141 let mut max_change = 0.0_f64;
142
143 let r_mean = residuals.iter().sum::<f64>() / n_f64;
145 let new_intercept = intercept + r_mean;
146 max_change = max_change.max((new_intercept - intercept).abs());
147 for r in &mut residuals {
148 *r -= r_mean;
149 }
150 intercept = new_intercept;
151
152 for j in 0..m {
154 if col_norm_sq[j] < crate::constants::NEAR_ZERO {
155 continue;
156 }
157
158 let old_beta_j = beta[j];
159 let col = &data.features[j];
160
161 if old_beta_j != 0.0 {
163 for i in 0..n {
164 residuals[i] += col[i] * old_beta_j;
165 }
166 }
167
168 let mut rho = 0.0;
170 for i in 0..n {
171 rho += col[i] * residuals[i];
172 }
173 rho /= n_f64;
174
175 let new_beta_j = soft_threshold(rho, l1_pen) / (col_norm_sq[j] + l2_pen);
177 max_change = max_change.max((new_beta_j - old_beta_j).abs());
178 beta[j] = new_beta_j;
179
180 if new_beta_j != 0.0 {
182 for i in 0..n {
183 residuals[i] -= col[i] * new_beta_j;
184 }
185 }
186 }
187
188 if max_change < self.tol {
189 break;
190 }
191 }
192
193 self.coefficients = beta;
194 self.intercept = intercept;
195 self.fitted = true;
196 Ok(())
197 }
198
199 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
203 crate::version::check_schema_version(self._schema_version)?;
204 if !self.fitted {
205 return Err(ScryLearnError::NotFitted);
206 }
207 Ok(features
208 .iter()
209 .map(|row| {
210 row.iter()
211 .zip(self.coefficients.iter())
212 .map(|(x, b)| x * b)
213 .sum::<f64>()
214 + self.intercept
215 })
216 .collect())
217 }
218
219 #[allow(clippy::needless_range_loop)]
221 pub fn fit_sparse(&mut self, features: &CscMatrix, target: &[f64]) -> Result<()> {
222 let n = features.n_rows();
223 let m = features.n_cols();
224 if n == 0 {
225 return Err(ScryLearnError::EmptyDataset);
226 }
227 if target.len() != n {
228 return Err(ScryLearnError::InvalidParameter(format!(
229 "target length {} != n_rows {}",
230 target.len(),
231 n
232 )));
233 }
234 if self.alpha < 0.0 {
235 return Err(ScryLearnError::InvalidParameter(
236 "alpha must be >= 0".into(),
237 ));
238 }
239 if !(0.0..=1.0).contains(&self.l1_ratio) {
240 return Err(ScryLearnError::InvalidParameter(
241 "l1_ratio must be in [0, 1]".into(),
242 ));
243 }
244
245 let n_f64 = n as f64;
246 let l1_pen = self.alpha * self.l1_ratio;
247 let l2_pen = self.alpha * (1.0 - self.l1_ratio);
248
249 let mut beta = vec![0.0; m];
250 let mut intercept = target.iter().sum::<f64>() / n_f64;
251
252 let mut col_norm_sq = vec![0.0; m];
253 for j in 0..m {
254 let mut sq_sum = 0.0;
255 for (_, val) in features.col(j).iter() {
256 sq_sum += val * val;
257 }
258 col_norm_sq[j] = sq_sum / n_f64;
259 }
260
261 let mut residuals: Vec<f64> = target.iter().map(|&y| y - intercept).collect();
262
263 for _iter in 0..self.max_iter {
264 let mut max_change = 0.0_f64;
265
266 let r_mean = residuals.iter().sum::<f64>() / n_f64;
267 let new_intercept = intercept + r_mean;
268 max_change = max_change.max((new_intercept - intercept).abs());
269 for r in &mut residuals {
270 *r -= r_mean;
271 }
272 intercept = new_intercept;
273
274 for j in 0..m {
275 if col_norm_sq[j] < crate::constants::NEAR_ZERO {
276 continue;
277 }
278
279 let old_beta_j = beta[j];
280
281 if old_beta_j != 0.0 {
282 for (row_idx, val) in features.col(j).iter() {
283 residuals[row_idx] += val * old_beta_j;
284 }
285 }
286
287 let mut rho = 0.0;
288 for (row_idx, val) in features.col(j).iter() {
289 rho += val * residuals[row_idx];
290 }
291 rho /= n_f64;
292
293 let new_beta_j = soft_threshold(rho, l1_pen) / (col_norm_sq[j] + l2_pen);
294 max_change = max_change.max((new_beta_j - old_beta_j).abs());
295 beta[j] = new_beta_j;
296
297 if new_beta_j != 0.0 {
298 for (row_idx, val) in features.col(j).iter() {
299 residuals[row_idx] -= val * new_beta_j;
300 }
301 }
302 }
303
304 if max_change < self.tol {
305 break;
306 }
307 }
308
309 self.coefficients = beta;
310 self.intercept = intercept;
311 self.fitted = true;
312 Ok(())
313 }
314
315 pub fn predict_sparse(&self, features: &CsrMatrix) -> Result<Vec<f64>> {
317 if !self.fitted {
318 return Err(ScryLearnError::NotFitted);
319 }
320 Ok((0..features.n_rows())
321 .map(|i| {
322 let mut y = self.intercept;
323 for (col, val) in features.row(i).iter() {
324 if col < self.coefficients.len() {
325 y += self.coefficients[col] * val;
326 }
327 }
328 y
329 })
330 .collect())
331 }
332
333 pub fn coefficients(&self) -> &[f64] {
335 &self.coefficients
336 }
337
338 pub fn intercept(&self) -> f64 {
340 self.intercept
341 }
342}
343
344impl Default for ElasticNet {
345 fn default() -> Self {
346 Self::new()
347 }
348}
349
350#[inline]
352fn soft_threshold(z: f64, gamma: f64) -> f64 {
353 if z > gamma {
354 z - gamma
355 } else if z < -gamma {
356 z + gamma
357 } else {
358 0.0
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365
366 #[test]
367 fn test_elastic_net_fit_predict() {
368 let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
370 let target = vec![3.0, 5.0, 7.0, 9.0, 11.0];
371 let data = Dataset::new(features, target, vec!["x".into()], "y");
372
373 let mut en = ElasticNet::new().alpha(0.01).l1_ratio(0.5).max_iter(5000);
374 en.fit(&data).unwrap();
375
376 let preds = en.predict(&[vec![3.0]]).unwrap();
377 assert!(
378 (preds[0] - 7.0).abs() < 0.5,
379 "expected ~7.0, got {}",
380 preds[0]
381 );
382 }
383
384 #[test]
385 fn test_elastic_net_pure_lasso() {
386 let n = 50;
388 let mut rng = crate::rng::FastRng::new(42);
389 let mut x1 = Vec::with_capacity(n);
390 let mut x2 = Vec::with_capacity(n);
391 let mut y = Vec::with_capacity(n);
392 for _ in 0..n {
393 let v1 = rng.f64() * 10.0;
394 let v2 = rng.f64() * 10.0;
395 x1.push(v1);
396 x2.push(v2);
397 y.push(3.0 * v1 + 1.0); }
399
400 let data = Dataset::new(vec![x1, x2], y, vec!["x1".into(), "x2".into()], "y");
401
402 let mut en = ElasticNet::new().alpha(0.5).l1_ratio(1.0).max_iter(5000);
403 en.fit(&data).unwrap();
404
405 let coefs = en.coefficients();
406 assert!(
407 coefs[1].abs() < 0.1,
408 "x2 should be ~0 with pure lasso, got {}",
409 coefs[1]
410 );
411 }
412
413 #[test]
414 fn test_elastic_net_not_fitted() {
415 let en = ElasticNet::new();
416 assert!(en.predict(&[vec![1.0]]).is_err());
417 }
418
419 #[test]
420 fn test_sparse_elastic_net_matches_dense() {
421 let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
422 let target = vec![3.0, 5.0, 7.0, 9.0, 11.0];
423 let data = Dataset::new(features.clone(), target.clone(), vec!["x".into()], "y");
424
425 let mut en_dense = ElasticNet::new().alpha(0.01).l1_ratio(0.5).max_iter(5000);
426 en_dense.fit(&data).unwrap();
427
428 let csc = CscMatrix::from_dense(&features);
429 let mut en_sparse = ElasticNet::new().alpha(0.01).l1_ratio(0.5).max_iter(5000);
430 en_sparse.fit_sparse(&csc, &target).unwrap();
431
432 assert!(
433 (en_dense.coefficients()[0] - en_sparse.coefficients()[0]).abs() < 0.1,
434 "Dense={} vs Sparse={}",
435 en_dense.coefficients()[0],
436 en_sparse.coefficients()[0]
437 );
438 }
439}