scirs2_integrate/analysis/ml_prediction/uncertainty.rs
1//! Uncertainty Quantification and Performance Metrics
2//!
3//! This module contains structures for uncertainty quantification, performance tracking,
4//! and various methods for estimating prediction uncertainty.
5
6use scirs2_core::ndarray::{Array1, Array2};
7
8/// Model performance metrics
9#[derive(Debug, Clone, Default)]
10pub struct PerformanceMetrics {
11 /// Training metrics
12 pub training_metrics: Vec<EpochMetrics>,
13 /// Validation metrics
14 pub validation_metrics: Vec<EpochMetrics>,
15 /// Test metrics
16 pub test_metrics: Option<TestMetrics>,
17 /// Confusion matrix (for classification)
18 pub confusion_matrix: Option<Array2<usize>>,
19 /// Feature importance scores
20 pub feature_importance: Option<Array1<f64>>,
21}
22
23/// Metrics for each training epoch
24#[derive(Debug, Clone)]
25pub struct EpochMetrics {
26 /// Epoch number
27 pub epoch: usize,
28 /// Loss value
29 pub loss: f64,
30 /// Accuracy (for classification)
31 pub accuracy: Option<f64>,
32 /// Precision scores per class
33 pub precision: Option<Vec<f64>>,
34 /// Recall scores per class
35 pub recall: Option<Vec<f64>>,
36 /// F1 scores per class
37 pub f1_score: Option<Vec<f64>>,
38 /// Learning rate used
39 pub learning_rate: f64,
40}
41
42/// Test set evaluation metrics
43#[derive(Debug, Clone)]
44pub struct TestMetrics {
45 /// Overall accuracy
46 pub accuracy: f64,
47 /// Precision per class
48 pub precision: Vec<f64>,
49 /// Recall per class
50 pub recall: Vec<f64>,
51 /// F1 score per class
52 pub f1_score: Vec<f64>,
53 /// Area under ROC curve
54 pub auc_roc: f64,
55 /// Area under precision-recall curve
56 pub auc_pr: f64,
57 /// Matthews correlation coefficient
58 pub mcc: f64,
59}
60
61/// Uncertainty quantification for predictions
62#[derive(Debug, Clone, Default)]
63pub struct UncertaintyQuantification {
64 /// Bayesian neural network configuration
65 pub bayesian_config: Option<BayesianConfig>,
66 /// Monte Carlo dropout configuration
67 pub mc_dropout_config: Option<MCDropoutConfig>,
68 /// Ensemble configuration
69 pub ensemble_config: Option<EnsembleConfig>,
70 /// Conformal prediction configuration
71 pub conformal_config: Option<ConformalConfig>,
72}
73
74/// Bayesian neural network configuration
75#[derive(Debug, Clone)]
76pub struct BayesianConfig {
77 /// Prior distribution parameters
78 pub prior_params: PriorParams,
79 /// Variational inference method
80 pub variational_method: VariationalMethod,
81 /// Number of Monte Carlo samples
82 pub mc_samples: usize,
83 /// KL divergence weight
84 pub kl_weight: f64,
85}
86
87/// Prior distribution parameters
88#[derive(Debug, Clone)]
89pub struct PriorParams {
90 /// Weight prior mean
91 pub weight_mean: f64,
92 /// Weight prior standard deviation
93 pub weight_std: f64,
94 /// Bias prior mean
95 pub bias_mean: f64,
96 /// Bias prior standard deviation
97 pub bias_std: f64,
98}
99
100/// Variational inference methods
101#[derive(Debug, Clone, Copy)]
102pub enum VariationalMethod {
103 /// Mean-field variational inference
104 MeanField,
105 /// Matrix-variate Gaussian
106 MatrixVariate,
107 /// Normalizing flows
108 NormalizingFlows,
109}
110
111/// Monte Carlo dropout configuration
112#[derive(Debug, Clone)]
113pub struct MCDropoutConfig {
114 /// Dropout rate during inference
115 pub dropoutrate: f64,
116 /// Number of forward passes
117 pub num_samples: usize,
118 /// Use different dropout masks
119 pub stochastic_masks: bool,
120}
121
122/// Ensemble configuration
123#[derive(Debug, Clone)]
124pub struct EnsembleConfig {
125 /// Number of models in ensemble
126 pub num_models: usize,
127 /// Ensemble aggregation method
128 pub aggregation_method: EnsembleAggregation,
129 /// Diversity encouragement method
130 pub diversity_method: DiversityMethod,
131}
132
133/// Ensemble aggregation methods
134#[derive(Debug, Clone, Copy)]
135pub enum EnsembleAggregation {
136 /// Simple averaging
137 Average,
138 /// Weighted averaging
139 WeightedAverage,
140 /// Voting (for classification)
141 Voting,
142 /// Stacking with meta-learner
143 Stacking,
144}
145
146/// Methods to encourage diversity in ensemble
147#[derive(Debug, Clone, Copy)]
148pub enum DiversityMethod {
149 /// Bootstrap aggregating
150 Bagging,
151 /// Different random initializations
152 RandomInit,
153 /// Different architectures
154 DifferentArchitectures,
155 /// Adversarial training
156 AdversarialTraining,
157}
158
159/// Conformal prediction configuration
160#[derive(Debug, Clone)]
161pub struct ConformalConfig {
162 /// Confidence level (e.g., 0.95 for 95% confidence)
163 pub confidence_level: f64,
164 /// Conformity score function
165 pub score_function: ConformityScore,
166 /// Calibration set size
167 pub calibration_size: usize,
168}
169
170/// Conformity score functions
171#[derive(Debug, Clone, Copy)]
172pub enum ConformityScore {
173 /// Absolute residuals (for regression)
174 AbsoluteResiduals,
175 /// Normalized residuals
176 NormalizedResiduals,
177 /// Softmax scores (for classification)
178 SoftmaxScores,
179 /// Margin scores
180 MarginScores,
181}