1use scirs2_core::ndarray::{Array1, Array2};
8use sklears_core::{
9 error::{Result, SklearsError},
10 traits::{Fit, Trained, Transform, Untrained},
11 types::Float,
12};
13use std::marker::PhantomData;
14
15#[cfg(feature = "parallel")]
16use rayon::prelude::*;
17
18#[derive(Debug, Clone)]
20pub struct WinsorizerConfig {
21 pub lower_percentile: Float,
23 pub upper_percentile: Float,
25 pub feature_wise: bool,
27 pub nan_strategy: NanStrategy,
29}
30
31#[derive(Debug, Clone, Copy, Default)]
33pub enum NanStrategy {
34 #[default]
36 Skip,
37 Interpolate,
39 Replace,
41}
42
43impl Default for WinsorizerConfig {
44 fn default() -> Self {
45 Self {
46 lower_percentile: 5.0,
47 upper_percentile: 95.0,
48 feature_wise: true,
49 nan_strategy: NanStrategy::Skip,
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
61pub struct Winsorizer<State = Untrained> {
62 config: WinsorizerConfig,
63 state: PhantomData<State>,
64 lower_bounds_: Option<Array1<Float>>,
66 upper_bounds_: Option<Array1<Float>>,
67 n_features_in_: Option<usize>,
68}
69
70impl Winsorizer<Untrained> {
71 pub fn new() -> Self {
73 Self {
74 config: WinsorizerConfig::default(),
75 state: PhantomData,
76 lower_bounds_: None,
77 upper_bounds_: None,
78 n_features_in_: None,
79 }
80 }
81
82 pub fn with_percentiles(lower: Float, upper: Float) -> Self {
84 Self::new().lower_percentile(lower).upper_percentile(upper)
85 }
86
87 pub fn with_iqr(multiplier: Float) -> Self {
89 let lower_perc = if multiplier >= 1.5 { 0.7 } else { 2.5 };
92 let upper_perc = if multiplier >= 1.5 { 99.3 } else { 97.5 };
93 Self::new()
94 .lower_percentile(lower_perc)
95 .upper_percentile(upper_perc)
96 }
97
98 pub fn lower_percentile(mut self, percentile: Float) -> Self {
100 if !(0.0..50.0).contains(&percentile) {
101 panic!("Lower percentile must be between 0 and 50");
102 }
103 self.config.lower_percentile = percentile;
104 self
105 }
106
107 pub fn upper_percentile(mut self, percentile: Float) -> Self {
109 if percentile <= 50.0 || percentile > 100.0 {
110 panic!("Upper percentile must be between 50 and 100");
111 }
112 self.config.upper_percentile = percentile;
113 self
114 }
115
116 pub fn feature_wise(mut self, feature_wise: bool) -> Self {
118 self.config.feature_wise = feature_wise;
119 self
120 }
121
122 pub fn nan_strategy(mut self, strategy: NanStrategy) -> Self {
124 self.config.nan_strategy = strategy;
125 self
126 }
127
128 fn compute_percentile(sorted_data: &[Float], percentile: Float) -> Float {
130 if sorted_data.is_empty() {
131 return Float::NAN;
132 }
133
134 if percentile <= 0.0 {
135 return sorted_data[0];
136 }
137 if percentile >= 100.0 {
138 return sorted_data[sorted_data.len() - 1];
139 }
140
141 let index = (percentile / 100.0) * (sorted_data.len() - 1) as Float;
142 let lower_index = index.floor() as usize;
143 let upper_index = index.ceil() as usize;
144
145 if lower_index == upper_index {
146 sorted_data[lower_index]
147 } else {
148 let weight = index - lower_index as Float;
149 sorted_data[lower_index] * (1.0 - weight) + sorted_data[upper_index] * weight
150 }
151 }
152
153 fn compute_feature_bounds(&self, feature_data: &Array1<Float>) -> Result<(Float, Float)> {
155 let mut valid_data: Vec<Float> = feature_data
157 .iter()
158 .filter(|&&x| x.is_finite())
159 .cloned()
160 .collect();
161
162 if valid_data.is_empty() {
163 return Ok((Float::NEG_INFINITY, Float::INFINITY));
164 }
165
166 if valid_data.len() == 1 {
167 let value = valid_data[0];
168 return Ok((value, value));
169 }
170
171 valid_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
173
174 let lower_bound = Self::compute_percentile(&valid_data, self.config.lower_percentile);
175 let upper_bound = Self::compute_percentile(&valid_data, self.config.upper_percentile);
176
177 Ok((lower_bound, upper_bound))
178 }
179}
180
181impl Winsorizer<Trained> {
182 pub fn lower_bounds(&self) -> &Array1<Float> {
184 self.lower_bounds_
185 .as_ref()
186 .expect("Winsorizer should be fitted")
187 }
188
189 pub fn upper_bounds(&self) -> &Array1<Float> {
191 self.upper_bounds_
192 .as_ref()
193 .expect("Winsorizer should be fitted")
194 }
195
196 pub fn n_features_in(&self) -> usize {
198 self.n_features_in_.expect("Winsorizer should be fitted")
199 }
200
201 pub fn winsorize_single(&self, feature_idx: usize, value: Float) -> Float {
203 if value.is_nan() {
204 match self.config.nan_strategy {
205 NanStrategy::Skip => value,
206 NanStrategy::Interpolate | NanStrategy::Replace => {
207 let lower = self.lower_bounds()[feature_idx];
209 let upper = self.upper_bounds()[feature_idx];
210 if lower.is_finite() && upper.is_finite() {
211 (lower + upper) / 2.0
212 } else {
213 value
214 }
215 }
216 }
217 } else {
218 let lower = self.lower_bounds()[feature_idx];
219 let upper = self.upper_bounds()[feature_idx];
220
221 if value < lower {
222 lower
223 } else if value > upper {
224 upper
225 } else {
226 value
227 }
228 }
229 }
230
231 pub fn get_winsorization_stats(&self, x: &Array2<Float>) -> Result<WinsorizationStats> {
233 let (n_samples, n_features) = x.dim();
234
235 if n_features != self.n_features_in() {
236 return Err(SklearsError::FeatureMismatch {
237 expected: self.n_features_in(),
238 actual: n_features,
239 });
240 }
241
242 let mut stats = WinsorizationStats {
243 n_samples,
244 n_features,
245 lower_clipped_per_feature: vec![0; n_features],
246 upper_clipped_per_feature: vec![0; n_features],
247 total_clipped: 0,
248 clipping_rate: 0.0,
249 };
250
251 #[cfg(feature = "parallel")]
252 {
253 if n_samples * n_features > 10000 {
255 let clipping_counts: Vec<(usize, usize, usize)> = (0..n_features)
256 .into_par_iter()
257 .map(|j| {
258 let lower = self.lower_bounds()[j];
259 let upper = self.upper_bounds()[j];
260 let mut lower_clipped = 0;
261 let mut upper_clipped = 0;
262
263 for i in 0..n_samples {
264 let value = x[[i, j]];
265 if value.is_finite() {
266 if value < lower {
267 lower_clipped += 1;
268 } else if value > upper {
269 upper_clipped += 1;
270 }
271 }
272 }
273 (j, lower_clipped, upper_clipped)
274 })
275 .collect();
276
277 for (j, lower_clipped, upper_clipped) in clipping_counts {
278 stats.lower_clipped_per_feature[j] = lower_clipped;
279 stats.upper_clipped_per_feature[j] = upper_clipped;
280 stats.total_clipped += lower_clipped + upper_clipped;
281 }
282 } else {
283 for i in 0..n_samples {
285 for j in 0..n_features {
286 let value = x[[i, j]];
287 if value.is_finite() {
288 let lower = self.lower_bounds()[j];
289 let upper = self.upper_bounds()[j];
290
291 if value < lower {
292 stats.lower_clipped_per_feature[j] += 1;
293 stats.total_clipped += 1;
294 } else if value > upper {
295 stats.upper_clipped_per_feature[j] += 1;
296 stats.total_clipped += 1;
297 }
298 }
299 }
300 }
301 }
302 }
303
304 #[cfg(not(feature = "parallel"))]
305 {
306 for i in 0..n_samples {
308 for j in 0..n_features {
309 let value = x[[i, j]];
310 if value.is_finite() {
311 let lower = self.lower_bounds()[j];
312 let upper = self.upper_bounds()[j];
313
314 if value < lower {
315 stats.lower_clipped_per_feature[j] += 1;
316 stats.total_clipped += 1;
317 } else if value > upper {
318 stats.upper_clipped_per_feature[j] += 1;
319 stats.total_clipped += 1;
320 }
321 }
322 }
323 }
324 }
325
326 stats.clipping_rate = stats.total_clipped as Float / (n_samples * n_features) as Float;
327 Ok(stats)
328 }
329}
330
331#[derive(Debug, Clone)]
333pub struct WinsorizationStats {
334 pub n_samples: usize,
336 pub n_features: usize,
338 pub lower_clipped_per_feature: Vec<usize>,
340 pub upper_clipped_per_feature: Vec<usize>,
342 pub total_clipped: usize,
344 pub clipping_rate: Float,
346}
347
348impl Default for Winsorizer<Untrained> {
349 fn default() -> Self {
350 Self::new()
351 }
352}
353
354impl Fit<Array2<Float>, ()> for Winsorizer<Untrained> {
355 type Fitted = Winsorizer<Trained>;
356
357 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
358 let (n_samples, n_features) = x.dim();
359
360 if n_samples == 0 {
361 return Err(SklearsError::InvalidInput(
362 "Cannot fit winsorizer on empty dataset".to_string(),
363 ));
364 }
365
366 if self.config.lower_percentile >= self.config.upper_percentile {
367 return Err(SklearsError::InvalidInput(
368 "Lower percentile must be less than upper percentile".to_string(),
369 ));
370 }
371
372 let mut lower_bounds = Array1::<Float>::zeros(n_features);
373 let mut upper_bounds = Array1::<Float>::zeros(n_features);
374
375 #[cfg(feature = "parallel")]
376 {
377 if n_features > 4 {
379 let bounds: Result<Vec<(Float, Float)>> = (0..n_features)
380 .into_par_iter()
381 .map(|j| {
382 let feature_data = x.column(j).to_owned();
383 self.compute_feature_bounds(&feature_data)
384 })
385 .collect();
386
387 match bounds {
388 Ok(bounds_vec) => {
389 for (j, (lower, upper)) in bounds_vec.into_iter().enumerate() {
390 lower_bounds[j] = lower;
391 upper_bounds[j] = upper;
392 }
393 }
394 Err(e) => return Err(e),
395 }
396 } else {
397 for j in 0..n_features {
399 let feature_data = x.column(j).to_owned();
400 let (lower, upper) = self.compute_feature_bounds(&feature_data)?;
401 lower_bounds[j] = lower;
402 upper_bounds[j] = upper;
403 }
404 }
405 }
406
407 #[cfg(not(feature = "parallel"))]
408 {
409 for j in 0..n_features {
411 let feature_data = x.column(j).to_owned();
412 let (lower, upper) = self.compute_feature_bounds(&feature_data)?;
413 lower_bounds[j] = lower;
414 upper_bounds[j] = upper;
415 }
416 }
417
418 Ok(Winsorizer {
419 config: self.config,
420 state: PhantomData,
421 lower_bounds_: Some(lower_bounds),
422 upper_bounds_: Some(upper_bounds),
423 n_features_in_: Some(n_features),
424 })
425 }
426}
427
428impl Transform<Array2<Float>, Array2<Float>> for Winsorizer<Trained> {
429 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
430 let (n_samples, n_features) = x.dim();
431
432 if n_features != self.n_features_in() {
433 return Err(SklearsError::FeatureMismatch {
434 expected: self.n_features_in(),
435 actual: n_features,
436 });
437 }
438
439 let mut result = x.clone();
440
441 #[cfg(feature = "parallel")]
442 {
443 if n_samples * n_features > 10000 {
445 result
446 .as_slice_mut()
447 .unwrap()
448 .par_iter_mut()
449 .enumerate()
450 .for_each(|(idx, value)| {
451 let i = idx / n_features;
452 let j = idx % n_features;
453 *value = self.winsorize_single(j, x[[i, j]]);
454 });
455 return Ok(result);
456 }
457 }
458
459 for i in 0..n_samples {
461 for j in 0..n_features {
462 result[[i, j]] = self.winsorize_single(j, x[[i, j]]);
463 }
464 }
465
466 Ok(result)
467 }
468}
469
470#[allow(non_snake_case)]
471#[cfg(test)]
472mod tests {
473 use super::*;
474 use approx::assert_abs_diff_eq;
475 use scirs2_core::ndarray::array;
476
477 #[test]
478 fn test_winsorizer_basic() {
479 let x = array![
480 [1.0, 10.0],
481 [2.0, 20.0],
482 [3.0, 30.0],
483 [4.0, 40.0],
484 [5.0, 50.0],
485 [6.0, 60.0],
486 [7.0, 70.0],
487 [8.0, 80.0],
488 [9.0, 90.0],
489 [100.0, 1000.0], ];
491
492 let winsorizer = Winsorizer::with_percentiles(10.0, 90.0)
493 .fit(&x, &())
494 .unwrap();
495
496 let transformed = winsorizer.transform(&x).unwrap();
497
498 assert!(transformed[[9, 0]] < x[[9, 0]]); assert!(transformed[[9, 1]] < x[[9, 1]]); assert_abs_diff_eq!(transformed[[4, 0]], x[[4, 0]], epsilon = 1e-10);
504 assert_abs_diff_eq!(transformed[[4, 1]], x[[4, 1]], epsilon = 1e-10);
505 }
506
507 #[test]
508 fn test_winsorizer_percentiles() {
509 let x = array![
510 [1.0],
511 [2.0],
512 [3.0],
513 [4.0],
514 [5.0],
515 [6.0],
516 [7.0],
517 [8.0],
518 [9.0],
519 [10.0]
520 ];
521
522 let winsorizer = Winsorizer::with_percentiles(20.0, 80.0)
523 .fit(&x, &())
524 .unwrap();
525
526 let lower = winsorizer.lower_bounds()[0];
528 let upper = winsorizer.upper_bounds()[0];
529
530 assert!(lower >= 2.0 && lower <= 3.5);
531 assert!(upper >= 7.5 && upper <= 9.0);
532 }
533
534 #[test]
535 fn test_winsorizer_with_nans() {
536 let x = array![
537 [1.0, 10.0],
538 [2.0, Float::NAN],
539 [3.0, 30.0],
540 [4.0, 40.0],
541 [100.0, 1000.0], ];
543
544 let winsorizer = Winsorizer::with_percentiles(25.0, 75.0)
545 .nan_strategy(NanStrategy::Skip)
546 .fit(&x, &())
547 .unwrap();
548
549 let transformed = winsorizer.transform(&x).unwrap();
550
551 assert!(transformed[[1, 1]].is_nan());
553
554 assert!(transformed[[4, 0]] < x[[4, 0]]);
556 assert!(transformed[[4, 1]] < x[[4, 1]]);
557 }
558
559 #[test]
560 fn test_winsorizer_single_value() {
561 let winsorizer = Winsorizer::with_percentiles(10.0, 90.0);
562 let x = array![[5.0], [15.0], [25.0], [100.0]];
563
564 let fitted = winsorizer.fit(&x, &()).unwrap();
565
566 let lower_bound = fitted.lower_bounds()[0];
567 let upper_bound = fitted.upper_bounds()[0];
568
569 assert_eq!(fitted.winsorize_single(0, 15.0), 15.0); assert_eq!(fitted.winsorize_single(0, 3.0), lower_bound); assert_eq!(fitted.winsorize_single(0, 200.0), upper_bound); assert!(fitted.winsorize_single(0, 200.0) < 200.0); }
576
577 #[test]
578 fn test_winsorization_stats() {
579 let x = array![
580 [1.0, 10.0],
581 [2.0, 20.0],
582 [3.0, 30.0],
583 [4.0, 40.0],
584 [100.0, 1000.0], ];
586
587 let winsorizer = Winsorizer::with_percentiles(25.0, 75.0)
588 .fit(&x, &())
589 .unwrap();
590
591 let stats = winsorizer.get_winsorization_stats(&x).unwrap();
592
593 assert_eq!(stats.n_samples, 5);
594 assert_eq!(stats.n_features, 2);
595 assert!(stats.total_clipped > 0);
596 assert!(stats.clipping_rate > 0.0);
597 }
598
599 #[test]
600 fn test_winsorizer_edge_cases() {
601 let x = array![[5.0], [5.0], [5.0], [5.0]];
603 let winsorizer = Winsorizer::new().fit(&x, &()).unwrap();
604 let transformed = winsorizer.transform(&x).unwrap();
605 assert_eq!(transformed, x);
606
607 let x = array![[5.0]];
609 let winsorizer = Winsorizer::new().fit(&x, &()).unwrap();
610 let transformed = winsorizer.transform(&x).unwrap();
611 assert_eq!(transformed, x);
612 }
613
614 #[test]
615 fn test_winsorizer_feature_mismatch() {
616 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
617 let x_test = array![[1.0, 2.0, 3.0]]; let winsorizer = Winsorizer::new().fit(&x_train, &()).unwrap();
620 let result = winsorizer.transform(&x_test);
621 assert!(result.is_err());
622 }
623
624 #[test]
625 fn test_winsorizer_invalid_percentiles() {
626 let result = std::panic::catch_unwind(|| {
627 Winsorizer::new().lower_percentile(60.0); });
629 assert!(result.is_err());
630
631 let result = std::panic::catch_unwind(|| {
632 Winsorizer::new().upper_percentile(40.0); });
634 assert!(result.is_err());
635 }
636
637 #[test]
638 fn test_winsorizer_empty_data() {
639 let x = Array2::<Float>::zeros((0, 2));
640 let result = Winsorizer::new().fit(&x, &());
641 assert!(result.is_err());
642 }
643}