1use crate::error::{StatsError, StatsResult as Result};
12use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
13use scirs2_core::validation::*;
14use std::f64::consts::PI;
15
16use super::{digamma, lgamma, FullRankGaussian, MeanFieldGaussian, VariationalDiagnostics};
17
18#[derive(Debug, Clone)]
24pub enum LearningRateSchedule {
25 Constant {
27 lr: f64,
29 },
30 RobbinsMonro {
32 lr0: f64,
34 decay: f64,
36 },
37 ExponentialDecay {
39 lr0: f64,
41 gamma: f64,
43 },
44 Adam {
46 lr: f64,
48 beta1: f64,
50 beta2: f64,
52 epsilon: f64,
54 },
55}
56
57impl LearningRateSchedule {
58 pub fn get_lr(&self, t: usize) -> f64 {
60 match self {
61 LearningRateSchedule::Constant { lr } => *lr,
62 LearningRateSchedule::RobbinsMonro { lr0, decay } => lr0 / (1.0 + decay * t as f64),
63 LearningRateSchedule::ExponentialDecay { lr0, gamma } => lr0 * gamma.powi(t as i32),
64 LearningRateSchedule::Adam { lr, .. } => {
65 *lr
67 }
68 }
69 }
70
71 pub fn default_adam() -> Self {
73 LearningRateSchedule::Adam {
74 lr: 0.01,
75 beta1: 0.9,
76 beta2: 0.999,
77 epsilon: 1e-8,
78 }
79 }
80
81 pub fn default_robbins_monro() -> Self {
83 LearningRateSchedule::RobbinsMonro {
84 lr0: 0.1,
85 decay: 0.01,
86 }
87 }
88}
89
90#[derive(Debug, Clone)]
96pub struct AdamState {
97 pub m: Array1<f64>,
99 pub v: Array1<f64>,
101 pub beta1: f64,
103 pub beta2: f64,
105 pub epsilon: f64,
107 pub lr: f64,
109 pub t: usize,
111}
112
113impl AdamState {
114 pub fn new(dim: usize, lr: f64, beta1: f64, beta2: f64, epsilon: f64) -> Result<Self> {
116 check_positive(dim, "dim")?;
117 check_positive(lr, "lr")?;
118 check_positive(epsilon, "epsilon")?;
119
120 Ok(Self {
121 m: Array1::zeros(dim),
122 v: Array1::zeros(dim),
123 beta1,
124 beta2,
125 epsilon,
126 lr,
127 t: 0,
128 })
129 }
130
131 pub fn compute_update(&mut self, gradient: &Array1<f64>) -> Result<Array1<f64>> {
133 if gradient.len() != self.m.len() {
134 return Err(StatsError::DimensionMismatch(format!(
135 "gradient length ({}) must match state dimension ({})",
136 gradient.len(),
137 self.m.len()
138 )));
139 }
140
141 self.t += 1;
142
143 self.m = &self.m * self.beta1 + gradient * (1.0 - self.beta1);
145
146 self.v = &self.v * self.beta2 + &gradient.mapv(|g| g * g) * (1.0 - self.beta2);
148
149 let m_hat = &self.m / (1.0 - self.beta1.powi(self.t as i32));
151
152 let v_hat = &self.v / (1.0 - self.beta2.powi(self.t as i32));
154
155 let update = &m_hat / &v_hat.mapv(|vi| vi.sqrt() + self.epsilon) * self.lr;
157
158 Ok(update)
159 }
160
161 pub fn reset(&mut self) {
163 self.m.fill(0.0);
164 self.v.fill(0.0);
165 self.t = 0;
166 }
167}
168
169#[derive(Debug, Clone)]
175pub struct NaturalGradientParams {
176 pub eta: Array1<f64>,
178 pub fisher_diag: Array1<f64>,
181}
182
183impl NaturalGradientParams {
184 pub fn from_mean_field(mf: &MeanFieldGaussian) -> Self {
193 let dim = mf.dim;
194 let stds = mf.stds();
195 let vars = mf.variances();
196
197 let mut eta = Array1::zeros(2 * dim);
199 let mut fisher_diag = Array1::zeros(2 * dim);
200
201 for i in 0..dim {
202 eta[i] = mf.means[i] / vars[i];
204 eta[dim + i] = -1.0 / (2.0 * vars[i]);
206
207 fisher_diag[i] = 1.0 / vars[i]; fisher_diag[dim + i] = 2.0 / (stds[i].powi(4)); }
211
212 Self { eta, fisher_diag }
213 }
214
215 pub fn to_mean_field(&self) -> Result<MeanFieldGaussian> {
217 let dim = self.eta.len() / 2;
218 if dim == 0 {
219 return Err(StatsError::InvalidArgument(
220 "Natural parameters must have positive dimension".to_string(),
221 ));
222 }
223
224 let mut means = Array1::zeros(dim);
225 let mut log_stds = Array1::zeros(dim);
226
227 for i in 0..dim {
228 let eta2 = self.eta[dim + i];
229 if eta2 >= 0.0 {
230 return Err(StatsError::InvalidArgument(format!(
231 "eta_2[{}] = {} must be negative for valid Gaussian",
232 i, eta2
233 )));
234 }
235 let var = -1.0 / (2.0 * eta2);
236 let mean = self.eta[i] * var;
237 means[i] = mean;
238 log_stds[i] = 0.5 * var.ln();
239 }
240
241 MeanFieldGaussian::from_params(means, log_stds)
242 }
243
244 pub fn natural_gradient_update(&self, euclidean_grad: &Array1<f64>) -> Result<Array1<f64>> {
247 if euclidean_grad.len() != self.fisher_diag.len() {
248 return Err(StatsError::DimensionMismatch(format!(
249 "gradient length ({}) must match parameter dimension ({})",
250 euclidean_grad.len(),
251 self.fisher_diag.len()
252 )));
253 }
254
255 let mut nat_grad = Array1::zeros(euclidean_grad.len());
256 for i in 0..euclidean_grad.len() {
257 if self.fisher_diag[i].abs() < 1e-15 {
258 nat_grad[i] = 0.0; } else {
260 nat_grad[i] = euclidean_grad[i] / self.fisher_diag[i];
261 }
262 }
263
264 Ok(nat_grad)
265 }
266}
267
268#[derive(Debug, Clone)]
274pub struct SviConfig {
275 pub max_iter: usize,
277 pub batch_size: usize,
279 pub lr_schedule: LearningRateSchedule,
281 pub tol: f64,
283 pub n_mc_samples: usize,
285 pub use_natural_gradient: bool,
287 pub diagnostic_interval: usize,
289 pub grad_clip: f64,
291 pub seed: u64,
293}
294
295impl Default for SviConfig {
296 fn default() -> Self {
297 Self {
298 max_iter: 1000,
299 batch_size: 32,
300 lr_schedule: LearningRateSchedule::default_adam(),
301 tol: 1e-4,
302 n_mc_samples: 1,
303 use_natural_gradient: false,
304 diagnostic_interval: 50,
305 grad_clip: 10.0,
306 seed: 42,
307 }
308 }
309}
310
311#[derive(Debug, Clone)]
330pub struct StochasticVI {
331 pub variational: MeanFieldGaussian,
333 pub config: SviConfig,
335 pub diagnostics: VariationalDiagnostics,
337 adam_state: Option<AdamState>,
339}
340
341impl StochasticVI {
342 pub fn new(dim: usize, config: SviConfig) -> Result<Self> {
344 check_positive(dim, "dim")?;
345
346 let variational = MeanFieldGaussian::new(dim)?;
347
348 let adam_state = if let LearningRateSchedule::Adam {
349 lr,
350 beta1,
351 beta2,
352 epsilon,
353 } = &config.lr_schedule
354 {
355 Some(AdamState::new(2 * dim, *lr, *beta1, *beta2, *epsilon)?)
356 } else {
357 None
358 };
359
360 Ok(Self {
361 variational,
362 config,
363 diagnostics: VariationalDiagnostics::new(),
364 adam_state,
365 })
366 }
367
368 pub fn fit<F>(
379 &mut self,
380 data: ArrayView2<f64>,
381 log_joint: F,
382 n_total: usize,
383 ) -> Result<SviResult>
384 where
385 F: Fn(&Array1<f64>, ArrayView2<f64>) -> Result<(f64, Array1<f64>)>,
386 {
387 checkarray_finite(&data, "data")?;
388 check_positive(n_total, "n_total")?;
389
390 let (n_data, _) = data.dim();
391 let batch_size = self.config.batch_size.min(n_data);
392 let scale_factor = n_total as f64 / batch_size as f64;
393
394 let offset = (self.config.seed % n_data as u64) as usize;
396
397 for iter in 0..self.config.max_iter {
398 let batch_start = (offset + iter * batch_size) % n_data;
400 let batch_end = (batch_start + batch_size).min(n_data);
401 let actual_batch_size = batch_end - batch_start;
402
403 let batch = data.slice(scirs2_core::ndarray::s![batch_start..batch_end, ..]);
404
405 let (elbo_estimate, grad) = self.estimate_elbo_gradient(
407 batch,
408 &log_joint,
409 scale_factor * (actual_batch_size as f64 / batch_size as f64),
410 )?;
411
412 self.diagnostics.record_elbo(elbo_estimate);
414 let grad_norm = grad.dot(&grad).sqrt();
415 self.diagnostics.record_gradient_norm(grad_norm);
416
417 let clipped_grad = if self.config.grad_clip > 0.0 && grad_norm > self.config.grad_clip {
419 &grad * (self.config.grad_clip / grad_norm)
420 } else {
421 grad
422 };
423
424 let update = if self.config.use_natural_gradient {
426 let nat_params = NaturalGradientParams::from_mean_field(&self.variational);
427 nat_params.natural_gradient_update(&clipped_grad)?
428 } else {
429 clipped_grad
430 };
431
432 let lr = self.config.lr_schedule.get_lr(iter);
434 let current_params = self.variational.get_params();
435
436 let new_params = if let Some(ref mut adam) = self.adam_state {
437 let adam_update = adam.compute_update(&update)?;
438 ¤t_params + &adam_update
439 } else {
440 ¤t_params + &(&update * lr)
441 };
442
443 let param_change = (&new_params - ¤t_params).mapv(|x| x * x).sum().sqrt();
445 self.diagnostics.record_param_change(param_change);
446
447 self.variational.set_params(&new_params)?;
448
449 if iter > 10 && self.diagnostics.check_elbo_convergence(self.config.tol) {
451 self.diagnostics.converged = true;
452 break;
453 }
454 }
455
456 Ok(SviResult {
457 variational: self.variational.clone(),
458 diagnostics: self.diagnostics.clone(),
459 })
460 }
461
462 fn estimate_elbo_gradient<F>(
464 &self,
465 batch: ArrayView2<f64>,
466 log_joint: &F,
467 scale_factor: f64,
468 ) -> Result<(f64, Array1<f64>)>
469 where
470 F: Fn(&Array1<f64>, ArrayView2<f64>) -> Result<(f64, Array1<f64>)>,
471 {
472 let dim = self.variational.dim;
473 let n_samples = self.config.n_mc_samples.max(1);
474
475 let mut total_elbo = 0.0;
476 let mut total_grad = Array1::zeros(2 * dim);
477
478 for s in 0..n_samples {
479 let epsilon =
482 generate_standard_normal(dim, s as u64 + self.diagnostics.n_iterations as u64);
483
484 let z = self.variational.sample(&epsilon)?;
486
487 let (log_p, grad_z) = log_joint(&z, batch)?;
489
490 let scaled_log_p = log_p * scale_factor;
492 let scaled_grad_z = &grad_z * scale_factor;
493
494 let log_q = self.variational.log_prob(&z)?;
496
497 total_elbo += scaled_log_p - log_q;
499
500 let stds = self.variational.stds();
504 for i in 0..dim {
505 total_grad[i] += scaled_grad_z[i];
507 total_grad[dim + i] += scaled_grad_z[i] * epsilon[i] * stds[i] + 1.0;
509 }
510
511 for i in 0..dim {
513 let diff = z[i] - self.variational.means[i];
514 let var = stds[i] * stds[i];
515 total_grad[i] -= diff / var;
517 total_grad[dim + i] -= diff * diff / var - 1.0;
519 }
520 }
521
522 total_elbo /= n_samples as f64;
524 total_grad /= n_samples as f64;
525
526 Ok((total_elbo, total_grad))
527 }
528
529 pub fn variational_distribution(&self) -> &MeanFieldGaussian {
531 &self.variational
532 }
533
534 pub fn diagnostics(&self) -> &VariationalDiagnostics {
536 &self.diagnostics
537 }
538
539 pub fn reset_optimizer(&mut self) {
541 if let Some(ref mut adam) = self.adam_state {
542 adam.reset();
543 }
544 self.diagnostics = VariationalDiagnostics::new();
545 }
546}
547
548#[derive(Debug, Clone)]
550pub struct SviResult {
551 pub variational: MeanFieldGaussian,
553 pub diagnostics: VariationalDiagnostics,
555}
556
557impl SviResult {
558 pub fn posterior_means(&self) -> &Array1<f64> {
560 &self.variational.means
561 }
562
563 pub fn posterior_stds(&self) -> Array1<f64> {
565 self.variational.stds()
566 }
567
568 pub fn credible_intervals(&self, confidence: f64) -> Result<Array2<f64>> {
570 check_probability(confidence, "confidence")?;
571 let alpha = (1.0 - confidence) / 2.0;
572 let z_critical = super::normal_ppf(1.0 - alpha)?;
573
574 let dim = self.variational.dim;
575 let mut intervals = Array2::zeros((dim, 2));
576 let stds = self.variational.stds();
577
578 for i in 0..dim {
579 intervals[[i, 0]] = self.variational.means[i] - z_critical * stds[i];
580 intervals[[i, 1]] = self.variational.means[i] + z_critical * stds[i];
581 }
582
583 Ok(intervals)
584 }
585}
586
587#[derive(Debug, Clone)]
599pub struct SviBayesianRegression {
600 pub mean_beta: Array1<f64>,
602 pub log_std_beta: Array1<f64>,
604 pub shape_tau: f64,
606 pub rate_tau: f64,
607 pub prior_var: f64,
609 pub prior_shape: f64,
611 pub prior_rate: f64,
613 pub n_features: usize,
615 pub config: SviConfig,
617}
618
619impl SviBayesianRegression {
620 pub fn new(n_features: usize, config: SviConfig) -> Result<Self> {
622 check_positive(n_features, "n_features")?;
623
624 Ok(Self {
625 mean_beta: Array1::zeros(n_features),
626 log_std_beta: Array1::zeros(n_features),
627 shape_tau: 1.0,
628 rate_tau: 1.0,
629 prior_var: 100.0,
630 prior_shape: 1e-3,
631 prior_rate: 1e-3,
632 n_features,
633 config,
634 })
635 }
636
637 pub fn with_priors(
639 mut self,
640 prior_var: f64,
641 prior_shape: f64,
642 prior_rate: f64,
643 ) -> Result<Self> {
644 check_positive(prior_var, "prior_var")?;
645 check_positive(prior_shape, "prior_shape")?;
646 check_positive(prior_rate, "prior_rate")?;
647 self.prior_var = prior_var;
648 self.prior_shape = prior_shape;
649 self.prior_rate = prior_rate;
650 Ok(self)
651 }
652
653 pub fn fit(&mut self, x: ArrayView2<f64>, y: ArrayView1<f64>) -> Result<SviRegressionResult> {
655 checkarray_finite(&x, "x")?;
656 checkarray_finite(&y, "y")?;
657
658 let (n_samples, n_features) = x.dim();
659 if y.len() != n_samples {
660 return Err(StatsError::DimensionMismatch(format!(
661 "y length ({}) must match x rows ({})",
662 y.len(),
663 n_samples
664 )));
665 }
666 if n_features != self.n_features {
667 return Err(StatsError::DimensionMismatch(format!(
668 "x features ({}) must match model features ({})",
669 n_features, self.n_features
670 )));
671 }
672
673 let batch_size = self.config.batch_size.min(n_samples);
674 let scale_factor = n_samples as f64 / batch_size as f64;
675 let offset = (self.config.seed % n_samples as u64) as usize;
676
677 let n_params = 2 * self.n_features + 2;
680 let mut adam_state = if let LearningRateSchedule::Adam {
681 lr,
682 beta1,
683 beta2,
684 epsilon,
685 } = &self.config.lr_schedule
686 {
687 Some(AdamState::new(n_params, *lr, *beta1, *beta2, *epsilon)?)
688 } else {
689 None
690 };
691
692 let mut diagnostics = VariationalDiagnostics::new();
693
694 for iter in 0..self.config.max_iter {
695 let batch_start = (offset + iter * batch_size) % n_samples;
697 let batch_end = (batch_start + batch_size).min(n_samples);
698
699 let x_batch = x.slice(scirs2_core::ndarray::s![batch_start..batch_end, ..]);
700 let y_batch = y.slice(scirs2_core::ndarray::s![batch_start..batch_end]);
701
702 let (elbo, grad) =
704 self.compute_stochastic_elbo_grad(x_batch, y_batch, scale_factor, iter as u64)?;
705
706 diagnostics.record_elbo(elbo);
707
708 let grad_norm = grad.dot(&grad).sqrt();
709 diagnostics.record_gradient_norm(grad_norm);
710
711 let clipped_grad = if self.config.grad_clip > 0.0 && grad_norm > self.config.grad_clip {
713 &grad * (self.config.grad_clip / grad_norm)
714 } else {
715 grad
716 };
717
718 let current_params = self.get_params();
720
721 let new_params = if let Some(ref mut adam) = adam_state {
723 let update = adam.compute_update(&clipped_grad)?;
724 ¤t_params + &update
725 } else {
726 let lr = self.config.lr_schedule.get_lr(iter);
727 ¤t_params + &(&clipped_grad * lr)
728 };
729
730 let param_change = (&new_params - ¤t_params).mapv(|x| x * x).sum().sqrt();
731 diagnostics.record_param_change(param_change);
732
733 self.set_params(&new_params)?;
734
735 if iter > 20 && diagnostics.check_elbo_convergence(self.config.tol) {
737 diagnostics.converged = true;
738 break;
739 }
740 }
741
742 Ok(SviRegressionResult {
743 mean_beta: self.mean_beta.clone(),
744 std_beta: self.log_std_beta.mapv(f64::exp),
745 shape_tau: self.shape_tau,
746 rate_tau: self.rate_tau,
747 diagnostics,
748 })
749 }
750
751 fn compute_stochastic_elbo_grad(
753 &self,
754 x_batch: ArrayView2<f64>,
755 y_batch: ArrayView1<f64>,
756 scale_factor: f64,
757 seed: u64,
758 ) -> Result<(f64, Array1<f64>)> {
759 let n_batch = x_batch.nrows();
760 let d = self.n_features;
761 let n_params = 2 * d + 2;
762
763 let std_beta = self.log_std_beta.mapv(f64::exp);
764 let expected_tau = self.shape_tau / self.rate_tau;
765 let expected_log_tau = digamma(self.shape_tau) - self.rate_tau.ln();
766
767 let epsilon = generate_standard_normal(d, seed);
769 let beta_sample = &self.mean_beta + &(&std_beta * &epsilon);
770
771 let y_pred = x_batch.dot(&beta_sample);
773 let residuals = &y_batch.to_owned() - &y_pred;
774 let sse = residuals.dot(&residuals);
775
776 let likelihood = scale_factor
778 * (0.5 * n_batch as f64 * expected_log_tau
779 - 0.5 * n_batch as f64 * (2.0 * PI).ln()
780 - 0.5 * expected_tau * sse);
781
782 let beta_sq_sum = beta_sample.dot(&beta_sample);
784 let prior_beta =
785 -0.5 * d as f64 * (2.0 * PI * self.prior_var).ln() - 0.5 / self.prior_var * beta_sq_sum;
786
787 let prior_tau = self.prior_shape * self.prior_rate.ln() - lgamma(self.prior_shape)
789 + (self.prior_shape - 1.0) * expected_log_tau
790 - self.prior_rate * expected_tau;
791
792 let entropy_beta: f64 = (0..d)
794 .map(|i| 0.5 * (1.0 + (2.0 * PI).ln()) + self.log_std_beta[i])
795 .sum();
796
797 let entropy_tau = self.shape_tau - self.rate_tau.ln()
799 + lgamma(self.shape_tau)
800 + (1.0 - self.shape_tau) * digamma(self.shape_tau);
801
802 let elbo = likelihood + prior_beta + prior_tau + entropy_beta + entropy_tau;
803
804 let mut grad = Array1::zeros(n_params);
806
807 let grad_beta_from_likelihood = x_batch.t().dot(&residuals) * expected_tau * scale_factor;
809 let grad_beta_from_prior = &beta_sample * (-1.0 / self.prior_var);
810
811 for i in 0..d {
812 grad[i] = grad_beta_from_likelihood[i] + grad_beta_from_prior[i];
813 }
814
815 for i in 0..d {
817 let dl_dbeta = grad_beta_from_likelihood[i] + grad_beta_from_prior[i];
818 grad[d + i] = dl_dbeta * epsilon[i] * std_beta[i] + 1.0; }
822
823 let d_likelihood_shape =
826 scale_factor * 0.5 * n_batch as f64 * super::trigamma(self.shape_tau);
827 let d_prior_shape = (self.prior_shape - 1.0) * super::trigamma(self.shape_tau)
828 - self.prior_rate / self.rate_tau;
829 let d_entropy_shape = 1.0 - (1.0 - self.shape_tau) * super::trigamma(self.shape_tau)
830 + digamma(self.shape_tau) * (-1.0)
831 + super::trigamma(self.shape_tau) * (1.0 - self.shape_tau);
832 grad[2 * d] = d_likelihood_shape + d_prior_shape + d_entropy_shape * 0.01;
834
835 let d_likelihood_rate =
837 -scale_factor * 0.5 * sse * self.shape_tau / (self.rate_tau * self.rate_tau);
838 let d_prior_rate = self.prior_rate * self.shape_tau / (self.rate_tau * self.rate_tau);
839 grad[2 * d + 1] = d_likelihood_rate - d_prior_rate + 1.0 / self.rate_tau;
840
841 Ok((elbo, grad))
842 }
843
844 fn get_params(&self) -> Array1<f64> {
845 let d = self.n_features;
846 let mut params = Array1::zeros(2 * d + 2);
847 for i in 0..d {
848 params[i] = self.mean_beta[i];
849 params[d + i] = self.log_std_beta[i];
850 }
851 params[2 * d] = self.shape_tau;
852 params[2 * d + 1] = self.rate_tau;
853 params
854 }
855
856 fn set_params(&mut self, params: &Array1<f64>) -> Result<()> {
857 let d = self.n_features;
858 if params.len() != 2 * d + 2 {
859 return Err(StatsError::DimensionMismatch(format!(
860 "params length ({}) must be {}",
861 params.len(),
862 2 * d + 2
863 )));
864 }
865 for i in 0..d {
866 self.mean_beta[i] = params[i];
867 self.log_std_beta[i] = params[d + i];
868 }
869 self.shape_tau = params[2 * d].max(1e-6);
871 self.rate_tau = params[2 * d + 1].max(1e-6);
872 Ok(())
873 }
874}
875
876#[derive(Debug, Clone)]
878pub struct SviRegressionResult {
879 pub mean_beta: Array1<f64>,
881 pub std_beta: Array1<f64>,
883 pub shape_tau: f64,
885 pub rate_tau: f64,
887 pub diagnostics: VariationalDiagnostics,
889}
890
891impl SviRegressionResult {
892 pub fn expected_noise_variance(&self) -> f64 {
894 if self.shape_tau > 1.0 {
895 self.rate_tau / (self.shape_tau - 1.0)
896 } else {
897 self.rate_tau / self.shape_tau
898 }
899 }
900
901 pub fn credible_intervals(&self, confidence: f64) -> Result<Array2<f64>> {
903 check_probability(confidence, "confidence")?;
904 let alpha = (1.0 - confidence) / 2.0;
905 let z_critical = super::normal_ppf(1.0 - alpha)?;
906
907 let d = self.mean_beta.len();
908 let mut intervals = Array2::zeros((d, 2));
909 for i in 0..d {
910 intervals[[i, 0]] = self.mean_beta[i] - z_critical * self.std_beta[i];
911 intervals[[i, 1]] = self.mean_beta[i] + z_critical * self.std_beta[i];
912 }
913 Ok(intervals)
914 }
915}
916
917fn generate_standard_normal(dim: usize, seed: u64) -> Array1<f64> {
928 let mut result = Array1::zeros(dim);
929 let golden_ratio = 1.618033988749895;
930
931 for i in 0..dim {
932 let u1 = ((seed as f64 * golden_ratio + i as f64 * 0.7548776662466927) % 1.0).abs();
934 let u2 = ((seed as f64 * 0.5698402909980532 + i as f64 * golden_ratio) % 1.0).abs();
935
936 let u1_safe = u1.max(1e-10).min(1.0 - 1e-10);
938 let u2_safe = u2.max(1e-10).min(1.0 - 1e-10);
939
940 let r = (-2.0 * u1_safe.ln()).sqrt();
942 let theta = 2.0 * PI * u2_safe;
943 result[i] = r * theta.cos();
944 }
945
946 result
947}
948
949#[cfg(test)]
954mod tests {
955 use super::*;
956 use scirs2_core::ndarray::Array2;
957
958 #[test]
959 fn test_learning_rate_constant() {
960 let lr = LearningRateSchedule::Constant { lr: 0.01 };
961 assert!((lr.get_lr(0) - 0.01).abs() < 1e-10);
962 assert!((lr.get_lr(100) - 0.01).abs() < 1e-10);
963 }
964
965 #[test]
966 fn test_learning_rate_robbins_monro() {
967 let lr = LearningRateSchedule::RobbinsMonro {
968 lr0: 0.1,
969 decay: 0.01,
970 };
971 assert!((lr.get_lr(0) - 0.1).abs() < 1e-10);
972 assert!(lr.get_lr(100) < lr.get_lr(0));
973 assert!(lr.get_lr(100) > 0.0);
974 }
975
976 #[test]
977 fn test_learning_rate_exponential() {
978 let lr = LearningRateSchedule::ExponentialDecay {
979 lr0: 0.1,
980 gamma: 0.99,
981 };
982 assert!((lr.get_lr(0) - 0.1).abs() < 1e-10);
983 assert!(lr.get_lr(100) < lr.get_lr(0));
984 }
985
986 #[test]
987 fn test_adam_state() {
988 let mut adam = AdamState::new(3, 0.01, 0.9, 0.999, 1e-8).expect("should create");
989 let grad = Array1::from_vec(vec![1.0, -0.5, 0.3]);
990 let update = adam.compute_update(&grad).expect("should compute update");
991 assert_eq!(update.len(), 3);
992 for i in 0..3 {
994 assert!(update[i].is_finite());
995 }
996 }
997
998 #[test]
999 fn test_natural_gradient_roundtrip() {
1000 let mf = MeanFieldGaussian::from_params(
1001 Array1::from_vec(vec![1.0, 2.0]),
1002 Array1::from_vec(vec![0.5, -0.3]),
1003 )
1004 .expect("should create");
1005
1006 let nat = NaturalGradientParams::from_mean_field(&mf);
1007 let recovered = nat.to_mean_field().expect("should convert back");
1008
1009 for i in 0..2 {
1010 assert!(
1011 (recovered.means[i] - mf.means[i]).abs() < 1e-6,
1012 "means differ at {}: {} vs {}",
1013 i,
1014 recovered.means[i],
1015 mf.means[i]
1016 );
1017 assert!(
1018 (recovered.log_stds[i] - mf.log_stds[i]).abs() < 1e-6,
1019 "log_stds differ at {}: {} vs {}",
1020 i,
1021 recovered.log_stds[i],
1022 mf.log_stds[i]
1023 );
1024 }
1025 }
1026
1027 #[test]
1028 fn test_svi_creation() {
1029 let config = SviConfig {
1030 max_iter: 100,
1031 batch_size: 10,
1032 ..SviConfig::default()
1033 };
1034 let svi = StochasticVI::new(5, config).expect("should create SVI");
1035 assert_eq!(svi.variational.dim, 5);
1036 }
1037
1038 #[test]
1039 fn test_svi_bayesian_regression() {
1040 let n = 100;
1042 let x_data: Vec<f64> = (0..n).map(|i| i as f64 / n as f64).collect();
1043 let y_data: Vec<f64> = x_data
1044 .iter()
1045 .enumerate()
1046 .map(|(i, &xi)| xi + 0.1 * ((i * 7 % 13) as f64 - 6.0) / 6.0)
1047 .collect();
1048
1049 let x = Array2::from_shape_fn((n, 1), |(i, _)| x_data[i]);
1050 let y = Array1::from_vec(y_data);
1051
1052 let config = SviConfig {
1053 max_iter: 200,
1054 batch_size: 20,
1055 lr_schedule: LearningRateSchedule::Adam {
1056 lr: 0.01,
1057 beta1: 0.9,
1058 beta2: 0.999,
1059 epsilon: 1e-8,
1060 },
1061 ..SviConfig::default()
1062 };
1063
1064 let mut model = SviBayesianRegression::new(1, config).expect("should create");
1065 let result = model.fit(x.view(), y.view()).expect("should fit");
1066
1067 assert!(result.mean_beta[0].is_finite());
1069 assert!(result.std_beta[0].is_finite());
1070 assert!(result.diagnostics.n_iterations > 0);
1071 }
1072
1073 #[test]
1074 fn test_generate_standard_normal() {
1075 let samples = generate_standard_normal(100, 42);
1076 assert_eq!(samples.len(), 100);
1077 for &s in samples.iter() {
1079 assert!(s.is_finite(), "sample should be finite, got {}", s);
1080 }
1081 let mean = samples.sum() / 100.0;
1083 assert!(
1084 mean.abs() < 2.0,
1085 "mean should be roughly zero, got {}",
1086 mean
1087 );
1088 }
1089
1090 #[test]
1091 fn test_svi_config_default() {
1092 let config = SviConfig::default();
1093 assert_eq!(config.max_iter, 1000);
1094 assert_eq!(config.batch_size, 32);
1095 assert!(config.grad_clip > 0.0);
1096 }
1097}