1mod advi;
17mod families;
18mod svi;
19
20pub use advi::*;
21pub use families::*;
22pub use svi::*;
23
24use crate::error::{StatsError, StatsResult as Result};
25use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
26use scirs2_core::validation::*;
27use statrs::statistics::Statistics;
28use std::f64::consts::PI;
29
30#[derive(Debug, Clone)]
36pub struct VariationalDiagnostics {
37 pub elbo_trace: Vec<f64>,
39 pub gradient_norms: Vec<f64>,
41 pub param_change_norms: Vec<f64>,
43 pub converged: bool,
45 pub n_iterations: usize,
47 pub final_elbo: f64,
49}
50
51impl VariationalDiagnostics {
52 pub fn new() -> Self {
54 Self {
55 elbo_trace: Vec::new(),
56 gradient_norms: Vec::new(),
57 param_change_norms: Vec::new(),
58 converged: false,
59 n_iterations: 0,
60 final_elbo: f64::NEG_INFINITY,
61 }
62 }
63
64 pub fn record_elbo(&mut self, elbo: f64) {
66 self.elbo_trace.push(elbo);
67 self.final_elbo = elbo;
68 self.n_iterations = self.elbo_trace.len();
69 }
70
71 pub fn record_gradient_norm(&mut self, norm: f64) {
73 self.gradient_norms.push(norm);
74 }
75
76 pub fn record_param_change(&mut self, norm: f64) {
78 self.param_change_norms.push(norm);
79 }
80
81 pub fn check_elbo_convergence(&self, tol: f64) -> bool {
83 if self.elbo_trace.len() < 2 {
84 return false;
85 }
86 let n = self.elbo_trace.len();
87 (self.elbo_trace[n - 1] - self.elbo_trace[n - 2]).abs() < tol
88 }
89
90 pub fn check_gradient_convergence(&self, tol: f64) -> bool {
92 if let Some(&last_norm) = self.gradient_norms.last() {
93 last_norm < tol
94 } else {
95 false
96 }
97 }
98
99 pub fn check_param_convergence(&self, tol: f64) -> bool {
101 if let Some(&last_change) = self.param_change_norms.last() {
102 last_change < tol
103 } else {
104 false
105 }
106 }
107
108 pub fn relative_elbo_change(&self, window: usize) -> Option<f64> {
110 let n = self.elbo_trace.len();
111 if n < window + 1 {
112 return None;
113 }
114 let recent = self.elbo_trace[n - 1];
115 let earlier = self.elbo_trace[n - 1 - window];
116 if earlier.abs() < 1e-15 {
117 return Some(f64::INFINITY);
118 }
119 Some((recent - earlier).abs() / earlier.abs())
120 }
121
122 pub fn elbo_summary(&self) -> ElboSummary {
124 let n = self.elbo_trace.len();
125 if n == 0 {
126 return ElboSummary {
127 min: f64::NAN,
128 max: f64::NAN,
129 final_value: f64::NAN,
130 mean_change: f64::NAN,
131 monotonic: true,
132 };
133 }
134
135 let min = self
136 .elbo_trace
137 .iter()
138 .copied()
139 .fold(f64::INFINITY, f64::min);
140 let max = self
141 .elbo_trace
142 .iter()
143 .copied()
144 .fold(f64::NEG_INFINITY, f64::max);
145
146 let mut monotonic = true;
147 let mut total_change = 0.0;
148 for i in 1..n {
149 let change = self.elbo_trace[i] - self.elbo_trace[i - 1];
150 total_change += change.abs();
151 if change < -1e-10 {
152 monotonic = false;
153 }
154 }
155
156 let mean_change = if n > 1 {
157 total_change / (n - 1) as f64
158 } else {
159 0.0
160 };
161
162 ElboSummary {
163 min,
164 max,
165 final_value: self.elbo_trace[n - 1],
166 mean_change,
167 monotonic,
168 }
169 }
170}
171
172impl Default for VariationalDiagnostics {
173 fn default() -> Self {
174 Self::new()
175 }
176}
177
178#[derive(Debug, Clone)]
180pub struct ElboSummary {
181 pub min: f64,
183 pub max: f64,
185 pub final_value: f64,
187 pub mean_change: f64,
189 pub monotonic: bool,
191}
192
193#[derive(Debug, Clone)]
204pub struct MeanFieldGaussian {
205 pub means: Array1<f64>,
207 pub log_stds: Array1<f64>,
209 pub dim: usize,
211}
212
213impl MeanFieldGaussian {
214 pub fn new(dim: usize) -> Result<Self> {
216 check_positive(dim, "dim")?;
217 Ok(Self {
218 means: Array1::zeros(dim),
219 log_stds: Array1::zeros(dim), dim,
221 })
222 }
223
224 pub fn from_params(means: Array1<f64>, log_stds: Array1<f64>) -> Result<Self> {
226 if means.len() != log_stds.len() {
227 return Err(StatsError::DimensionMismatch(format!(
228 "means length ({}) must match log_stds length ({})",
229 means.len(),
230 log_stds.len()
231 )));
232 }
233 checkarray_finite(&means, "means")?;
234 checkarray_finite(&log_stds, "log_stds")?;
235 let dim = means.len();
236 Ok(Self {
237 means,
238 log_stds,
239 dim,
240 })
241 }
242
243 pub fn stds(&self) -> Array1<f64> {
245 self.log_stds.mapv(f64::exp)
246 }
247
248 pub fn variances(&self) -> Array1<f64> {
250 self.log_stds.mapv(|ls| (2.0 * ls).exp())
251 }
252
253 pub fn sample(&self, epsilon: &Array1<f64>) -> Result<Array1<f64>> {
257 if epsilon.len() != self.dim {
258 return Err(StatsError::DimensionMismatch(format!(
259 "epsilon length ({}) must match dimension ({})",
260 epsilon.len(),
261 self.dim
262 )));
263 }
264 let stds = self.stds();
265 Ok(&self.means + &(&stds * epsilon))
266 }
267
268 pub fn entropy(&self) -> f64 {
271 let base = 0.5 * (1.0 + (2.0 * PI).ln());
272 self.log_stds.iter().map(|&ls| base + ls).sum::<f64>()
273 }
274
275 pub fn log_prob(&self, z: &Array1<f64>) -> Result<f64> {
277 if z.len() != self.dim {
278 return Err(StatsError::DimensionMismatch(format!(
279 "z length ({}) must match dimension ({})",
280 z.len(),
281 self.dim
282 )));
283 }
284 let stds = self.stds();
285 let mut log_prob = 0.0;
286 for i in 0..self.dim {
287 let diff = z[i] - self.means[i];
288 log_prob += -0.5 * (2.0 * PI).ln() - self.log_stds[i] - 0.5 * (diff / stds[i]).powi(2);
289 }
290 Ok(log_prob)
291 }
292
293 pub fn n_params(&self) -> usize {
295 2 * self.dim
296 }
297
298 pub fn get_params(&self) -> Array1<f64> {
300 let mut params = Array1::zeros(2 * self.dim);
301 for i in 0..self.dim {
302 params[i] = self.means[i];
303 params[self.dim + i] = self.log_stds[i];
304 }
305 params
306 }
307
308 pub fn set_params(&mut self, params: &Array1<f64>) -> Result<()> {
310 if params.len() != 2 * self.dim {
311 return Err(StatsError::DimensionMismatch(format!(
312 "params length ({}) must be 2 * dim ({})",
313 params.len(),
314 2 * self.dim
315 )));
316 }
317 for i in 0..self.dim {
318 self.means[i] = params[i];
319 self.log_stds[i] = params[self.dim + i];
320 }
321 Ok(())
322 }
323}
324
325#[derive(Debug, Clone)]
332pub struct FullRankGaussian {
333 pub mean: Array1<f64>,
335 pub chol_factor: Array2<f64>,
338 pub dim: usize,
340}
341
342impl FullRankGaussian {
343 pub fn new(dim: usize) -> Result<Self> {
345 check_positive(dim, "dim")?;
346 Ok(Self {
347 mean: Array1::zeros(dim),
348 chol_factor: Array2::eye(dim), dim,
350 })
351 }
352
353 pub fn from_params(mean: Array1<f64>, chol_factor: Array2<f64>) -> Result<Self> {
355 let dim = mean.len();
356 if chol_factor.nrows() != dim || chol_factor.ncols() != dim {
357 return Err(StatsError::DimensionMismatch(format!(
358 "chol_factor shape ({},{}) must be ({},{})",
359 chol_factor.nrows(),
360 chol_factor.ncols(),
361 dim,
362 dim
363 )));
364 }
365 checkarray_finite(&mean, "mean")?;
366 checkarray_finite(&chol_factor, "chol_factor")?;
367 Ok(Self {
368 mean,
369 chol_factor,
370 dim,
371 })
372 }
373
374 pub fn covariance(&self) -> Array2<f64> {
376 self.chol_factor.dot(&self.chol_factor.t())
377 }
378
379 pub fn precision(&self) -> Result<Array2<f64>> {
381 let cov = self.covariance();
382 scirs2_linalg::inv(&cov.view(), None).map_err(|e| {
383 StatsError::ComputationError(format!("Failed to invert covariance: {}", e))
384 })
385 }
386
387 pub fn sample(&self, epsilon: &Array1<f64>) -> Result<Array1<f64>> {
391 if epsilon.len() != self.dim {
392 return Err(StatsError::DimensionMismatch(format!(
393 "epsilon length ({}) must match dimension ({})",
394 epsilon.len(),
395 self.dim
396 )));
397 }
398 Ok(&self.mean + &self.chol_factor.dot(epsilon))
399 }
400
401 pub fn entropy(&self) -> f64 {
404 let base = 0.5 * self.dim as f64 * (1.0 + (2.0 * PI).ln());
405 let log_det: f64 = (0..self.dim)
406 .map(|i| self.chol_factor[[i, i]].abs().ln())
407 .sum();
408 base + log_det
409 }
410
411 pub fn log_prob(&self, z: &Array1<f64>) -> Result<f64> {
413 if z.len() != self.dim {
414 return Err(StatsError::DimensionMismatch(format!(
415 "z length ({}) must match dimension ({})",
416 z.len(),
417 self.dim
418 )));
419 }
420 let precision = self.precision()?;
421 let diff = z - &self.mean;
422 let mahal = diff.dot(&precision.dot(&diff));
423 let log_det: f64 = (0..self.dim)
424 .map(|i| self.chol_factor[[i, i]].abs().ln())
425 .sum();
426 Ok(-0.5 * self.dim as f64 * (2.0 * PI).ln() - log_det - 0.5 * mahal)
427 }
428
429 pub fn n_params(&self) -> usize {
431 self.dim + self.dim * (self.dim + 1) / 2
432 }
433
434 pub fn get_params(&self) -> Array1<f64> {
437 let n_tril = self.dim * (self.dim + 1) / 2;
438 let mut params = Array1::zeros(self.dim + n_tril);
439 for i in 0..self.dim {
441 params[i] = self.mean[i];
442 }
443 let mut idx = self.dim;
445 for i in 0..self.dim {
446 for j in 0..=i {
447 params[idx] = self.chol_factor[[i, j]];
448 idx += 1;
449 }
450 }
451 params
452 }
453
454 pub fn set_params(&mut self, params: &Array1<f64>) -> Result<()> {
456 let n_tril = self.dim * (self.dim + 1) / 2;
457 let expected = self.dim + n_tril;
458 if params.len() != expected {
459 return Err(StatsError::DimensionMismatch(format!(
460 "params length ({}) must be {}",
461 params.len(),
462 expected
463 )));
464 }
465 for i in 0..self.dim {
467 self.mean[i] = params[i];
468 }
469 let mut idx = self.dim;
471 self.chol_factor = Array2::zeros((self.dim, self.dim));
472 for i in 0..self.dim {
473 for j in 0..=i {
474 self.chol_factor[[i, j]] = params[idx];
475 idx += 1;
476 }
477 }
478 Ok(())
479 }
480}
481
482#[derive(Debug, Clone)]
494pub struct NormalizingFlowVI {
495 pub base: MeanFieldGaussian,
497 pub flows: Vec<FlowLayer>,
499 pub dim: usize,
501}
502
503#[derive(Debug, Clone)]
505pub enum FlowLayer {
506 Planar {
508 u: Array1<f64>,
510 w: Array1<f64>,
512 b: f64,
514 },
515 Radial {
517 z0: Array1<f64>,
519 log_alpha: f64,
521 beta: f64,
523 },
524}
525
526impl NormalizingFlowVI {
527 pub fn new(dim: usize, n_flows: usize) -> Result<Self> {
529 check_positive(dim, "dim")?;
530 let base = MeanFieldGaussian::new(dim)?;
531
532 let mut flows = Vec::with_capacity(n_flows);
534 for _ in 0..n_flows {
535 let u = Array1::from_elem(dim, 0.01);
536 let w = Array1::from_elem(dim, 0.01);
537 flows.push(FlowLayer::Planar { u, w, b: 0.0 });
538 }
539
540 Ok(Self { base, flows, dim })
541 }
542
543 pub fn add_planar_flow(&mut self, u: Array1<f64>, w: Array1<f64>, b: f64) -> Result<()> {
545 if u.len() != self.dim || w.len() != self.dim {
546 return Err(StatsError::DimensionMismatch(format!(
547 "u ({}) and w ({}) must have dimension {}",
548 u.len(),
549 w.len(),
550 self.dim
551 )));
552 }
553 self.flows.push(FlowLayer::Planar { u, w, b });
554 Ok(())
555 }
556
557 pub fn add_radial_flow(&mut self, z0: Array1<f64>, log_alpha: f64, beta: f64) -> Result<()> {
559 if z0.len() != self.dim {
560 return Err(StatsError::DimensionMismatch(format!(
561 "z0 ({}) must have dimension {}",
562 z0.len(),
563 self.dim
564 )));
565 }
566 self.flows.push(FlowLayer::Radial {
567 z0,
568 log_alpha,
569 beta,
570 });
571 Ok(())
572 }
573
574 pub fn transform(&self, z0: &Array1<f64>) -> Result<(Array1<f64>, f64)> {
577 if z0.len() != self.dim {
578 return Err(StatsError::DimensionMismatch(format!(
579 "z0 length ({}) must match dimension ({})",
580 z0.len(),
581 self.dim
582 )));
583 }
584 let mut z = z0.clone();
585 let mut sum_log_det_jac = 0.0;
586
587 for flow in &self.flows {
588 let (z_new, log_det) = apply_flow_layer(flow, &z)?;
589 z = z_new;
590 sum_log_det_jac += log_det;
591 }
592
593 Ok((z, sum_log_det_jac))
594 }
595
596 pub fn sample(&self, epsilon: &Array1<f64>) -> Result<(Array1<f64>, f64)> {
598 let z0 = self.base.sample(epsilon)?;
599 let (z_k, sum_log_det) = self.transform(&z0)?;
600 let log_q0 = self.base.log_prob(&z0)?;
601 let log_q_k = log_q0 - sum_log_det;
603 Ok((z_k, log_q_k))
604 }
605
606 pub fn n_flow_params(&self) -> usize {
608 self.flows
609 .iter()
610 .map(|f| match f {
611 FlowLayer::Planar { u, w, .. } => u.len() + w.len() + 1,
612 FlowLayer::Radial { z0, .. } => z0.len() + 2,
613 })
614 .sum()
615 }
616}
617
618fn apply_flow_layer(flow: &FlowLayer, z: &Array1<f64>) -> Result<(Array1<f64>, f64)> {
620 match flow {
621 FlowLayer::Planar { u, w, b } => {
622 let activation = w.dot(z) + b;
624 let tanh_val = activation.tanh();
625 let z_new = z + &(u * tanh_val);
626
627 let dtanh = 1.0 - tanh_val * tanh_val;
629 let psi = w * dtanh;
630 let det = 1.0 + u.dot(&psi);
631 let log_det = det.abs().ln();
632
633 Ok((z_new, log_det))
634 }
635 FlowLayer::Radial {
636 z0,
637 log_alpha,
638 beta,
639 } => {
640 let alpha = log_alpha.exp();
641 let diff = z - z0;
642 let r = diff.dot(&diff).sqrt();
643 let h = 1.0 / (alpha + r);
644 let z_new = z + &(&diff * (*beta * h));
645
646 let d = z.len() as f64;
648 let h_prime = -1.0 / ((alpha + r) * (alpha + r));
649 let term1 = (1.0 + beta * h).powi(d as i32 - 1);
650 let term2 = 1.0 + beta * h + beta * h_prime * r;
651 let det = term1 * term2;
652 let log_det = det.abs().ln();
653
654 Ok((z_new, log_det))
655 }
656 }
657}
658
659#[derive(Debug, Clone)]
668pub struct VariationalBayesianRegression {
669 pub mean_beta: Array1<f64>,
671 pub cov_beta: Array2<f64>,
673 pub shape_tau: f64,
675 pub rate_tau: f64,
677 pub prior_mean_beta: Array1<f64>,
679 pub prior_cov_beta: Array2<f64>,
680 pub priorshape_tau: f64,
681 pub prior_rate_tau: f64,
682 pub n_features: usize,
684 pub fit_intercept: bool,
686}
687
688impl VariationalBayesianRegression {
689 pub fn new(n_features: usize, fit_intercept: bool) -> Result<Self> {
691 check_positive(n_features, "n_features")?;
692
693 let prior_mean_beta = Array1::zeros(n_features);
695 let prior_cov_beta = Array2::eye(n_features) * 100.0; let priorshape_tau = 1e-3;
697 let prior_rate_tau = 1e-3;
698
699 Ok(Self {
700 mean_beta: prior_mean_beta.clone(),
701 cov_beta: prior_cov_beta.clone(),
702 shape_tau: priorshape_tau,
703 rate_tau: prior_rate_tau,
704 prior_mean_beta,
705 prior_cov_beta,
706 priorshape_tau,
707 prior_rate_tau,
708 n_features,
709 fit_intercept,
710 })
711 }
712
713 pub fn with_priors(
715 mut self,
716 prior_mean_beta: Array1<f64>,
717 prior_cov_beta: Array2<f64>,
718 priorshape_tau: f64,
719 prior_rate_tau: f64,
720 ) -> Result<Self> {
721 checkarray_finite(&prior_mean_beta, "prior_mean_beta")?;
722 checkarray_finite(&prior_cov_beta, "prior_cov_beta")?;
723 check_positive(priorshape_tau, "priorshape_tau")?;
724 check_positive(prior_rate_tau, "prior_rate_tau")?;
725
726 self.prior_mean_beta = prior_mean_beta.clone();
727 self.prior_cov_beta = prior_cov_beta.clone();
728 self.priorshape_tau = priorshape_tau;
729 self.prior_rate_tau = prior_rate_tau;
730 self.mean_beta = prior_mean_beta;
731 self.cov_beta = prior_cov_beta;
732 self.shape_tau = priorshape_tau;
733 self.rate_tau = prior_rate_tau;
734
735 Ok(self)
736 }
737
738 pub fn fit(
740 &mut self,
741 x: ArrayView2<f64>,
742 y: ArrayView1<f64>,
743 max_iter: usize,
744 tol: f64,
745 ) -> Result<VariationalRegressionResult> {
746 checkarray_finite(&x, "x")?;
747 checkarray_finite(&y, "y")?;
748 check_positive(max_iter, "max_iter")?;
749 check_positive(tol, "tol")?;
750
751 let (n_samples_, n_features) = x.dim();
752 if y.len() != n_samples_ {
753 return Err(StatsError::DimensionMismatch(format!(
754 "y length ({}) must match x rows ({})",
755 y.len(),
756 n_samples_
757 )));
758 }
759
760 if n_features != self.n_features {
761 return Err(StatsError::DimensionMismatch(format!(
762 "x features ({}) must match model features ({})",
763 n_features, self.n_features
764 )));
765 }
766
767 let (x_centered, y_centered, x_mean, y_mean) = if self.fit_intercept {
769 let x_mean = x.mean_axis(Axis(0)).expect("Operation failed");
770 let y_mean = y.mean();
771
772 let mut x_centered = x.to_owned();
773 for mut row in x_centered.rows_mut() {
774 row -= &x_mean;
775 }
776 let y_centered = &y.to_owned() - y_mean;
777
778 (x_centered, y_centered, Some(x_mean), Some(y_mean))
779 } else {
780 (x.to_owned(), y.to_owned(), None, None)
781 };
782
783 let xtx = x_centered.t().dot(&x_centered);
785 let xty = x_centered.t().dot(&y_centered);
786 let yty = y_centered.dot(&y_centered);
787
788 let prior_precision =
790 scirs2_linalg::inv(&self.prior_cov_beta.view(), None).map_err(|e| {
791 StatsError::ComputationError(format!("Failed to invert prior covariance: {}", e))
792 })?;
793
794 let mut prev_elbo = f64::NEG_INFINITY;
795 let mut elbo_history = Vec::new();
796
797 for _iter in 0..max_iter {
798 self.update_beta_variational(&xtx, &xty, &prior_precision)?;
800
801 self.update_tau_variational(n_samples_ as f64, &xtx, yty)?;
803
804 let elbo = self.compute_elbo(n_samples_ as f64, &xtx, &xty, yty, &prior_precision)?;
806 elbo_history.push(elbo);
807
808 if _iter > 0 && (elbo - prev_elbo).abs() < tol {
810 break;
811 }
812
813 prev_elbo = elbo;
814 }
815
816 Ok(VariationalRegressionResult {
817 mean_beta: self.mean_beta.clone(),
818 cov_beta: self.cov_beta.clone(),
819 shape_tau: self.shape_tau,
820 rate_tau: self.rate_tau,
821 elbo: prev_elbo,
822 elbo_history: elbo_history.clone(),
823 n_samples_,
824 n_features: self.n_features,
825 x_mean,
826 y_mean,
827 converged: elbo_history.len() < max_iter,
828 })
829 }
830
831 fn update_beta_variational(
833 &mut self,
834 xtx: &Array2<f64>,
835 xty: &Array1<f64>,
836 prior_precision: &Array2<f64>,
837 ) -> Result<()> {
838 let expected_tau = self.shape_tau / self.rate_tau;
840
841 let precision_beta = prior_precision + &(xtx * expected_tau);
843
844 self.cov_beta = scirs2_linalg::inv(&precision_beta.view(), None).map_err(|e| {
846 StatsError::ComputationError(format!("Failed to invert precision: {}", e))
847 })?;
848
849 let prior_contrib = prior_precision.dot(&self.prior_mean_beta);
851 let data_contrib = xty * expected_tau;
852 self.mean_beta = self.cov_beta.dot(&(prior_contrib + data_contrib));
853
854 Ok(())
855 }
856
857 fn update_tau_variational(
859 &mut self,
860 n_samples_: f64,
861 xtx: &Array2<f64>,
862 yty: f64,
863 ) -> Result<()> {
864 self.shape_tau = self.priorshape_tau + n_samples_ / 2.0;
866
867 let expected_beta_outer = &self.cov_beta + outer_product(&self.mean_beta);
869 let trace_term = (xtx * &expected_beta_outer).sum();
870 let quadratic_term = 2.0 * self.mean_beta.dot(&xtx.dot(&self.mean_beta));
871
872 self.rate_tau = self.prior_rate_tau + 0.5 * (yty - quadratic_term + trace_term);
873
874 Ok(())
875 }
876
877 fn compute_elbo(
879 &self,
880 n_samples_: f64,
881 xtx: &Array2<f64>,
882 xty: &Array1<f64>,
883 yty: f64,
884 prior_precision: &Array2<f64>,
885 ) -> Result<f64> {
886 let expected_tau = self.shape_tau / self.rate_tau;
887 let expected_log_tau = digamma(self.shape_tau) - self.rate_tau.ln();
888
889 let diff =
891 yty - 2.0 * self.mean_beta.dot(xty) + self.mean_beta.dot(&xtx.dot(&self.mean_beta));
892 let trace_term = (xtx * &self.cov_beta).sum();
893 let likelihood_term = 0.5 * n_samples_ * expected_log_tau
894 - 0.5 * n_samples_ * (2.0_f64 * PI).ln()
895 - 0.5 * expected_tau * (diff + trace_term);
896
897 let beta_diff = &self.mean_beta - &self.prior_mean_beta;
899 let beta_quad = beta_diff.dot(&prior_precision.dot(&beta_diff));
900 let beta_trace = (prior_precision * &self.cov_beta).sum();
901
902 let prior_det = scirs2_linalg::det(&prior_precision.view(), None).map_err(|e| {
903 StatsError::ComputationError(format!("Failed to compute determinant: {}", e))
904 })?;
905
906 let beta_prior_term = 0.5 * prior_det.ln()
907 - 0.5 * self.n_features as f64 * (2.0_f64 * PI).ln()
908 - 0.5 * (beta_quad + beta_trace);
909
910 let tau_prior_term = self.priorshape_tau * self.prior_rate_tau.ln()
912 - lgamma(self.priorshape_tau)
913 + (self.priorshape_tau - 1.0) * expected_log_tau
914 - self.prior_rate_tau * expected_tau;
915
916 let var_det = scirs2_linalg::det(&self.cov_beta.view(), None).map_err(|e| {
918 StatsError::ComputationError(format!("Failed to compute determinant: {}", e))
919 })?;
920 let beta_entropy =
921 0.5 * self.n_features as f64 * (1.0 + (2.0_f64 * PI).ln()) + 0.5 * var_det.ln();
922
923 let tau_entropy = self.shape_tau - self.rate_tau.ln()
925 + lgamma(self.shape_tau)
926 + (1.0 - self.shape_tau) * digamma(self.shape_tau);
927
928 Ok(likelihood_term + beta_prior_term + tau_prior_term + beta_entropy + tau_entropy)
929 }
930
931 pub fn predict(
933 &self,
934 x: ArrayView2<f64>,
935 result: &VariationalRegressionResult,
936 ) -> Result<VariationalPredictionResult> {
937 checkarray_finite(&x, "x")?;
938 let (n_test, n_features) = x.dim();
939
940 if n_features != result.n_features {
941 return Err(StatsError::DimensionMismatch(format!(
942 "x has {} features, expected {}",
943 n_features, result.n_features
944 )));
945 }
946
947 let x_centered = if let Some(ref x_mean) = result.x_mean {
949 let mut x_c = x.to_owned();
950 for mut row in x_c.rows_mut() {
951 row -= x_mean;
952 }
953 x_c
954 } else {
955 x.to_owned()
956 };
957
958 let y_pred_centered = x_centered.dot(&result.mean_beta);
960 let y_pred = if let Some(y_mean) = result.y_mean {
961 &y_pred_centered + y_mean
962 } else {
963 y_pred_centered.clone()
964 };
965
966 let expected_noise_variance = result.rate_tau / result.shape_tau;
968 let mut predictive_variance = Array1::zeros(n_test);
969
970 for i in 0..n_test {
971 let x_row = x_centered.row(i);
972 let model_variance = x_row.dot(&result.cov_beta.dot(&x_row));
973 predictive_variance[i] = expected_noise_variance + model_variance;
974 }
975
976 Ok(VariationalPredictionResult {
977 mean: y_pred,
978 variance: predictive_variance.clone(),
979 model_uncertainty: predictive_variance.mapv(|v| (v - expected_noise_variance).max(0.0)),
980 noise_variance: expected_noise_variance,
981 })
982 }
983}
984
985#[derive(Debug, Clone)]
987pub struct VariationalRegressionResult {
988 pub mean_beta: Array1<f64>,
990 pub cov_beta: Array2<f64>,
992 pub shape_tau: f64,
994 pub rate_tau: f64,
996 pub elbo: f64,
998 pub elbo_history: Vec<f64>,
1000 pub n_samples_: usize,
1002 pub n_features: usize,
1004 pub x_mean: Option<Array1<f64>>,
1006 pub y_mean: Option<f64>,
1008 pub converged: bool,
1010}
1011
1012impl VariationalRegressionResult {
1013 pub fn std_beta(&self) -> Array1<f64> {
1015 self.cov_beta.diag().mapv(f64::sqrt)
1016 }
1017
1018 pub fn precision_stats(&self) -> (f64, f64) {
1020 let mean = self.shape_tau / self.rate_tau;
1021 let variance = self.shape_tau / (self.rate_tau * self.rate_tau);
1022 (mean, variance.sqrt())
1023 }
1024
1025 pub fn credible_intervals(&self, confidence: f64) -> Result<Array2<f64>> {
1027 check_probability(confidence, "confidence")?;
1028
1029 let n_features = self.mean_beta.len();
1030 let mut intervals = Array2::zeros((n_features, 2));
1031 let alpha = (1.0 - confidence) / 2.0;
1032
1033 for i in 0..n_features {
1035 let mean = self.mean_beta[i];
1036 let std = self.cov_beta[[i, i]].sqrt();
1037
1038 let z_critical = normal_ppf(1.0 - alpha)?;
1040 intervals[[i, 0]] = mean - z_critical * std;
1041 intervals[[i, 1]] = mean + z_critical * std;
1042 }
1043
1044 Ok(intervals)
1045 }
1046}
1047
1048#[derive(Debug, Clone)]
1050pub struct VariationalPredictionResult {
1051 pub mean: Array1<f64>,
1053 pub variance: Array1<f64>,
1055 pub model_uncertainty: Array1<f64>,
1057 pub noise_variance: f64,
1059}
1060
1061impl VariationalPredictionResult {
1062 pub fn std(&self) -> Array1<f64> {
1064 self.variance.mapv(f64::sqrt)
1065 }
1066
1067 pub fn credible_intervals(&self, confidence: f64) -> Result<Array2<f64>> {
1069 check_probability(confidence, "confidence")?;
1070
1071 let n_predictions = self.mean.len();
1072 let mut intervals = Array2::zeros((n_predictions, 2));
1073 let alpha = (1.0 - confidence) / 2.0;
1074
1075 let z_critical = normal_ppf(1.0 - alpha)?;
1076
1077 for i in 0..n_predictions {
1078 let mean = self.mean[i];
1079 let std = self.variance[i].sqrt();
1080 intervals[[i, 0]] = mean - z_critical * std;
1081 intervals[[i, 1]] = mean + z_critical * std;
1082 }
1083
1084 Ok(intervals)
1085 }
1086}
1087
1088#[derive(Debug, Clone)]
1096pub struct VariationalARD {
1097 pub mean_beta: Array1<f64>,
1099 pub var_beta: Array1<f64>,
1101 pub shape_alpha: Array1<f64>,
1103 pub rate_alpha: Array1<f64>,
1104 pub shape_tau: f64,
1106 pub rate_tau: f64,
1107 pub priorshape_alpha: f64,
1109 pub prior_rate_alpha: f64,
1110 pub priorshape_tau: f64,
1111 pub prior_rate_tau: f64,
1112 pub n_features: usize,
1114 pub fit_intercept: bool,
1115}
1116
1117impl VariationalARD {
1118 pub fn new(n_features: usize, fit_intercept: bool) -> Result<Self> {
1120 check_positive(n_features, "n_features")?;
1121
1122 let priorshape_alpha = 1e-3;
1124 let prior_rate_alpha = 1e-3;
1125 let priorshape_tau = 1e-3;
1126 let prior_rate_tau = 1e-3;
1127
1128 Ok(Self {
1129 mean_beta: Array1::zeros(n_features),
1130 var_beta: Array1::from_elem(n_features, 1.0),
1131 shape_alpha: Array1::from_elem(n_features, priorshape_alpha),
1132 rate_alpha: Array1::from_elem(n_features, prior_rate_alpha),
1133 shape_tau: priorshape_tau,
1134 rate_tau: prior_rate_tau,
1135 priorshape_alpha,
1136 prior_rate_alpha,
1137 priorshape_tau,
1138 prior_rate_tau,
1139 n_features,
1140 fit_intercept,
1141 })
1142 }
1143
1144 pub fn fit(
1146 &mut self,
1147 x: ArrayView2<f64>,
1148 y: ArrayView1<f64>,
1149 max_iter: usize,
1150 tol: f64,
1151 ) -> Result<VariationalARDResult> {
1152 checkarray_finite(&x, "x")?;
1153 checkarray_finite(&y, "y")?;
1154 check_positive(max_iter, "max_iter")?;
1155 check_positive(tol, "tol")?;
1156
1157 let (n_samples_, n_features) = x.dim();
1158 if y.len() != n_samples_ {
1159 return Err(StatsError::DimensionMismatch(format!(
1160 "y length ({}) must match x rows ({})",
1161 y.len(),
1162 n_samples_
1163 )));
1164 }
1165
1166 let (x_centered, y_centered, x_mean, y_mean) = if self.fit_intercept {
1168 let x_mean = x.mean_axis(Axis(0)).expect("Operation failed");
1169 let y_mean = y.mean();
1170
1171 let mut x_centered = x.to_owned();
1172 for mut row in x_centered.rows_mut() {
1173 row -= &x_mean;
1174 }
1175 let y_centered = &y.to_owned() - y_mean;
1176
1177 (x_centered, y_centered, Some(x_mean), Some(y_mean))
1178 } else {
1179 (x.to_owned(), y.to_owned(), None, None)
1180 };
1181
1182 let xtx = x_centered.t().dot(&x_centered);
1184 let xty = x_centered.t().dot(&y_centered);
1185 let yty = y_centered.dot(&y_centered);
1186
1187 let mut prev_elbo = f64::NEG_INFINITY;
1188 let mut elbo_history = Vec::new();
1189
1190 for _iter in 0..max_iter {
1191 self.update_beta_ard(&xtx, &xty)?;
1193
1194 self.update_alpha_ard()?;
1196
1197 self.update_tau_ard(n_samples_ as f64, &xtx, yty)?;
1199
1200 let elbo = self.compute_elbo_ard(n_samples_ as f64, &xtx, &xty, yty)?;
1202 elbo_history.push(elbo);
1203
1204 if _iter > 0 && (elbo - prev_elbo).abs() < tol {
1206 break;
1207 }
1208
1209 if _iter % 10 == 0 {
1211 self.prune_features()?;
1212 }
1213
1214 prev_elbo = elbo;
1215 }
1216
1217 Ok(VariationalARDResult {
1218 mean_beta: self.mean_beta.clone(),
1219 var_beta: self.var_beta.clone(),
1220 shape_alpha: self.shape_alpha.clone(),
1221 rate_alpha: self.rate_alpha.clone(),
1222 shape_tau: self.shape_tau,
1223 rate_tau: self.rate_tau,
1224 elbo: prev_elbo,
1225 elbo_history: elbo_history.clone(),
1226 n_samples_,
1227 n_features: self.n_features,
1228 x_mean,
1229 y_mean,
1230 converged: elbo_history.len() < max_iter,
1231 })
1232 }
1233
1234 fn update_beta_ard(&mut self, xtx: &Array2<f64>, xty: &Array1<f64>) -> Result<()> {
1236 let expected_tau = self.shape_tau / self.rate_tau;
1237 let expected_alpha = &self.shape_alpha / &self.rate_alpha;
1238
1239 for i in 0..self.n_features {
1241 let precision_i = expected_alpha[i] + expected_tau * xtx[[i, i]];
1242 self.var_beta[i] = 1.0 / precision_i;
1243 }
1244
1245 for i in 0..self.n_features {
1247 let sum_j = (0..self.n_features)
1248 .filter(|&j| j != i)
1249 .map(|j| xtx[[i, j]] * self.mean_beta[j])
1250 .sum::<f64>();
1251
1252 self.mean_beta[i] = expected_tau * self.var_beta[i] * (xty[i] - sum_j);
1253 }
1254
1255 Ok(())
1256 }
1257
1258 fn update_alpha_ard(&mut self) -> Result<()> {
1260 for i in 0..self.n_features {
1261 self.shape_alpha[i] = self.priorshape_alpha + 0.5;
1262 self.rate_alpha[i] =
1263 self.prior_rate_alpha + 0.5 * (self.mean_beta[i].powi(2) + self.var_beta[i]);
1264 }
1265
1266 Ok(())
1267 }
1268
1269 fn update_tau_ard(&mut self, n_samples_: f64, xtx: &Array2<f64>, yty: f64) -> Result<()> {
1271 self.shape_tau = self.priorshape_tau + n_samples_ / 2.0;
1272
1273 let mut quadratic_term = 0.0;
1274 for i in 0..self.n_features {
1275 for j in 0..self.n_features {
1276 if i == j {
1277 quadratic_term += xtx[[i, j]] * (self.mean_beta[i].powi(2) + self.var_beta[i]);
1278 } else {
1279 quadratic_term += xtx[[i, j]] * self.mean_beta[i] * self.mean_beta[j];
1280 }
1281 }
1282 }
1283
1284 self.rate_tau = self.prior_rate_tau
1285 + 0.5 * (yty - 2.0 * self.mean_beta.dot(&xtx.dot(&self.mean_beta)) + quadratic_term);
1286
1287 Ok(())
1288 }
1289
1290 fn compute_elbo_ard(
1292 &self,
1293 n_samples_: f64,
1294 xtx: &Array2<f64>,
1295 xty: &Array1<f64>,
1296 yty: f64,
1297 ) -> Result<f64> {
1298 let expected_tau = self.shape_tau / self.rate_tau;
1299 let expected_log_tau = digamma(self.shape_tau) - self.rate_tau.ln();
1300
1301 let mut quadratic_form = yty - 2.0 * self.mean_beta.dot(xty);
1303 for i in 0..self.n_features {
1304 for j in 0..self.n_features {
1305 if i == j {
1306 quadratic_form += xtx[[i, j]] * (self.mean_beta[i].powi(2) + self.var_beta[i]);
1307 } else {
1308 quadratic_form += xtx[[i, j]] * self.mean_beta[i] * self.mean_beta[j];
1309 }
1310 }
1311 }
1312
1313 let likelihood_term = 0.5 * n_samples_ * expected_log_tau
1314 - 0.5 * n_samples_ * (2.0_f64 * PI).ln()
1315 - 0.5 * expected_tau * quadratic_form;
1316
1317 let mut prior_term = 0.0;
1319 for i in 0..self.n_features {
1320 let expected_alpha_i = self.shape_alpha[i] / self.rate_alpha[i];
1321 let expected_log_alpha_i = digamma(self.shape_alpha[i]) - self.rate_alpha[i].ln();
1322
1323 prior_term += 0.5 * expected_log_alpha_i
1324 - 0.5 * (2.0_f64 * PI).ln()
1325 - 0.5 * expected_alpha_i * (self.mean_beta[i].powi(2) + self.var_beta[i]);
1326 }
1327
1328 let mut entropy_term = 0.0;
1330 for i in 0..self.n_features {
1331 entropy_term += 0.5 * (1.0 + (2.0 * PI * self.var_beta[i]).ln());
1332 }
1333
1334 Ok(likelihood_term + prior_term + entropy_term)
1335 }
1336
1337 fn prune_features(&mut self) -> Result<()> {
1339 let threshold = 1e12; for i in 0..self.n_features {
1342 let expected_alpha = self.shape_alpha[i] / self.rate_alpha[i];
1343 if expected_alpha > threshold {
1344 self.mean_beta[i] = 0.0;
1346 self.var_beta[i] = 1e-12;
1347 }
1348 }
1349
1350 Ok(())
1351 }
1352
1353 pub fn feature_relevance(&self) -> Array1<f64> {
1355 let expected_alpha = &self.shape_alpha / &self.rate_alpha;
1356 expected_alpha.mapv(|alpha| 1.0 / alpha)
1358 }
1359}
1360
1361#[derive(Debug, Clone)]
1363pub struct VariationalARDResult {
1364 pub mean_beta: Array1<f64>,
1366 pub var_beta: Array1<f64>,
1368 pub shape_alpha: Array1<f64>,
1370 pub rate_alpha: Array1<f64>,
1372 pub shape_tau: f64,
1374 pub rate_tau: f64,
1376 pub elbo: f64,
1378 pub elbo_history: Vec<f64>,
1380 pub n_samples_: usize,
1382 pub n_features: usize,
1384 pub x_mean: Option<Array1<f64>>,
1386 pub y_mean: Option<f64>,
1388 pub converged: bool,
1390}
1391
1392impl VariationalARDResult {
1393 pub fn selected_features(&self, threshold: f64) -> Vec<usize> {
1395 let expected_alpha = &self.shape_alpha / &self.rate_alpha;
1396 expected_alpha
1397 .iter()
1398 .enumerate()
1399 .filter(|(_, &alpha)| alpha < threshold) .map(|(i, _)| i)
1401 .collect()
1402 }
1403
1404 pub fn feature_importance(&self) -> Array1<f64> {
1406 self.mean_beta.mapv(f64::abs)
1407 }
1408}
1409
1410pub(crate) fn outer_product(v: &Array1<f64>) -> Array2<f64> {
1416 let n = v.len();
1417 let mut result = Array2::zeros((n, n));
1418 for i in 0..n {
1419 for j in 0..n {
1420 result[[i, j]] = v[i] * v[j];
1421 }
1422 }
1423 result
1424}
1425
1426pub(crate) fn normal_ppf(p: f64) -> Result<f64> {
1428 if p <= 0.0 || p >= 1.0 {
1429 return Err(StatsError::InvalidArgument(
1430 "p must be between 0 and 1".to_string(),
1431 ));
1432 }
1433
1434 let a = [
1435 -3.969683028665376e+01,
1436 2.209460984245205e+02,
1437 -2.759285104469687e+02,
1438 1.383_577_518_672_69e2,
1439 -3.066479806614716e+01,
1440 2.506628277459239e+00,
1441 ];
1442
1443 let b = [
1444 -5.447609879822406e+01,
1445 1.615858368580409e+02,
1446 -1.556989798598866e+02,
1447 6.680131188771972e+01,
1448 -1.328068155288572e+01,
1449 ];
1450
1451 let c = [
1452 -7.784894002430293e-03,
1453 -3.223964580411365e-01,
1454 -2.400758277161838e+00,
1455 -2.549732539343734e+00,
1456 4.374664141464968e+00,
1457 2.938163982698783e+00,
1458 ];
1459
1460 let d = [
1461 7.784695709041462e-03,
1462 3.224671290700398e-01,
1463 2.445134137142996e+00,
1464 3.754408661907416e+00,
1465 ];
1466
1467 let p_low = 0.02425;
1468 let p_high = 1.0 - p_low;
1469
1470 if p < p_low {
1471 let q = (-2.0 * p.ln()).sqrt();
1472 Ok(
1473 (((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5])
1474 / ((((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0),
1475 )
1476 } else if p <= p_high {
1477 let q = p - 0.5;
1478 let r = q * q;
1479 Ok(
1480 (((((a[0] * r + a[1]) * r + a[2]) * r + a[3]) * r + a[4]) * r + a[5]) * q
1481 / (((((b[0] * r + b[1]) * r + b[2]) * r + b[3]) * r + b[4]) * r + 1.0),
1482 )
1483 } else {
1484 let q = (-2.0 * (1.0 - p).ln()).sqrt();
1485 Ok(
1486 (-((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5])
1487 / ((((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0),
1488 )
1489 }
1490}
1491
1492pub(crate) fn digamma(x: f64) -> f64 {
1494 if x <= 0.0 {
1495 return f64::NEG_INFINITY;
1496 }
1497
1498 if x < 8.0 {
1499 return digamma(x + 1.0) - 1.0 / x;
1500 }
1501
1502 let inv_x = 1.0 / x;
1503 let inv_x2 = inv_x * inv_x;
1504
1505 x.ln() - 0.5 * inv_x - inv_x2 / 12.0 + inv_x2 * inv_x2 / 120.0
1506 - inv_x2 * inv_x2 * inv_x2 / 252.0
1507}
1508
1509pub(crate) fn lgamma(x: f64) -> f64 {
1511 if x <= 0.0 {
1512 return f64::NEG_INFINITY;
1513 }
1514
1515 if x < 0.5 {
1517 return (PI / (PI * x).sin()).ln() - lgamma(1.0 - x);
1519 }
1520
1521 const G: f64 = 7.0;
1524 const C: [f64; 9] = [
1525 0.99999999999980993,
1526 676.5203681218851,
1527 -1259.1392167224028,
1528 771.323_428_777_653_1,
1529 -176.615_029_162_140_6,
1530 12.507_343_278_686_905,
1531 -0.138_571_095_265_720_12,
1532 9.984_369_578_019_572e-6,
1533 1.505_632_735_149_311_6e-7,
1534 ];
1535
1536 let x = x - 1.0;
1537 let mut a = C[0];
1538 let t = x + G + 0.5;
1539 for (i, &c) in C[1..].iter().enumerate() {
1540 a += c / (x + (i as f64 + 1.0));
1541 }
1542 0.5 * (2.0 * PI).ln() + (x + 0.5) * t.ln() - t + a.ln()
1543}
1544
1545pub(crate) fn trigamma(x: f64) -> f64 {
1547 if x <= 0.0 {
1548 return f64::INFINITY;
1549 }
1550
1551 if x < 8.0 {
1552 return trigamma(x + 1.0) + 1.0 / (x * x);
1553 }
1554
1555 let inv_x = 1.0 / x;
1556 let inv_x2 = inv_x * inv_x;
1557
1558 inv_x + 0.5 * inv_x2 + inv_x2 * inv_x / 6.0 - inv_x2 * inv_x2 * inv_x / 30.0
1559 + inv_x2 * inv_x2 * inv_x2 * inv_x / 42.0
1560}
1561
1562#[cfg(test)]
1567mod tests {
1568 use super::*;
1569 use scirs2_core::ndarray::Array2;
1570
1571 #[test]
1572 fn test_mean_field_gaussian_creation() {
1573 let mf = MeanFieldGaussian::new(5).expect("should create mean-field Gaussian");
1574 assert_eq!(mf.dim, 5);
1575 assert_eq!(mf.means.len(), 5);
1576 assert_eq!(mf.log_stds.len(), 5);
1577 assert_eq!(mf.n_params(), 10);
1578 }
1579
1580 #[test]
1581 fn test_mean_field_gaussian_entropy() {
1582 let mf = MeanFieldGaussian::new(2).expect("should create");
1583 let entropy = mf.entropy();
1584 let expected = 2.0 * 0.5 * (1.0 + (2.0 * PI).ln());
1586 assert!((entropy - expected).abs() < 1e-10);
1587 }
1588
1589 #[test]
1590 fn test_mean_field_gaussian_sample() {
1591 let mf = MeanFieldGaussian::new(3).expect("should create");
1592 let epsilon = Array1::from_vec(vec![0.5, -0.3, 1.0]);
1593 let sample = mf.sample(&epsilon).expect("should sample");
1594 assert_eq!(sample.len(), 3);
1595 for i in 0..3 {
1597 assert!((sample[i] - epsilon[i]).abs() < 1e-10);
1598 }
1599 }
1600
1601 #[test]
1602 fn test_mean_field_gaussian_params_roundtrip() {
1603 let mut mf = MeanFieldGaussian::new(3).expect("should create");
1604 let params = Array1::from_vec(vec![1.0, 2.0, 3.0, 0.5, -0.3, 0.1]);
1605 mf.set_params(¶ms).expect("should set params");
1606 let retrieved = mf.get_params();
1607 for i in 0..6 {
1608 assert!((retrieved[i] - params[i]).abs() < 1e-10);
1609 }
1610 }
1611
1612 #[test]
1613 fn test_full_rank_gaussian_creation() {
1614 let fr = FullRankGaussian::new(3).expect("should create full-rank Gaussian");
1615 assert_eq!(fr.dim, 3);
1616 assert_eq!(fr.mean.len(), 3);
1617 assert_eq!(fr.n_params(), 9);
1619 }
1620
1621 #[test]
1622 fn test_full_rank_gaussian_entropy() {
1623 let fr = FullRankGaussian::new(2).expect("should create");
1624 let entropy = fr.entropy();
1625 let expected = 2.0 * 0.5 * (1.0 + (2.0 * PI).ln());
1627 assert!((entropy - expected).abs() < 1e-10);
1628 }
1629
1630 #[test]
1631 fn test_full_rank_gaussian_sample() {
1632 let fr = FullRankGaussian::new(2).expect("should create");
1633 let epsilon = Array1::from_vec(vec![1.0, -1.0]);
1634 let sample = fr.sample(&epsilon).expect("should sample");
1635 assert_eq!(sample.len(), 2);
1636 for i in 0..2 {
1638 assert!((sample[i] - epsilon[i]).abs() < 1e-10);
1639 }
1640 }
1641
1642 #[test]
1643 fn test_normalizing_flow_creation() {
1644 let nf = NormalizingFlowVI::new(3, 2).expect("should create");
1645 assert_eq!(nf.dim, 3);
1646 assert_eq!(nf.flows.len(), 2);
1647 }
1648
1649 #[test]
1650 fn test_normalizing_flow_transform() {
1651 let nf = NormalizingFlowVI::new(2, 1).expect("should create");
1652 let z0 = Array1::from_vec(vec![0.5, -0.5]);
1653 let (z_k, log_det) = nf.transform(&z0).expect("should transform");
1654 assert_eq!(z_k.len(), 2);
1655 assert!(log_det.is_finite());
1656 }
1657
1658 #[test]
1659 fn test_diagnostics() {
1660 let mut diag = VariationalDiagnostics::new();
1661 diag.record_elbo(-100.0);
1662 diag.record_elbo(-90.0);
1663 diag.record_elbo(-85.0);
1664 diag.record_gradient_norm(10.0);
1665 diag.record_gradient_norm(5.0);
1666
1667 assert_eq!(diag.n_iterations, 3);
1668 assert!(!diag.check_elbo_convergence(1.0));
1669 assert!(diag.check_elbo_convergence(10.0));
1670
1671 let summary = diag.elbo_summary();
1672 assert!((summary.min - (-100.0)).abs() < 1e-10);
1673 assert!((summary.max - (-85.0)).abs() < 1e-10);
1674 assert!(summary.monotonic);
1675 }
1676
1677 #[test]
1678 fn test_variational_bayesian_regression() {
1679 let n = 50;
1681 let mut x_data = Vec::with_capacity(n);
1682 let mut y_data = Vec::with_capacity(n);
1683
1684 for i in 0..n {
1685 let xi = i as f64 / n as f64;
1686 x_data.push(xi);
1687 y_data.push(2.0 * xi + 1.0 + 0.1 * ((i * 7 % 13) as f64 - 6.0) / 6.0);
1688 }
1689
1690 let x = Array2::from_shape_fn((n, 1), |(i, _)| x_data[i]);
1691 let y = Array1::from_vec(y_data);
1692
1693 let mut model = VariationalBayesianRegression::new(1, true).expect("should create model");
1694 let result = model
1695 .fit(x.view(), y.view(), 100, 1e-6)
1696 .expect("should fit");
1697
1698 assert!(
1700 (result.mean_beta[0] - 2.0).abs() < 0.5,
1701 "beta should be close to 2.0, got {}",
1702 result.mean_beta[0]
1703 );
1704 }
1705
1706 #[test]
1707 fn test_trigamma() {
1708 let expected = PI * PI / 6.0;
1710 let computed = trigamma(1.0);
1711 assert!(
1712 (computed - expected).abs() < 0.01,
1713 "trigamma(1) should be close to pi^2/6, got {}",
1714 computed
1715 );
1716 }
1717}