tea_rolling/
binary.rs

1use tea_core::prelude::*;
2
3/// Trait for performing rolling binary operations on valid elements in vectors.
4///
5/// This trait provides methods for calculating rolling covariance and correlation
6/// between two vectors of potentially nullable elements.
7pub trait RollingValidBinary<T: IsNone>: Vec1View<T> {
8    /// Calculates the rolling covariance between two vectors.
9    ///
10    /// # Arguments
11    ///
12    /// * `other` - The other vector to calculate covariance with.
13    /// * `window` - The size of the rolling window.
14    /// * `min_periods` - The minimum number of observations in window required to have a value.
15    /// * `out` - Optional output buffer to store the results.
16    ///
17    /// # Returns
18    ///
19    /// A vector containing the rolling covariance values.
20    #[no_out]
21    fn ts_vcov<O: Vec1<U>, U, V2: Vec1View<T2>, T2: IsNone>(
22        &self,
23        other: &V2,
24        window: usize,
25        min_periods: Option<usize>,
26        out: Option<O::UninitRefMut<'_>>,
27    ) -> O
28    where
29        T::Inner: Number,
30        T2::Inner: Number,
31        f64: Cast<U>,
32    {
33        let min_periods = min_periods.unwrap_or(window / 2).min(window);
34        let mut sum_a = 0.;
35        let mut sum_b = 0.;
36        let mut sum_ab = 0.;
37        let mut n = 0;
38        self.rolling2_apply(
39            other,
40            window,
41            |remove_values, (va, vb)| {
42                if va.not_none() && vb.not_none() {
43                    n += 1;
44                    let (va, vb) = (va.unwrap().f64(), vb.unwrap().f64());
45                    sum_a += va;
46                    sum_b += vb;
47                    sum_ab += va * vb;
48                };
49                let res = if n >= min_periods {
50                    (sum_ab - (sum_a * sum_b) / n.f64()) / (n - 1).f64()
51                } else {
52                    f64::NAN
53                };
54                if let Some((va, vb)) = remove_values {
55                    if va.not_none() && vb.not_none() {
56                        n -= 1;
57                        let (va, vb) = (va.unwrap().f64(), vb.unwrap().f64());
58                        sum_a -= va;
59                        sum_b -= vb;
60                        sum_ab -= va * vb;
61                    };
62                }
63                res.cast()
64            },
65            out,
66        )
67    }
68
69    /// Calculates the rolling correlation between two vectors.
70    ///
71    /// # Arguments
72    ///
73    /// * `other` - The other vector to calculate correlation with.
74    /// * `window` - The size of the rolling window.
75    /// * `min_periods` - The minimum number of observations in window required to have a value.
76    /// * `out` - Optional output buffer to store the results.
77    ///
78    /// # Returns
79    ///
80    /// A vector containing the rolling correlation values.
81    #[no_out]
82    fn ts_vcorr<O: Vec1<U>, U, V2: Vec1View<T2>, T2: IsNone>(
83        &self,
84        other: &V2,
85        window: usize,
86        min_periods: Option<usize>,
87        out: Option<O::UninitRefMut<'_>>,
88    ) -> O
89    where
90        T::Inner: Number,
91        T2::Inner: Number,
92        f64: Cast<U>,
93    {
94        let mut sum_a = 0.;
95        let mut sum2_a = 0.;
96        let mut sum_b = 0.;
97        let mut sum2_b = 0.;
98        let mut sum_ab = 0.;
99        let mut n = 0;
100        let min_periods = min_periods.unwrap_or(window / 2).min(window);
101        self.rolling2_apply(
102            other,
103            window,
104            |remove_values, (va, vb)| {
105                if va.not_none() && vb.not_none() {
106                    n += 1;
107                    let (va, vb) = (va.unwrap().f64(), vb.unwrap().f64());
108                    sum_a += va;
109                    sum2_a += va * va;
110                    sum_b += vb;
111                    sum2_b += vb * vb;
112                    sum_ab += va * vb;
113                };
114                let res = if n >= min_periods {
115                    let n_f64 = n.f64();
116                    let mean_a = sum_a / n_f64;
117                    let mut var_a = sum2_a / n_f64;
118                    let mean_b = sum_b / n_f64;
119                    let mut var_b = sum2_b / n_f64;
120                    var_a -= mean_a.powi(2);
121                    var_b -= mean_b.powi(2);
122                    if (var_a > EPS) & (var_b > EPS) {
123                        let exy = sum_ab / n_f64;
124                        let exey = sum_a * sum_b / n_f64.powi(2);
125                        (exy - exey) / (var_a * var_b).sqrt()
126                    } else {
127                        f64::NAN
128                    }
129                } else {
130                    f64::NAN
131                };
132                if let Some((va, vb)) = remove_values {
133                    if va.not_none() && vb.not_none() {
134                        n -= 1;
135                        let (va, vb) = (va.unwrap().f64(), vb.unwrap().f64());
136                        sum_a -= va;
137                        sum2_a -= va * va;
138                        sum_b -= vb;
139                        sum2_b -= vb * vb;
140                        sum_ab -= va * vb;
141                    };
142                }
143                res.cast()
144            },
145            out,
146        )
147    }
148}
149
150impl<T: IsNone, I: Vec1View<T>> RollingValidBinary<T> for I {}
151
152#[cfg(test)]
153mod tests {
154    use tea_core::testing::assert_vec1d_equal_numeric;
155
156    use super::*;
157    #[test]
158    fn test_cov() {
159        let data = vec![1, 5, 3, 2, 5];
160        let data2 = vec![2, 5, 4, 3, 6];
161        let out1: Vec<f64> = data.ts_vcov(&data2, 3, Some(2));
162        let out2: Vec<f64> = data
163            .rolling2_custom(&data2, 3, |v1, v2| v1.titer().vcov(v2.titer(), 2), None)
164            .unwrap();
165        assert_vec1d_equal_numeric(&out1, &out2, None);
166        assert_vec1d_equal_numeric(&out1, &vec![f64::NAN, 6., 3., 1.5, 2.333333333333332], None);
167    }
168
169    #[test]
170    fn test_corr() {
171        let data = vec![1, 5, 3, 2, 5];
172        let data2 = vec![2, 5, 4, 3, 6];
173        let out1: Vec<f64> = data.ts_vcorr(&data2, 3, Some(2));
174        let out2: Vec<f64> = data
175            .rolling2_custom(
176                &data2,
177                3,
178                |v1, v2| v1.titer().vcorr_pearson(v2.titer(), 2),
179                None,
180            )
181            .unwrap();
182        assert_vec1d_equal_numeric(&out1, &out2, None);
183        assert_vec1d_equal_numeric(
184            &out1,
185            &vec![f64::NAN, 1., 0.9819805060619652, 0.9819805060619652, 1.],
186            None,
187        );
188    }
189}