1use crate::error::{StatsError, StatsResult};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView2, ScalarOperand};
8use scirs2_core::numeric::{Float, FromPrimitive, NumAssign, One, ToPrimitive, Zero};
9use scirs2_core::{simd_ops::SimdUnifiedOps, validation::*};
10use std::marker::PhantomData;
11
12#[derive(Debug, Clone)]
14pub struct EnhancedBayesianRegression<F> {
15 pub design_matrix: Array2<F>,
17 pub response: Array1<F>,
19 pub prior: BayesianRegressionPrior<F>,
21 pub inference_method: InferenceMethod,
23 pub config: BayesianRegressionConfig,
25 _phantom: PhantomData<F>,
26}
27
28#[derive(Debug, Clone)]
30pub struct BayesianRegressionPrior<F> {
31 pub beta_mean: Array1<F>,
33 pub beta_precision: Array2<F>,
35 pub noiseshape: F,
37 pub noise_rate: F,
39}
40
41#[derive(Debug, Clone, PartialEq)]
43pub enum InferenceMethod {
44 Exact,
46 VariationalBayes,
48 MCMC,
50 ExpectationPropagation,
52}
53
54#[derive(Debug, Clone)]
56pub struct BayesianRegressionConfig {
57 pub max_iter: usize,
59 pub tolerance: f64,
61 pub parallel: bool,
63 pub seed: Option<u64>,
65}
66
67impl Default for BayesianRegressionConfig {
68 fn default() -> Self {
69 Self {
70 max_iter: 1000,
71 tolerance: 1e-6,
72 parallel: true,
73 seed: None,
74 }
75 }
76}
77
78#[derive(Debug, Clone)]
80pub struct BayesianRegressionResult<F> {
81 pub beta_mean: Array1<F>,
83 pub beta_covariance: Array2<F>,
85 pub noise_precision_mean: F,
87 pub noise_precision_var: F,
89 pub log_marginal_likelihood: F,
91 pub predictive_mean: Array1<F>,
93 pub predictive_var: Array1<F>,
95 pub convergence_info: ConvergenceInfo,
97}
98
99#[derive(Debug, Clone)]
101pub struct ConvergenceInfo {
102 pub converged: bool,
104 pub iterations: usize,
106 pub final_tolerance: f64,
108}
109
110impl<F> EnhancedBayesianRegression<F>
111where
112 F: Float
113 + Zero
114 + One
115 + Copy
116 + Send
117 + Sync
118 + SimdUnifiedOps
119 + std::fmt::Display
120 + 'static
121 + std::iter::Sum
122 + NumAssign
123 + ScalarOperand
124 + ToPrimitive
125 + FromPrimitive,
126{
127 pub fn new(
129 design_matrix: Array2<F>,
130 response: Array1<F>,
131 prior: BayesianRegressionPrior<F>,
132 inference_method: InferenceMethod,
133 ) -> StatsResult<Self> {
134 checkarray_finite(&design_matrix, "design_matrix")?;
135 checkarray_finite(&response, "response")?;
136 checkarray_finite(&prior.beta_mean, "beta_mean")?;
137 checkarray_finite(&prior.beta_precision, "beta_precision")?;
138
139 let (n, p) = design_matrix.dim();
140
141 if response.len() != n {
142 return Err(StatsError::DimensionMismatch(format!(
143 "Response length ({}) must match design _matrix rows ({})",
144 response.len(),
145 n
146 )));
147 }
148
149 if prior.beta_mean.len() != p {
150 return Err(StatsError::DimensionMismatch(format!(
151 "Prior mean length ({}) must match design _matrix columns ({})",
152 prior.beta_mean.len(),
153 p
154 )));
155 }
156
157 if prior.beta_precision.nrows() != p || prior.beta_precision.ncols() != p {
158 return Err(StatsError::DimensionMismatch(format!(
159 "Prior precision shape ({}, {}) must be ({}, {})",
160 prior.beta_precision.nrows(),
161 prior.beta_precision.ncols(),
162 p,
163 p
164 )));
165 }
166
167 Ok(Self {
168 design_matrix,
169 response,
170 prior,
171 inference_method,
172 config: BayesianRegressionConfig::default(),
173 _phantom: PhantomData,
174 })
175 }
176
177 pub fn with_config(mut self, config: BayesianRegressionConfig) -> Self {
179 self.config = config;
180 self
181 }
182
183 pub fn fit(&self) -> StatsResult<BayesianRegressionResult<F>> {
185 match self.inference_method {
186 InferenceMethod::Exact => self.fit_exact(),
187 InferenceMethod::VariationalBayes => self.fit_variational_bayes(),
188 InferenceMethod::MCMC => self.fit_mcmc(),
189 InferenceMethod::ExpectationPropagation => self.fit_expectation_propagation(),
190 }
191 }
192
193 fn fit_exact(&self) -> StatsResult<BayesianRegressionResult<F>> {
195 let x = &self.design_matrix;
196 let y = &self.response;
197 let n = x.nrows() as f64;
198 let p = x.ncols();
199
200 let xtx = x.t().dot(x);
202 let xty = x.t().dot(y);
203
204 let xtx_f64 = xtx.mapv(|v| v.to_f64().unwrap_or(0.0));
206 let xty_f64 = xty.mapv(|v| v.to_f64().unwrap_or(0.0));
207 let prior_precision_f64 = self
208 .prior
209 .beta_precision
210 .mapv(|v| v.to_f64().unwrap_or(0.0));
211 let prior_mean_f64 = self.prior.beta_mean.mapv(|v| v.to_f64().unwrap_or(0.0));
212 let noiseshape_f64 = self.prior.noiseshape.to_f64().unwrap_or(1.0);
213 let noise_rate_f64 = self.prior.noise_rate.to_f64().unwrap_or(1.0);
214
215 let posterior_precision_f64 = xtx_f64.clone() + prior_precision_f64.clone();
217
218 let posterior_covariance_f64 = scirs2_linalg::inv(&posterior_precision_f64.view(), None)
220 .map_err(|e| {
221 StatsError::ComputationError(format!("Failed to invert posterior precision: {}", e))
222 })?;
223
224 let posterior_mean_f64 = posterior_covariance_f64
226 .dot(&(xtx_f64.dot(&xty_f64) + prior_precision_f64.dot(&prior_mean_f64)));
227
228 let posterior_mean_f: Array1<F> =
230 posterior_mean_f64.mapv(|v| F::from(v).expect("Failed to convert to float"));
231 let residual = y - &x.dot(&posterior_mean_f);
232 let residual_sum_squares = residual.dot(&residual).to_f64().unwrap_or(0.0);
233
234 let posterior_noiseshape = noiseshape_f64 + n / 2.0;
235 let posterior_noise_rate = noise_rate_f64 + residual_sum_squares / 2.0;
236
237 let beta_mean =
239 posterior_mean_f64.mapv(|v| F::from(v).expect("Failed to convert to float"));
240 let beta_covariance =
241 posterior_covariance_f64.mapv(|v| F::from(v).expect("Failed to convert to float"));
242
243 let noise_precision_mean = F::from(posterior_noiseshape / posterior_noise_rate)
244 .expect("Failed to convert to float");
245 let noise_precision_var =
246 F::from(posterior_noiseshape / (posterior_noise_rate * posterior_noise_rate))
247 .expect("Operation failed");
248
249 let predictive_mean = x.dot(&beta_mean);
251 let predictive_var_diag =
252 self.compute_predictive_variance(x.view(), &beta_covariance, noise_precision_mean)?;
253
254 let log_marginal_likelihood = self.compute_log_marginal_likelihood(
256 &xtx_f64,
257 &xty_f64,
258 &prior_precision_f64,
259 &prior_mean_f64,
260 noiseshape_f64,
261 noise_rate_f64,
262 n,
263 p,
264 )?;
265
266 Ok(BayesianRegressionResult {
267 beta_mean,
268 beta_covariance,
269 noise_precision_mean,
270 noise_precision_var,
271 log_marginal_likelihood,
272 predictive_mean,
273 predictive_var: predictive_var_diag,
274 convergence_info: ConvergenceInfo {
275 converged: true,
276 iterations: 1,
277 final_tolerance: 0.0,
278 },
279 })
280 }
281
282 fn fit_variational_bayes(&self) -> StatsResult<BayesianRegressionResult<F>> {
284 let x = &self.design_matrix;
285 let y = &self.response;
286 let (n, p) = x.dim();
287
288 let mut q_beta_mean = self.prior.beta_mean.clone();
290 let mut q_beta_precision = self.prior.beta_precision.clone();
291 let mut q_noiseshape = self.prior.noiseshape;
292 let mut q_noise_rate = self.prior.noise_rate;
293
294 let mut converged = false;
295 let mut iterations = 0;
296 let mut prev_elbo = F::neg_infinity();
297
298 for iter in 0..self.config.max_iter {
299 iterations = iter + 1;
300
301 let xtx = x.t().dot(x);
303 let xty = x.t().dot(y);
304 let expected_noise_precision = q_noiseshape / q_noise_rate;
305
306 q_beta_precision =
307 self.prior.beta_precision.clone() + xtx.mapv(|v| v * expected_noise_precision);
308
309 let q_beta_covariance = scirs2_linalg::inv(&q_beta_precision.view(), None)
310 .map_err(|e| StatsError::ComputationError(format!("VB update failed: {}", e)))?;
311
312 q_beta_mean = q_beta_covariance.dot(
313 &(self.prior.beta_precision.dot(&self.prior.beta_mean)
314 + xty.mapv(|v| v * expected_noise_precision)),
315 );
316
317 q_noiseshape = self.prior.noiseshape
319 + F::from(n).expect("Failed to convert to float")
320 / F::from(2.0).expect("Failed to convert constant to float");
321
322 let _expected_beta_squared =
323 q_beta_mean.dot(&q_beta_mean) + q_beta_covariance.diag().sum();
324 let residual_term = y.dot(y)
325 - F::from(2.0).expect("Failed to convert constant to float")
326 * y.dot(&x.dot(&q_beta_mean))
327 + x.dot(&q_beta_mean).dot(&x.dot(&q_beta_mean))
328 + (x.t().dot(x) * q_beta_covariance).diag().sum();
329
330 q_noise_rate = self.prior.noise_rate
331 + residual_term / F::from(2.0).expect("Failed to convert constant to float");
332
333 let elbo =
335 self.compute_elbo(&q_beta_mean, &q_beta_precision, q_noiseshape, q_noise_rate)?;
336
337 if (elbo - prev_elbo).abs()
338 < F::from(self.config.tolerance).expect("Failed to convert to float")
339 {
340 converged = true;
341 break;
342 }
343
344 prev_elbo = elbo;
345 }
346
347 let beta_covariance = scirs2_linalg::inv(&q_beta_precision.view(), None).map_err(|e| {
349 StatsError::ComputationError(format!("Final covariance computation failed: {}", e))
350 })?;
351
352 let noise_precision_mean = q_noiseshape / q_noise_rate;
353 let noise_precision_var = q_noiseshape / (q_noise_rate * q_noise_rate);
354
355 let predictive_mean = x.dot(&q_beta_mean);
356 let predictive_var =
357 self.compute_predictive_variance(x.view(), &beta_covariance, noise_precision_mean)?;
358
359 let log_marginal_likelihood = prev_elbo; Ok(BayesianRegressionResult {
362 beta_mean: q_beta_mean,
363 beta_covariance,
364 noise_precision_mean,
365 noise_precision_var,
366 log_marginal_likelihood,
367 predictive_mean,
368 predictive_var,
369 convergence_info: ConvergenceInfo {
370 converged,
371 iterations,
372 final_tolerance: if converged {
373 self.config.tolerance
374 } else {
375 f64::INFINITY
376 },
377 },
378 })
379 }
380
381 fn fit_mcmc(&self) -> StatsResult<BayesianRegressionResult<F>> {
383 use scirs2_core::random::rngs::StdRng;
384 use scirs2_core::random::SeedableRng;
385 use scirs2_core::random::{Distribution, Gamma};
386
387 let x = &self.design_matrix;
388 let y = &self.response;
389 let (n, p) = x.dim();
390
391 let n_samples_ = self.config.max_iter;
393 let n_burnin = n_samples_ / 4; let n_thin = 1; let mut rng = match self.config.seed {
397 Some(seed) => StdRng::seed_from_u64(seed),
398 None => {
399 let mut rng = scirs2_core::random::thread_rng();
400 StdRng::from_rng(&mut rng)
401 }
402 };
403
404 #[allow(unused_assignments)]
406 let mut beta = self.prior.beta_mean.clone();
407 let mut noise_precision = self.prior.noiseshape / self.prior.noise_rate;
408
409 let mut beta_samples = Vec::with_capacity(n_samples_ - n_burnin);
411 let mut noise_precision_samples_ = Vec::with_capacity(n_samples_ - n_burnin);
412 let mut log_likelihood_history = Vec::new();
413
414 let xtx = x.t().dot(x);
416 let xty = x.t().dot(y);
417
418 for iter in 0..n_samples_ {
420 let precision_matrix =
422 self.prior.beta_precision.clone() + xtx.mapv(|v| v * noise_precision);
423
424 let precision_f64 = precision_matrix.mapv(|v| v.to_f64().unwrap_or(0.0));
426 let posterior_cov_f64 =
427 scirs2_linalg::inv(&precision_f64.view(), None).map_err(|e| {
428 StatsError::ComputationError(format!("MCMC covariance inversion failed: {}", e))
429 })?;
430
431 let mean_term = self.prior.beta_precision.dot(&self.prior.beta_mean)
432 + xty.mapv(|v| v * noise_precision);
433 let posterior_mean_f64 =
434 posterior_cov_f64.dot(&mean_term.mapv(|v| v.to_f64().unwrap_or(0.0)));
435
436 beta =
438 self.sample_multivariate_normal(&posterior_mean_f64, &posterior_cov_f64, &mut rng)?;
439
440 let residual = y - &x.dot(&beta);
442 let sum_squared_residuals = residual.dot(&residual).to_f64().unwrap_or(0.0);
443
444 let posteriorshape = self.prior.noiseshape.to_f64().unwrap_or(1.0) + (n as f64) / 2.0;
445 let posterior_rate =
446 self.prior.noise_rate.to_f64().unwrap_or(1.0) + sum_squared_residuals / 2.0;
447
448 let gamma_dist = Gamma::new(posteriorshape, 1.0 / posterior_rate).map_err(|e| {
449 StatsError::ComputationError(format!("Failed to create gamma distribution: {}", e))
450 })?;
451 noise_precision = F::from(gamma_dist.sample(&mut rng)).expect("Operation failed");
452
453 if iter >= n_burnin && (iter - n_burnin).is_multiple_of(n_thin) {
455 beta_samples.push(beta.clone());
456 noise_precision_samples_.push(noise_precision);
457 }
458
459 if iter % 100 == 0 {
461 let ll = self.compute_mcmc_log_likelihood(&beta, noise_precision)?;
462 log_likelihood_history.push(ll);
463 }
464 }
465
466 let n_kept_samples = beta_samples.len();
468 if n_kept_samples == 0 {
469 return Err(StatsError::ComputationError(
470 "No MCMC samples collected".to_string(),
471 ));
472 }
473
474 let mut posterior_beta_mean = Array1::zeros(p);
476 for sample in &beta_samples {
477 posterior_beta_mean += sample;
478 }
479 posterior_beta_mean /= F::from(n_kept_samples).expect("Failed to convert to float");
480
481 let mut posterior_beta_cov = Array2::zeros((p, p));
483 for sample in &beta_samples {
484 let centered = sample - &posterior_beta_mean;
485 for i in 0..p {
486 for j in 0..p {
487 posterior_beta_cov[[i, j]] += centered[i] * centered[j];
488 }
489 }
490 }
491 posterior_beta_cov /=
492 F::from(n_kept_samples.saturating_sub(1).max(1)).expect("Operation failed");
493
494 let noise_precision_mean = noise_precision_samples_
496 .iter()
497 .fold(F::zero(), |acc, &x| acc + x)
498 / F::from(n_kept_samples).expect("Failed to convert to float");
499
500 let noise_precision_var = {
501 let mean_sq = noise_precision_samples_
502 .iter()
503 .map(|&x| (x - noise_precision_mean) * (x - noise_precision_mean))
504 .fold(F::zero(), |acc, x| acc + x)
505 / F::from(n_kept_samples.saturating_sub(1).max(1)).expect("Operation failed");
506 mean_sq
507 };
508
509 let predictive_mean = x.dot(&posterior_beta_mean);
511 let predictive_var =
512 self.compute_predictive_variance(x.view(), &posterior_beta_cov, noise_precision_mean)?;
513
514 let final_log_likelihood = if log_likelihood_history.is_empty() {
516 self.compute_mcmc_log_likelihood(&posterior_beta_mean, noise_precision_mean)?
517 } else {
518 *log_likelihood_history.last().expect("Operation failed")
519 };
520
521 let converged = self.check_mcmc_convergence(&beta_samples, &noise_precision_samples_)?;
523
524 Ok(BayesianRegressionResult {
525 beta_mean: posterior_beta_mean,
526 beta_covariance: posterior_beta_cov,
527 noise_precision_mean,
528 noise_precision_var,
529 log_marginal_likelihood: final_log_likelihood,
530 predictive_mean,
531 predictive_var,
532 convergence_info: ConvergenceInfo {
533 converged,
534 iterations: n_samples_,
535 final_tolerance: if converged {
536 self.config.tolerance
537 } else {
538 f64::INFINITY
539 },
540 },
541 })
542 }
543
544 fn fit_expectation_propagation(&self) -> StatsResult<BayesianRegressionResult<F>> {
546 self.fit_variational_bayes()
549 }
550
551 fn compute_predictive_variance(
553 &self,
554 x: ArrayView2<F>,
555 beta_covariance: &Array2<F>,
556 noise_precision_mean: F,
557 ) -> StatsResult<Array1<F>> {
558 let n = x.nrows();
559 let mut predictive_var = Array1::zeros(n);
560
561 for i in 0..n {
562 let x_i = x.row(i);
563 let var_beta = x_i.dot(&beta_covariance.dot(&x_i));
564 let var_noise = F::one() / noise_precision_mean;
565 predictive_var[i] = var_beta + var_noise;
566 }
567
568 Ok(predictive_var)
569 }
570
571 fn compute_log_marginal_likelihood(
573 &self,
574 xtx: &Array2<f64>,
575 _xty: &Array1<f64>,
576 prior_precision: &Array2<f64>,
577 _prior_mean: &Array1<f64>,
578 noiseshape: f64,
579 noise_rate: f64,
580 n: f64,
581 p: usize,
582 ) -> StatsResult<F> {
583 let posterior_precision = xtx + prior_precision;
585 let det_prior = scirs2_linalg::det(&prior_precision.view(), None).map_err(|e| {
586 StatsError::ComputationError(format!("Determinant computation failed: {}", e))
587 })?;
588 let det_posterior = scirs2_linalg::det(&posterior_precision.view(), None).map_err(|e| {
589 StatsError::ComputationError(format!("Determinant computation failed: {}", e))
590 })?;
591
592 let log_ml = 0.5 * (det_prior / det_posterior).ln() + noiseshape * noise_rate.ln()
594 - (n / 2.0) * (2.0 * std::f64::consts::PI).ln();
595
596 Ok(F::from(log_ml).expect("Failed to convert to float"))
597 }
598
599 fn compute_elbo(
601 &self,
602 q_beta_mean: &Array1<F>,
603 _q_beta_precision: &Array2<F>,
604 q_noiseshape: F,
605 q_noise_rate: F,
606 ) -> StatsResult<F> {
607 let expected_noise_precision = q_noiseshape / q_noise_rate;
610 let residual = &self.response - &self.design_matrix.dot(q_beta_mean);
611 let data_term = -F::from(0.5).expect("Failed to convert constant to float")
612 * expected_noise_precision
613 * residual.dot(&residual);
614
615 Ok(data_term)
616 }
617
618 fn sample_multivariate_normal<R: scirs2_core::random::Rng>(
620 &self,
621 mean: &Array1<f64>,
622 covariance: &Array2<f64>,
623 rng: &mut R,
624 ) -> StatsResult<Array1<F>> {
625 use scirs2_core::random::{Distribution, StandardNormal};
626
627 let d = mean.len();
628
629 let chol = scirs2_linalg::cholesky(&covariance.view(), None).map_err(|e| {
631 StatsError::ComputationError(format!("Cholesky decomposition failed: {}", e))
632 })?;
633
634 let z: Vec<f64> = (0..d).map(|_| StandardNormal.sample(rng)).collect();
636 let z_array = Array1::from_vec(z);
637
638 let sample_f64 = mean + &chol.dot(&z_array);
640 let sample = sample_f64.mapv(|x| F::from(x).expect("Failed to convert to float"));
641
642 Ok(sample)
643 }
644
645 fn compute_mcmc_log_likelihood(&self, beta: &Array1<F>, noise_precision: F) -> StatsResult<F> {
647 let x = &self.design_matrix;
648 let y = &self.response;
649 let n = x.nrows() as f64;
650
651 let residual = y - &x.dot(beta);
652 let sum_squared_residuals = residual.dot(&residual).to_f64().unwrap_or(0.0);
653
654 let log_likelihood = (n / 2.0) * noise_precision.to_f64().unwrap_or(1.0).ln()
655 - (n / 2.0) * (2.0 * std::f64::consts::PI).ln()
656 - 0.5 * noise_precision.to_f64().unwrap_or(1.0) * sum_squared_residuals;
657
658 Ok(F::from(log_likelihood).expect("Failed to convert to float"))
659 }
660
661 fn check_mcmc_convergence(
663 &self,
664 beta_samples: &[Array1<F>],
665 noise_precision_samples_: &[F],
666 ) -> StatsResult<bool> {
667 if beta_samples.len() < 100 {
668 return Ok(false); }
670
671 let n = beta_samples.len();
673 let mid = n / 2;
674
675 let first_half = &beta_samples[..mid];
677 let second_half = &beta_samples[mid..];
678
679 if !beta_samples.is_empty() && !beta_samples[0].is_empty() {
681 let first_half_var = self
682 .compute_sample_variance_1d(&first_half.iter().map(|x| x[0]).collect::<Vec<_>>());
683 let second_half_var = self
684 .compute_sample_variance_1d(&second_half.iter().map(|x| x[0]).collect::<Vec<_>>());
685
686 let var_ratio =
687 first_half_var.max(second_half_var) / first_half_var.min(second_half_var);
688 if var_ratio > F::from(2.0).expect("Failed to convert constant to float") {
689 return Ok(false); }
691 }
692
693 let eff_samplesize = self.compute_effective_samplesize(noise_precision_samples_)?;
695 if eff_samplesize < 100.0 {
696 return Ok(false); }
698
699 Ok(true)
700 }
701
702 fn compute_sample_variance_1d(&self, samples: &[F]) -> F {
704 if samples.is_empty() {
705 return F::one();
706 }
707
708 let n = samples.len();
709 let mean = samples.iter().fold(F::zero(), |acc, &x| acc + x)
710 / F::from(n).expect("Failed to convert to float");
711 let variance = samples
712 .iter()
713 .map(|&x| (x - mean) * (x - mean))
714 .fold(F::zero(), |acc, x| acc + x)
715 / F::from(n.saturating_sub(1).max(1)).expect("Operation failed");
716
717 variance.max(F::from(1e-10).expect("Failed to convert constant to float"))
718 }
720
721 fn compute_effective_samplesize(&self, samples: &[F]) -> StatsResult<f64> {
723 if samples.len() < 10 {
724 return Ok(samples.len() as f64);
725 }
726
727 let n = samples.len();
728 let mean = samples.iter().fold(F::zero(), |acc, &x| acc + x)
729 / F::from(n).expect("Failed to convert to float");
730
731 let mut numerator = F::zero();
733 let mut denominator = F::zero();
734
735 for i in 0..n - 1 {
736 let x_i = samples[i] - mean;
737 let x_i1 = samples[i + 1] - mean;
738 numerator += x_i * x_i1;
739 denominator += x_i * x_i;
740 }
741
742 let autocorr = if denominator > F::from(1e-10).expect("Failed to convert constant to float")
743 {
744 (numerator / denominator).to_f64().unwrap_or(0.0)
745 } else {
746 0.0
747 };
748
749 let eff_n = if autocorr > 0.1 {
751 n as f64 * (1.0 - autocorr) / (1.0 + autocorr)
752 } else {
753 n as f64
754 };
755
756 Ok(eff_n.max(1.0))
757 }
758
759 pub fn predict(
761 &self,
762 x_new: &Array2<F>,
763 result: &BayesianRegressionResult<F>,
764 ) -> StatsResult<(Array1<F>, Array1<F>)> {
765 checkarray_finite(x_new, "x_new")?;
766
767 if x_new.ncols() != self.design_matrix.ncols() {
768 return Err(StatsError::DimensionMismatch(format!(
769 "New data columns ({}) must match training data columns ({})",
770 x_new.ncols(),
771 self.design_matrix.ncols()
772 )));
773 }
774
775 let pred_mean = x_new.dot(&result.beta_mean);
776 let pred_var = self.compute_predictive_variance(
777 x_new.view(),
778 &result.beta_covariance,
779 result.noise_precision_mean,
780 )?;
781
782 Ok((pred_mean, pred_var))
783 }
784}
785
786impl<F> BayesianRegressionPrior<F>
787where
788 F: Float + Zero + One + Copy + ScalarOperand + std::fmt::Display + FromPrimitive,
789{
790 pub fn uninformative(p: usize) -> Self {
792 let beta_mean = Array1::zeros(p);
793 let beta_precision =
794 Array2::eye(p) * F::from(1e-6).expect("Failed to convert constant to float"); let noiseshape = F::from(1e-3).expect("Failed to convert constant to float");
796 let noise_rate = F::from(1e-3).expect("Failed to convert constant to float");
797
798 Self {
799 beta_mean,
800 beta_precision,
801 noiseshape,
802 noise_rate,
803 }
804 }
805
806 pub fn ridge(p: usize, alpha: F) -> Self {
808 let beta_mean = Array1::zeros(p);
809 let beta_precision = Array2::eye(p) * alpha;
810 let noiseshape = F::one();
811 let noise_rate = F::one();
812
813 Self {
814 beta_mean,
815 beta_precision,
816 noiseshape,
817 noise_rate,
818 }
819 }
820}
821
822#[allow(dead_code)]
824pub fn bayesian_linear_regression_exact<F>(
825 x: Array2<F>,
826 y: Array1<F>,
827 prior: Option<BayesianRegressionPrior<F>>,
828) -> StatsResult<BayesianRegressionResult<F>>
829where
830 F: Float
831 + Zero
832 + One
833 + Copy
834 + Send
835 + Sync
836 + SimdUnifiedOps
837 + 'static
838 + std::iter::Sum
839 + NumAssign
840 + ScalarOperand
841 + std::fmt::Display
842 + ToPrimitive
843 + FromPrimitive,
844{
845 let p = x.ncols();
846 let prior = prior.unwrap_or_else(|| BayesianRegressionPrior::uninformative(p));
847
848 let model = EnhancedBayesianRegression::new(x, y, prior, InferenceMethod::Exact)?;
849 model.fit()
850}
851
852#[allow(dead_code)]
853pub fn bayesian_linear_regression_vb<F>(
854 x: Array2<F>,
855 y: Array1<F>,
856 prior: Option<BayesianRegressionPrior<F>>,
857 config: Option<BayesianRegressionConfig>,
858) -> StatsResult<BayesianRegressionResult<F>>
859where
860 F: Float
861 + Zero
862 + One
863 + Copy
864 + Send
865 + Sync
866 + SimdUnifiedOps
867 + 'static
868 + std::iter::Sum
869 + NumAssign
870 + ScalarOperand
871 + std::fmt::Display
872 + ToPrimitive
873 + FromPrimitive,
874{
875 let p = x.ncols();
876 let prior = prior.unwrap_or_else(|| BayesianRegressionPrior::uninformative(p));
877 let config = config.unwrap_or_default();
878
879 let model = EnhancedBayesianRegression::new(x, y, prior, InferenceMethod::VariationalBayes)?
880 .with_config(config);
881 model.fit()
882}