1use scirs2_core::ndarray::{Array1, ArrayView1};
7use sklears_core::prelude::*;
8use std::collections::HashMap;
9
10fn bma_error(msg: &str) -> SklearsError {
11 SklearsError::InvalidInput(msg.to_string())
12}
13
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum PriorType {
16 Uniform,
18 Jeffreys,
20 Exponential(f64),
22 Custom,
24}
25
26#[derive(Debug, Clone, Copy, PartialEq)]
27pub enum EvidenceMethod {
28 MarginalLikelihood,
30 BIC,
32 AIC,
34 AICc,
36 DIC,
38 WAIC,
40 CrossValidation,
42 BootstrapEstimate,
44}
45
46#[derive(Debug, Clone)]
47pub struct BMAConfig {
48 pub prior_type: PriorType,
49 pub evidence_method: EvidenceMethod,
50 pub min_weight_threshold: f64,
51 pub normalize_weights: bool,
52 pub use_log_space: bool,
53 pub regularization_lambda: f64,
54 pub bootstrap_samples: usize,
55 pub cv_folds: usize,
56}
57
58impl Default for BMAConfig {
59 fn default() -> Self {
60 Self {
61 prior_type: PriorType::Uniform,
62 evidence_method: EvidenceMethod::CrossValidation,
63 min_weight_threshold: 1e-6,
64 normalize_weights: true,
65 use_log_space: true,
66 regularization_lambda: 1e-3,
67 bootstrap_samples: 100,
68 cv_folds: 5,
69 }
70 }
71}
72
73#[derive(Debug, Clone)]
74pub struct ModelInfo {
75 pub model_id: String,
76 pub complexity: usize,
77 pub training_accuracy: f64,
78 pub validation_accuracy: f64,
79 pub log_likelihood: f64,
80 pub n_parameters: usize,
81 pub predictions: Array1<f64>,
82 pub prediction_variance: Option<Array1<f64>>,
83}
84
85#[derive(Debug, Clone)]
86pub struct BMAResult {
87 pub averaged_predictions: Array1<f64>,
88 pub prediction_variance: Array1<f64>,
89 pub model_weights: HashMap<String, f64>,
90 pub effective_model_count: f64,
91 pub total_evidence: f64,
92 pub model_posterior_probabilities: HashMap<String, f64>,
93 pub ensemble_accuracy: f64,
94}
95
96pub struct BayesianModelAverager {
97 config: BMAConfig,
98 models: Vec<ModelInfo>,
99 prior_weights: Option<HashMap<String, f64>>,
100 evidence_cache: HashMap<String, f64>,
101}
102
103impl BayesianModelAverager {
104 pub fn new(config: BMAConfig) -> Self {
105 Self {
106 config,
107 models: Vec::new(),
108 prior_weights: None,
109 evidence_cache: HashMap::new(),
110 }
111 }
112
113 pub fn with_prior_weights(mut self, weights: HashMap<String, f64>) -> Result<Self> {
114 for weight in weights.values() {
115 if *weight < 0.0 {
116 return Err(bma_error("Prior weights cannot be negative"));
117 }
118 }
119 self.prior_weights = Some(weights);
120 Ok(self)
121 }
122
123 pub fn add_model(&mut self, model: ModelInfo) -> Result<()> {
124 if !self.models.is_empty() {
125 let expected_len = self.models[0].predictions.len();
126 if model.predictions.len() != expected_len {
127 return Err(bma_error(&format!(
128 "Inconsistent prediction dimensions: expected {}, got {}",
129 expected_len,
130 model.predictions.len()
131 )));
132 }
133 }
134 self.models.push(model);
135 Ok(())
136 }
137
138 pub fn add_models(&mut self, models: Vec<ModelInfo>) -> Result<()> {
139 for model in models {
140 self.add_model(model)?;
141 }
142 Ok(())
143 }
144
145 pub fn compute_average(&mut self, y_true: Option<&ArrayView1<f64>>) -> Result<BMAResult> {
146 if self.models.is_empty() {
147 return Err(bma_error("No models provided"));
148 }
149
150 let posterior_weights = self.compute_posterior_weights(y_true)?;
151 let averaged_predictions = self.compute_weighted_predictions(&posterior_weights)?;
152 let prediction_variance =
153 self.compute_prediction_variance(&posterior_weights, &averaged_predictions)?;
154 let effective_model_count = self.compute_effective_model_count(&posterior_weights);
155 let total_evidence = self.compute_total_evidence(y_true)?;
156
157 let ensemble_accuracy = if let Some(y_true) = y_true {
158 self.compute_ensemble_accuracy(&averaged_predictions, y_true)
159 } else {
160 0.0
161 };
162
163 let model_posterior_probabilities: HashMap<String, f64> = self
164 .models
165 .iter()
166 .zip(posterior_weights.iter())
167 .map(|(model, &weight)| (model.model_id.clone(), weight))
168 .collect();
169
170 let model_weights = model_posterior_probabilities.clone();
171
172 Ok(BMAResult {
173 averaged_predictions,
174 prediction_variance,
175 model_weights,
176 effective_model_count,
177 total_evidence,
178 model_posterior_probabilities,
179 ensemble_accuracy,
180 })
181 }
182
183 fn compute_posterior_weights(&mut self, y_true: Option<&ArrayView1<f64>>) -> Result<Vec<f64>> {
184 let n_models = self.models.len();
185 let mut log_posteriors = vec![0.0; n_models];
186
187 let models: Vec<_> = self.models.to_vec();
189
190 for (i, model) in models.iter().enumerate() {
191 let log_prior = self.compute_log_prior(model)?;
192 let log_evidence = self.compute_log_evidence(model, y_true)?;
193
194 log_posteriors[i] = log_prior + log_evidence;
195 }
196
197 if self.config.use_log_space {
198 self.normalize_log_weights(&mut log_posteriors)
199 } else {
200 let posteriors: Vec<f64> = log_posteriors.iter().map(|&lp| lp.exp()).collect();
201 self.normalize_weights(&posteriors)
202 }
203 }
204
205 fn compute_log_prior(&self, model: &ModelInfo) -> Result<f64> {
206 match self.config.prior_type {
207 PriorType::Uniform => Ok(-(self.models.len() as f64).ln()),
208 PriorType::Jeffreys => {
209 let complexity = model.complexity as f64;
210 Ok(-0.5 * complexity.ln())
211 }
212 PriorType::Exponential(lambda) => {
213 let complexity = model.complexity as f64;
214 Ok(lambda.ln() - lambda * complexity)
215 }
216 PriorType::Custom => {
217 if let Some(ref prior_weights) = self.prior_weights {
218 if let Some(&weight) = prior_weights.get(&model.model_id) {
219 Ok(weight.ln())
220 } else {
221 Ok(-(self.models.len() as f64).ln())
222 }
223 } else {
224 Err(bma_error("Invalid prior specification"))
225 }
226 }
227 }
228 }
229
230 fn compute_log_evidence(
231 &mut self,
232 model: &ModelInfo,
233 y_true: Option<&ArrayView1<f64>>,
234 ) -> Result<f64> {
235 if let Some(cached_evidence) = self.evidence_cache.get(&model.model_id) {
236 return Ok(*cached_evidence);
237 }
238
239 let log_evidence = match self.config.evidence_method {
240 EvidenceMethod::MarginalLikelihood => {
241 if let Some(y_true) = y_true {
242 self.compute_marginal_likelihood(model, y_true)?
243 } else {
244 model.log_likelihood
245 }
246 }
247 EvidenceMethod::BIC => {
248 let n = model.predictions.len() as f64;
249 let k = model.n_parameters as f64;
250 model.log_likelihood - 0.5 * k * n.ln()
251 }
252 EvidenceMethod::AIC => {
253 let k = model.n_parameters as f64;
254 model.log_likelihood - k
255 }
256 EvidenceMethod::AICc => {
257 let n = model.predictions.len() as f64;
258 let k = model.n_parameters as f64;
259 let aic = model.log_likelihood - k;
260 let correction = (2.0 * k * (k + 1.0)) / (n - k - 1.0);
261 aic - correction
262 }
263 EvidenceMethod::DIC => {
264 let deviance = -2.0 * model.log_likelihood;
265 let p_dic = 2.0 * (model.training_accuracy - model.validation_accuracy).abs();
266 -(deviance + p_dic)
267 }
268 EvidenceMethod::WAIC => {
269 if let Some(ref var) = model.prediction_variance {
270 let lppd = model.log_likelihood;
271 let p_waic = var.sum();
272 lppd - p_waic
273 } else {
274 model.log_likelihood
275 }
276 }
277 EvidenceMethod::CrossValidation => -self.compute_cv_error(model)?,
278 EvidenceMethod::BootstrapEstimate => -self.compute_bootstrap_error(model)?,
279 };
280
281 self.evidence_cache
282 .insert(model.model_id.clone(), log_evidence);
283 Ok(log_evidence)
284 }
285
286 fn compute_marginal_likelihood(
287 &self,
288 model: &ModelInfo,
289 y_true: &ArrayView1<f64>,
290 ) -> Result<f64> {
291 let mut log_likelihood = 0.0;
292 let n = y_true.len();
293
294 for i in 0..n {
295 let residual = y_true[i] - model.predictions[i];
296 let variance = model
297 .prediction_variance
298 .as_ref()
299 .map(|v| v[i])
300 .unwrap_or(1.0);
301
302 if variance <= 0.0 {
303 return Err(bma_error("Numerical instability in posterior computation"));
304 }
305
306 log_likelihood += -0.5
307 * (residual.powi(2) / variance + variance.ln() + (2.0 * std::f64::consts::PI).ln());
308 }
309
310 let regularization = -0.5 * self.config.regularization_lambda * model.n_parameters as f64;
311 Ok(log_likelihood + regularization)
312 }
313
314 fn compute_cv_error(&self, model: &ModelInfo) -> Result<f64> {
315 let validation_error = 1.0 - model.validation_accuracy;
316 Ok(validation_error.max(1e-10).ln())
317 }
318
319 fn compute_bootstrap_error(&self, model: &ModelInfo) -> Result<f64> {
320 let training_error = 1.0 - model.training_accuracy;
321 let validation_error = 1.0 - model.validation_accuracy;
322 let bootstrap_error = (training_error + validation_error) / 2.0;
323 Ok(bootstrap_error.max(1e-10).ln())
324 }
325
326 fn normalize_log_weights(&self, log_weights: &mut [f64]) -> Result<Vec<f64>> {
327 let max_log_weight = log_weights
328 .iter()
329 .cloned()
330 .fold(f64::NEG_INFINITY, f64::max);
331
332 if max_log_weight.is_infinite() {
333 return Err(bma_error("Numerical instability in posterior computation"));
334 }
335
336 for w in log_weights.iter_mut() {
337 *w -= max_log_weight;
338 }
339
340 let weights: Vec<f64> = log_weights.iter().map(|&lw| lw.exp()).collect();
341 self.normalize_weights(&weights)
342 }
343
344 fn normalize_weights(&self, weights: &[f64]) -> Result<Vec<f64>> {
345 let sum: f64 = weights.iter().sum();
346
347 if sum == 0.0 || !sum.is_finite() {
348 return Err(bma_error("Numerical instability in posterior computation"));
349 }
350
351 let normalized: Vec<f64> = weights
352 .iter()
353 .map(|&w| w / sum)
354 .map(|w| {
355 if w < self.config.min_weight_threshold {
356 0.0
357 } else {
358 w
359 }
360 })
361 .collect();
362
363 let final_sum: f64 = normalized.iter().sum();
364 if final_sum == 0.0 {
365 return Err(bma_error("Numerical instability in posterior computation"));
366 }
367
368 Ok(normalized.iter().map(|&w| w / final_sum).collect())
369 }
370
371 fn compute_weighted_predictions(&self, weights: &[f64]) -> Result<Array1<f64>> {
372 if self.models.is_empty() {
373 return Err(bma_error("No models provided"));
374 }
375
376 let n_predictions = self.models[0].predictions.len();
377 let mut averaged = Array1::zeros(n_predictions);
378
379 for (weight, model) in weights.iter().zip(self.models.iter()) {
380 averaged = averaged + *weight * &model.predictions;
381 }
382
383 Ok(averaged)
384 }
385
386 fn compute_prediction_variance(
387 &self,
388 weights: &[f64],
389 averaged_predictions: &Array1<f64>,
390 ) -> Result<Array1<f64>> {
391 let n_predictions = averaged_predictions.len();
392 let mut variance = Array1::zeros(n_predictions);
393
394 for i in 0..n_predictions {
395 let mut prediction_var = 0.0;
396 let mut model_var = 0.0;
397
398 for (weight, model) in weights.iter().zip(self.models.iter()) {
399 let diff = model.predictions[i] - averaged_predictions[i];
400 prediction_var += weight * diff.powi(2);
401
402 if let Some(ref var) = model.prediction_variance {
403 model_var += weight * var[i];
404 }
405 }
406
407 variance[i] = prediction_var + model_var;
408 }
409
410 Ok(variance)
411 }
412
413 fn compute_effective_model_count(&self, weights: &[f64]) -> f64 {
414 let sum_squares: f64 = weights.iter().map(|w| w.powi(2)).sum();
415 if sum_squares > 0.0 {
416 1.0 / sum_squares
417 } else {
418 0.0
419 }
420 }
421
422 fn compute_total_evidence(&self, _y_true: Option<&ArrayView1<f64>>) -> Result<f64> {
423 let mut total_evidence = 0.0;
424
425 for model in &self.models {
426 let evidence = self
427 .evidence_cache
428 .get(&model.model_id)
429 .copied()
430 .unwrap_or(model.log_likelihood);
431 total_evidence += evidence.exp();
432 }
433
434 Ok(total_evidence.ln())
435 }
436
437 fn compute_ensemble_accuracy(
438 &self,
439 predictions: &Array1<f64>,
440 y_true: &ArrayView1<f64>,
441 ) -> f64 {
442 let mse: f64 = predictions
443 .iter()
444 .zip(y_true.iter())
445 .map(|(pred, true_val)| (pred - true_val).powi(2))
446 .sum::<f64>()
447 / predictions.len() as f64;
448
449 (-mse).exp()
450 }
451
452 pub fn get_model_rankings(&self, result: &BMAResult) -> Vec<(String, f64)> {
453 let mut rankings: Vec<_> = result
454 .model_weights
455 .iter()
456 .map(|(id, &weight)| (id.clone(), weight))
457 .collect();
458 rankings.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
459 rankings
460 }
461
462 pub fn prune_models(&mut self, min_weight: f64) -> usize {
463 let weights_result = self.compute_posterior_weights(None);
464 if let Ok(weights) = weights_result {
465 let indices_to_keep: Vec<usize> = weights
466 .iter()
467 .enumerate()
468 .filter(|(_, &w)| w >= min_weight)
469 .map(|(i, _)| i)
470 .collect();
471
472 let mut new_models = Vec::new();
473 for &idx in &indices_to_keep {
474 new_models.push(self.models[idx].clone());
475 }
476
477 let pruned_count = self.models.len() - new_models.len();
478 self.models = new_models;
479 self.evidence_cache.clear();
480
481 pruned_count
482 } else {
483 0
484 }
485 }
486}
487
488pub fn bayesian_model_average(
489 models: Vec<ModelInfo>,
490 y_true: Option<&ArrayView1<f64>>,
491 config: Option<BMAConfig>,
492) -> Result<BMAResult> {
493 let config = config.unwrap_or_default();
494 let mut averager = BayesianModelAverager::new(config);
495 averager.add_models(models)?;
496 averager.compute_average(y_true)
497}
498
499#[allow(non_snake_case)]
500#[cfg(test)]
501mod tests {
502 use super::*;
503 use scirs2_core::ndarray::arr1;
504
505 fn create_test_models() -> Vec<ModelInfo> {
506 vec![
507 ModelInfo {
508 model_id: "model1".to_string(),
509 complexity: 10,
510 training_accuracy: 0.85,
511 validation_accuracy: 0.80,
512 log_likelihood: -100.0,
513 n_parameters: 10,
514 predictions: arr1(&[0.8, 0.6, 0.9, 0.7, 0.5]),
515 prediction_variance: Some(arr1(&[0.01, 0.02, 0.01, 0.03, 0.02])),
516 },
517 ModelInfo {
518 model_id: "model2".to_string(),
519 complexity: 15,
520 training_accuracy: 0.90,
521 validation_accuracy: 0.82,
522 log_likelihood: -95.0,
523 n_parameters: 15,
524 predictions: arr1(&[0.9, 0.7, 0.8, 0.8, 0.6]),
525 prediction_variance: Some(arr1(&[0.02, 0.01, 0.02, 0.01, 0.03])),
526 },
527 ModelInfo {
528 model_id: "model3".to_string(),
529 complexity: 5,
530 training_accuracy: 0.75,
531 validation_accuracy: 0.78,
532 log_likelihood: -110.0,
533 n_parameters: 5,
534 predictions: arr1(&[0.7, 0.8, 0.7, 0.6, 0.7]),
535 prediction_variance: Some(arr1(&[0.03, 0.02, 0.03, 0.02, 0.01])),
536 },
537 ]
538 }
539
540 #[test]
541 fn test_basic_bma() {
542 let models = create_test_models();
543 let config = BMAConfig::default();
544 let result = bayesian_model_average(models, None, Some(config)).unwrap();
545
546 assert_eq!(result.averaged_predictions.len(), 5);
547 assert_eq!(result.prediction_variance.len(), 5);
548 assert_eq!(result.model_weights.len(), 3);
549 assert!(result.effective_model_count > 0.0);
550 assert!(result.total_evidence.is_finite());
551 }
552
553 #[test]
554 fn test_bma_with_ground_truth() {
555 let models = create_test_models();
556 let y_true = arr1(&[0.8, 0.7, 0.8, 0.7, 0.6]);
557 let config = BMAConfig::default();
558
559 let result = bayesian_model_average(models, Some(&y_true.view()), Some(config)).unwrap();
560
561 assert_eq!(result.averaged_predictions.len(), 5);
562 assert!(result.ensemble_accuracy > 0.0);
563 assert!(result.ensemble_accuracy <= 1.0);
564 }
565
566 #[test]
567 fn test_uniform_prior() {
568 let models = create_test_models();
569 let config = BMAConfig {
570 prior_type: PriorType::Uniform,
571 evidence_method: EvidenceMethod::BIC,
572 ..Default::default()
573 };
574
575 let result = bayesian_model_average(models, None, Some(config)).unwrap();
576 assert!(result.model_weights.values().all(|&w| w > 0.0));
577 }
578
579 #[test]
580 fn test_jeffreys_prior() {
581 let models = create_test_models();
582 let config = BMAConfig {
583 prior_type: PriorType::Jeffreys,
584 evidence_method: EvidenceMethod::AIC,
585 ..Default::default()
586 };
587
588 let result = bayesian_model_average(models, None, Some(config)).unwrap();
589 assert!(result.model_weights.values().all(|&w| w >= 0.0));
590 }
591
592 #[test]
593 fn test_model_pruning() {
594 let models = create_test_models();
595 let config = BMAConfig::default();
596 let mut averager = BayesianModelAverager::new(config);
597 averager.add_models(models).unwrap();
598
599 let initial_count = averager.models.len();
600 let pruned = averager.prune_models(0.1);
601
602 assert!(pruned <= initial_count);
603 assert!(averager.models.len() <= initial_count);
604 }
605
606 #[test]
607 fn test_inconsistent_dimensions() {
608 let mut models = create_test_models();
609 models[1].predictions = arr1(&[0.9, 0.7, 0.8]);
610
611 let result = bayesian_model_average(models, None, None);
612 assert!(result.is_err());
613 }
614
615 #[test]
616 fn test_empty_models() {
617 let models = Vec::new();
618 let result = bayesian_model_average(models, None, None);
619 assert!(result.is_err());
620 }
621
622 #[test]
623 fn test_custom_prior() {
624 let models = create_test_models();
625 let mut prior_weights = HashMap::new();
626 prior_weights.insert("model1".to_string(), 0.5);
627 prior_weights.insert("model2".to_string(), 0.3);
628 prior_weights.insert("model3".to_string(), 0.2);
629
630 let config = BMAConfig {
631 prior_type: PriorType::Custom,
632 ..Default::default()
633 };
634
635 let mut averager = BayesianModelAverager::new(config);
636 averager = averager.with_prior_weights(prior_weights).unwrap();
637 averager.add_models(models).unwrap();
638
639 let result = averager.compute_average(None).unwrap();
640 assert!(result.model_weights.values().all(|&w| w >= 0.0));
641 }
642}