1use std::marker::PhantomData;
31
32use crate::api::{Transformer, UnsupervisedEstimator};
33use crate::error::{Failed, FailedError};
34use crate::linalg::basic::arrays::Array2;
35use crate::numbers::basenum::Number;
36use crate::numbers::realnum::RealNumber;
37
38#[cfg(feature = "serde")]
39use serde::{Deserialize, Serialize};
40
41#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
43#[derive(Clone, Debug, Copy, Eq, PartialEq)]
44pub struct StandardScalerParameters {
45 with_mean: bool,
47 with_std: bool,
49}
50impl Default for StandardScalerParameters {
51 fn default() -> Self {
52 StandardScalerParameters {
53 with_mean: true,
54 with_std: true,
55 }
56 }
57}
58
59#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
65#[derive(Clone, Debug, Default, PartialEq)]
66pub struct StandardScaler<T: Number + RealNumber> {
67 means: Vec<f64>,
68 stds: Vec<f64>,
69 parameters: StandardScalerParameters,
70 _phantom: PhantomData<T>,
71}
72
73#[allow(dead_code)]
74impl<T: Number + RealNumber> StandardScaler<T> {
75 fn new(parameters: StandardScalerParameters) -> Self
76 where
77 T: Number + RealNumber,
78 {
79 Self {
80 means: vec![],
81 stds: vec![],
82 parameters: StandardScalerParameters {
83 with_mean: parameters.with_mean,
84 with_std: parameters.with_std,
85 },
86 _phantom: PhantomData,
87 }
88 }
89 fn adjust_column_mean(&self, mean: f64) -> f64 {
92 if self.parameters.with_mean {
93 mean
94 } else {
95 0f64
96 }
97 }
98 fn adjust_column_std(&self, std: f64) -> f64 {
101 if self.parameters.with_std {
102 ensure_std_valid(std)
103 } else {
104 1f64
105 }
106 }
107}
108
109fn ensure_std_valid<T: Number + RealNumber>(value: T) -> T {
114 value.max(T::min_positive_value())
115}
116
117impl<T: Number + RealNumber, M: Array2<T>> UnsupervisedEstimator<M, StandardScalerParameters>
119 for StandardScaler<T>
120{
121 fn fit(x: &M, parameters: StandardScalerParameters) -> Result<Self, Failed>
122 where
123 T: Number + RealNumber,
124 M: Array2<T>,
125 {
126 Ok(Self {
127 means: x.column_mean(),
128 stds: x.std_dev(0),
129 parameters,
130 _phantom: Default::default(),
131 })
132 }
133}
134
135impl<T: Number + RealNumber, M: Array2<T>> Transformer<M> for StandardScaler<T> {
139 fn transform(&self, x: &M) -> Result<M, Failed> {
140 let (_, n_cols) = x.shape();
141 if n_cols != self.means.len() {
142 return Err(Failed::because(
143 FailedError::TransformFailed,
144 &format!(
145 "Expected {} columns, but got {} columns instead.",
146 self.means.len(),
147 n_cols,
148 ),
149 ));
150 }
151
152 Ok(build_matrix_from_columns(
153 self.means
154 .iter()
155 .zip(self.stds.iter())
156 .enumerate()
157 .map(|(column_index, (column_mean, column_std))| {
158 x.take_column(column_index)
159 .sub_scalar(T::from(self.adjust_column_mean(*column_mean)).unwrap())
160 .div_scalar(T::from(self.adjust_column_std(*column_std)).unwrap())
161 })
162 .collect(),
163 )
164 .unwrap())
165 }
166}
167
168fn build_matrix_from_columns<T, M>(columns: Vec<M>) -> Option<M>
171where
172 T: Number + RealNumber,
173 M: Array2<T>,
174{
175 columns.first().cloned().map(|output_matrix| {
176 columns
177 .iter()
178 .skip(1)
179 .fold(output_matrix, |current_matrix, new_colum| {
180 current_matrix.h_stack(new_colum)
181 })
182 })
183}
184
185#[cfg(test)]
186mod tests {
187
188 mod helper_functionality {
189 use super::super::{build_matrix_from_columns, ensure_std_valid};
190 use crate::linalg::basic::matrix::DenseMatrix;
191
192 #[test]
193 fn combine_three_columns() {
194 assert_eq!(
195 build_matrix_from_columns(vec![
196 DenseMatrix::from_2d_vec(&vec![vec![1.0], vec![1.0], vec![1.0],]).unwrap(),
197 DenseMatrix::from_2d_vec(&vec![vec![2.0], vec![2.0], vec![2.0],]).unwrap(),
198 DenseMatrix::from_2d_vec(&vec![vec![3.0], vec![3.0], vec![3.0],]).unwrap()
199 ]),
200 Some(
201 DenseMatrix::from_2d_vec(&vec![
202 vec![1.0, 2.0, 3.0],
203 vec![1.0, 2.0, 3.0],
204 vec![1.0, 2.0, 3.0]
205 ])
206 .unwrap()
207 )
208 )
209 }
210
211 #[test]
212 fn negative_value_should_be_replace_with_minimal_positive_value() {
213 assert_eq!(ensure_std_valid(-1.0), f64::MIN_POSITIVE)
214 }
215
216 #[test]
217 fn zero_should_be_replace_with_minimal_positive_value() {
218 assert_eq!(ensure_std_valid(0.0), f64::MIN_POSITIVE)
219 }
220 }
221 mod standard_scaler {
222 use super::super::{StandardScaler, StandardScalerParameters};
223 use crate::api::{Transformer, UnsupervisedEstimator};
224 use crate::linalg::basic::arrays::Array2;
225 use crate::linalg::basic::matrix::DenseMatrix;
226
227 #[test]
228 fn dont_adjust_mean_if_used() {
229 assert_eq!(
230 (StandardScaler::<f64>::new(StandardScalerParameters {
231 with_mean: true,
232 with_std: true
233 }))
234 .adjust_column_mean(1.0),
235 1.0
236 )
237 }
238 #[test]
239 fn replace_mean_with_zero_if_not_used() {
240 assert_eq!(
241 (StandardScaler::<f64>::new(StandardScalerParameters {
242 with_mean: false,
243 with_std: true
244 }))
245 .adjust_column_mean(1.0),
246 0.0
247 )
248 }
249 #[test]
250 fn dont_adjust_std_if_used() {
251 assert_eq!(
252 (StandardScaler::<f64>::new(StandardScalerParameters {
253 with_mean: true,
254 with_std: true
255 }))
256 .adjust_column_std(10.0),
257 10.0
258 )
259 }
260 #[test]
261 fn replace_std_with_one_if_not_used() {
262 assert_eq!(
263 (StandardScaler::<f64>::new(StandardScalerParameters {
264 with_mean: true,
265 with_std: false
266 }))
267 .adjust_column_std(10.0),
268 1.0
269 )
270 }
271
272 fn fit_transform_with_default_standard_scaler(
274 values_to_be_transformed: &DenseMatrix<f64>,
275 ) -> DenseMatrix<f64> {
276 StandardScaler::fit(
277 values_to_be_transformed,
278 StandardScalerParameters::default(),
279 )
280 .unwrap()
281 .transform(values_to_be_transformed)
282 .unwrap()
283 }
284
285 #[test]
288 fn fit_transform_random_values() {
289 let transformed_values = fit_transform_with_default_standard_scaler(
290 &DenseMatrix::from_2d_array(&[
291 &[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793],
292 &[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
293 &[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
294 &[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
295 ])
296 .unwrap(),
297 );
298 println!("{transformed_values}");
299 assert!(transformed_values.approximate_eq(
300 &DenseMatrix::from_2d_array(&[
301 &[-1.1154020653, -0.4031985330, 0.9284605204, -0.4271473866],
302 &[-0.7615464283, -0.7076698384, -1.1075452562, 1.2632979631],
303 &[0.4832504303, -0.6106747444, 1.0630075435, 0.5494084257],
304 &[1.3936980634, 1.7215431158, -0.8839228078, -1.3855590021],
305 ])
306 .unwrap(),
307 1.0
308 ))
309 }
310
311 #[test]
313 fn fit_transform_with_zero_variance() {
314 assert_eq!(
315 fit_transform_with_default_standard_scaler(
316 &DenseMatrix::from_2d_array(&[&[1.0], &[1.0], &[1.0], &[1.0]]).unwrap()
317 ),
318 DenseMatrix::from_2d_array(&[&[0.0], &[0.0], &[0.0], &[0.0]]).unwrap(),
319 "When scaling values with zero variance, zero is expected as return value"
320 )
321 }
322
323 #[test]
325 fn fit_for_simple_values() {
326 assert_eq!(
327 StandardScaler::fit(
328 &DenseMatrix::from_2d_array(&[
329 &[1.0, 1.0, 1.0],
330 &[1.0, 2.0, 5.0],
331 &[1.0, 1.0, 1.0],
332 &[1.0, 2.0, 5.0]
333 ])
334 .unwrap(),
335 StandardScalerParameters::default(),
336 ),
337 Ok(StandardScaler {
338 means: vec![1.0, 1.5, 3.0],
339 stds: vec![0.0, 0.5, 2.0],
340 parameters: StandardScalerParameters {
341 with_mean: true,
342 with_std: true
343 },
344 _phantom: Default::default(),
345 })
346 )
347 }
348 #[test]
350 fn fit_for_random_values() {
351 let fitted_scaler = StandardScaler::fit(
352 &DenseMatrix::from_2d_array(&[
353 &[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793],
354 &[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
355 &[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
356 &[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
357 ])
358 .unwrap(),
359 StandardScalerParameters::default(),
360 )
361 .unwrap();
362
363 assert_eq!(
364 fitted_scaler.means,
365 vec![0.42864544605, 0.2869813741, 0.737752073825, 0.431011663625],
366 );
367
368 assert!(&DenseMatrix::<f64>::from_2d_vec(&vec![fitted_scaler.stds])
369 .unwrap()
370 .approximate_eq(
371 &DenseMatrix::from_2d_array(&[&[
372 0.29426447500954,
373 0.16758497615485,
374 0.20820945786863,
375 0.23329718831165
376 ],])
377 .unwrap(),
378 0.00000000000001
379 ))
380 }
381
382 #[test]
385 fn transform_without_std() {
386 let standard_scaler = StandardScaler {
387 means: vec![1.0, 3.0],
388 stds: vec![1.0, 2.0],
389 parameters: StandardScalerParameters {
390 with_mean: true,
391 with_std: false,
392 },
393 _phantom: Default::default(),
394 };
395
396 assert_eq!(
397 standard_scaler
398 .transform(&DenseMatrix::from_2d_array(&[&[0.0, 2.0], &[2.0, 4.0]]).unwrap()),
399 Ok(DenseMatrix::from_2d_array(&[&[-1.0, -1.0], &[1.0, 1.0]]).unwrap())
400 )
401 }
402
403 #[test]
406 fn transform_without_mean() {
407 let standard_scaler = StandardScaler {
408 means: vec![1.0, 2.0],
409 stds: vec![2.0, 3.0],
410 parameters: StandardScalerParameters {
411 with_mean: false,
412 with_std: true,
413 },
414 _phantom: Default::default(),
415 };
416
417 assert_eq!(
418 standard_scaler
419 .transform(&DenseMatrix::from_2d_array(&[&[0.0, 9.0], &[4.0, 12.0]]).unwrap()),
420 Ok(DenseMatrix::from_2d_array(&[&[0.0, 3.0], &[2.0, 4.0]]).unwrap())
421 )
422 }
423
424 #[cfg_attr(
427 all(target_arch = "wasm32", not(target_os = "wasi")),
428 wasm_bindgen_test::wasm_bindgen_test
429 )]
430 #[test]
431 #[cfg(feature = "serde")]
432 fn serde_fit_for_random_values() {
433 let fitted_scaler = StandardScaler::fit(
434 &DenseMatrix::from_2d_array(&[
435 &[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793],
436 &[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
437 &[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
438 &[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
439 ])
440 .unwrap(),
441 StandardScalerParameters::default(),
442 )
443 .unwrap();
444
445 let deserialized_scaler: StandardScaler<f64> =
446 serde_json::from_str(&serde_json::to_string(&fitted_scaler).unwrap()).unwrap();
447
448 assert_eq!(
449 deserialized_scaler.means,
450 vec![0.42864544605, 0.2869813741, 0.737752073825, 0.431011663625],
451 );
452
453 assert!(&DenseMatrix::from_2d_vec(&vec![deserialized_scaler.stds])
454 .unwrap()
455 .approximate_eq(
456 &DenseMatrix::from_2d_array(&[&[
457 0.29426447500954,
458 0.16758497615485,
459 0.20820945786863,
460 0.23329718831165
461 ],])
462 .unwrap(),
463 0.00000000000001
464 ))
465 }
466 }
467}