scirs2_core/ndarray_ext/stats/distribution.rs
1//! Distribution-related functions for ndarray arrays
2//!
3//! This module provides functions for working with data distributions,
4//! including histograms, binning, and quantile calculations.
5
6use ::ndarray::{Array, ArrayView, Ix1, Ix2};
7use num_traits::{Float, FromPrimitive};
8
9/// Result type for histogram function
10pub type HistogramResult<T> = Result<(Array<usize, Ix1>, Array<T, Ix1>), &'static str>;
11
12/// Result type for histogram2d function
13pub type Histogram2dResult<T> =
14 Result<(Array<usize, Ix2>, Array<T, Ix1>, Array<T, Ix1>), &'static str>;
15
16/// Calculate a histogram of data
17///
18/// # Arguments
19///
20/// * `array` - The input 1D array
21/// * `bins` - The number of bins
22/// * `range` - Optional tuple (min, max) to use. If None, the range is based on data
23/// * `weights` - Optional array of weights for each data point
24///
25/// # Returns
26///
27/// A tuple containing (histogram, bin_edges)
28///
29/// # Examples
30///
31/// ```
32/// use ::ndarray::array;
33/// use scirs2_core::ndarray_ext::stats::histogram;
34///
35/// let data = array![0.1, 0.5, 1.1, 1.5, 2.2, 2.9, 3.1, 3.8, 4.1, 4.9];
36/// let (hist, bin_edges) = histogram(data.view(), 5, None, None).expect("Operation failed");
37///
38/// assert_eq!(hist.len(), 5);
39/// assert_eq!(bin_edges.len(), 6);
40/// ```
41#[allow(dead_code)]
42pub fn histogram<T>(
43 array: ArrayView<T, Ix1>,
44 bins: usize,
45 range: Option<(T, T)>,
46 weights: Option<ArrayView<T, Ix1>>,
47) -> HistogramResult<T>
48where
49 T: Clone + Float + FromPrimitive,
50{
51 if array.is_empty() {
52 return Err("Cannot compute histogram of an empty array");
53 }
54
55 if bins == 0 {
56 return Err("Number of bins must be positive");
57 }
58
59 // Get range (min, max) of the data
60 let (min_val, max_val) = match range {
61 Some(r) => r,
62 None => {
63 let mut min_val = T::infinity();
64 let mut max_val = T::neg_infinity();
65
66 for &val in array.iter() {
67 if val < min_val {
68 min_val = val;
69 }
70 if val > max_val {
71 max_val = val;
72 }
73 }
74 (min_val, max_val)
75 }
76 };
77
78 if min_val >= max_val {
79 return Err("Range must be (min, max) with min < max");
80 }
81
82 // Create bin edges
83 let mut bin_edges = Array::<T, Ix1>::zeros(bins + 1);
84 let bin_width = (max_val - min_val) / T::from_usize(bins).expect("Operation failed");
85
86 for i in 0..=bins {
87 bin_edges[i] = min_val + bin_width * T::from_usize(i).expect("Operation failed");
88 }
89
90 // Ensure the last bin edge is exactly max_val
91 bin_edges[bins] = max_val;
92
93 // Initialize histogram array
94 let mut hist = Array::<usize, Ix1>::zeros(bins);
95
96 // Fill histogram
97 match weights {
98 Some(w) => {
99 if w.len() != array.len() {
100 return Err("Weights array must have the same length as the data array");
101 }
102
103 for (&val, &weight) in array.iter().zip(w.iter()) {
104 // Skip values outside the range
105 if val < min_val || val > max_val {
106 continue;
107 }
108
109 // Handle edge case where val == max_val (include in the last bin)
110 if val == max_val {
111 hist[bins - 1] += 1;
112 continue;
113 }
114
115 // Find bin index
116 let scaled_val = (val - min_val) / bin_width;
117 let bin_idx = scaled_val.to_usize().unwrap_or(0);
118 let bin_idx = bin_idx.min(bins - 1); // Ensure index is in bounds
119
120 // Add to histogram (with weight)
121 let weight_int = weight.to_usize().unwrap_or(1);
122 hist[bin_idx] += weight_int;
123 }
124 }
125 None => {
126 for &val in array.iter() {
127 // Skip values outside the range
128 if val < min_val || val > max_val {
129 continue;
130 }
131
132 // Handle edge case where val == max_val (include in the last bin)
133 if val == max_val {
134 hist[bins - 1] += 1;
135 continue;
136 }
137
138 // Find bin index
139 let scaled_val = (val - min_val) / bin_width;
140 let bin_idx = scaled_val.to_usize().unwrap_or(0);
141 let bin_idx = bin_idx.min(bins - 1); // Ensure index is in bounds
142
143 // Add to histogram
144 hist[bin_idx] += 1;
145 }
146 }
147 }
148
149 Ok((hist, bin_edges))
150}
151
152/// Calculate a 2D histogram of data
153///
154/// # Arguments
155///
156/// * `x` - The x coordinates of the data points
157/// * `y` - The y coordinates of the data points
158/// * `bins` - Either a tuple (x_bins, y_bins) for the number of bins, or None for 10 bins in each direction
159/// * `range` - Optional tuple ((x_min, x_max), (y_min, y_max)) to use. If None, the range is based on data
160/// * `weights` - Optional array of weights for each data point
161///
162/// # Returns
163///
164/// A tuple containing (histogram, x_edges, y_edges)
165///
166/// # Examples
167///
168/// ```
169/// use ::ndarray::array;
170/// use scirs2_core::ndarray_ext::stats::histogram2d;
171///
172/// let x = array![0.1, 0.5, 1.3, 2.5, 3.1, 3.8, 4.2, 4.9];
173/// let y = array![0.2, 0.8, 1.5, 2.0, 3.0, 3.2, 3.5, 4.5];
174/// let (hist, x_edges, y_edges) = histogram2d(x.view(), y.view(), Some((4, 4)), None, None).expect("Operation failed");
175///
176/// assert_eq!(hist.shape(), &[4, 4]);
177/// assert_eq!(x_edges.len(), 5);
178/// assert_eq!(y_edges.len(), 5);
179/// ```
180#[allow(dead_code)]
181pub fn histogram2d<T>(
182 x: ArrayView<T, Ix1>,
183 y: ArrayView<T, Ix1>,
184 bins: Option<(usize, usize)>,
185 range: Option<((T, T), (T, T))>,
186 weights: Option<ArrayView<T, Ix1>>,
187) -> Histogram2dResult<T>
188where
189 T: Clone + Float + FromPrimitive,
190{
191 if x.is_empty() || y.is_empty() {
192 return Err("Cannot compute histogram of empty arrays");
193 }
194
195 if x.len() != y.len() {
196 return Err("x and y arrays must have the same length");
197 }
198
199 // Default to 10 bins in each direction if not specified
200 let (x_bins, y_bins) = bins.unwrap_or((10, 10));
201
202 if x_bins == 0 || y_bins == 0 {
203 return Err("Number of bins must be positive");
204 }
205
206 // Get range for x and y
207 let ((x_min, x_max), (y_min, y_max)) = match range {
208 Some(r) => r,
209 None => {
210 let mut x_min = T::infinity();
211 let mut x_max = T::neg_infinity();
212 let mut y_min = T::infinity();
213 let mut y_max = T::neg_infinity();
214
215 for (&x_val, &y_val) in x.iter().zip(y.iter()) {
216 if x_val < x_min {
217 x_min = x_val;
218 }
219 if x_val > x_max {
220 x_max = x_val;
221 }
222 if y_val < y_min {
223 y_min = y_val;
224 }
225 if y_val > y_max {
226 y_max = y_val;
227 }
228 }
229 ((x_min, x_max), (y_min, y_max))
230 }
231 };
232
233 if x_min >= x_max || y_min >= y_max {
234 return Err("Range must be (min, max) with min < max");
235 }
236
237 // Create bin edges
238 let mut x_edges = Array::<T, Ix1>::zeros(x_bins + 1);
239 let mut y_edges = Array::<T, Ix1>::zeros(y_bins + 1);
240
241 let x_bin_width = (x_max - x_min) / T::from_usize(x_bins).expect("Operation failed");
242 let y_bin_width = (y_max - y_min) / T::from_usize(y_bins).expect("Operation failed");
243
244 for i in 0..=x_bins {
245 x_edges[i] = x_min + x_bin_width * T::from_usize(i).expect("Operation failed");
246 }
247
248 for i in 0..=y_bins {
249 y_edges[i] = y_min + y_bin_width * T::from_usize(i).expect("Operation failed");
250 }
251
252 // Ensure the last bin edges are exactly max values
253 x_edges[x_bins] = x_max;
254 y_edges[y_bins] = y_max;
255
256 // Initialize histogram array
257 let mut hist = Array::<usize, Ix2>::zeros((y_bins, x_bins));
258
259 // Fill histogram
260 match weights {
261 Some(w) => {
262 if w.len() != x.len() {
263 return Err("Weights array must have the same length as the data arrays");
264 }
265
266 for ((&x_val, &y_val), &weight) in x.iter().zip(y.iter()).zip(w.iter()) {
267 // Skip values outside the range
268 if x_val < x_min || x_val > x_max || y_val < y_min || y_val > y_max {
269 continue;
270 }
271
272 // Find bin indices
273 let x_scaled = (x_val - x_min) / x_bin_width;
274 let y_scaled = (y_val - y_min) / y_bin_width;
275
276 let mut x_idx = x_scaled.to_usize().unwrap_or(0);
277 let mut y_idx = y_scaled.to_usize().unwrap_or(0);
278
279 // Handle edge cases where val == max_val
280 if x_val == x_max {
281 x_idx = x_bins - 1;
282 } else {
283 x_idx = x_idx.min(x_bins - 1);
284 }
285
286 if y_val == y_max {
287 y_idx = y_bins - 1;
288 } else {
289 y_idx = y_idx.min(y_bins - 1);
290 }
291
292 // Add to histogram (with weight)
293 let weight_int = weight.to_usize().unwrap_or(1);
294 hist[[y_idx, x_idx]] += weight_int;
295 }
296 }
297 None => {
298 for (&x_val, &y_val) in x.iter().zip(y.iter()) {
299 // Skip values outside the range
300 if x_val < x_min || x_val > x_max || y_val < y_min || y_val > y_max {
301 continue;
302 }
303
304 // Find bin indices
305 let x_scaled = (x_val - x_min) / x_bin_width;
306 let y_scaled = (y_val - y_min) / y_bin_width;
307
308 let mut x_idx = x_scaled.to_usize().unwrap_or(0);
309 let mut y_idx = y_scaled.to_usize().unwrap_or(0);
310
311 // Handle edge cases where val == max_val
312 if x_val == x_max {
313 x_idx = x_bins - 1;
314 } else {
315 x_idx = x_idx.min(x_bins - 1);
316 }
317
318 if y_val == y_max {
319 y_idx = y_bins - 1;
320 } else {
321 y_idx = y_idx.min(y_bins - 1);
322 }
323
324 // Add to histogram
325 hist[[y_idx, x_idx]] += 1;
326 }
327 }
328 }
329
330 Ok((hist, x_edges, y_edges))
331}
332
333/// Calculate the quantile values from a 1D array
334///
335/// # Arguments
336///
337/// * `array` - The input 1D array
338/// * `q` - The quantile or array of quantiles to compute (between 0 and 1)
339/// * `method` - The interpolation method to use: "linear" (default), "lower", "higher", "midpoint", or "nearest"
340///
341/// # Returns
342///
343/// An array containing the quantile values
344///
345/// # Examples
346///
347/// ```
348/// use ::ndarray::array;
349/// use scirs2_core::ndarray_ext::stats::quantile;
350///
351/// let data = array![1.0, 3.0, 5.0, 7.0, 9.0];
352///
353/// // Median (50th percentile)
354/// let median = quantile(data.view(), array![0.5].view(), Some("linear")).expect("Operation failed");
355/// assert_eq!(median[0], 5.0);
356///
357/// // Multiple quantiles
358/// let quartiles = quantile(data.view(), array![0.25, 0.5, 0.75].view(), None).expect("Operation failed");
359/// assert_eq!(quartiles[0], 3.0); // 25th percentile
360/// assert_eq!(quartiles[1], 5.0); // 50th percentile
361/// assert_eq!(quartiles[2], 7.0); // 75th percentile
362/// ```
363///
364/// This function is equivalent to ``NumPy``'s `np.quantile` function.
365#[allow(dead_code)]
366pub fn quantile<T>(
367 array: ArrayView<T, Ix1>,
368 q: ArrayView<T, Ix1>,
369 method: Option<&str>,
370) -> Result<Array<T, Ix1>, &'static str>
371where
372 T: Clone + Float + FromPrimitive,
373{
374 if array.is_empty() {
375 return Err("Cannot compute quantile of an empty array");
376 }
377
378 // Validate q values
379 for &val in q.iter() {
380 if val < T::from_f64(0.0).expect("Operation failed")
381 || val > T::from_f64(1.0).expect("Operation failed")
382 {
383 return Err("Quantile values must be between 0 and 1");
384 }
385 }
386
387 // Clone and sort the array
388 let mut sorted: Vec<T> = array.iter().copied().collect();
389 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
390
391 let n = sorted.len();
392 let mut result = Array::<T, Ix1>::zeros(q.len());
393
394 // The interpolation method to use
395 let method = method.unwrap_or("linear");
396
397 for (i, &q_val) in q.iter().enumerate() {
398 if q_val == T::from_f64(0.0).expect("Operation failed") {
399 result[i] = sorted[0];
400 continue;
401 }
402
403 if q_val == T::from_f64(1.0).expect("Operation failed") {
404 result[i] = sorted[n - 1];
405 continue;
406 }
407
408 // Calculate the position in the sorted array
409 let h = T::from_usize(n - 1).expect("Operation failed") * q_val;
410 let h_floor = h.floor();
411 let idx_low = h_floor.to_usize().unwrap_or(0).min(n - 1);
412 let idx_high = (idx_low + 1).min(n - 1);
413
414 match method {
415 "linear" => {
416 let weight = h - h_floor;
417 result[i] = sorted[idx_low] * (T::from_f64(1.0).expect("Operation failed") - weight) + sorted[idx_high] * weight;
418 }
419 "lower" => {
420 result[i] = sorted[idx_low];
421 }
422 "higher" => {
423 result[i] = sorted[idx_high];
424 }
425 "midpoint" => {
426 result[i] = (sorted[idx_low] + sorted[idx_high]) / T::from_f64(2.0).expect("Operation failed");
427 }
428 "nearest" => {
429 let weight = h - h_floor;
430 if weight < T::from_f64(0.5).expect("Operation failed") {
431 result[i] = sorted[idx_low];
432 } else {
433 result[i] = sorted[idx_high];
434 }
435 }
436 _ => return Err("Invalid interpolation method. Use 'linear', 'lower', 'higher', 'midpoint', or 'nearest'"),
437 }
438 }
439
440 Ok(result)
441}
442
443/// Count number of occurrences of each value in array of non-negative ints.
444///
445/// # Arguments
446///
447/// * `array` - Input array of non-negative integers
448/// * `minlength` - Minimum number of bins for the output array. If None, the output array length is determined by the maximum value in `array`.
449/// * `weights` - Optional weights array. If specified, must have same shape as `array`.
450///
451/// # Returns
452///
453/// An array of counts
454///
455/// # Examples
456///
457/// ```
458/// use ::ndarray::array;
459/// use scirs2_core::ndarray_ext::stats::bincount;
460///
461/// let data = array![1, 2, 3, 1, 2, 1, 0, 1, 3, 2];
462/// let counts = bincount(data.view(), None, None).expect("Operation failed");
463/// assert_eq!(counts.len(), 4);
464/// assert_eq!(counts[0], 1.0); // '0' occurs once
465/// assert_eq!(counts[1], 4.0); // '1' occurs four times
466/// assert_eq!(counts[2], 3.0); // '2' occurs three times
467/// assert_eq!(counts[3], 2.0); // '3' occurs twice
468/// ```
469///
470/// This function is equivalent to ``NumPy``'s `np.bincount` function.
471#[allow(dead_code)]
472pub fn bincount(
473 array: ArrayView<usize, Ix1>,
474 minlength: Option<usize>,
475 weights: Option<ArrayView<f64, Ix1>>,
476) -> Result<Array<f64, Ix1>, &'static str> {
477 if array.is_empty() {
478 return Err("Cannot compute bincount of an empty array");
479 }
480
481 // Find maximum value to determine number of bins
482 let mut max_val = 0;
483 for &val in array.iter() {
484 if val > max_val {
485 max_val = val;
486 }
487 }
488
489 // Determine length of output array
490 let length = if let Some(min_len) = minlength {
491 max_val.max(min_len - 1) + 1
492 } else {
493 max_val + 1
494 };
495
496 let mut result = Array::<f64, Ix1>::zeros(length);
497
498 match weights {
499 Some(w) => {
500 if w.len() != array.len() {
501 return Err("Weights array must have same length as input array");
502 }
503 for (&idx, &weight) in array.iter().zip(w.iter()) {
504 result[idx] += weight;
505 }
506 }
507 None => {
508 for &idx in array.iter() {
509 result[idx] += 1.0;
510 }
511 }
512 }
513
514 Ok(result)
515}
516
517/// Return the indices of the bins to which each value in input array belongs.
518///
519/// # Arguments
520///
521/// * `array` - Input array
522/// * `bins` - Array of bin edges
523/// * `right` - Indicates whether the intervals include the right or left bin edge
524/// * `result_type` - Whether to return the indices ('indices') or the bin values ('values')
525///
526/// # Returns
527///
528/// Array of indices or values depending on result_type
529///
530/// # Examples
531///
532/// ```
533/// use ::ndarray::array;
534/// use scirs2_core::ndarray_ext::stats::digitize;
535///
536/// let data = array![1.2, 3.5, 5.1, 0.8, 2.9];
537/// let bins = array![1.0, 3.0, 5.0];
538/// let indices = digitize(data.view(), bins.view(), false, "indices").expect("Operation failed");
539///
540/// assert_eq!(indices[0], 1); // 1.2 is in the first bin (1.0 <= x < 3.0)
541/// assert_eq!(indices[1], 2); // 3.5 is in the second bin (3.0 <= x < 5.0)
542/// assert_eq!(indices[2], 3); // 5.1 is after the last bin (>= 5.0)
543/// assert_eq!(indices[3], 0); // 0.8 is before the first bin (< 1.0)
544/// assert_eq!(indices[4], 1); // 2.9 is in the first bin (1.0 <= x < 3.0)
545/// ```
546///
547/// This function is equivalent to ``NumPy``'s `np.digitize` function.
548#[allow(dead_code)]
549pub fn digitize<T>(
550 array: ArrayView<T, Ix1>,
551 bins: ArrayView<T, Ix1>,
552 right: bool,
553 result_type: &str,
554) -> Result<Array<usize, Ix1>, &'static str>
555where
556 T: Clone + Float + FromPrimitive,
557{
558 if array.is_empty() {
559 return Err("Cannot digitize an empty array");
560 }
561
562 if bins.is_empty() {
563 return Err("Bins array cannot be empty");
564 }
565
566 // Check that bins are monotonically increasing
567 for i in 1..bins.len() {
568 if bins[i] <= bins[i.saturating_sub(1)] {
569 return Err("Bins must be monotonically increasing");
570 }
571 }
572
573 let mut result = Array::<usize, Ix1>::zeros(array.len());
574
575 for (i, &val) in array.iter().enumerate() {
576 let mut bin_idx = 0;
577
578 if right {
579 // Right inclusive: val <= edge
580 for j in 0..bins.len() {
581 if val <= bins[j] {
582 bin_idx = j;
583 break;
584 }
585 bin_idx = bins.len();
586 }
587 } else {
588 // Left inclusive: val < edge
589 for j in 0..bins.len() {
590 if val < bins[j] {
591 bin_idx = j;
592 break;
593 }
594 bin_idx = bins.len();
595 }
596 }
597
598 result[i] = bin_idx;
599 }
600
601 if result_type == "indices" {
602 Ok(result)
603 } else if result_type == "values" {
604 Err("'values' result_type is not yet implemented")
605 } else {
606 Err("result_type must be 'indices' or 'values'")
607 }
608}