Skip to main content

scirs2_stats/bayesian/variational/
advi.rs

1//! Automatic Differentiation Variational Inference (ADVI)
2//!
3//! This module implements ADVI, which automatically transforms constrained
4//! parameters to unconstrained space and fits a Gaussian variational
5//! approximation. Key features:
6//!
7//! - Mean-field Gaussian approximation in unconstrained space
8//! - Full-rank Gaussian approximation in unconstrained space
9//! - Automatic parameter transformations (log, logit, etc.)
10//! - ELBO computation with the reparameterization trick
11//! - Support for custom log probability functions
12
13use crate::error::{StatsError, StatsResult as Result};
14use scirs2_core::ndarray::{Array1, Array2};
15use scirs2_core::validation::*;
16use std::f64::consts::PI;
17
18use super::svi::{AdamState, LearningRateSchedule};
19use super::{FullRankGaussian, MeanFieldGaussian, VariationalDiagnostics};
20
21// ============================================================================
22// Parameter Transformations
23// ============================================================================
24
25/// Type of parameter constraint and its corresponding transformation
26#[derive(Debug, Clone)]
27pub enum ParameterConstraint {
28    /// Unconstrained real-valued parameter (identity transform)
29    Real,
30    /// Positive parameter (log transform: unconstrained -> positive)
31    Positive,
32    /// Parameter on (0, 1) (logistic transform)
33    UnitInterval,
34    /// Parameter on (lower, upper) (scaled logistic transform)
35    Bounded {
36        /// Lower bound
37        lower: f64,
38        /// Upper bound
39        upper: f64,
40    },
41    /// Simplex constraint (sum to 1) using stick-breaking transform
42    Simplex {
43        /// Dimension of the simplex
44        dim: usize,
45    },
46    /// Lower-bounded parameter (shifted log transform)
47    LowerBounded {
48        /// Lower bound
49        lower: f64,
50    },
51    /// Upper-bounded parameter (reflected log transform)
52    UpperBounded {
53        /// Upper bound
54        upper: f64,
55    },
56}
57
58impl ParameterConstraint {
59    /// Transform from unconstrained to constrained space
60    pub fn forward(&self, unconstrained: f64) -> f64 {
61        match self {
62            ParameterConstraint::Real => unconstrained,
63            ParameterConstraint::Positive => unconstrained.exp(),
64            ParameterConstraint::UnitInterval => 1.0 / (1.0 + (-unconstrained).exp()),
65            ParameterConstraint::Bounded { lower, upper } => {
66                let sigmoid = 1.0 / (1.0 + (-unconstrained).exp());
67                lower + (upper - lower) * sigmoid
68            }
69            ParameterConstraint::LowerBounded { lower } => lower + unconstrained.exp(),
70            ParameterConstraint::UpperBounded { upper } => upper - (-unconstrained).exp(),
71            ParameterConstraint::Simplex { .. } => {
72                // For simplex, forward transform is handled element-wise
73                // via stick-breaking; this is a per-element sigmoid
74                1.0 / (1.0 + (-unconstrained).exp())
75            }
76        }
77    }
78
79    /// Transform from constrained to unconstrained space
80    pub fn inverse(&self, constrained: f64) -> Result<f64> {
81        match self {
82            ParameterConstraint::Real => Ok(constrained),
83            ParameterConstraint::Positive => {
84                if constrained <= 0.0 {
85                    return Err(StatsError::InvalidArgument(format!(
86                        "Positive constraint requires value > 0, got {}",
87                        constrained
88                    )));
89                }
90                Ok(constrained.ln())
91            }
92            ParameterConstraint::UnitInterval => {
93                if constrained <= 0.0 || constrained >= 1.0 {
94                    return Err(StatsError::InvalidArgument(format!(
95                        "Unit interval constraint requires 0 < value < 1, got {}",
96                        constrained
97                    )));
98                }
99                Ok((constrained / (1.0 - constrained)).ln())
100            }
101            ParameterConstraint::Bounded { lower, upper } => {
102                if constrained <= *lower || constrained >= *upper {
103                    return Err(StatsError::InvalidArgument(format!(
104                        "Bounded constraint requires {} < value < {}, got {}",
105                        lower, upper, constrained
106                    )));
107                }
108                let normalized = (constrained - lower) / (upper - lower);
109                Ok((normalized / (1.0 - normalized)).ln())
110            }
111            ParameterConstraint::LowerBounded { lower } => {
112                if constrained <= *lower {
113                    return Err(StatsError::InvalidArgument(format!(
114                        "Lower-bounded constraint requires value > {}, got {}",
115                        lower, constrained
116                    )));
117                }
118                Ok((constrained - lower).ln())
119            }
120            ParameterConstraint::UpperBounded { upper } => {
121                if constrained >= *upper {
122                    return Err(StatsError::InvalidArgument(format!(
123                        "Upper-bounded constraint requires value < {}, got {}",
124                        upper, constrained
125                    )));
126                }
127                Ok(-((*upper - constrained).ln()))
128            }
129            ParameterConstraint::Simplex { .. } => {
130                if constrained <= 0.0 || constrained >= 1.0 {
131                    return Err(StatsError::InvalidArgument(format!(
132                        "Simplex element must be in (0, 1), got {}",
133                        constrained
134                    )));
135                }
136                Ok((constrained / (1.0 - constrained)).ln())
137            }
138        }
139    }
140
141    /// Compute the log absolute determinant of the Jacobian of the forward transform
142    ///
143    /// This is needed for correcting the density when transforming from
144    /// unconstrained to constrained space:
145    /// p(constrained) = p_unconstrained(inverse(constrained)) * |det J^{-1}|
146    ///
147    /// Equivalently, for the ELBO in unconstrained space:
148    /// log p(forward(eta)) + log |det J_forward(eta)|
149    pub fn log_det_jacobian(&self, unconstrained: f64) -> f64 {
150        match self {
151            ParameterConstraint::Real => 0.0,
152            ParameterConstraint::Positive => {
153                // d/d_eta exp(eta) = exp(eta), so log|det J| = eta
154                unconstrained
155            }
156            ParameterConstraint::UnitInterval => {
157                // sigmoid'(eta) = sigmoid(eta) * (1 - sigmoid(eta))
158                let s = 1.0 / (1.0 + (-unconstrained).exp());
159                (s * (1.0 - s)).ln()
160            }
161            ParameterConstraint::Bounded { lower, upper } => {
162                let s = 1.0 / (1.0 + (-unconstrained).exp());
163                ((upper - lower) * s * (1.0 - s)).ln()
164            }
165            ParameterConstraint::LowerBounded { .. } => unconstrained,
166            ParameterConstraint::UpperBounded { .. } => unconstrained,
167            ParameterConstraint::Simplex { .. } => {
168                let s = 1.0 / (1.0 + (-unconstrained).exp());
169                (s * (1.0 - s)).ln()
170            }
171        }
172    }
173}
174
175// ============================================================================
176// ADVI Configuration
177// ============================================================================
178
179/// Configuration for ADVI
180#[derive(Debug, Clone)]
181pub struct AdviConfig {
182    /// Maximum number of optimization iterations
183    pub max_iter: usize,
184    /// Convergence tolerance (relative ELBO change)
185    pub tol: f64,
186    /// Number of Monte Carlo samples for ELBO gradient estimation
187    pub n_mc_samples: usize,
188    /// Learning rate schedule
189    pub lr_schedule: LearningRateSchedule,
190    /// Gradient clipping threshold (0 = no clipping)
191    pub grad_clip: f64,
192    /// Diagnostic output interval (0 = no diagnostics)
193    pub diagnostic_interval: usize,
194    /// Seed for reproducibility
195    pub seed: u64,
196    /// Convergence window (number of iterations to average ELBO over)
197    pub convergence_window: usize,
198}
199
200impl Default for AdviConfig {
201    fn default() -> Self {
202        Self {
203            max_iter: 10000,
204            tol: 1e-4,
205            n_mc_samples: 1,
206            lr_schedule: LearningRateSchedule::default_adam(),
207            grad_clip: 10.0,
208            diagnostic_interval: 100,
209            seed: 42,
210            convergence_window: 50,
211        }
212    }
213}
214
215// ============================================================================
216// ADVI (Mean-Field)
217// ============================================================================
218
219/// Automatic Differentiation Variational Inference with mean-field Gaussian
220///
221/// ADVI transforms constrained parameters to unconstrained space and
222/// fits a diagonal-covariance Gaussian variational approximation:
223///
224/// 1. Transform constrained parameters theta to unconstrained eta = T^{-1}(theta)
225/// 2. Fit q(eta) = N(mu, diag(sigma^2))
226/// 3. ELBO = E_q[log p(T(eta), x) + log |det J_T(eta)|] - E_q[log q(eta)]
227///
228/// The user provides:
229/// - A log joint density function log p(theta, data) in the *constrained* space
230/// - Parameter constraints for each dimension
231#[derive(Debug, Clone)]
232pub struct AdviMeanField {
233    /// Variational distribution in unconstrained space
234    pub variational: MeanFieldGaussian,
235    /// Parameter constraints
236    pub constraints: Vec<ParameterConstraint>,
237    /// Configuration
238    pub config: AdviConfig,
239    /// Diagnostics
240    pub diagnostics: VariationalDiagnostics,
241    /// Dimensionality
242    pub dim: usize,
243}
244
245impl AdviMeanField {
246    /// Create a new ADVI mean-field instance
247    ///
248    /// # Arguments
249    /// * `constraints` - Constraint for each parameter dimension
250    /// * `config` - ADVI configuration
251    pub fn new(constraints: Vec<ParameterConstraint>, config: AdviConfig) -> Result<Self> {
252        let dim = constraints.len();
253        if dim == 0 {
254            return Err(StatsError::InvalidArgument(
255                "Must have at least one parameter".to_string(),
256            ));
257        }
258
259        let variational = MeanFieldGaussian::new(dim)?;
260
261        Ok(Self {
262            variational,
263            constraints,
264            config,
265            diagnostics: VariationalDiagnostics::new(),
266            dim,
267        })
268    }
269
270    /// Create ADVI with all unconstrained parameters
271    pub fn new_unconstrained(dim: usize, config: AdviConfig) -> Result<Self> {
272        let constraints = vec![ParameterConstraint::Real; dim];
273        Self::new(constraints, config)
274    }
275
276    /// Initialize variational parameters from constrained-space values
277    pub fn initialize_from_constrained(&mut self, theta: &Array1<f64>) -> Result<()> {
278        if theta.len() != self.dim {
279            return Err(StatsError::DimensionMismatch(format!(
280                "theta length ({}) must match dimension ({})",
281                theta.len(),
282                self.dim
283            )));
284        }
285
286        let mut eta = Array1::zeros(self.dim);
287        for i in 0..self.dim {
288            eta[i] = self.constraints[i].inverse(theta[i])?;
289        }
290        self.variational.means = eta;
291        // Initialize with moderate uncertainty
292        self.variational.log_stds = Array1::from_elem(self.dim, -1.0);
293        Ok(())
294    }
295
296    /// Run ADVI optimization
297    ///
298    /// # Arguments
299    /// * `log_joint` - Function computing log p(theta) in the constrained space.
300    ///   Takes a constrained parameter vector and returns (log_prob, gradient_wrt_theta).
301    ///
302    /// # Returns
303    /// * ADVI result with optimized variational distribution
304    pub fn fit<F>(&mut self, log_joint: F) -> Result<AdviResult>
305    where
306        F: Fn(&Array1<f64>) -> Result<(f64, Array1<f64>)>,
307    {
308        let n_params = self.variational.n_params();
309        let mut adam_state = if let LearningRateSchedule::Adam {
310            lr,
311            beta1,
312            beta2,
313            epsilon,
314        } = &self.config.lr_schedule
315        {
316            Some(AdamState::new(n_params, *lr, *beta1, *beta2, *epsilon)?)
317        } else {
318            None
319        };
320
321        for iter in 0..self.config.max_iter {
322            // Compute stochastic ELBO gradient
323            let (elbo, grad) = self.compute_elbo_gradient(&log_joint, iter as u64)?;
324
325            self.diagnostics.record_elbo(elbo);
326
327            let grad_norm = grad.dot(&grad).sqrt();
328            self.diagnostics.record_gradient_norm(grad_norm);
329
330            // Clip gradient
331            let clipped_grad = if self.config.grad_clip > 0.0 && grad_norm > self.config.grad_clip {
332                &grad * (self.config.grad_clip / grad_norm)
333            } else {
334                grad
335            };
336
337            // Get current parameters
338            let current_params = self.variational.get_params();
339
340            // Apply update
341            let new_params = if let Some(ref mut adam) = adam_state {
342                let update = adam.compute_update(&clipped_grad)?;
343                &current_params + &update
344            } else {
345                let lr = self.config.lr_schedule.get_lr(iter);
346                &current_params + &(&clipped_grad * lr)
347            };
348
349            let param_change = (&new_params - &current_params).mapv(|x| x * x).sum().sqrt();
350            self.diagnostics.record_param_change(param_change);
351
352            self.variational.set_params(&new_params)?;
353
354            // Check convergence
355            if iter > self.config.convergence_window {
356                if let Some(rel_change) = self
357                    .diagnostics
358                    .relative_elbo_change(self.config.convergence_window)
359                {
360                    if rel_change < self.config.tol {
361                        self.diagnostics.converged = true;
362                        break;
363                    }
364                }
365            }
366        }
367
368        // Transform results back to constrained space
369        let constrained_means = self.transform_to_constrained(&self.variational.means)?;
370
371        Ok(AdviResult {
372            variational: self.variational.clone(),
373            constraints: self.constraints.clone(),
374            constrained_means,
375            diagnostics: self.diagnostics.clone(),
376            dim: self.dim,
377        })
378    }
379
380    /// Compute ELBO and its gradient using Monte Carlo estimation
381    fn compute_elbo_gradient<F>(&self, log_joint: &F, seed: u64) -> Result<(f64, Array1<f64>)>
382    where
383        F: Fn(&Array1<f64>) -> Result<(f64, Array1<f64>)>,
384    {
385        let dim = self.dim;
386        let n_samples = self.config.n_mc_samples.max(1);
387        let n_params = 2 * dim;
388
389        let mut total_elbo = 0.0;
390        let mut total_grad = Array1::zeros(n_params);
391
392        let stds = self.variational.stds();
393
394        for s in 0..n_samples {
395            // Generate epsilon ~ N(0, I)
396            let epsilon = generate_standard_normal_advi(dim, seed * 1000 + s as u64);
397
398            // Reparameterize: eta = mu + sigma * epsilon
399            let eta = self.variational.sample(&epsilon)?;
400
401            // Transform to constrained space
402            let theta = self.transform_to_constrained(&eta)?;
403
404            // Compute log joint in constrained space
405            let (log_p, grad_theta) = log_joint(&theta)?;
406
407            // Compute log |det J| (sum of per-element log det Jacobians)
408            let mut log_det_j = 0.0;
409            for i in 0..dim {
410                log_det_j += self.constraints[i].log_det_jacobian(eta[i]);
411            }
412
413            // ELBO contribution (before subtracting entropy, which we handle analytically)
414            total_elbo += log_p + log_det_j;
415
416            // Gradient of log p wrt eta (chain rule through transform)
417            let grad_eta = self.compute_grad_eta(&eta, &grad_theta)?;
418
419            // Gradient of log |det J| wrt eta
420            let grad_log_det_j = self.compute_grad_log_det_j(&eta)?;
421
422            // Combined gradient wrt eta
423            let grad_combined = &grad_eta + &grad_log_det_j;
424
425            // Gradient wrt variational params (mu, log_sigma)
426            for i in 0..dim {
427                // d/d_mu = d/d_eta (since eta = mu + sigma*eps, d_eta/d_mu = 1)
428                total_grad[i] += grad_combined[i];
429                // d/d_log_sigma = d/d_eta * d_eta/d_log_sigma
430                //               = grad_combined[i] * epsilon[i] * sigma[i]
431                total_grad[dim + i] += grad_combined[i] * epsilon[i] * stds[i];
432            }
433        }
434
435        // Average over samples
436        total_elbo /= n_samples as f64;
437        total_grad /= n_samples as f64;
438
439        // Add entropy and its gradient
440        let entropy = self.variational.entropy();
441        total_elbo += entropy;
442
443        // Gradient of entropy wrt log_sigma is 1 for each dimension
444        for i in 0..dim {
445            total_grad[dim + i] += 1.0;
446        }
447
448        Ok((total_elbo, total_grad))
449    }
450
451    /// Transform unconstrained parameters to constrained space
452    fn transform_to_constrained(&self, eta: &Array1<f64>) -> Result<Array1<f64>> {
453        let mut theta = Array1::zeros(self.dim);
454        for i in 0..self.dim {
455            theta[i] = self.constraints[i].forward(eta[i]);
456        }
457        Ok(theta)
458    }
459
460    /// Compute gradient of log p wrt unconstrained eta using chain rule
461    fn compute_grad_eta(&self, eta: &Array1<f64>, grad_theta: &Array1<f64>) -> Result<Array1<f64>> {
462        let mut grad_eta = Array1::zeros(self.dim);
463        for i in 0..self.dim {
464            // d log_p / d eta_i = d log_p / d theta_i * d theta_i / d eta_i
465            let dtheta_deta = self.compute_transform_derivative(i, eta[i]);
466            grad_eta[i] = grad_theta[i] * dtheta_deta;
467        }
468        Ok(grad_eta)
469    }
470
471    /// Compute derivative of forward transform for parameter i
472    fn compute_transform_derivative(&self, i: usize, unconstrained: f64) -> f64 {
473        match &self.constraints[i] {
474            ParameterConstraint::Real => 1.0,
475            ParameterConstraint::Positive => unconstrained.exp(),
476            ParameterConstraint::UnitInterval => {
477                let s = 1.0 / (1.0 + (-unconstrained).exp());
478                s * (1.0 - s)
479            }
480            ParameterConstraint::Bounded { lower, upper } => {
481                let s = 1.0 / (1.0 + (-unconstrained).exp());
482                (upper - lower) * s * (1.0 - s)
483            }
484            ParameterConstraint::LowerBounded { .. } => unconstrained.exp(),
485            ParameterConstraint::UpperBounded { .. } => (-unconstrained).exp(),
486            ParameterConstraint::Simplex { .. } => {
487                let s = 1.0 / (1.0 + (-unconstrained).exp());
488                s * (1.0 - s)
489            }
490        }
491    }
492
493    /// Compute gradient of log |det J| wrt unconstrained parameters
494    fn compute_grad_log_det_j(&self, eta: &Array1<f64>) -> Result<Array1<f64>> {
495        let mut grad = Array1::zeros(self.dim);
496        for i in 0..self.dim {
497            grad[i] = self.compute_grad_log_det_j_single(i, eta[i]);
498        }
499        Ok(grad)
500    }
501
502    /// Compute d/d_eta log|det J_forward(eta)| for a single parameter
503    fn compute_grad_log_det_j_single(&self, i: usize, unconstrained: f64) -> f64 {
504        match &self.constraints[i] {
505            ParameterConstraint::Real => 0.0,
506            ParameterConstraint::Positive => 1.0,
507            ParameterConstraint::UnitInterval => {
508                let s = 1.0 / (1.0 + (-unconstrained).exp());
509                1.0 - 2.0 * s
510            }
511            ParameterConstraint::Bounded { .. } => {
512                let s = 1.0 / (1.0 + (-unconstrained).exp());
513                1.0 - 2.0 * s
514            }
515            ParameterConstraint::LowerBounded { .. } => 1.0,
516            ParameterConstraint::UpperBounded { .. } => 1.0,
517            ParameterConstraint::Simplex { .. } => {
518                let s = 1.0 / (1.0 + (-unconstrained).exp());
519                1.0 - 2.0 * s
520            }
521        }
522    }
523}
524
525// ============================================================================
526// ADVI (Full-Rank)
527// ============================================================================
528
529/// Automatic Differentiation Variational Inference with full-rank Gaussian
530///
531/// Like AdviMeanField but uses a Gaussian with full covariance matrix,
532/// parameterized by its Cholesky factor. This can capture posterior
533/// correlations but has O(d^2) parameters.
534///
535/// q(eta) = N(mu, L L^T) where L is lower-triangular
536#[derive(Debug, Clone)]
537pub struct AdviFullRank {
538    /// Variational distribution in unconstrained space
539    pub variational: FullRankGaussian,
540    /// Parameter constraints
541    pub constraints: Vec<ParameterConstraint>,
542    /// Configuration
543    pub config: AdviConfig,
544    /// Diagnostics
545    pub diagnostics: VariationalDiagnostics,
546    /// Dimensionality
547    pub dim: usize,
548}
549
550impl AdviFullRank {
551    /// Create a new full-rank ADVI instance
552    pub fn new(constraints: Vec<ParameterConstraint>, config: AdviConfig) -> Result<Self> {
553        let dim = constraints.len();
554        if dim == 0 {
555            return Err(StatsError::InvalidArgument(
556                "Must have at least one parameter".to_string(),
557            ));
558        }
559
560        let variational = FullRankGaussian::new(dim)?;
561
562        Ok(Self {
563            variational,
564            constraints,
565            config,
566            diagnostics: VariationalDiagnostics::new(),
567            dim,
568        })
569    }
570
571    /// Create full-rank ADVI with all unconstrained parameters
572    pub fn new_unconstrained(dim: usize, config: AdviConfig) -> Result<Self> {
573        let constraints = vec![ParameterConstraint::Real; dim];
574        Self::new(constraints, config)
575    }
576
577    /// Initialize variational parameters from constrained-space values
578    pub fn initialize_from_constrained(&mut self, theta: &Array1<f64>) -> Result<()> {
579        if theta.len() != self.dim {
580            return Err(StatsError::DimensionMismatch(format!(
581                "theta length ({}) must match dimension ({})",
582                theta.len(),
583                self.dim
584            )));
585        }
586
587        let mut eta = Array1::zeros(self.dim);
588        for i in 0..self.dim {
589            eta[i] = self.constraints[i].inverse(theta[i])?;
590        }
591        self.variational.mean = eta;
592        // Initialize with small identity-like covariance
593        self.variational.chol_factor = Array2::eye(self.dim) * 0.1;
594        Ok(())
595    }
596
597    /// Run full-rank ADVI optimization
598    ///
599    /// # Arguments
600    /// * `log_joint` - Function computing log p(theta) in the constrained space.
601    ///   Takes a constrained parameter vector and returns (log_prob, gradient_wrt_theta).
602    pub fn fit<F>(&mut self, log_joint: F) -> Result<AdviFullRankResult>
603    where
604        F: Fn(&Array1<f64>) -> Result<(f64, Array1<f64>)>,
605    {
606        let n_params = self.variational.n_params();
607        let mut adam_state = if let LearningRateSchedule::Adam {
608            lr,
609            beta1,
610            beta2,
611            epsilon,
612        } = &self.config.lr_schedule
613        {
614            Some(AdamState::new(n_params, *lr, *beta1, *beta2, *epsilon)?)
615        } else {
616            None
617        };
618
619        for iter in 0..self.config.max_iter {
620            // Compute stochastic ELBO gradient
621            let (elbo, grad) = self.compute_elbo_gradient_full_rank(&log_joint, iter as u64)?;
622
623            self.diagnostics.record_elbo(elbo);
624
625            let grad_norm = grad.dot(&grad).sqrt();
626            self.diagnostics.record_gradient_norm(grad_norm);
627
628            // Clip gradient
629            let clipped_grad = if self.config.grad_clip > 0.0 && grad_norm > self.config.grad_clip {
630                &grad * (self.config.grad_clip / grad_norm)
631            } else {
632                grad
633            };
634
635            // Get current parameters
636            let current_params = self.variational.get_params();
637
638            // Apply update
639            let new_params = if let Some(ref mut adam) = adam_state {
640                let update = adam.compute_update(&clipped_grad)?;
641                &current_params + &update
642            } else {
643                let lr = self.config.lr_schedule.get_lr(iter);
644                &current_params + &(&clipped_grad * lr)
645            };
646
647            let param_change = (&new_params - &current_params).mapv(|x| x * x).sum().sqrt();
648            self.diagnostics.record_param_change(param_change);
649
650            self.variational.set_params(&new_params)?;
651
652            // Check convergence
653            if iter > self.config.convergence_window {
654                if let Some(rel_change) = self
655                    .diagnostics
656                    .relative_elbo_change(self.config.convergence_window)
657                {
658                    if rel_change < self.config.tol {
659                        self.diagnostics.converged = true;
660                        break;
661                    }
662                }
663            }
664        }
665
666        // Transform results back to constrained space
667        let constrained_means = self.transform_to_constrained(&self.variational.mean)?;
668
669        Ok(AdviFullRankResult {
670            variational: self.variational.clone(),
671            constraints: self.constraints.clone(),
672            constrained_means,
673            diagnostics: self.diagnostics.clone(),
674            dim: self.dim,
675        })
676    }
677
678    /// Compute ELBO and gradient for full-rank approximation
679    fn compute_elbo_gradient_full_rank<F>(
680        &self,
681        log_joint: &F,
682        seed: u64,
683    ) -> Result<(f64, Array1<f64>)>
684    where
685        F: Fn(&Array1<f64>) -> Result<(f64, Array1<f64>)>,
686    {
687        let dim = self.dim;
688        let n_samples = self.config.n_mc_samples.max(1);
689        let n_params = self.variational.n_params();
690
691        let mut total_elbo = 0.0;
692        let mut total_grad = Array1::zeros(n_params);
693
694        let n_tril = dim * (dim + 1) / 2;
695
696        for s in 0..n_samples {
697            // Generate epsilon ~ N(0, I)
698            let epsilon = generate_standard_normal_advi(dim, seed * 1000 + s as u64);
699
700            // Reparameterize: eta = mu + L * epsilon
701            let eta = self.variational.sample(&epsilon)?;
702
703            // Transform to constrained space
704            let theta = self.transform_to_constrained(&eta)?;
705
706            // Compute log joint
707            let (log_p, grad_theta) = log_joint(&theta)?;
708
709            // Compute log |det J|
710            let mut log_det_j = 0.0;
711            for i in 0..dim {
712                log_det_j += compute_log_det_jacobian(&self.constraints[i], eta[i]);
713            }
714
715            total_elbo += log_p + log_det_j;
716
717            // Gradient wrt eta
718            let grad_eta = compute_grad_eta_from_theta(dim, &eta, &grad_theta, &self.constraints)?;
719            let grad_log_det = compute_grad_log_det(dim, &eta, &self.constraints)?;
720            let grad_combined: Array1<f64> = &grad_eta + &grad_log_det;
721
722            // Gradient wrt mean: d/d_mu = grad_combined (since eta = mu + L*eps)
723            for i in 0..dim {
724                total_grad[i] += grad_combined[i];
725            }
726
727            // Gradient wrt L (lower triangular elements)
728            // d/d L_{ij} = grad_combined[i] * epsilon[j] for j <= i
729            let mut l_idx = dim;
730            for i in 0..dim {
731                for j in 0..=i {
732                    total_grad[l_idx] += grad_combined[i] * epsilon[j];
733                    l_idx += 1;
734                }
735            }
736        }
737
738        // Average over samples
739        total_elbo /= n_samples as f64;
740        total_grad /= n_samples as f64;
741
742        // Add entropy and its gradient
743        let entropy = self.variational.entropy();
744        total_elbo += entropy;
745
746        // Gradient of entropy wrt L_{ii} (diagonal of Cholesky factor) is 1/L_{ii}
747        let mut l_idx = dim;
748        for i in 0..dim {
749            for j in 0..=i {
750                if i == j {
751                    let l_ii = self.variational.chol_factor[[i, i]];
752                    if l_ii.abs() > 1e-15 {
753                        total_grad[l_idx] += 1.0 / l_ii;
754                    }
755                }
756                l_idx += 1;
757            }
758        }
759
760        Ok((total_elbo, total_grad))
761    }
762
763    /// Transform unconstrained parameters to constrained space
764    fn transform_to_constrained(&self, eta: &Array1<f64>) -> Result<Array1<f64>> {
765        let mut theta = Array1::zeros(self.dim);
766        for i in 0..self.dim {
767            theta[i] = self.constraints[i].forward(eta[i]);
768        }
769        Ok(theta)
770    }
771}
772
773// ============================================================================
774// ADVI Results
775// ============================================================================
776
777/// Results from mean-field ADVI
778#[derive(Debug, Clone)]
779pub struct AdviResult {
780    /// Optimized variational distribution in unconstrained space
781    pub variational: MeanFieldGaussian,
782    /// Parameter constraints
783    pub constraints: Vec<ParameterConstraint>,
784    /// Posterior means in constrained space
785    pub constrained_means: Array1<f64>,
786    /// Optimization diagnostics
787    pub diagnostics: VariationalDiagnostics,
788    /// Dimensionality
789    pub dim: usize,
790}
791
792impl AdviResult {
793    /// Get posterior means in unconstrained space
794    pub fn unconstrained_means(&self) -> &Array1<f64> {
795        &self.variational.means
796    }
797
798    /// Get posterior standard deviations in unconstrained space
799    pub fn unconstrained_stds(&self) -> Array1<f64> {
800        self.variational.stds()
801    }
802
803    /// Get posterior means in constrained space
804    pub fn constrained_means(&self) -> &Array1<f64> {
805        &self.constrained_means
806    }
807
808    /// Sample from the approximate posterior and transform to constrained space
809    pub fn sample_constrained(&self, epsilon: &Array1<f64>) -> Result<Array1<f64>> {
810        let eta = self.variational.sample(epsilon)?;
811        let mut theta = Array1::zeros(self.dim);
812        for i in 0..self.dim {
813            theta[i] = self.constraints[i].forward(eta[i]);
814        }
815        Ok(theta)
816    }
817
818    /// Compute approximate credible intervals in constrained space
819    ///
820    /// Note: These are approximate because the transform is nonlinear.
821    /// For more accurate intervals, use `sample_constrained` with many samples.
822    pub fn approximate_credible_intervals(&self, confidence: f64) -> Result<Array2<f64>> {
823        check_probability(confidence, "confidence")?;
824
825        let alpha = (1.0 - confidence) / 2.0;
826        let z_critical = super::normal_ppf(1.0 - alpha)?;
827
828        let stds = self.variational.stds();
829        let mut intervals = Array2::zeros((self.dim, 2));
830
831        for i in 0..self.dim {
832            let eta_low = self.variational.means[i] - z_critical * stds[i];
833            let eta_high = self.variational.means[i] + z_critical * stds[i];
834
835            // Transform bounds to constrained space
836            let theta_low = self.constraints[i].forward(eta_low);
837            let theta_high = self.constraints[i].forward(eta_high);
838
839            // Ensure proper ordering (some transforms may flip)
840            intervals[[i, 0]] = theta_low.min(theta_high);
841            intervals[[i, 1]] = theta_low.max(theta_high);
842        }
843
844        Ok(intervals)
845    }
846}
847
848/// Results from full-rank ADVI
849#[derive(Debug, Clone)]
850pub struct AdviFullRankResult {
851    /// Optimized variational distribution in unconstrained space
852    pub variational: FullRankGaussian,
853    /// Parameter constraints
854    pub constraints: Vec<ParameterConstraint>,
855    /// Posterior means in constrained space
856    pub constrained_means: Array1<f64>,
857    /// Optimization diagnostics
858    pub diagnostics: VariationalDiagnostics,
859    /// Dimensionality
860    pub dim: usize,
861}
862
863impl AdviFullRankResult {
864    /// Get posterior means in unconstrained space
865    pub fn unconstrained_means(&self) -> &Array1<f64> {
866        &self.variational.mean
867    }
868
869    /// Get posterior covariance in unconstrained space
870    pub fn unconstrained_covariance(&self) -> Array2<f64> {
871        self.variational.covariance()
872    }
873
874    /// Get posterior means in constrained space
875    pub fn constrained_means(&self) -> &Array1<f64> {
876        &self.constrained_means
877    }
878
879    /// Sample from the approximate posterior and transform to constrained space
880    pub fn sample_constrained(&self, epsilon: &Array1<f64>) -> Result<Array1<f64>> {
881        let eta = self.variational.sample(epsilon)?;
882        let mut theta = Array1::zeros(self.dim);
883        for i in 0..self.dim {
884            theta[i] = self.constraints[i].forward(eta[i]);
885        }
886        Ok(theta)
887    }
888
889    /// Compute approximate credible intervals in constrained space
890    pub fn approximate_credible_intervals(&self, confidence: f64) -> Result<Array2<f64>> {
891        check_probability(confidence, "confidence")?;
892
893        let alpha = (1.0 - confidence) / 2.0;
894        let z_critical = super::normal_ppf(1.0 - alpha)?;
895
896        let cov = self.variational.covariance();
897        let mut intervals = Array2::zeros((self.dim, 2));
898
899        for i in 0..self.dim {
900            let std_i = cov[[i, i]].sqrt();
901            let eta_low = self.variational.mean[i] - z_critical * std_i;
902            let eta_high = self.variational.mean[i] + z_critical * std_i;
903
904            let theta_low = self.constraints[i].forward(eta_low);
905            let theta_high = self.constraints[i].forward(eta_high);
906
907            intervals[[i, 0]] = theta_low.min(theta_high);
908            intervals[[i, 1]] = theta_low.max(theta_high);
909        }
910
911        Ok(intervals)
912    }
913}
914
915// ============================================================================
916// Helper functions
917// ============================================================================
918
919/// Compute log determinant of Jacobian for a single constraint
920fn compute_log_det_jacobian(constraint: &ParameterConstraint, unconstrained: f64) -> f64 {
921    constraint.log_det_jacobian(unconstrained)
922}
923
924/// Compute gradient of ELBO wrt unconstrained eta from constrained gradient
925fn compute_grad_eta_from_theta(
926    dim: usize,
927    eta: &Array1<f64>,
928    grad_theta: &Array1<f64>,
929    constraints: &[ParameterConstraint],
930) -> Result<Array1<f64>> {
931    let mut grad_eta = Array1::zeros(dim);
932    for i in 0..dim {
933        let dtheta_deta = compute_transform_deriv(&constraints[i], eta[i]);
934        grad_eta[i] = grad_theta[i] * dtheta_deta;
935    }
936    Ok(grad_eta)
937}
938
939/// Compute gradient of sum of log |det J| wrt eta
940fn compute_grad_log_det(
941    dim: usize,
942    eta: &Array1<f64>,
943    constraints: &[ParameterConstraint],
944) -> Result<Array1<f64>> {
945    let mut grad = Array1::zeros(dim);
946    for i in 0..dim {
947        grad[i] = compute_grad_log_det_single(&constraints[i], eta[i]);
948    }
949    Ok(grad)
950}
951
952/// Compute derivative of forward transform for a constraint
953fn compute_transform_deriv(constraint: &ParameterConstraint, unconstrained: f64) -> f64 {
954    match constraint {
955        ParameterConstraint::Real => 1.0,
956        ParameterConstraint::Positive => unconstrained.exp(),
957        ParameterConstraint::UnitInterval => {
958            let s = 1.0 / (1.0 + (-unconstrained).exp());
959            s * (1.0 - s)
960        }
961        ParameterConstraint::Bounded { lower, upper } => {
962            let s = 1.0 / (1.0 + (-unconstrained).exp());
963            (upper - lower) * s * (1.0 - s)
964        }
965        ParameterConstraint::LowerBounded { .. } => unconstrained.exp(),
966        ParameterConstraint::UpperBounded { .. } => (-unconstrained).exp(),
967        ParameterConstraint::Simplex { .. } => {
968            let s = 1.0 / (1.0 + (-unconstrained).exp());
969            s * (1.0 - s)
970        }
971    }
972}
973
974/// Compute d/d_eta log|det J_forward(eta)| for a single constraint
975fn compute_grad_log_det_single(constraint: &ParameterConstraint, unconstrained: f64) -> f64 {
976    match constraint {
977        ParameterConstraint::Real => 0.0,
978        ParameterConstraint::Positive => 1.0,
979        ParameterConstraint::UnitInterval => {
980            let s = 1.0 / (1.0 + (-unconstrained).exp());
981            1.0 - 2.0 * s
982        }
983        ParameterConstraint::Bounded { .. } => {
984            let s = 1.0 / (1.0 + (-unconstrained).exp());
985            1.0 - 2.0 * s
986        }
987        ParameterConstraint::LowerBounded { .. } => 1.0,
988        ParameterConstraint::UpperBounded { .. } => 1.0,
989        ParameterConstraint::Simplex { .. } => {
990            let s = 1.0 / (1.0 + (-unconstrained).exp());
991            1.0 - 2.0 * s
992        }
993    }
994}
995
996/// Generate approximate standard normal samples (deterministic)
997fn generate_standard_normal_advi(dim: usize, seed: u64) -> Array1<f64> {
998    let mut result = Array1::zeros(dim);
999    let golden_ratio = 1.618033988749895;
1000
1001    for i in 0..dim {
1002        let u1 = ((seed as f64 * golden_ratio + i as f64 * 0.7548776662466927) % 1.0).abs();
1003        let u2 = ((seed as f64 * 0.5698402909980532 + i as f64 * golden_ratio) % 1.0).abs();
1004
1005        let u1_safe = u1.max(1e-10).min(1.0 - 1e-10);
1006        let u2_safe = u2.max(1e-10).min(1.0 - 1e-10);
1007
1008        let r = (-2.0 * u1_safe.ln()).sqrt();
1009        let theta = 2.0 * PI * u2_safe;
1010        result[i] = r * theta.cos();
1011    }
1012
1013    result
1014}
1015
1016// ============================================================================
1017// Tests
1018// ============================================================================
1019
1020#[cfg(test)]
1021mod tests {
1022    use super::*;
1023    use scirs2_core::ndarray::Array1;
1024
1025    #[test]
1026    fn test_constraint_real() {
1027        let c = ParameterConstraint::Real;
1028        assert!((c.forward(1.5) - 1.5).abs() < 1e-10);
1029        let inv = c.inverse(1.5).expect("should invert");
1030        assert!((inv - 1.5).abs() < 1e-10);
1031        assert!((c.log_det_jacobian(1.5)).abs() < 1e-10);
1032    }
1033
1034    #[test]
1035    fn test_constraint_positive() {
1036        let c = ParameterConstraint::Positive;
1037        // forward(0) = exp(0) = 1
1038        assert!((c.forward(0.0) - 1.0).abs() < 1e-10);
1039        // forward(1) = exp(1) = e
1040        assert!((c.forward(1.0) - 1.0_f64.exp()).abs() < 1e-10);
1041        // inverse(e) = 1
1042        let inv = c.inverse(1.0_f64.exp()).expect("should invert");
1043        assert!((inv - 1.0).abs() < 1e-10);
1044        // Positive constraint: inverse of non-positive should fail
1045        assert!(c.inverse(-1.0).is_err());
1046    }
1047
1048    #[test]
1049    fn test_constraint_unit_interval() {
1050        let c = ParameterConstraint::UnitInterval;
1051        // sigmoid(0) = 0.5
1052        assert!((c.forward(0.0) - 0.5).abs() < 1e-10);
1053        // inverse(0.5) = 0
1054        let inv = c.inverse(0.5).expect("should invert");
1055        assert!(inv.abs() < 1e-10);
1056        // Boundary cases should fail
1057        assert!(c.inverse(0.0).is_err());
1058        assert!(c.inverse(1.0).is_err());
1059    }
1060
1061    #[test]
1062    fn test_constraint_bounded() {
1063        let c = ParameterConstraint::Bounded {
1064            lower: -1.0,
1065            upper: 1.0,
1066        };
1067        // forward(0) = -1 + 2 * sigmoid(0) = -1 + 1 = 0
1068        assert!((c.forward(0.0)).abs() < 1e-10);
1069        // inverse(0) should be 0
1070        let inv = c.inverse(0.0).expect("should invert");
1071        assert!(inv.abs() < 1e-10);
1072    }
1073
1074    #[test]
1075    fn test_constraint_lower_bounded() {
1076        let c = ParameterConstraint::LowerBounded { lower: 2.0 };
1077        // forward(0) = 2 + exp(0) = 3
1078        assert!((c.forward(0.0) - 3.0).abs() < 1e-10);
1079        let inv = c.inverse(3.0).expect("should invert");
1080        assert!(inv.abs() < 1e-10);
1081        assert!(c.inverse(1.0).is_err());
1082    }
1083
1084    #[test]
1085    fn test_constraint_roundtrip() {
1086        let constraints = vec![
1087            ParameterConstraint::Real,
1088            ParameterConstraint::Positive,
1089            ParameterConstraint::UnitInterval,
1090            ParameterConstraint::Bounded {
1091                lower: 0.0,
1092                upper: 10.0,
1093            },
1094        ];
1095
1096        let unconstrained_values = vec![0.5, 1.0, -0.5, 2.0];
1097        for (c, &eta) in constraints.iter().zip(unconstrained_values.iter()) {
1098            let theta = c.forward(eta);
1099            let eta_back = c.inverse(theta).expect("should invert");
1100            assert!(
1101                (eta_back - eta).abs() < 1e-8,
1102                "Roundtrip failed for {:?}: {} -> {} -> {}",
1103                c,
1104                eta,
1105                theta,
1106                eta_back
1107            );
1108        }
1109    }
1110
1111    #[test]
1112    fn test_advi_mean_field_creation() {
1113        let constraints = vec![ParameterConstraint::Real, ParameterConstraint::Positive];
1114        let config = AdviConfig::default();
1115        let advi = AdviMeanField::new(constraints, config).expect("should create");
1116        assert_eq!(advi.dim, 2);
1117    }
1118
1119    #[test]
1120    fn test_advi_mean_field_simple_gaussian() {
1121        // Fit ADVI to a simple 2D Gaussian target
1122        let target_mean = Array1::from_vec(vec![1.0, -2.0]);
1123        let target_precision = 2.0; // precision = 1/variance
1124
1125        let constraints = vec![ParameterConstraint::Real, ParameterConstraint::Real];
1126        let config = AdviConfig {
1127            max_iter: 500,
1128            n_mc_samples: 1,
1129            lr_schedule: LearningRateSchedule::Adam {
1130                lr: 0.05,
1131                beta1: 0.9,
1132                beta2: 0.999,
1133                epsilon: 1e-8,
1134            },
1135            tol: 1e-6,
1136            convergence_window: 20,
1137            ..AdviConfig::default()
1138        };
1139
1140        let mut advi = AdviMeanField::new(constraints, config).expect("should create");
1141
1142        let tm = target_mean.clone();
1143        let result = advi
1144            .fit(move |theta: &Array1<f64>| {
1145                let diff = theta - &tm;
1146                let log_p = -0.5 * target_precision * diff.dot(&diff);
1147                let grad = &diff * (-target_precision);
1148                Ok((log_p, grad))
1149            })
1150            .expect("should fit");
1151
1152        // Check that means are reasonable (within tolerance for stochastic optimization)
1153        assert!(
1154            result.diagnostics.n_iterations > 0,
1155            "Should have performed iterations"
1156        );
1157        assert!(
1158            result.diagnostics.final_elbo.is_finite(),
1159            "ELBO should be finite"
1160        );
1161    }
1162
1163    #[test]
1164    fn test_advi_full_rank_creation() {
1165        let constraints = vec![
1166            ParameterConstraint::Real,
1167            ParameterConstraint::Positive,
1168            ParameterConstraint::UnitInterval,
1169        ];
1170        let config = AdviConfig::default();
1171        let advi = AdviFullRank::new(constraints, config).expect("should create");
1172        assert_eq!(advi.dim, 3);
1173    }
1174
1175    #[test]
1176    fn test_advi_full_rank_simple() {
1177        let constraints = vec![ParameterConstraint::Real, ParameterConstraint::Real];
1178        let config = AdviConfig {
1179            max_iter: 200,
1180            n_mc_samples: 1,
1181            lr_schedule: LearningRateSchedule::Adam {
1182                lr: 0.02,
1183                beta1: 0.9,
1184                beta2: 0.999,
1185                epsilon: 1e-8,
1186            },
1187            tol: 1e-5,
1188            convergence_window: 20,
1189            ..AdviConfig::default()
1190        };
1191
1192        let mut advi = AdviFullRank::new(constraints, config).expect("should create");
1193
1194        let result = advi
1195            .fit(|theta: &Array1<f64>| {
1196                // Simple separable Gaussian
1197                let log_p = -0.5 * theta.dot(theta);
1198                let grad = theta * (-1.0);
1199                Ok((log_p, grad))
1200            })
1201            .expect("should fit");
1202
1203        assert!(result.diagnostics.n_iterations > 0);
1204        assert!(result.diagnostics.final_elbo.is_finite());
1205    }
1206
1207    #[test]
1208    fn test_advi_with_constrained_params() {
1209        // Test ADVI with mixed constraints
1210        let constraints = vec![
1211            ParameterConstraint::Real,     // unconstrained
1212            ParameterConstraint::Positive, // must be > 0
1213        ];
1214
1215        let config = AdviConfig {
1216            max_iter: 300,
1217            n_mc_samples: 1,
1218            lr_schedule: LearningRateSchedule::Adam {
1219                lr: 0.01,
1220                beta1: 0.9,
1221                beta2: 0.999,
1222                epsilon: 1e-8,
1223            },
1224            tol: 1e-5,
1225            convergence_window: 30,
1226            ..AdviConfig::default()
1227        };
1228
1229        let mut advi = AdviMeanField::new(constraints, config).expect("should create");
1230
1231        let result = advi
1232            .fit(|theta: &Array1<f64>| {
1233                // log p = -0.5 * (theta[0] - 1)^2 - 2 * (theta[1] - 2)^2
1234                let log_p = -0.5 * (theta[0] - 1.0).powi(2) - 2.0 * (theta[1] - 2.0).powi(2);
1235                let mut grad = Array1::zeros(2);
1236                grad[0] = -(theta[0] - 1.0);
1237                grad[1] = -4.0 * (theta[1] - 2.0);
1238                Ok((log_p, grad))
1239            })
1240            .expect("should fit");
1241
1242        // The constrained mean for the positive parameter should be > 0
1243        assert!(
1244            result.constrained_means[1] > 0.0,
1245            "Positive-constrained parameter should be > 0, got {}",
1246            result.constrained_means[1]
1247        );
1248    }
1249
1250    #[test]
1251    fn test_advi_result_credible_intervals() {
1252        let constraints = vec![ParameterConstraint::Real, ParameterConstraint::Positive];
1253        let config = AdviConfig {
1254            max_iter: 100,
1255            ..AdviConfig::default()
1256        };
1257
1258        let mut advi = AdviMeanField::new(constraints, config).expect("should create");
1259
1260        let result = advi
1261            .fit(|theta: &Array1<f64>| {
1262                let log_p = -0.5 * theta.dot(theta);
1263                let grad = theta * (-1.0);
1264                Ok((log_p, grad))
1265            })
1266            .expect("should fit");
1267
1268        let intervals = result
1269            .approximate_credible_intervals(0.95)
1270            .expect("should compute intervals");
1271
1272        assert_eq!(intervals.nrows(), 2);
1273        assert_eq!(intervals.ncols(), 2);
1274        // Lower bound should be less than upper bound
1275        for i in 0..2 {
1276            assert!(
1277                intervals[[i, 0]] <= intervals[[i, 1]],
1278                "Lower bound should be <= upper bound at dim {}",
1279                i
1280            );
1281        }
1282    }
1283
1284    #[test]
1285    fn test_log_det_jacobian_positive() {
1286        let c = ParameterConstraint::Positive;
1287        // log|det J| for exp transform is just the unconstrained value
1288        assert!((c.log_det_jacobian(0.0)).abs() < 1e-10);
1289        assert!((c.log_det_jacobian(1.0) - 1.0).abs() < 1e-10);
1290        assert!((c.log_det_jacobian(-1.0) - (-1.0)).abs() < 1e-10);
1291    }
1292
1293    #[test]
1294    fn test_log_det_jacobian_unit_interval() {
1295        let c = ParameterConstraint::UnitInterval;
1296        // At eta=0, sigmoid(0)=0.5, so log|sigmoid'(0)| = log(0.25)
1297        let expected = (0.25_f64).ln();
1298        assert!(
1299            (c.log_det_jacobian(0.0) - expected).abs() < 1e-10,
1300            "log det J at 0 should be {}, got {}",
1301            expected,
1302            c.log_det_jacobian(0.0)
1303        );
1304    }
1305}