scouter_types/binning/
equal_width.rs

1use crate::error::TypeError;
2use ndarray::ArrayView1;
3use ndarray_stats::QuantileExt;
4use num_traits::{Float, FromPrimitive};
5use pyo3::prelude::PyAnyMethods;
6use pyo3::{pyclass, pymethods, Bound, IntoPyObjectExt, PyAny, PyResult, Python};
7use serde::{Deserialize, Serialize};
8
9#[pyclass]
10#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
11pub struct Manual {
12    #[pyo3(get, set)]
13    num_bins: usize,
14}
15
16#[pymethods]
17impl Manual {
18    #[new]
19    pub fn new(num_bins: usize) -> Self {
20        Manual { num_bins }
21    }
22}
23
24impl Manual {
25    pub fn num_bins(&self) -> usize {
26        self.num_bins
27    }
28}
29
30#[pyclass]
31#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
32pub struct SquareRoot;
33
34impl Default for SquareRoot {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40#[pymethods]
41impl SquareRoot {
42    #[new]
43    pub fn new() -> Self {
44        SquareRoot
45    }
46}
47
48impl SquareRoot {
49    pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize {
50        let n = arr.len() as f64;
51        n.sqrt().ceil() as usize
52    }
53}
54
55#[pyclass]
56#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
57pub struct Sturges;
58
59#[pymethods]
60impl Sturges {
61    #[new]
62    pub fn new() -> Self {
63        Sturges
64    }
65}
66
67impl Default for Sturges {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73impl Sturges {
74    pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize {
75        let n = arr.len() as f64;
76        (n.log2().ceil() + 1.0) as usize
77    }
78}
79
80#[pyclass]
81#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
82pub struct Rice;
83
84#[pymethods]
85impl Rice {
86    #[new]
87    pub fn new() -> Self {
88        Rice
89    }
90}
91
92impl Default for Rice {
93    fn default() -> Self {
94        Self::new()
95    }
96}
97
98impl Rice {
99    pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize {
100        let n = arr.len() as f64;
101        (2.0 * n.powf(1.0 / 3.0)).ceil() as usize
102    }
103}
104#[pyclass]
105#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
106pub struct Doane;
107
108#[pymethods]
109impl Doane {
110    #[new]
111    pub fn new() -> Self {
112        Doane
113    }
114}
115
116impl Default for Doane {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121
122impl Doane {
123    pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize
124    where
125        F: Float,
126    {
127        let n = arr.len() as f64;
128        let data: Vec<f64> = arr.iter().map(|&x| x.to_f64().unwrap()).collect();
129        let mu = data.iter().sum::<f64>() / n;
130        let m2 = data.iter().map(|&x| (x - mu).powi(2)).sum::<f64>() / n;
131        let m3 = data.iter().map(|&x| (x - mu).powi(3)).sum::<f64>() / n;
132        let g1 = m3 / m2.powf(3.0 / 2.0);
133        let sigma_g1 = ((6.0 * (n - 2.0)) / ((n + 1.0) * (n + 3.0))).sqrt();
134        let k = 1.0 + n.log2() + (1.0 + g1.abs() / sigma_g1).log2();
135        k.round() as usize
136    }
137}
138#[pyclass]
139#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
140pub struct Scott;
141
142#[pymethods]
143impl Scott {
144    #[new]
145    pub fn new() -> Self {
146        Scott
147    }
148}
149
150impl Default for Scott {
151    fn default() -> Self {
152        Self::new()
153    }
154}
155
156impl Scott {
157    pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize
158    where
159        F: Float + FromPrimitive,
160    {
161        let n = arr.len() as f64;
162
163        let std_dev = arr.std(F::from(0.0).unwrap()).to_f64().unwrap();
164
165        let bin_width = 3.49 * std_dev * n.powf(-1.0 / 3.0);
166
167        let min_val = *arr.min().unwrap();
168        let max_val = *arr.max().unwrap();
169        let range = (max_val - min_val).to_f64().unwrap();
170
171        (range / bin_width).ceil() as usize
172    }
173}
174#[pyclass]
175#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
176pub struct TerrellScott;
177
178#[pymethods]
179impl TerrellScott {
180    #[new]
181    pub fn new() -> Self {
182        TerrellScott
183    }
184}
185
186impl Default for TerrellScott {
187    fn default() -> Self {
188        Self::new()
189    }
190}
191
192impl TerrellScott {
193    pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize {
194        let n = arr.len() as f64;
195        (2.0 * n).powf(1.0 / 3.0).round() as usize
196    }
197}
198
199#[pyclass]
200#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
201pub struct FreedmanDiaconis;
202
203impl Default for FreedmanDiaconis {
204    fn default() -> Self {
205        Self::new()
206    }
207}
208
209#[pymethods]
210impl FreedmanDiaconis {
211    #[new]
212    pub fn new() -> Self {
213        FreedmanDiaconis
214    }
215}
216
217impl FreedmanDiaconis {
218    pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize
219    where
220        F: Float,
221    {
222        let mut data: Vec<f64> = arr.iter().map(|&x| x.to_f64().unwrap()).collect();
223        let n = data.len() as f64;
224
225        data.sort_by(|a, b| a.partial_cmp(b).unwrap());
226
227        let q1_index = (0.25 * (data.len() - 1) as f64) as usize;
228        let q3_index = (0.75 * (data.len() - 1) as f64) as usize;
229
230        let q1 = data[q1_index];
231        let q3 = data[q3_index];
232
233        let iqr = q3 - q1;
234
235        let bin_width = 2.0 * iqr / n.powf(1.0 / 3.0);
236
237        let min_val = *arr.min().unwrap();
238        let max_val = *arr.max().unwrap();
239        let range = (max_val - min_val).to_f64().unwrap();
240
241        (range / bin_width).ceil() as usize
242    }
243}
244
245#[pyclass]
246#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
247pub enum EqualWidthMethod {
248    Manual(Manual),
249    SquareRoot(SquareRoot),
250    Sturges(Sturges),
251    Rice(Rice),
252    Doane(Doane),
253    Scott(Scott),
254    TerrellScott(TerrellScott),
255    FreedmanDiaconis(FreedmanDiaconis),
256}
257
258impl EqualWidthMethod {
259    pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize
260    where
261        F: Float + FromPrimitive,
262    {
263        match &self {
264            EqualWidthMethod::Manual(m) => m.num_bins(),
265            EqualWidthMethod::SquareRoot(m) => m.num_bins(arr),
266            EqualWidthMethod::Sturges(m) => m.num_bins(arr),
267            EqualWidthMethod::Rice(m) => m.num_bins(arr),
268            EqualWidthMethod::Doane(m) => m.num_bins(arr),
269            EqualWidthMethod::Scott(m) => m.num_bins(arr),
270            EqualWidthMethod::TerrellScott(m) => m.num_bins(arr),
271            EqualWidthMethod::FreedmanDiaconis(m) => m.num_bins(arr),
272        }
273    }
274}
275
276impl Default for EqualWidthMethod {
277    fn default() -> Self {
278        EqualWidthMethod::Doane(Doane)
279    }
280}
281
282#[pyclass]
283#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
284pub struct EqualWidthBinning {
285    pub method: EqualWidthMethod,
286}
287
288#[pymethods]
289impl EqualWidthBinning {
290    #[new]
291    #[pyo3(signature = (method=None))]
292    pub fn new(method: Option<&Bound<'_, PyAny>>) -> Result<Self, TypeError> {
293        let method = match method {
294            None => EqualWidthMethod::default(),
295            Some(method_obj) => {
296                if method_obj.is_instance_of::<Manual>() {
297                    EqualWidthMethod::Manual(method_obj.extract()?)
298                } else if method_obj.is_instance_of::<SquareRoot>() {
299                    EqualWidthMethod::SquareRoot(method_obj.extract()?)
300                } else if method_obj.is_instance_of::<Rice>() {
301                    EqualWidthMethod::Rice(method_obj.extract()?)
302                } else if method_obj.is_instance_of::<Sturges>() {
303                    EqualWidthMethod::Sturges(method_obj.extract()?)
304                } else if method_obj.is_instance_of::<Doane>() {
305                    EqualWidthMethod::Doane(method_obj.extract()?)
306                } else if method_obj.is_instance_of::<Scott>() {
307                    EqualWidthMethod::Scott(method_obj.extract()?)
308                } else if method_obj.is_instance_of::<TerrellScott>() {
309                    EqualWidthMethod::TerrellScott(method_obj.extract()?)
310                } else if method_obj.is_instance_of::<FreedmanDiaconis>() {
311                    EqualWidthMethod::FreedmanDiaconis(method_obj.extract()?)
312                } else {
313                    return Err(TypeError::InvalidEqualWidthBinningMethodError);
314                }
315            }
316        };
317
318        Ok(EqualWidthBinning { method })
319    }
320
321    #[getter]
322    pub fn method<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
323        match &self.method {
324            EqualWidthMethod::Manual(m) => m.clone().into_bound_py_any(py),
325            EqualWidthMethod::SquareRoot(m) => m.clone().into_bound_py_any(py),
326            EqualWidthMethod::Sturges(m) => m.clone().into_bound_py_any(py),
327            EqualWidthMethod::Rice(m) => m.clone().into_bound_py_any(py),
328            EqualWidthMethod::Doane(m) => m.clone().into_bound_py_any(py),
329            EqualWidthMethod::Scott(m) => m.clone().into_bound_py_any(py),
330            EqualWidthMethod::TerrellScott(m) => m.clone().into_bound_py_any(py),
331            EqualWidthMethod::FreedmanDiaconis(m) => m.clone().into_bound_py_any(py),
332        }
333    }
334}
335
336impl EqualWidthBinning {
337    pub fn compute_edges<F>(&self, arr: &ArrayView1<F>) -> Result<Vec<F>, TypeError>
338    where
339        F: Float + FromPrimitive,
340    {
341        let min_val = *arr.min().unwrap();
342        let max_val = *arr.max().unwrap();
343        let num_bins = self.method.num_bins(arr);
344
345        if num_bins < 2 {
346            return Err(TypeError::InvalidBinCountError(
347                format!("Specified Binning strategy did not return enough bins, at least 2 are needed, got {num_bins}")
348            ));
349        }
350
351        let range = max_val - min_val;
352        let bin_width = range / F::from_usize(num_bins).unwrap();
353
354        Ok((1..num_bins)
355            .map(|i| min_val + bin_width * F::from_usize(i).unwrap())
356            .collect())
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363    use ndarray::{arr1, Array1};
364    use ndarray_rand::rand_distr::Normal;
365    use ndarray_rand::RandomExt;
366    use test_utils::retry_flaky_test_sync;
367
368    fn create_normal_data(n: usize, mean: f64, std: f64) -> Array1<f64> {
369        Array1::random(n, Normal::new(mean, std).unwrap())
370    }
371
372    #[test]
373    fn test_manual_basic() {
374        let manual = Manual::new(10);
375        assert_eq!(manual.num_bins(), 10);
376        assert_eq!(manual.num_bins, 10);
377    }
378
379    // SquareRoot method tests
380    #[test]
381    fn test_square_root_known_values() {
382        let sr = SquareRoot::new();
383
384        // Perfect squares
385        let arr = arr1(&[1.0; 9]);
386        assert_eq!(sr.num_bins(&arr.view()), 3);
387
388        let arr = arr1(&[1.0; 100]);
389        assert_eq!(sr.num_bins(&arr.view()), 10);
390
391        let arr = arr1(&[1.0; 64]);
392        assert_eq!(sr.num_bins(&arr.view()), 8);
393    }
394
395    #[test]
396    fn test_square_root_non_perfect_squares() {
397        let sr = SquareRoot::new();
398
399        let arr = arr1(&[1.0; 10]);
400        assert_eq!(sr.num_bins(&arr.view()), 4); // ceil(sqrt(10)) = 4
401
402        let arr = arr1(&[1.0; 50]);
403        assert_eq!(sr.num_bins(&arr.view()), 8); // ceil(sqrt(50)) = 8
404    }
405
406    // Sturges method tests
407    #[test]
408    fn test_sturges_known_values() {
409        let sturges = Sturges::new();
410
411        let arr = arr1(&[1.0; 16]);
412        assert_eq!(sturges.num_bins(&arr.view()), 5); // log2(16) + 1 = 5
413
414        let arr = arr1(&[1.0; 32]);
415        assert_eq!(sturges.num_bins(&arr.view()), 6); // log2(32) + 1 = 6
416
417        let arr = arr1(&[1.0; 128]);
418        assert_eq!(sturges.num_bins(&arr.view()), 8); // log2(128) + 1 = 8
419    }
420
421    #[test]
422    fn test_scott_different_scales() {
423        retry_flaky_test_sync!(3, 1000, {
424            let scott = Scott::new();
425
426            // Same distribution shape but different scales
427            let arr1 = create_normal_data(100, 0.0, 1.0);
428            let arr2 = create_normal_data(100, 0.0, 10.0);
429
430            let bins1 = scott.num_bins(&arr1.view());
431            let bins2 = scott.num_bins(&arr2.view());
432
433            // Both should give similar bin counts since Scott's rule accounts for scale
434            assert!((bins1 as i32 - bins2 as i32).abs() <= 2);
435        });
436    }
437
438    #[test]
439    fn test_terrell_scott_known_values() {
440        let ts = TerrellScott::new();
441
442        let arr = arr1(&[1.0; 8]);
443        assert_eq!(ts.num_bins(&arr.view()), 3); // round((2*8)^(1/3)) = 3
444
445        let arr = arr1(&[1.0; 125]);
446        assert_eq!(ts.num_bins(&arr.view()), 6); // round((2*125)^(1/3)) = 6
447    }
448
449    #[test]
450    fn test_freedman_diaconis_heavy_tailed() {
451        retry_flaky_test_sync!(3, 1000, {
452            let fd = FreedmanDiaconis::new();
453            // Test with heavy-tailed distribution (using scaled normal as approximation)
454            let mut arr = create_normal_data(200, 0.0, 3.0);
455            // Add some extreme values to create heavy tails
456            for i in 0..10 {
457                arr[i] *= 3.0
458            }
459
460            let bins = fd.num_bins(&arr.view());
461            assert!(
462                bins > 3 && bins < 30,
463                "Expected bins between 3 and 30, got {}",
464                bins
465            );
466        });
467    }
468
469    #[test]
470    fn test_small_arrays() {
471        let arr = arr1(&[1.0, 2.0, 3.0]);
472
473        assert_eq!(SquareRoot::new().num_bins(&arr.view()), 2);
474        assert_eq!(Sturges::new().num_bins(&arr.view()), 3);
475        assert_eq!(Rice::new().num_bins(&arr.view()), 3);
476
477        let doane_bins = Doane::new().num_bins(&arr.view());
478        assert!((1..=5).contains(&doane_bins));
479    }
480
481    #[test]
482    fn test_default_method() {
483        let default_method = EqualWidthMethod::default();
484        match default_method {
485            EqualWidthMethod::Doane(_) => {} // Expected
486            _ => panic!("Default should be Doane method"),
487        }
488    }
489
490    #[test]
491    fn test_equal_width_method_serialization() {
492        let methods = vec![
493            EqualWidthMethod::Manual(Manual::new(10)),
494            EqualWidthMethod::SquareRoot(SquareRoot::new()),
495            EqualWidthMethod::Sturges(Sturges::new()),
496            EqualWidthMethod::Rice(Rice::new()),
497            EqualWidthMethod::Doane(Doane::new()),
498            EqualWidthMethod::Scott(Scott::new()),
499            EqualWidthMethod::TerrellScott(TerrellScott::new()),
500            EqualWidthMethod::FreedmanDiaconis(FreedmanDiaconis::new()),
501        ];
502
503        for method in methods {
504            let serialized = serde_json::to_string(&method).unwrap();
505            let deserialized: EqualWidthMethod = serde_json::from_str(&serialized).unwrap();
506            assert_eq!(method, deserialized);
507        }
508    }
509
510    #[test]
511    fn test_extreme_ranges() {
512        let arr = arr1(&[1e-10, 1e10]);
513
514        // Methods should handle extreme ranges without panic
515        let _sqrt_bins = SquareRoot::new().num_bins(&arr.view());
516        let _sturges_bins = Sturges::new().num_bins(&arr.view());
517        let _rice_bins = Rice::new().num_bins(&arr.view());
518        let _doane_bins = Doane::new().num_bins(&arr.view());
519        let _scott_bins = Scott::new().num_bins(&arr.view());
520        let _ts_bins = TerrellScott::new().num_bins(&arr.view());
521        let _fd_bins = FreedmanDiaconis::new().num_bins(&arr.view());
522    }
523}