1use ndarray::Array1;
26use rand::Rng;
27use so_core::error::{Error, Result};
28use std::collections::HashMap;
29
30#[derive(Debug, Clone)]
32pub struct ForecastMetrics {
33 pub mae: f64,
35 pub mse: f64,
37 pub rmse: f64,
39 pub mape: f64,
41 pub smape: f64,
43 pub mase: f64,
45 pub theils_u: f64,
47 pub r_squared: f64,
49 pub n: usize,
51 pub custom: HashMap<String, f64>,
53}
54
55impl ForecastMetrics {
56 pub fn new(actual: &Array1<f64>, predicted: &Array1<f64>) -> Result<Self> {
58 let n = actual.len();
59 if n != predicted.len() {
60 return Err(Error::DataError(format!(
61 "Actual and predicted lengths differ: {} vs {}",
62 n,
63 predicted.len()
64 )));
65 }
66
67 if n == 0 {
68 return Err(Error::DataError(
69 "Empty data for forecast evaluation".to_string(),
70 ));
71 }
72
73 let mut errors = Array1::zeros(n);
75 let mut abs_errors = Array1::zeros(n);
76 let mut squared_errors = Array1::zeros(n);
77 let mut abs_percentage_errors = Array1::zeros(n);
78 let mut symmetric_errors = Array1::zeros(n);
79
80 for i in 0..n {
81 let error = actual[i] - predicted[i];
82 errors[i] = error;
83 abs_errors[i] = error.abs();
84 squared_errors[i] = error.powi(2);
85
86 if actual[i] != 0.0 {
87 abs_percentage_errors[i] = (error.abs() / actual[i].abs()) * 100.0;
88 symmetric_errors[i] =
89 (error.abs() / (actual[i].abs() + predicted[i].abs())) * 200.0;
90 }
91 }
92
93 let mae = abs_errors.mean().unwrap_or(0.0);
95 let mse = squared_errors.mean().unwrap_or(0.0);
96 let rmse = mse.sqrt();
97
98 let mape = if abs_percentage_errors.iter().any(|&x| x.is_finite()) {
100 abs_percentage_errors
101 .iter()
102 .filter(|&&x| x.is_finite())
103 .sum::<f64>()
104 / abs_percentage_errors
105 .iter()
106 .filter(|&&x| x.is_finite())
107 .count() as f64
108 } else {
109 0.0
110 };
111
112 let smape = if symmetric_errors.iter().any(|&x| x.is_finite()) {
113 symmetric_errors
114 .iter()
115 .filter(|&&x| x.is_finite())
116 .sum::<f64>()
117 / symmetric_errors.iter().filter(|&&x| x.is_finite()).count() as f64
118 } else {
119 0.0
120 };
121
122 let mase = if n > 1 {
124 let mut naive_errors = Array1::zeros(n - 1);
125 for i in 1..n {
126 naive_errors[i - 1] = (actual[i] - actual[i - 1]).abs();
127 }
128 let mean_naive_error = naive_errors.mean().unwrap_or(1.0);
129 if mean_naive_error > 0.0 {
130 mae / mean_naive_error
131 } else {
132 0.0
133 }
134 } else {
135 0.0
136 };
137
138 let theils_u = if actual.var(1.0) > 0.0 && predicted.var(1.0) > 0.0 {
140 rmse / (actual.var(1.0).sqrt() + predicted.var(1.0).sqrt())
141 } else {
142 0.0
143 };
144
145 let ss_res = squared_errors.sum();
147 let ss_tot = actual.var(1.0) * n as f64;
148 let r_squared = if ss_tot > 0.0 {
149 1.0 - ss_res / ss_tot
150 } else {
151 0.0
152 };
153
154 Ok(Self {
155 mae,
156 mse,
157 rmse,
158 mape,
159 smape,
160 mase,
161 theils_u,
162 r_squared,
163 n,
164 custom: HashMap::new(),
165 })
166 }
167
168 pub fn with_custom(mut self, name: &str, value: f64) -> Self {
170 self.custom.insert(name.to_string(), value);
171 self
172 }
173
174 pub fn summary(&self) -> String {
176 let mut summary = String::new();
177 summary.push_str("Forecast Evaluation Metrics\n");
178 summary.push_str("===========================\n");
179 summary.push_str(&format!("Observations: {}\n", self.n));
180 summary.push_str(&format!("MAE: {:.4}\n", self.mae));
181 summary.push_str(&format!("MSE: {:.4}\n", self.mse));
182 summary.push_str(&format!("RMSE: {:.4}\n", self.rmse));
183 summary.push_str(&format!("MAPE: {:.2}%\n", self.mape));
184 summary.push_str(&format!("sMAPE: {:.2}%\n", self.smape));
185 summary.push_str(&format!("MASE: {:.4}\n", self.mase));
186 summary.push_str(&format!("Theil's U: {:.4}\n", self.theils_u));
187 summary.push_str(&format!("R²: {:.4}\n", self.r_squared));
188
189 if !self.custom.is_empty() {
190 summary.push_str("\nCustom Metrics:\n");
191 for (name, value) in &self.custom {
192 summary.push_str(&format!(" {}: {:.4}\n", name, value));
193 }
194 }
195
196 summary.push_str("\nInterpretation:\n");
198 if self.mape < 10.0 {
199 summary.push_str(" MAPE < 10%: Highly accurate forecast\n");
200 } else if self.mape < 20.0 {
201 summary.push_str(" MAPE < 20%: Good forecast\n");
202 } else if self.mape < 50.0 {
203 summary.push_str(" MAPE < 50%: Reasonable forecast\n");
204 } else {
205 summary.push_str(" MAPE ≥ 50%: Inaccurate forecast\n");
206 }
207
208 if self.mase < 1.0 {
209 summary.push_str(" MASE < 1: Better than naive forecast\n");
210 } else {
211 summary.push_str(" MASE ≥ 1: Worse than naive forecast\n");
212 }
213
214 summary
215 }
216
217 pub fn compare(&self, other: &Self, name_a: &str, name_b: &str) -> String {
219 let mut comparison = String::new();
220 comparison.push_str(&format!("Forecast Comparison: {} vs {}\n", name_a, name_b));
221 comparison.push_str("===================================\n");
222
223 comparison.push_str(&format!(
224 "MAE: {:.4} vs {:.4} ({:+.2}%)\n",
225 self.mae,
226 other.mae,
227 (other.mae - self.mae) / self.mae.max(1e-10) * 100.0
228 ));
229 comparison.push_str(&format!(
230 "RMSE: {:.4} vs {:.4} ({:+.2}%)\n",
231 self.rmse,
232 other.rmse,
233 (other.rmse - self.rmse) / self.rmse.max(1e-10) * 100.0
234 ));
235 comparison.push_str(&format!(
236 "MAPE: {:.2}% vs {:.2}% ({:+.2}pp)\n",
237 self.mape,
238 other.mape,
239 other.mape - self.mape
240 ));
241 comparison.push_str(&format!(
242 "MASE: {:.4} vs {:.4} ({:+.2}%)\n",
243 self.mase,
244 other.mase,
245 (other.mase - self.mase) / self.mase.max(1e-10) * 100.0
246 ));
247
248 comparison
249 }
250}
251
252#[derive(Debug, Clone, Copy, PartialEq)]
254pub enum IntervalMethod {
255 Normal,
257 Empirical,
259 Bootstrap,
261 Conformal,
263}
264
265#[derive(Debug, Clone)]
267pub struct PredictionInterval {
268 pub point: f64,
270 pub lower: f64,
272 pub upper: f64,
274 pub level: f64,
276 pub method: IntervalMethod,
278}
279
280impl PredictionInterval {
281 pub fn contains(&self, actual: f64) -> bool {
283 actual >= self.lower && actual <= self.upper
284 }
285
286 pub fn width(&self) -> f64 {
288 self.upper - self.lower
289 }
290
291 pub fn to_string(&self) -> String {
293 format!(
294 "{:.4} [{:.4}, {:.4}] ({}%)",
295 self.point,
296 self.lower,
297 self.upper,
298 (self.level * 100.0) as i32
299 )
300 }
301}
302
303#[derive(Debug, Clone)]
305pub struct PredictionIntervals {
306 pub points: Array1<f64>,
308 pub lower: Array1<f64>,
310 pub upper: Array1<f64>,
312 pub level: f64,
314 pub method: IntervalMethod,
316}
317
318impl PredictionIntervals {
319 pub fn normal(points: &Array1<f64>, std_dev: f64, level: f64) -> Self {
321 let z = normal_quantile(1.0 - (1.0 - level) / 2.0);
322 let margin = z * std_dev;
323
324 let lower = points - margin;
325 let upper = points + margin;
326
327 Self {
328 points: points.clone(),
329 lower,
330 upper,
331 level,
332 method: IntervalMethod::Normal,
333 }
334 }
335
336 pub fn empirical(points: &Array1<f64>, residuals: &Array1<f64>, level: f64) -> Self {
338 let n = residuals.len();
339 let mut sorted_residuals: Vec<f64> = residuals.iter().copied().collect();
340 sorted_residuals.sort_by(|a, b| a.partial_cmp(b).unwrap());
341
342 let lower_idx = ((1.0 - level) / 2.0 * n as f64).floor() as usize;
343 let upper_idx = ((1.0 + level) / 2.0 * n as f64).floor() as usize;
344
345 let lower_quantile = sorted_residuals[lower_idx.min(n - 1)];
346 let upper_quantile = sorted_residuals[upper_idx.min(n - 1)];
347
348 let lower = points + lower_quantile;
349 let upper = points + upper_quantile;
350
351 Self {
352 points: points.clone(),
353 lower,
354 upper,
355 level,
356 method: IntervalMethod::Empirical,
357 }
358 }
359
360 pub fn bootstrap(
362 points: &Array1<f64>,
363 residuals: &Array1<f64>,
364 level: f64,
365 n_bootstrap: usize,
366 ) -> Self {
367 let n = points.len();
368 let r = residuals.len();
369
370 let mut bootstrap_forecasts = Vec::new();
371
372 for _ in 0..n_bootstrap {
373 let mut boot_points = Array1::zeros(n);
374
375 for i in 0..n {
376 let idx = rand::rng().random_range(0..r);
378 let boot_error = residuals[idx];
379 boot_points[i] = points[i] + boot_error;
380 }
381
382 bootstrap_forecasts.push(boot_points);
383 }
384
385 let mut lower = Array1::zeros(n);
387 let mut upper = Array1::zeros(n);
388
389 for i in 0..n {
390 let mut values: Vec<f64> = bootstrap_forecasts.iter().map(|arr| arr[i]).collect();
391 values.sort_by(|a, b| a.partial_cmp(b).unwrap());
392
393 let lower_idx = ((1.0 - level) / 2.0 * n_bootstrap as f64).floor() as usize;
394 let upper_idx = ((1.0 + level) / 2.0 * n_bootstrap as f64).floor() as usize;
395
396 lower[i] = values[lower_idx.min(n_bootstrap - 1)];
397 upper[i] = values[upper_idx.min(n_bootstrap - 1)];
398 }
399
400 Self {
401 points: points.clone(),
402 lower,
403 upper,
404 level,
405 method: IntervalMethod::Bootstrap,
406 }
407 }
408
409 pub fn coverage(&self, actual: &Array1<f64>) -> f64 {
411 let n = actual.len();
412 let mut count = 0;
413
414 for i in 0..n.min(self.points.len()) {
415 if actual[i] >= self.lower[i] && actual[i] <= self.upper[i] {
416 count += 1;
417 }
418 }
419
420 count as f64 / n.min(self.points.len()) as f64
421 }
422
423 pub fn average_width(&self) -> f64 {
425 let n = self.points.len();
426 let mut total = 0.0;
427
428 for i in 0..n {
429 total += self.upper[i] - self.lower[i];
430 }
431
432 total / n as f64
433 }
434}
435
436pub struct TimeSeriesCV {
438 pub n_folds: usize,
440 pub min_train_size: usize,
442 pub step_size: usize,
444 pub expanding: bool,
446}
447
448impl Default for TimeSeriesCV {
449 fn default() -> Self {
450 Self {
451 n_folds: 5,
452 min_train_size: 20,
453 step_size: 1,
454 expanding: false,
455 }
456 }
457}
458
459impl TimeSeriesCV {
460 pub fn new(n_folds: usize) -> Self {
462 Self {
463 n_folds,
464 ..Default::default()
465 }
466 }
467
468 pub fn cross_validate<F>(
470 &self,
471 data: &Array1<f64>,
472 forecast_fn: F,
473 ) -> Result<Vec<ForecastMetrics>>
474 where
475 F: Fn(&Array1<f64>, usize) -> Result<Array1<f64>>,
476 {
477 let n = data.len();
478 let mut results = Vec::new();
479
480 let test_size = (n - self.min_train_size) / self.n_folds.max(1);
482 if test_size == 0 {
483 return Err(Error::DataError(
484 "Not enough data for cross-validation".to_string(),
485 ));
486 }
487
488 for fold in 0..self.n_folds {
489 let train_end = self.min_train_size + fold * self.step_size;
490 if train_end >= n {
491 break;
492 }
493
494 let test_end = (train_end + test_size).min(n);
495
496 let train_data = data.slice(ndarray::s![..train_end]).to_owned();
498 let test_data = data.slice(ndarray::s![train_end..test_end]).to_owned();
499
500 let horizon = test_data.len();
502 let forecasts = forecast_fn(&train_data, horizon)?;
503
504 if forecasts.len() == test_data.len() {
506 let metrics = ForecastMetrics::new(&test_data, &forecasts)?;
507 results.push(metrics);
508 }
509 }
510
511 Ok(results)
512 }
513
514 pub fn aggregate_metrics(&self, metrics: &[ForecastMetrics]) -> ForecastMetrics {
516 let n = metrics.len();
517 let mut aggregated = ForecastMetrics {
518 mae: 0.0,
519 mse: 0.0,
520 rmse: 0.0,
521 mape: 0.0,
522 smape: 0.0,
523 mase: 0.0,
524 theils_u: 0.0,
525 r_squared: 0.0,
526 n: metrics.iter().map(|m| m.n).sum(),
527 custom: HashMap::new(),
528 };
529
530 for metric in metrics {
531 aggregated.mae += metric.mae;
532 aggregated.mse += metric.mse;
533 aggregated.rmse += metric.rmse;
534 aggregated.mape += metric.mape;
535 aggregated.smape += metric.smape;
536 aggregated.mase += metric.mase;
537 aggregated.theils_u += metric.theils_u;
538 aggregated.r_squared += metric.r_squared;
539 }
540
541 aggregated.mae /= n as f64;
542 aggregated.mse /= n as f64;
543 aggregated.rmse /= n as f64;
544 aggregated.mape /= n as f64;
545 aggregated.smape /= n as f64;
546 aggregated.mase /= n as f64;
547 aggregated.theils_u /= n as f64;
548 aggregated.r_squared /= n as f64;
549
550 aggregated
551 }
552}
553
554fn normal_quantile(p: f64) -> f64 {
556 let t = if p <= 0.5 {
558 (-2.0 * p.ln()).sqrt()
559 } else {
560 (-2.0 * (1.0 - p).ln()).sqrt()
561 };
562
563 let c0 = 2.515517;
564 let c1 = 0.802853;
565 let c2 = 0.010328;
566 let d1 = 1.432788;
567 let d2 = 0.189269;
568 let d3 = 0.001308;
569
570 let num = c0 + c1 * t + c2 * t.powi(2);
571 let den = 1.0 + d1 * t + d2 * t.powi(2) + d3 * t.powi(3);
572
573 if p <= 0.5 {
574 -t + num / den
575 } else {
576 t - num / den
577 }
578}