rustkernel_temporal/
decomposition.rs

1//! Time series decomposition kernels.
2//!
3//! This module provides decomposition algorithms:
4//! - Seasonal decomposition (STL-like)
5//! - Trend extraction (various moving average methods)
6
7use std::time::Instant;
8
9use async_trait::async_trait;
10
11use crate::messages::{
12    SeasonalDecompositionInput, SeasonalDecompositionOutput, TrendExtractionInput,
13    TrendExtractionOutput,
14};
15use crate::types::{DecompositionResult, TimeSeries, TrendMethod, TrendResult};
16use rustkernel_core::{
17    domain::Domain,
18    error::Result,
19    kernel::KernelMetadata,
20    traits::{BatchKernel, GpuKernel},
21};
22
23// ============================================================================
24// Seasonal Decomposition Kernel
25// ============================================================================
26
27/// Seasonal decomposition kernel (STL-like).
28///
29/// Decomposes a time series into trend, seasonal, and residual components.
30#[derive(Debug, Clone)]
31pub struct SeasonalDecomposition {
32    metadata: KernelMetadata,
33}
34
35impl Default for SeasonalDecomposition {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl SeasonalDecomposition {
42    /// Create a new seasonal decomposition kernel.
43    #[must_use]
44    pub fn new() -> Self {
45        Self {
46            metadata: KernelMetadata::batch(
47                "temporal/seasonal-decomposition",
48                Domain::TemporalAnalysis,
49            )
50            .with_description("STL-style seasonal decomposition")
51            .with_throughput(10_000)
52            .with_latency_us(100.0),
53        }
54    }
55
56    /// Decompose a time series.
57    ///
58    /// # Arguments
59    /// * `series` - Input time series
60    /// * `period` - Seasonal period
61    /// * `robust` - Use robust (median-based) estimation
62    pub fn compute(series: &TimeSeries, period: usize, robust: bool) -> DecompositionResult {
63        let n = series.len();
64
65        if n < 2 * period || period < 2 {
66            return DecompositionResult {
67                trend: series.values.clone(),
68                seasonal: vec![0.0; n],
69                residual: vec![0.0; n],
70                n,
71                period,
72            };
73        }
74
75        // Step 1: Initial trend estimation using centered moving average
76        let trend = Self::centered_moving_average(&series.values, period);
77
78        // Step 2: Detrend the series
79        let detrended: Vec<f64> = series
80            .values
81            .iter()
82            .zip(trend.iter())
83            .map(|(v, t)| v - t)
84            .collect();
85
86        // Step 3: Estimate seasonal component
87        let seasonal_pattern = if robust {
88            Self::robust_seasonal(&detrended, period)
89        } else {
90            Self::mean_seasonal(&detrended, period)
91        };
92
93        // Extend seasonal pattern to full length
94        let seasonal: Vec<f64> = (0..n).map(|i| seasonal_pattern[i % period]).collect();
95
96        // Step 4: Refine trend by removing seasonality first
97        let deseasoned: Vec<f64> = series
98            .values
99            .iter()
100            .zip(seasonal.iter())
101            .map(|(v, s)| v - s)
102            .collect();
103
104        let refined_trend = Self::lowess_trend(&deseasoned, period);
105
106        // Step 5: Calculate residuals
107        let residual: Vec<f64> = series
108            .values
109            .iter()
110            .zip(refined_trend.iter())
111            .zip(seasonal.iter())
112            .map(|((v, t), s)| v - t - s)
113            .collect();
114
115        DecompositionResult {
116            trend: refined_trend,
117            seasonal,
118            residual,
119            n,
120            period,
121        }
122    }
123
124    /// Additive decomposition (simple version).
125    pub fn compute_additive(series: &TimeSeries, period: usize) -> DecompositionResult {
126        Self::compute(series, period, false)
127    }
128
129    /// Multiplicative decomposition.
130    ///
131    /// Y = T * S * R
132    pub fn compute_multiplicative(series: &TimeSeries, period: usize) -> DecompositionResult {
133        let n = series.len();
134
135        if n < 2 * period || period < 2 {
136            return DecompositionResult {
137                trend: series.values.clone(),
138                seasonal: vec![1.0; n],
139                residual: vec![1.0; n],
140                n,
141                period,
142            };
143        }
144
145        // Convert to log space if all values positive
146        let min_val = series.values.iter().cloned().fold(f64::INFINITY, f64::min);
147
148        if min_val <= 0.0 {
149            // Fall back to additive for non-positive data
150            return Self::compute(series, period, false);
151        }
152
153        let log_values: Vec<f64> = series.values.iter().map(|v| v.ln()).collect();
154        let log_series = TimeSeries::new(log_values);
155
156        let log_result = Self::compute(&log_series, period, false);
157
158        // Convert back from log space
159        DecompositionResult {
160            trend: log_result.trend.iter().map(|t| t.exp()).collect(),
161            seasonal: log_result.seasonal.iter().map(|s| s.exp()).collect(),
162            residual: log_result.residual.iter().map(|r| r.exp()).collect(),
163            n,
164            period,
165        }
166    }
167
168    /// Centered moving average.
169    fn centered_moving_average(values: &[f64], window: usize) -> Vec<f64> {
170        let n = values.len();
171        let half_w = window / 2;
172        let mut result = vec![0.0; n];
173
174        for i in 0..n {
175            let start = i.saturating_sub(half_w);
176            let end = (i + half_w + 1).min(n);
177
178            // For even windows, use weighted average at boundaries
179            if window % 2 == 0 && i >= half_w && i + half_w < n {
180                let mut sum = 0.0;
181                let mut weight = 0.0;
182
183                for j in start..end {
184                    let w = if j == start || j == end - 1 { 0.5 } else { 1.0 };
185                    sum += values[j] * w;
186                    weight += w;
187                }
188                result[i] = sum / weight;
189            } else {
190                result[i] = values[start..end].iter().sum::<f64>() / (end - start) as f64;
191            }
192        }
193
194        result
195    }
196
197    /// Mean-based seasonal estimation.
198    fn mean_seasonal(detrended: &[f64], period: usize) -> Vec<f64> {
199        let mut seasonal = vec![0.0; period];
200        let mut counts = vec![0usize; period];
201
202        for (i, &d) in detrended.iter().enumerate() {
203            seasonal[i % period] += d;
204            counts[i % period] += 1;
205        }
206
207        for i in 0..period {
208            if counts[i] > 0 {
209                seasonal[i] /= counts[i] as f64;
210            }
211        }
212
213        // Center the seasonal component
214        let mean: f64 = seasonal.iter().sum::<f64>() / period as f64;
215        for s in &mut seasonal {
216            *s -= mean;
217        }
218
219        seasonal
220    }
221
222    /// Robust (median-based) seasonal estimation.
223    fn robust_seasonal(detrended: &[f64], period: usize) -> Vec<f64> {
224        let mut seasonal = vec![0.0; period];
225
226        for s in 0..period {
227            let mut season_values: Vec<f64> = detrended
228                .iter()
229                .enumerate()
230                .filter(|(i, _)| i % period == s)
231                .map(|(_, &v)| v)
232                .collect();
233
234            if !season_values.is_empty() {
235                season_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
236                seasonal[s] = season_values[season_values.len() / 2];
237            }
238        }
239
240        // Center using median
241        let mut sorted = seasonal.clone();
242        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
243        let median = sorted[period / 2];
244
245        for s in &mut seasonal {
246            *s -= median;
247        }
248
249        seasonal
250    }
251
252    /// Lowess-style trend extraction (simplified).
253    fn lowess_trend(values: &[f64], bandwidth: usize) -> Vec<f64> {
254        let n = values.len();
255        let mut trend = vec![0.0; n];
256
257        for i in 0..n {
258            let start = i.saturating_sub(bandwidth);
259            let end = (i + bandwidth + 1).min(n);
260
261            // Tricube weights
262            let mut weighted_sum = 0.0;
263            let mut weight_sum = 0.0;
264
265            for j in start..end {
266                let dist = (j as f64 - i as f64).abs() / bandwidth as f64;
267                let weight = if dist < 1.0 {
268                    (1.0 - dist.powi(3)).powi(3)
269                } else {
270                    0.0
271                };
272                weighted_sum += values[j] * weight;
273                weight_sum += weight;
274            }
275
276            trend[i] = if weight_sum > 0.0 {
277                weighted_sum / weight_sum
278            } else {
279                values[i]
280            };
281        }
282
283        trend
284    }
285}
286
287impl GpuKernel for SeasonalDecomposition {
288    fn metadata(&self) -> &KernelMetadata {
289        &self.metadata
290    }
291}
292
293#[async_trait]
294impl BatchKernel<SeasonalDecompositionInput, SeasonalDecompositionOutput>
295    for SeasonalDecomposition
296{
297    async fn execute(
298        &self,
299        input: SeasonalDecompositionInput,
300    ) -> Result<SeasonalDecompositionOutput> {
301        let start = Instant::now();
302        let result = Self::compute(&input.series, input.period, input.robust);
303        Ok(SeasonalDecompositionOutput {
304            result,
305            compute_time_us: start.elapsed().as_micros() as u64,
306        })
307    }
308}
309
310// ============================================================================
311// Trend Extraction Kernel
312// ============================================================================
313
314/// Trend extraction kernel.
315///
316/// Extracts trend component using various moving average methods.
317#[derive(Debug, Clone)]
318pub struct TrendExtraction {
319    metadata: KernelMetadata,
320}
321
322impl Default for TrendExtraction {
323    fn default() -> Self {
324        Self::new()
325    }
326}
327
328impl TrendExtraction {
329    /// Create a new trend extraction kernel.
330    #[must_use]
331    pub fn new() -> Self {
332        Self {
333            metadata: KernelMetadata::batch("temporal/trend-extraction", Domain::TemporalAnalysis)
334                .with_description("Moving average trend extraction")
335                .with_throughput(50_000)
336                .with_latency_us(20.0),
337        }
338    }
339
340    /// Extract trend from a time series.
341    ///
342    /// # Arguments
343    /// * `series` - Input time series
344    /// * `method` - Trend extraction method
345    /// * `window` - Window size for moving average
346    pub fn compute(series: &TimeSeries, method: TrendMethod, window: usize) -> TrendResult {
347        if series.is_empty() {
348            return TrendResult {
349                trend: Vec::new(),
350                detrended: Vec::new(),
351                method,
352            };
353        }
354
355        let trend = match method {
356            TrendMethod::SimpleMovingAverage => Self::simple_ma(&series.values, window),
357            TrendMethod::ExponentialMovingAverage => Self::exponential_ma(&series.values, window),
358            TrendMethod::CenteredMovingAverage => Self::centered_ma(&series.values, window),
359            TrendMethod::Lowess => Self::lowess(&series.values, window),
360        };
361
362        let detrended: Vec<f64> = series
363            .values
364            .iter()
365            .zip(trend.iter())
366            .map(|(v, t)| v - t)
367            .collect();
368
369        TrendResult {
370            trend,
371            detrended,
372            method,
373        }
374    }
375
376    /// Simple moving average.
377    fn simple_ma(values: &[f64], window: usize) -> Vec<f64> {
378        let n = values.len();
379        let w = window.min(n).max(1);
380        let mut result = vec![0.0; n];
381
382        // Cumulative sum for efficient computation
383        let mut cumsum = vec![0.0; n + 1];
384        for (i, &v) in values.iter().enumerate() {
385            cumsum[i + 1] = cumsum[i] + v;
386        }
387
388        for i in 0..n {
389            let start = i.saturating_sub(w - 1);
390            let count = i - start + 1;
391            result[i] = (cumsum[i + 1] - cumsum[start]) / count as f64;
392        }
393
394        result
395    }
396
397    /// Exponential moving average.
398    fn exponential_ma(values: &[f64], span: usize) -> Vec<f64> {
399        let n = values.len();
400        if n == 0 {
401            return Vec::new();
402        }
403
404        let alpha = 2.0 / (span as f64 + 1.0);
405        let mut result = vec![0.0; n];
406        result[0] = values[0];
407
408        for i in 1..n {
409            result[i] = alpha * values[i] + (1.0 - alpha) * result[i - 1];
410        }
411
412        result
413    }
414
415    /// Centered moving average.
416    fn centered_ma(values: &[f64], window: usize) -> Vec<f64> {
417        let n = values.len();
418        let half_w = window / 2;
419        let mut result = vec![0.0; n];
420
421        for i in 0..n {
422            let start = i.saturating_sub(half_w);
423            let end = (i + half_w + 1).min(n);
424            result[i] = values[start..end].iter().sum::<f64>() / (end - start) as f64;
425        }
426
427        result
428    }
429
430    /// Lowess (Locally Weighted Scatterplot Smoothing).
431    fn lowess(values: &[f64], bandwidth: usize) -> Vec<f64> {
432        let n = values.len();
433        let mut result = vec![0.0; n];
434
435        for i in 0..n {
436            let start = i.saturating_sub(bandwidth);
437            let end = (i + bandwidth + 1).min(n);
438
439            // Fit local linear regression with tricube weights
440            let mut sum_w = 0.0;
441            let mut sum_wx = 0.0;
442            let mut sum_wy = 0.0;
443            let mut sum_wxx = 0.0;
444            let mut sum_wxy = 0.0;
445
446            for j in start..end {
447                let x = j as f64;
448                let y = values[j];
449                let dist = (j as f64 - i as f64).abs() / (bandwidth as f64 + 1.0);
450                let w = if dist < 1.0 {
451                    (1.0 - dist.powi(3)).powi(3)
452                } else {
453                    0.0
454                };
455
456                sum_w += w;
457                sum_wx += w * x;
458                sum_wy += w * y;
459                sum_wxx += w * x * x;
460                sum_wxy += w * x * y;
461            }
462
463            // Solve for local linear fit at point i
464            let det = sum_w * sum_wxx - sum_wx * sum_wx;
465            if det.abs() > 1e-10 {
466                let b0 = (sum_wxx * sum_wy - sum_wx * sum_wxy) / det;
467                let b1 = (sum_w * sum_wxy - sum_wx * sum_wy) / det;
468                result[i] = b0 + b1 * i as f64;
469            } else {
470                result[i] = if sum_w > 0.0 {
471                    sum_wy / sum_w
472                } else {
473                    values[i]
474                };
475            }
476        }
477
478        result
479    }
480
481    /// Double exponential smoothing (Holt's method).
482    pub fn holt_smoothing(values: &[f64], alpha: f64, beta: f64) -> (Vec<f64>, Vec<f64>) {
483        let n = values.len();
484        if n < 2 {
485            return (values.to_vec(), vec![0.0; n]);
486        }
487
488        let mut level = vec![0.0; n];
489        let mut trend = vec![0.0; n];
490
491        // Initialize
492        level[0] = values[0];
493        trend[0] = values[1] - values[0];
494
495        for i in 1..n {
496            level[i] = alpha * values[i] + (1.0 - alpha) * (level[i - 1] + trend[i - 1]);
497            trend[i] = beta * (level[i] - level[i - 1]) + (1.0 - beta) * trend[i - 1];
498        }
499
500        (level, trend)
501    }
502}
503
504impl GpuKernel for TrendExtraction {
505    fn metadata(&self) -> &KernelMetadata {
506        &self.metadata
507    }
508}
509
510#[async_trait]
511impl BatchKernel<TrendExtractionInput, TrendExtractionOutput> for TrendExtraction {
512    async fn execute(&self, input: TrendExtractionInput) -> Result<TrendExtractionOutput> {
513        let start = Instant::now();
514        let result = Self::compute(&input.series, input.method, input.window);
515        Ok(TrendExtractionOutput {
516            result,
517            compute_time_us: start.elapsed().as_micros() as u64,
518        })
519    }
520}
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525
526    fn create_seasonal_series() -> TimeSeries {
527        // Trend + seasonal pattern
528        let period = 12;
529        let values: Vec<f64> = (0..120)
530            .map(|t| {
531                let trend = 100.0 + 0.5 * t as f64;
532                let seasonal =
533                    10.0 * ((2.0 * std::f64::consts::PI * t as f64 / period as f64).sin());
534                trend + seasonal
535            })
536            .collect();
537        TimeSeries::new(values)
538    }
539
540    fn create_trend_series() -> TimeSeries {
541        // Pure trend with some noise
542        TimeSeries::new(
543            (0..100)
544                .map(|t| 10.0 + 2.0 * t as f64 + (t as f64 * 0.3).sin())
545                .collect(),
546        )
547    }
548
549    #[test]
550    fn test_decomposition_metadata() {
551        let kernel = SeasonalDecomposition::new();
552        assert_eq!(kernel.metadata().id, "temporal/seasonal-decomposition");
553        assert_eq!(kernel.metadata().domain, Domain::TemporalAnalysis);
554    }
555
556    #[test]
557    fn test_seasonal_decomposition() {
558        let series = create_seasonal_series();
559        let result = SeasonalDecomposition::compute(&series, 12, false);
560
561        // Should have correct lengths
562        assert_eq!(result.trend.len(), series.len());
563        assert_eq!(result.seasonal.len(), series.len());
564        assert_eq!(result.residual.len(), series.len());
565        assert_eq!(result.period, 12);
566
567        // Seasonal should be periodic
568        for i in 0..result.seasonal.len() - 12 {
569            let diff = (result.seasonal[i] - result.seasonal[i + 12]).abs();
570            assert!(diff < 0.01, "Seasonal not periodic at {}: diff={}", i, diff);
571        }
572    }
573
574    #[test]
575    fn test_robust_decomposition() {
576        let series = create_seasonal_series();
577        let result = SeasonalDecomposition::compute(&series, 12, true);
578
579        assert_eq!(result.trend.len(), series.len());
580        // Robust version should also produce valid decomposition
581    }
582
583    #[test]
584    fn test_multiplicative_decomposition() {
585        // Create multiplicative seasonal pattern
586        let values: Vec<f64> = (0..120)
587            .map(|t| {
588                let trend = 100.0 + 0.5 * t as f64;
589                let seasonal = 1.0 + 0.1 * ((2.0 * std::f64::consts::PI * t as f64 / 12.0).sin());
590                trend * seasonal
591            })
592            .collect();
593        let series = TimeSeries::new(values);
594
595        let result = SeasonalDecomposition::compute_multiplicative(&series, 12);
596
597        assert_eq!(result.trend.len(), series.len());
598        // In multiplicative, seasonal should be multiplicative factors
599    }
600
601    #[test]
602    fn test_trend_extraction_metadata() {
603        let kernel = TrendExtraction::new();
604        assert_eq!(kernel.metadata().id, "temporal/trend-extraction");
605    }
606
607    #[test]
608    fn test_simple_moving_average() {
609        let series = create_trend_series();
610        let result = TrendExtraction::compute(&series, TrendMethod::SimpleMovingAverage, 5);
611
612        assert_eq!(result.trend.len(), series.len());
613        assert_eq!(result.method, TrendMethod::SimpleMovingAverage);
614
615        // Trend should be smoother than original
616        let original_var: f64 = series.values.windows(2).map(|w| (w[1] - w[0]).abs()).sum();
617        let trend_var: f64 = result.trend.windows(2).map(|w| (w[1] - w[0]).abs()).sum();
618        assert!(trend_var <= original_var);
619    }
620
621    #[test]
622    fn test_exponential_moving_average() {
623        let series = create_trend_series();
624        let result = TrendExtraction::compute(&series, TrendMethod::ExponentialMovingAverage, 10);
625
626        assert_eq!(result.trend.len(), series.len());
627        assert_eq!(result.method, TrendMethod::ExponentialMovingAverage);
628    }
629
630    #[test]
631    fn test_centered_moving_average() {
632        let series = create_trend_series();
633        let result = TrendExtraction::compute(&series, TrendMethod::CenteredMovingAverage, 7);
634
635        assert_eq!(result.trend.len(), series.len());
636        assert_eq!(result.method, TrendMethod::CenteredMovingAverage);
637    }
638
639    #[test]
640    fn test_lowess_trend() {
641        let series = create_trend_series();
642        let result = TrendExtraction::compute(&series, TrendMethod::Lowess, 10);
643
644        assert_eq!(result.trend.len(), series.len());
645        assert_eq!(result.method, TrendMethod::Lowess);
646    }
647
648    #[test]
649    fn test_holt_smoothing() {
650        let values: Vec<f64> = (0..50).map(|t| 10.0 + 2.0 * t as f64).collect();
651        let (level, trend) = TrendExtraction::holt_smoothing(&values, 0.3, 0.1);
652
653        assert_eq!(level.len(), values.len());
654        assert_eq!(trend.len(), values.len());
655
656        // Trend should be approximately 2.0 (the slope)
657        assert!((trend.last().unwrap() - 2.0).abs() < 1.0);
658    }
659
660    #[test]
661    fn test_detrended_sums_to_zero_ish() {
662        let series = create_seasonal_series();
663        let result = TrendExtraction::compute(&series, TrendMethod::CenteredMovingAverage, 12);
664
665        // Detrended should roughly sum to zero (mean-centered)
666        let detrended_mean: f64 =
667            result.detrended.iter().sum::<f64>() / result.detrended.len() as f64;
668        assert!(
669            detrended_mean.abs() < 1.0,
670            "Detrended mean: {}",
671            detrended_mean
672        );
673    }
674
675    #[test]
676    fn test_empty_series() {
677        let empty = TimeSeries::new(Vec::new());
678
679        let decomp = SeasonalDecomposition::compute(&empty, 12, false);
680        assert!(decomp.trend.is_empty());
681
682        let trend = TrendExtraction::compute(&empty, TrendMethod::SimpleMovingAverage, 5);
683        assert!(trend.trend.is_empty());
684    }
685}