sklears_model_selection/epistemic_uncertainty/
aleatoric_quantifier.rs1use super::uncertainty_config::AleatoricUncertaintyConfig;
2use super::uncertainty_methods::AleatoricUncertaintyMethod;
3use super::uncertainty_results::AleatoricUncertaintyResult;
4use super::uncertainty_types::*;
5use super::variance_estimation::*;
6use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::numeric::Float;
8use scirs2_core::random::Random;
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
12pub struct AleatoricUncertaintyQuantifier {
13 config: AleatoricUncertaintyConfig,
14}
15
16impl AleatoricUncertaintyQuantifier {
17 pub fn new() -> Self {
18 Self {
19 config: AleatoricUncertaintyConfig::default(),
20 }
21 }
22
23 pub fn with_config(config: AleatoricUncertaintyConfig) -> Self {
24 Self { config }
25 }
26
27 pub fn method(mut self, method: AleatoricUncertaintyMethod) -> Self {
28 self.config.method = method;
29 self
30 }
31
32 pub fn confidence_level(mut self, level: f64) -> Self {
33 self.config.confidence_level = level;
34 self
35 }
36
37 pub fn random_state(mut self, seed: u64) -> Self {
38 self.config.random_state = Some(seed);
39 self
40 }
41
42 pub fn noise_regularization(mut self, reg: f64) -> Self {
43 self.config.noise_regularization = reg;
44 self
45 }
46
47 pub fn min_variance(mut self, min_var: f64) -> Self {
48 self.config.min_variance = min_var;
49 self
50 }
51
52 pub fn quantify<E, P>(
53 &self,
54 models: &[E],
55 x: &Array2<f64>,
56 y_true: Option<&Array1<f64>>,
57 ) -> Result<AleatoricUncertaintyResult, Box<dyn std::error::Error>>
58 where
59 E: Clone,
60 P: Clone,
61 {
62 let _rng = match self.config.random_state {
63 Some(seed) => Random::seed(seed),
64 None => Random::seed(42),
65 };
66
67 let (predictions, uncertainties, variance_estimates, noise_estimates) =
68 match &self.config.method {
69 AleatoricUncertaintyMethod::HeteroskedasticRegression { n_ensemble } => {
70 heteroskedastic_regression_uncertainty(models, x, *n_ensemble)?
71 }
72 AleatoricUncertaintyMethod::MixtureDensityNetwork { n_components } => {
73 mixture_density_network_uncertainty(models, x, *n_components)?
74 }
75 AleatoricUncertaintyMethod::QuantileRegression { quantiles } => {
76 quantile_regression_uncertainty(models, x, quantiles)?
77 }
78 AleatoricUncertaintyMethod::ParametricUncertainty { distribution } => {
79 parametric_uncertainty_estimation(models, x, distribution)?
80 }
81 AleatoricUncertaintyMethod::InputDependentNoise { noise_model } => {
82 input_dependent_noise_uncertainty(models, x, noise_model)?
83 }
84 AleatoricUncertaintyMethod::ResidualBasedUncertainty { window_size } => {
85 residual_based_uncertainty(models, x, y_true, *window_size)?
86 }
87 AleatoricUncertaintyMethod::EnsembleAleatoric {
88 n_models,
89 noise_estimation,
90 } => ensemble_aleatoric_uncertainty(models, x, *n_models, noise_estimation)?,
91 };
92
93 let alpha = 1.0 - self.config.confidence_level;
94 let lower_quantile = alpha / 2.0;
95 let upper_quantile = 1.0 - alpha / 2.0;
96
97 let prediction_intervals = self.compute_prediction_intervals(
98 &predictions,
99 &uncertainties,
100 lower_quantile,
101 upper_quantile,
102 )?;
103
104 let heteroskedastic_weights = self.compute_heteroskedastic_weights(&variance_estimates)?;
105 let distributional_parameters =
106 self.compute_distributional_parameters(&predictions, &variance_estimates)?;
107
108 let reliability_metrics =
109 self.compute_reliability_metrics(&predictions, &uncertainties, y_true)?;
110
111 Ok(AleatoricUncertaintyResult {
112 predictions,
113 uncertainties,
114 prediction_intervals,
115 noise_estimates,
116 variance_estimates,
117 heteroskedastic_weights,
118 distributional_parameters,
119 reliability_metrics,
120 })
121 }
122
123 fn compute_prediction_intervals(
124 &self,
125 predictions: &Array1<f64>,
126 uncertainties: &Array1<f64>,
127 lower_quantile: f64,
128 upper_quantile: f64,
129 ) -> Result<Array2<f64>, Box<dyn std::error::Error>> {
130 let n = predictions.len();
131 let mut intervals = Array2::<f64>::zeros((n, 2));
132
133 for i in 0..n {
134 let std_dev = uncertainties[i].sqrt().max(self.config.min_variance.sqrt());
135 let z_lower = normal_quantile(lower_quantile);
136 let z_upper = normal_quantile(upper_quantile);
137
138 intervals[[i, 0]] = predictions[i] + z_lower * std_dev;
139 intervals[[i, 1]] = predictions[i] + z_upper * std_dev;
140 }
141
142 Ok(intervals)
143 }
144
145 fn compute_heteroskedastic_weights(
146 &self,
147 variance_estimates: &Array1<f64>,
148 ) -> Result<Array1<f64>, Box<dyn std::error::Error>> {
149 let mean_variance = variance_estimates.mean().unwrap_or(1.0);
150 let weights =
151 variance_estimates.mapv(|var| if var > 0.0 { mean_variance / var } else { 1.0 });
152 Ok(weights)
153 }
154
155 fn compute_distributional_parameters(
156 &self,
157 predictions: &Array1<f64>,
158 variance_estimates: &Array1<f64>,
159 ) -> Result<HashMap<String, Array1<f64>>, Box<dyn std::error::Error>> {
160 let mut parameters = HashMap::new();
161
162 parameters.insert("mean".to_string(), predictions.clone());
163 parameters.insert("variance".to_string(), variance_estimates.clone());
164 parameters.insert("std_dev".to_string(), variance_estimates.mapv(|v| v.sqrt()));
165
166 let shape_params = variance_estimates.mapv(|v| {
167 let shape = predictions.mean().unwrap_or(1.0).powi(2) / v.max(self.config.min_variance);
168 shape.max(1e-6)
169 });
170 parameters.insert("shape".to_string(), shape_params);
171
172 let scale_params =
173 variance_estimates.mapv(|v| v / predictions.mean().unwrap_or(1.0).max(1e-6));
174 parameters.insert("scale".to_string(), scale_params);
175
176 Ok(parameters)
177 }
178
179 fn compute_reliability_metrics(
180 &self,
181 predictions: &Array1<f64>,
182 uncertainties: &Array1<f64>,
183 y_true: Option<&Array1<f64>>,
184 ) -> Result<ReliabilityMetrics, Box<dyn std::error::Error>> {
185 let calibration_error = match y_true {
186 Some(y) => self.compute_calibration_score(predictions, uncertainties, y)?,
187 None => 0.0,
188 };
189
190 let sharpness = uncertainties.mean().unwrap_or(0.0);
191 let reliability_score = 1.0 - calibration_error;
192 let coverage_probability = 0.95; let prediction_interval_score = 0.0; let continuous_ranked_probability_score = 0.0; Ok(ReliabilityMetrics {
197 calibration_error,
198 sharpness,
199 reliability_score,
200 coverage_probability,
201 prediction_interval_score,
202 continuous_ranked_probability_score,
203 })
204 }
205
206 fn compute_calibration_score(
207 &self,
208 predictions: &Array1<f64>,
209 uncertainties: &Array1<f64>,
210 y_true: &Array1<f64>,
211 ) -> Result<f64, Box<dyn std::error::Error>> {
212 let n_bins = 10;
213 let mut calibration_error = 0.0;
214
215 for bin_idx in 0..n_bins {
216 let lower_bound = bin_idx as f64 / n_bins as f64;
217 let upper_bound = (bin_idx + 1) as f64 / n_bins as f64;
218
219 let mut bin_predictions = Vec::new();
220 let mut bin_true_values = Vec::new();
221 let mut bin_uncertainties = Vec::new();
222
223 for i in 0..predictions.len() {
224 let normalized_uncertainty =
225 uncertainties[i] / uncertainties.iter().fold(0.0, |max, &x| max.max(x));
226 if normalized_uncertainty > lower_bound && normalized_uncertainty <= upper_bound {
227 bin_predictions.push(predictions[i]);
228 bin_true_values.push(y_true[i]);
229 bin_uncertainties.push(uncertainties[i]);
230 }
231 }
232
233 if !bin_predictions.is_empty() {
234 let bin_mse = bin_predictions
235 .iter()
236 .zip(bin_true_values.iter())
237 .map(|(&pred, &true_val)| (pred - true_val).powi(2))
238 .sum::<f64>()
239 / bin_predictions.len() as f64;
240
241 let expected_mse =
242 bin_uncertainties.iter().sum::<f64>() / bin_uncertainties.len() as f64;
243 calibration_error += (bin_mse - expected_mse).abs() * bin_predictions.len() as f64
244 / predictions.len() as f64;
245 }
246 }
247
248 Ok(calibration_error)
249 }
250
251 pub fn config(&self) -> &AleatoricUncertaintyConfig {
253 &self.config
254 }
255}
256
257impl Default for AleatoricUncertaintyQuantifier {
258 fn default() -> Self {
259 Self::new()
260 }
261}
262
263fn normal_quantile(p: f64) -> f64 {
264 if p <= 0.0 {
265 return f64::NEG_INFINITY;
266 }
267 if p >= 1.0 {
268 return f64::INFINITY;
269 }
270 if p == 0.5 {
271 return 0.0;
272 }
273
274 let c0 = 2.515517;
275 let c1 = 0.802853;
276 let c2 = 0.010328;
277 let d1 = 1.432788;
278 let d2 = 0.189269;
279 let d3 = 0.001308;
280
281 let t = if p < 0.5 {
282 (-2.0 * p.ln()).sqrt()
283 } else {
284 (-2.0 * (1.0 - p).ln()).sqrt()
285 };
286 let numerator = c0 + c1 * t + c2 * t * t;
287 let denominator = 1.0 + d1 * t + d2 * t * t + d3 * t * t * t;
288 let result = t - numerator / denominator;
289
290 if p < 0.5 {
291 -result
292 } else {
293 result
294 }
295}