sklears_model_selection/
bayesian_model_averaging.rs

1//! Bayesian Model Averaging (BMA) implementation
2//!
3//! This module provides Bayesian Model Averaging functionality for combining
4//! multiple models' predictions weighted by their posterior probabilities.
5
6use scirs2_core::ndarray::{Array1, ArrayView1};
7use sklears_core::prelude::*;
8use std::collections::HashMap;
9
10fn bma_error(msg: &str) -> SklearsError {
11    SklearsError::InvalidInput(msg.to_string())
12}
13
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum PriorType {
16    /// Uniform
17    Uniform,
18    /// Jeffreys
19    Jeffreys,
20    /// Exponential
21    Exponential(f64),
22    /// Custom
23    Custom,
24}
25
26#[derive(Debug, Clone, Copy, PartialEq)]
27pub enum EvidenceMethod {
28    /// MarginalLikelihood
29    MarginalLikelihood,
30    /// BIC
31    BIC,
32    /// AIC
33    AIC,
34    /// AICc
35    AICc,
36    /// DIC
37    DIC,
38    /// WAIC
39    WAIC,
40    /// CrossValidation
41    CrossValidation,
42    /// BootstrapEstimate
43    BootstrapEstimate,
44}
45
46#[derive(Debug, Clone)]
47pub struct BMAConfig {
48    pub prior_type: PriorType,
49    pub evidence_method: EvidenceMethod,
50    pub min_weight_threshold: f64,
51    pub normalize_weights: bool,
52    pub use_log_space: bool,
53    pub regularization_lambda: f64,
54    pub bootstrap_samples: usize,
55    pub cv_folds: usize,
56}
57
58impl Default for BMAConfig {
59    fn default() -> Self {
60        Self {
61            prior_type: PriorType::Uniform,
62            evidence_method: EvidenceMethod::CrossValidation,
63            min_weight_threshold: 1e-6,
64            normalize_weights: true,
65            use_log_space: true,
66            regularization_lambda: 1e-3,
67            bootstrap_samples: 100,
68            cv_folds: 5,
69        }
70    }
71}
72
73#[derive(Debug, Clone)]
74pub struct ModelInfo {
75    pub model_id: String,
76    pub complexity: usize,
77    pub training_accuracy: f64,
78    pub validation_accuracy: f64,
79    pub log_likelihood: f64,
80    pub n_parameters: usize,
81    pub predictions: Array1<f64>,
82    pub prediction_variance: Option<Array1<f64>>,
83}
84
85#[derive(Debug, Clone)]
86pub struct BMAResult {
87    pub averaged_predictions: Array1<f64>,
88    pub prediction_variance: Array1<f64>,
89    pub model_weights: HashMap<String, f64>,
90    pub effective_model_count: f64,
91    pub total_evidence: f64,
92    pub model_posterior_probabilities: HashMap<String, f64>,
93    pub ensemble_accuracy: f64,
94}
95
96pub struct BayesianModelAverager {
97    config: BMAConfig,
98    models: Vec<ModelInfo>,
99    prior_weights: Option<HashMap<String, f64>>,
100    evidence_cache: HashMap<String, f64>,
101}
102
103impl BayesianModelAverager {
104    pub fn new(config: BMAConfig) -> Self {
105        Self {
106            config,
107            models: Vec::new(),
108            prior_weights: None,
109            evidence_cache: HashMap::new(),
110        }
111    }
112
113    pub fn with_prior_weights(mut self, weights: HashMap<String, f64>) -> Result<Self> {
114        for weight in weights.values() {
115            if *weight < 0.0 {
116                return Err(bma_error("Prior weights cannot be negative"));
117            }
118        }
119        self.prior_weights = Some(weights);
120        Ok(self)
121    }
122
123    pub fn add_model(&mut self, model: ModelInfo) -> Result<()> {
124        if !self.models.is_empty() {
125            let expected_len = self.models[0].predictions.len();
126            if model.predictions.len() != expected_len {
127                return Err(bma_error(&format!(
128                    "Inconsistent prediction dimensions: expected {}, got {}",
129                    expected_len,
130                    model.predictions.len()
131                )));
132            }
133        }
134        self.models.push(model);
135        Ok(())
136    }
137
138    pub fn add_models(&mut self, models: Vec<ModelInfo>) -> Result<()> {
139        for model in models {
140            self.add_model(model)?;
141        }
142        Ok(())
143    }
144
145    pub fn compute_average(&mut self, y_true: Option<&ArrayView1<f64>>) -> Result<BMAResult> {
146        if self.models.is_empty() {
147            return Err(bma_error("No models provided"));
148        }
149
150        let posterior_weights = self.compute_posterior_weights(y_true)?;
151        let averaged_predictions = self.compute_weighted_predictions(&posterior_weights)?;
152        let prediction_variance =
153            self.compute_prediction_variance(&posterior_weights, &averaged_predictions)?;
154        let effective_model_count = self.compute_effective_model_count(&posterior_weights);
155        let total_evidence = self.compute_total_evidence(y_true)?;
156
157        let ensemble_accuracy = if let Some(y_true) = y_true {
158            self.compute_ensemble_accuracy(&averaged_predictions, y_true)
159        } else {
160            0.0
161        };
162
163        let model_posterior_probabilities: HashMap<String, f64> = self
164            .models
165            .iter()
166            .zip(posterior_weights.iter())
167            .map(|(model, &weight)| (model.model_id.clone(), weight))
168            .collect();
169
170        let model_weights = model_posterior_probabilities.clone();
171
172        Ok(BMAResult {
173            averaged_predictions,
174            prediction_variance,
175            model_weights,
176            effective_model_count,
177            total_evidence,
178            model_posterior_probabilities,
179            ensemble_accuracy,
180        })
181    }
182
183    fn compute_posterior_weights(&mut self, y_true: Option<&ArrayView1<f64>>) -> Result<Vec<f64>> {
184        let n_models = self.models.len();
185        let mut log_posteriors = vec![0.0; n_models];
186
187        // Collect model references to avoid borrow checker issues
188        let models: Vec<_> = self.models.to_vec();
189
190        for (i, model) in models.iter().enumerate() {
191            let log_prior = self.compute_log_prior(model)?;
192            let log_evidence = self.compute_log_evidence(model, y_true)?;
193
194            log_posteriors[i] = log_prior + log_evidence;
195        }
196
197        if self.config.use_log_space {
198            self.normalize_log_weights(&mut log_posteriors)
199        } else {
200            let posteriors: Vec<f64> = log_posteriors.iter().map(|&lp| lp.exp()).collect();
201            self.normalize_weights(&posteriors)
202        }
203    }
204
205    fn compute_log_prior(&self, model: &ModelInfo) -> Result<f64> {
206        match self.config.prior_type {
207            PriorType::Uniform => Ok(-(self.models.len() as f64).ln()),
208            PriorType::Jeffreys => {
209                let complexity = model.complexity as f64;
210                Ok(-0.5 * complexity.ln())
211            }
212            PriorType::Exponential(lambda) => {
213                let complexity = model.complexity as f64;
214                Ok(lambda.ln() - lambda * complexity)
215            }
216            PriorType::Custom => {
217                if let Some(ref prior_weights) = self.prior_weights {
218                    if let Some(&weight) = prior_weights.get(&model.model_id) {
219                        Ok(weight.ln())
220                    } else {
221                        Ok(-(self.models.len() as f64).ln())
222                    }
223                } else {
224                    Err(bma_error("Invalid prior specification"))
225                }
226            }
227        }
228    }
229
230    fn compute_log_evidence(
231        &mut self,
232        model: &ModelInfo,
233        y_true: Option<&ArrayView1<f64>>,
234    ) -> Result<f64> {
235        if let Some(cached_evidence) = self.evidence_cache.get(&model.model_id) {
236            return Ok(*cached_evidence);
237        }
238
239        let log_evidence = match self.config.evidence_method {
240            EvidenceMethod::MarginalLikelihood => {
241                if let Some(y_true) = y_true {
242                    self.compute_marginal_likelihood(model, y_true)?
243                } else {
244                    model.log_likelihood
245                }
246            }
247            EvidenceMethod::BIC => {
248                let n = model.predictions.len() as f64;
249                let k = model.n_parameters as f64;
250                model.log_likelihood - 0.5 * k * n.ln()
251            }
252            EvidenceMethod::AIC => {
253                let k = model.n_parameters as f64;
254                model.log_likelihood - k
255            }
256            EvidenceMethod::AICc => {
257                let n = model.predictions.len() as f64;
258                let k = model.n_parameters as f64;
259                let aic = model.log_likelihood - k;
260                let correction = (2.0 * k * (k + 1.0)) / (n - k - 1.0);
261                aic - correction
262            }
263            EvidenceMethod::DIC => {
264                let deviance = -2.0 * model.log_likelihood;
265                let p_dic = 2.0 * (model.training_accuracy - model.validation_accuracy).abs();
266                -(deviance + p_dic)
267            }
268            EvidenceMethod::WAIC => {
269                if let Some(ref var) = model.prediction_variance {
270                    let lppd = model.log_likelihood;
271                    let p_waic = var.sum();
272                    lppd - p_waic
273                } else {
274                    model.log_likelihood
275                }
276            }
277            EvidenceMethod::CrossValidation => -self.compute_cv_error(model)?,
278            EvidenceMethod::BootstrapEstimate => -self.compute_bootstrap_error(model)?,
279        };
280
281        self.evidence_cache
282            .insert(model.model_id.clone(), log_evidence);
283        Ok(log_evidence)
284    }
285
286    fn compute_marginal_likelihood(
287        &self,
288        model: &ModelInfo,
289        y_true: &ArrayView1<f64>,
290    ) -> Result<f64> {
291        let mut log_likelihood = 0.0;
292        let n = y_true.len();
293
294        for i in 0..n {
295            let residual = y_true[i] - model.predictions[i];
296            let variance = model
297                .prediction_variance
298                .as_ref()
299                .map(|v| v[i])
300                .unwrap_or(1.0);
301
302            if variance <= 0.0 {
303                return Err(bma_error("Numerical instability in posterior computation"));
304            }
305
306            log_likelihood += -0.5
307                * (residual.powi(2) / variance + variance.ln() + (2.0 * std::f64::consts::PI).ln());
308        }
309
310        let regularization = -0.5 * self.config.regularization_lambda * model.n_parameters as f64;
311        Ok(log_likelihood + regularization)
312    }
313
314    fn compute_cv_error(&self, model: &ModelInfo) -> Result<f64> {
315        let validation_error = 1.0 - model.validation_accuracy;
316        Ok(validation_error.max(1e-10).ln())
317    }
318
319    fn compute_bootstrap_error(&self, model: &ModelInfo) -> Result<f64> {
320        let training_error = 1.0 - model.training_accuracy;
321        let validation_error = 1.0 - model.validation_accuracy;
322        let bootstrap_error = (training_error + validation_error) / 2.0;
323        Ok(bootstrap_error.max(1e-10).ln())
324    }
325
326    fn normalize_log_weights(&self, log_weights: &mut [f64]) -> Result<Vec<f64>> {
327        let max_log_weight = log_weights
328            .iter()
329            .cloned()
330            .fold(f64::NEG_INFINITY, f64::max);
331
332        if max_log_weight.is_infinite() {
333            return Err(bma_error("Numerical instability in posterior computation"));
334        }
335
336        for w in log_weights.iter_mut() {
337            *w -= max_log_weight;
338        }
339
340        let weights: Vec<f64> = log_weights.iter().map(|&lw| lw.exp()).collect();
341        self.normalize_weights(&weights)
342    }
343
344    fn normalize_weights(&self, weights: &[f64]) -> Result<Vec<f64>> {
345        let sum: f64 = weights.iter().sum();
346
347        if sum == 0.0 || !sum.is_finite() {
348            return Err(bma_error("Numerical instability in posterior computation"));
349        }
350
351        let normalized: Vec<f64> = weights
352            .iter()
353            .map(|&w| w / sum)
354            .map(|w| {
355                if w < self.config.min_weight_threshold {
356                    0.0
357                } else {
358                    w
359                }
360            })
361            .collect();
362
363        let final_sum: f64 = normalized.iter().sum();
364        if final_sum == 0.0 {
365            return Err(bma_error("Numerical instability in posterior computation"));
366        }
367
368        Ok(normalized.iter().map(|&w| w / final_sum).collect())
369    }
370
371    fn compute_weighted_predictions(&self, weights: &[f64]) -> Result<Array1<f64>> {
372        if self.models.is_empty() {
373            return Err(bma_error("No models provided"));
374        }
375
376        let n_predictions = self.models[0].predictions.len();
377        let mut averaged = Array1::zeros(n_predictions);
378
379        for (weight, model) in weights.iter().zip(self.models.iter()) {
380            averaged = averaged + *weight * &model.predictions;
381        }
382
383        Ok(averaged)
384    }
385
386    fn compute_prediction_variance(
387        &self,
388        weights: &[f64],
389        averaged_predictions: &Array1<f64>,
390    ) -> Result<Array1<f64>> {
391        let n_predictions = averaged_predictions.len();
392        let mut variance = Array1::zeros(n_predictions);
393
394        for i in 0..n_predictions {
395            let mut prediction_var = 0.0;
396            let mut model_var = 0.0;
397
398            for (weight, model) in weights.iter().zip(self.models.iter()) {
399                let diff = model.predictions[i] - averaged_predictions[i];
400                prediction_var += weight * diff.powi(2);
401
402                if let Some(ref var) = model.prediction_variance {
403                    model_var += weight * var[i];
404                }
405            }
406
407            variance[i] = prediction_var + model_var;
408        }
409
410        Ok(variance)
411    }
412
413    fn compute_effective_model_count(&self, weights: &[f64]) -> f64 {
414        let sum_squares: f64 = weights.iter().map(|w| w.powi(2)).sum();
415        if sum_squares > 0.0 {
416            1.0 / sum_squares
417        } else {
418            0.0
419        }
420    }
421
422    fn compute_total_evidence(&self, _y_true: Option<&ArrayView1<f64>>) -> Result<f64> {
423        let mut total_evidence = 0.0;
424
425        for model in &self.models {
426            let evidence = self
427                .evidence_cache
428                .get(&model.model_id)
429                .copied()
430                .unwrap_or(model.log_likelihood);
431            total_evidence += evidence.exp();
432        }
433
434        Ok(total_evidence.ln())
435    }
436
437    fn compute_ensemble_accuracy(
438        &self,
439        predictions: &Array1<f64>,
440        y_true: &ArrayView1<f64>,
441    ) -> f64 {
442        let mse: f64 = predictions
443            .iter()
444            .zip(y_true.iter())
445            .map(|(pred, true_val)| (pred - true_val).powi(2))
446            .sum::<f64>()
447            / predictions.len() as f64;
448
449        (-mse).exp()
450    }
451
452    pub fn get_model_rankings(&self, result: &BMAResult) -> Vec<(String, f64)> {
453        let mut rankings: Vec<_> = result
454            .model_weights
455            .iter()
456            .map(|(id, &weight)| (id.clone(), weight))
457            .collect();
458        rankings.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
459        rankings
460    }
461
462    pub fn prune_models(&mut self, min_weight: f64) -> usize {
463        let weights_result = self.compute_posterior_weights(None);
464        if let Ok(weights) = weights_result {
465            let indices_to_keep: Vec<usize> = weights
466                .iter()
467                .enumerate()
468                .filter(|(_, &w)| w >= min_weight)
469                .map(|(i, _)| i)
470                .collect();
471
472            let mut new_models = Vec::new();
473            for &idx in &indices_to_keep {
474                new_models.push(self.models[idx].clone());
475            }
476
477            let pruned_count = self.models.len() - new_models.len();
478            self.models = new_models;
479            self.evidence_cache.clear();
480
481            pruned_count
482        } else {
483            0
484        }
485    }
486}
487
488pub fn bayesian_model_average(
489    models: Vec<ModelInfo>,
490    y_true: Option<&ArrayView1<f64>>,
491    config: Option<BMAConfig>,
492) -> Result<BMAResult> {
493    let config = config.unwrap_or_default();
494    let mut averager = BayesianModelAverager::new(config);
495    averager.add_models(models)?;
496    averager.compute_average(y_true)
497}
498
499#[allow(non_snake_case)]
500#[cfg(test)]
501mod tests {
502    use super::*;
503    use scirs2_core::ndarray::arr1;
504
505    fn create_test_models() -> Vec<ModelInfo> {
506        vec![
507            ModelInfo {
508                model_id: "model1".to_string(),
509                complexity: 10,
510                training_accuracy: 0.85,
511                validation_accuracy: 0.80,
512                log_likelihood: -100.0,
513                n_parameters: 10,
514                predictions: arr1(&[0.8, 0.6, 0.9, 0.7, 0.5]),
515                prediction_variance: Some(arr1(&[0.01, 0.02, 0.01, 0.03, 0.02])),
516            },
517            ModelInfo {
518                model_id: "model2".to_string(),
519                complexity: 15,
520                training_accuracy: 0.90,
521                validation_accuracy: 0.82,
522                log_likelihood: -95.0,
523                n_parameters: 15,
524                predictions: arr1(&[0.9, 0.7, 0.8, 0.8, 0.6]),
525                prediction_variance: Some(arr1(&[0.02, 0.01, 0.02, 0.01, 0.03])),
526            },
527            ModelInfo {
528                model_id: "model3".to_string(),
529                complexity: 5,
530                training_accuracy: 0.75,
531                validation_accuracy: 0.78,
532                log_likelihood: -110.0,
533                n_parameters: 5,
534                predictions: arr1(&[0.7, 0.8, 0.7, 0.6, 0.7]),
535                prediction_variance: Some(arr1(&[0.03, 0.02, 0.03, 0.02, 0.01])),
536            },
537        ]
538    }
539
540    #[test]
541    fn test_basic_bma() {
542        let models = create_test_models();
543        let config = BMAConfig::default();
544        let result = bayesian_model_average(models, None, Some(config)).unwrap();
545
546        assert_eq!(result.averaged_predictions.len(), 5);
547        assert_eq!(result.prediction_variance.len(), 5);
548        assert_eq!(result.model_weights.len(), 3);
549        assert!(result.effective_model_count > 0.0);
550        assert!(result.total_evidence.is_finite());
551    }
552
553    #[test]
554    fn test_bma_with_ground_truth() {
555        let models = create_test_models();
556        let y_true = arr1(&[0.8, 0.7, 0.8, 0.7, 0.6]);
557        let config = BMAConfig::default();
558
559        let result = bayesian_model_average(models, Some(&y_true.view()), Some(config)).unwrap();
560
561        assert_eq!(result.averaged_predictions.len(), 5);
562        assert!(result.ensemble_accuracy > 0.0);
563        assert!(result.ensemble_accuracy <= 1.0);
564    }
565
566    #[test]
567    fn test_uniform_prior() {
568        let models = create_test_models();
569        let config = BMAConfig {
570            prior_type: PriorType::Uniform,
571            evidence_method: EvidenceMethod::BIC,
572            ..Default::default()
573        };
574
575        let result = bayesian_model_average(models, None, Some(config)).unwrap();
576        assert!(result.model_weights.values().all(|&w| w > 0.0));
577    }
578
579    #[test]
580    fn test_jeffreys_prior() {
581        let models = create_test_models();
582        let config = BMAConfig {
583            prior_type: PriorType::Jeffreys,
584            evidence_method: EvidenceMethod::AIC,
585            ..Default::default()
586        };
587
588        let result = bayesian_model_average(models, None, Some(config)).unwrap();
589        assert!(result.model_weights.values().all(|&w| w >= 0.0));
590    }
591
592    #[test]
593    fn test_model_pruning() {
594        let models = create_test_models();
595        let config = BMAConfig::default();
596        let mut averager = BayesianModelAverager::new(config);
597        averager.add_models(models).unwrap();
598
599        let initial_count = averager.models.len();
600        let pruned = averager.prune_models(0.1);
601
602        assert!(pruned <= initial_count);
603        assert!(averager.models.len() <= initial_count);
604    }
605
606    #[test]
607    fn test_inconsistent_dimensions() {
608        let mut models = create_test_models();
609        models[1].predictions = arr1(&[0.9, 0.7, 0.8]);
610
611        let result = bayesian_model_average(models, None, None);
612        assert!(result.is_err());
613    }
614
615    #[test]
616    fn test_empty_models() {
617        let models = Vec::new();
618        let result = bayesian_model_average(models, None, None);
619        assert!(result.is_err());
620    }
621
622    #[test]
623    fn test_custom_prior() {
624        let models = create_test_models();
625        let mut prior_weights = HashMap::new();
626        prior_weights.insert("model1".to_string(), 0.5);
627        prior_weights.insert("model2".to_string(), 0.3);
628        prior_weights.insert("model3".to_string(), 0.2);
629
630        let config = BMAConfig {
631            prior_type: PriorType::Custom,
632            ..Default::default()
633        };
634
635        let mut averager = BayesianModelAverager::new(config);
636        averager = averager.with_prior_weights(prior_weights).unwrap();
637        averager.add_models(models).unwrap();
638
639        let result = averager.compute_average(None).unwrap();
640        assert!(result.model_weights.values().all(|&w| w >= 0.0));
641    }
642}