Skip to main content

scirs2_stats/conformal/
types.rs

1//! Core types for Conformal Prediction
2//!
3//! Defines configuration structures, prediction set representations, and result
4//! types used throughout the conformal prediction framework.
5
6/// Nonconformity score type for conformal prediction
7///
8/// Each variant determines how nonconformity scores are computed during
9/// calibration and inference.
10#[derive(Debug, Clone, Copy, PartialEq)]
11#[non_exhaustive]
12pub enum ScoreType {
13    /// |y - ŷ| for regression tasks
14    AbsResidual,
15    /// Conformal Quantile Regression (Romano et al. 2019):
16    /// s_i = max(q̂_lo(x_i) - y_i, y_i - q̂_hi(x_i))
17    QuantileRegression,
18    /// |y - ŷ| / σ̂  where σ̂ is a local difficulty estimate
19    NormalizedResidual,
20    /// Highest predictive density score: s = 1 - p(y | x)
21    Hpd,
22    /// Regularized Adaptive Prediction Sets (RAPS, Angelopoulos 2021)
23    /// for multi-class classification
24    Raps,
25}
26
27impl Default for ScoreType {
28    fn default() -> Self {
29        ScoreType::AbsResidual
30    }
31}
32
33/// Configuration for split/inductive conformal prediction
34#[derive(Debug, Clone)]
35pub struct ConformalConfig {
36    /// Significance level α ∈ (0, 1).  Coverage target is 1 − α.
37    pub alpha: f64,
38    /// Nonconformity score type to use.
39    pub score_fn: ScoreType,
40}
41
42impl Default for ConformalConfig {
43    fn default() -> Self {
44        Self {
45            alpha: 0.1,
46            score_fn: ScoreType::AbsResidual,
47        }
48    }
49}
50
51/// A prediction set for a single test point.
52///
53/// For regression tasks the set is the interval `[lower, upper]`.
54/// For classification tasks the set is a collection of class indices.
55#[derive(Debug, Clone, PartialEq)]
56pub struct PredictionSet {
57    /// Lower bound of the prediction interval (regression).
58    pub lower: f64,
59    /// Upper bound of the prediction interval (regression).
60    pub upper: f64,
61    /// Predicted class indices included in the set (classification).
62    pub set: Vec<usize>,
63}
64
65impl PredictionSet {
66    /// Create a regression interval prediction set.
67    pub fn interval(lower: f64, upper: f64) -> Self {
68        Self {
69            lower,
70            upper,
71            set: Vec::new(),
72        }
73    }
74
75    /// Create a classification prediction set.
76    pub fn classification(set: Vec<usize>) -> Self {
77        Self {
78            lower: f64::NEG_INFINITY,
79            upper: f64::INFINITY,
80            set,
81        }
82    }
83
84    /// Return `true` if `value` is inside the regression interval.
85    pub fn contains_value(&self, value: f64) -> bool {
86        value >= self.lower && value <= self.upper
87    }
88
89    /// Return `true` if `class` is in the classification set.
90    pub fn contains_class(&self, class: usize) -> bool {
91        self.set.contains(&class)
92    }
93
94    /// Width of the regression interval.  Returns `f64::INFINITY` for
95    /// classification sets.
96    pub fn width(&self) -> f64 {
97        if self.set.is_empty() {
98            self.upper - self.lower
99        } else {
100            f64::INFINITY
101        }
102    }
103}
104
105/// Aggregated results for a batch of conformal predictions.
106#[derive(Debug, Clone)]
107pub struct ConformalResult {
108    /// One [`PredictionSet`] per test point.
109    pub sets: Vec<PredictionSet>,
110    /// Empirical coverage of the prediction sets over the test batch
111    /// (fraction of sets that contain the true label).
112    pub coverage: f64,
113    /// Average width (regression) or average set size (classification).
114    pub avg_width: f64,
115}
116
117/// Configuration specific to RAPS (Regularized Adaptive Prediction Sets).
118#[derive(Debug, Clone)]
119pub struct RapsConfig {
120    /// Regularization threshold: classes ranked beyond `k_reg` incur a penalty.
121    pub k_reg: usize,
122    /// Regularization strength λ.  Larger values encourage smaller sets.
123    pub lambda: f64,
124}
125
126impl Default for RapsConfig {
127    fn default() -> Self {
128        Self {
129            k_reg: 5,
130            lambda: 0.01,
131        }
132    }
133}
134
135/// High-level configuration for adaptive conformal prediction.
136#[derive(Debug, Clone)]
137pub struct CpConfig {
138    /// Desired marginal coverage probability (e.g. 0.9 for 90% coverage).
139    pub coverage_target: f64,
140    /// When `true`, locally-adaptive scores (normalized / RAPS) are used.
141    pub adaptive: bool,
142}
143
144impl Default for CpConfig {
145    fn default() -> Self {
146        Self {
147            coverage_target: 0.9,
148            adaptive: false,
149        }
150    }
151}
152
153/// Compute the empirical (1−α)-quantile with the finite-sample correction
154/// (1 + 1/n) used in split conformal inference.
155///
156/// Returns `f64::INFINITY` if `scores` is empty.
157pub fn conformal_quantile(scores: &[f64], alpha: f64) -> f64 {
158    if scores.is_empty() {
159        return f64::INFINITY;
160    }
161    let n = scores.len();
162    // Level: ceil((n+1)(1-alpha)) / n  ≡  (1+1/n)(1-alpha) quantile
163    let level = ((n + 1) as f64 * (1.0 - alpha) / n as f64).min(1.0);
164    let mut sorted = scores.to_vec();
165    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
166    let idx = ((level * n as f64).ceil() as usize)
167        .saturating_sub(1)
168        .min(n - 1);
169    sorted[idx]
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[test]
177    fn test_conformal_config_default() {
178        let cfg = ConformalConfig::default();
179        assert!((cfg.alpha - 0.1).abs() < 1e-10);
180        assert_eq!(cfg.score_fn, ScoreType::AbsResidual);
181    }
182
183    #[test]
184    fn test_cp_config_default() {
185        let cfg = CpConfig::default();
186        assert!((cfg.coverage_target - 0.9).abs() < 1e-10);
187        assert!(!cfg.adaptive);
188    }
189
190    #[test]
191    fn test_raps_config_default() {
192        let cfg = RapsConfig::default();
193        assert_eq!(cfg.k_reg, 5);
194        assert!(cfg.lambda > 0.0);
195    }
196
197    #[test]
198    fn test_prediction_set_contains_value() {
199        let ps = PredictionSet::interval(1.0, 3.0);
200        assert!(ps.contains_value(2.0));
201        assert!(!ps.contains_value(0.5));
202        assert!((ps.width() - 2.0).abs() < 1e-10);
203    }
204
205    #[test]
206    fn test_prediction_set_classification() {
207        let ps = PredictionSet::classification(vec![0, 2]);
208        assert!(ps.contains_class(0));
209        assert!(!ps.contains_class(1));
210    }
211
212    #[test]
213    fn test_conformal_quantile_basic() {
214        let scores: Vec<f64> = (1..=10).map(|x| x as f64).collect();
215        let q = conformal_quantile(&scores, 0.1);
216        // For n=10, level = (11*0.9/10) = 0.99 → index ceil(9.9)-1 = 9 → score = 10.0
217        assert!(q <= 10.0);
218    }
219
220    #[test]
221    fn test_conformal_quantile_empty() {
222        let q = conformal_quantile(&[], 0.1);
223        assert!(q.is_infinite());
224    }
225}