sklears_model_selection/
conformal_prediction.rs

1//! Conformal prediction methods for uncertainty quantification
2//!
3//! This module provides conformal prediction algorithms that generate
4//! prediction intervals with finite-sample validity guarantees.
5
6use sklears_core::error::{Result, SklearsError};
7use std::cmp::Ordering;
8use std::collections::HashMap;
9
10/// Configuration for conformal prediction
11#[derive(Debug, Clone)]
12pub struct ConformalPredictionConfig {
13    /// Significance level (1 - coverage probability)
14    pub alpha: f64,
15    /// Method for computing nonconformity scores
16    pub nonconformity_method: NonconformityMethod,
17    /// Whether to use normalized nonconformity scores
18    pub normalize: bool,
19    /// Method for handling class imbalance in classification
20    pub class_conditional: bool,
21    /// Random state for reproducible results
22    pub random_state: Option<u64>,
23    /// Whether to use inductive (split) conformal prediction
24    pub inductive: bool,
25    /// Fraction of data to use for calibration in inductive setting
26    pub calibration_fraction: f64,
27}
28
29impl Default for ConformalPredictionConfig {
30    fn default() -> Self {
31        Self {
32            alpha: 0.1, // 90% coverage
33            nonconformity_method: NonconformityMethod::AbsoluteError,
34            normalize: false,
35            class_conditional: false,
36            random_state: None,
37            inductive: true,
38            calibration_fraction: 0.2,
39        }
40    }
41}
42
43/// Methods for computing nonconformity scores
44#[derive(Debug, Clone)]
45pub enum NonconformityMethod {
46    /// Absolute error: |y - ŷ|
47    AbsoluteError,
48    /// Squared error: (y - ŷ)²
49    SquaredError,
50    /// Signed error: y - ŷ (for quantile regression)
51    SignedError,
52    /// Margin-based (for classification): margin to true class
53    Margin,
54    /// Inverse probability score (for classification)
55    InverseProbability,
56    /// Custom nonconformity function
57    Custom(fn(&[f64], &[f64]) -> Vec<f64>),
58}
59
60/// Results from conformal prediction
61#[derive(Debug, Clone)]
62pub struct ConformalPredictionResult {
63    /// Prediction intervals (lower, upper) for regression
64    pub prediction_intervals: Option<Vec<(f64, f64)>>,
65    /// Prediction sets for classification
66    pub prediction_sets: Option<Vec<Vec<usize>>>,
67    /// Nonconformity scores used for calibration
68    pub calibration_scores: Vec<f64>,
69    /// Quantile threshold used for prediction intervals
70    pub quantile_threshold: f64,
71    /// Coverage statistics
72    pub coverage_stats: CoverageStatistics,
73    /// Efficiency metrics
74    pub efficiency_metrics: EfficiencyMetrics,
75}
76
77/// Coverage statistics for conformal prediction
78#[derive(Debug, Clone)]
79pub struct CoverageStatistics {
80    /// Empirical coverage rate
81    pub empirical_coverage: f64,
82    /// Target coverage rate (1 - alpha)
83    pub target_coverage: f64,
84    /// Coverage difference from target
85    pub coverage_gap: f64,
86    /// Coverage per class (for classification)
87    pub class_coverage: Option<HashMap<usize, f64>>,
88}
89
90/// Efficiency metrics for conformal prediction
91#[derive(Debug, Clone)]
92pub struct EfficiencyMetrics {
93    /// Average interval width (regression)
94    pub average_interval_width: Option<f64>,
95    /// Average set size (classification)
96    pub average_set_size: Option<f64>,
97    /// Interval width variability
98    pub interval_width_std: Option<f64>,
99    /// Set size variability
100    pub set_size_std: Option<f64>,
101    /// Singleton rate (classification: fraction with single prediction)
102    pub singleton_rate: Option<f64>,
103    /// Empty set rate (classification: fraction with no predictions)
104    pub empty_set_rate: Option<f64>,
105}
106
107/// Conformal predictor for regression and classification
108#[derive(Debug, Clone)]
109pub struct ConformalPredictor {
110    config: ConformalPredictionConfig,
111    calibration_scores: Option<Vec<f64>>,
112    quantile_threshold: Option<f64>,
113    class_thresholds: Option<HashMap<usize, f64>>,
114}
115
116impl ConformalPredictor {
117    pub fn new(config: ConformalPredictionConfig) -> Self {
118        Self {
119            config,
120            calibration_scores: None,
121            quantile_threshold: None,
122            class_thresholds: None,
123        }
124    }
125
126    /// Fit conformal predictor on calibration data
127    pub fn fit(
128        &mut self,
129        calibration_predictions: &[f64],
130        calibration_targets: &[f64],
131    ) -> Result<()> {
132        if calibration_predictions.len() != calibration_targets.len() {
133            return Err(SklearsError::InvalidInput(
134                "Predictions and targets must have the same length".to_string(),
135            ));
136        }
137
138        // Compute nonconformity scores
139        let scores =
140            self.compute_nonconformity_scores(calibration_predictions, calibration_targets)?;
141
142        // Calculate quantile threshold
143        let quantile_level = 1.0 - self.config.alpha;
144        let threshold = self.compute_quantile(&scores, quantile_level);
145
146        self.calibration_scores = Some(scores);
147        self.quantile_threshold = Some(threshold);
148
149        Ok(())
150    }
151
152    /// Fit conformal predictor for classification
153    pub fn fit_classification(
154        &mut self,
155        calibration_probabilities: &[Vec<f64>],
156        calibration_labels: &[usize],
157    ) -> Result<()> {
158        if calibration_probabilities.len() != calibration_labels.len() {
159            return Err(SklearsError::InvalidInput(
160                "Probabilities and labels must have the same length".to_string(),
161            ));
162        }
163
164        let scores =
165            self.compute_classification_scores(calibration_probabilities, calibration_labels)?;
166
167        if self.config.class_conditional {
168            // Compute separate thresholds for each class
169            let mut class_thresholds = HashMap::new();
170            let unique_classes = self.get_unique_classes(calibration_labels);
171
172            for &class in &unique_classes {
173                let class_scores: Vec<f64> = scores
174                    .iter()
175                    .enumerate()
176                    .filter(|(i, _)| calibration_labels[*i] == class)
177                    .map(|(_, &score)| score)
178                    .collect();
179
180                if !class_scores.is_empty() {
181                    let quantile_level = 1.0 - self.config.alpha;
182                    let threshold = self.compute_quantile(&class_scores, quantile_level);
183                    class_thresholds.insert(class, threshold);
184                }
185            }
186
187            self.class_thresholds = Some(class_thresholds);
188        } else {
189            // Single threshold for all classes
190            let quantile_level = 1.0 - self.config.alpha;
191            let threshold = self.compute_quantile(&scores, quantile_level);
192            self.quantile_threshold = Some(threshold);
193        }
194
195        self.calibration_scores = Some(scores);
196
197        Ok(())
198    }
199
200    /// Generate prediction intervals for regression
201    pub fn predict_intervals(
202        &self,
203        predictions: &[f64],
204        prediction_errors: Option<&[f64]>,
205    ) -> Result<ConformalPredictionResult> {
206        if self.quantile_threshold.is_none() {
207            return Err(SklearsError::NotFitted {
208                operation: "making predictions".to_string(),
209            });
210        }
211
212        let threshold = self.quantile_threshold.unwrap();
213        let mut intervals = Vec::new();
214
215        for (i, &pred) in predictions.iter().enumerate() {
216            let error_scale =
217                if let (true, Some(errors)) = (self.config.normalize, &prediction_errors) {
218                    errors[i].max(1e-8) // Avoid division by zero
219                } else {
220                    1.0
221                };
222
223            let margin = threshold * error_scale;
224            intervals.push((pred - margin, pred + margin));
225        }
226
227        // Calculate efficiency metrics
228        let average_width =
229            intervals.iter().map(|(l, u)| u - l).sum::<f64>() / intervals.len() as f64;
230        let width_std =
231            self.calculate_std(&intervals.iter().map(|(l, u)| u - l).collect::<Vec<_>>());
232
233        let efficiency_metrics = EfficiencyMetrics {
234            average_interval_width: Some(average_width),
235            average_set_size: None,
236            interval_width_std: Some(width_std),
237            set_size_std: None,
238            singleton_rate: None,
239            empty_set_rate: None,
240        };
241
242        let coverage_stats = CoverageStatistics {
243            empirical_coverage: 0.0, // Would need true labels to compute
244            target_coverage: 1.0 - self.config.alpha,
245            coverage_gap: 0.0,
246            class_coverage: None,
247        };
248
249        Ok(ConformalPredictionResult {
250            prediction_intervals: Some(intervals),
251            prediction_sets: None,
252            calibration_scores: self.calibration_scores.clone().unwrap_or_default(),
253            quantile_threshold: threshold,
254            coverage_stats,
255            efficiency_metrics,
256        })
257    }
258
259    /// Generate prediction sets for classification
260    pub fn predict_sets(
261        &self,
262        prediction_probabilities: &[Vec<f64>],
263    ) -> Result<ConformalPredictionResult> {
264        if self.quantile_threshold.is_none() && self.class_thresholds.is_none() {
265            return Err(SklearsError::NotFitted {
266                operation: "making predictions".to_string(),
267            });
268        }
269
270        let mut prediction_sets = Vec::new();
271
272        for probs in prediction_probabilities {
273            let mut prediction_set = Vec::new();
274
275            for (class_idx, &prob) in probs.iter().enumerate() {
276                let threshold = if let Some(ref class_thresholds) = self.class_thresholds {
277                    class_thresholds.get(&class_idx).copied().unwrap_or(0.0)
278                } else {
279                    self.quantile_threshold.unwrap()
280                };
281
282                // Include class if its score is below threshold
283                let score = match self.config.nonconformity_method {
284                    NonconformityMethod::InverseProbability => 1.0 - prob,
285                    _ => 1.0 - prob, // Default for classification
286                };
287
288                if score <= threshold {
289                    prediction_set.push(class_idx);
290                }
291            }
292
293            prediction_sets.push(prediction_set);
294        }
295
296        // Calculate efficiency metrics
297        let set_sizes: Vec<f64> = prediction_sets.iter().map(|s| s.len() as f64).collect();
298        let average_set_size = set_sizes.iter().sum::<f64>() / set_sizes.len() as f64;
299        let set_size_std = self.calculate_std(&set_sizes);
300        let singleton_rate =
301            set_sizes.iter().filter(|&&size| size == 1.0).count() as f64 / set_sizes.len() as f64;
302        let empty_set_rate =
303            set_sizes.iter().filter(|&&size| size == 0.0).count() as f64 / set_sizes.len() as f64;
304
305        let efficiency_metrics = EfficiencyMetrics {
306            average_interval_width: None,
307            average_set_size: Some(average_set_size),
308            interval_width_std: None,
309            set_size_std: Some(set_size_std),
310            singleton_rate: Some(singleton_rate),
311            empty_set_rate: Some(empty_set_rate),
312        };
313
314        let coverage_stats = CoverageStatistics {
315            empirical_coverage: 0.0, // Would need true labels to compute
316            target_coverage: 1.0 - self.config.alpha,
317            coverage_gap: 0.0,
318            class_coverage: None,
319        };
320
321        let threshold = self.quantile_threshold.unwrap_or(0.0);
322
323        Ok(ConformalPredictionResult {
324            prediction_intervals: None,
325            prediction_sets: Some(prediction_sets),
326            calibration_scores: self.calibration_scores.clone().unwrap_or_default(),
327            quantile_threshold: threshold,
328            coverage_stats,
329            efficiency_metrics,
330        })
331    }
332
333    /// Evaluate coverage and efficiency on test data
334    pub fn evaluate_coverage(
335        &self,
336        predictions: &[f64],
337        true_values: &[f64],
338        prediction_errors: Option<&[f64]>,
339    ) -> Result<CoverageStatistics> {
340        let result = self.predict_intervals(predictions, prediction_errors)?;
341        let intervals = result.prediction_intervals.unwrap();
342
343        let mut covered = 0;
344        for (i, &true_val) in true_values.iter().enumerate() {
345            let (lower, upper) = intervals[i];
346            if true_val >= lower && true_val <= upper {
347                covered += 1;
348            }
349        }
350
351        let empirical_coverage = covered as f64 / true_values.len() as f64;
352        let target_coverage = 1.0 - self.config.alpha;
353        let coverage_gap = empirical_coverage - target_coverage;
354
355        Ok(CoverageStatistics {
356            empirical_coverage,
357            target_coverage,
358            coverage_gap,
359            class_coverage: None,
360        })
361    }
362
363    /// Evaluate classification coverage
364    pub fn evaluate_classification_coverage(
365        &self,
366        prediction_probabilities: &[Vec<f64>],
367        true_labels: &[usize],
368    ) -> Result<CoverageStatistics> {
369        let result = self.predict_sets(prediction_probabilities)?;
370        let prediction_sets = result.prediction_sets.unwrap();
371
372        let mut covered = 0;
373        let mut class_coverage_counts: HashMap<usize, (usize, usize)> = HashMap::new();
374
375        for (i, &true_label) in true_labels.iter().enumerate() {
376            let prediction_set = &prediction_sets[i];
377            let is_covered = prediction_set.contains(&true_label);
378
379            if is_covered {
380                covered += 1;
381            }
382
383            // Track class-specific coverage
384            let (class_covered, class_total) =
385                class_coverage_counts.entry(true_label).or_insert((0, 0));
386            if is_covered {
387                *class_covered += 1;
388            }
389            *class_total += 1;
390        }
391
392        let empirical_coverage = covered as f64 / true_labels.len() as f64;
393        let target_coverage = 1.0 - self.config.alpha;
394        let coverage_gap = empirical_coverage - target_coverage;
395
396        // Calculate per-class coverage
397        let mut class_coverage = HashMap::new();
398        for (&class, &(covered_count, total_count)) in &class_coverage_counts {
399            class_coverage.insert(class, covered_count as f64 / total_count as f64);
400        }
401
402        Ok(CoverageStatistics {
403            empirical_coverage,
404            target_coverage,
405            coverage_gap,
406            class_coverage: Some(class_coverage),
407        })
408    }
409
410    /// Compute nonconformity scores
411    fn compute_nonconformity_scores(
412        &self,
413        predictions: &[f64],
414        targets: &[f64],
415    ) -> Result<Vec<f64>> {
416        match self.config.nonconformity_method {
417            NonconformityMethod::AbsoluteError => Ok(predictions
418                .iter()
419                .zip(targets.iter())
420                .map(|(&pred, &target)| (target - pred).abs())
421                .collect()),
422            NonconformityMethod::SquaredError => Ok(predictions
423                .iter()
424                .zip(targets.iter())
425                .map(|(&pred, &target)| (target - pred).powi(2))
426                .collect()),
427            NonconformityMethod::SignedError => Ok(predictions
428                .iter()
429                .zip(targets.iter())
430                .map(|(&pred, &target)| target - pred)
431                .collect()),
432            NonconformityMethod::Custom(func) => Ok(func(predictions, targets)),
433            _ => Err(SklearsError::InvalidInput(
434                "Invalid nonconformity method for regression".to_string(),
435            )),
436        }
437    }
438
439    /// Compute classification nonconformity scores
440    fn compute_classification_scores(
441        &self,
442        probabilities: &[Vec<f64>],
443        labels: &[usize],
444    ) -> Result<Vec<f64>> {
445        match self.config.nonconformity_method {
446            NonconformityMethod::InverseProbability => {
447                let scores = probabilities
448                    .iter()
449                    .zip(labels.iter())
450                    .map(|(probs, &label)| 1.0 - probs.get(label).copied().unwrap_or(0.0))
451                    .collect();
452                Ok(scores)
453            }
454            NonconformityMethod::Margin => {
455                let scores = probabilities
456                    .iter()
457                    .zip(labels.iter())
458                    .map(|(probs, &label)| {
459                        let true_class_prob = probs.get(label).copied().unwrap_or(0.0);
460                        let max_other_prob = probs
461                            .iter()
462                            .enumerate()
463                            .filter(|(i, _)| *i != label)
464                            .map(|(_, &prob)| prob)
465                            .fold(0.0, f64::max);
466                        max_other_prob - true_class_prob
467                    })
468                    .collect();
469                Ok(scores)
470            }
471            _ => Err(SklearsError::InvalidInput(
472                "Invalid nonconformity method for classification".to_string(),
473            )),
474        }
475    }
476
477    /// Compute quantile of a vector
478    fn compute_quantile(&self, values: &[f64], quantile: f64) -> f64 {
479        if values.is_empty() {
480            return 0.0;
481        }
482
483        let mut sorted_values = values.to_vec();
484        sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
485
486        let n = sorted_values.len();
487        let index = (quantile * (n + 1) as f64).ceil() as usize;
488        let index = index.min(n).saturating_sub(1);
489
490        sorted_values[index]
491    }
492
493    /// Get unique classes from labels
494    fn get_unique_classes(&self, labels: &[usize]) -> Vec<usize> {
495        let mut unique_classes: Vec<usize> = labels.to_vec();
496        unique_classes.sort_unstable();
497        unique_classes.dedup();
498        unique_classes
499    }
500
501    /// Calculate standard deviation
502    fn calculate_std(&self, values: &[f64]) -> f64 {
503        if values.len() < 2 {
504            return 0.0;
505        }
506
507        let mean = values.iter().sum::<f64>() / values.len() as f64;
508        let variance =
509            values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
510
511        variance.sqrt()
512    }
513}
514
515/// Jackknife+ conformal prediction for better efficiency
516#[derive(Debug, Clone)]
517pub struct JackknifeConformalPredictor {
518    base_predictor: ConformalPredictor,
519    jackknife_predictions: Option<Vec<Vec<f64>>>,
520}
521
522impl JackknifeConformalPredictor {
523    pub fn new(config: ConformalPredictionConfig) -> Self {
524        Self {
525            base_predictor: ConformalPredictor::new(config),
526            jackknife_predictions: None,
527        }
528    }
529
530    /// Fit using jackknife+ method
531    pub fn fit_jackknife(
532        &mut self,
533        all_predictions: &[Vec<f64>], // Predictions from leave-one-out models
534        targets: &[f64],
535    ) -> Result<()> {
536        if all_predictions.len() != targets.len() {
537            return Err(SklearsError::InvalidInput(
538                "Predictions and targets must have the same length".to_string(),
539            ));
540        }
541
542        // Compute residuals for each jackknife prediction
543        let mut residuals = Vec::new();
544        for (i, preds) in all_predictions.iter().enumerate() {
545            if !preds.is_empty() {
546                let residual = (targets[i] - preds[0]).abs(); // Use first prediction
547                residuals.push(residual);
548            }
549        }
550
551        // Use residuals as calibration scores
552        self.base_predictor.calibration_scores = Some(residuals.clone());
553        let quantile_level = 1.0 - self.base_predictor.config.alpha;
554        let threshold = self
555            .base_predictor
556            .compute_quantile(&residuals, quantile_level);
557        self.base_predictor.quantile_threshold = Some(threshold);
558        self.jackknife_predictions = Some(all_predictions.to_vec());
559
560        Ok(())
561    }
562
563    /// Generate jackknife+ prediction intervals
564    pub fn predict_jackknife_intervals(
565        &self,
566        predictions: &[f64],
567    ) -> Result<ConformalPredictionResult> {
568        self.base_predictor.predict_intervals(predictions, None)
569    }
570}
571
572#[allow(non_snake_case)]
573#[cfg(test)]
574mod tests {
575    use super::*;
576
577    #[test]
578    fn test_conformal_prediction_regression() {
579        let config = ConformalPredictionConfig::default();
580        let mut predictor = ConformalPredictor::new(config);
581
582        // Create synthetic calibration data
583        let cal_preds = vec![1.0, 2.0, 3.0, 4.0, 5.0];
584        let cal_targets = vec![1.1, 1.9, 3.2, 3.8, 5.1];
585
586        predictor.fit(&cal_preds, &cal_targets).unwrap();
587
588        // Make predictions
589        let test_preds = vec![2.5, 4.5];
590        let result = predictor.predict_intervals(&test_preds, None).unwrap();
591
592        assert!(result.prediction_intervals.is_some());
593        let intervals = result.prediction_intervals.unwrap();
594        assert_eq!(intervals.len(), 2);
595
596        // Check that intervals have positive width
597        for (lower, upper) in intervals {
598            assert!(upper > lower, "Interval should have positive width");
599        }
600    }
601
602    #[test]
603    fn test_conformal_prediction_classification() {
604        let config = ConformalPredictionConfig {
605            nonconformity_method: NonconformityMethod::InverseProbability,
606            ..ConformalPredictionConfig::default()
607        };
608        let mut predictor = ConformalPredictor::new(config);
609
610        // Create synthetic calibration data (3 classes)
611        let cal_probs = vec![
612            vec![0.8, 0.1, 0.1],
613            vec![0.2, 0.7, 0.1],
614            vec![0.1, 0.2, 0.7],
615            vec![0.6, 0.3, 0.1],
616            vec![0.1, 0.1, 0.8],
617        ];
618        let cal_labels = vec![0, 1, 2, 0, 2];
619
620        predictor
621            .fit_classification(&cal_probs, &cal_labels)
622            .unwrap();
623
624        // Make predictions
625        let test_probs = vec![vec![0.5, 0.3, 0.2], vec![0.2, 0.6, 0.2]];
626        let result = predictor.predict_sets(&test_probs).unwrap();
627
628        assert!(result.prediction_sets.is_some());
629        let sets = result.prediction_sets.unwrap();
630        assert_eq!(sets.len(), 2);
631    }
632
633    #[test]
634    fn test_coverage_evaluation() {
635        let config = ConformalPredictionConfig {
636            alpha: 0.2, // 80% coverage
637            ..Default::default()
638        };
639        let mut predictor = ConformalPredictor::new(config);
640
641        // Create calibration data
642        let cal_preds = vec![1.0, 2.0, 3.0, 4.0, 5.0];
643        let cal_targets = vec![1.0, 2.0, 3.0, 4.0, 5.0]; // Perfect predictions
644
645        predictor.fit(&cal_preds, &cal_targets).unwrap();
646
647        // Evaluate on test data
648        let test_preds = vec![1.5, 2.5];
649        let test_targets = vec![1.5, 2.5]; // Also perfect
650        let coverage = predictor
651            .evaluate_coverage(&test_preds, &test_targets, None)
652            .unwrap();
653
654        assert!(coverage.empirical_coverage >= 0.0);
655        assert!(coverage.empirical_coverage <= 1.0);
656        assert_eq!(coverage.target_coverage, 0.8);
657    }
658}