1use scirs2_core::ndarray::{Array1, Array2, Axis};
7use std::marker::PhantomData;
8
9use sklears_core::{
10 error::{Result, SklearsError},
11 traits::{Estimator, Fit, Trained, Transform, Untrained},
12 types::Float,
13};
14
15#[derive(Debug, Clone, Copy, PartialEq)]
17pub enum QuantileOutput {
18 Uniform,
20 Normal,
22}
23
24#[derive(Debug, Clone)]
26pub struct QuantileTransformerConfig {
27 pub n_quantiles: usize,
29 pub output_distribution: QuantileOutput,
31 pub subsample: Option<usize>,
33 pub random_state: Option<u64>,
35 pub copy: bool,
37 pub clip: bool,
39 pub ignore_outliers: Option<(Float, Float)>,
41}
42
43impl Default for QuantileTransformerConfig {
44 fn default() -> Self {
45 Self {
46 n_quantiles: 1000,
47 output_distribution: QuantileOutput::Uniform,
48 subsample: Some(100_000),
49 random_state: None,
50 copy: true,
51 clip: true,
52 ignore_outliers: None,
53 }
54 }
55}
56
57pub struct QuantileTransformer<State = Untrained> {
60 config: QuantileTransformerConfig,
61 state: PhantomData<State>,
62 quantiles_: Option<Vec<Array1<Float>>>,
64 n_quantiles_: Option<usize>,
66 references_: Option<Array1<Float>>,
68}
69
70impl QuantileTransformer<Untrained> {
71 pub fn new() -> Self {
73 Self {
74 config: QuantileTransformerConfig::default(),
75 state: PhantomData,
76 quantiles_: None,
77 n_quantiles_: None,
78 references_: None,
79 }
80 }
81
82 pub fn n_quantiles(mut self, n_quantiles: usize) -> Self {
84 if n_quantiles < 2 {
85 panic!("n_quantiles must be at least 2");
86 }
87 self.config.n_quantiles = n_quantiles;
88 self
89 }
90
91 pub fn output_distribution(mut self, output_distribution: QuantileOutput) -> Self {
93 self.config.output_distribution = output_distribution;
94 self
95 }
96
97 pub fn subsample(mut self, subsample: Option<usize>) -> Self {
99 self.config.subsample = subsample;
100 self
101 }
102
103 pub fn clip(mut self, clip: bool) -> Self {
105 self.config.clip = clip;
106 self
107 }
108
109 pub fn ignore_outliers(mut self, range: Option<(Float, Float)>) -> Self {
111 if let Some((low, high)) = range {
112 assert!(
113 low >= 0.0 && low < high && high <= 1.0,
114 "Outlier range must be (low, high) where 0 <= low < high <= 1"
115 );
116 }
117 self.config.ignore_outliers = range;
118 self
119 }
120}
121
122impl Default for QuantileTransformer<Untrained> {
123 fn default() -> Self {
124 Self::new()
125 }
126}
127
128impl Estimator for QuantileTransformer<Untrained> {
129 type Config = QuantileTransformerConfig;
130 type Error = SklearsError;
131 type Float = Float;
132
133 fn config(&self) -> &Self::Config {
134 &self.config
135 }
136}
137
138impl Estimator for QuantileTransformer<Trained> {
139 type Config = QuantileTransformerConfig;
140 type Error = SklearsError;
141 type Float = Float;
142
143 fn config(&self) -> &Self::Config {
144 &self.config
145 }
146}
147
148fn compute_quantiles(
150 data: &Array1<Float>,
151 n_quantiles: usize,
152 ignore_outliers: Option<(Float, Float)>,
153) -> (Array1<Float>, Array1<Float>) {
154 let mut sorted_data = data.to_vec();
155 sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
156
157 let n_samples = sorted_data.len();
158
159 let (start_idx, end_idx) = if let Some((low_quantile, high_quantile)) = ignore_outliers {
161 let start = ((low_quantile * (n_samples - 1) as Float) as usize).min(n_samples - 1);
162 let end = ((high_quantile * (n_samples - 1) as Float) as usize).min(n_samples - 1);
163 (start, end + 1)
164 } else {
165 (0, n_samples)
166 };
167
168 let filtered_data = &sorted_data[start_idx..end_idx];
169 let filtered_n_samples = filtered_data.len();
170 let n_quantiles = n_quantiles.min(filtered_n_samples);
171
172 let mut quantiles = Vec::with_capacity(n_quantiles);
173 let mut references = Vec::with_capacity(n_quantiles);
174
175 for i in 0..n_quantiles {
176 let quantile = i as Float / (n_quantiles - 1) as Float;
177 let idx = (quantile * (filtered_n_samples - 1) as Float) as usize;
178
179 quantiles.push(filtered_data[idx]);
180 references.push(quantile);
181 }
182
183 (Array1::from_vec(quantiles), Array1::from_vec(references))
184}
185
186fn erfinv(x: Float) -> Float {
189 if x.abs() >= 1.0 {
190 return if x > 0.0 {
191 Float::INFINITY
192 } else {
193 Float::NEG_INFINITY
194 };
195 }
196
197 if x == 0.0 {
198 return 0.0;
199 }
200
201 let sign = if x > 0.0 { 1.0 } else { -1.0 };
202 let x = x.abs();
203
204 let a = 0.147;
206 let ln_term = (1.0 - x * x).ln();
207 let term1 = 2.0 / (std::f64::consts::PI * a) + ln_term / 2.0;
208 let term2 = ln_term / a;
209
210 let result = (term1 * term1 - term2).sqrt() - term1;
211 sign * result.sqrt()
212}
213
214fn erfinv_accurate(x: Float) -> Float {
217 if x.abs() >= 1.0 {
218 return if x > 0.0 {
219 Float::INFINITY
220 } else {
221 Float::NEG_INFINITY
222 };
223 }
224
225 if x == 0.0 {
226 return 0.0;
227 }
228
229 let sign = x.signum();
231 let abs_x = x.abs();
232
233 let p = (1.0 + abs_x) / 2.0;
239
240 let result = if p > 0.5 {
242 let q = p - 0.5;
244 let r = q * q;
245
246 let numerator = (((((-39.6968302866538 * r + 220.946098424521) * r - 275.928510446969)
247 * r
248 + 138.357751867269)
249 * r
250 - 30.6647980661472)
251 * r
252 + 2.50662827745924)
253 * q;
254
255 let denominator = ((((-54.4760987982241 * r + 161.585836858041) * r - 155.698979859887)
256 * r
257 + 66.8013118877197)
258 * r
259 - 13.2806815528857)
260 * r
261 + 1.0;
262
263 numerator / denominator
264 } else {
265 let q = 0.5 - p;
267 let r = q * q;
268
269 let numerator = (((((-39.6968302866538 * r + 220.946098424521) * r - 275.928510446969)
270 * r
271 + 138.357751867269)
272 * r
273 - 30.6647980661472)
274 * r
275 + 2.50662827745924)
276 * q;
277
278 let denominator = ((((-54.4760987982241 * r + 161.585836858041) * r - 155.698979859887)
279 * r
280 + 66.8013118877197)
281 * r
282 - 13.2806815528857)
283 * r
284 + 1.0;
285
286 -numerator / denominator
287 };
288
289 sign * result / std::f64::consts::SQRT_2
291}
292
293fn uniform_to_normal(uniform_value: Float, clip: bool) -> Float {
295 let clipped = if clip {
296 uniform_value.clamp(1e-7, 1.0 - 1e-7)
298 } else {
299 uniform_value
300 };
301
302 std::f64::consts::SQRT_2 * erfinv_accurate(2.0 * clipped - 1.0)
304}
305
306fn erf_accurate(x: Float) -> Float {
308 if x == 0.0 {
309 return 0.0;
310 }
311
312 let sign = x.signum();
313 let abs_x = x.abs();
314
315 let a1 = 0.254829592;
317 let a2 = -0.284496736;
318 let a3 = 1.421413741;
319 let a4 = -1.453152027;
320 let a5 = 1.061405429;
321 let p = 0.3275911;
322
323 let t = 1.0 / (1.0 + p * abs_x);
324 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-abs_x * abs_x).exp();
325
326 sign * y
327}
328
329fn normal_to_uniform(normal_value: Float, clip: bool) -> Float {
331 let erf_input = normal_value / std::f64::consts::SQRT_2;
333 let erf_val = erf_accurate(erf_input);
334 let uniform_val = (1.0 + erf_val) / 2.0;
335
336 if clip {
337 uniform_val.clamp(1e-7, 1.0 - 1e-7)
338 } else {
339 uniform_val
340 }
341}
342
343impl Fit<Array2<Float>, ()> for QuantileTransformer<Untrained> {
344 type Fitted = QuantileTransformer<Trained>;
345
346 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
347 let n_samples = x.nrows();
348 let n_features = x.ncols();
349
350 let n_quantiles = self.config.n_quantiles.min(n_samples);
352
353 let mut all_quantiles = Vec::with_capacity(n_features);
354 let mut all_references = None;
355
356 for j in 0..n_features {
358 let feature_data = x.column(j).to_owned();
359
360 let data_to_use = if let Some(subsample_size) = self.config.subsample {
362 if n_samples > subsample_size {
363 let mut subsampled = Vec::with_capacity(subsample_size);
365
366 let mut indexed_data: Vec<(usize, Float)> = feature_data
368 .iter()
369 .enumerate()
370 .map(|(i, &val)| (i, val))
371 .collect();
372 indexed_data.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
373
374 let step = n_samples as Float / subsample_size as Float;
376 for i in 0..subsample_size {
377 let idx = (i as Float * step) as usize;
378 if idx < n_samples {
379 subsampled.push(indexed_data[idx].1);
380 }
381 }
382
383 Array1::from_vec(subsampled)
384 } else {
385 feature_data
386 }
387 } else {
388 feature_data
389 };
390
391 let (quantiles, references) =
392 compute_quantiles(&data_to_use, n_quantiles, self.config.ignore_outliers);
393 all_quantiles.push(quantiles);
394
395 if all_references.is_none() {
396 all_references = Some(references);
397 }
398 }
399
400 Ok(QuantileTransformer {
401 config: self.config,
402 state: PhantomData,
403 quantiles_: Some(all_quantiles),
404 n_quantiles_: Some(n_quantiles),
405 references_: all_references,
406 })
407 }
408}
409
410fn interpolate_value(value: Float, quantiles: &Array1<Float>, references: &Array1<Float>) -> Float {
412 let n = quantiles.len();
413
414 if value <= quantiles[0] {
416 return references[0];
417 }
418 if value >= quantiles[n - 1] {
419 return references[n - 1];
420 }
421
422 let mut left = 0;
424 let mut right = n - 1;
425
426 while left < right - 1 {
427 let mid = (left + right) / 2;
428 if value < quantiles[mid] {
429 right = mid;
430 } else {
431 left = mid;
432 }
433 }
434
435 let x0 = quantiles[left];
437 let x1 = quantiles[right];
438 let y0 = references[left];
439 let y1 = references[right];
440
441 if (x1 - x0).abs() < Float::EPSILON {
442 y0
443 } else {
444 y0 + (value - x0) * (y1 - y0) / (x1 - x0)
445 }
446}
447
448impl Transform<Array2<Float>, Array2<Float>> for QuantileTransformer<Trained> {
449 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
450 let n_samples = x.nrows();
451 let n_features = x.ncols();
452 let quantiles = self.quantiles_.as_ref().unwrap();
453 let references = self.references_.as_ref().unwrap();
454
455 if n_features != quantiles.len() {
456 return Err(SklearsError::InvalidInput(format!(
457 "X has {} features, but QuantileTransformer is expecting {} features",
458 n_features,
459 quantiles.len()
460 )));
461 }
462
463 let mut result = Array2::zeros((n_samples, n_features));
464
465 for i in 0..n_samples {
466 for j in 0..n_features {
467 let value = x[[i, j]];
468 let uniform_value = interpolate_value(value, &quantiles[j], references);
469
470 result[[i, j]] = match self.config.output_distribution {
471 QuantileOutput::Uniform => uniform_value,
472 QuantileOutput::Normal => uniform_to_normal(uniform_value, self.config.clip),
473 };
474 }
475 }
476
477 Ok(result)
478 }
479}
480
481impl Transform<Array1<Float>, Array1<Float>> for QuantileTransformer<Trained> {
482 fn transform(&self, x: &Array1<Float>) -> Result<Array1<Float>> {
483 let x_2d = x.clone().insert_axis(Axis(1));
485 let result_2d = self.transform(&x_2d)?;
486 Ok(result_2d.column(0).to_owned())
487 }
488}
489
490impl QuantileTransformer<Trained> {
491 pub fn quantiles(&self) -> &Vec<Array1<Float>> {
493 self.quantiles_.as_ref().unwrap()
494 }
495
496 pub fn n_quantiles(&self) -> usize {
498 self.n_quantiles_.unwrap()
499 }
500
501 pub fn inverse_transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
503 let n_samples = x.nrows();
504 let n_features = x.ncols();
505 let quantiles = self.quantiles_.as_ref().unwrap();
506 let references = self.references_.as_ref().unwrap();
507
508 if n_features != quantiles.len() {
509 return Err(SklearsError::InvalidInput(format!(
510 "X has {} features, but QuantileTransformer is expecting {} features",
511 n_features,
512 quantiles.len()
513 )));
514 }
515
516 let mut result = Array2::zeros((n_samples, n_features));
517
518 for i in 0..n_samples {
519 for j in 0..n_features {
520 let value = x[[i, j]];
521
522 let uniform_value = match self.config.output_distribution {
523 QuantileOutput::Uniform => value,
524 QuantileOutput::Normal => normal_to_uniform(value, self.config.clip),
525 };
526
527 result[[i, j]] = interpolate_value(uniform_value, references, &quantiles[j]);
529 }
530 }
531
532 Ok(result)
533 }
534}
535
536#[allow(non_snake_case)]
537#[cfg(test)]
538mod tests {
539 use super::*;
540 use approx::assert_abs_diff_eq;
541 use scirs2_core::ndarray::array;
542
543 #[test]
544 fn test_quantile_transformer_uniform() {
545 let x = array![
546 [0.0],
547 [1.0],
548 [2.0],
549 [3.0],
550 [4.0],
551 [5.0],
552 [6.0],
553 [7.0],
554 [8.0],
555 [9.0],
556 ];
557
558 let qt = QuantileTransformer::new()
559 .n_quantiles(10)
560 .output_distribution(QuantileOutput::Uniform)
561 .fit(&x, &())
562 .unwrap();
563
564 let x_transformed = qt.transform(&x).unwrap();
565
566 for value in x_transformed.iter() {
568 assert!(*value >= 0.0 && *value <= 1.0);
569 }
570
571 for i in 1..x_transformed.len() {
573 assert!(x_transformed[[i, 0]] >= x_transformed[[i - 1, 0]]);
574 }
575 }
576
577 #[test]
578 fn test_quantile_transformer_normal() {
579 let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0],];
580
581 let qt = QuantileTransformer::new()
582 .n_quantiles(6)
583 .output_distribution(QuantileOutput::Normal)
584 .fit(&x, &())
585 .unwrap();
586
587 let x_transformed = qt.transform(&x).unwrap();
588
589 for i in 1..x_transformed.len() {
591 assert!(
592 x_transformed[[i, 0]] >= x_transformed[[i - 1, 0]] - 1e-10,
593 "Values at index {} ({}) should be >= values at index {} ({})",
594 i,
595 x_transformed[[i, 0]],
596 i - 1,
597 x_transformed[[i - 1, 0]]
598 );
599 }
600
601 let min_val = x_transformed.iter().fold(Float::INFINITY, |a, &b| a.min(b));
603 let max_val = x_transformed
604 .iter()
605 .fold(Float::NEG_INFINITY, |a, &b| a.max(b));
606
607 assert!(
609 min_val > -5.0 && max_val < 5.0,
610 "Normal transformed values should be in reasonable range, got [{}, {}]",
611 min_val,
612 max_val
613 );
614 }
615
616 #[test]
617 fn test_quantile_transformer_multivariate() {
618 let x = array![
619 [0.0, 10.0],
620 [1.0, 20.0],
621 [2.0, 30.0],
622 [3.0, 40.0],
623 [4.0, 50.0],
624 ];
625
626 let qt = QuantileTransformer::new()
627 .n_quantiles(5)
628 .fit(&x, &())
629 .unwrap();
630
631 let x_transformed = qt.transform(&x).unwrap();
632
633 assert_eq!(x_transformed.ncols(), 2);
635
636 for value in x_transformed.iter() {
638 assert!(*value >= 0.0 && *value <= 1.0);
639 }
640 }
641
642 #[test]
643 fn test_quantile_transformer_inverse() {
644 let x = array![[0.0], [1.0], [2.0], [3.0], [4.0],];
645
646 let qt = QuantileTransformer::new()
647 .n_quantiles(5)
648 .fit(&x, &())
649 .unwrap();
650
651 let x_transformed = qt.transform(&x).unwrap();
652 let x_inverse = qt.inverse_transform(&x_transformed).unwrap();
653
654 for i in 0..x.nrows() {
656 assert_abs_diff_eq!(x[[i, 0]], x_inverse[[i, 0]], epsilon = 1e-6);
657 }
658 }
659
660 #[test]
661 fn test_quantile_transformer_edge_cases() {
662 let x = array![[0.0], [0.0], [1.0], [1.0], [1.0],];
663
664 let qt = QuantileTransformer::new()
665 .n_quantiles(3)
666 .fit(&x, &())
667 .unwrap();
668
669 let x_transformed = qt.transform(&x).unwrap();
670
671 for value in x_transformed.iter() {
673 assert!(value.is_finite());
674 }
675 }
676
677 #[test]
678 fn test_interpolate_value() {
679 let quantiles = array![0.0, 1.0, 2.0, 3.0];
680 let references = array![0.0, 0.33, 0.67, 1.0];
681
682 assert_abs_diff_eq!(interpolate_value(0.0, &quantiles, &references), 0.0);
684 assert_abs_diff_eq!(interpolate_value(3.0, &quantiles, &references), 1.0);
685
686 assert_abs_diff_eq!(
688 interpolate_value(0.5, &quantiles, &references),
689 0.165,
690 epsilon = 1e-3
691 );
692 assert_abs_diff_eq!(
693 interpolate_value(1.5, &quantiles, &references),
694 0.5,
695 epsilon = 1e-3
696 );
697
698 assert_abs_diff_eq!(interpolate_value(-1.0, &quantiles, &references), 0.0);
700 assert_abs_diff_eq!(interpolate_value(4.0, &quantiles, &references), 1.0);
701 }
702
703 #[test]
704 fn test_uniform_to_normal() {
705 assert_abs_diff_eq!(uniform_to_normal(0.5, true), 0.0, epsilon = 1e-4);
707
708 let values = vec![0.1, 0.3, 0.5, 0.7, 0.9];
710 let transformed: Vec<Float> = values.iter().map(|&v| uniform_to_normal(v, true)).collect();
711
712 for i in 1..transformed.len() {
713 assert!(transformed[i] > transformed[i - 1]);
714 }
715 }
716
717 #[test]
718 fn test_enhanced_quantile_transformer_with_clipping() {
719 let x = array![
720 [-10.0], [0.0],
722 [1.0],
723 [2.0],
724 [3.0],
725 [100.0], ];
727
728 let qt = QuantileTransformer::new()
729 .n_quantiles(6)
730 .output_distribution(QuantileOutput::Normal)
731 .clip(true)
732 .fit(&x, &())
733 .unwrap();
734
735 let x_transformed = qt.transform(&x).unwrap();
736
737 for value in x_transformed.iter() {
739 assert!(value.is_finite(), "All transformed values should be finite");
740 }
741
742 for i in 1..x_transformed.nrows() {
744 assert!(x_transformed[[i, 0]] >= x_transformed[[i - 1, 0]]);
745 }
746 }
747
748 #[test]
749 fn test_quantile_transformer_with_outlier_filtering() {
750 let mut x_data = vec![];
751 for i in 0..100 {
753 x_data.push([i as Float]);
754 }
755 x_data.push([-1000.0]);
757 x_data.push([1000.0]);
758
759 let x = Array2::from_shape_vec((102, 1), x_data.into_iter().flatten().collect()).unwrap();
760
761 let qt_filtered = QuantileTransformer::new()
763 .n_quantiles(50)
764 .ignore_outliers(Some((0.01, 0.99)))
765 .fit(&x, &())
766 .unwrap();
767
768 let qt_unfiltered = QuantileTransformer::new()
770 .n_quantiles(50)
771 .fit(&x, &())
772 .unwrap();
773
774 let test_value = array![[50.0]];
775 let result_filtered = qt_filtered.transform(&test_value).unwrap();
776 let result_unfiltered = qt_unfiltered.transform(&test_value).unwrap();
777
778 assert!(result_filtered.iter().all(|&v| v.is_finite()));
781 assert!(result_unfiltered.iter().all(|&v| v.is_finite()));
782 }
783
784 #[test]
785 fn test_improved_inverse_error_function() {
786 assert_abs_diff_eq!(erfinv_accurate(0.0), 0.0, epsilon = 1e-10);
788
789 let test_val = 0.5;
791 assert_abs_diff_eq!(
792 erfinv_accurate(test_val),
793 -erfinv_accurate(-test_val),
794 epsilon = 1e-10
795 );
796
797 let values = vec![-0.9, -0.5, 0.0, 0.5, 0.9];
799 let transformed: Vec<Float> = values.iter().map(|&v| erfinv_accurate(v)).collect();
800
801 for i in 1..transformed.len() {
802 assert!(transformed[i] > transformed[i - 1]);
803 }
804 }
805
806 #[test]
807 fn test_normal_to_uniform_conversion() {
808 let uniform_values = vec![0.1, 0.3, 0.5, 0.7, 0.9];
810
811 for &uniform_val in &uniform_values {
812 let normal_val = uniform_to_normal(uniform_val, true);
813 let recovered_uniform = normal_to_uniform(normal_val, true);
814
815 assert_abs_diff_eq!(uniform_val, recovered_uniform, epsilon = 1e-3);
816 }
817
818 assert_abs_diff_eq!(normal_to_uniform(0.0, true), 0.5, epsilon = 1e-3);
820 }
821
822 #[test]
823 fn test_builder_methods() {
824 let qt = QuantileTransformer::new()
825 .n_quantiles(500)
826 .output_distribution(QuantileOutput::Normal)
827 .subsample(Some(1000))
828 .clip(false)
829 .ignore_outliers(Some((0.05, 0.95)));
830
831 assert_eq!(qt.config.n_quantiles, 500);
832 assert_eq!(qt.config.output_distribution, QuantileOutput::Normal);
833 assert_eq!(qt.config.subsample, Some(1000));
834 assert_eq!(qt.config.clip, false);
835 assert_eq!(qt.config.ignore_outliers, Some((0.05, 0.95)));
836 }
837
838 #[test]
839 #[should_panic(expected = "Outlier range must be")]
840 fn test_invalid_outlier_range() {
841 QuantileTransformer::new().ignore_outliers(Some((0.9, 0.1))); }
843}