sklears_model_selection/epistemic_uncertainty/
epistemic_quantifier.rs

1use super::bayesian_methods::*;
2use super::calibration::CalibrationMethod;
3use super::ensemble_methods::*;
4use super::monte_carlo_methods::*;
5use super::uncertainty_config::EpistemicUncertaintyConfig;
6use super::uncertainty_methods::EpistemicUncertaintyMethod;
7use super::uncertainty_results::EpistemicUncertaintyResult;
8use super::uncertainty_types::*;
9use scirs2_core::ndarray::{Array1, Array2};
10// use scirs2_core::numeric::Float;
11use scirs2_core::random::Random;
12
13#[derive(Debug, Clone)]
14pub struct EpistemicUncertaintyQuantifier {
15    config: EpistemicUncertaintyConfig,
16}
17
18impl EpistemicUncertaintyQuantifier {
19    pub fn new() -> Self {
20        Self {
21            config: EpistemicUncertaintyConfig::default(),
22        }
23    }
24
25    pub fn with_config(config: EpistemicUncertaintyConfig) -> Self {
26        Self { config }
27    }
28
29    pub fn method(mut self, method: EpistemicUncertaintyMethod) -> Self {
30        self.config.method = method;
31        self
32    }
33
34    pub fn confidence_level(mut self, level: f64) -> Self {
35        self.config.confidence_level = level;
36        self
37    }
38
39    pub fn random_state(mut self, seed: u64) -> Self {
40        self.config.random_state = Some(seed);
41        self
42    }
43
44    pub fn calibration_method(mut self, method: CalibrationMethod) -> Self {
45        self.config.calibration_method = method;
46        self
47    }
48
49    pub fn temperature_scaling(mut self, enable: bool) -> Self {
50        self.config.temperature_scaling = enable;
51        self
52    }
53
54    pub fn quantify<E, P>(
55        &self,
56        models: &[E],
57        x: &Array2<f64>,
58        y_true: Option<&Array1<f64>>,
59    ) -> Result<EpistemicUncertaintyResult, Box<dyn std::error::Error>>
60    where
61        E: Clone,
62        P: Clone,
63    {
64        let mut rng = match self.config.random_state {
65            Some(seed) => Random::seed(seed),
66            None => Random::seed(42),
67        };
68
69        let (predictions, uncertainties) = match &self.config.method {
70            EpistemicUncertaintyMethod::MonteCarloDropout {
71                dropout_rate,
72                n_samples,
73            } => monte_carlo_dropout_uncertainty(models, x, *dropout_rate, *n_samples, &mut rng)?,
74            EpistemicUncertaintyMethod::DeepEnsembles { n_models } => {
75                deep_ensemble_uncertainty(models, x, *n_models)?
76            }
77            EpistemicUncertaintyMethod::BayesianNeuralNetwork { n_samples } => {
78                bayesian_neural_network_uncertainty(models, x, *n_samples, &mut rng)?
79            }
80            EpistemicUncertaintyMethod::Bootstrap {
81                n_bootstrap,
82                sample_ratio,
83            } => bootstrap_uncertainty(models, x, *n_bootstrap, *sample_ratio, &mut rng)?,
84            EpistemicUncertaintyMethod::GaussianProcess { kernel_type } => {
85                gaussian_process_uncertainty(models, x, kernel_type)?
86            }
87            EpistemicUncertaintyMethod::VariationalInference { n_samples } => {
88                variational_inference_uncertainty(models, x, *n_samples, &mut rng)?
89            }
90            EpistemicUncertaintyMethod::LaplaceApproximation { hessian_method } => {
91                laplace_approximation_uncertainty(models, x, hessian_method)?
92            }
93        };
94
95        let alpha = 1.0 - self.config.confidence_level;
96        let lower_quantile = alpha / 2.0;
97        let upper_quantile = 1.0 - alpha / 2.0;
98
99        let prediction_intervals = self.compute_prediction_intervals(
100            &predictions,
101            &uncertainties,
102            lower_quantile,
103            upper_quantile,
104        )?;
105
106        let entropy = self.compute_entropy(&predictions)?;
107        let mutual_information = self.compute_mutual_information(&predictions)?;
108
109        let epistemic_uncertainty_components = UncertaintyComponents {
110            model_uncertainty: uncertainties.clone(),
111            data_uncertainty: Array1::zeros(uncertainties.len()),
112            parameter_uncertainty: uncertainties.clone(),
113            structural_uncertainty: Array1::zeros(uncertainties.len()),
114            approximation_uncertainty: Array1::zeros(uncertainties.len()),
115        };
116
117        let calibration_score = match y_true {
118            Some(y) => self.compute_calibration_score(&predictions, &uncertainties, y)?,
119            None => 0.0,
120        };
121
122        let reliability_metrics =
123            self.compute_reliability_metrics(&predictions, &uncertainties, y_true)?;
124
125        Ok(EpistemicUncertaintyResult {
126            predictions,
127            uncertainties,
128            prediction_intervals,
129            calibration_score,
130            entropy,
131            mutual_information,
132            epistemic_uncertainty_components,
133            reliability_metrics,
134        })
135    }
136
137    fn compute_prediction_intervals(
138        &self,
139        predictions: &Array1<f64>,
140        uncertainties: &Array1<f64>,
141        lower_quantile: f64,
142        upper_quantile: f64,
143    ) -> Result<Array2<f64>, Box<dyn std::error::Error>> {
144        let n = predictions.len();
145        let mut intervals = Array2::<f64>::zeros((n, 2));
146
147        for i in 0..n {
148            let std_dev = uncertainties[i].sqrt();
149            let z_lower = normal_quantile(lower_quantile);
150            let z_upper = normal_quantile(upper_quantile);
151
152            intervals[[i, 0]] = predictions[i] + z_lower * std_dev;
153            intervals[[i, 1]] = predictions[i] + z_upper * std_dev;
154        }
155
156        Ok(intervals)
157    }
158
159    fn compute_entropy(
160        &self,
161        predictions: &Array1<f64>,
162    ) -> Result<Array1<f64>, Box<dyn std::error::Error>> {
163        let entropy = predictions.mapv(|p| {
164            if p > 0.0 && p < 1.0 {
165                -p * p.ln() - (1.0 - p) * (1.0 - p).ln()
166            } else {
167                0.0
168            }
169        });
170        Ok(entropy)
171    }
172
173    fn compute_mutual_information(
174        &self,
175        predictions: &Array1<f64>,
176    ) -> Result<f64, Box<dyn std::error::Error>> {
177        let mean_entropy = predictions
178            .iter()
179            .map(|&p| {
180                if p > 0.0 && p < 1.0 {
181                    -p * p.ln() - (1.0 - p) * (1.0 - p).ln()
182                } else {
183                    0.0
184                }
185            })
186            .sum::<f64>()
187            / predictions.len() as f64;
188
189        let mean_prediction = predictions.mean().unwrap_or(0.0);
190        let entropy_of_mean = if mean_prediction > 0.0 && mean_prediction < 1.0 {
191            -mean_prediction * mean_prediction.ln()
192                - (1.0 - mean_prediction) * (1.0 - mean_prediction).ln()
193        } else {
194            0.0
195        };
196
197        Ok(entropy_of_mean - mean_entropy)
198    }
199
200    fn compute_calibration_score(
201        &self,
202        predictions: &Array1<f64>,
203        uncertainties: &Array1<f64>,
204        y_true: &Array1<f64>,
205    ) -> Result<f64, Box<dyn std::error::Error>> {
206        let n_bins = 10;
207        let mut calibration_error = 0.0;
208
209        for bin_idx in 0..n_bins {
210            let lower_bound = bin_idx as f64 / n_bins as f64;
211            let upper_bound = (bin_idx + 1) as f64 / n_bins as f64;
212
213            let mut bin_predictions = Vec::new();
214            let mut bin_true_values = Vec::new();
215
216            for i in 0..predictions.len() {
217                let confidence = 1.0 - uncertainties[i];
218                if confidence > lower_bound && confidence <= upper_bound {
219                    bin_predictions.push(predictions[i]);
220                    bin_true_values.push(y_true[i]);
221                }
222            }
223
224            if !bin_predictions.is_empty() {
225                let bin_accuracy = bin_predictions
226                    .iter()
227                    .zip(bin_true_values.iter())
228                    .map(|(&pred, &true_val)| {
229                        if (pred - true_val).abs() < 0.1 {
230                            1.0
231                        } else {
232                            0.0
233                        }
234                    })
235                    .sum::<f64>()
236                    / bin_predictions.len() as f64;
237
238                let bin_confidence = (lower_bound + upper_bound) / 2.0;
239                calibration_error += (bin_accuracy - bin_confidence).abs()
240                    * bin_predictions.len() as f64
241                    / predictions.len() as f64;
242            }
243        }
244
245        Ok(calibration_error)
246    }
247
248    fn compute_reliability_metrics(
249        &self,
250        predictions: &Array1<f64>,
251        uncertainties: &Array1<f64>,
252        y_true: Option<&Array1<f64>>,
253    ) -> Result<ReliabilityMetrics, Box<dyn std::error::Error>> {
254        let calibration_error = match y_true {
255            Some(y) => self.compute_calibration_score(predictions, uncertainties, y)?,
256            None => 0.0,
257        };
258
259        let sharpness = uncertainties.mean().unwrap_or(0.0);
260        let reliability_score = 1.0 - calibration_error;
261        let coverage_probability = 0.95; // Placeholder
262        let prediction_interval_score = 0.0; // Placeholder
263        let continuous_ranked_probability_score = 0.0; // Placeholder
264
265        Ok(ReliabilityMetrics {
266            calibration_error,
267            sharpness,
268            reliability_score,
269            coverage_probability,
270            prediction_interval_score,
271            continuous_ranked_probability_score,
272        })
273    }
274
275    // Getter methods for testing
276    pub fn config(&self) -> &EpistemicUncertaintyConfig {
277        &self.config
278    }
279}
280
281impl Default for EpistemicUncertaintyQuantifier {
282    fn default() -> Self {
283        Self::new()
284    }
285}
286
287fn normal_quantile(p: f64) -> f64 {
288    // Simplified normal quantile approximation
289    if p <= 0.0 {
290        return f64::NEG_INFINITY;
291    }
292    if p >= 1.0 {
293        return f64::INFINITY;
294    }
295    if p == 0.5 {
296        return 0.0;
297    }
298
299    // Box-Muller approximation for normal quantile
300    let c0 = 2.515517;
301    let c1 = 0.802853;
302    let c2 = 0.010328;
303    let d1 = 1.432788;
304    let d2 = 0.189269;
305    let d3 = 0.001308;
306
307    let t = if p < 0.5 {
308        (-2.0 * p.ln()).sqrt()
309    } else {
310        (-2.0 * (1.0 - p).ln()).sqrt()
311    };
312    let numerator = c0 + c1 * t + c2 * t * t;
313    let denominator = 1.0 + d1 * t + d2 * t * t + d3 * t * t * t;
314    let result = t - numerator / denominator;
315
316    if p < 0.5 {
317        -result
318    } else {
319        result
320    }
321}