scirs2_stats/conformal/
types.rs1#[derive(Debug, Clone, Copy, PartialEq)]
11#[non_exhaustive]
12pub enum ScoreType {
13 AbsResidual,
15 QuantileRegression,
18 NormalizedResidual,
20 Hpd,
22 Raps,
25}
26
27impl Default for ScoreType {
28 fn default() -> Self {
29 ScoreType::AbsResidual
30 }
31}
32
33#[derive(Debug, Clone)]
35pub struct ConformalConfig {
36 pub alpha: f64,
38 pub score_fn: ScoreType,
40}
41
42impl Default for ConformalConfig {
43 fn default() -> Self {
44 Self {
45 alpha: 0.1,
46 score_fn: ScoreType::AbsResidual,
47 }
48 }
49}
50
51#[derive(Debug, Clone, PartialEq)]
56pub struct PredictionSet {
57 pub lower: f64,
59 pub upper: f64,
61 pub set: Vec<usize>,
63}
64
65impl PredictionSet {
66 pub fn interval(lower: f64, upper: f64) -> Self {
68 Self {
69 lower,
70 upper,
71 set: Vec::new(),
72 }
73 }
74
75 pub fn classification(set: Vec<usize>) -> Self {
77 Self {
78 lower: f64::NEG_INFINITY,
79 upper: f64::INFINITY,
80 set,
81 }
82 }
83
84 pub fn contains_value(&self, value: f64) -> bool {
86 value >= self.lower && value <= self.upper
87 }
88
89 pub fn contains_class(&self, class: usize) -> bool {
91 self.set.contains(&class)
92 }
93
94 pub fn width(&self) -> f64 {
97 if self.set.is_empty() {
98 self.upper - self.lower
99 } else {
100 f64::INFINITY
101 }
102 }
103}
104
105#[derive(Debug, Clone)]
107pub struct ConformalResult {
108 pub sets: Vec<PredictionSet>,
110 pub coverage: f64,
113 pub avg_width: f64,
115}
116
117#[derive(Debug, Clone)]
119pub struct RapsConfig {
120 pub k_reg: usize,
122 pub lambda: f64,
124}
125
126impl Default for RapsConfig {
127 fn default() -> Self {
128 Self {
129 k_reg: 5,
130 lambda: 0.01,
131 }
132 }
133}
134
135#[derive(Debug, Clone)]
137pub struct CpConfig {
138 pub coverage_target: f64,
140 pub adaptive: bool,
142}
143
144impl Default for CpConfig {
145 fn default() -> Self {
146 Self {
147 coverage_target: 0.9,
148 adaptive: false,
149 }
150 }
151}
152
153pub fn conformal_quantile(scores: &[f64], alpha: f64) -> f64 {
158 if scores.is_empty() {
159 return f64::INFINITY;
160 }
161 let n = scores.len();
162 let level = ((n + 1) as f64 * (1.0 - alpha) / n as f64).min(1.0);
164 let mut sorted = scores.to_vec();
165 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
166 let idx = ((level * n as f64).ceil() as usize)
167 .saturating_sub(1)
168 .min(n - 1);
169 sorted[idx]
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn test_conformal_config_default() {
178 let cfg = ConformalConfig::default();
179 assert!((cfg.alpha - 0.1).abs() < 1e-10);
180 assert_eq!(cfg.score_fn, ScoreType::AbsResidual);
181 }
182
183 #[test]
184 fn test_cp_config_default() {
185 let cfg = CpConfig::default();
186 assert!((cfg.coverage_target - 0.9).abs() < 1e-10);
187 assert!(!cfg.adaptive);
188 }
189
190 #[test]
191 fn test_raps_config_default() {
192 let cfg = RapsConfig::default();
193 assert_eq!(cfg.k_reg, 5);
194 assert!(cfg.lambda > 0.0);
195 }
196
197 #[test]
198 fn test_prediction_set_contains_value() {
199 let ps = PredictionSet::interval(1.0, 3.0);
200 assert!(ps.contains_value(2.0));
201 assert!(!ps.contains_value(0.5));
202 assert!((ps.width() - 2.0).abs() < 1e-10);
203 }
204
205 #[test]
206 fn test_prediction_set_classification() {
207 let ps = PredictionSet::classification(vec![0, 2]);
208 assert!(ps.contains_class(0));
209 assert!(!ps.contains_class(1));
210 }
211
212 #[test]
213 fn test_conformal_quantile_basic() {
214 let scores: Vec<f64> = (1..=10).map(|x| x as f64).collect();
215 let q = conformal_quantile(&scores, 0.1);
216 assert!(q <= 10.0);
218 }
219
220 #[test]
221 fn test_conformal_quantile_empty() {
222 let q = conformal_quantile(&[], 0.1);
223 assert!(q.is_infinite());
224 }
225}