tea_rolling/
norm.rs

1use tea_core::prelude::*;
2
3/// Trait for rolling window normalization operations on valid (non-None) elements.
4pub trait RollingValidNorm<T: IsNone>: Vec1View<T> {
5    /// Calculates the rolling z-score (standard score) for valid elements within a window.
6    ///
7    /// # Arguments
8    ///
9    /// * `window` - The size of the rolling window.
10    /// * `min_periods` - The minimum number of observations in window required to have a value.
11    /// * `out` - Optional output buffer to store the results.
12    ///
13    /// # Returns
14    ///
15    /// A vector containing the rolling z-scores.
16    ///
17    /// # Notes
18    ///
19    /// The z-score is calculated as (x - mean) / standard_deviation.
20    /// If the standard deviation is zero or if there are fewer than `min_periods` valid observations,
21    /// the result will be NaN.
22    #[no_out]
23    fn ts_vzscore<O: Vec1<U>, U>(
24        &self,
25        window: usize,
26        min_periods: Option<usize>,
27        out: Option<O::UninitRefMut<'_>>,
28    ) -> O
29    where
30        T::Inner: Number,
31        f64: Cast<U>,
32    {
33        let mut sum = 0.;
34        let mut sum2 = 0.;
35        let mut n = 0;
36        let min_periods = min_periods.unwrap_or(window / 2).min(window);
37        self.rolling_apply(
38            window,
39            |v_rm, v| {
40                let res = if v.not_none() {
41                    n += 1;
42                    let v = v.unwrap().f64();
43                    sum += v;
44                    sum2 += v * v;
45                    if n >= min_periods {
46                        let n_f64 = n.f64();
47                        let mut var = sum2 / n_f64;
48                        let mean = sum / n_f64;
49                        var -= mean.powi(2);
50                        if var > EPS {
51                            (v - mean) / (var * n_f64 / (n - 1).f64()).sqrt()
52                        } else {
53                            f64::NAN
54                        }
55                    } else {
56                        f64::NAN
57                    }
58                } else {
59                    f64::NAN
60                };
61                if let Some(v) = v_rm {
62                    if v.not_none() {
63                        let v = v.unwrap().f64();
64                        n -= 1;
65                        sum -= v;
66                        sum2 -= v * v
67                    };
68                }
69                res.cast()
70            },
71            out,
72        )
73    }
74
75    /// Calculates the rolling min-max normalization for valid elements within a window.
76    ///
77    /// # Arguments
78    ///
79    /// * `window` - The size of the rolling window.
80    /// * `min_periods` - The minimum number of observations in window required to have a value.
81    /// * `out` - Optional output buffer to store the results.
82    ///
83    /// # Returns
84    ///
85    /// A vector containing the rolling min-max normalized values.
86    ///
87    /// # Notes
88    ///
89    /// The min-max normalization is calculated as (x - min) / (max - min).
90    /// If max equals min or if there are fewer than `min_periods` valid observations,
91    /// the result will be NaN.
92    #[no_out]
93    fn ts_vminmaxnorm<O: Vec1<U>, U>(
94        &self,
95        window: usize,
96        min_periods: Option<usize>,
97        out: Option<O::UninitRefMut<'_>>,
98    ) -> O
99    where
100        T::Inner: Number,
101        f64: Cast<U>,
102    {
103        let mut max = T::Inner::min_();
104        let mut max_idx = 0;
105        let mut min = T::Inner::max_();
106        let mut min_idx = 0;
107        let mut n = 0;
108        let min_periods = min_periods.unwrap_or(window / 2).min(window);
109        self.rolling_apply_idx(
110            window,
111            |start, end, v| {
112                if let Some(start) = start {
113                    match (max_idx < start, min_idx < start) {
114                        (true, false) => {
115                            // max value is invalid, find max value again
116                            max = T::Inner::min_();
117                            for i in start..end {
118                                let v = unsafe { self.uget(i) };
119                                if v.not_none() {
120                                    let v = v.unwrap();
121                                    if v >= max {
122                                        (max, max_idx) = (v, i);
123                                    }
124                                }
125                            }
126                        },
127                        (false, true) => {
128                            // min value is invalid, find min value again
129                            min = T::Inner::max_();
130                            for i in start..end {
131                                let v = unsafe { self.uget(i) };
132                                if v.not_none() {
133                                    let v = v.unwrap();
134                                    if v <= min {
135                                        (min, min_idx) = (v, i);
136                                    }
137                                }
138                            }
139                        },
140                        (true, true) => {
141                            // both max and min value are invalid, find max and min value again
142                            (max, min) = (T::Inner::min_(), T::Inner::max_());
143                            for i in start..end {
144                                let v = unsafe { self.uget(i) };
145                                if v.not_none() {
146                                    let v = v.unwrap();
147                                    if v >= max {
148                                        (max, max_idx) = (v, i);
149                                    }
150                                    if v <= min {
151                                        (min, min_idx) = (v, i);
152                                    }
153                                }
154                            }
155                        },
156                        (false, false) => (), // we don't need to find max and min value again
157                    }
158                }
159                // check if end position is max or min value
160                let res = if v.not_none() {
161                    n += 1;
162                    let v = v.unwrap();
163                    if v >= max {
164                        (max, max_idx) = (v, end);
165                    }
166                    if v <= min {
167                        (min, min_idx) = (v, end);
168                    }
169                    if (n >= min_periods) & (max != min) {
170                        ((v - min).f64() / (max - min).f64()).cast()
171                    } else {
172                        f64::NAN.cast()
173                    }
174                } else {
175                    f64::NAN.cast()
176                };
177                if let Some(start) = start {
178                    let v = unsafe { self.uget(start) };
179                    if v.not_none() {
180                        n -= 1;
181                    }
182                }
183                res
184            },
185            out,
186        )
187    }
188}
189
190impl<T: IsNone, I: Vec1View<T>> RollingValidNorm<T> for I {}
191
192#[cfg(test)]
193mod tests {
194    use tea_core::testing::assert_vec1d_equal_numeric;
195
196    use super::*;
197    #[test]
198    fn test_ts_zscore() {
199        let data = vec![1., 2., 3., f64::NAN, 5., 6., 7., f64::NAN, 9., 10.];
200        let res: Vec<f64> = data.ts_vzscore(4, None);
201        let expect = vec![
202            f64::NAN,
203            0.707107,
204            1.0,
205            f64::NAN,
206            1.091089,
207            0.872872,
208            1.0,
209            f64::NAN,
210            1.091089,
211            0.872872,
212        ];
213        // assert_eq!(res, expect);
214        assert_vec1d_equal_numeric(&res, &expect, Some(1e-5))
215    }
216}