Skip to main content

vector_ta/indicators/
utility_functions.rs

1use std::collections::VecDeque;
2use thiserror::Error;
3
4#[derive(Debug, Error)]
5pub enum RollingError {
6    #[error("rolling: Empty data provided.")]
7    EmptyData,
8    #[error("rolling: Invalid period: period={period}, data length={data_len}")]
9    InvalidPeriod { period: usize, data_len: usize },
10    #[error("rolling: All values are NaN.")]
11    AllValuesNaN,
12    #[error("rolling: Not enough valid data: needed={needed}, valid={valid}")]
13    NotEnoughValidData { needed: usize, valid: usize },
14}
15
16#[inline]
17pub fn sum_rolling(data: &[f64], period: usize) -> Result<Vec<f64>, RollingError> {
18    if data.is_empty() {
19        return Err(RollingError::EmptyData);
20    }
21    if period == 0 || period > data.len() {
22        return Err(RollingError::InvalidPeriod {
23            period,
24            data_len: data.len(),
25        });
26    }
27
28    let first_valid_idx = match data.iter().position(|&x| !x.is_nan()) {
29        Some(idx) => idx,
30        None => return Err(RollingError::AllValuesNaN),
31    };
32
33    if (data.len() - first_valid_idx) < period {
34        return Err(RollingError::NotEnoughValidData {
35            needed: period,
36            valid: data.len() - first_valid_idx,
37        });
38    }
39
40    let valid_len = data.len() - first_valid_idx;
41    let mut prefix = Vec::with_capacity(valid_len + 1);
42    prefix.push(0.0);
43
44    for &val in &data[first_valid_idx..] {
45        let prev = *prefix.last().unwrap();
46        prefix.push(prev + val);
47    }
48
49    let mut output = vec![f64::NAN; data.len()];
50    let start_idx = first_valid_idx + period - 1;
51    for i in start_idx..data.len() {
52        let prefix_end = i - first_valid_idx + 1;
53        let prefix_start = prefix_end - period;
54        output[i] = prefix[prefix_end] - prefix[prefix_start];
55    }
56
57    Ok(output)
58}
59
60#[inline]
61pub fn max_rolling(data: &[f64], period: usize) -> Result<Vec<f64>, RollingError> {
62    if data.is_empty() {
63        return Err(RollingError::EmptyData);
64    }
65    if period == 0 || period > data.len() {
66        return Err(RollingError::InvalidPeriod {
67            period,
68            data_len: data.len(),
69        });
70    }
71
72    let first_valid_idx = match data.iter().position(|&x| !x.is_nan()) {
73        Some(idx) => idx,
74        None => return Err(RollingError::AllValuesNaN),
75    };
76    if (data.len() - first_valid_idx) < period {
77        return Err(RollingError::NotEnoughValidData {
78            needed: period,
79            valid: data.len() - first_valid_idx,
80        });
81    }
82
83    let mut output = vec![f64::NAN; data.len()];
84    let mut deque: VecDeque<usize> = VecDeque::with_capacity(period);
85
86    let start_idx = first_valid_idx + period - 1;
87
88    for i in 0..data.len() {
89        if i < first_valid_idx {
90            continue;
91        }
92        let window_start = i.saturating_sub(period - 1);
93
94        while let Some(&front_idx) = deque.front() {
95            if front_idx < window_start {
96                deque.pop_front();
97            } else {
98                break;
99            }
100        }
101
102        let val = data[i];
103        while let Some(&back_idx) = deque.back() {
104            if data[back_idx] <= val {
105                deque.pop_back();
106            } else {
107                break;
108            }
109        }
110        deque.push_back(i);
111
112        if i >= start_idx {
113            let max_idx = *deque.front().unwrap();
114            output[i] = data[max_idx];
115        }
116    }
117
118    Ok(output)
119}
120
121#[inline]
122pub fn min_rolling(data: &[f64], period: usize) -> Result<Vec<f64>, RollingError> {
123    if data.is_empty() {
124        return Err(RollingError::EmptyData);
125    }
126    if period == 0 || period > data.len() {
127        return Err(RollingError::InvalidPeriod {
128            period,
129            data_len: data.len(),
130        });
131    }
132
133    let first_valid_idx = match data.iter().position(|&x| !x.is_nan()) {
134        Some(idx) => idx,
135        None => return Err(RollingError::AllValuesNaN),
136    };
137
138    if (data.len() - first_valid_idx) < period {
139        return Err(RollingError::NotEnoughValidData {
140            needed: period,
141            valid: data.len() - first_valid_idx,
142        });
143    }
144
145    let mut output = vec![f64::NAN; data.len()];
146    let mut deque: VecDeque<usize> = VecDeque::with_capacity(period);
147
148    let start_idx = first_valid_idx + period - 1;
149
150    for i in 0..data.len() {
151        if i < first_valid_idx {
152            continue;
153        }
154        let window_start = i.saturating_sub(period - 1);
155
156        while let Some(&front) = deque.front() {
157            if front < window_start {
158                deque.pop_front();
159            } else {
160                break;
161            }
162        }
163
164        let val = data[i];
165        while let Some(&back_idx) = deque.back() {
166            if data[back_idx] >= val {
167                deque.pop_back();
168            } else {
169                break;
170            }
171        }
172        deque.push_back(i);
173
174        if i >= start_idx {
175            let min_idx = *deque.front().unwrap();
176            output[i] = data[min_idx];
177        }
178    }
179
180    Ok(output)
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    #[test]
188    fn test_sum_rolling_basic() {
189        let data = [1.0, 2.0, 3.0, 4.0, 5.0];
190        let period = 3;
191        let result = sum_rolling(&data, period).unwrap();
192        assert!(result[0].is_nan());
193        assert!(result[1].is_nan());
194        assert_eq!(result[2], 6.0);
195        assert_eq!(result[3], 9.0);
196        assert_eq!(result[4], 12.0);
197    }
198
199    #[test]
200    fn test_sum_rolling_zero_period() {
201        let data = [1.0, 2.0, 3.0];
202        let period = 0;
203        let err = sum_rolling(&data, period).unwrap_err();
204        assert!(
205            err.to_string().contains("Invalid period"),
206            "Expected InvalidPeriod error, got: {}",
207            err
208        );
209    }
210
211    #[test]
212    fn test_sum_rolling_period_exceeding_data_length() {
213        let data = [1.0, 2.0];
214        let period = 5;
215        let err = sum_rolling(&data, period).unwrap_err();
216        assert!(
217            err.to_string().contains("Invalid period"),
218            "Expected InvalidPeriod error, got: {}",
219            err
220        );
221    }
222
223    #[test]
224    fn test_max_rolling_basic() {
225        let data = [2.0, 5.0, 3.0, 8.0, 1.0];
226        let period = 2;
227        let result = max_rolling(&data, period).unwrap();
228        assert!(result[0].is_nan());
229        assert_eq!(result[1], 5.0);
230        assert_eq!(result[2], 5.0);
231        assert_eq!(result[3], 8.0);
232        assert_eq!(result[4], 8.0);
233    }
234
235    #[test]
236    fn test_max_rolling_all_nan() {
237        let data = [f64::NAN, f64::NAN, f64::NAN];
238        let err = max_rolling(&data, 2).unwrap_err();
239        assert!(
240            err.to_string().contains("All values are NaN"),
241            "Expected AllValuesNaN, got {}",
242            err
243        );
244    }
245
246    #[test]
247    fn test_min_rolling_basic() {
248        let data = [5.0, 2.0, 3.0, 1.0, 4.0];
249        let period = 2;
250        let result = min_rolling(&data, period).unwrap();
251        assert!(result[0].is_nan());
252        assert_eq!(result[1], 2.0);
253        assert_eq!(result[2], 2.0);
254        assert_eq!(result[3], 1.0);
255        assert_eq!(result[4], 1.0);
256    }
257
258    #[test]
259    fn test_min_rolling_nan_handling() {
260        let data = [f64::NAN, 5.0, 2.0];
261        let period = 2;
262        let result = min_rolling(&data, period).unwrap();
263        assert!(result[1].is_nan());
264        assert_eq!(result[2], 2.0);
265    }
266
267    #[test]
268    fn test_min_rolling_empty_data() {
269        let data: [f64; 0] = [];
270        let period = 3;
271        let err = min_rolling(&data, period).unwrap_err();
272        assert!(
273            err.to_string().contains("Empty data provided"),
274            "Expected EmptyData, got {}",
275            err
276        );
277    }
278}