sklears_model_selection/
bayesian_model_selection.rs

1//! Bayesian Model Selection with Evidence Estimation
2//!
3//! This module implements Bayesian model selection techniques that use the marginal
4//! likelihood (evidence) to compare models. It includes various methods for estimating
5//! the evidence, including:
6//!
7//! - Laplace approximation
8//! - Bayesian Information Criterion (BIC) approximation
9//! - Harmonic mean estimator
10//! - Thermodynamic integration
11//! - Nested sampling approximation
12//! - Model averaging with Bayesian weights
13
14use scirs2_core::ndarray::Array1;
15use scirs2_core::numeric::{Float as FloatTrait, ToPrimitive};
16use scirs2_core::random::rngs::StdRng;
17use scirs2_core::random::SeedableRng;
18use sklears_core::error::{Result, SklearsError};
19use std::fmt::Debug;
20
21/// Result of Bayesian model selection
22#[derive(Debug, Clone)]
23pub struct BayesianModelSelectionResult {
24    /// Model identifiers
25    pub model_names: Vec<String>,
26    /// Log evidence for each model
27    pub log_evidence: Vec<f64>,
28    /// Model probabilities (Bayesian weights)
29    pub model_probabilities: Vec<f64>,
30    /// Bayes factors relative to the best model
31    pub bayes_factors: Vec<f64>,
32    /// Best model index
33    pub best_model_index: usize,
34    /// Evidence estimation method used
35    pub method: EvidenceEstimationMethod,
36}
37
38impl BayesianModelSelectionResult {
39    /// Get the best model name
40    pub fn best_model(&self) -> &str {
41        &self.model_names[self.best_model_index]
42    }
43
44    /// Get ranking of models by evidence
45    pub fn model_ranking(&self) -> Vec<(usize, &str, f64)> {
46        let mut ranking: Vec<(usize, &str, f64)> = self
47            .model_names
48            .iter()
49            .enumerate()
50            .map(|(i, name)| (i, name.as_str(), self.log_evidence[i]))
51            .collect();
52
53        ranking.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
54        ranking
55    }
56
57    /// Interpret the strength of evidence using Jeffreys' scale
58    pub fn evidence_interpretation(&self, model1_idx: usize, model2_idx: usize) -> String {
59        let log_bf = self.log_evidence[model1_idx] - self.log_evidence[model2_idx];
60        let _bf = log_bf.exp();
61
62        match log_bf {
63            x if x < 1.0 => "Weak evidence".to_string(),
64            x if x < 2.5 => "Positive evidence".to_string(),
65            x if x <= 5.0 => "Strong evidence".to_string(),
66            _ => "Very strong evidence".to_string(),
67        }
68    }
69}
70
71/// Methods for estimating the evidence (marginal likelihood)
72#[derive(Debug, Clone)]
73pub enum EvidenceEstimationMethod {
74    /// Laplace approximation (Gaussian approximation around MAP)
75    LaplaceApproximation,
76    /// BIC approximation (asymptotic approximation)
77    BIC,
78    /// AIC with correction for finite sample size
79    AICc,
80    /// Harmonic mean estimator
81    HarmonicMean,
82    /// Thermodynamic integration
83    ThermodynamicIntegration { n_temperatures: usize },
84    /// Nested sampling approximation
85    NestedSampling { n_live_points: usize },
86    /// Cross-validation based evidence approximation
87    CrossValidationEvidence { n_folds: usize },
88}
89
90/// Bayesian model selector
91pub struct BayesianModelSelector {
92    /// Evidence estimation method
93    method: EvidenceEstimationMethod,
94    /// Prior model probabilities (if not uniform)
95    prior_probabilities: Option<Vec<f64>>,
96    /// Random number generator
97    rng: StdRng,
98}
99
100impl BayesianModelSelector {
101    /// Create a new Bayesian model selector
102    pub fn new(method: EvidenceEstimationMethod, random_state: Option<u64>) -> Self {
103        let rng = match random_state {
104            Some(seed) => StdRng::seed_from_u64(seed),
105            None => StdRng::from_rng(&mut scirs2_core::random::thread_rng()),
106        };
107
108        Self {
109            method,
110            prior_probabilities: None,
111            rng,
112        }
113    }
114
115    /// Set prior probabilities for models (must sum to 1)
116    pub fn with_prior_probabilities(mut self, priors: Vec<f64>) -> Result<Self> {
117        let sum: f64 = priors.iter().sum();
118        if (sum - 1.0).abs() > 1e-6 {
119            return Err(SklearsError::InvalidInput(
120                "Prior probabilities must sum to 1".to_string(),
121            ));
122        }
123        self.prior_probabilities = Some(priors);
124        Ok(self)
125    }
126
127    /// Compare models using Bayesian evidence
128    pub fn compare_models<F>(
129        &mut self,
130        model_results: &[(String, ModelEvidenceData)],
131    ) -> Result<BayesianModelSelectionResult>
132    where
133        F: FloatTrait + ToPrimitive,
134    {
135        if model_results.is_empty() {
136            return Err(SklearsError::InvalidInput("No models provided".to_string()));
137        }
138
139        let model_names: Vec<String> = model_results.iter().map(|(name, _)| name.clone()).collect();
140        let _n_models = model_names.len();
141
142        // Estimate log evidence for each model
143        let mut log_evidence = Vec::new();
144        for (_, data) in model_results {
145            let log_ev = self.estimate_evidence(data)?;
146            log_evidence.push(log_ev);
147        }
148
149        // Calculate model probabilities
150        let model_probabilities = self.calculate_model_probabilities(&log_evidence)?;
151
152        // Calculate Bayes factors relative to best model
153        let best_log_evidence = log_evidence
154            .iter()
155            .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
156        let bayes_factors: Vec<f64> = log_evidence
157            .iter()
158            .map(|&log_ev| (log_ev - best_log_evidence).exp())
159            .collect();
160
161        let best_model_index = log_evidence
162            .iter()
163            .enumerate()
164            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
165            .map(|(i, _)| i)
166            .unwrap();
167
168        Ok(BayesianModelSelectionResult {
169            model_names,
170            log_evidence,
171            model_probabilities,
172            bayes_factors,
173            best_model_index,
174            method: self.method.clone(),
175        })
176    }
177
178    /// Estimate evidence for a single model
179    fn estimate_evidence(&mut self, data: &ModelEvidenceData) -> Result<f64> {
180        match &self.method {
181            EvidenceEstimationMethod::LaplaceApproximation => self.laplace_approximation(data),
182            EvidenceEstimationMethod::BIC => self.bic_approximation(data),
183            EvidenceEstimationMethod::AICc => self.aicc_approximation(data),
184            EvidenceEstimationMethod::HarmonicMean => self.harmonic_mean_estimator(data),
185            EvidenceEstimationMethod::ThermodynamicIntegration { n_temperatures } => {
186                self.thermodynamic_integration(data, *n_temperatures)
187            }
188            EvidenceEstimationMethod::NestedSampling { n_live_points } => {
189                self.nested_sampling_approximation(data, *n_live_points)
190            }
191            EvidenceEstimationMethod::CrossValidationEvidence { n_folds } => {
192                self.cross_validation_evidence(data, *n_folds)
193            }
194        }
195    }
196
197    /// Laplace approximation to the evidence
198    fn laplace_approximation(&self, data: &ModelEvidenceData) -> Result<f64> {
199        let n_params = data.n_parameters as f64;
200        let n_data = data.n_data_points as f64;
201
202        // Log likelihood at MAP estimate
203        let log_likelihood_map = data.max_log_likelihood;
204
205        // Approximate Hessian determinant (assuming it's available or estimated)
206        let log_det_hessian = data.hessian_log_determinant.unwrap_or_else(|| {
207            // Rough approximation if Hessian not available
208            n_params * (2.0 * std::f64::consts::PI).ln() + n_params * n_data.ln()
209        });
210
211        // Prior contribution (assuming flat priors for simplicity)
212        let log_prior = data.log_prior.unwrap_or(0.0);
213
214        // Laplace approximation: log Z ≈ log L(θ_MAP) + log π(θ_MAP) + (k/2) log(2π) - (1/2) log|H|
215        let log_evidence =
216            log_likelihood_map + log_prior + (n_params / 2.0) * (2.0 * std::f64::consts::PI).ln()
217                - 0.5 * log_det_hessian;
218
219        Ok(log_evidence)
220    }
221
222    /// BIC approximation to the evidence
223    fn bic_approximation(&self, data: &ModelEvidenceData) -> Result<f64> {
224        let n_params = data.n_parameters as f64;
225        let n_data = data.n_data_points as f64;
226
227        // BIC = -2 * log L + k * log(n)
228        // log Z ≈ log L - (k/2) * log(n) - (k/2) * log(2π)
229        let log_evidence = data.max_log_likelihood
230            - (n_params / 2.0) * n_data.ln()
231            - (n_params / 2.0) * (2.0 * std::f64::consts::PI).ln();
232
233        Ok(log_evidence)
234    }
235
236    /// AICc approximation to the evidence
237    fn aicc_approximation(&self, data: &ModelEvidenceData) -> Result<f64> {
238        let k = data.n_parameters as f64;
239        let n = data.n_data_points as f64;
240
241        if n <= k + 1.0 {
242            return Err(SklearsError::InvalidInput(
243                "AICc requires n > k + 1".to_string(),
244            ));
245        }
246
247        // AICc = AIC + 2k(k+1)/(n-k-1)
248        let aicc_correction = 2.0 * k * (k + 1.0) / (n - k - 1.0);
249        let log_evidence = data.max_log_likelihood - k - aicc_correction / 2.0;
250
251        Ok(log_evidence)
252    }
253
254    /// Harmonic mean estimator
255    fn harmonic_mean_estimator(&self, data: &ModelEvidenceData) -> Result<f64> {
256        if data.posterior_samples.is_empty() {
257            return Err(SklearsError::InvalidInput(
258                "Harmonic mean estimator requires posterior samples".to_string(),
259            ));
260        }
261
262        // Harmonic mean: 1/Z = (1/N) Σ 1/L(θ_i)
263        // This is known to be unreliable, but included for completeness
264        let n_samples = data.posterior_samples.len() as f64;
265        let harmonic_mean: f64 = data
266            .posterior_samples
267            .iter()
268            .map(|&log_likelihood| (-log_likelihood).exp())
269            .sum::<f64>()
270            / n_samples;
271
272        let log_evidence = -harmonic_mean.ln();
273
274        // Add warning about reliability
275        eprintln!("Warning: Harmonic mean estimator is known to be unreliable");
276
277        Ok(log_evidence)
278    }
279
280    /// Thermodynamic integration
281    fn thermodynamic_integration(
282        &mut self,
283        data: &ModelEvidenceData,
284        n_temperatures: usize,
285    ) -> Result<f64> {
286        if data.posterior_samples.is_empty() {
287            return Err(SklearsError::InvalidInput(
288                "Thermodynamic integration requires posterior samples".to_string(),
289            ));
290        }
291
292        // Create temperature ladder from 0 to 1
293        let temperatures: Vec<f64> = (0..=n_temperatures)
294            .map(|i| i as f64 / n_temperatures as f64)
295            .collect();
296
297        // Estimate mean log likelihood at each temperature
298        let mut mean_log_likelihoods = Vec::new();
299
300        for &temp in &temperatures {
301            if temp == 0.0 {
302                // At temperature 0, all weight is on the prior
303                mean_log_likelihoods.push(0.0);
304            } else {
305                // Approximate using available samples (simplified)
306                let mean_ll = data.posterior_samples.iter().sum::<f64>()
307                    / data.posterior_samples.len() as f64;
308                mean_log_likelihoods.push(temp * mean_ll);
309            }
310        }
311
312        // Integrate using trapezoidal rule
313        let mut integral = 0.0;
314        for i in 1..temperatures.len() {
315            let dt = temperatures[i] - temperatures[i - 1];
316            integral += 0.5 * dt * (mean_log_likelihoods[i] + mean_log_likelihoods[i - 1]);
317        }
318
319        // Add prior contribution
320        let log_prior = data.log_prior.unwrap_or(0.0);
321        let log_evidence = log_prior + integral;
322
323        Ok(log_evidence)
324    }
325
326    /// Nested sampling approximation
327    fn nested_sampling_approximation(
328        &mut self,
329        data: &ModelEvidenceData,
330        n_live_points: usize,
331    ) -> Result<f64> {
332        // Simplified nested sampling approximation
333        // In practice, this would require implementing the full nested sampling algorithm
334
335        if data.posterior_samples.is_empty() {
336            return Err(SklearsError::InvalidInput(
337                "Nested sampling requires posterior samples".to_string(),
338            ));
339        }
340
341        let n_samples = data.posterior_samples.len();
342        let max_iterations = n_samples.min(1000); // Limit iterations
343
344        // Sort samples by likelihood
345        let mut sorted_samples = data.posterior_samples.clone();
346        sorted_samples.sort_by(|a, b| a.partial_cmp(b).unwrap());
347
348        // Approximate evidence using nested sampling logic
349        let mut log_evidence = f64::NEG_INFINITY;
350        let mut log_width = -(1.0 / n_live_points as f64).ln();
351
352        for (i, &log_likelihood) in sorted_samples.iter().enumerate() {
353            if i >= max_iterations {
354                break;
355            }
356
357            let log_weight = log_width + log_likelihood;
358            log_evidence = log_sum_exp(log_evidence, log_weight);
359
360            // Update width (shrinkage)
361            log_width -= (n_live_points as f64).ln();
362        }
363
364        Ok(log_evidence)
365    }
366
367    /// Cross-validation based evidence approximation
368    fn cross_validation_evidence(&self, data: &ModelEvidenceData, n_folds: usize) -> Result<f64> {
369        if data.cv_log_likelihoods.is_none() {
370            return Err(SklearsError::InvalidInput(
371                "Cross-validation evidence requires CV log-likelihoods".to_string(),
372            ));
373        }
374
375        let cv_log_likes = data.cv_log_likelihoods.as_ref().unwrap();
376        if cv_log_likes.len() != n_folds {
377            return Err(SklearsError::InvalidInput(
378                "Number of CV scores must match number of folds".to_string(),
379            ));
380        }
381
382        // Approximate evidence using cross-validation
383        let mean_cv_log_likelihood = cv_log_likes.iter().sum::<f64>() / cv_log_likes.len() as f64;
384
385        // Apply correction for finite sample effects
386        let n_data = data.n_data_points as f64;
387        let correction = (n_data / (n_data - 1.0)).ln();
388
389        let log_evidence = mean_cv_log_likelihood + correction;
390
391        Ok(log_evidence)
392    }
393
394    /// Calculate model probabilities from log evidence
395    fn calculate_model_probabilities(&self, log_evidence: &[f64]) -> Result<Vec<f64>> {
396        let n_models = log_evidence.len();
397
398        // Get prior probabilities
399        let log_priors = if let Some(ref priors) = self.prior_probabilities {
400            if priors.len() != n_models {
401                return Err(SklearsError::InvalidInput(
402                    "Number of prior probabilities must match number of models".to_string(),
403                ));
404            }
405            priors.iter().map(|&p| p.ln()).collect()
406        } else {
407            // Uniform priors
408            vec![-(n_models as f64).ln(); n_models]
409        };
410
411        // Calculate log posterior probabilities
412        let log_posteriors: Vec<f64> = log_evidence
413            .iter()
414            .zip(log_priors.iter())
415            .map(|(&log_ev, &log_prior)| log_ev + log_prior)
416            .collect();
417
418        // Normalize using log-sum-exp trick
419        let log_normalizer = log_sum_exp_vec(&log_posteriors);
420        let probabilities: Vec<f64> = log_posteriors
421            .iter()
422            .map(|&log_p| (log_p - log_normalizer).exp())
423            .collect();
424
425        Ok(probabilities)
426    }
427}
428
429/// Data required for evidence estimation
430#[derive(Debug, Clone)]
431pub struct ModelEvidenceData {
432    /// Maximum log-likelihood achieved
433    pub max_log_likelihood: f64,
434    /// Number of model parameters
435    pub n_parameters: usize,
436    /// Number of data points
437    pub n_data_points: usize,
438    /// Log determinant of Hessian at MAP (for Laplace approximation)
439    pub hessian_log_determinant: Option<f64>,
440    /// Log prior probability at MAP
441    pub log_prior: Option<f64>,
442    /// Posterior samples (log-likelihoods)
443    pub posterior_samples: Vec<f64>,
444    /// Cross-validation log-likelihoods
445    pub cv_log_likelihoods: Option<Vec<f64>>,
446}
447
448impl ModelEvidenceData {
449    /// Create new evidence data with required fields
450    pub fn new(max_log_likelihood: f64, n_parameters: usize, n_data_points: usize) -> Self {
451        Self {
452            max_log_likelihood,
453            n_parameters,
454            n_data_points,
455            hessian_log_determinant: None,
456            log_prior: None,
457            posterior_samples: Vec::new(),
458            cv_log_likelihoods: None,
459        }
460    }
461
462    /// Add Hessian information for Laplace approximation
463    pub fn with_hessian_log_determinant(mut self, log_det: f64) -> Self {
464        self.hessian_log_determinant = Some(log_det);
465        self
466    }
467
468    /// Add prior information
469    pub fn with_log_prior(mut self, log_prior: f64) -> Self {
470        self.log_prior = Some(log_prior);
471        self
472    }
473
474    /// Add posterior samples
475    pub fn with_posterior_samples(mut self, samples: Vec<f64>) -> Self {
476        self.posterior_samples = samples;
477        self
478    }
479
480    /// Add cross-validation log-likelihoods
481    pub fn with_cv_log_likelihoods(mut self, cv_scores: Vec<f64>) -> Self {
482        self.cv_log_likelihoods = Some(cv_scores);
483        self
484    }
485}
486
487/// Model averaging using Bayesian weights
488pub struct BayesianModelAverager {
489    /// Model selection result containing weights
490    selection_result: BayesianModelSelectionResult,
491}
492
493impl BayesianModelAverager {
494    /// Create new model averager from selection result
495    pub fn new(selection_result: BayesianModelSelectionResult) -> Self {
496        Self { selection_result }
497    }
498
499    /// Make prediction using Bayesian model averaging
500    pub fn predict(&self, model_predictions: &[Array1<f64>]) -> Result<Array1<f64>> {
501        if model_predictions.len() != self.selection_result.model_names.len() {
502            return Err(SklearsError::InvalidInput(
503                "Number of predictions must match number of models".to_string(),
504            ));
505        }
506
507        if model_predictions.is_empty() {
508            return Err(SklearsError::InvalidInput(
509                "No predictions provided".to_string(),
510            ));
511        }
512
513        let n_samples = model_predictions[0].len();
514
515        // Check all predictions have same length
516        for pred in model_predictions {
517            if pred.len() != n_samples {
518                return Err(SklearsError::InvalidInput(
519                    "All predictions must have the same length".to_string(),
520                ));
521            }
522        }
523
524        // Weighted average of predictions
525        let mut averaged_prediction = Array1::zeros(n_samples);
526
527        for (i, pred) in model_predictions.iter().enumerate() {
528            let weight = self.selection_result.model_probabilities[i];
529            averaged_prediction = averaged_prediction + pred * weight;
530        }
531
532        Ok(averaged_prediction)
533    }
534
535    /// Get model weights
536    pub fn get_weights(&self) -> &[f64] {
537        &self.selection_result.model_probabilities
538    }
539
540    /// Get effective number of models (entropy-based measure)
541    pub fn effective_number_of_models(&self) -> f64 {
542        let entropy: f64 = self
543            .selection_result
544            .model_probabilities
545            .iter()
546            .filter(|&&p| p > 0.0)
547            .map(|&p| -p * p.ln())
548            .sum();
549        entropy.exp()
550    }
551}
552
553// Utility functions
554fn log_sum_exp(a: f64, b: f64) -> f64 {
555    let max_val = a.max(b);
556    max_val + ((a - max_val).exp() + (b - max_val).exp()).ln()
557}
558
559fn log_sum_exp_vec(values: &[f64]) -> f64 {
560    let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
561    let sum: f64 = values.iter().map(|&x| (x - max_val).exp()).sum();
562    max_val + sum.ln()
563}
564
565#[allow(non_snake_case)]
566#[cfg(test)]
567mod tests {
568    use super::*;
569    use scirs2_core::ndarray::array;
570
571    #[test]
572    fn test_bic_approximation() {
573        let data = ModelEvidenceData::new(-100.0, 5, 100);
574        let mut selector = BayesianModelSelector::new(EvidenceEstimationMethod::BIC, Some(42));
575
576        let log_evidence = selector.estimate_evidence(&data).unwrap();
577        assert!(log_evidence < 0.0); // Should be negative
578    }
579
580    #[test]
581    fn test_model_comparison() {
582        let data1 = ModelEvidenceData::new(-95.0, 3, 100);
583        let data2 = ModelEvidenceData::new(-105.0, 5, 100);
584
585        let models = vec![("Model1".to_string(), data1), ("Model2".to_string(), data2)];
586
587        let mut selector = BayesianModelSelector::new(EvidenceEstimationMethod::BIC, Some(42));
588
589        let result: BayesianModelSelectionResult = selector.compare_models::<f64>(&models).unwrap();
590
591        assert_eq!(result.model_names.len(), 2);
592        assert_eq!(result.best_model(), "Model1"); // Better likelihood, fewer parameters
593        assert!((result.model_probabilities.iter().sum::<f64>() - 1.0).abs() < 1e-6);
594    }
595
596    #[test]
597    fn test_bayesian_model_averaging() {
598        let data1 = ModelEvidenceData::new(-95.0, 3, 100);
599        let data2 = ModelEvidenceData::new(-105.0, 5, 100);
600
601        let models = vec![("Model1".to_string(), data1), ("Model2".to_string(), data2)];
602
603        let mut selector = BayesianModelSelector::new(EvidenceEstimationMethod::BIC, Some(42));
604
605        let selection_result = selector.compare_models::<f64>(&models).unwrap();
606        let averager = BayesianModelAverager::new(selection_result);
607
608        let pred1 = array![1.0, 2.0, 3.0];
609        let pred2 = array![1.1, 2.1, 3.1];
610        let predictions = vec![pred1, pred2];
611
612        let averaged = averager.predict(&predictions).unwrap();
613        assert_eq!(averaged.len(), 3);
614
615        // Check effective number of models
616        let effective_n = averager.effective_number_of_models();
617        assert!(effective_n >= 1.0 && effective_n <= 2.0);
618    }
619
620    #[test]
621    fn test_evidence_interpretation() {
622        let log_evidence = vec![-95.0, -100.0];
623        let model_probabilities = vec![0.8, 0.2];
624        let bayes_factors = vec![1.0, 0.007]; // exp(-100 - (-95))
625
626        let result = BayesianModelSelectionResult {
627            model_names: vec!["Model1".to_string(), "Model2".to_string()],
628            log_evidence,
629            model_probabilities,
630            bayes_factors,
631            best_model_index: 0,
632            method: EvidenceEstimationMethod::BIC,
633        };
634
635        let interpretation = result.evidence_interpretation(0, 1);
636        assert!(interpretation.contains("Strong"));
637    }
638
639    #[test]
640    fn test_model_ranking() {
641        let result = BayesianModelSelectionResult {
642            model_names: vec![
643                "ModelA".to_string(),
644                "ModelB".to_string(),
645                "ModelC".to_string(),
646            ],
647            log_evidence: vec![-100.0, -95.0, -98.0],
648            model_probabilities: vec![0.1, 0.7, 0.2],
649            bayes_factors: vec![0.007, 1.0, 0.05],
650            best_model_index: 1,
651            method: EvidenceEstimationMethod::BIC,
652        };
653
654        let ranking = result.model_ranking();
655        assert_eq!(ranking[0].1, "ModelB"); // Best model
656        assert_eq!(ranking[1].1, "ModelC"); // Second best
657        assert_eq!(ranking[2].1, "ModelA"); // Worst model
658    }
659}