Skip to main content

so_tsa/
forecast.rs

1//! Forecasting evaluation and prediction intervals
2//!
3//! This module provides tools for evaluating forecast accuracy
4//! and constructing prediction intervals.
5//!
6//! # Evaluation Metrics
7//!
8//! - **MAE**: Mean Absolute Error
9//! - **MSE**: Mean Squared Error  
10//! - **RMSE**: Root Mean Squared Error
11//! - **MAPE**: Mean Absolute Percentage Error
12//! - **SMAPE**: Symmetric Mean Absolute Percentage Error
13//! - **MASE**: Mean Absolute Scaled Error
14//! - **Theil's U**: Theil's inequality coefficient
15//!
16//! # Prediction Intervals
17//!
18//! Methods for constructing prediction intervals:
19//!
20//! 1. **Normal approximation**: Assuming normal distribution of errors
21//! 2. **Empirical quantiles**: Using empirical distribution of residuals
22//! 3. **Bootstrapping**: Resampling residuals to estimate uncertainty
23//! 4. **Conformal prediction**: Distribution-free intervals
24
25use ndarray::Array1;
26use rand::Rng;
27use so_core::error::{Error, Result};
28use std::collections::HashMap;
29
30/// Forecast evaluation metrics
31#[derive(Debug, Clone)]
32pub struct ForecastMetrics {
33    /// Mean Absolute Error
34    pub mae: f64,
35    /// Mean Squared Error
36    pub mse: f64,
37    /// Root Mean Squared Error
38    pub rmse: f64,
39    /// Mean Absolute Percentage Error (%)
40    pub mape: f64,
41    /// Symmetric Mean Absolute Percentage Error (%)
42    pub smape: f64,
43    /// Mean Absolute Scaled Error
44    pub mase: f64,
45    /// Theil's U statistic
46    pub theils_u: f64,
47    /// R-squared of forecast
48    pub r_squared: f64,
49    /// Number of observations
50    pub n: usize,
51    /// Additional custom metrics
52    pub custom: HashMap<String, f64>,
53}
54
55impl ForecastMetrics {
56    /// Create new forecast metrics from actual and predicted values
57    pub fn new(actual: &Array1<f64>, predicted: &Array1<f64>) -> Result<Self> {
58        let n = actual.len();
59        if n != predicted.len() {
60            return Err(Error::DataError(format!(
61                "Actual and predicted lengths differ: {} vs {}",
62                n,
63                predicted.len()
64            )));
65        }
66
67        if n == 0 {
68            return Err(Error::DataError(
69                "Empty data for forecast evaluation".to_string(),
70            ));
71        }
72
73        // Calculate errors
74        let mut errors = Array1::zeros(n);
75        let mut abs_errors = Array1::zeros(n);
76        let mut squared_errors = Array1::zeros(n);
77        let mut abs_percentage_errors = Array1::zeros(n);
78        let mut symmetric_errors = Array1::zeros(n);
79
80        for i in 0..n {
81            let error = actual[i] - predicted[i];
82            errors[i] = error;
83            abs_errors[i] = error.abs();
84            squared_errors[i] = error.powi(2);
85
86            if actual[i] != 0.0 {
87                abs_percentage_errors[i] = (error.abs() / actual[i].abs()) * 100.0;
88                symmetric_errors[i] =
89                    (error.abs() / (actual[i].abs() + predicted[i].abs())) * 200.0;
90            }
91        }
92
93        // Basic metrics
94        let mae = abs_errors.mean().unwrap_or(0.0);
95        let mse = squared_errors.mean().unwrap_or(0.0);
96        let rmse = mse.sqrt();
97
98        // Percentage errors (handle zeros)
99        let mape = if abs_percentage_errors.iter().any(|&x| x.is_finite()) {
100            abs_percentage_errors
101                .iter()
102                .filter(|&&x| x.is_finite())
103                .sum::<f64>()
104                / abs_percentage_errors
105                    .iter()
106                    .filter(|&&x| x.is_finite())
107                    .count() as f64
108        } else {
109            0.0
110        };
111
112        let smape = if symmetric_errors.iter().any(|&x| x.is_finite()) {
113            symmetric_errors
114                .iter()
115                .filter(|&&x| x.is_finite())
116                .sum::<f64>()
117                / symmetric_errors.iter().filter(|&&x| x.is_finite()).count() as f64
118        } else {
119            0.0
120        };
121
122        // MASE - need naive forecast errors
123        let mase = if n > 1 {
124            let mut naive_errors = Array1::zeros(n - 1);
125            for i in 1..n {
126                naive_errors[i - 1] = (actual[i] - actual[i - 1]).abs();
127            }
128            let mean_naive_error = naive_errors.mean().unwrap_or(1.0);
129            if mean_naive_error > 0.0 {
130                mae / mean_naive_error
131            } else {
132                0.0
133            }
134        } else {
135            0.0
136        };
137
138        // Theil's U
139        let theils_u = if actual.var(1.0) > 0.0 && predicted.var(1.0) > 0.0 {
140            rmse / (actual.var(1.0).sqrt() + predicted.var(1.0).sqrt())
141        } else {
142            0.0
143        };
144
145        // R-squared
146        let ss_res = squared_errors.sum();
147        let ss_tot = actual.var(1.0) * n as f64;
148        let r_squared = if ss_tot > 0.0 {
149            1.0 - ss_res / ss_tot
150        } else {
151            0.0
152        };
153
154        Ok(Self {
155            mae,
156            mse,
157            rmse,
158            mape,
159            smape,
160            mase,
161            theils_u,
162            r_squared,
163            n,
164            custom: HashMap::new(),
165        })
166    }
167
168    /// Add custom metric
169    pub fn with_custom(mut self, name: &str, value: f64) -> Self {
170        self.custom.insert(name.to_string(), value);
171        self
172    }
173
174    /// Create summary string
175    pub fn summary(&self) -> String {
176        let mut summary = String::new();
177        summary.push_str("Forecast Evaluation Metrics\n");
178        summary.push_str("===========================\n");
179        summary.push_str(&format!("Observations: {}\n", self.n));
180        summary.push_str(&format!("MAE:  {:.4}\n", self.mae));
181        summary.push_str(&format!("MSE:  {:.4}\n", self.mse));
182        summary.push_str(&format!("RMSE: {:.4}\n", self.rmse));
183        summary.push_str(&format!("MAPE: {:.2}%\n", self.mape));
184        summary.push_str(&format!("sMAPE: {:.2}%\n", self.smape));
185        summary.push_str(&format!("MASE: {:.4}\n", self.mase));
186        summary.push_str(&format!("Theil's U: {:.4}\n", self.theils_u));
187        summary.push_str(&format!("R²:   {:.4}\n", self.r_squared));
188
189        if !self.custom.is_empty() {
190            summary.push_str("\nCustom Metrics:\n");
191            for (name, value) in &self.custom {
192                summary.push_str(&format!("  {}: {:.4}\n", name, value));
193            }
194        }
195
196        // Interpretation
197        summary.push_str("\nInterpretation:\n");
198        if self.mape < 10.0 {
199            summary.push_str("  MAPE < 10%: Highly accurate forecast\n");
200        } else if self.mape < 20.0 {
201            summary.push_str("  MAPE < 20%: Good forecast\n");
202        } else if self.mape < 50.0 {
203            summary.push_str("  MAPE < 50%: Reasonable forecast\n");
204        } else {
205            summary.push_str("  MAPE ≥ 50%: Inaccurate forecast\n");
206        }
207
208        if self.mase < 1.0 {
209            summary.push_str("  MASE < 1: Better than naive forecast\n");
210        } else {
211            summary.push_str("  MASE ≥ 1: Worse than naive forecast\n");
212        }
213
214        summary
215    }
216
217    /// Compare two forecast methods
218    pub fn compare(&self, other: &Self, name_a: &str, name_b: &str) -> String {
219        let mut comparison = String::new();
220        comparison.push_str(&format!("Forecast Comparison: {} vs {}\n", name_a, name_b));
221        comparison.push_str("===================================\n");
222
223        comparison.push_str(&format!(
224            "MAE:  {:.4} vs {:.4} ({:+.2}%)\n",
225            self.mae,
226            other.mae,
227            (other.mae - self.mae) / self.mae.max(1e-10) * 100.0
228        ));
229        comparison.push_str(&format!(
230            "RMSE: {:.4} vs {:.4} ({:+.2}%)\n",
231            self.rmse,
232            other.rmse,
233            (other.rmse - self.rmse) / self.rmse.max(1e-10) * 100.0
234        ));
235        comparison.push_str(&format!(
236            "MAPE: {:.2}% vs {:.2}% ({:+.2}pp)\n",
237            self.mape,
238            other.mape,
239            other.mape - self.mape
240        ));
241        comparison.push_str(&format!(
242            "MASE: {:.4} vs {:.4} ({:+.2}%)\n",
243            self.mase,
244            other.mase,
245            (other.mase - self.mase) / self.mase.max(1e-10) * 100.0
246        ));
247
248        comparison
249    }
250}
251
252/// Prediction interval methods
253#[derive(Debug, Clone, Copy, PartialEq)]
254pub enum IntervalMethod {
255    /// Normal approximation: ± z * σ
256    Normal,
257    /// Empirical quantiles of residuals
258    Empirical,
259    /// Bootstrapped intervals
260    Bootstrap,
261    /// Conformal prediction
262    Conformal,
263}
264
265/// Prediction interval
266#[derive(Debug, Clone)]
267pub struct PredictionInterval {
268    /// Point forecast
269    pub point: f64,
270    /// Lower bound
271    pub lower: f64,
272    /// Upper bound
273    pub upper: f64,
274    /// Confidence level (e.g., 0.95 for 95%)
275    pub level: f64,
276    /// Method used
277    pub method: IntervalMethod,
278}
279
280impl PredictionInterval {
281    /// Check if actual value is within interval
282    pub fn contains(&self, actual: f64) -> bool {
283        actual >= self.lower && actual <= self.upper
284    }
285
286    /// Interval width
287    pub fn width(&self) -> f64 {
288        self.upper - self.lower
289    }
290
291    /// Interval as string
292    pub fn to_string(&self) -> String {
293        format!(
294            "{:.4} [{:.4}, {:.4}] ({}%)",
295            self.point,
296            self.lower,
297            self.upper,
298            (self.level * 100.0) as i32
299        )
300    }
301}
302
303/// Prediction intervals for multiple forecasts
304#[derive(Debug, Clone)]
305pub struct PredictionIntervals {
306    /// Point forecasts
307    pub points: Array1<f64>,
308    /// Lower bounds
309    pub lower: Array1<f64>,
310    /// Upper bounds
311    pub upper: Array1<f64>,
312    /// Confidence level
313    pub level: f64,
314    /// Method used
315    pub method: IntervalMethod,
316}
317
318impl PredictionIntervals {
319    /// Create normal approximation intervals
320    pub fn normal(points: &Array1<f64>, std_dev: f64, level: f64) -> Self {
321        let z = normal_quantile(1.0 - (1.0 - level) / 2.0);
322        let margin = z * std_dev;
323
324        let lower = points - margin;
325        let upper = points + margin;
326
327        Self {
328            points: points.clone(),
329            lower,
330            upper,
331            level,
332            method: IntervalMethod::Normal,
333        }
334    }
335
336    /// Create intervals from empirical residuals
337    pub fn empirical(points: &Array1<f64>, residuals: &Array1<f64>, level: f64) -> Self {
338        let n = residuals.len();
339        let mut sorted_residuals: Vec<f64> = residuals.iter().copied().collect();
340        sorted_residuals.sort_by(|a, b| a.partial_cmp(b).unwrap());
341
342        let lower_idx = ((1.0 - level) / 2.0 * n as f64).floor() as usize;
343        let upper_idx = ((1.0 + level) / 2.0 * n as f64).floor() as usize;
344
345        let lower_quantile = sorted_residuals[lower_idx.min(n - 1)];
346        let upper_quantile = sorted_residuals[upper_idx.min(n - 1)];
347
348        let lower = points + lower_quantile;
349        let upper = points + upper_quantile;
350
351        Self {
352            points: points.clone(),
353            lower,
354            upper,
355            level,
356            method: IntervalMethod::Empirical,
357        }
358    }
359
360    /// Create bootstrapped intervals
361    pub fn bootstrap(
362        points: &Array1<f64>,
363        residuals: &Array1<f64>,
364        level: f64,
365        n_bootstrap: usize,
366    ) -> Self {
367        let n = points.len();
368        let r = residuals.len();
369
370        let mut bootstrap_forecasts = Vec::new();
371
372        for _ in 0..n_bootstrap {
373            let mut boot_points = Array1::zeros(n);
374
375            for i in 0..n {
376                // Sample residual with replacement
377                let idx = rand::rng().random_range(0..r);
378                let boot_error = residuals[idx];
379                boot_points[i] = points[i] + boot_error;
380            }
381
382            bootstrap_forecasts.push(boot_points);
383        }
384
385        // Calculate quantiles
386        let mut lower = Array1::zeros(n);
387        let mut upper = Array1::zeros(n);
388
389        for i in 0..n {
390            let mut values: Vec<f64> = bootstrap_forecasts.iter().map(|arr| arr[i]).collect();
391            values.sort_by(|a, b| a.partial_cmp(b).unwrap());
392
393            let lower_idx = ((1.0 - level) / 2.0 * n_bootstrap as f64).floor() as usize;
394            let upper_idx = ((1.0 + level) / 2.0 * n_bootstrap as f64).floor() as usize;
395
396            lower[i] = values[lower_idx.min(n_bootstrap - 1)];
397            upper[i] = values[upper_idx.min(n_bootstrap - 1)];
398        }
399
400        Self {
401            points: points.clone(),
402            lower,
403            upper,
404            level,
405            method: IntervalMethod::Bootstrap,
406        }
407    }
408
409    /// Check coverage (proportion of actual values within intervals)
410    pub fn coverage(&self, actual: &Array1<f64>) -> f64 {
411        let n = actual.len();
412        let mut count = 0;
413
414        for i in 0..n.min(self.points.len()) {
415            if actual[i] >= self.lower[i] && actual[i] <= self.upper[i] {
416                count += 1;
417            }
418        }
419
420        count as f64 / n.min(self.points.len()) as f64
421    }
422
423    /// Average interval width
424    pub fn average_width(&self) -> f64 {
425        let n = self.points.len();
426        let mut total = 0.0;
427
428        for i in 0..n {
429            total += self.upper[i] - self.lower[i];
430        }
431
432        total / n as f64
433    }
434}
435
436/// Cross-validation for time series
437pub struct TimeSeriesCV {
438    /// Number of folds
439    pub n_folds: usize,
440    /// Minimum training size
441    pub min_train_size: usize,
442    /// Step size between folds
443    pub step_size: usize,
444    /// Whether to use expanding window
445    pub expanding: bool,
446}
447
448impl Default for TimeSeriesCV {
449    fn default() -> Self {
450        Self {
451            n_folds: 5,
452            min_train_size: 20,
453            step_size: 1,
454            expanding: false,
455        }
456    }
457}
458
459impl TimeSeriesCV {
460    /// Create new time series cross-validator
461    pub fn new(n_folds: usize) -> Self {
462        Self {
463            n_folds,
464            ..Default::default()
465        }
466    }
467
468    /// Perform cross-validation
469    pub fn cross_validate<F>(
470        &self,
471        data: &Array1<f64>,
472        forecast_fn: F,
473    ) -> Result<Vec<ForecastMetrics>>
474    where
475        F: Fn(&Array1<f64>, usize) -> Result<Array1<f64>>,
476    {
477        let n = data.len();
478        let mut results = Vec::new();
479
480        // Determine fold boundaries
481        let test_size = (n - self.min_train_size) / self.n_folds.max(1);
482        if test_size == 0 {
483            return Err(Error::DataError(
484                "Not enough data for cross-validation".to_string(),
485            ));
486        }
487
488        for fold in 0..self.n_folds {
489            let train_end = self.min_train_size + fold * self.step_size;
490            if train_end >= n {
491                break;
492            }
493
494            let test_end = (train_end + test_size).min(n);
495
496            // Split data
497            let train_data = data.slice(ndarray::s![..train_end]).to_owned();
498            let test_data = data.slice(ndarray::s![train_end..test_end]).to_owned();
499
500            // Generate forecasts
501            let horizon = test_data.len();
502            let forecasts = forecast_fn(&train_data, horizon)?;
503
504            // Evaluate
505            if forecasts.len() == test_data.len() {
506                let metrics = ForecastMetrics::new(&test_data, &forecasts)?;
507                results.push(metrics);
508            }
509        }
510
511        Ok(results)
512    }
513
514    /// Aggregate cross-validation results
515    pub fn aggregate_metrics(&self, metrics: &[ForecastMetrics]) -> ForecastMetrics {
516        let n = metrics.len();
517        let mut aggregated = ForecastMetrics {
518            mae: 0.0,
519            mse: 0.0,
520            rmse: 0.0,
521            mape: 0.0,
522            smape: 0.0,
523            mase: 0.0,
524            theils_u: 0.0,
525            r_squared: 0.0,
526            n: metrics.iter().map(|m| m.n).sum(),
527            custom: HashMap::new(),
528        };
529
530        for metric in metrics {
531            aggregated.mae += metric.mae;
532            aggregated.mse += metric.mse;
533            aggregated.rmse += metric.rmse;
534            aggregated.mape += metric.mape;
535            aggregated.smape += metric.smape;
536            aggregated.mase += metric.mase;
537            aggregated.theils_u += metric.theils_u;
538            aggregated.r_squared += metric.r_squared;
539        }
540
541        aggregated.mae /= n as f64;
542        aggregated.mse /= n as f64;
543        aggregated.rmse /= n as f64;
544        aggregated.mape /= n as f64;
545        aggregated.smape /= n as f64;
546        aggregated.mase /= n as f64;
547        aggregated.theils_u /= n as f64;
548        aggregated.r_squared /= n as f64;
549
550        aggregated
551    }
552}
553
554/// Normal distribution quantile (simplified)
555fn normal_quantile(p: f64) -> f64 {
556    // Approximation of inverse normal CDF
557    let t = if p <= 0.5 {
558        (-2.0 * p.ln()).sqrt()
559    } else {
560        (-2.0 * (1.0 - p).ln()).sqrt()
561    };
562
563    let c0 = 2.515517;
564    let c1 = 0.802853;
565    let c2 = 0.010328;
566    let d1 = 1.432788;
567    let d2 = 0.189269;
568    let d3 = 0.001308;
569
570    let num = c0 + c1 * t + c2 * t.powi(2);
571    let den = 1.0 + d1 * t + d2 * t.powi(2) + d3 * t.powi(3);
572
573    if p <= 0.5 {
574        -t + num / den
575    } else {
576        t - num / den
577    }
578}