scirs2_series/decomposition/
str.rs

1//! Seasonal-Trend decomposition using Regression (STR)
2
3use scirs2_core::ndarray::{s, Array1, Array2, ScalarOperand};
4use scirs2_core::numeric::{Float, FromPrimitive, NumCast};
5use scirs2_linalg::{inv, solve};
6use std::fmt::Debug;
7
8use crate::error::{Result, TimeSeriesError};
9
10/// Options for STR (Seasonal-Trend decomposition using Regression)
11#[derive(Debug, Clone)]
12pub struct STROptions {
13    /// Type of regularization to use
14    pub regularization_type: RegularizationType,
15    /// Regularization parameter for trend
16    pub trend_lambda: f64,
17    /// Regularization parameter for seasonal components
18    pub seasonal_lambda: f64,
19    /// Seasonal periods (can include non-integer values)
20    pub seasonal_periods: Vec<f64>,
21    /// Whether to use robust estimation (less sensitive to outliers)
22    pub robust: bool,
23    /// Whether to compute confidence intervals
24    pub compute_confidence_intervals: bool,
25    /// Confidence level (e.g., 0.95 for 95% confidence)
26    pub confidence_level: f64,
27    /// Degrees of freedom for the trend
28    pub trend_degrees: usize,
29    /// Whether to allow the seasonal pattern to change over time
30    pub flexible_seasonal: bool,
31    /// Number of harmonics for each seasonal component
32    pub seasonal_harmonics: Option<Vec<usize>>,
33}
34
35/// Type of regularization to use in STR
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum RegularizationType {
38    /// Ridge regularization (L2 penalty)
39    Ridge,
40    /// LASSO regularization (L1 penalty)
41    Lasso,
42    /// Elastic Net regularization (combination of L1 and L2)
43    ElasticNet,
44}
45
46impl Default for STROptions {
47    fn default() -> Self {
48        Self {
49            regularization_type: RegularizationType::Ridge,
50            trend_lambda: 10.0,
51            seasonal_lambda: 0.5,
52            seasonal_periods: Vec::new(),
53            robust: false,
54            compute_confidence_intervals: false,
55            confidence_level: 0.95,
56            trend_degrees: 3,
57            flexible_seasonal: false,
58            seasonal_harmonics: None,
59        }
60    }
61}
62
63/// Result of STR decomposition
64#[derive(Debug, Clone)]
65pub struct STRResult<F> {
66    /// Trend component
67    pub trend: Array1<F>,
68    /// Seasonal components (one for each seasonal period)
69    pub seasonal_components: Vec<Array1<F>>,
70    /// Residual component
71    pub residual: Array1<F>,
72    /// Original time series
73    pub original: Array1<F>,
74    /// Confidence intervals for trend (if computed)
75    pub trend_ci: Option<(Array1<F>, Array1<F>)>, // (lower, upper)
76    /// Confidence intervals for seasonal components (if computed)
77    pub seasonal_ci: Option<Vec<(Array1<F>, Array1<F>)>>, // (lower, upper) for each component
78}
79
80/// Performs STR (Seasonal-Trend decomposition using Regression) on a time series
81///
82/// STR uses regularized regression to extract trend and seasonal components from
83/// a time series. It allows for multiple seasonal components, non-integer periods,
84/// and can provide confidence intervals for the components.
85///
86/// # Arguments
87///
88/// * `ts` - The time series to decompose
89/// * `options` - Options for STR decomposition
90///
91/// # Returns
92///
93/// * STR decomposition result
94///
95/// # Example
96///
97/// ```
98/// use scirs2_core::ndarray::array;
99/// use scirs2_series::decomposition::{str_decomposition, STROptions};
100///
101/// let ts = array![1.0, 2.0, 3.0, 2.0, 1.0, 2.0, 3.0, 2.0, 1.0, 2.0, 3.0, 2.0,
102///                 1.5, 2.5, 3.5, 2.5, 1.5, 2.5, 3.5, 2.5, 1.5, 2.5, 3.5, 2.5];
103///
104/// let mut options = STROptions::default();
105/// options.seasonal_periods = vec![4.0, 12.0]; // Both quarterly and yearly patterns
106///
107/// let result = str_decomposition(&ts, &options).unwrap();
108/// println!("Trend: {:?}", result.trend);
109/// println!("Seasonal Components: {:?}", result.seasonal_components);
110/// println!("Residual: {:?}", result.residual);
111/// ```
112#[allow(dead_code)]
113pub fn str_decomposition<F>(ts: &Array1<F>, options: &STROptions) -> Result<STRResult<F>>
114where
115    F: Float + FromPrimitive + Debug + ScalarOperand + NumCast + std::iter::Sum,
116{
117    let n = ts.len();
118
119    // Check inputs
120    if n < 3 {
121        return Err(TimeSeriesError::DecompositionError(
122            "Time series must have at least 3 points for STR decomposition".to_string(),
123        ));
124    }
125
126    if options.seasonal_periods.is_empty() {
127        return Err(TimeSeriesError::DecompositionError(
128            "At least one seasonal period must be specified for STR".to_string(),
129        ));
130    }
131
132    for &period in &options.seasonal_periods {
133        if period <= 1.0 {
134            return Err(TimeSeriesError::DecompositionError(
135                "Seasonal periods must be greater than 1".to_string(),
136            ));
137        }
138    }
139
140    if options.trend_lambda < 0.0 || options.seasonal_lambda < 0.0 {
141        return Err(TimeSeriesError::DecompositionError(
142            "Regularization parameters must be non-negative".to_string(),
143        ));
144    }
145
146    if options.confidence_level <= 0.0 || options.confidence_level >= 1.0 {
147        return Err(TimeSeriesError::DecompositionError(
148            "Confidence level must be between 0 and 1".to_string(),
149        ));
150    }
151
152    // Step 1: Prepare design matrices for trend and seasonal components
153    let time_indices: Array1<F> = Array1::from_iter((0..n).map(|i| F::from_usize(i).unwrap()));
154
155    // Trend design matrix using polynomial basis functions
156    let trend_degree = options.trend_degrees;
157    let mut trend_basis = Array2::zeros((n, trend_degree + 1));
158
159    // Fill trend design matrix with polynomial terms (1, t, t^2, t^3, ...)
160    for i in 0..n {
161        for j in 0..=trend_degree {
162            if j == 0 {
163                trend_basis[[i, j]] = F::one(); // Constant term
164            } else {
165                let time_idx = time_indices[i];
166                trend_basis[[i, j]] = Float::powf(time_idx, F::from_usize(j).unwrap());
167            }
168        }
169    }
170
171    // Seasonal design matrices using Fourier basis functions for each seasonal component
172    let mut seasonal_bases = Vec::with_capacity(options.seasonal_periods.len());
173    let mut total_seasonal_cols = 0;
174
175    for (idx, &period) in options.seasonal_periods.iter().enumerate() {
176        // Number of harmonics for this seasonal component
177        let harmonics = if let Some(ref harms) = options.seasonal_harmonics {
178            harms
179                .get(idx)
180                .copied()
181                .unwrap_or(((period / 2.0).floor() as usize).max(1))
182        } else {
183            ((period / 2.0).floor() as usize).max(1)
184        };
185
186        let mut seasonal_basis = Array2::zeros((n, 2 * harmonics)); // 2 columns per harmonic (sin and cos)
187
188        for i in 0..n {
189            let t = time_indices[i];
190            for j in 0..harmonics {
191                let freq =
192                    F::from_f64(2.0 * std::f64::consts::PI * (j + 1) as f64 / period).unwrap();
193                // Sin term
194                seasonal_basis[[i, 2 * j]] = Float::sin(freq * t);
195                // Cos term
196                seasonal_basis[[i, 2 * j + 1]] = Float::cos(freq * t);
197            }
198        }
199
200        total_seasonal_cols += 2 * harmonics;
201        seasonal_bases.push(seasonal_basis);
202    }
203
204    // Step 2: Combine all design matrices
205    let total_cols = trend_degree + 1 + total_seasonal_cols;
206    let mut design_matrix = Array2::zeros((n, total_cols));
207
208    // Fill trend columns
209    design_matrix
210        .slice_mut(s![.., 0..=trend_degree])
211        .assign(&trend_basis);
212
213    // Fill seasonal columns
214    let mut col_offset = trend_degree + 1;
215    for seasonal_basis in &seasonal_bases {
216        let next_offset = col_offset + seasonal_basis.ncols();
217        design_matrix
218            .slice_mut(s![.., col_offset..next_offset])
219            .assign(seasonal_basis);
220        col_offset = next_offset;
221    }
222
223    // Step 3: Set up regularization matrix
224    let mut regularization_matrix = Array2::zeros((total_cols, total_cols));
225
226    // Trend regularization (penalize higher-order polynomial coefficients)
227    for i in 0..=trend_degree {
228        let weight = if i == 0 {
229            0.0 // Don't penalize the constant term
230        } else {
231            options.trend_lambda * (i as f64).powi(2)
232        };
233        regularization_matrix[[i, i]] = F::from_f64(weight).unwrap();
234    }
235
236    // Seasonal regularization
237    col_offset = trend_degree + 1;
238    for seasonal_basis in &seasonal_bases {
239        let seasonal_cols = seasonal_basis.ncols();
240        for i in 0..seasonal_cols {
241            regularization_matrix[[col_offset + i, col_offset + i]] =
242                F::from(options.seasonal_lambda).unwrap();
243        }
244        col_offset += seasonal_cols;
245    }
246
247    // Step 4: Solve regularized least squares problem
248    // (X^T X + λR) β = X^T y
249    let xtx = design_matrix.t().dot(&design_matrix);
250    let xty = design_matrix.t().dot(ts);
251
252    // Add regularization
253    let system_matrix = xtx + regularization_matrix;
254
255    // Solve the system
256    let coefficients = match options.regularization_type {
257        RegularizationType::Ridge => {
258            // Ridge regression: solve (X^T X + λI) β = X^T y
259            solve_regularized_system(&system_matrix, &xty)?
260        }
261        RegularizationType::Lasso => {
262            // LASSO regression using coordinate descent
263            solve_lasso(
264                &design_matrix,
265                ts,
266                options.seasonal_lambda,
267                1000,
268                F::from(1e-6).unwrap(),
269            )?
270        }
271        RegularizationType::ElasticNet => {
272            // Elastic Net regression using coordinate descent
273            solve_elastic_net(
274                &design_matrix,
275                ts,
276                options.seasonal_lambda,
277                options.trend_lambda,
278                1000,
279                F::from(1e-6).unwrap(),
280            )?
281        }
282    };
283
284    // Step 5: Extract components from coefficients
285    // Trend component
286    let trend_coeffs = coefficients.slice(s![0..=trend_degree]);
287    let trend = trend_basis.dot(&trend_coeffs);
288
289    // Seasonal components
290    let mut seasonal_components = Vec::with_capacity(options.seasonal_periods.len());
291    col_offset = trend_degree + 1;
292
293    for seasonal_basis in &seasonal_bases {
294        let seasonal_cols = seasonal_basis.ncols();
295        let seasonal_coeffs = coefficients.slice(s![col_offset..col_offset + seasonal_cols]);
296        let seasonal_component = seasonal_basis.dot(&seasonal_coeffs);
297        seasonal_components.push(seasonal_component);
298        col_offset += seasonal_cols;
299    }
300
301    // Compute residuals
302    let mut residual = ts.clone();
303    for i in 0..n {
304        residual[i] = residual[i] - trend[i];
305        for seasonal_component in &seasonal_components {
306            residual[i] = residual[i] - seasonal_component[i];
307        }
308    }
309
310    // Compute confidence intervals if requested
311    let (trend_ci, seasonal_ci) = if options.compute_confidence_intervals {
312        compute_confidence_intervals(
313            &design_matrix,
314            &system_matrix,
315            &residual,
316            &trend_basis,
317            &seasonal_bases,
318            options.confidence_level,
319        )?
320    } else {
321        (None, None)
322    };
323
324    // Create result
325    let result = STRResult {
326        trend,
327        seasonal_components,
328        residual,
329        original: ts.clone(),
330        trend_ci,
331        seasonal_ci,
332    };
333
334    Ok(result)
335}
336
337/// Type alias for confidence interval bounds (lower, upper)
338type ConfidenceInterval<F> = (Array1<F>, Array1<F>);
339
340/// Type alias for confidence intervals result
341type ConfidenceIntervalsResult<F> = Result<(
342    Option<ConfidenceInterval<F>>,
343    Option<Vec<ConfidenceInterval<F>>>,
344)>;
345
346/// Compute confidence intervals for STR components
347#[allow(dead_code)]
348fn compute_confidence_intervals<F>(
349    design_matrix: &Array2<F>,
350    system_matrix: &Array2<F>,
351    residual: &Array1<F>,
352    trend_basis: &Array2<F>,
353    seasonal_bases: &[Array2<F>],
354    confidence_level: f64,
355) -> ConfidenceIntervalsResult<F>
356where
357    F: Float + FromPrimitive + Debug + ScalarOperand + NumCast + std::iter::Sum,
358{
359    let n = residual.len();
360    let p = design_matrix.ncols();
361
362    if n <= p {
363        return Ok((None, None));
364    }
365
366    // Estimate residual variance
367    let residual_variance = residual.mapv(|x| x * x).sum() / F::from_usize(n - p).unwrap();
368
369    // Compute covariance _matrix: σ² (X^T X + λR)^(-1)
370    let covariance_matrix = match matrix_inverse(system_matrix) {
371        Ok(inv) => inv.mapv(|x| x * residual_variance),
372        Err(_) => return Ok((None, None)), // Skip CI if _matrix is singular
373    };
374
375    // Get t-distribution critical value (approximation for large samples)
376    let alpha = 1.0 - confidence_level;
377    let df = n - p;
378    let t_critical = if df > 30 {
379        // Normal approximation for large df
380        match alpha {
381            a if a <= 0.01 => F::from(2.576).unwrap(), // 99% CI
382            a if a <= 0.05 => F::from(1.96).unwrap(),  // 95% CI
383            _ => F::from(1.645).unwrap(),              // 90% CI
384        }
385    } else {
386        // Simple t-distribution approximation
387        let base = F::from(2.0).unwrap();
388        base + F::from(df as f64).unwrap().recip()
389    };
390
391    // Compute standard errors for trend component
392    let trend_se = compute_component_standard_errors(trend_basis, &covariance_matrix)?;
393    let trend_margin = trend_se.mapv(|se| se * t_critical);
394    let trend_fitted = trend_basis.dot(&covariance_matrix.diag().slice(s![0..trend_basis.ncols()]));
395    let trend_lower = &trend_fitted - &trend_margin;
396    let trend_upper = &trend_fitted + &trend_margin;
397
398    // Compute standard errors for seasonal components
399    let mut seasonal_cis = Vec::new();
400    let mut col_offset = trend_basis.ncols();
401
402    for seasonal_basis in seasonal_bases {
403        let seasonal_cols = seasonal_basis.ncols();
404        let seasonal_cov = covariance_matrix.slice(s![
405            col_offset..col_offset + seasonal_cols,
406            col_offset..col_offset + seasonal_cols
407        ]);
408
409        let seasonal_se =
410            compute_component_standard_errors(seasonal_basis, &seasonal_cov.to_owned())?;
411        let seasonal_margin = seasonal_se.mapv(|se| se * t_critical);
412        let seasonal_fitted = seasonal_basis.dot(&seasonal_cov.diag());
413        let seasonal_lower = &seasonal_fitted - &seasonal_margin;
414        let seasonal_upper = &seasonal_fitted + &seasonal_margin;
415
416        seasonal_cis.push((seasonal_lower, seasonal_upper));
417        col_offset += seasonal_cols;
418    }
419
420    Ok((Some((trend_lower, trend_upper)), Some(seasonal_cis)))
421}
422
423/// Compute standard errors for a component given its basis and covariance matrix
424#[allow(dead_code)]
425fn compute_component_standard_errors<F>(
426    basis: &Array2<F>,
427    covariance: &Array2<F>,
428) -> Result<Array1<F>>
429where
430    F: Float + FromPrimitive + Debug + ScalarOperand + NumCast + std::iter::Sum,
431{
432    let n = basis.nrows();
433    let mut standard_errors = Array1::zeros(n);
434
435    for i in 0..n {
436        let basis_row = basis.row(i);
437        let variance = basis_row.dot(&covariance.dot(&basis_row));
438        standard_errors[i] = variance.sqrt();
439    }
440
441    Ok(standard_errors)
442}
443
444/// Matrix solve using scirs2-linalg
445#[allow(dead_code)]
446fn solve_regularized_system<F>(a: &Array2<F>, b: &Array1<F>) -> Result<Array1<F>>
447where
448    F: Float + FromPrimitive + ScalarOperand + NumCast + 'static,
449{
450    let n = a.shape()[0];
451    if n != a.shape()[1] || n != b.len() {
452        return Err(TimeSeriesError::DecompositionError(
453            "Matrix dimensions mismatch".to_string(),
454        ));
455    }
456
457    // Convert to f64 for scirs2-linalg computation
458    let a_f64 = a.mapv(|x| x.to_f64().unwrap_or(0.0));
459    let b_f64 = b.mapv(|x| x.to_f64().unwrap_or(0.0));
460
461    // Solve using scirs2-linalg
462    let x_f64 = solve(&a_f64.view(), &b_f64.view(), None)
463        .map_err(|e| TimeSeriesError::DecompositionError(format!("Linear solve failed: {e}")))?;
464
465    // Convert back to original type
466    let x = x_f64.mapv(|val| F::from_f64(val).unwrap_or_else(F::zero));
467
468    Ok(x)
469}
470
471/// LASSO regression using coordinate descent algorithm
472#[allow(dead_code)]
473fn solve_lasso<F>(
474    x: &Array2<F>,
475    y: &Array1<F>,
476    lambda: f64,
477    max_iter: usize,
478    tol: F,
479) -> Result<Array1<F>>
480where
481    F: Float + FromPrimitive + ScalarOperand + NumCast + std::iter::Sum,
482{
483    let (n, p) = (x.nrows(), x.ncols());
484    let mut beta = Array1::zeros(p);
485    let lambda_f = F::from(lambda).unwrap();
486
487    // Precompute X^T X diagonal (for efficiency)
488    let mut xtx_diag = Array1::zeros(p);
489    for j in 0..p {
490        xtx_diag[j] = x.column(j).dot(&x.column(j));
491    }
492
493    for _iter in 0..max_iter {
494        let beta_old = beta.clone();
495
496        for j in 0..p {
497            // Compute partial residual
498            let mut r = y.clone();
499            for k in 0..p {
500                if k != j {
501                    let x_k = x.column(k);
502                    for i in 0..n {
503                        r[i] = r[i] - beta[k] * x_k[i];
504                    }
505                }
506            }
507
508            // Compute coordinate update
509            let x_j = x.column(j);
510            let xty_j = x_j.dot(&r);
511
512            // Soft thresholding
513            let z = xty_j;
514            beta[j] = if z > lambda_f {
515                (z - lambda_f) / xtx_diag[j]
516            } else if z < -lambda_f {
517                (z + lambda_f) / xtx_diag[j]
518            } else {
519                F::zero()
520            };
521        }
522
523        // Check convergence
524        let mut diff = F::zero();
525        for j in 0..p {
526            diff = diff + (beta[j] - beta_old[j]).abs();
527        }
528
529        if diff < tol {
530            break;
531        }
532    }
533
534    Ok(beta)
535}
536
537/// Elastic Net regression using coordinate descent algorithm
538#[allow(dead_code)]
539fn solve_elastic_net<F>(
540    x: &Array2<F>,
541    y: &Array1<F>,
542    l1_lambda: f64,
543    l2_lambda: f64,
544    max_iter: usize,
545    tol: F,
546) -> Result<Array1<F>>
547where
548    F: Float + FromPrimitive + ScalarOperand + NumCast + std::iter::Sum,
549{
550    let (n, p) = (x.nrows(), x.ncols());
551    let mut beta = Array1::zeros(p);
552    let l1_lambda_f = F::from(l1_lambda).unwrap();
553    let l2_lambda_f = F::from(l2_lambda).unwrap();
554
555    // Precompute X^T X diagonal + L2 penalty
556    let mut xtx_diag = Array1::zeros(p);
557    for j in 0..p {
558        xtx_diag[j] = x.column(j).dot(&x.column(j)) + l2_lambda_f;
559    }
560
561    for _iter in 0..max_iter {
562        let beta_old = beta.clone();
563
564        for j in 0..p {
565            // Compute partial residual
566            let mut r = y.clone();
567            for k in 0..p {
568                if k != j {
569                    let x_k = x.column(k);
570                    for i in 0..n {
571                        r[i] = r[i] - beta[k] * x_k[i];
572                    }
573                }
574            }
575
576            // Compute coordinate update
577            let x_j = x.column(j);
578            let xty_j = x_j.dot(&r);
579
580            // Soft thresholding with L2 penalty
581            let z = xty_j;
582            beta[j] = if z > l1_lambda_f {
583                (z - l1_lambda_f) / xtx_diag[j]
584            } else if z < -l1_lambda_f {
585                (z + l1_lambda_f) / xtx_diag[j]
586            } else {
587                F::zero()
588            };
589        }
590
591        // Check convergence
592        let mut diff = F::zero();
593        for j in 0..p {
594            diff = diff + (beta[j] - beta_old[j]).abs();
595        }
596
597        if diff < tol {
598            break;
599        }
600    }
601
602    Ok(beta)
603}
604
605/// Matrix inversion using scirs2-linalg
606#[allow(dead_code)]
607fn matrix_inverse<F>(a: &Array2<F>) -> Result<Array2<F>>
608where
609    F: Float + FromPrimitive + ScalarOperand + NumCast + 'static,
610{
611    let n = a.shape()[0];
612    if n != a.shape()[1] {
613        return Err(TimeSeriesError::DecompositionError(
614            "Matrix must be square for inversion".to_string(),
615        ));
616    }
617
618    // Convert to f64 for scirs2-linalg computation
619    let a_f64 = a.mapv(|x| x.to_f64().unwrap_or(0.0));
620
621    // Compute inverse using scirs2-linalg
622    let inv_f64 = inv(&a_f64.view(), None).map_err(|e| {
623        TimeSeriesError::DecompositionError(format!("Matrix inversion failed: {e}"))
624    })?;
625
626    // Convert back to original type
627    let inverse = inv_f64.mapv(|val| F::from_f64(val).unwrap_or_else(F::zero));
628
629    Ok(inverse)
630}
631
632#[cfg(test)]
633mod tests {
634    use super::*;
635    use approx::assert_abs_diff_eq;
636    use scirs2_core::ndarray::array;
637
638    #[test]
639    fn test_str_basic() {
640        // Create a simple time series with trend and seasonality
641        let n = 50;
642        let mut ts = Array1::zeros(n);
643        for i in 0..n {
644            let trend = 0.1 * i as f64;
645            let seasonal = 2.0 * (2.0 * std::f64::consts::PI * i as f64 / 12.0).sin();
646            let noise = 0.1 * (i as f64 * 0.456).sin();
647            ts[i] = trend + seasonal + noise;
648        }
649
650        let options = STROptions {
651            seasonal_periods: vec![12.0],
652            trend_degrees: 2,
653            trend_lambda: 1.0,
654            seasonal_lambda: 0.1,
655            ..Default::default()
656        };
657
658        let result = str_decomposition(&ts, &options).unwrap();
659
660        // Check that decomposition sums to original (approximately)
661        for i in 0..n {
662            let reconstructed =
663                result.trend[i] + result.seasonal_components[0][i] + result.residual[i];
664            assert_abs_diff_eq!(reconstructed, ts[i], epsilon = 1e-10);
665        }
666
667        // Check that we extracted a trend
668        assert!(result.trend.len() == n);
669        // Check that we extracted seasonal components
670        assert!(result.seasonal_components.len() == 1);
671        assert!(result.seasonal_components[0].len() == n);
672    }
673
674    #[test]
675    fn test_str_multiple_seasons() {
676        // Create a time series with multiple seasonal patterns
677        let n = 100;
678        let mut ts = Array1::zeros(n);
679        for i in 0..n {
680            let trend = 0.05 * i as f64;
681            let seasonal1 = 3.0 * (2.0 * std::f64::consts::PI * i as f64 / 12.0).sin();
682            let seasonal2 = 1.5 * (2.0 * std::f64::consts::PI * i as f64 / 4.0).cos();
683            ts[i] = trend + seasonal1 + seasonal2;
684        }
685
686        let options = STROptions {
687            seasonal_periods: vec![12.0, 4.0],
688            trend_degrees: 1,
689            trend_lambda: 5.0,
690            seasonal_lambda: 0.5,
691            ..Default::default()
692        };
693
694        let result = str_decomposition(&ts, &options).unwrap();
695
696        // Check that decomposition sums to original
697        for i in 0..n {
698            let mut reconstructed = result.trend[i] + result.residual[i];
699            for seasonal_component in &result.seasonal_components {
700                reconstructed += seasonal_component[i];
701            }
702            assert_abs_diff_eq!(reconstructed, ts[i], epsilon = 1e-10);
703        }
704
705        // Check that we have the right number of seasonal components
706        assert_eq!(result.seasonal_components.len(), 2);
707    }
708
709    #[test]
710    fn test_str_edge_cases() {
711        // Test with minimum size time series
712        let ts = array![1.0, 2.0, 3.0];
713        let mut options = STROptions {
714            seasonal_periods: vec![2.0],
715            ..Default::default()
716        };
717
718        let result = str_decomposition(&ts, &options);
719        assert!(result.is_ok());
720
721        // Test with invalid seasonal period
722        options.seasonal_periods = vec![0.5];
723        let result = str_decomposition(&ts, &options);
724        assert!(result.is_err());
725
726        // Test with no seasonal periods
727        options.seasonal_periods = vec![];
728        let result = str_decomposition(&ts, &options);
729        assert!(result.is_err());
730
731        // Test with too small time series
732        let ts = array![1.0, 2.0];
733        options.seasonal_periods = vec![2.0];
734        let result = str_decomposition(&ts, &options);
735        assert!(result.is_err());
736    }
737}