rusty_machine/data/transforms/
standardize.rs

1//! The Standardizing Transformer
2//!
3//! This module contains the `Standardizer` transformer.
4//!
5//! The `Standardizer` transformer is used to transform input data
6//! so that the mean and standard deviation of each column are as
7//! specified. This is commonly used to transform the data to have `0` mean
8//! and a standard deviation of `1`.
9//!
10//! # Examples
11//!
12//! ```
13//! use rusty_machine::data::transforms::{Transformer, Standardizer};
14//! use rusty_machine::linalg::Matrix;
15//!
16//! // Constructs a new `Standardizer` to map to mean 0 and standard
17//! // deviation of 1.
18//! let mut transformer = Standardizer::default();
19//!
20//! let inputs = Matrix::new(2, 2, vec![-1.0, 2.0, 1.5, 3.0]);
21//!
22//! // Transform the inputs to get output data with required mean and
23//! // standard deviation.
24//! let transformed = transformer.transform(inputs).unwrap();
25//! ```
26
27use learning::error::{Error, ErrorKind};
28use linalg::{Matrix, Vector, Axes, BaseMatrix, BaseMatrixMut};
29use super::{Invertible, Transformer};
30
31use rulinalg::utils;
32
33use libnum::{Float, FromPrimitive};
34
35/// The Standardizer
36///
37/// The Standardizer provides an implementation of `Transformer`
38/// which allows us to transform the input data to have a new mean
39/// and standard deviation.
40///
41/// See the module description for more information.
42#[derive(Debug)]
43pub struct Standardizer<T: Float> {
44    /// Means per column of input data
45    means: Option<Vector<T>>,
46    /// Variances per column of input data
47    variances: Option<Vector<T>>,
48    /// The mean of the new data (default 0)
49    scaled_mean: T,
50    /// The standard deviation of the new data (default 1)
51    scaled_stdev: T,
52}
53
54/// Create a `Standardizer` with mean `0` and standard
55/// deviation `1`.
56impl<T: Float> Default for Standardizer<T> {
57    fn default() -> Standardizer<T> {
58        Standardizer {
59            means: None,
60            variances: None,
61            scaled_mean: T::zero(),
62            scaled_stdev: T::one(),
63        }
64    }
65}
66
67impl<T: Float> Standardizer<T> {
68    /// Constructs a new `Standardizer` with the given mean and variance
69    ///
70    /// # Examples
71    ///
72    /// ```
73    /// use rusty_machine::data::transforms::Standardizer;
74    ///
75    /// // Constructs a new `Standardizer` which will give the data
76    /// // mean `0` and standard deviation `2`.
77    /// let transformer = Standardizer::new(0.0, 2.0);
78    /// ```
79    pub fn new(mean: T, stdev: T) -> Standardizer<T> {
80        Standardizer {
81            means: None,
82            variances: None,
83            scaled_mean: mean,
84            scaled_stdev: stdev,
85        }
86    }
87}
88
89impl<T: Float + FromPrimitive> Transformer<Matrix<T>> for Standardizer<T> {
90
91    fn fit(&mut self, inputs: &Matrix<T>) -> Result<(), Error> {
92        if inputs.rows() <= 1 {
93            Err(Error::new(ErrorKind::InvalidData,
94                           "Cannot standardize data with only one row."))
95        } else {
96            let mean = inputs.mean(Axes::Row);
97            let variance = try!(inputs.variance(Axes::Row).map_err(|_| {
98                Error::new(ErrorKind::InvalidData, "Cannot compute variance of data.")
99            }));
100
101            if mean.data().iter().any(|x| !x.is_finite()) {
102                return Err(Error::new(ErrorKind::InvalidData, "Some data point is non-finite."));
103            }
104            self.means = Some(mean);
105            self.variances = Some(variance);
106            Ok(())
107        }
108    }
109
110    fn transform(&mut self, mut inputs: Matrix<T>) -> Result<Matrix<T>, Error> {
111        if let (&None, &None) = (&self.means, &self.variances) {
112            // if Transformer is not fitted to the data, fit for backward-compat.
113            try!(self.fit(&inputs));
114        }
115
116        if let (&Some(ref means), &Some(ref variances)) = (&self.means, &self.variances) {
117            if means.size() != inputs.cols() {
118                Err(Error::new(ErrorKind::InvalidData,
119                               "Input data has different number of columns from fitted data."))
120            } else {
121                for row in inputs.iter_rows_mut() {
122                    // Subtract the mean
123                    utils::in_place_vec_bin_op(row, means.data(), |x, &y| *x = *x - y);
124                    utils::in_place_vec_bin_op(row, variances.data(), |x, &y| {
125                        *x = (*x * self.scaled_stdev / y.sqrt()) + self.scaled_mean
126                    });
127                }
128                Ok(inputs)
129            }
130        } else {
131            Err(Error::new(ErrorKind::InvalidState, "Transformer has not been fitted."))
132        }
133    }
134}
135
136impl<T: Float + FromPrimitive> Invertible<Matrix<T>> for Standardizer<T> {
137    fn inv_transform(&self, mut inputs: Matrix<T>) -> Result<Matrix<T>, Error> {
138        if let (&Some(ref means), &Some(ref variances)) = (&self.means, &self.variances) {
139
140            let features = means.size();
141            if inputs.cols() != features {
142                return Err(Error::new(ErrorKind::InvalidData,
143                                      "Inputs have different feature count than transformer."));
144            }
145
146            for row in inputs.iter_rows_mut() {
147                utils::in_place_vec_bin_op(row, &variances.data(), |x, &y| {
148                    *x = (*x - self.scaled_mean) * y.sqrt() / self.scaled_stdev
149                });
150
151                // Add the mean
152                utils::in_place_vec_bin_op(row, &means.data(), |x, &y| *x = *x + y);
153            }
154
155            Ok(inputs)
156        } else {
157            Err(Error::new(ErrorKind::InvalidState, "Transformer has not been fitted."))
158        }
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use super::super::{Transformer, Invertible};
166    use linalg::{Axes, Matrix};
167
168    use std::f64;
169
170    #[test]
171    fn single_row_test() {
172        let inputs = Matrix::new(1, 2, vec![1.0, 2.0]);
173
174        let mut standardizer = Standardizer::default();
175
176        let res = standardizer.transform(inputs);
177        assert!(res.is_err());
178    }
179
180    #[test]
181    fn nan_data_test() {
182        let inputs = Matrix::new(2, 2, vec![f64::NAN; 4]);
183
184        let mut standardizer = Standardizer::default();
185
186        let res = standardizer.transform(inputs);
187        assert!(res.is_err());
188    }
189
190    #[test]
191    fn inf_data_test() {
192        let inputs = Matrix::new(2, 2, vec![f64::INFINITY; 4]);
193
194        let mut standardizer = Standardizer::default();
195
196        let res = standardizer.transform(inputs);
197        assert!(res.is_err());
198    }
199
200    #[test]
201    fn basic_standardize_test() {
202        let inputs = Matrix::new(2, 2, vec![-1.0f32, 2.0, 0.0, 3.0]);
203
204        let mut standardizer = Standardizer::default();
205        let transformed = standardizer.transform(inputs).unwrap();
206
207        let new_mean = transformed.mean(Axes::Row);
208        let new_var = transformed.variance(Axes::Row).unwrap();
209
210        assert!(new_mean.data().iter().all(|x| x.abs() < 1e-5));
211        assert!(new_var.data().iter().all(|x| (x.abs() - 1.0) < 1e-5));
212    }
213
214    #[test]
215    fn custom_standardize_test() {
216        let inputs = Matrix::new(2, 2, vec![-1.0f32, 2.0, 0.0, 3.0]);
217
218        let mut standardizer = Standardizer::new(1.0, 2.0);
219        let transformed = standardizer.transform(inputs).unwrap();
220
221        let new_mean = transformed.mean(Axes::Row);
222        let new_var = transformed.variance(Axes::Row).unwrap();
223
224        assert!(new_mean.data().iter().all(|x| (x.abs() - 1.0) < 1e-5));
225        assert!(new_var.data().iter().all(|x| (x.abs() - 4.0) < 1e-5));
226    }
227
228    #[test]
229    fn inv_transform_identity_test() {
230        let inputs = Matrix::new(2, 2, vec![-1.0f32, 2.0, 0.0, 3.0]);
231
232        let mut standardizer = Standardizer::new(1.0, 3.0);
233        let transformed = standardizer.transform(inputs.clone()).unwrap();
234
235        let original = standardizer.inv_transform(transformed).unwrap();
236
237        assert!((inputs - original).data().iter().all(|x| x.abs() < 1e-5));
238    }
239}