scouter_types/binning/
quantile.rs

1use crate::error::TypeError;
2use ndarray::ArrayView1;
3use num_traits::{Float, FromPrimitive};
4use pyo3::{pyclass, pymethods, PyResult};
5use serde::{Deserialize, Serialize};
6
7#[pyclass]
8#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
9pub struct QuantileBinning {
10    #[pyo3(get, set)]
11    pub num_bins: usize,
12}
13
14#[pymethods]
15impl QuantileBinning {
16    #[new]
17    #[pyo3(signature = (num_bins=10))]
18    pub fn new(num_bins: usize) -> PyResult<Self> {
19        Ok(QuantileBinning { num_bins })
20    }
21}
22
23impl Default for QuantileBinning {
24    fn default() -> Self {
25        QuantileBinning { num_bins: 10 }
26    }
27}
28
29impl QuantileBinning {
30    /// Computes quantile edges for binning using the R-7 method (Hyndman & Fan Type 7).
31    ///
32    /// This implementation follows the R-7 quantile definition from:
33    /// Hyndman, R. J. and Fan, Y. (1996) "Sample quantiles in statistical packages,"
34    /// The American Statistician, 50(4), pp. 361-365.
35    ///
36    /// The R-7 method uses the formula:
37    /// - m = 1 - p
38    /// - j = floor(np + m)
39    /// - h = np + m - j
40    /// - Q(p) = (1 - h) × x[j] + h × x[j+1]
41    ///
42    /// This method is the default in many statistical packages, median-unbiased
43    /// quantile estimates that are approximately unbiased for normal distributions.
44    ///
45    /// # Arguments
46    /// * `arr` - Sorted array of data values
47    ///
48    /// # Returns
49    /// * `Ok(Vec<F>)` - Vector of quantile edge values for binning
50    /// * `Err(DriftError)` - If insufficient data points for quantile calculation
51    ///
52    /// # Reference
53    /// PDF: https://www.amherst.edu/media/view/129116/original/Sample+Quantiles.pdf
54    pub fn compute_edges<F>(&self, arr: &ArrayView1<F>) -> Result<Vec<F>, TypeError>
55    where
56        F: Float + FromPrimitive,
57    {
58        if self.num_bins < 2 {
59            return Err(TypeError::InvalidParameterError(
60                "num_bins must be at least 2".to_string(),
61            ));
62        }
63
64        let mut data: Vec<F> = arr.to_vec();
65        data.sort_by(|a, b| a.partial_cmp(b).unwrap());
66
67        let mut edges = Vec::new();
68        let n = data.len();
69
70        for i in 1..self.num_bins {
71            let p = i as f64 / self.num_bins as f64;
72
73            // R-7 Formula Implementation
74            // Step 1: Calculate m (R-7 parameter: m = 1 - p)
75            let m = 1.0 - p;
76
77            // Step 2: Calculate np + m
78            let np_plus_m = (n as f64) * p + m;
79
80            // Step 3: Calculate j = floor(np + m)
81            let j = np_plus_m.floor() as usize;
82
83            // Step 4: Calculate h = np + m - j (fractional part)
84            let h = np_plus_m - (j as f64);
85
86            // Step 5: Convert j from 1-indexed (paper) to 0-indexed
87            let j_zero_indexed = if j > 0 { j - 1 } else { 0 };
88            let j_plus_1_zero_indexed = std::cmp::min(j_zero_indexed + 1, n - 1);
89
90            // Step 6: Apply R-7 interpolation formula
91            // Q(p) = (1 - h) × x[j] + h × x[j+1]
92            let one_minus_h = F::from_f64(1.0 - h).unwrap();
93            let h_f = F::from_f64(h).unwrap();
94
95            let quantile = one_minus_h * data[j_zero_indexed] + h_f * data[j_plus_1_zero_indexed];
96
97            edges.push(quantile);
98        }
99
100        Ok(edges)
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use approx::assert_abs_diff_eq;
108    use ndarray::Array1;
109
110    #[test]
111    fn test_invalid_num_bins() {
112        let binning = QuantileBinning { num_bins: 1 };
113        let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
114        let result = binning.compute_edges(&data.view());
115
116        assert!(result.is_err());
117        match result.unwrap_err() {
118            TypeError::InvalidParameterError(msg) => {
119                assert_eq!(msg, "num_bins must be at least 2");
120            }
121            _ => panic!("Expected InvalidParameterError"),
122        }
123    }
124
125    #[test]
126    fn test_quartiles_simple_case() {
127        let binning = QuantileBinning { num_bins: 4 };
128        let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
129        let edges = binning.compute_edges(&data.view()).unwrap();
130
131        assert_eq!(edges.len(), 3); // 4 quantiles = 3 edges
132
133        // For R-7 method with 8 data points and quartiles:
134        // Q1 (p=0.25): should be around 2.75
135        // Q2 (p=0.50): should be around 4.5
136        // Q3 (p=0.75): should be around 6.25
137        assert_abs_diff_eq!(edges[0], 2.75, epsilon = 1e-10);
138        assert_abs_diff_eq!(edges[1], 4.5, epsilon = 1e-10);
139        assert_abs_diff_eq!(edges[2], 6.25, epsilon = 1e-10);
140    }
141
142    #[test]
143    fn test_unsorted_data_produces_monotonic_edges() {
144        let binning = QuantileBinning { num_bins: 5 };
145        // Deliberately unsorted data
146        let data = Array1::from(vec![
147            12.0, 8.0, 17.0, 33.0, 123.0, 6.0, 9.23, 123.43, 1.9, 4.0, 11.0, 2.0, 5.6,
148        ]);
149        let edges = binning.compute_edges(&data.view()).unwrap();
150
151        assert_eq!(edges.len(), 4);
152
153        // Verify that quantile edges are monotonically increasing despite unsorted input
154        for i in 1..edges.len() {
155            assert!(edges[i] > edges[i-1],
156                    "Quantile edges should be monotonically increasing even with unsorted input: {} > {}",
157                    edges[i], edges[i-1]);
158        }
159    }
160}