1use scirs2_core::ndarray::{s, Array1, Array2, ScalarOperand};
4use scirs2_core::numeric::{Float, FromPrimitive, NumCast};
5use scirs2_linalg::{inv, solve};
6use std::fmt::Debug;
7
8use crate::error::{Result, TimeSeriesError};
9
10#[derive(Debug, Clone)]
12pub struct STROptions {
13 pub regularization_type: RegularizationType,
15 pub trend_lambda: f64,
17 pub seasonal_lambda: f64,
19 pub seasonal_periods: Vec<f64>,
21 pub robust: bool,
23 pub compute_confidence_intervals: bool,
25 pub confidence_level: f64,
27 pub trend_degrees: usize,
29 pub flexible_seasonal: bool,
31 pub seasonal_harmonics: Option<Vec<usize>>,
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum RegularizationType {
38 Ridge,
40 Lasso,
42 ElasticNet,
44}
45
46impl Default for STROptions {
47 fn default() -> Self {
48 Self {
49 regularization_type: RegularizationType::Ridge,
50 trend_lambda: 10.0,
51 seasonal_lambda: 0.5,
52 seasonal_periods: Vec::new(),
53 robust: false,
54 compute_confidence_intervals: false,
55 confidence_level: 0.95,
56 trend_degrees: 3,
57 flexible_seasonal: false,
58 seasonal_harmonics: None,
59 }
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct STRResult<F> {
66 pub trend: Array1<F>,
68 pub seasonal_components: Vec<Array1<F>>,
70 pub residual: Array1<F>,
72 pub original: Array1<F>,
74 pub trend_ci: Option<(Array1<F>, Array1<F>)>, pub seasonal_ci: Option<Vec<(Array1<F>, Array1<F>)>>, }
79
80#[allow(dead_code)]
113pub fn str_decomposition<F>(ts: &Array1<F>, options: &STROptions) -> Result<STRResult<F>>
114where
115 F: Float + FromPrimitive + Debug + ScalarOperand + NumCast + std::iter::Sum,
116{
117 let n = ts.len();
118
119 if n < 3 {
121 return Err(TimeSeriesError::DecompositionError(
122 "Time series must have at least 3 points for STR decomposition".to_string(),
123 ));
124 }
125
126 if options.seasonal_periods.is_empty() {
127 return Err(TimeSeriesError::DecompositionError(
128 "At least one seasonal period must be specified for STR".to_string(),
129 ));
130 }
131
132 for &period in &options.seasonal_periods {
133 if period <= 1.0 {
134 return Err(TimeSeriesError::DecompositionError(
135 "Seasonal periods must be greater than 1".to_string(),
136 ));
137 }
138 }
139
140 if options.trend_lambda < 0.0 || options.seasonal_lambda < 0.0 {
141 return Err(TimeSeriesError::DecompositionError(
142 "Regularization parameters must be non-negative".to_string(),
143 ));
144 }
145
146 if options.confidence_level <= 0.0 || options.confidence_level >= 1.0 {
147 return Err(TimeSeriesError::DecompositionError(
148 "Confidence level must be between 0 and 1".to_string(),
149 ));
150 }
151
152 let time_indices: Array1<F> = Array1::from_iter((0..n).map(|i| F::from_usize(i).unwrap()));
154
155 let trend_degree = options.trend_degrees;
157 let mut trend_basis = Array2::zeros((n, trend_degree + 1));
158
159 for i in 0..n {
161 for j in 0..=trend_degree {
162 if j == 0 {
163 trend_basis[[i, j]] = F::one(); } else {
165 let time_idx = time_indices[i];
166 trend_basis[[i, j]] = Float::powf(time_idx, F::from_usize(j).unwrap());
167 }
168 }
169 }
170
171 let mut seasonal_bases = Vec::with_capacity(options.seasonal_periods.len());
173 let mut total_seasonal_cols = 0;
174
175 for (idx, &period) in options.seasonal_periods.iter().enumerate() {
176 let harmonics = if let Some(ref harms) = options.seasonal_harmonics {
178 harms
179 .get(idx)
180 .copied()
181 .unwrap_or(((period / 2.0).floor() as usize).max(1))
182 } else {
183 ((period / 2.0).floor() as usize).max(1)
184 };
185
186 let mut seasonal_basis = Array2::zeros((n, 2 * harmonics)); for i in 0..n {
189 let t = time_indices[i];
190 for j in 0..harmonics {
191 let freq =
192 F::from_f64(2.0 * std::f64::consts::PI * (j + 1) as f64 / period).unwrap();
193 seasonal_basis[[i, 2 * j]] = Float::sin(freq * t);
195 seasonal_basis[[i, 2 * j + 1]] = Float::cos(freq * t);
197 }
198 }
199
200 total_seasonal_cols += 2 * harmonics;
201 seasonal_bases.push(seasonal_basis);
202 }
203
204 let total_cols = trend_degree + 1 + total_seasonal_cols;
206 let mut design_matrix = Array2::zeros((n, total_cols));
207
208 design_matrix
210 .slice_mut(s![.., 0..=trend_degree])
211 .assign(&trend_basis);
212
213 let mut col_offset = trend_degree + 1;
215 for seasonal_basis in &seasonal_bases {
216 let next_offset = col_offset + seasonal_basis.ncols();
217 design_matrix
218 .slice_mut(s![.., col_offset..next_offset])
219 .assign(seasonal_basis);
220 col_offset = next_offset;
221 }
222
223 let mut regularization_matrix = Array2::zeros((total_cols, total_cols));
225
226 for i in 0..=trend_degree {
228 let weight = if i == 0 {
229 0.0 } else {
231 options.trend_lambda * (i as f64).powi(2)
232 };
233 regularization_matrix[[i, i]] = F::from_f64(weight).unwrap();
234 }
235
236 col_offset = trend_degree + 1;
238 for seasonal_basis in &seasonal_bases {
239 let seasonal_cols = seasonal_basis.ncols();
240 for i in 0..seasonal_cols {
241 regularization_matrix[[col_offset + i, col_offset + i]] =
242 F::from(options.seasonal_lambda).unwrap();
243 }
244 col_offset += seasonal_cols;
245 }
246
247 let xtx = design_matrix.t().dot(&design_matrix);
250 let xty = design_matrix.t().dot(ts);
251
252 let system_matrix = xtx + regularization_matrix;
254
255 let coefficients = match options.regularization_type {
257 RegularizationType::Ridge => {
258 solve_regularized_system(&system_matrix, &xty)?
260 }
261 RegularizationType::Lasso => {
262 solve_lasso(
264 &design_matrix,
265 ts,
266 options.seasonal_lambda,
267 1000,
268 F::from(1e-6).unwrap(),
269 )?
270 }
271 RegularizationType::ElasticNet => {
272 solve_elastic_net(
274 &design_matrix,
275 ts,
276 options.seasonal_lambda,
277 options.trend_lambda,
278 1000,
279 F::from(1e-6).unwrap(),
280 )?
281 }
282 };
283
284 let trend_coeffs = coefficients.slice(s![0..=trend_degree]);
287 let trend = trend_basis.dot(&trend_coeffs);
288
289 let mut seasonal_components = Vec::with_capacity(options.seasonal_periods.len());
291 col_offset = trend_degree + 1;
292
293 for seasonal_basis in &seasonal_bases {
294 let seasonal_cols = seasonal_basis.ncols();
295 let seasonal_coeffs = coefficients.slice(s![col_offset..col_offset + seasonal_cols]);
296 let seasonal_component = seasonal_basis.dot(&seasonal_coeffs);
297 seasonal_components.push(seasonal_component);
298 col_offset += seasonal_cols;
299 }
300
301 let mut residual = ts.clone();
303 for i in 0..n {
304 residual[i] = residual[i] - trend[i];
305 for seasonal_component in &seasonal_components {
306 residual[i] = residual[i] - seasonal_component[i];
307 }
308 }
309
310 let (trend_ci, seasonal_ci) = if options.compute_confidence_intervals {
312 compute_confidence_intervals(
313 &design_matrix,
314 &system_matrix,
315 &residual,
316 &trend_basis,
317 &seasonal_bases,
318 options.confidence_level,
319 )?
320 } else {
321 (None, None)
322 };
323
324 let result = STRResult {
326 trend,
327 seasonal_components,
328 residual,
329 original: ts.clone(),
330 trend_ci,
331 seasonal_ci,
332 };
333
334 Ok(result)
335}
336
337type ConfidenceInterval<F> = (Array1<F>, Array1<F>);
339
340type ConfidenceIntervalsResult<F> = Result<(
342 Option<ConfidenceInterval<F>>,
343 Option<Vec<ConfidenceInterval<F>>>,
344)>;
345
346#[allow(dead_code)]
348fn compute_confidence_intervals<F>(
349 design_matrix: &Array2<F>,
350 system_matrix: &Array2<F>,
351 residual: &Array1<F>,
352 trend_basis: &Array2<F>,
353 seasonal_bases: &[Array2<F>],
354 confidence_level: f64,
355) -> ConfidenceIntervalsResult<F>
356where
357 F: Float + FromPrimitive + Debug + ScalarOperand + NumCast + std::iter::Sum,
358{
359 let n = residual.len();
360 let p = design_matrix.ncols();
361
362 if n <= p {
363 return Ok((None, None));
364 }
365
366 let residual_variance = residual.mapv(|x| x * x).sum() / F::from_usize(n - p).unwrap();
368
369 let covariance_matrix = match matrix_inverse(system_matrix) {
371 Ok(inv) => inv.mapv(|x| x * residual_variance),
372 Err(_) => return Ok((None, None)), };
374
375 let alpha = 1.0 - confidence_level;
377 let df = n - p;
378 let t_critical = if df > 30 {
379 match alpha {
381 a if a <= 0.01 => F::from(2.576).unwrap(), a if a <= 0.05 => F::from(1.96).unwrap(), _ => F::from(1.645).unwrap(), }
385 } else {
386 let base = F::from(2.0).unwrap();
388 base + F::from(df as f64).unwrap().recip()
389 };
390
391 let trend_se = compute_component_standard_errors(trend_basis, &covariance_matrix)?;
393 let trend_margin = trend_se.mapv(|se| se * t_critical);
394 let trend_fitted = trend_basis.dot(&covariance_matrix.diag().slice(s![0..trend_basis.ncols()]));
395 let trend_lower = &trend_fitted - &trend_margin;
396 let trend_upper = &trend_fitted + &trend_margin;
397
398 let mut seasonal_cis = Vec::new();
400 let mut col_offset = trend_basis.ncols();
401
402 for seasonal_basis in seasonal_bases {
403 let seasonal_cols = seasonal_basis.ncols();
404 let seasonal_cov = covariance_matrix.slice(s![
405 col_offset..col_offset + seasonal_cols,
406 col_offset..col_offset + seasonal_cols
407 ]);
408
409 let seasonal_se =
410 compute_component_standard_errors(seasonal_basis, &seasonal_cov.to_owned())?;
411 let seasonal_margin = seasonal_se.mapv(|se| se * t_critical);
412 let seasonal_fitted = seasonal_basis.dot(&seasonal_cov.diag());
413 let seasonal_lower = &seasonal_fitted - &seasonal_margin;
414 let seasonal_upper = &seasonal_fitted + &seasonal_margin;
415
416 seasonal_cis.push((seasonal_lower, seasonal_upper));
417 col_offset += seasonal_cols;
418 }
419
420 Ok((Some((trend_lower, trend_upper)), Some(seasonal_cis)))
421}
422
423#[allow(dead_code)]
425fn compute_component_standard_errors<F>(
426 basis: &Array2<F>,
427 covariance: &Array2<F>,
428) -> Result<Array1<F>>
429where
430 F: Float + FromPrimitive + Debug + ScalarOperand + NumCast + std::iter::Sum,
431{
432 let n = basis.nrows();
433 let mut standard_errors = Array1::zeros(n);
434
435 for i in 0..n {
436 let basis_row = basis.row(i);
437 let variance = basis_row.dot(&covariance.dot(&basis_row));
438 standard_errors[i] = variance.sqrt();
439 }
440
441 Ok(standard_errors)
442}
443
444#[allow(dead_code)]
446fn solve_regularized_system<F>(a: &Array2<F>, b: &Array1<F>) -> Result<Array1<F>>
447where
448 F: Float + FromPrimitive + ScalarOperand + NumCast + 'static,
449{
450 let n = a.shape()[0];
451 if n != a.shape()[1] || n != b.len() {
452 return Err(TimeSeriesError::DecompositionError(
453 "Matrix dimensions mismatch".to_string(),
454 ));
455 }
456
457 let a_f64 = a.mapv(|x| x.to_f64().unwrap_or(0.0));
459 let b_f64 = b.mapv(|x| x.to_f64().unwrap_or(0.0));
460
461 let x_f64 = solve(&a_f64.view(), &b_f64.view(), None)
463 .map_err(|e| TimeSeriesError::DecompositionError(format!("Linear solve failed: {e}")))?;
464
465 let x = x_f64.mapv(|val| F::from_f64(val).unwrap_or_else(F::zero));
467
468 Ok(x)
469}
470
471#[allow(dead_code)]
473fn solve_lasso<F>(
474 x: &Array2<F>,
475 y: &Array1<F>,
476 lambda: f64,
477 max_iter: usize,
478 tol: F,
479) -> Result<Array1<F>>
480where
481 F: Float + FromPrimitive + ScalarOperand + NumCast + std::iter::Sum,
482{
483 let (n, p) = (x.nrows(), x.ncols());
484 let mut beta = Array1::zeros(p);
485 let lambda_f = F::from(lambda).unwrap();
486
487 let mut xtx_diag = Array1::zeros(p);
489 for j in 0..p {
490 xtx_diag[j] = x.column(j).dot(&x.column(j));
491 }
492
493 for _iter in 0..max_iter {
494 let beta_old = beta.clone();
495
496 for j in 0..p {
497 let mut r = y.clone();
499 for k in 0..p {
500 if k != j {
501 let x_k = x.column(k);
502 for i in 0..n {
503 r[i] = r[i] - beta[k] * x_k[i];
504 }
505 }
506 }
507
508 let x_j = x.column(j);
510 let xty_j = x_j.dot(&r);
511
512 let z = xty_j;
514 beta[j] = if z > lambda_f {
515 (z - lambda_f) / xtx_diag[j]
516 } else if z < -lambda_f {
517 (z + lambda_f) / xtx_diag[j]
518 } else {
519 F::zero()
520 };
521 }
522
523 let mut diff = F::zero();
525 for j in 0..p {
526 diff = diff + (beta[j] - beta_old[j]).abs();
527 }
528
529 if diff < tol {
530 break;
531 }
532 }
533
534 Ok(beta)
535}
536
537#[allow(dead_code)]
539fn solve_elastic_net<F>(
540 x: &Array2<F>,
541 y: &Array1<F>,
542 l1_lambda: f64,
543 l2_lambda: f64,
544 max_iter: usize,
545 tol: F,
546) -> Result<Array1<F>>
547where
548 F: Float + FromPrimitive + ScalarOperand + NumCast + std::iter::Sum,
549{
550 let (n, p) = (x.nrows(), x.ncols());
551 let mut beta = Array1::zeros(p);
552 let l1_lambda_f = F::from(l1_lambda).unwrap();
553 let l2_lambda_f = F::from(l2_lambda).unwrap();
554
555 let mut xtx_diag = Array1::zeros(p);
557 for j in 0..p {
558 xtx_diag[j] = x.column(j).dot(&x.column(j)) + l2_lambda_f;
559 }
560
561 for _iter in 0..max_iter {
562 let beta_old = beta.clone();
563
564 for j in 0..p {
565 let mut r = y.clone();
567 for k in 0..p {
568 if k != j {
569 let x_k = x.column(k);
570 for i in 0..n {
571 r[i] = r[i] - beta[k] * x_k[i];
572 }
573 }
574 }
575
576 let x_j = x.column(j);
578 let xty_j = x_j.dot(&r);
579
580 let z = xty_j;
582 beta[j] = if z > l1_lambda_f {
583 (z - l1_lambda_f) / xtx_diag[j]
584 } else if z < -l1_lambda_f {
585 (z + l1_lambda_f) / xtx_diag[j]
586 } else {
587 F::zero()
588 };
589 }
590
591 let mut diff = F::zero();
593 for j in 0..p {
594 diff = diff + (beta[j] - beta_old[j]).abs();
595 }
596
597 if diff < tol {
598 break;
599 }
600 }
601
602 Ok(beta)
603}
604
605#[allow(dead_code)]
607fn matrix_inverse<F>(a: &Array2<F>) -> Result<Array2<F>>
608where
609 F: Float + FromPrimitive + ScalarOperand + NumCast + 'static,
610{
611 let n = a.shape()[0];
612 if n != a.shape()[1] {
613 return Err(TimeSeriesError::DecompositionError(
614 "Matrix must be square for inversion".to_string(),
615 ));
616 }
617
618 let a_f64 = a.mapv(|x| x.to_f64().unwrap_or(0.0));
620
621 let inv_f64 = inv(&a_f64.view(), None).map_err(|e| {
623 TimeSeriesError::DecompositionError(format!("Matrix inversion failed: {e}"))
624 })?;
625
626 let inverse = inv_f64.mapv(|val| F::from_f64(val).unwrap_or_else(F::zero));
628
629 Ok(inverse)
630}
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635 use approx::assert_abs_diff_eq;
636 use scirs2_core::ndarray::array;
637
638 #[test]
639 fn test_str_basic() {
640 let n = 50;
642 let mut ts = Array1::zeros(n);
643 for i in 0..n {
644 let trend = 0.1 * i as f64;
645 let seasonal = 2.0 * (2.0 * std::f64::consts::PI * i as f64 / 12.0).sin();
646 let noise = 0.1 * (i as f64 * 0.456).sin();
647 ts[i] = trend + seasonal + noise;
648 }
649
650 let options = STROptions {
651 seasonal_periods: vec![12.0],
652 trend_degrees: 2,
653 trend_lambda: 1.0,
654 seasonal_lambda: 0.1,
655 ..Default::default()
656 };
657
658 let result = str_decomposition(&ts, &options).unwrap();
659
660 for i in 0..n {
662 let reconstructed =
663 result.trend[i] + result.seasonal_components[0][i] + result.residual[i];
664 assert_abs_diff_eq!(reconstructed, ts[i], epsilon = 1e-10);
665 }
666
667 assert!(result.trend.len() == n);
669 assert!(result.seasonal_components.len() == 1);
671 assert!(result.seasonal_components[0].len() == n);
672 }
673
674 #[test]
675 fn test_str_multiple_seasons() {
676 let n = 100;
678 let mut ts = Array1::zeros(n);
679 for i in 0..n {
680 let trend = 0.05 * i as f64;
681 let seasonal1 = 3.0 * (2.0 * std::f64::consts::PI * i as f64 / 12.0).sin();
682 let seasonal2 = 1.5 * (2.0 * std::f64::consts::PI * i as f64 / 4.0).cos();
683 ts[i] = trend + seasonal1 + seasonal2;
684 }
685
686 let options = STROptions {
687 seasonal_periods: vec![12.0, 4.0],
688 trend_degrees: 1,
689 trend_lambda: 5.0,
690 seasonal_lambda: 0.5,
691 ..Default::default()
692 };
693
694 let result = str_decomposition(&ts, &options).unwrap();
695
696 for i in 0..n {
698 let mut reconstructed = result.trend[i] + result.residual[i];
699 for seasonal_component in &result.seasonal_components {
700 reconstructed += seasonal_component[i];
701 }
702 assert_abs_diff_eq!(reconstructed, ts[i], epsilon = 1e-10);
703 }
704
705 assert_eq!(result.seasonal_components.len(), 2);
707 }
708
709 #[test]
710 fn test_str_edge_cases() {
711 let ts = array![1.0, 2.0, 3.0];
713 let mut options = STROptions {
714 seasonal_periods: vec![2.0],
715 ..Default::default()
716 };
717
718 let result = str_decomposition(&ts, &options);
719 assert!(result.is_ok());
720
721 options.seasonal_periods = vec![0.5];
723 let result = str_decomposition(&ts, &options);
724 assert!(result.is_err());
725
726 options.seasonal_periods = vec![];
728 let result = str_decomposition(&ts, &options);
729 assert!(result.is_err());
730
731 let ts = array![1.0, 2.0];
733 options.seasonal_periods = vec![2.0];
734 let result = str_decomposition(&ts, &options);
735 assert!(result.is_err());
736 }
737}