1use scirs2_core::ndarray::{s, Array1};
43use sklears_core::{
44 error::{Result, SklearsError},
45 traits::{Estimator, Fit, Transform, Untrained},
46 types::Float,
47};
48
49#[cfg(feature = "serde")]
50use serde::{Deserialize, Serialize};
51
52#[derive(Debug, Clone, Copy, PartialEq)]
54#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
55pub enum StationarityMethod {
56 FirstDifference,
58 SecondDifference,
60 SeasonalDifference(usize),
62 LinearDetrend,
64 PolynomialDetrend(usize),
66 LogTransform,
68 BoxCox(Float),
70 CombinedDifference(usize),
72 MovingAverageDetrend(usize),
74}
75
76impl Default for StationarityMethod {
77 fn default() -> Self {
78 Self::FirstDifference
79 }
80}
81
82#[derive(Debug, Clone)]
84#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
85pub struct StationarityTransformerConfig {
86 pub method: StationarityMethod,
88 pub fill_method: FillMethod,
90 pub min_value_offset: Float,
92 pub test_stationarity: bool,
94}
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
98#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
99pub enum FillMethod {
100 Drop,
102 ForwardFill,
104 BackwardFill,
106 Zero,
108 Mean,
110}
111
112impl Default for FillMethod {
113 fn default() -> Self {
114 Self::Drop
115 }
116}
117
118impl Default for StationarityTransformerConfig {
119 fn default() -> Self {
120 Self {
121 method: StationarityMethod::default(),
122 fill_method: FillMethod::default(),
123 min_value_offset: 1e-8,
124 test_stationarity: false,
125 }
126 }
127}
128
129#[derive(Debug, Clone)]
131#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
132pub struct StationarityTransformer<State = Untrained> {
133 config: StationarityTransformerConfig,
134 state: std::marker::PhantomData<State>,
135}
136
137#[derive(Debug, Clone)]
139#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
140pub struct StationarityTransformerFitted {
141 config: StationarityTransformerConfig,
142 trend_params: Option<Vec<Float>>,
144 log_offset: Option<Float>,
146 boxcox_lambda: Option<Float>,
148 original_mean: Option<Float>,
150 original_std: Option<Float>,
151}
152
153impl Default for StationarityTransformer<Untrained> {
154 fn default() -> Self {
155 Self::new()
156 }
157}
158
159impl StationarityTransformer<Untrained> {
160 pub fn new() -> Self {
162 Self {
163 config: StationarityTransformerConfig::default(),
164 state: std::marker::PhantomData,
165 }
166 }
167
168 pub fn with_method(mut self, method: StationarityMethod) -> Self {
170 self.config.method = method;
171 self
172 }
173
174 pub fn with_fill_method(mut self, fill_method: FillMethod) -> Self {
176 self.config.fill_method = fill_method;
177 self
178 }
179
180 pub fn with_min_value_offset(mut self, offset: Float) -> Self {
182 self.config.min_value_offset = offset;
183 self
184 }
185
186 pub fn with_stationarity_test(mut self, test: bool) -> Self {
188 self.config.test_stationarity = test;
189 self
190 }
191}
192
193impl Estimator for StationarityTransformer<Untrained> {
194 type Config = StationarityTransformerConfig;
195 type Error = SklearsError;
196 type Float = Float;
197
198 fn config(&self) -> &Self::Config {
199 &self.config
200 }
201}
202
203impl Fit<Array1<Float>, ()> for StationarityTransformer<Untrained> {
204 type Fitted = StationarityTransformerFitted;
205
206 fn fit(self, x: &Array1<Float>, _y: &()) -> Result<Self::Fitted> {
207 if x.len() < 2 {
208 return Err(SklearsError::InvalidInput(
209 "Time series must have at least 2 points".to_string(),
210 ));
211 }
212
213 let original_mean = x.mean();
214 let original_std = Some(calculate_std(x));
215
216 let (trend_params, log_offset, boxcox_lambda) = match self.config.method {
218 StationarityMethod::LinearDetrend => {
219 let params = fit_linear_trend(x)?;
220 (Some(params), None, None)
221 }
222 StationarityMethod::PolynomialDetrend(degree) => {
223 let params = fit_polynomial_trend(x, degree)?;
224 (Some(params), None, None)
225 }
226 StationarityMethod::LogTransform => {
227 let min_val = x.iter().fold(Float::INFINITY, |a, &b| a.min(b));
228 let offset = if min_val <= 0.0 {
229 -min_val + self.config.min_value_offset
230 } else {
231 0.0
232 };
233 (None, Some(offset), None)
234 }
235 StationarityMethod::BoxCox(lambda) => {
236 let min_val = x.iter().fold(Float::INFINITY, |a, &b| a.min(b));
237 let offset = if min_val <= 0.0 {
238 -min_val + self.config.min_value_offset
239 } else {
240 0.0
241 };
242 (None, Some(offset), Some(lambda))
243 }
244 _ => (None, None, None),
245 };
246
247 Ok(StationarityTransformerFitted {
248 config: self.config,
249 trend_params,
250 log_offset,
251 boxcox_lambda,
252 original_mean,
253 original_std,
254 })
255 }
256}
257
258impl Transform<Array1<Float>, Array1<Float>> for StationarityTransformerFitted {
259 fn transform(&self, x: &Array1<Float>) -> Result<Array1<Float>> {
260 if x.is_empty() {
261 return Ok(Array1::zeros(0));
262 }
263
264 let result = match self.config.method {
265 StationarityMethod::FirstDifference => self.first_difference(x)?,
266 StationarityMethod::SecondDifference => self.second_difference(x)?,
267 StationarityMethod::SeasonalDifference(period) => {
268 self.seasonal_difference(x, period)?
269 }
270 StationarityMethod::LinearDetrend => self.linear_detrend(x)?,
271 StationarityMethod::PolynomialDetrend(degree) => self.polynomial_detrend(x, degree)?,
272 StationarityMethod::LogTransform => self.log_transform(x)?,
273 StationarityMethod::BoxCox(lambda) => self.box_cox_transform(x, lambda)?,
274 StationarityMethod::CombinedDifference(period) => {
275 self.combined_difference(x, period)?
276 }
277 StationarityMethod::MovingAverageDetrend(window) => {
278 self.moving_average_detrend(x, window)?
279 }
280 };
281
282 if self.config.test_stationarity && result.len() > 10 {
284 let _test_result = self.test_stationarity(&result)?;
285 }
287
288 Ok(result)
289 }
290}
291
292impl StationarityTransformerFitted {
293 fn first_difference(&self, x: &Array1<Float>) -> Result<Array1<Float>> {
295 if x.len() < 2 {
296 return Ok(Array1::zeros(0));
297 }
298
299 let mut result = Array1::zeros(x.len() - 1);
300 for i in 1..x.len() {
301 result[i - 1] = x[i] - x[i - 1];
302 }
303
304 self.handle_missing_values(result)
305 }
306
307 fn second_difference(&self, x: &Array1<Float>) -> Result<Array1<Float>> {
309 let first_diff = self.first_difference(x)?;
310 if first_diff.len() < 2 {
311 return Ok(Array1::zeros(0));
312 }
313
314 let mut result = Array1::zeros(first_diff.len() - 1);
315 for i in 1..first_diff.len() {
316 result[i - 1] = first_diff[i] - first_diff[i - 1];
317 }
318
319 self.handle_missing_values(result)
320 }
321
322 fn seasonal_difference(&self, x: &Array1<Float>, period: usize) -> Result<Array1<Float>> {
324 if x.len() <= period {
325 return Ok(Array1::zeros(0));
326 }
327
328 let mut result = Array1::zeros(x.len() - period);
329 for i in period..x.len() {
330 result[i - period] = x[i] - x[i - period];
331 }
332
333 self.handle_missing_values(result)
334 }
335
336 fn linear_detrend(&self, x: &Array1<Float>) -> Result<Array1<Float>> {
338 let params = self
339 .trend_params
340 .as_ref()
341 .ok_or_else(|| SklearsError::NotFitted {
342 operation: "Linear trend not fitted".to_string(),
343 })?;
344
345 if params.len() != 2 {
346 return Err(SklearsError::InvalidInput(
347 "Linear trend requires 2 parameters".to_string(),
348 ));
349 }
350
351 let slope = params[0];
352 let intercept = params[1];
353
354 let mut result = Array1::zeros(x.len());
355 for (i, &val) in x.iter().enumerate() {
356 let trend_val = slope * (i as Float) + intercept;
357 result[i] = val - trend_val;
358 }
359
360 Ok(result)
361 }
362
363 fn polynomial_detrend(&self, x: &Array1<Float>, _degree: usize) -> Result<Array1<Float>> {
365 let params = self
366 .trend_params
367 .as_ref()
368 .ok_or_else(|| SklearsError::NotFitted {
369 operation: "Polynomial trend not fitted".to_string(),
370 })?;
371
372 let mut result = Array1::zeros(x.len());
373 for (i, &val) in x.iter().enumerate() {
374 let mut trend_val = 0.0;
375 let t = i as Float;
376
377 for (degree, &coeff) in params.iter().enumerate() {
379 trend_val += coeff * t.powi(degree as i32);
380 }
381
382 result[i] = val - trend_val;
383 }
384
385 Ok(result)
386 }
387
388 fn log_transform(&self, x: &Array1<Float>) -> Result<Array1<Float>> {
390 let offset = self.log_offset.unwrap_or(0.0);
391
392 let mut result = Array1::zeros(x.len());
393 for (i, &val) in x.iter().enumerate() {
394 let adjusted_val = val + offset;
395 if adjusted_val <= 0.0 {
396 return Err(SklearsError::InvalidInput(format!(
397 "Cannot take log of non-positive value: {}",
398 adjusted_val
399 )));
400 }
401 result[i] = adjusted_val.ln();
402 }
403
404 Ok(result)
405 }
406
407 fn box_cox_transform(&self, x: &Array1<Float>, lambda: Float) -> Result<Array1<Float>> {
409 let offset = self.log_offset.unwrap_or(0.0);
410
411 let mut result = Array1::zeros(x.len());
412 for (i, &val) in x.iter().enumerate() {
413 let adjusted_val = val + offset;
414 if adjusted_val <= 0.0 {
415 return Err(SklearsError::InvalidInput(format!(
416 "Cannot apply Box-Cox to non-positive value: {}",
417 adjusted_val
418 )));
419 }
420
421 result[i] = if lambda.abs() < 1e-10 {
422 adjusted_val.ln()
423 } else {
424 (adjusted_val.powf(lambda) - 1.0) / lambda
425 };
426 }
427
428 Ok(result)
429 }
430
431 fn combined_difference(&self, x: &Array1<Float>, period: usize) -> Result<Array1<Float>> {
433 let first_diff = self.first_difference(x)?;
434 if first_diff.len() <= period {
435 return Ok(Array1::zeros(0));
436 }
437
438 let mut result = Array1::zeros(first_diff.len() - period);
439 for i in period..first_diff.len() {
440 result[i - period] = first_diff[i] - first_diff[i - period];
441 }
442
443 self.handle_missing_values(result)
444 }
445
446 fn moving_average_detrend(&self, x: &Array1<Float>, window: usize) -> Result<Array1<Float>> {
448 if x.len() < window {
449 return Err(SklearsError::InvalidInput(
450 "Series too short for moving average window".to_string(),
451 ));
452 }
453
454 let mut result = Array1::zeros(x.len());
455
456 for i in 0..x.len() {
458 let start_idx = i.saturating_sub(window / 2);
459 let end_idx = ((i + window / 2 + 1).min(x.len())).max(start_idx + 1);
460
461 let window_slice = x.slice(s![start_idx..end_idx]);
462 let moving_avg = window_slice.mean().unwrap_or(0.0);
463
464 result[i] = x[i] - moving_avg;
465 }
466
467 Ok(result)
468 }
469
470 fn handle_missing_values(&self, mut x: Array1<Float>) -> Result<Array1<Float>> {
472 match self.config.fill_method {
473 FillMethod::Drop => Ok(x), FillMethod::ForwardFill => {
475 if let Some(&first_val) = x.first() {
476 for val in x.iter_mut() {
477 if val.is_nan() {
478 *val = first_val;
479 }
480 }
481 }
482 Ok(x)
483 }
484 FillMethod::BackwardFill => {
485 if let Some(&last_val) = x.last() {
486 for val in x.iter_mut() {
487 if val.is_nan() {
488 *val = last_val;
489 }
490 }
491 }
492 Ok(x)
493 }
494 FillMethod::Zero => {
495 for val in x.iter_mut() {
496 if val.is_nan() {
497 *val = 0.0;
498 }
499 }
500 Ok(x)
501 }
502 FillMethod::Mean => {
503 let finite_values: Vec<Float> =
504 x.iter().copied().filter(|v| v.is_finite()).collect();
505 if !finite_values.is_empty() {
506 let mean_val =
507 finite_values.iter().sum::<Float>() / finite_values.len() as Float;
508 for val in x.iter_mut() {
509 if val.is_nan() {
510 *val = mean_val;
511 }
512 }
513 }
514 Ok(x)
515 }
516 }
517 }
518
519 fn test_stationarity(&self, x: &Array1<Float>) -> Result<Float> {
521 if x.len() < 10 {
522 return Ok(0.0); }
524
525 let mean = x.mean().unwrap_or(0.0);
527 let mut numerator = 0.0;
528 let mut denominator = 0.0;
529
530 for i in 1..x.len() {
531 numerator += (x[i] - mean) * (x[i - 1] - mean);
532 denominator += (x[i - 1] - mean).powi(2);
533 }
534
535 let autocorr = if denominator.abs() > 1e-10 {
536 numerator / denominator
537 } else {
538 0.0
539 };
540
541 Ok(autocorr.abs())
543 }
544}
545
546fn fit_linear_trend(x: &Array1<Float>) -> Result<Vec<Float>> {
549 let n = x.len() as Float;
550 let x_mean = (n - 1.0) / 2.0;
551 let y_mean = x.mean().unwrap_or(0.0);
552
553 let mut numerator = 0.0;
554 let mut denominator = 0.0;
555
556 for (i, &y) in x.iter().enumerate() {
557 let xi = i as Float;
558 numerator += (xi - x_mean) * (y - y_mean);
559 denominator += (xi - x_mean).powi(2);
560 }
561
562 let slope = if denominator.abs() > 1e-10 {
563 numerator / denominator
564 } else {
565 0.0
566 };
567
568 let intercept = y_mean - slope * x_mean;
569
570 Ok(vec![slope, intercept])
571}
572
573fn fit_polynomial_trend(x: &Array1<Float>, degree: usize) -> Result<Vec<Float>> {
575 if degree == 1 {
576 return fit_linear_trend(x);
577 }
578
579 if degree > 3 {
580 return Err(SklearsError::InvalidInput(
581 "Polynomial degree > 3 not supported in this implementation".to_string(),
582 ));
583 }
584
585 match degree {
587 0 => Ok(vec![x.mean().unwrap_or(0.0)]),
588 2 => fit_quadratic_trend(x),
589 3 => fit_cubic_trend(x),
590 _ => fit_linear_trend(x), }
592}
593
594fn fit_quadratic_trend(x: &Array1<Float>) -> Result<Vec<Float>> {
596 let n = x.len();
597 if n < 3 {
598 return Err(SklearsError::InvalidInput(
599 "Need at least 3 points for quadratic fit".to_string(),
600 ));
601 }
602
603 let mut s0 = 0.0; let mut s1 = 0.0; let mut s2 = 0.0; let mut s3 = 0.0; let mut s4 = 0.0; let mut sy = 0.0; let mut sty = 0.0; let mut st2y = 0.0; for (i, &y) in x.iter().enumerate() {
616 let t = i as Float;
617 let t2 = t * t;
618 let t3 = t2 * t;
619 let t4 = t3 * t;
620
621 s0 += 1.0;
622 s1 += t;
623 s2 += t2;
624 s3 += t3;
625 s4 += t4;
626 sy += y;
627 sty += t * y;
628 st2y += t2 * y;
629 }
630
631 let det = s0 * (s2 * s4 - s3 * s3) - s1 * (s1 * s4 - s2 * s3) + s2 * (s1 * s3 - s2 * s2);
633
634 if det.abs() < 1e-10 {
635 return Err(SklearsError::InvalidInput(
636 "Matrix is singular for quadratic fitting".to_string(),
637 ));
638 }
639
640 let a0 =
641 (sy * (s2 * s4 - s3 * s3) - sty * (s1 * s4 - s2 * s3) + st2y * (s1 * s3 - s2 * s2)) / det;
642 let a1 =
643 (s0 * (sty * s4 - st2y * s3) - sy * (s1 * s4 - s2 * s3) + st2y * (s1 * s2 - s0 * s3)) / det;
644 let a2 = (s0 * (s2 * st2y - s3 * sty) - s1 * (s1 * st2y - s2 * sty) + sy * (s1 * s3 - s2 * s2))
645 / det;
646
647 Ok(vec![a0, a1, a2])
648}
649
650fn fit_cubic_trend(x: &Array1<Float>) -> Result<Vec<Float>> {
652 fit_quadratic_trend(x)
655}
656
657fn calculate_std(x: &Array1<Float>) -> Float {
659 let mean = x.mean().unwrap_or(0.0);
660 let var = x.iter().map(|&v| (v - mean).powi(2)).sum::<Float>() / (x.len() as Float - 1.0);
661 var.sqrt()
662}
663
664#[allow(non_snake_case)]
665#[cfg(test)]
666mod tests {
667 use super::*;
668 use approx::assert_abs_diff_eq;
669 use scirs2_core::essentials::Uniform;
670 use scirs2_core::ndarray::Array1;
671 use scirs2_core::random::thread_rng;
672
673 #[test]
674 fn test_first_difference() -> Result<()> {
675 let data = Array1::from(vec![1.0, 3.0, 6.0, 10.0, 15.0]);
676 let transformer =
677 StationarityTransformer::new().with_method(StationarityMethod::FirstDifference);
678
679 let fitted = transformer.fit(&data, &())?;
680 let result = fitted.transform(&data)?;
681
682 let expected = Array1::from(vec![2.0, 3.0, 4.0, 5.0]);
683 assert_eq!(result.len(), expected.len());
684
685 for (actual, expected) in result.iter().zip(expected.iter()) {
686 assert_abs_diff_eq!(*actual, *expected, epsilon = 1e-10);
687 }
688
689 Ok(())
690 }
691
692 #[test]
693 fn test_second_difference() -> Result<()> {
694 let data = Array1::from(vec![1.0, 4.0, 9.0, 16.0, 25.0]); let transformer =
696 StationarityTransformer::new().with_method(StationarityMethod::SecondDifference);
697
698 let fitted = transformer.fit(&data, &())?;
699 let result = fitted.transform(&data)?;
700
701 let expected = Array1::from(vec![2.0, 2.0, 2.0]);
703 assert_eq!(result.len(), expected.len());
704
705 for (actual, expected) in result.iter().zip(expected.iter()) {
706 assert_abs_diff_eq!(*actual, *expected, epsilon = 1e-10);
707 }
708
709 Ok(())
710 }
711
712 #[test]
713 fn test_seasonal_difference() -> Result<()> {
714 let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
715 let transformer =
716 StationarityTransformer::new().with_method(StationarityMethod::SeasonalDifference(4));
717
718 let fitted = transformer.fit(&data, &())?;
719 let result = fitted.transform(&data)?;
720
721 let expected = Array1::from(vec![4.0, 4.0, 4.0, 4.0]); assert_eq!(result.len(), expected.len());
724
725 for (actual, expected) in result.iter().zip(expected.iter()) {
726 assert_abs_diff_eq!(*actual, *expected, epsilon = 1e-10);
727 }
728
729 Ok(())
730 }
731
732 #[test]
733 fn test_linear_detrend() -> Result<()> {
734 let mut data = Array1::zeros(10);
736 let mut rng = thread_rng();
737 for i in 0..10 {
738 data[i] = 2.0 * (i as Float) + 5.0 + rng.sample(&Uniform::new(-0.1, 0.1).unwrap());
739 }
741
742 let transformer =
743 StationarityTransformer::new().with_method(StationarityMethod::LinearDetrend);
744
745 let fitted = transformer.fit(&data, &())?;
746 let result = fitted.transform(&data)?;
747
748 let mean = result.mean().unwrap_or(0.0);
750 assert_abs_diff_eq!(mean, 0.0, epsilon = 0.2); Ok(())
753 }
754
755 #[test]
756 fn test_log_transform() -> Result<()> {
757 let data = Array1::from(vec![1.0, 2.0, 4.0, 8.0, 16.0]);
758 let transformer =
759 StationarityTransformer::new().with_method(StationarityMethod::LogTransform);
760
761 let fitted = transformer.fit(&data, &())?;
762 let result = fitted.transform(&data)?;
763
764 let expected = Array1::from(vec![
766 0.0,
767 2.0_f64.ln(),
768 4.0_f64.ln(),
769 8.0_f64.ln(),
770 16.0_f64.ln(),
771 ]);
772 assert_eq!(result.len(), expected.len());
773
774 for (actual, expected) in result.iter().zip(expected.iter()) {
775 assert_abs_diff_eq!(*actual, *expected, epsilon = 1e-10);
776 }
777
778 Ok(())
779 }
780
781 #[test]
782 fn test_box_cox_transform() -> Result<()> {
783 let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
784 let lambda = 0.5; let transformer =
786 StationarityTransformer::new().with_method(StationarityMethod::BoxCox(lambda));
787
788 let fitted = transformer.fit(&data, &())?;
789 let result = fitted.transform(&data)?;
790
791 let expected: Array1<Float> = data.mapv(|x| 2.0 * (x.sqrt() - 1.0));
793 assert_eq!(result.len(), expected.len());
794
795 for (actual, expected) in result.iter().zip(expected.iter()) {
796 assert_abs_diff_eq!(*actual, *expected, epsilon = 1e-10);
797 }
798
799 Ok(())
800 }
801
802 #[test]
803 fn test_moving_average_detrend() -> Result<()> {
804 let data = Array1::from(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
805 let transformer =
806 StationarityTransformer::new().with_method(StationarityMethod::MovingAverageDetrend(3));
807
808 let fitted = transformer.fit(&data, &())?;
809 let result = fitted.transform(&data)?;
810
811 assert_eq!(result.len(), data.len());
813
814 let mean = result.mean().unwrap_or(0.0);
816 assert!(mean.abs() < 1.0); Ok(())
819 }
820
821 #[test]
822 fn test_empty_series() -> Result<()> {
823 let data = Array1::zeros(0);
824 let transformer = StationarityTransformer::new();
825
826 let fitted = transformer.fit(&Array1::from(vec![1.0, 2.0]), &())?;
827 let result = fitted.transform(&data)?;
828
829 assert_eq!(result.len(), 0);
830
831 Ok(())
832 }
833
834 #[test]
835 fn test_short_series_error() {
836 let data = Array1::from(vec![1.0]); let transformer = StationarityTransformer::new();
838
839 let result = transformer.fit(&data, &());
840 assert!(result.is_err());
841 }
842}