scouter_types/binning/
quantile.rs1use crate::error::TypeError;
2use ndarray::ArrayView1;
3use num_traits::{Float, FromPrimitive};
4use pyo3::{pyclass, pymethods, PyResult};
5use serde::{Deserialize, Serialize};
6
7#[pyclass]
8#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
9pub struct QuantileBinning {
10 #[pyo3(get, set)]
11 pub num_bins: usize,
12}
13
14#[pymethods]
15impl QuantileBinning {
16 #[new]
17 #[pyo3(signature = (num_bins=10))]
18 pub fn new(num_bins: usize) -> PyResult<Self> {
19 Ok(QuantileBinning { num_bins })
20 }
21}
22
23impl Default for QuantileBinning {
24 fn default() -> Self {
25 QuantileBinning { num_bins: 10 }
26 }
27}
28
29impl QuantileBinning {
30 pub fn compute_edges<F>(&self, arr: &ArrayView1<F>) -> Result<Vec<F>, TypeError>
55 where
56 F: Float + FromPrimitive,
57 {
58 if self.num_bins < 2 {
59 return Err(TypeError::InvalidParameterError(
60 "num_bins must be at least 2".to_string(),
61 ));
62 }
63
64 let mut data: Vec<F> = arr.to_vec();
65 data.sort_by(|a, b| a.partial_cmp(b).unwrap());
66
67 let mut edges = Vec::new();
68 let n = data.len();
69
70 for i in 1..self.num_bins {
71 let p = i as f64 / self.num_bins as f64;
72
73 let m = 1.0 - p;
76
77 let np_plus_m = (n as f64) * p + m;
79
80 let j = np_plus_m.floor() as usize;
82
83 let h = np_plus_m - (j as f64);
85
86 let j_zero_indexed = if j > 0 { j - 1 } else { 0 };
88 let j_plus_1_zero_indexed = std::cmp::min(j_zero_indexed + 1, n - 1);
89
90 let one_minus_h = F::from_f64(1.0 - h).unwrap();
93 let h_f = F::from_f64(h).unwrap();
94
95 let quantile = one_minus_h * data[j_zero_indexed] + h_f * data[j_plus_1_zero_indexed];
96
97 edges.push(quantile);
98 }
99
100 Ok(edges)
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use approx::assert_abs_diff_eq;
108 use ndarray::Array1;
109
110 #[test]
111 fn test_invalid_num_bins() {
112 let binning = QuantileBinning { num_bins: 1 };
113 let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
114 let result = binning.compute_edges(&data.view());
115
116 assert!(result.is_err());
117 match result.unwrap_err() {
118 TypeError::InvalidParameterError(msg) => {
119 assert_eq!(msg, "num_bins must be at least 2");
120 }
121 _ => panic!("Expected InvalidParameterError"),
122 }
123 }
124
125 #[test]
126 fn test_quartiles_simple_case() {
127 let binning = QuantileBinning { num_bins: 4 };
128 let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
129 let edges = binning.compute_edges(&data.view()).unwrap();
130
131 assert_eq!(edges.len(), 3); assert_abs_diff_eq!(edges[0], 2.75, epsilon = 1e-10);
138 assert_abs_diff_eq!(edges[1], 4.5, epsilon = 1e-10);
139 assert_abs_diff_eq!(edges[2], 6.25, epsilon = 1e-10);
140 }
141
142 #[test]
143 fn test_unsorted_data_produces_monotonic_edges() {
144 let binning = QuantileBinning { num_bins: 5 };
145 let data = Array1::from(vec![
147 12.0, 8.0, 17.0, 33.0, 123.0, 6.0, 9.23, 123.43, 1.9, 4.0, 11.0, 2.0, 5.6,
148 ]);
149 let edges = binning.compute_edges(&data.view()).unwrap();
150
151 assert_eq!(edges.len(), 4);
152
153 for i in 1..edges.len() {
155 assert!(edges[i] > edges[i-1],
156 "Quantile edges should be monotonically increasing even with unsorted input: {} > {}",
157 edges[i], edges[i-1]);
158 }
159 }
160}