1use std::fmt;
7
8#[derive(Debug, Clone)]
14pub enum UncertaintyError {
15 EmptyPredictions,
16 InvalidNumSamples(usize),
17 InvalidConfidenceLevel(f64),
18 ShapeMismatch { expected: usize, got: usize },
19 InvalidBins(usize),
20 SamplingError(String),
21}
22
23impl fmt::Display for UncertaintyError {
24 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25 match self {
26 UncertaintyError::EmptyPredictions => write!(f, "predictions slice is empty"),
27 UncertaintyError::InvalidNumSamples(n) => {
28 write!(f, "num_samples must be >= 1, got {n}")
29 }
30 UncertaintyError::InvalidConfidenceLevel(l) => {
31 write!(f, "confidence_level must be in (0, 1), got {l}")
32 }
33 UncertaintyError::ShapeMismatch { expected, got } => {
34 write!(f, "shape mismatch: expected {expected}, got {got}")
35 }
36 UncertaintyError::InvalidBins(b) => {
37 write!(f, "num_bins must be >= 1, got {b}")
38 }
39 UncertaintyError::SamplingError(msg) => write!(f, "sampling error: {msg}"),
40 }
41 }
42}
43
44impl std::error::Error for UncertaintyError {}
45
46struct SimpleUncertaintyRng {
51 state: u64,
52}
53
54impl SimpleUncertaintyRng {
55 fn new(seed: u64) -> Self {
56 Self {
57 state: seed ^ 0x9e3779b97f4a7c15,
58 }
59 }
60
61 fn next_f64(&mut self) -> f64 {
63 self.state ^= self.state << 13;
65 self.state ^= self.state >> 7;
66 self.state ^= self.state << 17;
67 (self.state as f64) / (u64::MAX as f64 + 1.0)
69 }
70
71 fn next_normal(&mut self) -> f64 {
73 let u1 = self.next_f64().max(1e-15); let u2 = self.next_f64();
75 let r = (-2.0 * u1.ln()).sqrt();
76 let theta = std::f64::consts::TAU * u2;
77 r * theta.cos()
78 }
79}
80
81#[derive(Debug, Clone, PartialEq)]
87pub enum IntervalMethod {
88 Percentile,
90 Normal,
92}
93
94#[derive(Debug, Clone)]
96pub struct ConfidenceInterval {
97 pub lower: f64,
98 pub upper: f64,
99 pub level: f64,
100 pub method: IntervalMethod,
101}
102
103impl ConfidenceInterval {
104 pub fn percentile(samples: &[f64], level: f64) -> Result<Self, UncertaintyError> {
106 if samples.is_empty() {
107 return Err(UncertaintyError::EmptyPredictions);
108 }
109 if level <= 0.0 || level >= 1.0 {
110 return Err(UncertaintyError::InvalidConfidenceLevel(level));
111 }
112 let mut sorted = samples.to_vec();
113 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
114 let alpha = (1.0 - level) / 2.0;
115 let lower = quantile_sorted(&sorted, alpha);
116 let upper = quantile_sorted(&sorted, 1.0 - alpha);
117 Ok(Self {
118 lower,
119 upper,
120 level,
121 method: IntervalMethod::Percentile,
122 })
123 }
124
125 pub fn normal(mean: f64, std: f64, level: f64) -> Self {
127 let z = z_score(level);
128 Self {
129 lower: mean - z * std,
130 upper: mean + z * std,
131 level,
132 method: IntervalMethod::Normal,
133 }
134 }
135
136 pub fn width(&self) -> f64 {
138 self.upper - self.lower
139 }
140
141 pub fn contains(&self, value: f64) -> bool {
143 value >= self.lower && value <= self.upper
144 }
145}
146
147#[derive(Debug, Clone)]
153pub struct UncertaintyEstimate {
154 pub mean: f64,
155 pub variance: f64,
156 pub std_dev: f64,
157 pub confidence_interval: ConfidenceInterval,
158 pub entropy: f64,
159 pub epistemic_uncertainty: f64,
161 pub aleatoric_uncertainty: f64,
163}
164
165impl UncertaintyEstimate {
166 pub fn from_samples(samples: &[f64], confidence_level: f64) -> Result<Self, UncertaintyError> {
168 if samples.is_empty() {
169 return Err(UncertaintyError::EmptyPredictions);
170 }
171 if confidence_level <= 0.0 || confidence_level >= 1.0 {
172 return Err(UncertaintyError::InvalidConfidenceLevel(confidence_level));
173 }
174 let n = samples.len() as f64;
175 let mean = samples.iter().sum::<f64>() / n;
176 let variance = samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
177 let std_dev = variance.sqrt();
178 let confidence_interval = ConfidenceInterval::percentile(samples, confidence_level)?;
179 let entropy = histogram_entropy(samples, 10);
180
181 let epistemic_uncertainty = variance;
183 let aleatoric_uncertainty = 0.0;
184
185 Ok(Self {
186 mean,
187 variance,
188 std_dev,
189 confidence_interval,
190 entropy,
191 epistemic_uncertainty,
192 aleatoric_uncertainty,
193 })
194 }
195
196 pub fn is_confident(&self, threshold: f64) -> bool {
198 self.std_dev < threshold
199 }
200
201 pub fn summary(&self) -> String {
203 format!(
204 "UncertaintyEstimate {{ mean: {:.4}, std: {:.4}, CI [{:.4}, {:.4}] @{:.0}%, \
205 entropy: {:.4}, epistemic: {:.4}, aleatoric: {:.4} }}",
206 self.mean,
207 self.std_dev,
208 self.confidence_interval.lower,
209 self.confidence_interval.upper,
210 self.confidence_interval.level * 100.0,
211 self.entropy,
212 self.epistemic_uncertainty,
213 self.aleatoric_uncertainty,
214 )
215 }
216}
217
218pub struct MonteCarloEstimator {
224 pub num_samples: usize,
225 pub confidence_level: f64,
226 rng: SimpleUncertaintyRng,
227}
228
229impl MonteCarloEstimator {
230 pub fn new(
236 num_samples: usize,
237 confidence_level: f64,
238 seed: u64,
239 ) -> Result<Self, UncertaintyError> {
240 if num_samples < 1 {
241 return Err(UncertaintyError::InvalidNumSamples(num_samples));
242 }
243 if confidence_level <= 0.0 || confidence_level >= 1.0 {
244 return Err(UncertaintyError::InvalidConfidenceLevel(confidence_level));
245 }
246 Ok(Self {
247 num_samples,
248 confidence_level,
249 rng: SimpleUncertaintyRng::new(seed),
250 })
251 }
252
253 pub fn with_defaults() -> Self {
255 Self {
257 num_samples: 100,
258 confidence_level: 0.95,
259 rng: SimpleUncertaintyRng::new(42),
260 }
261 }
262
263 pub fn estimate<F>(&mut self, f: F) -> Result<UncertaintyEstimate, UncertaintyError>
265 where
266 F: Fn(f64) -> f64,
267 {
268 let samples: Vec<f64> = (0..self.num_samples)
269 .map(|_| {
270 let noise = self.rng.next_normal();
271 f(noise)
272 })
273 .collect();
274 UncertaintyEstimate::from_samples(&samples, self.confidence_level)
275 }
276
277 pub fn estimate_vector<F>(
281 &mut self,
282 dim: usize,
283 f: F,
284 ) -> Result<Vec<UncertaintyEstimate>, UncertaintyError>
285 where
286 F: Fn(f64) -> Vec<f64>,
287 {
288 if dim == 0 {
289 return Err(UncertaintyError::ShapeMismatch {
290 expected: 1,
291 got: 0,
292 });
293 }
294 let mut matrix: Vec<Vec<f64>> = Vec::with_capacity(self.num_samples);
296 for _ in 0..self.num_samples {
297 let noise = self.rng.next_normal();
298 let row = f(noise);
299 if row.len() != dim {
300 return Err(UncertaintyError::ShapeMismatch {
301 expected: dim,
302 got: row.len(),
303 });
304 }
305 matrix.push(row);
306 }
307
308 let mut estimates = Vec::with_capacity(dim);
310 for col in 0..dim {
311 let col_samples: Vec<f64> = matrix.iter().map(|row| row[col]).collect();
312 let est = UncertaintyEstimate::from_samples(&col_samples, self.confidence_level)?;
313 estimates.push(est);
314 }
315 Ok(estimates)
316 }
317}
318
319#[derive(Debug, Clone)]
325pub struct CalibrationBin {
326 pub confidence_lower: f64,
327 pub confidence_upper: f64,
328 pub count: usize,
329 pub avg_confidence: f64,
330 pub accuracy: f64,
331 pub gap: f64,
333}
334
335#[derive(Debug, Clone)]
337pub struct CalibrationMetrics {
338 pub ece: f64,
340 pub mce: f64,
342 pub overconfidence: f64,
344 pub underconfidence: f64,
346 pub num_bins: usize,
347 pub bin_stats: Vec<CalibrationBin>,
348}
349
350impl CalibrationMetrics {
351 pub fn compute(
357 predicted_probs: &[f64],
358 true_labels: &[u8],
359 num_bins: usize,
360 ) -> Result<Self, UncertaintyError> {
361 if predicted_probs.is_empty() {
362 return Err(UncertaintyError::EmptyPredictions);
363 }
364 if num_bins < 1 {
365 return Err(UncertaintyError::InvalidBins(num_bins));
366 }
367 if predicted_probs.len() != true_labels.len() {
368 return Err(UncertaintyError::ShapeMismatch {
369 expected: predicted_probs.len(),
370 got: true_labels.len(),
371 });
372 }
373
374 let total = predicted_probs.len() as f64;
375 let bin_width = 1.0 / num_bins as f64;
376
377 let mut bin_conf_sum = vec![0.0_f64; num_bins];
379 let mut bin_acc_sum = vec![0.0_f64; num_bins];
380 let mut bin_count = vec![0usize; num_bins];
381
382 for (p, y) in predicted_probs.iter().zip(true_labels.iter()) {
383 let p = p.clamp(0.0, 1.0);
384 let bin_idx = ((p / bin_width).floor() as usize).min(num_bins - 1);
385 bin_conf_sum[bin_idx] += p;
386 bin_acc_sum[bin_idx] += *y as f64;
387 bin_count[bin_idx] += 1;
388 }
389
390 let mut bin_stats = Vec::with_capacity(num_bins);
391 let mut ece = 0.0_f64;
392 let mut mce = 0.0_f64;
393 let mut over_gaps = Vec::new();
394 let mut under_gaps = Vec::new();
395
396 for i in 0..num_bins {
397 let count = bin_count[i];
398 let conf_lower = i as f64 * bin_width;
399 let conf_upper = conf_lower + bin_width;
400 let (avg_confidence, accuracy, gap) = if count == 0 {
401 (0.0, 0.0, 0.0)
402 } else {
403 let avg_conf = bin_conf_sum[i] / count as f64;
404 let acc = bin_acc_sum[i] / count as f64;
405 (avg_conf, acc, avg_conf - acc)
406 };
407
408 if count > 0 {
409 let weight = count as f64 / total;
410 ece += weight * gap.abs();
411 mce = mce.max(gap.abs());
412 if gap > 0.0 {
413 over_gaps.push(gap);
414 } else if gap < 0.0 {
415 under_gaps.push(-gap);
416 }
417 }
418
419 bin_stats.push(CalibrationBin {
420 confidence_lower: conf_lower,
421 confidence_upper: conf_upper,
422 count,
423 avg_confidence,
424 accuracy,
425 gap,
426 });
427 }
428
429 let overconfidence = if over_gaps.is_empty() {
430 0.0
431 } else {
432 over_gaps.iter().sum::<f64>() / over_gaps.len() as f64
433 };
434 let underconfidence = if under_gaps.is_empty() {
435 0.0
436 } else {
437 under_gaps.iter().sum::<f64>() / under_gaps.len() as f64
438 };
439
440 Ok(Self {
441 ece,
442 mce,
443 overconfidence,
444 underconfidence,
445 num_bins,
446 bin_stats,
447 })
448 }
449
450 pub fn is_well_calibrated(&self, ece_threshold: f64) -> bool {
452 self.ece < ece_threshold
453 }
454
455 pub fn format_reliability_diagram(&self) -> String {
457 let mut lines = vec!["Reliability Diagram (conf → accuracy):".to_string()];
458 for bin in &self.bin_stats {
459 if bin.count == 0 {
460 continue;
461 }
462 let bar_len = (bin.accuracy * 20.0).round() as usize;
463 let bar = "#".repeat(bar_len);
464 lines.push(format!(
465 "[{:.2},{:.2}] n={:4} acc={:.3} conf={:.3} gap={:+.3} |{}|",
466 bin.confidence_lower,
467 bin.confidence_upper,
468 bin.count,
469 bin.accuracy,
470 bin.avg_confidence,
471 bin.gap,
472 bar,
473 ));
474 }
475 lines.join("\n")
476 }
477
478 pub fn summary(&self) -> String {
480 format!(
481 "CalibrationMetrics {{ ECE: {:.4}, MCE: {:.4}, over: {:.4}, under: {:.4}, bins: {} }}",
482 self.ece, self.mce, self.overconfidence, self.underconfidence, self.num_bins
483 )
484 }
485}
486
487pub fn temperature_scale(logits: &[f64], temperature: f64) -> Vec<f64> {
493 let scaled: Vec<f64> = logits.iter().map(|l| l / temperature).collect();
494 softmax_vec(&scaled)
495}
496
497pub fn find_optimal_temperature(
502 logits: &[f64],
503 true_labels: &[u8],
504 temperatures: &[f64],
505) -> Result<f64, UncertaintyError> {
506 if logits.is_empty() {
507 return Err(UncertaintyError::EmptyPredictions);
508 }
509 if logits.len() != true_labels.len() {
510 return Err(UncertaintyError::ShapeMismatch {
511 expected: logits.len(),
512 got: true_labels.len(),
513 });
514 }
515 if temperatures.is_empty() {
516 return Err(UncertaintyError::SamplingError(
517 "temperatures slice is empty".to_string(),
518 ));
519 }
520
521 let mut best_temp = temperatures[0];
522 let mut best_nll = f64::INFINITY;
523
524 for &t in temperatures {
525 if t <= 0.0 {
526 continue;
527 }
528 let nll = compute_nll(logits, true_labels, t);
529 if nll < best_nll {
530 best_nll = nll;
531 best_temp = t;
532 }
533 }
534 Ok(best_temp)
535}
536
537#[derive(Debug, Clone)]
543pub struct PredictionInterval {
544 pub predictions: Vec<f64>,
545 pub lower_bounds: Vec<f64>,
546 pub upper_bounds: Vec<f64>,
547 pub coverage: f64,
549 pub avg_width: f64,
551}
552
553impl PredictionInterval {
554 pub fn from_quantile_predictions(
558 lower_preds: Vec<f64>,
559 upper_preds: Vec<f64>,
560 actuals: Option<&[f64]>,
561 ) -> Result<Self, UncertaintyError> {
562 if lower_preds.is_empty() {
563 return Err(UncertaintyError::EmptyPredictions);
564 }
565 if lower_preds.len() != upper_preds.len() {
566 return Err(UncertaintyError::ShapeMismatch {
567 expected: lower_preds.len(),
568 got: upper_preds.len(),
569 });
570 }
571
572 let predictions: Vec<f64> = lower_preds
574 .iter()
575 .zip(upper_preds.iter())
576 .map(|(lo, hi)| (lo + hi) / 2.0)
577 .collect();
578
579 let avg_width = lower_preds
580 .iter()
581 .zip(upper_preds.iter())
582 .map(|(lo, hi)| (hi - lo).abs())
583 .sum::<f64>()
584 / lower_preds.len() as f64;
585
586 let coverage = match actuals {
587 None => 0.0,
588 Some(act) => {
589 if act.len() != lower_preds.len() {
590 return Err(UncertaintyError::ShapeMismatch {
591 expected: lower_preds.len(),
592 got: act.len(),
593 });
594 }
595 let covered = lower_preds
596 .iter()
597 .zip(upper_preds.iter())
598 .zip(act.iter())
599 .filter(|((lo, hi), y)| *y >= *lo && *y <= *hi)
600 .count();
601 covered as f64 / act.len() as f64
602 }
603 };
604
605 Ok(Self {
606 predictions,
607 lower_bounds: lower_preds,
608 upper_bounds: upper_preds,
609 coverage,
610 avg_width,
611 })
612 }
613
614 pub fn summary(&self) -> String {
616 format!(
617 "PredictionInterval {{ n: {}, avg_width: {:.4}, coverage: {:.4} }}",
618 self.predictions.len(),
619 self.avg_width,
620 self.coverage,
621 )
622 }
623}
624
625fn quantile_sorted(sorted: &[f64], p: f64) -> f64 {
631 let n = sorted.len();
632 if n == 1 {
633 return sorted[0];
634 }
635 let idx = p * (n as f64 - 1.0);
636 let lo = idx.floor() as usize;
637 let hi = (lo + 1).min(n - 1);
638 let frac = idx - lo as f64;
639 sorted[lo] * (1.0 - frac) + sorted[hi] * frac
640}
641
642fn z_score(level: f64) -> f64 {
644 if (level - 0.99).abs() < 1e-9 {
645 2.576
646 } else if (level - 0.90).abs() < 1e-9 {
647 1.645
648 } else {
649 1.96
651 }
652}
653
654fn histogram_entropy(samples: &[f64], num_bins: usize) -> f64 {
656 if samples.is_empty() || num_bins == 0 {
657 return 0.0;
658 }
659 let min = samples.iter().cloned().fold(f64::INFINITY, f64::min);
660 let max = samples.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
661 if (max - min).abs() < f64::EPSILON {
662 return 0.0;
663 }
664 let width = (max - min) / num_bins as f64;
665 let mut counts = vec![0usize; num_bins];
666 for &x in samples {
667 let idx = (((x - min) / width).floor() as usize).min(num_bins - 1);
668 counts[idx] += 1;
669 }
670 let n = samples.len() as f64;
671 counts
672 .iter()
673 .filter(|&&c| c > 0)
674 .map(|&c| {
675 let p = c as f64 / n;
676 -p * p.ln()
677 })
678 .sum()
679}
680
681fn softmax_vec(logits: &[f64]) -> Vec<f64> {
683 if logits.is_empty() {
684 return Vec::new();
685 }
686 let max = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
687 let exps: Vec<f64> = logits.iter().map(|l| (l - max).exp()).collect();
688 let sum: f64 = exps.iter().sum();
689 if sum == 0.0 {
690 return vec![1.0 / logits.len() as f64; logits.len()];
691 }
692 exps.iter().map(|e| e / sum).collect()
693}
694
695fn compute_nll(logits: &[f64], true_labels: &[u8], temperature: f64) -> f64 {
698 let mut nll = 0.0_f64;
699 for (&l, &y) in logits.iter().zip(true_labels.iter()) {
700 let scaled = l / temperature;
701 let p = sigmoid(scaled);
703 let p_clamped = p.clamp(1e-15, 1.0 - 1e-15);
704 if y == 1 {
705 nll -= p_clamped.ln();
706 } else {
707 nll -= (1.0 - p_clamped).ln();
708 }
709 }
710 nll / logits.len() as f64
711}
712
713fn sigmoid(x: f64) -> f64 {
714 if x >= 0.0 {
715 let e = (-x).exp();
716 1.0 / (1.0 + e)
717 } else {
718 let e = x.exp();
719 e / (1.0 + e)
720 }
721}
722
723#[cfg(test)]
728mod tests {
729 use super::*;
730
731 #[test]
734 fn test_uncertainty_estimate_from_samples_basic() {
735 let samples: Vec<f64> = (0..100).map(|i| i as f64).collect();
736 let est = UncertaintyEstimate::from_samples(&samples, 0.95).unwrap();
737 assert!((est.mean - 49.5).abs() < 0.01, "mean={}", est.mean);
739 assert!(est.variance > 0.0);
740 assert!(est.std_dev > 0.0);
741 }
742
743 #[test]
744 fn test_uncertainty_estimate_confident() {
745 let samples = vec![1.0_f64; 50];
747 let est = UncertaintyEstimate::from_samples(&samples, 0.95).unwrap();
748 assert!(est.is_confident(0.1), "std_dev should be ~0");
749 }
750
751 #[test]
752 fn test_uncertainty_estimate_not_confident() {
753 let samples: Vec<f64> = (0..100).map(|i| i as f64 * 10.0).collect();
754 let est = UncertaintyEstimate::from_samples(&samples, 0.95).unwrap();
755 assert!(!est.is_confident(1.0), "high variance → not confident");
756 }
757
758 #[test]
759 fn test_uncertainty_estimate_summary_nonempty() {
760 let samples: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
761 let est = UncertaintyEstimate::from_samples(&samples, 0.95).unwrap();
762 let s = est.summary();
763 assert!(!s.is_empty());
764 assert!(s.contains("mean"));
765 }
766
767 #[test]
770 fn test_confidence_interval_percentile() {
771 let samples: Vec<f64> = (0..1000).map(|i| i as f64).collect();
772 let ci = ConfidenceInterval::percentile(&samples, 0.95).unwrap();
773 assert!(ci.lower < ci.upper, "lower={} upper={}", ci.lower, ci.upper);
774 assert_eq!(ci.method, IntervalMethod::Percentile);
775 }
776
777 #[test]
778 fn test_confidence_interval_normal_width() {
779 let mean = 0.0;
781 let std = 1.0;
782 let ci = ConfidenceInterval::normal(mean, std, 0.95);
783 let expected_width = 2.0 * 1.96 * std;
784 assert!(
785 (ci.width() - expected_width).abs() < 1e-9,
786 "width={}",
787 ci.width()
788 );
789 }
790
791 #[test]
792 fn test_confidence_interval_contains() {
793 let samples: Vec<f64> = (0..1000).map(|i| i as f64 / 10.0).collect();
794 let ci = ConfidenceInterval::percentile(&samples, 0.95).unwrap();
795 let mean = samples.iter().sum::<f64>() / samples.len() as f64;
796 assert!(ci.contains(mean), "mean should be inside CI");
797 }
798
799 #[test]
800 fn test_confidence_interval_width_positive() {
801 let ci = ConfidenceInterval::normal(5.0, 2.0, 0.95);
802 assert!(ci.width() > 0.0);
803 }
804
805 #[test]
808 fn test_mc_estimator_with_defaults() {
809 let est = MonteCarloEstimator::with_defaults();
810 assert_eq!(est.num_samples, 100);
811 assert!((est.confidence_level - 0.95).abs() < 1e-9);
812 }
813
814 #[test]
815 fn test_mc_estimator_estimate_constant_fn() {
816 let mut mc = MonteCarloEstimator::new(200, 0.95, 1).unwrap();
817 let est = mc.estimate(|_noise| 5.0).unwrap();
819 assert!((est.mean - 5.0).abs() < 1e-9, "mean={}", est.mean);
820 assert!(est.std_dev < 1e-9, "std_dev={}", est.std_dev);
821 }
822
823 #[test]
824 fn test_mc_estimator_estimate_linear_fn() {
825 let mut mc = MonteCarloEstimator::new(2000, 0.95, 7).unwrap();
826 let est = mc.estimate(|noise| noise).unwrap();
828 assert!(est.mean.abs() < 0.15, "mean should be ~0, got {}", est.mean);
829 assert!(
830 (est.std_dev - 1.0).abs() < 0.15,
831 "std_dev should be ~1, got {}",
832 est.std_dev
833 );
834 }
835
836 #[test]
837 fn test_mc_estimator_estimate_vector() {
838 let mut mc = MonteCarloEstimator::new(50, 0.95, 99).unwrap();
839 let dim = 4;
840 let estimates = mc.estimate_vector(dim, |noise| vec![noise; dim]).unwrap();
841 assert_eq!(estimates.len(), dim);
842 }
843
844 #[test]
847 fn test_calibration_metrics_compute_perfect() {
848 let predicted: Vec<f64> = vec![1.0; 100];
851 let labels: Vec<u8> = vec![1u8; 100];
852 let metrics = CalibrationMetrics::compute(&predicted, &labels, 10).unwrap();
853 assert!(
854 metrics.ece < 1e-9,
855 "ECE should be 0 for perfect preds, got {}",
856 metrics.ece
857 );
858 }
859
860 #[test]
861 fn test_calibration_metrics_compute_uniform() {
862 let mut rng = SimpleUncertaintyRng::new(42);
864 let n = 200;
865 let predicted: Vec<f64> = (0..n).map(|_| rng.next_f64()).collect();
866 let labels: Vec<u8> = (0..n).map(|i| (i % 2) as u8).collect();
867 let metrics = CalibrationMetrics::compute(&predicted, &labels, 10).unwrap();
868 assert!(metrics.ece >= 0.0);
870 assert!(metrics.num_bins == 10);
871 }
872
873 #[test]
874 fn test_calibration_metrics_bins() {
875 let predicted = vec![0.1, 0.5, 0.9];
876 let labels = vec![0u8, 1, 1];
877 let metrics = CalibrationMetrics::compute(&predicted, &labels, 5).unwrap();
878 assert_eq!(metrics.num_bins, 5);
879 assert_eq!(metrics.bin_stats.len(), 5);
880 }
881
882 #[test]
883 fn test_calibration_is_well_calibrated() {
884 let predicted = vec![1.0_f64; 50];
886 let labels = vec![1u8; 50];
887 let metrics = CalibrationMetrics::compute(&predicted, &labels, 5).unwrap();
888 assert!(metrics.is_well_calibrated(0.01));
889 }
890
891 #[test]
894 fn test_temperature_scale_identity() {
895 let logits = vec![1.0, 2.0, 3.0];
896 let scaled = temperature_scale(&logits, 1.0);
897 let direct = {
898 let exps: Vec<f64> = logits.iter().map(|l| l.exp()).collect();
899 let s: f64 = exps.iter().sum();
900 exps.iter().map(|e| e / s).collect::<Vec<_>>()
901 };
902 for (a, b) in scaled.iter().zip(direct.iter()) {
903 assert!((a - b).abs() < 1e-9, "a={a} b={b}");
904 }
905 }
906
907 #[test]
908 fn test_temperature_scale_high_temp() {
909 let logits = vec![10.0, 0.0, 0.0];
911 let high_t = temperature_scale(&logits, 100.0);
912 for p in &high_t {
914 assert!((p - 1.0 / 3.0).abs() < 0.1, "p={p}");
915 }
916 }
917
918 #[test]
919 fn test_find_optimal_temperature() {
920 let logits: Vec<f64> = vec![2.0, -1.0, 0.5, -2.0, 1.0];
921 let labels: Vec<u8> = vec![1, 0, 1, 0, 1];
922 let temps: Vec<f64> = vec![0.5, 1.0, 2.0, 4.0];
923 let opt_t = find_optimal_temperature(&logits, &labels, &temps).unwrap();
924 assert!(temps.contains(&opt_t), "optimal temp not in candidates");
925 }
926
927 #[test]
930 fn test_prediction_interval_basic() {
931 let lower = vec![0.0, 1.0, 2.0];
932 let upper = vec![1.0, 2.0, 3.0];
933 let pi = PredictionInterval::from_quantile_predictions(lower, upper, None).unwrap();
934 assert_eq!(pi.predictions.len(), 3);
935 assert!((pi.avg_width - 1.0).abs() < 1e-9);
936 let s = pi.summary();
937 assert!(!s.is_empty());
938 }
939
940 #[test]
941 fn test_prediction_interval_coverage() {
942 let lower = vec![0.0, 1.0, 2.0, 3.0];
943 let upper = vec![1.0, 2.0, 3.0, 4.0];
944 let actuals = vec![0.5, 1.5, 2.5, 3.5];
946 let pi =
947 PredictionInterval::from_quantile_predictions(lower, upper, Some(&actuals)).unwrap();
948 assert!((pi.coverage - 1.0).abs() < 1e-9, "coverage={}", pi.coverage);
949 }
950
951 #[test]
952 fn test_prediction_interval_partial_coverage() {
953 let lower = vec![0.0, 0.0, 0.0, 0.0];
954 let upper = vec![1.0, 1.0, 1.0, 1.0];
955 let actuals = vec![0.5, 0.5, 2.0, 2.0];
957 let pi =
958 PredictionInterval::from_quantile_predictions(lower, upper, Some(&actuals)).unwrap();
959 assert!((pi.coverage - 0.5).abs() < 1e-9, "coverage={}", pi.coverage);
960 }
961}