scirs2_core/ndarray_ext/stats/
distribution.rs

1//! Distribution-related functions for ndarray arrays
2//!
3//! This module provides functions for working with data distributions,
4//! including histograms, binning, and quantile calculations.
5
6use ndarray::{Array, ArrayView, Ix1, Ix2};
7use num_traits::{Float, FromPrimitive};
8
9/// Result type for histogram function
10pub type HistogramResult<T> = Result<(Array<usize, Ix1>, Array<T, Ix1>), &'static str>;
11
12/// Result type for histogram2d function
13pub type Histogram2dResult<T> =
14    Result<(Array<usize, Ix2>, Array<T, Ix1>, Array<T, Ix1>), &'static str>;
15
16/// Calculate a histogram of data
17///
18/// # Arguments
19///
20/// * `array` - The input 1D array
21/// * `bins` - The number of bins
22/// * `range` - Optional tuple (min, max) to use. If None, the range is based on data
23/// * `weights` - Optional array of weights for each data point
24///
25/// # Returns
26///
27/// A tuple containing (histogram, bin_edges)
28///
29/// # Examples
30///
31/// ```
32/// use ndarray::array;
33/// use scirs2_core::ndarray_ext::stats::histogram;
34///
35/// let data = array![0.1, 0.5, 1.1, 1.5, 2.2, 2.9, 3.1, 3.8, 4.1, 4.9];
36/// let (hist, bin_edges) = histogram(data.view(), 5, None, None).unwrap();
37///
38/// assert_eq!(hist.len(), 5);
39/// assert_eq!(bin_edges.len(), 6);
40/// ```
41#[allow(dead_code)]
42pub fn histogram<T>(
43    array: ArrayView<T, Ix1>,
44    bins: usize,
45    range: Option<(T, T)>,
46    weights: Option<ArrayView<T, Ix1>>,
47) -> HistogramResult<T>
48where
49    T: Clone + Float + FromPrimitive,
50{
51    if array.is_empty() {
52        return Err("Cannot compute histogram of an empty array");
53    }
54
55    if bins == 0 {
56        return Err("Number of bins must be positive");
57    }
58
59    // Get range (min, max) of the data
60    let (min_val, max_val) = match range {
61        Some(r) => r,
62        None => {
63            let mut min_val = T::infinity();
64            let mut max_val = T::neg_infinity();
65
66            for &val in array.iter() {
67                if val < min_val {
68                    min_val = val;
69                }
70                if val > max_val {
71                    max_val = val;
72                }
73            }
74            (min_val, max_val)
75        }
76    };
77
78    if min_val >= max_val {
79        return Err("Range must be (min, max) with min < max");
80    }
81
82    // Create bin edges
83    let mut bin_edges = Array::<T, Ix1>::zeros(bins + 1);
84    let bin_width = (max_val - min_val) / T::from_usize(bins).unwrap();
85
86    for i in 0..=bins {
87        bin_edges[i] = min_val + bin_width * T::from_usize(i).unwrap();
88    }
89
90    // Ensure the last bin edge is exactly max_val
91    bin_edges[bins] = max_val;
92
93    // Initialize histogram array
94    let mut hist = Array::<usize, Ix1>::zeros(bins);
95
96    // Fill histogram
97    match weights {
98        Some(w) => {
99            if w.len() != array.len() {
100                return Err("Weights array must have the same length as the data array");
101            }
102
103            for (&val, &weight) in array.iter().zip(w.iter()) {
104                // Skip values outside the range
105                if val < min_val || val > max_val {
106                    continue;
107                }
108
109                // Handle edge case where val == max_val (include in the last bin)
110                if val == max_val {
111                    hist[bins - 1] += 1;
112                    continue;
113                }
114
115                // Find bin index
116                let scaled_val = (val - min_val) / bin_width;
117                let bin_idx = scaled_val.to_usize().unwrap_or(0);
118                let bin_idx = bin_idx.min(bins - 1); // Ensure index is in bounds
119
120                // Add to histogram (with weight)
121                let weight_int = weight.to_usize().unwrap_or(1);
122                hist[bin_idx] += weight_int;
123            }
124        }
125        None => {
126            for &val in array.iter() {
127                // Skip values outside the range
128                if val < min_val || val > max_val {
129                    continue;
130                }
131
132                // Handle edge case where val == max_val (include in the last bin)
133                if val == max_val {
134                    hist[bins - 1] += 1;
135                    continue;
136                }
137
138                // Find bin index
139                let scaled_val = (val - min_val) / bin_width;
140                let bin_idx = scaled_val.to_usize().unwrap_or(0);
141                let bin_idx = bin_idx.min(bins - 1); // Ensure index is in bounds
142
143                // Add to histogram
144                hist[bin_idx] += 1;
145            }
146        }
147    }
148
149    Ok((hist, bin_edges))
150}
151
152/// Calculate a 2D histogram of data
153///
154/// # Arguments
155///
156/// * `x` - The x coordinates of the data points
157/// * `y` - The y coordinates of the data points
158/// * `bins` - Either a tuple (x_bins, y_bins) for the number of bins, or None for 10 bins in each direction
159/// * `range` - Optional tuple ((x_min, x_max), (y_min, y_max)) to use. If None, the range is based on data
160/// * `weights` - Optional array of weights for each data point
161///
162/// # Returns
163///
164/// A tuple containing (histogram, x_edges, y_edges)
165///
166/// # Examples
167///
168/// ```
169/// use ndarray::array;
170/// use scirs2_core::ndarray_ext::stats::histogram2d;
171///
172/// let x = array![0.1, 0.5, 1.3, 2.5, 3.1, 3.8, 4.2, 4.9];
173/// let y = array![0.2, 0.8, 1.5, 2.0, 3.0, 3.2, 3.5, 4.5];
174/// let (hist, x_edges, y_edges) = histogram2d(x.view(), y.view(), Some((4, 4)), None, None).unwrap();
175///
176/// assert_eq!(hist.shape(), &[4, 4]);
177/// assert_eq!(x_edges.len(), 5);
178/// assert_eq!(y_edges.len(), 5);
179/// ```
180#[allow(dead_code)]
181pub fn histogram2d<T>(
182    x: ArrayView<T, Ix1>,
183    y: ArrayView<T, Ix1>,
184    bins: Option<(usize, usize)>,
185    range: Option<((T, T), (T, T))>,
186    weights: Option<ArrayView<T, Ix1>>,
187) -> Histogram2dResult<T>
188where
189    T: Clone + Float + FromPrimitive,
190{
191    if x.is_empty() || y.is_empty() {
192        return Err("Cannot compute histogram of empty arrays");
193    }
194
195    if x.len() != y.len() {
196        return Err("x and y arrays must have the same length");
197    }
198
199    // Default to 10 bins in each direction if not specified
200    let (x_bins, y_bins) = bins.unwrap_or((10, 10));
201
202    if x_bins == 0 || y_bins == 0 {
203        return Err("Number of bins must be positive");
204    }
205
206    // Get range for x and y
207    let ((x_min, x_max), (y_min, y_max)) = match range {
208        Some(r) => r,
209        None => {
210            let mut x_min = T::infinity();
211            let mut x_max = T::neg_infinity();
212            let mut y_min = T::infinity();
213            let mut y_max = T::neg_infinity();
214
215            for (&x_val, &y_val) in x.iter().zip(y.iter()) {
216                if x_val < x_min {
217                    x_min = x_val;
218                }
219                if x_val > x_max {
220                    x_max = x_val;
221                }
222                if y_val < y_min {
223                    y_min = y_val;
224                }
225                if y_val > y_max {
226                    y_max = y_val;
227                }
228            }
229            ((x_min, x_max), (y_min, y_max))
230        }
231    };
232
233    if x_min >= x_max || y_min >= y_max {
234        return Err("Range must be (min, max) with min < max");
235    }
236
237    // Create bin edges
238    let mut x_edges = Array::<T, Ix1>::zeros(x_bins + 1);
239    let mut y_edges = Array::<T, Ix1>::zeros(y_bins + 1);
240
241    let x_bin_width = (x_max - x_min) / T::from_usize(x_bins).unwrap();
242    let y_bin_width = (y_max - y_min) / T::from_usize(y_bins).unwrap();
243
244    for i in 0..=x_bins {
245        x_edges[i] = x_min + x_bin_width * T::from_usize(i).unwrap();
246    }
247
248    for i in 0..=y_bins {
249        y_edges[i] = y_min + y_bin_width * T::from_usize(i).unwrap();
250    }
251
252    // Ensure the last bin edges are exactly max values
253    x_edges[x_bins] = x_max;
254    y_edges[y_bins] = y_max;
255
256    // Initialize histogram array
257    let mut hist = Array::<usize, Ix2>::zeros((y_bins, x_bins));
258
259    // Fill histogram
260    match weights {
261        Some(w) => {
262            if w.len() != x.len() {
263                return Err("Weights array must have the same length as the data arrays");
264            }
265
266            for ((&x_val, &y_val), &weight) in x.iter().zip(y.iter()).zip(w.iter()) {
267                // Skip values outside the range
268                if x_val < x_min || x_val > x_max || y_val < y_min || y_val > y_max {
269                    continue;
270                }
271
272                // Find bin indices
273                let x_scaled = (x_val - x_min) / x_bin_width;
274                let y_scaled = (y_val - y_min) / y_bin_width;
275
276                let mut x_idx = x_scaled.to_usize().unwrap_or(0);
277                let mut y_idx = y_scaled.to_usize().unwrap_or(0);
278
279                // Handle edge cases where val == max_val
280                if x_val == x_max {
281                    x_idx = x_bins - 1;
282                } else {
283                    x_idx = x_idx.min(x_bins - 1);
284                }
285
286                if y_val == y_max {
287                    y_idx = y_bins - 1;
288                } else {
289                    y_idx = y_idx.min(y_bins - 1);
290                }
291
292                // Add to histogram (with weight)
293                let weight_int = weight.to_usize().unwrap_or(1);
294                hist[[y_idx, x_idx]] += weight_int;
295            }
296        }
297        None => {
298            for (&x_val, &y_val) in x.iter().zip(y.iter()) {
299                // Skip values outside the range
300                if x_val < x_min || x_val > x_max || y_val < y_min || y_val > y_max {
301                    continue;
302                }
303
304                // Find bin indices
305                let x_scaled = (x_val - x_min) / x_bin_width;
306                let y_scaled = (y_val - y_min) / y_bin_width;
307
308                let mut x_idx = x_scaled.to_usize().unwrap_or(0);
309                let mut y_idx = y_scaled.to_usize().unwrap_or(0);
310
311                // Handle edge cases where val == max_val
312                if x_val == x_max {
313                    x_idx = x_bins - 1;
314                } else {
315                    x_idx = x_idx.min(x_bins - 1);
316                }
317
318                if y_val == y_max {
319                    y_idx = y_bins - 1;
320                } else {
321                    y_idx = y_idx.min(y_bins - 1);
322                }
323
324                // Add to histogram
325                hist[[y_idx, x_idx]] += 1;
326            }
327        }
328    }
329
330    Ok((hist, x_edges, y_edges))
331}
332
333/// Calculate the quantile values from a 1D array
334///
335/// # Arguments
336///
337/// * `array` - The input 1D array
338/// * `q` - The quantile or array of quantiles to compute (between 0 and 1)
339/// * `method` - The interpolation method to use: "linear" (default), "lower", "higher", "midpoint", or "nearest"
340///
341/// # Returns
342///
343/// An array containing the quantile values
344///
345/// # Examples
346///
347/// ```
348/// use ndarray::array;
349/// use scirs2_core::ndarray_ext::stats::quantile;
350///
351/// let data = array![1.0, 3.0, 5.0, 7.0, 9.0];
352///
353/// // Median (50th percentile)
354/// let median = quantile(data.view(), array![0.5].view(), Some("linear")).unwrap();
355/// assert_eq!(median[0], 5.0);
356///
357/// // Multiple quantiles
358/// let quartiles = quantile(data.view(), array![0.25, 0.5, 0.75].view(), None).unwrap();
359/// assert_eq!(quartiles[0], 3.0);  // 25th percentile
360/// assert_eq!(quartiles[1], 5.0);  // 50th percentile
361/// assert_eq!(quartiles[2], 7.0);  // 75th percentile
362/// ```
363///
364/// This function is equivalent to ``NumPy``'s `np.quantile` function.
365#[allow(dead_code)]
366pub fn quantile<T>(
367    array: ArrayView<T, Ix1>,
368    q: ArrayView<T, Ix1>,
369    method: Option<&str>,
370) -> Result<Array<T, Ix1>, &'static str>
371where
372    T: Clone + Float + FromPrimitive,
373{
374    if array.is_empty() {
375        return Err("Cannot compute quantile of an empty array");
376    }
377
378    // Validate q values
379    for &val in q.iter() {
380        if val < T::from_f64(0.0).unwrap() || val > T::from_f64(1.0).unwrap() {
381            return Err("Quantile values must be between 0 and 1");
382        }
383    }
384
385    // Clone and sort the array
386    let mut sorted: Vec<T> = array.iter().copied().collect();
387    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
388
389    let n = sorted.len();
390    let mut result = Array::<T, Ix1>::zeros(q.len());
391
392    // The interpolation method to use
393    let method = method.unwrap_or("linear");
394
395    for (i, &q_val) in q.iter().enumerate() {
396        if q_val == T::from_f64(0.0).unwrap() {
397            result[i] = sorted[0];
398            continue;
399        }
400
401        if q_val == T::from_f64(1.0).unwrap() {
402            result[i] = sorted[n - 1];
403            continue;
404        }
405
406        // Calculate the position in the sorted array
407        let h = T::from_usize(n - 1).unwrap() * q_val;
408        let h_floor = h.floor();
409        let idx_low = h_floor.to_usize().unwrap_or(0).min(n - 1);
410        let idx_high = (idx_low + 1).min(n - 1);
411
412        match method {
413            "linear" => {
414                let weight = h - h_floor;
415                result[i] = sorted[idx_low] * (T::from_f64(1.0).unwrap() - weight) + sorted[idx_high] * weight;
416            }
417            "lower" => {
418                result[i] = sorted[idx_low];
419            }
420            "higher" => {
421                result[i] = sorted[idx_high];
422            }
423            "midpoint" => {
424                result[i] = (sorted[idx_low] + sorted[idx_high]) / T::from_f64(2.0).unwrap();
425            }
426            "nearest" => {
427                let weight = h - h_floor;
428                if weight < T::from_f64(0.5).unwrap() {
429                    result[i] = sorted[idx_low];
430                } else {
431                    result[i] = sorted[idx_high];
432                }
433            }
434            _ => return Err("Invalid interpolation method. Use 'linear', 'lower', 'higher', 'midpoint', or 'nearest'"),
435        }
436    }
437
438    Ok(result)
439}
440
441/// Count number of occurrences of each value in array of non-negative ints.
442///
443/// # Arguments
444///
445/// * `array` - Input array of non-negative integers
446/// * `minlength` - Minimum number of bins for the output array. If None, the output array length is determined by the maximum value in `array`.
447/// * `weights` - Optional weights array. If specified, must have same shape as `array`.
448///
449/// # Returns
450///
451/// An array of counts
452///
453/// # Examples
454///
455/// ```
456/// use ndarray::array;
457/// use scirs2_core::ndarray_ext::stats::bincount;
458///
459/// let data = array![1, 2, 3, 1, 2, 1, 0, 1, 3, 2];
460/// let counts = bincount(data.view(), None, None).unwrap();
461/// assert_eq!(counts.len(), 4);
462/// assert_eq!(counts[0], 1.0); // '0' occurs once
463/// assert_eq!(counts[1], 4.0); // '1' occurs four times
464/// assert_eq!(counts[2], 3.0); // '2' occurs three times
465/// assert_eq!(counts[3], 2.0); // '3' occurs twice
466/// ```
467///
468/// This function is equivalent to ``NumPy``'s `np.bincount` function.
469#[allow(dead_code)]
470pub fn bincount(
471    array: ArrayView<usize, Ix1>,
472    minlength: Option<usize>,
473    weights: Option<ArrayView<f64, Ix1>>,
474) -> Result<Array<f64, Ix1>, &'static str> {
475    if array.is_empty() {
476        return Err("Cannot compute bincount of an empty array");
477    }
478
479    // Find maximum value to determine number of bins
480    let mut max_val = 0;
481    for &val in array.iter() {
482        if val > max_val {
483            max_val = val;
484        }
485    }
486
487    // Determine length of output array
488    let length = if let Some(min_len) = minlength {
489        max_val.max(min_len - 1) + 1
490    } else {
491        max_val + 1
492    };
493
494    let mut result = Array::<f64, Ix1>::zeros(length);
495
496    match weights {
497        Some(w) => {
498            if w.len() != array.len() {
499                return Err("Weights array must have same length as input array");
500            }
501            for (&idx, &weight) in array.iter().zip(w.iter()) {
502                result[idx] += weight;
503            }
504        }
505        None => {
506            for &idx in array.iter() {
507                result[idx] += 1.0;
508            }
509        }
510    }
511
512    Ok(result)
513}
514
515/// Return the indices of the bins to which each value in input array belongs.
516///
517/// # Arguments
518///
519/// * `array` - Input array
520/// * `bins` - Array of bin edges
521/// * `right` - Indicates whether the intervals include the right or left bin edge
522/// * `result_type` - Whether to return the indices ('indices') or the bin values ('values')
523///
524/// # Returns
525///
526/// Array of indices or values depending on result_type
527///
528/// # Examples
529///
530/// ```
531/// use ndarray::array;
532/// use scirs2_core::ndarray_ext::stats::digitize;
533///
534/// let data = array![1.2, 3.5, 5.1, 0.8, 2.9];
535/// let bins = array![1.0, 3.0, 5.0];
536/// let indices = digitize(data.view(), bins.view(), false, "indices").unwrap();
537///
538/// assert_eq!(indices[0], 1); // 1.2 is in the first bin (1.0 <= x < 3.0)
539/// assert_eq!(indices[1], 2); // 3.5 is in the second bin (3.0 <= x < 5.0)
540/// assert_eq!(indices[2], 3); // 5.1 is after the last bin (>= 5.0)
541/// assert_eq!(indices[3], 0); // 0.8 is before the first bin (< 1.0)
542/// assert_eq!(indices[4], 1); // 2.9 is in the first bin (1.0 <= x < 3.0)
543/// ```
544///
545/// This function is equivalent to ``NumPy``'s `np.digitize` function.
546#[allow(dead_code)]
547pub fn digitize<T>(
548    array: ArrayView<T, Ix1>,
549    bins: ArrayView<T, Ix1>,
550    right: bool,
551    result_type: &str,
552) -> Result<Array<usize, Ix1>, &'static str>
553where
554    T: Clone + Float + FromPrimitive,
555{
556    if array.is_empty() {
557        return Err("Cannot digitize an empty array");
558    }
559
560    if bins.is_empty() {
561        return Err("Bins array cannot be empty");
562    }
563
564    // Check that bins are monotonically increasing
565    for i in 1..bins.len() {
566        if bins[i] <= bins[i.saturating_sub(1)] {
567            return Err("Bins must be monotonically increasing");
568        }
569    }
570
571    let mut result = Array::<usize, Ix1>::zeros(array.len());
572
573    for (i, &val) in array.iter().enumerate() {
574        let mut bin_idx = 0;
575
576        if right {
577            // Right inclusive: val <= edge
578            for j in 0..bins.len() {
579                if val <= bins[j] {
580                    bin_idx = j;
581                    break;
582                }
583                bin_idx = bins.len();
584            }
585        } else {
586            // Left inclusive: val < edge
587            for j in 0..bins.len() {
588                if val < bins[j] {
589                    bin_idx = j;
590                    break;
591                }
592                bin_idx = bins.len();
593            }
594        }
595
596        result[i] = bin_idx;
597    }
598
599    if result_type == "indices" {
600        Ok(result)
601    } else if result_type == "values" {
602        Err("'values' result_type is not yet implemented")
603    } else {
604        Err("result_type must be 'indices' or 'values'")
605    }
606}