Skip to main content

sklears_preprocessing/
winsorization.rs

1//! Winsorization utilities for capping extreme outliers
2//!
3//! Winsorization is a statistical transformation that limits extreme values by
4//! replacing outliers with specified percentile values rather than removing them.
5//! This helps reduce the impact of outliers while preserving data points.
6
7use scirs2_core::ndarray::{Array1, Array2};
8use sklears_core::{
9    error::{Result, SklearsError},
10    traits::{Fit, Trained, Transform, Untrained},
11    types::Float,
12};
13use std::marker::PhantomData;
14
15#[cfg(feature = "parallel")]
16use rayon::prelude::*;
17
18/// Configuration for winsorization
19#[derive(Debug, Clone)]
20pub struct WinsorizerConfig {
21    /// Lower percentile for winsorization (default: 5.0)
22    pub lower_percentile: Float,
23    /// Upper percentile for winsorization (default: 95.0)
24    pub upper_percentile: Float,
25    /// Whether to winsorize each feature independently (default: true)
26    pub feature_wise: bool,
27    /// Strategy for handling NaN values
28    pub nan_strategy: NanStrategy,
29}
30
31/// Strategy for handling NaN values during winsorization
32#[derive(Debug, Clone, Copy, Default)]
33pub enum NanStrategy {
34    /// Skip NaN values during percentile calculation and leave them unchanged
35    #[default]
36    Skip,
37    /// Treat NaN values as missing and interpolate them
38    Interpolate,
39    /// Replace NaN values with the winsorization bounds
40    Replace,
41}
42
43impl Default for WinsorizerConfig {
44    fn default() -> Self {
45        Self {
46            lower_percentile: 5.0,
47            upper_percentile: 95.0,
48            feature_wise: true,
49            nan_strategy: NanStrategy::Skip,
50        }
51    }
52}
53
54/// Winsorizer for capping extreme outliers
55///
56/// This transformer limits extreme values by replacing outliers with percentile values.
57/// For example, with lower_percentile=5 and upper_percentile=95, all values below the 5th
58/// percentile are set to the 5th percentile value, and all values above the 95th percentile
59/// are set to the 95th percentile value.
60#[derive(Debug, Clone)]
61pub struct Winsorizer<State = Untrained> {
62    config: WinsorizerConfig,
63    state: PhantomData<State>,
64    // Fitted parameters
65    lower_bounds_: Option<Array1<Float>>,
66    upper_bounds_: Option<Array1<Float>>,
67    n_features_in_: Option<usize>,
68}
69
70impl Winsorizer<Untrained> {
71    /// Create a new Winsorizer with default configuration
72    pub fn new() -> Self {
73        Self {
74            config: WinsorizerConfig::default(),
75            state: PhantomData,
76            lower_bounds_: None,
77            upper_bounds_: None,
78            n_features_in_: None,
79        }
80    }
81
82    /// Create a Winsorizer with specified percentiles
83    pub fn with_percentiles(lower: Float, upper: Float) -> Result<Self> {
84        Self::new().lower_percentile(lower)?.upper_percentile(upper)
85    }
86
87    /// Create a Winsorizer with IQR-based bounds
88    pub fn with_iqr(multiplier: Float) -> Result<Self> {
89        // IQR method: Q1 - multiplier*IQR, Q3 + multiplier*IQR
90        // Convert to approximate percentiles
91        let lower_perc = if multiplier >= 1.5 { 0.7 } else { 2.5 };
92        let upper_perc = if multiplier >= 1.5 { 99.3 } else { 97.5 };
93        Self::new()
94            .lower_percentile(lower_perc)?
95            .upper_percentile(upper_perc)
96    }
97
98    /// Set the lower percentile
99    pub fn lower_percentile(mut self, percentile: Float) -> Result<Self> {
100        if !(0.0..50.0).contains(&percentile) {
101            return Err(SklearsError::InvalidParameter {
102                name: "lower_percentile".to_string(),
103                reason: "must be between 0 and 50".to_string(),
104            });
105        }
106        self.config.lower_percentile = percentile;
107        Ok(self)
108    }
109
110    /// Set the upper percentile
111    pub fn upper_percentile(mut self, percentile: Float) -> Result<Self> {
112        if percentile <= 50.0 || percentile > 100.0 {
113            return Err(SklearsError::InvalidParameter {
114                name: "upper_percentile".to_string(),
115                reason: "must be between 50 and 100".to_string(),
116            });
117        }
118        self.config.upper_percentile = percentile;
119        Ok(self)
120    }
121
122    /// Set whether to winsorize features independently
123    pub fn feature_wise(mut self, feature_wise: bool) -> Self {
124        self.config.feature_wise = feature_wise;
125        self
126    }
127
128    /// Set the NaN handling strategy
129    pub fn nan_strategy(mut self, strategy: NanStrategy) -> Self {
130        self.config.nan_strategy = strategy;
131        self
132    }
133
134    /// Compute percentile of sorted data
135    fn compute_percentile(sorted_data: &[Float], percentile: Float) -> Float {
136        if sorted_data.is_empty() {
137            return Float::NAN;
138        }
139
140        if percentile <= 0.0 {
141            return sorted_data[0];
142        }
143        if percentile >= 100.0 {
144            return sorted_data[sorted_data.len() - 1];
145        }
146
147        let index = (percentile / 100.0) * (sorted_data.len() - 1) as Float;
148        let lower_index = index.floor() as usize;
149        let upper_index = index.ceil() as usize;
150
151        if lower_index == upper_index {
152            sorted_data[lower_index]
153        } else {
154            let weight = index - lower_index as Float;
155            sorted_data[lower_index] * (1.0 - weight) + sorted_data[upper_index] * weight
156        }
157    }
158
159    /// Compute bounds for a single feature
160    fn compute_feature_bounds(&self, feature_data: &Array1<Float>) -> Result<(Float, Float)> {
161        // Filter out NaN values
162        let mut valid_data: Vec<Float> = feature_data
163            .iter()
164            .filter(|&&x| x.is_finite())
165            .cloned()
166            .collect();
167
168        if valid_data.is_empty() {
169            return Ok((Float::NEG_INFINITY, Float::INFINITY));
170        }
171
172        if valid_data.len() == 1 {
173            let value = valid_data[0];
174            return Ok((value, value));
175        }
176
177        // Sort data for percentile calculation
178        valid_data.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
179
180        let lower_bound = Self::compute_percentile(&valid_data, self.config.lower_percentile);
181        let upper_bound = Self::compute_percentile(&valid_data, self.config.upper_percentile);
182
183        Ok((lower_bound, upper_bound))
184    }
185}
186
187impl Winsorizer<Trained> {
188    /// Get the lower bounds
189    pub fn lower_bounds(&self) -> &Array1<Float> {
190        self.lower_bounds_
191            .as_ref()
192            .expect("Winsorizer should be fitted")
193    }
194
195    /// Get the upper bounds
196    pub fn upper_bounds(&self) -> &Array1<Float> {
197        self.upper_bounds_
198            .as_ref()
199            .expect("Winsorizer should be fitted")
200    }
201
202    /// Get the number of features seen during fitting
203    pub fn n_features_in(&self) -> usize {
204        self.n_features_in_.expect("Winsorizer should be fitted")
205    }
206
207    /// Apply winsorization to a single value for a specific feature
208    pub fn winsorize_single(&self, feature_idx: usize, value: Float) -> Float {
209        if value.is_nan() {
210            match self.config.nan_strategy {
211                NanStrategy::Skip => value,
212                NanStrategy::Interpolate | NanStrategy::Replace => {
213                    // For single values, we can't interpolate, so use the mean of bounds
214                    let lower = self.lower_bounds()[feature_idx];
215                    let upper = self.upper_bounds()[feature_idx];
216                    if lower.is_finite() && upper.is_finite() {
217                        (lower + upper) / 2.0
218                    } else {
219                        value
220                    }
221                }
222            }
223        } else {
224            let lower = self.lower_bounds()[feature_idx];
225            let upper = self.upper_bounds()[feature_idx];
226
227            if value < lower {
228                lower
229            } else if value > upper {
230                upper
231            } else {
232                value
233            }
234        }
235    }
236
237    /// Get statistics about the winsorization applied to the data
238    pub fn get_winsorization_stats(&self, x: &Array2<Float>) -> Result<WinsorizationStats> {
239        let (n_samples, n_features) = x.dim();
240
241        if n_features != self.n_features_in() {
242            return Err(SklearsError::FeatureMismatch {
243                expected: self.n_features_in(),
244                actual: n_features,
245            });
246        }
247
248        let mut stats = WinsorizationStats {
249            n_samples,
250            n_features,
251            lower_clipped_per_feature: vec![0; n_features],
252            upper_clipped_per_feature: vec![0; n_features],
253            total_clipped: 0,
254            clipping_rate: 0.0,
255        };
256
257        #[cfg(feature = "parallel")]
258        {
259            // Use parallel processing for large datasets
260            if n_samples * n_features > 10000 {
261                let clipping_counts: Vec<(usize, usize, usize)> = (0..n_features)
262                    .into_par_iter()
263                    .map(|j| {
264                        let lower = self.lower_bounds()[j];
265                        let upper = self.upper_bounds()[j];
266                        let mut lower_clipped = 0;
267                        let mut upper_clipped = 0;
268
269                        for i in 0..n_samples {
270                            let value = x[[i, j]];
271                            if value.is_finite() {
272                                if value < lower {
273                                    lower_clipped += 1;
274                                } else if value > upper {
275                                    upper_clipped += 1;
276                                }
277                            }
278                        }
279                        (j, lower_clipped, upper_clipped)
280                    })
281                    .collect();
282
283                for (j, lower_clipped, upper_clipped) in clipping_counts {
284                    stats.lower_clipped_per_feature[j] = lower_clipped;
285                    stats.upper_clipped_per_feature[j] = upper_clipped;
286                    stats.total_clipped += lower_clipped + upper_clipped;
287                }
288            } else {
289                // Sequential processing for smaller datasets
290                for i in 0..n_samples {
291                    for j in 0..n_features {
292                        let value = x[[i, j]];
293                        if value.is_finite() {
294                            let lower = self.lower_bounds()[j];
295                            let upper = self.upper_bounds()[j];
296
297                            if value < lower {
298                                stats.lower_clipped_per_feature[j] += 1;
299                                stats.total_clipped += 1;
300                            } else if value > upper {
301                                stats.upper_clipped_per_feature[j] += 1;
302                                stats.total_clipped += 1;
303                            }
304                        }
305                    }
306                }
307            }
308        }
309
310        #[cfg(not(feature = "parallel"))]
311        {
312            // Sequential processing when parallel feature is disabled
313            for i in 0..n_samples {
314                for j in 0..n_features {
315                    let value = x[[i, j]];
316                    if value.is_finite() {
317                        let lower = self.lower_bounds()[j];
318                        let upper = self.upper_bounds()[j];
319
320                        if value < lower {
321                            stats.lower_clipped_per_feature[j] += 1;
322                            stats.total_clipped += 1;
323                        } else if value > upper {
324                            stats.upper_clipped_per_feature[j] += 1;
325                            stats.total_clipped += 1;
326                        }
327                    }
328                }
329            }
330        }
331
332        stats.clipping_rate = stats.total_clipped as Float / (n_samples * n_features) as Float;
333        Ok(stats)
334    }
335}
336
337/// Statistics about winsorization applied to data
338#[derive(Debug, Clone)]
339pub struct WinsorizationStats {
340    /// Number of samples
341    pub n_samples: usize,
342    /// Number of features
343    pub n_features: usize,
344    /// Number of values clipped at lower bound per feature
345    pub lower_clipped_per_feature: Vec<usize>,
346    /// Number of values clipped at upper bound per feature
347    pub upper_clipped_per_feature: Vec<usize>,
348    /// Total number of values clipped
349    pub total_clipped: usize,
350    /// Overall clipping rate (proportion of values clipped)
351    pub clipping_rate: Float,
352}
353
354impl Default for Winsorizer<Untrained> {
355    fn default() -> Self {
356        Self::new()
357    }
358}
359
360impl Fit<Array2<Float>, ()> for Winsorizer<Untrained> {
361    type Fitted = Winsorizer<Trained>;
362
363    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
364        let (n_samples, n_features) = x.dim();
365
366        if n_samples == 0 {
367            return Err(SklearsError::InvalidInput(
368                "Cannot fit winsorizer on empty dataset".to_string(),
369            ));
370        }
371
372        if self.config.lower_percentile >= self.config.upper_percentile {
373            return Err(SklearsError::InvalidInput(
374                "Lower percentile must be less than upper percentile".to_string(),
375            ));
376        }
377
378        let mut lower_bounds = Array1::<Float>::zeros(n_features);
379        let mut upper_bounds = Array1::<Float>::zeros(n_features);
380
381        #[cfg(feature = "parallel")]
382        {
383            // Use parallel processing for multiple features
384            if n_features > 4 {
385                let bounds: Result<Vec<(Float, Float)>> = (0..n_features)
386                    .into_par_iter()
387                    .map(|j| {
388                        let feature_data = x.column(j).to_owned();
389                        self.compute_feature_bounds(&feature_data)
390                    })
391                    .collect();
392
393                match bounds {
394                    Ok(bounds_vec) => {
395                        for (j, (lower, upper)) in bounds_vec.into_iter().enumerate() {
396                            lower_bounds[j] = lower;
397                            upper_bounds[j] = upper;
398                        }
399                    }
400                    Err(e) => return Err(e),
401                }
402            } else {
403                // Sequential processing for few features
404                for j in 0..n_features {
405                    let feature_data = x.column(j).to_owned();
406                    let (lower, upper) = self.compute_feature_bounds(&feature_data)?;
407                    lower_bounds[j] = lower;
408                    upper_bounds[j] = upper;
409                }
410            }
411        }
412
413        #[cfg(not(feature = "parallel"))]
414        {
415            // Sequential processing when parallel feature is disabled
416            for j in 0..n_features {
417                let feature_data = x.column(j).to_owned();
418                let (lower, upper) = self.compute_feature_bounds(&feature_data)?;
419                lower_bounds[j] = lower;
420                upper_bounds[j] = upper;
421            }
422        }
423
424        Ok(Winsorizer {
425            config: self.config,
426            state: PhantomData,
427            lower_bounds_: Some(lower_bounds),
428            upper_bounds_: Some(upper_bounds),
429            n_features_in_: Some(n_features),
430        })
431    }
432}
433
434impl Transform<Array2<Float>, Array2<Float>> for Winsorizer<Trained> {
435    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
436        let (n_samples, n_features) = x.dim();
437
438        if n_features != self.n_features_in() {
439            return Err(SklearsError::FeatureMismatch {
440                expected: self.n_features_in(),
441                actual: n_features,
442            });
443        }
444
445        let mut result = x.clone();
446
447        #[cfg(feature = "parallel")]
448        {
449            // Use parallel processing for large datasets
450            if n_samples * n_features > 10000 {
451                result
452                    .as_slice_mut()
453                    .expect("operation should succeed")
454                    .par_iter_mut()
455                    .enumerate()
456                    .for_each(|(idx, value)| {
457                        let i = idx / n_features;
458                        let j = idx % n_features;
459                        *value = self.winsorize_single(j, x[[i, j]]);
460                    });
461                return Ok(result);
462            }
463        }
464
465        // Sequential processing for smaller datasets or when parallel feature is disabled
466        for i in 0..n_samples {
467            for j in 0..n_features {
468                result[[i, j]] = self.winsorize_single(j, x[[i, j]]);
469            }
470        }
471
472        Ok(result)
473    }
474}
475
476#[allow(non_snake_case)]
477#[cfg(test)]
478mod tests {
479    use super::*;
480    use approx::assert_abs_diff_eq;
481    use scirs2_core::ndarray::array;
482
483    #[test]
484    fn test_winsorizer_basic() {
485        let x = array![
486            [1.0, 10.0],
487            [2.0, 20.0],
488            [3.0, 30.0],
489            [4.0, 40.0],
490            [5.0, 50.0],
491            [6.0, 60.0],
492            [7.0, 70.0],
493            [8.0, 80.0],
494            [9.0, 90.0],
495            [100.0, 1000.0], // Extreme outliers
496        ];
497
498        let winsorizer = Winsorizer::with_percentiles(10.0, 90.0)
499            .expect("valid parameter")
500            .fit(&x, &())
501            .expect("operation should succeed");
502
503        let transformed = winsorizer
504            .transform(&x)
505            .expect("transformation should succeed");
506
507        // Check that outliers are capped
508        assert!(transformed[[9, 0]] < x[[9, 0]]); // 100.0 should be capped
509        assert!(transformed[[9, 1]] < x[[9, 1]]); // 1000.0 should be capped
510
511        // Check that normal values are unchanged
512        assert_abs_diff_eq!(transformed[[4, 0]], x[[4, 0]], epsilon = 1e-10);
513        assert_abs_diff_eq!(transformed[[4, 1]], x[[4, 1]], epsilon = 1e-10);
514    }
515
516    #[test]
517    fn test_winsorizer_percentiles() {
518        let x = array![
519            [1.0],
520            [2.0],
521            [3.0],
522            [4.0],
523            [5.0],
524            [6.0],
525            [7.0],
526            [8.0],
527            [9.0],
528            [10.0]
529        ];
530
531        let winsorizer = Winsorizer::with_percentiles(20.0, 80.0)
532            .expect("valid parameter")
533            .fit(&x, &())
534            .expect("operation should succeed");
535
536        // 20th percentile should be around 2.8, 80th percentile around 8.2
537        let lower = winsorizer.lower_bounds()[0];
538        let upper = winsorizer.upper_bounds()[0];
539
540        assert!(lower >= 2.0 && lower <= 3.5);
541        assert!(upper >= 7.5 && upper <= 9.0);
542    }
543
544    #[test]
545    fn test_winsorizer_with_nans() {
546        let x = array![
547            [1.0, 10.0],
548            [2.0, Float::NAN],
549            [3.0, 30.0],
550            [4.0, 40.0],
551            [100.0, 1000.0], // Outliers
552        ];
553
554        let winsorizer = Winsorizer::with_percentiles(25.0, 75.0)
555            .expect("valid parameter")
556            .nan_strategy(NanStrategy::Skip)
557            .fit(&x, &())
558            .expect("operation should succeed");
559
560        let transformed = winsorizer
561            .transform(&x)
562            .expect("transformation should succeed");
563
564        // NaN should remain NaN with Skip strategy
565        assert!(transformed[[1, 1]].is_nan());
566
567        // Outliers should be capped
568        assert!(transformed[[4, 0]] < x[[4, 0]]);
569        assert!(transformed[[4, 1]] < x[[4, 1]]);
570    }
571
572    #[test]
573    fn test_winsorizer_single_value() {
574        let winsorizer = Winsorizer::with_percentiles(10.0, 90.0).expect("valid parameter");
575        let x = array![[5.0], [15.0], [25.0], [100.0]];
576
577        let fitted = winsorizer
578            .fit(&x, &())
579            .expect("model fitting should succeed");
580
581        let lower_bound = fitted.lower_bounds()[0];
582        let upper_bound = fitted.upper_bounds()[0];
583
584        // Test single value winsorization
585        // 5.0 might be below the 10th percentile, so it gets clipped to the lower bound
586        assert_eq!(fitted.winsorize_single(0, 15.0), 15.0); // Within bounds
587        assert_eq!(fitted.winsorize_single(0, 3.0), lower_bound); // Below lower bound, should be clipped
588        assert_eq!(fitted.winsorize_single(0, 200.0), upper_bound); // Above upper bound, should be clipped
589        assert!(fitted.winsorize_single(0, 200.0) < 200.0); // Should be capped
590    }
591
592    #[test]
593    fn test_winsorization_stats() {
594        let x = array![
595            [1.0, 10.0],
596            [2.0, 20.0],
597            [3.0, 30.0],
598            [4.0, 40.0],
599            [100.0, 1000.0], // Outliers
600        ];
601
602        let winsorizer = Winsorizer::with_percentiles(25.0, 75.0)
603            .expect("valid parameter")
604            .fit(&x, &())
605            .expect("operation should succeed");
606
607        let stats = winsorizer
608            .get_winsorization_stats(&x)
609            .expect("operation should succeed");
610
611        assert_eq!(stats.n_samples, 5);
612        assert_eq!(stats.n_features, 2);
613        assert!(stats.total_clipped > 0);
614        assert!(stats.clipping_rate > 0.0);
615    }
616
617    #[test]
618    fn test_winsorizer_edge_cases() {
619        // Test with constant data
620        let x = array![[5.0], [5.0], [5.0], [5.0]];
621        let winsorizer = Winsorizer::new()
622            .fit(&x, &())
623            .expect("model fitting should succeed");
624        let transformed = winsorizer
625            .transform(&x)
626            .expect("transformation should succeed");
627        assert_eq!(transformed, x);
628
629        // Test with single sample
630        let x = array![[5.0]];
631        let winsorizer = Winsorizer::new()
632            .fit(&x, &())
633            .expect("model fitting should succeed");
634        let transformed = winsorizer
635            .transform(&x)
636            .expect("transformation should succeed");
637        assert_eq!(transformed, x);
638    }
639
640    #[test]
641    fn test_winsorizer_feature_mismatch() {
642        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
643        let x_test = array![[1.0, 2.0, 3.0]]; // Extra feature
644
645        let winsorizer = Winsorizer::new()
646            .fit(&x_train, &())
647            .expect("model fitting should succeed");
648        let result = winsorizer.transform(&x_test);
649        assert!(result.is_err());
650    }
651
652    #[test]
653    fn test_winsorizer_invalid_percentiles() {
654        let result = Winsorizer::new().lower_percentile(60.0); // Invalid: > 50
655        assert!(result.is_err());
656
657        let result = Winsorizer::new().upper_percentile(40.0); // Invalid: < 50
658        assert!(result.is_err());
659    }
660
661    #[test]
662    fn test_winsorizer_empty_data() {
663        let x = Array2::<Float>::zeros((0, 2));
664        let result = Winsorizer::new().fit(&x, &());
665        assert!(result.is_err());
666    }
667}