Skip to main content

sklears_model_selection/
conformal_prediction.rs

1//! Conformal prediction methods for uncertainty quantification
2//!
3//! This module provides conformal prediction algorithms that generate
4//! prediction intervals with finite-sample validity guarantees.
5
6use sklears_core::error::{Result, SklearsError};
7use std::cmp::Ordering;
8use std::collections::HashMap;
9
10/// Configuration for conformal prediction
11#[derive(Debug, Clone)]
12pub struct ConformalPredictionConfig {
13    /// Significance level (1 - coverage probability)
14    pub alpha: f64,
15    /// Method for computing nonconformity scores
16    pub nonconformity_method: NonconformityMethod,
17    /// Whether to use normalized nonconformity scores
18    pub normalize: bool,
19    /// Method for handling class imbalance in classification
20    pub class_conditional: bool,
21    /// Random state for reproducible results
22    pub random_state: Option<u64>,
23    /// Whether to use inductive (split) conformal prediction
24    pub inductive: bool,
25    /// Fraction of data to use for calibration in inductive setting
26    pub calibration_fraction: f64,
27}
28
29impl Default for ConformalPredictionConfig {
30    fn default() -> Self {
31        Self {
32            alpha: 0.1, // 90% coverage
33            nonconformity_method: NonconformityMethod::AbsoluteError,
34            normalize: false,
35            class_conditional: false,
36            random_state: None,
37            inductive: true,
38            calibration_fraction: 0.2,
39        }
40    }
41}
42
43/// Methods for computing nonconformity scores
44#[derive(Debug, Clone)]
45pub enum NonconformityMethod {
46    /// Absolute error: |y - ŷ|
47    AbsoluteError,
48    /// Squared error: (y - ŷ)²
49    SquaredError,
50    /// Signed error: y - ŷ (for quantile regression)
51    SignedError,
52    /// Margin-based (for classification): margin to true class
53    Margin,
54    /// Inverse probability score (for classification)
55    InverseProbability,
56    /// Custom nonconformity function
57    Custom(fn(&[f64], &[f64]) -> Vec<f64>),
58}
59
60/// Results from conformal prediction
61#[derive(Debug, Clone)]
62pub struct ConformalPredictionResult {
63    /// Prediction intervals (lower, upper) for regression
64    pub prediction_intervals: Option<Vec<(f64, f64)>>,
65    /// Prediction sets for classification
66    pub prediction_sets: Option<Vec<Vec<usize>>>,
67    /// Nonconformity scores used for calibration
68    pub calibration_scores: Vec<f64>,
69    /// Quantile threshold used for prediction intervals
70    pub quantile_threshold: f64,
71    /// Coverage statistics
72    pub coverage_stats: CoverageStatistics,
73    /// Efficiency metrics
74    pub efficiency_metrics: EfficiencyMetrics,
75}
76
77/// Coverage statistics for conformal prediction
78#[derive(Debug, Clone)]
79pub struct CoverageStatistics {
80    /// Empirical coverage rate
81    pub empirical_coverage: f64,
82    /// Target coverage rate (1 - alpha)
83    pub target_coverage: f64,
84    /// Coverage difference from target
85    pub coverage_gap: f64,
86    /// Coverage per class (for classification)
87    pub class_coverage: Option<HashMap<usize, f64>>,
88}
89
90/// Efficiency metrics for conformal prediction
91#[derive(Debug, Clone)]
92pub struct EfficiencyMetrics {
93    /// Average interval width (regression)
94    pub average_interval_width: Option<f64>,
95    /// Average set size (classification)
96    pub average_set_size: Option<f64>,
97    /// Interval width variability
98    pub interval_width_std: Option<f64>,
99    /// Set size variability
100    pub set_size_std: Option<f64>,
101    /// Singleton rate (classification: fraction with single prediction)
102    pub singleton_rate: Option<f64>,
103    /// Empty set rate (classification: fraction with no predictions)
104    pub empty_set_rate: Option<f64>,
105}
106
107/// Conformal predictor for regression and classification
108#[derive(Debug, Clone)]
109pub struct ConformalPredictor {
110    config: ConformalPredictionConfig,
111    calibration_scores: Option<Vec<f64>>,
112    quantile_threshold: Option<f64>,
113    class_thresholds: Option<HashMap<usize, f64>>,
114}
115
116impl ConformalPredictor {
117    pub fn new(config: ConformalPredictionConfig) -> Self {
118        Self {
119            config,
120            calibration_scores: None,
121            quantile_threshold: None,
122            class_thresholds: None,
123        }
124    }
125
126    /// Fit conformal predictor on calibration data
127    pub fn fit(
128        &mut self,
129        calibration_predictions: &[f64],
130        calibration_targets: &[f64],
131    ) -> Result<()> {
132        if calibration_predictions.len() != calibration_targets.len() {
133            return Err(SklearsError::InvalidInput(
134                "Predictions and targets must have the same length".to_string(),
135            ));
136        }
137
138        // Compute nonconformity scores
139        let scores =
140            self.compute_nonconformity_scores(calibration_predictions, calibration_targets)?;
141
142        // Calculate quantile threshold
143        let quantile_level = 1.0 - self.config.alpha;
144        let threshold = self.compute_quantile(&scores, quantile_level);
145
146        self.calibration_scores = Some(scores);
147        self.quantile_threshold = Some(threshold);
148
149        Ok(())
150    }
151
152    /// Fit conformal predictor for classification
153    pub fn fit_classification(
154        &mut self,
155        calibration_probabilities: &[Vec<f64>],
156        calibration_labels: &[usize],
157    ) -> Result<()> {
158        if calibration_probabilities.len() != calibration_labels.len() {
159            return Err(SklearsError::InvalidInput(
160                "Probabilities and labels must have the same length".to_string(),
161            ));
162        }
163
164        let scores =
165            self.compute_classification_scores(calibration_probabilities, calibration_labels)?;
166
167        if self.config.class_conditional {
168            // Compute separate thresholds for each class
169            let mut class_thresholds = HashMap::new();
170            let unique_classes = self.get_unique_classes(calibration_labels);
171
172            for &class in &unique_classes {
173                let class_scores: Vec<f64> = scores
174                    .iter()
175                    .enumerate()
176                    .filter(|(i, _)| calibration_labels[*i] == class)
177                    .map(|(_, &score)| score)
178                    .collect();
179
180                if !class_scores.is_empty() {
181                    let quantile_level = 1.0 - self.config.alpha;
182                    let threshold = self.compute_quantile(&class_scores, quantile_level);
183                    class_thresholds.insert(class, threshold);
184                }
185            }
186
187            self.class_thresholds = Some(class_thresholds);
188        } else {
189            // Single threshold for all classes
190            let quantile_level = 1.0 - self.config.alpha;
191            let threshold = self.compute_quantile(&scores, quantile_level);
192            self.quantile_threshold = Some(threshold);
193        }
194
195        self.calibration_scores = Some(scores);
196
197        Ok(())
198    }
199
200    /// Generate prediction intervals for regression
201    pub fn predict_intervals(
202        &self,
203        predictions: &[f64],
204        prediction_errors: Option<&[f64]>,
205    ) -> Result<ConformalPredictionResult> {
206        if self.quantile_threshold.is_none() {
207            return Err(SklearsError::NotFitted {
208                operation: "making predictions".to_string(),
209            });
210        }
211
212        let threshold = self.quantile_threshold.expect("operation should succeed");
213        let mut intervals = Vec::new();
214
215        for (i, &pred) in predictions.iter().enumerate() {
216            let error_scale =
217                if let (true, Some(errors)) = (self.config.normalize, &prediction_errors) {
218                    errors[i].max(1e-8) // Avoid division by zero
219                } else {
220                    1.0
221                };
222
223            let margin = threshold * error_scale;
224            intervals.push((pred - margin, pred + margin));
225        }
226
227        // Calculate efficiency metrics
228        let average_width =
229            intervals.iter().map(|(l, u)| u - l).sum::<f64>() / intervals.len() as f64;
230        let width_std =
231            self.calculate_std(&intervals.iter().map(|(l, u)| u - l).collect::<Vec<_>>());
232
233        let efficiency_metrics = EfficiencyMetrics {
234            average_interval_width: Some(average_width),
235            average_set_size: None,
236            interval_width_std: Some(width_std),
237            set_size_std: None,
238            singleton_rate: None,
239            empty_set_rate: None,
240        };
241
242        let coverage_stats = CoverageStatistics {
243            empirical_coverage: 0.0, // Would need true labels to compute
244            target_coverage: 1.0 - self.config.alpha,
245            coverage_gap: 0.0,
246            class_coverage: None,
247        };
248
249        Ok(ConformalPredictionResult {
250            prediction_intervals: Some(intervals),
251            prediction_sets: None,
252            calibration_scores: self.calibration_scores.clone().unwrap_or_default(),
253            quantile_threshold: threshold,
254            coverage_stats,
255            efficiency_metrics,
256        })
257    }
258
259    /// Generate prediction sets for classification
260    pub fn predict_sets(
261        &self,
262        prediction_probabilities: &[Vec<f64>],
263    ) -> Result<ConformalPredictionResult> {
264        if self.quantile_threshold.is_none() && self.class_thresholds.is_none() {
265            return Err(SklearsError::NotFitted {
266                operation: "making predictions".to_string(),
267            });
268        }
269
270        let mut prediction_sets = Vec::new();
271
272        for probs in prediction_probabilities {
273            let mut prediction_set = Vec::new();
274
275            for (class_idx, &prob) in probs.iter().enumerate() {
276                let threshold = if let Some(ref class_thresholds) = self.class_thresholds {
277                    class_thresholds.get(&class_idx).copied().unwrap_or(0.0)
278                } else {
279                    self.quantile_threshold.expect("operation should succeed")
280                };
281
282                // Include class if its score is below threshold
283                let score = match self.config.nonconformity_method {
284                    NonconformityMethod::InverseProbability => 1.0 - prob,
285                    _ => 1.0 - prob, // Default for classification
286                };
287
288                if score <= threshold {
289                    prediction_set.push(class_idx);
290                }
291            }
292
293            prediction_sets.push(prediction_set);
294        }
295
296        // Calculate efficiency metrics
297        let set_sizes: Vec<f64> = prediction_sets.iter().map(|s| s.len() as f64).collect();
298        let average_set_size = set_sizes.iter().sum::<f64>() / set_sizes.len() as f64;
299        let set_size_std = self.calculate_std(&set_sizes);
300        let singleton_rate =
301            set_sizes.iter().filter(|&&size| size == 1.0).count() as f64 / set_sizes.len() as f64;
302        let empty_set_rate =
303            set_sizes.iter().filter(|&&size| size == 0.0).count() as f64 / set_sizes.len() as f64;
304
305        let efficiency_metrics = EfficiencyMetrics {
306            average_interval_width: None,
307            average_set_size: Some(average_set_size),
308            interval_width_std: None,
309            set_size_std: Some(set_size_std),
310            singleton_rate: Some(singleton_rate),
311            empty_set_rate: Some(empty_set_rate),
312        };
313
314        let coverage_stats = CoverageStatistics {
315            empirical_coverage: 0.0, // Would need true labels to compute
316            target_coverage: 1.0 - self.config.alpha,
317            coverage_gap: 0.0,
318            class_coverage: None,
319        };
320
321        let threshold = self.quantile_threshold.unwrap_or(0.0);
322
323        Ok(ConformalPredictionResult {
324            prediction_intervals: None,
325            prediction_sets: Some(prediction_sets),
326            calibration_scores: self.calibration_scores.clone().unwrap_or_default(),
327            quantile_threshold: threshold,
328            coverage_stats,
329            efficiency_metrics,
330        })
331    }
332
333    /// Evaluate coverage and efficiency on test data
334    pub fn evaluate_coverage(
335        &self,
336        predictions: &[f64],
337        true_values: &[f64],
338        prediction_errors: Option<&[f64]>,
339    ) -> Result<CoverageStatistics> {
340        let result = self.predict_intervals(predictions, prediction_errors)?;
341        let intervals = result
342            .prediction_intervals
343            .expect("operation should succeed");
344
345        let mut covered = 0;
346        for (i, &true_val) in true_values.iter().enumerate() {
347            let (lower, upper) = intervals[i];
348            if true_val >= lower && true_val <= upper {
349                covered += 1;
350            }
351        }
352
353        let empirical_coverage = covered as f64 / true_values.len() as f64;
354        let target_coverage = 1.0 - self.config.alpha;
355        let coverage_gap = empirical_coverage - target_coverage;
356
357        Ok(CoverageStatistics {
358            empirical_coverage,
359            target_coverage,
360            coverage_gap,
361            class_coverage: None,
362        })
363    }
364
365    /// Evaluate classification coverage
366    pub fn evaluate_classification_coverage(
367        &self,
368        prediction_probabilities: &[Vec<f64>],
369        true_labels: &[usize],
370    ) -> Result<CoverageStatistics> {
371        let result = self.predict_sets(prediction_probabilities)?;
372        let prediction_sets = result.prediction_sets.expect("operation should succeed");
373
374        let mut covered = 0;
375        let mut class_coverage_counts: HashMap<usize, (usize, usize)> = HashMap::new();
376
377        for (i, &true_label) in true_labels.iter().enumerate() {
378            let prediction_set = &prediction_sets[i];
379            let is_covered = prediction_set.contains(&true_label);
380
381            if is_covered {
382                covered += 1;
383            }
384
385            // Track class-specific coverage
386            let (class_covered, class_total) =
387                class_coverage_counts.entry(true_label).or_insert((0, 0));
388            if is_covered {
389                *class_covered += 1;
390            }
391            *class_total += 1;
392        }
393
394        let empirical_coverage = covered as f64 / true_labels.len() as f64;
395        let target_coverage = 1.0 - self.config.alpha;
396        let coverage_gap = empirical_coverage - target_coverage;
397
398        // Calculate per-class coverage
399        let mut class_coverage = HashMap::new();
400        for (&class, &(covered_count, total_count)) in &class_coverage_counts {
401            class_coverage.insert(class, covered_count as f64 / total_count as f64);
402        }
403
404        Ok(CoverageStatistics {
405            empirical_coverage,
406            target_coverage,
407            coverage_gap,
408            class_coverage: Some(class_coverage),
409        })
410    }
411
412    /// Compute nonconformity scores
413    fn compute_nonconformity_scores(
414        &self,
415        predictions: &[f64],
416        targets: &[f64],
417    ) -> Result<Vec<f64>> {
418        match self.config.nonconformity_method {
419            NonconformityMethod::AbsoluteError => Ok(predictions
420                .iter()
421                .zip(targets.iter())
422                .map(|(&pred, &target)| (target - pred).abs())
423                .collect()),
424            NonconformityMethod::SquaredError => Ok(predictions
425                .iter()
426                .zip(targets.iter())
427                .map(|(&pred, &target)| (target - pred).powi(2))
428                .collect()),
429            NonconformityMethod::SignedError => Ok(predictions
430                .iter()
431                .zip(targets.iter())
432                .map(|(&pred, &target)| target - pred)
433                .collect()),
434            NonconformityMethod::Custom(func) => Ok(func(predictions, targets)),
435            _ => Err(SklearsError::InvalidInput(
436                "Invalid nonconformity method for regression".to_string(),
437            )),
438        }
439    }
440
441    /// Compute classification nonconformity scores
442    fn compute_classification_scores(
443        &self,
444        probabilities: &[Vec<f64>],
445        labels: &[usize],
446    ) -> Result<Vec<f64>> {
447        match self.config.nonconformity_method {
448            NonconformityMethod::InverseProbability => {
449                let scores = probabilities
450                    .iter()
451                    .zip(labels.iter())
452                    .map(|(probs, &label)| 1.0 - probs.get(label).copied().unwrap_or(0.0))
453                    .collect();
454                Ok(scores)
455            }
456            NonconformityMethod::Margin => {
457                let scores = probabilities
458                    .iter()
459                    .zip(labels.iter())
460                    .map(|(probs, &label)| {
461                        let true_class_prob = probs.get(label).copied().unwrap_or(0.0);
462                        let max_other_prob = probs
463                            .iter()
464                            .enumerate()
465                            .filter(|(i, _)| *i != label)
466                            .map(|(_, &prob)| prob)
467                            .fold(0.0, f64::max);
468                        max_other_prob - true_class_prob
469                    })
470                    .collect();
471                Ok(scores)
472            }
473            _ => Err(SklearsError::InvalidInput(
474                "Invalid nonconformity method for classification".to_string(),
475            )),
476        }
477    }
478
479    /// Compute quantile of a vector
480    fn compute_quantile(&self, values: &[f64], quantile: f64) -> f64 {
481        if values.is_empty() {
482            return 0.0;
483        }
484
485        let mut sorted_values = values.to_vec();
486        sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
487
488        let n = sorted_values.len();
489        let index = (quantile * (n + 1) as f64).ceil() as usize;
490        let index = index.min(n).saturating_sub(1);
491
492        sorted_values[index]
493    }
494
495    /// Get unique classes from labels
496    fn get_unique_classes(&self, labels: &[usize]) -> Vec<usize> {
497        let mut unique_classes: Vec<usize> = labels.to_vec();
498        unique_classes.sort_unstable();
499        unique_classes.dedup();
500        unique_classes
501    }
502
503    /// Calculate standard deviation
504    fn calculate_std(&self, values: &[f64]) -> f64 {
505        if values.len() < 2 {
506            return 0.0;
507        }
508
509        let mean = values.iter().sum::<f64>() / values.len() as f64;
510        let variance =
511            values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
512
513        variance.sqrt()
514    }
515}
516
517/// Jackknife+ conformal prediction for better efficiency
518#[derive(Debug, Clone)]
519pub struct JackknifeConformalPredictor {
520    base_predictor: ConformalPredictor,
521    jackknife_predictions: Option<Vec<Vec<f64>>>,
522}
523
524impl JackknifeConformalPredictor {
525    pub fn new(config: ConformalPredictionConfig) -> Self {
526        Self {
527            base_predictor: ConformalPredictor::new(config),
528            jackknife_predictions: None,
529        }
530    }
531
532    /// Fit using jackknife+ method
533    pub fn fit_jackknife(
534        &mut self,
535        all_predictions: &[Vec<f64>], // Predictions from leave-one-out models
536        targets: &[f64],
537    ) -> Result<()> {
538        if all_predictions.len() != targets.len() {
539            return Err(SklearsError::InvalidInput(
540                "Predictions and targets must have the same length".to_string(),
541            ));
542        }
543
544        // Compute residuals for each jackknife prediction
545        let mut residuals = Vec::new();
546        for (i, preds) in all_predictions.iter().enumerate() {
547            if !preds.is_empty() {
548                let residual = (targets[i] - preds[0]).abs(); // Use first prediction
549                residuals.push(residual);
550            }
551        }
552
553        // Use residuals as calibration scores
554        self.base_predictor.calibration_scores = Some(residuals.clone());
555        let quantile_level = 1.0 - self.base_predictor.config.alpha;
556        let threshold = self
557            .base_predictor
558            .compute_quantile(&residuals, quantile_level);
559        self.base_predictor.quantile_threshold = Some(threshold);
560        self.jackknife_predictions = Some(all_predictions.to_vec());
561
562        Ok(())
563    }
564
565    /// Generate jackknife+ prediction intervals
566    pub fn predict_jackknife_intervals(
567        &self,
568        predictions: &[f64],
569    ) -> Result<ConformalPredictionResult> {
570        self.base_predictor.predict_intervals(predictions, None)
571    }
572}
573
574#[allow(non_snake_case)]
575#[cfg(test)]
576mod tests {
577    use super::*;
578
579    #[test]
580    fn test_conformal_prediction_regression() {
581        let config = ConformalPredictionConfig::default();
582        let mut predictor = ConformalPredictor::new(config);
583
584        // Create synthetic calibration data
585        let cal_preds = vec![1.0, 2.0, 3.0, 4.0, 5.0];
586        let cal_targets = vec![1.1, 1.9, 3.2, 3.8, 5.1];
587
588        predictor
589            .fit(&cal_preds, &cal_targets)
590            .expect("operation should succeed");
591
592        // Make predictions
593        let test_preds = vec![2.5, 4.5];
594        let result = predictor
595            .predict_intervals(&test_preds, None)
596            .expect("operation should succeed");
597
598        assert!(result.prediction_intervals.is_some());
599        let intervals = result
600            .prediction_intervals
601            .expect("operation should succeed");
602        assert_eq!(intervals.len(), 2);
603
604        // Check that intervals have positive width
605        for (lower, upper) in intervals {
606            assert!(upper > lower, "Interval should have positive width");
607        }
608    }
609
610    #[test]
611    fn test_conformal_prediction_classification() {
612        let config = ConformalPredictionConfig {
613            nonconformity_method: NonconformityMethod::InverseProbability,
614            ..ConformalPredictionConfig::default()
615        };
616        let mut predictor = ConformalPredictor::new(config);
617
618        // Create synthetic calibration data (3 classes)
619        let cal_probs = vec![
620            vec![0.8, 0.1, 0.1],
621            vec![0.2, 0.7, 0.1],
622            vec![0.1, 0.2, 0.7],
623            vec![0.6, 0.3, 0.1],
624            vec![0.1, 0.1, 0.8],
625        ];
626        let cal_labels = vec![0, 1, 2, 0, 2];
627
628        predictor
629            .fit_classification(&cal_probs, &cal_labels)
630            .expect("operation should succeed");
631
632        // Make predictions
633        let test_probs = vec![vec![0.5, 0.3, 0.2], vec![0.2, 0.6, 0.2]];
634        let result = predictor
635            .predict_sets(&test_probs)
636            .expect("operation should succeed");
637
638        assert!(result.prediction_sets.is_some());
639        let sets = result.prediction_sets.expect("operation should succeed");
640        assert_eq!(sets.len(), 2);
641    }
642
643    #[test]
644    fn test_coverage_evaluation() {
645        let config = ConformalPredictionConfig {
646            alpha: 0.2, // 80% coverage
647            ..Default::default()
648        };
649        let mut predictor = ConformalPredictor::new(config);
650
651        // Create calibration data
652        let cal_preds = vec![1.0, 2.0, 3.0, 4.0, 5.0];
653        let cal_targets = vec![1.0, 2.0, 3.0, 4.0, 5.0]; // Perfect predictions
654
655        predictor
656            .fit(&cal_preds, &cal_targets)
657            .expect("operation should succeed");
658
659        // Evaluate on test data
660        let test_preds = vec![1.5, 2.5];
661        let test_targets = vec![1.5, 2.5]; // Also perfect
662        let coverage = predictor
663            .evaluate_coverage(&test_preds, &test_targets, None)
664            .expect("operation should succeed");
665
666        assert!(coverage.empirical_coverage >= 0.0);
667        assert!(coverage.empirical_coverage <= 1.0);
668        assert_eq!(coverage.target_coverage, 0.8);
669    }
670}