1use crate::error::{StatsError, StatsResult as Result};
14use scirs2_core::ndarray::{Array1, Array2};
15use scirs2_core::validation::*;
16use std::f64::consts::PI;
17
18use super::svi::{AdamState, LearningRateSchedule};
19use super::{FullRankGaussian, MeanFieldGaussian, VariationalDiagnostics};
20
21#[derive(Debug, Clone)]
27pub enum ParameterConstraint {
28 Real,
30 Positive,
32 UnitInterval,
34 Bounded {
36 lower: f64,
38 upper: f64,
40 },
41 Simplex {
43 dim: usize,
45 },
46 LowerBounded {
48 lower: f64,
50 },
51 UpperBounded {
53 upper: f64,
55 },
56}
57
58impl ParameterConstraint {
59 pub fn forward(&self, unconstrained: f64) -> f64 {
61 match self {
62 ParameterConstraint::Real => unconstrained,
63 ParameterConstraint::Positive => unconstrained.exp(),
64 ParameterConstraint::UnitInterval => 1.0 / (1.0 + (-unconstrained).exp()),
65 ParameterConstraint::Bounded { lower, upper } => {
66 let sigmoid = 1.0 / (1.0 + (-unconstrained).exp());
67 lower + (upper - lower) * sigmoid
68 }
69 ParameterConstraint::LowerBounded { lower } => lower + unconstrained.exp(),
70 ParameterConstraint::UpperBounded { upper } => upper - (-unconstrained).exp(),
71 ParameterConstraint::Simplex { .. } => {
72 1.0 / (1.0 + (-unconstrained).exp())
75 }
76 }
77 }
78
79 pub fn inverse(&self, constrained: f64) -> Result<f64> {
81 match self {
82 ParameterConstraint::Real => Ok(constrained),
83 ParameterConstraint::Positive => {
84 if constrained <= 0.0 {
85 return Err(StatsError::InvalidArgument(format!(
86 "Positive constraint requires value > 0, got {}",
87 constrained
88 )));
89 }
90 Ok(constrained.ln())
91 }
92 ParameterConstraint::UnitInterval => {
93 if constrained <= 0.0 || constrained >= 1.0 {
94 return Err(StatsError::InvalidArgument(format!(
95 "Unit interval constraint requires 0 < value < 1, got {}",
96 constrained
97 )));
98 }
99 Ok((constrained / (1.0 - constrained)).ln())
100 }
101 ParameterConstraint::Bounded { lower, upper } => {
102 if constrained <= *lower || constrained >= *upper {
103 return Err(StatsError::InvalidArgument(format!(
104 "Bounded constraint requires {} < value < {}, got {}",
105 lower, upper, constrained
106 )));
107 }
108 let normalized = (constrained - lower) / (upper - lower);
109 Ok((normalized / (1.0 - normalized)).ln())
110 }
111 ParameterConstraint::LowerBounded { lower } => {
112 if constrained <= *lower {
113 return Err(StatsError::InvalidArgument(format!(
114 "Lower-bounded constraint requires value > {}, got {}",
115 lower, constrained
116 )));
117 }
118 Ok((constrained - lower).ln())
119 }
120 ParameterConstraint::UpperBounded { upper } => {
121 if constrained >= *upper {
122 return Err(StatsError::InvalidArgument(format!(
123 "Upper-bounded constraint requires value < {}, got {}",
124 upper, constrained
125 )));
126 }
127 Ok(-((*upper - constrained).ln()))
128 }
129 ParameterConstraint::Simplex { .. } => {
130 if constrained <= 0.0 || constrained >= 1.0 {
131 return Err(StatsError::InvalidArgument(format!(
132 "Simplex element must be in (0, 1), got {}",
133 constrained
134 )));
135 }
136 Ok((constrained / (1.0 - constrained)).ln())
137 }
138 }
139 }
140
141 pub fn log_det_jacobian(&self, unconstrained: f64) -> f64 {
150 match self {
151 ParameterConstraint::Real => 0.0,
152 ParameterConstraint::Positive => {
153 unconstrained
155 }
156 ParameterConstraint::UnitInterval => {
157 let s = 1.0 / (1.0 + (-unconstrained).exp());
159 (s * (1.0 - s)).ln()
160 }
161 ParameterConstraint::Bounded { lower, upper } => {
162 let s = 1.0 / (1.0 + (-unconstrained).exp());
163 ((upper - lower) * s * (1.0 - s)).ln()
164 }
165 ParameterConstraint::LowerBounded { .. } => unconstrained,
166 ParameterConstraint::UpperBounded { .. } => unconstrained,
167 ParameterConstraint::Simplex { .. } => {
168 let s = 1.0 / (1.0 + (-unconstrained).exp());
169 (s * (1.0 - s)).ln()
170 }
171 }
172 }
173}
174
175#[derive(Debug, Clone)]
181pub struct AdviConfig {
182 pub max_iter: usize,
184 pub tol: f64,
186 pub n_mc_samples: usize,
188 pub lr_schedule: LearningRateSchedule,
190 pub grad_clip: f64,
192 pub diagnostic_interval: usize,
194 pub seed: u64,
196 pub convergence_window: usize,
198}
199
200impl Default for AdviConfig {
201 fn default() -> Self {
202 Self {
203 max_iter: 10000,
204 tol: 1e-4,
205 n_mc_samples: 1,
206 lr_schedule: LearningRateSchedule::default_adam(),
207 grad_clip: 10.0,
208 diagnostic_interval: 100,
209 seed: 42,
210 convergence_window: 50,
211 }
212 }
213}
214
215#[derive(Debug, Clone)]
232pub struct AdviMeanField {
233 pub variational: MeanFieldGaussian,
235 pub constraints: Vec<ParameterConstraint>,
237 pub config: AdviConfig,
239 pub diagnostics: VariationalDiagnostics,
241 pub dim: usize,
243}
244
245impl AdviMeanField {
246 pub fn new(constraints: Vec<ParameterConstraint>, config: AdviConfig) -> Result<Self> {
252 let dim = constraints.len();
253 if dim == 0 {
254 return Err(StatsError::InvalidArgument(
255 "Must have at least one parameter".to_string(),
256 ));
257 }
258
259 let variational = MeanFieldGaussian::new(dim)?;
260
261 Ok(Self {
262 variational,
263 constraints,
264 config,
265 diagnostics: VariationalDiagnostics::new(),
266 dim,
267 })
268 }
269
270 pub fn new_unconstrained(dim: usize, config: AdviConfig) -> Result<Self> {
272 let constraints = vec![ParameterConstraint::Real; dim];
273 Self::new(constraints, config)
274 }
275
276 pub fn initialize_from_constrained(&mut self, theta: &Array1<f64>) -> Result<()> {
278 if theta.len() != self.dim {
279 return Err(StatsError::DimensionMismatch(format!(
280 "theta length ({}) must match dimension ({})",
281 theta.len(),
282 self.dim
283 )));
284 }
285
286 let mut eta = Array1::zeros(self.dim);
287 for i in 0..self.dim {
288 eta[i] = self.constraints[i].inverse(theta[i])?;
289 }
290 self.variational.means = eta;
291 self.variational.log_stds = Array1::from_elem(self.dim, -1.0);
293 Ok(())
294 }
295
296 pub fn fit<F>(&mut self, log_joint: F) -> Result<AdviResult>
305 where
306 F: Fn(&Array1<f64>) -> Result<(f64, Array1<f64>)>,
307 {
308 let n_params = self.variational.n_params();
309 let mut adam_state = if let LearningRateSchedule::Adam {
310 lr,
311 beta1,
312 beta2,
313 epsilon,
314 } = &self.config.lr_schedule
315 {
316 Some(AdamState::new(n_params, *lr, *beta1, *beta2, *epsilon)?)
317 } else {
318 None
319 };
320
321 for iter in 0..self.config.max_iter {
322 let (elbo, grad) = self.compute_elbo_gradient(&log_joint, iter as u64)?;
324
325 self.diagnostics.record_elbo(elbo);
326
327 let grad_norm = grad.dot(&grad).sqrt();
328 self.diagnostics.record_gradient_norm(grad_norm);
329
330 let clipped_grad = if self.config.grad_clip > 0.0 && grad_norm > self.config.grad_clip {
332 &grad * (self.config.grad_clip / grad_norm)
333 } else {
334 grad
335 };
336
337 let current_params = self.variational.get_params();
339
340 let new_params = if let Some(ref mut adam) = adam_state {
342 let update = adam.compute_update(&clipped_grad)?;
343 ¤t_params + &update
344 } else {
345 let lr = self.config.lr_schedule.get_lr(iter);
346 ¤t_params + &(&clipped_grad * lr)
347 };
348
349 let param_change = (&new_params - ¤t_params).mapv(|x| x * x).sum().sqrt();
350 self.diagnostics.record_param_change(param_change);
351
352 self.variational.set_params(&new_params)?;
353
354 if iter > self.config.convergence_window {
356 if let Some(rel_change) = self
357 .diagnostics
358 .relative_elbo_change(self.config.convergence_window)
359 {
360 if rel_change < self.config.tol {
361 self.diagnostics.converged = true;
362 break;
363 }
364 }
365 }
366 }
367
368 let constrained_means = self.transform_to_constrained(&self.variational.means)?;
370
371 Ok(AdviResult {
372 variational: self.variational.clone(),
373 constraints: self.constraints.clone(),
374 constrained_means,
375 diagnostics: self.diagnostics.clone(),
376 dim: self.dim,
377 })
378 }
379
380 fn compute_elbo_gradient<F>(&self, log_joint: &F, seed: u64) -> Result<(f64, Array1<f64>)>
382 where
383 F: Fn(&Array1<f64>) -> Result<(f64, Array1<f64>)>,
384 {
385 let dim = self.dim;
386 let n_samples = self.config.n_mc_samples.max(1);
387 let n_params = 2 * dim;
388
389 let mut total_elbo = 0.0;
390 let mut total_grad = Array1::zeros(n_params);
391
392 let stds = self.variational.stds();
393
394 for s in 0..n_samples {
395 let epsilon = generate_standard_normal_advi(dim, seed * 1000 + s as u64);
397
398 let eta = self.variational.sample(&epsilon)?;
400
401 let theta = self.transform_to_constrained(&eta)?;
403
404 let (log_p, grad_theta) = log_joint(&theta)?;
406
407 let mut log_det_j = 0.0;
409 for i in 0..dim {
410 log_det_j += self.constraints[i].log_det_jacobian(eta[i]);
411 }
412
413 total_elbo += log_p + log_det_j;
415
416 let grad_eta = self.compute_grad_eta(&eta, &grad_theta)?;
418
419 let grad_log_det_j = self.compute_grad_log_det_j(&eta)?;
421
422 let grad_combined = &grad_eta + &grad_log_det_j;
424
425 for i in 0..dim {
427 total_grad[i] += grad_combined[i];
429 total_grad[dim + i] += grad_combined[i] * epsilon[i] * stds[i];
432 }
433 }
434
435 total_elbo /= n_samples as f64;
437 total_grad /= n_samples as f64;
438
439 let entropy = self.variational.entropy();
441 total_elbo += entropy;
442
443 for i in 0..dim {
445 total_grad[dim + i] += 1.0;
446 }
447
448 Ok((total_elbo, total_grad))
449 }
450
451 fn transform_to_constrained(&self, eta: &Array1<f64>) -> Result<Array1<f64>> {
453 let mut theta = Array1::zeros(self.dim);
454 for i in 0..self.dim {
455 theta[i] = self.constraints[i].forward(eta[i]);
456 }
457 Ok(theta)
458 }
459
460 fn compute_grad_eta(&self, eta: &Array1<f64>, grad_theta: &Array1<f64>) -> Result<Array1<f64>> {
462 let mut grad_eta = Array1::zeros(self.dim);
463 for i in 0..self.dim {
464 let dtheta_deta = self.compute_transform_derivative(i, eta[i]);
466 grad_eta[i] = grad_theta[i] * dtheta_deta;
467 }
468 Ok(grad_eta)
469 }
470
471 fn compute_transform_derivative(&self, i: usize, unconstrained: f64) -> f64 {
473 match &self.constraints[i] {
474 ParameterConstraint::Real => 1.0,
475 ParameterConstraint::Positive => unconstrained.exp(),
476 ParameterConstraint::UnitInterval => {
477 let s = 1.0 / (1.0 + (-unconstrained).exp());
478 s * (1.0 - s)
479 }
480 ParameterConstraint::Bounded { lower, upper } => {
481 let s = 1.0 / (1.0 + (-unconstrained).exp());
482 (upper - lower) * s * (1.0 - s)
483 }
484 ParameterConstraint::LowerBounded { .. } => unconstrained.exp(),
485 ParameterConstraint::UpperBounded { .. } => (-unconstrained).exp(),
486 ParameterConstraint::Simplex { .. } => {
487 let s = 1.0 / (1.0 + (-unconstrained).exp());
488 s * (1.0 - s)
489 }
490 }
491 }
492
493 fn compute_grad_log_det_j(&self, eta: &Array1<f64>) -> Result<Array1<f64>> {
495 let mut grad = Array1::zeros(self.dim);
496 for i in 0..self.dim {
497 grad[i] = self.compute_grad_log_det_j_single(i, eta[i]);
498 }
499 Ok(grad)
500 }
501
502 fn compute_grad_log_det_j_single(&self, i: usize, unconstrained: f64) -> f64 {
504 match &self.constraints[i] {
505 ParameterConstraint::Real => 0.0,
506 ParameterConstraint::Positive => 1.0,
507 ParameterConstraint::UnitInterval => {
508 let s = 1.0 / (1.0 + (-unconstrained).exp());
509 1.0 - 2.0 * s
510 }
511 ParameterConstraint::Bounded { .. } => {
512 let s = 1.0 / (1.0 + (-unconstrained).exp());
513 1.0 - 2.0 * s
514 }
515 ParameterConstraint::LowerBounded { .. } => 1.0,
516 ParameterConstraint::UpperBounded { .. } => 1.0,
517 ParameterConstraint::Simplex { .. } => {
518 let s = 1.0 / (1.0 + (-unconstrained).exp());
519 1.0 - 2.0 * s
520 }
521 }
522 }
523}
524
525#[derive(Debug, Clone)]
537pub struct AdviFullRank {
538 pub variational: FullRankGaussian,
540 pub constraints: Vec<ParameterConstraint>,
542 pub config: AdviConfig,
544 pub diagnostics: VariationalDiagnostics,
546 pub dim: usize,
548}
549
550impl AdviFullRank {
551 pub fn new(constraints: Vec<ParameterConstraint>, config: AdviConfig) -> Result<Self> {
553 let dim = constraints.len();
554 if dim == 0 {
555 return Err(StatsError::InvalidArgument(
556 "Must have at least one parameter".to_string(),
557 ));
558 }
559
560 let variational = FullRankGaussian::new(dim)?;
561
562 Ok(Self {
563 variational,
564 constraints,
565 config,
566 diagnostics: VariationalDiagnostics::new(),
567 dim,
568 })
569 }
570
571 pub fn new_unconstrained(dim: usize, config: AdviConfig) -> Result<Self> {
573 let constraints = vec![ParameterConstraint::Real; dim];
574 Self::new(constraints, config)
575 }
576
577 pub fn initialize_from_constrained(&mut self, theta: &Array1<f64>) -> Result<()> {
579 if theta.len() != self.dim {
580 return Err(StatsError::DimensionMismatch(format!(
581 "theta length ({}) must match dimension ({})",
582 theta.len(),
583 self.dim
584 )));
585 }
586
587 let mut eta = Array1::zeros(self.dim);
588 for i in 0..self.dim {
589 eta[i] = self.constraints[i].inverse(theta[i])?;
590 }
591 self.variational.mean = eta;
592 self.variational.chol_factor = Array2::eye(self.dim) * 0.1;
594 Ok(())
595 }
596
597 pub fn fit<F>(&mut self, log_joint: F) -> Result<AdviFullRankResult>
603 where
604 F: Fn(&Array1<f64>) -> Result<(f64, Array1<f64>)>,
605 {
606 let n_params = self.variational.n_params();
607 let mut adam_state = if let LearningRateSchedule::Adam {
608 lr,
609 beta1,
610 beta2,
611 epsilon,
612 } = &self.config.lr_schedule
613 {
614 Some(AdamState::new(n_params, *lr, *beta1, *beta2, *epsilon)?)
615 } else {
616 None
617 };
618
619 for iter in 0..self.config.max_iter {
620 let (elbo, grad) = self.compute_elbo_gradient_full_rank(&log_joint, iter as u64)?;
622
623 self.diagnostics.record_elbo(elbo);
624
625 let grad_norm = grad.dot(&grad).sqrt();
626 self.diagnostics.record_gradient_norm(grad_norm);
627
628 let clipped_grad = if self.config.grad_clip > 0.0 && grad_norm > self.config.grad_clip {
630 &grad * (self.config.grad_clip / grad_norm)
631 } else {
632 grad
633 };
634
635 let current_params = self.variational.get_params();
637
638 let new_params = if let Some(ref mut adam) = adam_state {
640 let update = adam.compute_update(&clipped_grad)?;
641 ¤t_params + &update
642 } else {
643 let lr = self.config.lr_schedule.get_lr(iter);
644 ¤t_params + &(&clipped_grad * lr)
645 };
646
647 let param_change = (&new_params - ¤t_params).mapv(|x| x * x).sum().sqrt();
648 self.diagnostics.record_param_change(param_change);
649
650 self.variational.set_params(&new_params)?;
651
652 if iter > self.config.convergence_window {
654 if let Some(rel_change) = self
655 .diagnostics
656 .relative_elbo_change(self.config.convergence_window)
657 {
658 if rel_change < self.config.tol {
659 self.diagnostics.converged = true;
660 break;
661 }
662 }
663 }
664 }
665
666 let constrained_means = self.transform_to_constrained(&self.variational.mean)?;
668
669 Ok(AdviFullRankResult {
670 variational: self.variational.clone(),
671 constraints: self.constraints.clone(),
672 constrained_means,
673 diagnostics: self.diagnostics.clone(),
674 dim: self.dim,
675 })
676 }
677
678 fn compute_elbo_gradient_full_rank<F>(
680 &self,
681 log_joint: &F,
682 seed: u64,
683 ) -> Result<(f64, Array1<f64>)>
684 where
685 F: Fn(&Array1<f64>) -> Result<(f64, Array1<f64>)>,
686 {
687 let dim = self.dim;
688 let n_samples = self.config.n_mc_samples.max(1);
689 let n_params = self.variational.n_params();
690
691 let mut total_elbo = 0.0;
692 let mut total_grad = Array1::zeros(n_params);
693
694 let n_tril = dim * (dim + 1) / 2;
695
696 for s in 0..n_samples {
697 let epsilon = generate_standard_normal_advi(dim, seed * 1000 + s as u64);
699
700 let eta = self.variational.sample(&epsilon)?;
702
703 let theta = self.transform_to_constrained(&eta)?;
705
706 let (log_p, grad_theta) = log_joint(&theta)?;
708
709 let mut log_det_j = 0.0;
711 for i in 0..dim {
712 log_det_j += compute_log_det_jacobian(&self.constraints[i], eta[i]);
713 }
714
715 total_elbo += log_p + log_det_j;
716
717 let grad_eta = compute_grad_eta_from_theta(dim, &eta, &grad_theta, &self.constraints)?;
719 let grad_log_det = compute_grad_log_det(dim, &eta, &self.constraints)?;
720 let grad_combined: Array1<f64> = &grad_eta + &grad_log_det;
721
722 for i in 0..dim {
724 total_grad[i] += grad_combined[i];
725 }
726
727 let mut l_idx = dim;
730 for i in 0..dim {
731 for j in 0..=i {
732 total_grad[l_idx] += grad_combined[i] * epsilon[j];
733 l_idx += 1;
734 }
735 }
736 }
737
738 total_elbo /= n_samples as f64;
740 total_grad /= n_samples as f64;
741
742 let entropy = self.variational.entropy();
744 total_elbo += entropy;
745
746 let mut l_idx = dim;
748 for i in 0..dim {
749 for j in 0..=i {
750 if i == j {
751 let l_ii = self.variational.chol_factor[[i, i]];
752 if l_ii.abs() > 1e-15 {
753 total_grad[l_idx] += 1.0 / l_ii;
754 }
755 }
756 l_idx += 1;
757 }
758 }
759
760 Ok((total_elbo, total_grad))
761 }
762
763 fn transform_to_constrained(&self, eta: &Array1<f64>) -> Result<Array1<f64>> {
765 let mut theta = Array1::zeros(self.dim);
766 for i in 0..self.dim {
767 theta[i] = self.constraints[i].forward(eta[i]);
768 }
769 Ok(theta)
770 }
771}
772
773#[derive(Debug, Clone)]
779pub struct AdviResult {
780 pub variational: MeanFieldGaussian,
782 pub constraints: Vec<ParameterConstraint>,
784 pub constrained_means: Array1<f64>,
786 pub diagnostics: VariationalDiagnostics,
788 pub dim: usize,
790}
791
792impl AdviResult {
793 pub fn unconstrained_means(&self) -> &Array1<f64> {
795 &self.variational.means
796 }
797
798 pub fn unconstrained_stds(&self) -> Array1<f64> {
800 self.variational.stds()
801 }
802
803 pub fn constrained_means(&self) -> &Array1<f64> {
805 &self.constrained_means
806 }
807
808 pub fn sample_constrained(&self, epsilon: &Array1<f64>) -> Result<Array1<f64>> {
810 let eta = self.variational.sample(epsilon)?;
811 let mut theta = Array1::zeros(self.dim);
812 for i in 0..self.dim {
813 theta[i] = self.constraints[i].forward(eta[i]);
814 }
815 Ok(theta)
816 }
817
818 pub fn approximate_credible_intervals(&self, confidence: f64) -> Result<Array2<f64>> {
823 check_probability(confidence, "confidence")?;
824
825 let alpha = (1.0 - confidence) / 2.0;
826 let z_critical = super::normal_ppf(1.0 - alpha)?;
827
828 let stds = self.variational.stds();
829 let mut intervals = Array2::zeros((self.dim, 2));
830
831 for i in 0..self.dim {
832 let eta_low = self.variational.means[i] - z_critical * stds[i];
833 let eta_high = self.variational.means[i] + z_critical * stds[i];
834
835 let theta_low = self.constraints[i].forward(eta_low);
837 let theta_high = self.constraints[i].forward(eta_high);
838
839 intervals[[i, 0]] = theta_low.min(theta_high);
841 intervals[[i, 1]] = theta_low.max(theta_high);
842 }
843
844 Ok(intervals)
845 }
846}
847
848#[derive(Debug, Clone)]
850pub struct AdviFullRankResult {
851 pub variational: FullRankGaussian,
853 pub constraints: Vec<ParameterConstraint>,
855 pub constrained_means: Array1<f64>,
857 pub diagnostics: VariationalDiagnostics,
859 pub dim: usize,
861}
862
863impl AdviFullRankResult {
864 pub fn unconstrained_means(&self) -> &Array1<f64> {
866 &self.variational.mean
867 }
868
869 pub fn unconstrained_covariance(&self) -> Array2<f64> {
871 self.variational.covariance()
872 }
873
874 pub fn constrained_means(&self) -> &Array1<f64> {
876 &self.constrained_means
877 }
878
879 pub fn sample_constrained(&self, epsilon: &Array1<f64>) -> Result<Array1<f64>> {
881 let eta = self.variational.sample(epsilon)?;
882 let mut theta = Array1::zeros(self.dim);
883 for i in 0..self.dim {
884 theta[i] = self.constraints[i].forward(eta[i]);
885 }
886 Ok(theta)
887 }
888
889 pub fn approximate_credible_intervals(&self, confidence: f64) -> Result<Array2<f64>> {
891 check_probability(confidence, "confidence")?;
892
893 let alpha = (1.0 - confidence) / 2.0;
894 let z_critical = super::normal_ppf(1.0 - alpha)?;
895
896 let cov = self.variational.covariance();
897 let mut intervals = Array2::zeros((self.dim, 2));
898
899 for i in 0..self.dim {
900 let std_i = cov[[i, i]].sqrt();
901 let eta_low = self.variational.mean[i] - z_critical * std_i;
902 let eta_high = self.variational.mean[i] + z_critical * std_i;
903
904 let theta_low = self.constraints[i].forward(eta_low);
905 let theta_high = self.constraints[i].forward(eta_high);
906
907 intervals[[i, 0]] = theta_low.min(theta_high);
908 intervals[[i, 1]] = theta_low.max(theta_high);
909 }
910
911 Ok(intervals)
912 }
913}
914
915fn compute_log_det_jacobian(constraint: &ParameterConstraint, unconstrained: f64) -> f64 {
921 constraint.log_det_jacobian(unconstrained)
922}
923
924fn compute_grad_eta_from_theta(
926 dim: usize,
927 eta: &Array1<f64>,
928 grad_theta: &Array1<f64>,
929 constraints: &[ParameterConstraint],
930) -> Result<Array1<f64>> {
931 let mut grad_eta = Array1::zeros(dim);
932 for i in 0..dim {
933 let dtheta_deta = compute_transform_deriv(&constraints[i], eta[i]);
934 grad_eta[i] = grad_theta[i] * dtheta_deta;
935 }
936 Ok(grad_eta)
937}
938
939fn compute_grad_log_det(
941 dim: usize,
942 eta: &Array1<f64>,
943 constraints: &[ParameterConstraint],
944) -> Result<Array1<f64>> {
945 let mut grad = Array1::zeros(dim);
946 for i in 0..dim {
947 grad[i] = compute_grad_log_det_single(&constraints[i], eta[i]);
948 }
949 Ok(grad)
950}
951
952fn compute_transform_deriv(constraint: &ParameterConstraint, unconstrained: f64) -> f64 {
954 match constraint {
955 ParameterConstraint::Real => 1.0,
956 ParameterConstraint::Positive => unconstrained.exp(),
957 ParameterConstraint::UnitInterval => {
958 let s = 1.0 / (1.0 + (-unconstrained).exp());
959 s * (1.0 - s)
960 }
961 ParameterConstraint::Bounded { lower, upper } => {
962 let s = 1.0 / (1.0 + (-unconstrained).exp());
963 (upper - lower) * s * (1.0 - s)
964 }
965 ParameterConstraint::LowerBounded { .. } => unconstrained.exp(),
966 ParameterConstraint::UpperBounded { .. } => (-unconstrained).exp(),
967 ParameterConstraint::Simplex { .. } => {
968 let s = 1.0 / (1.0 + (-unconstrained).exp());
969 s * (1.0 - s)
970 }
971 }
972}
973
974fn compute_grad_log_det_single(constraint: &ParameterConstraint, unconstrained: f64) -> f64 {
976 match constraint {
977 ParameterConstraint::Real => 0.0,
978 ParameterConstraint::Positive => 1.0,
979 ParameterConstraint::UnitInterval => {
980 let s = 1.0 / (1.0 + (-unconstrained).exp());
981 1.0 - 2.0 * s
982 }
983 ParameterConstraint::Bounded { .. } => {
984 let s = 1.0 / (1.0 + (-unconstrained).exp());
985 1.0 - 2.0 * s
986 }
987 ParameterConstraint::LowerBounded { .. } => 1.0,
988 ParameterConstraint::UpperBounded { .. } => 1.0,
989 ParameterConstraint::Simplex { .. } => {
990 let s = 1.0 / (1.0 + (-unconstrained).exp());
991 1.0 - 2.0 * s
992 }
993 }
994}
995
996fn generate_standard_normal_advi(dim: usize, seed: u64) -> Array1<f64> {
998 let mut result = Array1::zeros(dim);
999 let golden_ratio = 1.618033988749895;
1000
1001 for i in 0..dim {
1002 let u1 = ((seed as f64 * golden_ratio + i as f64 * 0.7548776662466927) % 1.0).abs();
1003 let u2 = ((seed as f64 * 0.5698402909980532 + i as f64 * golden_ratio) % 1.0).abs();
1004
1005 let u1_safe = u1.max(1e-10).min(1.0 - 1e-10);
1006 let u2_safe = u2.max(1e-10).min(1.0 - 1e-10);
1007
1008 let r = (-2.0 * u1_safe.ln()).sqrt();
1009 let theta = 2.0 * PI * u2_safe;
1010 result[i] = r * theta.cos();
1011 }
1012
1013 result
1014}
1015
1016#[cfg(test)]
1021mod tests {
1022 use super::*;
1023 use scirs2_core::ndarray::Array1;
1024
1025 #[test]
1026 fn test_constraint_real() {
1027 let c = ParameterConstraint::Real;
1028 assert!((c.forward(1.5) - 1.5).abs() < 1e-10);
1029 let inv = c.inverse(1.5).expect("should invert");
1030 assert!((inv - 1.5).abs() < 1e-10);
1031 assert!((c.log_det_jacobian(1.5)).abs() < 1e-10);
1032 }
1033
1034 #[test]
1035 fn test_constraint_positive() {
1036 let c = ParameterConstraint::Positive;
1037 assert!((c.forward(0.0) - 1.0).abs() < 1e-10);
1039 assert!((c.forward(1.0) - 1.0_f64.exp()).abs() < 1e-10);
1041 let inv = c.inverse(1.0_f64.exp()).expect("should invert");
1043 assert!((inv - 1.0).abs() < 1e-10);
1044 assert!(c.inverse(-1.0).is_err());
1046 }
1047
1048 #[test]
1049 fn test_constraint_unit_interval() {
1050 let c = ParameterConstraint::UnitInterval;
1051 assert!((c.forward(0.0) - 0.5).abs() < 1e-10);
1053 let inv = c.inverse(0.5).expect("should invert");
1055 assert!(inv.abs() < 1e-10);
1056 assert!(c.inverse(0.0).is_err());
1058 assert!(c.inverse(1.0).is_err());
1059 }
1060
1061 #[test]
1062 fn test_constraint_bounded() {
1063 let c = ParameterConstraint::Bounded {
1064 lower: -1.0,
1065 upper: 1.0,
1066 };
1067 assert!((c.forward(0.0)).abs() < 1e-10);
1069 let inv = c.inverse(0.0).expect("should invert");
1071 assert!(inv.abs() < 1e-10);
1072 }
1073
1074 #[test]
1075 fn test_constraint_lower_bounded() {
1076 let c = ParameterConstraint::LowerBounded { lower: 2.0 };
1077 assert!((c.forward(0.0) - 3.0).abs() < 1e-10);
1079 let inv = c.inverse(3.0).expect("should invert");
1080 assert!(inv.abs() < 1e-10);
1081 assert!(c.inverse(1.0).is_err());
1082 }
1083
1084 #[test]
1085 fn test_constraint_roundtrip() {
1086 let constraints = vec![
1087 ParameterConstraint::Real,
1088 ParameterConstraint::Positive,
1089 ParameterConstraint::UnitInterval,
1090 ParameterConstraint::Bounded {
1091 lower: 0.0,
1092 upper: 10.0,
1093 },
1094 ];
1095
1096 let unconstrained_values = vec![0.5, 1.0, -0.5, 2.0];
1097 for (c, &eta) in constraints.iter().zip(unconstrained_values.iter()) {
1098 let theta = c.forward(eta);
1099 let eta_back = c.inverse(theta).expect("should invert");
1100 assert!(
1101 (eta_back - eta).abs() < 1e-8,
1102 "Roundtrip failed for {:?}: {} -> {} -> {}",
1103 c,
1104 eta,
1105 theta,
1106 eta_back
1107 );
1108 }
1109 }
1110
1111 #[test]
1112 fn test_advi_mean_field_creation() {
1113 let constraints = vec![ParameterConstraint::Real, ParameterConstraint::Positive];
1114 let config = AdviConfig::default();
1115 let advi = AdviMeanField::new(constraints, config).expect("should create");
1116 assert_eq!(advi.dim, 2);
1117 }
1118
1119 #[test]
1120 fn test_advi_mean_field_simple_gaussian() {
1121 let target_mean = Array1::from_vec(vec![1.0, -2.0]);
1123 let target_precision = 2.0; let constraints = vec![ParameterConstraint::Real, ParameterConstraint::Real];
1126 let config = AdviConfig {
1127 max_iter: 500,
1128 n_mc_samples: 1,
1129 lr_schedule: LearningRateSchedule::Adam {
1130 lr: 0.05,
1131 beta1: 0.9,
1132 beta2: 0.999,
1133 epsilon: 1e-8,
1134 },
1135 tol: 1e-6,
1136 convergence_window: 20,
1137 ..AdviConfig::default()
1138 };
1139
1140 let mut advi = AdviMeanField::new(constraints, config).expect("should create");
1141
1142 let tm = target_mean.clone();
1143 let result = advi
1144 .fit(move |theta: &Array1<f64>| {
1145 let diff = theta - &tm;
1146 let log_p = -0.5 * target_precision * diff.dot(&diff);
1147 let grad = &diff * (-target_precision);
1148 Ok((log_p, grad))
1149 })
1150 .expect("should fit");
1151
1152 assert!(
1154 result.diagnostics.n_iterations > 0,
1155 "Should have performed iterations"
1156 );
1157 assert!(
1158 result.diagnostics.final_elbo.is_finite(),
1159 "ELBO should be finite"
1160 );
1161 }
1162
1163 #[test]
1164 fn test_advi_full_rank_creation() {
1165 let constraints = vec![
1166 ParameterConstraint::Real,
1167 ParameterConstraint::Positive,
1168 ParameterConstraint::UnitInterval,
1169 ];
1170 let config = AdviConfig::default();
1171 let advi = AdviFullRank::new(constraints, config).expect("should create");
1172 assert_eq!(advi.dim, 3);
1173 }
1174
1175 #[test]
1176 fn test_advi_full_rank_simple() {
1177 let constraints = vec![ParameterConstraint::Real, ParameterConstraint::Real];
1178 let config = AdviConfig {
1179 max_iter: 200,
1180 n_mc_samples: 1,
1181 lr_schedule: LearningRateSchedule::Adam {
1182 lr: 0.02,
1183 beta1: 0.9,
1184 beta2: 0.999,
1185 epsilon: 1e-8,
1186 },
1187 tol: 1e-5,
1188 convergence_window: 20,
1189 ..AdviConfig::default()
1190 };
1191
1192 let mut advi = AdviFullRank::new(constraints, config).expect("should create");
1193
1194 let result = advi
1195 .fit(|theta: &Array1<f64>| {
1196 let log_p = -0.5 * theta.dot(theta);
1198 let grad = theta * (-1.0);
1199 Ok((log_p, grad))
1200 })
1201 .expect("should fit");
1202
1203 assert!(result.diagnostics.n_iterations > 0);
1204 assert!(result.diagnostics.final_elbo.is_finite());
1205 }
1206
1207 #[test]
1208 fn test_advi_with_constrained_params() {
1209 let constraints = vec![
1211 ParameterConstraint::Real, ParameterConstraint::Positive, ];
1214
1215 let config = AdviConfig {
1216 max_iter: 300,
1217 n_mc_samples: 1,
1218 lr_schedule: LearningRateSchedule::Adam {
1219 lr: 0.01,
1220 beta1: 0.9,
1221 beta2: 0.999,
1222 epsilon: 1e-8,
1223 },
1224 tol: 1e-5,
1225 convergence_window: 30,
1226 ..AdviConfig::default()
1227 };
1228
1229 let mut advi = AdviMeanField::new(constraints, config).expect("should create");
1230
1231 let result = advi
1232 .fit(|theta: &Array1<f64>| {
1233 let log_p = -0.5 * (theta[0] - 1.0).powi(2) - 2.0 * (theta[1] - 2.0).powi(2);
1235 let mut grad = Array1::zeros(2);
1236 grad[0] = -(theta[0] - 1.0);
1237 grad[1] = -4.0 * (theta[1] - 2.0);
1238 Ok((log_p, grad))
1239 })
1240 .expect("should fit");
1241
1242 assert!(
1244 result.constrained_means[1] > 0.0,
1245 "Positive-constrained parameter should be > 0, got {}",
1246 result.constrained_means[1]
1247 );
1248 }
1249
1250 #[test]
1251 fn test_advi_result_credible_intervals() {
1252 let constraints = vec![ParameterConstraint::Real, ParameterConstraint::Positive];
1253 let config = AdviConfig {
1254 max_iter: 100,
1255 ..AdviConfig::default()
1256 };
1257
1258 let mut advi = AdviMeanField::new(constraints, config).expect("should create");
1259
1260 let result = advi
1261 .fit(|theta: &Array1<f64>| {
1262 let log_p = -0.5 * theta.dot(theta);
1263 let grad = theta * (-1.0);
1264 Ok((log_p, grad))
1265 })
1266 .expect("should fit");
1267
1268 let intervals = result
1269 .approximate_credible_intervals(0.95)
1270 .expect("should compute intervals");
1271
1272 assert_eq!(intervals.nrows(), 2);
1273 assert_eq!(intervals.ncols(), 2);
1274 for i in 0..2 {
1276 assert!(
1277 intervals[[i, 0]] <= intervals[[i, 1]],
1278 "Lower bound should be <= upper bound at dim {}",
1279 i
1280 );
1281 }
1282 }
1283
1284 #[test]
1285 fn test_log_det_jacobian_positive() {
1286 let c = ParameterConstraint::Positive;
1287 assert!((c.log_det_jacobian(0.0)).abs() < 1e-10);
1289 assert!((c.log_det_jacobian(1.0) - 1.0).abs() < 1e-10);
1290 assert!((c.log_det_jacobian(-1.0) - (-1.0)).abs() < 1e-10);
1291 }
1292
1293 #[test]
1294 fn test_log_det_jacobian_unit_interval() {
1295 let c = ParameterConstraint::UnitInterval;
1296 let expected = (0.25_f64).ln();
1298 assert!(
1299 (c.log_det_jacobian(0.0) - expected).abs() < 1e-10,
1300 "log det J at 0 should be {}, got {}",
1301 expected,
1302 c.log_det_jacobian(0.0)
1303 );
1304 }
1305}