tea_rolling/cmp.rs
1use std::cmp::{Ordering, min};
2
3use tea_core::prelude::*;
4/// Trait for performing rolling comparison operations on valid elements in vectors.
5///
6/// This trait provides methods for calculating rolling minimum, maximum, argmin, argmax,
7/// and rank operations on vectors of potentially nullable elements.
8pub trait RollingValidCmp<T: IsNone>: Vec1View<T> {
9 /// Calculates the rolling argmin (index of minimum value) for the vector.
10 ///
11 /// # Arguments
12 ///
13 /// * `window` - The size of the rolling window.
14 /// * `min_periods` - The minimum number of observations in window required to have a value.
15 /// * `out` - Optional output buffer to store the results.
16 ///
17 /// # Returns
18 ///
19 /// A vector containing the rolling argmin values.
20 #[no_out]
21 fn ts_vargmin<O: Vec1<U>, U>(
22 &self,
23 window: usize,
24 min_periods: Option<usize>,
25 out: Option<O::UninitRefMut<'_>>,
26 ) -> O
27 where
28 T::Inner: Number,
29 f64: Cast<U>,
30 {
31 let window = min(self.len(), window);
32 let mut min: Option<T::Inner> = None;
33 let mut min_idx: Option<usize> = None;
34 let mut n = 0;
35 let min_periods = min_periods.unwrap_or(window / 2);
36 self.rolling_apply_idx(
37 window,
38 |start, end, v| {
39 let v = v.to_opt();
40 unsafe {
41 if v.is_some() {
42 n += 1;
43 if min_idx.is_none() {
44 min_idx = Some(end);
45 min = Some(v.unwrap());
46 }
47 }
48 if min_idx < start {
49 // the minimum value has expired, find the minimum value again
50 let start = start.unwrap();
51 min = self.uget(start).to_opt();
52 for i in start..=end {
53 let v_ = self.uget(i).to_opt();
54 match v_.sort_cmp(&min) {
55 Ordering::Less | Ordering::Equal => {
56 (min, min_idx) = (v_, Some(i));
57 },
58 _ => {},
59 }
60 }
61 } else {
62 match v.sort_cmp(&min) {
63 Ordering::Less | Ordering::Equal => {
64 (min, min_idx) = (v, Some(end));
65 },
66 _ => {},
67 }
68 }
69 let out = if n >= min_periods {
70 min_idx
71 .map(|min_idx| (min_idx - start.unwrap_or(0) + 1).f64())
72 .unwrap_or(f64::NAN)
73 .cast()
74 } else {
75 f64::NAN.cast()
76 };
77 if start.is_some() && self.uget(start.unwrap()).not_none() {
78 n -= 1;
79 }
80 out
81 }
82 },
83 out,
84 )
85 }
86
87 /// Calculates the rolling minimum for the vector.
88 ///
89 /// # Arguments
90 ///
91 /// * `window` - The size of the rolling window.
92 /// * `min_periods` - The minimum number of observations in window required to have a value.
93 /// * `out` - Optional output buffer to store the results.
94 ///
95 /// # Returns
96 ///
97 /// A vector containing the rolling minimum values.
98 #[no_out]
99 fn ts_vmin<O: Vec1<U>, U>(
100 &self,
101 window: usize,
102 min_periods: Option<usize>,
103 out: Option<O::UninitRefMut<'_>>,
104 ) -> O
105 where
106 T::Inner: Number,
107 Option<T::Inner>: Cast<U>,
108 {
109 let window = min(self.len(), window);
110 let mut min: Option<T::Inner> = None;
111 let mut min_idx: Option<usize> = None;
112 let mut n = 0;
113 let min_periods = min_periods.unwrap_or(window / 2);
114 self.rolling_apply_idx(
115 window,
116 |start, end, v| {
117 let v = v.to_opt();
118 unsafe {
119 if v.is_some() {
120 n += 1;
121 if min_idx.is_none() {
122 (min, min_idx) = (v, Some(end));
123 }
124 }
125 if min_idx < start {
126 // the minimum value has expired, find the minimum value again
127 let start = start.unwrap();
128 min = self.uget(start).to_opt();
129 for i in start..=end {
130 let v_ = self.uget(i).to_opt();
131 match v_.sort_cmp(&min) {
132 Ordering::Less | Ordering::Equal => {
133 (min, min_idx) = (v_, Some(i));
134 },
135 _ => {},
136 }
137 }
138 } else {
139 match v.sort_cmp(&min) {
140 Ordering::Less | Ordering::Equal => {
141 (min, min_idx) = (v, Some(end));
142 },
143 _ => {},
144 }
145 }
146 let out = if n >= min_periods {
147 min.cast()
148 } else {
149 None.cast()
150 };
151 if start.is_some() && self.uget(start.unwrap()).not_none() {
152 n -= 1;
153 }
154 out
155 }
156 },
157 out,
158 )
159 }
160
161 /// Calculates the rolling argmax (index of maximum value) for the vector.
162 ///
163 /// # Arguments
164 ///
165 /// * `window` - The size of the rolling window.
166 /// * `min_periods` - The minimum number of observations in window required to have a value.
167 /// * `out` - Optional output buffer to store the results.
168 ///
169 /// # Returns
170 ///
171 /// A vector containing the rolling argmax values.
172 #[no_out]
173 fn ts_vargmax<O: Vec1<U>, U>(
174 &self,
175 window: usize,
176 min_periods: Option<usize>,
177 out: Option<O::UninitRefMut<'_>>,
178 ) -> O
179 where
180 T::Inner: Number,
181 f64: Cast<U>,
182 {
183 let window = min(self.len(), window);
184 let mut max: Option<T::Inner> = None;
185 let mut max_idx: Option<usize> = None;
186 let mut n = 0;
187 let min_periods = min_periods.unwrap_or(window / 2);
188 self.rolling_apply_idx(
189 window,
190 |start, end, v| {
191 let v = v.to_opt();
192 unsafe {
193 if v.is_some() {
194 n += 1;
195 if max_idx.is_none() {
196 max_idx = Some(end);
197 max = Some(v.unwrap());
198 }
199 }
200 if max_idx < start {
201 // the minimum value has expired, find the minimum value again
202 let start = start.unwrap();
203 max = self.uget(start).to_opt();
204 for i in start..=end {
205 let v_ = self.uget(i).to_opt();
206 match v_.sort_cmp_rev(&max) {
207 Ordering::Less | Ordering::Equal => {
208 (max, max_idx) = (v_, Some(i));
209 },
210 _ => {},
211 }
212 }
213 } else {
214 match v.sort_cmp_rev(&max) {
215 Ordering::Less | Ordering::Equal => {
216 (max, max_idx) = (v, Some(end));
217 },
218 _ => {},
219 }
220 }
221 let out = if n >= min_periods {
222 max_idx
223 .map(|max_idx| (max_idx - start.unwrap_or(0) + 1).f64())
224 .unwrap_or(f64::NAN)
225 .cast()
226 } else {
227 f64::NAN.cast()
228 };
229 if start.is_some() && self.uget(start.unwrap()).not_none() {
230 n -= 1;
231 }
232 out
233 }
234 },
235 out,
236 )
237 }
238
239 /// Calculates the rolling maximum for the vector.
240 ///
241 /// # Arguments
242 ///
243 /// * `window` - The size of the rolling window.
244 /// * `min_periods` - The minimum number of observations in window required to have a value.
245 /// * `out` - Optional output buffer to store the results.
246 ///
247 /// # Returns
248 ///
249 /// A vector containing the rolling maximum values.
250 #[no_out]
251 fn ts_vmax<O: Vec1<U>, U>(
252 &self,
253 window: usize,
254 min_periods: Option<usize>,
255 out: Option<O::UninitRefMut<'_>>,
256 ) -> O
257 where
258 T::Inner: Number,
259 Option<T::Inner>: Cast<U>,
260 {
261 let window = min(self.len(), window);
262 let mut max: Option<T::Inner> = None;
263 let mut max_idx: Option<usize> = None;
264 let mut n = 0;
265 let min_periods = min_periods.unwrap_or(window / 2);
266 self.rolling_apply_idx(
267 window,
268 |start, end, v| {
269 let v = v.to_opt();
270 unsafe {
271 if v.is_some() {
272 n += 1;
273 if max_idx.is_none() {
274 (max, max_idx) = (v, Some(end));
275 }
276 }
277 if max_idx < start {
278 // the minimum value has expired, find the minimum value again
279 let start = start.unwrap();
280 max = self.uget(start).to_opt();
281 for i in start..=end {
282 let v_ = self.uget(i).to_opt();
283 match v_.sort_cmp_rev(&max) {
284 Ordering::Less | Ordering::Equal => {
285 (max, max_idx) = (v_, Some(i));
286 },
287 _ => {},
288 }
289 }
290 } else {
291 match v.sort_cmp_rev(&max) {
292 Ordering::Less | Ordering::Equal => {
293 (max, max_idx) = (v, Some(end));
294 },
295 _ => {},
296 }
297 }
298 let out = if n >= min_periods {
299 max.cast()
300 } else {
301 None.cast()
302 };
303 if start.is_some() && self.uget(start.unwrap()).not_none() {
304 n -= 1;
305 }
306 out
307 }
308 },
309 out,
310 )
311 }
312
313 /// Calculates the rolling rank for the vector.
314 ///
315 /// # Arguments
316 ///
317 /// * `window` - The size of the rolling window.
318 /// * `min_periods` - The minimum number of observations in window required to have a value.
319 /// * `pct` - If true, return percentage rank, otherwise return absolute rank.
320 /// * `rev` - If true, rank in descending order, otherwise rank in ascending order.
321 /// * `out` - Optional output buffer to store the results.
322 ///
323 /// # Returns
324 ///
325 /// A vector containing the rolling rank values.
326 #[no_out]
327 fn ts_vrank<O: Vec1<U>, U>(
328 &self,
329 window: usize,
330 min_periods: Option<usize>,
331 pct: bool,
332 rev: bool,
333 out: Option<O::UninitRefMut<'_>>,
334 ) -> O
335 where
336 T::Inner: Number,
337 f64: Cast<U>,
338 {
339 let window = min(self.len(), window);
340 let min_periods = min_periods.unwrap_or(window / 2);
341 let w_m1 = window - 1; // window minus one
342 let mut n = 0usize; // keep the num of valid elements
343 self.rolling_apply_idx(
344 window,
345 |start, end, v| {
346 let mut n_repeat = 1; // repeat count of the current value
347 let mut rank = 1.; // assume that the first element is the smallest, the rank goes up if we find a smaller element
348 if v.not_none() {
349 n += 1;
350 let v = v.unwrap();
351 for i in start.unwrap_or(0)..end {
352 let a = unsafe { self.uget(i) };
353 if a.not_none() {
354 let a = a.unwrap();
355 if a < v {
356 rank += 1.
357 } else if a == v {
358 n_repeat += 1
359 }
360 }
361 }
362 } else {
363 rank = f64::NAN
364 }
365 let out: f64;
366 if n >= min_periods {
367 let res = if !rev {
368 rank + 0.5 * (n_repeat - 1) as f64 // method for repeated values: average
369 } else {
370 (n + 1) as f64 - rank - 0.5 * (n_repeat - 1) as f64
371 };
372 if pct {
373 out = res / n as f64;
374 } else {
375 out = res;
376 }
377 } else {
378 out = f64::NAN;
379 }
380 if end >= w_m1 && unsafe { self.uget(start.unwrap()) }.not_none() {
381 n -= 1;
382 }
383 out.cast()
384 },
385 out,
386 )
387 }
388}
389
390pub trait RollingCmp<T>: Vec1View<T> {}
391
392impl<T: IsNone, I: Vec1View<T>> RollingValidCmp<T> for I {}
393impl<T, I: Vec1View<T>> RollingCmp<T> for I {}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 #[test]
400 fn test_ts_vmin() {
401 let v = vec![19, 0, 1, 2, 3, 4, 5];
402 let res: Vec<f64> = v.ts_vargmin(2, Some(1));
403 assert_eq!(res, vec![1., 2., 1., 1., 1., 1., 1.]);
404 let v = vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)];
405 // test ts_vargmin
406 let res: Vec<Option<f64>> = v.ts_vargmin(3, None);
407 assert_eq!(res, vec![Some(1.), Some(1.), Some(1.), Some(1.), Some(1.)]);
408 // test ts_vmin
409 let res: Vec<Option<f64>> = v.ts_vmin::<Vec<Option<f64>>, Option<f64>>(3, None);
410 assert_eq!(
411 res,
412 vec![Some(1.), Some(1.), Some(1.0), Some(2.0), Some(3.0)]
413 );
414 let v = vec![1, 3, 2, 5, 3, 1, 5, 7, 3];
415 // test ts_vargmin
416 let res: Vec<Option<i32>> = v.opt().ts_vargmin(3, Some(3));
417 assert_eq!(
418 res,
419 vec![
420 None,
421 None,
422 Some(1),
423 Some(2),
424 Some(1),
425 Some(3),
426 Some(2),
427 Some(1),
428 Some(3)
429 ]
430 );
431 // test ts_vmin
432 let res: Vec<Option<i32>> = v.opt().ts_vmin(3, Some(3));
433 assert_eq!(
434 res,
435 vec![
436 None,
437 None,
438 Some(1),
439 Some(2),
440 Some(2),
441 Some(1),
442 Some(1),
443 Some(1),
444 Some(3)
445 ]
446 );
447 }
448
449 #[test]
450 fn test_ts_vmax() {
451 let v = vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)];
452 // test ts_vargmax
453 let res: Vec<f64> = v.ts_vargmax(3, None);
454 assert_eq!(res, vec![1., 2., 3., 3., 3.]);
455 // test ts_vmax
456 let res: Vec<f64> = v.ts_vmax(3, None);
457 assert_eq!(res, vec![1., 2., 3., 4., 5.]);
458 let v = vec![1, 3, 2, 5, 3, 1, 5, 7, 3];
459 // test ts_vargmin
460 let res: Vec<Option<f64>> = v.opt().ts_vargmax(3, Some(3));
461 assert_eq!(
462 res,
463 vec![
464 None,
465 None,
466 Some(2.),
467 Some(3.),
468 Some(2.),
469 Some(1.),
470 Some(3.),
471 Some(3.),
472 Some(2.)
473 ]
474 );
475 // test ts_vmin
476 let res: Vec<Option<i32>> = v.opt().ts_vmax(3, Some(3));
477 assert_eq!(
478 res,
479 vec![
480 None,
481 None,
482 Some(3),
483 Some(5),
484 Some(5),
485 Some(5),
486 Some(5),
487 Some(7),
488 Some(7)
489 ]
490 );
491 }
492
493 #[test]
494 fn test_ts_vrank() {
495 let v = vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)];
496 // test ts_vargmax
497 let res: Vec<f64> = v.ts_vrank(3, None, false, false);
498 assert_eq!(res, vec![1., 2., 3., 3., 3.]);
499 let v = vec![1, 3, 2, 5, 3, 1, 5, 7, 3];
500 // test ts_vargmin
501 let res: Vec<Option<f64>> = v.ts_vrank(3, Some(3), false, false);
502 assert_eq!(
503 res,
504 vec![
505 None,
506 None,
507 Some(2.),
508 Some(3.),
509 Some(2.),
510 Some(1.),
511 Some(3.),
512 Some(3.),
513 Some(1.)
514 ]
515 );
516 }
517}