1use crate::preprocessing::incremental::{IncrementalScaler, WelfordState};
39use crate::{Result, TreeBoostError};
40
41pub trait Scaler {
52 fn fit(&mut self, data: &[f32], num_features: usize) -> Result<()>;
58
59 fn transform(&self, data: &mut [f32], num_features: usize) -> Result<()>;
65
66 fn fit_transform(&mut self, data: &mut [f32], num_features: usize) -> Result<()> {
68 self.fit(data, num_features)?;
69 self.transform(data, num_features)?;
70 Ok(())
71 }
72
73 fn is_fitted(&self) -> bool;
75}
76
77#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
107pub struct StandardScaler {
108 pub means: Vec<f32>,
110 pub stds: Vec<f32>,
112 fitted: bool,
114 #[serde(default)]
116 welford_states: Vec<WelfordState>,
117 #[serde(default)]
132 forget_factor: Option<f32>,
133}
134
135impl StandardScaler {
136 pub fn new() -> Self {
138 Self {
139 means: Vec::new(),
140 stds: Vec::new(),
141 fitted: false,
142 welford_states: Vec::new(),
143 forget_factor: None,
144 }
145 }
146
147 pub fn with_forget_factor(forget_factor: f32) -> Self {
161 Self {
162 means: Vec::new(),
163 stds: Vec::new(),
164 fitted: false,
165 welford_states: Vec::new(),
166 forget_factor: Some(forget_factor.clamp(0.0, 1.0)),
167 }
168 }
169
170 pub fn set_forget_factor(&mut self, factor: Option<f32>) {
175 self.forget_factor = factor.map(|f| f.clamp(0.0, 1.0));
176 }
177
178 pub fn forget_factor(&self) -> Option<f32> {
180 self.forget_factor
181 }
182
183 pub fn means(&self) -> &[f32] {
185 &self.means
186 }
187
188 pub fn stds(&self) -> &[f32] {
190 &self.stds
191 }
192
193 fn sync_from_welford(&mut self) {
195 let num_features = self.welford_states.len();
196 self.means.resize(num_features, 0.0);
197 self.stds.resize(num_features, 1.0);
198
199 for (i, state) in self.welford_states.iter().enumerate() {
200 self.means[i] = state.mean as f32;
201 let std = state.std() as f32;
202 self.stds[i] = if std < 1e-8 { 1.0 } else { std };
204 }
205 }
206
207 fn compute_batch_stats(data: &[f32], num_features: usize) -> Vec<(f64, f64)> {
209 let num_rows = data.len() / num_features;
210 let mut stats = vec![(0.0f64, 0.0f64); num_features];
211
212 if num_rows == 0 {
213 return stats;
214 }
215
216 for feat in 0..num_features {
218 let mut sum = 0.0f64;
219 for row in 0..num_rows {
220 sum += data[row * num_features + feat] as f64;
221 }
222 stats[feat].0 = sum / num_rows as f64;
223 }
224
225 for feat in 0..num_features {
227 let mean = stats[feat].0;
228 let mut variance = 0.0f64;
229 for row in 0..num_rows {
230 let x = data[row * num_features + feat] as f64;
231 variance += (x - mean).powi(2);
232 }
233 stats[feat].1 = variance / num_rows as f64;
234 }
235
236 stats
237 }
238
239 fn partial_fit_ema(&mut self, data: &[f32], num_features: usize, alpha: f32) -> Result<()> {
247 let num_rows = data.len() / num_features;
248 if num_rows == 0 {
249 return Ok(());
250 }
251
252 let batch_stats = Self::compute_batch_stats(data, num_features);
254
255 if self.means.is_empty() || !self.fitted {
257 self.means = vec![0.0; num_features];
258 self.stds = vec![1.0; num_features];
259 self.welford_states = vec![WelfordState::new(); num_features];
260
261 for feat in 0..num_features {
262 let (mean, var) = batch_stats[feat];
263 self.means[feat] = mean as f32;
264 let std = var.sqrt() as f32;
265 self.stds[feat] = if std < 1e-8 { 1.0 } else { std };
266
267 self.welford_states[feat].n = num_rows as u64;
269 self.welford_states[feat].mean = mean;
270 self.welford_states[feat].m2 = var * num_rows as f64;
271 }
272 self.fitted = true;
273 return Ok(());
274 }
275
276 if self.means.len() != num_features {
278 return Err(TreeBoostError::Data(format!(
279 "num_features mismatch: initialized with {}, partial_fit with {}",
280 self.means.len(),
281 num_features
282 )));
283 }
284
285 let alpha_64 = alpha as f64;
287 let decay = 1.0 - alpha_64;
288
289 for feat in 0..num_features {
290 let (batch_mean, batch_var) = batch_stats[feat];
291
292 let old_mean = self.means[feat] as f64;
294 let new_mean = decay * old_mean + alpha_64 * batch_mean;
295 self.means[feat] = new_mean as f32;
296
297 let old_var = (self.stds[feat] as f64).powi(2);
300 let new_var = decay * old_var + alpha_64 * batch_var;
301 let new_std = new_var.sqrt() as f32;
302 self.stds[feat] = if new_std < 1e-8 { 1.0 } else { new_std };
303
304 self.welford_states[feat].n += num_rows as u64;
306 }
307
308 Ok(())
309 }
310}
311
312impl Default for StandardScaler {
313 fn default() -> Self {
314 Self::new()
315 }
316}
317
318impl Scaler for StandardScaler {
319 fn fit(&mut self, data: &[f32], num_features: usize) -> Result<()> {
320 if num_features == 0 {
321 return Err(TreeBoostError::Data("num_features must be > 0".into()));
322 }
323
324 if !data.len().is_multiple_of(num_features) {
325 return Err(TreeBoostError::Data(format!(
326 "Data length {} not divisible by num_features {}",
327 data.len(),
328 num_features
329 )));
330 }
331
332 let num_rows = data.len() / num_features;
333
334 if num_rows == 0 {
335 return Err(TreeBoostError::Data("No rows to fit".into()));
336 }
337
338 self.means = vec![0.0; num_features];
339 self.stds = vec![0.0; num_features];
340
341 for feat in 0..num_features {
343 let mut sum = 0.0;
344 for row in 0..num_rows {
345 sum += data[row * num_features + feat];
346 }
347 self.means[feat] = sum / num_rows as f32;
348 }
349
350 for feat in 0..num_features {
352 let mean = self.means[feat];
353 let mut variance = 0.0;
354 for row in 0..num_rows {
355 let x = data[row * num_features + feat];
356 variance += (x - mean).powi(2);
357 }
358 let std = (variance / num_rows as f32).sqrt();
359
360 self.stds[feat] = if std < 1e-8 { 1.0 } else { std };
362 }
363
364 self.fitted = true;
365 Ok(())
366 }
367
368 fn transform(&self, data: &mut [f32], num_features: usize) -> Result<()> {
369 if !self.fitted {
370 return Err(TreeBoostError::Data(
371 "StandardScaler not fitted. Call fit() first.".into(),
372 ));
373 }
374
375 if num_features != self.means.len() {
376 return Err(TreeBoostError::Data(format!(
377 "num_features mismatch: fit with {}, transform with {}",
378 self.means.len(),
379 num_features
380 )));
381 }
382
383 if !data.len().is_multiple_of(num_features) {
384 return Err(TreeBoostError::Data(format!(
385 "Data length {} not divisible by num_features {}",
386 data.len(),
387 num_features
388 )));
389 }
390
391 let num_rows = data.len() / num_features;
392
393 for feat in 0..num_features {
395 let mean = self.means[feat];
396 let std = self.stds[feat];
397 for row in 0..num_rows {
398 let idx = row * num_features + feat;
399 data[idx] = (data[idx] - mean) / std;
400 }
401 }
402
403 Ok(())
404 }
405
406 fn is_fitted(&self) -> bool {
407 self.fitted
408 }
409}
410
411impl IncrementalScaler for StandardScaler {
412 fn partial_fit(&mut self, data: &[f32], num_features: usize) -> Result<()> {
413 if num_features == 0 {
414 return Err(TreeBoostError::Data("num_features must be > 0".into()));
415 }
416
417 if !data.len().is_multiple_of(num_features) {
418 return Err(TreeBoostError::Data(format!(
419 "Data length {} not divisible by num_features {}",
420 data.len(),
421 num_features
422 )));
423 }
424
425 let num_rows = data.len() / num_features;
426 if num_rows == 0 {
427 return Ok(()); }
429
430 if let Some(alpha) = self.forget_factor {
432 return self.partial_fit_ema(data, num_features, alpha);
433 }
434
435 if self.welford_states.is_empty() {
439 self.welford_states = vec![WelfordState::new(); num_features];
440 } else if self.welford_states.len() != num_features {
441 return Err(TreeBoostError::Data(format!(
442 "num_features mismatch: initialized with {}, partial_fit with {}",
443 self.welford_states.len(),
444 num_features
445 )));
446 }
447
448 for row in 0..num_rows {
450 for feat in 0..num_features {
451 let x = data[row * num_features + feat] as f64;
452 if x.is_finite() {
453 self.welford_states[feat].update(x);
454 }
455 }
456 }
457
458 self.sync_from_welford();
460 self.fitted = true;
461
462 Ok(())
463 }
464
465 fn n_samples(&self) -> u64 {
466 self.welford_states.first().map(|s| s.n).unwrap_or(0)
467 }
468
469 fn merge(&mut self, other: &Self) -> Result<()> {
470 if self.welford_states.is_empty() {
471 self.welford_states = other.welford_states.clone();
473 self.sync_from_welford();
474 self.fitted = other.fitted;
475 return Ok(());
476 }
477
478 if other.welford_states.is_empty() {
479 return Ok(()); }
481
482 if self.welford_states.len() != other.welford_states.len() {
483 return Err(TreeBoostError::Data(format!(
484 "Cannot merge scalers with different num_features: {} vs {}",
485 self.welford_states.len(),
486 other.welford_states.len()
487 )));
488 }
489
490 for (self_state, other_state) in self.welford_states.iter_mut().zip(&other.welford_states) {
492 self_state.merge(other_state);
493 }
494
495 self.sync_from_welford();
496 Ok(())
497 }
498}
499
500#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
531pub struct MinMaxScaler {
532 pub mins: Vec<f32>,
534 pub maxs: Vec<f32>,
536 pub feature_range: (f32, f32),
538 fitted: bool,
540 #[serde(default)]
542 n_samples: u64,
543}
544
545impl MinMaxScaler {
546 pub fn new() -> Self {
548 Self {
549 mins: Vec::new(),
550 maxs: Vec::new(),
551 feature_range: (0.0, 1.0),
552 fitted: false,
553 n_samples: 0,
554 }
555 }
556
557 pub fn with_range(mut self, min: f32, max: f32) -> Self {
559 self.feature_range = (min, max);
560 self
561 }
562}
563
564impl Default for MinMaxScaler {
565 fn default() -> Self {
566 Self::new()
567 }
568}
569
570impl Scaler for MinMaxScaler {
571 fn fit(&mut self, data: &[f32], num_features: usize) -> Result<()> {
572 if num_features == 0 {
573 return Err(TreeBoostError::Data("num_features must be > 0".into()));
574 }
575
576 if !data.len().is_multiple_of(num_features) {
577 return Err(TreeBoostError::Data(format!(
578 "Data length {} not divisible by num_features {}",
579 data.len(),
580 num_features
581 )));
582 }
583
584 let num_rows = data.len() / num_features;
585
586 if num_rows == 0 {
587 return Err(TreeBoostError::Data("No rows to fit".into()));
588 }
589
590 self.mins = vec![f32::INFINITY; num_features];
591 self.maxs = vec![f32::NEG_INFINITY; num_features];
592
593 for feat in 0..num_features {
595 for row in 0..num_rows {
596 let val = data[row * num_features + feat];
597 self.mins[feat] = self.mins[feat].min(val);
598 self.maxs[feat] = self.maxs[feat].max(val);
599 }
600
601 if (self.maxs[feat] - self.mins[feat]).abs() < 1e-8 {
603 self.maxs[feat] = self.mins[feat] + 1.0;
604 }
605 }
606
607 self.fitted = true;
608 Ok(())
609 }
610
611 fn transform(&self, data: &mut [f32], num_features: usize) -> Result<()> {
612 if !self.fitted {
613 return Err(TreeBoostError::Data(
614 "MinMaxScaler not fitted. Call fit() first.".into(),
615 ));
616 }
617
618 if num_features != self.mins.len() {
619 return Err(TreeBoostError::Data(format!(
620 "num_features mismatch: fit with {}, transform with {}",
621 self.mins.len(),
622 num_features
623 )));
624 }
625
626 if !data.len().is_multiple_of(num_features) {
627 return Err(TreeBoostError::Data(format!(
628 "Data length {} not divisible by num_features {}",
629 data.len(),
630 num_features
631 )));
632 }
633
634 let num_rows = data.len() / num_features;
635 let (a, b) = self.feature_range;
636
637 for feat in 0..num_features {
639 let min = self.mins[feat];
640 let max = self.maxs[feat];
641 let scale = b - a;
642
643 for row in 0..num_rows {
644 let idx = row * num_features + feat;
645 data[idx] = (data[idx] - min) / (max - min) * scale + a;
646
647 data[idx] = data[idx].clamp(a, b);
649 }
650 }
651
652 Ok(())
653 }
654
655 fn is_fitted(&self) -> bool {
656 self.fitted
657 }
658}
659
660impl IncrementalScaler for MinMaxScaler {
661 fn partial_fit(&mut self, data: &[f32], num_features: usize) -> Result<()> {
662 if num_features == 0 {
663 return Err(TreeBoostError::Data("num_features must be > 0".into()));
664 }
665
666 if !data.len().is_multiple_of(num_features) {
667 return Err(TreeBoostError::Data(format!(
668 "Data length {} not divisible by num_features {}",
669 data.len(),
670 num_features
671 )));
672 }
673
674 let num_rows = data.len() / num_features;
675 if num_rows == 0 {
676 return Ok(()); }
678
679 if self.mins.is_empty() {
681 self.mins = vec![f32::INFINITY; num_features];
682 self.maxs = vec![f32::NEG_INFINITY; num_features];
683 } else if self.mins.len() != num_features {
684 return Err(TreeBoostError::Data(format!(
685 "num_features mismatch: initialized with {}, partial_fit with {}",
686 self.mins.len(),
687 num_features
688 )));
689 }
690
691 for row in 0..num_rows {
693 for feat in 0..num_features {
694 let val = data[row * num_features + feat];
695 if val.is_finite() {
696 self.mins[feat] = self.mins[feat].min(val);
697 self.maxs[feat] = self.maxs[feat].max(val);
698 }
699 }
700 }
701
702 for feat in 0..num_features {
704 if (self.maxs[feat] - self.mins[feat]).abs() < 1e-8 {
705 self.maxs[feat] = self.mins[feat] + 1.0;
706 }
707 }
708
709 self.n_samples += num_rows as u64;
710 self.fitted = true;
711
712 Ok(())
713 }
714
715 fn n_samples(&self) -> u64 {
716 self.n_samples
717 }
718
719 fn merge(&mut self, other: &Self) -> Result<()> {
720 if self.mins.is_empty() {
721 self.mins = other.mins.clone();
723 self.maxs = other.maxs.clone();
724 self.n_samples = other.n_samples;
725 self.fitted = other.fitted;
726 return Ok(());
727 }
728
729 if other.mins.is_empty() {
730 return Ok(()); }
732
733 if self.mins.len() != other.mins.len() {
734 return Err(TreeBoostError::Data(format!(
735 "Cannot merge scalers with different num_features: {} vs {}",
736 self.mins.len(),
737 other.mins.len()
738 )));
739 }
740
741 for i in 0..self.mins.len() {
743 self.mins[i] = self.mins[i].min(other.mins[i]);
744 self.maxs[i] = self.maxs[i].max(other.maxs[i]);
745 }
746
747 self.n_samples += other.n_samples;
748 Ok(())
749 }
750}
751
752#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
781pub struct RobustScaler {
782 pub medians: Vec<f32>,
784 pub iqrs: Vec<f32>,
786 fitted: bool,
788}
789
790impl RobustScaler {
791 pub fn new() -> Self {
793 Self {
794 medians: Vec::new(),
795 iqrs: Vec::new(),
796 fitted: false,
797 }
798 }
799}
800
801impl Default for RobustScaler {
802 fn default() -> Self {
803 Self::new()
804 }
805}
806
807impl Scaler for RobustScaler {
808 fn fit(&mut self, data: &[f32], num_features: usize) -> Result<()> {
809 if num_features == 0 {
810 return Err(TreeBoostError::Data("num_features must be > 0".into()));
811 }
812
813 if !data.len().is_multiple_of(num_features) {
814 return Err(TreeBoostError::Data(format!(
815 "Data length {} not divisible by num_features {}",
816 data.len(),
817 num_features
818 )));
819 }
820
821 let num_rows = data.len() / num_features;
822
823 if num_rows == 0 {
824 return Err(TreeBoostError::Data("No rows to fit".into()));
825 }
826
827 self.medians = vec![0.0; num_features];
828 self.iqrs = vec![0.0; num_features];
829
830 use tdigest::TDigest;
833
834 for feat in 0..num_features {
836 let mut digest = TDigest::new_with_size(100); for row in 0..num_rows {
840 let value = data[row * num_features + feat] as f64;
841 if value.is_finite() {
842 digest = digest.merge_unsorted(vec![value]);
843 }
844 }
845
846 let q1 = digest.estimate_quantile(0.25) as f32;
848 let median = digest.estimate_quantile(0.50) as f32;
849 let q3 = digest.estimate_quantile(0.75) as f32;
850
851 self.medians[feat] = median;
852
853 let iqr = q3 - q1;
855
856 self.iqrs[feat] = if iqr < 1e-8 { 1.0 } else { iqr };
858 }
859
860 self.fitted = true;
861 Ok(())
862 }
863
864 fn transform(&self, data: &mut [f32], num_features: usize) -> Result<()> {
865 if !self.fitted {
866 return Err(TreeBoostError::Data(
867 "RobustScaler not fitted. Call fit() first.".into(),
868 ));
869 }
870
871 if num_features != self.medians.len() {
872 return Err(TreeBoostError::Data(format!(
873 "num_features mismatch: fit with {}, transform with {}",
874 self.medians.len(),
875 num_features
876 )));
877 }
878
879 if !data.len().is_multiple_of(num_features) {
880 return Err(TreeBoostError::Data(format!(
881 "Data length {} not divisible by num_features {}",
882 data.len(),
883 num_features
884 )));
885 }
886
887 let num_rows = data.len() / num_features;
888
889 for feat in 0..num_features {
891 let median = self.medians[feat];
892 let iqr = self.iqrs[feat];
893 for row in 0..num_rows {
894 let idx = row * num_features + feat;
895 data[idx] = (data[idx] - median) / iqr;
896 }
897 }
898
899 Ok(())
900 }
901
902 fn is_fitted(&self) -> bool {
903 self.fitted
904 }
905}
906
907#[cfg(test)]
912mod tests {
913 use super::*;
914
915 #[test]
916 fn test_standard_scaler_basic() {
917 let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
918 let num_features = 3;
922
923 let mut scaler = StandardScaler::new();
924 assert!(!scaler.is_fitted());
925
926 scaler.fit(&data, num_features).unwrap();
927 assert!(scaler.is_fitted());
928
929 assert_eq!(scaler.means(), &[2.5, 3.5, 4.5]);
934
935 scaler.transform(&mut data, num_features).unwrap();
936
937 }
939
940 #[test]
941 fn test_standard_scaler_zero_variance() {
942 let mut data = vec![5.0, 1.0, 2.0, 5.0, 3.0, 4.0];
943 let num_features = 3;
948
949 let mut scaler = StandardScaler::new();
950 scaler.fit(&data, num_features).unwrap();
951
952 assert_eq!(scaler.stds[0], 1.0);
954 assert_eq!(scaler.means[0], 5.0);
955
956 scaler.transform(&mut data, num_features).unwrap();
958 }
959
960 #[test]
961 fn test_minmax_scaler_basic() {
962 let mut data = vec![1.0, 10.0, 2.0, 20.0, 3.0, 30.0]; let num_features = 2;
964
965 let mut scaler = MinMaxScaler::new();
966 scaler.fit(&data, num_features).unwrap();
967
968 assert_eq!(scaler.mins, vec![1.0, 10.0]);
969 assert_eq!(scaler.maxs, vec![3.0, 30.0]);
970
971 scaler.transform(&mut data, num_features).unwrap();
972
973 assert!((data[0] - 0.0).abs() < 1e-6);
975 assert!((data[2] - 0.5).abs() < 1e-6);
976 assert!((data[4] - 1.0).abs() < 1e-6);
977
978 assert!((data[1] - 0.0).abs() < 1e-6);
980 assert!((data[3] - 0.5).abs() < 1e-6);
981 assert!((data[5] - 1.0).abs() < 1e-6);
982 }
983
984 #[test]
985 fn test_minmax_scaler_custom_range() {
986 let mut data = vec![1.0, 2.0, 3.0]; let num_features = 1;
988
989 let mut scaler = MinMaxScaler::new().with_range(-1.0, 1.0);
990 scaler.fit(&data, num_features).unwrap();
991 scaler.transform(&mut data, num_features).unwrap();
992
993 assert!((data[0] - (-1.0)).abs() < 1e-6);
995 assert!((data[1] - 0.0).abs() < 1e-6);
996 assert!((data[2] - 1.0).abs() < 1e-6);
997 }
998
999 #[test]
1000 fn test_robust_scaler_basic() {
1001 let mut data = vec![1.0, 2.0, 3.0, 100.0]; let num_features = 2;
1003
1004 let mut scaler = RobustScaler::new();
1005 scaler.fit(&data, num_features).unwrap();
1006
1007 assert!((scaler.medians[0] - 2.0).abs() < 1e-6);
1009
1010 scaler.transform(&mut data, num_features).unwrap();
1011
1012 }
1014
1015 #[test]
1016 fn test_scaler_not_fitted_error() {
1017 let mut data = vec![1.0, 2.0, 3.0];
1018 let scaler = StandardScaler::new();
1019
1020 let result = scaler.transform(&mut data, 1);
1021 assert!(result.is_err());
1022 assert!(result.unwrap_err().to_string().contains("not fitted"));
1023 }
1024
1025 #[test]
1026 fn test_scaler_feature_mismatch_error() {
1027 let data = vec![1.0, 2.0, 3.0, 4.0];
1028 let mut scaler = StandardScaler::new();
1029
1030 scaler.fit(&data, 2).unwrap(); let mut test_data = vec![5.0, 6.0, 7.0];
1033 let result = scaler.transform(&mut test_data, 3); assert!(result.is_err());
1036 assert!(result.unwrap_err().to_string().contains("mismatch"));
1037 }
1038
1039 #[test]
1044 fn test_standard_scaler_incremental_equivalence() {
1045 let all_data: Vec<f32> = (0..1000).map(|i| i as f32).collect();
1047 let num_features = 1;
1048
1049 let mut scaler_a = StandardScaler::new();
1051 scaler_a.fit(&all_data, num_features).unwrap();
1052
1053 let mut scaler_b = StandardScaler::new();
1055 for chunk in all_data.chunks(100) {
1056 scaler_b.partial_fit(chunk, num_features).unwrap();
1057 }
1058
1059 assert!(
1061 (scaler_a.means[0] - scaler_b.means[0]).abs() < 1e-3,
1062 "Means differ: {} vs {}",
1063 scaler_a.means[0],
1064 scaler_b.means[0]
1065 );
1066 assert!(
1067 (scaler_a.stds[0] - scaler_b.stds[0]).abs() < 1e-3,
1068 "Stds differ: {} vs {}",
1069 scaler_a.stds[0],
1070 scaler_b.stds[0]
1071 );
1072
1073 assert_eq!(scaler_b.n_samples(), 1000);
1075 }
1076
1077 #[test]
1078 fn test_standard_scaler_welford_stability() {
1079 let offset = 1e8_f32;
1081 let data: Vec<f32> = (0..100).map(|i| offset + i as f32).collect();
1082 let num_features = 1;
1083
1084 let mut scaler = StandardScaler::new();
1085 scaler.partial_fit(&data, num_features).unwrap();
1086
1087 let expected_mean = offset + 49.5;
1089 assert!(
1090 (scaler.means[0] - expected_mean).abs() < 1.0,
1091 "Mean with large offset: got {}, expected {}",
1092 scaler.means[0],
1093 expected_mean
1094 );
1095 }
1096
1097 #[test]
1098 fn test_standard_scaler_merge() {
1099 let num_features = 2;
1100
1101 let mut scaler_a = StandardScaler::new();
1103 scaler_a
1104 .partial_fit(&[1.0, 10.0, 2.0, 20.0], num_features)
1105 .unwrap();
1106
1107 let mut scaler_b = StandardScaler::new();
1109 scaler_b
1110 .partial_fit(&[3.0, 30.0, 4.0, 40.0], num_features)
1111 .unwrap();
1112
1113 scaler_a.merge(&scaler_b).unwrap();
1115
1116 assert_eq!(scaler_a.n_samples(), 4);
1118
1119 assert!((scaler_a.means[0] - 2.5).abs() < 1e-5);
1121
1122 assert!((scaler_a.means[1] - 25.0).abs() < 1e-4);
1124 }
1125
1126 #[test]
1127 fn test_minmax_scaler_incremental() {
1128 let num_features = 2;
1129
1130 let mut scaler = MinMaxScaler::new();
1131
1132 scaler
1134 .partial_fit(&[0.0, 0.0, 50.0, 100.0], num_features)
1135 .unwrap();
1136 assert_eq!(scaler.mins, vec![0.0, 0.0]);
1137 assert_eq!(scaler.maxs, vec![50.0, 100.0]);
1138
1139 scaler
1141 .partial_fit(&[25.0, 50.0, 100.0, 200.0], num_features)
1142 .unwrap();
1143
1144 assert_eq!(scaler.mins, vec![0.0, 0.0]); assert_eq!(scaler.maxs, vec![100.0, 200.0]); assert_eq!(scaler.n_samples(), 4);
1149 }
1150
1151 #[test]
1152 fn test_minmax_scaler_merge() {
1153 let num_features = 1;
1154
1155 let mut scaler_a = MinMaxScaler::new();
1156 scaler_a.partial_fit(&[10.0, 20.0], num_features).unwrap();
1157
1158 let mut scaler_b = MinMaxScaler::new();
1159 scaler_b.partial_fit(&[5.0, 30.0], num_features).unwrap();
1160
1161 scaler_a.merge(&scaler_b).unwrap();
1162
1163 assert_eq!(scaler_a.mins, vec![5.0]);
1165 assert_eq!(scaler_a.maxs, vec![30.0]);
1166 assert_eq!(scaler_a.n_samples(), 4);
1167 }
1168
1169 #[test]
1174 fn test_standard_scaler_forget_factor_creation() {
1175 let scaler = StandardScaler::with_forget_factor(0.1);
1176 assert_eq!(scaler.forget_factor(), Some(0.1));
1177
1178 let mut scaler2 = StandardScaler::new();
1179 assert_eq!(scaler2.forget_factor(), None);
1180
1181 scaler2.set_forget_factor(Some(0.5));
1182 assert_eq!(scaler2.forget_factor(), Some(0.5));
1183
1184 scaler2.set_forget_factor(None);
1185 assert_eq!(scaler2.forget_factor(), None);
1186 }
1187
1188 #[test]
1189 fn test_standard_scaler_forget_factor_clamping() {
1190 let scaler = StandardScaler::with_forget_factor(-0.5);
1191 assert_eq!(scaler.forget_factor(), Some(0.0));
1192
1193 let scaler2 = StandardScaler::with_forget_factor(1.5);
1194 assert_eq!(scaler2.forget_factor(), Some(1.0));
1195 }
1196
1197 #[test]
1198 fn test_standard_scaler_ema_single_batch() {
1199 let num_features = 1;
1201 let data = vec![10.0, 20.0, 30.0, 40.0];
1202
1203 let mut scaler = StandardScaler::with_forget_factor(0.1);
1204 scaler.partial_fit(&data, num_features).unwrap();
1205
1206 assert!(scaler.is_fitted());
1207 assert!((scaler.means()[0] - 25.0).abs() < 0.01);
1209 }
1210
1211 #[test]
1212 fn test_standard_scaler_ema_decay() {
1213 let num_features = 1;
1214
1215 let batch1 = vec![8.0, 10.0, 12.0];
1217 let batch2 = vec![98.0, 100.0, 102.0];
1219
1220 let mut scaler = StandardScaler::with_forget_factor(0.3);
1221 scaler.partial_fit(&batch1, num_features).unwrap();
1222
1223 let mean_after_batch1 = scaler.means()[0];
1224 assert!((mean_after_batch1 - 10.0).abs() < 0.01);
1225
1226 scaler.partial_fit(&batch2, num_features).unwrap();
1229
1230 let mean_after_batch2 = scaler.means()[0];
1231 assert!(
1232 (mean_after_batch2 - 37.0).abs() < 0.5,
1233 "Expected ~37, got {}",
1234 mean_after_batch2
1235 );
1236 }
1237
1238 #[test]
1239 fn test_standard_scaler_ema_vs_cumulative() {
1240 let num_features = 1;
1241
1242 let batch1 = vec![8.0, 10.0, 12.0];
1244 let batch2 = vec![98.0, 100.0, 102.0];
1246
1247 let mut cumulative = StandardScaler::new();
1249 cumulative.partial_fit(&batch1, num_features).unwrap();
1250 cumulative.partial_fit(&batch2, num_features).unwrap();
1251
1252 let mut ema = StandardScaler::with_forget_factor(0.5);
1254 ema.partial_fit(&batch1, num_features).unwrap();
1255 ema.partial_fit(&batch2, num_features).unwrap();
1256
1257 let cumulative_mean = cumulative.means()[0];
1259
1260 let ema_mean = ema.means()[0];
1262
1263 assert!((cumulative_mean - 55.0).abs() < 1.0);
1265 assert!((ema_mean - 55.0).abs() < 1.0);
1266 }
1267
1268 #[test]
1269 fn test_standard_scaler_ema_adapts_to_drift() {
1270 let num_features = 1;
1271
1272 let batch1 = vec![8.0, 10.0, 12.0];
1274
1275 let batch2 = vec![28.0, 30.0, 32.0]; let batch3 = vec![48.0, 50.0, 52.0]; let batch4 = vec![68.0, 70.0, 72.0]; let batch5 = vec![88.0, 90.0, 92.0]; let mut scaler = StandardScaler::with_forget_factor(0.5); scaler.partial_fit(&batch1, num_features).unwrap();
1284 assert!((scaler.means()[0] - 10.0).abs() < 1.0);
1285
1286 scaler.partial_fit(&batch2, num_features).unwrap();
1287 assert!(
1289 (scaler.means()[0] - 20.0).abs() < 1.0,
1290 "Expected ~20, got {}",
1291 scaler.means()[0]
1292 );
1293
1294 scaler.partial_fit(&batch3, num_features).unwrap();
1295 assert!(
1297 (scaler.means()[0] - 35.0).abs() < 1.0,
1298 "Expected ~35, got {}",
1299 scaler.means()[0]
1300 );
1301
1302 scaler.partial_fit(&batch4, num_features).unwrap();
1303 assert!(
1305 (scaler.means()[0] - 52.5).abs() < 1.0,
1306 "Expected ~52.5, got {}",
1307 scaler.means()[0]
1308 );
1309
1310 scaler.partial_fit(&batch5, num_features).unwrap();
1311 assert!(
1313 (scaler.means()[0] - 71.25).abs() < 1.5,
1314 "Expected ~71.25, got {}",
1315 scaler.means()[0]
1316 );
1317 }
1318
1319 #[test]
1320 fn test_standard_scaler_ema_variance_decay() {
1321 let num_features = 1;
1322
1323 let batch1 = vec![9.9, 10.0, 10.1]; let batch2 = vec![0.0, 10.0, 20.0]; let mut scaler = StandardScaler::with_forget_factor(0.3);
1330
1331 scaler.partial_fit(&batch1, num_features).unwrap();
1332 let std_after_batch1 = scaler.stds()[0];
1333 assert!(
1334 std_after_batch1 < 1.0,
1335 "Std should be small after low-variance batch"
1336 );
1337
1338 scaler.partial_fit(&batch2, num_features).unwrap();
1339 let std_after_batch2 = scaler.stds()[0];
1340
1341 assert!(
1343 std_after_batch2 > std_after_batch1,
1344 "Std should increase after high-variance batch"
1345 );
1346 }
1347}