Skip to main content

scirs2_metrics/temporal/
mod.rs

1//! Temporal Metrics: Dynamic Time Warping and Probabilistic Forecast Skill Scores
2//!
3//! This module provides metrics for evaluating time series models and probabilistic
4//! forecasting systems:
5//!
6//! - **Dynamic Time Warping (DTW)**: Elastic similarity measure for time series
7//!   - Full O(n·m) DP computation
8//!   - Sakoe-Chiba band constraint for O(n·w) speedup
9//!   - Normalized DTW (divide by optimal path length)
10//! - **Brier Score**: MSE for probabilistic binary forecasts
11//! - **Brier Skill Score (BSS)**: Relative improvement over climatology
12//! - **CRPS**: Continuous Ranked Probability Score for distributional forecasts
13//!   - Closed-form for Gaussian predictive distributions
14//!   - Empirical approximation for general distributions
15//! - **Directional Accuracy**: Fraction of forecasts with correct sign of change
16//! - **Diebold-Mariano Test**: Statistical test for comparing two forecast series
17//!
18//! # Examples
19//!
20//! ```
21//! use scirs2_metrics::temporal::{
22//!     dtw, dtw_windowed, brier_score_temporal, crps_gaussian, directional_accuracy,
23//! };
24//!
25//! // DTW of identical series = 0
26//! let x = vec![1.0, 2.0, 3.0, 2.0, 1.0];
27//! let d = dtw(&x, &x).expect("should succeed");
28//! assert!(d.abs() < 1e-10);
29//!
30//! // CRPS for Gaussian predictive
31//! let mu = vec![0.0, 1.0, 2.0];
32//! let sigma = vec![1.0, 1.0, 1.0];
33//! let obs = vec![0.5, 1.5, 1.8];
34//! let crps = crps_gaussian(&mu, &sigma, &obs).expect("should succeed");
35//! assert!(crps >= 0.0);
36//! ```
37
38use crate::error::{MetricsError, Result};
39
40// ────────────────────────────────────────────────────────────────────────────
41// Dynamic Time Warping
42// ────────────────────────────────────────────────────────────────────────────
43
44/// Computes the Dynamic Time Warping (DTW) distance between two time series.
45///
46/// DTW finds the optimal alignment between two sequences by allowing elastic
47/// stretching/compression in the time dimension. The standard O(n·m) DP
48/// recurrence is:
49///
50/// ```text
51/// DTW[i][j] = dist(x[i], y[j]) + min(DTW[i-1][j], DTW[i][j-1], DTW[i-1][j-1])
52/// ```
53///
54/// # Arguments
55///
56/// * `x` - First time series
57/// * `y` - Second time series
58///
59/// # Returns
60///
61/// The DTW distance (non-negative).
62pub fn dtw(x: &[f64], y: &[f64]) -> Result<f64> {
63    dtw_full(x, y, None)
64}
65
66/// Computes DTW with a Sakoe-Chiba band constraint.
67///
68/// The Sakoe-Chiba band restricts warping to a window of ±`window` steps,
69/// reducing complexity from O(n·m) to O(n·w).
70///
71/// When `window` is large enough to cover the full matrix, this is equivalent
72/// to unconstrained DTW.
73///
74/// # Arguments
75///
76/// * `x` - First time series
77/// * `y` - Second time series
78/// * `window` - Half-width of Sakoe-Chiba band
79///
80/// # Returns
81///
82/// The constrained DTW distance.
83pub fn dtw_windowed(x: &[f64], y: &[f64], window: usize) -> Result<f64> {
84    dtw_full(x, y, Some(window))
85}
86
87/// Computes normalized DTW distance (divided by warping path length).
88///
89/// Normalization makes DTW comparable across time series of different lengths.
90///
91/// # Arguments
92///
93/// * `x` - First time series
94/// * `y` - Second time series
95/// * `window` - Optional Sakoe-Chiba window
96pub fn dtw_normalized(x: &[f64], y: &[f64], window: Option<usize>) -> Result<f64> {
97    if x.is_empty() || y.is_empty() {
98        return Err(MetricsError::InvalidInput(
99            "time series must not be empty".to_string(),
100        ));
101    }
102
103    let (dist, path_len) = dtw_with_path_length(x, y, window)?;
104
105    if path_len == 0 {
106        return Ok(0.0);
107    }
108
109    Ok(dist / path_len as f64)
110}
111
112/// Full DTW computation returning both distance and path length.
113fn dtw_with_path_length(x: &[f64], y: &[f64], window: Option<usize>) -> Result<(f64, usize)> {
114    let n = x.len();
115    let m = y.len();
116
117    if n == 0 || m == 0 {
118        return Err(MetricsError::InvalidInput(
119            "time series must not be empty".to_string(),
120        ));
121    }
122
123    // Effective window size: if None, use full matrix
124    let w = window.unwrap_or(n.max(m));
125    let w = w.max((n as isize - m as isize).unsigned_abs()); // must cover length difference
126
127    let inf = f64::INFINITY;
128    // Use flattened 2D array: dp[i][j] = DTW(x[0..=i], y[0..=j])
129    let mut dp = vec![inf; n * m];
130
131    let idx = |i: usize, j: usize| i * m + j;
132
133    for i in 0..n {
134        for j in 0..m {
135            // Sakoe-Chiba constraint
136            if (i as isize - j as isize).unsigned_abs() > w {
137                continue;
138            }
139
140            let cost = (x[i] - y[j]).abs();
141
142            let prev = if i == 0 && j == 0 {
143                0.0
144            } else if i == 0 {
145                dp[idx(0, j - 1)]
146            } else if j == 0 {
147                dp[idx(i - 1, 0)]
148            } else {
149                let d_ij = dp[idx(i - 1, j - 1)];
150                let d_i1j = dp[idx(i - 1, j)];
151                let d_ij1 = dp[idx(i, j - 1)];
152                d_ij.min(d_i1j).min(d_ij1)
153            };
154
155            if prev.is_infinite() && !(i == 0 && j == 0) {
156                dp[idx(i, j)] = inf;
157            } else {
158                dp[idx(i, j)] = cost + if i == 0 && j == 0 { 0.0 } else { prev };
159            }
160        }
161    }
162
163    let total = dp[idx(n - 1, m - 1)];
164    if total.is_infinite() {
165        return Err(MetricsError::CalculationError(
166            "DTW path not found within Sakoe-Chiba window; increase window size".to_string(),
167        ));
168    }
169
170    // Traceback to find path length
171    let mut path_len = 0usize;
172    let mut i = n - 1;
173    let mut j = m - 1;
174
175    loop {
176        path_len += 1;
177        if i == 0 && j == 0 {
178            break;
179        }
180        if i == 0 {
181            j -= 1;
182        } else if j == 0 {
183            i -= 1;
184        } else {
185            let d_ij = dp[idx(i - 1, j - 1)];
186            let d_i1j = dp[idx(i - 1, j)];
187            let d_ij1 = dp[idx(i, j - 1)];
188            let min_prev = d_ij.min(d_i1j).min(d_ij1);
189            if min_prev == d_ij {
190                i -= 1;
191                j -= 1;
192            } else if min_prev == d_i1j {
193                i -= 1;
194            } else {
195                j -= 1;
196            }
197        }
198    }
199
200    Ok((total, path_len))
201}
202
203/// Internal full DTW (handles window=None as unconstrained).
204fn dtw_full(x: &[f64], y: &[f64], window: Option<usize>) -> Result<f64> {
205    if x.is_empty() || y.is_empty() {
206        return Err(MetricsError::InvalidInput(
207            "time series must not be empty".to_string(),
208        ));
209    }
210    let (dist, _) = dtw_with_path_length(x, y, window)?;
211    Ok(dist)
212}
213
214// ────────────────────────────────────────────────────────────────────────────
215// Brier Score and Brier Skill Score
216// ────────────────────────────────────────────────────────────────────────────
217
218/// Computes the Brier Score for probabilistic binary forecasts.
219///
220/// BS = (1/n) * sum_i (p_i - o_i)²
221///
222/// where p_i is the forecast probability and o_i ∈ {0, 1} is the observation.
223///
224/// # Arguments
225///
226/// * `forecasts` - Predicted probabilities in [0, 1]
227/// * `observations` - Observed binary outcomes (0.0 or 1.0)
228///
229/// # Returns
230///
231/// Brier score in [0, 1]. 0 = perfect, 0.25 = no skill (climatology at 0.5).
232pub fn brier_score_temporal(forecasts: &[f64], observations: &[f64]) -> Result<f64> {
233    if forecasts.len() != observations.len() {
234        return Err(MetricsError::DimensionMismatch(format!(
235            "forecasts ({}) and observations ({}) must have the same length",
236            forecasts.len(),
237            observations.len()
238        )));
239    }
240    if forecasts.is_empty() {
241        return Err(MetricsError::InvalidInput(
242            "inputs must not be empty".to_string(),
243        ));
244    }
245
246    let n = forecasts.len();
247    let bs: f64 = (0..n)
248        .map(|i| (forecasts[i] - observations[i]).powi(2))
249        .sum::<f64>()
250        / n as f64;
251    Ok(bs)
252}
253
254/// Computes the Brier Skill Score (BSS) relative to a climatological reference.
255///
256/// BSS = 1 - BS(model) / BS(reference)
257///
258/// The reference is the climatological frequency (base rate) for all forecasts.
259/// BSS = 1: perfect forecast. BSS = 0: no skill over climatology. BSS < 0: worse.
260///
261/// # Arguments
262///
263/// * `forecasts` - Predicted probabilities
264/// * `observations` - Observed binary outcomes
265pub fn brier_skill_score_temporal(forecasts: &[f64], observations: &[f64]) -> Result<f64> {
266    let bs_model = brier_score_temporal(forecasts, observations)?;
267    let base_rate = observations.iter().sum::<f64>() / observations.len() as f64;
268    let ref_fc: Vec<f64> = vec![base_rate; observations.len()];
269    let bs_ref = brier_score_temporal(&ref_fc, observations)?;
270
271    if bs_ref <= f64::EPSILON {
272        // Degenerate case: all observations same
273        if bs_model <= f64::EPSILON {
274            return Ok(1.0);
275        }
276        return Ok(f64::NEG_INFINITY);
277    }
278
279    Ok(1.0 - bs_model / bs_ref)
280}
281
282// ────────────────────────────────────────────────────────────────────────────
283// CRPS (Continuous Ranked Probability Score)
284// ────────────────────────────────────────────────────────────────────────────
285
286/// Computes the mean CRPS for Gaussian predictive distributions.
287///
288/// For a Gaussian forecast N(μ, σ²), the CRPS has a closed-form solution:
289/// ```text
290/// CRPS(N(μ,σ), y) = σ * [ (y-μ)/σ * (2Φ((y-μ)/σ) - 1)
291///                         + 2φ((y-μ)/σ) - 1/√π ]
292/// ```
293/// where Φ is the standard normal CDF and φ is the PDF.
294///
295/// # Arguments
296///
297/// * `mu` - Predictive means, one per observation
298/// * `sigma` - Predictive standard deviations (must be > 0)
299/// * `observations` - Observed values
300///
301/// # Returns
302///
303/// Mean CRPS (lower is better; 0 is optimal).
304pub fn crps_gaussian(mu: &[f64], sigma: &[f64], observations: &[f64]) -> Result<f64> {
305    let n = mu.len();
306    if n != sigma.len() || n != observations.len() {
307        return Err(MetricsError::DimensionMismatch(format!(
308            "mu ({}), sigma ({}), observations ({}) must have the same length",
309            n,
310            sigma.len(),
311            observations.len()
312        )));
313    }
314    if n == 0 {
315        return Err(MetricsError::InvalidInput(
316            "inputs must not be empty".to_string(),
317        ));
318    }
319
320    let mut total = 0.0f64;
321    for i in 0..n {
322        if sigma[i] <= 0.0 {
323            return Err(MetricsError::InvalidInput(format!(
324                "sigma[{i}] must be positive, got {}",
325                sigma[i]
326            )));
327        }
328        let z = (observations[i] - mu[i]) / sigma[i];
329        // CRPS = sigma * [ z*(2*Phi(z)-1) + 2*phi(z) - 1/sqrt(pi) ]
330        let phi_z = standard_normal_pdf(z);
331        let big_phi_z = standard_normal_cdf(z);
332        let crps_i = sigma[i]
333            * (z * (2.0 * big_phi_z - 1.0) + 2.0 * phi_z - 1.0 / std::f64::consts::PI.sqrt());
334        total += crps_i;
335    }
336
337    Ok(total / n as f64)
338}
339
340/// Computes CRPS for an empirical predictive distribution (ensemble).
341///
342/// Given an ensemble of forecasts for each observation, the CRPS is:
343/// ```text
344/// CRPS = E[|X - y|] - (1/2) * E[|X - X'|]
345/// ```
346/// where X, X' are independent draws from the ensemble.
347///
348/// # Arguments
349///
350/// * `ensemble` - Shape (n_observations, n_members) ensemble forecasts
351/// * `observations` - Observed values, length n_observations
352pub fn crps_ensemble(ensemble: &[Vec<f64>], observations: &[f64]) -> Result<f64> {
353    let n = ensemble.len();
354    if n != observations.len() {
355        return Err(MetricsError::DimensionMismatch(format!(
356            "ensemble ({}) and observations ({}) must have the same length",
357            n,
358            observations.len()
359        )));
360    }
361    if n == 0 {
362        return Err(MetricsError::InvalidInput(
363            "inputs must not be empty".to_string(),
364        ));
365    }
366
367    let mut total = 0.0f64;
368    for i in 0..n {
369        let members = &ensemble[i];
370        let m = members.len();
371        if m == 0 {
372            return Err(MetricsError::InvalidInput(format!(
373                "ensemble[{i}] must not be empty"
374            )));
375        }
376
377        // E[|X - y|]
378        let e_xy: f64 = members
379            .iter()
380            .map(|&x| (x - observations[i]).abs())
381            .sum::<f64>()
382            / m as f64;
383
384        // E[|X - X'|] using sorted trick
385        let mut sorted = members.to_vec();
386        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
387        let mut prefix = 0.0f64;
388        let mut e_xx = 0.0f64;
389        for (k, &xk) in sorted.iter().enumerate() {
390            e_xx += xk * k as f64 - prefix;
391            prefix += xk;
392        }
393        let pairs = m as f64 * (m as f64 - 1.0) / 2.0;
394        let e_xx_mean = if pairs > 0.0 { e_xx / pairs } else { 0.0 };
395
396        total += e_xy - 0.5 * e_xx_mean;
397    }
398
399    Ok(total / n as f64)
400}
401
402// ────────────────────────────────────────────────────────────────────────────
403// Directional Accuracy
404// ────────────────────────────────────────────────────────────────────────────
405
406/// Computes the directional accuracy of forecasts.
407///
408/// Directional accuracy measures the fraction of forecasts that correctly
409/// predict the sign of change from the previous period:
410/// ```text
411/// DA = (1/(n-1)) * sum_{i=1}^{n-1} 1[sign(pred[i] - pred[i-1]) == sign(obs[i] - obs[i-1])]
412/// ```
413///
414/// # Arguments
415///
416/// * `forecasts` - Forecast values (continuous)
417/// * `observations` - Observed values (same length as forecasts)
418///
419/// # Returns
420///
421/// Directional accuracy in [0, 1]. 1.0 = all directions correct.
422pub fn directional_accuracy(forecasts: &[f64], observations: &[f64]) -> Result<f64> {
423    let n = forecasts.len();
424    if n != observations.len() {
425        return Err(MetricsError::DimensionMismatch(format!(
426            "forecasts ({}) and observations ({}) must have the same length",
427            n,
428            observations.len()
429        )));
430    }
431    if n < 2 {
432        return Err(MetricsError::InvalidInput(
433            "at least 2 data points required for directional accuracy".to_string(),
434        ));
435    }
436
437    let correct = (1..n)
438        .filter(|&i| {
439            let obs_dir = (observations[i] - observations[i - 1]).signum();
440            let fc_dir = (forecasts[i] - forecasts[i - 1]).signum();
441            obs_dir == fc_dir
442        })
443        .count();
444
445    Ok(correct as f64 / (n - 1) as f64)
446}
447
448// ────────────────────────────────────────────────────────────────────────────
449// Diebold-Mariano Test
450// ────────────────────────────────────────────────────────────────────────────
451
452/// Result of the Diebold-Mariano test.
453#[derive(Debug, Clone)]
454pub struct DieboldMarianoResult {
455    /// DM test statistic
456    pub statistic: f64,
457    /// Two-sided p-value (approximate, from t-distribution with n-1 df)
458    pub p_value: f64,
459    /// Loss differential series d_t = L(e1_t) - L(e2_t)
460    pub loss_differentials: Vec<f64>,
461    /// Mean of loss differentials
462    pub mean_differential: f64,
463}
464
465/// Loss function type for the Diebold-Mariano test.
466#[non_exhaustive]
467#[derive(Debug, Clone, Copy, PartialEq)]
468pub enum DmLossFunction {
469    /// Squared error loss: L(e) = e²
470    SquaredError,
471    /// Absolute error loss: L(e) = |e|
472    AbsoluteError,
473    /// Asymmetric loss with parameter a: L(e) = e² + a*e
474    Asymmetric(f64),
475}
476
477/// Performs the Diebold-Mariano test to compare predictive accuracy of two forecasters.
478///
479/// The DM statistic tests H₀: E\[d_t\] = 0 vs H₁: E\[d_t\] ≠ 0, where
480/// d_t = L(e1_t) - L(e2_t) is the loss differential.
481///
482/// Test statistic: DM = d̄ / sqrt(V̂(d̄))
483/// where V̂(d̄) is the long-run variance of d̄ estimated via HAC (Newey-West).
484///
485/// Under H₀, DM ~ t(n-1) approximately.
486///
487/// # Arguments
488///
489/// * `actual` - Actual observed values
490/// * `forecast1` - Forecasts from model 1
491/// * `forecast2` - Forecasts from model 2
492/// * `loss_fn` - Loss function to use
493/// * `h` - Forecast horizon (for Newey-West bandwidth selection: `h - 1`)
494///
495/// # Returns
496///
497/// `DieboldMarianoResult` with test statistic and p-value.
498pub fn diebold_mariano_test(
499    actual: &[f64],
500    forecast1: &[f64],
501    forecast2: &[f64],
502    loss_fn: DmLossFunction,
503    h: usize,
504) -> Result<DieboldMarianoResult> {
505    let n = actual.len();
506    if n != forecast1.len() || n != forecast2.len() {
507        return Err(MetricsError::DimensionMismatch(
508            "actual, forecast1, forecast2 must have the same length".to_string(),
509        ));
510    }
511    if n < 3 {
512        return Err(MetricsError::InvalidInput(
513            "at least 3 observations required for DM test".to_string(),
514        ));
515    }
516
517    let loss = |e: f64| -> f64 {
518        match loss_fn {
519            DmLossFunction::SquaredError => e * e,
520            DmLossFunction::AbsoluteError => e.abs(),
521            DmLossFunction::Asymmetric(a) => e * e + a * e,
522        }
523    };
524
525    // Compute loss differentials d_t = L(e1_t) - L(e2_t)
526    let d: Vec<f64> = (0..n)
527        .map(|i| {
528            let e1 = actual[i] - forecast1[i];
529            let e2 = actual[i] - forecast2[i];
530            loss(e1) - loss(e2)
531        })
532        .collect();
533
534    let d_mean = d.iter().sum::<f64>() / n as f64;
535
536    // HAC variance estimate (Newey-West with bandwidth M = h - 1)
537    let m = h.saturating_sub(1);
538
539    let gamma_0: f64 = d.iter().map(|&di| (di - d_mean).powi(2)).sum::<f64>() / n as f64;
540
541    let mut variance = gamma_0;
542    for lag in 1..=m {
543        if lag >= n {
544            break;
545        }
546        let gamma_l: f64 = (lag..n)
547            .map(|t| (d[t] - d_mean) * (d[t - lag] - d_mean))
548            .sum::<f64>()
549            / n as f64;
550        let weight = 1.0 - lag as f64 / (m + 1) as f64; // Bartlett kernel
551        variance += 2.0 * weight * gamma_l;
552    }
553
554    // Ensure positive variance
555    if variance <= 0.0 {
556        variance = gamma_0.max(f64::EPSILON);
557    }
558
559    let se = (variance / n as f64).sqrt();
560    if se <= f64::EPSILON {
561        return Err(MetricsError::CalculationError(
562            "variance of loss differentials is effectively zero".to_string(),
563        ));
564    }
565
566    let dm_stat = d_mean / se;
567
568    // Two-sided p-value from t-distribution with (n-1) degrees of freedom
569    let p_value = two_sided_t_pvalue(dm_stat, n - 1);
570
571    Ok(DieboldMarianoResult {
572        statistic: dm_stat,
573        p_value,
574        loss_differentials: d,
575        mean_differential: d_mean,
576    })
577}
578
579// ────────────────────────────────────────────────────────────────────────────
580// Forecast Skill Summary Struct
581// ────────────────────────────────────────────────────────────────────────────
582
583/// Summary of probabilistic forecast skill metrics.
584#[derive(Debug, Clone)]
585pub struct ForecastSkillMetrics {
586    /// Brier Score (binary probabilistic forecasts)
587    pub brier_score: f64,
588    /// Brier Skill Score
589    pub bss: f64,
590    /// Continuous Ranked Probability Score (Gaussian)
591    pub crps: f64,
592    /// Directional accuracy (sign of change)
593    pub directional_accuracy: f64,
594}
595
596impl ForecastSkillMetrics {
597    /// Compute all forecast skill metrics for Gaussian predictive distributions.
598    ///
599    /// # Arguments
600    ///
601    /// * `prob_forecasts` - Probabilistic binary forecasts (for Brier score)
602    /// * `binary_obs` - Binary observations for Brier score
603    /// * `mu` - Predictive means (for CRPS)
604    /// * `sigma` - Predictive standard deviations (for CRPS)
605    /// * `point_forecasts` - Point forecasts (for directional accuracy)
606    /// * `observations` - Continuous observations
607    pub fn compute(
608        prob_forecasts: &[f64],
609        binary_obs: &[f64],
610        mu: &[f64],
611        sigma: &[f64],
612        point_forecasts: &[f64],
613        observations: &[f64],
614    ) -> Result<Self> {
615        let brier_score = brier_score_temporal(prob_forecasts, binary_obs)?;
616        let bss = brier_skill_score_temporal(prob_forecasts, binary_obs)?;
617        let crps = crps_gaussian(mu, sigma, observations)?;
618        let da = directional_accuracy(point_forecasts, observations)?;
619
620        Ok(Self {
621            brier_score,
622            bss,
623            crps,
624            directional_accuracy: da,
625        })
626    }
627}
628
629// ────────────────────────────────────────────────────────────────────────────
630// Statistical utilities
631// ────────────────────────────────────────────────────────────────────────────
632
633/// Standard normal PDF φ(z)
634fn standard_normal_pdf(z: f64) -> f64 {
635    (-0.5 * z * z).exp() / (2.0 * std::f64::consts::PI).sqrt()
636}
637
638/// Standard normal CDF Φ(z) using Abramowitz & Stegun 26.2.17 approximation.
639fn standard_normal_cdf(z: f64) -> f64 {
640    // Using rational approximation (max error 7.5e-8)
641    if z >= 0.0 {
642        1.0 - standard_normal_cdf_positive(z)
643    } else {
644        standard_normal_cdf_positive(-z)
645    }
646}
647
648fn standard_normal_cdf_positive(z: f64) -> f64 {
649    // Complementary CDF for z >= 0
650    let t = 1.0 / (1.0 + 0.2316419 * z);
651    let poly = t
652        * (0.319381530
653            + t * (-0.356563782 + t * (1.781477937 + t * (-1.821255978 + t * 1.330274429))));
654    standard_normal_pdf(z) * poly
655}
656
657/// Approximate two-sided p-value for t-distribution.
658/// Uses normal approximation for df > 30, otherwise Bailey (1994) approximation.
659fn two_sided_t_pvalue(t: f64, df: usize) -> f64 {
660    let t_abs = t.abs();
661    if df == 0 {
662        return 1.0;
663    }
664
665    let p_one_sided = if df as f64 > 30.0 {
666        // Normal approximation
667        standard_normal_cdf(-t_abs)
668    } else {
669        // Incomplete beta function approximation for t-distribution
670        let x = df as f64 / (df as f64 + t_abs * t_abs);
671        0.5 * regularized_incomplete_beta(df as f64 / 2.0, 0.5, x)
672    };
673
674    (2.0 * p_one_sided).min(1.0).max(0.0)
675}
676
677/// Regularized incomplete beta function I_x(a, b) via continued fraction.
678fn regularized_incomplete_beta(a: f64, b: f64, x: f64) -> f64 {
679    if x <= 0.0 {
680        return 0.0;
681    }
682    if x >= 1.0 {
683        return 1.0;
684    }
685
686    // Use symmetry: I_x(a,b) = 1 - I_{1-x}(b,a) when x > (a+1)/(a+b+2)
687    if x > (a + 1.0) / (a + b + 2.0) {
688        return 1.0 - regularized_incomplete_beta(b, a, 1.0 - x);
689    }
690
691    // Lentz's continued fraction
692    let lbeta = lgamma(a) + lgamma(b) - lgamma(a + b);
693    let front = (a * x.ln() + b * (1.0 - x).ln() - lbeta).exp() / a;
694
695    // Continued fraction via Lentz's method
696    let max_iter = 200;
697    let tol = 1e-10;
698    let mut c = 1.0f64;
699    let raw_d = 1.0 - (a + b) * x / (a + 1.0);
700    let mut d = if raw_d.abs() < f64::MIN_POSITIVE {
701        f64::MIN_POSITIVE
702    } else {
703        1.0 / raw_d
704    };
705    let mut f = d;
706
707    for m in 1..=max_iter {
708        let m = m as f64;
709        // Even step
710        let num_even = m * (b - m) * x / ((a + 2.0 * m - 1.0) * (a + 2.0 * m));
711        d = 1.0 + num_even * d;
712        d = if d.abs() < f64::MIN_POSITIVE {
713            f64::MIN_POSITIVE
714        } else {
715            d
716        };
717        c = 1.0 + num_even / c;
718        c = if c.abs() < f64::MIN_POSITIVE {
719            f64::MIN_POSITIVE
720        } else {
721            c
722        };
723        d = 1.0 / d;
724        f *= c * d;
725
726        // Odd step
727        let num_odd = -(a + m) * (a + b + m) * x / ((a + 2.0 * m) * (a + 2.0 * m + 1.0));
728        d = 1.0 + num_odd * d;
729        d = if d.abs() < f64::MIN_POSITIVE {
730            f64::MIN_POSITIVE
731        } else {
732            d
733        };
734        c = 1.0 + num_odd / c;
735        c = if c.abs() < f64::MIN_POSITIVE {
736            f64::MIN_POSITIVE
737        } else {
738            c
739        };
740        d = 1.0 / d;
741        let delta = c * d;
742        f *= delta;
743
744        if (delta - 1.0).abs() < tol {
745            break;
746        }
747    }
748
749    front * f
750}
751
752/// Log-gamma function via Lanczos approximation.
753fn lgamma(x: f64) -> f64 {
754    // Lanczos approximation
755    let g = 7.0;
756    let c: [f64; 9] = [
757        0.999_999_999_999_809_9,
758        676.5203681218851,
759        -1259.1392167224028,
760        771.323_428_777_653_1,
761        -176.615_029_162_140_6,
762        12.507343278686905,
763        -0.13857109526572012,
764        9.984_369_578_019_572e-6,
765        1.5056327351493116e-7,
766    ];
767
768    if x < 0.5 {
769        std::f64::consts::PI.ln() - ((std::f64::consts::PI * x).sin().abs()).ln() - lgamma(1.0 - x)
770    } else {
771        let x = x - 1.0;
772        let mut s = c[0];
773        for (i, &ci) in c[1..].iter().enumerate() {
774            s += ci / (x + (i + 1) as f64);
775        }
776        let t = x + g + 0.5;
777        0.5 * (2.0 * std::f64::consts::PI).ln() + (x + 0.5) * t.ln() - t + s.ln()
778    }
779}
780
781// ────────────────────────────────────────────────────────────────────────────
782// Tests
783// ────────────────────────────────────────────────────────────────────────────
784
785#[cfg(test)]
786mod tests {
787    use super::*;
788
789    // ── DTW tests ─────────────────────────────────────────────────────────
790
791    #[test]
792    fn test_dtw_identical() {
793        let x = vec![1.0, 2.0, 3.0, 2.0, 1.0];
794        let d = dtw(&x, &x).expect("should succeed");
795        assert!(d.abs() < 1e-10, "DTW(x,x) should be 0, got {d}");
796    }
797
798    #[test]
799    fn test_dtw_shift() {
800        // DTW should handle time-shifted signals well
801        let x = vec![0.0, 1.0, 2.0, 1.0, 0.0];
802        let y = vec![0.0, 0.0, 1.0, 2.0, 1.0]; // shifted by 1
803        let d = dtw(&x, &y).expect("should succeed");
804        // DTW is allowed to warp, so distance should be small
805        assert!(d >= 0.0, "DTW must be non-negative");
806        assert!(d < 2.0, "DTW should handle shifted signals: got {d}");
807    }
808
809    #[test]
810    fn test_dtw_windowed_matches_full_for_large_window() {
811        let x = vec![1.0, 2.0, 3.0, 2.0, 1.0];
812        let y = vec![1.5, 2.5, 3.5, 2.5, 1.5];
813        let d_full = dtw(&x, &y).expect("full DTW");
814        let d_win = dtw_windowed(&x, &y, 5).expect("windowed DTW");
815        assert!(
816            (d_full - d_win).abs() < 1e-10,
817            "large window should match full DTW"
818        );
819    }
820
821    #[test]
822    fn test_dtw_windowed_constraint() {
823        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
824        let y = vec![1.0, 2.0, 3.0, 4.0, 5.0]; // same — DTW should be 0
825        let d = dtw_windowed(&x, &y, 2).expect("windowed DTW");
826        assert!(d.abs() < 1e-10, "DTW(x,x) with window should be 0");
827    }
828
829    #[test]
830    fn test_dtw_normalized() {
831        let x = vec![1.0, 2.0, 3.0];
832        let y = vec![1.0, 2.0, 3.0];
833        let d = dtw_normalized(&x, &y, None).expect("normalized DTW");
834        assert!(d.abs() < 1e-10, "normalized DTW(x,x) should be 0");
835    }
836
837    #[test]
838    fn test_dtw_different_lengths() {
839        let x = vec![1.0, 2.0, 3.0];
840        let y = vec![1.0, 1.5, 2.0, 2.5, 3.0];
841        let d = dtw(&x, &y).expect("DTW with different lengths");
842        assert!(d >= 0.0);
843    }
844
845    // ── Brier Score tests ──────────────────────────────────────────────────
846
847    #[test]
848    fn test_brier_score_perfect() {
849        let obs = vec![1.0, 0.0, 1.0, 0.0, 1.0];
850        let fc = vec![1.0, 0.0, 1.0, 0.0, 1.0]; // perfect
851        let bs = brier_score_temporal(&fc, &obs).expect("should succeed");
852        assert!(
853            bs.abs() < 1e-10,
854            "perfect Brier Score should be 0, got {bs}"
855        );
856    }
857
858    #[test]
859    fn test_brier_score_worst() {
860        let obs = vec![1.0, 0.0, 1.0, 0.0];
861        let fc = vec![0.0, 1.0, 0.0, 1.0]; // completely wrong
862        let bs = brier_score_temporal(&fc, &obs).expect("should succeed");
863        assert!(
864            (bs - 1.0).abs() < 1e-10,
865            "worst Brier Score should be 1.0, got {bs}"
866        );
867    }
868
869    #[test]
870    fn test_brier_skill_score_no_skill() {
871        // Climatological forecast should have BSS = 0
872        let obs = vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0];
873        let base_rate = 0.5;
874        let fc = vec![base_rate; obs.len()];
875        let bss = brier_skill_score_temporal(&fc, &obs).expect("should succeed");
876        assert!(
877            bss.abs() < 1e-10,
878            "climatological BSS should be 0, got {bss}"
879        );
880    }
881
882    // ── CRPS tests ─────────────────────────────────────────────────────────
883
884    #[test]
885    fn test_crps_gaussian_perfect() {
886        // When obs exactly at mu and sigma→0, CRPS → 0
887        let mu = vec![1.0, 2.0, 3.0];
888        let sigma = vec![0.001, 0.001, 0.001]; // very sharp
889        let obs = vec![1.0, 2.0, 3.0]; // exact
890        let crps = crps_gaussian(&mu, &sigma, &obs).expect("should succeed");
891        assert!(crps >= 0.0);
892        assert!(
893            crps < 0.01,
894            "CRPS for near-perfect Gaussian should be small, got {crps}"
895        );
896    }
897
898    #[test]
899    fn test_crps_gaussian_nonnegative() {
900        let mu = vec![0.0, 1.0, 2.0, -1.0];
901        let sigma = vec![1.0, 2.0, 0.5, 1.5];
902        let obs = vec![0.5, 1.5, 1.8, -0.5];
903        let crps = crps_gaussian(&mu, &sigma, &obs).expect("should succeed");
904        assert!(crps >= 0.0, "CRPS must be non-negative, got {crps}");
905    }
906
907    #[test]
908    fn test_crps_gaussian_known_value() {
909        // For N(0, 1) and observation y=0:
910        // CRPS = sigma * [ z*(2*Phi(z)-1) + 2*phi(z) - 1/sqrt(pi) ]
911        //      = 1 * [ 0 + 2*(1/sqrt(2*pi)) - 1/sqrt(pi) ]
912        //      = 2/sqrt(2*pi) - 1/sqrt(pi)
913        //      = (sqrt(2) - 1) / sqrt(pi)
914        //      ≈ 0.2338
915        let mu = vec![0.0];
916        let sigma = vec![1.0];
917        let obs = vec![0.0];
918        let crps = crps_gaussian(&mu, &sigma, &obs).expect("should succeed");
919        let expected = (2.0_f64.sqrt() - 1.0) / std::f64::consts::PI.sqrt();
920        assert!(
921            (crps - expected).abs() < 1e-4,
922            "CRPS(N(0,1), y=0) ≈ {expected:.4}, got {crps:.4}"
923        );
924    }
925
926    #[test]
927    fn test_crps_ensemble_identical() {
928        // Ensemble forecast exactly at observation
929        let ensemble = vec![vec![2.0, 2.0, 2.0], vec![5.0, 5.0, 5.0]];
930        let obs = vec![2.0, 5.0];
931        let crps = crps_ensemble(&ensemble, &obs).expect("should succeed");
932        assert!(crps >= 0.0);
933        assert!(
934            crps < 1e-6,
935            "perfect ensemble CRPS should be ~0, got {crps}"
936        );
937    }
938
939    // ── Directional Accuracy tests ─────────────────────────────────────────
940
941    #[test]
942    fn test_directional_accuracy_all_correct() {
943        let obs = vec![1.0, 2.0, 3.0, 2.0, 1.0];
944        let fc = vec![1.1, 2.1, 3.1, 1.9, 0.9]; // same directions
945        let da = directional_accuracy(&fc, &obs).expect("should succeed");
946        assert!(
947            (da - 1.0).abs() < 1e-10,
948            "all correct DA should be 1.0, got {da}"
949        );
950    }
951
952    #[test]
953    fn test_directional_accuracy_all_wrong() {
954        let obs = vec![1.0, 2.0, 3.0, 4.0];
955        let fc = vec![4.0, 3.0, 2.0, 1.0]; // opposite directions
956        let da = directional_accuracy(&fc, &obs).expect("should succeed");
957        assert!(
958            (da - 0.0).abs() < 1e-10,
959            "all wrong DA should be 0.0, got {da}"
960        );
961    }
962
963    #[test]
964    fn test_directional_accuracy_half() {
965        // Design: n=5, 4 changes, need exactly 2 correct
966        // obs: 1->2 (+), 2->1 (-), 1->2 (+), 2->1 (-)
967        // fc:  1->2 (+), 2->2 (0), 1->2 (+), 2->1 (-)  →  +,0,+,- vs +,-,+,- → 3 correct
968        // Better: forecast exactly opposite for first 2 changes, same for last 2
969        // obs: 0->1(+), 1->0(-), 0->1(+), 1->0(-)
970        // fc:  0->0(0), 0->1(+), 1->0(-), 0->1(+) →  0,+,-,+ vs +,-,+,- → 0 matches
971        // Simplest: 2 correct, 2 wrong
972        // obs changes: -, +, -, +  (obs=[4,2,6,1,5])
973        // fc  changes: -, +, +, -  (fc =[4,2,6,8,3]) → match: +,+,-,- vs correct: -,+,-,+
974        // i=1: obs=-,fc=- MATCH; i=2: obs=+,fc=+ MATCH; i=3: obs=-,fc=+ NO; i=4: obs=+,fc=- NO
975        let obs = vec![4.0, 2.0, 6.0, 1.0, 5.0];
976        let fc = vec![4.0, 2.0, 6.0, 8.0, 3.0];
977        let da = directional_accuracy(&fc, &obs).expect("should succeed");
978        assert!(
979            (da - 0.5).abs() < 1e-10,
980            "half-correct DA should be 0.5, got {da}"
981        );
982    }
983
984    // ── Diebold-Mariano tests ──────────────────────────────────────────────
985
986    #[test]
987    fn test_diebold_mariano_identical_forecasts() {
988        // If both forecasts are identical, d_t = 0 for all t
989        // Mean differential should be exactly 0
990        let actual = vec![1.0, 2.0, 3.0, 4.0, 5.0];
991        let fc1 = vec![1.1, 2.1, 3.1, 4.1, 5.1];
992        let fc2 = fc1.clone();
993        let result = diebold_mariano_test(&actual, &fc1, &fc2, DmLossFunction::SquaredError, 1);
994        // Loss differentials are all 0 → variance is 0 → error, or statistic is 0
995        match result {
996            Err(_) => {} // expected: zero variance triggers error
997            Ok(r) => {
998                // If it doesn't error, mean differential must be 0
999                assert!(
1000                    r.mean_differential.abs() < 1e-10,
1001                    "mean differential for identical forecasts should be 0"
1002                );
1003            }
1004        }
1005    }
1006
1007    #[test]
1008    fn test_diebold_mariano_clearly_different() {
1009        let actual: Vec<f64> = (0..20).map(|i| i as f64).collect();
1010        // Model 1: very poor forecasts
1011        let fc1: Vec<f64> = actual.iter().map(|&x| x + 5.0).collect();
1012        // Model 2: much better
1013        let fc2: Vec<f64> = actual.iter().map(|&x| x + 0.1).collect();
1014        let result = diebold_mariano_test(&actual, &fc1, &fc2, DmLossFunction::SquaredError, 1)
1015            .expect("should succeed");
1016        // fc1 has larger loss, so mean differential should be positive
1017        assert!(
1018            result.mean_differential > 0.0,
1019            "model1 should have higher loss"
1020        );
1021        assert!(result.p_value >= 0.0 && result.p_value <= 1.0);
1022    }
1023
1024    #[test]
1025    fn test_forecast_skill_metrics_compute() {
1026        let n = 10;
1027        let prob_fc: Vec<f64> = (0..n).map(|i| if i < 5 { 0.9 } else { 0.1 }).collect();
1028        let bin_obs: Vec<f64> = (0..n).map(|i| if i < 5 { 1.0 } else { 0.0 }).collect();
1029        let mu: Vec<f64> = (0..n).map(|i| i as f64).collect();
1030        let sigma = vec![1.0; n];
1031        let point_fc: Vec<f64> = mu.iter().map(|&x| x + 0.1).collect();
1032        let obs: Vec<f64> = mu.clone();
1033
1034        let metrics =
1035            ForecastSkillMetrics::compute(&prob_fc, &bin_obs, &mu, &sigma, &point_fc, &obs)
1036                .expect("should succeed");
1037
1038        assert!(metrics.brier_score >= 0.0);
1039        assert!(metrics.crps >= 0.0);
1040        assert!(metrics.directional_accuracy >= 0.0 && metrics.directional_accuracy <= 1.0);
1041    }
1042}