tea_rolling/
cmp.rs

1use std::cmp::{Ordering, min};
2
3use tea_core::prelude::*;
4/// Trait for performing rolling comparison operations on valid elements in vectors.
5///
6/// This trait provides methods for calculating rolling minimum, maximum, argmin, argmax,
7/// and rank operations on vectors of potentially nullable elements.
8pub trait RollingValidCmp<T: IsNone>: Vec1View<T> {
9    /// Calculates the rolling argmin (index of minimum value) for the vector.
10    ///
11    /// # Arguments
12    ///
13    /// * `window` - The size of the rolling window.
14    /// * `min_periods` - The minimum number of observations in window required to have a value.
15    /// * `out` - Optional output buffer to store the results.
16    ///
17    /// # Returns
18    ///
19    /// A vector containing the rolling argmin values.
20    #[no_out]
21    fn ts_vargmin<O: Vec1<U>, U>(
22        &self,
23        window: usize,
24        min_periods: Option<usize>,
25        out: Option<O::UninitRefMut<'_>>,
26    ) -> O
27    where
28        T::Inner: Number,
29        f64: Cast<U>,
30    {
31        let window = min(self.len(), window);
32        let mut min: Option<T::Inner> = None;
33        let mut min_idx: Option<usize> = None;
34        let mut n = 0;
35        let min_periods = min_periods.unwrap_or(window / 2);
36        self.rolling_apply_idx(
37            window,
38            |start, end, v| {
39                let v = v.to_opt();
40                unsafe {
41                    if v.is_some() {
42                        n += 1;
43                        if min_idx.is_none() {
44                            min_idx = Some(end);
45                            min = Some(v.unwrap());
46                        }
47                    }
48                    if min_idx < start {
49                        // the minimum value has expired, find the minimum value again
50                        let start = start.unwrap();
51                        min = self.uget(start).to_opt();
52                        for i in start..=end {
53                            let v_ = self.uget(i).to_opt();
54                            match v_.sort_cmp(&min) {
55                                Ordering::Less | Ordering::Equal => {
56                                    (min, min_idx) = (v_, Some(i));
57                                },
58                                _ => {},
59                            }
60                        }
61                    } else {
62                        match v.sort_cmp(&min) {
63                            Ordering::Less | Ordering::Equal => {
64                                (min, min_idx) = (v, Some(end));
65                            },
66                            _ => {},
67                        }
68                    }
69                    let out = if n >= min_periods {
70                        min_idx
71                            .map(|min_idx| (min_idx - start.unwrap_or(0) + 1).f64())
72                            .unwrap_or(f64::NAN)
73                            .cast()
74                    } else {
75                        f64::NAN.cast()
76                    };
77                    if start.is_some() && self.uget(start.unwrap()).not_none() {
78                        n -= 1;
79                    }
80                    out
81                }
82            },
83            out,
84        )
85    }
86
87    /// Calculates the rolling minimum for the vector.
88    ///
89    /// # Arguments
90    ///
91    /// * `window` - The size of the rolling window.
92    /// * `min_periods` - The minimum number of observations in window required to have a value.
93    /// * `out` - Optional output buffer to store the results.
94    ///
95    /// # Returns
96    ///
97    /// A vector containing the rolling minimum values.
98    #[no_out]
99    fn ts_vmin<O: Vec1<U>, U>(
100        &self,
101        window: usize,
102        min_periods: Option<usize>,
103        out: Option<O::UninitRefMut<'_>>,
104    ) -> O
105    where
106        T::Inner: Number,
107        Option<T::Inner>: Cast<U>,
108    {
109        let window = min(self.len(), window);
110        let mut min: Option<T::Inner> = None;
111        let mut min_idx: Option<usize> = None;
112        let mut n = 0;
113        let min_periods = min_periods.unwrap_or(window / 2);
114        self.rolling_apply_idx(
115            window,
116            |start, end, v| {
117                let v = v.to_opt();
118                unsafe {
119                    if v.is_some() {
120                        n += 1;
121                        if min_idx.is_none() {
122                            (min, min_idx) = (v, Some(end));
123                        }
124                    }
125                    if min_idx < start {
126                        // the minimum value has expired, find the minimum value again
127                        let start = start.unwrap();
128                        min = self.uget(start).to_opt();
129                        for i in start..=end {
130                            let v_ = self.uget(i).to_opt();
131                            match v_.sort_cmp(&min) {
132                                Ordering::Less | Ordering::Equal => {
133                                    (min, min_idx) = (v_, Some(i));
134                                },
135                                _ => {},
136                            }
137                        }
138                    } else {
139                        match v.sort_cmp(&min) {
140                            Ordering::Less | Ordering::Equal => {
141                                (min, min_idx) = (v, Some(end));
142                            },
143                            _ => {},
144                        }
145                    }
146                    let out = if n >= min_periods {
147                        min.cast()
148                    } else {
149                        None.cast()
150                    };
151                    if start.is_some() && self.uget(start.unwrap()).not_none() {
152                        n -= 1;
153                    }
154                    out
155                }
156            },
157            out,
158        )
159    }
160
161    /// Calculates the rolling argmax (index of maximum value) for the vector.
162    ///
163    /// # Arguments
164    ///
165    /// * `window` - The size of the rolling window.
166    /// * `min_periods` - The minimum number of observations in window required to have a value.
167    /// * `out` - Optional output buffer to store the results.
168    ///
169    /// # Returns
170    ///
171    /// A vector containing the rolling argmax values.
172    #[no_out]
173    fn ts_vargmax<O: Vec1<U>, U>(
174        &self,
175        window: usize,
176        min_periods: Option<usize>,
177        out: Option<O::UninitRefMut<'_>>,
178    ) -> O
179    where
180        T::Inner: Number,
181        f64: Cast<U>,
182    {
183        let window = min(self.len(), window);
184        let mut max: Option<T::Inner> = None;
185        let mut max_idx: Option<usize> = None;
186        let mut n = 0;
187        let min_periods = min_periods.unwrap_or(window / 2);
188        self.rolling_apply_idx(
189            window,
190            |start, end, v| {
191                let v = v.to_opt();
192                unsafe {
193                    if v.is_some() {
194                        n += 1;
195                        if max_idx.is_none() {
196                            max_idx = Some(end);
197                            max = Some(v.unwrap());
198                        }
199                    }
200                    if max_idx < start {
201                        // the minimum value has expired, find the minimum value again
202                        let start = start.unwrap();
203                        max = self.uget(start).to_opt();
204                        for i in start..=end {
205                            let v_ = self.uget(i).to_opt();
206                            match v_.sort_cmp_rev(&max) {
207                                Ordering::Less | Ordering::Equal => {
208                                    (max, max_idx) = (v_, Some(i));
209                                },
210                                _ => {},
211                            }
212                        }
213                    } else {
214                        match v.sort_cmp_rev(&max) {
215                            Ordering::Less | Ordering::Equal => {
216                                (max, max_idx) = (v, Some(end));
217                            },
218                            _ => {},
219                        }
220                    }
221                    let out = if n >= min_periods {
222                        max_idx
223                            .map(|max_idx| (max_idx - start.unwrap_or(0) + 1).f64())
224                            .unwrap_or(f64::NAN)
225                            .cast()
226                    } else {
227                        f64::NAN.cast()
228                    };
229                    if start.is_some() && self.uget(start.unwrap()).not_none() {
230                        n -= 1;
231                    }
232                    out
233                }
234            },
235            out,
236        )
237    }
238
239    /// Calculates the rolling maximum for the vector.
240    ///
241    /// # Arguments
242    ///
243    /// * `window` - The size of the rolling window.
244    /// * `min_periods` - The minimum number of observations in window required to have a value.
245    /// * `out` - Optional output buffer to store the results.
246    ///
247    /// # Returns
248    ///
249    /// A vector containing the rolling maximum values.
250    #[no_out]
251    fn ts_vmax<O: Vec1<U>, U>(
252        &self,
253        window: usize,
254        min_periods: Option<usize>,
255        out: Option<O::UninitRefMut<'_>>,
256    ) -> O
257    where
258        T::Inner: Number,
259        Option<T::Inner>: Cast<U>,
260    {
261        let window = min(self.len(), window);
262        let mut max: Option<T::Inner> = None;
263        let mut max_idx: Option<usize> = None;
264        let mut n = 0;
265        let min_periods = min_periods.unwrap_or(window / 2);
266        self.rolling_apply_idx(
267            window,
268            |start, end, v| {
269                let v = v.to_opt();
270                unsafe {
271                    if v.is_some() {
272                        n += 1;
273                        if max_idx.is_none() {
274                            (max, max_idx) = (v, Some(end));
275                        }
276                    }
277                    if max_idx < start {
278                        // the minimum value has expired, find the minimum value again
279                        let start = start.unwrap();
280                        max = self.uget(start).to_opt();
281                        for i in start..=end {
282                            let v_ = self.uget(i).to_opt();
283                            match v_.sort_cmp_rev(&max) {
284                                Ordering::Less | Ordering::Equal => {
285                                    (max, max_idx) = (v_, Some(i));
286                                },
287                                _ => {},
288                            }
289                        }
290                    } else {
291                        match v.sort_cmp_rev(&max) {
292                            Ordering::Less | Ordering::Equal => {
293                                (max, max_idx) = (v, Some(end));
294                            },
295                            _ => {},
296                        }
297                    }
298                    let out = if n >= min_periods {
299                        max.cast()
300                    } else {
301                        None.cast()
302                    };
303                    if start.is_some() && self.uget(start.unwrap()).not_none() {
304                        n -= 1;
305                    }
306                    out
307                }
308            },
309            out,
310        )
311    }
312
313    /// Calculates the rolling rank for the vector.
314    ///
315    /// # Arguments
316    ///
317    /// * `window` - The size of the rolling window.
318    /// * `min_periods` - The minimum number of observations in window required to have a value.
319    /// * `pct` - If true, return percentage rank, otherwise return absolute rank.
320    /// * `rev` - If true, rank in descending order, otherwise rank in ascending order.
321    /// * `out` - Optional output buffer to store the results.
322    ///
323    /// # Returns
324    ///
325    /// A vector containing the rolling rank values.
326    #[no_out]
327    fn ts_vrank<O: Vec1<U>, U>(
328        &self,
329        window: usize,
330        min_periods: Option<usize>,
331        pct: bool,
332        rev: bool,
333        out: Option<O::UninitRefMut<'_>>,
334    ) -> O
335    where
336        T::Inner: Number,
337        f64: Cast<U>,
338    {
339        let window = min(self.len(), window);
340        let min_periods = min_periods.unwrap_or(window / 2);
341        let w_m1 = window - 1; // window minus one
342        let mut n = 0usize; // keep the num of valid elements
343        self.rolling_apply_idx(
344            window,
345            |start, end, v| {
346                let mut n_repeat = 1; // repeat count of the current value
347                let mut rank = 1.; // assume that the first element is the smallest, the rank goes up if we find a smaller element
348                if v.not_none() {
349                    n += 1;
350                    let v = v.unwrap();
351                    for i in start.unwrap_or(0)..end {
352                        let a = unsafe { self.uget(i) };
353                        if a.not_none() {
354                            let a = a.unwrap();
355                            if a < v {
356                                rank += 1.
357                            } else if a == v {
358                                n_repeat += 1
359                            }
360                        }
361                    }
362                } else {
363                    rank = f64::NAN
364                }
365                let out: f64;
366                if n >= min_periods {
367                    let res = if !rev {
368                        rank + 0.5 * (n_repeat - 1) as f64 // method for repeated values: average
369                    } else {
370                        (n + 1) as f64 - rank - 0.5 * (n_repeat - 1) as f64
371                    };
372                    if pct {
373                        out = res / n as f64;
374                    } else {
375                        out = res;
376                    }
377                } else {
378                    out = f64::NAN;
379                }
380                if end >= w_m1 && unsafe { self.uget(start.unwrap()) }.not_none() {
381                    n -= 1;
382                }
383                out.cast()
384            },
385            out,
386        )
387    }
388}
389
390pub trait RollingCmp<T>: Vec1View<T> {}
391
392impl<T: IsNone, I: Vec1View<T>> RollingValidCmp<T> for I {}
393impl<T, I: Vec1View<T>> RollingCmp<T> for I {}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398
399    #[test]
400    fn test_ts_vmin() {
401        let v = vec![19, 0, 1, 2, 3, 4, 5];
402        let res: Vec<f64> = v.ts_vargmin(2, Some(1));
403        assert_eq!(res, vec![1., 2., 1., 1., 1., 1., 1.]);
404        let v = vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)];
405        // test ts_vargmin
406        let res: Vec<Option<f64>> = v.ts_vargmin(3, None);
407        assert_eq!(res, vec![Some(1.), Some(1.), Some(1.), Some(1.), Some(1.)]);
408        // test ts_vmin
409        let res: Vec<Option<f64>> = v.ts_vmin::<Vec<Option<f64>>, Option<f64>>(3, None);
410        assert_eq!(
411            res,
412            vec![Some(1.), Some(1.), Some(1.0), Some(2.0), Some(3.0)]
413        );
414        let v = vec![1, 3, 2, 5, 3, 1, 5, 7, 3];
415        // test ts_vargmin
416        let res: Vec<Option<i32>> = v.opt().ts_vargmin(3, Some(3));
417        assert_eq!(
418            res,
419            vec![
420                None,
421                None,
422                Some(1),
423                Some(2),
424                Some(1),
425                Some(3),
426                Some(2),
427                Some(1),
428                Some(3)
429            ]
430        );
431        // test ts_vmin
432        let res: Vec<Option<i32>> = v.opt().ts_vmin(3, Some(3));
433        assert_eq!(
434            res,
435            vec![
436                None,
437                None,
438                Some(1),
439                Some(2),
440                Some(2),
441                Some(1),
442                Some(1),
443                Some(1),
444                Some(3)
445            ]
446        );
447    }
448
449    #[test]
450    fn test_ts_vmax() {
451        let v = vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)];
452        // test ts_vargmax
453        let res: Vec<f64> = v.ts_vargmax(3, None);
454        assert_eq!(res, vec![1., 2., 3., 3., 3.]);
455        // test ts_vmax
456        let res: Vec<f64> = v.ts_vmax(3, None);
457        assert_eq!(res, vec![1., 2., 3., 4., 5.]);
458        let v = vec![1, 3, 2, 5, 3, 1, 5, 7, 3];
459        // test ts_vargmin
460        let res: Vec<Option<f64>> = v.opt().ts_vargmax(3, Some(3));
461        assert_eq!(
462            res,
463            vec![
464                None,
465                None,
466                Some(2.),
467                Some(3.),
468                Some(2.),
469                Some(1.),
470                Some(3.),
471                Some(3.),
472                Some(2.)
473            ]
474        );
475        // test ts_vmin
476        let res: Vec<Option<i32>> = v.opt().ts_vmax(3, Some(3));
477        assert_eq!(
478            res,
479            vec![
480                None,
481                None,
482                Some(3),
483                Some(5),
484                Some(5),
485                Some(5),
486                Some(5),
487                Some(7),
488                Some(7)
489            ]
490        );
491    }
492
493    #[test]
494    fn test_ts_vrank() {
495        let v = vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)];
496        // test ts_vargmax
497        let res: Vec<f64> = v.ts_vrank(3, None, false, false);
498        assert_eq!(res, vec![1., 2., 3., 3., 3.]);
499        let v = vec![1, 3, 2, 5, 3, 1, 5, 7, 3];
500        // test ts_vargmin
501        let res: Vec<Option<f64>> = v.ts_vrank(3, Some(3), false, false);
502        assert_eq!(
503            res,
504            vec![
505                None,
506                None,
507                Some(2.),
508                Some(3.),
509                Some(2.),
510                Some(1.),
511                Some(3.),
512                Some(3.),
513                Some(1.)
514            ]
515        );
516    }
517}