Skip to main content

scirs2_stats/bayesian/
enhanced_regression.rs

1//! Enhanced Bayesian regression methods
2//!
3//! This module provides advanced Bayesian regression techniques including
4//! variational inference, hierarchical models, and robust Bayesian regression.
5
6use crate::error::{StatsError, StatsResult};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView2, ScalarOperand};
8use scirs2_core::numeric::{Float, FromPrimitive, NumAssign, One, ToPrimitive, Zero};
9use scirs2_core::{simd_ops::SimdUnifiedOps, validation::*};
10use std::marker::PhantomData;
11
12/// Enhanced Bayesian linear regression with multiple inference methods
13#[derive(Debug, Clone)]
14pub struct EnhancedBayesianRegression<F> {
15    /// Design matrix (X)
16    pub design_matrix: Array2<F>,
17    /// Response vector (y)
18    pub response: Array1<F>,
19    /// Prior parameters
20    pub prior: BayesianRegressionPrior<F>,
21    /// Inference method
22    pub inference_method: InferenceMethod,
23    /// Model configuration
24    pub config: BayesianRegressionConfig,
25    _phantom: PhantomData<F>,
26}
27
28/// Prior specification for Bayesian regression
29#[derive(Debug, Clone)]
30pub struct BayesianRegressionPrior<F> {
31    /// Prior mean for coefficients
32    pub beta_mean: Array1<F>,
33    /// Prior precision matrix for coefficients
34    pub beta_precision: Array2<F>,
35    /// Prior shape parameter for noise precision
36    pub noiseshape: F,
37    /// Prior rate parameter for noise precision
38    pub noise_rate: F,
39}
40
41/// Inference methods for Bayesian regression
42#[derive(Debug, Clone, PartialEq)]
43pub enum InferenceMethod {
44    /// Exact conjugate inference (when applicable)
45    Exact,
46    /// Variational Bayes inference
47    VariationalBayes,
48    /// MCMC sampling
49    MCMC,
50    /// Expectation Propagation
51    ExpectationPropagation,
52}
53
54/// Configuration for Bayesian regression
55#[derive(Debug, Clone)]
56pub struct BayesianRegressionConfig {
57    /// Maximum iterations for iterative methods
58    pub max_iter: usize,
59    /// Convergence tolerance
60    pub tolerance: f64,
61    /// Whether to use parallel processing
62    pub parallel: bool,
63    /// Random seed for reproducibility
64    pub seed: Option<u64>,
65}
66
67impl Default for BayesianRegressionConfig {
68    fn default() -> Self {
69        Self {
70            max_iter: 1000,
71            tolerance: 1e-6,
72            parallel: true,
73            seed: None,
74        }
75    }
76}
77
78/// Posterior results for Bayesian regression
79#[derive(Debug, Clone)]
80pub struct BayesianRegressionResult<F> {
81    /// Posterior mean of coefficients
82    pub beta_mean: Array1<F>,
83    /// Posterior covariance of coefficients
84    pub beta_covariance: Array2<F>,
85    /// Posterior mean of noise precision
86    pub noise_precision_mean: F,
87    /// Posterior variance of noise precision
88    pub noise_precision_var: F,
89    /// Log marginal likelihood (model evidence)
90    pub log_marginal_likelihood: F,
91    /// Predictive mean
92    pub predictive_mean: Array1<F>,
93    /// Predictive variance
94    pub predictive_var: Array1<F>,
95    /// Convergence information
96    pub convergence_info: ConvergenceInfo,
97}
98
99/// Convergence information
100#[derive(Debug, Clone)]
101pub struct ConvergenceInfo {
102    /// Whether convergence was achieved
103    pub converged: bool,
104    /// Number of iterations taken
105    pub iterations: usize,
106    /// Final tolerance achieved
107    pub final_tolerance: f64,
108}
109
110impl<F> EnhancedBayesianRegression<F>
111where
112    F: Float
113        + Zero
114        + One
115        + Copy
116        + Send
117        + Sync
118        + SimdUnifiedOps
119        + std::fmt::Display
120        + 'static
121        + std::iter::Sum
122        + NumAssign
123        + ScalarOperand
124        + ToPrimitive
125        + FromPrimitive,
126{
127    /// Create new enhanced Bayesian regression model
128    pub fn new(
129        design_matrix: Array2<F>,
130        response: Array1<F>,
131        prior: BayesianRegressionPrior<F>,
132        inference_method: InferenceMethod,
133    ) -> StatsResult<Self> {
134        checkarray_finite(&design_matrix, "design_matrix")?;
135        checkarray_finite(&response, "response")?;
136        checkarray_finite(&prior.beta_mean, "beta_mean")?;
137        checkarray_finite(&prior.beta_precision, "beta_precision")?;
138
139        let (n, p) = design_matrix.dim();
140
141        if response.len() != n {
142            return Err(StatsError::DimensionMismatch(format!(
143                "Response length ({}) must match design _matrix rows ({})",
144                response.len(),
145                n
146            )));
147        }
148
149        if prior.beta_mean.len() != p {
150            return Err(StatsError::DimensionMismatch(format!(
151                "Prior mean length ({}) must match design _matrix columns ({})",
152                prior.beta_mean.len(),
153                p
154            )));
155        }
156
157        if prior.beta_precision.nrows() != p || prior.beta_precision.ncols() != p {
158            return Err(StatsError::DimensionMismatch(format!(
159                "Prior precision shape ({}, {}) must be ({}, {})",
160                prior.beta_precision.nrows(),
161                prior.beta_precision.ncols(),
162                p,
163                p
164            )));
165        }
166
167        Ok(Self {
168            design_matrix,
169            response,
170            prior,
171            inference_method,
172            config: BayesianRegressionConfig::default(),
173            _phantom: PhantomData,
174        })
175    }
176
177    /// Set configuration
178    pub fn with_config(mut self, config: BayesianRegressionConfig) -> Self {
179        self.config = config;
180        self
181    }
182
183    /// Fit the Bayesian regression model
184    pub fn fit(&self) -> StatsResult<BayesianRegressionResult<F>> {
185        match self.inference_method {
186            InferenceMethod::Exact => self.fit_exact(),
187            InferenceMethod::VariationalBayes => self.fit_variational_bayes(),
188            InferenceMethod::MCMC => self.fit_mcmc(),
189            InferenceMethod::ExpectationPropagation => self.fit_expectation_propagation(),
190        }
191    }
192
193    /// Exact conjugate inference (Normal-Gamma conjugacy)
194    fn fit_exact(&self) -> StatsResult<BayesianRegressionResult<F>> {
195        let x = &self.design_matrix;
196        let y = &self.response;
197        let n = x.nrows() as f64;
198        let p = x.ncols();
199
200        // Compute posterior parameters using matrix operations
201        let xtx = x.t().dot(x);
202        let xty = x.t().dot(y);
203
204        // Convert to f64 for numerical stability
205        let xtx_f64 = xtx.mapv(|v| v.to_f64().unwrap_or(0.0));
206        let xty_f64 = xty.mapv(|v| v.to_f64().unwrap_or(0.0));
207        let prior_precision_f64 = self
208            .prior
209            .beta_precision
210            .mapv(|v| v.to_f64().unwrap_or(0.0));
211        let prior_mean_f64 = self.prior.beta_mean.mapv(|v| v.to_f64().unwrap_or(0.0));
212        let noiseshape_f64 = self.prior.noiseshape.to_f64().unwrap_or(1.0);
213        let noise_rate_f64 = self.prior.noise_rate.to_f64().unwrap_or(1.0);
214
215        // Posterior precision matrix
216        let posterior_precision_f64 = xtx_f64.clone() + prior_precision_f64.clone();
217
218        // Invert posterior precision to get covariance
219        let posterior_covariance_f64 = scirs2_linalg::inv(&posterior_precision_f64.view(), None)
220            .map_err(|e| {
221                StatsError::ComputationError(format!("Failed to invert posterior precision: {}", e))
222            })?;
223
224        // Posterior mean
225        let posterior_mean_f64 = posterior_covariance_f64
226            .dot(&(xtx_f64.dot(&xty_f64) + prior_precision_f64.dot(&prior_mean_f64)));
227
228        // Posterior noise parameters
229        let posterior_mean_f: Array1<F> =
230            posterior_mean_f64.mapv(|v| F::from(v).expect("Failed to convert to float"));
231        let residual = y - &x.dot(&posterior_mean_f);
232        let residual_sum_squares = residual.dot(&residual).to_f64().unwrap_or(0.0);
233
234        let posterior_noiseshape = noiseshape_f64 + n / 2.0;
235        let posterior_noise_rate = noise_rate_f64 + residual_sum_squares / 2.0;
236
237        // Convert back to F type
238        let beta_mean =
239            posterior_mean_f64.mapv(|v| F::from(v).expect("Failed to convert to float"));
240        let beta_covariance =
241            posterior_covariance_f64.mapv(|v| F::from(v).expect("Failed to convert to float"));
242
243        let noise_precision_mean = F::from(posterior_noiseshape / posterior_noise_rate)
244            .expect("Failed to convert to float");
245        let noise_precision_var =
246            F::from(posterior_noiseshape / (posterior_noise_rate * posterior_noise_rate))
247                .expect("Operation failed");
248
249        // Compute predictive distribution
250        let predictive_mean = x.dot(&beta_mean);
251        let predictive_var_diag =
252            self.compute_predictive_variance(x.view(), &beta_covariance, noise_precision_mean)?;
253
254        // Compute log marginal likelihood
255        let log_marginal_likelihood = self.compute_log_marginal_likelihood(
256            &xtx_f64,
257            &xty_f64,
258            &prior_precision_f64,
259            &prior_mean_f64,
260            noiseshape_f64,
261            noise_rate_f64,
262            n,
263            p,
264        )?;
265
266        Ok(BayesianRegressionResult {
267            beta_mean,
268            beta_covariance,
269            noise_precision_mean,
270            noise_precision_var,
271            log_marginal_likelihood,
272            predictive_mean,
273            predictive_var: predictive_var_diag,
274            convergence_info: ConvergenceInfo {
275                converged: true,
276                iterations: 1,
277                final_tolerance: 0.0,
278            },
279        })
280    }
281
282    /// Variational Bayes inference
283    fn fit_variational_bayes(&self) -> StatsResult<BayesianRegressionResult<F>> {
284        let x = &self.design_matrix;
285        let y = &self.response;
286        let (n, p) = x.dim();
287
288        // Initialize variational parameters
289        let mut q_beta_mean = self.prior.beta_mean.clone();
290        let mut q_beta_precision = self.prior.beta_precision.clone();
291        let mut q_noiseshape = self.prior.noiseshape;
292        let mut q_noise_rate = self.prior.noise_rate;
293
294        let mut converged = false;
295        let mut iterations = 0;
296        let mut prev_elbo = F::neg_infinity();
297
298        for iter in 0..self.config.max_iter {
299            iterations = iter + 1;
300
301            // Update beta parameters
302            let xtx = x.t().dot(x);
303            let xty = x.t().dot(y);
304            let expected_noise_precision = q_noiseshape / q_noise_rate;
305
306            q_beta_precision =
307                self.prior.beta_precision.clone() + xtx.mapv(|v| v * expected_noise_precision);
308
309            let q_beta_covariance = scirs2_linalg::inv(&q_beta_precision.view(), None)
310                .map_err(|e| StatsError::ComputationError(format!("VB update failed: {}", e)))?;
311
312            q_beta_mean = q_beta_covariance.dot(
313                &(self.prior.beta_precision.dot(&self.prior.beta_mean)
314                    + xty.mapv(|v| v * expected_noise_precision)),
315            );
316
317            // Update noise parameters
318            q_noiseshape = self.prior.noiseshape
319                + F::from(n).expect("Failed to convert to float")
320                    / F::from(2.0).expect("Failed to convert constant to float");
321
322            let _expected_beta_squared =
323                q_beta_mean.dot(&q_beta_mean) + q_beta_covariance.diag().sum();
324            let residual_term = y.dot(y)
325                - F::from(2.0).expect("Failed to convert constant to float")
326                    * y.dot(&x.dot(&q_beta_mean))
327                + x.dot(&q_beta_mean).dot(&x.dot(&q_beta_mean))
328                + (x.t().dot(x) * q_beta_covariance).diag().sum();
329
330            q_noise_rate = self.prior.noise_rate
331                + residual_term / F::from(2.0).expect("Failed to convert constant to float");
332
333            // Compute ELBO for convergence check
334            let elbo =
335                self.compute_elbo(&q_beta_mean, &q_beta_precision, q_noiseshape, q_noise_rate)?;
336
337            if (elbo - prev_elbo).abs()
338                < F::from(self.config.tolerance).expect("Failed to convert to float")
339            {
340                converged = true;
341                break;
342            }
343
344            prev_elbo = elbo;
345        }
346
347        // Compute final results
348        let beta_covariance = scirs2_linalg::inv(&q_beta_precision.view(), None).map_err(|e| {
349            StatsError::ComputationError(format!("Final covariance computation failed: {}", e))
350        })?;
351
352        let noise_precision_mean = q_noiseshape / q_noise_rate;
353        let noise_precision_var = q_noiseshape / (q_noise_rate * q_noise_rate);
354
355        let predictive_mean = x.dot(&q_beta_mean);
356        let predictive_var =
357            self.compute_predictive_variance(x.view(), &beta_covariance, noise_precision_mean)?;
358
359        let log_marginal_likelihood = prev_elbo; // ELBO approximates log marginal likelihood
360
361        Ok(BayesianRegressionResult {
362            beta_mean: q_beta_mean,
363            beta_covariance,
364            noise_precision_mean,
365            noise_precision_var,
366            log_marginal_likelihood,
367            predictive_mean,
368            predictive_var,
369            convergence_info: ConvergenceInfo {
370                converged,
371                iterations,
372                final_tolerance: if converged {
373                    self.config.tolerance
374                } else {
375                    f64::INFINITY
376                },
377            },
378        })
379    }
380
381    /// MCMC inference using Gibbs sampling
382    fn fit_mcmc(&self) -> StatsResult<BayesianRegressionResult<F>> {
383        use scirs2_core::random::rngs::StdRng;
384        use scirs2_core::random::SeedableRng;
385        use scirs2_core::random::{Distribution, Gamma};
386
387        let x = &self.design_matrix;
388        let y = &self.response;
389        let (n, p) = x.dim();
390
391        // Initialize MCMC chain
392        let n_samples_ = self.config.max_iter;
393        let n_burnin = n_samples_ / 4; // 25% burn-in
394        let n_thin = 1; // No thinning for simplicity
395
396        let mut rng = match self.config.seed {
397            Some(seed) => StdRng::seed_from_u64(seed),
398            None => {
399                let mut rng = scirs2_core::random::thread_rng();
400                StdRng::from_rng(&mut rng)
401            }
402        };
403
404        // Initialize parameters
405        #[allow(unused_assignments)]
406        let mut beta = self.prior.beta_mean.clone();
407        let mut noise_precision = self.prior.noiseshape / self.prior.noise_rate;
408
409        // Storage for samples
410        let mut beta_samples = Vec::with_capacity(n_samples_ - n_burnin);
411        let mut noise_precision_samples_ = Vec::with_capacity(n_samples_ - n_burnin);
412        let mut log_likelihood_history = Vec::new();
413
414        // Precompute matrices for efficiency
415        let xtx = x.t().dot(x);
416        let xty = x.t().dot(y);
417
418        // Gibbs sampling
419        for iter in 0..n_samples_ {
420            // Sample beta | noise_precision, y
421            let precision_matrix =
422                self.prior.beta_precision.clone() + xtx.mapv(|v| v * noise_precision);
423
424            // Convert to f64 for numerical stability
425            let precision_f64 = precision_matrix.mapv(|v| v.to_f64().unwrap_or(0.0));
426            let posterior_cov_f64 =
427                scirs2_linalg::inv(&precision_f64.view(), None).map_err(|e| {
428                    StatsError::ComputationError(format!("MCMC covariance inversion failed: {}", e))
429                })?;
430
431            let mean_term = self.prior.beta_precision.dot(&self.prior.beta_mean)
432                + xty.mapv(|v| v * noise_precision);
433            let posterior_mean_f64 =
434                posterior_cov_f64.dot(&mean_term.mapv(|v| v.to_f64().unwrap_or(0.0)));
435
436            // Sample from multivariate normal
437            beta =
438                self.sample_multivariate_normal(&posterior_mean_f64, &posterior_cov_f64, &mut rng)?;
439
440            // Sample noise_precision | beta, y
441            let residual = y - &x.dot(&beta);
442            let sum_squared_residuals = residual.dot(&residual).to_f64().unwrap_or(0.0);
443
444            let posteriorshape = self.prior.noiseshape.to_f64().unwrap_or(1.0) + (n as f64) / 2.0;
445            let posterior_rate =
446                self.prior.noise_rate.to_f64().unwrap_or(1.0) + sum_squared_residuals / 2.0;
447
448            let gamma_dist = Gamma::new(posteriorshape, 1.0 / posterior_rate).map_err(|e| {
449                StatsError::ComputationError(format!("Failed to create gamma distribution: {}", e))
450            })?;
451            noise_precision = F::from(gamma_dist.sample(&mut rng)).expect("Operation failed");
452
453            // Store samples after burn-in
454            if iter >= n_burnin && (iter - n_burnin).is_multiple_of(n_thin) {
455                beta_samples.push(beta.clone());
456                noise_precision_samples_.push(noise_precision);
457            }
458
459            // Compute log-likelihood for convergence monitoring
460            if iter % 100 == 0 {
461                let ll = self.compute_mcmc_log_likelihood(&beta, noise_precision)?;
462                log_likelihood_history.push(ll);
463            }
464        }
465
466        // Compute posterior statistics from samples
467        let n_kept_samples = beta_samples.len();
468        if n_kept_samples == 0 {
469            return Err(StatsError::ComputationError(
470                "No MCMC samples collected".to_string(),
471            ));
472        }
473
474        // Posterior mean of beta
475        let mut posterior_beta_mean = Array1::zeros(p);
476        for sample in &beta_samples {
477            posterior_beta_mean += sample;
478        }
479        posterior_beta_mean /= F::from(n_kept_samples).expect("Failed to convert to float");
480
481        // Posterior covariance of beta
482        let mut posterior_beta_cov = Array2::zeros((p, p));
483        for sample in &beta_samples {
484            let centered = sample - &posterior_beta_mean;
485            for i in 0..p {
486                for j in 0..p {
487                    posterior_beta_cov[[i, j]] += centered[i] * centered[j];
488                }
489            }
490        }
491        posterior_beta_cov /=
492            F::from(n_kept_samples.saturating_sub(1).max(1)).expect("Operation failed");
493
494        // Posterior statistics for noise precision
495        let noise_precision_mean = noise_precision_samples_
496            .iter()
497            .fold(F::zero(), |acc, &x| acc + x)
498            / F::from(n_kept_samples).expect("Failed to convert to float");
499
500        let noise_precision_var = {
501            let mean_sq = noise_precision_samples_
502                .iter()
503                .map(|&x| (x - noise_precision_mean) * (x - noise_precision_mean))
504                .fold(F::zero(), |acc, x| acc + x)
505                / F::from(n_kept_samples.saturating_sub(1).max(1)).expect("Operation failed");
506            mean_sq
507        };
508
509        // Predictive distribution
510        let predictive_mean = x.dot(&posterior_beta_mean);
511        let predictive_var =
512            self.compute_predictive_variance(x.view(), &posterior_beta_cov, noise_precision_mean)?;
513
514        // Compute final log marginal likelihood estimate
515        let final_log_likelihood = if log_likelihood_history.is_empty() {
516            self.compute_mcmc_log_likelihood(&posterior_beta_mean, noise_precision_mean)?
517        } else {
518            *log_likelihood_history.last().expect("Operation failed")
519        };
520
521        // Check convergence based on effective sample size and stability
522        let converged = self.check_mcmc_convergence(&beta_samples, &noise_precision_samples_)?;
523
524        Ok(BayesianRegressionResult {
525            beta_mean: posterior_beta_mean,
526            beta_covariance: posterior_beta_cov,
527            noise_precision_mean,
528            noise_precision_var,
529            log_marginal_likelihood: final_log_likelihood,
530            predictive_mean,
531            predictive_var,
532            convergence_info: ConvergenceInfo {
533                converged,
534                iterations: n_samples_,
535                final_tolerance: if converged {
536                    self.config.tolerance
537                } else {
538                    f64::INFINITY
539                },
540            },
541        })
542    }
543
544    /// Expectation Propagation inference
545    fn fit_expectation_propagation(&self) -> StatsResult<BayesianRegressionResult<F>> {
546        // For now, fall back to variational Bayes
547        // Full EP implementation would be more complex
548        self.fit_variational_bayes()
549    }
550
551    /// Compute predictive variance
552    fn compute_predictive_variance(
553        &self,
554        x: ArrayView2<F>,
555        beta_covariance: &Array2<F>,
556        noise_precision_mean: F,
557    ) -> StatsResult<Array1<F>> {
558        let n = x.nrows();
559        let mut predictive_var = Array1::zeros(n);
560
561        for i in 0..n {
562            let x_i = x.row(i);
563            let var_beta = x_i.dot(&beta_covariance.dot(&x_i));
564            let var_noise = F::one() / noise_precision_mean;
565            predictive_var[i] = var_beta + var_noise;
566        }
567
568        Ok(predictive_var)
569    }
570
571    /// Compute log marginal likelihood for exact inference
572    fn compute_log_marginal_likelihood(
573        &self,
574        xtx: &Array2<f64>,
575        _xty: &Array1<f64>,
576        prior_precision: &Array2<f64>,
577        _prior_mean: &Array1<f64>,
578        noiseshape: f64,
579        noise_rate: f64,
580        n: f64,
581        p: usize,
582    ) -> StatsResult<F> {
583        // This is a simplified version - full implementation would include all normalization terms
584        let posterior_precision = xtx + prior_precision;
585        let det_prior = scirs2_linalg::det(&prior_precision.view(), None).map_err(|e| {
586            StatsError::ComputationError(format!("Determinant computation failed: {}", e))
587        })?;
588        let det_posterior = scirs2_linalg::det(&posterior_precision.view(), None).map_err(|e| {
589            StatsError::ComputationError(format!("Determinant computation failed: {}", e))
590        })?;
591
592        // Simplified log marginal likelihood computation
593        let log_ml = 0.5 * (det_prior / det_posterior).ln() + noiseshape * noise_rate.ln()
594            - (n / 2.0) * (2.0 * std::f64::consts::PI).ln();
595
596        Ok(F::from(log_ml).expect("Failed to convert to float"))
597    }
598
599    /// Compute Evidence Lower BOund (ELBO) for variational inference
600    fn compute_elbo(
601        &self,
602        q_beta_mean: &Array1<F>,
603        _q_beta_precision: &Array2<F>,
604        q_noiseshape: F,
605        q_noise_rate: F,
606    ) -> StatsResult<F> {
607        // Simplified ELBO computation
608        // Full implementation would include entropy terms and expected log-likelihood
609        let expected_noise_precision = q_noiseshape / q_noise_rate;
610        let residual = &self.response - &self.design_matrix.dot(q_beta_mean);
611        let data_term = -F::from(0.5).expect("Failed to convert constant to float")
612            * expected_noise_precision
613            * residual.dot(&residual);
614
615        Ok(data_term)
616    }
617
618    /// Sample from multivariate normal distribution
619    fn sample_multivariate_normal<R: scirs2_core::random::Rng>(
620        &self,
621        mean: &Array1<f64>,
622        covariance: &Array2<f64>,
623        rng: &mut R,
624    ) -> StatsResult<Array1<F>> {
625        use scirs2_core::random::{Distribution, StandardNormal};
626
627        let d = mean.len();
628
629        // Cholesky decomposition of covariance
630        let chol = scirs2_linalg::cholesky(&covariance.view(), None).map_err(|e| {
631            StatsError::ComputationError(format!("Cholesky decomposition failed: {}", e))
632        })?;
633
634        // Sample from standard normal
635        let z: Vec<f64> = (0..d).map(|_| StandardNormal.sample(rng)).collect();
636        let z_array = Array1::from_vec(z);
637
638        // Transform: mean + L * z where L is lower triangular Cholesky factor
639        let sample_f64 = mean + &chol.dot(&z_array);
640        let sample = sample_f64.mapv(|x| F::from(x).expect("Failed to convert to float"));
641
642        Ok(sample)
643    }
644
645    /// Compute log-likelihood for MCMC monitoring
646    fn compute_mcmc_log_likelihood(&self, beta: &Array1<F>, noise_precision: F) -> StatsResult<F> {
647        let x = &self.design_matrix;
648        let y = &self.response;
649        let n = x.nrows() as f64;
650
651        let residual = y - &x.dot(beta);
652        let sum_squared_residuals = residual.dot(&residual).to_f64().unwrap_or(0.0);
653
654        let log_likelihood = (n / 2.0) * noise_precision.to_f64().unwrap_or(1.0).ln()
655            - (n / 2.0) * (2.0 * std::f64::consts::PI).ln()
656            - 0.5 * noise_precision.to_f64().unwrap_or(1.0) * sum_squared_residuals;
657
658        Ok(F::from(log_likelihood).expect("Failed to convert to float"))
659    }
660
661    /// Check MCMC convergence using various diagnostics
662    fn check_mcmc_convergence(
663        &self,
664        beta_samples: &[Array1<F>],
665        noise_precision_samples_: &[F],
666    ) -> StatsResult<bool> {
667        if beta_samples.len() < 100 {
668            return Ok(false); // Need minimum _samples for convergence assessment
669        }
670
671        // Split _samples into two halves for Gelman-Rubin diagnostic
672        let n = beta_samples.len();
673        let mid = n / 2;
674
675        // Simplified convergence check: compare variance of first and second half
676        let first_half = &beta_samples[..mid];
677        let second_half = &beta_samples[mid..];
678
679        // Check if variance stabilized for first parameter
680        if !beta_samples.is_empty() && !beta_samples[0].is_empty() {
681            let first_half_var = self
682                .compute_sample_variance_1d(&first_half.iter().map(|x| x[0]).collect::<Vec<_>>());
683            let second_half_var = self
684                .compute_sample_variance_1d(&second_half.iter().map(|x| x[0]).collect::<Vec<_>>());
685
686            let var_ratio =
687                first_half_var.max(second_half_var) / first_half_var.min(second_half_var);
688            if var_ratio > F::from(2.0).expect("Failed to convert constant to float") {
689                return Ok(false); // Variance not stabilized
690            }
691        }
692
693        // Check effective sample size (simplified)
694        let eff_samplesize = self.compute_effective_samplesize(noise_precision_samples_)?;
695        if eff_samplesize < 100.0 {
696            return Ok(false); // Need larger effective sample size
697        }
698
699        Ok(true)
700    }
701
702    /// Compute sample variance for 1D samples
703    fn compute_sample_variance_1d(&self, samples: &[F]) -> F {
704        if samples.is_empty() {
705            return F::one();
706        }
707
708        let n = samples.len();
709        let mean = samples.iter().fold(F::zero(), |acc, &x| acc + x)
710            / F::from(n).expect("Failed to convert to float");
711        let variance = samples
712            .iter()
713            .map(|&x| (x - mean) * (x - mean))
714            .fold(F::zero(), |acc, x| acc + x)
715            / F::from(n.saturating_sub(1).max(1)).expect("Operation failed");
716
717        variance.max(F::from(1e-10).expect("Failed to convert constant to float"))
718        // Avoid zero variance
719    }
720
721    /// Compute effective sample size (simplified autocorrelation-based estimate)
722    fn compute_effective_samplesize(&self, samples: &[F]) -> StatsResult<f64> {
723        if samples.len() < 10 {
724            return Ok(samples.len() as f64);
725        }
726
727        let n = samples.len();
728        let mean = samples.iter().fold(F::zero(), |acc, &x| acc + x)
729            / F::from(n).expect("Failed to convert to float");
730
731        // Compute lag-1 autocorrelation (simplified)
732        let mut numerator = F::zero();
733        let mut denominator = F::zero();
734
735        for i in 0..n - 1 {
736            let x_i = samples[i] - mean;
737            let x_i1 = samples[i + 1] - mean;
738            numerator += x_i * x_i1;
739            denominator += x_i * x_i;
740        }
741
742        let autocorr = if denominator > F::from(1e-10).expect("Failed to convert constant to float")
743        {
744            (numerator / denominator).to_f64().unwrap_or(0.0)
745        } else {
746            0.0
747        };
748
749        // Simplified effective sample size estimate
750        let eff_n = if autocorr > 0.1 {
751            n as f64 * (1.0 - autocorr) / (1.0 + autocorr)
752        } else {
753            n as f64
754        };
755
756        Ok(eff_n.max(1.0))
757    }
758
759    /// Make predictions on new data
760    pub fn predict(
761        &self,
762        x_new: &Array2<F>,
763        result: &BayesianRegressionResult<F>,
764    ) -> StatsResult<(Array1<F>, Array1<F>)> {
765        checkarray_finite(x_new, "x_new")?;
766
767        if x_new.ncols() != self.design_matrix.ncols() {
768            return Err(StatsError::DimensionMismatch(format!(
769                "New data columns ({}) must match training data columns ({})",
770                x_new.ncols(),
771                self.design_matrix.ncols()
772            )));
773        }
774
775        let pred_mean = x_new.dot(&result.beta_mean);
776        let pred_var = self.compute_predictive_variance(
777            x_new.view(),
778            &result.beta_covariance,
779            result.noise_precision_mean,
780        )?;
781
782        Ok((pred_mean, pred_var))
783    }
784}
785
786impl<F> BayesianRegressionPrior<F>
787where
788    F: Float + Zero + One + Copy + ScalarOperand + std::fmt::Display + FromPrimitive,
789{
790    /// Create uninformative prior
791    pub fn uninformative(p: usize) -> Self {
792        let beta_mean = Array1::zeros(p);
793        let beta_precision =
794            Array2::eye(p) * F::from(1e-6).expect("Failed to convert constant to float"); // Very small precision = large variance
795        let noiseshape = F::from(1e-3).expect("Failed to convert constant to float");
796        let noise_rate = F::from(1e-3).expect("Failed to convert constant to float");
797
798        Self {
799            beta_mean,
800            beta_precision,
801            noiseshape,
802            noise_rate,
803        }
804    }
805
806    /// Create ridge-like prior
807    pub fn ridge(p: usize, alpha: F) -> Self {
808        let beta_mean = Array1::zeros(p);
809        let beta_precision = Array2::eye(p) * alpha;
810        let noiseshape = F::one();
811        let noise_rate = F::one();
812
813        Self {
814            beta_mean,
815            beta_precision,
816            noiseshape,
817            noise_rate,
818        }
819    }
820}
821
822/// Convenience functions
823#[allow(dead_code)]
824pub fn bayesian_linear_regression_exact<F>(
825    x: Array2<F>,
826    y: Array1<F>,
827    prior: Option<BayesianRegressionPrior<F>>,
828) -> StatsResult<BayesianRegressionResult<F>>
829where
830    F: Float
831        + Zero
832        + One
833        + Copy
834        + Send
835        + Sync
836        + SimdUnifiedOps
837        + 'static
838        + std::iter::Sum
839        + NumAssign
840        + ScalarOperand
841        + std::fmt::Display
842        + ToPrimitive
843        + FromPrimitive,
844{
845    let p = x.ncols();
846    let prior = prior.unwrap_or_else(|| BayesianRegressionPrior::uninformative(p));
847
848    let model = EnhancedBayesianRegression::new(x, y, prior, InferenceMethod::Exact)?;
849    model.fit()
850}
851
852#[allow(dead_code)]
853pub fn bayesian_linear_regression_vb<F>(
854    x: Array2<F>,
855    y: Array1<F>,
856    prior: Option<BayesianRegressionPrior<F>>,
857    config: Option<BayesianRegressionConfig>,
858) -> StatsResult<BayesianRegressionResult<F>>
859where
860    F: Float
861        + Zero
862        + One
863        + Copy
864        + Send
865        + Sync
866        + SimdUnifiedOps
867        + 'static
868        + std::iter::Sum
869        + NumAssign
870        + ScalarOperand
871        + std::fmt::Display
872        + ToPrimitive
873        + FromPrimitive,
874{
875    let p = x.ncols();
876    let prior = prior.unwrap_or_else(|| BayesianRegressionPrior::uninformative(p));
877    let config = config.unwrap_or_default();
878
879    let model = EnhancedBayesianRegression::new(x, y, prior, InferenceMethod::VariationalBayes)?
880        .with_config(config);
881    model.fit()
882}