1use scirs2_core::ndarray::{Array1, Array2, Axis};
7use std::marker::PhantomData;
8
9use sklears_core::{
10 error::{Result, SklearsError},
11 traits::{Estimator, Fit, Trained, Transform, Untrained},
12 types::Float,
13};
14
15#[derive(Debug, Clone, Copy, PartialEq)]
17pub enum QuantileOutput {
18 Uniform,
20 Normal,
22}
23
24#[derive(Debug, Clone)]
26pub struct QuantileTransformerConfig {
27 pub n_quantiles: usize,
29 pub output_distribution: QuantileOutput,
31 pub subsample: Option<usize>,
33 pub random_state: Option<u64>,
35 pub copy: bool,
37 pub clip: bool,
39 pub ignore_outliers: Option<(Float, Float)>,
41}
42
43impl Default for QuantileTransformerConfig {
44 fn default() -> Self {
45 Self {
46 n_quantiles: 1000,
47 output_distribution: QuantileOutput::Uniform,
48 subsample: Some(100_000),
49 random_state: None,
50 copy: true,
51 clip: true,
52 ignore_outliers: None,
53 }
54 }
55}
56
57pub struct QuantileTransformer<State = Untrained> {
60 config: QuantileTransformerConfig,
61 state: PhantomData<State>,
62 quantiles_: Option<Vec<Array1<Float>>>,
64 n_quantiles_: Option<usize>,
66 references_: Option<Array1<Float>>,
68}
69
70impl QuantileTransformer<Untrained> {
71 pub fn new() -> Self {
73 Self {
74 config: QuantileTransformerConfig::default(),
75 state: PhantomData,
76 quantiles_: None,
77 n_quantiles_: None,
78 references_: None,
79 }
80 }
81
82 pub fn n_quantiles(mut self, n_quantiles: usize) -> Result<Self> {
84 if n_quantiles < 2 {
85 return Err(SklearsError::InvalidParameter {
86 name: "n_quantiles".to_string(),
87 reason: "must be at least 2".to_string(),
88 });
89 }
90 self.config.n_quantiles = n_quantiles;
91 Ok(self)
92 }
93
94 pub fn output_distribution(mut self, output_distribution: QuantileOutput) -> Self {
96 self.config.output_distribution = output_distribution;
97 self
98 }
99
100 pub fn subsample(mut self, subsample: Option<usize>) -> Self {
102 self.config.subsample = subsample;
103 self
104 }
105
106 pub fn clip(mut self, clip: bool) -> Self {
108 self.config.clip = clip;
109 self
110 }
111
112 pub fn ignore_outliers(mut self, range: Option<(Float, Float)>) -> Self {
114 if let Some((low, high)) = range {
115 assert!(
116 low >= 0.0 && low < high && high <= 1.0,
117 "Outlier range must be (low, high) where 0 <= low < high <= 1"
118 );
119 }
120 self.config.ignore_outliers = range;
121 self
122 }
123}
124
125impl Default for QuantileTransformer<Untrained> {
126 fn default() -> Self {
127 Self::new()
128 }
129}
130
131impl Estimator for QuantileTransformer<Untrained> {
132 type Config = QuantileTransformerConfig;
133 type Error = SklearsError;
134 type Float = Float;
135
136 fn config(&self) -> &Self::Config {
137 &self.config
138 }
139}
140
141impl Estimator for QuantileTransformer<Trained> {
142 type Config = QuantileTransformerConfig;
143 type Error = SklearsError;
144 type Float = Float;
145
146 fn config(&self) -> &Self::Config {
147 &self.config
148 }
149}
150
151fn compute_quantiles(
153 data: &Array1<Float>,
154 n_quantiles: usize,
155 ignore_outliers: Option<(Float, Float)>,
156) -> (Array1<Float>, Array1<Float>) {
157 let mut sorted_data = data.to_vec();
158 sorted_data.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
159
160 let n_samples = sorted_data.len();
161
162 let (start_idx, end_idx) = if let Some((low_quantile, high_quantile)) = ignore_outliers {
164 let start = ((low_quantile * (n_samples - 1) as Float) as usize).min(n_samples - 1);
165 let end = ((high_quantile * (n_samples - 1) as Float) as usize).min(n_samples - 1);
166 (start, end + 1)
167 } else {
168 (0, n_samples)
169 };
170
171 let filtered_data = &sorted_data[start_idx..end_idx];
172 let filtered_n_samples = filtered_data.len();
173 let n_quantiles = n_quantiles.min(filtered_n_samples);
174
175 let mut quantiles = Vec::with_capacity(n_quantiles);
176 let mut references = Vec::with_capacity(n_quantiles);
177
178 for i in 0..n_quantiles {
179 let quantile = i as Float / (n_quantiles - 1) as Float;
180 let idx = (quantile * (filtered_n_samples - 1) as Float) as usize;
181
182 quantiles.push(filtered_data[idx]);
183 references.push(quantile);
184 }
185
186 (Array1::from_vec(quantiles), Array1::from_vec(references))
187}
188
189fn erfinv(x: Float) -> Float {
192 if x.abs() >= 1.0 {
193 return if x > 0.0 {
194 Float::INFINITY
195 } else {
196 Float::NEG_INFINITY
197 };
198 }
199
200 if x == 0.0 {
201 return 0.0;
202 }
203
204 let sign = if x > 0.0 { 1.0 } else { -1.0 };
205 let x = x.abs();
206
207 let a = 0.147;
209 let ln_term = (1.0 - x * x).ln();
210 let term1 = 2.0 / (std::f64::consts::PI * a) + ln_term / 2.0;
211 let term2 = ln_term / a;
212
213 let result = (term1 * term1 - term2).sqrt() - term1;
214 sign * result.sqrt()
215}
216
217fn erfinv_accurate(x: Float) -> Float {
220 if x.abs() >= 1.0 {
221 return if x > 0.0 {
222 Float::INFINITY
223 } else {
224 Float::NEG_INFINITY
225 };
226 }
227
228 if x == 0.0 {
229 return 0.0;
230 }
231
232 let sign = x.signum();
234 let abs_x = x.abs();
235
236 let p = (1.0 + abs_x) / 2.0;
242
243 let result = if p > 0.5 {
245 let q = p - 0.5;
247 let r = q * q;
248
249 let numerator = (((((-39.6968302866538 * r + 220.946098424521) * r - 275.928510446969)
250 * r
251 + 138.357751867269)
252 * r
253 - 30.6647980661472)
254 * r
255 + 2.50662827745924)
256 * q;
257
258 let denominator = ((((-54.4760987982241 * r + 161.585836858041) * r - 155.698979859887)
259 * r
260 + 66.8013118877197)
261 * r
262 - 13.2806815528857)
263 * r
264 + 1.0;
265
266 numerator / denominator
267 } else {
268 let q = 0.5 - p;
270 let r = q * q;
271
272 let numerator = (((((-39.6968302866538 * r + 220.946098424521) * r - 275.928510446969)
273 * r
274 + 138.357751867269)
275 * r
276 - 30.6647980661472)
277 * r
278 + 2.50662827745924)
279 * q;
280
281 let denominator = ((((-54.4760987982241 * r + 161.585836858041) * r - 155.698979859887)
282 * r
283 + 66.8013118877197)
284 * r
285 - 13.2806815528857)
286 * r
287 + 1.0;
288
289 -numerator / denominator
290 };
291
292 sign * result / std::f64::consts::SQRT_2
294}
295
296fn uniform_to_normal(uniform_value: Float, clip: bool) -> Float {
298 let clipped = if clip {
299 uniform_value.clamp(1e-7, 1.0 - 1e-7)
301 } else {
302 uniform_value
303 };
304
305 std::f64::consts::SQRT_2 * erfinv_accurate(2.0 * clipped - 1.0)
307}
308
309fn erf_accurate(x: Float) -> Float {
311 if x == 0.0 {
312 return 0.0;
313 }
314
315 let sign = x.signum();
316 let abs_x = x.abs();
317
318 let a1 = 0.254829592;
320 let a2 = -0.284496736;
321 let a3 = 1.421413741;
322 let a4 = -1.453152027;
323 let a5 = 1.061405429;
324 let p = 0.3275911;
325
326 let t = 1.0 / (1.0 + p * abs_x);
327 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-abs_x * abs_x).exp();
328
329 sign * y
330}
331
332fn normal_to_uniform(normal_value: Float, clip: bool) -> Float {
334 let erf_input = normal_value / std::f64::consts::SQRT_2;
336 let erf_val = erf_accurate(erf_input);
337 let uniform_val = (1.0 + erf_val) / 2.0;
338
339 if clip {
340 uniform_val.clamp(1e-7, 1.0 - 1e-7)
341 } else {
342 uniform_val
343 }
344}
345
346impl Fit<Array2<Float>, ()> for QuantileTransformer<Untrained> {
347 type Fitted = QuantileTransformer<Trained>;
348
349 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
350 let n_samples = x.nrows();
351 let n_features = x.ncols();
352
353 let n_quantiles = self.config.n_quantiles.min(n_samples);
355
356 let mut all_quantiles = Vec::with_capacity(n_features);
357 let mut all_references = None;
358
359 for j in 0..n_features {
361 let feature_data = x.column(j).to_owned();
362
363 let data_to_use = if let Some(subsample_size) = self.config.subsample {
365 if n_samples > subsample_size {
366 let mut subsampled = Vec::with_capacity(subsample_size);
368
369 let mut indexed_data: Vec<(usize, Float)> = feature_data
371 .iter()
372 .enumerate()
373 .map(|(i, &val)| (i, val))
374 .collect();
375 indexed_data
376 .sort_by(|a, b| a.1.partial_cmp(&b.1).expect("operation should succeed"));
377
378 let step = n_samples as Float / subsample_size as Float;
380 for i in 0..subsample_size {
381 let idx = (i as Float * step) as usize;
382 if idx < n_samples {
383 subsampled.push(indexed_data[idx].1);
384 }
385 }
386
387 Array1::from_vec(subsampled)
388 } else {
389 feature_data
390 }
391 } else {
392 feature_data
393 };
394
395 let (quantiles, references) =
396 compute_quantiles(&data_to_use, n_quantiles, self.config.ignore_outliers);
397 all_quantiles.push(quantiles);
398
399 if all_references.is_none() {
400 all_references = Some(references);
401 }
402 }
403
404 Ok(QuantileTransformer {
405 config: self.config,
406 state: PhantomData,
407 quantiles_: Some(all_quantiles),
408 n_quantiles_: Some(n_quantiles),
409 references_: all_references,
410 })
411 }
412}
413
414fn interpolate_value(value: Float, quantiles: &Array1<Float>, references: &Array1<Float>) -> Float {
416 let n = quantiles.len();
417
418 if value <= quantiles[0] {
420 return references[0];
421 }
422 if value >= quantiles[n - 1] {
423 return references[n - 1];
424 }
425
426 let mut left = 0;
428 let mut right = n - 1;
429
430 while left < right - 1 {
431 let mid = (left + right) / 2;
432 if value < quantiles[mid] {
433 right = mid;
434 } else {
435 left = mid;
436 }
437 }
438
439 let x0 = quantiles[left];
441 let x1 = quantiles[right];
442 let y0 = references[left];
443 let y1 = references[right];
444
445 if (x1 - x0).abs() < Float::EPSILON {
446 y0
447 } else {
448 y0 + (value - x0) * (y1 - y0) / (x1 - x0)
449 }
450}
451
452impl Transform<Array2<Float>, Array2<Float>> for QuantileTransformer<Trained> {
453 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
454 let n_samples = x.nrows();
455 let n_features = x.ncols();
456 let quantiles = self.quantiles_.as_ref().expect("operation should succeed");
457 let references = self.references_.as_ref().expect("operation should succeed");
458
459 if n_features != quantiles.len() {
460 return Err(SklearsError::InvalidInput(format!(
461 "X has {} features, but QuantileTransformer is expecting {} features",
462 n_features,
463 quantiles.len()
464 )));
465 }
466
467 let mut result = Array2::zeros((n_samples, n_features));
468
469 for i in 0..n_samples {
470 for j in 0..n_features {
471 let value = x[[i, j]];
472 let uniform_value = interpolate_value(value, &quantiles[j], references);
473
474 result[[i, j]] = match self.config.output_distribution {
475 QuantileOutput::Uniform => uniform_value,
476 QuantileOutput::Normal => uniform_to_normal(uniform_value, self.config.clip),
477 };
478 }
479 }
480
481 Ok(result)
482 }
483}
484
485impl Transform<Array1<Float>, Array1<Float>> for QuantileTransformer<Trained> {
486 fn transform(&self, x: &Array1<Float>) -> Result<Array1<Float>> {
487 let x_2d = x.clone().insert_axis(Axis(1));
489 let result_2d = self.transform(&x_2d)?;
490 Ok(result_2d.column(0).to_owned())
491 }
492}
493
494impl QuantileTransformer<Trained> {
495 pub fn quantiles(&self) -> &Vec<Array1<Float>> {
497 self.quantiles_.as_ref().expect("operation should succeed")
498 }
499
500 pub fn n_quantiles(&self) -> usize {
502 self.n_quantiles_.expect("operation should succeed")
503 }
504
505 pub fn inverse_transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
507 let n_samples = x.nrows();
508 let n_features = x.ncols();
509 let quantiles = self.quantiles_.as_ref().expect("operation should succeed");
510 let references = self.references_.as_ref().expect("operation should succeed");
511
512 if n_features != quantiles.len() {
513 return Err(SklearsError::InvalidInput(format!(
514 "X has {} features, but QuantileTransformer is expecting {} features",
515 n_features,
516 quantiles.len()
517 )));
518 }
519
520 let mut result = Array2::zeros((n_samples, n_features));
521
522 for i in 0..n_samples {
523 for j in 0..n_features {
524 let value = x[[i, j]];
525
526 let uniform_value = match self.config.output_distribution {
527 QuantileOutput::Uniform => value,
528 QuantileOutput::Normal => normal_to_uniform(value, self.config.clip),
529 };
530
531 result[[i, j]] = interpolate_value(uniform_value, references, &quantiles[j]);
533 }
534 }
535
536 Ok(result)
537 }
538}
539
540#[allow(non_snake_case)]
541#[cfg(test)]
542mod tests {
543 use super::*;
544 use approx::assert_abs_diff_eq;
545 use scirs2_core::ndarray::array;
546
547 #[test]
548 fn test_quantile_transformer_uniform() {
549 let x = array![
550 [0.0],
551 [1.0],
552 [2.0],
553 [3.0],
554 [4.0],
555 [5.0],
556 [6.0],
557 [7.0],
558 [8.0],
559 [9.0],
560 ];
561
562 let qt = QuantileTransformer::new()
563 .n_quantiles(10)
564 .expect("valid parameter")
565 .output_distribution(QuantileOutput::Uniform)
566 .fit(&x, &())
567 .expect("operation should succeed");
568
569 let x_transformed = qt.transform(&x).expect("transformation should succeed");
570
571 for value in x_transformed.iter() {
573 assert!(*value >= 0.0 && *value <= 1.0);
574 }
575
576 for i in 1..x_transformed.len() {
578 assert!(x_transformed[[i, 0]] >= x_transformed[[i - 1, 0]]);
579 }
580 }
581
582 #[test]
583 fn test_quantile_transformer_normal() {
584 let x = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0],];
585
586 let qt = QuantileTransformer::new()
587 .n_quantiles(6)
588 .expect("valid parameter")
589 .output_distribution(QuantileOutput::Normal)
590 .fit(&x, &())
591 .expect("operation should succeed");
592
593 let x_transformed = qt.transform(&x).expect("transformation should succeed");
594
595 for i in 1..x_transformed.len() {
597 assert!(
598 x_transformed[[i, 0]] >= x_transformed[[i - 1, 0]] - 1e-10,
599 "Values at index {} ({}) should be >= values at index {} ({})",
600 i,
601 x_transformed[[i, 0]],
602 i - 1,
603 x_transformed[[i - 1, 0]]
604 );
605 }
606
607 let min_val = x_transformed.iter().fold(Float::INFINITY, |a, &b| a.min(b));
609 let max_val = x_transformed
610 .iter()
611 .fold(Float::NEG_INFINITY, |a, &b| a.max(b));
612
613 assert!(
615 min_val > -5.0 && max_val < 5.0,
616 "Normal transformed values should be in reasonable range, got [{}, {}]",
617 min_val,
618 max_val
619 );
620 }
621
622 #[test]
623 fn test_quantile_transformer_multivariate() {
624 let x = array![
625 [0.0, 10.0],
626 [1.0, 20.0],
627 [2.0, 30.0],
628 [3.0, 40.0],
629 [4.0, 50.0],
630 ];
631
632 let qt = QuantileTransformer::new()
633 .n_quantiles(5)
634 .expect("valid parameter")
635 .fit(&x, &())
636 .expect("operation should succeed");
637
638 let x_transformed = qt.transform(&x).expect("transformation should succeed");
639
640 assert_eq!(x_transformed.ncols(), 2);
642
643 for value in x_transformed.iter() {
645 assert!(*value >= 0.0 && *value <= 1.0);
646 }
647 }
648
649 #[test]
650 fn test_quantile_transformer_inverse() {
651 let x = array![[0.0], [1.0], [2.0], [3.0], [4.0],];
652
653 let qt = QuantileTransformer::new()
654 .n_quantiles(5)
655 .expect("valid parameter")
656 .fit(&x, &())
657 .expect("operation should succeed");
658
659 let x_transformed = qt.transform(&x).expect("transformation should succeed");
660 let x_inverse = qt
661 .inverse_transform(&x_transformed)
662 .expect("operation should succeed");
663
664 for i in 0..x.nrows() {
666 assert_abs_diff_eq!(x[[i, 0]], x_inverse[[i, 0]], epsilon = 1e-6);
667 }
668 }
669
670 #[test]
671 fn test_quantile_transformer_edge_cases() {
672 let x = array![[0.0], [0.0], [1.0], [1.0], [1.0],];
673
674 let qt = QuantileTransformer::new()
675 .n_quantiles(3)
676 .expect("valid parameter")
677 .fit(&x, &())
678 .expect("operation should succeed");
679
680 let x_transformed = qt.transform(&x).expect("transformation should succeed");
681
682 for value in x_transformed.iter() {
684 assert!(value.is_finite());
685 }
686 }
687
688 #[test]
689 fn test_interpolate_value() {
690 let quantiles = array![0.0, 1.0, 2.0, 3.0];
691 let references = array![0.0, 0.33, 0.67, 1.0];
692
693 assert_abs_diff_eq!(interpolate_value(0.0, &quantiles, &references), 0.0);
695 assert_abs_diff_eq!(interpolate_value(3.0, &quantiles, &references), 1.0);
696
697 assert_abs_diff_eq!(
699 interpolate_value(0.5, &quantiles, &references),
700 0.165,
701 epsilon = 1e-3
702 );
703 assert_abs_diff_eq!(
704 interpolate_value(1.5, &quantiles, &references),
705 0.5,
706 epsilon = 1e-3
707 );
708
709 assert_abs_diff_eq!(interpolate_value(-1.0, &quantiles, &references), 0.0);
711 assert_abs_diff_eq!(interpolate_value(4.0, &quantiles, &references), 1.0);
712 }
713
714 #[test]
715 fn test_uniform_to_normal() {
716 assert_abs_diff_eq!(uniform_to_normal(0.5, true), 0.0, epsilon = 1e-4);
718
719 let values = vec![0.1, 0.3, 0.5, 0.7, 0.9];
721 let transformed: Vec<Float> = values.iter().map(|&v| uniform_to_normal(v, true)).collect();
722
723 for i in 1..transformed.len() {
724 assert!(transformed[i] > transformed[i - 1]);
725 }
726 }
727
728 #[test]
729 fn test_enhanced_quantile_transformer_with_clipping() {
730 let x = array![
731 [-10.0], [0.0],
733 [1.0],
734 [2.0],
735 [3.0],
736 [100.0], ];
738
739 let qt = QuantileTransformer::new()
740 .n_quantiles(6)
741 .expect("valid parameter")
742 .output_distribution(QuantileOutput::Normal)
743 .clip(true)
744 .fit(&x, &())
745 .expect("operation should succeed");
746
747 let x_transformed = qt.transform(&x).expect("transformation should succeed");
748
749 for value in x_transformed.iter() {
751 assert!(value.is_finite(), "All transformed values should be finite");
752 }
753
754 for i in 1..x_transformed.nrows() {
756 assert!(x_transformed[[i, 0]] >= x_transformed[[i - 1, 0]]);
757 }
758 }
759
760 #[test]
761 fn test_quantile_transformer_with_outlier_filtering() {
762 let mut x_data = vec![];
763 for i in 0..100 {
765 x_data.push([i as Float]);
766 }
767 x_data.push([-1000.0]);
769 x_data.push([1000.0]);
770
771 let x = Array2::from_shape_vec((102, 1), x_data.into_iter().flatten().collect())
772 .expect("shape and data length should match");
773
774 let qt_filtered = QuantileTransformer::new()
776 .n_quantiles(50)
777 .expect("valid parameter")
778 .ignore_outliers(Some((0.01, 0.99)))
779 .fit(&x, &())
780 .expect("operation should succeed");
781
782 let qt_unfiltered = QuantileTransformer::new()
784 .n_quantiles(50)
785 .expect("valid parameter")
786 .fit(&x, &())
787 .expect("operation should succeed");
788
789 let test_value = array![[50.0]];
790 let result_filtered = qt_filtered
791 .transform(&test_value)
792 .expect("transformation should succeed");
793 let result_unfiltered = qt_unfiltered
794 .transform(&test_value)
795 .expect("transformation should succeed");
796
797 assert!(result_filtered.iter().all(|&v| v.is_finite()));
800 assert!(result_unfiltered.iter().all(|&v| v.is_finite()));
801 }
802
803 #[test]
804 fn test_improved_inverse_error_function() {
805 assert_abs_diff_eq!(erfinv_accurate(0.0), 0.0, epsilon = 1e-10);
807
808 let test_val = 0.5;
810 assert_abs_diff_eq!(
811 erfinv_accurate(test_val),
812 -erfinv_accurate(-test_val),
813 epsilon = 1e-10
814 );
815
816 let values = vec![-0.9, -0.5, 0.0, 0.5, 0.9];
818 let transformed: Vec<Float> = values.iter().map(|&v| erfinv_accurate(v)).collect();
819
820 for i in 1..transformed.len() {
821 assert!(transformed[i] > transformed[i - 1]);
822 }
823 }
824
825 #[test]
826 fn test_normal_to_uniform_conversion() {
827 let uniform_values = vec![0.1, 0.3, 0.5, 0.7, 0.9];
829
830 for &uniform_val in &uniform_values {
831 let normal_val = uniform_to_normal(uniform_val, true);
832 let recovered_uniform = normal_to_uniform(normal_val, true);
833
834 assert_abs_diff_eq!(uniform_val, recovered_uniform, epsilon = 1e-3);
835 }
836
837 assert_abs_diff_eq!(normal_to_uniform(0.0, true), 0.5, epsilon = 1e-3);
839 }
840
841 #[test]
842 fn test_builder_methods() {
843 let qt = QuantileTransformer::new()
844 .n_quantiles(500)
845 .expect("valid parameter")
846 .output_distribution(QuantileOutput::Normal)
847 .subsample(Some(1000))
848 .clip(false)
849 .ignore_outliers(Some((0.05, 0.95)));
850
851 assert_eq!(qt.config.n_quantiles, 500);
852 assert_eq!(qt.config.output_distribution, QuantileOutput::Normal);
853 assert_eq!(qt.config.subsample, Some(1000));
854 assert_eq!(qt.config.clip, false);
855 assert_eq!(qt.config.ignore_outliers, Some((0.05, 0.95)));
856 }
857
858 #[test]
859 #[should_panic(expected = "Outlier range must be")]
860 fn test_invalid_outlier_range() {
861 QuantileTransformer::new().ignore_outliers(Some((0.9, 0.1))); }
863}