sklears_model_selection/epistemic_uncertainty/
uncertainty_quantifier.rs

1use super::aleatoric_quantifier::AleatoricUncertaintyQuantifier;
2use super::epistemic_quantifier::EpistemicUncertaintyQuantifier;
3use super::uncertainty_config::*;
4use super::uncertainty_decomposition::*;
5use super::uncertainty_results::*;
6use super::uncertainty_types::*;
7use scirs2_core::ndarray::{Array1, Array2};
8// use scirs2_core::numeric::Float;
9
10#[derive(Debug, Clone)]
11pub struct UncertaintyQuantifier {
12    config: UncertaintyQuantificationConfig,
13}
14
15impl UncertaintyQuantifier {
16    pub fn new() -> Self {
17        Self {
18            config: UncertaintyQuantificationConfig::default(),
19        }
20    }
21
22    pub fn with_config(config: UncertaintyQuantificationConfig) -> Self {
23        Self { config }
24    }
25
26    pub fn epistemic_config(mut self, config: EpistemicUncertaintyConfig) -> Self {
27        self.config.epistemic_config = config;
28        self
29    }
30
31    pub fn aleatoric_config(mut self, config: AleatoricUncertaintyConfig) -> Self {
32        self.config.aleatoric_config = config;
33        self
34    }
35
36    pub fn decomposition_method(mut self, method: UncertaintyDecompositionMethod) -> Self {
37        self.config.decomposition_method = method;
38        self
39    }
40
41    pub fn confidence_level(mut self, level: f64) -> Self {
42        self.config.confidence_level = level;
43        self
44    }
45
46    pub fn random_state(mut self, seed: u64) -> Self {
47        self.config.random_state = Some(seed);
48        self
49    }
50
51    pub fn quantify<E, P>(
52        &self,
53        models: &[E],
54        x: &Array2<f64>,
55        y_true: Option<&Array1<f64>>,
56    ) -> Result<UncertaintyQuantificationResult, Box<dyn std::error::Error>>
57    where
58        E: Clone,
59        P: Clone,
60    {
61        let epistemic_quantifier =
62            EpistemicUncertaintyQuantifier::with_config(self.config.epistemic_config.clone());
63        let aleatoric_quantifier =
64            AleatoricUncertaintyQuantifier::with_config(self.config.aleatoric_config.clone());
65
66        let epistemic_result = epistemic_quantifier.quantify::<E, P>(models, x, y_true)?;
67        let aleatoric_result = aleatoric_quantifier.quantify::<E, P>(models, x, y_true)?;
68
69        let total_uncertainty =
70            epistemic_result.uncertainties.clone() + &aleatoric_result.uncertainties;
71
72        let uncertainty_decomposition = self.decompose_uncertainty(
73            &epistemic_result.uncertainties,
74            &aleatoric_result.uncertainties,
75        )?;
76
77        let alpha = 1.0 - self.config.confidence_level;
78        let lower_quantile = alpha / 2.0;
79        let upper_quantile = 1.0 - alpha / 2.0;
80
81        let prediction_intervals = self.compute_combined_prediction_intervals(
82            &epistemic_result.predictions,
83            &total_uncertainty,
84            lower_quantile,
85            upper_quantile,
86        )?;
87
88        let calibration_score = (epistemic_result.calibration_score
89            + aleatoric_result.reliability_metrics.calibration_error)
90            / 2.0;
91
92        let reliability_metrics = ReliabilityMetrics {
93            calibration_error: calibration_score,
94            sharpness: total_uncertainty.mean().unwrap_or(0.0),
95            reliability_score: 1.0 - calibration_score,
96            coverage_probability: 0.95,
97            prediction_interval_score: 0.0,
98            continuous_ranked_probability_score: 0.0,
99        };
100
101        Ok(UncertaintyQuantificationResult {
102            predictions: epistemic_result.predictions.clone(),
103            total_uncertainty,
104            epistemic_uncertainty: epistemic_result.uncertainties.clone(),
105            aleatoric_uncertainty: aleatoric_result.uncertainties.clone(),
106            prediction_intervals,
107            uncertainty_decomposition,
108            calibration_score,
109            reliability_metrics,
110            epistemic_result,
111            aleatoric_result,
112        })
113    }
114
115    fn decompose_uncertainty(
116        &self,
117        epistemic_uncertainty: &Array1<f64>,
118        aleatoric_uncertainty: &Array1<f64>,
119    ) -> Result<UncertaintyDecomposition, Box<dyn std::error::Error>> {
120        let total_uncertainty = epistemic_uncertainty + aleatoric_uncertainty;
121        let _n = total_uncertainty.len();
122
123        let entropy_components = std::collections::HashMap::new();
124        let mutual_information = epistemic_uncertainty.mean().unwrap_or(0.0);
125        let explained_variance_ratio = epistemic_uncertainty.sum() / total_uncertainty.sum();
126        let uncertainty_ratios = epistemic_uncertainty / &total_uncertainty;
127
128        Ok(UncertaintyDecomposition {
129            total_uncertainty,
130            epistemic_uncertainty: epistemic_uncertainty.clone(),
131            aleatoric_uncertainty: aleatoric_uncertainty.clone(),
132            decomposition_method: format!("{:?}", self.config.decomposition_method),
133            entropy_components,
134            mutual_information,
135            explained_variance_ratio,
136            uncertainty_ratios,
137        })
138    }
139
140    fn compute_combined_prediction_intervals(
141        &self,
142        predictions: &Array1<f64>,
143        total_uncertainty: &Array1<f64>,
144        lower_quantile: f64,
145        upper_quantile: f64,
146    ) -> Result<Array2<f64>, Box<dyn std::error::Error>> {
147        let n = predictions.len();
148        let mut intervals = Array2::<f64>::zeros((n, 2));
149
150        for i in 0..n {
151            let std_dev = total_uncertainty[i].sqrt();
152            let z_lower = normal_quantile(lower_quantile);
153            let z_upper = normal_quantile(upper_quantile);
154
155            intervals[[i, 0]] = predictions[i] + z_lower * std_dev;
156            intervals[[i, 1]] = predictions[i] + z_upper * std_dev;
157        }
158
159        Ok(intervals)
160    }
161
162    // Getter methods for testing
163    pub fn config(&self) -> &UncertaintyQuantificationConfig {
164        &self.config
165    }
166}
167
168impl Default for UncertaintyQuantifier {
169    fn default() -> Self {
170        Self::new()
171    }
172}
173
174fn normal_quantile(p: f64) -> f64 {
175    if p <= 0.0 {
176        return f64::NEG_INFINITY;
177    }
178    if p >= 1.0 {
179        return f64::INFINITY;
180    }
181    if p == 0.5 {
182        return 0.0;
183    }
184
185    let c0 = 2.515517;
186    let c1 = 0.802853;
187    let c2 = 0.010328;
188    let d1 = 1.432788;
189    let d2 = 0.189269;
190    let d3 = 0.001308;
191
192    let t = if p < 0.5 {
193        (-2.0 * p.ln()).sqrt()
194    } else {
195        (-2.0 * (1.0 - p).ln()).sqrt()
196    };
197    let numerator = c0 + c1 * t + c2 * t * t;
198    let denominator = 1.0 + d1 * t + d2 * t * t + d3 * t * t * t;
199    let result = t - numerator / denominator;
200
201    if p < 0.5 {
202        -result
203    } else {
204        result
205    }
206}