rusty_machine/data/transforms/
standardize.rs1use learning::error::{Error, ErrorKind};
28use linalg::{Matrix, Vector, Axes, BaseMatrix, BaseMatrixMut};
29use super::{Invertible, Transformer};
30
31use rulinalg::utils;
32
33use libnum::{Float, FromPrimitive};
34
35#[derive(Debug)]
43pub struct Standardizer<T: Float> {
44 means: Option<Vector<T>>,
46 variances: Option<Vector<T>>,
48 scaled_mean: T,
50 scaled_stdev: T,
52}
53
54impl<T: Float> Default for Standardizer<T> {
57 fn default() -> Standardizer<T> {
58 Standardizer {
59 means: None,
60 variances: None,
61 scaled_mean: T::zero(),
62 scaled_stdev: T::one(),
63 }
64 }
65}
66
67impl<T: Float> Standardizer<T> {
68 pub fn new(mean: T, stdev: T) -> Standardizer<T> {
80 Standardizer {
81 means: None,
82 variances: None,
83 scaled_mean: mean,
84 scaled_stdev: stdev,
85 }
86 }
87}
88
89impl<T: Float + FromPrimitive> Transformer<Matrix<T>> for Standardizer<T> {
90
91 fn fit(&mut self, inputs: &Matrix<T>) -> Result<(), Error> {
92 if inputs.rows() <= 1 {
93 Err(Error::new(ErrorKind::InvalidData,
94 "Cannot standardize data with only one row."))
95 } else {
96 let mean = inputs.mean(Axes::Row);
97 let variance = try!(inputs.variance(Axes::Row).map_err(|_| {
98 Error::new(ErrorKind::InvalidData, "Cannot compute variance of data.")
99 }));
100
101 if mean.data().iter().any(|x| !x.is_finite()) {
102 return Err(Error::new(ErrorKind::InvalidData, "Some data point is non-finite."));
103 }
104 self.means = Some(mean);
105 self.variances = Some(variance);
106 Ok(())
107 }
108 }
109
110 fn transform(&mut self, mut inputs: Matrix<T>) -> Result<Matrix<T>, Error> {
111 if let (&None, &None) = (&self.means, &self.variances) {
112 try!(self.fit(&inputs));
114 }
115
116 if let (&Some(ref means), &Some(ref variances)) = (&self.means, &self.variances) {
117 if means.size() != inputs.cols() {
118 Err(Error::new(ErrorKind::InvalidData,
119 "Input data has different number of columns from fitted data."))
120 } else {
121 for row in inputs.iter_rows_mut() {
122 utils::in_place_vec_bin_op(row, means.data(), |x, &y| *x = *x - y);
124 utils::in_place_vec_bin_op(row, variances.data(), |x, &y| {
125 *x = (*x * self.scaled_stdev / y.sqrt()) + self.scaled_mean
126 });
127 }
128 Ok(inputs)
129 }
130 } else {
131 Err(Error::new(ErrorKind::InvalidState, "Transformer has not been fitted."))
132 }
133 }
134}
135
136impl<T: Float + FromPrimitive> Invertible<Matrix<T>> for Standardizer<T> {
137 fn inv_transform(&self, mut inputs: Matrix<T>) -> Result<Matrix<T>, Error> {
138 if let (&Some(ref means), &Some(ref variances)) = (&self.means, &self.variances) {
139
140 let features = means.size();
141 if inputs.cols() != features {
142 return Err(Error::new(ErrorKind::InvalidData,
143 "Inputs have different feature count than transformer."));
144 }
145
146 for row in inputs.iter_rows_mut() {
147 utils::in_place_vec_bin_op(row, &variances.data(), |x, &y| {
148 *x = (*x - self.scaled_mean) * y.sqrt() / self.scaled_stdev
149 });
150
151 utils::in_place_vec_bin_op(row, &means.data(), |x, &y| *x = *x + y);
153 }
154
155 Ok(inputs)
156 } else {
157 Err(Error::new(ErrorKind::InvalidState, "Transformer has not been fitted."))
158 }
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use super::super::{Transformer, Invertible};
166 use linalg::{Axes, Matrix};
167
168 use std::f64;
169
170 #[test]
171 fn single_row_test() {
172 let inputs = Matrix::new(1, 2, vec![1.0, 2.0]);
173
174 let mut standardizer = Standardizer::default();
175
176 let res = standardizer.transform(inputs);
177 assert!(res.is_err());
178 }
179
180 #[test]
181 fn nan_data_test() {
182 let inputs = Matrix::new(2, 2, vec![f64::NAN; 4]);
183
184 let mut standardizer = Standardizer::default();
185
186 let res = standardizer.transform(inputs);
187 assert!(res.is_err());
188 }
189
190 #[test]
191 fn inf_data_test() {
192 let inputs = Matrix::new(2, 2, vec![f64::INFINITY; 4]);
193
194 let mut standardizer = Standardizer::default();
195
196 let res = standardizer.transform(inputs);
197 assert!(res.is_err());
198 }
199
200 #[test]
201 fn basic_standardize_test() {
202 let inputs = Matrix::new(2, 2, vec![-1.0f32, 2.0, 0.0, 3.0]);
203
204 let mut standardizer = Standardizer::default();
205 let transformed = standardizer.transform(inputs).unwrap();
206
207 let new_mean = transformed.mean(Axes::Row);
208 let new_var = transformed.variance(Axes::Row).unwrap();
209
210 assert!(new_mean.data().iter().all(|x| x.abs() < 1e-5));
211 assert!(new_var.data().iter().all(|x| (x.abs() - 1.0) < 1e-5));
212 }
213
214 #[test]
215 fn custom_standardize_test() {
216 let inputs = Matrix::new(2, 2, vec![-1.0f32, 2.0, 0.0, 3.0]);
217
218 let mut standardizer = Standardizer::new(1.0, 2.0);
219 let transformed = standardizer.transform(inputs).unwrap();
220
221 let new_mean = transformed.mean(Axes::Row);
222 let new_var = transformed.variance(Axes::Row).unwrap();
223
224 assert!(new_mean.data().iter().all(|x| (x.abs() - 1.0) < 1e-5));
225 assert!(new_var.data().iter().all(|x| (x.abs() - 4.0) < 1e-5));
226 }
227
228 #[test]
229 fn inv_transform_identity_test() {
230 let inputs = Matrix::new(2, 2, vec![-1.0f32, 2.0, 0.0, 3.0]);
231
232 let mut standardizer = Standardizer::new(1.0, 3.0);
233 let transformed = standardizer.transform(inputs.clone()).unwrap();
234
235 let original = standardizer.inv_transform(transformed).unwrap();
236
237 assert!((inputs - original).data().iter().all(|x| x.abs() < 1e-5));
238 }
239}