scirs2_core/ndarray_ext/stats/distribution.rs
1//! Distribution-related functions for ndarray arrays
2//!
3//! This module provides functions for working with data distributions,
4//! including histograms, binning, and quantile calculations.
5
6use ndarray::{Array, ArrayView, Ix1, Ix2};
7use num_traits::{Float, FromPrimitive};
8
9/// Result type for histogram function
10pub type HistogramResult<T> = Result<(Array<usize, Ix1>, Array<T, Ix1>), &'static str>;
11
12/// Result type for histogram2d function
13pub type Histogram2dResult<T> =
14 Result<(Array<usize, Ix2>, Array<T, Ix1>, Array<T, Ix1>), &'static str>;
15
16/// Calculate a histogram of data
17///
18/// # Arguments
19///
20/// * `array` - The input 1D array
21/// * `bins` - The number of bins
22/// * `range` - Optional tuple (min, max) to use. If None, the range is based on data
23/// * `weights` - Optional array of weights for each data point
24///
25/// # Returns
26///
27/// A tuple containing (histogram, bin_edges)
28///
29/// # Examples
30///
31/// ```
32/// use ndarray::array;
33/// use scirs2_core::ndarray_ext::stats::histogram;
34///
35/// let data = array![0.1, 0.5, 1.1, 1.5, 2.2, 2.9, 3.1, 3.8, 4.1, 4.9];
36/// let (hist, bin_edges) = histogram(data.view(), 5, None, None).unwrap();
37///
38/// assert_eq!(hist.len(), 5);
39/// assert_eq!(bin_edges.len(), 6);
40/// ```
41#[allow(dead_code)]
42pub fn histogram<T>(
43 array: ArrayView<T, Ix1>,
44 bins: usize,
45 range: Option<(T, T)>,
46 weights: Option<ArrayView<T, Ix1>>,
47) -> HistogramResult<T>
48where
49 T: Clone + Float + FromPrimitive,
50{
51 if array.is_empty() {
52 return Err("Cannot compute histogram of an empty array");
53 }
54
55 if bins == 0 {
56 return Err("Number of bins must be positive");
57 }
58
59 // Get range (min, max) of the data
60 let (min_val, max_val) = match range {
61 Some(r) => r,
62 None => {
63 let mut min_val = T::infinity();
64 let mut max_val = T::neg_infinity();
65
66 for &val in array.iter() {
67 if val < min_val {
68 min_val = val;
69 }
70 if val > max_val {
71 max_val = val;
72 }
73 }
74 (min_val, max_val)
75 }
76 };
77
78 if min_val >= max_val {
79 return Err("Range must be (min, max) with min < max");
80 }
81
82 // Create bin edges
83 let mut bin_edges = Array::<T, Ix1>::zeros(bins + 1);
84 let bin_width = (max_val - min_val) / T::from_usize(bins).unwrap();
85
86 for i in 0..=bins {
87 bin_edges[i] = min_val + bin_width * T::from_usize(i).unwrap();
88 }
89
90 // Ensure the last bin edge is exactly max_val
91 bin_edges[bins] = max_val;
92
93 // Initialize histogram array
94 let mut hist = Array::<usize, Ix1>::zeros(bins);
95
96 // Fill histogram
97 match weights {
98 Some(w) => {
99 if w.len() != array.len() {
100 return Err("Weights array must have the same length as the data array");
101 }
102
103 for (&val, &weight) in array.iter().zip(w.iter()) {
104 // Skip values outside the range
105 if val < min_val || val > max_val {
106 continue;
107 }
108
109 // Handle edge case where val == max_val (include in the last bin)
110 if val == max_val {
111 hist[bins - 1] += 1;
112 continue;
113 }
114
115 // Find bin index
116 let scaled_val = (val - min_val) / bin_width;
117 let bin_idx = scaled_val.to_usize().unwrap_or(0);
118 let bin_idx = bin_idx.min(bins - 1); // Ensure index is in bounds
119
120 // Add to histogram (with weight)
121 let weight_int = weight.to_usize().unwrap_or(1);
122 hist[bin_idx] += weight_int;
123 }
124 }
125 None => {
126 for &val in array.iter() {
127 // Skip values outside the range
128 if val < min_val || val > max_val {
129 continue;
130 }
131
132 // Handle edge case where val == max_val (include in the last bin)
133 if val == max_val {
134 hist[bins - 1] += 1;
135 continue;
136 }
137
138 // Find bin index
139 let scaled_val = (val - min_val) / bin_width;
140 let bin_idx = scaled_val.to_usize().unwrap_or(0);
141 let bin_idx = bin_idx.min(bins - 1); // Ensure index is in bounds
142
143 // Add to histogram
144 hist[bin_idx] += 1;
145 }
146 }
147 }
148
149 Ok((hist, bin_edges))
150}
151
152/// Calculate a 2D histogram of data
153///
154/// # Arguments
155///
156/// * `x` - The x coordinates of the data points
157/// * `y` - The y coordinates of the data points
158/// * `bins` - Either a tuple (x_bins, y_bins) for the number of bins, or None for 10 bins in each direction
159/// * `range` - Optional tuple ((x_min, x_max), (y_min, y_max)) to use. If None, the range is based on data
160/// * `weights` - Optional array of weights for each data point
161///
162/// # Returns
163///
164/// A tuple containing (histogram, x_edges, y_edges)
165///
166/// # Examples
167///
168/// ```
169/// use ndarray::array;
170/// use scirs2_core::ndarray_ext::stats::histogram2d;
171///
172/// let x = array![0.1, 0.5, 1.3, 2.5, 3.1, 3.8, 4.2, 4.9];
173/// let y = array![0.2, 0.8, 1.5, 2.0, 3.0, 3.2, 3.5, 4.5];
174/// let (hist, x_edges, y_edges) = histogram2d(x.view(), y.view(), Some((4, 4)), None, None).unwrap();
175///
176/// assert_eq!(hist.shape(), &[4, 4]);
177/// assert_eq!(x_edges.len(), 5);
178/// assert_eq!(y_edges.len(), 5);
179/// ```
180#[allow(dead_code)]
181pub fn histogram2d<T>(
182 x: ArrayView<T, Ix1>,
183 y: ArrayView<T, Ix1>,
184 bins: Option<(usize, usize)>,
185 range: Option<((T, T), (T, T))>,
186 weights: Option<ArrayView<T, Ix1>>,
187) -> Histogram2dResult<T>
188where
189 T: Clone + Float + FromPrimitive,
190{
191 if x.is_empty() || y.is_empty() {
192 return Err("Cannot compute histogram of empty arrays");
193 }
194
195 if x.len() != y.len() {
196 return Err("x and y arrays must have the same length");
197 }
198
199 // Default to 10 bins in each direction if not specified
200 let (x_bins, y_bins) = bins.unwrap_or((10, 10));
201
202 if x_bins == 0 || y_bins == 0 {
203 return Err("Number of bins must be positive");
204 }
205
206 // Get range for x and y
207 let ((x_min, x_max), (y_min, y_max)) = match range {
208 Some(r) => r,
209 None => {
210 let mut x_min = T::infinity();
211 let mut x_max = T::neg_infinity();
212 let mut y_min = T::infinity();
213 let mut y_max = T::neg_infinity();
214
215 for (&x_val, &y_val) in x.iter().zip(y.iter()) {
216 if x_val < x_min {
217 x_min = x_val;
218 }
219 if x_val > x_max {
220 x_max = x_val;
221 }
222 if y_val < y_min {
223 y_min = y_val;
224 }
225 if y_val > y_max {
226 y_max = y_val;
227 }
228 }
229 ((x_min, x_max), (y_min, y_max))
230 }
231 };
232
233 if x_min >= x_max || y_min >= y_max {
234 return Err("Range must be (min, max) with min < max");
235 }
236
237 // Create bin edges
238 let mut x_edges = Array::<T, Ix1>::zeros(x_bins + 1);
239 let mut y_edges = Array::<T, Ix1>::zeros(y_bins + 1);
240
241 let x_bin_width = (x_max - x_min) / T::from_usize(x_bins).unwrap();
242 let y_bin_width = (y_max - y_min) / T::from_usize(y_bins).unwrap();
243
244 for i in 0..=x_bins {
245 x_edges[i] = x_min + x_bin_width * T::from_usize(i).unwrap();
246 }
247
248 for i in 0..=y_bins {
249 y_edges[i] = y_min + y_bin_width * T::from_usize(i).unwrap();
250 }
251
252 // Ensure the last bin edges are exactly max values
253 x_edges[x_bins] = x_max;
254 y_edges[y_bins] = y_max;
255
256 // Initialize histogram array
257 let mut hist = Array::<usize, Ix2>::zeros((y_bins, x_bins));
258
259 // Fill histogram
260 match weights {
261 Some(w) => {
262 if w.len() != x.len() {
263 return Err("Weights array must have the same length as the data arrays");
264 }
265
266 for ((&x_val, &y_val), &weight) in x.iter().zip(y.iter()).zip(w.iter()) {
267 // Skip values outside the range
268 if x_val < x_min || x_val > x_max || y_val < y_min || y_val > y_max {
269 continue;
270 }
271
272 // Find bin indices
273 let x_scaled = (x_val - x_min) / x_bin_width;
274 let y_scaled = (y_val - y_min) / y_bin_width;
275
276 let mut x_idx = x_scaled.to_usize().unwrap_or(0);
277 let mut y_idx = y_scaled.to_usize().unwrap_or(0);
278
279 // Handle edge cases where val == max_val
280 if x_val == x_max {
281 x_idx = x_bins - 1;
282 } else {
283 x_idx = x_idx.min(x_bins - 1);
284 }
285
286 if y_val == y_max {
287 y_idx = y_bins - 1;
288 } else {
289 y_idx = y_idx.min(y_bins - 1);
290 }
291
292 // Add to histogram (with weight)
293 let weight_int = weight.to_usize().unwrap_or(1);
294 hist[[y_idx, x_idx]] += weight_int;
295 }
296 }
297 None => {
298 for (&x_val, &y_val) in x.iter().zip(y.iter()) {
299 // Skip values outside the range
300 if x_val < x_min || x_val > x_max || y_val < y_min || y_val > y_max {
301 continue;
302 }
303
304 // Find bin indices
305 let x_scaled = (x_val - x_min) / x_bin_width;
306 let y_scaled = (y_val - y_min) / y_bin_width;
307
308 let mut x_idx = x_scaled.to_usize().unwrap_or(0);
309 let mut y_idx = y_scaled.to_usize().unwrap_or(0);
310
311 // Handle edge cases where val == max_val
312 if x_val == x_max {
313 x_idx = x_bins - 1;
314 } else {
315 x_idx = x_idx.min(x_bins - 1);
316 }
317
318 if y_val == y_max {
319 y_idx = y_bins - 1;
320 } else {
321 y_idx = y_idx.min(y_bins - 1);
322 }
323
324 // Add to histogram
325 hist[[y_idx, x_idx]] += 1;
326 }
327 }
328 }
329
330 Ok((hist, x_edges, y_edges))
331}
332
333/// Calculate the quantile values from a 1D array
334///
335/// # Arguments
336///
337/// * `array` - The input 1D array
338/// * `q` - The quantile or array of quantiles to compute (between 0 and 1)
339/// * `method` - The interpolation method to use: "linear" (default), "lower", "higher", "midpoint", or "nearest"
340///
341/// # Returns
342///
343/// An array containing the quantile values
344///
345/// # Examples
346///
347/// ```
348/// use ndarray::array;
349/// use scirs2_core::ndarray_ext::stats::quantile;
350///
351/// let data = array![1.0, 3.0, 5.0, 7.0, 9.0];
352///
353/// // Median (50th percentile)
354/// let median = quantile(data.view(), array![0.5].view(), Some("linear")).unwrap();
355/// assert_eq!(median[0], 5.0);
356///
357/// // Multiple quantiles
358/// let quartiles = quantile(data.view(), array![0.25, 0.5, 0.75].view(), None).unwrap();
359/// assert_eq!(quartiles[0], 3.0); // 25th percentile
360/// assert_eq!(quartiles[1], 5.0); // 50th percentile
361/// assert_eq!(quartiles[2], 7.0); // 75th percentile
362/// ```
363///
364/// This function is equivalent to ``NumPy``'s `np.quantile` function.
365#[allow(dead_code)]
366pub fn quantile<T>(
367 array: ArrayView<T, Ix1>,
368 q: ArrayView<T, Ix1>,
369 method: Option<&str>,
370) -> Result<Array<T, Ix1>, &'static str>
371where
372 T: Clone + Float + FromPrimitive,
373{
374 if array.is_empty() {
375 return Err("Cannot compute quantile of an empty array");
376 }
377
378 // Validate q values
379 for &val in q.iter() {
380 if val < T::from_f64(0.0).unwrap() || val > T::from_f64(1.0).unwrap() {
381 return Err("Quantile values must be between 0 and 1");
382 }
383 }
384
385 // Clone and sort the array
386 let mut sorted: Vec<T> = array.iter().copied().collect();
387 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
388
389 let n = sorted.len();
390 let mut result = Array::<T, Ix1>::zeros(q.len());
391
392 // The interpolation method to use
393 let method = method.unwrap_or("linear");
394
395 for (i, &q_val) in q.iter().enumerate() {
396 if q_val == T::from_f64(0.0).unwrap() {
397 result[i] = sorted[0];
398 continue;
399 }
400
401 if q_val == T::from_f64(1.0).unwrap() {
402 result[i] = sorted[n - 1];
403 continue;
404 }
405
406 // Calculate the position in the sorted array
407 let h = T::from_usize(n - 1).unwrap() * q_val;
408 let h_floor = h.floor();
409 let idx_low = h_floor.to_usize().unwrap_or(0).min(n - 1);
410 let idx_high = (idx_low + 1).min(n - 1);
411
412 match method {
413 "linear" => {
414 let weight = h - h_floor;
415 result[i] = sorted[idx_low] * (T::from_f64(1.0).unwrap() - weight) + sorted[idx_high] * weight;
416 }
417 "lower" => {
418 result[i] = sorted[idx_low];
419 }
420 "higher" => {
421 result[i] = sorted[idx_high];
422 }
423 "midpoint" => {
424 result[i] = (sorted[idx_low] + sorted[idx_high]) / T::from_f64(2.0).unwrap();
425 }
426 "nearest" => {
427 let weight = h - h_floor;
428 if weight < T::from_f64(0.5).unwrap() {
429 result[i] = sorted[idx_low];
430 } else {
431 result[i] = sorted[idx_high];
432 }
433 }
434 _ => return Err("Invalid interpolation method. Use 'linear', 'lower', 'higher', 'midpoint', or 'nearest'"),
435 }
436 }
437
438 Ok(result)
439}
440
441/// Count number of occurrences of each value in array of non-negative ints.
442///
443/// # Arguments
444///
445/// * `array` - Input array of non-negative integers
446/// * `minlength` - Minimum number of bins for the output array. If None, the output array length is determined by the maximum value in `array`.
447/// * `weights` - Optional weights array. If specified, must have same shape as `array`.
448///
449/// # Returns
450///
451/// An array of counts
452///
453/// # Examples
454///
455/// ```
456/// use ndarray::array;
457/// use scirs2_core::ndarray_ext::stats::bincount;
458///
459/// let data = array![1, 2, 3, 1, 2, 1, 0, 1, 3, 2];
460/// let counts = bincount(data.view(), None, None).unwrap();
461/// assert_eq!(counts.len(), 4);
462/// assert_eq!(counts[0], 1.0); // '0' occurs once
463/// assert_eq!(counts[1], 4.0); // '1' occurs four times
464/// assert_eq!(counts[2], 3.0); // '2' occurs three times
465/// assert_eq!(counts[3], 2.0); // '3' occurs twice
466/// ```
467///
468/// This function is equivalent to ``NumPy``'s `np.bincount` function.
469#[allow(dead_code)]
470pub fn bincount(
471 array: ArrayView<usize, Ix1>,
472 minlength: Option<usize>,
473 weights: Option<ArrayView<f64, Ix1>>,
474) -> Result<Array<f64, Ix1>, &'static str> {
475 if array.is_empty() {
476 return Err("Cannot compute bincount of an empty array");
477 }
478
479 // Find maximum value to determine number of bins
480 let mut max_val = 0;
481 for &val in array.iter() {
482 if val > max_val {
483 max_val = val;
484 }
485 }
486
487 // Determine length of output array
488 let length = if let Some(min_len) = minlength {
489 max_val.max(min_len - 1) + 1
490 } else {
491 max_val + 1
492 };
493
494 let mut result = Array::<f64, Ix1>::zeros(length);
495
496 match weights {
497 Some(w) => {
498 if w.len() != array.len() {
499 return Err("Weights array must have same length as input array");
500 }
501 for (&idx, &weight) in array.iter().zip(w.iter()) {
502 result[idx] += weight;
503 }
504 }
505 None => {
506 for &idx in array.iter() {
507 result[idx] += 1.0;
508 }
509 }
510 }
511
512 Ok(result)
513}
514
515/// Return the indices of the bins to which each value in input array belongs.
516///
517/// # Arguments
518///
519/// * `array` - Input array
520/// * `bins` - Array of bin edges
521/// * `right` - Indicates whether the intervals include the right or left bin edge
522/// * `result_type` - Whether to return the indices ('indices') or the bin values ('values')
523///
524/// # Returns
525///
526/// Array of indices or values depending on result_type
527///
528/// # Examples
529///
530/// ```
531/// use ndarray::array;
532/// use scirs2_core::ndarray_ext::stats::digitize;
533///
534/// let data = array![1.2, 3.5, 5.1, 0.8, 2.9];
535/// let bins = array![1.0, 3.0, 5.0];
536/// let indices = digitize(data.view(), bins.view(), false, "indices").unwrap();
537///
538/// assert_eq!(indices[0], 1); // 1.2 is in the first bin (1.0 <= x < 3.0)
539/// assert_eq!(indices[1], 2); // 3.5 is in the second bin (3.0 <= x < 5.0)
540/// assert_eq!(indices[2], 3); // 5.1 is after the last bin (>= 5.0)
541/// assert_eq!(indices[3], 0); // 0.8 is before the first bin (< 1.0)
542/// assert_eq!(indices[4], 1); // 2.9 is in the first bin (1.0 <= x < 3.0)
543/// ```
544///
545/// This function is equivalent to ``NumPy``'s `np.digitize` function.
546#[allow(dead_code)]
547pub fn digitize<T>(
548 array: ArrayView<T, Ix1>,
549 bins: ArrayView<T, Ix1>,
550 right: bool,
551 result_type: &str,
552) -> Result<Array<usize, Ix1>, &'static str>
553where
554 T: Clone + Float + FromPrimitive,
555{
556 if array.is_empty() {
557 return Err("Cannot digitize an empty array");
558 }
559
560 if bins.is_empty() {
561 return Err("Bins array cannot be empty");
562 }
563
564 // Check that bins are monotonically increasing
565 for i in 1..bins.len() {
566 if bins[i] <= bins[i.saturating_sub(1)] {
567 return Err("Bins must be monotonically increasing");
568 }
569 }
570
571 let mut result = Array::<usize, Ix1>::zeros(array.len());
572
573 for (i, &val) in array.iter().enumerate() {
574 let mut bin_idx = 0;
575
576 if right {
577 // Right inclusive: val <= edge
578 for j in 0..bins.len() {
579 if val <= bins[j] {
580 bin_idx = j;
581 break;
582 }
583 bin_idx = bins.len();
584 }
585 } else {
586 // Left inclusive: val < edge
587 for j in 0..bins.len() {
588 if val < bins[j] {
589 bin_idx = j;
590 break;
591 }
592 bin_idx = bins.len();
593 }
594 }
595
596 result[i] = bin_idx;
597 }
598
599 if result_type == "indices" {
600 Ok(result)
601 } else if result_type == "values" {
602 Err("'values' result_type is not yet implemented")
603 } else {
604 Err("result_type must be 'indices' or 'values'")
605 }
606}