1use crate::error::TypeError;
2use ndarray::ArrayView1;
3use ndarray_stats::QuantileExt;
4use num_traits::{Float, FromPrimitive};
5use pyo3::prelude::PyAnyMethods;
6use pyo3::{pyclass, pymethods, Bound, IntoPyObjectExt, PyAny, PyResult, Python};
7use serde::{Deserialize, Serialize};
8
9#[pyclass]
10#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
11pub struct Manual {
12 #[pyo3(get, set)]
13 num_bins: usize,
14}
15
16#[pymethods]
17impl Manual {
18 #[new]
19 pub fn new(num_bins: usize) -> Self {
20 Manual { num_bins }
21 }
22}
23
24impl Manual {
25 pub fn num_bins(&self) -> usize {
26 self.num_bins
27 }
28}
29
30#[pyclass]
31#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
32pub struct SquareRoot;
33
34impl Default for SquareRoot {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40#[pymethods]
41impl SquareRoot {
42 #[new]
43 pub fn new() -> Self {
44 SquareRoot
45 }
46}
47
48impl SquareRoot {
49 pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize {
50 let n = arr.len() as f64;
51 n.sqrt().ceil() as usize
52 }
53}
54
55#[pyclass]
56#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
57pub struct Sturges;
58
59#[pymethods]
60impl Sturges {
61 #[new]
62 pub fn new() -> Self {
63 Sturges
64 }
65}
66
67impl Default for Sturges {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73impl Sturges {
74 pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize {
75 let n = arr.len() as f64;
76 (n.log2().ceil() + 1.0) as usize
77 }
78}
79
80#[pyclass]
81#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
82pub struct Rice;
83
84#[pymethods]
85impl Rice {
86 #[new]
87 pub fn new() -> Self {
88 Rice
89 }
90}
91
92impl Default for Rice {
93 fn default() -> Self {
94 Self::new()
95 }
96}
97
98impl Rice {
99 pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize {
100 let n = arr.len() as f64;
101 (2.0 * n.powf(1.0 / 3.0)).ceil() as usize
102 }
103}
104#[pyclass]
105#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
106pub struct Doane;
107
108#[pymethods]
109impl Doane {
110 #[new]
111 pub fn new() -> Self {
112 Doane
113 }
114}
115
116impl Default for Doane {
117 fn default() -> Self {
118 Self::new()
119 }
120}
121
122impl Doane {
123 pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize
124 where
125 F: Float,
126 {
127 let n = arr.len() as f64;
128 let data: Vec<f64> = arr.iter().map(|&x| x.to_f64().unwrap()).collect();
129 let mu = data.iter().sum::<f64>() / n;
130 let m2 = data.iter().map(|&x| (x - mu).powi(2)).sum::<f64>() / n;
131 let m3 = data.iter().map(|&x| (x - mu).powi(3)).sum::<f64>() / n;
132 let g1 = m3 / m2.powf(3.0 / 2.0);
133 let sigma_g1 = ((6.0 * (n - 2.0)) / ((n + 1.0) * (n + 3.0))).sqrt();
134 let k = 1.0 + n.log2() + (1.0 + g1.abs() / sigma_g1).log2();
135 k.round() as usize
136 }
137}
138#[pyclass]
139#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
140pub struct Scott;
141
142#[pymethods]
143impl Scott {
144 #[new]
145 pub fn new() -> Self {
146 Scott
147 }
148}
149
150impl Default for Scott {
151 fn default() -> Self {
152 Self::new()
153 }
154}
155
156impl Scott {
157 pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize
158 where
159 F: Float + FromPrimitive,
160 {
161 let n = arr.len() as f64;
162
163 let std_dev = arr.std(F::from(0.0).unwrap()).to_f64().unwrap();
164
165 let bin_width = 3.49 * std_dev * n.powf(-1.0 / 3.0);
166
167 let min_val = *arr.min().unwrap();
168 let max_val = *arr.max().unwrap();
169 let range = (max_val - min_val).to_f64().unwrap();
170
171 (range / bin_width).ceil() as usize
172 }
173}
174#[pyclass]
175#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
176pub struct TerrellScott;
177
178#[pymethods]
179impl TerrellScott {
180 #[new]
181 pub fn new() -> Self {
182 TerrellScott
183 }
184}
185
186impl Default for TerrellScott {
187 fn default() -> Self {
188 Self::new()
189 }
190}
191
192impl TerrellScott {
193 pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize {
194 let n = arr.len() as f64;
195 (2.0 * n).powf(1.0 / 3.0).round() as usize
196 }
197}
198
199#[pyclass]
200#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
201pub struct FreedmanDiaconis;
202
203impl Default for FreedmanDiaconis {
204 fn default() -> Self {
205 Self::new()
206 }
207}
208
209#[pymethods]
210impl FreedmanDiaconis {
211 #[new]
212 pub fn new() -> Self {
213 FreedmanDiaconis
214 }
215}
216
217impl FreedmanDiaconis {
218 pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize
219 where
220 F: Float,
221 {
222 let mut data: Vec<f64> = arr.iter().map(|&x| x.to_f64().unwrap()).collect();
223 let n = data.len() as f64;
224
225 data.sort_by(|a, b| a.partial_cmp(b).unwrap());
226
227 let q1_index = (0.25 * (data.len() - 1) as f64) as usize;
228 let q3_index = (0.75 * (data.len() - 1) as f64) as usize;
229
230 let q1 = data[q1_index];
231 let q3 = data[q3_index];
232
233 let iqr = q3 - q1;
234
235 let bin_width = 2.0 * iqr / n.powf(1.0 / 3.0);
236
237 let min_val = *arr.min().unwrap();
238 let max_val = *arr.max().unwrap();
239 let range = (max_val - min_val).to_f64().unwrap();
240
241 (range / bin_width).ceil() as usize
242 }
243}
244
245#[pyclass]
246#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
247pub enum EqualWidthMethod {
248 Manual(Manual),
249 SquareRoot(SquareRoot),
250 Sturges(Sturges),
251 Rice(Rice),
252 Doane(Doane),
253 Scott(Scott),
254 TerrellScott(TerrellScott),
255 FreedmanDiaconis(FreedmanDiaconis),
256}
257
258impl EqualWidthMethod {
259 pub fn num_bins<F>(&self, arr: &ArrayView1<F>) -> usize
260 where
261 F: Float + FromPrimitive,
262 {
263 match &self {
264 EqualWidthMethod::Manual(m) => m.num_bins(),
265 EqualWidthMethod::SquareRoot(m) => m.num_bins(arr),
266 EqualWidthMethod::Sturges(m) => m.num_bins(arr),
267 EqualWidthMethod::Rice(m) => m.num_bins(arr),
268 EqualWidthMethod::Doane(m) => m.num_bins(arr),
269 EqualWidthMethod::Scott(m) => m.num_bins(arr),
270 EqualWidthMethod::TerrellScott(m) => m.num_bins(arr),
271 EqualWidthMethod::FreedmanDiaconis(m) => m.num_bins(arr),
272 }
273 }
274}
275
276impl Default for EqualWidthMethod {
277 fn default() -> Self {
278 EqualWidthMethod::Doane(Doane)
279 }
280}
281
282#[pyclass]
283#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
284pub struct EqualWidthBinning {
285 pub method: EqualWidthMethod,
286}
287
288#[pymethods]
289impl EqualWidthBinning {
290 #[new]
291 #[pyo3(signature = (method=None))]
292 pub fn new(method: Option<&Bound<'_, PyAny>>) -> Result<Self, TypeError> {
293 let method = match method {
294 None => EqualWidthMethod::default(),
295 Some(method_obj) => {
296 if method_obj.is_instance_of::<Manual>() {
297 EqualWidthMethod::Manual(method_obj.extract()?)
298 } else if method_obj.is_instance_of::<SquareRoot>() {
299 EqualWidthMethod::SquareRoot(method_obj.extract()?)
300 } else if method_obj.is_instance_of::<Rice>() {
301 EqualWidthMethod::Rice(method_obj.extract()?)
302 } else if method_obj.is_instance_of::<Sturges>() {
303 EqualWidthMethod::Sturges(method_obj.extract()?)
304 } else if method_obj.is_instance_of::<Doane>() {
305 EqualWidthMethod::Doane(method_obj.extract()?)
306 } else if method_obj.is_instance_of::<Scott>() {
307 EqualWidthMethod::Scott(method_obj.extract()?)
308 } else if method_obj.is_instance_of::<TerrellScott>() {
309 EqualWidthMethod::TerrellScott(method_obj.extract()?)
310 } else if method_obj.is_instance_of::<FreedmanDiaconis>() {
311 EqualWidthMethod::FreedmanDiaconis(method_obj.extract()?)
312 } else {
313 return Err(TypeError::InvalidEqualWidthBinningMethodError);
314 }
315 }
316 };
317
318 Ok(EqualWidthBinning { method })
319 }
320
321 #[getter]
322 pub fn method<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
323 match &self.method {
324 EqualWidthMethod::Manual(m) => m.clone().into_bound_py_any(py),
325 EqualWidthMethod::SquareRoot(m) => m.clone().into_bound_py_any(py),
326 EqualWidthMethod::Sturges(m) => m.clone().into_bound_py_any(py),
327 EqualWidthMethod::Rice(m) => m.clone().into_bound_py_any(py),
328 EqualWidthMethod::Doane(m) => m.clone().into_bound_py_any(py),
329 EqualWidthMethod::Scott(m) => m.clone().into_bound_py_any(py),
330 EqualWidthMethod::TerrellScott(m) => m.clone().into_bound_py_any(py),
331 EqualWidthMethod::FreedmanDiaconis(m) => m.clone().into_bound_py_any(py),
332 }
333 }
334}
335
336impl EqualWidthBinning {
337 pub fn compute_edges<F>(&self, arr: &ArrayView1<F>) -> Result<Vec<F>, TypeError>
338 where
339 F: Float + FromPrimitive,
340 {
341 let min_val = *arr.min().unwrap();
342 let max_val = *arr.max().unwrap();
343 let num_bins = self.method.num_bins(arr);
344
345 if num_bins < 2 {
346 return Err(TypeError::InvalidBinCountError(
347 format!("Specified Binning strategy did not return enough bins, at least 2 are needed, got {num_bins}")
348 ));
349 }
350
351 let range = max_val - min_val;
352 let bin_width = range / F::from_usize(num_bins).unwrap();
353
354 Ok((1..num_bins)
355 .map(|i| min_val + bin_width * F::from_usize(i).unwrap())
356 .collect())
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363 use ndarray::{arr1, Array1};
364 use ndarray_rand::rand_distr::Normal;
365 use ndarray_rand::RandomExt;
366 fn create_normal_data(n: usize, mean: f64, std: f64) -> Array1<f64> {
367 Array1::random(n, Normal::new(mean, std).unwrap())
368 }
369
370 #[test]
371 fn test_manual_basic() {
372 let manual = Manual::new(10);
373 assert_eq!(manual.num_bins(), 10);
374 assert_eq!(manual.num_bins, 10);
375 }
376
377 #[test]
379 fn test_square_root_known_values() {
380 let sr = SquareRoot::new();
381
382 let arr = arr1(&[1.0; 9]);
384 assert_eq!(sr.num_bins(&arr.view()), 3);
385
386 let arr = arr1(&[1.0; 100]);
387 assert_eq!(sr.num_bins(&arr.view()), 10);
388
389 let arr = arr1(&[1.0; 64]);
390 assert_eq!(sr.num_bins(&arr.view()), 8);
391 }
392
393 #[test]
394 fn test_square_root_non_perfect_squares() {
395 let sr = SquareRoot::new();
396
397 let arr = arr1(&[1.0; 10]);
398 assert_eq!(sr.num_bins(&arr.view()), 4); let arr = arr1(&[1.0; 50]);
401 assert_eq!(sr.num_bins(&arr.view()), 8); }
403
404 #[test]
406 fn test_sturges_known_values() {
407 let sturges = Sturges::new();
408
409 let arr = arr1(&[1.0; 16]);
410 assert_eq!(sturges.num_bins(&arr.view()), 5); let arr = arr1(&[1.0; 32]);
413 assert_eq!(sturges.num_bins(&arr.view()), 6); let arr = arr1(&[1.0; 128]);
416 assert_eq!(sturges.num_bins(&arr.view()), 8); }
418
419 #[test]
420 fn test_scott_different_scales() {
421 let scott = Scott::new();
422
423 let arr1 = create_normal_data(100, 0.0, 1.0);
425 let arr2 = create_normal_data(100, 0.0, 10.0);
426
427 let bins1 = scott.num_bins(&arr1.view());
428 let bins2 = scott.num_bins(&arr2.view());
429
430 assert!((bins1 as i32 - bins2 as i32).abs() <= 2);
432 }
433
434 #[test]
435 fn test_terrell_scott_known_values() {
436 let ts = TerrellScott::new();
437
438 let arr = arr1(&[1.0; 8]);
439 assert_eq!(ts.num_bins(&arr.view()), 3); let arr = arr1(&[1.0; 125]);
442 assert_eq!(ts.num_bins(&arr.view()), 6); }
444
445 #[test]
446 fn test_freedman_diaconis_heavy_tailed() {
447 let fd = FreedmanDiaconis::new();
448 let mut arr = create_normal_data(200, 0.0, 3.0);
450 for i in 0..10 {
452 arr[i] *= 3.0
453 }
454
455 let bins = fd.num_bins(&arr.view());
456 assert!(bins > 3 && bins < 30);
457 }
458
459 #[test]
460 fn test_small_arrays() {
461 let arr = arr1(&[1.0, 2.0, 3.0]);
462
463 assert_eq!(SquareRoot::new().num_bins(&arr.view()), 2);
464 assert_eq!(Sturges::new().num_bins(&arr.view()), 3);
465 assert_eq!(Rice::new().num_bins(&arr.view()), 3);
466
467 let doane_bins = Doane::new().num_bins(&arr.view());
468 assert!((1..=5).contains(&doane_bins));
469 }
470
471 #[test]
472 fn test_default_method() {
473 let default_method = EqualWidthMethod::default();
474 match default_method {
475 EqualWidthMethod::Doane(_) => {} _ => panic!("Default should be Doane method"),
477 }
478 }
479
480 #[test]
481 fn test_equal_width_method_serialization() {
482 let methods = vec![
483 EqualWidthMethod::Manual(Manual::new(10)),
484 EqualWidthMethod::SquareRoot(SquareRoot::new()),
485 EqualWidthMethod::Sturges(Sturges::new()),
486 EqualWidthMethod::Rice(Rice::new()),
487 EqualWidthMethod::Doane(Doane::new()),
488 EqualWidthMethod::Scott(Scott::new()),
489 EqualWidthMethod::TerrellScott(TerrellScott::new()),
490 EqualWidthMethod::FreedmanDiaconis(FreedmanDiaconis::new()),
491 ];
492
493 for method in methods {
494 let serialized = serde_json::to_string(&method).unwrap();
495 let deserialized: EqualWidthMethod = serde_json::from_str(&serialized).unwrap();
496 assert_eq!(method, deserialized);
497 }
498 }
499
500 #[test]
501 fn test_extreme_ranges() {
502 let arr = arr1(&[1e-10, 1e10]);
503
504 let _sqrt_bins = SquareRoot::new().num_bins(&arr.view());
506 let _sturges_bins = Sturges::new().num_bins(&arr.view());
507 let _rice_bins = Rice::new().num_bins(&arr.view());
508 let _doane_bins = Doane::new().num_bins(&arr.view());
509 let _scott_bins = Scott::new().num_bins(&arr.view());
510 let _ts_bins = TerrellScott::new().num_bins(&arr.view());
511 let _fd_bins = FreedmanDiaconis::new().num_bins(&arr.view());
512 }
513}