sklears_model_selection/epistemic_uncertainty/
aleatoric_quantifier.rs

1use super::uncertainty_config::AleatoricUncertaintyConfig;
2use super::uncertainty_methods::AleatoricUncertaintyMethod;
3use super::uncertainty_results::AleatoricUncertaintyResult;
4use super::uncertainty_types::*;
5use super::variance_estimation::*;
6use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::numeric::Float;
8use scirs2_core::random::Random;
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
12pub struct AleatoricUncertaintyQuantifier {
13    config: AleatoricUncertaintyConfig,
14}
15
16impl AleatoricUncertaintyQuantifier {
17    pub fn new() -> Self {
18        Self {
19            config: AleatoricUncertaintyConfig::default(),
20        }
21    }
22
23    pub fn with_config(config: AleatoricUncertaintyConfig) -> Self {
24        Self { config }
25    }
26
27    pub fn method(mut self, method: AleatoricUncertaintyMethod) -> Self {
28        self.config.method = method;
29        self
30    }
31
32    pub fn confidence_level(mut self, level: f64) -> Self {
33        self.config.confidence_level = level;
34        self
35    }
36
37    pub fn random_state(mut self, seed: u64) -> Self {
38        self.config.random_state = Some(seed);
39        self
40    }
41
42    pub fn noise_regularization(mut self, reg: f64) -> Self {
43        self.config.noise_regularization = reg;
44        self
45    }
46
47    pub fn min_variance(mut self, min_var: f64) -> Self {
48        self.config.min_variance = min_var;
49        self
50    }
51
52    pub fn quantify<E, P>(
53        &self,
54        models: &[E],
55        x: &Array2<f64>,
56        y_true: Option<&Array1<f64>>,
57    ) -> Result<AleatoricUncertaintyResult, Box<dyn std::error::Error>>
58    where
59        E: Clone,
60        P: Clone,
61    {
62        let _rng = match self.config.random_state {
63            Some(seed) => Random::seed(seed),
64            None => Random::seed(42),
65        };
66
67        let (predictions, uncertainties, variance_estimates, noise_estimates) =
68            match &self.config.method {
69                AleatoricUncertaintyMethod::HeteroskedasticRegression { n_ensemble } => {
70                    heteroskedastic_regression_uncertainty(models, x, *n_ensemble)?
71                }
72                AleatoricUncertaintyMethod::MixtureDensityNetwork { n_components } => {
73                    mixture_density_network_uncertainty(models, x, *n_components)?
74                }
75                AleatoricUncertaintyMethod::QuantileRegression { quantiles } => {
76                    quantile_regression_uncertainty(models, x, quantiles)?
77                }
78                AleatoricUncertaintyMethod::ParametricUncertainty { distribution } => {
79                    parametric_uncertainty_estimation(models, x, distribution)?
80                }
81                AleatoricUncertaintyMethod::InputDependentNoise { noise_model } => {
82                    input_dependent_noise_uncertainty(models, x, noise_model)?
83                }
84                AleatoricUncertaintyMethod::ResidualBasedUncertainty { window_size } => {
85                    residual_based_uncertainty(models, x, y_true, *window_size)?
86                }
87                AleatoricUncertaintyMethod::EnsembleAleatoric {
88                    n_models,
89                    noise_estimation,
90                } => ensemble_aleatoric_uncertainty(models, x, *n_models, noise_estimation)?,
91            };
92
93        let alpha = 1.0 - self.config.confidence_level;
94        let lower_quantile = alpha / 2.0;
95        let upper_quantile = 1.0 - alpha / 2.0;
96
97        let prediction_intervals = self.compute_prediction_intervals(
98            &predictions,
99            &uncertainties,
100            lower_quantile,
101            upper_quantile,
102        )?;
103
104        let heteroskedastic_weights = self.compute_heteroskedastic_weights(&variance_estimates)?;
105        let distributional_parameters =
106            self.compute_distributional_parameters(&predictions, &variance_estimates)?;
107
108        let reliability_metrics =
109            self.compute_reliability_metrics(&predictions, &uncertainties, y_true)?;
110
111        Ok(AleatoricUncertaintyResult {
112            predictions,
113            uncertainties,
114            prediction_intervals,
115            noise_estimates,
116            variance_estimates,
117            heteroskedastic_weights,
118            distributional_parameters,
119            reliability_metrics,
120        })
121    }
122
123    fn compute_prediction_intervals(
124        &self,
125        predictions: &Array1<f64>,
126        uncertainties: &Array1<f64>,
127        lower_quantile: f64,
128        upper_quantile: f64,
129    ) -> Result<Array2<f64>, Box<dyn std::error::Error>> {
130        let n = predictions.len();
131        let mut intervals = Array2::<f64>::zeros((n, 2));
132
133        for i in 0..n {
134            let std_dev = uncertainties[i].sqrt().max(self.config.min_variance.sqrt());
135            let z_lower = normal_quantile(lower_quantile);
136            let z_upper = normal_quantile(upper_quantile);
137
138            intervals[[i, 0]] = predictions[i] + z_lower * std_dev;
139            intervals[[i, 1]] = predictions[i] + z_upper * std_dev;
140        }
141
142        Ok(intervals)
143    }
144
145    fn compute_heteroskedastic_weights(
146        &self,
147        variance_estimates: &Array1<f64>,
148    ) -> Result<Array1<f64>, Box<dyn std::error::Error>> {
149        let mean_variance = variance_estimates.mean().unwrap_or(1.0);
150        let weights =
151            variance_estimates.mapv(|var| if var > 0.0 { mean_variance / var } else { 1.0 });
152        Ok(weights)
153    }
154
155    fn compute_distributional_parameters(
156        &self,
157        predictions: &Array1<f64>,
158        variance_estimates: &Array1<f64>,
159    ) -> Result<HashMap<String, Array1<f64>>, Box<dyn std::error::Error>> {
160        let mut parameters = HashMap::new();
161
162        parameters.insert("mean".to_string(), predictions.clone());
163        parameters.insert("variance".to_string(), variance_estimates.clone());
164        parameters.insert("std_dev".to_string(), variance_estimates.mapv(|v| v.sqrt()));
165
166        let shape_params = variance_estimates.mapv(|v| {
167            let shape = predictions.mean().unwrap_or(1.0).powi(2) / v.max(self.config.min_variance);
168            shape.max(1e-6)
169        });
170        parameters.insert("shape".to_string(), shape_params);
171
172        let scale_params =
173            variance_estimates.mapv(|v| v / predictions.mean().unwrap_or(1.0).max(1e-6));
174        parameters.insert("scale".to_string(), scale_params);
175
176        Ok(parameters)
177    }
178
179    fn compute_reliability_metrics(
180        &self,
181        predictions: &Array1<f64>,
182        uncertainties: &Array1<f64>,
183        y_true: Option<&Array1<f64>>,
184    ) -> Result<ReliabilityMetrics, Box<dyn std::error::Error>> {
185        let calibration_error = match y_true {
186            Some(y) => self.compute_calibration_score(predictions, uncertainties, y)?,
187            None => 0.0,
188        };
189
190        let sharpness = uncertainties.mean().unwrap_or(0.0);
191        let reliability_score = 1.0 - calibration_error;
192        let coverage_probability = 0.95; // Placeholder
193        let prediction_interval_score = 0.0; // Placeholder
194        let continuous_ranked_probability_score = 0.0; // Placeholder
195
196        Ok(ReliabilityMetrics {
197            calibration_error,
198            sharpness,
199            reliability_score,
200            coverage_probability,
201            prediction_interval_score,
202            continuous_ranked_probability_score,
203        })
204    }
205
206    fn compute_calibration_score(
207        &self,
208        predictions: &Array1<f64>,
209        uncertainties: &Array1<f64>,
210        y_true: &Array1<f64>,
211    ) -> Result<f64, Box<dyn std::error::Error>> {
212        let n_bins = 10;
213        let mut calibration_error = 0.0;
214
215        for bin_idx in 0..n_bins {
216            let lower_bound = bin_idx as f64 / n_bins as f64;
217            let upper_bound = (bin_idx + 1) as f64 / n_bins as f64;
218
219            let mut bin_predictions = Vec::new();
220            let mut bin_true_values = Vec::new();
221            let mut bin_uncertainties = Vec::new();
222
223            for i in 0..predictions.len() {
224                let normalized_uncertainty =
225                    uncertainties[i] / uncertainties.iter().fold(0.0, |max, &x| max.max(x));
226                if normalized_uncertainty > lower_bound && normalized_uncertainty <= upper_bound {
227                    bin_predictions.push(predictions[i]);
228                    bin_true_values.push(y_true[i]);
229                    bin_uncertainties.push(uncertainties[i]);
230                }
231            }
232
233            if !bin_predictions.is_empty() {
234                let bin_mse = bin_predictions
235                    .iter()
236                    .zip(bin_true_values.iter())
237                    .map(|(&pred, &true_val)| (pred - true_val).powi(2))
238                    .sum::<f64>()
239                    / bin_predictions.len() as f64;
240
241                let expected_mse =
242                    bin_uncertainties.iter().sum::<f64>() / bin_uncertainties.len() as f64;
243                calibration_error += (bin_mse - expected_mse).abs() * bin_predictions.len() as f64
244                    / predictions.len() as f64;
245            }
246        }
247
248        Ok(calibration_error)
249    }
250
251    // Getter methods for testing
252    pub fn config(&self) -> &AleatoricUncertaintyConfig {
253        &self.config
254    }
255}
256
257impl Default for AleatoricUncertaintyQuantifier {
258    fn default() -> Self {
259        Self::new()
260    }
261}
262
263fn normal_quantile(p: f64) -> f64 {
264    if p <= 0.0 {
265        return f64::NEG_INFINITY;
266    }
267    if p >= 1.0 {
268        return f64::INFINITY;
269    }
270    if p == 0.5 {
271        return 0.0;
272    }
273
274    let c0 = 2.515517;
275    let c1 = 0.802853;
276    let c2 = 0.010328;
277    let d1 = 1.432788;
278    let d2 = 0.189269;
279    let d3 = 0.001308;
280
281    let t = if p < 0.5 {
282        (-2.0 * p.ln()).sqrt()
283    } else {
284        (-2.0 * (1.0 - p).ln()).sqrt()
285    };
286    let numerator = c0 + c1 * t + c2 * t * t;
287    let denominator = 1.0 + d1 * t + d2 * t * t + d3 * t * t * t;
288    let result = t - numerator / denominator;
289
290    if p < 0.5 {
291        -result
292    } else {
293        result
294    }
295}