smartcore/linear/
linear_regression.rs

1//! # Linear Regression
2//!
3//! Linear regression is a very straightforward approach for predicting a quantitative response \\(y\\) on the basis of a linear combination of explanatory variables \\(X\\).
4//! Linear regression assumes that there is approximately a linear relationship between \\(X\\) and \\(y\\). Formally, we can write this linear relationship as
5//!
6//! \\[y \approx \beta_0 + \sum_{i=1}^n \beta_iX_i + \epsilon\\]
7//!
8//! where \\(\epsilon\\) is a mean-zero random error term and the regression coefficients \\(\beta_0, \beta_0, ... \beta_n\\) are unknown, and must be estimated.
9//!
10//! While regression coefficients can be estimated directly by solving
11//!
12//! \\[\hat{\beta} = (X^TX)^{-1}X^Ty \\]
13//!
14//! the \\((X^TX)^{-1}\\) term is both computationally expensive and numerically unstable. An alternative approach is to use a matrix decomposition to avoid this operation.
15//! `smartcore` uses [SVD](../../linalg/svd/index.html) and [QR](../../linalg/qr/index.html) matrix decomposition to find estimates of \\(\hat{\beta}\\).
16//! The QR decomposition is more computationally efficient and more numerically stable than calculating the normal equation directly,
17//! but does not work for all data matrices. Unlike the QR decomposition, all matrices have an SVD decomposition.
18//!
19//! Example:
20//!
21//! ```
22//! use smartcore::linalg::basic::matrix::DenseMatrix;
23//! use smartcore::linear::linear_regression::*;
24//!
25//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
26//! let x = DenseMatrix::from_2d_array(&[
27//!               &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
28//!               &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
29//!               &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
30//!               &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
31//!               &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
32//!               &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
33//!               &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
34//!               &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
35//!               &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
36//!               &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
37//!               &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
38//!               &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
39//!               &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
40//!               &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
41//!               &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
42//!               &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
43//!          ]).unwrap();
44//!
45//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
46//!           100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
47//!
48//! let lr = LinearRegression::fit(&x, &y,
49//!             LinearRegressionParameters::default().
50//!             with_solver(LinearRegressionSolverName::QR)).unwrap();
51//!
52//! let y_hat = lr.predict(&x).unwrap();
53//! ```
54//!
55//! ## References:
56//!
57//! * ["Pattern Recognition and Machine Learning", C.M. Bishop, Linear Models for Regression](https://www.microsoft.com/en-us/research/uploads/prod/2006/01/Bishop-Pattern-Recognition-and-Machine-Learning-2006.pdf)
58//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 3. Linear Regression](http://faculty.marshall.usc.edu/gareth-james/ISL/)
59//! * ["Numerical Recipes: The Art of Scientific Computing",  Press W.H., Teukolsky S.A., Vetterling W.T, Flannery B.P, 3rd ed., Section 15.4 General Linear Least Squares](http://numerical.recipes/)
60//!
61//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
62//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
63use std::fmt::Debug;
64use std::marker::PhantomData;
65
66#[cfg(feature = "serde")]
67use serde::{Deserialize, Serialize};
68
69use crate::api::{Predictor, SupervisedEstimator};
70use crate::error::Failed;
71use crate::linalg::basic::arrays::{Array1, Array2};
72use crate::linalg::traits::qr::QRDecomposable;
73use crate::linalg::traits::svd::SVDDecomposable;
74use crate::numbers::basenum::Number;
75use crate::numbers::realnum::RealNumber;
76
77#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
78#[derive(Debug, Default, Clone, Eq, PartialEq)]
79/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
80pub enum LinearRegressionSolverName {
81    /// QR decomposition, see [QR](../../linalg/qr/index.html)
82    QR,
83    #[default]
84    /// SVD decomposition, see [SVD](../../linalg/svd/index.html)
85    SVD,
86}
87
88/// Linear Regression parameters
89#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
90#[derive(Debug, Clone)]
91pub struct LinearRegressionParameters {
92    #[cfg_attr(feature = "serde", serde(default))]
93    /// Solver to use for estimation of regression coefficients.
94    pub solver: LinearRegressionSolverName,
95}
96
97impl Default for LinearRegressionParameters {
98    fn default() -> Self {
99        LinearRegressionParameters {
100            solver: LinearRegressionSolverName::SVD,
101        }
102    }
103}
104
105/// Linear Regression
106#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
107#[derive(Debug)]
108pub struct LinearRegression<
109    TX: Number + RealNumber,
110    TY: Number,
111    X: Array2<TX> + QRDecomposable<TX> + SVDDecomposable<TX>,
112    Y: Array1<TY>,
113> {
114    coefficients: Option<X>,
115    intercept: Option<TX>,
116    _phantom_ty: PhantomData<TY>,
117    _phantom_y: PhantomData<Y>,
118}
119
120impl LinearRegressionParameters {
121    /// Solver to use for estimation of regression coefficients.
122    pub fn with_solver(mut self, solver: LinearRegressionSolverName) -> Self {
123        self.solver = solver;
124        self
125    }
126}
127
128/// Linear Regression grid search parameters
129#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
130#[derive(Debug, Clone)]
131pub struct LinearRegressionSearchParameters {
132    #[cfg_attr(feature = "serde", serde(default))]
133    /// Solver to use for estimation of regression coefficients.
134    pub solver: Vec<LinearRegressionSolverName>,
135}
136
137/// Linear Regression grid search iterator
138pub struct LinearRegressionSearchParametersIterator {
139    linear_regression_search_parameters: LinearRegressionSearchParameters,
140    current_solver: usize,
141}
142
143impl IntoIterator for LinearRegressionSearchParameters {
144    type Item = LinearRegressionParameters;
145    type IntoIter = LinearRegressionSearchParametersIterator;
146
147    fn into_iter(self) -> Self::IntoIter {
148        LinearRegressionSearchParametersIterator {
149            linear_regression_search_parameters: self,
150            current_solver: 0,
151        }
152    }
153}
154
155impl Iterator for LinearRegressionSearchParametersIterator {
156    type Item = LinearRegressionParameters;
157
158    fn next(&mut self) -> Option<Self::Item> {
159        if self.current_solver == self.linear_regression_search_parameters.solver.len() {
160            return None;
161        }
162
163        let next = LinearRegressionParameters {
164            solver: self.linear_regression_search_parameters.solver[self.current_solver].clone(),
165        };
166
167        self.current_solver += 1;
168
169        Some(next)
170    }
171}
172
173impl Default for LinearRegressionSearchParameters {
174    fn default() -> Self {
175        let default_params = LinearRegressionParameters::default();
176
177        LinearRegressionSearchParameters {
178            solver: vec![default_params.solver],
179        }
180    }
181}
182
183impl<
184        TX: Number + RealNumber,
185        TY: Number,
186        X: Array2<TX> + QRDecomposable<TX> + SVDDecomposable<TX>,
187        Y: Array1<TY>,
188    > PartialEq for LinearRegression<TX, TY, X, Y>
189{
190    fn eq(&self, other: &Self) -> bool {
191        self.intercept == other.intercept
192            && self.coefficients().shape() == other.coefficients().shape()
193            && self
194                .coefficients()
195                .iterator(0)
196                .zip(other.coefficients().iterator(0))
197                .all(|(&a, &b)| (a - b).abs() <= TX::epsilon())
198    }
199}
200
201impl<
202        TX: Number + RealNumber,
203        TY: Number,
204        X: Array2<TX> + QRDecomposable<TX> + SVDDecomposable<TX>,
205        Y: Array1<TY>,
206    > SupervisedEstimator<X, Y, LinearRegressionParameters> for LinearRegression<TX, TY, X, Y>
207{
208    fn new() -> Self {
209        Self {
210            coefficients: Option::None,
211            intercept: Option::None,
212            _phantom_ty: PhantomData,
213            _phantom_y: PhantomData,
214        }
215    }
216
217    fn fit(x: &X, y: &Y, parameters: LinearRegressionParameters) -> Result<Self, Failed> {
218        LinearRegression::fit(x, y, parameters)
219    }
220}
221
222impl<
223        TX: Number + RealNumber,
224        TY: Number,
225        X: Array2<TX> + QRDecomposable<TX> + SVDDecomposable<TX>,
226        Y: Array1<TY>,
227    > Predictor<X, Y> for LinearRegression<TX, TY, X, Y>
228{
229    fn predict(&self, x: &X) -> Result<Y, Failed> {
230        self.predict(x)
231    }
232}
233
234impl<
235        TX: Number + RealNumber,
236        TY: Number,
237        X: Array2<TX> + QRDecomposable<TX> + SVDDecomposable<TX>,
238        Y: Array1<TY>,
239    > LinearRegression<TX, TY, X, Y>
240{
241    /// Fits Linear Regression to your data.
242    /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
243    /// * `y` - target values
244    /// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
245    pub fn fit(
246        x: &X,
247        y: &Y,
248        parameters: LinearRegressionParameters,
249    ) -> Result<LinearRegression<TX, TY, X, Y>, Failed> {
250        let b = X::from_iterator(
251            y.iterator(0).map(|&v| TX::from(v).unwrap()),
252            y.shape(),
253            1,
254            0,
255        );
256        let (x_nrows, num_attributes) = x.shape();
257        let (y_nrows, _) = b.shape();
258
259        if x_nrows != y_nrows {
260            return Err(Failed::fit(
261                "Number of rows of X doesn\'t match number of rows of Y",
262            ));
263        }
264
265        let a = x.h_stack(&X::ones(x_nrows, 1));
266
267        let w = match parameters.solver {
268            LinearRegressionSolverName::QR => a.qr_solve_mut(b)?,
269            LinearRegressionSolverName::SVD => a.svd_solve_mut(b)?,
270        };
271
272        let weights = X::from_slice(w.slice(0..num_attributes, 0..1).as_ref());
273
274        Ok(LinearRegression {
275            intercept: Some(*w.get((num_attributes, 0))),
276            coefficients: Some(weights),
277            _phantom_ty: PhantomData,
278            _phantom_y: PhantomData,
279        })
280    }
281
282    /// Predict target values from `x`
283    /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
284    pub fn predict(&self, x: &X) -> Result<Y, Failed> {
285        let (nrows, _) = x.shape();
286        let bias = X::fill(nrows, 1, *self.intercept());
287        let mut y_hat = x.matmul(self.coefficients());
288        y_hat.add_mut(&bias);
289        Ok(Y::from_iterator(
290            y_hat.iterator(0).map(|&v| TY::from(v).unwrap()),
291            nrows,
292        ))
293    }
294
295    /// Get estimates regression coefficients
296    pub fn coefficients(&self) -> &X {
297        self.coefficients.as_ref().unwrap()
298    }
299
300    /// Get estimate of intercept
301    pub fn intercept(&self) -> &TX {
302        self.intercept.as_ref().unwrap()
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309    use crate::linalg::basic::matrix::DenseMatrix;
310
311    #[test]
312    fn search_parameters() {
313        let parameters = LinearRegressionSearchParameters {
314            solver: vec![
315                LinearRegressionSolverName::QR,
316                LinearRegressionSolverName::SVD,
317            ],
318        };
319        let mut iter = parameters.into_iter();
320        assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::QR);
321        assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::SVD);
322        assert!(iter.next().is_none());
323    }
324
325    #[cfg_attr(
326        all(target_arch = "wasm32", not(target_os = "wasi")),
327        wasm_bindgen_test::wasm_bindgen_test
328    )]
329    #[test]
330    fn ols_fit_predict() {
331        let x = DenseMatrix::from_2d_array(&[
332            &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
333            &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
334            &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
335            &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
336            &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
337            &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
338            &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
339            &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
340            &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
341            &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
342            &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
343            &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
344        ])
345        .unwrap();
346
347        let y: Vec<f64> = vec![
348            83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8,
349        ];
350
351        let y_hat_qr = LinearRegression::fit(
352            &x,
353            &y,
354            LinearRegressionParameters {
355                solver: LinearRegressionSolverName::QR,
356            },
357        )
358        .and_then(|lr| lr.predict(&x))
359        .unwrap();
360
361        let y_hat_svd = LinearRegression::fit(&x, &y, Default::default())
362            .and_then(|lr| lr.predict(&x))
363            .unwrap();
364
365        assert!(y
366            .iter()
367            .zip(y_hat_qr.iter())
368            .all(|(&a, &b)| (a - b).abs() <= 5.0));
369        assert!(y
370            .iter()
371            .zip(y_hat_svd.iter())
372            .all(|(&a, &b)| (a - b).abs() <= 5.0));
373    }
374
375    // TODO: serialization for the new DenseMatrix needs to be implemented
376    // #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
377    // #[test]
378    // #[cfg(feature = "serde")]
379    // fn serde() {
380    //     let x = DenseMatrix::from_2d_array(&[
381    //         &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
382    //         &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
383    //         &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
384    //         &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
385    //         &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
386    //         &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
387    //         &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
388    //         &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
389    //         &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
390    //         &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
391    //         &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
392    //         &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
393    //         &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
394    //         &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
395    //         &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
396    //         &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
397    //     ]).unwrap();
398
399    //     let y = vec![
400    //         83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
401    //         114.2, 115.7, 116.9,
402    //     ];
403
404    //     let lr = LinearRegression::fit(&x, &y, Default::default()).unwrap();
405
406    //     let deserialized_lr: LinearRegression<f64, f64, DenseMatrix<f64>, Vec<f64>> =
407    //         serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
408
409    //     assert_eq!(lr, deserialized_lr);
410
411    //     let default = LinearRegressionParameters::default();
412    //     let parameters: LinearRegressionParameters = serde_json::from_str("{}").unwrap();
413    //     assert_eq!(parameters.solver, default.solver);
414    // }
415}