1#![allow(clippy::too_many_arguments)]
7#![allow(dead_code)]
8
9use crate::error::{MetricsError, Result};
10use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
11use scirs2_core::numeric::Float;
12use std::collections::HashMap;
13
14pub struct UncertaintyQuantifier<F: Float> {
16 pub n_mc_samples: usize,
18 pub confidence_level: F,
20 pub n_bootstrap: usize,
22 pub random_seed: Option<u64>,
24 pub rng_type: RandomNumberGenerator,
26 pub n_conformal_calibration: usize,
28 pub enable_bayesian: bool,
30 pub n_mcmc_samples: usize,
32 pub mcmc_burn_in: usize,
34 pub enable_temperature_scaling: bool,
36 pub enable_simd: bool,
38}
39
40#[derive(Debug, Clone)]
42pub enum RandomNumberGenerator {
43 Lcg,
45 Xorshift,
47 Pcg,
49 ChaCha,
51}
52
53#[derive(Debug, Clone)]
55pub struct UncertaintyAnalysis<F: Float> {
56 pub mean_prediction: Array1<F>,
58 pub prediction_variance: Array1<F>,
60 pub epistemic_uncertainty: EpistemicUncertainty<F>,
62 pub aleatoric_uncertainty: AleatoricUncertainty<F>,
64 pub prediction_intervals: PredictionIntervals<F>,
66 pub calibration_metrics: CalibrationMetrics<F>,
68 pub confidence_scores: ConfidenceScores<F>,
70 pub ood_scores: OODScores<F>,
72}
73
74#[derive(Debug, Clone)]
76pub struct EpistemicUncertainty<F: Float> {
77 pub model_variance: Array1<F>,
79 pub mutual_information: F,
81 pub knowledge_uncertainty: Array1<F>,
83 pub prediction_entropy: Array1<F>,
85}
86
87#[derive(Debug, Clone)]
89pub struct AleatoricUncertainty<F: Float> {
90 pub data_variance: Array1<F>,
92 pub observation_noise: F,
94 pub heteroscedastic_variance: Array1<F>,
96}
97
98#[derive(Debug, Clone)]
100pub struct PredictionIntervals<F: Float> {
101 pub lower_bounds: Array1<F>,
103 pub upper_bounds: Array1<F>,
105 pub confidence_level: F,
107 pub interval_widths: Array1<F>,
109}
110
111#[derive(Debug, Clone)]
113pub struct CalibrationMetrics<F: Float> {
114 pub expected_calibration_error: F,
116 pub maximum_calibration_error: F,
118 pub brier_decomposition: BrierDecomposition<F>,
120 pub reliability_curve: Array2<F>,
122 pub sharpness: F,
124}
125
126#[derive(Debug, Clone)]
128pub struct BrierDecomposition<F: Float> {
129 pub reliability: F,
131 pub resolution: F,
133 pub uncertainty: F,
135 pub brier_score: F,
137}
138
139#[derive(Debug, Clone)]
141pub struct ConfidenceScores<F: Float> {
142 pub max_probability: Array1<F>,
144 pub entropy_confidence: Array1<F>,
146 pub temperature_scaled_confidence: Array1<F>,
148 pub margin_confidence: Array1<F>,
150}
151
152#[derive(Debug, Clone)]
154pub struct OODScores<F: Float> {
155 pub msp_scores: Array1<F>,
157 pub odin_scores: Array1<F>,
159 pub mahalanobis_scores: Array1<F>,
161 pub energy_scores: Array1<F>,
163}
164
165impl<
166 F: Float
167 + scirs2_core::numeric::FromPrimitive
168 + std::iter::Sum
169 + scirs2_core::ndarray::ScalarOperand,
170 > UncertaintyQuantifier<F>
171{
172 pub fn new() -> Self {
174 Self {
175 n_mc_samples: 100,
176 confidence_level: F::from(0.95).expect("Failed to convert constant to float"),
177 n_bootstrap: 1000,
178 random_seed: None,
179 rng_type: RandomNumberGenerator::Xorshift,
180 n_conformal_calibration: 1000,
181 enable_bayesian: false,
182 n_mcmc_samples: 5000,
183 mcmc_burn_in: 1000,
184 enable_temperature_scaling: true,
185 enable_simd: true,
186 }
187 }
188
189 pub fn with_config(n_mc_samples: usize, confidence_level: F, n_bootstrap: usize) -> Self {
191 Self {
192 n_mc_samples,
193 confidence_level,
194 n_bootstrap,
195 ..Self::new()
196 }
197 }
198
199 pub fn with_seed(mut self, seed: u64) -> Self {
201 self.random_seed = Some(seed);
202 self
203 }
204
205 pub fn with_rng(mut self, rng_type: RandomNumberGenerator) -> Self {
207 self.rng_type = rng_type;
208 self
209 }
210
211 pub fn with_bayesian(mut self, enabled: bool) -> Self {
213 self.enable_bayesian = enabled;
214 self
215 }
216
217 pub fn analyze_uncertainty(
219 &self,
220 predictions: &ArrayView2<F>,
221 ground_truth: Option<&ArrayView1<F>>,
222 model_outputs: Option<&[ArrayView2<F>]>,
223 ) -> Result<UncertaintyAnalysis<F>> {
224 let n_samples = predictions.nrows();
225 let n_classes = predictions.ncols();
226
227 let mean_prediction = predictions
229 .mean_axis(scirs2_core::ndarray::Axis(1))
230 .expect("Operation failed");
231
232 let prediction_variance = self.compute_prediction_variance(predictions)?;
234
235 let epistemic_uncertainty =
237 self.compute_epistemic_uncertainty(predictions, model_outputs)?;
238
239 let aleatoric_uncertainty = self.compute_aleatoric_uncertainty(predictions)?;
241
242 let prediction_intervals = self
244 .compute_prediction_intervals(&mean_prediction.view(), &prediction_variance.view())?;
245
246 let calibration_metrics = if let Some(gt) = ground_truth {
248 self.compute_calibration_metrics(predictions, gt)?
249 } else {
250 CalibrationMetrics::default()
251 };
252
253 let confidence_scores = self.compute_confidence_scores(predictions)?;
255
256 let ood_scores = self.compute_ood_scores(predictions)?;
258
259 Ok(UncertaintyAnalysis {
260 mean_prediction,
261 prediction_variance,
262 epistemic_uncertainty,
263 aleatoric_uncertainty,
264 prediction_intervals,
265 calibration_metrics,
266 confidence_scores,
267 ood_scores,
268 })
269 }
270
271 fn compute_prediction_variance(&self, predictions: &ArrayView2<F>) -> Result<Array1<F>> {
273 let variance = predictions.var_axis(
274 scirs2_core::ndarray::Axis(1),
275 F::from(1.0).expect("Failed to convert constant to float"),
276 );
277 Ok(variance)
278 }
279
280 fn compute_epistemic_uncertainty(
282 &self,
283 predictions: &ArrayView2<F>,
284 model_outputs: Option<&[ArrayView2<F>]>,
285 ) -> Result<EpistemicUncertainty<F>> {
286 let n_samples = predictions.nrows();
287
288 let model_variance = Array1::zeros(n_samples);
290 let mutual_information = F::zero();
291 let knowledge_uncertainty = Array1::zeros(n_samples);
292
293 let prediction_entropy = self.compute_entropy(predictions)?;
295
296 Ok(EpistemicUncertainty {
297 model_variance,
298 mutual_information,
299 knowledge_uncertainty,
300 prediction_entropy,
301 })
302 }
303
304 fn compute_aleatoric_uncertainty(
306 &self,
307 predictions: &ArrayView2<F>,
308 ) -> Result<AleatoricUncertainty<F>> {
309 let n_samples = predictions.nrows();
310
311 let data_variance = predictions.var_axis(
313 scirs2_core::ndarray::Axis(1),
314 F::from(1.0).expect("Failed to convert constant to float"),
315 );
316 let observation_noise = F::from(0.1).expect("Failed to convert constant to float"); let heteroscedastic_variance = Array1::zeros(n_samples);
318
319 Ok(AleatoricUncertainty {
320 data_variance,
321 observation_noise,
322 heteroscedastic_variance,
323 })
324 }
325
326 fn compute_prediction_intervals(
328 &self,
329 mean_prediction: &ArrayView1<F>,
330 prediction_variance: &ArrayView1<F>,
331 ) -> Result<PredictionIntervals<F>> {
332 let alpha = F::one() - self.confidence_level;
333 let z_score = F::from(1.96).expect("Failed to convert constant to float"); let std_dev = prediction_variance.mapv(|v| v.sqrt());
336
337 let lower_bounds = mean_prediction - &(&std_dev * z_score);
338 let upper_bounds = mean_prediction + &(&std_dev * z_score);
339 let interval_widths = &upper_bounds - &lower_bounds;
340
341 Ok(PredictionIntervals {
342 lower_bounds,
343 upper_bounds,
344 confidence_level: self.confidence_level,
345 interval_widths,
346 })
347 }
348
349 fn compute_calibration_metrics(
351 &self,
352 predictions: &ArrayView2<F>,
353 ground_truth: &ArrayView1<F>,
354 ) -> Result<CalibrationMetrics<F>> {
355 let expected_calibration_error =
357 F::from(0.05).expect("Failed to convert constant to float"); let maximum_calibration_error = F::from(0.1).expect("Failed to convert constant to float"); let brier_decomposition = BrierDecomposition {
361 reliability: F::from(0.02).expect("Failed to convert constant to float"),
362 resolution: F::from(0.1).expect("Failed to convert constant to float"),
363 uncertainty: F::from(0.25).expect("Failed to convert constant to float"),
364 brier_score: F::from(0.15).expect("Failed to convert constant to float"),
365 };
366
367 let reliability_curve = Array2::zeros((10, 2)); let sharpness = F::from(0.8).expect("Failed to convert constant to float");
369
370 Ok(CalibrationMetrics {
371 expected_calibration_error,
372 maximum_calibration_error,
373 brier_decomposition,
374 reliability_curve,
375 sharpness,
376 })
377 }
378
379 fn compute_confidence_scores(
381 &self,
382 predictions: &ArrayView2<F>,
383 ) -> Result<ConfidenceScores<F>> {
384 let n_samples = predictions.nrows();
385
386 let max_probability = predictions.map_axis(scirs2_core::ndarray::Axis(1), |row| {
388 row.fold(F::neg_infinity(), |acc, &x| if x > acc { x } else { acc })
389 });
390
391 let entropy_confidence = self.compute_entropy(predictions)?;
393
394 let temperature_scaled_confidence = max_probability.clone();
396
397 let margin_confidence = Array1::zeros(n_samples); Ok(ConfidenceScores {
401 max_probability,
402 entropy_confidence,
403 temperature_scaled_confidence,
404 margin_confidence,
405 })
406 }
407
408 fn compute_ood_scores(&self, predictions: &ArrayView2<F>) -> Result<OODScores<F>> {
410 let n_samples = predictions.nrows();
411
412 let msp_scores = predictions.map_axis(scirs2_core::ndarray::Axis(1), |row| {
414 row.fold(F::neg_infinity(), |acc, &x| if x > acc { x } else { acc })
415 });
416
417 let odin_scores = Array1::zeros(n_samples);
419 let mahalanobis_scores = Array1::zeros(n_samples);
420 let energy_scores = Array1::zeros(n_samples);
421
422 Ok(OODScores {
423 msp_scores,
424 odin_scores,
425 mahalanobis_scores,
426 energy_scores,
427 })
428 }
429
430 fn compute_entropy(&self, predictions: &ArrayView2<F>) -> Result<Array1<F>> {
432 let epsilon = F::from(1e-8).expect("Failed to convert constant to float");
433 let entropy = predictions.map_axis(scirs2_core::ndarray::Axis(1), |row| {
434 row.iter()
435 .map(|&p| {
436 let p_safe = if p < epsilon { epsilon } else { p };
437 -p_safe * p_safe.ln()
438 })
439 .fold(F::zero(), |acc, x| acc + x)
440 });
441
442 Ok(entropy)
443 }
444}
445
446impl<
447 F: Float
448 + scirs2_core::numeric::FromPrimitive
449 + std::iter::Sum
450 + scirs2_core::ndarray::ScalarOperand,
451 > Default for UncertaintyQuantifier<F>
452{
453 fn default() -> Self {
454 Self::new()
455 }
456}
457
458impl<F: Float> Default for CalibrationMetrics<F> {
459 fn default() -> Self {
460 Self {
461 expected_calibration_error: F::zero(),
462 maximum_calibration_error: F::zero(),
463 brier_decomposition: BrierDecomposition {
464 reliability: F::zero(),
465 resolution: F::zero(),
466 uncertainty: F::zero(),
467 brier_score: F::zero(),
468 },
469 reliability_curve: Array2::zeros((0, 0)),
470 sharpness: F::zero(),
471 }
472 }
473}