Skip to main content

sklears_model_selection/
bayesian_model_selection.rs

1//! Bayesian Model Selection with Evidence Estimation
2//!
3//! This module implements Bayesian model selection techniques that use the marginal
4//! likelihood (evidence) to compare models. It includes various methods for estimating
5//! the evidence, including:
6//!
7//! - Laplace approximation
8//! - Bayesian Information Criterion (BIC) approximation
9//! - Harmonic mean estimator
10//! - Thermodynamic integration
11//! - Nested sampling approximation
12//! - Model averaging with Bayesian weights
13
14use scirs2_core::ndarray::Array1;
15use scirs2_core::numeric::{Float as FloatTrait, ToPrimitive};
16use scirs2_core::random::rngs::StdRng;
17use scirs2_core::random::SeedableRng;
18use sklears_core::error::{Result, SklearsError};
19use std::fmt::Debug;
20
21/// Result of Bayesian model selection
22#[derive(Debug, Clone)]
23pub struct BayesianModelSelectionResult {
24    /// Model identifiers
25    pub model_names: Vec<String>,
26    /// Log evidence for each model
27    pub log_evidence: Vec<f64>,
28    /// Model probabilities (Bayesian weights)
29    pub model_probabilities: Vec<f64>,
30    /// Bayes factors relative to the best model
31    pub bayes_factors: Vec<f64>,
32    /// Best model index
33    pub best_model_index: usize,
34    /// Evidence estimation method used
35    pub method: EvidenceEstimationMethod,
36}
37
38impl BayesianModelSelectionResult {
39    /// Get the best model name
40    pub fn best_model(&self) -> &str {
41        &self.model_names[self.best_model_index]
42    }
43
44    /// Get ranking of models by evidence
45    pub fn model_ranking(&self) -> Vec<(usize, &str, f64)> {
46        let mut ranking: Vec<(usize, &str, f64)> = self
47            .model_names
48            .iter()
49            .enumerate()
50            .map(|(i, name)| (i, name.as_str(), self.log_evidence[i]))
51            .collect();
52
53        ranking.sort_by(|a, b| b.2.partial_cmp(&a.2).expect("operation should succeed"));
54        ranking
55    }
56
57    /// Interpret the strength of evidence using Jeffreys' scale
58    pub fn evidence_interpretation(&self, model1_idx: usize, model2_idx: usize) -> String {
59        let log_bf = self.log_evidence[model1_idx] - self.log_evidence[model2_idx];
60        let _bf = log_bf.exp();
61
62        match log_bf {
63            x if x < 1.0 => "Weak evidence".to_string(),
64            x if x < 2.5 => "Positive evidence".to_string(),
65            x if x <= 5.0 => "Strong evidence".to_string(),
66            _ => "Very strong evidence".to_string(),
67        }
68    }
69}
70
71/// Methods for estimating the evidence (marginal likelihood)
72#[derive(Debug, Clone)]
73pub enum EvidenceEstimationMethod {
74    /// Laplace approximation (Gaussian approximation around MAP)
75    LaplaceApproximation,
76    /// BIC approximation (asymptotic approximation)
77    BIC,
78    /// AIC with correction for finite sample size
79    AICc,
80    /// Harmonic mean estimator
81    HarmonicMean,
82    /// Thermodynamic integration
83    ThermodynamicIntegration { n_temperatures: usize },
84    /// Nested sampling approximation
85    NestedSampling { n_live_points: usize },
86    /// Cross-validation based evidence approximation
87    CrossValidationEvidence { n_folds: usize },
88}
89
90/// Bayesian model selector
91pub struct BayesianModelSelector {
92    /// Evidence estimation method
93    method: EvidenceEstimationMethod,
94    /// Prior model probabilities (if not uniform)
95    prior_probabilities: Option<Vec<f64>>,
96    /// Random number generator
97    rng: StdRng,
98}
99
100impl BayesianModelSelector {
101    /// Create a new Bayesian model selector
102    pub fn new(method: EvidenceEstimationMethod, random_state: Option<u64>) -> Self {
103        let rng = match random_state {
104            Some(seed) => StdRng::seed_from_u64(seed),
105            None => StdRng::from_rng(&mut scirs2_core::random::thread_rng()),
106        };
107
108        Self {
109            method,
110            prior_probabilities: None,
111            rng,
112        }
113    }
114
115    /// Set prior probabilities for models (must sum to 1)
116    pub fn with_prior_probabilities(mut self, priors: Vec<f64>) -> Result<Self> {
117        let sum: f64 = priors.iter().sum();
118        if (sum - 1.0).abs() > 1e-6 {
119            return Err(SklearsError::InvalidInput(
120                "Prior probabilities must sum to 1".to_string(),
121            ));
122        }
123        self.prior_probabilities = Some(priors);
124        Ok(self)
125    }
126
127    /// Compare models using Bayesian evidence
128    pub fn compare_models<F>(
129        &mut self,
130        model_results: &[(String, ModelEvidenceData)],
131    ) -> Result<BayesianModelSelectionResult>
132    where
133        F: FloatTrait + ToPrimitive,
134    {
135        if model_results.is_empty() {
136            return Err(SklearsError::InvalidInput("No models provided".to_string()));
137        }
138
139        let model_names: Vec<String> = model_results.iter().map(|(name, _)| name.clone()).collect();
140        let _n_models = model_names.len();
141
142        // Estimate log evidence for each model
143        let mut log_evidence = Vec::new();
144        for (_, data) in model_results {
145            let log_ev = self.estimate_evidence(data)?;
146            log_evidence.push(log_ev);
147        }
148
149        // Calculate model probabilities
150        let model_probabilities = self.calculate_model_probabilities(&log_evidence)?;
151
152        // Calculate Bayes factors relative to best model
153        let best_log_evidence = log_evidence
154            .iter()
155            .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
156        let bayes_factors: Vec<f64> = log_evidence
157            .iter()
158            .map(|&log_ev| (log_ev - best_log_evidence).exp())
159            .collect();
160
161        let best_model_index = log_evidence
162            .iter()
163            .enumerate()
164            .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
165            .map(|(i, _)| i)
166            .expect("operation should succeed");
167
168        Ok(BayesianModelSelectionResult {
169            model_names,
170            log_evidence,
171            model_probabilities,
172            bayes_factors,
173            best_model_index,
174            method: self.method.clone(),
175        })
176    }
177
178    /// Estimate evidence for a single model
179    fn estimate_evidence(&mut self, data: &ModelEvidenceData) -> Result<f64> {
180        match &self.method {
181            EvidenceEstimationMethod::LaplaceApproximation => self.laplace_approximation(data),
182            EvidenceEstimationMethod::BIC => self.bic_approximation(data),
183            EvidenceEstimationMethod::AICc => self.aicc_approximation(data),
184            EvidenceEstimationMethod::HarmonicMean => self.harmonic_mean_estimator(data),
185            EvidenceEstimationMethod::ThermodynamicIntegration { n_temperatures } => {
186                self.thermodynamic_integration(data, *n_temperatures)
187            }
188            EvidenceEstimationMethod::NestedSampling { n_live_points } => {
189                self.nested_sampling_approximation(data, *n_live_points)
190            }
191            EvidenceEstimationMethod::CrossValidationEvidence { n_folds } => {
192                self.cross_validation_evidence(data, *n_folds)
193            }
194        }
195    }
196
197    /// Laplace approximation to the evidence
198    fn laplace_approximation(&self, data: &ModelEvidenceData) -> Result<f64> {
199        let n_params = data.n_parameters as f64;
200        let n_data = data.n_data_points as f64;
201
202        // Log likelihood at MAP estimate
203        let log_likelihood_map = data.max_log_likelihood;
204
205        // Approximate Hessian determinant (assuming it's available or estimated)
206        let log_det_hessian = data.hessian_log_determinant.unwrap_or_else(|| {
207            // Rough approximation if Hessian not available
208            n_params * (2.0 * std::f64::consts::PI).ln() + n_params * n_data.ln()
209        });
210
211        // Prior contribution (assuming flat priors for simplicity)
212        let log_prior = data.log_prior.unwrap_or(0.0);
213
214        // Laplace approximation: log Z ≈ log L(θ_MAP) + log π(θ_MAP) + (k/2) log(2π) - (1/2) log|H|
215        let log_evidence =
216            log_likelihood_map + log_prior + (n_params / 2.0) * (2.0 * std::f64::consts::PI).ln()
217                - 0.5 * log_det_hessian;
218
219        Ok(log_evidence)
220    }
221
222    /// BIC approximation to the evidence
223    fn bic_approximation(&self, data: &ModelEvidenceData) -> Result<f64> {
224        let n_params = data.n_parameters as f64;
225        let n_data = data.n_data_points as f64;
226
227        // BIC = -2 * log L + k * log(n)
228        // log Z ≈ log L - (k/2) * log(n) - (k/2) * log(2π)
229        let log_evidence = data.max_log_likelihood
230            - (n_params / 2.0) * n_data.ln()
231            - (n_params / 2.0) * (2.0 * std::f64::consts::PI).ln();
232
233        Ok(log_evidence)
234    }
235
236    /// AICc approximation to the evidence
237    fn aicc_approximation(&self, data: &ModelEvidenceData) -> Result<f64> {
238        let k = data.n_parameters as f64;
239        let n = data.n_data_points as f64;
240
241        if n <= k + 1.0 {
242            return Err(SklearsError::InvalidInput(
243                "AICc requires n > k + 1".to_string(),
244            ));
245        }
246
247        // AICc = AIC + 2k(k+1)/(n-k-1)
248        let aicc_correction = 2.0 * k * (k + 1.0) / (n - k - 1.0);
249        let log_evidence = data.max_log_likelihood - k - aicc_correction / 2.0;
250
251        Ok(log_evidence)
252    }
253
254    /// Harmonic mean estimator
255    fn harmonic_mean_estimator(&self, data: &ModelEvidenceData) -> Result<f64> {
256        if data.posterior_samples.is_empty() {
257            return Err(SklearsError::InvalidInput(
258                "Harmonic mean estimator requires posterior samples".to_string(),
259            ));
260        }
261
262        // Harmonic mean: 1/Z = (1/N) Σ 1/L(θ_i)
263        // This is known to be unreliable, but included for completeness
264        let n_samples = data.posterior_samples.len() as f64;
265        let harmonic_mean: f64 = data
266            .posterior_samples
267            .iter()
268            .map(|&log_likelihood| (-log_likelihood).exp())
269            .sum::<f64>()
270            / n_samples;
271
272        let log_evidence = -harmonic_mean.ln();
273
274        // Add warning about reliability
275        eprintln!("Warning: Harmonic mean estimator is known to be unreliable");
276
277        Ok(log_evidence)
278    }
279
280    /// Thermodynamic integration
281    fn thermodynamic_integration(
282        &mut self,
283        data: &ModelEvidenceData,
284        n_temperatures: usize,
285    ) -> Result<f64> {
286        if data.posterior_samples.is_empty() {
287            return Err(SklearsError::InvalidInput(
288                "Thermodynamic integration requires posterior samples".to_string(),
289            ));
290        }
291
292        // Create temperature ladder from 0 to 1
293        let temperatures: Vec<f64> = (0..=n_temperatures)
294            .map(|i| i as f64 / n_temperatures as f64)
295            .collect();
296
297        // Estimate mean log likelihood at each temperature
298        let mut mean_log_likelihoods = Vec::new();
299
300        for &temp in &temperatures {
301            if temp == 0.0 {
302                // At temperature 0, all weight is on the prior
303                mean_log_likelihoods.push(0.0);
304            } else {
305                // Approximate using available samples (simplified)
306                let mean_ll = data.posterior_samples.iter().sum::<f64>()
307                    / data.posterior_samples.len() as f64;
308                mean_log_likelihoods.push(temp * mean_ll);
309            }
310        }
311
312        // Integrate using trapezoidal rule
313        let mut integral = 0.0;
314        for i in 1..temperatures.len() {
315            let dt = temperatures[i] - temperatures[i - 1];
316            integral += 0.5 * dt * (mean_log_likelihoods[i] + mean_log_likelihoods[i - 1]);
317        }
318
319        // Add prior contribution
320        let log_prior = data.log_prior.unwrap_or(0.0);
321        let log_evidence = log_prior + integral;
322
323        Ok(log_evidence)
324    }
325
326    /// Nested sampling approximation
327    fn nested_sampling_approximation(
328        &mut self,
329        data: &ModelEvidenceData,
330        n_live_points: usize,
331    ) -> Result<f64> {
332        // Simplified nested sampling approximation
333        // In practice, this would require implementing the full nested sampling algorithm
334
335        if data.posterior_samples.is_empty() {
336            return Err(SklearsError::InvalidInput(
337                "Nested sampling requires posterior samples".to_string(),
338            ));
339        }
340
341        let n_samples = data.posterior_samples.len();
342        let max_iterations = n_samples.min(1000); // Limit iterations
343
344        // Sort samples by likelihood
345        let mut sorted_samples = data.posterior_samples.clone();
346        sorted_samples.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
347
348        // Approximate evidence using nested sampling logic
349        let mut log_evidence = f64::NEG_INFINITY;
350        let mut log_width = -(1.0 / n_live_points as f64).ln();
351
352        for (i, &log_likelihood) in sorted_samples.iter().enumerate() {
353            if i >= max_iterations {
354                break;
355            }
356
357            let log_weight = log_width + log_likelihood;
358            log_evidence = log_sum_exp(log_evidence, log_weight);
359
360            // Update width (shrinkage)
361            log_width -= (n_live_points as f64).ln();
362        }
363
364        Ok(log_evidence)
365    }
366
367    /// Cross-validation based evidence approximation
368    fn cross_validation_evidence(&self, data: &ModelEvidenceData, n_folds: usize) -> Result<f64> {
369        if data.cv_log_likelihoods.is_none() {
370            return Err(SklearsError::InvalidInput(
371                "Cross-validation evidence requires CV log-likelihoods".to_string(),
372            ));
373        }
374
375        let cv_log_likes = data
376            .cv_log_likelihoods
377            .as_ref()
378            .expect("operation should succeed");
379        if cv_log_likes.len() != n_folds {
380            return Err(SklearsError::InvalidInput(
381                "Number of CV scores must match number of folds".to_string(),
382            ));
383        }
384
385        // Approximate evidence using cross-validation
386        let mean_cv_log_likelihood = cv_log_likes.iter().sum::<f64>() / cv_log_likes.len() as f64;
387
388        // Apply correction for finite sample effects
389        let n_data = data.n_data_points as f64;
390        let correction = (n_data / (n_data - 1.0)).ln();
391
392        let log_evidence = mean_cv_log_likelihood + correction;
393
394        Ok(log_evidence)
395    }
396
397    /// Calculate model probabilities from log evidence
398    fn calculate_model_probabilities(&self, log_evidence: &[f64]) -> Result<Vec<f64>> {
399        let n_models = log_evidence.len();
400
401        // Get prior probabilities
402        let log_priors = if let Some(ref priors) = self.prior_probabilities {
403            if priors.len() != n_models {
404                return Err(SklearsError::InvalidInput(
405                    "Number of prior probabilities must match number of models".to_string(),
406                ));
407            }
408            priors.iter().map(|&p| p.ln()).collect()
409        } else {
410            // Uniform priors
411            vec![-(n_models as f64).ln(); n_models]
412        };
413
414        // Calculate log posterior probabilities
415        let log_posteriors: Vec<f64> = log_evidence
416            .iter()
417            .zip(log_priors.iter())
418            .map(|(&log_ev, &log_prior)| log_ev + log_prior)
419            .collect();
420
421        // Normalize using log-sum-exp trick
422        let log_normalizer = log_sum_exp_vec(&log_posteriors);
423        let probabilities: Vec<f64> = log_posteriors
424            .iter()
425            .map(|&log_p| (log_p - log_normalizer).exp())
426            .collect();
427
428        Ok(probabilities)
429    }
430}
431
432/// Data required for evidence estimation
433#[derive(Debug, Clone)]
434pub struct ModelEvidenceData {
435    /// Maximum log-likelihood achieved
436    pub max_log_likelihood: f64,
437    /// Number of model parameters
438    pub n_parameters: usize,
439    /// Number of data points
440    pub n_data_points: usize,
441    /// Log determinant of Hessian at MAP (for Laplace approximation)
442    pub hessian_log_determinant: Option<f64>,
443    /// Log prior probability at MAP
444    pub log_prior: Option<f64>,
445    /// Posterior samples (log-likelihoods)
446    pub posterior_samples: Vec<f64>,
447    /// Cross-validation log-likelihoods
448    pub cv_log_likelihoods: Option<Vec<f64>>,
449}
450
451impl ModelEvidenceData {
452    /// Create new evidence data with required fields
453    pub fn new(max_log_likelihood: f64, n_parameters: usize, n_data_points: usize) -> Self {
454        Self {
455            max_log_likelihood,
456            n_parameters,
457            n_data_points,
458            hessian_log_determinant: None,
459            log_prior: None,
460            posterior_samples: Vec::new(),
461            cv_log_likelihoods: None,
462        }
463    }
464
465    /// Add Hessian information for Laplace approximation
466    pub fn with_hessian_log_determinant(mut self, log_det: f64) -> Self {
467        self.hessian_log_determinant = Some(log_det);
468        self
469    }
470
471    /// Add prior information
472    pub fn with_log_prior(mut self, log_prior: f64) -> Self {
473        self.log_prior = Some(log_prior);
474        self
475    }
476
477    /// Add posterior samples
478    pub fn with_posterior_samples(mut self, samples: Vec<f64>) -> Self {
479        self.posterior_samples = samples;
480        self
481    }
482
483    /// Add cross-validation log-likelihoods
484    pub fn with_cv_log_likelihoods(mut self, cv_scores: Vec<f64>) -> Self {
485        self.cv_log_likelihoods = Some(cv_scores);
486        self
487    }
488}
489
490/// Model averaging using Bayesian weights
491pub struct BayesianModelAverager {
492    /// Model selection result containing weights
493    selection_result: BayesianModelSelectionResult,
494}
495
496impl BayesianModelAverager {
497    /// Create new model averager from selection result
498    pub fn new(selection_result: BayesianModelSelectionResult) -> Self {
499        Self { selection_result }
500    }
501
502    /// Make prediction using Bayesian model averaging
503    pub fn predict(&self, model_predictions: &[Array1<f64>]) -> Result<Array1<f64>> {
504        if model_predictions.len() != self.selection_result.model_names.len() {
505            return Err(SklearsError::InvalidInput(
506                "Number of predictions must match number of models".to_string(),
507            ));
508        }
509
510        if model_predictions.is_empty() {
511            return Err(SklearsError::InvalidInput(
512                "No predictions provided".to_string(),
513            ));
514        }
515
516        let n_samples = model_predictions[0].len();
517
518        // Check all predictions have same length
519        for pred in model_predictions {
520            if pred.len() != n_samples {
521                return Err(SklearsError::InvalidInput(
522                    "All predictions must have the same length".to_string(),
523                ));
524            }
525        }
526
527        // Weighted average of predictions
528        let mut averaged_prediction = Array1::zeros(n_samples);
529
530        for (i, pred) in model_predictions.iter().enumerate() {
531            let weight = self.selection_result.model_probabilities[i];
532            averaged_prediction = averaged_prediction + pred * weight;
533        }
534
535        Ok(averaged_prediction)
536    }
537
538    /// Get model weights
539    pub fn get_weights(&self) -> &[f64] {
540        &self.selection_result.model_probabilities
541    }
542
543    /// Get effective number of models (entropy-based measure)
544    pub fn effective_number_of_models(&self) -> f64 {
545        let entropy: f64 = self
546            .selection_result
547            .model_probabilities
548            .iter()
549            .filter(|&&p| p > 0.0)
550            .map(|&p| -p * p.ln())
551            .sum();
552        entropy.exp()
553    }
554}
555
556// Utility functions
557fn log_sum_exp(a: f64, b: f64) -> f64 {
558    let max_val = a.max(b);
559    max_val + ((a - max_val).exp() + (b - max_val).exp()).ln()
560}
561
562fn log_sum_exp_vec(values: &[f64]) -> f64 {
563    let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
564    let sum: f64 = values.iter().map(|&x| (x - max_val).exp()).sum();
565    max_val + sum.ln()
566}
567
568#[allow(non_snake_case)]
569#[cfg(test)]
570mod tests {
571    use super::*;
572    use scirs2_core::ndarray::array;
573
574    #[test]
575    fn test_bic_approximation() {
576        let data = ModelEvidenceData::new(-100.0, 5, 100);
577        let mut selector = BayesianModelSelector::new(EvidenceEstimationMethod::BIC, Some(42));
578
579        let log_evidence = selector
580            .estimate_evidence(&data)
581            .expect("operation should succeed");
582        assert!(log_evidence < 0.0); // Should be negative
583    }
584
585    #[test]
586    fn test_model_comparison() {
587        let data1 = ModelEvidenceData::new(-95.0, 3, 100);
588        let data2 = ModelEvidenceData::new(-105.0, 5, 100);
589
590        let models = vec![("Model1".to_string(), data1), ("Model2".to_string(), data2)];
591
592        let mut selector = BayesianModelSelector::new(EvidenceEstimationMethod::BIC, Some(42));
593
594        let result: BayesianModelSelectionResult = selector
595            .compare_models::<f64>(&models)
596            .expect("operation should succeed");
597
598        assert_eq!(result.model_names.len(), 2);
599        assert_eq!(result.best_model(), "Model1"); // Better likelihood, fewer parameters
600        assert!((result.model_probabilities.iter().sum::<f64>() - 1.0).abs() < 1e-6);
601    }
602
603    #[test]
604    fn test_bayesian_model_averaging() {
605        let data1 = ModelEvidenceData::new(-95.0, 3, 100);
606        let data2 = ModelEvidenceData::new(-105.0, 5, 100);
607
608        let models = vec![("Model1".to_string(), data1), ("Model2".to_string(), data2)];
609
610        let mut selector = BayesianModelSelector::new(EvidenceEstimationMethod::BIC, Some(42));
611
612        let selection_result = selector
613            .compare_models::<f64>(&models)
614            .expect("operation should succeed");
615        let averager = BayesianModelAverager::new(selection_result);
616
617        let pred1 = array![1.0, 2.0, 3.0];
618        let pred2 = array![1.1, 2.1, 3.1];
619        let predictions = vec![pred1, pred2];
620
621        let averaged = averager
622            .predict(&predictions)
623            .expect("operation should succeed");
624        assert_eq!(averaged.len(), 3);
625
626        // Check effective number of models
627        let effective_n = averager.effective_number_of_models();
628        assert!(effective_n >= 1.0 && effective_n <= 2.0);
629    }
630
631    #[test]
632    fn test_evidence_interpretation() {
633        let log_evidence = vec![-95.0, -100.0];
634        let model_probabilities = vec![0.8, 0.2];
635        let bayes_factors = vec![1.0, 0.007]; // exp(-100 - (-95))
636
637        let result = BayesianModelSelectionResult {
638            model_names: vec!["Model1".to_string(), "Model2".to_string()],
639            log_evidence,
640            model_probabilities,
641            bayes_factors,
642            best_model_index: 0,
643            method: EvidenceEstimationMethod::BIC,
644        };
645
646        let interpretation = result.evidence_interpretation(0, 1);
647        assert!(interpretation.contains("Strong"));
648    }
649
650    #[test]
651    fn test_model_ranking() {
652        let result = BayesianModelSelectionResult {
653            model_names: vec![
654                "ModelA".to_string(),
655                "ModelB".to_string(),
656                "ModelC".to_string(),
657            ],
658            log_evidence: vec![-100.0, -95.0, -98.0],
659            model_probabilities: vec![0.1, 0.7, 0.2],
660            bayes_factors: vec![0.007, 1.0, 0.05],
661            best_model_index: 1,
662            method: EvidenceEstimationMethod::BIC,
663        };
664
665        let ranking = result.model_ranking();
666        assert_eq!(ranking[0].1, "ModelB"); // Best model
667        assert_eq!(ranking[1].1, "ModelC"); // Second best
668        assert_eq!(ranking[2].1, "ModelA"); // Worst model
669    }
670}