1use scirs2_core::ndarray::{ArrayBase, Data, Ix1};
6use scirs2_core::numeric::{Float, FromPrimitive};
7use std::fmt::Display;
8
9use crate::error::{Result, TimeSeriesError};
10use statrs::statistics::Statistics;
11
12#[allow(dead_code)]
14pub fn check_positive<F: Float + Display>(value: F, name: &str) -> Result<()> {
15 if value <= F::zero() {
16 return Err(TimeSeriesError::InvalidParameter {
17 name: name.to_string(),
18 message: format!("Must be positive, got {value}"),
19 });
20 }
21 Ok(())
22}
23
24#[allow(dead_code)]
26pub fn check_non_negative<F: Float + Display>(value: F, name: &str) -> Result<()> {
27 if value < F::zero() {
28 return Err(TimeSeriesError::InvalidParameter {
29 name: name.to_string(),
30 message: format!("Must be non-negative, got {value}"),
31 });
32 }
33 Ok(())
34}
35
36#[allow(dead_code)]
38pub fn check_probability<F: Float + Display>(value: F, name: &str) -> Result<()> {
39 if value < F::zero() || value > F::one() {
40 return Err(TimeSeriesError::InvalidParameter {
41 name: name.to_string(),
42 message: format!("Must be in [0, 1], got {value}"),
43 });
44 }
45 Ok(())
46}
47
48#[allow(dead_code)]
50pub fn check_in_range<F: Float + Display>(value: F, min: F, max: F, name: &str) -> Result<()> {
51 if value < min || value > max {
52 return Err(TimeSeriesError::InvalidParameter {
53 name: name.to_string(),
54 message: format!("Must be in [{min}, {max}], got {value}"),
55 });
56 }
57 Ok(())
58}
59
60#[allow(dead_code)]
62pub fn check_array_length<S, F>(
63 data: &ArrayBase<S, Ix1>,
64 min_length: usize,
65 operation: &str,
66) -> Result<()>
67where
68 S: Data<Elem = F>,
69 F: Float,
70{
71 if data.len() < min_length {
72 return Err(TimeSeriesError::InsufficientData {
73 message: format!("for {operation}"),
74 required: min_length,
75 actual: data.len(),
76 });
77 }
78 Ok(())
79}
80
81#[allow(dead_code)]
83pub fn check_same_length<S1, S2, F>(
84 arr1: &ArrayBase<S1, Ix1>,
85 arr2: &ArrayBase<S2, Ix1>,
86 _name1: &str,
87 name2: &str,
88) -> Result<()>
89where
90 S1: Data<Elem = F>,
91 S2: Data<Elem = F>,
92 F: Float,
93{
94 if arr1.len() != arr2.len() {
95 return Err(TimeSeriesError::DimensionMismatch {
96 expected: arr1.len(),
97 actual: arr2.len(),
98 });
99 }
100 Ok(())
101}
102
103#[allow(dead_code)]
105pub fn validate_arima_orders(p: usize, d: usize, q: usize) -> Result<()> {
106 if p > 10 {
107 return Err(TimeSeriesError::InvalidParameter {
108 name: "p".to_string(),
109 message: format!("AR order too large: {p}"),
110 });
111 }
112 if d > 3 {
113 return Err(TimeSeriesError::InvalidParameter {
114 name: "d".to_string(),
115 message: format!("Differencing order too large: {d}"),
116 });
117 }
118 if q > 10 {
119 return Err(TimeSeriesError::InvalidParameter {
120 name: "q".to_string(),
121 message: format!("MA order too large: {q}"),
122 });
123 }
124 Ok(())
125}
126
127#[allow(dead_code)]
129pub fn validate_seasonal_arima_orders(
130 p: usize,
131 d: usize,
132 q: usize,
133 p_seasonal: usize,
134 d_seasonal: usize,
135 q_seasonal: usize,
136 period: usize,
137) -> Result<()> {
138 validate_arima_orders(p, d, q)?;
139
140 if p_seasonal > 5 {
141 return Err(TimeSeriesError::InvalidParameter {
142 name: "p_seasonal".to_string(),
143 message: format!("Seasonal AR order too large: {p_seasonal}"),
144 });
145 }
146 if d_seasonal > 2 {
147 return Err(TimeSeriesError::InvalidParameter {
148 name: "d_seasonal".to_string(),
149 message: format!("Seasonal differencing order too large: {d_seasonal}"),
150 });
151 }
152 if q_seasonal > 5 {
153 return Err(TimeSeriesError::InvalidParameter {
154 name: "q_seasonal".to_string(),
155 message: format!("Seasonal MA order too large: {q_seasonal}"),
156 });
157 }
158 if period < 2 {
159 return Err(TimeSeriesError::InvalidParameter {
160 name: "period".to_string(),
161 message: format!("Period must be at least 2, got {period}"),
162 });
163 }
164 if period > 365 {
165 return Err(TimeSeriesError::InvalidParameter {
166 name: "period".to_string(),
167 message: format!("Period too large: {period}"),
168 });
169 }
170
171 Ok(())
172}
173
174#[allow(dead_code)]
176pub fn validate_forecast_horizon(_steps: usize, maxreasonable: Option<usize>) -> Result<()> {
177 if _steps == 0 {
178 return Err(TimeSeriesError::InvalidParameter {
179 name: "_steps".to_string(),
180 message: "Forecast horizon must be positive".to_string(),
181 });
182 }
183
184 let max = maxreasonable.unwrap_or(10000);
185 if _steps > max {
186 return Err(TimeSeriesError::InvalidParameter {
187 name: "_steps".to_string(),
188 message: format!("Forecast horizon too large: {_steps}"),
189 });
190 }
191
192 Ok(())
193}
194
195#[allow(dead_code)]
197pub fn validate_window_size(_window: usize, datalength: usize) -> Result<()> {
198 if _window == 0 {
199 return Err(TimeSeriesError::InvalidParameter {
200 name: "_window".to_string(),
201 message: "Window size must be positive".to_string(),
202 });
203 }
204
205 if _window > datalength {
206 return Err(TimeSeriesError::InvalidParameter {
207 name: "_window".to_string(),
208 message: format!("Window size {_window} exceeds data _length {datalength}"),
209 });
210 }
211
212 Ok(())
213}
214
215#[allow(dead_code)]
217pub fn validate_lag(_lag: usize, datalength: usize) -> Result<()> {
218 if _lag >= datalength {
219 return Err(TimeSeriesError::InvalidParameter {
220 name: "_lag".to_string(),
221 message: format!("Lag {_lag} must be less than data _length {datalength}"),
222 });
223 }
224 Ok(())
225}
226
227#[allow(dead_code)]
229pub fn check_no_missing<S, F>(data: &ArrayBase<S, Ix1>) -> Result<()>
230where
231 S: Data<Elem = F>,
232 F: Float,
233{
234 for (i, &x) in data.iter().enumerate() {
235 if x.is_nan() || x.is_infinite() {
236 return Err(TimeSeriesError::InvalidInput(format!(
237 "Non-finite value at index {i}"
238 )));
239 }
240 }
241 Ok(())
242}
243
244#[allow(dead_code)]
246pub fn check_stationarity_basic<S, F>(data: &ArrayBase<S, Ix1>) -> Result<bool>
247where
248 S: Data<Elem = F>,
249 F: Float + FromPrimitive,
250{
251 check_array_length(data, 10, "stationarity check")?;
252
253 let mid = data.len() / 2;
255 let first_half = data.slice(scirs2_core::ndarray::s![..mid]);
256 let second_half = data.slice(scirs2_core::ndarray::s![mid..]);
257
258 let mean1 = first_half.mean().unwrap_or(F::zero());
260 let mean2 = second_half.mean().unwrap_or(F::zero());
261
262 let var1 = first_half
263 .mapv(|x| (x - mean1) * (x - mean1))
264 .mean()
265 .unwrap_or(F::zero());
266 let var2 = second_half
267 .mapv(|x| (x - mean2) * (x - mean2))
268 .mean()
269 .unwrap_or(F::zero());
270
271 let mean_diff = (mean1 - mean2).abs();
273 let var_ratio = if var1 > F::zero() && var2 > F::zero() {
274 (var1 / var2).max(var2 / var1)
275 } else {
276 F::one()
277 };
278
279 let mean_threshold =
281 F::from(0.2).unwrap() * (var1.sqrt() + var2.sqrt()) / F::from(2.0).unwrap();
282 let var_threshold = F::from(2.0).unwrap();
283
284 Ok(mean_diff < mean_threshold && var_ratio < var_threshold)
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290 use scirs2_core::ndarray::array;
291
292 #[test]
293 fn test_check_positive() {
294 assert!(check_positive(1.0, "value").is_ok());
295 assert!(check_positive(0.0, "value").is_err());
296 assert!(check_positive(-1.0, "value").is_err());
297 }
298
299 #[test]
300 fn test_check_probability() {
301 assert!(check_probability(0.5, "prob").is_ok());
302 assert!(check_probability(0.0, "prob").is_ok());
303 assert!(check_probability(1.0, "prob").is_ok());
304 assert!(check_probability(1.1, "prob").is_err());
305 assert!(check_probability(-0.1, "prob").is_err());
306 }
307
308 #[test]
309 fn test_check_array_length() {
310 let arr = array![1.0, 2.0, 3.0];
311 assert!(check_array_length(&arr, 3, "test").is_ok());
312 assert!(check_array_length(&arr, 4, "test").is_err());
313 }
314
315 #[test]
316 fn test_validate_arima_orders() {
317 assert!(validate_arima_orders(2, 1, 2).is_ok());
318 assert!(validate_arima_orders(11, 1, 1).is_err());
319 assert!(validate_arima_orders(1, 4, 1).is_err());
320 assert!(validate_arima_orders(1, 1, 11).is_err());
321 }
322}