1use std::time::Instant;
8
9use async_trait::async_trait;
10
11use crate::messages::{
12 SeasonalDecompositionInput, SeasonalDecompositionOutput, TrendExtractionInput,
13 TrendExtractionOutput,
14};
15use crate::types::{DecompositionResult, TimeSeries, TrendMethod, TrendResult};
16use rustkernel_core::{
17 domain::Domain,
18 error::Result,
19 kernel::KernelMetadata,
20 traits::{BatchKernel, GpuKernel},
21};
22
23#[derive(Debug, Clone)]
31pub struct SeasonalDecomposition {
32 metadata: KernelMetadata,
33}
34
35impl Default for SeasonalDecomposition {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl SeasonalDecomposition {
42 #[must_use]
44 pub fn new() -> Self {
45 Self {
46 metadata: KernelMetadata::batch(
47 "temporal/seasonal-decomposition",
48 Domain::TemporalAnalysis,
49 )
50 .with_description("STL-style seasonal decomposition")
51 .with_throughput(10_000)
52 .with_latency_us(100.0),
53 }
54 }
55
56 pub fn compute(series: &TimeSeries, period: usize, robust: bool) -> DecompositionResult {
63 let n = series.len();
64
65 if n < 2 * period || period < 2 {
66 return DecompositionResult {
67 trend: series.values.clone(),
68 seasonal: vec![0.0; n],
69 residual: vec![0.0; n],
70 n,
71 period,
72 };
73 }
74
75 let trend = Self::centered_moving_average(&series.values, period);
77
78 let detrended: Vec<f64> = series
80 .values
81 .iter()
82 .zip(trend.iter())
83 .map(|(v, t)| v - t)
84 .collect();
85
86 let seasonal_pattern = if robust {
88 Self::robust_seasonal(&detrended, period)
89 } else {
90 Self::mean_seasonal(&detrended, period)
91 };
92
93 let seasonal: Vec<f64> = (0..n).map(|i| seasonal_pattern[i % period]).collect();
95
96 let deseasoned: Vec<f64> = series
98 .values
99 .iter()
100 .zip(seasonal.iter())
101 .map(|(v, s)| v - s)
102 .collect();
103
104 let refined_trend = Self::lowess_trend(&deseasoned, period);
105
106 let residual: Vec<f64> = series
108 .values
109 .iter()
110 .zip(refined_trend.iter())
111 .zip(seasonal.iter())
112 .map(|((v, t), s)| v - t - s)
113 .collect();
114
115 DecompositionResult {
116 trend: refined_trend,
117 seasonal,
118 residual,
119 n,
120 period,
121 }
122 }
123
124 pub fn compute_additive(series: &TimeSeries, period: usize) -> DecompositionResult {
126 Self::compute(series, period, false)
127 }
128
129 pub fn compute_multiplicative(series: &TimeSeries, period: usize) -> DecompositionResult {
133 let n = series.len();
134
135 if n < 2 * period || period < 2 {
136 return DecompositionResult {
137 trend: series.values.clone(),
138 seasonal: vec![1.0; n],
139 residual: vec![1.0; n],
140 n,
141 period,
142 };
143 }
144
145 let min_val = series.values.iter().cloned().fold(f64::INFINITY, f64::min);
147
148 if min_val <= 0.0 {
149 return Self::compute(series, period, false);
151 }
152
153 let log_values: Vec<f64> = series.values.iter().map(|v| v.ln()).collect();
154 let log_series = TimeSeries::new(log_values);
155
156 let log_result = Self::compute(&log_series, period, false);
157
158 DecompositionResult {
160 trend: log_result.trend.iter().map(|t| t.exp()).collect(),
161 seasonal: log_result.seasonal.iter().map(|s| s.exp()).collect(),
162 residual: log_result.residual.iter().map(|r| r.exp()).collect(),
163 n,
164 period,
165 }
166 }
167
168 fn centered_moving_average(values: &[f64], window: usize) -> Vec<f64> {
170 let n = values.len();
171 let half_w = window / 2;
172 let mut result = vec![0.0; n];
173
174 for i in 0..n {
175 let start = i.saturating_sub(half_w);
176 let end = (i + half_w + 1).min(n);
177
178 if window % 2 == 0 && i >= half_w && i + half_w < n {
180 let mut sum = 0.0;
181 let mut weight = 0.0;
182
183 for j in start..end {
184 let w = if j == start || j == end - 1 { 0.5 } else { 1.0 };
185 sum += values[j] * w;
186 weight += w;
187 }
188 result[i] = sum / weight;
189 } else {
190 result[i] = values[start..end].iter().sum::<f64>() / (end - start) as f64;
191 }
192 }
193
194 result
195 }
196
197 fn mean_seasonal(detrended: &[f64], period: usize) -> Vec<f64> {
199 let mut seasonal = vec![0.0; period];
200 let mut counts = vec![0usize; period];
201
202 for (i, &d) in detrended.iter().enumerate() {
203 seasonal[i % period] += d;
204 counts[i % period] += 1;
205 }
206
207 for i in 0..period {
208 if counts[i] > 0 {
209 seasonal[i] /= counts[i] as f64;
210 }
211 }
212
213 let mean: f64 = seasonal.iter().sum::<f64>() / period as f64;
215 for s in &mut seasonal {
216 *s -= mean;
217 }
218
219 seasonal
220 }
221
222 fn robust_seasonal(detrended: &[f64], period: usize) -> Vec<f64> {
224 let mut seasonal = vec![0.0; period];
225
226 for s in 0..period {
227 let mut season_values: Vec<f64> = detrended
228 .iter()
229 .enumerate()
230 .filter(|(i, _)| i % period == s)
231 .map(|(_, &v)| v)
232 .collect();
233
234 if !season_values.is_empty() {
235 season_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
236 seasonal[s] = season_values[season_values.len() / 2];
237 }
238 }
239
240 let mut sorted = seasonal.clone();
242 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
243 let median = sorted[period / 2];
244
245 for s in &mut seasonal {
246 *s -= median;
247 }
248
249 seasonal
250 }
251
252 fn lowess_trend(values: &[f64], bandwidth: usize) -> Vec<f64> {
254 let n = values.len();
255 let mut trend = vec![0.0; n];
256
257 for i in 0..n {
258 let start = i.saturating_sub(bandwidth);
259 let end = (i + bandwidth + 1).min(n);
260
261 let mut weighted_sum = 0.0;
263 let mut weight_sum = 0.0;
264
265 for j in start..end {
266 let dist = (j as f64 - i as f64).abs() / bandwidth as f64;
267 let weight = if dist < 1.0 {
268 (1.0 - dist.powi(3)).powi(3)
269 } else {
270 0.0
271 };
272 weighted_sum += values[j] * weight;
273 weight_sum += weight;
274 }
275
276 trend[i] = if weight_sum > 0.0 {
277 weighted_sum / weight_sum
278 } else {
279 values[i]
280 };
281 }
282
283 trend
284 }
285}
286
287impl GpuKernel for SeasonalDecomposition {
288 fn metadata(&self) -> &KernelMetadata {
289 &self.metadata
290 }
291}
292
293#[async_trait]
294impl BatchKernel<SeasonalDecompositionInput, SeasonalDecompositionOutput>
295 for SeasonalDecomposition
296{
297 async fn execute(
298 &self,
299 input: SeasonalDecompositionInput,
300 ) -> Result<SeasonalDecompositionOutput> {
301 let start = Instant::now();
302 let result = Self::compute(&input.series, input.period, input.robust);
303 Ok(SeasonalDecompositionOutput {
304 result,
305 compute_time_us: start.elapsed().as_micros() as u64,
306 })
307 }
308}
309
310#[derive(Debug, Clone)]
318pub struct TrendExtraction {
319 metadata: KernelMetadata,
320}
321
322impl Default for TrendExtraction {
323 fn default() -> Self {
324 Self::new()
325 }
326}
327
328impl TrendExtraction {
329 #[must_use]
331 pub fn new() -> Self {
332 Self {
333 metadata: KernelMetadata::batch("temporal/trend-extraction", Domain::TemporalAnalysis)
334 .with_description("Moving average trend extraction")
335 .with_throughput(50_000)
336 .with_latency_us(20.0),
337 }
338 }
339
340 pub fn compute(series: &TimeSeries, method: TrendMethod, window: usize) -> TrendResult {
347 if series.is_empty() {
348 return TrendResult {
349 trend: Vec::new(),
350 detrended: Vec::new(),
351 method,
352 };
353 }
354
355 let trend = match method {
356 TrendMethod::SimpleMovingAverage => Self::simple_ma(&series.values, window),
357 TrendMethod::ExponentialMovingAverage => Self::exponential_ma(&series.values, window),
358 TrendMethod::CenteredMovingAverage => Self::centered_ma(&series.values, window),
359 TrendMethod::Lowess => Self::lowess(&series.values, window),
360 };
361
362 let detrended: Vec<f64> = series
363 .values
364 .iter()
365 .zip(trend.iter())
366 .map(|(v, t)| v - t)
367 .collect();
368
369 TrendResult {
370 trend,
371 detrended,
372 method,
373 }
374 }
375
376 fn simple_ma(values: &[f64], window: usize) -> Vec<f64> {
378 let n = values.len();
379 let w = window.min(n).max(1);
380 let mut result = vec![0.0; n];
381
382 let mut cumsum = vec![0.0; n + 1];
384 for (i, &v) in values.iter().enumerate() {
385 cumsum[i + 1] = cumsum[i] + v;
386 }
387
388 for i in 0..n {
389 let start = i.saturating_sub(w - 1);
390 let count = i - start + 1;
391 result[i] = (cumsum[i + 1] - cumsum[start]) / count as f64;
392 }
393
394 result
395 }
396
397 fn exponential_ma(values: &[f64], span: usize) -> Vec<f64> {
399 let n = values.len();
400 if n == 0 {
401 return Vec::new();
402 }
403
404 let alpha = 2.0 / (span as f64 + 1.0);
405 let mut result = vec![0.0; n];
406 result[0] = values[0];
407
408 for i in 1..n {
409 result[i] = alpha * values[i] + (1.0 - alpha) * result[i - 1];
410 }
411
412 result
413 }
414
415 fn centered_ma(values: &[f64], window: usize) -> Vec<f64> {
417 let n = values.len();
418 let half_w = window / 2;
419 let mut result = vec![0.0; n];
420
421 for i in 0..n {
422 let start = i.saturating_sub(half_w);
423 let end = (i + half_w + 1).min(n);
424 result[i] = values[start..end].iter().sum::<f64>() / (end - start) as f64;
425 }
426
427 result
428 }
429
430 fn lowess(values: &[f64], bandwidth: usize) -> Vec<f64> {
432 let n = values.len();
433 let mut result = vec![0.0; n];
434
435 for i in 0..n {
436 let start = i.saturating_sub(bandwidth);
437 let end = (i + bandwidth + 1).min(n);
438
439 let mut sum_w = 0.0;
441 let mut sum_wx = 0.0;
442 let mut sum_wy = 0.0;
443 let mut sum_wxx = 0.0;
444 let mut sum_wxy = 0.0;
445
446 for j in start..end {
447 let x = j as f64;
448 let y = values[j];
449 let dist = (j as f64 - i as f64).abs() / (bandwidth as f64 + 1.0);
450 let w = if dist < 1.0 {
451 (1.0 - dist.powi(3)).powi(3)
452 } else {
453 0.0
454 };
455
456 sum_w += w;
457 sum_wx += w * x;
458 sum_wy += w * y;
459 sum_wxx += w * x * x;
460 sum_wxy += w * x * y;
461 }
462
463 let det = sum_w * sum_wxx - sum_wx * sum_wx;
465 if det.abs() > 1e-10 {
466 let b0 = (sum_wxx * sum_wy - sum_wx * sum_wxy) / det;
467 let b1 = (sum_w * sum_wxy - sum_wx * sum_wy) / det;
468 result[i] = b0 + b1 * i as f64;
469 } else {
470 result[i] = if sum_w > 0.0 {
471 sum_wy / sum_w
472 } else {
473 values[i]
474 };
475 }
476 }
477
478 result
479 }
480
481 pub fn holt_smoothing(values: &[f64], alpha: f64, beta: f64) -> (Vec<f64>, Vec<f64>) {
483 let n = values.len();
484 if n < 2 {
485 return (values.to_vec(), vec![0.0; n]);
486 }
487
488 let mut level = vec![0.0; n];
489 let mut trend = vec![0.0; n];
490
491 level[0] = values[0];
493 trend[0] = values[1] - values[0];
494
495 for i in 1..n {
496 level[i] = alpha * values[i] + (1.0 - alpha) * (level[i - 1] + trend[i - 1]);
497 trend[i] = beta * (level[i] - level[i - 1]) + (1.0 - beta) * trend[i - 1];
498 }
499
500 (level, trend)
501 }
502}
503
504impl GpuKernel for TrendExtraction {
505 fn metadata(&self) -> &KernelMetadata {
506 &self.metadata
507 }
508}
509
510#[async_trait]
511impl BatchKernel<TrendExtractionInput, TrendExtractionOutput> for TrendExtraction {
512 async fn execute(&self, input: TrendExtractionInput) -> Result<TrendExtractionOutput> {
513 let start = Instant::now();
514 let result = Self::compute(&input.series, input.method, input.window);
515 Ok(TrendExtractionOutput {
516 result,
517 compute_time_us: start.elapsed().as_micros() as u64,
518 })
519 }
520}
521
522#[cfg(test)]
523mod tests {
524 use super::*;
525
526 fn create_seasonal_series() -> TimeSeries {
527 let period = 12;
529 let values: Vec<f64> = (0..120)
530 .map(|t| {
531 let trend = 100.0 + 0.5 * t as f64;
532 let seasonal =
533 10.0 * ((2.0 * std::f64::consts::PI * t as f64 / period as f64).sin());
534 trend + seasonal
535 })
536 .collect();
537 TimeSeries::new(values)
538 }
539
540 fn create_trend_series() -> TimeSeries {
541 TimeSeries::new(
543 (0..100)
544 .map(|t| 10.0 + 2.0 * t as f64 + (t as f64 * 0.3).sin())
545 .collect(),
546 )
547 }
548
549 #[test]
550 fn test_decomposition_metadata() {
551 let kernel = SeasonalDecomposition::new();
552 assert_eq!(kernel.metadata().id, "temporal/seasonal-decomposition");
553 assert_eq!(kernel.metadata().domain, Domain::TemporalAnalysis);
554 }
555
556 #[test]
557 fn test_seasonal_decomposition() {
558 let series = create_seasonal_series();
559 let result = SeasonalDecomposition::compute(&series, 12, false);
560
561 assert_eq!(result.trend.len(), series.len());
563 assert_eq!(result.seasonal.len(), series.len());
564 assert_eq!(result.residual.len(), series.len());
565 assert_eq!(result.period, 12);
566
567 for i in 0..result.seasonal.len() - 12 {
569 let diff = (result.seasonal[i] - result.seasonal[i + 12]).abs();
570 assert!(diff < 0.01, "Seasonal not periodic at {}: diff={}", i, diff);
571 }
572 }
573
574 #[test]
575 fn test_robust_decomposition() {
576 let series = create_seasonal_series();
577 let result = SeasonalDecomposition::compute(&series, 12, true);
578
579 assert_eq!(result.trend.len(), series.len());
580 }
582
583 #[test]
584 fn test_multiplicative_decomposition() {
585 let values: Vec<f64> = (0..120)
587 .map(|t| {
588 let trend = 100.0 + 0.5 * t as f64;
589 let seasonal = 1.0 + 0.1 * ((2.0 * std::f64::consts::PI * t as f64 / 12.0).sin());
590 trend * seasonal
591 })
592 .collect();
593 let series = TimeSeries::new(values);
594
595 let result = SeasonalDecomposition::compute_multiplicative(&series, 12);
596
597 assert_eq!(result.trend.len(), series.len());
598 }
600
601 #[test]
602 fn test_trend_extraction_metadata() {
603 let kernel = TrendExtraction::new();
604 assert_eq!(kernel.metadata().id, "temporal/trend-extraction");
605 }
606
607 #[test]
608 fn test_simple_moving_average() {
609 let series = create_trend_series();
610 let result = TrendExtraction::compute(&series, TrendMethod::SimpleMovingAverage, 5);
611
612 assert_eq!(result.trend.len(), series.len());
613 assert_eq!(result.method, TrendMethod::SimpleMovingAverage);
614
615 let original_var: f64 = series.values.windows(2).map(|w| (w[1] - w[0]).abs()).sum();
617 let trend_var: f64 = result.trend.windows(2).map(|w| (w[1] - w[0]).abs()).sum();
618 assert!(trend_var <= original_var);
619 }
620
621 #[test]
622 fn test_exponential_moving_average() {
623 let series = create_trend_series();
624 let result = TrendExtraction::compute(&series, TrendMethod::ExponentialMovingAverage, 10);
625
626 assert_eq!(result.trend.len(), series.len());
627 assert_eq!(result.method, TrendMethod::ExponentialMovingAverage);
628 }
629
630 #[test]
631 fn test_centered_moving_average() {
632 let series = create_trend_series();
633 let result = TrendExtraction::compute(&series, TrendMethod::CenteredMovingAverage, 7);
634
635 assert_eq!(result.trend.len(), series.len());
636 assert_eq!(result.method, TrendMethod::CenteredMovingAverage);
637 }
638
639 #[test]
640 fn test_lowess_trend() {
641 let series = create_trend_series();
642 let result = TrendExtraction::compute(&series, TrendMethod::Lowess, 10);
643
644 assert_eq!(result.trend.len(), series.len());
645 assert_eq!(result.method, TrendMethod::Lowess);
646 }
647
648 #[test]
649 fn test_holt_smoothing() {
650 let values: Vec<f64> = (0..50).map(|t| 10.0 + 2.0 * t as f64).collect();
651 let (level, trend) = TrendExtraction::holt_smoothing(&values, 0.3, 0.1);
652
653 assert_eq!(level.len(), values.len());
654 assert_eq!(trend.len(), values.len());
655
656 assert!((trend.last().unwrap() - 2.0).abs() < 1.0);
658 }
659
660 #[test]
661 fn test_detrended_sums_to_zero_ish() {
662 let series = create_seasonal_series();
663 let result = TrendExtraction::compute(&series, TrendMethod::CenteredMovingAverage, 12);
664
665 let detrended_mean: f64 =
667 result.detrended.iter().sum::<f64>() / result.detrended.len() as f64;
668 assert!(
669 detrended_mean.abs() < 1.0,
670 "Detrended mean: {}",
671 detrended_mean
672 );
673 }
674
675 #[test]
676 fn test_empty_series() {
677 let empty = TimeSeries::new(Vec::new());
678
679 let decomp = SeasonalDecomposition::compute(&empty, 12, false);
680 assert!(decomp.trend.is_empty());
681
682 let trend = TrendExtraction::compute(&empty, TrendMethod::SimpleMovingAverage, 5);
683 assert!(trend.trend.is_empty());
684 }
685}