scouter_types/binning/
strategy.rs1use crate::binning::equal_width::EqualWidthBinning;
2use crate::binning::quantile::QuantileBinning;
3use crate::error::TypeError;
4use ndarray::{Array1, ArrayView1};
5use num_traits::{Float, FromPrimitive};
6use pyo3::{pyclass, Bound, IntoPyObjectExt, PyAny, PyResult, Python};
7use serde::{Deserialize, Serialize};
8
9#[pyclass]
10#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
11pub enum BinningStrategy {
12 QuantileBinning(QuantileBinning),
13 EqualWidthBinning(EqualWidthBinning),
14}
15
16impl Default for BinningStrategy {
17 fn default() -> Self {
18 BinningStrategy::QuantileBinning(QuantileBinning::default())
19 }
20}
21
22impl BinningStrategy {
23 pub fn strategy<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
24 match self {
25 BinningStrategy::QuantileBinning(strategy) => strategy.clone().into_bound_py_any(py),
26 BinningStrategy::EqualWidthBinning(strategy) => strategy.clone().into_bound_py_any(py),
27 }
28 }
29
30 pub fn compute_edges<F>(&self, arr: &ArrayView1<F>) -> Result<Vec<F>, TypeError>
31 where
32 F: Float + FromPrimitive,
33 {
34 let clean_arr = Array1::from(
35 arr.iter()
36 .filter(|&&x| x.is_finite())
37 .cloned()
38 .collect::<Vec<F>>(),
39 );
40
41 if clean_arr.is_empty() {
42 return Err(TypeError::EmptyArrayError(
43 "unable to compute bin edges".to_string(),
44 ));
45 }
46
47 match self {
48 BinningStrategy::QuantileBinning(b) => b.compute_edges(arr),
49 BinningStrategy::EqualWidthBinning(b) => b.compute_edges(arr),
50 }
51 }
52}
53
54#[cfg(test)]
55mod tests {
56 use super::*;
57 use ndarray::Array1;
58
59 #[test]
60 fn test_empty_array() {
61 let binning_strategy = BinningStrategy::QuantileBinning(QuantileBinning { num_bins: 4 });
62 let data = Array1::<f64>::from(vec![]);
63 let result = binning_strategy.compute_edges(&data.view());
64
65 assert!(result.is_err());
66 match result.unwrap_err() {
67 TypeError::EmptyArrayError(msg) => {
68 assert_eq!(msg, "unable to compute bin edges");
69 }
70 _ => panic!("Expected EmptyArrayError"),
71 }
72 }
73
74 #[test]
75 fn test_default_method() {
76 let default_method = BinningStrategy::default();
77 match default_method {
78 BinningStrategy::QuantileBinning(_) => {} _ => panic!("Default should be QuantileBinning"),
80 }
81 }
82}