rusty_machine/learning/
lin_reg.rs

1//! Linear Regression module
2//!
3//! Contains implemention of linear regression using
4//! OLS and gradient descent optimization.
5//!
6//! The regressor will automatically add the intercept term
7//! so you do not need to format the input matrices yourself.
8//!
9//! # Usage
10//!
11//! ```
12//! use rusty_machine::learning::lin_reg::LinRegressor;
13//! use rusty_machine::learning::SupModel;
14//! use rusty_machine::linalg::Matrix;
15//! use rusty_machine::linalg::Vector;
16//!
17//! let inputs = Matrix::new(4,1,vec![1.0,3.0,5.0,7.0]);
18//! let targets = Vector::new(vec![1.,5.,9.,13.]);
19//!
20//! let mut lin_mod = LinRegressor::default();
21//!
22//! // Train the model
23//! lin_mod.train(&inputs, &targets).unwrap();
24//!
25//! // Now we'll predict a new point
26//! let new_point = Matrix::new(1,1,vec![10.]);
27//! let output = lin_mod.predict(&new_point).unwrap();
28//!
29//! // Hopefully we classified our new point correctly!
30//! assert!(output[0] > 17f64, "Our regressor isn't very good!");
31//! ```
32
33use linalg::{Matrix, BaseMatrix};
34use linalg::Vector;
35use learning::{LearningResult, SupModel};
36use learning::toolkit::cost_fn::CostFunc;
37use learning::toolkit::cost_fn::MeanSqError;
38use learning::optim::grad_desc::GradientDesc;
39use learning::optim::{OptimAlgorithm, Optimizable};
40use learning::error::Error;
41
42/// Linear Regression Model.
43///
44/// Contains option for optimized parameter.
45#[derive(Debug)]
46pub struct LinRegressor {
47    /// The parameters for the regression model.
48    parameters: Option<Vector<f64>>,
49}
50
51impl Default for LinRegressor {
52    fn default() -> LinRegressor {
53        LinRegressor { parameters: None }
54    }
55}
56
57impl LinRegressor {
58    /// Get the parameters from the model.
59    ///
60    /// Returns an option that is None if the model has not been trained.
61    pub fn parameters(&self) -> Option<&Vector<f64>> {
62        self.parameters.as_ref()
63    }
64}
65
66impl SupModel<Matrix<f64>, Vector<f64>> for LinRegressor {
67    /// Train the linear regression model.
68    ///
69    /// Takes training data and output values as input.
70    ///
71    /// # Examples
72    ///
73    /// ```
74    /// use rusty_machine::learning::lin_reg::LinRegressor;
75    /// use rusty_machine::linalg::Matrix;
76    /// use rusty_machine::linalg::Vector;
77    /// use rusty_machine::learning::SupModel;
78    ///
79    /// let mut lin_mod = LinRegressor::default();
80    /// let inputs = Matrix::new(3,1, vec![2.0, 3.0, 4.0]);
81    /// let targets = Vector::new(vec![5.0, 6.0, 7.0]);
82    ///
83    /// lin_mod.train(&inputs, &targets).unwrap();
84    /// ```
85    fn train(&mut self, inputs: &Matrix<f64>, targets: &Vector<f64>) -> LearningResult<()> {
86        let ones = Matrix::<f64>::ones(inputs.rows(), 1);
87        let full_inputs = ones.hcat(inputs);
88
89        let xt = full_inputs.transpose();
90        self.parameters = Some((&xt * full_inputs).solve(&xt * targets)
91                                                  .expect("Unable to solve linear equation."));
92        Ok(())
93    }
94
95    /// Predict output value from input data.
96    ///
97    /// Model must be trained before prediction can be made.
98    fn predict(&self, inputs: &Matrix<f64>) -> LearningResult<Vector<f64>> {
99        if let Some(ref v) = self.parameters {
100            let ones = Matrix::<f64>::ones(inputs.rows(), 1);
101            let full_inputs = ones.hcat(inputs);
102            Ok(full_inputs * v)
103        } else {
104            Err(Error::new_untrained())
105        }
106    }
107}
108
109impl Optimizable for LinRegressor {
110    type Inputs = Matrix<f64>;
111    type Targets = Vector<f64>;
112
113    fn compute_grad(&self,
114                    params: &[f64],
115                    inputs: &Matrix<f64>,
116                    targets: &Vector<f64>)
117                    -> (f64, Vec<f64>) {
118
119        let beta_vec = Vector::new(params.to_vec());
120        let outputs = inputs * beta_vec;
121
122        let cost = MeanSqError::cost(&outputs, targets);
123        let grad = (inputs.transpose() * (outputs - targets)) / (inputs.rows() as f64);
124
125        (cost, grad.into_vec())
126    }
127}
128
129impl LinRegressor {
130    /// Train the linear regressor using Gradient Descent.
131    ///
132    /// # Examples
133    ///
134    /// ```
135    /// use rusty_machine::learning::lin_reg::LinRegressor;
136    /// use rusty_machine::learning::SupModel;
137    /// use rusty_machine::linalg::Matrix;
138    /// use rusty_machine::linalg::Vector;
139    ///
140    /// let inputs = Matrix::new(4,1,vec![1.0,3.0,5.0,7.0]);
141    /// let targets = Vector::new(vec![1.,5.,9.,13.]);
142    ///
143    /// let mut lin_mod = LinRegressor::default();
144    ///
145    /// // Train the model
146    /// lin_mod.train_with_optimization(&inputs, &targets);
147    ///
148    /// // Now we'll predict a new point
149    /// let new_point = Matrix::new(1,1,vec![10.]);
150    /// let _ = lin_mod.predict(&new_point).unwrap();
151    /// ```
152    pub fn train_with_optimization(&mut self, inputs: &Matrix<f64>, targets: &Vector<f64>) {
153        let ones = Matrix::<f64>::ones(inputs.rows(), 1);
154        let full_inputs = ones.hcat(inputs);
155
156        let initial_params = vec![0.; full_inputs.cols()];
157
158        let gd = GradientDesc::default();
159        let optimal_w = gd.optimize(self, &initial_params[..], &full_inputs, targets);
160        self.parameters = Some(Vector::new(optimal_w));
161    }
162}