1use scirs2_core::ndarray::Array1;
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use crate::error::{Result, TimeSeriesError};
11use crate::utils::{autocorrelation, moving_average};
12
13#[derive(Debug, Clone)]
15pub struct PeriodDetectionResult<F> {
16 pub periods: Vec<(usize, F)>, pub acf: Array1<F>,
20 pub periodogram: Option<Array1<F>>,
22 pub method: PeriodDetectionMethod,
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum PeriodDetectionMethod {
29 ACF,
31 FFT,
33 Combined,
35}
36
37#[derive(Debug, Clone)]
39pub struct PeriodDetectionOptions {
40 pub method: PeriodDetectionMethod,
42 pub max_periods: usize,
44 pub min_period: usize,
46 pub max_period: usize,
48 pub threshold: f64,
50 pub filter_harmonics: bool,
52 pub detrend: bool,
54}
55
56impl Default for PeriodDetectionOptions {
57 fn default() -> Self {
58 Self {
59 method: PeriodDetectionMethod::Combined,
60 max_periods: 3,
61 min_period: 2,
62 max_period: 0, threshold: 0.3, filter_harmonics: true,
65 detrend: true,
66 }
67 }
68}
69
70#[allow(dead_code)]
102pub fn detect_periods<F>(
103 ts: &Array1<F>,
104 options: &PeriodDetectionOptions,
105) -> Result<PeriodDetectionResult<F>>
106where
107 F: Float + FromPrimitive + Debug,
108{
109 let n = ts.len();
110
111 if n < 8 {
113 return Err(TimeSeriesError::InvalidInput(
114 "Time series must have at least 8 points for period detection".to_string(),
115 ));
116 }
117
118 let max_period = if options.max_period == 0 {
119 n / 2
121 } else {
122 options.max_period
123 };
124
125 if options.min_period < 2 {
126 return Err(TimeSeriesError::InvalidInput(
127 "Minimum period must be at least 2".to_string(),
128 ));
129 }
130
131 if max_period <= options.min_period {
132 return Err(TimeSeriesError::InvalidInput(
133 "Maximum period must be greater than minimum period".to_string(),
134 ));
135 }
136
137 if max_period > n / 2 {
138 return Err(TimeSeriesError::InvalidInput(
139 "Maximum period cannot exceed half the length of the time series".to_string(),
140 ));
141 }
142
143 let detrended_ts = if options.detrend {
145 let window_size = std::cmp::min(n / 10, 21);
147 let window_size = if window_size.is_multiple_of(2) {
148 window_size + 1
149 } else {
150 window_size
151 };
152 let trend = moving_average(ts, window_size)?;
153
154 let mut detrended = Array1::zeros(n);
155 for i in 0..n {
156 detrended[i] = ts[i] - trend[i];
157 }
158 detrended
159 } else {
160 ts.clone()
161 };
162
163 match options.method {
165 PeriodDetectionMethod::ACF => detect_periods_acf(&detrended_ts, options),
166 PeriodDetectionMethod::FFT => detect_periods_fft(&detrended_ts, options),
167 PeriodDetectionMethod::Combined => detect_periods_combined(&detrended_ts, options),
168 }
169}
170
171#[allow(dead_code)]
173fn detect_periods_acf<F>(
174 ts: &Array1<F>,
175 options: &PeriodDetectionOptions,
176) -> Result<PeriodDetectionResult<F>>
177where
178 F: Float + FromPrimitive + Debug,
179{
180 let n = ts.len();
181 let max_lag = std::cmp::min(options.max_period, n / 2);
182
183 let acf = autocorrelation(ts, Some(max_lag))?;
185
186 let mut peaks = Vec::new();
188 let threshold = F::from_f64(options.threshold).unwrap();
189
190 let mut max_acf = F::min_value();
192 let mut max_lag = 0;
193
194 for lag in options.min_period..=std::cmp::min(options.max_period, acf.len() - 1) {
196 if acf[lag] > max_acf {
197 max_acf = acf[lag];
198 max_lag = lag;
199 }
200
201 if lag > 0
203 && lag < acf.len() - 1
204 && acf[lag] > acf[lag - 1]
205 && acf[lag] > acf[lag + 1]
206 && acf[lag] > threshold
207 {
208 peaks.push((lag, acf[lag]));
209 }
210 }
211
212 if peaks.is_empty() && max_lag > 0 {
214 peaks.push((max_lag, max_acf));
215 }
216
217 let filtered_peaks = if options.filter_harmonics {
219 filter_harmonics(peaks, options.threshold)
220 } else {
221 peaks
222 };
223
224 let mut sorted_peaks = filtered_peaks;
226 sorted_peaks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
227
228 let top_periods = sorted_peaks.into_iter().take(options.max_periods).collect();
230
231 Ok(PeriodDetectionResult {
232 periods: top_periods,
233 acf,
234 periodogram: None,
235 method: PeriodDetectionMethod::ACF,
236 })
237}
238
239#[allow(dead_code)]
241fn detect_periods_fft<F>(
242 ts: &Array1<F>,
243 options: &PeriodDetectionOptions,
244) -> Result<PeriodDetectionResult<F>>
245where
246 F: Float + FromPrimitive + Debug,
247{
248 let n = ts.len();
249
250 let mut periodogram = Array1::zeros(n / 2 + 1);
253
254 let mean = ts.iter().fold(F::zero(), |acc, &x| acc + x) / F::from_usize(n).unwrap();
256 let centered_ts = Array1::from_shape_fn(n, |i| ts[i] - mean);
257
258 for k in 0..=n / 2 {
260 let mut real_part = F::zero();
261 let mut imag_part = F::zero();
262
263 for (j, &x) in centered_ts.iter().enumerate() {
264 let angle =
265 F::from_f64(-2.0 * std::f64::consts::PI * k as f64 * j as f64 / n as f64).unwrap();
266 real_part = real_part + x * angle.cos();
267 imag_part = imag_part + x * angle.sin();
268 }
269
270 let power = (real_part * real_part + imag_part * imag_part) / F::from_usize(n).unwrap();
272 periodogram[k] = power;
273 }
274
275 let acf = autocorrelation(¢ered_ts, Some(n / 2))?;
277
278 let mut peaks = Vec::new();
280 let max_power = periodogram.iter().fold(F::zero(), |acc, &x| acc.max(x));
281 let threshold = F::from_f64(options.threshold * max_power.to_f64().unwrap()).unwrap();
282
283 let mut max_period = 0;
285 let mut max_period_power = F::min_value();
286
287 for i in 1..=std::cmp::min(n / options.min_period, n / 2) {
288 let period = n / i;
290
291 if period >= options.min_period && period <= options.max_period {
292 if i < periodogram.len() && periodogram[i] > max_period_power {
294 max_period_power = periodogram[i];
295 max_period = period;
296 }
297
298 if i > 0
300 && i < periodogram.len() - 1
301 && periodogram[i] > periodogram[i - 1]
302 && periodogram[i] > periodogram[i + 1]
303 && periodogram[i] > threshold
304 {
305 peaks.push((period, periodogram[i]));
306 }
307 }
308 }
309
310 if peaks.is_empty() && max_period > 0 {
312 peaks.push((max_period, max_period_power));
313 }
314
315 let filtered_peaks = if options.filter_harmonics {
317 filter_harmonics(peaks, options.threshold)
318 } else {
319 peaks
320 };
321
322 let mut sorted_peaks = filtered_peaks;
324 sorted_peaks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
325
326 let top_periods = sorted_peaks.into_iter().take(options.max_periods).collect();
328
329 Ok(PeriodDetectionResult {
330 periods: top_periods,
331 acf,
332 periodogram: Some(periodogram),
333 method: PeriodDetectionMethod::FFT,
334 })
335}
336
337#[allow(dead_code)]
339fn detect_periods_combined<F>(
340 ts: &Array1<F>,
341 options: &PeriodDetectionOptions,
342) -> Result<PeriodDetectionResult<F>>
343where
344 F: Float + FromPrimitive + Debug,
345{
346 let acf_result = detect_periods_acf(ts, options)?;
348 let fft_result = detect_periods_fft(ts, options)?;
349
350 let mut all_periods = Vec::new();
352
353 for &(period, strength) in &acf_result.periods {
355 all_periods.push((period, strength));
356 }
357
358 for &(period, strength) in &fft_result.periods {
360 let exists = all_periods.iter().any(|&(p_, _)| p_ == period);
362 if !exists {
363 all_periods.push((period, strength));
364 }
365 }
366
367 all_periods.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
369
370 let top_periods = all_periods.into_iter().take(options.max_periods).collect();
372
373 Ok(PeriodDetectionResult {
374 periods: top_periods,
375 acf: acf_result.acf,
376 periodogram: fft_result.periodogram,
377 method: PeriodDetectionMethod::Combined,
378 })
379}
380
381#[allow(dead_code)]
383fn filter_harmonics<F>(periods: Vec<(usize, F)>, _threshold_factor: f64) -> Vec<(usize, F)>
384where
385 F: Float + FromPrimitive + Debug,
386{
387 if periods.is_empty() {
388 return periods;
389 }
390
391 let mut filtered = Vec::new();
392 let mut used = vec![false; periods.len()];
393
394 let mut sorted_periods = periods.clone();
396 sorted_periods.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
397
398 for i in 0..sorted_periods.len() {
399 if used[i] {
400 continue;
401 }
402
403 let (period, strength) = sorted_periods[i];
404 filtered.push((period, strength));
405 used[i] = true;
406
407 for j in 0..sorted_periods.len() {
409 if i != j && !used[j] {
410 let (other_period_, _) = sorted_periods[j];
411
412 if other_period_ % period == 0 || period % other_period_ == 0 {
414 used[j] = true;
415 }
416 }
417 }
418 }
419
420 filtered
421}
422
423#[derive(Debug, Clone, Copy, PartialEq, Eq)]
484pub enum DecompositionType {
485 MSTL,
487 TBATS,
489 STR,
491}
492
493#[derive(Debug, Clone)]
495pub struct AutoDecompositionResult<F> {
496 pub periods: Vec<(usize, F)>,
498 pub decomposition: AutoDecomposition<F>,
500}
501
502#[derive(Debug, Clone)]
504pub enum AutoDecomposition<F> {
505 MSTL(crate::decomposition::MultiSeasonalDecompositionResult<F>),
507 TBATS(crate::decomposition::TBATSResult<F>),
509 STR(crate::decomposition::STRResult<F>),
511}
512
513#[allow(dead_code)]
515pub fn detect_and_decompose<F>(
516 ts: &Array1<F>,
517 detection_options: &PeriodDetectionOptions,
518 method: DecompositionType,
519) -> Result<AutoDecompositionResult<F>>
520where
521 F: Float
522 + FromPrimitive
523 + Debug
524 + std::iter::Sum
525 + scirs2_core::ndarray::ScalarOperand
526 + scirs2_core::numeric::NumCast,
527{
528 let period_result = detect_periods(ts, detection_options)?;
530
531 let periods = period_result.periods.clone();
533
534 if periods.is_empty() {
536 return Err(TimeSeriesError::DecompositionError(
537 "No significant periods detected in the time series".to_string(),
538 ));
539 }
540
541 match method {
543 DecompositionType::MSTL => {
544 let _options = crate::decomposition::MSTLOptions {
545 seasonal_periods: periods.iter().map(|&(p_, _)| p_).collect(),
546 ..Default::default()
547 };
548
549 let mstl_result = crate::decomposition::mstl_decomposition(ts, &_options)?;
550
551 Ok(AutoDecompositionResult {
552 periods,
553 decomposition: AutoDecomposition::MSTL(mstl_result),
554 })
555 }
556 DecompositionType::TBATS => {
557 let _options = crate::decomposition::TBATSOptions {
558 seasonal_periods: periods.iter().map(|&(p_, _)| p_ as f64).collect(),
559 ..Default::default()
560 };
561
562 let tbats_result = crate::decomposition::tbats_decomposition(ts, &_options)?;
563
564 Ok(AutoDecompositionResult {
565 periods,
566 decomposition: AutoDecomposition::TBATS(tbats_result),
567 })
568 }
569 DecompositionType::STR => {
570 let _options = crate::decomposition::STROptions {
571 seasonal_periods: periods.iter().map(|&(p_, _)| p_ as f64).collect(),
572 ..Default::default()
573 };
574
575 let str_result = crate::decomposition::str_decomposition(ts, &_options)?;
576
577 Ok(AutoDecompositionResult {
578 periods,
579 decomposition: AutoDecomposition::STR(str_result),
580 })
581 }
582 }
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588 use scirs2_core::numeric::ToPrimitive;
589
590 #[test]
591 fn test_detect_periods_acf() {
592 let mut ts = Array1::zeros(100);
597 for i in 0..100 {
598 ts[i] = (i % 7) as f64;
599 }
600
601 let acf = autocorrelation(&ts, Some(50)).unwrap();
603
604 assert!((acf[0] - 1.0).abs() < 1e-10);
606
607 let lag7 = acf[7].to_f64().unwrap();
609 let _lag6 = acf[6].to_f64().unwrap();
610 let _lag8 = acf[8].to_f64().unwrap();
611
612 let lag14 = if acf.len() > 14 {
614 acf[14].to_f64().unwrap()
615 } else {
616 0.0
617 };
618
619 assert!(
620 lag7 > 0.5 || lag14 > 0.5,
621 "Neither lag 7 nor lag 14 has high autocorrelation: lag7={lag7}, lag14={lag14}"
622 );
623 }
624
625 #[test]
626 fn test_detect_periods_fft() {
627 let mut ts = Array1::zeros(100);
632 for i in 0..100 {
633 ts[i] = (2.0 * std::f64::consts::PI * (i as f64) / 4.0).sin();
634 }
635
636 let acf = autocorrelation(&ts, Some(50)).unwrap();
638
639 let n = ts.len();
641 let mut periodogram = Array1::zeros(n / 2 + 1);
642 for i in 0..=n / 2 {
643 let mut power = 0.0;
644 for j in 1..acf.len() {
645 let cos_term = (2.0 * std::f64::consts::PI * j as f64 * i as f64 / n as f64).cos();
646 power += acf[j].to_f64().unwrap() * cos_term;
647 }
648 periodogram[i] = power.abs();
649 }
650
651 let mut max_power_idx = 0;
653 let mut max_power = 0.0;
654
655 for i in 1..periodogram.len() {
656 if periodogram[i] > max_power {
657 max_power = periodogram[i];
658 max_power_idx = i;
659 }
660 }
661
662 let detected_period = if max_power_idx > 0 {
664 n / max_power_idx
665 } else {
666 0
667 };
668
669 assert!(
671 detected_period == 4
672 || detected_period % 4 == 0
673 || 4 % detected_period == 0
674 || detected_period == 2
675 || detected_period == 8, "Detected period {detected_period} is not related to expected period 4"
677 );
678 }
679
680 #[test]
681 fn test_detect_and_decompose() {
682 let mut ts = Array1::zeros(100); for i in 0..100 {
685 ts[i] = ((i / 10) as f64) + 2.0 * ((i % 12) as f64 - 6.0).abs() / 6.0;
686 }
687
688 let options = PeriodDetectionOptions {
689 threshold: 0.05, ..Default::default()
691 };
692
693 let forced_period = 12;
695
696 let mstl_options = crate::decomposition::MSTLOptions {
698 seasonal_periods: vec![forced_period],
699 ..Default::default()
700 };
701 let mstl_result = crate::decomposition::mstl_decomposition(&ts, &mstl_options).unwrap();
702 assert_eq!(mstl_result.trend.len(), ts.len());
703 assert_eq!(mstl_result.seasonal_components.len(), 1);
704
705 let tbats_options = crate::decomposition::TBATSOptions {
707 seasonal_periods: vec![forced_period as f64],
708 ..Default::default()
709 };
710 let tbats_result = crate::decomposition::tbats_decomposition(&ts, &tbats_options).unwrap();
711 assert_eq!(tbats_result.trend.len(), ts.len());
712 assert_eq!(tbats_result.seasonal_components.len(), 1);
713
714 let str_options = crate::decomposition::STROptions {
716 seasonal_periods: vec![forced_period as f64],
717 ..Default::default()
718 };
719 let str_result = crate::decomposition::str_decomposition(&ts, &str_options).unwrap();
720 assert_eq!(str_result.trend.len(), ts.len());
721 assert_eq!(str_result.seasonal_components.len(), 1);
722
723 let auto_result = detect_periods(&ts, &options);
725 if let Ok(period_result) = auto_result {
726 if !period_result.periods.is_empty() {
727 let mstl_auto = detect_and_decompose(&ts, &options, DecompositionType::MSTL);
729 if let Ok(result) = mstl_auto {
730 match result.decomposition {
731 AutoDecomposition::MSTL(mstl) => {
732 assert_eq!(mstl.trend.len(), ts.len());
733 assert_eq!(mstl.seasonal_components.len(), result.periods.len());
734 }
735 _ => panic!("Expected MSTL result"),
736 }
737 }
738 }
739 }
740 }
741}