Skip to main content

scirs2_metrics/explainability/uncertainty_quantification/
core.rs

1//! Core uncertainty quantification types and analyzer
2//!
3//! This module provides the main uncertainty quantification framework
4//! and core types for estimating prediction uncertainty.
5
6#![allow(clippy::too_many_arguments)]
7#![allow(dead_code)]
8
9use crate::error::{MetricsError, Result};
10use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
11use scirs2_core::numeric::Float;
12use std::collections::HashMap;
13
14/// Uncertainty quantification analyzer
15pub struct UncertaintyQuantifier<F: Float> {
16    /// Number of Monte Carlo samples
17    pub n_mc_samples: usize,
18    /// Confidence level for intervals
19    pub confidence_level: F,
20    /// Bootstrap samples for confidence estimation
21    pub n_bootstrap: usize,
22    /// Random seed
23    pub random_seed: Option<u64>,
24    /// Random number generator type
25    pub rng_type: RandomNumberGenerator,
26    /// Number of conformal calibration samples
27    pub n_conformal_calibration: usize,
28    /// Enable Bayesian uncertainty estimation
29    pub enable_bayesian: bool,
30    /// Number of MCMC samples
31    pub n_mcmc_samples: usize,
32    /// MCMC burn-in samples
33    pub mcmc_burn_in: usize,
34    /// Enable temperature scaling
35    pub enable_temperature_scaling: bool,
36    /// Enable SIMD acceleration
37    pub enable_simd: bool,
38}
39
40/// Random number generator types
41#[derive(Debug, Clone)]
42pub enum RandomNumberGenerator {
43    /// Linear Congruential Generator (fast, basic quality)
44    Lcg,
45    /// Xorshift (good balance of speed and quality)
46    Xorshift,
47    /// Permuted Congruential Generator (high quality)
48    Pcg,
49    /// ChaCha (cryptographically secure)
50    ChaCha,
51}
52
53/// Uncertainty analysis results
54#[derive(Debug, Clone)]
55pub struct UncertaintyAnalysis<F: Float> {
56    /// Mean prediction
57    pub mean_prediction: Array1<F>,
58    /// Prediction variance
59    pub prediction_variance: Array1<F>,
60    /// Epistemic uncertainty (model uncertainty)
61    pub epistemic_uncertainty: EpistemicUncertainty<F>,
62    /// Aleatoric uncertainty (data uncertainty)
63    pub aleatoric_uncertainty: AleatoricUncertainty<F>,
64    /// Prediction intervals
65    pub prediction_intervals: PredictionIntervals<F>,
66    /// Calibration metrics
67    pub calibration_metrics: CalibrationMetrics<F>,
68    /// Confidence scores
69    pub confidence_scores: ConfidenceScores<F>,
70    /// Out-of-distribution scores
71    pub ood_scores: OODScores<F>,
72}
73
74/// Epistemic uncertainty (model uncertainty)
75#[derive(Debug, Clone)]
76pub struct EpistemicUncertainty<F: Float> {
77    /// Model variance across ensemble
78    pub model_variance: Array1<F>,
79    /// Mutual information
80    pub mutual_information: F,
81    /// Knowledge uncertainty
82    pub knowledge_uncertainty: Array1<F>,
83    /// Prediction entropy
84    pub prediction_entropy: Array1<F>,
85}
86
87/// Aleatoric uncertainty (data uncertainty)
88#[derive(Debug, Clone)]
89pub struct AleatoricUncertainty<F: Float> {
90    /// Data noise variance
91    pub data_variance: Array1<F>,
92    /// Observation noise
93    pub observation_noise: F,
94    /// Input-dependent variance
95    pub heteroscedastic_variance: Array1<F>,
96}
97
98/// Prediction intervals
99#[derive(Debug, Clone)]
100pub struct PredictionIntervals<F: Float> {
101    /// Lower bounds
102    pub lower_bounds: Array1<F>,
103    /// Upper bounds
104    pub upper_bounds: Array1<F>,
105    /// Confidence level
106    pub confidence_level: F,
107    /// Interval widths
108    pub interval_widths: Array1<F>,
109}
110
111/// Calibration metrics
112#[derive(Debug, Clone)]
113pub struct CalibrationMetrics<F: Float> {
114    /// Expected calibration error
115    pub expected_calibration_error: F,
116    /// Maximum calibration error
117    pub maximum_calibration_error: F,
118    /// Brier score decomposition
119    pub brier_decomposition: BrierDecomposition<F>,
120    /// Reliability curve
121    pub reliability_curve: Array2<F>,
122    /// Sharpness measure
123    pub sharpness: F,
124}
125
126/// Brier score decomposition
127#[derive(Debug, Clone)]
128pub struct BrierDecomposition<F: Float> {
129    /// Reliability component
130    pub reliability: F,
131    /// Resolution component
132    pub resolution: F,
133    /// Uncertainty component
134    pub uncertainty: F,
135    /// Overall Brier score
136    pub brier_score: F,
137}
138
139/// Confidence scores
140#[derive(Debug, Clone)]
141pub struct ConfidenceScores<F: Float> {
142    /// Maximum predicted probability
143    pub max_probability: Array1<F>,
144    /// Entropy-based confidence
145    pub entropy_confidence: Array1<F>,
146    /// Temperature-scaled confidence
147    pub temperature_scaled_confidence: Array1<F>,
148    /// Margin-based confidence
149    pub margin_confidence: Array1<F>,
150}
151
152/// Out-of-distribution detection scores
153#[derive(Debug, Clone)]
154pub struct OODScores<F: Float> {
155    /// Maximum softmax probability
156    pub msp_scores: Array1<F>,
157    /// ODIN scores
158    pub odin_scores: Array1<F>,
159    /// Mahalanobis distance scores
160    pub mahalanobis_scores: Array1<F>,
161    /// Energy scores
162    pub energy_scores: Array1<F>,
163}
164
165impl<
166        F: Float
167            + scirs2_core::numeric::FromPrimitive
168            + std::iter::Sum
169            + scirs2_core::ndarray::ScalarOperand,
170    > UncertaintyQuantifier<F>
171{
172    /// Create new uncertainty quantifier
173    pub fn new() -> Self {
174        Self {
175            n_mc_samples: 100,
176            confidence_level: F::from(0.95).expect("Failed to convert constant to float"),
177            n_bootstrap: 1000,
178            random_seed: None,
179            rng_type: RandomNumberGenerator::Xorshift,
180            n_conformal_calibration: 1000,
181            enable_bayesian: false,
182            n_mcmc_samples: 5000,
183            mcmc_burn_in: 1000,
184            enable_temperature_scaling: true,
185            enable_simd: true,
186        }
187    }
188
189    /// Create uncertainty quantifier with custom configuration
190    pub fn with_config(n_mc_samples: usize, confidence_level: F, n_bootstrap: usize) -> Self {
191        Self {
192            n_mc_samples,
193            confidence_level,
194            n_bootstrap,
195            ..Self::new()
196        }
197    }
198
199    /// Set random seed
200    pub fn with_seed(mut self, seed: u64) -> Self {
201        self.random_seed = Some(seed);
202        self
203    }
204
205    /// Set RNG type
206    pub fn with_rng(mut self, rng_type: RandomNumberGenerator) -> Self {
207        self.rng_type = rng_type;
208        self
209    }
210
211    /// Enable Bayesian uncertainty estimation
212    pub fn with_bayesian(mut self, enabled: bool) -> Self {
213        self.enable_bayesian = enabled;
214        self
215    }
216
217    /// Compute uncertainty analysis for predictions
218    pub fn analyze_uncertainty(
219        &self,
220        predictions: &ArrayView2<F>,
221        ground_truth: Option<&ArrayView1<F>>,
222        model_outputs: Option<&[ArrayView2<F>]>,
223    ) -> Result<UncertaintyAnalysis<F>> {
224        let n_samples = predictions.nrows();
225        let n_classes = predictions.ncols();
226
227        // Compute mean prediction
228        let mean_prediction = predictions
229            .mean_axis(scirs2_core::ndarray::Axis(1))
230            .expect("Operation failed");
231
232        // Compute prediction variance
233        let prediction_variance = self.compute_prediction_variance(predictions)?;
234
235        // Compute epistemic uncertainty
236        let epistemic_uncertainty =
237            self.compute_epistemic_uncertainty(predictions, model_outputs)?;
238
239        // Compute aleatoric uncertainty
240        let aleatoric_uncertainty = self.compute_aleatoric_uncertainty(predictions)?;
241
242        // Compute prediction intervals
243        let prediction_intervals = self
244            .compute_prediction_intervals(&mean_prediction.view(), &prediction_variance.view())?;
245
246        // Compute calibration metrics
247        let calibration_metrics = if let Some(gt) = ground_truth {
248            self.compute_calibration_metrics(predictions, gt)?
249        } else {
250            CalibrationMetrics::default()
251        };
252
253        // Compute confidence scores
254        let confidence_scores = self.compute_confidence_scores(predictions)?;
255
256        // Compute OOD scores
257        let ood_scores = self.compute_ood_scores(predictions)?;
258
259        Ok(UncertaintyAnalysis {
260            mean_prediction,
261            prediction_variance,
262            epistemic_uncertainty,
263            aleatoric_uncertainty,
264            prediction_intervals,
265            calibration_metrics,
266            confidence_scores,
267            ood_scores,
268        })
269    }
270
271    /// Compute prediction variance
272    fn compute_prediction_variance(&self, predictions: &ArrayView2<F>) -> Result<Array1<F>> {
273        let variance = predictions.var_axis(
274            scirs2_core::ndarray::Axis(1),
275            F::from(1.0).expect("Failed to convert constant to float"),
276        );
277        Ok(variance)
278    }
279
280    /// Compute epistemic uncertainty
281    fn compute_epistemic_uncertainty(
282        &self,
283        predictions: &ArrayView2<F>,
284        model_outputs: Option<&[ArrayView2<F>]>,
285    ) -> Result<EpistemicUncertainty<F>> {
286        let n_samples = predictions.nrows();
287
288        // Default values
289        let model_variance = Array1::zeros(n_samples);
290        let mutual_information = F::zero();
291        let knowledge_uncertainty = Array1::zeros(n_samples);
292
293        // Compute prediction entropy
294        let prediction_entropy = self.compute_entropy(predictions)?;
295
296        Ok(EpistemicUncertainty {
297            model_variance,
298            mutual_information,
299            knowledge_uncertainty,
300            prediction_entropy,
301        })
302    }
303
304    /// Compute aleatoric uncertainty
305    fn compute_aleatoric_uncertainty(
306        &self,
307        predictions: &ArrayView2<F>,
308    ) -> Result<AleatoricUncertainty<F>> {
309        let n_samples = predictions.nrows();
310
311        // Simplified aleatoric uncertainty computation
312        let data_variance = predictions.var_axis(
313            scirs2_core::ndarray::Axis(1),
314            F::from(1.0).expect("Failed to convert constant to float"),
315        );
316        let observation_noise = F::from(0.1).expect("Failed to convert constant to float"); // Default noise level
317        let heteroscedastic_variance = Array1::zeros(n_samples);
318
319        Ok(AleatoricUncertainty {
320            data_variance,
321            observation_noise,
322            heteroscedastic_variance,
323        })
324    }
325
326    /// Compute prediction intervals
327    fn compute_prediction_intervals(
328        &self,
329        mean_prediction: &ArrayView1<F>,
330        prediction_variance: &ArrayView1<F>,
331    ) -> Result<PredictionIntervals<F>> {
332        let alpha = F::one() - self.confidence_level;
333        let z_score = F::from(1.96).expect("Failed to convert constant to float"); // 95% confidence interval
334
335        let std_dev = prediction_variance.mapv(|v| v.sqrt());
336
337        let lower_bounds = mean_prediction - &(&std_dev * z_score);
338        let upper_bounds = mean_prediction + &(&std_dev * z_score);
339        let interval_widths = &upper_bounds - &lower_bounds;
340
341        Ok(PredictionIntervals {
342            lower_bounds,
343            upper_bounds,
344            confidence_level: self.confidence_level,
345            interval_widths,
346        })
347    }
348
349    /// Compute calibration metrics
350    fn compute_calibration_metrics(
351        &self,
352        predictions: &ArrayView2<F>,
353        ground_truth: &ArrayView1<F>,
354    ) -> Result<CalibrationMetrics<F>> {
355        // Simplified calibration computation
356        let expected_calibration_error =
357            F::from(0.05).expect("Failed to convert constant to float"); // Placeholder
358        let maximum_calibration_error = F::from(0.1).expect("Failed to convert constant to float"); // Placeholder
359
360        let brier_decomposition = BrierDecomposition {
361            reliability: F::from(0.02).expect("Failed to convert constant to float"),
362            resolution: F::from(0.1).expect("Failed to convert constant to float"),
363            uncertainty: F::from(0.25).expect("Failed to convert constant to float"),
364            brier_score: F::from(0.15).expect("Failed to convert constant to float"),
365        };
366
367        let reliability_curve = Array2::zeros((10, 2)); // Placeholder
368        let sharpness = F::from(0.8).expect("Failed to convert constant to float");
369
370        Ok(CalibrationMetrics {
371            expected_calibration_error,
372            maximum_calibration_error,
373            brier_decomposition,
374            reliability_curve,
375            sharpness,
376        })
377    }
378
379    /// Compute confidence scores
380    fn compute_confidence_scores(
381        &self,
382        predictions: &ArrayView2<F>,
383    ) -> Result<ConfidenceScores<F>> {
384        let n_samples = predictions.nrows();
385
386        // Maximum probability
387        let max_probability = predictions.map_axis(scirs2_core::ndarray::Axis(1), |row| {
388            row.fold(F::neg_infinity(), |acc, &x| if x > acc { x } else { acc })
389        });
390
391        // Entropy-based confidence
392        let entropy_confidence = self.compute_entropy(predictions)?;
393
394        // Temperature-scaled confidence (simplified)
395        let temperature_scaled_confidence = max_probability.clone();
396
397        // Margin-based confidence (difference between top two predictions)
398        let margin_confidence = Array1::zeros(n_samples); // Simplified
399
400        Ok(ConfidenceScores {
401            max_probability,
402            entropy_confidence,
403            temperature_scaled_confidence,
404            margin_confidence,
405        })
406    }
407
408    /// Compute OOD scores
409    fn compute_ood_scores(&self, predictions: &ArrayView2<F>) -> Result<OODScores<F>> {
410        let n_samples = predictions.nrows();
411
412        // Maximum softmax probability (MSP)
413        let msp_scores = predictions.map_axis(scirs2_core::ndarray::Axis(1), |row| {
414            row.fold(F::neg_infinity(), |acc, &x| if x > acc { x } else { acc })
415        });
416
417        // Simplified scores for other methods
418        let odin_scores = Array1::zeros(n_samples);
419        let mahalanobis_scores = Array1::zeros(n_samples);
420        let energy_scores = Array1::zeros(n_samples);
421
422        Ok(OODScores {
423            msp_scores,
424            odin_scores,
425            mahalanobis_scores,
426            energy_scores,
427        })
428    }
429
430    /// Compute entropy of predictions
431    fn compute_entropy(&self, predictions: &ArrayView2<F>) -> Result<Array1<F>> {
432        let epsilon = F::from(1e-8).expect("Failed to convert constant to float");
433        let entropy = predictions.map_axis(scirs2_core::ndarray::Axis(1), |row| {
434            row.iter()
435                .map(|&p| {
436                    let p_safe = if p < epsilon { epsilon } else { p };
437                    -p_safe * p_safe.ln()
438                })
439                .fold(F::zero(), |acc, x| acc + x)
440        });
441
442        Ok(entropy)
443    }
444}
445
446impl<
447        F: Float
448            + scirs2_core::numeric::FromPrimitive
449            + std::iter::Sum
450            + scirs2_core::ndarray::ScalarOperand,
451    > Default for UncertaintyQuantifier<F>
452{
453    fn default() -> Self {
454        Self::new()
455    }
456}
457
458impl<F: Float> Default for CalibrationMetrics<F> {
459    fn default() -> Self {
460        Self {
461            expected_calibration_error: F::zero(),
462            maximum_calibration_error: F::zero(),
463            brier_decomposition: BrierDecomposition {
464                reliability: F::zero(),
465                resolution: F::zero(),
466                uncertainty: F::zero(),
467                brier_score: F::zero(),
468            },
469            reliability_curve: Array2::zeros((0, 0)),
470            sharpness: F::zero(),
471        }
472    }
473}