1use scirs2_core::ndarray::Array1;
15use scirs2_core::numeric::{Float as FloatTrait, ToPrimitive};
16use scirs2_core::random::rngs::StdRng;
17use scirs2_core::random::SeedableRng;
18use sklears_core::error::{Result, SklearsError};
19use std::fmt::Debug;
20
21#[derive(Debug, Clone)]
23pub struct BayesianModelSelectionResult {
24 pub model_names: Vec<String>,
26 pub log_evidence: Vec<f64>,
28 pub model_probabilities: Vec<f64>,
30 pub bayes_factors: Vec<f64>,
32 pub best_model_index: usize,
34 pub method: EvidenceEstimationMethod,
36}
37
38impl BayesianModelSelectionResult {
39 pub fn best_model(&self) -> &str {
41 &self.model_names[self.best_model_index]
42 }
43
44 pub fn model_ranking(&self) -> Vec<(usize, &str, f64)> {
46 let mut ranking: Vec<(usize, &str, f64)> = self
47 .model_names
48 .iter()
49 .enumerate()
50 .map(|(i, name)| (i, name.as_str(), self.log_evidence[i]))
51 .collect();
52
53 ranking.sort_by(|a, b| b.2.partial_cmp(&a.2).expect("operation should succeed"));
54 ranking
55 }
56
57 pub fn evidence_interpretation(&self, model1_idx: usize, model2_idx: usize) -> String {
59 let log_bf = self.log_evidence[model1_idx] - self.log_evidence[model2_idx];
60 let _bf = log_bf.exp();
61
62 match log_bf {
63 x if x < 1.0 => "Weak evidence".to_string(),
64 x if x < 2.5 => "Positive evidence".to_string(),
65 x if x <= 5.0 => "Strong evidence".to_string(),
66 _ => "Very strong evidence".to_string(),
67 }
68 }
69}
70
71#[derive(Debug, Clone)]
73pub enum EvidenceEstimationMethod {
74 LaplaceApproximation,
76 BIC,
78 AICc,
80 HarmonicMean,
82 ThermodynamicIntegration { n_temperatures: usize },
84 NestedSampling { n_live_points: usize },
86 CrossValidationEvidence { n_folds: usize },
88}
89
90pub struct BayesianModelSelector {
92 method: EvidenceEstimationMethod,
94 prior_probabilities: Option<Vec<f64>>,
96 rng: StdRng,
98}
99
100impl BayesianModelSelector {
101 pub fn new(method: EvidenceEstimationMethod, random_state: Option<u64>) -> Self {
103 let rng = match random_state {
104 Some(seed) => StdRng::seed_from_u64(seed),
105 None => StdRng::from_rng(&mut scirs2_core::random::thread_rng()),
106 };
107
108 Self {
109 method,
110 prior_probabilities: None,
111 rng,
112 }
113 }
114
115 pub fn with_prior_probabilities(mut self, priors: Vec<f64>) -> Result<Self> {
117 let sum: f64 = priors.iter().sum();
118 if (sum - 1.0).abs() > 1e-6 {
119 return Err(SklearsError::InvalidInput(
120 "Prior probabilities must sum to 1".to_string(),
121 ));
122 }
123 self.prior_probabilities = Some(priors);
124 Ok(self)
125 }
126
127 pub fn compare_models<F>(
129 &mut self,
130 model_results: &[(String, ModelEvidenceData)],
131 ) -> Result<BayesianModelSelectionResult>
132 where
133 F: FloatTrait + ToPrimitive,
134 {
135 if model_results.is_empty() {
136 return Err(SklearsError::InvalidInput("No models provided".to_string()));
137 }
138
139 let model_names: Vec<String> = model_results.iter().map(|(name, _)| name.clone()).collect();
140 let _n_models = model_names.len();
141
142 let mut log_evidence = Vec::new();
144 for (_, data) in model_results {
145 let log_ev = self.estimate_evidence(data)?;
146 log_evidence.push(log_ev);
147 }
148
149 let model_probabilities = self.calculate_model_probabilities(&log_evidence)?;
151
152 let best_log_evidence = log_evidence
154 .iter()
155 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
156 let bayes_factors: Vec<f64> = log_evidence
157 .iter()
158 .map(|&log_ev| (log_ev - best_log_evidence).exp())
159 .collect();
160
161 let best_model_index = log_evidence
162 .iter()
163 .enumerate()
164 .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
165 .map(|(i, _)| i)
166 .expect("operation should succeed");
167
168 Ok(BayesianModelSelectionResult {
169 model_names,
170 log_evidence,
171 model_probabilities,
172 bayes_factors,
173 best_model_index,
174 method: self.method.clone(),
175 })
176 }
177
178 fn estimate_evidence(&mut self, data: &ModelEvidenceData) -> Result<f64> {
180 match &self.method {
181 EvidenceEstimationMethod::LaplaceApproximation => self.laplace_approximation(data),
182 EvidenceEstimationMethod::BIC => self.bic_approximation(data),
183 EvidenceEstimationMethod::AICc => self.aicc_approximation(data),
184 EvidenceEstimationMethod::HarmonicMean => self.harmonic_mean_estimator(data),
185 EvidenceEstimationMethod::ThermodynamicIntegration { n_temperatures } => {
186 self.thermodynamic_integration(data, *n_temperatures)
187 }
188 EvidenceEstimationMethod::NestedSampling { n_live_points } => {
189 self.nested_sampling_approximation(data, *n_live_points)
190 }
191 EvidenceEstimationMethod::CrossValidationEvidence { n_folds } => {
192 self.cross_validation_evidence(data, *n_folds)
193 }
194 }
195 }
196
197 fn laplace_approximation(&self, data: &ModelEvidenceData) -> Result<f64> {
199 let n_params = data.n_parameters as f64;
200 let n_data = data.n_data_points as f64;
201
202 let log_likelihood_map = data.max_log_likelihood;
204
205 let log_det_hessian = data.hessian_log_determinant.unwrap_or_else(|| {
207 n_params * (2.0 * std::f64::consts::PI).ln() + n_params * n_data.ln()
209 });
210
211 let log_prior = data.log_prior.unwrap_or(0.0);
213
214 let log_evidence =
216 log_likelihood_map + log_prior + (n_params / 2.0) * (2.0 * std::f64::consts::PI).ln()
217 - 0.5 * log_det_hessian;
218
219 Ok(log_evidence)
220 }
221
222 fn bic_approximation(&self, data: &ModelEvidenceData) -> Result<f64> {
224 let n_params = data.n_parameters as f64;
225 let n_data = data.n_data_points as f64;
226
227 let log_evidence = data.max_log_likelihood
230 - (n_params / 2.0) * n_data.ln()
231 - (n_params / 2.0) * (2.0 * std::f64::consts::PI).ln();
232
233 Ok(log_evidence)
234 }
235
236 fn aicc_approximation(&self, data: &ModelEvidenceData) -> Result<f64> {
238 let k = data.n_parameters as f64;
239 let n = data.n_data_points as f64;
240
241 if n <= k + 1.0 {
242 return Err(SklearsError::InvalidInput(
243 "AICc requires n > k + 1".to_string(),
244 ));
245 }
246
247 let aicc_correction = 2.0 * k * (k + 1.0) / (n - k - 1.0);
249 let log_evidence = data.max_log_likelihood - k - aicc_correction / 2.0;
250
251 Ok(log_evidence)
252 }
253
254 fn harmonic_mean_estimator(&self, data: &ModelEvidenceData) -> Result<f64> {
256 if data.posterior_samples.is_empty() {
257 return Err(SklearsError::InvalidInput(
258 "Harmonic mean estimator requires posterior samples".to_string(),
259 ));
260 }
261
262 let n_samples = data.posterior_samples.len() as f64;
265 let harmonic_mean: f64 = data
266 .posterior_samples
267 .iter()
268 .map(|&log_likelihood| (-log_likelihood).exp())
269 .sum::<f64>()
270 / n_samples;
271
272 let log_evidence = -harmonic_mean.ln();
273
274 eprintln!("Warning: Harmonic mean estimator is known to be unreliable");
276
277 Ok(log_evidence)
278 }
279
280 fn thermodynamic_integration(
282 &mut self,
283 data: &ModelEvidenceData,
284 n_temperatures: usize,
285 ) -> Result<f64> {
286 if data.posterior_samples.is_empty() {
287 return Err(SklearsError::InvalidInput(
288 "Thermodynamic integration requires posterior samples".to_string(),
289 ));
290 }
291
292 let temperatures: Vec<f64> = (0..=n_temperatures)
294 .map(|i| i as f64 / n_temperatures as f64)
295 .collect();
296
297 let mut mean_log_likelihoods = Vec::new();
299
300 for &temp in &temperatures {
301 if temp == 0.0 {
302 mean_log_likelihoods.push(0.0);
304 } else {
305 let mean_ll = data.posterior_samples.iter().sum::<f64>()
307 / data.posterior_samples.len() as f64;
308 mean_log_likelihoods.push(temp * mean_ll);
309 }
310 }
311
312 let mut integral = 0.0;
314 for i in 1..temperatures.len() {
315 let dt = temperatures[i] - temperatures[i - 1];
316 integral += 0.5 * dt * (mean_log_likelihoods[i] + mean_log_likelihoods[i - 1]);
317 }
318
319 let log_prior = data.log_prior.unwrap_or(0.0);
321 let log_evidence = log_prior + integral;
322
323 Ok(log_evidence)
324 }
325
326 fn nested_sampling_approximation(
328 &mut self,
329 data: &ModelEvidenceData,
330 n_live_points: usize,
331 ) -> Result<f64> {
332 if data.posterior_samples.is_empty() {
336 return Err(SklearsError::InvalidInput(
337 "Nested sampling requires posterior samples".to_string(),
338 ));
339 }
340
341 let n_samples = data.posterior_samples.len();
342 let max_iterations = n_samples.min(1000); let mut sorted_samples = data.posterior_samples.clone();
346 sorted_samples.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
347
348 let mut log_evidence = f64::NEG_INFINITY;
350 let mut log_width = -(1.0 / n_live_points as f64).ln();
351
352 for (i, &log_likelihood) in sorted_samples.iter().enumerate() {
353 if i >= max_iterations {
354 break;
355 }
356
357 let log_weight = log_width + log_likelihood;
358 log_evidence = log_sum_exp(log_evidence, log_weight);
359
360 log_width -= (n_live_points as f64).ln();
362 }
363
364 Ok(log_evidence)
365 }
366
367 fn cross_validation_evidence(&self, data: &ModelEvidenceData, n_folds: usize) -> Result<f64> {
369 if data.cv_log_likelihoods.is_none() {
370 return Err(SklearsError::InvalidInput(
371 "Cross-validation evidence requires CV log-likelihoods".to_string(),
372 ));
373 }
374
375 let cv_log_likes = data
376 .cv_log_likelihoods
377 .as_ref()
378 .expect("operation should succeed");
379 if cv_log_likes.len() != n_folds {
380 return Err(SklearsError::InvalidInput(
381 "Number of CV scores must match number of folds".to_string(),
382 ));
383 }
384
385 let mean_cv_log_likelihood = cv_log_likes.iter().sum::<f64>() / cv_log_likes.len() as f64;
387
388 let n_data = data.n_data_points as f64;
390 let correction = (n_data / (n_data - 1.0)).ln();
391
392 let log_evidence = mean_cv_log_likelihood + correction;
393
394 Ok(log_evidence)
395 }
396
397 fn calculate_model_probabilities(&self, log_evidence: &[f64]) -> Result<Vec<f64>> {
399 let n_models = log_evidence.len();
400
401 let log_priors = if let Some(ref priors) = self.prior_probabilities {
403 if priors.len() != n_models {
404 return Err(SklearsError::InvalidInput(
405 "Number of prior probabilities must match number of models".to_string(),
406 ));
407 }
408 priors.iter().map(|&p| p.ln()).collect()
409 } else {
410 vec![-(n_models as f64).ln(); n_models]
412 };
413
414 let log_posteriors: Vec<f64> = log_evidence
416 .iter()
417 .zip(log_priors.iter())
418 .map(|(&log_ev, &log_prior)| log_ev + log_prior)
419 .collect();
420
421 let log_normalizer = log_sum_exp_vec(&log_posteriors);
423 let probabilities: Vec<f64> = log_posteriors
424 .iter()
425 .map(|&log_p| (log_p - log_normalizer).exp())
426 .collect();
427
428 Ok(probabilities)
429 }
430}
431
432#[derive(Debug, Clone)]
434pub struct ModelEvidenceData {
435 pub max_log_likelihood: f64,
437 pub n_parameters: usize,
439 pub n_data_points: usize,
441 pub hessian_log_determinant: Option<f64>,
443 pub log_prior: Option<f64>,
445 pub posterior_samples: Vec<f64>,
447 pub cv_log_likelihoods: Option<Vec<f64>>,
449}
450
451impl ModelEvidenceData {
452 pub fn new(max_log_likelihood: f64, n_parameters: usize, n_data_points: usize) -> Self {
454 Self {
455 max_log_likelihood,
456 n_parameters,
457 n_data_points,
458 hessian_log_determinant: None,
459 log_prior: None,
460 posterior_samples: Vec::new(),
461 cv_log_likelihoods: None,
462 }
463 }
464
465 pub fn with_hessian_log_determinant(mut self, log_det: f64) -> Self {
467 self.hessian_log_determinant = Some(log_det);
468 self
469 }
470
471 pub fn with_log_prior(mut self, log_prior: f64) -> Self {
473 self.log_prior = Some(log_prior);
474 self
475 }
476
477 pub fn with_posterior_samples(mut self, samples: Vec<f64>) -> Self {
479 self.posterior_samples = samples;
480 self
481 }
482
483 pub fn with_cv_log_likelihoods(mut self, cv_scores: Vec<f64>) -> Self {
485 self.cv_log_likelihoods = Some(cv_scores);
486 self
487 }
488}
489
490pub struct BayesianModelAverager {
492 selection_result: BayesianModelSelectionResult,
494}
495
496impl BayesianModelAverager {
497 pub fn new(selection_result: BayesianModelSelectionResult) -> Self {
499 Self { selection_result }
500 }
501
502 pub fn predict(&self, model_predictions: &[Array1<f64>]) -> Result<Array1<f64>> {
504 if model_predictions.len() != self.selection_result.model_names.len() {
505 return Err(SklearsError::InvalidInput(
506 "Number of predictions must match number of models".to_string(),
507 ));
508 }
509
510 if model_predictions.is_empty() {
511 return Err(SklearsError::InvalidInput(
512 "No predictions provided".to_string(),
513 ));
514 }
515
516 let n_samples = model_predictions[0].len();
517
518 for pred in model_predictions {
520 if pred.len() != n_samples {
521 return Err(SklearsError::InvalidInput(
522 "All predictions must have the same length".to_string(),
523 ));
524 }
525 }
526
527 let mut averaged_prediction = Array1::zeros(n_samples);
529
530 for (i, pred) in model_predictions.iter().enumerate() {
531 let weight = self.selection_result.model_probabilities[i];
532 averaged_prediction = averaged_prediction + pred * weight;
533 }
534
535 Ok(averaged_prediction)
536 }
537
538 pub fn get_weights(&self) -> &[f64] {
540 &self.selection_result.model_probabilities
541 }
542
543 pub fn effective_number_of_models(&self) -> f64 {
545 let entropy: f64 = self
546 .selection_result
547 .model_probabilities
548 .iter()
549 .filter(|&&p| p > 0.0)
550 .map(|&p| -p * p.ln())
551 .sum();
552 entropy.exp()
553 }
554}
555
556fn log_sum_exp(a: f64, b: f64) -> f64 {
558 let max_val = a.max(b);
559 max_val + ((a - max_val).exp() + (b - max_val).exp()).ln()
560}
561
562fn log_sum_exp_vec(values: &[f64]) -> f64 {
563 let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
564 let sum: f64 = values.iter().map(|&x| (x - max_val).exp()).sum();
565 max_val + sum.ln()
566}
567
568#[allow(non_snake_case)]
569#[cfg(test)]
570mod tests {
571 use super::*;
572 use scirs2_core::ndarray::array;
573
574 #[test]
575 fn test_bic_approximation() {
576 let data = ModelEvidenceData::new(-100.0, 5, 100);
577 let mut selector = BayesianModelSelector::new(EvidenceEstimationMethod::BIC, Some(42));
578
579 let log_evidence = selector
580 .estimate_evidence(&data)
581 .expect("operation should succeed");
582 assert!(log_evidence < 0.0); }
584
585 #[test]
586 fn test_model_comparison() {
587 let data1 = ModelEvidenceData::new(-95.0, 3, 100);
588 let data2 = ModelEvidenceData::new(-105.0, 5, 100);
589
590 let models = vec![("Model1".to_string(), data1), ("Model2".to_string(), data2)];
591
592 let mut selector = BayesianModelSelector::new(EvidenceEstimationMethod::BIC, Some(42));
593
594 let result: BayesianModelSelectionResult = selector
595 .compare_models::<f64>(&models)
596 .expect("operation should succeed");
597
598 assert_eq!(result.model_names.len(), 2);
599 assert_eq!(result.best_model(), "Model1"); assert!((result.model_probabilities.iter().sum::<f64>() - 1.0).abs() < 1e-6);
601 }
602
603 #[test]
604 fn test_bayesian_model_averaging() {
605 let data1 = ModelEvidenceData::new(-95.0, 3, 100);
606 let data2 = ModelEvidenceData::new(-105.0, 5, 100);
607
608 let models = vec![("Model1".to_string(), data1), ("Model2".to_string(), data2)];
609
610 let mut selector = BayesianModelSelector::new(EvidenceEstimationMethod::BIC, Some(42));
611
612 let selection_result = selector
613 .compare_models::<f64>(&models)
614 .expect("operation should succeed");
615 let averager = BayesianModelAverager::new(selection_result);
616
617 let pred1 = array![1.0, 2.0, 3.0];
618 let pred2 = array![1.1, 2.1, 3.1];
619 let predictions = vec![pred1, pred2];
620
621 let averaged = averager
622 .predict(&predictions)
623 .expect("operation should succeed");
624 assert_eq!(averaged.len(), 3);
625
626 let effective_n = averager.effective_number_of_models();
628 assert!(effective_n >= 1.0 && effective_n <= 2.0);
629 }
630
631 #[test]
632 fn test_evidence_interpretation() {
633 let log_evidence = vec![-95.0, -100.0];
634 let model_probabilities = vec![0.8, 0.2];
635 let bayes_factors = vec![1.0, 0.007]; let result = BayesianModelSelectionResult {
638 model_names: vec!["Model1".to_string(), "Model2".to_string()],
639 log_evidence,
640 model_probabilities,
641 bayes_factors,
642 best_model_index: 0,
643 method: EvidenceEstimationMethod::BIC,
644 };
645
646 let interpretation = result.evidence_interpretation(0, 1);
647 assert!(interpretation.contains("Strong"));
648 }
649
650 #[test]
651 fn test_model_ranking() {
652 let result = BayesianModelSelectionResult {
653 model_names: vec![
654 "ModelA".to_string(),
655 "ModelB".to_string(),
656 "ModelC".to_string(),
657 ],
658 log_evidence: vec![-100.0, -95.0, -98.0],
659 model_probabilities: vec![0.1, 0.7, 0.2],
660 bayes_factors: vec![0.007, 1.0, 0.05],
661 best_model_index: 1,
662 method: EvidenceEstimationMethod::BIC,
663 };
664
665 let ranking = result.model_ranking();
666 assert_eq!(ranking[0].1, "ModelB"); assert_eq!(ranking[1].1, "ModelC"); assert_eq!(ranking[2].1, "ModelA"); }
670}