scouter_types/binning/
equal_width.rs

1use crate::error::TypeError;
2use ndarray::ArrayView1;
3use ndarray_stats::QuantileExt;
4use num_traits::{Float, FromPrimitive};
5use pyo3::prelude::PyAnyMethods;
6use pyo3::{pyclass, pymethods, Bound, IntoPyObjectExt, PyAny, PyResult, Python};
7use serde::{Deserialize, Serialize};
8
9#[pyclass]
10#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
11pub struct Manual {
12    #[pyo3(get, set)]
13    num_bins: usize,
14}
15
16#[pymethods]
17impl Manual {
18    #[new]
19    pub fn new(num_bins: usize) -> Self {
20        Manual { num_bins }
21    }
22}
23
24impl Manual {
25    pub fn num_bins(&self) -> usize {
26        self.num_bins
27    }
28}
29
30#[pyclass]
31#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
32pub struct SquareRoot;
33
34impl Default for SquareRoot {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40#[pymethods]
41impl SquareRoot {
42    #[new]
43    pub fn new() -> Self {
44        SquareRoot
45    }
46}
47
48impl SquareRoot {
49    pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize {
50        let n = arr.len() as f64;
51        n.sqrt().ceil() as usize
52    }
53}
54
55#[pyclass]
56#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
57pub struct Sturges;
58
59#[pymethods]
60impl Sturges {
61    #[new]
62    pub fn new() -> Self {
63        Sturges
64    }
65}
66
67impl Default for Sturges {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73impl Sturges {
74    pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize {
75        let n = arr.len() as f64;
76        (n.log2().ceil() + 1.0) as usize
77    }
78}
79
80#[pyclass]
81#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
82pub struct Rice;
83
84#[pymethods]
85impl Rice {
86    #[new]
87    pub fn new() -> Self {
88        Rice
89    }
90}
91
92impl Default for Rice {
93    fn default() -> Self {
94        Self::new()
95    }
96}
97
98impl Rice {
99    pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize {
100        let n = arr.len() as f64;
101        (2.0 * n.powf(1.0 / 3.0)).ceil() as usize
102    }
103}
104#[pyclass]
105#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
106pub struct Doane;
107
108#[pymethods]
109impl Doane {
110    #[new]
111    pub fn new() -> Self {
112        Doane
113    }
114}
115
116impl Default for Doane {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121
122impl Doane {
123    pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize
124    where
125        F: Float,
126    {
127        let n = arr.len() as f64;
128        let data: Vec<f64> = arr.iter().map(|&x| x.to_f64().unwrap()).collect();
129        let mu = data.iter().sum::<f64>() / n;
130        let m2 = data.iter().map(|&x| (x - mu).powi(2)).sum::<f64>() / n;
131        let m3 = data.iter().map(|&x| (x - mu).powi(3)).sum::<f64>() / n;
132        let g1 = m3 / m2.powf(3.0 / 2.0);
133        let sigma_g1 = ((6.0 * (n - 2.0)) / ((n + 1.0) * (n + 3.0))).sqrt();
134        let k = 1.0 + n.log2() + (1.0 + g1.abs() / sigma_g1).log2();
135        k.round() as usize
136    }
137}
138#[pyclass]
139#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
140pub struct Scott;
141
142#[pymethods]
143impl Scott {
144    #[new]
145    pub fn new() -> Self {
146        Scott
147    }
148}
149
150impl Default for Scott {
151    fn default() -> Self {
152        Self::new()
153    }
154}
155
156impl Scott {
157    pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize
158    where
159        F: Float + FromPrimitive,
160    {
161        let n = arr.len() as f64;
162
163        let std_dev = arr.std(F::from(0.0).unwrap()).to_f64().unwrap();
164
165        let bin_width = 3.49 * std_dev * n.powf(-1.0 / 3.0);
166
167        let min_val = *arr.min().unwrap();
168        let max_val = *arr.max().unwrap();
169        let range = (max_val - min_val).to_f64().unwrap();
170
171        (range / bin_width).ceil() as usize
172    }
173}
174#[pyclass]
175#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
176pub struct TerrellScott;
177
178#[pymethods]
179impl TerrellScott {
180    #[new]
181    pub fn new() -> Self {
182        TerrellScott
183    }
184}
185
186impl Default for TerrellScott {
187    fn default() -> Self {
188        Self::new()
189    }
190}
191
192impl TerrellScott {
193    pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize {
194        let n = arr.len() as f64;
195        (2.0 * n).powf(1.0 / 3.0).round() as usize
196    }
197}
198
199#[pyclass]
200#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
201pub struct FreedmanDiaconis;
202
203impl Default for FreedmanDiaconis {
204    fn default() -> Self {
205        Self::new()
206    }
207}
208
209#[pymethods]
210impl FreedmanDiaconis {
211    #[new]
212    pub fn new() -> Self {
213        FreedmanDiaconis
214    }
215}
216
217impl FreedmanDiaconis {
218    pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize
219    where
220        F: Float,
221    {
222        let mut data: Vec<f64> = arr.iter().map(|&x| x.to_f64().unwrap()).collect();
223        let n = data.len() as f64;
224
225        data.sort_by(|a, b| a.partial_cmp(b).unwrap());
226
227        let q1_index = (0.25 * (data.len() - 1) as f64) as usize;
228        let q3_index = (0.75 * (data.len() - 1) as f64) as usize;
229
230        let q1 = data[q1_index];
231        let q3 = data[q3_index];
232
233        let iqr = q3 - q1;
234
235        let bin_width = 2.0 * iqr / n.powf(1.0 / 3.0);
236
237        let min_val = *arr.min().unwrap();
238        let max_val = *arr.max().unwrap();
239        let range = (max_val - min_val).to_f64().unwrap();
240
241        (range / bin_width).ceil() as usize
242    }
243}
244
245#[pyclass]
246#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
247pub enum EqualWidthMethod {
248    Manual(Manual),
249    SquareRoot(SquareRoot),
250    Sturges(Sturges),
251    Rice(Rice),
252    Doane(Doane),
253    Scott(Scott),
254    TerrellScott(TerrellScott),
255    FreedmanDiaconis(FreedmanDiaconis),
256}
257
258impl EqualWidthMethod {
259    pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize
260    where
261        F: Float + FromPrimitive,
262    {
263        match &self {
264            EqualWidthMethod::Manual(m) => m.num_bins(),
265            EqualWidthMethod::SquareRoot(m) => m.num_bins(arr),
266            EqualWidthMethod::Sturges(m) => m.num_bins(arr),
267            EqualWidthMethod::Rice(m) => m.num_bins(arr),
268            EqualWidthMethod::Doane(m) => m.num_bins(arr),
269            EqualWidthMethod::Scott(m) => m.num_bins(arr),
270            EqualWidthMethod::TerrellScott(m) => m.num_bins(arr),
271            EqualWidthMethod::FreedmanDiaconis(m) => m.num_bins(arr),
272        }
273    }
274}
275
276impl Default for EqualWidthMethod {
277    fn default() -> Self {
278        EqualWidthMethod::Doane(Doane)
279    }
280}
281
282#[pyclass]
283#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
284pub struct EqualWidthBinning {
285    pub method: EqualWidthMethod,
286}
287
288#[pymethods]
289impl EqualWidthBinning {
290    #[new]
291    #[pyo3(signature = (method=None))]
292    pub fn new(method: Option<&Bound<'_, PyAny>>) -> Result<Self, TypeError> {
293        let method = match method {
294            None => EqualWidthMethod::default(),
295            Some(method_obj) => {
296                if method_obj.is_instance_of::<Manual>() {
297                    EqualWidthMethod::Manual(method_obj.extract()?)
298                } else if method_obj.is_instance_of::<SquareRoot>() {
299                    EqualWidthMethod::SquareRoot(method_obj.extract()?)
300                } else if method_obj.is_instance_of::<Rice>() {
301                    EqualWidthMethod::Rice(method_obj.extract()?)
302                } else if method_obj.is_instance_of::<Sturges>() {
303                    EqualWidthMethod::Sturges(method_obj.extract()?)
304                } else if method_obj.is_instance_of::<Doane>() {
305                    EqualWidthMethod::Doane(method_obj.extract()?)
306                } else if method_obj.is_instance_of::<Scott>() {
307                    EqualWidthMethod::Scott(method_obj.extract()?)
308                } else if method_obj.is_instance_of::<TerrellScott>() {
309                    EqualWidthMethod::TerrellScott(method_obj.extract()?)
310                } else if method_obj.is_instance_of::<FreedmanDiaconis>() {
311                    EqualWidthMethod::FreedmanDiaconis(method_obj.extract()?)
312                } else {
313                    return Err(TypeError::InvalidEqualWidthBinningMethodError);
314                }
315            }
316        };
317
318        Ok(EqualWidthBinning { method })
319    }
320
321    #[getter]
322    pub fn method<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
323        match &self.method {
324            EqualWidthMethod::Manual(m) => m.clone().into_bound_py_any(py),
325            EqualWidthMethod::SquareRoot(m) => m.clone().into_bound_py_any(py),
326            EqualWidthMethod::Sturges(m) => m.clone().into_bound_py_any(py),
327            EqualWidthMethod::Rice(m) => m.clone().into_bound_py_any(py),
328            EqualWidthMethod::Doane(m) => m.clone().into_bound_py_any(py),
329            EqualWidthMethod::Scott(m) => m.clone().into_bound_py_any(py),
330            EqualWidthMethod::TerrellScott(m) => m.clone().into_bound_py_any(py),
331            EqualWidthMethod::FreedmanDiaconis(m) => m.clone().into_bound_py_any(py),
332        }
333    }
334}
335
336impl EqualWidthBinning {
337    pub fn compute_edges<F>(&self, arr: &ArrayView1<F>) -> Result<Vec<F>, TypeError>
338    where
339        F: Float + FromPrimitive,
340    {
341        let min_val = *arr.min().unwrap();
342        let max_val = *arr.max().unwrap();
343        let num_bins = self.method.num_bins(arr);
344
345        if num_bins < 2 {
346            return Err(TypeError::InvalidBinCountError(
347                format!("Specified Binning strategy did not return enough bins, at least 2 are needed, got {num_bins}")
348            ));
349        }
350
351        let range = max_val - min_val;
352        let bin_width = range / F::from_usize(num_bins).unwrap();
353
354        Ok((1..num_bins)
355            .map(|i| min_val + bin_width * F::from_usize(i).unwrap())
356            .collect())
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363    use ndarray::{arr1, Array1};
364    use ndarray_rand::rand_distr::Normal;
365    use ndarray_rand::RandomExt;
366    fn create_normal_data(n: usize, mean: f64, std: f64) -> Array1<f64> {
367        Array1::random(n, Normal::new(mean, std).unwrap())
368    }
369
370    #[test]
371    fn test_manual_basic() {
372        let manual = Manual::new(10);
373        assert_eq!(manual.num_bins(), 10);
374        assert_eq!(manual.num_bins, 10);
375    }
376
377    // SquareRoot method tests
378    #[test]
379    fn test_square_root_known_values() {
380        let sr = SquareRoot::new();
381
382        // Perfect squares
383        let arr = arr1(&[1.0; 9]);
384        assert_eq!(sr.num_bins(&arr.view()), 3);
385
386        let arr = arr1(&[1.0; 100]);
387        assert_eq!(sr.num_bins(&arr.view()), 10);
388
389        let arr = arr1(&[1.0; 64]);
390        assert_eq!(sr.num_bins(&arr.view()), 8);
391    }
392
393    #[test]
394    fn test_square_root_non_perfect_squares() {
395        let sr = SquareRoot::new();
396
397        let arr = arr1(&[1.0; 10]);
398        assert_eq!(sr.num_bins(&arr.view()), 4); // ceil(sqrt(10)) = 4
399
400        let arr = arr1(&[1.0; 50]);
401        assert_eq!(sr.num_bins(&arr.view()), 8); // ceil(sqrt(50)) = 8
402    }
403
404    // Sturges method tests
405    #[test]
406    fn test_sturges_known_values() {
407        let sturges = Sturges::new();
408
409        let arr = arr1(&[1.0; 16]);
410        assert_eq!(sturges.num_bins(&arr.view()), 5); // log2(16) + 1 = 5
411
412        let arr = arr1(&[1.0; 32]);
413        assert_eq!(sturges.num_bins(&arr.view()), 6); // log2(32) + 1 = 6
414
415        let arr = arr1(&[1.0; 128]);
416        assert_eq!(sturges.num_bins(&arr.view()), 8); // log2(128) + 1 = 8
417    }
418
419    #[test]
420    fn test_scott_different_scales() {
421        let scott = Scott::new();
422
423        // Same distribution shape but different scales
424        let arr1 = create_normal_data(100, 0.0, 1.0);
425        let arr2 = create_normal_data(100, 0.0, 10.0);
426
427        let bins1 = scott.num_bins(&arr1.view());
428        let bins2 = scott.num_bins(&arr2.view());
429
430        // Both should give similar bin counts since Scott's rule accounts for scale
431        assert!((bins1 as i32 - bins2 as i32).abs() <= 2);
432    }
433
434    #[test]
435    fn test_terrell_scott_known_values() {
436        let ts = TerrellScott::new();
437
438        let arr = arr1(&[1.0; 8]);
439        assert_eq!(ts.num_bins(&arr.view()), 3); // round((2*8)^(1/3)) = 3
440
441        let arr = arr1(&[1.0; 125]);
442        assert_eq!(ts.num_bins(&arr.view()), 6); // round((2*125)^(1/3)) = 6
443    }
444
445    #[test]
446    fn test_freedman_diaconis_heavy_tailed() {
447        let fd = FreedmanDiaconis::new();
448        // Test with heavy-tailed distribution (using scaled normal as approximation)
449        let mut arr = create_normal_data(200, 0.0, 3.0);
450        // Add some extreme values to create heavy tails
451        for i in 0..10 {
452            arr[i] *= 3.0
453        }
454
455        let bins = fd.num_bins(&arr.view());
456        assert!(bins > 3 && bins < 30);
457    }
458
459    #[test]
460    fn test_small_arrays() {
461        let arr = arr1(&[1.0, 2.0, 3.0]);
462
463        assert_eq!(SquareRoot::new().num_bins(&arr.view()), 2);
464        assert_eq!(Sturges::new().num_bins(&arr.view()), 3);
465        assert_eq!(Rice::new().num_bins(&arr.view()), 3);
466
467        let doane_bins = Doane::new().num_bins(&arr.view());
468        assert!((1..=5).contains(&doane_bins));
469    }
470
471    #[test]
472    fn test_default_method() {
473        let default_method = EqualWidthMethod::default();
474        match default_method {
475            EqualWidthMethod::Doane(_) => {} // Expected
476            _ => panic!("Default should be Doane method"),
477        }
478    }
479
480    #[test]
481    fn test_equal_width_method_serialization() {
482        let methods = vec![
483            EqualWidthMethod::Manual(Manual::new(10)),
484            EqualWidthMethod::SquareRoot(SquareRoot::new()),
485            EqualWidthMethod::Sturges(Sturges::new()),
486            EqualWidthMethod::Rice(Rice::new()),
487            EqualWidthMethod::Doane(Doane::new()),
488            EqualWidthMethod::Scott(Scott::new()),
489            EqualWidthMethod::TerrellScott(TerrellScott::new()),
490            EqualWidthMethod::FreedmanDiaconis(FreedmanDiaconis::new()),
491        ];
492
493        for method in methods {
494            let serialized = serde_json::to_string(&method).unwrap();
495            let deserialized: EqualWidthMethod = serde_json::from_str(&serialized).unwrap();
496            assert_eq!(method, deserialized);
497        }
498    }
499
500    #[test]
501    fn test_extreme_ranges() {
502        let arr = arr1(&[1e-10, 1e10]);
503
504        // Methods should handle extreme ranges without panic
505        let _sqrt_bins = SquareRoot::new().num_bins(&arr.view());
506        let _sturges_bins = Sturges::new().num_bins(&arr.view());
507        let _rice_bins = Rice::new().num_bins(&arr.view());
508        let _doane_bins = Doane::new().num_bins(&arr.view());
509        let _scott_bins = Scott::new().num_bins(&arr.view());
510        let _ts_bins = TerrellScott::new().num_bins(&arr.view());
511        let _fd_bins = FreedmanDiaconis::new().num_bins(&arr.view());
512    }
513}