ruvector_math/utils/
sorting.rs

1//! Sorting utilities for optimal transport
2
3/// Argsort: returns indices that would sort the array
4pub fn argsort(data: &[f64]) -> Vec<usize> {
5    let mut indices: Vec<usize> = (0..data.len()).collect();
6    indices.sort_by(|&a, &b| data[a].partial_cmp(&data[b]).unwrap_or(std::cmp::Ordering::Equal));
7    indices
8}
9
10/// Sort with indices: returns (sorted_data, original_indices)
11pub fn sort_with_indices(data: &[f64]) -> (Vec<f64>, Vec<usize>) {
12    let indices = argsort(data);
13    let sorted: Vec<f64> = indices.iter().map(|&i| data[i]).collect();
14    (sorted, indices)
15}
16
17/// Quantile of sorted data (0.0 to 1.0)
18pub fn quantile_sorted(sorted_data: &[f64], q: f64) -> f64 {
19    if sorted_data.is_empty() {
20        return 0.0;
21    }
22
23    let q = q.clamp(0.0, 1.0);
24    let n = sorted_data.len();
25
26    if n == 1 {
27        return sorted_data[0];
28    }
29
30    let idx_f = q * (n - 1) as f64;
31    let idx_low = idx_f.floor() as usize;
32    let idx_high = (idx_low + 1).min(n - 1);
33    let frac = idx_f - idx_low as f64;
34
35    sorted_data[idx_low] * (1.0 - frac) + sorted_data[idx_high] * frac
36}
37
38/// Compute cumulative distribution function values
39pub fn compute_cdf(weights: &[f64]) -> Vec<f64> {
40    let total: f64 = weights.iter().sum();
41    let mut cdf = Vec::with_capacity(weights.len());
42    let mut cumsum = 0.0;
43
44    for &w in weights {
45        cumsum += w / total;
46        cdf.push(cumsum);
47    }
48
49    cdf
50}
51
52/// Weighted quantile
53pub fn weighted_quantile(values: &[f64], weights: &[f64], q: f64) -> f64 {
54    if values.is_empty() {
55        return 0.0;
56    }
57
58    let indices = argsort(values);
59    let sorted_values: Vec<f64> = indices.iter().map(|&i| values[i]).collect();
60    let sorted_weights: Vec<f64> = indices.iter().map(|&i| weights[i]).collect();
61
62    let cdf = compute_cdf(&sorted_weights);
63    let q = q.clamp(0.0, 1.0);
64
65    // Find the value at quantile q
66    for (i, &c) in cdf.iter().enumerate() {
67        if c >= q {
68            return sorted_values[i];
69        }
70    }
71
72    sorted_values[sorted_values.len() - 1]
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78
79    #[test]
80    fn test_argsort() {
81        let data = vec![3.0, 1.0, 2.0];
82        let indices = argsort(&data);
83        assert_eq!(indices, vec![1, 2, 0]);
84    }
85
86    #[test]
87    fn test_quantile() {
88        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
89
90        assert!((quantile_sorted(&data, 0.0) - 1.0).abs() < 1e-10);
91        assert!((quantile_sorted(&data, 0.5) - 3.0).abs() < 1e-10);
92        assert!((quantile_sorted(&data, 1.0) - 5.0).abs() < 1e-10);
93    }
94
95    #[test]
96    fn test_cdf() {
97        let weights = vec![0.25, 0.25, 0.25, 0.25];
98        let cdf = compute_cdf(&weights);
99
100        assert!((cdf[0] - 0.25).abs() < 1e-10);
101        assert!((cdf[1] - 0.50).abs() < 1e-10);
102        assert!((cdf[2] - 0.75).abs() < 1e-10);
103        assert!((cdf[3] - 1.00).abs() < 1e-10);
104    }
105}