1use crate::error::{StatsError, StatsResult};
15use scirs2_core::ndarray::{Array1, Array2};
16use std::f64::consts::PI;
17
18use super::{PosteriorResult, VariationalInference};
19
20#[derive(Debug, Clone)]
26pub enum AdviTransform {
27 Identity,
29 Log,
31 Logit,
33 Bounded {
35 lower: f64,
37 upper: f64,
39 },
40}
41
42impl AdviTransform {
43 pub fn forward(&self, eta: f64) -> f64 {
45 match self {
46 AdviTransform::Identity => eta,
47 AdviTransform::Log => eta.exp(),
48 AdviTransform::Logit => 1.0 / (1.0 + (-eta).exp()),
49 AdviTransform::Bounded { lower, upper } => {
50 let s = 1.0 / (1.0 + (-eta).exp());
51 lower + (upper - lower) * s
52 }
53 }
54 }
55
56 pub fn inverse(&self, theta: f64) -> StatsResult<f64> {
58 match self {
59 AdviTransform::Identity => Ok(theta),
60 AdviTransform::Log => {
61 if theta <= 0.0 {
62 return Err(StatsError::InvalidArgument(format!(
63 "Log transform requires positive value, got {}",
64 theta
65 )));
66 }
67 Ok(theta.ln())
68 }
69 AdviTransform::Logit => {
70 if theta <= 0.0 || theta >= 1.0 {
71 return Err(StatsError::InvalidArgument(format!(
72 "Logit transform requires value in (0, 1), got {}",
73 theta
74 )));
75 }
76 Ok((theta / (1.0 - theta)).ln())
77 }
78 AdviTransform::Bounded { lower, upper } => {
79 if theta <= *lower || theta >= *upper {
80 return Err(StatsError::InvalidArgument(format!(
81 "Bounded transform requires value in ({}, {}), got {}",
82 lower, upper, theta
83 )));
84 }
85 let s = (theta - lower) / (upper - lower);
86 Ok((s / (1.0 - s)).ln())
87 }
88 }
89 }
90
91 pub fn log_det_jacobian(&self, eta: f64) -> f64 {
94 match self {
95 AdviTransform::Identity => 0.0,
96 AdviTransform::Log => eta,
97 AdviTransform::Logit => {
98 let sp = softplus(eta);
103 eta - 2.0 * sp
104 }
105 AdviTransform::Bounded { lower, upper } => {
106 let log_range = (upper - lower).ln();
107 let sp = softplus(eta);
108 log_range + eta - 2.0 * sp
109 }
110 }
111 }
112
113 pub fn grad_log_det_jacobian(&self, eta: f64) -> f64 {
115 match self {
116 AdviTransform::Identity => 0.0,
117 AdviTransform::Log => 1.0,
118 AdviTransform::Logit | AdviTransform::Bounded { .. } => {
119 let s = sigmoid(eta);
121 1.0 - 2.0 * s
122 }
123 }
124 }
125
126 pub fn forward_grad(&self, eta: f64) -> f64 {
128 match self {
129 AdviTransform::Identity => 1.0,
130 AdviTransform::Log => eta.exp(),
131 AdviTransform::Logit => {
132 let s = sigmoid(eta);
133 s * (1.0 - s)
134 }
135 AdviTransform::Bounded { lower, upper } => {
136 let s = sigmoid(eta);
137 (upper - lower) * s * (1.0 - s)
138 }
139 }
140 }
141}
142
143fn softplus(x: f64) -> f64 {
145 if x > 20.0 {
146 x
147 } else if x < -20.0 {
148 x.exp()
149 } else {
150 (1.0 + x.exp()).ln()
151 }
152}
153
154fn sigmoid(x: f64) -> f64 {
156 if x >= 0.0 {
157 1.0 / (1.0 + (-x).exp())
158 } else {
159 let ex = x.exp();
160 ex / (1.0 + ex)
161 }
162}
163
164#[derive(Debug, Clone, Copy, PartialEq)]
170pub enum AdviApproximation {
171 MeanField,
173 FullRank,
175}
176
177#[derive(Debug, Clone)]
183struct AdviAdamState {
184 m: Array1<f64>,
185 v: Array1<f64>,
186 t: usize,
187 beta1: f64,
188 beta2: f64,
189 epsilon: f64,
190}
191
192impl AdviAdamState {
193 fn new(n_params: usize) -> Self {
194 Self {
195 m: Array1::zeros(n_params),
196 v: Array1::zeros(n_params),
197 t: 0,
198 beta1: 0.9,
199 beta2: 0.999,
200 epsilon: 1e-8,
201 }
202 }
203
204 fn update(&mut self, grad: &Array1<f64>) -> Array1<f64> {
206 self.t += 1;
207 let n = grad.len();
208 let mut direction = Array1::zeros(n);
209 for i in 0..n {
210 self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * grad[i];
211 self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * grad[i] * grad[i];
212 let m_hat = self.m[i] / (1.0 - self.beta1.powi(self.t as i32));
213 let v_hat = self.v[i] / (1.0 - self.beta2.powi(self.t as i32));
214 direction[i] = m_hat / (v_hat.sqrt() + self.epsilon);
215 }
216 direction
217 }
218}
219
220#[derive(Debug, Clone)]
226pub struct AdviConfig {
227 pub approximation: AdviApproximation,
229 pub transforms: Vec<AdviTransform>,
231 pub num_samples: usize,
233 pub learning_rate: f64,
235 pub max_iterations: usize,
237 pub tolerance: f64,
239 pub seed: u64,
241 pub convergence_window: usize,
243}
244
245impl Default for AdviConfig {
246 fn default() -> Self {
247 Self {
248 approximation: AdviApproximation::MeanField,
249 transforms: Vec::new(),
250 num_samples: 10,
251 learning_rate: 0.01,
252 max_iterations: 10000,
253 tolerance: 1e-4,
254 seed: 42,
255 convergence_window: 50,
256 }
257 }
258}
259
260#[derive(Debug, Clone)]
286pub struct Advi {
287 pub config: AdviConfig,
289}
290
291impl Advi {
292 pub fn new(config: AdviConfig) -> Self {
294 Self { config }
295 }
296
297 fn generate_epsilon(&self, dim: usize, seed: u64) -> Array1<f64> {
300 let golden = 1.618033988749895_f64;
301 let plastic = 1.324717957244746_f64;
302 Array1::from_shape_fn(dim, |i| {
303 let u1 = ((seed as f64 * golden + i as f64 * plastic) % 1.0).abs();
304 let u2 = ((seed as f64 * plastic + i as f64 * golden) % 1.0).abs();
305 let u1 = u1.max(1e-10).min(1.0 - 1e-10);
306 let u2 = u2.max(1e-10).min(1.0 - 1e-10);
307 let r = (-2.0 * u1.ln()).sqrt();
308 r * (2.0 * PI * u2).cos()
309 })
310 }
311
312 fn get_transform(&self, i: usize) -> &AdviTransform {
314 if i < self.config.transforms.len() {
315 &self.config.transforms[i]
316 } else {
317 &AdviTransform::Identity
319 }
320 }
321
322 fn transform_to_constrained(&self, eta: &Array1<f64>) -> Array1<f64> {
324 Array1::from_shape_fn(eta.len(), |i| self.get_transform(i).forward(eta[i]))
325 }
326
327 fn total_log_det_jacobian(&self, eta: &Array1<f64>) -> f64 {
329 (0..eta.len())
330 .map(|i| self.get_transform(i).log_det_jacobian(eta[i]))
331 .sum()
332 }
333
334 fn grad_log_det_jacobian(&self, eta: &Array1<f64>) -> Array1<f64> {
336 Array1::from_shape_fn(eta.len(), |i| {
337 self.get_transform(i).grad_log_det_jacobian(eta[i])
338 })
339 }
340
341 fn forward_grad(&self, eta: &Array1<f64>) -> Array1<f64> {
343 Array1::from_shape_fn(eta.len(), |i| self.get_transform(i).forward_grad(eta[i]))
344 }
345
346 fn fit_mean_field<F>(&self, log_joint: F, dim: usize) -> StatsResult<PosteriorResult>
348 where
349 F: Fn(&Array1<f64>) -> StatsResult<(f64, Array1<f64>)>,
350 {
351 let n_params = 2 * dim;
353 let mut mu = Array1::zeros(dim);
354 let mut log_sigma = Array1::zeros(dim); let mut adam = AdviAdamState::new(n_params);
357 let mut elbo_history = Vec::with_capacity(self.config.max_iterations);
358 let mut converged = false;
359
360 for iter in 0..self.config.max_iterations {
361 let mut elbo_sum = 0.0;
362 let mut grad_mu_sum = Array1::zeros(dim);
363 let mut grad_log_sigma_sum = Array1::zeros(dim);
364
365 for s in 0..self.config.num_samples {
366 let seed = self
367 .config
368 .seed
369 .wrapping_add(iter as u64 * 1000)
370 .wrapping_add(s as u64);
371 let epsilon = self.generate_epsilon(dim, seed);
372
373 let sigma = log_sigma.mapv(f64::exp);
375 let eta = &mu + &(&sigma * &epsilon);
376
377 let theta = self.transform_to_constrained(&eta);
379
380 let (log_p, grad_theta) = log_joint(&theta)?;
382
383 let ldj = self.total_log_det_jacobian(&eta);
385 let grad_ldj = self.grad_log_det_jacobian(&eta);
386
387 let fwd_grad = self.forward_grad(&eta);
389 let grad_eta: Array1<f64> =
390 Array1::from_shape_fn(dim, |i| grad_theta[i] * fwd_grad[i] + grad_ldj[i]);
391
392 let elbo_s = log_p + ldj;
394 elbo_sum += elbo_s;
395
396 for i in 0..dim {
400 grad_mu_sum[i] += grad_eta[i];
401 grad_log_sigma_sum[i] += grad_eta[i] * sigma[i] * epsilon[i];
402 }
403 }
404
405 let n_s = self.config.num_samples as f64;
406 elbo_sum /= n_s;
407 grad_mu_sum /= n_s;
408 grad_log_sigma_sum /= n_s;
409
410 for i in 0..dim {
413 grad_log_sigma_sum[i] += 1.0;
414 }
415
416 let entropy: f64 = (0..dim)
418 .map(|i| 0.5 * (1.0 + (2.0 * PI).ln()) + log_sigma[i])
419 .sum();
420 elbo_sum += entropy;
421
422 elbo_history.push(elbo_sum);
423
424 let mut full_grad = Array1::zeros(n_params);
426 for i in 0..dim {
427 full_grad[i] = grad_mu_sum[i];
428 full_grad[dim + i] = grad_log_sigma_sum[i];
429 }
430
431 let direction = adam.update(&full_grad);
433 let lr = self.config.learning_rate;
434 for i in 0..dim {
435 mu[i] += lr * direction[i];
436 log_sigma[i] += lr * direction[dim + i];
437 log_sigma[i] = log_sigma[i].max(-10.0).min(10.0);
439 }
440
441 if elbo_history.len() >= self.config.convergence_window {
443 let n = elbo_history.len();
444 let w = self.config.convergence_window;
445 let recent_avg: f64 =
446 elbo_history[n - w / 2..n].iter().sum::<f64>() / (w / 2) as f64;
447 let earlier_avg: f64 =
448 elbo_history[n - w..n - w / 2].iter().sum::<f64>() / (w / 2) as f64;
449 if (recent_avg - earlier_avg).abs() < self.config.tolerance {
450 converged = true;
451 break;
452 }
453 }
454 }
455
456 let sigma = log_sigma.mapv(f64::exp);
458 let constrained_means = self.transform_to_constrained(&mu);
459
460 let fwd_grad = self.forward_grad(&mu);
463 let constrained_stds = Array1::from_shape_fn(dim, |i| (fwd_grad[i] * sigma[i]).abs());
464
465 Ok(PosteriorResult {
466 means: constrained_means,
467 std_devs: constrained_stds,
468 elbo_history: elbo_history.clone(),
469 iterations: elbo_history.len(),
470 converged,
471 samples: None,
472 })
473 }
474
475 fn fit_full_rank<F>(&self, log_joint: F, dim: usize) -> StatsResult<PosteriorResult>
477 where
478 F: Fn(&Array1<f64>) -> StatsResult<(f64, Array1<f64>)>,
479 {
480 let n_tril = dim * (dim + 1) / 2;
482 let n_params = dim + n_tril;
483 let mut mu = Array1::zeros(dim);
484 let mut l_entries = Array1::zeros(n_tril);
486 {
487 let mut idx = 0;
488 for row in 0..dim {
489 for col in 0..=row {
490 if row == col {
491 l_entries[idx] = 1.0; }
493 idx += 1;
494 }
495 }
496 }
497
498 let mut adam = AdviAdamState::new(n_params);
499 let mut elbo_history = Vec::with_capacity(self.config.max_iterations);
500 let mut converged = false;
501
502 for iter in 0..self.config.max_iterations {
503 let l_mat = tril_to_matrix(dim, &l_entries);
505
506 let mut elbo_sum = 0.0;
507 let mut grad_mu_sum = Array1::zeros(dim);
508 let mut grad_l_sum = Array1::zeros(n_tril);
509
510 for s in 0..self.config.num_samples {
511 let seed = self
512 .config
513 .seed
514 .wrapping_add(iter as u64 * 1000)
515 .wrapping_add(s as u64);
516 let epsilon = self.generate_epsilon(dim, seed);
517
518 let l_eps = l_mat.dot(&epsilon);
520 let eta = &mu + &l_eps;
521
522 let theta = self.transform_to_constrained(&eta);
524
525 let (log_p, grad_theta) = log_joint(&theta)?;
527
528 let ldj = self.total_log_det_jacobian(&eta);
530 let grad_ldj = self.grad_log_det_jacobian(&eta);
531
532 let fwd_grad = self.forward_grad(&eta);
534 let grad_eta: Array1<f64> =
535 Array1::from_shape_fn(dim, |i| grad_theta[i] * fwd_grad[i] + grad_ldj[i]);
536
537 let elbo_s = log_p + ldj;
538 elbo_sum += elbo_s;
539
540 for i in 0..dim {
542 grad_mu_sum[i] += grad_eta[i];
543 }
544
545 let mut idx = 0;
548 for row in 0..dim {
549 for col in 0..=row {
550 grad_l_sum[idx] += grad_eta[row] * epsilon[col];
551 idx += 1;
552 }
553 }
554 }
555
556 let n_s = self.config.num_samples as f64;
557 elbo_sum /= n_s;
558 grad_mu_sum /= n_s;
559 grad_l_sum /= n_s;
560
561 let mut entropy = 0.5 * dim as f64 * (1.0 + (2.0 * PI).ln());
564 {
565 let mut idx = 0;
566 for row in 0..dim {
567 for col in 0..=row {
568 if row == col {
569 entropy += l_entries[idx].abs().max(1e-15).ln();
570 let l_ii = l_entries[idx];
572 if l_ii.abs() > 1e-15 {
573 grad_l_sum[idx] += 1.0 / l_ii;
574 }
575 }
576 idx += 1;
577 }
578 }
579 }
580 elbo_sum += entropy;
581 elbo_history.push(elbo_sum);
582
583 let mut full_grad = Array1::zeros(n_params);
585 for i in 0..dim {
586 full_grad[i] = grad_mu_sum[i];
587 }
588 for i in 0..n_tril {
589 full_grad[dim + i] = grad_l_sum[i];
590 }
591
592 let direction = adam.update(&full_grad);
594 let lr = self.config.learning_rate;
595 for i in 0..dim {
596 mu[i] += lr * direction[i];
597 }
598 for i in 0..n_tril {
599 l_entries[i] += lr * direction[dim + i];
600 }
601
602 {
604 let mut idx = 0;
605 for row in 0..dim {
606 for col in 0..=row {
607 if row == col {
608 l_entries[idx] = l_entries[idx].abs().max(1e-6);
609 }
610 l_entries[idx] = l_entries[idx].max(-10.0).min(10.0);
612 idx += 1;
613 }
614 }
615 }
616
617 if elbo_history.len() >= self.config.convergence_window {
619 let n = elbo_history.len();
620 let w = self.config.convergence_window;
621 let recent_avg: f64 =
622 elbo_history[n - w / 2..n].iter().sum::<f64>() / (w / 2) as f64;
623 let earlier_avg: f64 =
624 elbo_history[n - w..n - w / 2].iter().sum::<f64>() / (w / 2) as f64;
625 if (recent_avg - earlier_avg).abs() < self.config.tolerance {
626 converged = true;
627 break;
628 }
629 }
630 }
631
632 let l_mat = tril_to_matrix(dim, &l_entries);
634 let constrained_means = self.transform_to_constrained(&mu);
635
636 let cov = l_mat.dot(&l_mat.t());
638
639 let fwd_grad = self.forward_grad(&mu);
641 let constrained_stds =
642 Array1::from_shape_fn(dim, |i| (fwd_grad[i] * fwd_grad[i] * cov[[i, i]]).sqrt());
643
644 Ok(PosteriorResult {
645 means: constrained_means,
646 std_devs: constrained_stds,
647 elbo_history: elbo_history.clone(),
648 iterations: elbo_history.len(),
649 converged,
650 samples: None,
651 })
652 }
653}
654
655impl VariationalInference for Advi {
656 fn fit<F>(&mut self, log_joint: F, dim: usize) -> StatsResult<PosteriorResult>
657 where
658 F: Fn(&Array1<f64>) -> StatsResult<(f64, Array1<f64>)>,
659 {
660 if dim == 0 {
661 return Err(StatsError::InvalidArgument(
662 "Dimension must be at least 1".to_string(),
663 ));
664 }
665 if self.config.num_samples == 0 {
666 return Err(StatsError::InvalidArgument(
667 "num_samples must be at least 1".to_string(),
668 ));
669 }
670 if self.config.learning_rate <= 0.0 {
671 return Err(StatsError::InvalidArgument(
672 "learning_rate must be positive".to_string(),
673 ));
674 }
675
676 match self.config.approximation {
677 AdviApproximation::MeanField => self.fit_mean_field(log_joint, dim),
678 AdviApproximation::FullRank => self.fit_full_rank(log_joint, dim),
679 }
680 }
681}
682
683fn tril_to_matrix(dim: usize, entries: &Array1<f64>) -> Array2<f64> {
689 let mut mat = Array2::zeros((dim, dim));
690 let mut idx = 0;
691 for row in 0..dim {
692 for col in 0..=row {
693 mat[[row, col]] = entries[idx];
694 idx += 1;
695 }
696 }
697 mat
698}
699
700#[cfg(test)]
705mod tests {
706 use super::*;
707
708 #[test]
711 fn test_advi_gaussian_posterior_recovery() {
712 let data_mean = 3.0_f64;
715 let n_data = 10.0_f64;
716 let prior_mean = 0.0_f64;
717 let prior_var = 1.0_f64;
718 let lik_var = 1.0_f64;
719
720 let config = AdviConfig {
721 approximation: AdviApproximation::MeanField,
722 transforms: vec![AdviTransform::Identity],
723 num_samples: 20,
724 learning_rate: 0.05,
725 max_iterations: 3000,
726 tolerance: 1e-5,
727 seed: 123,
728 convergence_window: 100,
729 };
730
731 let mut advi = Advi::new(config);
732 let result = advi
733 .fit(
734 move |theta: &Array1<f64>| {
735 let mu = theta[0];
736 let log_prior = -0.5 * (mu - prior_mean).powi(2) / prior_var;
738 let log_lik = -n_data / 2.0 * (mu - data_mean).powi(2) / lik_var;
740 let log_p = log_prior + log_lik;
741 let grad_prior = -(mu - prior_mean) / prior_var;
743 let grad_lik = -n_data * (mu - data_mean) / lik_var;
744 let grad = Array1::from_vec(vec![grad_prior + grad_lik]);
745 Ok((log_p, grad))
746 },
747 1,
748 )
749 .expect("ADVI should not fail");
750
751 let expected_mean = (n_data * data_mean / lik_var + prior_mean / prior_var)
752 / (n_data / lik_var + 1.0 / prior_var);
753 let expected_std = (1.0 / (n_data / lik_var + 1.0 / prior_var)).sqrt();
754
755 assert!(
756 (result.means[0] - expected_mean).abs() < 0.3,
757 "Mean should be close to {}, got {}",
758 expected_mean,
759 result.means[0]
760 );
761 assert!(
762 (result.std_devs[0] - expected_std).abs() < 0.2,
763 "Std should be close to {}, got {}",
764 expected_std,
765 result.std_devs[0]
766 );
767 }
768
769 #[test]
771 fn test_advi_elbo_increases() {
772 let config = AdviConfig {
773 approximation: AdviApproximation::MeanField,
774 transforms: vec![AdviTransform::Identity, AdviTransform::Identity],
775 num_samples: 15,
776 learning_rate: 0.02,
777 max_iterations: 500,
778 tolerance: 1e-6,
779 seed: 77,
780 convergence_window: 50,
781 };
782
783 let mut advi = Advi::new(config);
784 let result = advi
785 .fit(
786 |theta: &Array1<f64>| {
787 let diff0 = theta[0] - 1.0;
789 let diff1 = theta[1] - 2.0;
790 let log_p = -0.5 * (diff0 * diff0 + diff1 * diff1);
791 let grad = Array1::from_vec(vec![-diff0, -diff1]);
792 Ok((log_p, grad))
793 },
794 2,
795 )
796 .expect("ADVI should succeed");
797
798 let n = result.elbo_history.len();
800 assert!(n > 100, "Should run at least 100 iterations");
801 let early_avg: f64 = result.elbo_history[..50].iter().sum::<f64>() / 50.0;
802 let late_avg: f64 = result.elbo_history[n - 50..].iter().sum::<f64>() / 50.0;
803 assert!(
804 late_avg > early_avg - 1.0,
805 "Late ELBO ({}) should be higher than early ({})",
806 late_avg,
807 early_avg
808 );
809 }
810
811 #[test]
815 fn test_advi_mean_field_vs_full_rank() {
816 let rho = 0.8_f64;
818 let log_joint = move |theta: &Array1<f64>| {
819 let x = theta[0];
820 let y = theta[1];
821 let det = 1.0 - rho * rho;
822 let log_p =
823 -0.5 / det * (x * x - 2.0 * rho * x * y + y * y) - 0.5 * (2.0 * PI * det).ln();
824 let gx = -1.0 / det * (x - rho * y);
825 let gy = -1.0 / det * (y - rho * x);
826 Ok((log_p, Array1::from_vec(vec![gx, gy])))
827 };
828
829 let mf_config = AdviConfig {
831 approximation: AdviApproximation::MeanField,
832 num_samples: 20,
833 learning_rate: 0.02,
834 max_iterations: 2000,
835 tolerance: 1e-5,
836 seed: 42,
837 convergence_window: 100,
838 ..Default::default()
839 };
840 let mut mf_advi = Advi::new(mf_config);
841 let mf_result = mf_advi.fit(log_joint, 2).expect("MF should succeed");
842
843 let fr_config = AdviConfig {
845 approximation: AdviApproximation::FullRank,
846 num_samples: 20,
847 learning_rate: 0.02,
848 max_iterations: 2000,
849 tolerance: 1e-5,
850 seed: 42,
851 convergence_window: 100,
852 ..Default::default()
853 };
854 let mut fr_advi = Advi::new(fr_config);
855 let fr_result = fr_advi.fit(log_joint, 2).expect("FR should succeed");
856
857 let mf_final_elbo = mf_result
858 .elbo_history
859 .last()
860 .copied()
861 .unwrap_or(f64::NEG_INFINITY);
862 let fr_final_elbo = fr_result
863 .elbo_history
864 .last()
865 .copied()
866 .unwrap_or(f64::NEG_INFINITY);
867
868 assert!(
870 fr_final_elbo > mf_final_elbo - 1.0,
871 "Full-rank ELBO ({}) should be >= mean-field ELBO ({}) minus tolerance",
872 fr_final_elbo,
873 mf_final_elbo
874 );
875 }
876
877 #[test]
879 fn test_advi_log_transform() {
880 let config = AdviConfig {
883 approximation: AdviApproximation::MeanField,
884 transforms: vec![AdviTransform::Log],
885 num_samples: 20,
886 learning_rate: 0.01,
887 max_iterations: 3000,
888 tolerance: 1e-5,
889 seed: 55,
890 convergence_window: 100,
891 };
892
893 let mut advi = Advi::new(config);
894 let result = advi
895 .fit(
896 |theta: &Array1<f64>| {
897 let x = theta[0];
898 if x <= 0.0 {
899 return Ok((f64::NEG_INFINITY, Array1::zeros(1)));
900 }
901 let log_p = 2.0 * x.ln() - x - (2.0_f64).ln(); let grad = Array1::from_vec(vec![2.0 / x - 1.0]);
904 Ok((log_p, grad))
905 },
906 1,
907 )
908 .expect("ADVI with log transform should succeed");
909
910 assert!(
912 result.means[0] > 0.0,
913 "Mean should be positive with log transform"
914 );
915 assert!(
916 (result.means[0] - 3.0).abs() < 1.5,
917 "Mean should be near 3 (Gamma(3,1) mean), got {}",
918 result.means[0]
919 );
920 }
921
922 #[test]
924 fn test_advi_zero_dim_error() {
925 let mut advi = Advi::new(AdviConfig::default());
926 let result = advi.fit(|_theta: &Array1<f64>| Ok((0.0, Array1::zeros(0))), 0);
927 assert!(result.is_err());
928 }
929
930 #[test]
932 fn test_transform_roundtrip() {
933 let transforms = vec![
934 AdviTransform::Identity,
935 AdviTransform::Log,
936 AdviTransform::Logit,
937 AdviTransform::Bounded {
938 lower: -2.0,
939 upper: 5.0,
940 },
941 ];
942 let test_vals = vec![1.5, 2.0, 0.3, 1.0];
943
944 for (t, v) in transforms.iter().zip(test_vals.iter()) {
945 let eta = t.inverse(*v).expect("inverse should succeed");
946 let recovered = t.forward(eta);
947 assert!(
948 (recovered - v).abs() < 1e-10,
949 "Roundtrip failed for {:?}: {} -> {} -> {}",
950 t,
951 v,
952 eta,
953 recovered
954 );
955 }
956 }
957
958 #[test]
960 fn test_log_det_jacobian_nonzero() {
961 let transforms = vec![
962 AdviTransform::Log,
963 AdviTransform::Logit,
964 AdviTransform::Bounded {
965 lower: 0.0,
966 upper: 10.0,
967 },
968 ];
969 for t in &transforms {
970 let ldj = t.log_det_jacobian(0.5);
971 assert!(
972 ldj.is_finite(),
973 "Log-det-Jacobian should be finite for {:?}",
974 t
975 );
976 }
977 }
978}