rustkernel_temporal/
forecasting.rs

1//! Forecasting kernels.
2//!
3//! This module provides time series forecasting:
4//! - ARIMA (AutoRegressive Integrated Moving Average)
5//! - Prophet-style decomposition forecasting
6
7use std::time::Instant;
8
9use async_trait::async_trait;
10
11use crate::messages::{
12    ARIMAForecastInput, ARIMAForecastOutput, ProphetDecompositionInput, ProphetDecompositionOutput,
13};
14use crate::types::{ARIMAParams, ARIMAResult, ProphetResult, TimeSeries};
15use rustkernel_core::{
16    domain::Domain,
17    error::Result,
18    kernel::KernelMetadata,
19    traits::{BatchKernel, GpuKernel},
20};
21
22// ============================================================================
23// ARIMA Forecast Kernel
24// ============================================================================
25
26/// ARIMA forecasting kernel.
27///
28/// Fits an ARIMA(p,d,q) model and generates forecasts.
29#[derive(Debug, Clone)]
30pub struct ARIMAForecast {
31    metadata: KernelMetadata,
32}
33
34impl Default for ARIMAForecast {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40impl ARIMAForecast {
41    /// Create a new ARIMA forecast kernel.
42    #[must_use]
43    pub fn new() -> Self {
44        Self {
45            metadata: KernelMetadata::batch("temporal/arima-forecast", Domain::TemporalAnalysis)
46                .with_description("ARIMA model fitting and forecasting")
47                .with_throughput(10_000)
48                .with_latency_us(100.0),
49        }
50    }
51
52    /// Fit ARIMA model and generate forecasts.
53    ///
54    /// # Arguments
55    /// * `series` - Input time series
56    /// * `params` - ARIMA(p,d,q) parameters
57    /// * `horizon` - Number of steps to forecast
58    pub fn compute(series: &TimeSeries, params: ARIMAParams, horizon: usize) -> ARIMAResult {
59        if series.is_empty() {
60            return ARIMAResult {
61                ar_coefficients: Vec::new(),
62                ma_coefficients: Vec::new(),
63                intercept: 0.0,
64                fitted: Vec::new(),
65                residuals: Vec::new(),
66                forecast: Vec::new(),
67                aic: f64::INFINITY,
68            };
69        }
70
71        // Difference the series d times
72        let mut diff_series = series.values.clone();
73        for _ in 0..params.d {
74            diff_series = Self::difference(&diff_series);
75        }
76
77        if diff_series.len() < params.p.max(params.q) + 1 {
78            return ARIMAResult {
79                ar_coefficients: vec![0.0; params.p],
80                ma_coefficients: vec![0.0; params.q],
81                intercept: series.mean(),
82                fitted: series.values.clone(),
83                residuals: vec![0.0; series.len()],
84                forecast: vec![series.mean(); horizon],
85                aic: f64::INFINITY,
86            };
87        }
88
89        // Fit AR coefficients using Yule-Walker equations (simplified)
90        let ar_coefficients = if params.p > 0 {
91            Self::fit_ar(&diff_series, params.p)
92        } else {
93            Vec::new()
94        };
95
96        // Calculate residuals from AR fit
97        let ar_fitted = Self::apply_ar(&diff_series, &ar_coefficients);
98        let residuals: Vec<f64> = diff_series
99            .iter()
100            .zip(ar_fitted.iter())
101            .map(|(y, yhat)| y - yhat)
102            .collect();
103
104        // Fit MA coefficients (simplified - innovation algorithm)
105        let ma_coefficients = if params.q > 0 {
106            Self::fit_ma(&residuals, params.q)
107        } else {
108            Vec::new()
109        };
110
111        // Calculate intercept
112        let intercept = diff_series.iter().sum::<f64>() / diff_series.len() as f64;
113
114        // Generate fitted values
115        let fitted =
116            Self::generate_fitted(&diff_series, &ar_coefficients, &ma_coefficients, intercept);
117
118        // Integrate back to original scale
119        let fitted_integrated = Self::integrate(&fitted, &series.values, params.d);
120
121        // Calculate residuals on original scale
122        let final_residuals: Vec<f64> = series
123            .values
124            .iter()
125            .zip(fitted_integrated.iter())
126            .map(|(y, yhat)| y - yhat)
127            .collect();
128
129        // Generate forecasts
130        let forecast = Self::forecast_ahead(
131            &diff_series,
132            &ar_coefficients,
133            &ma_coefficients,
134            intercept,
135            horizon,
136        );
137
138        // Integrate forecasts
139        let forecast_integrated = Self::integrate_forecast(&forecast, &series.values, params.d);
140
141        // Calculate AIC
142        let n = series.len() as f64;
143        let k = (params.p + params.q + 1) as f64;
144        let rss: f64 = final_residuals.iter().map(|r| r.powi(2)).sum();
145        let aic = n * (rss / n).ln() + 2.0 * k;
146
147        ARIMAResult {
148            ar_coefficients,
149            ma_coefficients,
150            intercept,
151            fitted: fitted_integrated,
152            residuals: final_residuals,
153            forecast: forecast_integrated,
154            aic,
155        }
156    }
157
158    /// Difference a series.
159    fn difference(series: &[f64]) -> Vec<f64> {
160        if series.len() < 2 {
161            return Vec::new();
162        }
163        series.windows(2).map(|w| w[1] - w[0]).collect()
164    }
165
166    /// Fit AR coefficients using Yule-Walker equations.
167    fn fit_ar(series: &[f64], p: usize) -> Vec<f64> {
168        let n = series.len();
169        if n <= p {
170            return vec![0.0; p];
171        }
172
173        // Calculate autocorrelations
174        let mean: f64 = series.iter().sum::<f64>() / n as f64;
175        let var: f64 = series.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n as f64;
176
177        if var < 1e-10 {
178            return vec![0.0; p];
179        }
180
181        let mut acf = vec![1.0; p + 1];
182        for k in 1..=p {
183            let cov: f64 = (0..n - k)
184                .map(|i| (series[i] - mean) * (series[i + k] - mean))
185                .sum::<f64>()
186                / n as f64;
187            acf[k] = cov / var;
188        }
189
190        // Solve Yule-Walker using Levinson-Durbin
191        Self::levinson_durbin(&acf, p)
192    }
193
194    /// Levinson-Durbin algorithm for solving Yule-Walker equations.
195    fn levinson_durbin(acf: &[f64], p: usize) -> Vec<f64> {
196        let mut phi = vec![vec![0.0; p + 1]; p + 1];
197        let mut sigma = vec![0.0; p + 1];
198
199        sigma[0] = acf[0];
200
201        for k in 1..=p {
202            let mut num = acf[k];
203            for j in 1..k {
204                num -= phi[k - 1][j] * acf[k - j];
205            }
206            phi[k][k] = num / sigma[k - 1];
207
208            for j in 1..k {
209                phi[k][j] = phi[k - 1][j] - phi[k][k] * phi[k - 1][k - j];
210            }
211
212            sigma[k] = sigma[k - 1] * (1.0 - phi[k][k].powi(2));
213        }
214
215        (1..=p).map(|j| phi[p][j]).collect()
216    }
217
218    /// Apply AR model to get fitted values.
219    fn apply_ar(series: &[f64], coefficients: &[f64]) -> Vec<f64> {
220        let n = series.len();
221        let p = coefficients.len();
222        let mut fitted = vec![0.0; n];
223
224        for i in p..n {
225            for (j, &coef) in coefficients.iter().enumerate() {
226                fitted[i] += coef * series[i - j - 1];
227            }
228        }
229
230        fitted
231    }
232
233    /// Fit MA coefficients using CSS (Conditional Sum of Squares) optimization.
234    ///
235    /// This implements an iterative optimization approach that minimizes
236    /// the sum of squared residuals when fitting MA(q) parameters.
237    fn fit_ma(residuals: &[f64], q: usize) -> Vec<f64> {
238        let n = residuals.len();
239        if n <= q + 10 {
240            return vec![0.0; q];
241        }
242
243        let var: f64 = residuals.iter().map(|x| x.powi(2)).sum::<f64>() / n as f64;
244        if var < 1e-10 {
245            return vec![0.0; q];
246        }
247
248        // Initialize with autocorrelation-based estimates
249        let mean: f64 = residuals.iter().sum::<f64>() / n as f64;
250        let centered: Vec<f64> = residuals.iter().map(|&x| x - mean).collect();
251
252        let mut ma_coefs: Vec<f64> = Vec::with_capacity(q);
253        let c0: f64 = centered.iter().map(|x| x.powi(2)).sum::<f64>() / n as f64;
254
255        for k in 1..=q {
256            if c0 > 1e-10 {
257                let ck: f64 = (0..n - k)
258                    .map(|i| centered[i] * centered[i + k])
259                    .sum::<f64>()
260                    / n as f64;
261                // Initial estimate (bounded for stability)
262                ma_coefs.push((ck / c0).clamp(-0.9, 0.9));
263            } else {
264                ma_coefs.push(0.0);
265            }
266        }
267
268        // CSS optimization: iteratively refine MA coefficients
269        let mut best_sse = Self::calculate_ma_sse(residuals, &ma_coefs);
270
271        for _iter in 0..20 {
272            let mut improved = false;
273
274            for j in 0..q {
275                // Try perturbations for coefficient j
276                let original = ma_coefs[j];
277                let step_size = 0.05 * (1.0 - original.abs()).max(0.1);
278
279                for delta in [-step_size, step_size, -step_size * 0.5, step_size * 0.5] {
280                    let new_val = (original + delta).clamp(-0.95, 0.95);
281                    ma_coefs[j] = new_val;
282
283                    // Check invertibility: sum of |theta| < 1
284                    let sum_abs: f64 = ma_coefs.iter().map(|c| c.abs()).sum();
285                    if sum_abs >= 0.99 {
286                        ma_coefs[j] = original;
287                        continue;
288                    }
289
290                    let new_sse = Self::calculate_ma_sse(residuals, &ma_coefs);
291
292                    if new_sse < best_sse && new_sse.is_finite() {
293                        best_sse = new_sse;
294                        improved = true;
295                        break;
296                    } else {
297                        ma_coefs[j] = original;
298                    }
299                }
300            }
301
302            if !improved {
303                break;
304            }
305        }
306
307        ma_coefs
308    }
309
310    /// Calculate sum of squared errors for MA model.
311    fn calculate_ma_sse(residuals: &[f64], ma_coefs: &[f64]) -> f64 {
312        let n = residuals.len();
313        let q = ma_coefs.len();
314
315        // Reconstruct innovations (errors in the MA sense)
316        let mut innovations = vec![0.0; n];
317
318        for t in 0..n {
319            // e_t = r_t - sum(theta_j * e_{t-j})
320            let mut innovation = residuals[t];
321            for j in 0..q {
322                if t > j {
323                    innovation -= ma_coefs[j] * innovations[t - j - 1];
324                }
325            }
326            innovations[t] = innovation;
327        }
328
329        // Skip initial transient period
330        innovations[q..].iter().map(|e| e.powi(2)).sum()
331    }
332
333    /// Generate fitted values from ARMA model.
334    fn generate_fitted(
335        series: &[f64],
336        ar_coefs: &[f64],
337        ma_coefs: &[f64],
338        intercept: f64,
339    ) -> Vec<f64> {
340        let n = series.len();
341        let p = ar_coefs.len();
342        let q = ma_coefs.len();
343        let start = p.max(q);
344
345        let mut fitted = vec![series.iter().sum::<f64>() / n as f64; n];
346        let mut errors = vec![0.0; n];
347
348        for i in start..n {
349            let mut yhat = intercept;
350
351            // AR terms
352            for (j, &coef) in ar_coefs.iter().enumerate() {
353                yhat += coef * series[i - j - 1];
354            }
355
356            // MA terms
357            for (j, &coef) in ma_coefs.iter().enumerate() {
358                if i > j {
359                    yhat += coef * errors[i - j - 1];
360                }
361            }
362
363            fitted[i] = yhat;
364            errors[i] = series[i] - yhat;
365        }
366
367        fitted
368    }
369
370    /// Integrate differenced series back to original scale.
371    fn integrate(diff_fitted: &[f64], original: &[f64], d: usize) -> Vec<f64> {
372        if d == 0 || original.is_empty() {
373            return diff_fitted.to_vec();
374        }
375
376        let mut result = diff_fitted.to_vec();
377
378        for i in 0..d {
379            let start_val = if i < original.len() { original[i] } else { 0.0 };
380
381            let mut integrated = vec![start_val];
382            for &diff in &result {
383                integrated.push(integrated.last().unwrap() + diff);
384            }
385            result = integrated;
386        }
387
388        // Trim to original length
389        result.truncate(original.len());
390        result
391    }
392
393    /// Generate forecasts.
394    fn forecast_ahead(
395        series: &[f64],
396        ar_coefs: &[f64],
397        _ma_coefs: &[f64],
398        intercept: f64,
399        horizon: usize,
400    ) -> Vec<f64> {
401        let _p = ar_coefs.len();
402        let mut forecasts = Vec::with_capacity(horizon);
403        let mut extended = series.to_vec();
404
405        for _ in 0..horizon {
406            let mut yhat = intercept;
407
408            // AR terms using most recent values
409            for (j, &coef) in ar_coefs.iter().enumerate() {
410                let idx = extended.len().saturating_sub(j + 1);
411                yhat += coef * extended[idx];
412            }
413
414            // MA terms fade out as we forecast further ahead
415            forecasts.push(yhat);
416            extended.push(yhat);
417        }
418
419        forecasts
420    }
421
422    /// Integrate forecasts.
423    fn integrate_forecast(forecasts: &[f64], original: &[f64], d: usize) -> Vec<f64> {
424        if d == 0 || original.is_empty() {
425            return forecasts.to_vec();
426        }
427
428        let mut result = forecasts.to_vec();
429        let last_val = *original.last().unwrap_or(&0.0);
430
431        for _ in 0..d {
432            let mut integrated = vec![last_val];
433            for &diff in &result {
434                integrated.push(integrated.last().unwrap() + diff);
435            }
436            result = integrated[1..].to_vec(); // Skip the initial value
437        }
438
439        result
440    }
441}
442
443impl GpuKernel for ARIMAForecast {
444    fn metadata(&self) -> &KernelMetadata {
445        &self.metadata
446    }
447}
448
449#[async_trait]
450impl BatchKernel<ARIMAForecastInput, ARIMAForecastOutput> for ARIMAForecast {
451    async fn execute(&self, input: ARIMAForecastInput) -> Result<ARIMAForecastOutput> {
452        let start = Instant::now();
453        let result = Self::compute(&input.series, input.params, input.horizon);
454        Ok(ARIMAForecastOutput {
455            result,
456            compute_time_us: start.elapsed().as_micros() as u64,
457        })
458    }
459}
460
461// ============================================================================
462// Prophet-style Decomposition Forecast Kernel
463// ============================================================================
464
465/// Prophet-style decomposition and forecasting kernel.
466///
467/// Decomposes time series into trend + seasonality and forecasts.
468#[derive(Debug, Clone)]
469pub struct ProphetDecomposition {
470    metadata: KernelMetadata,
471}
472
473impl Default for ProphetDecomposition {
474    fn default() -> Self {
475        Self::new()
476    }
477}
478
479impl ProphetDecomposition {
480    /// Create a new Prophet decomposition kernel.
481    #[must_use]
482    pub fn new() -> Self {
483        Self {
484            metadata: KernelMetadata::batch(
485                "temporal/prophet-decomposition",
486                Domain::TemporalAnalysis,
487            )
488            .with_description("Prophet-style trend/seasonal decomposition")
489            .with_throughput(5_000)
490            .with_latency_us(200.0),
491        }
492    }
493
494    /// Decompose and forecast time series.
495    ///
496    /// # Arguments
497    /// * `series` - Input time series
498    /// * `period` - Seasonal period (e.g., 12 for monthly, 7 for daily)
499    /// * `horizon` - Forecast horizon
500    pub fn compute(series: &TimeSeries, period: Option<usize>, horizon: usize) -> ProphetResult {
501        if series.is_empty() {
502            return ProphetResult {
503                trend: Vec::new(),
504                seasonal: None,
505                holidays: None,
506                residuals: Vec::new(),
507                forecast: Vec::new(),
508            };
509        }
510
511        let n = series.len();
512
513        // Extract trend using centered moving average
514        let window = period.unwrap_or(1);
515        let trend = Self::extract_trend(&series.values, window);
516
517        // Extract seasonality if period specified
518        let seasonal = if let Some(p) = period {
519            if p > 1 && n > p {
520                Some(Self::extract_seasonal(&series.values, &trend, p))
521            } else {
522                None
523            }
524        } else {
525            None
526        };
527
528        // Calculate residuals
529        let residuals: Vec<f64> = series
530            .values
531            .iter()
532            .enumerate()
533            .map(|(i, &y)| {
534                let t = trend[i];
535                let s = seasonal.as_ref().map(|s| s[i % s.len()]).unwrap_or(0.0);
536                y - t - s
537            })
538            .collect();
539
540        // Generate forecasts
541        let forecast = Self::forecast(&trend, seasonal.as_ref(), &residuals, horizon);
542
543        ProphetResult {
544            trend,
545            seasonal,
546            holidays: None,
547            residuals,
548            forecast,
549        }
550    }
551
552    /// Extract trend using centered moving average.
553    #[allow(clippy::needless_range_loop)]
554    fn extract_trend(values: &[f64], window: usize) -> Vec<f64> {
555        let n = values.len();
556        let w = window.max(1);
557        let half_w = w / 2;
558
559        let mut trend = vec![0.0; n];
560
561        for i in 0..n {
562            let start = i.saturating_sub(half_w);
563            let end = (i + half_w + 1).min(n);
564            let count = end - start;
565
566            trend[i] = values[start..end].iter().sum::<f64>() / count as f64;
567        }
568
569        trend
570    }
571
572    /// Extract seasonal component.
573    fn extract_seasonal(values: &[f64], trend: &[f64], period: usize) -> Vec<f64> {
574        let _n = values.len();
575
576        // Detrend
577        let detrended: Vec<f64> = values
578            .iter()
579            .zip(trend.iter())
580            .map(|(v, t)| v - t)
581            .collect();
582
583        // Average by season
584        let mut seasonal = vec![0.0; period];
585        let mut counts = vec![0usize; period];
586
587        for (i, &d) in detrended.iter().enumerate() {
588            let s = i % period;
589            seasonal[s] += d;
590            counts[s] += 1;
591        }
592
593        for (s, &c) in counts.iter().enumerate() {
594            if c > 0 {
595                seasonal[s] /= c as f64;
596            }
597        }
598
599        // Center seasonal (subtract mean)
600        let mean: f64 = seasonal.iter().sum::<f64>() / period as f64;
601        for s in &mut seasonal {
602            *s -= mean;
603        }
604
605        seasonal
606    }
607
608    /// Generate forecasts.
609    fn forecast(
610        trend: &[f64],
611        seasonal: Option<&Vec<f64>>,
612        _residuals: &[f64],
613        horizon: usize,
614    ) -> Vec<f64> {
615        let n = trend.len();
616        if n < 2 {
617            return vec![trend.last().copied().unwrap_or(0.0); horizon];
618        }
619
620        // Extrapolate trend linearly
621        let slope = trend[n - 1] - trend[n - 2];
622        let last_trend = trend[n - 1];
623
624        let mut forecasts = Vec::with_capacity(horizon);
625
626        for h in 1..=horizon {
627            let trend_forecast = last_trend + slope * h as f64;
628            let seasonal_forecast = seasonal.map(|s| s[(n + h - 1) % s.len()]).unwrap_or(0.0);
629            forecasts.push(trend_forecast + seasonal_forecast);
630        }
631
632        forecasts
633    }
634}
635
636impl GpuKernel for ProphetDecomposition {
637    fn metadata(&self) -> &KernelMetadata {
638        &self.metadata
639    }
640}
641
642#[async_trait]
643impl BatchKernel<ProphetDecompositionInput, ProphetDecompositionOutput> for ProphetDecomposition {
644    async fn execute(
645        &self,
646        input: ProphetDecompositionInput,
647    ) -> Result<ProphetDecompositionOutput> {
648        let start = Instant::now();
649        let result = Self::compute(&input.series, input.period, input.horizon);
650        Ok(ProphetDecompositionOutput {
651            result,
652            compute_time_us: start.elapsed().as_micros() as u64,
653        })
654    }
655}
656
657#[cfg(test)]
658mod tests {
659    use super::*;
660
661    fn create_trend_series() -> TimeSeries {
662        // Linear trend: y = 10 + 2*t
663        TimeSeries::new((0..50).map(|t| 10.0 + 2.0 * t as f64).collect())
664    }
665
666    fn create_seasonal_series() -> TimeSeries {
667        // Trend + seasonal pattern
668        let period = 12;
669        let values: Vec<f64> = (0..60)
670            .map(|t| {
671                let trend = 100.0 + 0.5 * t as f64;
672                let seasonal =
673                    10.0 * ((2.0 * std::f64::consts::PI * t as f64 / period as f64).sin());
674                trend + seasonal
675            })
676            .collect();
677        TimeSeries::new(values)
678    }
679
680    #[test]
681    fn test_arima_metadata() {
682        let kernel = ARIMAForecast::new();
683        assert_eq!(kernel.metadata().id, "temporal/arima-forecast");
684        assert_eq!(kernel.metadata().domain, Domain::TemporalAnalysis);
685    }
686
687    #[test]
688    fn test_arima_forecast_trend() {
689        let series = create_trend_series();
690        let params = ARIMAParams::new(1, 1, 0); // AR(1) with differencing
691        let result = ARIMAForecast::compute(&series, params, 5);
692
693        assert_eq!(result.forecast.len(), 5);
694        assert!(!result.ar_coefficients.is_empty() || params.p == 0);
695
696        // Forecasts should continue the trend
697        let last = *series.values.last().unwrap();
698        for f in &result.forecast {
699            assert!(*f > last * 0.8); // Should be reasonably close to continuation
700        }
701    }
702
703    #[test]
704    fn test_prophet_metadata() {
705        let kernel = ProphetDecomposition::new();
706        assert_eq!(kernel.metadata().id, "temporal/prophet-decomposition");
707    }
708
709    #[test]
710    fn test_prophet_decomposition() {
711        let series = create_seasonal_series();
712        let result = ProphetDecomposition::compute(&series, Some(12), 12);
713
714        // Should have all components
715        assert_eq!(result.trend.len(), series.len());
716        assert!(result.seasonal.is_some());
717        assert_eq!(result.seasonal.as_ref().unwrap().len(), 12);
718        assert_eq!(result.forecast.len(), 12);
719    }
720
721    #[test]
722    fn test_prophet_no_seasonality() {
723        let series = create_trend_series();
724        let result = ProphetDecomposition::compute(&series, None, 5);
725
726        // No seasonal component
727        assert!(result.seasonal.is_none());
728        assert_eq!(result.forecast.len(), 5);
729    }
730
731    #[test]
732    fn test_empty_series() {
733        let empty = TimeSeries::new(Vec::new());
734
735        let arima = ARIMAForecast::compute(&empty, ARIMAParams::new(1, 0, 0), 5);
736        assert!(arima.forecast.is_empty());
737
738        let prophet = ProphetDecomposition::compute(&empty, Some(12), 5);
739        assert!(prophet.forecast.is_empty());
740    }
741}