tea_agg/
vec_valid.rs

1use tea_core::prelude::*;
2/// Enum representing different methods for calculating quantiles.
3#[derive(Copy, Clone)]
4pub enum QuantileMethod {
5    /// Linear interpolation between closest ranks.
6    Linear,
7    /// Use the lower of the two nearest ranks.
8    Lower,
9    /// Use the higher of the two nearest ranks.
10    Higher,
11    /// Use the average of the two nearest ranks.
12    MidPoint,
13}
14
15/// Extension trait providing additional aggregation methods for vectors with potentially invalid (None) values.
16pub trait VecAggValidExt<T: IsNone>: Vec1View<T> {
17    /// Calculate the quantile of the vector, ignoring NaN or None values.
18    ///
19    /// # Arguments
20    ///
21    /// * `q` - The quantile to calculate, must be between 0 and 1.
22    /// * `method` - The method to use for quantile calculation.
23    ///
24    /// # Returns
25    ///
26    /// Returns a `TResult<f64>` containing the calculated quantile value.
27    ///
28    /// # Errors
29    ///
30    /// Returns an error if `q` is not between 0 and 1.
31    fn vquantile(&self, q: f64, method: QuantileMethod) -> TResult<f64>
32    where
33        T: Cast<f64>,
34        T::Inner: Number,
35    {
36        tensure!(
37            (0. ..=1.).contains(&q),
38            "q must be between 0 and 1, find {}",
39            q
40        );
41        use QuantileMethod::*;
42        let mut out_c: Vec<_> = self.titer().collect_trusted_vec1(); // clone the array
43        let slc = out_c.try_as_slice_mut().unwrap();
44        let n = self.titer().count_valid();
45        // fast path for special cases
46        if n == 0 {
47            return Ok(f64::NAN);
48        } else if n == 1 {
49            return Ok(slc[0].clone().cast());
50        }
51        let len_1 = (n - 1).f64();
52        let (q, i, j, vi, vj) = if q <= 0.5 {
53            let q_idx = len_1 * q;
54            let (i, j) = (q_idx.floor().usize(), q_idx.ceil().usize());
55            let (head, m, _tail) = slc.select_nth_unstable_by(j, |va, vb| va.sort_cmp(vb));
56            if i != j {
57                let vi: f64 = head.titer().vmax().map(|v| v.f64()).cast();
58                (q, i, j, vi, m.clone().cast())
59            } else {
60                return Ok(m.clone().cast());
61            }
62        } else {
63            // sort from largest to smallest
64            let q = 1. - q;
65            let q_idx = len_1 * q;
66            let (i, j) = (q_idx.floor().usize(), q_idx.ceil().usize());
67            let (head, m, _tail) = slc.select_nth_unstable_by(j, |va, vb| va.sort_cmp_rev(vb));
68            if i != j {
69                let vi: f64 = head.titer().vmin().map(|v| v.f64()).cast();
70                match method {
71                    Lower => {
72                        return Ok(m.clone().cast());
73                    },
74                    Higher => {
75                        return Ok(vi);
76                    },
77                    _ => {},
78                };
79                (q, i, j, vi, m.clone().cast())
80            } else {
81                return Ok(m.clone().cast());
82            }
83        };
84        match method {
85            Linear => {
86                // `i + (j - i) * fraction`, where `fraction` is the
87                // fractional part of the index surrounded by `i` and `j`.
88                let (qi, qj) = (i.f64() / len_1, j.f64() / len_1);
89                let fraction = (q - qi) / (qj - qi);
90                Ok(vi + (vj - vi) * fraction)
91            },
92            Lower => Ok(vi),                // i
93            Higher => Ok(vj),               // j
94            MidPoint => Ok((vi + vj) / 2.), // (i + j) / 2.
95        }
96    }
97
98    /// Calculate the median of the vector, ignoring NaN or None values.
99    ///
100    /// # Returns
101    ///
102    /// Returns the median value as an `f64`.
103    #[inline]
104    fn vmedian(&self) -> f64
105    where
106        T: Cast<f64>,
107        T::Inner: Number,
108    {
109        self.vquantile(0.5, QuantileMethod::Linear).unwrap()
110    }
111}
112impl<V: Vec1View<T>, T: IsNone> VecAggValidExt<T> for V {}
113
114#[cfg(test)]
115mod tests {
116    #[test]
117    fn test_quantile() {
118        use super::*;
119        let a = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
120        assert_eq!(a.vquantile(0.5, QuantileMethod::Linear).unwrap(), 5.5);
121        assert_eq!(a.vquantile(0.5, QuantileMethod::Lower).unwrap(), 5.);
122        assert_eq!(a.vquantile(0.5, QuantileMethod::Higher).unwrap(), 6.);
123        assert_eq!(a.vquantile(0.5, QuantileMethod::MidPoint).unwrap(), 5.5);
124        assert_eq!(a.vquantile(0.25, QuantileMethod::Linear).unwrap(), 3.25);
125        assert_eq!(a.vquantile(0.25, QuantileMethod::Lower).unwrap(), 3.);
126        assert_eq!(a.vquantile(0.25, QuantileMethod::Higher).unwrap(), 4.);
127        assert_eq!(a.vquantile(0.25, QuantileMethod::MidPoint).unwrap(), 3.5);
128        assert_eq!(a.vquantile(0.75, QuantileMethod::Linear).unwrap(), 7.75);
129        assert_eq!(a.vquantile(0.75, QuantileMethod::Lower).unwrap(), 7.);
130        assert_eq!(a.vquantile(0.75, QuantileMethod::Higher).unwrap(), 8.);
131        assert_eq!(a.vquantile(0.75, QuantileMethod::MidPoint).unwrap(), 7.5);
132        assert_eq!(a.vquantile(0.22, QuantileMethod::Linear).unwrap(), 2.98);
133        assert_eq!(a.vquantile(0.78, QuantileMethod::Linear).unwrap(), 8.02);
134    }
135}