1use scirs2_core::ndarray::Array1;
15use scirs2_core::numeric::{Float as FloatTrait, ToPrimitive};
16use scirs2_core::random::rngs::StdRng;
17use scirs2_core::random::SeedableRng;
18use sklears_core::error::{Result, SklearsError};
19use std::fmt::Debug;
20
21#[derive(Debug, Clone)]
23pub struct BayesianModelSelectionResult {
24 pub model_names: Vec<String>,
26 pub log_evidence: Vec<f64>,
28 pub model_probabilities: Vec<f64>,
30 pub bayes_factors: Vec<f64>,
32 pub best_model_index: usize,
34 pub method: EvidenceEstimationMethod,
36}
37
38impl BayesianModelSelectionResult {
39 pub fn best_model(&self) -> &str {
41 &self.model_names[self.best_model_index]
42 }
43
44 pub fn model_ranking(&self) -> Vec<(usize, &str, f64)> {
46 let mut ranking: Vec<(usize, &str, f64)> = self
47 .model_names
48 .iter()
49 .enumerate()
50 .map(|(i, name)| (i, name.as_str(), self.log_evidence[i]))
51 .collect();
52
53 ranking.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
54 ranking
55 }
56
57 pub fn evidence_interpretation(&self, model1_idx: usize, model2_idx: usize) -> String {
59 let log_bf = self.log_evidence[model1_idx] - self.log_evidence[model2_idx];
60 let _bf = log_bf.exp();
61
62 match log_bf {
63 x if x < 1.0 => "Weak evidence".to_string(),
64 x if x < 2.5 => "Positive evidence".to_string(),
65 x if x <= 5.0 => "Strong evidence".to_string(),
66 _ => "Very strong evidence".to_string(),
67 }
68 }
69}
70
71#[derive(Debug, Clone)]
73pub enum EvidenceEstimationMethod {
74 LaplaceApproximation,
76 BIC,
78 AICc,
80 HarmonicMean,
82 ThermodynamicIntegration { n_temperatures: usize },
84 NestedSampling { n_live_points: usize },
86 CrossValidationEvidence { n_folds: usize },
88}
89
90pub struct BayesianModelSelector {
92 method: EvidenceEstimationMethod,
94 prior_probabilities: Option<Vec<f64>>,
96 rng: StdRng,
98}
99
100impl BayesianModelSelector {
101 pub fn new(method: EvidenceEstimationMethod, random_state: Option<u64>) -> Self {
103 let rng = match random_state {
104 Some(seed) => StdRng::seed_from_u64(seed),
105 None => StdRng::from_rng(&mut scirs2_core::random::thread_rng()),
106 };
107
108 Self {
109 method,
110 prior_probabilities: None,
111 rng,
112 }
113 }
114
115 pub fn with_prior_probabilities(mut self, priors: Vec<f64>) -> Result<Self> {
117 let sum: f64 = priors.iter().sum();
118 if (sum - 1.0).abs() > 1e-6 {
119 return Err(SklearsError::InvalidInput(
120 "Prior probabilities must sum to 1".to_string(),
121 ));
122 }
123 self.prior_probabilities = Some(priors);
124 Ok(self)
125 }
126
127 pub fn compare_models<F>(
129 &mut self,
130 model_results: &[(String, ModelEvidenceData)],
131 ) -> Result<BayesianModelSelectionResult>
132 where
133 F: FloatTrait + ToPrimitive,
134 {
135 if model_results.is_empty() {
136 return Err(SklearsError::InvalidInput("No models provided".to_string()));
137 }
138
139 let model_names: Vec<String> = model_results.iter().map(|(name, _)| name.clone()).collect();
140 let _n_models = model_names.len();
141
142 let mut log_evidence = Vec::new();
144 for (_, data) in model_results {
145 let log_ev = self.estimate_evidence(data)?;
146 log_evidence.push(log_ev);
147 }
148
149 let model_probabilities = self.calculate_model_probabilities(&log_evidence)?;
151
152 let best_log_evidence = log_evidence
154 .iter()
155 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
156 let bayes_factors: Vec<f64> = log_evidence
157 .iter()
158 .map(|&log_ev| (log_ev - best_log_evidence).exp())
159 .collect();
160
161 let best_model_index = log_evidence
162 .iter()
163 .enumerate()
164 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
165 .map(|(i, _)| i)
166 .unwrap();
167
168 Ok(BayesianModelSelectionResult {
169 model_names,
170 log_evidence,
171 model_probabilities,
172 bayes_factors,
173 best_model_index,
174 method: self.method.clone(),
175 })
176 }
177
178 fn estimate_evidence(&mut self, data: &ModelEvidenceData) -> Result<f64> {
180 match &self.method {
181 EvidenceEstimationMethod::LaplaceApproximation => self.laplace_approximation(data),
182 EvidenceEstimationMethod::BIC => self.bic_approximation(data),
183 EvidenceEstimationMethod::AICc => self.aicc_approximation(data),
184 EvidenceEstimationMethod::HarmonicMean => self.harmonic_mean_estimator(data),
185 EvidenceEstimationMethod::ThermodynamicIntegration { n_temperatures } => {
186 self.thermodynamic_integration(data, *n_temperatures)
187 }
188 EvidenceEstimationMethod::NestedSampling { n_live_points } => {
189 self.nested_sampling_approximation(data, *n_live_points)
190 }
191 EvidenceEstimationMethod::CrossValidationEvidence { n_folds } => {
192 self.cross_validation_evidence(data, *n_folds)
193 }
194 }
195 }
196
197 fn laplace_approximation(&self, data: &ModelEvidenceData) -> Result<f64> {
199 let n_params = data.n_parameters as f64;
200 let n_data = data.n_data_points as f64;
201
202 let log_likelihood_map = data.max_log_likelihood;
204
205 let log_det_hessian = data.hessian_log_determinant.unwrap_or_else(|| {
207 n_params * (2.0 * std::f64::consts::PI).ln() + n_params * n_data.ln()
209 });
210
211 let log_prior = data.log_prior.unwrap_or(0.0);
213
214 let log_evidence =
216 log_likelihood_map + log_prior + (n_params / 2.0) * (2.0 * std::f64::consts::PI).ln()
217 - 0.5 * log_det_hessian;
218
219 Ok(log_evidence)
220 }
221
222 fn bic_approximation(&self, data: &ModelEvidenceData) -> Result<f64> {
224 let n_params = data.n_parameters as f64;
225 let n_data = data.n_data_points as f64;
226
227 let log_evidence = data.max_log_likelihood
230 - (n_params / 2.0) * n_data.ln()
231 - (n_params / 2.0) * (2.0 * std::f64::consts::PI).ln();
232
233 Ok(log_evidence)
234 }
235
236 fn aicc_approximation(&self, data: &ModelEvidenceData) -> Result<f64> {
238 let k = data.n_parameters as f64;
239 let n = data.n_data_points as f64;
240
241 if n <= k + 1.0 {
242 return Err(SklearsError::InvalidInput(
243 "AICc requires n > k + 1".to_string(),
244 ));
245 }
246
247 let aicc_correction = 2.0 * k * (k + 1.0) / (n - k - 1.0);
249 let log_evidence = data.max_log_likelihood - k - aicc_correction / 2.0;
250
251 Ok(log_evidence)
252 }
253
254 fn harmonic_mean_estimator(&self, data: &ModelEvidenceData) -> Result<f64> {
256 if data.posterior_samples.is_empty() {
257 return Err(SklearsError::InvalidInput(
258 "Harmonic mean estimator requires posterior samples".to_string(),
259 ));
260 }
261
262 let n_samples = data.posterior_samples.len() as f64;
265 let harmonic_mean: f64 = data
266 .posterior_samples
267 .iter()
268 .map(|&log_likelihood| (-log_likelihood).exp())
269 .sum::<f64>()
270 / n_samples;
271
272 let log_evidence = -harmonic_mean.ln();
273
274 eprintln!("Warning: Harmonic mean estimator is known to be unreliable");
276
277 Ok(log_evidence)
278 }
279
280 fn thermodynamic_integration(
282 &mut self,
283 data: &ModelEvidenceData,
284 n_temperatures: usize,
285 ) -> Result<f64> {
286 if data.posterior_samples.is_empty() {
287 return Err(SklearsError::InvalidInput(
288 "Thermodynamic integration requires posterior samples".to_string(),
289 ));
290 }
291
292 let temperatures: Vec<f64> = (0..=n_temperatures)
294 .map(|i| i as f64 / n_temperatures as f64)
295 .collect();
296
297 let mut mean_log_likelihoods = Vec::new();
299
300 for &temp in &temperatures {
301 if temp == 0.0 {
302 mean_log_likelihoods.push(0.0);
304 } else {
305 let mean_ll = data.posterior_samples.iter().sum::<f64>()
307 / data.posterior_samples.len() as f64;
308 mean_log_likelihoods.push(temp * mean_ll);
309 }
310 }
311
312 let mut integral = 0.0;
314 for i in 1..temperatures.len() {
315 let dt = temperatures[i] - temperatures[i - 1];
316 integral += 0.5 * dt * (mean_log_likelihoods[i] + mean_log_likelihoods[i - 1]);
317 }
318
319 let log_prior = data.log_prior.unwrap_or(0.0);
321 let log_evidence = log_prior + integral;
322
323 Ok(log_evidence)
324 }
325
326 fn nested_sampling_approximation(
328 &mut self,
329 data: &ModelEvidenceData,
330 n_live_points: usize,
331 ) -> Result<f64> {
332 if data.posterior_samples.is_empty() {
336 return Err(SklearsError::InvalidInput(
337 "Nested sampling requires posterior samples".to_string(),
338 ));
339 }
340
341 let n_samples = data.posterior_samples.len();
342 let max_iterations = n_samples.min(1000); let mut sorted_samples = data.posterior_samples.clone();
346 sorted_samples.sort_by(|a, b| a.partial_cmp(b).unwrap());
347
348 let mut log_evidence = f64::NEG_INFINITY;
350 let mut log_width = -(1.0 / n_live_points as f64).ln();
351
352 for (i, &log_likelihood) in sorted_samples.iter().enumerate() {
353 if i >= max_iterations {
354 break;
355 }
356
357 let log_weight = log_width + log_likelihood;
358 log_evidence = log_sum_exp(log_evidence, log_weight);
359
360 log_width -= (n_live_points as f64).ln();
362 }
363
364 Ok(log_evidence)
365 }
366
367 fn cross_validation_evidence(&self, data: &ModelEvidenceData, n_folds: usize) -> Result<f64> {
369 if data.cv_log_likelihoods.is_none() {
370 return Err(SklearsError::InvalidInput(
371 "Cross-validation evidence requires CV log-likelihoods".to_string(),
372 ));
373 }
374
375 let cv_log_likes = data.cv_log_likelihoods.as_ref().unwrap();
376 if cv_log_likes.len() != n_folds {
377 return Err(SklearsError::InvalidInput(
378 "Number of CV scores must match number of folds".to_string(),
379 ));
380 }
381
382 let mean_cv_log_likelihood = cv_log_likes.iter().sum::<f64>() / cv_log_likes.len() as f64;
384
385 let n_data = data.n_data_points as f64;
387 let correction = (n_data / (n_data - 1.0)).ln();
388
389 let log_evidence = mean_cv_log_likelihood + correction;
390
391 Ok(log_evidence)
392 }
393
394 fn calculate_model_probabilities(&self, log_evidence: &[f64]) -> Result<Vec<f64>> {
396 let n_models = log_evidence.len();
397
398 let log_priors = if let Some(ref priors) = self.prior_probabilities {
400 if priors.len() != n_models {
401 return Err(SklearsError::InvalidInput(
402 "Number of prior probabilities must match number of models".to_string(),
403 ));
404 }
405 priors.iter().map(|&p| p.ln()).collect()
406 } else {
407 vec![-(n_models as f64).ln(); n_models]
409 };
410
411 let log_posteriors: Vec<f64> = log_evidence
413 .iter()
414 .zip(log_priors.iter())
415 .map(|(&log_ev, &log_prior)| log_ev + log_prior)
416 .collect();
417
418 let log_normalizer = log_sum_exp_vec(&log_posteriors);
420 let probabilities: Vec<f64> = log_posteriors
421 .iter()
422 .map(|&log_p| (log_p - log_normalizer).exp())
423 .collect();
424
425 Ok(probabilities)
426 }
427}
428
429#[derive(Debug, Clone)]
431pub struct ModelEvidenceData {
432 pub max_log_likelihood: f64,
434 pub n_parameters: usize,
436 pub n_data_points: usize,
438 pub hessian_log_determinant: Option<f64>,
440 pub log_prior: Option<f64>,
442 pub posterior_samples: Vec<f64>,
444 pub cv_log_likelihoods: Option<Vec<f64>>,
446}
447
448impl ModelEvidenceData {
449 pub fn new(max_log_likelihood: f64, n_parameters: usize, n_data_points: usize) -> Self {
451 Self {
452 max_log_likelihood,
453 n_parameters,
454 n_data_points,
455 hessian_log_determinant: None,
456 log_prior: None,
457 posterior_samples: Vec::new(),
458 cv_log_likelihoods: None,
459 }
460 }
461
462 pub fn with_hessian_log_determinant(mut self, log_det: f64) -> Self {
464 self.hessian_log_determinant = Some(log_det);
465 self
466 }
467
468 pub fn with_log_prior(mut self, log_prior: f64) -> Self {
470 self.log_prior = Some(log_prior);
471 self
472 }
473
474 pub fn with_posterior_samples(mut self, samples: Vec<f64>) -> Self {
476 self.posterior_samples = samples;
477 self
478 }
479
480 pub fn with_cv_log_likelihoods(mut self, cv_scores: Vec<f64>) -> Self {
482 self.cv_log_likelihoods = Some(cv_scores);
483 self
484 }
485}
486
487pub struct BayesianModelAverager {
489 selection_result: BayesianModelSelectionResult,
491}
492
493impl BayesianModelAverager {
494 pub fn new(selection_result: BayesianModelSelectionResult) -> Self {
496 Self { selection_result }
497 }
498
499 pub fn predict(&self, model_predictions: &[Array1<f64>]) -> Result<Array1<f64>> {
501 if model_predictions.len() != self.selection_result.model_names.len() {
502 return Err(SklearsError::InvalidInput(
503 "Number of predictions must match number of models".to_string(),
504 ));
505 }
506
507 if model_predictions.is_empty() {
508 return Err(SklearsError::InvalidInput(
509 "No predictions provided".to_string(),
510 ));
511 }
512
513 let n_samples = model_predictions[0].len();
514
515 for pred in model_predictions {
517 if pred.len() != n_samples {
518 return Err(SklearsError::InvalidInput(
519 "All predictions must have the same length".to_string(),
520 ));
521 }
522 }
523
524 let mut averaged_prediction = Array1::zeros(n_samples);
526
527 for (i, pred) in model_predictions.iter().enumerate() {
528 let weight = self.selection_result.model_probabilities[i];
529 averaged_prediction = averaged_prediction + pred * weight;
530 }
531
532 Ok(averaged_prediction)
533 }
534
535 pub fn get_weights(&self) -> &[f64] {
537 &self.selection_result.model_probabilities
538 }
539
540 pub fn effective_number_of_models(&self) -> f64 {
542 let entropy: f64 = self
543 .selection_result
544 .model_probabilities
545 .iter()
546 .filter(|&&p| p > 0.0)
547 .map(|&p| -p * p.ln())
548 .sum();
549 entropy.exp()
550 }
551}
552
553fn log_sum_exp(a: f64, b: f64) -> f64 {
555 let max_val = a.max(b);
556 max_val + ((a - max_val).exp() + (b - max_val).exp()).ln()
557}
558
559fn log_sum_exp_vec(values: &[f64]) -> f64 {
560 let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
561 let sum: f64 = values.iter().map(|&x| (x - max_val).exp()).sum();
562 max_val + sum.ln()
563}
564
565#[allow(non_snake_case)]
566#[cfg(test)]
567mod tests {
568 use super::*;
569 use scirs2_core::ndarray::array;
570
571 #[test]
572 fn test_bic_approximation() {
573 let data = ModelEvidenceData::new(-100.0, 5, 100);
574 let mut selector = BayesianModelSelector::new(EvidenceEstimationMethod::BIC, Some(42));
575
576 let log_evidence = selector.estimate_evidence(&data).unwrap();
577 assert!(log_evidence < 0.0); }
579
580 #[test]
581 fn test_model_comparison() {
582 let data1 = ModelEvidenceData::new(-95.0, 3, 100);
583 let data2 = ModelEvidenceData::new(-105.0, 5, 100);
584
585 let models = vec![("Model1".to_string(), data1), ("Model2".to_string(), data2)];
586
587 let mut selector = BayesianModelSelector::new(EvidenceEstimationMethod::BIC, Some(42));
588
589 let result: BayesianModelSelectionResult = selector.compare_models::<f64>(&models).unwrap();
590
591 assert_eq!(result.model_names.len(), 2);
592 assert_eq!(result.best_model(), "Model1"); assert!((result.model_probabilities.iter().sum::<f64>() - 1.0).abs() < 1e-6);
594 }
595
596 #[test]
597 fn test_bayesian_model_averaging() {
598 let data1 = ModelEvidenceData::new(-95.0, 3, 100);
599 let data2 = ModelEvidenceData::new(-105.0, 5, 100);
600
601 let models = vec![("Model1".to_string(), data1), ("Model2".to_string(), data2)];
602
603 let mut selector = BayesianModelSelector::new(EvidenceEstimationMethod::BIC, Some(42));
604
605 let selection_result = selector.compare_models::<f64>(&models).unwrap();
606 let averager = BayesianModelAverager::new(selection_result);
607
608 let pred1 = array![1.0, 2.0, 3.0];
609 let pred2 = array![1.1, 2.1, 3.1];
610 let predictions = vec![pred1, pred2];
611
612 let averaged = averager.predict(&predictions).unwrap();
613 assert_eq!(averaged.len(), 3);
614
615 let effective_n = averager.effective_number_of_models();
617 assert!(effective_n >= 1.0 && effective_n <= 2.0);
618 }
619
620 #[test]
621 fn test_evidence_interpretation() {
622 let log_evidence = vec![-95.0, -100.0];
623 let model_probabilities = vec![0.8, 0.2];
624 let bayes_factors = vec![1.0, 0.007]; let result = BayesianModelSelectionResult {
627 model_names: vec!["Model1".to_string(), "Model2".to_string()],
628 log_evidence,
629 model_probabilities,
630 bayes_factors,
631 best_model_index: 0,
632 method: EvidenceEstimationMethod::BIC,
633 };
634
635 let interpretation = result.evidence_interpretation(0, 1);
636 assert!(interpretation.contains("Strong"));
637 }
638
639 #[test]
640 fn test_model_ranking() {
641 let result = BayesianModelSelectionResult {
642 model_names: vec![
643 "ModelA".to_string(),
644 "ModelB".to_string(),
645 "ModelC".to_string(),
646 ],
647 log_evidence: vec![-100.0, -95.0, -98.0],
648 model_probabilities: vec![0.1, 0.7, 0.2],
649 bayes_factors: vec![0.007, 1.0, 0.05],
650 best_model_index: 1,
651 method: EvidenceEstimationMethod::BIC,
652 };
653
654 let ranking = result.model_ranking();
655 assert_eq!(ranking[0].1, "ModelB"); assert_eq!(ranking[1].1, "ModelC"); assert_eq!(ranking[2].1, "ModelA"); }
659}