1use crate::dataset::Dataset;
5use crate::error::{Result, ScryLearnError};
6use crate::preprocess::Transformer;
7use crate::sparse::CscMatrix;
8
9#[derive(Clone, Debug)]
14#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
15#[non_exhaustive]
16pub struct StandardScaler {
17 means: Vec<f64>,
18 stds: Vec<f64>,
19 fitted: bool,
20 #[cfg_attr(feature = "serde", serde(default))]
21 _schema_version: u32,
22}
23
24impl StandardScaler {
25 pub fn new() -> Self {
27 Self {
28 means: Vec::new(),
29 stds: Vec::new(),
30 fitted: false,
31 _schema_version: crate::version::SCHEMA_VERSION,
32 }
33 }
34}
35
36impl StandardScaler {
37 pub fn fit_sparse(&mut self, features: &CscMatrix) -> Result<()> {
42 let n = features.n_rows();
43 if n == 0 {
44 return Err(ScryLearnError::EmptyDataset);
45 }
46 let n_f64 = n as f64;
47 self.means = Vec::with_capacity(features.n_cols());
48 self.stds = Vec::with_capacity(features.n_cols());
49
50 for j in 0..features.n_cols() {
51 let col = features.col(j);
52 let sum: f64 = col.iter().map(|(_, v)| v).sum();
53 let mean = sum / n_f64;
54 let mut var = 0.0;
55 let mut nnz_count = 0usize;
56 for (_, val) in col.iter() {
57 var += (val - mean).powi(2);
58 nnz_count += 1;
59 }
60 let n_zeros = n - nnz_count;
62 var += n_zeros as f64 * mean * mean;
63 var /= n_f64;
64 self.means.push(mean);
65 self.stds.push(var.sqrt());
66 }
67 self.fitted = true;
68 Ok(())
69 }
70
71 pub fn transform_sparse(&self, features: &CscMatrix) -> Result<CscMatrix> {
76 if !self.fitted {
77 return Err(ScryLearnError::NotFitted);
78 }
79 let mut cols: Vec<Vec<f64>> = Vec::with_capacity(features.n_cols());
81 for j in 0..features.n_cols() {
82 let std = self.stds[j];
83 let mut col = vec![0.0; features.n_rows()];
84 if std > 1e-12 {
85 for (row_idx, val) in features.col(j).iter() {
86 col[row_idx] = val / std;
87 }
88 }
89 cols.push(col);
90 }
91 Ok(CscMatrix::from_dense(&cols))
92 }
93}
94
95impl StandardScaler {
96 pub fn is_fitted(&self) -> bool {
98 self.fitted
99 }
100
101 pub fn means(&self) -> &[f64] {
103 &self.means
104 }
105
106 pub fn stds(&self) -> &[f64] {
108 &self.stds
109 }
110}
111
112impl Default for StandardScaler {
113 fn default() -> Self {
114 Self::new()
115 }
116}
117
118impl Transformer for StandardScaler {
119 fn fit(&mut self, data: &Dataset) -> Result<()> {
120 data.validate_finite()?;
121 if let Some(csc) = data.sparse_csc() {
122 return self.fit_sparse(csc);
123 }
124 let n = data.n_samples() as f64;
125 if n == 0.0 {
126 return Err(ScryLearnError::EmptyDataset);
127 }
128 let mat = data.matrix();
129 self.means = Vec::with_capacity(data.n_features());
130 self.stds = Vec::with_capacity(data.n_features());
131
132 for j in 0..data.n_features() {
133 let col = mat.col(j);
134 let mean = col.iter().sum::<f64>() / n;
135 let var = col.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
136 self.means.push(mean);
137 self.stds.push(var.sqrt());
138 }
139 self.fitted = true;
140 Ok(())
141 }
142
143 fn transform(&self, data: &mut Dataset) -> Result<()> {
144 crate::version::check_schema_version(self._schema_version)?;
145 if !self.fitted {
146 return Err(ScryLearnError::NotFitted);
147 }
148 for (j, col) in data.features.iter_mut().enumerate() {
149 let mean = self.means[j];
150 let std = self.stds[j];
151 if std > 1e-12 {
152 for x in col.iter_mut() {
153 *x = (*x - mean) / std;
154 }
155 }
156 }
157 data.sync_matrix();
158 Ok(())
159 }
160
161 fn inverse_transform(&self, data: &mut Dataset) -> Result<()> {
162 if !self.fitted {
163 return Err(ScryLearnError::NotFitted);
164 }
165 for (j, col) in data.features.iter_mut().enumerate() {
166 let mean = self.means[j];
167 let std = self.stds[j];
168 if std > 1e-12 {
169 for x in col.iter_mut() {
170 *x = *x * std + mean;
171 }
172 }
173 }
176 data.sync_matrix();
177 Ok(())
178 }
179}
180
181#[derive(Clone, Debug)]
186#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
187#[non_exhaustive]
188pub struct MinMaxScaler {
189 mins: Vec<f64>,
190 maxs: Vec<f64>,
191 fitted: bool,
192 #[cfg_attr(feature = "serde", serde(default))]
193 _schema_version: u32,
194}
195
196impl MinMaxScaler {
197 pub fn new() -> Self {
199 Self {
200 mins: Vec::new(),
201 maxs: Vec::new(),
202 fitted: false,
203 _schema_version: crate::version::SCHEMA_VERSION,
204 }
205 }
206}
207
208impl MinMaxScaler {
209 pub fn fit_sparse(&mut self, features: &CscMatrix) -> Result<()> {
213 let n = features.n_rows();
214 if n == 0 {
215 return Err(ScryLearnError::EmptyDataset);
216 }
217 self.mins = Vec::with_capacity(features.n_cols());
218 self.maxs = Vec::with_capacity(features.n_cols());
219
220 for j in 0..features.n_cols() {
221 let col = features.col(j);
222 let nnz = col.nnz();
223 if nnz == 0 {
224 self.mins.push(0.0);
226 self.maxs.push(0.0);
227 } else {
228 let mut min = f64::INFINITY;
229 let mut max = f64::NEG_INFINITY;
230 for (_, val) in col.iter() {
231 if val < min {
232 min = val;
233 }
234 if val > max {
235 max = val;
236 }
237 }
238 if nnz < n {
240 if 0.0 < min {
241 min = 0.0;
242 }
243 if 0.0 > max {
244 max = 0.0;
245 }
246 }
247 self.mins.push(min);
248 self.maxs.push(max);
249 }
250 }
251 self.fitted = true;
252 Ok(())
253 }
254
255 pub fn transform_sparse(&self, features: &CscMatrix) -> Result<CscMatrix> {
257 if !self.fitted {
258 return Err(ScryLearnError::NotFitted);
259 }
260 let mut cols: Vec<Vec<f64>> = Vec::with_capacity(features.n_cols());
261 for j in 0..features.n_cols() {
262 let min = self.mins[j];
263 let range = self.maxs[j] - min;
264 let mut col = vec![0.0; features.n_rows()];
265 if range > 1e-12 {
266 let zero_mapped = (0.0 - min) / range;
268 col.fill(zero_mapped);
269 for (row_idx, val) in features.col(j).iter() {
270 col[row_idx] = (val - min) / range;
271 }
272 }
273 cols.push(col);
274 }
275 Ok(CscMatrix::from_dense(&cols))
276 }
277}
278
279impl Default for MinMaxScaler {
280 fn default() -> Self {
281 Self::new()
282 }
283}
284
285impl Transformer for MinMaxScaler {
286 fn fit(&mut self, data: &Dataset) -> Result<()> {
287 data.validate_finite()?;
288 if let Some(csc) = data.sparse_csc() {
289 return self.fit_sparse(csc);
290 }
291 if data.n_samples() == 0 {
292 return Err(ScryLearnError::EmptyDataset);
293 }
294 let mat = data.matrix();
295 self.mins = Vec::with_capacity(data.n_features());
296 self.maxs = Vec::with_capacity(data.n_features());
297
298 for j in 0..data.n_features() {
299 let col = mat.col(j);
300 let min = col.iter().copied().fold(f64::INFINITY, f64::min);
301 let max = col.iter().copied().fold(f64::NEG_INFINITY, f64::max);
302 self.mins.push(min);
303 self.maxs.push(max);
304 }
305 self.fitted = true;
306 Ok(())
307 }
308
309 fn transform(&self, data: &mut Dataset) -> Result<()> {
310 crate::version::check_schema_version(self._schema_version)?;
311 if !self.fitted {
312 return Err(ScryLearnError::NotFitted);
313 }
314 for (j, col) in data.features.iter_mut().enumerate() {
315 let min = self.mins[j];
316 let range = self.maxs[j] - min;
317 if range > 1e-12 {
318 for x in col.iter_mut() {
319 *x = (*x - min) / range;
320 }
321 } else {
322 for x in col.iter_mut() {
323 *x = 0.0;
324 }
325 }
326 }
327 data.sync_matrix();
328 Ok(())
329 }
330
331 fn inverse_transform(&self, data: &mut Dataset) -> Result<()> {
332 if !self.fitted {
333 return Err(ScryLearnError::NotFitted);
334 }
335 for (j, col) in data.features.iter_mut().enumerate() {
336 let min = self.mins[j];
337 let range = self.maxs[j] - min;
338 for x in col.iter_mut() {
339 *x = *x * range + min;
340 }
341 }
342 data.sync_matrix();
343 Ok(())
344 }
345}
346
347fn quantile_sorted(sorted: &[f64], q: f64) -> f64 {
351 debug_assert!(!sorted.is_empty());
352 if sorted.len() == 1 {
353 return sorted[0];
354 }
355 let pos = q * (sorted.len() - 1) as f64;
356 let lo = pos.floor() as usize;
357 let hi = pos.ceil() as usize;
358 let frac = pos - lo as f64;
359 sorted[lo] * (1.0 - frac) + sorted[hi] * frac
360}
361
362#[derive(Clone, Debug)]
377#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
378#[non_exhaustive]
379pub struct RobustScaler {
380 medians: Vec<f64>,
381 iqrs: Vec<f64>,
382 fitted: bool,
383 #[cfg_attr(feature = "serde", serde(default))]
384 _schema_version: u32,
385}
386
387impl RobustScaler {
388 pub fn new() -> Self {
390 Self {
391 medians: Vec::new(),
392 iqrs: Vec::new(),
393 fitted: false,
394 _schema_version: crate::version::SCHEMA_VERSION,
395 }
396 }
397}
398
399impl Default for RobustScaler {
400 fn default() -> Self {
401 Self::new()
402 }
403}
404
405impl Transformer for RobustScaler {
406 fn fit(&mut self, data: &Dataset) -> Result<()> {
407 data.validate_finite()?;
408 if data.n_samples() == 0 {
409 return Err(ScryLearnError::EmptyDataset);
410 }
411 let mat = data.matrix();
412 self.medians = Vec::with_capacity(data.n_features());
413 self.iqrs = Vec::with_capacity(data.n_features());
414
415 for j in 0..data.n_features() {
416 let col = mat.col(j);
417 let mut sorted = col.to_vec();
418 sorted.sort_unstable_by(|a, b| a.total_cmp(b));
419 let median = quantile_sorted(&sorted, 0.5);
420 let q1 = quantile_sorted(&sorted, 0.25);
421 let q3 = quantile_sorted(&sorted, 0.75);
422 self.medians.push(median);
423 self.iqrs.push(q3 - q1);
424 }
425 self.fitted = true;
426 Ok(())
427 }
428
429 fn transform(&self, data: &mut Dataset) -> Result<()> {
430 crate::version::check_schema_version(self._schema_version)?;
431 if !self.fitted {
432 return Err(ScryLearnError::NotFitted);
433 }
434 for (j, col) in data.features.iter_mut().enumerate() {
435 let median = self.medians[j];
436 let iqr = self.iqrs[j];
437 if iqr > 1e-12 {
438 for x in col.iter_mut() {
439 *x = (*x - median) / iqr;
440 }
441 } else {
442 for x in col.iter_mut() {
443 *x -= median;
444 }
445 }
446 }
447 data.sync_matrix();
448 Ok(())
449 }
450
451 fn inverse_transform(&self, data: &mut Dataset) -> Result<()> {
452 if !self.fitted {
453 return Err(ScryLearnError::NotFitted);
454 }
455 for (j, col) in data.features.iter_mut().enumerate() {
456 let median = self.medians[j];
457 let iqr = self.iqrs[j];
458 if iqr > 1e-12 {
459 for x in col.iter_mut() {
460 *x = *x * iqr + median;
461 }
462 } else {
463 for x in col.iter_mut() {
466 *x += median;
467 }
468 }
469 }
470 data.sync_matrix();
471 Ok(())
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478
479 #[test]
480 fn test_standard_scaler_zero_mean() {
481 let mut ds = Dataset::new(
482 vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]],
483 vec![0.0; 5],
484 vec!["x".into()],
485 "y",
486 );
487 let mut scaler = StandardScaler::new();
488 scaler.fit_transform(&mut ds).unwrap();
489
490 let mean: f64 = ds.features[0].iter().sum::<f64>() / 5.0;
491 assert!((mean).abs() < 1e-10, "mean should be ~0, got {mean}");
492
493 let var: f64 = ds.features[0].iter().map(|x| x.powi(2)).sum::<f64>() / 5.0;
494 assert!(
495 (var - 1.0).abs() < 1e-10,
496 "variance should be ~1, got {var}"
497 );
498 }
499
500 #[test]
501 fn test_minmax_scaler_range() {
502 let mut ds = Dataset::new(
503 vec![vec![10.0, 20.0, 30.0]],
504 vec![0.0; 3],
505 vec!["x".into()],
506 "y",
507 );
508 let mut scaler = MinMaxScaler::new();
509 scaler.fit_transform(&mut ds).unwrap();
510
511 assert!((ds.features[0][0]).abs() < 1e-10);
512 assert!((ds.features[0][2] - 1.0).abs() < 1e-10);
513 }
514
515 #[test]
516 fn test_standard_scaler_not_fitted() {
517 let scaler = StandardScaler::new();
518 let mut ds = Dataset::new(vec![vec![1.0]], vec![0.0], vec!["x".into()], "y");
519 assert!(scaler.transform(&mut ds).is_err());
520 }
521
522 #[test]
523 fn test_standard_scaler_roundtrip() {
524 let original = vec![2.0, 4.0, 6.0, 8.0];
525 let mut ds = Dataset::new(vec![original.clone()], vec![0.0; 4], vec!["x".into()], "y");
526 let mut scaler = StandardScaler::new();
527 scaler.fit_transform(&mut ds).unwrap();
528 scaler.inverse_transform(&mut ds).unwrap();
529
530 for (a, b) in ds.features[0].iter().zip(original.iter()) {
531 assert!((a - b).abs() < 1e-10);
532 }
533 }
534
535 #[test]
536 fn test_robust_scaler_median_centering() {
537 let mut ds = Dataset::new(
539 vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]],
540 vec![0.0; 5],
541 vec!["x".into()],
542 "y",
543 );
544 let mut scaler = RobustScaler::new();
545 scaler.fit_transform(&mut ds).unwrap();
546
547 assert!(
549 ds.features[0][2].abs() < 1e-10,
550 "median should map to 0, got {}",
551 ds.features[0][2]
552 );
553 }
554
555 #[test]
556 fn test_robust_scaler_outlier_tolerance() {
557 let data = vec![1.0, 2.0, 3.0, 4.0, 1000.0];
559
560 let mut ds_std = Dataset::new(vec![data.clone()], vec![0.0; 5], vec!["x".into()], "y");
562 let mut std_scaler = StandardScaler::new();
563 std_scaler.fit_transform(&mut ds_std).unwrap();
564
565 let mut ds_rob = Dataset::new(vec![data], vec![0.0; 5], vec!["x".into()], "y");
567 let mut rob_scaler = RobustScaler::new();
568 rob_scaler.fit_transform(&mut ds_rob).unwrap();
569
570 let robust_range = ds_rob.features[0][3] - ds_rob.features[0][0];
574 let std_range = ds_std.features[0][3] - ds_std.features[0][0];
575 assert!(
576 robust_range > std_range,
577 "RobustScaler should give wider spread to non-outliers: robust={robust_range:.4} vs std={std_range:.4}"
578 );
579 }
580
581 #[test]
582 fn test_robust_scaler_roundtrip() {
583 let original = vec![2.0, 4.0, 6.0, 8.0];
584 let mut ds = Dataset::new(vec![original.clone()], vec![0.0; 4], vec!["x".into()], "y");
585 let mut scaler = RobustScaler::new();
586 scaler.fit_transform(&mut ds).unwrap();
587 scaler.inverse_transform(&mut ds).unwrap();
588
589 for (a, b) in ds.features[0].iter().zip(original.iter()) {
590 assert!((a - b).abs() < 1e-10, "roundtrip failed: {a} != {b}");
591 }
592 }
593
594 #[test]
595 fn test_standard_scaler_sparse_fit() {
596 let cols = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
597 let csc = CscMatrix::from_dense(&cols);
598
599 let mut scaler = StandardScaler::new();
600 scaler.fit_sparse(&csc).unwrap();
601
602 let ds = Dataset::new(cols, vec![0.0; 5], vec!["x".into()], "y");
604 let mut scaler_d = StandardScaler::new();
605 scaler_d.fit(&ds).unwrap();
606
607 assert!(
609 (scaler.means[0] - scaler_d.means[0]).abs() < 1e-10,
610 "Sparse mean={} vs Dense mean={}",
611 scaler.means[0],
612 scaler_d.means[0]
613 );
614 }
615
616 #[test]
617 fn test_minmax_scaler_sparse_fit() {
618 let cols = vec![vec![0.0, 5.0, 0.0, 10.0, 0.0]];
619 let csc = CscMatrix::from_dense(&cols);
620
621 let mut scaler = MinMaxScaler::new();
622 scaler.fit_sparse(&csc).unwrap();
623
624 assert!((scaler.mins[0] - 0.0).abs() < 1e-10);
625 assert!((scaler.maxs[0] - 10.0).abs() < 1e-10);
626 }
627}