sklears_preprocessing/
winsorization.rs

1//! Winsorization utilities for capping extreme outliers
2//!
3//! Winsorization is a statistical transformation that limits extreme values by
4//! replacing outliers with specified percentile values rather than removing them.
5//! This helps reduce the impact of outliers while preserving data points.
6
7use scirs2_core::ndarray::{Array1, Array2};
8use sklears_core::{
9    error::{Result, SklearsError},
10    traits::{Fit, Trained, Transform, Untrained},
11    types::Float,
12};
13use std::marker::PhantomData;
14
15#[cfg(feature = "parallel")]
16use rayon::prelude::*;
17
18/// Configuration for winsorization
19#[derive(Debug, Clone)]
20pub struct WinsorizerConfig {
21    /// Lower percentile for winsorization (default: 5.0)
22    pub lower_percentile: Float,
23    /// Upper percentile for winsorization (default: 95.0)
24    pub upper_percentile: Float,
25    /// Whether to winsorize each feature independently (default: true)
26    pub feature_wise: bool,
27    /// Strategy for handling NaN values
28    pub nan_strategy: NanStrategy,
29}
30
31/// Strategy for handling NaN values during winsorization
32#[derive(Debug, Clone, Copy, Default)]
33pub enum NanStrategy {
34    /// Skip NaN values during percentile calculation and leave them unchanged
35    #[default]
36    Skip,
37    /// Treat NaN values as missing and interpolate them
38    Interpolate,
39    /// Replace NaN values with the winsorization bounds
40    Replace,
41}
42
43impl Default for WinsorizerConfig {
44    fn default() -> Self {
45        Self {
46            lower_percentile: 5.0,
47            upper_percentile: 95.0,
48            feature_wise: true,
49            nan_strategy: NanStrategy::Skip,
50        }
51    }
52}
53
54/// Winsorizer for capping extreme outliers
55///
56/// This transformer limits extreme values by replacing outliers with percentile values.
57/// For example, with lower_percentile=5 and upper_percentile=95, all values below the 5th
58/// percentile are set to the 5th percentile value, and all values above the 95th percentile
59/// are set to the 95th percentile value.
60#[derive(Debug, Clone)]
61pub struct Winsorizer<State = Untrained> {
62    config: WinsorizerConfig,
63    state: PhantomData<State>,
64    // Fitted parameters
65    lower_bounds_: Option<Array1<Float>>,
66    upper_bounds_: Option<Array1<Float>>,
67    n_features_in_: Option<usize>,
68}
69
70impl Winsorizer<Untrained> {
71    /// Create a new Winsorizer with default configuration
72    pub fn new() -> Self {
73        Self {
74            config: WinsorizerConfig::default(),
75            state: PhantomData,
76            lower_bounds_: None,
77            upper_bounds_: None,
78            n_features_in_: None,
79        }
80    }
81
82    /// Create a Winsorizer with specified percentiles
83    pub fn with_percentiles(lower: Float, upper: Float) -> Self {
84        Self::new().lower_percentile(lower).upper_percentile(upper)
85    }
86
87    /// Create a Winsorizer with IQR-based bounds
88    pub fn with_iqr(multiplier: Float) -> Self {
89        // IQR method: Q1 - multiplier*IQR, Q3 + multiplier*IQR
90        // Convert to approximate percentiles
91        let lower_perc = if multiplier >= 1.5 { 0.7 } else { 2.5 };
92        let upper_perc = if multiplier >= 1.5 { 99.3 } else { 97.5 };
93        Self::new()
94            .lower_percentile(lower_perc)
95            .upper_percentile(upper_perc)
96    }
97
98    /// Set the lower percentile
99    pub fn lower_percentile(mut self, percentile: Float) -> Self {
100        if !(0.0..50.0).contains(&percentile) {
101            panic!("Lower percentile must be between 0 and 50");
102        }
103        self.config.lower_percentile = percentile;
104        self
105    }
106
107    /// Set the upper percentile
108    pub fn upper_percentile(mut self, percentile: Float) -> Self {
109        if percentile <= 50.0 || percentile > 100.0 {
110            panic!("Upper percentile must be between 50 and 100");
111        }
112        self.config.upper_percentile = percentile;
113        self
114    }
115
116    /// Set whether to winsorize features independently
117    pub fn feature_wise(mut self, feature_wise: bool) -> Self {
118        self.config.feature_wise = feature_wise;
119        self
120    }
121
122    /// Set the NaN handling strategy
123    pub fn nan_strategy(mut self, strategy: NanStrategy) -> Self {
124        self.config.nan_strategy = strategy;
125        self
126    }
127
128    /// Compute percentile of sorted data
129    fn compute_percentile(sorted_data: &[Float], percentile: Float) -> Float {
130        if sorted_data.is_empty() {
131            return Float::NAN;
132        }
133
134        if percentile <= 0.0 {
135            return sorted_data[0];
136        }
137        if percentile >= 100.0 {
138            return sorted_data[sorted_data.len() - 1];
139        }
140
141        let index = (percentile / 100.0) * (sorted_data.len() - 1) as Float;
142        let lower_index = index.floor() as usize;
143        let upper_index = index.ceil() as usize;
144
145        if lower_index == upper_index {
146            sorted_data[lower_index]
147        } else {
148            let weight = index - lower_index as Float;
149            sorted_data[lower_index] * (1.0 - weight) + sorted_data[upper_index] * weight
150        }
151    }
152
153    /// Compute bounds for a single feature
154    fn compute_feature_bounds(&self, feature_data: &Array1<Float>) -> Result<(Float, Float)> {
155        // Filter out NaN values
156        let mut valid_data: Vec<Float> = feature_data
157            .iter()
158            .filter(|&&x| x.is_finite())
159            .cloned()
160            .collect();
161
162        if valid_data.is_empty() {
163            return Ok((Float::NEG_INFINITY, Float::INFINITY));
164        }
165
166        if valid_data.len() == 1 {
167            let value = valid_data[0];
168            return Ok((value, value));
169        }
170
171        // Sort data for percentile calculation
172        valid_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
173
174        let lower_bound = Self::compute_percentile(&valid_data, self.config.lower_percentile);
175        let upper_bound = Self::compute_percentile(&valid_data, self.config.upper_percentile);
176
177        Ok((lower_bound, upper_bound))
178    }
179}
180
181impl Winsorizer<Trained> {
182    /// Get the lower bounds
183    pub fn lower_bounds(&self) -> &Array1<Float> {
184        self.lower_bounds_
185            .as_ref()
186            .expect("Winsorizer should be fitted")
187    }
188
189    /// Get the upper bounds
190    pub fn upper_bounds(&self) -> &Array1<Float> {
191        self.upper_bounds_
192            .as_ref()
193            .expect("Winsorizer should be fitted")
194    }
195
196    /// Get the number of features seen during fitting
197    pub fn n_features_in(&self) -> usize {
198        self.n_features_in_.expect("Winsorizer should be fitted")
199    }
200
201    /// Apply winsorization to a single value for a specific feature
202    pub fn winsorize_single(&self, feature_idx: usize, value: Float) -> Float {
203        if value.is_nan() {
204            match self.config.nan_strategy {
205                NanStrategy::Skip => value,
206                NanStrategy::Interpolate | NanStrategy::Replace => {
207                    // For single values, we can't interpolate, so use the mean of bounds
208                    let lower = self.lower_bounds()[feature_idx];
209                    let upper = self.upper_bounds()[feature_idx];
210                    if lower.is_finite() && upper.is_finite() {
211                        (lower + upper) / 2.0
212                    } else {
213                        value
214                    }
215                }
216            }
217        } else {
218            let lower = self.lower_bounds()[feature_idx];
219            let upper = self.upper_bounds()[feature_idx];
220
221            if value < lower {
222                lower
223            } else if value > upper {
224                upper
225            } else {
226                value
227            }
228        }
229    }
230
231    /// Get statistics about the winsorization applied to the data
232    pub fn get_winsorization_stats(&self, x: &Array2<Float>) -> Result<WinsorizationStats> {
233        let (n_samples, n_features) = x.dim();
234
235        if n_features != self.n_features_in() {
236            return Err(SklearsError::FeatureMismatch {
237                expected: self.n_features_in(),
238                actual: n_features,
239            });
240        }
241
242        let mut stats = WinsorizationStats {
243            n_samples,
244            n_features,
245            lower_clipped_per_feature: vec![0; n_features],
246            upper_clipped_per_feature: vec![0; n_features],
247            total_clipped: 0,
248            clipping_rate: 0.0,
249        };
250
251        #[cfg(feature = "parallel")]
252        {
253            // Use parallel processing for large datasets
254            if n_samples * n_features > 10000 {
255                let clipping_counts: Vec<(usize, usize, usize)> = (0..n_features)
256                    .into_par_iter()
257                    .map(|j| {
258                        let lower = self.lower_bounds()[j];
259                        let upper = self.upper_bounds()[j];
260                        let mut lower_clipped = 0;
261                        let mut upper_clipped = 0;
262
263                        for i in 0..n_samples {
264                            let value = x[[i, j]];
265                            if value.is_finite() {
266                                if value < lower {
267                                    lower_clipped += 1;
268                                } else if value > upper {
269                                    upper_clipped += 1;
270                                }
271                            }
272                        }
273                        (j, lower_clipped, upper_clipped)
274                    })
275                    .collect();
276
277                for (j, lower_clipped, upper_clipped) in clipping_counts {
278                    stats.lower_clipped_per_feature[j] = lower_clipped;
279                    stats.upper_clipped_per_feature[j] = upper_clipped;
280                    stats.total_clipped += lower_clipped + upper_clipped;
281                }
282            } else {
283                // Sequential processing for smaller datasets
284                for i in 0..n_samples {
285                    for j in 0..n_features {
286                        let value = x[[i, j]];
287                        if value.is_finite() {
288                            let lower = self.lower_bounds()[j];
289                            let upper = self.upper_bounds()[j];
290
291                            if value < lower {
292                                stats.lower_clipped_per_feature[j] += 1;
293                                stats.total_clipped += 1;
294                            } else if value > upper {
295                                stats.upper_clipped_per_feature[j] += 1;
296                                stats.total_clipped += 1;
297                            }
298                        }
299                    }
300                }
301            }
302        }
303
304        #[cfg(not(feature = "parallel"))]
305        {
306            // Sequential processing when parallel feature is disabled
307            for i in 0..n_samples {
308                for j in 0..n_features {
309                    let value = x[[i, j]];
310                    if value.is_finite() {
311                        let lower = self.lower_bounds()[j];
312                        let upper = self.upper_bounds()[j];
313
314                        if value < lower {
315                            stats.lower_clipped_per_feature[j] += 1;
316                            stats.total_clipped += 1;
317                        } else if value > upper {
318                            stats.upper_clipped_per_feature[j] += 1;
319                            stats.total_clipped += 1;
320                        }
321                    }
322                }
323            }
324        }
325
326        stats.clipping_rate = stats.total_clipped as Float / (n_samples * n_features) as Float;
327        Ok(stats)
328    }
329}
330
331/// Statistics about winsorization applied to data
332#[derive(Debug, Clone)]
333pub struct WinsorizationStats {
334    /// Number of samples
335    pub n_samples: usize,
336    /// Number of features
337    pub n_features: usize,
338    /// Number of values clipped at lower bound per feature
339    pub lower_clipped_per_feature: Vec<usize>,
340    /// Number of values clipped at upper bound per feature
341    pub upper_clipped_per_feature: Vec<usize>,
342    /// Total number of values clipped
343    pub total_clipped: usize,
344    /// Overall clipping rate (proportion of values clipped)
345    pub clipping_rate: Float,
346}
347
348impl Default for Winsorizer<Untrained> {
349    fn default() -> Self {
350        Self::new()
351    }
352}
353
354impl Fit<Array2<Float>, ()> for Winsorizer<Untrained> {
355    type Fitted = Winsorizer<Trained>;
356
357    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
358        let (n_samples, n_features) = x.dim();
359
360        if n_samples == 0 {
361            return Err(SklearsError::InvalidInput(
362                "Cannot fit winsorizer on empty dataset".to_string(),
363            ));
364        }
365
366        if self.config.lower_percentile >= self.config.upper_percentile {
367            return Err(SklearsError::InvalidInput(
368                "Lower percentile must be less than upper percentile".to_string(),
369            ));
370        }
371
372        let mut lower_bounds = Array1::<Float>::zeros(n_features);
373        let mut upper_bounds = Array1::<Float>::zeros(n_features);
374
375        #[cfg(feature = "parallel")]
376        {
377            // Use parallel processing for multiple features
378            if n_features > 4 {
379                let bounds: Result<Vec<(Float, Float)>> = (0..n_features)
380                    .into_par_iter()
381                    .map(|j| {
382                        let feature_data = x.column(j).to_owned();
383                        self.compute_feature_bounds(&feature_data)
384                    })
385                    .collect();
386
387                match bounds {
388                    Ok(bounds_vec) => {
389                        for (j, (lower, upper)) in bounds_vec.into_iter().enumerate() {
390                            lower_bounds[j] = lower;
391                            upper_bounds[j] = upper;
392                        }
393                    }
394                    Err(e) => return Err(e),
395                }
396            } else {
397                // Sequential processing for few features
398                for j in 0..n_features {
399                    let feature_data = x.column(j).to_owned();
400                    let (lower, upper) = self.compute_feature_bounds(&feature_data)?;
401                    lower_bounds[j] = lower;
402                    upper_bounds[j] = upper;
403                }
404            }
405        }
406
407        #[cfg(not(feature = "parallel"))]
408        {
409            // Sequential processing when parallel feature is disabled
410            for j in 0..n_features {
411                let feature_data = x.column(j).to_owned();
412                let (lower, upper) = self.compute_feature_bounds(&feature_data)?;
413                lower_bounds[j] = lower;
414                upper_bounds[j] = upper;
415            }
416        }
417
418        Ok(Winsorizer {
419            config: self.config,
420            state: PhantomData,
421            lower_bounds_: Some(lower_bounds),
422            upper_bounds_: Some(upper_bounds),
423            n_features_in_: Some(n_features),
424        })
425    }
426}
427
428impl Transform<Array2<Float>, Array2<Float>> for Winsorizer<Trained> {
429    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
430        let (n_samples, n_features) = x.dim();
431
432        if n_features != self.n_features_in() {
433            return Err(SklearsError::FeatureMismatch {
434                expected: self.n_features_in(),
435                actual: n_features,
436            });
437        }
438
439        let mut result = x.clone();
440
441        #[cfg(feature = "parallel")]
442        {
443            // Use parallel processing for large datasets
444            if n_samples * n_features > 10000 {
445                result
446                    .as_slice_mut()
447                    .unwrap()
448                    .par_iter_mut()
449                    .enumerate()
450                    .for_each(|(idx, value)| {
451                        let i = idx / n_features;
452                        let j = idx % n_features;
453                        *value = self.winsorize_single(j, x[[i, j]]);
454                    });
455                return Ok(result);
456            }
457        }
458
459        // Sequential processing for smaller datasets or when parallel feature is disabled
460        for i in 0..n_samples {
461            for j in 0..n_features {
462                result[[i, j]] = self.winsorize_single(j, x[[i, j]]);
463            }
464        }
465
466        Ok(result)
467    }
468}
469
470#[allow(non_snake_case)]
471#[cfg(test)]
472mod tests {
473    use super::*;
474    use approx::assert_abs_diff_eq;
475    use scirs2_core::ndarray::array;
476
477    #[test]
478    fn test_winsorizer_basic() {
479        let x = array![
480            [1.0, 10.0],
481            [2.0, 20.0],
482            [3.0, 30.0],
483            [4.0, 40.0],
484            [5.0, 50.0],
485            [6.0, 60.0],
486            [7.0, 70.0],
487            [8.0, 80.0],
488            [9.0, 90.0],
489            [100.0, 1000.0], // Extreme outliers
490        ];
491
492        let winsorizer = Winsorizer::with_percentiles(10.0, 90.0)
493            .fit(&x, &())
494            .unwrap();
495
496        let transformed = winsorizer.transform(&x).unwrap();
497
498        // Check that outliers are capped
499        assert!(transformed[[9, 0]] < x[[9, 0]]); // 100.0 should be capped
500        assert!(transformed[[9, 1]] < x[[9, 1]]); // 1000.0 should be capped
501
502        // Check that normal values are unchanged
503        assert_abs_diff_eq!(transformed[[4, 0]], x[[4, 0]], epsilon = 1e-10);
504        assert_abs_diff_eq!(transformed[[4, 1]], x[[4, 1]], epsilon = 1e-10);
505    }
506
507    #[test]
508    fn test_winsorizer_percentiles() {
509        let x = array![
510            [1.0],
511            [2.0],
512            [3.0],
513            [4.0],
514            [5.0],
515            [6.0],
516            [7.0],
517            [8.0],
518            [9.0],
519            [10.0]
520        ];
521
522        let winsorizer = Winsorizer::with_percentiles(20.0, 80.0)
523            .fit(&x, &())
524            .unwrap();
525
526        // 20th percentile should be around 2.8, 80th percentile around 8.2
527        let lower = winsorizer.lower_bounds()[0];
528        let upper = winsorizer.upper_bounds()[0];
529
530        assert!(lower >= 2.0 && lower <= 3.5);
531        assert!(upper >= 7.5 && upper <= 9.0);
532    }
533
534    #[test]
535    fn test_winsorizer_with_nans() {
536        let x = array![
537            [1.0, 10.0],
538            [2.0, Float::NAN],
539            [3.0, 30.0],
540            [4.0, 40.0],
541            [100.0, 1000.0], // Outliers
542        ];
543
544        let winsorizer = Winsorizer::with_percentiles(25.0, 75.0)
545            .nan_strategy(NanStrategy::Skip)
546            .fit(&x, &())
547            .unwrap();
548
549        let transformed = winsorizer.transform(&x).unwrap();
550
551        // NaN should remain NaN with Skip strategy
552        assert!(transformed[[1, 1]].is_nan());
553
554        // Outliers should be capped
555        assert!(transformed[[4, 0]] < x[[4, 0]]);
556        assert!(transformed[[4, 1]] < x[[4, 1]]);
557    }
558
559    #[test]
560    fn test_winsorizer_single_value() {
561        let winsorizer = Winsorizer::with_percentiles(10.0, 90.0);
562        let x = array![[5.0], [15.0], [25.0], [100.0]];
563
564        let fitted = winsorizer.fit(&x, &()).unwrap();
565
566        let lower_bound = fitted.lower_bounds()[0];
567        let upper_bound = fitted.upper_bounds()[0];
568
569        // Test single value winsorization
570        // 5.0 might be below the 10th percentile, so it gets clipped to the lower bound
571        assert_eq!(fitted.winsorize_single(0, 15.0), 15.0); // Within bounds
572        assert_eq!(fitted.winsorize_single(0, 3.0), lower_bound); // Below lower bound, should be clipped
573        assert_eq!(fitted.winsorize_single(0, 200.0), upper_bound); // Above upper bound, should be clipped
574        assert!(fitted.winsorize_single(0, 200.0) < 200.0); // Should be capped
575    }
576
577    #[test]
578    fn test_winsorization_stats() {
579        let x = array![
580            [1.0, 10.0],
581            [2.0, 20.0],
582            [3.0, 30.0],
583            [4.0, 40.0],
584            [100.0, 1000.0], // Outliers
585        ];
586
587        let winsorizer = Winsorizer::with_percentiles(25.0, 75.0)
588            .fit(&x, &())
589            .unwrap();
590
591        let stats = winsorizer.get_winsorization_stats(&x).unwrap();
592
593        assert_eq!(stats.n_samples, 5);
594        assert_eq!(stats.n_features, 2);
595        assert!(stats.total_clipped > 0);
596        assert!(stats.clipping_rate > 0.0);
597    }
598
599    #[test]
600    fn test_winsorizer_edge_cases() {
601        // Test with constant data
602        let x = array![[5.0], [5.0], [5.0], [5.0]];
603        let winsorizer = Winsorizer::new().fit(&x, &()).unwrap();
604        let transformed = winsorizer.transform(&x).unwrap();
605        assert_eq!(transformed, x);
606
607        // Test with single sample
608        let x = array![[5.0]];
609        let winsorizer = Winsorizer::new().fit(&x, &()).unwrap();
610        let transformed = winsorizer.transform(&x).unwrap();
611        assert_eq!(transformed, x);
612    }
613
614    #[test]
615    fn test_winsorizer_feature_mismatch() {
616        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
617        let x_test = array![[1.0, 2.0, 3.0]]; // Extra feature
618
619        let winsorizer = Winsorizer::new().fit(&x_train, &()).unwrap();
620        let result = winsorizer.transform(&x_test);
621        assert!(result.is_err());
622    }
623
624    #[test]
625    fn test_winsorizer_invalid_percentiles() {
626        let result = std::panic::catch_unwind(|| {
627            Winsorizer::new().lower_percentile(60.0); // Invalid: > 50
628        });
629        assert!(result.is_err());
630
631        let result = std::panic::catch_unwind(|| {
632            Winsorizer::new().upper_percentile(40.0); // Invalid: < 50
633        });
634        assert!(result.is_err());
635    }
636
637    #[test]
638    fn test_winsorizer_empty_data() {
639        let x = Array2::<Float>::zeros((0, 2));
640        let result = Winsorizer::new().fit(&x, &());
641        assert!(result.is_err());
642    }
643}