rusty_machine/learning/lin_reg.rs
1//! Linear Regression module
2//!
3//! Contains implemention of linear regression using
4//! OLS and gradient descent optimization.
5//!
6//! The regressor will automatically add the intercept term
7//! so you do not need to format the input matrices yourself.
8//!
9//! # Usage
10//!
11//! ```
12//! use rusty_machine::learning::lin_reg::LinRegressor;
13//! use rusty_machine::learning::SupModel;
14//! use rusty_machine::linalg::Matrix;
15//! use rusty_machine::linalg::Vector;
16//!
17//! let inputs = Matrix::new(4,1,vec![1.0,3.0,5.0,7.0]);
18//! let targets = Vector::new(vec![1.,5.,9.,13.]);
19//!
20//! let mut lin_mod = LinRegressor::default();
21//!
22//! // Train the model
23//! lin_mod.train(&inputs, &targets).unwrap();
24//!
25//! // Now we'll predict a new point
26//! let new_point = Matrix::new(1,1,vec![10.]);
27//! let output = lin_mod.predict(&new_point).unwrap();
28//!
29//! // Hopefully we classified our new point correctly!
30//! assert!(output[0] > 17f64, "Our regressor isn't very good!");
31//! ```
32
33use linalg::{Matrix, BaseMatrix};
34use linalg::Vector;
35use learning::{LearningResult, SupModel};
36use learning::toolkit::cost_fn::CostFunc;
37use learning::toolkit::cost_fn::MeanSqError;
38use learning::optim::grad_desc::GradientDesc;
39use learning::optim::{OptimAlgorithm, Optimizable};
40use learning::error::Error;
41
42/// Linear Regression Model.
43///
44/// Contains option for optimized parameter.
45#[derive(Debug)]
46pub struct LinRegressor {
47 /// The parameters for the regression model.
48 parameters: Option<Vector<f64>>,
49}
50
51impl Default for LinRegressor {
52 fn default() -> LinRegressor {
53 LinRegressor { parameters: None }
54 }
55}
56
57impl LinRegressor {
58 /// Get the parameters from the model.
59 ///
60 /// Returns an option that is None if the model has not been trained.
61 pub fn parameters(&self) -> Option<&Vector<f64>> {
62 self.parameters.as_ref()
63 }
64}
65
66impl SupModel<Matrix<f64>, Vector<f64>> for LinRegressor {
67 /// Train the linear regression model.
68 ///
69 /// Takes training data and output values as input.
70 ///
71 /// # Examples
72 ///
73 /// ```
74 /// use rusty_machine::learning::lin_reg::LinRegressor;
75 /// use rusty_machine::linalg::Matrix;
76 /// use rusty_machine::linalg::Vector;
77 /// use rusty_machine::learning::SupModel;
78 ///
79 /// let mut lin_mod = LinRegressor::default();
80 /// let inputs = Matrix::new(3,1, vec![2.0, 3.0, 4.0]);
81 /// let targets = Vector::new(vec![5.0, 6.0, 7.0]);
82 ///
83 /// lin_mod.train(&inputs, &targets).unwrap();
84 /// ```
85 fn train(&mut self, inputs: &Matrix<f64>, targets: &Vector<f64>) -> LearningResult<()> {
86 let ones = Matrix::<f64>::ones(inputs.rows(), 1);
87 let full_inputs = ones.hcat(inputs);
88
89 let xt = full_inputs.transpose();
90 self.parameters = Some((&xt * full_inputs).solve(&xt * targets)
91 .expect("Unable to solve linear equation."));
92 Ok(())
93 }
94
95 /// Predict output value from input data.
96 ///
97 /// Model must be trained before prediction can be made.
98 fn predict(&self, inputs: &Matrix<f64>) -> LearningResult<Vector<f64>> {
99 if let Some(ref v) = self.parameters {
100 let ones = Matrix::<f64>::ones(inputs.rows(), 1);
101 let full_inputs = ones.hcat(inputs);
102 Ok(full_inputs * v)
103 } else {
104 Err(Error::new_untrained())
105 }
106 }
107}
108
109impl Optimizable for LinRegressor {
110 type Inputs = Matrix<f64>;
111 type Targets = Vector<f64>;
112
113 fn compute_grad(&self,
114 params: &[f64],
115 inputs: &Matrix<f64>,
116 targets: &Vector<f64>)
117 -> (f64, Vec<f64>) {
118
119 let beta_vec = Vector::new(params.to_vec());
120 let outputs = inputs * beta_vec;
121
122 let cost = MeanSqError::cost(&outputs, targets);
123 let grad = (inputs.transpose() * (outputs - targets)) / (inputs.rows() as f64);
124
125 (cost, grad.into_vec())
126 }
127}
128
129impl LinRegressor {
130 /// Train the linear regressor using Gradient Descent.
131 ///
132 /// # Examples
133 ///
134 /// ```
135 /// use rusty_machine::learning::lin_reg::LinRegressor;
136 /// use rusty_machine::learning::SupModel;
137 /// use rusty_machine::linalg::Matrix;
138 /// use rusty_machine::linalg::Vector;
139 ///
140 /// let inputs = Matrix::new(4,1,vec![1.0,3.0,5.0,7.0]);
141 /// let targets = Vector::new(vec![1.,5.,9.,13.]);
142 ///
143 /// let mut lin_mod = LinRegressor::default();
144 ///
145 /// // Train the model
146 /// lin_mod.train_with_optimization(&inputs, &targets);
147 ///
148 /// // Now we'll predict a new point
149 /// let new_point = Matrix::new(1,1,vec![10.]);
150 /// let _ = lin_mod.predict(&new_point).unwrap();
151 /// ```
152 pub fn train_with_optimization(&mut self, inputs: &Matrix<f64>, targets: &Vector<f64>) {
153 let ones = Matrix::<f64>::ones(inputs.rows(), 1);
154 let full_inputs = ones.hcat(inputs);
155
156 let initial_params = vec![0.; full_inputs.cols()];
157
158 let gd = GradientDesc::default();
159 let optimal_w = gd.optimize(self, &initial_params[..], &full_inputs, targets);
160 self.parameters = Some(Vector::new(optimal_w));
161 }
162}