rusty_machine/learning/
glm.rs

1//! Generalized Linear Model module
2//!
3//! <i>This model is likely to undergo changes in the near future.
4//! These changes will improve the learning algorithm.</i>
5//!
6//! Contains implemention of generalized linear models using
7//! iteratively reweighted least squares.
8//!
9//! The model will automatically add the intercept term to the
10//! input data.
11//!
12//! # Usage
13//!
14//! ```
15//! use rusty_machine::learning::glm::{GenLinearModel, Bernoulli};
16//! use rusty_machine::learning::SupModel;
17//! use rusty_machine::linalg::Matrix;
18//! use rusty_machine::linalg::Vector;
19//!
20//! let inputs = Matrix::new(4,1,vec![1.0,3.0,5.0,7.0]);
21//! let targets = Vector::new(vec![0.,0.,1.,1.]);
22//!
23//! // Construct a GLM with a Bernoulli distribution
24//! // This is equivalent to a logistic regression model.
25//! let mut log_mod = GenLinearModel::new(Bernoulli);
26//!
27//! // Train the model
28//! log_mod.train(&inputs, &targets).unwrap();
29//!
30//! // Now we'll predict a new point
31//! let new_point = Matrix::new(1,1,vec![10.]);
32//! let output = log_mod.predict(&new_point).unwrap();
33//!
34//! // Hopefully we classified our new point correctly!
35//! assert!(output[0] > 0.5, "Our classifier isn't very good!");
36//! ```
37
38use linalg::Vector;
39use linalg::{Matrix, BaseMatrix};
40
41use learning::{LearningResult, SupModel};
42use learning::error::{Error, ErrorKind};
43
44/// The Generalized Linear Model
45///
46/// The model is generic over a Criterion
47/// which specifies the distribution family and
48/// the link function.
49#[derive(Debug)]
50pub struct GenLinearModel<C: Criterion> {
51    parameters: Option<Vector<f64>>,
52    criterion: C,
53}
54
55impl<C: Criterion> GenLinearModel<C> {
56    /// Constructs a new Generalized Linear Model.
57    ///
58    /// Takes a Criterion which fully specifies the family
59    /// and the link function used by the GLM.
60    ///
61    /// ```
62    /// use rusty_machine::learning::glm::GenLinearModel;
63    /// use rusty_machine::learning::glm::Bernoulli;
64    ///
65    /// let glm = GenLinearModel::new(Bernoulli);
66    /// ```
67    pub fn new(criterion: C) -> GenLinearModel<C> {
68        GenLinearModel {
69            parameters: None,
70            criterion: criterion,
71        }
72    }
73}
74
75/// Supervised model trait for the GLM.
76///
77/// Predictions are made from the model by computing g^-1(Xb).
78///
79/// The model is trained using Iteratively Re-weighted Least Squares.
80impl<C: Criterion> SupModel<Matrix<f64>, Vector<f64>> for GenLinearModel<C> {
81    /// Predict output from inputs.
82    fn predict(&self, inputs: &Matrix<f64>) -> LearningResult<Vector<f64>> {
83        if let Some(ref v) = self.parameters {
84            let ones = Matrix::<f64>::ones(inputs.rows(), 1);
85            let full_inputs = ones.hcat(inputs);
86            Ok(self.criterion.apply_link_inv(full_inputs * v))
87        } else {
88            Err(Error::new_untrained())
89        }
90    }
91
92    /// Train the model using inputs and targets.
93    fn train(&mut self, inputs: &Matrix<f64>, targets: &Vector<f64>) -> LearningResult<()> {
94        let n = inputs.rows();
95
96        if n != targets.size() {
97            return Err(Error::new(ErrorKind::InvalidData,
98                                  "Training data do not have the same dimensions"));
99        }
100
101        // Construct initial estimate for mu
102        let mut mu = Vector::new(self.criterion.initialize_mu(targets.data()));
103        let mut z = mu.clone();
104        let mut beta: Vector<f64> = Vector::new(vec![0f64; inputs.cols() + 1]);
105
106        let ones = Matrix::<f64>::ones(inputs.rows(), 1);
107        let full_inputs = ones.hcat(inputs);
108        let x_t = full_inputs.transpose();
109
110        // Iterate to convergence
111        for _ in 0..8 {
112            let w_diag = self.criterion.compute_working_weight(mu.data());
113            let y_bar_data = self.criterion.compute_y_bar(targets.data(), mu.data());
114
115            let w = Matrix::from_diag(&w_diag);
116            let y_bar = Vector::new(y_bar_data);
117
118            let x_t_w = &x_t * w;
119
120            let new_beta = (&x_t_w * &full_inputs)
121                .inverse()
122                .expect("Could not compute input data inverse.") *
123                           x_t_w * z;
124            let diff = (beta - &new_beta).apply(&|x| x.abs()).sum();
125            beta = new_beta;
126
127            if diff < 1e-10 {
128                break;
129            }
130
131            // Update z and mu
132            let fitted = &full_inputs * &beta;
133            z = y_bar + &fitted;
134            mu = self.criterion.apply_link_inv(fitted);
135        }
136
137        self.parameters = Some(beta);
138        Ok(())
139    }
140}
141
142/// The criterion for the Generalized Linear Model.
143///
144/// This trait specifies a Link function and requires a model
145/// variance to be specified. The model variance must be defined
146/// to specify the regression family. The other functions need not
147/// be specified but can be used to control optimization.
148pub trait Criterion {
149    /// The link function of the GLM Criterion.
150    type Link: LinkFunc;
151
152    /// The variance of the regression family.
153    fn model_variance(&self, mu: f64) -> f64;
154
155    /// Initializes the mean value.
156    ///
157    /// By default the mean takes the training target values.
158    fn initialize_mu(&self, y: &[f64]) -> Vec<f64> {
159        y.to_vec()
160    }
161
162    /// Computes the working weights that make up the diagonal
163    /// of the `W` matrix used in the iterative reweighted least squares
164    /// algorithm.
165    ///
166    /// This is equal to:
167    ///
168    /// 1 / (Var(u) * g'(u) * g'(u))
169    fn compute_working_weight(&self, mu: &[f64]) -> Vec<f64> {
170        let mut working_weights_vec = Vec::with_capacity(mu.len());
171
172        for m in mu {
173            let grad = self.link_grad(*m);
174            working_weights_vec.push(1f64 / (self.model_variance(*m) * grad * grad));
175        }
176
177        working_weights_vec
178    }
179
180    /// Computes the adjustment to the fitted values used during
181    /// fitting.
182    ///
183    /// This is equal to:
184    ///
185    /// g`(u) * (y - u)
186    fn compute_y_bar(&self, y: &[f64], mu: &[f64]) -> Vec<f64> {
187        let mut y_bar_vec = Vec::with_capacity(mu.len());
188
189        for (idx, m) in mu.iter().enumerate() {
190            y_bar_vec.push(self.link_grad(*m) * (y[idx] - m));
191        }
192
193        y_bar_vec
194    }
195
196    /// Applies the link function to a vector.
197    fn apply_link_func(&self, vec: Vector<f64>) -> Vector<f64> {
198        vec.apply(&Self::Link::func)
199    }
200
201    /// Applies the inverse of the link function to a vector.
202    fn apply_link_inv(&self, vec: Vector<f64>) -> Vector<f64> {
203        vec.apply(&Self::Link::func_inv)
204    }
205
206    /// Computes the gradient of the link function.
207    fn link_grad(&self, mu: f64) -> f64 {
208        Self::Link::func_grad(mu)
209    }
210}
211
212/// Link functions.
213///
214/// Used within Generalized Linear Regression models.
215pub trait LinkFunc {
216    /// The link function.
217    fn func(x: f64) -> f64;
218
219    /// The gradient of the link function.
220    fn func_grad(x: f64) -> f64;
221
222    /// The inverse of the link function.
223    /// Often called the 'mean' function.
224    fn func_inv(x: f64) -> f64;
225}
226
227/// The Logit link function.
228///
229/// Used primarily as the canonical link in Binomial Regression.
230#[derive(Clone, Copy, Debug)]
231pub struct Logit;
232
233/// The Logit link function.
234///
235/// g(u) = ln(x / (1 - x))
236impl LinkFunc for Logit {
237    fn func(x: f64) -> f64 {
238        (x / (1f64 - x)).ln()
239    }
240
241    fn func_grad(x: f64) -> f64 {
242        1f64 / (x * (1f64 - x))
243    }
244
245    fn func_inv(x: f64) -> f64 {
246        1.0 / (1.0 + (-x).exp())
247    }
248}
249
250/// The log link function.
251///
252/// Used primarily as the canonical link in Poisson Regression.
253#[derive(Clone, Copy, Debug)]
254pub struct Log;
255
256/// The log link function.
257///
258/// g(u) = ln(u)
259impl LinkFunc for Log {
260    fn func(x: f64) -> f64 {
261        x.ln()
262    }
263
264    fn func_grad(x: f64) -> f64 {
265        1f64 / x
266    }
267
268    fn func_inv(x: f64) -> f64 {
269        x.exp()
270    }
271}
272
273/// The Identity link function.
274///
275/// Used primarily as the canonical link in Linear Regression.
276#[derive(Clone, Copy, Debug)]
277pub struct Identity;
278
279/// The Identity link function.
280///
281/// g(u) = u
282impl LinkFunc for Identity {
283    fn func(x: f64) -> f64 {
284        x
285    }
286
287    fn func_grad(_: f64) -> f64 {
288        1f64
289    }
290
291    fn func_inv(x: f64) -> f64 {
292        x
293    }
294}
295
296/// The Bernoulli regression family.
297///
298/// This is equivalent to logistic regression.
299#[derive(Clone, Copy, Debug)]
300pub struct Bernoulli;
301
302impl Criterion for Bernoulli {
303    type Link = Logit;
304
305    fn model_variance(&self, mu: f64) -> f64 {
306        let var = mu * (1f64 - mu);
307
308        if var.abs() < 1e-10 {
309            1e-10
310        } else {
311            var
312        }
313    }
314
315    fn initialize_mu(&self, y: &[f64]) -> Vec<f64> {
316        let mut mu_data = Vec::with_capacity(y.len());
317
318        for y_val in y {
319            mu_data.push(if *y_val < 1e-10 {
320                1e-10
321            } else if *y_val > 1f64 - 1e-10 {
322                1f64 - 1e-10
323            } else {
324                *y_val
325            });
326        }
327
328        mu_data
329    }
330
331    fn compute_working_weight(&self, mu: &[f64]) -> Vec<f64> {
332        let mut working_weights_vec = Vec::with_capacity(mu.len());
333
334        for m in mu {
335            let var = self.model_variance(*m);
336
337            working_weights_vec.push(if var.abs() < 1e-5 {
338                1e-5
339            } else {
340                var
341            });
342        }
343
344        working_weights_vec
345    }
346
347    fn compute_y_bar(&self, y: &[f64], mu: &[f64]) -> Vec<f64> {
348        let mut y_bar_vec = Vec::with_capacity(y.len());
349
350        for (idx, m) in mu.iter().enumerate() {
351            let target_diff = y[idx] - m;
352
353            y_bar_vec.push(if target_diff.abs() < 1e-15 {
354                0f64
355            } else {
356                self.link_grad(*m) * target_diff
357            });
358        }
359
360        y_bar_vec
361    }
362}
363
364/// The Binomial regression family.
365#[derive(Debug)]
366pub struct Binomial {
367    weights: Vec<f64>,
368}
369
370impl Criterion for Binomial {
371    type Link = Logit;
372
373    fn model_variance(&self, mu: f64) -> f64 {
374        let var = mu * (1f64 - mu);
375
376        if var.abs() < 1e-10 {
377            1e-10
378        } else {
379            var
380        }
381
382    }
383
384    fn initialize_mu(&self, y: &[f64]) -> Vec<f64> {
385        let mut mu_data = Vec::with_capacity(y.len());
386
387        for y_val in y {
388            mu_data.push(if *y_val < 1e-10 {
389                1e-10
390            } else if *y_val > 1f64 - 1e-10 {
391                1f64 - 1e-10
392            } else {
393                *y_val
394            });
395        }
396
397        mu_data
398    }
399
400    fn compute_working_weight(&self, mu: &[f64]) -> Vec<f64> {
401        let mut working_weights_vec = Vec::with_capacity(mu.len());
402
403        for (idx, m) in mu.iter().enumerate() {
404            let var = self.model_variance(*m) / self.weights[idx];
405
406            working_weights_vec.push(if var.abs() < 1e-5 {
407                1e-5
408            } else {
409                var
410            });
411        }
412
413        working_weights_vec
414    }
415
416    fn compute_y_bar(&self, y: &[f64], mu: &[f64]) -> Vec<f64> {
417        let mut y_bar_vec = Vec::with_capacity(y.len());
418
419        for (idx, m) in mu.iter().enumerate() {
420            let target_diff = y[idx] - m;
421
422            y_bar_vec.push(if target_diff.abs() < 1e-15 {
423                0f64
424            } else {
425                self.link_grad(*m) * target_diff
426            });
427        }
428
429        y_bar_vec
430    }
431}
432
433/// The Normal regression family.
434///
435/// This is equivalent to the Linear Regression model.
436#[derive(Clone, Copy, Debug)]
437pub struct Normal;
438
439impl Criterion for Normal {
440    type Link = Identity;
441
442    fn model_variance(&self, _: f64) -> f64 {
443        1f64
444    }
445}
446
447/// The Poisson regression family.
448#[derive(Clone, Copy, Debug)]
449pub struct Poisson;
450
451impl Criterion for Poisson {
452    type Link = Log;
453
454    fn model_variance(&self, mu: f64) -> f64 {
455        mu
456    }
457
458    fn initialize_mu(&self, y: &[f64]) -> Vec<f64> {
459        let mut mu_data = Vec::with_capacity(y.len());
460
461        for y_val in y {
462            mu_data.push(if *y_val < 1e-10 {
463                1e-10
464            } else {
465                *y_val
466            });
467        }
468
469        mu_data
470    }
471
472    fn compute_working_weight(&self, mu: &[f64]) -> Vec<f64> {
473        mu.to_vec()
474    }
475}