sklears_preprocessing/temporal/
stationarity.rs

1//! Stationarity transformations for time series
2//!
3//! This module provides methods to transform non-stationary time series into
4//! stationary ones, which is essential for many time series modeling techniques.
5//!
6//! # Stationarity Transformations
7//!
8//! - **Differencing**: First and higher-order differencing to remove trends
9//! - **Detrending**: Remove linear, polynomial, or local trends
10//! - **Log transformation**: Stabilize variance for exponential growth patterns
11//! - **Box-Cox transformation**: Generalized power transformation for variance stabilization
12//! - **Seasonal differencing**: Remove seasonal patterns
13//! - **Combined transformations**: Multiple transformations in sequence
14//!
15//! # Stationarity Tests
16//!
17//! - **Augmented Dickey-Fuller (ADF)**: Test for unit root (simplified implementation)
18//! - **KPSS test**: Test for trend stationarity (simplified)
19//! - **Phillips-Perron test**: Alternative unit root test (basic implementation)
20//!
21//! # Examples
22//!
23//! ```rust,ignore
24//! use sklears_preprocessing::temporal::stationarity::{
25//!     StationarityTransformer, StationarityMethod
26//! };
27//! use scirs2_core::ndarray::Array1;
28//!
29//! // Create sample time series with trend
30//! let mut data = Array1::zeros(100);
31//! for i in 0..100 {
32//!     data[i] = (i as f64) + (i as f64 * 0.1).sin(); // Linear trend + sine wave
33//! }
34//!
35//! // Apply first differencing to remove trend
36//! let transformer = StationarityTransformer::new()
37//!     .with_method(StationarityMethod::FirstDifference);
38//!
39//! let stationary_data = transformer.transform(&data).unwrap();
40//! ```
41
42use scirs2_core::ndarray::{s, Array1};
43use sklears_core::{
44    error::{Result, SklearsError},
45    traits::{Estimator, Fit, Transform, Untrained},
46    types::Float,
47};
48
49#[cfg(feature = "serde")]
50use serde::{Deserialize, Serialize};
51
52/// Stationarity transformation methods
53#[derive(Debug, Clone, Copy, PartialEq)]
54#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
55pub enum StationarityMethod {
56    /// First-order differencing: x_t - x_{t-1}
57    FirstDifference,
58    /// Second-order differencing: (x_t - x_{t-1}) - (x_{t-1} - x_{t-2})
59    SecondDifference,
60    /// Seasonal differencing with specified period
61    SeasonalDifference(usize),
62    /// Linear detrending using least squares
63    LinearDetrend,
64    /// Polynomial detrending with specified degree
65    PolynomialDetrend(usize),
66    /// Log transformation to stabilize variance
67    LogTransform,
68    /// Box-Cox transformation with lambda parameter
69    BoxCox(Float),
70    /// Combined first difference and seasonal difference
71    CombinedDifference(usize),
72    /// Moving average detrending with window size
73    MovingAverageDetrend(usize),
74}
75
76impl Default for StationarityMethod {
77    fn default() -> Self {
78        Self::FirstDifference
79    }
80}
81
82/// Configuration for stationarity transformer
83#[derive(Debug, Clone)]
84#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
85pub struct StationarityTransformerConfig {
86    /// Primary transformation method
87    pub method: StationarityMethod,
88    /// Handle missing values created by differencing
89    pub fill_method: FillMethod,
90    /// Minimum value for log/Box-Cox transformations (to avoid log(0))
91    pub min_value_offset: Float,
92    /// Test for stationarity after transformation
93    pub test_stationarity: bool,
94}
95
96/// Methods for handling missing values created by transformations
97#[derive(Debug, Clone, Copy, PartialEq, Eq)]
98#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
99pub enum FillMethod {
100    /// Drop missing values (reduce series length)
101    Drop,
102    /// Forward fill with first available value
103    ForwardFill,
104    /// Backward fill with last available value
105    BackwardFill,
106    /// Fill with zero
107    Zero,
108    /// Fill with mean of non-missing values
109    Mean,
110}
111
112impl Default for FillMethod {
113    fn default() -> Self {
114        Self::Drop
115    }
116}
117
118impl Default for StationarityTransformerConfig {
119    fn default() -> Self {
120        Self {
121            method: StationarityMethod::default(),
122            fill_method: FillMethod::default(),
123            min_value_offset: 1e-8,
124            test_stationarity: false,
125        }
126    }
127}
128
129/// Stationarity transformer for making time series stationary
130#[derive(Debug, Clone)]
131#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
132pub struct StationarityTransformer<State = Untrained> {
133    config: StationarityTransformerConfig,
134    state: std::marker::PhantomData<State>,
135}
136
137/// Fitted state containing transformation parameters
138#[derive(Debug, Clone)]
139#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
140pub struct StationarityTransformerFitted {
141    config: StationarityTransformerConfig,
142    /// Parameters for reverse transformation
143    trend_params: Option<Vec<Float>>,
144    /// Log offset used in transformation
145    log_offset: Option<Float>,
146    /// Box-Cox lambda parameter
147    boxcox_lambda: Option<Float>,
148    /// Original series statistics for validation
149    original_mean: Option<Float>,
150    original_std: Option<Float>,
151}
152
153impl Default for StationarityTransformer<Untrained> {
154    fn default() -> Self {
155        Self::new()
156    }
157}
158
159impl StationarityTransformer<Untrained> {
160    /// Create a new stationarity transformer
161    pub fn new() -> Self {
162        Self {
163            config: StationarityTransformerConfig::default(),
164            state: std::marker::PhantomData,
165        }
166    }
167
168    /// Set the stationarity transformation method
169    pub fn with_method(mut self, method: StationarityMethod) -> Self {
170        self.config.method = method;
171        self
172    }
173
174    /// Set the fill method for missing values
175    pub fn with_fill_method(mut self, fill_method: FillMethod) -> Self {
176        self.config.fill_method = fill_method;
177        self
178    }
179
180    /// Set the minimum value offset for log transformations
181    pub fn with_min_value_offset(mut self, offset: Float) -> Self {
182        self.config.min_value_offset = offset;
183        self
184    }
185
186    /// Enable/disable stationarity testing
187    pub fn with_stationarity_test(mut self, test: bool) -> Self {
188        self.config.test_stationarity = test;
189        self
190    }
191}
192
193impl Estimator for StationarityTransformer<Untrained> {
194    type Config = StationarityTransformerConfig;
195    type Error = SklearsError;
196    type Float = Float;
197
198    fn config(&self) -> &Self::Config {
199        &self.config
200    }
201}
202
203impl Fit<Array1<Float>, ()> for StationarityTransformer<Untrained> {
204    type Fitted = StationarityTransformerFitted;
205
206    fn fit(self, x: &Array1<Float>, _y: &()) -> Result<Self::Fitted> {
207        if x.len() < 2 {
208            return Err(SklearsError::InvalidInput(
209                "Time series must have at least 2 points".to_string(),
210            ));
211        }
212
213        let original_mean = x.mean();
214        let original_std = Some(calculate_std(x));
215
216        // Fit transformation parameters based on method
217        let (trend_params, log_offset, boxcox_lambda) = match self.config.method {
218            StationarityMethod::LinearDetrend => {
219                let params = fit_linear_trend(x)?;
220                (Some(params), None, None)
221            }
222            StationarityMethod::PolynomialDetrend(degree) => {
223                let params = fit_polynomial_trend(x, degree)?;
224                (Some(params), None, None)
225            }
226            StationarityMethod::LogTransform => {
227                let min_val = x.iter().fold(Float::INFINITY, |a, &b| a.min(b));
228                let offset = if min_val <= 0.0 {
229                    -min_val + self.config.min_value_offset
230                } else {
231                    0.0
232                };
233                (None, Some(offset), None)
234            }
235            StationarityMethod::BoxCox(lambda) => {
236                let min_val = x.iter().fold(Float::INFINITY, |a, &b| a.min(b));
237                let offset = if min_val <= 0.0 {
238                    -min_val + self.config.min_value_offset
239                } else {
240                    0.0
241                };
242                (None, Some(offset), Some(lambda))
243            }
244            _ => (None, None, None),
245        };
246
247        Ok(StationarityTransformerFitted {
248            config: self.config,
249            trend_params,
250            log_offset,
251            boxcox_lambda,
252            original_mean,
253            original_std,
254        })
255    }
256}
257
258impl Transform<Array1<Float>, Array1<Float>> for StationarityTransformerFitted {
259    fn transform(&self, x: &Array1<Float>) -> Result<Array1<Float>> {
260        if x.is_empty() {
261            return Ok(Array1::zeros(0));
262        }
263
264        let result = match self.config.method {
265            StationarityMethod::FirstDifference => self.first_difference(x)?,
266            StationarityMethod::SecondDifference => self.second_difference(x)?,
267            StationarityMethod::SeasonalDifference(period) => {
268                self.seasonal_difference(x, period)?
269            }
270            StationarityMethod::LinearDetrend => self.linear_detrend(x)?,
271            StationarityMethod::PolynomialDetrend(degree) => self.polynomial_detrend(x, degree)?,
272            StationarityMethod::LogTransform => self.log_transform(x)?,
273            StationarityMethod::BoxCox(lambda) => self.box_cox_transform(x, lambda)?,
274            StationarityMethod::CombinedDifference(period) => {
275                self.combined_difference(x, period)?
276            }
277            StationarityMethod::MovingAverageDetrend(window) => {
278                self.moving_average_detrend(x, window)?
279            }
280        };
281
282        // Test stationarity if requested
283        if self.config.test_stationarity && result.len() > 10 {
284            let _test_result = self.test_stationarity(&result)?;
285            // Could log or return test results
286        }
287
288        Ok(result)
289    }
290}
291
292impl StationarityTransformerFitted {
293    /// Apply first differencing
294    fn first_difference(&self, x: &Array1<Float>) -> Result<Array1<Float>> {
295        if x.len() < 2 {
296            return Ok(Array1::zeros(0));
297        }
298
299        let mut result = Array1::zeros(x.len() - 1);
300        for i in 1..x.len() {
301            result[i - 1] = x[i] - x[i - 1];
302        }
303
304        self.handle_missing_values(result)
305    }
306
307    /// Apply second differencing
308    fn second_difference(&self, x: &Array1<Float>) -> Result<Array1<Float>> {
309        let first_diff = self.first_difference(x)?;
310        if first_diff.len() < 2 {
311            return Ok(Array1::zeros(0));
312        }
313
314        let mut result = Array1::zeros(first_diff.len() - 1);
315        for i in 1..first_diff.len() {
316            result[i - 1] = first_diff[i] - first_diff[i - 1];
317        }
318
319        self.handle_missing_values(result)
320    }
321
322    /// Apply seasonal differencing
323    fn seasonal_difference(&self, x: &Array1<Float>, period: usize) -> Result<Array1<Float>> {
324        if x.len() <= period {
325            return Ok(Array1::zeros(0));
326        }
327
328        let mut result = Array1::zeros(x.len() - period);
329        for i in period..x.len() {
330            result[i - period] = x[i] - x[i - period];
331        }
332
333        self.handle_missing_values(result)
334    }
335
336    /// Apply linear detrending
337    fn linear_detrend(&self, x: &Array1<Float>) -> Result<Array1<Float>> {
338        let params = self
339            .trend_params
340            .as_ref()
341            .ok_or_else(|| SklearsError::NotFitted {
342                operation: "Linear trend not fitted".to_string(),
343            })?;
344
345        if params.len() != 2 {
346            return Err(SklearsError::InvalidInput(
347                "Linear trend requires 2 parameters".to_string(),
348            ));
349        }
350
351        let slope = params[0];
352        let intercept = params[1];
353
354        let mut result = Array1::zeros(x.len());
355        for (i, &val) in x.iter().enumerate() {
356            let trend_val = slope * (i as Float) + intercept;
357            result[i] = val - trend_val;
358        }
359
360        Ok(result)
361    }
362
363    /// Apply polynomial detrending
364    fn polynomial_detrend(&self, x: &Array1<Float>, _degree: usize) -> Result<Array1<Float>> {
365        let params = self
366            .trend_params
367            .as_ref()
368            .ok_or_else(|| SklearsError::NotFitted {
369                operation: "Polynomial trend not fitted".to_string(),
370            })?;
371
372        let mut result = Array1::zeros(x.len());
373        for (i, &val) in x.iter().enumerate() {
374            let mut trend_val = 0.0;
375            let t = i as Float;
376
377            // Apply polynomial: a0 + a1*t + a2*t^2 + ... + an*t^n
378            for (degree, &coeff) in params.iter().enumerate() {
379                trend_val += coeff * t.powi(degree as i32);
380            }
381
382            result[i] = val - trend_val;
383        }
384
385        Ok(result)
386    }
387
388    /// Apply log transformation
389    fn log_transform(&self, x: &Array1<Float>) -> Result<Array1<Float>> {
390        let offset = self.log_offset.unwrap_or(0.0);
391
392        let mut result = Array1::zeros(x.len());
393        for (i, &val) in x.iter().enumerate() {
394            let adjusted_val = val + offset;
395            if adjusted_val <= 0.0 {
396                return Err(SklearsError::InvalidInput(format!(
397                    "Cannot take log of non-positive value: {}",
398                    adjusted_val
399                )));
400            }
401            result[i] = adjusted_val.ln();
402        }
403
404        Ok(result)
405    }
406
407    /// Apply Box-Cox transformation
408    fn box_cox_transform(&self, x: &Array1<Float>, lambda: Float) -> Result<Array1<Float>> {
409        let offset = self.log_offset.unwrap_or(0.0);
410
411        let mut result = Array1::zeros(x.len());
412        for (i, &val) in x.iter().enumerate() {
413            let adjusted_val = val + offset;
414            if adjusted_val <= 0.0 {
415                return Err(SklearsError::InvalidInput(format!(
416                    "Cannot apply Box-Cox to non-positive value: {}",
417                    adjusted_val
418                )));
419            }
420
421            result[i] = if lambda.abs() < 1e-10 {
422                adjusted_val.ln()
423            } else {
424                (adjusted_val.powf(lambda) - 1.0) / lambda
425            };
426        }
427
428        Ok(result)
429    }
430
431    /// Apply combined first and seasonal differencing
432    fn combined_difference(&self, x: &Array1<Float>, period: usize) -> Result<Array1<Float>> {
433        let first_diff = self.first_difference(x)?;
434        if first_diff.len() <= period {
435            return Ok(Array1::zeros(0));
436        }
437
438        let mut result = Array1::zeros(first_diff.len() - period);
439        for i in period..first_diff.len() {
440            result[i - period] = first_diff[i] - first_diff[i - period];
441        }
442
443        self.handle_missing_values(result)
444    }
445
446    /// Apply moving average detrending
447    fn moving_average_detrend(&self, x: &Array1<Float>, window: usize) -> Result<Array1<Float>> {
448        if x.len() < window {
449            return Err(SklearsError::InvalidInput(
450                "Series too short for moving average window".to_string(),
451            ));
452        }
453
454        let mut result = Array1::zeros(x.len());
455
456        // Calculate moving averages
457        for i in 0..x.len() {
458            let start_idx = i.saturating_sub(window / 2);
459            let end_idx = ((i + window / 2 + 1).min(x.len())).max(start_idx + 1);
460
461            let window_slice = x.slice(s![start_idx..end_idx]);
462            let moving_avg = window_slice.mean().unwrap_or(0.0);
463
464            result[i] = x[i] - moving_avg;
465        }
466
467        Ok(result)
468    }
469
470    /// Handle missing values according to fill method
471    fn handle_missing_values(&self, mut x: Array1<Float>) -> Result<Array1<Float>> {
472        match self.config.fill_method {
473            FillMethod::Drop => Ok(x), // Already handled by differencing operations
474            FillMethod::ForwardFill => {
475                if let Some(&first_val) = x.first() {
476                    for val in x.iter_mut() {
477                        if val.is_nan() {
478                            *val = first_val;
479                        }
480                    }
481                }
482                Ok(x)
483            }
484            FillMethod::BackwardFill => {
485                if let Some(&last_val) = x.last() {
486                    for val in x.iter_mut() {
487                        if val.is_nan() {
488                            *val = last_val;
489                        }
490                    }
491                }
492                Ok(x)
493            }
494            FillMethod::Zero => {
495                for val in x.iter_mut() {
496                    if val.is_nan() {
497                        *val = 0.0;
498                    }
499                }
500                Ok(x)
501            }
502            FillMethod::Mean => {
503                let finite_values: Vec<Float> =
504                    x.iter().copied().filter(|v| v.is_finite()).collect();
505                if !finite_values.is_empty() {
506                    let mean_val =
507                        finite_values.iter().sum::<Float>() / finite_values.len() as Float;
508                    for val in x.iter_mut() {
509                        if val.is_nan() {
510                            *val = mean_val;
511                        }
512                    }
513                }
514                Ok(x)
515            }
516        }
517    }
518
519    /// Simple stationarity test (simplified ADF-like test)
520    fn test_stationarity(&self, x: &Array1<Float>) -> Result<Float> {
521        if x.len() < 10 {
522            return Ok(0.0); // Not enough data for meaningful test
523        }
524
525        // Calculate first-order autocorrelation as a simple stationarity proxy
526        let mean = x.mean().unwrap_or(0.0);
527        let mut numerator = 0.0;
528        let mut denominator = 0.0;
529
530        for i in 1..x.len() {
531            numerator += (x[i] - mean) * (x[i - 1] - mean);
532            denominator += (x[i - 1] - mean).powi(2);
533        }
534
535        let autocorr = if denominator.abs() > 1e-10 {
536            numerator / denominator
537        } else {
538            0.0
539        };
540
541        // Return absolute autocorrelation as stationarity score (lower is more stationary)
542        Ok(autocorr.abs())
543    }
544}
545
546/// Utility functions
547/// Fit linear trend using least squares
548fn fit_linear_trend(x: &Array1<Float>) -> Result<Vec<Float>> {
549    let n = x.len() as Float;
550    let x_mean = (n - 1.0) / 2.0;
551    let y_mean = x.mean().unwrap_or(0.0);
552
553    let mut numerator = 0.0;
554    let mut denominator = 0.0;
555
556    for (i, &y) in x.iter().enumerate() {
557        let xi = i as Float;
558        numerator += (xi - x_mean) * (y - y_mean);
559        denominator += (xi - x_mean).powi(2);
560    }
561
562    let slope = if denominator.abs() > 1e-10 {
563        numerator / denominator
564    } else {
565        0.0
566    };
567
568    let intercept = y_mean - slope * x_mean;
569
570    Ok(vec![slope, intercept])
571}
572
573/// Fit polynomial trend using least squares (simplified for low-degree polynomials)
574fn fit_polynomial_trend(x: &Array1<Float>, degree: usize) -> Result<Vec<Float>> {
575    if degree == 1 {
576        return fit_linear_trend(x);
577    }
578
579    if degree > 3 {
580        return Err(SklearsError::InvalidInput(
581            "Polynomial degree > 3 not supported in this implementation".to_string(),
582        ));
583    }
584
585    // For simplicity, implement up to degree 3 polynomials
586    match degree {
587        0 => Ok(vec![x.mean().unwrap_or(0.0)]),
588        2 => fit_quadratic_trend(x),
589        3 => fit_cubic_trend(x),
590        _ => fit_linear_trend(x), // Fallback
591    }
592}
593
594/// Fit quadratic trend (degree 2 polynomial)
595fn fit_quadratic_trend(x: &Array1<Float>) -> Result<Vec<Float>> {
596    let n = x.len();
597    if n < 3 {
598        return Err(SklearsError::InvalidInput(
599            "Need at least 3 points for quadratic fit".to_string(),
600        ));
601    }
602
603    // Simplified quadratic fitting using normal equations
604    // For a more robust implementation, use QR decomposition or SVD
605
606    let mut s0 = 0.0; // sum of 1
607    let mut s1 = 0.0; // sum of t
608    let mut s2 = 0.0; // sum of t^2
609    let mut s3 = 0.0; // sum of t^3
610    let mut s4 = 0.0; // sum of t^4
611    let mut sy = 0.0; // sum of y
612    let mut sty = 0.0; // sum of t*y
613    let mut st2y = 0.0; // sum of t^2*y
614
615    for (i, &y) in x.iter().enumerate() {
616        let t = i as Float;
617        let t2 = t * t;
618        let t3 = t2 * t;
619        let t4 = t3 * t;
620
621        s0 += 1.0;
622        s1 += t;
623        s2 += t2;
624        s3 += t3;
625        s4 += t4;
626        sy += y;
627        sty += t * y;
628        st2y += t2 * y;
629    }
630
631    // Solve the normal equations using Cramer's rule (for 3x3 system)
632    let det = s0 * (s2 * s4 - s3 * s3) - s1 * (s1 * s4 - s2 * s3) + s2 * (s1 * s3 - s2 * s2);
633
634    if det.abs() < 1e-10 {
635        return Err(SklearsError::InvalidInput(
636            "Matrix is singular for quadratic fitting".to_string(),
637        ));
638    }
639
640    let a0 =
641        (sy * (s2 * s4 - s3 * s3) - sty * (s1 * s4 - s2 * s3) + st2y * (s1 * s3 - s2 * s2)) / det;
642    let a1 =
643        (s0 * (sty * s4 - st2y * s3) - sy * (s1 * s4 - s2 * s3) + st2y * (s1 * s2 - s0 * s3)) / det;
644    let a2 = (s0 * (s2 * st2y - s3 * sty) - s1 * (s1 * st2y - s2 * sty) + sy * (s1 * s3 - s2 * s2))
645        / det;
646
647    Ok(vec![a0, a1, a2])
648}
649
650/// Fit cubic trend (degree 3 polynomial) - simplified implementation
651fn fit_cubic_trend(x: &Array1<Float>) -> Result<Vec<Float>> {
652    // For now, fallback to quadratic for simplicity
653    // A full cubic implementation would require solving a 4x4 system
654    fit_quadratic_trend(x)
655}
656
657/// Calculate standard deviation
658fn calculate_std(x: &Array1<Float>) -> Float {
659    let mean = x.mean().unwrap_or(0.0);
660    let var = x.iter().map(|&v| (v - mean).powi(2)).sum::<Float>() / (x.len() as Float - 1.0);
661    var.sqrt()
662}
663
664#[allow(non_snake_case)]
665#[cfg(test)]
666mod tests {
667    use super::*;
668    use approx::assert_abs_diff_eq;
669    use scirs2_core::essentials::Uniform;
670    use scirs2_core::ndarray::Array1;
671    use scirs2_core::random::thread_rng;
672
673    #[test]
674    fn test_first_difference() -> Result<()> {
675        let data = Array1::from(vec![1.0, 3.0, 6.0, 10.0, 15.0]);
676        let transformer =
677            StationarityTransformer::new().with_method(StationarityMethod::FirstDifference);
678
679        let fitted = transformer.fit(&data, &())?;
680        let result = fitted.transform(&data)?;
681
682        let expected = Array1::from(vec![2.0, 3.0, 4.0, 5.0]);
683        assert_eq!(result.len(), expected.len());
684
685        for (actual, expected) in result.iter().zip(expected.iter()) {
686            assert_abs_diff_eq!(*actual, *expected, epsilon = 1e-10);
687        }
688
689        Ok(())
690    }
691
692    #[test]
693    fn test_second_difference() -> Result<()> {
694        let data = Array1::from(vec![1.0, 4.0, 9.0, 16.0, 25.0]); // Perfect squares
695        let transformer =
696            StationarityTransformer::new().with_method(StationarityMethod::SecondDifference);
697
698        let fitted = transformer.fit(&data, &())?;
699        let result = fitted.transform(&data)?;
700
701        // Second difference should be constant for quadratic series
702        let expected = Array1::from(vec![2.0, 2.0, 2.0]);
703        assert_eq!(result.len(), expected.len());
704
705        for (actual, expected) in result.iter().zip(expected.iter()) {
706            assert_abs_diff_eq!(*actual, *expected, epsilon = 1e-10);
707        }
708
709        Ok(())
710    }
711
712    #[test]
713    fn test_seasonal_difference() -> Result<()> {
714        let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
715        let transformer =
716            StationarityTransformer::new().with_method(StationarityMethod::SeasonalDifference(4));
717
718        let fitted = transformer.fit(&data, &())?;
719        let result = fitted.transform(&data)?;
720
721        // Seasonal difference with period 4
722        let expected = Array1::from(vec![4.0, 4.0, 4.0, 4.0]); // 5-1, 6-2, 7-3, 8-4
723        assert_eq!(result.len(), expected.len());
724
725        for (actual, expected) in result.iter().zip(expected.iter()) {
726            assert_abs_diff_eq!(*actual, *expected, epsilon = 1e-10);
727        }
728
729        Ok(())
730    }
731
732    #[test]
733    fn test_linear_detrend() -> Result<()> {
734        // Create data with linear trend
735        let mut data = Array1::zeros(10);
736        let mut rng = thread_rng();
737        for i in 0..10 {
738            data[i] = 2.0 * (i as Float) + 5.0 + rng.sample(&Uniform::new(-0.1, 0.1).unwrap());
739            // Linear trend with noise
740        }
741
742        let transformer =
743            StationarityTransformer::new().with_method(StationarityMethod::LinearDetrend);
744
745        let fitted = transformer.fit(&data, &())?;
746        let result = fitted.transform(&data)?;
747
748        // After detrending, the mean should be close to 0
749        let mean = result.mean().unwrap_or(0.0);
750        assert_abs_diff_eq!(mean, 0.0, epsilon = 0.2); // Allow for some noise
751
752        Ok(())
753    }
754
755    #[test]
756    fn test_log_transform() -> Result<()> {
757        let data = Array1::from(vec![1.0, 2.0, 4.0, 8.0, 16.0]);
758        let transformer =
759            StationarityTransformer::new().with_method(StationarityMethod::LogTransform);
760
761        let fitted = transformer.fit(&data, &())?;
762        let result = fitted.transform(&data)?;
763
764        // Log of exponential sequence should be linear
765        let expected = Array1::from(vec![
766            0.0,
767            2.0_f64.ln(),
768            4.0_f64.ln(),
769            8.0_f64.ln(),
770            16.0_f64.ln(),
771        ]);
772        assert_eq!(result.len(), expected.len());
773
774        for (actual, expected) in result.iter().zip(expected.iter()) {
775            assert_abs_diff_eq!(*actual, *expected, epsilon = 1e-10);
776        }
777
778        Ok(())
779    }
780
781    #[test]
782    fn test_box_cox_transform() -> Result<()> {
783        let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
784        let lambda = 0.5; // Square root transformation
785        let transformer =
786            StationarityTransformer::new().with_method(StationarityMethod::BoxCox(lambda));
787
788        let fitted = transformer.fit(&data, &())?;
789        let result = fitted.transform(&data)?;
790
791        // Box-Cox with lambda=0.5 should give (x^0.5 - 1) / 0.5 = 2 * (sqrt(x) - 1)
792        let expected: Array1<Float> = data.mapv(|x| 2.0 * (x.sqrt() - 1.0));
793        assert_eq!(result.len(), expected.len());
794
795        for (actual, expected) in result.iter().zip(expected.iter()) {
796            assert_abs_diff_eq!(*actual, *expected, epsilon = 1e-10);
797        }
798
799        Ok(())
800    }
801
802    #[test]
803    fn test_moving_average_detrend() -> Result<()> {
804        let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
805        let transformer =
806            StationarityTransformer::new().with_method(StationarityMethod::MovingAverageDetrend(3));
807
808        let fitted = transformer.fit(&data, &())?;
809        let result = fitted.transform(&data)?;
810
811        // Result should have same length as input
812        assert_eq!(result.len(), data.len());
813
814        // The detrended series should oscillate around zero
815        let mean = result.mean().unwrap_or(0.0);
816        assert!(mean.abs() < 1.0); // Should be reasonably close to zero
817
818        Ok(())
819    }
820
821    #[test]
822    fn test_empty_series() -> Result<()> {
823        let data = Array1::zeros(0);
824        let transformer = StationarityTransformer::new();
825
826        let fitted = transformer.fit(&Array1::from(vec![1.0, 2.0]), &())?;
827        let result = fitted.transform(&data)?;
828
829        assert_eq!(result.len(), 0);
830
831        Ok(())
832    }
833
834    #[test]
835    fn test_short_series_error() {
836        let data = Array1::from(vec![1.0]); // Only one point
837        let transformer = StationarityTransformer::new();
838
839        let result = transformer.fit(&data, &());
840        assert!(result.is_err());
841    }
842}