1use sklears_core::error::{Result, SklearsError};
7use std::cmp::Ordering;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone)]
12pub struct ConformalPredictionConfig {
13 pub alpha: f64,
15 pub nonconformity_method: NonconformityMethod,
17 pub normalize: bool,
19 pub class_conditional: bool,
21 pub random_state: Option<u64>,
23 pub inductive: bool,
25 pub calibration_fraction: f64,
27}
28
29impl Default for ConformalPredictionConfig {
30 fn default() -> Self {
31 Self {
32 alpha: 0.1, nonconformity_method: NonconformityMethod::AbsoluteError,
34 normalize: false,
35 class_conditional: false,
36 random_state: None,
37 inductive: true,
38 calibration_fraction: 0.2,
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
45pub enum NonconformityMethod {
46 AbsoluteError,
48 SquaredError,
50 SignedError,
52 Margin,
54 InverseProbability,
56 Custom(fn(&[f64], &[f64]) -> Vec<f64>),
58}
59
60#[derive(Debug, Clone)]
62pub struct ConformalPredictionResult {
63 pub prediction_intervals: Option<Vec<(f64, f64)>>,
65 pub prediction_sets: Option<Vec<Vec<usize>>>,
67 pub calibration_scores: Vec<f64>,
69 pub quantile_threshold: f64,
71 pub coverage_stats: CoverageStatistics,
73 pub efficiency_metrics: EfficiencyMetrics,
75}
76
77#[derive(Debug, Clone)]
79pub struct CoverageStatistics {
80 pub empirical_coverage: f64,
82 pub target_coverage: f64,
84 pub coverage_gap: f64,
86 pub class_coverage: Option<HashMap<usize, f64>>,
88}
89
90#[derive(Debug, Clone)]
92pub struct EfficiencyMetrics {
93 pub average_interval_width: Option<f64>,
95 pub average_set_size: Option<f64>,
97 pub interval_width_std: Option<f64>,
99 pub set_size_std: Option<f64>,
101 pub singleton_rate: Option<f64>,
103 pub empty_set_rate: Option<f64>,
105}
106
107#[derive(Debug, Clone)]
109pub struct ConformalPredictor {
110 config: ConformalPredictionConfig,
111 calibration_scores: Option<Vec<f64>>,
112 quantile_threshold: Option<f64>,
113 class_thresholds: Option<HashMap<usize, f64>>,
114}
115
116impl ConformalPredictor {
117 pub fn new(config: ConformalPredictionConfig) -> Self {
118 Self {
119 config,
120 calibration_scores: None,
121 quantile_threshold: None,
122 class_thresholds: None,
123 }
124 }
125
126 pub fn fit(
128 &mut self,
129 calibration_predictions: &[f64],
130 calibration_targets: &[f64],
131 ) -> Result<()> {
132 if calibration_predictions.len() != calibration_targets.len() {
133 return Err(SklearsError::InvalidInput(
134 "Predictions and targets must have the same length".to_string(),
135 ));
136 }
137
138 let scores =
140 self.compute_nonconformity_scores(calibration_predictions, calibration_targets)?;
141
142 let quantile_level = 1.0 - self.config.alpha;
144 let threshold = self.compute_quantile(&scores, quantile_level);
145
146 self.calibration_scores = Some(scores);
147 self.quantile_threshold = Some(threshold);
148
149 Ok(())
150 }
151
152 pub fn fit_classification(
154 &mut self,
155 calibration_probabilities: &[Vec<f64>],
156 calibration_labels: &[usize],
157 ) -> Result<()> {
158 if calibration_probabilities.len() != calibration_labels.len() {
159 return Err(SklearsError::InvalidInput(
160 "Probabilities and labels must have the same length".to_string(),
161 ));
162 }
163
164 let scores =
165 self.compute_classification_scores(calibration_probabilities, calibration_labels)?;
166
167 if self.config.class_conditional {
168 let mut class_thresholds = HashMap::new();
170 let unique_classes = self.get_unique_classes(calibration_labels);
171
172 for &class in &unique_classes {
173 let class_scores: Vec<f64> = scores
174 .iter()
175 .enumerate()
176 .filter(|(i, _)| calibration_labels[*i] == class)
177 .map(|(_, &score)| score)
178 .collect();
179
180 if !class_scores.is_empty() {
181 let quantile_level = 1.0 - self.config.alpha;
182 let threshold = self.compute_quantile(&class_scores, quantile_level);
183 class_thresholds.insert(class, threshold);
184 }
185 }
186
187 self.class_thresholds = Some(class_thresholds);
188 } else {
189 let quantile_level = 1.0 - self.config.alpha;
191 let threshold = self.compute_quantile(&scores, quantile_level);
192 self.quantile_threshold = Some(threshold);
193 }
194
195 self.calibration_scores = Some(scores);
196
197 Ok(())
198 }
199
200 pub fn predict_intervals(
202 &self,
203 predictions: &[f64],
204 prediction_errors: Option<&[f64]>,
205 ) -> Result<ConformalPredictionResult> {
206 if self.quantile_threshold.is_none() {
207 return Err(SklearsError::NotFitted {
208 operation: "making predictions".to_string(),
209 });
210 }
211
212 let threshold = self.quantile_threshold.unwrap();
213 let mut intervals = Vec::new();
214
215 for (i, &pred) in predictions.iter().enumerate() {
216 let error_scale =
217 if let (true, Some(errors)) = (self.config.normalize, &prediction_errors) {
218 errors[i].max(1e-8) } else {
220 1.0
221 };
222
223 let margin = threshold * error_scale;
224 intervals.push((pred - margin, pred + margin));
225 }
226
227 let average_width =
229 intervals.iter().map(|(l, u)| u - l).sum::<f64>() / intervals.len() as f64;
230 let width_std =
231 self.calculate_std(&intervals.iter().map(|(l, u)| u - l).collect::<Vec<_>>());
232
233 let efficiency_metrics = EfficiencyMetrics {
234 average_interval_width: Some(average_width),
235 average_set_size: None,
236 interval_width_std: Some(width_std),
237 set_size_std: None,
238 singleton_rate: None,
239 empty_set_rate: None,
240 };
241
242 let coverage_stats = CoverageStatistics {
243 empirical_coverage: 0.0, target_coverage: 1.0 - self.config.alpha,
245 coverage_gap: 0.0,
246 class_coverage: None,
247 };
248
249 Ok(ConformalPredictionResult {
250 prediction_intervals: Some(intervals),
251 prediction_sets: None,
252 calibration_scores: self.calibration_scores.clone().unwrap_or_default(),
253 quantile_threshold: threshold,
254 coverage_stats,
255 efficiency_metrics,
256 })
257 }
258
259 pub fn predict_sets(
261 &self,
262 prediction_probabilities: &[Vec<f64>],
263 ) -> Result<ConformalPredictionResult> {
264 if self.quantile_threshold.is_none() && self.class_thresholds.is_none() {
265 return Err(SklearsError::NotFitted {
266 operation: "making predictions".to_string(),
267 });
268 }
269
270 let mut prediction_sets = Vec::new();
271
272 for probs in prediction_probabilities {
273 let mut prediction_set = Vec::new();
274
275 for (class_idx, &prob) in probs.iter().enumerate() {
276 let threshold = if let Some(ref class_thresholds) = self.class_thresholds {
277 class_thresholds.get(&class_idx).copied().unwrap_or(0.0)
278 } else {
279 self.quantile_threshold.unwrap()
280 };
281
282 let score = match self.config.nonconformity_method {
284 NonconformityMethod::InverseProbability => 1.0 - prob,
285 _ => 1.0 - prob, };
287
288 if score <= threshold {
289 prediction_set.push(class_idx);
290 }
291 }
292
293 prediction_sets.push(prediction_set);
294 }
295
296 let set_sizes: Vec<f64> = prediction_sets.iter().map(|s| s.len() as f64).collect();
298 let average_set_size = set_sizes.iter().sum::<f64>() / set_sizes.len() as f64;
299 let set_size_std = self.calculate_std(&set_sizes);
300 let singleton_rate =
301 set_sizes.iter().filter(|&&size| size == 1.0).count() as f64 / set_sizes.len() as f64;
302 let empty_set_rate =
303 set_sizes.iter().filter(|&&size| size == 0.0).count() as f64 / set_sizes.len() as f64;
304
305 let efficiency_metrics = EfficiencyMetrics {
306 average_interval_width: None,
307 average_set_size: Some(average_set_size),
308 interval_width_std: None,
309 set_size_std: Some(set_size_std),
310 singleton_rate: Some(singleton_rate),
311 empty_set_rate: Some(empty_set_rate),
312 };
313
314 let coverage_stats = CoverageStatistics {
315 empirical_coverage: 0.0, target_coverage: 1.0 - self.config.alpha,
317 coverage_gap: 0.0,
318 class_coverage: None,
319 };
320
321 let threshold = self.quantile_threshold.unwrap_or(0.0);
322
323 Ok(ConformalPredictionResult {
324 prediction_intervals: None,
325 prediction_sets: Some(prediction_sets),
326 calibration_scores: self.calibration_scores.clone().unwrap_or_default(),
327 quantile_threshold: threshold,
328 coverage_stats,
329 efficiency_metrics,
330 })
331 }
332
333 pub fn evaluate_coverage(
335 &self,
336 predictions: &[f64],
337 true_values: &[f64],
338 prediction_errors: Option<&[f64]>,
339 ) -> Result<CoverageStatistics> {
340 let result = self.predict_intervals(predictions, prediction_errors)?;
341 let intervals = result.prediction_intervals.unwrap();
342
343 let mut covered = 0;
344 for (i, &true_val) in true_values.iter().enumerate() {
345 let (lower, upper) = intervals[i];
346 if true_val >= lower && true_val <= upper {
347 covered += 1;
348 }
349 }
350
351 let empirical_coverage = covered as f64 / true_values.len() as f64;
352 let target_coverage = 1.0 - self.config.alpha;
353 let coverage_gap = empirical_coverage - target_coverage;
354
355 Ok(CoverageStatistics {
356 empirical_coverage,
357 target_coverage,
358 coverage_gap,
359 class_coverage: None,
360 })
361 }
362
363 pub fn evaluate_classification_coverage(
365 &self,
366 prediction_probabilities: &[Vec<f64>],
367 true_labels: &[usize],
368 ) -> Result<CoverageStatistics> {
369 let result = self.predict_sets(prediction_probabilities)?;
370 let prediction_sets = result.prediction_sets.unwrap();
371
372 let mut covered = 0;
373 let mut class_coverage_counts: HashMap<usize, (usize, usize)> = HashMap::new();
374
375 for (i, &true_label) in true_labels.iter().enumerate() {
376 let prediction_set = &prediction_sets[i];
377 let is_covered = prediction_set.contains(&true_label);
378
379 if is_covered {
380 covered += 1;
381 }
382
383 let (class_covered, class_total) =
385 class_coverage_counts.entry(true_label).or_insert((0, 0));
386 if is_covered {
387 *class_covered += 1;
388 }
389 *class_total += 1;
390 }
391
392 let empirical_coverage = covered as f64 / true_labels.len() as f64;
393 let target_coverage = 1.0 - self.config.alpha;
394 let coverage_gap = empirical_coverage - target_coverage;
395
396 let mut class_coverage = HashMap::new();
398 for (&class, &(covered_count, total_count)) in &class_coverage_counts {
399 class_coverage.insert(class, covered_count as f64 / total_count as f64);
400 }
401
402 Ok(CoverageStatistics {
403 empirical_coverage,
404 target_coverage,
405 coverage_gap,
406 class_coverage: Some(class_coverage),
407 })
408 }
409
410 fn compute_nonconformity_scores(
412 &self,
413 predictions: &[f64],
414 targets: &[f64],
415 ) -> Result<Vec<f64>> {
416 match self.config.nonconformity_method {
417 NonconformityMethod::AbsoluteError => Ok(predictions
418 .iter()
419 .zip(targets.iter())
420 .map(|(&pred, &target)| (target - pred).abs())
421 .collect()),
422 NonconformityMethod::SquaredError => Ok(predictions
423 .iter()
424 .zip(targets.iter())
425 .map(|(&pred, &target)| (target - pred).powi(2))
426 .collect()),
427 NonconformityMethod::SignedError => Ok(predictions
428 .iter()
429 .zip(targets.iter())
430 .map(|(&pred, &target)| target - pred)
431 .collect()),
432 NonconformityMethod::Custom(func) => Ok(func(predictions, targets)),
433 _ => Err(SklearsError::InvalidInput(
434 "Invalid nonconformity method for regression".to_string(),
435 )),
436 }
437 }
438
439 fn compute_classification_scores(
441 &self,
442 probabilities: &[Vec<f64>],
443 labels: &[usize],
444 ) -> Result<Vec<f64>> {
445 match self.config.nonconformity_method {
446 NonconformityMethod::InverseProbability => {
447 let scores = probabilities
448 .iter()
449 .zip(labels.iter())
450 .map(|(probs, &label)| 1.0 - probs.get(label).copied().unwrap_or(0.0))
451 .collect();
452 Ok(scores)
453 }
454 NonconformityMethod::Margin => {
455 let scores = probabilities
456 .iter()
457 .zip(labels.iter())
458 .map(|(probs, &label)| {
459 let true_class_prob = probs.get(label).copied().unwrap_or(0.0);
460 let max_other_prob = probs
461 .iter()
462 .enumerate()
463 .filter(|(i, _)| *i != label)
464 .map(|(_, &prob)| prob)
465 .fold(0.0, f64::max);
466 max_other_prob - true_class_prob
467 })
468 .collect();
469 Ok(scores)
470 }
471 _ => Err(SklearsError::InvalidInput(
472 "Invalid nonconformity method for classification".to_string(),
473 )),
474 }
475 }
476
477 fn compute_quantile(&self, values: &[f64], quantile: f64) -> f64 {
479 if values.is_empty() {
480 return 0.0;
481 }
482
483 let mut sorted_values = values.to_vec();
484 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
485
486 let n = sorted_values.len();
487 let index = (quantile * (n + 1) as f64).ceil() as usize;
488 let index = index.min(n).saturating_sub(1);
489
490 sorted_values[index]
491 }
492
493 fn get_unique_classes(&self, labels: &[usize]) -> Vec<usize> {
495 let mut unique_classes: Vec<usize> = labels.to_vec();
496 unique_classes.sort_unstable();
497 unique_classes.dedup();
498 unique_classes
499 }
500
501 fn calculate_std(&self, values: &[f64]) -> f64 {
503 if values.len() < 2 {
504 return 0.0;
505 }
506
507 let mean = values.iter().sum::<f64>() / values.len() as f64;
508 let variance =
509 values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
510
511 variance.sqrt()
512 }
513}
514
515#[derive(Debug, Clone)]
517pub struct JackknifeConformalPredictor {
518 base_predictor: ConformalPredictor,
519 jackknife_predictions: Option<Vec<Vec<f64>>>,
520}
521
522impl JackknifeConformalPredictor {
523 pub fn new(config: ConformalPredictionConfig) -> Self {
524 Self {
525 base_predictor: ConformalPredictor::new(config),
526 jackknife_predictions: None,
527 }
528 }
529
530 pub fn fit_jackknife(
532 &mut self,
533 all_predictions: &[Vec<f64>], targets: &[f64],
535 ) -> Result<()> {
536 if all_predictions.len() != targets.len() {
537 return Err(SklearsError::InvalidInput(
538 "Predictions and targets must have the same length".to_string(),
539 ));
540 }
541
542 let mut residuals = Vec::new();
544 for (i, preds) in all_predictions.iter().enumerate() {
545 if !preds.is_empty() {
546 let residual = (targets[i] - preds[0]).abs(); residuals.push(residual);
548 }
549 }
550
551 self.base_predictor.calibration_scores = Some(residuals.clone());
553 let quantile_level = 1.0 - self.base_predictor.config.alpha;
554 let threshold = self
555 .base_predictor
556 .compute_quantile(&residuals, quantile_level);
557 self.base_predictor.quantile_threshold = Some(threshold);
558 self.jackknife_predictions = Some(all_predictions.to_vec());
559
560 Ok(())
561 }
562
563 pub fn predict_jackknife_intervals(
565 &self,
566 predictions: &[f64],
567 ) -> Result<ConformalPredictionResult> {
568 self.base_predictor.predict_intervals(predictions, None)
569 }
570}
571
572#[allow(non_snake_case)]
573#[cfg(test)]
574mod tests {
575 use super::*;
576
577 #[test]
578 fn test_conformal_prediction_regression() {
579 let config = ConformalPredictionConfig::default();
580 let mut predictor = ConformalPredictor::new(config);
581
582 let cal_preds = vec![1.0, 2.0, 3.0, 4.0, 5.0];
584 let cal_targets = vec![1.1, 1.9, 3.2, 3.8, 5.1];
585
586 predictor.fit(&cal_preds, &cal_targets).unwrap();
587
588 let test_preds = vec![2.5, 4.5];
590 let result = predictor.predict_intervals(&test_preds, None).unwrap();
591
592 assert!(result.prediction_intervals.is_some());
593 let intervals = result.prediction_intervals.unwrap();
594 assert_eq!(intervals.len(), 2);
595
596 for (lower, upper) in intervals {
598 assert!(upper > lower, "Interval should have positive width");
599 }
600 }
601
602 #[test]
603 fn test_conformal_prediction_classification() {
604 let config = ConformalPredictionConfig {
605 nonconformity_method: NonconformityMethod::InverseProbability,
606 ..ConformalPredictionConfig::default()
607 };
608 let mut predictor = ConformalPredictor::new(config);
609
610 let cal_probs = vec![
612 vec![0.8, 0.1, 0.1],
613 vec![0.2, 0.7, 0.1],
614 vec![0.1, 0.2, 0.7],
615 vec![0.6, 0.3, 0.1],
616 vec![0.1, 0.1, 0.8],
617 ];
618 let cal_labels = vec![0, 1, 2, 0, 2];
619
620 predictor
621 .fit_classification(&cal_probs, &cal_labels)
622 .unwrap();
623
624 let test_probs = vec![vec![0.5, 0.3, 0.2], vec![0.2, 0.6, 0.2]];
626 let result = predictor.predict_sets(&test_probs).unwrap();
627
628 assert!(result.prediction_sets.is_some());
629 let sets = result.prediction_sets.unwrap();
630 assert_eq!(sets.len(), 2);
631 }
632
633 #[test]
634 fn test_coverage_evaluation() {
635 let config = ConformalPredictionConfig {
636 alpha: 0.2, ..Default::default()
638 };
639 let mut predictor = ConformalPredictor::new(config);
640
641 let cal_preds = vec![1.0, 2.0, 3.0, 4.0, 5.0];
643 let cal_targets = vec![1.0, 2.0, 3.0, 4.0, 5.0]; predictor.fit(&cal_preds, &cal_targets).unwrap();
646
647 let test_preds = vec![1.5, 2.5];
649 let test_targets = vec![1.5, 2.5]; let coverage = predictor
651 .evaluate_coverage(&test_preds, &test_targets, None)
652 .unwrap();
653
654 assert!(coverage.empirical_coverage >= 0.0);
655 assert!(coverage.empirical_coverage <= 1.0);
656 assert_eq!(coverage.target_coverage, 0.8);
657 }
658}