rusty_machine/learning/logistic_reg.rs
1//! Logistic Regression module
2//!
3//! Contains implemention of logistic regression using
4//! gradient descent optimization.
5//!
6//! The regressor will automatically add the intercept term
7//! so you do not need to format the input matrices yourself.
8//!
9//! # Usage
10//!
11//! ```
12//! use rusty_machine::learning::logistic_reg::LogisticRegressor;
13//! use rusty_machine::learning::SupModel;
14//! use rusty_machine::linalg::Matrix;
15//! use rusty_machine::linalg::Vector;
16//!
17//! let inputs = Matrix::new(4,1,vec![1.0,3.0,5.0,7.0]);
18//! let targets = Vector::new(vec![0.,0.,1.,1.]);
19//!
20//! let mut log_mod = LogisticRegressor::default();
21//!
22//! // Train the model
23//! log_mod.train(&inputs, &targets).unwrap();
24//!
25//! // Now we'll predict a new point
26//! let new_point = Matrix::new(1,1,vec![10.]);
27//! let output = log_mod.predict(&new_point).unwrap();
28//!
29//! // Hopefully we classified our new point correctly!
30//! assert!(output[0] > 0.5, "Our classifier isn't very good!");
31//! ```
32//!
33//! We could have been more specific about the learning of the model
34//! by using the `new` constructor instead. This allows us to provide
35//! a `GradientDesc` object with custom parameters.
36
37use linalg::{Matrix, BaseMatrix};
38use linalg::Vector;
39use learning::{LearningResult, SupModel};
40use learning::toolkit::activ_fn::{ActivationFunc, Sigmoid};
41use learning::toolkit::cost_fn::{CostFunc, CrossEntropyError};
42use learning::optim::grad_desc::GradientDesc;
43use learning::optim::{OptimAlgorithm, Optimizable};
44use learning::error::Error;
45
46/// Logistic Regression Model.
47///
48/// Contains option for optimized parameter.
49#[derive(Debug)]
50pub struct LogisticRegressor<A>
51 where A: OptimAlgorithm<BaseLogisticRegressor>
52{
53 base: BaseLogisticRegressor,
54 alg: A,
55}
56
57/// Constructs a default Logistic Regression model
58/// using standard gradient descent.
59impl Default for LogisticRegressor<GradientDesc> {
60 fn default() -> LogisticRegressor<GradientDesc> {
61 LogisticRegressor {
62 base: BaseLogisticRegressor::new(),
63 alg: GradientDesc::default(),
64 }
65 }
66}
67
68impl<A: OptimAlgorithm<BaseLogisticRegressor>> LogisticRegressor<A> {
69 /// Constructs untrained logistic regression model.
70 ///
71 /// # Examples
72 ///
73 /// ```
74 /// use rusty_machine::learning::logistic_reg::LogisticRegressor;
75 /// use rusty_machine::learning::optim::grad_desc::GradientDesc;
76 ///
77 /// let gd = GradientDesc::default();
78 /// let mut logistic_mod = LogisticRegressor::new(gd);
79 /// ```
80 pub fn new(alg: A) -> LogisticRegressor<A> {
81 LogisticRegressor {
82 base: BaseLogisticRegressor::new(),
83 alg: alg,
84 }
85 }
86
87 /// Get the parameters from the model.
88 ///
89 /// Returns an option that is None if the model has not been trained.
90 pub fn parameters(&self) -> Option<&Vector<f64>> {
91 self.base.parameters()
92 }
93}
94
95impl<A> SupModel<Matrix<f64>, Vector<f64>> for LogisticRegressor<A>
96 where A: OptimAlgorithm<BaseLogisticRegressor>
97{
98 /// Train the logistic regression model.
99 ///
100 /// Takes training data and output values as input.
101 ///
102 /// # Examples
103 ///
104 /// ```
105 /// use rusty_machine::learning::logistic_reg::LogisticRegressor;
106 /// use rusty_machine::linalg::Matrix;
107 /// use rusty_machine::linalg::Vector;
108 /// use rusty_machine::learning::SupModel;
109 ///
110 /// let mut logistic_mod = LogisticRegressor::default();
111 /// let inputs = Matrix::new(3,2, vec![1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
112 /// let targets = Vector::new(vec![5.0, 6.0, 7.0]);
113 ///
114 /// logistic_mod.train(&inputs, &targets).unwrap();
115 /// ```
116 fn train(&mut self, inputs: &Matrix<f64>, targets: &Vector<f64>) -> LearningResult<()> {
117 let ones = Matrix::<f64>::ones(inputs.rows(), 1);
118 let full_inputs = ones.hcat(inputs);
119
120 let initial_params = vec![0.5; full_inputs.cols()];
121
122 let optimal_w = self.alg.optimize(&self.base, &initial_params[..], &full_inputs, targets);
123 self.base.set_parameters(Vector::new(optimal_w));
124 Ok(())
125 }
126
127 /// Predict output value from input data.
128 ///
129 /// Model must be trained before prediction can be made.
130 fn predict(&self, inputs: &Matrix<f64>) -> LearningResult<Vector<f64>> {
131 if let Some(v) = self.base.parameters() {
132 let ones = Matrix::<f64>::ones(inputs.rows(), 1);
133 let full_inputs = ones.hcat(inputs);
134 Ok((full_inputs * v).apply(&Sigmoid::func))
135 } else {
136 Err(Error::new_untrained())
137 }
138 }
139}
140
141/// The Base Logistic Regression model.
142///
143/// This struct cannot be instantianated and is used internally only.
144#[derive(Debug)]
145pub struct BaseLogisticRegressor {
146 parameters: Option<Vector<f64>>,
147}
148
149impl BaseLogisticRegressor {
150 /// Construct a new BaseLogisticRegressor
151 /// with parameters set to None.
152 fn new() -> BaseLogisticRegressor {
153 BaseLogisticRegressor { parameters: None }
154 }
155}
156
157impl BaseLogisticRegressor {
158 /// Returns a reference to the parameters.
159 fn parameters(&self) -> Option<&Vector<f64>> {
160 self.parameters.as_ref()
161 }
162
163 /// Set the parameters to `Some` vector.
164 fn set_parameters(&mut self, params: Vector<f64>) {
165 self.parameters = Some(params);
166 }
167}
168
169/// Computing the gradient of the underlying Logistic
170/// Regression model.
171///
172/// The gradient is given by
173///
174/// X<sup>T</sup>(h(Xb) - y) / m
175///
176/// where `h` is the sigmoid function and `b` the underlying model parameters.
177impl Optimizable for BaseLogisticRegressor {
178 type Inputs = Matrix<f64>;
179 type Targets = Vector<f64>;
180
181 fn compute_grad(&self,
182 params: &[f64],
183 inputs: &Matrix<f64>,
184 targets: &Vector<f64>)
185 -> (f64, Vec<f64>) {
186
187 let beta_vec = Vector::new(params.to_vec());
188 let outputs = (inputs * beta_vec).apply(&Sigmoid::func);
189
190 let cost = CrossEntropyError::cost(&outputs, targets);
191 let grad = (inputs.transpose() * (outputs - targets)) / (inputs.rows() as f64);
192
193 (cost, grad.into_vec())
194 }
195}