Skip to main content

ruvector_math/utils/
sorting.rs

1//! Sorting utilities for optimal transport
2
3/// Argsort: returns indices that would sort the array
4pub fn argsort(data: &[f64]) -> Vec<usize> {
5    let mut indices: Vec<usize> = (0..data.len()).collect();
6    indices.sort_by(|&a, &b| {
7        data[a]
8            .partial_cmp(&data[b])
9            .unwrap_or(std::cmp::Ordering::Equal)
10    });
11    indices
12}
13
14/// Sort with indices: returns (sorted_data, original_indices)
15pub fn sort_with_indices(data: &[f64]) -> (Vec<f64>, Vec<usize>) {
16    let indices = argsort(data);
17    let sorted: Vec<f64> = indices.iter().map(|&i| data[i]).collect();
18    (sorted, indices)
19}
20
21/// Quantile of sorted data (0.0 to 1.0)
22pub fn quantile_sorted(sorted_data: &[f64], q: f64) -> f64 {
23    if sorted_data.is_empty() {
24        return 0.0;
25    }
26
27    let q = q.clamp(0.0, 1.0);
28    let n = sorted_data.len();
29
30    if n == 1 {
31        return sorted_data[0];
32    }
33
34    let idx_f = q * (n - 1) as f64;
35    let idx_low = idx_f.floor() as usize;
36    let idx_high = (idx_low + 1).min(n - 1);
37    let frac = idx_f - idx_low as f64;
38
39    sorted_data[idx_low] * (1.0 - frac) + sorted_data[idx_high] * frac
40}
41
42/// Compute cumulative distribution function values
43pub fn compute_cdf(weights: &[f64]) -> Vec<f64> {
44    let total: f64 = weights.iter().sum();
45    let mut cdf = Vec::with_capacity(weights.len());
46    let mut cumsum = 0.0;
47
48    for &w in weights {
49        cumsum += w / total;
50        cdf.push(cumsum);
51    }
52
53    cdf
54}
55
56/// Weighted quantile
57pub fn weighted_quantile(values: &[f64], weights: &[f64], q: f64) -> f64 {
58    if values.is_empty() {
59        return 0.0;
60    }
61
62    let indices = argsort(values);
63    let sorted_values: Vec<f64> = indices.iter().map(|&i| values[i]).collect();
64    let sorted_weights: Vec<f64> = indices.iter().map(|&i| weights[i]).collect();
65
66    let cdf = compute_cdf(&sorted_weights);
67    let q = q.clamp(0.0, 1.0);
68
69    // Find the value at quantile q
70    for (i, &c) in cdf.iter().enumerate() {
71        if c >= q {
72            return sorted_values[i];
73        }
74    }
75
76    sorted_values[sorted_values.len() - 1]
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82
83    #[test]
84    fn test_argsort() {
85        let data = vec![3.0, 1.0, 2.0];
86        let indices = argsort(&data);
87        assert_eq!(indices, vec![1, 2, 0]);
88    }
89
90    #[test]
91    fn test_quantile() {
92        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
93
94        assert!((quantile_sorted(&data, 0.0) - 1.0).abs() < 1e-10);
95        assert!((quantile_sorted(&data, 0.5) - 3.0).abs() < 1e-10);
96        assert!((quantile_sorted(&data, 1.0) - 5.0).abs() < 1e-10);
97    }
98
99    #[test]
100    fn test_cdf() {
101        let weights = vec![0.25, 0.25, 0.25, 0.25];
102        let cdf = compute_cdf(&weights);
103
104        assert!((cdf[0] - 0.25).abs() < 1e-10);
105        assert!((cdf[1] - 0.50).abs() < 1e-10);
106        assert!((cdf[2] - 0.75).abs() < 1e-10);
107        assert!((cdf[3] - 1.00).abs() < 1e-10);
108    }
109}