scouter_types/binning/
strategy.rs

1use crate::binning::equal_width::EqualWidthBinning;
2use crate::binning::quantile::QuantileBinning;
3use crate::error::TypeError;
4use ndarray::{Array1, ArrayView1};
5use num_traits::{Float, FromPrimitive};
6use pyo3::{pyclass, Bound, IntoPyObjectExt, PyAny, PyResult, Python};
7use serde::{Deserialize, Serialize};
8
9#[pyclass]
10#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
11pub enum BinningStrategy {
12    QuantileBinning(QuantileBinning),
13    EqualWidthBinning(EqualWidthBinning),
14}
15
16impl Default for BinningStrategy {
17    fn default() -> Self {
18        BinningStrategy::QuantileBinning(QuantileBinning::default())
19    }
20}
21
22impl BinningStrategy {
23    pub fn strategy<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
24        match self {
25            BinningStrategy::QuantileBinning(strategy) => strategy.clone().into_bound_py_any(py),
26            BinningStrategy::EqualWidthBinning(strategy) => strategy.clone().into_bound_py_any(py),
27        }
28    }
29
30    pub fn compute_edges<F>(&self, arr: &ArrayView1<F>) -> Result<Vec<F>, TypeError>
31    where
32        F: Float + FromPrimitive,
33    {
34        let clean_arr = Array1::from(
35            arr.iter()
36                .filter(|&&x| x.is_finite())
37                .cloned()
38                .collect::<Vec<F>>(),
39        );
40
41        if clean_arr.is_empty() {
42            return Err(TypeError::EmptyArrayError(
43                "unable to compute bin edges".to_string(),
44            ));
45        }
46
47        match self {
48            BinningStrategy::QuantileBinning(b) => b.compute_edges(arr),
49            BinningStrategy::EqualWidthBinning(b) => b.compute_edges(arr),
50        }
51    }
52}
53
54#[cfg(test)]
55mod tests {
56    use super::*;
57    use ndarray::Array1;
58
59    #[test]
60    fn test_empty_array() {
61        let binning_strategy = BinningStrategy::QuantileBinning(QuantileBinning { num_bins: 4 });
62        let data = Array1::<f64>::from(vec![]);
63        let result = binning_strategy.compute_edges(&data.view());
64
65        assert!(result.is_err());
66        match result.unwrap_err() {
67            TypeError::EmptyArrayError(msg) => {
68                assert_eq!(msg, "unable to compute bin edges");
69            }
70            _ => panic!("Expected EmptyArrayError"),
71        }
72    }
73
74    #[test]
75    fn test_default_method() {
76        let default_method = BinningStrategy::default();
77        match default_method {
78            BinningStrategy::QuantileBinning(_) => {} // Expected
79            _ => panic!("Default should be QuantileBinning"),
80        }
81    }
82}