rustkernel_temporal/
decomposition.rs

1//! Time series decomposition kernels.
2//!
3//! This module provides decomposition algorithms:
4//! - Seasonal decomposition (STL-like)
5//! - Trend extraction (various moving average methods)
6
7use std::time::Instant;
8
9use async_trait::async_trait;
10
11use crate::messages::{
12    SeasonalDecompositionInput, SeasonalDecompositionOutput, TrendExtractionInput,
13    TrendExtractionOutput,
14};
15use crate::types::{DecompositionResult, TimeSeries, TrendMethod, TrendResult};
16use rustkernel_core::{
17    domain::Domain,
18    error::Result,
19    kernel::KernelMetadata,
20    traits::{BatchKernel, GpuKernel},
21};
22
23// ============================================================================
24// Seasonal Decomposition Kernel
25// ============================================================================
26
27/// Seasonal decomposition kernel (STL-like).
28///
29/// Decomposes a time series into trend, seasonal, and residual components.
30#[derive(Debug, Clone)]
31pub struct SeasonalDecomposition {
32    metadata: KernelMetadata,
33}
34
35impl Default for SeasonalDecomposition {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl SeasonalDecomposition {
42    /// Create a new seasonal decomposition kernel.
43    #[must_use]
44    pub fn new() -> Self {
45        Self {
46            metadata: KernelMetadata::batch(
47                "temporal/seasonal-decomposition",
48                Domain::TemporalAnalysis,
49            )
50            .with_description("STL-style seasonal decomposition")
51            .with_throughput(10_000)
52            .with_latency_us(100.0),
53        }
54    }
55
56    /// Decompose a time series.
57    ///
58    /// # Arguments
59    /// * `series` - Input time series
60    /// * `period` - Seasonal period
61    /// * `robust` - Use robust (median-based) estimation
62    pub fn compute(series: &TimeSeries, period: usize, robust: bool) -> DecompositionResult {
63        let n = series.len();
64
65        if n < 2 * period || period < 2 {
66            return DecompositionResult {
67                trend: series.values.clone(),
68                seasonal: vec![0.0; n],
69                residual: vec![0.0; n],
70                n,
71                period,
72            };
73        }
74
75        // Step 1: Initial trend estimation using centered moving average
76        let trend = Self::centered_moving_average(&series.values, period);
77
78        // Step 2: Detrend the series
79        let detrended: Vec<f64> = series
80            .values
81            .iter()
82            .zip(trend.iter())
83            .map(|(v, t)| v - t)
84            .collect();
85
86        // Step 3: Estimate seasonal component
87        let seasonal_pattern = if robust {
88            Self::robust_seasonal(&detrended, period)
89        } else {
90            Self::mean_seasonal(&detrended, period)
91        };
92
93        // Extend seasonal pattern to full length
94        let seasonal: Vec<f64> = (0..n).map(|i| seasonal_pattern[i % period]).collect();
95
96        // Step 4: Refine trend by removing seasonality first
97        let deseasoned: Vec<f64> = series
98            .values
99            .iter()
100            .zip(seasonal.iter())
101            .map(|(v, s)| v - s)
102            .collect();
103
104        let refined_trend = Self::lowess_trend(&deseasoned, period);
105
106        // Step 5: Calculate residuals
107        let residual: Vec<f64> = series
108            .values
109            .iter()
110            .zip(refined_trend.iter())
111            .zip(seasonal.iter())
112            .map(|((v, t), s)| v - t - s)
113            .collect();
114
115        DecompositionResult {
116            trend: refined_trend,
117            seasonal,
118            residual,
119            n,
120            period,
121        }
122    }
123
124    /// Additive decomposition (simple version).
125    pub fn compute_additive(series: &TimeSeries, period: usize) -> DecompositionResult {
126        Self::compute(series, period, false)
127    }
128
129    /// Multiplicative decomposition.
130    ///
131    /// Y = T * S * R
132    pub fn compute_multiplicative(series: &TimeSeries, period: usize) -> DecompositionResult {
133        let n = series.len();
134
135        if n < 2 * period || period < 2 {
136            return DecompositionResult {
137                trend: series.values.clone(),
138                seasonal: vec![1.0; n],
139                residual: vec![1.0; n],
140                n,
141                period,
142            };
143        }
144
145        // Convert to log space if all values positive
146        let min_val = series.values.iter().cloned().fold(f64::INFINITY, f64::min);
147
148        if min_val <= 0.0 {
149            // Fall back to additive for non-positive data
150            return Self::compute(series, period, false);
151        }
152
153        let log_values: Vec<f64> = series.values.iter().map(|v| v.ln()).collect();
154        let log_series = TimeSeries::new(log_values);
155
156        let log_result = Self::compute(&log_series, period, false);
157
158        // Convert back from log space
159        DecompositionResult {
160            trend: log_result.trend.iter().map(|t| t.exp()).collect(),
161            seasonal: log_result.seasonal.iter().map(|s| s.exp()).collect(),
162            residual: log_result.residual.iter().map(|r| r.exp()).collect(),
163            n,
164            period,
165        }
166    }
167
168    /// Centered moving average.
169    #[allow(clippy::needless_range_loop)]
170    fn centered_moving_average(values: &[f64], window: usize) -> Vec<f64> {
171        let n = values.len();
172        let half_w = window / 2;
173        let mut result = vec![0.0; n];
174
175        for i in 0..n {
176            let start = i.saturating_sub(half_w);
177            let end = (i + half_w + 1).min(n);
178
179            // For even windows, use weighted average at boundaries
180            if window % 2 == 0 && i >= half_w && i + half_w < n {
181                let mut sum = 0.0;
182                let mut weight = 0.0;
183
184                for j in start..end {
185                    let w = if j == start || j == end - 1 { 0.5 } else { 1.0 };
186                    sum += values[j] * w;
187                    weight += w;
188                }
189                result[i] = sum / weight;
190            } else {
191                result[i] = values[start..end].iter().sum::<f64>() / (end - start) as f64;
192            }
193        }
194
195        result
196    }
197
198    /// Mean-based seasonal estimation.
199    fn mean_seasonal(detrended: &[f64], period: usize) -> Vec<f64> {
200        let mut seasonal = vec![0.0; period];
201        let mut counts = vec![0usize; period];
202
203        for (i, &d) in detrended.iter().enumerate() {
204            seasonal[i % period] += d;
205            counts[i % period] += 1;
206        }
207
208        for i in 0..period {
209            if counts[i] > 0 {
210                seasonal[i] /= counts[i] as f64;
211            }
212        }
213
214        // Center the seasonal component
215        let mean: f64 = seasonal.iter().sum::<f64>() / period as f64;
216        for s in &mut seasonal {
217            *s -= mean;
218        }
219
220        seasonal
221    }
222
223    /// Robust (median-based) seasonal estimation.
224    #[allow(clippy::needless_range_loop)]
225    fn robust_seasonal(detrended: &[f64], period: usize) -> Vec<f64> {
226        let mut seasonal = vec![0.0; period];
227
228        for s in 0..period {
229            let mut season_values: Vec<f64> = detrended
230                .iter()
231                .enumerate()
232                .filter(|(i, _)| i % period == s)
233                .map(|(_, &v)| v)
234                .collect();
235
236            if !season_values.is_empty() {
237                season_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
238                seasonal[s] = season_values[season_values.len() / 2];
239            }
240        }
241
242        // Center using median
243        let mut sorted = seasonal.clone();
244        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
245        let median = sorted[period / 2];
246
247        for s in &mut seasonal {
248            *s -= median;
249        }
250
251        seasonal
252    }
253
254    /// Lowess-style trend extraction (simplified).
255    #[allow(clippy::needless_range_loop)]
256    fn lowess_trend(values: &[f64], bandwidth: usize) -> Vec<f64> {
257        let n = values.len();
258        let mut trend = vec![0.0; n];
259
260        for i in 0..n {
261            let start = i.saturating_sub(bandwidth);
262            let end = (i + bandwidth + 1).min(n);
263
264            // Tricube weights
265            let mut weighted_sum = 0.0;
266            let mut weight_sum = 0.0;
267
268            for j in start..end {
269                let dist = (j as f64 - i as f64).abs() / bandwidth as f64;
270                let weight = if dist < 1.0 {
271                    (1.0 - dist.powi(3)).powi(3)
272                } else {
273                    0.0
274                };
275                weighted_sum += values[j] * weight;
276                weight_sum += weight;
277            }
278
279            trend[i] = if weight_sum > 0.0 {
280                weighted_sum / weight_sum
281            } else {
282                values[i]
283            };
284        }
285
286        trend
287    }
288}
289
290impl GpuKernel for SeasonalDecomposition {
291    fn metadata(&self) -> &KernelMetadata {
292        &self.metadata
293    }
294}
295
296#[async_trait]
297impl BatchKernel<SeasonalDecompositionInput, SeasonalDecompositionOutput>
298    for SeasonalDecomposition
299{
300    async fn execute(
301        &self,
302        input: SeasonalDecompositionInput,
303    ) -> Result<SeasonalDecompositionOutput> {
304        let start = Instant::now();
305        let result = Self::compute(&input.series, input.period, input.robust);
306        Ok(SeasonalDecompositionOutput {
307            result,
308            compute_time_us: start.elapsed().as_micros() as u64,
309        })
310    }
311}
312
313// ============================================================================
314// Trend Extraction Kernel
315// ============================================================================
316
317/// Trend extraction kernel.
318///
319/// Extracts trend component using various moving average methods.
320#[derive(Debug, Clone)]
321pub struct TrendExtraction {
322    metadata: KernelMetadata,
323}
324
325impl Default for TrendExtraction {
326    fn default() -> Self {
327        Self::new()
328    }
329}
330
331impl TrendExtraction {
332    /// Create a new trend extraction kernel.
333    #[must_use]
334    pub fn new() -> Self {
335        Self {
336            metadata: KernelMetadata::batch("temporal/trend-extraction", Domain::TemporalAnalysis)
337                .with_description("Moving average trend extraction")
338                .with_throughput(50_000)
339                .with_latency_us(20.0),
340        }
341    }
342
343    /// Extract trend from a time series.
344    ///
345    /// # Arguments
346    /// * `series` - Input time series
347    /// * `method` - Trend extraction method
348    /// * `window` - Window size for moving average
349    pub fn compute(series: &TimeSeries, method: TrendMethod, window: usize) -> TrendResult {
350        if series.is_empty() {
351            return TrendResult {
352                trend: Vec::new(),
353                detrended: Vec::new(),
354                method,
355            };
356        }
357
358        let trend = match method {
359            TrendMethod::SimpleMovingAverage => Self::simple_ma(&series.values, window),
360            TrendMethod::ExponentialMovingAverage => Self::exponential_ma(&series.values, window),
361            TrendMethod::CenteredMovingAverage => Self::centered_ma(&series.values, window),
362            TrendMethod::Lowess => Self::lowess(&series.values, window),
363        };
364
365        let detrended: Vec<f64> = series
366            .values
367            .iter()
368            .zip(trend.iter())
369            .map(|(v, t)| v - t)
370            .collect();
371
372        TrendResult {
373            trend,
374            detrended,
375            method,
376        }
377    }
378
379    /// Simple moving average.
380    fn simple_ma(values: &[f64], window: usize) -> Vec<f64> {
381        let n = values.len();
382        let w = window.min(n).max(1);
383        let mut result = vec![0.0; n];
384
385        // Cumulative sum for efficient computation
386        let mut cumsum = vec![0.0; n + 1];
387        for (i, &v) in values.iter().enumerate() {
388            cumsum[i + 1] = cumsum[i] + v;
389        }
390
391        for i in 0..n {
392            let start = i.saturating_sub(w - 1);
393            let count = i - start + 1;
394            result[i] = (cumsum[i + 1] - cumsum[start]) / count as f64;
395        }
396
397        result
398    }
399
400    /// Exponential moving average.
401    fn exponential_ma(values: &[f64], span: usize) -> Vec<f64> {
402        let n = values.len();
403        if n == 0 {
404            return Vec::new();
405        }
406
407        let alpha = 2.0 / (span as f64 + 1.0);
408        let mut result = vec![0.0; n];
409        result[0] = values[0];
410
411        for i in 1..n {
412            result[i] = alpha * values[i] + (1.0 - alpha) * result[i - 1];
413        }
414
415        result
416    }
417
418    /// Centered moving average.
419    #[allow(clippy::needless_range_loop)]
420    fn centered_ma(values: &[f64], window: usize) -> Vec<f64> {
421        let n = values.len();
422        let half_w = window / 2;
423        let mut result = vec![0.0; n];
424
425        for i in 0..n {
426            let start = i.saturating_sub(half_w);
427            let end = (i + half_w + 1).min(n);
428            result[i] = values[start..end].iter().sum::<f64>() / (end - start) as f64;
429        }
430
431        result
432    }
433
434    /// Lowess (Locally Weighted Scatterplot Smoothing).
435    #[allow(clippy::needless_range_loop)]
436    fn lowess(values: &[f64], bandwidth: usize) -> Vec<f64> {
437        let n = values.len();
438        let mut result = vec![0.0; n];
439
440        for i in 0..n {
441            let start = i.saturating_sub(bandwidth);
442            let end = (i + bandwidth + 1).min(n);
443
444            // Fit local linear regression with tricube weights
445            let mut sum_w = 0.0;
446            let mut sum_wx = 0.0;
447            let mut sum_wy = 0.0;
448            let mut sum_wxx = 0.0;
449            let mut sum_wxy = 0.0;
450
451            for j in start..end {
452                let x = j as f64;
453                let y = values[j];
454                let dist = (j as f64 - i as f64).abs() / (bandwidth as f64 + 1.0);
455                let w = if dist < 1.0 {
456                    (1.0 - dist.powi(3)).powi(3)
457                } else {
458                    0.0
459                };
460
461                sum_w += w;
462                sum_wx += w * x;
463                sum_wy += w * y;
464                sum_wxx += w * x * x;
465                sum_wxy += w * x * y;
466            }
467
468            // Solve for local linear fit at point i
469            let det = sum_w * sum_wxx - sum_wx * sum_wx;
470            if det.abs() > 1e-10 {
471                let b0 = (sum_wxx * sum_wy - sum_wx * sum_wxy) / det;
472                let b1 = (sum_w * sum_wxy - sum_wx * sum_wy) / det;
473                result[i] = b0 + b1 * i as f64;
474            } else {
475                result[i] = if sum_w > 0.0 {
476                    sum_wy / sum_w
477                } else {
478                    values[i]
479                };
480            }
481        }
482
483        result
484    }
485
486    /// Double exponential smoothing (Holt's method).
487    pub fn holt_smoothing(values: &[f64], alpha: f64, beta: f64) -> (Vec<f64>, Vec<f64>) {
488        let n = values.len();
489        if n < 2 {
490            return (values.to_vec(), vec![0.0; n]);
491        }
492
493        let mut level = vec![0.0; n];
494        let mut trend = vec![0.0; n];
495
496        // Initialize
497        level[0] = values[0];
498        trend[0] = values[1] - values[0];
499
500        for i in 1..n {
501            level[i] = alpha * values[i] + (1.0 - alpha) * (level[i - 1] + trend[i - 1]);
502            trend[i] = beta * (level[i] - level[i - 1]) + (1.0 - beta) * trend[i - 1];
503        }
504
505        (level, trend)
506    }
507}
508
509impl GpuKernel for TrendExtraction {
510    fn metadata(&self) -> &KernelMetadata {
511        &self.metadata
512    }
513}
514
515#[async_trait]
516impl BatchKernel<TrendExtractionInput, TrendExtractionOutput> for TrendExtraction {
517    async fn execute(&self, input: TrendExtractionInput) -> Result<TrendExtractionOutput> {
518        let start = Instant::now();
519        let result = Self::compute(&input.series, input.method, input.window);
520        Ok(TrendExtractionOutput {
521            result,
522            compute_time_us: start.elapsed().as_micros() as u64,
523        })
524    }
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530
531    fn create_seasonal_series() -> TimeSeries {
532        // Trend + seasonal pattern
533        let period = 12;
534        let values: Vec<f64> = (0..120)
535            .map(|t| {
536                let trend = 100.0 + 0.5 * t as f64;
537                let seasonal =
538                    10.0 * ((2.0 * std::f64::consts::PI * t as f64 / period as f64).sin());
539                trend + seasonal
540            })
541            .collect();
542        TimeSeries::new(values)
543    }
544
545    fn create_trend_series() -> TimeSeries {
546        // Pure trend with some noise
547        TimeSeries::new(
548            (0..100)
549                .map(|t| 10.0 + 2.0 * t as f64 + (t as f64 * 0.3).sin())
550                .collect(),
551        )
552    }
553
554    #[test]
555    fn test_decomposition_metadata() {
556        let kernel = SeasonalDecomposition::new();
557        assert_eq!(kernel.metadata().id, "temporal/seasonal-decomposition");
558        assert_eq!(kernel.metadata().domain, Domain::TemporalAnalysis);
559    }
560
561    #[test]
562    fn test_seasonal_decomposition() {
563        let series = create_seasonal_series();
564        let result = SeasonalDecomposition::compute(&series, 12, false);
565
566        // Should have correct lengths
567        assert_eq!(result.trend.len(), series.len());
568        assert_eq!(result.seasonal.len(), series.len());
569        assert_eq!(result.residual.len(), series.len());
570        assert_eq!(result.period, 12);
571
572        // Seasonal should be periodic
573        for i in 0..result.seasonal.len() - 12 {
574            let diff = (result.seasonal[i] - result.seasonal[i + 12]).abs();
575            assert!(diff < 0.01, "Seasonal not periodic at {}: diff={}", i, diff);
576        }
577    }
578
579    #[test]
580    fn test_robust_decomposition() {
581        let series = create_seasonal_series();
582        let result = SeasonalDecomposition::compute(&series, 12, true);
583
584        assert_eq!(result.trend.len(), series.len());
585        // Robust version should also produce valid decomposition
586    }
587
588    #[test]
589    fn test_multiplicative_decomposition() {
590        // Create multiplicative seasonal pattern
591        let values: Vec<f64> = (0..120)
592            .map(|t| {
593                let trend = 100.0 + 0.5 * t as f64;
594                let seasonal = 1.0 + 0.1 * ((2.0 * std::f64::consts::PI * t as f64 / 12.0).sin());
595                trend * seasonal
596            })
597            .collect();
598        let series = TimeSeries::new(values);
599
600        let result = SeasonalDecomposition::compute_multiplicative(&series, 12);
601
602        assert_eq!(result.trend.len(), series.len());
603        // In multiplicative, seasonal should be multiplicative factors
604    }
605
606    #[test]
607    fn test_trend_extraction_metadata() {
608        let kernel = TrendExtraction::new();
609        assert_eq!(kernel.metadata().id, "temporal/trend-extraction");
610    }
611
612    #[test]
613    fn test_simple_moving_average() {
614        let series = create_trend_series();
615        let result = TrendExtraction::compute(&series, TrendMethod::SimpleMovingAverage, 5);
616
617        assert_eq!(result.trend.len(), series.len());
618        assert_eq!(result.method, TrendMethod::SimpleMovingAverage);
619
620        // Trend should be smoother than original
621        let original_var: f64 = series.values.windows(2).map(|w| (w[1] - w[0]).abs()).sum();
622        let trend_var: f64 = result.trend.windows(2).map(|w| (w[1] - w[0]).abs()).sum();
623        assert!(trend_var <= original_var);
624    }
625
626    #[test]
627    fn test_exponential_moving_average() {
628        let series = create_trend_series();
629        let result = TrendExtraction::compute(&series, TrendMethod::ExponentialMovingAverage, 10);
630
631        assert_eq!(result.trend.len(), series.len());
632        assert_eq!(result.method, TrendMethod::ExponentialMovingAverage);
633    }
634
635    #[test]
636    fn test_centered_moving_average() {
637        let series = create_trend_series();
638        let result = TrendExtraction::compute(&series, TrendMethod::CenteredMovingAverage, 7);
639
640        assert_eq!(result.trend.len(), series.len());
641        assert_eq!(result.method, TrendMethod::CenteredMovingAverage);
642    }
643
644    #[test]
645    fn test_lowess_trend() {
646        let series = create_trend_series();
647        let result = TrendExtraction::compute(&series, TrendMethod::Lowess, 10);
648
649        assert_eq!(result.trend.len(), series.len());
650        assert_eq!(result.method, TrendMethod::Lowess);
651    }
652
653    #[test]
654    fn test_holt_smoothing() {
655        let values: Vec<f64> = (0..50).map(|t| 10.0 + 2.0 * t as f64).collect();
656        let (level, trend) = TrendExtraction::holt_smoothing(&values, 0.3, 0.1);
657
658        assert_eq!(level.len(), values.len());
659        assert_eq!(trend.len(), values.len());
660
661        // Trend should be approximately 2.0 (the slope)
662        assert!((trend.last().unwrap() - 2.0).abs() < 1.0);
663    }
664
665    #[test]
666    fn test_detrended_sums_to_zero_ish() {
667        let series = create_seasonal_series();
668        let result = TrendExtraction::compute(&series, TrendMethod::CenteredMovingAverage, 12);
669
670        // Detrended should roughly sum to zero (mean-centered)
671        let detrended_mean: f64 =
672            result.detrended.iter().sum::<f64>() / result.detrended.len() as f64;
673        assert!(
674            detrended_mean.abs() < 1.0,
675            "Detrended mean: {}",
676            detrended_mean
677        );
678    }
679
680    #[test]
681    fn test_empty_series() {
682        let empty = TimeSeries::new(Vec::new());
683
684        let decomp = SeasonalDecomposition::compute(&empty, 12, false);
685        assert!(decomp.trend.is_empty());
686
687        let trend = TrendExtraction::compute(&empty, TrendMethod::SimpleMovingAverage, 5);
688        assert!(trend.trend.is_empty());
689    }
690}