scirs2_core/ndarray_ext/stats/
distribution.rs

1//! Distribution-related functions for ndarray arrays
2//!
3//! This module provides functions for working with data distributions,
4//! including histograms, binning, and quantile calculations.
5
6use ::ndarray::{Array, ArrayView, Ix1, Ix2};
7use num_traits::{Float, FromPrimitive};
8
9/// Result type for histogram function
10pub type HistogramResult<T> = Result<(Array<usize, Ix1>, Array<T, Ix1>), &'static str>;
11
12/// Result type for histogram2d function
13pub type Histogram2dResult<T> =
14    Result<(Array<usize, Ix2>, Array<T, Ix1>, Array<T, Ix1>), &'static str>;
15
16/// Calculate a histogram of data
17///
18/// # Arguments
19///
20/// * `array` - The input 1D array
21/// * `bins` - The number of bins
22/// * `range` - Optional tuple (min, max) to use. If None, the range is based on data
23/// * `weights` - Optional array of weights for each data point
24///
25/// # Returns
26///
27/// A tuple containing (histogram, bin_edges)
28///
29/// # Examples
30///
31/// ```
32/// use ::ndarray::array;
33/// use scirs2_core::ndarray_ext::stats::histogram;
34///
35/// let data = array![0.1, 0.5, 1.1, 1.5, 2.2, 2.9, 3.1, 3.8, 4.1, 4.9];
36/// let (hist, bin_edges) = histogram(data.view(), 5, None, None).expect("Operation failed");
37///
38/// assert_eq!(hist.len(), 5);
39/// assert_eq!(bin_edges.len(), 6);
40/// ```
41#[allow(dead_code)]
42pub fn histogram<T>(
43    array: ArrayView<T, Ix1>,
44    bins: usize,
45    range: Option<(T, T)>,
46    weights: Option<ArrayView<T, Ix1>>,
47) -> HistogramResult<T>
48where
49    T: Clone + Float + FromPrimitive,
50{
51    if array.is_empty() {
52        return Err("Cannot compute histogram of an empty array");
53    }
54
55    if bins == 0 {
56        return Err("Number of bins must be positive");
57    }
58
59    // Get range (min, max) of the data
60    let (min_val, max_val) = match range {
61        Some(r) => r,
62        None => {
63            let mut min_val = T::infinity();
64            let mut max_val = T::neg_infinity();
65
66            for &val in array.iter() {
67                if val < min_val {
68                    min_val = val;
69                }
70                if val > max_val {
71                    max_val = val;
72                }
73            }
74            (min_val, max_val)
75        }
76    };
77
78    if min_val >= max_val {
79        return Err("Range must be (min, max) with min < max");
80    }
81
82    // Create bin edges
83    let mut bin_edges = Array::<T, Ix1>::zeros(bins + 1);
84    let bin_width = (max_val - min_val) / T::from_usize(bins).expect("Operation failed");
85
86    for i in 0..=bins {
87        bin_edges[i] = min_val + bin_width * T::from_usize(i).expect("Operation failed");
88    }
89
90    // Ensure the last bin edge is exactly max_val
91    bin_edges[bins] = max_val;
92
93    // Initialize histogram array
94    let mut hist = Array::<usize, Ix1>::zeros(bins);
95
96    // Fill histogram
97    match weights {
98        Some(w) => {
99            if w.len() != array.len() {
100                return Err("Weights array must have the same length as the data array");
101            }
102
103            for (&val, &weight) in array.iter().zip(w.iter()) {
104                // Skip values outside the range
105                if val < min_val || val > max_val {
106                    continue;
107                }
108
109                // Handle edge case where val == max_val (include in the last bin)
110                if val == max_val {
111                    hist[bins - 1] += 1;
112                    continue;
113                }
114
115                // Find bin index
116                let scaled_val = (val - min_val) / bin_width;
117                let bin_idx = scaled_val.to_usize().unwrap_or(0);
118                let bin_idx = bin_idx.min(bins - 1); // Ensure index is in bounds
119
120                // Add to histogram (with weight)
121                let weight_int = weight.to_usize().unwrap_or(1);
122                hist[bin_idx] += weight_int;
123            }
124        }
125        None => {
126            for &val in array.iter() {
127                // Skip values outside the range
128                if val < min_val || val > max_val {
129                    continue;
130                }
131
132                // Handle edge case where val == max_val (include in the last bin)
133                if val == max_val {
134                    hist[bins - 1] += 1;
135                    continue;
136                }
137
138                // Find bin index
139                let scaled_val = (val - min_val) / bin_width;
140                let bin_idx = scaled_val.to_usize().unwrap_or(0);
141                let bin_idx = bin_idx.min(bins - 1); // Ensure index is in bounds
142
143                // Add to histogram
144                hist[bin_idx] += 1;
145            }
146        }
147    }
148
149    Ok((hist, bin_edges))
150}
151
152/// Calculate a 2D histogram of data
153///
154/// # Arguments
155///
156/// * `x` - The x coordinates of the data points
157/// * `y` - The y coordinates of the data points
158/// * `bins` - Either a tuple (x_bins, y_bins) for the number of bins, or None for 10 bins in each direction
159/// * `range` - Optional tuple ((x_min, x_max), (y_min, y_max)) to use. If None, the range is based on data
160/// * `weights` - Optional array of weights for each data point
161///
162/// # Returns
163///
164/// A tuple containing (histogram, x_edges, y_edges)
165///
166/// # Examples
167///
168/// ```
169/// use ::ndarray::array;
170/// use scirs2_core::ndarray_ext::stats::histogram2d;
171///
172/// let x = array![0.1, 0.5, 1.3, 2.5, 3.1, 3.8, 4.2, 4.9];
173/// let y = array![0.2, 0.8, 1.5, 2.0, 3.0, 3.2, 3.5, 4.5];
174/// let (hist, x_edges, y_edges) = histogram2d(x.view(), y.view(), Some((4, 4)), None, None).expect("Operation failed");
175///
176/// assert_eq!(hist.shape(), &[4, 4]);
177/// assert_eq!(x_edges.len(), 5);
178/// assert_eq!(y_edges.len(), 5);
179/// ```
180#[allow(dead_code)]
181pub fn histogram2d<T>(
182    x: ArrayView<T, Ix1>,
183    y: ArrayView<T, Ix1>,
184    bins: Option<(usize, usize)>,
185    range: Option<((T, T), (T, T))>,
186    weights: Option<ArrayView<T, Ix1>>,
187) -> Histogram2dResult<T>
188where
189    T: Clone + Float + FromPrimitive,
190{
191    if x.is_empty() || y.is_empty() {
192        return Err("Cannot compute histogram of empty arrays");
193    }
194
195    if x.len() != y.len() {
196        return Err("x and y arrays must have the same length");
197    }
198
199    // Default to 10 bins in each direction if not specified
200    let (x_bins, y_bins) = bins.unwrap_or((10, 10));
201
202    if x_bins == 0 || y_bins == 0 {
203        return Err("Number of bins must be positive");
204    }
205
206    // Get range for x and y
207    let ((x_min, x_max), (y_min, y_max)) = match range {
208        Some(r) => r,
209        None => {
210            let mut x_min = T::infinity();
211            let mut x_max = T::neg_infinity();
212            let mut y_min = T::infinity();
213            let mut y_max = T::neg_infinity();
214
215            for (&x_val, &y_val) in x.iter().zip(y.iter()) {
216                if x_val < x_min {
217                    x_min = x_val;
218                }
219                if x_val > x_max {
220                    x_max = x_val;
221                }
222                if y_val < y_min {
223                    y_min = y_val;
224                }
225                if y_val > y_max {
226                    y_max = y_val;
227                }
228            }
229            ((x_min, x_max), (y_min, y_max))
230        }
231    };
232
233    if x_min >= x_max || y_min >= y_max {
234        return Err("Range must be (min, max) with min < max");
235    }
236
237    // Create bin edges
238    let mut x_edges = Array::<T, Ix1>::zeros(x_bins + 1);
239    let mut y_edges = Array::<T, Ix1>::zeros(y_bins + 1);
240
241    let x_bin_width = (x_max - x_min) / T::from_usize(x_bins).expect("Operation failed");
242    let y_bin_width = (y_max - y_min) / T::from_usize(y_bins).expect("Operation failed");
243
244    for i in 0..=x_bins {
245        x_edges[i] = x_min + x_bin_width * T::from_usize(i).expect("Operation failed");
246    }
247
248    for i in 0..=y_bins {
249        y_edges[i] = y_min + y_bin_width * T::from_usize(i).expect("Operation failed");
250    }
251
252    // Ensure the last bin edges are exactly max values
253    x_edges[x_bins] = x_max;
254    y_edges[y_bins] = y_max;
255
256    // Initialize histogram array
257    let mut hist = Array::<usize, Ix2>::zeros((y_bins, x_bins));
258
259    // Fill histogram
260    match weights {
261        Some(w) => {
262            if w.len() != x.len() {
263                return Err("Weights array must have the same length as the data arrays");
264            }
265
266            for ((&x_val, &y_val), &weight) in x.iter().zip(y.iter()).zip(w.iter()) {
267                // Skip values outside the range
268                if x_val < x_min || x_val > x_max || y_val < y_min || y_val > y_max {
269                    continue;
270                }
271
272                // Find bin indices
273                let x_scaled = (x_val - x_min) / x_bin_width;
274                let y_scaled = (y_val - y_min) / y_bin_width;
275
276                let mut x_idx = x_scaled.to_usize().unwrap_or(0);
277                let mut y_idx = y_scaled.to_usize().unwrap_or(0);
278
279                // Handle edge cases where val == max_val
280                if x_val == x_max {
281                    x_idx = x_bins - 1;
282                } else {
283                    x_idx = x_idx.min(x_bins - 1);
284                }
285
286                if y_val == y_max {
287                    y_idx = y_bins - 1;
288                } else {
289                    y_idx = y_idx.min(y_bins - 1);
290                }
291
292                // Add to histogram (with weight)
293                let weight_int = weight.to_usize().unwrap_or(1);
294                hist[[y_idx, x_idx]] += weight_int;
295            }
296        }
297        None => {
298            for (&x_val, &y_val) in x.iter().zip(y.iter()) {
299                // Skip values outside the range
300                if x_val < x_min || x_val > x_max || y_val < y_min || y_val > y_max {
301                    continue;
302                }
303
304                // Find bin indices
305                let x_scaled = (x_val - x_min) / x_bin_width;
306                let y_scaled = (y_val - y_min) / y_bin_width;
307
308                let mut x_idx = x_scaled.to_usize().unwrap_or(0);
309                let mut y_idx = y_scaled.to_usize().unwrap_or(0);
310
311                // Handle edge cases where val == max_val
312                if x_val == x_max {
313                    x_idx = x_bins - 1;
314                } else {
315                    x_idx = x_idx.min(x_bins - 1);
316                }
317
318                if y_val == y_max {
319                    y_idx = y_bins - 1;
320                } else {
321                    y_idx = y_idx.min(y_bins - 1);
322                }
323
324                // Add to histogram
325                hist[[y_idx, x_idx]] += 1;
326            }
327        }
328    }
329
330    Ok((hist, x_edges, y_edges))
331}
332
333/// Calculate the quantile values from a 1D array
334///
335/// # Arguments
336///
337/// * `array` - The input 1D array
338/// * `q` - The quantile or array of quantiles to compute (between 0 and 1)
339/// * `method` - The interpolation method to use: "linear" (default), "lower", "higher", "midpoint", or "nearest"
340///
341/// # Returns
342///
343/// An array containing the quantile values
344///
345/// # Examples
346///
347/// ```
348/// use ::ndarray::array;
349/// use scirs2_core::ndarray_ext::stats::quantile;
350///
351/// let data = array![1.0, 3.0, 5.0, 7.0, 9.0];
352///
353/// // Median (50th percentile)
354/// let median = quantile(data.view(), array![0.5].view(), Some("linear")).expect("Operation failed");
355/// assert_eq!(median[0], 5.0);
356///
357/// // Multiple quantiles
358/// let quartiles = quantile(data.view(), array![0.25, 0.5, 0.75].view(), None).expect("Operation failed");
359/// assert_eq!(quartiles[0], 3.0);  // 25th percentile
360/// assert_eq!(quartiles[1], 5.0);  // 50th percentile
361/// assert_eq!(quartiles[2], 7.0);  // 75th percentile
362/// ```
363///
364/// This function is equivalent to ``NumPy``'s `np.quantile` function.
365#[allow(dead_code)]
366pub fn quantile<T>(
367    array: ArrayView<T, Ix1>,
368    q: ArrayView<T, Ix1>,
369    method: Option<&str>,
370) -> Result<Array<T, Ix1>, &'static str>
371where
372    T: Clone + Float + FromPrimitive,
373{
374    if array.is_empty() {
375        return Err("Cannot compute quantile of an empty array");
376    }
377
378    // Validate q values
379    for &val in q.iter() {
380        if val < T::from_f64(0.0).expect("Operation failed")
381            || val > T::from_f64(1.0).expect("Operation failed")
382        {
383            return Err("Quantile values must be between 0 and 1");
384        }
385    }
386
387    // Clone and sort the array
388    let mut sorted: Vec<T> = array.iter().copied().collect();
389    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
390
391    let n = sorted.len();
392    let mut result = Array::<T, Ix1>::zeros(q.len());
393
394    // The interpolation method to use
395    let method = method.unwrap_or("linear");
396
397    for (i, &q_val) in q.iter().enumerate() {
398        if q_val == T::from_f64(0.0).expect("Operation failed") {
399            result[i] = sorted[0];
400            continue;
401        }
402
403        if q_val == T::from_f64(1.0).expect("Operation failed") {
404            result[i] = sorted[n - 1];
405            continue;
406        }
407
408        // Calculate the position in the sorted array
409        let h = T::from_usize(n - 1).expect("Operation failed") * q_val;
410        let h_floor = h.floor();
411        let idx_low = h_floor.to_usize().unwrap_or(0).min(n - 1);
412        let idx_high = (idx_low + 1).min(n - 1);
413
414        match method {
415            "linear" => {
416                let weight = h - h_floor;
417                result[i] = sorted[idx_low] * (T::from_f64(1.0).expect("Operation failed") - weight) + sorted[idx_high] * weight;
418            }
419            "lower" => {
420                result[i] = sorted[idx_low];
421            }
422            "higher" => {
423                result[i] = sorted[idx_high];
424            }
425            "midpoint" => {
426                result[i] = (sorted[idx_low] + sorted[idx_high]) / T::from_f64(2.0).expect("Operation failed");
427            }
428            "nearest" => {
429                let weight = h - h_floor;
430                if weight < T::from_f64(0.5).expect("Operation failed") {
431                    result[i] = sorted[idx_low];
432                } else {
433                    result[i] = sorted[idx_high];
434                }
435            }
436            _ => return Err("Invalid interpolation method. Use 'linear', 'lower', 'higher', 'midpoint', or 'nearest'"),
437        }
438    }
439
440    Ok(result)
441}
442
443/// Count number of occurrences of each value in array of non-negative ints.
444///
445/// # Arguments
446///
447/// * `array` - Input array of non-negative integers
448/// * `minlength` - Minimum number of bins for the output array. If None, the output array length is determined by the maximum value in `array`.
449/// * `weights` - Optional weights array. If specified, must have same shape as `array`.
450///
451/// # Returns
452///
453/// An array of counts
454///
455/// # Examples
456///
457/// ```
458/// use ::ndarray::array;
459/// use scirs2_core::ndarray_ext::stats::bincount;
460///
461/// let data = array![1, 2, 3, 1, 2, 1, 0, 1, 3, 2];
462/// let counts = bincount(data.view(), None, None).expect("Operation failed");
463/// assert_eq!(counts.len(), 4);
464/// assert_eq!(counts[0], 1.0); // '0' occurs once
465/// assert_eq!(counts[1], 4.0); // '1' occurs four times
466/// assert_eq!(counts[2], 3.0); // '2' occurs three times
467/// assert_eq!(counts[3], 2.0); // '3' occurs twice
468/// ```
469///
470/// This function is equivalent to ``NumPy``'s `np.bincount` function.
471#[allow(dead_code)]
472pub fn bincount(
473    array: ArrayView<usize, Ix1>,
474    minlength: Option<usize>,
475    weights: Option<ArrayView<f64, Ix1>>,
476) -> Result<Array<f64, Ix1>, &'static str> {
477    if array.is_empty() {
478        return Err("Cannot compute bincount of an empty array");
479    }
480
481    // Find maximum value to determine number of bins
482    let mut max_val = 0;
483    for &val in array.iter() {
484        if val > max_val {
485            max_val = val;
486        }
487    }
488
489    // Determine length of output array
490    let length = if let Some(min_len) = minlength {
491        max_val.max(min_len - 1) + 1
492    } else {
493        max_val + 1
494    };
495
496    let mut result = Array::<f64, Ix1>::zeros(length);
497
498    match weights {
499        Some(w) => {
500            if w.len() != array.len() {
501                return Err("Weights array must have same length as input array");
502            }
503            for (&idx, &weight) in array.iter().zip(w.iter()) {
504                result[idx] += weight;
505            }
506        }
507        None => {
508            for &idx in array.iter() {
509                result[idx] += 1.0;
510            }
511        }
512    }
513
514    Ok(result)
515}
516
517/// Return the indices of the bins to which each value in input array belongs.
518///
519/// # Arguments
520///
521/// * `array` - Input array
522/// * `bins` - Array of bin edges
523/// * `right` - Indicates whether the intervals include the right or left bin edge
524/// * `result_type` - Whether to return the indices ('indices') or the bin values ('values')
525///
526/// # Returns
527///
528/// Array of indices or values depending on result_type
529///
530/// # Examples
531///
532/// ```
533/// use ::ndarray::array;
534/// use scirs2_core::ndarray_ext::stats::digitize;
535///
536/// let data = array![1.2, 3.5, 5.1, 0.8, 2.9];
537/// let bins = array![1.0, 3.0, 5.0];
538/// let indices = digitize(data.view(), bins.view(), false, "indices").expect("Operation failed");
539///
540/// assert_eq!(indices[0], 1); // 1.2 is in the first bin (1.0 <= x < 3.0)
541/// assert_eq!(indices[1], 2); // 3.5 is in the second bin (3.0 <= x < 5.0)
542/// assert_eq!(indices[2], 3); // 5.1 is after the last bin (>= 5.0)
543/// assert_eq!(indices[3], 0); // 0.8 is before the first bin (< 1.0)
544/// assert_eq!(indices[4], 1); // 2.9 is in the first bin (1.0 <= x < 3.0)
545/// ```
546///
547/// This function is equivalent to ``NumPy``'s `np.digitize` function.
548#[allow(dead_code)]
549pub fn digitize<T>(
550    array: ArrayView<T, Ix1>,
551    bins: ArrayView<T, Ix1>,
552    right: bool,
553    result_type: &str,
554) -> Result<Array<usize, Ix1>, &'static str>
555where
556    T: Clone + Float + FromPrimitive,
557{
558    if array.is_empty() {
559        return Err("Cannot digitize an empty array");
560    }
561
562    if bins.is_empty() {
563        return Err("Bins array cannot be empty");
564    }
565
566    // Check that bins are monotonically increasing
567    for i in 1..bins.len() {
568        if bins[i] <= bins[i.saturating_sub(1)] {
569            return Err("Bins must be monotonically increasing");
570        }
571    }
572
573    let mut result = Array::<usize, Ix1>::zeros(array.len());
574
575    for (i, &val) in array.iter().enumerate() {
576        let mut bin_idx = 0;
577
578        if right {
579            // Right inclusive: val <= edge
580            for j in 0..bins.len() {
581                if val <= bins[j] {
582                    bin_idx = j;
583                    break;
584                }
585                bin_idx = bins.len();
586            }
587        } else {
588            // Left inclusive: val < edge
589            for j in 0..bins.len() {
590                if val < bins[j] {
591                    bin_idx = j;
592                    break;
593                }
594                bin_idx = bins.len();
595            }
596        }
597
598        result[i] = bin_idx;
599    }
600
601    if result_type == "indices" {
602        Ok(result)
603    } else if result_type == "values" {
604        Err("'values' result_type is not yet implemented")
605    } else {
606        Err("result_type must be 'indices' or 'values'")
607    }
608}