sklears_feature_selection/
bayesian.rs

1//! Bayesian feature selection algorithms
2//!
3//! This module provides Bayesian approaches to feature selection, including
4//! spike-and-slab priors, Bayesian model averaging, and variational inference.
5
6use crate::base::{FeatureSelector, SelectorMixin};
7use scirs2_core::ndarray::{Array1, Array2, Axis};
8use sklears_core::{
9    error::{validate, Result as SklResult, SklearsError},
10    traits::{Estimator, Fit, Trained, Transform, Untrained},
11    types::Float,
12};
13use std::marker::PhantomData;
14
15/// Prior type for Bayesian feature selection
16#[derive(Debug, Clone)]
17pub enum PriorType {
18    /// Spike-and-slab prior with given spike and slab variances
19    SpikeAndSlab { spike_var: Float, slab_var: Float },
20    /// Horseshoe prior for sparse feature selection
21    Horseshoe { tau: Float },
22    /// Laplace prior (equivalent to L1 regularization)
23    Laplace { scale: Float },
24    /// Independent normal priors for features
25    Normal { var: Float },
26}
27
28/// Inference method for Bayesian feature selection
29#[derive(Debug, Clone)]
30pub enum BayesianInferenceMethod {
31    /// Variational Bayes with mean-field approximation
32    VariationalBayes { max_iter: usize, tol: Float },
33    /// Gibbs sampling MCMC
34    GibbsSampling { n_samples: usize, burn_in: usize },
35    /// Expectation-Maximization algorithm
36    ExpectationMaximization { max_iter: usize, tol: Float },
37    /// Laplace approximation for posterior
38    LaplaceApproximation,
39}
40
41/// Bayesian variable selection with spike-and-slab priors
42#[derive(Debug, Clone)]
43pub struct BayesianVariableSelector<State = Untrained> {
44    prior: PriorType,
45    inference: BayesianInferenceMethod,
46    n_features_select: Option<usize>,
47    inclusion_threshold: Float,
48    random_state: Option<u64>,
49    state: PhantomData<State>,
50    // Trained state
51    posterior_inclusion_probs_: Option<Array1<Float>>,
52    feature_coefficients_: Option<Array1<Float>>,
53    selected_features_: Option<Vec<usize>>,
54    n_features_: Option<usize>,
55    evidence_: Option<Float>,
56}
57
58impl Default for BayesianVariableSelector<Untrained> {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl BayesianVariableSelector<Untrained> {
65    /// Create a new Bayesian variable selector
66    pub fn new() -> Self {
67        Self {
68            prior: PriorType::SpikeAndSlab {
69                spike_var: 0.01,
70                slab_var: 1.0,
71            },
72            inference: BayesianInferenceMethod::VariationalBayes {
73                max_iter: 100,
74                tol: 1e-4,
75            },
76            n_features_select: None,
77            inclusion_threshold: 0.5,
78            random_state: None,
79            state: PhantomData,
80            posterior_inclusion_probs_: None,
81            feature_coefficients_: None,
82            selected_features_: None,
83            n_features_: None,
84            evidence_: None,
85        }
86    }
87
88    /// Set the prior type
89    pub fn prior(mut self, prior: PriorType) -> Self {
90        self.prior = prior;
91        self
92    }
93
94    /// Set the inference method
95    pub fn inference(mut self, inference: BayesianInferenceMethod) -> Self {
96        self.inference = inference;
97        self
98    }
99
100    /// Set the number of features to select (if None, use threshold)
101    pub fn n_features_select(mut self, n_features: usize) -> Self {
102        self.n_features_select = Some(n_features);
103        self
104    }
105
106    /// Set the inclusion probability threshold
107    pub fn inclusion_threshold(mut self, threshold: Float) -> Self {
108        self.inclusion_threshold = threshold;
109        self
110    }
111
112    /// Set the random state for reproducibility
113    pub fn random_state(mut self, seed: u64) -> Self {
114        self.random_state = Some(seed);
115        self
116    }
117
118    /// Perform Bayesian inference for feature selection
119    fn fit_bayesian(
120        &self,
121        features: &Array2<Float>,
122        target: &Array1<Float>,
123    ) -> SklResult<(Array1<Float>, Array1<Float>, Float)> {
124        match &self.inference {
125            BayesianInferenceMethod::VariationalBayes { max_iter, tol } => {
126                self.variational_bayes_inference(features, target, *max_iter, *tol)
127            }
128            BayesianInferenceMethod::GibbsSampling { n_samples, burn_in } => {
129                self.gibbs_sampling_inference(features, target, *n_samples, *burn_in)
130            }
131            BayesianInferenceMethod::ExpectationMaximization { max_iter, tol } => {
132                self.em_inference(features, target, *max_iter, *tol)
133            }
134            BayesianInferenceMethod::LaplaceApproximation => {
135                self.laplace_approximation_inference(features, target)
136            }
137        }
138    }
139
140    /// Variational Bayes inference with mean-field approximation
141    fn variational_bayes_inference(
142        &self,
143        features: &Array2<Float>,
144        target: &Array1<Float>,
145        max_iter: usize,
146        tol: Float,
147    ) -> SklResult<(Array1<Float>, Array1<Float>, Float)> {
148        let n_features = features.ncols();
149        let n_samples = features.nrows();
150
151        // Initialize variational parameters
152        let mut gamma = Array1::from_elem(n_features, 0.5); // Inclusion probabilities
153        let mut mu = Array1::zeros(n_features); // Mean of coefficients
154        let mut sigma2 = Array1::from_elem(n_features, 1.0); // Variance of coefficients
155
156        let (spike_var, slab_var) = match &self.prior {
157            PriorType::SpikeAndSlab {
158                spike_var,
159                slab_var,
160            } => (*spike_var, *slab_var),
161            _ => (0.01, 1.0), // Default values
162        };
163
164        for _iter in 0..max_iter {
165            let gamma_old = gamma.clone();
166
167            // Update coefficient parameters
168            for j in 0..n_features {
169                let feature_col = features.column(j);
170
171                // Compute residual excluding feature j
172                let mut residual = target.clone();
173                for k in 0..n_features {
174                    if k != j {
175                        let feature_k = features.column(k);
176                        for i in 0..n_samples {
177                            residual[i] -= gamma[k] * mu[k] * feature_k[i];
178                        }
179                    }
180                }
181
182                // Update mean and variance
183                let feature_norm = feature_col.dot(&feature_col);
184                let xy = feature_col.dot(&residual);
185
186                // Handle degenerate case where feature is all zeros
187                if feature_norm < 1e-10 {
188                    // For zero features, use uninformative prior
189                    gamma[j] = 0.5; // Neutral inclusion probability
190                    mu[j] = 0.0; // Zero coefficient
191                    sigma2[j] = slab_var; // Default variance
192                    continue;
193                }
194
195                let precision_spike = 1.0 / spike_var + feature_norm;
196                let precision_slab = 1.0 / slab_var + feature_norm;
197
198                let mu_spike = xy / precision_spike;
199                let mu_slab = xy / precision_slab;
200
201                let sigma2_spike = 1.0 / precision_spike;
202                let sigma2_slab = 1.0 / precision_slab;
203
204                // Update inclusion probability using Bayes rule
205                let log_prob_spike =
206                    -0.5 * (mu_spike * mu_spike / sigma2_spike + (sigma2_spike / spike_var).ln());
207                let log_prob_slab =
208                    -0.5 * (mu_slab * mu_slab / sigma2_slab + (sigma2_slab / slab_var).ln());
209
210                let max_log = log_prob_spike.max(log_prob_slab);
211                let exp_spike = (log_prob_spike - max_log).exp();
212                let exp_slab = (log_prob_slab - max_log).exp();
213
214                let denom = exp_spike + exp_slab;
215                if denom > 1e-10 {
216                    gamma[j] = exp_slab / denom;
217                } else {
218                    gamma[j] = 0.5; // Fallback for numerical issues
219                }
220
221                // Ensure gamma is in valid range [0, 1]
222                gamma[j] = gamma[j].clamp(0.0, 1.0);
223
224                mu[j] = gamma[j] * mu_slab + (1.0 - gamma[j]) * mu_spike;
225                sigma2[j] = gamma[j] * (sigma2_slab + mu_slab * mu_slab)
226                    + (1.0 - gamma[j]) * (sigma2_spike + mu_spike * mu_spike)
227                    - mu[j] * mu[j];
228            }
229
230            // Check convergence
231            let diff = (&gamma - &gamma_old).mapv(|x| x.abs()).sum();
232            if diff < tol {
233                break;
234            }
235        }
236
237        // Compute evidence (approximate)
238        let evidence = self.compute_evidence(features, target, &gamma, &mu);
239
240        Ok((gamma, mu, evidence))
241    }
242
243    /// Gibbs sampling MCMC inference
244    fn gibbs_sampling_inference(
245        &self,
246        features: &Array2<Float>,
247        target: &Array1<Float>,
248        n_samples: usize,
249        burn_in: usize,
250    ) -> SklResult<(Array1<Float>, Array1<Float>, Float)> {
251        let n_features = features.ncols();
252        let n_obs = features.nrows();
253
254        // Initialize parameters
255        let mut gamma = Array1::from_elem(n_features, 0.5);
256        let mut coefficients = Array1::zeros(n_features);
257
258        // Storage for samples
259        let mut gamma_samples = Array2::zeros((n_samples, n_features));
260        let mut coeff_samples = Array2::zeros((n_samples, n_features));
261
262        let (_spike_var, slab_var) = match &self.prior {
263            PriorType::SpikeAndSlab {
264                spike_var,
265                slab_var,
266            } => (*spike_var, *slab_var),
267            _ => (0.01, 1.0),
268        };
269
270        // Simple pseudo-random number generation (for demonstration)
271        let mut rng_state = self.random_state.unwrap_or(42);
272
273        for sample_idx in 0..(n_samples + burn_in) {
274            // Sample each feature indicator and coefficient
275            for j in 0..n_features {
276                let feature_col = features.column(j);
277
278                // Compute residual excluding feature j
279                let mut residual = target.clone();
280                for k in 0..n_features {
281                    if k != j && gamma[k] > 0.5 {
282                        let feature_k = features.column(k);
283                        for i in 0..n_obs {
284                            residual[i] -= coefficients[k] * feature_k[i];
285                        }
286                    }
287                }
288
289                // Sample inclusion indicator
290                let xy = feature_col.dot(&residual);
291                let xx = feature_col.dot(&feature_col);
292
293                let log_prob_in = -0.5 * xx * slab_var / (1.0 + xx * slab_var)
294                    * (xy / (1.0 + xx * slab_var)).powi(2);
295                let log_prob_out = 0.0;
296
297                let prob_in = 1.0 / (1.0 + (log_prob_out - log_prob_in).exp());
298
299                // Simple random number (for demonstration - would use proper RNG in practice)
300                rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
301                let u = (rng_state as Float) / (u32::MAX as Float);
302
303                gamma[j] = if u < prob_in { 1.0 } else { 0.0 };
304
305                // Sample coefficient if included
306                if gamma[j] > 0.5 {
307                    let var_post = 1.0 / (1.0 / slab_var + xx);
308                    let mean_post = var_post * xy;
309
310                    // Simple normal approximation (would use proper sampling in practice)
311                    coefficients[j] = mean_post;
312                } else {
313                    coefficients[j] = 0.0;
314                }
315            }
316
317            // Store samples after burn-in
318            if sample_idx >= burn_in {
319                let store_idx = sample_idx - burn_in;
320                for j in 0..n_features {
321                    gamma_samples[[store_idx, j]] = gamma[j];
322                    coeff_samples[[store_idx, j]] = coefficients[j];
323                }
324            }
325        }
326
327        // Compute posterior inclusion probabilities and coefficient estimates
328        let inclusion_probs = gamma_samples.mean_axis(Axis(0)).unwrap();
329        let coeff_estimates = coeff_samples.mean_axis(Axis(0)).unwrap();
330        let evidence = 0.0; // Would compute marginal likelihood from samples
331
332        Ok((inclusion_probs, coeff_estimates, evidence))
333    }
334
335    /// Expectation-Maximization inference
336    fn em_inference(
337        &self,
338        features: &Array2<Float>,
339        target: &Array1<Float>,
340        max_iter: usize,
341        tol: Float,
342    ) -> SklResult<(Array1<Float>, Array1<Float>, Float)> {
343        let n_features = features.ncols();
344
345        // Initialize parameters
346        let mut inclusion_probs = Array1::from_elem(n_features, 0.5);
347        let mut coefficients = Array1::zeros(n_features);
348        let mut noise_var = 1.0;
349
350        let (spike_var, slab_var) = match &self.prior {
351            PriorType::SpikeAndSlab {
352                spike_var,
353                slab_var,
354            } => (*spike_var, *slab_var),
355            _ => (0.01, 1.0),
356        };
357
358        for _iter in 0..max_iter {
359            let inclusion_probs_old = inclusion_probs.clone();
360
361            // E-step: Update posterior inclusion probabilities
362            for j in 0..n_features {
363                let feature_col = features.column(j);
364
365                // Compute likelihood under both models
366                let mut residual = target.clone();
367                for k in 0..n_features {
368                    if k != j {
369                        let feature_k = features.column(k);
370                        for i in 0..features.nrows() {
371                            residual[i] -= inclusion_probs[k] * coefficients[k] * feature_k[i];
372                        }
373                    }
374                }
375
376                let xy = feature_col.dot(&residual);
377                let xx = feature_col.dot(&feature_col);
378
379                // Handle degenerate case where feature is all zeros
380                if xx < 1e-10 {
381                    // For zero features, use uninformative prior
382                    inclusion_probs[j] = 0.5; // Neutral inclusion probability
383                    coefficients[j] = 0.0; // Zero coefficient
384                    continue;
385                }
386
387                // Bayes factor for inclusion vs exclusion
388                let precision_in = 1.0 / slab_var + xx / noise_var;
389                let precision_out = 1.0 / spike_var + xx / noise_var;
390
391                let mean_in = (xy / noise_var) / precision_in;
392                let mean_out = (xy / noise_var) / precision_out;
393
394                let log_bf = 0.5 * (precision_out / precision_in).ln()
395                    + 0.5
396                        * (mean_in * mean_in * precision_in - mean_out * mean_out * precision_out);
397
398                let prob = 1.0 / (1.0 + (-log_bf).exp());
399                inclusion_probs[j] = prob.clamp(0.0, 1.0); // Ensure valid probability
400                coefficients[j] = inclusion_probs[j] * mean_in;
401            }
402
403            // M-step: Update noise variance
404            let mut sse = 0.0;
405            for i in 0..features.nrows() {
406                let mut pred = 0.0;
407                for j in 0..n_features {
408                    pred += inclusion_probs[j] * coefficients[j] * features[[i, j]];
409                }
410                sse += (target[i] - pred).powi(2);
411            }
412            // Add minimum variance constraint to prevent numerical instability
413            // when data is all zeros or nearly zero
414            noise_var = (sse / features.nrows() as Float).max(1e-10);
415
416            // Check convergence
417            let diff = (&inclusion_probs - &inclusion_probs_old)
418                .mapv(|x| x.abs())
419                .sum();
420            if diff < tol {
421                break;
422            }
423        }
424
425        let evidence = self.compute_evidence(features, target, &inclusion_probs, &coefficients);
426        Ok((inclusion_probs, coefficients, evidence))
427    }
428
429    /// Laplace approximation inference
430    fn laplace_approximation_inference(
431        &self,
432        features: &Array2<Float>,
433        target: &Array1<Float>,
434    ) -> SklResult<(Array1<Float>, Array1<Float>, Float)> {
435        let n_features = features.ncols();
436
437        // For Laplace approximation, we find the MAP estimate and approximate the posterior
438        let mut coefficients = Array1::zeros(n_features);
439        let mut inclusion_probs = Array1::from_elem(n_features, 0.5);
440
441        // Simple optimization for MAP estimate (would use proper optimization in practice)
442        let xtx = features.t().dot(features);
443        let xty = features.t().dot(target);
444
445        // Ridge regression solution as approximation
446        let lambda = 0.1; // Regularization parameter
447        let mut xtx_reg = xtx.clone();
448        for i in 0..n_features {
449            xtx_reg[[i, i]] += lambda;
450        }
451
452        // Solve linear system (simplified - would use proper linear algebra)
453        for j in 0..n_features {
454            if xtx_reg[[j, j]] != 0.0 {
455                coefficients[j] = xty[j] / xtx_reg[[j, j]];
456            }
457        }
458
459        // Compute inclusion probabilities based on coefficient magnitudes
460        let coeff_threshold = coefficients.mapv(|x| x.abs()).mean().unwrap_or(0.0);
461        for j in 0..n_features {
462            inclusion_probs[j] = if coefficients[j].abs() > coeff_threshold {
463                0.8
464            } else {
465                0.2
466            };
467        }
468
469        let evidence = self.compute_evidence(features, target, &inclusion_probs, &coefficients);
470        Ok((inclusion_probs, coefficients, evidence))
471    }
472
473    /// Compute model evidence (marginal likelihood)
474    fn compute_evidence(
475        &self,
476        features: &Array2<Float>,
477        target: &Array1<Float>,
478        inclusion_probs: &Array1<Float>,
479        coefficients: &Array1<Float>,
480    ) -> Float {
481        let n_samples = features.nrows() as Float;
482        let mut sse = 0.0;
483
484        for i in 0..features.nrows() {
485            let mut pred = 0.0;
486            for j in 0..features.ncols() {
487                pred += inclusion_probs[j] * coefficients[j] * features[[i, j]];
488            }
489            sse += (target[i] - pred).powi(2);
490        }
491
492        // Simplified BIC approximation
493        let k = inclusion_probs.sum(); // Effective number of parameters
494
495        // Handle edge case where sse is 0 or very small to avoid log(0)
496        let log_likelihood = if sse < 1e-10 {
497            // Perfect fit or near-perfect fit case
498            -1e-10_f64.ln() // Use a small positive value instead of 0
499        } else {
500            -(sse / n_samples).ln()
501        };
502
503        -0.5 * n_samples * log_likelihood - 0.5 * k * n_samples.ln()
504    }
505
506    /// Select features based on inclusion probabilities
507    fn select_features_from_probabilities(&self, inclusion_probs: &Array1<Float>) -> Vec<usize> {
508        if let Some(n_select) = self.n_features_select {
509            // Select top n features by inclusion probability
510            let mut indices: Vec<usize> = (0..inclusion_probs.len()).collect();
511            indices.sort_by(|&a, &b| {
512                inclusion_probs[b]
513                    .partial_cmp(&inclusion_probs[a])
514                    .unwrap_or(std::cmp::Ordering::Equal)
515            });
516            indices.truncate(n_select);
517            indices
518        } else {
519            // Select features above threshold
520            inclusion_probs
521                .iter()
522                .enumerate()
523                .filter(|(_, &prob)| prob >= self.inclusion_threshold)
524                .map(|(idx, _)| idx)
525                .collect()
526        }
527    }
528}
529
530impl Estimator for BayesianVariableSelector<Untrained> {
531    type Config = ();
532    type Error = SklearsError;
533    type Float = Float;
534
535    fn config(&self) -> &Self::Config {
536        &()
537    }
538}
539
540impl Fit<Array2<Float>, Array1<Float>> for BayesianVariableSelector<Untrained> {
541    type Fitted = BayesianVariableSelector<Trained>;
542
543    fn fit(self, features: &Array2<Float>, target: &Array1<Float>) -> SklResult<Self::Fitted> {
544        validate::check_consistent_length(features, target)?;
545
546        let n_features = features.ncols();
547        if n_features == 0 {
548            return Err(SklearsError::InvalidInput(
549                "No features provided".to_string(),
550            ));
551        }
552
553        let (inclusion_probs, coefficients, evidence) = self.fit_bayesian(features, target)?;
554        let selected_features = self.select_features_from_probabilities(&inclusion_probs);
555
556        if selected_features.is_empty() {
557            return Err(SklearsError::InvalidInput(
558                "No features selected with current threshold".to_string(),
559            ));
560        }
561
562        Ok(BayesianVariableSelector {
563            prior: self.prior,
564            inference: self.inference,
565            n_features_select: self.n_features_select,
566            inclusion_threshold: self.inclusion_threshold,
567            random_state: self.random_state,
568            state: PhantomData,
569            posterior_inclusion_probs_: Some(inclusion_probs),
570            feature_coefficients_: Some(coefficients),
571            selected_features_: Some(selected_features),
572            n_features_: Some(n_features),
573            evidence_: Some(evidence),
574        })
575    }
576}
577
578impl Transform<Array2<Float>> for BayesianVariableSelector<Trained> {
579    fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
580        validate::check_n_features(x, self.n_features_.unwrap())?;
581
582        let selected_features = self.selected_features_.as_ref().unwrap();
583        let n_samples = x.nrows();
584        let n_selected = selected_features.len();
585        let mut x_new = Array2::zeros((n_samples, n_selected));
586
587        for (new_idx, &old_idx) in selected_features.iter().enumerate() {
588            x_new.column_mut(new_idx).assign(&x.column(old_idx));
589        }
590
591        Ok(x_new)
592    }
593}
594
595impl SelectorMixin for BayesianVariableSelector<Trained> {
596    fn get_support(&self) -> SklResult<Array1<bool>> {
597        let n_features = self.n_features_.unwrap();
598        let selected_features = self.selected_features_.as_ref().unwrap();
599        let mut support = Array1::from_elem(n_features, false);
600
601        for &idx in selected_features {
602            support[idx] = true;
603        }
604
605        Ok(support)
606    }
607
608    fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
609        let selected_features = self.selected_features_.as_ref().unwrap();
610        Ok(indices
611            .iter()
612            .filter_map(|&idx| selected_features.iter().position(|&f| f == idx))
613            .collect())
614    }
615}
616
617impl FeatureSelector for BayesianVariableSelector<Trained> {
618    fn selected_features(&self) -> &Vec<usize> {
619        self.selected_features_.as_ref().unwrap()
620    }
621}
622
623impl BayesianVariableSelector<Trained> {
624    /// Get posterior inclusion probabilities
625    pub fn inclusion_probabilities(&self) -> &Array1<Float> {
626        self.posterior_inclusion_probs_.as_ref().unwrap()
627    }
628
629    /// Get feature coefficients
630    pub fn coefficients(&self) -> &Array1<Float> {
631        self.feature_coefficients_.as_ref().unwrap()
632    }
633
634    /// Get model evidence (log marginal likelihood)
635    pub fn evidence(&self) -> Float {
636        self.evidence_.unwrap()
637    }
638
639    /// Get the number of selected features
640    pub fn n_features_out(&self) -> usize {
641        self.selected_features_.as_ref().unwrap().len()
642    }
643
644    /// Check if a feature was selected
645    pub fn is_feature_selected(&self, feature_idx: usize) -> bool {
646        self.selected_features_
647            .as_ref()
648            .unwrap()
649            .contains(&feature_idx)
650    }
651}
652
653/// Bayesian Model Averaging for feature selection
654#[derive(Debug, Clone)]
655pub struct BayesianModelAveraging<State = Untrained> {
656    max_models: usize,
657    prior_inclusion_prob: Float,
658    inference_method: BayesianInferenceMethod,
659    random_state: Option<u64>,
660    state: PhantomData<State>,
661    // Trained state
662    model_probabilities_: Option<Vec<Float>>,
663    model_features_: Option<Vec<Vec<usize>>>,
664    averaged_inclusion_probs_: Option<Array1<Float>>,
665    selected_features_: Option<Vec<usize>>,
666    n_features_: Option<usize>,
667}
668
669impl Default for BayesianModelAveraging<Untrained> {
670    fn default() -> Self {
671        Self::new()
672    }
673}
674
675impl BayesianModelAveraging<Untrained> {
676    pub fn new() -> Self {
677        Self {
678            max_models: 1000,
679            prior_inclusion_prob: 0.5,
680            inference_method: BayesianInferenceMethod::VariationalBayes {
681                max_iter: 50,
682                tol: 1e-3,
683            },
684            random_state: None,
685            state: PhantomData,
686            model_probabilities_: None,
687            model_features_: None,
688            averaged_inclusion_probs_: None,
689            selected_features_: None,
690            n_features_: None,
691        }
692    }
693
694    pub fn max_models(mut self, max_models: usize) -> Self {
695        self.max_models = max_models;
696        self
697    }
698
699    pub fn prior_inclusion_prob(mut self, prob: Float) -> Self {
700        self.prior_inclusion_prob = prob;
701        self
702    }
703
704    pub fn inference_method(mut self, method: BayesianInferenceMethod) -> Self {
705        self.inference_method = method;
706        self
707    }
708
709    pub fn random_state(mut self, seed: u64) -> Self {
710        self.random_state = Some(seed);
711        self
712    }
713
714    /// Enumerate and evaluate models for Bayesian model averaging
715    fn enumerate_models(
716        &self,
717        features: &Array2<Float>,
718        target: &Array1<Float>,
719    ) -> SklResult<(Vec<Vec<usize>>, Vec<Float>, Array1<Float>)> {
720        let n_features = features.ncols();
721        let mut models = Vec::new();
722        let mut model_probs = Vec::new();
723
724        // For demonstration, use a subset of all possible models
725        // In practice, would use more sophisticated model enumeration
726        let max_features_per_model = (n_features / 3).clamp(1, 10);
727
728        // Simple random model generation
729        let mut rng_state = self.random_state.unwrap_or(42);
730
731        for _ in 0..self.max_models.min(1000) {
732            let mut model_features = Vec::new();
733
734            // Randomly select features for this model
735            for j in 0..n_features {
736                rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
737                let u = (rng_state as Float) / (u32::MAX as Float);
738
739                if u < self.prior_inclusion_prob && model_features.len() < max_features_per_model {
740                    model_features.push(j);
741                }
742            }
743
744            if model_features.is_empty() {
745                model_features.push(0); // Ensure at least one feature
746            }
747
748            // Evaluate model
749            let model_evidence = self.evaluate_model(features, target, &model_features)?;
750
751            models.push(model_features);
752            model_probs.push(model_evidence);
753        }
754
755        // Normalize model probabilities with numerical stability
756        if model_probs.is_empty() {
757            return Err(SklearsError::InvalidInput(
758                "No models generated".to_string(),
759            ));
760        }
761
762        let max_evidence = model_probs
763            .iter()
764            .cloned()
765            .fold(f64::NEG_INFINITY, f64::max);
766
767        // If all evidences are negative infinity, assign uniform probabilities
768        if !max_evidence.is_finite() {
769            let uniform_prob = 1.0 / model_probs.len() as Float;
770            model_probs.fill(uniform_prob);
771        } else {
772            let mut total_prob = 0.0;
773            for prob in &mut model_probs {
774                *prob = (*prob - max_evidence).exp();
775                total_prob += *prob;
776            }
777
778            // Ensure total_prob is positive and finite
779            if total_prob <= 0.0 || !total_prob.is_finite() {
780                let uniform_prob = 1.0 / model_probs.len() as Float;
781                model_probs.fill(uniform_prob);
782            } else {
783                for prob in &mut model_probs {
784                    *prob /= total_prob;
785                    // Ensure non-negative probabilities
786                    *prob = prob.max(0.0);
787                }
788
789                // Renormalize to ensure sum = 1 after clamping
790                let new_total: Float = model_probs.iter().sum();
791                if new_total > 0.0 {
792                    for prob in &mut model_probs {
793                        *prob /= new_total;
794                    }
795                }
796            }
797        }
798
799        // Compute averaged inclusion probabilities
800        let mut inclusion_probs = Array1::<Float>::zeros(n_features);
801        for (model, &prob) in models.iter().zip(model_probs.iter()) {
802            for &feature in model {
803                inclusion_probs[feature] += prob;
804            }
805        }
806
807        // Ensure inclusion probabilities are properly bounded [0, 1]
808        for prob in inclusion_probs.iter_mut() {
809            *prob = prob.clamp(0.0, 1.0);
810        }
811
812        Ok((models, model_probs, inclusion_probs))
813    }
814
815    /// Evaluate a single model using marginal likelihood
816    fn evaluate_model(
817        &self,
818        features: &Array2<Float>,
819        target: &Array1<Float>,
820        model_features: &[usize],
821    ) -> SklResult<Float> {
822        if model_features.is_empty() {
823            return Ok(f64::NEG_INFINITY);
824        }
825
826        // Extract features for this model
827        let mut model_x = Array2::zeros((features.nrows(), model_features.len()));
828        for (new_idx, &old_idx) in model_features.iter().enumerate() {
829            model_x
830                .column_mut(new_idx)
831                .assign(&features.column(old_idx));
832        }
833
834        // Compute marginal likelihood (simplified)
835        let n = features.nrows() as Float;
836        let k = model_features.len() as Float;
837
838        // Check for degenerate cases
839        let target_var = target.var(0.0);
840        if target_var < 1e-12 {
841            // If target has zero variance, all models are equally poor
842            return Ok(-1000.0 - k * 10.0); // Penalize complexity when no signal
843        }
844
845        // Check if all features are zero
846        let feature_norm: Float = model_x.iter().map(|&x| x * x).sum();
847        if feature_norm < 1e-12 {
848            // If features are all zero, they have no predictive power
849            return Ok(-1000.0 - k * 10.0);
850        }
851
852        // Bayesian linear regression marginal likelihood approximation
853        let xtx = model_x.t().dot(&model_x);
854        let _xty = model_x.t().dot(target);
855
856        // Add regularization for numerical stability
857        let mut xtx_reg = xtx.clone();
858        for i in 0..model_features.len() {
859            xtx_reg[[i, i]] += 1e-6;
860        }
861
862        // Compute SSE for this model (simplified)
863        let sse = target.dot(target);
864        let sse_normalized = (sse / n).max(1e-12); // Prevent log(0)
865
866        // Simplified BIC-like score with numerical stability
867        let log_likelihood = -0.5 * n * sse_normalized.ln();
868        let penalty = -0.5 * k * n.ln();
869
870        let evidence = log_likelihood + penalty;
871
872        // Ensure finite result
873        if evidence.is_finite() {
874            Ok(evidence)
875        } else {
876            Ok(-1000.0 - k * 10.0)
877        }
878    }
879}
880
881impl Estimator for BayesianModelAveraging<Untrained> {
882    type Config = ();
883    type Error = SklearsError;
884    type Float = Float;
885
886    fn config(&self) -> &Self::Config {
887        &()
888    }
889}
890
891impl Fit<Array2<Float>, Array1<Float>> for BayesianModelAveraging<Untrained> {
892    type Fitted = BayesianModelAveraging<Trained>;
893
894    fn fit(self, features: &Array2<Float>, target: &Array1<Float>) -> SklResult<Self::Fitted> {
895        validate::check_consistent_length(features, target)?;
896
897        let n_features = features.ncols();
898        if n_features == 0 {
899            return Err(SklearsError::InvalidInput(
900                "No features provided".to_string(),
901            ));
902        }
903
904        let (models, model_probs, inclusion_probs) = self.enumerate_models(features, target)?;
905
906        // Select features with highest averaged inclusion probabilities
907        let threshold = 0.5; // Could be made configurable
908        let selected_features: Vec<usize> = inclusion_probs
909            .iter()
910            .enumerate()
911            .filter(|(_, &prob)| prob >= threshold)
912            .map(|(idx, _)| idx)
913            .collect();
914
915        if selected_features.is_empty() {
916            // Fallback: select top features
917            let mut indices: Vec<usize> = (0..n_features).collect();
918            indices.sort_by(|&a, &b| {
919                inclusion_probs[b]
920                    .partial_cmp(&inclusion_probs[a])
921                    .unwrap_or(std::cmp::Ordering::Equal)
922            });
923            let selected_features = indices.into_iter().take(1).collect();
924
925            return Ok(BayesianModelAveraging {
926                max_models: self.max_models,
927                prior_inclusion_prob: self.prior_inclusion_prob,
928                inference_method: self.inference_method,
929                random_state: self.random_state,
930                state: PhantomData,
931                model_probabilities_: Some(model_probs),
932                model_features_: Some(models),
933                averaged_inclusion_probs_: Some(inclusion_probs),
934                selected_features_: Some(selected_features),
935                n_features_: Some(n_features),
936            });
937        }
938
939        Ok(BayesianModelAveraging {
940            max_models: self.max_models,
941            prior_inclusion_prob: self.prior_inclusion_prob,
942            inference_method: self.inference_method,
943            random_state: self.random_state,
944            state: PhantomData,
945            model_probabilities_: Some(model_probs),
946            model_features_: Some(models),
947            averaged_inclusion_probs_: Some(inclusion_probs),
948            selected_features_: Some(selected_features),
949            n_features_: Some(n_features),
950        })
951    }
952}
953
954impl Transform<Array2<Float>> for BayesianModelAveraging<Trained> {
955    fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
956        validate::check_n_features(x, self.n_features_.unwrap())?;
957
958        let selected_features = self.selected_features_.as_ref().unwrap();
959        let n_samples = x.nrows();
960        let n_selected = selected_features.len();
961        let mut x_new = Array2::zeros((n_samples, n_selected));
962
963        for (new_idx, &old_idx) in selected_features.iter().enumerate() {
964            x_new.column_mut(new_idx).assign(&x.column(old_idx));
965        }
966
967        Ok(x_new)
968    }
969}
970
971impl SelectorMixin for BayesianModelAveraging<Trained> {
972    fn get_support(&self) -> SklResult<Array1<bool>> {
973        let n_features = self.n_features_.unwrap();
974        let selected_features = self.selected_features_.as_ref().unwrap();
975        let mut support = Array1::from_elem(n_features, false);
976
977        for &idx in selected_features {
978            support[idx] = true;
979        }
980
981        Ok(support)
982    }
983
984    fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
985        let selected_features = self.selected_features_.as_ref().unwrap();
986        Ok(indices
987            .iter()
988            .filter_map(|&idx| selected_features.iter().position(|&f| f == idx))
989            .collect())
990    }
991}
992
993impl FeatureSelector for BayesianModelAveraging<Trained> {
994    fn selected_features(&self) -> &Vec<usize> {
995        self.selected_features_.as_ref().unwrap()
996    }
997}
998
999impl BayesianModelAveraging<Trained> {
1000    /// Get model probabilities
1001    pub fn model_probabilities(&self) -> &[Float] {
1002        self.model_probabilities_.as_ref().unwrap()
1003    }
1004
1005    /// Get features for each model
1006    pub fn model_features(&self) -> &[Vec<usize>] {
1007        self.model_features_.as_ref().unwrap()
1008    }
1009
1010    /// Get averaged inclusion probabilities
1011    pub fn inclusion_probabilities(&self) -> &Array1<Float> {
1012        self.averaged_inclusion_probs_.as_ref().unwrap()
1013    }
1014
1015    /// Get the number of selected features
1016    pub fn n_features_out(&self) -> usize {
1017        self.selected_features_.as_ref().unwrap().len()
1018    }
1019}
1020
1021#[allow(non_snake_case)]
1022#[cfg(test)]
1023mod tests {
1024    use super::*;
1025    use proptest::prelude::*;
1026    use scirs2_core::ndarray::Array2;
1027
1028    fn create_test_data() -> (Array2<Float>, Array1<Float>) {
1029        // Create synthetic data with some signal
1030        let n_samples = 100;
1031        let n_features = 10;
1032        let mut features = Array2::zeros((n_samples, n_features));
1033        let mut target = Array1::zeros(n_samples);
1034
1035        // Fill with some structured data
1036        for i in 0..n_samples {
1037            for j in 0..n_features {
1038                features[[i, j]] = (i as Float * 0.1 + j as Float * 0.01) % 1.0;
1039            }
1040            // Make first few features predictive
1041            target[i] =
1042                features[[i, 0]] + 0.5 * features[[i, 1]] + 0.1 * ((i as Float) * 0.01).sin();
1043        }
1044
1045        (features, target)
1046    }
1047
1048    #[test]
1049    fn test_bayesian_variable_selector_variational() {
1050        let (features, target) = create_test_data();
1051
1052        let selector = BayesianVariableSelector::new()
1053            .prior(PriorType::SpikeAndSlab {
1054                spike_var: 0.01,
1055                slab_var: 1.0,
1056            })
1057            .inference(BayesianInferenceMethod::VariationalBayes {
1058                max_iter: 10,
1059                tol: 1e-3,
1060            })
1061            .n_features_select(3);
1062
1063        let trained = selector.fit(&features, &target).unwrap();
1064        assert_eq!(trained.n_features_out(), 3);
1065        assert!(trained.inclusion_probabilities().len() == features.ncols());
1066    }
1067
1068    #[test]
1069    fn test_bayesian_variable_selector_em() {
1070        let (features, target) = create_test_data();
1071
1072        let selector = BayesianVariableSelector::new()
1073            .inference(BayesianInferenceMethod::ExpectationMaximization {
1074                max_iter: 10,
1075                tol: 1e-3,
1076            })
1077            .inclusion_threshold(0.3);
1078
1079        let trained = selector.fit(&features, &target).unwrap();
1080        assert!(trained.n_features_out() > 0);
1081    }
1082
1083    #[test]
1084    fn test_bayesian_model_averaging() {
1085        let (features, target) = create_test_data();
1086
1087        let selector = BayesianModelAveraging::new()
1088            .max_models(50)
1089            .prior_inclusion_prob(0.3);
1090
1091        let trained = selector.fit(&features, &target).unwrap();
1092        assert!(trained.n_features_out() > 0);
1093        assert!(trained.model_probabilities().len() > 0);
1094    }
1095
1096    #[test]
1097    fn test_transform() {
1098        let (features, target) = create_test_data();
1099
1100        let selector = BayesianVariableSelector::new().n_features_select(4);
1101
1102        let trained = selector.fit(&features, &target).unwrap();
1103        let transformed = trained.transform(&features).unwrap();
1104
1105        assert_eq!(transformed.ncols(), 4);
1106        assert_eq!(transformed.nrows(), features.nrows());
1107    }
1108
1109    #[test]
1110    fn test_horseshoe_prior() {
1111        let (features, target) = create_test_data();
1112
1113        let selector = BayesianVariableSelector::new()
1114            .prior(PriorType::Horseshoe { tau: 0.1 })
1115            .n_features_select(3);
1116
1117        let trained = selector.fit(&features, &target).unwrap();
1118        assert_eq!(trained.n_features_out(), 3);
1119    }
1120
1121    #[test]
1122    fn test_selector_mixin() {
1123        let (features, target) = create_test_data();
1124
1125        let selector = BayesianVariableSelector::new().n_features_select(5);
1126
1127        let trained = selector.fit(&features, &target).unwrap();
1128        let support = trained.get_support().unwrap();
1129
1130        assert_eq!(support.len(), features.ncols());
1131        assert_eq!(support.iter().filter(|&&x| x).count(), 5);
1132    }
1133
1134    // Property-based tests for Bayesian feature selection
1135    mod proptests {
1136        use super::*;
1137
1138        fn valid_features() -> impl Strategy<Value = Array2<Float>> {
1139            (3usize..10, 20usize..50).prop_flat_map(|(n_cols, n_rows)| {
1140                prop::collection::vec(-5.0..5.0f64, n_rows * n_cols).prop_map(move |values| {
1141                    Array2::from_shape_vec((n_rows, n_cols), values).unwrap()
1142                })
1143            })
1144        }
1145
1146        fn valid_target(n_samples: usize) -> impl Strategy<Value = Array1<Float>> {
1147            prop::collection::vec(-10.0..10.0f64, n_samples)
1148                .prop_map(|values| Array1::from_vec(values))
1149        }
1150
1151        proptest! {
1152            #[test]
1153            fn prop_bayesian_selector_respects_feature_count(
1154                features in valid_features(),
1155                n_features in 1usize..8
1156            ) {
1157                let target = Array1::zeros(features.nrows());
1158                let n_select = n_features.min(features.ncols());
1159
1160                let selector = BayesianVariableSelector::new()
1161                    .n_features_select(n_select)
1162                    .inference(BayesianInferenceMethod::VariationalBayes { max_iter: 5, tol: 1e-2 });
1163
1164                if let Ok(trained) = selector.fit(&features, &target) {
1165                    prop_assert_eq!(trained.n_features_out(), n_select);
1166                    prop_assert_eq!(trained.selected_features().len(), n_select);
1167
1168                    // All selected features should be valid indices
1169                    for &idx in trained.selected_features() {
1170                        prop_assert!(idx < features.ncols());
1171                    }
1172                }
1173            }
1174
1175            #[test]
1176            fn prop_bayesian_selector_inclusion_probabilities_valid(
1177                features in valid_features(),
1178                n_features in 1usize..5
1179            ) {
1180                let target = Array1::zeros(features.nrows());
1181                let n_select = n_features.min(features.ncols());
1182
1183                let selector = BayesianVariableSelector::new()
1184                    .n_features_select(n_select)
1185                    .inference(BayesianInferenceMethod::ExpectationMaximization { max_iter: 5, tol: 1e-2 });
1186
1187                if let Ok(trained) = selector.fit(&features, &target) {
1188                    let inclusion_probs = trained.inclusion_probabilities();
1189
1190                    // All inclusion probabilities should be between 0 and 1
1191                    for &prob in inclusion_probs.iter() {
1192                        prop_assert!(prob >= 0.0);
1193                        prop_assert!(prob <= 1.0);
1194                    }
1195
1196                    // Selected features should have higher inclusion probabilities
1197                    let selected_features = trained.selected_features();
1198                    if !selected_features.is_empty() {
1199                        let min_selected_prob = selected_features.iter()
1200                            .map(|&idx| inclusion_probs[idx])
1201                            .fold(f64::INFINITY, f64::min);
1202
1203                        // There should be at least n_select features with prob >= min_selected_prob
1204                        let count_above_min = inclusion_probs.iter()
1205                            .filter(|&&prob| prob >= min_selected_prob)
1206                            .count();
1207                        prop_assert!(count_above_min >= selected_features.len());
1208                    }
1209                }
1210            }
1211
1212            #[test]
1213            fn prop_bayesian_selector_transform_preserves_shape(
1214                features in valid_features(),
1215                n_features in 1usize..5
1216            ) {
1217                let target = Array1::zeros(features.nrows());
1218                let n_select = n_features.min(features.ncols());
1219
1220                let selector = BayesianVariableSelector::new()
1221                    .n_features_select(n_select);
1222
1223                if let Ok(trained) = selector.fit(&features, &target) {
1224                    if let Ok(transformed) = trained.transform(&features) {
1225                        prop_assert_eq!(transformed.nrows(), features.nrows());
1226                        prop_assert_eq!(transformed.ncols(), n_select);
1227
1228                        // Transformed values should match original features
1229                        for (sample_idx, row) in transformed.rows().into_iter().enumerate() {
1230                            for (new_feat_idx, &value) in row.iter().enumerate() {
1231                                let orig_feat_idx = trained.selected_features()[new_feat_idx];
1232                                let expected = features[[sample_idx, orig_feat_idx]];
1233                                prop_assert!((value - expected).abs() < 1e-10);
1234                            }
1235                        }
1236                    }
1237                }
1238            }
1239
1240            #[test]
1241            fn prop_bayesian_model_averaging_probabilities_sum_to_one(
1242                features in valid_features(),
1243                max_models in 10usize..50
1244            ) {
1245                let target = Array1::zeros(features.nrows());
1246
1247                let selector = BayesianModelAveraging::new()
1248                    .max_models(max_models)
1249                    .prior_inclusion_prob(0.3);
1250
1251                if let Ok(trained) = selector.fit(&features, &target) {
1252                    let model_probs = trained.model_probabilities();
1253
1254                    if !model_probs.is_empty() {
1255                        // All probabilities should be non-negative
1256                        for &prob in model_probs {
1257                            prop_assert!(prob >= 0.0);
1258                        }
1259
1260                        // Probabilities should approximately sum to 1
1261                        let sum: Float = model_probs.iter().sum();
1262                        prop_assert!((sum - 1.0).abs() < 1e-6);
1263                    }
1264                }
1265            }
1266
1267            #[test]
1268            fn prop_bayesian_model_averaging_inclusion_probs_valid(
1269                features in valid_features(),
1270                max_models in 5usize..20
1271            ) {
1272                let target = Array1::zeros(features.nrows());
1273
1274                let selector = BayesianModelAveraging::new()
1275                    .max_models(max_models)
1276                    .prior_inclusion_prob(0.4);
1277
1278                if let Ok(trained) = selector.fit(&features, &target) {
1279                    let inclusion_probs = trained.inclusion_probabilities();
1280
1281                    // All inclusion probabilities should be between 0 and 1
1282                    for &prob in inclusion_probs.iter() {
1283                        prop_assert!(prob >= 0.0);
1284                        prop_assert!(prob <= 1.0);
1285                    }
1286
1287                    // Should have same length as number of features
1288                    prop_assert_eq!(inclusion_probs.len(), features.ncols());
1289                }
1290            }
1291
1292            #[test]
1293            fn prop_prior_types_affect_selection(
1294                features in valid_features(),
1295                n_features in 1usize..3
1296            ) {
1297                let target = Array1::zeros(features.nrows());
1298                let n_select = n_features.min(features.ncols());
1299
1300                // Test different prior types
1301                let priors = vec![
1302                    PriorType::SpikeAndSlab { spike_var: 0.01, slab_var: 1.0 },
1303                    PriorType::Horseshoe { tau: 0.1 },
1304                    PriorType::Laplace { scale: 1.0 },
1305                    PriorType::Normal { var: 1.0 },
1306                ];
1307
1308                for prior in priors {
1309                    let selector = BayesianVariableSelector::new()
1310                        .prior(prior)
1311                        .n_features_select(n_select)
1312                        .inference(BayesianInferenceMethod::LaplaceApproximation);
1313
1314                    if let Ok(trained) = selector.fit(&features, &target) {
1315                        prop_assert_eq!(trained.n_features_out(), n_select);
1316                        prop_assert!(trained.inclusion_probabilities().len() == features.ncols());
1317                    }
1318                }
1319            }
1320
1321            #[test]
1322            fn prop_bayesian_selector_deterministic_with_same_seed(
1323                features in valid_features(),
1324                n_features in 1usize..4,
1325                seed in 1u64..1000
1326            ) {
1327                let target = Array1::zeros(features.nrows());
1328                let n_select = n_features.min(features.ncols());
1329
1330                let selector1 = BayesianVariableSelector::new()
1331                    .n_features_select(n_select)
1332                    .random_state(seed)
1333                    .inference(BayesianInferenceMethod::GibbsSampling { n_samples: 10, burn_in: 5 });
1334
1335                let selector2 = BayesianVariableSelector::new()
1336                    .n_features_select(n_select)
1337                    .random_state(seed)
1338                    .inference(BayesianInferenceMethod::GibbsSampling { n_samples: 10, burn_in: 5 });
1339
1340                if let (Ok(trained1), Ok(trained2)) = (selector1.fit(&features, &target), selector2.fit(&features, &target)) {
1341                    // Same seed should produce same results
1342                    prop_assert_eq!(trained1.selected_features(), trained2.selected_features());
1343                }
1344            }
1345
1346            #[test]
1347            fn prop_bayesian_selector_evidence_is_finite(
1348                features in valid_features(),
1349                n_features in 1usize..3
1350            ) {
1351                let target = Array1::zeros(features.nrows());
1352                let n_select = n_features.min(features.ncols());
1353
1354                let selector = BayesianVariableSelector::new()
1355                    .n_features_select(n_select);
1356
1357                if let Ok(trained) = selector.fit(&features, &target) {
1358                    let evidence = trained.evidence();
1359
1360                    // Evidence should be finite
1361                    prop_assert!(evidence.is_finite());
1362                }
1363            }
1364        }
1365    }
1366}