1use std::time::Instant;
8
9use async_trait::async_trait;
10
11use crate::messages::{
12 SeasonalDecompositionInput, SeasonalDecompositionOutput, TrendExtractionInput,
13 TrendExtractionOutput,
14};
15use crate::types::{DecompositionResult, TimeSeries, TrendMethod, TrendResult};
16use rustkernel_core::{
17 domain::Domain,
18 error::Result,
19 kernel::KernelMetadata,
20 traits::{BatchKernel, GpuKernel},
21};
22
23#[derive(Debug, Clone)]
31pub struct SeasonalDecomposition {
32 metadata: KernelMetadata,
33}
34
35impl Default for SeasonalDecomposition {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl SeasonalDecomposition {
42 #[must_use]
44 pub fn new() -> Self {
45 Self {
46 metadata: KernelMetadata::batch(
47 "temporal/seasonal-decomposition",
48 Domain::TemporalAnalysis,
49 )
50 .with_description("STL-style seasonal decomposition")
51 .with_throughput(10_000)
52 .with_latency_us(100.0),
53 }
54 }
55
56 pub fn compute(series: &TimeSeries, period: usize, robust: bool) -> DecompositionResult {
63 let n = series.len();
64
65 if n < 2 * period || period < 2 {
66 return DecompositionResult {
67 trend: series.values.clone(),
68 seasonal: vec![0.0; n],
69 residual: vec![0.0; n],
70 n,
71 period,
72 };
73 }
74
75 let trend = Self::centered_moving_average(&series.values, period);
77
78 let detrended: Vec<f64> = series
80 .values
81 .iter()
82 .zip(trend.iter())
83 .map(|(v, t)| v - t)
84 .collect();
85
86 let seasonal_pattern = if robust {
88 Self::robust_seasonal(&detrended, period)
89 } else {
90 Self::mean_seasonal(&detrended, period)
91 };
92
93 let seasonal: Vec<f64> = (0..n).map(|i| seasonal_pattern[i % period]).collect();
95
96 let deseasoned: Vec<f64> = series
98 .values
99 .iter()
100 .zip(seasonal.iter())
101 .map(|(v, s)| v - s)
102 .collect();
103
104 let refined_trend = Self::lowess_trend(&deseasoned, period);
105
106 let residual: Vec<f64> = series
108 .values
109 .iter()
110 .zip(refined_trend.iter())
111 .zip(seasonal.iter())
112 .map(|((v, t), s)| v - t - s)
113 .collect();
114
115 DecompositionResult {
116 trend: refined_trend,
117 seasonal,
118 residual,
119 n,
120 period,
121 }
122 }
123
124 pub fn compute_additive(series: &TimeSeries, period: usize) -> DecompositionResult {
126 Self::compute(series, period, false)
127 }
128
129 pub fn compute_multiplicative(series: &TimeSeries, period: usize) -> DecompositionResult {
133 let n = series.len();
134
135 if n < 2 * period || period < 2 {
136 return DecompositionResult {
137 trend: series.values.clone(),
138 seasonal: vec![1.0; n],
139 residual: vec![1.0; n],
140 n,
141 period,
142 };
143 }
144
145 let min_val = series.values.iter().cloned().fold(f64::INFINITY, f64::min);
147
148 if min_val <= 0.0 {
149 return Self::compute(series, period, false);
151 }
152
153 let log_values: Vec<f64> = series.values.iter().map(|v| v.ln()).collect();
154 let log_series = TimeSeries::new(log_values);
155
156 let log_result = Self::compute(&log_series, period, false);
157
158 DecompositionResult {
160 trend: log_result.trend.iter().map(|t| t.exp()).collect(),
161 seasonal: log_result.seasonal.iter().map(|s| s.exp()).collect(),
162 residual: log_result.residual.iter().map(|r| r.exp()).collect(),
163 n,
164 period,
165 }
166 }
167
168 #[allow(clippy::needless_range_loop)]
170 fn centered_moving_average(values: &[f64], window: usize) -> Vec<f64> {
171 let n = values.len();
172 let half_w = window / 2;
173 let mut result = vec![0.0; n];
174
175 for i in 0..n {
176 let start = i.saturating_sub(half_w);
177 let end = (i + half_w + 1).min(n);
178
179 if window % 2 == 0 && i >= half_w && i + half_w < n {
181 let mut sum = 0.0;
182 let mut weight = 0.0;
183
184 for j in start..end {
185 let w = if j == start || j == end - 1 { 0.5 } else { 1.0 };
186 sum += values[j] * w;
187 weight += w;
188 }
189 result[i] = sum / weight;
190 } else {
191 result[i] = values[start..end].iter().sum::<f64>() / (end - start) as f64;
192 }
193 }
194
195 result
196 }
197
198 fn mean_seasonal(detrended: &[f64], period: usize) -> Vec<f64> {
200 let mut seasonal = vec![0.0; period];
201 let mut counts = vec![0usize; period];
202
203 for (i, &d) in detrended.iter().enumerate() {
204 seasonal[i % period] += d;
205 counts[i % period] += 1;
206 }
207
208 for i in 0..period {
209 if counts[i] > 0 {
210 seasonal[i] /= counts[i] as f64;
211 }
212 }
213
214 let mean: f64 = seasonal.iter().sum::<f64>() / period as f64;
216 for s in &mut seasonal {
217 *s -= mean;
218 }
219
220 seasonal
221 }
222
223 #[allow(clippy::needless_range_loop)]
225 fn robust_seasonal(detrended: &[f64], period: usize) -> Vec<f64> {
226 let mut seasonal = vec![0.0; period];
227
228 for s in 0..period {
229 let mut season_values: Vec<f64> = detrended
230 .iter()
231 .enumerate()
232 .filter(|(i, _)| i % period == s)
233 .map(|(_, &v)| v)
234 .collect();
235
236 if !season_values.is_empty() {
237 season_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
238 seasonal[s] = season_values[season_values.len() / 2];
239 }
240 }
241
242 let mut sorted = seasonal.clone();
244 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
245 let median = sorted[period / 2];
246
247 for s in &mut seasonal {
248 *s -= median;
249 }
250
251 seasonal
252 }
253
254 #[allow(clippy::needless_range_loop)]
256 fn lowess_trend(values: &[f64], bandwidth: usize) -> Vec<f64> {
257 let n = values.len();
258 let mut trend = vec![0.0; n];
259
260 for i in 0..n {
261 let start = i.saturating_sub(bandwidth);
262 let end = (i + bandwidth + 1).min(n);
263
264 let mut weighted_sum = 0.0;
266 let mut weight_sum = 0.0;
267
268 for j in start..end {
269 let dist = (j as f64 - i as f64).abs() / bandwidth as f64;
270 let weight = if dist < 1.0 {
271 (1.0 - dist.powi(3)).powi(3)
272 } else {
273 0.0
274 };
275 weighted_sum += values[j] * weight;
276 weight_sum += weight;
277 }
278
279 trend[i] = if weight_sum > 0.0 {
280 weighted_sum / weight_sum
281 } else {
282 values[i]
283 };
284 }
285
286 trend
287 }
288}
289
290impl GpuKernel for SeasonalDecomposition {
291 fn metadata(&self) -> &KernelMetadata {
292 &self.metadata
293 }
294}
295
296#[async_trait]
297impl BatchKernel<SeasonalDecompositionInput, SeasonalDecompositionOutput>
298 for SeasonalDecomposition
299{
300 async fn execute(
301 &self,
302 input: SeasonalDecompositionInput,
303 ) -> Result<SeasonalDecompositionOutput> {
304 let start = Instant::now();
305 let result = Self::compute(&input.series, input.period, input.robust);
306 Ok(SeasonalDecompositionOutput {
307 result,
308 compute_time_us: start.elapsed().as_micros() as u64,
309 })
310 }
311}
312
313#[derive(Debug, Clone)]
321pub struct TrendExtraction {
322 metadata: KernelMetadata,
323}
324
325impl Default for TrendExtraction {
326 fn default() -> Self {
327 Self::new()
328 }
329}
330
331impl TrendExtraction {
332 #[must_use]
334 pub fn new() -> Self {
335 Self {
336 metadata: KernelMetadata::batch("temporal/trend-extraction", Domain::TemporalAnalysis)
337 .with_description("Moving average trend extraction")
338 .with_throughput(50_000)
339 .with_latency_us(20.0),
340 }
341 }
342
343 pub fn compute(series: &TimeSeries, method: TrendMethod, window: usize) -> TrendResult {
350 if series.is_empty() {
351 return TrendResult {
352 trend: Vec::new(),
353 detrended: Vec::new(),
354 method,
355 };
356 }
357
358 let trend = match method {
359 TrendMethod::SimpleMovingAverage => Self::simple_ma(&series.values, window),
360 TrendMethod::ExponentialMovingAverage => Self::exponential_ma(&series.values, window),
361 TrendMethod::CenteredMovingAverage => Self::centered_ma(&series.values, window),
362 TrendMethod::Lowess => Self::lowess(&series.values, window),
363 };
364
365 let detrended: Vec<f64> = series
366 .values
367 .iter()
368 .zip(trend.iter())
369 .map(|(v, t)| v - t)
370 .collect();
371
372 TrendResult {
373 trend,
374 detrended,
375 method,
376 }
377 }
378
379 fn simple_ma(values: &[f64], window: usize) -> Vec<f64> {
381 let n = values.len();
382 let w = window.min(n).max(1);
383 let mut result = vec![0.0; n];
384
385 let mut cumsum = vec![0.0; n + 1];
387 for (i, &v) in values.iter().enumerate() {
388 cumsum[i + 1] = cumsum[i] + v;
389 }
390
391 for i in 0..n {
392 let start = i.saturating_sub(w - 1);
393 let count = i - start + 1;
394 result[i] = (cumsum[i + 1] - cumsum[start]) / count as f64;
395 }
396
397 result
398 }
399
400 fn exponential_ma(values: &[f64], span: usize) -> Vec<f64> {
402 let n = values.len();
403 if n == 0 {
404 return Vec::new();
405 }
406
407 let alpha = 2.0 / (span as f64 + 1.0);
408 let mut result = vec![0.0; n];
409 result[0] = values[0];
410
411 for i in 1..n {
412 result[i] = alpha * values[i] + (1.0 - alpha) * result[i - 1];
413 }
414
415 result
416 }
417
418 #[allow(clippy::needless_range_loop)]
420 fn centered_ma(values: &[f64], window: usize) -> Vec<f64> {
421 let n = values.len();
422 let half_w = window / 2;
423 let mut result = vec![0.0; n];
424
425 for i in 0..n {
426 let start = i.saturating_sub(half_w);
427 let end = (i + half_w + 1).min(n);
428 result[i] = values[start..end].iter().sum::<f64>() / (end - start) as f64;
429 }
430
431 result
432 }
433
434 #[allow(clippy::needless_range_loop)]
436 fn lowess(values: &[f64], bandwidth: usize) -> Vec<f64> {
437 let n = values.len();
438 let mut result = vec![0.0; n];
439
440 for i in 0..n {
441 let start = i.saturating_sub(bandwidth);
442 let end = (i + bandwidth + 1).min(n);
443
444 let mut sum_w = 0.0;
446 let mut sum_wx = 0.0;
447 let mut sum_wy = 0.0;
448 let mut sum_wxx = 0.0;
449 let mut sum_wxy = 0.0;
450
451 for j in start..end {
452 let x = j as f64;
453 let y = values[j];
454 let dist = (j as f64 - i as f64).abs() / (bandwidth as f64 + 1.0);
455 let w = if dist < 1.0 {
456 (1.0 - dist.powi(3)).powi(3)
457 } else {
458 0.0
459 };
460
461 sum_w += w;
462 sum_wx += w * x;
463 sum_wy += w * y;
464 sum_wxx += w * x * x;
465 sum_wxy += w * x * y;
466 }
467
468 let det = sum_w * sum_wxx - sum_wx * sum_wx;
470 if det.abs() > 1e-10 {
471 let b0 = (sum_wxx * sum_wy - sum_wx * sum_wxy) / det;
472 let b1 = (sum_w * sum_wxy - sum_wx * sum_wy) / det;
473 result[i] = b0 + b1 * i as f64;
474 } else {
475 result[i] = if sum_w > 0.0 {
476 sum_wy / sum_w
477 } else {
478 values[i]
479 };
480 }
481 }
482
483 result
484 }
485
486 pub fn holt_smoothing(values: &[f64], alpha: f64, beta: f64) -> (Vec<f64>, Vec<f64>) {
488 let n = values.len();
489 if n < 2 {
490 return (values.to_vec(), vec![0.0; n]);
491 }
492
493 let mut level = vec![0.0; n];
494 let mut trend = vec![0.0; n];
495
496 level[0] = values[0];
498 trend[0] = values[1] - values[0];
499
500 for i in 1..n {
501 level[i] = alpha * values[i] + (1.0 - alpha) * (level[i - 1] + trend[i - 1]);
502 trend[i] = beta * (level[i] - level[i - 1]) + (1.0 - beta) * trend[i - 1];
503 }
504
505 (level, trend)
506 }
507}
508
509impl GpuKernel for TrendExtraction {
510 fn metadata(&self) -> &KernelMetadata {
511 &self.metadata
512 }
513}
514
515#[async_trait]
516impl BatchKernel<TrendExtractionInput, TrendExtractionOutput> for TrendExtraction {
517 async fn execute(&self, input: TrendExtractionInput) -> Result<TrendExtractionOutput> {
518 let start = Instant::now();
519 let result = Self::compute(&input.series, input.method, input.window);
520 Ok(TrendExtractionOutput {
521 result,
522 compute_time_us: start.elapsed().as_micros() as u64,
523 })
524 }
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530
531 fn create_seasonal_series() -> TimeSeries {
532 let period = 12;
534 let values: Vec<f64> = (0..120)
535 .map(|t| {
536 let trend = 100.0 + 0.5 * t as f64;
537 let seasonal =
538 10.0 * ((2.0 * std::f64::consts::PI * t as f64 / period as f64).sin());
539 trend + seasonal
540 })
541 .collect();
542 TimeSeries::new(values)
543 }
544
545 fn create_trend_series() -> TimeSeries {
546 TimeSeries::new(
548 (0..100)
549 .map(|t| 10.0 + 2.0 * t as f64 + (t as f64 * 0.3).sin())
550 .collect(),
551 )
552 }
553
554 #[test]
555 fn test_decomposition_metadata() {
556 let kernel = SeasonalDecomposition::new();
557 assert_eq!(kernel.metadata().id, "temporal/seasonal-decomposition");
558 assert_eq!(kernel.metadata().domain, Domain::TemporalAnalysis);
559 }
560
561 #[test]
562 fn test_seasonal_decomposition() {
563 let series = create_seasonal_series();
564 let result = SeasonalDecomposition::compute(&series, 12, false);
565
566 assert_eq!(result.trend.len(), series.len());
568 assert_eq!(result.seasonal.len(), series.len());
569 assert_eq!(result.residual.len(), series.len());
570 assert_eq!(result.period, 12);
571
572 for i in 0..result.seasonal.len() - 12 {
574 let diff = (result.seasonal[i] - result.seasonal[i + 12]).abs();
575 assert!(diff < 0.01, "Seasonal not periodic at {}: diff={}", i, diff);
576 }
577 }
578
579 #[test]
580 fn test_robust_decomposition() {
581 let series = create_seasonal_series();
582 let result = SeasonalDecomposition::compute(&series, 12, true);
583
584 assert_eq!(result.trend.len(), series.len());
585 }
587
588 #[test]
589 fn test_multiplicative_decomposition() {
590 let values: Vec<f64> = (0..120)
592 .map(|t| {
593 let trend = 100.0 + 0.5 * t as f64;
594 let seasonal = 1.0 + 0.1 * ((2.0 * std::f64::consts::PI * t as f64 / 12.0).sin());
595 trend * seasonal
596 })
597 .collect();
598 let series = TimeSeries::new(values);
599
600 let result = SeasonalDecomposition::compute_multiplicative(&series, 12);
601
602 assert_eq!(result.trend.len(), series.len());
603 }
605
606 #[test]
607 fn test_trend_extraction_metadata() {
608 let kernel = TrendExtraction::new();
609 assert_eq!(kernel.metadata().id, "temporal/trend-extraction");
610 }
611
612 #[test]
613 fn test_simple_moving_average() {
614 let series = create_trend_series();
615 let result = TrendExtraction::compute(&series, TrendMethod::SimpleMovingAverage, 5);
616
617 assert_eq!(result.trend.len(), series.len());
618 assert_eq!(result.method, TrendMethod::SimpleMovingAverage);
619
620 let original_var: f64 = series.values.windows(2).map(|w| (w[1] - w[0]).abs()).sum();
622 let trend_var: f64 = result.trend.windows(2).map(|w| (w[1] - w[0]).abs()).sum();
623 assert!(trend_var <= original_var);
624 }
625
626 #[test]
627 fn test_exponential_moving_average() {
628 let series = create_trend_series();
629 let result = TrendExtraction::compute(&series, TrendMethod::ExponentialMovingAverage, 10);
630
631 assert_eq!(result.trend.len(), series.len());
632 assert_eq!(result.method, TrendMethod::ExponentialMovingAverage);
633 }
634
635 #[test]
636 fn test_centered_moving_average() {
637 let series = create_trend_series();
638 let result = TrendExtraction::compute(&series, TrendMethod::CenteredMovingAverage, 7);
639
640 assert_eq!(result.trend.len(), series.len());
641 assert_eq!(result.method, TrendMethod::CenteredMovingAverage);
642 }
643
644 #[test]
645 fn test_lowess_trend() {
646 let series = create_trend_series();
647 let result = TrendExtraction::compute(&series, TrendMethod::Lowess, 10);
648
649 assert_eq!(result.trend.len(), series.len());
650 assert_eq!(result.method, TrendMethod::Lowess);
651 }
652
653 #[test]
654 fn test_holt_smoothing() {
655 let values: Vec<f64> = (0..50).map(|t| 10.0 + 2.0 * t as f64).collect();
656 let (level, trend) = TrendExtraction::holt_smoothing(&values, 0.3, 0.1);
657
658 assert_eq!(level.len(), values.len());
659 assert_eq!(trend.len(), values.len());
660
661 assert!((trend.last().unwrap() - 2.0).abs() < 1.0);
663 }
664
665 #[test]
666 fn test_detrended_sums_to_zero_ish() {
667 let series = create_seasonal_series();
668 let result = TrendExtraction::compute(&series, TrendMethod::CenteredMovingAverage, 12);
669
670 let detrended_mean: f64 =
672 result.detrended.iter().sum::<f64>() / result.detrended.len() as f64;
673 assert!(
674 detrended_mean.abs() < 1.0,
675 "Detrended mean: {}",
676 detrended_mean
677 );
678 }
679
680 #[test]
681 fn test_empty_series() {
682 let empty = TimeSeries::new(Vec::new());
683
684 let decomp = SeasonalDecomposition::compute(&empty, 12, false);
685 assert!(decomp.trend.is_empty());
686
687 let trend = TrendExtraction::compute(&empty, TrendMethod::SimpleMovingAverage, 5);
688 assert!(trend.trend.is_empty());
689 }
690}