scirs2_integrate/analysis/ml_prediction/
uncertainty.rs

1//! Uncertainty Quantification and Performance Metrics
2//!
3//! This module contains structures for uncertainty quantification, performance tracking,
4//! and various methods for estimating prediction uncertainty.
5
6use scirs2_core::ndarray::{Array1, Array2};
7
8/// Model performance metrics
9#[derive(Debug, Clone, Default)]
10pub struct PerformanceMetrics {
11    /// Training metrics
12    pub training_metrics: Vec<EpochMetrics>,
13    /// Validation metrics
14    pub validation_metrics: Vec<EpochMetrics>,
15    /// Test metrics
16    pub test_metrics: Option<TestMetrics>,
17    /// Confusion matrix (for classification)
18    pub confusion_matrix: Option<Array2<usize>>,
19    /// Feature importance scores
20    pub feature_importance: Option<Array1<f64>>,
21}
22
23/// Metrics for each training epoch
24#[derive(Debug, Clone)]
25pub struct EpochMetrics {
26    /// Epoch number
27    pub epoch: usize,
28    /// Loss value
29    pub loss: f64,
30    /// Accuracy (for classification)
31    pub accuracy: Option<f64>,
32    /// Precision scores per class
33    pub precision: Option<Vec<f64>>,
34    /// Recall scores per class
35    pub recall: Option<Vec<f64>>,
36    /// F1 scores per class
37    pub f1_score: Option<Vec<f64>>,
38    /// Learning rate used
39    pub learning_rate: f64,
40}
41
42/// Test set evaluation metrics
43#[derive(Debug, Clone)]
44pub struct TestMetrics {
45    /// Overall accuracy
46    pub accuracy: f64,
47    /// Precision per class
48    pub precision: Vec<f64>,
49    /// Recall per class
50    pub recall: Vec<f64>,
51    /// F1 score per class
52    pub f1_score: Vec<f64>,
53    /// Area under ROC curve
54    pub auc_roc: f64,
55    /// Area under precision-recall curve
56    pub auc_pr: f64,
57    /// Matthews correlation coefficient
58    pub mcc: f64,
59}
60
61/// Uncertainty quantification for predictions
62#[derive(Debug, Clone, Default)]
63pub struct UncertaintyQuantification {
64    /// Bayesian neural network configuration
65    pub bayesian_config: Option<BayesianConfig>,
66    /// Monte Carlo dropout configuration
67    pub mc_dropout_config: Option<MCDropoutConfig>,
68    /// Ensemble configuration
69    pub ensemble_config: Option<EnsembleConfig>,
70    /// Conformal prediction configuration
71    pub conformal_config: Option<ConformalConfig>,
72}
73
74/// Bayesian neural network configuration
75#[derive(Debug, Clone)]
76pub struct BayesianConfig {
77    /// Prior distribution parameters
78    pub prior_params: PriorParams,
79    /// Variational inference method
80    pub variational_method: VariationalMethod,
81    /// Number of Monte Carlo samples
82    pub mc_samples: usize,
83    /// KL divergence weight
84    pub kl_weight: f64,
85}
86
87/// Prior distribution parameters
88#[derive(Debug, Clone)]
89pub struct PriorParams {
90    /// Weight prior mean
91    pub weight_mean: f64,
92    /// Weight prior standard deviation
93    pub weight_std: f64,
94    /// Bias prior mean
95    pub bias_mean: f64,
96    /// Bias prior standard deviation
97    pub bias_std: f64,
98}
99
100/// Variational inference methods
101#[derive(Debug, Clone, Copy)]
102pub enum VariationalMethod {
103    /// Mean-field variational inference
104    MeanField,
105    /// Matrix-variate Gaussian
106    MatrixVariate,
107    /// Normalizing flows
108    NormalizingFlows,
109}
110
111/// Monte Carlo dropout configuration
112#[derive(Debug, Clone)]
113pub struct MCDropoutConfig {
114    /// Dropout rate during inference
115    pub dropoutrate: f64,
116    /// Number of forward passes
117    pub num_samples: usize,
118    /// Use different dropout masks
119    pub stochastic_masks: bool,
120}
121
122/// Ensemble configuration
123#[derive(Debug, Clone)]
124pub struct EnsembleConfig {
125    /// Number of models in ensemble
126    pub num_models: usize,
127    /// Ensemble aggregation method
128    pub aggregation_method: EnsembleAggregation,
129    /// Diversity encouragement method
130    pub diversity_method: DiversityMethod,
131}
132
133/// Ensemble aggregation methods
134#[derive(Debug, Clone, Copy)]
135pub enum EnsembleAggregation {
136    /// Simple averaging
137    Average,
138    /// Weighted averaging
139    WeightedAverage,
140    /// Voting (for classification)
141    Voting,
142    /// Stacking with meta-learner
143    Stacking,
144}
145
146/// Methods to encourage diversity in ensemble
147#[derive(Debug, Clone, Copy)]
148pub enum DiversityMethod {
149    /// Bootstrap aggregating
150    Bagging,
151    /// Different random initializations
152    RandomInit,
153    /// Different architectures
154    DifferentArchitectures,
155    /// Adversarial training
156    AdversarialTraining,
157}
158
159/// Conformal prediction configuration
160#[derive(Debug, Clone)]
161pub struct ConformalConfig {
162    /// Confidence level (e.g., 0.95 for 95% confidence)
163    pub confidence_level: f64,
164    /// Conformity score function
165    pub score_function: ConformityScore,
166    /// Calibration set size
167    pub calibration_size: usize,
168}
169
170/// Conformity score functions
171#[derive(Debug, Clone, Copy)]
172pub enum ConformityScore {
173    /// Absolute residuals (for regression)
174    AbsoluteResiduals,
175    /// Normalized residuals
176    NormalizedResiduals,
177    /// Softmax scores (for classification)
178    SoftmaxScores,
179    /// Margin scores
180    MarginScores,
181}