1#![allow(clippy::needless_range_loop)]
2use std::fmt::Debug;
58use std::marker::PhantomData;
59
60#[cfg(feature = "serde")]
61use serde::{Deserialize, Serialize};
62
63use crate::api::{Predictor, SupervisedEstimator};
64use crate::error::Failed;
65use crate::linalg::basic::arrays::{Array, Array1, Array2, MutArray};
66use crate::numbers::basenum::Number;
67use crate::numbers::floatnum::FloatNumber;
68use crate::numbers::realnum::RealNumber;
69
70use crate::linear::lasso_optimizer::InteriorPointOptimizer;
71
72#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
74#[derive(Debug, Clone)]
75pub struct ElasticNetParameters {
76 #[cfg_attr(feature = "serde", serde(default))]
77 pub alpha: f64,
79 #[cfg_attr(feature = "serde", serde(default))]
80 pub l1_ratio: f64,
84 #[cfg_attr(feature = "serde", serde(default))]
85 pub normalize: bool,
87 #[cfg_attr(feature = "serde", serde(default))]
88 pub tol: f64,
90 #[cfg_attr(feature = "serde", serde(default))]
91 pub max_iter: usize,
93}
94
95#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
97#[derive(Debug)]
98pub struct ElasticNet<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
99 coefficients: Option<X>,
100 intercept: Option<TX>,
101 _phantom_ty: PhantomData<TY>,
102 _phantom_y: PhantomData<Y>,
103}
104
105impl ElasticNetParameters {
106 pub fn with_alpha(mut self, alpha: f64) -> Self {
108 self.alpha = alpha;
109 self
110 }
111 pub fn with_l1_ratio(mut self, l1_ratio: f64) -> Self {
115 self.l1_ratio = l1_ratio;
116 self
117 }
118 pub fn with_normalize(mut self, normalize: bool) -> Self {
120 self.normalize = normalize;
121 self
122 }
123 pub fn with_tol(mut self, tol: f64) -> Self {
125 self.tol = tol;
126 self
127 }
128 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
130 self.max_iter = max_iter;
131 self
132 }
133}
134
135impl Default for ElasticNetParameters {
136 fn default() -> Self {
137 ElasticNetParameters {
138 alpha: 1.0,
139 l1_ratio: 0.5,
140 normalize: true,
141 tol: 1e-4,
142 max_iter: 1000,
143 }
144 }
145}
146
147#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
149#[derive(Debug, Clone)]
150pub struct ElasticNetSearchParameters {
151 #[cfg_attr(feature = "serde", serde(default))]
152 pub alpha: Vec<f64>,
154 #[cfg_attr(feature = "serde", serde(default))]
155 pub l1_ratio: Vec<f64>,
159 #[cfg_attr(feature = "serde", serde(default))]
160 pub normalize: Vec<bool>,
162 #[cfg_attr(feature = "serde", serde(default))]
163 pub tol: Vec<f64>,
165 #[cfg_attr(feature = "serde", serde(default))]
166 pub max_iter: Vec<usize>,
168}
169
170pub struct ElasticNetSearchParametersIterator {
172 lasso_regression_search_parameters: ElasticNetSearchParameters,
173 current_alpha: usize,
174 current_l1_ratio: usize,
175 current_normalize: usize,
176 current_tol: usize,
177 current_max_iter: usize,
178}
179
180impl IntoIterator for ElasticNetSearchParameters {
181 type Item = ElasticNetParameters;
182 type IntoIter = ElasticNetSearchParametersIterator;
183
184 fn into_iter(self) -> Self::IntoIter {
185 ElasticNetSearchParametersIterator {
186 lasso_regression_search_parameters: self,
187 current_alpha: 0,
188 current_l1_ratio: 0,
189 current_normalize: 0,
190 current_tol: 0,
191 current_max_iter: 0,
192 }
193 }
194}
195
196impl Iterator for ElasticNetSearchParametersIterator {
197 type Item = ElasticNetParameters;
198
199 fn next(&mut self) -> Option<Self::Item> {
200 if self.current_alpha == self.lasso_regression_search_parameters.alpha.len()
201 && self.current_l1_ratio == self.lasso_regression_search_parameters.l1_ratio.len()
202 && self.current_normalize == self.lasso_regression_search_parameters.normalize.len()
203 && self.current_tol == self.lasso_regression_search_parameters.tol.len()
204 && self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len()
205 {
206 return None;
207 }
208
209 let next = ElasticNetParameters {
210 alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha],
211 l1_ratio: self.lasso_regression_search_parameters.alpha[self.current_l1_ratio],
212 normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize],
213 tol: self.lasso_regression_search_parameters.tol[self.current_tol],
214 max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter],
215 };
216
217 if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() {
218 self.current_alpha += 1;
219 } else if self.current_l1_ratio + 1 < self.lasso_regression_search_parameters.l1_ratio.len()
220 {
221 self.current_alpha = 0;
222 self.current_l1_ratio += 1;
223 } else if self.current_normalize + 1
224 < self.lasso_regression_search_parameters.normalize.len()
225 {
226 self.current_alpha = 0;
227 self.current_l1_ratio = 0;
228 self.current_normalize += 1;
229 } else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() {
230 self.current_alpha = 0;
231 self.current_l1_ratio = 0;
232 self.current_normalize = 0;
233 self.current_tol += 1;
234 } else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len()
235 {
236 self.current_alpha = 0;
237 self.current_l1_ratio = 0;
238 self.current_normalize = 0;
239 self.current_tol = 0;
240 self.current_max_iter += 1;
241 } else {
242 self.current_alpha += 1;
243 self.current_l1_ratio += 1;
244 self.current_normalize += 1;
245 self.current_tol += 1;
246 self.current_max_iter += 1;
247 }
248
249 Some(next)
250 }
251}
252
253impl Default for ElasticNetSearchParameters {
254 fn default() -> Self {
255 let default_params = ElasticNetParameters::default();
256
257 ElasticNetSearchParameters {
258 alpha: vec![default_params.alpha],
259 l1_ratio: vec![default_params.l1_ratio],
260 normalize: vec![default_params.normalize],
261 tol: vec![default_params.tol],
262 max_iter: vec![default_params.max_iter],
263 }
264 }
265}
266
267impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
268 for ElasticNet<TX, TY, X, Y>
269{
270 fn eq(&self, other: &Self) -> bool {
271 if self.intercept() != other.intercept() {
272 return false;
273 }
274 if self.coefficients().shape() != other.coefficients().shape() {
275 return false;
276 }
277 self.coefficients()
278 .iterator(0)
279 .zip(other.coefficients().iterator(0))
280 .all(|(&a, &b)| (a - b).abs() <= TX::epsilon())
281 }
282}
283
284impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
285 SupervisedEstimator<X, Y, ElasticNetParameters> for ElasticNet<TX, TY, X, Y>
286{
287 fn new() -> Self {
288 Self {
289 coefficients: Option::None,
290 intercept: Option::None,
291 _phantom_ty: PhantomData,
292 _phantom_y: PhantomData,
293 }
294 }
295
296 fn fit(x: &X, y: &Y, parameters: ElasticNetParameters) -> Result<Self, Failed> {
297 ElasticNet::fit(x, y, parameters)
298 }
299}
300
301impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Predictor<X, Y>
302 for ElasticNet<TX, TY, X, Y>
303{
304 fn predict(&self, x: &X) -> Result<Y, Failed> {
305 self.predict(x)
306 }
307}
308
309impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
310 ElasticNet<TX, TY, X, Y>
311{
312 pub fn fit(
317 x: &X,
318 y: &Y,
319 parameters: ElasticNetParameters,
320 ) -> Result<ElasticNet<TX, TY, X, Y>, Failed> {
321 let (n, p) = x.shape();
322
323 if y.shape() != n {
324 return Err(Failed::fit("Number of rows in X should = len(y)"));
325 }
326
327 let n_float = n as f64;
328
329 let l1_reg = TX::from_f64(parameters.alpha * parameters.l1_ratio * n_float).unwrap();
330 let l2_reg =
331 TX::from_f64(parameters.alpha * (1.0 - parameters.l1_ratio) * n_float).unwrap();
332
333 let y_mean = TX::from_f64(y.mean_by()).unwrap();
334
335 let (w, b) = if parameters.normalize {
336 let (scaled_x, col_mean, col_std) = Self::rescale_x(x)?;
337
338 let (x, y, gamma) = Self::augment_x_and_y(&scaled_x, y, l2_reg);
339
340 let mut optimizer = InteriorPointOptimizer::new(&x, p);
341
342 let mut w = optimizer.optimize(
343 &x,
344 &y,
345 l1_reg * gamma,
346 parameters.max_iter,
347 TX::from_f64(parameters.tol).unwrap(),
348 )?;
349
350 for i in 0..p {
351 w.set(i, gamma * *w.get(i) / col_std[i]);
352 }
353
354 let mut b = TX::zero();
355
356 for i in 0..p {
357 b += *w.get(i) * col_mean[i];
358 }
359
360 b = y_mean - b;
361
362 (X::from_column(&w), b)
363 } else {
364 let (x, y, gamma) = Self::augment_x_and_y(x, y, l2_reg);
365
366 let mut optimizer = InteriorPointOptimizer::new(&x, p);
367
368 let mut w = optimizer.optimize(
369 &x,
370 &y,
371 l1_reg * gamma,
372 parameters.max_iter,
373 TX::from_f64(parameters.tol).unwrap(),
374 )?;
375
376 for i in 0..p {
377 w.set(i, gamma * *w.get(i));
378 }
379
380 (X::from_column(&w), y_mean)
381 };
382
383 Ok(ElasticNet {
384 intercept: Some(b),
385 coefficients: Some(w),
386 _phantom_ty: PhantomData,
387 _phantom_y: PhantomData,
388 })
389 }
390
391 pub fn predict(&self, x: &X) -> Result<Y, Failed> {
394 let (nrows, _) = x.shape();
395 let mut y_hat = x.matmul(self.coefficients.as_ref().unwrap());
396 let bias = X::fill(nrows, 1, self.intercept.unwrap());
397 y_hat.add_mut(&bias);
398 Ok(Y::from_iterator(
399 y_hat.iterator(0).map(|&v| TY::from(v).unwrap()),
400 nrows,
401 ))
402 }
403
404 pub fn coefficients(&self) -> &X {
406 self.coefficients.as_ref().unwrap()
407 }
408
409 pub fn intercept(&self) -> &TX {
411 self.intercept.as_ref().unwrap()
412 }
413
414 fn rescale_x(x: &X) -> Result<(X, Vec<TX>, Vec<TX>), Failed> {
415 let col_mean: Vec<TX> = x
416 .mean_by(0)
417 .iter()
418 .map(|&v| TX::from_f64(v).unwrap())
419 .collect();
420 let col_std: Vec<TX> = x
421 .std_dev(0)
422 .iter()
423 .map(|&v| TX::from_f64(v).unwrap())
424 .collect();
425
426 for (i, col_std_i) in col_std.iter().enumerate() {
427 if (*col_std_i - TX::zero()).abs() < TX::epsilon() {
428 return Err(Failed::fit(&format!("Cannot rescale constant column {i}")));
429 }
430 }
431
432 let mut scaled_x = x.clone();
433 scaled_x.scale_mut(&col_mean, &col_std, 0);
434 Ok((scaled_x, col_mean, col_std))
435 }
436
437 fn augment_x_and_y(x: &X, y: &Y, l2_reg: TX) -> (X, Vec<TX>, TX) {
438 let (n, p) = x.shape();
439
440 let gamma = TX::one() / (TX::one() + l2_reg).sqrt();
441 let padding = gamma * l2_reg.sqrt();
442
443 let mut y2 = Vec::<TX>::zeros(n + p);
444 for i in 0..y.shape() {
445 y2.set(i, TX::from(*y.get(i)).unwrap());
446 }
447
448 let mut x2 = X::zeros(n + p, p);
449
450 for j in 0..p {
451 for i in 0..n {
452 x2.set((i, j), gamma * *x.get((i, j)));
453 }
454
455 x2.set((j + n, j), padding);
456 }
457
458 (x2, y2, gamma)
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465 use crate::linalg::basic::matrix::DenseMatrix;
466 use crate::metrics::mean_absolute_error;
467
468 #[test]
469 fn search_parameters() {
470 let parameters = ElasticNetSearchParameters {
471 alpha: vec![0., 1.],
472 max_iter: vec![10, 100],
473 ..Default::default()
474 };
475 let mut iter = parameters.into_iter();
476 let next = iter.next().unwrap();
477 assert_eq!(next.alpha, 0.);
478 assert_eq!(next.max_iter, 10);
479 let next = iter.next().unwrap();
480 assert_eq!(next.alpha, 1.);
481 assert_eq!(next.max_iter, 10);
482 let next = iter.next().unwrap();
483 assert_eq!(next.alpha, 0.);
484 assert_eq!(next.max_iter, 100);
485 let next = iter.next().unwrap();
486 assert_eq!(next.alpha, 1.);
487 assert_eq!(next.max_iter, 100);
488 assert!(iter.next().is_none());
489 }
490
491 #[cfg_attr(
492 all(target_arch = "wasm32", not(target_os = "wasi")),
493 wasm_bindgen_test::wasm_bindgen_test
494 )]
495 #[test]
496 fn elasticnet_longley() {
497 let x = DenseMatrix::from_2d_array(&[
498 &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
499 &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
500 &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
501 &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
502 &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
503 &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
504 &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
505 &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
506 &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
507 &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
508 &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
509 &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
510 &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
511 &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
512 &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
513 &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
514 ])
515 .unwrap();
516
517 let y: Vec<f64> = vec![
518 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
519 114.2, 115.7, 116.9,
520 ];
521
522 let y_hat = ElasticNet::fit(
523 &x,
524 &y,
525 ElasticNetParameters {
526 alpha: 1.0,
527 l1_ratio: 0.5,
528 normalize: false,
529 tol: 1e-4,
530 max_iter: 1000,
531 },
532 )
533 .and_then(|lr| lr.predict(&x))
534 .unwrap();
535
536 assert!(mean_absolute_error(&y_hat, &y) < 30.0);
537 }
538
539 #[cfg_attr(
540 all(target_arch = "wasm32", not(target_os = "wasi")),
541 wasm_bindgen_test::wasm_bindgen_test
542 )]
543 #[test]
544 fn elasticnet_fit_predict1() {
545 let x = DenseMatrix::from_2d_array(&[
546 &[0.0, 1931.0, 1.2232755825400514],
547 &[1.0, 1933.0, 1.1379726120972395],
548 &[2.0, 1920.0, 1.4366265120543429],
549 &[3.0, 1918.0, 1.206005737827858],
550 &[4.0, 1934.0, 1.436613542400669],
551 &[5.0, 1918.0, 1.1594588621640636],
552 &[6.0, 1933.0, 1.19809994745985],
553 &[7.0, 1918.0, 1.3396363871645678],
554 &[8.0, 1931.0, 1.2535342096493207],
555 &[9.0, 1933.0, 1.3101281563456293],
556 &[10.0, 1922.0, 1.3585833349920762],
557 &[11.0, 1930.0, 1.4830786699709897],
558 &[12.0, 1916.0, 1.4919891143094546],
559 &[13.0, 1915.0, 1.259655137451551],
560 &[14.0, 1932.0, 1.3979191428724789],
561 &[15.0, 1917.0, 1.3686634746782371],
562 &[16.0, 1932.0, 1.381658454569724],
563 &[17.0, 1918.0, 1.4054969025700674],
564 &[18.0, 1929.0, 1.3271699396384906],
565 &[19.0, 1915.0, 1.1373332337674806],
566 ])
567 .unwrap();
568
569 let y: Vec<f64> = vec![
570 1.48, 2.72, 4.52, 5.72, 5.25, 4.07, 3.75, 4.75, 6.77, 4.72, 6.78, 6.79, 8.3, 7.42,
571 10.2, 7.92, 7.62, 8.06, 9.06, 9.29,
572 ];
573
574 let l1_model = ElasticNet::fit(
575 &x,
576 &y,
577 ElasticNetParameters {
578 alpha: 1.0,
579 l1_ratio: 1.0,
580 normalize: true,
581 tol: 1e-4,
582 max_iter: 1000,
583 },
584 )
585 .unwrap();
586
587 let l2_model = ElasticNet::fit(
588 &x,
589 &y,
590 ElasticNetParameters {
591 alpha: 1.0,
592 l1_ratio: 0.0,
593 normalize: true,
594 tol: 1e-4,
595 max_iter: 1000,
596 },
597 )
598 .unwrap();
599
600 let mae_l1 = mean_absolute_error(&l1_model.predict(&x).unwrap(), &y);
601 let mae_l2 = mean_absolute_error(&l2_model.predict(&x).unwrap(), &y);
602
603 assert!(mae_l1 < 2.0);
604 assert!(mae_l2 < 2.0);
605
606 assert!(l1_model.coefficients().get((0, 0)) > l1_model.coefficients().get((1, 0)));
607 assert!(l1_model.coefficients().get((0, 0)) > l1_model.coefficients().get((2, 0)));
608 }
609
610 }