Skip to main content

tensorlogic_infer/
uncertainty.rs

1//! Uncertainty estimation for probabilistic predictions.
2//!
3//! Provides Monte Carlo sampling, calibration metrics, confidence intervals,
4//! and prediction intervals for regression and classification tasks.
5
6use std::fmt;
7
8// ---------------------------------------------------------------------------
9// Error type
10// ---------------------------------------------------------------------------
11
12/// Errors that can occur during uncertainty estimation.
13#[derive(Debug, Clone)]
14pub enum UncertaintyError {
15    EmptyPredictions,
16    InvalidNumSamples(usize),
17    InvalidConfidenceLevel(f64),
18    ShapeMismatch { expected: usize, got: usize },
19    InvalidBins(usize),
20    SamplingError(String),
21}
22
23impl fmt::Display for UncertaintyError {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        match self {
26            UncertaintyError::EmptyPredictions => write!(f, "predictions slice is empty"),
27            UncertaintyError::InvalidNumSamples(n) => {
28                write!(f, "num_samples must be >= 1, got {n}")
29            }
30            UncertaintyError::InvalidConfidenceLevel(l) => {
31                write!(f, "confidence_level must be in (0, 1), got {l}")
32            }
33            UncertaintyError::ShapeMismatch { expected, got } => {
34                write!(f, "shape mismatch: expected {expected}, got {got}")
35            }
36            UncertaintyError::InvalidBins(b) => {
37                write!(f, "num_bins must be >= 1, got {b}")
38            }
39            UncertaintyError::SamplingError(msg) => write!(f, "sampling error: {msg}"),
40        }
41    }
42}
43
44impl std::error::Error for UncertaintyError {}
45
46// ---------------------------------------------------------------------------
47// Simple LCG-based RNG (no external rand crate)
48// ---------------------------------------------------------------------------
49
50struct SimpleUncertaintyRng {
51    state: u64,
52}
53
54impl SimpleUncertaintyRng {
55    fn new(seed: u64) -> Self {
56        Self {
57            state: seed ^ 0x9e3779b97f4a7c15,
58        }
59    }
60
61    /// Returns a uniformly distributed f64 in [0, 1).
62    fn next_f64(&mut self) -> f64 {
63        // Xorshift64 for better quality than a plain LCG
64        self.state ^= self.state << 13;
65        self.state ^= self.state >> 7;
66        self.state ^= self.state << 17;
67        // Map to [0, 1)
68        (self.state as f64) / (u64::MAX as f64 + 1.0)
69    }
70
71    /// Box-Muller transform — returns a standard normal sample N(0,1).
72    fn next_normal(&mut self) -> f64 {
73        let u1 = self.next_f64().max(1e-15); // avoid log(0)
74        let u2 = self.next_f64();
75        let r = (-2.0 * u1.ln()).sqrt();
76        let theta = std::f64::consts::TAU * u2;
77        r * theta.cos()
78    }
79}
80
81// ---------------------------------------------------------------------------
82// ConfidenceInterval
83// ---------------------------------------------------------------------------
84
85/// Method used to construct a confidence interval.
86#[derive(Debug, Clone, PartialEq)]
87pub enum IntervalMethod {
88    /// Empirical percentiles derived from sample quantiles.
89    Percentile,
90    /// Gaussian approximation: mean ± z * std.
91    Normal,
92}
93
94/// A confidence interval [lower, upper] at a given level (e.g. 0.95).
95#[derive(Debug, Clone)]
96pub struct ConfidenceInterval {
97    pub lower: f64,
98    pub upper: f64,
99    pub level: f64,
100    pub method: IntervalMethod,
101}
102
103impl ConfidenceInterval {
104    /// Construct a percentile-based CI by sorting `samples` and taking quantiles.
105    pub fn percentile(samples: &[f64], level: f64) -> Result<Self, UncertaintyError> {
106        if samples.is_empty() {
107            return Err(UncertaintyError::EmptyPredictions);
108        }
109        if level <= 0.0 || level >= 1.0 {
110            return Err(UncertaintyError::InvalidConfidenceLevel(level));
111        }
112        let mut sorted = samples.to_vec();
113        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
114        let alpha = (1.0 - level) / 2.0;
115        let lower = quantile_sorted(&sorted, alpha);
116        let upper = quantile_sorted(&sorted, 1.0 - alpha);
117        Ok(Self {
118            lower,
119            upper,
120            level,
121            method: IntervalMethod::Percentile,
122        })
123    }
124
125    /// Construct a Normal CI using mean ± z * std.
126    pub fn normal(mean: f64, std: f64, level: f64) -> Self {
127        let z = z_score(level);
128        Self {
129            lower: mean - z * std,
130            upper: mean + z * std,
131            level,
132            method: IntervalMethod::Normal,
133        }
134    }
135
136    /// Width of the interval.
137    pub fn width(&self) -> f64 {
138        self.upper - self.lower
139    }
140
141    /// Whether `value` lies within [lower, upper].
142    pub fn contains(&self, value: f64) -> bool {
143        value >= self.lower && value <= self.upper
144    }
145}
146
147// ---------------------------------------------------------------------------
148// UncertaintyEstimate
149// ---------------------------------------------------------------------------
150
151/// A single uncertainty estimate for a prediction.
152#[derive(Debug, Clone)]
153pub struct UncertaintyEstimate {
154    pub mean: f64,
155    pub variance: f64,
156    pub std_dev: f64,
157    pub confidence_interval: ConfidenceInterval,
158    pub entropy: f64,
159    /// Uncertainty due to the model (epistemic).
160    pub epistemic_uncertainty: f64,
161    /// Uncertainty inherent in the data (aleatoric).
162    pub aleatoric_uncertainty: f64,
163}
164
165impl UncertaintyEstimate {
166    /// Compute an estimate from a slice of scalar samples.
167    pub fn from_samples(samples: &[f64], confidence_level: f64) -> Result<Self, UncertaintyError> {
168        if samples.is_empty() {
169            return Err(UncertaintyError::EmptyPredictions);
170        }
171        if confidence_level <= 0.0 || confidence_level >= 1.0 {
172            return Err(UncertaintyError::InvalidConfidenceLevel(confidence_level));
173        }
174        let n = samples.len() as f64;
175        let mean = samples.iter().sum::<f64>() / n;
176        let variance = samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
177        let std_dev = variance.sqrt();
178        let confidence_interval = ConfidenceInterval::percentile(samples, confidence_level)?;
179        let entropy = histogram_entropy(samples, 10);
180
181        // With a single vector of samples we treat the full variance as epistemic.
182        let epistemic_uncertainty = variance;
183        let aleatoric_uncertainty = 0.0;
184
185        Ok(Self {
186            mean,
187            variance,
188            std_dev,
189            confidence_interval,
190            entropy,
191            epistemic_uncertainty,
192            aleatoric_uncertainty,
193        })
194    }
195
196    /// Returns `true` when the standard deviation is below `threshold`.
197    pub fn is_confident(&self, threshold: f64) -> bool {
198        self.std_dev < threshold
199    }
200
201    /// Human-readable summary string.
202    pub fn summary(&self) -> String {
203        format!(
204            "UncertaintyEstimate {{ mean: {:.4}, std: {:.4}, CI [{:.4}, {:.4}] @{:.0}%, \
205             entropy: {:.4}, epistemic: {:.4}, aleatoric: {:.4} }}",
206            self.mean,
207            self.std_dev,
208            self.confidence_interval.lower,
209            self.confidence_interval.upper,
210            self.confidence_interval.level * 100.0,
211            self.entropy,
212            self.epistemic_uncertainty,
213            self.aleatoric_uncertainty,
214        )
215    }
216}
217
218// ---------------------------------------------------------------------------
219// MonteCarloEstimator
220// ---------------------------------------------------------------------------
221
222/// Monte Carlo estimator: run a function N times with injected Gaussian noise.
223pub struct MonteCarloEstimator {
224    pub num_samples: usize,
225    pub confidence_level: f64,
226    rng: SimpleUncertaintyRng,
227}
228
229impl MonteCarloEstimator {
230    /// Create a new estimator.
231    ///
232    /// * `num_samples` – must be >= 1.
233    /// * `confidence_level` – must be in (0, 1).
234    /// * `seed` – RNG seed for reproducibility.
235    pub fn new(
236        num_samples: usize,
237        confidence_level: f64,
238        seed: u64,
239    ) -> Result<Self, UncertaintyError> {
240        if num_samples < 1 {
241            return Err(UncertaintyError::InvalidNumSamples(num_samples));
242        }
243        if confidence_level <= 0.0 || confidence_level >= 1.0 {
244            return Err(UncertaintyError::InvalidConfidenceLevel(confidence_level));
245        }
246        Ok(Self {
247            num_samples,
248            confidence_level,
249            rng: SimpleUncertaintyRng::new(seed),
250        })
251    }
252
253    /// Convenient constructor with defaults: 100 samples, 0.95 CI, seed=42.
254    pub fn with_defaults() -> Self {
255        // Safe: values are valid.
256        Self {
257            num_samples: 100,
258            confidence_level: 0.95,
259            rng: SimpleUncertaintyRng::new(42),
260        }
261    }
262
263    /// Run `f(noise)` `num_samples` times, injecting N(0,1) noise each call.
264    pub fn estimate<F>(&mut self, f: F) -> Result<UncertaintyEstimate, UncertaintyError>
265    where
266        F: Fn(f64) -> f64,
267    {
268        let samples: Vec<f64> = (0..self.num_samples)
269            .map(|_| {
270                let noise = self.rng.next_normal();
271                f(noise)
272            })
273            .collect();
274        UncertaintyEstimate::from_samples(&samples, self.confidence_level)
275    }
276
277    /// Run `f(noise)` `num_samples` times, average per element.
278    ///
279    /// `dim` – expected length of the vector returned by `f`.
280    pub fn estimate_vector<F>(
281        &mut self,
282        dim: usize,
283        f: F,
284    ) -> Result<Vec<UncertaintyEstimate>, UncertaintyError>
285    where
286        F: Fn(f64) -> Vec<f64>,
287    {
288        if dim == 0 {
289            return Err(UncertaintyError::ShapeMismatch {
290                expected: 1,
291                got: 0,
292            });
293        }
294        // Collect num_samples × dim matrix
295        let mut matrix: Vec<Vec<f64>> = Vec::with_capacity(self.num_samples);
296        for _ in 0..self.num_samples {
297            let noise = self.rng.next_normal();
298            let row = f(noise);
299            if row.len() != dim {
300                return Err(UncertaintyError::ShapeMismatch {
301                    expected: dim,
302                    got: row.len(),
303                });
304            }
305            matrix.push(row);
306        }
307
308        // Transpose: for each dimension, collect samples across runs
309        let mut estimates = Vec::with_capacity(dim);
310        for col in 0..dim {
311            let col_samples: Vec<f64> = matrix.iter().map(|row| row[col]).collect();
312            let est = UncertaintyEstimate::from_samples(&col_samples, self.confidence_level)?;
313            estimates.push(est);
314        }
315        Ok(estimates)
316    }
317}
318
319// ---------------------------------------------------------------------------
320// CalibrationMetrics
321// ---------------------------------------------------------------------------
322
323/// Statistics for a single calibration bin.
324#[derive(Debug, Clone)]
325pub struct CalibrationBin {
326    pub confidence_lower: f64,
327    pub confidence_upper: f64,
328    pub count: usize,
329    pub avg_confidence: f64,
330    pub accuracy: f64,
331    /// avg_confidence − accuracy (positive → overconfident).
332    pub gap: f64,
333}
334
335/// Calibration metrics for probabilistic classifiers.
336#[derive(Debug, Clone)]
337pub struct CalibrationMetrics {
338    /// Expected Calibration Error.
339    pub ece: f64,
340    /// Maximum Calibration Error.
341    pub mce: f64,
342    /// Average gap in overconfident bins.
343    pub overconfidence: f64,
344    /// Average gap in underconfident bins.
345    pub underconfidence: f64,
346    pub num_bins: usize,
347    pub bin_stats: Vec<CalibrationBin>,
348}
349
350impl CalibrationMetrics {
351    /// Compute calibration metrics from predicted probabilities and true binary labels.
352    ///
353    /// * `predicted_probs` – probability of positive class, in [0, 1].
354    /// * `true_labels` – 0 or 1.
355    /// * `num_bins` – number of equal-width bins.
356    pub fn compute(
357        predicted_probs: &[f64],
358        true_labels: &[u8],
359        num_bins: usize,
360    ) -> Result<Self, UncertaintyError> {
361        if predicted_probs.is_empty() {
362            return Err(UncertaintyError::EmptyPredictions);
363        }
364        if num_bins < 1 {
365            return Err(UncertaintyError::InvalidBins(num_bins));
366        }
367        if predicted_probs.len() != true_labels.len() {
368            return Err(UncertaintyError::ShapeMismatch {
369                expected: predicted_probs.len(),
370                got: true_labels.len(),
371            });
372        }
373
374        let total = predicted_probs.len() as f64;
375        let bin_width = 1.0 / num_bins as f64;
376
377        // Accumulate per-bin sums
378        let mut bin_conf_sum = vec![0.0_f64; num_bins];
379        let mut bin_acc_sum = vec![0.0_f64; num_bins];
380        let mut bin_count = vec![0usize; num_bins];
381
382        for (p, y) in predicted_probs.iter().zip(true_labels.iter()) {
383            let p = p.clamp(0.0, 1.0);
384            let bin_idx = ((p / bin_width).floor() as usize).min(num_bins - 1);
385            bin_conf_sum[bin_idx] += p;
386            bin_acc_sum[bin_idx] += *y as f64;
387            bin_count[bin_idx] += 1;
388        }
389
390        let mut bin_stats = Vec::with_capacity(num_bins);
391        let mut ece = 0.0_f64;
392        let mut mce = 0.0_f64;
393        let mut over_gaps = Vec::new();
394        let mut under_gaps = Vec::new();
395
396        for i in 0..num_bins {
397            let count = bin_count[i];
398            let conf_lower = i as f64 * bin_width;
399            let conf_upper = conf_lower + bin_width;
400            let (avg_confidence, accuracy, gap) = if count == 0 {
401                (0.0, 0.0, 0.0)
402            } else {
403                let avg_conf = bin_conf_sum[i] / count as f64;
404                let acc = bin_acc_sum[i] / count as f64;
405                (avg_conf, acc, avg_conf - acc)
406            };
407
408            if count > 0 {
409                let weight = count as f64 / total;
410                ece += weight * gap.abs();
411                mce = mce.max(gap.abs());
412                if gap > 0.0 {
413                    over_gaps.push(gap);
414                } else if gap < 0.0 {
415                    under_gaps.push(-gap);
416                }
417            }
418
419            bin_stats.push(CalibrationBin {
420                confidence_lower: conf_lower,
421                confidence_upper: conf_upper,
422                count,
423                avg_confidence,
424                accuracy,
425                gap,
426            });
427        }
428
429        let overconfidence = if over_gaps.is_empty() {
430            0.0
431        } else {
432            over_gaps.iter().sum::<f64>() / over_gaps.len() as f64
433        };
434        let underconfidence = if under_gaps.is_empty() {
435            0.0
436        } else {
437            under_gaps.iter().sum::<f64>() / under_gaps.len() as f64
438        };
439
440        Ok(Self {
441            ece,
442            mce,
443            overconfidence,
444            underconfidence,
445            num_bins,
446            bin_stats,
447        })
448    }
449
450    /// True when ECE is below `ece_threshold`.
451    pub fn is_well_calibrated(&self, ece_threshold: f64) -> bool {
452        self.ece < ece_threshold
453    }
454
455    /// ASCII reliability diagram (confidence vs accuracy per bin).
456    pub fn format_reliability_diagram(&self) -> String {
457        let mut lines = vec!["Reliability Diagram (conf → accuracy):".to_string()];
458        for bin in &self.bin_stats {
459            if bin.count == 0 {
460                continue;
461            }
462            let bar_len = (bin.accuracy * 20.0).round() as usize;
463            let bar = "#".repeat(bar_len);
464            lines.push(format!(
465                "[{:.2},{:.2}] n={:4}  acc={:.3}  conf={:.3}  gap={:+.3}  |{}|",
466                bin.confidence_lower,
467                bin.confidence_upper,
468                bin.count,
469                bin.accuracy,
470                bin.avg_confidence,
471                bin.gap,
472                bar,
473            ));
474        }
475        lines.join("\n")
476    }
477
478    /// Short summary string.
479    pub fn summary(&self) -> String {
480        format!(
481            "CalibrationMetrics {{ ECE: {:.4}, MCE: {:.4}, over: {:.4}, under: {:.4}, bins: {} }}",
482            self.ece, self.mce, self.overconfidence, self.underconfidence, self.num_bins
483        )
484    }
485}
486
487// ---------------------------------------------------------------------------
488// Temperature scaling
489// ---------------------------------------------------------------------------
490
491/// Apply temperature scaling to raw logits: softmax(logit / T).
492pub fn temperature_scale(logits: &[f64], temperature: f64) -> Vec<f64> {
493    let scaled: Vec<f64> = logits.iter().map(|l| l / temperature).collect();
494    softmax_vec(&scaled)
495}
496
497/// Find the temperature that minimises the negative log-likelihood on the
498/// provided logits and true labels.
499///
500/// `temperatures` – a grid of candidate temperatures to evaluate.
501pub fn find_optimal_temperature(
502    logits: &[f64],
503    true_labels: &[u8],
504    temperatures: &[f64],
505) -> Result<f64, UncertaintyError> {
506    if logits.is_empty() {
507        return Err(UncertaintyError::EmptyPredictions);
508    }
509    if logits.len() != true_labels.len() {
510        return Err(UncertaintyError::ShapeMismatch {
511            expected: logits.len(),
512            got: true_labels.len(),
513        });
514    }
515    if temperatures.is_empty() {
516        return Err(UncertaintyError::SamplingError(
517            "temperatures slice is empty".to_string(),
518        ));
519    }
520
521    let mut best_temp = temperatures[0];
522    let mut best_nll = f64::INFINITY;
523
524    for &t in temperatures {
525        if t <= 0.0 {
526            continue;
527        }
528        let nll = compute_nll(logits, true_labels, t);
529        if nll < best_nll {
530            best_nll = nll;
531            best_temp = t;
532        }
533    }
534    Ok(best_temp)
535}
536
537// ---------------------------------------------------------------------------
538// PredictionInterval
539// ---------------------------------------------------------------------------
540
541/// Prediction intervals for regression tasks.
542#[derive(Debug, Clone)]
543pub struct PredictionInterval {
544    pub predictions: Vec<f64>,
545    pub lower_bounds: Vec<f64>,
546    pub upper_bounds: Vec<f64>,
547    /// Empirical coverage (fraction of actuals inside the interval), if actuals provided.
548    pub coverage: f64,
549    /// Average interval width.
550    pub avg_width: f64,
551}
552
553impl PredictionInterval {
554    /// Construct from quantile predictions.
555    ///
556    /// If `actuals` is `Some`, compute empirical coverage.
557    pub fn from_quantile_predictions(
558        lower_preds: Vec<f64>,
559        upper_preds: Vec<f64>,
560        actuals: Option<&[f64]>,
561    ) -> Result<Self, UncertaintyError> {
562        if lower_preds.is_empty() {
563            return Err(UncertaintyError::EmptyPredictions);
564        }
565        if lower_preds.len() != upper_preds.len() {
566            return Err(UncertaintyError::ShapeMismatch {
567                expected: lower_preds.len(),
568                got: upper_preds.len(),
569            });
570        }
571
572        // Use midpoint of [lower, upper] as point prediction
573        let predictions: Vec<f64> = lower_preds
574            .iter()
575            .zip(upper_preds.iter())
576            .map(|(lo, hi)| (lo + hi) / 2.0)
577            .collect();
578
579        let avg_width = lower_preds
580            .iter()
581            .zip(upper_preds.iter())
582            .map(|(lo, hi)| (hi - lo).abs())
583            .sum::<f64>()
584            / lower_preds.len() as f64;
585
586        let coverage = match actuals {
587            None => 0.0,
588            Some(act) => {
589                if act.len() != lower_preds.len() {
590                    return Err(UncertaintyError::ShapeMismatch {
591                        expected: lower_preds.len(),
592                        got: act.len(),
593                    });
594                }
595                let covered = lower_preds
596                    .iter()
597                    .zip(upper_preds.iter())
598                    .zip(act.iter())
599                    .filter(|((lo, hi), y)| *y >= *lo && *y <= *hi)
600                    .count();
601                covered as f64 / act.len() as f64
602            }
603        };
604
605        Ok(Self {
606            predictions,
607            lower_bounds: lower_preds,
608            upper_bounds: upper_preds,
609            coverage,
610            avg_width,
611        })
612    }
613
614    /// Short summary string.
615    pub fn summary(&self) -> String {
616        format!(
617            "PredictionInterval {{ n: {}, avg_width: {:.4}, coverage: {:.4} }}",
618            self.predictions.len(),
619            self.avg_width,
620            self.coverage,
621        )
622    }
623}
624
625// ---------------------------------------------------------------------------
626// Private helpers
627// ---------------------------------------------------------------------------
628
629/// Linear interpolation quantile on a sorted slice.
630fn quantile_sorted(sorted: &[f64], p: f64) -> f64 {
631    let n = sorted.len();
632    if n == 1 {
633        return sorted[0];
634    }
635    let idx = p * (n as f64 - 1.0);
636    let lo = idx.floor() as usize;
637    let hi = (lo + 1).min(n - 1);
638    let frac = idx - lo as f64;
639    sorted[lo] * (1.0 - frac) + sorted[hi] * frac
640}
641
642/// Return a z-score for a given confidence level (two-tailed).
643fn z_score(level: f64) -> f64 {
644    if (level - 0.99).abs() < 1e-9 {
645        2.576
646    } else if (level - 0.90).abs() < 1e-9 {
647        1.645
648    } else {
649        // Default to 1.96 (≈ 0.95)
650        1.96
651    }
652}
653
654/// Entropy from an empirical histogram of `samples` with `num_bins` bins.
655fn histogram_entropy(samples: &[f64], num_bins: usize) -> f64 {
656    if samples.is_empty() || num_bins == 0 {
657        return 0.0;
658    }
659    let min = samples.iter().cloned().fold(f64::INFINITY, f64::min);
660    let max = samples.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
661    if (max - min).abs() < f64::EPSILON {
662        return 0.0;
663    }
664    let width = (max - min) / num_bins as f64;
665    let mut counts = vec![0usize; num_bins];
666    for &x in samples {
667        let idx = (((x - min) / width).floor() as usize).min(num_bins - 1);
668        counts[idx] += 1;
669    }
670    let n = samples.len() as f64;
671    counts
672        .iter()
673        .filter(|&&c| c > 0)
674        .map(|&c| {
675            let p = c as f64 / n;
676            -p * p.ln()
677        })
678        .sum()
679}
680
681/// Numerically stable softmax over a vector.
682fn softmax_vec(logits: &[f64]) -> Vec<f64> {
683    if logits.is_empty() {
684        return Vec::new();
685    }
686    let max = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
687    let exps: Vec<f64> = logits.iter().map(|l| (l - max).exp()).collect();
688    let sum: f64 = exps.iter().sum();
689    if sum == 0.0 {
690        return vec![1.0 / logits.len() as f64; logits.len()];
691    }
692    exps.iter().map(|e| e / sum).collect()
693}
694
695/// Compute NLL for binary classification with temperature scaling.
696/// Logits are treated as log-odds (sigmoid output).
697fn compute_nll(logits: &[f64], true_labels: &[u8], temperature: f64) -> f64 {
698    let mut nll = 0.0_f64;
699    for (&l, &y) in logits.iter().zip(true_labels.iter()) {
700        let scaled = l / temperature;
701        // Sigmoid probability
702        let p = sigmoid(scaled);
703        let p_clamped = p.clamp(1e-15, 1.0 - 1e-15);
704        if y == 1 {
705            nll -= p_clamped.ln();
706        } else {
707            nll -= (1.0 - p_clamped).ln();
708        }
709    }
710    nll / logits.len() as f64
711}
712
713fn sigmoid(x: f64) -> f64 {
714    if x >= 0.0 {
715        let e = (-x).exp();
716        1.0 / (1.0 + e)
717    } else {
718        let e = x.exp();
719        e / (1.0 + e)
720    }
721}
722
723// ---------------------------------------------------------------------------
724// Tests
725// ---------------------------------------------------------------------------
726
727#[cfg(test)]
728mod tests {
729    use super::*;
730
731    // ---- UncertaintyEstimate -----------------------------------------------
732
733    #[test]
734    fn test_uncertainty_estimate_from_samples_basic() {
735        let samples: Vec<f64> = (0..100).map(|i| i as f64).collect();
736        let est = UncertaintyEstimate::from_samples(&samples, 0.95).unwrap();
737        // Mean of 0..99 = 49.5
738        assert!((est.mean - 49.5).abs() < 0.01, "mean={}", est.mean);
739        assert!(est.variance > 0.0);
740        assert!(est.std_dev > 0.0);
741    }
742
743    #[test]
744    fn test_uncertainty_estimate_confident() {
745        // Near-constant samples → very low std_dev
746        let samples = vec![1.0_f64; 50];
747        let est = UncertaintyEstimate::from_samples(&samples, 0.95).unwrap();
748        assert!(est.is_confident(0.1), "std_dev should be ~0");
749    }
750
751    #[test]
752    fn test_uncertainty_estimate_not_confident() {
753        let samples: Vec<f64> = (0..100).map(|i| i as f64 * 10.0).collect();
754        let est = UncertaintyEstimate::from_samples(&samples, 0.95).unwrap();
755        assert!(!est.is_confident(1.0), "high variance → not confident");
756    }
757
758    #[test]
759    fn test_uncertainty_estimate_summary_nonempty() {
760        let samples: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
761        let est = UncertaintyEstimate::from_samples(&samples, 0.95).unwrap();
762        let s = est.summary();
763        assert!(!s.is_empty());
764        assert!(s.contains("mean"));
765    }
766
767    // ---- ConfidenceInterval ------------------------------------------------
768
769    #[test]
770    fn test_confidence_interval_percentile() {
771        let samples: Vec<f64> = (0..1000).map(|i| i as f64).collect();
772        let ci = ConfidenceInterval::percentile(&samples, 0.95).unwrap();
773        assert!(ci.lower < ci.upper, "lower={} upper={}", ci.lower, ci.upper);
774        assert_eq!(ci.method, IntervalMethod::Percentile);
775    }
776
777    #[test]
778    fn test_confidence_interval_normal_width() {
779        // 95% CI: width = 2 * 1.96 * std
780        let mean = 0.0;
781        let std = 1.0;
782        let ci = ConfidenceInterval::normal(mean, std, 0.95);
783        let expected_width = 2.0 * 1.96 * std;
784        assert!(
785            (ci.width() - expected_width).abs() < 1e-9,
786            "width={}",
787            ci.width()
788        );
789    }
790
791    #[test]
792    fn test_confidence_interval_contains() {
793        let samples: Vec<f64> = (0..1000).map(|i| i as f64 / 10.0).collect();
794        let ci = ConfidenceInterval::percentile(&samples, 0.95).unwrap();
795        let mean = samples.iter().sum::<f64>() / samples.len() as f64;
796        assert!(ci.contains(mean), "mean should be inside CI");
797    }
798
799    #[test]
800    fn test_confidence_interval_width_positive() {
801        let ci = ConfidenceInterval::normal(5.0, 2.0, 0.95);
802        assert!(ci.width() > 0.0);
803    }
804
805    // ---- MonteCarloEstimator -----------------------------------------------
806
807    #[test]
808    fn test_mc_estimator_with_defaults() {
809        let est = MonteCarloEstimator::with_defaults();
810        assert_eq!(est.num_samples, 100);
811        assert!((est.confidence_level - 0.95).abs() < 1e-9);
812    }
813
814    #[test]
815    fn test_mc_estimator_estimate_constant_fn() {
816        let mut mc = MonteCarloEstimator::new(200, 0.95, 1).unwrap();
817        // f(noise) = 5.0 regardless of noise
818        let est = mc.estimate(|_noise| 5.0).unwrap();
819        assert!((est.mean - 5.0).abs() < 1e-9, "mean={}", est.mean);
820        assert!(est.std_dev < 1e-9, "std_dev={}", est.std_dev);
821    }
822
823    #[test]
824    fn test_mc_estimator_estimate_linear_fn() {
825        let mut mc = MonteCarloEstimator::new(2000, 0.95, 7).unwrap();
826        // f(noise) = noise ~ N(0,1) → mean ≈ 0, std ≈ 1
827        let est = mc.estimate(|noise| noise).unwrap();
828        assert!(est.mean.abs() < 0.15, "mean should be ~0, got {}", est.mean);
829        assert!(
830            (est.std_dev - 1.0).abs() < 0.15,
831            "std_dev should be ~1, got {}",
832            est.std_dev
833        );
834    }
835
836    #[test]
837    fn test_mc_estimator_estimate_vector() {
838        let mut mc = MonteCarloEstimator::new(50, 0.95, 99).unwrap();
839        let dim = 4;
840        let estimates = mc.estimate_vector(dim, |noise| vec![noise; dim]).unwrap();
841        assert_eq!(estimates.len(), dim);
842    }
843
844    // ---- CalibrationMetrics ------------------------------------------------
845
846    #[test]
847    fn test_calibration_metrics_compute_perfect() {
848        // When predicted probability equals true label (perfect, deterministic)
849        // ECE should be 0 (or very close).
850        let predicted: Vec<f64> = vec![1.0; 100];
851        let labels: Vec<u8> = vec![1u8; 100];
852        let metrics = CalibrationMetrics::compute(&predicted, &labels, 10).unwrap();
853        assert!(
854            metrics.ece < 1e-9,
855            "ECE should be 0 for perfect preds, got {}",
856            metrics.ece
857        );
858    }
859
860    #[test]
861    fn test_calibration_metrics_compute_uniform() {
862        // Uniformly random predictions → ECE > 0
863        let mut rng = SimpleUncertaintyRng::new(42);
864        let n = 200;
865        let predicted: Vec<f64> = (0..n).map(|_| rng.next_f64()).collect();
866        let labels: Vec<u8> = (0..n).map(|i| (i % 2) as u8).collect();
867        let metrics = CalibrationMetrics::compute(&predicted, &labels, 10).unwrap();
868        // ECE might be small by chance but the test just checks it is non-negative and computable
869        assert!(metrics.ece >= 0.0);
870        assert!(metrics.num_bins == 10);
871    }
872
873    #[test]
874    fn test_calibration_metrics_bins() {
875        let predicted = vec![0.1, 0.5, 0.9];
876        let labels = vec![0u8, 1, 1];
877        let metrics = CalibrationMetrics::compute(&predicted, &labels, 5).unwrap();
878        assert_eq!(metrics.num_bins, 5);
879        assert_eq!(metrics.bin_stats.len(), 5);
880    }
881
882    #[test]
883    fn test_calibration_is_well_calibrated() {
884        // Perfect calibration
885        let predicted = vec![1.0_f64; 50];
886        let labels = vec![1u8; 50];
887        let metrics = CalibrationMetrics::compute(&predicted, &labels, 5).unwrap();
888        assert!(metrics.is_well_calibrated(0.01));
889    }
890
891    // ---- Temperature scaling -----------------------------------------------
892
893    #[test]
894    fn test_temperature_scale_identity() {
895        let logits = vec![1.0, 2.0, 3.0];
896        let scaled = temperature_scale(&logits, 1.0);
897        let direct = {
898            let exps: Vec<f64> = logits.iter().map(|l| l.exp()).collect();
899            let s: f64 = exps.iter().sum();
900            exps.iter().map(|e| e / s).collect::<Vec<_>>()
901        };
902        for (a, b) in scaled.iter().zip(direct.iter()) {
903            assert!((a - b).abs() < 1e-9, "a={a} b={b}");
904        }
905    }
906
907    #[test]
908    fn test_temperature_scale_high_temp() {
909        // High temperature should push towards uniform distribution
910        let logits = vec![10.0, 0.0, 0.0];
911        let high_t = temperature_scale(&logits, 100.0);
912        // Each should be close to 1/3
913        for p in &high_t {
914            assert!((p - 1.0 / 3.0).abs() < 0.1, "p={p}");
915        }
916    }
917
918    #[test]
919    fn test_find_optimal_temperature() {
920        let logits: Vec<f64> = vec![2.0, -1.0, 0.5, -2.0, 1.0];
921        let labels: Vec<u8> = vec![1, 0, 1, 0, 1];
922        let temps: Vec<f64> = vec![0.5, 1.0, 2.0, 4.0];
923        let opt_t = find_optimal_temperature(&logits, &labels, &temps).unwrap();
924        assert!(temps.contains(&opt_t), "optimal temp not in candidates");
925    }
926
927    // ---- PredictionInterval ------------------------------------------------
928
929    #[test]
930    fn test_prediction_interval_basic() {
931        let lower = vec![0.0, 1.0, 2.0];
932        let upper = vec![1.0, 2.0, 3.0];
933        let pi = PredictionInterval::from_quantile_predictions(lower, upper, None).unwrap();
934        assert_eq!(pi.predictions.len(), 3);
935        assert!((pi.avg_width - 1.0).abs() < 1e-9);
936        let s = pi.summary();
937        assert!(!s.is_empty());
938    }
939
940    #[test]
941    fn test_prediction_interval_coverage() {
942        let lower = vec![0.0, 1.0, 2.0, 3.0];
943        let upper = vec![1.0, 2.0, 3.0, 4.0];
944        // All actuals inside the interval
945        let actuals = vec![0.5, 1.5, 2.5, 3.5];
946        let pi =
947            PredictionInterval::from_quantile_predictions(lower, upper, Some(&actuals)).unwrap();
948        assert!((pi.coverage - 1.0).abs() < 1e-9, "coverage={}", pi.coverage);
949    }
950
951    #[test]
952    fn test_prediction_interval_partial_coverage() {
953        let lower = vec![0.0, 0.0, 0.0, 0.0];
954        let upper = vec![1.0, 1.0, 1.0, 1.0];
955        // Half inside, half outside
956        let actuals = vec![0.5, 0.5, 2.0, 2.0];
957        let pi =
958            PredictionInterval::from_quantile_predictions(lower, upper, Some(&actuals)).unwrap();
959        assert!((pi.coverage - 0.5).abs() < 1e-9, "coverage={}", pi.coverage);
960    }
961}