smartcore/linear/
elastic_net.rs

1#![allow(clippy::needless_range_loop)]
2//! # Elastic Net
3//!
4//! Elastic net is an extension of [linear regression](../linear_regression/index.html) that adds regularization penalties to the loss function during training.
5//! Just like in ordinary linear regression you assume a linear relationship between input variables and the target variable.
6//! Unlike linear regression elastic net adds regularization penalties to the loss function during training.
7//! In particular, the elastic net coefficient estimates \\(\beta\\) are the values that minimize
8//!
9//! \\[L(\alpha, \beta) = \vert \boldsymbol{y} - \boldsymbol{X}\beta\vert^2 + \lambda_1 \vert \beta \vert^2 + \lambda_2 \vert \beta \vert_1\\]
10//!
11//! where \\(\lambda_1 = \\alpha l_{1r}\\), \\(\lambda_2 = \\alpha (1 -  l_{1r})\\) and \\(l_{1r}\\) is the l1 ratio, elastic net mixing parameter.
12//!
13//! In essense, elastic net combines both the [L1](../lasso/index.html) and [L2](../ridge_regression/index.html) penalties during training,
14//! which can result in better performance than a model with either one or the other penalty on some problems.
15//! The elastic net is particularly useful when the number of predictors (p) is much bigger than the number of observations (n).
16//!
17//! Example:
18//!
19//! ```
20//! use smartcore::linalg::basic::matrix::DenseMatrix;
21//! use smartcore::linear::elastic_net::*;
22//!
23//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
24//! let x = DenseMatrix::from_2d_array(&[
25//!               &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
26//!               &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
27//!               &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
28//!               &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
29//!               &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
30//!               &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
31//!               &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
32//!               &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
33//!               &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
34//!               &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
35//!               &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
36//!               &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
37//!               &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
38//!               &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
39//!               &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
40//!               &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
41//!          ]).unwrap();
42//!
43//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
44//!           100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
45//!
46//! let y_hat = ElasticNet::fit(&x, &y, Default::default()).
47//!                 and_then(|lr| lr.predict(&x)).unwrap();
48//! ```
49//!
50//! ## References:
51//!
52//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 6.2. Shrinkage Methods](http://faculty.marshall.usc.edu/gareth-james/ISL/)
53//! * ["Regularization and variable selection via the elastic net",  Hui Zou and Trevor Hastie](https://web.stanford.edu/~hastie/Papers/B67.2%20(2005)%20301-320%20Zou%20&%20Hastie.pdf)
54//!
55//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
56//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
57use std::fmt::Debug;
58use std::marker::PhantomData;
59
60#[cfg(feature = "serde")]
61use serde::{Deserialize, Serialize};
62
63use crate::api::{Predictor, SupervisedEstimator};
64use crate::error::Failed;
65use crate::linalg::basic::arrays::{Array, Array1, Array2, MutArray};
66use crate::numbers::basenum::Number;
67use crate::numbers::floatnum::FloatNumber;
68use crate::numbers::realnum::RealNumber;
69
70use crate::linear::lasso_optimizer::InteriorPointOptimizer;
71
72/// Elastic net parameters
73#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
74#[derive(Debug, Clone)]
75pub struct ElasticNetParameters {
76    #[cfg_attr(feature = "serde", serde(default))]
77    /// Regularization parameter.
78    pub alpha: f64,
79    #[cfg_attr(feature = "serde", serde(default))]
80    /// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
81    /// For l1_ratio = 0 the penalty is an L2 penalty.
82    /// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
83    pub l1_ratio: f64,
84    #[cfg_attr(feature = "serde", serde(default))]
85    /// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
86    pub normalize: bool,
87    #[cfg_attr(feature = "serde", serde(default))]
88    /// The tolerance for the optimization
89    pub tol: f64,
90    #[cfg_attr(feature = "serde", serde(default))]
91    /// The maximum number of iterations
92    pub max_iter: usize,
93}
94
95/// Elastic net
96#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
97#[derive(Debug)]
98pub struct ElasticNet<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
99    coefficients: Option<X>,
100    intercept: Option<TX>,
101    _phantom_ty: PhantomData<TY>,
102    _phantom_y: PhantomData<Y>,
103}
104
105impl ElasticNetParameters {
106    /// Regularization parameter.
107    pub fn with_alpha(mut self, alpha: f64) -> Self {
108        self.alpha = alpha;
109        self
110    }
111    /// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
112    /// For l1_ratio = 0 the penalty is an L2 penalty.
113    /// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
114    pub fn with_l1_ratio(mut self, l1_ratio: f64) -> Self {
115        self.l1_ratio = l1_ratio;
116        self
117    }
118    /// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
119    pub fn with_normalize(mut self, normalize: bool) -> Self {
120        self.normalize = normalize;
121        self
122    }
123    /// The tolerance for the optimization
124    pub fn with_tol(mut self, tol: f64) -> Self {
125        self.tol = tol;
126        self
127    }
128    /// The maximum number of iterations
129    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
130        self.max_iter = max_iter;
131        self
132    }
133}
134
135impl Default for ElasticNetParameters {
136    fn default() -> Self {
137        ElasticNetParameters {
138            alpha: 1.0,
139            l1_ratio: 0.5,
140            normalize: true,
141            tol: 1e-4,
142            max_iter: 1000,
143        }
144    }
145}
146
147/// ElasticNet grid search parameters
148#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
149#[derive(Debug, Clone)]
150pub struct ElasticNetSearchParameters {
151    #[cfg_attr(feature = "serde", serde(default))]
152    /// Regularization parameter.
153    pub alpha: Vec<f64>,
154    #[cfg_attr(feature = "serde", serde(default))]
155    /// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
156    /// For l1_ratio = 0 the penalty is an L2 penalty.
157    /// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
158    pub l1_ratio: Vec<f64>,
159    #[cfg_attr(feature = "serde", serde(default))]
160    /// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
161    pub normalize: Vec<bool>,
162    #[cfg_attr(feature = "serde", serde(default))]
163    /// The tolerance for the optimization
164    pub tol: Vec<f64>,
165    #[cfg_attr(feature = "serde", serde(default))]
166    /// The maximum number of iterations
167    pub max_iter: Vec<usize>,
168}
169
170/// ElasticNet grid search iterator
171pub struct ElasticNetSearchParametersIterator {
172    lasso_regression_search_parameters: ElasticNetSearchParameters,
173    current_alpha: usize,
174    current_l1_ratio: usize,
175    current_normalize: usize,
176    current_tol: usize,
177    current_max_iter: usize,
178}
179
180impl IntoIterator for ElasticNetSearchParameters {
181    type Item = ElasticNetParameters;
182    type IntoIter = ElasticNetSearchParametersIterator;
183
184    fn into_iter(self) -> Self::IntoIter {
185        ElasticNetSearchParametersIterator {
186            lasso_regression_search_parameters: self,
187            current_alpha: 0,
188            current_l1_ratio: 0,
189            current_normalize: 0,
190            current_tol: 0,
191            current_max_iter: 0,
192        }
193    }
194}
195
196impl Iterator for ElasticNetSearchParametersIterator {
197    type Item = ElasticNetParameters;
198
199    fn next(&mut self) -> Option<Self::Item> {
200        if self.current_alpha == self.lasso_regression_search_parameters.alpha.len()
201            && self.current_l1_ratio == self.lasso_regression_search_parameters.l1_ratio.len()
202            && self.current_normalize == self.lasso_regression_search_parameters.normalize.len()
203            && self.current_tol == self.lasso_regression_search_parameters.tol.len()
204            && self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len()
205        {
206            return None;
207        }
208
209        let next = ElasticNetParameters {
210            alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha],
211            l1_ratio: self.lasso_regression_search_parameters.alpha[self.current_l1_ratio],
212            normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize],
213            tol: self.lasso_regression_search_parameters.tol[self.current_tol],
214            max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter],
215        };
216
217        if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() {
218            self.current_alpha += 1;
219        } else if self.current_l1_ratio + 1 < self.lasso_regression_search_parameters.l1_ratio.len()
220        {
221            self.current_alpha = 0;
222            self.current_l1_ratio += 1;
223        } else if self.current_normalize + 1
224            < self.lasso_regression_search_parameters.normalize.len()
225        {
226            self.current_alpha = 0;
227            self.current_l1_ratio = 0;
228            self.current_normalize += 1;
229        } else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() {
230            self.current_alpha = 0;
231            self.current_l1_ratio = 0;
232            self.current_normalize = 0;
233            self.current_tol += 1;
234        } else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len()
235        {
236            self.current_alpha = 0;
237            self.current_l1_ratio = 0;
238            self.current_normalize = 0;
239            self.current_tol = 0;
240            self.current_max_iter += 1;
241        } else {
242            self.current_alpha += 1;
243            self.current_l1_ratio += 1;
244            self.current_normalize += 1;
245            self.current_tol += 1;
246            self.current_max_iter += 1;
247        }
248
249        Some(next)
250    }
251}
252
253impl Default for ElasticNetSearchParameters {
254    fn default() -> Self {
255        let default_params = ElasticNetParameters::default();
256
257        ElasticNetSearchParameters {
258            alpha: vec![default_params.alpha],
259            l1_ratio: vec![default_params.l1_ratio],
260            normalize: vec![default_params.normalize],
261            tol: vec![default_params.tol],
262            max_iter: vec![default_params.max_iter],
263        }
264    }
265}
266
267impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
268    for ElasticNet<TX, TY, X, Y>
269{
270    fn eq(&self, other: &Self) -> bool {
271        if self.intercept() != other.intercept() {
272            return false;
273        }
274        if self.coefficients().shape() != other.coefficients().shape() {
275            return false;
276        }
277        self.coefficients()
278            .iterator(0)
279            .zip(other.coefficients().iterator(0))
280            .all(|(&a, &b)| (a - b).abs() <= TX::epsilon())
281    }
282}
283
284impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
285    SupervisedEstimator<X, Y, ElasticNetParameters> for ElasticNet<TX, TY, X, Y>
286{
287    fn new() -> Self {
288        Self {
289            coefficients: Option::None,
290            intercept: Option::None,
291            _phantom_ty: PhantomData,
292            _phantom_y: PhantomData,
293        }
294    }
295
296    fn fit(x: &X, y: &Y, parameters: ElasticNetParameters) -> Result<Self, Failed> {
297        ElasticNet::fit(x, y, parameters)
298    }
299}
300
301impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Predictor<X, Y>
302    for ElasticNet<TX, TY, X, Y>
303{
304    fn predict(&self, x: &X) -> Result<Y, Failed> {
305        self.predict(x)
306    }
307}
308
309impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
310    ElasticNet<TX, TY, X, Y>
311{
312    /// Fits elastic net regression to your data.
313    /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
314    /// * `y` - target values
315    /// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
316    pub fn fit(
317        x: &X,
318        y: &Y,
319        parameters: ElasticNetParameters,
320    ) -> Result<ElasticNet<TX, TY, X, Y>, Failed> {
321        let (n, p) = x.shape();
322
323        if y.shape() != n {
324            return Err(Failed::fit("Number of rows in X should = len(y)"));
325        }
326
327        let n_float = n as f64;
328
329        let l1_reg = TX::from_f64(parameters.alpha * parameters.l1_ratio * n_float).unwrap();
330        let l2_reg =
331            TX::from_f64(parameters.alpha * (1.0 - parameters.l1_ratio) * n_float).unwrap();
332
333        let y_mean = TX::from_f64(y.mean_by()).unwrap();
334
335        let (w, b) = if parameters.normalize {
336            let (scaled_x, col_mean, col_std) = Self::rescale_x(x)?;
337
338            let (x, y, gamma) = Self::augment_x_and_y(&scaled_x, y, l2_reg);
339
340            let mut optimizer = InteriorPointOptimizer::new(&x, p);
341
342            let mut w = optimizer.optimize(
343                &x,
344                &y,
345                l1_reg * gamma,
346                parameters.max_iter,
347                TX::from_f64(parameters.tol).unwrap(),
348            )?;
349
350            for i in 0..p {
351                w.set(i, gamma * *w.get(i) / col_std[i]);
352            }
353
354            let mut b = TX::zero();
355
356            for i in 0..p {
357                b += *w.get(i) * col_mean[i];
358            }
359
360            b = y_mean - b;
361
362            (X::from_column(&w), b)
363        } else {
364            let (x, y, gamma) = Self::augment_x_and_y(x, y, l2_reg);
365
366            let mut optimizer = InteriorPointOptimizer::new(&x, p);
367
368            let mut w = optimizer.optimize(
369                &x,
370                &y,
371                l1_reg * gamma,
372                parameters.max_iter,
373                TX::from_f64(parameters.tol).unwrap(),
374            )?;
375
376            for i in 0..p {
377                w.set(i, gamma * *w.get(i));
378            }
379
380            (X::from_column(&w), y_mean)
381        };
382
383        Ok(ElasticNet {
384            intercept: Some(b),
385            coefficients: Some(w),
386            _phantom_ty: PhantomData,
387            _phantom_y: PhantomData,
388        })
389    }
390
391    /// Predict target values from `x`
392    /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
393    pub fn predict(&self, x: &X) -> Result<Y, Failed> {
394        let (nrows, _) = x.shape();
395        let mut y_hat = x.matmul(self.coefficients.as_ref().unwrap());
396        let bias = X::fill(nrows, 1, self.intercept.unwrap());
397        y_hat.add_mut(&bias);
398        Ok(Y::from_iterator(
399            y_hat.iterator(0).map(|&v| TY::from(v).unwrap()),
400            nrows,
401        ))
402    }
403
404    /// Get estimates regression coefficients
405    pub fn coefficients(&self) -> &X {
406        self.coefficients.as_ref().unwrap()
407    }
408
409    /// Get estimate of intercept
410    pub fn intercept(&self) -> &TX {
411        self.intercept.as_ref().unwrap()
412    }
413
414    fn rescale_x(x: &X) -> Result<(X, Vec<TX>, Vec<TX>), Failed> {
415        let col_mean: Vec<TX> = x
416            .mean_by(0)
417            .iter()
418            .map(|&v| TX::from_f64(v).unwrap())
419            .collect();
420        let col_std: Vec<TX> = x
421            .std_dev(0)
422            .iter()
423            .map(|&v| TX::from_f64(v).unwrap())
424            .collect();
425
426        for (i, col_std_i) in col_std.iter().enumerate() {
427            if (*col_std_i - TX::zero()).abs() < TX::epsilon() {
428                return Err(Failed::fit(&format!("Cannot rescale constant column {i}")));
429            }
430        }
431
432        let mut scaled_x = x.clone();
433        scaled_x.scale_mut(&col_mean, &col_std, 0);
434        Ok((scaled_x, col_mean, col_std))
435    }
436
437    fn augment_x_and_y(x: &X, y: &Y, l2_reg: TX) -> (X, Vec<TX>, TX) {
438        let (n, p) = x.shape();
439
440        let gamma = TX::one() / (TX::one() + l2_reg).sqrt();
441        let padding = gamma * l2_reg.sqrt();
442
443        let mut y2 = Vec::<TX>::zeros(n + p);
444        for i in 0..y.shape() {
445            y2.set(i, TX::from(*y.get(i)).unwrap());
446        }
447
448        let mut x2 = X::zeros(n + p, p);
449
450        for j in 0..p {
451            for i in 0..n {
452                x2.set((i, j), gamma * *x.get((i, j)));
453            }
454
455            x2.set((j + n, j), padding);
456        }
457
458        (x2, y2, gamma)
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465    use crate::linalg::basic::matrix::DenseMatrix;
466    use crate::metrics::mean_absolute_error;
467
468    #[test]
469    fn search_parameters() {
470        let parameters = ElasticNetSearchParameters {
471            alpha: vec![0., 1.],
472            max_iter: vec![10, 100],
473            ..Default::default()
474        };
475        let mut iter = parameters.into_iter();
476        let next = iter.next().unwrap();
477        assert_eq!(next.alpha, 0.);
478        assert_eq!(next.max_iter, 10);
479        let next = iter.next().unwrap();
480        assert_eq!(next.alpha, 1.);
481        assert_eq!(next.max_iter, 10);
482        let next = iter.next().unwrap();
483        assert_eq!(next.alpha, 0.);
484        assert_eq!(next.max_iter, 100);
485        let next = iter.next().unwrap();
486        assert_eq!(next.alpha, 1.);
487        assert_eq!(next.max_iter, 100);
488        assert!(iter.next().is_none());
489    }
490
491    #[cfg_attr(
492        all(target_arch = "wasm32", not(target_os = "wasi")),
493        wasm_bindgen_test::wasm_bindgen_test
494    )]
495    #[test]
496    fn elasticnet_longley() {
497        let x = DenseMatrix::from_2d_array(&[
498            &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
499            &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
500            &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
501            &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
502            &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
503            &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
504            &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
505            &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
506            &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
507            &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
508            &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
509            &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
510            &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
511            &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
512            &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
513            &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
514        ])
515        .unwrap();
516
517        let y: Vec<f64> = vec![
518            83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
519            114.2, 115.7, 116.9,
520        ];
521
522        let y_hat = ElasticNet::fit(
523            &x,
524            &y,
525            ElasticNetParameters {
526                alpha: 1.0,
527                l1_ratio: 0.5,
528                normalize: false,
529                tol: 1e-4,
530                max_iter: 1000,
531            },
532        )
533        .and_then(|lr| lr.predict(&x))
534        .unwrap();
535
536        assert!(mean_absolute_error(&y_hat, &y) < 30.0);
537    }
538
539    #[cfg_attr(
540        all(target_arch = "wasm32", not(target_os = "wasi")),
541        wasm_bindgen_test::wasm_bindgen_test
542    )]
543    #[test]
544    fn elasticnet_fit_predict1() {
545        let x = DenseMatrix::from_2d_array(&[
546            &[0.0, 1931.0, 1.2232755825400514],
547            &[1.0, 1933.0, 1.1379726120972395],
548            &[2.0, 1920.0, 1.4366265120543429],
549            &[3.0, 1918.0, 1.206005737827858],
550            &[4.0, 1934.0, 1.436613542400669],
551            &[5.0, 1918.0, 1.1594588621640636],
552            &[6.0, 1933.0, 1.19809994745985],
553            &[7.0, 1918.0, 1.3396363871645678],
554            &[8.0, 1931.0, 1.2535342096493207],
555            &[9.0, 1933.0, 1.3101281563456293],
556            &[10.0, 1922.0, 1.3585833349920762],
557            &[11.0, 1930.0, 1.4830786699709897],
558            &[12.0, 1916.0, 1.4919891143094546],
559            &[13.0, 1915.0, 1.259655137451551],
560            &[14.0, 1932.0, 1.3979191428724789],
561            &[15.0, 1917.0, 1.3686634746782371],
562            &[16.0, 1932.0, 1.381658454569724],
563            &[17.0, 1918.0, 1.4054969025700674],
564            &[18.0, 1929.0, 1.3271699396384906],
565            &[19.0, 1915.0, 1.1373332337674806],
566        ])
567        .unwrap();
568
569        let y: Vec<f64> = vec![
570            1.48, 2.72, 4.52, 5.72, 5.25, 4.07, 3.75, 4.75, 6.77, 4.72, 6.78, 6.79, 8.3, 7.42,
571            10.2, 7.92, 7.62, 8.06, 9.06, 9.29,
572        ];
573
574        let l1_model = ElasticNet::fit(
575            &x,
576            &y,
577            ElasticNetParameters {
578                alpha: 1.0,
579                l1_ratio: 1.0,
580                normalize: true,
581                tol: 1e-4,
582                max_iter: 1000,
583            },
584        )
585        .unwrap();
586
587        let l2_model = ElasticNet::fit(
588            &x,
589            &y,
590            ElasticNetParameters {
591                alpha: 1.0,
592                l1_ratio: 0.0,
593                normalize: true,
594                tol: 1e-4,
595                max_iter: 1000,
596            },
597        )
598        .unwrap();
599
600        let mae_l1 = mean_absolute_error(&l1_model.predict(&x).unwrap(), &y);
601        let mae_l2 = mean_absolute_error(&l2_model.predict(&x).unwrap(), &y);
602
603        assert!(mae_l1 < 2.0);
604        assert!(mae_l2 < 2.0);
605
606        assert!(l1_model.coefficients().get((0, 0)) > l1_model.coefficients().get((1, 0)));
607        assert!(l1_model.coefficients().get((0, 0)) > l1_model.coefficients().get((2, 0)));
608    }
609
610    // TODO: serialization for the new DenseMatrix needs to be implemented
611    // #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
612    // #[test]
613    // #[cfg(feature = "serde")]
614    // fn serde() {
615    //     let x = DenseMatrix::from_2d_array(&[
616    //         &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
617    //         &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
618    //         &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
619    //         &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
620    //         &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
621    //         &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
622    //         &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
623    //         &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
624    //         &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
625    //         &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
626    //         &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
627    //         &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
628    //         &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
629    //         &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
630    //         &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
631    //         &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
632    //     ]).unwrap();
633
634    //     let y = vec![
635    //         83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
636    //         114.2, 115.7, 116.9,
637    //     ];
638
639    //     let lr = ElasticNet::fit(&x, &y, Default::default()).unwrap();
640
641    //     let deserialized_lr: ElasticNet<f64, f64, DenseMatrix<f64>, Vec<f64>> =
642    //         serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
643
644    //     assert_eq!(lr, deserialized_lr);
645    // }
646}