Skip to main content

sklears_preprocessing/
quantile_transformer.rs

1//! Quantile Transformer
2//!
3//! This module provides QuantileTransformer which transforms features to follow
4//! a uniform or a normal distribution using quantiles information.
5
6use scirs2_core::ndarray::{Array1, Array2, Axis};
7use std::marker::PhantomData;
8
9use sklears_core::{
10    error::{Result, SklearsError},
11    traits::{Estimator, Fit, Trained, Transform, Untrained},
12    types::Float,
13};
14
15/// Output distribution for QuantileTransformer
16#[derive(Debug, Clone, Copy, PartialEq)]
17pub enum QuantileOutput {
18    /// Transform to uniform distribution on [0, 1]
19    Uniform,
20    /// Transform to standard normal distribution
21    Normal,
22}
23
24/// Configuration for QuantileTransformer
25#[derive(Debug, Clone)]
26pub struct QuantileTransformerConfig {
27    /// Number of quantiles to be computed
28    pub n_quantiles: usize,
29    /// Desired output distribution
30    pub output_distribution: QuantileOutput,
31    /// Subsample size for large datasets
32    pub subsample: Option<usize>,
33    /// Random state for subsampling
34    pub random_state: Option<u64>,
35    /// Whether to copy the input data
36    pub copy: bool,
37    /// Whether to clip transformed values to avoid infinities
38    pub clip: bool,
39    /// Ignore outliers beyond this quantile range
40    pub ignore_outliers: Option<(Float, Float)>,
41}
42
43impl Default for QuantileTransformerConfig {
44    fn default() -> Self {
45        Self {
46            n_quantiles: 1000,
47            output_distribution: QuantileOutput::Uniform,
48            subsample: Some(100_000),
49            random_state: None,
50            copy: true,
51            clip: true,
52            ignore_outliers: None,
53        }
54    }
55}
56
57/// QuantileTransformer applies a non-linear transformation to make data follow
58/// a uniform or normal distribution
59pub struct QuantileTransformer<State = Untrained> {
60    config: QuantileTransformerConfig,
61    state: PhantomData<State>,
62    /// The quantiles for each feature
63    quantiles_: Option<Vec<Array1<Float>>>,
64    /// The actual number of quantiles used
65    n_quantiles_: Option<usize>,
66    /// Reference values for inverse transform
67    references_: Option<Array1<Float>>,
68}
69
70impl QuantileTransformer<Untrained> {
71    /// Create a new QuantileTransformer with default configuration
72    pub fn new() -> Self {
73        Self {
74            config: QuantileTransformerConfig::default(),
75            state: PhantomData,
76            quantiles_: None,
77            n_quantiles_: None,
78            references_: None,
79        }
80    }
81
82    /// Set the number of quantiles
83    pub fn n_quantiles(mut self, n_quantiles: usize) -> Result<Self> {
84        if n_quantiles < 2 {
85            return Err(SklearsError::InvalidParameter {
86                name: "n_quantiles".to_string(),
87                reason: "must be at least 2".to_string(),
88            });
89        }
90        self.config.n_quantiles = n_quantiles;
91        Ok(self)
92    }
93
94    /// Set the output distribution
95    pub fn output_distribution(mut self, output_distribution: QuantileOutput) -> Self {
96        self.config.output_distribution = output_distribution;
97        self
98    }
99
100    /// Set the subsample size
101    pub fn subsample(mut self, subsample: Option<usize>) -> Self {
102        self.config.subsample = subsample;
103        self
104    }
105
106    /// Set whether to clip values to avoid infinities
107    pub fn clip(mut self, clip: bool) -> Self {
108        self.config.clip = clip;
109        self
110    }
111
112    /// Set outlier quantile range to ignore during fitting
113    pub fn ignore_outliers(mut self, range: Option<(Float, Float)>) -> Self {
114        if let Some((low, high)) = range {
115            assert!(
116                low >= 0.0 && low < high && high <= 1.0,
117                "Outlier range must be (low, high) where 0 <= low < high <= 1"
118            );
119        }
120        self.config.ignore_outliers = range;
121        self
122    }
123}
124
125impl Default for QuantileTransformer<Untrained> {
126    fn default() -> Self {
127        Self::new()
128    }
129}
130
131impl Estimator for QuantileTransformer<Untrained> {
132    type Config = QuantileTransformerConfig;
133    type Error = SklearsError;
134    type Float = Float;
135
136    fn config(&self) -> &Self::Config {
137        &self.config
138    }
139}
140
141impl Estimator for QuantileTransformer<Trained> {
142    type Config = QuantileTransformerConfig;
143    type Error = SklearsError;
144    type Float = Float;
145
146    fn config(&self) -> &Self::Config {
147        &self.config
148    }
149}
150
151/// Compute quantiles for a feature with optional outlier filtering
152fn compute_quantiles(
153    data: &Array1<Float>,
154    n_quantiles: usize,
155    ignore_outliers: Option<(Float, Float)>,
156) -> (Array1<Float>, Array1<Float>) {
157    let mut sorted_data = data.to_vec();
158    sorted_data.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
159
160    let n_samples = sorted_data.len();
161
162    // Apply outlier filtering if specified
163    let (start_idx, end_idx) = if let Some((low_quantile, high_quantile)) = ignore_outliers {
164        let start = ((low_quantile * (n_samples - 1) as Float) as usize).min(n_samples - 1);
165        let end = ((high_quantile * (n_samples - 1) as Float) as usize).min(n_samples - 1);
166        (start, end + 1)
167    } else {
168        (0, n_samples)
169    };
170
171    let filtered_data = &sorted_data[start_idx..end_idx];
172    let filtered_n_samples = filtered_data.len();
173    let n_quantiles = n_quantiles.min(filtered_n_samples);
174
175    let mut quantiles = Vec::with_capacity(n_quantiles);
176    let mut references = Vec::with_capacity(n_quantiles);
177
178    for i in 0..n_quantiles {
179        let quantile = i as Float / (n_quantiles - 1) as Float;
180        let idx = (quantile * (filtered_n_samples - 1) as Float) as usize;
181
182        quantiles.push(filtered_data[idx]);
183        references.push(quantile);
184    }
185
186    (Array1::from_vec(quantiles), Array1::from_vec(references))
187}
188
189/// Compute the inverse of the error function (for normal distribution)
190/// Uses an accurate rational approximation by Winitzki (2008)
191fn erfinv(x: Float) -> Float {
192    if x.abs() >= 1.0 {
193        return if x > 0.0 {
194            Float::INFINITY
195        } else {
196            Float::NEG_INFINITY
197        };
198    }
199
200    if x == 0.0 {
201        return 0.0;
202    }
203
204    let sign = if x > 0.0 { 1.0 } else { -1.0 };
205    let x = x.abs();
206
207    // Use Winitzki's approximation
208    let a = 0.147;
209    let ln_term = (1.0 - x * x).ln();
210    let term1 = 2.0 / (std::f64::consts::PI * a) + ln_term / 2.0;
211    let term2 = ln_term / a;
212
213    let result = (term1 * term1 - term2).sqrt() - term1;
214    sign * result.sqrt()
215}
216
217/// Robust inverse error function implementation using proven approximations
218/// This ensures monotonicity by construction using Beasley-Springer-Moro algorithm
219fn erfinv_accurate(x: Float) -> Float {
220    if x.abs() >= 1.0 {
221        return if x > 0.0 {
222            Float::INFINITY
223        } else {
224            Float::NEG_INFINITY
225        };
226    }
227
228    if x == 0.0 {
229        return 0.0;
230    }
231
232    // Use symmetric property
233    let sign = x.signum();
234    let abs_x = x.abs();
235
236    // Use the Beasley-Springer-Moro algorithm for standard normal quantiles
237    // This is monotonic and highly accurate
238
239    // Convert erf^-1 to standard normal quantile function
240    // If x = erf(y), then sqrt(2) * y = Φ^-1((1+x)/2)
241    let p = (1.0 + abs_x) / 2.0;
242
243    // Apply rational approximation for inverse normal CDF
244    let result = if p > 0.5 {
245        // Upper tail
246        let q = p - 0.5;
247        let r = q * q;
248
249        let numerator = (((((-39.6968302866538 * r + 220.946098424521) * r - 275.928510446969)
250            * r
251            + 138.357751867269)
252            * r
253            - 30.6647980661472)
254            * r
255            + 2.50662827745924)
256            * q;
257
258        let denominator = ((((-54.4760987982241 * r + 161.585836858041) * r - 155.698979859887)
259            * r
260            + 66.8013118877197)
261            * r
262            - 13.2806815528857)
263            * r
264            + 1.0;
265
266        numerator / denominator
267    } else {
268        // Lower tail - use symmetry
269        let q = 0.5 - p;
270        let r = q * q;
271
272        let numerator = (((((-39.6968302866538 * r + 220.946098424521) * r - 275.928510446969)
273            * r
274            + 138.357751867269)
275            * r
276            - 30.6647980661472)
277            * r
278            + 2.50662827745924)
279            * q;
280
281        let denominator = ((((-54.4760987982241 * r + 161.585836858041) * r - 155.698979859887)
282            * r
283            + 66.8013118877197)
284            * r
285            - 13.2806815528857)
286            * r
287            + 1.0;
288
289        -numerator / denominator
290    };
291
292    // Convert from standard normal quantile to erf^-1
293    sign * result / std::f64::consts::SQRT_2
294}
295
296/// Convert uniform values to normal distribution
297fn uniform_to_normal(uniform_value: Float, clip: bool) -> Float {
298    let clipped = if clip {
299        // Clip to avoid infinities
300        uniform_value.clamp(1e-7, 1.0 - 1e-7)
301    } else {
302        uniform_value
303    };
304
305    // Use more accurate inverse error function to convert uniform to normal
306    std::f64::consts::SQRT_2 * erfinv_accurate(2.0 * clipped - 1.0)
307}
308
309/// Accurate error function implementation using Abramowitz and Stegun approximation
310fn erf_accurate(x: Float) -> Float {
311    if x == 0.0 {
312        return 0.0;
313    }
314
315    let sign = x.signum();
316    let abs_x = x.abs();
317
318    // Abramowitz and Stegun approximation with high accuracy
319    let a1 = 0.254829592;
320    let a2 = -0.284496736;
321    let a3 = 1.421413741;
322    let a4 = -1.453152027;
323    let a5 = 1.061405429;
324    let p = 0.3275911;
325
326    let t = 1.0 / (1.0 + p * abs_x);
327    let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-abs_x * abs_x).exp();
328
329    sign * y
330}
331
332/// Convert normal values to uniform distribution
333fn normal_to_uniform(normal_value: Float, clip: bool) -> Float {
334    // Use accurate error function to convert normal to uniform
335    let erf_input = normal_value / std::f64::consts::SQRT_2;
336    let erf_val = erf_accurate(erf_input);
337    let uniform_val = (1.0 + erf_val) / 2.0;
338
339    if clip {
340        uniform_val.clamp(1e-7, 1.0 - 1e-7)
341    } else {
342        uniform_val
343    }
344}
345
346impl Fit<Array2<Float>, ()> for QuantileTransformer<Untrained> {
347    type Fitted = QuantileTransformer<Trained>;
348
349    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
350        let n_samples = x.nrows();
351        let n_features = x.ncols();
352
353        // Determine actual number of quantiles
354        let n_quantiles = self.config.n_quantiles.min(n_samples);
355
356        let mut all_quantiles = Vec::with_capacity(n_features);
357        let mut all_references = None;
358
359        // Compute quantiles for each feature
360        for j in 0..n_features {
361            let feature_data = x.column(j).to_owned();
362
363            // Subsample if needed with better strategy
364            let data_to_use = if let Some(subsample_size) = self.config.subsample {
365                if n_samples > subsample_size {
366                    // Use stratified subsampling for better representation
367                    let mut subsampled = Vec::with_capacity(subsample_size);
368
369                    // Sort data to enable stratified sampling
370                    let mut indexed_data: Vec<(usize, Float)> = feature_data
371                        .iter()
372                        .enumerate()
373                        .map(|(i, &val)| (i, val))
374                        .collect();
375                    indexed_data
376                        .sort_by(|a, b| a.1.partial_cmp(&b.1).expect("operation should succeed"));
377
378                    // Take stratified samples across the sorted data
379                    let step = n_samples as Float / subsample_size as Float;
380                    for i in 0..subsample_size {
381                        let idx = (i as Float * step) as usize;
382                        if idx < n_samples {
383                            subsampled.push(indexed_data[idx].1);
384                        }
385                    }
386
387                    Array1::from_vec(subsampled)
388                } else {
389                    feature_data
390                }
391            } else {
392                feature_data
393            };
394
395            let (quantiles, references) =
396                compute_quantiles(&data_to_use, n_quantiles, self.config.ignore_outliers);
397            all_quantiles.push(quantiles);
398
399            if all_references.is_none() {
400                all_references = Some(references);
401            }
402        }
403
404        Ok(QuantileTransformer {
405            config: self.config,
406            state: PhantomData,
407            quantiles_: Some(all_quantiles),
408            n_quantiles_: Some(n_quantiles),
409            references_: all_references,
410        })
411    }
412}
413
414/// Interpolate value based on quantiles
415fn interpolate_value(value: Float, quantiles: &Array1<Float>, references: &Array1<Float>) -> Float {
416    let n = quantiles.len();
417
418    // Handle edge cases
419    if value <= quantiles[0] {
420        return references[0];
421    }
422    if value >= quantiles[n - 1] {
423        return references[n - 1];
424    }
425
426    // Binary search for the interval
427    let mut left = 0;
428    let mut right = n - 1;
429
430    while left < right - 1 {
431        let mid = (left + right) / 2;
432        if value < quantiles[mid] {
433            right = mid;
434        } else {
435            left = mid;
436        }
437    }
438
439    // Linear interpolation
440    let x0 = quantiles[left];
441    let x1 = quantiles[right];
442    let y0 = references[left];
443    let y1 = references[right];
444
445    if (x1 - x0).abs() < Float::EPSILON {
446        y0
447    } else {
448        y0 + (value - x0) * (y1 - y0) / (x1 - x0)
449    }
450}
451
452impl Transform<Array2<Float>, Array2<Float>> for QuantileTransformer<Trained> {
453    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
454        let n_samples = x.nrows();
455        let n_features = x.ncols();
456        let quantiles = self.quantiles_.as_ref().expect("operation should succeed");
457        let references = self.references_.as_ref().expect("operation should succeed");
458
459        if n_features != quantiles.len() {
460            return Err(SklearsError::InvalidInput(format!(
461                "X has {} features, but QuantileTransformer is expecting {} features",
462                n_features,
463                quantiles.len()
464            )));
465        }
466
467        let mut result = Array2::zeros((n_samples, n_features));
468
469        for i in 0..n_samples {
470            for j in 0..n_features {
471                let value = x[[i, j]];
472                let uniform_value = interpolate_value(value, &quantiles[j], references);
473
474                result[[i, j]] = match self.config.output_distribution {
475                    QuantileOutput::Uniform => uniform_value,
476                    QuantileOutput::Normal => uniform_to_normal(uniform_value, self.config.clip),
477                };
478            }
479        }
480
481        Ok(result)
482    }
483}
484
485impl Transform<Array1<Float>, Array1<Float>> for QuantileTransformer<Trained> {
486    fn transform(&self, x: &Array1<Float>) -> Result<Array1<Float>> {
487        // Transform 1D array by treating it as a column vector
488        let x_2d = x.clone().insert_axis(Axis(1));
489        let result_2d = self.transform(&x_2d)?;
490        Ok(result_2d.column(0).to_owned())
491    }
492}
493
494impl QuantileTransformer<Trained> {
495    /// Get the quantiles for each feature
496    pub fn quantiles(&self) -> &Vec<Array1<Float>> {
497        self.quantiles_.as_ref().expect("operation should succeed")
498    }
499
500    /// Get the number of quantiles used
501    pub fn n_quantiles(&self) -> usize {
502        self.n_quantiles_.expect("operation should succeed")
503    }
504
505    /// Inverse transform data back to original distribution
506    pub fn inverse_transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
507        let n_samples = x.nrows();
508        let n_features = x.ncols();
509        let quantiles = self.quantiles_.as_ref().expect("operation should succeed");
510        let references = self.references_.as_ref().expect("operation should succeed");
511
512        if n_features != quantiles.len() {
513            return Err(SklearsError::InvalidInput(format!(
514                "X has {} features, but QuantileTransformer is expecting {} features",
515                n_features,
516                quantiles.len()
517            )));
518        }
519
520        let mut result = Array2::zeros((n_samples, n_features));
521
522        for i in 0..n_samples {
523            for j in 0..n_features {
524                let value = x[[i, j]];
525
526                let uniform_value = match self.config.output_distribution {
527                    QuantileOutput::Uniform => value,
528                    QuantileOutput::Normal => normal_to_uniform(value, self.config.clip),
529                };
530
531                // Interpolate back to original distribution
532                result[[i, j]] = interpolate_value(uniform_value, references, &quantiles[j]);
533            }
534        }
535
536        Ok(result)
537    }
538}
539
540#[allow(non_snake_case)]
541#[cfg(test)]
542mod tests {
543    use super::*;
544    use approx::assert_abs_diff_eq;
545    use scirs2_core::ndarray::array;
546
547    #[test]
548    fn test_quantile_transformer_uniform() {
549        let x = array![
550            [0.0],
551            [1.0],
552            [2.0],
553            [3.0],
554            [4.0],
555            [5.0],
556            [6.0],
557            [7.0],
558            [8.0],
559            [9.0],
560        ];
561
562        let qt = QuantileTransformer::new()
563            .n_quantiles(10)
564            .expect("valid parameter")
565            .output_distribution(QuantileOutput::Uniform)
566            .fit(&x, &())
567            .expect("operation should succeed");
568
569        let x_transformed = qt.transform(&x).expect("transformation should succeed");
570
571        // Check that values are in [0, 1]
572        for value in x_transformed.iter() {
573            assert!(*value >= 0.0 && *value <= 1.0);
574        }
575
576        // Check that transformation is monotonic
577        for i in 1..x_transformed.len() {
578            assert!(x_transformed[[i, 0]] >= x_transformed[[i - 1, 0]]);
579        }
580    }
581
582    #[test]
583    fn test_quantile_transformer_normal() {
584        let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0],];
585
586        let qt = QuantileTransformer::new()
587            .n_quantiles(6)
588            .expect("valid parameter")
589            .output_distribution(QuantileOutput::Normal)
590            .fit(&x, &())
591            .expect("operation should succeed");
592
593        let x_transformed = qt.transform(&x).expect("transformation should succeed");
594
595        // Check that transformation is monotonic (more important than exact values)
596        for i in 1..x_transformed.len() {
597            assert!(
598                x_transformed[[i, 0]] >= x_transformed[[i - 1, 0]] - 1e-10,
599                "Values at index {} ({}) should be >= values at index {} ({})",
600                i,
601                x_transformed[[i, 0]],
602                i - 1,
603                x_transformed[[i - 1, 0]]
604            );
605        }
606
607        // Check that the range is reasonable for normal distribution
608        let min_val = x_transformed.iter().fold(Float::INFINITY, |a, &b| a.min(b));
609        let max_val = x_transformed
610            .iter()
611            .fold(Float::NEG_INFINITY, |a, &b| a.max(b));
612
613        // Normal transformed values should be in a reasonable range
614        assert!(
615            min_val > -5.0 && max_val < 5.0,
616            "Normal transformed values should be in reasonable range, got [{}, {}]",
617            min_val,
618            max_val
619        );
620    }
621
622    #[test]
623    fn test_quantile_transformer_multivariate() {
624        let x = array![
625            [0.0, 10.0],
626            [1.0, 20.0],
627            [2.0, 30.0],
628            [3.0, 40.0],
629            [4.0, 50.0],
630        ];
631
632        let qt = QuantileTransformer::new()
633            .n_quantiles(5)
634            .expect("valid parameter")
635            .fit(&x, &())
636            .expect("operation should succeed");
637
638        let x_transformed = qt.transform(&x).expect("transformation should succeed");
639
640        // Each feature should be transformed independently
641        assert_eq!(x_transformed.ncols(), 2);
642
643        // Values should be in [0, 1]
644        for value in x_transformed.iter() {
645            assert!(*value >= 0.0 && *value <= 1.0);
646        }
647    }
648
649    #[test]
650    fn test_quantile_transformer_inverse() {
651        let x = array![[0.0], [1.0], [2.0], [3.0], [4.0],];
652
653        let qt = QuantileTransformer::new()
654            .n_quantiles(5)
655            .expect("valid parameter")
656            .fit(&x, &())
657            .expect("operation should succeed");
658
659        let x_transformed = qt.transform(&x).expect("transformation should succeed");
660        let x_inverse = qt
661            .inverse_transform(&x_transformed)
662            .expect("operation should succeed");
663
664        // Check that inverse transform recovers original values
665        for i in 0..x.nrows() {
666            assert_abs_diff_eq!(x[[i, 0]], x_inverse[[i, 0]], epsilon = 1e-6);
667        }
668    }
669
670    #[test]
671    fn test_quantile_transformer_edge_cases() {
672        let x = array![[0.0], [0.0], [1.0], [1.0], [1.0],];
673
674        let qt = QuantileTransformer::new()
675            .n_quantiles(3)
676            .expect("valid parameter")
677            .fit(&x, &())
678            .expect("operation should succeed");
679
680        let x_transformed = qt.transform(&x).expect("transformation should succeed");
681
682        // All values should be valid
683        for value in x_transformed.iter() {
684            assert!(value.is_finite());
685        }
686    }
687
688    #[test]
689    fn test_interpolate_value() {
690        let quantiles = array![0.0, 1.0, 2.0, 3.0];
691        let references = array![0.0, 0.33, 0.67, 1.0];
692
693        // Test exact matches
694        assert_abs_diff_eq!(interpolate_value(0.0, &quantiles, &references), 0.0);
695        assert_abs_diff_eq!(interpolate_value(3.0, &quantiles, &references), 1.0);
696
697        // Test interpolation
698        assert_abs_diff_eq!(
699            interpolate_value(0.5, &quantiles, &references),
700            0.165,
701            epsilon = 1e-3
702        );
703        assert_abs_diff_eq!(
704            interpolate_value(1.5, &quantiles, &references),
705            0.5,
706            epsilon = 1e-3
707        );
708
709        // Test extrapolation
710        assert_abs_diff_eq!(interpolate_value(-1.0, &quantiles, &references), 0.0);
711        assert_abs_diff_eq!(interpolate_value(4.0, &quantiles, &references), 1.0);
712    }
713
714    #[test]
715    fn test_uniform_to_normal() {
716        // Test some known values
717        assert_abs_diff_eq!(uniform_to_normal(0.5, true), 0.0, epsilon = 1e-4);
718
719        // Test that it's monotonic
720        let values = vec![0.1, 0.3, 0.5, 0.7, 0.9];
721        let transformed: Vec<Float> = values.iter().map(|&v| uniform_to_normal(v, true)).collect();
722
723        for i in 1..transformed.len() {
724            assert!(transformed[i] > transformed[i - 1]);
725        }
726    }
727
728    #[test]
729    fn test_enhanced_quantile_transformer_with_clipping() {
730        let x = array![
731            [-10.0], // Extreme outlier
732            [0.0],
733            [1.0],
734            [2.0],
735            [3.0],
736            [100.0], // Extreme outlier
737        ];
738
739        let qt = QuantileTransformer::new()
740            .n_quantiles(6)
741            .expect("valid parameter")
742            .output_distribution(QuantileOutput::Normal)
743            .clip(true)
744            .fit(&x, &())
745            .expect("operation should succeed");
746
747        let x_transformed = qt.transform(&x).expect("transformation should succeed");
748
749        // Check that all values are finite
750        for value in x_transformed.iter() {
751            assert!(value.is_finite(), "All transformed values should be finite");
752        }
753
754        // Check that transformation is monotonic
755        for i in 1..x_transformed.nrows() {
756            assert!(x_transformed[[i, 0]] >= x_transformed[[i - 1, 0]]);
757        }
758    }
759
760    #[test]
761    fn test_quantile_transformer_with_outlier_filtering() {
762        let mut x_data = vec![];
763        // Add normal data
764        for i in 0..100 {
765            x_data.push([i as Float]);
766        }
767        // Add outliers
768        x_data.push([-1000.0]);
769        x_data.push([1000.0]);
770
771        let x = Array2::from_shape_vec((102, 1), x_data.into_iter().flatten().collect())
772            .expect("shape and data length should match");
773
774        // Test with outlier filtering (ignore bottom 1% and top 1%)
775        let qt_filtered = QuantileTransformer::new()
776            .n_quantiles(50)
777            .expect("valid parameter")
778            .ignore_outliers(Some((0.01, 0.99)))
779            .fit(&x, &())
780            .expect("operation should succeed");
781
782        // Test without outlier filtering
783        let qt_unfiltered = QuantileTransformer::new()
784            .n_quantiles(50)
785            .expect("valid parameter")
786            .fit(&x, &())
787            .expect("operation should succeed");
788
789        let test_value = array![[50.0]];
790        let result_filtered = qt_filtered
791            .transform(&test_value)
792            .expect("transformation should succeed");
793        let result_unfiltered = qt_unfiltered
794            .transform(&test_value)
795            .expect("transformation should succeed");
796
797        // The filtered version should handle the middle values better
798        // (this is a qualitative test - both should be valid but different)
799        assert!(result_filtered.iter().all(|&v| v.is_finite()));
800        assert!(result_unfiltered.iter().all(|&v| v.is_finite()));
801    }
802
803    #[test]
804    fn test_improved_inverse_error_function() {
805        // Test erfinv_accurate with known values
806        assert_abs_diff_eq!(erfinv_accurate(0.0), 0.0, epsilon = 1e-10);
807
808        // Test symmetry
809        let test_val = 0.5;
810        assert_abs_diff_eq!(
811            erfinv_accurate(test_val),
812            -erfinv_accurate(-test_val),
813            epsilon = 1e-10
814        );
815
816        // Test that it's monotonic
817        let values = vec![-0.9, -0.5, 0.0, 0.5, 0.9];
818        let transformed: Vec<Float> = values.iter().map(|&v| erfinv_accurate(v)).collect();
819
820        for i in 1..transformed.len() {
821            assert!(transformed[i] > transformed[i - 1]);
822        }
823    }
824
825    #[test]
826    fn test_normal_to_uniform_conversion() {
827        // Test round-trip conversion
828        let uniform_values = vec![0.1, 0.3, 0.5, 0.7, 0.9];
829
830        for &uniform_val in &uniform_values {
831            let normal_val = uniform_to_normal(uniform_val, true);
832            let recovered_uniform = normal_to_uniform(normal_val, true);
833
834            assert_abs_diff_eq!(uniform_val, recovered_uniform, epsilon = 1e-3);
835        }
836
837        // Test edge cases
838        assert_abs_diff_eq!(normal_to_uniform(0.0, true), 0.5, epsilon = 1e-3);
839    }
840
841    #[test]
842    fn test_builder_methods() {
843        let qt = QuantileTransformer::new()
844            .n_quantiles(500)
845            .expect("valid parameter")
846            .output_distribution(QuantileOutput::Normal)
847            .subsample(Some(1000))
848            .clip(false)
849            .ignore_outliers(Some((0.05, 0.95)));
850
851        assert_eq!(qt.config.n_quantiles, 500);
852        assert_eq!(qt.config.output_distribution, QuantileOutput::Normal);
853        assert_eq!(qt.config.subsample, Some(1000));
854        assert_eq!(qt.config.clip, false);
855        assert_eq!(qt.config.ignore_outliers, Some((0.05, 0.95)));
856    }
857
858    #[test]
859    #[should_panic(expected = "Outlier range must be")]
860    fn test_invalid_outlier_range() {
861        QuantileTransformer::new().ignore_outliers(Some((0.9, 0.1))); // Invalid: low > high
862    }
863}