sklears_model_selection/epistemic_uncertainty/
epistemic_quantifier.rs1use super::bayesian_methods::*;
2use super::calibration::CalibrationMethod;
3use super::ensemble_methods::*;
4use super::monte_carlo_methods::*;
5use super::uncertainty_config::EpistemicUncertaintyConfig;
6use super::uncertainty_methods::EpistemicUncertaintyMethod;
7use super::uncertainty_results::EpistemicUncertaintyResult;
8use super::uncertainty_types::*;
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::random::Random;
12
13#[derive(Debug, Clone)]
14pub struct EpistemicUncertaintyQuantifier {
15 config: EpistemicUncertaintyConfig,
16}
17
18impl EpistemicUncertaintyQuantifier {
19 pub fn new() -> Self {
20 Self {
21 config: EpistemicUncertaintyConfig::default(),
22 }
23 }
24
25 pub fn with_config(config: EpistemicUncertaintyConfig) -> Self {
26 Self { config }
27 }
28
29 pub fn method(mut self, method: EpistemicUncertaintyMethod) -> Self {
30 self.config.method = method;
31 self
32 }
33
34 pub fn confidence_level(mut self, level: f64) -> Self {
35 self.config.confidence_level = level;
36 self
37 }
38
39 pub fn random_state(mut self, seed: u64) -> Self {
40 self.config.random_state = Some(seed);
41 self
42 }
43
44 pub fn calibration_method(mut self, method: CalibrationMethod) -> Self {
45 self.config.calibration_method = method;
46 self
47 }
48
49 pub fn temperature_scaling(mut self, enable: bool) -> Self {
50 self.config.temperature_scaling = enable;
51 self
52 }
53
54 pub fn quantify<E, P>(
55 &self,
56 models: &[E],
57 x: &Array2<f64>,
58 y_true: Option<&Array1<f64>>,
59 ) -> Result<EpistemicUncertaintyResult, Box<dyn std::error::Error>>
60 where
61 E: Clone,
62 P: Clone,
63 {
64 let mut rng = match self.config.random_state {
65 Some(seed) => Random::seed(seed),
66 None => Random::seed(42),
67 };
68
69 let (predictions, uncertainties) = match &self.config.method {
70 EpistemicUncertaintyMethod::MonteCarloDropout {
71 dropout_rate,
72 n_samples,
73 } => monte_carlo_dropout_uncertainty(models, x, *dropout_rate, *n_samples, &mut rng)?,
74 EpistemicUncertaintyMethod::DeepEnsembles { n_models } => {
75 deep_ensemble_uncertainty(models, x, *n_models)?
76 }
77 EpistemicUncertaintyMethod::BayesianNeuralNetwork { n_samples } => {
78 bayesian_neural_network_uncertainty(models, x, *n_samples, &mut rng)?
79 }
80 EpistemicUncertaintyMethod::Bootstrap {
81 n_bootstrap,
82 sample_ratio,
83 } => bootstrap_uncertainty(models, x, *n_bootstrap, *sample_ratio, &mut rng)?,
84 EpistemicUncertaintyMethod::GaussianProcess { kernel_type } => {
85 gaussian_process_uncertainty(models, x, kernel_type)?
86 }
87 EpistemicUncertaintyMethod::VariationalInference { n_samples } => {
88 variational_inference_uncertainty(models, x, *n_samples, &mut rng)?
89 }
90 EpistemicUncertaintyMethod::LaplaceApproximation { hessian_method } => {
91 laplace_approximation_uncertainty(models, x, hessian_method)?
92 }
93 };
94
95 let alpha = 1.0 - self.config.confidence_level;
96 let lower_quantile = alpha / 2.0;
97 let upper_quantile = 1.0 - alpha / 2.0;
98
99 let prediction_intervals = self.compute_prediction_intervals(
100 &predictions,
101 &uncertainties,
102 lower_quantile,
103 upper_quantile,
104 )?;
105
106 let entropy = self.compute_entropy(&predictions)?;
107 let mutual_information = self.compute_mutual_information(&predictions)?;
108
109 let epistemic_uncertainty_components = UncertaintyComponents {
110 model_uncertainty: uncertainties.clone(),
111 data_uncertainty: Array1::zeros(uncertainties.len()),
112 parameter_uncertainty: uncertainties.clone(),
113 structural_uncertainty: Array1::zeros(uncertainties.len()),
114 approximation_uncertainty: Array1::zeros(uncertainties.len()),
115 };
116
117 let calibration_score = match y_true {
118 Some(y) => self.compute_calibration_score(&predictions, &uncertainties, y)?,
119 None => 0.0,
120 };
121
122 let reliability_metrics =
123 self.compute_reliability_metrics(&predictions, &uncertainties, y_true)?;
124
125 Ok(EpistemicUncertaintyResult {
126 predictions,
127 uncertainties,
128 prediction_intervals,
129 calibration_score,
130 entropy,
131 mutual_information,
132 epistemic_uncertainty_components,
133 reliability_metrics,
134 })
135 }
136
137 fn compute_prediction_intervals(
138 &self,
139 predictions: &Array1<f64>,
140 uncertainties: &Array1<f64>,
141 lower_quantile: f64,
142 upper_quantile: f64,
143 ) -> Result<Array2<f64>, Box<dyn std::error::Error>> {
144 let n = predictions.len();
145 let mut intervals = Array2::<f64>::zeros((n, 2));
146
147 for i in 0..n {
148 let std_dev = uncertainties[i].sqrt();
149 let z_lower = normal_quantile(lower_quantile);
150 let z_upper = normal_quantile(upper_quantile);
151
152 intervals[[i, 0]] = predictions[i] + z_lower * std_dev;
153 intervals[[i, 1]] = predictions[i] + z_upper * std_dev;
154 }
155
156 Ok(intervals)
157 }
158
159 fn compute_entropy(
160 &self,
161 predictions: &Array1<f64>,
162 ) -> Result<Array1<f64>, Box<dyn std::error::Error>> {
163 let entropy = predictions.mapv(|p| {
164 if p > 0.0 && p < 1.0 {
165 -p * p.ln() - (1.0 - p) * (1.0 - p).ln()
166 } else {
167 0.0
168 }
169 });
170 Ok(entropy)
171 }
172
173 fn compute_mutual_information(
174 &self,
175 predictions: &Array1<f64>,
176 ) -> Result<f64, Box<dyn std::error::Error>> {
177 let mean_entropy = predictions
178 .iter()
179 .map(|&p| {
180 if p > 0.0 && p < 1.0 {
181 -p * p.ln() - (1.0 - p) * (1.0 - p).ln()
182 } else {
183 0.0
184 }
185 })
186 .sum::<f64>()
187 / predictions.len() as f64;
188
189 let mean_prediction = predictions.mean().unwrap_or(0.0);
190 let entropy_of_mean = if mean_prediction > 0.0 && mean_prediction < 1.0 {
191 -mean_prediction * mean_prediction.ln()
192 - (1.0 - mean_prediction) * (1.0 - mean_prediction).ln()
193 } else {
194 0.0
195 };
196
197 Ok(entropy_of_mean - mean_entropy)
198 }
199
200 fn compute_calibration_score(
201 &self,
202 predictions: &Array1<f64>,
203 uncertainties: &Array1<f64>,
204 y_true: &Array1<f64>,
205 ) -> Result<f64, Box<dyn std::error::Error>> {
206 let n_bins = 10;
207 let mut calibration_error = 0.0;
208
209 for bin_idx in 0..n_bins {
210 let lower_bound = bin_idx as f64 / n_bins as f64;
211 let upper_bound = (bin_idx + 1) as f64 / n_bins as f64;
212
213 let mut bin_predictions = Vec::new();
214 let mut bin_true_values = Vec::new();
215
216 for i in 0..predictions.len() {
217 let confidence = 1.0 - uncertainties[i];
218 if confidence > lower_bound && confidence <= upper_bound {
219 bin_predictions.push(predictions[i]);
220 bin_true_values.push(y_true[i]);
221 }
222 }
223
224 if !bin_predictions.is_empty() {
225 let bin_accuracy = bin_predictions
226 .iter()
227 .zip(bin_true_values.iter())
228 .map(|(&pred, &true_val)| {
229 if (pred - true_val).abs() < 0.1 {
230 1.0
231 } else {
232 0.0
233 }
234 })
235 .sum::<f64>()
236 / bin_predictions.len() as f64;
237
238 let bin_confidence = (lower_bound + upper_bound) / 2.0;
239 calibration_error += (bin_accuracy - bin_confidence).abs()
240 * bin_predictions.len() as f64
241 / predictions.len() as f64;
242 }
243 }
244
245 Ok(calibration_error)
246 }
247
248 fn compute_reliability_metrics(
249 &self,
250 predictions: &Array1<f64>,
251 uncertainties: &Array1<f64>,
252 y_true: Option<&Array1<f64>>,
253 ) -> Result<ReliabilityMetrics, Box<dyn std::error::Error>> {
254 let calibration_error = match y_true {
255 Some(y) => self.compute_calibration_score(predictions, uncertainties, y)?,
256 None => 0.0,
257 };
258
259 let sharpness = uncertainties.mean().unwrap_or(0.0);
260 let reliability_score = 1.0 - calibration_error;
261 let coverage_probability = 0.95; let prediction_interval_score = 0.0; let continuous_ranked_probability_score = 0.0; Ok(ReliabilityMetrics {
266 calibration_error,
267 sharpness,
268 reliability_score,
269 coverage_probability,
270 prediction_interval_score,
271 continuous_ranked_probability_score,
272 })
273 }
274
275 pub fn config(&self) -> &EpistemicUncertaintyConfig {
277 &self.config
278 }
279}
280
281impl Default for EpistemicUncertaintyQuantifier {
282 fn default() -> Self {
283 Self::new()
284 }
285}
286
287fn normal_quantile(p: f64) -> f64 {
288 if p <= 0.0 {
290 return f64::NEG_INFINITY;
291 }
292 if p >= 1.0 {
293 return f64::INFINITY;
294 }
295 if p == 0.5 {
296 return 0.0;
297 }
298
299 let c0 = 2.515517;
301 let c1 = 0.802853;
302 let c2 = 0.010328;
303 let d1 = 1.432788;
304 let d2 = 0.189269;
305 let d3 = 0.001308;
306
307 let t = if p < 0.5 {
308 (-2.0 * p.ln()).sqrt()
309 } else {
310 (-2.0 * (1.0 - p).ln()).sqrt()
311 };
312 let numerator = c0 + c1 * t + c2 * t * t;
313 let denominator = 1.0 + d1 * t + d2 * t * t + d3 * t * t * t;
314 let result = t - numerator / denominator;
315
316 if p < 0.5 {
317 -result
318 } else {
319 result
320 }
321}